ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +86 -72
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/constants.py +140 -12
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +27 -36
- ocf_data_sampler/load/site.py +11 -7
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +15 -13
- ocf_data_sampler/numpy_sample/nwp.py +17 -23
- ocf_data_sampler/numpy_sample/satellite.py +17 -14
- ocf_data_sampler/numpy_sample/site.py +8 -7
- ocf_data_sampler/numpy_sample/sun_position.py +19 -25
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +23 -44
- ocf_data_sampler/sample/site.py +25 -69
- ocf_data_sampler/sample/uk_regional.py +52 -103
- ocf_data_sampler/select/dropout.py +42 -27
- ocf_data_sampler/select/fill_time_periods.py +15 -3
- ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
- ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +62 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -286
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -52
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -75
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -195
- tests/test_sample/test_uk_regional_sample.py +0 -163
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -167
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,275 +0,0 @@
|
|
|
1
|
-
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
import xarray as xr
|
|
6
|
-
import pytest
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
NWP_FREQ = pd.Timedelta("3h")
|
|
10
|
-
|
|
11
|
-
@pytest.fixture(scope="module")
|
|
12
|
-
def da_sat_like():
|
|
13
|
-
"""Create dummy data which looks like satellite data"""
|
|
14
|
-
x = np.arange(-100, 100)
|
|
15
|
-
y = np.arange(-100, 100)
|
|
16
|
-
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")
|
|
17
|
-
|
|
18
|
-
da_sat = xr.DataArray(
|
|
19
|
-
np.random.normal(size=(len(datetimes), len(x), len(y))),
|
|
20
|
-
coords=dict(
|
|
21
|
-
time_utc=(["time_utc"], datetimes),
|
|
22
|
-
x_geostationary=(["x_geostationary"], x),
|
|
23
|
-
y_geostationary=(["y_geostationary"], y),
|
|
24
|
-
)
|
|
25
|
-
)
|
|
26
|
-
return da_sat
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@pytest.fixture(scope="module")
|
|
30
|
-
def da_nwp_like():
|
|
31
|
-
"""Create dummy data which looks like NWP data"""
|
|
32
|
-
|
|
33
|
-
x = np.arange(-100, 100)
|
|
34
|
-
y = np.arange(-100, 100)
|
|
35
|
-
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
|
|
36
|
-
steps = pd.timedelta_range("0h", "16h", freq="1h")
|
|
37
|
-
channels = ["t", "dswrf"]
|
|
38
|
-
|
|
39
|
-
da_nwp = xr.DataArray(
|
|
40
|
-
np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
|
|
41
|
-
coords=dict(
|
|
42
|
-
init_time_utc=(["init_time_utc"], datetimes),
|
|
43
|
-
step=(["step"], steps),
|
|
44
|
-
channel=(["channel"], channels),
|
|
45
|
-
x_osgb=(["x_osgb"], x),
|
|
46
|
-
y_osgb=(["y_osgb"], y),
|
|
47
|
-
)
|
|
48
|
-
)
|
|
49
|
-
return da_nwp
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@pytest.mark.parametrize("t0_str", ["12:30", "12:40", "12:00"])
|
|
53
|
-
def test_select_time_slice(da_sat_like, t0_str):
|
|
54
|
-
"""Test the basic functionality of select_time_slice"""
|
|
55
|
-
|
|
56
|
-
# Slice parameters
|
|
57
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
58
|
-
interval_start = pd.Timedelta(-0, "min")
|
|
59
|
-
interval_end = pd.Timedelta(60, "min")
|
|
60
|
-
freq = pd.Timedelta("5min")
|
|
61
|
-
|
|
62
|
-
# Expect to return these timestamps from the selection
|
|
63
|
-
expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq)
|
|
64
|
-
|
|
65
|
-
# Make the selection
|
|
66
|
-
sat_sample = select_time_slice(
|
|
67
|
-
da_sat_like,
|
|
68
|
-
t0=t0,
|
|
69
|
-
interval_start=interval_start,
|
|
70
|
-
interval_end=interval_end,
|
|
71
|
-
sample_period_duration=freq,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
# Check the returned times are as expected
|
|
75
|
-
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@pytest.mark.parametrize("t0_str", ["00:00", "00:25", "11:00", "11:55"])
|
|
79
|
-
def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
|
|
80
|
-
"""Test the behaviour of select_time_slice when the selection is out of bounds"""
|
|
81
|
-
|
|
82
|
-
# Slice parameters
|
|
83
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
84
|
-
interval_start = pd.Timedelta(-30, "min")
|
|
85
|
-
interval_end = pd.Timedelta(60, "min")
|
|
86
|
-
freq = pd.Timedelta("5min")
|
|
87
|
-
|
|
88
|
-
# The data is available between these times
|
|
89
|
-
min_time = pd.Timestamp(da_sat_like.time_utc.min().item())
|
|
90
|
-
max_time = pd.Timestamp(da_sat_like.time_utc.max().item())
|
|
91
|
-
|
|
92
|
-
# Expect to return these timestamps within the requested range
|
|
93
|
-
expected_datetimes = pd.date_range(
|
|
94
|
-
max(t0 + interval_start, min_time),
|
|
95
|
-
min(t0 + interval_end, max_time),
|
|
96
|
-
freq=freq,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# Make the partially out of bounds selection
|
|
100
|
-
sat_sample = select_time_slice(
|
|
101
|
-
da_sat_like,
|
|
102
|
-
t0=t0,
|
|
103
|
-
interval_start=interval_start,
|
|
104
|
-
interval_end=interval_end,
|
|
105
|
-
sample_period_duration=freq,
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
# Check the returned times are as expected
|
|
109
|
-
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
# Check all the values before the first timestamp available in the data are NaN
|
|
113
|
-
all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))
|
|
114
|
-
if expected_datetimes[0] < min_time:
|
|
115
|
-
assert all_nan_space.sel(time_utc=slice(None, min_time-freq)).all(dim="time_utc")
|
|
116
|
-
|
|
117
|
-
# Check all the values before the first timestamp available in the data are NaN
|
|
118
|
-
if expected_datetimes[-1] > max_time:
|
|
119
|
-
assert all_nan_space.sel(time_utc=slice(max_time+freq, None)).all(dim="time_utc")
|
|
120
|
-
|
|
121
|
-
# Check that none of the values between the first and last available timestamp are NaN
|
|
122
|
-
any_nan_space = sat_sample.isnull().any(dim=("x_geostationary", "y_geostationary"))
|
|
123
|
-
assert not any_nan_space.sel(time_utc=slice(min_time, max_time)).any(dim="time_utc")
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
|
|
127
|
-
def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
|
|
128
|
-
"""Test the basic functionality of select_time_slice_nwp"""
|
|
129
|
-
|
|
130
|
-
# Slice parameters
|
|
131
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
132
|
-
interval_start = pd.Timedelta(-6, "h")
|
|
133
|
-
interval_end = pd.Timedelta(3, "h")
|
|
134
|
-
freq = pd.Timedelta("1h")
|
|
135
|
-
|
|
136
|
-
# Make the selection
|
|
137
|
-
da_slice = select_time_slice_nwp(
|
|
138
|
-
da_nwp_like,
|
|
139
|
-
t0,
|
|
140
|
-
sample_period_duration=freq,
|
|
141
|
-
interval_start=interval_start,
|
|
142
|
-
interval_end=interval_end,
|
|
143
|
-
dropout_timedeltas = None,
|
|
144
|
-
dropout_frac = 0,
|
|
145
|
-
accum_channels = [],
|
|
146
|
-
channel_dim_name = "channel",
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# Check the target-times are as expected
|
|
150
|
-
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
151
|
-
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
152
|
-
|
|
153
|
-
# Check the init-times are as expected
|
|
154
|
-
# - Forecast frequency is `NWP_FREQ`, and we can't have selected future init-times
|
|
155
|
-
expected_init_times = pd.to_datetime(
|
|
156
|
-
[t if t<t0 else t0 for t in expected_target_times]
|
|
157
|
-
).floor(NWP_FREQ)
|
|
158
|
-
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
@pytest.mark.parametrize("dropout_hours", [1, 2, 5])
|
|
162
|
-
def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
|
|
163
|
-
"""Test the functionality of select_time_slice_nwp with dropout"""
|
|
164
|
-
|
|
165
|
-
t0 = pd.Timestamp("2024-01-02 12:00")
|
|
166
|
-
interval_start = pd.Timedelta(-6, "h")
|
|
167
|
-
interval_end = pd.Timedelta(3, "h")
|
|
168
|
-
freq = pd.Timedelta("1h")
|
|
169
|
-
dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h")
|
|
170
|
-
|
|
171
|
-
da_slice = select_time_slice_nwp(
|
|
172
|
-
da_nwp_like,
|
|
173
|
-
t0,
|
|
174
|
-
sample_period_duration=freq,
|
|
175
|
-
interval_start=interval_start,
|
|
176
|
-
interval_end=interval_end,
|
|
177
|
-
dropout_timedeltas = [dropout_timedelta],
|
|
178
|
-
dropout_frac = 1,
|
|
179
|
-
accum_channels = [],
|
|
180
|
-
channel_dim_name = "channel",
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# Check the target-times are as expected
|
|
184
|
-
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
185
|
-
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
186
|
-
|
|
187
|
-
# Check the init-times are as expected considering the delay
|
|
188
|
-
t0_delayed = t0 + dropout_timedelta
|
|
189
|
-
expected_init_times = pd.to_datetime(
|
|
190
|
-
[t if t<t0_delayed else t0_delayed for t in expected_target_times]
|
|
191
|
-
).floor(NWP_FREQ)
|
|
192
|
-
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
|
|
196
|
-
def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
197
|
-
"""Test the functionality of select_time_slice_nwp with dropout and accumulated variables"""
|
|
198
|
-
|
|
199
|
-
# Slice parameters
|
|
200
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
201
|
-
interval_start = pd.Timedelta(-6, "h")
|
|
202
|
-
interval_end = pd.Timedelta(3, "h")
|
|
203
|
-
freq = pd.Timedelta("1h")
|
|
204
|
-
dropout_timedelta = pd.Timedelta("-2h")
|
|
205
|
-
|
|
206
|
-
t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
|
|
207
|
-
|
|
208
|
-
da_slice = select_time_slice_nwp(
|
|
209
|
-
da_nwp_like,
|
|
210
|
-
t0,
|
|
211
|
-
sample_period_duration=freq,
|
|
212
|
-
interval_start=interval_start,
|
|
213
|
-
interval_end=interval_end,
|
|
214
|
-
dropout_timedeltas=[dropout_timedelta],
|
|
215
|
-
dropout_frac=1,
|
|
216
|
-
accum_channels=["dswrf"],
|
|
217
|
-
channel_dim_name="channel",
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
# Check the target-times are as expected
|
|
221
|
-
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
222
|
-
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
223
|
-
|
|
224
|
-
# Check the init-times are as expected considering the delay
|
|
225
|
-
expected_init_times = pd.to_datetime(
|
|
226
|
-
[t if t<t0_delayed else t0_delayed for t in expected_target_times]
|
|
227
|
-
).floor(NWP_FREQ)
|
|
228
|
-
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
229
|
-
|
|
230
|
-
# Check channels are as expected
|
|
231
|
-
assert (da_slice.channel.values == ["t", "diff_dswrf"]).all()
|
|
232
|
-
|
|
233
|
-
# Check the accummulated channel has been differenced correctly
|
|
234
|
-
|
|
235
|
-
# This part of the data is pulled from the init-time: t0_delayed
|
|
236
|
-
da_slice_accum = da_slice.sel(
|
|
237
|
-
target_time_utc=slice(t0_delayed, None),
|
|
238
|
-
channel="diff_dswrf"
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
# Get the original data for the t0_delayed init-time, and diff it along steps
|
|
242
|
-
# then select the steps which are expected to be used in the above slice
|
|
243
|
-
da_orig_diffed = (
|
|
244
|
-
da_nwp_like.sel(
|
|
245
|
-
init_time_utc=t0_delayed,
|
|
246
|
-
channel="dswrf",
|
|
247
|
-
).diff(dim="step", label="lower")
|
|
248
|
-
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
# Check the values are the same
|
|
252
|
-
assert (da_slice_accum.values == da_orig_diffed.values).all()
|
|
253
|
-
|
|
254
|
-
# Check the non-accummulated channel has not been differenced
|
|
255
|
-
|
|
256
|
-
# This part of the data is pulled from the init-time: t0_delayed
|
|
257
|
-
da_slice_nonaccum = da_slice.sel(
|
|
258
|
-
target_time_utc=slice(t0_delayed, None),
|
|
259
|
-
channel="t"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# Get the original data for the t0_delayed init-time, and select the steps which are expected
|
|
263
|
-
# to be used in the above slice
|
|
264
|
-
da_orig = (
|
|
265
|
-
da_nwp_like.sel(
|
|
266
|
-
init_time_utc=t0_delayed,
|
|
267
|
-
channel="t",
|
|
268
|
-
)
|
|
269
|
-
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
# Check the values are the same
|
|
273
|
-
assert (da_slice_nonaccum.values == da_orig.values).all()
|
|
274
|
-
|
|
275
|
-
|
tests/test_sample/test_base.py
DELETED
|
@@ -1,164 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Base class testing - SampleBase
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
import torch
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from ocf_data_sampler.sample.base import (
|
|
11
|
-
SampleBase,
|
|
12
|
-
batch_to_tensor,
|
|
13
|
-
copy_batch_to_device
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
class TestSample(SampleBase):
|
|
17
|
-
"""
|
|
18
|
-
SampleBase for testing purposes
|
|
19
|
-
Minimal implementations - abstract methods
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def __init__(self):
|
|
23
|
-
super().__init__()
|
|
24
|
-
self._data = {}
|
|
25
|
-
|
|
26
|
-
def plot(self, **kwargs):
|
|
27
|
-
""" Minimal plot implementation """
|
|
28
|
-
return None
|
|
29
|
-
|
|
30
|
-
def to_numpy(self) -> None:
|
|
31
|
-
""" Standard implementation """
|
|
32
|
-
return {key: np.array(value) for key, value in self._data.items()}
|
|
33
|
-
|
|
34
|
-
def save(self, path):
|
|
35
|
-
""" Minimal save implementation """
|
|
36
|
-
path = Path(path)
|
|
37
|
-
with open(path, 'wb') as f:
|
|
38
|
-
f.write(b'test_data')
|
|
39
|
-
|
|
40
|
-
@classmethod
|
|
41
|
-
def load(cls, path):
|
|
42
|
-
""" Minimal load implementation """
|
|
43
|
-
path = Path(path)
|
|
44
|
-
instance = cls()
|
|
45
|
-
return instance
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def test_sample_base_initialisation():
|
|
49
|
-
""" Initialisation of SampleBase subclass """
|
|
50
|
-
|
|
51
|
-
sample = TestSample()
|
|
52
|
-
assert hasattr(sample, '_data'), "Sample should have _data attribute"
|
|
53
|
-
assert sample._data == {}, "Sample should start with empty dict"
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def test_sample_base_save_load(tmp_path):
|
|
57
|
-
""" Test basic save and load functionality """
|
|
58
|
-
|
|
59
|
-
sample = TestSample()
|
|
60
|
-
sample._data['test_data'] = [1, 2, 3]
|
|
61
|
-
|
|
62
|
-
save_path = tmp_path / 'test_sample.dat'
|
|
63
|
-
sample.save(save_path)
|
|
64
|
-
assert save_path.exists()
|
|
65
|
-
|
|
66
|
-
loaded_sample = TestSample.load(save_path)
|
|
67
|
-
assert isinstance(loaded_sample, TestSample)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def test_sample_base_abstract_methods():
|
|
71
|
-
""" Test abstract method enforcement """
|
|
72
|
-
|
|
73
|
-
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
|
74
|
-
SampleBase()
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def test_sample_base_to_numpy():
|
|
78
|
-
""" Test the to_numpy functionality """
|
|
79
|
-
import numpy as np
|
|
80
|
-
|
|
81
|
-
sample = TestSample()
|
|
82
|
-
sample._data = {
|
|
83
|
-
'int_data': 42,
|
|
84
|
-
'list_data': [1, 2, 3]
|
|
85
|
-
}
|
|
86
|
-
numpy_data = sample.to_numpy()
|
|
87
|
-
|
|
88
|
-
assert isinstance(numpy_data, dict)
|
|
89
|
-
assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
|
|
90
|
-
assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def test_batch_to_tensor_nested():
|
|
94
|
-
""" Test nested dictionary conversion """
|
|
95
|
-
batch = {
|
|
96
|
-
'outer': {
|
|
97
|
-
'inner': np.array([1, 2, 3])
|
|
98
|
-
}
|
|
99
|
-
}
|
|
100
|
-
tensor_batch = batch_to_tensor(batch)
|
|
101
|
-
|
|
102
|
-
assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def test_batch_to_tensor_mixed_types():
|
|
106
|
-
""" Test handling of mixed data types """
|
|
107
|
-
batch = {
|
|
108
|
-
'tensor_data': np.array([1, 2, 3]),
|
|
109
|
-
'string_data': 'not_a_tensor',
|
|
110
|
-
'nested': {
|
|
111
|
-
'numbers': np.array([4, 5, 6]),
|
|
112
|
-
'text': 'still_not_a_tensor'
|
|
113
|
-
}
|
|
114
|
-
}
|
|
115
|
-
tensor_batch = batch_to_tensor(batch)
|
|
116
|
-
|
|
117
|
-
assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
|
|
118
|
-
assert isinstance(tensor_batch['string_data'], str)
|
|
119
|
-
assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
|
|
120
|
-
assert isinstance(tensor_batch['nested']['text'], str)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def test_batch_to_tensor_different_dtypes():
|
|
124
|
-
""" Test conversion of arrays with different dtypes """
|
|
125
|
-
batch = {
|
|
126
|
-
'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
|
|
127
|
-
'int_data': np.array([1, 2, 3], dtype=np.int64),
|
|
128
|
-
'bool_data': np.array([True, False, True], dtype=np.bool_)
|
|
129
|
-
}
|
|
130
|
-
tensor_batch = batch_to_tensor(batch)
|
|
131
|
-
|
|
132
|
-
assert isinstance(tensor_batch['bool_data'], torch.Tensor)
|
|
133
|
-
assert tensor_batch['float_data'].dtype == torch.float32
|
|
134
|
-
assert tensor_batch['int_data'].dtype == torch.int64
|
|
135
|
-
assert tensor_batch['bool_data'].dtype == torch.bool
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def test_batch_to_tensor_multidimensional():
|
|
139
|
-
""" Test conversion of multidimensional arrays """
|
|
140
|
-
batch = {
|
|
141
|
-
'matrix': np.array([[1, 2], [3, 4]]),
|
|
142
|
-
'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
143
|
-
}
|
|
144
|
-
tensor_batch = batch_to_tensor(batch)
|
|
145
|
-
|
|
146
|
-
assert tensor_batch['matrix'].shape == (2, 2)
|
|
147
|
-
assert tensor_batch['tensor'].shape == (2, 2, 2)
|
|
148
|
-
assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def test_copy_batch_to_device():
|
|
152
|
-
""" Test moving tensors to a different device """
|
|
153
|
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
154
|
-
batch = {
|
|
155
|
-
'tensor_data': torch.tensor([1, 2, 3]),
|
|
156
|
-
'nested': {
|
|
157
|
-
'matrix': torch.tensor([[1, 2], [3, 4]])
|
|
158
|
-
},
|
|
159
|
-
'non_tensor': 'unchanged'
|
|
160
|
-
}
|
|
161
|
-
moved_batch = copy_batch_to_device(batch, device)
|
|
162
|
-
assert moved_batch['tensor_data'].device == device
|
|
163
|
-
assert moved_batch['nested']['matrix'].device == device
|
|
164
|
-
assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged
|
|
@@ -1,195 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Site class testing - SiteSample
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
import numpy as np
|
|
7
|
-
import xarray as xr
|
|
8
|
-
import pandas as pd
|
|
9
|
-
|
|
10
|
-
from pathlib import Path
|
|
11
|
-
from xarray import Dataset
|
|
12
|
-
|
|
13
|
-
from ocf_data_sampler.sample.site import SiteSample
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@pytest.fixture
|
|
17
|
-
def sample_data():
|
|
18
|
-
""" Fixture creation with sample data """
|
|
19
|
-
|
|
20
|
-
# Time periods specified
|
|
21
|
-
init_time = pd.Timestamp("2023-01-01 00:00")
|
|
22
|
-
target_times = pd.date_range("2023-01-01 00:00", periods=4, freq="1h")
|
|
23
|
-
sat_times = pd.date_range("2023-01-01 00:00", periods=7, freq="5min")
|
|
24
|
-
site_times = pd.date_range("2023-01-01 00:00", periods=4, freq="30min")
|
|
25
|
-
|
|
26
|
-
# Defined steps for NWP data
|
|
27
|
-
steps = [(t - init_time) for t in target_times]
|
|
28
|
-
|
|
29
|
-
# Create the sample dataset
|
|
30
|
-
return Dataset(
|
|
31
|
-
data_vars={
|
|
32
|
-
'nwp-ukv': (
|
|
33
|
-
["nwp-ukv__target_time_utc", "nwp-ukv__channel",
|
|
34
|
-
"nwp-ukv__y_osgb", "nwp-ukv__x_osgb"],
|
|
35
|
-
np.random.rand(4, 1, 2, 2)
|
|
36
|
-
),
|
|
37
|
-
'satellite': (
|
|
38
|
-
["satellite__time_utc", "satellite__channel",
|
|
39
|
-
"satellite__y_geostationary", "satellite__x_geostationary"],
|
|
40
|
-
np.random.rand(7, 1, 2, 2)
|
|
41
|
-
),
|
|
42
|
-
'site': (["site__time_utc"], np.random.rand(4))
|
|
43
|
-
},
|
|
44
|
-
coords={
|
|
45
|
-
# NWP coords
|
|
46
|
-
'nwp-ukv__target_time_utc': target_times,
|
|
47
|
-
'nwp-ukv__channel': ['dswrf'],
|
|
48
|
-
'nwp-ukv__y_osgb': [0, 1],
|
|
49
|
-
'nwp-ukv__x_osgb': [0, 1],
|
|
50
|
-
'nwp-ukv__init_time_utc': init_time,
|
|
51
|
-
'nwp-ukv__step': ('nwp-ukv__target_time_utc', steps),
|
|
52
|
-
|
|
53
|
-
# Sat coords
|
|
54
|
-
'satellite__time_utc': sat_times,
|
|
55
|
-
'satellite__channel': ['vis'],
|
|
56
|
-
'satellite__y_geostationary': [0, 1],
|
|
57
|
-
'satellite__x_geostationary': [0, 1],
|
|
58
|
-
|
|
59
|
-
# Site coords
|
|
60
|
-
'site__time_utc': site_times,
|
|
61
|
-
'site__capacity_kwp': 1000.0,
|
|
62
|
-
'site__site_id': 1,
|
|
63
|
-
'site__longitude': -3.5,
|
|
64
|
-
'site__latitude': 51.5,
|
|
65
|
-
|
|
66
|
-
# Site features as coords
|
|
67
|
-
'site__solar_azimuth': ('site__time_utc', np.random.rand(4)),
|
|
68
|
-
'site__solar_elevation': ('site__time_utc', np.random.rand(4)),
|
|
69
|
-
'site__date_cos': ('site__time_utc', np.random.rand(4)),
|
|
70
|
-
'site__date_sin': ('site__time_utc', np.random.rand(4)),
|
|
71
|
-
'site__time_cos': ('site__time_utc', np.random.rand(4)),
|
|
72
|
-
'site__time_sin': ('site__time_utc', np.random.rand(4))
|
|
73
|
-
}
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def test_site_sample_init():
|
|
78
|
-
""" Test initialisation """
|
|
79
|
-
sample = SiteSample()
|
|
80
|
-
assert isinstance(sample._data, dict)
|
|
81
|
-
assert len(sample._data) == 0
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def test_site_sample_with_data(sample_data):
|
|
85
|
-
""" Testing of defined sample with actual data """
|
|
86
|
-
sample = SiteSample()
|
|
87
|
-
sample._data = sample_data
|
|
88
|
-
|
|
89
|
-
# Assert data structure
|
|
90
|
-
assert isinstance(sample._data, Dataset)
|
|
91
|
-
|
|
92
|
-
# Assert dimensions / shapes
|
|
93
|
-
expected_dims = {
|
|
94
|
-
"satellite__x_geostationary",
|
|
95
|
-
"site__time_utc",
|
|
96
|
-
"nwp-ukv__target_time_utc",
|
|
97
|
-
"nwp-ukv__x_osgb",
|
|
98
|
-
"satellite__channel",
|
|
99
|
-
"satellite__y_geostationary",
|
|
100
|
-
"satellite__time_utc",
|
|
101
|
-
"nwp-ukv__channel",
|
|
102
|
-
"nwp-ukv__y_osgb",
|
|
103
|
-
}
|
|
104
|
-
assert set(sample._data.dims) == expected_dims
|
|
105
|
-
assert sample._data["satellite"].values.shape == (7, 1, 2, 2)
|
|
106
|
-
assert sample._data["nwp-ukv"].values.shape == (4, 1, 2, 2)
|
|
107
|
-
assert sample._data["site"].values.shape == (4,)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def test_save_load(tmp_path, sample_data):
|
|
111
|
-
""" Save and load functionality """
|
|
112
|
-
sample = SiteSample()
|
|
113
|
-
sample._data = sample_data
|
|
114
|
-
filepath = tmp_path / "test_sample.nc"
|
|
115
|
-
sample.save(filepath)
|
|
116
|
-
|
|
117
|
-
# Assert file exists and has content
|
|
118
|
-
assert filepath.exists()
|
|
119
|
-
assert filepath.stat().st_size > 0
|
|
120
|
-
|
|
121
|
-
# Load and verify
|
|
122
|
-
loaded = SiteSample.load(filepath)
|
|
123
|
-
assert isinstance(loaded, SiteSample)
|
|
124
|
-
assert isinstance(loaded._data, Dataset)
|
|
125
|
-
|
|
126
|
-
# Compare original / loaded data
|
|
127
|
-
xr.testing.assert_identical(sample._data, loaded._data)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def test_invalid_save_format(sample_data):
|
|
131
|
-
""" Saving with invalid format """
|
|
132
|
-
sample = SiteSample()
|
|
133
|
-
sample._data = sample_data
|
|
134
|
-
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
135
|
-
sample.save("invalid.txt")
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def test_invalid_load_format():
|
|
139
|
-
""" Loading with invalid format """
|
|
140
|
-
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
141
|
-
SiteSample.load("invalid.txt")
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def test_invalid_data_type():
|
|
145
|
-
""" Handling of invalid data types """
|
|
146
|
-
sample = SiteSample()
|
|
147
|
-
sample._data = {"invalid": "data"}
|
|
148
|
-
|
|
149
|
-
with pytest.raises(TypeError, match="Data must be xarray Dataset"):
|
|
150
|
-
sample.to_numpy()
|
|
151
|
-
|
|
152
|
-
with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
|
|
153
|
-
sample.save("test.nc")
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def test_to_numpy(sample_data):
|
|
157
|
-
""" To numpy conversion """
|
|
158
|
-
sample = SiteSample()
|
|
159
|
-
sample._data = sample_data
|
|
160
|
-
numpy_data = sample.to_numpy()
|
|
161
|
-
|
|
162
|
-
# Assert structure
|
|
163
|
-
assert isinstance(numpy_data, dict)
|
|
164
|
-
assert 'site' in numpy_data
|
|
165
|
-
assert 'nwp' in numpy_data
|
|
166
|
-
|
|
167
|
-
# Check site - numpy array instead of dict
|
|
168
|
-
site_data = numpy_data['site']
|
|
169
|
-
assert isinstance(site_data, np.ndarray)
|
|
170
|
-
assert site_data.ndim == 1
|
|
171
|
-
assert len(site_data) == 4
|
|
172
|
-
assert np.all(site_data >= 0) and np.all(site_data <= 1)
|
|
173
|
-
|
|
174
|
-
# Check NWP
|
|
175
|
-
assert 'ukv' in numpy_data['nwp']
|
|
176
|
-
nwp_data = numpy_data['nwp']['ukv']
|
|
177
|
-
assert 'nwp' in nwp_data
|
|
178
|
-
assert nwp_data['nwp'].shape == (4, 1, 2, 2)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def test_data_consistency(sample_data):
|
|
182
|
-
""" Consistency of data across operations """
|
|
183
|
-
sample = SiteSample()
|
|
184
|
-
sample._data = sample_data
|
|
185
|
-
numpy_data = sample.to_numpy()
|
|
186
|
-
|
|
187
|
-
# Assert components remain consistent after conversion above
|
|
188
|
-
assert numpy_data['nwp']['ukv']['nwp'].shape == (4, 1, 2, 2)
|
|
189
|
-
assert 'site' in numpy_data
|
|
190
|
-
|
|
191
|
-
# Update site data checks to expect numpy array
|
|
192
|
-
assert isinstance(numpy_data['site'], np.ndarray)
|
|
193
|
-
assert numpy_data['site'].shape == (4,)
|
|
194
|
-
assert np.all(numpy_data['site'] >= 0)
|
|
195
|
-
assert np.all(numpy_data['site'] <= 1)
|