ocf-data-sampler 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -86,7 +86,7 @@ ECMWF_STD = {
86
86
  "lcc": 0.3791404366493225,
87
87
  "mcc": 0.38039860129356384,
88
88
  "prate": 9.81039775069803e-05,
89
- "sde": 0.000913831521756947,
89
+ "sd": 0.000913831521756947,
90
90
  "sr": 16294988.0,
91
91
  "t2m": 3.692270040512085,
92
92
  "tcc": 0.37487083673477173,
@@ -110,7 +110,7 @@ ECMWF_MEAN = {
110
110
  "lcc": 0.44901806116104126,
111
111
  "mcc": 0.3288780450820923,
112
112
  "prate": 3.108070450252853e-05,
113
- "sde": 8.107526082312688e-05,
113
+ "sd": 8.107526082312688e-05,
114
114
  "sr": 12905302.0,
115
115
  "t2m": 283.48333740234375,
116
116
  "tcc": 0.7049227356910706,
@@ -73,3 +73,26 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
73
73
  elif np.issubdtype(v.dtype, np.number):
74
74
  batch[k] = torch.as_tensor(v)
75
75
  return batch
76
+
77
+
78
+ def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
79
+ """
80
+ Moves tensor leaves in a nested dict to a new device.
81
+
82
+ Args:
83
+ batch: Nested dict with tensors to move.
84
+ device: Device to move tensors to.
85
+
86
+ Returns:
87
+ A dict with tensors moved to the new device.
88
+ """
89
+ batch_copy = {}
90
+
91
+ for k, v in batch.items():
92
+ if isinstance(v, dict):
93
+ batch_copy[k] = copy_batch_to_device(v, device)
94
+ elif isinstance(v, torch.Tensor):
95
+ batch_copy[k] = v.to(device)
96
+ else:
97
+ batch_copy[k] = v
98
+ return batch_copy
@@ -31,10 +31,15 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
31
  merge_dicts,
32
32
  fill_nans_in_arrays,
33
33
  )
34
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
+ validate_nwp_channels,
36
+ validate_satellite_channels,
37
+ )
34
38
 
35
39
 
36
40
  xr.set_options(keep_attrs=True)
37
41
 
