ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +146 -64
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
- ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
- ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +63 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler/constants.py +0 -222
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,63 +1,58 @@
|
|
|
1
|
-
"""Torch dataset for sites"""
|
|
1
|
+
"""Torch dataset for sites."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import xarray as xr
|
|
7
|
-
from typing import Tuple
|
|
8
|
-
|
|
9
8
|
from torch.utils.data import Dataset
|
|
9
|
+
from typing_extensions import override
|
|
10
10
|
|
|
11
11
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
12
12
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
13
|
+
from ocf_data_sampler.numpy_sample import (
|
|
14
|
+
NWPSampleKey,
|
|
15
|
+
convert_nwp_to_numpy_sample,
|
|
16
|
+
convert_satellite_to_numpy_sample,
|
|
17
|
+
convert_site_to_numpy_sample,
|
|
18
|
+
make_datetime_numpy_dict,
|
|
19
|
+
make_sun_position_numpy_sample,
|
|
20
|
+
)
|
|
13
21
|
from ocf_data_sampler.select import (
|
|
14
22
|
Location,
|
|
15
23
|
fill_time_periods,
|
|
16
24
|
find_contiguous_t0_periods,
|
|
17
25
|
intersection_of_multiple_dataframes_of_periods,
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from ocf_data_sampler.utils import minutes
|
|
21
|
-
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
22
|
-
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
|
|
23
|
-
|
|
24
|
-
from ocf_data_sampler.numpy_sample import (
|
|
25
|
-
convert_site_to_numpy_sample,
|
|
26
|
-
convert_satellite_to_numpy_sample,
|
|
27
|
-
convert_nwp_to_numpy_sample,
|
|
28
|
-
make_datetime_numpy_dict,
|
|
29
|
-
make_sun_position_numpy_sample,
|
|
26
|
+
slice_datasets_by_space,
|
|
27
|
+
slice_datasets_by_time,
|
|
30
28
|
)
|
|
31
|
-
from ocf_data_sampler.
|
|
32
|
-
from ocf_data_sampler.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
validate_nwp_channels,
|
|
36
|
-
validate_satellite_channels,
|
|
29
|
+
from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
|
|
30
|
+
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
31
|
+
fill_nans_in_arrays,
|
|
32
|
+
merge_dicts,
|
|
37
33
|
)
|
|
34
|
+
from ocf_data_sampler.utils import minutes
|
|
38
35
|
|
|
39
36
|
xr.set_options(keep_attrs=True)
|
|
40
37
|
|
|
41
38
|
|
|
42
39
|
class SitesDataset(Dataset):
|
|
40
|
+
"""A torch Dataset for creating PVNet Site samples."""
|
|
41
|
+
|
|
43
42
|
def __init__(
|
|
44
43
|
self,
|
|
45
44
|
config_filename: str,
|
|
46
45
|
start_time: str | None = None,
|
|
47
46
|
end_time: str | None = None,
|
|
48
|
-
):
|
|
49
|
-
"""A torch Dataset for creating PVNet Site samples
|
|
47
|
+
) -> None:
|
|
48
|
+
"""A torch Dataset for creating PVNet Site samples.
|
|
50
49
|
|
|
51
50
|
Args:
|
|
52
51
|
config_filename: Path to the configuration file
|
|
53
52
|
start_time: Limit the init-times to be after this
|
|
54
53
|
end_time: Limit the init-times to be before this
|
|
55
54
|
"""
|
|
56
|
-
|
|
57
55
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
58
|
-
validate_nwp_channels(config)
|
|
59
|
-
validate_satellite_channels(config)
|
|
60
|
-
|
|
61
56
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
62
57
|
|
|
63
58
|
# Assign config and input data to self
|
|
@@ -65,28 +60,31 @@ class SitesDataset(Dataset):
|
|
|
65
60
|
self.config = config
|
|
66
61
|
|
|
67
62
|
# get all locations
|
|
68
|
-
self.locations = self.get_locations(datasets_dict[
|
|
63
|
+
self.locations = self.get_locations(datasets_dict["site"])
|
|
69
64
|
|
|
70
65
|
# Get t0 times where all input data is available
|
|
71
66
|
valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
|
|
72
67
|
|
|
73
68
|
# Filter t0 times to given range
|
|
74
69
|
if start_time is not None:
|
|
75
|
-
valid_t0_and_site_ids
|
|
76
|
-
|
|
70
|
+
valid_t0_and_site_ids = valid_t0_and_site_ids[
|
|
71
|
+
valid_t0_and_site_ids["t0"] >= pd.Timestamp(start_time)
|
|
72
|
+
]
|
|
77
73
|
|
|
78
74
|
if end_time is not None:
|
|
79
|
-
valid_t0_and_site_ids
|
|
80
|
-
|
|
75
|
+
valid_t0_and_site_ids = valid_t0_and_site_ids[
|
|
76
|
+
valid_t0_and_site_ids["t0"] <= pd.Timestamp(end_time)
|
|
77
|
+
]
|
|
81
78
|
|
|
82
79
|
# Assign coords and indices to self
|
|
83
80
|
self.valid_t0_and_site_ids = valid_t0_and_site_ids
|
|
84
81
|
|
|
85
|
-
|
|
82
|
+
@override
|
|
83
|
+
def __len__(self) -> int:
|
|
86
84
|
return len(self.valid_t0_and_site_ids)
|
|
87
|
-
|
|
88
|
-
def __getitem__(self, idx):
|
|
89
85
|
|
|
86
|
+
@override
|
|
87
|
+
def __getitem__(self, idx: int) -> dict:
|
|
90
88
|
# Get the coordinates of the sample
|
|
91
89
|
t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
|
|
92
90
|
|
|
@@ -97,7 +95,7 @@ class SitesDataset(Dataset):
|
|
|
97
95
|
return self._get_sample(t0, location)
|
|
98
96
|
|
|
99
97
|
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
|
|
100
|
-
"""Generate the PVNet sample for given coordinates
|
|
98
|
+
"""Generate the PVNet sample for given coordinates.
|
|
101
99
|
|
|
102
100
|
Args:
|
|
103
101
|
t0: init-time for sample
|
|
@@ -106,7 +104,7 @@ class SitesDataset(Dataset):
|
|
|
106
104
|
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
|
|
107
105
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
108
106
|
|
|
109
|
-
sample = self.process_and_combine_site_sample_dict(sample_dict)
|
|
107
|
+
sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
|
|
110
108
|
sample = sample.compute()
|
|
111
109
|
return sample
|
|
112
110
|
|
|
@@ -119,20 +117,20 @@ class SitesDataset(Dataset):
|
|
|
119
117
|
t0: init-time for sample
|
|
120
118
|
site_id: site id as int
|
|
121
119
|
"""
|
|
122
|
-
|
|
123
120
|
location = self.get_location_from_site_id(site_id)
|
|
124
121
|
|
|
125
122
|
return self._get_sample(t0, location)
|
|
126
|
-
|
|
127
|
-
def get_location_from_site_id(self, site_id):
|
|
128
|
-
"""Get location from system id"""
|
|
129
123
|
|
|
124
|
+
def get_location_from_site_id(self, site_id: int) -> Location:
|
|
125
|
+
"""Get location from system id."""
|
|
130
126
|
locations = [loc for loc in self.locations if loc.id == site_id]
|
|
131
127
|
if len(locations) == 0:
|
|
132
128
|
raise ValueError(f"Location not found for site_id {site_id}")
|
|
133
129
|
|
|
134
130
|
if len(locations) > 1:
|
|
135
|
-
logging.warning(
|
|
131
|
+
logging.warning(
|
|
132
|
+
f"Multiple locations found for site_id {site_id}, but will take the first",
|
|
133
|
+
)
|
|
136
134
|
|
|
137
135
|
return locations[0]
|
|
138
136
|
|
|
@@ -140,7 +138,7 @@ class SitesDataset(Dataset):
|
|
|
140
138
|
self,
|
|
141
139
|
datasets_dict: dict,
|
|
142
140
|
) -> pd.DataFrame:
|
|
143
|
-
"""Find the t0 times where all of the requested input data is available
|
|
141
|
+
"""Find the t0 times where all of the requested input data is available.
|
|
144
142
|
|
|
145
143
|
The idea is to
|
|
146
144
|
1. Get valid time period for nwp and satellite
|
|
@@ -150,9 +148,8 @@ class SitesDataset(Dataset):
|
|
|
150
148
|
datasets_dict: A dictionary of input datasets
|
|
151
149
|
config: Configuration file
|
|
152
150
|
"""
|
|
153
|
-
|
|
154
151
|
# 1. Get valid time period for nwp and satellite
|
|
155
|
-
datasets_without_site = {k:v for k, v in datasets_dict.items() if k!="site"}
|
|
152
|
+
datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
|
|
156
153
|
valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
|
|
157
154
|
|
|
158
155
|
# 2. Now lets loop over each location in system id and find the valid periods
|
|
@@ -166,39 +163,37 @@ class SitesDataset(Dataset):
|
|
|
166
163
|
|
|
167
164
|
# drop any nan values
|
|
168
165
|
# not sure this is right?
|
|
169
|
-
site = site.dropna(dim=
|
|
166
|
+
site = site.dropna(dim="time_utc")
|
|
170
167
|
|
|
171
168
|
# Get the valid time periods for this location
|
|
172
169
|
time_periods = find_contiguous_t0_periods(
|
|
173
170
|
pd.DatetimeIndex(site["time_utc"]),
|
|
174
|
-
|
|
171
|
+
time_resolution=minutes(site_config.time_resolution_minutes),
|
|
175
172
|
interval_start=minutes(site_config.interval_start_minutes),
|
|
176
173
|
interval_end=minutes(site_config.interval_end_minutes),
|
|
177
174
|
)
|
|
178
175
|
valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
|
|
179
|
-
[valid_time_periods, time_periods]
|
|
176
|
+
[valid_time_periods, time_periods],
|
|
180
177
|
)
|
|
181
178
|
|
|
182
179
|
# Fill out the contiguous time periods to get the t0 times
|
|
183
180
|
valid_t0_times_per_site = fill_time_periods(
|
|
184
181
|
valid_time_periods_per_site,
|
|
185
|
-
freq=minutes(site_config.time_resolution_minutes)
|
|
182
|
+
freq=minutes(site_config.time_resolution_minutes),
|
|
186
183
|
)
|
|
187
184
|
|
|
188
185
|
valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site)
|
|
189
|
-
valid_t0_per_site[
|
|
186
|
+
valid_t0_per_site["site_id"] = site_id
|
|
190
187
|
valid_t0_and_site_ids.append(valid_t0_per_site)
|
|
191
188
|
|
|
192
189
|
valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
|
|
193
|
-
valid_t0_and_site_ids.index.name =
|
|
190
|
+
valid_t0_and_site_ids.index.name = "t0"
|
|
194
191
|
valid_t0_and_site_ids.reset_index(inplace=True)
|
|
195
192
|
|
|
196
193
|
return valid_t0_and_site_ids
|
|
197
194
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
"""Get list of locations of all sites"""
|
|
201
|
-
|
|
195
|
+
def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
|
|
196
|
+
"""Get list of locations of all sites."""
|
|
202
197
|
locations = []
|
|
203
198
|
for site_id in site_xr.site_id.values:
|
|
204
199
|
site = site_xr.sel(site_id=site_id)
|
|
@@ -206,7 +201,7 @@ class SitesDataset(Dataset):
|
|
|
206
201
|
id=site_id,
|
|
207
202
|
x=site.longitude.values,
|
|
208
203
|
y=site.latitude.values,
|
|
209
|
-
coordinate_system="lon_lat"
|
|
204
|
+
coordinate_system="lon_lat",
|
|
210
205
|
)
|
|
211
206
|
locations.append(location)
|
|
212
207
|
|
|
@@ -215,34 +210,48 @@ class SitesDataset(Dataset):
|
|
|
215
210
|
def process_and_combine_site_sample_dict(
|
|
216
211
|
self,
|
|
217
212
|
dataset_dict: dict,
|
|
213
|
+
t0: pd.Timestamp,
|
|
218
214
|
) -> xr.Dataset:
|
|
219
|
-
"""
|
|
220
|
-
Normalize and combine data into a single xr Dataset
|
|
215
|
+
"""Normalize and combine data into a single xr Dataset.
|
|
221
216
|
|
|
222
217
|
Args:
|
|
223
218
|
dataset_dict: dict containing sliced xr DataArrays
|
|
224
|
-
|
|
219
|
+
t0: The initial timestamp of the sample
|
|
225
220
|
|
|
226
221
|
Returns:
|
|
227
222
|
xr.Dataset: A merged Dataset with nans filled in.
|
|
228
|
-
|
|
229
|
-
"""
|
|
230
223
|
|
|
224
|
+
"""
|
|
231
225
|
data_arrays = []
|
|
232
226
|
|
|
233
227
|
if "nwp" in dataset_dict:
|
|
234
228
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
235
229
|
provider = self.config.input_data.nwp[nwp_key].provider
|
|
236
|
-
|
|
230
|
+
|
|
237
231
|
# Standardise
|
|
238
|
-
|
|
232
|
+
da_channel_means = channel_dict_to_dataarray(
|
|
233
|
+
self.config.input_data.nwp[nwp_key].channel_means,
|
|
234
|
+
)
|
|
235
|
+
da_channel_stds = channel_dict_to_dataarray(
|
|
236
|
+
self.config.input_data.nwp[nwp_key].channel_stds,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
240
|
+
|
|
239
241
|
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
240
|
-
|
|
242
|
+
|
|
241
243
|
if "sat" in dataset_dict:
|
|
242
244
|
da_sat = dataset_dict["sat"]
|
|
243
245
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
+
da_channel_means = channel_dict_to_dataarray(
|
|
247
|
+
self.config.input_data.satellite.channel_means,
|
|
248
|
+
)
|
|
249
|
+
da_channel_stds = channel_dict_to_dataarray(
|
|
250
|
+
self.config.input_data.satellite.channel_stds,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
254
|
+
|
|
246
255
|
data_arrays.append(("satellite", da_sat))
|
|
247
256
|
|
|
248
257
|
if "site" in dataset_dict:
|
|
@@ -257,33 +266,57 @@ class SitesDataset(Dataset):
|
|
|
257
266
|
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
|
|
258
267
|
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
|
|
259
268
|
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
260
|
-
{k: ("site__time_utc", v) for k, v in datetime_features.items()}
|
|
269
|
+
{k: ("site__time_utc", v) for k, v in datetime_features.items()},
|
|
261
270
|
)
|
|
262
271
|
|
|
263
|
-
# add
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
lat=combined_sample_dataset.site__latitude.values,
|
|
268
|
-
key_prefix="site_",
|
|
269
|
-
)
|
|
270
|
-
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
271
|
-
{k: ("site__time_utc", v) for k, v in sun_position_features.items()}
|
|
272
|
+
# Only add solar position if explicitly configured
|
|
273
|
+
has_solar_config = (
|
|
274
|
+
hasattr(self.config.input_data, "solar_position") and
|
|
275
|
+
self.config.input_data.solar_position is not None
|
|
272
276
|
)
|
|
273
277
|
|
|
274
|
-
|
|
278
|
+
if has_solar_config:
|
|
279
|
+
solar_config = self.config.input_data.solar_position
|
|
280
|
+
|
|
281
|
+
# Datetime range - solar config params
|
|
282
|
+
solar_datetimes = pd.date_range(
|
|
283
|
+
t0 + minutes(solar_config.interval_start_minutes),
|
|
284
|
+
t0 + minutes(solar_config.interval_end_minutes),
|
|
285
|
+
freq=minutes(solar_config.time_resolution_minutes),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Calculate sun position features
|
|
289
|
+
sun_position_features = make_sun_position_numpy_sample(
|
|
290
|
+
datetimes=solar_datetimes,
|
|
291
|
+
lon=combined_sample_dataset.site__longitude.values,
|
|
292
|
+
lat=combined_sample_dataset.site__latitude.values,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Dimension state for solar position data
|
|
296
|
+
solar_dim_name = "solar_time_utc"
|
|
297
|
+
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
298
|
+
{solar_dim_name: solar_datetimes},
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Assign solar position values
|
|
302
|
+
for key, values in sun_position_features.items():
|
|
303
|
+
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
304
|
+
{key: (solar_dim_name, values)},
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# TODO include t0_index in xr dataset?
|
|
275
308
|
|
|
276
309
|
# Fill any nan values
|
|
277
310
|
return combined_sample_dataset.fillna(0.0)
|
|
278
311
|
|
|
279
312
|
def merge_data_arrays(
|
|
280
|
-
self,
|
|
313
|
+
self,
|
|
314
|
+
normalised_data_arrays: list[tuple[str, xr.DataArray]],
|
|
281
315
|
) -> xr.Dataset:
|
|
282
|
-
"""
|
|
283
|
-
Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
316
|
+
"""Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
284
317
|
|
|
285
318
|
Args:
|
|
286
|
-
|
|
319
|
+
normalised_data_arrays: List of tuples where each tuple contains:
|
|
287
320
|
- A string (key name).
|
|
288
321
|
- An xarray.DataArray.
|
|
289
322
|
|
|
@@ -295,7 +328,7 @@ class SitesDataset(Dataset):
|
|
|
295
328
|
for key, data_array in normalised_data_arrays:
|
|
296
329
|
# Ensure all attributes are strings for consistency
|
|
297
330
|
data_array = data_array.assign_attrs(
|
|
298
|
-
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
|
|
331
|
+
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()},
|
|
299
332
|
)
|
|
300
333
|
|
|
301
334
|
# Convert DataArray to Dataset with the variable name as the key
|
|
@@ -303,15 +336,16 @@ class SitesDataset(Dataset):
|
|
|
303
336
|
|
|
304
337
|
# Prepend key name to all dimension and coordinate names for uniqueness
|
|
305
338
|
dataset = dataset.rename(
|
|
306
|
-
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
|
|
339
|
+
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords},
|
|
307
340
|
)
|
|
308
341
|
dataset = dataset.rename(
|
|
309
|
-
{coord: f"{key}__{coord}" for coord in dataset.coords}
|
|
342
|
+
{coord: f"{key}__{coord}" for coord in dataset.coords},
|
|
310
343
|
)
|
|
311
344
|
|
|
312
345
|
# Handle concatenation dimension if applicable
|
|
313
346
|
concat_dim = (
|
|
314
|
-
f"{key}__target_time_utc"
|
|
347
|
+
f"{key}__target_time_utc"
|
|
348
|
+
if f"{key}__target_time_utc" in dataset.coords
|
|
315
349
|
else f"{key}__time_utc"
|
|
316
350
|
)
|
|
317
351
|
|
|
@@ -325,20 +359,22 @@ class SitesDataset(Dataset):
|
|
|
325
359
|
|
|
326
360
|
# Ensure all datasets are valid xarray.Dataset objects
|
|
327
361
|
for ds in datasets:
|
|
328
|
-
|
|
362
|
+
if not isinstance(ds, xr.Dataset):
|
|
363
|
+
raise ValueError(f"Object is not an xr.Dataset: {type(ds)}")
|
|
329
364
|
|
|
330
365
|
# Merge all prepared datasets
|
|
331
366
|
combined_dataset = xr.merge(datasets)
|
|
332
367
|
|
|
333
368
|
return combined_dataset
|
|
334
369
|
|
|
370
|
+
|
|
335
371
|
# ----- functions to load presaved samples ------
|
|
336
372
|
|
|
337
|
-
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
338
|
-
"""Convert a netcdf dataset to a numpy sample"""
|
|
339
373
|
|
|
374
|
+
def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
375
|
+
"""Convert a netcdf dataset to a numpy sample."""
|
|
340
376
|
# convert the single dataset to a dict of arrays
|
|
341
|
-
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
377
|
+
sample_dict = convert_from_dataset_to_dict_datasets(ds)
|
|
342
378
|
|
|
343
379
|
if "satellite" in sample_dict:
|
|
344
380
|
# rename satellite to satellite actual # TODO this could be improves
|
|
@@ -349,14 +385,21 @@ def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
|
|
|
349
385
|
dataset_dict=sample_dict,
|
|
350
386
|
)
|
|
351
387
|
|
|
352
|
-
#
|
|
353
|
-
|
|
388
|
+
# Extraction of solar position coords
|
|
389
|
+
solar_keys = ["solar_azimuth", "solar_elevation"]
|
|
390
|
+
for key in solar_keys:
|
|
391
|
+
if key in ds.coords:
|
|
392
|
+
sample[key] = ds.coords[key].values
|
|
393
|
+
|
|
394
|
+
# TODO think about normalization:
|
|
395
|
+
# * maybe its done not in sample creation, maybe its done afterwards,
|
|
396
|
+
# to allow it to be flexible
|
|
354
397
|
|
|
355
398
|
return sample
|
|
356
399
|
|
|
400
|
+
|
|
357
401
|
def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]:
|
|
358
|
-
"""
|
|
359
|
-
Convert a combined sample dataset to a dict of datasets for each input
|
|
402
|
+
"""Convert a combined sample dataset to a dict of datasets for each input.
|
|
360
403
|
|
|
361
404
|
Args:
|
|
362
405
|
combined_dataset: The combined NetCDF dataset
|
|
@@ -374,10 +417,10 @@ def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[
|
|
|
374
417
|
if f"{key}__" not in dim:
|
|
375
418
|
dataset: xr.Dataset = dataset.drop(dim)
|
|
376
419
|
dataset = dataset.rename(
|
|
377
|
-
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords}
|
|
420
|
+
{dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
|
|
378
421
|
)
|
|
379
422
|
dataset: xr.Dataset = dataset.rename(
|
|
380
|
-
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords}
|
|
423
|
+
{coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
|
|
381
424
|
)
|
|
382
425
|
# Split the dataset by the prefix
|
|
383
426
|
datasets[key] = dataset
|
|
@@ -391,22 +434,21 @@ def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
|
|
|
391
434
|
"""Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
|
|
392
435
|
nwp_prefix = f"nwp{sep}"
|
|
393
436
|
new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
|
|
394
|
-
nwp_keys = [k for k in d
|
|
437
|
+
nwp_keys = [k for k in d if k.startswith(nwp_prefix)]
|
|
395
438
|
if len(nwp_keys) > 0:
|
|
396
439
|
nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
|
|
397
440
|
new_dict["nwp"] = nwp_subdict
|
|
398
441
|
return new_dict
|
|
399
442
|
|
|
443
|
+
|
|
400
444
|
def convert_to_numpy_and_combine(
|
|
401
445
|
dataset_dict: dict,
|
|
402
446
|
) -> dict:
|
|
403
|
-
"""Convert input data in a dict to numpy arrays"""
|
|
404
|
-
|
|
447
|
+
"""Convert input data in a dict to numpy arrays."""
|
|
405
448
|
numpy_modalities = []
|
|
406
449
|
|
|
407
450
|
if "nwp" in dataset_dict:
|
|
408
|
-
|
|
409
|
-
nwp_numpy_modalities = dict()
|
|
451
|
+
nwp_numpy_modalities = {}
|
|
410
452
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
411
453
|
# Convert to NumpySample
|
|
412
454
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
@@ -427,7 +469,7 @@ def convert_to_numpy_and_combine(
|
|
|
427
469
|
numpy_modalities.append(
|
|
428
470
|
convert_site_to_numpy_sample(
|
|
429
471
|
da_sites,
|
|
430
|
-
)
|
|
472
|
+
),
|
|
431
473
|
)
|
|
432
474
|
|
|
433
475
|
# Combine all the modalities and fill NaNs
|
|
@@ -437,25 +479,23 @@ def convert_to_numpy_and_combine(
|
|
|
437
479
|
return combined_sample
|
|
438
480
|
|
|
439
481
|
|
|
440
|
-
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float=0.1):
|
|
441
|
-
"""
|
|
442
|
-
|
|
443
|
-
|
|
482
|
+
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
|
|
483
|
+
"""Coarsen the data to a specified resolution in degrees.
|
|
484
|
+
|
|
444
485
|
Args:
|
|
445
486
|
xr_data: xarray dataset to coarsen
|
|
446
487
|
coarsen_to_deg: resolution to coarsen to in degrees
|
|
447
488
|
"""
|
|
448
|
-
|
|
449
489
|
if "latitude" in xr_data.coords and "longitude" in xr_data.coords:
|
|
450
|
-
step = np.abs(xr_data.latitude.values[1]-xr_data.latitude.values[0])
|
|
451
|
-
step = np.round(step,4)
|
|
452
|
-
coarsen_factor = int(coarsen_to_deg/step)
|
|
490
|
+
step = np.abs(xr_data.latitude.values[1] - xr_data.latitude.values[0])
|
|
491
|
+
step = np.round(step, 4)
|
|
492
|
+
coarsen_factor = int(coarsen_to_deg / step)
|
|
453
493
|
if coarsen_factor > 1:
|
|
454
494
|
xr_data = xr_data.coarsen(
|
|
455
495
|
latitude=coarsen_factor,
|
|
456
496
|
longitude=coarsen_factor,
|
|
457
497
|
boundary="pad",
|
|
458
|
-
coord_func="min"
|
|
498
|
+
coord_func="min",
|
|
459
499
|
).mean()
|
|
460
|
-
|
|
461
|
-
return xr_data
|
|
500
|
+
|
|
501
|
+
return xr_data
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Converts a dictionary of channel values to a DataArray."""
|
|
2
|
+
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
|
|
7
|
+
"""Converts a dictionary of channel values to a DataArray."""
|
|
8
|
+
return xr.DataArray(
|
|
9
|
+
list(channel_dict.values()),
|
|
10
|
+
coords={"channel": list(channel_dict.keys())},
|
|
11
|
+
)
|
|
@@ -1,13 +1,17 @@
|
|
|
1
|
+
"""Utility functions for merging dictionaries and filling NaNs in arrays."""
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
|
|
5
|
+
|
|
3
6
|
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
4
|
-
"""Merge a list of dictionaries into a single dictionary"""
|
|
7
|
+
"""Merge a list of dictionaries into a single dictionary."""
|
|
5
8
|
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
6
9
|
combined_dict = {}
|
|
7
10
|
for d in list_of_dicts:
|
|
8
11
|
combined_dict.update(d)
|
|
9
12
|
return combined_dict
|
|
10
13
|
|
|
14
|
+
|
|
11
15
|
def fill_nans_in_arrays(sample: dict) -> dict:
|
|
12
16
|
"""Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
|
|
13
17
|
|
|
@@ -22,4 +26,4 @@ def fill_nans_in_arrays(sample: dict) -> dict:
|
|
|
22
26
|
elif isinstance(v, dict):
|
|
23
27
|
fill_nans_in_arrays(v)
|
|
24
28
|
|
|
25
|
-
return sample
|
|
29
|
+
return sample
|