ocf-data-sampler 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +84 -87
- ocf_data_sampler/load/load_dataset.py +55 -0
- ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
- ocf_data_sampler/load/site.py +30 -0
- ocf_data_sampler/numpy_batch/__init__.py +1 -0
- ocf_data_sampler/numpy_batch/site.py +29 -0
- ocf_data_sampler/select/__init__.py +8 -1
- ocf_data_sampler/select/dropout.py +2 -1
- ocf_data_sampler/select/geospatial.py +43 -1
- ocf_data_sampler/select/select_spatial_slice.py +8 -2
- ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
- ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
- ocf_data_sampler/time_functions.py +11 -0
- ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
- ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
- ocf_data_sampler/torch_datasets/site.py +196 -0
- ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/RECORD +28 -16
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/top_level.txt +1 -0
- scripts/refactor_site.py +50 -0
- tests/config/test_config.py +9 -6
- tests/conftest.py +62 -0
- tests/load/test_load_sites.py +14 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +4 -4
- tests/torch_datasets/test_site.py +85 -0
- {ocf_data_sampler-0.0.24.dist-info → ocf_data_sampler-0.0.26.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
from ocf_data_sampler.config import Configuration
|
|
5
|
+
from ocf_data_sampler.select.find_contiguous_time_periods import find_contiguous_t0_periods_nwp, \
|
|
6
|
+
find_contiguous_t0_periods, intersection_of_multiple_dataframes_of_periods
|
|
7
|
+
from ocf_data_sampler.time_functions import minutes
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def find_valid_time_periods(
|
|
11
|
+
datasets_dict: dict,
|
|
12
|
+
config: Configuration,
|
|
13
|
+
):
|
|
14
|
+
"""Find the t0 times where all of the requested input data is available
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
datasets_dict: A dictionary of input datasets
|
|
18
|
+
config: Configuration file
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
|
|
22
|
+
|
|
23
|
+
contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
|
|
24
|
+
|
|
25
|
+
if "nwp" in datasets_dict:
|
|
26
|
+
for nwp_key, nwp_config in config.input_data.nwp.items():
|
|
27
|
+
|
|
28
|
+
da = datasets_dict["nwp"][nwp_key]
|
|
29
|
+
|
|
30
|
+
if nwp_config.dropout_timedeltas_minutes is None:
|
|
31
|
+
max_dropout = minutes(0)
|
|
32
|
+
else:
|
|
33
|
+
max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
|
|
34
|
+
|
|
35
|
+
if nwp_config.max_staleness_minutes is None:
|
|
36
|
+
max_staleness = None
|
|
37
|
+
else:
|
|
38
|
+
max_staleness = minutes(nwp_config.max_staleness_minutes)
|
|
39
|
+
|
|
40
|
+
# The last step of the forecast is lost if we have to diff channels
|
|
41
|
+
if len(nwp_config.accum_channels) > 0:
|
|
42
|
+
end_buffer = minutes(nwp_config.time_resolution_minutes)
|
|
43
|
+
else:
|
|
44
|
+
end_buffer = minutes(0)
|
|
45
|
+
|
|
46
|
+
# This is the max staleness we can use considering the max step of the input data
|
|
47
|
+
max_possible_staleness = (
|
|
48
|
+
pd.Timedelta(da["step"].max().item())
|
|
49
|
+
- minutes(nwp_config.forecast_minutes)
|
|
50
|
+
- end_buffer
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Default to use max possible staleness unless specified in config
|
|
54
|
+
if max_staleness is None:
|
|
55
|
+
max_staleness = max_possible_staleness
|
|
56
|
+
else:
|
|
57
|
+
# Make sure the max acceptable staleness isn't longer than the max possible
|
|
58
|
+
assert max_staleness <= max_possible_staleness
|
|
59
|
+
|
|
60
|
+
time_periods = find_contiguous_t0_periods_nwp(
|
|
61
|
+
datetimes=pd.DatetimeIndex(da["init_time_utc"]),
|
|
62
|
+
history_duration=minutes(nwp_config.history_minutes),
|
|
63
|
+
max_staleness=max_staleness,
|
|
64
|
+
max_dropout=max_dropout,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
|
|
68
|
+
|
|
69
|
+
if "sat" in datasets_dict:
|
|
70
|
+
sat_config = config.input_data.satellite
|
|
71
|
+
|
|
72
|
+
time_periods = find_contiguous_t0_periods(
|
|
73
|
+
pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
|
|
74
|
+
sample_period_duration=minutes(sat_config.time_resolution_minutes),
|
|
75
|
+
history_duration=minutes(sat_config.history_minutes),
|
|
76
|
+
forecast_duration=minutes(sat_config.forecast_minutes),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
contiguous_time_periods['sat'] = time_periods
|
|
80
|
+
|
|
81
|
+
if "gsp" in datasets_dict:
|
|
82
|
+
gsp_config = config.input_data.gsp
|
|
83
|
+
|
|
84
|
+
time_periods = find_contiguous_t0_periods(
|
|
85
|
+
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
|
|
86
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
87
|
+
history_duration=minutes(gsp_config.history_minutes),
|
|
88
|
+
forecast_duration=minutes(gsp_config.forecast_minutes),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
contiguous_time_periods['gsp'] = time_periods
|
|
92
|
+
|
|
93
|
+
# just get the values (not the keys)
|
|
94
|
+
contiguous_time_periods_values = list(contiguous_time_periods.values())
|
|
95
|
+
|
|
96
|
+
# Find joint overlapping contiguous time periods
|
|
97
|
+
if len(contiguous_time_periods_values) > 1:
|
|
98
|
+
valid_time_periods = intersection_of_multiple_dataframes_of_periods(
|
|
99
|
+
contiguous_time_periods_values
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
valid_time_periods = contiguous_time_periods_values[0]
|
|
103
|
+
|
|
104
|
+
# check there are some valid time periods
|
|
105
|
+
if len(valid_time_periods) == 0:
|
|
106
|
+
raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
|
|
107
|
+
|
|
108
|
+
return valid_time_periods
|
|
@@ -1,41 +1,52 @@
|
|
|
1
1
|
ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
2
|
ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
|
|
3
|
+
ocf_data_sampler/time_functions.py,sha256=R6ZlVEe6h4UlJeUW7paZYAMWveOv9MTjMsoISCwnsiE,284
|
|
3
4
|
ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
|
|
4
5
|
ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
|
|
5
|
-
ocf_data_sampler/config/model.py,sha256=
|
|
6
|
+
ocf_data_sampler/config/model.py,sha256=YnGOzt6T835h6bozWqrlMnUIHPo26U8o-DTKAKvv_24,7121
|
|
6
7
|
ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
|
|
7
8
|
ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
8
9
|
ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
|
|
9
10
|
ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
|
|
11
|
+
ocf_data_sampler/load/load_dataset.py,sha256=Ua3RaUg4PIYJkD9BKqTfN8IWUbezbhThJGgEkd9PcaE,1587
|
|
10
12
|
ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
|
|
13
|
+
ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
|
|
11
14
|
ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
|
|
12
15
|
ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
|
|
13
16
|
ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
|
|
14
17
|
ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=
|
|
18
|
+
ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68dWjlHa6TVJzx3ac,1280
|
|
16
19
|
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
|
|
17
20
|
ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
|
|
18
|
-
ocf_data_sampler/numpy_batch/__init__.py,sha256=
|
|
21
|
+
ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
|
|
19
22
|
ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
|
|
20
23
|
ocf_data_sampler/numpy_batch/nwp.py,sha256=dAehfRo5DL2Yb20ifHHl5cU1QOrm3ZOpQmN39fSUOw8,1255
|
|
21
24
|
ocf_data_sampler/numpy_batch/satellite.py,sha256=3NoE_ElzMHwO60apqJeFAwI6J7eIxD0OWTyAVl-uJi8,903
|
|
25
|
+
ocf_data_sampler/numpy_batch/site.py,sha256=lJYMEot50UgSBnSOgADQMjUhky1YyWKYqwNsisyYV6w,789
|
|
22
26
|
ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
|
|
23
|
-
ocf_data_sampler/select/__init__.py,sha256=
|
|
24
|
-
ocf_data_sampler/select/dropout.py,sha256=
|
|
27
|
+
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
28
|
+
ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
|
|
25
29
|
ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
|
|
26
30
|
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=6ioB8LeFpFNBMgKDxrgG3zqzNjkBF_jlV9yye2ZYT2E,11925
|
|
27
|
-
ocf_data_sampler/select/geospatial.py,sha256=
|
|
31
|
+
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
28
32
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
29
|
-
ocf_data_sampler/select/select_spatial_slice.py,sha256=
|
|
33
|
+
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
30
34
|
ocf_data_sampler/select/select_time_slice.py,sha256=41cch1fQr59fZgv7UHsNGc3OvoynrixT3bmr3_1d7cU,6628
|
|
35
|
+
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
36
|
+
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=5gcTGgQ1D524OhullNRWq3hxCwl2SoliGR210G-62JA,4216
|
|
31
37
|
ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
32
|
-
ocf_data_sampler/torch_datasets/
|
|
38
|
+
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=GA-tGZLEMNAqX5Zun_7tPcTWVxlVtwejC9zfXPECwSk,4989
|
|
39
|
+
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=TpHALGU7hpo3iLbvD0nkoY6zu94Vq99W1V1qSGEcIW8,5552
|
|
40
|
+
ocf_data_sampler/torch_datasets/site.py,sha256=1k0fWXYwAAIWG5DX_j3tgNfY8gglfPGLNzNlZd8EnJs,6631
|
|
41
|
+
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=vP25e7DpWAu4dACTFMJZm0bi304iUFdi1XySAmxi_c0,4159
|
|
42
|
+
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
33
43
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
-
tests/conftest.py,sha256=
|
|
35
|
-
tests/config/test_config.py,sha256=
|
|
44
|
+
tests/conftest.py,sha256=ZRktySCynj3NBbFRR4EFNLRLFMErkQsC-qQlmQzhbRg,7360
|
|
45
|
+
tests/config/test_config.py,sha256=C8NppoEVCMKxTTUf3o_z1Jb_I2DDH75XKpQ9x45U3Hw,5090
|
|
36
46
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
37
47
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
38
48
|
tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
|
|
49
|
+
tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
|
|
39
50
|
tests/numpy_batch/test_gsp.py,sha256=VANXV32K8aLX4dCdhCUnDorJmyNN-Bjc7Wc1N-RzWEk,548
|
|
40
51
|
tests/numpy_batch/test_nwp.py,sha256=Fnj7cR-VR2Z0kMu8SrgnIayjxWnPWrYFjWSjMmnrh4Y,1445
|
|
41
52
|
tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0dbsx1k,1179
|
|
@@ -46,9 +57,10 @@ tests/select/test_find_contiguous_time_periods.py,sha256=G6tJRJd0DMfH9EdfzlKWsmf
|
|
|
46
57
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
47
58
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
48
59
|
tests/select/test_select_time_slice.py,sha256=XC1J3DBBDnt81jcba5u-Hnd0yKv8GIQErLm-OECV6rs,10147
|
|
49
|
-
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=
|
|
50
|
-
|
|
51
|
-
ocf_data_sampler-0.0.
|
|
52
|
-
ocf_data_sampler-0.0.
|
|
53
|
-
ocf_data_sampler-0.0.
|
|
54
|
-
ocf_data_sampler-0.0.
|
|
60
|
+
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=8gxjJO8FhY-ImX6eGnihDFsa8fhU2Zb4bVJaToJwuwo,2653
|
|
61
|
+
tests/torch_datasets/test_site.py,sha256=yTv6tAT6lha5yLYJiC8DNms1dct8o_ObPV97dHZyT7I,2719
|
|
62
|
+
ocf_data_sampler-0.0.26.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
63
|
+
ocf_data_sampler-0.0.26.dist-info/METADATA,sha256=VRnSRX4dgDbz4k9bwSM66uqaHI4P97xC97_NsEIt5qU,5269
|
|
64
|
+
ocf_data_sampler-0.0.26.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
65
|
+
ocf_data_sampler-0.0.26.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
66
|
+
ocf_data_sampler-0.0.26.dist-info/RECORD,,
|
scripts/refactor_site.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
""" Helper functions for refactoring legacy site data """
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def legacy_format(data_ds, metadata_df):
|
|
5
|
+
"""This formats old legacy data to the new format.
|
|
6
|
+
|
|
7
|
+
1. This renames the columns in the metadata
|
|
8
|
+
2. Re-formats the site data from data variables named by the site_id to
|
|
9
|
+
a data array with a site_id dimension. Also adds capacity_kwp to the dataset as a time series for each site_id
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
if "system_id" in metadata_df.columns:
|
|
13
|
+
metadata_df["site_id"] = metadata_df["system_id"]
|
|
14
|
+
|
|
15
|
+
if "capacity_megawatts" in metadata_df.columns:
|
|
16
|
+
metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000
|
|
17
|
+
|
|
18
|
+
# only site data has the site_id as data variables.
|
|
19
|
+
# We want to join them all together and create another coordinate called site_id
|
|
20
|
+
if "0" in data_ds:
|
|
21
|
+
gen_df = data_ds.to_dataframe()
|
|
22
|
+
gen_da = xr.DataArray(
|
|
23
|
+
data=gen_df.values,
|
|
24
|
+
coords=(
|
|
25
|
+
("time_utc", gen_df.index.values),
|
|
26
|
+
("site_id", metadata_df["site_id"]),
|
|
27
|
+
),
|
|
28
|
+
name="generation_kw",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
capacity_df = gen_df
|
|
32
|
+
for col in capacity_df.columns:
|
|
33
|
+
capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][
|
|
34
|
+
"capacity_kwp"
|
|
35
|
+
].iloc[0]
|
|
36
|
+
capacity_da = xr.DataArray(
|
|
37
|
+
data=capacity_df.values,
|
|
38
|
+
coords=(
|
|
39
|
+
("time_utc", gen_df.index.values),
|
|
40
|
+
("site_id", metadata_df["site_id"]),
|
|
41
|
+
),
|
|
42
|
+
name="capacity_kwp",
|
|
43
|
+
)
|
|
44
|
+
data_ds = xr.Dataset(
|
|
45
|
+
{
|
|
46
|
+
"generation_kw": gen_da,
|
|
47
|
+
"capacity_kwp": capacity_da,
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
return data_ds
|
tests/config/test_config.py
CHANGED
|
@@ -10,13 +10,13 @@ from ocf_data_sampler.config import (
|
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def test_default_configuration():
|
|
14
14
|
"""Test default pydantic class"""
|
|
15
15
|
|
|
16
16
|
_ = Configuration()
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def test_load_yaml_configuration(test_config_filename):
|
|
20
20
|
"""
|
|
21
21
|
Test that yaml loading works for 'test_config.yaml'
|
|
22
22
|
and fails for an empty .yaml file
|
|
@@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename):
|
|
|
56
56
|
assert test_config == tmp_config
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def
|
|
59
|
+
def test_extra_field_error():
|
|
60
60
|
"""
|
|
61
61
|
Check an extra parameters in config causes error
|
|
62
62
|
"""
|
|
@@ -99,10 +99,11 @@ def test_incorrect_nwp_provider(test_config_filename):
|
|
|
99
99
|
|
|
100
100
|
configuration = load_yaml_configuration(test_config_filename)
|
|
101
101
|
|
|
102
|
-
configuration.input_data.nwp['ukv'].
|
|
102
|
+
configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
|
|
103
103
|
with pytest.raises(Exception, match="NWP provider"):
|
|
104
104
|
_ = Configuration(**configuration.model_dump())
|
|
105
105
|
|
|
106
|
+
|
|
106
107
|
def test_incorrect_dropout(test_config_filename):
|
|
107
108
|
"""
|
|
108
109
|
Check a dropout timedelta over 0 causes error and 0 doesn't
|
|
@@ -119,6 +120,7 @@ def test_incorrect_dropout(test_config_filename):
|
|
|
119
120
|
configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
|
|
120
121
|
_ = Configuration(**configuration.model_dump())
|
|
121
122
|
|
|
123
|
+
|
|
122
124
|
def test_incorrect_dropout_fraction(test_config_filename):
|
|
123
125
|
"""
|
|
124
126
|
Check dropout fraction outside of range causes error
|
|
@@ -127,11 +129,12 @@ def test_incorrect_dropout_fraction(test_config_filename):
|
|
|
127
129
|
configuration = load_yaml_configuration(test_config_filename)
|
|
128
130
|
|
|
129
131
|
configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
|
|
130
|
-
|
|
132
|
+
|
|
133
|
+
with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
|
|
131
134
|
_ = Configuration(**configuration.model_dump())
|
|
132
135
|
|
|
133
136
|
configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
|
|
134
|
-
with pytest.raises(
|
|
137
|
+
with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
|
|
135
138
|
_ = Configuration(**configuration.model_dump())
|
|
136
139
|
|
|
137
140
|
|
tests/conftest.py
CHANGED
|
@@ -6,6 +6,8 @@ import pytest
|
|
|
6
6
|
import xarray as xr
|
|
7
7
|
import tempfile
|
|
8
8
|
|
|
9
|
+
from ocf_data_sampler.config.model import Site
|
|
10
|
+
|
|
9
11
|
_top_test_directory = os.path.dirname(os.path.realpath(__file__))
|
|
10
12
|
|
|
11
13
|
@pytest.fixture()
|
|
@@ -197,6 +199,66 @@ def ds_uk_gsp():
|
|
|
197
199
|
})
|
|
198
200
|
|
|
199
201
|
|
|
202
|
+
@pytest.fixture(scope="session")
|
|
203
|
+
def data_sites() -> Site:
|
|
204
|
+
"""
|
|
205
|
+
Make fake data for sites
|
|
206
|
+
Returns: filename for netcdf file, and csv metadata
|
|
207
|
+
"""
|
|
208
|
+
times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
|
|
209
|
+
site_ids = list(range(0,10))
|
|
210
|
+
capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4])
|
|
211
|
+
# these are quite specific for the fake satellite data
|
|
212
|
+
longitude = np.arange(-4, -3, 0.1)
|
|
213
|
+
latitude = np.arange(51, 52, 0.1)
|
|
214
|
+
|
|
215
|
+
generation = np.random.uniform(0, 200, size=(len(times), len(site_ids))).astype(np.float32)
|
|
216
|
+
|
|
217
|
+
# repeat capacity in new dims len(times) times
|
|
218
|
+
capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10)
|
|
219
|
+
|
|
220
|
+
coords = (
|
|
221
|
+
("time_utc", times),
|
|
222
|
+
("site_id", site_ids),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
da_cap = xr.DataArray(
|
|
226
|
+
capacity_kwp,
|
|
227
|
+
coords=coords,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
da_gen = xr.DataArray(
|
|
231
|
+
generation,
|
|
232
|
+
coords=coords,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# metadata
|
|
236
|
+
meta_df = pd.DataFrame(columns=[], data = [])
|
|
237
|
+
meta_df['site_id'] = site_ids
|
|
238
|
+
meta_df['capacity_kwp'] = capacity_kwp_1d
|
|
239
|
+
meta_df['longitude'] = longitude
|
|
240
|
+
meta_df['latitude'] = latitude
|
|
241
|
+
|
|
242
|
+
generation = xr.Dataset({
|
|
243
|
+
"capacity_kwp": da_cap,
|
|
244
|
+
"generation_kw": da_gen,
|
|
245
|
+
})
|
|
246
|
+
|
|
247
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
248
|
+
filename = tmpdir + "/sites.netcdf"
|
|
249
|
+
filename_csv = tmpdir + "/sites_metadata.csv"
|
|
250
|
+
generation.to_netcdf(filename)
|
|
251
|
+
meta_df.to_csv(filename_csv)
|
|
252
|
+
|
|
253
|
+
site = Site(file_path=filename,
|
|
254
|
+
metadata_file_path=filename_csv,
|
|
255
|
+
time_resolution_minutes=30,
|
|
256
|
+
forecast_minutes=60,
|
|
257
|
+
history_minutes=30)
|
|
258
|
+
|
|
259
|
+
yield site
|
|
260
|
+
|
|
261
|
+
|
|
200
262
|
@pytest.fixture(scope="session")
|
|
201
263
|
def uk_gsp_zarr_path(ds_uk_gsp):
|
|
202
264
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from ocf_data_sampler.load.site import open_site
|
|
2
|
+
import xarray as xr
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_open_site(data_sites):
|
|
6
|
+
da = open_site(data_sites)
|
|
7
|
+
|
|
8
|
+
assert isinstance(da, xr.DataArray)
|
|
9
|
+
assert da.dims == ("time_utc", "site_id")
|
|
10
|
+
|
|
11
|
+
assert "capacity_kwp" in da.coords
|
|
12
|
+
assert "latitude" in da.coords
|
|
13
|
+
assert "longitude" in da.coords
|
|
14
|
+
assert da.shape == (49, 10)
|
|
@@ -11,9 +11,9 @@ def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_z
|
|
|
11
11
|
|
|
12
12
|
# adjust config to point to the zarr file
|
|
13
13
|
config = load_yaml_configuration(config_filename)
|
|
14
|
-
config.input_data.nwp['ukv'].
|
|
15
|
-
config.input_data.satellite.
|
|
16
|
-
config.input_data.gsp.
|
|
14
|
+
config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
|
|
15
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
16
|
+
config.input_data.gsp.zarr_path = uk_gsp_zarr_path
|
|
17
17
|
|
|
18
18
|
filename = f"{tmp_path}/configuration.yaml"
|
|
19
19
|
save_yaml_configuration(config, filename)
|
|
@@ -60,7 +60,7 @@ def test_pvnet_no_gsp(pvnet_config_filename):
|
|
|
60
60
|
# load config
|
|
61
61
|
config = load_yaml_configuration(pvnet_config_filename)
|
|
62
62
|
# remove gsp
|
|
63
|
-
config.input_data.gsp.
|
|
63
|
+
config.input_data.gsp.zarr_path = ''
|
|
64
64
|
|
|
65
65
|
# save temp config file
|
|
66
66
|
with tempfile.NamedTemporaryFile() as temp_config_file:
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from ocf_data_sampler.torch_datasets.site import SitesDataset
|
|
5
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
6
|
+
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
|
|
7
|
+
from ocf_data_sampler.numpy_batch.site import SiteBatchKey
|
|
8
|
+
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture()
|
|
12
|
+
def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
|
|
13
|
+
|
|
14
|
+
# adjust config to point to the zarr file
|
|
15
|
+
config = load_yaml_configuration(config_filename)
|
|
16
|
+
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
17
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
18
|
+
config.input_data.site = data_sites
|
|
19
|
+
config.input_data.gsp = None
|
|
20
|
+
|
|
21
|
+
filename = f"{tmp_path}/configuration.yaml"
|
|
22
|
+
save_yaml_configuration(config, filename)
|
|
23
|
+
return filename
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_site(site_config_filename):
|
|
27
|
+
|
|
28
|
+
# Create dataset object
|
|
29
|
+
dataset = SitesDataset(site_config_filename)
|
|
30
|
+
|
|
31
|
+
assert len(dataset) == 10 * 41
|
|
32
|
+
# TODO check 41
|
|
33
|
+
|
|
34
|
+
# Generate a sample
|
|
35
|
+
sample = dataset[0]
|
|
36
|
+
|
|
37
|
+
assert isinstance(sample, dict)
|
|
38
|
+
|
|
39
|
+
for key in [
|
|
40
|
+
NWPBatchKey.nwp,
|
|
41
|
+
SatelliteBatchKey.satellite_actual,
|
|
42
|
+
SiteBatchKey.generation,
|
|
43
|
+
SiteBatchKey.site_solar_azimuth,
|
|
44
|
+
SiteBatchKey.site_solar_elevation,
|
|
45
|
+
]:
|
|
46
|
+
assert key in sample
|
|
47
|
+
|
|
48
|
+
for nwp_source in ["ukv"]:
|
|
49
|
+
assert nwp_source in sample[NWPBatchKey.nwp]
|
|
50
|
+
|
|
51
|
+
# check the shape of the data is correct
|
|
52
|
+
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
53
|
+
assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
54
|
+
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
|
|
55
|
+
assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
|
|
56
|
+
# 3 hours of 30 minute data (inclusive)
|
|
57
|
+
assert sample[SiteBatchKey.generation].shape == (4,)
|
|
58
|
+
# Solar angles have same shape as GSP data
|
|
59
|
+
assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
|
|
60
|
+
assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_site_time_filter_start(site_config_filename):
|
|
64
|
+
|
|
65
|
+
# Create dataset object
|
|
66
|
+
dataset = SitesDataset(site_config_filename, start_time="2024-01-01")
|
|
67
|
+
|
|
68
|
+
assert len(dataset) == 0
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_site_time_filter_end(site_config_filename):
|
|
72
|
+
|
|
73
|
+
# Create dataset object
|
|
74
|
+
dataset = SitesDataset(site_config_filename, end_time="2000-01-01")
|
|
75
|
+
|
|
76
|
+
assert len(dataset) == 0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_site_get_sample(site_config_filename):
|
|
80
|
+
|
|
81
|
+
# Create dataset object
|
|
82
|
+
dataset = SitesDataset(site_config_filename)
|
|
83
|
+
|
|
84
|
+
assert len(dataset) == 410
|
|
85
|
+
sample = dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
|
|
File without changes
|