ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (78) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +146 -64
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/load/gsp.py +6 -5
  5. ocf_data_sampler/load/load_dataset.py +5 -6
  6. ocf_data_sampler/load/nwp/nwp.py +17 -5
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  8. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  9. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  10. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  11. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  12. ocf_data_sampler/load/satellite.py +9 -10
  13. ocf_data_sampler/load/site.py +10 -6
  14. ocf_data_sampler/load/utils.py +21 -16
  15. ocf_data_sampler/numpy_sample/collate.py +10 -9
  16. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  17. ocf_data_sampler/numpy_sample/gsp.py +12 -14
  18. ocf_data_sampler/numpy_sample/nwp.py +12 -12
  19. ocf_data_sampler/numpy_sample/satellite.py +9 -9
  20. ocf_data_sampler/numpy_sample/site.py +5 -8
  21. ocf_data_sampler/numpy_sample/sun_position.py +16 -21
  22. ocf_data_sampler/sample/base.py +15 -17
  23. ocf_data_sampler/sample/site.py +13 -20
  24. ocf_data_sampler/sample/uk_regional.py +29 -35
  25. ocf_data_sampler/select/dropout.py +16 -14
  26. ocf_data_sampler/select/fill_time_periods.py +15 -5
  27. ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
  28. ocf_data_sampler/select/geospatial.py +63 -54
  29. ocf_data_sampler/select/location.py +16 -51
  30. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  31. ocf_data_sampler/select/select_time_slice.py +71 -58
  32. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  33. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  34. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
  35. ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
  36. ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
  37. ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/utils.py +3 -1
  41. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
  42. ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
  43. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
  44. {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
  45. scripts/refactor_site.py +63 -33
  46. utils/compute_icon_mean_stddev.py +72 -0
  47. ocf_data_sampler/constants.py +0 -222
  48. ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
  49. ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
  50. ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
  51. tests/__init__.py +0 -0
  52. tests/config/test_config.py +0 -113
  53. tests/config/test_load.py +0 -7
  54. tests/config/test_save.py +0 -28
  55. tests/conftest.py +0 -319
  56. tests/load/test_load_gsp.py +0 -15
  57. tests/load/test_load_nwp.py +0 -21
  58. tests/load/test_load_satellite.py +0 -17
  59. tests/load/test_load_sites.py +0 -14
  60. tests/numpy_sample/test_collate.py +0 -21
  61. tests/numpy_sample/test_datetime_features.py +0 -37
  62. tests/numpy_sample/test_gsp.py +0 -38
  63. tests/numpy_sample/test_nwp.py +0 -13
  64. tests/numpy_sample/test_satellite.py +0 -40
  65. tests/numpy_sample/test_sun_position.py +0 -81
  66. tests/select/test_dropout.py +0 -69
  67. tests/select/test_fill_time_periods.py +0 -28
  68. tests/select/test_find_contiguous_time_periods.py +0 -202
  69. tests/select/test_location.py +0 -67
  70. tests/select/test_select_spatial_slice.py +0 -154
  71. tests/select/test_select_time_slice.py +0 -275
  72. tests/test_sample/test_base.py +0 -164
  73. tests/test_sample/test_site_sample.py +0 -165
  74. tests/test_sample/test_uk_regional_sample.py +0 -136
  75. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  76. tests/torch_datasets/test_pvnet_uk.py +0 -154
  77. tests/torch_datasets/test_site.py +0 -226
  78. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,41 +1,37 @@
1
- """Torch dataset for UK PVNet"""
1
+ """Torch dataset for UK PVNet."""
2
2
 
3
- import pkg_resources
3
+ from importlib.resources import files
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import xarray as xr
8
8
  from torch.utils.data import Dataset
9
+ from typing_extensions import override
10
+
9
11
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
10
12
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
11
- from ocf_data_sampler.select import (
12
- fill_time_periods,
13
- Location,
14
- slice_datasets_by_space,
15
- slice_datasets_by_time,
16
- )
17
- from ocf_data_sampler.utils import minutes
18
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
19
13
  from ocf_data_sampler.numpy_sample import (
14
+ convert_gsp_to_numpy_sample,
20
15
  convert_nwp_to_numpy_sample,
21
16
  convert_satellite_to_numpy_sample,
22
- convert_gsp_to_numpy_sample,
23
17
  make_sun_position_numpy_sample,
24
18
  )
19
+ from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
25
20
  from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
26
21
  from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
27
- from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
22
+ from ocf_data_sampler.select import (
23
+ Location,
24
+ fill_time_periods,
25
+ slice_datasets_by_space,
26
+ slice_datasets_by_time,
27
+ )
28
28
  from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
29
+ from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
30
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
- merge_dicts,
32
31
  fill_nans_in_arrays,
32
+ merge_dicts,
33
33
  )
