ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (64) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +246 -0
  4. ocf_data_sampler/config/save.py +73 -0
  5. ocf_data_sampler/constants.py +173 -0
  6. ocf_data_sampler/load/load_dataset.py +55 -0
  7. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  8. ocf_data_sampler/load/site.py +30 -0
  9. ocf_data_sampler/numpy_sample/__init__.py +8 -0
  10. ocf_data_sampler/numpy_sample/collate.py +75 -0
  11. ocf_data_sampler/numpy_sample/gsp.py +34 -0
  12. ocf_data_sampler/numpy_sample/nwp.py +42 -0
  13. ocf_data_sampler/numpy_sample/satellite.py +30 -0
  14. ocf_data_sampler/numpy_sample/site.py +30 -0
  15. ocf_data_sampler/{numpy_batch → numpy_sample}/sun_position.py +9 -10
  16. ocf_data_sampler/select/__init__.py +8 -1
  17. ocf_data_sampler/select/dropout.py +4 -3
  18. ocf_data_sampler/select/find_contiguous_time_periods.py +40 -75
  19. ocf_data_sampler/select/geospatial.py +160 -0
  20. ocf_data_sampler/select/location.py +62 -0
  21. ocf_data_sampler/select/select_spatial_slice.py +13 -16
  22. ocf_data_sampler/select/select_time_slice.py +24 -33
  23. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  24. ocf_data_sampler/select/time_slice_for_dataset.py +125 -0
  25. ocf_data_sampler/torch_datasets/__init__.py +2 -1
  26. ocf_data_sampler/torch_datasets/process_and_combine.py +131 -0
  27. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +11 -425
  28. ocf_data_sampler/torch_datasets/site.py +405 -0
  29. ocf_data_sampler/torch_datasets/valid_time_periods.py +116 -0
  30. ocf_data_sampler/utils.py +10 -0
  31. ocf_data_sampler-0.0.43.dist-info/METADATA +154 -0
  32. ocf_data_sampler-0.0.43.dist-info/RECORD +71 -0
  33. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/WHEEL +1 -1
  34. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/top_level.txt +1 -0
  35. scripts/refactor_site.py +50 -0
  36. tests/config/test_config.py +161 -0
  37. tests/config/test_save.py +37 -0
  38. tests/conftest.py +86 -1
  39. tests/load/test_load_gsp.py +15 -0
  40. tests/load/test_load_nwp.py +21 -0
  41. tests/load/test_load_satellite.py +17 -0
  42. tests/load/test_load_sites.py +14 -0
  43. tests/numpy_sample/test_collate.py +26 -0
  44. tests/numpy_sample/test_gsp.py +38 -0
  45. tests/numpy_sample/test_nwp.py +52 -0
  46. tests/numpy_sample/test_satellite.py +40 -0
  47. tests/numpy_sample/test_sun_position.py +81 -0
  48. tests/select/test_dropout.py +75 -0
  49. tests/select/test_fill_time_periods.py +28 -0
  50. tests/select/test_find_contiguous_time_periods.py +202 -0
  51. tests/select/test_location.py +67 -0
  52. tests/select/test_select_spatial_slice.py +154 -0
  53. tests/select/test_select_time_slice.py +272 -0
  54. tests/torch_datasets/conftest.py +18 -0
  55. tests/torch_datasets/test_process_and_combine.py +126 -0
  56. tests/torch_datasets/test_pvnet_uk_regional.py +59 -0
  57. tests/torch_datasets/test_site.py +129 -0
  58. ocf_data_sampler/numpy_batch/__init__.py +0 -7
  59. ocf_data_sampler/numpy_batch/gsp.py +0 -20
  60. ocf_data_sampler/numpy_batch/nwp.py +0 -33
  61. ocf_data_sampler/numpy_batch/satellite.py +0 -23
  62. ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
  63. ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
  64. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.43.dist-info}/LICENSE +0 -0
