ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/__init__.py +5 -0
- ocf_data_sampler/config/load.py +33 -0
- ocf_data_sampler/config/model.py +246 -0
- ocf_data_sampler/config/save.py +73 -0
- ocf_data_sampler/constants.py +173 -0
- ocf_data_sampler/load/load_dataset.py +55 -0
- ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
- ocf_data_sampler/load/site.py +30 -0
- ocf_data_sampler/numpy_sample/__init__.py +8 -0
- ocf_data_sampler/numpy_sample/collate.py +75 -0
- ocf_data_sampler/numpy_sample/gsp.py +34 -0
- ocf_data_sampler/numpy_sample/nwp.py +42 -0
- ocf_data_sampler/numpy_sample/satellite.py +30 -0
- ocf_data_sampler/numpy_sample/site.py +30 -0
- ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
- ocf_data_sampler/select/__init__.py +8 -1
- ocf_data_sampler/select/dropout.py +4 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
- ocf_data_sampler/select/geospatial.py +160 -0
- ocf_data_sampler/select/location.py +62 -0
- ocf_data_sampler/select/select_spatial_slice.py +13 -16
- ocf_data_sampler/select/select_time_slice.py +24 -33
- ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
- ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
- ocf_data_sampler/torch_datasets/__init__.py +2 -1
- ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +11 -425
- ocf_data_sampler/torch_datasets/site.py +405 -0
- ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
- ocf_data_sampler/utils.py +10 -0
- ocf_data_sampler-0.0.43.dist-info/METADATA +154 -0
- ocf_data_sampler-0.0.43.dist-info/RECORD +71 -0
- {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/top_level.txt +1 -0
- scripts/refactor_site.py +50 -0
- tests/config/test_config.py +161 -0
- tests/config/test_save.py +37 -0
- tests/conftest.py +86 -1
- tests/load/test_load_gsp.py +15 -0
- tests/load/test_load_nwp.py +21 -0
- tests/load/test_load_satellite.py +17 -0
- tests/load/test_load_sites.py +14 -0
- tests/numpy_sample/test_collate.py +26 -0
- tests/numpy_sample/test_gsp.py +38 -0
- tests/numpy_sample/test_nwp.py +52 -0
- tests/numpy_sample/test_satellite.py +40 -0
- tests/numpy_sample/test_sun_position.py +81 -0
- tests/select/test_dropout.py +75 -0
- tests/select/test_fill_time_periods.py +28 -0
- tests/select/test_find_contiguous_time_periods.py +202 -0
- tests/select/test_location.py +67 -0
- tests/select/test_select_spatial_slice.py +154 -0
- tests/select/test_select_time_slice.py +272 -0
- tests/torch_datasets/conftest.py +18 -0
- tests/torch_datasets/test_process_and_combine.py +126 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
- tests/torch_datasets/test_site.py +129 -0
- ocf_data_sampler/numpy_batch/__init__.py +0 -7
- ocf_data_sampler/numpy_batch/gsp.py +0 -20
- ocf_data_sampler/numpy_batch/nwp.py +0 -33
- ocf_data_sampler/numpy_batch/satellite.py +0 -23
- ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
- ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
- {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
from ocf_data_sampler.select.location import Location
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from ocf_data_sampler.select.select_spatial_slice import (
|
|
7
|
+
select_spatial_slice_pixels, _get_idx_of_pixel_closest_to_poi
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
@pytest.fixture(scope="module")
|
|
11
|
+
def da():
|
|
12
|
+
# Create dummy data
|
|
13
|
+
x = np.arange(-100, 100)
|
|
14
|
+
y = np.arange(-100, 100)
|
|
15
|
+
|
|
16
|
+
da = xr.DataArray(
|
|
17
|
+
np.random.normal(size=(len(x), len(y))),
|
|
18
|
+
coords=dict(
|
|
19
|
+
x_osgb=(["x_osgb"], x),
|
|
20
|
+
y_osgb=(["y_osgb"], y),
|
|
21
|
+
)
|
|
22
|
+
)
|
|
23
|
+
return da
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_get_idx_of_pixel_closest_to_poi(da):
|
|
27
|
+
|
|
28
|
+
idx_location = _get_idx_of_pixel_closest_to_poi(
|
|
29
|
+
da,
|
|
30
|
+
location=Location(x=10, y=10, coordinate_system="osgb"),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
assert idx_location.coordinate_system == "idx"
|
|
34
|
+
assert idx_location.x == 110
|
|
35
|
+
assert idx_location.y == 110
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_select_spatial_slice_pixels(da):
|
|
41
|
+
|
|
42
|
+
# Select window which lies within x-y bounds of the data
|
|
43
|
+
da_sliced = select_spatial_slice_pixels(
|
|
44
|
+
da,
|
|
45
|
+
location=Location(x=-90, y=-80, coordinate_system="osgb"),
|
|
46
|
+
width_pixels=10,
|
|
47
|
+
height_pixels=10,
|
|
48
|
+
allow_partial_slice=True,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
53
|
+
assert (da_sliced.x_osgb.values == np.arange(-95, -85)).all()
|
|
54
|
+
assert (da_sliced.y_osgb.values == np.arange(-85, -75)).all()
|
|
55
|
+
# No padding in this case so no NaNs
|
|
56
|
+
assert not da_sliced.isnull().any()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Select window where the edge of the window lies right on the edge of the data
|
|
60
|
+
da_sliced = select_spatial_slice_pixels(
|
|
61
|
+
da,
|
|
62
|
+
location=Location(x=-90, y=-80, coordinate_system="osgb"),
|
|
63
|
+
width_pixels=20,
|
|
64
|
+
height_pixels=20,
|
|
65
|
+
allow_partial_slice=True,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
69
|
+
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
|
|
70
|
+
assert (da_sliced.y_osgb.values == np.arange(-90, -70)).all()
|
|
71
|
+
# No padding in this case so no NaNs
|
|
72
|
+
assert not da_sliced.isnull().any()
|
|
73
|
+
|
|
74
|
+
# Select window which is partially outside the boundary of the data - padded on left
|
|
75
|
+
da_sliced = select_spatial_slice_pixels(
|
|
76
|
+
da,
|
|
77
|
+
location=Location(x=-90, y=-80, coordinate_system="osgb"),
|
|
78
|
+
width_pixels=30,
|
|
79
|
+
height_pixels=30,
|
|
80
|
+
allow_partial_slice=True,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
84
|
+
assert (da_sliced.x_osgb.values == np.arange(-105, -75)).all()
|
|
85
|
+
assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
|
|
86
|
+
# Data has been padded on left by 5 NaN pixels
|
|
87
|
+
assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Select window which is partially outside the boundary of the data - padded on right
|
|
91
|
+
da_sliced = select_spatial_slice_pixels(
|
|
92
|
+
da,
|
|
93
|
+
location=Location(x=90, y=-80, coordinate_system="osgb"),
|
|
94
|
+
width_pixels=30,
|
|
95
|
+
height_pixels=30,
|
|
96
|
+
allow_partial_slice=True,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
100
|
+
assert (da_sliced.x_osgb.values == np.arange(75, 105)).all()
|
|
101
|
+
assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
|
|
102
|
+
# Data has been padded on right by 5 NaN pixels
|
|
103
|
+
assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
location = Location(x=-90, y=-0, coordinate_system="osgb")
|
|
107
|
+
|
|
108
|
+
# Select window which is partially outside the boundary of the data - padded on top
|
|
109
|
+
da_sliced = select_spatial_slice_pixels(
|
|
110
|
+
da,
|
|
111
|
+
location=Location(x=-90, y=95, coordinate_system="osgb"),
|
|
112
|
+
width_pixels=20,
|
|
113
|
+
height_pixels=20,
|
|
114
|
+
allow_partial_slice=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
118
|
+
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
|
|
119
|
+
assert (da_sliced.y_osgb.values == np.arange(85, 105)).all()
|
|
120
|
+
# Data has been padded on top by 5 NaN pixels
|
|
121
|
+
assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
|
|
122
|
+
|
|
123
|
+
# Select window which is partially outside the boundary of the data - padded on bottom
|
|
124
|
+
da_sliced = select_spatial_slice_pixels(
|
|
125
|
+
da,
|
|
126
|
+
location=Location(x=-90, y=-95, coordinate_system="osgb"),
|
|
127
|
+
width_pixels=20,
|
|
128
|
+
height_pixels=20,
|
|
129
|
+
allow_partial_slice=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
133
|
+
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
|
|
134
|
+
assert (da_sliced.y_osgb.values == np.arange(-105, -85)).all()
|
|
135
|
+
# Data has been padded on bottom by 5 NaN pixels
|
|
136
|
+
assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
|
|
137
|
+
|
|
138
|
+
# Select window which is partially outside the boundary of the data - padded right and bottom
|
|
139
|
+
da_sliced = select_spatial_slice_pixels(
|
|
140
|
+
da,
|
|
141
|
+
location=Location(x=90, y=-80, coordinate_system="osgb"),
|
|
142
|
+
width_pixels=50,
|
|
143
|
+
height_pixels=50,
|
|
144
|
+
allow_partial_slice=True,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
assert isinstance(da_sliced, xr.DataArray)
|
|
148
|
+
assert (da_sliced.x_osgb.values == np.arange(65, 115)).all()
|
|
149
|
+
assert (da_sliced.y_osgb.values == np.arange(-105, -55)).all()
|
|
150
|
+
# Data has been padded on right by 15 pixels and bottom by 5 NaN pixels
|
|
151
|
+
assert da_sliced.isnull().sum() == 15*len(da_sliced.y_osgb) + 5*len(da_sliced.x_osgb) - 15*5
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import xarray as xr
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
NWP_FREQ = pd.Timedelta("3h")
|
|
10
|
+
|
|
11
|
+
@pytest.fixture(scope="module")
|
|
12
|
+
def da_sat_like():
|
|
13
|
+
"""Create dummy data which looks like satellite data"""
|
|
14
|
+
x = np.arange(-100, 100)
|
|
15
|
+
y = np.arange(-100, 100)
|
|
16
|
+
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")
|
|
17
|
+
|
|
18
|
+
da_sat = xr.DataArray(
|
|
19
|
+
np.random.normal(size=(len(datetimes), len(x), len(y))),
|
|
20
|
+
coords=dict(
|
|
21
|
+
time_utc=(["time_utc"], datetimes),
|
|
22
|
+
x_geostationary=(["x_geostationary"], x),
|
|
23
|
+
y_geostationary=(["y_geostationary"], y),
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
return da_sat
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture(scope="module")
|
|
30
|
+
def da_nwp_like():
|
|
31
|
+
"""Create dummy data which looks like NWP data"""
|
|
32
|
+
|
|
33
|
+
x = np.arange(-100, 100)
|
|
34
|
+
y = np.arange(-100, 100)
|
|
35
|
+
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
|
|
36
|
+
steps = pd.timedelta_range("0h", "16h", freq="1h")
|
|
37
|
+
channels = ["t", "dswrf"]
|
|
38
|
+
|
|
39
|
+
da_nwp = xr.DataArray(
|
|
40
|
+
np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
|
|
41
|
+
coords=dict(
|
|
42
|
+
init_time_utc=(["init_time_utc"], datetimes),
|
|
43
|
+
step=(["step"], steps),
|
|
44
|
+
channel=(["channel"], channels),
|
|
45
|
+
x_osgb=(["x_osgb"], x),
|
|
46
|
+
y_osgb=(["y_osgb"], y),
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
return da_nwp
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.mark.parametrize("t0_str", ["12:30", "12:40", "12:00"])
|
|
53
|
+
def test_select_time_slice(da_sat_like, t0_str):
|
|
54
|
+
"""Test the basic functionality of select_time_slice"""
|
|
55
|
+
|
|
56
|
+
# Slice parameters
|
|
57
|
+
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
58
|
+
interval_start = pd.Timedelta(-0, "min")
|
|
59
|
+
interval_end = pd.Timedelta(60, "min")
|
|
60
|
+
freq = pd.Timedelta("5min")
|
|
61
|
+
|
|
62
|
+
# Expect to return these timestamps from the selection
|
|
63
|
+
expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq)
|
|
64
|
+
|
|
65
|
+
# Make the selection
|
|
66
|
+
sat_sample = select_time_slice(
|
|
67
|
+
da_sat_like,
|
|
68
|
+
t0=t0,
|
|
69
|
+
interval_start=interval_start,
|
|
70
|
+
interval_end=interval_end,
|
|
71
|
+
sample_period_duration=freq,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Check the returned times are as expected
|
|
75
|
+
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.mark.parametrize("t0_str", ["00:00", "00:25", "11:00", "11:55"])
|
|
79
|
+
def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
|
|
80
|
+
"""Test the behaviour of select_time_slice when the selection is out of bounds"""
|
|
81
|
+
|
|
82
|
+
# Slice parameters
|
|
83
|
+
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
84
|
+
interval_start = pd.Timedelta(-30, "min")
|
|
85
|
+
interval_end = pd.Timedelta(60, "min")
|
|
86
|
+
freq = pd.Timedelta("5min")
|
|
87
|
+
|
|
88
|
+
# The data is available between these times
|
|
89
|
+
min_time = da_sat_like.time_utc.min()
|
|
90
|
+
max_time = da_sat_like.time_utc.max()
|
|
91
|
+
|
|
92
|
+
# Expect to return these timestamps from the selection
|
|
93
|
+
expected_datetimes = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
94
|
+
|
|
95
|
+
# Make the partially out of bounds selection
|
|
96
|
+
sat_sample = select_time_slice(
|
|
97
|
+
da_sat_like,
|
|
98
|
+
t0=t0,
|
|
99
|
+
interval_start=interval_start,
|
|
100
|
+
interval_end=interval_end,
|
|
101
|
+
sample_period_duration=freq,
|
|
102
|
+
fill_selection=True
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Check the returned times are as expected
|
|
106
|
+
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Check all the values before the first timestamp available in the data are NaN
|
|
110
|
+
all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))
|
|
111
|
+
if expected_datetimes[0] < min_time:
|
|
112
|
+
assert all_nan_space.sel(time_utc=slice(None, min_time-freq)).all(dim="time_utc")
|
|
113
|
+
|
|
114
|
+
# Check all the values before the first timestamp available in the data are NaN
|
|
115
|
+
if expected_datetimes[-1] > max_time:
|
|
116
|
+
assert all_nan_space.sel(time_utc=slice(max_time+freq, None)).all(dim="time_utc")
|
|
117
|
+
|
|
118
|
+
# Check that none of the values between the first and last available timestamp are NaN
|
|
119
|
+
any_nan_space = sat_sample.isnull().any(dim=("x_geostationary", "y_geostationary"))
|
|
120
|
+
assert not any_nan_space.sel(time_utc=slice(min_time, max_time)).any(dim="time_utc")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
|
|
124
|
+
def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
|
|
125
|
+
"""Test the basic functionality of select_time_slice_nwp"""
|
|
126
|
+
|
|
127
|
+
# Slice parameters
|
|
128
|
+
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
129
|
+
interval_start = pd.Timedelta(-6, "h")
|
|
130
|
+
interval_end = pd.Timedelta(3, "h")
|
|
131
|
+
freq = pd.Timedelta("1h")
|
|
132
|
+
|
|
133
|
+
# Make the selection
|
|
134
|
+
da_slice = select_time_slice_nwp(
|
|
135
|
+
da_nwp_like,
|
|
136
|
+
t0,
|
|
137
|
+
sample_period_duration=freq,
|
|
138
|
+
interval_start=interval_start,
|
|
139
|
+
interval_end=interval_end,
|
|
140
|
+
dropout_timedeltas = None,
|
|
141
|
+
dropout_frac = 0,
|
|
142
|
+
accum_channels = [],
|
|
143
|
+
channel_dim_name = "channel",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Check the target-times are as expected
|
|
147
|
+
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
148
|
+
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
149
|
+
|
|
150
|
+
# Check the init-times are as expected
|
|
151
|
+
# - Forecast frequency is `NWP_FREQ`, and we can't have selected future init-times
|
|
152
|
+
expected_init_times = pd.to_datetime(
|
|
153
|
+
[t if t<t0 else t0 for t in expected_target_times]
|
|
154
|
+
).floor(NWP_FREQ)
|
|
155
|
+
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@pytest.mark.parametrize("dropout_hours", [1, 2, 5])
|
|
159
|
+
def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
|
|
160
|
+
"""Test the functionality of select_time_slice_nwp with dropout"""
|
|
161
|
+
|
|
162
|
+
t0 = pd.Timestamp("2024-01-02 12:00")
|
|
163
|
+
interval_start = pd.Timedelta(-6, "h")
|
|
164
|
+
interval_end = pd.Timedelta(3, "h")
|
|
165
|
+
freq = pd.Timedelta("1h")
|
|
166
|
+
dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h")
|
|
167
|
+
|
|
168
|
+
da_slice = select_time_slice_nwp(
|
|
169
|
+
da_nwp_like,
|
|
170
|
+
t0,
|
|
171
|
+
sample_period_duration=freq,
|
|
172
|
+
interval_start=interval_start,
|
|
173
|
+
interval_end=interval_end,
|
|
174
|
+
dropout_timedeltas = [dropout_timedelta],
|
|
175
|
+
dropout_frac = 1,
|
|
176
|
+
accum_channels = [],
|
|
177
|
+
channel_dim_name = "channel",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Check the target-times are as expected
|
|
181
|
+
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
182
|
+
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
183
|
+
|
|
184
|
+
# Check the init-times are as expected considering the delay
|
|
185
|
+
t0_delayed = t0 + dropout_timedelta
|
|
186
|
+
expected_init_times = pd.to_datetime(
|
|
187
|
+
[t if t<t0_delayed else t0_delayed for t in expected_target_times]
|
|
188
|
+
).floor(NWP_FREQ)
|
|
189
|
+
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
|
|
193
|
+
def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
194
|
+
"""Test the functionality of select_time_slice_nwp with dropout and accumulated variables"""
|
|
195
|
+
|
|
196
|
+
# Slice parameters
|
|
197
|
+
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
198
|
+
interval_start = pd.Timedelta(-6, "h")
|
|
199
|
+
interval_end = pd.Timedelta(3, "h")
|
|
200
|
+
freq = pd.Timedelta("1h")
|
|
201
|
+
dropout_timedelta = pd.Timedelta("-2h")
|
|
202
|
+
|
|
203
|
+
t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
|
|
204
|
+
|
|
205
|
+
da_slice = select_time_slice_nwp(
|
|
206
|
+
da_nwp_like,
|
|
207
|
+
t0,
|
|
208
|
+
sample_period_duration=freq,
|
|
209
|
+
interval_start=interval_start,
|
|
210
|
+
interval_end=interval_end,
|
|
211
|
+
dropout_timedeltas=[dropout_timedelta],
|
|
212
|
+
dropout_frac=1,
|
|
213
|
+
accum_channels=["dswrf"],
|
|
214
|
+
channel_dim_name="channel",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Check the target-times are as expected
|
|
218
|
+
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
219
|
+
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
220
|
+
|
|
221
|
+
# Check the init-times are as expected considering the delay
|
|
222
|
+
expected_init_times = pd.to_datetime(
|
|
223
|
+
[t if t<t0_delayed else t0_delayed for t in expected_target_times]
|
|
224
|
+
).floor(NWP_FREQ)
|
|
225
|
+
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
226
|
+
|
|
227
|
+
# Check channels are as expected
|
|
228
|
+
assert (da_slice.channel.values == ["t", "diff_dswrf"]).all()
|
|
229
|
+
|
|
230
|
+
# Check the accummulated channel has been differenced correctly
|
|
231
|
+
|
|
232
|
+
# This part of the data is pulled from the init-time: t0_delayed
|
|
233
|
+
da_slice_accum = da_slice.sel(
|
|
234
|
+
target_time_utc=slice(t0_delayed, None),
|
|
235
|
+
channel="diff_dswrf"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Get the original data for the t0_delayed init-time, and diff it along steps
|
|
239
|
+
# then select the steps which are expected to be used in the above slice
|
|
240
|
+
da_orig_diffed = (
|
|
241
|
+
da_nwp_like.sel(
|
|
242
|
+
init_time_utc=t0_delayed,
|
|
243
|
+
channel="dswrf",
|
|
244
|
+
).diff(dim="step", label="lower")
|
|
245
|
+
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Check the values are the same
|
|
249
|
+
assert (da_slice_accum.values == da_orig_diffed.values).all()
|
|
250
|
+
|
|
251
|
+
# Check the non-accummulated channel has not been differenced
|
|
252
|
+
|
|
253
|
+
# This part of the data is pulled from the init-time: t0_delayed
|
|
254
|
+
da_slice_nonaccum = da_slice.sel(
|
|
255
|
+
target_time_utc=slice(t0_delayed, None),
|
|
256
|
+
channel="t"
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Get the original data for the t0_delayed init-time, and select the steps which are expected
|
|
260
|
+
# to be used in the above slice
|
|
261
|
+
da_orig = (
|
|
262
|
+
da_nwp_like.sel(
|
|
263
|
+
init_time_utc=t0_delayed,
|
|
264
|
+
channel="t",
|
|
265
|
+
)
|
|
266
|
+
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Check the values are the same
|
|
270
|
+
assert (da_slice_nonaccum.values == da_orig.values).all()
|
|
271
|
+
|
|
272
|
+
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture()
|
|
7
|
+
def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
|
|
8
|
+
|
|
9
|
+
# adjust config to point to the zarr file
|
|
10
|
+
config = load_yaml_configuration(config_filename)
|
|
11
|
+
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
12
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
13
|
+
config.input_data.site = data_sites
|
|
14
|
+
config.input_data.gsp = None
|
|
15
|
+
|
|
16
|
+
filename = f"{tmp_path}/configuration.yaml"
|
|
17
|
+
save_yaml_configuration(config, filename)
|
|
18
|
+
return filename
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
4
|
+
import dask.array as da
|
|
5
|
+
|
|
6
|
+
from ocf_data_sampler.config import load_yaml_configuration
|
|
7
|
+
from ocf_data_sampler.select.location import Location
|
|
8
|
+
from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
|
|
9
|
+
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
10
|
+
|
|
11
|
+
from ocf_data_sampler.torch_datasets.process_and_combine import (
|
|
12
|
+
process_and_combine_datasets,
|
|
13
|
+
merge_dicts,
|
|
14
|
+
fill_nans_in_arrays,
|
|
15
|
+
compute,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
20
|
+
|
|
21
|
+
# Load in config for function and define location
|
|
22
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
23
|
+
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
24
|
+
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
25
|
+
|
|
26
|
+
nwp_data = xr.DataArray(
|
|
27
|
+
np.random.rand(4, 2, 2, 2),
|
|
28
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
29
|
+
coords={
|
|
30
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
31
|
+
"channel": ["t2m", "dswrf"],
|
|
32
|
+
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
33
|
+
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
sat_data = xr.DataArray(
|
|
38
|
+
np.random.rand(7, 1, 2, 2),
|
|
39
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
40
|
+
coords={
|
|
41
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
|
|
42
|
+
"channel": ["HRV"],
|
|
43
|
+
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
|
|
44
|
+
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
|
|
45
|
+
}
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Combine as dict
|
|
49
|
+
dataset_dict = {
|
|
50
|
+
"nwp": {"ukv": nwp_data},
|
|
51
|
+
"sat": sat_data
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Call relevant function
|
|
55
|
+
result = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
56
|
+
|
|
57
|
+
# Assert result is dict - check and validate
|
|
58
|
+
assert isinstance(result, dict)
|
|
59
|
+
assert NWPSampleKey.nwp in result
|
|
60
|
+
assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
61
|
+
assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_merge_dicts():
|
|
65
|
+
"""Test merge_dicts function"""
|
|
66
|
+
dict1 = {"a": 1, "b": 2}
|
|
67
|
+
dict2 = {"c": 3, "d": 4}
|
|
68
|
+
dict3 = {"e": 5}
|
|
69
|
+
|
|
70
|
+
result = merge_dicts([dict1, dict2, dict3])
|
|
71
|
+
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
|
72
|
+
|
|
73
|
+
# Test key overwriting
|
|
74
|
+
dict4 = {"a": 10, "f": 6}
|
|
75
|
+
result = merge_dicts([dict1, dict4])
|
|
76
|
+
assert result["a"] == 10
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_fill_nans_in_arrays():
|
|
80
|
+
"""Test the fill_nans_in_arrays function"""
|
|
81
|
+
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
|
|
82
|
+
nested_dict = {
|
|
83
|
+
"array1": array_with_nans,
|
|
84
|
+
"nested": {
|
|
85
|
+
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
|
|
86
|
+
},
|
|
87
|
+
"string_key": "not_an_array"
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
result = fill_nans_in_arrays(nested_dict)
|
|
91
|
+
|
|
92
|
+
assert not np.isnan(result["array1"]).any()
|
|
93
|
+
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
|
|
94
|
+
assert not np.isnan(result["nested"]["array2"]).any()
|
|
95
|
+
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
|
|
96
|
+
assert result["string_key"] == "not_an_array"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_compute():
|
|
100
|
+
"""Test compute function with dask array"""
|
|
101
|
+
da_dask = xr.DataArray(da.random.random((5, 5)))
|
|
102
|
+
|
|
103
|
+
# Create a nested dictionary with dask array
|
|
104
|
+
nested_dict = {
|
|
105
|
+
"array1": da_dask,
|
|
106
|
+
"nested": {
|
|
107
|
+
"array2": da_dask
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
# Ensure initial data is lazy - i.e. not yet computed
|
|
112
|
+
assert not isinstance(nested_dict["array1"].data, np.ndarray)
|
|
113
|
+
assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
|
|
114
|
+
|
|
115
|
+
# Call the compute function
|
|
116
|
+
result = compute(nested_dict)
|
|
117
|
+
|
|
118
|
+
# Assert that the result is an xarray DataArray and no longer lazy
|
|
119
|
+
assert isinstance(result["array1"], xr.DataArray)
|
|
120
|
+
assert isinstance(result["nested"]["array2"], xr.DataArray)
|
|
121
|
+
assert isinstance(result["array1"].data, np.ndarray)
|
|
122
|
+
assert isinstance(result["nested"]["array2"].data, np.ndarray)
|
|
123
|
+
|
|
124
|
+
# Ensure there no NaN values in computed data
|
|
125
|
+
assert not np.isnan(result["array1"].data).any()
|
|
126
|
+
assert not np.isnan(result["nested"]["array2"].data).any()
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import tempfile
|
|
3
|
+
|
|
4
|
+
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
5
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
6
|
+
from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_pvnet(pvnet_config_filename):
|
|
11
|
+
|
|
12
|
+
# Create dataset object
|
|
13
|
+
dataset = PVNetUKRegionalDataset(pvnet_config_filename)
|
|
14
|
+
|
|
15
|
+
assert len(dataset.locations) == 317 # no of GSPs not including the National level
|
|
16
|
+
# NB. I have not checked this value is in fact correct, but it does seem to stay constant
|
|
17
|
+
assert len(dataset.valid_t0_times) == 39
|
|
18
|
+
assert len(dataset) == 317*39
|
|
19
|
+
|
|
20
|
+
# Generate a sample
|
|
21
|
+
sample = dataset[0]
|
|
22
|
+
|
|
23
|
+
assert isinstance(sample, dict)
|
|
24
|
+
|
|
25
|
+
for key in [
|
|
26
|
+
NWPSampleKey.nwp, SatelliteSampleKey.satellite_actual, GSPSampleKey.gsp,
|
|
27
|
+
GSPSampleKey.solar_azimuth, GSPSampleKey.solar_elevation,
|
|
28
|
+
]:
|
|
29
|
+
assert key in sample
|
|
30
|
+
|
|
31
|
+
for nwp_source in ["ukv"]:
|
|
32
|
+
assert nwp_source in sample[NWPSampleKey.nwp]
|
|
33
|
+
|
|
34
|
+
# check the shape of the data is correct
|
|
35
|
+
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
36
|
+
assert sample[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
37
|
+
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
|
|
38
|
+
assert sample[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
|
|
39
|
+
# 3 hours of 30 minute data (inclusive)
|
|
40
|
+
assert sample[GSPSampleKey.gsp].shape == (7,)
|
|
41
|
+
# Solar angles have same shape as GSP data
|
|
42
|
+
assert sample[GSPSampleKey.solar_azimuth].shape == (7,)
|
|
43
|
+
assert sample[GSPSampleKey.solar_elevation].shape == (7,)
|
|
44
|
+
|
|
45
|
+
def test_pvnet_no_gsp(pvnet_config_filename):
|
|
46
|
+
|
|
47
|
+
# load config
|
|
48
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
49
|
+
# remove gsp
|
|
50
|
+
config.input_data.gsp.zarr_path = ''
|
|
51
|
+
|
|
52
|
+
# save temp config file
|
|
53
|
+
with tempfile.NamedTemporaryFile() as temp_config_file:
|
|
54
|
+
save_yaml_configuration(config, temp_config_file.name)
|
|
55
|
+
# Create dataset object
|
|
56
|
+
dataset = PVNetUKRegionalDataset(temp_config_file.name)
|
|
57
|
+
|
|
58
|
+
# Generate a sample
|
|
59
|
+
_ = dataset[0]
|