PVNet_summation 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PVNet_summation might be problematic. Click here for more details.
- pvnet_summation/__init__.py +1 -0
- pvnet_summation/data/__init__.py +2 -0
- pvnet_summation/data/datamodule.py +213 -0
- pvnet_summation/load_model.py +70 -0
- pvnet_summation/models/__init__.py +3 -0
- pvnet_summation/models/base_model.py +345 -0
- pvnet_summation/models/dense_model.py +75 -0
- pvnet_summation/optimizers.py +219 -0
- pvnet_summation/training/__init__.py +3 -0
- pvnet_summation/training/lightning_module.py +247 -0
- pvnet_summation/training/plots.py +80 -0
- pvnet_summation/training/train.py +185 -0
- pvnet_summation/utils.py +87 -0
- pvnet_summation-1.0.0.dist-info/METADATA +100 -0
- pvnet_summation-1.0.0.dist-info/RECORD +18 -0
- pvnet_summation-1.0.0.dist-info/WHEEL +5 -0
- pvnet_summation-1.0.0.dist-info/licenses/LICENSE +21 -0
- pvnet_summation-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Simple model which only uses outputs of PVNet for all GSPs"""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from pvnet_summation.data.datamodule import SumTensorBatch
|
|
9
|
+
from pvnet_summation.models.base_model import BaseModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DenseModel(BaseModel):
|
|
13
|
+
"""Neural network architecture based on naive dense layers
|
|
14
|
+
|
|
15
|
+
This model flattens all the features into a 1D vector before feeding them into the sub network
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
output_quantiles: list[float] | None,
|
|
21
|
+
num_input_locations: int,
|
|
22
|
+
input_quantiles: list[float] | None,
|
|
23
|
+
history_minutes: int,
|
|
24
|
+
forecast_minutes: int,
|
|
25
|
+
interval_minutes: int,
|
|
26
|
+
output_network: torch.nn.Module,
|
|
27
|
+
predict_difference_from_sum: bool = False,
|
|
28
|
+
):
|
|
29
|
+
"""Neural network architecture based on naive dense layers
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
super().__init__(
|
|
34
|
+
output_quantiles,
|
|
35
|
+
num_input_locations,
|
|
36
|
+
input_quantiles,
|
|
37
|
+
history_minutes,
|
|
38
|
+
forecast_minutes,
|
|
39
|
+
interval_minutes,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.predict_difference_from_sum = predict_difference_from_sum
|
|
43
|
+
|
|
44
|
+
self.model = output_network(
|
|
45
|
+
in_features=np.prod(self.input_shape),
|
|
46
|
+
out_features=self.num_output_features,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Add linear layer if predicting difference from sum
|
|
50
|
+
# This allows difference to be positive or negative
|
|
51
|
+
if predict_difference_from_sum:
|
|
52
|
+
self.model = nn.Sequential(
|
|
53
|
+
self.model,
|
|
54
|
+
nn.Linear(self.num_output_features, self.num_output_features),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def forward(self, x: SumTensorBatch) -> torch.Tensor:
|
|
58
|
+
"""Run model forward"""
|
|
59
|
+
|
|
60
|
+
x_in = torch.flatten(x["pvnet_outputs"], start_dim=1)
|
|
61
|
+
out = self.model(x_in)
|
|
62
|
+
|
|
63
|
+
if self.use_quantile_regression:
|
|
64
|
+
# Shape: [batch_size, seq_length * num_quantiles]
|
|
65
|
+
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
|
|
66
|
+
|
|
67
|
+
if self.predict_difference_from_sum:
|
|
68
|
+
loc_sum = self.sum_of_locations(x)
|
|
69
|
+
|
|
70
|
+
if self.use_quantile_regression:
|
|
71
|
+
loc_sum = loc_sum.unsqueeze(-1)
|
|
72
|
+
|
|
73
|
+
out = F.leaky_relu(loc_sum + out)
|
|
74
|
+
|
|
75
|
+
return out
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Optimizer factory-function classes.
|
|
2
|
+
"""
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn import Module
|
|
8
|
+
from torch.nn.parameter import Parameter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def find_submodule_parameters(model: Module, search_modules: list[Module]) -> list[Parameter]:
|
|
12
|
+
"""Finds all parameters within given submodule types
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model: torch Module to search through
|
|
16
|
+
search_modules: List of submodule types to search for
|
|
17
|
+
"""
|
|
18
|
+
if isinstance(model, search_modules):
|
|
19
|
+
return model.parameters()
|
|
20
|
+
|
|
21
|
+
children = list(model.children())
|
|
22
|
+
if len(children) == 0:
|
|
23
|
+
return []
|
|
24
|
+
else:
|
|
25
|
+
params = []
|
|
26
|
+
for c in children:
|
|
27
|
+
params += find_submodule_parameters(c, search_modules)
|
|
28
|
+
return params
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def find_other_than_submodule_parameters(
|
|
32
|
+
model: Module,
|
|
33
|
+
ignore_modules: list[Module],
|
|
34
|
+
) -> list[Parameter]:
|
|
35
|
+
"""Finds all parameters not with given submodule types
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model: torch Module to search through
|
|
39
|
+
ignore_modules: List of submodule types to ignore
|
|
40
|
+
"""
|
|
41
|
+
if isinstance(model, ignore_modules):
|
|
42
|
+
return []
|
|
43
|
+
|
|
44
|
+
children = list(model.children())
|
|
45
|
+
if len(children) == 0:
|
|
46
|
+
return model.parameters()
|
|
47
|
+
else:
|
|
48
|
+
params = []
|
|
49
|
+
for c in children:
|
|
50
|
+
params += find_other_than_submodule_parameters(c, ignore_modules)
|
|
51
|
+
return params
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class AbstractOptimizer(ABC):
|
|
55
|
+
"""Abstract class for optimizer
|
|
56
|
+
|
|
57
|
+
Optimizer classes will be used by model like:
|
|
58
|
+
> OptimizerGenerator = AbstractOptimizer()
|
|
59
|
+
> optimizer = OptimizerGenerator(model)
|
|
60
|
+
The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
|
|
61
|
+
`configure_optimizers()` method.
|
|
62
|
+
See :
|
|
63
|
+
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def __call__(self):
|
|
69
|
+
"""Abstract call"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Adam(AbstractOptimizer):
|
|
74
|
+
"""Adam optimizer"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, lr: float = 0.0005, **kwargs):
|
|
77
|
+
"""Adam optimizer"""
|
|
78
|
+
self.lr = lr
|
|
79
|
+
self.kwargs = kwargs
|
|
80
|
+
|
|
81
|
+
def __call__(self, model: Module):
|
|
82
|
+
"""Return optimizer"""
|
|
83
|
+
return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AdamW(AbstractOptimizer):
|
|
87
|
+
"""AdamW optimizer"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, lr: float = 0.0005, **kwargs):
|
|
90
|
+
"""AdamW optimizer"""
|
|
91
|
+
self.lr = lr
|
|
92
|
+
self.kwargs = kwargs
|
|
93
|
+
|
|
94
|
+
def __call__(self, model: Module):
|
|
95
|
+
"""Return optimizer"""
|
|
96
|
+
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
|
|
101
|
+
"""AdamW optimizer and reduce on plateau scheduler"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
lr: float = 0.0005,
|
|
106
|
+
weight_decay: float = 0.01,
|
|
107
|
+
patience: int = 3,
|
|
108
|
+
factor: float = 0.5,
|
|
109
|
+
threshold: float = 2e-4,
|
|
110
|
+
**opt_kwargs,
|
|
111
|
+
):
|
|
112
|
+
"""AdamW optimizer and reduce on plateau scheduler"""
|
|
113
|
+
self.lr = lr
|
|
114
|
+
self.weight_decay = weight_decay
|
|
115
|
+
self.patience = patience
|
|
116
|
+
self.factor = factor
|
|
117
|
+
self.threshold = threshold
|
|
118
|
+
self.opt_kwargs = opt_kwargs
|
|
119
|
+
|
|
120
|
+
def __call__(self, model):
|
|
121
|
+
"""Return optimizer"""
|
|
122
|
+
|
|
123
|
+
search_modules = (torch.nn.Embedding,)
|
|
124
|
+
|
|
125
|
+
no_decay = find_submodule_parameters(model, search_modules)
|
|
126
|
+
decay = find_other_than_submodule_parameters(model, search_modules)
|
|
127
|
+
|
|
128
|
+
optim_groups = [
|
|
129
|
+
{"params": decay, "weight_decay": self.weight_decay},
|
|
130
|
+
{"params": no_decay, "weight_decay": 0.0},
|
|
131
|
+
]
|
|
132
|
+
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
|
|
133
|
+
|
|
134
|
+
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
135
|
+
opt,
|
|
136
|
+
factor=self.factor,
|
|
137
|
+
patience=self.patience,
|
|
138
|
+
threshold=self.threshold,
|
|
139
|
+
)
|
|
140
|
+
sch = {
|
|
141
|
+
"scheduler": sch,
|
|
142
|
+
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
|
143
|
+
}
|
|
144
|
+
return [opt], [sch]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class AdamWReduceLROnPlateau(AbstractOptimizer):
|
|
148
|
+
"""AdamW optimizer and reduce on plateau scheduler"""
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
lr: float = 0.0005,
|
|
153
|
+
patience: int = 3,
|
|
154
|
+
factor: float = 0.5,
|
|
155
|
+
threshold: float = 2e-4,
|
|
156
|
+
step_freq=None,
|
|
157
|
+
**opt_kwargs,
|
|
158
|
+
):
|
|
159
|
+
"""AdamW optimizer and reduce on plateau scheduler"""
|
|
160
|
+
self.lr = lr
|
|
161
|
+
self.patience = patience
|
|
162
|
+
self.factor = factor
|
|
163
|
+
self.threshold = threshold
|
|
164
|
+
self.step_freq = step_freq
|
|
165
|
+
self.opt_kwargs = opt_kwargs
|
|
166
|
+
|
|
167
|
+
def _call_multi(self, model):
|
|
168
|
+
remaining_params = {k: p for k, p in model.named_parameters()}
|
|
169
|
+
|
|
170
|
+
group_args = []
|
|
171
|
+
|
|
172
|
+
for key in self.lr.keys():
|
|
173
|
+
if key == "default":
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
submodule_params = []
|
|
177
|
+
for param_name in list(remaining_params.keys()):
|
|
178
|
+
if param_name.startswith(key):
|
|
179
|
+
submodule_params += [remaining_params.pop(param_name)]
|
|
180
|
+
|
|
181
|
+
group_args += [{"params": submodule_params, "lr": self.lr[key]}]
|
|
182
|
+
|
|
183
|
+
remaining_params = [p for k, p in remaining_params.items()]
|
|
184
|
+
group_args += [{"params": remaining_params}]
|
|
185
|
+
|
|
186
|
+
opt = torch.optim.AdamW(
|
|
187
|
+
group_args,
|
|
188
|
+
lr=self.lr["default"] if model.lr is None else model.lr,
|
|
189
|
+
**self.opt_kwargs,
|
|
190
|
+
)
|
|
191
|
+
sch = {
|
|
192
|
+
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
193
|
+
opt,
|
|
194
|
+
factor=self.factor,
|
|
195
|
+
patience=self.patience,
|
|
196
|
+
threshold=self.threshold,
|
|
197
|
+
),
|
|
198
|
+
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
return [opt], [sch]
|
|
202
|
+
|
|
203
|
+
def __call__(self, model):
|
|
204
|
+
"""Return optimizer"""
|
|
205
|
+
if not isinstance(self.lr, float):
|
|
206
|
+
return self._call_multi(model)
|
|
207
|
+
else:
|
|
208
|
+
opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
|
|
209
|
+
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
210
|
+
opt,
|
|
211
|
+
factor=self.factor,
|
|
212
|
+
patience=self.patience,
|
|
213
|
+
threshold=self.threshold,
|
|
214
|
+
)
|
|
215
|
+
sch = {
|
|
216
|
+
"scheduler": sch,
|
|
217
|
+
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
|
218
|
+
}
|
|
219
|
+
return [opt], [sch]
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
"""Pytorch lightning module for training PVNet models"""
|
|
2
|
+
|
|
3
|
+
import lightning.pytorch as pl
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import wandb
|
|
9
|
+
from ocf_data_sampler.numpy_sample.common_types import TensorBatch
|
|
10
|
+
from torch.utils.data import default_collate
|
|
11
|
+
|
|
12
|
+
from pvnet_summation.models.base_model import BaseModel
|
|
13
|
+
from pvnet_summation.optimizers import AbstractOptimizer
|
|
14
|
+
from pvnet_summation.training.plots import plot_sample_forecasts, wandb_line_plot
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PVNetSummationLightningModule(pl.LightningModule):
|
|
18
|
+
"""Lightning module for training PVNet models"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: BaseModel,
|
|
23
|
+
optimizer: AbstractOptimizer,
|
|
24
|
+
):
|
|
25
|
+
"""Lightning module for training PVNet models
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: The PVNet model
|
|
29
|
+
optimizer: Optimizer
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
self.model = model
|
|
34
|
+
self._optimizer = optimizer
|
|
35
|
+
|
|
36
|
+
# Model must have lr to allow tuning
|
|
37
|
+
# This setting is only used when lr is tuned with callback
|
|
38
|
+
self.lr = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
"""Calculate quantile loss.
|
|
43
|
+
|
|
44
|
+
Note:
|
|
45
|
+
Implementation copied from:
|
|
46
|
+
https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting
|
|
47
|
+
/metrics/quantile.html#QuantileLoss.loss
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
y_quantiles: Quantile prediction of network
|
|
51
|
+
y: Target values
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Quantile loss
|
|
55
|
+
"""
|
|
56
|
+
losses = []
|
|
57
|
+
for i, q in enumerate(self.model.output_quantiles):
|
|
58
|
+
errors = y - y_quantiles[..., i]
|
|
59
|
+
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
|
|
60
|
+
losses = 2 * torch.cat(losses, dim=2)
|
|
61
|
+
|
|
62
|
+
return losses.mean()
|
|
63
|
+
|
|
64
|
+
def configure_optimizers(self):
|
|
65
|
+
"""Configure the optimizers using learning rate found with LR finder if used"""
|
|
66
|
+
if self.lr is not None:
|
|
67
|
+
# Use learning rate found by learning rate finder callback
|
|
68
|
+
self._optimizer.lr = self.lr
|
|
69
|
+
return self._optimizer(self.model)
|
|
70
|
+
|
|
71
|
+
def _calculate_common_losses(
|
|
72
|
+
self,
|
|
73
|
+
y: torch.Tensor,
|
|
74
|
+
y_hat: torch.Tensor,
|
|
75
|
+
) -> dict[str, torch.Tensor]:
|
|
76
|
+
"""Calculate losses common to train, and val"""
|
|
77
|
+
|
|
78
|
+
losses = {}
|
|
79
|
+
|
|
80
|
+
if self.model.use_quantile_regression:
|
|
81
|
+
losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
|
|
82
|
+
y_hat = self.model._quantiles_to_prediction(y_hat)
|
|
83
|
+
|
|
84
|
+
losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
|
|
85
|
+
|
|
86
|
+
return losses
|
|
87
|
+
|
|
88
|
+
def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
|
|
89
|
+
"""Run training step"""
|
|
90
|
+
|
|
91
|
+
y_hat = self.model(batch)
|
|
92
|
+
|
|
93
|
+
y = batch["target"]
|
|
94
|
+
|
|
95
|
+
losses = self._calculate_common_losses(y, y_hat)
|
|
96
|
+
losses = {f"{k}/train": v for k, v in losses.items()}
|
|
97
|
+
|
|
98
|
+
self.log_dict(losses, on_step=True, on_epoch=True)
|
|
99
|
+
|
|
100
|
+
if self.model.use_quantile_regression:
|
|
101
|
+
opt_target = losses["quantile_loss/train"]
|
|
102
|
+
else:
|
|
103
|
+
opt_target = losses["MAE/train"]
|
|
104
|
+
return opt_target
|
|
105
|
+
|
|
106
|
+
def _calculate_val_losses(
|
|
107
|
+
self,
|
|
108
|
+
y: torch.Tensor,
|
|
109
|
+
y_hat: torch.Tensor,
|
|
110
|
+
) -> dict[str, torch.Tensor]:
|
|
111
|
+
"""Calculate additional losses only run in validation"""
|
|
112
|
+
|
|
113
|
+
losses = {}
|
|
114
|
+
|
|
115
|
+
if self.model.use_quantile_regression:
|
|
116
|
+
metric_name = "val_fraction_below/fraction_below_{:.2f}_quantile"
|
|
117
|
+
# Add fraction below each quantile for calibration
|
|
118
|
+
for i, quantile in enumerate(self.model.output_quantiles):
|
|
119
|
+
below_quant = y <= y_hat[..., i]
|
|
120
|
+
# Mask values small values, which are dominated by night
|
|
121
|
+
mask = y >= 0.01
|
|
122
|
+
losses[metric_name.format(quantile)] = below_quant[mask].float().mean()
|
|
123
|
+
|
|
124
|
+
return losses
|
|
125
|
+
|
|
126
|
+
def _calculate_step_metrics(
|
|
127
|
+
self,
|
|
128
|
+
y: torch.Tensor,
|
|
129
|
+
y_hat: torch.Tensor,
|
|
130
|
+
) -> tuple[np.array, np.array]:
|
|
131
|
+
"""Calculate the MAE and MSE at each forecast step"""
|
|
132
|
+
|
|
133
|
+
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy()
|
|
134
|
+
mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy()
|
|
135
|
+
|
|
136
|
+
return mae_each_step, mse_each_step
|
|
137
|
+
|
|
138
|
+
def on_validation_epoch_start(self):
|
|
139
|
+
"""Run at start of val period"""
|
|
140
|
+
# Set up stores which we will fill during validation
|
|
141
|
+
self._val_horizon_maes: list[np.array] = []
|
|
142
|
+
if self.current_epoch==0:
|
|
143
|
+
self._val_persistence_horizon_maes: list[np.array] = []
|
|
144
|
+
|
|
145
|
+
# Plot some sample forecasts
|
|
146
|
+
val_dataset = self.trainer.val_dataloaders.dataset
|
|
147
|
+
|
|
148
|
+
plots_per_figure = 16
|
|
149
|
+
num_figures = 2
|
|
150
|
+
|
|
151
|
+
for plot_num in range(num_figures):
|
|
152
|
+
idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure
|
|
153
|
+
idxs = idxs[idxs<len(val_dataset)]
|
|
154
|
+
|
|
155
|
+
if len(idxs)==0:
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
batch = default_collate([val_dataset[i] for i in idxs])
|
|
159
|
+
batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
|
|
160
|
+
with torch.no_grad():
|
|
161
|
+
y_hat = self.model(batch)
|
|
162
|
+
|
|
163
|
+
fig = plot_sample_forecasts(batch, y_hat, quantiles=self.model.output_quantiles)
|
|
164
|
+
|
|
165
|
+
plot_name = f"val_forecast_samples/sample_set_{plot_num}"
|
|
166
|
+
|
|
167
|
+
self.logger.experiment.log({plot_name: wandb.Image(fig)})
|
|
168
|
+
|
|
169
|
+
plt.close(fig)
|
|
170
|
+
|
|
171
|
+
def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
|
|
172
|
+
"""Run validation step"""
|
|
173
|
+
|
|
174
|
+
y_hat = self.model(batch)
|
|
175
|
+
|
|
176
|
+
y = batch["target"]
|
|
177
|
+
|
|
178
|
+
losses = self._calculate_common_losses(y, y_hat)
|
|
179
|
+
losses = {f"{k}/val": v for k, v in losses.items()}
|
|
180
|
+
|
|
181
|
+
losses.update(self._calculate_val_losses(y, y_hat))
|
|
182
|
+
|
|
183
|
+
# Calculate the horizon MAE/MSE metrics
|
|
184
|
+
if self.model.use_quantile_regression:
|
|
185
|
+
y_hat_mid = self.model._quantiles_to_prediction(y_hat)
|
|
186
|
+
else:
|
|
187
|
+
y_hat_mid = y_hat
|
|
188
|
+
|
|
189
|
+
mae_step, mse_step = self._calculate_step_metrics(y, y_hat_mid)
|
|
190
|
+
|
|
191
|
+
# Store to make horizon-MAE plot
|
|
192
|
+
self._val_horizon_maes.append(mae_step)
|
|
193
|
+
|
|
194
|
+
# Also add each step to logged metrics
|
|
195
|
+
losses.update({f"val_step_MAE/step_{i:03}": m for i, m in enumerate(mae_step)})
|
|
196
|
+
losses.update({f"val_step_MSE/step_{i:03}": m for i, m in enumerate(mse_step)})
|
|
197
|
+
|
|
198
|
+
# Calculate the persistance losses - we only need to do this once per training run
|
|
199
|
+
# not every epoch
|
|
200
|
+
if self.current_epoch==0:
|
|
201
|
+
y_persist = batch["last_outturn"].unsqueeze(1).expand(-1, self.model.forecast_len)
|
|
202
|
+
mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
|
|
203
|
+
self._val_persistence_horizon_maes.append(mae_step_persist)
|
|
204
|
+
losses.update(
|
|
205
|
+
{
|
|
206
|
+
"MAE/val_persistence": mae_step_persist.mean(),
|
|
207
|
+
"MSE/val_persistence": mse_step_persist.mean()
|
|
208
|
+
}
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Log the metrics
|
|
212
|
+
self.log_dict(losses, on_step=False, on_epoch=True)
|
|
213
|
+
|
|
214
|
+
def on_validation_epoch_end(self) -> None:
|
|
215
|
+
"""Run on epoch end"""
|
|
216
|
+
|
|
217
|
+
val_horizon_maes = np.mean(self._val_horizon_maes, axis=0)
|
|
218
|
+
self._val_horizon_maes = []
|
|
219
|
+
|
|
220
|
+
# We only run this on the first epoch
|
|
221
|
+
if self.current_epoch==0:
|
|
222
|
+
val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0)
|
|
223
|
+
self._val_persistence_horizon_maes = []
|
|
224
|
+
|
|
225
|
+
if isinstance(self.logger, pl.loggers.WandbLogger):
|
|
226
|
+
|
|
227
|
+
# Create the horizon accuracy curve
|
|
228
|
+
horizon_mae_plot = wandb_line_plot(
|
|
229
|
+
x=np.arange(self.model.forecast_len),
|
|
230
|
+
y=val_horizon_maes,
|
|
231
|
+
xlabel="Horizon step",
|
|
232
|
+
ylabel="MAE",
|
|
233
|
+
title="Val horizon loss curve",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
wandb.log({"val_horizon_mae_plot": horizon_mae_plot})
|
|
237
|
+
|
|
238
|
+
# Create persistence horizon accuracy curve but only on first epoch
|
|
239
|
+
if self.current_epoch==0:
|
|
240
|
+
persist_horizon_mae_plot = wandb_line_plot(
|
|
241
|
+
x=np.arange(self.model.forecast_len),
|
|
242
|
+
y=val_persistence_horizon_maes,
|
|
243
|
+
xlabel="Horizon step",
|
|
244
|
+
ylabel="MAE",
|
|
245
|
+
title="Val persistence horizon loss curve",
|
|
246
|
+
)
|
|
247
|
+
wandb.log({"persistence_val_horizon_mae_plot": persist_horizon_mae_plot})
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Plots logged during training"""
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import pylab
|
|
7
|
+
import torch
|
|
8
|
+
import wandb
|
|
9
|
+
|
|
10
|
+
from pvnet_summation.data.datamodule import SumTensorBatch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def wandb_line_plot(
|
|
14
|
+
x: Sequence[float],
|
|
15
|
+
y: Sequence[float],
|
|
16
|
+
xlabel: str,
|
|
17
|
+
ylabel: str,
|
|
18
|
+
title: str | None = None,
|
|
19
|
+
) -> wandb.plot.CustomChart:
|
|
20
|
+
"""Make a wandb line plot"""
|
|
21
|
+
data = [[xi, yi] for (xi, yi) in zip(x, y)]
|
|
22
|
+
table = wandb.Table(data=data, columns=[xlabel, ylabel])
|
|
23
|
+
return wandb.plot.line(table, xlabel, ylabel, title=title)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def plot_sample_forecasts(
|
|
27
|
+
batch: SumTensorBatch,
|
|
28
|
+
y_hat: torch.Tensor,
|
|
29
|
+
quantiles: list[float] | None,
|
|
30
|
+
) -> plt.Figure:
|
|
31
|
+
"""Plot a batch of data and the forecast from that batch"""
|
|
32
|
+
|
|
33
|
+
y = batch["target"].cpu().numpy()
|
|
34
|
+
y_hat = y_hat.cpu().numpy()
|
|
35
|
+
times_utc = pd.to_datetime(batch["valid_times"].cpu().numpy().astype("datetime64[ns]"))
|
|
36
|
+
batch_size = y.shape[0]
|
|
37
|
+
|
|
38
|
+
fig, axes = plt.subplots(4, 4, figsize=(16, 16))
|
|
39
|
+
|
|
40
|
+
for i, ax in enumerate(axes.ravel()[:batch_size]):
|
|
41
|
+
|
|
42
|
+
ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$")
|
|
43
|
+
|
|
44
|
+
if quantiles is None:
|
|
45
|
+
ax.plot(
|
|
46
|
+
times_utc[i][-len(y_hat[i]) :],
|
|
47
|
+
y_hat[i],
|
|
48
|
+
marker=".",
|
|
49
|
+
color="r",
|
|
50
|
+
label=r"$\hat{y}$",
|
|
51
|
+
)
|
|
52
|
+
else:
|
|
53
|
+
cm = pylab.get_cmap("twilight")
|
|
54
|
+
for nq, q in enumerate(quantiles):
|
|
55
|
+
ax.plot(
|
|
56
|
+
times_utc[i][-len(y_hat[i]) :],
|
|
57
|
+
y_hat[i, :, nq],
|
|
58
|
+
color=cm(q),
|
|
59
|
+
label=r"$\hat{y}$" + f"({q})",
|
|
60
|
+
alpha=0.7,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
ax.set_title(f"{times_utc[i][0].date()}", fontsize="small")
|
|
64
|
+
|
|
65
|
+
xticks = [t for t in times_utc[i] if t.minute == 0][::2]
|
|
66
|
+
ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90)
|
|
67
|
+
ax.grid()
|
|
68
|
+
|
|
69
|
+
axes[0, 0].legend(loc="best")
|
|
70
|
+
|
|
71
|
+
if batch_size<16:
|
|
72
|
+
for ax in axes.ravel()[batch_size:]:
|
|
73
|
+
ax.axis("off")
|
|
74
|
+
|
|
75
|
+
for ax in axes[-1, :]:
|
|
76
|
+
ax.set_xlabel("Time (hour of day)")
|
|
77
|
+
|
|
78
|
+
plt.tight_layout()
|
|
79
|
+
|
|
80
|
+
return fig
|