anemoi-datasets 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/create.py +4 -2
  3. anemoi/datasets/create/__init__.py +22 -6
  4. anemoi/datasets/create/check.py +1 -1
  5. anemoi/datasets/create/functions/__init__.py +15 -1
  6. anemoi/datasets/create/functions/filters/orog_to_z.py +58 -0
  7. anemoi/datasets/create/functions/filters/sum.py +71 -0
  8. anemoi/datasets/create/functions/filters/wz_to_w.py +79 -0
  9. anemoi/datasets/create/functions/sources/accumulations.py +7 -2
  10. anemoi/datasets/create/functions/sources/eccc_fstd.py +16 -0
  11. anemoi/datasets/create/functions/sources/mars.py +5 -1
  12. anemoi/datasets/create/functions/sources/xarray/__init__.py +3 -3
  13. anemoi/datasets/create/functions/sources/xarray/field.py +5 -1
  14. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +10 -1
  15. anemoi/datasets/create/functions/sources/xarray/metadata.py +5 -11
  16. anemoi/datasets/create/functions/sources/xarray/patch.py +44 -0
  17. anemoi/datasets/create/functions/sources/xarray/time.py +15 -0
  18. anemoi/datasets/create/functions/sources/xarray/variable.py +18 -2
  19. anemoi/datasets/create/input/repeated_dates.py +18 -0
  20. anemoi/datasets/create/input/result.py +1 -1
  21. anemoi/datasets/create/statistics/__init__.py +7 -4
  22. anemoi/datasets/create/utils.py +4 -0
  23. anemoi/datasets/data/complement.py +164 -0
  24. anemoi/datasets/data/dataset.py +68 -5
  25. anemoi/datasets/data/ensemble.py +55 -0
  26. anemoi/datasets/data/join.py +1 -2
  27. anemoi/datasets/data/merge.py +3 -0
  28. anemoi/datasets/data/misc.py +34 -1
  29. anemoi/datasets/grids.py +29 -10
  30. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/METADATA +2 -2
  31. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/RECORD +35 -29
  32. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/WHEEL +1 -1
  33. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/LICENSE +0 -0
  34. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/entry_points.txt +0 -0
  35. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ class Variable:
37
37
  self.coordinates = coordinates
38
38
 
39
39
  self._metadata = metadata.copy()
40
- self._metadata.update({"variable": variable.name})
40
+ self._metadata.update({"variable": variable.name, "param": variable.name})
41
41
 
42
42
  self.time = time
43
43
 
@@ -45,6 +45,9 @@ class Variable:
45
45
  self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
46
46
  self.by_name = {c.variable.name: c for c in coordinates}
47
47
 
48
+ # We need that alias for the time dimension
49
+ self._aliases = dict(valid_datetime="time")
50
+
48
51
  self.length = math.prod(self.shape)
49
52
 
50
53
  @property
@@ -96,15 +99,28 @@ class Variable:
96
99
 
97
100
  k, v = kwargs.popitem()
98
101
 
102
+ user_provided_k = k
103
+
104
+ if k == "valid_datetime":
105
+ # Ask the Time object to select the valid datetime
106
+ k = self.time.select_valid_datetime(self)
107
+ if k is None:
108
+ return None
109
+
99
110
  c = self.by_name.get(k)
100
111
 
112
+ # assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}"
113
+
101
114
  if c is None:
102
115
  missing[k] = v
103
116
  return self.sel(missing, **kwargs)
104
117
 
105
118
  i = c.index(v)
106
119
  if i is None:
107
- LOG.warning(f"Could not find {k}={v} in {c}")
120
+ if k != user_provided_k:
121
+ LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})")
122
+ else:
123
+ LOG.warning(f"Could not find {k}={v} in {c}")
108
124
  return None
109
125
 
110
126
  coordinates = [x.reduced(i) if c is x else x for x in self.coordinates]
@@ -72,6 +72,11 @@ class DateMapperClosest(DateMapper):
72
72
  end += self.frequency
73
73
 
74
74
  to_try = sorted(to_try - self.tried)
