ocf-data-sampler 0.1.11__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/load.py +3 -3
- ocf_data_sampler/config/model.py +146 -64
- ocf_data_sampler/config/save.py +5 -4
- ocf_data_sampler/load/gsp.py +6 -5
- ocf_data_sampler/load/load_dataset.py +5 -6
- ocf_data_sampler/load/nwp/nwp.py +17 -5
- ocf_data_sampler/load/nwp/providers/ecmwf.py +6 -7
- ocf_data_sampler/load/nwp/providers/gfs.py +36 -0
- ocf_data_sampler/load/nwp/providers/icon.py +46 -0
- ocf_data_sampler/load/nwp/providers/ukv.py +4 -5
- ocf_data_sampler/load/nwp/providers/utils.py +3 -1
- ocf_data_sampler/load/satellite.py +9 -10
- ocf_data_sampler/load/site.py +10 -6
- ocf_data_sampler/load/utils.py +21 -16
- ocf_data_sampler/numpy_sample/collate.py +10 -9
- ocf_data_sampler/numpy_sample/datetime_features.py +3 -5
- ocf_data_sampler/numpy_sample/gsp.py +12 -14
- ocf_data_sampler/numpy_sample/nwp.py +12 -12
- ocf_data_sampler/numpy_sample/satellite.py +9 -9
- ocf_data_sampler/numpy_sample/site.py +5 -8
- ocf_data_sampler/numpy_sample/sun_position.py +16 -21
- ocf_data_sampler/sample/base.py +15 -17
- ocf_data_sampler/sample/site.py +13 -20
- ocf_data_sampler/sample/uk_regional.py +29 -35
- ocf_data_sampler/select/dropout.py +16 -14
- ocf_data_sampler/select/fill_time_periods.py +15 -5
- ocf_data_sampler/select/find_contiguous_time_periods.py +88 -75
- ocf_data_sampler/select/geospatial.py +63 -54
- ocf_data_sampler/select/location.py +16 -51
- ocf_data_sampler/select/select_spatial_slice.py +105 -89
- ocf_data_sampler/select/select_time_slice.py +71 -58
- ocf_data_sampler/select/spatial_slice_for_dataset.py +7 -6
- ocf_data_sampler/select/time_slice_for_dataset.py +17 -16
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +140 -131
- ocf_data_sampler/torch_datasets/datasets/site.py +152 -112
- ocf_data_sampler/torch_datasets/utils/__init__.py +3 -0
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +11 -0
- ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +6 -2
- ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +23 -22
- ocf_data_sampler/utils.py +3 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/METADATA +7 -18
- ocf_data_sampler-0.1.17.dist-info/RECORD +56 -0
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.1.11.dist-info → ocf_data_sampler-0.1.17.dist-info}/top_level.txt +1 -1
- scripts/refactor_site.py +63 -33
- utils/compute_icon_mean_stddev.py +72 -0
- ocf_data_sampler/constants.py +0 -222
- ocf_data_sampler/torch_datasets/utils/validate_channels.py +0 -82
- ocf_data_sampler-0.1.11.dist-info/LICENSE +0 -21
- ocf_data_sampler-0.1.11.dist-info/RECORD +0 -82
- tests/__init__.py +0 -0
- tests/config/test_config.py +0 -113
- tests/config/test_load.py +0 -7
- tests/config/test_save.py +0 -28
- tests/conftest.py +0 -319
- tests/load/test_load_gsp.py +0 -15
- tests/load/test_load_nwp.py +0 -21
- tests/load/test_load_satellite.py +0 -17
- tests/load/test_load_sites.py +0 -14
- tests/numpy_sample/test_collate.py +0 -21
- tests/numpy_sample/test_datetime_features.py +0 -37
- tests/numpy_sample/test_gsp.py +0 -38
- tests/numpy_sample/test_nwp.py +0 -13
- tests/numpy_sample/test_satellite.py +0 -40
- tests/numpy_sample/test_sun_position.py +0 -81
- tests/select/test_dropout.py +0 -69
- tests/select/test_fill_time_periods.py +0 -28
- tests/select/test_find_contiguous_time_periods.py +0 -202
- tests/select/test_location.py +0 -67
- tests/select/test_select_spatial_slice.py +0 -154
- tests/select/test_select_time_slice.py +0 -275
- tests/test_sample/test_base.py +0 -164
- tests/test_sample/test_site_sample.py +0 -165
- tests/test_sample/test_uk_regional_sample.py +0 -136
- tests/torch_datasets/test_merge_and_fill_utils.py +0 -40
- tests/torch_datasets/test_pvnet_uk.py +0 -154
- tests/torch_datasets/test_site.py +0 -226
- tests/torch_datasets/test_validate_channels_utils.py +0 -78
|
@@ -1,226 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
import xarray as xr
|
|
6
|
-
|
|
7
|
-
from torch.utils.data import DataLoader
|
|
8
|
-
|
|
9
|
-
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
10
|
-
from ocf_data_sampler.torch_datasets.datasets.site import (
|
|
11
|
-
SitesDataset, convert_from_dataset_to_dict_datasets, coarsen_data
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@pytest.fixture()
|
|
17
|
-
def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites):
|
|
18
|
-
|
|
19
|
-
# adjust config to point to the zarr file
|
|
20
|
-
config = load_yaml_configuration(config_filename)
|
|
21
|
-
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
22
|
-
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
23
|
-
config.input_data.site = data_sites
|
|
24
|
-
config.input_data.gsp = None
|
|
25
|
-
|
|
26
|
-
filename = f"{tmp_path}/configuration.yaml"
|
|
27
|
-
save_yaml_configuration(config, filename)
|
|
28
|
-
yield filename
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@pytest.fixture()
|
|
32
|
-
def sites_dataset(site_config_filename):
|
|
33
|
-
return SitesDataset(site_config_filename)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def test_site(tmp_path, site_config_filename):
|
|
37
|
-
|
|
38
|
-
# Create dataset object
|
|
39
|
-
dataset = SitesDataset(site_config_filename)
|
|
40
|
-
|
|
41
|
-
assert len(dataset) == 10 * 41
|
|
42
|
-
# TODO check 41
|
|
43
|
-
|
|
44
|
-
# Generate a sample
|
|
45
|
-
sample = dataset[0]
|
|
46
|
-
|
|
47
|
-
assert isinstance(sample, xr.Dataset)
|
|
48
|
-
|
|
49
|
-
# Expected dimensions and data variables
|
|
50
|
-
expected_dims = {
|
|
51
|
-
"satellite__x_geostationary",
|
|
52
|
-
"site__time_utc",
|
|
53
|
-
"nwp-ukv__target_time_utc",
|
|
54
|
-
"nwp-ukv__x_osgb",
|
|
55
|
-
"satellite__channel",
|
|
56
|
-
"satellite__y_geostationary",
|
|
57
|
-
"satellite__time_utc",
|
|
58
|
-
"nwp-ukv__channel",
|
|
59
|
-
"nwp-ukv__y_osgb",
|
|
60
|
-
}
|
|
61
|
-
|
|
62
|
-
expected_coords_subset = {
|
|
63
|
-
"site__solar_azimuth",
|
|
64
|
-
"site__solar_elevation",
|
|
65
|
-
"site__date_cos",
|
|
66
|
-
"site__time_cos",
|
|
67
|
-
"site__time_sin",
|
|
68
|
-
"site__date_sin",
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
expected_data_vars = {"nwp-ukv", "satellite", "site"}
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
sample.to_netcdf(f"{tmp_path}/sample.nc")
|
|
75
|
-
sample = xr.open_dataset(f"{tmp_path}/sample.nc")
|
|
76
|
-
|
|
77
|
-
# Check dimensions
|
|
78
|
-
assert (
|
|
79
|
-
set(sample.dims) == expected_dims
|
|
80
|
-
), f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
|
|
81
|
-
# Check data variables
|
|
82
|
-
assert (
|
|
83
|
-
set(sample.data_vars) == expected_data_vars
|
|
84
|
-
), f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"
|
|
85
|
-
|
|
86
|
-
for coords in expected_coords_subset:
|
|
87
|
-
assert coords in sample.coords
|
|
88
|
-
|
|
89
|
-
# check the shape of the data is correct
|
|
90
|
-
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
|
|
91
|
-
assert sample["satellite"].values.shape == (7, 1, 2, 2)
|
|
92
|
-
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
|
|
93
|
-
assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2)
|
|
94
|
-
# 1.5 hours of 30 minute data (inclusive)
|
|
95
|
-
assert sample["site"].values.shape == (4,)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def test_site_time_filter_start(site_config_filename):
|
|
99
|
-
|
|
100
|
-
# Create dataset object
|
|
101
|
-
dataset = SitesDataset(site_config_filename, start_time="2024-01-01")
|
|
102
|
-
|
|
103
|
-
assert len(dataset) == 0
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def test_site_time_filter_end(site_config_filename):
|
|
107
|
-
|
|
108
|
-
# Create dataset object
|
|
109
|
-
dataset = SitesDataset(site_config_filename, end_time="2000-01-01")
|
|
110
|
-
|
|
111
|
-
assert len(dataset) == 0
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def test_site_get_sample(sites_dataset):
|
|
115
|
-
sample = sites_dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def test_convert_from_dataset_to_dict_datasets(sites_dataset):
|
|
119
|
-
|
|
120
|
-
# Generate sample
|
|
121
|
-
sample_xr = sites_dataset[0]
|
|
122
|
-
|
|
123
|
-
sample = convert_from_dataset_to_dict_datasets(sample_xr)
|
|
124
|
-
|
|
125
|
-
assert isinstance(sample, dict)
|
|
126
|
-
|
|
127
|
-
for key in ["nwp", "satellite", "site"]:
|
|
128
|
-
assert key in sample
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def test_site_dataset_with_dataloader(sites_dataset):
|
|
132
|
-
|
|
133
|
-
expected_coods = {
|
|
134
|
-
"site__solar_azimuth",
|
|
135
|
-
"site__solar_elevation",
|
|
136
|
-
"site__date_cos",
|
|
137
|
-
"site__time_cos",
|
|
138
|
-
"site__time_sin",
|
|
139
|
-
"site__date_sin",
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
dataloader_kwargs = dict(
|
|
143
|
-
shuffle=False,
|
|
144
|
-
batch_size=None,
|
|
145
|
-
sampler=None,
|
|
146
|
-
batch_sampler=None,
|
|
147
|
-
num_workers=1,
|
|
148
|
-
collate_fn=None,
|
|
149
|
-
pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
|
|
150
|
-
drop_last=False,
|
|
151
|
-
timeout=0,
|
|
152
|
-
worker_init_fn=None,
|
|
153
|
-
prefetch_factor=1,
|
|
154
|
-
persistent_workers=False, # Not needed since we only enter the dataloader loop once
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
dataloader = DataLoader(sites_dataset, collate_fn=None, batch_size=None)
|
|
158
|
-
|
|
159
|
-
sample = next(iter(dataloader))
|
|
160
|
-
|
|
161
|
-
# check that expected_dims is in the sample
|
|
162
|
-
for key in expected_coods:
|
|
163
|
-
assert key in sample
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
def test_process_and_combine_site_sample_dict(sites_dataset):
|
|
167
|
-
|
|
168
|
-
# Specify minimal structure for testing
|
|
169
|
-
raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
|
|
170
|
-
fake_site_values = np.random.rand(197)
|
|
171
|
-
site_dict = {
|
|
172
|
-
"nwp": {
|
|
173
|
-
"ukv": xr.DataArray(
|
|
174
|
-
raw_nwp_values,
|
|
175
|
-
dims=["time_utc", "channel", "y", "x"],
|
|
176
|
-
coords={
|
|
177
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
|
|
178
|
-
"channel": ["dswrf"], # Single channel
|
|
179
|
-
},
|
|
180
|
-
)
|
|
181
|
-
},
|
|
182
|
-
"site": xr.DataArray(
|
|
183
|
-
fake_site_values,
|
|
184
|
-
dims=["time_utc"],
|
|
185
|
-
coords={
|
|
186
|
-
"time_utc": pd.date_range("2024-01-01 00:00", periods=197, freq="15min"),
|
|
187
|
-
"capacity_kwp": 1000,
|
|
188
|
-
"site_id": 1,
|
|
189
|
-
"longitude": -3.5,
|
|
190
|
-
"latitude": 51.5
|
|
191
|
-
}
|
|
192
|
-
)
|
|
193
|
-
}
|
|
194
|
-
print(f"Input site_dict: {site_dict}")
|
|
195
|
-
|
|
196
|
-
# Call function
|
|
197
|
-
result = sites_dataset.process_and_combine_site_sample_dict(site_dict)
|
|
198
|
-
|
|
199
|
-
# Assert to validate output structure
|
|
200
|
-
assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
|
|
201
|
-
assert len(result.data_vars) > 0, "Dataset should contain data variables"
|
|
202
|
-
|
|
203
|
-
# Validate variable via assertion and shape of such
|
|
204
|
-
expected_variables = ["nwp-ukv", "site"]
|
|
205
|
-
for expected_variable in expected_variables:
|
|
206
|
-
assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found"
|
|
207
|
-
|
|
208
|
-
nwp_result = result["nwp-ukv"]
|
|
209
|
-
assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for nwp-ukv : {nwp_result.shape}"
|
|
210
|
-
site_result = result["site"]
|
|
211
|
-
assert site_result.shape == (197,), f"Unexpected shape for site: {site_result.shape}"
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def test_potentially_coarsen(ds_nwp_ecmwf):
|
|
215
|
-
"""Test potentially_coarsen function with ECMWF_UK data."""
|
|
216
|
-
nwp_data = ds_nwp_ecmwf
|
|
217
|
-
assert nwp_data.ECMWF_UK.shape[3:] == (15, 12) # Check initial shape (lon, lat)
|
|
218
|
-
|
|
219
|
-
data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=2)
|
|
220
|
-
assert data.ECMWF_UK.shape[3:] == (8, 6) # Coarsen to every 2 degrees
|
|
221
|
-
|
|
222
|
-
data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=3)
|
|
223
|
-
assert data.ECMWF_UK.shape[3:] == (5, 4) # Coarsen to every 3 degrees
|
|
224
|
-
|
|
225
|
-
data = coarsen_data(xr_data=nwp_data, coarsen_to_deg=1)
|
|
226
|
-
assert data.ECMWF_UK.shape[3:] == (15, 12) # No coarsening (same shape)
|
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
"""Tests for channel validation utility functions"""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from ocf_data_sampler.torch_datasets.utils.validate_channels import (
|
|
5
|
-
validate_channels,
|
|
6
|
-
validate_nwp_channels,
|
|
7
|
-
validate_satellite_channels,
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class TestChannelValidation:
|
|
12
|
-
"""Tests for channel validation functions"""
|
|
13
|
-
|
|
14
|
-
@pytest.mark.parametrize("test_case", [
|
|
15
|
-
# Base validation - success case
|
|
16
|
-
{
|
|
17
|
-
"data_channels": ["channel1", "channel2"],
|
|
18
|
-
"norm_channels": ["channel1", "channel2", "extra"],
|
|
19
|
-
"source_name": "test_source",
|
|
20
|
-
"expect_error": False
|
|
21
|
-
},
|
|
22
|
-
# Base validation - error case
|
|
23
|
-
{
|
|
24
|
-
"data_channels": ["channel1", "missing_channel"],
|
|
25
|
-
"norm_channels": ["channel1"],
|
|
26
|
-
"source_name": "test_source",
|
|
27
|
-
"expect_error": True,
|
|
28
|
-
"error_match": "following channels for test_source are missing in normalisation means"
|
|
29
|
-
},
|
|
30
|
-
# NWP case - success
|
|
31
|
-
{
|
|
32
|
-
"data_channels": ["t2m", "dswrf"],
|
|
33
|
-
"norm_channels": ["t2m", "dswrf", "extra"],
|
|
34
|
-
"source_name": "ecmwf",
|
|
35
|
-
"expect_error": False
|
|
36
|
-
},
|
|
37
|
-
# NWP case - error
|
|
38
|
-
{
|
|
39
|
-
"data_channels": ["t2m", "missing_channel"],
|
|
40
|
-
"norm_channels": ["t2m"],
|
|
41
|
-
"source_name": "ecmwf",
|
|
42
|
-
"expect_error": True,
|
|
43
|
-
"error_match": "following channels for ecmwf are missing in normalisation means"
|
|
44
|
-
},
|
|
45
|
-
# Satellite case - success
|
|
46
|
-
{
|
|
47
|
-
"data_channels": ["IR_016", "VIS006"],
|
|
48
|
-
"norm_channels": ["IR_016", "VIS006", "extra"],
|
|
49
|
-
"source_name": "satellite",
|
|
50
|
-
"expect_error": False
|
|
51
|
-
},
|
|
52
|
-
# Satellite case - error
|
|
53
|
-
{
|
|
54
|
-
"data_channels": ["IR_016", "missing_channel"],
|
|
55
|
-
"norm_channels": ["IR_016"],
|
|
56
|
-
"source_name": "satellite",
|
|
57
|
-
"expect_error": True,
|
|
58
|
-
"error_match": "following channels for satellite are missing in normalisation means"
|
|
59
|
-
}
|
|
60
|
-
])
|
|
61
|
-
def test_channel_validation(self, test_case):
|
|
62
|
-
"""Test channel validation for both base, NWP and satellite data"""
|
|
63
|
-
if test_case["expect_error"]:
|
|
64
|
-
with pytest.raises(ValueError, match=test_case["error_match"]):
|
|
65
|
-
validate_channels(
|
|
66
|
-
data_channels=test_case["data_channels"],
|
|
67
|
-
means_channels=test_case["norm_channels"],
|
|
68
|
-
stds_channels=test_case["norm_channels"],
|
|
69
|
-
source_name=test_case["source_name"]
|
|
70
|
-
)
|
|
71
|
-
else:
|
|
72
|
-
# Should not raise any exceptions
|
|
73
|
-
validate_channels(
|
|
74
|
-
data_channels=test_case["data_channels"],
|
|
75
|
-
means_channels=test_case["norm_channels"],
|
|
76
|
-
stds_channels=test_case["norm_channels"],
|
|
77
|
-
source_name=test_case["source_name"]
|
|
78
|
-
)
|