anemoi-datasets 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +52 -21
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/create/__init__.py +959 -146
  13. anemoi/datasets/create/check.py +5 -3
  14. anemoi/datasets/create/config.py +54 -2
  15. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
  16. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
  17. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
  18. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
  19. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
  20. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
  21. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
  22. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
  23. anemoi/datasets/create/functions/sources/grib.py +86 -1
  24. anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
  25. anemoi/datasets/create/functions/sources/mars.py +9 -3
  26. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  27. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  28. anemoi/datasets/create/functions/sources/xarray/field.py +8 -2
  29. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  30. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  31. anemoi/datasets/create/functions/sources/xarray/metadata.py +40 -40
  32. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  33. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  34. anemoi/datasets/create/input.py +62 -39
  35. anemoi/datasets/create/persistent.py +1 -1
  36. anemoi/datasets/create/statistics/__init__.py +39 -23
  37. anemoi/datasets/create/utils.py +6 -2
  38. anemoi/datasets/data/__init__.py +1 -0
  39. anemoi/datasets/data/concat.py +46 -2
  40. anemoi/datasets/data/dataset.py +119 -34
  41. anemoi/datasets/data/debug.py +5 -1
  42. anemoi/datasets/data/forwards.py +17 -8
  43. anemoi/datasets/data/grids.py +17 -3
  44. anemoi/datasets/data/interpolate.py +133 -0
  45. anemoi/datasets/data/masked.py +2 -2
  46. anemoi/datasets/data/misc.py +56 -66
  47. anemoi/datasets/data/missing.py +240 -0
  48. anemoi/datasets/data/rescale.py +147 -0
  49. anemoi/datasets/data/select.py +7 -1
  50. anemoi/datasets/data/stores.py +23 -10
  51. anemoi/datasets/data/subset.py +47 -5
  52. anemoi/datasets/data/unchecked.py +20 -22
  53. anemoi/datasets/data/xy.py +125 -0
  54. anemoi/datasets/dates/__init__.py +124 -95
  55. anemoi/datasets/dates/groups.py +85 -20
  56. anemoi/datasets/grids.py +66 -48
  57. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/METADATA +8 -17
  58. anemoi_datasets-0.5.0.dist-info/RECORD +105 -0
  59. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/WHEEL +1 -1
  60. anemoi/datasets/create/loaders.py +0 -936
  61. anemoi_datasets-0.4.4.dist-info/RECORD +0 -86
  62. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/LICENSE +0 -0
  63. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/entry_points.txt +0 -0
  64. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@
8
8
  import calendar
9
9
  import datetime
10
10
  import logging
11
- import re
12
11
  from pathlib import PurePath
13
12
 
14
13
  import numpy as np
@@ -39,26 +38,21 @@ def add_dataset_path(path):
39
38
  config["datasets"]["path"].append(path)
40
39
 
41
40
 
42
- def _frequency_to_hours(frequency):
43
- if isinstance(frequency, int):
44
- return frequency
45
-
46
- if isinstance(frequency, float):
47
- assert int(frequency) == frequency
48
- return int(frequency)
49
-
50
- m = re.match(r"(\d+)([dh])?", frequency)
51
- if m is None:
52
- raise ValueError("Invalid frequency: " + frequency)
53
-
54
- frequency = int(m.group(1))
55
- if m.group(2) == "h":
56
- return frequency
57
-
58
- if m.group(2) == "d":
59
- return frequency * 24
41
+ def round_datetime(d, dates, up):
42
+ """Round up (or down) a datetime to the nearest date in a list of dates"""
43
+ if dates is None or len(dates) == 0:
44
+ return d
60
45
 
61
- raise NotImplementedError()
46
+ for i, date in enumerate(dates):
47
+ if date == d:
48
+ return date
49
+ if date > d:
50
+ if up:
51
+ return date
52
+ if i > 0:
53
+ return dates[i - 1]
54
+ return date
55
+ return dates[-1]
62
56
 
63
57
 
64
58
  def _as_date(d, dates, last):
@@ -67,7 +61,8 @@ def _as_date(d, dates, last):
67
61
  # so we need to check for datetime.datetime first
68
62
 
