ocf-data-sampler 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (32) hide show
  1. ocf_data_sampler/config/__init__.py +5 -0
  2. ocf_data_sampler/config/load.py +33 -0
  3. ocf_data_sampler/config/model.py +249 -0
  4. ocf_data_sampler/config/save.py +36 -0
  5. ocf_data_sampler/select/dropout.py +2 -2
  6. ocf_data_sampler/select/geospatial.py +118 -0
  7. ocf_data_sampler/select/location.py +62 -0
  8. ocf_data_sampler/select/select_spatial_slice.py +5 -14
  9. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +1 -2
  10. ocf_data_sampler-0.0.21.dist-info/METADATA +83 -0
  11. ocf_data_sampler-0.0.21.dist-info/RECORD +53 -0
  12. tests/config/test_config.py +152 -0
  13. tests/conftest.py +6 -1
  14. tests/load/test_load_gsp.py +15 -0
  15. tests/load/test_load_nwp.py +21 -0
  16. tests/load/test_load_satellite.py +17 -0
  17. tests/numpy_batch/test_gsp.py +23 -0
  18. tests/numpy_batch/test_nwp.py +54 -0
  19. tests/numpy_batch/test_satellite.py +42 -0
  20. tests/numpy_batch/test_sun_position.py +81 -0
  21. tests/select/test_dropout.py +75 -0
  22. tests/select/test_fill_time_periods.py +28 -0
  23. tests/select/test_find_contiguous_time_periods.py +202 -0
  24. tests/select/test_location.py +67 -0
  25. tests/select/test_select_spatial_slice.py +154 -0
  26. tests/select/test_select_time_slice.py +284 -0
  27. tests/torch_datasets/test_pvnet_uk_regional.py +72 -0
  28. ocf_data_sampler-0.0.19.dist-info/METADATA +0 -22
  29. ocf_data_sampler-0.0.19.dist-info/RECORD +0 -32
  30. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.21.dist-info}/LICENSE +0 -0
  31. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.21.dist-info}/WHEEL +0 -0
  32. {ocf_data_sampler-0.0.19.dist-info → ocf_data_sampler-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ import pandas as pd
2
+
3
+ from ocf_data_sampler.select.find_contiguous_time_periods import (
4
+ find_contiguous_t0_periods, find_contiguous_t0_periods_nwp,
5
+ intersection_of_multiple_dataframes_of_periods,
6
+ )
7
+
8
+
9
+
10
+ def test_find_contiguous_t0_periods():
11
+
12
+ # Create 5-minutely data timestamps
13
+ freq = pd.Timedelta(5, "min")
14
+ history_duration = pd.Timedelta(60, "min")
15
+ forecast_duration = pd.Timedelta(15, "min")
16
+
17
+ datetimes = (
18
+ pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq)
19
+ .delete([5, 6, 30])
20
+ )
21
+
22
+ periods = find_contiguous_t0_periods(
23
+ datetimes=datetimes,
24
+ history_duration=history_duration,
25
+ forecast_duration=forecast_duration,
26
+ sample_period_duration=freq,
27
+ )
28
+
29
+ expected_results = pd.DataFrame(
30
+ {
31
+ "start_dt": pd.to_datetime(
32
+ [
33
+ "2023-01-01 13:35",
34
+ "2023-01-01 15:35",
35
+ ]
36
+ ),
37
+ "end_dt": pd.to_datetime(
38
+ [
39
+ "2023-01-01 14:10",
40
+ "2023-01-01 16:45",
41
+ ]
42
+ ),
43
+ },
44
+ )
45
+
46
+ assert periods.equals(expected_results)
47
+
48
+
49
+ def test_find_contiguous_t0_periods_nwp():
50
+
51
+ # These are the expected results of the test
52
+ expected_results = [
53
+ pd.DataFrame(
54
+ {
55
+ "start_dt": pd.to_datetime(["2023-01-01 03:00", "2023-01-02 03:00"]),
56
+ "end_dt": pd.to_datetime(["2023-01-01 21:00", "2023-01-03 06:00"]),
57
+ },
58
+ ),
59
+ pd.DataFrame(
60
+ {
61
+ "start_dt": pd.to_datetime(
62
+ [
63
+ "2023-01-01 05:00",
64
+ "2023-01-02 05:00",
65
+ ]
66
+ ),
67
+ "end_dt": pd.to_datetime(
68
+ [
69
+ "2023-01-01 21:00",
70
+ "2023-01-03 06:00",
71
+ ]
72
+ ),
73
+ },
74
+ ),
75
+ pd.DataFrame(
76
+ {
77
+ "start_dt": pd.to_datetime(
78
+ [
79
+ "2023-01-01 05:00",
80
+ "2023-01-02 05:00",
81
+ "2023-01-02 14:00",
82
+ ]
83
+ ),
84
+ "end_dt": pd.to_datetime(
85
+ [
86
+ "2023-01-01 18:00",
87
+ "2023-01-02 09:00",
88
+ "2023-01-03 03:00",
89
+ ]
90
+ ),
91
+ },
92
+ ),
93
+ pd.DataFrame(
94
+ {
95
+ "start_dt": pd.to_datetime(
96
+ [
97
+ "2023-01-01 05:00",
98
+ "2023-01-01 11:00",
99
+ "2023-01-02 05:00",
100
+ "2023-01-02 14:00",
101
+ ]
102
+ ),
103
+ "end_dt": pd.to_datetime(
104
+ [
105
+ "2023-01-01 06:00",
106
+ "2023-01-01 15:00",
107
+ "2023-01-02 06:00",
108
+ "2023-01-03 00:00",
109
+ ]
110
+ ),
111
+ },
112
+ ),
113
+ pd.DataFrame(
114
+ {
115
+ "start_dt": pd.to_datetime(
116
+ [
117
+ "2023-01-01 06:00",
118
+ "2023-01-01 12:00",
119
+ "2023-01-02 06:00",
120
+ "2023-01-02 15:00",
121
+ ]
122
+ ),
123
+ "end_dt": pd.to_datetime(
124
+ [
125
+ "2023-01-01 09:00",
126
+ "2023-01-01 18:00",
127
+ "2023-01-02 09:00",
128
+ "2023-01-03 03:00",
129
+ ]
130
+ ),
131
+ },
132
+ ),
133
+ ]
134
+
135
+ # Create 3-hourly init times with a few time stamps missing
136
+ freq = pd.Timedelta(3, "h")
137
+
138
+ datetimes = (
139
+ pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq)
140
+ .delete([1, 4, 5, 6, 7, 9, 10])
141
+ )
142
+
143
+ # Choose some history durations and max stalenesses
144
+ history_durations_hr = [0, 2, 2, 2, 2]
145
+ max_stalenesses_hr = [9, 9, 6, 3, 6]
146
+ max_dropouts_hr = [0, 0, 0, 0, 3]
147
+
148
+ for i in range(len(expected_results)):
149
+ history_duration = pd.Timedelta(history_durations_hr[i], "h")
150
+ max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h")
151
+ max_dropout = pd.Timedelta(max_dropouts_hr[i], "h")
152
+
153
+ time_periods = find_contiguous_t0_periods_nwp(
154
+ datetimes=datetimes,
155
+ history_duration=history_duration,
156
+ max_staleness=max_staleness,
157
+ max_dropout=max_dropout,
158
+ )
159
+
160
+ # Check if results are as expected
161
+ assert time_periods.equals(expected_results[i])
162
+
163
+
164
+ def test_intersection_of_multiple_dataframes_of_periods():
165
+ periods_1 = pd.DataFrame(
166
+ {
167
+ "start_dt": pd.to_datetime(["2023-01-01 05:00", "2023-01-01 14:10"]),
168
+ "end_dt": pd.to_datetime(["2023-01-01 13:35", "2023-01-01 18:00"]),
169
+ },
170
+ )
171
+
172
+ periods_2 = pd.DataFrame(
173
+ {
174
+ "start_dt": pd.to_datetime(["2023-01-01 12:00"]),
175
+ "end_dt": pd.to_datetime(["2023-01-02 00:00"]),
176
+ },
177
+ )
178
+
179
+ periods_3 = pd.DataFrame(
180
+ {
181
+ "start_dt": pd.to_datetime(["2023-01-01 00:00", "2023-01-01 13:00"]),
182
+ "end_dt": pd.to_datetime(["2023-01-01 12:30", "2023-01-01 23:00"]),
183
+ },
184
+ )
185
+
186
+ expected_result = pd.DataFrame(
187
+ {
188
+ "start_dt": pd.to_datetime(
189
+ ["2023-01-01 12:00", "2023-01-01 13:00", "2023-01-01 14:10"]
190
+ ),
191
+ "end_dt": pd.to_datetime([
192
+ "2023-01-01 12:30", "2023-01-01 13:35", "2023-01-01 18:00"]
193
+ ),
194
+ },
195
+ )
196
+
197
+ overlaping_periods = intersection_of_multiple_dataframes_of_periods(
198
+ [periods_1, periods_2, periods_3]
199
+ )
200
+
201
+ # Check if results are as expected
202
+ assert overlaping_periods.equals(expected_result)
@@ -0,0 +1,67 @@
1
+ from ocf_data_sampler.select.location import Location
2
+ import pytest
3
+
4
+
5
+ def test_make_valid_location_object_with_default_coordinate_system():
6
+ x, y = -1000.5, 50000
7
+ location = Location(x=x, y=y)
8
+ assert location.x == x, "location.x value not set correctly"
9
+ assert location.y == y, "location.x value not set correctly"
10
+ assert (
11
+ location.coordinate_system == "osgb"
12
+ ), "location.coordinate_system value not set correctly"
13
+
14
+
15
+ def test_make_valid_location_object_with_osgb_coordinate_system():
16
+ x, y, coordinate_system = 1.2, 22.9, "osgb"
17
+ location = Location(x=x, y=y, coordinate_system=coordinate_system)
18
+ assert location.x == x, "location.x value not set correctly"
19
+ assert location.y == y, "location.x value not set correctly"
20
+ assert (
21
+ location.coordinate_system == coordinate_system
22
+ ), "location.coordinate_system value not set correctly"
23
+
24
+
25
+ def test_make_valid_location_object_with_lon_lat_coordinate_system():
26
+ x, y, coordinate_system = 1.2, 1.2, "lon_lat"
27
+ location = Location(x=x, y=y, coordinate_system=coordinate_system)
28
+ assert location.x == x, "location.x value not set correctly"
29
+ assert location.y == y, "location.x value not set correctly"
30
+ assert (
31
+ location.coordinate_system == coordinate_system
32
+ ), "location.coordinate_system value not set correctly"
33
+
34
+
35
+ def test_make_invalid_location_object_with_invalid_osgb_x():
36
+ x, y, coordinate_system = 10000000, 1.2, "osgb"
37
+ with pytest.raises(ValueError) as err:
38
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
39
+ assert err.typename == "ValidationError"
40
+
41
+
42
+ def test_make_invalid_location_object_with_invalid_osgb_y():
43
+ x, y, coordinate_system = 2.5, 10000000, "osgb"
44
+ with pytest.raises(ValueError) as err:
45
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
46
+ assert err.typename == "ValidationError"
47
+
48
+
49
+ def test_make_invalid_location_object_with_invalid_lon_lat_x():
50
+ x, y, coordinate_system = 200, 1.2, "lon_lat"
51
+ with pytest.raises(ValueError) as err:
52
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
53
+ assert err.typename == "ValidationError"
54
+
55
+
56
+ def test_make_invalid_location_object_with_invalid_lon_lat_y():
57
+ x, y, coordinate_system = 2.5, -200, "lon_lat"
58
+ with pytest.raises(ValueError) as err:
59
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
60
+ assert err.typename == "ValidationError"
61
+
62
+
63
+ def test_make_invalid_location_object_with_invalid_coordinate_system():
64
+ x, y, coordinate_system = 2.5, 1000, "abcd"
65
+ with pytest.raises(ValueError) as err:
66
+ _ = Location(x=x, y=y, coordinate_system=coordinate_system)
67
+ assert err.typename == "ValidationError"
@@ -0,0 +1,154 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ from ocf_datapipes.utils import Location
4
+ import pytest
5
+
6
+ from ocf_data_sampler.select.select_spatial_slice import (
7
+ select_spatial_slice_pixels, _get_idx_of_pixel_closest_to_poi
8
+ )
9
+
10
+ @pytest.fixture(scope="module")
11
+ def da():
12
+ # Create dummy data
13
+ x = np.arange(-100, 100)
14
+ y = np.arange(-100, 100)
15
+
16
+ da = xr.DataArray(
17
+ np.random.normal(size=(len(x), len(y))),
18
+ coords=dict(
19
+ x_osgb=(["x_osgb"], x),
20
+ y_osgb=(["y_osgb"], y),
21
+ )
22
+ )
23
+ return da
24
+
25
+
26
+ def test_get_idx_of_pixel_closest_to_poi(da):
27
+
28
+ idx_location = _get_idx_of_pixel_closest_to_poi(
29
+ da,
30
+ location=Location(x=10, y=10, coordinate_system="osgb"),
31
+ )
32
+
33
+ assert idx_location.coordinate_system == "idx"
34
+ assert idx_location.x == 110
35
+ assert idx_location.y == 110
36
+
37
+
38
+
39
+
40
+ def test_select_spatial_slice_pixels(da):
41
+
42
+ # Select window which lies within x-y bounds of the data
43
+ da_sliced = select_spatial_slice_pixels(
44
+ da,
45
+ location=Location(x=-90, y=-80, coordinate_system="osgb"),
46
+ width_pixels=10,
47
+ height_pixels=10,
48
+ allow_partial_slice=True,
49
+ )
50
+
51
+
52
+ assert isinstance(da_sliced, xr.DataArray)
53
+ assert (da_sliced.x_osgb.values == np.arange(-95, -85)).all()
54
+ assert (da_sliced.y_osgb.values == np.arange(-85, -75)).all()
55
+ # No padding in this case so no NaNs
56
+ assert not da_sliced.isnull().any()
57
+
58
+
59
+ # Select window where the edge of the window lies right on the edge of the data
60
+ da_sliced = select_spatial_slice_pixels(
61
+ da,
62
+ location=Location(x=-90, y=-80, coordinate_system="osgb"),
63
+ width_pixels=20,
64
+ height_pixels=20,
65
+ allow_partial_slice=True,
66
+ )
67
+
68
+ assert isinstance(da_sliced, xr.DataArray)
69
+ assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
70
+ assert (da_sliced.y_osgb.values == np.arange(-90, -70)).all()
71
+ # No padding in this case so no NaNs
72
+ assert not da_sliced.isnull().any()
73
+
74
+ # Select window which is partially outside the boundary of the data - padded on left
75
+ da_sliced = select_spatial_slice_pixels(
76
+ da,
77
+ location=Location(x=-90, y=-80, coordinate_system="osgb"),
78
+ width_pixels=30,
79
+ height_pixels=30,
80
+ allow_partial_slice=True,
81
+ )
82
+
83
+ assert isinstance(da_sliced, xr.DataArray)
84
+ assert (da_sliced.x_osgb.values == np.arange(-105, -75)).all()
85
+ assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
86
+ # Data has been padded on left by 5 NaN pixels
87
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
88
+
89
+
90
+ # Select window which is partially outside the boundary of the data - padded on right
91
+ da_sliced = select_spatial_slice_pixels(
92
+ da,
93
+ location=Location(x=90, y=-80, coordinate_system="osgb"),
94
+ width_pixels=30,
95
+ height_pixels=30,
96
+ allow_partial_slice=True,
97
+ )
98
+
99
+ assert isinstance(da_sliced, xr.DataArray)
100
+ assert (da_sliced.x_osgb.values == np.arange(75, 105)).all()
101
+ assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
102
+ # Data has been padded on right by 5 NaN pixels
103
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
104
+
105
+
106
+ location = Location(x=-90, y=-0, coordinate_system="osgb")
107
+
108
+ # Select window which is partially outside the boundary of the data - padded on top
109
+ da_sliced = select_spatial_slice_pixels(
110
+ da,
111
+ location=Location(x=-90, y=95, coordinate_system="osgb"),
112
+ width_pixels=20,
113
+ height_pixels=20,
114
+ allow_partial_slice=True,
115
+ )
116
+
117
+ assert isinstance(da_sliced, xr.DataArray)
118
+ assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
119
+ assert (da_sliced.y_osgb.values == np.arange(85, 105)).all()
120
+ # Data has been padded on top by 5 NaN pixels
121
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
122
+
123
+ # Select window which is partially outside the boundary of the data - padded on bottom
124
+ da_sliced = select_spatial_slice_pixels(
125
+ da,
126
+ location=Location(x=-90, y=-95, coordinate_system="osgb"),
127
+ width_pixels=20,
128
+ height_pixels=20,
129
+ allow_partial_slice=True,
130
+ )
131
+
132
+ assert isinstance(da_sliced, xr.DataArray)
133
+ assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
134
+ assert (da_sliced.y_osgb.values == np.arange(-105, -85)).all()
135
+ # Data has been padded on bottom by 5 NaN pixels
136
+ assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
137
+
138
+ # Select window which is partially outside the boundary of the data - padded right and bottom
139
+ da_sliced = select_spatial_slice_pixels(
140
+ da,
141
+ location=Location(x=90, y=-80, coordinate_system="osgb"),
142
+ width_pixels=50,
143
+ height_pixels=50,
144
+ allow_partial_slice=True,
145
+ )
146
+
147
+ assert isinstance(da_sliced, xr.DataArray)
148
+ assert (da_sliced.x_osgb.values == np.arange(65, 115)).all()
149
+ assert (da_sliced.y_osgb.values == np.arange(-105, -55)).all()
150
+ # Data has been padded on right by 15 pixels and bottom by 5 NaN pixels
151
+ assert da_sliced.isnull().sum() == 15*len(da_sliced.y_osgb) + 5*len(da_sliced.x_osgb) - 15*5
152
+
153
+
154
+
@@ -0,0 +1,284 @@
1
+ from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+ import pytest
7
+
8
+
9
+ NWP_FREQ = pd.Timedelta("3H")
10
+
11
+ @pytest.fixture(scope="module")
12
+ def da_sat_like():
13
+ """Create dummy data which looks like satellite data"""
14
+ x = np.arange(-100, 100)
15
+ y = np.arange(-100, 100)
16
+ datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")
17
+
18
+ da_sat = xr.DataArray(
19
+ np.random.normal(size=(len(datetimes), len(x), len(y))),
20
+ coords=dict(
21
+ time_utc=(["time_utc"], datetimes),
22
+ x_geostationary=(["x_geostationary"], x),
23
+ y_geostationary=(["y_geostationary"], y),
24
+ )
25
+ )
26
+ return da_sat
27
+
28
+
29
+ @pytest.fixture(scope="module")
30
+ def da_nwp_like():
31
+ """Create dummy data which looks like NWP data"""
32
+
33
+ x = np.arange(-100, 100)
34
+ y = np.arange(-100, 100)
35
+ datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
36
+ steps = pd.timedelta_range("0H", "16H", freq="1H")
37
+ channels = ["t", "dswrf"]
38
+
39
+ da_nwp = xr.DataArray(
40
+ np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
41
+ coords=dict(
42
+ init_time_utc=(["init_time_utc"], datetimes),
43
+ step=(["step"], steps),
44
+ channel=(["channel"], channels),
45
+ x_osgb=(["x_osgb"], x),
46
+ y_osgb=(["y_osgb"], y),
47
+ )
48
+ )
49
+ return da_nwp
50
+
51
+
52
+ @pytest.mark.parametrize("t0_str", ["12:30", "12:40", "12:00"])
53
+ def test_select_time_slice(da_sat_like, t0_str):
54
+ """Test the basic functionality of select_time_slice"""
55
+
56
+ # Slice parameters
57
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
58
+ forecast_duration = pd.Timedelta("0min")
59
+ history_duration = pd.Timedelta("60min")
60
+ freq = pd.Timedelta("5min")
61
+
62
+ # Expect to return these timestamps from the selection
63
+ expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)
64
+
65
+ # Make the selection using the `[x]_duration` parameters
66
+ sat_sample = select_time_slice(
67
+ ds=da_sat_like,
68
+ t0=t0,
69
+ history_duration=history_duration,
70
+ forecast_duration=forecast_duration,
71
+ sample_period_duration=freq,
72
+ )
73
+
74
+ # Check the returned times are as expected
75
+ assert (sat_sample.time_utc == expected_datetimes).all()
76
+
77
+ # Make the selection using the `interval_[x]` parameters
78
+ sat_sample = select_time_slice(
79
+ ds=da_sat_like,
80
+ t0=t0,
81
+ interval_start=-history_duration,
82
+ interval_end=forecast_duration,
83
+ sample_period_duration=freq,
84
+ )
85
+
86
+ # Check the returned times are as expected
87
+ assert (sat_sample.time_utc == expected_datetimes).all()
88
+
89
+
90
+ @pytest.mark.parametrize("t0_str", ["00:00", "00:25", "11:00", "11:55"])
91
+ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
92
+ """Test the behaviour of select_time_slice when the selection is out of bounds"""
93
+
94
+ # Slice parameters
95
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
96
+ forecast_duration = pd.Timedelta("30min")
97
+ history_duration = pd.Timedelta("60min")
98
+ freq = pd.Timedelta("5min")
99
+
100
+ # The data is available between these times
101
+ min_time = da_sat_like.time_utc.min()
102
+ max_time = da_sat_like.time_utc.max()
103
+
104
+ # Expect to return these timestamps from the selection
105
+ expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)
106
+
107
+ # Make the partially out of bounds selection
108
+ sat_sample = select_time_slice(
109
+ ds=da_sat_like,
110
+ t0=t0,
111
+ history_duration=history_duration,
112
+ forecast_duration=forecast_duration,
113
+ sample_period_duration=freq,
114
+ fill_selection=True
115
+ )
116
+
117
+ # Check the returned times are as expected
118
+ assert (sat_sample.time_utc == expected_datetimes).all()
119
+
120
+
121
+ # Check all the values before the first timestamp available in the data are NaN
122
+ all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))
123
+ if expected_datetimes[0] < min_time:
124
+ assert all_nan_space.sel(time_utc=slice(None, min_time-freq)).all(dim="time_utc")
125
+
126
+ # Check all the values before the first timestamp available in the data are NaN
127
+ if expected_datetimes[-1] > max_time:
128
+ assert all_nan_space.sel(time_utc=slice(max_time+freq, None)).all(dim="time_utc")
129
+
130
+ # Check that none of the values between the first and last available timestamp are NaN
131
+ any_nan_space = sat_sample.isnull().any(dim=("x_geostationary", "y_geostationary"))
132
+ assert not any_nan_space.sel(time_utc=slice(min_time, max_time)).any(dim="time_utc")
133
+
134
+
135
+ @pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
136
+ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
137
+ """Test the basic functionality of select_time_slice_nwp"""
138
+
139
+ # Slice parameters
140
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
141
+ forecast_duration = pd.Timedelta("6H")
142
+ history_duration = pd.Timedelta("3H")
143
+ freq = pd.Timedelta("1H")
144
+
145
+ # Make the selection
146
+ da_slice = select_time_slice_nwp(
147
+ da_nwp_like,
148
+ t0,
149
+ sample_period_duration=freq,
150
+ history_duration=history_duration,
151
+ forecast_duration=forecast_duration,
152
+ dropout_timedeltas = None,
153
+ dropout_frac = 0,
154
+ accum_channels = [],
155
+ channel_dim_name = "channel",
156
+ )
157
+
158
+ # Check the target-times are as expected
159
+ expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)
160
+ assert (da_slice.target_time_utc==expected_target_times).all()
161
+
162
+ # Check the init-times are as expected
163
+ # - Forecast frequency is `NWP_FREQ`, and we can't have selected future init-times
164
+ expected_init_times = pd.to_datetime(
165
+ [t if t<t0 else t0 for t in expected_target_times]
166
+ ).floor(NWP_FREQ)
167
+ assert (da_slice.init_time_utc==expected_init_times).all()
168
+
169
+
170
+ @pytest.mark.parametrize("dropout_hours", [1, 2, 5])
171
+ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
172
+ """Test the functionality of select_time_slice_nwp with dropout"""
173
+
174
+ t0 = pd.Timestamp("2024-01-02 12:00")
175
+ forecast_duration = pd.Timedelta("6H")
176
+ history_duration = pd.Timedelta("3H")
177
+ freq = pd.Timedelta("1H")
178
+ dropout_timedelta = pd.Timedelta(f"-{dropout_hours}H")
179
+
180
+ da_slice = select_time_slice_nwp(
181
+ da_nwp_like,
182
+ t0,
183
+ sample_period_duration=freq,
184
+ history_duration=history_duration,
185
+ forecast_duration=forecast_duration,
186
+ dropout_timedeltas = [dropout_timedelta],
187
+ dropout_frac = 1,
188
+ accum_channels = [],
189
+ channel_dim_name = "channel",
190
+ )
191
+
192
+ # Check the target-times are as expected
193
+ expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)
194
+ assert (da_slice.target_time_utc==expected_target_times).all()
195
+
196
+ # Check the init-times are as expected considering the delay
197
+ t0_delayed = t0 + dropout_timedelta
198
+ expected_init_times = pd.to_datetime(
199
+ [t if t<t0_delayed else t0_delayed for t in expected_target_times]
200
+ ).floor(NWP_FREQ)
201
+ assert (da_slice.init_time_utc==expected_init_times).all()
202
+
203
+
204
+ @pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
205
+ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
206
+ """Test the functionality of select_time_slice_nwp with dropout and accumulated variables"""
207
+
208
+ # Slice parameters
209
+ t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
210
+ forecast_duration = pd.Timedelta("6H")
211
+ history_duration = pd.Timedelta("3H")
212
+ freq = pd.Timedelta("1H")
213
+ dropout_timedelta = pd.Timedelta("-2H")
214
+
215
+ t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
216
+
217
+ da_slice = select_time_slice_nwp(
218
+ da_nwp_like,
219
+ t0,
220
+ sample_period_duration=freq,
221
+ history_duration=history_duration,
222
+ forecast_duration=forecast_duration,
223
+ dropout_timedeltas=[dropout_timedelta],
224
+ dropout_frac=1,
225
+ accum_channels=["dswrf"],
226
+ channel_dim_name="channel",
227
+ )
228
+
229
+ # Check the target-times are as expected
230
+ expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)
231
+ assert (da_slice.target_time_utc==expected_target_times).all()
232
+
233
+ # Check the init-times are as expected considering the delay
234
+ expected_init_times = pd.to_datetime(
235
+ [t if t<t0_delayed else t0_delayed for t in expected_target_times]
236
+ ).floor(NWP_FREQ)
237
+ assert (da_slice.init_time_utc==expected_init_times).all()
238
+
239
+ # Check channels are as expected
240
+ assert (da_slice.channel.values == ["t", "diff_dswrf"]).all()
241
+
242
+ # Check the accummulated channel has been differenced correctly
243
+
244
+ # This part of the data is pulled from the init-time: t0_delayed
245
+ da_slice_accum = da_slice.sel(
246
+ target_time_utc=slice(t0_delayed, None),
247
+ channel="diff_dswrf"
248
+ )
249
+
250
+ # Get the original data for the t0_delayed init-time, and diff it along steps
251
+ # then select the steps which are expected to be used in the above slice
252
+ da_orig_diffed = (
253
+ da_nwp_like.sel(
254
+ init_time_utc=t0_delayed,
255
+ channel="dswrf",
256
+ ).diff(dim="step", label="lower")
257
+ .sel(step=slice(t0-t0_delayed - history_duration, t0-t0_delayed + forecast_duration))
258
+ )
259
+
260
+ # Check the values are the same
261
+ assert (da_slice_accum.values == da_orig_diffed.values).all()
262
+
263
+ # Check the non-accummulated channel has not been differenced
264
+
265
+ # This part of the data is pulled from the init-time: t0_delayed
266
+ da_slice_nonaccum = da_slice.sel(
267
+ target_time_utc=slice(t0_delayed, None),
268
+ channel="t"
269
+ )
270
+
271
+ # Get the original data for the t0_delayed init-time, and select the steps which are expected
272
+ # to be used in the above slice
273
+ da_orig = (
274
+ da_nwp_like.sel(
275
+ init_time_utc=t0_delayed,
276
+ channel="t",
277
+ )
278
+ .sel(step=slice(t0-t0_delayed - history_duration, t0-t0_delayed + forecast_duration))
279
+ )
280
+
281
+ # Check the values are the same
282
+ assert (da_slice_nonaccum.values == da_orig.values).all()
283
+
284
+