ocf-data-sampler 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +85 -122
- ocf_data_sampler/load/load_dataset.py +6 -6
- ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
- ocf_data_sampler/select/select_time_slice.py +24 -33
- ocf_data_sampler/select/spatial_slice_for_dataset.py +4 -4
- ocf_data_sampler/select/time_slice_for_dataset.py +18 -17
- ocf_data_sampler/torch_datasets/process_and_combine.py +13 -14
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +1 -1
- ocf_data_sampler/torch_datasets/site.py +10 -10
- ocf_data_sampler/torch_datasets/valid_time_periods.py +20 -12
- ocf_data_sampler/{time_functions.py → utils.py} +1 -2
- {ocf_data_sampler-0.0.25.dist-info → ocf_data_sampler-0.0.27.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.0.25.dist-info → ocf_data_sampler-0.0.27.dist-info}/RECORD +22 -22
- {ocf_data_sampler-0.0.25.dist-info → ocf_data_sampler-0.0.27.dist-info}/WHEEL +1 -1
- tests/config/test_config.py +23 -14
- tests/conftest.py +7 -5
- tests/select/test_find_contiguous_time_periods.py +8 -8
- tests/select/test_select_time_slice.py +31 -43
- tests/torch_datasets/test_pvnet_uk_regional.py +4 -4
- tests/torch_datasets/test_site.py +2 -2
- {ocf_data_sampler-0.0.25.dist-info → ocf_data_sampler-0.0.27.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.25.dist-info → ocf_data_sampler-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -30,8 +30,8 @@ def slice_datasets_by_space(
|
|
|
30
30
|
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
|
|
31
31
|
datasets_dict["nwp"][nwp_key],
|
|
32
32
|
location,
|
|
33
|
-
height_pixels=nwp_config.
|
|
34
|
-
width_pixels=nwp_config.
|
|
33
|
+
height_pixels=nwp_config.image_size_pixels_height,
|
|
34
|
+
width_pixels=nwp_config.image_size_pixels_width,
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
if "sat" in datasets_dict:
|
|
@@ -40,8 +40,8 @@ def slice_datasets_by_space(
|
|
|
40
40
|
sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
|
|
41
41
|
datasets_dict["sat"],
|
|
42
42
|
location,
|
|
43
|
-
height_pixels=sat_config.
|
|
44
|
-
width_pixels=sat_config.
|
|
43
|
+
height_pixels=sat_config.image_size_pixels_height,
|
|
44
|
+
width_pixels=sat_config.image_size_pixels_width,
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
if "gsp" in datasets_dict:
|
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
|
4
4
|
from ocf_data_sampler.config import Configuration
|
|
5
5
|
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
6
6
|
from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
|
|
7
|
-
from ocf_data_sampler.
|
|
7
|
+
from ocf_data_sampler.utils import minutes
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def slice_datasets_by_time(
|
|
@@ -23,22 +23,22 @@ def slice_datasets_by_time(
|
|
|
23
23
|
sliced_datasets_dict = {}
|
|
24
24
|
|
|
25
25
|
if "nwp" in datasets_dict:
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
sliced_datasets_dict["nwp"] = {}
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
for nwp_key, da_nwp in datasets_dict["nwp"].items():
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
nwp_config = config.input_data.nwp[nwp_key]
|
|
32
32
|
|
|
33
33
|
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
34
34
|
da_nwp,
|
|
35
35
|
t0,
|
|
36
36
|
sample_period_duration=minutes(nwp_config.time_resolution_minutes),
|
|
37
|
-
|
|
38
|
-
|
|
37
|
+
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
38
|
+
interval_end=minutes(nwp_config.interval_end_minutes),
|
|
39
39
|
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
40
40
|
dropout_frac=nwp_config.dropout_fraction,
|
|
41
|
-
accum_channels=nwp_config.
|
|
41
|
+
accum_channels=nwp_config.accum_channels,
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
if "sat" in datasets_dict:
|
|
@@ -49,8 +49,8 @@ def slice_datasets_by_time(
|
|
|
49
49
|
datasets_dict["sat"],
|
|
50
50
|
t0,
|
|
51
51
|
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
52
|
-
interval_start=minutes(
|
|
53
|
-
interval_end=minutes(
|
|
52
|
+
interval_start=minutes(sat_config.interval_start_minutes),
|
|
53
|
+
interval_end=minutes(sat_config.interval_end_minutes),
|
|
54
54
|
max_steps_gap=2,
|
|
55
55
|
)
|
|
56
56
|
|
|
@@ -74,15 +74,15 @@ def slice_datasets_by_time(
|
|
|
74
74
|
datasets_dict["gsp"],
|
|
75
75
|
t0,
|
|
76
76
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
77
|
-
interval_start=minutes(
|
|
78
|
-
interval_end=minutes(gsp_config.
|
|
77
|
+
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
78
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
79
79
|
)
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
82
82
|
datasets_dict["gsp"],
|
|
83
83
|
t0,
|
|
84
84
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
85
|
-
interval_start
|
|
85
|
+
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
86
86
|
interval_end=minutes(0),
|
|
87
87
|
)
|
|
88
88
|
|
|
@@ -94,9 +94,10 @@ def slice_datasets_by_time(
|
|
|
94
94
|
)
|
|
95
95
|
|
|
96
96
|
sliced_datasets_dict["gsp"] = apply_dropout_time(
|
|
97
|
-
sliced_datasets_dict["gsp"],
|
|
97
|
+
sliced_datasets_dict["gsp"],
|
|
98
|
+
gsp_dropout_time
|
|
98
99
|
)
|
|
99
|
-
|
|
100
|
+
|
|
100
101
|
if "site" in datasets_dict:
|
|
101
102
|
site_config = config.input_data.site
|
|
102
103
|
|
|
@@ -104,8 +105,8 @@ def slice_datasets_by_time(
|
|
|
104
105
|
datasets_dict["site"],
|
|
105
106
|
t0,
|
|
106
107
|
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
107
|
-
interval_start
|
|
108
|
-
interval_end=minutes(site_config.
|
|
108
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
109
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
109
110
|
)
|
|
110
111
|
|
|
111
112
|
# Randomly sample dropout
|
|
@@ -15,7 +15,7 @@ from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
|
|
|
15
15
|
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
|
|
16
16
|
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
17
17
|
from ocf_data_sampler.select.location import Location
|
|
18
|
-
from ocf_data_sampler.
|
|
18
|
+
from ocf_data_sampler.utils import minutes
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def process_and_combine_datasets(
|
|
@@ -23,7 +23,7 @@ def process_and_combine_datasets(
|
|
|
23
23
|
config: Configuration,
|
|
24
24
|
t0: pd.Timestamp,
|
|
25
25
|
location: Location,
|
|
26
|
-
|
|
26
|
+
target_key: str = 'gsp'
|
|
27
27
|
) -> dict:
|
|
28
28
|
"""Normalize and convert data to numpy arrays"""
|
|
29
29
|
|
|
@@ -35,7 +35,7 @@ def process_and_combine_datasets(
|
|
|
35
35
|
|
|
36
36
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
37
37
|
# Standardise
|
|
38
|
-
provider = config.input_data.nwp[nwp_key].
|
|
38
|
+
provider = config.input_data.nwp[nwp_key].provider
|
|
39
39
|
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
40
|
# Convert to NumpyBatch
|
|
41
41
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
@@ -58,7 +58,8 @@ def process_and_combine_datasets(
|
|
|
58
58
|
|
|
59
59
|
numpy_modalities.append(
|
|
60
60
|
convert_gsp_to_numpy_batch(
|
|
61
|
-
da_gsp,
|
|
61
|
+
da_gsp,
|
|
62
|
+
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
62
63
|
)
|
|
63
64
|
)
|
|
64
65
|
|
|
@@ -80,34 +81,32 @@ def process_and_combine_datasets(
|
|
|
80
81
|
|
|
81
82
|
numpy_modalities.append(
|
|
82
83
|
convert_site_to_numpy_batch(
|
|
83
|
-
da_sites, t0_idx
|
|
84
|
+
da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
|
|
84
85
|
)
|
|
85
86
|
)
|
|
86
87
|
|
|
87
|
-
if
|
|
88
|
+
if target_key == 'gsp':
|
|
88
89
|
# Make sun coords NumpyBatch
|
|
89
90
|
datetimes = pd.date_range(
|
|
90
|
-
t0
|
|
91
|
-
t0
|
|
91
|
+
t0+minutes(gsp_config.interval_start_minutes),
|
|
92
|
+
t0+minutes(gsp_config.interval_end_minutes),
|
|
92
93
|
freq=minutes(gsp_config.time_resolution_minutes),
|
|
93
94
|
)
|
|
94
95
|
|
|
95
96
|
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
96
|
-
key_prefix = "gsp"
|
|
97
97
|
|
|
98
|
-
elif
|
|
98
|
+
elif target_key == 'site':
|
|
99
99
|
# Make sun coords NumpyBatch
|
|
100
100
|
datetimes = pd.date_range(
|
|
101
|
-
t0
|
|
102
|
-
t0
|
|
101
|
+
t0+minutes(site_config.interval_start_minutes),
|
|
102
|
+
t0+minutes(site_config.interval_end_minutes),
|
|
103
103
|
freq=minutes(site_config.time_resolution_minutes),
|
|
104
104
|
)
|
|
105
105
|
|
|
106
106
|
lon, lat = location.x, location.y
|
|
107
|
-
key_prefix = "site"
|
|
108
107
|
|
|
109
108
|
numpy_modalities.append(
|
|
110
|
-
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=
|
|
109
|
+
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
# Combine all the modalities and fill NaNs
|
|
@@ -9,7 +9,7 @@ from torch.utils.data import Dataset
|
|
|
9
9
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
10
10
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
11
11
|
from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
|
|
12
|
-
from ocf_data_sampler.
|
|
12
|
+
from ocf_data_sampler.utils import minutes
|
|
13
13
|
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
|
|
14
14
|
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
15
15
|
|
|
@@ -14,7 +14,7 @@ from ocf_data_sampler.select import (
|
|
|
14
14
|
intersection_of_multiple_dataframes_of_periods,
|
|
15
15
|
slice_datasets_by_time, slice_datasets_by_space
|
|
16
16
|
)
|
|
17
|
-
from ocf_data_sampler.
|
|
17
|
+
from ocf_data_sampler.utils import minutes
|
|
18
18
|
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
|
|
19
19
|
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
|
|
20
20
|
|
|
@@ -22,8 +22,8 @@ xr.set_options(keep_attrs=True)
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def find_valid_t0_and_site_ids(
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
datasets_dict: dict,
|
|
26
|
+
config: Configuration,
|
|
27
27
|
) -> pd.DataFrame:
|
|
28
28
|
"""Find the t0 times where all of the requested input data is available
|
|
29
29
|
|
|
@@ -57,8 +57,8 @@ def find_valid_t0_and_site_ids(
|
|
|
57
57
|
time_periods = find_contiguous_t0_periods(
|
|
58
58
|
pd.DatetimeIndex(site["time_utc"]),
|
|
59
59
|
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
61
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
62
62
|
)
|
|
63
63
|
valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
|
|
64
64
|
[valid_time_periods, time_periods]
|
|
@@ -100,10 +100,10 @@ def get_locations(site_xr: xr.Dataset):
|
|
|
100
100
|
|
|
101
101
|
class SitesDataset(Dataset):
|
|
102
102
|
def __init__(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
103
|
+
self,
|
|
104
|
+
config_filename: str,
|
|
105
|
+
start_time: str | None = None,
|
|
106
|
+
end_time: str | None = None,
|
|
107
107
|
):
|
|
108
108
|
"""A torch Dataset for creating PVNet Site samples
|
|
109
109
|
|
|
@@ -154,7 +154,7 @@ class SitesDataset(Dataset):
|
|
|
154
154
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
155
155
|
sample_dict = compute(sample_dict)
|
|
156
156
|
|
|
157
|
-
sample = process_and_combine_datasets(sample_dict, self.config, t0, location,
|
|
157
|
+
sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')
|
|
158
158
|
|
|
159
159
|
return sample
|
|
160
160
|
|
|
@@ -2,9 +2,13 @@ import numpy as np
|
|
|
2
2
|
import pandas as pd
|
|
3
3
|
|
|
4
4
|
from ocf_data_sampler.config import Configuration
|
|
5
|
-
from ocf_data_sampler.select.find_contiguous_time_periods import
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
from ocf_data_sampler.select.find_contiguous_time_periods import (
|
|
6
|
+
find_contiguous_t0_periods_nwp,
|
|
7
|
+
find_contiguous_t0_periods,
|
|
8
|
+
intersection_of_multiple_dataframes_of_periods,
|
|
9
|
+
)
|
|
10
|
+
from ocf_data_sampler.utils import minutes
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
def find_valid_time_periods(
|
|
@@ -38,7 +42,7 @@ def find_valid_time_periods(
|
|
|
38
42
|
max_staleness = minutes(nwp_config.max_staleness_minutes)
|
|
39
43
|
|
|
40
44
|
# The last step of the forecast is lost if we have to diff channels
|
|
41
|
-
if len(nwp_config.
|
|
45
|
+
if len(nwp_config.accum_channels) > 0:
|
|
42
46
|
end_buffer = minutes(nwp_config.time_resolution_minutes)
|
|
43
47
|
else:
|
|
44
48
|
end_buffer = minutes(0)
|
|
@@ -46,7 +50,7 @@ def find_valid_time_periods(
|
|
|
46
50
|
# This is the max staleness we can use considering the max step of the input data
|
|
47
51
|
max_possible_staleness = (
|
|
48
52
|
pd.Timedelta(da["step"].max().item())
|
|
49
|
-
- minutes(nwp_config.
|
|
53
|
+
- minutes(nwp_config.interval_end_minutes)
|
|
50
54
|
- end_buffer
|
|
51
55
|
)
|
|
52
56
|
|
|
@@ -56,12 +60,16 @@ def find_valid_time_periods(
|
|
|
56
60
|
else:
|
|
57
61
|
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
58
62
|
assert max_staleness <= max_possible_staleness
|
|
63
|
+
|
|
64
|
+
# Find the first forecast step
|
|
65
|
+
first_forecast_step = pd.Timedelta(da["step"].min().item())
|
|
59
66
|
|
|
60
67
|
time_periods = find_contiguous_t0_periods_nwp(
|
|
61
|
-
|
|
62
|
-
|
|
68
|
+
init_times=pd.DatetimeIndex(da["init_time_utc"]),
|
|
69
|
+
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
63
70
|
max_staleness=max_staleness,
|
|
64
71
|
max_dropout=max_dropout,
|
|
72
|
+
first_forecast_step = first_forecast_step,
|
|
65
73
|
)
|
|
66
74
|
|
|
67
75
|
contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
|
|
@@ -72,8 +80,8 @@ def find_valid_time_periods(
|
|
|
72
80
|
time_periods = find_contiguous_t0_periods(
|
|
73
81
|
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
74
82
|
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
75
|
-
|
|
76
|
-
|
|
83
|
+
interval_start=minutes(sat_config.interval_start_minutes),
|
|
84
|
+
interval_end=minutes(sat_config.interval_end_minutes),
|
|
77
85
|
)
|
|
78
86
|
|
|
79
87
|
contiguous_time_periods['sat'] = time_periods
|
|
@@ -84,8 +92,8 @@ def find_valid_time_periods(
|
|
|
84
92
|
time_periods = find_contiguous_t0_periods(
|
|
85
93
|
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
86
94
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
87
|
-
|
|
88
|
-
|
|
95
|
+
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
96
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
89
97
|
)
|
|
90
98
|
|
|
91
99
|
contiguous_time_periods['gsp'] = time_periods
|
|
@@ -105,4 +113,4 @@ def find_valid_time_periods(
|
|
|
105
113
|
if len(valid_time_periods) == 0:
|
|
106
114
|
raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
|
|
107
115
|
|
|
108
|
-
return valid_time_periods
|
|
116
|
+
return valid_time_periods
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
2
|
ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
|
|
3
|
-
ocf_data_sampler/
|
|
3
|
+
ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
4
4
|
ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
|
|
5
5
|
ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
|
|
6
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
6
|
+
ocf_data_sampler/config/model.py,sha256=sXmh7IadwXDT-7lxEl5_b3vjovZgZYR77EXy4GHaf4w,7276
|
|
7
7
|
ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
|
|
10
10
|
ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
|
|
11
|
-
ocf_data_sampler/load/load_dataset.py,sha256=
|
|
11
|
+
ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
|
|
12
12
|
ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
|
|
13
13
|
ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
|
|
14
14
|
ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
|
|
@@ -27,22 +27,22 @@ ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLH
|
|
|
27
27
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
28
28
|
ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
|
|
29
29
|
ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
|
|
30
|
-
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=
|
|
30
|
+
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=q7IaNfX95A3z9XHqbhgtkZ4Js1gn5K9Qyp6DVLbsL-Q,11093
|
|
31
31
|
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
32
32
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
33
33
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
34
|
-
ocf_data_sampler/select/select_time_slice.py,sha256=
|
|
35
|
-
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=
|
|
36
|
-
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=
|
|
34
|
+
ocf_data_sampler/select/select_time_slice.py,sha256=D5P_cSvnv8Qs49K5au7lPxDr9U_VmDn42s5leMzHt0k,6122
|
|
35
|
+
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
36
|
+
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=LMw8KnOCKnPjD0m4UubAWERpaiQtzRKkI2cSh5a0A-M,4335
|
|
37
37
|
ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
38
|
-
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=
|
|
39
|
-
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=
|
|
40
|
-
ocf_data_sampler/torch_datasets/site.py,sha256=
|
|
41
|
-
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=
|
|
38
|
+
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=4k6f6PlMqrg3luMwGw3764iOyfuUNUePKyoikYGaRMI,4953
|
|
39
|
+
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=QRFqbdfNchVWj4y70n-rJdFvFGvQj-WpZLdFqWjnOTw,5543
|
|
40
|
+
ocf_data_sampler/torch_datasets/site.py,sha256=lo2ULurfWNu9vzBC6H4pdKMMpUMIT8_FWC1l_1mgIOM,6596
|
|
41
|
+
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
42
42
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
43
43
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
tests/conftest.py,sha256=
|
|
45
|
-
tests/config/test_config.py,sha256=
|
|
44
|
+
tests/conftest.py,sha256=N-_XgXpWeTRhkwP_NVh2mBORt2LKkM4mbkm-O62RN5I,7363
|
|
45
|
+
tests/config/test_config.py,sha256=eaye_F7-el4tTP4n2vRME8qlV0b2jaKUX4HhgOUpa7E,5203
|
|
46
46
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
47
47
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
48
48
|
tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
|
|
@@ -53,14 +53,14 @@ tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0
|
|
|
53
53
|
tests/numpy_batch/test_sun_position.py,sha256=FYQ7KtlN0V5LlEjgI-cKjTMtGHUCxiMvxkRYTdMAgEE,2485
|
|
54
54
|
tests/select/test_dropout.py,sha256=kiycl7RxAQYMCZJlokmx6Da5h_oBpSs8Is8pmSW4gOU,2413
|
|
55
55
|
tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
|
|
56
|
-
tests/select/test_find_contiguous_time_periods.py,sha256=
|
|
56
|
+
tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
|
|
57
57
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
58
58
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
59
|
-
tests/select/test_select_time_slice.py,sha256=
|
|
60
|
-
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=
|
|
61
|
-
tests/torch_datasets/test_site.py,sha256=
|
|
62
|
-
ocf_data_sampler-0.0.
|
|
63
|
-
ocf_data_sampler-0.0.
|
|
64
|
-
ocf_data_sampler-0.0.
|
|
65
|
-
ocf_data_sampler-0.0.
|
|
66
|
-
ocf_data_sampler-0.0.
|
|
59
|
+
tests/select/test_select_time_slice.py,sha256=QOhoR3qsr7RBGze4yohcViZ-ad1zYQzIKzxlnf0ymnU,9603
|
|
60
|
+
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=8gxjJO8FhY-ImX6eGnihDFsa8fhU2Zb4bVJaToJwuwo,2653
|
|
61
|
+
tests/torch_datasets/test_site.py,sha256=yTv6tAT6lha5yLYJiC8DNms1dct8o_ObPV97dHZyT7I,2719
|
|
62
|
+
ocf_data_sampler-0.0.27.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
63
|
+
ocf_data_sampler-0.0.27.dist-info/METADATA,sha256=bMOcVYluH-m7tyVm2J0Vz2T3ZLqNtEoX0HUwUvZMfEw,5269
|
|
64
|
+
ocf_data_sampler-0.0.27.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
|
65
|
+
ocf_data_sampler-0.0.27.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
66
|
+
ocf_data_sampler-0.0.27.dist-info/RECORD,,
|
tests/config/test_config.py
CHANGED
|
@@ -10,13 +10,13 @@ from ocf_data_sampler.config import (
|
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def test_default_configuration():
|
|
14
14
|
"""Test default pydantic class"""
|
|
15
15
|
|
|
16
16
|
_ = Configuration()
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def test_load_yaml_configuration(test_config_filename):
|
|
20
20
|
"""
|
|
21
21
|
Test that yaml loading works for 'test_config.yaml'
|
|
22
22
|
and fails for an empty .yaml file
|
|
@@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename):
|
|
|
56
56
|
assert test_config == tmp_config
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def
|
|
59
|
+
def test_extra_field_error():
|
|
60
60
|
"""
|
|
61
61
|
Check an extra parameters in config causes error
|
|
62
62
|
"""
|
|
@@ -68,27 +68,33 @@ def test_extra_field():
|
|
|
68
68
|
_ = Configuration(**configuration_dict)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def
|
|
71
|
+
def test_incorrect_interval_start_minutes(test_config_filename):
|
|
72
72
|
"""
|
|
73
|
-
Check a
|
|
73
|
+
Check a history length not divisible by time resolution causes error
|
|
74
74
|
"""
|
|
75
75
|
|
|
76
76
|
configuration = load_yaml_configuration(test_config_filename)
|
|
77
77
|
|
|
78
|
-
configuration.input_data.nwp['ukv'].
|
|
79
|
-
with pytest.raises(
|
|
78
|
+
configuration.input_data.nwp['ukv'].interval_start_minutes = -1111
|
|
79
|
+
with pytest.raises(
|
|
80
|
+
ValueError,
|
|
81
|
+
match="interval_start_minutes must be divisible by time_resolution_minutes"
|
|
82
|
+
):
|
|
80
83
|
_ = Configuration(**configuration.model_dump())
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
def
|
|
86
|
+
def test_incorrect_interval_end_minutes(test_config_filename):
|
|
84
87
|
"""
|
|
85
|
-
Check a
|
|
88
|
+
Check a forecast length not divisible by time resolution causes error
|
|
86
89
|
"""
|
|
87
90
|
|
|
88
91
|
configuration = load_yaml_configuration(test_config_filename)
|
|
89
92
|
|
|
90
|
-
configuration.input_data.nwp['ukv'].
|
|
91
|
-
with pytest.raises(
|
|
93
|
+
configuration.input_data.nwp['ukv'].interval_end_minutes = 1111
|
|
94
|
+
with pytest.raises(
|
|
95
|
+
ValueError,
|
|
96
|
+
match="interval_end_minutes must be divisible by time_resolution_minutes"
|
|
97
|
+
):
|
|
92
98
|
_ = Configuration(**configuration.model_dump())
|
|
93
99
|
|
|
94
100
|
|
|
@@ -99,10 +105,11 @@ def test_incorrect_nwp_provider(test_config_filename):
|
|
|
99
105
|
|
|
100
106
|
configuration = load_yaml_configuration(test_config_filename)
|
|
101
107
|
|
|
102
|
-
configuration.input_data.nwp['ukv'].
|
|
108
|
+
configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
|
|
103
109
|
with pytest.raises(Exception, match="NWP provider"):
|
|
104
110
|
_ = Configuration(**configuration.model_dump())
|
|
105
111
|
|
|
112
|
+
|
|
106
113
|
def test_incorrect_dropout(test_config_filename):
|
|
107
114
|
"""
|
|
108
115
|
Check a dropout timedelta over 0 causes error and 0 doesn't
|
|
@@ -119,6 +126,7 @@ def test_incorrect_dropout(test_config_filename):
|
|
|
119
126
|
configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
|
|
120
127
|
_ = Configuration(**configuration.model_dump())
|
|
121
128
|
|
|
129
|
+
|
|
122
130
|
def test_incorrect_dropout_fraction(test_config_filename):
|
|
123
131
|
"""
|
|
124
132
|
Check dropout fraction outside of range causes error
|
|
@@ -127,11 +135,12 @@ def test_incorrect_dropout_fraction(test_config_filename):
|
|
|
127
135
|
configuration = load_yaml_configuration(test_config_filename)
|
|
128
136
|
|
|
129
137
|
configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
|
|
130
|
-
|
|
138
|
+
|
|
139
|
+
with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
|
|
131
140
|
_ = Configuration(**configuration.model_dump())
|
|
132
141
|
|
|
133
142
|
configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
|
|
134
|
-
with pytest.raises(
|
|
143
|
+
with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
|
|
135
144
|
_ = Configuration(**configuration.model_dump())
|
|
136
145
|
|
|
137
146
|
|
tests/conftest.py
CHANGED
|
@@ -250,11 +250,13 @@ def data_sites() -> Site:
|
|
|
250
250
|
generation.to_netcdf(filename)
|
|
251
251
|
meta_df.to_csv(filename_csv)
|
|
252
252
|
|
|
253
|
-
site = Site(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
253
|
+
site = Site(
|
|
254
|
+
file_path=filename,
|
|
255
|
+
metadata_file_path=filename_csv,
|
|
256
|
+
interval_start_minutes=-30,
|
|
257
|
+
interval_end_minutes=60,
|
|
258
|
+
time_resolution_minutes=30,
|
|
259
|
+
)
|
|
258
260
|
|
|
259
261
|
yield site
|
|
260
262
|
|
|
@@ -11,8 +11,8 @@ def test_find_contiguous_t0_periods():
|
|
|
11
11
|
|
|
12
12
|
# Create 5-minutely data timestamps
|
|
13
13
|
freq = pd.Timedelta(5, "min")
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
interval_start = pd.Timedelta(-60, "min")
|
|
15
|
+
interval_end = pd.Timedelta(15, "min")
|
|
16
16
|
|
|
17
17
|
datetimes = (
|
|
18
18
|
pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq)
|
|
@@ -21,8 +21,8 @@ def test_find_contiguous_t0_periods():
|
|
|
21
21
|
|
|
22
22
|
periods = find_contiguous_t0_periods(
|
|
23
23
|
datetimes=datetimes,
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
interval_start=interval_start,
|
|
25
|
+
interval_end=interval_end,
|
|
26
26
|
sample_period_duration=freq,
|
|
27
27
|
)
|
|
28
28
|
|
|
@@ -135,7 +135,7 @@ def test_find_contiguous_t0_periods_nwp():
|
|
|
135
135
|
# Create 3-hourly init times with a few time stamps missing
|
|
136
136
|
freq = pd.Timedelta(3, "h")
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
init_times = (
|
|
139
139
|
pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq)
|
|
140
140
|
.delete([1, 4, 5, 6, 7, 9, 10])
|
|
141
141
|
)
|
|
@@ -146,13 +146,13 @@ def test_find_contiguous_t0_periods_nwp():
|
|
|
146
146
|
max_dropouts_hr = [0, 0, 0, 0, 3]
|
|
147
147
|
|
|
148
148
|
for i in range(len(expected_results)):
|
|
149
|
-
|
|
149
|
+
interval_start = pd.Timedelta(-history_durations_hr[i], "h")
|
|
150
150
|
max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h")
|
|
151
151
|
max_dropout = pd.Timedelta(max_dropouts_hr[i], "h")
|
|
152
152
|
|
|
153
153
|
time_periods = find_contiguous_t0_periods_nwp(
|
|
154
|
-
|
|
155
|
-
|
|
154
|
+
init_times=init_times,
|
|
155
|
+
interval_start=interval_start,
|
|
156
156
|
max_staleness=max_staleness,
|
|
157
157
|
max_dropout=max_dropout,
|
|
158
158
|
)
|