ocf-data-sampler 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -86,7 +86,7 @@ ECMWF_STD = {
86
86
  "lcc": 0.3791404366493225,
87
87
  "mcc": 0.38039860129356384,
88
88
  "prate": 9.81039775069803e-05,
89
- "sde": 0.000913831521756947,
89
+ "sd": 0.000913831521756947,
90
90
  "sr": 16294988.0,
91
91
  "t2m": 3.692270040512085,
92
92
  "tcc": 0.37487083673477173,
@@ -110,7 +110,7 @@ ECMWF_MEAN = {
110
110
  "lcc": 0.44901806116104126,
111
111
  "mcc": 0.3288780450820923,
112
112
  "prate": 3.108070450252853e-05,
113
- "sde": 8.107526082312688e-05,
113
+ "sd": 8.107526082312688e-05,
114
114
  "sr": 12905302.0,
115
115
  "t2m": 283.48333740234375,
116
116
  "tcc": 0.7049227356910706,
@@ -5,25 +5,34 @@ Handling of both flat and nested structures - consideration for NWP
5
5
 
6
6
  import logging
7
7
  import numpy as np
8
+ import torch
9
+ import xarray as xr
8
10
 
9
11
  from pathlib import Path
10
- from typing import Any, Dict, Optional, Union
12
+ from typing import Any, Dict, Optional, Union, TypeAlias
11
13
  from abc import ABC, abstractmethod
12
14
 
15
+
13
16
  logger = logging.getLogger(__name__)
14
17
 
18
+ NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
19
+ NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
20
+ TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
21
+
22
+
15
23
  class SampleBase(ABC):
16
24
  """
17
25
  Abstract base class for all sample types
18
26
  Provides core data storage functionality
19
27
  """
20
28
 
21
- def __init__(self):
29
+ def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
22
30
  """ Initialise data container """
23
31
  logger.debug("Initialising SampleBase instance")
32
+ self._data = data
24
33
 
25
34
  @abstractmethod
26
- def to_numpy(self) -> Dict[str, Any]:
35
+ def to_numpy(self) -> NumpySample:
27
36
  """ Convert data to a numpy array representation """
28
37
  raise NotImplementedError
29
38
 
@@ -42,3 +51,25 @@ class SampleBase(ABC):
42
51
  def load(cls, path: Union[str, Path]) -> 'SampleBase':
43
52
  """ Abstract class method for loading sample data """
44
53
  raise NotImplementedError
54
+
55
+
56
+ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
57
+ """
58
+ Moves ndarrays in a nested dict to torch tensors
59
+ Args:
60
+ batch: NumpyBatch with data in numpy arrays
61
+ Returns:
62
+ TensorBatch with data in torch tensors
63
+ """
64
+ if not batch:
65
+ raise ValueError("Cannot convert empty batch to tensors")
66
+
67
+ for k, v in batch.items():
68
+ if isinstance(v, dict):
69
+ batch[k] = batch_to_tensor(v)
70
+ elif isinstance(v, np.ndarray):
71
+ if v.dtype == np.bool_:
72
+ batch[k] = torch.tensor(v, dtype=torch.bool)
73
+ elif np.issubdtype(v.dtype, np.number):
74
+ batch[k] = torch.as_tensor(v)
75
+ return batch
@@ -31,10 +31,15 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
31
  merge_dicts,
32
32
  fill_nans_in_arrays,
33
33
  )
34
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
+ validate_nwp_channels,
36
+ validate_satellite_channels,
37
+ )
34
38
 
35
39
 
36
40
  xr.set_options(keep_attrs=True)
37
41
 
