ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,275 +0,0 @@
1
- from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import xarray as xr
6
- import pytest
7
-
8
-
9
- NWP_FREQ = pd.Timedelta("3h")
10
-
11
- @pytest.fixture(scope="module")
12
- def da_sat_like():
13
- """Create dummy data which looks like satellite data"""
14
- x = np.arange(-100, 100)
15
- y = np.arange(-100, 100)
16
- datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")
17
-
18
- da_sat = xr.DataArray(
19
- np.random.normal(size=(len(datetimes), len(x), len(y))),
20
- coords=dict(
21
- time_utc=(["time_utc"], datetimes),
22
- x_geostationary=(["x_geostationary"], x),
23
- y_geostationary=(["y_geostationary"], y),
24
- )
25
- )
26
- return da_sat
27
-
28
-
29
- @pytest.fixture(scope="module")
30
- def da_nwp_like():
31
- """Create dummy data which looks like NWP data"""
32
-
33
- x = np.arange(-100, 100)
34
- y = np.arange(-100, 100)
35
- datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
36
- steps = pd.timedelta_range("0h", "16h", freq="1h")
37
- channels = ["t", "dswrf"]
38
-
39
- da_nwp = xr.DataArray(
40
- np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
41
- coords=dict(
42
- init_time_utc=(["init_time_utc"], datetimes),
43
- step=(["step"], steps),
44
- channel=(["channel"], channels),
45
- x_osgb=(["x_osgb"], x),
46
- y_osgb=(["y_osgb"], y),
47
- )
48
- )
49
- return da_nwp
50
-
51
-
52
- @pytest.mark.parametrize("t0_str", ["12:30", "12:40", "12:00"])
53
- def test_select_time_slice(da_sat_like, t0_str):
54
- """Test the basic functionality of select_time_slice"""
55
-
56
- # Slice parameters
57
- t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
58
- interval_start = pd.Timedelta(-0, "min")
59
- interval_end = pd.Timedelta(60, "min")
60
- freq = pd.Timedelta("5min")
61
-
62
- # Expect to return these timestamps from the selection
63
- expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq)
64
-
65
- # Make the selection
66
- sat_sample = select_time_slice(
67
- da_sat_like,
68
- t0=t0,
69
- interval_start=interval_start,
70
- interval_end=interval_end,
71
- sample_period_duration=freq,
72
- )
73
-
74
- # Check the returned times are as expected
75
- assert (sat_sample.time_utc == expected_datetimes).all()
76
-
77
-
78
- @pytest.mark.parametrize("t0_str", ["00:00", "00:25", "11:00", "11:55"])
79
- def test_select_time_slice_out_of_bounds(da_sat_like, t0_str):
80
- """Test the behaviour of select_time_slice when the selection is out of bounds"""
81
-
82
- # Slice parameters
83
- t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
84
- interval_start = pd.Timedelta(-30, "min")
85
- interval_end = pd.Timedelta(60, "min")
86
- freq = pd.Timedelta("5min")
87
-
88
- # The data is available between these times
89
- min_time = pd.Timestamp(da_sat_like.time_utc.min().item())
90
- max_time = pd.Timestamp(da_sat_like.time_utc.max().item())
91
-
92
- # Expect to return these timestamps within the requested range
93
- expected_datetimes = pd.date_range(
94
- max(t0 + interval_start, min_time),
95
- min(t0 + interval_end, max_time),
96
- freq=freq,
97
- )
98
-
99
- # Make the partially out of bounds selection
100
- sat_sample = select_time_slice(
101
- da_sat_like,
102
- t0=t0,
103
- interval_start=interval_start,
104
- interval_end=interval_end,
105
- sample_period_duration=freq,
106
- )
107
-
108
- # Check the returned times are as expected
109
- assert (sat_sample.time_utc == expected_datetimes).all()
110
-
111
-
112
- # Check all the values before the first timestamp available in the data are NaN
113
- all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))
114
- if expected_datetimes[0] < min_time:
115
- assert all_nan_space.sel(time_utc=slice(None, min_time-freq)).all(dim="time_utc")
116
-
117
- # Check all the values before the first timestamp available in the data are NaN
118
- if expected_datetimes[-1] > max_time:
119
- assert all_nan_space.sel(time_utc=slice(max_time+freq, None)).all(dim="time_utc")
120
-
121
- # Check that none of the values between the first and last available timestamp are NaN
122
- any_nan_space = sat_sample.isnull().any(dim=("x_geostationary", "y_geostationary"))
123
- assert not any_nan_space.sel(time_utc=slice(min_time, max_time)).any(dim="time_utc")
124
-
125
-
126
- @pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
127
- def test_select_time_slice_nwp_basic(da_nwp_like, t0_str):
128
- """Test the basic functionality of select_time_slice_nwp"""
129
-
130
- # Slice parameters
131
- t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
132
- interval_start = pd.Timedelta(-6, "h")
133
- interval_end = pd.Timedelta(3, "h")
134
- freq = pd.Timedelta("1h")
135
-
136
- # Make the selection
137
- da_slice = select_time_slice_nwp(
138
- da_nwp_like,
139
- t0,
140
- sample_period_duration=freq,
141
- interval_start=interval_start,
142
- interval_end=interval_end,
143
- dropout_timedeltas = None,
144
- dropout_frac = 0,
145
- accum_channels = [],
146
- channel_dim_name = "channel",
147
- )
148
-
149
- # Check the target-times are as expected
150
- expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
151
- assert (da_slice.target_time_utc==expected_target_times).all()
152
-
153
- # Check the init-times are as expected
154
- # - Forecast frequency is `NWP_FREQ`, and we can't have selected future init-times
155
- expected_init_times = pd.to_datetime(
156
- [t if t<t0 else t0 for t in expected_target_times]
157
- ).floor(NWP_FREQ)
158
- assert (da_slice.init_time_utc==expected_init_times).all()
159
-
160
-
161
- @pytest.mark.parametrize("dropout_hours", [1, 2, 5])
162
- def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
163
- """Test the functionality of select_time_slice_nwp with dropout"""
164
-
165
- t0 = pd.Timestamp("2024-01-02 12:00")
166
- interval_start = pd.Timedelta(-6, "h")
167
- interval_end = pd.Timedelta(3, "h")
168
- freq = pd.Timedelta("1h")
169
- dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h")
170
-
171
- da_slice = select_time_slice_nwp(
172
- da_nwp_like,
173
- t0,
174
- sample_period_duration=freq,
175
- interval_start=interval_start,
176
- interval_end=interval_end,
177
- dropout_timedeltas = [dropout_timedelta],
178
- dropout_frac = 1,
179
- accum_channels = [],
180
- channel_dim_name = "channel",
181
- )
182
-
183
- # Check the target-times are as expected
184
- expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
185
- assert (da_slice.target_time_utc==expected_target_times).all()
186
-
187
- # Check the init-times are as expected considering the delay
188
- t0_delayed = t0 + dropout_timedelta
189
- expected_init_times = pd.to_datetime(
190
- [t if t<t0_delayed else t0_delayed for t in expected_target_times]
191
- ).floor(NWP_FREQ)
192
- assert (da_slice.init_time_utc==expected_init_times).all()
193
-
194
-
195
- @pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
196
- def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
197
- """Test the functionality of select_time_slice_nwp with dropout and accumulated variables"""
198
-
199
- # Slice parameters
200
- t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
201
- interval_start = pd.Timedelta(-6, "h")
202
- interval_end = pd.Timedelta(3, "h")
203
- freq = pd.Timedelta("1h")
204
- dropout_timedelta = pd.Timedelta("-2h")
205
-
206
- t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
207
-
208
- da_slice = select_time_slice_nwp(
209
- da_nwp_like,
210
- t0,
211
- sample_period_duration=freq,
212
- interval_start=interval_start,
213
- interval_end=interval_end,
214
- dropout_timedeltas=[dropout_timedelta],
215
- dropout_frac=1,
216
- accum_channels=["dswrf"],
217
- channel_dim_name="channel",
218
- )
219
-
220
- # Check the target-times are as expected
221
- expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq)
222
- assert (da_slice.target_time_utc==expected_target_times).all()
223
-
224
- # Check the init-times are as expected considering the delay
225
- expected_init_times = pd.to_datetime(
226
- [t if t<t0_delayed else t0_delayed for t in expected_target_times]
227
- ).floor(NWP_FREQ)
228
- assert (da_slice.init_time_utc==expected_init_times).all()
229
-
230
- # Check channels are as expected
231
- assert (da_slice.channel.values == ["t", "diff_dswrf"]).all()
232
-
233
- # Check the accummulated channel has been differenced correctly
234
-
235
- # This part of the data is pulled from the init-time: t0_delayed
236
- da_slice_accum = da_slice.sel(
237
- target_time_utc=slice(t0_delayed, None),
238
- channel="diff_dswrf"
239
- )
240
-
241
- # Get the original data for the t0_delayed init-time, and diff it along steps
242
- # then select the steps which are expected to be used in the above slice
243
- da_orig_diffed = (
244
- da_nwp_like.sel(
245
- init_time_utc=t0_delayed,
246
- channel="dswrf",
247
- ).diff(dim="step", label="lower")
248
- .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
249
- )
250
-
251
- # Check the values are the same
252
- assert (da_slice_accum.values == da_orig_diffed.values).all()
253
-
254
- # Check the non-accummulated channel has not been differenced
255
-
256
- # This part of the data is pulled from the init-time: t0_delayed
257
- da_slice_nonaccum = da_slice.sel(
258
- target_time_utc=slice(t0_delayed, None),
259
- channel="t"
260
- )
261
-
262
- # Get the original data for the t0_delayed init-time, and select the steps which are expected
263
- # to be used in the above slice
264
- da_orig = (
265
- da_nwp_like.sel(
266
- init_time_utc=t0_delayed,
267
- channel="t",
268
- )
269
- .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end))
270
- )
271
-
272
- # Check the values are the same
273
- assert (da_slice_nonaccum.values == da_orig.values).all()
274
-
275
-
@@ -1,164 +0,0 @@
1
- """
2
- Base class testing - SampleBase
3
- """
4
-
5
- import pytest
6
- import torch
7
- import numpy as np
8
-
9
- from pathlib import Path
10
- from ocf_data_sampler.sample.base import (
11
- SampleBase,
12
- batch_to_tensor,
13
- copy_batch_to_device
14
- )
15
-
16
- class TestSample(SampleBase):
17
- """
18
- SampleBase for testing purposes
19
- Minimal implementations - abstract methods
20
- """
21
-
22
- def __init__(self):
23
- super().__init__()
24
- self._data = {}
25
-
26
- def plot(self, **kwargs):
27
- """ Minimal plot implementation """
28
- return None
29
-
30
- def to_numpy(self) -> None:
31
- """ Standard implementation """
32
- return {key: np.array(value) for key, value in self._data.items()}
33
-
34
- def save(self, path):
35
- """ Minimal save implementation """
36
- path = Path(path)
37
- with open(path, 'wb') as f:
38
- f.write(b'test_data')
39
-
40
- @classmethod
41
- def load(cls, path):
42
- """ Minimal load implementation """
43
- path = Path(path)
44
- instance = cls()
45
- return instance
46
-
47
-
48
- def test_sample_base_initialisation():
49
- """ Initialisation of SampleBase subclass """
50
-
51
- sample = TestSample()
52
- assert hasattr(sample, '_data'), "Sample should have _data attribute"
53
- assert sample._data == {}, "Sample should start with empty dict"
54
-
55
-
56
- def test_sample_base_save_load(tmp_path):
57
- """ Test basic save and load functionality """
58
-
59
- sample = TestSample()
60
- sample._data['test_data'] = [1, 2, 3]
61
-
62
- save_path = tmp_path / 'test_sample.dat'
63
- sample.save(save_path)
64
- assert save_path.exists()
65
-
66
- loaded_sample = TestSample.load(save_path)
67
- assert isinstance(loaded_sample, TestSample)
68
-
69
-
70
- def test_sample_base_abstract_methods():
71
- """ Test abstract method enforcement """
72
-
73
- with pytest.raises(TypeError, match="Can't instantiate abstract class"):
74
- SampleBase()
75
-
76
-
77
- def test_sample_base_to_numpy():
78
- """ Test the to_numpy functionality """
79
- import numpy as np
80
-
81
- sample = TestSample()
82
- sample._data = {
83
- 'int_data': 42,
84
- 'list_data': [1, 2, 3]
85
- }
86
- numpy_data = sample.to_numpy()
87
-
88
- assert isinstance(numpy_data, dict)
89
- assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
90
- assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
91
-
92
-
93
- def test_batch_to_tensor_nested():
94
- """ Test nested dictionary conversion """
95
- batch = {
96
- 'outer': {
97
- 'inner': np.array([1, 2, 3])
98
- }
99
- }
100
- tensor_batch = batch_to_tensor(batch)
101
-
102
- assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
103
-
104
-
105
- def test_batch_to_tensor_mixed_types():
106
- """ Test handling of mixed data types """
107
- batch = {
108
- 'tensor_data': np.array([1, 2, 3]),
109
- 'string_data': 'not_a_tensor',
110
- 'nested': {
111
- 'numbers': np.array([4, 5, 6]),
112
- 'text': 'still_not_a_tensor'
113
- }
114
- }
115
- tensor_batch = batch_to_tensor(batch)
116
-
117
- assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
118
- assert isinstance(tensor_batch['string_data'], str)
119
- assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
120
- assert isinstance(tensor_batch['nested']['text'], str)
121
-
122
-
123
- def test_batch_to_tensor_different_dtypes():
124
- """ Test conversion of arrays with different dtypes """
125
- batch = {
126
- 'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
127
- 'int_data': np.array([1, 2, 3], dtype=np.int64),
128
- 'bool_data': np.array([True, False, True], dtype=np.bool_)
129
- }
130
- tensor_batch = batch_to_tensor(batch)
131
-
132
- assert isinstance(tensor_batch['bool_data'], torch.Tensor)
133
- assert tensor_batch['float_data'].dtype == torch.float32
134
- assert tensor_batch['int_data'].dtype == torch.int64
135
- assert tensor_batch['bool_data'].dtype == torch.bool
136
-
137
-
138
- def test_batch_to_tensor_multidimensional():
139
- """ Test conversion of multidimensional arrays """
140
- batch = {
141
- 'matrix': np.array([[1, 2], [3, 4]]),
142
- 'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
143
- }
144
- tensor_batch = batch_to_tensor(batch)
145
-
146
- assert tensor_batch['matrix'].shape == (2, 2)
147
- assert tensor_batch['tensor'].shape == (2, 2, 2)
148
- assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
149
-
150
-
151
- def test_copy_batch_to_device():
152
- """ Test moving tensors to a different device """
153
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
- batch = {
155
- 'tensor_data': torch.tensor([1, 2, 3]),
156
- 'nested': {
157
- 'matrix': torch.tensor([[1, 2], [3, 4]])
158
- },
159
- 'non_tensor': 'unchanged'
160
- }
161
- moved_batch = copy_batch_to_device(batch, device)
162
- assert moved_batch['tensor_data'].device == device
163
- assert moved_batch['nested']['matrix'].device == device
164
- assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged
@@ -1,195 +0,0 @@
1
- """
2
- Site class testing - SiteSample
3
- """
4
-
5
- import pytest
6
- import numpy as np
7
- import xarray as xr
8
- import pandas as pd
9
-
10
- from pathlib import Path
11
- from xarray import Dataset
12
-
13
- from ocf_data_sampler.sample.site import SiteSample
14
-
15
-
16
- @pytest.fixture
17
- def sample_data():
18
- """ Fixture creation with sample data """
19
-
20
- # Time periods specified
21
- init_time = pd.Timestamp("2023-01-01 00:00")
22
- target_times = pd.date_range("2023-01-01 00:00", periods=4, freq="1h")
23
- sat_times = pd.date_range("2023-01-01 00:00", periods=7, freq="5min")
24
- site_times = pd.date_range("2023-01-01 00:00", periods=4, freq="30min")
25
-
26
- # Defined steps for NWP data
27
- steps = [(t - init_time) for t in target_times]
28
-
29
- # Create the sample dataset
30
- return Dataset(
31
- data_vars={
32
- 'nwp-ukv': (
33
- ["nwp-ukv__target_time_utc", "nwp-ukv__channel",
34
- "nwp-ukv__y_osgb", "nwp-ukv__x_osgb"],
35
- np.random.rand(4, 1, 2, 2)
36
- ),
37
- 'satellite': (
38
- ["satellite__time_utc", "satellite__channel",
39
- "satellite__y_geostationary", "satellite__x_geostationary"],
40
- np.random.rand(7, 1, 2, 2)
41
- ),
42
- 'site': (["site__time_utc"], np.random.rand(4))
43
- },
44
- coords={
45
- # NWP coords
46
- 'nwp-ukv__target_time_utc': target_times,
47
- 'nwp-ukv__channel': ['dswrf'],
48
- 'nwp-ukv__y_osgb': [0, 1],
49
- 'nwp-ukv__x_osgb': [0, 1],
50
- 'nwp-ukv__init_time_utc': init_time,
51
- 'nwp-ukv__step': ('nwp-ukv__target_time_utc', steps),
52
-
53
- # Sat coords
54
- 'satellite__time_utc': sat_times,
55
- 'satellite__channel': ['vis'],
56
- 'satellite__y_geostationary': [0, 1],
57
- 'satellite__x_geostationary': [0, 1],
58
-
59
- # Site coords
60
- 'site__time_utc': site_times,
61
- 'site__capacity_kwp': 1000.0,
62
- 'site__site_id': 1,
63
- 'site__longitude': -3.5,
64
- 'site__latitude': 51.5,
65
-
66
- # Site features as coords
67
- 'site__solar_azimuth': ('site__time_utc', np.random.rand(4)),
68
- 'site__solar_elevation': ('site__time_utc', np.random.rand(4)),
69
- 'site__date_cos': ('site__time_utc', np.random.rand(4)),
70
- 'site__date_sin': ('site__time_utc', np.random.rand(4)),
71
- 'site__time_cos': ('site__time_utc', np.random.rand(4)),
72
- 'site__time_sin': ('site__time_utc', np.random.rand(4))
73
- }
74
- )
75
-
76
-
77
- def test_site_sample_init():
78
- """ Test initialisation """
79
- sample = SiteSample()
80
- assert isinstance(sample._data, dict)
81
- assert len(sample._data) == 0
82
-
83
-
84
- def test_site_sample_with_data(sample_data):
85
- """ Testing of defined sample with actual data """
86
- sample = SiteSample()
87
- sample._data = sample_data
88
-
89
- # Assert data structure
90
- assert isinstance(sample._data, Dataset)
91
-
92
- # Assert dimensions / shapes
93
- expected_dims = {
94
- "satellite__x_geostationary",
95
- "site__time_utc",
96
- "nwp-ukv__target_time_utc",
97
- "nwp-ukv__x_osgb",
98
- "satellite__channel",
99
- "satellite__y_geostationary",
100
- "satellite__time_utc",
101
- "nwp-ukv__channel",
102
- "nwp-ukv__y_osgb",
103
- }
104
- assert set(sample._data.dims) == expected_dims
105
- assert sample._data["satellite"].values.shape == (7, 1, 2, 2)
106
- assert sample._data["nwp-ukv"].values.shape == (4, 1, 2, 2)
107
- assert sample._data["site"].values.shape == (4,)
108
-
109
-
110
- def test_save_load(tmp_path, sample_data):
111
- """ Save and load functionality """
112
- sample = SiteSample()
113
- sample._data = sample_data
114
- filepath = tmp_path / "test_sample.nc"
115
- sample.save(filepath)
116
-
117
- # Assert file exists and has content
118
- assert filepath.exists()
119
- assert filepath.stat().st_size > 0
120
-
121
- # Load and verify
122
- loaded = SiteSample.load(filepath)
123
- assert isinstance(loaded, SiteSample)
124
- assert isinstance(loaded._data, Dataset)
125
-
126
- # Compare original / loaded data
127
- xr.testing.assert_identical(sample._data, loaded._data)
128
-
129
-
130
- def test_invalid_save_format(sample_data):
131
- """ Saving with invalid format """
132
- sample = SiteSample()
133
- sample._data = sample_data
134
- with pytest.raises(ValueError, match="Only .nc format is supported"):
135
- sample.save("invalid.txt")
136
-
137
-
138
- def test_invalid_load_format():
139
- """ Loading with invalid format """
140
- with pytest.raises(ValueError, match="Only .nc format is supported"):
141
- SiteSample.load("invalid.txt")
142
-
143
-
144
- def test_invalid_data_type():
145
- """ Handling of invalid data types """
146
- sample = SiteSample()
147
- sample._data = {"invalid": "data"}
148
-
149
- with pytest.raises(TypeError, match="Data must be xarray Dataset"):
150
- sample.to_numpy()
151
-
152
- with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
153
- sample.save("test.nc")
154
-
155
-
156
- def test_to_numpy(sample_data):
157
- """ To numpy conversion """
158
- sample = SiteSample()
159
- sample._data = sample_data
160
- numpy_data = sample.to_numpy()
161
-
162
- # Assert structure
163
- assert isinstance(numpy_data, dict)
164
- assert 'site' in numpy_data
165
- assert 'nwp' in numpy_data
166
-
167
- # Check site - numpy array instead of dict
168
- site_data = numpy_data['site']
169
- assert isinstance(site_data, np.ndarray)
170
- assert site_data.ndim == 1
171
- assert len(site_data) == 4
172
- assert np.all(site_data >= 0) and np.all(site_data <= 1)
173
-
174
- # Check NWP
175
- assert 'ukv' in numpy_data['nwp']
176
- nwp_data = numpy_data['nwp']['ukv']
177
- assert 'nwp' in nwp_data
178
- assert nwp_data['nwp'].shape == (4, 1, 2, 2)
179
-
180
-
181
- def test_data_consistency(sample_data):
182
- """ Consistency of data across operations """
183
- sample = SiteSample()
184
- sample._data = sample_data
185
- numpy_data = sample.to_numpy()
186
-
187
- # Assert components remain consistent after conversion above
188
- assert numpy_data['nwp']['ukv']['nwp'].shape == (4, 1, 2, 2)
189
- assert 'site' in numpy_data
190
-
191
- # Update site data checks to expect numpy array
192
- assert isinstance(numpy_data['site'], np.ndarray)
193
- assert numpy_data['site'].shape == (4,)
194
- assert np.all(numpy_data['site'] >= 0)
195
- assert np.all(numpy_data['site'] <= 1)