ocf-data-sampler 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +34 -0
- ocf_data_sampler/load/load_dataset.py +55 -0
- ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
- ocf_data_sampler/load/site.py +30 -0
- ocf_data_sampler/numpy_batch/__init__.py +4 -3
- ocf_data_sampler/numpy_batch/gsp.py +12 -12
- ocf_data_sampler/numpy_batch/nwp.py +14 -14
- ocf_data_sampler/numpy_batch/satellite.py +8 -8
- ocf_data_sampler/numpy_batch/site.py +29 -0
- ocf_data_sampler/select/__init__.py +8 -1
- ocf_data_sampler/select/dropout.py +2 -1
- ocf_data_sampler/select/geospatial.py +43 -1
- ocf_data_sampler/select/select_spatial_slice.py +8 -2
- ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
- ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
- ocf_data_sampler/time_functions.py +11 -0
- ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
- ocf_data_sampler/torch_datasets/site.py +196 -0
- ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
- {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/METADATA +1 -1
- ocf_data_sampler-0.0.25.dist-info/RECORD +66 -0
- {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/top_level.txt +1 -0
- scripts/refactor_site.py +50 -0
- tests/conftest.py +62 -0
- tests/load/test_load_sites.py +14 -0
- tests/numpy_batch/test_gsp.py +1 -2
- tests/numpy_batch/test_nwp.py +1 -3
- tests/numpy_batch/test_satellite.py +1 -3
- tests/numpy_batch/test_sun_position.py +7 -7
- tests/torch_datasets/test_pvnet_uk_regional.py +4 -6
- tests/torch_datasets/test_site.py +85 -0
- ocf_data_sampler-0.0.23.dist-info/RECORD +0 -54
- {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.config import Configuration
|
|
6
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
7
|
+
from ocf_data_sampler.numpy_batch import (
|
|
8
|
+
convert_nwp_to_numpy_batch,
|
|
9
|
+
convert_satellite_to_numpy_batch,
|
|
10
|
+
convert_gsp_to_numpy_batch,
|
|
11
|
+
make_sun_position_numpy_batch,
|
|
12
|
+
convert_site_to_numpy_batch,
|
|
13
|
+
)
|
|
14
|
+
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
|
|
15
|
+
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
|
|
16
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
17
|
+
from ocf_data_sampler.select.location import Location
|
|
18
|
+
from ocf_data_sampler.time_functions import minutes
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process_and_combine_datasets(
|
|
22
|
+
dataset_dict: dict,
|
|
23
|
+
config: Configuration,
|
|
24
|
+
t0: pd.Timestamp,
|
|
25
|
+
location: Location,
|
|
26
|
+
sun_position_key: str = 'gsp'
|
|
27
|
+
) -> dict:
|
|
28
|
+
"""Normalize and convert data to numpy arrays"""
|
|
29
|
+
|
|
30
|
+
numpy_modalities = []
|
|
31
|
+
|
|
32
|
+
if "nwp" in dataset_dict:
|
|
33
|
+
|
|
34
|
+
nwp_numpy_modalities = dict()
|
|
35
|
+
|
|
36
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
37
|
+
# Standardise
|
|
38
|
+
provider = config.input_data.nwp[nwp_key].nwp_provider
|
|
39
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
|
+
# Convert to NumpyBatch
|
|
41
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
42
|
+
|
|
43
|
+
# Combine the NWPs into NumpyBatch
|
|
44
|
+
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
|
|
45
|
+
|
|
46
|
+
if "sat" in dataset_dict:
|
|
47
|
+
# Satellite is already in the range [0-1] so no need to standardise
|
|
48
|
+
da_sat = dataset_dict["sat"]
|
|
49
|
+
|
|
50
|
+
# Convert to NumpyBatch
|
|
51
|
+
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
|
|
52
|
+
|
|
53
|
+
gsp_config = config.input_data.gsp
|
|
54
|
+
|
|
55
|
+
if "gsp" in dataset_dict:
|
|
56
|
+
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
|
|
57
|
+
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
58
|
+
|
|
59
|
+
numpy_modalities.append(
|
|
60
|
+
convert_gsp_to_numpy_batch(
|
|
61
|
+
da_gsp, t0_idx=gsp_config.history_minutes // gsp_config.time_resolution_minutes
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Add coordinate data
|
|
66
|
+
# TODO: Do we need all of these?
|
|
67
|
+
numpy_modalities.append(
|
|
68
|
+
{
|
|
69
|
+
GSPBatchKey.gsp_id: location.id,
|
|
70
|
+
GSPBatchKey.x_osgb: location.x,
|
|
71
|
+
GSPBatchKey.y_osgb: location.y,
|
|
72
|
+
}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if "site" in dataset_dict:
|
|
77
|
+
site_config = config.input_data.site
|
|
78
|
+
da_sites = dataset_dict["site"]
|
|
79
|
+
da_sites = da_sites / da_sites.capacity_kwp
|
|
80
|
+
|
|
81
|
+
numpy_modalities.append(
|
|
82
|
+
convert_site_to_numpy_batch(
|
|
83
|
+
da_sites, t0_idx=site_config.history_minutes / site_config.time_resolution_minutes
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if sun_position_key == 'gsp':
|
|
88
|
+
# Make sun coords NumpyBatch
|
|
89
|
+
datetimes = pd.date_range(
|
|
90
|
+
t0 - minutes(gsp_config.history_minutes),
|
|
91
|
+
t0 + minutes(gsp_config.forecast_minutes),
|
|
92
|
+
freq=minutes(gsp_config.time_resolution_minutes),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
96
|
+
key_prefix = "gsp"
|
|
97
|
+
|
|
98
|
+
elif sun_position_key == 'site':
|
|
99
|
+
# Make sun coords NumpyBatch
|
|
100
|
+
datetimes = pd.date_range(
|
|
101
|
+
t0 - minutes(site_config.history_minutes),
|
|
102
|
+
t0 + minutes(site_config.forecast_minutes),
|
|
103
|
+
freq=minutes(site_config.time_resolution_minutes),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
lon, lat = location.x, location.y
|
|
107
|
+
key_prefix = "site"
|
|
108
|
+
|
|
109
|
+
numpy_modalities.append(
|
|
110
|
+
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=key_prefix)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Combine all the modalities and fill NaNs
|
|
114
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
115
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
116
|
+
|
|
117
|
+
return combined_sample
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
121
|
+
"""Merge a list of dictionaries into a single dictionary"""
|
|
122
|
+
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
123
|
+
combined_dict = {}
|
|
124
|
+
for d in list_of_dicts:
|
|
125
|
+
combined_dict.update(d)
|
|
126
|
+
return combined_dict
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def fill_nans_in_arrays(batch: dict) -> dict:
|
|
130
|
+
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
|
|
131
|
+
|
|
132
|
+
Operation is performed in-place on the batch.
|
|
133
|
+
"""
|
|
134
|
+
for k, v in batch.items():
|
|
135
|
+
if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
|
|
136
|
+
if np.isnan(v).any():
|
|
137
|
+
batch[k] = np.nan_to_num(v, copy=False, nan=0.0)
|
|
138
|
+
|
|
139
|
+
# Recursion is included to reach NWP arrays in subdict
|
|
140
|
+
elif isinstance(v, dict):
|
|
141
|
+
fill_nans_in_arrays(v)
|
|
142
|
+
|
|
143
|
+
return batch
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def compute(xarray_dict: dict) -> dict:
|
|
147
|
+
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
148
|
+
for k, v in xarray_dict.items():
|
|
149
|
+
if isinstance(v, dict):
|
|
150
|
+
xarray_dict[k] = compute(v)
|
|
151
|
+
else:
|
|
152
|
+
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
153
|
+
return xarray_dict
|
|
@@ -2,100 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
+
import pkg_resources
|
|
5
6
|
import xarray as xr
|
|
6
7
|
from torch.utils.data import Dataset
|
|
7
|
-
import pkg_resources
|
|
8
|
-
|
|
9
|
-
from ocf_data_sampler.load.gsp import open_gsp
|
|
10
|
-
from ocf_data_sampler.load.nwp import open_nwp
|
|
11
|
-
from ocf_data_sampler.load.satellite import open_sat_data
|
|
12
|
-
|
|
13
|
-
from ocf_data_sampler.select.find_contiguous_time_periods import (
|
|
14
|
-
find_contiguous_t0_periods, find_contiguous_t0_periods_nwp,
|
|
15
|
-
intersection_of_multiple_dataframes_of_periods,
|
|
16
|
-
)
|
|
17
|
-
from ocf_data_sampler.select.fill_time_periods import fill_time_periods
|
|
18
|
-
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
19
|
-
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
20
|
-
from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
|
|
21
|
-
|
|
22
|
-
from ocf_data_sampler.numpy_batch import (
|
|
23
|
-
convert_gsp_to_numpy_batch,
|
|
24
|
-
convert_nwp_to_numpy_batch,
|
|
25
|
-
convert_satellite_to_numpy_batch,
|
|
26
|
-
make_sun_position_numpy_batch,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
8
|
|
|
30
9
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
31
|
-
from ocf_data_sampler.
|
|
32
|
-
from ocf_data_sampler.
|
|
33
|
-
|
|
34
|
-
from ocf_data_sampler.
|
|
35
|
-
from ocf_data_sampler.
|
|
36
|
-
|
|
37
|
-
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
10
|
+
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
|
+
from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
|
|
12
|
+
from ocf_data_sampler.time_functions import minutes
|
|
13
|
+
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
|
|
14
|
+
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
41
15
|
|
|
42
16
|
xr.set_options(keep_attrs=True)
|
|
43
17
|
|
|
44
18
|
|
|
45
|
-
|
|
46
|
-
def minutes(minutes: list[float]):
|
|
47
|
-
"""Timedelta minutes
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
m: minutes
|
|
51
|
-
"""
|
|
52
|
-
return pd.to_timedelta(minutes, unit="m")
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]:
|
|
56
|
-
"""Construct dictionary of all of the input data sources
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
config: Configuration file
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
in_config = config.input_data
|
|
63
|
-
|
|
64
|
-
datasets_dict = {}
|
|
65
|
-
|
|
66
|
-
# Load GSP data unless the path is None
|
|
67
|
-
if in_config.gsp.gsp_zarr_path:
|
|
68
|
-
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
|
|
69
|
-
|
|
70
|
-
# Remove national GSP
|
|
71
|
-
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
|
|
72
|
-
|
|
73
|
-
# Load NWP data if in config
|
|
74
|
-
if in_config.nwp:
|
|
75
|
-
|
|
76
|
-
datasets_dict["nwp"] = {}
|
|
77
|
-
for nwp_source, nwp_config in in_config.nwp.items():
|
|
78
|
-
|
|
79
|
-
da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
|
|
80
|
-
|
|
81
|
-
da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
|
|
82
|
-
|
|
83
|
-
datasets_dict["nwp"][nwp_source] = da_nwp
|
|
84
|
-
|
|
85
|
-
# Load satellite data if in config
|
|
86
|
-
if in_config.satellite:
|
|
87
|
-
sat_config = config.input_data.satellite
|
|
88
|
-
|
|
89
|
-
da_sat = open_sat_data(sat_config.satellite_zarr_path)
|
|
90
|
-
|
|
91
|
-
da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
|
|
92
|
-
|
|
93
|
-
datasets_dict["sat"] = da_sat
|
|
94
|
-
|
|
95
|
-
return datasets_dict
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
19
|
def find_valid_t0_times(
|
|
100
20
|
datasets_dict: dict,
|
|
101
21
|
config: Configuration,
|
|
@@ -103,96 +23,11 @@ def find_valid_t0_times(
|
|
|
103
23
|
"""Find the t0 times where all of the requested input data is available
|
|
104
24
|
|
|
105
25
|
Args:
|
|
106
|
-
datasets_dict: A dictionary of input datasets
|
|
26
|
+
datasets_dict: A dictionary of input datasets
|
|
107
27
|
config: Configuration file
|
|
108
28
|
"""
|
|
109
29
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
|
|
113
|
-
|
|
114
|
-
if "nwp" in datasets_dict:
|
|
115
|
-
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
116
|
-
|
|
117
|
-
da = datasets_dict["nwp"][nwp_key]
|
|
118
|
-
|
|
119
|
-
if nwp_config.dropout_timedeltas_minutes is None:
|
|
120
|
-
max_dropout = minutes(0)
|
|
121
|
-
else:
|
|
122
|
-
max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
|
|
123
|
-
|
|
124
|
-
if nwp_config.max_staleness_minutes is None:
|
|
125
|
-
max_staleness = None
|
|
126
|
-
else:
|
|
127
|
-
max_staleness = minutes(nwp_config.max_staleness_minutes)
|
|
128
|
-
|
|
129
|
-
# The last step of the forecast is lost if we have to diff channels
|
|
130
|
-
if len(nwp_config.nwp_accum_channels) > 0:
|
|
131
|
-
end_buffer = minutes(nwp_config.time_resolution_minutes)
|
|
132
|
-
else:
|
|
133
|
-
end_buffer = minutes(0)
|
|
134
|
-
|
|
135
|
-
# This is the max staleness we can use considering the max step of the input data
|
|
136
|
-
max_possible_staleness = (
|
|
137
|
-
pd.Timedelta(da["step"].max().item())
|
|
138
|
-
- minutes(nwp_config.forecast_minutes)
|
|
139
|
-
- end_buffer
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
# Default to use max possible staleness unless specified in config
|
|
143
|
-
if max_staleness is None:
|
|
144
|
-
max_staleness = max_possible_staleness
|
|
145
|
-
else:
|
|
146
|
-
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
147
|
-
assert max_staleness <= max_possible_staleness
|
|
148
|
-
|
|
149
|
-
time_periods = find_contiguous_t0_periods_nwp(
|
|
150
|
-
datetimes=pd.DatetimeIndex(da["init_time_utc"]),
|
|
151
|
-
history_duration=minutes(nwp_config.history_minutes),
|
|
152
|
-
max_staleness=max_staleness,
|
|
153
|
-
max_dropout=max_dropout,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
|
|
157
|
-
|
|
158
|
-
if "sat" in datasets_dict:
|
|
159
|
-
sat_config = config.input_data.satellite
|
|
160
|
-
|
|
161
|
-
time_periods = find_contiguous_t0_periods(
|
|
162
|
-
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
163
|
-
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
164
|
-
history_duration=minutes(sat_config.history_minutes),
|
|
165
|
-
forecast_duration=minutes(sat_config.forecast_minutes),
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
contiguous_time_periods['sat'] = time_periods
|
|
169
|
-
|
|
170
|
-
if "gsp" in datasets_dict:
|
|
171
|
-
gsp_config = config.input_data.gsp
|
|
172
|
-
|
|
173
|
-
time_periods = find_contiguous_t0_periods(
|
|
174
|
-
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
175
|
-
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
176
|
-
history_duration=minutes(gsp_config.history_minutes),
|
|
177
|
-
forecast_duration=minutes(gsp_config.forecast_minutes),
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
contiguous_time_periods['gsp'] = time_periods
|
|
181
|
-
|
|
182
|
-
# just get the values (not the keys)
|
|
183
|
-
contiguous_time_periods_values = list(contiguous_time_periods.values())
|
|
184
|
-
|
|
185
|
-
# Find joint overlapping contiguous time periods
|
|
186
|
-
if len(contiguous_time_periods_values) > 1:
|
|
187
|
-
valid_time_periods = intersection_of_multiple_dataframes_of_periods(
|
|
188
|
-
contiguous_time_periods_values
|
|
189
|
-
)
|
|
190
|
-
else:
|
|
191
|
-
valid_time_periods = contiguous_time_periods_values[0]
|
|
192
|
-
|
|
193
|
-
# check there are some valid time periods
|
|
194
|
-
if len(valid_time_periods) == 0:
|
|
195
|
-
raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
|
|
30
|
+
valid_time_periods = find_valid_time_periods(datasets_dict, config)
|
|
196
31
|
|
|
197
32
|
# Fill out the contiguous time periods to get the t0 times
|
|
198
33
|
valid_t0_times = fill_time_periods(
|
|
@@ -203,250 +38,6 @@ def find_valid_t0_times(
|
|
|
203
38
|
return valid_t0_times
|
|
204
39
|
|
|
205
40
|
|
|
206
|
-
def slice_datasets_by_space(
|
|
207
|
-
datasets_dict: dict,
|
|
208
|
-
location: Location,
|
|
209
|
-
config: Configuration,
|
|
210
|
-
) -> dict:
|
|
211
|
-
"""Slice a dictionaries of input data sources around a given location
|
|
212
|
-
|
|
213
|
-
Args:
|
|
214
|
-
datasets_dict: Dictionary of the input data sources
|
|
215
|
-
location: The location to sample around
|
|
216
|
-
config: Configuration object.
|
|
217
|
-
"""
|
|
218
|
-
|
|
219
|
-
assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
|
|
220
|
-
|
|
221
|
-
sliced_datasets_dict = {}
|
|
222
|
-
|
|
223
|
-
if "nwp" in datasets_dict:
|
|
224
|
-
|
|
225
|
-
sliced_datasets_dict["nwp"] = {}
|
|
226
|
-
|
|
227
|
-
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
228
|
-
|
|
229
|
-
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
|
|
230
|
-
datasets_dict["nwp"][nwp_key],
|
|
231
|
-
location,
|
|
232
|
-
height_pixels=nwp_config.nwp_image_size_pixels_height,
|
|
233
|
-
width_pixels=nwp_config.nwp_image_size_pixels_width,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
if "sat" in datasets_dict:
|
|
237
|
-
sat_config = config.input_data.satellite
|
|
238
|
-
|
|
239
|
-
sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
|
|
240
|
-
datasets_dict["sat"],
|
|
241
|
-
location,
|
|
242
|
-
height_pixels=sat_config.satellite_image_size_pixels_height,
|
|
243
|
-
width_pixels=sat_config.satellite_image_size_pixels_width,
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
if "gsp" in datasets_dict:
|
|
247
|
-
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
|
|
248
|
-
|
|
249
|
-
return sliced_datasets_dict
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def slice_datasets_by_time(
|
|
253
|
-
datasets_dict: dict,
|
|
254
|
-
t0: pd.Timedelta,
|
|
255
|
-
config: Configuration,
|
|
256
|
-
) -> dict:
|
|
257
|
-
"""Slice a dictionaries of input data sources around a given t0 time
|
|
258
|
-
|
|
259
|
-
Args:
|
|
260
|
-
datasets_dict: Dictionary of the input data sources
|
|
261
|
-
t0: The init-time
|
|
262
|
-
config: Configuration object.
|
|
263
|
-
"""
|
|
264
|
-
|
|
265
|
-
sliced_datasets_dict = {}
|
|
266
|
-
|
|
267
|
-
if "nwp" in datasets_dict:
|
|
268
|
-
|
|
269
|
-
sliced_datasets_dict["nwp"] = {}
|
|
270
|
-
|
|
271
|
-
for nwp_key, da_nwp in datasets_dict["nwp"].items():
|
|
272
|
-
|
|
273
|
-
nwp_config = config.input_data.nwp[nwp_key]
|
|
274
|
-
|
|
275
|
-
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
276
|
-
da_nwp,
|
|
277
|
-
t0,
|
|
278
|
-
sample_period_duration=minutes(nwp_config.time_resolution_minutes),
|
|
279
|
-
history_duration=minutes(nwp_config.history_minutes),
|
|
280
|
-
forecast_duration=minutes(nwp_config.forecast_minutes),
|
|
281
|
-
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
282
|
-
dropout_frac=nwp_config.dropout_fraction,
|
|
283
|
-
accum_channels=nwp_config.nwp_accum_channels,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
if "sat" in datasets_dict:
|
|
287
|
-
|
|
288
|
-
sat_config = config.input_data.satellite
|
|
289
|
-
|
|
290
|
-
sliced_datasets_dict["sat"] = select_time_slice(
|
|
291
|
-
datasets_dict["sat"],
|
|
292
|
-
t0,
|
|
293
|
-
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
294
|
-
interval_start=minutes(-sat_config.history_minutes),
|
|
295
|
-
interval_end=minutes(-sat_config.live_delay_minutes),
|
|
296
|
-
max_steps_gap=2,
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
# Randomly sample dropout
|
|
300
|
-
sat_dropout_time = draw_dropout_time(
|
|
301
|
-
t0,
|
|
302
|
-
dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
|
|
303
|
-
dropout_frac=sat_config.dropout_fraction,
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
# Apply the dropout
|
|
307
|
-
sliced_datasets_dict["sat"] = apply_dropout_time(
|
|
308
|
-
sliced_datasets_dict["sat"],
|
|
309
|
-
sat_dropout_time,
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
if "gsp" in datasets_dict:
|
|
313
|
-
gsp_config = config.input_data.gsp
|
|
314
|
-
|
|
315
|
-
sliced_datasets_dict["gsp_future"] = select_time_slice(
|
|
316
|
-
datasets_dict["gsp"],
|
|
317
|
-
t0,
|
|
318
|
-
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
319
|
-
interval_start=minutes(30),
|
|
320
|
-
interval_end=minutes(gsp_config.forecast_minutes),
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
324
|
-
datasets_dict["gsp"],
|
|
325
|
-
t0,
|
|
326
|
-
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
327
|
-
interval_start=-minutes(gsp_config.history_minutes),
|
|
328
|
-
interval_end=minutes(0),
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
# Dropout on the GSP, but not the future GSP
|
|
332
|
-
gsp_dropout_time = draw_dropout_time(
|
|
333
|
-
t0,
|
|
334
|
-
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
335
|
-
dropout_frac=gsp_config.dropout_fraction,
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)
|
|
339
|
-
|
|
340
|
-
return sliced_datasets_dict
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def fill_nans_in_arrays(batch: dict) -> dict:
|
|
344
|
-
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
|
|
345
|
-
|
|
346
|
-
Operation is performed in-place on the batch.
|
|
347
|
-
"""
|
|
348
|
-
for k, v in batch.items():
|
|
349
|
-
if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
|
|
350
|
-
if np.isnan(v).any():
|
|
351
|
-
batch[k] = np.nan_to_num(v, copy=False, nan=0.0)
|
|
352
|
-
|
|
353
|
-
# Recursion is included to reach NWP arrays in subdict
|
|
354
|
-
elif isinstance(v, dict):
|
|
355
|
-
fill_nans_in_arrays(v)
|
|
356
|
-
|
|
357
|
-
return batch
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
362
|
-
"""Merge a list of dictionaries into a single dictionary"""
|
|
363
|
-
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
364
|
-
combined_dict = {}
|
|
365
|
-
for d in list_of_dicts:
|
|
366
|
-
combined_dict.update(d)
|
|
367
|
-
return combined_dict
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
def process_and_combine_datasets(
|
|
371
|
-
dataset_dict: dict,
|
|
372
|
-
config: Configuration,
|
|
373
|
-
t0: pd.Timedelta,
|
|
374
|
-
location: Location,
|
|
375
|
-
) -> dict:
|
|
376
|
-
"""Normalize and convert data to numpy arrays"""
|
|
377
|
-
|
|
378
|
-
numpy_modalities = []
|
|
379
|
-
|
|
380
|
-
if "nwp" in dataset_dict:
|
|
381
|
-
|
|
382
|
-
nwp_numpy_modalities = dict()
|
|
383
|
-
|
|
384
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
385
|
-
# Standardise
|
|
386
|
-
provider = config.input_data.nwp[nwp_key].nwp_provider
|
|
387
|
-
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
388
|
-
# Convert to NumpyBatch
|
|
389
|
-
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
390
|
-
|
|
391
|
-
# Combine the NWPs into NumpyBatch
|
|
392
|
-
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
|
|
393
|
-
|
|
394
|
-
if "sat" in dataset_dict:
|
|
395
|
-
# Satellite is already in the range [0-1] so no need to standardise
|
|
396
|
-
da_sat = dataset_dict["sat"]
|
|
397
|
-
|
|
398
|
-
# Convert to NumpyBatch
|
|
399
|
-
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
|
|
400
|
-
|
|
401
|
-
gsp_config = config.input_data.gsp
|
|
402
|
-
|
|
403
|
-
if "gsp" in dataset_dict:
|
|
404
|
-
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
|
|
405
|
-
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
406
|
-
|
|
407
|
-
numpy_modalities.append(
|
|
408
|
-
convert_gsp_to_numpy_batch(
|
|
409
|
-
da_gsp,
|
|
410
|
-
t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
|
|
411
|
-
)
|
|
412
|
-
)
|
|
413
|
-
|
|
414
|
-
# Make sun coords NumpyBatch
|
|
415
|
-
datetimes = pd.date_range(
|
|
416
|
-
t0-minutes(gsp_config.history_minutes),
|
|
417
|
-
t0+minutes(gsp_config.forecast_minutes),
|
|
418
|
-
freq=minutes(gsp_config.time_resolution_minutes),
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
422
|
-
|
|
423
|
-
numpy_modalities.append(make_sun_position_numpy_batch(datetimes, lon, lat))
|
|
424
|
-
|
|
425
|
-
# Add coordinate data
|
|
426
|
-
# TODO: Do we need all of these?
|
|
427
|
-
numpy_modalities.append({
|
|
428
|
-
GSPBatchKey.gsp_id: location.id,
|
|
429
|
-
GSPBatchKey.gsp_x_osgb: location.x,
|
|
430
|
-
GSPBatchKey.gsp_y_osgb: location.y,
|
|
431
|
-
})
|
|
432
|
-
|
|
433
|
-
# Combine all the modalities and fill NaNs
|
|
434
|
-
combined_sample = merge_dicts(numpy_modalities)
|
|
435
|
-
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
436
|
-
|
|
437
|
-
return combined_sample
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
def compute(xarray_dict: dict) -> dict:
|
|
441
|
-
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
442
|
-
for k, v in xarray_dict.items():
|
|
443
|
-
if isinstance(v, dict):
|
|
444
|
-
xarray_dict[k] = compute(v)
|
|
445
|
-
else:
|
|
446
|
-
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
447
|
-
return xarray_dict
|
|
448
|
-
|
|
449
|
-
|
|
450
41
|
def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
|
|
451
42
|
"""Get list of locations of all GSPs"""
|
|
452
43
|
|
|
@@ -473,7 +64,6 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
|
|
|
473
64
|
return locations
|
|
474
65
|
|
|
475
66
|
|
|
476
|
-
|
|
477
67
|
class PVNetUKRegionalDataset(Dataset):
|
|
478
68
|
def __init__(
|
|
479
69
|
self,
|