anemoi-datasets 0.5.27__py3-none-any.whl → 0.5.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/recipe/__init__.py +93 -0
  3. anemoi/datasets/commands/recipe/format.py +55 -0
  4. anemoi/datasets/commands/recipe/migrate.py +555 -0
  5. anemoi/datasets/create/__init__.py +42 -1
  6. anemoi/datasets/create/config.py +2 -0
  7. anemoi/datasets/create/input/__init__.py +43 -63
  8. anemoi/datasets/create/input/action.py +296 -236
  9. anemoi/datasets/create/input/context/__init__.py +71 -0
  10. anemoi/datasets/create/input/context/field.py +54 -0
  11. anemoi/datasets/create/input/data_sources.py +2 -1
  12. anemoi/datasets/create/input/misc.py +0 -71
  13. anemoi/datasets/create/input/repeated_dates.py +0 -114
  14. anemoi/datasets/create/input/result/__init__.py +17 -0
  15. anemoi/datasets/create/input/{result.py → result/field.py} +9 -89
  16. anemoi/datasets/create/sources/accumulations.py +74 -94
  17. anemoi/datasets/create/sources/accumulations2.py +16 -45
  18. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  19. anemoi/datasets/create/sources/constants.py +39 -38
  20. anemoi/datasets/create/sources/empty.py +26 -22
  21. anemoi/datasets/create/sources/forcings.py +29 -28
  22. anemoi/datasets/create/sources/grib.py +92 -72
  23. anemoi/datasets/create/sources/grib_index.py +46 -42
  24. anemoi/datasets/create/sources/hindcasts.py +56 -55
  25. anemoi/datasets/create/sources/legacy.py +10 -62
  26. anemoi/datasets/create/sources/mars.py +107 -131
  27. anemoi/datasets/create/sources/netcdf.py +28 -24
  28. anemoi/datasets/create/sources/opendap.py +28 -24
  29. anemoi/datasets/create/sources/recentre.py +42 -41
  30. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  31. anemoi/datasets/create/sources/source.py +26 -48
  32. anemoi/datasets/create/sources/tendencies.py +67 -94
  33. anemoi/datasets/create/sources/xarray_support/__init__.py +29 -24
  34. anemoi/datasets/create/sources/xarray_support/field.py +4 -4
  35. anemoi/datasets/create/sources/xarray_zarr.py +28 -24
  36. anemoi/datasets/create/sources/zenodo.py +43 -39
  37. anemoi/datasets/create/utils.py +0 -42
  38. anemoi/datasets/data/dataset.py +6 -0
  39. anemoi/datasets/data/grids.py +0 -152
  40. anemoi/datasets/data/rolling_average.py +141 -0
  41. anemoi/datasets/data/stores.py +7 -9
  42. anemoi/datasets/dates/__init__.py +2 -0
  43. anemoi/datasets/dumper.py +76 -0
  44. anemoi/datasets/grids.py +1 -178
  45. anemoi/datasets/schemas/recipe.json +131 -0
  46. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +5 -2
  47. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.28.dist-info}/RECORD +51 -51
  48. anemoi/datasets/create/filter.py +0 -47
  49. anemoi/datasets/create/input/concat.py +0 -161
  50. anemoi/datasets/create/input/context.py +0 -86
  51. anemoi/datasets/create/input/empty.py +0 -53
  52. anemoi/datasets/create/input/filter.py +0 -117
  53. anemoi/datasets/create/input/function.py +0 -232
  54. anemoi/datasets/create/input/join.py +0 -129
  55. anemoi/datasets/create/input/pipe.py +0 -66
  56. anemoi/datasets/create/input/step.py +0 -173
  57. anemoi/datasets/create/input/template.py +0 -161
  58. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
  59. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
  60. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
  61. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
@@ -96,48 +96,6 @@ def to_datetime(*args: Any, **kwargs: Any) -> datetime.datetime:
96
96
  return to_datetime_(*args, **kwargs)
97
97
 
98
98
 
