gridfm-graphkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import os
4
+ from torch.utils.data import Subset
5
+ from typing import Tuple
6
+
7
+
8
+ def split_dataset(
9
+ dataset,
10
+ log_dir: str,
11
+ val_ratio: float = 0.1,
12
+ test_ratio: float = 0.1,
13
+ ) -> Tuple[Subset, Subset, Subset]:
14
+ """
15
+ Splits a dataset into training, validation, and test sets, and logs the indices for each split to CSV files for further analysis
16
+
17
+ Args:
18
+ dataset (torch_geometric.dataDataset): The dataset to split.
19
+ log_dir (str): Directory where CSV files containing the indices for each split will be saved.
20
+ val_ratio (float, optional): Proportion of the dataset to include in the validation set.
21
+ test_ratio (float, optional): Proportion of the dataset to include in the test set.
22
+
23
+ Raises:
24
+ ValueError: If `val_ratio + test_ratio >= 1`, which would leave no data for the training set.
25
+
26
+ Returns:
27
+ tuple: A tuple containing:
28
+ - train_dataset (torch.utils.data.Subset): The training subset of the dataset.
29
+ - val_dataset (torch.utils.data.Subset): The validation subset of the dataset.
30
+ - test_dataset (torch.utils.data.Subset): The test subset of the dataset.
31
+ """
32
+
33
+ if val_ratio + test_ratio >= 1:
34
+ raise ValueError("The sum of val_ratio and test_ratio must be less than 1.")
35
+
36
+ val_size = int(val_ratio * len(dataset))
37
+ test_size = int(test_ratio * len(dataset))
38
+ train_size = len(dataset) - val_size - test_size
39
+
40
+ # Generate shuffled indices and split manually
41
+ indices = np.random.permutation(len(dataset))
42
+ train_indices = indices[:train_size]
43
+ val_indices = indices[train_size : train_size + val_size]
44
+ test_indices = indices[train_size + val_size :]
45
+
46
+ # Save indices to CSV files
47
+ pd.DataFrame(train_indices, columns=["index"]).to_csv(
48
+ os.path.join(log_dir, "train_indices.csv"),
49
+ index=False,
50
+ )
51
+ pd.DataFrame(val_indices, columns=["index"]).to_csv(
52
+ os.path.join(log_dir, "val_indices.csv"),
53
+ index=False,
54
+ )
55
+ pd.DataFrame(test_indices, columns=["index"]).to_csv(
56
+ os.path.join(log_dir, "test_indices.csv"),
57
+ index=False,
58
+ )
59
+
60
+ # Create subsets
61
+ train_dataset = Subset(dataset, train_indices)
62
+ val_dataset = Subset(dataset, val_indices)
63
+ test_dataset = Subset(dataset, test_indices)
64
+
65
+ return train_dataset, val_dataset, test_dataset
File without changes
@@ -0,0 +1,293 @@
1
+ from gridfm_graphkit.datasets.data_normalization import (
2
+ IdentityNormalizer,
3
+ MinMaxNormalizer,
4
+ Standardizer,
5
+ BaseMVANormalizer,
6
+ )
7
+ from gridfm_graphkit.datasets.transforms import (
8
+ AddRandomMask,
9
+ AddPFMask,
10
+ AddOPFMask,
11
+ AddIdentityMask,
12
+ )
13
+ from gridfm_graphkit.utils.loss import (
14
+ PBELoss,
15
+ MaskedMSELoss,
16
+ SCELoss,
17
+ MixedLoss,
18
+ MSELoss,
19
+ )
20
+ from gridfm_graphkit.models.graphTransformer import GNN_TransformerConv
21
+ from gridfm_graphkit.models.gps_transformer import GPSTransformer
22
+
23
+ import argparse
24
+ import itertools
25
+
26
+
27
+ class NestedNamespace(argparse.Namespace):
28
+ """
29
+ A namespace object that supports nested structures, allowing for
30
+ easy access and manipulation of hierarchical configurations.
31
+
32
+ """
33
+
34
+ def __init__(self, **kwargs):
35
+ for key, value in kwargs.items():
36
+ if isinstance(value, dict):
37
+ # Recursively convert dictionaries to NestedNamespace
38
+ setattr(self, key, NestedNamespace(**value))
39
+ else:
40
+ setattr(self, key, value)
41
+
42
+ def to_dict(self):
43
+ # Recursively convert NestedNamespace back to dictionary
44
+ result = {}
45
+ for key, value in self.__dict__.items():
46
+ if isinstance(value, NestedNamespace):
47
+ result[key] = value.to_dict()
48
+ else:
49
+ result[key] = value
50
+ return result
51
+
52
+ def flatten(self, parent_key="", sep="."):
53
+ # Flatten the dictionary with dot-separated keys
54
+ items = []
55
+ for key, value in self.__dict__.items():
56
+ new_key = f"{parent_key}{sep}{key}" if parent_key else key
57
+ if isinstance(value, NestedNamespace):
58
+ items.extend(value.flatten(new_key, sep=sep).items())
59
+ else:
60
+ items.append((new_key, value))
61
+ return dict(items)
62
+
63
+
64
+ def flatten_dict(d, parent_key="", sep="."):
65
+ """
66
+ Flatten a nested dictionary into a single-level dictionary with dot-separated keys.
67
+
68
+ Args:
69
+ d (dict): The dictionary to flatten.
70
+ parent_key (str, optional): Prefix for the keys in the flattened dictionary.
71
+ sep (str, optional): Separator for nested keys. Defaults to '.'.
72
+
73
+ Returns:
74
+ dict: A flattened version of the input dictionary.
75
+ """
76
+ items = []
77
+ for key, value in d.items():
78
+ new_key = f"{parent_key}{sep}{key}" if parent_key else key
79
+ if isinstance(value, dict):
80
+ items.extend(flatten_dict(value, new_key, sep=sep).items())
81
+ else:
82
+ items.append((new_key, value))
83
+ return dict(items)
84
+
85
+
86
+ def unflatten_dict(d, sep="."):
87
+ """
88
+ Reconstruct a nested dictionary from a flattened dictionary with dot-separated keys.
89
+
90
+ Args:
91
+ d (dict): The flattened dictionary to unflatten.
92
+ sep (str, optional): Separator used in the flattened keys. Defaults to '.'.
93
+
94
+ Returns:
95
+ dict: A nested dictionary reconstructed from the flattened input.
96
+ """
97
+ result = {}
98
+ for key, value in d.items():
99
+ parts = key.split(sep)
100
+ target = result
101
+ for part in parts[:-1]:
102
+ target = target.setdefault(part, {})
103
+ target[parts[-1]] = value
104
+ return result
105
+
106
+
107
+ def merge_dict(base, updates):
108
+ """
109
+ Recursively merge updates into a base dictionary, but only if the keys exist in the base.
110
+
111
+ Args:
112
+ base (dict): The original dictionary to be updated.
113
+ updates (dict): The dictionary containing updates.
114
+
115
+ Raises:
116
+ KeyError: If a key in updates does not exist in base.
117
+ TypeError: If a key in base is not a dictionary but updates attempt to provide nested values.
118
+ """
119
+ for key, value in updates.items():
120
+ if key not in base:
121
+ raise KeyError(f"Key '{key}' not found in base configuration.")
122
+
123
+ if isinstance(value, dict):
124
+ if not isinstance(base[key], dict):
125
+ raise TypeError(
126
+ f"Default config expects {type(base[key])}, but got a dict at key '{key}'",
127
+ )
128
+ # Recursively merge dictionaries
129
+ merge_dict(base[key], value)
130
+ else:
131
+ # Update the existing key
132
+ base[key] = value
133
+
134
+
135
+ def param_combination_gen(grid_config):
136
+ """
137
+ Generate all combinations of parameters from a nested dictionary
138
+
139
+ Args:
140
+ grid_config (dict): A nested dictionary where keys are parameter names
141
+ and values are lists of possible values.
142
+
143
+ Returns:
144
+ list: A list of dictionaries representing all possible parameter combinations.
145
+ Each dictionary corresponds to one combination.
146
+ """
147
+
148
+ # Flatten the grid config for combination generation
149
+ flat_grid = flatten_dict(grid_config)
150
+
151
+ # Separate keys and values for itertools.product
152
+ keys, values = zip(*flat_grid.items())
153
+
154
+ # Generate all combinations of parameters
155
+ combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
156
+
157
+ # Unflatten the combinations back into nested dictionaries
158
+ nested_combinations = [unflatten_dict(comb) for comb in combinations]
159
+ return nested_combinations
160
+
161
+
162
+ def load_normalizer(args):
163
+ """
164
+ Load the appropriate data normalization methods
165
+
166
+ Args:
167
+ args (NestedNamespace): contains configs.
168
+
169
+ Returns:
170
+ tuple: Node and edge normalizers
171
+
172
+ Raises:
173
+ ValueError: If an unknown normalization method is specified.
174
+ """
175
+ method = args.data.normalization
176
+
177
+ if method == "minmax":
178
+ return MinMaxNormalizer(), MinMaxNormalizer()
179
+ elif method == "standard":
180
+ return Standardizer(), Standardizer()
181
+ elif method == "baseMVAnorm":
182
+ return BaseMVANormalizer(
183
+ node_data=True,
184
+ baseMVA_orig=args.data.baseMVA,
185
+ ), BaseMVANormalizer(node_data=False, baseMVA_orig=args.data.baseMVA)
186
+ elif method == "identity":
187
+ return IdentityNormalizer(), IdentityNormalizer()
188
+ else:
189
+ raise ValueError(f"Unknown normalization method: {method}")
190
+
191
+
192
+ def get_loss_function(args):
193
+ """
194
+ Load the appropriate loss function
195
+
196
+ Args:
197
+ args (NestedNamespace): contains configs.
198
+
199
+ Returns:
200
+ nn.Module: Loss function
201
+
202
+ Raises:
203
+ ValueError: If an unknown loss function is specified.
204
+ """
205
+ loss_functions = []
206
+ for loss_name in args.training.losses:
207
+ if loss_name == "MSE":
208
+ loss_functions.append(MSELoss())
209
+ elif loss_name == "MaskedMSE":
210
+ loss_functions.append(MaskedMSELoss())
211
+ elif loss_name == "SCE":
212
+ loss_functions.append(SCELoss())
213
+ elif loss_name == "PBE":
214
+ loss_functions.append(PBELoss())
215
+ else:
216
+ raise ValueError(f"Unknown loss function: {loss_name}")
217
+
218
+ return MixedLoss(loss_functions=loss_functions, weights=args.training.loss_weights)
219
+
220
+
221
+ def load_model(args):
222
+ """
223
+ Load the appropriate model
224
+
225
+ Args:
226
+ args (NestedNamespace): contains configs.
227
+
228
+ Returns:
229
+ nn.Module: The selected model initialized with the provided configurations.
230
+
231
+ Raises:
232
+ ValueError: If an unknown model type is specified.
233
+ """
234
+ model_type = args.model.type
235
+
236
+ if model_type == "GNN_TransformerConv":
237
+ return GNN_TransformerConv(
238
+ input_dim=args.model.input_dim,
239
+ hidden_dim=args.model.hidden_size,
240
+ output_dim=args.model.output_dim,
241
+ edge_dim=args.model.edge_dim,
242
+ num_layers=args.model.num_layers,
243
+ heads=args.model.attention_head,
244
+ mask_dim=args.data.mask_dim,
245
+ mask_value=args.data.mask_value,
246
+ learn_mask=args.data.learn_mask,
247
+ )
248
+ elif model_type == "GPSTransformer":
249
+ return GPSTransformer(
250
+ input_dim=args.model.input_dim,
251
+ hidden_dim=args.model.hidden_size,
252
+ output_dim=args.model.output_dim,
253
+ edge_dim=args.model.edge_dim,
254
+ pe_dim=args.model.pe_dim,
255
+ heads=args.model.attention_head,
256
+ num_layers=args.model.num_layers,
257
+ dropout=args.model.dropout,
258
+ mask_dim=args.data.mask_dim,
259
+ mask_value=args.data.mask_value,
260
+ learn_mask=args.data.learn_mask,
261
+ )
262
+ else:
263
+ raise ValueError(f"Unknown model type: {model_type}")
264
+
265
+
266
+ def get_transform(args):
267
+ """
268
+ Load the appropriate dataset transform
269
+
270
+ Args:
271
+ args (NestedNamespace): contains configs.
272
+
273
+ Returns:
274
+ BaseTransform: Transformation
275
+
276
+ Raises:
277
+ ValueError: If an unknown transform is specified.
278
+ """
279
+ mask_type = args.data.mask_type
280
+
281
+ if mask_type == "rnd":
282
+ return AddRandomMask(
283
+ mask_dim=args.data.mask_dim,
284
+ mask_ratio=args.data.mask_ratio,
285
+ )
286
+ elif mask_type == "pf":
287
+ return AddPFMask()
288
+ elif mask_type == "opf":
289
+ return AddOPFMask()
290
+ elif mask_type == "none":
291
+ return AddIdentityMask()
292
+ else:
293
+ raise ValueError(f"Unknown transformation: {mask_type}")
File without changes
@@ -0,0 +1,143 @@
1
+ from torch_geometric.nn import GPSConv, GINEConv
2
+ from torch import nn
3
+ import torch
4
+
5
+
6
+ class GPSTransformer(nn.Module):
7
+ """
8
+ A GPS (Graph Transformer) model based on [GPSConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html) and [GINEConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINEConv.html) layers from Pytorch Geometric.
9
+
10
+ This model encodes node features and positional encodings separately,
11
+ then applies multiple graph convolution layers with batch normalization,
12
+ and finally decodes to the output dimension.
13
+
14
+ Args:
15
+ input_dim (int): Dimension of input node features.
16
+ hidden_dim (int): Hidden dimension size for all layers.
17
+ output_dim (int): Dimension of the output node features.
18
+ edge_dim (int): Dimension of edge features.
19
+ pe_dim (int): Dimension of the positional encoding.
20
+ Must be less than hidden_dim.
21
+ num_layers (int): Number of GPSConv layers.
22
+ heads (int, optional): Number of attention heads in GPSConv.
23
+ dropout (float, optional): Dropout rate in GPSConv.
24
+ mask_dim (int, optional): Dimension of the mask vector.
25
+ mask_value (float, optional): Initial value for learnable mask parameters.
26
+ learn_mask (bool, optional): Whether to learn mask values as parameters.
27
+
28
+ Raises:
29
+ ValueError: If `pe_dim` is not less than `hidden_dim`.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ input_dim: int,
35
+ hidden_dim: int,
36
+ output_dim: int,
37
+ edge_dim: int,
38
+ pe_dim: int,
39
+ num_layers: int,
40
+ heads: int = 1,
41
+ dropout: float = 0.0,
42
+ mask_dim: int = 6,
43
+ mask_value: float = -1.0,
44
+ learn_mask: bool = True,
45
+ ):
46
+ super(GPSTransformer, self).__init__()
47
+ self.num_layers = num_layers
48
+ self.hidden_dim = hidden_dim
49
+ self.edge_dim = edge_dim
50
+ self.pe_dim = pe_dim
51
+ self.heads = heads
52
+ self.dropout = dropout
53
+ self.mask_dim = mask_dim
54
+ self.mask_value = mask_value
55
+ self.learn_mask = learn_mask
56
+
57
+ if not pe_dim < hidden_dim:
58
+ raise ValueError(
59
+ "positional encoding dimension must be smaller than model hidden dimension",
60
+ )
61
+
62
+ self.layers = nn.ModuleList()
63
+
64
+ self.encoder = nn.Sequential(
65
+ nn.Linear(input_dim, self.hidden_dim - self.pe_dim),
66
+ nn.LeakyReLU(),
67
+ )
68
+ self.input_norm = nn.BatchNorm1d(self.hidden_dim - self.pe_dim)
69
+ self.pe_norm = nn.BatchNorm1d(self.pe_dim)
70
+
71
+ for _ in range(self.num_layers):
72
+ mlp = nn.Sequential(
73
+ nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim),
74
+ nn.LeakyReLU(),
75
+ )
76
+ self.layers.append(
77
+ nn.ModuleDict(
78
+ {
79
+ "conv": GPSConv(
80
+ channels=self.hidden_dim,
81
+ conv=GINEConv(nn=mlp, edge_dim=self.edge_dim),
82
+ heads=self.heads,
83
+ dropout=self.dropout,
84
+ ),
85
+ "norm": nn.BatchNorm1d(
86
+ self.hidden_dim,
87
+ ), # BatchNorm after each graph layer
88
+ },
89
+ ),
90
+ )
91
+
92
+ self.pre_decoder_norm = nn.BatchNorm1d(self.hidden_dim)
93
+ # Fully connected (MLP) layers after the GAT layers
94
+ self.decoder = nn.Sequential(
95
+ nn.Linear(self.hidden_dim, self.hidden_dim),
96
+ nn.LeakyReLU(),
97
+ nn.Linear(self.hidden_dim, output_dim),
98
+ )
99
+
100
+ if learn_mask:
101
+ self.mask_value = nn.Parameter(
102
+ torch.randn(mask_dim) + mask_value,
103
+ requires_grad=True,
104
+ )
105
+ else:
106
+ self.mask_value = nn.Parameter(
107
+ torch.zeros(mask_dim) + mask_value,
108
+ requires_grad=False,
109
+ )
110
+
111
+ def forward(self, x, pe, edge_index, edge_attr, batch):
112
+ """
113
+ Forward pass for the GPSTransformer.
114
+
115
+ Args:
116
+ x (Tensor): Input node features of shape [num_nodes, input_dim].
117
+ pe (Tensor): Positional encoding of shape [num_nodes, pe_dim].
118
+ edge_index (Tensor): Edge indices for graph convolution.
119
+ edge_attr (Tensor): Edge feature tensor.
120
+ batch (Tensor): Batch vector assigning nodes to graphs.
121
+
122
+ Returns:
123
+ output (Tensor): Output node features of shape [num_nodes, output_dim].
124
+ """
125
+ x_pe = self.pe_norm(pe)
126
+
127
+ x = self.encoder(x)
128
+ x = self.input_norm(x)
129
+
130
+ x = torch.cat((x, x_pe), 1)
131
+ for layer in self.layers:
132
+ x = layer["conv"](
133
+ x=x,
134
+ edge_index=edge_index,
135
+ edge_attr=edge_attr,
136
+ batch=batch,
137
+ )
138
+ x = layer["norm"](x)
139
+
140
+ x = self.pre_decoder_norm(x)
141
+ x = self.decoder(x)
142
+
143
+ return x
@@ -0,0 +1,96 @@
1
+ from torch_geometric.nn import TransformerConv
2
+ from torch import nn
3
+ import torch
4
+
5
+
6
+ class GNN_TransformerConv(nn.Module):
7
+ """
8
+ Graph Neural Network using [TransformerConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TransformerConv.html) layers from PyTorch Geometric.
9
+
10
+ Args:
11
+ input_dim (int): Dimensionality of input node features.
12
+ hidden_dim (int): Hidden dimension size for TransformerConv layers.
13
+ output_dim (int): Output dimension size.
14
+ edge_dim (int): Dimensionality of edge features.
15
+ num_layers (int): Number of TransformerConv layers.
16
+ heads (int, optional): Number of attention heads.
17
+ mask_dim (int, optional): Dimension of mask vector.
18
+ mask_value (float, optional): Initial mask value.
19
+ learn_mask (bool, optional): Whether mask values are learnable.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ input_dim: int,
25
+ hidden_dim: int,
26
+ output_dim: int,
27
+ edge_dim: int,
28
+ num_layers: int,
29
+ heads: int = 1,
30
+ mask_dim: int = 6,
31
+ mask_value: float = -1.0,
32
+ learn_mask: bool = False,
33
+ ):
34
+ super(GNN_TransformerConv, self).__init__()
35
+ self.num_layers = num_layers
36
+ self.hidden_dim = hidden_dim
37
+ self.edge_dim = edge_dim
38
+ self.heads = heads
39
+ self.mask_dim = mask_dim
40
+ self.mask_value = mask_value
41
+ self.learn_mask = learn_mask
42
+
43
+ self.layers = nn.ModuleList()
44
+ current_dim = input_dim # First layer takes `input_dim` as input
45
+
46
+ for _ in range(self.num_layers):
47
+ self.layers.append(
48
+ TransformerConv(
49
+ current_dim,
50
+ self.hidden_dim,
51
+ heads=self.heads,
52
+ edge_dim=self.edge_dim,
53
+ beta=False,
54
+ ),
55
+ )
56
+ # Update the dimension for the next layer
57
+ current_dim = self.hidden_dim * self.heads
58
+
59
+ # Fully connected (MLP) layers after the GAT layers
60
+ self.mlps = nn.Sequential(
61
+ nn.Linear(self.hidden_dim * self.heads, self.hidden_dim),
62
+ nn.LeakyReLU(),
63
+ nn.Linear(self.hidden_dim, output_dim),
64
+ )
65
+
66
+ if learn_mask:
67
+ self.mask_value = nn.Parameter(
68
+ torch.randn(mask_dim) + mask_value,
69
+ requires_grad=True,
70
+ )
71
+ else:
72
+ self.mask_value = nn.Parameter(
73
+ torch.zeros(mask_dim) + mask_value,
74
+ requires_grad=False,
75
+ )
76
+
77
+ def forward(self, x, pe, edge_index, edge_attr, batch):
78
+ """
79
+ Forward pass for the GPSTransformer.
80
+
81
+ Args:
82
+ x (Tensor): Input node features of shape [num_nodes, input_dim].
83
+ pe (Tensor): Positional encoding of shape [num_nodes, pe_dim] (not used).
84
+ edge_index (Tensor): Edge indices for graph convolution.
85
+ edge_attr (Tensor): Edge feature tensor.
86
+ batch (Tensor): Batch vector assigning nodes to graphs (not used).
87
+
88
+ Returns:
89
+ output (Tensor): Output node features of shape [num_nodes, output_dim].
90
+ """
91
+ for conv in self.layers:
92
+ x = conv(x, edge_index, edge_attr)
93
+ x = nn.LeakyReLU()(x)
94
+
95
+ x = self.mlps(x)
96
+ return x
File without changes
@@ -0,0 +1,47 @@
1
+ import torch
2
+
3
+
4
+ class EarlyStopper:
5
+ def __init__(
6
+ self,
7
+ saving_path,
8
+ patience=5,
9
+ tol=0,
10
+ min_validation_loss=float("inf"),
11
+ ):
12
+ """
13
+ Args:
14
+ patience (int): number of epochs to wait before early stopping
15
+ -1 means no early stopping
16
+ 0 means stop training the first time the validation loss increases
17
+ tol (float): tolerance to consider validation loss as worse as the best one so far
18
+ """
19
+
20
+ self.patience = patience
21
+ self.tol = tol
22
+ self.counter = 0
23
+ self.min_validation_loss = min_validation_loss
24
+ self.saving_path = saving_path
25
+
26
+ def early_stop(self, validation_loss, model):
27
+ if validation_loss < self.min_validation_loss:
28
+ self.min_validation_loss = (
29
+ validation_loss # Update the best validation loss
30
+ )
31
+ self.counter = 0
32
+
33
+ # Save the best model whenever a new minimum is found
34
+ torch.save(model, self.saving_path)
35
+
36
+ # check if the valid loss is worse than the best one so far, accounting for tolerance
37
+ elif validation_loss > (self.min_validation_loss + self.tol):
38
+ self.counter += 1
39
+
40
+ if self.patience != -1 and self.counter > self.patience:
41
+ print(
42
+ "Early stopping after {} epochs of no improvement.".format(
43
+ self.counter,
44
+ ),
45
+ )
46
+ return True
47
+ return False