anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. anemoi/datasets/__init__.py +1 -2
  2. anemoi/datasets/_version.py +16 -3
  3. anemoi/datasets/commands/check.py +1 -1
  4. anemoi/datasets/commands/copy.py +1 -2
  5. anemoi/datasets/commands/create.py +1 -1
  6. anemoi/datasets/commands/inspect.py +27 -35
  7. anemoi/datasets/commands/recipe/__init__.py +93 -0
  8. anemoi/datasets/commands/recipe/format.py +55 -0
  9. anemoi/datasets/commands/recipe/migrate.py +555 -0
  10. anemoi/datasets/commands/validate.py +59 -0
  11. anemoi/datasets/compute/recentre.py +3 -6
  12. anemoi/datasets/create/__init__.py +64 -26
  13. anemoi/datasets/create/check.py +10 -12
  14. anemoi/datasets/create/chunks.py +1 -2
  15. anemoi/datasets/create/config.py +5 -6
  16. anemoi/datasets/create/input/__init__.py +44 -65
  17. anemoi/datasets/create/input/action.py +296 -238
  18. anemoi/datasets/create/input/context/__init__.py +71 -0
  19. anemoi/datasets/create/input/context/field.py +54 -0
  20. anemoi/datasets/create/input/data_sources.py +7 -9
  21. anemoi/datasets/create/input/misc.py +2 -75
  22. anemoi/datasets/create/input/repeated_dates.py +11 -130
  23. anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
  24. anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
  25. anemoi/datasets/create/input/trace.py +1 -1
  26. anemoi/datasets/create/patch.py +1 -2
  27. anemoi/datasets/create/persistent.py +3 -5
  28. anemoi/datasets/create/size.py +1 -3
  29. anemoi/datasets/create/sources/accumulations.py +120 -145
  30. anemoi/datasets/create/sources/accumulations2.py +20 -53
  31. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  32. anemoi/datasets/create/sources/constants.py +39 -40
  33. anemoi/datasets/create/sources/empty.py +22 -19
  34. anemoi/datasets/create/sources/fdb.py +133 -0
  35. anemoi/datasets/create/sources/forcings.py +29 -29
  36. anemoi/datasets/create/sources/grib.py +94 -78
  37. anemoi/datasets/create/sources/grib_index.py +57 -55
  38. anemoi/datasets/create/sources/hindcasts.py +57 -59
  39. anemoi/datasets/create/sources/legacy.py +10 -62
  40. anemoi/datasets/create/sources/mars.py +121 -149
  41. anemoi/datasets/create/sources/netcdf.py +28 -25
  42. anemoi/datasets/create/sources/opendap.py +28 -26
  43. anemoi/datasets/create/sources/patterns.py +4 -6
  44. anemoi/datasets/create/sources/recentre.py +46 -48
  45. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  46. anemoi/datasets/create/sources/source.py +26 -51
  47. anemoi/datasets/create/sources/tendencies.py +68 -98
  48. anemoi/datasets/create/sources/xarray.py +4 -6
  49. anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
  50. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
  51. anemoi/datasets/create/sources/xarray_support/field.py +20 -16
  52. anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
  53. anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
  54. anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
  55. anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
  56. anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
  57. anemoi/datasets/create/sources/xarray_support/time.py +10 -13
  58. anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
  59. anemoi/datasets/create/sources/xarray_zarr.py +28 -25
  60. anemoi/datasets/create/sources/zenodo.py +43 -41
  61. anemoi/datasets/create/statistics/__init__.py +3 -6
  62. anemoi/datasets/create/testing.py +4 -0
  63. anemoi/datasets/create/typing.py +1 -2
  64. anemoi/datasets/create/utils.py +0 -43
  65. anemoi/datasets/create/zarr.py +7 -2
  66. anemoi/datasets/data/__init__.py +15 -6
  67. anemoi/datasets/data/complement.py +7 -12
  68. anemoi/datasets/data/concat.py +5 -8
  69. anemoi/datasets/data/dataset.py +48 -47
  70. anemoi/datasets/data/debug.py +7 -9
  71. anemoi/datasets/data/ensemble.py +4 -6
  72. anemoi/datasets/data/fill_missing.py +7 -10
  73. anemoi/datasets/data/forwards.py +22 -26
  74. anemoi/datasets/data/grids.py +12 -168
  75. anemoi/datasets/data/indexing.py +9 -12
  76. anemoi/datasets/data/interpolate.py +7 -15
  77. anemoi/datasets/data/join.py +8 -12
  78. anemoi/datasets/data/masked.py +6 -11
  79. anemoi/datasets/data/merge.py +5 -9
  80. anemoi/datasets/data/misc.py +41 -45
  81. anemoi/datasets/data/missing.py +11 -16
  82. anemoi/datasets/data/observations/__init__.py +8 -14
  83. anemoi/datasets/data/padded.py +3 -5
  84. anemoi/datasets/data/records/backends/__init__.py +2 -2
  85. anemoi/datasets/data/rescale.py +5 -12
  86. anemoi/datasets/data/rolling_average.py +141 -0
  87. anemoi/datasets/data/select.py +13 -16
  88. anemoi/datasets/data/statistics.py +4 -7
  89. anemoi/datasets/data/stores.py +22 -29
  90. anemoi/datasets/data/subset.py +8 -11
  91. anemoi/datasets/data/unchecked.py +7 -11
  92. anemoi/datasets/data/xy.py +25 -21
  93. anemoi/datasets/dates/__init__.py +15 -18
  94. anemoi/datasets/dates/groups.py +7 -10
  95. anemoi/datasets/dumper.py +76 -0
  96. anemoi/datasets/grids.py +4 -185
  97. anemoi/datasets/schemas/recipe.json +131 -0
  98. anemoi/datasets/testing.py +93 -7
  99. anemoi/datasets/validate.py +598 -0
  100. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
  101. anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
  102. anemoi/datasets/create/filter.py +0 -48
  103. anemoi/datasets/create/input/concat.py +0 -164
  104. anemoi/datasets/create/input/context.py +0 -89
  105. anemoi/datasets/create/input/empty.py +0 -54
  106. anemoi/datasets/create/input/filter.py +0 -118
  107. anemoi/datasets/create/input/function.py +0 -233
  108. anemoi/datasets/create/input/join.py +0 -130
  109. anemoi/datasets/create/input/pipe.py +0 -66
  110. anemoi/datasets/create/input/step.py +0 -177
  111. anemoi/datasets/create/input/template.py +0 -162
  112. anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
  113. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
  114. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
  115. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
  116. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,10 @@
