anemoi-datasets 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +50 -20
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/create/__init__.py +961 -146
  13. anemoi/datasets/create/check.py +5 -3
  14. anemoi/datasets/create/config.py +53 -2
  15. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  16. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  17. anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
  18. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  19. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  20. anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
  21. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  22. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  23. anemoi/datasets/create/input.py +23 -22
  24. anemoi/datasets/create/statistics/__init__.py +39 -23
  25. anemoi/datasets/create/utils.py +3 -2
  26. anemoi/datasets/data/__init__.py +1 -0
  27. anemoi/datasets/data/concat.py +46 -2
  28. anemoi/datasets/data/dataset.py +109 -34
  29. anemoi/datasets/data/forwards.py +17 -8
  30. anemoi/datasets/data/grids.py +17 -3
  31. anemoi/datasets/data/interpolate.py +133 -0
  32. anemoi/datasets/data/misc.py +56 -66
  33. anemoi/datasets/data/missing.py +240 -0
  34. anemoi/datasets/data/select.py +7 -1
  35. anemoi/datasets/data/stores.py +3 -3
  36. anemoi/datasets/data/subset.py +47 -5
  37. anemoi/datasets/data/unchecked.py +20 -22
  38. anemoi/datasets/data/xy.py +125 -0
  39. anemoi/datasets/dates/__init__.py +13 -66
  40. anemoi/datasets/dates/groups.py +2 -2
  41. anemoi/datasets/grids.py +66 -48
  42. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
  43. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +47 -37
  44. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
  45. anemoi/datasets/create/loaders.py +0 -936
  46. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
  47. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
  48. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
@@ -79,6 +79,37 @@ def to_datetimes(dates):
79
79
  return [to_datetime(d) for d in dates]
80
80
 
81
81
 
82
+ def fix_variance(x, name, count, sums, squares):
83
+ assert count.shape == sums.shape == squares.shape
84
+ assert isinstance(x, float)
85
+
86
+ mean = sums / count
87
+ assert mean.shape == count.shape
88
+
89
+ if x >= 0:
90
+ return x
91
+
92
+ LOG.warning(f"Negative variance for {name=}, variance={x}")
93
+ magnitude = np.sqrt((squares / count + mean * mean) / 2)
94
+ LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}")
95
+ LOG.warning(f"Variable span order of magnitude is {magnitude}.")
96
+ LOG.warning(f"Count is {count}.")
97
+
98
+ variances = squares / count - mean * mean
99
+ assert variances.shape == squares.shape == mean.shape
100
+ if all(variances >= 0):
101
+ LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
102
+ return 0
103
+
104
+ # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6:
105
+ # LOG.warning("Variance is negative but very small.")
106
+ # variances = squares / count - mean * mean
107
+ # return 0
108
+
109
+ LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
110
+ return x
111
+
112
+
82
113
  def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
83
114
  if (x >= 0).all():
84
115
  return
@@ -292,39 +323,24 @@ class StatAggregator:
292
323
  def aggregate(self):
293
324
  minimum = np.nanmin(self.minimum, axis=0)
294
325
  maximum = np.nanmax(self.maximum, axis=0)
326
+
295
327
  sums = np.nansum(self.sums, axis=0)
296
328
  squares = np.nansum(self.squares, axis=0)
297
329
  count = np.nansum(self.count, axis=0)
298
330
  has_nans = np.any(self.has_nans, axis=0)
299
- mean = sums / count
331
+ assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape
300
332
 
301
- assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape
333
+ mean = sums / count
334
+ assert mean.shape == minimum.shape
302
335
 
303
336
  x = squares / count - mean * mean
304
-
305
- # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
306
- # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
307
- # assert x.shape == (1,)
308
- # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
309
- # if x >= 0:
310
- # return x
311
- #
312
- # order = np.sqrt((squares / count + mean * mean)/2)
313
- # range = maximum - minimum
314
- # LOG.warning(f"Negative variance for {name=}, variance={x}")
315
- # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
316
- # LOG.warning(f"Variable order of magnitude is {order}.")
317
- # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
318
- # LOG.warning(f"Count is {count}.")
319
- # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
320
- # LOG.warning(f"Variance is negative but very small, setting to 0.")
321
- # return x*0
322
- # return x
337
+ assert x.shape == minimum.shape
323
338
 
