gridfm-graphkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,218 @@
1
+ from abc import abstractmethod
2
+ from typing import Dict, Optional
3
+ import mlflow
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import Optimizer
8
+ from torch.optim.lr_scheduler import LRScheduler
9
+
10
+
11
+ class TrainerPlugin:
12
+ """
13
+ Base class for training plugins.
14
+
15
+ A `TrainerPlugin` is invoked during the training process either at regular step intervals,
16
+ at the end of each epoch, or both. It can be extended to perform actions like logging,
17
+ checkpointing, or validation.
18
+
19
+ Args:
20
+ steps (int, optional): Interval (in steps) to run the plugin. If `None`, only runs at end of epoch
21
+ """
22
+
23
+ def __init__(self, steps: Optional[int] = None):
24
+ self.steps = steps
25
+
26
+ def run(self, step: int, end_of_epoch: bool) -> bool:
27
+ """
28
+ Determines whether to execute the plugin at the current step.
29
+
30
+ Args:
31
+ step (int): The current step number.
32
+ end_of_epoch (bool): Whether this is the end of the epoch.
33
+
34
+ Returns:
35
+ bool: True if the plugin should run; False otherwise.
36
+ """
37
+ # By default we always run for epoch ends.
38
+ if end_of_epoch:
39
+ return True
40
+ # If self.steps is None, we're only recording epoch ends and this isn't one.
41
+ if self.steps is None:
42
+ return False
43
+ # record every `step` steps, starting from step `step`
44
+ if step != 0 and (step + 1) % self.steps == 0:
45
+ return True
46
+ return False
47
+
48
+ @abstractmethod
49
+ def step(
50
+ self,
51
+ epoch: int,
52
+ step: int,
53
+ metrics: Dict = {},
54
+ end_of_epoch: bool = False,
55
+ **kwargs,
56
+ ):
57
+ """
58
+ This method is called on every step of training, or with step=None
59
+ at the end of each epoch. Implementations can use the passed in
60
+ parameters for validation, checkpointing, logging, etc.
61
+
62
+ Args:
63
+ epoch (int): The current epoch number.
64
+ step (int): The current step within the epoch.
65
+ metrics (dict): Dictionary of training metrics (e.g., loss).
66
+ end_of_epoch (bool): Indicates if this call is at the end of an epoch.
67
+ **kwargs (Any): Additional parameters such as model, optimizer, scheduler.
68
+ """
69
+ pass
70
+
71
+
72
+ class MLflowLoggerPlugin(TrainerPlugin):
73
+ """
74
+ Plugin to log training metrics to MLflow.
75
+
76
+ Logs metrics dynamically during training at defined step intervals and/or
77
+ at the end of each epoch. Also logs initial training parameters once.
78
+
79
+ Args:
80
+ steps (int, optional): Interval in steps to log metrics.
81
+ params (dict, optional): Parameters to log to MLflow at the start.
82
+ """
83
+
84
+ def __init__(self, steps: Optional[int] = None, params: dict = None):
85
+ super().__init__(steps=steps) # Initialize the steps from the base class
86
+ self.steps = steps
87
+ self.metrics_history = {} # Dictionary to hold lists of all metrics over time
88
+ if params:
89
+ # Log parameters to MLflow at the beginning of training
90
+ mlflow.log_params(params)
91
+
92
+ def step(
93
+ self,
94
+ epoch: int,
95
+ step: int,
96
+ metrics: Dict = {},
97
+ end_of_epoch: bool = False,
98
+ **kwargs,
99
+ ):
100
+ """
101
+ Logs metrics to MLflow dynamically at each specified step and at the end of each epoch.
102
+
103
+ Args:
104
+ epoch (int): The current epoch number.
105
+ step (int): The current step within the epoch.
106
+ metrics (Dict): Dictionary of metrics to log, e.g., {'train_loss': value}.
107
+ end_of_epoch (bool): Flag indicating whether this is the end of the epoch.
108
+ """
109
+ for metric_name, metric_value in metrics.items():
110
+ # Add metric to history
111
+ if metric_name not in self.metrics_history:
112
+ self.metrics_history[metric_name] = []
113
+ self.metrics_history[metric_name].append(metric_value)
114
+
115
+ if end_of_epoch:
116
+ for metric_name, values in self.metrics_history.items():
117
+ if values: # Avoid division by zero or empty lists
118
+ avg_value = sum(values) / len(values)
119
+ mlflow.log_metric(f"{metric_name}", avg_value, step=epoch)
120
+
121
+ # Clear metrics for the next epoch
122
+ self.metrics_history = {}
123
+
124
+
125
+ class CheckpointerPlugin(TrainerPlugin):
126
+ """
127
+ Plugin to periodically save model checkpoints.
128
+
129
+ Stores the model, optimizer, and scheduler states to a given directory
130
+ at specified step intervals or at the end of each epoch.
131
+
132
+ Args:
133
+ checkpoint_dir (str): Directory where checkpoints will be saved.
134
+ steps (int, optional): Interval in steps for checkpointing.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ checkpoint_dir: str,
140
+ steps: Optional[int] = None,
141
+ ):
142
+ super().__init__(steps=steps)
143
+ self.checkpoint_dir = checkpoint_dir
144
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
145
+
146
+ def step(
147
+ self,
148
+ epoch: int,
149
+ step: int,
150
+ metrics: Dict = {},
151
+ end_of_epoch: bool = False,
152
+ model: Optional[nn.Module] = None,
153
+ optimizer: Optional[Optimizer] = None,
154
+ scheduler: Optional[LRScheduler] = None,
155
+ ):
156
+ """
157
+ Saves a checkpoint if the conditions to run the plugin are met.
158
+
159
+ Args:
160
+ epoch (int): Current epoch number.
161
+ step (int): Current training step.
162
+ metrics (dict): Optional metrics dictionary (unused here).
163
+ end_of_epoch (bool): Whether this is the end of the epoch.
164
+ model (nn.Module, optional): Model to be checkpointed.
165
+ optimizer (Optimizer, optional): Optimizer to save.
166
+ scheduler (LRScheduler, optional): Scheduler to save.
167
+ """
168
+ # Check if we should save at this step or end of epoch
169
+ if not self.run(step, end_of_epoch):
170
+ return
171
+
172
+ checkpoint = {
173
+ "epoch": epoch,
174
+ "model_state_dict": model.state_dict() if model else None,
175
+ "optimizer_state_dict": optimizer.state_dict() if optimizer else None,
176
+ "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
177
+ }
178
+
179
+ checkpoint_path = os.path.join(
180
+ self.checkpoint_dir,
181
+ "checkpoint_last_epoch.pth",
182
+ )
183
+ torch.save(checkpoint, checkpoint_path)
184
+
185
+
186
+ class MetricsTrackerPlugin(TrainerPlugin):
187
+ """
188
+ Logs metrics at the end of each epoch. Currently only returning the validation loss.
189
+ """
190
+
191
+ def __init__(self):
192
+ super().__init__()
193
+ self.validation_losses = []
194
+ self.metrics_history = {}
195
+
196
+ def step(
197
+ self,
198
+ epoch: int,
199
+ step: int,
200
+ metrics: Dict = {},
201
+ end_of_epoch: bool = False,
202
+ **kwargs,
203
+ ):
204
+ for metric_name, metric_value in metrics.items():
205
+ # Add metric to history
206
+ if metric_name not in self.metrics_history:
207
+ self.metrics_history[metric_name] = []
208
+ self.metrics_history[metric_name].append(metric_value)
209
+
210
+ if end_of_epoch:
211
+ for metric_name, values in self.metrics_history.items():
212
+ if values: # Avoid division by zero or empty lists
213
+ avg_value = sum(values) / len(values)
214
+ if metric_name == "Validation Loss":
215
+ self.validation_losses.append(avg_value)
216
+
217
+ def get_losses(self):
218
+ return self.validation_losses
@@ -0,0 +1,156 @@
1
+ from gridfm_graphkit.training.plugins import TrainerPlugin
2
+ from gridfm_graphkit.training.callbacks import EarlyStopper
3
+
4
+ from typing import List
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+
12
+ class Trainer:
13
+ """
14
+ A flexible training loop for GridFM models with optional validation, learning rate scheduling,
15
+ and plugin callbacks for logging or custom behavior.
16
+
17
+ Attributes:
18
+ model (nn.Module): The PyTorch model to train.
19
+ optimizer (Optimizer): The optimizer used for updating model parameters.
20
+ device: The device to train on (CPU or CUDA).
21
+ loss_fn (nn.Module): Loss function that returns a loss dictionary.
22
+ early_stopper (EarlyStopper): Callback for early stopping based on validation loss.
23
+ train_dataloader (DataLoader): Dataloader for training data.
24
+ val_dataloader (DataLoader, optional): Dataloader for validation data.
25
+ lr_scheduler (optional): Learning rate scheduler.
26
+ plugins (List[TrainerPlugin]): List of plugin callbacks.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model: nn.Module,
32
+ optimizer: Optimizer,
33
+ device,
34
+ loss_fn: nn.Module,
35
+ early_stopper: EarlyStopper,
36
+ train_dataloader: DataLoader,
37
+ val_dataloader: DataLoader,
38
+ lr_scheduler=None,
39
+ plugins: List[TrainerPlugin] = [],
40
+ ):
41
+ self.model = model
42
+ self.optimizer = optimizer
43
+ self.device = device
44
+ self.early_stopper = early_stopper
45
+ self.loss_fn = loss_fn
46
+ self.train_dataloader = train_dataloader
47
+ self.val_dataloader = val_dataloader
48
+ self.lr_scheduler = lr_scheduler
49
+ self.plugins = plugins
50
+
51
+ def __one_step(
52
+ self,
53
+ input: torch.Tensor,
54
+ edge_index: torch.Tensor,
55
+ label: torch.Tensor,
56
+ edge_attr: torch.Tensor,
57
+ mask: torch.Tensor = None,
58
+ batch: torch.Tensor = None,
59
+ pe: torch.Tensor = None,
60
+ val: bool = False,
61
+ ):
62
+ # expand the learnable mask to the input shape
63
+ mask_value_expanded = self.model.mask_value.expand(input.shape[0], -1)
64
+ # The line below will overwrite the last mask values, which is fine as long as the features which are masked do not change between batches
65
+ # set the learnable mask to the inout where it should be masked
66
+ input[:, : mask.shape[1]][mask] = mask_value_expanded[mask]
67
+ output = self.model(input, pe, edge_index, edge_attr, batch)
68
+
69
+ loss_dict = self.loss_fn(output, label, edge_index, edge_attr, mask)
70
+
71
+ if not val:
72
+ self.optimizer.zero_grad()
73
+ loss_dict["loss"].backward()
74
+ self.optimizer.step()
75
+
76
+ return loss_dict
77
+
78
+ def __one_epoch(self, epoch: int, prev_step: int):
79
+ self.model.train()
80
+
81
+ highest_step = prev_step
82
+ for step, batch in enumerate(self.train_dataloader):
83
+ step = prev_step + step + 1
84
+ highest_step = step
85
+ batch = batch.to(self.device)
86
+
87
+ mask = getattr(batch, "mask", None)
88
+
89
+ loss_dict = self.__one_step(
90
+ batch.x,
91
+ batch.edge_index,
92
+ batch.y,
93
+ batch.edge_attr,
94
+ mask,
95
+ batch.batch,
96
+ batch.pe,
97
+ )
98
+ current_lr = self.optimizer.param_groups[0]["lr"]
99
+ metrics = {}
100
+ metrics["Training Loss"] = loss_dict["loss"].item()
101
+ metrics["Learning Rate"] = current_lr
102
+
103
+ if self.model.learn_mask:
104
+ metrics["Mask Gradient Norm"] = self.model.mask_value.grad.norm().item()
105
+
106
+ for plugin in self.plugins:
107
+ plugin.step(epoch, step, metrics=metrics)
108
+
109
+ self.model.eval()
110
+ val_loss = 0.0
111
+ with torch.no_grad():
112
+ for batch in self.val_dataloader:
113
+ batch = batch.to(self.device)
114
+ mask = getattr(batch, "mask", None)
115
+ metrics = self.__one_step(
116
+ batch.x,
117
+ batch.edge_index,
118
+ batch.y,
119
+ batch.edge_attr,
120
+ mask,
121
+ batch.batch,
122
+ batch.pe,
123
+ True,
124
+ )
125
+ val_loss += metrics["loss"].item()
126
+ metrics["Validation Loss"] = metrics.pop("loss").item()
127
+
128
+ for plugin in self.plugins:
129
+ plugin.step(epoch, step, metrics=metrics)
130
+ val_loss /= len(self.val_dataloader)
131
+ if self.lr_scheduler is not None:
132
+ self.lr_scheduler.step(val_loss)
133
+ for plugin in self.plugins:
134
+ plugin.step(
135
+ epoch,
136
+ step=highest_step,
137
+ end_of_epoch=True,
138
+ model=self.model,
139
+ optimizer=self.optimizer,
140
+ scheduler=self.lr_scheduler,
141
+ )
142
+ return val_loss
143
+
144
+ def train(self, start_epoch: int = 0, epochs: int = 1, prev_step: int = -1):
145
+ """
146
+ Main training loop.
147
+
148
+ Args:
149
+ start_epoch (int): Epoch to start training from.
150
+ epochs (int): Total number of epochs to train.
151
+ prev_step (int): Previous training step (for logging continuity).
152
+ """
153
+ for epoch in tqdm(range(start_epoch, start_epoch + epochs), desc="Epochs"):
154
+ val_loss = self.__one_epoch(epoch, prev_step)
155
+ if self.early_stopper.early_stop(val_loss, self.model):
156
+ break
File without changes
@@ -0,0 +1,198 @@
1
+ from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VM, VA, G, B
2
+
3
+ import torch.nn.functional as F
4
+ import torch
5
+ from torch_geometric.utils import to_torch_coo_tensor
6
+ import torch.nn as nn
7
+
8
+
9
+ class MaskedMSELoss(nn.Module):
10
+ """
11
+ Mean Squared Error loss computed only on masked elements.
12
+ """
13
+
14
+ def __init__(self, reduction="mean"):
15
+ super(MaskedMSELoss, self).__init__()
16
+ self.reduction = reduction
17
+
18
+ def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
19
+ loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction)
20
+ return {"loss": loss, "Masked MSE loss": loss.item()}
21
+
22
+
23
+ class MSELoss(nn.Module):
24
+ """Standard Mean Squared Error loss."""
25
+
26
+ def __init__(self, reduction="mean"):
27
+ super(MSELoss, self).__init__()
28
+ self.reduction = reduction
29
+
30
+ def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
31
+ loss = F.mse_loss(pred, target, reduction=self.reduction)
32
+ return {"loss": loss, "MSE loss": loss.item()}
33
+
34
+
35
+ class SCELoss(nn.Module):
36
+ """Scaled Cosine Error Loss with optional masking and normalization."""
37
+
38
+ def __init__(self, alpha=3):
39
+ super(SCELoss, self).__init__()
40
+ self.alpha = alpha
41
+
42
+ def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
43
+ if mask is not None:
44
+ pred = F.normalize(pred[mask], p=2, dim=-1)
45
+ target = F.normalize(target[mask], p=2, dim=-1)
46
+ else:
47
+ pred = F.normalize(pred, p=2, dim=-1)
48
+ target = F.normalize(target, p=2, dim=-1)
49
+
50
+ loss = ((1 - (pred * target).sum(dim=-1)).pow(self.alpha)).mean()
51
+
52
+ return {
53
+ "loss": loss,
54
+ "SCE loss": loss.item(),
55
+ }
56
+
57
+
58
+ class PBELoss(nn.Module):
59
+ """
60
+ Loss based on the Power Balance Equations.
61
+ """
62
+
63
+ def __init__(self, visualization=False):
64
+ super(PBELoss, self).__init__()
65
+
66
+ self.visualization = visualization
67
+
68
+ def forward(self, pred, target, edge_index, edge_attr, mask):
69
+ # Create a temporary copy of pred to avoid modifying it
70
+ temp_pred = pred.clone()
71
+
72
+ # If a value is not masked, then use the original one
73
+ unmasked = ~mask
74
+ temp_pred[unmasked] = target[unmasked]
75
+
76
+ # Voltage magnitudes and angles
77
+ V_m = temp_pred[:, VM] # Voltage magnitudes
78
+ V_a = temp_pred[:, VA] # Voltage angles
79
+
80
+ # Compute the complex voltage vector V
81
+ V = V_m * torch.exp(1j * V_a)
82
+
83
+ # Compute the conjugate of V
84
+ V_conj = torch.conj(V)
85
+
86
+ # Extract edge attributes for Y_bus
87
+ edge_complex = edge_attr[:, G] + 1j * edge_attr[:, B]
88
+
89
+ # Construct sparse admittance matrix (real and imaginary parts separately)
90
+ Y_bus_sparse = to_torch_coo_tensor(
91
+ edge_index,
92
+ edge_complex,
93
+ size=(target.size(0), target.size(0)),
94
+ )
95
+
96
+ # Conjugate of the admittance matrix
97
+ Y_bus_conj = torch.conj(Y_bus_sparse)
98
+
99
+ # Compute the complex power injection S_injection
100
+ S_injection = torch.diag(V) @ Y_bus_conj @ V_conj
101
+
102
+ # Compute net power balance
103
+ net_P = temp_pred[:, PG] - temp_pred[:, PD]
104
+ net_Q = temp_pred[:, QG] - temp_pred[:, QD]
105
+ S_net_power_balance = net_P + 1j * net_Q
106
+
107
+ # Power balance loss
108
+ loss = torch.mean(
109
+ torch.abs(S_net_power_balance - S_injection),
110
+ ) # Mean of absolute complex power value
111
+
112
+ real_loss_power = torch.mean(
113
+ torch.abs(torch.real(S_net_power_balance - S_injection)),
114
+ )
115
+ imag_loss_power = torch.mean(
116
+ torch.abs(torch.imag(S_net_power_balance - S_injection)),
117
+ )
118
+ if self.visualization:
119
+ return {
120
+ "loss": loss,
121
+ "Power power loss in p.u.": loss.item(),
122
+ "Active Power Loss in p.u.": real_loss_power.item(),
123
+ "Reactive Power Loss in p.u.": imag_loss_power.item(),
124
+ "Nodal Active Power Loss in p.u.": torch.abs(
125
+ torch.real(S_net_power_balance - S_injection),
126
+ ),
127
+ "Nodal Reactive Power Loss in p.u.": torch.abs(
128
+ torch.imag(S_net_power_balance - S_injection),
129
+ ),
130
+ }
131
+ else:
132
+ return {
133
+ "loss": loss,
134
+ "Power power loss in p.u.": loss.item(),
135
+ "Active Power Loss in p.u.": real_loss_power.item(),
136
+ "Reactive Power Loss in p.u.": imag_loss_power.item(),
137
+ }
138
+
139
+
140
+ class MixedLoss(nn.Module):
141
+ """
142
+ Combines multiple loss functions with weighted sum.
143
+
144
+ Args:
145
+ loss_functions (list[nn.Module]): List of loss functions.
146
+ weights (list[float]): Corresponding weights for each loss function.
147
+ """
148
+
149
+ def __init__(self, loss_functions, weights):
150
+ super(MixedLoss, self).__init__()
151
+
152
+ if len(loss_functions) != len(weights):
153
+ raise ValueError(
154
+ "The number of loss functions must match the number of weights.",
155
+ )
156
+
157
+ self.loss_functions = nn.ModuleList(loss_functions)
158
+ self.weights = weights
159
+
160
+ def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None):
161
+ """
162
+ Compute the weighted sum of all specified losses.
163
+
164
+ Parameters:
165
+
166
+ - pred: Predictions.
167
+ - target: Ground truth.
168
+ - edge_index: Optional edge index for graph-based losses.
169
+ - edge_attr: Optional edge attributes for graph-based losses.
170
+ - mask: Optional mask to filter the inputs for certain losses.
171
+
172
+ Returns:
173
+ - A dictionary with the total loss and individual losses.
174
+ """
175
+ total_loss = 0.0
176
+ loss_details = {}
177
+
178
+ for i, loss_fn in enumerate(self.loss_functions):
179
+ loss_output = loss_fn(
180
+ pred,
181
+ target,
182
+ edge_index=edge_index,
183
+ edge_attr=edge_attr,
184
+ mask=mask,
185
+ )
186
+
187
+ # Assume each loss function returns a dictionary with a "loss" key
188
+ individual_loss = loss_output.pop("loss")
189
+ weighted_loss = self.weights[i] * individual_loss
190
+
191
+ total_loss += weighted_loss
192
+
193
+ # Add other keys from the loss output to the details
194
+ for key, val in loss_output.items():
195
+ loss_details[key] = val
196
+
197
+ loss_details["loss"] = total_loss
198
+ return loss_details