99
- def make_list_int(value: str | list | tuple | int) -> list[int]:
100
- """Convert a string, list, tuple, or integer to a list of integers.
101
-
102
- Parameters
103
- ----------
104
- value : str or list or tuple or int
105
- The value to convert.
106
-
107
- Returns
108
- -------
109
- list[int]
110
- A list of integers.
111
-
112
- Raises
113
- ------
114
- ValueError
115
- If the value cannot be converted to a list of integers.
116
- """
117
- # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
118
- # Moved to anemoi.utils.humanize
119
- # replace with from anemoi.utils.humanize import make_list_int
120
- # when anemoi-utils is released and pyproject.toml is updated
121
- if isinstance(value, str):
122
- if "/" not in value:
123
- return [value]
124
- bits = value.split("/")
125
- if len(bits) == 3 and bits[1].lower() == "to":
126
- value = list(range(int(bits[0]), int(bits[2]) + 1, 1))
127
-
128
- elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by":
129
- value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4])))
130
-
131
- if isinstance(value, list):
132
- return value
133
- if isinstance(value, tuple):
134
- return value
135
- if isinstance(value, int):
136
- return [value]
137
-
138
- raise ValueError(f"Cannot make list from {value}")
139
-
140
-
141
99
  def normalize_and_check_dates(
142
100
  dates: list[datetime.datetime],
143
101
  start: datetime.datetime,
@@ -293,6 +293,12 @@ class Dataset(ABC, Sized):
293
293
  if skip_missing_dates:
294
294
  return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
295
295
 
296
+ if "rolling_average" in kwargs:
297
+ from .rolling_average import RollingAverage
298
+
299
+ rolling_average = kwargs.pop("rolling_average")
300
+ return RollingAverage(self, rolling_average)._subset(**kwargs).mutate()
301
+
296
302
  if "interpolate_frequency" in kwargs:
297
303
  from .interpolate import InterpolateFrequency
298
304
 
@@ -21,167 +21,15 @@ from .dataset import FullIndex
21
21
  from .dataset import Shape
22
22
  from .dataset import TupleIndex
23
23
  from .debug import Node
24
- from .debug import debug_indexing
25
- from .forwards import Combined
26
24
  from .forwards import GivenAxis
27
25
  from .indexing import apply_index_to_slices_changes
28
- from .indexing import expand_list_indexing
29
26
  from .indexing import index_to_slices
30
- from .indexing import length_to_slices
31
- from .indexing import update_tuple
32
27
  from .misc import _auto_adjust
33
28
  from .misc import _open
34
29
 
35
30
  LOG = logging.getLogger(__name__)
36
31
 
37
32
 
38
- class Concat(Combined):
39
- """A class to represent concatenated datasets."""
40
-
41
- def __len__(self) -> int:
42
- """Returns the total length of the concatenated datasets.
43
-
44
- Returns
45
- -------
46
- int
47
- Total length of the concatenated datasets.
48
- """
49
- return sum(len(i) for i in self.datasets)
50
-
51
- @debug_indexing
52
- @expand_list_indexing
53
- def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
54
- """Retrieves a tuple of data from the concatenated datasets based on the given index.
55
-
56
- Parameters
57
- ----------
58
- index : TupleIndex
59
- Index specifying the data to retrieve.
60
-
61
- Returns
62
- -------
63
- NDArray[Any]
64
- Concatenated data array from the specified index.
65
- """
66
- index, changes = index_to_slices(index, self.shape)
67
- # print(index, changes)
68
- lengths = [d.shape[0] for d in self.datasets]
69
- slices = length_to_slices(index[0], lengths)
70
- # print("slies", slices)
71
- result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
72
- result = np.concatenate(result, axis=0)
73
- return apply_index_to_slices_changes(result, changes)
74
-
75
- @debug_indexing
76
- def __getitem__(self, n: FullIndex) -> NDArray[Any]:
77
- """Retrieves data from the concatenated datasets based on the given index.
78
-
79
- Parameters
80
- ----------
81
- n : FullIndex
82
- Index specifying the data to retrieve.
83
-
84
- Returns
85
- -------
86
- NDArray[Any]
87
- Data array from the concatenated datasets based on the index.
88
- """
89
- if isinstance(n, tuple):
90
- return self._get_tuple(n)
91
-
92
- if isinstance(n, slice):
93
- return self._get_slice(n)
94
-
95
- # TODO: optimize
96
- k = 0
97
- while n >= self.datasets[k]._len:
98
- n -= self.datasets[k]._len
99
- k += 1
100
- return self.datasets[k][n]
101
-
102
- @debug_indexing
103
- def _get_slice(self, s: slice) -> NDArray[Any]:
104
- """Retrieves a slice of data from the concatenated datasets.
105
-
106
- Parameters
107
- ----------
108
- s : slice
109
- Slice object specifying the range of data to retrieve.
110
-
111
- Returns
112
- -------
113
- NDArray[Any]
114
- Concatenated data array from the specified slice.
115
- """
116
- result = []
117
-
118
- lengths = [d.shape[0] for d in self.datasets]
119
- slices = length_to_slices(s, lengths)
120
-
121
- result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
122
-
123
- return np.concatenate(result)
124
-
125
- def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
126
- """Check the compatibility of two datasets for concatenation.
127
-
128
- Parameters
129
- ----------
130
- d1 : Dataset
131
- The first dataset.
132
- d2 : Dataset
133
- The second dataset.
134
- """
135
- super().check_compatibility(d1, d2)
136
- self.check_same_sub_shapes(d1, d2, drop_axis=0)
137
-
138
- def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
139
- """Check if the lengths of two datasets are the same.
140
-
141
- Parameters
142
- ----------
143
- d1 : Dataset
144
- The first dataset.
145
- d2 : Dataset
146
- The second dataset.
147
- """
148
- # Turned off because we are concatenating along the first axis
149
- pass
150
-
151
- def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
152
- """Check if the dates of two datasets are the same.
153
-
154
- Parameters
155
- ----------
156
- d1 : Dataset
157
- The first dataset.
158
- d2 : Dataset
159
- The second dataset.
160
- """
161
- # Turned off because we are concatenating along the dates axis
162
- pass
163
-
164
- @property
165
- def dates(self) -> NDArray[np.datetime64]:
166
- """Returns the concatenated dates of all datasets."""
167
- return np.concatenate([d.dates for d in self.datasets])
168
-
169
- @property
170
- def shape(self) -> Shape:
171
- """Returns the shape of the concatenated datasets."""
172
- return (len(self),) + self.datasets[0].shape[1:]
173
-
174
- def tree(self) -> Node:
175
- """Generates a hierarchical tree structure for the concatenated datasets.
176
-
177
- Returns
178
- -------
179
- Node
180
- A Node object representing the concatenated datasets.
181
- """
182
- return Node(self, [d.tree() for d in self.datasets])
183
-
184
-
185
33
  class GridsBase(GivenAxis):
186
34
  """A base class for handling grids in datasets."""
187
35
 
@@ -0,0 +1,141 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ from functools import cached_property
13
+ from typing import Any
14
+
15
+ import numpy as np
16
+ from numpy.typing import NDArray
17
+
18
+ from anemoi.datasets.data.indexing import expand_list_indexing
19
+
20
+ from .dataset import Dataset
21
+ from .dataset import FullIndex
22
+ from .debug import Node
23
+ from .debug import debug_indexing
24
+ from .forwards import Forwards
25
+
26
+ LOG = logging.getLogger(__name__)
27
+
28
+
29
+ class RollingAverage(Forwards):
30
+ """A class to represent a dataset with interpolated frequency."""
31
+
32
+ def __init__(self, dataset: Dataset, window: str | tuple[int, int, str]) -> None:
33
+ """Initialize the RollingAverage class.
34
+
35
+ Parameters
36
+ ----------
37
+ dataset : Dataset
38
+ The dataset to be averaged with a rolling window.
39
+ window : (int, int, str)
40
+ The rolling average window (start, end, 'freq').
41
+ 'freq' means the window is in number of time steps in the dataset.
42
+ Both start and end are inclusive, i.e. window = (-2, 2, 'freq') means a window of 5 time steps.
43
+ For now, only 'freq' is supported, in the future other units may be supported.
44
+ Windows such as "[-2h, +2h]" are not supported yet.
45
+ """
46
+ super().__init__(dataset)
47
+ if not (isinstance(window, (list, tuple)) and len(window) == 3):
48
+ raise ValueError(f"Window must be (int, int, str), got {window}")
49
+ if not isinstance(window[0], int) or not isinstance(window[1], int) or not isinstance(window[2], str):
50
+ raise ValueError(f"Window must be (int, int, str), got {window}")
51
+ if window[2] not in ["freq", "frequency"]:
52
+ raise NotImplementedError(f"Window must be (int, int, 'freq'), got {window}")
53
+
54
+ # window = (0, 0, 'freq') means no change
55
+ self.i_start = -window[0]
56
+ self.i_end = window[1] + 1
57
+ if self.i_start <= 0:
58
+ raise ValueError(f"Window start must be negative, got {window}")
59
+ if self.i_end <= 0:
60
+ raise ValueError(f"Window end must be positive, got {window}")
61
+
62
+ self.window_str = f"-{self.i_start}-to-{self.i_end}"
63
+
64
+ @property
65
+ def shape(self):
66
+ shape = list(self.forward.shape)
67
+ shape[0] = len(self)
68
+ return tuple(shape)
69
+
70
+ @debug_indexing
71
+ @expand_list_indexing
72
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
73
+ def f(array):
74
+ return np.nanmean(array, axis=0)
75
+
76
+ if isinstance(n, slice):
77
+ n = (n,)
78
+
79
+ if isinstance(n, tuple):
80
+ first = n[0]
81
+ if len(n) > 1:
82
+ rest = n[1:]
83
+ else:
84
+ rest = ()
85
+
86
+ if isinstance(first, int):
87
+ slice_ = slice(first, first + self.i_start + self.i_end)
88
+ data = self.forward[(slice_,) + rest]
89
+ return f(data)
90
+
91
+ if isinstance(first, slice):
92
+ first = list(range(first.start or 0, first.stop or len(self), first.step or 1))
93
+
94
+ if isinstance(first, (list, tuple)):
95
+ first = [i if i >= 0 else len(self) + i for i in first]
96
+ if any(i >= len(self) for i in first):
97
+ raise IndexError(f"Index out of range: {first}")
98
+ slices = [slice(i, i + self.i_start + self.i_end) for i in first]
99
+ data = [self.forward[(slice_,) + rest] for slice_ in slices]
100
+ res = [f(d) for d in data]
101
+ return np.array(res)
102
+
103
+ assert False, f"Expected int, slice, list or tuple as first element of tuple, got {type(first)}"
104
+
105
+ assert isinstance(n, int), f"Expected int, slice, tuple, got {type(n)}"
106
+
107
+ if n < 0:
108
+ n = len(self) + n
109
+ if n >= len(self):
110
+ raise IndexError(f"Index out of range: {n}")
111
+
112
+ slice_ = slice(n, n + self.i_start + self.i_end)
113
+ data = self.forward[slice_]
114
+ return f(data)
115
+
116
+ def __len__(self) -> int:
117
+ return len(self.forward) - (self.i_end + self.i_start - 1)
118
+
119
+ @cached_property
120
+ def dates(self) -> NDArray[np.datetime64]:
121
+ """Get the interpolated dates."""
122
+ dates = self.forward.dates
123
+ return dates[self.i_start : len(dates) - self.i_end + 1]
124
+
125
+ def tree(self) -> Node:
126
+ return Node(self, [self.forward.tree()], window=self.window_str)
127
+
128
+ @cached_property
129
+ def missing(self) -> set[int]:
130
+ """Get the missing data indices."""
131
+ result = []
132
+
133
+ for i in self.forward.missing:
134
+ for j in range(0, self.i_end + self.i_start):
135
+ result.append(i + j)
136
+
137
+ result = {x for x in result if x < self._len}
138
+ return result
139
+
140
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
141
+ return {}
@@ -85,22 +85,20 @@ class S3Store(ReadOnlyStore):
85
85
  options using the anemoi configs.
86
86
  """
87
87
 
88
- def __init__(self, url: str, region: str | None = None) -> None:
89
- """Initialize the S3Store with a URL and optional region."""
90
- from anemoi.utils.remote.s3 import s3_client
88
+ def __init__(self, url: str) -> None:
89
+ """Initialize the S3Store with a URL."""
91
90
 
92
- _, _, self.bucket, self.key = url.split("/", 3)
93
- self.s3 = s3_client(self.bucket, region=region)
91
+ self.url = url
94
92
 
95
93
  def __getitem__(self, key: str) -> bytes:
96
94
  """Retrieve an item from the store."""
95
+ from anemoi.utils.remote.s3 import get_object
96
+
97
97
  try:
98
- response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key)
99
- except self.s3.exceptions.NoSuchKey:
98
+ return get_object(os.path.join(self.url, key))
99
+ except FileNotFoundError:
100
100
  raise KeyError(key)
101
101
 
102
- return response["Body"].read()
103
-
104
102
 
105
103
  class DebugStore(ReadOnlyStore):
106
104
  """A store to debug the zarr loading."""
@@ -58,6 +58,8 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]:
58
58
  class DatesProvider:
59
59
  """Base class for date generation.
60
60
 
61
+ Examples
62
+ --------
61
63
  >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values
62
64
  [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)]
