ocf-data-sampler 0.0.53__tar.gz → 0.0.54__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (87) hide show
  1. {ocf_data_sampler-0.0.53/ocf_data_sampler.egg-info → ocf_data_sampler-0.0.54}/PKG-INFO +4 -1
  2. ocf_data_sampler-0.0.54/ocf_data_sampler/sample/__init__.py +10 -0
  3. ocf_data_sampler-0.0.54/ocf_data_sampler/sample/base.py +44 -0
  4. ocf_data_sampler-0.0.54/ocf_data_sampler/sample/site.py +81 -0
  5. ocf_data_sampler-0.0.54/ocf_data_sampler/sample/uk_regional.py +118 -0
  6. ocf_data_sampler-0.0.54/ocf_data_sampler/torch_datasets/datasets/__init__.py +11 -0
  7. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54/ocf_data_sampler.egg-info}/PKG-INFO +4 -1
  8. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler.egg-info/SOURCES.txt +7 -0
  9. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler.egg-info/requires.txt +4 -0
  10. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/pyproject.toml +6 -2
  11. ocf_data_sampler-0.0.54/tests/test_sample/test_base.py +86 -0
  12. ocf_data_sampler-0.0.54/tests/test_sample/test_site_sample.py +195 -0
  13. ocf_data_sampler-0.0.54/tests/test_sample/test_uk_regional_sample.py +163 -0
  14. ocf_data_sampler-0.0.53/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -2
  15. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/LICENSE +0 -0
  16. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/MANIFEST.in +0 -0
  17. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/README.md +0 -0
  18. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/__init__.py +0 -0
  19. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/config/__init__.py +0 -0
  20. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/config/load.py +0 -0
  21. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/config/model.py +0 -0
  22. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/config/save.py +0 -0
  23. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/constants.py +0 -0
  24. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  25. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/__init__.py +0 -0
  26. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/gsp.py +0 -0
  27. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/load_dataset.py +0 -0
  28. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  29. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  30. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  31. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  32. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  33. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  34. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/satellite.py +0 -0
  35. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/site.py +0 -0
  36. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/load/utils.py +0 -0
  37. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  38. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  39. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  40. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  41. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  42. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  43. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/site.py +0 -0
  44. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  45. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/__init__.py +0 -0
  46. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/dropout.py +0 -0
  47. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  48. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  49. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/geospatial.py +0 -0
  50. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/location.py +0 -0
  51. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  52. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/select_time_slice.py +0 -0
  53. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
  54. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/select/time_slice_for_dataset.py +0 -0
  55. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py +0 -0
  56. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/torch_datasets/datasets/site.py +0 -0
  57. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  58. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  59. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler/utils.py +0 -0
  60. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  61. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  62. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/scripts/refactor_site.py +0 -0
  63. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/setup.cfg +0 -0
  64. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/__init__.py +0 -0
  65. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/config/test_config.py +0 -0
  66. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/config/test_save.py +0 -0
  67. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/conftest.py +0 -0
  68. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/load/test_load_gsp.py +0 -0
  69. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/load/test_load_nwp.py +0 -0
  70. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/load/test_load_satellite.py +0 -0
  71. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/load/test_load_sites.py +0 -0
  72. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/numpy_sample/test_collate.py +0 -0
  73. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/numpy_sample/test_datetime_features.py +0 -0
  74. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/numpy_sample/test_gsp.py +0 -0
  75. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/numpy_sample/test_nwp.py +0 -0
  76. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/numpy_sample/test_satellite.py +0 -0
  77. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/numpy_sample/test_sun_position.py +0 -0
  78. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/select/test_dropout.py +0 -0
  79. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/select/test_fill_time_periods.py +0 -0
  80. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/select/test_find_contiguous_time_periods.py +0 -0
  81. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/select/test_location.py +0 -0
  82. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/select/test_select_spatial_slice.py +0 -0
  83. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/select/test_select_time_slice.py +0 -0
  84. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/torch_datasets/conftest.py +0 -0
  85. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/torch_datasets/test_merge_and_fill_utils.py +0 -0
  86. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/torch_datasets/test_pvnet_uk_regional.py +0 -0
  87. {ocf_data_sampler-0.0.53 → ocf_data_sampler-0.0.54}/tests/torch_datasets/test_site.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.53