69
63
  if isinstance(d, (np.datetime64, datetime.datetime)):
70
- return d
64
+ d = round_datetime(d, dates, up=not last)
65
+ return np.datetime64(d)
71
66
 
72
67
  if isinstance(d, datetime.date):
73
68
  d = d.year * 10_000 + d.month * 100 + d.day
@@ -81,27 +76,27 @@ def _as_date(d, dates, last):
81
76
  if len(str(d)) == 4:
82
77
  year = d
83
78
  if last:
84
- return np.datetime64(f"{year:04}-12-31T23:59:59")
79
+ return _as_date(np.datetime64(f"{year:04}-12-31T23:59:59"), dates, last)
85
80
  else:
86
- return np.datetime64(f"{year:04}-01-01T00:00:00")
81
+ return _as_date(np.datetime64(f"{year:04}-01-01T00:00:00"), dates, last)
87
82
 
88
83
  if len(str(d)) == 6:
89
84
  year = d // 100
90
85
  month = d % 100
91
86
  if last:
92
87
  _, last_day = calendar.monthrange(year, month)
93
- return np.datetime64(f"{year:04}-{month:02}-{last_day:02}T23:59:59")
88
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{last_day:02}T23:59:59"), dates, last)
94
89
  else:
95
- return np.datetime64(f"{year:04}-{month:02}-01T00:00:00")
90
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-01T00:00:00"), dates, last)
96
91
 
97
92
  if len(str(d)) == 8:
98
93
  year = d // 10000
99
94
  month = (d % 10000) // 100
100
95
  day = d % 100
101
96
  if last:
102
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T23:59:59")
97
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T23:59:59"), dates, last)
103
98
  else:
104
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00")
99
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00"), dates, last)
105
100
 
106
101
  if isinstance(d, str):
107
102
 
@@ -109,7 +104,11 @@ def _as_date(d, dates, last):
109
104
  date, time = d.replace(" ", "T").split("T")
110
105
  year, month, day = [int(_) for _ in date.split("-")]
111
106
  hour, minute, second = [int(_) for _ in time.split(":")]
112
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}")
107
+ return _as_date(
108
+ np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}"),
109
+ dates,
110
+ last,
111
+ )
113
112
 
114
113
  if "-" in d:
115
114
  assert ":" not in d
@@ -121,11 +120,8 @@ def _as_date(d, dates, last):
121
120
  return _as_date(int(bits[0]) * 100 + int(bits[1]), dates, last)
122
121
 
123
122
  if len(bits) == 3:
124
- return _as_date(
125
- int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]),
126
- dates,
127
- last,
128
- )
123
+ return _as_date(int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]), dates, last)
124
+
129
125
  if ":" in d:
130
126
  assert len(d) == 5
131
127
  hour, minute = d.split(":")
@@ -136,7 +132,7 @@ def _as_date(d, dates, last):
136
132
  month = first.month
137
133
  day = first.day
138
134
 
139
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00")
135
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00"), dates, last)
140
136
 
141
137
  raise NotImplementedError(f"Unsupported date: {d} ({type(d)})")
142
138
 
@@ -163,28 +159,10 @@ def _concat_or_join(datasets, kwargs):
163
159
 
164
160
  return Join(datasets)._overlay(), kwargs
165
161
 
166
- # Make sure the dates are disjoint
167
- for i in range(len(ranges)):
168
- r = ranges[i]
169
- for j in range(i + 1, len(ranges)):
170
- s = ranges[j]
171
- if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
172
- raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
173
-
174
- # For now we should have the datasets in order with no gaps
175
-
176
- frequency = _frequency_to_hours(datasets[0].frequency)
177
-
178
- for i in range(len(ranges) - 1):
179
- r = ranges[i]
180
- s = ranges[i + 1]
181
- if r[1] + datetime.timedelta(hours=frequency) != s[0]:
182
- raise ValueError(
183
- "Datasets must be sorted by dates, with no gaps: " f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
184
- )
185
-
186
162
  from .concat import Concat
187
163
 
164
+ Concat.check_dataset_compatibility(datasets)
165
+
188
166
  return Concat(datasets), kwargs
189
167
 
190
168
 
