ocf-data-sampler 0.1.6__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- {ocf_data_sampler-0.1.6/ocf_data_sampler.egg-info → ocf_data_sampler-0.1.8}/PKG-INFO +1 -1
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/constants.py +2 -2
- ocf_data_sampler-0.1.8/ocf_data_sampler/sample/base.py +75 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +21 -12
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/torch_datasets/datasets/site.py +11 -1
- ocf_data_sampler-0.1.8/ocf_data_sampler/torch_datasets/utils/validate_channels.py +82 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler.egg-info/SOURCES.txt +3 -1
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/pyproject.toml +1 -1
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/test_sample/test_base.py +63 -2
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/torch_datasets/test_pvnet_uk.py +2 -2
- ocf_data_sampler-0.1.8/tests/torch_datasets/test_validate_channels_utils.py +78 -0
- ocf_data_sampler-0.1.6/ocf_data_sampler/sample/base.py +0 -44
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/LICENSE +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/MANIFEST.in +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/README.md +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/config/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/config/load.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/config/model.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/config/save.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/gsp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/load_dataset.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/nwp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/satellite.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/site.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/utils.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/collate.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/site.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/sample/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/sample/site.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/sample/uk_regional.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/dropout.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/fill_time_periods.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/geospatial.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/location.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/select_time_slice.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/time_slice_for_dataset.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/utils.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler.egg-info/requires.txt +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler.egg-info/top_level.txt +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/scripts/refactor_site.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/setup.cfg +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/__init__.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/config/test_config.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/config/test_load.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/config/test_save.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/conftest.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/load/test_load_gsp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/load/test_load_nwp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/load/test_load_satellite.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/load/test_load_sites.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_collate.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_datetime_features.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_gsp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_nwp.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_satellite.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_sun_position.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_dropout.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_fill_time_periods.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_find_contiguous_time_periods.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_location.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_select_spatial_slice.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_select_time_slice.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/test_sample/test_site_sample.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/test_sample/test_uk_regional_sample.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/torch_datasets/test_merge_and_fill_utils.py +0 -0
- {ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/torch_datasets/test_site.py +0 -0
|
@@ -86,7 +86,7 @@ ECMWF_STD = {
|
|
|
86
86
|
"lcc": 0.3791404366493225,
|
|
87
87
|
"mcc": 0.38039860129356384,
|
|
88
88
|
"prate": 9.81039775069803e-05,
|
|
89
|
-
"
|
|
89
|
+
"sd": 0.000913831521756947,
|
|
90
90
|
"sr": 16294988.0,
|
|
91
91
|
"t2m": 3.692270040512085,
|
|
92
92
|
"tcc": 0.37487083673477173,
|
|
@@ -110,7 +110,7 @@ ECMWF_MEAN = {
|
|
|
110
110
|
"lcc": 0.44901806116104126,
|
|
111
111
|
"mcc": 0.3288780450820923,
|
|
112
112
|
"prate": 3.108070450252853e-05,
|
|
113
|
-
"
|
|
113
|
+
"sd": 8.107526082312688e-05,
|
|
114
114
|
"sr": 12905302.0,
|
|
115
115
|
"t2m": 283.48333740234375,
|
|
116
116
|
"tcc": 0.7049227356910706,
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base class definition - abstract
|
|
3
|
+
Handling of both flat and nested structures - consideration for NWP
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import xarray as xr
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Dict, Optional, Union, TypeAlias
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
|
|
19
|
+
NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
|
|
20
|
+
TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SampleBase(ABC):
|
|
24
|
+
"""
|
|
25
|
+
Abstract base class for all sample types
|
|
26
|
+
Provides core data storage functionality
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
|
|
30
|
+
""" Initialise data container """
|
|
31
|
+
logger.debug("Initialising SampleBase instance")
|
|
32
|
+
self._data = data
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def to_numpy(self) -> NumpySample:
|
|
36
|
+
""" Convert data to a numpy array representation """
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def plot(self, **kwargs) -> None:
|
|
41
|
+
""" Abstract method for plotting """
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
46
|
+
""" Abstract method for saving sample data """
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def load(cls, path: Union[str, Path]) -> 'SampleBase':
|
|
52
|
+
""" Abstract class method for loading sample data """
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
57
|
+
"""
|
|
58
|
+
Moves ndarrays in a nested dict to torch tensors
|
|
59
|
+
Args:
|
|
60
|
+
batch: NumpyBatch with data in numpy arrays
|
|
61
|
+
Returns:
|
|
62
|
+
TensorBatch with data in torch tensors
|
|
63
|
+
"""
|
|
64
|
+
if not batch:
|
|
65
|
+
raise ValueError("Cannot convert empty batch to tensors")
|
|
66
|
+
|
|
67
|
+
for k, v in batch.items():
|
|
68
|
+
if isinstance(v, dict):
|
|
69
|
+
batch[k] = batch_to_tensor(v)
|
|
70
|
+
elif isinstance(v, np.ndarray):
|
|
71
|
+
if v.dtype == np.bool_:
|
|
72
|
+
batch[k] = torch.tensor(v, dtype=torch.bool)
|
|
73
|
+
elif np.issubdtype(v.dtype, np.number):
|
|
74
|
+
batch[k] = torch.as_tensor(v)
|
|
75
|
+
return batch
|
|
@@ -31,10 +31,15 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
|
31
31
|
merge_dicts,
|
|
32
32
|
fill_nans_in_arrays,
|
|
33
33
|
)
|
|
34
|
+
from ocf_data_sampler.torch_datasets.utils.validate_channels import (
|
|
35
|
+
validate_nwp_channels,
|
|
36
|
+
validate_satellite_channels,
|
|
37
|
+
)
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
xr.set_options(keep_attrs=True)
|
|
37
41
|
|
|
42
|
+
|
|
38
43
|
def process_and_combine_datasets(
|
|
39
44
|
dataset_dict: dict,
|
|
40
45
|
config: Configuration,
|
|
@@ -47,27 +52,23 @@ def process_and_combine_datasets(
|
|
|
47
52
|
numpy_modalities = []
|
|
48
53
|
|
|
49
54
|
if "nwp" in dataset_dict:
|
|
50
|
-
|
|
51
55
|
nwp_numpy_modalities = dict()
|
|
52
56
|
|
|
53
57
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
54
|
-
# Standardise
|
|
55
58
|
provider = config.input_data.nwp[nwp_key].provider
|
|
56
|
-
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
57
59
|
|
|
58
|
-
#
|
|
60
|
+
# Standardise and convert to NumpyBatch
|
|
61
|
+
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
59
62
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
|
|
60
63
|
|
|
61
64
|
# Combine the NWPs into NumpyBatch
|
|
62
65
|
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
|
|
63
66
|
|
|
64
|
-
|
|
65
67
|
if "sat" in dataset_dict:
|
|
66
|
-
# Standardise
|
|
67
68
|
da_sat = dataset_dict["sat"]
|
|
68
|
-
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
69
69
|
|
|
70
|
-
#
|
|
70
|
+
# Standardise and convert to NumpyBatch
|
|
71
|
+
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
71
72
|
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
72
73
|
|
|
73
74
|
gsp_config = config.input_data.gsp
|
|
@@ -186,9 +187,13 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
186
187
|
"""
|
|
187
188
|
|
|
188
189
|
config = load_yaml_configuration(config_filename)
|
|
189
|
-
|
|
190
|
+
|
|
191
|
+
# Validate channels for NWP and satellite data
|
|
192
|
+
validate_nwp_channels(config)
|
|
193
|
+
validate_satellite_channels(config)
|
|
194
|
+
|
|
190
195
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
191
|
-
|
|
196
|
+
|
|
192
197
|
# Get t0 times where all input data is available
|
|
193
198
|
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
194
199
|
|
|
@@ -294,7 +299,11 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
294
299
|
"""
|
|
295
300
|
|
|
296
301
|
config = load_yaml_configuration(config_filename)
|
|
297
|
-
|
|
302
|
+
|
|
303
|
+
# Validate channels for NWP and satellite data
|
|
304
|
+
validate_nwp_channels(config)
|
|
305
|
+
validate_satellite_channels(config)
|
|
306
|
+
|
|
298
307
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
299
308
|
|
|
300
309
|
# Get t0 times where all input data is available
|
|
@@ -361,4 +370,4 @@ class PVNetUKConcurrentDataset(Dataset):
|
|
|
361
370
|
"""
|
|
362
371
|
# Check data is availablle for init-time t0
|
|
363
372
|
assert t0 in self.valid_t0_times
|
|
364
|
-
return self._get_sample(t0)
|
|
373
|
+
return self._get_sample(t0)
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/torch_datasets/datasets/site.py
RENAMED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Torch dataset for sites"""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
@@ -19,6 +20,8 @@ from ocf_data_sampler.select import (
|
|
|
19
20
|
from ocf_data_sampler.utils import minutes
|
|
20
21
|
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
21
22
|
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
|
|
23
|
+
from ocf_data_sampler.torch_datasets.utils.validate_channels import validate_nwp_channels
|
|
24
|
+
|
|
22
25
|
from ocf_data_sampler.numpy_sample import (
|
|
23
26
|
convert_site_to_numpy_sample,
|
|
24
27
|
convert_satellite_to_numpy_sample,
|
|
@@ -29,8 +32,10 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
29
32
|
from ocf_data_sampler.numpy_sample import NWPSampleKey
|
|
30
33
|
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
31
34
|
|
|
35
|
+
|
|
32
36
|
xr.set_options(keep_attrs=True)
|
|
33
37
|
|
|
38
|
+
|
|
34
39
|
class SitesDataset(Dataset):
|
|
35
40
|
def __init__(
|
|
36
41
|
self,
|
|
@@ -47,6 +52,10 @@ class SitesDataset(Dataset):
|
|
|
47
52
|
"""
|
|
48
53
|
|
|
49
54
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
55
|
+
|
|
56
|
+
# Validate NWP channels
|
|
57
|
+
validate_nwp_channels(config)
|
|
58
|
+
|
|
50
59
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
51
60
|
|
|
52
61
|
# Assign config and input data to self
|
|
@@ -221,8 +230,9 @@ class SitesDataset(Dataset):
|
|
|
221
230
|
|
|
222
231
|
if "nwp" in dataset_dict:
|
|
223
232
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
224
|
-
# Standardise
|
|
225
233
|
provider = self.config.input_data.nwp[nwp_key].provider
|
|
234
|
+
|
|
235
|
+
# Standardise
|
|
226
236
|
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
227
237
|
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
228
238
|
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
|
|
3
|
+
from ocf_data_sampler.config import Configuration
|
|
4
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def validate_channels(
|
|
8
|
+
data_channels: list,
|
|
9
|
+
means_channels: list,
|
|
10
|
+
stds_channels: list,
|
|
11
|
+
source_name: str | None = None
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Validates that all channels in data have corresponding normalisation constants.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
data_channels: Set of channels from the data
|
|
18
|
+
means_channels: Set of channels from means constants
|
|
19
|
+
stds_channels: Set of channels from stds constants
|
|
20
|
+
source_name: Name of data source (e.g., 'ecmwf', 'satellite') for error messages
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ValueError: If there's a mismatch between data channels and normalisation constants
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
data_set = set(data_channels)
|
|
27
|
+
means_set = set(means_channels)
|
|
28
|
+
stds_set = set(stds_channels)
|
|
29
|
+
|
|
30
|
+
# Find missing channels in means
|
|
31
|
+
missing_in_means = data_set - means_set
|
|
32
|
+
if missing_in_means:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"The following channels for {source_name} are missing in normalisation means: "
|
|
35
|
+
f"{missing_in_means}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Find missing channels in stds
|
|
39
|
+
missing_in_stds = data_set - stds_set
|
|
40
|
+
if missing_in_stds:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"The following channels for {source_name} are missing in normalisation stds: "
|
|
43
|
+
f"{missing_in_stds}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def validate_nwp_channels(config: Configuration) -> None:
|
|
48
|
+
"""Validate that NWP channels in config have corresponding normalisation constants.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config: Configuration object containing NWP channel information
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If there's a mismatch between configured NWP channels and normalisation constants
|
|
55
|
+
"""
|
|
56
|
+
if hasattr(config.input_data, "nwp"):
|
|
57
|
+
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
58
|
+
provider = nwp_config.provider
|
|
59
|
+
validate_channels(
|
|
60
|
+
data_channels=nwp_config.channels,
|
|
61
|
+
means_channels=NWP_MEANS[provider].channel.values,
|
|
62
|
+
stds_channels=NWP_STDS[provider].channel.values,
|
|
63
|
+
source_name=provider
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def validate_satellite_channels(config: Configuration) -> None:
|
|
68
|
+
"""Validate that satellite channels in config have corresponding normalisation constants.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
config: Configuration object containing satellite channel information
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: If there's a mismatch between configured satellite channels and normalisation constants
|
|
75
|
+
"""
|
|
76
|
+
if hasattr(config.input_data, "satellite"):
|
|
77
|
+
validate_channels(
|
|
78
|
+
data_channels=config.input_data.satellite.channels,
|
|
79
|
+
means_channels=RSS_MEAN.channel.values,
|
|
80
|
+
stds_channels=RSS_STD.channel.values,
|
|
81
|
+
source_name="satellite"
|
|
82
|
+
)
|
|
@@ -54,6 +54,7 @@ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
|
|
|
54
54
|
ocf_data_sampler/torch_datasets/datasets/site.py
|
|
55
55
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
|
|
56
56
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py
|
|
57
|
+
ocf_data_sampler/torch_datasets/utils/validate_channels.py
|
|
57
58
|
scripts/refactor_site.py
|
|
58
59
|
tests/__init__.py
|
|
59
60
|
tests/conftest.py
|
|
@@ -81,4 +82,5 @@ tests/test_sample/test_site_sample.py
|
|
|
81
82
|
tests/test_sample/test_uk_regional_sample.py
|
|
82
83
|
tests/torch_datasets/test_merge_and_fill_utils.py
|
|
83
84
|
tests/torch_datasets/test_pvnet_uk.py
|
|
84
|
-
tests/torch_datasets/test_site.py
|
|
85
|
+
tests/torch_datasets/test_site.py
|
|
86
|
+
tests/torch_datasets/test_validate_channels_utils.py
|
|
@@ -3,11 +3,14 @@ Base class testing - SampleBase
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
+
import torch
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from ocf_data_sampler.sample.base import
|
|
10
|
-
|
|
10
|
+
from ocf_data_sampler.sample.base import (
|
|
11
|
+
SampleBase,
|
|
12
|
+
batch_to_tensor
|
|
13
|
+
)
|
|
11
14
|
|
|
12
15
|
class TestSample(SampleBase):
|
|
13
16
|
"""
|
|
@@ -84,3 +87,61 @@ def test_sample_base_to_numpy():
|
|
|
84
87
|
assert isinstance(numpy_data, dict)
|
|
85
88
|
assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
|
|
86
89
|
assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_batch_to_tensor_nested():
|
|
93
|
+
""" Test nested dictionary conversion """
|
|
94
|
+
batch = {
|
|
95
|
+
'outer': {
|
|
96
|
+
'inner': np.array([1, 2, 3])
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
tensor_batch = batch_to_tensor(batch)
|
|
100
|
+
|
|
101
|
+
assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_batch_to_tensor_mixed_types():
|
|
105
|
+
""" Test handling of mixed data types """
|
|
106
|
+
batch = {
|
|
107
|
+
'tensor_data': np.array([1, 2, 3]),
|
|
108
|
+
'string_data': 'not_a_tensor',
|
|
109
|
+
'nested': {
|
|
110
|
+
'numbers': np.array([4, 5, 6]),
|
|
111
|
+
'text': 'still_not_a_tensor'
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
tensor_batch = batch_to_tensor(batch)
|
|
115
|
+
|
|
116
|
+
assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
|
|
117
|
+
assert isinstance(tensor_batch['string_data'], str)
|
|
118
|
+
assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
|
|
119
|
+
assert isinstance(tensor_batch['nested']['text'], str)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_batch_to_tensor_different_dtypes():
|
|
123
|
+
""" Test conversion of arrays with different dtypes """
|
|
124
|
+
batch = {
|
|
125
|
+
'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
|
|
126
|
+
'int_data': np.array([1, 2, 3], dtype=np.int64),
|
|
127
|
+
'bool_data': np.array([True, False, True], dtype=np.bool_)
|
|
128
|
+
}
|
|
129
|
+
tensor_batch = batch_to_tensor(batch)
|
|
130
|
+
|
|
131
|
+
assert isinstance(tensor_batch['bool_data'], torch.Tensor)
|
|
132
|
+
assert tensor_batch['float_data'].dtype == torch.float32
|
|
133
|
+
assert tensor_batch['int_data'].dtype == torch.int64
|
|
134
|
+
assert tensor_batch['bool_data'].dtype == torch.bool
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def test_batch_to_tensor_multidimensional():
|
|
138
|
+
""" Test conversion of multidimensional arrays """
|
|
139
|
+
batch = {
|
|
140
|
+
'matrix': np.array([[1, 2], [3, 4]]),
|
|
141
|
+
'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
142
|
+
}
|
|
143
|
+
tensor_batch = batch_to_tensor(batch)
|
|
144
|
+
|
|
145
|
+
assert tensor_batch['matrix'].shape == (2, 2)
|
|
146
|
+
assert tensor_batch['tensor'].shape == (2, 2, 2)
|
|
147
|
+
assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
|
|
@@ -24,7 +24,7 @@ def test_process_and_combine_datasets(pvnet_config_filename):
|
|
|
24
24
|
dims=["time_utc", "channel", "y", "x"],
|
|
25
25
|
coords={
|
|
26
26
|
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
27
|
-
"channel": ["
|
|
27
|
+
"channel": ["t", "dswrf"],
|
|
28
28
|
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
29
29
|
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
30
30
|
}
|
|
@@ -54,7 +54,7 @@ def test_process_and_combine_datasets(pvnet_config_filename):
|
|
|
54
54
|
assert isinstance(sample, dict)
|
|
55
55
|
assert "nwp" in sample
|
|
56
56
|
assert sample["satellite_actual"].shape == (7, 1, 2, 2)
|
|
57
|
-
assert sample["nwp"]["ukv"]["nwp"].shape == (4,
|
|
57
|
+
assert sample["nwp"]["ukv"]["nwp"].shape == (4, 2, 2, 2)
|
|
58
58
|
assert "gsp_id" in sample
|
|
59
59
|
|
|
60
60
|
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Tests for channel validation utility functions"""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from ocf_data_sampler.torch_datasets.utils.validate_channels import (
|
|
5
|
+
validate_channels,
|
|
6
|
+
validate_nwp_channels,
|
|
7
|
+
validate_satellite_channels,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestChannelValidation:
|
|
12
|
+
"""Tests for channel validation functions"""
|
|
13
|
+
|
|
14
|
+
@pytest.mark.parametrize("test_case", [
|
|
15
|
+
# Base validation - success case
|
|
16
|
+
{
|
|
17
|
+
"data_channels": ["channel1", "channel2"],
|
|
18
|
+
"norm_channels": ["channel1", "channel2", "extra"],
|
|
19
|
+
"source_name": "test_source",
|
|
20
|
+
"expect_error": False
|
|
21
|
+
},
|
|
22
|
+
# Base validation - error case
|
|
23
|
+
{
|
|
24
|
+
"data_channels": ["channel1", "missing_channel"],
|
|
25
|
+
"norm_channels": ["channel1"],
|
|
26
|
+
"source_name": "test_source",
|
|
27
|
+
"expect_error": True,
|
|
28
|
+
"error_match": "following channels for test_source are missing in normalisation means"
|
|
29
|
+
},
|
|
30
|
+
# NWP case - success
|
|
31
|
+
{
|
|
32
|
+
"data_channels": ["t2m", "dswrf"],
|
|
33
|
+
"norm_channels": ["t2m", "dswrf", "extra"],
|
|
34
|
+
"source_name": "ecmwf",
|
|
35
|
+
"expect_error": False
|
|
36
|
+
},
|
|
37
|
+
# NWP case - error
|
|
38
|
+
{
|
|
39
|
+
"data_channels": ["t2m", "missing_channel"],
|
|
40
|
+
"norm_channels": ["t2m"],
|
|
41
|
+
"source_name": "ecmwf",
|
|
42
|
+
"expect_error": True,
|
|
43
|
+
"error_match": "following channels for ecmwf are missing in normalisation means"
|
|
44
|
+
},
|
|
45
|
+
# Satellite case - success
|
|
46
|
+
{
|
|
47
|
+
"data_channels": ["IR_016", "VIS006"],
|
|
48
|
+
"norm_channels": ["IR_016", "VIS006", "extra"],
|
|
49
|
+
"source_name": "satellite",
|
|
50
|
+
"expect_error": False
|
|
51
|
+
},
|
|
52
|
+
# Satellite case - error
|
|
53
|
+
{
|
|
54
|
+
"data_channels": ["IR_016", "missing_channel"],
|
|
55
|
+
"norm_channels": ["IR_016"],
|
|
56
|
+
"source_name": "satellite",
|
|
57
|
+
"expect_error": True,
|
|
58
|
+
"error_match": "following channels for satellite are missing in normalisation means"
|
|
59
|
+
}
|
|
60
|
+
])
|
|
61
|
+
def test_channel_validation(self, test_case):
|
|
62
|
+
"""Test channel validation for both base, NWP and satellite data"""
|
|
63
|
+
if test_case["expect_error"]:
|
|
64
|
+
with pytest.raises(ValueError, match=test_case["error_match"]):
|
|
65
|
+
validate_channels(
|
|
66
|
+
data_channels=test_case["data_channels"],
|
|
67
|
+
means_channels=test_case["norm_channels"],
|
|
68
|
+
stds_channels=test_case["norm_channels"],
|
|
69
|
+
source_name=test_case["source_name"]
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
# Should not raise any exceptions
|
|
73
|
+
validate_channels(
|
|
74
|
+
data_channels=test_case["data_channels"],
|
|
75
|
+
means_channels=test_case["norm_channels"],
|
|
76
|
+
stds_channels=test_case["norm_channels"],
|
|
77
|
+
source_name=test_case["source_name"]
|
|
78
|
+
)
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Base class definition - abstract
|
|
3
|
-
Handling of both flat and nested structures - consideration for NWP
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from typing import Any, Dict, Optional, Union
|
|
11
|
-
from abc import ABC, abstractmethod
|
|
12
|
-
|
|
13
|
-
logger = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
class SampleBase(ABC):
|
|
16
|
-
"""
|
|
17
|
-
Abstract base class for all sample types
|
|
18
|
-
Provides core data storage functionality
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(self):
|
|
22
|
-
""" Initialise data container """
|
|
23
|
-
logger.debug("Initialising SampleBase instance")
|
|
24
|
-
|
|
25
|
-
@abstractmethod
|
|
26
|
-
def to_numpy(self) -> Dict[str, Any]:
|
|
27
|
-
""" Convert data to a numpy array representation """
|
|
28
|
-
raise NotImplementedError
|
|
29
|
-
|
|
30
|
-
@abstractmethod
|
|
31
|
-
def plot(self, **kwargs) -> None:
|
|
32
|
-
""" Abstract method for plotting """
|
|
33
|
-
raise NotImplementedError
|
|
34
|
-
|
|
35
|
-
@abstractmethod
|
|
36
|
-
def save(self, path: Union[str, Path]) -> None:
|
|
37
|
-
""" Abstract method for saving sample data """
|
|
38
|
-
raise NotImplementedError
|
|
39
|
-
|
|
40
|
-
@classmethod
|
|
41
|
-
@abstractmethod
|
|
42
|
-
def load(cls, path: Union[str, Path]) -> 'SampleBase':
|
|
43
|
-
""" Abstract class method for loading sample data """
|
|
44
|
-
raise NotImplementedError
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/data/uk_gsp_locations.csv
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/__init__.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/ecmwf.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/ukv.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/load/nwp/providers/utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/datetime_features.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/satellite.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/numpy_sample/sun_position.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/fill_time_periods.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/select_spatial_slice.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/select_time_slice.py
RENAMED
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler/select/time_slice_for_dataset.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/ocf_data_sampler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/numpy_sample/test_datetime_features.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/select/test_find_contiguous_time_periods.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/test_sample/test_uk_regional_sample.py
RENAMED
|
File without changes
|
{ocf_data_sampler-0.1.6 → ocf_data_sampler-0.1.8}/tests/torch_datasets/test_merge_and_fill_utils.py
RENAMED
|
File without changes
|
|
File without changes
|