ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,25 +1,26 @@
1
- """ Slice datasets by time"""
1
+ """Slice datasets by time."""
2
+
2
3
  import pandas as pd
3
4
  import xarray as xr
4
5
 
5
6
  from ocf_data_sampler.config import Configuration
6
- from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
7
- from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
7
+ from ocf_data_sampler.select.dropout import apply_dropout_time, draw_dropout_time
8
+ from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
8
9
  from ocf_data_sampler.utils import minutes
9
10
 
11
+
10
12
  def slice_datasets_by_time(
11
13
  datasets_dict: dict,
12
14
  t0: pd.Timestamp,
13
15
  config: Configuration,
14
16
  ) -> dict:
15
- """Slice the dictionary of input data sources around a given t0 time
17
+ """Slice the dictionary of input data sources around a given t0 time.
16
18
 
17
19
  Args:
18
20
  datasets_dict: Dictionary of the input data sources
19
21
  t0: The init-time
20
22
  config: Configuration object.
21
23
  """
22
-
23
24
  sliced_datasets_dict = {}
24
25
 
25
26
  if "nwp" in datasets_dict:
@@ -31,7 +32,7 @@ def slice_datasets_by_time(
31
32
  sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
32
33
  da_nwp,
33
34
  t0,
34
- sample_period_duration=minutes(nwp_config.time_resolution_minutes),
35
+ time_resolution=minutes(nwp_config.time_resolution_minutes),
35
36
  interval_start=minutes(nwp_config.interval_start_minutes),
36
37
  interval_end=minutes(nwp_config.interval_end_minutes),
37
38
  dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
@@ -45,7 +46,7 @@ def slice_datasets_by_time(
45
46
  sliced_datasets_dict["sat"] = select_time_slice(
46
47
  datasets_dict["sat"],
47
48
  t0,
48
- sample_period_duration=minutes(sat_config.time_resolution_minutes),
49
+ time_resolution=minutes(sat_config.time_resolution_minutes),
49
50
  interval_start=minutes(sat_config.interval_start_minutes),
50
51
  interval_end=minutes(sat_config.interval_end_minutes),
51
52
  )
@@ -65,11 +66,11 @@ def slice_datasets_by_time(
65
66
 
66
67
  if "gsp" in datasets_dict:
67
68
  gsp_config = config.input_data.gsp
68
-
69
+
69
70
  da_gsp_past = select_time_slice(
70
71
  datasets_dict["gsp"],
71
72
  t0,
72
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
73
+ time_resolution=minutes(gsp_config.time_resolution_minutes),
73
74
  interval_start=minutes(gsp_config.interval_start_minutes),
74
75
  interval_end=minutes(0),
75
76
  )
@@ -82,18 +83,18 @@ def slice_datasets_by_time(
82
83
  )
83
84
 
84
85
  da_gsp_past = apply_dropout_time(
85
- da_gsp_past,
86
- gsp_dropout_time
86
+ da_gsp_past,
87
+ gsp_dropout_time,
87
88
  )
88
-
89
+
89
90
  da_gsp_future = select_time_slice(
90
91
  datasets_dict["gsp"],
91
92
  t0,
92
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
93
+ time_resolution=minutes(gsp_config.time_resolution_minutes),
93
94
  interval_start=minutes(gsp_config.time_resolution_minutes),
94
95
  interval_end=minutes(gsp_config.interval_end_minutes),
95
96
  )
96
-
97
+
97
98
  sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
98
99
 
99
100
  if "site" in datasets_dict:
@@ -102,7 +103,7 @@ def slice_datasets_by_time(
102
103
  sliced_datasets_dict["site"] = select_time_slice(
103
104
  datasets_dict["site"],
104
105
  t0,
105
- sample_period_duration=minutes(site_config.time_resolution_minutes),
106
+ time_resolution=minutes(site_config.time_resolution_minutes),
106
107
  interval_start=minutes(site_config.interval_start_minutes),
107
108
  interval_end=minutes(site_config.interval_end_minutes),
108
109
  )
@@ -120,4 +121,4 @@ def slice_datasets_by_time(
120
121
  site_dropout_time,
121
122
  )
122
123
 
123
- return sliced_datasets_dict
124
+ return sliced_datasets_dict
@@ -1,41 +1,42 @@
1
- """Torch dataset for UK PVNet"""
1
+ """Torch dataset for UK PVNet."""
2
2
 
3
- import pkg_resources
3
+ from importlib.resources import files
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import xarray as xr
8
8
  from torch.utils.data import Dataset
9
+ from typing_extensions import override
10
+
9
11
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
10
- from ocf_data_sampler.load.load_dataset import get_dataset_dict
11
- from ocf_data_sampler.select import (
12
- fill_time_periods,
13
- Location,
14
- slice_datasets_by_space,
15
- slice_datasets_by_time,
16
- )
17
- from ocf_data_sampler.utils import minutes
18
12
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
13
+ from ocf_data_sampler.load.load_dataset import get_dataset_dict
19
14
  from ocf_data_sampler.numpy_sample import (
15
+ convert_gsp_to_numpy_sample,
20
16
  convert_nwp_to_numpy_sample,
21
17
  convert_satellite_to_numpy_sample,
22
- convert_gsp_to_numpy_sample,
23
18
  make_sun_position_numpy_sample,
24
19
  )
20
+ from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
25
21
  from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
26
22
  from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
27
- from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
23
+ from ocf_data_sampler.select import (
24
+ Location,
25
+ fill_time_periods,
26
+ slice_datasets_by_space,
27
+ slice_datasets_by_time,
28
+ )
28
29
  from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
30
30
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
- merge_dicts,
32
31
  fill_nans_in_arrays,
32
+ merge_dicts,
33
33
  )
34
+ from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
34
35
  from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
36
  validate_nwp_channels,
36
37
  validate_satellite_channels,
37
38
  )
38
-
39
+ from ocf_data_sampler.utils import minutes
39
40
 
40
41
  xr.set_options(keep_attrs=True)
41
42
 
@@ -45,14 +46,12 @@ def process_and_combine_datasets(
45
46
  config: Configuration,
46
47
  t0: pd.Timestamp,
47
48
  location: Location,
48
- target_key: str = 'gsp'
49
49
  ) -> dict:
50
-
51
- """Normalise and convert data to numpy arrays"""
50
+ """Normalise and convert data to numpy arrays."""
52
51
  numpy_modalities = []
53
52
 
54
53
  if "nwp" in dataset_dict:
55
- nwp_numpy_modalities = dict()
54
+ nwp_numpy_modalities = {}
56
55
 
57
56
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
58
57
  provider = config.input_data.nwp[nwp_key].provider
@@ -71,41 +70,50 @@ def process_and_combine_datasets(
71
70
  da_sat = (da_sat - RSS_MEAN) / RSS_STD
72
71
  numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
73
72
 
74
- gsp_config = config.input_data.gsp
75
-
76
73
  if "gsp" in dataset_dict:
74
+ gsp_config = config.input_data.gsp
77
75
  da_gsp = dataset_dict["gsp"]
78
76
  da_gsp = da_gsp / da_gsp.effective_capacity_mwp
79
-
77
+
80
78
  # Convert to NumpyBatch
81
79
  numpy_modalities.append(
82
80
  convert_gsp_to_numpy_sample(
83
- da_gsp,
84
- t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
85
- )
81
+ da_gsp,
82
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
83
+ ),
86
84
  )
87
85
 
88
- if target_key == 'gsp':
89
- # Make sun coords NumpySample
86
+ # Add GSP location data
87
+ numpy_modalities.append(
88
+ {
89
+ GSPSampleKey.gsp_id: location.id,
90
+ GSPSampleKey.x_osgb: location.x,
91
+ GSPSampleKey.y_osgb: location.y,
92
+ },
93
+ )
94
+
95
+ # Only add solar position if explicitly configured
96
+ has_solar_config = (
97
+ hasattr(config.input_data, "solar_position") and
98
+ config.input_data.solar_position is not None
99
+ )
100
+
101
+ if has_solar_config:
102
+ solar_config = config.input_data.solar_position
103
+
104
+ # Create datetime range for solar position calculation
90
105
  datetimes = pd.date_range(
91
- t0+minutes(gsp_config.interval_start_minutes),
92
- t0+minutes(gsp_config.interval_end_minutes),
93
- freq=minutes(gsp_config.time_resolution_minutes),
106
+ t0 + minutes(solar_config.interval_start_minutes),
107
+ t0 + minutes(solar_config.interval_end_minutes),
108
+ freq=minutes(solar_config.time_resolution_minutes),
94
109
  )
95
110
 
111
+ # Convert OSGB coordinates to lon/lat
96
112
  lon, lat = osgb_to_lon_lat(location.x, location.y)
97
113
 
98
- numpy_modalities.append(
99
- {
100
- GSPSampleKey.gsp_id: location.id,
101
- GSPSampleKey.x_osgb: location.x,
102
- GSPSampleKey.y_osgb: location.y,
103
- }
104
- )
105
-
106
- numpy_modalities.append(
107
- make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
108
- )
114
+ # Calculate solar positions and add to modalities
115
+ solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
116
+ numpy_modalities.append(solar_positions)
109
117
 
110
118
  # Combine all the modalities and fill NaNs
111
119
  combined_sample = merge_dicts(numpy_modalities)
@@ -115,7 +123,7 @@ def process_and_combine_datasets(
115
123
 
116
124
 
117
125
  def compute(xarray_dict: dict) -> dict:
118
- """Eagerly load a nested dictionary of xarray DataArrays"""
126
+ """Eagerly load a nested dictionary of xarray DataArrays."""
119
127
  for k, v in xarray_dict.items():
120
128
  if isinstance(v, dict):
121
129
  xarray_dict[k] = compute(v)
@@ -125,59 +133,58 @@ def compute(xarray_dict: dict) -> dict:
125
133
 
126
134
 
127
135
  def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
128
- """Find the t0 times where all of the requested input data is available
136
+ """Find the t0 times where all of the requested input data is available.
129
137
 
130
138
  Args:
131
139
  datasets_dict: A dictionary of input datasets
132
140
  config: Configuration file
133
141
  """