324
339
  for i, name in enumerate(self.variables_names):
325
340
  # remove negative variance due to numerical errors
326
- # Not needed for now, fix_variance is disabled
327
- # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1])
341
+ x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1])
342
+
343
+ for i, name in enumerate(self.variables_names):
328
344
  check_variance(
329
345
  x[i : i + 1],
330
346
  [name],
@@ -7,6 +7,7 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ import datetime
10
11
  import os
11
12
  from contextlib import contextmanager
12
13
 
@@ -61,10 +62,10 @@ def make_list_int(value):
61
62
 
62
63
 
63
64
  def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
64
- assert isinstance(frequency, int), frequency
65
+ assert isinstance(frequency, datetime.timedelta), frequency
65
66
  start = np.datetime64(start)
66
67
  end = np.datetime64(end)
67
- delta = np.timedelta64(frequency, "h")
68
+ delta = np.timedelta64(frequency)
68
69
 
69
70
  res = []
70
71
  while start <= end:
@@ -27,6 +27,7 @@ class MissingDateError(Exception):
27
27
 
28
28
  def open_dataset(*args, **kwargs):
29
29
  ds = _open_dataset(*args, **kwargs)
30
+ ds = ds.mutate()
30
31
  ds.arguments = {"args": args, "kwargs": kwargs}
31
32
  ds._check()
32
33
  return ds
@@ -9,6 +9,7 @@ import logging
9
9
  from functools import cached_property
10
10
 
11
11
  import numpy as np
12
+ from anemoi.utils.dates import frequency_to_timedelta
12
13
 
13
14
  from .debug import Node
14
15
  from .debug import debug_indexing
@@ -102,20 +103,63 @@ class Concat(ConcatMixin, Combined):
102
103
  def tree(self):
103
104
  return Node(self, [d.tree() for d in self.datasets])
104
105
 
106
+ @classmethod
107
+ def check_dataset_compatibility(cls, datasets, fill_missing_gaps=False):
108
+ # Study the dates
109
+ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets]
105
110
 
106
- def concat_factory(args, kwargs, zarr_root):
111
+ # Make sure the dates are disjoint
112
+ for i in range(len(ranges)):
113
+ r = ranges[i]
114
+ for j in range(i + 1, len(ranges)):
115
+ s = ranges[j]
116
+ if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
117
+ raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
118
+
119
+ # For now we should have the datasets in order with no gaps
120
+
121
+ frequency = frequency_to_timedelta(datasets[0].frequency)
122
+ result = []
123
+
124
+ for i in range(len(ranges) - 1):
125
+ result.append(datasets[i])
126
+ r = ranges[i]
127
+ s = ranges[i + 1]
128
+ if r[1] + frequency != s[0]:
129
+ if fill_missing_gaps:
130
+ from .missing import MissingDataset
131
+
132
+ result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency))
133
+ else:
134
+ r = [str(e) for e in r]
135
+ s = [str(e) for e in s]
136
+ raise ValueError(
137
+ "Datasets must be sorted by dates, with no gaps: "
138
+ f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
139
+ )
140
+
141
+ result.append(datasets[-1])
142
+ assert len(result) >= len(datasets), (len(result), len(datasets))
143
+
144
+ return result
145
+
146
+
147
+ def concat_factory(args, kwargs):
107
148
 
108
149
  datasets = kwargs.pop("concat")
150
+ fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
109
151
  assert isinstance(datasets, (list, tuple))
110
152
  assert len(args) == 0
111
153
 
112
154
  assert isinstance(datasets, (list, tuple))
113
155
 
114
- datasets = [_open(e, zarr_root) for e in datasets]
156
+ datasets = [_open(e) for e in datasets]
115
157
 
116
158
  if len(datasets) == 1:
117
159
  return datasets[0]._subset(**kwargs)
118
160
 
119
161
  datasets, kwargs = _auto_adjust(datasets, kwargs)
120
162
 
163
+ datasets = Concat.check_dataset_compatibility(datasets, fill_missing_gaps)
164
+
121
165
  return Concat(datasets)._subset(**kwargs)
@@ -5,24 +5,37 @@
5
5
  # granted to it by virtue of its status as an intergovernmental organisation
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
+ import datetime
9
+ import json
8
10
  import logging
9
11
  import os
12
+ import pprint
10
13
  import warnings
