PVNet 5.0.14__py3-none-any.whl → 5.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  """Base model for all PVNet submodels"""
2
- import copy
3
2
  import logging
4
3
  import os
5
4
  import shutil
@@ -12,9 +11,7 @@ import torch
12
11
  import yaml
13
12
  from huggingface_hub import ModelCard, ModelCardData, snapshot_download
14
13
  from huggingface_hub.hf_api import HfApi
15
- from ocf_data_sampler.numpy_sample.common_types import TensorBatch
16
14
  from safetensors.torch import load_file, save_file
17
- from torchvision.transforms.functional import center_crop
18
15
 
19
16
  from pvnet.utils import (
20
17
  DATA_CONFIG_NAME,
@@ -437,69 +434,6 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
437
434
  else:
438
435
  self.num_output_features = self.forecast_len
439
436
 
440
- def _adapt_batch(self, batch: TensorBatch) -> TensorBatch:
441
- """Slice batches into appropriate shapes for model.
442
-
443
- Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
444
- We make some specific assumptions about the original batch and the derived sliced batch:
445
- - We are only limiting the future projections. I.e. we are never shrinking the batch from
446
- the left hand side of the time axis, only slicing it from the right
447
- - We are only shrinking the spatial crop of the satellite and NWP data
448
-
449
- """
450
- # Create a copy of the batch to avoid modifying the original
451
- new_batch = {key: copy.deepcopy(value) for key, value in batch.items()}
452
-
453
- if "gsp" in new_batch.keys():
454
- # Slice off the end of the GSP data
455
- gsp_len = self.forecast_len + self.history_len + 1
456
- new_batch["gsp"] = new_batch["gsp"][:, :gsp_len]
457
- new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len]
458
-
459
- if "site" in new_batch.keys():
460
- # Slice off the end of the site data
461
- site_len = self.forecast_len + self.history_len + 1
462
- new_batch["site"] = new_batch["site"][:, :site_len]
463
-
464
- # Slice all site related datetime coordinates and features
465
- site_time_keys = [
466
- "site_time_utc",
467
- "site_date_sin",
468
- "site_date_cos",
469
- "site_time_sin",
470
- "site_time_cos",
471
- ]
472
-
473
- for key in site_time_keys:
474
- if key in new_batch.keys():
475
- new_batch[key] = new_batch[key][:, :site_len]
476
-
477
- if self.include_sat:
478
- # Slice off the end of the satellite data and spatially crop
479
- # Shape: batch_size, seq_length, channel, height, width
480
- new_batch["satellite_actual"] = center_crop(
481
- new_batch["satellite_actual"][:, : self.sat_sequence_len],
482
- output_size=self.sat_encoder.image_size_pixels,
483
- )
484
-
485
- if self.include_nwp:
486
- # Slice off the end of the NWP data and spatially crop
487
- for nwp_source in self.nwp_encoders_dict:
488
- # shape: batch_size, seq_len, n_chans, height, width
489
- new_batch["nwp"][nwp_source]["nwp"] = center_crop(
490
- new_batch["nwp"][nwp_source]["nwp"],
491
- output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
492
- )[:, : self.nwp_encoders_dict[nwp_source].sequence_length]
493
-
494
- if self.include_sun:
495
- sun_len = self.forecast_len + self.history_len + 1
496
- # Slice off end of solar coords
497
- for s in ["solar_azimuth", "solar_elevation"]:
498
- if s in new_batch.keys():
499
- new_batch[s] = new_batch[s][:, :sun_len]
500
-
501
- return new_batch
502
-
503
437
  def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
504
438
  """
505
439
  Convert network prediction into a point prediction.
@@ -517,4 +451,4 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
517
451
  """
518
452
  # y_quantiles Shape: batch_size, seq_length, num_quantiles
519
453
  idx = self.output_quantiles.index(0.5)
520
- return y_quantiles[..., idx]
454
+ return y_quantiles[..., idx]
@@ -31,6 +31,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
31
31
  self.out_features = out_features
32
32
  self.image_size_pixels = image_size_pixels
33
33
  self.sequence_length = sequence_length