@@ -0,0 +1,154 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ from ocf_data_sampler.select.location import Location
4
+ import pytest
5
+
6
+ from ocf_data_sampler.select.select_spatial_slice import (
7
+ select_spatial_slice_pixels, _get_idx_of_pixel_closest_to_poi
8
+ )
9
+
10
+ @pytest.fixture(scope="module")
11
+ def da():
12
+ # Create dummy data
13
+ x = np.arange(-100, 100)
14
+ y = np.arange(-100, 100)
15
+
16
+ da = xr.DataArray(
17
+ np.random.normal(size=(len(x), len(y))),
18
+ coords=dict(
19
+ x_osgb=(["x_osgb"], x),
20
+ y_osgb=(["y_osgb"], y),
21
+ )
22
+ )
23
+ return da
24
+
25
+
26
+ def test_get_idx_of_pixel_closest_to_poi(da):
27
+
28
+ idx_location = _get_idx_of_pixel_closest_to_poi(
29
+ da,
30
+ location=Location(x=10, y=10, coordinate_system="osgb"),
31
+ )
32
+
33
+ assert idx_location.coordinate_system == "idx"
34
+ assert idx_location.x == 110
35
+ assert idx_location.y == 110
36
+
37
+
38
+
39
+
40
+ def test_select_spatial_slice_pixels(da):
41
+
42
+ # Select window which lies within x-y bounds of the data
43
+ da_sliced = select_spatial_slice_pixels(
44
+ da,
45
+ location=Location(x=-90, y=-80, coordinate_system="osgb"),
46
+ width_pixels=10,
47
+ height_pixels=10,
48
+ allow_partial_slice=True,
49
+ )
50
+
51
+
52
+ assert isinstance(da_sliced, xr.DataArray)
53
+ assert (da_sliced.x_osgb.values == np.arange(-95, -85)).all()
54
+ assert (da_sliced.y_osgb.values == np.arange(-85, -75)).all()
55
+ # No padding in this case so no NaNs
56
+ assert not da_sliced.isnull().any()
57
+
58
+
59
+ # Select window where the edge of the window lies right on the edge of the data
60
+ da_sliced = select_spatial_slice_pixels(
61
+ da,
62
+ location=Location(x=-90, y=-80, coordinate_system="osgb"),
63
+ width_pixels=20,
64
+ height_pixels=20,
65
+ allow_partial_slice=True,
66
+ )
67
+
68
+ assert isinstance(da_sliced, xr.DataArray)
69
+ assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
70
+ assert (da_sliced.y_osgb.values == np.arange(-90, -70)).all()
71
+ # No padding in this case so no NaNs
72
+ assert not da_sliced.isnull().any()
73
+
74
+ # Select window which is partially outside the boundary of the data - padded on left
75
+ da_sliced = select_spatial_slice_pixels(
76
+ da,
77
+ location=Location(x=-90, y=-80, coordinate_system="osgb"),
78
+ width_pixels=30,
79
+ height_pixels=30,
80
+ allow_partial_slice=True,
81
+ )
82
+
83
+ assert isinstance(da_sliced, xr.DataArray)
84
+ assert (da_sliced.x_osgb.values == np.arange(-105, -75)).all()
85
+ assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
86
+ # Data has been padded on left by 5 NaN pixels
87
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
88
+
89
+
90
+ # Select window which is partially outside the boundary of the data - padded on right
91
+ da_sliced = select_spatial_slice_pixels(
92
+ da,
93
+ location=Location(x=90, y=-80, coordinate_system="osgb"),
94
+ width_pixels=30,
95
+ height_pixels=30,
96
+ allow_partial_slice=True,
97
+ )
98
+
99
+ assert isinstance(da_sliced, xr.DataArray)
100
+ assert (da_sliced.x_osgb.values == np.arange(75, 105)).all()
101
+ assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
102
+ # Data has been padded on right by 5 NaN pixels
103
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
104
+
105
+
106
+ location = Location(x=-90, y=-0, coordinate_system="osgb")
107
+
108
+ # Select window which is partially outside the boundary of the data - padded on top
109
+ da_sliced = select_spatial_slice_pixels(
110
+ da,
111
+ location=Location(x=-90, y=95, coordinate_system="osgb"),
112
+ width_pixels=20,
113
+ height_pixels=20,
114
+ allow_partial_slice=True,
115
+ )
116
+
117
+ assert isinstance(da_sliced, xr.DataArray)
118
+ assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
119
+ assert (da_sliced.y_osgb.values == np.arange(85, 105)).all()
120
+ # Data has been padded on top by 5 NaN pixels
121
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
122
+
123
+ # Select window which is partially outside the boundary of the data - padded on bottom
124
+ da_sliced = select_spatial_slice_pixels(
125
+ da,
126
+ location=Location(x=-90, y=-95, coordinate_system="osgb"),
127
+ width_pixels=20,
128
+ height_pixels=20,
129
+ allow_partial_slice=True,
130
+ )
131
+
132
+ assert isinstance(da_sliced, xr.DataArray)
133
+ assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
134
+ assert (da_sliced.y_osgb.values == np.arange(-105, -85)).all()
135
+ # Data has been padded on bottom by 5 NaN pixels
136
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
137
+
138
+ # Select window which is partially outside the boundary of the data - padded right and bottom
139
+ da_sliced = select_spatial_slice_pixels(
140
+ da,
141
+ location=Location(x=90, y=-80, coordinate_system="osgb"),
142
+ width_pixels=50,
143
+ height_pixels=50,
144
+ allow_partial_slice=True,
145
+ )
146
+
147
+ assert isinstance(da_sliced, xr.DataArray)
148
+ assert (da_sliced.x_osgb.values == np.arange(65, 115)).all()
149
+ assert (da_sliced.y_osgb.values == np.arange(-105, -55)).all()
150
+ # Data has been padded on right by 15 pixels and bottom by 5 NaN pixels
151
+ assert da_sliced.isnull().sum() == 15*len(da_sliced.y_osgb) + 5*len(da_sliced.x_osgb) - 15*5
152
+
153
+
154
+
@@ -0,0 +1,272 @@
1
+ from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+ import pytest
7
+
8
+
9
+ NWP_FREQ = pd.Timedelta("3h")
10
+
11
+ @pytest.fixture(scope="module")
12
+ def da_sat_like():
13
+ """Create dummy data which looks like satellite data"""
14
+ x = np.arange(-100, 100)
15
+ y = np.arange(-100, 100)
16
+ datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")
17
+
18
+ da_sat = xr.DataArray(
19
+ np.random.normal(size=(len(datetimes), len(x), len(y))),
20
+ coords=dict(
21
+ time_utc=(["time_utc"], datetimes),
22
+ x_geostationary=(["x_geostationary"], x),
23
+ y_geostationary=(["y_geostationary"], y),
24
+ )
25
+ )
26
+ return da_sat
27
+
28
+
29
+ @pytest.fixture(scope="module")
30
+ def da_nwp_like():
31
+ """Create dummy data which looks like NWP data"""
32
+
33
+ x = np.arange(-100, 100)
34
+ y = np.arange(-100, 100)
35
+ datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
36
+ steps = pd.timedelta_range("0h", "16h", freq="1h")
37
+ channels = ["t", "dswrf"]
38
+
39
+ da_nwp = xr.DataArray(
40
+ np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
41
+ coords=dict(
42
+ init_time_utc=(["init_time_utc"], datetimes),
43
+ step=(["step"], steps),
44
+ channel=(["channel"], channels),
45
+ x_osgb=(["x_osgb"], x),
46
+ y_osgb=(["y_osgb"], y),
47
+ )
48
+ )
49
+ return da_nwp
50
+
51
+
52
+ @pytest.mark.parametrize("t0_str", ["12:30", "12:40", "12:00"])
53
+ def test_select_time_slice(da_sat_like, t0_str):
54
+ """Test the basic functionality of select_time_slice"""
55
+
56
+ # Slice parameters
57
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
58
+ interval_start = pd.Timedelta(-0, "min")
59
+ interval_end = pd.Timedelta(60, "min")
60
+ freq = pd.Timedelta("5min")
61
+
62
+ # Expect to return these timestamps from the selection
63
+ expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq)
64
+
65
+ # Make the selection
66
+ sat_sample = select_time_slice(
67
+ da_sat_like,
68
+ t0=t0,
69
+ interval_start=interval_start,
70
+ interval_end=interval_end,
71
+ sample_period_duration=freq,
72
+ )
73
+
74
+ # Check the returned times are as expected
75
+ assert (sat_sample.time_utc == expected_datetimes).all()
76
+
77
+
78
+ @pytest.mark.parametrize("t0_str", ["00:00", "00:25", "11:00", "11:55"])
79
+ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
80
+ """Test the behaviour of select_time_slice when the selection is out of bounds"""
81
+
82
+ # Slice parameters
83
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
84
+ interval_start = pd.Timedelta(-30, "min")
85
+ interval_end = pd.Timedelta(60, "min")
86
+ freq = pd.Timedelta("5min")
87
+
88
+ # The data is available between these times
89
+ min_time = da_sat_like.time_utc.min()
90
+ max_time = da_sat_like.time_utc.max()
91
+
92
+ # Expect to return these timestamps from the selection
93
+ expected_datetimes = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
94
+
95
+ # Make the partially out of bounds selection
96
+ sat_sample = select_time_slice(
97
+ da_sat_like,
98
+ t0=t0,
99
+ interval_start=interval_start,
100
+ interval_end=interval_end,
101
+ sample_period_duration=freq,
102
+ fill_selection=True
103
+ )
104
+
105
+ # Check the returned times are as expected
106
+ assert (sat_sample.time_utc == expected_datetimes).all()
107
+
108
+
109
+ # Check all the values before the first timestamp available in the data are NaN
110
+ all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))
111
+ if expected_datetimes[0] < min_time:
112
+ assert all_nan_space.sel(time_utc=slice(None, min_time-freq)).all(dim="time_utc")
113
+
114
+ # Check all the values before the first timestamp available in the data are NaN
115
+ if expected_datetimes[-1] > max_time:
116
+ assert all_nan_space.sel(time_utc=slice(max_time+freq, None)).all(dim="time_utc")
117
+
118
+ # Check that none of the values between the first and last available timestamp are NaN
119
+ any_nan_space = sat_sample.isnull().any(dim=("x_geostationary", "y_geostationary"))
120
+ assert not any_nan_space.sel(time_utc=slice(min_time, max_time)).any(dim="time_utc")
121
+
122
+
123
+ @pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
124
+ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
125
+ """Test the basic functionality of select_time_slice_nwp"""
126
+
127
+ # Slice parameters
128
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
129
+ interval_start = pd.Timedelta(-6, "h")
130
+ interval_end = pd.Timedelta(3, "h")
131
+ freq = pd.Timedelta("1h")
132
+
133
+ # Make the selection
134
+ da_slice = select_time_slice_nwp(
135
+ da_nwp_like,
136
+ t0,
137
+ sample_period_duration=freq,
138
+ interval_start=interval_start,
139
+ interval_end=interval_end,
140
+ dropout_timedeltas = None,
141
+ dropout_frac = 0,
142
+ accum_channels = [],
143
+ channel_dim_name = "channel",
144
+ )
145
+
146
+ # Check the target-times are as expected
147
+ expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
148
+ assert (da_slice.target_time_utc==expected_target_times).all()
149
+
150
+ # Check the init-times are as expected
151
+ # - Forecast frequency is `NWP_FREQ`, and we can't have selected future init-times
152
+ expected_init_times = pd.to_datetime(
153
+ [t if t<t0 else t0 for t in expected_target_times]
154
+ ).floor(NWP_FREQ)
155
+ assert (da_slice.init_time_utc==expected_init_times).all()
156
+
157
+
158
+ @pytest.mark.parametrize("dropout_hours", [1, 2, 5])
159
+ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
160
+ """Test the functionality of select_time_slice_nwp with dropout"""
161
+
162
+ t0 = pd.Timestamp("2024-01-02 12:00")
163
+ interval_start = pd.Timedelta(-6, "h")
164
+ interval_end = pd.Timedelta(3, "h")
165
+ freq = pd.Timedelta("1h")
166
+ dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h")
167
+
168
+ da_slice = select_time_slice_nwp(
169
+ da_nwp_like,
170
+ t0,
171
+ sample_period_duration=freq,
172
+ interval_start=interval_start,
173
+ interval_end=interval_end,
174
+ dropout_timedeltas = [dropout_timedelta],
175
+ dropout_frac = 1,
176
+ accum_channels = [],
177
+ channel_dim_name = "channel",
178
+ )
179
+
180
+ # Check the target-times are as expected
181
+ expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
182
+ assert (da_slice.target_time_utc==expected_target_times).all()
183
+
184
+ # Check the init-times are as expected considering the delay
185
+ t0_delayed = t0 + dropout_timedelta
186
+ expected_init_times = pd.to_datetime(
187
+ [t if t<t0_delayed else t0_delayed for t in expected_target_times]
188
+ ).floor(NWP_FREQ)
189
+ assert (da_slice.init_time_utc==expected_init_times).all()
190
+
191
+
192
+ @pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
193
+ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
194
+ """Test the functionality of select_time_slice_nwp with dropout and accumulated variables"""
195
+
196
+ # Slice parameters
197
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
198
+ interval_start = pd.Timedelta(-6, "h")
199
+ interval_end = pd.Timedelta(3, "h")
200
+ freq = pd.Timedelta("1h")
201
+ dropout_timedelta = pd.Timedelta("-2h")
202
+
203
+ t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
204
+
205
+ da_slice = select_time_slice_nwp(
206
+ da_nwp_like,
207
+ t0,
208
+ sample_period_duration=freq,
209
+ interval_start=interval_start,
210
+ interval_end=interval_end,
211
+ dropout_timedeltas=[dropout_timedelta],
212
+ dropout_frac=1,
213
+ accum_channels=["dswrf"],
214
+ channel_dim_name="channel",
215
+ )
216
+
217
+ # Check the target-times are as expected
218
+ expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
219
+ assert (da_slice.target_time_utc==expected_target_times).all()
220
+
221
+ # Check the init-times are as expected considering the delay
222
+ expected_init_times = pd.to_datetime(
223
+ [t if t<t0_delayed else t0_delayed for t in expected_target_times]
224
+ ).floor(NWP_FREQ)
225
+ assert (da_slice.init_time_utc==expected_init_times).all()
226
+
227
+ # Check channels are as expected
228
+ assert (da_slice.channel.values == ["t", "diff_dswrf"]).all()
229
+
230
+ # Check the accummulated channel has been differenced correctly
231
+
232
+ # This part of the data is pulled from the init-time: t0_delayed
233
+ da_slice_accum = da_slice.sel(
234
+ target_time_utc=slice(t0_delayed, None),
235
+ channel="diff_dswrf"
236
+ )
237
+
238
+ # Get the original data for the t0_delayed init-time, and diff it along steps
239
+ # then select the steps which are expected to be used in the above slice
240
+ da_orig_diffed = (
241
+ da_nwp_like.sel(
242
+ init_time_utc=t0_delayed,
243
+ channel="dswrf",
244
+ ).diff(dim="step", label="lower")
245
+ .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
246
+ )
247
+
248
+ # Check the values are the same
249
+ assert (da_slice_accum.values == da_orig_diffed.values).all()
250
+
251
+ # Check the non-accummulated channel has not been differenced
252
+
253
+ # This part of the data is pulled from the init-time: t0_delayed
254
+ da_slice_nonaccum = da_slice.sel(
255
+ target_time_utc=slice(t0_delayed, None),
256
+ channel="t"
257
+ )
258
+
259
+ # Get the original data for the t0_delayed init-time, and select the steps which are expected
260
+ # to be used in the above slice
261
+ da_orig = (
262
+ da_nwp_like.sel(
263
+ init_time_utc=t0_delayed,
264
+ channel="t",
265
+ )
266
+ .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
267
+ )
268
+
269
+ # Check the values are the same
270
+ assert (da_slice_nonaccum.values == da_orig.values).all()
271
+
272
+
@@ -0,0 +1,18 @@
1
+ import pytest
2
+
3
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
4
+
5
+
6
+ @pytest.fixture()
7
+ def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
8
+
9
+ # adjust config to point to the zarr file
10
+ config = load_yaml_configuration(config_filename)
11
+ config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
12
+ config.input_data.satellite.zarr_path = sat_zarr_path
13
+ config.input_data.site = data_sites
14
+ config.input_data.gsp = None
15
+
16
+ filename = f"{tmp_path}/configuration.yaml"
17
+ save_yaml_configuration(config, filename)
18
+ return filename
@@ -0,0 +1,126 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xarray as xr
4
+ import dask.array as da
5
+
6
+ from ocf_data_sampler.config import load_yaml_configuration
7
+ from ocf_data_sampler.select.location import Location
8
+ from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
9
+ from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
10
+
11
+ from ocf_data_sampler.torch_datasets.process_and_combine import (
12
+ process_and_combine_datasets,
13
+ merge_dicts,
14
+ fill_nans_in_arrays,
15
+ compute,
16
+ )
17
+
18
+
19
+ def test_process_and_combine_datasets(pvnet_config_filename):
20
+
21
+ # Load in config for function and define location
22
+ config = load_yaml_configuration(pvnet_config_filename)
23
+ t0 = pd.Timestamp("2024-01-01 00:00")
24
+ location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
25
+
26
+ nwp_data = xr.DataArray(
27
+ np.random.rand(4, 2, 2, 2),
28
+ dims=["time_utc", "channel", "y", "x"],
29
+ coords={
30
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
31
+ "channel": ["t2m", "dswrf"],
32
+ "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
33
+ "init_time_utc": pd.Timestamp("2024-01-01 00:00")
34
+ }
35
+ )
36
+
37
+ sat_data = xr.DataArray(
38
+ np.random.rand(7, 1, 2, 2),
39
+ dims=["time_utc", "channel", "y", "x"],
40
+ coords={
41
+ "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
42
+ "channel": ["HRV"],
43
+ "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
44
+ "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
45
+ }
46
+ )
47
+
48
+ # Combine as dict
49
+ dataset_dict = {
50
+ "nwp": {"ukv": nwp_data},
51
+ "sat": sat_data
52
+ }
53
+
54
+ # Call relevant function
55
+ result = process_and_combine_datasets(dataset_dict, config, t0, location)
56
+
57
+ # Assert result is dict - check and validate
58
+ assert isinstance(result, dict)
59
+ assert NWPSampleKey.nwp in result
60
+ assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
61
+ assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
62
+
63
+
64
+ def test_merge_dicts():
65
+ """Test merge_dicts function"""
66
+ dict1 = {"a": 1, "b": 2}
67
+ dict2 = {"c": 3, "d": 4}
68
+ dict3 = {"e": 5}
69
+
70
+ result = merge_dicts([dict1, dict2, dict3])
71
+ assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
72
+
73
+ # Test key overwriting
74
+ dict4 = {"a": 10, "f": 6}
75
+ result = merge_dicts([dict1, dict4])
76
+ assert result["a"] == 10
77
+
78
+
79
+ def test_fill_nans_in_arrays():
80
+ """Test the fill_nans_in_arrays function"""
81
+ array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
82
+ nested_dict = {
83
+ "array1": array_with_nans,
84
+ "nested": {
85
+ "array2": np.array([np.nan, 2.0, np.nan, 4.0])
86
+ },
87
+ "string_key": "not_an_array"
88
+ }
89
+
90
+ result = fill_nans_in_arrays(nested_dict)
91
+
92
+ assert not np.isnan(result["array1"]).any()
93
+ assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
94
+ assert not np.isnan(result["nested"]["array2"]).any()
95
+ assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
96
+ assert result["string_key"] == "not_an_array"
97
+
98
+
99
+ def test_compute():
100
+ """Test compute function with dask array"""
101
+ da_dask = xr.DataArray(da.random.random((5, 5)))
102
+
103
+ # Create a nested dictionary with dask array
104
+ nested_dict = {
105
+ "array1": da_dask,
106
+ "nested": {
107
+ "array2": da_dask
108
+ }
109
+ }
110
+
111
+ # Ensure initial data is lazy - i.e. not yet computed
112
+ assert not isinstance(nested_dict["array1"].data, np.ndarray)
113
+ assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
114
+
115
+ # Call the compute function
116
+ result = compute(nested_dict)
117
+
118
+ # Assert that the result is an xarray DataArray and no longer lazy
119
+ assert isinstance(result["array1"], xr.DataArray)
120
+ assert isinstance(result["nested"]["array2"], xr.DataArray)
121
+ assert isinstance(result["array1"].data, np.ndarray)
122
+ assert isinstance(result["nested"]["array2"].data, np.ndarray)
123
+
124
+ # Ensure there no NaN values in computed data
125
+ assert not np.isnan(result["array1"].data).any()
126
+ assert not np.isnan(result["nested"]["array2"].data).any()
@@ -0,0 +1,59 @@
1
+ import pytest
2
+ import tempfile
3
+
4
+ from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
5
+ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
+ from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
7
+
8
+
9
+
10
+ def test_pvnet(pvnet_config_filename):
11
+
12
+ # Create dataset object
13
+ dataset = PVNetUKRegionalDataset(pvnet_config_filename)
14
+
15
+ assert len(dataset.locations) == 317 # no of GSPs not including the National level
16
+ # NB. I have not checked this value is in fact correct, but it does seem to stay constant
17
+ assert len(dataset.valid_t0_times) == 39
18
+ assert len(dataset) == 317*39
19
+
20
+ # Generate a sample
21
+ sample = dataset[0]
22
+
23
+ assert isinstance(sample, dict)
24
+
25
+ for key in [
26
+ NWPSampleKey.nwp, SatelliteSampleKey.satellite_actual, GSPSampleKey.gsp,
27
+ GSPSampleKey.solar_azimuth, GSPSampleKey.solar_elevation,
28
+ ]:
29
+ assert key in sample
30
+
31
+ for nwp_source in ["ukv"]:
32
+ assert nwp_source in sample[NWPSampleKey.nwp]
33
+
34
+ # check the shape of the data is correct
35
+ # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
36
+ assert sample[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
37
+ # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
38
+ assert sample[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
39
+ # 3 hours of 30 minute data (inclusive)
40
+ assert sample[GSPSampleKey.gsp].shape == (7,)
41
+ # Solar angles have same shape as GSP data
42
+ assert sample[GSPSampleKey.solar_azimuth].shape == (7,)
43
+ assert sample[GSPSampleKey.solar_elevation].shape == (7,)
44
+
45
+ def test_pvnet_no_gsp(pvnet_config_filename):
46
+
47
+ # load config
48
+ config = load_yaml_configuration(pvnet_config_filename)
49
+ # remove gsp
50
+ config.input_data.gsp.zarr_path = ''
51
+
52
+ # save temp config file
53
+ with tempfile.NamedTemporaryFile() as temp_config_file:
54
+ save_yaml_configuration(config, temp_config_file.name)
55
+ # Create dataset object
56
+ dataset = PVNetUKRegionalDataset(temp_config_file.name)
57
+
58
+ # Generate a sample
59
+ _ = dataset[0]