42
+
38
43
  def process_and_combine_datasets(
39
44
  dataset_dict: dict,
40
45
  config: Configuration,
@@ -47,27 +52,23 @@ def process_and_combine_datasets(
47
52
  numpy_modalities = []
48
53
 
49
54
  if "nwp" in dataset_dict:
50
-
51
55
  nwp_numpy_modalities = dict()
52
56
 
53
57
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
54
- # Standardise
55
58
  provider = config.input_data.nwp[nwp_key].provider
56
- da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
57
59
 
58
- # Convert to NumpyBatch
60
+ # Standardise and convert to NumpyBatch
61
+ da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
59
62
  nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
60
63
 
61
64
  # Combine the NWPs into NumpyBatch
62
65
  numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
63
66
 
64
-
65
67
  if "sat" in dataset_dict:
66
- # Standardise
67
68
  da_sat = dataset_dict["sat"]
68
- da_sat = (da_sat - RSS_MEAN) / RSS_STD
69
69
 
70
- # Convert to NumpyBatch
70
+ # Standardise and convert to NumpyBatch
71
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
71
72
  numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
72
73
 
73
74
  gsp_config = config.input_data.gsp
@@ -186,9 +187,13 @@ class PVNetUKRegionalDataset(Dataset):
186
187
  """
187
188
 
188
189
  config = load_yaml_configuration(config_filename)
189
-
190
+
191
+ # Validate channels for NWP and satellite data
192
+ validate_nwp_channels(config)
193
+ validate_satellite_channels(config)
194
+
190
195
  datasets_dict = get_dataset_dict(config.input_data)
191
-
196
+
192
197
  # Get t0 times where all input data is available
193
198
  valid_t0_times = find_valid_t0_times(datasets_dict, config)
194
199
 
@@ -294,7 +299,11 @@ class PVNetUKConcurrentDataset(Dataset):
294
299
  """
295
300
 
296
301
  config = load_yaml_configuration(config_filename)
297
-
302
+
303
+ # Validate channels for NWP and satellite data
304
+ validate_nwp_channels(config)
305
+ validate_satellite_channels(config)
306
+
298
307
  datasets_dict = get_dataset_dict(config.input_data)
299
308
 
300
309
  # Get t0 times where all input data is available
@@ -361,4 +370,4 @@ class PVNetUKConcurrentDataset(Dataset):
361
370
  """
362
371
  # Check data is availablle for init-time t0
363
372
  assert t0 in self.valid_t0_times
364
- return self._get_sample(t0)
373
+ return self._get_sample(t0)
@@ -1,4 +1,5 @@
1
1
  """Torch dataset for sites"""
2
+
2
3
  import logging
3
4
  import numpy as np
4
5
  import pandas as pd
@@ -19,6 +20,8 @@ from ocf_data_sampler.select import (
19
20
  from ocf_data_sampler.utils import minutes
20
21
  from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
21
22
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
23
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import validate_nwp_channels
24
+
22
25
  from ocf_data_sampler.numpy_sample import (
23
26
  convert_site_to_numpy_sample,
24
27
  convert_satellite_to_numpy_sample,
@@ -29,8 +32,10 @@ from ocf_data_sampler.numpy_sample import (
29
32
  from ocf_data_sampler.numpy_sample import NWPSampleKey
30
33
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
31
34
 
35
+
32
36
  xr.set_options(keep_attrs=True)
33
37
 
38
+
34
39
  class SitesDataset(Dataset):
35
40
  def __init__(
36
41
  self,
@@ -47,6 +52,10 @@ class SitesDataset(Dataset):
47
52
  """
48
53
 
49
54
  config: Configuration = load_yaml_configuration(config_filename)
55
+
56
+ # Validate NWP channels
57
+ validate_nwp_channels(config)
58
+
50
59
  datasets_dict = get_dataset_dict(config.input_data)
51
60
 
52
61
  # Assign config and input data to self
@@ -221,8 +230,9 @@ class SitesDataset(Dataset):
221
230
 
222
231
  if "nwp" in dataset_dict:
223
232
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
224
- # Standardise
225
233
  provider = self.config.input_data.nwp[nwp_key].provider
234
+
235
+ # Standardise
226
236
  da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
227
237
  data_arrays.append((f"nwp-{provider}", da_nwp))
228
238
 
@@ -0,0 +1,82 @@
1
+ import xarray as xr
2
+
3
+ from ocf_data_sampler.config import Configuration
4
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
5
+
6
+
7
+ def validate_channels(
8
+ data_channels: list,
9
+ means_channels: list,
10
+ stds_channels: list,
11
+ source_name: str | None = None
12
+ ) -> None:
13
+ """
14
+ Validates that all channels in data have corresponding normalisation constants.
15
+
16
+ Args:
17
+ data_channels: Set of channels from the data
18
+ means_channels: Set of channels from means constants
19
+ stds_channels: Set of channels from stds constants
20
+ source_name: Name of data source (e.g., 'ecmwf', 'satellite') for error messages
21
+
22
+ Raises:
23
+ ValueError: If there's a mismatch between data channels and normalisation constants
24
+ """
25
+
26
+ data_set = set(data_channels)
27
+ means_set = set(means_channels)
28
+ stds_set = set(stds_channels)
29
+
30
+ # Find missing channels in means
31
+ missing_in_means = data_set - means_set
32
+ if missing_in_means:
33
+ raise ValueError(
34
+ f"The following channels for {source_name} are missing in normalisation means: "
35
+ f"{missing_in_means}"
36
+ )
37
+
38
+ # Find missing channels in stds
39
+ missing_in_stds = data_set - stds_set
40
+ if missing_in_stds:
41
+ raise ValueError(
42
+ f"The following channels for {source_name} are missing in normalisation stds: "
43
+ f"{missing_in_stds}"
44
+ )
45
+
46
+
47
+ def validate_nwp_channels(config: Configuration) -> None:
48
+ """Validate that NWP channels in config have corresponding normalisation constants.
49
+
50
+ Args:
51
+ config: Configuration object containing NWP channel information
52
+
53
+ Raises:
54
+ ValueError: If there's a mismatch between configured NWP channels and normalisation constants
55
+ """
56
+ if hasattr(config.input_data, "nwp"):
57
+ for nwp_key, nwp_config in config.input_data.nwp.items():
58
+ provider = nwp_config.provider
59
+ validate_channels(
60
+ data_channels=nwp_config.channels,
61
+ means_channels=NWP_MEANS[provider].channel.values,
62
+ stds_channels=NWP_STDS[provider].channel.values,
63
+ source_name=provider
64
+ )
65
+
66
+
67
+ def validate_satellite_channels(config: Configuration) -> None:
68
+ """Validate that satellite channels in config have corresponding normalisation constants.
69
+
70
+ Args:
71
+ config: Configuration object containing satellite channel information
72
+
73
+ Raises:
74
+ ValueError: If there's a mismatch between configured satellite channels and normalisation constants
75
+ """
76
+ if hasattr(config.input_data, "satellite"):
77
+ validate_channels(
78
+ data_channels=config.input_data.satellite.channels,
79
+ means_channels=RSS_MEAN.channel.values,
80
+ stds_channels=RSS_STD.channel.values,
81
+ source_name="satellite"
82
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -1,5 +1,5 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
- ocf_data_sampler/constants.py,sha256=ClteRIgp7EPlUPqIbkel83BfIaD7_VIDjUeHzUfyhnM,5079
2
+ ocf_data_sampler/constants.py,sha256=0HYNmqwBaHVTAEEx9qzk6WD9YInh0gSKLeI3pyq7aNs,5077
3
3
  ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
4
4
  ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
5
5
  ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
27
27
  ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
28
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=4U78tczCRsKMDwU4HkD20nyGyYjIBSZV5neF2mT--2M,1197
30
+ ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
32
  ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
@@ -41,10 +41,11 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
41
41
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
42
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=xuNJyCXZ4dZ9UldX1lqOoRSRNP39Vcy0DR77Vr7dxlk,11895
45
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=ZjvJS0mWUyQE7ZcrhS1TdMHaPrEZXVbBAv2vDwBvQwA,16044
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=N85duDyEm6LIYgYIpLhrpxHddMIcvFosuZg8rzIztwE,12267
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=L_4w967ZxPjd7vHRkPtj7ZSmamEShKRT28j9_f-enJY,16228
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
+ ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
48
49
  scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
49
50
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
51
  tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
@@ -67,14 +68,15 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
67
68
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
68
69
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
69
70
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
70
- tests/test_sample/test_base.py,sha256=ljtB38MmscTGN6OvUgclBceNnfx6m7AN8iHYDml9XW4,2189
71
+ tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
71
72
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
72
73
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
73
74
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
74
- tests/torch_datasets/test_pvnet_uk.py,sha256=loueo7PUUYJVda3-vBn3bQIC_zgrTAThfx-GTDcBOZg,5596
75
+ tests/torch_datasets/test_pvnet_uk.py,sha256=F0D-DugFgVtt8G1q7lylmPLrOZj6H6YPNd9s_6Wn_yM,5594
75
76
  tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
76
- ocf_data_sampler-0.1.6.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
77
- ocf_data_sampler-0.1.6.dist-info/METADATA,sha256=qltSR8dsD54ufCfXXFFYYLY_l_1saBWGaxwzZDIaJoU,12173
78
- ocf_data_sampler-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
79
- ocf_data_sampler-0.1.6.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
80
- ocf_data_sampler-0.1.6.dist-info/RECORD,,
77
+ tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
78
+ ocf_data_sampler-0.1.8.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
+ ocf_data_sampler-0.1.8.dist-info/METADATA,sha256=hWohmy0-J73u-uy3MPEG0_tuprAXOh32hX8WyIDPqaU,12173
80
+ ocf_data_sampler-0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
+ ocf_data_sampler-0.1.8.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
+ ocf_data_sampler-0.1.8.dist-info/RECORD,,
@@ -3,11 +3,14 @@ Base class testing - SampleBase
3
3
  """
4
4
 
5
5
  import pytest
6
+ import torch
6
7
  import numpy as np
7
8
 
8
9
  from pathlib import Path
9
- from ocf_data_sampler.sample.base import SampleBase
10
-
10
+ from ocf_data_sampler.sample.base import (
11
+ SampleBase,
12
+ batch_to_tensor
13
+ )
11
14
 
12
15
  class TestSample(SampleBase):
13
16
  """
@@ -84,3 +87,61 @@ def test_sample_base_to_numpy():
84
87
  assert isinstance(numpy_data, dict)
85
88
  assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
86
89
  assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
90
+
91
+
92
+ def test_batch_to_tensor_nested():
93
+ """ Test nested dictionary conversion """
94
+ batch = {
95
+ 'outer': {
96
+ 'inner': np.array([1, 2, 3])
97
+ }
98
+ }
99
+ tensor_batch = batch_to_tensor(batch)
100
+
101
+ assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
102
+
103
+
104
+ def test_batch_to_tensor_mixed_types():
105
+ """ Test handling of mixed data types """
106
+ batch = {
107
+ 'tensor_data': np.array([1, 2, 3]),
108
+ 'string_data': 'not_a_tensor',
109
+ 'nested': {
110
+ 'numbers': np.array([4, 5, 6]),
111
+ 'text': 'still_not_a_tensor'
112
+ }
113
+ }
114
+ tensor_batch = batch_to_tensor(batch)
115
+
116
+ assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
117
+ assert isinstance(tensor_batch['string_data'], str)
118
+ assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
119
+ assert isinstance(tensor_batch['nested']['text'], str)
120
+
121
+
122
+ def test_batch_to_tensor_different_dtypes():
123
+ """ Test conversion of arrays with different dtypes """
124
+ batch = {
125
+ 'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
126
+ 'int_data': np.array([1, 2, 3], dtype=np.int64),
127
+ 'bool_data': np.array([True, False, True], dtype=np.bool_)
128
+ }
129
+ tensor_batch = batch_to_tensor(batch)
130
+
131
+ assert isinstance(tensor_batch['bool_data'], torch.Tensor)
132
+ assert tensor_batch['float_data'].dtype == torch.float32
133
+ assert tensor_batch['int_data'].dtype == torch.int64
134
+ assert tensor_batch['bool_data'].dtype == torch.bool
135
+
136
+
137
+ def test_batch_to_tensor_multidimensional():
138
+ """ Test conversion of multidimensional arrays """
139
+ batch = {
140
+ 'matrix': np.array([[1, 2], [3, 4]]),
141
+ 'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
142
+ }
143
+ tensor_batch = batch_to_tensor(batch)
144
+
145
+ assert tensor_batch['matrix'].shape == (2, 2)
146
+ assert tensor_batch['tensor'].shape == (2, 2, 2)
147
+ assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
@@ -24,7 +24,7 @@ def test_process_and_combine_datasets(pvnet_config_filename):
24
24
  dims=["time_utc", "channel", "y", "x"],
25
25
  coords={
26
26
  "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
27
- "channel": ["t2m", "dswrf"],
27
+ "channel": ["t", "dswrf"],
28
28
  "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
29
29
  "init_time_utc": pd.Timestamp("2024-01-01 00:00")
30
30
  }
@@ -54,7 +54,7 @@ def test_process_and_combine_datasets(pvnet_config_filename):
54
54
  assert isinstance(sample, dict)
55
55
  assert "nwp" in sample
56
56
  assert sample["satellite_actual"].shape == (7, 1, 2, 2)
57
- assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
57
+ assert sample["nwp"]["ukv"]["nwp"].shape == (4, 2, 2, 2)
58
58
  assert "gsp_id" in sample
59
59
 
60
60
 
@@ -0,0 +1,78 @@
1
+ """Tests for channel validation utility functions"""
2
+
3
+ import pytest
4
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import (
5
+ validate_channels,
6
+ validate_nwp_channels,
7
+ validate_satellite_channels,
8
+ )
9
+
10
+
11
+ class TestChannelValidation:
12
+ """Tests for channel validation functions"""
13
+
14
+ @pytest.mark.parametrize("test_case", [
15
+ # Base validation - success case
16
+ {
17
+ "data_channels": ["channel1", "channel2"],
18
+ "norm_channels": ["channel1", "channel2", "extra"],
19
+ "source_name": "test_source",
20
+ "expect_error": False
21
+ },
22
+ # Base validation - error case
23
+ {
24
+ "data_channels": ["channel1", "missing_channel"],
25
+ "norm_channels": ["channel1"],
26
+ "source_name": "test_source",
27
+ "expect_error": True,
28
+ "error_match": "following channels for test_source are missing in normalisation means"
29
+ },
30
+ # NWP case - success
31
+ {
32
+ "data_channels": ["t2m", "dswrf"],
33
+ "norm_channels": ["t2m", "dswrf", "extra"],
34
+ "source_name": "ecmwf",
35
+ "expect_error": False
36
+ },
37
+ # NWP case - error
38
+ {
39
+ "data_channels": ["t2m", "missing_channel"],
40
+ "norm_channels": ["t2m"],
41
+ "source_name": "ecmwf",
42
+ "expect_error": True,
43
+ "error_match": "following channels for ecmwf are missing in normalisation means"
44
+ },
45
+ # Satellite case - success
46
+ {
47
+ "data_channels": ["IR_016", "VIS006"],
48
+ "norm_channels": ["IR_016", "VIS006", "extra"],
49
+ "source_name": "satellite",
50
+ "expect_error": False
51
+ },
52
+ # Satellite case - error
53
+ {
54
+ "data_channels": ["IR_016", "missing_channel"],
55
+ "norm_channels": ["IR_016"],
56
+ "source_name": "satellite",
57
+ "expect_error": True,
58
+ "error_match": "following channels for satellite are missing in normalisation means"
59
+ }
60
+ ])
61
+ def test_channel_validation(self, test_case):
62
+ """Test channel validation for both base, NWP and satellite data"""
63
+ if test_case["expect_error"]:
64
+ with pytest.raises(ValueError, match=test_case["error_match"]):
65
+ validate_channels(
66
+ data_channels=test_case["data_channels"],
67
+ means_channels=test_case["norm_channels"],
68
+ stds_channels=test_case["norm_channels"],
69
+ source_name=test_case["source_name"]
70
+ )
71
+ else:
72
+ # Should not raise any exceptions
73
+ validate_channels(
74
+ data_channels=test_case["data_channels"],
75
+ means_channels=test_case["norm_channels"],
76
+ stds_channels=test_case["norm_channels"],
77
+ source_name=test_case["source_name"]
78
+ )