@@ -193,7 +171,7 @@ def _open(a):
193
171
  from .stores import zarr_lookup
194
172
 
195
173
  if isinstance(a, Dataset):
196
- return a
174
+ return a.mutate()
197
175
 
198
176
  if isinstance(a, zarr.hierarchy.Group):
199
177
  return Zarr(a).mutate()
@@ -202,13 +180,13 @@ def _open(a):
202
180
  return Zarr(zarr_lookup(a)).mutate()
203
181
 
204
182
  if isinstance(a, PurePath):
205
- return _open(str(a))
183
+ return _open(str(a)).mutate()
206
184
 
207
185
  if isinstance(a, dict):
208
- return _open_dataset(**a)
186
+ return _open_dataset(**a).mutate()
209
187
 
210
188
  if isinstance(a, (list, tuple)):
211
- return _open_dataset(*a)
189
+ return _open_dataset(*a).mutate()
212
190
 
213
191
  raise NotImplementedError(f"Unsupported argument: {type(a)}")
214
192
 
@@ -288,47 +266,59 @@ def _open_dataset(*args, **kwargs):
288
266
  for a in args:
289
267
  sets.append(_open(a))
290
268
 
269
+ if "xy" in kwargs:
270
+ from .xy import xy_factory
271
+
272
+ assert not sets, sets
273
+ return xy_factory(args, kwargs).mutate()
274
+
275
+ if "x" in kwargs and "y" in kwargs:
276
+ from .xy import xy_factory
277
+
278
+ assert not sets, sets
279
+ return xy_factory(args, kwargs).mutate()
280
+
291
281
  if "zip" in kwargs:
292
- from .unchecked import zip_factory
282
+ from .xy import zip_factory
293
283
 
294
284
  assert not sets, sets
295
- return zip_factory(args, kwargs)
285
+ return zip_factory(args, kwargs).mutate()
296
286
 
297
287
  if "chain" in kwargs:
298
288
  from .unchecked import chain_factory
299
289
 
300
290
  assert not sets, sets
301
- return chain_factory(args, kwargs)
291
+ return chain_factory(args, kwargs).mutate()
302
292
 
303
293
  if "join" in kwargs:
304
294
  from .join import join_factory
305
295
 
306
296
  assert not sets, sets
307
- return join_factory(args, kwargs)
297
+ return join_factory(args, kwargs).mutate()
308
298
 
309
299
  if "concat" in kwargs:
310
300
  from .concat import concat_factory
311
301
 
312
302
  assert not sets, sets
313
- return concat_factory(args, kwargs)
303
+ return concat_factory(args, kwargs).mutate()
314
304
 
315
305
  if "ensemble" in kwargs:
316
306
  from .ensemble import ensemble_factory
317
307
 
318
308
  assert not sets, sets
319
- return ensemble_factory(args, kwargs)
309
+ return ensemble_factory(args, kwargs).mutate()
320
310
 
321
311
  if "grids" in kwargs:
322
312
  from .grids import grids_factory
323
313
 
324
314
  assert not sets, sets
325
- return grids_factory(args, kwargs)
315
+ return grids_factory(args, kwargs).mutate()
326
316
 
327
317
  if "cutout" in kwargs:
328
318
  from .grids import cutout_factory
329
319
 
330
320
  assert not sets, sets
331
- return cutout_factory(args, kwargs)
321
+ return cutout_factory(args, kwargs).mutate()
332
322
 
333
323
  for name in ("datasets", "dataset"):
334
324
  if name in kwargs:
@@ -0,0 +1,240 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import logging
9
+ from functools import cached_property
10
+
11
+ import numpy as np
12
+
13
+ from anemoi.datasets.create.utils import to_datetime
14
+ from anemoi.datasets.data import MissingDateError
15
+
16
+ from .debug import Node
17
+ from .debug import debug_indexing
18
+ from .forwards import Forwards
19
+ from .indexing import expand_list_indexing
20
+ from .indexing import update_tuple
21
+
22
+ LOG = logging.getLogger(__name__)
23
+
24
+
25
+ class MissingDates(Forwards):
26
+ # TODO: Use that class instead of ZarrMissing
27
+
28
+ def __init__(self, dataset, missing_dates):
29
+ super().__init__(dataset)
30
+ self.missing_dates = []
31
+
32
+ self._missing = set()
33
+
34
+ other = []
35
+ for date in missing_dates:
36
+ if isinstance(date, int):
37
+ self._missing.add(date)
38
+ self.missing_dates.append(dataset.dates[date])
39
+ else:
40
+ date = to_datetime(date)
41
+ other.append(date)
42
+
43
+ if other:
44
+ for i, date in enumerate(dataset.dates):
45
+ if date in other:
46
+ self._missing.add(i)
47
+ self.missing_dates.append(date)
48
+
49
+ n = self.forward._len
50
+ self._missing = set(i for i in self._missing if 0 <= i < n)
51
+ self.missing_dates = sorted(to_datetime(x) for x in self.missing_dates)
52
+
53
+ assert len(self._missing), "No dates to force missing"
54
+
55
+ @cached_property
56
+ def missing(self):
57
+ return self._missing.union(self.forward.missing)
58
+
59
+ @debug_indexing
60
+ @expand_list_indexing
61
+ def __getitem__(self, n):
62
+ if isinstance(n, int):
63
+ if n in self.missing:
64
+ self._report_missing(n)
65
+ return self.forward[n]
66
+
67
+ if isinstance(n, slice):
68
+ common = set(range(*n.indices(len(self)))) & self.missing
69
+ if common:
70
+ self._report_missing(list(common)[0])
71
+ return self.forward[n]
72
+
73
+ if isinstance(n, tuple):
74
+ first = n[0]
75
+ if isinstance(first, int):
76
+ if first in self.missing:
77
+ self._report_missing(first)
78
+ return self.forward[n]
79
+
80
+ if isinstance(first, slice):
81
+ common = set(range(*first.indices(len(self)))) & self.missing
82
+ if common:
83
+ self._report_missing(list(common)[0])
84
+ return self.forward[n]
85
+
86
+ if isinstance(first, (list, tuple)):
87
+ common = set(first) & self.missing
88
+ if common:
89
+ self._report_missing(list(common)[0])
90
+ return self.forward[n]
91
+
92
+ raise TypeError(f"Unsupported index {n} {type(n)}")
93
+
94
+ def _report_missing(self, n):
95
+ raise MissingDateError(f"Date {self.forward.dates[n]} is missing (index={n})")
96
+
97
+ @property
98
+ def reason(self):
99
+ return {"missing_dates": self.missing_dates}
100
+
101
+ def tree(self):
102
+ return Node(self, [self.forward.tree()], **self.reason)
103
+
104
+ def subclass_metadata_specific(self):
105
+ return {"missing_dates": self.missing_dates}
106
+
107
+
108
+ class SkipMissingDates(Forwards):
109
+
110
+ def __init__(self, dataset, expected_access):
111
+ super().__init__(dataset)
112
+
113
+ # if isinstance(expected_access, (tuple, list)):
114
+ # expected_access = slice(*expected_access)
115
+
116
+ if isinstance(expected_access, int):
117
+ expected_access = slice(0, expected_access)
118
+
119
+ assert isinstance(expected_access, slice), f"Expected access must be a slice, got {expected_access}"
120
+
121
+ expected_access = slice(*expected_access.indices(dataset._len))
122
+ missing = dataset.missing.copy()
123
+
124
+ size = (expected_access.stop - expected_access.start) // expected_access.step
125
+ indices = []
126
+
127
+ for i in range(dataset._len):
128
+ s = slice(expected_access.start + i, expected_access.stop + i, expected_access.step)
129
+ p = set(range(*s.indices(dataset._len)))
130
+ if p.intersection(missing):
131
+ continue
132
+
133
+ if len(p) != size:
134
+ continue
135
+
136
+ indices.append(tuple(sorted(p)))
137
+
138
+ self.expected_access = expected_access
139
+ self.indices = indices
140
+
141
+ def __len__(self):
142
+ return len(self.indices)
143
+
144
+ @property
145
+ def start_date(self):
146
+ return self.forward.start_date
147
+
148
+ @property
149
+ def end_date(self):
150
+ return self.forward.end_date
151
+
152
+ @property
153
+ def dates(self):
154
+ raise NotImplementedError("SkipMissingDates.dates")
155
+
156
+ @debug_indexing
157
+ @expand_list_indexing
158
+ def _get_tuple(self, index):
159
+
160
+ def _get_one(n):
161
+ result = []
162
+ for i in self.indices[n]:
163
+ s, _ = update_tuple(index, 0, i)
164
+ result.append(self.forward[s])
165
+
166
+ return tuple(result)
167
+
168
+ first = index[0]
169
+ if isinstance(first, int):
170
+ return _get_one(first)
171
+
172
+ assert isinstance(first, slice), f"SkipMissingDates._get_tuple {index}"
173
+
174
+ values = [_get_one(i) for i in range(*first.indices(self._len))]
175
+
176
+ result = [_ for _ in zip(*values)]
177
+ return tuple(np.stack(_) for _ in result)
178
+
179
+ @debug_indexing
180
+ def _get_slice(self, s):
181
+ values = [self[i] for i in range(*s.indices(self._len))]
182
+ result = [_ for _ in zip(*values)]
183
+ return tuple(np.stack(_) for _ in result)
184
+
185
+ @debug_indexing
186
+ def __getitem__(self, n):
187
+ if isinstance(n, tuple):
188
+ return self._get_tuple(n)
189
+
190
+ if isinstance(n, slice):
191
+ return self._get_slice(n)
192
+
193
+ return tuple(self.forward[i] for i in self.indices[n])
194
+
195
+ @property
196
+ def frequency(self):
197
+ return self.forward.frequency
198
+
199
+ def tree(self):
200
+ return Node(self, [self.forward.tree()], expected_access=self.expected_access)
201
+
202
+ def subclass_metadata_specific(self):
203
+ return {"expected_access": self.expected_access}
204
+
205
+
206
+ class MissingDataset(Forwards):
207
+
208
+ def __init__(self, dataset, start, end):
209
+ super().__init__(dataset)
210
+ self.start = start
211
+ self.end = end
212
+
213
+ dates = []
214
+ date = start
215
+ while date <= end:
216
+ dates.append(date)
217
+ date += dataset.frequency
218
+
219
+ self._dates = np.array(dates, dtype="datetime64")
220
+ self._missing = set(range(len(dates)))
221
+
222
+ def __len__(self):
223
+ return len(self._dates)
224
+
225
+ @property
226
+ def dates(self):
227
+ return self._dates
228
+
229
+ @property
230
+ def missing(self):
231
+ return self._missing
232
+
233
+ def __getitem__(self, n):
234
+ raise MissingDateError(f"Date {self.dates[n]} is missing (index={n})")
235
+
236
+ def tree(self):
237
+ return Node(self, [self.forward.tree()], start=self.start, end=self.end)
238
+
239
+ def subclass_metadata_specific(self):
240
+ return {"start": self.start, "end": self.end}
@@ -0,0 +1,147 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import logging
9
+ from functools import cached_property
10
+
11
+ import numpy as np
12
+
13
+ from .debug import Node
14
+ from .debug import debug_indexing
15
+ from .forwards import Forwards
16
+ from .indexing import apply_index_to_slices_changes
17
+ from .indexing import expand_list_indexing
18
+ from .indexing import index_to_slices
19
+ from .indexing import update_tuple
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+
24
+ def make_rescale(variable, rescale):
25
+
26
+ if isinstance(rescale, (tuple, list)):
27
+
28
+ assert len(rescale) == 2, rescale
29
+
30
+ if isinstance(rescale[0], (int, float)):
31
+ return rescale
32
+
33
+ from cfunits import Units
34
+
35
+ u0 = Units(rescale[0])
36
+ u1 = Units(rescale[1])
37
+
38
+ x1, x2 = 0.0, 1.0
39
+ y1, y2 = Units.conform([x1, x2], u0, u1)
40
+
41
+ a = (y2 - y1) / (x2 - x1)
42
+ b = y1 - a * x1
43
+
44
+ return a, b
45
+
46
+ return rescale
47
+
48
+ if isinstance(rescale, dict):
49
+ assert "scale" in rescale, rescale
50
+ assert "offset" in rescale, rescale
51
+ return rescale["scale"], rescale["offset"]
52
+
53
+ assert False
54
+
55
+
56
+ class Rescale(Forwards):
57
+ def __init__(self, dataset, rescale):
58
+ super().__init__(dataset)
59
+ for n in rescale:
60
+ assert n in dataset.variables, n
61
+
62
+ variables = dataset.variables
63
+
64
+ self._a = np.ones(len(variables))
65
+ self._b = np.zeros(len(variables))
66
+
67
+ self.rescale = {}
68
+ for i, v in enumerate(variables):
69
+ if v in rescale:
70
+ a, b = make_rescale(v, rescale[v])
71
+ self.rescale[v] = a, b
72
+ self._a[i], self._b[i] = a, b
73
+
74
+ self._a = self._a[np.newaxis, :, np.newaxis, np.newaxis]
75
+ self._b = self._b[np.newaxis, :, np.newaxis, np.newaxis]
76
+
77
+ self._a = self._a.astype(self.forward.dtype)
78
+ self._b = self._b.astype(self.forward.dtype)
79
+
80
+ def tree(self):
81
+ return Node(self, [self.forward.tree()], rescale=self.rescale)
82
+
83
+ def subclass_metadata_specific(self):
84
+ return dict(rescale=self.rescale)
85
+
86
+ @debug_indexing
87
+ @expand_list_indexing
88
+ def _get_tuple(self, index):
89
+ index, changes = index_to_slices(index, self.shape)
90
+ index, previous = update_tuple(index, 1, slice(None))
91
+ result = self.forward[index]
92
+ result = result * self._a + self._b
93
+ result = result[:, previous]
94
+ result = apply_index_to_slices_changes(result, changes)
95
+ return result
96
+
97
+ @debug_indexing
98
+ def __get_slice_(self, n):
99
+ data = self.forward[n]
100
+ return data * self._a + self._b
101
+
102
+ @debug_indexing
103
+ def __getitem__(self, n):
104
+
105
+ if isinstance(n, tuple):
106
+ return self._get_tuple(n)
107
+
108
+ if isinstance(n, slice):
109
+ return self.__get_slice_(n)
110
+
111
+ data = self.forward[n]
112
+
113
+ return data * self._a[0] + self._b[0]
114
+
115
+ @cached_property
116
+ def statistics(self):
117
+ result = {}
118
+ a = self._a.squeeze()
119
+ assert np.all(a >= 0)
120
+
121
+ b = self._b.squeeze()
122
+ for k, v in self.forward.statistics.items():
123
+ if k in ("maximum", "minimum", "mean"):
124
+ result[k] = v * a + b
125
+ continue
126
+
127
+ if k in ("stdev",):
128
+ result[k] = v * a
129
+ continue
130
+
131
+ raise NotImplementedError("rescale statistics", k)
132
+
133
+ return result
134
+
135
+ def statistics_tendencies(self, delta=None):
136
+ result = {}
137
+ a = self._a.squeeze()
138
+ assert np.all(a >= 0)
139
+
140
+ for k, v in self.forward.statistics_tendencies(delta).items():
141
+ if k in ("maximum", "minimum", "mean", "stdev"):
142
+ result[k] = v * a
143
+ continue
144
+
145
+ raise NotImplementedError("rescale tendencies statistics", k)
146
+
147
+ return result
@@ -40,6 +40,12 @@ class Select(Forwards):
40
40
  # Forward other properties to the main dataset
41
41
  super().__init__(dataset)
42
42
 
43
+ def clone(self, dataset):
44
+ return self.__class__(dataset, self.indices, self.reason).mutate()
45
+
46
+ def mutate(self):
47
+ return self.forward.swap_with_parent(parent=self)
48
+
43
49
  @debug_indexing
44
50
  @expand_list_indexing
45
51
  def _get_tuple(self, index):
@@ -101,7 +107,7 @@ class Rename(Forwards):
101
107
  def __init__(self, dataset, rename):
102
108
  super().__init__(dataset)
103
109
  for n in rename:
104
- assert n in dataset.variables
110
+ assert n in dataset.variables, n
105
111
  self._variables = [rename.get(v, v) for v in dataset.variables]
106
112
  self.rename = rename
107
113