42
+
38
43
  def process_and_combine_datasets(
39
44
  dataset_dict: dict,
40
45
  config: Configuration,
@@ -47,27 +52,23 @@ def process_and_combine_datasets(
47
52
  numpy_modalities = []
48
53
 
49
54
  if "nwp" in dataset_dict:
50
-
51
55
  nwp_numpy_modalities = dict()
52
56
 
53
57
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
54
- # Standardise
55
58
  provider = config.input_data.nwp[nwp_key].provider
56
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
57
59
 
58
- # Convert to NumpyBatch
60
+ # Standardise and convert to NumpyBatch
61
+ da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
59
62
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
60
63
 
61
64
  # Combine the NWPs into NumpyBatch
62
65
  numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
63
66
 
64
-
65
67
  if "sat" in dataset_dict:
66
- # Standardise
67
68
  da_sat = dataset_dict["sat"]
68
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
69
69
 
70
- # Convert to NumpyBatch
70
+ # Standardise and convert to NumpyBatch
71
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
71
72
  numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
72
73
 
73
74
  gsp_config = config.input_data.gsp
@@ -186,9 +187,13 @@ class PVNetUKRegionalDataset(Dataset):
186
187
  """
187
188
 
188
189
  config = load_yaml_configuration(config_filename)
189
-
190
+
191
+ # Validate channels for NWP and satellite data
192
+ validate_nwp_channels(config)
193
+ validate_satellite_channels(config)
194
+
190
195
  datasets_dict = get_dataset_dict(config.input_data)
191
-
196
+
192
197
  # Get t0 times where all input data is available
193
198
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
194
199
 
@@ -294,7 +299,11 @@ class PVNetUKConcurrentDataset(Dataset):
294
299
  """
295
300
 
296
301
  config = load_yaml_configuration(config_filename)
297
-
302
+
303
+ # Validate channels for NWP and satellite data
304
+ validate_nwp_channels(config)
305
+ validate_satellite_channels(config)
306
+
298
307
  datasets_dict = get_dataset_dict(config.input_data)
299
308
 
300
309
  # Get t0 times where all input data is available
@@ -361,4 +370,4 @@ class PVNetUKConcurrentDataset(Dataset):
361
370
  """
362
371
  # Check data is availablle for init-time t0
363
372
  assert t0 in self.valid_t0_times
364
- return self._get_sample(t0)
373
+ return self._get_sample(t0)
@@ -1,4 +1,5 @@
1
1
  """Torch dataset for sites"""
2
+
2
3
  import logging
3
4
  import numpy as np
4
5
  import pandas as pd
@@ -19,6 +20,8 @@ from ocf_data_sampler.select import (
19
20
  from ocf_data_sampler.utils import minutes
20
21
  from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
21
22
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
23
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import validate_nwp_channels
24
+
22
25
  from ocf_data_sampler.numpy_sample import (
23
26
  convert_site_to_numpy_sample,
24
27
  convert_satellite_to_numpy_sample,
@@ -29,8 +32,10 @@ from ocf_data_sampler.numpy_sample import (
29
32
  from ocf_data_sampler.numpy_sample import NWPSampleKey
30
33
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
31
34
 
35
+
32
36
  xr.set_options(keep_attrs=True)
33
37
 
38
+
34
39
  class SitesDataset(Dataset):
35
40
  def __init__(
36
41
  self,
@@ -47,6 +52,10 @@ class SitesDataset(Dataset):
47
52
  """
48
53
 
49
54
  config: Configuration = load_yaml_configuration(config_filename)
55
+
56
+ # Validate NWP channels
57
+ validate_nwp_channels(config)
58
+
50
59
  datasets_dict = get_dataset_dict(config.input_data)
51
60
 
52
61
  # Assign config and input data to self
@@ -221,8 +230,9 @@ class SitesDataset(Dataset):
221
230
 
222
231
  if "nwp" in dataset_dict:
223
232
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
224
- # Standardise
225
233
  provider = self.config.input_data.nwp[nwp_key].provider
234
+
235
+ # Standardise
226
236
  da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
227
237
  data_arrays.append((f"nwp-{provider}", da_nwp))
228
238
 
@@ -0,0 +1,82 @@
1
+ import xarray as xr
2
+
3
+ from ocf_data_sampler.config import Configuration
4
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
5
+
6
+
7
+ def validate_channels(
8
+ data_channels: list,
9
+ means_channels: list,
10
+ stds_channels: list,
11
+ source_name: str | None = None
12
+ ) -> None:
13
+ """
14
+ Validates that all channels in data have corresponding normalisation constants.
15
+
16
+ Args:
17
+ data_channels: Set of channels from the data
18
+ means_channels: Set of channels from means constants
19
+ stds_channels: Set of channels from stds constants
20
+ source_name: Name of data source (e.g., 'ecmwf', 'satellite') for error messages
21
+
22
+ Raises:
23
+ ValueError: If there's a mismatch between data channels and normalisation constants
24
+ """
25
+
26
+ data_set = set(data_channels)
27
+ means_set = set(means_channels)
28
+ stds_set = set(stds_channels)
29
+
30
+ # Find missing channels in means
31
+ missing_in_means = data_set - means_set
32
+ if missing_in_means:
33
+ raise ValueError(
34
+ f"The following channels for {source_name} are missing in normalisation means: "
35
+ f"{missing_in_means}"
36
+ )
37
+
38
+ # Find missing channels in stds
39
+ missing_in_stds = data_set - stds_set
40
+ if missing_in_stds:
41
+ raise ValueError(
42
+ f"The following channels for {source_name} are missing in normalisation stds: "
43
+ f"{missing_in_stds}"
44
+ )
45
+
46
+
47
+ def validate_nwp_channels(config: Configuration) -> None:
48
+ """Validate that NWP channels in config have corresponding normalisation constants.
49
+
50
+ Args:
51
+ config: Configuration object containing NWP channel information
52
+
53
+ Raises:
54
+ ValueError: If there's a mismatch between configured NWP channels and normalisation constants
55
+ """
56
+ if hasattr(config.input_data, "nwp"):
57
+ for nwp_key, nwp_config in config.input_data.nwp.items():
58
+ provider = nwp_config.provider
59
+ validate_channels(
60
+ data_channels=nwp_config.channels,
61
+ means_channels=NWP_MEANS[provider].channel.values,
62
+ stds_channels=NWP_STDS[provider].channel.values,
63
+ source_name=provider
64
+ )
65
+
66
+
67
+ def validate_satellite_channels(config: Configuration) -> None:
68
+ """Validate that satellite channels in config have corresponding normalisation constants.
69
+
70
+ Args:
71
+ config: Configuration object containing satellite channel information
72
+
73
+ Raises:
74
+ ValueError: If there's a mismatch between configured satellite channels and normalisation constants
75
+ """
76
+ if hasattr(config.input_data, "satellite"):
77
+ validate_channels(
78
+ data_channels=config.input_data.satellite.channels,
79
+ means_channels=RSS_MEAN.channel.values,
80
+ stds_channels=RSS_STD.channel.values,
81
+ source_name="satellite"
82
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -1,5 +1,5 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
- ocf_data_sampler/constants.py,sha256=ClteRIgp7EPlUPqIbkel83BfIaD7_VIDjUeHzUfyhnM,5079
2
+ ocf_data_sampler/constants.py,sha256=0HYNmqwBaHVTAEEx9qzk6WD9YInh0gSKLeI3pyq7aNs,5077
3
3
  ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
4
4
  ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
5
5
  ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
27
27
  ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
28
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
30
+ ocf_data_sampler/sample/base.py,sha256=q3wpqoW4JXRmzfar6ed7UMn1nxBxSJXNvMLJmHXy1dw,2856
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
32
  ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
@@ -41,11 +41,12 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
41
41
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
42
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=xuNJyCXZ4dZ9UldX1lqOoRSRNP39Vcy0DR77Vr7dxlk,11895
45
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=ZjvJS0mWUyQE7ZcrhS1TdMHaPrEZXVbBAv2vDwBvQwA,16044
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=N85duDyEm6LIYgYIpLhrpxHddMIcvFosuZg8rzIztwE,12267
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=L_4w967ZxPjd7vHRkPtj7ZSmamEShKRT28j9_f-enJY,16228
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
- scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
48
+ ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
49
+ scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
49
50
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
51
  tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
51
52
  tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
@@ -67,14 +68,15 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
67
68
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
68
69
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
69
70
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
70
- tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
71
+ tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
71
72
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
72
73
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
73
74
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
74
- tests/torch_datasets/test_pvnet_uk.py,sha256=loueo7PUUYJVda3-vBn3bQIC_zgrTAThfx-GTDcBOZg,5596
75
+ tests/torch_datasets/test_pvnet_uk.py,sha256=F0D-DugFgVtt8G1q7lylmPLrOZj6H6YPNd9s_6Wn_yM,5594
75
76
  tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
76
- ocf_data_sampler-0.1.7.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
77
- ocf_data_sampler-0.1.7.dist-info/METADATA,sha256=8SbL1qjkmeFDYdv1_hHBL9jxbSpt4aFCpx70rEEPeb0,12173
78
- ocf_data_sampler-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
79
- ocf_data_sampler-0.1.7.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
80
- ocf_data_sampler-0.1.7.dist-info/RECORD,,
77
+ tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
78
+ ocf_data_sampler-0.1.9.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
+ ocf_data_sampler-0.1.9.dist-info/METADATA,sha256=Lfu8Yrj4CSlqPzGhk0iDy5r5zCLd5REnGAlVcFuKuow,12173
80
+ ocf_data_sampler-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
+ ocf_data_sampler-0.1.9.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
+ ocf_data_sampler-0.1.9.dist-info/RECORD,,
scripts/refactor_site.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """ Helper functions for refactoring legacy site data """
2
-
2
+ import xarray as xr
3
3
 
4
4
  def legacy_format(data_ds, metadata_df):
5
5
  """This formats old legacy data to the new format.
@@ -9,7 +9,8 @@ import numpy as np
9
9
  from pathlib import Path
10
10
  from ocf_data_sampler.sample.base import (
11
11
  SampleBase,
12
- batch_to_tensor
12
+ batch_to_tensor,
13
+ copy_batch_to_device
13
14
  )
14
15
 
15
16
  class TestSample(SampleBase):
@@ -145,3 +146,19 @@ def test_batch_to_tensor_multidimensional():
145
146
  assert tensor_batch['matrix'].shape == (2, 2)
146
147
  assert tensor_batch['tensor'].shape == (2, 2, 2)
147
148
  assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
149
+
150
+
151
+ def test_copy_batch_to_device():
152
+ """ Test moving tensors to a different device """
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ batch = {
155
+ 'tensor_data': torch.tensor([1, 2, 3]),
156
+ 'nested': {
157
+ 'matrix': torch.tensor([[1, 2], [3, 4]])
158
+ },
159
+ 'non_tensor': 'unchanged'
160
+ }
161
+ moved_batch = copy_batch_to_device(batch, device)
162
+ assert moved_batch['tensor_data'].device == device
163
+ assert moved_batch['nested']['matrix'].device == device
164
+ assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged
@@ -24,7 +24,7 @@ def test_process_and_combine_datasets(pvnet_config_filename):
24
24
  dims=["time_utc", "channel", "y", "x"],
25
25
  coords={
26
26
  "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
27
- "channel": ["t2m", "dswrf"],
27
+ "channel": ["t", "dswrf"],
28
28
  "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
29
29
  "init_time_utc": pd.Timestamp("2024-01-01 00:00")
30
30
  }
@@ -54,7 +54,7 @@ def test_process_and_combine_datasets(pvnet_config_filename):
54
54
  assert isinstance(sample, dict)
55
55
  assert "nwp" in sample
56
56
  assert sample["satellite_actual"].shape == (7, 1, 2, 2)
57
- assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
57
+ assert sample["nwp"]["ukv"]["nwp"].shape == (4, 2, 2, 2)
58
58
  assert "gsp_id" in sample
59
59
 
60
60
 
@@ -0,0 +1,78 @@
1
+ """Tests for channel validation utility functions"""
2
+
3
+ import pytest
4
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import (
5
+ validate_channels,
6
+ validate_nwp_channels,
7
+ validate_satellite_channels,
8
+ )
9
+
10
+
11
+ class TestChannelValidation:
12
+ """Tests for channel validation functions"""
13
+
14
+ @pytest.mark.parametrize("test_case", [
15
+ # Base validation - success case
16
+ {
17
+ "data_channels": ["channel1", "channel2"],
18
+ "norm_channels": ["channel1", "channel2", "extra"],
19
+ "source_name": "test_source",
20
+ "expect_error": False
21
+ },
22
+ # Base validation - error case
23
+ {
24
+ "data_channels": ["channel1", "missing_channel"],
25
+ "norm_channels": ["channel1"],
26
+ "source_name": "test_source",
27
+ "expect_error": True,
28
+ "error_match": "following channels for test_source are missing in normalisation means"
29
+ },
30
+ # NWP case - success
31
+ {
32
+ "data_channels": ["t2m", "dswrf"],
33
+ "norm_channels": ["t2m", "dswrf", "extra"],
34
+ "source_name": "ecmwf",
35
+ "expect_error": False
36
+ },
37
+ # NWP case - error
38
+ {
39
+ "data_channels": ["t2m", "missing_channel"],
40
+ "norm_channels": ["t2m"],
41
+ "source_name": "ecmwf",
42
+ "expect_error": True,
43
+ "error_match": "following channels for ecmwf are missing in normalisation means"
44
+ },
45
+ # Satellite case - success
46
+ {
47
+ "data_channels": ["IR_016", "VIS006"],
48
+ "norm_channels": ["IR_016", "VIS006", "extra"],
49
+ "source_name": "satellite",
50
+ "expect_error": False
51
+ },
52
+ # Satellite case - error
53
+ {
54
+ "data_channels": ["IR_016", "missing_channel"],
55
+ "norm_channels": ["IR_016"],
56
+ "source_name": "satellite",
57
+ "expect_error": True,
58
+ "error_match": "following channels for satellite are missing in normalisation means"
59
+ }
60
+ ])
61
+ def test_channel_validation(self, test_case):
62
+ """Test channel validation for both base, NWP and satellite data"""
63
+ if test_case["expect_error"]:
64
+ with pytest.raises(ValueError, match=test_case["error_match"]):
65
+ validate_channels(
66
+ data_channels=test_case["data_channels"],
67
+ means_channels=test_case["norm_channels"],
68
+ stds_channels=test_case["norm_channels"],
69
+ source_name=test_case["source_name"]
70
+ )
71
+ else:
72
+ # Should not raise any exceptions
73
+ validate_channels(
74
+ data_channels=test_case["data_channels"],
75
+ means_channels=test_case["norm_channels"],
76
+ stds_channels=test_case["norm_channels"],
77
+ source_name=test_case["source_name"]
78
+ )