ocf-data-sampler 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -4,5 +4,3 @@ from .find_contiguous_time_periods import (
4
4
  intersection_of_multiple_dataframes_of_periods,
5
5
  )
6
6
  from .location import Location
7
- from .spatial_slice_for_dataset import slice_datasets_by_space
8
- from .time_slice_for_dataset import slice_datasets_by_time
@@ -17,80 +17,64 @@ from ocf_data_sampler.select.location import Location
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
- # -------------------------------- utility functions --------------------------------
21
-
22
-
23
- def convert_coords_to_match_xarray(
20
+ def convert_coordinates(
21
+ from_coords: str,
24
22
  x: float | np.ndarray,
25
23
  y: float | np.ndarray,
26
- from_coords: str,
27
24
  da: xr.DataArray,
28
25
  ) -> tuple[float | np.ndarray, float | np.ndarray]:
29
- """Convert x and y coords to cooridnate system matching xarray data.
26
+ """Convert x and y coordinates to coordinate system matching xarray data.
30
27
 
31
28
  Args:
32
- x: Float or array-like
33
- y: Float or array-like
34
- from_coords: String describing coordinate system of x and y
35
- da: DataArray to which coordinates should be matched
29
+ from_coords: The coordinate system to convert from.
30
+ x: The x-coordinate to convert.
31
+ y: The y-coordinate to convert.
32
+ da: The xarray DataArray used for context (e.g., for geostationary conversion).
33
+
34
+ Returns:
35
+ The converted (x, y) coordinates.
36
36
  """
37
37
  target_coords, *_ = spatial_coord_type(da)
38
38
 
39
39
  match (from_coords, target_coords):
40
40
  case ("osgb", "geostationary"):
41
41
  x, y = osgb_to_geostationary_area_coords(x, y, da)
42
-
43
42
  case ("osgb", "lon_lat"):
44
43
  x, y = osgb_to_lon_lat(x, y)
45
-
46
44
  case ("osgb", "osgb"):
47
45
  pass
48
-
49
46
  case ("lon_lat", "osgb"):
50
47
  x, y = lon_lat_to_osgb(x, y)
51
-
52
48
  case ("lon_lat", "geostationary"):
53
49
  x, y = lon_lat_to_geostationary_area_coords(x, y, da)
54
-
55
50
  case ("lon_lat", "lon_lat"):
56
51
  pass
57
-
58
52
  case (_, _):
59
53
  raise NotImplementedError(
60
- f"Conversion from {from_coords} to {target_coords} is not supported",
54
+ f"Conversion from {from_coords} to "
55
+ f"{target_coords} is not supported",
61
56
  )
62
-
63
57
  return x, y
64
58
 
65
59
 
66
- # TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
67
- # We should combine them, and consider making a Coord class to help with this
68
- def _get_idx_of_pixel_closest_to_poi(
69
- da: xr.DataArray,
70
- location: Location,
71
- ) -> Location:
72
- """Return x and y index location of pixel at center of region of interest.
60
+ def _get_pixel_index_location(da: xr.DataArray, location: Location) -> Location:
61
+ """Find pixel index location closest to given Location.
73
62
 
74
63
  Args:
75
- da: xarray DataArray
76
- location: Location to find index of
64
+ da: The xarray DataArray.
65
+ location: The Location object representing the point of interest.
66
+
77
67
  Returns:
78
- The Location for the center pixel
68
+ A Location object with x and y attributes representing the pixel indices.
69
+
70
+ Raises:
71
+ ValueError: If the location is outside the bounds of the DataArray.
79
72
  """
80
73
  xr_coords, x_dim, y_dim = spatial_coord_type(da)
81
74
 
82
- if xr_coords not in ["osgb", "lon_lat"]:
83
- raise NotImplementedError(f"Only 'osgb' and 'lon_lat' are supported - not '{xr_coords}'")
75
+ x, y = convert_coordinates(location.coordinate_system, location.x, location.y, da)
84
76
 
85
- # Convert location coords to match xarray data
86
- x, y = convert_coords_to_match_xarray(
87
- location.x,
88
- location.y,
89
- from_coords=location.coordinate_system,
90
- da=da,
91
- )
92
-
93
- # Check that the requested point lies within the data
77
+ # Check that requested point lies within the data
94
78
  if not (da[x_dim].min() < x < da[x_dim].max()):
95
79
  raise ValueError(
96
80
  f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
@@ -102,84 +86,53 @@ def _get_idx_of_pixel_closest_to_poi(
102
86
 
103
87
  x_index = da.get_index(x_dim)
104
88
  y_index = da.get_index(y_dim)
105
-
106
89
  closest_x = x_index.get_indexer([x], method="nearest")[0]
107
90
  closest_y = y_index.get_indexer([y], method="nearest")[0]
108
91
 
109
92
  return Location(x=closest_x, y=closest_y, coordinate_system="idx")
110
93
 
111
94
 
112
- def _get_idx_of_pixel_closest_to_poi_geostationary(
113
- da: xr.DataArray,
114
- center: Location,
115
- ) -> Location:
116
- """Return x and y index location of pixel at center of region of interest.
117
-
118
- Args:
119
- da: xarray DataArray
120
- center: Center in OSGB coordinates
121
-
122
- Returns:
123
- Location for the center pixel in geostationary coordinates
124
- """
125
- _, x_dim, y_dim = spatial_coord_type(da)
126
-
127
- if center.coordinate_system == "osgb":
128
- x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
129
- elif center.coordinate_system == "lon_lat":
130
- x, y = lon_lat_to_geostationary_area_coords(
131
- longitude=center.x,
132
- latitude=center.y,
133
- xr_data=da,
134
- )
135
- else:
136
- x, y = center.x, center.y
137
- center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
138
-
139
- # Check that the requested point lies within the data
140
- if not (da[x_dim].min() < x < da[x_dim].max()):
141
- raise ValueError(
142
- f"{x} is not in the interval {da[x_dim].min().values}: {da[x_dim].max().values}",
143
- )
144
- if not (da[y_dim].min() < y < da[y_dim].max()):
145
- raise ValueError(
146
- f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}",
147
- )
148
-
149
- # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
150
- x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
151
- y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)
152
-
153
- return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
154
-
155
-
156
- # ---------------------------- sub-functions for slicing ----------------------------
157
-
158
-
159
- def _select_partial_spatial_slice_pixels(
95
+ def _select_padded_slice(
160
96
  da: xr.DataArray,
161
97
  left_idx: int,
162
98
  right_idx: int,
163
99
  bottom_idx: int,
164
100
  top_idx: int,
165
- left_pad_pixels: int,
166
- right_pad_pixels: int,
167
- bottom_pad_pixels: int,
168
- top_pad_pixels: int,
169
101
  x_dim: str,
170
102
  y_dim: str,
171
103
  ) -> xr.DataArray:
172
- """Return spatial window of given pixel size when window partially overlaps input data."""
173
- # We should never be padding on both sides of a window. This would mean our desired window is
174
- # larger than the size of the input data
175
- if (left_pad_pixels != 0 and right_pad_pixels != 0) or (
176
- bottom_pad_pixels != 0 and top_pad_pixels != 0
104
+ """Selects spatial slice - padding where necessary if indices are out of bounds.
105
+
106
+ Args:
107
+ da: xarray DataArray.
108
+ left_idx: The leftmost index of the slice.
109
+ right_idx: The rightmost index of the slice.
110
+ bottom_idx: The bottommost index of the slice.
111
+ top_idx: The topmost index of the slice.
112
+ x_dim: Name of the x dimension.
113
+ y_dim: Name of the y dimension.
114
+
115
+ Returns:
116
+ An xarray DataArray with padding, if necessary.
117
+ """
118
+ data_width_pixels = len(da[x_dim])
119
+ data_height_pixels = len(da[y_dim])
120
+
121
+ left_pad_pixels = max(0, -left_idx)
122
+ right_pad_pixels = max(0, right_idx - data_width_pixels)
123
+ bottom_pad_pixels = max(0, -bottom_idx)
124
+ top_pad_pixels = max(0, top_idx - data_height_pixels)
125
+
126
+ if (left_pad_pixels > 0 and right_pad_pixels > 0) or (
127
+ bottom_pad_pixels > 0 and top_pad_pixels > 0
177
128
  ):
178
129
  raise ValueError("Cannot pad both sides of the window")
179
130
 
180
131
  dx = np.median(np.diff(da[x_dim].values))
181
132
  dy = np.median(np.diff(da[y_dim].values))
182
133
 
134
+ # Create a new DataArray which has indices which go outside
135
+ # the original DataArray
183
136
  # Pad the left of the window
184
137
  if left_pad_pixels > 0:
185
138
  x_sel = np.concatenate(
@@ -222,7 +175,7 @@ def _select_partial_spatial_slice_pixels(
222
175
  da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
223
176
  ],
224
177
  )
225
- da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})
178
+ da = da.isel({y_dim: slice(bottom_idx, None)}).reindex({y_dim: y_sel})
226
179
 
227
180
  # No bottom-top padding required
228
181
  else:
@@ -231,34 +184,38 @@ def _select_partial_spatial_slice_pixels(
231
184
  return da
232
185
 
233
186
 
234
- def _select_spatial_slice_pixels(
187
+ def select_spatial_slice_pixels(
235
188
  da: xr.DataArray,
236
- center_idx: Location,
189
+ location: Location,
237
190
  width_pixels: int,
238
191
  height_pixels: int,
239
- x_dim: str,
240
- y_dim: str,
241
- allow_partial_slice: bool,
192
+ allow_partial_slice: bool = False,
242
193
  ) -> xr.DataArray:
243
- """Select a spatial slice from an xarray object.
194
+ """Select spatial slice based off pixels from location point of interest.
244
195
 
245
196
  Args:
246
197
  da: xarray DataArray to slice from
247
- center_idx: Location object describing the centre of the window with index coordinates
248
- width_pixels: Window with in pixels
249
- height_pixels: Window height in pixels
250
- x_dim: Name of the x-dimension in `da`
251
- y_dim: Name of the y-dimension in `da`
252
- allow_partial_slice: Whether to allow a partially filled window
198
+ location: Location of interest that will be the center of the returned slice
199
+ height_pixels: Height of the slice in pixels
200
+ width_pixels: Width of the slice in pixels
201
+ allow_partial_slice: Whether to allow a partial slice.
202
+
203
+ Returns:
204
+ The selected DataArray slice.
205
+
206
+ Raises:
207
+ ValueError: If the dimensions are not even or the slice is not allowed
208
+ when padding is required.
209
+
253
210
  """
254
- if center_idx.coordinate_system != "idx":
255
- raise ValueError(f"Expected center_idx to be in 'idx' coordinates, got '{center_idx}'")
256
- # TODO: It shouldn't take much effort to allow height and width to be odd
257
211
  if (width_pixels % 2) != 0:
258
212
  raise ValueError("Width must be an even number")
259
213
  if (height_pixels % 2) != 0:
260
214
  raise ValueError("Height must be an even number")
261
215
 
216
+ _, x_dim, y_dim = spatial_coord_type(da)
217
+ center_idx = _get_pixel_index_location(da, location)
218
+
262
219
  half_width = width_pixels // 2
263
220
  half_height = height_pixels // 2
264
221
 
@@ -270,104 +227,29 @@ def _select_spatial_slice_pixels(
270
227
  data_width_pixels = len(da[x_dim])
271
228
  data_height_pixels = len(da[y_dim])
272
229
 
273
- left_pad_required = left_idx < 0
274
- right_pad_required = right_idx > data_width_pixels
275
- bottom_pad_required = bottom_idx < 0
276
- top_pad_required = top_idx > data_height_pixels
277
-
278
- pad_required = left_pad_required | right_pad_required | bottom_pad_required | top_pad_required
230
+ # Padding checks
231
+ pad_required = (
232
+ left_idx < 0
233
+ or right_idx > data_width_pixels
234
+ or bottom_idx < 0
235
+ or top_idx > data_height_pixels
236
+ )
279
237
 
280
238
  if pad_required:
281
239
  if allow_partial_slice:
282
- left_pad_pixels = (-left_idx) if left_pad_required else 0
283
- right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0
284
-
285
- bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
286
- top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0
287
-
288
- da = _select_partial_spatial_slice_pixels(
289
- da,
290
- left_idx,
291
- right_idx,
292
- bottom_idx,
293
- top_idx,
294
- left_pad_pixels,
295
- right_pad_pixels,
296
- bottom_pad_pixels,
297
- top_pad_pixels,
298
- x_dim,
299
- y_dim,
300
- )
240
+ da = _select_padded_slice(da, left_idx, right_idx, bottom_idx, top_idx, x_dim, y_dim)
301
241
  else:
302
242
  raise ValueError(
303
- f"Window for location {center_idx} not available. Missing (left, right, bottom, "
304
- f"top) pixels = ({left_pad_required}, {right_pad_required}, "
305
- f"{bottom_pad_required}, {top_pad_required}). "
306
- f"You may wish to set `allow_partial_slice=True`",
243
+ f"Window for location {location} not available. Padding required. "
244
+ "You may wish to set `allow_partial_slice=True`",
307
245
  )
308
-
309
246
  else:
310
- da = da.isel(
311
- {
312
- x_dim: slice(left_idx, right_idx),
313
- y_dim: slice(bottom_idx, top_idx),
314
- },
315
- )
247
+ # Standard selection - without padding
248
+ da = da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)})
316
249
 
317
250
  if len(da[x_dim]) != width_pixels:
318
- raise ValueError(
319
- f"Expected x-dim len {width_pixels} got {len(da[x_dim])} "
320
- f"for location {center_idx} for slice {left_idx}:{right_idx}",
321
- )
251
+ raise ValueError(f"x-dim has size {len(da[x_dim])}, expected {width_pixels}")
322
252
  if len(da[y_dim]) != height_pixels:
323
- raise ValueError(
324
- f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
325
- f"for location {center_idx} for slice {bottom_idx}:{top_idx}",
326
- )
253
+ raise ValueError(f"y-dim has size {len(da[y_dim])}, expected {height_pixels}")
327
254
 
328
255
  return da
329
-
330
-
331
- # ---------------------------- main functions for slicing ---------------------------
332
-
333
-
334
- def select_spatial_slice_pixels(
335
- da: xr.DataArray,
336
- location: Location,
337
- width_pixels: int,
338
- height_pixels: int,
339
- allow_partial_slice: bool = False,
340
- ) -> xr.DataArray:
341
- """Select spatial slice based off pixels from location point of interest.
342
-
343
- If `allow_partial_slice` is set to True, then slices may be made which intersect the border
344
- of the input data. The additional x and y cordinates that would be required for this slice
345
- are extrapolated based on the average spacing of these coordinates in the input data.
346
- However, currently slices cannot be made where the centre of the window is outside of the
347
- input data.
348
-
349
- Args:
350
- da: xarray DataArray to slice from
351
- location: Location of interest
352
- height_pixels: Height of the slice in pixels
353
- width_pixels: Width of the slice in pixels
354
- allow_partial_slice: Whether to allow a partial slice.
355
- """
356
- xr_coords, x_dim, y_dim = spatial_coord_type(da)
357
-
358
- if xr_coords == "geostationary":
359
- center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(da, location)
360
- else:
361
- center_idx: Location = _get_idx_of_pixel_closest_to_poi(da, location)
362
-
363
- selected = _select_spatial_slice_pixels(
364
- da,
365
- center_idx,
366
- width_pixels,
367
- height_pixels,
368
- x_dim,
369
- y_dim,
370
- allow_partial_slice=allow_partial_slice,
371
- )
372
-
373
- return selected
@@ -17,16 +17,17 @@ from ocf_data_sampler.numpy_sample import (
17
17
  make_sun_position_numpy_sample,
18
18
  )
19
19
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
20
+ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
20
21
  from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
21
22
  from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
22
- from ocf_data_sampler.select import (
23
- Location,
24
- fill_time_periods,
23
+ from ocf_data_sampler.select import Location, fill_time_periods
24
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
25
+ from ocf_data_sampler.torch_datasets.utils import (
26
+ channel_dict_to_dataarray,
27
+ find_valid_time_periods,
25
28
  slice_datasets_by_space,
26
29
  slice_datasets_by_time,
27
30
  )
28
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
- from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
31
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
32
  fill_nans_in_arrays,
32
33
  merge_dicts,
@@ -36,99 +37,6 @@ from ocf_data_sampler.utils import minutes
36
37
  xr.set_options(keep_attrs=True)
37
38
 
38
39
 
39
- def process_and_combine_datasets(
40
- dataset_dict: dict,
41
- config: Configuration,
42
- t0: pd.Timestamp,
43
- location: Location,
44
- ) -> dict:
45
- """Normalise and convert data to numpy arrays."""
46
- numpy_modalities = []
47
-
48
- if "nwp" in dataset_dict:
49
- nwp_numpy_modalities = {}
50
-
51
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
52
-
53
- # Standardise and convert to NumpyBatch
54
-
55
- da_channel_means = channel_dict_to_dataarray(
56
- config.input_data.nwp[nwp_key].channel_means,
57
- )
58
- da_channel_stds = channel_dict_to_dataarray(
59
- config.input_data.nwp[nwp_key].channel_stds,
60
- )
61
-
62
- da_nwp = (da_nwp - da_channel_means) / da_channel_stds
63
-
64
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
65
-
66
- # Combine the NWPs into NumpyBatch
67
- numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
68
-
69
- if "sat" in dataset_dict:
70
- da_sat = dataset_dict["sat"]
71
-
72
- # Standardise and convert to NumpyBatch
73
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
74
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
75
-
76
- da_sat = (da_sat - da_channel_means) / da_channel_stds
77
-
78
- numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
79
-
80
- if "gsp" in dataset_dict:
81
- gsp_config = config.input_data.gsp
82
- da_gsp = dataset_dict["gsp"]
83
- da_gsp = da_gsp / da_gsp.effective_capacity_mwp
84
-
85
- # Convert to NumpyBatch
86
- numpy_modalities.append(
87
- convert_gsp_to_numpy_sample(
88
- da_gsp,
89
- t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
90
- ),
91
- )
92
-
93
- # Add GSP location data
94
- numpy_modalities.append(
95
- {
96
- GSPSampleKey.gsp_id: location.id,
97
- GSPSampleKey.x_osgb: location.x,
98
- GSPSampleKey.y_osgb: location.y,
99
- },
100
- )
101
-
102
- # Only add solar position if explicitly configured
103
- has_solar_config = (
104
- hasattr(config.input_data, "solar_position") and
105
- config.input_data.solar_position is not None
106
- )
107
-
108
- if has_solar_config:
109
- solar_config = config.input_data.solar_position
110
-
111
- # Create datetime range for solar position calculation
112
- datetimes = pd.date_range(
113
- t0 + minutes(solar_config.interval_start_minutes),
114
- t0 + minutes(solar_config.interval_end_minutes),
115
- freq=minutes(solar_config.time_resolution_minutes),
116
- )
117
-
118
- # Convert OSGB coordinates to lon/lat
119
- lon, lat = osgb_to_lon_lat(location.x, location.y)
120
-
121
- # Calculate solar positions and add to modalities
122
- solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
123
- numpy_modalities.append(solar_positions)
124
-
125
- # Combine all the modalities and fill NaNs
126
- combined_sample = merge_dicts(numpy_modalities)
127
- combined_sample = fill_nans_in_arrays(combined_sample)
128
-
129
- return combined_sample
130
-
131
-
132
40
  def compute(xarray_dict: dict) -> dict:
133
41
  """Eagerly load a nested dictionary of xarray DataArrays."""
134
42
  for k, v in xarray_dict.items():
@@ -139,25 +47,12 @@ def compute(xarray_dict: dict) -> dict:
139
47
  return xarray_dict
140
48
 
141
49
 
142
- def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
143
- """Find the t0 times where all of the requested input data is available.
50
+ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
51
+ """Get list of locations of all GSPs.
144
52
 
145
53
  Args:
146
- datasets_dict: A dictionary of input datasets
147
- config: Configuration file
54
+ gsp_ids: List of GSP IDs to include. Defaults to all
148
55
  """
149
- valid_time_periods = find_valid_time_periods(datasets_dict, config)
150
-
151
- # Fill out the contiguous time periods to get the t0 times
152
- valid_t0_times = fill_time_periods(
153
- valid_time_periods,
154
- freq=minutes(config.input_data.gsp.time_resolution_minutes),
155
- )
156
- return valid_t0_times
157
-
158
-
159
- def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
160
- """Get list of locations of all GSPs."""
161
56
  if gsp_ids is None:
162
57
  gsp_ids = list(range(1, 318))
163
58
 
@@ -181,8 +76,8 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
181
76
  return locations
182
77
 
183
78
 
184
- class PVNetUKRegionalDataset(Dataset):
185
- """A torch Dataset for creating PVNet UK regional samples."""
79
+ class AbstractPVNetUKDataset(Dataset):
80
+ """Abstract class for PVNet UK datasets."""
186
81
 
187
82
  def __init__(
188
83
  self,
@@ -191,7 +86,7 @@ class PVNetUKRegionalDataset(Dataset):
191
86
  end_time: str | None = None,
192
87
  gsp_ids: list[int] | None = None,
193
88
  ) -> None:
194
- """A torch Dataset for creating PVNet UK GSP samples.
89
+ """A torch Dataset for creating PVNet UK samples.
195
90
 
196
91
  Args:
197
92
  config_filename: Path to the configuration file
@@ -199,13 +94,11 @@ class PVNetUKRegionalDataset(Dataset):
199
94
  end_time: Limit the init-times to be before this
200
95
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
201
96
  """
202
- # config = load_yaml_configuration(config_filename)
203
- config: Configuration = load_yaml_configuration(config_filename)
204
-
97
+ config = load_yaml_configuration(config_filename)
205
98
  datasets_dict = get_dataset_dict(config.input_data)
206
99
 
207
100
  # Get t0 times where all input data is available
208
- valid_t0_times = find_valid_t0_times(datasets_dict, config)
101
+ valid_t0_times = self.find_valid_t0_times(datasets_dict, config)
209
102
 
210
103
  # Filter t0 times to given range
211
104
  if start_time is not None:
@@ -215,35 +108,167 @@ class PVNetUKRegionalDataset(Dataset):
215
108
  valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
216
109
 
217
110
  # Construct list of locations to sample from
218
- locations = get_gsp_locations(gsp_ids)
111
+ self.locations = get_gsp_locations(gsp_ids)
112
+ self.valid_t0_times = valid_t0_times
113
+
114
+ # Assign config and input data to self
115
+ self.config = config
116
+ self.datasets_dict = datasets_dict
117
+
118
+
119
+ @staticmethod
120
+ def process_and_combine_datasets(
121
+ dataset_dict: dict,
122
+ config: Configuration,
123
+ t0: pd.Timestamp,
124
+ location: Location,
125
+ ) -> NumpySample:
126
+ """Normalise and convert data to numpy arrays.
127
+
128
+ Args:
129
+ dataset_dict: Dictionary of xarray datasets
130
+ config: Configuration object
131
+ t0: init-time for sample
132
+ location: location of the sample
133
+ """
134
+ numpy_modalities = []
135
+
136
+ if "nwp" in dataset_dict:
137
+ nwp_numpy_modalities = {}
138
+
139
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
140
+
141
+ # Standardise and convert to NumpyBatch
142
+
143
+ da_channel_means = channel_dict_to_dataarray(
144
+ config.input_data.nwp[nwp_key].channel_means,
145
+ )
146
+ da_channel_stds = channel_dict_to_dataarray(
147
+ config.input_data.nwp[nwp_key].channel_stds,
148
+ )
149
+
150
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
151
+
152
+ nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
153
+
154
+ # Combine the NWPs into NumpyBatch
155
+ numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
156
+
157
+ if "sat" in dataset_dict:
158
+ da_sat = dataset_dict["sat"]
159
+
160
+ # Standardise and convert to NumpyBatch
161
+ da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
162
+ da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
163
+
164
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
165
+
166
+ numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
167
+
168
+ if "gsp" in dataset_dict:
169
+ gsp_config = config.input_data.gsp
170
+ da_gsp = dataset_dict["gsp"]
171
+ da_gsp = da_gsp / da_gsp.effective_capacity_mwp
172
+
173
+ # Convert to NumpyBatch
174
+ numpy_modalities.append(
175
+ convert_gsp_to_numpy_sample(
176
+ da_gsp,
177
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
178
+ ),
179
+ )
180
+
181
+ # Add GSP location data
182
+ numpy_modalities.append(
183
+ {
184
+ GSPSampleKey.gsp_id: location.id,
185
+ GSPSampleKey.x_osgb: location.x,
186
+ GSPSampleKey.y_osgb: location.y,
187
+ },
188
+ )
189
+
190
+ # Only add solar position if explicitly configured
191
+ has_solar_config = (
192
+ hasattr(config.input_data, "solar_position") and
193
+ config.input_data.solar_position is not None
194
+ )
195
+
196
+ if has_solar_config:
197
+ solar_config = config.input_data.solar_position
198
+
199
+ # Create datetime range for solar position calculation
200
+ datetimes = pd.date_range(
201
+ t0 + minutes(solar_config.interval_start_minutes),
202
+ t0 + minutes(solar_config.interval_end_minutes),
203
+ freq=minutes(solar_config.time_resolution_minutes),
204
+ )
205
+
206
+ # Convert OSGB coordinates to lon/lat
207
+ lon, lat = osgb_to_lon_lat(location.x, location.y)
208
+
209
+ # Calculate solar positions and add to modalities
210
+ numpy_modalities.append(make_sun_position_numpy_sample(datetimes, lon, lat))
211
+
212
+ # Combine all the modalities and fill NaNs
213
+ combined_sample = merge_dicts(numpy_modalities)
214
+ combined_sample = fill_nans_in_arrays(combined_sample)
215
+
216
+ return combined_sample
217
+
218
+ @staticmethod
219
+ def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
220
+ """Find the t0 times where all of the requested input data is available.
221
+
222
+ Args:
223
+ datasets_dict: A dictionary of input datasets
224
+ config: Configuration file
225
+ """
226
+ valid_time_periods = find_valid_time_periods(datasets_dict, config)
227
+
228
+ # Fill out the contiguous time periods to get the t0 times
229
+ valid_t0_times = fill_time_periods(
230
+ valid_time_periods,
231
+ freq=minutes(config.input_data.gsp.time_resolution_minutes),
232
+ )
233
+ return valid_t0_times
234
+
235
+
236
+
237
+ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
238
+ """A torch Dataset for creating PVNet UK regional samples."""
239
+
240
+ @override
241
+ def __init__(
242
+ self,
243
+ config_filename: str,
244
+ start_time: str | None = None,
245
+ end_time: str | None = None,
246
+ gsp_ids: list[int] | None = None,
247
+ ) -> None:
248
+
249
+ super().__init__(config_filename, start_time, end_time, gsp_ids)
219
250
 
220
251
  # Construct a lookup for locations - useful for users to construct sample by GSP ID
221
- location_lookup = {loc.id: loc for loc in locations}
252
+ location_lookup = {loc.id: loc for loc in self.locations}
222
253
 
223
254
  # Construct indices for sampling
224
255
  t_index, loc_index = np.meshgrid(
225
- np.arange(len(valid_t0_times)),
226
- np.arange(len(locations)),
256
+ np.arange(len(self.valid_t0_times)),
257
+ np.arange(len(self.locations)),
227
258
  )
228
259
 
229
260
  # Make array of all possible (t0, location) coordinates. Each row is a single coordinate
230
261
  index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
231
262
 
232
263
  # Assign coords and indices to self
233
- self.valid_t0_times = valid_t0_times
234
- self.locations = locations
235
264
  self.location_lookup = location_lookup
236
265
  self.index_pairs = index_pairs
237
266
 
238
- # Assign config and input data to self
239
- self.datasets_dict = datasets_dict
240
- self.config = config
241
-
242
267
  @override
243
268
  def __len__(self) -> int:
244
269
  return len(self.index_pairs)
245
270
 
246
- def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
271
+ def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpySample:
247
272
  """Generate the PVNet sample for given coordinates.
248
273
 
249
274
  Args:
@@ -254,21 +279,18 @@ class PVNetUKRegionalDataset(Dataset):
254
279
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
255
280
  sample_dict = compute(sample_dict)
256
281
 
257
- sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
258
-
259
- return sample
282
+ return self.process_and_combine_datasets(sample_dict, self.config, t0, location)
260
283
 
261
284
  @override
262
- def __getitem__(self, idx: int) -> dict:
285
+ def __getitem__(self, idx: int) -> NumpySample:
263
286
  # Get the coordinates of the sample
264
287
  t_index, loc_index = self.index_pairs[idx]
265
288
  location = self.locations[loc_index]
266
289
  t0 = self.valid_t0_times[t_index]
267
290
 
268
- # Generate the sample
269
291
  return self._get_sample(t0, location)
270
292
 
271
- def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
293
+ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpySample:
272
294
  """Generate a sample for the given coordinates.
273
295
 
274
296
  Useful for users to generate specific samples.
@@ -288,56 +310,14 @@ class PVNetUKRegionalDataset(Dataset):
288
310
  return self._get_sample(t0, location)
289
311
 
290
312
 
291
- class PVNetUKConcurrentDataset(Dataset):
313
+ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
292
314
  """A torch Dataset for creating concurrent PVNet UK regional samples."""
293
315
 
294
- def __init__(
295
- self,
296
- config_filename: str,
297
- start_time: str | None = None,
298
- end_time: str | None = None,
299
- gsp_ids: list[int] | None = None,
300
- ) -> None:
301
- """A torch Dataset for creating concurrent samples of PVNet UK regional data.
302
-
303
- Each concurrent sample includes the data from all GSPs for a single t0 time
304
-
305
- Args:
306
- config_filename: Path to the configuration file
307
- start_time: Limit the init-times to be after this
308
- end_time: Limit the init-times to be before this
309
- gsp_ids: List of all GSP IDs included in each sample. Defaults to all
310
- """
311
- config = load_yaml_configuration(config_filename)
312
-
313
- datasets_dict = get_dataset_dict(config.input_data)
314
-
315
- # Get t0 times where all input data is available
316
- valid_t0_times = find_valid_t0_times(datasets_dict, config)
317
-
318
- # Filter t0 times to given range
319
- if start_time is not None:
320
- valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
321
-
322
- if end_time is not None:
323
- valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
324
-
325
- # Construct list of locations to sample from
326
- locations = get_gsp_locations(gsp_ids)
327
-
328
- # Assign coords and indices to self
329
- self.valid_t0_times = valid_t0_times
330
- self.locations = locations
331
-
332
- # Assign config and input data to self
333
- self.datasets_dict = datasets_dict
334
- self.config = config
335
-
336
316
  @override
337
317
  def __len__(self) -> int:
338
318
  return len(self.valid_t0_times)
339
319
 
340
- def _get_sample(self, t0: pd.Timestamp) -> dict:
320
+ def _get_sample(self, t0: pd.Timestamp) -> NumpyBatch:
341
321
  """Generate a concurrent PVNet sample for given init-time.
342
322
 
343
323
  Args:
@@ -352,7 +332,7 @@ class PVNetUKConcurrentDataset(Dataset):
352
332
  # Prepare sample for each GSP
353
333
  for location in self.locations:
354
334
  gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
355
- gsp_numpy_sample = process_and_combine_datasets(
335
+ gsp_numpy_sample = self.process_and_combine_datasets(
356
336
  gsp_sample_dict,
357
337
  self.config,
358
338
  t0,
@@ -364,10 +344,10 @@ class PVNetUKConcurrentDataset(Dataset):
364
344
  return stack_np_samples_into_batch(gsp_samples)
365
345
 
366
346
  @override
367
- def __getitem__(self, idx: int) -> dict:
347
+ def __getitem__(self, idx: int) -> NumpyBatch:
368
348
  return self._get_sample(self.valid_t0_times[idx])
369
349
 
370
- def get_sample(self, t0: pd.Timestamp) -> dict:
350
+ def get_sample(self, t0: pd.Timestamp) -> NumpyBatch:
371
351
  """Generate a sample for the given init-time.
372
352
 
373
353
  Useful for users to generate specific samples.
@@ -1,14 +1,12 @@
1
1
  """Torch dataset for sites."""
2
2
 
3
- import logging
4
-
5
3
  import numpy as np
6
4
  import pandas as pd
7
5
  import xarray as xr
8
6
  from torch.utils.data import Dataset
9
7
  from typing_extensions import override
10
8
 
11
- from ocf_data_sampler.config import Configuration, load_yaml_configuration
9
+ from ocf_data_sampler.config import load_yaml_configuration
12
10
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
13
11
  from ocf_data_sampler.numpy_sample import (
14
12
  NWPSampleKey,
@@ -18,15 +16,19 @@ from ocf_data_sampler.numpy_sample import (
18
16
  make_datetime_numpy_dict,
19
17
  make_sun_position_numpy_sample,
20
18
  )
19
+ from ocf_data_sampler.numpy_sample.common_types import NumpySample
21
20
  from ocf_data_sampler.select import (
22
21
  Location,
23
22
  fill_time_periods,
24
23
  find_contiguous_t0_periods,
25
24
  intersection_of_multiple_dataframes_of_periods,
25
+ )
26
+ from ocf_data_sampler.torch_datasets.utils import (
27
+ channel_dict_to_dataarray,
28
+ find_valid_time_periods,
26
29
  slice_datasets_by_space,
27
30
  slice_datasets_by_time,
28
31
  )
29
- from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
32
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
33
  fill_nans_in_arrays,
32
34
  merge_dicts,
@@ -52,7 +54,7 @@ class SitesDataset(Dataset):
52
54
  start_time: Limit the init-times to be after this
53
55
  end_time: Limit the init-times to be before this
54
56
  """
55
- config: Configuration = load_yaml_configuration(config_filename)
57
+ config = load_yaml_configuration(config_filename)
56
58
  datasets_dict = get_dataset_dict(config.input_data)
57
59
 
58
60
  # Assign config and input data to self
@@ -61,6 +63,7 @@ class SitesDataset(Dataset):
61
63
 
62
64
  # get all locations
63
65
  self.locations = self.get_locations(datasets_dict["site"])
66
+ self.location_lookup = {loc.id: loc for loc in self.locations}
64
67
 
65
68
  # Get t0 times where all input data is available
66
69
  valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
@@ -89,7 +92,7 @@ class SitesDataset(Dataset):
89
92
  t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
90
93
 
91
94
  # get location from site id
92
- location = self.get_location_from_site_id(site_id)
95
+ location = self.location_lookup[site_id]
93
96
 
94
97
  # Generate the sample
95
98
  return self._get_sample(t0, location)
@@ -105,8 +108,7 @@ class SitesDataset(Dataset):
105
108
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
106
109
 
107
110
  sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
108
- sample = sample.compute()
109
- return sample
111
+ return sample.compute()
110
112
 
111
113
  def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
112
114
  """Generate a sample for a given site id and t0.
@@ -117,22 +119,10 @@ class SitesDataset(Dataset):
117
119
  t0: init-time for sample
118
120
  site_id: site id as int
119
121
  """
120
- location = self.get_location_from_site_id(site_id)
122
+ location = self.location_lookup[site_id]
121
123
 
122
124
  return self._get_sample(t0, location)
123
125
 
124
- def get_location_from_site_id(self, site_id: int) -> Location:
125
- """Get location from system id."""
126
- locations = [loc for loc in self.locations if loc.id == site_id]
127
- if len(locations) == 0:
128
- raise ValueError(f"Location not found for site_id {site_id}")
129
-
130
- if len(locations) > 1:
131
- logging.warning(
132
- f"Multiple locations found for site_id {site_id}, but will take the first",
133
- )
134
-
135
- return locations[0]
136
126
 
137
127
  def find_valid_t0_and_site_ids(
138
128
  self,
@@ -148,24 +138,21 @@ class SitesDataset(Dataset):
148
138
  datasets_dict: A dictionary of input datasets
149
139
  config: Configuration file
150
140
  """
151
- # 1. Get valid time period for nwp and satellite
141
+ # Get valid time period for nwp and satellite
152
142
  datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
153
143
  valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
154
144
 
155
- # 2. Now lets loop over each location in system id and find the valid periods
156
- # Should we have a different option if there are not nans
145
+ # Loop over each location in system id and obtain valid periods
157
146
  sites = datasets_dict["site"]
158
147
  site_ids = sites.site_id.values
159
148
  site_config = self.config.input_data.site
160
149
  valid_t0_and_site_ids = []
161
150
  for site_id in site_ids:
162
151
  site = sites.sel(site_id=site_id)
163
-
164
- # drop any nan values
165
- # not sure this is right?
152
+ # Drop NaN values
166
153
  site = site.dropna(dim="time_utc")
167
154
 
168
- # Get the valid time periods for this location
155
+ # Obtain valid time periods for this location
169
156
  time_periods = find_contiguous_t0_periods(
170
157
  pd.DatetimeIndex(site["time_utc"]),
171
158
  time_resolution=minutes(site_config.time_resolution_minutes),
@@ -176,7 +163,7 @@ class SitesDataset(Dataset):
176
163
  [valid_time_periods, time_periods],
177
164
  )
178
165
 
179
- # Fill out the contiguous time periods to get the t0 times
166
+ # Fill out contiguous time periods to get t0 times
180
167
  valid_t0_times_per_site = fill_time_periods(
181
168
  valid_time_periods_per_site,
182
169
  freq=minutes(site_config.time_resolution_minutes),
@@ -188,12 +175,15 @@ class SitesDataset(Dataset):
188
175
 
189
176
  valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
190
177
  valid_t0_and_site_ids.index.name = "t0"
191
- valid_t0_and_site_ids.reset_index(inplace=True)
178
+ return valid_t0_and_site_ids.reset_index()
192
179
 
193
- return valid_t0_and_site_ids
194
180
 
195
181
  def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
196
- """Get list of locations of all sites."""
182
+ """Get list of locations of all sites.
183
+
184
+ Args:
185
+ site_xr: xarray Dataset of site data
186
+ """
197
187
  locations = []
198
188
  for site_id in site_xr.site_id.values:
199
189
  site = site_xr.sel(site_id=site_id)
@@ -220,7 +210,6 @@ class SitesDataset(Dataset):
220
210
 
221
211
  Returns:
222
212
  xr.Dataset: A merged Dataset with nans filled in.
223
-
224
213
  """
225
214
  data_arrays = []
226
215
 
@@ -228,7 +217,6 @@ class SitesDataset(Dataset):
228
217
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
229
218
  provider = self.config.input_data.nwp[nwp_key].provider
230
219
 
231
- # Standardise
232
220
  da_channel_means = channel_dict_to_dataarray(
233
221
  self.config.input_data.nwp[nwp_key].channel_means,
234
222
  )
@@ -237,7 +225,6 @@ class SitesDataset(Dataset):
237
225
  )
238
226
 
239
227
  da_nwp = (da_nwp - da_channel_means) / da_channel_stds
240
-
241
228
  data_arrays.append((f"nwp-{provider}", da_nwp))
242
229
 
243
230
  if "sat" in dataset_dict:
@@ -251,11 +238,9 @@ class SitesDataset(Dataset):
251
238
  )
252
239
 
253
240
  da_sat = (da_sat - da_channel_means) / da_channel_stds
254
-
255
241
  data_arrays.append(("satellite", da_sat))
256
242
 
257
243
  if "site" in dataset_dict:
258
- # site_config = config.input_data.site
259
244
  da_sites = dataset_dict["site"]
260
245
  da_sites = da_sites / da_sites.capacity_kwp
261
246
  data_arrays.append(("site", da_sites))
@@ -372,12 +357,16 @@ class SitesDataset(Dataset):
372
357
 
373
358
 
374
359
  def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
375
- """Convert a netcdf dataset to a numpy sample."""
360
+ """Convert a netcdf dataset to a numpy sample.
361
+
362
+ Args:
363
+ ds: xarray Dataset
364
+ """
376
365
  # convert the single dataset to a dict of arrays
377
366
  sample_dict = convert_from_dataset_to_dict_datasets(ds)
378
367
 
379
368
  if "satellite" in sample_dict:
380
- # rename satellite to satellite actual # TODO this could be improves
369
+ # rename satellite to sat # TODO this could be improved
381
370
  sample_dict["sat"] = sample_dict.pop("satellite")
382
371
 
383
372
  # process and combine the datasets
@@ -408,43 +397,52 @@ def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[
408
397
  The uncombined datasets as a dict of xr.Datasets
409
398
  """
410
399
  # Split into datasets by splitting by the prefix added in combine_to_netcdf
411
- datasets = {}
400
+ datasets: dict[str, xr.DataArray] = {}
401
+
412
402
  # Go through each data variable and split it into a dataset
413
403
  for key, dataset in combined_dataset.items():
414
- # If 'key_' doesn't exist in a dim or coordinate, remove it
415
- dataset_dims = list(dataset.coords)
416
- for dim in dataset_dims:
404
+ # If 'key__' doesn't exist in a dim or coordinate, remove it
405
+ for dim in list(dataset.coords):
417
406
  if f"{key}__" not in dim:
418
- dataset: xr.Dataset = dataset.drop(dim)
407
+ dataset = dataset.drop_vars(dim)
419
408
  dataset = dataset.rename(
420
409
  {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
421
410
  )
422
- dataset: xr.Dataset = dataset.rename(
411
+ dataset = dataset.rename(
423
412
  {coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
424
413
  )
425
414
  # Split the dataset by the prefix
426
415
  datasets[key] = dataset
427
416
 
428
417
  # Unflatten any NWP data
429
- datasets = nest_nwp_source_dict(datasets, sep="-")
430
- return datasets
418
+ return nest_nwp_source_dict(datasets, sep="-")
419
+
431
420
 
421
+ def nest_nwp_source_dict(
422
+ dataset_dict: dict[xr.Dataset],
423
+ sep: str = "-",
424
+ ) -> dict[str, xr.Dataset | dict[xr.Dataset]]:
425
+ """Re-nest a dictionary where the NWP values are nested under keys 'nwp-<key>'.
432
426
 
433
- def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
434
- """Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
427
+ Args:
428
+ dataset_dict: Dictionary of datasets
429
+ sep: Separator to use to nest NWP keys
430
+ """
435
431
  nwp_prefix = f"nwp{sep}"
436
- new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
437
- nwp_keys = [k for k in d if k.startswith(nwp_prefix)]
432
+ new_dict = {k: v for k, v in dataset_dict.items() if not k.startswith(nwp_prefix)}
433
+ nwp_keys = [k for k in dataset_dict if k.startswith(nwp_prefix)]
438
434
  if len(nwp_keys) > 0:
439
- nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
435
+ nwp_subdict = {k.removeprefix(nwp_prefix): dataset_dict[k] for k in nwp_keys}
440
436
  new_dict["nwp"] = nwp_subdict
441
437
  return new_dict
442
438
 
443
439
 
444
- def convert_to_numpy_and_combine(
445
- dataset_dict: dict,
446
- ) -> dict:
447
- """Convert input data in a dict to numpy arrays."""
440
+ def convert_to_numpy_and_combine(dataset_dict: dict[xr.Dataset]) -> NumpySample:
441
+ """Convert input data in a dict to numpy arrays.
442
+
443
+ Args:
444
+ dataset_dict: Dictionary of xarray Datasets
445
+ """
448
446
  numpy_modalities = []
449
447
 
450
448
  if "nwp" in dataset_dict:
@@ -474,9 +472,7 @@ def convert_to_numpy_and_combine(
474
472
 
475
473
  # Combine all the modalities and fill NaNs
476
474
  combined_sample = merge_dicts(numpy_modalities)
477
- combined_sample = fill_nans_in_arrays(combined_sample)
478
-
479
- return combined_sample
475
+ return fill_nans_in_arrays(combined_sample)
480
476
 
481
477
 
482
478
  def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
@@ -0,0 +1,3 @@
1
+ from .base import SampleBase
2
+ from .uk_regional import UKRegionalSample
3
+ from .site import SiteSample
@@ -4,9 +4,10 @@ import xarray as xr
4
4
  from typing_extensions import override
5
5
 
6
6
  from ocf_data_sampler.numpy_sample.common_types import NumpySample
7
- from ocf_data_sampler.sample.base import SampleBase
8
7
  from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
9
8
 
9
+ from .base import SampleBase
10
+
10
11
 
11
12
  class SiteSample(SampleBase):
12
13
  """Handles PVNet site specific netCDF operations."""
@@ -9,7 +9,8 @@ from ocf_data_sampler.numpy_sample import (
9
9
  SatelliteSampleKey,
10
10
  )
11
11
  from ocf_data_sampler.numpy_sample.common_types import NumpySample
12
- from ocf_data_sampler.sample.base import SampleBase
12
+
13
+ from .base import SampleBase
13
14
 
14
15
 
15
16
  class UKRegionalSample(SampleBase):
@@ -1,3 +1,5 @@
1
1
  from .channel_dict_to_dataarray import channel_dict_to_dataarray
2
2
  from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
- from .valid_time_periods import find_valid_time_periods
3
+ from .valid_time_periods import find_valid_time_periods
4
+ from .spatial_slice_for_dataset import slice_datasets_by_space
5
+ from .time_slice_for_dataset import slice_datasets_by_time
@@ -1,10 +1,17 @@
1
- """Converts a dictionary of channel values to a DataArray."""
1
+ """Utility function for converting channel dictionaries to xarray DataArrays."""
2
2
 
3
3
  import xarray as xr
4
4
 
5
5
 
6
6
  def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
- """Converts a dictionary of channel values to a DataArray."""
7
+ """Converts a dictionary of channel values to a DataArray.
8
+
9
+ Args:
10
+ channel_dict: Dictionary mapping channel names (str) to their values (float).
11
+
12
+ Returns:
13
+ xr.DataArray: A 1D DataArray with channels as coordinates.
14
+ """
8
15
  return xr.DataArray(
9
16
  list(channel_dict.values()),
10
17
  coords={"channel": list(channel_dict.keys())},
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -35,6 +35,7 @@ Requires-Dist: numpy
35
35
  Requires-Dist: pandas
36
36
  Requires-Dist: xarray
37
37
  Requires-Dist: zarr==2.18.3
38
+ Requires-Dist: numcodecs<0.16
38
39
  Requires-Dist: dask
39
40
  Requires-Dist: matplotlib
40
41
  Requires-Dist: ocf_blosc2
@@ -29,30 +29,30 @@ ocf_data_sampler/numpy_sample/nwp.py,sha256=X9T5XZLVucXX8QAUhdeTnomNBPrsfvsO8I4S
29
29
  ocf_data_sampler/numpy_sample/satellite.py,sha256=RaYzYIcB1AmDrKeiqSpn4QVfBH-QMe26F1P5t1az2Jg,1111
30
30
  ocf_data_sampler/numpy_sample/site.py,sha256=zfYBjK3CJrIaKH1QdKXU7gwOxTqONt527y3nJ9TRnwc,1325
31
31
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=5tt-zNm6aRuZMsxZPaAxyg7HeikswfZCeHWXTHuO2K0,1555
32
- ocf_data_sampler/sample/__init__.py,sha256=zdS73NTnxFX_j8uh9tT-IXiURB6635wbneM1koWYV1o,169
33
- ocf_data_sampler/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
34
- ocf_data_sampler/sample/site.py,sha256=BhQPygeLUncXJGN_Yd2CL050kN6ktZlobaJw0O0RagI,1290
35
- ocf_data_sampler/sample/uk_regional.py,sha256=VOby07RnZYvzszExwqoZRVwZ1EbCclRpXq1e9CL16CE,2463
36
- ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
32
+ ocf_data_sampler/select/__init__.py,sha256=mK7Wu_-j9IXGTYrOuDf5yDDuU5a306b0iGKTAooNg_s,210
37
33
  ocf_data_sampler/select/dropout.py,sha256=WVOCweTGfIjufAlnfmYiPofz6X38TxQgzkLwtiB3TrU,1712
38
34
  ocf_data_sampler/select/fill_time_periods.py,sha256=TlGxp1xiAqnhdWfLy0pv3FuZc00dtimjWdLzr4JoTGA,865
39
35
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=cEXrQDzk8pXknxB0q3v5DakosagHMoLDAj302B8Xpw0,11537
40
36
  ocf_data_sampler/select/geospatial.py,sha256=CDExkl36eZOKmdJPzUr_K0Wn3axHqv5nYo-EkSiINcc,5032
41
37
  ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O5Deu0c,1037
42
- ocf_data_sampler/select/select_spatial_slice.py,sha256=qY2Ll00EPA80oBtzwMoR5nk0UIpoWZF9oXl22YwWr0Q,12341
38
+ ocf_data_sampler/select/select_spatial_slice.py,sha256=liAqIa-Amj58pOqx5r16i99HURj9oQ41j7gnPgRDQP4,8201
43
39
  ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
44
- ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
45
- ocf_data_sampler/select/time_slice_for_dataset.py,sha256=1DN6VsWWdLvkpJxodZtBRDUgC4vJE2td_RP5J3ZqPNw,4268
46
40
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
47
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=ZkXm0IQEIzZUi8O-qJJz2PbJr9T4ZvutL424yRQUJhc,12878
48
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=j29cWPIcksRbge014MxR0_OgJqoskdki6KqvtoHtxpY,18023
49
- ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=7Yt4anQVU9y27nj4Wx1tRLqbAQLbzW0ED71UL65LvxA,187
50
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=MGylKhXxXLQC2fYv-8L_GVoYhov3LcEwC0Q21xItDSk,353
41
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=tx5Sg64eknhU6VIcONiAaG2PurN6Y8Te6rE3AaWg8t4,12338
42
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nRUlhXQQGVrTuBmE1QnwXAUsPTXz0dsezlQjwK71jIQ,17641
43
+ ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
44
+ ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
45
+ ocf_data_sampler/torch_datasets/sample/site.py,sha256=ZUEgn50g-GmqujOEtezNILF7wjokF80sDAA4OOldcRI,1268
46
+ ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=zpCeUw3eljOnoJTSUYW2R4kiWrY6hbuXjK8igJrXgPg,2441
47
+ ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=N7i_hHtWUDiJqsiJoDx4T_QuiYOuvIyulPrn6xEA4TY,309
48
+ ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=un2IiyoAmTDIymdeMiPU899_86iCDMD-oIifjHlNyqw,555
51
49
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
50
+ ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
51
+ ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=1DN6VsWWdLvkpJxodZtBRDUgC4vJE2td_RP5J3ZqPNw,4268
52
52
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=xcy75cVxl0WrglnX5YUAFjXXlO2GwEBHWyqo8TDuiOA,4714
53
53
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
54
54
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
55
- ocf_data_sampler-0.2.8.dist-info/METADATA,sha256=vfGBLPNsG5G5dPeZmdt0H38EK1LIQexvh2-BEwmi2dc,11594
56
- ocf_data_sampler-0.2.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
57
- ocf_data_sampler-0.2.8.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
58
- ocf_data_sampler-0.2.8.dist-info/RECORD,,
55
+ ocf_data_sampler-0.2.10.dist-info/METADATA,sha256=CEhASIN7vsyVYY8ZQzIXbdZrI3VhDip_gX7Hwct2p-M,11625
56
+ ocf_data_sampler-0.2.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
57
+ ocf_data_sampler-0.2.10.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
58
+ ocf_data_sampler-0.2.10.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- from ocf_data_sampler.sample.base import SampleBase
2
- from ocf_data_sampler.sample.uk_regional import UKRegionalSample
3
- from ocf_data_sampler.sample.site import SiteSample