PVNet 5.0.15__py3-none-any.whl → 5.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -451,4 +451,4 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
451
451
  """
452
452
  # y_quantiles Shape: batch_size, seq_length, num_quantiles
453
453
  idx = self.output_quantiles.index(0.5)
454
- return y_quantiles[..., idx]
454
+ return y_quantiles[..., idx]
@@ -31,6 +31,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
31
31
  self.out_features = out_features
32
32
  self.image_size_pixels = image_size_pixels
33
33
  self.sequence_length = sequence_length
34
+ self.in_channels = in_channels
34
35
 
35
36
  @abstractmethod
36
37
  def forward(self):
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
15
15
  from pvnet.models.base_model import BaseModel
16
16
  from pvnet.optimizers import AbstractOptimizer
17
17
  from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
18
+ from pvnet.utils import validate_batch_against_config
18
19
 
19
20
 
20
21
  class PVNetLightningModule(pl.LightningModule):
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
42
43
  # This setting is only used when lr is tuned with callback
43
44
  self.lr = None
44
45
 
45
- # Set up store for all all validation results so we can log these
46
- self.save_all_validation_results = save_all_validation_results
47
-
48
46
  def transfer_batch_to_device(
49
47
  self,
50
48
  batch: TensorBatch,
@@ -189,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
189
187
  self._val_horizon_maes: list[np.array] = []
190
188
  if self.current_epoch==0:
191
189
  self._val_persistence_horizon_maes: list[np.array] = []
192
-
190
+
193
191
  # Plot some sample forecasts
194
192
  val_dataset = self.trainer.val_dataloaders.dataset
195
193
 
@@ -205,6 +203,14 @@ class PVNetLightningModule(pl.LightningModule):
205
203
 
206
204
  batch = collate_fn([val_dataset[i] for i in idxs])
207
205
  batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
206
+
207
+ # Batch validation check only during sanity check phase - use first batch
208
+ if self.trainer.sanity_checking and plot_num == 0:
209
+ validate_batch_against_config(
210
+ batch=batch,
211
+ model=self.model
212
+ )
213
+
208
214
  with torch.no_grad():
209
215
  y_hat = self.model(batch)
210
216
 
@@ -306,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
306
312
  self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
307
313
 
308
314
  # Optionally save all validation results - these are overridden each epoch
309
- if self.save_all_validation_results:
315
+ if self.hparams.save_all_validation_results:
310
316
  # Add attributes
311
317
  ds_val_results.attrs["epoch"] = self.current_epoch
312
318
 
pvnet/training/train.py CHANGED
@@ -26,7 +26,6 @@ from pvnet.utils import (
26
26
  log = logging.getLogger(__name__)
27
27
 
28
28
 
29
-
30
29
  def resolve_monitor_loss(output_quantiles: list | None) -> str:
31
30
  """Return the desired metric to monitor based on whether quantile regression is being used.
32
31
 
pvnet/utils.py CHANGED
@@ -1,11 +1,15 @@
1
1
  """Utils"""
2
2
  import logging
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  import rich.syntax
5
6
  import rich.tree
6
7
  from lightning.pytorch.utilities import rank_zero_only
7
8
  from omegaconf import DictConfig, OmegaConf
8
9
 
10
+ if TYPE_CHECKING:
11
+ from pvnet.models.base_model import BaseModel
12
+
9
13
  logger = logging.getLogger(__name__)
10
14
 
11
15
 
@@ -79,6 +83,7 @@ def print_config(
79
83
  branch = tree.add(field, style=style, guide_style=style)
80
84
 
81
85
  config_section = config.get(field)
86
+
82
87
  branch_content = str(config_section)
83
88
  if isinstance(config_section, DictConfig):
84
89
  branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
@@ -86,3 +91,67 @@ def print_config(
86
91
  branch.add(rich.syntax.Syntax(branch_content, "yaml"))
87
92
 
88
93
  rich.print(tree)
94
+
95
+
96
+ def validate_batch_against_config(
97
+ batch: dict,
98
+ model: "BaseModel",
99
+ ) -> None:
100
+ """Validates tensor shapes in batch against model configuration."""
101
+ logger.info("Performing batch shape validation against model config.")
102
+
103
+ # NWP validation
104
+ if hasattr(model, 'nwp_encoders_dict'):
105
+ if "nwp" not in batch:
106
+ raise ValueError(
107
+ "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
+ )
109
+
110
+ for source, nwp_data in batch["nwp"].items():
111
+ if source in model.nwp_encoders_dict:
112
+
113
+ enc = model.nwp_encoders_dict[source]
114
+ expected_channels = enc.in_channels
115
+ if model.add_image_embedding_channel:
116
+ expected_channels -= 1
117
+
118
+ expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
119
+ expected_channels, enc.image_size_pixels, enc.image_size_pixels)
120
+ if tuple(nwp_data["nwp"].shape) != expected:
121
+ actual_shape = tuple(nwp_data['nwp'].shape)
122
+ raise ValueError(
123
+ f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
124
+ )
125
+
126
+ # Satellite validation
127
+ if hasattr(model, 'sat_encoder'):
128
+ if "satellite_actual" not in batch:
129
+ raise ValueError(
130
+ "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
131
+ )
132
+
133
+ enc = model.sat_encoder
134
+ expected_channels = enc.in_channels
135
+ if model.add_image_embedding_channel:
136
+ expected_channels -= 1
137
+
138
+ expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
139
+ enc.image_size_pixels, enc.image_size_pixels)
140
+ if tuple(batch["satellite_actual"].shape) != expected:
141
+ actual_shape = tuple(batch['satellite_actual'].shape)
142
+ raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
143
+
144
+ # GSP/Site validation
145
+ key = model._target_key
146
+ if key in batch:
147
+ total_minutes = model.history_minutes + model.forecast_minutes
148
+ interval = model.interval_minutes
149
+ expected_len = total_minutes // interval + 1
150
+ expected = (batch[key].shape[0], expected_len)
151
+ if tuple(batch[key].shape) != expected:
152
+ actual_shape = tuple(batch[key].shape)
153
+ raise ValueError(
154
+ f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
155
+ )
156
+
157
+ logger.info("Batch shape validation successful!")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.15
3
+ Version: 5.0.16
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -1,19 +1,19 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
3
3
  pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
4
- pvnet/utils.py,sha256=h4w9nmx6V_IAtiRp6VQ90TmQZGGTdMzU63WfGmL-pPs,2666
4
+ pvnet/utils.py,sha256=iHG0wlN_cITKXpR16w54fs2R68wkX7vy3a_6f4SuanY,5414
5
5
  pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
6
6
  pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
7
7
  pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
8
8
  pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
9
9
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
10
- pvnet/models/base_model.py,sha256=4Sxal_8hq4F6irr2uG7GbtYZAjaOzUxmLBnO1U7fce0,16181
10
+ pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
11
11
  pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
12
12
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
13
13
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
14
14
  pvnet/models/late_fusion/late_fusion.py,sha256=VpP1aY646iO-JBlYkyoBlj0Z-gzsqaHnMgpNzjTqAIo,15735
15
15
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
16
- pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=N8CZCQUwydmBip10AHY5hSkwrSP12B_Mrvm-5XVcz1s,3005
16
+ pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
17
17
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=i8POpAsdNFwdnx64Hn7Zm9q9bby8qMPn7G7miwhsxGk,6645
18
18
  pvnet/models/late_fusion/linear_networks/__init__.py,sha256=16dLdGfH4QWNrI1fUB-cXWx24lArqo2lWIjdUCWbcBY,96
19
19
  pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM0NphZ5gRSOpogo7927XIlZJ9LM0,2787
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
22
22
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
23
23
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
24
24
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
25
- pvnet/training/lightning_module.py,sha256=UvtpijxlRSDKrwH979FSQBeCLLMReYe-S-guWe1upl4,12685
25
+ pvnet/training/lightning_module.py,sha256=TkvLtOPswRtTIwcmAvNOSHg2RvIMzHsJVvs_d0xiRmQ,12891
26
26
  pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
27
- pvnet/training/train.py,sha256=zj9JMi9C6W68vGsQUBapWkJ4aDzDuJFMv0IVjO73s1k,5215
28
- pvnet-5.0.15.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
- pvnet-5.0.15.dist-info/METADATA,sha256=uiHZcCcXv50Cl82sg_5XU1D0adASBW8C9DRH6pvamaA,18017
30
- pvnet-5.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- pvnet-5.0.15.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
- pvnet-5.0.15.dist-info/RECORD,,
27
+ pvnet/training/train.py,sha256=1tDA34ianCRfilS0yEeIpR9nsQWaJiiZfTD2qRUmgEc,5214
28
+ pvnet-5.0.16.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
+ pvnet-5.0.16.dist-info/METADATA,sha256=d0m-tj088qPWfu3s9CMgUCWNbDaf8LCYXpGP1PN0Idg,18017
30
+ pvnet-5.0.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ pvnet-5.0.16.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
+ pvnet-5.0.16.dist-info/RECORD,,
File without changes