11
14
  from functools import cached_property
12
15
 
16
+ from anemoi.utils.dates import frequency_to_seconds
17
+ from anemoi.utils.dates import frequency_to_string
18
+ from anemoi.utils.dates import frequency_to_timedelta
19
+
13
20
  LOG = logging.getLogger(__name__)
14
21
 
15
22
 
16
23
  class Dataset:
17
24
  arguments = {}
18
25
 
26
+ def mutate(self):
27
+ return self
28
+
29
+ def swap_with_parent(self, parent):
30
+ return parent
31
+
19
32
  @cached_property
20
33
  def _len(self):
21
34
  return len(self)
22
35
 
23
36
  def _subset(self, **kwargs):
24
37
  if not kwargs:
25
- return self
38
+ return self.mutate()
26
39
 
27
40
  if "start" in kwargs or "end" in kwargs:
28
41
  start = kwargs.pop("start", None)
@@ -30,37 +43,52 @@ class Dataset:
30
43
 
31
44
  from .subset import Subset
32
45
 
33
- return Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs)
46
+ return (
47
+ Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate()
48
+ )
34
49
 
35
50
  if "frequency" in kwargs:
36
51
  from .subset import Subset
37
52
 
53
+ if "interpolate_frequency" in kwargs:
54
+ raise ValueError("Cannot use both `frequency` and `interpolate_frequency`")
55
+
38
56
  frequency = kwargs.pop("frequency")
39
- return Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))._subset(**kwargs)
57
+ return (
58
+ Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))
59
+ ._subset(**kwargs)
60
+ .mutate()
61
+ )
62
+
63
+ if "interpolate_frequency" in kwargs:
64
+ from .interpolate import InterpolateFrequency
65
+
66
+ interpolate_frequency = kwargs.pop("interpolate_frequency")
67
+ return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate()
40
68
 
41
69
  if "select" in kwargs:
42
70
  from .select import Select
43
71
 
44
72
  select = kwargs.pop("select")
45
- return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs)
73
+ return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate()
46
74
 
47
75
  if "drop" in kwargs:
48
76
  from .select import Select
49
77
 
50
78
  drop = kwargs.pop("drop")
51
- return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs)
79
+ return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate()
52
80
 
53
81
  if "reorder" in kwargs:
54
82
  from .select import Select
55
83
 
56
84
  reorder = kwargs.pop("reorder")
57
- return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs)
85
+ return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
58
86
 
59
87
  if "rename" in kwargs:
60
88
  from .select import Rename
61
89
 
62
90
  rename = kwargs.pop("rename")
63
- return Rename(self, rename)._subset(**kwargs)
91
+ return Rename(self, rename)._subset(**kwargs).mutate()
64
92
 
65
93
  if "statistics" in kwargs:
66
94
  from ..data import open_dataset
@@ -68,20 +96,38 @@ class Dataset:
68
96
 
69
97
  statistics = kwargs.pop("statistics")
70
98
 
71
- return Statistics(self, open_dataset(statistics))._subset(**kwargs)
99
+ return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
72
100
 
73
101
  if "thinning" in kwargs:
74
102
  from .masked import Thinning
75
103
 
76
104
  thinning = kwargs.pop("thinning")
77
105
  method = kwargs.pop("method", "every-nth")
78
- return Thinning(self, thinning, method)._subset(**kwargs)
106
+ return Thinning(self, thinning, method)._subset(**kwargs).mutate()
79
107
 
80
108
  if "area" in kwargs:
81
109
  from .masked import Cropping
82
110
 
83
111
  bbox = kwargs.pop("area")
84
- return Cropping(self, bbox)._subset(**kwargs)
112
+ return Cropping(self, bbox)._subset(**kwargs).mutate()
113
+
114
+ if "missing_dates" in kwargs:
115
+ from .missing import MissingDates
116
+
117
+ missing_dates = kwargs.pop("missing_dates")
118
+ return MissingDates(self, missing_dates)._subset(**kwargs).mutate()
119
+
120
+ if "skip_missing_dates" in kwargs:
121
+ from .missing import SkipMissingDates
122
+
123
+ if "expected_access" not in kwargs:
124
+ raise ValueError("`expected_access` is required with `skip_missing_dates`")
125
+
126
+ skip_missing_dates = kwargs.pop("skip_missing_dates")
127
+ expected_access = kwargs.pop("expected_access")
128
+
129
+ if skip_missing_dates:
130
+ return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
85
131
 