3
+ Version: 0.0.54
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -49,9 +49,12 @@ Requires-Dist: pyproj
49
49
  Requires-Dist: pathy
50
50
  Requires-Dist: pyaml_env
51
51
  Requires-Dist: pyresample
52
+ Requires-Dist: h5netcdf
52
53
  Provides-Extra: docs
53
54
  Requires-Dist: mkdocs>=1.2; extra == "docs"
54
55
  Requires-Dist: mkdocs-material>=8.0; extra == "docs"
56
+ Provides-Extra: plot
57
+ Requires-Dist: matplotlib; extra == "plot"
55
58
 
56
59
  # ocf-data-sampler
57
60
 
@@ -0,0 +1,10 @@
1
+ from ocf_data_sampler.sample.base import SampleBase
2
+ from ocf_data_sampler.sample.uk_regional import UKRegionalSample
3
+ from ocf_data_sampler.sample.site import SiteSample
4
+
5
+
6
+ __all__ = [
7
+ 'SampleBase',
8
+ 'UKRegionalSample',
9
+ 'SiteSample'
10
+ ]
@@ -0,0 +1,44 @@
1
+ """
2
+ Base class definition - abstract
3
+ Handling of both flat and nested structures - consideration for NWP
4
+ """
5
+
6
+ import logging
7
+ import numpy as np
8
+
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional, Union
11
+ from abc import ABC, abstractmethod
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class SampleBase(ABC):
16
+ """
17
+ Abstract base class for all sample types
18
+ Provides core data storage functionality
19
+ """
20
+
21
+ def __init__(self):
22
+ """ Initialise data container """
23
+ logger.debug("Initialising SampleBase instance")
24
+
25
+ @abstractmethod
26
+ def to_numpy(self) -> Dict[str, Any]:
27
+ """ Convert data to a numpy array representation """
28
+ raise NotImplementedError
29
+
30
+ @abstractmethod
31
+ def plot(self, **kwargs) -> None:
32
+ """ Abstract method for plotting """
33
+ raise NotImplementedError
34
+
35
+ @abstractmethod
36
+ def save(self, path: Union[str, Path]) -> None:
37
+ """ Abstract method for saving sample data """
38
+ raise NotImplementedError
39
+
40
+ @classmethod
41
+ @abstractmethod
42
+ def load(cls, path: Union[str, Path]) -> 'SampleBase':
43
+ """ Abstract class method for loading sample data """
44
+ raise NotImplementedError
@@ -0,0 +1,81 @@
1
+ """
2
+ PVNet - Site sample / dataset implementation
3
+ """
4
+
5
+ import logging
6
+ import xarray as xr
7
+ import numpy as np
8
+
9
+ from pathlib import Path
10
+ from typing import Dict, Any, Union
11
+
12
+ from ocf_data_sampler.sample.base import SampleBase
13
+ from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SiteSample(SampleBase):
20
+ """ Sample class specific to Site PVNet """
21
+
22
+ def __init__(self):
23
+ logger.debug("Initialise SiteSample instance")
24
+ super().__init__()
25
+ self._data = {}
26
+
27
+ def to_numpy(self) -> Dict[str, Any]:
28
+ """ Convert sample numpy arrays - netCDF conversion """
29
+ logger.debug("Converting site sample to numpy format")
30
+
31
+ try:
32
+ if not isinstance(self._data, xr.Dataset):
33
+ raise TypeError("Data must be xarray Dataset")
34
+
35
+ numpy_data = convert_netcdf_to_numpy_sample(self._data)
36
+
37
+ logger.debug("Successfully converted to numpy format")
38
+ return numpy_data
39
+
40
+ except Exception as e:
41
+ logger.error(f"Error converting to numpy: {str(e)}")
42
+ raise
43
+
44
+ def save(self, path: Union[str, Path]) -> None:
45
+ """ Save site sample as netCDF - h5netcdf engine """
46
+ logger.debug(f"Saving SiteSample to {path}")
47
+ path = Path(path)
48
+
49
+ if path.suffix != '.nc':
50
+ logger.error(f"Invalid file format - {path.suffix}")
51
+ raise ValueError("Only .nc format is supported")
52
+
53
+ if not isinstance(self._data, xr.Dataset):
54
+ raise TypeError("Data must be xarray Dataset for saving")
55
+
56
+ self._data.to_netcdf(
57
+ path,
58
+ mode="w",
59
+ engine="h5netcdf"
60
+ )
61
+ logger.debug(f"Successfully saved SiteSample - {path}")
62
+
63
+ @classmethod
64
+ def load(cls, path: str) -> None:
65
+ """ Load site sample from netCDF """
66
+ logger.debug(f"Loading SiteSample from {path}")
67
+ path = Path(path)
68
+
69
+ if path.suffix != '.nc':
70
+ logger.error(f"Invalid file format - {path.suffix}")
71
+ raise ValueError("Only .nc format is supported")
72
+
73
+ instance = cls()
74
+ instance._data = xr.open_dataset(path)
75
+ logger.debug(f"Loaded SiteSample from {path}")
76
+ return instance
77
+
78
+ # TO DO - placeholder for now
79
+ def plot(self, **kwargs) -> None:
80
+ """ Plot sample data - placeholder """
81
+ pass
@@ -0,0 +1,118 @@
1
+ """
2
+ PVNet - UK Regional sample / dataset implementation
3
+ """
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import logging
9
+
10
+ from typing import Dict, Any, Union, List, Optional
11
+ from pathlib import Path
12
+
13
+ from ocf_data_sampler.numpy_sample import (
14
+ NWPSampleKey,
15
+ GSPSampleKey,
16
+ SatelliteSampleKey
17
+ )
18
+
19
+ from ocf_data_sampler.sample.base import SampleBase
20
+
21
+ try:
22
+ import matplotlib.pyplot as plt
23
+ MATPLOTLIB_AVAILABLE = True
24
+ except ImportError:
25
+ MATPLOTLIB_AVAILABLE = False
26
+ plt = None
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class UKRegionalSample(SampleBase):
33
+ """ Sample class specific to UK Regional PVNet """
34
+
35
+ def __init__(self):
36
+ logger.debug("Initialise UKRegionalSample instance")
37
+ super().__init__()
38
+ self._data = {}
39
+
40
+ def to_numpy(self) -> Dict[str, Any]:
41
+ """ Convert sample data to numpy format """
42
+ logger.debug("Converting sample data to numpy format")
43
+ return self._data
44
+
45
+ def save(self, path: Union[str, Path]) -> None:
46
+ """ Save PVNet sample as .pt """
47
+ logger.debug(f"Saving UKRegionalSample to {path}")
48
+ path = Path(path)
49
+
50
+ if path.suffix != '.pt':
51
+ logger.error(f"Invalid file format: {path.suffix}")
52
+ raise ValueError(f"Only .pt format is supported: {path.suffix}")
53
+
54
+ torch.save(self._data, path)
55
+ logger.debug(f"Successfully saved UKRegionalSample to {path}")
56
+
57
+ @classmethod
58
+ def load(cls, path: Union[str, Path]) -> 'UKRegionalSample':
59
+ """ Load PVNet sample data from .pt """
60
+ logger.debug(f"Attempting to load UKRegionalSample from {path}")
61
+ path = Path(path)
62
+
63
+ if path.suffix != '.pt':
64
+ logger.error(f"Invalid file format: {path.suffix}")
65
+ raise ValueError(f"Only .pt format is supported: {path.suffix}")
66
+
67
+ instance = cls()
68
+ instance._data = torch.load(path)
69
+ logger.debug(f"Successfully loaded UKRegionalSample from {path}")
70
+ return instance
71
+
72
+ def plot(self, **kwargs) -> None:
73
+ """ Sample visualisation definition """
74
+ logger.debug("Creating UKRegionalSample visualisation")
75
+
76
+ if not MATPLOTLIB_AVAILABLE:
77
+ raise ImportError(
78
+ "Matplotlib required for plotting"
79
+ "Install via 'ocf_data_sampler[plot]'"
80
+ )
81
+
82
+ try:
83
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
84
+
85
+ if NWPSampleKey.nwp in self._data:
86
+ logger.debug("Plotting NWP data")
87
+ first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
88
+ if 'nwp' in first_nwp:
89
+ axes[0, 1].imshow(first_nwp['nwp'][0])
90
+ axes[0, 1].set_title('NWP (First Channel)')
91
+ if NWPSampleKey.channel_names in first_nwp:
92
+ channel_names = first_nwp[NWPSampleKey.channel_names]
93
+ if len(channel_names) > 0:
94
+ axes[0, 1].set_title(f'NWP: {channel_names[0]}')
95
+
96
+ if GSPSampleKey.gsp in self._data:
97
+ logger.debug("Plotting GSP generation data")
98
+ axes[0, 0].plot(self._data[GSPSampleKey.gsp])
99
+ axes[0, 0].set_title('GSP Generation')
100
+
101
+ if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
102
+ logger.debug("Plotting solar position data")
103
+ axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
104
+ axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
105
+ axes[1, 1].set_title('Solar Position')
106
+ axes[1, 1].legend()
107
+
108
+ if SatelliteSampleKey.satellite_actual in self._data:
109
+ logger.debug("Plotting satellite data")
110
+ axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
111
+ axes[1, 0].set_title('Satellite Data')
112
+
113
+ plt.tight_layout()
114
+ plt.show()
115
+ logger.debug("Successfully created visualisation")
116
+ except Exception as e:
117
+ logger.error(f"Error creating visualisation: {str(e)}")
118
+ raise
@@ -0,0 +1,11 @@
1
+ from .pvnet_uk_regional import PVNetUKRegionalDataset
2
+
3
+ from .site import (
4
+ convert_netcdf_to_numpy_sample,
5
+ SitesDataset
6
+ )
7
+
8
+ __all__ = [
9
+ 'convert_netcdf_to_numpy_sample',
10
+ 'SitesDataset'
11
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.53
3
+ Version: 0.0.54
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -49,9 +49,12 @@ Requires-Dist: pyproj
49
49
  Requires-Dist: pathy
50
50
  Requires-Dist: pyaml_env
51
51
  Requires-Dist: pyresample
52
+ Requires-Dist: h5netcdf
52
53
  Provides-Extra: docs
53
54
  Requires-Dist: mkdocs>=1.2; extra == "docs"
54
55
  Requires-Dist: mkdocs-material>=8.0; extra == "docs"
56
+ Provides-Extra: plot
57
+ Requires-Dist: matplotlib; extra == "plot"
55
58
 
56
59
  # ocf-data-sampler
57
60
 
@@ -35,6 +35,10 @@ ocf_data_sampler/numpy_sample/nwp.py
35
35
  ocf_data_sampler/numpy_sample/satellite.py
36
36
  ocf_data_sampler/numpy_sample/site.py
37
37
  ocf_data_sampler/numpy_sample/sun_position.py
38
+ ocf_data_sampler/sample/__init__.py
39
+ ocf_data_sampler/sample/base.py
40
+ ocf_data_sampler/sample/site.py
41
+ ocf_data_sampler/sample/uk_regional.py
38
42
  ocf_data_sampler/select/__init__.py
39
43
  ocf_data_sampler/select/dropout.py
40
44
  ocf_data_sampler/select/fill_time_periods.py
@@ -71,6 +75,9 @@ tests/select/test_find_contiguous_time_periods.py
71
75
  tests/select/test_location.py
72
76
  tests/select/test_select_spatial_slice.py
73
77
  tests/select/test_select_time_slice.py
78
+ tests/test_sample/test_base.py
79
+ tests/test_sample/test_site_sample.py
80
+ tests/test_sample/test_uk_regional_sample.py
74
81
  tests/torch_datasets/conftest.py
75
82
  tests/torch_datasets/test_merge_and_fill_utils.py
76
83
  tests/torch_datasets/test_pvnet_uk_regional.py
@@ -11,7 +11,11 @@ pyproj
11
11
  pathy
12
12
  pyaml_env
13
13
  pyresample
14
+ h5netcdf
14
15
 
15
16
  [docs]
16
17
  mkdocs>=1.2
17
18
  mkdocs-material>=8.0
19
+
20
+ [plot]
21
+ matplotlib
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ocf_data_sampler"
7
- version = "0.0.53"
7
+ version = "0.0.54"
8
8
  license = { file = "LICENSE" }
9
9
  readme = "README.md"
10
10
  description = "Sample from weather data for renewable energy prediction"
@@ -30,7 +30,8 @@ dependencies = [ # Migration from requirements.txt
30
30
  "pyproj",
31
31
  "pathy",
32
32
  "pyaml_env",
33
- "pyresample"
33
+ "pyresample",
34
+ "h5netcdf",
34
35
  ]
35
36
 
36
37
  keywords = [ # I've added some keywords, but please provide feedback if you'd like them changed!
@@ -51,6 +52,9 @@ docs = [
51
52
  "mkdocs>=1.2",
52
53
  "mkdocs-material>=8.0"
53
54
  ]
55
+ plot = [
56
+ "matplotlib"
57
+ ]
54
58
 
55
59
  [project.urls]
56
60
  homepage = "https://github.com/openclimatefix"
@@ -0,0 +1,86 @@
1
+ """
2
+ Base class testing - SampleBase
3
+ """
4
+
5
+ import pytest
6
+ import numpy as np
7
+
8
+ from pathlib import Path
9
+ from ocf_data_sampler.sample.base import SampleBase
10
+
11
+
12
+ class TestSample(SampleBase):
13
+ """
14
+ SampleBase for testing purposes
15
+ Minimal implementations - abstract methods
16
+ """
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ self._data = {}
21
+
22
+ def plot(self, **kwargs):
23
+ """ Minimal plot implementation """
24
+ return None
25
+
26
+ def to_numpy(self) -> None:
27
+ """ Standard implementation """
28
+ return {key: np.array(value) for key, value in self._data.items()}
29
+
30
+ def save(self, path):
31
+ """ Minimal save implementation """
32
+ path = Path(path)
33
+ with open(path, 'wb') as f:
34
+ f.write(b'test_data')
35
+
36
+ @classmethod
37
+ def load(cls, path):
38
+ """ Minimal load implementation """
39
+ path = Path(path)
40
+ instance = cls()
41
+ return instance
42
+
43
+
44
+ def test_sample_base_initialisation():
45
+ """ Initialisation of SampleBase subclass """
46
+
47
+ sample = TestSample()
48
+ assert hasattr(sample, '_data'), "Sample should have _data attribute"
49
+ assert sample._data == {}, "Sample should start with empty dict"
50
+
51
+
52
+ def test_sample_base_save_load(tmp_path):
53
+ """ Test basic save and load functionality """
54
+
55
+ sample = TestSample()
56
+ sample._data['test_data'] = [1, 2, 3]
57
+
58
+ save_path = tmp_path / 'test_sample.dat'
59
+ sample.save(save_path)
60
+ assert save_path.exists()
61
+
62
+ loaded_sample = TestSample.load(save_path)
63
+ assert isinstance(loaded_sample, TestSample)
64
+
65
+
66
+ def test_sample_base_abstract_methods():
67
+ """ Test abstract method enforcement """
68
+
69
+ with pytest.raises(TypeError, match="Can't instantiate abstract class"):
70
+ SampleBase()
71
+
72
+
73
+ def test_sample_base_to_numpy():
74
+ """ Test the to_numpy functionality """
75
+ import numpy as np
76
+
77
+ sample = TestSample()
78
+ sample._data = {
79
+ 'int_data': 42,
80
+ 'list_data': [1, 2, 3]
81
+ }
82
+ numpy_data = sample.to_numpy()
83
+
84
+ assert isinstance(numpy_data, dict)
85
+ assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
86
+ assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
@@ -0,0 +1,195 @@
1
+ """
2
+ Site class testing - SiteSample
3
+ """
4
+
5
+ import pytest
6
+ import numpy as np
7
+ import xarray as xr
8
+ import pandas as pd
9
+
10
+ from pathlib import Path
11
+ from xarray import Dataset
12
+
13
+ from ocf_data_sampler.sample.site import SiteSample
14
+
15
+
16
+ @pytest.fixture
17
+ def sample_data():
18
+ """ Fixture creation with sample data """
19
+
20
+ # Time periods specified
21
+ init_time = pd.Timestamp("2023-01-01 00:00")
22
+ target_times = pd.date_range("2023-01-01 00:00", periods=4, freq="1h")
23
+ sat_times = pd.date_range("2023-01-01 00:00", periods=7, freq="5min")
24
+ site_times = pd.date_range("2023-01-01 00:00", periods=4, freq="30min")
25
+
26
+ # Defined steps for NWP data
27
+ steps = [(t - init_time) for t in target_times]
28
+
29
+ # Create the sample dataset
30
+ return Dataset(
31
+ data_vars={
32
+ 'nwp-ukv': (
33
+ ["nwp-ukv__target_time_utc", "nwp-ukv__channel",
34
+ "nwp-ukv__y_osgb", "nwp-ukv__x_osgb"],
35
+ np.random.rand(4, 1, 2, 2)
36
+ ),
37
+ 'satellite': (
38
+ ["satellite__time_utc", "satellite__channel",
39
+ "satellite__y_geostationary", "satellite__x_geostationary"],
40
+ np.random.rand(7, 1, 2, 2)
41
+ ),
42
+ 'site': (["site__time_utc"], np.random.rand(4))
43
+ },
44
+ coords={
45
+ # NWP coords
46
+ 'nwp-ukv__target_time_utc': target_times,
47
+ 'nwp-ukv__channel': ['dswrf'],
48
+ 'nwp-ukv__y_osgb': [0, 1],
49
+ 'nwp-ukv__x_osgb': [0, 1],
50
+ 'nwp-ukv__init_time_utc': init_time,
51
+ 'nwp-ukv__step': ('nwp-ukv__target_time_utc', steps),
52
+
53
+ # Sat coords
54
+ 'satellite__time_utc': sat_times,
55
+ 'satellite__channel': ['vis'],
56
+ 'satellite__y_geostationary': [0, 1],
57
+ 'satellite__x_geostationary': [0, 1],
58
+
59
+ # Site coords
60
+ 'site__time_utc': site_times,
61
+ 'site__capacity_kwp': 1000.0,
62
+ 'site__site_id': 1,
63
+ 'site__longitude': -3.5,
64
+ 'site__latitude': 51.5,
65
+
66
+ # Site features as coords
67
+ 'site__solar_azimuth': ('site__time_utc', np.random.rand(4)),
68
+ 'site__solar_elevation': ('site__time_utc', np.random.rand(4)),
69
+ 'site__date_cos': ('site__time_utc', np.random.rand(4)),
70
+ 'site__date_sin': ('site__time_utc', np.random.rand(4)),
71
+ 'site__time_cos': ('site__time_utc', np.random.rand(4)),
72
+ 'site__time_sin': ('site__time_utc', np.random.rand(4))
73
+ }
74
+ )
75
+
76
+
77
+ def test_site_sample_init():
78
+ """ Test initialisation """
79
+ sample = SiteSample()
80
+ assert isinstance(sample._data, dict)
81
+ assert len(sample._data) == 0
82
+
83
+
84
+ def test_site_sample_with_data(sample_data):
85
+ """ Testing of defined sample with actual data """
86
+ sample = SiteSample()
87
+ sample._data = sample_data
88
+
89
+ # Assert data structure
90
+ assert isinstance(sample._data, Dataset)
91
+
92
+ # Assert dimensions / shapes
93
+ expected_dims = {
94
+ "satellite__x_geostationary",
95
+ "site__time_utc",
96
+ "nwp-ukv__target_time_utc",
97
+ "nwp-ukv__x_osgb",
98
+ "satellite__channel",
99
+ "satellite__y_geostationary",
100
+ "satellite__time_utc",
101
+ "nwp-ukv__channel",
102
+ "nwp-ukv__y_osgb",
103
+ }
104
+ assert set(sample._data.dims) == expected_dims
105
+ assert sample._data["satellite"].values.shape == (7, 1, 2, 2)
106
+ assert sample._data["nwp-ukv"].values.shape == (4, 1, 2, 2)
107
+ assert sample._data["site"].values.shape == (4,)
108
+
109
+
110
+ def test_save_load(tmp_path, sample_data):
111
+ """ Save and load functionality """
112
+ sample = SiteSample()
113
+ sample._data = sample_data
114
+ filepath = tmp_path / "test_sample.nc"
115
+ sample.save(filepath)
116
+
117
+ # Assert file exists and has content
118
+ assert filepath.exists()
119
+ assert filepath.stat().st_size > 0
120
+
121
+ # Load and verify
122
+ loaded = SiteSample.load(filepath)
123
+ assert isinstance(loaded, SiteSample)
124
+ assert isinstance(loaded._data, Dataset)
125
+
126
+ # Compare original / loaded data
127
+ xr.testing.assert_identical(sample._data, loaded._data)
128
+
129
+
130
+ def test_invalid_save_format(sample_data):
131
+ """ Saving with invalid format """
132
+ sample = SiteSample()
133
+ sample._data = sample_data
134
+ with pytest.raises(ValueError, match="Only .nc format is supported"):
135
+ sample.save("invalid.txt")
136
+
137
+
138
+ def test_invalid_load_format():
139
+ """ Loading with invalid format """
140
+ with pytest.raises(ValueError, match="Only .nc format is supported"):
141
+ SiteSample.load("invalid.txt")
142
+
143
+
144
+ def test_invalid_data_type():
145
+ """ Handling of invalid data types """
146
+ sample = SiteSample()
147
+ sample._data = {"invalid": "data"}
148
+
149
+ with pytest.raises(TypeError, match="Data must be xarray Dataset"):
150
+ sample.to_numpy()
151
+
152
+ with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
153
+ sample.save("test.nc")
154
+
155
+
156
+ def test_to_numpy(sample_data):
157
+ """ To numpy conversion """
158
+ sample = SiteSample()
159
+ sample._data = sample_data
160
+ numpy_data = sample.to_numpy()
161
+
162
+ # Assert structure
163
+ assert isinstance(numpy_data, dict)
164
+ assert 'site' in numpy_data
165
+ assert 'nwp' in numpy_data
166
+
167
+ # Check site - numpy array instead of dict
168
+ site_data = numpy_data['site']
169
+ assert isinstance(site_data, np.ndarray)
170
+ assert site_data.ndim == 1
171
+ assert len(site_data) == 4
172
+ assert np.all(site_data >= 0) and np.all(site_data <= 1)
173
+
174
+ # Check NWP
175
+ assert 'ukv' in numpy_data['nwp']
176
+ nwp_data = numpy_data['nwp']['ukv']
177
+ assert 'nwp' in nwp_data
178
+ assert nwp_data['nwp'].shape == (4, 1, 2, 2)
179
+
180
+
181
+ def test_data_consistency(sample_data):
182
+ """ Consistency of data across operations """
183
+ sample = SiteSample()
184
+ sample._data = sample_data
185
+ numpy_data = sample.to_numpy()
186
+
187
+ # Assert components remain consistent after conversion above
188
+ assert numpy_data['nwp']['ukv']['nwp'].shape == (4, 1, 2, 2)
189
+ assert 'site' in numpy_data
190
+
191
+ # Update site data checks to expect numpy array
192
+ assert isinstance(numpy_data['site'], np.ndarray)
193
+ assert numpy_data['site'].shape == (4,)
194
+ assert np.all(numpy_data['site'] >= 0)
195
+ assert np.all(numpy_data['site'] <= 1)
@@ -0,0 +1,163 @@
1
+ """
2
+ UK Regional class testing - UKRegionalSample
3
+ """
4
+
5
+ import pytest
6
+ import numpy as np
7
+ import torch
8
+ import tempfile
9
+
10
+ from ocf_data_sampler.numpy_sample import (
11
+ GSPSampleKey,
12
+ SatelliteSampleKey,
13
+ NWPSampleKey
14
+ )
15
+
16
+ from ocf_data_sampler.sample.uk_regional import UKRegionalSample
17
+
18
+
19
+ # Fixture define
20
+ @pytest.fixture
21
+ def pvnet_config_filename(tmp_path):
22
+ """ Minimal config file - testing """
23
+ config_content = """
24
+ input_data:
25
+ gsp:
26
+ zarr_path: ""
27
+ time_resolution_minutes: 30
28
+ interval_start_minutes: -180
29
+ interval_end_minutes: 0
30
+ nwp:
31
+ ukv:
32
+ zarr_path: ""
33
+ image_size_pixels_height: 64
34
+ image_size_pixels_width: 64
35
+ time_resolution_minutes: 60
36
+ interval_start_minutes: -180
37
+ interval_end_minutes: 0
38
+ channels: ["t", "dswrf"]
39
+ provider: "ukv"
40
+ satellite:
41
+ zarr_path: ""
42
+ image_size_pixels_height: 64
43
+ image_size_pixels_width: 64
44
+ time_resolution_minutes: 30
45
+ interval_start_minutes: -180
46
+ interval_end_minutes: 0
47
+ channels: ["HRV"]
48
+ """
49
+ config_file = tmp_path / "test_config.yaml"
50
+ config_file.write_text(config_content)
51
+ return str(config_file)
52
+
53
+
54
+ def create_test_data():
55
+ """ Synthetic data generation """
56
+
57
+ # Field / spatial coordinates
58
+ nwp_data = {
59
+ 'nwp': np.random.rand(4, 1, 2, 2),
60
+ 'x': np.array([1, 2]),
61
+ 'y': np.array([1, 2]),
62
+ NWPSampleKey.channel_names: ['test_channel']
63
+ }
64
+
65
+ return {
66
+ 'nwp': {
67
+ 'ukv': nwp_data
68
+ },
69
+ GSPSampleKey.gsp: np.random.rand(7),
70
+ SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
71
+ GSPSampleKey.solar_azimuth: np.random.rand(7),
72
+ GSPSampleKey.solar_elevation: np.random.rand(7)
73
+ }
74
+
75
+
76
+ # UKRegionalSample testing
77
+ def test_sample_init():
78
+ """ Initialisation """
79
+ sample = UKRegionalSample()
80
+ assert hasattr(sample, '_data'), "Sample should have _data attribute"
81
+ assert isinstance(sample._data, dict)
82
+ assert len(sample._data) == 0
83
+
84
+
85
+ def test_sample_save_load():
86
+ sample = UKRegionalSample()
87
+ sample._data = create_test_data()
88
+
89
+ with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
90
+ sample.save(tf.name)
91
+ loaded = UKRegionalSample.load(tf.name)
92
+
93
+ assert set(loaded._data.keys()) == set(sample._data.keys())
94
+ assert isinstance(loaded._data['nwp'], dict)
95
+ assert 'ukv' in loaded._data['nwp']
96
+
97
+ assert loaded._data[GSPSampleKey.gsp].shape == (7,)
98
+ assert loaded._data[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
99
+ assert loaded._data[GSPSampleKey.solar_azimuth].shape == (7,)
100
+ assert loaded._data[GSPSampleKey.solar_elevation].shape == (7,)
101
+
102
+ np.testing.assert_array_almost_equal(
103
+ loaded._data[GSPSampleKey.gsp],
104
+ sample._data[GSPSampleKey.gsp]
105
+ )
106
+
107
+
108
+ def test_save_unsupported_format():
109
+ """ Test saving - unsupported file format """
110
+ sample = UKRegionalSample()
111
+ sample._data = create_test_data()
112
+
113
+ with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
114
+ with pytest.raises(ValueError, match="Only .pt format is supported"):
115
+ sample.save(tf.name)
116
+
117
+
118
+ def test_load_unsupported_format():
119
+ """ Test loading - unsupported file format """
120
+
121
+ with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
122
+ with pytest.raises(ValueError, match="Only .pt format is supported"):
123
+ UKRegionalSample.load(tf.name)
124
+
125
+
126
+ def test_load_corrupted_file():
127
+ """ Test loading - corrupted / empty file """
128
+
129
+ with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
130
+ with open(tf.name, 'wb') as f:
131
+ f.write(b'corrupted data')
132
+
133
+ with pytest.raises(Exception):
134
+ UKRegionalSample.load(tf.name)
135
+
136
+
137
+ def test_to_numpy():
138
+ """ To numpy conversion check """
139
+ sample = UKRegionalSample()
140
+ sample._data = {
141
+ 'nwp': {
142
+ 'ukv': {
143
+ 'nwp': np.random.rand(4, 1, 2, 2),
144
+ 'x': np.array([1, 2]),
145
+ 'y': np.array([1, 2])
146
+ }
147
+ },
148
+ GSPSampleKey.gsp: np.random.rand(7),
149
+ SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
150
+ GSPSampleKey.solar_azimuth: np.random.rand(7),
151
+ GSPSampleKey.solar_elevation: np.random.rand(7)
152
+ }
153
+
154
+ numpy_data = sample.to_numpy()
155
+
156
+ # Check returned data matches
157
+ assert numpy_data == sample._data
158
+ assert len(numpy_data) == len(sample._data)
159
+
160
+ # Assert specific keys and types
161
+ assert 'nwp' in numpy_data
162
+ assert isinstance(numpy_data['nwp']['ukv']['nwp'], np.ndarray)
163
+ assert numpy_data[GSPSampleKey.gsp].shape == (7,)
@@ -1,2 +0,0 @@
1
- from .pvnet_uk_regional import PVNetUKRegionalDataset
2
- from .site import SitesDataset