ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/__init__.py +5 -0
- ocf_data_sampler/config/load.py +33 -0
- ocf_data_sampler/config/model.py +246 -0
- ocf_data_sampler/config/save.py +73 -0
- ocf_data_sampler/constants.py +173 -0
- ocf_data_sampler/load/load_dataset.py +55 -0
- ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
- ocf_data_sampler/load/site.py +30 -0
- ocf_data_sampler/numpy_sample/__init__.py +8 -0
- ocf_data_sampler/numpy_sample/collate.py +77 -0
- ocf_data_sampler/numpy_sample/gsp.py +34 -0
- ocf_data_sampler/numpy_sample/nwp.py +42 -0
- ocf_data_sampler/numpy_sample/satellite.py +30 -0
- ocf_data_sampler/numpy_sample/site.py +30 -0
- ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
- ocf_data_sampler/select/__init__.py +8 -1
- ocf_data_sampler/select/dropout.py +4 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
- ocf_data_sampler/select/geospatial.py +160 -0
- ocf_data_sampler/select/location.py +62 -0
- ocf_data_sampler/select/select_spatial_slice.py +13 -16
- ocf_data_sampler/select/select_time_slice.py +24 -33
- ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
- ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
- ocf_data_sampler/torch_datasets/__init__.py +2 -1
- ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +11 -425
- ocf_data_sampler/torch_datasets/site.py +405 -0
- ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
- ocf_data_sampler/utils.py +10 -0
- ocf_data_sampler-0.0.42.dist-info/METADATA +153 -0
- ocf_data_sampler-0.0.42.dist-info/RECORD +71 -0
- {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.42.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.42.dist-info}/top_level.txt +1 -0
- scripts/refactor_site.py +50 -0
- tests/config/test_config.py +161 -0
- tests/config/test_save.py +37 -0
- tests/conftest.py +86 -1
- tests/load/test_load_gsp.py +15 -0
- tests/load/test_load_nwp.py +21 -0
- tests/load/test_load_satellite.py +17 -0
- tests/load/test_load_sites.py +14 -0
- tests/numpy_sample/test_collate.py +26 -0
- tests/numpy_sample/test_gsp.py +38 -0
- tests/numpy_sample/test_nwp.py +52 -0
- tests/numpy_sample/test_satellite.py +40 -0
- tests/numpy_sample/test_sun_position.py +81 -0
- tests/select/test_dropout.py +75 -0
- tests/select/test_fill_time_periods.py +28 -0
- tests/select/test_find_contiguous_time_periods.py +202 -0
- tests/select/test_location.py +67 -0
- tests/select/test_select_spatial_slice.py +154 -0
- tests/select/test_select_time_slice.py +272 -0
- tests/torch_datasets/conftest.py +18 -0
- tests/torch_datasets/test_process_and_combine.py +126 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
- tests/torch_datasets/test_site.py +129 -0
- ocf_data_sampler/numpy_batch/__init__.py +0 -7
- ocf_data_sampler/numpy_batch/gsp.py +0 -20
- ocf_data_sampler/numpy_batch/nwp.py +0 -33
- ocf_data_sampler/numpy_batch/satellite.py +0 -23
- ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
- ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
- {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.42.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""location"""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]
|
|
10
|
+
|
|
11
|
+
class Location(BaseModel):
|
|
12
|
+
"""Represent a spatial location."""
|
|
13
|
+
|
|
14
|
+
coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
|
|
15
|
+
x: float
|
|
16
|
+
y: float
|
|
17
|
+
id: Optional[int] = Field(None)
|
|
18
|
+
|
|
19
|
+
@model_validator(mode='after')
|
|
20
|
+
def validate_coordinate_system(self):
|
|
21
|
+
"""Validate 'coordinate_system'"""
|
|
22
|
+
if self.coordinate_system not in allowed_coordinate_systems:
|
|
23
|
+
raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
|
|
24
|
+
return self
|
|
25
|
+
|
|
26
|
+
@model_validator(mode='after')
|
|
27
|
+
def validate_x(self):
|
|
28
|
+
"""Validate 'x'"""
|
|
29
|
+
min_x: float
|
|
30
|
+
max_x: float
|
|
31
|
+
|
|
32
|
+
co = self.coordinate_system
|
|
33
|
+
if co == "osgb":
|
|
34
|
+
min_x, max_x = -103976.3, 652897.98
|
|
35
|
+
if co == "lon_lat":
|
|
36
|
+
min_x, max_x = -180, 180
|
|
37
|
+
if co == "geostationary":
|
|
38
|
+
min_x, max_x = -5568748.275756836, 5567248.074173927
|
|
39
|
+
if co == "idx":
|
|
40
|
+
min_x, max_x = 0, np.inf
|
|
41
|
+
if self.x < min_x or self.x > max_x:
|
|
42
|
+
raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
|
|
43
|
+
return self
|
|
44
|
+
|
|
45
|
+
@model_validator(mode='after')
|
|
46
|
+
def validate_y(self):
|
|
47
|
+
"""Validate 'y'"""
|
|
48
|
+
min_y: float
|
|
49
|
+
max_y: float
|
|
50
|
+
|
|
51
|
+
co = self.coordinate_system
|
|
52
|
+
if co == "osgb":
|
|
53
|
+
min_y, max_y = -16703.87, 1199851.44
|
|
54
|
+
if co == "lon_lat":
|
|
55
|
+
min_y, max_y = -90, 90
|
|
56
|
+
if co == "geostationary":
|
|
57
|
+
min_y, max_y = 1393687.2151494026, 5570748.323202133
|
|
58
|
+
if co == "idx":
|
|
59
|
+
min_y, max_y = 0, np.inf
|
|
60
|
+
if self.y < min_y or self.y > max_y:
|
|
61
|
+
raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
|
|
62
|
+
return self
|
|
@@ -5,15 +5,15 @@ import logging
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import xarray as xr
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
lon_lat_to_geostationary_area_coords,
|
|
8
|
+
from ocf_data_sampler.select.location import Location
|
|
9
|
+
from ocf_data_sampler.select.geospatial import (
|
|
11
10
|
lon_lat_to_osgb,
|
|
11
|
+
lon_lat_to_geostationary_area_coords,
|
|
12
12
|
osgb_to_geostationary_area_coords,
|
|
13
13
|
osgb_to_lon_lat,
|
|
14
14
|
spatial_coord_type,
|
|
15
15
|
)
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
@@ -45,9 +45,6 @@ def convert_coords_to_match_xarray(
|
|
|
45
45
|
if from_coords == "osgb":
|
|
46
46
|
x, y = osgb_to_geostationary_area_coords(x, y, da)
|
|
47
47
|
|
|
48
|
-
elif from_coords == "lon_lat":
|
|
49
|
-
x, y = lon_lat_to_geostationary_area_coords(x, y, da)
|
|
50
|
-
|
|
51
48
|
elif target_coords == "lon_lat":
|
|
52
49
|
if from_coords == "osgb":
|
|
53
50
|
x, y = osgb_to_lon_lat(x, y)
|
|
@@ -105,7 +102,7 @@ def _get_idx_of_pixel_closest_to_poi(
|
|
|
105
102
|
|
|
106
103
|
def _get_idx_of_pixel_closest_to_poi_geostationary(
|
|
107
104
|
da: xr.DataArray,
|
|
108
|
-
|
|
105
|
+
center: Location,
|
|
109
106
|
) -> Location:
|
|
110
107
|
"""
|
|
111
108
|
Return x and y index location of pixel at center of region of interest.
|
|
@@ -120,7 +117,12 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
|
|
|
120
117
|
|
|
121
118
|
_, x_dim, y_dim = spatial_coord_type(da)
|
|
122
119
|
|
|
123
|
-
|
|
120
|
+
if center.coordinate_system == 'osgb':
|
|
121
|
+
x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
|
|
122
|
+
elif center.coordinate_system == 'lon_lat':
|
|
123
|
+
x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da)
|
|
124
|
+
else:
|
|
125
|
+
x,y = center.x, center.y
|
|
124
126
|
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
|
|
125
127
|
|
|
126
128
|
# Check that the requested point lies within the data
|
|
@@ -130,13 +132,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
|
|
|
130
132
|
f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"
|
|
131
133
|
|
|
132
134
|
# Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
|
|
133
|
-
x_index_at_center = searchsorted(
|
|
134
|
-
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
y_index_at_center = searchsorted(
|
|
138
|
-
da[y_dim].values, center_geostationary.y, assume_ascending=True
|
|
139
|
-
)
|
|
135
|
+
x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
|
|
136
|
+
y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)
|
|
140
137
|
|
|
141
138
|
return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
|
|
142
139
|
|
|
@@ -39,23 +39,14 @@ def _sel_fillinterp(
|
|
|
39
39
|
def select_time_slice(
|
|
40
40
|
ds: xr.DataArray,
|
|
41
41
|
t0: pd.Timestamp,
|
|
42
|
+
interval_start: pd.Timedelta,
|
|
43
|
+
interval_end: pd.Timedelta,
|
|
42
44
|
sample_period_duration: pd.Timedelta,
|
|
43
|
-
history_duration: pd.Timedelta | None = None,
|
|
44
|
-
forecast_duration: pd.Timedelta | None = None,
|
|
45
|
-
interval_start: pd.Timedelta | None = None,
|
|
46
|
-
interval_end: pd.Timedelta | None = None,
|
|
47
45
|
fill_selection: bool = False,
|
|
48
46
|
max_steps_gap: int = 0,
|
|
49
47
|
):
|
|
50
48
|
"""Select a time slice from a Dataset or DataArray."""
|
|
51
|
-
used_duration = history_duration is not None and forecast_duration is not None
|
|
52
|
-
used_intervals = interval_start is not None and interval_end is not None
|
|
53
|
-
assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
|
|
54
49
|
assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "
|
|
55
|
-
|
|
56
|
-
if used_duration:
|
|
57
|
-
interval_start = - history_duration
|
|
58
|
-
interval_end = forecast_duration
|
|
59
50
|
|
|
60
51
|
if fill_selection and max_steps_gap == 0:
|
|
61
52
|
_sel = _sel_fillnan
|
|
@@ -75,11 +66,11 @@ def select_time_slice(
|
|
|
75
66
|
|
|
76
67
|
|
|
77
68
|
def select_time_slice_nwp(
|
|
78
|
-
|
|
69
|
+
da: xr.DataArray,
|
|
79
70
|
t0: pd.Timestamp,
|
|
71
|
+
interval_start: pd.Timedelta,
|
|
72
|
+
interval_end: pd.Timedelta,
|
|
80
73
|
sample_period_duration: pd.Timedelta,
|
|
81
|
-
history_duration: pd.Timedelta,
|
|
82
|
-
forecast_duration: pd.Timedelta,
|
|
83
74
|
dropout_timedeltas: list[pd.Timedelta] | None = None,
|
|
84
75
|
dropout_frac: float | None = 0,
|
|
85
76
|
accum_channels: list[str] = [],
|
|
@@ -92,31 +83,31 @@ def select_time_slice_nwp(
|
|
|
92
83
|
), "dropout timedeltas must be negative"
|
|
93
84
|
assert len(dropout_timedeltas) >= 1
|
|
94
85
|
assert 0 <= dropout_frac <= 1
|
|
95
|
-
|
|
86
|
+
consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0
|
|
96
87
|
|
|
97
88
|
|
|
98
89
|
# The accumatation and non-accumulation channels
|
|
99
90
|
accum_channels = np.intersect1d(
|
|
100
|
-
|
|
91
|
+
da[channel_dim_name].values, accum_channels
|
|
101
92
|
)
|
|
102
93
|
non_accum_channels = np.setdiff1d(
|
|
103
|
-
|
|
94
|
+
da[channel_dim_name].values, accum_channels
|
|
104
95
|
)
|
|
105
96
|
|
|
106
|
-
start_dt = (t0
|
|
107
|
-
end_dt = (t0 +
|
|
97
|
+
start_dt = (t0 + interval_start).ceil(sample_period_duration)
|
|
98
|
+
end_dt = (t0 + interval_end).ceil(sample_period_duration)
|
|
108
99
|
|
|
109
100
|
target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
|
|
110
101
|
|
|
111
102
|
# Maybe apply NWP dropout
|
|
112
|
-
if
|
|
103
|
+
if consider_dropout and (np.random.uniform() < dropout_frac):
|
|
113
104
|
dt = np.random.choice(dropout_timedeltas)
|
|
114
105
|
t0_available = t0 + dt
|
|
115
106
|
else:
|
|
116
107
|
t0_available = t0
|
|
117
108
|
|
|
118
109
|
# Forecasts made up to and including t0
|
|
119
|
-
available_init_times =
|
|
110
|
+
available_init_times = da.init_time_utc.sel(
|
|
120
111
|
init_time_utc=slice(None, t0_available)
|
|
121
112
|
)
|
|
122
113
|
|
|
@@ -139,7 +130,7 @@ def select_time_slice_nwp(
|
|
|
139
130
|
step_indexer = xr.DataArray(steps, coords=coords)
|
|
140
131
|
|
|
141
132
|
if len(accum_channels) == 0:
|
|
142
|
-
|
|
133
|
+
da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
143
134
|
|
|
144
135
|
else:
|
|
145
136
|
# First minimise the size of the dataset we are diffing
|
|
@@ -149,7 +140,7 @@ def select_time_slice_nwp(
|
|
|
149
140
|
min_step = min(steps)
|
|
150
141
|
max_step = max(steps) + sample_period_duration
|
|
151
142
|
|
|
152
|
-
|
|
143
|
+
da_min = da.sel(
|
|
153
144
|
{
|
|
154
145
|
"init_time_utc": unique_init_times,
|
|
155
146
|
"step": slice(min_step, max_step),
|
|
@@ -157,28 +148,28 @@ def select_time_slice_nwp(
|
|
|
157
148
|
)
|
|
158
149
|
|
|
159
150
|
# Slice out the data which does not need to be diffed
|
|
160
|
-
|
|
161
|
-
|
|
151
|
+
da_non_accum = da_min.sel({channel_dim_name: non_accum_channels})
|
|
152
|
+
da_sel_non_accum = da_non_accum.sel(
|
|
162
153
|
step=step_indexer, init_time_utc=init_time_indexer
|
|
163
154
|
)
|
|
164
155
|
|
|
165
156
|
# Slice out the channels which need to be diffed
|
|
166
|
-
|
|
157
|
+
da_accum = da_min.sel({channel_dim_name: accum_channels})
|
|
167
158
|
|
|
168
159
|
# Take the diff and slice requested data
|
|
169
|
-
|
|
170
|
-
|
|
160
|
+
da_accum = da_accum.diff(dim="step", label="lower")
|
|
161
|
+
da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer)
|
|
171
162
|
|
|
172
163
|
# Join diffed and non-diffed variables
|
|
173
|
-
|
|
164
|
+
da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name)
|
|
174
165
|
|
|
175
166
|
# Reorder the variable back to the original order
|
|
176
|
-
|
|
167
|
+
da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values})
|
|
177
168
|
|
|
178
169
|
# Rename the diffed channels
|
|
179
|
-
|
|
170
|
+
da_sel[channel_dim_name] = [
|
|
180
171
|
f"diff_{v}" if v in accum_channels else v
|
|
181
|
-
for v in
|
|
172
|
+
for v in da_sel[channel_dim_name].values
|
|
182
173
|
]
|
|
183
174
|
|
|
184
|
-
return
|
|
175
|
+
return da_sel
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
""" Functions for selecting data around a given location """
|
|
2
|
+
from ocf_data_sampler.config import Configuration
|
|
3
|
+
from ocf_data_sampler.select.location import Location
|
|
4
|
+
from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def slice_datasets_by_space(
|
|
8
|
+
datasets_dict: dict,
|
|
9
|
+
location: Location,
|
|
10
|
+
config: Configuration,
|
|
11
|
+
) -> dict:
|
|
12
|
+
"""Slice the dictionary of input data sources around a given location
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
datasets_dict: Dictionary of the input data sources
|
|
16
|
+
location: The location to sample around
|
|
17
|
+
config: Configuration object.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"})
|
|
21
|
+
|
|
22
|
+
sliced_datasets_dict = {}
|
|
23
|
+
|
|
24
|
+
if "nwp" in datasets_dict:
|
|
25
|
+
|
|
26
|
+
sliced_datasets_dict["nwp"] = {}
|
|
27
|
+
|
|
28
|
+
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
29
|
+
|
|
30
|
+
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
|
|
31
|
+
datasets_dict["nwp"][nwp_key],
|
|
32
|
+
location,
|
|
33
|
+
height_pixels=nwp_config.image_size_pixels_height,
|
|
34
|
+
width_pixels=nwp_config.image_size_pixels_width,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if "sat" in datasets_dict:
|
|
38
|
+
sat_config = config.input_data.satellite
|
|
39
|
+
|
|
40
|
+
sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
|
|
41
|
+
datasets_dict["sat"],
|
|
42
|
+
location,
|
|
43
|
+
height_pixels=sat_config.image_size_pixels_height,
|
|
44
|
+
width_pixels=sat_config.image_size_pixels_width,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if "gsp" in datasets_dict:
|
|
48
|
+
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
|
|
49
|
+
|
|
50
|
+
if "site" in datasets_dict:
|
|
51
|
+
sliced_datasets_dict["site"] = datasets_dict["site"].sel(site_id=location.id)
|
|
52
|
+
|
|
53
|
+
return sliced_datasets_dict
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
""" Slice datasets by time"""
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
from ocf_data_sampler.config import Configuration
|
|
5
|
+
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
6
|
+
from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice
|
|
7
|
+
from ocf_data_sampler.utils import minutes
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def slice_datasets_by_time(
|
|
11
|
+
datasets_dict: dict,
|
|
12
|
+
t0: pd.Timestamp,
|
|
13
|
+
config: Configuration,
|
|
14
|
+
) -> dict:
|
|
15
|
+
"""Slice the dictionary of input data sources around a given t0 time
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
datasets_dict: Dictionary of the input data sources
|
|
19
|
+
t0: The init-time
|
|
20
|
+
config: Configuration object.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
sliced_datasets_dict = {}
|
|
24
|
+
|
|
25
|
+
if "nwp" in datasets_dict:
|
|
26
|
+
|
|
27
|
+
sliced_datasets_dict["nwp"] = {}
|
|
28
|
+
|
|
29
|
+
for nwp_key, da_nwp in datasets_dict["nwp"].items():
|
|
30
|
+
|
|
31
|
+
nwp_config = config.input_data.nwp[nwp_key]
|
|
32
|
+
|
|
33
|
+
sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
|
|
34
|
+
da_nwp,
|
|
35
|
+
t0,
|
|
36
|
+
sample_period_duration=minutes(nwp_config.time_resolution_minutes),
|
|
37
|
+
interval_start=minutes(nwp_config.interval_start_minutes),
|
|
38
|
+
interval_end=minutes(nwp_config.interval_end_minutes),
|
|
39
|
+
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
|
|
40
|
+
dropout_frac=nwp_config.dropout_fraction,
|
|
41
|
+
accum_channels=nwp_config.accum_channels,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if "sat" in datasets_dict:
|
|
45
|
+
|
|
46
|
+
sat_config = config.input_data.satellite
|
|
47
|
+
|
|
48
|
+
sliced_datasets_dict["sat"] = select_time_slice(
|
|
49
|
+
datasets_dict["sat"],
|
|
50
|
+
t0,
|
|
51
|
+
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
52
|
+
interval_start=minutes(sat_config.interval_start_minutes),
|
|
53
|
+
interval_end=minutes(sat_config.interval_end_minutes),
|
|
54
|
+
max_steps_gap=2,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Randomly sample dropout
|
|
58
|
+
sat_dropout_time = draw_dropout_time(
|
|
59
|
+
t0,
|
|
60
|
+
dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
|
|
61
|
+
dropout_frac=sat_config.dropout_fraction,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Apply the dropout
|
|
65
|
+
sliced_datasets_dict["sat"] = apply_dropout_time(
|
|
66
|
+
sliced_datasets_dict["sat"],
|
|
67
|
+
sat_dropout_time,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if "gsp" in datasets_dict:
|
|
71
|
+
gsp_config = config.input_data.gsp
|
|
72
|
+
|
|
73
|
+
sliced_datasets_dict["gsp_future"] = select_time_slice(
|
|
74
|
+
datasets_dict["gsp"],
|
|
75
|
+
t0,
|
|
76
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
77
|
+
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
78
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
82
|
+
datasets_dict["gsp"],
|
|
83
|
+
t0,
|
|
84
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
85
|
+
interval_start=minutes(gsp_config.interval_start_minutes),
|
|
86
|
+
interval_end=minutes(0),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Dropout on the GSP, but not the future GSP
|
|
90
|
+
gsp_dropout_time = draw_dropout_time(
|
|
91
|
+
t0,
|
|
92
|
+
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
93
|
+
dropout_frac=gsp_config.dropout_fraction,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
sliced_datasets_dict["gsp"] = apply_dropout_time(
|
|
97
|
+
sliced_datasets_dict["gsp"],
|
|
98
|
+
gsp_dropout_time
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if "site" in datasets_dict:
|
|
102
|
+
site_config = config.input_data.site
|
|
103
|
+
|
|
104
|
+
sliced_datasets_dict["site"] = select_time_slice(
|
|
105
|
+
datasets_dict["site"],
|
|
106
|
+
t0,
|
|
107
|
+
sample_period_duration=minutes(site_config.time_resolution_minutes),
|
|
108
|
+
interval_start=minutes(site_config.interval_start_minutes),
|
|
109
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Randomly sample dropout
|
|
113
|
+
site_dropout_time = draw_dropout_time(
|
|
114
|
+
t0,
|
|
115
|
+
dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes),
|
|
116
|
+
dropout_frac=site_config.dropout_fraction,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Apply the dropout
|
|
120
|
+
sliced_datasets_dict["site"] = apply_dropout_time(
|
|
121
|
+
sliced_datasets_dict["site"],
|
|
122
|
+
site_dropout_time,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return sliced_datasets_dict
|
|
@@ -1 +1,2 @@
|
|
|
1
|
-
|
|
1
|
+
from .pvnet_uk_regional import PVNetUKRegionalDataset
|
|
2
|
+
from .site import SitesDataset
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ocf_data_sampler.config import Configuration
|
|
7
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS,RSS_MEAN,RSS_STD
|
|
8
|
+
from ocf_data_sampler.numpy_sample import (
|
|
9
|
+
convert_nwp_to_numpy_sample,
|
|
10
|
+
convert_satellite_to_numpy_sample,
|
|
11
|
+
convert_gsp_to_numpy_sample,
|
|
12
|
+
make_sun_position_numpy_sample,
|
|
13
|
+
)
|
|
14
|
+
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
15
|
+
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
16
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
17
|
+
from ocf_data_sampler.select.location import Location
|
|
18
|
+
from ocf_data_sampler.utils import minutes
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process_and_combine_datasets(
|
|
22
|
+
dataset_dict: dict,
|
|
23
|
+
config: Configuration,
|
|
24
|
+
t0: Optional[pd.Timestamp] = None,
|
|
25
|
+
location: Optional[Location] = None,
|
|
26
|
+
target_key: str = 'gsp'
|
|
27
|
+
) -> dict:
|
|
28
|
+
|
|
29
|
+
"""Normalise and convert data to numpy arrays"""
|
|
30
|
+
numpy_modalities = []
|
|
31
|
+
|
|
32
|
+
if "nwp" in dataset_dict:
|
|
33
|
+
|
|
34
|
+
nwp_numpy_modalities = dict()
|
|
35
|
+
|
|
36
|
+
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
37
|
+
# Standardise
|
|
38
|
+
provider = config.input_data.nwp[nwp_key].provider
|
|
39
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
|
+
# Convert to NumpySample
|
|
41
|
+
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
42
|
+
|
|
43
|
+
# Combine the NWPs into NumpySample
|
|
44
|
+
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if "sat" in dataset_dict:
|
|
48
|
+
# Standardise
|
|
49
|
+
da_sat = dataset_dict["sat"]
|
|
50
|
+
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
51
|
+
|
|
52
|
+
# Convert to NumpySample
|
|
53
|
+
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
gsp_config = config.input_data.gsp
|
|
57
|
+
|
|
58
|
+
if "gsp" in dataset_dict:
|
|
59
|
+
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
|
|
60
|
+
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
61
|
+
|
|
62
|
+
numpy_modalities.append(
|
|
63
|
+
convert_gsp_to_numpy_sample(
|
|
64
|
+
da_gsp,
|
|
65
|
+
t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Add coordinate data
|
|
70
|
+
# TODO: Do we need all of these?
|
|
71
|
+
numpy_modalities.append(
|
|
72
|
+
{
|
|
73
|
+
GSPSampleKey.gsp_id: location.id,
|
|
74
|
+
GSPSampleKey.x_osgb: location.x,
|
|
75
|
+
GSPSampleKey.y_osgb: location.y,
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if target_key == 'gsp':
|
|
80
|
+
# Make sun coords NumpySample
|
|
81
|
+
datetimes = pd.date_range(
|
|
82
|
+
t0+minutes(gsp_config.interval_start_minutes),
|
|
83
|
+
t0+minutes(gsp_config.interval_end_minutes),
|
|
84
|
+
freq=minutes(gsp_config.time_resolution_minutes),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
lon, lat = osgb_to_lon_lat(location.x, location.y)
|
|
88
|
+
|
|
89
|
+
numpy_modalities.append(
|
|
90
|
+
make_sun_position_numpy_sample(datetimes, lon, lat, key_prefix=target_key)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Combine all the modalities and fill NaNs
|
|
94
|
+
combined_sample = merge_dicts(numpy_modalities)
|
|
95
|
+
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
96
|
+
|
|
97
|
+
return combined_sample
|
|
98
|
+
|
|
99
|
+
def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
100
|
+
"""Merge a list of dictionaries into a single dictionary"""
|
|
101
|
+
# TODO: This doesn't account for duplicate keys, which will be overwritten
|
|
102
|
+
combined_dict = {}
|
|
103
|
+
for d in list_of_dicts:
|
|
104
|
+
combined_dict.update(d)
|
|
105
|
+
return combined_dict
|
|
106
|
+
|
|
107
|
+
def fill_nans_in_arrays(sample: dict) -> dict:
|
|
108
|
+
"""Fills all NaN values in each np.ndarray in the sample dictionary with zeros.
|
|
109
|
+
|
|
110
|
+
Operation is performed in-place on the sample.
|
|
111
|
+
"""
|
|
112
|
+
for k, v in sample.items():
|
|
113
|
+
if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
|
|
114
|
+
if np.isnan(v).any():
|
|
115
|
+
sample[k] = np.nan_to_num(v, copy=False, nan=0.0)
|
|
116
|
+
|
|
117
|
+
# Recursion is included to reach NWP arrays in subdict
|
|
118
|
+
elif isinstance(v, dict):
|
|
119
|
+
fill_nans_in_arrays(v)
|
|
120
|
+
|
|
121
|
+
return sample
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def compute(xarray_dict: dict) -> dict:
|
|
125
|
+
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
126
|
+
for k, v in xarray_dict.items():
|
|
127
|
+
if isinstance(v, dict):
|
|
128
|
+
xarray_dict[k] = compute(v)
|
|
129
|
+
else:
|
|
130
|
+
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
131
|
+
return xarray_dict
|