PVNet 5.0.22__tar.gz → 5.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {pvnet-5.0.22 → pvnet-5.3.1}/PKG-INFO +5 -10
  2. {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/PKG-INFO +5 -10
  3. {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/SOURCES.txt +2 -4
  4. {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/requires.txt +1 -1
  5. {pvnet-5.0.22 → pvnet-5.3.1}/README.md +2 -7
  6. pvnet-5.0.22/pvnet/data/base_datamodule.py → pvnet-5.3.1/pvnet/datamodule.py +16 -100
  7. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/load_model.py +2 -0
  8. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/base_model.py +18 -23
  9. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/ensemble.py +0 -4
  10. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/late_fusion.py +40 -61
  11. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/site_encoders/encoders.py +14 -24
  12. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/lightning_module.py +46 -51
  13. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/plots.py +2 -2
  14. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/train.py +1 -22
  15. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/utils.py +39 -17
  16. {pvnet-5.0.22 → pvnet-5.3.1}/pyproject.toml +2 -2
  17. pvnet-5.3.1/tests/test_datamodule.py +15 -0
  18. {pvnet-5.0.22 → pvnet-5.3.1}/tests/test_end2end.py +4 -4
  19. pvnet-5.0.22/pvnet/data/__init__.py +0 -3
  20. pvnet-5.0.22/pvnet/data/site_datamodule.py +0 -29
  21. pvnet-5.0.22/pvnet/data/uk_regional_datamodule.py +0 -29
  22. {pvnet-5.0.22 → pvnet-5.3.1}/LICENSE +0 -0
  23. {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/dependency_links.txt +0 -0
  24. {pvnet-5.0.22 → pvnet-5.3.1}/PVNet.egg-info/top_level.txt +0 -0
  25. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/__init__.py +0 -0
  26. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/__init__.py +0 -0
  27. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/__init__.py +0 -0
  28. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  29. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  30. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
  31. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
  32. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  33. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  34. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  35. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  36. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  37. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/optimizers.py +0 -0
  38. {pvnet-5.0.22 → pvnet-5.3.1}/pvnet/training/__init__.py +0 -0
  39. {pvnet-5.0.22 → pvnet-5.3.1}/setup.cfg +0 -0
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.22
3
+ Version: 5.3.1
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
- Requires-Python: >=3.11
6
+ Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.5.20
9
+ Requires-Dist: ocf-data-sampler>=0.6.0
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -29,7 +29,7 @@ Dynamic: license-file
29
29
 
30
30
  # PVNet
31
31
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
32
- [![All Contributors](https://img.shields.io/badge/all_contributors-20-orange.svg?style=flat-square)](#contributors-)
32
+ [![All Contributors](https://img.shields.io/badge/all_contributors-21-orange.svg?style=flat-square)](#contributors-)
33
33
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
34
34
 
35
35
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/PVNet?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/PVNet/tags)
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
142
142
  If you install the local version of `ocf-data-sampler` that is more recent than the version
143
143
  specified in `PVNet` it is not guarenteed to function properly with this library.
144
144
 
145
- ## Streaming samples (no pre-save)
146
-
147
- PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
148
-
149
- Make sure you have copied example configs (as already stated above):
150
- cp -r configs.example configs
151
145
 
152
146
  ### Set up and config example for streaming
153
147
 
@@ -251,6 +245,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
251
245
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/markus-kreft"><img src="https://avatars.githubusercontent.com/u/129367085?v=4?s=100" width="100px;" alt="Markus Kreft"/><br /><sub><b>Markus Kreft</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=markus-kreft" title="Code">💻</a></td>
252
246
  <td align="center" valign="top" width="14.28%"><a href="http://jack-kelly.com"><img src="https://avatars.githubusercontent.com/u/460756?v=4?s=100" width="100px;" alt="Jack Kelly"/><br /><sub><b>Jack Kelly</b></sub></a><br /><a href="#ideas-JackKelly" title="Ideas, Planning, & Feedback">🤔</a></td>
253
247
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=zaryab-ali" title="Code">💻</a></td>
248
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/Lex-Ashu"><img src="https://avatars.githubusercontent.com/u/181084934?v=4?s=100" width="100px;" alt="Lex-Ashu"/><br /><sub><b>Lex-Ashu</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=Lex-Ashu" title="Code">💻</a></td>
254
249
  </tr>
255
250
  </tbody>
256
251
  </table>
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.22
3
+ Version: 5.3.1
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
- Requires-Python: >=3.11
6
+ Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.5.20
9
+ Requires-Dist: ocf-data-sampler>=0.6.0
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -29,7 +29,7 @@ Dynamic: license-file
29
29
 
30
30
  # PVNet
31
31
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
32
- [![All Contributors](https://img.shields.io/badge/all_contributors-20-orange.svg?style=flat-square)](#contributors-)
32
+ [![All Contributors](https://img.shields.io/badge/all_contributors-21-orange.svg?style=flat-square)](#contributors-)
33
33
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
34
34
 
35
35
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/PVNet?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/PVNet/tags)
@@ -142,12 +142,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
142
142
  If you install the local version of `ocf-data-sampler` that is more recent than the version
143
143
  specified in `PVNet` it is not guarenteed to function properly with this library.
144
144
 
145
- ## Streaming samples (no pre-save)
146
-
147
- PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
148
-
149
- Make sure you have copied example configs (as already stated above):
150
- cp -r configs.example configs
151
145
 
152
146
  ### Set up and config example for streaming
153
147
 
@@ -251,6 +245,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
251
245
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/markus-kreft"><img src="https://avatars.githubusercontent.com/u/129367085?v=4?s=100" width="100px;" alt="Markus Kreft"/><br /><sub><b>Markus Kreft</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=markus-kreft" title="Code">💻</a></td>
252
246
  <td align="center" valign="top" width="14.28%"><a href="http://jack-kelly.com"><img src="https://avatars.githubusercontent.com/u/460756?v=4?s=100" width="100px;" alt="Jack Kelly"/><br /><sub><b>Jack Kelly</b></sub></a><br /><a href="#ideas-JackKelly" title="Ideas, Planning, & Feedback">🤔</a></td>
253
247
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=zaryab-ali" title="Code">💻</a></td>
248
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/Lex-Ashu"><img src="https://avatars.githubusercontent.com/u/181084934?v=4?s=100" width="100px;" alt="Lex-Ashu"/><br /><sub><b>Lex-Ashu</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=Lex-Ashu" title="Code">💻</a></td>
254
249
  </tr>
255
250
  </tbody>
256
251
  </table>
@@ -7,13 +7,10 @@ PVNet.egg-info/dependency_links.txt
7
7
  PVNet.egg-info/requires.txt
8
8
  PVNet.egg-info/top_level.txt
9
9
  pvnet/__init__.py
10
+ pvnet/datamodule.py
10
11
  pvnet/load_model.py
11
12
  pvnet/optimizers.py
12
13
  pvnet/utils.py
13
- pvnet/data/__init__.py
14
- pvnet/data/base_datamodule.py
15
- pvnet/data/site_datamodule.py
16
- pvnet/data/uk_regional_datamodule.py
17
14
  pvnet/models/__init__.py
18
15
  pvnet/models/base_model.py
19
16
  pvnet/models/ensemble.py
@@ -33,4 +30,5 @@ pvnet/training/__init__.py
33
30
  pvnet/training/lightning_module.py
34
31
  pvnet/training/plots.py
35
32
  pvnet/training/train.py
33
+ tests/test_datamodule.py
36
34
  tests/test_end2end.py
@@ -1,4 +1,4 @@
1
- ocf-data-sampler>=0.5.20
1
+ ocf-data-sampler>=0.6.0
2
2
  numpy
3
3
  pandas
4
4
  matplotlib
@@ -1,6 +1,6 @@
1
1
  # PVNet
2
2
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
3
- [![All Contributors](https://img.shields.io/badge/all_contributors-20-orange.svg?style=flat-square)](#contributors-)
3
+ [![All Contributors](https://img.shields.io/badge/all_contributors-21-orange.svg?style=flat-square)](#contributors-)
4
4
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
5
5
 
6
6
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/PVNet?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/PVNet/tags)
@@ -113,12 +113,6 @@ pip install -e <PATH-TO-ocf-data-sampler-REPO>
113
113
  If you install the local version of `ocf-data-sampler` that is more recent than the version
114
114
  specified in `PVNet` it is not guarenteed to function properly with this library.
115
115
 
116
- ## Streaming samples (no pre-save)
117
-
118
- PVNet now trains and validates directly from **streamed_samples** (i.e. no pre-saving to disk).
119
-
120
- Make sure you have copied example configs (as already stated above):
121
- cp -r configs.example configs
122
116
 
123
117
  ### Set up and config example for streaming
124
118
 
@@ -222,6 +216,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
222
216
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/markus-kreft"><img src="https://avatars.githubusercontent.com/u/129367085?v=4?s=100" width="100px;" alt="Markus Kreft"/><br /><sub><b>Markus Kreft</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=markus-kreft" title="Code">💻</a></td>
223
217
  <td align="center" valign="top" width="14.28%"><a href="http://jack-kelly.com"><img src="https://avatars.githubusercontent.com/u/460756?v=4?s=100" width="100px;" alt="Jack Kelly"/><br /><sub><b>Jack Kelly</b></sub></a><br /><a href="#ideas-JackKelly" title="Ideas, Planning, & Feedback">🤔</a></td>
224
218
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=zaryab-ali" title="Code">💻</a></td>
219
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/Lex-Ashu"><img src="https://avatars.githubusercontent.com/u/181084934?v=4?s=100" width="100px;" alt="Lex-Ashu"/><br /><sub><b>Lex-Ashu</b></sub></a><br /><a href="https://github.com/openclimatefix/pvnet/commits?author=Lex-Ashu" title="Code">💻</a></td>
225
220
  </tr>
226
221
  </tbody>
227
222
  </table>
@@ -1,14 +1,14 @@
1
- """ Data module for pytorch lightning """
1
+ """Data module for pytorch lightning"""
2
2
 
3
3
  import os
4
- from glob import glob
5
4
 
6
5
  import numpy as np
7
6
  from lightning.pytorch import LightningDataModule
8
7
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
9
8
  from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
10
- from ocf_data_sampler.torch_datasets.sample.base import SampleBase, batch_to_tensor
11
- from torch.utils.data import DataLoader, Dataset, Subset
9
+ from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetDataset
10
+ from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import batch_to_tensor
11
+ from torch.utils.data import DataLoader, Subset
12
12
 
13
13
 
14
14
  def collate_fn(samples: list[NumpySample]) -> TensorBatch:
@@ -16,87 +16,8 @@ def collate_fn(samples: list[NumpySample]) -> TensorBatch:
16
16
  return batch_to_tensor(stack_np_samples_into_batch(samples))
17
17
 
18
18
 
19
- class PresavedSamplesDataset(Dataset):
20
- """Dataset of pre-saved samples
21
-
22
- Args:
23
- sample_dir: Path to the directory of pre-saved samples.
24
- sample_class: sample class type to use for save/load/to_numpy
25
- """
26
-
27
- def __init__(self, sample_dir: str, sample_class: SampleBase):
28
- """Initialise PresavedSamplesDataset"""
29
- self.sample_paths = glob(f"{sample_dir}/*")
30
- self.sample_class = sample_class
31
-
32
- def __len__(self) -> int:
33
- return len(self.sample_paths)
34
-
35
- def __getitem__(self, idx) -> NumpySample:
36
- sample = self.sample_class.load(self.sample_paths[idx])
37
- return sample.to_numpy()
38
-
39
-
40
- class BasePresavedDataModule(LightningDataModule):
41
- """Base Datamodule for loading pre-saved samples."""
42
-
43
- def __init__(
44
- self,
45
- sample_dir: str,
46
- batch_size: int = 16,
47
- num_workers: int = 0,
48
- prefetch_factor: int | None = None,
49
- persistent_workers: bool = False,
50
- pin_memory: bool = False,
51
- ):
52
- """Base Datamodule for loading pre-saved samples
53
-
54
- Args:
55
- sample_dir: Path to the directory of pre-saved samples.
56
- batch_size: Batch size.
57
- num_workers: Number of workers to use in multiprocess batch loading.
58
- prefetch_factor: Number of batches loaded in advance by each worker.
59
- persistent_workers: If True, the data loader will not shut down the worker processes
60
- after a dataset has been consumed once. This allows to maintain the workers Dataset
61
- instances alive.
62
- pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
63
- before returning them.
64
- """
65
- super().__init__()
66
-
67
- self.sample_dir = sample_dir
68
-
69
- self._common_dataloader_kwargs = dict(
70
- batch_size=batch_size,
71
- sampler=None,
72
- batch_sampler=None,
73
- num_workers=num_workers,
74
- collate_fn=collate_fn,
75
- pin_memory=pin_memory,
76
- drop_last=False,
77
- timeout=0,
78
- worker_init_fn=None,
79
- prefetch_factor=prefetch_factor,
80
- persistent_workers=persistent_workers,
81
- multiprocessing_context="spawn" if num_workers>0 else None,
82
- )
83
-
84
- def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
85
- raise NotImplementedError
86
-
87
- def train_dataloader(self) -> DataLoader:
88
- """Construct train dataloader"""
89
- dataset = self._get_premade_samples_dataset("train")
90
- return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
91
-
92
- def val_dataloader(self) -> DataLoader:
93
- """Construct val dataloader"""
94
- dataset = self._get_premade_samples_dataset("val")
95
- return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
96
-
97
-
98
- class BaseStreamedDataModule(LightningDataModule):
99
- """Base Datamodule which streams samples using a sampler for ocf-data-sampler."""
19
+ class PVNetDataModule(LightningDataModule):
20
+ """Base Datamodule which streams samples using a sampler from ocf-data-sampler."""
100
21
 
101
22
  def __init__(
102
23
  self,
@@ -118,10 +39,10 @@ class BaseStreamedDataModule(LightningDataModule):
118
39
  batch_size: Batch size.
119
40
  num_workers: Number of workers to use in multiprocess batch loading.
120
41
  prefetch_factor: Number of batches loaded in advance by each worker.
121
- persistent_workers: If True, the data loader will not shut down the worker processes
122
- after a dataset has been consumed once. This allows to maintain the workers Dataset
42
+ persistent_workers: If True, the data loader will not shut down the worker processes
43
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
123
44
  instances alive.
124
- pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
45
+ pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
125
46
  before returning them.
126
47
  train_period: Date range filter for train dataloader.
127
48
  val_period: Date range filter for val dataloader.
@@ -148,7 +69,7 @@ class BaseStreamedDataModule(LightningDataModule):
148
69
  worker_init_fn=None,
149
70
  prefetch_factor=prefetch_factor,
150
71
  persistent_workers=persistent_workers,
151
- multiprocessing_context="spawn" if num_workers>0 else None,
72
+ multiprocessing_context="spawn" if num_workers > 0 else None,
152
73
  )
153
74
 
154
75
  def setup(self, stage: str | None = None):
@@ -157,16 +78,15 @@ class BaseStreamedDataModule(LightningDataModule):
157
78
  # This logic runs only once at the start of training, therefore the val dataset is only
158
79
  # shuffled once
159
80
  if stage == "fit":
160
-
161
81
  # Prepare the train dataset
162
- self.train_dataset = self._get_streamed_samples_dataset(*self.train_period)
82
+ self.train_dataset = self._get_dataset(*self.train_period)
163
83
 
164
- # Prepare and pre-shuffle the val dataset and set seed for reproducibility
165
- val_dataset = self._get_streamed_samples_dataset(*self.val_period)
84
+ # Prepare and pre-shuffle the val dataset and set seed for reproducibility
85
+ val_dataset = self._get_dataset(*self.val_period)
166
86
 
167
87
  shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
168
88
  self.val_dataset = Subset(val_dataset, shuffled_indices)
169
-
89
+
170
90
  if self.dataset_pickle_dir is not None:
171
91
  os.makedirs(self.dataset_pickle_dir, exist_ok=True)
172
92
  train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
@@ -194,12 +114,8 @@ class BaseStreamedDataModule(LightningDataModule):
194
114
  if os.path.exists(filepath):
195
115
  os.remove(filepath)
196
116
 
197
- def _get_streamed_samples_dataset(
198
- self,
199
- start_time: str | None,
200
- end_time: str | None
201
- ) -> Dataset:
202
- raise NotImplementedError
117
+ def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetDataset:
118
+ return PVNetDataset(self.configuration, start_time=start_time, end_time=end_time)
203
119
 
204
120
  def train_dataloader(self) -> DataLoader:
205
121
  """Construct train dataloader"""
@@ -73,6 +73,8 @@ def get_model_from_checkpoints(
73
73
  else:
74
74
  raise FileNotFoundError(f"File {data_config} does not exist")
75
75
 
76
+ # TODO: This should be removed in a future release since no new models will be trained on
77
+ # presaved samples
76
78
  # Check for datamodule config
77
79
  # This only exists if the model was trained with presaved samples
78
80
  datamodule_config = f"{path}/{DATAMODULE_CONFIG_NAME}"
@@ -1,4 +1,5 @@
1
1
  """Base model for all PVNet submodels"""
2
+
2
3
  import logging
3
4
  import os
4
5
  import shutil
@@ -32,7 +33,7 @@ def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHO
32
33
  """
33
34
  input_config = config["input_data"]
34
35
 
35
- for source in ["gsp", "satellite"]:
36
+ for source in ["generation", "satellite"]:
36
37
  if source in input_config:
37
38
  # If not empty - i.e. if used
38
39
  if input_config[source]["zarr_path"] != "":
@@ -78,8 +79,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
78
79
 
79
80
  # Replace the interval_end_minutes minutes
80
81
  nwp_config["interval_end_minutes"] = (
81
- nwp_config["interval_start_minutes"] +
82
- (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
82
+ nwp_config["interval_start_minutes"]
83
+ + (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
83
84
  * nwp_config["time_resolution_minutes"]
84
85
  )
85
86
 
@@ -96,20 +97,19 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
96
97
 
97
98
  # Replace the interval_end_minutes minutes
98
99
  sat_config["interval_end_minutes"] = (
99
- sat_config["interval_start_minutes"] +
100
- (model.sat_encoder.sequence_length - 1)
101
- * sat_config["time_resolution_minutes"]
100
+ sat_config["interval_start_minutes"]
101
+ + (model.sat_encoder.sequence_length - 1) * sat_config["time_resolution_minutes"]
102
102
  )
103
103
 
104
104
  if "pv" in input_config:
105
105
  if not model.include_pv:
106
106
  del input_config["pv"]
107
107
 
108
- if "gsp" in input_config:
109
- gsp_config = input_config["gsp"]
108
+ if "generation" in input_config:
109
+ generation_config = input_config["generation"]
110
110
 
111
111
  # Replace the forecast minutes
112
- gsp_config["interval_end_minutes"] = model.forecast_minutes
112
+ generation_config["interval_end_minutes"] = model.forecast_minutes
113
113
 
114
114
  if "solar_position" in input_config:
115
115
  solar_config = input_config["solar_position"]
@@ -138,9 +138,9 @@ def download_from_hf(
138
138
  force_download: Whether to force a new download
139
139
  max_retries: Maximum number of retry attempts
140
140
  wait_time: Wait time (in seconds) before retrying
141
- token:
141
+ token:
142
142
  HF authentication token. If True, the token is read from the HuggingFace config folder.
143
- If a string, it is used as the authentication token.
143
+ If a string, it is used as the authentication token.
144
144
 
145
145
  Returns:
146
146
  The local file path of the downloaded file(s)
@@ -160,7 +160,7 @@ def download_from_hf(
160
160
  return [f"{save_dir}/{f}" for f in filename]
161
161
  else:
162
162
  return f"{save_dir}/{filename}"
163
-
163
+
164
164
  except Exception as e:
165
165
  if attempt == max_retries:
166
166
  raise Exception(
@@ -205,7 +205,7 @@ class HuggingfaceMixin:
205
205
  force_download=force_download,
206
206
  max_retries=5,
207
207
  wait_time=10,
208
- token=token
208
+ token=token,
209
209
  )
210
210
 
211
211
  with open(config_file, "r") as f:
@@ -240,7 +240,7 @@ class HuggingfaceMixin:
240
240
  force_download=force_download,
241
241
  max_retries=5,
242
242
  wait_time=10,
243
- token=token
243
+ token=token,
244
244
  )
245
245
 
246
246
  return data_config_file
@@ -301,7 +301,7 @@ class HuggingfaceMixin:
301
301
  # Save cleaned version of input data configuration file
302
302
  with open(data_config_path) as cfg:
303
303
  config = yaml.load(cfg, Loader=yaml.FullLoader)
304
-
304
+
305
305
  config = fill_config_paths_with_placeholder(config)
306
306
  config = minimize_config_for_model(config, self)
307
307
 
@@ -311,7 +311,7 @@ class HuggingfaceMixin:
311
311
  # Save the datamodule config
312
312
  if datamodule_config_path is not None:
313
313
  shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
314
-
314
+
315
315
  # Save the full experimental config
316
316
  if experiment_config_path is not None:
317
317
  shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
@@ -378,7 +378,6 @@ class HuggingfaceMixin:
378
378
  packages_to_display = ["pvnet", "ocf-data-sampler"]
379
379
  packages_and_versions = {package: version(package) for package in packages_to_display}
380
380
 
381
-
382
381
  package_versions_markdown = ""
383
382
  for package, v in packages_and_versions.items():
384
383
  package_versions_markdown += f" - {package}=={v}\n"
@@ -399,23 +398,19 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
399
398
  history_minutes: int,
400
399
  forecast_minutes: int,
401
400
  output_quantiles: list[float] | None = None,
402
- target_key: str = "gsp",
403
401
  interval_minutes: int = 30,
404
402
  ):
405
403
  """Abtstract base class for PVNet submodels.
406
404
 
407
405
  Args:
408
- history_minutes (int): Length of the GSP history period in minutes
409
- forecast_minutes (int): Length of the GSP forecast period in minutes
406
+ history_minutes (int): Length of the generation history period in minutes
407
+ forecast_minutes (int): Length of the generation forecast period in minutes
410
408
  output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
411
409
  None the output is a single value.
412
- target_key: The key of the target variable in the batch
413
410
  interval_minutes: The interval in minutes between each timestep in the data
414
411
  """
415
412
  super().__init__()
416
413
 
417
- self._target_key = target_key
418
-
419
414
  self.history_minutes = history_minutes
420
415
  self.forecast_minutes = forecast_minutes
421
416
  self.output_quantiles = output_quantiles
@@ -26,7 +26,6 @@ class Ensemble(BaseModel):
26
26
  output_quantiles = []
27
27
  history_minutes = []
28
28
  forecast_minutes = []
29
- target_key = []
30
29
  interval_minutes = []
31
30
 
32
31
  # Get some model properties from each model
@@ -34,7 +33,6 @@ class Ensemble(BaseModel):
34
33
  output_quantiles.append(model.output_quantiles)
35
34
  history_minutes.append(model.history_minutes)
36
35
  forecast_minutes.append(model.forecast_minutes)
37
- target_key.append(model._target_key)
38
36
  interval_minutes.append(model.interval_minutes)
39
37
 
40
38
  # Check these properties are all the same
@@ -42,7 +40,6 @@ class Ensemble(BaseModel):
42
40
  output_quantiles,
43
41
  history_minutes,
44
42
  forecast_minutes,
45
- target_key,
46
43
  interval_minutes,
47
44
  ]:
48
45
  assert all([p == param_list[0] for p in param_list]), param_list
@@ -51,7 +48,6 @@ class Ensemble(BaseModel):
51
48
  history_minutes=history_minutes[0],
52
49
  forecast_minutes=forecast_minutes[0],
53
50
  output_quantiles=output_quantiles[0],
54
- target_key=target_key[0],
55
51
  interval_minutes=interval_minutes[0],
56
52
  )
57
53