PVNet 5.0.15__tar.gz → 5.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {pvnet-5.0.15 → pvnet-5.0.17}/PKG-INFO +1 -1
  2. {pvnet-5.0.15 → pvnet-5.0.17}/PVNet.egg-info/PKG-INFO +1 -1
  3. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/base_model.py +1 -1
  4. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/encoders/basic_blocks.py +1 -1
  5. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/encoders/encoders3d.py +3 -2
  6. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/training/lightning_module.py +11 -5
  7. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/training/train.py +0 -1
  8. pvnet-5.0.17/pvnet/utils.py +157 -0
  9. {pvnet-5.0.15 → pvnet-5.0.17}/tests/test_end2end.py +5 -4
  10. pvnet-5.0.15/pvnet/utils.py +0 -88
  11. {pvnet-5.0.15 → pvnet-5.0.17}/LICENSE +0 -0
  12. {pvnet-5.0.15 → pvnet-5.0.17}/PVNet.egg-info/SOURCES.txt +0 -0
  13. {pvnet-5.0.15 → pvnet-5.0.17}/PVNet.egg-info/dependency_links.txt +0 -0
  14. {pvnet-5.0.15 → pvnet-5.0.17}/PVNet.egg-info/requires.txt +0 -0
  15. {pvnet-5.0.15 → pvnet-5.0.17}/PVNet.egg-info/top_level.txt +0 -0
  16. {pvnet-5.0.15 → pvnet-5.0.17}/README.md +0 -0
  17. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/__init__.py +0 -0
  18. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/data/__init__.py +0 -0
  19. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/data/base_datamodule.py +0 -0
  20. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/data/site_datamodule.py +0 -0
  21. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/data/uk_regional_datamodule.py +0 -0
  22. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/load_model.py +0 -0
  23. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/__init__.py +0 -0
  24. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/ensemble.py +0 -0
  25. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/__init__.py +0 -0
  26. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  27. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  28. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/late_fusion.py +0 -0
  29. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  30. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  31. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  32. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  33. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  34. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
  35. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/optimizers.py +0 -0
  36. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/training/__init__.py +0 -0
  37. {pvnet-5.0.15 → pvnet-5.0.17}/pvnet/training/plots.py +0 -0
  38. {pvnet-5.0.15 → pvnet-5.0.17}/pyproject.toml +0 -0
  39. {pvnet-5.0.15 → pvnet-5.0.17}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.15
3
+ Version: 5.0.17
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.15
3
+ Version: 5.0.17
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -451,4 +451,4 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
451
451
  """
452
452
  # y_quantiles Shape: batch_size, seq_length, num_quantiles
453
453
  idx = self.output_quantiles.index(0.5)
454
- return y_quantiles[..., idx]
454
+ return y_quantiles[..., idx]
@@ -16,7 +16,6 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
16
16
  self,
17
17
  sequence_length: int,
18
18
  image_size_pixels: int,
19
- in_channels: int,
20
19
  out_features: int,
21
20
  ):
22
21
  """Abstract class for NWP/satellite encoder.
@@ -31,6 +30,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
31
30
  self.out_features = out_features
32
31
  self.image_size_pixels = image_size_pixels
33
32
  self.sequence_length = sequence_length
33
+ self.in_channels = in_channels
34
34
 
35
35
  @abstractmethod
36
36
  def forward(self):
@@ -41,7 +41,8 @@ class DefaultPVNet(AbstractNWPSatelliteEncoder):
41
41
  padding: The padding used in the conv3d layers. If an int, the same padding
42
42
  is used in all dimensions
43
43
  """
44
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
44
+
45
+ super().__init__(sequence_length, image_size_pixels, out_features)
45
46
 
46
47
  if isinstance(padding, int):
47
48
  padding = (padding, padding, padding)
@@ -136,7 +137,7 @@ class ResConv3DNet(AbstractNWPSatelliteEncoder):
136
137
  batch_norm: Whether to include batch normalisation.
137
138
  dropout_frac: Probability of an element to be zeroed in the residual pathways.
138
139
  """