11
11
  import logging
12
12
  import os
13
13
  import textwrap
14
+ from collections.abc import Callable
14
15
  from functools import wraps
15
16
  from typing import TYPE_CHECKING
16
17
  from typing import Any
17
- from typing import Callable
18
- from typing import List
19
- from typing import Optional
20
18
 
21
19
  from anemoi.utils.text import Tree
22
20
  from numpy.typing import NDArray
@@ -56,7 +54,7 @@ def css(name: str) -> str:
56
54
  class Node:
57
55
  """A class to represent a node in a dataset tree."""
58
56
 
59
- def __init__(self, dataset: "Dataset", kids: List[Any], **kwargs: Any) -> None:
57
+ def __init__(self, dataset: "Dataset", kids: list[Any], **kwargs: Any) -> None:
60
58
  """Initializes a Node object.
61
59
 
62
60
  Parameters
@@ -72,7 +70,7 @@ class Node:
72
70
  self.kids = kids
73
71
  self.kwargs = kwargs
74
72
 
75
- def _put(self, indent: int, result: List[str]) -> None:
73
+ def _put(self, indent: int, result: list[str]) -> None:
76
74
  """Helper method to add the node representation to the result list.
77
75
 
78
76
  Parameters
@@ -103,11 +101,11 @@ class Node:
103
101
  str
