ocf-data-sampler 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/constants.py +38 -0
- ocf_data_sampler/numpy_batch/nwp.py +0 -1
- ocf_data_sampler/numpy_batch/satellite.py +2 -1
- ocf_data_sampler/torch_datasets/process_and_combine.py +11 -4
- {ocf_data_sampler-0.0.37.dist-info → ocf_data_sampler-0.0.38.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.0.37.dist-info → ocf_data_sampler-0.0.38.dist-info}/RECORD +10 -9
- tests/torch_datasets/test_process_and_combine.py +165 -0
- {ocf_data_sampler-0.0.37.dist-info → ocf_data_sampler-0.0.38.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.37.dist-info → ocf_data_sampler-0.0.38.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.0.37.dist-info → ocf_data_sampler-0.0.38.dist-info}/top_level.txt +0 -0
ocf_data_sampler/constants.py
CHANGED
|
@@ -28,6 +28,7 @@ class NWPStatDict(dict):
|
|
|
28
28
|
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
+
|
|
31
32
|
# ------ UKV
|
|
32
33
|
# Means and std computed WITH version_7 and higher, MetOffice values
|
|
33
34
|
UKV_STD = {
|
|
@@ -49,6 +50,7 @@ UKV_STD = {
|
|
|
49
50
|
"prmsl": 1252.71790539,
|
|
50
51
|
"prate": 0.00021497,
|
|
51
52
|
}
|
|
53
|
+
|
|
52
54
|
UKV_MEAN = {
|
|
53
55
|
"cdcb": 1412.26599062,
|
|
54
56
|
"lcc": 50.08362643,
|
|
@@ -97,6 +99,7 @@ ECMWF_STD = {
|
|
|
97
99
|
"diff_duvrs": 81605.25,
|
|
98
100
|
"diff_sr": 818950.6875,
|
|
99
101
|
}
|
|
102
|
+
|
|
100
103
|
ECMWF_MEAN = {
|
|
101
104
|
"dlwrf": 27187026.0,
|
|
102
105
|
"dswrf": 11458988.0,
|
|
@@ -133,3 +136,38 @@ NWP_MEANS = NWPStatDict(
|
|
|
133
136
|
ecmwf=ECMWF_MEAN,
|
|
134
137
|
)
|
|
135
138
|
|
|
139
|
+
# ------ Satellite
|
|
140
|
+
# RSS Mean and std values from randomised 20% of 2020 imagery
|
|
141
|
+
|
|
142
|
+
RSS_STD = {
|
|
143
|
+
"HRV": 0.11405209,
|
|
144
|
+
"IR_016": 0.21462157,
|
|
145
|
+
"IR_039": 0.04618041,
|
|
146
|
+
"IR_087": 0.06687243,
|
|
147
|
+
"IR_097": 0.0468558,
|
|
148
|
+
"IR_108": 0.17482725,
|
|
149
|
+
"IR_120": 0.06115861,
|
|
150
|
+
"IR_134": 0.04492306,
|
|
151
|
+
"VIS006": 0.12184761,
|
|
152
|
+
"VIS008": 0.13090034,
|
|
153
|
+
"WV_062": 0.16111417,
|
|
154
|
+
"WV_073": 0.12924142,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
RSS_MEAN = {
|
|
158
|
+
"HRV": 0.09298719,
|
|
159
|
+
"IR_016": 0.17594202,
|
|
160
|
+
"IR_039": 0.86167645,
|
|
161
|
+
"IR_087": 0.7719318,
|
|
162
|
+
"IR_097": 0.8014212,
|
|
163
|
+
"IR_108": 0.71254843,
|
|
164
|
+
"IR_120": 0.89058584,
|
|
165
|
+
"IR_134": 0.944365,
|
|
166
|
+
"VIS006": 0.09633306,
|
|
167
|
+
"VIS008": 0.11426069,
|
|
168
|
+
"WV_062": 0.7359355,
|
|
169
|
+
"WV_073": 0.62479186,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
RSS_STD = _to_data_array(RSS_STD)
|
|
173
|
+
RSS_MEAN = _to_data_array(RSS_MEAN)
|
|
@@ -13,6 +13,7 @@ class SatelliteBatchKey:
|
|
|
13
13
|
|
|
14
14
|
def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
15
15
|
"""Convert from Xarray to NumpyBatch"""
|
|
16
|
+
|
|
16
17
|
example = {
|
|
17
18
|
SatelliteBatchKey.satellite_actual: da.values,
|
|
18
19
|
SatelliteBatchKey.time_utc: da.time_utc.values.astype(float),
|
|
@@ -27,4 +28,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None
|
|
|
27
28
|
if t0_idx is not None:
|
|
28
29
|
example[SatelliteBatchKey.t0_idx] = t0_idx
|
|
29
30
|
|
|
30
|
-
return example
|
|
31
|
+
return example
|
|
@@ -4,7 +4,7 @@ import xarray as xr
|
|
|
4
4
|
from typing import Tuple
|
|
5
5
|
|
|
6
6
|
from ocf_data_sampler.config import Configuration
|
|
7
|
-
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
|
|
7
|
+
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
8
8
|
from ocf_data_sampler.numpy_batch import (
|
|
9
9
|
convert_nwp_to_numpy_batch,
|
|
10
10
|
convert_satellite_to_numpy_batch,
|
|
@@ -25,8 +25,8 @@ def process_and_combine_datasets(
|
|
|
25
25
|
location: Location,
|
|
26
26
|
target_key: str = 'gsp'
|
|
27
27
|
) -> dict:
|
|
28
|
-
"""Normalize and convert data to numpy arrays"""
|
|
29
28
|
|
|
29
|
+
"""Normalise and convert data to numpy arrays"""
|
|
30
30
|
numpy_modalities = []
|
|
31
31
|
|
|
32
32
|
if "nwp" in dataset_dict:
|
|
@@ -37,19 +37,23 @@ def process_and_combine_datasets(
|
|
|
37
37
|
# Standardise
|
|
38
38
|
provider = config.input_data.nwp[nwp_key].provider
|
|
39
39
|
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
|
|
40
|
+
|
|
40
41
|
# Convert to NumpyBatch
|
|
41
42
|
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
|
|
42
43
|
|
|
43
44
|
# Combine the NWPs into NumpyBatch
|
|
44
45
|
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})
|
|
45
46
|
|
|
47
|
+
|
|
46
48
|
if "sat" in dataset_dict:
|
|
47
|
-
#
|
|
49
|
+
# Standardise
|
|
48
50
|
da_sat = dataset_dict["sat"]
|
|
51
|
+
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
49
52
|
|
|
50
53
|
# Convert to NumpyBatch
|
|
51
54
|
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
|
|
52
55
|
|
|
56
|
+
|
|
53
57
|
gsp_config = config.input_data.gsp
|
|
54
58
|
|
|
55
59
|
if "gsp" in dataset_dict:
|
|
@@ -93,6 +97,7 @@ def process_and_combine_datasets(
|
|
|
93
97
|
|
|
94
98
|
return combined_sample
|
|
95
99
|
|
|
100
|
+
|
|
96
101
|
def process_and_combine_site_sample_dict(
|
|
97
102
|
dataset_dict: dict,
|
|
98
103
|
config: Configuration,
|
|
@@ -119,8 +124,9 @@ def process_and_combine_site_sample_dict(
|
|
|
119
124
|
data_arrays.append((f"nwp-{provider}", da_nwp))
|
|
120
125
|
|
|
121
126
|
if "sat" in dataset_dict:
|
|
122
|
-
#
|
|
127
|
+
# Standardise
|
|
123
128
|
da_sat = dataset_dict["sat"]
|
|
129
|
+
da_sat = (da_sat - RSS_MEAN) / RSS_STD
|
|
124
130
|
data_arrays.append(("satellite", da_sat))
|
|
125
131
|
|
|
126
132
|
if "site" in dataset_dict:
|
|
@@ -143,6 +149,7 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
|
|
|
143
149
|
combined_dict.update(d)
|
|
144
150
|
return combined_dict
|
|
145
151
|
|
|
152
|
+
|
|
146
153
|
def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
|
|
147
154
|
"""
|
|
148
155
|
Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
|
-
ocf_data_sampler/constants.py,sha256=
|
|
2
|
+
ocf_data_sampler/constants.py,sha256=G2VfkE_-veq_0hNBQQOQCtCsfC37O5-QG9mJWEmln5s,4153
|
|
3
3
|
ocf_data_sampler/utils.py,sha256=rKA0BHAyAG4f90zEcgxp25EEYrXS-aOVNzttZ6Mzv2k,250
|
|
4
4
|
ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
|
|
5
5
|
ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
|
|
@@ -21,8 +21,8 @@ ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3tx
|
|
|
21
21
|
ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
|
|
22
22
|
ocf_data_sampler/numpy_batch/collate.py,sha256=KyWdDi8AXD5YiokXXiqr2_X1SC1me1GrhnQMelg0Qx8,2202
|
|
23
23
|
ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
|
|
24
|
-
ocf_data_sampler/numpy_batch/nwp.py,sha256=
|
|
25
|
-
ocf_data_sampler/numpy_batch/satellite.py,sha256=
|
|
24
|
+
ocf_data_sampler/numpy_batch/nwp.py,sha256=bEvBB9xGf7B8okPBZ-eZLK4PBWA0nvmmEFiN49dgqPU,1254
|
|
25
|
+
ocf_data_sampler/numpy_batch/satellite.py,sha256=VKo8eiSIcYhAdHHBUH697HMz7rBv6S9XZ6_XCZ-qG4Y,905
|
|
26
26
|
ocf_data_sampler/numpy_batch/site.py,sha256=CWI0efUl8SrnGm0VNGdGwAqrmlT1XaVbJIUE2hSOz9E,744
|
|
27
27
|
ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
|
|
28
28
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
@@ -36,7 +36,7 @@ ocf_data_sampler/select/select_time_slice.py,sha256=D5P_cSvnv8Qs49K5au7lPxDr9U_V
|
|
|
36
36
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
37
37
|
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=LMw8KnOCKnPjD0m4UubAWERpaiQtzRKkI2cSh5a0A-M,4335
|
|
38
38
|
ocf_data_sampler/torch_datasets/__init__.py,sha256=nJUa2KzVa84ZoM0PT2AbDz26ennmAYc7M7WJVfypPMs,85
|
|
39
|
-
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=
|
|
39
|
+
ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=ImfU4I75x7A57KCShWj6dr62tNtJqJ0ImKRiT0hijIQ,7564
|
|
40
40
|
ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=QRFqbdfNchVWj4y70n-rJdFvFGvQj-WpZLdFqWjnOTw,5543
|
|
41
41
|
ocf_data_sampler/torch_datasets/site.py,sha256=NYuhgm9ti9SRt1dcb_WrFYYo14NgVdOsaoPbc5FsnaA,6560
|
|
42
42
|
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
@@ -59,10 +59,11 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
|
|
|
59
59
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
60
60
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
61
61
|
tests/select/test_select_time_slice.py,sha256=K1EJR5TwZa9dJf_YTEHxGtvs398iy1xS2lr1BgJZkoo,9603
|
|
62
|
+
tests/torch_datasets/test_process_and_combine.py,sha256=SWmrI59JVfMnHK78N5yhKzQR8b5kJ8TeMZke9Mlnc-o,5717
|
|
62
63
|
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=eqy0nQOWoHnqltlJlGmRlgIiIzPEwOC6o5A6GARryKA,2118
|
|
63
64
|
tests/torch_datasets/test_site.py,sha256=YuVjWTI14_kmEOx23XE5J_RZ8UalCKD2xRv6mqYizB8,2872
|
|
64
|
-
ocf_data_sampler-0.0.
|
|
65
|
-
ocf_data_sampler-0.0.
|
|
66
|
-
ocf_data_sampler-0.0.
|
|
67
|
-
ocf_data_sampler-0.0.
|
|
68
|
-
ocf_data_sampler-0.0.
|
|
65
|
+
ocf_data_sampler-0.0.38.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
66
|
+
ocf_data_sampler-0.0.38.dist-info/METADATA,sha256=YbU2ymHq94ZLsyjlD1ZdKoYpVVDzUUmyWN7xRDBvQDM,10290
|
|
67
|
+
ocf_data_sampler-0.0.38.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
68
|
+
ocf_data_sampler-0.0.38.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
69
|
+
ocf_data_sampler-0.0.38.dist-info/RECORD,,
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import tempfile
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import xarray as xr
|
|
7
|
+
import dask.array as da
|
|
8
|
+
|
|
9
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
10
|
+
from ocf_data_sampler.config import Configuration
|
|
11
|
+
from ocf_data_sampler.select.location import Location
|
|
12
|
+
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
|
|
13
|
+
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
14
|
+
|
|
15
|
+
from ocf_data_sampler.torch_datasets.process_and_combine import (
|
|
16
|
+
process_and_combine_datasets,
|
|
17
|
+
process_and_combine_site_sample_dict,
|
|
18
|
+
merge_dicts,
|
|
19
|
+
fill_nans_in_arrays,
|
|
20
|
+
compute,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
25
|
+
|
|
26
|
+
# Load in config for function and define location
|
|
27
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
28
|
+
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
29
|
+
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
30
|
+
|
|
31
|
+
nwp_data = xr.DataArray(
|
|
32
|
+
np.random.rand(4, 2, 2, 2),
|
|
33
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
34
|
+
coords={
|
|
35
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
36
|
+
"channel": ["t2m", "dswrf"],
|
|
37
|
+
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
38
|
+
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
sat_data = xr.DataArray(
|
|
43
|
+
np.random.rand(7, 1, 2, 2),
|
|
44
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
45
|
+
coords={
|
|
46
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
|
|
47
|
+
"channel": ["HRV"],
|
|
48
|
+
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
|
|
49
|
+
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
|
|
50
|
+
}
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Combine as dict
|
|
54
|
+
dataset_dict = {
|
|
55
|
+
"nwp": {"ukv": nwp_data},
|
|
56
|
+
"sat": sat_data
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
# Call relevant function
|
|
60
|
+
result = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
61
|
+
|
|
62
|
+
# Assert result is dict - check and validate
|
|
63
|
+
assert isinstance(result, dict)
|
|
64
|
+
assert NWPBatchKey.nwp in result
|
|
65
|
+
assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
66
|
+
assert result[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_merge_dicts():
|
|
70
|
+
"""Test merge_dicts function"""
|
|
71
|
+
dict1 = {"a": 1, "b": 2}
|
|
72
|
+
dict2 = {"c": 3, "d": 4}
|
|
73
|
+
dict3 = {"e": 5}
|
|
74
|
+
|
|
75
|
+
result = merge_dicts([dict1, dict2, dict3])
|
|
76
|
+
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
|
77
|
+
|
|
78
|
+
# Test key overwriting
|
|
79
|
+
dict4 = {"a": 10, "f": 6}
|
|
80
|
+
result = merge_dicts([dict1, dict4])
|
|
81
|
+
assert result["a"] == 10
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_fill_nans_in_arrays():
|
|
85
|
+
"""Test the fill_nans_in_arrays function"""
|
|
86
|
+
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
|
|
87
|
+
nested_dict = {
|
|
88
|
+
"array1": array_with_nans,
|
|
89
|
+
"nested": {
|
|
90
|
+
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
|
|
91
|
+
},
|
|
92
|
+
"string_key": "not_an_array"
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
result = fill_nans_in_arrays(nested_dict)
|
|
96
|
+
|
|
97
|
+
assert not np.isnan(result["array1"]).any()
|
|
98
|
+
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
|
|
99
|
+
assert not np.isnan(result["nested"]["array2"]).any()
|
|
100
|
+
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
|
|
101
|
+
assert result["string_key"] == "not_an_array"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_compute():
|
|
105
|
+
"""Test compute function with dask array"""
|
|
106
|
+
da_dask = xr.DataArray(da.random.random((5, 5)))
|
|
107
|
+
|
|
108
|
+
# Create a nested dictionary with dask array
|
|
109
|
+
nested_dict = {
|
|
110
|
+
"array1": da_dask,
|
|
111
|
+
"nested": {
|
|
112
|
+
"array2": da_dask
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
# Ensure initial data is lazy - i.e. not yet computed
|
|
117
|
+
assert not isinstance(nested_dict["array1"].data, np.ndarray)
|
|
118
|
+
assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
|
|
119
|
+
|
|
120
|
+
# Call the compute function
|
|
121
|
+
result = compute(nested_dict)
|
|
122
|
+
|
|
123
|
+
# Assert that the result is an xarray DataArray and no longer lazy
|
|
124
|
+
assert isinstance(result["array1"], xr.DataArray)
|
|
125
|
+
assert isinstance(result["nested"]["array2"], xr.DataArray)
|
|
126
|
+
assert isinstance(result["array1"].data, np.ndarray)
|
|
127
|
+
assert isinstance(result["nested"]["array2"].data, np.ndarray)
|
|
128
|
+
|
|
129
|
+
# Ensure there no NaN values in computed data
|
|
130
|
+
assert not np.isnan(result["array1"].data).any()
|
|
131
|
+
assert not np.isnan(result["nested"]["array2"].data).any()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_process_and_combine_site_sample_dict(pvnet_config_filename):
|
|
135
|
+
# Load config
|
|
136
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
137
|
+
|
|
138
|
+
# Specify minimal structure for testing
|
|
139
|
+
raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
|
|
140
|
+
site_dict = {
|
|
141
|
+
"nwp": {
|
|
142
|
+
"ukv": xr.DataArray(
|
|
143
|
+
raw_nwp_values,
|
|
144
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
145
|
+
coords={
|
|
146
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
147
|
+
"channel": ["dswrf"], # Single channel
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
print(f"Input site_dict: {site_dict}")
|
|
153
|
+
|
|
154
|
+
# Call function
|
|
155
|
+
result = process_and_combine_site_sample_dict(site_dict, config)
|
|
156
|
+
|
|
157
|
+
# Assert to validate output structure
|
|
158
|
+
assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
|
|
159
|
+
assert len(result.data_vars) > 0, "Dataset should contain data variables"
|
|
160
|
+
|
|
161
|
+
# Validate variable via assertion and shape of such
|
|
162
|
+
expected_variable = "nwp-ukv"
|
|
163
|
+
assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found"
|
|
164
|
+
nwp_result = result[expected_variable]
|
|
165
|
+
assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for '{expected_variable}': {nwp_result.shape}"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|