ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +86 -72
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +27 -36
- ocf_data_sampler/load/site.py +11 -7
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +15 -13
- ocf_data_sampler/numpy_sample/nwp.py +17 -23
- ocf_data_sampler/numpy_sample/satellite.py +17 -14
- ocf_data_sampler/numpy_sample/site.py +8 -7
- ocf_data_sampler/numpy_sample/sun_position.py +19 -25
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +23 -44
- ocf_data_sampler/sample/site.py +25 -69
- ocf_data_sampler/sample/uk_regional.py +52 -103
- ocf_data_sampler/select/dropout.py +42 -27
- ocf_data_sampler/select/fill_time_periods.py +15 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -286
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -52
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -75
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -195
- tests/test_sample/test_uk_regional_sample.py +0 -163
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -167
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,25 +1,26 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Slice datasets by time."""
|
|
2
|
+
|
|
2
3
|
import pandas as pd
|
|
3
4
|
import xarray as xr
|
|
4
5
|
|
|
5
6
|
from ocf_data_sampler.config import Configuration
|
|
6
|
-
from ocf_data_sampler.select.dropout import
|
|
7
|
-
from ocf_data_sampler.select.select_time_slice import
|
|
7
|
+
from ocf_data_sampler.select.dropout import apply_dropout_time, draw_dropout_time
|
|
8
|
+
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
8
9
|
from ocf_data_sampler.utils import minutes
|
|
9
10
|
|
|
11
|
+
|
|
10
12
|
def slice_datasets_by_time(
|
|
11
13
|
datasets_dict: dict,
|
|
12
14
|
t0: pd.Timestamp,
|
|
13
15
|
config: Configuration,
|
|
14
16
|
) -> dict:
|
|
15
|
-
"""Slice the dictionary of input data sources around a given t0 time
|
|
17
|
+
"""Slice the dictionary of input data sources around a given t0 time.
|
|
16
18
|
|
|
17
19
|
Args:
|
|
18
20
|
datasets_dict: Dictionary of the input data sources
|
|
19
21
|
t0: The init-time
|
|
20
22
|
config: Configuration object.
|
|
21
23
|
"""
|
|
22
|
-
|
|
23
24
|
sliced_datasets_dict = {}
|
|
24
25
|
|
|
25
26
|
if "nwp" in datasets_dict:
|
|
@@ -31,7 +32,7 @@ def slice_datasets_by_time(
|
|
|
31
32
|
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
32
33
|
da_nwp,
|
|
33
34
|
t0,
|
|
34
|
-
|
|
35
|
+
time_resolution=minutes(nwp_config.time_resolution_minutes),
|
|
35
36
|
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
36
37
|
interval_end=minutes(nwp_config.interval_end_minutes),
|
|
37
38
|
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
@@ -45,7 +46,7 @@ def slice_datasets_by_time(
|
|
|
45
46
|
sliced_datasets_dict["sat"] = select_time_slice(
|
|
46
47
|
datasets_dict["sat"],
|
|
47
48
|
t0,
|
|
48
|
-
|
|
49
|
+
time_resolution=minutes(sat_config.time_resolution_minutes),
|
|
49
50
|
interval_start=minutes(sat_config.interval_start_minutes),
|
|
50
51
|
interval_end=minutes(sat_config.interval_end_minutes),
|
|
51
52
|
)
|
|
@@ -65,11 +66,11 @@ def slice_datasets_by_time(
|
|
|
65
66
|
|
|
66
67
|
if "gsp" in datasets_dict:
|
|
67
68
|
gsp_config = config.input_data.gsp
|
|
68
|
-
|
|
69
|
+
|
|
69
70
|
da_gsp_past = select_time_slice(
|
|
70
71
|
datasets_dict["gsp"],
|
|
71
72
|
t0,
|
|
72
|
-
|
|
73
|
+
time_resolution=minutes(gsp_config.time_resolution_minutes),
|
|
73
74
|
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
74
75
|
interval_end=minutes(0),
|
|
75
76
|
)
|
|
@@ -82,18 +83,18 @@ def slice_datasets_by_time(
|
|
|
82
83
|
)
|
|
83
84
|
|
|
84
85
|
da_gsp_past = apply_dropout_time(
|
|
85
|
-
da_gsp_past,
|
|
86
|
-
gsp_dropout_time
|
|
86
|
+
da_gsp_past,
|
|
87
|
+
gsp_dropout_time,
|
|
87
88
|
)
|
|
88
|
-
|
|
89
|
+
|
|
89
90
|
da_gsp_future = select_time_slice(
|
|
90
91
|
datasets_dict["gsp"],
|
|
91
92
|
t0,
|
|
92
|
-
|
|
93
|
+
time_resolution=minutes(gsp_config.time_resolution_minutes),
|
|
93
94
|
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
94
95
|
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
95
96
|
)
|
|
96
|
-
|
|
97
|
+
|
|
97
98
|
sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
|
|
98
99
|
|
|
99
100
|
if "site" in datasets_dict:
|
|
@@ -102,7 +103,7 @@ def slice_datasets_by_time(
|
|
|
102
103
|
sliced_datasets_dict["site"] = select_time_slice(
|
|
103
104
|
datasets_dict["site"],
|
|
104
105
|
t0,
|
|
105
|
-
|
|
106
|
+
time_resolution=minutes(site_config.time_resolution_minutes),
|
|
106
107
|
interval_start=minutes(site_config.interval_start_minutes),
|
|
107
108
|
interval_end=minutes(site_config.interval_end_minutes),
|
|
108
109
|
)
|
|
@@ -120,4 +121,4 @@ def slice_datasets_by_time(
|
|
|
120
121
|
site_dropout_time,
|
|
121
122
|
)
|
|
122
123
|
|
|
123
|
-
return sliced_datasets_dict
|
|
124
|
+
return sliced_datasets_dict
|
|
@@ -1,41 +1,42 @@
|
|
|
1
|
-
"""Torch dataset for UK PVNet"""
|
|
1
|
+
"""Torch dataset for UK PVNet."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from importlib.resources import files
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import xarray as xr
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
9
11
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
10
|
-
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
|
-
from ocf_data_sampler.select import (
|
|
12
|
-
fill_time_periods,
|
|
13
|
-
Location,
|
|
14
|
-
slice_datasets_by_space,
|
|
15
|
-
slice_datasets_by_time,
|
|
16
|
-
)
|
|
17
|
-
from ocf_data_sampler.utils import minutes
|
|
18
12
|
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
13
|
+
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
19
14
|
from ocf_data_sampler.numpy_sample import (
|
|
15
|
+
convert_gsp_to_numpy_sample,
|
|
20
16
|
convert_nwp_to_numpy_sample,
|
|
21
17
|
convert_satellite_to_numpy_sample,
|
|
22
|
-
convert_gsp_to_numpy_sample,
|
|
23
18
|
make_sun_position_numpy_sample,
|
|
24
19
|
)
|
|
20
|
+
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
25
21
|
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
26
22
|
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
27
|
-
from ocf_data_sampler.
|
|
23
|
+
from ocf_data_sampler.select import (
|
|
24
|
+
Location,
|
|
25
|
+
fill_time_periods,
|
|
26
|
+
slice_datasets_by_space,
|
|
27
|
+
slice_datasets_by_time,
|
|
28
|
+
)
|
|
28
29
|
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
29
|
-
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
30
30
|
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
31
|
-
merge_dicts,
|
|
32
31
|
fill_nans_in_arrays,
|
|
32
|
+
merge_dicts,
|
|
33
33
|
)
|
|
34
|
+
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
34
35
|
from ocf_data_sampler.torch_datasets.utils.validate_channels import (
|
|
35
36
|
validate_nwp_channels,
|
|
36
37
|
validate_satellite_channels,
|
|
37
38
|
)
|
|
38
|
-
|
|
39
|
+
from ocf_data_sampler.utils import minutes
|
|
39
40
|
|
|
40
41
|
xr.set_options(keep_attrs=True)
|
|
41
42
|
|
|
@@ -45,14 +46,12 @@ def process_and_combine_datasets(
|
|
|
45
46
|
config: Configuration,
|
|
46
47
|
t0: pd.Timestamp,
|
|
47
48
|
location: Location,
|
|
48
|
-
target_key: str = 'gsp'
|
|
49
49
|
) -> dict:
|
|
50
|
-
|
|
51
|
-
"""Normalise and convert data to numpy arrays"""
|
|
50
|
+
"""Normalise and convert data to numpy arrays."""
|
|
52
51
|
numpy_modalities = []
|
|
53
52
|
|
|
54
53
|
if "nwp" in dataset_dict:
|
|
55
|
-
nwp_numpy_modalities =
|
|
54
|
+
nwp_numpy_modalities = {}
|
|
56
55
|
|
|
57
56
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
58
57
|
provider = config.input_data.nwp[nwp_key].provider
|
|
@@ -71,41 +70,50 @@ def process_and_combine_datasets(
|
|
|
71
70
|
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
72
71
|
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
73
72
|
|
|
74
|
-
gsp_config = config.input_data.gsp
|
|
75
|
-
|
|
76
73
|
if "gsp" in dataset_dict:
|
|
74
|
+
gsp_config = config.input_data.gsp
|
|
77
75
|
da_gsp = dataset_dict["gsp"]
|
|
78
76
|
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
79
|
-
|
|
77
|
+
|
|
80
78
|
# Convert to NumpyBatch
|
|
81
79
|
numpy_modalities.append(
|
|
82
80
|
convert_gsp_to_numpy_sample(
|
|
83
|
-
da_gsp,
|
|
84
|
-
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
85
|
-
)
|
|
81
|
+
da_gsp,
|
|
82
|
+
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
|
|
83
|
+
),
|
|
86
84
|
)
|
|
87
85
|
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
# Add GSP location data
|
|
87
|
+
numpy_modalities.append(
|
|
88
|
+
{
|
|
89
|
+
GSPSampleKey.gsp_id: location.id,
|
|
90
|
+
GSPSampleKey.x_osgb: location.x,
|
|
91
|
+
GSPSampleKey.y_osgb: location.y,
|
|
92
|
+
},
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Only add solar position if explicitly configured
|
|
96
|
+
has_solar_config = (
|
|
97
|
+
hasattr(config.input_data, "solar_position") and
|
|
98
|
+
config.input_data.solar_position is not None
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if has_solar_config:
|
|
102
|
+
solar_config = config.input_data.solar_position
|
|
103
|
+
|
|
104
|
+
# Create datetime range for solar position calculation
|
|
90
105
|
datetimes = pd.date_range(
|
|
91
|
-
t0+minutes(
|
|
92
|
-
t0+minutes(
|
|
93
|
-
freq=minutes(
|
|
106
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
107
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
108
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
94
109
|
)
|
|
95
110
|
|
|
111
|
+
# Convert OSGB coordinates to lon/lat
|
|
96
112
|
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
97
113
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
GSPSampleKey.x_osgb: location.x,
|
|
102
|
-
GSPSampleKey.y_osgb: location.y,
|
|
103
|
-
}
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
numpy_modalities.append(
|
|
107
|
-
make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
|
|
108
|
-
)
|
|
114
|
+
# Calculate solar positions and add to modalities
|
|
115
|
+
solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
|
|
116
|
+
numpy_modalities.append(solar_positions)
|
|
109
117
|
|
|
110
118
|
# Combine all the modalities and fill NaNs
|
|
111
119
|
combined_sample = merge_dicts(numpy_modalities)
|
|
@@ -115,7 +123,7 @@ def process_and_combine_datasets(
|
|
|
115
123
|
|
|
116
124
|
|
|
117
125
|
def compute(xarray_dict: dict) -> dict:
|
|
118
|
-
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
126
|
+
"""Eagerly load a nested dictionary of xarray DataArrays."""
|
|
119
127
|
for k, v in xarray_dict.items():
|
|
120
128
|
if isinstance(v, dict):
|
|
121
129
|
xarray_dict[k] = compute(v)
|
|
@@ -125,59 +133,58 @@ def compute(xarray_dict: dict) -> dict:
|
|
|
125
133
|
|
|
126
134
|
|
|
127
135
|
def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
|
|
128
|
-
"""Find the t0 times where all of the requested input data is available
|
|
136
|
+
"""Find the t0 times where all of the requested input data is available.
|
|
129
137
|
|
|
130
138
|
Args:
|
|
131
139
|
datasets_dict: A dictionary of input datasets
|
|
132
140
|
config: Configuration file
|
|
133
141
|
"""
|
|
134
|
-
|
|
135
142
|
valid_time_periods = find_valid_time_periods(datasets_dict, config)
|
|
136
143
|
|
|
137
144
|
# Fill out the contiguous time periods to get the t0 times
|
|
138
145
|
valid_t0_times = fill_time_periods(
|
|
139
|
-
valid_time_periods,
|
|
140
|
-
freq=minutes(config.input_data.gsp.time_resolution_minutes)
|
|
146
|
+
valid_time_periods,
|
|
147
|
+
freq=minutes(config.input_data.gsp.time_resolution_minutes),
|
|
141
148
|
)
|
|
142
|
-
|
|
143
149
|
return valid_t0_times
|
|
144
150
|
|
|
145
151
|
|
|
146
152
|
def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
|
|
147
|
-
"""Get list of locations of all GSPs"""
|
|
148
|
-
|
|
153
|
+
"""Get list of locations of all GSPs."""
|
|
149
154
|
if gsp_ids is None:
|
|
150
|
-
gsp_ids =
|
|
151
|
-
|
|
155
|
+
gsp_ids = list(range(1, 318))
|
|
156
|
+
|
|
152
157
|
locations = []
|
|
153
158
|
|
|
154
159
|
# Load UK GSP locations
|
|
155
160
|
df_gsp_loc = pd.read_csv(
|
|
156
|
-
|
|
161
|
+
files("ocf_data_sampler.data").joinpath("uk_gsp_locations.csv"),
|
|
157
162
|
index_col="gsp_id",
|
|
158
163
|
)
|
|
159
164
|
|
|
160
165
|
for gsp_id in gsp_ids:
|
|
161
166
|
locations.append(
|
|
162
167
|
Location(
|
|
163
|
-
coordinate_system
|
|
168
|
+
coordinate_system="osgb",
|
|
164
169
|
x=df_gsp_loc.loc[gsp_id].x_osgb,
|
|
165
170
|
y=df_gsp_loc.loc[gsp_id].y_osgb,
|
|
166
171
|
id=gsp_id,
|
|
167
|
-
)
|
|
172
|
+
),
|
|
168
173
|
)
|
|
169
174
|
return locations
|
|
170
175
|
|
|
171
176
|
|
|
172
177
|
class PVNetUKRegionalDataset(Dataset):
|
|
178
|
+
"""A torch Dataset for creating PVNet UK regional samples."""
|
|
179
|
+
|
|
173
180
|
def __init__(
|
|
174
|
-
self,
|
|
175
|
-
config_filename: str,
|
|
181
|
+
self,
|
|
182
|
+
config_filename: str,
|
|
176
183
|
start_time: str | None = None,
|
|
177
184
|
end_time: str | None = None,
|
|
178
185
|
gsp_ids: list[int] | None = None,
|
|
179
|
-
):
|
|
180
|
-
"""A torch Dataset for creating PVNet UK GSP samples
|
|
186
|
+
) -> None:
|
|
187
|
+
"""A torch Dataset for creating PVNet UK GSP samples.
|
|
181
188
|
|
|
182
189
|
Args:
|
|
183
190
|
config_filename: Path to the configuration file
|
|
@@ -185,31 +192,30 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
185
192
|
end_time: Limit the init-times to be before this
|
|
186
193
|
gsp_ids: List of GSP IDs to create samples for. Defaults to all
|
|
187
194
|
"""
|
|
188
|
-
|
|
189
195
|
# config = load_yaml_configuration(config_filename)
|
|
190
196
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
191
197
|
validate_nwp_channels(config)
|
|
192
198
|
validate_satellite_channels(config)
|
|
193
199
|
|
|
194
200
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
195
|
-
|
|
201
|
+
|
|
196
202
|
# Get t0 times where all input data is available
|
|
197
203
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
198
204
|
|
|
199
205
|
# Filter t0 times to given range
|
|
200
206
|
if start_time is not None:
|
|
201
|
-
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
202
|
-
|
|
207
|
+
valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
|
|
208
|
+
|
|
203
209
|
if end_time is not None:
|
|
204
|
-
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
210
|
+
valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
|
|
205
211
|
|
|
206
212
|
# Construct list of locations to sample from
|
|
207
213
|
locations = get_gsp_locations(gsp_ids)
|
|
208
214
|
|
|
209
215
|
# Construct a lookup for locations - useful for users to construct sample by GSP ID
|
|
210
216
|
location_lookup = {loc.id: loc for loc in locations}
|
|
211
|
-
|
|
212
|
-
#
|
|
217
|
+
|
|
218
|
+
# Construct indices for sampling
|
|
213
219
|
t_index, loc_index = np.meshgrid(
|
|
214
220
|
np.arange(len(valid_t0_times)),
|
|
215
221
|
np.arange(len(locations)),
|
|
@@ -217,7 +223,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
217
223
|
|
|
218
224
|
# Make array of all possible (t0, location) coordinates. Each row is a single coordinate
|
|
219
225
|
index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
|
|
220
|
-
|
|
226
|
+
|
|
221
227
|
# Assign coords and indices to self
|
|
222
228
|
self.valid_t0_times = valid_t0_times
|
|
223
229
|
self.locations = locations
|
|
@@ -227,15 +233,14 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
227
233
|
# Assign config and input data to self
|
|
228
234
|
self.datasets_dict = datasets_dict
|
|
229
235
|
self.config = config
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def __len__(self):
|
|
236
|
+
|
|
237
|
+
@override
|
|
238
|
+
def __len__(self) -> int:
|
|
233
239
|
return len(self.index_pairs)
|
|
234
|
-
|
|
235
|
-
|
|
240
|
+
|
|
236
241
|
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
237
|
-
"""Generate the PVNet sample for given coordinates
|
|
238
|
-
|
|
242
|
+
"""Generate the PVNet sample for given coordinates.
|
|
243
|
+
|
|
239
244
|
Args:
|
|
240
245
|
t0: init-time for sample
|
|
241
246
|
location: location for sample
|
|
@@ -245,49 +250,51 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
245
250
|
sample_dict = compute(sample_dict)
|
|
246
251
|
|
|
247
252
|
sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
|
|
248
|
-
|
|
253
|
+
|
|
249
254
|
return sample
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def __getitem__(self, idx):
|
|
253
|
-
|
|
255
|
+
|
|
256
|
+
@override
|
|
257
|
+
def __getitem__(self, idx: int) -> dict:
|
|
254
258
|
# Get the coordinates of the sample
|
|
255
259
|
t_index, loc_index = self.index_pairs[idx]
|
|
256
260
|
location = self.locations[loc_index]
|
|
257
261
|
t0 = self.valid_t0_times[t_index]
|
|
258
|
-
|
|
262
|
+
|
|
259
263
|
# Generate the sample
|
|
260
264
|
return self._get_sample(t0, location)
|
|
261
|
-
|
|
262
265
|
|
|
263
266
|
def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
|
|
264
|
-
"""Generate a sample for the given coordinates.
|
|
265
|
-
|
|
267
|
+
"""Generate a sample for the given coordinates.
|
|
268
|
+
|
|
266
269
|
Useful for users to generate specific samples.
|
|
267
|
-
|
|
270
|
+
|
|
268
271
|
Args:
|
|
269
272
|
t0: init-time for sample
|
|
270
273
|
gsp_id: GSP ID
|
|
271
274
|
"""
|
|
272
275
|
# Check the user has asked for a sample which we have the data for
|
|
273
|
-
|
|
274
|
-
|
|
276
|
+
if t0 not in self.valid_t0_times:
|
|
277
|
+
raise ValueError(f"Input init time '{t0!s}' not in valid times")
|
|
278
|
+
if gsp_id not in self.location_lookup:
|
|
279
|
+
raise ValueError(f"Input GSP '{gsp_id}' not known")
|
|
275
280
|
|
|
276
281
|
location = self.location_lookup[gsp_id]
|
|
277
|
-
|
|
282
|
+
|
|
278
283
|
return self._get_sample(t0, location)
|
|
279
|
-
|
|
280
|
-
|
|
284
|
+
|
|
285
|
+
|
|
281
286
|
class PVNetUKConcurrentDataset(Dataset):
|
|
287
|
+
"""A torch Dataset for creating concurrent PVNet UK regional samples."""
|
|
288
|
+
|
|
282
289
|
def __init__(
|
|
283
|
-
self,
|
|
284
|
-
config_filename: str,
|
|
290
|
+
self,
|
|
291
|
+
config_filename: str,
|
|
285
292
|
start_time: str | None = None,
|
|
286
293
|
end_time: str | None = None,
|
|
287
294
|
gsp_ids: list[int] | None = None,
|
|
288
|
-
):
|
|
289
|
-
"""A torch Dataset for creating concurrent samples of PVNet UK regional data
|
|
290
|
-
|
|
295
|
+
) -> None:
|
|
296
|
+
"""A torch Dataset for creating concurrent samples of PVNet UK regional data.
|
|
297
|
+
|
|
291
298
|
Each concurrent sample includes the data from all GSPs for a single t0 time
|
|
292
299
|
|
|
293
300
|
Args:
|
|
@@ -296,7 +303,6 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
296
303
|
end_time: Limit the init-times to be before this
|
|
297
304
|
gsp_ids: List of all GSP IDs included in each sample. Defaults to all
|
|
298
305
|
"""
|
|
299
|
-
|
|
300
306
|
config = load_yaml_configuration(config_filename)
|
|
301
307
|
|
|
302
308
|
# Validate channels for NWP and satellite data
|
|
@@ -304,20 +310,20 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
304
310
|
validate_satellite_channels(config)
|
|
305
311
|
|
|
306
312
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
307
|
-
|
|
313
|
+
|
|
308
314
|
# Get t0 times where all input data is available
|
|
309
315
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
310
316
|
|
|
311
317
|
# Filter t0 times to given range
|
|
312
318
|
if start_time is not None:
|
|
313
|
-
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
314
|
-
|
|
319
|
+
valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
|
|
320
|
+
|
|
315
321
|
if end_time is not None:
|
|
316
|
-
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
322
|
+
valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
|
|
317
323
|
|
|
318
324
|
# Construct list of locations to sample from
|
|
319
325
|
locations = get_gsp_locations(gsp_ids)
|
|
320
|
-
|
|
326
|
+
|
|
321
327
|
# Assign coords and indices to self
|
|
322
328
|
self.valid_t0_times = valid_t0_times
|
|
323
329
|
self.locations = locations
|
|
@@ -325,48 +331,50 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
325
331
|
# Assign config and input data to self
|
|
326
332
|
self.datasets_dict = datasets_dict
|
|
327
333
|
self.config = config
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
def __len__(self):
|
|
334
|
+
|
|
335
|
+
@override
|
|
336
|
+
def __len__(self) -> int:
|
|
331
337
|
return len(self.valid_t0_times)
|
|
332
|
-
|
|
333
|
-
|
|
338
|
+
|
|
334
339
|
def _get_sample(self, t0: pd.Timestamp) -> dict:
|
|
335
|
-
"""Generate a concurrent PVNet sample for given init-time
|
|
336
|
-
|
|
340
|
+
"""Generate a concurrent PVNet sample for given init-time.
|
|
341
|
+
|
|
337
342
|
Args:
|
|
338
343
|
t0: init-time for sample
|
|
339
344
|
"""
|
|
340
345
|
# Slice by time then load to avoid loading the data multiple times from disk
|
|
341
346
|
sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
|
|
342
347
|
sample_dict = compute(sample_dict)
|
|
343
|
-
|
|
348
|
+
|
|
344
349
|
gsp_samples = []
|
|
345
|
-
|
|
350
|
+
|
|
346
351
|
# Prepare sample for each GSP
|
|
347
352
|
for location in self.locations:
|
|
348
353
|
gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
349
354
|
gsp_numpy_sample = process_and_combine_datasets(
|
|
350
|
-
gsp_sample_dict,
|
|
355
|
+
gsp_sample_dict,
|
|
356
|
+
self.config,
|
|
357
|
+
t0,
|
|
358
|
+
location,
|
|
351
359
|
)
|
|
352
360
|
gsp_samples.append(gsp_numpy_sample)
|
|
353
|
-
|
|
361
|
+
|
|
354
362
|
# Stack GSP samples
|
|
355
363
|
return stack_np_samples_into_batch(gsp_samples)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
def __getitem__(self, idx):
|
|
364
|
+
|
|
365
|
+
@override
|
|
366
|
+
def __getitem__(self, idx: int) -> dict:
|
|
359
367
|
return self._get_sample(self.valid_t0_times[idx])
|
|
360
|
-
|
|
361
368
|
|
|
362
369
|
def get_sample(self, t0: pd.Timestamp) -> dict:
|
|
363
|
-
"""Generate a sample for the given init-time.
|
|
364
|
-
|
|
370
|
+
"""Generate a sample for the given init-time.
|
|
371
|
+
|
|
365
372
|
Useful for users to generate specific samples.
|
|
366
|
-
|
|
373
|
+
|
|
367
374
|
Args:
|
|
368
375
|
t0: init-time for sample
|
|
369
376
|
"""
|
|
370
377
|
# Check data is availablle for init-time t0
|
|
371
|
-
|
|
378
|
+
if t0 not in self.valid_t0_times:
|
|
379
|
+
raise ValueError(f"Input init time '{t0!s}' not in valid times")
|
|
372
380
|
return self._get_sample(t0)
|