ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,19 +1,18 @@
1
- """Select spatial slices"""
1
+ """Select spatial slices."""
2
2
 
3
3
  import logging
4
4
 
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
 
8
- from ocf_data_sampler.select.location import Location
9
8
  from ocf_data_sampler.select.geospatial import (
10
- lon_lat_to_osgb,
11
9
  lon_lat_to_geostationary_area_coords,
10
+ lon_lat_to_osgb,
12
11
  osgb_to_geostationary_area_coords,
13
12
  osgb_to_lon_lat,
14
13
  spatial_coord_type,
15
14
  )
16
-
15
+ from ocf_data_sampler.select.location import Location
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
@@ -22,12 +21,12 @@ logger = logging.getLogger(__name__)
22
21
 
23
22
 
24
23
  def convert_coords_to_match_xarray(
25
- x: float | np.ndarray,
26
- y: float | np.ndarray,
27
- from_coords: str,
28
- da: xr.DataArray
29
- ):
30
- """Convert x and y coords to cooridnate system matching xarray data
24
+ x: float | np.ndarray,
25
+ y: float | np.ndarray,
26
+ from_coords: str,
27
+ da: xr.DataArray,
28
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
29
+ """Convert x and y coords to cooridnate system matching xarray data.
31
30
 
32
31
  Args:
33
32
  x: Float or array-like
@@ -35,38 +34,42 @@ def convert_coords_to_match_xarray(
35
34
  from_coords: String describing coordinate system of x and y
36
35
  da: DataArray to which coordinates should be matched
37
36
  """
38
-
39
37
  target_coords, *_ = spatial_coord_type(da)
40
38
 
41
- assert from_coords in ["osgb", "lon_lat"]
42
- assert target_coords in ["geostationary", "osgb", "lon_lat"]
43
-
44
- if target_coords == "geostationary":
45
- if from_coords == "osgb":
39
+ match (from_coords, target_coords):
40
+ case ("osgb", "geostationary"):
46
41
  x, y = osgb_to_geostationary_area_coords(x, y, da)
47
42
 
48
- elif target_coords == "lon_lat":
49
- if from_coords == "osgb":
43
+ case ("osgb", "lon_lat"):
50
44
  x, y = osgb_to_lon_lat(x, y)
51
45
 
52
- # else the from_coords=="lon_lat" and we don't need to convert
46
+ case ("osgb", "osgb"):
47
+ pass
53
48
 
54
- elif target_coords == "osgb":
55
- if from_coords == "lon_lat":
49
+ case ("lon_lat", "osgb"):
56
50
  x, y = lon_lat_to_osgb(x, y)
57
51
 
58
- # else the from_coords=="osgb" and we don't need to convert
52
+ case ("lon_lat", "geostationary"):
53
+ x, y = lon_lat_to_geostationary_area_coords(x, y, da)
54
+
55
+ case ("lon_lat", "lon_lat"):
56
+ pass
57
+
58
+ case (_, _):
59
+ raise NotImplementedError(
60
+ f"Conversion from {from_coords} to {target_coords} is not supported",
61
+ )
59
62
 
60
63
  return x, y
61
64
 
62
- # TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
65
+
66
+ # TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
63
67
  # We should combine them, and consider making a Coord class to help with this
64
68
  def _get_idx_of_pixel_closest_to_poi(
65
69
  da: xr.DataArray,
66
70
  location: Location,
67
71
  ) -> Location:
68
- """
69
- Return x and y index location of pixel at center of region of interest.
72
+ """Return x and y index location of pixel at center of region of interest.
70
73
 
71
74
  Args:
72
75
  da: xarray DataArray
@@ -88,8 +91,14 @@ def _get_idx_of_pixel_closest_to_poi(
88
91
  )
89
92
 
90
93
  # Check that the requested point lies within the data
91
- assert da[x_dim].min() < x < da[x_dim].max()
92
- assert da[y_dim].min() < y < da[y_dim].max()
94
+ if not (da[x_dim].min() < x < da[x_dim].max()):
95
+ raise ValueError(
96
+ f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
97
+ )
98
+ if not (da[y_dim].min() < y < da[y_dim].max()):
99
+ raise ValueError(
100
+ f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
101
+ )
93
102
 
94
103
  x_index = da.get_index(x_dim)
95
104
  y_index = da.get_index(y_dim)
@@ -104,32 +113,38 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
104
113
  da: xr.DataArray,
105
114
  center: Location,
106
115
  ) -> Location:
107
- """
108
- Return x and y index location of pixel at center of region of interest.
116
+ """Return x and y index location of pixel at center of region of interest.
109
117
 
110
118
  Args:
111
119
  da: xarray DataArray
112
- center_osgb: Center in OSGB coordinates
120
+ center: Center in OSGB coordinates
113
121
 
114
122
  Returns:
115
123
  Location for the center pixel in geostationary coordinates
116
124
  """
117
-
118
125
  _, x_dim, y_dim = spatial_coord_type(da)
119
126
 
120
- if center.coordinate_system == 'osgb':
127
+ if center.coordinate_system == "osgb":
121
128
  x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
122
- elif center.coordinate_system == 'lon_lat':
123
- x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da)
129
+ elif center.coordinate_system == "lon_lat":
130
+ x, y = lon_lat_to_geostationary_area_coords(
131
+ longitude=center.x,
132
+ latitude=center.y,
133
+ xr_data=da,
134
+ )
124
135
  else:
