PVNet 5.0.7__py3-none-any.whl → 5.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet/data/base_datamodule.py +34 -1
- pvnet/training/lightning_module.py +12 -7
- {pvnet-5.0.7.dist-info → pvnet-5.0.9.dist-info}/METADATA +3 -3
- {pvnet-5.0.7.dist-info → pvnet-5.0.9.dist-info}/RECORD +7 -7
- {pvnet-5.0.7.dist-info → pvnet-5.0.9.dist-info}/WHEEL +0 -0
- {pvnet-5.0.7.dist-info → pvnet-5.0.9.dist-info}/licenses/LICENSE +0 -0
- {pvnet-5.0.7.dist-info → pvnet-5.0.9.dist-info}/top_level.txt +0 -0
pvnet/data/base_datamodule.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
""" Data module for pytorch lightning """
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from glob import glob
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -77,6 +78,7 @@ class BasePresavedDataModule(LightningDataModule):
|
|
|
77
78
|
worker_init_fn=None,
|
|
78
79
|
prefetch_factor=prefetch_factor,
|
|
79
80
|
persistent_workers=persistent_workers,
|
|
81
|
+
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
80
82
|
)
|
|
81
83
|
|
|
82
84
|
def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
|
|
@@ -107,7 +109,7 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
107
109
|
train_period: list[str | None] = [None, None],
|
|
108
110
|
val_period: list[str | None] = [None, None],
|
|
109
111
|
seed: int | None = None,
|
|
110
|
-
|
|
112
|
+
dataset_pickle_dir: str | None = None,
|
|
111
113
|
):
|
|
112
114
|
"""Base Datamodule for streaming samples.
|
|
113
115
|
|
|
@@ -124,6 +126,8 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
124
126
|
train_period: Date range filter for train dataloader.
|
|
125
127
|
val_period: Date range filter for val dataloader.
|
|
126
128
|
seed: Random seed used in shuffling datasets.
|
|
129
|
+
dataset_pickle_dir: Directory in which the val and train set will be presaved as
|
|
130
|
+
pickle objects. Setting this speeds up instantiation of multiple workers a lot.
|
|
127
131
|
"""
|
|
128
132
|
super().__init__()
|
|
129
133
|
|
|
@@ -131,6 +135,7 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
131
135
|
self.train_period = train_period
|
|
132
136
|
self.val_period = val_period
|
|
133
137
|
self.seed = seed
|
|
138
|
+
self.dataset_pickle_dir = dataset_pickle_dir
|
|
134
139
|
|
|
135
140
|
self._common_dataloader_kwargs = dict(
|
|
136
141
|
batch_size=batch_size,
|
|
@@ -143,6 +148,7 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
143
148
|
worker_init_fn=None,
|
|
144
149
|
prefetch_factor=prefetch_factor,
|
|
145
150
|
persistent_workers=persistent_workers,
|
|
151
|
+
multiprocessing_context="spawn" if num_workers>0 else None,
|
|
146
152
|
)
|
|
147
153
|
|
|
148
154
|
def setup(self, stage: str | None = None):
|
|
@@ -160,6 +166,33 @@ class BaseStreamedDataModule(LightningDataModule):
|
|
|
160
166
|
|
|
161
167
|
shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
|
|
162
168
|
self.val_dataset = Subset(val_dataset, shuffled_indices)
|
|
169
|
+
|
|
170
|
+
if self.dataset_pickle_dir is not None:
|
|
171
|
+
os.makedirs(self.dataset_pickle_dir, exist_ok=True)
|
|
172
|
+
train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
|
|
173
|
+
val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
|
|
174
|
+
|
|
175
|
+
# For safety, these pickled datasets cannot be overwritten.
|
|
176
|
+
# See: https://github.com/openclimatefix/pvnet/pull/445
|
|
177
|
+
for path in [train_dataset_path, val_dataset_path]:
|
|
178
|
+
if os.path.exists(path):
|
|
179
|
+
raise FileExistsError(
|
|
180
|
+
f"The pickled dataset path '{path}' already exists. Make sure that "
|
|
181
|
+
"this can be safely deleted (i.e. not currently being used by any "
|
|
182
|
+
"training run) and delete it manually. Else change the "
|
|
183
|
+
"`dataset_pickle_dir` to a different directory."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.train_dataset.presave_pickle(train_dataset_path)
|
|
187
|
+
self.train_dataset.presave_pickle(val_dataset_path)
|
|
188
|
+
|
|
189
|
+
def teardown(self, stage: str | None = None) -> None:
|
|
190
|
+
"""Clean up the pickled datasets"""
|
|
191
|
+
if self.dataset_pickle_dir is not None:
|
|
192
|
+
for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
|
|
193
|
+
filepath = f"{self.dataset_pickle_dir}/{filename}"
|
|
194
|
+
if os.path.exists(filepath):
|
|
195
|
+
os.remove(filepath)
|
|
163
196
|
|
|
164
197
|
def _get_streamed_samples_dataset(
|
|
165
198
|
self,
|
|
@@ -105,8 +105,9 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
105
105
|
"""Run training step"""
|
|
106
106
|
y_hat = self.model(batch)
|
|
107
107
|
|
|
108
|
-
# Batch
|
|
109
|
-
|
|
108
|
+
# Batch may be adapted in the model forward method, would need adapting here too
|
|
109
|
+
if self.model.adapt_batches:
|
|
110
|
+
batch = self.model._adapt_batch(batch)
|
|
110
111
|
|
|
111
112
|
y = batch[self.model._target_key][:, -self.model.forecast_len :]
|
|
112
113
|
|
|
@@ -211,8 +212,9 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
211
212
|
with torch.no_grad():
|
|
212
213
|
y_hat = self.model(batch)
|
|
213
214
|
|
|
214
|
-
# Batch
|
|
215
|
-
|
|
215
|
+
# Batch may be adapted in the model forward method, would need adapting here too
|
|
216
|
+
if self.model.adapt_batches:
|
|
217
|
+
batch = self.model._adapt_batch(batch)
|
|
216
218
|
|
|
217
219
|
fig = plot_sample_forecasts(
|
|
218
220
|
batch,
|
|
@@ -223,7 +225,9 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
223
225
|
|
|
224
226
|
plot_name = f"val_forecast_samples/sample_set_{plot_num}"
|
|
225
227
|
|
|
226
|
-
|
|
228
|
+
# Disabled for testing or using no logger
|
|
229
|
+
if self.logger:
|
|
230
|
+
self.logger.experiment.log({plot_name: wandb.Image(fig)})
|
|
227
231
|
|
|
228
232
|
plt.close(fig)
|
|
229
233
|
|
|
@@ -231,8 +235,9 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
231
235
|
"""Run validation step"""
|
|
232
236
|
|
|
233
237
|
y_hat = self.model(batch)
|
|
234
|
-
# Batch
|
|
235
|
-
|
|
238
|
+
# Batch may be adapted in the model forward method, would need adapting here too
|
|
239
|
+
if self.model.adapt_batches:
|
|
240
|
+
batch = self.model._adapt_batch(batch)
|
|
236
241
|
|
|
237
242
|
# Internally store the val predictions
|
|
238
243
|
self._store_val_predictions(batch, y_hat)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: PVNet
|
|
3
|
-
Version: 5.0.
|
|
3
|
+
Version: 5.0.9
|
|
4
4
|
Summary: PVNet
|
|
5
5
|
Author-email: Peter Dudfield <info@openclimatefix.org>
|
|
6
|
-
Requires-Python: >=3.
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
License-File: LICENSE
|
|
9
|
-
Requires-Dist: ocf-data-sampler>=0.
|
|
9
|
+
Requires-Dist: ocf-data-sampler>=0.5.20
|
|
10
10
|
Requires-Dist: numpy
|
|
11
11
|
Requires-Dist: pandas
|
|
12
12
|
Requires-Dist: matplotlib
|
|
@@ -3,7 +3,7 @@ pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
|
|
|
3
3
|
pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
|
|
4
4
|
pvnet/utils.py,sha256=h4w9nmx6V_IAtiRp6VQ90TmQZGGTdMzU63WfGmL-pPs,2666
|
|
5
5
|
pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
|
|
6
|
-
pvnet/data/base_datamodule.py,sha256=
|
|
6
|
+
pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
|
|
7
7
|
pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
|
|
8
8
|
pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
|
|
9
9
|
pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
|
|
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
|
|
|
22
22
|
pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
|
|
23
23
|
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
|
|
24
24
|
pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
|
|
25
|
-
pvnet/training/lightning_module.py,sha256=
|
|
25
|
+
pvnet/training/lightning_module.py,sha256=GVqdi5ALFo9-_WRYeyMMj2qH_k4gPxQ2sG6FhL_wRFE,13242
|
|
26
26
|
pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
|
|
27
27
|
pvnet/training/train.py,sha256=zj9JMi9C6W68vGsQUBapWkJ4aDzDuJFMv0IVjO73s1k,5215
|
|
28
|
-
pvnet-5.0.
|
|
29
|
-
pvnet-5.0.
|
|
30
|
-
pvnet-5.0.
|
|
31
|
-
pvnet-5.0.
|
|
32
|
-
pvnet-5.0.
|
|
28
|
+
pvnet-5.0.9.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
|
|
29
|
+
pvnet-5.0.9.dist-info/METADATA,sha256=Z909YP0GU68upJxamWI1BlHI4vZjbSUAAlEnIt2V_fc,18043
|
|
30
|
+
pvnet-5.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
pvnet-5.0.9.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
|
|
32
|
+
pvnet-5.0.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|