PVNet 5.0.7__py3-none-any.whl → 5.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  """ Data module for pytorch lightning """
2
2
 
3
+ import os
3
4
  from glob import glob
4
5
 
5
6
  import numpy as np
@@ -77,6 +78,7 @@ class BasePresavedDataModule(LightningDataModule):
77
78
  worker_init_fn=None,
78
79
  prefetch_factor=prefetch_factor,
79
80
  persistent_workers=persistent_workers,
81
+ multiprocessing_context="spawn" if num_workers>0 else None,
80
82
  )
81
83
 
82
84
  def _get_premade_samples_dataset(self, subdir: str) -> Dataset:
@@ -107,7 +109,7 @@ class BaseStreamedDataModule(LightningDataModule):
107
109
  train_period: list[str | None] = [None, None],
108
110
  val_period: list[str | None] = [None, None],
109
111
  seed: int | None = None,
110
-
112
+ dataset_pickle_dir: str | None = None,
111
113
  ):
112
114
  """Base Datamodule for streaming samples.
113
115
 
@@ -124,6 +126,8 @@ class BaseStreamedDataModule(LightningDataModule):
124
126
  train_period: Date range filter for train dataloader.
125
127
  val_period: Date range filter for val dataloader.
126
128
  seed: Random seed used in shuffling datasets.
129
+ dataset_pickle_dir: Directory in which the val and train set will be presaved as
130
+ pickle objects. Setting this speeds up instantiation of multiple workers a lot.
127
131
  """
128
132
  super().__init__()
129
133
 
@@ -131,6 +135,7 @@ class BaseStreamedDataModule(LightningDataModule):
131
135
  self.train_period = train_period
132
136
  self.val_period = val_period
133
137
  self.seed = seed
138
+ self.dataset_pickle_dir = dataset_pickle_dir
134
139
 
135
140
  self._common_dataloader_kwargs = dict(
136
141
  batch_size=batch_size,
@@ -143,6 +148,7 @@ class BaseStreamedDataModule(LightningDataModule):
143
148
  worker_init_fn=None,
144
149
  prefetch_factor=prefetch_factor,
145
150
  persistent_workers=persistent_workers,
151
+ multiprocessing_context="spawn" if num_workers>0 else None,
146
152
  )
147
153
 
148
154
  def setup(self, stage: str | None = None):
@@ -160,6 +166,33 @@ class BaseStreamedDataModule(LightningDataModule):
160
166
 
161
167
  shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset))
162
168
  self.val_dataset = Subset(val_dataset, shuffled_indices)
169
+
170
+ if self.dataset_pickle_dir is not None:
171
+ os.makedirs(self.dataset_pickle_dir, exist_ok=True)
172
+ train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl"
173
+ val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl"
174
+
175
+ # For safety, these pickled datasets cannot be overwritten.
176
+ # See: https://github.com/openclimatefix/pvnet/pull/445
177
+ for path in [train_dataset_path, val_dataset_path]:
178
+ if os.path.exists(path):
179
+ raise FileExistsError(
180
+ f"The pickled dataset path '{path}' already exists. Make sure that "
181
+ "this can be safely deleted (i.e. not currently being used by any "
182
+ "training run) and delete it manually. Else change the "
183
+ "`dataset_pickle_dir` to a different directory."
184
+ )
185
+
186
+ self.train_dataset.presave_pickle(train_dataset_path)
187
+ self.train_dataset.presave_pickle(val_dataset_path)
188
+
189
+ def teardown(self, stage: str | None = None) -> None:
190
+ """Clean up the pickled datasets"""
191
+ if self.dataset_pickle_dir is not None:
192
+ for filename in ["val_dataset.pkl", "train_dataset.pkl"]:
193
+ filepath = f"{self.dataset_pickle_dir}/{filename}"
194
+ if os.path.exists(filepath):
195
+ os.remove(filepath)
163
196
 
