ocf-data-sampler 0.1.1__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- {ocf_data_sampler-0.1.1/ocf_data_sampler.egg-info → ocf_data_sampler-0.1.3}/PKG-INFO +1 -1
- ocf_data_sampler-0.1.3/ocf_data_sampler/numpy_sample/collate.py +64 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/fill_time_periods.py +1 -1
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/time_slice_for_dataset.py +16 -13
- ocf_data_sampler-0.1.3/ocf_data_sampler/torch_datasets/datasets/__init__.py +6 -0
- ocf_data_sampler-0.1.1/ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py → ocf_data_sampler-0.1.3/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +114 -16
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler.egg-info/SOURCES.txt +2 -3
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/pyproject.toml +1 -1
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/conftest.py +69 -70
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/load/test_load_satellite.py +3 -3
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/numpy_sample/test_collate.py +4 -9
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/torch_datasets/test_merge_and_fill_utils.py +0 -2
- ocf_data_sampler-0.1.3/tests/torch_datasets/test_pvnet_uk.py +166 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/torch_datasets/test_site.py +47 -36
- ocf_data_sampler-0.1.1/ocf_data_sampler/numpy_sample/collate.py +0 -75
- ocf_data_sampler-0.1.1/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -11
- ocf_data_sampler-0.1.1/tests/torch_datasets/conftest.py +0 -18
- ocf_data_sampler-0.1.1/tests/torch_datasets/test_pvnet_uk_regional.py +0 -136
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/LICENSE +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/MANIFEST.in +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/README.md +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/config/model.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/constants.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/gsp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/load_dataset.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/nwp/nwp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/satellite.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/site.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/site.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/sample/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/sample/base.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/sample/site.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/sample/uk_regional.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/dropout.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/location.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/select_time_slice.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/torch_datasets/datasets/site.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/utils.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/setup.cfg +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/__init__.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/config/test_config.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/config/test_save.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/load/test_load_gsp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/load/test_load_nwp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/load/test_load_sites.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/numpy_sample/test_datetime_features.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/numpy_sample/test_gsp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/numpy_sample/test_nwp.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/numpy_sample/test_satellite.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/numpy_sample/test_sun_position.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/select/test_dropout.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/select/test_fill_time_periods.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/select/test_find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/select/test_location.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/select/test_select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/select/test_select_time_slice.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/test_sample/test_base.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/test_sample/test_site_sample.py +0 -0
- {ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/test_sample/test_uk_regional_sample.py +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
5
|
+
"""Stacks list of dict samples into a dict where all samples are joined along a new axis
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
dict_list: A list of dict-like samples to stack
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
Dict of the samples stacked with new batch dimension on axis 0
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
batch = {}
|
|
15
|
+
|
|
16
|
+
keys = list(dict_list[0].keys())
|
|
17
|
+
|
|
18
|
+
for key in keys:
|
|
19
|
+
# NWP is nested so treat separately
|
|
20
|
+
if key == "nwp":
|
|
21
|
+
batch["nwp"] = {}
|
|
22
|
+
|
|
23
|
+
# Unpack NWP provider keys
|
|
24
|
+
nwp_providers = list(dict_list[0]["nwp"].keys())
|
|
25
|
+
|
|
26
|
+
for nwp_provider in nwp_providers:
|
|
27
|
+
# Keys can be different for different NWPs
|
|
28
|
+
nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
|
|
29
|
+
|
|
30
|
+
# Create dict to store NWP batch for this provider
|
|
31
|
+
nwp_provider_batch = {}
|
|
32
|
+
|
|
33
|
+
for nwp_key in nwp_keys:
|
|
34
|
+
# Stack values under each NWP key for this provider
|
|
35
|
+
nwp_provider_batch[nwp_key] = stack_data_list(
|
|
36
|
+
[d["nwp"][nwp_provider][nwp_key] for d in dict_list],
|
|
37
|
+
nwp_key,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
batch["nwp"][nwp_provider] = nwp_provider_batch
|
|
41
|
+
|
|
42
|
+
else:
|
|
43
|
+
batch[key] = stack_data_list([d[key] for d in dict_list], key)
|
|
44
|
+
|
|
45
|
+
return batch
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _key_is_constant(key: str):
|
|
49
|
+
return key.endswith("t0_idx") or key.endswith("channel_names")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def stack_data_list(data_list: list, key: str):
|
|
53
|
+
"""Stack a sequence of data elements along a new axis
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
data_list: List of data elements to combine
|
|
57
|
+
key: string identifying the data type
|
|
58
|
+
"""
|
|
59
|
+
if _key_is_constant(key):
|
|
60
|
+
# These are always the same for all examples.
|
|
61
|
+
return data_list[0]
|
|
62
|
+
else:
|
|
63
|
+
return np.stack(data_list)
|
|
64
|
+
|
{ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/fill_time_periods.py
RENAMED
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta):
|
|
7
|
+
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
|
|
8
8
|
start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
|
|
9
9
|
end_dts = pd.to_datetime(time_periods["end_dt"].values)
|
|
10
10
|
date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
|
{ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/ocf_data_sampler/select/time_slice_for_dataset.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
""" Slice datasets by time"""
|
|
2
2
|
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
3
4
|
|
|
4
5
|
from ocf_data_sampler.config import Configuration
|
|
5
6
|
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
@@ -64,16 +65,8 @@ def slice_datasets_by_time(
|
|
|
64
65
|
|
|
65
66
|
if "gsp" in datasets_dict:
|
|
66
67
|
gsp_config = config.input_data.gsp
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
datasets_dict["gsp"],
|
|
70
|
-
t0,
|
|
71
|
-
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
72
|
-
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
73
|
-
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
68
|
+
|
|
69
|
+
da_gsp_past = select_time_slice(
|
|
77
70
|
datasets_dict["gsp"],
|
|
78
71
|
t0,
|
|
79
72
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
@@ -81,17 +74,27 @@ def slice_datasets_by_time(
|
|
|
81
74
|
interval_end=minutes(0),
|
|
82
75
|
)
|
|
83
76
|
|
|
84
|
-
# Dropout on the GSP, but not the future GSP
|
|
77
|
+
# Dropout on the past GSP, but not the future GSP
|
|
85
78
|
gsp_dropout_time = draw_dropout_time(
|
|
86
79
|
t0,
|
|
87
80
|
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
88
81
|
dropout_frac=gsp_config.dropout_fraction,
|
|
89
82
|
)
|
|
90
83
|
|
|
91
|
-
|
|
92
|
-
|
|
84
|
+
da_gsp_past = apply_dropout_time(
|
|
85
|
+
da_gsp_past,
|
|
93
86
|
gsp_dropout_time
|
|
94
87
|
)
|
|
88
|
+
|
|
89
|
+
da_gsp_future = select_time_slice(
|
|
90
|
+
datasets_dict["gsp"],
|
|
91
|
+
t0,
|
|
92
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
93
|
+
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
94
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
|
|
95
98
|
|
|
96
99
|
if "site" in datasets_dict:
|
|
97
100
|
site_config = config.input_data.site
|
|
@@ -1,15 +1,20 @@
|
|
|
1
|
-
"""Torch dataset for PVNet"""
|
|
1
|
+
"""Torch dataset for UK PVNet"""
|
|
2
|
+
|
|
3
|
+
import pkg_resources
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
6
|
import pandas as pd
|
|
5
|
-
import pkg_resources
|
|
6
7
|
import xarray as xr
|
|
7
8
|
from torch.utils.data import Dataset
|
|
8
9
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
9
10
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
10
|
-
from ocf_data_sampler.select import
|
|
11
|
+
from ocf_data_sampler.select import (
|
|
12
|
+
fill_time_periods,
|
|
13
|
+
Location,
|
|
14
|
+
slice_datasets_by_space,
|
|
15
|
+
slice_datasets_by_time,
|
|
16
|
+
)
|
|
11
17
|
from ocf_data_sampler.utils import minutes
|
|
12
|
-
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
13
18
|
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
14
19
|
from ocf_data_sampler.numpy_sample import (
|
|
15
20
|
convert_nwp_to_numpy_sample,
|
|
@@ -17,13 +22,16 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
17
22
|
convert_gsp_to_numpy_sample,
|
|
18
23
|
make_sun_position_numpy_sample,
|
|
19
24
|
)
|
|
25
|
+
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
26
|
+
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
27
|
+
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
28
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
29
|
+
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
20
30
|
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
21
31
|
merge_dicts,
|
|
22
32
|
fill_nans_in_arrays,
|
|
23
33
|
)
|
|
24
|
-
|
|
25
|
-
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
26
|
-
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
34
|
+
|
|
27
35
|
|
|
28
36
|
xr.set_options(keep_attrs=True)
|
|
29
37
|
|
|
@@ -65,9 +73,10 @@ def process_and_combine_datasets(
|
|
|
65
73
|
gsp_config = config.input_data.gsp
|
|
66
74
|
|
|
67
75
|
if "gsp" in dataset_dict:
|
|
68
|
-
da_gsp =
|
|
76
|
+
da_gsp = dataset_dict["gsp"]
|
|
69
77
|
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
70
|
-
|
|
78
|
+
|
|
79
|
+
# Convert to NumpyBatch
|
|
71
80
|
numpy_modalities.append(
|
|
72
81
|
convert_gsp_to_numpy_sample(
|
|
73
82
|
da_gsp,
|
|
@@ -105,6 +114,7 @@ def process_and_combine_datasets(
|
|
|
105
114
|
|
|
106
115
|
return combined_sample
|
|
107
116
|
|
|
117
|
+
|
|
108
118
|
def compute(xarray_dict: dict) -> dict:
|
|
109
119
|
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
110
120
|
for k, v in xarray_dict.items():
|
|
@@ -114,10 +124,8 @@ def compute(xarray_dict: dict) -> dict:
|
|
|
114
124
|
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
115
125
|
return xarray_dict
|
|
116
126
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
config: Configuration,
|
|
120
|
-
):
|
|
127
|
+
|
|
128
|
+
def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
|
|
121
129
|
"""Find the t0 times where all of the requested input data is available
|
|
122
130
|
|
|
123
131
|
Args:
|
|
@@ -167,7 +175,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
167
175
|
self,
|
|
168
176
|
config_filename: str,
|
|
169
177
|
start_time: str | None = None,
|
|
170
|
-
end_time: str| None = None,
|
|
178
|
+
end_time: str | None = None,
|
|
171
179
|
gsp_ids: list[int] | None = None,
|
|
172
180
|
):
|
|
173
181
|
"""A torch Dataset for creating PVNet UK GSP samples
|
|
@@ -253,7 +261,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
253
261
|
def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
|
|
254
262
|
"""Generate a sample for the given coordinates.
|
|
255
263
|
|
|
256
|
-
Useful for users to generate samples
|
|
264
|
+
Useful for users to generate specific samples.
|
|
257
265
|
|
|
258
266
|
Args:
|
|
259
267
|
t0: init-time for sample
|
|
@@ -265,4 +273,94 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
265
273
|
|
|
266
274
|
location = self.location_lookup[gsp_id]
|
|
267
275
|
|
|
268
|
-
return self._get_sample(t0, location)
|
|
276
|
+
return self._get_sample(t0, location)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class PVNetUKConcurrentDataset(Dataset):
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
config_filename: str,
|
|
283
|
+
start_time: str | None = None,
|
|
284
|
+
end_time: str | None = None,
|
|
285
|
+
gsp_ids: list[int] | None = None,
|
|
286
|
+
):
|
|
287
|
+
"""A torch Dataset for creating concurrent samples of PVNet UK regional data
|
|
288
|
+
|
|
289
|
+
Each concurrent sample includes the data from all GSPs for a single t0 time
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
config_filename: Path to the configuration file
|
|
293
|
+
start_time: Limit the init-times to be after this
|
|
294
|
+
end_time: Limit the init-times to be before this
|
|
295
|
+
gsp_ids: List of all GSP IDs included in each sample. Defaults to all
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
config = load_yaml_configuration(config_filename)
|
|
299
|
+
|
|
300
|
+
datasets_dict = get_dataset_dict(config)
|
|
301
|
+
|
|
302
|
+
# Get t0 times where all input data is available
|
|
303
|
+
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
304
|
+
|
|
305
|
+
# Filter t0 times to given range
|
|
306
|
+
if start_time is not None:
|
|
307
|
+
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
308
|
+
|
|
309
|
+
if end_time is not None:
|
|
310
|
+
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
311
|
+
|
|
312
|
+
# Construct list of locations to sample from
|
|
313
|
+
locations = get_gsp_locations(gsp_ids)
|
|
314
|
+
|
|
315
|
+
# Assign coords and indices to self
|
|
316
|
+
self.valid_t0_times = valid_t0_times
|
|
317
|
+
self.locations = locations
|
|
318
|
+
|
|
319
|
+
# Assign config and input data to self
|
|
320
|
+
self.datasets_dict = datasets_dict
|
|
321
|
+
self.config = config
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def __len__(self):
|
|
325
|
+
return len(self.valid_t0_times)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _get_sample(self, t0: pd.Timestamp) -> dict:
|
|
329
|
+
"""Generate a concurrent PVNet sample for given init-time
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
t0: init-time for sample
|
|
333
|
+
"""
|
|
334
|
+
# Slice by time then load to avoid loading the data multiple times from disk
|
|
335
|
+
sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
|
|
336
|
+
sample_dict = compute(sample_dict)
|
|
337
|
+
|
|
338
|
+
gsp_samples = []
|
|
339
|
+
|
|
340
|
+
# Prepare sample for each GSP
|
|
341
|
+
for location in self.locations:
|
|
342
|
+
gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
343
|
+
gsp_numpy_sample = process_and_combine_datasets(
|
|
344
|
+
gsp_sample_dict, self.config, t0, location
|
|
345
|
+
)
|
|
346
|
+
gsp_samples.append(gsp_numpy_sample)
|
|
347
|
+
|
|
348
|
+
# Stack GSP samples
|
|
349
|
+
return stack_np_samples_into_batch(gsp_samples)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def __getitem__(self, idx):
|
|
353
|
+
return self._get_sample(self.valid_t0_times[idx])
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_sample(self, t0: pd.Timestamp) -> dict:
|
|
357
|
+
"""Generate a sample for the given init-time.
|
|
358
|
+
|
|
359
|
+
Useful for users to generate specific samples.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
t0: init-time for sample
|
|
363
|
+
"""
|
|
364
|
+
# Check data is availablle for init-time t0
|
|
365
|
+
assert t0 in self.valid_t0_times
|
|
366
|
+
return self._get_sample(t0)
|
|
@@ -50,7 +50,7 @@ ocf_data_sampler/select/select_time_slice.py
|
|
|
50
50
|
ocf_data_sampler/select/spatial_slice_for_dataset.py
|
|
51
51
|
ocf_data_sampler/select/time_slice_for_dataset.py
|
|
52
52
|
ocf_data_sampler/torch_datasets/datasets/__init__.py
|
|
53
|
-
ocf_data_sampler/torch_datasets/datasets/
|
|
53
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
|
|
54
54
|
ocf_data_sampler/torch_datasets/datasets/site.py
|
|
55
55
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
|
|
56
56
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py
|
|
@@ -78,7 +78,6 @@ tests/select/test_select_time_slice.py
|
|
|
78
78
|
tests/test_sample/test_base.py
|
|
79
79
|
tests/test_sample/test_site_sample.py
|
|
80
80
|
tests/test_sample/test_uk_regional_sample.py
|
|
81
|
-
tests/torch_datasets/conftest.py
|
|
82
81
|
tests/torch_datasets/test_merge_and_fill_utils.py
|
|
83
|
-
tests/torch_datasets/
|
|
82
|
+
tests/torch_datasets/test_pvnet_uk.py
|
|
84
83
|
tests/torch_datasets/test_site.py
|
|
@@ -1,14 +1,15 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import numpy as np
|
|
3
5
|
import pandas as pd
|
|
4
|
-
import pytest
|
|
5
6
|
import xarray as xr
|
|
6
|
-
import
|
|
7
|
-
from typing import Generator
|
|
7
|
+
import dask.array
|
|
8
8
|
|
|
9
9
|
from ocf_data_sampler.config.model import Site
|
|
10
10
|
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
_top_test_directory = os.path.dirname(os.path.realpath(__file__))
|
|
13
14
|
|
|
14
15
|
@pytest.fixture()
|
|
@@ -18,40 +19,27 @@ def test_config_filename():
|
|
|
18
19
|
|
|
19
20
|
@pytest.fixture(scope="session")
|
|
20
21
|
def config_filename():
|
|
21
|
-
return f"{
|
|
22
|
+
return f"{_top_test_directory}/test_data/configs/pvnet_test_config.yaml"
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
@pytest.fixture(scope="session")
|
|
25
|
-
def
|
|
26
|
-
|
|
27
|
-
# Load dataset which only contains coordinates, but no data
|
|
28
|
-
ds = xr.open_zarr(
|
|
29
|
-
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr.zip"
|
|
30
|
-
).compute()
|
|
31
|
-
|
|
32
|
-
# Add time coord
|
|
33
|
-
ds = ds.assign_coords(time=pd.date_range("2023-01-01 00:00", "2023-01-02 23:55", freq="5min"))
|
|
34
|
-
|
|
35
|
-
# Add data to dataset
|
|
36
|
-
ds["data"] = xr.DataArray(
|
|
37
|
-
np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32),
|
|
38
|
-
coords=ds.coords,
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
# Transpose to variables, time, y, x (just in case)
|
|
42
|
-
ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
|
|
26
|
+
def session_tmp_path(tmp_path_factory):
|
|
27
|
+
return tmp_path_factory.mktemp("data")
|
|
43
28
|
|
|
44
|
-
# add 100,000 to x_geostationary, this to make sure the fix index is within the satellite image
|
|
45
|
-
ds["x_geostationary"] = ds["x_geostationary"] - 200_000
|
|
46
29
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
#
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
30
|
+
@pytest.fixture(scope="session")
|
|
31
|
+
def sat_zarr_path(session_tmp_path):
|
|
32
|
+
|
|
33
|
+
# Define coords for satellite-like dataset
|
|
34
|
+
variables = [
|
|
35
|
+
'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
|
|
36
|
+
'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073',
|
|
37
|
+
]
|
|
38
|
+
x = np.linspace(start=15002, stop=-1824245, num=100)
|
|
39
|
+
y = np.linspace(start=4191563, stop=5304712, num=100)
|
|
40
|
+
times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min")
|
|
41
|
+
|
|
42
|
+
area_string = (
|
|
55
43
|
"""msg_seviri_rss_3km:
|
|
56
44
|
description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution
|
|
57
45
|
projection:
|
|
@@ -73,16 +61,31 @@ def sat_zarr_path():
|
|
|
73
61
|
units: m
|
|
74
62
|
"""
|
|
75
63
|
)
|
|
76
|
-
|
|
77
|
-
#
|
|
78
|
-
|
|
64
|
+
|
|
65
|
+
# Create satellite-like data with some NaNs
|
|
66
|
+
data = dask.array.zeros(
|
|
67
|
+
shape=(len(variables), len(times), len(y), len(x)),
|
|
68
|
+
chunks=(-1, 10, -1, -1),
|
|
69
|
+
dtype=np.float32
|
|
70
|
+
)
|
|
71
|
+
data [:, 10, :, :] = np.nan
|
|
72
|
+
|
|
73
|
+
ds = xr.DataArray(
|
|
74
|
+
data=data,
|
|
75
|
+
coords=dict(
|
|
76
|
+
variable=variables,
|
|
77
|
+
time=times,
|
|
78
|
+
y_geostationary=y,
|
|
79
|
+
x_geostationary=x,
|
|
80
|
+
),
|
|
81
|
+
attrs=dict(area=area_string),
|
|
82
|
+
).to_dataset(name="data")
|
|
79
83
|
|
|
80
84
|
# Save temporarily as a zarr
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
ds.to_zarr(zarr_path)
|
|
85
|
+
zarr_path = session_tmp_path / "test_sat.zarr"
|
|
86
|
+
ds.to_zarr(zarr_path)
|
|
84
87
|
|
|
85
|
-
|
|
88
|
+
yield zarr_path
|
|
86
89
|
|
|
87
90
|
|
|
88
91
|
@pytest.fixture(scope="session")
|
|
@@ -112,7 +115,7 @@ def ds_nwp_ukv():
|
|
|
112
115
|
|
|
113
116
|
|
|
114
117
|
@pytest.fixture(scope="session")
|
|
115
|
-
def nwp_ukv_zarr_path(ds_nwp_ukv):
|
|
118
|
+
def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
|
|
116
119
|
ds = ds_nwp_ukv.chunk(
|
|
117
120
|
{
|
|
118
121
|
"init_time": 1,
|
|
@@ -122,10 +125,9 @@ def nwp_ukv_zarr_path(ds_nwp_ukv):
|
|
|
122
125
|
"y": 50,
|
|
123
126
|
}
|
|
124
127
|
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
yield filename
|
|
128
|
+
zarr_path = session_tmp_path / "ukv_nwp.zarr"
|
|
129
|
+
ds.to_zarr(zarr_path)
|
|
130
|
+
yield zarr_path
|
|
129
131
|
|
|
130
132
|
|
|
131
133
|
@pytest.fixture(scope="session")
|
|
@@ -155,7 +157,7 @@ def ds_nwp_ecmwf():
|
|
|
155
157
|
|
|
156
158
|
|
|
157
159
|
@pytest.fixture(scope="session")
|
|
158
|
-
def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
|
|
160
|
+
def nwp_ecmwf_zarr_path(session_tmp_path, ds_nwp_ecmwf):
|
|
159
161
|
ds = ds_nwp_ecmwf.chunk(
|
|
160
162
|
{
|
|
161
163
|
"init_time": 1,
|
|
@@ -165,10 +167,10 @@ def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
|
|
|
165
167
|
"latitude": 50,
|
|
166
168
|
}
|
|
167
169
|
)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
170
|
+
|
|
171
|
+
zarr_path = session_tmp_path / "ukv_ecmwf.zarr"
|
|
172
|
+
ds.to_zarr(zarr_path)
|
|
173
|
+
yield zarr_path
|
|
172
174
|
|
|
173
175
|
|
|
174
176
|
@pytest.fixture(scope="session")
|
|
@@ -201,7 +203,7 @@ def ds_uk_gsp():
|
|
|
201
203
|
|
|
202
204
|
|
|
203
205
|
@pytest.fixture(scope="session")
|
|
204
|
-
def data_sites() ->
|
|
206
|
+
def data_sites(session_tmp_path) -> Site:
|
|
205
207
|
"""
|
|
206
208
|
Make fake data for sites
|
|
207
209
|
Returns: filename for netcdf file, and csv metadata
|
|
@@ -245,30 +247,27 @@ def data_sites() -> Generator[Site, None, None]:
|
|
|
245
247
|
"generation_kw": da_gen,
|
|
246
248
|
})
|
|
247
249
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
)
|
|
250
|
+
filename = f"{session_tmp_path}/sites.netcdf"
|
|
251
|
+
filename_csv = f"{session_tmp_path}/sites_metadata.csv"
|
|
252
|
+
generation.to_netcdf(filename)
|
|
253
|
+
meta_df.to_csv(filename_csv)
|
|
254
|
+
|
|
255
|
+
site = Site(
|
|
256
|
+
file_path=filename,
|
|
257
|
+
metadata_file_path=filename_csv,
|
|
258
|
+
interval_start_minutes=-30,
|
|
259
|
+
interval_end_minutes=60,
|
|
260
|
+
time_resolution_minutes=30,
|
|
261
|
+
)
|
|
261
262
|
|
|
262
|
-
|
|
263
|
+
yield site
|
|
263
264
|
|
|
264
265
|
|
|
265
266
|
@pytest.fixture(scope="session")
|
|
266
|
-
def uk_gsp_zarr_path(ds_uk_gsp):
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
ds_uk_gsp.to_zarr(filename)
|
|
271
|
-
yield filename
|
|
267
|
+
def uk_gsp_zarr_path(session_tmp_path, ds_uk_gsp):
|
|
268
|
+
zarr_path = session_tmp_path / "uk_gsp.zarr"
|
|
269
|
+
ds_uk_gsp.to_zarr(zarr_path)
|
|
270
|
+
yield zarr_path
|
|
272
271
|
|
|
273
272
|
|
|
274
273
|
@pytest.fixture()
|
|
@@ -8,10 +8,10 @@ def test_open_satellite(sat_zarr_path):
|
|
|
8
8
|
|
|
9
9
|
assert isinstance(da, xr.DataArray)
|
|
10
10
|
assert da.dims == ("time_utc", "channel", "x_geostationary", "y_geostationary")
|
|
11
|
-
#
|
|
11
|
+
# 288 is 1 days of data at 5 minutes intervals, 12 * 24
|
|
12
12
|
# There are 11 channels
|
|
13
|
-
# There are
|
|
14
|
-
assert da.shape == (
|
|
13
|
+
# There are 100 x 100 pixels
|
|
14
|
+
assert da.shape == (288, 11, 100, 100)
|
|
15
15
|
assert np.issubdtype(da.dtype, np.number)
|
|
16
16
|
|
|
17
17
|
|
|
@@ -1,17 +1,12 @@
|
|
|
1
|
-
from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
|
|
2
1
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
3
|
-
from ocf_data_sampler.torch_datasets.datasets.
|
|
2
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
|
|
4
3
|
|
|
5
4
|
|
|
6
|
-
def
|
|
5
|
+
def test_stack_np_samples_into_batch(pvnet_config_filename):
|
|
7
6
|
|
|
8
7
|
# Create dataset object
|
|
9
8
|
dataset = PVNetUKRegionalDataset(pvnet_config_filename)
|
|
10
9
|
|
|
11
|
-
assert len(dataset.locations) == 317
|
|
12
|
-
assert len(dataset.valid_t0_times) == 39
|
|
13
|
-
assert len(dataset) == 317 * 39
|
|
14
|
-
|
|
15
10
|
# Generate 2 samples
|
|
16
11
|
sample1 = dataset[0]
|
|
17
12
|
sample2 = dataset[1]
|
|
@@ -22,5 +17,5 @@ def test_pvnet(pvnet_config_filename):
|
|
|
22
17
|
assert "nwp" in batch
|
|
23
18
|
assert isinstance(batch["nwp"], dict)
|
|
24
19
|
assert "ukv" in batch["nwp"]
|
|
25
|
-
assert
|
|
26
|
-
assert
|
|
20
|
+
assert "gsp" in batch
|
|
21
|
+
assert "satellite_actual" in batch
|
{ocf_data_sampler-0.1.1 → ocf_data_sampler-0.1.3}/tests/torch_datasets/test_merge_and_fill_utils.py
RENAMED
|
@@ -33,9 +33,7 @@ def test_fill_nans_in_arrays():
|
|
|
33
33
|
|
|
34
34
|
result = fill_nans_in_arrays(nested_dict)
|
|
35
35
|
|
|
36
|
-
assert not np.isnan(result["array1"]).any()
|
|
37
36
|
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
|
|
38
|
-
assert not np.isnan(result["nested"]["array2"]).any()
|
|
39
37
|
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
|
|
40
38
|
assert result["string_key"] == "not_an_array"
|
|
41
39
|
|