PVNet_summation 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,219 @@
1
+ """Optimizer factory-function classes.
2
+ """
3
+
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from torch.nn.parameter import Parameter
9
+
10
+
11
+ def find_submodule_parameters(model: Module, search_modules: list[Module]) -> list[Parameter]:
12
+ """Finds all parameters within given submodule types
13
+
14
+ Args:
15
+ model: torch Module to search through
16
+ search_modules: List of submodule types to search for
17
+ """
18
+ if isinstance(model, search_modules):
19
+ return model.parameters()
20
+
21
+ children = list(model.children())
22
+ if len(children) == 0:
23
+ return []
24
+ else:
25
+ params = []
26
+ for c in children:
27
+ params += find_submodule_parameters(c, search_modules)
28
+ return params
29
+
30
+
31
+ def find_other_than_submodule_parameters(
32
+ model: Module,
33
+ ignore_modules: list[Module],
34
+ ) -> list[Parameter]:
35
+ """Finds all parameters not with given submodule types
36
+
37
+ Args:
38
+ model: torch Module to search through
39
+ ignore_modules: List of submodule types to ignore
40
+ """
41
+ if isinstance(model, ignore_modules):
42
+ return []
43
+
44
+ children = list(model.children())
45
+ if len(children) == 0:
46
+ return model.parameters()
47
+ else:
48
+ params = []
49
+ for c in children:
50
+ params += find_other_than_submodule_parameters(c, ignore_modules)
51
+ return params
52
+
53
+
54
+ class AbstractOptimizer(ABC):
55
+ """Abstract class for optimizer
56
+
57
+ Optimizer classes will be used by model like:
58
+ > OptimizerGenerator = AbstractOptimizer()
59
+ > optimizer = OptimizerGenerator(model)
60
+ The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
61
+ `configure_optimizers()` method.
62
+ See :
63
+ https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
64
+
65
+ """
66
+
67
+ @abstractmethod
68
+ def __call__(self):
69
+ """Abstract call"""
70
+ pass
71
+
72
+
73
+ class Adam(AbstractOptimizer):
74
+ """Adam optimizer"""
75
+
76
+ def __init__(self, lr: float = 0.0005, **kwargs):
77
+ """Adam optimizer"""
78
+ self.lr = lr
79
+ self.kwargs = kwargs
80
+
81
+ def __call__(self, model: Module):
82
+ """Return optimizer"""
83
+ return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)
84
+
85
+
86
+ class AdamW(AbstractOptimizer):
87
+ """AdamW optimizer"""
88
+
89
+ def __init__(self, lr: float = 0.0005, **kwargs):
90
+ """AdamW optimizer"""
91
+ self.lr = lr
92
+ self.kwargs = kwargs
93
+
94
+ def __call__(self, model: Module):
95
+ """Return optimizer"""
96
+ return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)
97
+
98
+
99
+
100
+ class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
101
+ """AdamW optimizer and reduce on plateau scheduler"""
102
+
103
+ def __init__(
104
+ self,
105
+ lr: float = 0.0005,
106
+ weight_decay: float = 0.01,
107
+ patience: int = 3,
108
+ factor: float = 0.5,
109
+ threshold: float = 2e-4,
110
+ **opt_kwargs,
111
+ ):
112
+ """AdamW optimizer and reduce on plateau scheduler"""
113
+ self.lr = lr
114
+ self.weight_decay = weight_decay
115
+ self.patience = patience
116
+ self.factor = factor
117
+ self.threshold = threshold
118
+ self.opt_kwargs = opt_kwargs
119
+
120
+ def __call__(self, model):
121
+ """Return optimizer"""
122
+
123
+ search_modules = (torch.nn.Embedding,)
124
+
125
+ no_decay = find_submodule_parameters(model, search_modules)
126
+ decay = find_other_than_submodule_parameters(model, search_modules)
127
+
128
+ optim_groups = [
129
+ {"params": decay, "weight_decay": self.weight_decay},
130
+ {"params": no_decay, "weight_decay": 0.0},
131
+ ]
132
+ opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
133
+
134
+ sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
135
+ opt,
136
+ factor=self.factor,
137
+ patience=self.patience,
138
+ threshold=self.threshold,
139
+ )
140
+ sch = {
141
+ "scheduler": sch,
142
+ "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
143
+ }
144
+ return [opt], [sch]
145
+
146
+
147
+ class AdamWReduceLROnPlateau(AbstractOptimizer):
148
+ """AdamW optimizer and reduce on plateau scheduler"""
149
+
150
+ def __init__(
151
+ self,
152
+ lr: float = 0.0005,
153
+ patience: int = 3,
154
+ factor: float = 0.5,
155
+ threshold: float = 2e-4,
156
+ step_freq=None,
157
+ **opt_kwargs,
158
+ ):
159
+ """AdamW optimizer and reduce on plateau scheduler"""
160
+ self.lr = lr
161
+ self.patience = patience
162
+ self.factor = factor
163
+ self.threshold = threshold
164
+ self.step_freq = step_freq
165
+ self.opt_kwargs = opt_kwargs
166
+
167
+ def _call_multi(self, model):
168
+ remaining_params = {k: p for k, p in model.named_parameters()}
169
+
170
+ group_args = []
171
+
172
+ for key in self.lr.keys():
173
+ if key == "default":
174
+ continue
175
+
176
+ submodule_params = []
177
+ for param_name in list(remaining_params.keys()):
178
+ if param_name.startswith(key):
179
+ submodule_params += [remaining_params.pop(param_name)]
180
+
181
+ group_args += [{"params": submodule_params, "lr": self.lr[key]}]
182
+
183
+ remaining_params = [p for k, p in remaining_params.items()]
184
+ group_args += [{"params": remaining_params}]
185
+
186
+ opt = torch.optim.AdamW(
187
+ group_args,
188
+ lr=self.lr["default"] if model.lr is None else model.lr,
189
+ **self.opt_kwargs,
190
+ )
191
+ sch = {
192
+ "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
193
+ opt,
194
+ factor=self.factor,
195
+ patience=self.patience,
196
+ threshold=self.threshold,
197
+ ),
198
+ "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
199
+ }
200
+
201
+ return [opt], [sch]
202
+
203
+ def __call__(self, model):
204
+ """Return optimizer"""
205
+ if not isinstance(self.lr, float):
206
+ return self._call_multi(model)
207
+ else:
208
+ opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
209
+ sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
210
+ opt,
211
+ factor=self.factor,
212
+ patience=self.patience,
213
+ threshold=self.threshold,
214
+ )
215
+ sch = {
216
+ "scheduler": sch,
217
+ "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
218
+ }
219
+ return [opt], [sch]
@@ -0,0 +1,3 @@
1
+ """Training submodule"""
2
+ from .lightning_module import PVNetSummationLightningModule
3
+ from .train import train
@@ -0,0 +1,278 @@
1
+ """Pytorch lightning module for training PVNet models"""
2
+
3
+ import lightning.pytorch as pl
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import wandb
9
+ from ocf_data_sampler.numpy_sample.common_types import TensorBatch
10
+ from torch.utils.data import default_collate
11
+
12
+ from pvnet_summation.models.base_model import BaseModel
13
+ from pvnet_summation.optimizers import AbstractOptimizer
14
+ from pvnet_summation.training.plots import plot_sample_forecasts, wandb_line_plot
15
+
16
+
17
+ class PVNetSummationLightningModule(pl.LightningModule):
18
+ """Lightning module for training PVNet models"""
19
+
20
+ def __init__(
21
+ self,
22
+ model: BaseModel,
23
+ optimizer: AbstractOptimizer,
24
+ ):
25
+ """Lightning module for training PVNet models
26
+
27
+ Args:
28
+ model: The PVNet model
29
+ optimizer: Optimizer
30
+ """
31
+ super().__init__()
32
+
33
+ self.model = model
34
+ self._optimizer = optimizer
35
+
36
+ # Model must have lr to allow tuning
37
+ # This setting is only used when lr is tuned with callback
38
+ self.lr = None
39
+
40
+
41
+ def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
42
+ """Calculate quantile loss.
43
+
44
+ Note:
45
+ Implementation copied from:
46
+ https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting
47
+ /metrics/quantile.html#QuantileLoss.loss
48
+
49
+ Args:
50
+ y_quantiles: Quantile prediction of network
51
+ y: Target values
52
+
53
+ Returns:
54
+ Quantile loss
55
+ """
56
+ losses = []
57
+ for i, q in enumerate(self.model.output_quantiles):
58
+ errors = y - y_quantiles[..., i]
59
+ losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
60
+ losses = 2 * torch.cat(losses, dim=2)
61
+
62
+ return losses.mean()
63
+
64
+ def configure_optimizers(self):
65
+ """Configure the optimizers using learning rate found with LR finder if used"""
66
+ if self.lr is not None:
67
+ # Use learning rate found by learning rate finder callback
68
+ self._optimizer.lr = self.lr
69
+ return self._optimizer(self.model)
70
+
71
+ def _calculate_common_losses(
72
+ self,
73
+ y: torch.Tensor,
74
+ y_hat: torch.Tensor,
75
+ ) -> dict[str, torch.Tensor]:
76
+ """Calculate losses common to train, and val"""
77
+
78
+ losses = {}
79
+
80
+ if self.model.use_quantile_regression:
81
+ losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
82
+ y_hat = self.model._quantiles_to_prediction(y_hat)
83
+
84
+ losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
85
+
86
+ return losses
87
+
88
+ def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
89
+ """Run training step"""
90
+
91
+ y_hat = self.model(batch)
92
+
93
+ y = batch["target"]
94
+
95
+ losses = self._calculate_common_losses(y, y_hat)
96
+ losses = {f"{k}/train": v for k, v in losses.items()}
97
+
98
+ self.log_dict(losses, on_step=True, on_epoch=True)
99
+
100
+ if self.model.use_quantile_regression:
101
+ opt_target = losses["quantile_loss/train"]
102
+ else:
103
+ opt_target = losses["MAE/train"]
104
+ return opt_target
105
+
106
+ def _calculate_val_losses(
107
+ self,
108
+ y: torch.Tensor,
109
+ y_hat: torch.Tensor,
110
+ ) -> dict[str, torch.Tensor]:
111
+ """Calculate additional losses only run in validation"""
112
+
113
+ losses = {}
114
+
115
+ if self.model.use_quantile_regression:
116
+ metric_name = "val_fraction_below/fraction_below_{:.2f}_quantile"
117
+ # Add fraction below each quantile for calibration
118
+ for i, quantile in enumerate(self.model.output_quantiles):
119
+ below_quant = y <= y_hat[..., i]
120
+ # Mask values small values, which are dominated by night
121
+ mask = y >= 0.01
122
+ losses[metric_name.format(quantile)] = below_quant[mask].float().mean()
123
+
124
+ return losses
125
+
126
+ def _calculate_step_metrics(
127
+ self,
128
+ y: torch.Tensor,
129
+ y_hat: torch.Tensor,
130
+ ) -> tuple[np.array, np.array]:
131
+ """Calculate the MAE and MSE at each forecast step"""
132
+
133
+ mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy()
134
+ mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy()
135
+
136
+ return mae_each_step, mse_each_step
137
+
138
+ def on_validation_epoch_start(self):
139
+ """Run at start of val period"""
140
+ # Set up stores which we will fill during validation
141
+ self._val_horizon_maes: list[np.array] = []
142
+ if self.current_epoch==0:
143
+ self._val_persistence_horizon_maes: list[np.array] = []
144
+ self._val_loc_sum_horizon_maes: list[np.array] = []
145
+
146
+ # Plot some sample forecasts
147
+ val_dataset = self.trainer.val_dataloaders.dataset
148
+
149
+ plots_per_figure = 16
150
+ num_figures = 2
151
+
152
+ for plot_num in range(num_figures):
153
+ idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure
154
+ idxs = idxs[idxs<len(val_dataset)]
155
+
156
+ if len(idxs)==0:
157
+ continue
158
+
159
+ batch = default_collate([val_dataset[i] for i in idxs])
160
+ batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
161
+ with torch.no_grad():
162
+ y_hat = self.model(batch)
163
+
164
+ y_loc_sum = self.model.sum_of_locations(batch)
165
+
166
+ fig = plot_sample_forecasts(batch, y_hat, y_loc_sum, self.model.output_quantiles)
167
+
168
+ plot_name = f"val_forecast_samples/sample_set_{plot_num}"
169
+
170
+ self.logger.experiment.log({plot_name: wandb.Image(fig)})
171
+
172
+ plt.close(fig)
173
+
174
+ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
175
+ """Run validation step"""
176
+
177
+ y_hat = self.model(batch)
178
+
179
+ y = batch["target"]
180
+
181
+ losses = self._calculate_common_losses(y, y_hat)
182
+ losses = {f"{k}/val": v for k, v in losses.items()}
183
+
184
+ losses.update(self._calculate_val_losses(y, y_hat))
185
+
186
+ # Calculate the horizon MAE/MSE metrics
187
+ if self.model.use_quantile_regression:
188
+ y_hat_mid = self.model._quantiles_to_prediction(y_hat)
189
+ else:
190
+ y_hat_mid = y_hat
191
+
192
+ mae_step, mse_step = self._calculate_step_metrics(y, y_hat_mid)
193
+
194
+ # Store to make horizon-MAE plot
195
+ self._val_horizon_maes.append(mae_step)
196
+
197
+ # Also add each step to logged metrics
198
+ losses.update({f"val_step_MAE/step_{i:03}": m for i, m in enumerate(mae_step)})
199
+ losses.update({f"val_step_MSE/step_{i:03}": m for i, m in enumerate(mse_step)})
200
+
201
+ # Calculate the persistence and sum-of-locations losses - we only need to do this once per
202
+ # training run not every epoch
203
+ if self.current_epoch==0:
204
+
205
+ # Persistence
206
+ y_persist = batch["last_outturn"].unsqueeze(1).expand(-1, self.model.forecast_len)
207
+ mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
208
+ self._val_persistence_horizon_maes.append(mae_step_persist)
209
+ losses.update(
210
+ {
211
+ "MAE/val_persistence": mae_step_persist.mean(),
212
+ "MSE/val_persistence": mse_step_persist.mean()
213
+ }
214
+ )
215
+
216
+ # Sum of Locations
217
+ y_loc_sum = self.model.sum_of_locations(batch)
218
+ mae_step_loc_sum, mse_step_loc_sum = self._calculate_step_metrics(y, y_loc_sum)
219
+ self._val_loc_sum_horizon_maes.append(mae_step_loc_sum)
220
+ losses.update(
221
+ {
222
+ "MAE/val_location_sum": mae_step_loc_sum.mean(),
223
+ "MSE/val_location_sum": mse_step_loc_sum.mean()
224
+ }
225
+ )
226
+
227
+ # Log the metrics
228
+ self.log_dict(losses, on_step=False, on_epoch=True)
229
+
230
+ def on_validation_epoch_end(self) -> None:
231
+ """Run on epoch end"""
232
+
233
+ val_horizon_maes = np.mean(self._val_horizon_maes, axis=0)
234
+ self._val_horizon_maes = []
235
+
236
+ if isinstance(self.logger, pl.loggers.WandbLogger):
237
+
238
+ # Create the horizon accuracy curve
239
+ horizon_mae_plot = wandb_line_plot(
240
+ x=np.arange(self.model.forecast_len),
241
+ y=val_horizon_maes,
242
+ xlabel="Horizon step",
243
+ ylabel="MAE",
244
+ title="Val horizon loss curve",
245
+ )
246
+
247
+ wandb.log({"val_horizon_mae_plot": horizon_mae_plot})
248
+
249
+ # Create persistence and location-sum horizon accuracy curve on first epoch
250
+ if self.current_epoch==0:
251
+ val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0)
252
+ del self._val_persistence_horizon_maes
253
+
254
+ val_loc_sum_horizon_maes = np.mean(self._val_loc_sum_horizon_maes, axis=0)
255
+ del self._val_loc_sum_horizon_maes
256
+
257
+ persist_horizon_mae_plot = wandb_line_plot(
258
+ x=np.arange(self.model.forecast_len),
259
+ y=val_persistence_horizon_maes,
260
+ xlabel="Horizon step",
261
+ ylabel="MAE",
262
+ title="Val persistence horizon loss curve",
263
+ )
264
+
265
+ loc_sum_horizon_mae_plot = wandb_line_plot(
266
+ x=np.arange(self.model.forecast_len),
267
+ y=val_loc_sum_horizon_maes,
268
+ xlabel="Horizon step",
269
+ ylabel="MAE",
270
+ title="Val location-sum horizon loss curve",
271
+ )
272
+
273
+ wandb.log(
274
+ {
275
+ "persistence_val_horizon_mae_plot": persist_horizon_mae_plot,
276
+ "location_sum_val_horizon_mae_plot": loc_sum_horizon_mae_plot,
277
+ }
278
+ )
@@ -0,0 +1,91 @@
1
+ """Plots logged during training"""
2
+ from collections.abc import Sequence
3
+
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ import pylab
7
+ import torch
8
+ import wandb
9
+
10
+ from pvnet_summation.data.datamodule import SumTensorBatch
11
+
12
+
13
+ def wandb_line_plot(
14
+ x: Sequence[float],
15
+ y: Sequence[float],
16
+ xlabel: str,
17
+ ylabel: str,
18
+ title: str | None = None,
19
+ ) -> wandb.plot.CustomChart:
20
+ """Make a wandb line plot"""
21
+ data = [[xi, yi] for (xi, yi) in zip(x, y)]
22
+ table = wandb.Table(data=data, columns=[xlabel, ylabel])
23
+ return wandb.plot.line(table, xlabel, ylabel, title=title)
24
+
25
+
26
+ def plot_sample_forecasts(
27
+ batch: SumTensorBatch,
28
+ y_hat: torch.Tensor,
29
+ y_loc_sum: torch.Tensor,
30
+ quantiles: list[float] | None,
31
+ ) -> plt.Figure:
32
+ """Plot a batch of data and the forecast from that batch"""
33
+
34
+ y = batch["target"].cpu().numpy()
35
+ y_hat = y_hat.cpu().numpy()
36
+ y_loc_sum = y_loc_sum.cpu().numpy()
37
+ times_utc = pd.to_datetime(batch["valid_times"].cpu().numpy().astype("datetime64[ns]"))
38
+ batch_size = y.shape[0]
39
+
40
+ fig, axes = plt.subplots(4, 4, figsize=(16, 16))
41
+
42
+ for i, ax in enumerate(axes.ravel()[:batch_size]):
43
+
44
+ ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$")
45
+
46
+ ax.plot(
47
+ times_utc[i],
48
+ y_loc_sum[i],
49
+ marker=".",
50
+ linestyle="-.",
51
+ color="r",
52
+ label=r"$\hat{y}_{loc-sum}$",
53
+ )
54
+
55
+ if quantiles is None:
56
+ ax.plot(
57
+ times_utc[i],
58
+ y_hat[i],
59
+ marker=".",
60
+ color="r",
61
+ label=r"$\hat{y}$",
62
+ )
63
+ else:
64
+ cm = pylab.get_cmap("twilight")
65
+ for nq, q in enumerate(quantiles):
66
+ ax.plot(
67
+ times_utc[i],
68
+ y_hat[i, :, nq],
69
+ color=cm(q),
70
+ label=r"$\hat{y}$" + f"({q})",
71
+ alpha=0.7,
72
+ )
73
+
74
+ ax.set_title(f"{times_utc[i][0].date()}", fontsize="small")
75
+
76
+ xticks = [t for t in times_utc[i] if t.minute == 0][::2]
77
+ ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90)
78
+ ax.grid()
79
+
80
+ axes[0, 0].legend(loc="best")
81
+
82
+ if batch_size<16:
83
+ for ax in axes.ravel()[batch_size:]:
84
+ ax.axis("off")
85
+
86
+ for ax in axes[-1, :]:
87
+ ax.set_xlabel("Time (hour of day)")
88
+
89
+ plt.tight_layout()
90
+
91
+ return fig