104
102
  String representation of the node.
105
103
  """
106
- result: List[str] = []
104
+ result: list[str] = []
107
105
  self._put(0, result)
108
106
  return "\n".join(result)
109
107
 
110
- def graph(self, digraph: List[str], nodes: dict) -> None:
108
+ def graph(self, digraph: list[str], nodes: dict) -> None:
111
109
  """Generates a graph representation of the node.
112
110
 
113
111
  Parameters
@@ -170,7 +168,7 @@ class Node:
170
168
  digraph.append("}")
171
169
  return "\n".join(digraph)
172
170
 
173
- def _html(self, indent: str, rows: List[List[str]]) -> None:
171
+ def _html(self, indent: str, rows: list[list[str]]) -> None:
174
172
  """Helper method to add the node representation to the HTML rows.
175
173
 
176
174
  Parameters
@@ -273,7 +271,7 @@ class Node:
273
271
  class Source:
274
272
  """A class used to follow the provenance of a data point."""
275
273
 
276
- def __init__(self, dataset: Any, index: int, source: Optional[Any] = None, info: Optional[Any] = None) -> None:
274
+ def __init__(self, dataset: Any, index: int, source: Any | None = None, info: Any | None = None) -> None:
277
275
  """Initializes a Source object.
278
276
 
279
277
  Parameters
@@ -10,8 +10,6 @@
10
10
 
11
11
  import logging
12
12
  from typing import Any
13
- from typing import Dict
14
- from typing import Tuple
15
13
 
16
14
  import numpy as np
17
15
  from numpy.typing import NDArray
@@ -105,7 +103,7 @@ class Number(Forwards):
105
103
  """
106
104
  return Node(self, [self.forward.tree()], numbers=[n + 1 for n in self.members])
107
105
 
108
- def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
106
+ def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
109
107
  """Returns metadata specific to the Number object.
110
108
 
111
109
  Parameters
@@ -122,7 +120,7 @@ class Number(Forwards):
122
120
  "numbers": [n + 1 for n in self.members],
123
121
  }
124
122
 
125
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
123
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
126
124
  """Returns metadata specific to the Number object."""
127
125
  return {}
128
126
 
@@ -140,7 +138,7 @@ class Ensemble(GivenAxis):
140
138
  """
141
139
  return Node(self, [d.tree() for d in self.datasets])
142
140
 
143
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
141
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
144
142
  """Get the metadata specific to the forwards subclass.
145
143
 
146
144
  Returns:
@@ -149,7 +147,7 @@ class Ensemble(GivenAxis):
149
147
  return {}
150
148
 
151
149
 
152
- def ensemble_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Ensemble:
150
+ def ensemble_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Ensemble:
153
151
  """Factory function to create an Ensemble object.
154
152
 
155
153
  Parameters
@@ -10,9 +10,6 @@
10
10
 
11
11
  import logging
12
12
  from typing import Any
13
- from typing import Dict
14
- from typing import Optional
15
- from typing import Set
16
13
 
17
14
  import numpy as np
18
15
  from numpy.typing import NDArray
@@ -46,7 +43,7 @@ class MissingDatesFill(Forwards):
46
43
  """
47
44
  super().__init__(dataset)
48
45
  self._missing = set(dataset.missing)
49
- self._warnings: Set[int] = set()
46
+ self._warnings: set[int] = set()
50
47
 
51
48
  @debug_indexing
52
49
  @expand_list_indexing
@@ -84,7 +81,7 @@ class MissingDatesFill(Forwards):
84
81
  return np.stack([self[i] for i in range(*s.indices(self._len))])
85
82
 
86
83
  @property
87
- def missing(self) -> Set[int]:
84
+ def missing(self) -> set[int]:
88
85
  """Get the set of missing dates."""
