ocf-data-sampler 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/numpy_sample/collate.py +46 -57
- ocf_data_sampler/sample/uk_regional.py +3 -1
- ocf_data_sampler/select/fill_time_periods.py +1 -1
- ocf_data_sampler/select/time_slice_for_dataset.py +16 -13
- ocf_data_sampler/torch_datasets/datasets/__init__.py +2 -7
- ocf_data_sampler/torch_datasets/datasets/{pvnet_uk_regional.py → pvnet_uk.py} +114 -16
- {ocf_data_sampler-0.1.0.dist-info → ocf_data_sampler-0.1.2.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.0.dist-info → ocf_data_sampler-0.1.2.dist-info}/RECORD +17 -18
- tests/conftest.py +69 -70
- tests/load/test_load_satellite.py +3 -3
- tests/numpy_sample/test_collate.py +4 -9
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -2
- tests/torch_datasets/test_pvnet_uk.py +166 -0
- tests/torch_datasets/test_site.py +47 -36
- tests/torch_datasets/conftest.py +0 -18
- tests/torch_datasets/test_pvnet_uk_regional.py +0 -136
- {ocf_data_sampler-0.1.0.dist-info → ocf_data_sampler-0.1.2.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.0.dist-info → ocf_data_sampler-0.1.2.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.0.dist-info → ocf_data_sampler-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -1,75 +1,64 @@
|
|
|
1
|
-
from ocf_data_sampler.numpy_sample import NWPSampleKey
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
|
-
import logging
|
|
5
|
-
from typing import Union
|
|
6
|
-
|
|
7
|
-
logger = logging.getLogger(__name__)
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def stack_np_samples_into_batch(dict_list):
|
|
12
|
-
# """
|
|
13
|
-
# Stacks Numpy samples into a batch
|
|
14
2
|
|
|
15
|
-
# Args:
|
|
16
|
-
# dict_list: A list of dict-like Numpy samples to stack
|
|
17
3
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
# """
|
|
4
|
+
def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
|
|
5
|
+
"""Stacks list of dict samples into a dict where all samples are joined along a new axis
|
|
21
6
|
|
|
22
|
-
|
|
23
|
-
|
|
7
|
+
Args:
|
|
8
|
+
dict_list: A list of dict-like samples to stack
|
|
24
9
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
10
|
+
Returns:
|
|
11
|
+
Dict of the samples stacked with new batch dimension on axis 0
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
batch = {}
|
|
15
|
+
|
|
16
|
+
keys = list(dict_list[0].keys())
|
|
17
|
+
|
|
18
|
+
for key in keys:
|
|
19
|
+
# NWP is nested so treat separately
|
|
20
|
+
if key == "nwp":
|
|
21
|
+
batch["nwp"] = {}
|
|
22
|
+
|
|
23
|
+
# Unpack NWP provider keys
|
|
24
|
+
nwp_providers = list(dict_list[0]["nwp"].keys())
|
|
25
|
+
|
|
26
|
+
for nwp_provider in nwp_providers:
|
|
27
|
+
# Keys can be different for different NWPs
|
|
28
|
+
nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
|
|
29
|
+
|
|
30
|
+
# Create dict to store NWP batch for this provider
|
|
31
|
+
nwp_provider_batch = {}
|
|
32
|
+
|
|
33
|
+
for nwp_key in nwp_keys:
|
|
34
|
+
# Stack values under each NWP key for this provider
|
|
35
|
+
nwp_provider_batch[nwp_key] = stack_data_list(
|
|
36
|
+
[d["nwp"][nwp_provider][nwp_key] for d in dict_list],
|
|
37
|
+
nwp_key,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
batch["nwp"][nwp_provider] = nwp_provider_batch
|
|
28
41
|
|
|
29
|
-
# Process - handle NWP separately due to nested structure
|
|
30
|
-
for sample_key in sample_keys:
|
|
31
|
-
if sample_key == "nwp":
|
|
32
|
-
sample["nwp"] = process_nwp_data(dict_list)
|
|
33
42
|
else:
|
|
34
|
-
|
|
35
|
-
sample[sample_key] = stack_data_list([d[sample_key] for d in dict_list], sample_key)
|
|
36
|
-
return sample
|
|
43
|
+
batch[key] = stack_data_list([d[key] for d in dict_list], key)
|
|
37
44
|
|
|
45
|
+
return batch
|
|
38
46
|
|
|
39
|
-
def process_nwp_data(dict_list):
|
|
40
|
-
"""Stacks data for NWP, handling nested structure"""
|
|
41
|
-
|
|
42
|
-
nwp_sample = {}
|
|
43
|
-
nwp_sources = dict_list[0]["nwp"].keys()
|
|
44
47
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
nested_keys = dict_list[0]["nwp"][nwp_source].keys()
|
|
48
|
-
nwp_sample[nwp_source] = {
|
|
49
|
-
key: stack_data_list([d["nwp"][nwp_source][key] for d in dict_list], key)
|
|
50
|
-
for key in nested_keys
|
|
51
|
-
}
|
|
52
|
-
return nwp_sample
|
|
48
|
+
def _key_is_constant(key: str):
|
|
49
|
+
return key.endswith("t0_idx") or key.endswith("channel_names")
|
|
53
50
|
|
|
54
|
-
def _key_is_constant(sample_key):
|
|
55
|
-
return sample_key.endswith("t0_idx") or sample_key == NWPSampleKey.channel_names
|
|
56
51
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
"""How to combine data entries for each key
|
|
52
|
+
def stack_data_list(data_list: list, key: str):
|
|
53
|
+
"""Stack a sequence of data elements along a new axis
|
|
60
54
|
|
|
61
55
|
Args:
|
|
62
|
-
data_list: List of data
|
|
63
|
-
|
|
56
|
+
data_list: List of data elements to combine
|
|
57
|
+
key: string identifying the data type
|
|
64
58
|
"""
|
|
65
|
-
if _key_is_constant(
|
|
59
|
+
if _key_is_constant(key):
|
|
66
60
|
# These are always the same for all examples.
|
|
67
61
|
return data_list[0]
|
|
68
|
-
|
|
62
|
+
else:
|
|
69
63
|
return np.stack(data_list)
|
|
70
|
-
|
|
71
|
-
logger.debug(f"Could not stack the following shapes together, ({sample_key})")
|
|
72
|
-
shapes = [example.shape for example in data_list]
|
|
73
|
-
logger.debug(shapes)
|
|
74
|
-
logger.error(e)
|
|
75
|
-
raise e
|
|
64
|
+
|
|
@@ -65,7 +65,9 @@ class UKRegionalSample(SampleBase):
|
|
|
65
65
|
raise ValueError(f"Only .pt format is supported: {path.suffix}")
|
|
66
66
|
|
|
67
67
|
instance = cls()
|
|
68
|
-
|
|
68
|
+
# TODO: We should move away from using torch.load(..., weights_only=False)
|
|
69
|
+
# This is not recommended
|
|
70
|
+
instance._data = torch.load(path, weights_only=False)
|
|
69
71
|
logger.debug(f"Successfully loaded UKRegionalSample from {path}")
|
|
70
72
|
return instance
|
|
71
73
|
|
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta):
|
|
7
|
+
def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
|
|
8
8
|
start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
|
|
9
9
|
end_dts = pd.to_datetime(time_periods["end_dt"].values)
|
|
10
10
|
date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
""" Slice datasets by time"""
|
|
2
2
|
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
3
4
|
|
|
4
5
|
from ocf_data_sampler.config import Configuration
|
|
5
6
|
from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
|
|
@@ -64,16 +65,8 @@ def slice_datasets_by_time(
|
|
|
64
65
|
|
|
65
66
|
if "gsp" in datasets_dict:
|
|
66
67
|
gsp_config = config.input_data.gsp
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
datasets_dict["gsp"],
|
|
70
|
-
t0,
|
|
71
|
-
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
72
|
-
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
73
|
-
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
sliced_datasets_dict["gsp"] = select_time_slice(
|
|
68
|
+
|
|
69
|
+
da_gsp_past = select_time_slice(
|
|
77
70
|
datasets_dict["gsp"],
|
|
78
71
|
t0,
|
|
79
72
|
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
@@ -81,17 +74,27 @@ def slice_datasets_by_time(
|
|
|
81
74
|
interval_end=minutes(0),
|
|
82
75
|
)
|
|
83
76
|
|
|
84
|
-
# Dropout on the GSP, but not the future GSP
|
|
77
|
+
# Dropout on the past GSP, but not the future GSP
|
|
85
78
|
gsp_dropout_time = draw_dropout_time(
|
|
86
79
|
t0,
|
|
87
80
|
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
88
81
|
dropout_frac=gsp_config.dropout_fraction,
|
|
89
82
|
)
|
|
90
83
|
|
|
91
|
-
|
|
92
|
-
|
|
84
|
+
da_gsp_past = apply_dropout_time(
|
|
85
|
+
da_gsp_past,
|
|
93
86
|
gsp_dropout_time
|
|
94
87
|
)
|
|
88
|
+
|
|
89
|
+
da_gsp_future = select_time_slice(
|
|
90
|
+
datasets_dict["gsp"],
|
|
91
|
+
t0,
|
|
92
|
+
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
|
|
93
|
+
interval_start=minutes(gsp_config.time_resolution_minutes),
|
|
94
|
+
interval_end=minutes(gsp_config.interval_end_minutes),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
|
|
95
98
|
|
|
96
99
|
if "site" in datasets_dict:
|
|
97
100
|
site_config = config.input_data.site
|
|
@@ -1,11 +1,6 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .pvnet_uk import PVNetUKRegionalDataset, PVNetUKConcurrentDataset
|
|
2
2
|
|
|
3
3
|
from .site import (
|
|
4
4
|
convert_netcdf_to_numpy_sample,
|
|
5
5
|
SitesDataset
|
|
6
|
-
)
|
|
7
|
-
|
|
8
|
-
__all__ = [
|
|
9
|
-
'convert_netcdf_to_numpy_sample',
|
|
10
|
-
'SitesDataset'
|
|
11
|
-
]
|
|
6
|
+
)
|
|
@@ -1,15 +1,20 @@
|
|
|
1
|
-
"""Torch dataset for PVNet"""
|
|
1
|
+
"""Torch dataset for UK PVNet"""
|
|
2
|
+
|
|
3
|
+
import pkg_resources
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
6
|
import pandas as pd
|
|
5
|
-
import pkg_resources
|
|
6
7
|
import xarray as xr
|
|
7
8
|
from torch.utils.data import Dataset
|
|
8
9
|
from ocf_data_sampler.config import Configuration, load_yaml_configuration
|
|
9
10
|
from ocf_data_sampler.load.load_dataset import get_dataset_dict
|
|
10
|
-
from ocf_data_sampler.select import
|
|
11
|
+
from ocf_data_sampler.select import (
|
|
12
|
+
fill_time_periods,
|
|
13
|
+
Location,
|
|
14
|
+
slice_datasets_by_space,
|
|
15
|
+
slice_datasets_by_time,
|
|
16
|
+
)
|
|
11
17
|
from ocf_data_sampler.utils import minutes
|
|
12
|
-
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
13
18
|
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
|
|
14
19
|
from ocf_data_sampler.numpy_sample import (
|
|
15
20
|
convert_nwp_to_numpy_sample,
|
|
@@ -17,13 +22,16 @@ from ocf_data_sampler.numpy_sample import (
|
|
|
17
22
|
convert_gsp_to_numpy_sample,
|
|
18
23
|
make_sun_position_numpy_sample,
|
|
19
24
|
)
|
|
25
|
+
from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
|
|
26
|
+
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
27
|
+
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
28
|
+
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
29
|
+
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
|
|
20
30
|
from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
|
|
21
31
|
merge_dicts,
|
|
22
32
|
fill_nans_in_arrays,
|
|
23
33
|
)
|
|
24
|
-
|
|
25
|
-
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
26
|
-
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
34
|
+
|
|
27
35
|
|
|
28
36
|
xr.set_options(keep_attrs=True)
|
|
29
37
|
|
|
@@ -65,9 +73,10 @@ def process_and_combine_datasets(
|
|
|
65
73
|
gsp_config = config.input_data.gsp
|
|
66
74
|
|
|
67
75
|
if "gsp" in dataset_dict:
|
|
68
|
-
da_gsp =
|
|
76
|
+
da_gsp = dataset_dict["gsp"]
|
|
69
77
|
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
70
|
-
|
|
78
|
+
|
|
79
|
+
# Convert to NumpyBatch
|
|
71
80
|
numpy_modalities.append(
|
|
72
81
|
convert_gsp_to_numpy_sample(
|
|
73
82
|
da_gsp,
|
|
@@ -105,6 +114,7 @@ def process_and_combine_datasets(
|
|
|
105
114
|
|
|
106
115
|
return combined_sample
|
|
107
116
|
|
|
117
|
+
|
|
108
118
|
def compute(xarray_dict: dict) -> dict:
|
|
109
119
|
"""Eagerly load a nested dictionary of xarray DataArrays"""
|
|
110
120
|
for k, v in xarray_dict.items():
|
|
@@ -114,10 +124,8 @@ def compute(xarray_dict: dict) -> dict:
|
|
|
114
124
|
xarray_dict[k] = v.compute(scheduler="single-threaded")
|
|
115
125
|
return xarray_dict
|
|
116
126
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
config: Configuration,
|
|
120
|
-
):
|
|
127
|
+
|
|
128
|
+
def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
|
|
121
129
|
"""Find the t0 times where all of the requested input data is available
|
|
122
130
|
|
|
123
131
|
Args:
|
|
@@ -167,7 +175,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
167
175
|
self,
|
|
168
176
|
config_filename: str,
|
|
169
177
|
start_time: str | None = None,
|
|
170
|
-
end_time: str| None = None,
|
|
178
|
+
end_time: str | None = None,
|
|
171
179
|
gsp_ids: list[int] | None = None,
|
|
172
180
|
):
|
|
173
181
|
"""A torch Dataset for creating PVNet UK GSP samples
|
|
@@ -253,7 +261,7 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
253
261
|
def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
|
|
254
262
|
"""Generate a sample for the given coordinates.
|
|
255
263
|
|
|
256
|
-
Useful for users to generate samples
|
|
264
|
+
Useful for users to generate specific samples.
|
|
257
265
|
|
|
258
266
|
Args:
|
|
259
267
|
t0: init-time for sample
|
|
@@ -265,4 +273,94 @@ class PVNetUKRegionalDataset(Dataset):
|
|
|
265
273
|
|
|
266
274
|
location = self.location_lookup[gsp_id]
|
|
267
275
|
|
|
268
|
-
return self._get_sample(t0, location)
|
|
276
|
+
return self._get_sample(t0, location)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class PVNetUKConcurrentDataset(Dataset):
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
config_filename: str,
|
|
283
|
+
start_time: str | None = None,
|
|
284
|
+
end_time: str | None = None,
|
|
285
|
+
gsp_ids: list[int] | None = None,
|
|
286
|
+
):
|
|
287
|
+
"""A torch Dataset for creating concurrent samples of PVNet UK regional data
|
|
288
|
+
|
|
289
|
+
Each concurrent sample includes the data from all GSPs for a single t0 time
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
config_filename: Path to the configuration file
|
|
293
|
+
start_time: Limit the init-times to be after this
|
|
294
|
+
end_time: Limit the init-times to be before this
|
|
295
|
+
gsp_ids: List of all GSP IDs included in each sample. Defaults to all
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
config = load_yaml_configuration(config_filename)
|
|
299
|
+
|
|
300
|
+
datasets_dict = get_dataset_dict(config)
|
|
301
|
+
|
|
302
|
+
# Get t0 times where all input data is available
|
|
303
|
+
valid_t0_times = find_valid_t0_times(datasets_dict, config)
|
|
304
|
+
|
|
305
|
+
# Filter t0 times to given range
|
|
306
|
+
if start_time is not None:
|
|
307
|
+
valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
|
|
308
|
+
|
|
309
|
+
if end_time is not None:
|
|
310
|
+
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
|
|
311
|
+
|
|
312
|
+
# Construct list of locations to sample from
|
|
313
|
+
locations = get_gsp_locations(gsp_ids)
|
|
314
|
+
|
|
315
|
+
# Assign coords and indices to self
|
|
316
|
+
self.valid_t0_times = valid_t0_times
|
|
317
|
+
self.locations = locations
|
|
318
|
+
|
|
319
|
+
# Assign config and input data to self
|
|
320
|
+
self.datasets_dict = datasets_dict
|
|
321
|
+
self.config = config
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def __len__(self):
|
|
325
|
+
return len(self.valid_t0_times)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _get_sample(self, t0: pd.Timestamp) -> dict:
|
|
329
|
+
"""Generate a concurrent PVNet sample for given init-time
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
t0: init-time for sample
|
|
333
|
+
"""
|
|
334
|
+
# Slice by time then load to avoid loading the data multiple times from disk
|
|
335
|
+
sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
|
|
336
|
+
sample_dict = compute(sample_dict)
|
|
337
|
+
|
|
338
|
+
gsp_samples = []
|
|
339
|
+
|
|
340
|
+
# Prepare sample for each GSP
|
|
341
|
+
for location in self.locations:
|
|
342
|
+
gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
343
|
+
gsp_numpy_sample = process_and_combine_datasets(
|
|
344
|
+
gsp_sample_dict, self.config, t0, location
|
|
345
|
+
)
|
|
346
|
+
gsp_samples.append(gsp_numpy_sample)
|
|
347
|
+
|
|
348
|
+
# Stack GSP samples
|
|
349
|
+
return stack_np_samples_into_batch(gsp_samples)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def __getitem__(self, idx):
|
|
353
|
+
return self._get_sample(self.valid_t0_times[idx])
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_sample(self, t0: pd.Timestamp) -> dict:
|
|
357
|
+
"""Generate a sample for the given init-time.
|
|
358
|
+
|
|
359
|
+
Useful for users to generate specific samples.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
t0: init-time for sample
|
|
363
|
+
"""
|
|
364
|
+
# Check data is availablle for init-time t0
|
|
365
|
+
assert t0 in self.valid_t0_times
|
|
366
|
+
return self._get_sample(t0)
|
|
@@ -19,7 +19,7 @@ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68
|
|
|
19
19
|
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
|
|
20
20
|
ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
|
|
21
21
|
ocf_data_sampler/numpy_sample/__init__.py,sha256=nY5C6CcuxiWZ_jrXRzWtN7WyKXhJImSiVTIG6Rz4B_4,401
|
|
22
|
-
ocf_data_sampler/numpy_sample/collate.py,sha256=
|
|
22
|
+
ocf_data_sampler/numpy_sample/collate.py,sha256=Onl_aKhsZ4pbFJsh70orjsHk523GHxrpRirH2vJq_GA,1911
|
|
23
23
|
ocf_data_sampler/numpy_sample/datetime_features.py,sha256=U-9uRplfZ7VYFA4qBduI8OkG2x_65RYIP8wrLG4i-Nw,1441
|
|
24
24
|
ocf_data_sampler/numpy_sample/gsp.py,sha256=5UaWO_aGRRVQo82wnDaT4zBKHihOnIsXiwgPjM8vGFM,1005
|
|
25
25
|
ocf_data_sampler/numpy_sample/nwp.py,sha256=_seQNWsut3IzPsrpipqImjnaM3XNHZCy5_5be6syivk,1297
|
|
@@ -29,32 +29,32 @@ ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrd
|
|
|
29
29
|
ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
|
|
30
30
|
ocf_data_sampler/sample/base.py,sha256=4U78tczCRsKMDwU4HkD20nyGyYjIBSZV5neF2mT--2M,1197
|
|
31
31
|
ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
|
|
32
|
-
ocf_data_sampler/sample/uk_regional.py,sha256=
|
|
32
|
+
ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
34
34
|
ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
|
|
35
|
-
ocf_data_sampler/select/fill_time_periods.py,sha256=
|
|
35
|
+
ocf_data_sampler/select/fill_time_periods.py,sha256=h0XD1Ds_wUUoy-7bILxmN8AIbjlQ6YdXRKuCk_Is5jo,460
|
|
36
36
|
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=q7IaNfX95A3z9XHqbhgtkZ4Js1gn5K9Qyp6DVLbsL-Q,11093
|
|
37
37
|
ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
|
|
38
38
|
ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
|
|
39
39
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
|
|
40
40
|
ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56sNWk3BnCnkCgcPI,4725
|
|
41
41
|
ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
|
|
42
|
-
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=
|
|
43
|
-
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=
|
|
44
|
-
ocf_data_sampler/torch_datasets/datasets/
|
|
42
|
+
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
|
|
43
|
+
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
|
|
44
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=rodvVSR4sh8qZ2hLdI8qAc3lyxq5U7cVGfS4rRKCzbs,11944
|
|
45
45
|
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=5T8nkTMUHHFidZRuFOunYeKAqNuyZ8V7sikBoBOBwwA,16033
|
|
46
46
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
47
47
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
48
48
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
49
49
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
|
-
tests/conftest.py,sha256=
|
|
50
|
+
tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
|
|
51
51
|
tests/config/test_config.py,sha256=Vq_kTL5tJcwEP-hXD_Nah5O6cgafo99iX6Fw1AN5NDY,5288
|
|
52
52
|
tests/config/test_save.py,sha256=rA_XVxP1pOxB--5Ebujz4T5o-VbcrCbg2VSlSq2iI0o,1318
|
|
53
53
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
54
54
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
55
|
-
tests/load/test_load_satellite.py,sha256=
|
|
55
|
+
tests/load/test_load_satellite.py,sha256=IQ8ISRZKCEoi8IsJoPpXZJTolD0mwjnl2E7762RM_PM,524
|
|
56
56
|
tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
|
|
57
|
-
tests/numpy_sample/test_collate.py,sha256=
|
|
57
|
+
tests/numpy_sample/test_collate.py,sha256=RqHCD5_LTRpe4r6kqC_2TKhmhM_IHYM0ZtFUvSjDqcM,654
|
|
58
58
|
tests/numpy_sample/test_datetime_features.py,sha256=o4t3KeKFdGrOBQ77rNFcDuDMQSD23ileCS5T5AP3wG4,1769
|
|
59
59
|
tests/numpy_sample/test_gsp.py,sha256=FLlq4SlJ-9cSRAepf4_ksA6PsUVKegnKEAc5pUojCJ0,1458
|
|
60
60
|
tests/numpy_sample/test_nwp.py,sha256=yf4u7mAU0E3FQ4xAH6YjuHuHBzzFoXjHSFNkOVJUdSM,1455
|
|
@@ -69,12 +69,11 @@ tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4
|
|
|
69
69
|
tests/test_sample/test_base.py,sha256=ljtB38MmscTGN6OvUgclBceNnfx6m7AN8iHYDml9XW4,2189
|
|
70
70
|
tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
|
|
71
71
|
tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
|
|
72
|
-
tests/torch_datasets/
|
|
73
|
-
tests/torch_datasets/
|
|
74
|
-
tests/torch_datasets/
|
|
75
|
-
|
|
76
|
-
ocf_data_sampler-0.1.
|
|
77
|
-
ocf_data_sampler-0.1.
|
|
78
|
-
ocf_data_sampler-0.1.
|
|
79
|
-
ocf_data_sampler-0.1.
|
|
80
|
-
ocf_data_sampler-0.1.0.dist-info/RECORD,,
|
|
72
|
+
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
|
|
73
|
+
tests/torch_datasets/test_pvnet_uk.py,sha256=OzT9ArdnWPa3iJKggxc2-7npkDqWmZyS5pzM4M08NZU,5566
|
|
74
|
+
tests/torch_datasets/test_site.py,sha256=5MH5zkHFJXekwpnV6nHuSxt_sRNu9_mxiUjfWqmEhr0,6966
|
|
75
|
+
ocf_data_sampler-0.1.2.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
76
|
+
ocf_data_sampler-0.1.2.dist-info/METADATA,sha256=tWyfIpvmOufUWc7LquOXZ6g5Le_WhwihhZQBSZ0WKhA,12173
|
|
77
|
+
ocf_data_sampler-0.1.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
78
|
+
ocf_data_sampler-0.1.2.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
79
|
+
ocf_data_sampler-0.1.2.dist-info/RECORD,,
|
tests/conftest.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import numpy as np
|
|
3
5
|
import pandas as pd
|
|
4
|
-
import pytest
|
|
5
6
|
import xarray as xr
|
|
6
|
-
import
|
|
7
|
-
from typing import Generator
|
|
7
|
+
import dask.array
|
|
8
8
|
|
|
9
9
|
from ocf_data_sampler.config.model import Site
|
|
10
10
|
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
_top_test_directory = os.path.dirname(os.path.realpath(__file__))
|
|
13
14
|
|
|
14
15
|
@pytest.fixture()
|
|
@@ -18,40 +19,27 @@ def test_config_filename():
|
|
|
18
19
|
|
|
19
20
|
@pytest.fixture(scope="session")
|
|
20
21
|
def config_filename():
|
|
21
|
-
return f"{
|
|
22
|
+
return f"{_top_test_directory}/test_data/configs/pvnet_test_config.yaml"
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
@pytest.fixture(scope="session")
|
|
25
|
-
def
|
|
26
|
-
|
|
27
|
-
# Load dataset which only contains coordinates, but no data
|
|
28
|
-
ds = xr.open_zarr(
|
|
29
|
-
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr.zip"
|
|
30
|
-
).compute()
|
|
31
|
-
|
|
32
|
-
# Add time coord
|
|
33
|
-
ds = ds.assign_coords(time=pd.date_range("2023-01-01 00:00", "2023-01-02 23:55", freq="5min"))
|
|
34
|
-
|
|
35
|
-
# Add data to dataset
|
|
36
|
-
ds["data"] = xr.DataArray(
|
|
37
|
-
np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32),
|
|
38
|
-
coords=ds.coords,
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
# Transpose to variables, time, y, x (just in case)
|
|
42
|
-
ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
|
|
26
|
+
def session_tmp_path(tmp_path_factory):
|
|
27
|
+
return tmp_path_factory.mktemp("data")
|
|
43
28
|
|
|
44
|
-
# add 100,000 to x_geostationary, this to make sure the fix index is within the satellite image
|
|
45
|
-
ds["x_geostationary"] = ds["x_geostationary"] - 200_000
|
|
46
29
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
#
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
30
|
+
@pytest.fixture(scope="session")
|
|
31
|
+
def sat_zarr_path(session_tmp_path):
|
|
32
|
+
|
|
33
|
+
# Define coords for satellite-like dataset
|
|
34
|
+
variables = [
|
|
35
|
+
'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
|
|
36
|
+
'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073',
|
|
37
|
+
]
|
|
38
|
+
x = np.linspace(start=15002, stop=-1824245, num=100)
|
|
39
|
+
y = np.linspace(start=4191563, stop=5304712, num=100)
|
|
40
|
+
times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min")
|
|
41
|
+
|
|
42
|
+
area_string = (
|
|
55
43
|
"""msg_seviri_rss_3km:
|
|
56
44
|
description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution
|
|
57
45
|
projection:
|
|
@@ -73,16 +61,31 @@ def sat_zarr_path():
|
|
|
73
61
|
units: m
|
|
74
62
|
"""
|
|
75
63
|
)
|
|
76
|
-
|
|
77
|
-
#
|
|
78
|
-
|
|
64
|
+
|
|
65
|
+
# Create satellite-like data with some NaNs
|
|
66
|
+
data = dask.array.zeros(
|
|
67
|
+
shape=(len(variables), len(times), len(y), len(x)),
|
|
68
|
+
chunks=(-1, 10, -1, -1),
|
|
69
|
+
dtype=np.float32
|
|
70
|
+
)
|
|
71
|
+
data [:, 10, :, :] = np.nan
|
|
72
|
+
|
|
73
|
+
ds = xr.DataArray(
|
|
74
|
+
data=data,
|
|
75
|
+
coords=dict(
|
|
76
|
+
variable=variables,
|
|
77
|
+
time=times,
|
|
78
|
+
y_geostationary=y,
|
|
79
|
+
x_geostationary=x,
|
|
80
|
+
),
|
|
81
|
+
attrs=dict(area=area_string),
|
|
82
|
+
).to_dataset(name="data")
|
|
79
83
|
|
|
80
84
|
# Save temporarily as a zarr
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
ds.to_zarr(zarr_path)
|
|
85
|
+
zarr_path = session_tmp_path / "test_sat.zarr"
|
|
86
|
+
ds.to_zarr(zarr_path)
|
|
84
87
|
|
|
85
|
-
|
|
88
|
+
yield zarr_path
|
|
86
89
|
|
|
87
90
|
|
|
88
91
|
@pytest.fixture(scope="session")
|
|
@@ -112,7 +115,7 @@ def ds_nwp_ukv():
|
|
|
112
115
|
|
|
113
116
|
|
|
114
117
|
@pytest.fixture(scope="session")
|
|
115
|
-
def nwp_ukv_zarr_path(ds_nwp_ukv):
|
|
118
|
+
def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
|
|
116
119
|
ds = ds_nwp_ukv.chunk(
|
|
117
120
|
{
|
|
118
121
|
"init_time": 1,
|
|
@@ -122,10 +125,9 @@ def nwp_ukv_zarr_path(ds_nwp_ukv):
|
|
|
122
125
|
"y": 50,
|
|
123
126
|
}
|
|
124
127
|
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
yield filename
|
|
128
|
+
zarr_path = session_tmp_path / "ukv_nwp.zarr"
|
|
129
|
+
ds.to_zarr(zarr_path)
|
|
130
|
+
yield zarr_path
|
|
129
131
|
|
|
130
132
|
|
|
131
133
|
@pytest.fixture(scope="session")
|
|
@@ -155,7 +157,7 @@ def ds_nwp_ecmwf():
|
|
|
155
157
|
|
|
156
158
|
|
|
157
159
|
@pytest.fixture(scope="session")
|
|
158
|
-
def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
|
|
160
|
+
def nwp_ecmwf_zarr_path(session_tmp_path, ds_nwp_ecmwf):
|
|
159
161
|
ds = ds_nwp_ecmwf.chunk(
|
|
160
162
|
{
|
|
161
163
|
"init_time": 1,
|
|
@@ -165,10 +167,10 @@ def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
|
|
|
165
167
|
"latitude": 50,
|
|
166
168
|
}
|
|
167
169
|
)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
170
|
+
|
|
171
|
+
zarr_path = session_tmp_path / "ukv_ecmwf.zarr"
|
|
172
|
+
ds.to_zarr(zarr_path)
|
|
173
|
+
yield zarr_path
|
|
172
174
|
|
|
173
175
|
|
|
174
176
|
@pytest.fixture(scope="session")
|
|
@@ -201,7 +203,7 @@ def ds_uk_gsp():
|
|
|
201
203
|
|
|
202
204
|
|
|
203
205
|
@pytest.fixture(scope="session")
|
|
204
|
-
def data_sites() ->
|
|
206
|
+
def data_sites(session_tmp_path) -> Site:
|
|
205
207
|
"""
|
|
206
208
|
Make fake data for sites
|
|
207
209
|
Returns: filename for netcdf file, and csv metadata
|
|
@@ -245,30 +247,27 @@ def data_sites() -> Generator[Site, None, None]:
|
|
|
245
247
|
"generation_kw": da_gen,
|
|
246
248
|
})
|
|
247
249
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
)
|
|
250
|
+
filename = f"{session_tmp_path}/sites.netcdf"
|
|
251
|
+
filename_csv = f"{session_tmp_path}/sites_metadata.csv"
|
|
252
|
+
generation.to_netcdf(filename)
|
|
253
|
+
meta_df.to_csv(filename_csv)
|
|
254
|
+
|
|
255
|
+
site = Site(
|
|
256
|
+
file_path=filename,
|
|
257
|
+
metadata_file_path=filename_csv,
|
|
258
|
+
interval_start_minutes=-30,
|
|
259
|
+
interval_end_minutes=60,
|
|
260
|
+
time_resolution_minutes=30,
|
|
261
|
+
)
|
|
261
262
|
|
|
262
|
-
|
|
263
|
+
yield site
|
|
263
264
|
|
|
264
265
|
|
|
265
266
|
@pytest.fixture(scope="session")
|
|
266
|
-
def uk_gsp_zarr_path(ds_uk_gsp):
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
ds_uk_gsp.to_zarr(filename)
|
|
271
|
-
yield filename
|
|
267
|
+
def uk_gsp_zarr_path(session_tmp_path, ds_uk_gsp):
|
|
268
|
+
zarr_path = session_tmp_path / "uk_gsp.zarr"
|
|
269
|
+
ds_uk_gsp.to_zarr(zarr_path)
|
|
270
|
+
yield zarr_path
|
|
272
271
|
|
|
273
272
|
|
|
274
273
|
@pytest.fixture()
|
|
@@ -8,10 +8,10 @@ def test_open_satellite(sat_zarr_path):
|
|
|
8
8
|
|
|
9
9
|
assert isinstance(da, xr.DataArray)
|
|
10
10
|
assert da.dims == ("time_utc", "channel", "x_geostationary", "y_geostationary")
|
|
11
|
-
#
|
|
11
|
+
# 288 is 1 days of data at 5 minutes intervals, 12 * 24
|
|
12
12
|
# There are 11 channels
|
|
13
|
-
# There are
|
|
14
|
-
assert da.shape == (
|
|
13
|
+
# There are 100 x 100 pixels
|
|
14
|
+
assert da.shape == (288, 11, 100, 100)
|
|
15
15
|
assert np.issubdtype(da.dtype, np.number)
|
|
16
16
|
|
|
17
17
|
|
|
@@ -1,17 +1,12 @@
|
|
|
1
|
-
from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
|
|
2
1
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
3
|
-
from ocf_data_sampler.torch_datasets.datasets.
|
|
2
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
|
|
4
3
|
|
|
5
4
|
|
|
6
|
-
def
|
|
5
|
+
def test_stack_np_samples_into_batch(pvnet_config_filename):
|
|
7
6
|
|
|
8
7
|
# Create dataset object
|
|
9
8
|
dataset = PVNetUKRegionalDataset(pvnet_config_filename)
|
|
10
9
|
|
|
11
|
-
assert len(dataset.locations) == 317
|
|
12
|
-
assert len(dataset.valid_t0_times) == 39
|
|
13
|
-
assert len(dataset) == 317 * 39
|
|
14
|
-
|
|
15
10
|
# Generate 2 samples
|
|
16
11
|
sample1 = dataset[0]
|
|
17
12
|
sample2 = dataset[1]
|
|
@@ -22,5 +17,5 @@ def test_pvnet(pvnet_config_filename):
|
|
|
22
17
|
assert "nwp" in batch
|
|
23
18
|
assert isinstance(batch["nwp"], dict)
|
|
24
19
|
assert "ukv" in batch["nwp"]
|
|
25
|
-
assert
|
|
26
|
-
assert
|
|
20
|
+
assert "gsp" in batch
|
|
21
|
+
assert "satellite_actual" in batch
|
|
@@ -33,9 +33,7 @@ def test_fill_nans_in_arrays():
|
|
|
33
33
|
|
|
34
34
|
result = fill_nans_in_arrays(nested_dict)
|
|
35
35
|
|
|
36
|
-
assert not np.isnan(result["array1"]).any()
|
|
37
36
|
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
|
|
38
|
-
assert not np.isnan(result["nested"]["array2"]).any()
|
|
39
37
|
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
|
|
40
38
|
assert result["string_key"] == "not_an_array"
|
|
41
39
|
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xarray as xr
|
|
4
|
+
import dask.array
|
|
5
|
+
|
|
6
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
7
|
+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import (
|
|
8
|
+
PVNetUKRegionalDataset,
|
|
9
|
+
PVNetUKConcurrentDataset,
|
|
10
|
+
process_and_combine_datasets,
|
|
11
|
+
compute,
|
|
12
|
+
)
|
|
13
|
+
from ocf_data_sampler.select.location import Location
|
|
14
|
+
|
|
15
|
+
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
16
|
+
|
|
17
|
+
# Load in config for function and define location
|
|
18
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
19
|
+
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
20
|
+
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
21
|
+
|
|
22
|
+
nwp_data = xr.DataArray(
|
|
23
|
+
np.random.rand(4, 2, 2, 2),
|
|
24
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
25
|
+
coords={
|
|
26
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
27
|
+
"channel": ["t2m", "dswrf"],
|
|
28
|
+
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
29
|
+
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
30
|
+
}
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
sat_data = xr.DataArray(
|
|
34
|
+
np.random.rand(7, 1, 2, 2),
|
|
35
|
+
dims=["time_utc", "channel", "y", "x"],
|
|
36
|
+
coords={
|
|
37
|
+
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
|
|
38
|
+
"channel": ["HRV"],
|
|
39
|
+
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
|
|
40
|
+
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
|
|
41
|
+
}
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Combine as dict
|
|
45
|
+
dataset_dict = {
|
|
46
|
+
"nwp": {"ukv": nwp_data},
|
|
47
|
+
"sat": sat_data
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Call relevant function
|
|
51
|
+
sample = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
52
|
+
|
|
53
|
+
# Assert result is dict - check and validate
|
|
54
|
+
assert isinstance(sample, dict)
|
|
55
|
+
assert "nwp" in sample
|
|
56
|
+
assert sample["satellite_actual"].shape == (7, 1, 2, 2)
|
|
57
|
+
assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_compute():
|
|
61
|
+
"""Test compute function with dask array"""
|
|
62
|
+
da_dask = xr.DataArray(dask.array.random.random((5, 5)))
|
|
63
|
+
|
|
64
|
+
# Create a nested dictionary with dask array
|
|
65
|
+
lazy_data_dict = {
|
|
66
|
+
"array1": da_dask,
|
|
67
|
+
"nested": {
|
|
68
|
+
"array2": da_dask
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
computed_data_dict = compute(lazy_data_dict)
|
|
73
|
+
|
|
74
|
+
# Assert that the result is no longer lazy
|
|
75
|
+
assert isinstance(computed_data_dict["array1"].data, np.ndarray)
|
|
76
|
+
assert isinstance(computed_data_dict["nested"]["array2"].data, np.ndarray)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_pvnet_uk_regional_dataset(pvnet_config_filename):
|
|
80
|
+
|
|
81
|
+
# Create dataset object
|
|
82
|
+
dataset = PVNetUKRegionalDataset(pvnet_config_filename)
|
|
83
|
+
|
|
84
|
+
assert len(dataset.locations) == 317 # Number of regional GSPs
|
|
85
|
+
# NB. I have not checked the value (39 below) is in fact correct
|
|
86
|
+
assert len(dataset.valid_t0_times) == 39
|
|
87
|
+
assert len(dataset) == 317*39
|
|
88
|
+
|
|
89
|
+
# Generate a sample
|
|
90
|
+
sample = dataset[0]
|
|
91
|
+
|
|
92
|
+
assert isinstance(sample, dict)
|
|
93
|
+
|
|
94
|
+
for key in [
|
|
95
|
+
"nwp", "satellite_actual", "gsp",
|
|
96
|
+
"gsp_solar_azimuth", "gsp_solar_elevation",
|
|
97
|
+
]:
|
|
98
|
+
assert key in sample
|
|
99
|
+
|
|
100
|
+
for nwp_source in ["ukv"]:
|
|
101
|
+
assert nwp_source in sample["nwp"]
|
|
102
|
+
|
|
103
|
+
# Check the shape of the data is correct
|
|
104
|
+
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
105
|
+
assert sample["satellite_actual"].shape == (7, 1, 2, 2)
|
|
106
|
+
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
|
|
107
|
+
assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
|
|
108
|
+
# 3 hours of 30 minute data (inclusive)
|
|
109
|
+
assert sample["gsp"].shape == (7,)
|
|
110
|
+
# Solar angles have same shape as GSP data
|
|
111
|
+
assert sample["gsp_solar_azimuth"].shape == (7,)
|
|
112
|
+
assert sample["gsp_solar_elevation"].shape == (7,)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_pvnet_no_gsp(tmp_path, pvnet_config_filename):
|
|
116
|
+
|
|
117
|
+
# Create new config without GSP inputs
|
|
118
|
+
config = load_yaml_configuration(pvnet_config_filename)
|
|
119
|
+
config.input_data.gsp.zarr_path = ''
|
|
120
|
+
new_config_path = tmp_path / "pvnet_config_no_gsp.yaml"
|
|
121
|
+
save_yaml_configuration(config, new_config_path)
|
|
122
|
+
|
|
123
|
+
# Create dataset object
|
|
124
|
+
dataset = PVNetUKRegionalDataset(new_config_path)
|
|
125
|
+
|
|
126
|
+
# Generate a sample
|
|
127
|
+
_ = dataset[0]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_pvnet_uk_concurrent_dataset(pvnet_config_filename):
|
|
131
|
+
|
|
132
|
+
# Create dataset object using a limited set of GSPs for test
|
|
133
|
+
gsp_ids = [1,2,3]
|
|
134
|
+
num_gsps = len(gsp_ids)
|
|
135
|
+
|
|
136
|
+
dataset = PVNetUKConcurrentDataset(pvnet_config_filename, gsp_ids=gsp_ids)
|
|
137
|
+
|
|
138
|
+
assert len(dataset.locations) == num_gsps # Number of regional GSPs
|
|
139
|
+
# NB. I have not checked the value (39 below) is in fact correct
|
|
140
|
+
assert len(dataset.valid_t0_times) == 39
|
|
141
|
+
assert len(dataset) == 39
|
|
142
|
+
|
|
143
|
+
# Generate a sample
|
|
144
|
+
sample = dataset[0]
|
|
145
|
+
|
|
146
|
+
assert isinstance(sample, dict)
|
|
147
|
+
|
|
148
|
+
for key in [
|
|
149
|
+
"nwp", "satellite_actual", "gsp",
|
|
150
|
+
"gsp_solar_azimuth", "gsp_solar_elevation",
|
|
151
|
+
]:
|
|
152
|
+
assert key in sample
|
|
153
|
+
|
|
154
|
+
for nwp_source in ["ukv"]:
|
|
155
|
+
assert nwp_source in sample["nwp"]
|
|
156
|
+
|
|
157
|
+
# Check the shape of the data is correct
|
|
158
|
+
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
159
|
+
assert sample["satellite_actual"].shape == (num_gsps, 7, 1, 2, 2)
|
|
160
|
+
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
|
|
161
|
+
assert sample["nwp"]["ukv"]["nwp"].shape == (num_gsps, 4, 1, 2, 2)
|
|
162
|
+
# 3 hours of 30 minute data (inclusive)
|
|
163
|
+
assert sample["gsp"].shape == (num_gsps, 7,)
|
|
164
|
+
# Solar angles have same shape as GSP data
|
|
165
|
+
assert sample["gsp_solar_azimuth"].shape == (num_gsps, 7,)
|
|
166
|
+
assert sample["gsp_solar_elevation"].shape == (num_gsps, 7,)
|
|
@@ -1,11 +1,37 @@
|
|
|
1
|
-
import
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
|
-
|
|
4
|
-
from xarray import Dataset, DataArray
|
|
4
|
+
import pandas as pd
|
|
5
5
|
import xarray as xr
|
|
6
6
|
|
|
7
7
|
from torch.utils.data import DataLoader
|
|
8
8
|
|
|
9
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
10
|
+
from ocf_data_sampler.torch_datasets.datasets.site import (
|
|
11
|
+
SitesDataset, convert_from_dataset_to_dict_datasets, coarsen_data
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture()
|
|
17
|
+
def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
|
|
18
|
+
|
|
19
|
+
# adjust config to point to the zarr file
|
|
20
|
+
config = load_yaml_configuration(config_filename)
|
|
21
|
+
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
22
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
23
|
+
config.input_data.site = data_sites
|
|
24
|
+
config.input_data.gsp = None
|
|
25
|
+
|
|
26
|
+
filename = f"{tmp_path}/configuration.yaml"
|
|
27
|
+
save_yaml_configuration(config, filename)
|
|
28
|
+
yield filename
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@pytest.fixture()
|
|
32
|
+
def sites_dataset(site_config_filename):
|
|
33
|
+
return SitesDataset(site_config_filename)
|
|
34
|
+
|
|
9
35
|
|
|
10
36
|
def test_site(site_config_filename):
|
|
11
37
|
|
|
@@ -18,7 +44,7 @@ def test_site(site_config_filename):
|
|
|
18
44
|
# Generate a sample
|
|
19
45
|
sample = dataset[0]
|
|
20
46
|
|
|
21
|
-
assert isinstance(sample, Dataset)
|
|
47
|
+
assert isinstance(sample, xr.Dataset)
|
|
22
48
|
|
|
23
49
|
# Expected dimensions and data variables
|
|
24
50
|
expected_dims = {
|
|
@@ -85,21 +111,14 @@ def test_site_time_filter_end(site_config_filename):
|
|
|
85
111
|
assert len(dataset) == 0
|
|
86
112
|
|
|
87
113
|
|
|
88
|
-
def test_site_get_sample(
|
|
89
|
-
|
|
90
|
-
# Create dataset object
|
|
91
|
-
dataset = SitesDataset(site_config_filename)
|
|
92
|
-
|
|
93
|
-
assert len(dataset) == 410
|
|
94
|
-
sample = dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
|
|
114
|
+
def test_site_get_sample(sites_dataset):
|
|
115
|
+
sample = sites_dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
|
|
95
116
|
|
|
96
117
|
|
|
97
|
-
def test_convert_from_dataset_to_dict_datasets(
|
|
98
|
-
# Create dataset object
|
|
99
|
-
dataset = SitesDataset(site_config_filename)
|
|
118
|
+
def test_convert_from_dataset_to_dict_datasets(sites_dataset):
|
|
100
119
|
|
|
101
|
-
# Generate
|
|
102
|
-
sample_xr =
|
|
120
|
+
# Generate sample
|
|
121
|
+
sample_xr = sites_dataset[0]
|
|
103
122
|
|
|
104
123
|
sample = convert_from_dataset_to_dict_datasets(sample_xr)
|
|
105
124
|
|
|
@@ -109,9 +128,7 @@ def test_convert_from_dataset_to_dict_datasets(site_config_filename):
|
|
|
109
128
|
assert key in sample
|
|
110
129
|
|
|
111
130
|
|
|
112
|
-
def test_site_dataset_with_dataloader(
|
|
113
|
-
# Create dataset object
|
|
114
|
-
dataset = SitesDataset(site_config_filename)
|
|
131
|
+
def test_site_dataset_with_dataloader(sites_dataset):
|
|
115
132
|
|
|
116
133
|
expected_coods = {
|
|
117
134
|
"site__solar_azimuth",
|
|
@@ -122,10 +139,6 @@ def test_site_dataset_with_dataloader(site_config_filename):
|
|
|
122
139
|
"site__date_sin",
|
|
123
140
|
}
|
|
124
141
|
|
|
125
|
-
sample = dataset[0]
|
|
126
|
-
for key in expected_coods:
|
|
127
|
-
assert key in sample
|
|
128
|
-
|
|
129
142
|
dataloader_kwargs = dict(
|
|
130
143
|
shuffle=False,
|
|
131
144
|
batch_size=None,
|
|
@@ -141,25 +154,23 @@ def test_site_dataset_with_dataloader(site_config_filename):
|
|
|
141
154
|
persistent_workers=False, # Not needed since we only enter the dataloader loop once
|
|
142
155
|
)
|
|
143
156
|
|
|
144
|
-
dataloader = DataLoader(
|
|
157
|
+
dataloader = DataLoader(sites_dataset, collate_fn=None, batch_size=None)
|
|
145
158
|
|
|
146
|
-
|
|
159
|
+
sample = next(iter(dataloader))
|
|
160
|
+
|
|
161
|
+
# check that expected_dims is in the sample
|
|
162
|
+
for key in expected_coods:
|
|
163
|
+
assert key in sample
|
|
147
164
|
|
|
148
|
-
# check that expected_dims is in the sample
|
|
149
|
-
for key in expected_coods:
|
|
150
|
-
assert key in sample
|
|
151
165
|
|
|
166
|
+
def test_process_and_combine_site_sample_dict(sites_dataset):
|
|
152
167
|
|
|
153
|
-
def test_process_and_combine_site_sample_dict(site_config_filename):
|
|
154
|
-
# Load config
|
|
155
|
-
# config = load_yaml_configuration(pvnet_config_filename)
|
|
156
|
-
site_ds = SitesDataset(site_config_filename)
|
|
157
168
|
# Specify minimal structure for testing
|
|
158
169
|
raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
|
|
159
170
|
fake_site_values = np.random.rand(197)
|
|
160
171
|
site_dict = {
|
|
161
172
|
"nwp": {
|
|
162
|
-
"ukv": DataArray(
|
|
173
|
+
"ukv": xr.DataArray(
|
|
163
174
|
raw_nwp_values,
|
|
164
175
|
dims=["time_utc", "channel", "y", "x"],
|
|
165
176
|
coords={
|
|
@@ -168,7 +179,7 @@ def test_process_and_combine_site_sample_dict(site_config_filename):
|
|
|
168
179
|
},
|
|
169
180
|
)
|
|
170
181
|
},
|
|
171
|
-
"site": DataArray(
|
|
182
|
+
"site": xr.DataArray(
|
|
172
183
|
fake_site_values,
|
|
173
184
|
dims=["time_utc"],
|
|
174
185
|
coords={
|
|
@@ -183,10 +194,10 @@ def test_process_and_combine_site_sample_dict(site_config_filename):
|
|
|
183
194
|
print(f"Input site_dict: {site_dict}")
|
|
184
195
|
|
|
185
196
|
# Call function
|
|
186
|
-
result =
|
|
197
|
+
result = sites_dataset.process_and_combine_site_sample_dict(site_dict)
|
|
187
198
|
|
|
188
199
|
# Assert to validate output structure
|
|
189
|
-
assert isinstance(result, Dataset), "Result should be an xarray.Dataset"
|
|
200
|
+
assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
|
|
190
201
|
assert len(result.data_vars) > 0, "Dataset should contain data variables"
|
|
191
202
|
|
|
192
203
|
# Validate variable via assertion and shape of such
|
tests/torch_datasets/conftest.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
|
|
3
|
-
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
@pytest.fixture()
|
|
7
|
-
def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
|
|
8
|
-
|
|
9
|
-
# adjust config to point to the zarr file
|
|
10
|
-
config = load_yaml_configuration(config_filename)
|
|
11
|
-
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
12
|
-
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
13
|
-
config.input_data.site = data_sites
|
|
14
|
-
config.input_data.gsp = None
|
|
15
|
-
|
|
16
|
-
filename = f"{tmp_path}/configuration.yaml"
|
|
17
|
-
save_yaml_configuration(config, filename)
|
|
18
|
-
return filename
|
|
@@ -1,136 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import xarray as xr
|
|
4
|
-
import dask.array as da
|
|
5
|
-
import tempfile
|
|
6
|
-
|
|
7
|
-
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
|
|
8
|
-
from ocf_data_sampler.config.save import save_yaml_configuration
|
|
9
|
-
from ocf_data_sampler.config.load import load_yaml_configuration
|
|
10
|
-
from ocf_data_sampler.numpy_sample import NWPSampleKey, GSPSampleKey, SatelliteSampleKey
|
|
11
|
-
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import process_and_combine_datasets, compute
|
|
12
|
-
from ocf_data_sampler.select.location import Location
|
|
13
|
-
|
|
14
|
-
def test_process_and_combine_datasets(pvnet_config_filename):
|
|
15
|
-
|
|
16
|
-
# Load in config for function and define location
|
|
17
|
-
config = load_yaml_configuration(pvnet_config_filename)
|
|
18
|
-
t0 = pd.Timestamp("2024-01-01 00:00")
|
|
19
|
-
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)
|
|
20
|
-
|
|
21
|
-
nwp_data = xr.DataArray(
|
|
22
|
-
np.random.rand(4, 2, 2, 2),
|
|
23
|
-
dims=["time_utc", "channel", "y", "x"],
|
|
24
|
-
coords={
|
|
25
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
26
|
-
"channel": ["t2m", "dswrf"],
|
|
27
|
-
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
|
|
28
|
-
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
|
|
29
|
-
}
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
sat_data = xr.DataArray(
|
|
33
|
-
np.random.rand(7, 1, 2, 2),
|
|
34
|
-
dims=["time_utc", "channel", "y", "x"],
|
|
35
|
-
coords={
|
|
36
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
|
|
37
|
-
"channel": ["HRV"],
|
|
38
|
-
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
|
|
39
|
-
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
|
|
40
|
-
}
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
# Combine as dict
|
|
44
|
-
dataset_dict = {
|
|
45
|
-
"nwp": {"ukv": nwp_data},
|
|
46
|
-
"sat": sat_data
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
# Call relevant function
|
|
50
|
-
result = process_and_combine_datasets(dataset_dict, config, t0, location)
|
|
51
|
-
|
|
52
|
-
# Assert result is dict - check and validate
|
|
53
|
-
assert isinstance(result, dict)
|
|
54
|
-
assert NWPSampleKey.nwp in result
|
|
55
|
-
assert result[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
56
|
-
assert result[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
|
|
57
|
-
|
|
58
|
-
def test_compute():
|
|
59
|
-
"""Test compute function with dask array"""
|
|
60
|
-
da_dask = xr.DataArray(da.random.random((5, 5)))
|
|
61
|
-
|
|
62
|
-
# Create a nested dictionary with dask array
|
|
63
|
-
nested_dict = {
|
|
64
|
-
"array1": da_dask,
|
|
65
|
-
"nested": {
|
|
66
|
-
"array2": da_dask
|
|
67
|
-
}
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
# Ensure initial data is lazy - i.e. not yet computed
|
|
71
|
-
assert not isinstance(nested_dict["array1"].data, np.ndarray)
|
|
72
|
-
assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)
|
|
73
|
-
|
|
74
|
-
# Call the compute function
|
|
75
|
-
result = compute(nested_dict)
|
|
76
|
-
|
|
77
|
-
# Assert that the result is an xarray DataArray and no longer lazy
|
|
78
|
-
assert isinstance(result["array1"], xr.DataArray)
|
|
79
|
-
assert isinstance(result["nested"]["array2"], xr.DataArray)
|
|
80
|
-
assert isinstance(result["array1"].data, np.ndarray)
|
|
81
|
-
assert isinstance(result["nested"]["array2"].data, np.ndarray)
|
|
82
|
-
|
|
83
|
-
# Ensure there no NaN values in computed data
|
|
84
|
-
assert not np.isnan(result["array1"].data).any()
|
|
85
|
-
assert not np.isnan(result["nested"]["array2"].data).any()
|
|
86
|
-
|
|
87
|
-
def test_pvnet(pvnet_config_filename):
|
|
88
|
-
|
|
89
|
-
# Create dataset object
|
|
90
|
-
dataset = PVNetUKRegionalDataset(pvnet_config_filename)
|
|
91
|
-
|
|
92
|
-
assert len(dataset.locations) == 317 # no of GSPs not including the National level
|
|
93
|
-
# NB. I have not checked this value is in fact correct, but it does seem to stay constant
|
|
94
|
-
assert len(dataset.valid_t0_times) == 39
|
|
95
|
-
assert len(dataset) == 317*39
|
|
96
|
-
|
|
97
|
-
# Generate a sample
|
|
98
|
-
sample = dataset[0]
|
|
99
|
-
|
|
100
|
-
assert isinstance(sample, dict)
|
|
101
|
-
|
|
102
|
-
for key in [
|
|
103
|
-
NWPSampleKey.nwp, SatelliteSampleKey.satellite_actual, GSPSampleKey.gsp,
|
|
104
|
-
GSPSampleKey.solar_azimuth, GSPSampleKey.solar_elevation,
|
|
105
|
-
]:
|
|
106
|
-
assert key in sample
|
|
107
|
-
|
|
108
|
-
for nwp_source in ["ukv"]:
|
|
109
|
-
assert nwp_source in sample[NWPSampleKey.nwp]
|
|
110
|
-
|
|
111
|
-
# check the shape of the data is correct
|
|
112
|
-
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
113
|
-
assert sample[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
|
|
114
|
-
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
|
|
115
|
-
assert sample[NWPSampleKey.nwp]["ukv"][NWPSampleKey.nwp].shape == (4, 1, 2, 2)
|
|
116
|
-
# 3 hours of 30 minute data (inclusive)
|
|
117
|
-
assert sample[GSPSampleKey.gsp].shape == (7,)
|
|
118
|
-
# Solar angles have same shape as GSP data
|
|
119
|
-
assert sample[GSPSampleKey.solar_azimuth].shape == (7,)
|
|
120
|
-
assert sample[GSPSampleKey.solar_elevation].shape == (7,)
|
|
121
|
-
|
|
122
|
-
def test_pvnet_no_gsp(pvnet_config_filename):
|
|
123
|
-
|
|
124
|
-
# load config
|
|
125
|
-
config = load_yaml_configuration(pvnet_config_filename)
|
|
126
|
-
# remove gsp
|
|
127
|
-
config.input_data.gsp.zarr_path = ''
|
|
128
|
-
|
|
129
|
-
# save temp config file
|
|
130
|
-
with tempfile.NamedTemporaryFile() as temp_config_file:
|
|
131
|
-
save_yaml_configuration(config, temp_config_file.name)
|
|
132
|
-
# Create dataset object
|
|
133
|
-
dataset = PVNetUKRegionalDataset(temp_config_file.name)
|
|
134
|
-
|
|
135
|
-
# Generate a sample
|
|
136
|
-
_ = dataset[0]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|