ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +146 -64
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
- ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
- ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +63 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler/constants.py +0 -222
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,41 +1,37 @@
|
|
|
1
|
-
"""Torch dataset for UK PVNet"""
|
|
1
|
+
"""Torch dataset for UK PVNet."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from importlib.resources import files
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import xarray as xr
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
9
11
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
10
12
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
|
-
from ocf_data_sampler.select import (
|
|
12
|
-
fill_time_periods,
|
|
13
|
-
Location,
|
|
14
|
-
slice_datasets_by_space,
|
|
15
|
-
slice_datasets_by_time,
|
|
16
|
-
)
|
|
17
|
-
from ocf_data_sampler.utils import minutes
|
|
18
|
-
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
19
13
|
from ocf_data_sampler.numpy_sample import (
|
|
14
|
+
convert_gsp_to_numpy_sample,
|
|
20
15
|
convert_nwp_to_numpy_sample,
|
|
21
16
|
convert_satellite_to_numpy_sample,
|
|
22
|
-
convert_gsp_to_numpy_sample,
|
|
23
17
|
make_sun_position_numpy_sample,
|
|
24
18
|
)
|
|
19
|
+
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
25
20
|
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
26
21
|
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
27
|
-
from ocf_data_sampler.
|
|
22
|
+
from ocf_data_sampler.select import (
|
|
23
|
+
Location,
|
|
24
|
+
fill_time_periods,
|
|
25
|
+
slice_datasets_by_space,
|
|
26
|
+
slice_datasets_by_time,
|
|
27
|
+
)
|
|
28
28
|
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
29
|
-
from ocf_data_sampler.torch_datasets.utils
|
|
29
|
+
from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
|
|
30
30
|
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
31
|
-
merge_dicts,
|
|
32
31
|
fill_nans_in_arrays,
|
|
32
|
+
merge_dicts,
|
|
33
33
|
)
|
|
34
|
-
from ocf_data_sampler.
|
|
35
|
-
validate_nwp_channels,
|
|
36
|
-
validate_satellite_channels,
|
|
37
|
-
)
|
|
38
|
-
|
|
34
|
+
from ocf_data_sampler.utils import minutes
|
|
39
35
|
|
|
40
36
|
xr.set_options(keep_attrs=True)
|
|
41
37
|
|
|
@@ -45,20 +41,26 @@ def process_and_combine_datasets(
|
|
|
45
41
|
config: Configuration,
|
|
46
42
|
t0: pd.Timestamp,
|
|
47
43
|
location: Location,
|
|
48
|
-
target_key: str = 'gsp'
|
|
49
44
|
) -> dict:
|
|
50
|
-
|
|
51
|
-
"""Normalise and convert data to numpy arrays"""
|
|
45
|
+
"""Normalise and convert data to numpy arrays."""
|
|
52
46
|
numpy_modalities = []
|
|
53
47
|
|
|
54
48
|
if "nwp" in dataset_dict:
|
|
55
|
-
nwp_numpy_modalities =
|
|
49
|
+
nwp_numpy_modalities = {}
|
|
56
50
|
|
|
57
51
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
58
|
-
provider = config.input_data.nwp[nwp_key].provider
|
|
59
52
|
|
|
60
53
|
# Standardise and convert to NumpyBatch
|
|
61
|
-
|
|
54
|
+
|
|
55
|
+
da_channel_means = channel_dict_to_dataarray(
|
|
56
|
+
config.input_data.nwp[nwp_key].channel_means,
|
|
57
|
+
)
|
|
58
|
+
da_channel_stds = channel_dict_to_dataarray(
|
|
59
|
+
config.input_data.nwp[nwp_key].channel_stds,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
63
|
+
|
|
62
64
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
63
65
|
|
|
64
66
|
# Combine the NWPs into NumpyBatch
|
|
@@ -68,44 +70,57 @@ def process_and_combine_datasets(
|
|
|
68
70
|
da_sat = dataset_dict["sat"]
|
|
69
71
|
|
|
70
72
|
# Standardise and convert to NumpyBatch
|
|
71
|
-
|
|
72
|
-
|
|
73
|
+
da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
74
|
+
da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
73
75
|
|
|
74
|
-
|
|
76
|
+
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
77
|
+
|
|
78
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
75
79
|
|
|
76
80
|
if "gsp" in dataset_dict:
|
|
81
|
+
gsp_config = config.input_data.gsp
|
|
77
82
|
da_gsp = dataset_dict["gsp"]
|
|
78
83
|
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
79
|
-
|
|
84
|
+
|
|
80
85
|
# Convert to NumpyBatch
|
|
81
86
|
numpy_modalities.append(
|
|
82
87
|
convert_gsp_to_numpy_sample(
|
|
83
|
-
da_gsp,
|
|
84
|
-
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
85
|
-
)
|
|
88
|
+
da_gsp,
|
|
89
|
+
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
|
|
90
|
+
),
|
|
86
91
|
)
|
|
87
92
|
|
|
88
|
-
|
|
89
|
-
|
|
93
|
+
# Add GSP location data
|
|
94
|
+
numpy_modalities.append(
|
|
95
|
+
{
|
|
96
|
+
GSPSampleKey.gsp_id: location.id,
|
|
97
|
+
GSPSampleKey.x_osgb: location.x,
|
|
98
|
+
GSPSampleKey.y_osgb: location.y,
|
|
99
|
+
},
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Only add solar position if explicitly configured
|
|
103
|
+
has_solar_config = (
|
|
104
|
+
hasattr(config.input_data, "solar_position") and
|
|
105
|
+
config.input_data.solar_position is not None
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if has_solar_config:
|
|
109
|
+
solar_config = config.input_data.solar_position
|
|
110
|
+
|
|
111
|
+
# Create datetime range for solar position calculation
|
|
90
112
|
datetimes = pd.date_range(
|
|
91
|
-
t0+minutes(
|
|
92
|
-
t0+minutes(
|
|
93
|
-
freq=minutes(
|
|
113
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
114
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
115
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
94
116
|
)
|
|
95
117
|
|
|
118
|
+
# Convert OSGB coordinates to lon/lat
|
|
96
119
|
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
97
120
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
GSPSampleKey.x_osgb: location.x,
|
|
102
|
-
GSPSampleKey.y_osgb: location.y,
|
|
103
|
-
}
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
numpy_modalities.append(
|
|
107
|
-
make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
|
|
108
|
-
)
|
|
121
|
+
# Calculate solar positions and add to modalities
|
|
122
|
+
solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
|
|
123
|
+
numpy_modalities.append(solar_positions)
|
|
109
124
|
|
|
110
125
|
# Combine all the modalities and fill NaNs
|
|
111
126
|
combined_sample = merge_dicts(numpy_modalities)
|
|
@@ -115,7 +130,7 @@ def process_and_combine_datasets(
|
|
|
115
130
|
|
|
116
131
|
|
|
117
132
|
def compute(xarray_dict: dict) -> dict:
|
|
118
|
-
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
133
|
+
"""Eagerly load a nested dictionary of xarray DataArrays."""
|
|
119
134
|
for k, v in xarray_dict.items():
|
|
120
135
|
if isinstance(v, dict):
|
|
121
136
|
xarray_dict[k] = compute(v)
|
|
@@ -125,59 +140,58 @@ def compute(xarray_dict: dict) -> dict:
|
|
|
125
140
|
|
|
126
141
|
|
|
127
142
|
def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
|
|
128
|
-
"""Find the t0 times where all of the requested input data is available
|
|
143
|
+
"""Find the t0 times where all of the requested input data is available.
|
|
129
144
|
|
|
130
145
|
Args:
|
|
131
146
|
datasets_dict: A dictionary of input datasets
|
|
132
147
|
config: Configuration file
|
|
133
148
|
"""
|
|
134
|
-
|
|
135
149
|
valid_time_periods = find_valid_time_periods(datasets_dict, config)
|
|
136
150
|
|
|
137
151
|
# Fill out the contiguous time periods to get the t0 times
|
|
138
152
|
valid_t0_times = fill_time_periods(
|
|
139
|
-
valid_time_periods,
|
|
140
|
-
freq=minutes(config.input_data.gsp.time_resolution_minutes)
|
|
153
|
+
valid_time_periods,
|
|
154
|
+
freq=minutes(config.input_data.gsp.time_resolution_minutes),
|
|
141
155
|
)
|
|
142
|
-
|
|
143
156
|
return valid_t0_times
|
|
144
157
|
|
|
145
158
|
|
|
146
159
|
def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
|
|
147
|
-
"""Get list of locations of all GSPs"""
|
|
148
|
-
|
|
160
|
+
"""Get list of locations of all GSPs."""
|
|
149
161
|
if gsp_ids is None:
|
|
150
|
-
gsp_ids =
|
|
151
|
-
|
|
162
|
+
gsp_ids = list(range(1, 318))
|
|
163
|
+
|
|
152
164
|
locations = []
|
|
153
165
|
|
|
154
166
|
# Load UK GSP locations
|
|
155
167
|
df_gsp_loc = pd.read_csv(
|
|
156
|
-
|
|
168
|
+
files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
|
|
157
169
|
index_col="gsp_id",
|
|
158
170
|
)
|
|
159
171
|
|
|
160
172
|
for gsp_id in gsp_ids:
|
|
161
173
|
locations.append(
|
|
162
174
|
Location(
|
|
163
|
-
coordinate_system
|
|
175
|
+
coordinate_system="osgb",
|
|
164
176
|
x=df_gsp_loc.loc[gsp_id].x_osgb,
|
|
165
177
|
y=df_gsp_loc.loc[gsp_id].y_osgb,
|
|
166
178
|
id=gsp_id,
|
|
167
|
-
)
|
|
179
|
+
),
|
|
168
180
|
)
|
|
169
181
|
return locations
|
|
170
182
|
|
|
171
183
|
|
|
172
184
|
class PVNetUKRegionalDataset(Dataset):
|
|
185
|
+
"""A torch Dataset for creating PVNet UK regional samples."""
|
|
186
|
+
|
|
173
187
|
def __init__(
|
|
174
|
-
self,
|
|
175
|
-
config_filename: str,
|
|
188
|
+
self,
|
|
189
|
+
config_filename: str,
|
|
176
190
|
start_time: str | None = None,
|
|
177
191
|
end_time: str | None = None,
|
|
178
192
|
gsp_ids: list[int] | None = None,
|
|
179
|
-
):
|
|
180
|
-
"""A torch Dataset for creating PVNet UK GSP samples
|
|
193
|
+
) -> None:
|
|
194
|
+
"""A torch Dataset for creating PVNet UK GSP samples.
|
|
181
195
|
|
|
182
196
|
Args:
|
|
183
197
|
config_filename: Path to the configuration file
|
|
@@ -185,31 +199,28 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
185
199
|
end_time: Limit the init-times to be before this
|
|
186
200
|
gsp_ids: List of GSP IDs to create samples for. Defaults to all
|
|
187
201
|
"""
|
|
188
|
-
|
|
189
202
|
# config = load_yaml_configuration(config_filename)
|
|
190
203
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
191
|
-
validate_nwp_channels(config)
|
|
192
|
-
validate_satellite_channels(config)
|
|
193
204
|
|
|
194
205
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
195
|
-
|
|
206
|
+
|
|
196
207
|
# Get t0 times where all input data is available
|
|
197
208
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
198
209
|
|
|
199
210
|
# Filter t0 times to given range
|
|
200
211
|
if start_time is not None:
|
|
201
|
-
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
202
|
-
|
|
212
|
+
valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
|
|
213
|
+
|
|
203
214
|
if end_time is not None:
|
|
204
|
-
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
215
|
+
valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
|
|
205
216
|
|
|
206
217
|
# Construct list of locations to sample from
|
|
207
218
|
locations = get_gsp_locations(gsp_ids)
|
|
208
219
|
|
|
209
220
|
# Construct a lookup for locations - useful for users to construct sample by GSP ID
|
|
210
221
|
location_lookup = {loc.id: loc for loc in locations}
|
|
211
|
-
|
|
212
|
-
#
|
|
222
|
+
|
|
223
|
+
# Construct indices for sampling
|
|
213
224
|
t_index, loc_index = np.meshgrid(
|
|
214
225
|
np.arange(len(valid_t0_times)),
|
|
215
226
|
np.arange(len(locations)),
|
|
@@ -217,7 +228,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
217
228
|
|
|
218
229
|
# Make array of all possible (t0, location) coordinates. Each row is a single coordinate
|
|
219
230
|
index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
|
|
220
|
-
|
|
231
|
+
|
|
221
232
|
# Assign coords and indices to self
|
|
222
233
|
self.valid_t0_times = valid_t0_times
|
|
223
234
|
self.locations = locations
|
|
@@ -227,15 +238,14 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
227
238
|
# Assign config and input data to self
|
|
228
239
|
self.datasets_dict = datasets_dict
|
|
229
240
|
self.config = config
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def __len__(self):
|
|
241
|
+
|
|
242
|
+
@override
|
|
243
|
+
def __len__(self) -> int:
|
|
233
244
|
return len(self.index_pairs)
|
|
234
|
-
|
|
235
|
-
|
|
245
|
+
|
|
236
246
|
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
237
|
-
"""Generate the PVNet sample for given coordinates
|
|
238
|
-
|
|
247
|
+
"""Generate the PVNet sample for given coordinates.
|
|
248
|
+
|
|
239
249
|
Args:
|
|
240
250
|
t0: init-time for sample
|
|
241
251
|
location: location for sample
|
|
@@ -245,49 +255,51 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
245
255
|
sample_dict = compute(sample_dict)
|
|
246
256
|
|
|
247
257
|
sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
|
|
248
|
-
|
|
258
|
+
|
|
249
259
|
return sample
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def __getitem__(self, idx):
|
|
253
|
-
|
|
260
|
+
|
|
261
|
+
@override
|
|
262
|
+
def __getitem__(self, idx: int) -> dict:
|
|
254
263
|
# Get the coordinates of the sample
|
|
255
264
|
t_index, loc_index = self.index_pairs[idx]
|
|
256
265
|
location = self.locations[loc_index]
|
|
257
266
|
t0 = self.valid_t0_times[t_index]
|
|
258
|
-
|
|
267
|
+
|
|
259
268
|
# Generate the sample
|
|
260
269
|
return self._get_sample(t0, location)
|
|
261
|
-
|
|
262
270
|
|
|
263
271
|
def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
|
|
264
|
-
"""Generate a sample for the given coordinates.
|
|
265
|
-
|
|
272
|
+
"""Generate a sample for the given coordinates.
|
|
273
|
+
|
|
266
274
|
Useful for users to generate specific samples.
|
|
267
|
-
|
|
275
|
+
|
|
268
276
|
Args:
|
|
269
277
|
t0: init-time for sample
|
|
270
278
|
gsp_id: GSP ID
|
|
271
279
|
"""
|
|
272
280
|
# Check the user has asked for a sample which we have the data for
|
|
273
|
-
|
|
274
|
-
|
|
281
|
+
if t0 not in self.valid_t0_times:
|
|
282
|
+
raise ValueError(f"Input init time '{t0!s}' not in valid times")
|
|
283
|
+
if gsp_id not in self.location_lookup:
|
|
284
|
+
raise ValueError(f"Input GSP '{gsp_id}' not known")
|
|
275
285
|
|
|
276
286
|
location = self.location_lookup[gsp_id]
|
|
277
|
-
|
|
287
|
+
|
|
278
288
|
return self._get_sample(t0, location)
|
|
279
|
-
|
|
280
|
-
|
|
289
|
+
|
|
290
|
+
|
|
281
291
|
class PVNetUKConcurrentDataset(Dataset):
|
|
292
|
+
"""A torch Dataset for creating concurrent PVNet UK regional samples."""
|
|
293
|
+
|
|
282
294
|
def __init__(
|
|
283
|
-
self,
|
|
284
|
-
config_filename: str,
|
|
295
|
+
self,
|
|
296
|
+
config_filename: str,
|
|
285
297
|
start_time: str | None = None,
|
|
286
298
|
end_time: str | None = None,
|
|
287
299
|
gsp_ids: list[int] | None = None,
|
|
288
|
-
):
|
|
289
|
-
"""A torch Dataset for creating concurrent samples of PVNet UK regional data
|
|
290
|
-
|
|
300
|
+
) -> None:
|
|
301
|
+
"""A torch Dataset for creating concurrent samples of PVNet UK regional data.
|
|
302
|
+
|
|
291
303
|
Each concurrent sample includes the data from all GSPs for a single t0 time
|
|
292
304
|
|
|
293
305
|
Args:
|
|
@@ -296,28 +308,23 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
296
308
|
end_time: Limit the init-times to be before this
|
|
297
309
|
gsp_ids: List of all GSP IDs included in each sample. Defaults to all
|
|
298
310
|
"""
|
|
299
|
-
|
|
300
311
|
config = load_yaml_configuration(config_filename)
|
|
301
312
|
|
|
302
|
-
# Validate channels for NWP and satellite data
|
|
303
|
-
validate_nwp_channels(config)
|
|
304
|
-
validate_satellite_channels(config)
|
|
305
|
-
|
|
306
313
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
307
|
-
|
|
314
|
+
|
|
308
315
|
# Get t0 times where all input data is available
|
|
309
316
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
310
317
|
|
|
311
318
|
# Filter t0 times to given range
|
|
312
319
|
if start_time is not None:
|
|
313
|
-
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
314
|
-
|
|
320
|
+
valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
|
|
321
|
+
|
|
315
322
|
if end_time is not None:
|
|
316
|
-
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
323
|
+
valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
|
|
317
324
|
|
|
318
325
|
# Construct list of locations to sample from
|
|
319
326
|
locations = get_gsp_locations(gsp_ids)
|
|
320
|
-
|
|
327
|
+
|
|
321
328
|
# Assign coords and indices to self
|
|
322
329
|
self.valid_t0_times = valid_t0_times
|
|
323
330
|
self.locations = locations
|
|
@@ -325,48 +332,50 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
325
332
|
# Assign config and input data to self
|
|
326
333
|
self.datasets_dict = datasets_dict
|
|
327
334
|
self.config = config
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
def __len__(self):
|
|
335
|
+
|
|
336
|
+
@override
|
|
337
|
+
def __len__(self) -> int:
|
|
331
338
|
return len(self.valid_t0_times)
|
|
332
|
-
|
|
333
|
-
|
|
339
|
+
|
|
334
340
|
def _get_sample(self, t0: pd.Timestamp) -> dict:
|
|
335
|
-
"""Generate a concurrent PVNet sample for given init-time
|
|
336
|
-
|
|
341
|
+
"""Generate a concurrent PVNet sample for given init-time.
|
|
342
|
+
|
|
337
343
|
Args:
|
|
338
344
|
t0: init-time for sample
|
|
339
345
|
"""
|
|
340
346
|
# Slice by time then load to avoid loading the data multiple times from disk
|
|
341
347
|
sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
|
|
342
348
|
sample_dict = compute(sample_dict)
|
|
343
|
-
|
|
349
|
+
|
|
344
350
|
gsp_samples = []
|
|
345
|
-
|
|
351
|
+
|
|
346
352
|
# Prepare sample for each GSP
|
|
347
353
|
for location in self.locations:
|
|
348
354
|
gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
349
355
|
gsp_numpy_sample = process_and_combine_datasets(
|
|
350
|
-
gsp_sample_dict,
|
|
356
|
+
gsp_sample_dict,
|
|
357
|
+
self.config,
|
|
358
|
+
t0,
|
|
359
|
+
location,
|
|
351
360
|
)
|
|
352
361
|
gsp_samples.append(gsp_numpy_sample)
|
|
353
|
-
|
|
362
|
+
|
|
354
363
|
# Stack GSP samples
|
|
355
364
|
return stack_np_samples_into_batch(gsp_samples)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
def __getitem__(self, idx):
|
|
365
|
+
|
|
366
|
+
@override
|
|
367
|
+
def __getitem__(self, idx: int) -> dict:
|
|
359
368
|
return self._get_sample(self.valid_t0_times[idx])
|
|
360
|
-
|
|
361
369
|
|
|
362
370
|
def get_sample(self, t0: pd.Timestamp) -> dict:
|
|
363
|
-
"""Generate a sample for the given init-time.
|
|
364
|
-
|
|
371
|
+
"""Generate a sample for the given init-time.
|
|
372
|
+
|
|
365
373
|
Useful for users to generate specific samples.
|
|
366
|
-
|
|
374
|
+
|
|
367
375
|
Args:
|
|
368
376
|
t0: init-time for sample
|
|
369
377
|
"""
|
|
370
378
|
# Check data is availablle for init-time t0
|
|
371
|
-
|
|
379
|
+
if t0 not in self.valid_t0_times:
|
|
380
|
+
raise ValueError(f"Input init time '{t0!s}' not in valid times")
|
|
372
381
|
return self._get_sample(t0)
|