ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +146 -64
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
- ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
- ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +63 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler/constants.py +0 -222
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,154 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import xarray as xr
|
|
3
|
-
from ocf_data_sampler.select.location import Location
|
|
4
|
-
import pytest
|
|
5
|
-
|
|
6
|
-
from ocf_data_sampler.select.select_spatial_slice import (
|
|
7
|
-
select_spatial_slice_pixels, _get_idx_of_pixel_closest_to_poi
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
@pytest.fixture(scope="module")
|
|
11
|
-
def da():
|
|
12
|
-
# Create dummy data
|
|
13
|
-
x = np.arange(-100, 100)
|
|
14
|
-
y = np.arange(-100, 100)
|
|
15
|
-
|
|
16
|
-
da = xr.DataArray(
|
|
17
|
-
np.random.normal(size=(len(x), len(y))),
|
|
18
|
-
coords=dict(
|
|
19
|
-
x_osgb=(["x_osgb"], x),
|
|
20
|
-
y_osgb=(["y_osgb"], y),
|
|
21
|
-
)
|
|
22
|
-
)
|
|
23
|
-
return da
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def test_get_idx_of_pixel_closest_to_poi(da):
|
|
27
|
-
|
|
28
|
-
idx_location = _get_idx_of_pixel_closest_to_poi(
|
|
29
|
-
da,
|
|
30
|
-
location=Location(x=10, y=10, coordinate_system="osgb"),
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
assert idx_location.coordinate_system == "idx"
|
|
34
|
-
assert idx_location.x == 110
|
|
35
|
-
assert idx_location.y == 110
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def test_select_spatial_slice_pixels(da):
|
|
41
|
-
|
|
42
|
-
# Select window which lies within x-y bounds of the data
|
|
43
|
-
da_sliced = select_spatial_slice_pixels(
|
|
44
|
-
da,
|
|
45
|
-
location=Location(x=-90, y=-80, coordinate_system="osgb"),
|
|
46
|
-
width_pixels=10,
|
|
47
|
-
height_pixels=10,
|
|
48
|
-
allow_partial_slice=True,
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
53
|
-
assert (da_sliced.x_osgb.values == np.arange(-95, -85)).all()
|
|
54
|
-
assert (da_sliced.y_osgb.values == np.arange(-85, -75)).all()
|
|
55
|
-
# No padding in this case so no NaNs
|
|
56
|
-
assert not da_sliced.isnull().any()
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
# Select window where the edge of the window lies right on the edge of the data
|
|
60
|
-
da_sliced = select_spatial_slice_pixels(
|
|
61
|
-
da,
|
|
62
|
-
location=Location(x=-90, y=-80, coordinate_system="osgb"),
|
|
63
|
-
width_pixels=20,
|
|
64
|
-
height_pixels=20,
|
|
65
|
-
allow_partial_slice=True,
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
69
|
-
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
|
|
70
|
-
assert (da_sliced.y_osgb.values == np.arange(-90, -70)).all()
|
|
71
|
-
# No padding in this case so no NaNs
|
|
72
|
-
assert not da_sliced.isnull().any()
|
|
73
|
-
|
|
74
|
-
# Select window which is partially outside the boundary of the data - padded on left
|
|
75
|
-
da_sliced = select_spatial_slice_pixels(
|
|
76
|
-
da,
|
|
77
|
-
location=Location(x=-90, y=-80, coordinate_system="osgb"),
|
|
78
|
-
width_pixels=30,
|
|
79
|
-
height_pixels=30,
|
|
80
|
-
allow_partial_slice=True,
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
84
|
-
assert (da_sliced.x_osgb.values == np.arange(-105, -75)).all()
|
|
85
|
-
assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
|
|
86
|
-
# Data has been padded on left by 5 NaN pixels
|
|
87
|
-
assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
# Select window which is partially outside the boundary of the data - padded on right
|
|
91
|
-
da_sliced = select_spatial_slice_pixels(
|
|
92
|
-
da,
|
|
93
|
-
location=Location(x=90, y=-80, coordinate_system="osgb"),
|
|
94
|
-
width_pixels=30,
|
|
95
|
-
height_pixels=30,
|
|
96
|
-
allow_partial_slice=True,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
100
|
-
assert (da_sliced.x_osgb.values == np.arange(75, 105)).all()
|
|
101
|
-
assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
|
|
102
|
-
# Data has been padded on right by 5 NaN pixels
|
|
103
|
-
assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
location = Location(x=-90, y=-0, coordinate_system="osgb")
|
|
107
|
-
|
|
108
|
-
# Select window which is partially outside the boundary of the data - padded on top
|
|
109
|
-
da_sliced = select_spatial_slice_pixels(
|
|
110
|
-
da,
|
|
111
|
-
location=Location(x=-90, y=95, coordinate_system="osgb"),
|
|
112
|
-
width_pixels=20,
|
|
113
|
-
height_pixels=20,
|
|
114
|
-
allow_partial_slice=True,
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
118
|
-
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
|
|
119
|
-
assert (da_sliced.y_osgb.values == np.arange(85, 105)).all()
|
|
120
|
-
# Data has been padded on top by 5 NaN pixels
|
|
121
|
-
assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
|
|
122
|
-
|
|
123
|
-
# Select window which is partially outside the boundary of the data - padded on bottom
|
|
124
|
-
da_sliced = select_spatial_slice_pixels(
|
|
125
|
-
da,
|
|
126
|
-
location=Location(x=-90, y=-95, coordinate_system="osgb"),
|
|
127
|
-
width_pixels=20,
|
|
128
|
-
height_pixels=20,
|
|
129
|
-
allow_partial_slice=True,
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
133
|
-
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
|
|
134
|
-
assert (da_sliced.y_osgb.values == np.arange(-105, -85)).all()
|
|
135
|
-
# Data has been padded on bottom by 5 NaN pixels
|
|
136
|
-
assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)
|
|
137
|
-
|
|
138
|
-
# Select window which is partially outside the boundary of the data - padded right and bottom
|
|
139
|
-
da_sliced = select_spatial_slice_pixels(
|
|
140
|
-
da,
|
|
141
|
-
location=Location(x=90, y=-80, coordinate_system="osgb"),
|
|
142
|
-
width_pixels=50,
|
|
143
|
-
height_pixels=50,
|
|
144
|
-
allow_partial_slice=True,
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
assert isinstance(da_sliced, xr.DataArray)
|
|
148
|
-
assert (da_sliced.x_osgb.values == np.arange(65, 115)).all()
|
|
149
|
-
assert (da_sliced.y_osgb.values == np.arange(-105, -55)).all()
|
|
150
|
-
# Data has been padded on right by 15 pixels and bottom by 5 NaN pixels
|
|
151
|
-
assert da_sliced.isnull().sum() == 15*len(da_sliced.y_osgb) + 5*len(da_sliced.x_osgb) - 15*5
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
@@ -1,275 +0,0 @@
|
|
|
1
|
-
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
import xarray as xr
|
|
6
|
-
import pytest
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
NWP_FREQ = pd.Timedelta("3h")
|
|
10
|
-
|
|
11
|
-
@pytest.fixture(scope="module")
|
|
12
|
-
def da_sat_like():
|
|
13
|
-
"""Create dummy data which looks like satellite data"""
|
|
14
|
-
x = np.arange(-100, 100)
|
|
15
|
-
y = np.arange(-100, 100)
|
|
16
|
-
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")
|
|
17
|
-
|
|
18
|
-
da_sat = xr.DataArray(
|
|
19
|
-
np.random.normal(size=(len(datetimes), len(x), len(y))),
|
|
20
|
-
coords=dict(
|
|
21
|
-
time_utc=(["time_utc"], datetimes),
|
|
22
|
-
x_geostationary=(["x_geostationary"], x),
|
|
23
|
-
y_geostationary=(["y_geostationary"], y),
|
|
24
|
-
)
|
|
25
|
-
)
|
|
26
|
-
return da_sat
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@pytest.fixture(scope="module")
|
|
30
|
-
def da_nwp_like():
|
|
31
|
-
"""Create dummy data which looks like NWP data"""
|
|
32
|
-
|
|
33
|
-
x = np.arange(-100, 100)
|
|
34
|
-
y = np.arange(-100, 100)
|
|
35
|
-
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
|
|
36
|
-
steps = pd.timedelta_range("0h", "16h", freq="1h")
|
|
37
|
-
channels = ["t", "dswrf"]
|
|
38
|
-
|
|
39
|
-
da_nwp = xr.DataArray(
|
|
40
|
-
np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
|
|
41
|
-
coords=dict(
|
|
42
|
-
init_time_utc=(["init_time_utc"], datetimes),
|
|
43
|
-
step=(["step"], steps),
|
|
44
|
-
channel=(["channel"], channels),
|
|
45
|
-
x_osgb=(["x_osgb"], x),
|
|
46
|
-
y_osgb=(["y_osgb"], y),
|
|
47
|
-
)
|
|
48
|
-
)
|
|
49
|
-
return da_nwp
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@pytest.mark.parametrize("t0_str", ["12:30", "12:40", "12:00"])
|
|
53
|
-
def test_select_time_slice(da_sat_like, t0_str):
|
|
54
|
-
"""Test the basic functionality of select_time_slice"""
|
|
55
|
-
|
|
56
|
-
# Slice parameters
|
|
57
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
58
|
-
interval_start = pd.Timedelta(-0, "min")
|
|
59
|
-
interval_end = pd.Timedelta(60, "min")
|
|
60
|
-
freq = pd.Timedelta("5min")
|
|
61
|
-
|
|
62
|
-
# Expect to return these timestamps from the selection
|
|
63
|
-
expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq)
|
|
64
|
-
|
|
65
|
-
# Make the selection
|
|
66
|
-
sat_sample = select_time_slice(
|
|
67
|
-
da_sat_like,
|
|
68
|
-
t0=t0,
|
|
69
|
-
interval_start=interval_start,
|
|
70
|
-
interval_end=interval_end,
|
|
71
|
-
sample_period_duration=freq,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
# Check the returned times are as expected
|
|
75
|
-
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@pytest.mark.parametrize("t0_str", ["00:00", "00:25", "11:00", "11:55"])
|
|
79
|
-
def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
|
|
80
|
-
"""Test the behaviour of select_time_slice when the selection is out of bounds"""
|
|
81
|
-
|
|
82
|
-
# Slice parameters
|
|
83
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
84
|
-
interval_start = pd.Timedelta(-30, "min")
|
|
85
|
-
interval_end = pd.Timedelta(60, "min")
|
|
86
|
-
freq = pd.Timedelta("5min")
|
|
87
|
-
|
|
88
|
-
# The data is available between these times
|
|
89
|
-
min_time = pd.Timestamp(da_sat_like.time_utc.min().item())
|
|
90
|
-
max_time = pd.Timestamp(da_sat_like.time_utc.max().item())
|
|
91
|
-
|
|
92
|
-
# Expect to return these timestamps within the requested range
|
|
93
|
-
expected_datetimes = pd.date_range(
|
|
94
|
-
max(t0 + interval_start, min_time),
|
|
95
|
-
min(t0 + interval_end, max_time),
|
|
96
|
-
freq=freq,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# Make the partially out of bounds selection
|
|
100
|
-
sat_sample = select_time_slice(
|
|
101
|
-
da_sat_like,
|
|
102
|
-
t0=t0,
|
|
103
|
-
interval_start=interval_start,
|
|
104
|
-
interval_end=interval_end,
|
|
105
|
-
sample_period_duration=freq,
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
# Check the returned times are as expected
|
|
109
|
-
assert (sat_sample.time_utc == expected_datetimes).all()
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
# Check all the values before the first timestamp available in the data are NaN
|
|
113
|
-
all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))
|
|
114
|
-
if expected_datetimes[0] < min_time:
|
|
115
|
-
assert all_nan_space.sel(time_utc=slice(None, min_time-freq)).all(dim="time_utc")
|
|
116
|
-
|
|
117
|
-
# Check all the values before the first timestamp available in the data are NaN
|
|
118
|
-
if expected_datetimes[-1] > max_time:
|
|
119
|
-
assert all_nan_space.sel(time_utc=slice(max_time+freq, None)).all(dim="time_utc")
|
|
120
|
-
|
|
121
|
-
# Check that none of the values between the first and last available timestamp are NaN
|
|
122
|
-
any_nan_space = sat_sample.isnull().any(dim=("x_geostationary", "y_geostationary"))
|
|
123
|
-
assert not any_nan_space.sel(time_utc=slice(min_time, max_time)).any(dim="time_utc")
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
|
|
127
|
-
def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
|
|
128
|
-
"""Test the basic functionality of select_time_slice_nwp"""
|
|
129
|
-
|
|
130
|
-
# Slice parameters
|
|
131
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
132
|
-
interval_start = pd.Timedelta(-6, "h")
|
|
133
|
-
interval_end = pd.Timedelta(3, "h")
|
|
134
|
-
freq = pd.Timedelta("1h")
|
|
135
|
-
|
|
136
|
-
# Make the selection
|
|
137
|
-
da_slice = select_time_slice_nwp(
|
|
138
|
-
da_nwp_like,
|
|
139
|
-
t0,
|
|
140
|
-
sample_period_duration=freq,
|
|
141
|
-
interval_start=interval_start,
|
|
142
|
-
interval_end=interval_end,
|
|
143
|
-
dropout_timedeltas = None,
|
|
144
|
-
dropout_frac = 0,
|
|
145
|
-
accum_channels = [],
|
|
146
|
-
channel_dim_name = "channel",
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# Check the target-times are as expected
|
|
150
|
-
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
151
|
-
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
152
|
-
|
|
153
|
-
# Check the init-times are as expected
|
|
154
|
-
# - Forecast frequency is `NWP_FREQ`, and we can't have selected future init-times
|
|
155
|
-
expected_init_times = pd.to_datetime(
|
|
156
|
-
[t if t<t0 else t0 for t in expected_target_times]
|
|
157
|
-
).floor(NWP_FREQ)
|
|
158
|
-
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
@pytest.mark.parametrize("dropout_hours", [1, 2, 5])
|
|
162
|
-
def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
|
|
163
|
-
"""Test the functionality of select_time_slice_nwp with dropout"""
|
|
164
|
-
|
|
165
|
-
t0 = pd.Timestamp("2024-01-02 12:00")
|
|
166
|
-
interval_start = pd.Timedelta(-6, "h")
|
|
167
|
-
interval_end = pd.Timedelta(3, "h")
|
|
168
|
-
freq = pd.Timedelta("1h")
|
|
169
|
-
dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h")
|
|
170
|
-
|
|
171
|
-
da_slice = select_time_slice_nwp(
|
|
172
|
-
da_nwp_like,
|
|
173
|
-
t0,
|
|
174
|
-
sample_period_duration=freq,
|
|
175
|
-
interval_start=interval_start,
|
|
176
|
-
interval_end=interval_end,
|
|
177
|
-
dropout_timedeltas = [dropout_timedelta],
|
|
178
|
-
dropout_frac = 1,
|
|
179
|
-
accum_channels = [],
|
|
180
|
-
channel_dim_name = "channel",
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# Check the target-times are as expected
|
|
184
|
-
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
185
|
-
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
186
|
-
|
|
187
|
-
# Check the init-times are as expected considering the delay
|
|
188
|
-
t0_delayed = t0 + dropout_timedelta
|
|
189
|
-
expected_init_times = pd.to_datetime(
|
|
190
|
-
[t if t<t0_delayed else t0_delayed for t in expected_target_times]
|
|
191
|
-
).floor(NWP_FREQ)
|
|
192
|
-
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
|
|
196
|
-
def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
|
|
197
|
-
"""Test the functionality of select_time_slice_nwp with dropout and accumulated variables"""
|
|
198
|
-
|
|
199
|
-
# Slice parameters
|
|
200
|
-
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
|
|
201
|
-
interval_start = pd.Timedelta(-6, "h")
|
|
202
|
-
interval_end = pd.Timedelta(3, "h")
|
|
203
|
-
freq = pd.Timedelta("1h")
|
|
204
|
-
dropout_timedelta = pd.Timedelta("-2h")
|
|
205
|
-
|
|
206
|
-
t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
|
|
207
|
-
|
|
208
|
-
da_slice = select_time_slice_nwp(
|
|
209
|
-
da_nwp_like,
|
|
210
|
-
t0,
|
|
211
|
-
sample_period_duration=freq,
|
|
212
|
-
interval_start=interval_start,
|
|
213
|
-
interval_end=interval_end,
|
|
214
|
-
dropout_timedeltas=[dropout_timedelta],
|
|
215
|
-
dropout_frac=1,
|
|
216
|
-
accum_channels=["dswrf"],
|
|
217
|
-
channel_dim_name="channel",
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
# Check the target-times are as expected
|
|
221
|
-
expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
|
|
222
|
-
assert (da_slice.target_time_utc==expected_target_times).all()
|
|
223
|
-
|
|
224
|
-
# Check the init-times are as expected considering the delay
|
|
225
|
-
expected_init_times = pd.to_datetime(
|
|
226
|
-
[t if t<t0_delayed else t0_delayed for t in expected_target_times]
|
|
227
|
-
).floor(NWP_FREQ)
|
|
228
|
-
assert (da_slice.init_time_utc==expected_init_times).all()
|
|
229
|
-
|
|
230
|
-
# Check channels are as expected
|
|
231
|
-
assert (da_slice.channel.values == ["t", "diff_dswrf"]).all()
|
|
232
|
-
|
|
233
|
-
# Check the accummulated channel has been differenced correctly
|
|
234
|
-
|
|
235
|
-
# This part of the data is pulled from the init-time: t0_delayed
|
|
236
|
-
da_slice_accum = da_slice.sel(
|
|
237
|
-
target_time_utc=slice(t0_delayed, None),
|
|
238
|
-
channel="diff_dswrf"
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
# Get the original data for the t0_delayed init-time, and diff it along steps
|
|
242
|
-
# then select the steps which are expected to be used in the above slice
|
|
243
|
-
da_orig_diffed = (
|
|
244
|
-
da_nwp_like.sel(
|
|
245
|
-
init_time_utc=t0_delayed,
|
|
246
|
-
channel="dswrf",
|
|
247
|
-
).diff(dim="step", label="lower")
|
|
248
|
-
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
# Check the values are the same
|
|
252
|
-
assert (da_slice_accum.values == da_orig_diffed.values).all()
|
|
253
|
-
|
|
254
|
-
# Check the non-accummulated channel has not been differenced
|
|
255
|
-
|
|
256
|
-
# This part of the data is pulled from the init-time: t0_delayed
|
|
257
|
-
da_slice_nonaccum = da_slice.sel(
|
|
258
|
-
target_time_utc=slice(t0_delayed, None),
|
|
259
|
-
channel="t"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# Get the original data for the t0_delayed init-time, and select the steps which are expected
|
|
263
|
-
# to be used in the above slice
|
|
264
|
-
da_orig = (
|
|
265
|
-
da_nwp_like.sel(
|
|
266
|
-
init_time_utc=t0_delayed,
|
|
267
|
-
channel="t",
|
|
268
|
-
)
|
|
269
|
-
.sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
# Check the values are the same
|
|
273
|
-
assert (da_slice_nonaccum.values == da_orig.values).all()
|
|
274
|
-
|
|
275
|
-
|
tests/test_sample/test_base.py
DELETED
|
@@ -1,164 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Base class testing - SampleBase
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
import torch
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from ocf_data_sampler.sample.base import (
|
|
11
|
-
SampleBase,
|
|
12
|
-
batch_to_tensor,
|
|
13
|
-
copy_batch_to_device
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
class TestSample(SampleBase):
|
|
17
|
-
"""
|
|
18
|
-
SampleBase for testing purposes
|
|
19
|
-
Minimal implementations - abstract methods
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def __init__(self):
|
|
23
|
-
super().__init__()
|
|
24
|
-
self._data = {}
|
|
25
|
-
|
|
26
|
-
def plot(self, **kwargs):
|
|
27
|
-
""" Minimal plot implementation """
|
|
28
|
-
return None
|
|
29
|
-
|
|
30
|
-
def to_numpy(self) -> None:
|
|
31
|
-
""" Standard implementation """
|
|
32
|
-
return {key: np.array(value) for key, value in self._data.items()}
|
|
33
|
-
|
|
34
|
-
def save(self, path):
|
|
35
|
-
""" Minimal save implementation """
|
|
36
|
-
path = Path(path)
|
|
37
|
-
with open(path, 'wb') as f:
|
|
38
|
-
f.write(b'test_data')
|
|
39
|
-
|
|
40
|
-
@classmethod
|
|
41
|
-
def load(cls, path):
|
|
42
|
-
""" Minimal load implementation """
|
|
43
|
-
path = Path(path)
|
|
44
|
-
instance = cls()
|
|
45
|
-
return instance
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def test_sample_base_initialisation():
|
|
49
|
-
""" Initialisation of SampleBase subclass """
|
|
50
|
-
|
|
51
|
-
sample = TestSample()
|
|
52
|
-
assert hasattr(sample, '_data'), "Sample should have _data attribute"
|
|
53
|
-
assert sample._data == {}, "Sample should start with empty dict"
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def test_sample_base_save_load(tmp_path):
|
|
57
|
-
""" Test basic save and load functionality """
|
|
58
|
-
|
|
59
|
-
sample = TestSample()
|
|
60
|
-
sample._data['test_data'] = [1, 2, 3]
|
|
61
|
-
|
|
62
|
-
save_path = tmp_path / 'test_sample.dat'
|
|
63
|
-
sample.save(save_path)
|
|
64
|
-
assert save_path.exists()
|
|
65
|
-
|
|
66
|
-
loaded_sample = TestSample.load(save_path)
|
|
67
|
-
assert isinstance(loaded_sample, TestSample)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def test_sample_base_abstract_methods():
|
|
71
|
-
""" Test abstract method enforcement """
|
|
72
|
-
|
|
73
|
-
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
|
74
|
-
SampleBase()
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def test_sample_base_to_numpy():
|
|
78
|
-
""" Test the to_numpy functionality """
|
|
79
|
-
import numpy as np
|
|
80
|
-
|
|
81
|
-
sample = TestSample()
|
|
82
|
-
sample._data = {
|
|
83
|
-
'int_data': 42,
|
|
84
|
-
'list_data': [1, 2, 3]
|
|
85
|
-
}
|
|
86
|
-
numpy_data = sample.to_numpy()
|
|
87
|
-
|
|
88
|
-
assert isinstance(numpy_data, dict)
|
|
89
|
-
assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
|
|
90
|
-
assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def test_batch_to_tensor_nested():
|
|
94
|
-
""" Test nested dictionary conversion """
|
|
95
|
-
batch = {
|
|
96
|
-
'outer': {
|
|
97
|
-
'inner': np.array([1, 2, 3])
|
|
98
|
-
}
|
|
99
|
-
}
|
|
100
|
-
tensor_batch = batch_to_tensor(batch)
|
|
101
|
-
|
|
102
|
-
assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def test_batch_to_tensor_mixed_types():
|
|
106
|
-
""" Test handling of mixed data types """
|
|
107
|
-
batch = {
|
|
108
|
-
'tensor_data': np.array([1, 2, 3]),
|
|
109
|
-
'string_data': 'not_a_tensor',
|
|
110
|
-
'nested': {
|
|
111
|
-
'numbers': np.array([4, 5, 6]),
|
|
112
|
-
'text': 'still_not_a_tensor'
|
|
113
|
-
}
|
|
114
|
-
}
|
|
115
|
-
tensor_batch = batch_to_tensor(batch)
|
|
116
|
-
|
|
117
|
-
assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
|
|
118
|
-
assert isinstance(tensor_batch['string_data'], str)
|
|
119
|
-
assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
|
|
120
|
-
assert isinstance(tensor_batch['nested']['text'], str)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def test_batch_to_tensor_different_dtypes():
|
|
124
|
-
""" Test conversion of arrays with different dtypes """
|
|
125
|
-
batch = {
|
|
126
|
-
'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
|
|
127
|
-
'int_data': np.array([1, 2, 3], dtype=np.int64),
|
|
128
|
-
'bool_data': np.array([True, False, True], dtype=np.bool_)
|
|
129
|
-
}
|
|
130
|
-
tensor_batch = batch_to_tensor(batch)
|
|
131
|
-
|
|
132
|
-
assert isinstance(tensor_batch['bool_data'], torch.Tensor)
|
|
133
|
-
assert tensor_batch['float_data'].dtype == torch.float32
|
|
134
|
-
assert tensor_batch['int_data'].dtype == torch.int64
|
|
135
|
-
assert tensor_batch['bool_data'].dtype == torch.bool
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def test_batch_to_tensor_multidimensional():
|
|
139
|
-
""" Test conversion of multidimensional arrays """
|
|
140
|
-
batch = {
|
|
141
|
-
'matrix': np.array([[1, 2], [3, 4]]),
|
|
142
|
-
'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
143
|
-
}
|
|
144
|
-
tensor_batch = batch_to_tensor(batch)
|
|
145
|
-
|
|
146
|
-
assert tensor_batch['matrix'].shape == (2, 2)
|
|
147
|
-
assert tensor_batch['tensor'].shape == (2, 2, 2)
|
|
148
|
-
assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def test_copy_batch_to_device():
|
|
152
|
-
""" Test moving tensors to a different device """
|
|
153
|
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
154
|
-
batch = {
|
|
155
|
-
'tensor_data': torch.tensor([1, 2, 3]),
|
|
156
|
-
'nested': {
|
|
157
|
-
'matrix': torch.tensor([[1, 2], [3, 4]])
|
|
158
|
-
},
|
|
159
|
-
'non_tensor': 'unchanged'
|
|
160
|
-
}
|
|
161
|
-
moved_batch = copy_batch_to_device(batch, device)
|
|
162
|
-
assert moved_batch['tensor_data'].device == device
|
|
163
|
-
assert moved_batch['nested']['matrix'].device == device
|
|
164
|
-
assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged
|