89
86
  return set()
90
87
 
@@ -153,7 +150,7 @@ class MissingDatesClosest(MissingDatesFill):
153
150
  self.closest = closest
154
151
  self._closest = {}
155
152
 
156
- def _fill_missing(self, n: int, a: Optional[int], b: Optional[int]) -> NDArray[Any]:
153
+ def _fill_missing(self, n: int, a: int | None, b: int | None) -> NDArray[Any]:
157
154
  """Fill the missing date at the given index.
158
155
 
159
156
  Parameters
@@ -189,7 +186,7 @@ class MissingDatesClosest(MissingDatesFill):
189
186
 
190
187
  return self.forward[self._closest[n]]
191
188
 
192
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
189
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
193
190
  """Get metadata specific to the subclass.
194
191
 
195
192
  Returns
@@ -224,7 +221,7 @@ class MissingDatesInterpolate(MissingDatesFill):
224
221
  super().__init__(dataset)
225
222
  self._alpha = {}
226
223
 
227
- def _fill_missing(self, n: int, a: Optional[int], b: Optional[int]) -> NDArray[Any]:
224
+ def _fill_missing(self, n: int, a: int | None, b: int | None) -> NDArray[Any]:
228
225
  """Fill the missing date at the given index using interpolation.
229
226
 
230
227
  Parameters
@@ -264,7 +261,7 @@ class MissingDatesInterpolate(MissingDatesFill):
264
261
  alpha = self._alpha[n]
265
262
  return self.forward[a] * (1 - alpha) + self.forward[b] * alpha
266
263
 
267
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
264
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
268
265
  """Get metadata specific to the subclass.
269
266
 
270
267
  Returns
@@ -285,7 +282,7 @@ class MissingDatesInterpolate(MissingDatesFill):
285
282
  return Node(self, [self.forward.tree()])
286
283
 
287
284
 
288
- def fill_missing_dates_factory(dataset: Any, method: str, kwargs: Dict[str, Any]) -> Dataset:
285
+ def fill_missing_dates_factory(dataset: Any, method: str, kwargs: dict[str, Any]) -> Dataset:
289
286
  """Factory function to create an instance of a class to fill missing dates.
290
287
 
291
288
  Parameters
@@ -14,10 +14,6 @@ import warnings
14
14
  from abc import abstractmethod
15
15
  from functools import cached_property
16
16
  from typing import Any
17
- from typing import Dict
18
- from typing import List
19
- from typing import Optional
20
- from typing import Set
21
17
 
22
18
  import numpy as np
23
19
  from numpy.typing import NDArray
@@ -75,7 +71,7 @@ class Forwards(Dataset):
75
71
  return self.forward[n]
76
72
 
77
73
  @property
78
- def name(self) -> Optional[str]:
74
+ def name(self) -> str | None:
79
75
  """Returns the name of the forward dataset."""
80
76
  if self._name is not None:
81
77
  return self._name
@@ -112,26 +108,26 @@ class Forwards(Dataset):
112
108
  return self.forward.longitudes
113
109
 
114
110
  @property
115
- def name_to_index(self) -> Dict[str, int]:
111
+ def name_to_index(self) -> dict[str, int]:
116
112
  """Returns a dictionary mapping variable names to their indices."""
117
113
  return self.forward.name_to_index
118
114
 
119
115
  @property
120
- def variables(self) -> List[str]:
116
+ def variables(self) -> list[str]:
121
117
  """Returns the variables of the forward dataset."""
122
118
  return self.forward.variables
123
119
 
124
120
  @property
125
- def variables_metadata(self) -> Dict[str, Any]:
121
+ def variables_metadata(self) -> dict[str, Any]:
126
122
  """Returns the metadata of the variables in the forward dataset."""
127
123
  return self.forward.variables_metadata
128
124
 
129
125
  @property
130
- def statistics(self) -> Dict[str, NDArray[Any]]:
126
+ def statistics(self) -> dict[str, NDArray[Any]]:
131
127
  """Returns the statistics of the forward dataset."""
132
128
  return self.forward.statistics
133
129
 
134
- def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
130
+ def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict[str, NDArray[Any]]:
135
131
  """Returns the statistics tendencies of the forward dataset.
