ocf-data-sampler 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/__init__.py +1 -0
- ocf_data_sampler/data/uk_gsp_locations.csv +319 -0
- ocf_data_sampler/numpy_batch/__init__.py +7 -0
- ocf_data_sampler/numpy_batch/gsp.py +23 -0
- ocf_data_sampler/numpy_batch/nwp.py +33 -0
- ocf_data_sampler/numpy_batch/satellite.py +23 -0
- ocf_data_sampler/numpy_batch/sun_position.py +66 -0
- ocf_data_sampler/select/__init__.py +1 -0
- ocf_data_sampler/select/dropout.py +38 -0
- ocf_data_sampler/select/fill_time_periods.py +11 -0
- ocf_data_sampler/select/find_contiguous_time_periods.py +301 -0
- ocf_data_sampler/select/select_spatial_slice.py +358 -0
- ocf_data_sampler/select/select_time_slice.py +184 -0
- ocf_data_sampler/torch_datasets/__init__.py +1 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +538 -0
- {ocf_data_sampler-0.0.6.dist-info → ocf_data_sampler-0.0.8.dist-info}/METADATA +1 -1
- ocf_data_sampler-0.0.8.dist-info/RECORD +22 -0
- ocf_data_sampler-0.0.8.dist-info/top_level.txt +2 -0
- ocf_data_sampler-0.0.6.dist-info/RECORD +0 -7
- ocf_data_sampler-0.0.6.dist-info/top_level.txt +0 -1
- {ocf_data_sampler-0.0.6.dist-info → ocf_data_sampler-0.0.8.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.6.dist-info → ocf_data_sampler-0.0.8.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""Select spatial slices"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import xarray as xr
|
|
7
|
+
|
|
8
|
+
from ocf_datapipes.utils import Location
|
|
9
|
+
from ocf_datapipes.utils.geospatial import (
|
|
10
|
+
lon_lat_to_geostationary_area_coords,
|
|
11
|
+
lon_lat_to_osgb,
|
|
12
|
+
osgb_to_geostationary_area_coords,
|
|
13
|
+
osgb_to_lon_lat,
|
|
14
|
+
spatial_coord_type,
|
|
15
|
+
)
|
|
16
|
+
from ocf_datapipes.utils.utils import searchsorted
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# -------------------------------- utility functions --------------------------------
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def convert_coords_to_match_xarray(
|
|
25
|
+
x: float | np.ndarray,
|
|
26
|
+
y: float | np.ndarray,
|
|
27
|
+
from_coords: str,
|
|
28
|
+
da: xr.DataArray
|
|
29
|
+
):
|
|
30
|
+
"""Convert x and y coords to cooridnate system matching xarray data
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
x: Float or array-like
|
|
34
|
+
y: Float or array-like
|
|
35
|
+
from_coords: String describing coordinate system of x and y
|
|
36
|
+
da: DataArray to which coordinates should be matched
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
target_coords, *_ = spatial_coord_type(da)
|
|
40
|
+
|
|
41
|
+
assert from_coords in ["osgb", "lon_lat"]
|
|
42
|
+
assert target_coords in ["geostationary", "osgb", "lon_lat"]
|
|
43
|
+
|
|
44
|
+
if target_coords == "geostationary":
|
|
45
|
+
if from_coords == "osgb":
|
|
46
|
+
x, y = osgb_to_geostationary_area_coords(x, y, da)
|
|
47
|
+
|
|
48
|
+
elif from_coords == "lon_lat":
|
|
49
|
+
x, y = lon_lat_to_geostationary_area_coords(x, y, da)
|
|
50
|
+
|
|
51
|
+
elif target_coords == "lon_lat":
|
|
52
|
+
if from_coords == "osgb":
|
|
53
|
+
x, y = osgb_to_lon_lat(x, y)
|
|
54
|
+
|
|
55
|
+
# else the from_coords=="lon_lat" and we don't need to convert
|
|
56
|
+
|
|
57
|
+
elif target_coords == "osgb":
|
|
58
|
+
if from_coords == "lon_lat":
|
|
59
|
+
x, y = lon_lat_to_osgb(x, y)
|
|
60
|
+
|
|
61
|
+
# else the from_coords=="osgb" and we don't need to convert
|
|
62
|
+
|
|
63
|
+
return x, y
|
|
64
|
+
|
|
65
|
+
# TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
|
|
66
|
+
# We should combine them, and consider making a Coord class to help with this
|
|
67
|
+
def _get_idx_of_pixel_closest_to_poi(
|
|
68
|
+
da: xr.DataArray,
|
|
69
|
+
location: Location,
|
|
70
|
+
) -> Location:
|
|
71
|
+
"""
|
|
72
|
+
Return x and y index location of pixel at center of region of interest.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
da: xarray DataArray
|
|
76
|
+
location: Location to find index of
|
|
77
|
+
Returns:
|
|
78
|
+
The Location for the center pixel
|
|
79
|
+
"""
|
|
80
|
+
xr_coords, x_dim, y_dim = spatial_coord_type(da)
|
|
81
|
+
|
|
82
|
+
if xr_coords not in ["osgb", "lon_lat"]:
|
|
83
|
+
raise NotImplementedError(f"Only 'osgb' and 'lon_lat' are supported - not '{xr_coords}'")
|
|
84
|
+
|
|
85
|
+
# Convert location coords to match xarray data
|
|
86
|
+
x, y = convert_coords_to_match_xarray(
|
|
87
|
+
location.x,
|
|
88
|
+
location.y,
|
|
89
|
+
from_coords=location.coordinate_system,
|
|
90
|
+
da=da,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Check that the requested point lies within the data
|
|
94
|
+
assert da[x_dim].min() < x < da[x_dim].max()
|
|
95
|
+
assert da[y_dim].min() < y < da[y_dim].max()
|
|
96
|
+
|
|
97
|
+
x_index = da.get_index(x_dim)
|
|
98
|
+
y_index = da.get_index(y_dim)
|
|
99
|
+
|
|
100
|
+
closest_x = x_index.get_indexer([x], method="nearest")[0]
|
|
101
|
+
closest_y = y_index.get_indexer([y], method="nearest")[0]
|
|
102
|
+
|
|
103
|
+
return Location(x=closest_x, y=closest_y, coordinate_system="idx")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _get_idx_of_pixel_closest_to_poi_geostationary(
|
|
107
|
+
da: xr.DataArray,
|
|
108
|
+
center_osgb: Location,
|
|
109
|
+
) -> Location:
|
|
110
|
+
"""
|
|
111
|
+
Return x and y index location of pixel at center of region of interest.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
da: xarray DataArray
|
|
115
|
+
center_osgb: Center in OSGB coordinates
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Location for the center pixel in geostationary coordinates
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
_, x_dim, y_dim = spatial_coord_type(da)
|
|
122
|
+
|
|
123
|
+
x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da)
|
|
124
|
+
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
|
|
125
|
+
|
|
126
|
+
# Check that the requested point lies within the data
|
|
127
|
+
assert da[x_dim].min() < x < da[x_dim].max()
|
|
128
|
+
assert da[y_dim].min() < y < da[y_dim].max()
|
|
129
|
+
|
|
130
|
+
# Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
|
|
131
|
+
x_index_at_center = searchsorted(
|
|
132
|
+
da[x_dim].values, center_geostationary.x, assume_ascending=True
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
y_index_at_center = searchsorted(
|
|
136
|
+
da[y_dim].values, center_geostationary.y, assume_ascending=True
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# ---------------------------- sub-functions for slicing ----------------------------
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _select_partial_spatial_slice_pixels(
|
|
146
|
+
da,
|
|
147
|
+
left_idx,
|
|
148
|
+
right_idx,
|
|
149
|
+
bottom_idx,
|
|
150
|
+
top_idx,
|
|
151
|
+
left_pad_pixels,
|
|
152
|
+
right_pad_pixels,
|
|
153
|
+
bottom_pad_pixels,
|
|
154
|
+
top_pad_pixels,
|
|
155
|
+
x_dim,
|
|
156
|
+
y_dim,
|
|
157
|
+
):
|
|
158
|
+
"""Return spatial window of given pixel size when window partially overlaps input data"""
|
|
159
|
+
|
|
160
|
+
# We should never be padding on both sides of a window. This would mean our desired window is
|
|
161
|
+
# larger than the size of the input data
|
|
162
|
+
assert left_pad_pixels==0 or right_pad_pixels==0
|
|
163
|
+
assert bottom_pad_pixels==0 or top_pad_pixels==0
|
|
164
|
+
|
|
165
|
+
dx = np.median(np.diff(da[x_dim].values))
|
|
166
|
+
dy = np.median(np.diff(da[y_dim].values))
|
|
167
|
+
|
|
168
|
+
# Pad the left of the window
|
|
169
|
+
if left_pad_pixels > 0:
|
|
170
|
+
x_sel = np.concatenate(
|
|
171
|
+
[
|
|
172
|
+
da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx,
|
|
173
|
+
da[x_dim].values[0:right_idx],
|
|
174
|
+
]
|
|
175
|
+
)
|
|
176
|
+
da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel})
|
|
177
|
+
|
|
178
|
+
# Pad the right of the window
|
|
179
|
+
elif right_pad_pixels > 0:
|
|
180
|
+
x_sel = np.concatenate(
|
|
181
|
+
[
|
|
182
|
+
da[x_dim].values[left_idx:],
|
|
183
|
+
da[x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx,
|
|
184
|
+
]
|
|
185
|
+
)
|
|
186
|
+
da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel})
|
|
187
|
+
|
|
188
|
+
# No left-right padding required
|
|
189
|
+
else:
|
|
190
|
+
da = da.isel({x_dim: slice(left_idx, right_idx)})
|
|
191
|
+
|
|
192
|
+
# Pad the bottom of the window
|
|
193
|
+
if bottom_pad_pixels > 0:
|
|
194
|
+
y_sel = np.concatenate(
|
|
195
|
+
[
|
|
196
|
+
da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy,
|
|
197
|
+
da[y_dim].values[0:top_idx],
|
|
198
|
+
]
|
|
199
|
+
)
|
|
200
|
+
da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel})
|
|
201
|
+
|
|
202
|
+
# Pad the top of the window
|
|
203
|
+
elif top_pad_pixels > 0:
|
|
204
|
+
y_sel = np.concatenate(
|
|
205
|
+
[
|
|
206
|
+
da[y_dim].values[bottom_idx:],
|
|
207
|
+
da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})
|
|
211
|
+
|
|
212
|
+
# No bottom-top padding required
|
|
213
|
+
else:
|
|
214
|
+
da = da.isel({y_dim: slice(bottom_idx, top_idx)})
|
|
215
|
+
|
|
216
|
+
return da
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _select_spatial_slice_pixels(
|
|
220
|
+
da: xr.DataArray,
|
|
221
|
+
center_idx: Location,
|
|
222
|
+
width_pixels: int,
|
|
223
|
+
height_pixels: int,
|
|
224
|
+
x_dim: str,
|
|
225
|
+
y_dim: str,
|
|
226
|
+
allow_partial_slice: bool,
|
|
227
|
+
):
|
|
228
|
+
"""Select a spatial slice from an xarray object
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
da: xarray DataArray to slice from
|
|
232
|
+
center_idx: Location object describing the centre of the window with index coordinates
|
|
233
|
+
width_pixels: Window with in pixels
|
|
234
|
+
height_pixels: Window height in pixels
|
|
235
|
+
x_dim: Name of the x-dimension in `da`
|
|
236
|
+
y_dim: Name of the y-dimension in `da`
|
|
237
|
+
allow_partial_slice: Whether to allow a partially filled window
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
assert center_idx.coordinate_system == "idx"
|
|
241
|
+
# TODO: It shouldn't take much effort to allow height and width to be odd
|
|
242
|
+
assert (width_pixels % 2)==0, "Width must be an even number"
|
|
243
|
+
assert (height_pixels % 2)==0, "Height must be an even number"
|
|
244
|
+
|
|
245
|
+
half_width = width_pixels // 2
|
|
246
|
+
half_height = height_pixels // 2
|
|
247
|
+
|
|
248
|
+
left_idx = int(center_idx.x - half_width)
|
|
249
|
+
right_idx = int(center_idx.x + half_width)
|
|
250
|
+
bottom_idx = int(center_idx.y - half_height)
|
|
251
|
+
top_idx = int(center_idx.y + half_height)
|
|
252
|
+
|
|
253
|
+
data_width_pixels = len(da[x_dim])
|
|
254
|
+
data_height_pixels = len(da[y_dim])
|
|
255
|
+
|
|
256
|
+
left_pad_required = left_idx < 0
|
|
257
|
+
right_pad_required = right_idx > data_width_pixels
|
|
258
|
+
bottom_pad_required = bottom_idx < 0
|
|
259
|
+
top_pad_required = top_idx > data_height_pixels
|
|
260
|
+
|
|
261
|
+
pad_required = left_pad_required | right_pad_required | bottom_pad_required | top_pad_required
|
|
262
|
+
|
|
263
|
+
if pad_required:
|
|
264
|
+
if allow_partial_slice:
|
|
265
|
+
|
|
266
|
+
left_pad_pixels = (-left_idx) if left_pad_required else 0
|
|
267
|
+
right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0
|
|
268
|
+
|
|
269
|
+
bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
|
|
270
|
+
top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
da = _select_partial_spatial_slice_pixels(
|
|
274
|
+
da,
|
|
275
|
+
left_idx,
|
|
276
|
+
right_idx,
|
|
277
|
+
bottom_idx,
|
|
278
|
+
top_idx,
|
|
279
|
+
left_pad_pixels,
|
|
280
|
+
right_pad_pixels,
|
|
281
|
+
bottom_pad_pixels,
|
|
282
|
+
top_pad_pixels,
|
|
283
|
+
x_dim,
|
|
284
|
+
y_dim,
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
f"Window for location {center_idx} not available. Missing (left, right, bottom, "
|
|
289
|
+
f"top) pixels = ({left_pad_required}, {right_pad_required}, "
|
|
290
|
+
f"{bottom_pad_required}, {top_pad_required}). "
|
|
291
|
+
f"You may wish to set `allow_partial_slice=True`"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
else:
|
|
295
|
+
da = da.isel(
|
|
296
|
+
{
|
|
297
|
+
x_dim: slice(left_idx, right_idx),
|
|
298
|
+
y_dim: slice(bottom_idx, top_idx),
|
|
299
|
+
}
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
assert len(da[x_dim]) == width_pixels, (
|
|
303
|
+
f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
|
|
304
|
+
f"for location {center_idx} for slice {left_idx}:{right_idx}"
|
|
305
|
+
)
|
|
306
|
+
assert len(da[y_dim]) == height_pixels, (
|
|
307
|
+
f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
|
|
308
|
+
f"for location {center_idx} for slice {bottom_idx}:{top_idx}"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return da
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# ---------------------------- main functions for slicing ---------------------------
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def select_spatial_slice_pixels(
|
|
318
|
+
da: xr.DataArray,
|
|
319
|
+
location: Location,
|
|
320
|
+
width_pixels: int,
|
|
321
|
+
height_pixels: int,
|
|
322
|
+
allow_partial_slice: bool = False,
|
|
323
|
+
):
|
|
324
|
+
"""
|
|
325
|
+
Select spatial slice based off pixels from location point of interest
|
|
326
|
+
|
|
327
|
+
If `allow_partial_slice` is set to True, then slices may be made which intersect the border
|
|
328
|
+
of the input data. The additional x and y cordinates that would be required for this slice
|
|
329
|
+
are extrapolated based on the average spacing of these coordinates in the input data.
|
|
330
|
+
However, currently slices cannot be made where the centre of the window is outside of the
|
|
331
|
+
input data.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
da: xarray DataArray to slice from
|
|
335
|
+
location: Location of interest
|
|
336
|
+
height_pixels: Height of the slice in pixels
|
|
337
|
+
width_pixels: Width of the slice in pixels
|
|
338
|
+
allow_partial_slice: Whether to allow a partial slice.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
xr_coords, x_dim, y_dim = spatial_coord_type(da)
|
|
342
|
+
|
|
343
|
+
if xr_coords == "geostationary":
|
|
344
|
+
center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(da, location)
|
|
345
|
+
else:
|
|
346
|
+
center_idx: Location = _get_idx_of_pixel_closest_to_poi(da, location)
|
|
347
|
+
|
|
348
|
+
selected = _select_spatial_slice_pixels(
|
|
349
|
+
da,
|
|
350
|
+
center_idx,
|
|
351
|
+
width_pixels,
|
|
352
|
+
height_pixels,
|
|
353
|
+
x_dim,
|
|
354
|
+
y_dim,
|
|
355
|
+
allow_partial_slice=allow_partial_slice,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return selected
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _sel_fillnan(
|
|
8
|
+
da: xr.DataArray,
|
|
9
|
+
start_dt: pd.Timestamp,
|
|
10
|
+
end_dt: pd.Timestamp,
|
|
11
|
+
sample_period_duration: pd.Timedelta,
|
|
12
|
+
) -> xr.DataArray:
|
|
13
|
+
"""Select a time slice from a DataArray, filling missing times with NaNs."""
|
|
14
|
+
requested_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
|
|
15
|
+
return da.reindex(time_utc=requested_times)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _sel_default(
|
|
19
|
+
da: xr.DataArray,
|
|
20
|
+
start_dt: pd.Timestamp,
|
|
21
|
+
end_dt: pd.Timestamp,
|
|
22
|
+
sample_period_duration: pd.Timedelta,
|
|
23
|
+
) -> xr.DataArray:
|
|
24
|
+
"""Select a time slice from a DataArray, without filling missing times."""
|
|
25
|
+
return da.sel(time_utc=slice(start_dt, end_dt))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# TODO either implement this or remove it, which would tidy up the code
|
|
29
|
+
def _sel_fillinterp(
|
|
30
|
+
da: xr.DataArray,
|
|
31
|
+
start_dt: pd.Timestamp,
|
|
32
|
+
end_dt: pd.Timestamp,
|
|
33
|
+
sample_period_duration: pd.Timedelta,
|
|
34
|
+
) -> xr.DataArray:
|
|
35
|
+
"""Select a time slice from a DataArray, filling missing times with linear interpolation."""
|
|
36
|
+
return NotImplemented
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def select_time_slice(
|
|
40
|
+
ds: xr.Dataset | xr.DataArray,
|
|
41
|
+
t0: pd.Timestamp,
|
|
42
|
+
sample_period_duration: pd.Timedelta,
|
|
43
|
+
history_duration: pd.Timedelta | None = None,
|
|
44
|
+
forecast_duration: pd.Timedelta | None = None,
|
|
45
|
+
interval_start: pd.Timedelta | None = None,
|
|
46
|
+
interval_end: pd.Timedelta | None = None,
|
|
47
|
+
fill_selection: bool = False,
|
|
48
|
+
max_steps_gap: int = 0,
|
|
49
|
+
):
|
|
50
|
+
"""Select a time slice from a Dataset or DataArray."""
|
|
51
|
+
used_duration = history_duration is not None and forecast_duration is not None
|
|
52
|
+
used_intervals = interval_start is not None and interval_end is not None
|
|
53
|
+
assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
|
|
54
|
+
assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
|
|
55
|
+
|
|
56
|
+
if used_duration:
|
|
57
|
+
interval_start = - history_duration
|
|
58
|
+
interval_end = forecast_duration
|
|
59
|
+
|
|
60
|
+
if fill_selection and max_steps_gap == 0:
|
|
61
|
+
_sel = _sel_fillnan
|
|
62
|
+
elif fill_selection and max_steps_gap > 0:
|
|
63
|
+
_sel = _sel_fillinterp
|
|
64
|
+
else:
|
|
65
|
+
_sel = _sel_default
|
|
66
|
+
|
|
67
|
+
t0_datetime_utc = pd.Timestamp(t0)
|
|
68
|
+
start_dt = t0_datetime_utc + interval_start
|
|
69
|
+
end_dt = t0_datetime_utc + interval_end
|
|
70
|
+
|
|
71
|
+
start_dt = start_dt.ceil(sample_period_duration)
|
|
72
|
+
end_dt = end_dt.ceil(sample_period_duration)
|
|
73
|
+
|
|
74
|
+
return _sel(ds, start_dt, end_dt, sample_period_duration)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def select_time_slice_nwp(
|
|
78
|
+
ds: xr.Dataset | xr.DataArray,
|
|
79
|
+
t0: pd.Timestamp,
|
|
80
|
+
sample_period_duration: pd.Timedelta,
|
|
81
|
+
history_duration: pd.Timedelta,
|
|
82
|
+
forecast_duration: pd.Timedelta,
|
|
83
|
+
dropout_timedeltas: list[pd.Timedelta] | None = None,
|
|
84
|
+
dropout_frac: float | None = 0,
|
|
85
|
+
accum_channels: list[str] = [],
|
|
86
|
+
channel_dim_name: str = "channel",
|
|
87
|
+
):
|
|
88
|
+
|
|
89
|
+
if dropout_timedeltas is not None:
|
|
90
|
+
assert all(
|
|
91
|
+
[t < pd.Timedelta(0) for t in dropout_timedeltas]
|
|
92
|
+
), "dropout timedeltas must be negative"
|
|
93
|
+
assert len(dropout_timedeltas) >= 1
|
|
94
|
+
assert 0 <= dropout_frac <= 1
|
|
95
|
+
_consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# The accumatation and non-accumulation channels
|
|
99
|
+
accum_channels = np.intersect1d(
|
|
100
|
+
ds[channel_dim_name].values, accum_channels
|
|
101
|
+
)
|
|
102
|
+
non_accum_channels = np.setdiff1d(
|
|
103
|
+
ds[channel_dim_name].values, accum_channels
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
start_dt = (t0 - history_duration).ceil(sample_period_duration)
|
|
107
|
+
end_dt = (t0 + forecast_duration).ceil(sample_period_duration)
|
|
108
|
+
|
|
109
|
+
target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
|
|
110
|
+
|
|
111
|
+
# Maybe apply NWP dropout
|
|
112
|
+
if _consider_dropout and (np.random.uniform() < dropout_frac):
|
|
113
|
+
dt = np.random.choice(dropout_timedeltas)
|
|
114
|
+
t0_available = t0 + dt
|
|
115
|
+
else:
|
|
116
|
+
t0_available = t0
|
|
117
|
+
|
|
118
|
+
# Forecasts made up to and including t0
|
|
119
|
+
available_init_times = ds.init_time_utc.sel(
|
|
120
|
+
init_time_utc=slice(None, t0_available)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Find the most recent available init times for all target times
|
|
124
|
+
selected_init_times = available_init_times.sel(
|
|
125
|
+
init_time_utc=target_times,
|
|
126
|
+
method="ffill", # forward fill from init times to target times
|
|
127
|
+
).values
|
|
128
|
+
|
|
129
|
+
# Find the required steps for all target times
|
|
130
|
+
steps = target_times - selected_init_times
|
|
131
|
+
|
|
132
|
+
# We want one timestep for each target_time_hourly (obviously!) If we simply do
|
|
133
|
+
# nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
|
|
134
|
+
# init_times and steps, which is not what # we want! Instead, we use xarray's
|
|
135
|
+
# vectorized-indexing mode by using a DataArray indexer. See the last example here:
|
|
136
|
+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
|
|
137
|
+
coords = {"target_time_utc": target_times}
|
|
138
|
+
init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
|
|
139
|
+
step_indexer = xr.DataArray(steps, coords=coords)
|
|
140
|
+
|
|
141
|
+
if len(accum_channels) == 0:
|
|
142
|
+
xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
143
|
+
|
|
144
|
+
else:
|
|
145
|
+
# First minimise the size of the dataset we are diffing
|
|
146
|
+
# - find the init times we are slicing from
|
|
147
|
+
unique_init_times = np.unique(selected_init_times)
|
|
148
|
+
# - find the min and max steps we slice over. Max is extended due to diff
|
|
149
|
+
min_step = min(steps)
|
|
150
|
+
max_step = max(steps) + sample_period_duration
|
|
151
|
+
|
|
152
|
+
xr_min = ds.sel(
|
|
153
|
+
{
|
|
154
|
+
"init_time_utc": unique_init_times,
|
|
155
|
+
"step": slice(min_step, max_step),
|
|
156
|
+
}
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Slice out the data which does not need to be diffed
|
|
160
|
+
xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels})
|
|
161
|
+
xr_sel_non_accum = xr_non_accum.sel(
|
|
162
|
+
step=step_indexer, init_time_utc=init_time_indexer
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Slice out the channels which need to be diffed
|
|
166
|
+
xr_accum = xr_min.sel({channel_dim_name: accum_channels})
|
|
167
|
+
|
|
168
|
+
# Take the diff and slice requested data
|
|
169
|
+
xr_accum = xr_accum.diff(dim="step", label="lower")
|
|
170
|
+
xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
171
|
+
|
|
172
|
+
# Join diffed and non-diffed variables
|
|
173
|
+
xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name)
|
|
174
|
+
|
|
175
|
+
# Reorder the variable back to the original order
|
|
176
|
+
xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values})
|
|
177
|
+
|
|
178
|
+
# Rename the diffed channels
|
|
179
|
+
xr_sel[channel_dim_name] = [
|
|
180
|
+
f"diff_{v}" if v in accum_channels else v
|
|
181
|
+
for v in xr_sel[channel_dim_name].values
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
return xr_sel
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|