125
- x,y = center.x, center.y
136
+ x, y = center.x, center.y
126
137
  center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
127
138
 
128
139
  # Check that the requested point lies within the data
129
- assert da[x_dim].min() < x < da[x_dim].max(), \
130
- f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}"
131
- assert da[y_dim].min() < y < da[y_dim].max(), \
132
- f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"
140
+ if not (da[x_dim].min() < x < da[x_dim].max()):
141
+ raise ValueError(
142
+ f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
143
+ )
144
+ if not (da[y_dim].min() < y < da[y_dim].max()):
145
+ raise ValueError(
146
+ f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
147
+ )
133
148
 
134
149
  # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
135
150
  x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
@@ -142,24 +157,25 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
142
157
 
143
158
 
144
159
  def _select_partial_spatial_slice_pixels(
145
- da,
146
- left_idx,
147
- right_idx,
148
- bottom_idx,
149
- top_idx,
150
- left_pad_pixels,
151
- right_pad_pixels,
152
- bottom_pad_pixels,
153
- top_pad_pixels,
154
- x_dim,
155
- y_dim,
156
- ):
157
- """Return spatial window of given pixel size when window partially overlaps input data"""
158
-
159
- # We should never be padding on both sides of a window. This would mean our desired window is
160
+ da: xr.DataArray,
161
+ left_idx: int,
162
+ right_idx: int,
163
+ bottom_idx: int,
164
+ top_idx: int,
165
+ left_pad_pixels: int,
166
+ right_pad_pixels: int,
167
+ bottom_pad_pixels: int,
168
+ top_pad_pixels: int,
169
+ x_dim: str,
170
+ y_dim: str,
171
+ ) -> xr.DataArray:
172
+ """Return spatial window of given pixel size when window partially overlaps input data."""
173
+ # We should never be padding on both sides of a window. This would mean our desired window is
160
174
  # larger than the size of the input data
161
- assert left_pad_pixels==0 or right_pad_pixels==0
162
- assert bottom_pad_pixels==0 or top_pad_pixels==0
175
+ if (left_pad_pixels != 0 and right_pad_pixels != 0) or (
176
+ bottom_pad_pixels != 0 and top_pad_pixels != 0
177
+ ):
178
+ raise ValueError("Cannot pad both sides of the window")
163
179
 
164
180
  dx = np.median(np.diff(da[x_dim].values))
165
181
  dy = np.median(np.diff(da[y_dim].values))
@@ -170,7 +186,7 @@ def _select_partial_spatial_slice_pixels(
170
186
  [
171
187
  da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx,
172
188
  da[x_dim].values[0:right_idx],
173
- ]
189
+ ],
174
190
  )
175
191
  da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel})
176
192
 
@@ -180,7 +196,7 @@ def _select_partial_spatial_slice_pixels(
180
196
  [
181
197
  da[x_dim].values[left_idx:],
182
198
  da[x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx,
183
- ]
199
+ ],
184
200
  )
185
201
  da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel})
186
202
 
@@ -194,7 +210,7 @@ def _select_partial_spatial_slice_pixels(
194
210
  [
195
211
  da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy,
196
212
  da[y_dim].values[0:top_idx],
197
- ]
213
+ ],
198
214
  )
199
215
  da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel})
200
216
 
@@ -204,7 +220,7 @@ def _select_partial_spatial_slice_pixels(
204
220
  [
205
221
  da[y_dim].values[bottom_idx:],
206
222
  da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
207
- ]
223
+ ],
208
224
  )
209
225
  da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})
210
226
 
@@ -216,15 +232,15 @@ def _select_partial_spatial_slice_pixels(
216
232
 
217
233
 
218
234
  def _select_spatial_slice_pixels(
219
- da: xr.DataArray,
220
- center_idx: Location,
221
- width_pixels: int,
222
- height_pixels: int,
223
- x_dim: str,
224
- y_dim: str,
235
+ da: xr.DataArray,
236
+ center_idx: Location,
237
+ width_pixels: int,
238
+ height_pixels: int,
239
+ x_dim: str,
240
+ y_dim: str,
225
241
  allow_partial_slice: bool,
226
- ):
227
- """Select a spatial slice from an xarray object
242
+ ) -> xr.DataArray:
243
+ """Select a spatial slice from an xarray object.
228
244
 
229
245
  Args:
230
246
  da: xarray DataArray to slice from
@@ -235,11 +251,13 @@ def _select_spatial_slice_pixels(
235
251
  y_dim: Name of the y-dimension in `da`
236
252
  allow_partial_slice: Whether to allow a partially filled window
237
253
  """
238
-
239
- assert center_idx.coordinate_system == "idx"
254
+ if center_idx.coordinate_system != "idx":
255
+ raise ValueError(f"Expected center_idx to be in 'idx' coordinates, got '{center_idx}'")
240
256
  # TODO: It shouldn't take much effort to allow height and width to be odd
241
- assert (width_pixels % 2)==0, "Width must be an even number"
242
- assert (height_pixels % 2)==0, "Height must be an even number"
257
+ if (width_pixels % 2) != 0:
258
+ raise ValueError("Width must be an even number")
259
+ if (height_pixels % 2) != 0:
260
+ raise ValueError("Height must be an even number")
243
261
 
244
262
  half_width = width_pixels // 2
245
263
  half_height = height_pixels // 2
@@ -261,14 +279,12 @@ def _select_spatial_slice_pixels(
261
279
 
262
280
  if pad_required:
263
281
  if allow_partial_slice:
264
-
265
282
  left_pad_pixels = (-left_idx) if left_pad_required else 0
266
283
  right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0
267
284
 
268
285
  bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
269
286
  top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0
270
287
 
271
-
272
288
  da = _select_partial_spatial_slice_pixels(
273
289
  da,
274
290
  left_idx,
@@ -287,7 +303,7 @@ def _select_spatial_slice_pixels(
287
303
  f"Window for location {center_idx} not available. Missing (left, right, bottom, "
288
304
  f"top) pixels = ({left_pad_required}, {right_pad_required}, "
289
305
  f"{bottom_pad_required}, {top_pad_required}). "
290
- f"You may wish to set `allow_partial_slice=True`"
306
+ f"You may wish to set `allow_partial_slice=True`",
291
307
  )
292
308
 
293
309
  else:
@@ -295,17 +311,19 @@ def _select_spatial_slice_pixels(
295
311
  {
296
312
  x_dim: slice(left_idx, right_idx),
297
313
  y_dim: slice(bottom_idx, top_idx),
298
- }
314
+ },
299
315
  )
300
316
 
301
- assert len(da[x_dim]) == width_pixels, (
302
- f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
303
- f"for location {center_idx} for slice {left_idx}:{right_idx}"
304
- )
305
- assert len(da[y_dim]) == height_pixels, (
306
- f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
307
- f"for location {center_idx} for slice {bottom_idx}:{top_idx}"
308
- )
317
+ if len(da[x_dim]) != width_pixels:
318
+ raise ValueError(
319
+ f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
320
+ f"for location {center_idx} for slice {left_idx}:{right_idx}",
321
+ )
322
+ if len(da[y_dim]) != height_pixels:
323
+ raise ValueError(
324
+ f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
325
+ f"for location {center_idx} for slice {bottom_idx}:{top_idx}",
326
+ )
309
327
 
310
328
  return da
311
329
 
@@ -319,9 +337,8 @@ def select_spatial_slice_pixels(
319
337
  width_pixels: int,
320
338
  height_pixels: int,
321
339
  allow_partial_slice: bool = False,
322
- ):
323
- """
324
- Select spatial slice based off pixels from location point of interest
340
+ ) -> xr.DataArray:
341
+ """Select spatial slice based off pixels from location point of interest.
325
342
 
326
343
  If `allow_partial_slice` is set to True, then slices may be made which intersect the border
327
344
  of the input data. The additional x and y cordinates that would be required for this slice
@@ -336,7 +353,6 @@ def select_spatial_slice_pixels(
336
353
  width_pixels: Width of the slice in pixels
337
354
  allow_partial_slice: Whether to allow a partial slice.
338
355
  """
339
-
340
356
  xr_coords, x_dim, y_dim = spatial_coord_type(da)
341
357
 
342
358
  if xr_coords == "geostationary":
@@ -354,4 +370,4 @@ def select_spatial_slice_pixels(
354
370
  allow_partial_slice=allow_partial_slice,
355
371
  )
356
372
 
357
- return selected
373
+ return selected
@@ -1,57 +1,80 @@
1
- import xarray as xr
2
- import pandas as pd
1
+ """Select a time slice from a Dataset or DataArray."""
2
+
3
3
  import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+
4
7
 
5
8
  def select_time_slice(
6
- ds: xr.DataArray,
9
+ da: xr.DataArray,
7
10
  t0: pd.Timestamp,
8
11
  interval_start: pd.Timedelta,
9
12
  interval_end: pd.Timedelta,
10
- sample_period_duration: pd.Timedelta,
11
- ):
12
- """Select a time slice from a Dataset or DataArray."""
13
- t0_datetime_utc = pd.Timestamp(t0)
14
- start_dt = t0_datetime_utc + interval_start
15
- end_dt = t0_datetime_utc + interval_end
13
+ time_resolution: pd.Timedelta,
14
+ ) -> xr.DataArray:
15
+ """Select a time slice from a DataArray.
16
+
17
+ Args:
18
+ da: The DataArray to slice from
19
+ t0: The init-time
20
+ interval_start: The start of the interval with respect to t0
21
+ interval_end: The end of the interval with respect to t0
22
+ time_resolution: Distance between neighbouring timestamps
23
+ """
24
+ start_dt = t0 + interval_start
25
+ end_dt = t0 + interval_end
16
26
 
17
- start_dt = start_dt.ceil(sample_period_duration)
18
- end_dt = end_dt.ceil(sample_period_duration)
27
+ start_dt = start_dt.ceil(time_resolution)
28
+ end_dt = end_dt.ceil(time_resolution)
29
+
30
+ return da.sel(time_utc=slice(start_dt, end_dt))
19
31
 
20
- return ds.sel(time_utc=slice(start_dt, end_dt))
21
32
 
22
33
  def select_time_slice_nwp(
23
34
  da: xr.DataArray,
24
35
  t0: pd.Timestamp,
25
36
  interval_start: pd.Timedelta,
26
37
  interval_end: pd.Timedelta,
27
- sample_period_duration: pd.Timedelta,
38
+ time_resolution: pd.Timedelta,
28
39
  dropout_timedeltas: list[pd.Timedelta] | None = None,
29
40
  dropout_frac: float | None = 0,
30
- accum_channels: list[str] = [],
31
- channel_dim_name: str = "channel",
32
- ):
41
+ accum_channels: list[str] | None = None,
42
+ ) -> xr.DataArray:
43
+ """Select a time slice from an NWP DataArray.
44
+
45
+ Args:
46
+ da: The DataArray to slice from
47
+ t0: The init-time
48
+ interval_start: The start of the interval with respect to t0
49
+ interval_end: The end of the interval with respect to t0
50
+ time_resolution: Distance between neighbouring timestamps
51
+ dropout_timedeltas: List of possible timedeltas before t0 where data availability may start
52
+ dropout_frac: Probability to apply dropout
53
+ accum_channels: Channels which are accumulated and need to be differenced
54
+ """
55
+ if accum_channels is None:
56
+ accum_channels = []
57
+
33
58
  if dropout_timedeltas is not None:
34
- assert all(
35
- [t < pd.Timedelta(0) for t in dropout_timedeltas]
36
- ), "dropout timedeltas must be negative"
37
- assert len(dropout_timedeltas) >= 1
38
- assert 0 <= dropout_frac <= 1
39
- consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
59
+ if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
60
+ raise ValueError("dropout timedeltas must be negative")
61
+ if len(dropout_timedeltas) < 1:
62
+ raise ValueError("dropout timedeltas must have at least one element")
63
+
64
+ if not (0 <= dropout_frac <= 1):
65
+ raise ValueError("dropout_frac must be between 0 and 1")
40
66
 
41
- # The accumatation and non-accumulation channels
42
- accum_channels = np.intersect1d(
43
- da[channel_dim_name].values, accum_channels
44
- )
45
- non_accum_channels = np.setdiff1d(
46
- da[channel_dim_name].values, accum_channels
47
- )
67
+ consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
48
68
 
49
- start_dt = (t0 + interval_start).ceil(sample_period_duration)
50
- end_dt = (t0 + interval_end).ceil(sample_period_duration)
69
+ # The accumatated and non-accumulated channels
70
+ accum_channels = np.intersect1d(da.channel.values, accum_channels)
71
+ non_accum_channels = np.setdiff1d(da.channel.values, accum_channels)
51
72
 
52
- target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
73
+ start_dt = (t0 + interval_start).ceil(time_resolution)
74
+ end_dt = (t0 + interval_end).ceil(time_resolution)
75
+ target_times = pd.date_range(start_dt, end_dt, freq=time_resolution)
53
76
 
54
- # Maybe apply NWP dropout
77
+ # Potentially apply NWP dropout
55
78
  if consider_dropout and (np.random.uniform() < dropout_frac):
56
79
  dt = np.random.choice(dropout_timedeltas)
57
80
  t0_available = t0 + dt