136
132
 
137
133
  Parameters
@@ -159,7 +155,7 @@ class Forwards(Dataset):
159
155
  return self.forward.dtype
160
156
 
161
157
  @property
162
- def missing(self) -> Set[int]:
158
+ def missing(self) -> set[int]:
163
159
  """Returns the missing data information of the forward dataset."""
164
160
  return self.forward.missing
165
161
 
@@ -168,7 +164,7 @@ class Forwards(Dataset):
168
164
  """Returns the grids of the forward dataset."""
169
165
  return self.forward.grids
170
166
 
171
- def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
167
+ def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
172
168
  """Returns metadata specific to the forward dataset.
173
169
 
174
170
  Parameters
@@ -187,7 +183,7 @@ class Forwards(Dataset):
187
183
  **kwargs,
188
184
  )
189
185
 
190
- def collect_supporting_arrays(self, collected: List[Any], *path: Any) -> None:
186
+ def collect_supporting_arrays(self, collected: list[Any], *path: Any) -> None:
191
187
  """Collects supporting arrays from the forward dataset.
192
188
 
193
189
  Parameters
@@ -199,7 +195,7 @@ class Forwards(Dataset):
199
195
  """
200
196
  self.forward.collect_supporting_arrays(collected, *path)
201
197
 
202
- def collect_input_sources(self, collected: List[Any]) -> None:
198
+ def collect_input_sources(self, collected: list[Any]) -> None:
203
199
  """Collects input sources from the forward dataset.
204
200
 
205
201
  Parameters
@@ -225,11 +221,11 @@ class Forwards(Dataset):
225
221
  return self.forward.source(index)
226
222
 
227
223
  @abstractmethod
228
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
224
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
229
225
  """Returns metadata specific to the subclass."""
230
226
  pass
231
227
 
232
- def get_dataset_names(self, names: Set[str]) -> None:
228
+ def get_dataset_names(self, names: set[str]) -> None:
233
229
  """Collects the names of the datasets.
234
230
 
235
231
  Parameters
@@ -240,7 +236,7 @@ class Forwards(Dataset):
240
236
  self.forward.get_dataset_names(names)
241
237
 
242
238
  @property
243
- def constant_fields(self) -> List[str]:
239
+ def constant_fields(self) -> list[str]:
244
240
  """Returns the constant fields of the forward dataset."""
245
241
  return self.forward.constant_fields
246
242
 
@@ -248,7 +244,7 @@ class Forwards(Dataset):
248
244
  class Combined(Forwards):
249
245
  """A class to combine multiple datasets into a single dataset."""
250
246
 
251
- def __init__(self, datasets: List[Dataset]) -> None:
247
+ def __init__(self, datasets: list[Dataset]) -> None:
252
248
  """Initializes a Combined object.
253
249
 
254
250
  Parameters
@@ -466,7 +462,7 @@ class Combined(Forwards):
466
462
  self.check_same_variables(d1, d2)
467
463
  self.check_same_dates(d1, d2)
468
464
 
469
- def provenance(self) -> List[Any]:
465
+ def provenance(self) -> list[Any]:
470
466
  """Returns the provenance of the combined datasets.
471
467
 
472
468
  Returns
@@ -487,7 +483,7 @@ class Combined(Forwards):
487
483
  lst = ", ".join(repr(d) for d in self.datasets)
488
484
  return f"{self.__class__.__name__}({lst})"
489
485
 
490
- def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
486
+ def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
491
487
  """Returns metadata specific to the combined datasets.
