ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (77) hide show
  1. ocf_data_sampler/config/load.py +3 -3
  2. ocf_data_sampler/config/model.py +86 -72
  3. ocf_data_sampler/config/save.py +5 -4
  4. ocf_data_sampler/constants.py +140 -12
  5. ocf_data_sampler/load/gsp.py +6 -5
  6. ocf_data_sampler/load/load_dataset.py +5 -6
  7. ocf_data_sampler/load/nwp/nwp.py +17 -5
  8. ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
  9. ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
  10. ocf_data_sampler/load/nwp/providers/icon.py +46 -0
  11. ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
  12. ocf_data_sampler/load/nwp/providers/utils.py +3 -1
  13. ocf_data_sampler/load/satellite.py +27 -36
  14. ocf_data_sampler/load/site.py +11 -7
  15. ocf_data_sampler/load/utils.py +21 -16
  16. ocf_data_sampler/numpy_sample/collate.py +10 -9
  17. ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
  18. ocf_data_sampler/numpy_sample/gsp.py +15 -13
  19. ocf_data_sampler/numpy_sample/nwp.py +17 -23
  20. ocf_data_sampler/numpy_sample/satellite.py +17 -14
  21. ocf_data_sampler/numpy_sample/site.py +8 -7
  22. ocf_data_sampler/numpy_sample/sun_position.py +19 -25
  23. ocf_data_sampler/sample/__init__.py +0 -7
  24. ocf_data_sampler/sample/base.py +23 -44
  25. ocf_data_sampler/sample/site.py +25 -69
  26. ocf_data_sampler/sample/uk_regional.py +52 -103
  27. ocf_data_sampler/select/dropout.py +42 -27
  28. ocf_data_sampler/select/fill_time_periods.py +15 -3
  29. ocf_data_sampler/select/find_contiguous_time_periods.py +87 -75
  30. ocf_data_sampler/select/geospatial.py +63 -54
  31. ocf_data_sampler/select/location.py +16 -51
  32. ocf_data_sampler/select/select_spatial_slice.py +105 -89
  33. ocf_data_sampler/select/select_time_slice.py +71 -58
  34. ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
  35. ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
  36. ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +126 -118
  37. ocf_data_sampler/torch_datasets/datasets/site.py +135 -101
  38. ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
  39. ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
  40. ocf_data_sampler/torch_datasets/utils/validate_channels.py +23 -19
  41. ocf_data_sampler/utils.py +3 -1
  42. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/METADATA +7 -18
  43. ocf_data_sampler-0.1.16.dist-info/RECORD +56 -0
  44. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/WHEEL +1 -1
  45. {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.16.dist-info}/top_level.txt +1 -1
  46. scripts/refactor_site.py +62 -33
  47. utils/compute_icon_mean_stddev.py +72 -0
  48. ocf_data_sampler-0.1.10.dist-info/LICENSE +0 -21
  49. ocf_data_sampler-0.1.10.dist-info/RECORD +0 -82
  50. tests/__init__.py +0 -0
  51. tests/config/test_config.py +0 -113
  52. tests/config/test_load.py +0 -7
  53. tests/config/test_save.py +0 -28
  54. tests/conftest.py +0 -286
  55. tests/load/test_load_gsp.py +0 -15
  56. tests/load/test_load_nwp.py +0 -21
  57. tests/load/test_load_satellite.py +0 -17
  58. tests/load/test_load_sites.py +0 -14
  59. tests/numpy_sample/test_collate.py +0 -21
  60. tests/numpy_sample/test_datetime_features.py +0 -37
  61. tests/numpy_sample/test_gsp.py +0 -38
  62. tests/numpy_sample/test_nwp.py +0 -52
  63. tests/numpy_sample/test_satellite.py +0 -40
  64. tests/numpy_sample/test_sun_position.py +0 -81
  65. tests/select/test_dropout.py +0 -75
  66. tests/select/test_fill_time_periods.py +0 -28
  67. tests/select/test_find_contiguous_time_periods.py +0 -202
  68. tests/select/test_location.py +0 -67
  69. tests/select/test_select_spatial_slice.py +0 -154
  70. tests/select/test_select_time_slice.py +0 -275
  71. tests/test_sample/test_base.py +0 -164
  72. tests/test_sample/test_site_sample.py +0 -195
  73. tests/test_sample/test_uk_regional_sample.py +0 -163
  74. tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
  75. tests/torch_datasets/test_pvnet_uk.py +0 -167
  76. tests/torch_datasets/test_site.py +0 -226
  77. tests/torch_datasets/test_validate_channels_utils.py +0 -78
