PVNet 5.0.15__py3-none-any.whl → 5.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -451,4 +451,4 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
451
451
  """
452
452
  # y_quantiles Shape: batch_size, seq_length, num_quantiles
453
453
  idx = self.output_quantiles.index(0.5)
454
- return y_quantiles[..., idx]
454
+ return y_quantiles[..., idx]
@@ -16,7 +16,6 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
16
16
  self,
17
17
  sequence_length: int,
18
18
  image_size_pixels: int,
19
- in_channels: int,
20
19
  out_features: int,
21
20
  ):
22
21
  """Abstract class for NWP/satellite encoder.
@@ -31,6 +30,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
31
30
  self.out_features = out_features
32
31
  self.image_size_pixels = image_size_pixels
33
32
  self.sequence_length = sequence_length
33
+ self.in_channels = in_channels
34
34
 
35
35
  @abstractmethod
36
36
  def forward(self):
@@ -41,7 +41,8 @@ class DefaultPVNet(AbstractNWPSatelliteEncoder):
41
41
  padding: The padding used in the conv3d layers. If an int, the same padding
42
42
  is used in all dimensions
43
43
  """
44
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
44
+
45
+ super().__init__(sequence_length, image_size_pixels, out_features)
45
46
 
46
47
  if isinstance(padding, int):
47
48
  padding = (padding, padding, padding)
@@ -136,7 +137,7 @@ class ResConv3DNet(AbstractNWPSatelliteEncoder):
136
137
  batch_norm: Whether to include batch normalisation.
137
138
  dropout_frac: Probability of an element to be zeroed in the residual pathways.
138
139
  """
139
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
140
+ super().__init__(sequence_length, image_size_pixels, out_features)
140
141
 