492
488
 
493
489
  Parameters
@@ -508,7 +504,7 @@ class Combined(Forwards):
508
504
  **kwargs,
509
505
  )
510
506
 
511
- def collect_supporting_arrays(self, collected: List[Any], *path: Any) -> None:
507
+ def collect_supporting_arrays(self, collected: list[Any], *path: Any) -> None:
512
508
  """Collects supporting arrays from the combined datasets.
513
509
 
514
510
  Parameters
@@ -524,7 +520,7 @@ class Combined(Forwards):
524
520
  d.collect_supporting_arrays(collected, *path, name)
525
521
 
526
522
  @property
527
- def missing(self) -> Set[int]:
523
+ def missing(self) -> set[int]:
528
524
  """Returns the missing data information of the combined datasets.
529
525
 
530
526
  Raises
@@ -534,7 +530,7 @@ class Combined(Forwards):
534
530
  """
535
531
  raise NotImplementedError("missing() not implemented for Combined")
536
532
 
537
- def get_dataset_names(self, names: Set[str]) -> None:
533
+ def get_dataset_names(self, names: set[str]) -> None:
538
534
  """Collects the names of the combined datasets.
539
535
 
540
536
  Parameters
@@ -549,7 +545,7 @@ class Combined(Forwards):
549
545
  class GivenAxis(Combined):
550
546
  """A class to combine datasets along a given axis."""
551
547
 
