ocf-data-sampler 0.0.18__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +246 -0
  4. ocf_data_sampler/config/save.py +73 -0
  5. ocf_data_sampler/constants.py +173 -0
  6. ocf_data_sampler/load/load_dataset.py +55 -0
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  8. ocf_data_sampler/load/site.py +30 -0
  9. ocf_data_sampler/numpy_sample/__init__.py +8 -0
  10. ocf_data_sampler/numpy_sample/collate.py +77 -0
  11. ocf_data_sampler/numpy_sample/gsp.py +34 -0
  12. ocf_data_sampler/numpy_sample/nwp.py +42 -0
  13. ocf_data_sampler/numpy_sample/satellite.py +30 -0
  14. ocf_data_sampler/numpy_sample/site.py +30 -0
  15. ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
  16. ocf_data_sampler/select/__init__.py +8 -1
  17. ocf_data_sampler/select/dropout.py +4 -3
  18. ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  19. ocf_data_sampler/select/geospatial.py +160 -0
  20. ocf_data_sampler/select/location.py +62 -0
  21. ocf_data_sampler/select/select_spatial_slice.py +13 -16
  22. ocf_data_sampler/select/select_time_slice.py +24 -33
  23. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  24. ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
  25. ocf_data_sampler/torch_datasets/__init__.py +2 -1
  26. ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
  27. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +19 -427
  28. ocf_data_sampler/torch_datasets/site.py +405 -0
  29. ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
  30. ocf_data_sampler/utils.py +10 -0
  31. ocf_data_sampler-0.0.42.dist-info/METADATA +153 -0
  32. ocf_data_sampler-0.0.42.dist-info/RECORD +71 -0
  33. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/WHEEL +1 -1
  34. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/top_level.txt +1 -0
  35. scripts/refactor_site.py +50 -0
  36. tests/config/test_config.py +161 -0
  37. tests/config/test_save.py +37 -0
  38. tests/conftest.py +86 -1
  39. tests/load/test_load_gsp.py +15 -0
  40. tests/load/test_load_nwp.py +21 -0
  41. tests/load/test_load_satellite.py +17 -0
  42. tests/load/test_load_sites.py +14 -0
  43. tests/numpy_sample/test_collate.py +26 -0
  44. tests/numpy_sample/test_gsp.py +38 -0
  45. tests/numpy_sample/test_nwp.py +52 -0
  46. tests/numpy_sample/test_satellite.py +40 -0
  47. tests/numpy_sample/test_sun_position.py +81 -0
  48. tests/select/test_dropout.py +75 -0
  49. tests/select/test_fill_time_periods.py +28 -0
  50. tests/select/test_find_contiguous_time_periods.py +202 -0
  51. tests/select/test_location.py +67 -0
  52. tests/select/test_select_spatial_slice.py +154 -0
  53. tests/select/test_select_time_slice.py +272 -0
  54. tests/torch_datasets/conftest.py +18 -0
  55. tests/torch_datasets/test_process_and_combine.py +126 -0
  56. tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
  57. tests/torch_datasets/test_site.py +129 -0
  58. ocf_data_sampler/numpy_batch/__init__.py +0 -7
  59. ocf_data_sampler/numpy_batch/gsp.py +0 -20
  60. ocf_data_sampler/numpy_batch/nwp.py +0 -33
  61. ocf_data_sampler/numpy_batch/satellite.py +0 -23
  62. ocf_data_sampler-0.0.18.dist-info/METADATA +0 -22
  63. ocf_data_sampler-0.0.18.dist-info/RECORD +0 -32
  64. {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/LICENSE +0 -0
@@ -2,104 +2,20 @@
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
+ import pkg_resources
5
6
  import xarray as xr
6
7
  from torch.utils.data import Dataset
7
- import pkg_resources
8
-
9
- from ocf_data_sampler.load.gsp import open_gsp
10
- from ocf_data_sampler.load.nwp import open_nwp
11
- from ocf_data_sampler.load.satellite import open_sat_data
12
-
13
- from ocf_data_sampler.select.find_contiguous_time_periods import (
14
- find_contiguous_t0_periods, find_contiguous_t0_periods_nwp,
15
- intersection_of_multiple_dataframes_of_periods,
16
- )
17
- from ocf_data_sampler.select.fill_time_periods import fill_time_periods
18
- from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
19
- from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
20
- from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
21
-
22
- from ocf_data_sampler.numpy_batch import (
23
- convert_gsp_to_numpy_batch,
24
- convert_nwp_to_numpy_batch,
25
- convert_satellite_to_numpy_batch,
26
- make_sun_position_numpy_batch,
27
- )
28
-
29
-
30
- from ocf_datapipes.config.model import Configuration
31
- from ocf_datapipes.config.load import load_yaml_configuration
32
- from ocf_datapipes.batch import BatchKey, NumpyBatch
33
-
34
- from ocf_datapipes.utils.location import Location
35
- from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
36
-
37
- from ocf_datapipes.utils.consts import (
38
- NWP_MEANS,
39
- NWP_STDS,
40
- )
41
-
42
- from ocf_datapipes.training.common import concat_xr_time_utc, normalize_gsp
43
-
44
8
 
9
+ from ocf_data_sampler.config import Configuration, load_yaml_configuration
10
+ from ocf_data_sampler.load.load_dataset import get_dataset_dict
11
+ from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
12
+ from ocf_data_sampler.utils import minutes
13
+ from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
14
+ from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
45
15
 
46
16
  xr.set_options(keep_attrs=True)
47
17
 
48
18
 
49
-
50
- def minutes(minutes: list[float]):
51
- """Timedelta minutes
52
-
53
- Args:
54
- m: minutes
55
- """
56
- return pd.to_timedelta(minutes, unit="m")
57
-
58
-
59
- def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]:
60
- """Construct dictionary of all of the input data sources
61
-
62
- Args:
63
- config: Configuration file
64
- """
65
-
66
- in_config = config.input_data
67
-
68
- datasets_dict = {}
69
-
70
- # Load GSP data unless the path is None
71
- if in_config.gsp.gsp_zarr_path:
72
- da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
73
-
74
- # Remove national GSP
75
- datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
76
-
77
- # Load NWP data if in config
78
- if in_config.nwp:
79
-
80
- datasets_dict["nwp"] = {}
81
- for nwp_source, nwp_config in in_config.nwp.items():
82
-
83
- da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
84
-
85
- da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
86
-
87
- datasets_dict["nwp"][nwp_source] = da_nwp
88
-
89
- # Load satellite data if in config
90
- if in_config.satellite:
91
- sat_config = config.input_data.satellite
92
-
93
- da_sat = open_sat_data(sat_config.satellite_zarr_path)
94
-
95
- da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
96
-
97
- datasets_dict["sat"] = da_sat
98
-
99
- return datasets_dict
100
-
101
-
102
-
103
19
  def find_valid_t0_times(
104
20
  datasets_dict: dict,
105
21
  config: Configuration,
@@ -107,96 +23,11 @@ def find_valid_t0_times(
107
23
  """Find the t0 times where all of the requested input data is available
108
24
 
109
25
  Args:
110
- datasets_dict: A dictionary of input datasets
26
+ datasets_dict: A dictionary of input datasets
111
27
  config: Configuration file
112
28
  """
113
29
 
114
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
115
-
116
- contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
117
-
118
- if "nwp" in datasets_dict:
119
- for nwp_key, nwp_config in config.input_data.nwp.items():
120
-
121
- da = datasets_dict["nwp"][nwp_key]
122
-
123
- if nwp_config.dropout_timedeltas_minutes is None:
124
- max_dropout = minutes(0)
125
- else:
126
- max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
127
-
128
- if nwp_config.max_staleness_minutes is None:
129
- max_staleness = None
130
- else:
131
- max_staleness = minutes(nwp_config.max_staleness_minutes)
132
-
133
- # The last step of the forecast is lost if we have to diff channels
134
- if len(nwp_config.nwp_accum_channels) > 0:
135
- end_buffer = minutes(nwp_config.time_resolution_minutes)
136
- else:
137
- end_buffer = minutes(0)
138
-
139
- # This is the max staleness we can use considering the max step of the input data
140
- max_possible_staleness = (
141
- pd.Timedelta(da["step"].max().item())
142
- - minutes(nwp_config.forecast_minutes)
143
- - end_buffer
144
- )
145
-
146
- # Default to use max possible staleness unless specified in config
147
- if max_staleness is None:
148
- max_staleness = max_possible_staleness
149
- else:
150
- # Make sure the max acceptable staleness isn't longer than the max possible
151
- assert max_staleness <= max_possible_staleness
152
-
153
- time_periods = find_contiguous_t0_periods_nwp(
154
- datetimes=pd.DatetimeIndex(da["init_time_utc"]),
155
- history_duration=minutes(nwp_config.history_minutes),
156
- max_staleness=max_staleness,
157
- max_dropout=max_dropout,
158
- )
159
-
160
- contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
161
-
162
- if "sat" in datasets_dict:
163
- sat_config = config.input_data.satellite
164
-
165
- time_periods = find_contiguous_t0_periods(
166
- pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
167
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
168
- history_duration=minutes(sat_config.history_minutes),
169
- forecast_duration=minutes(sat_config.forecast_minutes),
170
- )
171
-
172
- contiguous_time_periods['sat'] = time_periods
173
-
174
- if "gsp" in datasets_dict:
175
- gsp_config = config.input_data.gsp
176
-
177
- time_periods = find_contiguous_t0_periods(
178
- pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
179
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
180
- history_duration=minutes(gsp_config.history_minutes),
181
- forecast_duration=minutes(gsp_config.forecast_minutes),
182
- )
183
-
184
- contiguous_time_periods['gsp'] = time_periods
185
-
186
- # just get the values (not the keys)
187
- contiguous_time_periods_values = list(contiguous_time_periods.values())
188
-
189
- # Find joint overlapping contiguous time periods
190
- if len(contiguous_time_periods_values) > 1:
191
- valid_time_periods = intersection_of_multiple_dataframes_of_periods(
192
- contiguous_time_periods_values
193
- )
194
- else:
195
- valid_time_periods = contiguous_time_periods_values[0]
196
-
197
- # check there are some valid time periods
198
- if len(valid_time_periods) == 0:
199
- raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
30
+ valid_time_periods = find_valid_time_periods(datasets_dict, config)
200
31
 
201
32
  # Fill out the contiguous time periods to get the t0 times
202
33
  valid_t0_times = fill_time_periods(
@@ -207,252 +38,12 @@ def find_valid_t0_times(
207
38
  return valid_t0_times
208
39
 
209
40
 
210
- def slice_datasets_by_space(
211
- datasets_dict: dict,
212
- location: Location,
213
- config: Configuration,
214
- ) -> dict:
215
- """Slice a dictionaries of input data sources around a given location
216
-
217
- Args:
218
- datasets_dict: Dictionary of the input data sources
219
- location: The location to sample around
220
- config: Configuration object.
221
- """
222
-
223
- assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
224
-
225
- sliced_datasets_dict = {}
226
-
227
- if "nwp" in datasets_dict:
228
-
229
- sliced_datasets_dict["nwp"] = {}
230
-
231
- for nwp_key, nwp_config in config.input_data.nwp.items():
232
-
233
- sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
234
- datasets_dict["nwp"][nwp_key],
235
- location,
236
- height_pixels=nwp_config.nwp_image_size_pixels_height,
237
- width_pixels=nwp_config.nwp_image_size_pixels_width,
238
- )
239
-
240
- if "sat" in datasets_dict:
241
- sat_config = config.input_data.satellite
242
-
243
- sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
244
- datasets_dict["sat"],
245
- location,
246
- height_pixels=sat_config.satellite_image_size_pixels_height,
247
- width_pixels=sat_config.satellite_image_size_pixels_width,
248
- )
249
-
250
- if "gsp" in datasets_dict:
251
- sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
252
-
253
- return sliced_datasets_dict
254
-
255
-
256
- def slice_datasets_by_time(
257
- datasets_dict: dict,
258
- t0: pd.Timedelta,
259
- config: Configuration,
260
- ) -> dict:
261
- """Slice a dictionaries of input data sources around a given t0 time
262
-
263
- Args:
264
- datasets_dict: Dictionary of the input data sources
265
- t0: The init-time
266
- config: Configuration object.
267
- """
268
-
269
- sliced_datasets_dict = {}
270
-
271
- if "nwp" in datasets_dict:
272
-
273
- sliced_datasets_dict["nwp"] = {}
274
-
275
- for nwp_key, da_nwp in datasets_dict["nwp"].items():
276
-
277
- nwp_config = config.input_data.nwp[nwp_key]
278
-
279
- sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
280
- da_nwp,
281
- t0,
282
- sample_period_duration=minutes(nwp_config.time_resolution_minutes),
283
- history_duration=minutes(nwp_config.history_minutes),
284
- forecast_duration=minutes(nwp_config.forecast_minutes),
285
- dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
286
- dropout_frac=nwp_config.dropout_fraction,
287
- accum_channels=nwp_config.nwp_accum_channels,
288
- )
289
-
290
- if "sat" in datasets_dict:
291
-
292
- sat_config = config.input_data.satellite
293
-
294
- sliced_datasets_dict["sat"] = select_time_slice(
295
- datasets_dict["sat"],
296
- t0,
297
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
298
- interval_start=minutes(-sat_config.history_minutes),
299
- interval_end=minutes(-sat_config.live_delay_minutes),
300
- max_steps_gap=2,
301
- )
302
-
303
- # Randomly sample dropout
304
- sat_dropout_time = draw_dropout_time(
305
- t0,
306
- dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
307
- dropout_frac=sat_config.dropout_fraction,
308
- )
309
-
310
- # Apply the dropout
311
- sliced_datasets_dict["sat"] = apply_dropout_time(
312
- sliced_datasets_dict["sat"],
313
- sat_dropout_time,
314
- )
315
-
316
- if "gsp" in datasets_dict:
317
- gsp_config = config.input_data.gsp
318
-
319
- sliced_datasets_dict["gsp_future"] = select_time_slice(
320
- datasets_dict["gsp"],
321
- t0,
322
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
323
- interval_start=minutes(30),
324
- interval_end=minutes(gsp_config.forecast_minutes),
325
- )
326
-
327
- sliced_datasets_dict["gsp"] = select_time_slice(
328
- datasets_dict["gsp"],
329
- t0,
330
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
331
- interval_start=-minutes(gsp_config.history_minutes),
332
- interval_end=minutes(0),
333
- )
334
-
335
- # Dropout on the GSP, but not the future GSP
336
- gsp_dropout_time = draw_dropout_time(
337
- t0,
338
- dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
339
- dropout_frac=gsp_config.dropout_fraction,
340
- )
341
-
342
- sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)
343
-
344
- return sliced_datasets_dict
345
-
346
-
347
- def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch:
348
- """Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
349
-
350
- Operation is performed in-place on the batch.
351
- """
352
- for k, v in batch.items():
353
- if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
354
- if np.isnan(v).any():
355
- batch[k] = np.nan_to_num(v, copy=False, nan=0.0)
356
-
357
- # Recursion is included to reach NWP arrays in subdict
358
- elif isinstance(v, dict):
359
- fill_nans_in_arrays(v)
360
-
361
- return batch
362
-
363
-
364
-
365
- def merge_dicts(list_of_dicts: list[dict]) -> dict:
366
- """Merge a list of dictionaries into a single dictionary"""
367
- # TODO: This doesn't account for duplicate keys, which will be overwritten
368
- combined_dict = {}
369
- for d in list_of_dicts:
370
- combined_dict.update(d)
371
- return combined_dict
372
-
373
-
374
- def process_and_combine_datasets(
375
- dataset_dict: dict,
376
- config: Configuration,
377
- t0: pd.Timedelta,
378
- location: Location,
379
- ) -> NumpyBatch:
380
- """Normalize and convert data to numpy arrays"""
381
-
382
- numpy_modalities = []
383
-
384
- if "nwp" in dataset_dict:
385
-
386
- nwp_numpy_modalities = dict()
387
-
388
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
389
- # Standardise
390
- provider = config.input_data.nwp[nwp_key].nwp_provider
391
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
392
- # Convert to NumpyBatch
393
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
394
-
395
- # Combine the NWPs into NumpyBatch
396
- numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities})
397
-
398
- if "sat" in dataset_dict:
399
- # Satellite is already in the range [0-1] so no need to standardise
400
- da_sat = dataset_dict["sat"]
401
-
402
- # Convert to NumpyBatch
403
- numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
404
-
405
- gsp_config = config.input_data.gsp
41
+ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
42
+ """Get list of locations of all GSPs"""
406
43
 
