ocf-data-sampler 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +25 -23
- ocf_data_sampler/load/satellite.py +21 -29
- ocf_data_sampler/load/site.py +1 -1
- ocf_data_sampler/numpy_sample/gsp.py +6 -2
- ocf_data_sampler/numpy_sample/nwp.py +7 -13
- ocf_data_sampler/numpy_sample/satellite.py +11 -8
- ocf_data_sampler/numpy_sample/site.py +6 -2
- ocf_data_sampler/numpy_sample/sun_position.py +9 -10
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +16 -35
- ocf_data_sampler/sample/site.py +28 -65
- ocf_data_sampler/sample/uk_regional.py +52 -97
- ocf_data_sampler/select/dropout.py +38 -25
- ocf_data_sampler/select/fill_time_periods.py +3 -1
- ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -3
- ocf_data_sampler/torch_datasets/datasets/site.py +9 -5
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +29 -29
- tests/config/test_config.py +3 -3
- tests/conftest.py +33 -0
- tests/numpy_sample/test_nwp.py +3 -42
- tests/select/test_dropout.py +7 -13
- tests/test_sample/test_site_sample.py +5 -35
- tests/test_sample/test_uk_regional_sample.py +8 -35
- tests/torch_datasets/test_pvnet_uk.py +6 -19
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.9.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
|
@@ -1,120 +1,75 @@
|
|
|
1
|
-
"""
|
|
2
|
-
PVNet - UK Regional sample / dataset implementation
|
|
3
|
-
"""
|
|
1
|
+
"""PVNet UK Regional sample implementation for dataset handling and visualisation"""
|
|
4
2
|
|
|
5
|
-
|
|
6
|
-
import pandas as pd
|
|
7
|
-
import torch
|
|
8
|
-
import logging
|
|
3
|
+
from typing_extensions import override
|
|
9
4
|
|
|
10
|
-
|
|
11
|
-
from pathlib import Path
|
|
5
|
+
import torch
|
|
12
6
|
|
|
7
|
+
from ocf_data_sampler.sample.base import SampleBase, NumpySample
|
|
13
8
|
from ocf_data_sampler.numpy_sample import (
|
|
14
9
|
NWPSampleKey,
|
|
15
10
|
GSPSampleKey,
|
|
16
11
|
SatelliteSampleKey
|
|
17
12
|
)
|
|
18
13
|
|
|
19
|
-
from ocf_data_sampler.sample.base import SampleBase
|
|
20
|
-
|
|
21
|
-
try:
|
|
22
|
-
import matplotlib.pyplot as plt
|
|
23
|
-
MATPLOTLIB_AVAILABLE = True
|
|
24
|
-
except ImportError:
|
|
25
|
-
MATPLOTLIB_AVAILABLE = False
|
|
26
|
-
plt = None
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
14
|
|
|
32
15
|
class UKRegionalSample(SampleBase):
|
|
33
|
-
"""
|
|
16
|
+
"""Handles UK Regional PVNet data operations"""
|
|
34
17
|
|
|
35
|
-
def __init__(self):
|
|
36
|
-
|
|
37
|
-
super().__init__()
|
|
38
|
-
self._data = {}
|
|
18
|
+
def __init__(self, data: NumpySample):
|
|
19
|
+
self._data = data
|
|
39
20
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
logger.debug("Converting sample data to numpy format")
|
|
21
|
+
@override
|
|
22
|
+
def to_numpy(self) -> NumpySample:
|
|
43
23
|
return self._data
|
|
44
24
|
|
|
45
|
-
def save(self, path:
|
|
46
|
-
"""
|
|
47
|
-
logger.debug(f"Saving UKRegionalSample to {path}")
|
|
48
|
-
path = Path(path)
|
|
49
|
-
|
|
50
|
-
if path.suffix != '.pt':
|
|
51
|
-
logger.error(f"Invalid file format: {path.suffix}")
|
|
52
|
-
raise ValueError(f"Only .pt format is supported: {path.suffix}")
|
|
25
|
+
def save(self, path: str) -> None:
|
|
26
|
+
"""Save PVNet sample as pickle format using torch.save
|
|
53
27
|
|
|
28
|
+
Args:
|
|
29
|
+
path: Path to save the sample data to
|
|
30
|
+
"""
|
|
54
31
|
torch.save(self._data, path)
|
|
55
|
-
logger.debug(f"Successfully saved UKRegionalSample to {path}")
|
|
56
32
|
|
|
57
33
|
@classmethod
|
|
58
|
-
def load(cls, path:
|
|
59
|
-
"""
|
|
60
|
-
logger.debug(f"Attempting to load UKRegionalSample from {path}")
|
|
61
|
-
path = Path(path)
|
|
34
|
+
def load(cls, path: str) -> 'UKRegionalSample':
|
|
35
|
+
"""Load PVNet sample data from .pt format
|
|
62
36
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
instance = cls()
|
|
37
|
+
Args:
|
|
38
|
+
path: Path to load the sample data from
|
|
39
|
+
"""
|
|
68
40
|
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
69
|
-
|
|
70
|
-
instance._data = torch.load(path, weights_only=False)
|
|
71
|
-
logger.debug(f"Successfully loaded UKRegionalSample from {path}")
|
|
72
|
-
return instance
|
|
73
|
-
|
|
74
|
-
def plot(self, **kwargs) -> None:
|
|
75
|
-
""" Sample visualisation definition """
|
|
76
|
-
logger.debug("Creating UKRegionalSample visualisation")
|
|
77
|
-
|
|
78
|
-
if not MATPLOTLIB_AVAILABLE:
|
|
79
|
-
raise ImportError(
|
|
80
|
-
"Matplotlib required for plotting"
|
|
81
|
-
"Install via 'ocf_data_sampler[plot]'"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
try:
|
|
85
|
-
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
86
|
-
|
|
87
|
-
if NWPSampleKey.nwp in self._data:
|
|
88
|
-
logger.debug("Plotting NWP data")
|
|
89
|
-
first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
|
|
90
|
-
if 'nwp' in first_nwp:
|
|
91
|
-
axes[0, 1].imshow(first_nwp['nwp'][0])
|
|
92
|
-
axes[0, 1].set_title('NWP (First Channel)')
|
|
93
|
-
if NWPSampleKey.channel_names in first_nwp:
|
|
94
|
-
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
95
|
-
if len(channel_names) > 0:
|
|
96
|
-
axes[0, 1].set_title(f'NWP: {channel_names[0]}')
|
|
41
|
+
return cls(torch.load(path, weights_only=False))
|
|
97
42
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
axes[0, 0].set_title('GSP Generation')
|
|
102
|
-
|
|
103
|
-
if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
|
|
104
|
-
logger.debug("Plotting solar position data")
|
|
105
|
-
axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
|
|
106
|
-
axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
|
|
107
|
-
axes[1, 1].set_title('Solar Position')
|
|
108
|
-
axes[1, 1].legend()
|
|
43
|
+
def plot(self) -> None:
|
|
44
|
+
"""Creates visualisations for NWP, GSP, solar position, and satellite data"""
|
|
45
|
+
from matplotlib import pyplot as plt
|
|
109
46
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
47
|
+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
48
|
+
|
|
49
|
+
if NWPSampleKey.nwp in self._data:
|
|
50
|
+
first_nwp = list(self._data[NWPSampleKey.nwp].values())[0]
|
|
51
|
+
if 'nwp' in first_nwp:
|
|
52
|
+
axes[0, 1].imshow(first_nwp['nwp'][0])
|
|
53
|
+
title = 'NWP (First Channel)'
|
|
54
|
+
if NWPSampleKey.channel_names in first_nwp:
|
|
55
|
+
channel_names = first_nwp[NWPSampleKey.channel_names]
|
|
56
|
+
if channel_names:
|
|
57
|
+
title = f'NWP: {channel_names[0]}'
|
|
58
|
+
axes[0, 1].set_title(title)
|
|
59
|
+
|
|
60
|
+
if GSPSampleKey.gsp in self._data:
|
|
61
|
+
axes[0, 0].plot(self._data[GSPSampleKey.gsp])
|
|
62
|
+
axes[0, 0].set_title('GSP Generation')
|
|
63
|
+
|
|
64
|
+
if GSPSampleKey.solar_azimuth in self._data and GSPSampleKey.solar_elevation in self._data:
|
|
65
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_azimuth], label='Azimuth')
|
|
66
|
+
axes[1, 1].plot(self._data[GSPSampleKey.solar_elevation], label='Elevation')
|
|
67
|
+
axes[1, 1].set_title('Solar Position')
|
|
68
|
+
axes[1, 1].legend()
|
|
69
|
+
|
|
70
|
+
if SatelliteSampleKey.satellite_actual in self._data:
|
|
71
|
+
axes[1, 0].imshow(self._data[SatelliteSampleKey.satellite_actual])
|
|
72
|
+
axes[1, 0].set_title('Satellite Data')
|
|
73
|
+
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
plt.show()
|
|
@@ -1,39 +1,52 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Functions for simulating dropout in time series data
|
|
2
|
+
|
|
3
|
+
This is used for the following types of data: GSP, Satellite and Site
|
|
4
|
+
This is not used for NWP
|
|
5
|
+
"""
|
|
2
6
|
import numpy as np
|
|
3
7
|
import pandas as pd
|
|
4
8
|
import xarray as xr
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
def draw_dropout_time(
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
12
|
+
t0: pd.Timestamp,
|
|
13
|
+
dropout_timedeltas: list[pd.Timedelta],
|
|
14
|
+
dropout_frac: float,
|
|
15
|
+
) -> pd.Timestamp:
|
|
16
|
+
"""Randomly pick a dropout time from a list of timedeltas
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
t0: The forecast init-time
|
|
20
|
+
dropout_timedeltas: List of timedeltas relative to t0 to pick from
|
|
21
|
+
dropout_frac: Probability that dropout will be applied. This should be between 0 and 1
|
|
22
|
+
inclusive
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if dropout_frac>0:
|
|
26
|
+
assert len(dropout_timedeltas) > 0, "To apply dropout dropout_timedeltas must be provided"
|
|
27
|
+
|
|
28
|
+
for t in dropout_timedeltas:
|
|
29
|
+
assert t <= pd.Timedelta("0min"), "Dropout timedeltas must be negative"
|
|
30
|
+
|
|
18
31
|
assert 0 <= dropout_frac <= 1
|
|
19
32
|
|
|
20
|
-
if (dropout_timedeltas
|
|
21
|
-
dropout_time =
|
|
33
|
+
if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
|
|
34
|
+
dropout_time = t0
|
|
22
35
|
else:
|
|
23
|
-
|
|
24
|
-
dt = np.random.choice(dropout_timedeltas)
|
|
25
|
-
dropout_time = t0_datetime_utc + dt
|
|
36
|
+
dropout_time = t0 + np.random.choice(dropout_timedeltas)
|
|
26
37
|
|
|
27
38
|
return dropout_time
|
|
28
39
|
|
|
29
40
|
|
|
30
41
|
def apply_dropout_time(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
42
|
+
ds: xr.DataArray,
|
|
43
|
+
dropout_time: pd.Timestamp,
|
|
44
|
+
) -> xr.DataArray:
|
|
45
|
+
"""Apply dropout time to the data
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
ds: Xarray DataArray with 'time_utc' coordiante
|
|
49
|
+
dropout_time: Time after which data is set to NaN
|
|
50
|
+
"""
|
|
51
|
+
# This replaces the times after the dropout with NaNs
|
|
52
|
+
return ds.where(ds.time_utc <= dropout_time)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Fill time periods between start and end dates at specified frequency"""
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
|
|
8
|
+
"""Generate DatetimeIndex for all timestamps between start and end dates"""
|
|
9
|
+
|
|
8
10
|
start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
|
|
9
11
|
end_dts = pd.to_datetime(time_periods["end_dt"].values)
|
|
10
12
|
date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
|
|
@@ -186,9 +186,8 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
186
186
|
gsp_ids: List of GSP IDs to create samples for. Defaults to all
|
|
187
187
|
"""
|
|
188
188
|
|
|
189
|
-
config = load_yaml_configuration(config_filename)
|
|
190
|
-
|
|
191
|
-
# Validate channels for NWP and satellite data
|
|
189
|
+
# config = load_yaml_configuration(config_filename)
|
|
190
|
+
config: Configuration = load_yaml_configuration(config_filename)
|
|
192
191
|
validate_nwp_channels(config)
|
|
193
192
|
validate_satellite_channels(config)
|
|
194
193
|
|
|
@@ -20,7 +20,6 @@ from ocf_data_sampler.select import (
|
|
|
20
20
|
from ocf_data_sampler.utils import minutes
|
|
21
21
|
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
22
22
|
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
|
|
23
|
-
from ocf_data_sampler.torch_datasets.utils.validate_channels import validate_nwp_channels
|
|
24
23
|
|
|
25
24
|
from ocf_data_sampler.numpy_sample import (
|
|
26
25
|
convert_site_to_numpy_sample,
|
|
@@ -30,8 +29,12 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
30
29
|
make_sun_position_numpy_sample,
|
|
31
30
|
)
|
|
32
31
|
from ocf_data_sampler.numpy_sample import NWPSampleKey
|
|
33
|
-
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
32
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
34
33
|
|
|
34
|
+
from ocf_data_sampler.torch_datasets.utils.validate_channels import (
|
|
35
|
+
validate_nwp_channels,
|
|
36
|
+
validate_satellite_channels,
|
|
37
|
+
)
|
|
35
38
|
|
|
36
39
|
xr.set_options(keep_attrs=True)
|
|
37
40
|
|
|
@@ -52,9 +55,8 @@ class SitesDataset(Dataset):
|
|
|
52
55
|
"""
|
|
53
56
|
|
|
54
57
|
config: Configuration = load_yaml_configuration(config_filename)
|
|
55
|
-
|
|
56
|
-
# Validate NWP channels
|
|
57
58
|
validate_nwp_channels(config)
|
|
59
|
+
validate_satellite_channels(config)
|
|
58
60
|
|
|
59
61
|
datasets_dict = get_dataset_dict(config.input_data)
|
|
60
62
|
|
|
@@ -237,8 +239,10 @@ class SitesDataset(Dataset):
|
|
|
237
239
|
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
238
240
|
|
|
239
241
|
if "sat" in dataset_dict:
|
|
240
|
-
# TODO add some satellite normalisation
|
|
241
242
|
da_sat = dataset_dict["sat"]
|
|
243
|
+
|
|
244
|
+
# Standardise
|
|
245
|
+
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
242
246
|
data_arrays.append(("satellite", da_sat))
|
|
243
247
|
|
|
244
248
|
if "site" in dataset_dict:
|
|
@@ -3,14 +3,14 @@ ocf_data_sampler/constants.py,sha256=0HYNmqwBaHVTAEEx9qzk6WD9YInh0gSKLeI3pyq7aNs
|
|
|
3
3
|
ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
4
4
|
ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
|
|
5
5
|
ocf_data_sampler/config/load.py,sha256=sKCKmhkkeFvvkNL5xmnFvdAulaCtV4-rigPsFvVDPDc,634
|
|
6
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
6
|
+
ocf_data_sampler/config/model.py,sha256=8PO-23uVy_JjWOJKgaZWdNMehQsAI-Jn8t0lcmBycwg,6992
|
|
7
7
|
ocf_data_sampler/config/save.py,sha256=OqCPT3e0d7vMI2g2iRzmifPD7GscDkFQztU_qE5I0JY,1066
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=T5Zj1PGt0aiiNEN7Ra1Ac-cBsNKhphmmHy_8g7XU_w0,219
|
|
10
10
|
ocf_data_sampler/load/gsp.py,sha256=uRxEORH7J99JAJ-D38nm0iJFOQh7dkm_NCXcpbYkyvo,857
|
|
11
11
|
ocf_data_sampler/load/load_dataset.py,sha256=PHUGSm4hFHfS9nfIP2KjHHCp325O4br7uGBdQH_DP7g,1603
|
|
12
|
-
ocf_data_sampler/load/satellite.py,sha256=
|
|
13
|
-
ocf_data_sampler/load/site.py,sha256=
|
|
12
|
+
ocf_data_sampler/load/satellite.py,sha256=SEQZ9oPe-asEeZeEMDkB1xWK5hErhWMagxohFcBl6KI,2294
|
|
13
|
+
ocf_data_sampler/load/site.py,sha256=hMdoF6sn2PcSBfF2soj7nuQoK9SItaxDXco5nk2n-44,1232
|
|
14
14
|
ocf_data_sampler/load/utils.py,sha256=sAEkPMS9LXVCrc5pANQo97zaoEItVg9hoNj2ZWfx_Ug,1405
|
|
15
15
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
16
16
|
ocf_data_sampler/load/nwp/nwp.py,sha256=Jyq1dE7DN0iSe6iSEGA76uu9LoeJz9FzfEUkq6ZZExQ,565
|
|
@@ -21,19 +21,19 @@ ocf_data_sampler/load/nwp/providers/utils.py,sha256=MFOZ5ZXLu3-SxYVJExdlo30b3y3s
|
|
|
21
21
|
ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
|
|
22
22
|
ocf_data_sampler/numpy_sample/collate.py,sha256=oX5axq30sCsSquhNbmWAVMjM54HT1v3MCMopYHcO5Q0,1950
|
|
23
23
|
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=D0RajbnBjg15qjYk16h2H0XO4wH3fw-x0--4VC2nq0s,1204
|
|
24
|
-
ocf_data_sampler/numpy_sample/gsp.py,sha256=
|
|
25
|
-
ocf_data_sampler/numpy_sample/nwp.py,sha256=
|
|
26
|
-
ocf_data_sampler/numpy_sample/satellite.py,sha256=
|
|
27
|
-
ocf_data_sampler/numpy_sample/site.py,sha256=
|
|
28
|
-
ocf_data_sampler/numpy_sample/sun_position.py,sha256=
|
|
29
|
-
ocf_data_sampler/sample/__init__.py,sha256=
|
|
30
|
-
ocf_data_sampler/sample/base.py,sha256=
|
|
31
|
-
ocf_data_sampler/sample/site.py,sha256=
|
|
32
|
-
ocf_data_sampler/sample/uk_regional.py,sha256=
|
|
24
|
+
ocf_data_sampler/numpy_sample/gsp.py,sha256=uBquCFCoWuhJKY8sXpgsTCUDWUuLuv1XeixtFnFw6KU,1115
|
|
25
|
+
ocf_data_sampler/numpy_sample/nwp.py,sha256=Tiba-es23XeyMoEPgZUpLT6EnJCGU9A_1MdY6qkE7bM,1015
|
|
26
|
+
ocf_data_sampler/numpy_sample/satellite.py,sha256=RdXMdGGXysUx-AdL9T33yFOlxprtIdPNBKKX99-mhpY,991
|
|
27
|
+
ocf_data_sampler/numpy_sample/site.py,sha256=TvoEU85fmjYW8pD9UZOyUUACjimdQYxEzulQXunRO6Q,1425
|
|
28
|
+
ocf_data_sampler/numpy_sample/sun_position.py,sha256=ithM--eztAhiIQ1g52tlxgj-tMKbsJzx8mk6CgV2tzk,1613
|
|
29
|
+
ocf_data_sampler/sample/__init__.py,sha256=zdS73NTnxFX_j8uh9tT-IXiURB6635wbneM1koWYV1o,169
|
|
30
|
+
ocf_data_sampler/sample/base.py,sha256=IH3HbfqEUwjHmq-h2eJYLd8Jk-0ZcOylnehMyCPMV38,2223
|
|
31
|
+
ocf_data_sampler/sample/site.py,sha256=ONf2Yz5zi8Ombd_znA4T7NXbO01F76kQsBZv6rfnC74,1343
|
|
32
|
+
ocf_data_sampler/sample/uk_regional.py,sha256=KhJ5Ik1pZRp7PgIJjGIrE4i7SQnIdVjUbBHnfn-7ghg,2649
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
34
|
-
ocf_data_sampler/select/dropout.py,sha256=
|
|
35
|
-
ocf_data_sampler/select/fill_time_periods.py,sha256=
|
|
36
|
-
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=
|
|
34
|
+
ocf_data_sampler/select/dropout.py,sha256=Pgov9P7rQMkSdqluG_hwm8loGyYNFOg-3PJUBLN_kjU,1526
|
|
35
|
+
ocf_data_sampler/select/fill_time_periods.py,sha256=EIcXG-77aQVOAYNwbDBEv6SGf6DO2p1WMEf96iW4MEM,596
|
|
36
|
+
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=IwPQwvgu4cOiAZ5Gbjflv3fnQCcs0EVK0g4V6yqqSgw,11129
|
|
37
37
|
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
38
38
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
39
39
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
@@ -41,15 +41,15 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
|
|
|
41
41
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
42
42
|
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
|
|
43
43
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
|
|
44
|
-
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=
|
|
45
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
44
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=ZgfvVCcEU3dj3RoY0zdBdKGppC7Wm81qecqB17gYTmE,12286
|
|
45
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_uHmqg-VJu-MHgXc5JFDX1noPfH6E8nY4XhQmsrOav4,16325
|
|
46
46
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
47
47
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
48
48
|
ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
|
|
49
49
|
scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
|
|
50
50
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
|
-
tests/conftest.py,sha256=
|
|
52
|
-
tests/config/test_config.py,sha256=
|
|
51
|
+
tests/conftest.py,sha256=k7nM3u2YJmkMupN4SIbJP3BRoxNR1dpIoo2fPFf0abg,8588
|
|
52
|
+
tests/config/test_config.py,sha256=CzYVhAUpgT4lvQdIddtVxtJeMqYL_TJolfeIwaaohq4,3969
|
|
53
53
|
tests/config/test_load.py,sha256=8nui2UsgK_eufWGD74yXvf-6eY_SxBFKhDmGYUtRQxw,260
|
|
54
54
|
tests/config/test_save.py,sha256=BxSd2S50-bRPIXP_4iX0B6Wt7pRFJnUbLYtzfLaqlAs,915
|
|
55
55
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
@@ -59,24 +59,24 @@ tests/load/test_load_sites.py,sha256=6V-U3_EtBklkV7w-hOoR4nba3dSaZ_cnjuRWFs8kYVU
|
|
|
59
59
|
tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
|
|
60
60
|
tests/numpy_sample/test_datetime_features.py,sha256=iR9WdBLj1nIBNqoaTFE9rkUaH1eKFJSNb96nwiEaQH0,1449
|
|
61
61
|
tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
|
|
62
|
-
tests/numpy_sample/test_nwp.py,sha256=
|
|
62
|
+
tests/numpy_sample/test_nwp.py,sha256=Lnd-PMa6gI-fSIJkSZ554QiHFfnwxeXZxLg-rpuBv1U,442
|
|
63
63
|
tests/numpy_sample/test_satellite.py,sha256=cCqtn5See-uSNfh89COGTUQNuFm6sIZ8QmBVHsuUeRI,1189
|
|
64
64
|
tests/numpy_sample/test_sun_position.py,sha256=_ENYzsNBVPdNXf--FI-UUFqw2u5w7_zqw6LcENU2uZM,2504
|
|
65
|
-
tests/select/test_dropout.py,sha256=
|
|
65
|
+
tests/select/test_dropout.py,sha256=aQuSSqZF9RxBjN9-ogkQ8O-_zktAM30CrT1Lz7j1hMg,2222
|
|
66
66
|
tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
|
|
67
67
|
tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM3agOhsvZYx8inXtUn1PM,5976
|
|
68
68
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
69
69
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
70
70
|
tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
|
|
71
71
|
tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
|
|
72
|
-
tests/test_sample/test_site_sample.py,sha256=
|
|
73
|
-
tests/test_sample/test_uk_regional_sample.py,sha256=
|
|
72
|
+
tests/test_sample/test_site_sample.py,sha256=8HNenhIWYouCQu4y389PDQGokSPI5jQ4lS4CG-eA1Y8,5382
|
|
73
|
+
tests/test_sample/test_uk_regional_sample.py,sha256=MFibX9-M8mFK7vwMPu58gAG2VoY6y7w7chW5BlZclwk,3962
|
|
74
74
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
|
|
75
|
-
tests/torch_datasets/test_pvnet_uk.py,sha256=
|
|
75
|
+
tests/torch_datasets/test_pvnet_uk.py,sha256=hgD_IDa4D8cgc4cgK1UqKYkT6sFlrTMAvgVn_iwD5_4,5086
|
|
76
76
|
tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
|
|
77
77
|
tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
|
|
78
|
-
ocf_data_sampler-0.1.
|
|
79
|
-
ocf_data_sampler-0.1.
|
|
80
|
-
ocf_data_sampler-0.1.
|
|
81
|
-
ocf_data_sampler-0.1.
|
|
82
|
-
ocf_data_sampler-0.1.
|
|
78
|
+
ocf_data_sampler-0.1.11.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
79
|
+
ocf_data_sampler-0.1.11.dist-info/METADATA,sha256=d8wctSlRyDbP1_yYHFvIGQgEC8DmOkM8h-ITI4XFuPw,12174
|
|
80
|
+
ocf_data_sampler-0.1.11.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
81
|
+
ocf_data_sampler-0.1.11.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
82
|
+
ocf_data_sampler-0.1.11.dist-info/RECORD,,
|
tests/config/test_config.py
CHANGED
|
@@ -30,7 +30,7 @@ def test_incorrect_interval_start_minutes(test_config_filename):
|
|
|
30
30
|
configuration.input_data.nwp['ukv'].interval_start_minutes = -1111
|
|
31
31
|
with pytest.raises(
|
|
32
32
|
ValueError,
|
|
33
|
-
match="interval_start_minutes
|
|
33
|
+
match="interval_start_minutes.*must be divisible.*time_resolution_minutes.*"
|
|
34
34
|
):
|
|
35
35
|
_ = Configuration(**configuration.model_dump())
|
|
36
36
|
|
|
@@ -45,7 +45,7 @@ def test_incorrect_interval_end_minutes(test_config_filename):
|
|
|
45
45
|
configuration.input_data.nwp['ukv'].interval_end_minutes = 1111
|
|
46
46
|
with pytest.raises(
|
|
47
47
|
ValueError,
|
|
48
|
-
match="interval_end_minutes
|
|
48
|
+
match="interval_end_minutes.*must be divisible.*time_resolution_minutes.*"
|
|
49
49
|
):
|
|
50
50
|
_ = Configuration(**configuration.model_dump())
|
|
51
51
|
|
|
@@ -103,7 +103,7 @@ def test_inconsistent_dropout_use(test_config_filename):
|
|
|
103
103
|
|
|
104
104
|
configuration = load_yaml_configuration(test_config_filename)
|
|
105
105
|
configuration.input_data.satellite.dropout_fraction= 1.0
|
|
106
|
-
configuration.input_data.satellite.dropout_timedeltas_minutes =
|
|
106
|
+
configuration.input_data.satellite.dropout_timedeltas_minutes = []
|
|
107
107
|
|
|
108
108
|
with pytest.raises(ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas"):
|
|
109
109
|
_ = Configuration(**configuration.model_dump())
|
tests/conftest.py
CHANGED
|
@@ -130,6 +130,39 @@ def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
|
|
|
130
130
|
yield zarr_path
|
|
131
131
|
|
|
132
132
|
|
|
133
|
+
@pytest.fixture()
|
|
134
|
+
def ds_nwp_ukv_time_sliced():
|
|
135
|
+
|
|
136
|
+
t0 = pd.to_datetime("2024-01-02 00:00")
|
|
137
|
+
|
|
138
|
+
x = np.arange(-100, 100, 10)
|
|
139
|
+
y = np.arange(-100, 100, 10)
|
|
140
|
+
steps = pd.timedelta_range("0h", "8h", freq="1h")
|
|
141
|
+
target_times = t0 + steps
|
|
142
|
+
|
|
143
|
+
channels = ["t", "dswrf"]
|
|
144
|
+
init_times = pd.to_datetime([t0]*len(steps))
|
|
145
|
+
|
|
146
|
+
# Create dummy time-sliced NWP data
|
|
147
|
+
da_nwp = xr.DataArray(
|
|
148
|
+
np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
|
|
149
|
+
coords=dict(
|
|
150
|
+
target_time_utc=(["target_time_utc"], target_times),
|
|
151
|
+
channel=(["channel"], channels),
|
|
152
|
+
x_osgb=(["x_osgb"], x),
|
|
153
|
+
y_osgb=(["y_osgb"], y),
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Add extra non-coordinate dimensions
|
|
158
|
+
da_nwp = da_nwp.assign_coords(
|
|
159
|
+
init_time_utc=("target_time_utc", init_times),
|
|
160
|
+
step=("target_time_utc", steps),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return da_nwp
|
|
164
|
+
|
|
165
|
+
|
|
133
166
|
@pytest.fixture(scope="session")
|
|
134
167
|
def ds_nwp_ecmwf():
|
|
135
168
|
init_times = pd.date_range(start="2023-01-01 00:00", freq="6h", periods=24 * 7)
|
tests/numpy_sample/test_nwp.py
CHANGED
|
@@ -1,52 +1,13 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import xarray as xr
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
1
|
from ocf_data_sampler.numpy_sample import convert_nwp_to_numpy_sample, NWPSampleKey
|
|
8
2
|
|
|
9
|
-
@pytest.fixture(scope="module")
|
|
10
|
-
def da_nwp_like():
|
|
11
|
-
"""Create dummy data which looks like time-sliced NWP data"""
|
|
12
|
-
|
|
13
|
-
t0 = pd.to_datetime("2024-01-02 00:00")
|
|
14
|
-
|
|
15
|
-
x = np.arange(-100, 100, 10)
|
|
16
|
-
y = np.arange(-100, 100, 10)
|
|
17
|
-
steps = pd.timedelta_range("0h", "8h", freq="1h")
|
|
18
|
-
target_times = t0 + steps
|
|
19
|
-
|
|
20
|
-
channels = ["t", "dswrf"]
|
|
21
|
-
init_times = pd.to_datetime([t0]*len(steps))
|
|
22
|
-
|
|
23
|
-
# Create dummy time-sliced NWP data
|
|
24
|
-
da_nwp = xr.DataArray(
|
|
25
|
-
np.random.normal(size=(len(target_times), len(channels), len(x), len(y))),
|
|
26
|
-
coords=dict(
|
|
27
|
-
target_times_utc=(["target_times_utc"], target_times),
|
|
28
|
-
channel=(["channel"], channels),
|
|
29
|
-
x_osgb=(["x_osgb"], x),
|
|
30
|
-
y_osgb=(["y_osgb"], y),
|
|
31
|
-
)
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
# Add extra non-coordinate dimensions
|
|
35
|
-
da_nwp = da_nwp.assign_coords(
|
|
36
|
-
init_time_utc=("target_times_utc", init_times),
|
|
37
|
-
step=("target_times_utc", steps),
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
return da_nwp
|
|
41
|
-
|
|
42
3
|
|
|
43
|
-
def test_convert_nwp_to_numpy_sample(
|
|
4
|
+
def test_convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced):
|
|
44
5
|
|
|
45
6
|
# Call the function
|
|
46
|
-
numpy_sample = convert_nwp_to_numpy_sample(
|
|
7
|
+
numpy_sample = convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced)
|
|
47
8
|
|
|
48
9
|
# Assert the output type
|
|
49
10
|
assert isinstance(numpy_sample, dict)
|
|
50
11
|
|
|
51
12
|
# Assert the shape of the numpy sample
|
|
52
|
-
assert (numpy_sample[NWPSampleKey.nwp] ==
|
|
13
|
+
assert (numpy_sample[NWPSampleKey.nwp] == ds_nwp_ukv_time_sliced.values).all()
|
tests/select/test_dropout.py
CHANGED
|
@@ -14,10 +14,8 @@ def da_sample():
|
|
|
14
14
|
datetimes = pd.date_range("2024-01-01 12:00", "2024-01-01 13:00", freq="5min")
|
|
15
15
|
|
|
16
16
|
da_sat = xr.DataArray(
|
|
17
|
-
np.random.normal(size=(len(datetimes)
|
|
18
|
-
coords=dict(
|
|
19
|
-
time_utc=(["time_utc"], datetimes),
|
|
20
|
-
)
|
|
17
|
+
np.random.normal(size=(len(datetimes))),
|
|
18
|
+
coords=dict(time_utc=datetimes)
|
|
21
19
|
)
|
|
22
20
|
return da_sat
|
|
23
21
|
|
|
@@ -29,7 +27,7 @@ def test_draw_dropout_time():
|
|
|
29
27
|
dropout_time = draw_dropout_time(t0, dropout_timedeltas, dropout_frac=1)
|
|
30
28
|
|
|
31
29
|
assert isinstance(dropout_time, pd.Timestamp)
|
|
32
|
-
assert dropout_time-t0 in dropout_timedeltas
|
|
30
|
+
assert (dropout_time-t0) in dropout_timedeltas
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
def test_draw_dropout_time_partial():
|
|
@@ -48,21 +46,17 @@ def test_draw_dropout_time_partial():
|
|
|
48
46
|
dropouts == {None} | set(t0 + dt for dt in dropout_timedeltas)
|
|
49
47
|
|
|
50
48
|
|
|
51
|
-
def
|
|
49
|
+
def test_draw_dropout_time_null():
|
|
52
50
|
t0 = pd.Timestamp("2021-01-01 04:00:00")
|
|
53
51
|
|
|
54
|
-
# No dropout timedeltas
|
|
55
|
-
dropout_time = draw_dropout_time(t0, dropout_timedeltas=None, dropout_frac=1)
|
|
56
|
-
assert dropout_time is None
|
|
57
|
-
|
|
58
52
|
# Dropout fraction is 0
|
|
59
53
|
dropout_timedeltas = [pd.Timedelta(-30, "min")]
|
|
60
54
|
dropout_time = draw_dropout_time(t0, dropout_timedeltas=dropout_timedeltas, dropout_frac=0)
|
|
61
|
-
assert dropout_time
|
|
55
|
+
assert dropout_time==t0
|
|
62
56
|
|
|
63
57
|
# No dropout timedeltas and dropout fraction is 0
|
|
64
|
-
dropout_time = draw_dropout_time(t0, dropout_timedeltas=
|
|
65
|
-
assert dropout_time
|
|
58
|
+
dropout_time = draw_dropout_time(t0, dropout_timedeltas=[], dropout_frac=0)
|
|
59
|
+
assert dropout_time==t0
|
|
66
60
|
|
|
67
61
|
|
|
68
62
|
@pytest.mark.parametrize("t0_str", ["12:00", "12:30", "13:00"])
|