ocf-data-sampler 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/save.py +22 -11
- ocf_data_sampler/numpy_sample/__init__.py +1 -0
- ocf_data_sampler/numpy_sample/datetime_features.py +46 -0
- ocf_data_sampler/torch_datasets/{pvnet_uk_regional.py → datasets/pvnet_uk_regional.py} +102 -4
- ocf_data_sampler/torch_datasets/{site.py → datasets/site.py} +23 -5
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +25 -0
- {ocf_data_sampler-0.0.44.dist-info → ocf_data_sampler-0.0.46.dist-info}/METADATA +3 -2
- {ocf_data_sampler-0.0.44.dist-info → ocf_data_sampler-0.0.46.dist-info}/RECORD +20 -18
- tests/config/test_config.py +25 -27
- tests/conftest.py +2 -2
- tests/numpy_sample/test_collate.py +1 -1
- tests/numpy_sample/test_datetime_features.py +47 -0
- tests/torch_datasets/test_merge_and_fill_utils.py +42 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +80 -3
- tests/torch_datasets/test_site.py +4 -4
- ocf_data_sampler/torch_datasets/process_and_combine.py +0 -131
- tests/torch_datasets/test_process_and_combine.py +0 -126
- /ocf_data_sampler/torch_datasets/{__init__.py → datasets/__init__.py} +0 -0
- /ocf_data_sampler/torch_datasets/{valid_time_periods.py → utils/valid_time_periods.py} +0 -0
- {ocf_data_sampler-0.0.44.dist-info → ocf_data_sampler-0.0.46.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.44.dist-info → ocf_data_sampler-0.0.46.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.0.44.dist-info → ocf_data_sampler-0.0.46.dist-info}/top_level.txt +0 -0
ocf_data_sampler/config/save.py
CHANGED
|
@@ -9,7 +9,6 @@ Example:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import json
|
|
12
|
-
|
|
13
12
|
from pathlib import Path
|
|
14
13
|
from typing import Union
|
|
15
14
|
|
|
@@ -18,7 +17,6 @@ import yaml
|
|
|
18
17
|
|
|
19
18
|
from ocf_data_sampler.config import Configuration
|
|
20
19
|
|
|
21
|
-
|
|
22
20
|
def save_yaml_configuration(
|
|
23
21
|
configuration: Configuration,
|
|
24
22
|
filename: Union[str, Path],
|
|
@@ -35,7 +33,7 @@ def save_yaml_configuration(
|
|
|
35
33
|
Path: The path where the configuration was saved
|
|
36
34
|
|
|
37
35
|
Raises:
|
|
38
|
-
ValueError: If filename is None or if writing to the specified path fails
|
|
36
|
+
ValueError: If filename is None, directory doesn't exist, or if writing to the specified path fails
|
|
39
37
|
TypeError: If the configuration cannot be serialized
|
|
40
38
|
"""
|
|
41
39
|
if filename is None:
|
|
@@ -50,24 +48,37 @@ def save_yaml_configuration(
|
|
|
50
48
|
|
|
51
49
|
filepath = Path(filename)
|
|
52
50
|
|
|
53
|
-
# For local
|
|
51
|
+
# For local paths, check if parent directory exists before attempting to create
|
|
54
52
|
if filepath.is_absolute():
|
|
55
|
-
|
|
56
|
-
if not directory.exists():
|
|
53
|
+
if not filepath.parent.exists():
|
|
57
54
|
raise ValueError("Directory does not exist")
|
|
55
|
+
|
|
56
|
+
# Only try to create directory if it's in a writable location
|
|
57
|
+
try:
|
|
58
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
except PermissionError:
|
|
60
|
+
raise ValueError(f"Permission denied when accessing directory {filepath.parent}")
|
|
58
61
|
|
|
59
62
|
# Serialize configuration to JSON-compatible dictionary
|
|
60
63
|
config_dict = json.loads(configuration.model_dump_json())
|
|
61
64
|
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
# Write to file directly for local paths
|
|
66
|
+
if filepath.is_absolute():
|
|
67
|
+
try:
|
|
68
|
+
with open(filepath, 'w') as f:
|
|
69
|
+
yaml.safe_dump(config_dict, f, default_flow_style=False)
|
|
70
|
+
except PermissionError:
|
|
71
|
+
raise ValueError(f"Permission denied when writing to {filename}")
|
|
72
|
+
else:
|
|
73
|
+
# Use fsspec for cloud storage
|
|
74
|
+
with fsspec.open(str(filepath), mode='w') as yaml_file:
|
|
75
|
+
yaml.safe_dump(config_dict, yaml_file, default_flow_style=False)
|
|
65
76
|
|
|
66
77
|
return filepath
|
|
67
78
|
|
|
68
79
|
except json.JSONDecodeError as e:
|
|
69
80
|
raise TypeError(f"Failed to serialize configuration: {str(e)}") from e
|
|
70
|
-
except PermissionError as e:
|
|
71
|
-
raise ValueError(f"Permission denied when writing to {filename}") from e
|
|
72
81
|
except (IOError, OSError) as e:
|
|
82
|
+
if "Permission denied" in str(e):
|
|
83
|
+
raise ValueError(f"Permission denied when writing to {filename}") from e
|
|
73
84
|
raise ValueError(f"Failed to write configuration to {filename}: {str(e)}") from e
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Conversion from Xarray to NumpySample"""
|
|
2
2
|
|
|
3
|
+
from .datetime_features import make_datetime_numpy_dict
|
|
3
4
|
from .gsp import convert_gsp_to_numpy_sample, GSPSampleKey
|
|
4
5
|
from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
|
|
5
6
|
from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Functions to create trigonometric date and time inputs"""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _get_date_time_in_pi(
|
|
9
|
+
dt: pd.DatetimeIndex,
|
|
10
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
11
|
+
"""
|
|
12
|
+
Change the datetimes, into time and date scaled in radians
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
day_of_year = dt.dayofyear
|
|
16
|
+
minute_of_day = dt.minute + dt.hour * 60
|
|
17
|
+
|
|
18
|
+
# converting into positions on sin-cos circle
|
|
19
|
+
time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 60))
|
|
20
|
+
date_in_pi = (2 * np.pi) * (day_of_year / 365)
|
|
21
|
+
|
|
22
|
+
return date_in_pi, time_in_pi
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def make_datetime_numpy_dict(datetimes: pd.DatetimeIndex, key_prefix: str = "wind") -> dict:
|
|
26
|
+
""" Make dictionary of datetime features"""
|
|
27
|
+
|
|
28
|
+
if datetimes.empty:
|
|
29
|
+
raise ValueError("Input datetimes is empty for 'make_datetime_numpy_dict' function")
|
|
30
|
+
|
|
31
|
+
time_numpy_sample = {}
|
|
32
|
+
|
|
33
|
+
date_in_pi, time_in_pi = _get_date_time_in_pi(datetimes)
|
|
34
|
+
|
|
35
|
+
# Store
|
|
36
|
+
date_sin_batch_key = key_prefix + "_date_sin"
|
|
37
|
+
date_cos_batch_key = key_prefix + "_date_cos"
|
|
38
|
+
time_sin_batch_key = key_prefix + "_time_sin"
|
|
39
|
+
time_cos_batch_key = key_prefix + "_time_cos"
|
|
40
|
+
|
|
41
|
+
time_numpy_sample[date_sin_batch_key] = np.sin(date_in_pi)
|
|
42
|
+
time_numpy_sample[date_cos_batch_key] = np.cos(date_in_pi)
|
|
43
|
+
time_numpy_sample[time_sin_batch_key] = np.sin(time_in_pi)
|
|
44
|
+
time_numpy_sample[time_cos_batch_key] = np.cos(time_in_pi)
|
|
45
|
+
|
|
46
|
+
return time_numpy_sample
|
|
@@ -5,16 +5,114 @@ import pandas as pd
|
|
|
5
5
|
import pkg_resources
|
|
6
6
|
import xarray as xr
|
|
7
7
|
from torch.utils.data import Dataset
|
|
8
|
-
|
|
9
8
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
10
9
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
10
|
from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
|
|
12
11
|
from ocf_data_sampler.utils import minutes
|
|
13
|
-
from ocf_data_sampler.torch_datasets.
|
|
14
|
-
from ocf_data_sampler.
|
|
12
|
+
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
13
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
14
|
+
from ocf_data_sampler.numpy_sample import (
|
|
15
|
+
convert_nwp_to_numpy_sample,
|
|
16
|
+
convert_satellite_to_numpy_sample,
|
|
17
|
+
convert_gsp_to_numpy_sample,
|
|
18
|
+
make_sun_position_numpy_sample,
|
|
19
|
+
)
|
|
20
|
+
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
21
|
+
merge_dicts,
|
|
22
|
+
fill_nans_in_arrays,
|
|
23
|
+
)
|
|
24
|
+
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
25
|
+
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
26
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
15
27
|
|
|
16
28
|
xr.set_options(keep_attrs=True)
|
|
17
29
|
|
|
30
|
+
def process_and_combine_datasets(
|
|
31
|
+
dataset_dict: dict,
|
|
32
|
+
config: Configuration,
|
|
33
|
+
t0: pd.Timestamp,
|
|
34
|
+
location: Location,
|
|
35
|
+
target_key: str = 'gsp'
|
|
36
|
+
) -> dict:
|
|
37
|
+
|
|
38
|
+
"""Normalise and convert data to numpy arrays"""
|
|
39
|
+
numpy_modalities = []
|
|
40
|
+
|
|
41
|
+
if "nwp" in dataset_dict:
|
|
42
|
+
|
|
43
|
+
nwp_numpy_modalities = dict()
|
|
44
|
+
|
|
45
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
46
|
+
# Standardise
|
|
47
|
+
provider = config.input_data.nwp[nwp_key].provider
|
|
48
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
49
|
+
|
|
50
|
+
# Convert to NumpyBatch
|
|
51
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
52
|
+
|
|
53
|
+
# Combine the NWPs into NumpyBatch
|
|
54
|
+
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if "sat" in dataset_dict:
|
|
58
|
+
# Standardise
|
|
59
|
+
da_sat = dataset_dict["sat"]
|
|
60
|
+
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
61
|
+
|
|
62
|
+
# Convert to NumpyBatch
|
|
63
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
64
|
+
|
|
65
|
+
gsp_config = config.input_data.gsp
|
|
66
|
+
|
|
67
|
+
if "gsp" in dataset_dict:
|
|
68
|
+
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
|
|
69
|
+
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
70
|
+
|
|
71
|
+
numpy_modalities.append(
|
|
72
|
+
convert_gsp_to_numpy_sample(
|
|
73
|
+
da_gsp,
|
|
74
|
+
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Add coordinate data
|
|
79
|
+
# TODO: Do we need all of these?
|
|
80
|
+
numpy_modalities.append(
|
|
81
|
+
{
|
|
82
|
+
GSPSampleKey.gsp_id: location.id,
|
|
83
|
+
GSPSampleKey.x_osgb: location.x,
|
|
84
|
+
GSPSampleKey.y_osgb: location.y,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if target_key == 'gsp':
|
|
89
|
+
# Make sun coords NumpySample
|
|
90
|
+
datetimes = pd.date_range(
|
|
91
|
+
t0+minutes(gsp_config.interval_start_minutes),
|
|
92
|
+
t0+minutes(gsp_config.interval_end_minutes),
|
|
93
|
+
freq=minutes(gsp_config.time_resolution_minutes),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
97
|
+
|
|
98
|
+
numpy_modalities.append(
|
|
99
|
+
make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Combine all the modalities and fill NaNs
|
|
103
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
104
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
105
|
+
|
|
106
|
+
return combined_sample
|
|
107
|
+
|
|
108
|
+
def compute(xarray_dict: dict) -> dict:
|
|
109
|
+
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
110
|
+
for k, v in xarray_dict.items():
|
|
111
|
+
if isinstance(v, dict):
|
|
112
|
+
xarray_dict[k] = compute(v)
|
|
113
|
+
else:
|
|
114
|
+
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
115
|
+
return xarray_dict
|
|
18
116
|
|
|
19
117
|
def find_valid_t0_times(
|
|
20
118
|
datasets_dict: dict,
|
|
@@ -48,7 +146,7 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
|
|
|
48
146
|
|
|
49
147
|
# Load UK GSP locations
|
|
50
148
|
df_gsp_loc = pd.read_csv(
|
|
51
|
-
pkg_resources.resource_filename(__name__, "
|
|
149
|
+
pkg_resources.resource_filename(__name__, "../../data/uk_gsp_locations.csv"),
|
|
52
150
|
index_col="gsp_id",
|
|
53
151
|
)
|
|
54
152
|
|
|
@@ -17,12 +17,14 @@ from ocf_data_sampler.select import (
|
|
|
17
17
|
slice_datasets_by_time, slice_datasets_by_space
|
|
18
18
|
)
|
|
19
19
|
from ocf_data_sampler.utils import minutes
|
|
20
|
-
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
21
|
-
from ocf_data_sampler.torch_datasets.
|
|
20
|
+
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
21
|
+
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
|
|
22
22
|
from ocf_data_sampler.numpy_sample import (
|
|
23
23
|
convert_site_to_numpy_sample,
|
|
24
24
|
convert_satellite_to_numpy_sample,
|
|
25
|
-
convert_nwp_to_numpy_sample
|
|
25
|
+
convert_nwp_to_numpy_sample,
|
|
26
|
+
make_datetime_numpy_dict,
|
|
27
|
+
make_sun_position_numpy_sample,
|
|
26
28
|
)
|
|
27
29
|
from ocf_data_sampler.numpy_sample import NWPSampleKey
|
|
28
30
|
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
@@ -234,10 +236,26 @@ class SitesDataset(Dataset):
|
|
|
234
236
|
da_sites = dataset_dict["site"]
|
|
235
237
|
da_sites = da_sites / da_sites.capacity_kwp
|
|
236
238
|
data_arrays.append(("site", da_sites))
|
|
237
|
-
|
|
239
|
+
|
|
238
240
|
combined_sample_dataset = self.merge_data_arrays(data_arrays)
|
|
239
241
|
|
|
240
|
-
#
|
|
242
|
+
# add datetime features
|
|
243
|
+
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
|
|
244
|
+
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site")
|
|
245
|
+
datetime_features_xr = xr.Dataset(datetime_features, coords={"site__time_utc": datetimes})
|
|
246
|
+
combined_sample_dataset = xr.merge([combined_sample_dataset, datetime_features_xr])
|
|
247
|
+
|
|
248
|
+
# add sun features
|
|
249
|
+
sun_position_features = make_sun_position_numpy_sample(
|
|
250
|
+
datetimes=datetimes,
|
|
251
|
+
lon=combined_sample_dataset.site__longitude.values,
|
|
252
|
+
lat=combined_sample_dataset.site__latitude.values,
|
|
253
|
+
key_prefix="site",
|
|
254
|
+
)
|
|
255
|
+
sun_position_features_xr = xr.Dataset(
|
|
256
|
+
sun_position_features, coords={"site__time_utc": datetimes}
|
|
257
|
+
)
|
|
258
|
+
combined_sample_dataset = xr.merge([combined_sample_dataset, sun_position_features_xr])
|
|
241
259
|
|
|
242
260
|
# Fill any nan values
|
|
243
261
|
return combined_sample_dataset.fillna(0.0)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
4
|
+
"""Merge a list of dictionaries into a single dictionary"""
|
|
5
|
+
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
6
|
+
combined_dict = {}
|
|
7
|
+
for d in list_of_dicts:
|
|
8
|
+
combined_dict.update(d)
|
|
9
|
+
return combined_dict
|
|
10
|
+
|
|
11
|
+
def fill_nans_in_arrays(sample: dict) -> dict:
|
|
12
|
+
"""Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
|
|
13
|
+
|
|
14
|
+
Operation is performed in-place on the sample.
|
|
15
|
+
"""
|
|
16
|
+
for k, v in sample.items():
|
|
17
|
+
if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
|
|
18
|
+
if np.isnan(v).any():
|
|
19
|
+
sample[k] = np.nan_to_num(v, copy=False, nan=0.0)
|
|
20
|
+
|
|
21
|
+
# Recursion is included to reach NWP arrays in subdict
|
|
22
|
+
elif isinstance(v, dict):
|
|
23
|
+
fill_nans_in_arrays(v)
|
|
24
|
+
|
|
25
|
+
return sample
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: ocf_data_sampler
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.46
|
|
4
4
|
Summary: Sample from weather data for renewable energy prediction
|
|
5
5
|
Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
|
|
6
6
|
Author-email: info@openclimatefix.org
|
|
@@ -56,7 +56,7 @@ Requires-Dist: mkdocs-material>=8.0; extra == "docs"
|
|
|
56
56
|
# ocf-data-sampler
|
|
57
57
|
|
|
58
58
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
59
|
-
[](#contributors-)
|
|
60
60
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
61
61
|
|
|
62
62
|
[](https://github.com/openclimatefix/ocf-data-sampler/tags)
|
|
@@ -135,6 +135,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
135
135
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/felix-e-h-p"><img src="https://avatars.githubusercontent.com/u/137530077?v=4?s=100" width="100px;" alt="Felix"/><br /><sub><b>Felix</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=felix-e-h-p" title="Code">💻</a></td>
|
|
136
136
|
<td align="center" valign="top" width="14.28%"><a href="https://timothyajaniportfolio-b6v3zq29k-timthegreat.vercel.app/"><img src="https://avatars.githubusercontent.com/u/60073728?v=4?s=100" width="100px;" alt="Ajani Timothy"/><br /><sub><b>Ajani Timothy</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=Tim1119" title="Code">💻</a></td>
|
|
137
137
|
<td align="center" valign="top" width="14.28%"><a href="https://rupeshmangalam.vercel.app/"><img src="https://avatars.githubusercontent.com/u/91172425?v=4?s=100" width="100px;" alt="Rupesh Mangalam"/><br /><sub><b>Rupesh Mangalam</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=RupeshMangalam21" title="Code">💻</a></td>
|
|
138
|
+
<td align="center" valign="top" width="14.28%"><a href="http://siddharth7113.github.io"><img src="https://avatars.githubusercontent.com/u/114160268?v=4?s=100" width="100px;" alt="Siddharth"/><br /><sub><b>Siddharth</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=siddharth7113" title="Code">💻</a></td>
|
|
138
139
|
</tr>
|
|
139
140
|
</tbody>
|
|
140
141
|
</table>
|
|
@@ -4,7 +4,7 @@ ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
|
4
4
|
ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
|
|
5
5
|
ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
|
|
6
6
|
ocf_data_sampler/config/model.py,sha256=sXmh7IadwXDT-7lxEl5_b3vjovZgZYR77EXy4GHaf4w,7276
|
|
7
|
-
ocf_data_sampler/config/save.py,sha256=
|
|
7
|
+
ocf_data_sampler/config/save.py,sha256=gB44isAZWUlCe3L6VBkLkngWC9GFpcCfAM57gy-0dkg,3156
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
|
|
10
10
|
ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
|
|
@@ -18,8 +18,9 @@ ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQ
|
|
|
18
18
|
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68dWjlHa6TVJzx3ac,1280
|
|
19
19
|
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
|
|
20
20
|
ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
|
|
21
|
-
ocf_data_sampler/numpy_sample/__init__.py,sha256=
|
|
21
|
+
ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
|
|
22
22
|
ocf_data_sampler/numpy_sample/collate.py,sha256=y8QFUhskaAfOMP22aVkexwyGAwLHbNE-q1pOZ6txWKA,2227
|
|
23
|
+
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=U-9uRplfZ7VYFA4qBduI8OkG2x_65RYIP8wrLG4i-Nw,1441
|
|
23
24
|
ocf_data_sampler/numpy_sample/gsp.py,sha256=5UaWO_aGRRVQo82wnDaT4zBKHihOnIsXiwgPjM8vGFM,1005
|
|
24
25
|
ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5be6syivk,1297
|
|
25
26
|
ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBunbhDN4Pjo0Grxs,910
|
|
@@ -35,21 +36,22 @@ ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejD
|
|
|
35
36
|
ocf_data_sampler/select/select_time_slice.py,sha256=D5P_cSvnv8Qs49K5au7lPxDr9U_VmDn42s5leMzHt0k,6122
|
|
36
37
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
37
38
|
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=LMw8KnOCKnPjD0m4UubAWERpaiQtzRKkI2cSh5a0A-M,4335
|
|
38
|
-
ocf_data_sampler/torch_datasets/__init__.py,sha256=nJUa2KzVa84ZoM0PT2AbDz26ennmAYc7M7WJVfypPMs,85
|
|
39
|
-
ocf_data_sampler/torch_datasets/
|
|
40
|
-
ocf_data_sampler/torch_datasets/
|
|
41
|
-
ocf_data_sampler/torch_datasets/
|
|
42
|
-
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
39
|
+
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=nJUa2KzVa84ZoM0PT2AbDz26ennmAYc7M7WJVfypPMs,85
|
|
40
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py,sha256=xxeX4Js9LQpydehi3BS7k9psqkYGzgJuM17uTYux40M,8742
|
|
41
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=7gTtXG3DFzs_0XlYK0oleFPT-Gena_NSngcG_FAnY54,15394
|
|
42
|
+
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
43
|
+
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
43
44
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
44
45
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
45
|
-
tests/conftest.py,sha256=
|
|
46
|
-
tests/config/test_config.py,sha256=
|
|
46
|
+
tests/conftest.py,sha256=DfrH0Pm552Tnl35eZn2UHCfOn2lHRiEQCcUCJIhycSU,8021
|
|
47
|
+
tests/config/test_config.py,sha256=Vq_kTL5tJcwEP-hXD_Nah5O6cgafo99iX6Fw1AN5NDY,5288
|
|
47
48
|
tests/config/test_save.py,sha256=rA_XVxP1pOxB--5Ebujz4T5o-VbcrCbg2VSlSq2iI0o,1318
|
|
48
49
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
49
50
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
50
51
|
tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
|
|
51
52
|
tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
|
|
52
|
-
tests/numpy_sample/test_collate.py,sha256=
|
|
53
|
+
tests/numpy_sample/test_collate.py,sha256=ngbJ8vIewnAvkXx-PpfuSMVNM82_SYaZPLhJkZZw7s0,867
|
|
54
|
+
tests/numpy_sample/test_datetime_features.py,sha256=o4t3KeKFdGrOBQ77rNFcDuDMQSD23ileCS5T5AP3wG4,1769
|
|
53
55
|
tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
|
|
54
56
|
tests/numpy_sample/test_nwp.py,sha256=yf4u7mAU0E3FQ4xAH6YjuHuHBzzFoXjHSFNkOVJUdSM,1455
|
|
55
57
|
tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
|
|
@@ -61,11 +63,11 @@ tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts
|
|
|
61
63
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
62
64
|
tests/select/test_select_time_slice.py,sha256=K1EJR5TwZa9dJf_YTEHxGtvs398iy1xS2lr1BgJZkoo,9603
|
|
63
65
|
tests/torch_datasets/conftest.py,sha256=eRCzHE7cxS4AoskExkCGFDBeqItktAYNAdkfpMoFCeE,629
|
|
64
|
-
tests/torch_datasets/
|
|
65
|
-
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=
|
|
66
|
-
tests/torch_datasets/test_site.py,sha256=
|
|
67
|
-
ocf_data_sampler-0.0.
|
|
68
|
-
ocf_data_sampler-0.0.
|
|
69
|
-
ocf_data_sampler-0.0.
|
|
70
|
-
ocf_data_sampler-0.0.
|
|
71
|
-
ocf_data_sampler-0.0.
|
|
66
|
+
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=ueA0A7gZaWEgNdsU8p3CnKuvSnlleTUjEhSw2HUUROM,1229
|
|
67
|
+
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=FCiFueeFqrsXe7gWguSjBz5ZeUrvyhGbGw81gaVvkHM,5087
|
|
68
|
+
tests/torch_datasets/test_site.py,sha256=0tnjgx6z4VlzjoF_V2p3Y2t2Z1d0o_07Vwb-FH_c3tU,4640
|
|
69
|
+
ocf_data_sampler-0.0.46.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
70
|
+
ocf_data_sampler-0.0.46.dist-info/METADATA,sha256=S8ScJ8z3O0O5qhgGZmdI0Ugan2Yz4dH0nGj9R8N1sgs,11788
|
|
71
|
+
ocf_data_sampler-0.0.46.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
72
|
+
ocf_data_sampler-0.0.46.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
73
|
+
ocf_data_sampler-0.0.46.dist-info/RECORD,,
|
tests/config/test_config.py
CHANGED
|
@@ -2,7 +2,7 @@ import tempfile
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
from pydantic import ValidationError
|
|
5
|
-
|
|
5
|
+
from pathlib import Path
|
|
6
6
|
from ocf_data_sampler.config import (
|
|
7
7
|
load_yaml_configuration,
|
|
8
8
|
Configuration,
|
|
@@ -21,39 +21,37 @@ def test_load_yaml_configuration(test_config_filename):
|
|
|
21
21
|
Test that yaml loading works for 'test_config.yaml'
|
|
22
22
|
and fails for an empty .yaml file
|
|
23
23
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
#
|
|
24
|
+
# Create temporary directory instead of file
|
|
25
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
26
|
+
# Create path for empty file
|
|
27
|
+
empty_file = Path(temp_dir) / "empty.yaml"
|
|
28
|
+
|
|
29
|
+
# Create an empty file
|
|
30
|
+
empty_file.touch()
|
|
31
|
+
|
|
32
|
+
# Test loading empty file
|
|
30
33
|
with pytest.raises(TypeError):
|
|
31
|
-
_ = load_yaml_configuration(
|
|
32
|
-
|
|
33
|
-
# test can load test_config.yaml
|
|
34
|
-
config = load_yaml_configuration(test_config_filename)
|
|
35
|
-
|
|
36
|
-
assert isinstance(config, Configuration)
|
|
37
|
-
|
|
34
|
+
_ = load_yaml_configuration(str(empty_file))
|
|
38
35
|
|
|
39
36
|
def test_yaml_save(test_config_filename):
|
|
40
37
|
"""
|
|
41
38
|
Check configuration can be saved to a .yaml file
|
|
42
39
|
"""
|
|
43
|
-
|
|
44
40
|
test_config = load_yaml_configuration(test_config_filename)
|
|
45
|
-
|
|
46
|
-
with tempfile.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
41
|
+
|
|
42
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
43
|
+
# Create path for config file
|
|
44
|
+
config_path = Path(temp_dir) / "test_config.yaml"
|
|
45
|
+
|
|
46
|
+
# Save configuration
|
|
47
|
+
saved_path = save_yaml_configuration(test_config, config_path)
|
|
48
|
+
|
|
49
|
+
# Verify file exists
|
|
50
|
+
assert saved_path.exists()
|
|
51
|
+
|
|
52
|
+
# Test loading saved configuration
|
|
53
|
+
loaded_config = load_yaml_configuration(str(saved_path))
|
|
54
|
+
assert loaded_config == test_config
|
|
57
55
|
|
|
58
56
|
|
|
59
57
|
def test_extra_field_error():
|
tests/conftest.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
|
-
|
|
3
2
|
import numpy as np
|
|
4
3
|
import pandas as pd
|
|
5
4
|
import pytest
|
|
6
5
|
import xarray as xr
|
|
7
6
|
import tempfile
|
|
7
|
+
from typing import Generator
|
|
8
8
|
|
|
9
9
|
from ocf_data_sampler.config.model import Site
|
|
10
10
|
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
@@ -201,7 +201,7 @@ def ds_uk_gsp():
|
|
|
201
201
|
|
|
202
202
|
|
|
203
203
|
@pytest.fixture(scope="session")
|
|
204
|
-
def data_sites() -> Site:
|
|
204
|
+
def data_sites() -> Generator[Site, None, None]:
|
|
205
205
|
"""
|
|
206
206
|
Make fake data for sites
|
|
207
207
|
Returns: filename for netcdf file, and csv metadata
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
|
|
2
2
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
3
|
-
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
3
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def test_pvnet(pvnet_config_filename):
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.numpy_sample.datetime_features import make_datetime_numpy_dict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_calculate_azimuth_and_elevation():
|
|
9
|
+
|
|
10
|
+
# Pick the day of the summer solstice
|
|
11
|
+
datetimes = pd.to_datetime(["2024-06-20 12:00", "2024-06-20 12:30", "2024-06-20 13:00"])
|
|
12
|
+
|
|
13
|
+
# Calculate sun angles
|
|
14
|
+
datetime_features = make_datetime_numpy_dict(datetimes)
|
|
15
|
+
|
|
16
|
+
assert len(datetime_features) == 4
|
|
17
|
+
|
|
18
|
+
assert len(datetime_features["wind_date_sin"]) == len(datetimes)
|
|
19
|
+
assert (datetime_features["wind_date_cos"] != datetime_features["wind_date_sin"]).all()
|
|
20
|
+
|
|
21
|
+
# assert all values are between -1 and 1
|
|
22
|
+
assert all(np.abs(datetime_features["wind_date_sin"]) <= 1)
|
|
23
|
+
assert all(np.abs(datetime_features["wind_date_cos"]) <= 1)
|
|
24
|
+
assert all(np.abs(datetime_features["wind_time_sin"]) <= 1)
|
|
25
|
+
assert all(np.abs(datetime_features["wind_time_cos"]) <= 1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_make_datetime_numpy_batch_custom_key_prefix():
|
|
29
|
+
# Test function correctly applies custom prefix to dict keys
|
|
30
|
+
datetimes = pd.to_datetime(["2024-06-20 12:00", "2024-06-20 12:30", "2024-06-20 13:00"])
|
|
31
|
+
key_prefix = "solar"
|
|
32
|
+
|
|
33
|
+
datetime_features = make_datetime_numpy_dict(datetimes, key_prefix=key_prefix)
|
|
34
|
+
|
|
35
|
+
# Assert dict contains expected quantity of keys and verify starting with custom prefix
|
|
36
|
+
assert len(datetime_features) == 4
|
|
37
|
+
assert all(key.startswith(key_prefix) for key in datetime_features.keys())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_make_datetime_numpy_batch_empty_input():
|
|
41
|
+
# Verification that function raises error for empty input
|
|
42
|
+
datetimes = pd.DatetimeIndex([])
|
|
43
|
+
|
|
44
|
+
with pytest.raises(
|
|
45
|
+
ValueError, match="Input datetimes is empty for 'make_datetime_numpy_dict' function"
|
|
46
|
+
):
|
|
47
|
+
make_datetime_numpy_dict(datetimes)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
4
|
+
merge_dicts,
|
|
5
|
+
fill_nans_in_arrays,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
def test_merge_dicts():
|
|
9
|
+
"""Test merge_dicts function"""
|
|
10
|
+
dict1 = {"a": 1, "b": 2}
|
|
11
|
+
dict2 = {"c": 3, "d": 4}
|
|
12
|
+
dict3 = {"e": 5}
|
|
13
|
+
|
|
14
|
+
result = merge_dicts([dict1, dict2, dict3])
|
|
15
|
+
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
|
16
|
+
|
|
17
|
+
# Test key overwriting
|
|
18
|
+
dict4 = {"a": 10, "f": 6}
|
|
19
|
+
result = merge_dicts([dict1, dict4])
|
|
20
|
+
assert result["a"] == 10
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_fill_nans_in_arrays():
|
|
24
|
+
"""Test the fill_nans_in_arrays function"""
|
|
25
|
+
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
|
|
26
|
+
nested_dict = {
|
|
27
|
+
"array1": array_with_nans,
|
|
28
|
+
"nested": {
|
|
29
|
+
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
|
|
30
|
+
},
|
|
31
|
+
"string_key": "not_an_array"
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
result = fill_nans_in_arrays(nested_dict)
|
|
35
|
+
|
|
36
|
+
assert not np.isnan(result["array1"]).any()
|
|
37
|
+
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
|
|
38
|
+
assert not np.isnan(result["nested"]["array2"]).any()
|
|
39
|
+
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
|
|
40
|
+
assert result["string_key"] == "not_an_array"
|
|
41
|
+
|
|
42
|
+
|
|
@@ -1,11 +1,88 @@
|
|
|
1
|
-
import
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
4
|
+
import dask.array as da
|
|
2
5
|
import tempfile
|
|
3
6
|
|
|
4
|
-
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
5
|
-
from ocf_data_sampler.config import
|
|
7
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
|
|
8
|
+
from ocf_data_sampler.config.save import save_yaml_configuration
|
|
9
|
+
from ocf_data_sampler.config.load import load_yaml_configuration
|
|
6
10
|
from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
|
|
11
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import process_and_combine_datasets, compute
|
|
12
|
+
from ocf_data_sampler.select.location import Location
|
|
7
13
|
|
|
14
|
+
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
8
15
|
|
|
16
|
+
# Load in config for function and define location
|
|
17
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
18
|
+
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
19
|
+
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
20
|
+
|
|
21
|
+
nwp_data = xr.DataArray(
|
|
22
|
+
np.random.rand(4, 2, 2, 2),
|
|
23
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
24
|
+
coords={
|
|
25
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
26
|
+
"channel": ["t2m", "dswrf"],
|
|
27
|
+
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
28
|
+
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
29
|
+
}
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
sat_data = xr.DataArray(
|
|
33
|
+
np.random.rand(7, 1, 2, 2),
|
|
34
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
35
|
+
coords={
|
|
36
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
|
|
37
|
+
"channel": ["HRV"],
|
|
38
|
+
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
|
|
39
|
+
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
|
|
40
|
+
}
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Combine as dict
|
|
44
|
+
dataset_dict = {
|
|
45
|
+
"nwp": {"ukv": nwp_data},
|
|
46
|
+
"sat": sat_data
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
# Call relevant function
|
|
50
|
+
result = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
51
|
+
|
|
52
|
+
# Assert result is dict - check and validate
|
|
53
|
+
assert isinstance(result, dict)
|
|
54
|
+
assert NWPSampleKey.nwp in result
|
|
55
|
+
assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
56
|
+
assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
|
|
57
|
+
|
|
58
|
+
def test_compute():
|
|
59
|
+
"""Test compute function with dask array"""
|
|
60
|
+
da_dask = xr.DataArray(da.random.random((5, 5)))
|
|
61
|
+
|
|
62
|
+
# Create a nested dictionary with dask array
|
|
63
|
+
nested_dict = {
|
|
64
|
+
"array1": da_dask,
|
|
65
|
+
"nested": {
|
|
66
|
+
"array2": da_dask
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# Ensure initial data is lazy - i.e. not yet computed
|
|
71
|
+
assert not isinstance(nested_dict["array1"].data, np.ndarray)
|
|
72
|
+
assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
|
|
73
|
+
|
|
74
|
+
# Call the compute function
|
|
75
|
+
result = compute(nested_dict)
|
|
76
|
+
|
|
77
|
+
# Assert that the result is an xarray DataArray and no longer lazy
|
|
78
|
+
assert isinstance(result["array1"], xr.DataArray)
|
|
79
|
+
assert isinstance(result["nested"]["array2"], xr.DataArray)
|
|
80
|
+
assert isinstance(result["array1"].data, np.ndarray)
|
|
81
|
+
assert isinstance(result["nested"]["array2"].data, np.ndarray)
|
|
82
|
+
|
|
83
|
+
# Ensure there no NaN values in computed data
|
|
84
|
+
assert not np.isnan(result["array1"].data).any()
|
|
85
|
+
assert not np.isnan(result["nested"]["array2"].data).any()
|
|
9
86
|
|
|
10
87
|
def test_pvnet(pvnet_config_filename):
|
|
11
88
|
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
-
|
|
3
|
-
from ocf_data_sampler.torch_datasets import SitesDataset
|
|
4
|
-
from ocf_data_sampler.torch_datasets.site import convert_from_dataset_to_dict_datasets
|
|
5
2
|
import numpy as np
|
|
3
|
+
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets
|
|
6
4
|
from xarray import Dataset, DataArray
|
|
7
5
|
|
|
8
6
|
|
|
@@ -22,7 +20,9 @@ def test_site(site_config_filename):
|
|
|
22
20
|
# Expected dimensions and data variables
|
|
23
21
|
expected_dims = {'satellite__x_geostationary', 'site__time_utc', 'nwp-ukv__target_time_utc',
|
|
24
22
|
'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
|
|
25
|
-
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'
|
|
23
|
+
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb', 'site_solar_azimuth',
|
|
24
|
+
'site_solar_elevation', 'site_date_cos', 'site_time_cos', 'site_time_sin', 'site_date_sin'}
|
|
25
|
+
|
|
26
26
|
expected_data_vars = {"nwp-ukv", "satellite", "site"}
|
|
27
27
|
|
|
28
28
|
# Check dimensions
|
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import xarray as xr
|
|
4
|
-
from typing import Optional
|
|
5
|
-
|
|
6
|
-
from ocf_data_sampler.config import Configuration
|
|
7
|
-
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS,RSS_MEAN,RSS_STD
|
|
8
|
-
from ocf_data_sampler.numpy_sample import (
|
|
9
|
-
convert_nwp_to_numpy_sample,
|
|
10
|
-
convert_satellite_to_numpy_sample,
|
|
11
|
-
convert_gsp_to_numpy_sample,
|
|
12
|
-
make_sun_position_numpy_sample,
|
|
13
|
-
)
|
|
14
|
-
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
15
|
-
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
16
|
-
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
17
|
-
from ocf_data_sampler.select.location import Location
|
|
18
|
-
from ocf_data_sampler.utils import minutes
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def process_and_combine_datasets(
|
|
22
|
-
dataset_dict: dict,
|
|
23
|
-
config: Configuration,
|
|
24
|
-
t0: Optional[pd.Timestamp] = None,
|
|
25
|
-
location: Optional[Location] = None,
|
|
26
|
-
target_key: str = 'gsp'
|
|
27
|
-
) -> dict:
|
|
28
|
-
|
|
29
|
-
"""Normalise and convert data to numpy arrays"""
|
|
30
|
-
numpy_modalities = []
|
|
31
|
-
|
|
32
|
-
if "nwp" in dataset_dict:
|
|
33
|
-
|
|
34
|
-
nwp_numpy_modalities = dict()
|
|
35
|
-
|
|
36
|
-
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
37
|
-
# Standardise
|
|
38
|
-
provider = config.input_data.nwp[nwp_key].provider
|
|
39
|
-
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
|
-
# Convert to NumpySample
|
|
41
|
-
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
42
|
-
|
|
43
|
-
# Combine the NWPs into NumpySample
|
|
44
|
-
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
if "sat" in dataset_dict:
|
|
48
|
-
# Standardise
|
|
49
|
-
da_sat = dataset_dict["sat"]
|
|
50
|
-
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
51
|
-
|
|
52
|
-
# Convert to NumpySample
|
|
53
|
-
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
gsp_config = config.input_data.gsp
|
|
57
|
-
|
|
58
|
-
if "gsp" in dataset_dict:
|
|
59
|
-
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
|
|
60
|
-
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
61
|
-
|
|
62
|
-
numpy_modalities.append(
|
|
63
|
-
convert_gsp_to_numpy_sample(
|
|
64
|
-
da_gsp,
|
|
65
|
-
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
66
|
-
)
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
# Add coordinate data
|
|
70
|
-
# TODO: Do we need all of these?
|
|
71
|
-
numpy_modalities.append(
|
|
72
|
-
{
|
|
73
|
-
GSPSampleKey.gsp_id: location.id,
|
|
74
|
-
GSPSampleKey.x_osgb: location.x,
|
|
75
|
-
GSPSampleKey.y_osgb: location.y,
|
|
76
|
-
}
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
if target_key == 'gsp':
|
|
80
|
-
# Make sun coords NumpySample
|
|
81
|
-
datetimes = pd.date_range(
|
|
82
|
-
t0+minutes(gsp_config.interval_start_minutes),
|
|
83
|
-
t0+minutes(gsp_config.interval_end_minutes),
|
|
84
|
-
freq=minutes(gsp_config.time_resolution_minutes),
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
88
|
-
|
|
89
|
-
numpy_modalities.append(
|
|
90
|
-
make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
# Combine all the modalities and fill NaNs
|
|
94
|
-
combined_sample = merge_dicts(numpy_modalities)
|
|
95
|
-
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
96
|
-
|
|
97
|
-
return combined_sample
|
|
98
|
-
|
|
99
|
-
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
100
|
-
"""Merge a list of dictionaries into a single dictionary"""
|
|
101
|
-
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
102
|
-
combined_dict = {}
|
|
103
|
-
for d in list_of_dicts:
|
|
104
|
-
combined_dict.update(d)
|
|
105
|
-
return combined_dict
|
|
106
|
-
|
|
107
|
-
def fill_nans_in_arrays(sample: dict) -> dict:
|
|
108
|
-
"""Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
|
|
109
|
-
|
|
110
|
-
Operation is performed in-place on the sample.
|
|
111
|
-
"""
|
|
112
|
-
for k, v in sample.items():
|
|
113
|
-
if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
|
|
114
|
-
if np.isnan(v).any():
|
|
115
|
-
sample[k] = np.nan_to_num(v, copy=False, nan=0.0)
|
|
116
|
-
|
|
117
|
-
# Recursion is included to reach NWP arrays in subdict
|
|
118
|
-
elif isinstance(v, dict):
|
|
119
|
-
fill_nans_in_arrays(v)
|
|
120
|
-
|
|
121
|
-
return sample
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def compute(xarray_dict: dict) -> dict:
|
|
125
|
-
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
126
|
-
for k, v in xarray_dict.items():
|
|
127
|
-
if isinstance(v, dict):
|
|
128
|
-
xarray_dict[k] = compute(v)
|
|
129
|
-
else:
|
|
130
|
-
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
131
|
-
return xarray_dict
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import xarray as xr
|
|
4
|
-
import dask.array as da
|
|
5
|
-
|
|
6
|
-
from ocf_data_sampler.config import load_yaml_configuration
|
|
7
|
-
from ocf_data_sampler.select.location import Location
|
|
8
|
-
from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
|
|
9
|
-
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
10
|
-
|
|
11
|
-
from ocf_data_sampler.torch_datasets.process_and_combine import (
|
|
12
|
-
process_and_combine_datasets,
|
|
13
|
-
merge_dicts,
|
|
14
|
-
fill_nans_in_arrays,
|
|
15
|
-
compute,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
20
|
-
|
|
21
|
-
# Load in config for function and define location
|
|
22
|
-
config = load_yaml_configuration(pvnet_config_filename)
|
|
23
|
-
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
24
|
-
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
25
|
-
|
|
26
|
-
nwp_data = xr.DataArray(
|
|
27
|
-
np.random.rand(4, 2, 2, 2),
|
|
28
|
-
dims=["time_utc", "channel", "y", "x"],
|
|
29
|
-
coords={
|
|
30
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
31
|
-
"channel": ["t2m", "dswrf"],
|
|
32
|
-
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
33
|
-
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
34
|
-
}
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
sat_data = xr.DataArray(
|
|
38
|
-
np.random.rand(7, 1, 2, 2),
|
|
39
|
-
dims=["time_utc", "channel", "y", "x"],
|
|
40
|
-
coords={
|
|
41
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
|
|
42
|
-
"channel": ["HRV"],
|
|
43
|
-
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
|
|
44
|
-
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
|
|
45
|
-
}
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
# Combine as dict
|
|
49
|
-
dataset_dict = {
|
|
50
|
-
"nwp": {"ukv": nwp_data},
|
|
51
|
-
"sat": sat_data
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
# Call relevant function
|
|
55
|
-
result = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
56
|
-
|
|
57
|
-
# Assert result is dict - check and validate
|
|
58
|
-
assert isinstance(result, dict)
|
|
59
|
-
assert NWPSampleKey.nwp in result
|
|
60
|
-
assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
61
|
-
assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def test_merge_dicts():
|
|
65
|
-
"""Test merge_dicts function"""
|
|
66
|
-
dict1 = {"a": 1, "b": 2}
|
|
67
|
-
dict2 = {"c": 3, "d": 4}
|
|
68
|
-
dict3 = {"e": 5}
|
|
69
|
-
|
|
70
|
-
result = merge_dicts([dict1, dict2, dict3])
|
|
71
|
-
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
|
72
|
-
|
|
73
|
-
# Test key overwriting
|
|
74
|
-
dict4 = {"a": 10, "f": 6}
|
|
75
|
-
result = merge_dicts([dict1, dict4])
|
|
76
|
-
assert result["a"] == 10
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def test_fill_nans_in_arrays():
|
|
80
|
-
"""Test the fill_nans_in_arrays function"""
|
|
81
|
-
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
|
|
82
|
-
nested_dict = {
|
|
83
|
-
"array1": array_with_nans,
|
|
84
|
-
"nested": {
|
|
85
|
-
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
|
|
86
|
-
},
|
|
87
|
-
"string_key": "not_an_array"
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
result = fill_nans_in_arrays(nested_dict)
|
|
91
|
-
|
|
92
|
-
assert not np.isnan(result["array1"]).any()
|
|
93
|
-
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
|
|
94
|
-
assert not np.isnan(result["nested"]["array2"]).any()
|
|
95
|
-
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
|
|
96
|
-
assert result["string_key"] == "not_an_array"
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def test_compute():
|
|
100
|
-
"""Test compute function with dask array"""
|
|
101
|
-
da_dask = xr.DataArray(da.random.random((5, 5)))
|
|
102
|
-
|
|
103
|
-
# Create a nested dictionary with dask array
|
|
104
|
-
nested_dict = {
|
|
105
|
-
"array1": da_dask,
|
|
106
|
-
"nested": {
|
|
107
|
-
"array2": da_dask
|
|
108
|
-
}
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
# Ensure initial data is lazy - i.e. not yet computed
|
|
112
|
-
assert not isinstance(nested_dict["array1"].data, np.ndarray)
|
|
113
|
-
assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
|
|
114
|
-
|
|
115
|
-
# Call the compute function
|
|
116
|
-
result = compute(nested_dict)
|
|
117
|
-
|
|
118
|
-
# Assert that the result is an xarray DataArray and no longer lazy
|
|
119
|
-
assert isinstance(result["array1"], xr.DataArray)
|
|
120
|
-
assert isinstance(result["nested"]["array2"], xr.DataArray)
|
|
121
|
-
assert isinstance(result["array1"].data, np.ndarray)
|
|
122
|
-
assert isinstance(result["nested"]["array2"].data, np.ndarray)
|
|
123
|
-
|
|
124
|
-
# Ensure there no NaN values in computed data
|
|
125
|
-
assert not np.isnan(result["array1"].data).any()
|
|
126
|
-
assert not np.isnan(result["nested"]["array2"].data).any()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|