139
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
140
+ super().__init__(sequence_length, image_size_pixels, out_features)
140
141
 
141
142
  model = [
142
143
  nn.Conv3d(
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
15
15
  from pvnet.models.base_model import BaseModel
16
16
  from pvnet.optimizers import AbstractOptimizer
17
17
  from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
18
+ from pvnet.utils import validate_batch_against_config
18
19
 
19
20
 
20
21
  class PVNetLightningModule(pl.LightningModule):
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
42
43
  # This setting is only used when lr is tuned with callback
43
44
  self.lr = None
44
45
 
45
- # Set up store for all all validation results so we can log these
46
- self.save_all_validation_results = save_all_validation_results
47
-
48
46
  def transfer_batch_to_device(
49
47
  self,
50
48
  batch: TensorBatch,
@@ -189,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
189
187
  self._val_horizon_maes: list[np.array] = []
190
188
  if self.current_epoch==0:
191
189
  self._val_persistence_horizon_maes: list[np.array] = []
192
-
190
+
193
191
  # Plot some sample forecasts
194
192
  val_dataset = self.trainer.val_dataloaders.dataset
195
193
 
@@ -205,6 +203,14 @@ class PVNetLightningModule(pl.LightningModule):
205
203
 
206
204
  batch = collate_fn([val_dataset[i] for i in idxs])
207
205
  batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
206
+
207
+ # Batch validation check only during sanity check phase - use first batch
208
+ if self.trainer.sanity_checking and plot_num == 0:
209
+ validate_batch_against_config(
210
+ batch=batch,
211
+ model=self.model
212
+ )
213
+
208
214
  with torch.no_grad():
209
215
  y_hat = self.model(batch)
210
216
 
@@ -306,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
306
312
  self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
307
313
 
308
314
  # Optionally save all validation results - these are overridden each epoch
309
- if self.save_all_validation_results:
315
+ if self.hparams.save_all_validation_results:
310
316
  # Add attributes
311
317
  ds_val_results.attrs["epoch"] = self.current_epoch
312
318
 
@@ -26,7 +26,6 @@ from pvnet.utils import (
26
26
  log = logging.getLogger(__name__)
27
27
 
28
28
 
29
-
30
29
  def resolve_monitor_loss(output_quantiles: list | None) -> str:
31
30
  """Return the desired metric to monitor based on whether quantile regression is being used.
32
31
 
@@ -0,0 +1,157 @@
1
+ """Utils"""
2
+ import logging
3
+ from typing import TYPE_CHECKING
4
+
5
+ import rich.syntax
6
+ import rich.tree
7
+ from lightning.pytorch.utilities import rank_zero_only
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+ if TYPE_CHECKING:
11
+ from pvnet.models.base_model import BaseModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
17
+ MODEL_CONFIG_NAME = "model_config.yaml"
18
+ DATA_CONFIG_NAME = "data_config.yaml"
19
+ DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
20
+ FULL_CONFIG_NAME = "full_experiment_config.yaml"
21
+ MODEL_CARD_NAME = "README.md"
22
+
23
+
24
+
25
+ def run_config_utilities(config: DictConfig) -> None:
26
+ """A couple of optional utilities.
27
+
28
+ Controlled by main config file:
29
+ - forcing debug friendly configuration
30
+
31
+ Modifies DictConfig in place.
32
+
33
+ Args:
34
+ config (DictConfig): Configuration composed by Hydra.
35
+ """
36
+
37
+ # Enable adding new keys to config
38
+ OmegaConf.set_struct(config, False)
39
+
40
+ # Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
41
+ if config.trainer.get("fast_dev_run"):
42
+ logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
43
+ # Debuggers don't like GPUs or multiprocessing
44
+ if config.trainer.get("gpus"):
45
+ config.trainer.gpus = 0
46
+ if config.datamodule.get("pin_memory"):
47
+ config.datamodule.pin_memory = False
48
+ if config.datamodule.get("num_workers"):
49
+ config.datamodule.num_workers = 0
50
+ if config.datamodule.get("prefetch_factor"):
51
+ config.datamodule.prefetch_factor = None
52
+
53
+ # Disable adding new keys to config
54
+ OmegaConf.set_struct(config, True)
55
+
56
+
57
+ @rank_zero_only
58
+ def print_config(
59
+ config: DictConfig,
60
+ fields: tuple[str] = (
61
+ "trainer",
62
+ "model",
63
+ "datamodule",
64
+ "callbacks",
65
+ "logger",
66
+ "seed",
67
+ ),
68
+ resolve: bool = True,
69
+ ) -> None:
70
+ """Prints content of DictConfig using Rich library and its tree structure.
71
+
72
+ Args:
73
+ config (DictConfig): Configuration composed by Hydra.
74
+ fields (Sequence[str], optional): Determines which main fields from config will
75
+ be printed and in what order.
76
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
77
+ """
78
+
79
+ style = "dim"
80
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
81
+
82
+ for field in fields:
83
+ branch = tree.add(field, style=style, guide_style=style)
84
+
85
+ config_section = config.get(field)
86
+
87
+ branch_content = str(config_section)
88
+ if isinstance(config_section, DictConfig):
89
+ branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
90
+
91
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
92
+
93
+ rich.print(tree)
94
+
95
+
96
+ def validate_batch_against_config(
97
+ batch: dict,
98
+ model: "BaseModel",
99
+ ) -> None:
100
+ """Validates tensor shapes in batch against model configuration."""
101
+ logger.info("Performing batch shape validation against model config.")
102
+
103
+ # NWP validation
104
+ if hasattr(model, 'nwp_encoders_dict'):
105
+ if "nwp" not in batch:
106
+ raise ValueError(
107
+ "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
+ )
109
+
110
+ for source, nwp_data in batch["nwp"].items():
111
+ if source in model.nwp_encoders_dict:
112
+
113
+ enc = model.nwp_encoders_dict[source]
114
+ expected_channels = enc.in_channels
115
+ if model.add_image_embedding_channel:
116
+ expected_channels -= 1
117
+
118
+ expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
119
+ expected_channels, enc.image_size_pixels, enc.image_size_pixels)
120
+ if tuple(nwp_data["nwp"].shape) != expected:
121
+ actual_shape = tuple(nwp_data['nwp'].shape)
122
+ raise ValueError(
123
+ f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
124
+ )
125
+
126
+ # Satellite validation
127
+ if hasattr(model, 'sat_encoder'):
128
+ if "satellite_actual" not in batch:
129
+ raise ValueError(
130
+ "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
131
+ )
132
+
133
+ enc = model.sat_encoder
134
+ expected_channels = enc.in_channels
135
+ if model.add_image_embedding_channel:
136
+ expected_channels -= 1
137
+
138
+ expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
139
+ enc.image_size_pixels, enc.image_size_pixels)
140
+ if tuple(batch["satellite_actual"].shape) != expected:
141
+ actual_shape = tuple(batch['satellite_actual'].shape)
142
+ raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
143
+
144
+ # GSP/Site validation
145
+ key = model._target_key
146
+ if key in batch:
147
+ total_minutes = model.history_minutes + model.forecast_minutes
148
+ interval = model.interval_minutes
149
+ expected_len = total_minutes // interval + 1
150
+ expected = (batch[key].shape[0], expected_len)
151
+ if tuple(batch[key].shape) != expected:
152
+ actual_shape = tuple(batch[key].shape)
153
+ raise ValueError(
154
+ f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
155
+ )
156
+
157
+ logger.info("Batch shape validation successful!")
@@ -1,7 +1,8 @@
1
1
  import lightning
2
- from pvnet.data import UKRegionalStreamedDataModule
3
- from pvnet.training.lightning_module import PVNetLightningModule
2
+
3
+ from pvnet.data import UKRegionalStreamedDataModule
4
4
  from pvnet.optimizers import EmbAdamWReduceLROnPlateau
5
+ from pvnet.training.lightning_module import PVNetLightningModule
5
6
 
6
7
 
7
8
  def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_model):
@@ -15,7 +16,7 @@ def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_mo
15
16
  dataset_pickle_dir=f"{session_tmp_path}/dataset_pickles"
16
17
  )
17
18
 
18
- ligtning_model = PVNetLightningModule(
19
+ lightning_model = PVNetLightningModule(
19
20
  model=late_fusion_model,
20
21
  optimizer=EmbAdamWReduceLROnPlateau(),
21
22
  )
@@ -29,4 +30,4 @@ def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_mo
29
30
  logger=False,
30
31
  enable_checkpointing=False,
31
32
  )
32
- trainer.fit(model=ligtning_model, datamodule=datamodule)
33
+ trainer.fit(model=lightning_model, datamodule=datamodule)
@@ -1,88 +0,0 @@
1
- """Utils"""
2
- import logging
3
-
4
- import rich.syntax
5
- import rich.tree
6
- from lightning.pytorch.utilities import rank_zero_only
7
- from omegaconf import DictConfig, OmegaConf
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
13
- MODEL_CONFIG_NAME = "model_config.yaml"
14
- DATA_CONFIG_NAME = "data_config.yaml"
15
- DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
16
- FULL_CONFIG_NAME = "full_experiment_config.yaml"
17
- MODEL_CARD_NAME = "README.md"
18
-
19
-
20
-
21
- def run_config_utilities(config: DictConfig) -> None:
22
- """A couple of optional utilities.
23
-
24
- Controlled by main config file:
25
- - forcing debug friendly configuration
26
-
27
- Modifies DictConfig in place.
28
-
29
- Args:
30
- config (DictConfig): Configuration composed by Hydra.
31
- """
32
-
33
- # Enable adding new keys to config
34
- OmegaConf.set_struct(config, False)
35
-
36
- # Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
37
- if config.trainer.get("fast_dev_run"):
38
- logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
39
- # Debuggers don't like GPUs or multiprocessing
40
- if config.trainer.get("gpus"):
41
- config.trainer.gpus = 0
42
- if config.datamodule.get("pin_memory"):
43
- config.datamodule.pin_memory = False
44
- if config.datamodule.get("num_workers"):
45
- config.datamodule.num_workers = 0
46
- if config.datamodule.get("prefetch_factor"):
47
- config.datamodule.prefetch_factor = None
48
-
49
- # Disable adding new keys to config
50
- OmegaConf.set_struct(config, True)
51
-
52
-
53
- @rank_zero_only
54
- def print_config(
55
- config: DictConfig,
56
- fields: tuple[str] = (
57
- "trainer",
58
- "model",
59
- "datamodule",
60
- "callbacks",
61
- "logger",
62
- "seed",
63
- ),
64
- resolve: bool = True,
65
- ) -> None:
66
- """Prints content of DictConfig using Rich library and its tree structure.
67
-
68
- Args:
69
- config (DictConfig): Configuration composed by Hydra.
70
- fields (Sequence[str], optional): Determines which main fields from config will
71
- be printed and in what order.
72
- resolve (bool, optional): Whether to resolve reference fields of DictConfig.
73
- """
74
-
75
- style = "dim"
76
- tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
77
-
78
- for field in fields:
79
- branch = tree.add(field, style=style, guide_style=style)
80
-
81
- config_section = config.get(field)
82
- branch_content = str(config_section)
83
- if isinstance(config_section, DictConfig):
84
- branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
85
-
86
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
87
-
88
- rich.print(tree)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes