ocf-data-sampler 0.0.48__py3-none-any.whl → 0.0.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/torch_datasets/datasets/site.py +10 -9
- {ocf_data_sampler-0.0.48.dist-info → ocf_data_sampler-0.0.49.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.0.48.dist-info → ocf_data_sampler-0.0.49.dist-info}/RECORD +7 -7
- tests/torch_datasets/test_site.py +79 -8
- {ocf_data_sampler-0.0.48.dist-info → ocf_data_sampler-0.0.49.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.48.dist-info → ocf_data_sampler-0.0.49.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.0.48.dist-info → ocf_data_sampler-0.0.49.dist-info}/top_level.txt +0 -0
|
@@ -241,29 +241,30 @@ class SitesDataset(Dataset):
|
|
|
241
241
|
|
|
242
242
|
# add datetime features
|
|
243
243
|
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
|
|
244
|
-
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="
|
|
245
|
-
|
|
246
|
-
|
|
244
|
+
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
|
|
245
|
+
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
246
|
+
{k: ("site__time_utc", v) for k, v in datetime_features.items()}
|
|
247
|
+
)
|
|
247
248
|
|
|
248
249
|
# add sun features
|
|
249
250
|
sun_position_features = make_sun_position_numpy_sample(
|
|
250
251
|
datetimes=datetimes,
|
|
251
252
|
lon=combined_sample_dataset.site__longitude.values,
|
|
252
253
|
lat=combined_sample_dataset.site__latitude.values,
|
|
253
|
-
key_prefix="
|
|
254
|
+
key_prefix="site_",
|
|
254
255
|
)
|
|
255
|
-
|
|
256
|
-
|
|
256
|
+
combined_sample_dataset = combined_sample_dataset.assign_coords(
|
|
257
|
+
{k: ("site__time_utc", v) for k, v in sun_position_features.items()}
|
|
257
258
|
)
|
|
258
|
-
combined_sample_dataset = xr.merge([combined_sample_dataset, sun_position_features_xr])
|
|
259
259
|
|
|
260
260
|
# TODO include t0_index in xr dataset?
|
|
261
261
|
|
|
262
262
|
# Fill any nan values
|
|
263
263
|
return combined_sample_dataset.fillna(0.0)
|
|
264
264
|
|
|
265
|
-
|
|
266
|
-
|
|
265
|
+
def merge_data_arrays(
|
|
266
|
+
self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]
|
|
267
|
+
) -> xr.Dataset:
|
|
267
268
|
"""
|
|
268
269
|
Combine a list of DataArrays into a single Dataset with unique naming conventions.
|
|
269
270
|
|
|
@@ -38,7 +38,7 @@ ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq
|
|
|
38
38
|
ocf_data_sampler/select/time_slice_for_dataset.py,sha256=BFjNwWAzhcb1hpqx7UPi5RF9WWt15owbZp1WB-uGA6Q,4305
|
|
39
39
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=nJUa2KzVa84ZoM0PT2AbDz26ennmAYc7M7WJVfypPMs,85
|
|
40
40
|
ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py,sha256=xxeX4Js9LQpydehi3BS7k9psqkYGzgJuM17uTYux40M,8742
|
|
41
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
41
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=v7plMF_WJPkfwnJAUFf_8gXAy8SXE5Og_fgZMEm4c20,15257
|
|
42
42
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
43
43
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
44
44
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
@@ -65,9 +65,9 @@ tests/select/test_select_time_slice.py,sha256=K1EJR5TwZa9dJf_YTEHxGtvs398iy1xS2l
|
|
|
65
65
|
tests/torch_datasets/conftest.py,sha256=eRCzHE7cxS4AoskExkCGFDBeqItktAYNAdkfpMoFCeE,629
|
|
66
66
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=ueA0A7gZaWEgNdsU8p3CnKuvSnlleTUjEhSw2HUUROM,1229
|
|
67
67
|
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=FCiFueeFqrsXe7gWguSjBz5ZeUrvyhGbGw81gaVvkHM,5087
|
|
68
|
-
tests/torch_datasets/test_site.py,sha256=
|
|
69
|
-
ocf_data_sampler-0.0.
|
|
70
|
-
ocf_data_sampler-0.0.
|
|
71
|
-
ocf_data_sampler-0.0.
|
|
72
|
-
ocf_data_sampler-0.0.
|
|
73
|
-
ocf_data_sampler-0.0.
|
|
68
|
+
tests/torch_datasets/test_site.py,sha256=0gT_7k086BBnxqbvOayiUeI-vzJsYXlx3KvACC0c6lk,6114
|
|
69
|
+
ocf_data_sampler-0.0.49.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
70
|
+
ocf_data_sampler-0.0.49.dist-info/METADATA,sha256=GuLd3IDZ7qU9W9wwV84AQ5tN8rlouhF4ZpDThHsVUKo,11788
|
|
71
|
+
ocf_data_sampler-0.0.49.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
72
|
+
ocf_data_sampler-0.0.49.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
73
|
+
ocf_data_sampler-0.0.49.dist-info/RECORD,,
|
|
@@ -3,6 +3,8 @@ import numpy as np
|
|
|
3
3
|
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets
|
|
4
4
|
from xarray import Dataset, DataArray
|
|
5
5
|
|
|
6
|
+
from torch.utils.data import DataLoader
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def test_site(site_config_filename):
|
|
8
10
|
|
|
@@ -18,17 +20,45 @@ def test_site(site_config_filename):
|
|
|
18
20
|
assert isinstance(sample, Dataset)
|
|
19
21
|
|
|
20
22
|
# Expected dimensions and data variables
|
|
21
|
-
expected_dims = {
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
23
|
+
expected_dims = {
|
|
24
|
+
"satellite__x_geostationary",
|
|
25
|
+
"site__time_utc",
|
|
26
|
+
"nwp-ukv__target_time_utc",
|
|
27
|
+
"nwp-ukv__x_osgb",
|
|
28
|
+
"satellite__channel",
|
|
29
|
+
"satellite__y_geostationary",
|
|
30
|
+
"satellite__time_utc",
|
|
31
|
+
"nwp-ukv__channel",
|
|
32
|
+
"nwp-ukv__y_osgb",
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
expected_coords_subset = {
|
|
36
|
+
"site__solar_azimuth",
|
|
37
|
+
"site__solar_elevation",
|
|
38
|
+
"site__date_cos",
|
|
39
|
+
"site__time_cos",
|
|
40
|
+
"site__time_sin",
|
|
41
|
+
"site__date_sin",
|
|
42
|
+
}
|
|
25
43
|
|
|
26
44
|
expected_data_vars = {"nwp-ukv", "satellite", "site"}
|
|
27
45
|
|
|
46
|
+
import xarray as xr
|
|
47
|
+
|
|
48
|
+
sample.to_netcdf("sample.nc")
|
|
49
|
+
sample = xr.open_dataset("sample.nc")
|
|
50
|
+
|
|
28
51
|
# Check dimensions
|
|
29
|
-
assert
|
|
52
|
+
assert (
|
|
53
|
+
set(sample.dims) == expected_dims
|
|
54
|
+
), f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
|
|
30
55
|
# Check data variables
|
|
31
|
-
assert
|
|
56
|
+
assert (
|
|
57
|
+
set(sample.data_vars) == expected_data_vars
|
|
58
|
+
), f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"
|
|
59
|
+
|
|
60
|
+
for coords in expected_coords_subset:
|
|
61
|
+
assert coords in sample.coords
|
|
32
62
|
|
|
33
63
|
# check the shape of the data is correct
|
|
34
64
|
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
@@ -38,6 +68,7 @@ def test_site(site_config_filename):
|
|
|
38
68
|
# 1.5 hours of 30 minute data (inclusive)
|
|
39
69
|
assert sample["site"].values.shape == (4,)
|
|
40
70
|
|
|
71
|
+
|
|
41
72
|
def test_site_time_filter_start(site_config_filename):
|
|
42
73
|
|
|
43
74
|
# Create dataset object
|
|
@@ -74,11 +105,51 @@ def test_convert_from_dataset_to_dict_datasets(site_config_filename):
|
|
|
74
105
|
|
|
75
106
|
assert isinstance(sample, dict)
|
|
76
107
|
|
|
77
|
-
print(sample.keys())
|
|
78
|
-
|
|
79
108
|
for key in ["nwp", "satellite", "site"]:
|
|
80
109
|
assert key in sample
|
|
81
110
|
|
|
111
|
+
|
|
112
|
+
def test_site_dataset_with_dataloader(site_config_filename):
|
|
113
|
+
# Create dataset object
|
|
114
|
+
dataset = SitesDataset(site_config_filename)
|
|
115
|
+
|
|
116
|
+
expected_coods = {
|
|
117
|
+
"site__solar_azimuth",
|
|
118
|
+
"site__solar_elevation",
|
|
119
|
+
"site__date_cos",
|
|
120
|
+
"site__time_cos",
|
|
121
|
+
"site__time_sin",
|
|
122
|
+
"site__date_sin",
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
sample = dataset[0]
|
|
126
|
+
for key in expected_coods:
|
|
127
|
+
assert key in sample
|
|
128
|
+
|
|
129
|
+
dataloader_kwargs = dict(
|
|
130
|
+
shuffle=False,
|
|
131
|
+
batch_size=None,
|
|
132
|
+
sampler=None,
|
|
133
|
+
batch_sampler=None,
|
|
134
|
+
num_workers=1,
|
|
135
|
+
collate_fn=None,
|
|
136
|
+
pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
|
|
137
|
+
drop_last=False,
|
|
138
|
+
timeout=0,
|
|
139
|
+
worker_init_fn=None,
|
|
140
|
+
prefetch_factor=1,
|
|
141
|
+
persistent_workers=False, # Not needed since we only enter the dataloader loop once
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
dataloader = DataLoader(dataset, collate_fn=None, batch_size=None)
|
|
145
|
+
|
|
146
|
+
for i, sample in zip(range(1), dataloader):
|
|
147
|
+
|
|
148
|
+
# check that expected_dims is in the sample
|
|
149
|
+
for key in expected_coods:
|
|
150
|
+
assert key in sample
|
|
151
|
+
|
|
152
|
+
|
|
82
153
|
def test_process_and_combine_site_sample_dict(site_config_filename):
|
|
83
154
|
# Load config
|
|
84
155
|
# config = load_yaml_configuration(pvnet_config_filename)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|