34
- from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
- validate_nwp_channels,
36
- validate_satellite_channels,
37
- )
38
-
34
+ from ocf_data_sampler.utils import minutes
39
35
 
40
36
  xr.set_options(keep_attrs=True)
41
37
 
@@ -45,20 +41,26 @@ def process_and_combine_datasets(
45
41
  config: Configuration,
46
42
  t0: pd.Timestamp,
47
43
  location: Location,
48
- target_key: str = 'gsp'
49
44
  ) -> dict:
50
-
51
- """Normalise and convert data to numpy arrays"""
45
+ """Normalise and convert data to numpy arrays."""
52
46
  numpy_modalities = []
53
47
 
54
48
  if "nwp" in dataset_dict:
55
- nwp_numpy_modalities = dict()
49
+ nwp_numpy_modalities = {}
56
50
 
57
51
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
58
- provider = config.input_data.nwp[nwp_key].provider
59
52
 
60
53
  # Standardise and convert to NumpyBatch
61
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
54
+
55
+ da_channel_means = channel_dict_to_dataarray(
56
+ config.input_data.nwp[nwp_key].channel_means,
57
+ )
58
+ da_channel_stds = channel_dict_to_dataarray(
59
+ config.input_data.nwp[nwp_key].channel_stds,
60
+ )
61
+
62
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
63
+
62
64
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
63
65
 
64
66
  # Combine the NWPs into NumpyBatch
@@ -68,44 +70,57 @@ def process_and_combine_datasets(
68
70
  da_sat = dataset_dict["sat"]
69
71
 
70
72
  # Standardise and convert to NumpyBatch
71
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
72
- numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
73
+ da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
74
+ da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
73
75
 
74
- gsp_config = config.input_data.gsp
76
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
77
+
78
+ numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
75
79
 
76
80
  if "gsp" in dataset_dict:
81
+ gsp_config = config.input_data.gsp
77
82
  da_gsp = dataset_dict["gsp"]
78
83
  da_gsp = da_gsp / da_gsp.effective_capacity_mwp
79
-
84
+
80
85
  # Convert to NumpyBatch
81
86
  numpy_modalities.append(
82
87
  convert_gsp_to_numpy_sample(
83
- da_gsp,
84
- t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
85
- )
88
+ da_gsp,
89
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
90
+ ),
86
91
  )
87
92
 
88
- if target_key == 'gsp':
89
- # Make sun coords NumpySample
93
+ # Add GSP location data
94
+ numpy_modalities.append(
95
+ {
96
+ GSPSampleKey.gsp_id: location.id,
97
+ GSPSampleKey.x_osgb: location.x,
98
+ GSPSampleKey.y_osgb: location.y,
99
+ },
100
+ )
101
+
102
+ # Only add solar position if explicitly configured
103
+ has_solar_config = (
104
+ hasattr(config.input_data, "solar_position") and
105
+ config.input_data.solar_position is not None
106
+ )
107
+
108
+ if has_solar_config:
109
+ solar_config = config.input_data.solar_position
110
+
111
+ # Create datetime range for solar position calculation
90
112
  datetimes = pd.date_range(
91
- t0+minutes(gsp_config.interval_start_minutes),
92
- t0+minutes(gsp_config.interval_end_minutes),
93
- freq=minutes(gsp_config.time_resolution_minutes),
113
+ t0 + minutes(solar_config.interval_start_minutes),
114
+ t0 + minutes(solar_config.interval_end_minutes),
115
+ freq=minutes(solar_config.time_resolution_minutes),
94
116
  )
95
117
 
118
+ # Convert OSGB coordinates to lon/lat
96
119
  lon, lat = osgb_to_lon_lat(location.x, location.y)
97
120
 
98
- numpy_modalities.append(
99
- {
100
- GSPSampleKey.gsp_id: location.id,
101
- GSPSampleKey.x_osgb: location.x,
102
- GSPSampleKey.y_osgb: location.y,
103
- }
104
- )
105
-
106
- numpy_modalities.append(
107
- make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
108
- )
121
+ # Calculate solar positions and add to modalities
122
+ solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
123
+ numpy_modalities.append(solar_positions)
109
124
 
110
125
  # Combine all the modalities and fill NaNs
111
126
  combined_sample = merge_dicts(numpy_modalities)
@@ -115,7 +130,7 @@ def process_and_combine_datasets(
115
130
 
116
131
 
117
132
  def compute(xarray_dict: dict) -> dict:
118
- """Eagerly load a nested dictionary of xarray DataArrays"""
133
+ """Eagerly load a nested dictionary of xarray DataArrays."""
119
134
  for k, v in xarray_dict.items():
120
135
  if isinstance(v, dict):
121
136
  xarray_dict[k] = compute(v)
@@ -125,59 +140,58 @@ def compute(xarray_dict: dict) -> dict:
125
140
 
126
141
 
127
142
  def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
128
- """Find the t0 times where all of the requested input data is available
143
+ """Find the t0 times where all of the requested input data is available.
129
144
 
130
145
  Args:
131
146
  datasets_dict: A dictionary of input datasets
132
147
  config: Configuration file
133
148
  """
134
-
135
149
  valid_time_periods = find_valid_time_periods(datasets_dict, config)
136
150
 
137
151
  # Fill out the contiguous time periods to get the t0 times
138
152
  valid_t0_times = fill_time_periods(
139
- valid_time_periods,
140
- freq=minutes(config.input_data.gsp.time_resolution_minutes)
153
+ valid_time_periods,
154
+ freq=minutes(config.input_data.gsp.time_resolution_minutes),
141
155
  )
142
-
143
156
  return valid_t0_times
144
157
 
145
158
 
146
159
  def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
147
- """Get list of locations of all GSPs"""
148
-
160
+ """Get list of locations of all GSPs."""
149
161
  if gsp_ids is None:
150
- gsp_ids = [i for i in range(1, 318)]
151
-
162
+ gsp_ids = list(range(1, 318))
163
+
152
164
  locations = []
153
165
 
154
166
  # Load UK GSP locations
155
167
  df_gsp_loc = pd.read_csv(
156
- pkg_resources.resource_filename(__name__, "../../data/uk_gsp_locations.csv"),
168
+ files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
157
169
  index_col="gsp_id",
158
170
  )
159
171
 
160
172
  for gsp_id in gsp_ids:
161
173
  locations.append(
162
174
  Location(
163
- coordinate_system = "osgb",
175
+ coordinate_system="osgb",
164
176
  x=df_gsp_loc.loc[gsp_id].x_osgb,
165
177
  y=df_gsp_loc.loc[gsp_id].y_osgb,
166
178
  id=gsp_id,
167
- )
179
+ ),
168
180
  )
169
181
  return locations
170
182
 
171
183
 
172
184
  class PVNetUKRegionalDataset(Dataset):
185
+ """A torch Dataset for creating PVNet UK regional samples."""
186
+
173
187
  def __init__(
174
- self,
175
- config_filename: str,
188
+ self,
189
+ config_filename: str,
176
190
  start_time: str | None = None,
177
191
  end_time: str | None = None,
178
192
  gsp_ids: list[int] | None = None,
179
- ):
180
- """A torch Dataset for creating PVNet UK GSP samples
193
+ ) -> None:
194
+ """A torch Dataset for creating PVNet UK GSP samples.
181
195
 
182
196
  Args:
183
197
  config_filename: Path to the configuration file
@@ -185,31 +199,28 @@ class PVNetUKRegionalDataset(Dataset):
185
199
  end_time: Limit the init-times to be before this
186
200
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
187
201
  """
188
-
189
202
  # config = load_yaml_configuration(config_filename)
190
203
  config: Configuration = load_yaml_configuration(config_filename)
191
- validate_nwp_channels(config)
192
- validate_satellite_channels(config)
193
204
 
194
205
  datasets_dict = get_dataset_dict(config.input_data)
195
-
206
+
196
207
  # Get t0 times where all input data is available
197
208
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
198
209
 
199
210
  # Filter t0 times to given range
200
211
  if start_time is not None:
201
- valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
202
-
212
+ valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
213
+
203
214
  if end_time is not None:
204
- valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
215
+ valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
205
216
 
206
217
  # Construct list of locations to sample from
207
218
  locations = get_gsp_locations(gsp_ids)
208
219
 
209
220
  # Construct a lookup for locations - useful for users to construct sample by GSP ID
210
221
  location_lookup = {loc.id: loc for loc in locations}
211
-
212
- # Construct indices for sampling
222
+
223
+ # Construct indices for sampling
213
224
  t_index, loc_index = np.meshgrid(
214
225
  np.arange(len(valid_t0_times)),
215
226
  np.arange(len(locations)),
@@ -217,7 +228,7 @@ class PVNetUKRegionalDataset(Dataset):
217
228
 
218
229
  # Make array of all possible (t0, location) coordinates. Each row is a single coordinate
219
230
  index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
220
-
231
+
221
232
  # Assign coords and indices to self
222
233
  self.valid_t0_times = valid_t0_times
223
234
  self.locations = locations
@@ -227,15 +238,14 @@ class PVNetUKRegionalDataset(Dataset):
227
238
  # Assign config and input data to self
228
239
  self.datasets_dict = datasets_dict
229
240
  self.config = config
230
-
231
-
232
- def __len__(self):
241
+
242
+ @override
243
+ def __len__(self) -> int:
233
244
  return len(self.index_pairs)
234
-
235
-
245
+
236
246
  def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
237
- """Generate the PVNet sample for given coordinates
238
-
247
+ """Generate the PVNet sample for given coordinates.
248
+
239
249
  Args:
240
250
  t0: init-time for sample
241
251
  location: location for sample
@@ -245,49 +255,51 @@ class PVNetUKRegionalDataset(Dataset):
245
255
  sample_dict = compute(sample_dict)
246
256
 
247
257
  sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
248
-
258
+
249
259
  return sample
250
-
251
-
252
- def __getitem__(self, idx):
253
-
260
+
261
+ @override
262
+ def __getitem__(self, idx: int) -> dict:
254
263
  # Get the coordinates of the sample
255
264
  t_index, loc_index = self.index_pairs[idx]
256
265
  location = self.locations[loc_index]
257
266
  t0 = self.valid_t0_times[t_index]
258
-
267
+
259
268
  # Generate the sample
260
269
  return self._get_sample(t0, location)
261
-
262
270
 
263
271
  def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
264
- """Generate a sample for the given coordinates.
265
-
272
+ """Generate a sample for the given coordinates.
273
+
266
274
  Useful for users to generate specific samples.
267
-
275
+
268
276
  Args:
269
277
  t0: init-time for sample
270
278
  gsp_id: GSP ID
271
279
  """
272
280
  # Check the user has asked for a sample which we have the data for
273
- assert t0 in self.valid_t0_times
274
- assert gsp_id in self.location_lookup
281
+ if t0 not in self.valid_t0_times:
282
+ raise ValueError(f"Input init time '{t0!s}' not in valid times")
283
+ if gsp_id not in self.location_lookup:
284
+ raise ValueError(f"Input GSP '{gsp_id}' not known")
275
285
 
276
286
  location = self.location_lookup[gsp_id]
277
-
287
+
278
288
  return self._get_sample(t0, location)
279
-
280
-
289
+
290
+
281
291
  class PVNetUKConcurrentDataset(Dataset):
292
+ """A torch Dataset for creating concurrent PVNet UK regional samples."""
293
+
282
294
  def __init__(
283
- self,
284
- config_filename: str,
295
+ self,
296
+ config_filename: str,
285
297
  start_time: str | None = None,
286
298
  end_time: str | None = None,
287
299
  gsp_ids: list[int] | None = None,
288
- ):
289
- """A torch Dataset for creating concurrent samples of PVNet UK regional data
290
-
300
+ ) -> None:
301
+ """A torch Dataset for creating concurrent samples of PVNet UK regional data.
302
+
291
303
  Each concurrent sample includes the data from all GSPs for a single t0 time
292
304
 
293
305
  Args:
@@ -296,28 +308,23 @@ class PVNetUKConcurrentDataset(Dataset):
296
308
  end_time: Limit the init-times to be before this
297
309
  gsp_ids: List of all GSP IDs included in each sample. Defaults to all
298
310
  """
299
-
300
311
  config = load_yaml_configuration(config_filename)
301
312
 
302
- # Validate channels for NWP and satellite data
303
- validate_nwp_channels(config)
304
- validate_satellite_channels(config)
305
-
306
313
  datasets_dict = get_dataset_dict(config.input_data)
307
-
314
+
308
315
  # Get t0 times where all input data is available
309
316
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
310
317
 
311
318
  # Filter t0 times to given range
312
319
  if start_time is not None:
313
- valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
314
-
320
+ valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
321
+
315
322
  if end_time is not None:
316
- valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
323
+ valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
317
324
 
318
325
  # Construct list of locations to sample from
319
326
  locations = get_gsp_locations(gsp_ids)
320
-
327
+
321
328
  # Assign coords and indices to self
322
329
  self.valid_t0_times = valid_t0_times
323
330
  self.locations = locations
@@ -325,48 +332,50 @@ class PVNetUKConcurrentDataset(Dataset):
325
332
  # Assign config and input data to self
326
333
  self.datasets_dict = datasets_dict
327
334
  self.config = config
328
-
329
-
330
- def __len__(self):
335
+
336
+ @override
337
+ def __len__(self) -> int:
331
338
  return len(self.valid_t0_times)
332
-
333
-
339
+
334
340
  def _get_sample(self, t0: pd.Timestamp) -> dict:
335
- """Generate a concurrent PVNet sample for given init-time
336
-
341
+ """Generate a concurrent PVNet sample for given init-time.
342
+
337
343
  Args:
338
344
  t0: init-time for sample
339
345
  """
340
346
  # Slice by time then load to avoid loading the data multiple times from disk
341
347
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
342
348
  sample_dict = compute(sample_dict)
343
-
349
+
344
350
  gsp_samples = []
345
-
351
+
346
352
  # Prepare sample for each GSP
347
353
  for location in self.locations:
348
354
  gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
349
355
  gsp_numpy_sample = process_and_combine_datasets(
350
- gsp_sample_dict, self.config, t0, location
356
+ gsp_sample_dict,
357
+ self.config,
358
+ t0,
359
+ location,
351
360
  )
352
361
  gsp_samples.append(gsp_numpy_sample)
353
-
362
+
354
363
  # Stack GSP samples
355
364
  return stack_np_samples_into_batch(gsp_samples)
356
-
357
-
358
- def __getitem__(self, idx):
365
+
366
+ @override
367
+ def __getitem__(self, idx: int) -> dict:
359
368
  return self._get_sample(self.valid_t0_times[idx])
360
-
361
369
 
362
370
  def get_sample(self, t0: pd.Timestamp) -> dict:
363
- """Generate a sample for the given init-time.
364
-
371
+ """Generate a sample for the given init-time.
372
+
365
373
  Useful for users to generate specific samples.
366
-
374
+
367
375
  Args:
368
376
  t0: init-time for sample
369
377
  """
370
378
  # Check data is availablle for init-time t0
371
- assert t0 in self.valid_t0_times
379
+ if t0 not in self.valid_t0_times:
380
+ raise ValueError(f"Input init time '{t0!s}' not in valid times")
372
381
  return self._get_sample(t0)