ocf-data-sampler 0.0.52__py3-none-any.whl → 0.0.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/load/utils.py +1 -1
- ocf_data_sampler/sample/__init__.py +10 -0
- ocf_data_sampler/sample/base.py +44 -0
- ocf_data_sampler/sample/site.py +81 -0
- ocf_data_sampler/sample/uk_regional.py +118 -0
- ocf_data_sampler/torch_datasets/datasets/__init__.py +10 -1
- ocf_data_sampler/torch_datasets/datasets/site.py +25 -1
- {ocf_data_sampler-0.0.52.dist-info → ocf_data_sampler-0.0.54.dist-info}/METADATA +4 -1
- {ocf_data_sampler-0.0.52.dist-info → ocf_data_sampler-0.0.54.dist-info}/RECORD +16 -9
- tests/test_sample/test_base.py +86 -0
- tests/test_sample/test_site_sample.py +195 -0
- tests/test_sample/test_uk_regional_sample.py +163 -0
- tests/torch_datasets/test_site.py +17 -2
- {ocf_data_sampler-0.0.52.dist-info → ocf_data_sampler-0.0.54.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.52.dist-info → ocf_data_sampler-0.0.54.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.0.52.dist-info → ocf_data_sampler-0.0.54.dist-info}/top_level.txt +0 -0
ocf_data_sampler/load/utils.py
CHANGED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base class definition - abstract
|
|
3
|
+
Handling of both flat and nested structures - consideration for NWP
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, Optional, Union
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class SampleBase(ABC):
|
|
16
|
+
"""
|
|
17
|
+
Abstract base class for all sample types
|
|
18
|
+
Provides core data storage functionality
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
""" Initialise data container """
|
|
23
|
+
logger.debug("Initialising SampleBase instance")
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def to_numpy(self) -> Dict[str, Any]:
|
|
27
|
+
""" Convert data to a numpy array representation """
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def plot(self, **kwargs) -> None:
|
|
32
|
+
""" Abstract method for plotting """
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
37
|
+
""" Abstract method for saving sample data """
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def load(cls, path: Union[str, Path]) -> 'SampleBase':
|
|
43
|
+
""" Abstract class method for loading sample data """
|
|
44
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PVNet - Site sample / dataset implementation
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import xarray as xr
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, Any, Union
|
|
11
|
+
|
|
12
|
+
from ocf_data_sampler.sample.base import SampleBase
|
|
13
|
+
from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SiteSample(SampleBase):
|
|
20
|
+
""" Sample class specific to Site PVNet """
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
logger.debug("Initialise SiteSample instance")
|
|
24
|
+
super().__init__()
|
|
25
|
+
self._data = {}
|
|
26
|
+
|
|
27
|
+
def to_numpy(self) -> Dict[str, Any]:
|
|
28
|
+
""" Convert sample numpy arrays - netCDF conversion """
|
|
29
|
+
logger.debug("Converting site sample to numpy format")
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
if not isinstance(self._data, xr.Dataset):
|
|
33
|
+
raise TypeError("Data must be xarray Dataset")
|
|
34
|
+
|
|
35
|
+
numpy_data = convert_netcdf_to_numpy_sample(self._data)
|
|
36
|
+
|
|
37
|
+
logger.debug("Successfully converted to numpy format")
|
|
38
|
+
return numpy_data
|
|
39
|
+
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logger.error(f"Error converting to numpy: {str(e)}")
|
|
42
|
+
raise
|
|
43
|
+
|
|
44
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
45
|
+
""" Save site sample as netCDF - h5netcdf engine """
|
|
46
|
+
logger.debug(f"Saving SiteSample to {path}")
|
|
47
|
+
path = Path(path)
|
|
48
|
+
|
|
49
|
+
if path.suffix != '.nc':
|
|
50
|
+
logger.error(f"Invalid file format - {path.suffix}")
|
|
51
|
+
raise ValueError("Only .nc format is supported")
|
|
52
|
+
|
|
53
|
+
if not isinstance(self._data, xr.Dataset):
|
|
54
|
+
raise TypeError("Data must be xarray Dataset for saving")
|
|
55
|
+
|
|
56
|
+
self._data.to_netcdf(
|
|
57
|
+
path,
|
|
58
|
+
mode="w",
|
|
59
|
+
engine="h5netcdf"
|
|
60
|
+
)
|
|
61
|
+
logger.debug(f"Successfully saved SiteSample - {path}")
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def load(cls, path: str) -> None:
|
|
65
|
+
""" Load site sample from netCDF """
|
|
66
|
+
logger.debug(f"Loading SiteSample from {path}")
|
|
67
|
+
path = Path(path)
|
|
68
|
+
|
|
69
|
+
if path.suffix != '.nc':
|
|
70
|
+
logger.error(f"Invalid file format - {path.suffix}")
|
|
71
|
+
raise ValueError("Only .nc format is supported")
|
|
72
|
+
|
|
73
|
+
instance = cls()
|
|
74
|
+
instance._data = xr.open_dataset(path)
|
|
75
|
+
logger.debug(f"Loaded SiteSample from {path}")
|
|
76
|
+
return instance
|
|
77
|
+
|
|
78
|
+
# TO DO - placeholder for now
|
|
79
|
+
def plot(self, **kwargs) -> None:
|
|
80
|
+
""" Plot sample data - placeholder """
|
|
81
|
+
pass
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PVNet - UK Regional sample / dataset implementation
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from typing import Dict, Any, Union, List, Optional
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
from ocf_data_sampler.numpy_sample import (
|
|
14
|
+
NWPSampleKey,
|
|
15
|
+
GSPSampleKey,
|
|
16
|
+
SatelliteSampleKey
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from ocf_data_sampler.sample.base import SampleBase
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import matplotlib.pyplot as plt
|
|
23
|
+
MATPLOTLIB_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
MATPLOTLIB_AVAILABLE = False
|
|
26
|
+
plt = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UKRegionalSample(SampleBase):
|
|
33
|
+
""" Sample class specific to UK Regional PVNet """
|
|
34
|
+
|
|
35
|
+
def __init__(self):
|
|
36
|
+
logger.debug("Initialise UKRegionalSample instance")
|
|
37
|
+
super().__init__()
|
|
38
|
+
self._data = {}
|
|
39
|
+
|
|
40
|
+
def to_numpy(self) -> Dict[str, Any]:
|
|
41
|
+
""" Convert sample data to numpy format """
|
|
42
|
+
logger.debug("Converting sample data to numpy format")
|
|
43
|
+
return self._data
|
|
44
|
+
|
|
45
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
46
|
+
""" Save PVNet sample as .pt """
|
|
47
|
+
logger.debug(f"Saving UKRegionalSample to {path}")
|
|
48
|
+
path = Path(path)
|
|
49
|
+
|
|
50
|
+
if path.suffix != '.pt':
|
|
51
|
+
logger.error(f"Invalid file format: {path.suffix}")
|
|
52
|
+
raise ValueError(f"Only .pt format is supported: {path.suffix}")
|
|
53
|
+
|
|
54
|
+
torch.save(self._data, path)
|
|
55
|
+
logger.debug(f"Successfully saved UKRegionalSample to {path}")
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def load(cls, path: Union[str, Path]) -> 'UKRegionalSample':
|
|
59
|
+
""" Load PVNet sample data from .pt """
|
|
60
|
+
logger.debug(f"Attempting to load UKRegionalSample from {path}")
|
|
61
|
+
path = Path(path)
|
|
62
|
+
|
|
63
|
+
if path.suffix != '.pt':
|
|
64
|
+
logger.error(f"Invalid file format: {path.suffix}")
|
|
65
|
+
raise ValueError(f"Only .pt format is supported: {path.suffix}")
|
|
66
|
+
|
|
67
|
+
instance = cls()
|
|
68
|
+
instance._data = torch.load(path)
|
|
69
|
+
logger.debug(f"Successfully loaded UKRegionalSample from {path}")
|
|
70
|
+
return instance
|
|
71
|
+
|
|
72
|
+
def plot(self, **kwargs) -> None:
|
|
73
|
+
""" Sample visualisation definition """
|
|
74
|
+
logger.debug("Creating UKRegionalSample visualisation")
|
|
75
|
+
|
|
76
|
+
if not MATPLOTLIB_AVAILABLE:
|
|
77
|
+
raise ImportError(
|
|
78
|
+
"Matplotlib required for plotting"
|
|
79
|
+
"Install via 'ocf_data_sampler[plot]'"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
84
|
+
|
|
85
|
+
if NWPSampleKey.nwp in self._data:
|
|
86
|
+
logger.debug("Plotting NWP data")
|
|
87
|
+
first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
|
|
88
|
+
if 'nwp' in first_nwp:
|
|
89
|
+
axes[0, 1].imshow(first_nwp['nwp'][0])
|
|
90
|
+
axes[0, 1].set_title('NWP (First Channel)')
|
|
91
|
+
if NWPSampleKey.channel_names in first_nwp:
|
|
92
|
+
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
93
|
+
if len(channel_names) > 0:
|
|
94
|
+
axes[0, 1].set_title(f'NWP: {channel_names[0]}')
|
|
95
|
+
|
|
96
|
+
if GSPSampleKey.gsp in self._data:
|
|
97
|
+
logger.debug("Plotting GSP generation data")
|
|
98
|
+
axes[0, 0].plot(self._data[GSPSampleKey.gsp])
|
|
99
|
+
axes[0, 0].set_title('GSP Generation')
|
|
100
|
+
|
|
101
|
+
if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
|
|
102
|
+
logger.debug("Plotting solar position data")
|
|
103
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
|
|
104
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
|
|
105
|
+
axes[1, 1].set_title('Solar Position')
|
|
106
|
+
axes[1, 1].legend()
|
|
107
|
+
|
|
108
|
+
if SatelliteSampleKey.satellite_actual in self._data:
|
|
109
|
+
logger.debug("Plotting satellite data")
|
|
110
|
+
axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
|
|
111
|
+
axes[1, 0].set_title('Satellite Data')
|
|
112
|
+
|
|
113
|
+
plt.tight_layout()
|
|
114
|
+
plt.show()
|
|
115
|
+
logger.debug("Successfully created visualisation")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f"Error creating visualisation: {str(e)}")
|
|
118
|
+
raise
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Torch dataset for sites"""
|
|
2
2
|
import logging
|
|
3
|
-
|
|
3
|
+
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import xarray as xr
|
|
6
6
|
from typing import Tuple
|
|
@@ -421,3 +421,27 @@ def convert_to_numpy_and_combine(
|
|
|
421
421
|
combined_sample = fill_nans_in_arrays(combined_sample)
|
|
422
422
|
|
|
423
423
|
return combined_sample
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float=0.1):
|
|
427
|
+
"""
|
|
428
|
+
Coarsen the data to a specified resolution in degrees.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
xr_data: xarray dataset to coarsen
|
|
432
|
+
coarsen_to_deg: resolution to coarsen to in degrees
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
if "latitude" in xr_data.coords and "longitude" in xr_data.coords:
|
|
436
|
+
step = np.abs(xr_data.latitude.values[1]-xr_data.latitude.values[0])
|
|
437
|
+
step = np.round(step,4)
|
|
438
|
+
coarsen_factor = int(coarsen_to_deg/step)
|
|
439
|
+
if coarsen_factor > 1:
|
|
440
|
+
xr_data = xr_data.coarsen(
|
|
441
|
+
latitude=coarsen_factor,
|
|
442
|
+
longitude=coarsen_factor,
|
|
443
|
+
boundary="pad",
|
|
444
|
+
coord_func="min"
|
|
445
|
+
).mean()
|
|
446
|
+
|
|
447
|
+
return xr_data
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: ocf_data_sampler
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.54
|
|
4
4
|
Summary: Sample from weather data for renewable energy prediction
|
|
5
5
|
Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
|
|
6
6
|
Author-email: info@openclimatefix.org
|
|
@@ -49,9 +49,12 @@ Requires-Dist: pyproj
|
|
|
49
49
|
Requires-Dist: pathy
|
|
50
50
|
Requires-Dist: pyaml_env
|
|
51
51
|
Requires-Dist: pyresample
|
|
52
|
+
Requires-Dist: h5netcdf
|
|
52
53
|
Provides-Extra: docs
|
|
53
54
|
Requires-Dist: mkdocs>=1.2; extra == "docs"
|
|
54
55
|
Requires-Dist: mkdocs-material>=8.0; extra == "docs"
|
|
56
|
+
Provides-Extra: plot
|
|
57
|
+
Requires-Dist: matplotlib; extra == "plot"
|
|
55
58
|
|
|
56
59
|
# ocf-data-sampler
|
|
57
60
|
|
|
@@ -11,7 +11,7 @@ ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,
|
|
|
11
11
|
ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
|
|
12
12
|
ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
|
|
13
13
|
ocf_data_sampler/load/site.py,sha256=P83uz01WBDzoZajdOH0m8FQt4-buKDlUk19N548KqhA,1086
|
|
14
|
-
ocf_data_sampler/load/utils.py,sha256=
|
|
14
|
+
ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
|
|
15
15
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
16
16
|
ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
|
|
17
17
|
ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -26,6 +26,10 @@ ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5b
|
|
|
26
26
|
ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBunbhDN4Pjo0Grxs,910
|
|
27
27
|
ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
|
|
28
28
|
ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
|
|
29
|
+
ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
|
|
30
|
+
ocf_data_sampler/sample/base.py,sha256=4U78tczCRsKMDwU4HkD20nyGyYjIBSZV5neF2mT--2M,1197
|
|
31
|
+
ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
|
|
32
|
+
ocf_data_sampler/sample/uk_regional.py,sha256=FPaFi6qaTsi1ag42pfVKDZhopt3cDjQsF4rVI8k2qWo,4244
|
|
29
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
30
34
|
ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
|
|
31
35
|
ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
|
|
@@ -36,9 +40,9 @@ ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejD
|
|
|
36
40
|
ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56sNWk3BnCnkCgcPI,4725
|
|
37
41
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
38
42
|
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=P7cAARfDzjttGDvpKt2zuA4WkLoTmSXy_lBpI8RiA6k,4249
|
|
39
|
-
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=
|
|
43
|
+
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=JMfMQ6DxCWQiwm-Xwdy_b0gnGqZnORSR_SGrLM1QEe4,201
|
|
40
44
|
ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py,sha256=xxeX4Js9LQpydehi3BS7k9psqkYGzgJuM17uTYux40M,8742
|
|
41
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
45
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=5T8nkTMUHHFidZRuFOunYeKAqNuyZ8V7sikBoBOBwwA,16033
|
|
42
46
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
43
47
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
44
48
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
@@ -62,12 +66,15 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
|
|
|
62
66
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
63
67
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
64
68
|
tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
|
|
69
|
+
tests/test_sample/test_base.py,sha256=ljtB38MmscTGN6OvUgclBceNnfx6m7AN8iHYDml9XW4,2189
|
|
70
|
+
tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
|
|
71
|
+
tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
|
|
65
72
|
tests/torch_datasets/conftest.py,sha256=eRCzHE7cxS4AoskExkCGFDBeqItktAYNAdkfpMoFCeE,629
|
|
66
73
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=ueA0A7gZaWEgNdsU8p3CnKuvSnlleTUjEhSw2HUUROM,1229
|
|
67
74
|
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=FCiFueeFqrsXe7gWguSjBz5ZeUrvyhGbGw81gaVvkHM,5087
|
|
68
|
-
tests/torch_datasets/test_site.py,sha256=
|
|
69
|
-
ocf_data_sampler-0.0.
|
|
70
|
-
ocf_data_sampler-0.0.
|
|
71
|
-
ocf_data_sampler-0.0.
|
|
72
|
-
ocf_data_sampler-0.0.
|
|
73
|
-
ocf_data_sampler-0.0.
|
|
75
|
+
tests/torch_datasets/test_site.py,sha256=J1ZDE5V5MRlq7EuZ1zUu-aFRGTDJIiO-ZZzkOXvDdWA,6757
|
|
76
|
+
ocf_data_sampler-0.0.54.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
77
|
+
ocf_data_sampler-0.0.54.dist-info/METADATA,sha256=o6EIqhbRzXKCy74lRYn4HS48DZyEnyLS1XgiFDh-F-g,12231
|
|
78
|
+
ocf_data_sampler-0.0.54.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
79
|
+
ocf_data_sampler-0.0.54.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
80
|
+
ocf_data_sampler-0.0.54.dist-info/RECORD,,
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base class testing - SampleBase
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from ocf_data_sampler.sample.base import SampleBase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestSample(SampleBase):
|
|
13
|
+
"""
|
|
14
|
+
SampleBase for testing purposes
|
|
15
|
+
Minimal implementations - abstract methods
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self._data = {}
|
|
21
|
+
|
|
22
|
+
def plot(self, **kwargs):
|
|
23
|
+
""" Minimal plot implementation """
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
def to_numpy(self) -> None:
|
|
27
|
+
""" Standard implementation """
|
|
28
|
+
return {key: np.array(value) for key, value in self._data.items()}
|
|
29
|
+
|
|
30
|
+
def save(self, path):
|
|
31
|
+
""" Minimal save implementation """
|
|
32
|
+
path = Path(path)
|
|
33
|
+
with open(path, 'wb') as f:
|
|
34
|
+
f.write(b'test_data')
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def load(cls, path):
|
|
38
|
+
""" Minimal load implementation """
|
|
39
|
+
path = Path(path)
|
|
40
|
+
instance = cls()
|
|
41
|
+
return instance
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_sample_base_initialisation():
|
|
45
|
+
""" Initialisation of SampleBase subclass """
|
|
46
|
+
|
|
47
|
+
sample = TestSample()
|
|
48
|
+
assert hasattr(sample, '_data'), "Sample should have _data attribute"
|
|
49
|
+
assert sample._data == {}, "Sample should start with empty dict"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_sample_base_save_load(tmp_path):
|
|
53
|
+
""" Test basic save and load functionality """
|
|
54
|
+
|
|
55
|
+
sample = TestSample()
|
|
56
|
+
sample._data['test_data'] = [1, 2, 3]
|
|
57
|
+
|
|
58
|
+
save_path = tmp_path / 'test_sample.dat'
|
|
59
|
+
sample.save(save_path)
|
|
60
|
+
assert save_path.exists()
|
|
61
|
+
|
|
62
|
+
loaded_sample = TestSample.load(save_path)
|
|
63
|
+
assert isinstance(loaded_sample, TestSample)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_sample_base_abstract_methods():
|
|
67
|
+
""" Test abstract method enforcement """
|
|
68
|
+
|
|
69
|
+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
|
70
|
+
SampleBase()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_sample_base_to_numpy():
|
|
74
|
+
""" Test the to_numpy functionality """
|
|
75
|
+
import numpy as np
|
|
76
|
+
|
|
77
|
+
sample = TestSample()
|
|
78
|
+
sample._data = {
|
|
79
|
+
'int_data': 42,
|
|
80
|
+
'list_data': [1, 2, 3]
|
|
81
|
+
}
|
|
82
|
+
numpy_data = sample.to_numpy()
|
|
83
|
+
|
|
84
|
+
assert isinstance(numpy_data, dict)
|
|
85
|
+
assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
|
|
86
|
+
assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Site class testing - SiteSample
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import numpy as np
|
|
7
|
+
import xarray as xr
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from xarray import Dataset
|
|
12
|
+
|
|
13
|
+
from ocf_data_sampler.sample.site import SiteSample
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def sample_data():
|
|
18
|
+
""" Fixture creation with sample data """
|
|
19
|
+
|
|
20
|
+
# Time periods specified
|
|
21
|
+
init_time = pd.Timestamp("2023-01-01 00:00")
|
|
22
|
+
target_times = pd.date_range("2023-01-01 00:00", periods=4, freq="1h")
|
|
23
|
+
sat_times = pd.date_range("2023-01-01 00:00", periods=7, freq="5min")
|
|
24
|
+
site_times = pd.date_range("2023-01-01 00:00", periods=4, freq="30min")
|
|
25
|
+
|
|
26
|
+
# Defined steps for NWP data
|
|
27
|
+
steps = [(t - init_time) for t in target_times]
|
|
28
|
+
|
|
29
|
+
# Create the sample dataset
|
|
30
|
+
return Dataset(
|
|
31
|
+
data_vars={
|
|
32
|
+
'nwp-ukv': (
|
|
33
|
+
["nwp-ukv__target_time_utc", "nwp-ukv__channel",
|
|
34
|
+
"nwp-ukv__y_osgb", "nwp-ukv__x_osgb"],
|
|
35
|
+
np.random.rand(4, 1, 2, 2)
|
|
36
|
+
),
|
|
37
|
+
'satellite': (
|
|
38
|
+
["satellite__time_utc", "satellite__channel",
|
|
39
|
+
"satellite__y_geostationary", "satellite__x_geostationary"],
|
|
40
|
+
np.random.rand(7, 1, 2, 2)
|
|
41
|
+
),
|
|
42
|
+
'site': (["site__time_utc"], np.random.rand(4))
|
|
43
|
+
},
|
|
44
|
+
coords={
|
|
45
|
+
# NWP coords
|
|
46
|
+
'nwp-ukv__target_time_utc': target_times,
|
|
47
|
+
'nwp-ukv__channel': ['dswrf'],
|
|
48
|
+
'nwp-ukv__y_osgb': [0, 1],
|
|
49
|
+
'nwp-ukv__x_osgb': [0, 1],
|
|
50
|
+
'nwp-ukv__init_time_utc': init_time,
|
|
51
|
+
'nwp-ukv__step': ('nwp-ukv__target_time_utc', steps),
|
|
52
|
+
|
|
53
|
+
# Sat coords
|
|
54
|
+
'satellite__time_utc': sat_times,
|
|
55
|
+
'satellite__channel': ['vis'],
|
|
56
|
+
'satellite__y_geostationary': [0, 1],
|
|
57
|
+
'satellite__x_geostationary': [0, 1],
|
|
58
|
+
|
|
59
|
+
# Site coords
|
|
60
|
+
'site__time_utc': site_times,
|
|
61
|
+
'site__capacity_kwp': 1000.0,
|
|
62
|
+
'site__site_id': 1,
|
|
63
|
+
'site__longitude': -3.5,
|
|
64
|
+
'site__latitude': 51.5,
|
|
65
|
+
|
|
66
|
+
# Site features as coords
|
|
67
|
+
'site__solar_azimuth': ('site__time_utc', np.random.rand(4)),
|
|
68
|
+
'site__solar_elevation': ('site__time_utc', np.random.rand(4)),
|
|
69
|
+
'site__date_cos': ('site__time_utc', np.random.rand(4)),
|
|
70
|
+
'site__date_sin': ('site__time_utc', np.random.rand(4)),
|
|
71
|
+
'site__time_cos': ('site__time_utc', np.random.rand(4)),
|
|
72
|
+
'site__time_sin': ('site__time_utc', np.random.rand(4))
|
|
73
|
+
}
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_site_sample_init():
|
|
78
|
+
""" Test initialisation """
|
|
79
|
+
sample = SiteSample()
|
|
80
|
+
assert isinstance(sample._data, dict)
|
|
81
|
+
assert len(sample._data) == 0
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_site_sample_with_data(sample_data):
|
|
85
|
+
""" Testing of defined sample with actual data """
|
|
86
|
+
sample = SiteSample()
|
|
87
|
+
sample._data = sample_data
|
|
88
|
+
|
|
89
|
+
# Assert data structure
|
|
90
|
+
assert isinstance(sample._data, Dataset)
|
|
91
|
+
|
|
92
|
+
# Assert dimensions / shapes
|
|
93
|
+
expected_dims = {
|
|
94
|
+
"satellite__x_geostationary",
|
|
95
|
+
"site__time_utc",
|
|
96
|
+
"nwp-ukv__target_time_utc",
|
|
97
|
+
"nwp-ukv__x_osgb",
|
|
98
|
+
"satellite__channel",
|
|
99
|
+
"satellite__y_geostationary",
|
|
100
|
+
"satellite__time_utc",
|
|
101
|
+
"nwp-ukv__channel",
|
|
102
|
+
"nwp-ukv__y_osgb",
|
|
103
|
+
}
|
|
104
|
+
assert set(sample._data.dims) == expected_dims
|
|
105
|
+
assert sample._data["satellite"].values.shape == (7, 1, 2, 2)
|
|
106
|
+
assert sample._data["nwp-ukv"].values.shape == (4, 1, 2, 2)
|
|
107
|
+
assert sample._data["site"].values.shape == (4,)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_save_load(tmp_path, sample_data):
|
|
111
|
+
""" Save and load functionality """
|
|
112
|
+
sample = SiteSample()
|
|
113
|
+
sample._data = sample_data
|
|
114
|
+
filepath = tmp_path / "test_sample.nc"
|
|
115
|
+
sample.save(filepath)
|
|
116
|
+
|
|
117
|
+
# Assert file exists and has content
|
|
118
|
+
assert filepath.exists()
|
|
119
|
+
assert filepath.stat().st_size > 0
|
|
120
|
+
|
|
121
|
+
# Load and verify
|
|
122
|
+
loaded = SiteSample.load(filepath)
|
|
123
|
+
assert isinstance(loaded, SiteSample)
|
|
124
|
+
assert isinstance(loaded._data, Dataset)
|
|
125
|
+
|
|
126
|
+
# Compare original / loaded data
|
|
127
|
+
xr.testing.assert_identical(sample._data, loaded._data)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_invalid_save_format(sample_data):
|
|
131
|
+
""" Saving with invalid format """
|
|
132
|
+
sample = SiteSample()
|
|
133
|
+
sample._data = sample_data
|
|
134
|
+
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
135
|
+
sample.save("invalid.txt")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def test_invalid_load_format():
|
|
139
|
+
""" Loading with invalid format """
|
|
140
|
+
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
141
|
+
SiteSample.load("invalid.txt")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_invalid_data_type():
|
|
145
|
+
""" Handling of invalid data types """
|
|
146
|
+
sample = SiteSample()
|
|
147
|
+
sample._data = {"invalid": "data"}
|
|
148
|
+
|
|
149
|
+
with pytest.raises(TypeError, match="Data must be xarray Dataset"):
|
|
150
|
+
sample.to_numpy()
|
|
151
|
+
|
|
152
|
+
with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
|
|
153
|
+
sample.save("test.nc")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_to_numpy(sample_data):
|
|
157
|
+
""" To numpy conversion """
|
|
158
|
+
sample = SiteSample()
|
|
159
|
+
sample._data = sample_data
|
|
160
|
+
numpy_data = sample.to_numpy()
|
|
161
|
+
|
|
162
|
+
# Assert structure
|
|
163
|
+
assert isinstance(numpy_data, dict)
|
|
164
|
+
assert 'site' in numpy_data
|
|
165
|
+
assert 'nwp' in numpy_data
|
|
166
|
+
|
|
167
|
+
# Check site - numpy array instead of dict
|
|
168
|
+
site_data = numpy_data['site']
|
|
169
|
+
assert isinstance(site_data, np.ndarray)
|
|
170
|
+
assert site_data.ndim == 1
|
|
171
|
+
assert len(site_data) == 4
|
|
172
|
+
assert np.all(site_data >= 0) and np.all(site_data <= 1)
|
|
173
|
+
|
|
174
|
+
# Check NWP
|
|
175
|
+
assert 'ukv' in numpy_data['nwp']
|
|
176
|
+
nwp_data = numpy_data['nwp']['ukv']
|
|
177
|
+
assert 'nwp' in nwp_data
|
|
178
|
+
assert nwp_data['nwp'].shape == (4, 1, 2, 2)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test_data_consistency(sample_data):
|
|
182
|
+
""" Consistency of data across operations """
|
|
183
|
+
sample = SiteSample()
|
|
184
|
+
sample._data = sample_data
|
|
185
|
+
numpy_data = sample.to_numpy()
|
|
186
|
+
|
|
187
|
+
# Assert components remain consistent after conversion above
|
|
188
|
+
assert numpy_data['nwp']['ukv']['nwp'].shape == (4, 1, 2, 2)
|
|
189
|
+
assert 'site' in numpy_data
|
|
190
|
+
|
|
191
|
+
# Update site data checks to expect numpy array
|
|
192
|
+
assert isinstance(numpy_data['site'], np.ndarray)
|
|
193
|
+
assert numpy_data['site'].shape == (4,)
|
|
194
|
+
assert np.all(numpy_data['site'] >= 0)
|
|
195
|
+
assert np.all(numpy_data['site'] <= 1)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UK Regional class testing - UKRegionalSample
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import tempfile
|
|
9
|
+
|
|
10
|
+
from ocf_data_sampler.numpy_sample import (
|
|
11
|
+
GSPSampleKey,
|
|
12
|
+
SatelliteSampleKey,
|
|
13
|
+
NWPSampleKey
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from ocf_data_sampler.sample.uk_regional import UKRegionalSample
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Fixture define
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def pvnet_config_filename(tmp_path):
|
|
22
|
+
""" Minimal config file - testing """
|
|
23
|
+
config_content = """
|
|
24
|
+
input_data:
|
|
25
|
+
gsp:
|
|
26
|
+
zarr_path: ""
|
|
27
|
+
time_resolution_minutes: 30
|
|
28
|
+
interval_start_minutes: -180
|
|
29
|
+
interval_end_minutes: 0
|
|
30
|
+
nwp:
|
|
31
|
+
ukv:
|
|
32
|
+
zarr_path: ""
|
|
33
|
+
image_size_pixels_height: 64
|
|
34
|
+
image_size_pixels_width: 64
|
|
35
|
+
time_resolution_minutes: 60
|
|
36
|
+
interval_start_minutes: -180
|
|
37
|
+
interval_end_minutes: 0
|
|
38
|
+
channels: ["t", "dswrf"]
|
|
39
|
+
provider: "ukv"
|
|
40
|
+
satellite:
|
|
41
|
+
zarr_path: ""
|
|
42
|
+
image_size_pixels_height: 64
|
|
43
|
+
image_size_pixels_width: 64
|
|
44
|
+
time_resolution_minutes: 30
|
|
45
|
+
interval_start_minutes: -180
|
|
46
|
+
interval_end_minutes: 0
|
|
47
|
+
channels: ["HRV"]
|
|
48
|
+
"""
|
|
49
|
+
config_file = tmp_path / "test_config.yaml"
|
|
50
|
+
config_file.write_text(config_content)
|
|
51
|
+
return str(config_file)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def create_test_data():
|
|
55
|
+
""" Synthetic data generation """
|
|
56
|
+
|
|
57
|
+
# Field / spatial coordinates
|
|
58
|
+
nwp_data = {
|
|
59
|
+
'nwp': np.random.rand(4, 1, 2, 2),
|
|
60
|
+
'x': np.array([1, 2]),
|
|
61
|
+
'y': np.array([1, 2]),
|
|
62
|
+
NWPSampleKey.channel_names: ['test_channel']
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return {
|
|
66
|
+
'nwp': {
|
|
67
|
+
'ukv': nwp_data
|
|
68
|
+
},
|
|
69
|
+
GSPSampleKey.gsp: np.random.rand(7),
|
|
70
|
+
SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
|
|
71
|
+
GSPSampleKey.solar_azimuth: np.random.rand(7),
|
|
72
|
+
GSPSampleKey.solar_elevation: np.random.rand(7)
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# UKRegionalSample testing
|
|
77
|
+
def test_sample_init():
|
|
78
|
+
""" Initialisation """
|
|
79
|
+
sample = UKRegionalSample()
|
|
80
|
+
assert hasattr(sample, '_data'), "Sample should have _data attribute"
|
|
81
|
+
assert isinstance(sample._data, dict)
|
|
82
|
+
assert len(sample._data) == 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_sample_save_load():
|
|
86
|
+
sample = UKRegionalSample()
|
|
87
|
+
sample._data = create_test_data()
|
|
88
|
+
|
|
89
|
+
with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
|
|
90
|
+
sample.save(tf.name)
|
|
91
|
+
loaded = UKRegionalSample.load(tf.name)
|
|
92
|
+
|
|
93
|
+
assert set(loaded._data.keys()) == set(sample._data.keys())
|
|
94
|
+
assert isinstance(loaded._data['nwp'], dict)
|
|
95
|
+
assert 'ukv' in loaded._data['nwp']
|
|
96
|
+
|
|
97
|
+
assert loaded._data[GSPSampleKey.gsp].shape == (7,)
|
|
98
|
+
assert loaded._data[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
99
|
+
assert loaded._data[GSPSampleKey.solar_azimuth].shape == (7,)
|
|
100
|
+
assert loaded._data[GSPSampleKey.solar_elevation].shape == (7,)
|
|
101
|
+
|
|
102
|
+
np.testing.assert_array_almost_equal(
|
|
103
|
+
loaded._data[GSPSampleKey.gsp],
|
|
104
|
+
sample._data[GSPSampleKey.gsp]
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_save_unsupported_format():
|
|
109
|
+
""" Test saving - unsupported file format """
|
|
110
|
+
sample = UKRegionalSample()
|
|
111
|
+
sample._data = create_test_data()
|
|
112
|
+
|
|
113
|
+
with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
|
|
114
|
+
with pytest.raises(ValueError, match="Only .pt format is supported"):
|
|
115
|
+
sample.save(tf.name)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_load_unsupported_format():
|
|
119
|
+
""" Test loading - unsupported file format """
|
|
120
|
+
|
|
121
|
+
with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
|
|
122
|
+
with pytest.raises(ValueError, match="Only .pt format is supported"):
|
|
123
|
+
UKRegionalSample.load(tf.name)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_load_corrupted_file():
|
|
127
|
+
""" Test loading - corrupted / empty file """
|
|
128
|
+
|
|
129
|
+
with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
|
|
130
|
+
with open(tf.name, 'wb') as f:
|
|
131
|
+
f.write(b'corrupted data')
|
|
132
|
+
|
|
133
|
+
with pytest.raises(Exception):
|
|
134
|
+
UKRegionalSample.load(tf.name)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def test_to_numpy():
|
|
138
|
+
""" To numpy conversion check """
|
|
139
|
+
sample = UKRegionalSample()
|
|
140
|
+
sample._data = {
|
|
141
|
+
'nwp': {
|
|
142
|
+
'ukv': {
|
|
143
|
+
'nwp': np.random.rand(4, 1, 2, 2),
|
|
144
|
+
'x': np.array([1, 2]),
|
|
145
|
+
'y': np.array([1, 2])
|
|
146
|
+
}
|
|
147
|
+
},
|
|
148
|
+
GSPSampleKey.gsp: np.random.rand(7),
|
|
149
|
+
SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
|
|
150
|
+
GSPSampleKey.solar_azimuth: np.random.rand(7),
|
|
151
|
+
GSPSampleKey.solar_elevation: np.random.rand(7)
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
numpy_data = sample.to_numpy()
|
|
155
|
+
|
|
156
|
+
# Check returned data matches
|
|
157
|
+
assert numpy_data == sample._data
|
|
158
|
+
assert len(numpy_data) == len(sample._data)
|
|
159
|
+
|
|
160
|
+
# Assert specific keys and types
|
|
161
|
+
assert 'nwp' in numpy_data
|
|
162
|
+
assert isinstance(numpy_data['nwp']['ukv']['nwp'], np.ndarray)
|
|
163
|
+
assert numpy_data[GSPSampleKey.gsp].shape == (7,)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
import numpy as np
|
|
3
|
-
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets
|
|
3
|
+
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets, coarsen_data
|
|
4
4
|
from xarray import Dataset, DataArray
|
|
5
|
+
import xarray as xr
|
|
5
6
|
|
|
6
7
|
from torch.utils.data import DataLoader
|
|
7
8
|
|
|
@@ -43,7 +44,6 @@ def test_site(site_config_filename):
|
|
|
43
44
|
|
|
44
45
|
expected_data_vars = {"nwp-ukv", "satellite", "site"}
|
|
45
46
|
|
|
46
|
-
import xarray as xr
|
|
47
47
|
|
|
48
48
|
sample.to_netcdf("sample.nc")
|
|
49
49
|
sample = xr.open_dataset("sample.nc")
|
|
@@ -198,3 +198,18 @@ def test_process_and_combine_site_sample_dict(site_config_filename):
|
|
|
198
198
|
assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for nwp-ukv : {nwp_result.shape}"
|
|
199
199
|
site_result = result["site"]
|
|
200
200
|
assert site_result.shape == (197,), f"Unexpected shape for site: {site_result.shape}"
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_potentially_coarsen(ds_nwp_ecmwf):
|
|
204
|
+
"""Test potentially_coarsen function with ECMWF_UK data."""
|
|
205
|
+
nwp_data = ds_nwp_ecmwf
|
|
206
|
+
assert nwp_data.ECMWF_UK.shape[3:] == (15, 12) # Check initial shape (lon, lat)
|
|
207
|
+
|
|
208
|
+
data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=2)
|
|
209
|
+
assert data.ECMWF_UK.shape[3:] == (8, 6) # Coarsen to every 2 degrees
|
|
210
|
+
|
|
211
|
+
data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=3)
|
|
212
|
+
assert data.ECMWF_UK.shape[3:] == (5, 4) # Coarsen to every 3 degrees
|
|
213
|
+
|
|
214
|
+
data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=1)
|
|
215
|
+
assert data.ECMWF_UK.shape[3:] == (15, 12) # No coarsening (same shape)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|