34
+ self.in_channels = in_channels
34
35
 
35
36
  @abstractmethod
36
37
  def forward(self):
@@ -61,7 +61,6 @@ class LateFusionModel(BaseModel):
61
61
  nwp_interval_minutes: DictConfig | None = None,
62
62
  pv_interval_minutes: int = 5,
63
63
  sat_interval_minutes: int = 5,
64
- adapt_batches: bool = False,
65
64
  ):
66
65
  """Neural network which combines information from different sources.
67
66
 
@@ -110,9 +109,6 @@ class LateFusionModel(BaseModel):
110
109
  data for each source
111
110
  pv_interval_minutes: The interval between each sample of the PV data
112
111
  sat_interval_minutes: The interval between each sample of the satellite data
113
- adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
114
- the model to use. This allows us to overprepare batches and slice from them for the
115
- data we need for a model run.
116
112
  """
117
113
  super().__init__(
118
114
  history_minutes=history_minutes,
@@ -134,7 +130,6 @@ class LateFusionModel(BaseModel):
134
130
  self.add_image_embedding_channel = add_image_embedding_channel
135
131
  self.interval_minutes = interval_minutes
136
132
  self.min_sat_delay_minutes = min_sat_delay_minutes
137
- self.adapt_batches = adapt_batches
138
133
 
139
134
  if self.location_id_mapping is None:
140
135
  logger.warning(
@@ -272,9 +267,6 @@ class LateFusionModel(BaseModel):
272
267
  def forward(self, x: TensorBatch) -> torch.Tensor:
273
268
  """Run model forward"""
274
269
 
275
- if self.adapt_batches:
276
- x = self._adapt_batch(x)
277
-
278
270
  if self.use_id_embedding:
279
271
  # eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0]
280
272
  id = torch.tensor(
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
15
15
  from pvnet.models.base_model import BaseModel
16
16
  from pvnet.optimizers import AbstractOptimizer
17
17
  from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
18
+ from pvnet.utils import validate_batch_against_config
18
19
 
19
20
 
20
21
  class PVNetLightningModule(pl.LightningModule):
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
42
43
  # This setting is only used when lr is tuned with callback
43
44
  self.lr = None
44
45
 
45
- # Set up store for all all validation results so we can log these
46
- self.save_all_validation_results = save_all_validation_results
47
-
48
46
  def transfer_batch_to_device(
49
47
  self,
50
48
  batch: TensorBatch,
@@ -105,10 +103,6 @@ class PVNetLightningModule(pl.LightningModule):
105
103
  """Run training step"""
106
104
  y_hat = self.model(batch)
107
105
 
108
- # Batch may be adapted in the model forward method, would need adapting here too
109
- if self.model.adapt_batches:
110
- batch = self.model._adapt_batch(batch)
111
-
112
106
  y = batch[self.model._target_key][:, -self.model.forecast_len :]
113
107
 
114
108
  losses = self._calculate_common_losses(y, y_hat)
@@ -193,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
193
187
  self._val_horizon_maes: list[np.array] = []
194
188
  if self.current_epoch==0:
195
189
  self._val_persistence_horizon_maes: list[np.array] = []
196
-
190
+
197
191
  # Plot some sample forecasts
198
192
  val_dataset = self.trainer.val_dataloaders.dataset
199
193
 
@@ -209,13 +203,17 @@ class PVNetLightningModule(pl.LightningModule):
209
203
 
210
204
  batch = collate_fn([val_dataset[i] for i in idxs])
211
205
  batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
206
+
207
+ # Batch validation check only during sanity check phase - use first batch
208
+ if self.trainer.sanity_checking and plot_num == 0:
209
+ validate_batch_against_config(
210
+ batch=batch,
211
+ model=self.model
212
+ )
213
+
212
214
  with torch.no_grad():
213
215
  y_hat = self.model(batch)
214
216
 
215
- # Batch may be adapted in the model forward method, would need adapting here too
216
- if self.model.adapt_batches:
217
- batch = self.model._adapt_batch(batch)
218
-
219
217
  fig = plot_sample_forecasts(
220
218
  batch,
221
219
  y_hat,
@@ -235,9 +233,6 @@ class PVNetLightningModule(pl.LightningModule):
235
233
  """Run validation step"""
236
234
 
237
235
  y_hat = self.model(batch)
238
- # Batch may be adapted in the model forward method, would need adapting here too
239
- if self.model.adapt_batches:
240
- batch = self.model._adapt_batch(batch)
241
236
 
242
237
  # Internally store the val predictions
243
238
  self._store_val_predictions(batch, y_hat)
@@ -317,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
317
312
  self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
318
313
 
319
314
  # Optionally save all validation results - these are overridden each epoch
320
- if self.save_all_validation_results:
315
+ if self.hparams.save_all_validation_results:
321
316
  # Add attributes
322
317
  ds_val_results.attrs["epoch"] = self.current_epoch
323
318
 
pvnet/training/train.py CHANGED
@@ -26,7 +26,6 @@ from pvnet.utils import (
26
26
  log = logging.getLogger(__name__)
27
27
 
28
28
 
29
-
30
29
  def resolve_monitor_loss(output_quantiles: list | None) -> str:
31
30
  """Return the desired metric to monitor based on whether quantile regression is being used.
32
31
 
pvnet/utils.py CHANGED
@@ -1,11 +1,15 @@
1
1
  """Utils"""
2
2
  import logging
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  import rich.syntax
5
6
  import rich.tree
6
7
  from lightning.pytorch.utilities import rank_zero_only
7
8
  from omegaconf import DictConfig, OmegaConf
8
9
 
10
+ if TYPE_CHECKING:
11
+ from pvnet.models.base_model import BaseModel
12
+
9
13
  logger = logging.getLogger(__name__)
10
14
 
11
15
 
@@ -79,6 +83,7 @@ def print_config(
79
83
  branch = tree.add(field, style=style, guide_style=style)
80
84
 
81
85
  config_section = config.get(field)
86
+
82
87
  branch_content = str(config_section)
83
88
  if isinstance(config_section, DictConfig):
84
89
  branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
@@ -86,3 +91,67 @@ def print_config(
86
91
  branch.add(rich.syntax.Syntax(branch_content, "yaml"))
87
92
 
88
93
  rich.print(tree)
94
+
95
+
96
+ def validate_batch_against_config(
97
+ batch: dict,
98
+ model: "BaseModel",
99
+ ) -> None:
100
+ """Validates tensor shapes in batch against model configuration."""
101
+ logger.info("Performing batch shape validation against model config.")
102
+
103
+ # NWP validation
104
+ if hasattr(model, 'nwp_encoders_dict'):
105
+ if "nwp" not in batch:
106
+ raise ValueError(
107
+ "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
+ )
109
+
110
+ for source, nwp_data in batch["nwp"].items():
111
+ if source in model.nwp_encoders_dict:
112
+
113
+ enc = model.nwp_encoders_dict[source]
114
+ expected_channels = enc.in_channels
115
+ if model.add_image_embedding_channel:
116
+ expected_channels -= 1
117
+
118
+ expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
119
+ expected_channels, enc.image_size_pixels, enc.image_size_pixels)
120
+ if tuple(nwp_data["nwp"].shape) != expected:
121
+ actual_shape = tuple(nwp_data['nwp'].shape)
122
+ raise ValueError(
123
+ f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
124
+ )
125
+
126
+ # Satellite validation
127
+ if hasattr(model, 'sat_encoder'):
128
+ if "satellite_actual" not in batch:
129
+ raise ValueError(
130
+ "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
131
+ )
132
+
133
+ enc = model.sat_encoder
134
+ expected_channels = enc.in_channels
135
+ if model.add_image_embedding_channel:
136
+ expected_channels -= 1
137
+
138
+ expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
139
+ enc.image_size_pixels, enc.image_size_pixels)
140
+ if tuple(batch["satellite_actual"].shape) != expected:
141
+ actual_shape = tuple(batch['satellite_actual'].shape)
142
+ raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
143
+
144
+ # GSP/Site validation
145
+ key = model._target_key
146
+ if key in batch:
147
+ total_minutes = model.history_minutes + model.forecast_minutes
148
+ interval = model.interval_minutes
149
+ expected_len = total_minutes // interval + 1
150
+ expected = (batch[key].shape[0], expected_len)
151
+ if tuple(batch[key].shape) != expected:
152
+ actual_shape = tuple(batch[key].shape)
153
+ raise ValueError(
154
+ f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
155
+ )
156
+
157
+ logger.info("Batch shape validation successful!")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.14
3
+ Version: 5.0.16
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -14,7 +14,6 @@ Requires-Dist: xarray
14
14
  Requires-Dist: h5netcdf
15
15
  Requires-Dist: torch>=2.0.0
16
16
  Requires-Dist: lightning
17
- Requires-Dist: torchvision
18
17
  Requires-Dist: typer
19
18
  Requires-Dist: sqlalchemy
20
19
  Requires-Dist: fsspec[s3]
@@ -1,19 +1,19 @@
1
1
  pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
3
3
  pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
4
- pvnet/utils.py,sha256=h4w9nmx6V_IAtiRp6VQ90TmQZGGTdMzU63WfGmL-pPs,2666
4
+ pvnet/utils.py,sha256=iHG0wlN_cITKXpR16w54fs2R68wkX7vy3a_6f4SuanY,5414
5
5
  pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
6
6
  pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
7
7
  pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
8
8
  pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
9
9
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
10
- pvnet/models/base_model.py,sha256=sdVWUuJivQYmfK0CF2wzv8Hm-C50lJ--zttMScRYirY,19203
10
+ pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
11
11
  pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
12
12
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
13
13
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
14
- pvnet/models/late_fusion/late_fusion.py,sha256=Lz4qrhIwVz7hrIR2sOzOoWAL2rKWWmUREIAtQs0zI8c,16131
14
+ pvnet/models/late_fusion/late_fusion.py,sha256=VpP1aY646iO-JBlYkyoBlj0Z-gzsqaHnMgpNzjTqAIo,15735
15
15
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
16
- pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=N8CZCQUwydmBip10AHY5hSkwrSP12B_Mrvm-5XVcz1s,3005
16
+ pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
17
17
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=i8POpAsdNFwdnx64Hn7Zm9q9bby8qMPn7G7miwhsxGk,6645
18
18
  pvnet/models/late_fusion/linear_networks/__init__.py,sha256=16dLdGfH4QWNrI1fUB-cXWx24lArqo2lWIjdUCWbcBY,96
19
19
  pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM0NphZ5gRSOpogo7927XIlZJ9LM0,2787
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
22
22
  pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
23
23
  pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
24
24
  pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
25
- pvnet/training/lightning_module.py,sha256=GVqdi5ALFo9-_WRYeyMMj2qH_k4gPxQ2sG6FhL_wRFE,13242
25
+ pvnet/training/lightning_module.py,sha256=TkvLtOPswRtTIwcmAvNOSHg2RvIMzHsJVvs_d0xiRmQ,12891
26
26
  pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
27
- pvnet/training/train.py,sha256=zj9JMi9C6W68vGsQUBapWkJ4aDzDuJFMv0IVjO73s1k,5215
28
- pvnet-5.0.14.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
- pvnet-5.0.14.dist-info/METADATA,sha256=FJMwH-1nLPAkE6zsgtdQeofXdWoJn3qzEpUU9ltWn_s,18044
30
- pvnet-5.0.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- pvnet-5.0.14.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
- pvnet-5.0.14.dist-info/RECORD,,
27
+ pvnet/training/train.py,sha256=1tDA34ianCRfilS0yEeIpR9nsQWaJiiZfTD2qRUmgEc,5214
28
+ pvnet-5.0.16.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
29
+ pvnet-5.0.16.dist-info/METADATA,sha256=d0m-tj088qPWfu3s9CMgUCWNbDaf8LCYXpGP1PN0Idg,18017
30
+ pvnet-5.0.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ pvnet-5.0.16.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
32
+ pvnet-5.0.16.dist-info/RECORD,,
File without changes