PVNet 5.3.2__py3-none-any.whl → 5.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -256,11 +256,28 @@ class PVNetLightningModule(pl.LightningModule):
256
256
 
257
257
  # Calculate the persistance losses - we only need to do this once per training run
258
258
  # not every epoch
259
- if self.current_epoch == 0:
259
+ if self.current_epoch==0:
260
+ # Need to find last valid value before forecast
261
+ target_data = batch["generation"]
262
+ history_data = target_data[:, :-(self.model.forecast_len)]
263
+
264
+ # Find where values aren't dropped
265
+ valid_mask = history_data >= 0
266
+
267
+ # Last valid value index for each sample
268
+ flipped_mask = valid_mask.float().flip(dims=[1])
269
+ last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
270
+ last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
271
+
272
+ # Grab those last valid values
273
+ batch_indices = torch.arange(
274
+ history_data.shape[0],
275
+ device=history_data.device
276
+ )
277
+ last_valid_values = history_data[batch_indices, last_valid_indices]
278
+
260
279
  y_persist = (
261
- batch["generation"][:, -(self.model.forecast_len + 1)]
262
- .unsqueeze(1)
263
- .expand(-1, self.model.forecast_len)
280
+ last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
264
281
  )
265
282
  mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
266
283
  self._val_persistence_horizon_maes.append(mae_step_persist)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.2
3
+ Version: 5.3.3
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
@@ -19,11 +19,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
19
19
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
20
20
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=PemEUa_Wv5pFWw3usPKEtXcvs_MX2LSrO6nhldO_QVk,11320
21
21
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
- pvnet/training/lightning_module.py,sha256=57sT7bPCU7mJw4EskzOE-JJ9JhWIuAbs40_x5RoBbA8,12705
22
+ pvnet/training/lightning_module.py,sha256=WEC5cSIbQxA5rHVZi2bhH8uJM4A6vQoZmJvneZnGse0,13478
23
23
  pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.3.2.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.3.2.dist-info/METADATA,sha256=xG-pdUU0A1WAfxY987N-Ln_RTu3wFJaY4XWjyMENTWU,16479
27
- pvnet-5.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.3.2.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.3.2.dist-info/RECORD,,
25
+ pvnet-5.3.3.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.3.dist-info/METADATA,sha256=OJATSh8sSgvPaYjp16iyjsQgZVsIag2WenG78oX4SdA,16479
27
+ pvnet-5.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ pvnet-5.3.3.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.3.dist-info/RECORD,,
File without changes