141
142
  model = [
142
143
  nn.Conv3d(
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
15
15
  from pvnet.models.base_model import BaseModel
16
16
  from pvnet.optimizers import AbstractOptimizer
17
17
  from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
18
+ from pvnet.utils import validate_batch_against_config
18
19
 
19
20
 
20
21
  class PVNetLightningModule(pl.LightningModule):
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
42
43
  # This setting is only used when lr is tuned with callback
43
44
  self.lr = None
44
45
 
45
- # Set up store for all all validation results so we can log these
46
- self.save_all_validation_results = save_all_validation_results
47
-
48
46
  def transfer_batch_to_device(
49
47
  self,
50
48
  batch: TensorBatch,
@@ -189,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
189
187
  self._val_horizon_maes: list[np.array] = []
190
188
  if self.current_epoch==0:
191
189
  self._val_persistence_horizon_maes: list[np.array] = []
192
-
190
+
193
191
  # Plot some sample forecasts
194
192
  val_dataset = self.trainer.val_dataloaders.dataset
195
193
 
@@ -205,6 +203,14 @@ class PVNetLightningModule(pl.LightningModule):
205
203
 
206
204
  batch = collate_fn([val_dataset[i] for i in idxs])
207
205
  batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
206
+
207
+ # Batch validation check only during sanity check phase - use first batch
208
+ if self.trainer.sanity_checking and plot_num == 0:
209
+ validate_batch_against_config(
210
+ batch=batch,
211
+ model=self.model
212
+ )
213
+
208
214
  with torch.no_grad():
209
215
  y_hat = self.model(batch)
210
216
 
@@ -306,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
306
312
  self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
307
313
 
308
314
  # Optionally save all validation results - these are overridden each epoch
309
- if self.save_all_validation_results:
315
+ if self.hparams.save_all_validation_results:
310
316
  # Add attributes
311
317
  ds_val_results.attrs["epoch"] = self.current_epoch
312
318
 
pvnet/training/train.py CHANGED
@@ -26,7 +26,6 @@ from pvnet.utils import (
26
26
  log = logging.getLogger(__name__)
27
27
 
28
28
 
29
-
30
29
  def resolve_monitor_loss(output_quantiles: list | None) -> str:
31
30
  """Return the desired metric to monitor based on whether quantile regression is being used.
32
31
 
pvnet/utils.py CHANGED
@@ -1,11 +1,15 @@
1
1
  """Utils"""
2
2
  import logging
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  import rich.syntax
5
6
  import rich.tree
6
7
  from lightning.pytorch.utilities import rank_zero_only
7
8
  from omegaconf import DictConfig, OmegaConf
8
9
 
10
+ if TYPE_CHECKING:
11
+ from pvnet.models.base_model import BaseModel
12
+
9
13
  logger = logging.getLogger(__name__)
10
14
 
11
15
 
@@ -79,6 +83,7 @@ def print_config(
79
83
  branch = tree.add(field, style=style, guide_style=style)
80
84
 
81
85
  config_section = config.get(field)
86
+
82
87
  branch_content = str(config_section)
83
88
  if isinstance(config_section, DictConfig):
84
89
  branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
@@ -86,3 +91,67 @@ def print_config(
86
91
  branch.add(rich.syntax.Syntax(branch_content, "yaml"))
87
92
 
88
93
  rich.print(tree)
94
+
95
+
96
+ def validate_batch_against_config(
97
+ batch: dict,
98
+ model: "BaseModel",
99
+ ) -> None:
100
+ """Validates tensor shapes in batch against model configuration."""
101
+ logger.info("Performing batch shape validation against model config.")
102
+
103
+ # NWP validation
104
+ if hasattr(model, 'nwp_encoders_dict'):
105
+ if "nwp" not in batch:
106
+ raise ValueError(
107
+ "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
+ )
109
+
110
+ for source, nwp_data in batch["nwp"].items():
111
+ if source in model.nwp_encoders_dict:
112
+
113
+ enc = model.nwp_encoders_dict[source]
114
+ expected_channels = enc.in_channels
115
+ if model.add_image_embedding_channel:
116
+ expected_channels -= 1
117
+
118
+ expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
119
+ expected_channels, enc.image_size_pixels, enc.image_size_pixels)
120
+ if tuple(nwp_data["nwp"].shape) != expected:
121
+ actual_shape = tuple(nwp_data['nwp'].shape)
122
+ raise ValueError(
123
+ f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
124
+ )
125
+
126
+ # Satellite validation
127
+ if hasattr(model, 'sat_encoder'):
128
+ if "satellite_actual" not in batch:
129
+ raise ValueError(
130
+ "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
131
+ )
132
+
133
+ enc = model.sat_encoder
134
+ expected_channels = enc.in_channels
135
+ if model.add_image_embedding_channel:
136
+ expected_channels -= 1
137
+
138
+ expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
139
+ enc.image_size_pixels, enc.image_size_pixels)
140
+ if tuple(batch["satellite_actual"].shape) != expected:
141
+ actual_shape = tuple(batch['satellite_actual'].shape)
142
+ raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
143
+
144
+ # GSP/Site validation
145
+ key = model._target_key
146
+ if key in batch:
147
+ total_minutes = model.history_minutes + model.forecast_minutes
148
+ interval = model.interval_minutes
149
+ expected_len = total_minutes // interval + 1
150
+ expected = (batch[key].shape[0], expected_len)
151
+ if tuple(batch[key].shape) != expected:
152
+ actual_shape = tuple(batch[key].shape)
153
+ raise ValueError(
154
+ f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
155
+ )
156
+
157
+ logger.info("Batch shape validation successful!")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.15
3
+ Version: 5.0.17
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -1,20 +1,20 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
3
3
  pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
4
- pvnet/utils.py,sha256=h4w9nmx6V_IAtiRp6VQ90TmQZGGTdMzU63WfGmL-pPs,2666
4
+ pvnet/utils.py,sha256=iHG0wlN_cITKXpR16w54fs2R68wkX7vy3a_6f4SuanY,5414
5
5
  pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
6
6
  pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
7
7
  pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
8
8
  pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
9
9
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
10
- pvnet/models/base_model.py,sha256=4Sxal_8hq4F6irr2uG7GbtYZAjaOzUxmLBnO1U7fce0,16181
10
+ pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
11
11
  pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
12
12
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
13
13
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
14
14
  pvnet/models/late_fusion/late_fusion.py,sha256=VpP1aY646iO-JBlYkyoBlj0Z-gzsqaHnMgpNzjTqAIo,15735
15
15
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
16
- pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=N8CZCQUwydmBip10AHY5hSkwrSP12B_Mrvm-5XVcz1s,3005
17
- pvnet/models/late_fusion/encoders/encoders3d.py,sha256=i8POpAsdNFwdnx64Hn7Zm9q9bby8qMPn7G7miwhsxGk,6645
16
+ pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=Yg0Pch0nOdvXjxFln_yZyEeyWYUlSUjbnXhZIEdhHdw,3018
17
+ pvnet/models/late_fusion/encoders/encoders3d.py,sha256=pZ_QnCcvJv2Dw64ielFFfK_4rxxYy7bvNrb_9l4RZyQ,6620
18
18
  pvnet/models/late_fusion/linear_networks/__init__.py,sha256=16dLdGfH4QWNrI1fUB-cXWx24lArqo2lWIjdUCWbcBY,96
19
19
  pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM0NphZ5gRSOpogo7927XIlZJ9LM0,2787
20
20
  pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
22
22
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
23
23
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
24
24
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
25
- pvnet/training/lightning_module.py,sha256=UvtpijxlRSDKrwH979FSQBeCLLMReYe-S-guWe1upl4,12685
25
+ pvnet/training/lightning_module.py,sha256=TkvLtOPswRtTIwcmAvNOSHg2RvIMzHsJVvs_d0xiRmQ,12891
26
26
  pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
27
- pvnet/training/train.py,sha256=zj9JMi9C6W68vGsQUBapWkJ4aDzDuJFMv0IVjO73s1k,5215
28
- pvnet-5.0.15.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
- pvnet-5.0.15.dist-info/METADATA,sha256=uiHZcCcXv50Cl82sg_5XU1D0adASBW8C9DRH6pvamaA,18017
30
- pvnet-5.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- pvnet-5.0.15.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
- pvnet-5.0.15.dist-info/RECORD,,
27
+ pvnet/training/train.py,sha256=1tDA34ianCRfilS0yEeIpR9nsQWaJiiZfTD2qRUmgEc,5214
28
+ pvnet-5.0.17.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
+ pvnet-5.0.17.dist-info/METADATA,sha256=WxuG2XNuflRMoDdvhU6tpojEv9vf6OfivvEudSTbB2c,18017
30
+ pvnet-5.0.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ pvnet-5.0.17.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
+ pvnet-5.0.17.dist-info/RECORD,,
File without changes