ocf-data-sampler 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +84 -87
- ocf_data_sampler/load/load_dataset.py +55 -0
- ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
- ocf_data_sampler/load/site.py +30 -0
- ocf_data_sampler/numpy_batch/__init__.py +1 -0
- ocf_data_sampler/numpy_batch/site.py +29 -0
- ocf_data_sampler/select/__init__.py +8 -1
- ocf_data_sampler/select/dropout.py +2 -1
- ocf_data_sampler/select/geospatial.py +43 -1
- ocf_data_sampler/select/select_spatial_slice.py +8 -2
- ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
- ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
- ocf_data_sampler/time_functions.py +11 -0
- ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
- ocf_data_sampler/torch_datasets/site.py +196 -0
- ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/RECORD +28 -16
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/top_level.txt +1 -0
- scripts/refactor_site.py +50 -0
- tests/config/test_config.py +9 -6
- tests/conftest.py +62 -0
- tests/load/test_load_sites.py +14 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +4 -4
- tests/torch_datasets/test_site.py +85 -0
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
""" Slice datasets by time"""
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
from ocf_data_sampler.config import Configuration
|
|
5
|
+
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
6
|
+
from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
|
|
7
|
+
from ocf_data_sampler.time_functions import minutes
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def slice_datasets_by_time(
|
|
11
|
+
datasets_dict: dict,
|
|
12
|
+
t0: pd.Timestamp,
|
|
13
|
+
config: Configuration,
|
|
14
|
+
) -> dict:
|
|
15
|
+
"""Slice the dictionary of input data sources around a given t0 time
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
datasets_dict: Dictionary of the input data sources
|
|
19
|
+
t0: The init-time
|
|
20
|
+
config: Configuration object.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
sliced_datasets_dict = {}
|
|
24
|
+
|
|
25
|
+
if "nwp" in datasets_dict:
|
|
26
|
+
|
|
27
|
+
sliced_datasets_dict["nwp"] = {}
|
|
28
|
+
|
|
29
|
+
for nwp_key, da_nwp in datasets_dict["nwp"].items():
|
|
30
|
+
|
|
31
|
+
nwp_config = config.input_data.nwp[nwp_key]
|
|
32
|
+
|
|
33
|
+
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
34
|
+
da_nwp,
|
|
35
|
+
t0,
|
|
36
|
+
sample_period_duration=minutes(nwp_config.time_resolution_minutes),
|
|
37
|
+
history_duration=minutes(nwp_config.history_minutes),
|
|
38
|
+
forecast_duration=minutes(nwp_config.forecast_minutes),
|
|
39
|
+
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
40
|
+
dropout_frac=nwp_config.dropout_fraction,
|
|
41
|
+
accum_channels=nwp_config.accum_channels,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if "sat" in datasets_dict:
|
|
45
|
+
|
|
46
|
+
sat_config = config.input_data.satellite
|
|
47
|
+
|
|
48
|
+
sliced_datasets_dict["sat"] = select_time_slice(
|
|
49
|
+
datasets_dict["sat"],
|
|
50
|
+
t0,
|
|
51
|
+
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
52
|
+
interval_start=minutes(-sat_config.history_minutes),
|
|
53
|
+
interval_end=minutes(-sat_config.live_delay_minutes),
|
|
54
|
+
max_steps_gap=2,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Randomly sample dropout
|
|
58
|
+
sat_dropout_time = draw_dropout_time(
|
|
59
|
+
t0,
|
|
60
|
+
dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
|
|
61
|
+
dropout_frac=sat_config.dropout_fraction,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Apply the dropout
|
|
65
|
+
sliced_datasets_dict["sat"] = apply_dropout_time(
|
|
66
|
+
sliced_datasets_dict["sat"],
|
|
67
|
+
sat_dropout_time,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if "gsp" in datasets_dict:
|
|
71
|
+
gsp_config = config.input_data.gsp
|
|
72
|
+
|
|
73
|
+
sliced_datasets_dict["gsp_future"] = select_time_slice(
|
|
74
|
+
datasets_dict["gsp"],
|
|
75
|
+
t0,
|
|
76
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
77
|
+
interval_start=minutes(30),
|
|
78
|
+
interval_end=minutes(gsp_config.forecast_minutes),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
82
|
+
datasets_dict["gsp"],
|
|
83
|
+
t0,
|
|
84
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
85
|
+
interval_start=-minutes(gsp_config.history_minutes),
|
|
86
|
+
interval_end=minutes(0),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Dropout on the GSP, but not the future GSP
|
|
90
|
+
gsp_dropout_time = draw_dropout_time(
|
|
91
|
+
t0,
|
|
92
|
+
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
93
|
+
dropout_frac=gsp_config.dropout_fraction,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
sliced_datasets_dict["gsp"] = apply_dropout_time(
|
|
97
|
+
sliced_datasets_dict["gsp"], gsp_dropout_time
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if "site" in datasets_dict:
|
|
101
|
+
site_config = config.input_data.site
|
|
102
|
+
|
|
103
|
+
sliced_datasets_dict["site"] = select_time_slice(
|
|
104
|
+
datasets_dict["site"],
|
|
105
|
+
t0,
|
|
106
|
+
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
107
|
+
interval_start=-minutes(site_config.history_minutes),
|
|
108
|
+
interval_end=minutes(site_config.forecast_minutes),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Randomly sample dropout
|
|
112
|
+
site_dropout_time = draw_dropout_time(
|
|
113
|
+
t0,
|
|
114
|
+
dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes),
|
|
115
|
+
dropout_frac=site_config.dropout_fraction,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Apply the dropout
|
|
119
|
+
sliced_datasets_dict["site"] = apply_dropout_time(
|
|
120
|
+
sliced_datasets_dict["site"],
|
|
121
|
+
site_dropout_time,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return sliced_datasets_dict
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex:
|
|
5
|
+
"""Timedelta minutes
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
minutes: the number of minutes, single value or list
|
|
9
|
+
"""
|
|
10
|
+
minutes_delta = pd.to_timedelta(minutes, unit="m")
|
|
11
|
+
return minutes_delta
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.config import Configuration
|
|
6
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
7
|
+
from ocf_data_sampler.numpy_batch import (
|
|
8
|
+
convert_nwp_to_numpy_batch,
|
|
9
|
+
convert_satellite_to_numpy_batch,
|
|
10
|
+
convert_gsp_to_numpy_batch,
|
|
11
|
+
make_sun_position_numpy_batch,
|
|
12
|
+
convert_site_to_numpy_batch,
|
|
13
|
+
)
|
|
14
|
+
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
|
|
15
|
+
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
|
|
16
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
17
|
+
from ocf_data_sampler.select.location import Location
|
|
18
|
+
from ocf_data_sampler.time_functions import minutes
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process_and_combine_datasets(
|
|
22
|
+
dataset_dict: dict,
|
|
23
|
+
config: Configuration,
|
|
24
|
+
t0: pd.Timestamp,
|
|
25
|
+
location: Location,
|
|
26
|
+
sun_position_key: str = 'gsp'
|
|
27
|
+
) -> dict:
|
|
28
|
+
"""Normalize and convert data to numpy arrays"""
|
|
29
|
+
|
|
30
|
+
numpy_modalities = []
|
|
31
|
+
|
|
32
|
+
if "nwp" in dataset_dict:
|
|
33
|
+
|
|
34
|
+
nwp_numpy_modalities = dict()
|
|
35
|
+
|
|
36
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
37
|
+
# Standardise
|
|
38
|
+
provider = config.input_data.nwp[nwp_key].provider
|
|
39
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
|
+
# Convert to NumpyBatch
|
|
41
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
42
|
+
|
|
43
|
+
# Combine the NWPs into NumpyBatch
|
|
44
|
+
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
|
|
45
|
+
|
|
46
|
+
if "sat" in dataset_dict:
|
|
47
|
+
# Satellite is already in the range [0-1] so no need to standardise
|
|
48
|
+
da_sat = dataset_dict["sat"]
|
|
49
|
+
|
|
50
|
+
# Convert to NumpyBatch
|
|
51
|
+
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
|
|
52
|
+
|
|
53
|
+
gsp_config = config.input_data.gsp
|
|
54
|
+
|
|
55
|
+
if "gsp" in dataset_dict:
|
|
56
|
+
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
|
|
57
|
+
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
58
|
+
|
|
59
|
+
numpy_modalities.append(
|
|
60
|
+
convert_gsp_to_numpy_batch(
|
|
61
|
+
da_gsp, t0_idx=gsp_config.history_minutes // gsp_config.time_resolution_minutes
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Add coordinate data
|
|
66
|
+
# TODO: Do we need all of these?
|
|
67
|
+
numpy_modalities.append(
|
|
68
|
+
{
|
|
69
|
+
GSPBatchKey.gsp_id: location.id,
|
|
70
|
+
GSPBatchKey.x_osgb: location.x,
|
|
71
|
+
GSPBatchKey.y_osgb: location.y,
|
|
72
|
+
}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if "site" in dataset_dict:
|
|
77
|
+
site_config = config.input_data.site
|
|
78
|
+
da_sites = dataset_dict["site"]
|
|
79
|
+
da_sites = da_sites / da_sites.capacity_kwp
|
|
80
|
+
|
|
81
|
+
numpy_modalities.append(
|
|
82
|
+
convert_site_to_numpy_batch(
|
|
83
|
+
da_sites, t0_idx=site_config.history_minutes / site_config.time_resolution_minutes
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if sun_position_key == 'gsp':
|
|
88
|
+
# Make sun coords NumpyBatch
|
|
89
|
+
datetimes = pd.date_range(
|
|
90
|
+
t0 - minutes(gsp_config.history_minutes),
|
|
91
|
+
t0 + minutes(gsp_config.forecast_minutes),
|
|
92
|
+
freq=minutes(gsp_config.time_resolution_minutes),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
96
|
+
key_prefix = "gsp"
|
|
97
|
+
|
|
98
|
+
elif sun_position_key == 'site':
|
|
99
|
+
# Make sun coords NumpyBatch
|
|
100
|
+
datetimes = pd.date_range(
|
|
101
|
+
t0 - minutes(site_config.history_minutes),
|
|
102
|
+
t0 + minutes(site_config.forecast_minutes),
|
|
103
|
+
freq=minutes(site_config.time_resolution_minutes),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
lon, lat = location.x, location.y
|
|
107
|
+
key_prefix = "site"
|
|
108
|
+
|
|
109
|
+
numpy_modalities.append(
|
|
110
|
+
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=key_prefix)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Combine all the modalities and fill NaNs
|
|
114
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
115
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
116
|
+
|
|
117
|
+
return combined_sample
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
121
|
+
"""Merge a list of dictionaries into a single dictionary"""
|
|
122
|
+
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
123
|
+
combined_dict = {}
|
|
124
|
+
for d in list_of_dicts:
|
|
125
|
+
combined_dict.update(d)
|
|
126
|
+
return combined_dict
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def fill_nans_in_arrays(batch: dict) -> dict:
|
|
130
|
+
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
|
|
131
|
+
|
|
132
|
+
Operation is performed in-place on the batch.
|
|
133
|
+
"""
|
|
134
|
+
for k, v in batch.items():
|
|
135
|
+
if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
|
|
136
|
+
if np.isnan(v).any():
|
|
137
|
+
batch[k] = np.nan_to_num(v, copy=False, nan=0.0)
|
|
138
|
+
|
|
139
|
+
# Recursion is included to reach NWP arrays in subdict
|
|
140
|
+
elif isinstance(v, dict):
|
|
141
|
+
fill_nans_in_arrays(v)
|
|
142
|
+
|
|
143
|
+
return batch
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def compute(xarray_dict: dict) -> dict:
|
|
147
|
+
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
148
|
+
for k, v in xarray_dict.items():
|
|
149
|
+
if isinstance(v, dict):
|
|
150
|
+
xarray_dict[k] = compute(v)
|
|
151
|
+
else:
|
|
152
|
+
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
153
|
+
return xarray_dict
|