PVNet 5.3.1__tar.gz → 5.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pvnet-5.3.1 → pvnet-5.3.3}/PKG-INFO +1 -1
  2. {pvnet-5.3.1 → pvnet-5.3.3}/PVNet.egg-info/PKG-INFO +1 -1
  3. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/training/lightning_module.py +21 -4
  4. {pvnet-5.3.1 → pvnet-5.3.3}/LICENSE +0 -0
  5. {pvnet-5.3.1 → pvnet-5.3.3}/PVNet.egg-info/SOURCES.txt +0 -0
  6. {pvnet-5.3.1 → pvnet-5.3.3}/PVNet.egg-info/dependency_links.txt +0 -0
  7. {pvnet-5.3.1 → pvnet-5.3.3}/PVNet.egg-info/requires.txt +0 -0
  8. {pvnet-5.3.1 → pvnet-5.3.3}/PVNet.egg-info/top_level.txt +0 -0
  9. {pvnet-5.3.1 → pvnet-5.3.3}/README.md +0 -0
  10. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/__init__.py +0 -0
  11. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/datamodule.py +0 -0
  12. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/load_model.py +0 -0
  13. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/__init__.py +0 -0
  14. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/base_model.py +0 -0
  15. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/ensemble.py +0 -0
  16. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/__init__.py +0 -0
  17. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  18. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  19. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/encoders/basic_blocks.py +0 -0
  20. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
  21. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/late_fusion.py +0 -0
  22. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  23. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  24. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  25. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  26. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  27. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
  28. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/optimizers.py +0 -0
  29. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/training/__init__.py +0 -0
  30. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/training/plots.py +0 -0
  31. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/training/train.py +0 -0
  32. {pvnet-5.3.1 → pvnet-5.3.3}/pvnet/utils.py +0 -0
  33. {pvnet-5.3.1 → pvnet-5.3.3}/pyproject.toml +0 -0
  34. {pvnet-5.3.1 → pvnet-5.3.3}/setup.cfg +0 -0
  35. {pvnet-5.3.1 → pvnet-5.3.3}/tests/test_datamodule.py +0 -0
  36. {pvnet-5.3.1 → pvnet-5.3.3}/tests/test_end2end.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.1
3
+ Version: 5.3.3
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.1
3
+ Version: 5.3.3
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
@@ -256,11 +256,28 @@ class PVNetLightningModule(pl.LightningModule):
256
256
 
257
257
  # Calculate the persistance losses - we only need to do this once per training run
258
258
  # not every epoch
259
- if self.current_epoch == 0:
259
+ if self.current_epoch==0:
260
+ # Need to find last valid value before forecast
261
+ target_data = batch["generation"]
262
+ history_data = target_data[:, :-(self.model.forecast_len)]
263
+
264
+ # Find where values aren't dropped
265
+ valid_mask = history_data >= 0
266
+
267
+ # Last valid value index for each sample
268
+ flipped_mask = valid_mask.float().flip(dims=[1])
269
+ last_valid_indices_flipped = torch.argmax(flipped_mask, dim=1)
270
+ last_valid_indices = history_data.shape[1] - 1 - last_valid_indices_flipped
271
+
272
+ # Grab those last valid values
273
+ batch_indices = torch.arange(
274
+ history_data.shape[0],
275
+ device=history_data.device
276
+ )
277
+ last_valid_values = history_data[batch_indices, last_valid_indices]
278
+
260
279
  y_persist = (
261
- batch["generation"][:, -(self.model.forecast_len + 1)]
262
- .unsqueeze(1)
263
- .expand(-1, self.model.forecast_len)
280
+ last_valid_values.unsqueeze(1).expand(-1, self.model.forecast_len)
264
281
  )
265
282
  mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
266
283
  self._val_persistence_horizon_maes.append(mae_step_persist)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes