ocf-data-sampler 0.0.18__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/__init__.py +5 -0
- ocf_data_sampler/config/load.py +33 -0
- ocf_data_sampler/config/model.py +246 -0
- ocf_data_sampler/config/save.py +73 -0
- ocf_data_sampler/constants.py +173 -0
- ocf_data_sampler/load/load_dataset.py +55 -0
- ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
- ocf_data_sampler/load/site.py +30 -0
- ocf_data_sampler/numpy_sample/__init__.py +8 -0
- ocf_data_sampler/numpy_sample/collate.py +77 -0
- ocf_data_sampler/numpy_sample/gsp.py +34 -0
- ocf_data_sampler/numpy_sample/nwp.py +42 -0
- ocf_data_sampler/numpy_sample/satellite.py +30 -0
- ocf_data_sampler/numpy_sample/site.py +30 -0
- ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
- ocf_data_sampler/select/__init__.py +8 -1
- ocf_data_sampler/select/dropout.py +4 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
- ocf_data_sampler/select/geospatial.py +160 -0
- ocf_data_sampler/select/location.py +62 -0
- ocf_data_sampler/select/select_spatial_slice.py +13 -16
- ocf_data_sampler/select/select_time_slice.py +24 -33
- ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
- ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
- ocf_data_sampler/torch_datasets/__init__.py +2 -1
- ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +19 -427
- ocf_data_sampler/torch_datasets/site.py +405 -0
- ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
- ocf_data_sampler/utils.py +10 -0
- ocf_data_sampler-0.0.42.dist-info/METADATA +153 -0
- ocf_data_sampler-0.0.42.dist-info/RECORD +71 -0
- {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/top_level.txt +1 -0
- scripts/refactor_site.py +50 -0
- tests/config/test_config.py +161 -0
- tests/config/test_save.py +37 -0
- tests/conftest.py +86 -1
- tests/load/test_load_gsp.py +15 -0
- tests/load/test_load_nwp.py +21 -0
- tests/load/test_load_satellite.py +17 -0
- tests/load/test_load_sites.py +14 -0
- tests/numpy_sample/test_collate.py +26 -0
- tests/numpy_sample/test_gsp.py +38 -0
- tests/numpy_sample/test_nwp.py +52 -0
- tests/numpy_sample/test_satellite.py +40 -0
- tests/numpy_sample/test_sun_position.py +81 -0
- tests/select/test_dropout.py +75 -0
- tests/select/test_fill_time_periods.py +28 -0
- tests/select/test_find_contiguous_time_periods.py +202 -0
- tests/select/test_location.py +67 -0
- tests/select/test_select_spatial_slice.py +154 -0
- tests/select/test_select_time_slice.py +272 -0
- tests/torch_datasets/conftest.py +18 -0
- tests/torch_datasets/test_process_and_combine.py +126 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
- tests/torch_datasets/test_site.py +129 -0
- ocf_data_sampler/numpy_batch/__init__.py +0 -7
- ocf_data_sampler/numpy_batch/gsp.py +0 -20
- ocf_data_sampler/numpy_batch/nwp.py +0 -33
- ocf_data_sampler/numpy_batch/satellite.py +0 -23
- ocf_data_sampler-0.0.18.dist-info/METADATA +0 -22
- ocf_data_sampler-0.0.18.dist-info/RECORD +0 -32
- {ocf_data_sampler-0.0.18.dist-info → ocf_data_sampler-0.0.42.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Torch dataset for sites"""
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import xarray as xr
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
|
|
10
|
+
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
11
|
+
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
12
|
+
from ocf_data_sampler.select import (
|
|
13
|
+
Location,
|
|
14
|
+
fill_time_periods,
|
|
15
|
+
find_contiguous_t0_periods,
|
|
16
|
+
intersection_of_multiple_dataframes_of_periods,
|
|
17
|
+
slice_datasets_by_time, slice_datasets_by_space
|
|
18
|
+
)
|
|
19
|
+
from ocf_data_sampler.utils import minutes
|
|
20
|
+
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
21
|
+
from ocf_data_sampler.torch_datasets.process_and_combine import merge_dicts, fill_nans_in_arrays
|
|
22
|
+
from ocf_data_sampler.numpy_sample import (
|
|
23
|
+
convert_site_to_numpy_sample,
|
|
24
|
+
convert_satellite_to_numpy_sample,
|
|
25
|
+
convert_nwp_to_numpy_sample
|
|
26
|
+
)
|
|
27
|
+
from ocf_data_sampler.numpy_sample import NWPSampleKey
|
|
28
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
29
|
+
|
|
30
|
+
xr.set_options(keep_attrs=True)
|
|
31
|
+
|
|
32
|
+
class SitesDataset(Dataset):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
config_filename: str,
|
|
36
|
+
start_time: str | None = None,
|
|
37
|
+
end_time: str | None = None,
|
|
38
|
+
):
|
|
39
|
+
"""A torch Dataset for creating PVNet Site samples
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
config_filename: Path to the configuration file
|
|
43
|
+
start_time: Limit the init-times to be after this
|
|
44
|
+
end_time: Limit the init-times to be before this
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
config: Configuration = load_yaml_configuration(config_filename)
|
|
48
|
+
datasets_dict = get_dataset_dict(config)
|
|
49
|
+
|
|
50
|
+
# Assign config and input data to self
|
|
51
|
+
self.datasets_dict = datasets_dict
|
|
52
|
+
self.config = config
|
|
53
|
+
|
|
54
|
+
# get all locations
|
|
55
|
+
self.locations = self.get_locations(datasets_dict['site'])
|
|
56
|
+
|
|
57
|
+
# Get t0 times where all input data is available
|
|
58
|
+
valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
|
|
59
|
+
|
|
60
|
+
# Filter t0 times to given range
|
|
61
|
+
if start_time is not None:
|
|
62
|
+
valid_t0_and_site_ids \
|
|
63
|
+
= valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] >= pd.Timestamp(start_time)]
|
|
64
|
+
|
|
65
|
+
if end_time is not None:
|
|
66
|
+
valid_t0_and_site_ids \
|
|
67
|
+
= valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] <= pd.Timestamp(end_time)]
|
|
68
|
+
|
|
69
|
+
# Assign coords and indices to self
|
|
70
|
+
self.valid_t0_and_site_ids = valid_t0_and_site_ids
|
|
71
|
+
|
|
72
|
+
def __len__(self):
|
|
73
|
+
return len(self.valid_t0_and_site_ids)
|
|
74
|
+
|
|
75
|
+
def __getitem__(self, idx):
|
|
76
|
+
|
|
77
|
+
# Get the coordinates of the sample
|
|
78
|
+
t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
|
|
79
|
+
|
|
80
|
+
# get location from site id
|
|
81
|
+
location = self.get_location_from_site_id(site_id)
|
|
82
|
+
|
|
83
|
+
# Generate the sample
|
|
84
|
+
return self._get_sample(t0, location)
|
|
85
|
+
|
|
86
|
+
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
87
|
+
"""Generate the PVNet sample for given coordinates
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
t0: init-time for sample
|
|
91
|
+
location: location for sample
|
|
92
|
+
"""
|
|
93
|
+
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
94
|
+
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
95
|
+
|
|
96
|
+
sample = self.process_and_combine_site_sample_dict(sample_dict)
|
|
97
|
+
sample = sample.compute()
|
|
98
|
+
return sample
|
|
99
|
+
|
|
100
|
+
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
101
|
+
"""Generate a sample for a given site id and t0.
|
|
102
|
+
|
|
103
|
+
Useful for users to generate samples by t0 and site id
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
t0: init-time for sample
|
|
107
|
+
site_id: site id as int
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
location = self.get_location_from_site_id(site_id)
|
|
111
|
+
|
|
112
|
+
return self._get_sample(t0, location)
|
|
113
|
+
|
|
114
|
+
def get_location_from_site_id(self, site_id):
|
|
115
|
+
"""Get location from system id"""
|
|
116
|
+
|
|
117
|
+
locations = [loc for loc in self.locations if loc.id == site_id]
|
|
118
|
+
if len(locations) == 0:
|
|
119
|
+
raise ValueError(f"Location not found for site_id {site_id}")
|
|
120
|
+
|
|
121
|
+
if len(locations) > 1:
|
|
122
|
+
logging.warning(f"Multiple locations found for site_id {site_id}, but will take the first")
|
|
123
|
+
|
|
124
|
+
return locations[0]
|
|
125
|
+
|
|
126
|
+
def find_valid_t0_and_site_ids(
|
|
127
|
+
self,
|
|
128
|
+
datasets_dict: dict,
|
|
129
|
+
) -> pd.DataFrame:
|
|
130
|
+
"""Find the t0 times where all of the requested input data is available
|
|
131
|
+
|
|
132
|
+
The idea is to
|
|
133
|
+
1. Get valid time period for nwp and satellite
|
|
134
|
+
2. For each site location, find valid periods for that location
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
datasets_dict: A dictionary of input datasets
|
|
138
|
+
config: Configuration file
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
# 1. Get valid time period for nwp and satellite
|
|
142
|
+
datasets_without_site = {k:v for k, v in datasets_dict.items() if k!="site"}
|
|
143
|
+
valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
|
|
144
|
+
|
|
145
|
+
# 2. Now lets loop over each location in system id and find the valid periods
|
|
146
|
+
# Should we have a different option if there are not nans
|
|
147
|
+
sites = datasets_dict["site"]
|
|
148
|
+
site_ids = sites.site_id.values
|
|
149
|
+
site_config = self.config.input_data.site
|
|
150
|
+
valid_t0_and_site_ids = []
|
|
151
|
+
for site_id in site_ids:
|
|
152
|
+
site = sites.sel(site_id=site_id)
|
|
153
|
+
|
|
154
|
+
# drop any nan values
|
|
155
|
+
# not sure this is right?
|
|
156
|
+
site = site.dropna(dim='time_utc')
|
|
157
|
+
|
|
158
|
+
# Get the valid time periods for this location
|
|
159
|
+
time_periods = find_contiguous_t0_periods(
|
|
160
|
+
pd.DatetimeIndex(site["time_utc"]),
|
|
161
|
+
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
162
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
163
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
164
|
+
)
|
|
165
|
+
valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
|
|
166
|
+
[valid_time_periods, time_periods]
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Fill out the contiguous time periods to get the t0 times
|
|
170
|
+
valid_t0_times_per_site = fill_time_periods(
|
|
171
|
+
valid_time_periods_per_site,
|
|
172
|
+
freq=minutes(site_config.time_resolution_minutes)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site)
|
|
176
|
+
valid_t0_per_site['site_id'] = site_id
|
|
177
|
+
valid_t0_and_site_ids.append(valid_t0_per_site)
|
|
178
|
+
|
|
179
|
+
valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
|
|
180
|
+
valid_t0_and_site_ids.index.name = 't0'
|
|
181
|
+
valid_t0_and_site_ids.reset_index(inplace=True)
|
|
182
|
+
|
|
183
|
+
return valid_t0_and_site_ids
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def get_locations(self, site_xr: xr.Dataset):
|
|
187
|
+
"""Get list of locations of all sites"""
|
|
188
|
+
|
|
189
|
+
locations = []
|
|
190
|
+
for site_id in site_xr.site_id.values:
|
|
191
|
+
site = site_xr.sel(site_id=site_id)
|
|
192
|
+
location = Location(
|
|
193
|
+
id=site_id,
|
|
194
|
+
x=site.longitude.values,
|
|
195
|
+
y=site.latitude.values,
|
|
196
|
+
coordinate_system="lon_lat"
|
|
197
|
+
)
|
|
198
|
+
locations.append(location)
|
|
199
|
+
|
|
200
|
+
return locations
|
|
201
|
+
|
|
202
|
+
def process_and_combine_site_sample_dict(
|
|
203
|
+
self,
|
|
204
|
+
dataset_dict: dict,
|
|
205
|
+
) -> xr.Dataset:
|
|
206
|
+
"""
|
|
207
|
+
Normalize and combine data into a single xr Dataset
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
dataset_dict: dict containing sliced xr DataArrays
|
|
211
|
+
config: Configuration for the model
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
xr.Dataset: A merged Dataset with nans filled in.
|
|
215
|
+
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
data_arrays = []
|
|
219
|
+
|
|
220
|
+
if "nwp" in dataset_dict:
|
|
221
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
222
|
+
# Standardise
|
|
223
|
+
provider = self.config.input_data.nwp[nwp_key].provider
|
|
224
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
225
|
+
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
226
|
+
|
|
227
|
+
if "sat" in dataset_dict:
|
|
228
|
+
# TODO add some satellite normalisation
|
|
229
|
+
da_sat = dataset_dict["sat"]
|
|
230
|
+
data_arrays.append(("satellite", da_sat))
|
|
231
|
+
|
|
232
|
+
if "site" in dataset_dict:
|
|
233
|
+
# site_config = config.input_data.site
|
|
234
|
+
da_sites = dataset_dict["site"]
|
|
235
|
+
da_sites = da_sites / da_sites.capacity_kwp
|
|
236
|
+
data_arrays.append(("site", da_sites))
|
|
237
|
+
|
|
238
|
+
combined_sample_dataset = self.merge_data_arrays(data_arrays)
|
|
239
|
+
|
|
240
|
+
# TODO add solar + time features for sites
|
|
241
|
+
|
|
242
|
+
# Fill any nan values
|
|
243
|
+
return combined_sample_dataset.fillna(0.0)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def merge_data_arrays(self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
|
|
247
|
+
"""
|
|
248
|
+
Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
list_of_arrays: List of tuples where each tuple contains:
|
|
252
|
+
- A string (key name).
|
|
253
|
+
- An xarray.DataArray.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
|
|
257
|
+
"""
|
|
258
|
+
datasets = []
|
|
259
|
+
|
|
260
|
+
for key, data_array in normalised_data_arrays:
|
|
261
|
+
# Ensure all attributes are strings for consistency
|
|
262
|
+
data_array = data_array.assign_attrs(
|
|
263
|
+
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Convert DataArray to Dataset with the variable name as the key
|
|
267
|
+
dataset = data_array.to_dataset(name=key)
|
|
268
|
+
|
|
269
|
+
# Prepend key name to all dimension and coordinate names for uniqueness
|
|
270
|
+
dataset = dataset.rename(
|
|
271
|
+
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
|
|
272
|
+
)
|
|
273
|
+
dataset = dataset.rename(
|
|
274
|
+
{coord: f"{key}__{coord}" for coord in dataset.coords}
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Handle concatenation dimension if applicable
|
|
278
|
+
concat_dim = (
|
|
279
|
+
f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
|
|
280
|
+
else f"{key}__time_utc"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if f"{key}__init_time_utc" in dataset.coords:
|
|
284
|
+
init_coord = f"{key}__init_time_utc"
|
|
285
|
+
if dataset[init_coord].ndim == 0: # Check if scalar
|
|
286
|
+
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
|
|
287
|
+
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})
|
|
288
|
+
|
|
289
|
+
datasets.append(dataset)
|
|
290
|
+
|
|
291
|
+
# Ensure all datasets are valid xarray.Dataset objects
|
|
292
|
+
for ds in datasets:
|
|
293
|
+
assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"
|
|
294
|
+
|
|
295
|
+
# Merge all prepared datasets
|
|
296
|
+
combined_dataset = xr.merge(datasets)
|
|
297
|
+
|
|
298
|
+
return combined_dataset
|
|
299
|
+
|
|
300
|
+
# ----- functions to load presaved samples ------
|
|
301
|
+
|
|
302
|
+
def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
|
|
303
|
+
"""
|
|
304
|
+
Convert a combined sample dataset to a dict of datasets for each input
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
combined_dataset: The combined NetCDF dataset
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
The uncombined datasets as a dict of xr.Datasets
|
|
311
|
+
"""
|
|
312
|
+
# Split into datasets by splitting by the prefix added in combine_to_netcdf
|
|
313
|
+
datasets = {}
|
|
314
|
+
# Go through each data variable and split it into a dataset
|
|
315
|
+
for key, dataset in combined_dataset.items():
|
|
316
|
+
# If 'key_' doesn't exist in a dim or coordinate, remove it
|
|
317
|
+
dataset_dims = list(dataset.coords)
|
|
318
|
+
for dim in dataset_dims:
|
|
319
|
+
if f"{key}__" not in dim:
|
|
320
|
+
dataset: xr.Dataset = dataset.drop(dim)
|
|
321
|
+
dataset = dataset.rename(
|
|
322
|
+
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords}
|
|
323
|
+
)
|
|
324
|
+
dataset: xr.Dataset = dataset.rename(
|
|
325
|
+
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords}
|
|
326
|
+
)
|
|
327
|
+
# Split the dataset by the prefix
|
|
328
|
+
datasets[key] = dataset
|
|
329
|
+
|
|
330
|
+
# Unflatten any NWP data
|
|
331
|
+
datasets = nest_nwp_source_dict(datasets, sep="-")
|
|
332
|
+
return datasets
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
|
|
336
|
+
"""Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
|
|
337
|
+
nwp_prefix = f"nwp{sep}"
|
|
338
|
+
new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
|
|
339
|
+
nwp_keys = [k for k in d.keys() if k.startswith(nwp_prefix)]
|
|
340
|
+
if len(nwp_keys) > 0:
|
|
341
|
+
nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
|
|
342
|
+
new_dict["nwp"] = nwp_subdict
|
|
343
|
+
return new_dict
|
|
344
|
+
|
|
345
|
+
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
346
|
+
"""Convert a netcdf dataset to a numpy sample"""
|
|
347
|
+
|
|
348
|
+
# convert the single dataset to a dict of arrays
|
|
349
|
+
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
350
|
+
|
|
351
|
+
if "satellite" in sample_dict:
|
|
352
|
+
# rename satellite to satellite actual # TODO this could be improves
|
|
353
|
+
sample_dict["sat"] = sample_dict.pop("satellite")
|
|
354
|
+
|
|
355
|
+
# process and combine the datasets
|
|
356
|
+
sample = convert_to_numpy_and_combine(
|
|
357
|
+
dataset_dict=sample_dict,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# TODO think about normalization, maybe its done not in sample creation, maybe its done afterwards,
|
|
361
|
+
# to allow it to be flexible
|
|
362
|
+
|
|
363
|
+
return sample
|
|
364
|
+
|
|
365
|
+
def convert_to_numpy_and_combine(
|
|
366
|
+
dataset_dict: dict,
|
|
367
|
+
) -> dict:
|
|
368
|
+
"""Convert input data in a dict to numpy arrays"""
|
|
369
|
+
|
|
370
|
+
numpy_modalities = []
|
|
371
|
+
|
|
372
|
+
if "nwp" in dataset_dict:
|
|
373
|
+
|
|
374
|
+
nwp_numpy_modalities = dict()
|
|
375
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
376
|
+
# Convert to NumpySample
|
|
377
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
378
|
+
|
|
379
|
+
# Combine the NWPs into NumpySample
|
|
380
|
+
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
381
|
+
|
|
382
|
+
if "sat" in dataset_dict:
|
|
383
|
+
# Satellite is already in the range [0-1] so no need to standardise
|
|
384
|
+
da_sat = dataset_dict["sat"]
|
|
385
|
+
|
|
386
|
+
# Convert to NumpySample
|
|
387
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
388
|
+
|
|
389
|
+
if "site" in dataset_dict:
|
|
390
|
+
da_sites = dataset_dict["site"]
|
|
391
|
+
sites_sample = convert_site_to_numpy_sample(da_sites)
|
|
392
|
+
|
|
393
|
+
numpy_modalities.append(
|
|
394
|
+
convert_site_to_numpy_sample(
|
|
395
|
+
da_sites,
|
|
396
|
+
)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
numpy_modalities.append(sites_sample)
|
|
400
|
+
|
|
401
|
+
# Combine all the modalities and fill NaNs
|
|
402
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
403
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
404
|
+
|
|
405
|
+
return combined_sample
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
from ocf_data_sampler.config import Configuration
|
|
5
|
+
from ocf_data_sampler.select.find_contiguous_time_periods import (
|
|
6
|
+
find_contiguous_t0_periods_nwp,
|
|
7
|
+
find_contiguous_t0_periods,
|
|
8
|
+
intersection_of_multiple_dataframes_of_periods,
|
|
9
|
+
)
|
|
10
|
+
from ocf_data_sampler.utils import minutes
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def find_valid_time_periods(
|
|
15
|
+
datasets_dict: dict,
|
|
16
|
+
config: Configuration,
|
|
17
|
+
):
|
|
18
|
+
"""Find the t0 times where all of the requested input data is available
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
datasets_dict: A dictionary of input datasets
|
|
22
|
+
config: Configuration file
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
|
|
26
|
+
|
|
27
|
+
contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
|
|
28
|
+
|
|
29
|
+
if "nwp" in datasets_dict:
|
|
30
|
+
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
31
|
+
|
|
32
|
+
da = datasets_dict["nwp"][nwp_key]
|
|
33
|
+
|
|
34
|
+
if nwp_config.dropout_timedeltas_minutes is None:
|
|
35
|
+
max_dropout = minutes(0)
|
|
36
|
+
else:
|
|
37
|
+
max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
|
|
38
|
+
|
|
39
|
+
if nwp_config.max_staleness_minutes is None:
|
|
40
|
+
max_staleness = None
|
|
41
|
+
else:
|
|
42
|
+
max_staleness = minutes(nwp_config.max_staleness_minutes)
|
|
43
|
+
|
|
44
|
+
# The last step of the forecast is lost if we have to diff channels
|
|
45
|
+
if len(nwp_config.accum_channels) > 0:
|
|
46
|
+
end_buffer = minutes(nwp_config.time_resolution_minutes)
|
|
47
|
+
else:
|
|
48
|
+
end_buffer = minutes(0)
|
|
49
|
+
|
|
50
|
+
# This is the max staleness we can use considering the max step of the input data
|
|
51
|
+
max_possible_staleness = (
|
|
52
|
+
pd.Timedelta(da["step"].max().item())
|
|
53
|
+
- minutes(nwp_config.interval_end_minutes)
|
|
54
|
+
- end_buffer
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Default to use max possible staleness unless specified in config
|
|
58
|
+
if max_staleness is None:
|
|
59
|
+
max_staleness = max_possible_staleness
|
|
60
|
+
else:
|
|
61
|
+
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
62
|
+
assert max_staleness <= max_possible_staleness
|
|
63
|
+
|
|
64
|
+
# Find the first forecast step
|
|
65
|
+
first_forecast_step = pd.Timedelta(da["step"].min().item())
|
|
66
|
+
|
|
67
|
+
time_periods = find_contiguous_t0_periods_nwp(
|
|
68
|
+
init_times=pd.DatetimeIndex(da["init_time_utc"]),
|
|
69
|
+
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
70
|
+
max_staleness=max_staleness,
|
|
71
|
+
max_dropout=max_dropout,
|
|
72
|
+
first_forecast_step = first_forecast_step,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
|
|
76
|
+
|
|
77
|
+
if "sat" in datasets_dict:
|
|
78
|
+
sat_config = config.input_data.satellite
|
|
79
|
+
|
|
80
|
+
time_periods = find_contiguous_t0_periods(
|
|
81
|
+
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
82
|
+
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
83
|
+
interval_start=minutes(sat_config.interval_start_minutes),
|
|
84
|
+
interval_end=minutes(sat_config.interval_end_minutes),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
contiguous_time_periods['sat'] = time_periods
|
|
88
|
+
|
|
89
|
+
if "gsp" in datasets_dict:
|
|
90
|
+
gsp_config = config.input_data.gsp
|
|
91
|
+
|
|
92
|
+
time_periods = find_contiguous_t0_periods(
|
|
93
|
+
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
94
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
95
|
+
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
96
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
contiguous_time_periods['gsp'] = time_periods
|
|
100
|
+
|
|
101
|
+
# just get the values (not the keys)
|
|
102
|
+
contiguous_time_periods_values = list(contiguous_time_periods.values())
|
|
103
|
+
|
|
104
|
+
# Find joint overlapping contiguous time periods
|
|
105
|
+
if len(contiguous_time_periods_values) > 1:
|
|
106
|
+
valid_time_periods = intersection_of_multiple_dataframes_of_periods(
|
|
107
|
+
contiguous_time_periods_values
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
valid_time_periods = contiguous_time_periods_values[0]
|
|
111
|
+
|
|
112
|
+
# check there are some valid time periods
|
|
113
|
+
if len(valid_time_periods) == 0:
|
|
114
|
+
raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
|
|
115
|
+
|
|
116
|
+
return valid_time_periods
|