86
132
  # Keep last
87
133
  if "shuffle" in kwargs:
@@ -90,15 +136,14 @@ class Dataset:
90
136
  shuffle = kwargs.pop("shuffle")
91
137
 
92
138
  if shuffle:
93
- return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs)
139
+ return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs).mutate()
94
140
 
95
141
  raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))
96
142
 
97
143
  def _frequency_to_indices(self, frequency):
98
- from .misc import _frequency_to_hours
99
144
 
100
- requested_frequency = _frequency_to_hours(frequency)
101
- dataset_frequency = _frequency_to_hours(self.frequency)
145
+ requested_frequency = frequency_to_seconds(frequency)
146
+ dataset_frequency = frequency_to_seconds(self.frequency)
102
147
  assert requested_frequency % dataset_frequency == 0
103
148
  # Question: where do we start? first date, or first date that is a multiple of the frequency?
104
149
  step = requested_frequency // dataset_frequency
@@ -171,37 +216,71 @@ class Dataset:
171
216
  import anemoi
172
217
 
173
218
  def tidy(v):
174
- if isinstance(v, (list, tuple)):
219
+ if isinstance(v, (list, tuple, set)):
175
220
  return [tidy(i) for i in v]
176
221
  if isinstance(v, dict):
177
222
  return {k: tidy(v) for k, v in v.items()}
178
223
  if isinstance(v, str) and v.startswith("/"):
179
224
  return os.path.basename(v)
225
+ if isinstance(v, datetime.datetime):
226
+ return v.isoformat()
227
+ if isinstance(v, datetime.date):
228
+ return v.isoformat()
229
+ if isinstance(v, datetime.timedelta):
230
+ return frequency_to_string(v)
231
+
232
+ if isinstance(v, Dataset):
233
+ # That can happen in the `arguments`
234
+ # if a dataset is passed as an argument
235
+ return repr(v)
236
+
237
+ if isinstance(v, slice):
238
+ return (v.start, v.stop, v.step)
239
+
180
240
  return v
181
241
 
182
- return tidy(
183
- dict(
184
- version=anemoi.datasets.__version__,
185
- shape=self.shape,
186
- arguments=self.arguments,
187
- specific=self.metadata_specific(),
188
- frequency=self.frequency,
189
- variables=self.variables,
190
- start_date=self.dates[0].astype(str),
191
- end_date=self.dates[-1].astype(str),
192
- )
242
+ md = dict(
243
+ version=anemoi.datasets.__version__,
244
+ arguments=self.arguments,
245
+ **self.dataset_metadata(),
246
+ )
247
+
248
+ try:
249
+ return json.loads(json.dumps(tidy(md)))
250
+ except Exception:
251
+ LOG.exception("Failed to serialize metadata")
252
+ pprint.pprint(md)
253
+
254
+ raise
255
+
256
+ @property
257
+ def start_date(self):
258
+ return self.dates[0]
259
+
260
+ @property
261
+ def end_date(self):
262
+ return self.dates[-1]
263
+
264
+ def dataset_metadata(self):
265
+ return dict(
266
+ specific=self.metadata_specific(),
267
+ frequency=self.frequency,
268
+ variables=self.variables,
269
+ shape=self.shape,
270
+ start_date=self.start_date.astype(str),
271
+ end_date=self.end_date.astype(str),
193
272
  )
194
273
 
195
274
  def metadata_specific(self, **kwargs):
196
275
  action = self.__class__.__name__.lower()
197
- assert isinstance(self.frequency, int), (self.frequency, self, action)
276
+ # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
198
277
  return dict(
199
278
  action=action,
200
279
  variables=self.variables,
201
280
  shape=self.shape,
202
- frequency=self.frequency,
203
- start_date=self.dates[0].astype(str),
204
- end_date=self.dates[-1].astype(str),
281
+ frequency=frequency_to_string(frequency_to_timedelta(self.frequency)),
282
+ start_date=self.start_date.astype(str),
283
+ end_date=self.end_date.astype(str),
205
284
  **kwargs,
206
285
  )
207
286
 
@@ -220,10 +299,6 @@ class Dataset:
220
299
  if n.startswith("_") and not n.startswith("__"):
221
300
  warnings.warn(f"Private method {n} is overriden in {ds.__class__.__name__}")
222
301
 
223
- # for n in ('metadata_specific', 'tree', 'source'):
224
- # if n not in overriden:
225
- # warnings.warn(f"Method {n} is not overriden in {ds.__class__.__name__}")
226
-
227
302
  def _repr_html_(self):
228
303
  return self.tree().html()
229
304
 
@@ -23,7 +23,7 @@ LOG = logging.getLogger(__name__)
23
23
 
24
24
  class Forwards(Dataset):
25
25
  def __init__(self, forward):
26
- self.forward = forward
26
+ self.forward = forward.mutate()
27
27
 
28
28
  def __len__(self):
29
29
  return len(self.forward)
@@ -118,6 +118,9 @@ class Combined(Forwards):
118
118
  # Forward most properties to the first dataset
119
119
  super().__init__(datasets[0])
120
120
 
121
+ def mutate(self):
122
+ return self
123
+
121
124
  def check_same_resolution(self, d1, d2):
122
125
  if d1.resolution != d2.resolution:
123
126
  raise ValueError(f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})")
