gridfm-graphkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gridfm_graphkit/__init__.py +0 -0
- gridfm_graphkit/__main__.py +62 -0
- gridfm_graphkit/cli.py +530 -0
- gridfm_graphkit/datasets/__init__.py +0 -0
- gridfm_graphkit/datasets/data_normalization.py +227 -0
- gridfm_graphkit/datasets/globals.py +19 -0
- gridfm_graphkit/datasets/powergrid.py +192 -0
- gridfm_graphkit/datasets/transforms.py +223 -0
- gridfm_graphkit/datasets/utils.py +65 -0
- gridfm_graphkit/io/__init__.py +0 -0
- gridfm_graphkit/io/param_handler.py +293 -0
- gridfm_graphkit/models/__init__.py +0 -0
- gridfm_graphkit/models/gps_transformer.py +143 -0
- gridfm_graphkit/models/graphTransformer.py +96 -0
- gridfm_graphkit/training/__init__.py +0 -0
- gridfm_graphkit/training/callbacks.py +47 -0
- gridfm_graphkit/training/plugins.py +218 -0
- gridfm_graphkit/training/trainer.py +156 -0
- gridfm_graphkit/utils/__init__.py +0 -0
- gridfm_graphkit/utils/loss.py +198 -0
- gridfm_graphkit/utils/visualization.py +324 -0
- gridfm_graphkit-0.0.1.dist-info/METADATA +163 -0
- gridfm_graphkit-0.0.1.dist-info/RECORD +27 -0
- gridfm_graphkit-0.0.1.dist-info/WHEEL +5 -0
- gridfm_graphkit-0.0.1.dist-info/entry_points.txt +2 -0
- gridfm_graphkit-0.0.1.dist-info/licenses/LICENSE +201 -0
- gridfm_graphkit-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import os
|
|
4
|
+
from torch.utils.data import Subset
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def split_dataset(
|
|
9
|
+
dataset,
|
|
10
|
+
log_dir: str,
|
|
11
|
+
val_ratio: float = 0.1,
|
|
12
|
+
test_ratio: float = 0.1,
|
|
13
|
+
) -> Tuple[Subset, Subset, Subset]:
|
|
14
|
+
"""
|
|
15
|
+
Splits a dataset into training, validation, and test sets, and logs the indices for each split to CSV files for further analysis
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
dataset (torch_geometric.dataDataset): The dataset to split.
|
|
19
|
+
log_dir (str): Directory where CSV files containing the indices for each split will be saved.
|
|
20
|
+
val_ratio (float, optional): Proportion of the dataset to include in the validation set.
|
|
21
|
+
test_ratio (float, optional): Proportion of the dataset to include in the test set.
|
|
22
|
+
|
|
23
|
+
Raises:
|
|
24
|
+
ValueError: If `val_ratio + test_ratio >= 1`, which would leave no data for the training set.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
tuple: A tuple containing:
|
|
28
|
+
- train_dataset (torch.utils.data.Subset): The training subset of the dataset.
|
|
29
|
+
- val_dataset (torch.utils.data.Subset): The validation subset of the dataset.
|
|
30
|
+
- test_dataset (torch.utils.data.Subset): The test subset of the dataset.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
if val_ratio + test_ratio >= 1:
|
|
34
|
+
raise ValueError("The sum of val_ratio and test_ratio must be less than 1.")
|
|
35
|
+
|
|
36
|
+
val_size = int(val_ratio * len(dataset))
|
|
37
|
+
test_size = int(test_ratio * len(dataset))
|
|
38
|
+
train_size = len(dataset) - val_size - test_size
|
|
39
|
+
|
|
40
|
+
# Generate shuffled indices and split manually
|
|
41
|
+
indices = np.random.permutation(len(dataset))
|
|
42
|
+
train_indices = indices[:train_size]
|
|
43
|
+
val_indices = indices[train_size : train_size + val_size]
|
|
44
|
+
test_indices = indices[train_size + val_size :]
|
|
45
|
+
|
|
46
|
+
# Save indices to CSV files
|
|
47
|
+
pd.DataFrame(train_indices, columns=["index"]).to_csv(
|
|
48
|
+
os.path.join(log_dir, "train_indices.csv"),
|
|
49
|
+
index=False,
|
|
50
|
+
)
|
|
51
|
+
pd.DataFrame(val_indices, columns=["index"]).to_csv(
|
|
52
|
+
os.path.join(log_dir, "val_indices.csv"),
|
|
53
|
+
index=False,
|
|
54
|
+
)
|
|
55
|
+
pd.DataFrame(test_indices, columns=["index"]).to_csv(
|
|
56
|
+
os.path.join(log_dir, "test_indices.csv"),
|
|
57
|
+
index=False,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Create subsets
|
|
61
|
+
train_dataset = Subset(dataset, train_indices)
|
|
62
|
+
val_dataset = Subset(dataset, val_indices)
|
|
63
|
+
test_dataset = Subset(dataset, test_indices)
|
|
64
|
+
|
|
65
|
+
return train_dataset, val_dataset, test_dataset
|
|
File without changes
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
from gridfm_graphkit.datasets.data_normalization import (
|
|
2
|
+
IdentityNormalizer,
|
|
3
|
+
MinMaxNormalizer,
|
|
4
|
+
Standardizer,
|
|
5
|
+
BaseMVANormalizer,
|
|
6
|
+
)
|
|
7
|
+
from gridfm_graphkit.datasets.transforms import (
|
|
8
|
+
AddRandomMask,
|
|
9
|
+
AddPFMask,
|
|
10
|
+
AddOPFMask,
|
|
11
|
+
AddIdentityMask,
|
|
12
|
+
)
|
|
13
|
+
from gridfm_graphkit.utils.loss import (
|
|
14
|
+
PBELoss,
|
|
15
|
+
MaskedMSELoss,
|
|
16
|
+
SCELoss,
|
|
17
|
+
MixedLoss,
|
|
18
|
+
MSELoss,
|
|
19
|
+
)
|
|
20
|
+
from gridfm_graphkit.models.graphTransformer import GNN_TransformerConv
|
|
21
|
+
from gridfm_graphkit.models.gps_transformer import GPSTransformer
|
|
22
|
+
|
|
23
|
+
import argparse
|
|
24
|
+
import itertools
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class NestedNamespace(argparse.Namespace):
|
|
28
|
+
"""
|
|
29
|
+
A namespace object that supports nested structures, allowing for
|
|
30
|
+
easy access and manipulation of hierarchical configurations.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, **kwargs):
|
|
35
|
+
for key, value in kwargs.items():
|
|
36
|
+
if isinstance(value, dict):
|
|
37
|
+
# Recursively convert dictionaries to NestedNamespace
|
|
38
|
+
setattr(self, key, NestedNamespace(**value))
|
|
39
|
+
else:
|
|
40
|
+
setattr(self, key, value)
|
|
41
|
+
|
|
42
|
+
def to_dict(self):
|
|
43
|
+
# Recursively convert NestedNamespace back to dictionary
|
|
44
|
+
result = {}
|
|
45
|
+
for key, value in self.__dict__.items():
|
|
46
|
+
if isinstance(value, NestedNamespace):
|
|
47
|
+
result[key] = value.to_dict()
|
|
48
|
+
else:
|
|
49
|
+
result[key] = value
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
def flatten(self, parent_key="", sep="."):
|
|
53
|
+
# Flatten the dictionary with dot-separated keys
|
|
54
|
+
items = []
|
|
55
|
+
for key, value in self.__dict__.items():
|
|
56
|
+
new_key = f"{parent_key}{sep}{key}" if parent_key else key
|
|
57
|
+
if isinstance(value, NestedNamespace):
|
|
58
|
+
items.extend(value.flatten(new_key, sep=sep).items())
|
|
59
|
+
else:
|
|
60
|
+
items.append((new_key, value))
|
|
61
|
+
return dict(items)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def flatten_dict(d, parent_key="", sep="."):
|
|
65
|
+
"""
|
|
66
|
+
Flatten a nested dictionary into a single-level dictionary with dot-separated keys.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
d (dict): The dictionary to flatten.
|
|
70
|
+
parent_key (str, optional): Prefix for the keys in the flattened dictionary.
|
|
71
|
+
sep (str, optional): Separator for nested keys. Defaults to '.'.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
dict: A flattened version of the input dictionary.
|
|
75
|
+
"""
|
|
76
|
+
items = []
|
|
77
|
+
for key, value in d.items():
|
|
78
|
+
new_key = f"{parent_key}{sep}{key}" if parent_key else key
|
|
79
|
+
if isinstance(value, dict):
|
|
80
|
+
items.extend(flatten_dict(value, new_key, sep=sep).items())
|
|
81
|
+
else:
|
|
82
|
+
items.append((new_key, value))
|
|
83
|
+
return dict(items)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def unflatten_dict(d, sep="."):
|
|
87
|
+
"""
|
|
88
|
+
Reconstruct a nested dictionary from a flattened dictionary with dot-separated keys.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
d (dict): The flattened dictionary to unflatten.
|
|
92
|
+
sep (str, optional): Separator used in the flattened keys. Defaults to '.'.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
dict: A nested dictionary reconstructed from the flattened input.
|
|
96
|
+
"""
|
|
97
|
+
result = {}
|
|
98
|
+
for key, value in d.items():
|
|
99
|
+
parts = key.split(sep)
|
|
100
|
+
target = result
|
|
101
|
+
for part in parts[:-1]:
|
|
102
|
+
target = target.setdefault(part, {})
|
|
103
|
+
target[parts[-1]] = value
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def merge_dict(base, updates):
|
|
108
|
+
"""
|
|
109
|
+
Recursively merge updates into a base dictionary, but only if the keys exist in the base.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
base (dict): The original dictionary to be updated.
|
|
113
|
+
updates (dict): The dictionary containing updates.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
KeyError: If a key in updates does not exist in base.
|
|
117
|
+
TypeError: If a key in base is not a dictionary but updates attempt to provide nested values.
|
|
118
|
+
"""
|
|
119
|
+
for key, value in updates.items():
|
|
120
|
+
if key not in base:
|
|
121
|
+
raise KeyError(f"Key '{key}' not found in base configuration.")
|
|
122
|
+
|
|
123
|
+
if isinstance(value, dict):
|
|
124
|
+
if not isinstance(base[key], dict):
|
|
125
|
+
raise TypeError(
|
|
126
|
+
f"Default config expects {type(base[key])}, but got a dict at key '{key}'",
|
|
127
|
+
)
|
|
128
|
+
# Recursively merge dictionaries
|
|
129
|
+
merge_dict(base[key], value)
|
|
130
|
+
else:
|
|
131
|
+
# Update the existing key
|
|
132
|
+
base[key] = value
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def param_combination_gen(grid_config):
|
|
136
|
+
"""
|
|
137
|
+
Generate all combinations of parameters from a nested dictionary
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
grid_config (dict): A nested dictionary where keys are parameter names
|
|
141
|
+
and values are lists of possible values.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
list: A list of dictionaries representing all possible parameter combinations.
|
|
145
|
+
Each dictionary corresponds to one combination.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
# Flatten the grid config for combination generation
|
|
149
|
+
flat_grid = flatten_dict(grid_config)
|
|
150
|
+
|
|
151
|
+
# Separate keys and values for itertools.product
|
|
152
|
+
keys, values = zip(*flat_grid.items())
|
|
153
|
+
|
|
154
|
+
# Generate all combinations of parameters
|
|
155
|
+
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
|
156
|
+
|
|
157
|
+
# Unflatten the combinations back into nested dictionaries
|
|
158
|
+
nested_combinations = [unflatten_dict(comb) for comb in combinations]
|
|
159
|
+
return nested_combinations
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def load_normalizer(args):
|
|
163
|
+
"""
|
|
164
|
+
Load the appropriate data normalization methods
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
args (NestedNamespace): contains configs.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
tuple: Node and edge normalizers
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ValueError: If an unknown normalization method is specified.
|
|
174
|
+
"""
|
|
175
|
+
method = args.data.normalization
|
|
176
|
+
|
|
177
|
+
if method == "minmax":
|
|
178
|
+
return MinMaxNormalizer(), MinMaxNormalizer()
|
|
179
|
+
elif method == "standard":
|
|
180
|
+
return Standardizer(), Standardizer()
|
|
181
|
+
elif method == "baseMVAnorm":
|
|
182
|
+
return BaseMVANormalizer(
|
|
183
|
+
node_data=True,
|
|
184
|
+
baseMVA_orig=args.data.baseMVA,
|
|
185
|
+
), BaseMVANormalizer(node_data=False, baseMVA_orig=args.data.baseMVA)
|
|
186
|
+
elif method == "identity":
|
|
187
|
+
return IdentityNormalizer(), IdentityNormalizer()
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError(f"Unknown normalization method: {method}")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def get_loss_function(args):
|
|
193
|
+
"""
|
|
194
|
+
Load the appropriate loss function
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
args (NestedNamespace): contains configs.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
nn.Module: Loss function
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
ValueError: If an unknown loss function is specified.
|
|
204
|
+
"""
|
|
205
|
+
loss_functions = []
|
|
206
|
+
for loss_name in args.training.losses:
|
|
207
|
+
if loss_name == "MSE":
|
|
208
|
+
loss_functions.append(MSELoss())
|
|
209
|
+
elif loss_name == "MaskedMSE":
|
|
210
|
+
loss_functions.append(MaskedMSELoss())
|
|
211
|
+
elif loss_name == "SCE":
|
|
212
|
+
loss_functions.append(SCELoss())
|
|
213
|
+
elif loss_name == "PBE":
|
|
214
|
+
loss_functions.append(PBELoss())
|
|
215
|
+
else:
|
|
216
|
+
raise ValueError(f"Unknown loss function: {loss_name}")
|
|
217
|
+
|
|
218
|
+
return MixedLoss(loss_functions=loss_functions, weights=args.training.loss_weights)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def load_model(args):
|
|
222
|
+
"""
|
|
223
|
+
Load the appropriate model
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
args (NestedNamespace): contains configs.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
nn.Module: The selected model initialized with the provided configurations.
|
|
230
|
+
|
|
231
|
+
Raises:
|
|
232
|
+
ValueError: If an unknown model type is specified.
|
|
233
|
+
"""
|
|
234
|
+
model_type = args.model.type
|
|
235
|
+
|
|
236
|
+
if model_type == "GNN_TransformerConv":
|
|
237
|
+
return GNN_TransformerConv(
|
|
238
|
+
input_dim=args.model.input_dim,
|
|
239
|
+
hidden_dim=args.model.hidden_size,
|
|
240
|
+
output_dim=args.model.output_dim,
|
|
241
|
+
edge_dim=args.model.edge_dim,
|
|
242
|
+
num_layers=args.model.num_layers,
|
|
243
|
+
heads=args.model.attention_head,
|
|
244
|
+
mask_dim=args.data.mask_dim,
|
|
245
|
+
mask_value=args.data.mask_value,
|
|
246
|
+
learn_mask=args.data.learn_mask,
|
|
247
|
+
)
|
|
248
|
+
elif model_type == "GPSTransformer":
|
|
249
|
+
return GPSTransformer(
|
|
250
|
+
input_dim=args.model.input_dim,
|
|
251
|
+
hidden_dim=args.model.hidden_size,
|
|
252
|
+
output_dim=args.model.output_dim,
|
|
253
|
+
edge_dim=args.model.edge_dim,
|
|
254
|
+
pe_dim=args.model.pe_dim,
|
|
255
|
+
heads=args.model.attention_head,
|
|
256
|
+
num_layers=args.model.num_layers,
|
|
257
|
+
dropout=args.model.dropout,
|
|
258
|
+
mask_dim=args.data.mask_dim,
|
|
259
|
+
mask_value=args.data.mask_value,
|
|
260
|
+
learn_mask=args.data.learn_mask,
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
raise ValueError(f"Unknown model type: {model_type}")
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_transform(args):
|
|
267
|
+
"""
|
|
268
|
+
Load the appropriate dataset transform
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
args (NestedNamespace): contains configs.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
BaseTransform: Transformation
|
|
275
|
+
|
|
276
|
+
Raises:
|
|
277
|
+
ValueError: If an unknown transform is specified.
|
|
278
|
+
"""
|
|
279
|
+
mask_type = args.data.mask_type
|
|
280
|
+
|
|
281
|
+
if mask_type == "rnd":
|
|
282
|
+
return AddRandomMask(
|
|
283
|
+
mask_dim=args.data.mask_dim,
|
|
284
|
+
mask_ratio=args.data.mask_ratio,
|
|
285
|
+
)
|
|
286
|
+
elif mask_type == "pf":
|
|
287
|
+
return AddPFMask()
|
|
288
|
+
elif mask_type == "opf":
|
|
289
|
+
return AddOPFMask()
|
|
290
|
+
elif mask_type == "none":
|
|
291
|
+
return AddIdentityMask()
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError(f"Unknown transformation: {mask_type}")
|
|
File without changes
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from torch_geometric.nn import GPSConv, GINEConv
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GPSTransformer(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
A GPS (Graph Transformer) model based on [GPSConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html) and [GINEConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINEConv.html) layers from Pytorch Geometric.
|
|
9
|
+
|
|
10
|
+
This model encodes node features and positional encodings separately,
|
|
11
|
+
then applies multiple graph convolution layers with batch normalization,
|
|
12
|
+
and finally decodes to the output dimension.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
input_dim (int): Dimension of input node features.
|
|
16
|
+
hidden_dim (int): Hidden dimension size for all layers.
|
|
17
|
+
output_dim (int): Dimension of the output node features.
|
|
18
|
+
edge_dim (int): Dimension of edge features.
|
|
19
|
+
pe_dim (int): Dimension of the positional encoding.
|
|
20
|
+
Must be less than hidden_dim.
|
|
21
|
+
num_layers (int): Number of GPSConv layers.
|
|
22
|
+
heads (int, optional): Number of attention heads in GPSConv.
|
|
23
|
+
dropout (float, optional): Dropout rate in GPSConv.
|
|
24
|
+
mask_dim (int, optional): Dimension of the mask vector.
|
|
25
|
+
mask_value (float, optional): Initial value for learnable mask parameters.
|
|
26
|
+
learn_mask (bool, optional): Whether to learn mask values as parameters.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If `pe_dim` is not less than `hidden_dim`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
input_dim: int,
|
|
35
|
+
hidden_dim: int,
|
|
36
|
+
output_dim: int,
|
|
37
|
+
edge_dim: int,
|
|
38
|
+
pe_dim: int,
|
|
39
|
+
num_layers: int,
|
|
40
|
+
heads: int = 1,
|
|
41
|
+
dropout: float = 0.0,
|
|
42
|
+
mask_dim: int = 6,
|
|
43
|
+
mask_value: float = -1.0,
|
|
44
|
+
learn_mask: bool = True,
|
|
45
|
+
):
|
|
46
|
+
super(GPSTransformer, self).__init__()
|
|
47
|
+
self.num_layers = num_layers
|
|
48
|
+
self.hidden_dim = hidden_dim
|
|
49
|
+
self.edge_dim = edge_dim
|
|
50
|
+
self.pe_dim = pe_dim
|
|
51
|
+
self.heads = heads
|
|
52
|
+
self.dropout = dropout
|
|
53
|
+
self.mask_dim = mask_dim
|
|
54
|
+
self.mask_value = mask_value
|
|
55
|
+
self.learn_mask = learn_mask
|
|
56
|
+
|
|
57
|
+
if not pe_dim < hidden_dim:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"positional encoding dimension must be smaller than model hidden dimension",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.layers = nn.ModuleList()
|
|
63
|
+
|
|
64
|
+
self.encoder = nn.Sequential(
|
|
65
|
+
nn.Linear(input_dim, self.hidden_dim - self.pe_dim),
|
|
66
|
+
nn.LeakyReLU(),
|
|
67
|
+
)
|
|
68
|
+
self.input_norm = nn.BatchNorm1d(self.hidden_dim - self.pe_dim)
|
|
69
|
+
self.pe_norm = nn.BatchNorm1d(self.pe_dim)
|
|
70
|
+
|
|
71
|
+
for _ in range(self.num_layers):
|
|
72
|
+
mlp = nn.Sequential(
|
|
73
|
+
nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim),
|
|
74
|
+
nn.LeakyReLU(),
|
|
75
|
+
)
|
|
76
|
+
self.layers.append(
|
|
77
|
+
nn.ModuleDict(
|
|
78
|
+
{
|
|
79
|
+
"conv": GPSConv(
|
|
80
|
+
channels=self.hidden_dim,
|
|
81
|
+
conv=GINEConv(nn=mlp, edge_dim=self.edge_dim),
|
|
82
|
+
heads=self.heads,
|
|
83
|
+
dropout=self.dropout,
|
|
84
|
+
),
|
|
85
|
+
"norm": nn.BatchNorm1d(
|
|
86
|
+
self.hidden_dim,
|
|
87
|
+
), # BatchNorm after each graph layer
|
|
88
|
+
},
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.pre_decoder_norm = nn.BatchNorm1d(self.hidden_dim)
|
|
93
|
+
# Fully connected (MLP) layers after the GAT layers
|
|
94
|
+
self.decoder = nn.Sequential(
|
|
95
|
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
|
96
|
+
nn.LeakyReLU(),
|
|
97
|
+
nn.Linear(self.hidden_dim, output_dim),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if learn_mask:
|
|
101
|
+
self.mask_value = nn.Parameter(
|
|
102
|
+
torch.randn(mask_dim) + mask_value,
|
|
103
|
+
requires_grad=True,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
self.mask_value = nn.Parameter(
|
|
107
|
+
torch.zeros(mask_dim) + mask_value,
|
|
108
|
+
requires_grad=False,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def forward(self, x, pe, edge_index, edge_attr, batch):
|
|
112
|
+
"""
|
|
113
|
+
Forward pass for the GPSTransformer.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
x (Tensor): Input node features of shape [num_nodes, input_dim].
|
|
117
|
+
pe (Tensor): Positional encoding of shape [num_nodes, pe_dim].
|
|
118
|
+
edge_index (Tensor): Edge indices for graph convolution.
|
|
119
|
+
edge_attr (Tensor): Edge feature tensor.
|
|
120
|
+
batch (Tensor): Batch vector assigning nodes to graphs.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
output (Tensor): Output node features of shape [num_nodes, output_dim].
|
|
124
|
+
"""
|
|
125
|
+
x_pe = self.pe_norm(pe)
|
|
126
|
+
|
|
127
|
+
x = self.encoder(x)
|
|
128
|
+
x = self.input_norm(x)
|
|
129
|
+
|
|
130
|
+
x = torch.cat((x, x_pe), 1)
|
|
131
|
+
for layer in self.layers:
|
|
132
|
+
x = layer["conv"](
|
|
133
|
+
x=x,
|
|
134
|
+
edge_index=edge_index,
|
|
135
|
+
edge_attr=edge_attr,
|
|
136
|
+
batch=batch,
|
|
137
|
+
)
|
|
138
|
+
x = layer["norm"](x)
|
|
139
|
+
|
|
140
|
+
x = self.pre_decoder_norm(x)
|
|
141
|
+
x = self.decoder(x)
|
|
142
|
+
|
|
143
|
+
return x
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from torch_geometric.nn import TransformerConv
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GNN_TransformerConv(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Graph Neural Network using [TransformerConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TransformerConv.html) layers from PyTorch Geometric.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
input_dim (int): Dimensionality of input node features.
|
|
12
|
+
hidden_dim (int): Hidden dimension size for TransformerConv layers.
|
|
13
|
+
output_dim (int): Output dimension size.
|
|
14
|
+
edge_dim (int): Dimensionality of edge features.
|
|
15
|
+
num_layers (int): Number of TransformerConv layers.
|
|
16
|
+
heads (int, optional): Number of attention heads.
|
|
17
|
+
mask_dim (int, optional): Dimension of mask vector.
|
|
18
|
+
mask_value (float, optional): Initial mask value.
|
|
19
|
+
learn_mask (bool, optional): Whether mask values are learnable.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
input_dim: int,
|
|
25
|
+
hidden_dim: int,
|
|
26
|
+
output_dim: int,
|
|
27
|
+
edge_dim: int,
|
|
28
|
+
num_layers: int,
|
|
29
|
+
heads: int = 1,
|
|
30
|
+
mask_dim: int = 6,
|
|
31
|
+
mask_value: float = -1.0,
|
|
32
|
+
learn_mask: bool = False,
|
|
33
|
+
):
|
|
34
|
+
super(GNN_TransformerConv, self).__init__()
|
|
35
|
+
self.num_layers = num_layers
|
|
36
|
+
self.hidden_dim = hidden_dim
|
|
37
|
+
self.edge_dim = edge_dim
|
|
38
|
+
self.heads = heads
|
|
39
|
+
self.mask_dim = mask_dim
|
|
40
|
+
self.mask_value = mask_value
|
|
41
|
+
self.learn_mask = learn_mask
|
|
42
|
+
|
|
43
|
+
self.layers = nn.ModuleList()
|
|
44
|
+
current_dim = input_dim # First layer takes `input_dim` as input
|
|
45
|
+
|
|
46
|
+
for _ in range(self.num_layers):
|
|
47
|
+
self.layers.append(
|
|
48
|
+
TransformerConv(
|
|
49
|
+
current_dim,
|
|
50
|
+
self.hidden_dim,
|
|
51
|
+
heads=self.heads,
|
|
52
|
+
edge_dim=self.edge_dim,
|
|
53
|
+
beta=False,
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
# Update the dimension for the next layer
|
|
57
|
+
current_dim = self.hidden_dim * self.heads
|
|
58
|
+
|
|
59
|
+
# Fully connected (MLP) layers after the GAT layers
|
|
60
|
+
self.mlps = nn.Sequential(
|
|
61
|
+
nn.Linear(self.hidden_dim * self.heads, self.hidden_dim),
|
|
62
|
+
nn.LeakyReLU(),
|
|
63
|
+
nn.Linear(self.hidden_dim, output_dim),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if learn_mask:
|
|
67
|
+
self.mask_value = nn.Parameter(
|
|
68
|
+
torch.randn(mask_dim) + mask_value,
|
|
69
|
+
requires_grad=True,
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
self.mask_value = nn.Parameter(
|
|
73
|
+
torch.zeros(mask_dim) + mask_value,
|
|
74
|
+
requires_grad=False,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def forward(self, x, pe, edge_index, edge_attr, batch):
|
|
78
|
+
"""
|
|
79
|
+
Forward pass for the GPSTransformer.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
x (Tensor): Input node features of shape [num_nodes, input_dim].
|
|
83
|
+
pe (Tensor): Positional encoding of shape [num_nodes, pe_dim] (not used).
|
|
84
|
+
edge_index (Tensor): Edge indices for graph convolution.
|
|
85
|
+
edge_attr (Tensor): Edge feature tensor.
|
|
86
|
+
batch (Tensor): Batch vector assigning nodes to graphs (not used).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
output (Tensor): Output node features of shape [num_nodes, output_dim].
|
|
90
|
+
"""
|
|
91
|
+
for conv in self.layers:
|
|
92
|
+
x = conv(x, edge_index, edge_attr)
|
|
93
|
+
x = nn.LeakyReLU()(x)
|
|
94
|
+
|
|
95
|
+
x = self.mlps(x)
|
|
96
|
+
return x
|
|
File without changes
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EarlyStopper:
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
saving_path,
|
|
8
|
+
patience=5,
|
|
9
|
+
tol=0,
|
|
10
|
+
min_validation_loss=float("inf"),
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Args:
|
|
14
|
+
patience (int): number of epochs to wait before early stopping
|
|
15
|
+
-1 means no early stopping
|
|
16
|
+
0 means stop training the first time the validation loss increases
|
|
17
|
+
tol (float): tolerance to consider validation loss as worse as the best one so far
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
self.patience = patience
|
|
21
|
+
self.tol = tol
|
|
22
|
+
self.counter = 0
|
|
23
|
+
self.min_validation_loss = min_validation_loss
|
|
24
|
+
self.saving_path = saving_path
|
|
25
|
+
|
|
26
|
+
def early_stop(self, validation_loss, model):
|
|
27
|
+
if validation_loss < self.min_validation_loss:
|
|
28
|
+
self.min_validation_loss = (
|
|
29
|
+
validation_loss # Update the best validation loss
|
|
30
|
+
)
|
|
31
|
+
self.counter = 0
|
|
32
|
+
|
|
33
|
+
# Save the best model whenever a new minimum is found
|
|
34
|
+
torch.save(model, self.saving_path)
|
|
35
|
+
|
|
36
|
+
# check if the valid loss is worse than the best one so far, accounting for tolerance
|
|
37
|
+
elif validation_loss > (self.min_validation_loss + self.tol):
|
|
38
|
+
self.counter += 1
|
|
39
|
+
|
|
40
|
+
if self.patience != -1 and self.counter > self.patience:
|
|
41
|
+
print(
|
|
42
|
+
"Early stopping after {} epochs of no improvement.".format(
|
|
43
|
+
self.counter,
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
return True
|
|
47
|
+
return False
|