ocf-data-sampler 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/numpy_batch/collate.py +79 -0
- {ocf_data_sampler-0.0.36.dist-info → ocf_data_sampler-0.0.37.dist-info}/METADATA +3 -2
- {ocf_data_sampler-0.0.36.dist-info → ocf_data_sampler-0.0.37.dist-info}/RECORD +9 -7
- tests/conftest.py +16 -0
- tests/numpy_batch/test_collate.py +26 -0
- tests/torch_datasets/test_pvnet_uk_regional.py +0 -13
- {ocf_data_sampler-0.0.36.dist-info → ocf_data_sampler-0.0.37.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.0.36.dist-info → ocf_data_sampler-0.0.37.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.0.36.dist-info → ocf_data_sampler-0.0.37.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from ocf_data_sampler.numpy_batch import NWPBatchKey
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def stack_np_examples_into_batch(dict_list):
|
|
11
|
+
"""
|
|
12
|
+
Stacks Numpy examples into a batch
|
|
13
|
+
|
|
14
|
+
See also: `unstack_np_batch_into_examples()` for opposite
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
dict_list: A list of dict-like Numpy examples to stack
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
The stacked NumpyBatch object
|
|
21
|
+
"""
|
|
22
|
+
batch = {}
|
|
23
|
+
|
|
24
|
+
batch_keys = list(dict_list[0].keys())
|
|
25
|
+
|
|
26
|
+
for batch_key in batch_keys:
|
|
27
|
+
# NWP is nested so treat separately
|
|
28
|
+
if batch_key == "nwp":
|
|
29
|
+
nwp_batch: dict[str, NWPBatchKey] = {}
|
|
30
|
+
|
|
31
|
+
# Unpack source keys
|
|
32
|
+
nwp_sources = list(dict_list[0]["nwp"].keys())
|
|
33
|
+
|
|
34
|
+
for nwp_source in nwp_sources:
|
|
35
|
+
# Keys can be different for different NWPs
|
|
36
|
+
nwp_batch_keys = list(dict_list[0]["nwp"][nwp_source].keys())
|
|
37
|
+
|
|
38
|
+
nwp_source_batch = {}
|
|
39
|
+
for nwp_batch_key in nwp_batch_keys:
|
|
40
|
+
nwp_source_batch[nwp_batch_key] = stack_data_list(
|
|
41
|
+
[d["nwp"][nwp_source][nwp_batch_key] for d in dict_list],
|
|
42
|
+
nwp_batch_key,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
nwp_batch[nwp_source] = nwp_source_batch
|
|
46
|
+
|
|
47
|
+
batch["nwp"] = nwp_batch
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
batch[batch_key] = stack_data_list(
|
|
51
|
+
[d[batch_key] for d in dict_list],
|
|
52
|
+
batch_key,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return batch
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _key_is_constant(batch_key):
|
|
59
|
+
is_constant = batch_key.endswith("t0_idx") or batch_key == NWPBatchKey.channel_names
|
|
60
|
+
return is_constant
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def stack_data_list(
|
|
64
|
+
data_list: list,
|
|
65
|
+
batch_key: Union[str, NWPBatchKey],
|
|
66
|
+
):
|
|
67
|
+
"""How to combine data entries for each key
|
|
68
|
+
"""
|
|
69
|
+
if _key_is_constant(batch_key):
|
|
70
|
+
# These are always the same for all examples.
|
|
71
|
+
return data_list[0]
|
|
72
|
+
try:
|
|
73
|
+
return np.stack(data_list)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.debug(f"Could not stack the following shapes together, ({batch_key})")
|
|
76
|
+
shapes = [example.shape for example in data_list]
|
|
77
|
+
logger.debug(shapes)
|
|
78
|
+
logger.error(e)
|
|
79
|
+
raise e
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocf_data_sampler
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.37
|
|
4
4
|
Summary: Sample from weather data for renewable energy prediction
|
|
5
5
|
Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
|
|
6
6
|
Author-email: info@openclimatefix.org
|
|
@@ -56,7 +56,7 @@ Requires-Dist: mkdocs-material>=8.0; extra == "docs"
|
|
|
56
56
|
# ocf-data-sampler
|
|
57
57
|
|
|
58
58
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
59
|
-
[](#contributors-)
|
|
60
60
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
61
61
|
|
|
62
62
|
[](https://github.com/openclimatefix/ocf-data-sampler/tags)
|
|
@@ -129,6 +129,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
129
129
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/peterdudfield"><img src="https://avatars.githubusercontent.com/u/34686298?v=4?s=100" width="100px;" alt="Peter Dudfield"/><br /><sub><b>Peter Dudfield</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=peterdudfield" title="Code">💻</a></td>
|
|
130
130
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/VikramsDataScience"><img src="https://avatars.githubusercontent.com/u/45002417?v=4?s=100" width="100px;" alt="Vikram Pande"/><br /><sub><b>Vikram Pande</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=VikramsDataScience" title="Code">💻</a></td>
|
|
131
131
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/SophiaLi20"><img src="https://avatars.githubusercontent.com/u/163532536?v=4?s=100" width="100px;" alt="Unnati Bhardwaj"/><br /><sub><b>Unnati Bhardwaj</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=SophiaLi20" title="Documentation">📖</a></td>
|
|
132
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/alirashidAR"><img src="https://avatars.githubusercontent.com/u/110668489?v=4?s=100" width="100px;" alt="Ali Rashid"/><br /><sub><b>Ali Rashid</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=alirashidAR" title="Code">💻</a></td>
|
|
132
133
|
</tr>
|
|
133
134
|
</tbody>
|
|
134
135
|
</table>
|
|
@@ -19,6 +19,7 @@ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68
|
|
|
19
19
|
ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
|
|
20
20
|
ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
|
|
21
21
|
ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
|
|
22
|
+
ocf_data_sampler/numpy_batch/collate.py,sha256=KyWdDi8AXD5YiokXXiqr2_X1SC1me1GrhnQMelg0Qx8,2202
|
|
22
23
|
ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
|
|
23
24
|
ocf_data_sampler/numpy_batch/nwp.py,sha256=dAehfRo5DL2Yb20ifHHl5cU1QOrm3ZOpQmN39fSUOw8,1255
|
|
24
25
|
ocf_data_sampler/numpy_batch/satellite.py,sha256=3NoE_ElzMHwO60apqJeFAwI6J7eIxD0OWTyAVl-uJi8,903
|
|
@@ -41,12 +42,13 @@ ocf_data_sampler/torch_datasets/site.py,sha256=NYuhgm9ti9SRt1dcb_WrFYYo14NgVdOsa
|
|
|
41
42
|
ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
42
43
|
scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
|
|
43
44
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
tests/conftest.py,sha256=
|
|
45
|
+
tests/conftest.py,sha256=68hH-HPdHPLvLrtYJU8bjfkdGKbhPfNveLKvUs6_Lr0,7970
|
|
45
46
|
tests/config/test_config.py,sha256=eaye_F7-el4tTP4n2vRME8qlV0b2jaKUX4HhgOUpa7E,5203
|
|
46
47
|
tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
|
|
47
48
|
tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
|
|
48
49
|
tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
|
|
49
50
|
tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
|
|
51
|
+
tests/numpy_batch/test_collate.py,sha256=U0u5LLpkImr7R50xDuTKNo2Of8sUT5pEs9F2ZMW3jEU,836
|
|
50
52
|
tests/numpy_batch/test_gsp.py,sha256=VANXV32K8aLX4dCdhCUnDorJmyNN-Bjc7Wc1N-RzWEk,548
|
|
51
53
|
tests/numpy_batch/test_nwp.py,sha256=Fnj7cR-VR2Z0kMu8SrgnIayjxWnPWrYFjWSjMmnrh4Y,1445
|
|
52
54
|
tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0dbsx1k,1179
|
|
@@ -57,10 +59,10 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
|
|
|
57
59
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
58
60
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
59
61
|
tests/select/test_select_time_slice.py,sha256=K1EJR5TwZa9dJf_YTEHxGtvs398iy1xS2lr1BgJZkoo,9603
|
|
60
|
-
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=
|
|
62
|
+
tests/torch_datasets/test_pvnet_uk_regional.py,sha256=eqy0nQOWoHnqltlJlGmRlgIiIzPEwOC6o5A6GARryKA,2118
|
|
61
63
|
tests/torch_datasets/test_site.py,sha256=YuVjWTI14_kmEOx23XE5J_RZ8UalCKD2xRv6mqYizB8,2872
|
|
62
|
-
ocf_data_sampler-0.0.
|
|
63
|
-
ocf_data_sampler-0.0.
|
|
64
|
-
ocf_data_sampler-0.0.
|
|
65
|
-
ocf_data_sampler-0.0.
|
|
66
|
-
ocf_data_sampler-0.0.
|
|
64
|
+
ocf_data_sampler-0.0.37.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
65
|
+
ocf_data_sampler-0.0.37.dist-info/METADATA,sha256=tKixIA37U0AA76QsYmCIfLzpzE2aSGRmquSx69jX4aY,10290
|
|
66
|
+
ocf_data_sampler-0.0.37.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
67
|
+
ocf_data_sampler-0.0.37.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
68
|
+
ocf_data_sampler-0.0.37.dist-info/RECORD,,
|
tests/conftest.py
CHANGED
|
@@ -7,6 +7,7 @@ import xarray as xr
|
|
|
7
7
|
import tempfile
|
|
8
8
|
|
|
9
9
|
from ocf_data_sampler.config.model import Site
|
|
10
|
+
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
|
|
10
11
|
|
|
11
12
|
_top_test_directory = os.path.dirname(os.path.realpath(__file__))
|
|
12
13
|
|
|
@@ -269,3 +270,18 @@ def uk_gsp_zarr_path(ds_uk_gsp):
|
|
|
269
270
|
ds_uk_gsp.to_zarr(filename)
|
|
270
271
|
yield filename
|
|
271
272
|
|
|
273
|
+
|
|
274
|
+
@pytest.fixture()
|
|
275
|
+
def pvnet_config_filename(
|
|
276
|
+
tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path
|
|
277
|
+
):
|
|
278
|
+
|
|
279
|
+
# adjust config to point to the zarr file
|
|
280
|
+
config = load_yaml_configuration(config_filename)
|
|
281
|
+
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
|
|
282
|
+
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
283
|
+
config.input_data.gsp.zarr_path = uk_gsp_zarr_path
|
|
284
|
+
|
|
285
|
+
filename = f"{tmp_path}/configuration.yaml"
|
|
286
|
+
save_yaml_configuration(config, filename)
|
|
287
|
+
return filename
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from ocf_data_sampler.numpy_batch import GSPBatchKey, SatelliteBatchKey
|
|
2
|
+
from ocf_data_sampler.numpy_batch.collate import stack_np_examples_into_batch
|
|
3
|
+
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_pvnet(pvnet_config_filename):
|
|
7
|
+
|
|
8
|
+
# Create dataset object
|
|
9
|
+
dataset = PVNetUKRegionalDataset(pvnet_config_filename)
|
|
10
|
+
|
|
11
|
+
assert len(dataset.locations) == 317
|
|
12
|
+
assert len(dataset.valid_t0_times) == 39
|
|
13
|
+
assert len(dataset) == 317 * 39
|
|
14
|
+
|
|
15
|
+
# Generate 2 samples
|
|
16
|
+
sample1 = dataset[0]
|
|
17
|
+
sample2 = dataset[1]
|
|
18
|
+
|
|
19
|
+
batch = stack_np_examples_into_batch([sample1, sample2])
|
|
20
|
+
|
|
21
|
+
assert isinstance(batch, dict)
|
|
22
|
+
assert "nwp" in batch
|
|
23
|
+
assert isinstance(batch["nwp"], dict)
|
|
24
|
+
assert "ukv" in batch["nwp"]
|
|
25
|
+
assert GSPBatchKey.gsp in batch
|
|
26
|
+
assert SatelliteBatchKey.satellite_actual in batch
|
|
@@ -6,19 +6,6 @@ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configura
|
|
|
6
6
|
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
@pytest.fixture()
|
|
10
|
-
def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path):
|
|
11
|
-
|
|
12
|
-
# adjust config to point to the zarr file
|
|
13
|
-
config = load_yaml_configuration(config_filename)
|
|
14
|
-
config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
|
|
15
|
-
config.input_data.satellite.zarr_path = sat_zarr_path
|
|
16
|
-
config.input_data.gsp.zarr_path = uk_gsp_zarr_path
|
|
17
|
-
|
|
18
|
-
filename = f"{tmp_path}/configuration.yaml"
|
|
19
|
-
save_yaml_configuration(config, filename)
|
|
20
|
-
return filename
|
|
21
|
-
|
|
22
9
|
|
|
23
10
|
def test_pvnet(pvnet_config_filename):
|
|
24
11
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|