63
65
 
@@ -0,0 +1,76 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import datetime
11
+ import io
12
+ import logging
13
+
14
+ import ruamel.yaml
15
+
16
+ LOG = logging.getLogger(__name__)
17
+
18
+
19
+ def represent_date(dumper, data):
20
+
21
+ if isinstance(data, datetime.datetime):
22
+ if data.tzinfo is None:
23
+ data = data.replace(tzinfo=datetime.timezone.utc)
24
+ data = data.astimezone(datetime.timezone.utc)
25
+ iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z"
26
+ else:
27
+ iso_str = data.isoformat()
28
+
29
+ return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str)
30
+
31
+
32
+ # --- Represent multiline strings with | style ---
33
+ def represent_multiline_str(dumper, data):
34
+ if "\n" in data:
35
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|")
36
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data)
37
+
38
+
39
+ # --- Represent short lists inline (flow style) ---
40
+ def represent_inline_list(dumper, data):
41
+
42
+ if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data):
43
+ return dumper.represent_sequence("tag:yaml.org,2002:seq", data)
44
+
45
+ return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True)
46
+
47
+
48
+ def yaml_dump(obj, order=None, stream=None, **kwargs):
49
+
50
+ if order:
51
+
52
+ def _ordering(k):
53
+ return order.index(k) if k in order else len(order)
54
+
55
+ obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))}
56
+
57
+ yaml = ruamel.yaml.YAML()
58
+ yaml.width = 120 # wrap long flow sequences
59
+
60
+ yaml.Representer.add_representer(datetime.date, represent_date)
61
+ yaml.Representer.add_representer(datetime.datetime, represent_date)
62
+ yaml.Representer.add_representer(str, represent_multiline_str)
63
+ yaml.Representer.add_representer(list, represent_inline_list)
64
+
65
+ data = ruamel.yaml.comments.CommentedMap()
66
+ for i, (k, v) in enumerate(obj.items()):
67
+ data[k] = v
68
+ if i > 0:
69
+ data.yaml_set_comment_before_after_key(key=k, before="\n")
70
+
71
+ if stream:
72
+ yaml.dump(data, stream=stream, **kwargs)
73
+
74
+ stream = io.StringIO()
75
+ yaml.dump(data, stream=stream, **kwargs)
76
+ return stream.getvalue()
anemoi/datasets/grids.py CHANGED
@@ -8,11 +8,11 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
- import base64
12
11
  import logging
