gridfm-graphkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gridfm_graphkit/__init__.py +0 -0
- gridfm_graphkit/__main__.py +62 -0
- gridfm_graphkit/cli.py +530 -0
- gridfm_graphkit/datasets/__init__.py +0 -0
- gridfm_graphkit/datasets/data_normalization.py +227 -0
- gridfm_graphkit/datasets/globals.py +19 -0
- gridfm_graphkit/datasets/powergrid.py +192 -0
- gridfm_graphkit/datasets/transforms.py +223 -0
- gridfm_graphkit/datasets/utils.py +65 -0
- gridfm_graphkit/io/__init__.py +0 -0
- gridfm_graphkit/io/param_handler.py +293 -0
- gridfm_graphkit/models/__init__.py +0 -0
- gridfm_graphkit/models/gps_transformer.py +143 -0
- gridfm_graphkit/models/graphTransformer.py +96 -0
- gridfm_graphkit/training/__init__.py +0 -0
- gridfm_graphkit/training/callbacks.py +47 -0
- gridfm_graphkit/training/plugins.py +218 -0
- gridfm_graphkit/training/trainer.py +156 -0
- gridfm_graphkit/utils/__init__.py +0 -0
- gridfm_graphkit/utils/loss.py +198 -0
- gridfm_graphkit/utils/visualization.py +324 -0
- gridfm_graphkit-0.0.1.dist-info/METADATA +163 -0
- gridfm_graphkit-0.0.1.dist-info/RECORD +27 -0
- gridfm_graphkit-0.0.1.dist-info/WHEEL +5 -0
- gridfm_graphkit-0.0.1.dist-info/entry_points.txt +2 -0
- gridfm_graphkit-0.0.1.dist-info/licenses/LICENSE +201 -0
- gridfm_graphkit-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
import mlflow
|
|
4
|
+
import os
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TrainerPlugin:
|
|
12
|
+
"""
|
|
13
|
+
Base class for training plugins.
|
|
14
|
+
|
|
15
|
+
A `TrainerPlugin` is invoked during the training process either at regular step intervals,
|
|
16
|
+
at the end of each epoch, or both. It can be extended to perform actions like logging,
|
|
17
|
+
checkpointing, or validation.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
steps (int, optional): Interval (in steps) to run the plugin. If `None`, only runs at end of epoch
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, steps: Optional[int] = None):
|
|
24
|
+
self.steps = steps
|
|
25
|
+
|
|
26
|
+
def run(self, step: int, end_of_epoch: bool) -> bool:
|
|
27
|
+
"""
|
|
28
|
+
Determines whether to execute the plugin at the current step.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
step (int): The current step number.
|
|
32
|
+
end_of_epoch (bool): Whether this is the end of the epoch.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
bool: True if the plugin should run; False otherwise.
|
|
36
|
+
"""
|
|
37
|
+
# By default we always run for epoch ends.
|
|
38
|
+
if end_of_epoch:
|
|
39
|
+
return True
|
|
40
|
+
# If self.steps is None, we're only recording epoch ends and this isn't one.
|
|
41
|
+
if self.steps is None:
|
|
42
|
+
return False
|
|
43
|
+
# record every `step` steps, starting from step `step`
|
|
44
|
+
if step != 0 and (step + 1) % self.steps == 0:
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def step(
|
|
50
|
+
self,
|
|
51
|
+
epoch: int,
|
|
52
|
+
step: int,
|
|
53
|
+
metrics: Dict = {},
|
|
54
|
+
end_of_epoch: bool = False,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
This method is called on every step of training, or with step=None
|
|
59
|
+
at the end of each epoch. Implementations can use the passed in
|
|
60
|
+
parameters for validation, checkpointing, logging, etc.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
epoch (int): The current epoch number.
|
|
64
|
+
step (int): The current step within the epoch.
|
|
65
|
+
metrics (dict): Dictionary of training metrics (e.g., loss).
|
|
66
|
+
end_of_epoch (bool): Indicates if this call is at the end of an epoch.
|
|
67
|
+
**kwargs (Any): Additional parameters such as model, optimizer, scheduler.
|
|
68
|
+
"""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MLflowLoggerPlugin(TrainerPlugin):
|
|
73
|
+
"""
|
|
74
|
+
Plugin to log training metrics to MLflow.
|
|
75
|
+
|
|
76
|
+
Logs metrics dynamically during training at defined step intervals and/or
|
|
77
|
+
at the end of each epoch. Also logs initial training parameters once.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
steps (int, optional): Interval in steps to log metrics.
|
|
81
|
+
params (dict, optional): Parameters to log to MLflow at the start.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, steps: Optional[int] = None, params: dict = None):
|
|
85
|
+
super().__init__(steps=steps) # Initialize the steps from the base class
|
|
86
|
+
self.steps = steps
|
|
87
|
+
self.metrics_history = {} # Dictionary to hold lists of all metrics over time
|
|
88
|
+
if params:
|
|
89
|
+
# Log parameters to MLflow at the beginning of training
|
|
90
|
+
mlflow.log_params(params)
|
|
91
|
+
|
|
92
|
+
def step(
|
|
93
|
+
self,
|
|
94
|
+
epoch: int,
|
|
95
|
+
step: int,
|
|
96
|
+
metrics: Dict = {},
|
|
97
|
+
end_of_epoch: bool = False,
|
|
98
|
+
**kwargs,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Logs metrics to MLflow dynamically at each specified step and at the end of each epoch.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
epoch (int): The current epoch number.
|
|
105
|
+
step (int): The current step within the epoch.
|
|
106
|
+
metrics (Dict): Dictionary of metrics to log, e.g., {'train_loss': value}.
|
|
107
|
+
end_of_epoch (bool): Flag indicating whether this is the end of the epoch.
|
|
108
|
+
"""
|
|
109
|
+
for metric_name, metric_value in metrics.items():
|
|
110
|
+
# Add metric to history
|
|
111
|
+
if metric_name not in self.metrics_history:
|
|
112
|
+
self.metrics_history[metric_name] = []
|
|
113
|
+
self.metrics_history[metric_name].append(metric_value)
|
|
114
|
+
|
|
115
|
+
if end_of_epoch:
|
|
116
|
+
for metric_name, values in self.metrics_history.items():
|
|
117
|
+
if values: # Avoid division by zero or empty lists
|
|
118
|
+
avg_value = sum(values) / len(values)
|
|
119
|
+
mlflow.log_metric(f"{metric_name}", avg_value, step=epoch)
|
|
120
|
+
|
|
121
|
+
# Clear metrics for the next epoch
|
|
122
|
+
self.metrics_history = {}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CheckpointerPlugin(TrainerPlugin):
|
|
126
|
+
"""
|
|
127
|
+
Plugin to periodically save model checkpoints.
|
|
128
|
+
|
|
129
|
+
Stores the model, optimizer, and scheduler states to a given directory
|
|
130
|
+
at specified step intervals or at the end of each epoch.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
checkpoint_dir (str): Directory where checkpoints will be saved.
|
|
134
|
+
steps (int, optional): Interval in steps for checkpointing.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
checkpoint_dir: str,
|
|
140
|
+
steps: Optional[int] = None,
|
|
141
|
+
):
|
|
142
|
+
super().__init__(steps=steps)
|
|
143
|
+
self.checkpoint_dir = checkpoint_dir
|
|
144
|
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
145
|
+
|
|
146
|
+
def step(
|
|
147
|
+
self,
|
|
148
|
+
epoch: int,
|
|
149
|
+
step: int,
|
|
150
|
+
metrics: Dict = {},
|
|
151
|
+
end_of_epoch: bool = False,
|
|
152
|
+
model: Optional[nn.Module] = None,
|
|
153
|
+
optimizer: Optional[Optimizer] = None,
|
|
154
|
+
scheduler: Optional[LRScheduler] = None,
|
|
155
|
+
):
|
|
156
|
+
"""
|
|
157
|
+
Saves a checkpoint if the conditions to run the plugin are met.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
epoch (int): Current epoch number.
|
|
161
|
+
step (int): Current training step.
|
|
162
|
+
metrics (dict): Optional metrics dictionary (unused here).
|
|
163
|
+
end_of_epoch (bool): Whether this is the end of the epoch.
|
|
164
|
+
model (nn.Module, optional): Model to be checkpointed.
|
|
165
|
+
optimizer (Optimizer, optional): Optimizer to save.
|
|
166
|
+
scheduler (LRScheduler, optional): Scheduler to save.
|
|
167
|
+
"""
|
|
168
|
+
# Check if we should save at this step or end of epoch
|
|
169
|
+
if not self.run(step, end_of_epoch):
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
checkpoint = {
|
|
173
|
+
"epoch": epoch,
|
|
174
|
+
"model_state_dict": model.state_dict() if model else None,
|
|
175
|
+
"optimizer_state_dict": optimizer.state_dict() if optimizer else None,
|
|
176
|
+
"scheduler_state_dict": scheduler.state_dict() if scheduler else None,
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
checkpoint_path = os.path.join(
|
|
180
|
+
self.checkpoint_dir,
|
|
181
|
+
"checkpoint_last_epoch.pth",
|
|
182
|
+
)
|
|
183
|
+
torch.save(checkpoint, checkpoint_path)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class MetricsTrackerPlugin(TrainerPlugin):
|
|
187
|
+
"""
|
|
188
|
+
Logs metrics at the end of each epoch. Currently only returning the validation loss.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self):
|
|
192
|
+
super().__init__()
|
|
193
|
+
self.validation_losses = []
|
|
194
|
+
self.metrics_history = {}
|
|
195
|
+
|
|
196
|
+
def step(
|
|
197
|
+
self,
|
|
198
|
+
epoch: int,
|
|
199
|
+
step: int,
|
|
200
|
+
metrics: Dict = {},
|
|
201
|
+
end_of_epoch: bool = False,
|
|
202
|
+
**kwargs,
|
|
203
|
+
):
|
|
204
|
+
for metric_name, metric_value in metrics.items():
|
|
205
|
+
# Add metric to history
|
|
206
|
+
if metric_name not in self.metrics_history:
|
|
207
|
+
self.metrics_history[metric_name] = []
|
|
208
|
+
self.metrics_history[metric_name].append(metric_value)
|
|
209
|
+
|
|
210
|
+
if end_of_epoch:
|
|
211
|
+
for metric_name, values in self.metrics_history.items():
|
|
212
|
+
if values: # Avoid division by zero or empty lists
|
|
213
|
+
avg_value = sum(values) / len(values)
|
|
214
|
+
if metric_name == "Validation Loss":
|
|
215
|
+
self.validation_losses.append(avg_value)
|
|
216
|
+
|
|
217
|
+
def get_losses(self):
|
|
218
|
+
return self.validation_losses
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from gridfm_graphkit.training.plugins import TrainerPlugin
|
|
2
|
+
from gridfm_graphkit.training.callbacks import EarlyStopper
|
|
3
|
+
|
|
4
|
+
from typing import List
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Trainer:
|
|
13
|
+
"""
|
|
14
|
+
A flexible training loop for GridFM models with optional validation, learning rate scheduling,
|
|
15
|
+
and plugin callbacks for logging or custom behavior.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
model (nn.Module): The PyTorch model to train.
|
|
19
|
+
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
20
|
+
device: The device to train on (CPU or CUDA).
|
|
21
|
+
loss_fn (nn.Module): Loss function that returns a loss dictionary.
|
|
22
|
+
early_stopper (EarlyStopper): Callback for early stopping based on validation loss.
|
|
23
|
+
train_dataloader (DataLoader): Dataloader for training data.
|
|
24
|
+
val_dataloader (DataLoader, optional): Dataloader for validation data.
|
|
25
|
+
lr_scheduler (optional): Learning rate scheduler.
|
|
26
|
+
plugins (List[TrainerPlugin]): List of plugin callbacks.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
model: nn.Module,
|
|
32
|
+
optimizer: Optimizer,
|
|
33
|
+
device,
|
|
34
|
+
loss_fn: nn.Module,
|
|
35
|
+
early_stopper: EarlyStopper,
|
|
36
|
+
train_dataloader: DataLoader,
|
|
37
|
+
val_dataloader: DataLoader,
|
|
38
|
+
lr_scheduler=None,
|
|
39
|
+
plugins: List[TrainerPlugin] = [],
|
|
40
|
+
):
|
|
41
|
+
self.model = model
|
|
42
|
+
self.optimizer = optimizer
|
|
43
|
+
self.device = device
|
|
44
|
+
self.early_stopper = early_stopper
|
|
45
|
+
self.loss_fn = loss_fn
|
|
46
|
+
self.train_dataloader = train_dataloader
|
|
47
|
+
self.val_dataloader = val_dataloader
|
|
48
|
+
self.lr_scheduler = lr_scheduler
|
|
49
|
+
self.plugins = plugins
|
|
50
|
+
|
|
51
|
+
def __one_step(
|
|
52
|
+
self,
|
|
53
|
+
input: torch.Tensor,
|
|
54
|
+
edge_index: torch.Tensor,
|
|
55
|
+
label: torch.Tensor,
|
|
56
|
+
edge_attr: torch.Tensor,
|
|
57
|
+
mask: torch.Tensor = None,
|
|
58
|
+
batch: torch.Tensor = None,
|
|
59
|
+
pe: torch.Tensor = None,
|
|
60
|
+
val: bool = False,
|
|
61
|
+
):
|
|
62
|
+
# expand the learnable mask to the input shape
|
|
63
|
+
mask_value_expanded = self.model.mask_value.expand(input.shape[0], -1)
|
|
64
|
+
# The line below will overwrite the last mask values, which is fine as long as the features which are masked do not change between batches
|
|
65
|
+
# set the learnable mask to the inout where it should be masked
|
|
66
|
+
input[:, : mask.shape[1]][mask] = mask_value_expanded[mask]
|
|
67
|
+
output = self.model(input, pe, edge_index, edge_attr, batch)
|
|
68
|
+
|
|
69
|
+
loss_dict = self.loss_fn(output, label, edge_index, edge_attr, mask)
|
|
70
|
+
|
|
71
|
+
if not val:
|
|
72
|
+
self.optimizer.zero_grad()
|
|
73
|
+
loss_dict["loss"].backward()
|
|
74
|
+
self.optimizer.step()
|
|
75
|
+
|
|
76
|
+
return loss_dict
|
|
77
|
+
|
|
78
|
+
def __one_epoch(self, epoch: int, prev_step: int):
|
|
79
|
+
self.model.train()
|
|
80
|
+
|
|
81
|
+
highest_step = prev_step
|
|
82
|
+
for step, batch in enumerate(self.train_dataloader):
|
|
83
|
+
step = prev_step + step + 1
|
|
84
|
+
highest_step = step
|
|
85
|
+
batch = batch.to(self.device)
|
|
86
|
+
|
|
87
|
+
mask = getattr(batch, "mask", None)
|
|
88
|
+
|
|
89
|
+
loss_dict = self.__one_step(
|
|
90
|
+
batch.x,
|
|
91
|
+
batch.edge_index,
|
|
92
|
+
batch.y,
|
|
93
|
+
batch.edge_attr,
|
|
94
|
+
mask,
|
|
95
|
+
batch.batch,
|
|
96
|
+
batch.pe,
|
|
97
|
+
)
|
|
98
|
+
current_lr = self.optimizer.param_groups[0]["lr"]
|
|
99
|
+
metrics = {}
|
|
100
|
+
metrics["Training Loss"] = loss_dict["loss"].item()
|
|
101
|
+
metrics["Learning Rate"] = current_lr
|
|
102
|
+
|
|
103
|
+
if self.model.learn_mask:
|
|
104
|
+
metrics["Mask Gradient Norm"] = self.model.mask_value.grad.norm().item()
|
|
105
|
+
|
|
106
|
+
for plugin in self.plugins:
|
|
107
|
+
plugin.step(epoch, step, metrics=metrics)
|
|
108
|
+
|
|
109
|
+
self.model.eval()
|
|
110
|
+
val_loss = 0.0
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
for batch in self.val_dataloader:
|
|
113
|
+
batch = batch.to(self.device)
|
|
114
|
+
mask = getattr(batch, "mask", None)
|
|
115
|
+
metrics = self.__one_step(
|
|
116
|
+
batch.x,
|
|
117
|
+
batch.edge_index,
|
|
118
|
+
batch.y,
|
|
119
|
+
batch.edge_attr,
|
|
120
|
+
mask,
|
|
121
|
+
batch.batch,
|
|
122
|
+
batch.pe,
|
|
123
|
+
True,
|
|
124
|
+
)
|
|
125
|
+
val_loss += metrics["loss"].item()
|
|
126
|
+
metrics["Validation Loss"] = metrics.pop("loss").item()
|
|
127
|
+
|
|
128
|
+
for plugin in self.plugins:
|
|
129
|
+
plugin.step(epoch, step, metrics=metrics)
|
|
130
|
+
val_loss /= len(self.val_dataloader)
|
|
131
|
+
if self.lr_scheduler is not None:
|
|
132
|
+
self.lr_scheduler.step(val_loss)
|
|
133
|
+
for plugin in self.plugins:
|
|
134
|
+
plugin.step(
|
|
135
|
+
epoch,
|
|
136
|
+
step=highest_step,
|
|
137
|
+
end_of_epoch=True,
|
|
138
|
+
model=self.model,
|
|
139
|
+
optimizer=self.optimizer,
|
|
140
|
+
scheduler=self.lr_scheduler,
|
|
141
|
+
)
|
|
142
|
+
return val_loss
|
|
143
|
+
|
|
144
|
+
def train(self, start_epoch: int = 0, epochs: int = 1, prev_step: int = -1):
|
|
145
|
+
"""
|
|
146
|
+
Main training loop.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
start_epoch (int): Epoch to start training from.
|
|
150
|
+
epochs (int): Total number of epochs to train.
|
|
151
|
+
prev_step (int): Previous training step (for logging continuity).
|
|
152
|
+
"""
|
|
153
|
+
for epoch in tqdm(range(start_epoch, start_epoch + epochs), desc="Epochs"):
|
|
154
|
+
val_loss = self.__one_epoch(epoch, prev_step)
|
|
155
|
+
if self.early_stopper.early_stop(val_loss, self.model):
|
|
156
|
+
break
|
|
File without changes
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VM, VA, G, B
|
|
2
|
+
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch
|
|
5
|
+
from torch_geometric.utils import to_torch_coo_tensor
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MaskedMSELoss(nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Mean Squared Error loss computed only on masked elements.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, reduction="mean"):
|
|
15
|
+
super(MaskedMSELoss, self).__init__()
|
|
16
|
+
self.reduction = reduction
|
|
17
|
+
|
|
18
|
+
def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
|
|
19
|
+
loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction)
|
|
20
|
+
return {"loss": loss, "Masked MSE loss": loss.item()}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MSELoss(nn.Module):
|
|
24
|
+
"""Standard Mean Squared Error loss."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, reduction="mean"):
|
|
27
|
+
super(MSELoss, self).__init__()
|
|
28
|
+
self.reduction = reduction
|
|
29
|
+
|
|
30
|
+
def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
|
|
31
|
+
loss = F.mse_loss(pred, target, reduction=self.reduction)
|
|
32
|
+
return {"loss": loss, "MSE loss": loss.item()}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SCELoss(nn.Module):
|
|
36
|
+
"""Scaled Cosine Error Loss with optional masking and normalization."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, alpha=3):
|
|
39
|
+
super(SCELoss, self).__init__()
|
|
40
|
+
self.alpha = alpha
|
|
41
|
+
|
|
42
|
+
def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
|
|
43
|
+
if mask is not None:
|
|
44
|
+
pred = F.normalize(pred[mask], p=2, dim=-1)
|
|
45
|
+
target = F.normalize(target[mask], p=2, dim=-1)
|
|
46
|
+
else:
|
|
47
|
+
pred = F.normalize(pred, p=2, dim=-1)
|
|
48
|
+
target = F.normalize(target, p=2, dim=-1)
|
|
49
|
+
|
|
50
|
+
loss = ((1 - (pred * target).sum(dim=-1)).pow(self.alpha)).mean()
|
|
51
|
+
|
|
52
|
+
return {
|
|
53
|
+
"loss": loss,
|
|
54
|
+
"SCE loss": loss.item(),
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PBELoss(nn.Module):
|
|
59
|
+
"""
|
|
60
|
+
Loss based on the Power Balance Equations.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, visualization=False):
|
|
64
|
+
super(PBELoss, self).__init__()
|
|
65
|
+
|
|
66
|
+
self.visualization = visualization
|
|
67
|
+
|
|
68
|
+
def forward(self, pred, target, edge_index, edge_attr, mask):
|
|
69
|
+
# Create a temporary copy of pred to avoid modifying it
|
|
70
|
+
temp_pred = pred.clone()
|
|
71
|
+
|
|
72
|
+
# If a value is not masked, then use the original one
|
|
73
|
+
unmasked = ~mask
|
|
74
|
+
temp_pred[unmasked] = target[unmasked]
|
|
75
|
+
|
|
76
|
+
# Voltage magnitudes and angles
|
|
77
|
+
V_m = temp_pred[:, VM] # Voltage magnitudes
|
|
78
|
+
V_a = temp_pred[:, VA] # Voltage angles
|
|
79
|
+
|
|
80
|
+
# Compute the complex voltage vector V
|
|
81
|
+
V = V_m * torch.exp(1j * V_a)
|
|
82
|
+
|
|
83
|
+
# Compute the conjugate of V
|
|
84
|
+
V_conj = torch.conj(V)
|
|
85
|
+
|
|
86
|
+
# Extract edge attributes for Y_bus
|
|
87
|
+
edge_complex = edge_attr[:, G] + 1j * edge_attr[:, B]
|
|
88
|
+
|
|
89
|
+
# Construct sparse admittance matrix (real and imaginary parts separately)
|
|
90
|
+
Y_bus_sparse = to_torch_coo_tensor(
|
|
91
|
+
edge_index,
|
|
92
|
+
edge_complex,
|
|
93
|
+
size=(target.size(0), target.size(0)),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Conjugate of the admittance matrix
|
|
97
|
+
Y_bus_conj = torch.conj(Y_bus_sparse)
|
|
98
|
+
|
|
99
|
+
# Compute the complex power injection S_injection
|
|
100
|
+
S_injection = torch.diag(V) @ Y_bus_conj @ V_conj
|
|
101
|
+
|
|
102
|
+
# Compute net power balance
|
|
103
|
+
net_P = temp_pred[:, PG] - temp_pred[:, PD]
|
|
104
|
+
net_Q = temp_pred[:, QG] - temp_pred[:, QD]
|
|
105
|
+
S_net_power_balance = net_P + 1j * net_Q
|
|
106
|
+
|
|
107
|
+
# Power balance loss
|
|
108
|
+
loss = torch.mean(
|
|
109
|
+
torch.abs(S_net_power_balance - S_injection),
|
|
110
|
+
) # Mean of absolute complex power value
|
|
111
|
+
|
|
112
|
+
real_loss_power = torch.mean(
|
|
113
|
+
torch.abs(torch.real(S_net_power_balance - S_injection)),
|
|
114
|
+
)
|
|
115
|
+
imag_loss_power = torch.mean(
|
|
116
|
+
torch.abs(torch.imag(S_net_power_balance - S_injection)),
|
|
117
|
+
)
|
|
118
|
+
if self.visualization:
|
|
119
|
+
return {
|
|
120
|
+
"loss": loss,
|
|
121
|
+
"Power power loss in p.u.": loss.item(),
|
|
122
|
+
"Active Power Loss in p.u.": real_loss_power.item(),
|
|
123
|
+
"Reactive Power Loss in p.u.": imag_loss_power.item(),
|
|
124
|
+
"Nodal Active Power Loss in p.u.": torch.abs(
|
|
125
|
+
torch.real(S_net_power_balance - S_injection),
|
|
126
|
+
),
|
|
127
|
+
"Nodal Reactive Power Loss in p.u.": torch.abs(
|
|
128
|
+
torch.imag(S_net_power_balance - S_injection),
|
|
129
|
+
),
|
|
130
|
+
}
|
|
131
|
+
else:
|
|
132
|
+
return {
|
|
133
|
+
"loss": loss,
|
|
134
|
+
"Power power loss in p.u.": loss.item(),
|
|
135
|
+
"Active Power Loss in p.u.": real_loss_power.item(),
|
|
136
|
+
"Reactive Power Loss in p.u.": imag_loss_power.item(),
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class MixedLoss(nn.Module):
|
|
141
|
+
"""
|
|
142
|
+
Combines multiple loss functions with weighted sum.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
loss_functions (list[nn.Module]): List of loss functions.
|
|
146
|
+
weights (list[float]): Corresponding weights for each loss function.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(self, loss_functions, weights):
|
|
150
|
+
super(MixedLoss, self).__init__()
|
|
151
|
+
|
|
152
|
+
if len(loss_functions) != len(weights):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"The number of loss functions must match the number of weights.",
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.loss_functions = nn.ModuleList(loss_functions)
|
|
158
|
+
self.weights = weights
|
|
159
|
+
|
|
160
|
+
def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
|
|
161
|
+
"""
|
|
162
|
+
Compute the weighted sum of all specified losses.
|
|
163
|
+
|
|
164
|
+
Parameters:
|
|
165
|
+
|
|
166
|
+
- pred: Predictions.
|
|
167
|
+
- target: Ground truth.
|
|
168
|
+
- edge_index: Optional edge index for graph-based losses.
|
|
169
|
+
- edge_attr: Optional edge attributes for graph-based losses.
|
|
170
|
+
- mask: Optional mask to filter the inputs for certain losses.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
- A dictionary with the total loss and individual losses.
|
|
174
|
+
"""
|
|
175
|
+
total_loss = 0.0
|
|
176
|
+
loss_details = {}
|
|
177
|
+
|
|
178
|
+
for i, loss_fn in enumerate(self.loss_functions):
|
|
179
|
+
loss_output = loss_fn(
|
|
180
|
+
pred,
|
|
181
|
+
target,
|
|
182
|
+
edge_index=edge_index,
|
|
183
|
+
edge_attr=edge_attr,
|
|
184
|
+
mask=mask,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Assume each loss function returns a dictionary with a "loss" key
|
|
188
|
+
individual_loss = loss_output.pop("loss")
|
|
189
|
+
weighted_loss = self.weights[i] * individual_loss
|
|
190
|
+
|
|
191
|
+
total_loss += weighted_loss
|
|
192
|
+
|
|
193
|
+
# Add other keys from the loss output to the details
|
|
194
|
+
for key, val in loss_output.items():
|
|
195
|
+
loss_details[key] = val
|
|
196
|
+
|
|
197
|
+
loss_details["loss"] = total_loss
|
|
198
|
+
return loss_details
|