ocf-data-sampler 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -0,0 +1,358 @@
1
+ """Select spatial slices"""
2
+
3
+ import logging
4
+
5
+ import numpy as np
6
+ import xarray as xr
7
+
8
+ from ocf_datapipes.utils import Location
9
+ from ocf_datapipes.utils.geospatial import (
10
+ lon_lat_to_geostationary_area_coords,
11
+ lon_lat_to_osgb,
12
+ osgb_to_geostationary_area_coords,
13
+ osgb_to_lon_lat,
14
+ spatial_coord_type,
15
+ )
16
+ from ocf_datapipes.utils.utils import searchsorted
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # -------------------------------- utility functions --------------------------------
22
+
23
+
24
+ def convert_coords_to_match_xarray(
25
+ x: float | np.ndarray,
26
+ y: float | np.ndarray,
27
+ from_coords: str,
28
+ da: xr.DataArray
29
+ ):
30
+ """Convert x and y coords to cooridnate system matching xarray data
31
+
32
+ Args:
33
+ x: Float or array-like
34
+ y: Float or array-like
35
+ from_coords: String describing coordinate system of x and y
36
+ da: DataArray to which coordinates should be matched
37
+ """
38
+
39
+ target_coords, *_ = spatial_coord_type(da)
40
+
41
+ assert from_coords in ["osgb", "lon_lat"]
42
+ assert target_coords in ["geostationary", "osgb", "lon_lat"]
43
+
44
+ if target_coords == "geostationary":
45
+ if from_coords == "osgb":
46
+ x, y = osgb_to_geostationary_area_coords(x, y, da)
47
+
48
+ elif from_coords == "lon_lat":
49
+ x, y = lon_lat_to_geostationary_area_coords(x, y, da)
50
+
51
+ elif target_coords == "lon_lat":
52
+ if from_coords == "osgb":
53
+ x, y = osgb_to_lon_lat(x, y)
54
+
55
+ # else the from_coords=="lon_lat" and we don't need to convert
56
+
57
+ elif target_coords == "osgb":
58
+ if from_coords == "lon_lat":
59
+ x, y = lon_lat_to_osgb(x, y)
60
+
61
+ # else the from_coords=="osgb" and we don't need to convert
62
+
63
+ return x, y
64
+
65
+ # TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
66
+ # We should combine them, and consider making a Coord class to help with this
67
+ def _get_idx_of_pixel_closest_to_poi(
68
+ da: xr.DataArray,
69
+ location: Location,
70
+ ) -> Location:
71
+ """
72
+ Return x and y index location of pixel at center of region of interest.
73
+
74
+ Args:
75
+ da: xarray DataArray
76
+ location: Location to find index of
77
+ Returns:
78
+ The Location for the center pixel
79
+ """
80
+ xr_coords, x_dim, y_dim = spatial_coord_type(da)
81
+
82
+ if xr_coords not in ["osgb", "lon_lat"]:
83
+ raise NotImplementedError(f"Only 'osgb' and 'lon_lat' are supported - not '{xr_coords}'")
84
+
85
+ # Convert location coords to match xarray data
86
+ x, y = convert_coords_to_match_xarray(
87
+ location.x,
88
+ location.y,
89
+ from_coords=location.coordinate_system,
90
+ da=da,
91
+ )
92
+
93
+ # Check that the requested point lies within the data
94
+ assert da[x_dim].min() < x < da[x_dim].max()
95
+ assert da[y_dim].min() < y < da[y_dim].max()
96
+
97
+ x_index = da.get_index(x_dim)
98
+ y_index = da.get_index(y_dim)
99
+
100
+ closest_x = x_index.get_indexer([x], method="nearest")[0]
101
+ closest_y = y_index.get_indexer([y], method="nearest")[0]
102
+
103
+ return Location(x=closest_x, y=closest_y, coordinate_system="idx")
104
+
105
+
106
+ def _get_idx_of_pixel_closest_to_poi_geostationary(
107
+ da: xr.DataArray,
108
+ center_osgb: Location,
109
+ ) -> Location:
110
+ """
111
+ Return x and y index location of pixel at center of region of interest.
112
+
113
+ Args:
114
+ da: xarray DataArray
115
+ center_osgb: Center in OSGB coordinates
116
+
117
+ Returns:
118
+ Location for the center pixel in geostationary coordinates
119
+ """
120
+
121
+ _, x_dim, y_dim = spatial_coord_type(da)
122
+
123
+ x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da)
124
+ center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
125
+
126
+ # Check that the requested point lies within the data
127
+ assert da[x_dim].min() < x < da[x_dim].max()
128
+ assert da[y_dim].min() < y < da[y_dim].max()
129
+
130
+ # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
131
+ x_index_at_center = searchsorted(
132
+ da[x_dim].values, center_geostationary.x, assume_ascending=True
133
+ )
134
+
135
+ y_index_at_center = searchsorted(
136
+ da[y_dim].values, center_geostationary.y, assume_ascending=True
137
+ )
138
+
139
+ return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
140
+
141
+
142
+ # ---------------------------- sub-functions for slicing ----------------------------
143
+
144
+
145
+ def _select_partial_spatial_slice_pixels(
146
+ da,
147
+ left_idx,
148
+ right_idx,
149
+ bottom_idx,
150
+ top_idx,
151
+ left_pad_pixels,
152
+ right_pad_pixels,
153
+ bottom_pad_pixels,
154
+ top_pad_pixels,
155
+ x_dim,
156
+ y_dim,
157
+ ):
158
+ """Return spatial window of given pixel size when window partially overlaps input data"""
159
+
160
+ # We should never be padding on both sides of a window. This would mean our desired window is
161
+ # larger than the size of the input data
162
+ assert left_pad_pixels==0 or right_pad_pixels==0
163
+ assert bottom_pad_pixels==0 or top_pad_pixels==0
164
+
165
+ dx = np.median(np.diff(da[x_dim].values))
166
+ dy = np.median(np.diff(da[y_dim].values))
167
+
168
+ # Pad the left of the window
169
+ if left_pad_pixels > 0:
170
+ x_sel = np.concatenate(
171
+ [
172
+ da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx,
173
+ da[x_dim].values[0:right_idx],
174
+ ]
175
+ )
176
+ da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel})
177
+
178
+ # Pad the right of the window
179
+ elif right_pad_pixels > 0:
180
+ x_sel = np.concatenate(
181
+ [
182
+ da[x_dim].values[left_idx:],
183
+ da[x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx,
184
+ ]
185
+ )
186
+ da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel})
187
+
188
+ # No left-right padding required
189
+ else:
190
+ da = da.isel({x_dim: slice(left_idx, right_idx)})
191
+
192
+ # Pad the bottom of the window
193
+ if bottom_pad_pixels > 0:
194
+ y_sel = np.concatenate(
195
+ [
196
+ da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy,
197
+ da[y_dim].values[0:top_idx],
198
+ ]
199
+ )
200
+ da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel})
201
+
202
+ # Pad the top of the window
203
+ elif top_pad_pixels > 0:
204
+ y_sel = np.concatenate(
205
+ [
206
+ da[y_dim].values[bottom_idx:],
207
+ da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
208
+ ]
209
+ )
210
+ da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})
211
+
212
+ # No bottom-top padding required
213
+ else:
214
+ da = da.isel({y_dim: slice(bottom_idx, top_idx)})
215
+
216
+ return da
217
+
218
+
219
+ def _select_spatial_slice_pixels(
220
+ da: xr.DataArray,
221
+ center_idx: Location,
222
+ width_pixels: int,
223
+ height_pixels: int,
224
+ x_dim: str,
225
+ y_dim: str,
226
+ allow_partial_slice: bool,
227
+ ):
228
+ """Select a spatial slice from an xarray object
229
+
230
+ Args:
231
+ da: xarray DataArray to slice from
232
+ center_idx: Location object describing the centre of the window with index coordinates
233
+ width_pixels: Window with in pixels
234
+ height_pixels: Window height in pixels
235
+ x_dim: Name of the x-dimension in `da`
236
+ y_dim: Name of the y-dimension in `da`
237
+ allow_partial_slice: Whether to allow a partially filled window
238
+ """
239
+
240
+ assert center_idx.coordinate_system == "idx"
241
+ # TODO: It shouldn't take much effort to allow height and width to be odd
242
+ assert (width_pixels % 2)==0, "Width must be an even number"
243
+ assert (height_pixels % 2)==0, "Height must be an even number"
244
+
245
+ half_width = width_pixels // 2
246
+ half_height = height_pixels // 2
247
+
248
+ left_idx = int(center_idx.x - half_width)
249
+ right_idx = int(center_idx.x + half_width)
250
+ bottom_idx = int(center_idx.y - half_height)
251
+ top_idx = int(center_idx.y + half_height)
252
+
253
+ data_width_pixels = len(da[x_dim])
254
+ data_height_pixels = len(da[y_dim])
255
+
256
+ left_pad_required = left_idx < 0
257
+ right_pad_required = right_idx > data_width_pixels
258
+ bottom_pad_required = bottom_idx < 0
259
+ top_pad_required = top_idx > data_height_pixels
260
+
261
+ pad_required = left_pad_required | right_pad_required | bottom_pad_required | top_pad_required
262
+
263
+ if pad_required:
264
+ if allow_partial_slice:
265
+
266
+ left_pad_pixels = (-left_idx) if left_pad_required else 0
267
+ right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0
268
+
269
+ bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
270
+ top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0
271
+
272
+
273
+ da = _select_partial_spatial_slice_pixels(
274
+ da,
275
+ left_idx,
276
+ right_idx,
277
+ bottom_idx,
278
+ top_idx,
279
+ left_pad_pixels,
280
+ right_pad_pixels,
281
+ bottom_pad_pixels,
282
+ top_pad_pixels,
283
+ x_dim,
284
+ y_dim,
285
+ )
286
+ else:
287
+ raise ValueError(
288
+ f"Window for location {center_idx} not available. Missing (left, right, bottom, "
289
+ f"top) pixels = ({left_pad_required}, {right_pad_required}, "
290
+ f"{bottom_pad_required}, {top_pad_required}). "
291
+ f"You may wish to set `allow_partial_slice=True`"
292
+ )
293
+
294
+ else:
295
+ da = da.isel(
296
+ {
297
+ x_dim: slice(left_idx, right_idx),
298
+ y_dim: slice(bottom_idx, top_idx),
299
+ }
300
+ )
301
+
302
+ assert len(da[x_dim]) == width_pixels, (
303
+ f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
304
+ f"for location {center_idx} for slice {left_idx}:{right_idx}"
305
+ )
306
+ assert len(da[y_dim]) == height_pixels, (
307
+ f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
308
+ f"for location {center_idx} for slice {bottom_idx}:{top_idx}"
309
+ )
310
+
311
+ return da
312
+
313
+
314
+ # ---------------------------- main functions for slicing ---------------------------
315
+
316
+
317
+ def select_spatial_slice_pixels(
318
+ da: xr.DataArray,
319
+ location: Location,
320
+ width_pixels: int,
321
+ height_pixels: int,
322
+ allow_partial_slice: bool = False,
323
+ ):
324
+ """
325
+ Select spatial slice based off pixels from location point of interest
326
+
327
+ If `allow_partial_slice` is set to True, then slices may be made which intersect the border
328
+ of the input data. The additional x and y cordinates that would be required for this slice
329
+ are extrapolated based on the average spacing of these coordinates in the input data.
330
+ However, currently slices cannot be made where the centre of the window is outside of the
331
+ input data.
332
+
333
+ Args:
334
+ da: xarray DataArray to slice from
335
+ location: Location of interest
336
+ height_pixels: Height of the slice in pixels
337
+ width_pixels: Width of the slice in pixels
338
+ allow_partial_slice: Whether to allow a partial slice.
339
+ """
340
+
341
+ xr_coords, x_dim, y_dim = spatial_coord_type(da)
342
+
343
+ if xr_coords == "geostationary":
344
+ center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(da, location)
345
+ else:
346
+ center_idx: Location = _get_idx_of_pixel_closest_to_poi(da, location)
347
+
348
+ selected = _select_spatial_slice_pixels(
349
+ da,
350
+ center_idx,
351
+ width_pixels,
352
+ height_pixels,
353
+ x_dim,
354
+ y_dim,
355
+ allow_partial_slice=allow_partial_slice,
356
+ )
357
+
358
+ return selected
@@ -0,0 +1,184 @@
1
+ import xarray as xr
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+
6
+
7
+ def _sel_fillnan(
8
+ da: xr.DataArray,
9
+ start_dt: pd.Timestamp,
10
+ end_dt: pd.Timestamp,
11
+ sample_period_duration: pd.Timedelta,
12
+ ) -> xr.DataArray:
13
+ """Select a time slice from a DataArray, filling missing times with NaNs."""
14
+ requested_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
15
+ return da.reindex(time_utc=requested_times)
16
+
17
+
18
+ def _sel_default(
19
+ da: xr.DataArray,
20
+ start_dt: pd.Timestamp,
21
+ end_dt: pd.Timestamp,
22
+ sample_period_duration: pd.Timedelta,
23
+ ) -> xr.DataArray:
24
+ """Select a time slice from a DataArray, without filling missing times."""
25
+ return da.sel(time_utc=slice(start_dt, end_dt))
26
+
27
+
28
+ # TODO either implement this or remove it, which would tidy up the code
29
+ def _sel_fillinterp(
30
+ da: xr.DataArray,
31
+ start_dt: pd.Timestamp,
32
+ end_dt: pd.Timestamp,
33
+ sample_period_duration: pd.Timedelta,
34
+ ) -> xr.DataArray:
35
+ """Select a time slice from a DataArray, filling missing times with linear interpolation."""
36
+ return NotImplemented
37
+
38
+
39
+ def select_time_slice(
40
+ ds: xr.Dataset | xr.DataArray,
41
+ t0: pd.Timestamp,
42
+ sample_period_duration: pd.Timedelta,
43
+ history_duration: pd.Timedelta | None = None,
44
+ forecast_duration: pd.Timedelta | None = None,
45
+ interval_start: pd.Timedelta | None = None,
46
+ interval_end: pd.Timedelta | None = None,
47
+ fill_selection: bool = False,
48
+ max_steps_gap: int = 0,
49
+ ):
50
+ """Select a time slice from a Dataset or DataArray."""
51
+ used_duration = history_duration is not None and forecast_duration is not None
52
+ used_intervals = interval_start is not None and interval_end is not None
53
+ assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
54
+ assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
55
+
56
+ if used_duration:
57
+ interval_start = - history_duration
58
+ interval_end = forecast_duration
59
+
60
+ if fill_selection and max_steps_gap == 0:
61
+ _sel = _sel_fillnan
62
+ elif fill_selection and max_steps_gap > 0:
63
+ _sel = _sel_fillinterp
64
+ else:
65
+ _sel = _sel_default
66
+
67
+ t0_datetime_utc = pd.Timestamp(t0)
68
+ start_dt = t0_datetime_utc + interval_start
69
+ end_dt = t0_datetime_utc + interval_end
70
+
71
+ start_dt = start_dt.ceil(sample_period_duration)
72
+ end_dt = end_dt.ceil(sample_period_duration)
73
+
74
+ return _sel(ds, start_dt, end_dt, sample_period_duration)
75
+
76
+
77
+ def select_time_slice_nwp(
78
+ ds: xr.Dataset | xr.DataArray,
79
+ t0: pd.Timestamp,
80
+ sample_period_duration: pd.Timedelta,
81
+ history_duration: pd.Timedelta,
82
+ forecast_duration: pd.Timedelta,
83
+ dropout_timedeltas: list[pd.Timedelta] | None = None,
84
+ dropout_frac: float | None = 0,
85
+ accum_channels: list[str] = [],
86
+ channel_dim_name: str = "channel",
87
+ ):
88
+
89
+ if dropout_timedeltas is not None:
90
+ assert all(
91
+ [t < pd.Timedelta(0) for t in dropout_timedeltas]
92
+ ), "dropout timedeltas must be negative"
93
+ assert len(dropout_timedeltas) >= 1
94
+ assert 0 <= dropout_frac <= 1
95
+ _consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
96
+
97
+
98
+ # The accumatation and non-accumulation channels
99
+ accum_channels = np.intersect1d(
100
+ ds[channel_dim_name].values, accum_channels
101
+ )
102
+ non_accum_channels = np.setdiff1d(
103
+ ds[channel_dim_name].values, accum_channels
104
+ )
105
+
106
+ start_dt = (t0 - history_duration).ceil(sample_period_duration)
107
+ end_dt = (t0 + forecast_duration).ceil(sample_period_duration)
108
+
109
+ target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
110
+
111
+ # Maybe apply NWP dropout
112
+ if _consider_dropout and (np.random.uniform() < dropout_frac):
113
+ dt = np.random.choice(dropout_timedeltas)
114
+ t0_available = t0 + dt
115
+ else:
116
+ t0_available = t0
117
+
118
+ # Forecasts made up to and including t0
119
+ available_init_times = ds.init_time_utc.sel(
120
+ init_time_utc=slice(None, t0_available)
121
+ )
122
+
123
+ # Find the most recent available init times for all target times
124
+ selected_init_times = available_init_times.sel(
125
+ init_time_utc=target_times,
126
+ method="ffill", # forward fill from init times to target times
127
+ ).values
128
+
129
+ # Find the required steps for all target times
130
+ steps = target_times - selected_init_times
131
+
132
+ # We want one timestep for each target_time_hourly (obviously!) If we simply do
133
+ # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
134
+ # init_times and steps, which is not what # we want! Instead, we use xarray's
135
+ # vectorized-indexing mode by using a DataArray indexer. See the last example here:
136
+ # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
137
+ coords = {"target_time_utc": target_times}
138
+ init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
139
+ step_indexer = xr.DataArray(steps, coords=coords)
140
+
141
+ if len(accum_channels) == 0:
142
+ xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer)
143
+
144
+ else:
145
+ # First minimise the size of the dataset we are diffing
146
+ # - find the init times we are slicing from
147
+ unique_init_times = np.unique(selected_init_times)
148
+ # - find the min and max steps we slice over. Max is extended due to diff
149
+ min_step = min(steps)
150
+ max_step = max(steps) + sample_period_duration
151
+
152
+ xr_min = ds.sel(
153
+ {
154
+ "init_time_utc": unique_init_times,
155
+ "step": slice(min_step, max_step),
156
+ }
157
+ )
158
+
159
+ # Slice out the data which does not need to be diffed
160
+ xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels})
161
+ xr_sel_non_accum = xr_non_accum.sel(
162
+ step=step_indexer, init_time_utc=init_time_indexer
163
+ )
164
+
165
+ # Slice out the channels which need to be diffed
166
+ xr_accum = xr_min.sel({channel_dim_name: accum_channels})
167
+
168
+ # Take the diff and slice requested data
169
+ xr_accum = xr_accum.diff(dim="step", label="lower")
170
+ xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
171
+
172
+ # Join diffed and non-diffed variables
173
+ xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name)
174
+
175
+ # Reorder the variable back to the original order
176
+ xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values})
177
+
178
+ # Rename the diffed channels
179
+ xr_sel[channel_dim_name] = [
180
+ f"diff_{v}" if v in accum_channels else v
181
+ for v in xr_sel[channel_dim_name].values
182
+ ]
183
+
184
+ return xr_sel