PVNet 5.0.22__tar.gz → 5.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pvnet-5.0.22 → pvnet-5.3.1}/PKG-INFO +5 -10
- {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/PKG-INFO +5 -10
- {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/SOURCES.txt +2 -4
- {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/requires.txt +1 -1
- {pvnet-5.0.22 → pvnet-5.3.1}/README.md +2 -7
- pvnet-5.0.22/pvnet/data/base_datamodule.py → pvnet-5.3.1/pvnet/datamodule.py +16 -100
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/load_model.py +2 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/base_model.py +18 -23
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/ensemble.py +0 -4
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/late_fusion.py +40 -61
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/site_encoders/encoders.py +14 -24
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/lightning_module.py +46 -51
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/plots.py +2 -2
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/train.py +1 -22
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/utils.py +39 -17
- {pvnet-5.0.22 → pvnet-5.3.1}/pyproject.toml +2 -2
- pvnet-5.3.1/tests/test_datamodule.py +15 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/tests/test_end2end.py +4 -4
- pvnet-5.0.22/pvnet/data/__init__.py +0 -3
- pvnet-5.0.22/pvnet/data/site_datamodule.py +0 -29
- pvnet-5.0.22/pvnet/data/uk_regional_datamodule.py +0 -29
- {pvnet-5.0.22 → pvnet-5.3.1}/LICENSE +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/dependency_links.txt +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/top_level.txt +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/basic_blocks.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/optimizers.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/__init__.py +0 -0
- {pvnet-5.0.22 → pvnet-5.3.1}/setup.cfg +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet
|
|
3
|
-
Version: 5.
|
|
3
|
+
Version: 5.3.1
|
|
4
4
|
Summary: PVNet
|
|
5
5
|
Author-email: Peter Dudfield <info@openclimatefix.org>
|
|
6
|
-
Requires-Python:
|
|
6
|
+
Requires-Python: <3.14,>=3.11
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
License-File: LICENSE
|
|
9
|
-
Requires-Dist: ocf-data-sampler>=0.
|
|
9
|
+
Requires-Dist: ocf-data-sampler>=0.6.0
|
|
10
10
|
Requires-Dist: numpy
|
|
11
11
|
Requires-Dist: pandas
|
|
12
12
|
Requires-Dist: matplotlib
|
|
@@ -29,7 +29,7 @@ Dynamic: license-file
|
|
|
29
29
|
|
|
30
30
|
# PVNet
|
|
31
31
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
32
|
-
[](#contributors-)
|
|
33
33
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
34
34
|
|
|
35
35
|
[](https://github.com/openclimatefix/PVNet/tags)
|
|
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
|
|
|
142
142
|
If you install the local version of `ocf-data-sampler` that is more recent than the version
|
|
143
143
|
specified in `PVNet` it is not guarenteed to function properly with this library.
|
|
144
144
|
|
|
145
|
-
## Streaming samples (no pre-save)
|
|
146
|
-
|
|
147
|
-
PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
|
|
148
|
-
|
|
149
|
-
Make sure you have copied example configs (as already stated above):
|
|
150
|
-
cp -r configs.example configs
|
|
151
145
|
|
|
152
146
|
### Set up and config example for streaming
|
|
153
147
|
|
|
@@ -251,6 +245,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
251
245
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/markus-kreft"><img src="https://avatars.githubusercontent.com/u/129367085?v=4?s=100" width="100px;" alt="Markus Kreft"/><br /><sub><b>Markus Kreft</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=markus-kreft" title="Code">💻</a></td>
|
|
252
246
|
<td align="center" valign="top" width="14.28%"><a href="http://jack-kelly.com"><img src="https://avatars.githubusercontent.com/u/460756?v=4?s=100" width="100px;" alt="Jack Kelly"/><br /><sub><b>Jack Kelly</b></sub></a><br /><a href="#ideas-JackKelly" title="Ideas, Planning, & Feedback">🤔</a></td>
|
|
253
247
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=zaryab-ali" title="Code">💻</a></td>
|
|
248
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/Lex-Ashu"><img src="https://avatars.githubusercontent.com/u/181084934?v=4?s=100" width="100px;" alt="Lex-Ashu"/><br /><sub><b>Lex-Ashu</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=Lex-Ashu" title="Code">💻</a></td>
|
|
254
249
|
</tr>
|
|
255
250
|
</tbody>
|
|
256
251
|
</table>
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet
|
|
3
|
-
Version: 5.
|
|
3
|
+
Version: 5.3.1
|
|
4
4
|
Summary: PVNet
|
|
5
5
|
Author-email: Peter Dudfield <info@openclimatefix.org>
|
|
6
|
-
Requires-Python:
|
|
6
|
+
Requires-Python: <3.14,>=3.11
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
License-File: LICENSE
|
|
9
|
-
Requires-Dist: ocf-data-sampler>=0.
|
|
9
|
+
Requires-Dist: ocf-data-sampler>=0.6.0
|
|
10
10
|
Requires-Dist: numpy
|
|
11
11
|
Requires-Dist: pandas
|
|
12
12
|
Requires-Dist: matplotlib
|
|
@@ -29,7 +29,7 @@ Dynamic: license-file
|
|
|
29
29
|
|
|
30
30
|
# PVNet
|
|
31
31
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
32
|
-
[](#contributors-)
|
|
33
33
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
34
34
|
|
|
35
35
|
[](https://github.com/openclimatefix/PVNet/tags)
|
|
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
|
|
|
142
142
|
If you install the local version of `ocf-data-sampler` that is more recent than the version
|
|
143
143
|
specified in `PVNet` it is not guarenteed to function properly with this library.
|
|
144
144
|
|
|
145
|
-
## Streaming samples (no pre-save)
|
|
146
|
-
|
|
147
|
-
PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
|
|
148
|
-
|
|
149
|
-
Make sure you have copied example configs (as already stated above):
|
|
150
|
-
cp -r configs.example configs
|
|
151
145
|
|
|
152
146
|
### Set up and config example for streaming
|
|
153
147
|
|
|
@@ -251,6 +245,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
251
245
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/markus-kreft"><img src="https://avatars.githubusercontent.com/u/129367085?v=4?s=100" width="100px;" alt="Markus Kreft"/><br /><sub><b>Markus Kreft</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=markus-kreft" title="Code">💻</a></td>
|
|
252
246
|
<td align="center" valign="top" width="14.28%"><a href="http://jack-kelly.com"><img src="https://avatars.githubusercontent.com/u/460756?v=4?s=100" width="100px;" alt="Jack Kelly"/><br /><sub><b>Jack Kelly</b></sub></a><br /><a href="#ideas-JackKelly" title="Ideas, Planning, & Feedback">🤔</a></td>
|
|
253
247
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=zaryab-ali" title="Code">💻</a></td>
|
|
248
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/Lex-Ashu"><img src="https://avatars.githubusercontent.com/u/181084934?v=4?s=100" width="100px;" alt="Lex-Ashu"/><br /><sub><b>Lex-Ashu</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=Lex-Ashu" title="Code">💻</a></td>
|
|
254
249
|
</tr>
|
|
255
250
|
</tbody>
|
|
256
251
|
</table>
|
|
@@ -7,13 +7,10 @@ PVNet.egg-info/dependency_links.txt
|
|
|
7
7
|
PVNet.egg-info/requires.txt
|
|
8
8
|
PVNet.egg-info/top_level.txt
|
|
9
9
|
pvnet/__init__.py
|
|
10
|
+
pvnet/datamodule.py
|
|
10
11
|
pvnet/load_model.py
|
|
11
12
|
pvnet/optimizers.py
|
|
12
13
|
pvnet/utils.py
|
|
13
|
-
pvnet/data/__init__.py
|
|
14
|
-
pvnet/data/base_datamodule.py
|
|
15
|
-
pvnet/data/site_datamodule.py
|
|
16
|
-
pvnet/data/uk_regional_datamodule.py
|
|
17
14
|
pvnet/models/__init__.py
|
|
18
15
|
pvnet/models/base_model.py
|
|
19
16
|
pvnet/models/ensemble.py
|
|
@@ -33,4 +30,5 @@ pvnet/training/__init__.py
|
|
|
33
30
|
pvnet/training/lightning_module.py
|
|
34
31
|
pvnet/training/plots.py
|
|
35
32
|
pvnet/training/train.py
|
|
33
|
+
tests/test_datamodule.py
|
|
36
34
|
tests/test_end2end.py
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# PVNet
|
|
2
2
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
|
3
|
-
[](#contributors-)
|
|
4
4
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
|
5
5
|
|
|
6
6
|
[](https://github.com/openclimatefix/PVNet/tags)
|
|
@@ -113,12 +113,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
|
|
|
113
113
|
If you install the local version of `ocf-data-sampler` that is more recent than the version
|
|
114
114
|
specified in `PVNet` it is not guarenteed to function properly with this library.
|
|
115
115
|
|
|
116
|
-
## Streaming samples (no pre-save)
|
|
117
|
-
|
|
118
|
-
PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
|
|
119
|
-
|
|
120
|
-
Make sure you have copied example configs (as already stated above):
|
|
121
|
-
cp -r configs.example configs
|
|
122
116
|
|
|
123
117
|
### Set up and config example for streaming
|
|
124
118
|
|
|
@@ -222,6 +216,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
|
|
|
222
216
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/markus-kreft"><img src="https://avatars.githubusercontent.com/u/129367085?v=4?s=100" width="100px;" alt="Markus Kreft"/><br /><sub><b>Markus Kreft</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=markus-kreft" title="Code">💻</a></td>
|
|
223
217
|
<td align="center" valign="top" width="14.28%"><a href="http://jack-kelly.com"><img src="https://avatars.githubusercontent.com/u/460756?v=4?s=100" width="100px;" alt="Jack Kelly"/><br /><sub><b>Jack Kelly</b></sub></a><br /><a href="#ideas-JackKelly" title="Ideas, Planning, & Feedback">🤔</a></td>
|
|
224
218
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=zaryab-ali" title="Code">💻</a></td>
|
|
219
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/Lex-Ashu"><img src="https://avatars.githubusercontent.com/u/181084934?v=4?s=100" width="100px;" alt="Lex-Ashu"/><br /><sub><b>Lex-Ashu</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=Lex-Ashu" title="Code">💻</a></td>
|
|
225
220
|
</tr>
|
|
226
221
|
</tbody>
|
|
227
222
|
</table>
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Data module for pytorch lightning"""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from glob import glob
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
from lightning.pytorch import LightningDataModule
|
|
8
7
|
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
|
9
8
|
from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
|
|
10
|
-
from ocf_data_sampler.torch_datasets.
|
|
11
|
-
from
|
|
9
|
+
from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetDataset
|
|
10
|
+
from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import batch_to_tensor
|
|
11
|
+
from torch.utils.data import DataLoader, Subset
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def collate_fn(samples: list[NumpySample]) -> TensorBatch:
|
|
@@ -16,87 +16,8 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
|
|
|
16
16
|
return batch_to_tensor(stack_np_samples_into_batch(samples))
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
sample_dir: Path to the directory of pre-saved samples.
|
|
24
|
-
sample_class: sample class type to use for save/load/to_numpy
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, sample_dir: str, sample_class: SampleBase):
|
|
28
|
-
"""Initialise PresavedSamplesDataset"""
|
|
29
|
-
self.sample_paths = glob(f"{sample_dir}/*")
|
|
30
|
-
self.sample_class = sample_class
|
|
31
|
-
|
|
32
|
-
def __len__(self) -> int:
|
|
33
|
-
return len(self.sample_paths)
|
|
34
|
-
|
|
35
|
-
def __getitem__(self, idx) -> NumpySample:
|
|
36
|
-
sample = self.sample_class.load(self.sample_paths[idx])
|
|
37
|
-
return sample.to_numpy()
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class BasePresavedDataModule(LightningDataModule):
|
|
41
|
-
"""Base Datamodule for loading pre-saved samples."""
|
|
42
|
-
|
|
43
|
-
def __init__(
|
|
44
|
-
self,
|
|
45
|
-
sample_dir: str,
|
|
46
|
-
batch_size: int = 16,
|
|
47
|
-
num_workers: int = 0,
|
|
48
|
-
prefetch_factor: int | None = None,
|
|
49
|
-
persistent_workers: bool = False,
|
|
50
|
-
pin_memory: bool = False,
|
|
51
|
-
):
|
|
52
|
-
"""Base Datamodule for loading pre-saved samples
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
sample_dir: Path to the directory of pre-saved samples.
|
|
56
|
-
batch_size: Batch size.
|
|
57
|
-
num_workers: Number of workers to use in multiprocess batch loading.
|
|
58
|
-
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
59
|
-
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
60
|
-
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
61
|
-
instances alive.
|
|
62
|
-
pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
|
|
63
|
-
before returning them.
|
|
64
|
-
"""
|
|
65
|
-
super().__init__()
|
|
66
|
-
|
|
67
|
-
self.sample_dir = sample_dir
|
|
68
|
-
|
|
69
|
-
self._common_dataloader_kwargs = dict(
|
|
70
|
-
batch_size=batch_size,
|
|
71
|
-
sampler=None,
|
|
72
|
-
batch_sampler=None,
|
|
73
|
-
num_workers=num_workers,
|
|
74
|
-
collate_fn=collate_fn,
|
|
75
|
-
pin_memory=pin_memory,
|
|
76
|
-
drop_last=False,
|
|
77
|
-
timeout=0,
|
|
78
|
-
worker_init_fn=None,
|
|
79
|
-
prefetch_factor=prefetch_factor,
|
|
80
|
-
persistent_workers=persistent_workers,
|
|
81
|
-
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
|
|
85
|
-
raise NotImplementedError
|
|
86
|
-
|
|
87
|
-
def train_dataloader(self) -> DataLoader:
|
|
88
|
-
"""Construct train dataloader"""
|
|
89
|
-
dataset = self._get_premade_samples_dataset("train")
|
|
90
|
-
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
|
|
91
|
-
|
|
92
|
-
def val_dataloader(self) -> DataLoader:
|
|
93
|
-
"""Construct val dataloader"""
|
|
94
|
-
dataset = self._get_premade_samples_dataset("val")
|
|
95
|
-
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class BaseStreamedDataModule(LightningDataModule):
|
|
99
|
-
"""Base Datamodule which streams samples using a sampler for ocf-data-sampler."""
|
|
19
|
+
class PVNetDataModule(LightningDataModule):
|
|
20
|
+
"""Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
|
|
100
21
|
|
|
101
22
|
def __init__(
|
|
102
23
|
self,
|
|
@@ -118,10 +39,10 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
118
39
|
batch_size: Batch size.
|
|
119
40
|
num_workers: Number of workers to use in multiprocess batch loading.
|
|
120
41
|
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
121
|
-
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
122
|
-
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
42
|
+
persistent_workers: If True, the data loader will not shut down the worker processes
|
|
43
|
+
after a dataset has been consumed once. This allows to maintain the workers Dataset
|
|
123
44
|
instances alive.
|
|
124
|
-
pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
|
|
45
|
+
pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
|
|
125
46
|
before returning them.
|
|
126
47
|
train_period: Date range filter for train dataloader.
|
|
127
48
|
val_period: Date range filter for val dataloader.
|
|
@@ -148,7 +69,7 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
148
69
|
worker_init_fn=None,
|
|
149
70
|
prefetch_factor=prefetch_factor,
|
|
150
71
|
persistent_workers=persistent_workers,
|
|
151
|
-
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
72
|
+
multiprocessing_context="spawn" if num_workers > 0 else None,
|
|
152
73
|
)
|
|
153
74
|
|
|
154
75
|
def setup(self, stage: str | None = None):
|
|
@@ -157,16 +78,15 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
157
78
|
# This logic runs only once at the start of training, therefore the val dataset is only
|
|
158
79
|
# shuffled once
|
|
159
80
|
if stage == "fit":
|
|
160
|
-
|
|
161
81
|
# Prepare the train dataset
|
|
162
|
-
self.train_dataset = self.
|
|
82
|
+
self.train_dataset = self._get_dataset(*self.train_period)
|
|
163
83
|
|
|
164
|
-
#
|
|
165
|
-
val_dataset = self.
|
|
84
|
+
# Prepare and pre-shuffle the val dataset and set seed for reproducibility
|
|
85
|
+
val_dataset = self._get_dataset(*self.val_period)
|
|
166
86
|
|
|
167
87
|
shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
|
|
168
88
|
self.val_dataset = Subset(val_dataset, shuffled_indices)
|
|
169
|
-
|
|
89
|
+
|
|
170
90
|
if self.dataset_pickle_dir is not None:
|
|
171
91
|
os.makedirs(self.dataset_pickle_dir, exist_ok=True)
|
|
172
92
|
train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
|
|
@@ -194,12 +114,8 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
194
114
|
if os.path.exists(filepath):
|
|
195
115
|
os.remove(filepath)
|
|
196
116
|
|
|
197
|
-
def
|
|
198
|
-
self,
|
|
199
|
-
start_time: str | None,
|
|
200
|
-
end_time: str | None
|
|
201
|
-
) -> Dataset:
|
|
202
|
-
raise NotImplementedError
|
|
117
|
+
def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetDataset:
|
|
118
|
+
return PVNetDataset(self.configuration, start_time=start_time, end_time=end_time)
|
|
203
119
|
|
|
204
120
|
def train_dataloader(self) -> DataLoader:
|
|
205
121
|
"""Construct train dataloader"""
|
|
@@ -73,6 +73,8 @@ def get_model_from_checkpoints(
|
|
|
73
73
|
else:
|
|
74
74
|
raise FileNotFoundError(f"File {data_config} does not exist")
|
|
75
75
|
|
|
76
|
+
# TODO: This should be removed in a future release since no new models will be trained on
|
|
77
|
+
# presaved samples
|
|
76
78
|
# Check for datamodule config
|
|
77
79
|
# This only exists if the model was trained with presaved samples
|
|
78
80
|
datamodule_config = f"{path}/{DATAMODULE_CONFIG_NAME}"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Base model for all PVNet submodels"""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import shutil
|
|
@@ -32,7 +33,7 @@ def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHO
|
|
|
32
33
|
"""
|
|
33
34
|
input_config = config["input_data"]
|
|
34
35
|
|
|
35
|
-
for source in ["
|
|
36
|
+
for source in ["generation", "satellite"]:
|
|
36
37
|
if source in input_config:
|
|
37
38
|
# If not empty - i.e. if used
|
|
38
39
|
if input_config[source]["zarr_path"] != "":
|
|
@@ -78,8 +79,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
|
|
|
78
79
|
|
|
79
80
|
# Replace the interval_end_minutes minutes
|
|
80
81
|
nwp_config["interval_end_minutes"] = (
|
|
81
|
-
nwp_config["interval_start_minutes"]
|
|
82
|
-
(model.nwp_encoders_dict[nwp_source].sequence_length - 1)
|
|
82
|
+
nwp_config["interval_start_minutes"]
|
|
83
|
+
+ (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
|
|
83
84
|
* nwp_config["time_resolution_minutes"]
|
|
84
85
|
)
|
|
85
86
|
|
|
@@ -96,20 +97,19 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
|
|
|
96
97
|
|
|
97
98
|
# Replace the interval_end_minutes minutes
|
|
98
99
|
sat_config["interval_end_minutes"] = (
|
|
99
|
-
sat_config["interval_start_minutes"]
|
|
100
|
-
(model.sat_encoder.sequence_length - 1)
|
|
101
|
-
* sat_config["time_resolution_minutes"]
|
|
100
|
+
sat_config["interval_start_minutes"]
|
|
101
|
+
+ (model.sat_encoder.sequence_length - 1) * sat_config["time_resolution_minutes"]
|
|
102
102
|
)
|
|
103
103
|
|
|
104
104
|
if "pv" in input_config:
|
|
105
105
|
if not model.include_pv:
|
|
106
106
|
del input_config["pv"]
|
|
107
107
|
|
|
108
|
-
if "
|
|
109
|
-
|
|
108
|
+
if "generation" in input_config:
|
|
109
|
+
generation_config = input_config["generation"]
|
|
110
110
|
|
|
111
111
|
# Replace the forecast minutes
|
|
112
|
-
|
|
112
|
+
generation_config["interval_end_minutes"] = model.forecast_minutes
|
|
113
113
|
|
|
114
114
|
if "solar_position" in input_config:
|
|
115
115
|
solar_config = input_config["solar_position"]
|
|
@@ -138,9 +138,9 @@ def download_from_hf(
|
|
|
138
138
|
force_download: Whether to force a new download
|
|
139
139
|
max_retries: Maximum number of retry attempts
|
|
140
140
|
wait_time: Wait time (in seconds) before retrying
|
|
141
|
-
token:
|
|
141
|
+
token:
|
|
142
142
|
HF authentication token. If True, the token is read from the HuggingFace config folder.
|
|
143
|
-
If a string, it is used as the authentication token.
|
|
143
|
+
If a string, it is used as the authentication token.
|
|
144
144
|
|
|
145
145
|
Returns:
|
|
146
146
|
The local file path of the downloaded file(s)
|
|
@@ -160,7 +160,7 @@ def download_from_hf(
|
|
|
160
160
|
return [f"{save_dir}/{f}" for f in filename]
|
|
161
161
|
else:
|
|
162
162
|
return f"{save_dir}/{filename}"
|
|
163
|
-
|
|
163
|
+
|
|
164
164
|
except Exception as e:
|
|
165
165
|
if attempt == max_retries:
|
|
166
166
|
raise Exception(
|
|
@@ -205,7 +205,7 @@ class HuggingfaceMixin:
|
|
|
205
205
|
force_download=force_download,
|
|
206
206
|
max_retries=5,
|
|
207
207
|
wait_time=10,
|
|
208
|
-
token=token
|
|
208
|
+
token=token,
|
|
209
209
|
)
|
|
210
210
|
|
|
211
211
|
with open(config_file, "r") as f:
|
|
@@ -240,7 +240,7 @@ class HuggingfaceMixin:
|
|
|
240
240
|
force_download=force_download,
|
|
241
241
|
max_retries=5,
|
|
242
242
|
wait_time=10,
|
|
243
|
-
token=token
|
|
243
|
+
token=token,
|
|
244
244
|
)
|
|
245
245
|
|
|
246
246
|
return data_config_file
|
|
@@ -301,7 +301,7 @@ class HuggingfaceMixin:
|
|
|
301
301
|
# Save cleaned version of input data configuration file
|
|
302
302
|
with open(data_config_path) as cfg:
|
|
303
303
|
config = yaml.load(cfg, Loader=yaml.FullLoader)
|
|
304
|
-
|
|
304
|
+
|
|
305
305
|
config = fill_config_paths_with_placeholder(config)
|
|
306
306
|
config = minimize_config_for_model(config, self)
|
|
307
307
|
|
|
@@ -311,7 +311,7 @@ class HuggingfaceMixin:
|
|
|
311
311
|
# Save the datamodule config
|
|
312
312
|
if datamodule_config_path is not None:
|
|
313
313
|
shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
|
|
314
|
-
|
|
314
|
+
|
|
315
315
|
# Save the full experimental config
|
|
316
316
|
if experiment_config_path is not None:
|
|
317
317
|
shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
|
|
@@ -378,7 +378,6 @@ class HuggingfaceMixin:
|
|
|
378
378
|
packages_to_display = ["pvnet", "ocf-data-sampler"]
|
|
379
379
|
packages_and_versions = {package: version(package) for package in packages_to_display}
|
|
380
380
|
|
|
381
|
-
|
|
382
381
|
package_versions_markdown = ""
|
|
383
382
|
for package, v in packages_and_versions.items():
|
|
384
383
|
package_versions_markdown += f" - {package}=={v}\n"
|
|
@@ -399,23 +398,19 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
|
|
|
399
398
|
history_minutes: int,
|
|
400
399
|
forecast_minutes: int,
|
|
401
400
|
output_quantiles: list[float] | None = None,
|
|
402
|
-
target_key: str = "gsp",
|
|
403
401
|
interval_minutes: int = 30,
|
|
404
402
|
):
|
|
405
403
|
"""Abtstract base class for PVNet submodels.
|
|
406
404
|
|
|
407
405
|
Args:
|
|
408
|
-
history_minutes (int): Length of the
|
|
409
|
-
forecast_minutes (int): Length of the
|
|
406
|
+
history_minutes (int): Length of the generation history period in minutes
|
|
407
|
+
forecast_minutes (int): Length of the generation forecast period in minutes
|
|
410
408
|
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
|
411
409
|
None the output is a single value.
|
|
412
|
-
target_key: The key of the target variable in the batch
|
|
413
410
|
interval_minutes: The interval in minutes between each timestep in the data
|
|
414
411
|
"""
|
|
415
412
|
super().__init__()
|
|
416
413
|
|
|
417
|
-
self._target_key = target_key
|
|
418
|
-
|
|
419
414
|
self.history_minutes = history_minutes
|
|
420
415
|
self.forecast_minutes = forecast_minutes
|
|
421
416
|
self.output_quantiles = output_quantiles
|
|
@@ -26,7 +26,6 @@ class Ensemble(BaseModel):
|
|
|
26
26
|
output_quantiles = []
|
|
27
27
|
history_minutes = []
|
|
28
28
|
forecast_minutes = []
|
|
29
|
-
target_key = []
|
|
30
29
|
interval_minutes = []
|
|
31
30
|
|
|
32
31
|
# Get some model properties from each model
|
|
@@ -34,7 +33,6 @@ class Ensemble(BaseModel):
|
|
|
34
33
|
output_quantiles.append(model.output_quantiles)
|
|
35
34
|
history_minutes.append(model.history_minutes)
|
|
36
35
|
forecast_minutes.append(model.forecast_minutes)
|
|
37
|
-
target_key.append(model._target_key)
|
|
38
36
|
interval_minutes.append(model.interval_minutes)
|
|
39
37
|
|
|
40
38
|
# Check these properties are all the same
|
|
@@ -42,7 +40,6 @@ class Ensemble(BaseModel):
|
|
|
42
40
|
output_quantiles,
|
|
43
41
|
history_minutes,
|
|
44
42
|
forecast_minutes,
|
|
45
|
-
target_key,
|
|
46
43
|
interval_minutes,
|
|
47
44
|
]:
|
|
48
45
|
assert all([p == param_list[0] for p in param_list]), param_list
|
|
@@ -51,7 +48,6 @@ class Ensemble(BaseModel):
|
|
|
51
48
|
history_minutes=history_minutes[0],
|
|
52
49
|
forecast_minutes=forecast_minutes[0],
|
|
53
50
|
output_quantiles=output_quantiles[0],
|
|
54
|
-
target_key=target_key[0],
|
|
55
51
|
interval_minutes=interval_minutes[0],
|
|
56
52
|
)
|
|
57
53
|
|