@@ -59,9 +82,7 @@ def select_time_slice_nwp(
59
82
  t0_available = t0
60
83
 
61
84
  # Forecasts made up to and including t0
62
- available_init_times = da.init_time_utc.sel(
63
- init_time_utc=slice(None, t0_available)
64
- )
85
+ available_init_times = da.init_time_utc.sel(init_time_utc=slice(None, t0_available))
65
86
 
66
87
  # Find the most recent available init times for all target times
67
88
  selected_init_times = available_init_times.sel(
@@ -74,10 +95,10 @@ def select_time_slice_nwp(
74
95
 
75
96
  # We want one timestep for each target_time_hourly (obviously!) If we simply do
76
97
  # nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
77
- # init_times and steps, which is not what # we want! Instead, we use xarray's
78
- # vectorized-indexing mode by using a DataArray indexer. See the last example here:
98
+ # init_times and steps, which is not what we want! Instead, we use xarray's
99
+ # vectorised-indexing mode via using a DataArray indexer. See the last example here:
79
100
  # https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
80
-
101
+
81
102
  coords = {"target_time_utc": target_times}
82
103
  init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
83
104
  step_indexer = xr.DataArray(steps, coords=coords)
@@ -90,38 +111,30 @@ def select_time_slice_nwp(
90
111
  unique_init_times = np.unique(selected_init_times)
91
112
  # - find the min and max steps we slice over. Max is extended due to diff
92
113
  min_step = min(steps)
93
- max_step = max(steps) + sample_period_duration
114
+ max_step = max(steps) + time_resolution
94
115
 
95
- da_min = da.sel(
96
- {
97
- "init_time_utc": unique_init_times,
98
- "step": slice(min_step, max_step),
99
- }
100
- )
116
+ da_min = da.sel(init_time_utc=unique_init_times, step=slice(min_step, max_step))
101
117
 
102
118
  # Slice out the data which does not need to be diffed
103
- da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
104
- da_sel_non_accum = da_non_accum.sel(
105
- step=step_indexer, init_time_utc=init_time_indexer
106
- )
119
+ da_non_accum = da_min.sel(channel=non_accum_channels)
120
+ da_sel_non_accum = da_non_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
107
121
 
108
122
  # Slice out the channels which need to be diffed
109
- da_accum = da_min.sel({channel_dim_name: accum_channels})
110
-
123
+ da_accum = da_min.sel(channel=accum_channels)
124
+
111
125
  # Take the diff and slice requested data
112
126
  da_accum = da_accum.diff(dim="step", label="lower")
113
127
  da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
114
128
 
115
129
  # Join diffed and non-diffed variables
116
- da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
117
-
130
+ da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim="channel")
131
+
118
132
  # Reorder the variable back to the original order
119
- da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
133
+ da_sel = da_sel.sel(channel=da.channel.values)
120
134
 
121
135
  # Rename the diffed channels
122
- da_sel[channel_dim_name] = [
123
- f"diff_{v}" if v in accum_channels else v
124
- for v in da_sel[channel_dim_name].values
136
+ da_sel["channel"] = [
137
+ f"diff_{v}" if v in accum_channels else v for v in da_sel.channel.values
125
138
  ]
126
139
 
127
140
  return da_sel
@@ -1,4 +1,5 @@
1
- """ Functions for selecting data around a given location """
1
+ """Functions for selecting data around a given location."""
2
+
2
3
  from ocf_data_sampler.config import Configuration
3
4
  from ocf_data_sampler.select.location import Location
4
5
  from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
@@ -9,24 +10,24 @@ def slice_datasets_by_space(
9
10
  location: Location,
10
11
  config: Configuration,
11
12
  ) -> dict:
12
- """Slice the dictionary of input data sources around a given location
13
+ """Slice the dictionary of input data sources around a given location.
13
14
 
14
15
  Args:
15
16
  datasets_dict: Dictionary of the input data sources
16
17
  location: The location to sample around
17
18
  config: Configuration object.
18
19
  """
19
-
20
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"})
20
+ if not set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"}):
21
+ raise ValueError(
22
+ "'datasets_dict' should only contain keys 'nwp', 'sat', 'gsp', 'site'",
23
+ )
21
24
 
22
25
  sliced_datasets_dict = {}
23
26
 
24
27
  if "nwp" in datasets_dict:
25
-
26
28
  sliced_datasets_dict["nwp"] = {}
27
29
 
28
30
  for nwp_key, nwp_config in config.input_data.nwp.items():
29
-
30
31
  sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
31
32
  datasets_dict["nwp"][nwp_key],
32
33
  location,