PVNet 5.0.15__tar.gz → 5.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pvnet-5.0.15 → pvnet-5.0.16}/PKG-INFO +1 -1
- {pvnet-5.0.15 → pvnet-5.0.16}/PVNet.egg-info/PKG-INFO +1 -1
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/base_model.py +1 -1
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/encoders/basic_blocks.py +1 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/training/lightning_module.py +11 -5
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/training/train.py +0 -1
- pvnet-5.0.16/pvnet/utils.py +157 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/tests/test_end2end.py +5 -4
- pvnet-5.0.15/pvnet/utils.py +0 -88
- {pvnet-5.0.15 → pvnet-5.0.16}/LICENSE +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/PVNet.egg-info/SOURCES.txt +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/PVNet.egg-info/dependency_links.txt +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/PVNet.egg-info/requires.txt +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/PVNet.egg-info/top_level.txt +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/README.md +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/data/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/data/base_datamodule.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/data/site_datamodule.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/data/uk_regional_datamodule.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/load_model.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/ensemble.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/basic_blocks.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/late_fusion.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/optimizers.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/training/__init__.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pvnet/training/plots.py +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/pyproject.toml +0 -0
- {pvnet-5.0.15 → pvnet-5.0.16}/setup.cfg +0 -0
|
@@ -31,6 +31,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
|
|
|
31
31
|
self.out_features = out_features
|
|
32
32
|
self.image_size_pixels = image_size_pixels
|
|
33
33
|
self.sequence_length = sequence_length
|
|
34
|
+
self.in_channels = in_channels
|
|
34
35
|
|
|
35
36
|
@abstractmethod
|
|
36
37
|
def forward(self):
|
|
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
|
|
|
15
15
|
from pvnet.models.base_model import BaseModel
|
|
16
16
|
from pvnet.optimizers import AbstractOptimizer
|
|
17
17
|
from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
|
|
18
|
+
from pvnet.utils import validate_batch_against_config
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class PVNetLightningModule(pl.LightningModule):
|
|
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
42
43
|
# This setting is only used when lr is tuned with callback
|
|
43
44
|
self.lr = None
|
|
44
45
|
|
|
45
|
-
# Set up store for all all validation results so we can log these
|
|
46
|
-
self.save_all_validation_results = save_all_validation_results
|
|
47
|
-
|
|
48
46
|
def transfer_batch_to_device(
|
|
49
47
|
self,
|
|
50
48
|
batch: TensorBatch,
|
|
@@ -189,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
189
187
|
self._val_horizon_maes: list[np.array] = []
|
|
190
188
|
if self.current_epoch==0:
|
|
191
189
|
self._val_persistence_horizon_maes: list[np.array] = []
|
|
192
|
-
|
|
190
|
+
|
|
193
191
|
# Plot some sample forecasts
|
|
194
192
|
val_dataset = self.trainer.val_dataloaders.dataset
|
|
195
193
|
|
|
@@ -205,6 +203,14 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
205
203
|
|
|
206
204
|
batch = collate_fn([val_dataset[i] for i in idxs])
|
|
207
205
|
batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
|
|
206
|
+
|
|
207
|
+
# Batch validation check only during sanity check phase - use first batch
|
|
208
|
+
if self.trainer.sanity_checking and plot_num == 0:
|
|
209
|
+
validate_batch_against_config(
|
|
210
|
+
batch=batch,
|
|
211
|
+
model=self.model
|
|
212
|
+
)
|
|
213
|
+
|
|
208
214
|
with torch.no_grad():
|
|
209
215
|
y_hat = self.model(batch)
|
|
210
216
|
|
|
@@ -306,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
|
|
|
306
312
|
self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
|
|
307
313
|
|
|
308
314
|
# Optionally save all validation results - these are overridden each epoch
|
|
309
|
-
if self.save_all_validation_results:
|
|
315
|
+
if self.hparams.save_all_validation_results:
|
|
310
316
|
# Add attributes
|
|
311
317
|
ds_val_results.attrs["epoch"] = self.current_epoch
|
|
312
318
|
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Utils"""
|
|
2
|
+
import logging
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import rich.syntax
|
|
6
|
+
import rich.tree
|
|
7
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from pvnet.models.base_model import BaseModel
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
|
|
17
|
+
MODEL_CONFIG_NAME = "model_config.yaml"
|
|
18
|
+
DATA_CONFIG_NAME = "data_config.yaml"
|
|
19
|
+
DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
|
|
20
|
+
FULL_CONFIG_NAME = "full_experiment_config.yaml"
|
|
21
|
+
MODEL_CARD_NAME = "README.md"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def run_config_utilities(config: DictConfig) -> None:
|
|
26
|
+
"""A couple of optional utilities.
|
|
27
|
+
|
|
28
|
+
Controlled by main config file:
|
|
29
|
+
- forcing debug friendly configuration
|
|
30
|
+
|
|
31
|
+
Modifies DictConfig in place.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
config (DictConfig): Configuration composed by Hydra.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# Enable adding new keys to config
|
|
38
|
+
OmegaConf.set_struct(config, False)
|
|
39
|
+
|
|
40
|
+
# Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
|
|
41
|
+
if config.trainer.get("fast_dev_run"):
|
|
42
|
+
logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
|
|
43
|
+
# Debuggers don't like GPUs or multiprocessing
|
|
44
|
+
if config.trainer.get("gpus"):
|
|
45
|
+
config.trainer.gpus = 0
|
|
46
|
+
if config.datamodule.get("pin_memory"):
|
|
47
|
+
config.datamodule.pin_memory = False
|
|
48
|
+
if config.datamodule.get("num_workers"):
|
|
49
|
+
config.datamodule.num_workers = 0
|
|
50
|
+
if config.datamodule.get("prefetch_factor"):
|
|
51
|
+
config.datamodule.prefetch_factor = None
|
|
52
|
+
|
|
53
|
+
# Disable adding new keys to config
|
|
54
|
+
OmegaConf.set_struct(config, True)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@rank_zero_only
|
|
58
|
+
def print_config(
|
|
59
|
+
config: DictConfig,
|
|
60
|
+
fields: tuple[str] = (
|
|
61
|
+
"trainer",
|
|
62
|
+
"model",
|
|
63
|
+
"datamodule",
|
|
64
|
+
"callbacks",
|
|
65
|
+
"logger",
|
|
66
|
+
"seed",
|
|
67
|
+
),
|
|
68
|
+
resolve: bool = True,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
config (DictConfig): Configuration composed by Hydra.
|
|
74
|
+
fields (Sequence[str], optional): Determines which main fields from config will
|
|
75
|
+
be printed and in what order.
|
|
76
|
+
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
style = "dim"
|
|
80
|
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
|
81
|
+
|
|
82
|
+
for field in fields:
|
|
83
|
+
branch = tree.add(field, style=style, guide_style=style)
|
|
84
|
+
|
|
85
|
+
config_section = config.get(field)
|
|
86
|
+
|
|
87
|
+
branch_content = str(config_section)
|
|
88
|
+
if isinstance(config_section, DictConfig):
|
|
89
|
+
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
|
|
90
|
+
|
|
91
|
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
92
|
+
|
|
93
|
+
rich.print(tree)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def validate_batch_against_config(
|
|
97
|
+
batch: dict,
|
|
98
|
+
model: "BaseModel",
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Validates tensor shapes in batch against model configuration."""
|
|
101
|
+
logger.info("Performing batch shape validation against model config.")
|
|
102
|
+
|
|
103
|
+
# NWP validation
|
|
104
|
+
if hasattr(model, 'nwp_encoders_dict'):
|
|
105
|
+
if "nwp" not in batch:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
for source, nwp_data in batch["nwp"].items():
|
|
111
|
+
if source in model.nwp_encoders_dict:
|
|
112
|
+
|
|
113
|
+
enc = model.nwp_encoders_dict[source]
|
|
114
|
+
expected_channels = enc.in_channels
|
|
115
|
+
if model.add_image_embedding_channel:
|
|
116
|
+
expected_channels -= 1
|
|
117
|
+
|
|
118
|
+
expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
|
|
119
|
+
expected_channels, enc.image_size_pixels, enc.image_size_pixels)
|
|
120
|
+
if tuple(nwp_data["nwp"].shape) != expected:
|
|
121
|
+
actual_shape = tuple(nwp_data['nwp'].shape)
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Satellite validation
|
|
127
|
+
if hasattr(model, 'sat_encoder'):
|
|
128
|
+
if "satellite_actual" not in batch:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
enc = model.sat_encoder
|
|
134
|
+
expected_channels = enc.in_channels
|
|
135
|
+
if model.add_image_embedding_channel:
|
|
136
|
+
expected_channels -= 1
|
|
137
|
+
|
|
138
|
+
expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
|
|
139
|
+
enc.image_size_pixels, enc.image_size_pixels)
|
|
140
|
+
if tuple(batch["satellite_actual"].shape) != expected:
|
|
141
|
+
actual_shape = tuple(batch['satellite_actual'].shape)
|
|
142
|
+
raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
|
|
143
|
+
|
|
144
|
+
# GSP/Site validation
|
|
145
|
+
key = model._target_key
|
|
146
|
+
if key in batch:
|
|
147
|
+
total_minutes = model.history_minutes + model.forecast_minutes
|
|
148
|
+
interval = model.interval_minutes
|
|
149
|
+
expected_len = total_minutes // interval + 1
|
|
150
|
+
expected = (batch[key].shape[0], expected_len)
|
|
151
|
+
if tuple(batch[key].shape) != expected:
|
|
152
|
+
actual_shape = tuple(batch[key].shape)
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
logger.info("Batch shape validation successful!")
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import lightning
|
|
2
|
-
|
|
3
|
-
from pvnet.
|
|
2
|
+
|
|
3
|
+
from pvnet.data import UKRegionalStreamedDataModule
|
|
4
4
|
from pvnet.optimizers import EmbAdamWReduceLROnPlateau
|
|
5
|
+
from pvnet.training.lightning_module import PVNetLightningModule
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_model):
|
|
@@ -15,7 +16,7 @@ def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_mo
|
|
|
15
16
|
dataset_pickle_dir=f"{session_tmp_path}/dataset_pickles"
|
|
16
17
|
)
|
|
17
18
|
|
|
18
|
-
|
|
19
|
+
lightning_model = PVNetLightningModule(
|
|
19
20
|
model=late_fusion_model,
|
|
20
21
|
optimizer=EmbAdamWReduceLROnPlateau(),
|
|
21
22
|
)
|
|
@@ -29,4 +30,4 @@ def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_mo
|
|
|
29
30
|
logger=False,
|
|
30
31
|
enable_checkpointing=False,
|
|
31
32
|
)
|
|
32
|
-
trainer.fit(model=
|
|
33
|
+
trainer.fit(model=lightning_model, datamodule=datamodule)
|
pvnet-5.0.15/pvnet/utils.py
DELETED
|
@@ -1,88 +0,0 @@
|
|
|
1
|
-
"""Utils"""
|
|
2
|
-
import logging
|
|
3
|
-
|
|
4
|
-
import rich.syntax
|
|
5
|
-
import rich.tree
|
|
6
|
-
from lightning.pytorch.utilities import rank_zero_only
|
|
7
|
-
from omegaconf import DictConfig, OmegaConf
|
|
8
|
-
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
|
|
13
|
-
MODEL_CONFIG_NAME = "model_config.yaml"
|
|
14
|
-
DATA_CONFIG_NAME = "data_config.yaml"
|
|
15
|
-
DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
|
|
16
|
-
FULL_CONFIG_NAME = "full_experiment_config.yaml"
|
|
17
|
-
MODEL_CARD_NAME = "README.md"
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def run_config_utilities(config: DictConfig) -> None:
|
|
22
|
-
"""A couple of optional utilities.
|
|
23
|
-
|
|
24
|
-
Controlled by main config file:
|
|
25
|
-
- forcing debug friendly configuration
|
|
26
|
-
|
|
27
|
-
Modifies DictConfig in place.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
config (DictConfig): Configuration composed by Hydra.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
# Enable adding new keys to config
|
|
34
|
-
OmegaConf.set_struct(config, False)
|
|
35
|
-
|
|
36
|
-
# Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
|
|
37
|
-
if config.trainer.get("fast_dev_run"):
|
|
38
|
-
logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
|
|
39
|
-
# Debuggers don't like GPUs or multiprocessing
|
|
40
|
-
if config.trainer.get("gpus"):
|
|
41
|
-
config.trainer.gpus = 0
|
|
42
|
-
if config.datamodule.get("pin_memory"):
|
|
43
|
-
config.datamodule.pin_memory = False
|
|
44
|
-
if config.datamodule.get("num_workers"):
|
|
45
|
-
config.datamodule.num_workers = 0
|
|
46
|
-
if config.datamodule.get("prefetch_factor"):
|
|
47
|
-
config.datamodule.prefetch_factor = None
|
|
48
|
-
|
|
49
|
-
# Disable adding new keys to config
|
|
50
|
-
OmegaConf.set_struct(config, True)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@rank_zero_only
|
|
54
|
-
def print_config(
|
|
55
|
-
config: DictConfig,
|
|
56
|
-
fields: tuple[str] = (
|
|
57
|
-
"trainer",
|
|
58
|
-
"model",
|
|
59
|
-
"datamodule",
|
|
60
|
-
"callbacks",
|
|
61
|
-
"logger",
|
|
62
|
-
"seed",
|
|
63
|
-
),
|
|
64
|
-
resolve: bool = True,
|
|
65
|
-
) -> None:
|
|
66
|
-
"""Prints content of DictConfig using Rich library and its tree structure.
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
config (DictConfig): Configuration composed by Hydra.
|
|
70
|
-
fields (Sequence[str], optional): Determines which main fields from config will
|
|
71
|
-
be printed and in what order.
|
|
72
|
-
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
style = "dim"
|
|
76
|
-
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
|
77
|
-
|
|
78
|
-
for field in fields:
|
|
79
|
-
branch = tree.add(field, style=style, guide_style=style)
|
|
80
|
-
|
|
81
|
-
config_section = config.get(field)
|
|
82
|
-
branch_content = str(config_section)
|
|
83
|
-
if isinstance(config_section, DictConfig):
|
|
84
|
-
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
|
|
85
|
-
|
|
86
|
-
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
87
|
-
|
|
88
|
-
rich.print(tree)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|