13
12
  from typing import Any
14
13
 
15
14
  import numpy as np
15
+ from anemoi.utils.grids import latlon_to_xyz
16
16
  from numpy.typing import NDArray
17
17
 
18
18
  LOG = logging.getLogger(__name__)
@@ -88,71 +88,6 @@ def plot_mask(
88
88
  plt.savefig(path + "-global-zoomed.png")
89
89
 
90
90
 
91
- # TODO: Use the one from anemoi.utils.grids instead
92
- # from anemoi.utils.grids import ...
93
- def xyz_to_latlon(x: NDArray[Any], y: NDArray[Any], z: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any]]:
94
- """Convert Cartesian coordinates to latitude and longitude.
95
-
96
- Parameters
97
- ----------
98
- x : NDArray[Any]
99
- X coordinates.
100
- y : NDArray[Any]
101
- Y coordinates.
102
- z : NDArray[Any]
103
- Z coordinates.
104
-
105
- Returns
106
- -------
107
- Tuple[NDArray[Any], NDArray[Any]]
108
- Latitude and longitude coordinates.
109
- """
110
- return (
111
- np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))),
112
- np.rad2deg(np.arctan2(y, x)),
113
- )
114
-
115
-
116
- # TODO: Use the one from anemoi.utils.grids instead
117
- # from anemoi.utils.grids import ...
118
- def latlon_to_xyz(
119
- lat: NDArray[Any], lon: NDArray[Any], radius: float = 1.0
120
- ) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]:
121
- """Convert latitude and longitude to Cartesian coordinates.
122
-
123
- Parameters
124
- ----------
125
- lat : NDArray[Any]
126
- Latitude coordinates.
127
- lon : NDArray[Any]
128
- Longitude coordinates.
129
- radius : float, optional
130
- Radius of the sphere. Defaults to 1.0.
131
-
132
- Returns
133
- -------
134
- Tuple[NDArray[Any], NDArray[Any], NDArray[Any]]
135
- X, Y, and Z coordinates.
136
- """
137
- # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
138
- # We assume that the Earth is a sphere of radius 1 so N(phi) = 1
139
- # We assume h = 0
140
- #
141
- phi = np.deg2rad(lat)
142
- lda = np.deg2rad(lon)
143
-
144
- cos_phi = np.cos(phi)
145
- cos_lda = np.cos(lda)
146
- sin_phi = np.sin(phi)
147
- sin_lda = np.sin(lda)
148
-
149
- x = cos_phi * cos_lda * radius
150
- y = cos_phi * sin_lda * radius
151
- z = sin_phi * radius
152
-
153
- return x, y, z
154
-
155
-
156
91
  class Triangle3D:
157
92
  """A class to represent a 3D triangle and perform intersection tests with rays."""
158
93
 
@@ -509,92 +444,6 @@ def outline(lats: NDArray[Any], lons: NDArray[Any], neighbours: int = 5) -> list
509
444
  return outside
510
445
 
511
446
 
512
- def deserialise_mask(encoded: str) -> NDArray[Any]:
513
- """Deserialise a mask from a base64 encoded string.
514
-
515
- Parameters
516
- ----------
517
- encoded : str
518
- Base64 encoded string.
519
-
520
- Returns
521
- -------
522
- NDArray[Any]
523
- Deserialised mask array.
524
- """
525
- import pickle
526
- import zlib
527
-
528
- packed = pickle.loads(zlib.decompress(base64.b64decode(encoded)))
529
-
530
- mask = []
531
- value = False
532
- for count in packed:
533
- mask.extend([value] * count)
534
- value = not value
535
- return np.array(mask, dtype=bool)
536
-
537
-
538
- def _serialise_mask(mask: NDArray[Any]) -> str:
539
- """Serialise a mask to a base64 encoded string.
540
-
541
- Parameters
542
- ----------
543
- mask : NDArray[Any]
544
- Mask array.
545
-
546
- Returns
547
- -------
548
- str
549
- Base64 encoded string.
550
- """
551
- import pickle
552
- import zlib
553
-
554
- assert len(mask.shape) == 1
555
- assert len(mask)
556
-
557
- packed = []
558
- last = mask[0]
559
- count = 1
560
-
561
- for value in mask[1:]:
562
- if value == last:
563
- count += 1
564
- else:
565
- packed.append(count)
566
- last = value
567
- count = 1
568
-
569
- packed.append(count)
570
-
571
- # We always start with an 'off' value
572
- # So if the first value is 'on', we need to add a zero
573
- if mask[0]:
574
- packed.insert(0, 0)
575
-
576
- return base64.b64encode(zlib.compress(pickle.dumps(packed))).decode("utf-8")
577
-
578
-
579
- def serialise_mask(mask: NDArray[Any]) -> str:
580
- """Serialise a mask and ensure it can be deserialised.
581
-
582
- Parameters
583
- ----------
584
- mask : NDArray[Any]
585
- Mask array.
586
-
587
- Returns
588
- -------
589
- str
590
- Base64 encoded string.
591
- """
592
- result = _serialise_mask(mask)
593
- # Make sure we can deserialise it
594
- assert np.all(mask == deserialise_mask(result))
595
- return result
596
-
597
-
598
447
  def nearest_grid_points(
599
448
  source_latitudes: NDArray[Any],
600
449
  source_longitudes: NDArray[Any],
@@ -640,29 +489,3 @@ def nearest_grid_points(
640
489
  else:
641
490
  distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance)
642
491
  return distances, indices
643
-
644
-
645
- if __name__ == "__main__":
646
- global_lats, global_lons = np.meshgrid(
647
- np.linspace(90, -90, 90),
648
- np.linspace(-180, 180, 180),
649
- )
650
- global_lats = global_lats.flatten()
651
- global_lons = global_lons.flatten()
652
-
653
- lats, lons = np.meshgrid(
654
- np.linspace(50, 40, 100),
655
- np.linspace(-10, 15, 100),
656
- )
657
- lats = lats.flatten()
658
- lons = lons.flatten()
659
-
660
- mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0)
661
-
662
- import matplotlib.pyplot as plt
663
-
664
- fig = plt.figure(figsize=(10, 5))
665
- plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r")
666
- plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k")
667
- # plt.scatter(lons, lats, s=0.01)
668
- plt.savefig("cutout.png")