134
-
135
142
  valid_time_periods = find_valid_time_periods(datasets_dict, config)
136
143
 
137
144
  # Fill out the contiguous time periods to get the t0 times
138
145
  valid_t0_times = fill_time_periods(
139
- valid_time_periods,
140
- freq=minutes(config.input_data.gsp.time_resolution_minutes)
146
+ valid_time_periods,
147
+ freq=minutes(config.input_data.gsp.time_resolution_minutes),
141
148
  )
142
-
143
149
  return valid_t0_times
144
150
 
145
151
 
146
152
  def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
147
- """Get list of locations of all GSPs"""
148
-
153
+ """Get list of locations of all GSPs."""
149
154
  if gsp_ids is None:
150
- gsp_ids = [i for i in range(1, 318)]
151
-
155
+ gsp_ids = list(range(1, 318))
156
+
152
157
  locations = []
153
158
 
154
159
  # Load UK GSP locations
155
160
  df_gsp_loc = pd.read_csv(
156
- pkg_resources.resource_filename(__name__, "../../data/uk_gsp_locations.csv"),
161
+ files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
157
162
  index_col="gsp_id",
158
163
  )
159
164
 
160
165
  for gsp_id in gsp_ids:
161
166
  locations.append(
162
167
  Location(
163
- coordinate_system = "osgb",
168
+ coordinate_system="osgb",
164
169
  x=df_gsp_loc.loc[gsp_id].x_osgb,
165
170
  y=df_gsp_loc.loc[gsp_id].y_osgb,
166
171
  id=gsp_id,
167
- )
172
+ ),
168
173
  )
169
174
  return locations
170
175
 
171
176
 
172
177
  class PVNetUKRegionalDataset(Dataset):
178
+ """A torch Dataset for creating PVNet UK regional samples."""
179
+
173
180
  def __init__(
174
- self,
175
- config_filename: str,
181
+ self,
182
+ config_filename: str,
176
183
  start_time: str | None = None,
177
184
  end_time: str | None = None,
178
185
  gsp_ids: list[int] | None = None,
179
- ):
180
- """A torch Dataset for creating PVNet UK GSP samples
186
+ ) -> None:
187
+ """A torch Dataset for creating PVNet UK GSP samples.
181
188
 
182
189
  Args:
183
190
  config_filename: Path to the configuration file
@@ -185,31 +192,30 @@ class PVNetUKRegionalDataset(Dataset):
185
192
  end_time: Limit the init-times to be before this
186
193
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
187
194
  """
188
-
189
195
  # config = load_yaml_configuration(config_filename)
190
196
  config: Configuration = load_yaml_configuration(config_filename)
191
197
  validate_nwp_channels(config)
192
198
  validate_satellite_channels(config)
193
199
 
194
200
  datasets_dict = get_dataset_dict(config.input_data)
195
-
201
+
196
202
  # Get t0 times where all input data is available
197
203
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
198
204
 
199
205
  # Filter t0 times to given range
200
206
  if start_time is not None:
201
- valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
202
-
207
+ valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
208
+
203
209
  if end_time is not None:
204
- valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
210
+ valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
205
211
 
206
212
  # Construct list of locations to sample from
207
213
  locations = get_gsp_locations(gsp_ids)
208
214
 
209
215
  # Construct a lookup for locations - useful for users to construct sample by GSP ID
210
216
  location_lookup = {loc.id: loc for loc in locations}
211
-
212
- # Construct indices for sampling
217
+
218
+ # Construct indices for sampling
213
219
  t_index, loc_index = np.meshgrid(
214
220
  np.arange(len(valid_t0_times)),
215
221
  np.arange(len(locations)),
@@ -217,7 +223,7 @@ class PVNetUKRegionalDataset(Dataset):
217
223
 
218
224
  # Make array of all possible (t0, location) coordinates. Each row is a single coordinate
219
225
  index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
220
-
226
+
221
227
  # Assign coords and indices to self
222
228
  self.valid_t0_times = valid_t0_times
223
229
  self.locations = locations
@@ -227,15 +233,14 @@ class PVNetUKRegionalDataset(Dataset):
227
233
  # Assign config and input data to self
228
234
  self.datasets_dict = datasets_dict
229
235
  self.config = config
230
-
231
-
232
- def __len__(self):
236
+
237
+ @override
238
+ def __len__(self) -> int:
233
239
  return len(self.index_pairs)
234
-
235
-
240
+
236
241
  def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
237
- """Generate the PVNet sample for given coordinates
238
-
242
+ """Generate the PVNet sample for given coordinates.
243
+
239
244
  Args:
240
245
  t0: init-time for sample
241
246
  location: location for sample
@@ -245,49 +250,51 @@ class PVNetUKRegionalDataset(Dataset):
245
250
  sample_dict = compute(sample_dict)
246
251
 
247
252
  sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
248
-
253
+
249
254
  return sample
250
-
251
-
252
- def __getitem__(self, idx):
253
-
255
+
256
+ @override
257
+ def __getitem__(self, idx: int) -> dict:
254
258
  # Get the coordinates of the sample
255
259
  t_index, loc_index = self.index_pairs[idx]
256
260
  location = self.locations[loc_index]
257
261
  t0 = self.valid_t0_times[t_index]
258
-
262
+
259
263
  # Generate the sample
260
264
  return self._get_sample(t0, location)
261
-
262
265
 
263
266
  def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
264
- """Generate a sample for the given coordinates.
265
-
267
+ """Generate a sample for the given coordinates.
268
+
266
269
  Useful for users to generate specific samples.
267
-
270
+
268
271
  Args:
269
272
  t0: init-time for sample
270
273
  gsp_id: GSP ID
271
274
  """
272
275
  # Check the user has asked for a sample which we have the data for
273
- assert t0 in self.valid_t0_times
274
- assert gsp_id in self.location_lookup
276
+ if t0 not in self.valid_t0_times:
277
+ raise ValueError(f"Input init time '{t0!s}' not in valid times")
278
+ if gsp_id not in self.location_lookup:
279
+ raise ValueError(f"Input GSP '{gsp_id}' not known")
275
280
 
276
281
  location = self.location_lookup[gsp_id]
277
-
282
+
278
283
  return self._get_sample(t0, location)
279
-
280
-
284
+
285
+
281
286
  class PVNetUKConcurrentDataset(Dataset):
287
+ """A torch Dataset for creating concurrent PVNet UK regional samples."""
288
+
282
289
  def __init__(
283
- self,
284
- config_filename: str,
290
+ self,
291
+ config_filename: str,
285
292
  start_time: str | None = None,
286
293
  end_time: str | None = None,
287
294
  gsp_ids: list[int] | None = None,
288
- ):
289
- """A torch Dataset for creating concurrent samples of PVNet UK regional data
290
-
295
+ ) -> None:
296
+ """A torch Dataset for creating concurrent samples of PVNet UK regional data.
297
+
291
298
  Each concurrent sample includes the data from all GSPs for a single t0 time
292
299
 
293
300
  Args:
@@ -296,7 +303,6 @@ class PVNetUKConcurrentDataset(Dataset):
296
303
  end_time: Limit the init-times to be before this
297
304
  gsp_ids: List of all GSP IDs included in each sample. Defaults to all
298
305
  """
299
-
300
306
  config = load_yaml_configuration(config_filename)
301
307
 
302
308
  # Validate channels for NWP and satellite data
@@ -304,20 +310,20 @@ class PVNetUKConcurrentDataset(Dataset):
304
310
  validate_satellite_channels(config)
305
311
 
306
312
  datasets_dict = get_dataset_dict(config.input_data)
307
-
313
+
308
314
  # Get t0 times where all input data is available
309
315
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
310
316
 
311
317
  # Filter t0 times to given range
312
318
  if start_time is not None:
313
- valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
314
-
319
+ valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
320
+
315
321
  if end_time is not None:
316
- valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
322
+ valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
317
323
 
318
324
  # Construct list of locations to sample from
319
325
  locations = get_gsp_locations(gsp_ids)
320
-
326
+
321
327
  # Assign coords and indices to self
322
328
  self.valid_t0_times = valid_t0_times
323
329
  self.locations = locations
@@ -325,48 +331,50 @@ class PVNetUKConcurrentDataset(Dataset):
325
331
  # Assign config and input data to self
326
332
  self.datasets_dict = datasets_dict
327
333
  self.config = config
328
-
329
-
330
- def __len__(self):
334
+
335
+ @override
336
+ def __len__(self) -> int:
331
337
  return len(self.valid_t0_times)
332
-
333
-
338
+
334
339
  def _get_sample(self, t0: pd.Timestamp) -> dict:
335
- """Generate a concurrent PVNet sample for given init-time
336
-
340
+ """Generate a concurrent PVNet sample for given init-time.
341
+
337
342
  Args:
338
343
  t0: init-time for sample
339
344
  """
340
345
  # Slice by time then load to avoid loading the data multiple times from disk
341
346
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
342
347
  sample_dict = compute(sample_dict)
343
-
348
+
344
349
  gsp_samples = []
345
-
350
+
346
351
  # Prepare sample for each GSP
347
352
  for location in self.locations:
348
353
  gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
349
354
  gsp_numpy_sample = process_and_combine_datasets(
350
- gsp_sample_dict, self.config, t0, location
355
+ gsp_sample_dict,
356
+ self.config,
357
+ t0,
358
+ location,
351
359
  )
352
360
  gsp_samples.append(gsp_numpy_sample)
353
-
361
+
354
362
  # Stack GSP samples
355
363
  return stack_np_samples_into_batch(gsp_samples)
356
-
357
-
358
- def __getitem__(self, idx):
364
+
365
+ @override
366
+ def __getitem__(self, idx: int) -> dict:
359
367
  return self._get_sample(self.valid_t0_times[idx])
360
-
361
368
 
362
369
  def get_sample(self, t0: pd.Timestamp) -> dict:
363
- """Generate a sample for the given init-time.
364
-
370
+ """Generate a sample for the given init-time.
371
+
365
372
  Useful for users to generate specific samples.
366
-
373
+
367
374
  Args:
368
375
  t0: init-time for sample
369
376
  """
370
377
  # Check data is availablle for init-time t0
371
- assert t0 in self.valid_t0_times
378
+ if t0 not in self.valid_t0_times:
379
+ raise ValueError(f"Input init time '{t0!s}' not in valid times")
372
380
  return self._get_sample(t0)