ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +86 -72
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +27 -36
- ocf_data_sampler/load/site.py +11 -7
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +15 -13
- ocf_data_sampler/numpy_sample/nwp.py +17 -23
- ocf_data_sampler/numpy_sample/satellite.py +17 -14
- ocf_data_sampler/numpy_sample/site.py +8 -7
- ocf_data_sampler/numpy_sample/sun_position.py +19 -25
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +23 -44
- ocf_data_sampler/sample/site.py +25 -69
- ocf_data_sampler/sample/uk_regional.py +52 -103
- ocf_data_sampler/select/dropout.py +42 -27
- ocf_data_sampler/select/fill_time_periods.py +15 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -286
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -52
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -75
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -195
- tests/test_sample/test_uk_regional_sample.py +0 -163
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -167
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,19 +1,18 @@
|
|
|
1
|
-
"""Select spatial slices"""
|
|
1
|
+
"""Select spatial slices."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import xarray as xr
|
|
7
7
|
|
|
8
|
-
from ocf_data_sampler.select.location import Location
|
|
9
8
|
from ocf_data_sampler.select.geospatial import (
|
|
10
|
-
lon_lat_to_osgb,
|
|
11
9
|
lon_lat_to_geostationary_area_coords,
|
|
10
|
+
lon_lat_to_osgb,
|
|
12
11
|
osgb_to_geostationary_area_coords,
|
|
13
12
|
osgb_to_lon_lat,
|
|
14
13
|
spatial_coord_type,
|
|
15
14
|
)
|
|
16
|
-
|
|
15
|
+
from ocf_data_sampler.select.location import Location
|
|
17
16
|
|
|
18
17
|
logger = logging.getLogger(__name__)
|
|
19
18
|
|
|
@@ -22,12 +21,12 @@ logger = logging.getLogger(__name__)
|
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
def convert_coords_to_match_xarray(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
"""Convert x and y coords to cooridnate system matching xarray data
|
|
24
|
+
x: float | np.ndarray,
|
|
25
|
+
y: float | np.ndarray,
|
|
26
|
+
from_coords: str,
|
|
27
|
+
da: xr.DataArray,
|
|
28
|
+
) -> tuple[float | np.ndarray, float | np.ndarray]:
|
|
29
|
+
"""Convert x and y coords to cooridnate system matching xarray data.
|
|
31
30
|
|
|
32
31
|
Args:
|
|
33
32
|
x: Float or array-like
|
|
@@ -35,38 +34,42 @@ def convert_coords_to_match_xarray(
|
|
|
35
34
|
from_coords: String describing coordinate system of x and y
|
|
36
35
|
da: DataArray to which coordinates should be matched
|
|
37
36
|
"""
|
|
38
|
-
|
|
39
37
|
target_coords, *_ = spatial_coord_type(da)
|
|
40
38
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
if target_coords == "geostationary":
|
|
45
|
-
if from_coords == "osgb":
|
|
39
|
+
match (from_coords, target_coords):
|
|
40
|
+
case ("osgb", "geostationary"):
|
|
46
41
|
x, y = osgb_to_geostationary_area_coords(x, y, da)
|
|
47
42
|
|
|
48
|
-
|
|
49
|
-
if from_coords == "osgb":
|
|
43
|
+
case ("osgb", "lon_lat"):
|
|
50
44
|
x, y = osgb_to_lon_lat(x, y)
|
|
51
45
|
|
|
52
|
-
|
|
46
|
+
case ("osgb", "osgb"):
|
|
47
|
+
pass
|
|
53
48
|
|
|
54
|
-
|
|
55
|
-
if from_coords == "lon_lat":
|
|
49
|
+
case ("lon_lat", "osgb"):
|
|
56
50
|
x, y = lon_lat_to_osgb(x, y)
|
|
57
51
|
|
|
58
|
-
|
|
52
|
+
case ("lon_lat", "geostationary"):
|
|
53
|
+
x, y = lon_lat_to_geostationary_area_coords(x, y, da)
|
|
54
|
+
|
|
55
|
+
case ("lon_lat", "lon_lat"):
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
case (_, _):
|
|
59
|
+
raise NotImplementedError(
|
|
60
|
+
f"Conversion from {from_coords} to {target_coords} is not supported",
|
|
61
|
+
)
|
|
59
62
|
|
|
60
63
|
return x, y
|
|
61
64
|
|
|
62
|
-
|
|
65
|
+
|
|
66
|
+
# TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
|
|
63
67
|
# We should combine them, and consider making a Coord class to help with this
|
|
64
68
|
def _get_idx_of_pixel_closest_to_poi(
|
|
65
69
|
da: xr.DataArray,
|
|
66
70
|
location: Location,
|
|
67
71
|
) -> Location:
|
|
68
|
-
"""
|
|
69
|
-
Return x and y index location of pixel at center of region of interest.
|
|
72
|
+
"""Return x and y index location of pixel at center of region of interest.
|
|
70
73
|
|
|
71
74
|
Args:
|
|
72
75
|
da: xarray DataArray
|
|
@@ -88,8 +91,14 @@ def _get_idx_of_pixel_closest_to_poi(
|
|
|
88
91
|
)
|
|
89
92
|
|
|
90
93
|
# Check that the requested point lies within the data
|
|
91
|
-
|
|
92
|
-
|
|
94
|
+
if not (da[x_dim].min() < x < da[x_dim].max()):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
|
|
97
|
+
)
|
|
98
|
+
if not (da[y_dim].min() < y < da[y_dim].max()):
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
|
|
101
|
+
)
|
|
93
102
|
|
|
94
103
|
x_index = da.get_index(x_dim)
|
|
95
104
|
y_index = da.get_index(y_dim)
|
|
@@ -104,32 +113,38 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
|
|
|
104
113
|
da: xr.DataArray,
|
|
105
114
|
center: Location,
|
|
106
115
|
) -> Location:
|
|
107
|
-
"""
|
|
108
|
-
Return x and y index location of pixel at center of region of interest.
|
|
116
|
+
"""Return x and y index location of pixel at center of region of interest.
|
|
109
117
|
|
|
110
118
|
Args:
|
|
111
119
|
da: xarray DataArray
|
|
112
|
-
|
|
120
|
+
center: Center in OSGB coordinates
|
|
113
121
|
|
|
114
122
|
Returns:
|
|
115
123
|
Location for the center pixel in geostationary coordinates
|
|
116
124
|
"""
|
|
117
|
-
|
|
118
125
|
_, x_dim, y_dim = spatial_coord_type(da)
|
|
119
126
|
|
|
120
|
-
if center.coordinate_system ==
|
|
127
|
+
if center.coordinate_system == "osgb":
|
|
121
128
|
x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
|
|
122
|
-
elif center.coordinate_system ==
|
|
123
|
-
x, y = lon_lat_to_geostationary_area_coords(
|
|
129
|
+
elif center.coordinate_system == "lon_lat":
|
|
130
|
+
x, y = lon_lat_to_geostationary_area_coords(
|
|
131
|
+
longitude=center.x,
|
|
132
|
+
latitude=center.y,
|
|
133
|
+
xr_data=da,
|
|
134
|
+
)
|
|
124
135
|
else:
|
|
125
|
-
x,y = center.x, center.y
|
|
136
|
+
x, y = center.x, center.y
|
|
126
137
|
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
|
|
127
138
|
|
|
128
139
|
# Check that the requested point lies within the data
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
140
|
+
if not (da[x_dim].min() < x < da[x_dim].max()):
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
|
|
143
|
+
)
|
|
144
|
+
if not (da[y_dim].min() < y < da[y_dim].max()):
|
|
145
|
+
raise ValueError(
|
|
146
|
+
f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
|
|
147
|
+
)
|
|
133
148
|
|
|
134
149
|
# Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
|
|
135
150
|
x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
|
|
@@ -142,24 +157,25 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
|
|
|
142
157
|
|
|
143
158
|
|
|
144
159
|
def _select_partial_spatial_slice_pixels(
|
|
145
|
-
da,
|
|
146
|
-
left_idx,
|
|
147
|
-
right_idx,
|
|
148
|
-
bottom_idx,
|
|
149
|
-
top_idx,
|
|
150
|
-
left_pad_pixels,
|
|
151
|
-
right_pad_pixels,
|
|
152
|
-
bottom_pad_pixels,
|
|
153
|
-
top_pad_pixels,
|
|
154
|
-
x_dim,
|
|
155
|
-
y_dim,
|
|
156
|
-
):
|
|
157
|
-
"""Return spatial window of given pixel size when window partially overlaps input data"""
|
|
158
|
-
|
|
159
|
-
# We should never be padding on both sides of a window. This would mean our desired window is
|
|
160
|
+
da: xr.DataArray,
|
|
161
|
+
left_idx: int,
|
|
162
|
+
right_idx: int,
|
|
163
|
+
bottom_idx: int,
|
|
164
|
+
top_idx: int,
|
|
165
|
+
left_pad_pixels: int,
|
|
166
|
+
right_pad_pixels: int,
|
|
167
|
+
bottom_pad_pixels: int,
|
|
168
|
+
top_pad_pixels: int,
|
|
169
|
+
x_dim: str,
|
|
170
|
+
y_dim: str,
|
|
171
|
+
) -> xr.DataArray:
|
|
172
|
+
"""Return spatial window of given pixel size when window partially overlaps input data."""
|
|
173
|
+
# We should never be padding on both sides of a window. This would mean our desired window is
|
|
160
174
|
# larger than the size of the input data
|
|
161
|
-
|
|
162
|
-
|
|
175
|
+
if (left_pad_pixels != 0 and right_pad_pixels != 0) or (
|
|
176
|
+
bottom_pad_pixels != 0 and top_pad_pixels != 0
|
|
177
|
+
):
|
|
178
|
+
raise ValueError("Cannot pad both sides of the window")
|
|
163
179
|
|
|
164
180
|
dx = np.median(np.diff(da[x_dim].values))
|
|
165
181
|
dy = np.median(np.diff(da[y_dim].values))
|
|
@@ -170,7 +186,7 @@ def _select_partial_spatial_slice_pixels(
|
|
|
170
186
|
[
|
|
171
187
|
da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx,
|
|
172
188
|
da[x_dim].values[0:right_idx],
|
|
173
|
-
]
|
|
189
|
+
],
|
|
174
190
|
)
|
|
175
191
|
da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel})
|
|
176
192
|
|
|
@@ -180,7 +196,7 @@ def _select_partial_spatial_slice_pixels(
|
|
|
180
196
|
[
|
|
181
197
|
da[x_dim].values[left_idx:],
|
|
182
198
|
da[x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx,
|
|
183
|
-
]
|
|
199
|
+
],
|
|
184
200
|
)
|
|
185
201
|
da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel})
|
|
186
202
|
|
|
@@ -194,7 +210,7 @@ def _select_partial_spatial_slice_pixels(
|
|
|
194
210
|
[
|
|
195
211
|
da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy,
|
|
196
212
|
da[y_dim].values[0:top_idx],
|
|
197
|
-
]
|
|
213
|
+
],
|
|
198
214
|
)
|
|
199
215
|
da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel})
|
|
200
216
|
|
|
@@ -204,7 +220,7 @@ def _select_partial_spatial_slice_pixels(
|
|
|
204
220
|
[
|
|
205
221
|
da[y_dim].values[bottom_idx:],
|
|
206
222
|
da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
|
|
207
|
-
]
|
|
223
|
+
],
|
|
208
224
|
)
|
|
209
225
|
da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})
|
|
210
226
|
|
|
@@ -216,15 +232,15 @@ def _select_partial_spatial_slice_pixels(
|
|
|
216
232
|
|
|
217
233
|
|
|
218
234
|
def _select_spatial_slice_pixels(
|
|
219
|
-
da: xr.DataArray,
|
|
220
|
-
center_idx: Location,
|
|
221
|
-
width_pixels: int,
|
|
222
|
-
height_pixels: int,
|
|
223
|
-
x_dim: str,
|
|
224
|
-
y_dim: str,
|
|
235
|
+
da: xr.DataArray,
|
|
236
|
+
center_idx: Location,
|
|
237
|
+
width_pixels: int,
|
|
238
|
+
height_pixels: int,
|
|
239
|
+
x_dim: str,
|
|
240
|
+
y_dim: str,
|
|
225
241
|
allow_partial_slice: bool,
|
|
226
|
-
):
|
|
227
|
-
"""Select a spatial slice from an xarray object
|
|
242
|
+
) -> xr.DataArray:
|
|
243
|
+
"""Select a spatial slice from an xarray object.
|
|
228
244
|
|
|
229
245
|
Args:
|
|
230
246
|
da: xarray DataArray to slice from
|
|
@@ -235,11 +251,13 @@ def _select_spatial_slice_pixels(
|
|
|
235
251
|
y_dim: Name of the y-dimension in `da`
|
|
236
252
|
allow_partial_slice: Whether to allow a partially filled window
|
|
237
253
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
254
|
+
if center_idx.coordinate_system != "idx":
|
|
255
|
+
raise ValueError(f"Expected center_idx to be in 'idx' coordinates, got '{center_idx}'")
|
|
240
256
|
# TODO: It shouldn't take much effort to allow height and width to be odd
|
|
241
|
-
|
|
242
|
-
|
|
257
|
+
if (width_pixels % 2) != 0:
|
|
258
|
+
raise ValueError("Width must be an even number")
|
|
259
|
+
if (height_pixels % 2) != 0:
|
|
260
|
+
raise ValueError("Height must be an even number")
|
|
243
261
|
|
|
244
262
|
half_width = width_pixels // 2
|
|
245
263
|
half_height = height_pixels // 2
|
|
@@ -261,14 +279,12 @@ def _select_spatial_slice_pixels(
|
|
|
261
279
|
|
|
262
280
|
if pad_required:
|
|
263
281
|
if allow_partial_slice:
|
|
264
|
-
|
|
265
282
|
left_pad_pixels = (-left_idx) if left_pad_required else 0
|
|
266
283
|
right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0
|
|
267
284
|
|
|
268
285
|
bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
|
|
269
286
|
top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0
|
|
270
287
|
|
|
271
|
-
|
|
272
288
|
da = _select_partial_spatial_slice_pixels(
|
|
273
289
|
da,
|
|
274
290
|
left_idx,
|
|
@@ -287,7 +303,7 @@ def _select_spatial_slice_pixels(
|
|
|
287
303
|
f"Window for location {center_idx} not available. Missing (left, right, bottom, "
|
|
288
304
|
f"top) pixels = ({left_pad_required}, {right_pad_required}, "
|
|
289
305
|
f"{bottom_pad_required}, {top_pad_required}). "
|
|
290
|
-
f"You may wish to set `allow_partial_slice=True`"
|
|
306
|
+
f"You may wish to set `allow_partial_slice=True`",
|
|
291
307
|
)
|
|
292
308
|
|
|
293
309
|
else:
|
|
@@ -295,17 +311,19 @@ def _select_spatial_slice_pixels(
|
|
|
295
311
|
{
|
|
296
312
|
x_dim: slice(left_idx, right_idx),
|
|
297
313
|
y_dim: slice(bottom_idx, top_idx),
|
|
298
|
-
}
|
|
314
|
+
},
|
|
299
315
|
)
|
|
300
316
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
317
|
+
if len(da[x_dim]) != width_pixels:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
|
|
320
|
+
f"for location {center_idx} for slice {left_idx}:{right_idx}",
|
|
321
|
+
)
|
|
322
|
+
if len(da[y_dim]) != height_pixels:
|
|
323
|
+
raise ValueError(
|
|
324
|
+
f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
|
|
325
|
+
f"for location {center_idx} for slice {bottom_idx}:{top_idx}",
|
|
326
|
+
)
|
|
309
327
|
|
|
310
328
|
return da
|
|
311
329
|
|
|
@@ -319,9 +337,8 @@ def select_spatial_slice_pixels(
|
|
|
319
337
|
width_pixels: int,
|
|
320
338
|
height_pixels: int,
|
|
321
339
|
allow_partial_slice: bool = False,
|
|
322
|
-
):
|
|
323
|
-
"""
|
|
324
|
-
Select spatial slice based off pixels from location point of interest
|
|
340
|
+
) -> xr.DataArray:
|
|
341
|
+
"""Select spatial slice based off pixels from location point of interest.
|
|
325
342
|
|
|
326
343
|
If `allow_partial_slice` is set to True, then slices may be made which intersect the border
|
|
327
344
|
of the input data. The additional x and y cordinates that would be required for this slice
|
|
@@ -336,7 +353,6 @@ def select_spatial_slice_pixels(
|
|
|
336
353
|
width_pixels: Width of the slice in pixels
|
|
337
354
|
allow_partial_slice: Whether to allow a partial slice.
|
|
338
355
|
"""
|
|
339
|
-
|
|
340
356
|
xr_coords, x_dim, y_dim = spatial_coord_type(da)
|
|
341
357
|
|
|
342
358
|
if xr_coords == "geostationary":
|
|
@@ -354,4 +370,4 @@ def select_spatial_slice_pixels(
|
|
|
354
370
|
allow_partial_slice=allow_partial_slice,
|
|
355
371
|
)
|
|
356
372
|
|
|
357
|
-
return selected
|
|
373
|
+
return selected
|
|
@@ -1,57 +1,80 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""Select a time slice from a Dataset or DataArray."""
|
|
2
|
+
|
|
3
3
|
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
def select_time_slice(
|
|
6
|
-
|
|
9
|
+
da: xr.DataArray,
|
|
7
10
|
t0: pd.Timestamp,
|
|
8
11
|
interval_start: pd.Timedelta,
|
|
9
12
|
interval_end: pd.Timedelta,
|
|
10
|
-
|
|
11
|
-
):
|
|
12
|
-
"""Select a time slice from a
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
time_resolution: pd.Timedelta,
|
|
14
|
+
) -> xr.DataArray:
|
|
15
|
+
"""Select a time slice from a DataArray.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
da: The DataArray to slice from
|
|
19
|
+
t0: The init-time
|
|
20
|
+
interval_start: The start of the interval with respect to t0
|
|
21
|
+
interval_end: The end of the interval with respect to t0
|
|
22
|
+
time_resolution: Distance between neighbouring timestamps
|
|
23
|
+
"""
|
|
24
|
+
start_dt = t0 + interval_start
|
|
25
|
+
end_dt = t0 + interval_end
|
|
16
26
|
|
|
17
|
-
start_dt = start_dt.ceil(
|
|
18
|
-
end_dt = end_dt.ceil(
|
|
27
|
+
start_dt = start_dt.ceil(time_resolution)
|
|
28
|
+
end_dt = end_dt.ceil(time_resolution)
|
|
29
|
+
|
|
30
|
+
return da.sel(time_utc=slice(start_dt, end_dt))
|
|
19
31
|
|
|
20
|
-
return ds.sel(time_utc=slice(start_dt, end_dt))
|
|
21
32
|
|
|
22
33
|
def select_time_slice_nwp(
|
|
23
34
|
da: xr.DataArray,
|
|
24
35
|
t0: pd.Timestamp,
|
|
25
36
|
interval_start: pd.Timedelta,
|
|
26
37
|
interval_end: pd.Timedelta,
|
|
27
|
-
|
|
38
|
+
time_resolution: pd.Timedelta,
|
|
28
39
|
dropout_timedeltas: list[pd.Timedelta] | None = None,
|
|
29
40
|
dropout_frac: float | None = 0,
|
|
30
|
-
accum_channels: list[str] =
|
|
31
|
-
|
|
32
|
-
|
|
41
|
+
accum_channels: list[str] | None = None,
|
|
42
|
+
) -> xr.DataArray:
|
|
43
|
+
"""Select a time slice from an NWP DataArray.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
da: The DataArray to slice from
|
|
47
|
+
t0: The init-time
|
|
48
|
+
interval_start: The start of the interval with respect to t0
|
|
49
|
+
interval_end: The end of the interval with respect to t0
|
|
50
|
+
time_resolution: Distance between neighbouring timestamps
|
|
51
|
+
dropout_timedeltas: List of possible timedeltas before t0 where data availability may start
|
|
52
|
+
dropout_frac: Probability to apply dropout
|
|
53
|
+
accum_channels: Channels which are accumulated and need to be differenced
|
|
54
|
+
"""
|
|
55
|
+
if accum_channels is None:
|
|
56
|
+
accum_channels = []
|
|
57
|
+
|
|
33
58
|
if dropout_timedeltas is not None:
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
59
|
+
if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
|
|
60
|
+
raise ValueError("dropout timedeltas must be negative")
|
|
61
|
+
if len(dropout_timedeltas) < 1:
|
|
62
|
+
raise ValueError("dropout timedeltas must have at least one element")
|
|
63
|
+
|
|
64
|
+
if not (0 <= dropout_frac <= 1):
|
|
65
|
+
raise ValueError("dropout_frac must be between 0 and 1")
|
|
40
66
|
|
|
41
|
-
|
|
42
|
-
accum_channels = np.intersect1d(
|
|
43
|
-
da[channel_dim_name].values, accum_channels
|
|
44
|
-
)
|
|
45
|
-
non_accum_channels = np.setdiff1d(
|
|
46
|
-
da[channel_dim_name].values, accum_channels
|
|
47
|
-
)
|
|
67
|
+
consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
|
|
48
68
|
|
|
49
|
-
|
|
50
|
-
|
|
69
|
+
# The accumatated and non-accumulated channels
|
|
70
|
+
accum_channels = np.intersect1d(da.channel.values, accum_channels)
|
|
71
|
+
non_accum_channels = np.setdiff1d(da.channel.values, accum_channels)
|
|
51
72
|
|
|
52
|
-
|
|
73
|
+
start_dt = (t0 + interval_start).ceil(time_resolution)
|
|
74
|
+
end_dt = (t0 + interval_end).ceil(time_resolution)
|
|
75
|
+
target_times = pd.date_range(start_dt, end_dt, freq=time_resolution)
|
|
53
76
|
|
|
54
|
-
#
|
|
77
|
+
# Potentially apply NWP dropout
|
|
55
78
|
if consider_dropout and (np.random.uniform() < dropout_frac):
|
|
56
79
|
dt = np.random.choice(dropout_timedeltas)
|
|
57
80
|
t0_available = t0 + dt
|
|
@@ -59,9 +82,7 @@ def select_time_slice_nwp(
|
|
|
59
82
|
t0_available = t0
|
|
60
83
|
|
|
61
84
|
# Forecasts made up to and including t0
|
|
62
|
-
available_init_times = da.init_time_utc.sel(
|
|
63
|
-
init_time_utc=slice(None, t0_available)
|
|
64
|
-
)
|
|
85
|
+
available_init_times = da.init_time_utc.sel(init_time_utc=slice(None, t0_available))
|
|
65
86
|
|
|
66
87
|
# Find the most recent available init times for all target times
|
|
67
88
|
selected_init_times = available_init_times.sel(
|
|
@@ -74,10 +95,10 @@ def select_time_slice_nwp(
|
|
|
74
95
|
|
|
75
96
|
# We want one timestep for each target_time_hourly (obviously!) If we simply do
|
|
76
97
|
# nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
|
|
77
|
-
# init_times and steps, which is not what
|
|
78
|
-
#
|
|
98
|
+
# init_times and steps, which is not what we want! Instead, we use xarray's
|
|
99
|
+
# vectorised-indexing mode via using a DataArray indexer. See the last example here:
|
|
79
100
|
# https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
|
|
80
|
-
|
|
101
|
+
|
|
81
102
|
coords = {"target_time_utc": target_times}
|
|
82
103
|
init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
|
|
83
104
|
step_indexer = xr.DataArray(steps, coords=coords)
|
|
@@ -90,38 +111,30 @@ def select_time_slice_nwp(
|
|
|
90
111
|
unique_init_times = np.unique(selected_init_times)
|
|
91
112
|
# - find the min and max steps we slice over. Max is extended due to diff
|
|
92
113
|
min_step = min(steps)
|
|
93
|
-
max_step = max(steps) +
|
|
114
|
+
max_step = max(steps) + time_resolution
|
|
94
115
|
|
|
95
|
-
da_min = da.sel(
|
|
96
|
-
{
|
|
97
|
-
"init_time_utc": unique_init_times,
|
|
98
|
-
"step": slice(min_step, max_step),
|
|
99
|
-
}
|
|
100
|
-
)
|
|
116
|
+
da_min = da.sel(init_time_utc=unique_init_times, step=slice(min_step, max_step))
|
|
101
117
|
|
|
102
118
|
# Slice out the data which does not need to be diffed
|
|
103
|
-
da_non_accum = da_min.sel(
|
|
104
|
-
da_sel_non_accum = da_non_accum.sel(
|
|
105
|
-
step=step_indexer, init_time_utc=init_time_indexer
|
|
106
|
-
)
|
|
119
|
+
da_non_accum = da_min.sel(channel=non_accum_channels)
|
|
120
|
+
da_sel_non_accum = da_non_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
107
121
|
|
|
108
122
|
# Slice out the channels which need to be diffed
|
|
109
|
-
da_accum = da_min.sel(
|
|
110
|
-
|
|
123
|
+
da_accum = da_min.sel(channel=accum_channels)
|
|
124
|
+
|
|
111
125
|
# Take the diff and slice requested data
|
|
112
126
|
da_accum = da_accum.diff(dim="step", label="lower")
|
|
113
127
|
da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
114
128
|
|
|
115
129
|
# Join diffed and non-diffed variables
|
|
116
|
-
da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=
|
|
117
|
-
|
|
130
|
+
da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim="channel")
|
|
131
|
+
|
|
118
132
|
# Reorder the variable back to the original order
|
|
119
|
-
da_sel = da_sel.sel(
|
|
133
|
+
da_sel = da_sel.sel(channel=da.channel.values)
|
|
120
134
|
|
|
121
135
|
# Rename the diffed channels
|
|
122
|
-
da_sel[
|
|
123
|
-
f"diff_{v}" if v in accum_channels else v
|
|
124
|
-
for v in da_sel[channel_dim_name].values
|
|
136
|
+
da_sel["channel"] = [
|
|
137
|
+
f"diff_{v}" if v in accum_channels else v for v in da_sel.channel.values
|
|
125
138
|
]
|
|
126
139
|
|
|
127
140
|
return da_sel
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Functions for selecting data around a given location."""
|
|
2
|
+
|
|
2
3
|
from ocf_data_sampler.config import Configuration
|
|
3
4
|
from ocf_data_sampler.select.location import Location
|
|
4
5
|
from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
|
|
@@ -9,24 +10,24 @@ def slice_datasets_by_space(
|
|
|
9
10
|
location: Location,
|
|
10
11
|
config: Configuration,
|
|
11
12
|
) -> dict:
|
|
12
|
-
"""Slice the dictionary of input data sources around a given location
|
|
13
|
+
"""Slice the dictionary of input data sources around a given location.
|
|
13
14
|
|
|
14
15
|
Args:
|
|
15
16
|
datasets_dict: Dictionary of the input data sources
|
|
16
17
|
location: The location to sample around
|
|
17
18
|
config: Configuration object.
|
|
18
19
|
"""
|
|
19
|
-
|
|
20
|
-
|
|
20
|
+
if not set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"}):
|
|
21
|
+
raise ValueError(
|
|
22
|
+
"'datasets_dict' should only contain keys 'nwp', 'sat', 'gsp', 'site'",
|
|
23
|
+
)
|
|
21
24
|
|
|
22
25
|
sliced_datasets_dict = {}
|
|
23
26
|
|
|
24
27
|
if "nwp" in datasets_dict:
|
|
25
|
-
|
|
26
28
|
sliced_datasets_dict["nwp"] = {}
|
|
27
29
|
|
|
28
30
|
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
29
|
-
|
|
30
31
|
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
|
|
31
32
|
datasets_dict["nwp"][nwp_key],
|
|
32
33
|
location,
|