407
- if "gsp" in dataset_dict:
408
- da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
409
- da_gsp = normalize_gsp(da_gsp)
410
-
411
- numpy_modalities.append(
412
- convert_gsp_to_numpy_batch(
413
- da_gsp,
414
- t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
415
- )
416
- )
417
-
418
- # Make sun coords NumpyBatch
419
- datetimes = pd.date_range(
420
- t0-minutes(gsp_config.history_minutes),
421
- t0+minutes(gsp_config.forecast_minutes),
422
- freq=minutes(gsp_config.time_resolution_minutes),
423
- )
424
-
425
- lon, lat = osgb_to_lon_lat(location.x, location.y)
426
-
427
- numpy_modalities.append(make_sun_position_numpy_batch(datetimes, lon, lat))
44
+ if gsp_ids is None:
45
+ gsp_ids = [i for i in range(1, 318)]
428
46
 
429
- # Add coordinate data
430
- # TODO: Do we need all of these?
431
- numpy_modalities.append({
432
- BatchKey.gsp_id: location.id,
433
- BatchKey.gsp_x_osgb: location.x,
434
- BatchKey.gsp_y_osgb: location.y,
435
- })
436
-
437
- # Combine all the modalities and fill NaNs
438
- combined_sample = merge_dicts(numpy_modalities)
439
- combined_sample = fill_nans_in_arrays(combined_sample)
440
-
441
- return combined_sample
442
-
443
-
444
- def compute(xarray_dict: dict) -> dict:
445
- """Eagerly load a nested dictionary of xarray DataArrays"""
446
- for k, v in xarray_dict.items():
447
- if isinstance(v, dict):
448
- xarray_dict[k] = compute(v)
449
- else:
450
- xarray_dict[k] = v.compute(scheduler="single-threaded")
451
- return xarray_dict
452
-
453
-
454
- def get_gsp_locations() -> list[Location]:
455
- """Get list of locations of all GSPs"""
456
47
  locations = []
