ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +25 -23
- ocf_data_sampler/load/satellite.py +21 -29
- ocf_data_sampler/load/site.py +1 -1
- ocf_data_sampler/numpy_sample/gsp.py +6 -2
- ocf_data_sampler/numpy_sample/nwp.py +7 -13
- ocf_data_sampler/numpy_sample/satellite.py +11 -8
- ocf_data_sampler/numpy_sample/site.py +6 -2
- ocf_data_sampler/numpy_sample/sun_position.py +9 -10
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +16 -35
- ocf_data_sampler/sample/site.py +28 -65
- ocf_data_sampler/sample/uk_regional.py +52 -97
- ocf_data_sampler/select/dropout.py +38 -25
- ocf_data_sampler/select/fill_time_periods.py +3 -1
- ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +27 -27
- tests/config/test_config.py +3 -3
- tests/conftest.py +33 -0
- tests/numpy_sample/test_nwp.py +3 -42
- tests/select/test_dropout.py +7 -13
- tests/test_sample/test_site_sample.py +5 -35
- tests/test_sample/test_uk_regional_sample.py +8 -35
- tests/torch_datasets/test_pvnet_uk.py +6 -19
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
|
@@ -1,120 +1,75 @@
|
|
|
1
|
-
"""
|
|
2
|
-
PVNet - UK Regional sample / dataset implementation
|
|
3
|
-
"""
|
|
1
|
+
"""PVNet UK Regional sample implementation for dataset handling and visualisation"""
|
|
4
2
|
|
|
5
|
-
|
|
6
|
-
import pandas as pd
|
|
7
|
-
import torch
|
|
8
|
-
import logging
|
|
3
|
+
from typing_extensions import override
|
|
9
4
|
|
|
10
|
-
|
|
11
|
-
from pathlib import Path
|
|
5
|
+
import torch
|
|
12
6
|
|
|
7
|
+
from ocf_data_sampler.sample.base import SampleBase, NumpySample
|
|
13
8
|
from ocf_data_sampler.numpy_sample import (
|
|
14
9
|
NWPSampleKey,
|
|
15
10
|
GSPSampleKey,
|
|
16
11
|
SatelliteSampleKey
|
|
17
12
|
)
|
|
18
13
|
|
|
19
|
-
from ocf_data_sampler.sample.base import SampleBase
|
|
20
|
-
|
|
21
|
-
try:
|
|
22
|
-
import matplotlib.pyplot as plt
|
|
23
|
-
MATPLOTLIB_AVAILABLE = True
|
|
24
|
-
except ImportError:
|
|
25
|
-
MATPLOTLIB_AVAILABLE = False
|
|
26
|
-
plt = None
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
14
|
|
|
32
15
|
class UKRegionalSample(SampleBase):
|
|
33
|
-
"""
|
|
16
|
+
"""Handles UK Regional PVNet data operations"""
|
|
34
17
|
|
|
35
|
-
def __init__(self):
|
|
36
|
-
|
|
37
|
-
super().__init__()
|
|
38
|
-
self._data = {}
|
|
18
|
+
def __init__(self, data: NumpySample):
|
|
19
|
+
self._data = data
|
|
39
20
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
logger.debug("Converting sample data to numpy format")
|
|
21
|
+
@override
|
|
22
|
+
def to_numpy(self) -> NumpySample:
|
|
43
23
|
return self._data
|
|
44
24
|
|
|
45
|
-
def save(self, path:
|
|
46
|
-
"""
|
|
47
|
-
logger.debug(f"Saving UKRegionalSample to {path}")
|
|
48
|
-
path = Path(path)
|
|
49
|
-
|
|
50
|
-
if path.suffix != '.pt':
|
|
51
|
-
logger.error(f"Invalid file format: {path.suffix}")
|
|
52
|
-
raise ValueError(f"Only .pt format is supported: {path.suffix}")
|
|
25
|
+
def save(self, path: str) -> None:
|
|
26
|
+
"""Save PVNet sample as pickle format using torch.save
|
|
53
27
|
|
|
28
|
+
Args:
|
|
29
|
+
path: Path to save the sample data to
|
|
30
|
+
"""
|
|
54
31
|
torch.save(self._data, path)
|
|
55
|
-
logger.debug(f"Successfully saved UKRegionalSample to {path}")
|
|
56
32
|
|
|
57
33
|
@classmethod
|
|
58
|
-
def load(cls, path:
|
|
59
|
-
"""
|
|
60
|
-
logger.debug(f"Attempting to load UKRegionalSample from {path}")
|
|
61
|
-
path = Path(path)
|
|
34
|
+
def load(cls, path: str) -> 'UKRegionalSample':
|
|
35
|
+
"""Load PVNet sample data from .pt format
|
|
62
36
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
instance = cls()
|
|
37
|
+
Args:
|
|
38
|
+
path: Path to load the sample data from
|
|
39
|
+
"""
|
|
68
40
|
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
69
|
-
|
|
70
|
-
instance._data = torch.load(path, weights_only=False)
|
|
71
|
-
logger.debug(f"Successfully loaded UKRegionalSample from {path}")
|
|
72
|
-
return instance
|
|
73
|
-
|
|
74
|
-
def plot(self, **kwargs) -> None:
|
|
75
|
-
""" Sample visualisation definition """
|
|
76
|
-
logger.debug("Creating UKRegionalSample visualisation")
|
|
77
|
-
|
|
78
|
-
if not MATPLOTLIB_AVAILABLE:
|
|
79
|
-
raise ImportError(
|
|
80
|
-
"Matplotlib required for plotting"
|
|
81
|
-
"Install via 'ocf_data_sampler[plot]'"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
try:
|
|
85
|
-
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
86
|
-
|
|
87
|
-
if NWPSampleKey.nwp in self._data:
|
|
88
|
-
logger.debug("Plotting NWP data")
|
|
89
|
-
first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
|
|
90
|
-
if 'nwp' in first_nwp:
|
|
91
|
-
axes[0, 1].imshow(first_nwp['nwp'][0])
|
|
92
|
-
axes[0, 1].set_title('NWP (First Channel)')
|
|
93
|
-
if NWPSampleKey.channel_names in first_nwp:
|
|
94
|
-
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
95
|
-
if len(channel_names) > 0:
|
|
96
|
-
axes[0, 1].set_title(f'NWP: {channel_names[0]}')
|
|
41
|
+
return cls(torch.load(path, weights_only=False))
|
|
97
42
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
axes[0, 0].set_title('GSP Generation')
|
|
102
|
-
|
|
103
|
-
if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
|
|
104
|
-
logger.debug("Plotting solar position data")
|
|
105
|
-
axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
|
|
106
|
-
axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
|
|
107
|
-
axes[1, 1].set_title('Solar Position')
|
|
108
|
-
axes[1, 1].legend()
|
|
43
|
+
def plot(self) -> None:
|
|
44
|
+
"""Creates visualisations for NWP, GSP, solar position, and satellite data"""
|
|
45
|
+
from matplotlib import pyplot as plt
|
|
109
46
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
47
|
+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
48
|
+
|
|
49
|
+
if NWPSampleKey.nwp in self._data:
|
|
50
|
+
first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
|
|
51
|
+
if 'nwp' in first_nwp:
|
|
52
|
+
axes[0, 1].imshow(first_nwp['nwp'][0])
|
|
53
|
+
title = 'NWP (First Channel)'
|
|
54
|
+
if NWPSampleKey.channel_names in first_nwp:
|
|
55
|
+
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
56
|
+
if channel_names:
|
|
57
|
+
title = f'NWP: {channel_names[0]}'
|
|
58
|
+
axes[0, 1].set_title(title)
|
|
59
|
+
|
|
60
|
+
if GSPSampleKey.gsp in self._data:
|
|
61
|
+
axes[0, 0].plot(self._data[GSPSampleKey.gsp])
|
|
62
|
+
axes[0, 0].set_title('GSP Generation')
|
|
63
|
+
|
|
64
|
+
if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
|
|
65
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
|
|
66
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
|
|
67
|
+
axes[1, 1].set_title('Solar Position')
|
|
68
|
+
axes[1, 1].legend()
|
|
69
|
+
|
|
70
|
+
if SatelliteSampleKey.satellite_actual in self._data:
|
|
71
|
+
axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
|
|
72
|
+
axes[1, 0].set_title('Satellite Data')
|
|
73
|
+
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
plt.show()
|
|
@@ -1,39 +1,52 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Functions for simulating dropout in time series data
|
|
2
|
+
|
|
3
|
+
This is used for the following types of data: GSP, Satellite and Site
|
|
4
|
+
This is not used for NWP
|
|
5
|
+
"""
|
|
2
6
|
import numpy as np
|
|
3
7
|
import pandas as pd
|
|
4
8
|
import xarray as xr
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
def draw_dropout_time(
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
12
|
+
t0: pd.Timestamp,
|
|
13
|
+
dropout_timedeltas: list[pd.Timedelta],
|
|
14
|
+
dropout_frac: float,
|
|
15
|
+
) -> pd.Timestamp:
|
|
16
|
+
"""Randomly pick a dropout time from a list of timedeltas
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
t0: The forecast init-time
|
|
20
|
+
dropout_timedeltas: List of timedeltas relative to t0 to pick from
|
|
21
|
+
dropout_frac: Probability that dropout will be applied. This should be between 0 and 1
|
|
22
|
+
inclusive
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if dropout_frac>0:
|
|
26
|
+
assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
|
|
27
|
+
|
|
28
|
+
for t in dropout_timedeltas:
|
|
29
|
+
assert t <= pd.Timedelta("0min"), "Dropout timedeltas must be negative"
|
|
30
|
+
|
|
18
31
|
assert 0 <= dropout_frac <= 1
|
|
19
32
|
|
|
20
|
-
if (dropout_timedeltas
|
|
21
|
-
dropout_time =
|
|
33
|
+
if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
|
|
34
|
+
dropout_time = t0
|
|
22
35
|
else:
|
|
23
|
-
|
|
24
|
-
dt = np.random.choice(dropout_timedeltas)
|
|
25
|
-
dropout_time = t0_datetime_utc + dt
|
|
36
|
+
dropout_time = t0 + np.random.choice(dropout_timedeltas)
|
|
26
37
|
|
|
27
38
|
return dropout_time
|
|
28
39
|
|
|
29
40
|
|
|
30
41
|
def apply_dropout_time(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
42
|
+
ds: xr.DataArray,
|
|
43
|
+
dropout_time: pd.Timestamp,
|
|
44
|
+
) -> xr.DataArray:
|
|
45
|
+
"""Apply dropout time to the data
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
ds: Xarray DataArray with 'time_utc' coordiante
|
|
49
|
+
dropout_time: Time after which data is set to NaN
|
|
50
|
+
"""
|
|
51
|
+
# This replaces the times after the dropout with NaNs
|
|
52
|
+
return ds.where(ds.time_utc <= dropout_time)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Fill time periods between start and end dates at specified frequency"""
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
|
|
8
|
+
"""Generate DatetimeIndex for all timestamps between start and end dates"""
|
|
9
|
+
|
|
8
10
|
start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
|
|
9
11
|
end_dts = pd.to_datetime(time_periods["end_dt"].values)
|
|
10
12
|
date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
|
|
@@ -3,14 +3,14 @@ ocf_data_sampler/constants.py,sha256=0HYNmqwBaHVTAEEx9qzk6WD9YInh0gSKLeI3pyq7aNs
|
|
|
3
3
|
ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
4
4
|
ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
|
|
5
5
|
ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
|
|
6
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
6
|
+
ocf_data_sampler/config/model.py,sha256=8PO-23uVy_JjWOJKgaZWdNMehQsAI-Jn8t0lcmBycwg,6992
|
|
7
7
|
ocf_data_sampler/config/save.py,sha256=OqCPT3e0d7vMI2g2iRzmifPD7GscDkFQztU_qE5I0JY,1066
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
|
|
10
10
|
ocf_data_sampler/load/gsp.py,sha256=uRxEORH7J99JAJ-D38nm0iJFOQh7dkm_NCXcpbYkyvo,857
|
|
11
11
|
ocf_data_sampler/load/load_dataset.py,sha256=PHUGSm4hFHfS9nfIP2KjHHCp325O4br7uGBdQH_DP7g,1603
|
|
12
|
-
ocf_data_sampler/load/satellite.py,sha256=
|
|
13
|
-
ocf_data_sampler/load/site.py,sha256=
|
|
12
|
+
ocf_data_sampler/load/satellite.py,sha256=SEQZ9oPe-asEeZeEMDkB1xWK5hErhWMagxohFcBl6KI,2294
|
|
13
|
+
ocf_data_sampler/load/site.py,sha256=hMdoF6sn2PcSBfF2soj7nuQoK9SItaxDXco5nk2n-44,1232
|
|
14
14
|
ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
|
|
15
15
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
16
16
|
ocf_data_sampler/load/nwp/nwp.py,sha256=Jyq1dE7DN0iSe6iSEGA76uu9LoeJz9FzfEUkq6ZZExQ,565
|
|
@@ -21,19 +21,19 @@ ocf_data_sampler/load/nwp/providers/utils.py,sha256=MFOZ5ZXLu3-SxYVJExdlo30b3y3s
|
|
|
21
21
|
ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
|
|
22
22
|
ocf_data_sampler/numpy_sample/collate.py,sha256=oX5axq30sCsSquhNbmWAVMjM54HT1v3MCMopYHcO5Q0,1950
|
|
23
23
|
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=D0RajbnBjg15qjYk16h2H0XO4wH3fw-x0--4VC2nq0s,1204
|
|
24
|
-
ocf_data_sampler/numpy_sample/gsp.py,sha256=
|
|
25
|
-
ocf_data_sampler/numpy_sample/nwp.py,sha256=
|
|
26
|
-
ocf_data_sampler/numpy_sample/satellite.py,sha256=
|
|
27
|
-
ocf_data_sampler/numpy_sample/site.py,sha256=
|
|
28
|
-
ocf_data_sampler/numpy_sample/sun_position.py,sha256=
|
|
29
|
-
ocf_data_sampler/sample/__init__.py,sha256=
|
|
30
|
-
ocf_data_sampler/sample/base.py,sha256=
|
|
31
|
-
ocf_data_sampler/sample/site.py,sha256=
|
|
32
|
-
ocf_data_sampler/sample/uk_regional.py,sha256=
|
|
24
|
+
ocf_data_sampler/numpy_sample/gsp.py,sha256=uBquCFCoWuhJKY8sXpgsTCUDWUuLuv1XeixtFnFw6KU,1115
|
|
25
|
+
ocf_data_sampler/numpy_sample/nwp.py,sha256=Tiba-es23XeyMoEPgZUpLT6EnJCGU9A_1MdY6qkE7bM,1015
|
|
26
|
+
ocf_data_sampler/numpy_sample/satellite.py,sha256=RdXMdGGXysUx-AdL9T33yFOlxprtIdPNBKKX99-mhpY,991
|
|
27
|
+
ocf_data_sampler/numpy_sample/site.py,sha256=TvoEU85fmjYW8pD9UZOyUUACjimdQYxEzulQXunRO6Q,1425
|
|
28
|
+
ocf_data_sampler/numpy_sample/sun_position.py,sha256=ithM--eztAhiIQ1g52tlxgj-tMKbsJzx8mk6CgV2tzk,1613
|
|
29
|
+
ocf_data_sampler/sample/__init__.py,sha256=zdS73NTnxFX_j8uh9tT-IXiURB6635wbneM1koWYV1o,169
|
|
30
|
+
ocf_data_sampler/sample/base.py,sha256=IH3HbfqEUwjHmq-h2eJYLd8Jk-0ZcOylnehMyCPMV38,2223
|
|
31
|
+
ocf_data_sampler/sample/site.py,sha256=ONf2Yz5zi8Ombd_znA4T7NXbO01F76kQsBZv6rfnC74,1343
|
|
32
|
+
ocf_data_sampler/sample/uk_regional.py,sha256=KhJ5Ik1pZRp7PgIJjGIrE4i7SQnIdVjUbBHnfn-7ghg,2649
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
34
|
-
ocf_data_sampler/select/dropout.py,sha256=
|
|
35
|
-
ocf_data_sampler/select/fill_time_periods.py,sha256=
|
|
36
|
-
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=
|
|
34
|
+
ocf_data_sampler/select/dropout.py,sha256=Pgov9P7rQMkSdqluG_hwm8loGyYNFOg-3PJUBLN_kjU,1526
|
|
35
|
+
ocf_data_sampler/select/fill_time_periods.py,sha256=EIcXG-77aQVOAYNwbDBEv6SGf6DO2p1WMEf96iW4MEM,596
|
|
36
|
+
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=IwPQwvgu4cOiAZ5Gbjflv3fnQCcs0EVK0g4V6yqqSgw,11129
|
|
37
37
|
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
38
38
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
39
39
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
@@ -48,8 +48,8 @@ ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW
|
|
|
48
48
|
ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
|
|
49
49
|
scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
|
|
50
50
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
|
-
tests/conftest.py,sha256=
|
|
52
|
-
tests/config/test_config.py,sha256=
|
|
51
|
+
tests/conftest.py,sha256=k7nM3u2YJmkMupN4SIbJP3BRoxNR1dpIoo2fPFf0abg,8588
|
|
52
|
+
tests/config/test_config.py,sha256=CzYVhAUpgT4lvQdIddtVxtJeMqYL_TJolfeIwaaohq4,3969
|
|
53
53
|
tests/config/test_load.py,sha256=8nui2UsgK_eufWGD74yXvf-6eY_SxBFKhDmGYUtRQxw,260
|
|
54
54
|
tests/config/test_save.py,sha256=BxSd2S50-bRPIXP_4iX0B6Wt7pRFJnUbLYtzfLaqlAs,915
|
|
55
55
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
@@ -59,24 +59,24 @@ tests/load/test_load_sites.py,sha256=6V-U3_EtBklkV7w-hOoR4nba3dSaZ_cnjuRWFs8kYVU
|
|
|
59
59
|
tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
|
|
60
60
|
tests/numpy_sample/test_datetime_features.py,sha256=iR9WdBLj1nIBNqoaTFE9rkUaH1eKFJSNb96nwiEaQH0,1449
|
|
61
61
|
tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
|
|
62
|
-
tests/numpy_sample/test_nwp.py,sha256=
|
|
62
|
+
tests/numpy_sample/test_nwp.py,sha256=Lnd-PMa6gI-fSIJkSZ554QiHFfnwxeXZxLg-rpuBv1U,442
|
|
63
63
|
tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
|
|
64
64
|
tests/numpy_sample/test_sun_position.py,sha256=_ENYzsNBVPdNXf--FI-UUFqw2u5w7_zqw6LcENU2uZM,2504
|
|
65
|
-
tests/select/test_dropout.py,sha256=
|
|
65
|
+
tests/select/test_dropout.py,sha256=aQuSSqZF9RxBjN9-ogkQ8O-_zktAM30CrT1Lz7j1hMg,2222
|
|
66
66
|
tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
|
|
67
67
|
tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
|
|
68
68
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
69
69
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
70
70
|
tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
|
|
71
71
|
tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
|
|
72
|
-
tests/test_sample/test_site_sample.py,sha256=
|
|
73
|
-
tests/test_sample/test_uk_regional_sample.py,sha256=
|
|
72
|
+
tests/test_sample/test_site_sample.py,sha256=8HNenhIWYouCQu4y389PDQGokSPI5jQ4lS4CG-eA1Y8,5382
|
|
73
|
+
tests/test_sample/test_uk_regional_sample.py,sha256=MFibX9-M8mFK7vwMPu58gAG2VoY6y7w7chW5BlZclwk,3962
|
|
74
74
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
|
|
75
|
-
tests/torch_datasets/test_pvnet_uk.py,sha256=
|
|
75
|
+
tests/torch_datasets/test_pvnet_uk.py,sha256=hgD_IDa4D8cgc4cgK1UqKYkT6sFlrTMAvgVn_iwD5_4,5086
|
|
76
76
|
tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
|
|
77
77
|
tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
|
|
78
|
-
ocf_data_sampler-0.1.
|
|
79
|
-
ocf_data_sampler-0.1.
|
|
80
|
-
ocf_data_sampler-0.1.
|
|
81
|
-
ocf_data_sampler-0.1.
|
|
82
|
-
ocf_data_sampler-0.1.
|
|
78
|
+
ocf_data_sampler-0.1.11.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
79
|
+
ocf_data_sampler-0.1.11.dist-info/METADATA,sha256=d8wctSlRyDbP1_yYHFvIGQgEC8DmOkM8h-ITI4XFuPw,12174
|
|
80
|
+
ocf_data_sampler-0.1.11.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
81
|
+
ocf_data_sampler-0.1.11.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
82
|
+
ocf_data_sampler-0.1.11.dist-info/RECORD,,
|
tests/config/test_config.py
CHANGED
|
@@ -30,7 +30,7 @@ def test_incorrect_interval_start_minutes(test_config_filename):
|
|
|
30
30
|
configuration.input_data.nwp['ukv'].interval_start_minutes = -1111
|
|
31
31
|
with pytest.raises(
|
|
32
32
|
ValueError,
|
|
33
|
-
match="interval_start_minutes
|
|
33
|
+
match="interval_start_minutes.*must be divisible.*time_resolution_minutes.*"
|
|
34
34
|
):
|
|
35
35
|
_ = Configuration(**configuration.model_dump())
|
|
36
36
|
|
|
@@ -45,7 +45,7 @@ def test_incorrect_interval_end_minutes(test_config_filename):
|
|
|
45
45
|
configuration.input_data.nwp['ukv'].interval_end_minutes = 1111
|
|
46
46
|
with pytest.raises(
|
|
47
47
|
ValueError,
|
|
48
|
-
match="interval_end_minutes
|
|
48
|
+
match="interval_end_minutes.*must be divisible.*time_resolution_minutes.*"
|
|
49
49
|
):
|
|
50
50
|
_ = Configuration(**configuration.model_dump())
|
|
51
51
|
|
|
@@ -103,7 +103,7 @@ def test_inconsistent_dropout_use(test_config_filename):
|
|
|
103
103
|
|
|
104
104
|
configuration = load_yaml_configuration(test_config_filename)
|
|
105
105
|
configuration.input_data.satellite.dropout_fraction= 1.0
|
|
106
|
-
configuration.input_data.satellite.dropout_timedeltas_minutes =
|
|
106
|
+
configuration.input_data.satellite.dropout_timedeltas_minutes = []
|
|
107
107
|
|
|
108
108
|
with pytest.raises(ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas"):
|
|
109
109
|
_ = Configuration(**configuration.model_dump())
|
tests/conftest.py
CHANGED
|
@@ -130,6 +130,39 @@ def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
|
|
|
130
130
|
yield zarr_path
|
|
131
131
|
|
|
132
132
|
|
|
133
|
+
@pytest.fixture()
|
|
134
|
+
def ds_nwp_ukv_time_sliced():
|
|
135
|
+
|
|
136
|
+
t0 = pd.to_datetime("2024-01-02 00:00")
|
|
137
|
+
|
|
138
|
+
x = np.arange(-100, 100, 10)
|
|
139
|
+
y = np.arange(-100, 100, 10)
|
|
140
|
+
steps = pd.timedelta_range("0h", "8h", freq="1h")
|
|
141
|
+
target_times = t0 + steps
|
|
142
|
+
|
|
143
|
+
channels = ["t", "dswrf"]
|
|
144
|
+
init_times = pd.to_datetime([t0]*len(steps))
|
|
145
|
+
|
|
146
|
+
# Create dummy time-sliced NWP data
|
|
147
|
+
da_nwp = xr.DataArray(
|
|
148
|
+
np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
|
|
149
|
+
coords=dict(
|
|
150
|
+
target_time_utc=(["target_time_utc"], target_times),
|
|
151
|
+
channel=(["channel"], channels),
|
|
152
|
+
x_osgb=(["x_osgb"], x),
|
|
153
|
+
y_osgb=(["y_osgb"], y),
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Add extra non-coordinate dimensions
|
|
158
|
+
da_nwp = da_nwp.assign_coords(
|
|
159
|
+
init_time_utc=("target_time_utc", init_times),
|
|
160
|
+
step=("target_time_utc", steps),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return da_nwp
|
|
164
|
+
|
|
165
|
+
|
|
133
166
|
@pytest.fixture(scope="session")
|
|
134
167
|
def ds_nwp_ecmwf():
|
|
135
168
|
init_times = pd.date_range(start="2023-01-01 00:00", freq="6h", periods=24 * 7)
|
tests/numpy_sample/test_nwp.py
CHANGED
|
@@ -1,52 +1,13 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import xarray as xr
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
1
|
from ocf_data_sampler.numpy_sample import convert_nwp_to_numpy_sample, NWPSampleKey
|
|
8
2
|
|
|
9
|
-
@pytest.fixture(scope="module")
|
|
10
|
-
def da_nwp_like():
|
|
11
|
-
"""Create dummy data which looks like time-sliced NWP data"""
|
|
12
|
-
|
|
13
|
-
t0 = pd.to_datetime("2024-01-02 00:00")
|
|
14
|
-
|
|
15
|
-
x = np.arange(-100, 100, 10)
|
|
16
|
-
y = np.arange(-100, 100, 10)
|
|
17
|
-
steps = pd.timedelta_range("0h", "8h", freq="1h")
|
|
18
|
-
target_times = t0 + steps
|
|
19
|
-
|
|
20
|
-
channels = ["t", "dswrf"]
|
|
21
|
-
init_times = pd.to_datetime([t0]*len(steps))
|
|
22
|
-
|
|
23
|
-
# Create dummy time-sliced NWP data
|
|
24
|
-
da_nwp = xr.DataArray(
|
|
25
|
-
np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
|
|
26
|
-
coords=dict(
|
|
27
|
-
target_times_utc=(["target_times_utc"], target_times),
|
|
28
|
-
channel=(["channel"], channels),
|
|
29
|
-
x_osgb=(["x_osgb"], x),
|
|
30
|
-
y_osgb=(["y_osgb"], y),
|
|
31
|
-
)
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
# Add extra non-coordinate dimensions
|
|
35
|
-
da_nwp = da_nwp.assign_coords(
|
|
36
|
-
init_time_utc=("target_times_utc", init_times),
|
|
37
|
-
step=("target_times_utc", steps),
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
return da_nwp
|
|
41
|
-
|
|
42
3
|
|
|
43
|
-
def test_convert_nwp_to_numpy_sample(
|
|
4
|
+
def test_convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced):
|
|
44
5
|
|
|
45
6
|
# Call the function
|
|
46
|
-
numpy_sample = convert_nwp_to_numpy_sample(
|
|
7
|
+
numpy_sample = convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced)
|
|
47
8
|
|
|
48
9
|
# Assert the output type
|
|
49
10
|
assert isinstance(numpy_sample, dict)
|
|
50
11
|
|
|
51
12
|
# Assert the shape of the numpy sample
|
|
52
|
-
assert (numpy_sample[NWPSampleKey.nwp] ==
|
|
13
|
+
assert (numpy_sample[NWPSampleKey.nwp] == ds_nwp_ukv_time_sliced.values).all()
|
tests/select/test_dropout.py
CHANGED
|
@@ -14,10 +14,8 @@ def da_sample():
|
|
|
14
14
|
datetimes = pd.date_range("2024-01-01 12:00", "2024-01-01 13:00", freq="5min")
|
|
15
15
|
|
|
16
16
|
da_sat = xr.DataArray(
|
|
17
|
-
np.random.normal(size=(len(datetimes)
|
|
18
|
-
coords=dict(
|
|
19
|
-
time_utc=(["time_utc"], datetimes),
|
|
20
|
-
)
|
|
17
|
+
np.random.normal(size=(len(datetimes))),
|
|
18
|
+
coords=dict(time_utc=datetimes)
|
|
21
19
|
)
|
|
22
20
|
return da_sat
|
|
23
21
|
|
|
@@ -29,7 +27,7 @@ def test_draw_dropout_time():
|
|
|
29
27
|
dropout_time = draw_dropout_time(t0, dropout_timedeltas, dropout_frac=1)
|
|
30
28
|
|
|
31
29
|
assert isinstance(dropout_time, pd.Timestamp)
|
|
32
|
-
assert dropout_time-t0 in dropout_timedeltas
|
|
30
|
+
assert (dropout_time-t0) in dropout_timedeltas
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
def test_draw_dropout_time_partial():
|
|
@@ -48,21 +46,17 @@ def test_draw_dropout_time_partial():
|
|
|
48
46
|
dropouts == {None} | set(t0 + dt for dt in dropout_timedeltas)
|
|
49
47
|
|
|
50
48
|
|
|
51
|
-
def
|
|
49
|
+
def test_draw_dropout_time_null():
|
|
52
50
|
t0 = pd.Timestamp("2021-01-01 04:00:00")
|
|
53
51
|
|
|
54
|
-
# No dropout timedeltas
|
|
55
|
-
dropout_time = draw_dropout_time(t0, dropout_timedeltas=None, dropout_frac=1)
|
|
56
|
-
assert dropout_time is None
|
|
57
|
-
|
|
58
52
|
# Dropout fraction is 0
|
|
59
53
|
dropout_timedeltas = [pd.Timedelta(-30, "min")]
|
|
60
54
|
dropout_time = draw_dropout_time(t0, dropout_timedeltas=dropout_timedeltas, dropout_frac=0)
|
|
61
|
-
assert dropout_time
|
|
55
|
+
assert dropout_time==t0
|
|
62
56
|
|
|
63
57
|
# No dropout timedeltas and dropout fraction is 0
|
|
64
|
-
dropout_time = draw_dropout_time(t0, dropout_timedeltas=
|
|
65
|
-
assert dropout_time
|
|
58
|
+
dropout_time = draw_dropout_time(t0, dropout_timedeltas=[], dropout_frac=0)
|
|
59
|
+
assert dropout_time==t0
|
|
66
60
|
|
|
67
61
|
|
|
68
62
|
@pytest.mark.parametrize("t0_str", ["12:00", "12:30", "13:00"])
|
|
@@ -74,17 +74,9 @@ def sample_data():
|
|
|
74
74
|
)
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def test_site_sample_init():
|
|
78
|
-
""" Test initialisation """
|
|
79
|
-
sample = SiteSample()
|
|
80
|
-
assert isinstance(sample._data, dict)
|
|
81
|
-
assert len(sample._data) == 0
|
|
82
|
-
|
|
83
|
-
|
|
84
77
|
def test_site_sample_with_data(sample_data):
|
|
85
78
|
""" Testing of defined sample with actual data """
|
|
86
|
-
sample = SiteSample()
|
|
87
|
-
sample._data = sample_data
|
|
79
|
+
sample = SiteSample(sample_data)
|
|
88
80
|
|
|
89
81
|
# Assert data structure
|
|
90
82
|
assert isinstance(sample._data, Dataset)
|
|
@@ -109,8 +101,7 @@ def test_site_sample_with_data(sample_data):
|
|
|
109
101
|
|
|
110
102
|
def test_save_load(tmp_path, sample_data):
|
|
111
103
|
""" Save and load functionality """
|
|
112
|
-
sample = SiteSample()
|
|
113
|
-
sample._data = sample_data
|
|
104
|
+
sample = SiteSample(sample_data)
|
|
114
105
|
filepath = tmp_path / "test_sample.nc"
|
|
115
106
|
sample.save(filepath)
|
|
116
107
|
|
|
@@ -127,36 +118,16 @@ def test_save_load(tmp_path, sample_data):
|
|
|
127
118
|
xr.testing.assert_identical(sample._data, loaded._data)
|
|
128
119
|
|
|
129
120
|
|
|
130
|
-
def test_invalid_save_format(sample_data):
|
|
131
|
-
""" Saving with invalid format """
|
|
132
|
-
sample = SiteSample()
|
|
133
|
-
sample._data = sample_data
|
|
134
|
-
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
135
|
-
sample.save("invalid.txt")
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def test_invalid_load_format():
|
|
139
|
-
""" Loading with invalid format """
|
|
140
|
-
with pytest.raises(ValueError, match="Only .nc format is supported"):
|
|
141
|
-
SiteSample.load("invalid.txt")
|
|
142
|
-
|
|
143
|
-
|
|
144
121
|
def test_invalid_data_type():
|
|
145
122
|
""" Handling of invalid data types """
|
|
146
|
-
sample = SiteSample()
|
|
147
|
-
sample._data = {"invalid": "data"}
|
|
148
123
|
|
|
149
124
|
with pytest.raises(TypeError, match="Data must be xarray Dataset"):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
with pytest.raises(TypeError, match="Data must be xarray Dataset for saving"):
|
|
153
|
-
sample.save("test.nc")
|
|
125
|
+
_ = SiteSample({"invalid": "data"})
|
|
154
126
|
|
|
155
127
|
|
|
156
128
|
def test_to_numpy(sample_data):
|
|
157
129
|
""" To numpy conversion """
|
|
158
|
-
sample = SiteSample()
|
|
159
|
-
sample._data = sample_data
|
|
130
|
+
sample = SiteSample(sample_data)
|
|
160
131
|
numpy_data = sample.to_numpy()
|
|
161
132
|
|
|
162
133
|
# Assert structure
|
|
@@ -180,8 +151,7 @@ def test_to_numpy(sample_data):
|
|
|
180
151
|
|
|
181
152
|
def test_data_consistency(sample_data):
|
|
182
153
|
""" Consistency of data across operations """
|
|
183
|
-
sample = SiteSample()
|
|
184
|
-
sample._data = sample_data
|
|
154
|
+
sample = SiteSample(sample_data)
|
|
185
155
|
numpy_data = sample.to_numpy()
|
|
186
156
|
|
|
187
157
|
# Assert components remain consistent after conversion above
|