164
197
  def _get_streamed_samples_dataset(
165
198
  self,
@@ -105,8 +105,9 @@ class PVNetLightningModule(pl.LightningModule):
105
105
  """Run training step"""
106
106
  y_hat = self.model(batch)
107
107
 
108
- # Batch is adapted in the model forward method, but needs to be adapted here too
109
- batch = self.model._adapt_batch(batch)
108
+ # Batch may be adapted in the model forward method, would need adapting here too
109
+ if self.model.adapt_batches:
110
+ batch = self.model._adapt_batch(batch)
110
111
 
111
112
  y = batch[self.model._target_key][:, -self.model.forecast_len :]
112
113
 
@@ -211,8 +212,9 @@ class PVNetLightningModule(pl.LightningModule):
211
212
  with torch.no_grad():
212
213
  y_hat = self.model(batch)
213
214
 
214
- # Batch is adapted in the model forward method, but needs to be adapted here too
215
- batch = self.model._adapt_batch(batch)
215
+ # Batch may be adapted in the model forward method, would need adapting here too
216
+ if self.model.adapt_batches:
217
+ batch = self.model._adapt_batch(batch)
216
218
 
217
219
  fig = plot_sample_forecasts(
218
220
  batch,
@@ -223,7 +225,9 @@ class PVNetLightningModule(pl.LightningModule):
223
225
 
224
226
  plot_name = f"val_forecast_samples/sample_set_{plot_num}"
225
227
 
226
- self.logger.experiment.log({plot_name: wandb.Image(fig)})
228
+ # Disabled for testing or using no logger
229
+ if self.logger:
230
+ self.logger.experiment.log({plot_name: wandb.Image(fig)})
227
231
 
228
232
  plt.close(fig)
229
233
 
@@ -231,8 +235,9 @@ class PVNetLightningModule(pl.LightningModule):
231
235
  """Run validation step"""
232
236
 
233
237
  y_hat = self.model(batch)
234
- # Batch is adapted in the model forward method, but needs to be adapted here too
235
- batch = self.model._adapt_batch(batch)
238
+ # Batch may be adapted in the model forward method, would need adapting here too
239
+ if self.model.adapt_batches:
240
+ batch = self.model._adapt_batch(batch)
236
241
 
237
242
  # Internally store the val predictions
238
243
  self._store_val_predictions(batch, y_hat)
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.7
3
+ Version: 5.0.9
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
- Requires-Python: >=3.10
6
+ Requires-Python: >=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.2.32
9
+ Requires-Dist: ocf-data-sampler>=0.5.20
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -3,7 +3,7 @@ pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
3
3
  pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
4
4
  pvnet/utils.py,sha256=h4w9nmx6V_IAtiRp6VQ90TmQZGGTdMzU63WfGmL-pPs,2666
5
5
  pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
6
- pvnet/data/base_datamodule.py,sha256=A5x6HsVWyXSZopNqyJM2qKxUzhlGiggs_K187leVaQ4,6678
6
+ pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
7
7
  pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
8
8
  pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
9
9
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
22
22
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
23
23
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
24
24
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
25
- pvnet/training/lightning_module.py,sha256=8UXOeL4mnsnHTZf1kHNbRxP8aeasTI-clASegSxuWzI,13029
25
+ pvnet/training/lightning_module.py,sha256=GVqdi5ALFo9-_WRYeyMMj2qH_k4gPxQ2sG6FhL_wRFE,13242
26
26
  pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
27
27
  pvnet/training/train.py,sha256=zj9JMi9C6W68vGsQUBapWkJ4aDzDuJFMv0IVjO73s1k,5215
28
- pvnet-5.0.7.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
- pvnet-5.0.7.dist-info/METADATA,sha256=UyYMlgXH6SCla4n2ez-SVmPW6doYcYKeS2f9M5SqsRY,18043
30
- pvnet-5.0.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- pvnet-5.0.7.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
- pvnet-5.0.7.dist-info/RECORD,,
28
+ pvnet-5.0.9.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
+ pvnet-5.0.9.dist-info/METADATA,sha256=Z909YP0GU68upJxamWI1BlHI4vZjbSUAAlEnIt2V_fc,18043
30
+ pvnet-5.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ pvnet-5.0.9.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
+ pvnet-5.0.9.dist-info/RECORD,,
File without changes