457
48
 
458
49
  # Load UK GSP locations
@@ -461,7 +52,7 @@ def get_gsp_locations() -> list[Location]:
461
52
  index_col="gsp_id",
462
53
  )
463
54
 
464
- for gsp_id in np.arange(1, 318):
55
+ for gsp_id in gsp_ids:
465
56
  locations.append(
466
57
  Location(
467
58
  coordinate_system = "osgb",
@@ -473,13 +64,13 @@ def get_gsp_locations() -> list[Location]:
473
64
  return locations
474
65
 
475
66
 
476
-
477
67
  class PVNetUKRegionalDataset(Dataset):
478
68
  def __init__(
479
69
  self,
480
70
  config_filename: str,
481
71
  start_time: str | None = None,
482
72
  end_time: str| None = None,
73
+ gsp_ids: list[int] | None = None,
483
74
  ):
484
75
  """A torch Dataset for creating PVNet UK GSP samples
485
76
 
@@ -487,6 +78,7 @@ class PVNetUKRegionalDataset(Dataset):
487
78
  config_filename: Path to the configuration file
488
79
  start_time: Limit the init-times to be after this
489
80
  end_time: Limit the init-times to be before this
81
+ gsp_ids: List of GSP IDs to create samples for. Defaults to all
490
82
  """
491
83
 
492
84
  config = load_yaml_configuration(config_filename)
@@ -504,7 +96,7 @@ class PVNetUKRegionalDataset(Dataset):
504
96
  valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
505
97
 
506
98
  # Construct list of locations to sample from
507
- locations = get_gsp_locations()
99
+ locations = get_gsp_locations(gsp_ids)
508
100
 
509
101
  # Construct a lookup for locations - useful for users to construct sample by GSP ID
510
102
  location_lookup = {loc.id: loc for loc in locations}
@@ -533,7 +125,7 @@ class PVNetUKRegionalDataset(Dataset):
533
125
  return len(self.index_pairs)
534
126
 
535
127
 
536
- def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch:
128
+ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
537
129
  """Generate the PVNet sample for given coordinates
538
130
 
539
131
  Args:
@@ -560,7 +152,7 @@ class PVNetUKRegionalDataset(Dataset):
560
152
  return self._get_sample(t0, location)
561
153
 
562
154
 
563
- def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:
155
+ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
564
156
  """Generate a sample for the given coordinates.
565
157
 
566
158
  Useful for users to generate samples by GSP ID.