@@ -1,163 +0,0 @@
1
- """
2
- UK Regional class testing - UKRegionalSample
3
- """
4
-
5
- import pytest
6
- import numpy as np
7
- import torch
8
- import tempfile
9
-
10
- from ocf_data_sampler.numpy_sample import (
11
- GSPSampleKey,
12
- SatelliteSampleKey,
13
- NWPSampleKey
14
- )
15
-
16
- from ocf_data_sampler.sample.uk_regional import UKRegionalSample
17
-
18
-
19
- # Fixture define
20
- @pytest.fixture
21
- def pvnet_config_filename(tmp_path):
22
- """ Minimal config file - testing """
23
- config_content = """
24
- input_data:
25
- gsp:
26
- zarr_path: ""
27
- time_resolution_minutes: 30
28
- interval_start_minutes: -180
29
- interval_end_minutes: 0
30
- nwp:
31
- ukv:
32
- zarr_path: ""
33
- image_size_pixels_height: 64
34
- image_size_pixels_width: 64
35
- time_resolution_minutes: 60
36
- interval_start_minutes: -180
37
- interval_end_minutes: 0
38
- channels: ["t", "dswrf"]
39
- provider: "ukv"
40
- satellite:
41
- zarr_path: ""
42
- image_size_pixels_height: 64
43
- image_size_pixels_width: 64
44
- time_resolution_minutes: 30
45
- interval_start_minutes: -180
46
- interval_end_minutes: 0
47
- channels: ["HRV"]
48
- """
49
- config_file = tmp_path / "test_config.yaml"
50
- config_file.write_text(config_content)
51
- return str(config_file)
52
-
53
-
54
- def create_test_data():
55
- """ Synthetic data generation """
56
-
57
- # Field / spatial coordinates
58
- nwp_data = {
59
- 'nwp': np.random.rand(4, 1, 2, 2),
60
- 'x': np.array([1, 2]),
61
- 'y': np.array([1, 2]),
62
- NWPSampleKey.channel_names: ['test_channel']
63
- }
64
-
65
- return {
66
- 'nwp': {
67
- 'ukv': nwp_data
68
- },
69
- GSPSampleKey.gsp: np.random.rand(7),
70
- SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
71
- GSPSampleKey.solar_azimuth: np.random.rand(7),
72
- GSPSampleKey.solar_elevation: np.random.rand(7)
73
- }
74
-
75
-
76
- # UKRegionalSample testing
77
- def test_sample_init():
78
- """ Initialisation """
79
- sample = UKRegionalSample()
80
- assert hasattr(sample, '_data'), "Sample should have _data attribute"
81
- assert isinstance(sample._data, dict)
82
- assert len(sample._data) == 0
83
-
84
-
85
- def test_sample_save_load():
86
- sample = UKRegionalSample()
87
- sample._data = create_test_data()
88
-
89
- with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
90
- sample.save(tf.name)
91
- loaded = UKRegionalSample.load(tf.name)
92
-
93
- assert set(loaded._data.keys()) == set(sample._data.keys())
94
- assert isinstance(loaded._data['nwp'], dict)
95
- assert 'ukv' in loaded._data['nwp']
96
-
97
- assert loaded._data[GSPSampleKey.gsp].shape == (7,)
98
- assert loaded._data[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
99
- assert loaded._data[GSPSampleKey.solar_azimuth].shape == (7,)
100
- assert loaded._data[GSPSampleKey.solar_elevation].shape == (7,)
101
-
102
- np.testing.assert_array_almost_equal(
103
- loaded._data[GSPSampleKey.gsp],
104
- sample._data[GSPSampleKey.gsp]
105
- )
106
-
107
-
108
- def test_save_unsupported_format():
109
- """ Test saving - unsupported file format """
110
- sample = UKRegionalSample()
111
- sample._data = create_test_data()
112
-
113
- with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
114
- with pytest.raises(ValueError, match="Only .pt format is supported"):
115
- sample.save(tf.name)
116
-
117
-
118
- def test_load_unsupported_format():
119
- """ Test loading - unsupported file format """
120
-
121
- with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
122
- with pytest.raises(ValueError, match="Only .pt format is supported"):
123
- UKRegionalSample.load(tf.name)
124
-
125
-
126
- def test_load_corrupted_file():
127
- """ Test loading - corrupted / empty file """
128
-
129
- with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
130
- with open(tf.name, 'wb') as f:
131
- f.write(b'corrupted data')
132
-
133
- with pytest.raises(Exception):
134
- UKRegionalSample.load(tf.name)
135
-
136
-
137
- def test_to_numpy():
138
- """ To numpy conversion check """
139
- sample = UKRegionalSample()
140
- sample._data = {
141
- 'nwp': {
142
- 'ukv': {
143
- 'nwp': np.random.rand(4, 1, 2, 2),
144
- 'x': np.array([1, 2]),
145
- 'y': np.array([1, 2])
146
- }
147
- },
148
- GSPSampleKey.gsp: np.random.rand(7),
149
- SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
150
- GSPSampleKey.solar_azimuth: np.random.rand(7),
151
- GSPSampleKey.solar_elevation: np.random.rand(7)
152
- }
153
-
154
- numpy_data = sample.to_numpy()
155
-
156
- # Check returned data matches
157
- assert numpy_data == sample._data
158
- assert len(numpy_data) == len(sample._data)
159
-
160
- # Assert specific keys and types
161
- assert 'nwp' in numpy_data
162
- assert isinstance(numpy_data['nwp']['ukv']['nwp'], np.ndarray)
163
- assert numpy_data[GSPSampleKey.gsp].shape == (7,)
@@ -1,40 +0,0 @@
1
- import numpy as np
2
-
3
- from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
4
- merge_dicts,
5
- fill_nans_in_arrays,
6
- )
7
-
8
- def test_merge_dicts():
9
- """Test merge_dicts function"""
10
- dict1 = {"a": 1, "b": 2}
11
- dict2 = {"c": 3, "d": 4}
12
- dict3 = {"e": 5}
13
-
14
- result = merge_dicts([dict1, dict2, dict3])
15
- assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
16
-
17
- # Test key overwriting
18
- dict4 = {"a": 10, "f": 6}
19
- result = merge_dicts([dict1, dict4])
20
- assert result["a"] == 10
21
-
22
-
23
- def test_fill_nans_in_arrays():
24
- """Test the fill_nans_in_arrays function"""
25
- array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
26
- nested_dict = {
27
- "array1": array_with_nans,
28
- "nested": {
29
- "array2": np.array([np.nan, 2.0, np.nan, 4.0])
30
- },
31
- "string_key": "not_an_array"
32
- }
33
-
34
- result = fill_nans_in_arrays(nested_dict)
35
-
36
- assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
37
- assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
38
- assert result["string_key"] == "not_an_array"
39
-
40
-
@@ -1,167 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import xarray as xr
4
- import dask.array
5
-
6
- from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
7
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import (
8
- PVNetUKRegionalDataset,
9
- PVNetUKConcurrentDataset,
10
- process_and_combine_datasets,
11
- compute,
12
- )
13
- from ocf_data_sampler.select.location import Location
14
-
15
- def test_process_and_combine_datasets(pvnet_config_filename):
16
-
17
- # Load in config for function and define location
18
- config = load_yaml_configuration(pvnet_config_filename)
19
- t0 = pd.Timestamp("2024-01-01 00:00")
20
- location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
21
-
22
- nwp_data = xr.DataArray(
23
- np.random.rand(4, 2, 2, 2),
24
- dims=["time_utc", "channel", "y", "x"],
25
- coords={
26
- "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
27
- "channel": ["t", "dswrf"],
28
- "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
29
- "init_time_utc": pd.Timestamp("2024-01-01 00:00")
30
- }
31
- )
32
-
33
- sat_data = xr.DataArray(
34
- np.random.rand(7, 1, 2, 2),
35
- dims=["time_utc", "channel", "y", "x"],
36
- coords={
37
- "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
38
- "channel": ["HRV"],
39
- "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
40
- "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
41
- }
42
- )
43
-
44
- # Combine as dict
45
- dataset_dict = {
46
- "nwp": {"ukv": nwp_data},
47
- "sat": sat_data
48
- }
49
-
50
- # Call relevant function
51
- sample = process_and_combine_datasets(dataset_dict, config, t0, location)
52
-
53
- # Assert result is dict - check and validate
54
- assert isinstance(sample, dict)
55
- assert "nwp" in sample
56
- assert sample["satellite_actual"].shape == (7, 1, 2, 2)
57
- assert sample["nwp"]["ukv"]["nwp"].shape == (4, 2, 2, 2)
58
- assert "gsp_id" in sample
59
-
60
-
61
- def test_compute():
62
- """Test compute function with dask array"""
63
- da_dask = xr.DataArray(dask.array.random.random((5, 5)))
64
-
65
- # Create a nested dictionary with dask array
66
- lazy_data_dict = {
67
- "array1": da_dask,
68
- "nested": {
69
- "array2": da_dask
70
- }
71
- }
72
-
73
- computed_data_dict = compute(lazy_data_dict)
74
-
75
- # Assert that the result is no longer lazy
76
- assert isinstance(computed_data_dict["array1"].data, np.ndarray)
77
- assert isinstance(computed_data_dict["nested"]["array2"].data, np.ndarray)
78
-
79
-
80
- def test_pvnet_uk_regional_dataset(pvnet_config_filename):
81
-
82
- # Create dataset object
83
- dataset = PVNetUKRegionalDataset(pvnet_config_filename)
84
-
85
- assert len(dataset.locations) == 317 # Number of regional GSPs
86
- # NB. I have not checked the value (39 below) is in fact correct
87
- assert len(dataset.valid_t0_times) == 39
88
- assert len(dataset) == 317*39
89
-
90
- # Generate a sample
91
- sample = dataset[0]
92
-
93
- assert isinstance(sample, dict)
94
-
95
- for key in [
96
- "nwp", "satellite_actual", "gsp",
97
- "gsp_solar_azimuth", "gsp_solar_elevation",
98
- ]:
99
- assert key in sample
100
-
101
- for nwp_source in ["ukv"]:
102
- assert nwp_source in sample["nwp"]
103
-
104
- # Check the shape of the data is correct
105
- # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
106
- assert sample["satellite_actual"].shape == (7, 1, 2, 2)
107
- # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
108
- assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
109
- # 3 hours of 30 minute data (inclusive)
110
- assert sample["gsp"].shape == (7,)
111
- # Solar angles have same shape as GSP data
112
- assert sample["gsp_solar_azimuth"].shape == (7,)
113
- assert sample["gsp_solar_elevation"].shape == (7,)
114
-
115
-
116
- def test_pvnet_no_gsp(tmp_path, pvnet_config_filename):
117
-
118
- # Create new config without GSP inputs
119
- config = load_yaml_configuration(pvnet_config_filename)
120
- config.input_data.gsp.zarr_path = ''
121
- new_config_path = tmp_path / "pvnet_config_no_gsp.yaml"
122
- save_yaml_configuration(config, new_config_path)
123
-
124
- # Create dataset object
125
- dataset = PVNetUKRegionalDataset(new_config_path)
126
-
127
- # Generate a sample
128
- _ = dataset[0]
129
-
130
-
131
- def test_pvnet_uk_concurrent_dataset(pvnet_config_filename):
132
-
133
- # Create dataset object using a limited set of GSPs for test
134
- gsp_ids = [1,2,3]
135
- num_gsps = len(gsp_ids)
136
-
137
- dataset = PVNetUKConcurrentDataset(pvnet_config_filename, gsp_ids=gsp_ids)
138
-
139
- assert len(dataset.locations) == num_gsps # Number of regional GSPs
140
- # NB. I have not checked the value (39 below) is in fact correct
141
- assert len(dataset.valid_t0_times) == 39
142
- assert len(dataset) == 39
143
-
144
- # Generate a sample
145
- sample = dataset[0]
146
-
147
- assert isinstance(sample, dict)
148
-
149
- for key in [
150
- "nwp", "satellite_actual", "gsp",
151
- "gsp_solar_azimuth", "gsp_solar_elevation",
152
- ]:
153
- assert key in sample
154
-
155
- for nwp_source in ["ukv"]:
156
- assert nwp_source in sample["nwp"]
157
-
158
- # Check the shape of the data is correct
159
- # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
160
- assert sample["satellite_actual"].shape == (num_gsps, 7, 1, 2, 2)
161
- # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
162
- assert sample["nwp"]["ukv"]["nwp"].shape == (num_gsps, 4, 1, 2, 2)
163
- # 3 hours of 30 minute data (inclusive)
164
- assert sample["gsp"].shape == (num_gsps, 7,)
165
- # Solar angles have same shape as GSP data
166
- assert sample["gsp_solar_azimuth"].shape == (num_gsps, 7,)
167
- assert sample["gsp_solar_elevation"].shape == (num_gsps, 7,)
@@ -1,226 +0,0 @@
1
- import pytest
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import xarray as xr
6
-
7
- from torch.utils.data import DataLoader
8
-
9
- from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
10
- from ocf_data_sampler.torch_datasets.datasets.site import (
11
- SitesDataset, convert_from_dataset_to_dict_datasets, coarsen_data
12
- )
13
-
14
-
15
-
16
- @pytest.fixture()
17
- def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
18
-
19
- # adjust config to point to the zarr file
20
- config = load_yaml_configuration(config_filename)
21
- config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
22
- config.input_data.satellite.zarr_path = sat_zarr_path
23
- config.input_data.site = data_sites
24
- config.input_data.gsp = None
25
-
26
- filename = f"{tmp_path}/configuration.yaml"
27
- save_yaml_configuration(config, filename)
28
- yield filename
29
-
30
-
31
- @pytest.fixture()
32
- def sites_dataset(site_config_filename):
33
- return SitesDataset(site_config_filename)
34
-
35
-
36
- def test_site(tmp_path, site_config_filename):
37
-
38
- # Create dataset object
39
- dataset = SitesDataset(site_config_filename)
40
-
41
- assert len(dataset) == 10 * 41
42
- # TODO check 41
43
-
44
- # Generate a sample
45
- sample = dataset[0]
46
-
47
- assert isinstance(sample, xr.Dataset)
48
-
49
- # Expected dimensions and data variables
50
- expected_dims = {
51
- "satellite__x_geostationary",
52
- "site__time_utc",
53
- "nwp-ukv__target_time_utc",
54
- "nwp-ukv__x_osgb",
55
- "satellite__channel",
56
- "satellite__y_geostationary",
57
- "satellite__time_utc",
58
- "nwp-ukv__channel",
59
- "nwp-ukv__y_osgb",
60
- }
61
-
62
- expected_coords_subset = {
63
- "site__solar_azimuth",
64
- "site__solar_elevation",
65
- "site__date_cos",
66
- "site__time_cos",
67
- "site__time_sin",
68
- "site__date_sin",
69
- }
70
-
71
- expected_data_vars = {"nwp-ukv", "satellite", "site"}
72
-
73
-
74
- sample.to_netcdf(f"{tmp_path}/sample.nc")
75
- sample = xr.open_dataset(f"{tmp_path}/sample.nc")
76
-
77
- # Check dimensions
78
- assert (
79
- set(sample.dims) == expected_dims
80
- ), f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
81
- # Check data variables
82
- assert (
83
- set(sample.data_vars) == expected_data_vars
84
- ), f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"
85
-
86
- for coords in expected_coords_subset:
87
- assert coords in sample.coords
88
-
89
- # check the shape of the data is correct
90
- # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
91
- assert sample["satellite"].values.shape == (7, 1, 2, 2)
92
- # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
93
- assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2)
94
- # 1.5 hours of 30 minute data (inclusive)
95
- assert sample["site"].values.shape == (4,)
96
-
97
-
98
- def test_site_time_filter_start(site_config_filename):
99
-
100
- # Create dataset object
101
- dataset = SitesDataset(site_config_filename, start_time="2024-01-01")
102
-
103
- assert len(dataset) == 0
104
-
105
-
106
- def test_site_time_filter_end(site_config_filename):
107
-
108
- # Create dataset object
109
- dataset = SitesDataset(site_config_filename, end_time="2000-01-01")
110
-
111
- assert len(dataset) == 0
112
-
113
-
114
- def test_site_get_sample(sites_dataset):
115
- sample = sites_dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
116
-
117
-
118
- def test_convert_from_dataset_to_dict_datasets(sites_dataset):
119
-
120
- # Generate sample
121
- sample_xr = sites_dataset[0]
122
-
123
- sample = convert_from_dataset_to_dict_datasets(sample_xr)
124
-
125
- assert isinstance(sample, dict)
126
-
127
- for key in ["nwp", "satellite", "site"]:
128
- assert key in sample
129
-
130
-
131
- def test_site_dataset_with_dataloader(sites_dataset):
132
-
133
- expected_coods = {
134
- "site__solar_azimuth",
135
- "site__solar_elevation",
136
- "site__date_cos",
137
- "site__time_cos",
138
- "site__time_sin",
139
- "site__date_sin",
140
- }
141
-
142
- dataloader_kwargs = dict(
143
- shuffle=False,
144
- batch_size=None,
145
- sampler=None,
146
- batch_sampler=None,
147
- num_workers=1,
148
- collate_fn=None,
149
- pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
150
- drop_last=False,
151
- timeout=0,
152
- worker_init_fn=None,
153
- prefetch_factor=1,
154
- persistent_workers=False, # Not needed since we only enter the dataloader loop once
155
- )
156
-
157
- dataloader = DataLoader(sites_dataset, collate_fn=None, batch_size=None)
158
-
159
- sample = next(iter(dataloader))
160
-
161
- # check that expected_dims is in the sample
162
- for key in expected_coods:
163
- assert key in sample
164
-
165
-
166
- def test_process_and_combine_site_sample_dict(sites_dataset):
167
-
168
- # Specify minimal structure for testing
169
- raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
170
- fake_site_values = np.random.rand(197)
171
- site_dict = {
172
- "nwp": {
173
- "ukv": xr.DataArray(
174
- raw_nwp_values,
175
- dims=["time_utc", "channel", "y", "x"],
176
- coords={
177
- "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
178
- "channel": ["dswrf"], # Single channel
179
- },
180
- )
181
- },
182
- "site": xr.DataArray(
183
- fake_site_values,
184
- dims=["time_utc"],
185
- coords={
186
- "time_utc": pd.date_range("2024-01-01 00:00", periods=197, freq="15min"),
187
- "capacity_kwp": 1000,
188
- "site_id": 1,
189
- "longitude": -3.5,
190
- "latitude": 51.5
191
- }
192
- )
193
- }
194
- print(f"Input site_dict: {site_dict}")
195
-
196
- # Call function
197
- result = sites_dataset.process_and_combine_site_sample_dict(site_dict)
198
-
199
- # Assert to validate output structure
200
- assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
201
- assert len(result.data_vars) > 0, "Dataset should contain data variables"
202
-
203
- # Validate variable via assertion and shape of such
204
- expected_variables = ["nwp-ukv", "site"]
205
- for expected_variable in expected_variables:
206
- assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found"
207
-
208
- nwp_result = result["nwp-ukv"]
209
- assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for nwp-ukv : {nwp_result.shape}"
210
- site_result = result["site"]
211
- assert site_result.shape == (197,), f"Unexpected shape for site: {site_result.shape}"
212
-
213
-
214
- def test_potentially_coarsen(ds_nwp_ecmwf):
215
- """Test potentially_coarsen function with ECMWF_UK data."""
216
- nwp_data = ds_nwp_ecmwf
217
- assert nwp_data.ECMWF_UK.shape[3:] == (15, 12) # Check initial shape (lon, lat)
218
-
219
- data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=2)
220
- assert data.ECMWF_UK.shape[3:] == (8, 6) # Coarsen to every 2 degrees
221
-
222
- data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=3)
223
- assert data.ECMWF_UK.shape[3:] == (5, 4) # Coarsen to every 3 degrees
224
-
225
- data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=1)
226
- assert data.ECMWF_UK.shape[3:] == (15, 12) # No coarsening (same shape)
@@ -1,78 +0,0 @@
1
- """Tests for channel validation utility functions"""
2
-
3
- import pytest
4
- from ocf_data_sampler.torch_datasets.utils.validate_channels import (
5
- validate_channels,
6
- validate_nwp_channels,
7
- validate_satellite_channels,
8
- )
9
-
10
-
11
- class TestChannelValidation:
12
- """Tests for channel validation functions"""
13
-
14
- @pytest.mark.parametrize("test_case", [
15
- # Base validation - success case
16
- {
17
- "data_channels": ["channel1", "channel2"],
18
- "norm_channels": ["channel1", "channel2", "extra"],
19
- "source_name": "test_source",
20
- "expect_error": False
21
- },
22
- # Base validation - error case
23
- {
24
- "data_channels": ["channel1", "missing_channel"],
25
- "norm_channels": ["channel1"],
26
- "source_name": "test_source",
27
- "expect_error": True,
28
- "error_match": "following channels for test_source are missing in normalisation means"
29
- },
30
- # NWP case - success
31
- {
32
- "data_channels": ["t2m", "dswrf"],
33
- "norm_channels": ["t2m", "dswrf", "extra"],
34
- "source_name": "ecmwf",
35
- "expect_error": False
36
- },
37
- # NWP case - error
38
- {
39
- "data_channels": ["t2m", "missing_channel"],
40
- "norm_channels": ["t2m"],
41
- "source_name": "ecmwf",
42
- "expect_error": True,
43
- "error_match": "following channels for ecmwf are missing in normalisation means"
44
- },
45
- # Satellite case - success
46
- {
47
- "data_channels": ["IR_016", "VIS006"],
48
- "norm_channels": ["IR_016", "VIS006", "extra"],
49
- "source_name": "satellite",
50
- "expect_error": False
51
- },
52
- # Satellite case - error
53
- {
54
- "data_channels": ["IR_016", "missing_channel"],
55
- "norm_channels": ["IR_016"],
56
- "source_name": "satellite",
57
- "expect_error": True,
58
- "error_match": "following channels for satellite are missing in normalisation means"
59
- }
60
- ])
61
- def test_channel_validation(self, test_case):
62
- """Test channel validation for both base, NWP and satellite data"""
63
- if test_case["expect_error"]:
64
- with pytest.raises(ValueError, match=test_case["error_match"]):
65
- validate_channels(
66
- data_channels=test_case["data_channels"],
67
- means_channels=test_case["norm_channels"],
68
- stds_channels=test_case["norm_channels"],
69
- source_name=test_case["source_name"]
70
- )
71
- else:
72
- # Should not raise any exceptions
73
- validate_channels(
74
- data_channels=test_case["data_channels"],
75
- means_channels=test_case["norm_channels"],
76
- stds_channels=test_case["norm_channels"],
77
- source_name=test_case["source_name"]
78
- )