@@ -187,14 +190,9 @@ class Combined(Forwards):
187
190
  **kwargs,
188
191
  )
189
192
 
190
- @cached_property
193
+ @property
191
194
  def missing(self):
192
- offset = 0
193
- result = set()
194
- for d in self.datasets:
195
- result.update(offset + m for m in d.missing)
196
- offset += len(d)
197
- return result
195
+ raise NotImplementedError("missing() not implemented for Combined")
198
196
 
199
197
  def get_dataset_names(self, names):
200
198
  for d in self.datasets:
@@ -249,3 +247,14 @@ class GivenAxis(Combined):
249
247
  return self._get_slice(n)
250
248
 
251
249
  return np.concatenate([d[n] for d in self.datasets], axis=self.axis - 1)
250
+
251
+ @cached_property
252
+ def missing(self):
253
+ offset = 0
254
+ result = set()
255
+ for d in self.datasets:
256
+ print("--->", d.missing, d)
257
+ result.update(offset + m for m in d.missing)
258
+ if self.axis == 0: # Advance if axis is time
259
+ offset += len(d)
260
+ return result
@@ -128,7 +128,7 @@ class Grids(GridsBase):
128
128
 
129
129
 
130
130
  class Cutout(GridsBase):
131
- def __init__(self, datasets, axis):
131
+ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
132
132
  from anemoi.datasets.grids import cutout_mask
133
133
 
134
134
  super().__init__(datasets, axis)
@@ -144,7 +144,10 @@ class Cutout(GridsBase):
144
144
  self.lam.longitudes,
145
145
  self.globe.latitudes,
146
146
  self.globe.longitudes,
147
- # plot="cutout",
147
+ plot=plot,
148
+ min_distance_km=min_distance_km,
149
+ cropping_distance=cropping_distance,
150
+ neighbours=neighbours,
148
151
  )