552
- def __init__(self, datasets: List[Any], axis: int) -> None:
548
+ def __init__(self, datasets: list[Any], axis: int) -> None:
553
549
  """Initializes a GivenAxis object.
554
550
 
555
551
  Parameters
@@ -656,10 +652,10 @@ class GivenAxis(Combined):
656
652
  return np.concatenate([d[n] for d in self.datasets], axis=self.axis - 1)
657
653
 
658
654
  @cached_property
659
- def missing(self) -> Set[int]:
655
+ def missing(self) -> set[int]:
660
656
  """Returns the missing data information of the combined dataset along the given axis."""
661
657
  offset = 0
662
- result: Set[int] = set()
658
+ result: set[int] = set()
663
659
  for d in self.datasets:
664
660
  result.update(offset + m for m in d.missing)
665
661
  if self.axis == 0: # Advance if axis is time
@@ -11,10 +11,6 @@
11
11
  import logging
12
12
  from functools import cached_property
13
13
  from typing import Any
14
- from typing import Dict
15
- from typing import List
16
- from typing import Optional
17
- from typing import Tuple
18
14
 
19
15
  import numpy as np
20
16
  from numpy.typing import NDArray
@@ -25,171 +21,19 @@ from .dataset import FullIndex
25
21
  from .dataset import Shape
26
22
  from .dataset import TupleIndex
27
23
  from .debug import Node
28
- from .debug import debug_indexing
29
- from .forwards import Combined
30
24
  from .forwards import GivenAxis
31
25
  from .indexing import apply_index_to_slices_changes
32
- from .indexing import expand_list_indexing
33
26
  from .indexing import index_to_slices
34
- from .indexing import length_to_slices
35
- from .indexing import update_tuple
36
27
  from .misc import _auto_adjust
37
28
  from .misc import _open
38
29
 
39
30
  LOG = logging.getLogger(__name__)
40
31
 
41
32
 
42
- class Concat(Combined):
43
- """A class to represent concatenated datasets."""
44
-
45
- def __len__(self) -> int:
46
- """Returns the total length of the concatenated datasets.
47
-
48
- Returns
49
- -------
50
- int
51
- Total length of the concatenated datasets.
52
- """
53
- return sum(len(i) for i in self.datasets)
54
-
55
- @debug_indexing
56
- @expand_list_indexing
57
- def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
58
- """Retrieves a tuple of data from the concatenated datasets based on the given index.
59
-
60
- Parameters
61
- ----------
62
- index : TupleIndex
63
- Index specifying the data to retrieve.
64
-
65
- Returns
66
- -------
67
- NDArray[Any]
68
- Concatenated data array from the specified index.
69
- """
70
- index, changes = index_to_slices(index, self.shape)
71
- # print(index, changes)
72
- lengths = [d.shape[0] for d in self.datasets]
73
- slices = length_to_slices(index[0], lengths)
74
- # print("slies", slices)
75
- result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
76
- result = np.concatenate(result, axis=0)
77
- return apply_index_to_slices_changes(result, changes)
78
-
79
- @debug_indexing
80
- def __getitem__(self, n: FullIndex) -> NDArray[Any]:
81
- """Retrieves data from the concatenated datasets based on the given index.
82
-
83
- Parameters
84
- ----------
85
- n : FullIndex
86
- Index specifying the data to retrieve.
87
-
88
- Returns
89
- -------
90
- NDArray[Any]
91
- Data array from the concatenated datasets based on the index.
92
- """
93
- if isinstance(n, tuple):
94
- return self._get_tuple(n)
95
-
96
- if isinstance(n, slice):
97
- return self._get_slice(n)
98
-
99
- # TODO: optimize
100
- k = 0
101
- while n >= self.datasets[k]._len:
102
- n -= self.datasets[k]._len
103
- k += 1
104
- return self.datasets[k][n]
105
-
106
- @debug_indexing
107
- def _get_slice(self, s: slice) -> NDArray[Any]:
108
- """Retrieves a slice of data from the concatenated datasets.
109
-
110
- Parameters
111
- ----------
112
- s : slice
113
- Slice object specifying the range of data to retrieve.
114
-
115
- Returns
116
- -------
117
- NDArray[Any]
118
- Concatenated data array from the specified slice.
119
- """
120
- result = []
121
-
122
- lengths = [d.shape[0] for d in self.datasets]
123
- slices = length_to_slices(s, lengths)
124
-
125
- result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
126
-
127
- return np.concatenate(result)
128
-
129
- def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
130
- """Check the compatibility of two datasets for concatenation.
131
-
132
- Parameters
133
- ----------
134
- d1 : Dataset
135
- The first dataset.
136
- d2 : Dataset
137
- The second dataset.
138
- """
139
- super().check_compatibility(d1, d2)
140
- self.check_same_sub_shapes(d1, d2, drop_axis=0)
141
-
142
- def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
143
- """Check if the lengths of two datasets are the same.
144
-
145
- Parameters
146
- ----------
147
- d1 : Dataset
148
- The first dataset.
149
- d2 : Dataset
150
- The second dataset.
151
- """
152
- # Turned off because we are concatenating along the first axis
153
- pass
154
-
155
- def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
156
- """Check if the dates of two datasets are the same.
157
-
158
- Parameters
159
- ----------
160
- d1 : Dataset
161
- The first dataset.
162
- d2 : Dataset
163
- The second dataset.
164
- """
165
- # Turned off because we are concatenating along the dates axis
166
- pass
167
-
168
- @property
169
- def dates(self) -> NDArray[np.datetime64]:
170
- """Returns the concatenated dates of all datasets."""
171
- return np.concatenate([d.dates for d in self.datasets])
172
-
173
- @property
174
- def shape(self) -> Shape:
175
- """Returns the shape of the concatenated datasets."""
176
- return (len(self),) + self.datasets[0].shape[1:]
177
-
178
- def tree(self) -> Node:
179
- """Generates a hierarchical tree structure for the concatenated datasets.
180
-
181
- Returns
182
- -------
183
- Node
184
- A Node object representing the concatenated datasets.
185
- """
186
- return Node(self, [d.tree() for d in self.datasets])
187
-
188
-
189
33
  class GridsBase(GivenAxis):
190
34
  """A base class for handling grids in datasets."""
191
35
 
192
- def __init__(self, datasets: List[Any], axis: int) -> None:
36
+ def __init__(self, datasets: list[Any], axis: int) -> None:
193
37
  """Initializes a GridsBase object.