75
+ info = {k: "no-data" for k in to_try}
76
+
77
+ if not to_try:
78
+ LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}")
79
+ # return []
75
80
 
76
81
  if to_try:
77
82
  result = self.source.select(
@@ -82,19 +87,32 @@ class DateMapperClosest(DateMapper):
82
87
  )
83
88
  )
84
89
 
90
+ cnt = 0
85
91
  for f in result.datasource:
92
+ cnt += 1
86
93
  # We could keep the fields in a dictionary, but we don't want to keep the fields in memory
87
94
  date = as_datetime(f.metadata("valid_datetime"))
88
95
 
89
96
  if self.skip_all_nans:
90
97
  if np.isnan(f.to_numpy()).all():
91
98
  LOG.warning(f"Skipping {date} because all values are NaN")
99
+ info[date] = "all-nans"
92
100
  continue
93
101
 
102
+ info[date] = "ok"
94
103
  self.found.add(date)
95
104
 
105
+ if cnt == 0:
106
+ raise ValueError(f"No data found for {group_of_dates} in {self.source}")
107
+
96
108
  self.tried.update(to_try)
97
109
 
110
+ if not self.found:
111
+ for k, v in info.items():
112
+ LOG.warning(f"{k}: {v}")
113
+
114
+ raise ValueError(f"No matching data found for {asked_dates} in {self.source}")
115
+
98
116
  new_dates = defaultdict(list)
99
117
 
100
118
  for date in asked_dates:
@@ -459,7 +459,7 @@ class Result:
459
459
  if self.group_of_dates is not None:
460
460
  dates = f" {len(self.group_of_dates)} dates"
461
461
  dates += " ("
462
- dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates)
462
+ dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates)
463
463
  if len(dates) > 100:
464
464
  dates = dates[:100] + "..."
465
465
  dates += ")"
@@ -18,6 +18,7 @@ import shutil
18
18
  import socket
19
19
 
20
20
  import numpy as np
21
+ import tqdm
21
22
  from anemoi.utils.provenance import gather_provenance_info
22
23
 
23
24
  from ..check import check_data_values
@@ -98,7 +99,7 @@ def fix_variance(x, name, count, sums, squares):
98
99
 
99
100
  variances = squares / count - mean * mean
100
101
  assert variances.shape == squares.shape == mean.shape
101
- if all(variances >= 0):
102
+ if np.all(variances >= 0):
102
103
  LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
103
104
  return 0
104
105
 
@@ -108,7 +109,7 @@ def fix_variance(x, name, count, sums, squares):
108
109
  # return 0
109
110
 
110
111
  LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
111
- return x
112
+ return 0
112
113
 
113
114
 
114
115
  def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
@@ -134,7 +135,7 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa
134
135
 
135
136
  def compute_statistics(array, check_variables_names=None, allow_nans=False):
136
137
  """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary."""
137
-
138
+ LOG.info(f"Computing statistics for {array.shape} array")
138
139
  nvars = array.shape[1]
139
140
 
140
141
  LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}")
@@ -149,7 +150,7 @@ def compute_statistics(array, check_variables_names=None, allow_nans=False):
149
150
  maximum = np.zeros(stats_shape, dtype=np.float64)
150
151
  has_nans = np.zeros(stats_shape, dtype=np.bool_)
151
152
 
152
- for i, chunk in enumerate(array):
153
+ for i, chunk in tqdm.tqdm(enumerate(array), delay=1, total=array.shape[0], desc="Computing statistics"):
153
154
  values = chunk.reshape((nvars, -1))
154
155
 
155
156
  for j, name in enumerate(check_variables_names):
@@ -166,6 +167,8 @@ def compute_statistics(array, check_variables_names=None, allow_nans=False):
166
167
  count[i] = np.sum(~np.isnan(values), axis=1)
167
168
  has_nans[i] = np.isnan(values).any()
168
169
 
