ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +73 -61
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,59 +1,62 @@
|
|
|
1
|
-
"""Torch dataset for sites"""
|
|
1
|
+
"""Torch dataset for sites."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import xarray as xr
|
|
7
|
-
from typing import Tuple
|
|
8
|
-
|
|
9
8
|
from torch.utils.data import Dataset
|
|
9
|
+
from typing_extensions import override
|
|
10
10
|
|
|
11
11
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
12
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
12
13
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
14
|
+
from ocf_data_sampler.numpy_sample import (
|
|
15
|
+
NWPSampleKey,
|
|
16
|
+
convert_nwp_to_numpy_sample,
|
|
17
|
+
convert_satellite_to_numpy_sample,
|
|
18
|
+
convert_site_to_numpy_sample,
|
|
19
|
+
make_datetime_numpy_dict,
|
|
20
|
+
make_sun_position_numpy_sample,
|
|
21
|
+
)
|
|
13
22
|
from ocf_data_sampler.select import (
|
|
14
23
|
Location,
|
|
15
24
|
fill_time_periods,
|
|
16
25
|
find_contiguous_t0_periods,
|
|
17
26
|
intersection_of_multiple_dataframes_of_periods,
|
|
18
|
-
|
|
27
|
+
slice_datasets_by_space,
|
|
28
|
+
slice_datasets_by_time,
|
|
19
29
|
)
|
|
20
|
-
from ocf_data_sampler.utils import
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
from ocf_data_sampler.numpy_sample import (
|
|
25
|
-
convert_site_to_numpy_sample,
|
|
26
|
-
convert_satellite_to_numpy_sample,
|
|
27
|
-
convert_nwp_to_numpy_sample,
|
|
28
|
-
make_datetime_numpy_dict,
|
|
29
|
-
make_sun_position_numpy_sample,
|
|
30
|
+
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
31
|
+
fill_nans_in_arrays,
|
|
32
|
+
merge_dicts,
|
|
30
33
|
)
|
|
31
|
-
from ocf_data_sampler.
|
|
32
|
-
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
33
|
-
|
|
34
|
+
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
34
35
|
from ocf_data_sampler.torch_datasets.utils.validate_channels import (
|
|
35
36
|
validate_nwp_channels,
|
|
36
37
|
validate_satellite_channels,
|
|
37
38
|
)
|
|
39
|
+
from ocf_data_sampler.utils import minutes
|
|
38
40
|
|
|
39
41
|
xr.set_options(keep_attrs=True)
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
class SitesDataset(Dataset):
|
|
45
|
+
"""A torch Dataset for creating PVNet Site samples."""
|
|
46
|
+
|
|
43
47
|
def __init__(
|
|
44
48
|
self,
|
|
45
49
|
config_filename: str,
|
|
46
50
|
start_time: str | None = None,
|
|
47
51
|
end_time: str | None = None,
|
|
48
|
-
):
|
|
49
|
-
"""A torch Dataset for creating PVNet Site samples
|
|
52
|
+
) -> None:
|
|
53
|
+
"""A torch Dataset for creating PVNet Site samples.
|
|
50
54
|
|
|
51
55
|
Args:
|
|
52
56
|
config_filename: Path to the configuration file
|
|
53
57
|
start_time: Limit the init-times to be after this
|
|
54
58
|
end_time: Limit the init-times to be before this
|
|
55
59
|
"""
|
|
56
|
-
|
|
57
60
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
58
61
|
validate_nwp_channels(config)
|
|
59
62
|
validate_satellite_channels(config)
|
|
@@ -65,28 +68,31 @@ class SitesDataset(Dataset):
|
|
|
65
68
|
self.config = config
|
|
66
69
|
|
|
67
70
|
# get all locations
|
|
68
|
-
self.locations = self.get_locations(datasets_dict[
|
|
71
|
+
self.locations = self.get_locations(datasets_dict["site"])
|
|
69
72
|
|
|
70
73
|
# Get t0 times where all input data is available
|
|
71
74
|
valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
|
|
72
75
|
|
|
73
76
|
# Filter t0 times to given range
|
|
74
77
|
if start_time is not None:
|
|
75
|
-
valid_t0_and_site_ids
|
|
76
|
-
|
|
78
|
+
valid_t0_and_site_ids = valid_t0_and_site_ids[
|
|
79
|
+
valid_t0_and_site_ids["t0"] >= pd.Timestamp(start_time)
|
|
80
|
+
]
|
|
77
81
|
|
|
78
82
|
if end_time is not None:
|
|
79
|
-
valid_t0_and_site_ids
|
|
80
|
-
|
|
83
|
+
valid_t0_and_site_ids = valid_t0_and_site_ids[
|
|
84
|
+
valid_t0_and_site_ids["t0"] <= pd.Timestamp(end_time)
|
|
85
|
+
]
|
|
81
86
|
|
|
82
87
|
# Assign coords and indices to self
|
|
83
88
|
self.valid_t0_and_site_ids = valid_t0_and_site_ids
|
|
84
89
|
|
|
85
|
-
|
|
90
|
+
@override
|
|
91
|
+
def __len__(self) -> int:
|
|
86
92
|
return len(self.valid_t0_and_site_ids)
|
|
87
|
-
|
|
88
|
-
def __getitem__(self, idx):
|
|
89
93
|
|
|
94
|
+
@override
|
|
95
|
+
def __getitem__(self, idx: int) -> dict:
|
|
90
96
|
# Get the coordinates of the sample
|
|
91
97
|
t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
|
|
92
98
|
|
|
@@ -97,7 +103,7 @@ class SitesDataset(Dataset):
|
|
|
97
103
|
return self._get_sample(t0, location)
|
|
98
104
|
|
|
99
105
|
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
100
|
-
"""Generate the PVNet sample for given coordinates
|
|
106
|
+
"""Generate the PVNet sample for given coordinates.
|
|
101
107
|
|
|
102
108
|
Args:
|
|
103
109
|
t0: init-time for sample
|
|
@@ -106,7 +112,7 @@ class SitesDataset(Dataset):
|
|
|
106
112
|
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
107
113
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
108
114
|
|
|
109
|
-
sample = self.process_and_combine_site_sample_dict(sample_dict)
|
|
115
|
+
sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
|
|
110
116
|
sample = sample.compute()
|
|
111
117
|
return sample
|
|
112
118
|
|
|
@@ -119,20 +125,20 @@ class SitesDataset(Dataset):
|
|
|
119
125
|
t0: init-time for sample
|
|
120
126
|
site_id: site id as int
|
|
121
127
|
"""
|
|
122
|
-
|
|
123
128
|
location = self.get_location_from_site_id(site_id)
|
|
124
129
|
|
|
125
130
|
return self._get_sample(t0, location)
|
|
126
|
-
|
|
127
|
-
def get_location_from_site_id(self, site_id):
|
|
128
|
-
"""Get location from system id"""
|
|
129
131
|
|
|
132
|
+
def get_location_from_site_id(self, site_id: int) -> Location:
|
|
133
|
+
"""Get location from system id."""
|
|
130
134
|
locations = [loc for loc in self.locations if loc.id == site_id]
|
|
131
135
|
if len(locations) == 0:
|
|
132
136
|
raise ValueError(f"Location not found for site_id {site_id}")
|
|
133
137
|
|
|
134
138
|
if len(locations) > 1:
|
|
135
|
-
logging.warning(
|
|
139
|
+
logging.warning(
|
|
140
|
+
f"Multiple locations found for site_id {site_id}, but will take the first",
|
|
141
|
+
)
|
|
136
142
|
|
|
137
143
|
return locations[0]
|
|
138
144
|
|
|
@@ -140,7 +146,7 @@ class SitesDataset(Dataset):
|
|
|
140
146
|
self,
|
|
141
147
|
datasets_dict: dict,
|
|
142
148
|
) -> pd.DataFrame:
|
|
143
|
-
"""Find the t0 times where all of the requested input data is available
|
|
149
|
+
"""Find the t0 times where all of the requested input data is available.
|
|
144
150
|
|
|
145
151
|
The idea is to
|
|
146
152
|
1. Get valid time period for nwp and satellite
|
|
@@ -150,9 +156,8 @@ class SitesDataset(Dataset):
|
|
|
150
156
|
datasets_dict: A dictionary of input datasets
|
|
151
157
|
config: Configuration file
|
|
152
158
|
"""
|
|
153
|
-
|
|
154
159
|
# 1. Get valid time period for nwp and satellite
|
|
155
|
-
datasets_without_site = {k:v for k, v in datasets_dict.items() if k!="site"}
|
|
160
|
+
datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
|
|
156
161
|
valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
|
|
157
162
|
|
|
158
163
|
# 2. Now lets loop over each location in system id and find the valid periods
|
|
@@ -166,39 +171,37 @@ class SitesDataset(Dataset):
|
|
|
166
171
|
|
|
167
172
|
# drop any nan values
|
|
168
173
|
# not sure this is right?
|
|
169
|
-
site = site.dropna(dim=
|
|
174
|
+
site = site.dropna(dim="time_utc")
|
|
170
175
|
|
|
171
176
|
# Get the valid time periods for this location
|
|
172
177
|
time_periods = find_contiguous_t0_periods(
|
|
173
178
|
pd.DatetimeIndex(site["time_utc"]),
|
|
174
|
-
|
|
179
|
+
time_resolution=minutes(site_config.time_resolution_minutes),
|
|
175
180
|
interval_start=minutes(site_config.interval_start_minutes),
|
|
176
181
|
interval_end=minutes(site_config.interval_end_minutes),
|
|
177
182
|
)
|
|
178
183
|
valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
|
|
179
|
-
[valid_time_periods, time_periods]
|
|
184
|
+
[valid_time_periods, time_periods],
|
|
180
185
|
)
|
|
181
186
|
|
|
182
187
|
# Fill out the contiguous time periods to get the t0 times
|
|
183
188
|
valid_t0_times_per_site = fill_time_periods(
|
|
184
189
|
valid_time_periods_per_site,
|
|
185
|
-
freq=minutes(site_config.time_resolution_minutes)
|
|
190
|
+
freq=minutes(site_config.time_resolution_minutes),
|
|
186
191
|
)
|
|
187
192
|
|
|
188
193
|
valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site)
|
|
189
|
-
valid_t0_per_site[
|
|
194
|
+
valid_t0_per_site["site_id"] = site_id
|
|
190
195
|
valid_t0_and_site_ids.append(valid_t0_per_site)
|
|
191
196
|
|
|
192
197
|
valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
|
|
193
|
-
valid_t0_and_site_ids.index.name =
|
|
198
|
+
valid_t0_and_site_ids.index.name = "t0"
|
|
194
199
|
valid_t0_and_site_ids.reset_index(inplace=True)
|
|
195
200
|
|
|
196
201
|
return valid_t0_and_site_ids
|
|
197
202
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
"""Get list of locations of all sites"""
|
|
201
|
-
|
|
203
|
+
def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
|
|
204
|
+
"""Get list of locations of all sites."""
|
|
202
205
|
locations = []
|
|
203
206
|
for site_id in site_xr.site_id.values:
|
|
204
207
|
site = site_xr.sel(site_id=site_id)
|
|
@@ -206,7 +209,7 @@ class SitesDataset(Dataset):
|
|
|
206
209
|
id=site_id,
|
|
207
210
|
x=site.longitude.values,
|
|
208
211
|
y=site.latitude.values,
|
|
209
|
-
coordinate_system="lon_lat"
|
|
212
|
+
coordinate_system="lon_lat",
|
|
210
213
|
)
|
|
211
214
|
locations.append(location)
|
|
212
215
|
|
|
@@ -215,29 +218,29 @@ class SitesDataset(Dataset):
|
|
|
215
218
|
def process_and_combine_site_sample_dict(
|
|
216
219
|
self,
|
|
217
220
|
dataset_dict: dict,
|
|
221
|
+
t0: pd.Timestamp,
|
|
218
222
|
) -> xr.Dataset:
|
|
219
|
-
"""
|
|
220
|
-
Normalize and combine data into a single xr Dataset
|
|
223
|
+
"""Normalize and combine data into a single xr Dataset.
|
|
221
224
|
|
|
222
225
|
Args:
|
|
223
226
|
dataset_dict: dict containing sliced xr DataArrays
|
|
224
227
|
config: Configuration for the model
|
|
228
|
+
t0: The initial timestamp of the sample
|
|
225
229
|
|
|
226
230
|
Returns:
|
|
227
231
|
xr.Dataset: A merged Dataset with nans filled in.
|
|
228
|
-
|
|
229
|
-
"""
|
|
230
232
|
|
|
233
|
+
"""
|
|
231
234
|
data_arrays = []
|
|
232
235
|
|
|
233
236
|
if "nwp" in dataset_dict:
|
|
234
237
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
235
238
|
provider = self.config.input_data.nwp[nwp_key].provider
|
|
236
|
-
|
|
239
|
+
|
|
237
240
|
# Standardise
|
|
238
241
|
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
239
242
|
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
240
|
-
|
|
243
|
+
|
|
241
244
|
if "sat" in dataset_dict:
|
|
242
245
|
da_sat = dataset_dict["sat"]
|
|
243
246
|
|
|
@@ -257,33 +260,57 @@ class SitesDataset(Dataset):
|
|
|
257
260
|
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
|
|
258
261
|
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
|
|
259
262
|
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
260
|
-
{k: ("site__time_utc", v) for k, v in datetime_features.items()}
|
|
263
|
+
{k: ("site__time_utc", v) for k, v in datetime_features.items()},
|
|
261
264
|
)
|
|
262
265
|
|
|
263
|
-
# add
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
lat=combined_sample_dataset.site__latitude.values,
|
|
268
|
-
key_prefix="site_",
|
|
269
|
-
)
|
|
270
|
-
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
271
|
-
{k: ("site__time_utc", v) for k, v in sun_position_features.items()}
|
|
266
|
+
# Only add solar position if explicitly configured
|
|
267
|
+
has_solar_config = (
|
|
268
|
+
hasattr(self.config.input_data, "solar_position") and
|
|
269
|
+
self.config.input_data.solar_position is not None
|
|
272
270
|
)
|
|
273
271
|
|
|
274
|
-
|
|
272
|
+
if has_solar_config:
|
|
273
|
+
solar_config = self.config.input_data.solar_position
|
|
274
|
+
|
|
275
|
+
# Datetime range - solar config params
|
|
276
|
+
solar_datetimes = pd.date_range(
|
|
277
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
278
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
279
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Calculate sun position features
|
|
283
|
+
sun_position_features = make_sun_position_numpy_sample(
|
|
284
|
+
datetimes=solar_datetimes,
|
|
285
|
+
lon=combined_sample_dataset.site__longitude.values,
|
|
286
|
+
lat=combined_sample_dataset.site__latitude.values,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Dimension state for solar position data
|
|
290
|
+
solar_dim_name = "solar_time_utc"
|
|
291
|
+
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
292
|
+
{solar_dim_name: solar_datetimes},
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Assign solar position values
|
|
296
|
+
for key, values in sun_position_features.items():
|
|
297
|
+
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
298
|
+
{key: (solar_dim_name, values)},
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# TODO include t0_index in xr dataset?
|
|
275
302
|
|
|
276
303
|
# Fill any nan values
|
|
277
304
|
return combined_sample_dataset.fillna(0.0)
|
|
278
305
|
|
|
279
306
|
def merge_data_arrays(
|
|
280
|
-
self,
|
|
307
|
+
self,
|
|
308
|
+
normalised_data_arrays: list[tuple[str, xr.DataArray]],
|
|
281
309
|
) -> xr.Dataset:
|
|
282
|
-
"""
|
|
283
|
-
Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
310
|
+
"""Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
284
311
|
|
|
285
312
|
Args:
|
|
286
|
-
|
|
313
|
+
normalised_data_arrays: List of tuples where each tuple contains:
|
|
287
314
|
- A string (key name).
|
|
288
315
|
- An xarray.DataArray.
|
|
289
316
|
|
|
@@ -295,7 +322,7 @@ class SitesDataset(Dataset):
|
|
|
295
322
|
for key, data_array in normalised_data_arrays:
|
|
296
323
|
# Ensure all attributes are strings for consistency
|
|
297
324
|
data_array = data_array.assign_attrs(
|
|
298
|
-
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
|
|
325
|
+
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
|
|
299
326
|
)
|
|
300
327
|
|
|
301
328
|
# Convert DataArray to Dataset with the variable name as the key
|
|
@@ -303,15 +330,16 @@ class SitesDataset(Dataset):
|
|
|
303
330
|
|
|
304
331
|
# Prepend key name to all dimension and coordinate names for uniqueness
|
|
305
332
|
dataset = dataset.rename(
|
|
306
|
-
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
|
|
333
|
+
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
|
|
307
334
|
)
|
|
308
335
|
dataset = dataset.rename(
|
|
309
|
-
{coord: f"{key}__{coord}" for coord in dataset.coords}
|
|
336
|
+
{coord: f"{key}__{coord}" for coord in dataset.coords},
|
|
310
337
|
)
|
|
311
338
|
|
|
312
339
|
# Handle concatenation dimension if applicable
|
|
313
340
|
concat_dim = (
|
|
314
|
-
f"{key}__target_time_utc"
|
|
341
|
+
f"{key}__target_time_utc"
|
|
342
|
+
if f"{key}__target_time_utc" in dataset.coords
|
|
315
343
|
else f"{key}__time_utc"
|
|
316
344
|
)
|
|
317
345
|
|
|
@@ -325,20 +353,22 @@ class SitesDataset(Dataset):
|
|
|
325
353
|
|
|
326
354
|
# Ensure all datasets are valid xarray.Dataset objects
|
|
327
355
|
for ds in datasets:
|
|
328
|
-
|
|
356
|
+
if not isinstance(ds, xr.Dataset):
|
|
357
|
+
raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
|
|
329
358
|
|
|
330
359
|
# Merge all prepared datasets
|
|
331
360
|
combined_dataset = xr.merge(datasets)
|
|
332
361
|
|
|
333
362
|
return combined_dataset
|
|
334
363
|
|
|
364
|
+
|
|
335
365
|
# ----- functions to load presaved samples ------
|
|
336
366
|
|
|
337
|
-
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
338
|
-
"""Convert a netcdf dataset to a numpy sample"""
|
|
339
367
|
|
|
368
|
+
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
369
|
+
"""Convert a netcdf dataset to a numpy sample."""
|
|
340
370
|
# convert the single dataset to a dict of arrays
|
|
341
|
-
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
371
|
+
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
342
372
|
|
|
343
373
|
if "satellite" in sample_dict:
|
|
344
374
|
# rename satellite to satellite actual # TODO this could be improves
|
|
@@ -349,14 +379,21 @@ def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
|
349
379
|
dataset_dict=sample_dict,
|
|
350
380
|
)
|
|
351
381
|
|
|
352
|
-
#
|
|
353
|
-
|
|
382
|
+
# Extraction of solar position coords
|
|
383
|
+
solar_keys = ["solar_azimuth", "solar_elevation"]
|
|
384
|
+
for key in solar_keys:
|
|
385
|
+
if key in ds.coords:
|
|
386
|
+
sample[key] = ds.coords[key].values
|
|
387
|
+
|
|
388
|
+
# TODO think about normalization:
|
|
389
|
+
# * maybe its done not in sample creation, maybe its done afterwards,
|
|
390
|
+
# to allow it to be flexible
|
|
354
391
|
|
|
355
392
|
return sample
|
|
356
393
|
|
|
394
|
+
|
|
357
395
|
def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
|
|
358
|
-
"""
|
|
359
|
-
Convert a combined sample dataset to a dict of datasets for each input
|
|
396
|
+
"""Convert a combined sample dataset to a dict of datasets for each input.
|
|
360
397
|
|
|
361
398
|
Args:
|
|
362
399
|
combined_dataset: The combined NetCDF dataset
|
|
@@ -374,10 +411,10 @@ def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[
|
|
|
374
411
|
if f"{key}__" not in dim:
|
|
375
412
|
dataset: xr.Dataset = dataset.drop(dim)
|
|
376
413
|
dataset = dataset.rename(
|
|
377
|
-
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords}
|
|
414
|
+
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
|
|
378
415
|
)
|
|
379
416
|
dataset: xr.Dataset = dataset.rename(
|
|
380
|
-
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords}
|
|
417
|
+
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
|
|
381
418
|
)
|
|
382
419
|
# Split the dataset by the prefix
|
|
383
420
|
datasets[key] = dataset
|
|
@@ -391,22 +428,21 @@ def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
|
|
|
391
428
|
"""Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
|
|
392
429
|
nwp_prefix = f"nwp{sep}"
|
|
393
430
|
new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
|
|
394
|
-
nwp_keys = [k for k in d
|
|
431
|
+
nwp_keys = [k for k in d if k.startswith(nwp_prefix)]
|
|
395
432
|
if len(nwp_keys) > 0:
|
|
396
433
|
nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
|
|
397
434
|
new_dict["nwp"] = nwp_subdict
|
|
398
435
|
return new_dict
|
|
399
436
|
|
|
437
|
+
|
|
400
438
|
def convert_to_numpy_and_combine(
|
|
401
439
|
dataset_dict: dict,
|
|
402
440
|
) -> dict:
|
|
403
|
-
"""Convert input data in a dict to numpy arrays"""
|
|
404
|
-
|
|
441
|
+
"""Convert input data in a dict to numpy arrays."""
|
|
405
442
|
numpy_modalities = []
|
|
406
443
|
|
|
407
444
|
if "nwp" in dataset_dict:
|
|
408
|
-
|
|
409
|
-
nwp_numpy_modalities = dict()
|
|
445
|
+
nwp_numpy_modalities = {}
|
|
410
446
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
411
447
|
# Convert to NumpySample
|
|
412
448
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
@@ -427,7 +463,7 @@ def convert_to_numpy_and_combine(
|
|
|
427
463
|
numpy_modalities.append(
|
|
428
464
|
convert_site_to_numpy_sample(
|
|
429
465
|
da_sites,
|
|
430
|
-
)
|
|
466
|
+
),
|
|
431
467
|
)
|
|
432
468
|
|
|
433
469
|
# Combine all the modalities and fill NaNs
|
|
@@ -437,25 +473,23 @@ def convert_to_numpy_and_combine(
|
|
|
437
473
|
return combined_sample
|
|
438
474
|
|
|
439
475
|
|
|
440
|
-
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float=0.1):
|
|
441
|
-
"""
|
|
442
|
-
|
|
443
|
-
|
|
476
|
+
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
|
|
477
|
+
"""Coarsen the data to a specified resolution in degrees.
|
|
478
|
+
|
|
444
479
|
Args:
|
|
445
480
|
xr_data: xarray dataset to coarsen
|
|
446
481
|
coarsen_to_deg: resolution to coarsen to in degrees
|
|
447
482
|
"""
|
|
448
|
-
|
|
449
483
|
if "latitude" in xr_data.coords and "longitude" in xr_data.coords:
|
|
450
|
-
step = np.abs(xr_data.latitude.values[1]-xr_data.latitude.values[0])
|
|
451
|
-
step = np.round(step,4)
|
|
452
|
-
coarsen_factor = int(coarsen_to_deg/step)
|
|
484
|
+
step = np.abs(xr_data.latitude.values[1] - xr_data.latitude.values[0])
|
|
485
|
+
step = np.round(step, 4)
|
|
486
|
+
coarsen_factor = int(coarsen_to_deg / step)
|
|
453
487
|
if coarsen_factor > 1:
|
|
454
488
|
xr_data = xr_data.coarsen(
|
|
455
489
|
latitude=coarsen_factor,
|
|
456
490
|
longitude=coarsen_factor,
|
|
457
491
|
boundary="pad",
|
|
458
|
-
coord_func="min"
|
|
492
|
+
coord_func="min",
|
|
459
493
|
).mean()
|
|
460
|
-
|
|
461
|
-
return xr_data
|
|
494
|
+
|
|
495
|
+
return xr_data
|
|
@@ -1,13 +1,17 @@
|
|
|
1
|
+
"""Utility functions for merging dictionaries and filling NaNs in arrays."""
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
|
|
5
|
+
|
|
3
6
|
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
4
|
-
"""Merge a list of dictionaries into a single dictionary"""
|
|
7
|
+
"""Merge a list of dictionaries into a single dictionary."""
|
|
5
8
|
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
6
9
|
combined_dict = {}
|
|
7
10
|
for d in list_of_dicts:
|
|
8
11
|
combined_dict.update(d)
|
|
9
12
|
return combined_dict
|
|
10
13
|
|
|
14
|
+
|
|
11
15
|
def fill_nans_in_arrays(sample: dict) -> dict:
|
|
12
16
|
"""Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
|
|
13
17
|
|
|
@@ -22,4 +26,4 @@ def fill_nans_in_arrays(sample: dict) -> dict:
|
|
|
22
26
|
elif isinstance(v, dict):
|
|
23
27
|
fill_nans_in_arrays(v)
|
|
24
28
|
|
|
25
|
-
return sample
|
|
29
|
+
return sample
|
|
@@ -1,34 +1,31 @@
|
|
|
1
|
+
"""Functions pertaining to finding valid time periods for the input data."""
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
5
|
|
|
4
6
|
from ocf_data_sampler.config import Configuration
|
|
5
7
|
from ocf_data_sampler.select.find_contiguous_time_periods import (
|
|
8
|
+
find_contiguous_t0_periods,
|
|
6
9
|
find_contiguous_t0_periods_nwp,
|
|
7
|
-
find_contiguous_t0_periods,
|
|
8
10
|
intersection_of_multiple_dataframes_of_periods,
|
|
9
11
|
)
|
|
10
12
|
from ocf_data_sampler.utils import minutes
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
datasets_dict: dict,
|
|
16
|
-
config: Configuration,
|
|
17
|
-
):
|
|
18
|
-
"""Find the t0 times where all of the requested input data is available
|
|
15
|
+
def find_valid_time_periods(datasets_dict: dict, config: Configuration) -> pd.DataFrame:
|
|
16
|
+
"""Find the t0 times where all of the requested input data is available.
|
|
19
17
|
|
|
20
18
|
Args:
|
|
21
19
|
datasets_dict: A dictionary of input datasets
|
|
22
20
|
config: Configuration file
|
|
23
21
|
"""
|
|
22
|
+
if not set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"}):
|
|
23
|
+
raise ValueError(f"Invalid keys in datasets_dict: {datasets_dict.keys()}")
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
|
|
28
|
-
|
|
25
|
+
# Used to store contiguous time periods from each data source
|
|
26
|
+
contiguous_time_periods: dict[str : pd.DataFrame] = {}
|
|
29
27
|
if "nwp" in datasets_dict:
|
|
30
28
|
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
31
|
-
|
|
32
29
|
da = datasets_dict["nwp"][nwp_key]
|
|
33
30
|
|
|
34
31
|
if nwp_config.dropout_timedeltas_minutes is None:
|
|
@@ -59,8 +56,12 @@ def find_valid_time_periods(
|
|
|
59
56
|
max_staleness = max_possible_staleness
|
|
60
57
|
else:
|
|
61
58
|
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
62
|
-
|
|
63
|
-
|
|
59
|
+
if max_staleness > max_possible_staleness:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"max_staleness_minutes is too long for the input data, "
|
|
62
|
+
f"{max_staleness=}, {max_possible_staleness=}",
|
|
63
|
+
)
|
|
64
|
+
|
|
64
65
|
# Find the first forecast step
|
|
65
66
|
first_forecast_step = pd.Timedelta(da["step"].min().item())
|
|
66
67
|
|
|
@@ -69,34 +70,34 @@ def find_valid_time_periods(
|
|
|
69
70
|
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
70
71
|
max_staleness=max_staleness,
|
|
71
72
|
max_dropout=max_dropout,
|
|
72
|
-
first_forecast_step
|
|
73
|
+
first_forecast_step=first_forecast_step,
|
|
73
74
|
)
|
|
74
75
|
|
|
75
|
-
contiguous_time_periods[f
|
|
76
|
+
contiguous_time_periods[f"nwp_{nwp_key}"] = time_periods
|
|
76
77
|
|
|
77
78
|
if "sat" in datasets_dict:
|
|
78
79
|
sat_config = config.input_data.satellite
|
|
79
80
|
|
|
80
81
|
time_periods = find_contiguous_t0_periods(
|
|
81
82
|
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
82
|
-
|
|
83
|
+
time_resolution=minutes(sat_config.time_resolution_minutes),
|
|
83
84
|
interval_start=minutes(sat_config.interval_start_minutes),
|
|
84
85
|
interval_end=minutes(sat_config.interval_end_minutes),
|
|
85
86
|
)
|
|
86
87
|
|
|
87
|
-
contiguous_time_periods[
|
|
88
|
+
contiguous_time_periods["sat"] = time_periods
|
|
88
89
|
|
|
89
90
|
if "gsp" in datasets_dict:
|
|
90
91
|
gsp_config = config.input_data.gsp
|
|
91
92
|
|
|
92
93
|
time_periods = find_contiguous_t0_periods(
|
|
93
94
|
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
94
|
-
|
|
95
|
+
time_resolution=minutes(gsp_config.time_resolution_minutes),
|
|
95
96
|
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
96
97
|
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
97
98
|
)
|
|
98
99
|
|
|
99
|
-
contiguous_time_periods[
|
|
100
|
+
contiguous_time_periods["gsp"] = time_periods
|
|
100
101
|
|
|
101
102
|
# just get the values (not the keys)
|
|
102
103
|
contiguous_time_periods_values = list(contiguous_time_periods.values())
|
|
@@ -104,7 +105,7 @@ def find_valid_time_periods(
|
|
|
104
105
|
# Find joint overlapping contiguous time periods
|
|
105
106
|
if len(contiguous_time_periods_values) > 1:
|
|
106
107
|
valid_time_periods = intersection_of_multiple_dataframes_of_periods(
|
|
107
|
-
contiguous_time_periods_values
|
|
108
|
+
contiguous_time_periods_values,
|
|
108
109
|
)
|
|
109
110
|
else:
|
|
110
111
|
valid_time_periods = contiguous_time_periods_values[0]
|
|
@@ -113,4 +114,4 @@ def find_valid_time_periods(
|
|
|
113
114
|
if len(valid_time_periods) == 0:
|
|
114
115
|
raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
|
|
115
116
|
|
|
116
|
-
return valid_time_periods
|
|
117
|
+
return valid_time_periods
|