194
38
 
195
39
  Parameters
@@ -229,7 +73,7 @@ class GridsBase(GivenAxis):
229
73
  # We don't check the resolution, because we want to be able to combine
230
74
  pass
231
75
 
232
- def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
76
+ def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
233
77
  """Returns metadata specific to the GridsBase object.
234
78
 
235
79
  Parameters
@@ -246,7 +90,7 @@ class GridsBase(GivenAxis):
246
90
  multi_grids=True,
247
91
  )
248
92
 
249
- def collect_input_sources(self, collected: List[Any]) -> None:
93
+ def collect_input_sources(self, collected: list[Any]) -> None:
250
94
  """Collects input sources from the datasets.
251
95
 
252
96
  Parameters
@@ -275,7 +119,7 @@ class Grids(GridsBase):
275
119
  return np.concatenate([d.longitudes for d in self.datasets])
276
120
 
277
121
  @property
278
- def grids(self) -> Tuple[Any, ...]:
122
+ def grids(self) -> tuple[Any, ...]:
279
123
  """Returns the grids of all datasets."""
280
124
  result = []
281
125
  for d in self.datasets:
@@ -292,7 +136,7 @@ class Grids(GridsBase):
292
136
  """
293
137
  return Node(self, [d.tree() for d in self.datasets], mode="concat")
294
138
 
295
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
139
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
296
140
  """Get the metadata specific to the forwards subclass.
297
141
 
298
142
  Returns:
@@ -306,12 +150,12 @@ class Cutout(GridsBase):
306
150
 
307
151
  def __init__(
308
152
  self,
309
- datasets: List[Any],
153
+ datasets: list[Any],
310
154
  axis: int = 3,
311
155
  cropping_distance: float = 2.0,
312
156
  neighbours: int = 5,
313
- min_distance_km: Optional[float] = None,
314
- plot: Optional[bool] = None,
157
+ min_distance_km: float | None = None,
158
+ plot: bool | None = None,
315
159
  ) -> None:
316
160
  """Initializes a Cutout object for hierarchical management of Limited Area
317
161
  Models (LAMs) and a global dataset, handling overlapping regions.
@@ -487,7 +331,7 @@ class Cutout(GridsBase):
487
331
 
488
332
  return apply_index_to_slices_changes(result, changes)
489
333
 
490
- def collect_supporting_arrays(self, collected: List[Any], *path: Any) -> None:
334
+ def collect_supporting_arrays(self, collected: list[Any], *path: Any) -> None:
491
335
  """Collect supporting arrays, including masks for each LAM and the global dataset.
492
336
 
493
337
  Parameters
@@ -577,7 +421,7 @@ class Cutout(GridsBase):
577
421
  """
578
422
  return Node(self, [d.tree() for d in self.datasets])
579
423
 
580
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
424
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
581
425
  """Returns metadata specific to the Cutout object.
582
426
 
583
427
  Returns
@@ -588,7 +432,7 @@ class Cutout(GridsBase):
588
432
  return {}
589
433
 
590
434
 
591
- def grids_factory(args: Tuple[Any, ...], kwargs: dict) -> Dataset:
435
+ def grids_factory(args: tuple[Any, ...], kwargs: dict) -> Dataset:
592
436
  """Factory function to create a Grids object.
593
437
 
594
438
  Parameters
@@ -618,7 +462,7 @@ def grids_factory(args: Tuple[Any, ...], kwargs: dict) -> Dataset:
618
462
  return Grids(datasets, axis=axis)._subset(**kwargs)
619
463
 
620
464
 
621
- def cutout_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dataset:
465
+ def cutout_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Dataset:
622
466
  """Factory function to create a Cutout object.
623
467
 
624
468
  Parameters