PVNet 5.0.15__py3-none-any.whl → 5.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet/models/base_model.py +1 -1
- pvnet/models/late_fusion/encoders/basic_blocks.py +1 -1
- pvnet/models/late_fusion/encoders/encoders3d.py +3 -2
- pvnet/training/lightning_module.py +11 -5
- pvnet/training/train.py +0 -1
- pvnet/utils.py +69 -0
- {pvnet-5.0.15.dist-info → pvnet-5.0.17.dist-info}/METADATA +1 -1
- {pvnet-5.0.15.dist-info → pvnet-5.0.17.dist-info}/RECORD +11 -11
- {pvnet-5.0.15.dist-info → pvnet-5.0.17.dist-info}/WHEEL +0 -0
- {pvnet-5.0.15.dist-info → pvnet-5.0.17.dist-info}/licenses/LICENSE +0 -0
- {pvnet-5.0.15.dist-info → pvnet-5.0.17.dist-info}/top_level.txt +0 -0
pvnet/models/base_model.py
CHANGED
|
@@ -16,7 +16,6 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
|
|
|
16
16
|
self,
|
|
17
17
|
sequence_length: int,
|
|
18
18
|
image_size_pixels: int,
|
|
19
|
-
in_channels: int,
|
|
20
19
|
out_features: int,
|
|
21
20
|
):
|
|
22
21
|
"""Abstract class for NWP/satellite encoder.
|
|
@@ -31,6 +30,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
|
|
|
31
30
|
self.out_features = out_features
|
|
32
31
|
self.image_size_pixels = image_size_pixels
|
|
33
32
|
self.sequence_length = sequence_length
|
|
33
|
+
self.in_channels = in_channels
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
36
|
def forward(self):
|
|
@@ -41,7 +41,8 @@ class DefaultPVNet(AbstractNWPSatelliteEncoder):
|
|
|
41
41
|
padding: The padding used in the conv3d layers. If an int, the same padding
|
|
42
42
|
is used in all dimensions
|
|
43
43
|
"""
|
|
44
|
-
|
|
44
|
+
|
|
45
|
+
super().__init__(sequence_length, image_size_pixels, out_features)
|
|
45
46
|
|
|
46
47
|
if isinstance(padding, int):
|
|
47
48
|
padding = (padding, padding, padding)
|
|
@@ -136,7 +137,7 @@ class ResConv3DNet(AbstractNWPSatelliteEncoder):
|
|
|
136
137
|
batch_norm: Whether to include batch normalisation.
|
|
137
138
|
dropout_frac: Probability of an element to be zeroed in the residual pathways.
|
|
138
139
|
"""
|
|
139
|
-
super().__init__(sequence_length, image_size_pixels,
|
|
140
|
+
super().__init__(sequence_length, image_size_pixels, out_features)
|
|
140
141
|
|
|
141
142
|
model = [
|
|
142
143
|
nn.Conv3d(
|
|
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
|
|
|
15
15
|
from pvnet.models.base_model import BaseModel
|
|
16
16
|
from pvnet.optimizers import AbstractOptimizer
|
|
17
17
|
from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
|
|
18
|
+
from pvnet.utils import validate_batch_against_config
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class PVNetLightningModule(pl.LightningModule):
|
|
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
42
43
|
# This setting is only used when lr is tuned with callback
|
|
43
44
|
self.lr = None
|
|
44
45
|
|
|
45
|
-
# Set up store for all all validation results so we can log these
|
|
46
|
-
self.save_all_validation_results = save_all_validation_results
|
|
47
|
-
|
|
48
46
|
def transfer_batch_to_device(
|
|
49
47
|
self,
|
|
50
48
|
batch: TensorBatch,
|
|
@@ -189,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
189
187
|
self._val_horizon_maes: list[np.array] = []
|
|
190
188
|
if self.current_epoch==0:
|
|
191
189
|
self._val_persistence_horizon_maes: list[np.array] = []
|
|
192
|
-
|
|
190
|
+
|
|
193
191
|
# Plot some sample forecasts
|
|
194
192
|
val_dataset = self.trainer.val_dataloaders.dataset
|
|
195
193
|
|
|
@@ -205,6 +203,14 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
205
203
|
|
|
206
204
|
batch = collate_fn([val_dataset[i] for i in idxs])
|
|
207
205
|
batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
|
|
206
|
+
|
|
207
|
+
# Batch validation check only during sanity check phase - use first batch
|
|
208
|
+
if self.trainer.sanity_checking and plot_num == 0:
|
|
209
|
+
validate_batch_against_config(
|
|
210
|
+
batch=batch,
|
|
211
|
+
model=self.model
|
|
212
|
+
)
|
|
213
|
+
|
|
208
214
|
with torch.no_grad():
|
|
209
215
|
y_hat = self.model(batch)
|
|
210
216
|
|
|
@@ -306,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
306
312
|
self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
|
|
307
313
|
|
|
308
314
|
# Optionally save all validation results - these are overridden each epoch
|
|
309
|
-
if self.save_all_validation_results:
|
|
315
|
+
if self.hparams.save_all_validation_results:
|
|
310
316
|
# Add attributes
|
|
311
317
|
ds_val_results.attrs["epoch"] = self.current_epoch
|
|
312
318
|
|
pvnet/training/train.py
CHANGED
pvnet/utils.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Utils"""
|
|
2
2
|
import logging
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
import rich.syntax
|
|
5
6
|
import rich.tree
|
|
6
7
|
from lightning.pytorch.utilities import rank_zero_only
|
|
7
8
|
from omegaconf import DictConfig, OmegaConf
|
|
8
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from pvnet.models.base_model import BaseModel
|
|
12
|
+
|
|
9
13
|
logger = logging.getLogger(__name__)
|
|
10
14
|
|
|
11
15
|
|
|
@@ -79,6 +83,7 @@ def print_config(
|
|
|
79
83
|
branch = tree.add(field, style=style, guide_style=style)
|
|
80
84
|
|
|
81
85
|
config_section = config.get(field)
|
|
86
|
+
|
|
82
87
|
branch_content = str(config_section)
|
|
83
88
|
if isinstance(config_section, DictConfig):
|
|
84
89
|
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
|
|
@@ -86,3 +91,67 @@ def print_config(
|
|
|
86
91
|
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
87
92
|
|
|
88
93
|
rich.print(tree)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def validate_batch_against_config(
|
|
97
|
+
batch: dict,
|
|
98
|
+
model: "BaseModel",
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Validates tensor shapes in batch against model configuration."""
|
|
101
|
+
logger.info("Performing batch shape validation against model config.")
|
|
102
|
+
|
|
103
|
+
# NWP validation
|
|
104
|
+
if hasattr(model, 'nwp_encoders_dict'):
|
|
105
|
+
if "nwp" not in batch:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
for source, nwp_data in batch["nwp"].items():
|
|
111
|
+
if source in model.nwp_encoders_dict:
|
|
112
|
+
|
|
113
|
+
enc = model.nwp_encoders_dict[source]
|
|
114
|
+
expected_channels = enc.in_channels
|
|
115
|
+
if model.add_image_embedding_channel:
|
|
116
|
+
expected_channels -= 1
|
|
117
|
+
|
|
118
|
+
expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
|
|
119
|
+
expected_channels, enc.image_size_pixels, enc.image_size_pixels)
|
|
120
|
+
if tuple(nwp_data["nwp"].shape) != expected:
|
|
121
|
+
actual_shape = tuple(nwp_data['nwp'].shape)
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Satellite validation
|
|
127
|
+
if hasattr(model, 'sat_encoder'):
|
|
128
|
+
if "satellite_actual" not in batch:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
enc = model.sat_encoder
|
|
134
|
+
expected_channels = enc.in_channels
|
|
135
|
+
if model.add_image_embedding_channel:
|
|
136
|
+
expected_channels -= 1
|
|
137
|
+
|
|
138
|
+
expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
|
|
139
|
+
enc.image_size_pixels, enc.image_size_pixels)
|
|
140
|
+
if tuple(batch["satellite_actual"].shape) != expected:
|
|
141
|
+
actual_shape = tuple(batch['satellite_actual'].shape)
|
|
142
|
+
raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
|
|
143
|
+
|
|
144
|
+
# GSP/Site validation
|
|
145
|
+
key = model._target_key
|
|
146
|
+
if key in batch:
|
|
147
|
+
total_minutes = model.history_minutes + model.forecast_minutes
|
|
148
|
+
interval = model.interval_minutes
|
|
149
|
+
expected_len = total_minutes // interval + 1
|
|
150
|
+
expected = (batch[key].shape[0], expected_len)
|
|
151
|
+
if tuple(batch[key].shape) != expected:
|
|
152
|
+
actual_shape = tuple(batch[key].shape)
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
logger.info("Batch shape validation successful!")
|
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
|
|
2
2
|
pvnet/load_model.py,sha256=LzN06O3oXzqhj1Dh_VlschDTxOq_Eea0OWDxrrboSKw,3726
|
|
3
3
|
pvnet/optimizers.py,sha256=1N4b-Xd6QiIrcUU8cbU326bbFC0BvMNIV8VYWtGILJc,6548
|
|
4
|
-
pvnet/utils.py,sha256=
|
|
4
|
+
pvnet/utils.py,sha256=iHG0wlN_cITKXpR16w54fs2R68wkX7vy3a_6f4SuanY,5414
|
|
5
5
|
pvnet/data/__init__.py,sha256=FFD2tkLwEw9YiAVDam3tmaXNWMKiKVMHcnIz7zXCtrg,191
|
|
6
6
|
pvnet/data/base_datamodule.py,sha256=Ibz0RoSr15HT6tMCs6ftXTpMa-NOKAmEd5ky55MqEK0,8615
|
|
7
7
|
pvnet/data/site_datamodule.py,sha256=-KGxirGCBXVwcCREsjFkF7JDfa6NICv8bBDV6EILF_Q,962
|
|
8
8
|
pvnet/data/uk_regional_datamodule.py,sha256=KA2_7DYuSggmD5b-XiXshXq8xmu36BjtFmy_pS7e4QE,1017
|
|
9
9
|
pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
|
|
10
|
-
pvnet/models/base_model.py,sha256=
|
|
10
|
+
pvnet/models/base_model.py,sha256=CnQaaf2kAdOcXqo1319nWa120mHfLQiwOQ639m4OzPk,16182
|
|
11
11
|
pvnet/models/ensemble.py,sha256=1mFUEsl33kWcLL5d7zfDm9ypWxgAxBHgBiJLt0vwTeg,2363
|
|
12
12
|
pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
|
|
13
13
|
pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
|
|
14
14
|
pvnet/models/late_fusion/late_fusion.py,sha256=VpP1aY646iO-JBlYkyoBlj0Z-gzsqaHnMgpNzjTqAIo,15735
|
|
15
15
|
pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
|
|
16
|
-
pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=
|
|
17
|
-
pvnet/models/late_fusion/encoders/encoders3d.py,sha256=
|
|
16
|
+
pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=Yg0Pch0nOdvXjxFln_yZyEeyWYUlSUjbnXhZIEdhHdw,3018
|
|
17
|
+
pvnet/models/late_fusion/encoders/encoders3d.py,sha256=pZ_QnCcvJv2Dw64ielFFfK_4rxxYy7bvNrb_9l4RZyQ,6620
|
|
18
18
|
pvnet/models/late_fusion/linear_networks/__init__.py,sha256=16dLdGfH4QWNrI1fUB-cXWx24lArqo2lWIjdUCWbcBY,96
|
|
19
19
|
pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM0NphZ5gRSOpogo7927XIlZJ9LM0,2787
|
|
20
20
|
pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
|
|
@@ -22,11 +22,11 @@ pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4T
|
|
|
22
22
|
pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
|
|
23
23
|
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=k4z690cfcP6J4pm2KtDujHN-W3uOl7QY0WvBIu1tM8c,11703
|
|
24
24
|
pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
|
|
25
|
-
pvnet/training/lightning_module.py,sha256=
|
|
25
|
+
pvnet/training/lightning_module.py,sha256=TkvLtOPswRtTIwcmAvNOSHg2RvIMzHsJVvs_d0xiRmQ,12891
|
|
26
26
|
pvnet/training/plots.py,sha256=4xID7TBA4IazaARaCN5AoG5fFPJF1wIprn0y6I0C31c,2469
|
|
27
|
-
pvnet/training/train.py,sha256=
|
|
28
|
-
pvnet-5.0.
|
|
29
|
-
pvnet-5.0.
|
|
30
|
-
pvnet-5.0.
|
|
31
|
-
pvnet-5.0.
|
|
32
|
-
pvnet-5.0.
|
|
27
|
+
pvnet/training/train.py,sha256=1tDA34ianCRfilS0yEeIpR9nsQWaJiiZfTD2qRUmgEc,5214
|
|
28
|
+
pvnet-5.0.17.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
|
|
29
|
+
pvnet-5.0.17.dist-info/METADATA,sha256=WxuG2XNuflRMoDdvhU6tpojEv9vf6OfivvEudSTbB2c,18017
|
|
30
|
+
pvnet-5.0.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
pvnet-5.0.17.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
|
|
32
|
+
pvnet-5.0.17.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|