170
+ LOG.info(f"Statistics computed for {nvars} variables.")
171
+
169
172
  return {
170
173
  "minimum": minimum,
171
174
  "maximum": maximum,
@@ -54,6 +54,10 @@ def to_datetime(*args, **kwargs):
54
54
 
55
55
 
56
56
  def make_list_int(value):
57
+ # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
58
+ # Moved to anemoi.utils.humanize
59
+ # replace with from anemoi.utils.humanize import make_list_int
60
+ # when anemoi-utils is released and pyproject.toml is updated
57
61
  if isinstance(value, str):
58
62
  if "/" not in value:
59
63
  return [value]
@@ -0,0 +1,164 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ from functools import cached_property
13
+
14
+ from ..grids import nearest_grid_points
15
+ from .debug import Node
16
+ from .forwards import Combined
17
+ from .indexing import apply_index_to_slices_changes
18
+ from .indexing import index_to_slices
19
+ from .indexing import update_tuple
20
+ from .misc import _auto_adjust
21
+ from .misc import _open
22
+
23
+ LOG = logging.getLogger(__name__)
24
+
25
+
26
+ class Complement(Combined):
27
+
28
+ def __init__(self, target, source, what="variables", interpolation="nearest"):
29
+ super().__init__([target, source])
30
+
31
+ # We had the variables of dataset[1] to dataset[0]
32
+ # interpoated on the grid of dataset[0]
33
+
34
+ self.target = target
35
+ self.source = source
36
+
37
+ self._variables = []
38
+
39
+ # Keep the same order as the original dataset
40
+ for v in self.source.variables:
41
+ if v not in self.target.variables:
42
+ self._variables.append(v)
43
+
44
+ if not self._variables:
45
+ raise ValueError("Augment: no missing variables")
46
+
47
+ @property
48
+ def variables(self):
49
+ return self._variables
50
+
51
+ @property
52
+ def name_to_index(self):
53
+ return {v: i for i, v in enumerate(self.variables)}
54
+
55
+ @property
56
+ def shape(self):
57
+ shape = self.target.shape
58
+ return (shape[0], len(self._variables)) + shape[2:]
59
+
60
+ @property
61
+ def variables_metadata(self):
62
+ return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables}
63
+
64
+ def check_same_variables(self, d1, d2):
65
+ pass
66
+
67
+ @cached_property
68
+ def missing(self):
69
+ missing = self.source.missing.copy()
70
+ missing = missing | self.target.missing
71
+ return set(missing)
72
+
73
+ def tree(self):
74
+ """Generates a hierarchical tree structure for the `Cutout` instance and
75
+ its associated datasets.
76
+
77
+ Returns:
78
+ Node: A `Node` object representing the `Cutout` instance as the root
79
+ node, with each dataset in `self.datasets` represented as a child
80
+ node.
81
+ """
82
+ return Node(self, [d.tree() for d in (self.target, self.source)])
83
+
84
+ def __getitem__(self, index):
85
+ if isinstance(index, (int, slice)):
86
+ index = (index, slice(None), slice(None), slice(None))
87
+ return self._get_tuple(index)
88
+
89
+
90
+ class ComplementNone(Complement):
91
+
92
+ def __init__(self, target, source):
93
+ super().__init__(target, source)
94
+
95
+ def _get_tuple(self, index):
96
+ index, changes = index_to_slices(index, self.shape)
97
+ result = self.source[index]
98
+ return apply_index_to_slices_changes(result, changes)
99
+
100
+
101
+ class ComplementNearest(Complement):
102
+
103
+ def __init__(self, target, source):
104
+ super().__init__(target, source)
105
+
106
+ self._nearest_grid_points = nearest_grid_points(
107
+ self.source.latitudes,
108
+ self.source.longitudes,
109
+ self.target.latitudes,
110
+ self.target.longitudes,
111
+ )
112
+
113
+ def check_compatibility(self, d1, d2):
114
+ pass
115
+
116
+ def _get_tuple(self, index):
117
+ variable_index = 1
118
+ index, changes = index_to_slices(index, self.shape)
119
+ index, previous = update_tuple(index, variable_index, slice(None))
120
+ source_index = [self.source.name_to_index[x] for x in self.variables[previous]]
121
+ source_data = self.source[index[0], source_index, index[2], ...]
122
+ target_data = source_data[..., self._nearest_grid_points]
123
+
124
+ result = target_data[..., index[3]]
125
+
126
+ return apply_index_to_slices_changes(result, changes)
127
+
128
+
129
+ def complement_factory(args, kwargs):
130
+ from .select import Select
131
+
132
+ assert len(args) == 0, args
133
+
134
+ source = kwargs.pop("source")
135
+ target = kwargs.pop("complement")
136
+ what = kwargs.pop("what", "variables")
137
+ interpolation = kwargs.pop("interpolation", "none")
138
+
139
+ if what != "variables":
140
+ raise NotImplementedError(f"Complement what={what} not implemented")
141
+
142
+ if interpolation not in ("none", "nearest"):
143
+ raise NotImplementedError(f"Complement method={interpolation} not implemented")
144
+
145
+ source = _open(source)
146
+ target = _open(target)
147
+ # `select` is the same as `variables`
148
+ (source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"])
149
+
150
+ Class = {
151
+ None: ComplementNone,
152
+ "none": ComplementNone,
153
+ "nearest": ComplementNearest,
154
+ }[interpolation]
155
+
156
+ complement = Class(target=target, source=source)._subset(**kwargs)
157
+
158
+ # Will join the datasets along the variables axis
159
+ reorder = source.variables
160
+ complemented = _open([target, complement])
161
+ ordered = (
162
+ Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
163
+ )
164
+ return ordered
@@ -168,6 +168,16 @@ class Dataset:
168
168
  bbox = kwargs.pop("area")
169
169
  return Cropping(self, bbox)._subset(**kwargs).mutate()
170
170
 
171
+ if "number" in kwargs or "numbers" or "member" in kwargs or "members" in kwargs:
172
+ from .ensemble import Number
173
+
174
+ members = {}
175
+ for key in ["number", "numbers", "member", "members"]:
176
+ if key in kwargs:
177
+ members[key] = kwargs.pop(key)
178
+
179
+ return Number(self, **members)._subset(**kwargs).mutate()
180
+
171
181
  if "set_missing_dates" in kwargs:
172
182
  from .missing import MissingDates
173
183
 
@@ -251,13 +261,19 @@ class Dataset:
251
261
  return sorted([v for k, v in self.name_to_index.items() if k not in vars])
252
262
 
253
263
  def _reorder_to_columns(self, vars):
264
+ if isinstance(vars, str) and vars == "sort":
265
+ # Sorting the variables alphabetically.
266
+ # This is cruical for pre-training then transfer learning in combination with
267
+ # cutout and adjust = 'all'
268
+
269
+ indices = [self.name_to_index[k] for k, v in sorted(self.name_to_index.items(), key=lambda x: x[0])]
270
+ assert set(indices) == set(range(len(self.name_to_index)))
271
+ return indices
272
+
254
273
  if isinstance(vars, (list, tuple)):
255
274
  vars = {k: i for i, k in enumerate(vars)}
256
275
 
257
- indices = []
258
-
259
- for k, v in sorted(vars.items(), key=lambda x: x[1]):
260
- indices.append(self.name_to_index[k])
276
+ indices = [self.name_to_index[k] for k, v in sorted(vars.items(), key=lambda x: x[1])]
261
277
 
262
278
  # Make sure we don't forget any variables
263
279
  assert set(indices) == set(range(len(self.name_to_index)))
@@ -469,7 +485,7 @@ class Dataset:
469
485
  sample_count = min(4, len(indices))
470
486
  count = len(indices)
471
487
 
472
- p = slice(0, count, count // (sample_count - 1))
488
+ p = slice(0, count, count // max(1, sample_count - 1))
473
489
  samples = list(range(*p.indices(count)))
474
490
 
475
491
  samples.append(count - 1) # Add last
@@ -502,3 +518,50 @@ class Dataset:
502
518
  result.append(v)
503
519
 
504
520
  return result
521
+
522
+ def plot(self, date, variable, member=0, **kwargs):
523
+ """For debugging purposes, plot a field.
524
+
525
+ Parameters
526
+ ----------
527
+ date : int or datetime.datetime or numpy.datetime64 or str
528
+ The date to plot.
529
+ variable : int or str
530
+ The variable to plot.
531
+ member : int, optional
532
+ The ensemble member to plot.
533
+
534
+ **kwargs:
535
+ Additional arguments to pass to matplotlib.pyplot.tricontourf
536
+
537
+
538
+ Returns
539
+ -------
540
+ matplotlib.pyplot.Axes
541
+ """
542
+
543
+ from anemoi.utils.devtools import plot_values
544
+ from earthkit.data.utils.dates import to_datetime
545
+
546
+ if not isinstance(date, int):
547
+ date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype)
548
+ index = np.where(self.dates == date)[0]
549
+ if len(index) == 0:
550
+ raise ValueError(
551
+ f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}"
552
+ )
553
+ date_index = index[0]
554
+ else:
555
+ date_index = date
556
+
557
+ if isinstance(variable, int):
558
+ variable_index = variable
559
+ else:
560
+ if variable not in self.variables:
561
+ raise ValueError(f"Unknown variable {variable} (available: {self.variables})")
562
+
563
+ variable_index = self.name_to_index[variable]
564
+
565
+ values = self[date_index, variable_index, member]
566
+
567
+ return plot_values(values, self.latitudes, self.longitudes, **kwargs)
@@ -10,13 +10,68 @@
10
10
 
11
11
  import logging
12
12
 
13
+ import numpy as np
14
+
13
15
  from .debug import Node
16
+ from .forwards import Forwards
14
17
  from .forwards import GivenAxis
18
+ from .indexing import apply_index_to_slices_changes
19
+ from .indexing import index_to_slices
20
+ from .indexing import update_tuple
15
21
  from .misc import _auto_adjust
16
22
  from .misc import _open
17
23
 
18
24
  LOG = logging.getLogger(__name__)
19
25
 
26
+ OFFSETS = dict(number=1, numbers=1, member=0, members=0)
27
+
28
+
29
+ class Number(Forwards):
30
+ def __init__(self, forward, **kwargs):
31
+ super().__init__(forward)
32
+
33
+ self.members = []
34
+ for key, values in kwargs.items():
35
+ if not isinstance(values, (list, tuple)):
36
+ values = [values]
37
+ self.members.extend([int(v) - OFFSETS[key] for v in values])
38
+
39
+ self.members = sorted(set(self.members))
40
+ for n in self.members:
41
+ if not (0 <= n < forward.shape[2]):
42
+ raise ValueError(f"Member {n} is out of range. `number(s)` is one-based, `member(s)` is zero-based.")
43
+
44
+ self.mask = np.array([n in self.members for n in range(forward.shape[2])], dtype=bool)
45
+ self._shape, _ = update_tuple(forward.shape, 2, len(self.members))
46
+
47
+ @property
48
+ def shape(self):
49
+ return self._shape
50
+
51
+ def __getitem__(self, index):
52
+ if isinstance(index, int):
53
+ result = self.forward[index]
54
+ result = result[:, self.mask, :]
55
+ return result
56
+
57
+ if isinstance(index, slice):
58
+ result = self.forward[index]
59
+ result = result[:, :, self.mask, :]
60
+ return result
61
+
62
+ index, changes = index_to_slices(index, self.shape)
63
+ result = self.forward[index]
64
+ result = result[:, :, self.mask, :]
65
+ return apply_index_to_slices_changes(result, changes)
66
+
67
+ def tree(self):
68
+ return Node(self, [self.forward.tree()], numbers=[n + 1 for n in self.members])
69
+
70
+ def metadata_specific(self):
71
+ return {
72
+ "numbers": [n + 1 for n in self.members],
73
+ }
74
+
20
75
 
21
76
  class Ensemble(GivenAxis):
22
77
  def tree(self):
@@ -118,6 +118,7 @@ class Join(Combined):
118
118
  def variables_metadata(self):
119
119
  result = {}
120
120
  variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
121
+
121
122
  for d in self.datasets:
122
123
  md = d.variables_metadata
123
124
  for v in variables:
@@ -130,8 +131,6 @@ class Join(Combined):
130
131
  if v not in result:
131
132
  LOG.error("Missing metadata for %r.", v)
132
133
 
133
- raise ValueError("Some variables are missing metadata.")
134
-
135
134
  return result
136
135
 
137
136
  @cached_property
@@ -134,6 +134,9 @@ class Merge(Combined):
134
134
  def tree(self):
135
135
  return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
136
136
 
137
+ def metadata_specific(self):
138
+ return {"allow_gaps_in_dates": self.allow_gaps_in_dates}
139
+
137
140
  @debug_indexing
138
141
  def __getitem__(self, n):
139
142
  if isinstance(n, tuple):
@@ -103,6 +103,30 @@ def _as_date(d, dates, last):
103
103
 
104
104
  if isinstance(d, str):
105
105
 
106
+ def isfloat(s):
107
+ try:
108
+ float(s)
109
+ return True
110
+ except ValueError:
111
+ return False
112
+
113
+ if d.endswith("%") and isfloat(d[:-1]):
114
+ x = float(d[:-1])
115
+ if not 0 <= x <= 100:
116
+ raise ValueError(f"Invalid date: {d}")
117
+ i_float = x * len(dates) / 100
118
+
119
+ epsilon = 2 ** (-30)
120
+ if len(dates) > 1 / epsilon:
121
+ LOG.warning("Too many dates to use percentage, one date may be lost in rounding")
122
+
123
+ if last:
124
+ index = int(i_float + epsilon) - 1
125
+ else:
126
+ index = int(i_float - epsilon)
127
+ index = max(0, min(len(dates) - 1, index))
128
+ return dates[index]
129
+
106
130
  if "-" in d and ":" in d:
107
131
  date, time = d.replace(" ", "T").split("T")
108
132
  year, month, day = [int(_) for _ in date.split("-")]
@@ -194,7 +218,7 @@ def _open(a):
194
218
  raise NotImplementedError(f"Unsupported argument: {type(a)}")
195
219
 
196
220
 
197
- def _auto_adjust(datasets, kwargs):
221
+ def _auto_adjust(datasets, kwargs, exclude=None):
198
222
 
199
223
  if "adjust" not in kwargs:
200
224
  return datasets, kwargs
@@ -214,6 +238,9 @@ def _auto_adjust(datasets, kwargs):
214
238
  for a in adjust_list:
215
239
  adjust_set.update(ALIASES.get(a, [a]))
216
240
 
241
+ if exclude is not None:
242
+ adjust_set -= set(exclude)
243
+
217
244
  extra = set(adjust_set) - set(ALIASES["all"])
218
245
  if extra:
219
246
  raise ValueError(f"Invalid adjust keys: {extra}")
@@ -335,6 +362,12 @@ def _open_dataset(*args, **kwargs):
335
362
  assert not sets, sets
336
363
  return cutout_factory(args, kwargs).mutate()
337
364
 
365
+ if "complement" in kwargs:
366
+ from .complement import complement_factory
367
+
368
+ assert not sets, sets
369
+ return complement_factory(args, kwargs).mutate()
370
+
338
371
  for name in ("datasets", "dataset"):
339
372
  if name in kwargs:
340
373
  datasets = kwargs.pop(name)
anemoi/datasets/grids.py CHANGED
@@ -62,6 +62,8 @@ def plot_mask(path, mask, lats, lons, global_lats, global_lons):
62
62
  plt.savefig(path + "-global-zoomed.png")
63
63
 
64
64
 
65
+ # TODO: Use the one from anemoi.utils.grids instead
66
+ # from anemoi.utils.grids import ...
65
67
  def xyz_to_latlon(x, y, z):
66
68
  return (
67
69
  np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))),
@@ -69,6 +71,8 @@ def xyz_to_latlon(x, y, z):
69
71
  )
70
72
 
71
73
 
74
+ # TODO: Use the one from anemoi.utils.grids instead
75
+ # from anemoi.utils.grids import ...
72
76
  def latlon_to_xyz(lat, lon, radius=1.0):
73
77
  # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
74
78
  # We assume that the Earth is a sphere of radius 1 so N(phi) = 1
@@ -152,7 +156,7 @@ def cutout_mask(
152
156
  plot=None,
153
157
  ):
154
158
  """Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]"""
155
- from scipy.spatial import KDTree
159
+ from scipy.spatial import cKDTree
156
160
 
157
161
  # TODO: transform min_distance from lat/lon to xyz
158
162
 
@@ -195,13 +199,13 @@ def cutout_mask(
195
199
  min_distance = min_distance_km / 6371.0
196
200
  else:
197
201
  points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km]
198
- distances, _ = KDTree(points).query(points, k=2)
202
+ distances, _ = cKDTree(points).query(points, k=2)
199
203
  min_distance = np.min(distances[:, 1])
200
204
 
201
205
  LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km")
202
206
 
203
- # Use a KDTree to find the nearest points
204
- distances, indices = KDTree(lam_points).query(global_points, k=neighbours)
207
+ # Use a cKDTree to find the nearest points
208
+ distances, indices = cKDTree(lam_points).query(global_points, k=neighbours)
205
209
 
206
210
  # Centre of the Earth
207
211
  zero = np.array([0.0, 0.0, 0.0])
@@ -255,7 +259,7 @@ def thinning_mask(
255
259
  cropping_distance=2.0,
256
260
  ):
257
261
  """Return the list of points in [lats, lons] closest to [global_lats, global_lons]"""
258
- from scipy.spatial import KDTree
262
+ from scipy.spatial import cKDTree
259
263
 
260
264
  assert global_lats.ndim == 1
261
265
  assert global_lons.ndim == 1
@@ -291,20 +295,20 @@ def thinning_mask(
291
295
  xyx = latlon_to_xyz(lats, lons)
292
296
  points = np.array(xyx).transpose()
293
297
 
294
- # Use a KDTree to find the nearest points
295
- _, indices = KDTree(points).query(global_points, k=1)
298
+ # Use a cKDTree to find the nearest points
299
+ _, indices = cKDTree(points).query(global_points, k=1)
296
300
 
297
301
  return np.array([i for i in indices])
298
302
 
299
303
 
300
304
  def outline(lats, lons, neighbours=5):
301
- from scipy.spatial import KDTree
305
+ from scipy.spatial import cKDTree
302
306
 
303
307
  xyx = latlon_to_xyz(lats, lons)
304
308
  grid_points = np.array(xyx).transpose()
305
309
 
306
- # Use a KDTree to find the nearest points
307
- _, indices = KDTree(grid_points).query(grid_points, k=neighbours)
310
+ # Use a cKDTree to find the nearest points
311
+ _, indices = cKDTree(grid_points).query(grid_points, k=neighbours)
308
312
 
309
313
  # Centre of the Earth
310
314
  zero = np.array([0.0, 0.0, 0.0])
@@ -379,6 +383,21 @@ def serialise_mask(mask):
379
383
  return result
380
384
 
381
385
 
386
+ def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes):
387
+ # TODO: Use the one from anemoi.utils.grids instead
388
+ # from anemoi.utils.grids import ...
389
+ from scipy.spatial import cKDTree
390
+
391
+ source_xyz = latlon_to_xyz(source_latitudes, source_longitudes)
392
+ source_points = np.array(source_xyz).transpose()
393
+
394
+ target_xyz = latlon_to_xyz(target_latitudes, target_longitudes)
395
+ target_points = np.array(target_xyz).transpose()
396
+
397
+ _, indices = cKDTree(source_points).query(target_points, k=1)
398
+ return indices
399
+
400
+
382
401
  if __name__ == "__main__":
383
402
  global_lats, global_lons = np.meshgrid(
384
403
  np.linspace(90, -90, 90),