149
152
  assert len(self.mask) == self.globe.shape[3], (
150
153
  len(self.mask),
@@ -229,6 +232,10 @@ def cutout_factory(args, kwargs):
229
232
 
230
233
  cutout = kwargs.pop("cutout")
231
234
  axis = kwargs.pop("axis", 3)
235
+ plot = kwargs.pop("plot", None)
236
+ min_distance_km = kwargs.pop("min_distance_km", None)
237
+ cropping_distance = kwargs.pop("cropping_distance", 2.0)
238
+ neighbours = kwargs.pop("neighbours", 5)
232
239
 
233
240
  assert len(args) == 0
234
241
  assert isinstance(cutout, (list, tuple))
@@ -236,4 +243,11 @@ def cutout_factory(args, kwargs):
236
243
  datasets = [_open(e) for e in cutout]
237
244
  datasets, kwargs = _auto_adjust(datasets, kwargs)
238
245
 
239
- return Cutout(datasets, axis=axis)._subset(**kwargs)
246
+ return Cutout(
247
+ datasets,
248
+ axis=axis,
249
+ neighbours=neighbours,
250
+ min_distance_km=min_distance_km,
251
+ cropping_distance=cropping_distance,
252
+ plot=plot,
253
+ )._subset(**kwargs)
@@ -0,0 +1,133 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import logging
9
+ from functools import cached_property
10
+
11
+ import numpy as np
12
+ from anemoi.utils.dates import frequency_to_timedelta
13
+
14
+ from .debug import Node
15
+ from .debug import debug_indexing
16
+ from .forwards import Forwards
17
+ from .indexing import apply_index_to_slices_changes
18
+ from .indexing import expand_list_indexing
19
+ from .indexing import index_to_slices
20
+ from .indexing import update_tuple
21
+
22
+ LOG = logging.getLogger(__name__)
23
+
24
+
25
+ class InterpolateFrequency(Forwards):
26
+
27
+ def __init__(self, dataset, frequency):
28
+ super().__init__(dataset)
29
+ self._frequency = frequency_to_timedelta(frequency)
30
+
31
+ self.seconds = self._frequency.total_seconds()
32
+ other_seconds = dataset.frequency.total_seconds()
33
+
34
+ self.seconds = int(self.seconds)
35
+ assert self.seconds == self._frequency.total_seconds()
36
+
37
+ other_seconds = int(other_seconds)
38
+ assert other_seconds == dataset.frequency.total_seconds()
39
+
40
+ if self.seconds >= other_seconds:
41
+ raise ValueError(
42
+ f"Interpolate frequency {self._frequency} must be more frequent than dataset frequency {dataset.frequency}"
43
+ )
44
+
45
+ if other_seconds % self.seconds != 0:
46
+ raise ValueError(
47
+ f"Interpolate frequency {self._frequency} must be a multiple of the dataset frequency {dataset.frequency}"
48
+ )
49
+
50
+ self.ratio = other_seconds // self.seconds
51
+ self.alphas = np.linspace(0, 1, self.ratio + 1)
52
+ self.other_len = len(dataset)
53
+
54
+ @debug_indexing
55
+ @expand_list_indexing
56
+ def _get_tuple(self, index):
57
+ index, changes = index_to_slices(index, self.shape)
58
+ index, previous = update_tuple(index, 0, slice(None))
59
+ result = self._get_slice(previous)
60
+ return apply_index_to_slices_changes(result[index], changes)
61
+
62
+ def _get_slice(self, s):
63
+ return np.stack([self[i] for i in range(*s.indices(self._len))])
64
+
65
+ @debug_indexing
66
+ def __getitem__(self, n):
67
+ if isinstance(n, tuple):
68
+ return self._get_tuple(n)
69
+
70
+ if isinstance(n, slice):
71
+ return self._get_slice(n)
72
+
73
+ if n < 0:
74
+ n += self._len
75
+
76
+ if n == self._len - 1:
77
+ # Special case for the last element
78
+ return self.forward[-1]
79
+
80
+ i = n // self.ratio
81
+ x = n % self.ratio
82
+
83
+ if x == 0:
84
+ # No interpolation needed
85
+ return self.forward[i]
86
+
87
+ alpha = self.alphas[x]
88
+
89
+ assert 0 < alpha < 1, alpha
90
+ return self.forward[i] * (1 - alpha) + self.forward[i + 1] * alpha
91
+
92
+ def __len__(self):
93
+ return (self.other_len - 1) * self.ratio + 1
94
+
95
+ @property
96
+ def frequency(self):
97
+ return self._frequency
98
+
99
+ @cached_property
100
+ def dates(self):
101
+ result = []
102
+ deltas = [np.timedelta64(self.seconds * i, "s") for i in range(self.ratio)]
103
+ for d in self.forward.dates[:-1]:
104
+ for i in deltas:
105
+ result.append(d + i)
106
+ result.append(self.forward.dates[-1])
107
+ return np.array(result)
108
+
109
+ @property
110
+ def shape(self):
111
+ return (self._len,) + self.forward.shape[1:]
112
+
113
+ def tree(self):
114
+ return Node(self, [self.forward.tree()], frequency=self.frequency)
115
+
116
+ @cached_property
117
+ def missing(self):
118
+ result = []
119
+ j = 0
120
+ for i in range(self.other_len):
121
+ missing = i in self.forward.missing
122
+ for _ in range(self.ratio):
123
+ if missing:
124
+ result.append(j)
125
+ j += 1
126
+
127
+ result = set(x for x in result if x < self._len)
128
+ return result
129
+
130
+ def subclass_metadata_specific(self):
131
+ return {
132
+ # "frequency": frequency_to_string(self._frequency),
133
+ }