topologicpy 0.8.98__py3-none-any.whl → 0.8.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- topologicpy/ANN.py +1 -1
- topologicpy/Aperture.py +1 -1
- topologicpy/BVH.py +1 -1
- topologicpy/CSG.py +1 -1
- topologicpy/Cell.py +1 -1
- topologicpy/CellComplex.py +1 -1
- topologicpy/Cluster.py +1 -1
- topologicpy/Color.py +1 -1
- topologicpy/Context.py +1 -1
- topologicpy/DGL.py +1 -1
- topologicpy/Dictionary.py +92 -1
- topologicpy/Edge.py +1 -1
- topologicpy/EnergyModel.py +1 -1
- topologicpy/Face.py +1 -1
- topologicpy/Graph.py +887 -4
- topologicpy/Grid.py +1 -1
- topologicpy/Helper.py +1 -1
- topologicpy/Honeybee.py +1 -1
- topologicpy/Matrix.py +1 -1
- topologicpy/Neo4j.py +1 -1
- topologicpy/Plotly.py +1 -1
- topologicpy/Polyskel.py +1 -1
- topologicpy/PyG.py +1287 -2308
- topologicpy/ShapeGrammar.py +1 -1
- topologicpy/Shell.py +1 -1
- topologicpy/Speckle.py +1 -1
- topologicpy/Sun.py +1 -1
- topologicpy/Topology.py +1 -1
- topologicpy/Vector.py +1 -1
- topologicpy/Vertex.py +1 -1
- topologicpy/Wire.py +1 -1
- topologicpy/__init__.py +1 -1
- topologicpy/version.py +1 -1
- {topologicpy-0.8.98.dist-info → topologicpy-0.8.99.dist-info}/METADATA +1 -1
- topologicpy-0.8.99.dist-info/RECORD +39 -0
- topologicpy-0.8.98.dist-info/RECORD +0 -39
- {topologicpy-0.8.98.dist-info → topologicpy-0.8.99.dist-info}/WHEEL +0 -0
- {topologicpy-0.8.98.dist-info → topologicpy-0.8.99.dist-info}/licenses/LICENSE +0 -0
- {topologicpy-0.8.98.dist-info → topologicpy-0.8.99.dist-info}/top_level.txt +0 -0
topologicpy/PyG.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C)
|
|
1
|
+
# Copyright (C) 2026
|
|
2
2
|
# Wassim Jabi <wassim.jabi@gmail.com>
|
|
3
3
|
#
|
|
4
4
|
# This program is free software: you can redistribute it and/or modify it under
|
|
@@ -14,2396 +14,1375 @@
|
|
|
14
14
|
# You should have received a copy of the GNU Affero General Public License along with
|
|
15
15
|
# this program. If not, see <https://www.gnu.org/licenses/>.
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
import gc
|
|
21
|
-
|
|
22
|
-
try:
|
|
23
|
-
import numpy as np
|
|
24
|
-
except:
|
|
25
|
-
print("PyG - Installing required numpy library.")
|
|
26
|
-
try:
|
|
27
|
-
os.system("pip install numpy")
|
|
28
|
-
except:
|
|
29
|
-
os.system("pip install numpy --user")
|
|
30
|
-
try:
|
|
31
|
-
import numpy as np
|
|
32
|
-
print("PyG - numpy library installed successfully.")
|
|
33
|
-
except:
|
|
34
|
-
warnings.warn("PyG - Error: Could not import numpy.")
|
|
35
|
-
|
|
36
|
-
try:
|
|
37
|
-
import pandas as pd
|
|
38
|
-
except:
|
|
39
|
-
print("PyG - Installing required pandas library.")
|
|
40
|
-
try:
|
|
41
|
-
os.system("pip install pandas")
|
|
42
|
-
except:
|
|
43
|
-
os.system("pip install pandas --user")
|
|
44
|
-
try:
|
|
45
|
-
import numpy as np
|
|
46
|
-
print("PyG - pandas library installed successfully.")
|
|
47
|
-
except:
|
|
48
|
-
warnings.warn("PyG - Error: Could not import pandas.")
|
|
49
|
-
|
|
50
|
-
try:
|
|
51
|
-
from tqdm.auto import tqdm
|
|
52
|
-
except:
|
|
53
|
-
print("PyG - Installing required tqdm library.")
|
|
54
|
-
try:
|
|
55
|
-
os.system("pip install tqdm")
|
|
56
|
-
except:
|
|
57
|
-
os.system("pip install tqdm --user")
|
|
58
|
-
try:
|
|
59
|
-
from tqdm.auto import tqdm
|
|
60
|
-
print("PyG - tqdm library installed correctly.")
|
|
61
|
-
except:
|
|
62
|
-
raise Exception("PyG - Error: Could not import tqdm.")
|
|
63
|
-
|
|
64
|
-
try:
|
|
65
|
-
import torch
|
|
66
|
-
import torch.nn as nn
|
|
67
|
-
import torch.nn.functional as F
|
|
68
|
-
from torch.utils.data.sampler import SubsetRandomSampler
|
|
69
|
-
except:
|
|
70
|
-
print("PyG - Installing required torch library.")
|
|
71
|
-
try:
|
|
72
|
-
os.system("pip install torch")
|
|
73
|
-
except:
|
|
74
|
-
os.system("pip install torch --user")
|
|
75
|
-
try:
|
|
76
|
-
import torch
|
|
77
|
-
import torch.nn as nn
|
|
78
|
-
import torch.nn.functional as F
|
|
79
|
-
from torch.utils.data.sampler import SubsetRandomSampler
|
|
80
|
-
print("PyG - torch library installed correctly.")
|
|
81
|
-
except:
|
|
82
|
-
warnings.warn("PyG - Error: Could not import torch.")
|
|
83
|
-
|
|
84
|
-
try:
|
|
85
|
-
from torch_geometric.data import Data, Dataset
|
|
86
|
-
from torch_geometric.loader import DataLoader
|
|
87
|
-
from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool, global_add_pool
|
|
88
|
-
except:
|
|
89
|
-
print("PyG - Installing required torch_geometric library.")
|
|
90
|
-
try:
|
|
91
|
-
os.system("pip install torch_geometric")
|
|
92
|
-
except:
|
|
93
|
-
os.system("pip install torch_geometric --user")
|
|
94
|
-
try:
|
|
95
|
-
from torch_geometric.data import Data, Dataset
|
|
96
|
-
from torch_geometric.loader import DataLoader
|
|
97
|
-
from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool, global_add_pool
|
|
98
|
-
print("PyG - torch_geometric library installed correctly.")
|
|
99
|
-
except:
|
|
100
|
-
warnings.warn("PyG - Error: Could not import torch.")
|
|
101
|
-
|
|
102
|
-
try:
|
|
103
|
-
from sklearn.model_selection import KFold
|
|
104
|
-
from sklearn.metrics import accuracy_score
|
|
105
|
-
except:
|
|
106
|
-
print("PyG - Installing required scikit-learn library.")
|
|
107
|
-
try:
|
|
108
|
-
os.system("pip install -U scikit-learn")
|
|
109
|
-
except:
|
|
110
|
-
os.system("pip install -U scikit-learn --user")
|
|
111
|
-
try:
|
|
112
|
-
from sklearn.model_selection import KFold
|
|
113
|
-
from sklearn.metrics import accuracy_score
|
|
114
|
-
print("PyG - scikit-learn library installed correctly.")
|
|
115
|
-
except:
|
|
116
|
-
warnings.warn("PyG - Error: Could not import scikit. Please install it manually.")
|
|
117
|
-
|
|
118
|
-
class CustomGraphDataset(Dataset):
|
|
119
|
-
def __init__(self, root=None, data_list=None, indices=None, node_level=False, graph_level=True,
|
|
120
|
-
node_attr_key='feat', edge_attr_key='feat'):
|
|
121
|
-
"""
|
|
122
|
-
Initializes the CustomGraphDataset.
|
|
123
|
-
|
|
124
|
-
Parameters:
|
|
125
|
-
- root: Root directory of the dataset (used only if data_list is None)
|
|
126
|
-
- data_list: List of preprocessed data objects (used if provided)
|
|
127
|
-
- indices: List of indices to select a subset of the data
|
|
128
|
-
- node_level: Boolean flag indicating if the dataset is node-level
|
|
129
|
-
- graph_level: Boolean flag indicating if the dataset is graph-level
|
|
130
|
-
- node_attr_key: Key for node attributes
|
|
131
|
-
- edge_attr_key: Key for edge attributes
|
|
132
|
-
"""
|
|
133
|
-
assert not (node_level and graph_level), "Both node_level and graph_level cannot be True at the same time"
|
|
134
|
-
assert node_level or graph_level, "Both node_level and graph_level cannot be False at the same time"
|
|
135
|
-
|
|
136
|
-
self.node_level = node_level
|
|
137
|
-
self.graph_level = graph_level
|
|
138
|
-
self.node_attr_key = node_attr_key
|
|
139
|
-
self.edge_attr_key = edge_attr_key
|
|
140
|
-
|
|
141
|
-
if data_list is not None:
|
|
142
|
-
self.data_list = data_list # Use the provided data list
|
|
143
|
-
elif root is not None:
|
|
144
|
-
# Load and process data from root directory if data_list is not provided
|
|
145
|
-
self.graph_df = pd.read_csv(os.path.join(root, 'graphs.csv'))
|
|
146
|
-
self.nodes_df = pd.read_csv(os.path.join(root, 'nodes.csv'))
|
|
147
|
-
self.edges_df = pd.read_csv(os.path.join(root, 'edges.csv'))
|
|
148
|
-
self.data_list = self.process_all()
|
|
149
|
-
else:
|
|
150
|
-
raise ValueError("Either a root directory or a data_list must be provided.")
|
|
151
|
-
|
|
152
|
-
# Filter data_list based on indices if provided
|
|
153
|
-
if indices is not None:
|
|
154
|
-
self.data_list = [self.data_list[i] for i in indices]
|
|
155
|
-
|
|
156
|
-
def process_all(self):
|
|
157
|
-
data_list = []
|
|
158
|
-
for graph_id in self.graph_df['graph_id'].unique():
|
|
159
|
-
graph_nodes = self.nodes_df[self.nodes_df['graph_id'] == graph_id]
|
|
160
|
-
graph_edges = self.edges_df[self.edges_df['graph_id'] == graph_id]
|
|
161
|
-
|
|
162
|
-
if self.node_attr_key in graph_nodes.columns and not graph_nodes[self.node_attr_key].isnull().all():
|
|
163
|
-
x = torch.tensor(graph_nodes[self.node_attr_key].values.tolist(), dtype=torch.float)
|
|
164
|
-
if x.ndim == 1:
|
|
165
|
-
x = x.unsqueeze(1) # Ensure x has shape [num_nodes, *]
|
|
166
|
-
else:
|
|
167
|
-
x = None
|
|
168
|
-
|
|
169
|
-
edge_index = torch.tensor(graph_edges[['src_id', 'dst_id']].values.T, dtype=torch.long)
|
|
170
|
-
|
|
171
|
-
if self.edge_attr_key in graph_edges.columns and not graph_edges[self.edge_attr_key].isnull().all():
|
|
172
|
-
edge_attr = torch.tensor(graph_edges[self.edge_attr_key].values.tolist(), dtype=torch.float)
|
|
173
|
-
else:
|
|
174
|
-
edge_attr = None
|
|
175
|
-
|
|
176
|
-
if self.graph_level:
|
|
177
|
-
label_value = self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]
|
|
178
|
-
|
|
179
|
-
if isinstance(label_value, np.int64):
|
|
180
|
-
label_value = int(label_value)
|
|
181
|
-
if isinstance(label_value, np.float64):
|
|
182
|
-
label_value = float(label_value)
|
|
183
|
-
|
|
184
|
-
if isinstance(label_value, int) or isinstance(label_value, np.int64):
|
|
185
|
-
y = torch.tensor([label_value], dtype=torch.long)
|
|
186
|
-
elif isinstance(label_value, float):
|
|
187
|
-
y = torch.tensor([label_value], dtype=torch.float)
|
|
188
|
-
else:
|
|
189
|
-
raise ValueError(f"Unexpected label type: {type(label_value)}. Expected int or float.")
|
|
190
|
-
|
|
191
|
-
elif self.node_level:
|
|
192
|
-
label_values = graph_nodes['label'].values
|
|
193
|
-
|
|
194
|
-
if issubclass(label_values.dtype.type, int):
|
|
195
|
-
y = torch.tensor(label_values, dtype=torch.long)
|
|
196
|
-
elif issubclass(label_values.dtype.type, float):
|
|
197
|
-
y = torch.tensor(label_values, dtype=torch.float)
|
|
198
|
-
else:
|
|
199
|
-
raise ValueError(f"Unexpected label types: {label_values.dtype}. Expected int or float.")
|
|
200
|
-
|
|
201
|
-
data = Data(x=x, edge_index=edge_index, y=y)
|
|
202
|
-
if edge_attr is not None:
|
|
203
|
-
data.edge_attr = edge_attr
|
|
204
|
-
|
|
205
|
-
data_list.append(data)
|
|
206
|
-
|
|
207
|
-
return data_list
|
|
208
|
-
|
|
209
|
-
def __len__(self):
|
|
210
|
-
return len(self.data_list)
|
|
211
|
-
|
|
212
|
-
def __getitem__(self, idx):
|
|
213
|
-
return self.data_list[idx]
|
|
17
|
+
"""
|
|
18
|
+
TopologicPy: PyTorch Geometric (PyG) helper class
|
|
19
|
+
=================================================
|
|
214
20
|
|
|
21
|
+
This module provides a clean, beginner-friendly interface to:
|
|
215
22
|
|
|
23
|
+
1) Load TopologicPy-exported CSV datasets (graphs.csv, nodes.csv, edges.csv)
|
|
24
|
+
2) Train / validate / test models for:
|
|
25
|
+
- Graph-level prediction (classification or regression)
|
|
26
|
+
- Node-level prediction (classification or regression)
|
|
27
|
+
- Edge-level prediction (classification or regression)
|
|
28
|
+
- Link prediction (binary edge existence)
|
|
216
29
|
|
|
30
|
+
3) Report performance metrics and interactive Plotly visualisations.
|
|
217
31
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
32
|
+
User-controlled hyperparameters (medium-level)
|
|
33
|
+
----------------------------------------------
|
|
34
|
+
- Cross-validation: holdout or k-fold (graph-level)
|
|
35
|
+
- Network topology: number of hidden layers and neurons per layer (hidden_dims)
|
|
36
|
+
- GNN backbone: conv type (sage/gcn/gatv2), activation, dropout, batch_norm, residual
|
|
37
|
+
- Training: epochs, batch_size, lr, weight_decay, optimizer (adam/adamw),
|
|
38
|
+
gradient clipping, early stopping
|
|
224
39
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
40
|
+
CSV assumptions
|
|
41
|
+
---------------
|
|
42
|
+
- graphs.csv contains at least: graph_id, label, and optional graph feature columns:
|
|
43
|
+
feat_0, feat_1, ...
|
|
229
44
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# self.edges_df = pd.read_csv(os.path.join(root, 'edges.csv'))
|
|
45
|
+
- nodes.csv contains at least: graph_id, node_id, label, optional masks, and feature columns:
|
|
46
|
+
feat_0, feat_1, ...
|
|
233
47
|
|
|
234
|
-
|
|
48
|
+
- edges.csv contains at least: graph_id, src_id, dst_id, label, optional masks, and feature columns:
|
|
49
|
+
feat_0, feat_1, ...
|
|
235
50
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
51
|
+
Notes
|
|
52
|
+
-----
|
|
53
|
+
- This module intentionally avoids auto-install behaviour.
|
|
54
|
+
- It aims to be easy to read and modify by non-ML experts.
|
|
239
55
|
|
|
240
|
-
|
|
241
|
-
# data_list = []
|
|
242
|
-
# for graph_id in self.graph_df['graph_id'].unique():
|
|
243
|
-
# graph_nodes = self.nodes_df[self.nodes_df['graph_id'] == graph_id]
|
|
244
|
-
# graph_edges = self.edges_df[self.edges_df['graph_id'] == graph_id]
|
|
56
|
+
"""
|
|
245
57
|
|
|
246
|
-
|
|
247
|
-
# x = torch.tensor(graph_nodes[self.node_attr_key].values.tolist(), dtype=torch.float)
|
|
248
|
-
# if x.ndim == 1:
|
|
249
|
-
# x = x.unsqueeze(1) # Ensure x has shape [num_nodes, *]
|
|
250
|
-
# else:
|
|
251
|
-
# x = None
|
|
58
|
+
from __future__ import annotations
|
|
252
59
|
|
|
253
|
-
|
|
60
|
+
from dataclasses import dataclass
|
|
61
|
+
from typing import Dict, List, Optional, Tuple, Union, Literal
|
|
254
62
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
# if self.graph_level:
|
|
263
|
-
# label_value = self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]
|
|
264
|
-
|
|
265
|
-
# # Check if the label is an integer or a float and cast accordingly
|
|
266
|
-
# if isinstance(label_value, int):
|
|
267
|
-
# y = torch.tensor([label_value], dtype=torch.long)
|
|
268
|
-
# elif isinstance(label_value, float):
|
|
269
|
-
# y = torch.tensor([label_value], dtype=torch.float)
|
|
270
|
-
# else:
|
|
271
|
-
# raise ValueError(f"Unexpected label type: {type(label_value)}. Expected int or float.")
|
|
272
|
-
|
|
273
|
-
# elif self.node_level:
|
|
274
|
-
# label_values = graph_nodes['label'].values
|
|
275
|
-
|
|
276
|
-
# # Check if the labels are integers or floats and cast accordingly
|
|
277
|
-
# if issubclass(label_values.dtype.type, int):
|
|
278
|
-
# y = torch.tensor(label_values, dtype=torch.long)
|
|
279
|
-
# elif issubclass(label_values.dtype.type, float):
|
|
280
|
-
# y = torch.tensor(label_values, dtype=torch.float)
|
|
281
|
-
# else:
|
|
282
|
-
# raise ValueError(f"Unexpected label types: {label_values.dtype}. Expected int or float.")
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
# # if self.graph_level:
|
|
286
|
-
# # y = torch.tensor([self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]], dtype=torch.long)
|
|
287
|
-
# # elif self.node_level:
|
|
288
|
-
# # y = torch.tensor(graph_nodes['label'].values, dtype=torch.long)
|
|
63
|
+
import os
|
|
64
|
+
import math
|
|
65
|
+
import random
|
|
66
|
+
import copy
|
|
289
67
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
68
|
+
import numpy as np
|
|
69
|
+
import pandas as pd
|
|
70
|
+
|
|
71
|
+
import torch
|
|
72
|
+
import torch.nn as nn
|
|
73
|
+
import torch.nn.functional as F
|
|
74
|
+
|
|
75
|
+
from torch_geometric.data import Data
|
|
76
|
+
from torch_geometric.loader import DataLoader
|
|
77
|
+
from torch_geometric.nn import (
|
|
78
|
+
SAGEConv, GCNConv, GATv2Conv,
|
|
79
|
+
global_mean_pool, global_max_pool, global_add_pool
|
|
80
|
+
)
|
|
81
|
+
from torch_geometric.transforms import RandomLinkSplit
|
|
82
|
+
|
|
83
|
+
from sklearn.metrics import (
|
|
84
|
+
accuracy_score, precision_recall_fscore_support, confusion_matrix,
|
|
85
|
+
mean_absolute_error, mean_squared_error, r2_score
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
import plotly.graph_objects as go
|
|
89
|
+
import plotly.express as px
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
LabelType = Literal["categorical", "continuous"]
|
|
93
|
+
Level = Literal["graph", "node", "edge", "link"]
|
|
94
|
+
TaskKind = Literal["classification", "regression", "link_prediction"]
|
|
95
|
+
ConvKind = Literal["sage", "gcn", "gatv2"]
|
|
96
|
+
PoolingKind = Literal["mean", "max", "sum"]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class _RunConfig:
|
|
101
|
+
# ----------------------------
|
|
102
|
+
# Task selection
|
|
103
|
+
# ----------------------------
|
|
104
|
+
level: Level = "graph" # "graph" | "node" | "edge" | "link"
|
|
105
|
+
task: TaskKind = "classification" # "classification" | "regression" | "link_prediction"
|
|
106
|
+
|
|
107
|
+
# label types (graph/node/edge)
|
|
108
|
+
graph_label_type: LabelType = "categorical"
|
|
109
|
+
node_label_type: LabelType = "categorical"
|
|
110
|
+
edge_label_type: LabelType = "categorical"
|
|
111
|
+
|
|
112
|
+
# ----------------------------
|
|
113
|
+
# CSV headers
|
|
114
|
+
# ----------------------------
|
|
115
|
+
graph_id_header: str = "graph_id"
|
|
116
|
+
graph_label_header: str = "label"
|
|
117
|
+
graph_features_header: str = "feat"
|
|
118
|
+
|
|
119
|
+
node_id_header: str = "node_id"
|
|
120
|
+
node_label_header: str = "label"
|
|
121
|
+
node_features_header: str = "feat"
|
|
122
|
+
|
|
123
|
+
edge_src_header: str = "src_id"
|
|
124
|
+
edge_dst_header: str = "dst_id"
|
|
125
|
+
edge_label_header: str = "label"
|
|
126
|
+
edge_features_header: str = "feat"
|
|
127
|
+
|
|
128
|
+
# masks (optional)
|
|
129
|
+
node_train_mask_header: str = "train_mask"
|
|
130
|
+
node_val_mask_header: str = "val_mask"
|
|
131
|
+
node_test_mask_header: str = "test_mask"
|
|
132
|
+
|
|
133
|
+
edge_train_mask_header: str = "train_mask"
|
|
134
|
+
edge_val_mask_header: str = "val_mask"
|
|
135
|
+
edge_test_mask_header: str = "test_mask"
|
|
136
|
+
|
|
137
|
+
# ----------------------------
|
|
138
|
+
# Cross-validation / splitting
|
|
139
|
+
# ----------------------------
|
|
140
|
+
cv: Literal["holdout", "kfold"] = "holdout"
|
|
141
|
+
split: Tuple[float, float, float] = (0.8, 0.1, 0.1) # used for holdout
|
|
142
|
+
k_folds: int = 5 # used for kfold (graph-level only)
|
|
143
|
+
k_shuffle: bool = True
|
|
144
|
+
k_stratify: bool = True # only if categorical labels exist
|
|
145
|
+
random_state: int = 42
|
|
146
|
+
shuffle: bool = True # affects holdout + in-graph mask fallback
|
|
147
|
+
|
|
148
|
+
# link prediction split (within each graph)
|
|
149
|
+
link_val_ratio: float = 0.1
|
|
150
|
+
link_test_ratio: float = 0.1
|
|
151
|
+
link_is_undirected: bool = False
|
|
152
|
+
|
|
153
|
+
# ----------------------------
|
|
154
|
+
# Training hyperparameters
|
|
155
|
+
# ----------------------------
|
|
156
|
+
epochs: int = 50
|
|
157
|
+
batch_size: int = 32
|
|
158
|
+
lr: float = 1e-3
|
|
159
|
+
weight_decay: float = 0.0
|
|
160
|
+
optimizer: Literal["adam", "adamw"] = "adam"
|
|
161
|
+
gradient_clip_norm: Optional[float] = None
|
|
162
|
+
early_stopping: bool = False
|
|
163
|
+
early_stopping_patience: int = 10
|
|
164
|
+
use_gpu: bool = True
|
|
165
|
+
|
|
166
|
+
# ----------------------------
|
|
167
|
+
# Network topology / model hyperparameters
|
|
168
|
+
# ----------------------------
|
|
169
|
+
conv: ConvKind = "sage"
|
|
170
|
+
hidden_dims: Tuple[int, ...] = (64, 64) # explicit per-layer widths (controls depth)
|
|
171
|
+
activation: Literal["relu", "gelu", "elu"] = "relu"
|
|
172
|
+
dropout: float = 0.1
|
|
173
|
+
batch_norm: bool = False
|
|
174
|
+
residual: bool = False
|
|
175
|
+
pooling: PoolingKind = "mean" # only for graph-level
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class _GNNBackbone(nn.Module):
|
|
179
|
+
"""
|
|
180
|
+
Shared GNN encoder that produces node embeddings.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(self,
|
|
184
|
+
in_dim: int,
|
|
185
|
+
hidden_dims: Tuple[int, ...],
|
|
186
|
+
conv: ConvKind = "sage",
|
|
187
|
+
activation: str = "relu",
|
|
188
|
+
dropout: float = 0.1,
|
|
189
|
+
batch_norm: bool = False,
|
|
190
|
+
residual: bool = False):
|
|
191
|
+
super().__init__()
|
|
192
|
+
if in_dim <= 0:
|
|
193
|
+
raise ValueError("in_dim must be > 0. Your dataset has no node features columns.")
|
|
194
|
+
if hidden_dims is None or len(hidden_dims) == 0:
|
|
195
|
+
raise ValueError("hidden_dims must contain at least one layer width, e.g. (64, 64).")
|
|
196
|
+
|
|
197
|
+
self.dropout = float(dropout)
|
|
198
|
+
self.use_bn = bool(batch_norm)
|
|
199
|
+
self.use_residual = bool(residual)
|
|
200
|
+
|
|
201
|
+
if activation == "relu":
|
|
202
|
+
self.act = F.relu
|
|
203
|
+
elif activation == "gelu":
|
|
204
|
+
self.act = F.gelu
|
|
205
|
+
elif activation == "elu":
|
|
206
|
+
self.act = F.elu
|
|
207
|
+
else:
|
|
208
|
+
raise ValueError("Unsupported activation. Use 'relu', 'gelu', or 'elu'.")
|
|
293
209
|
|
|
294
|
-
|
|
210
|
+
dims = [int(in_dim)] + [int(d) for d in hidden_dims]
|
|
295
211
|
|
|
296
|
-
|
|
212
|
+
self.convs = nn.ModuleList()
|
|
213
|
+
self.bns = nn.ModuleList()
|
|
297
214
|
|
|
298
|
-
|
|
299
|
-
|
|
215
|
+
for i in range(1, len(dims)):
|
|
216
|
+
in_ch, out_ch = dims[i - 1], dims[i]
|
|
300
217
|
|
|
301
|
-
|
|
302
|
-
|
|
218
|
+
if conv == "sage":
|
|
219
|
+
self.convs.append(SAGEConv(in_ch, out_ch))
|
|
220
|
+
elif conv == "gcn":
|
|
221
|
+
self.convs.append(GCNConv(in_ch, out_ch))
|
|
222
|
+
elif conv == "gatv2":
|
|
223
|
+
self.convs.append(GATv2Conv(in_ch, out_ch, heads=1, concat=False))
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Unsupported conv='{conv}'.")
|
|
303
226
|
|
|
304
|
-
|
|
305
|
-
|
|
227
|
+
if self.use_bn:
|
|
228
|
+
self.bns.append(nn.BatchNorm1d(out_ch))
|
|
306
229
|
|
|
307
|
-
|
|
308
|
-
def __init__(self, model_type="ClassifierHoldout", optimizer_str="Adam", amsgrad=False, betas=(0.9, 0.999), eps=1e-6, lr=0.001, lr_decay= 0, maximize=False, rho=0.9, weight_decay=0, cv_type="Holdout", split=[0.8,0.1, 0.1], k_folds=5, hl_widths=[32], conv_layer_type='SAGEConv', pooling="AvgPooling", batch_size=32, epochs=1,
|
|
309
|
-
use_gpu=False, loss_function="Cross Entropy", input_type="graph"):
|
|
310
|
-
"""
|
|
311
|
-
Parameters
|
|
312
|
-
----------
|
|
313
|
-
cv : str
|
|
314
|
-
A string to define the method of cross-validation
|
|
315
|
-
"Holdout": Holdout
|
|
316
|
-
"K-Fold": K-Fold cross validation
|
|
317
|
-
k_folds : int
|
|
318
|
-
An int value in the range of 2 to X to define the number of k-folds for cross-validation. Default is 5.
|
|
319
|
-
split : list
|
|
320
|
-
A list of three item in the range of 0 to 1 to define the split of train,
|
|
321
|
-
validate, and test data. A default value of [0.8,0.1,0.1] means 80% of data will be
|
|
322
|
-
used for training, 10% will be used for validation, and the remaining 10% will be used for training
|
|
323
|
-
hl_widths : list
|
|
324
|
-
List of hidden neurons for each layer such as [32] will mean
|
|
325
|
-
that there is one hidden layers in the network with 32 neurons
|
|
326
|
-
optimizer : torch.optim object
|
|
327
|
-
This will be the selected optimizer from torch.optim package. By
|
|
328
|
-
default, torch.optim.Adam is selected
|
|
329
|
-
learning_rate : float
|
|
330
|
-
a step value to be used to apply the gradients by optimizer
|
|
331
|
-
batch_size : int
|
|
332
|
-
to define a set of samples to be used for training and testing in
|
|
333
|
-
each step of an epoch
|
|
334
|
-
epochs : int
|
|
335
|
-
An epoch means training the neural network with all the training data for one cycle. In an epoch, we use all of the data exactly once. A forward pass and a backward pass together are counted as one pass
|
|
336
|
-
use_GPU : use the GPU. Otherwise, use the CPU
|
|
337
|
-
input_type : str
|
|
338
|
-
selects the input_type of model such as graph, node or edge
|
|
230
|
+
self.out_dim = dims[-1]
|
|
339
231
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
232
|
+
def forward(self, x, edge_index):
|
|
233
|
+
h = x
|
|
234
|
+
for i, conv in enumerate(self.convs):
|
|
235
|
+
h_in = h
|
|
236
|
+
h = conv(h, edge_index)
|
|
237
|
+
if self.use_bn:
|
|
238
|
+
h = self.bns[i](h)
|
|
239
|
+
h = self.act(h)
|
|
343
240
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
self.model_type = model_type
|
|
347
|
-
self.optimizer_str = optimizer_str
|
|
348
|
-
self.amsgrad = amsgrad
|
|
349
|
-
self.betas = betas
|
|
350
|
-
self.eps = eps
|
|
351
|
-
self.lr = lr
|
|
352
|
-
self.lr_decay = lr_decay
|
|
353
|
-
self.maximize = maximize
|
|
354
|
-
self.rho = rho
|
|
355
|
-
self.weight_decay = weight_decay
|
|
356
|
-
self.cv_type = cv_type
|
|
357
|
-
self.split = split
|
|
358
|
-
self.k_folds = k_folds
|
|
359
|
-
self.hl_widths = hl_widths
|
|
360
|
-
self.conv_layer_type = conv_layer_type
|
|
361
|
-
self.pooling = pooling
|
|
362
|
-
self.batch_size = batch_size
|
|
363
|
-
self.epochs = epochs
|
|
364
|
-
self.use_gpu = use_gpu
|
|
365
|
-
self.loss_function = loss_function
|
|
366
|
-
self.input_type = input_type
|
|
367
|
-
|
|
368
|
-
class _SAGEConv(nn.Module):
|
|
369
|
-
def __init__(self, in_feats, h_feats, num_classes, pooling=None):
|
|
370
|
-
super(_SAGEConv, self).__init__()
|
|
371
|
-
assert isinstance(h_feats, list), "h_feats must be a list"
|
|
372
|
-
h_feats = [x for x in h_feats if x is not None]
|
|
373
|
-
assert len(h_feats) != 0, "h_feats is empty. unable to add hidden layers"
|
|
374
|
-
self.list_of_layers = nn.ModuleList()
|
|
375
|
-
dim = [in_feats] + h_feats
|
|
376
|
-
|
|
377
|
-
# Convolution (Hidden) Layers
|
|
378
|
-
for i in range(1, len(dim)):
|
|
379
|
-
self.list_of_layers.append(SAGEConv(dim[i-1], dim[i]))
|
|
380
|
-
|
|
381
|
-
# Final Layer
|
|
382
|
-
self.final = nn.Linear(dim[-1], num_classes)
|
|
383
|
-
|
|
384
|
-
# Pooling layer
|
|
385
|
-
if pooling is None:
|
|
386
|
-
self.pooling_layer = None
|
|
387
|
-
else:
|
|
388
|
-
if "av" in pooling.lower():
|
|
389
|
-
self.pooling_layer = global_mean_pool
|
|
390
|
-
elif "max" in pooling.lower():
|
|
391
|
-
self.pooling_layer = global_max_pool
|
|
392
|
-
elif "sum" in pooling.lower():
|
|
393
|
-
self.pooling_layer = global_add_pool
|
|
394
|
-
else:
|
|
395
|
-
raise NotImplementedError
|
|
241
|
+
if self.use_residual and h_in.shape == h.shape:
|
|
242
|
+
h = h + h_in
|
|
396
243
|
|
|
397
|
-
|
|
398
|
-
x, edge_index, batch = data.x, data.edge_index, data.batch
|
|
399
|
-
h = x
|
|
400
|
-
# Generate node features
|
|
401
|
-
for layer in self.list_of_layers:
|
|
402
|
-
h = layer(h, edge_index)
|
|
403
|
-
h = F.relu(h)
|
|
404
|
-
# h will now be a matrix of dimension [num_nodes, h_feats[-1]]
|
|
405
|
-
h = self.final(h)
|
|
406
|
-
# Go from node-level features to graph-level features by pooling
|
|
407
|
-
if self.pooling_layer:
|
|
408
|
-
h = self.pooling_layer(h, batch)
|
|
409
|
-
# h will now be a vector of dimension [num_classes]
|
|
244
|
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
410
245
|
return h
|
|
411
246
|
|
|
412
|
-
class _GraphRegressorHoldout:
|
|
413
|
-
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
|
414
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
415
|
-
self.trainingDataset = trainingDataset
|
|
416
|
-
self.validationDataset = validationDataset
|
|
417
|
-
self.testingDataset = testingDataset
|
|
418
|
-
self.hparams = hparams
|
|
419
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
420
|
-
self.model = _SAGEConv(trainingDataset[0].num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
|
|
421
|
-
else:
|
|
422
|
-
raise NotImplementedError
|
|
423
|
-
|
|
424
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
425
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
|
426
|
-
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
427
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
428
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
|
429
|
-
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
430
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
431
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
|
432
|
-
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
433
|
-
|
|
434
|
-
self.use_gpu = hparams.use_gpu
|
|
435
|
-
self.training_loss_list = []
|
|
436
|
-
self.validation_loss_list = []
|
|
437
|
-
self.node_attr_key = trainingDataset[0].x.shape[1]
|
|
438
|
-
|
|
439
|
-
# Train, validate, test split
|
|
440
|
-
num_train = int(len(trainingDataset) * hparams.split[0])
|
|
441
|
-
num_validate = int(len(trainingDataset) * hparams.split[1])
|
|
442
|
-
num_test = len(trainingDataset) - num_train - num_validate
|
|
443
|
-
idx = torch.randperm(len(trainingDataset))
|
|
444
|
-
train_sampler = SubsetRandomSampler(idx[:num_train])
|
|
445
|
-
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
|
446
|
-
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:])
|
|
447
|
-
|
|
448
|
-
if validationDataset:
|
|
449
|
-
self.train_dataloader = DataLoader(trainingDataset,
|
|
450
|
-
batch_size=hparams.batch_size,
|
|
451
|
-
drop_last=False)
|
|
452
|
-
self.validate_dataloader = DataLoader(validationDataset,
|
|
453
|
-
batch_size=hparams.batch_size,
|
|
454
|
-
drop_last=False)
|
|
455
|
-
else:
|
|
456
|
-
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler,
|
|
457
|
-
batch_size=hparams.batch_size,
|
|
458
|
-
drop_last=False)
|
|
459
|
-
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler,
|
|
460
|
-
batch_size=hparams.batch_size,
|
|
461
|
-
drop_last=False)
|
|
462
|
-
|
|
463
|
-
if testingDataset:
|
|
464
|
-
self.test_dataloader = DataLoader(testingDataset,
|
|
465
|
-
batch_size=len(testingDataset),
|
|
466
|
-
drop_last=False)
|
|
467
|
-
else:
|
|
468
|
-
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler,
|
|
469
|
-
batch_size=hparams.batch_size,
|
|
470
|
-
drop_last=False)
|
|
471
|
-
|
|
472
|
-
def train(self):
|
|
473
|
-
# Init the loss and accuracy reporting lists
|
|
474
|
-
self.training_loss_list = []
|
|
475
|
-
self.validation_loss_list = []
|
|
476
|
-
|
|
477
|
-
# Run the training loop for defined number of epochs
|
|
478
|
-
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
|
479
|
-
# Iterate over the DataLoader for training data
|
|
480
|
-
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
481
|
-
data = data.to(self.device)
|
|
482
|
-
# Make sure the model is in training mode
|
|
483
|
-
self.model.train()
|
|
484
|
-
# Zero the gradients
|
|
485
|
-
self.optimizer.zero_grad()
|
|
486
|
-
|
|
487
|
-
# Perform forward pass
|
|
488
|
-
pred = self.model(data).to(self.device)
|
|
489
|
-
# Compute loss
|
|
490
|
-
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
|
491
|
-
|
|
492
|
-
# Perform backward pass
|
|
493
|
-
loss.backward()
|
|
494
|
-
|
|
495
|
-
# Perform optimization
|
|
496
|
-
self.optimizer.step()
|
|
497
|
-
|
|
498
|
-
self.training_loss_list.append(torch.sqrt(loss).item())
|
|
499
|
-
self.validate()
|
|
500
|
-
self.validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
|
501
|
-
gc.collect()
|
|
502
|
-
|
|
503
|
-
def validate(self):
|
|
504
|
-
self.model.eval()
|
|
505
|
-
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
506
|
-
data = data.to(self.device)
|
|
507
|
-
pred = self.model(data).to(self.device)
|
|
508
|
-
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
|
509
|
-
self.validation_loss = loss
|
|
510
|
-
|
|
511
|
-
def test(self):
|
|
512
|
-
self.model.eval()
|
|
513
|
-
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
514
|
-
data = data.to(self.device)
|
|
515
|
-
pred = self.model(data).to(self.device)
|
|
516
|
-
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
|
517
|
-
self.testing_loss = torch.sqrt(loss).item()
|
|
518
|
-
|
|
519
|
-
def save(self, path):
|
|
520
|
-
if path:
|
|
521
|
-
# Make sure the file extension is .pt
|
|
522
|
-
ext = path[-3:]
|
|
523
|
-
if ext.lower() != ".pt":
|
|
524
|
-
path = path + ".pt"
|
|
525
|
-
torch.save(self.model.state_dict(), path)
|
|
526
|
-
|
|
527
|
-
def load(self, path):
|
|
528
|
-
#self.model.load_state_dict(torch.load(path))
|
|
529
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
530
|
-
|
|
531
|
-
class _GraphRegressorKFold:
|
|
532
|
-
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
|
533
|
-
self.trainingDataset = trainingDataset
|
|
534
|
-
self.testingDataset = testingDataset
|
|
535
|
-
self.hparams = hparams
|
|
536
|
-
self.losses = []
|
|
537
|
-
self.min_loss = 0
|
|
538
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
539
|
-
|
|
540
|
-
self.model = self._initialize_model(hparams, trainingDataset)
|
|
541
|
-
self.optimizer = self._initialize_optimizer(hparams)
|
|
542
|
-
|
|
543
|
-
self.use_gpu = hparams.use_gpu
|
|
544
|
-
self.training_loss_list = []
|
|
545
|
-
self.validation_loss_list = []
|
|
546
|
-
self.node_attr_key = trainingDataset.node_attr_key
|
|
547
|
-
|
|
548
|
-
# Train, validate, test split
|
|
549
|
-
num_train = int(len(trainingDataset) * hparams.split[0])
|
|
550
|
-
num_validate = int(len(trainingDataset) * hparams.split[1])
|
|
551
|
-
num_test = len(trainingDataset) - num_train - num_validate
|
|
552
|
-
idx = torch.randperm(len(trainingDataset))
|
|
553
|
-
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
|
554
|
-
|
|
555
|
-
if testingDataset:
|
|
556
|
-
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
|
557
|
-
else:
|
|
558
|
-
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
559
|
-
|
|
560
|
-
def _initialize_model(self, hparams, dataset):
|
|
561
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
562
|
-
return _SAGEConv(dataset[0].num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
|
|
563
|
-
#return _SAGEConv(dataset.num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
|
|
564
|
-
else:
|
|
565
|
-
raise NotImplementedError
|
|
566
|
-
|
|
567
|
-
def _initialize_optimizer(self, hparams):
|
|
568
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
569
|
-
return torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
570
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
571
|
-
return torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
572
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
573
|
-
return torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps, lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
574
|
-
|
|
575
|
-
def reset_weights(self):
|
|
576
|
-
self.model = self._initialize_model(self.hparams, self.trainingDataset)
|
|
577
|
-
self.optimizer = self._initialize_optimizer(self.hparams)
|
|
578
|
-
|
|
579
|
-
def train(self):
|
|
580
|
-
k_folds = self.hparams.k_folds
|
|
581
|
-
torch.manual_seed(42)
|
|
582
|
-
|
|
583
|
-
kfold = KFold(n_splits=k_folds, shuffle=True)
|
|
584
|
-
models, weights, losses, train_dataloaders, validate_dataloaders = [], [], [], [], []
|
|
585
|
-
|
|
586
|
-
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", total=k_folds, leave=False):
|
|
587
|
-
epoch_training_loss_list, epoch_validation_loss_list = [], []
|
|
588
|
-
train_subsampler = SubsetRandomSampler(train_ids)
|
|
589
|
-
validate_subsampler = SubsetRandomSampler(validate_ids)
|
|
590
|
-
|
|
591
|
-
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
|
|
592
|
-
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
|
|
593
|
-
|
|
594
|
-
self.reset_weights()
|
|
595
|
-
best_rmse = np.inf
|
|
596
|
-
|
|
597
|
-
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
|
598
|
-
for batched_graph in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
599
|
-
self.model.train()
|
|
600
|
-
self.optimizer.zero_grad()
|
|
601
|
-
|
|
602
|
-
batched_graph = batched_graph.to(self.device)
|
|
603
|
-
pred = self.model(batched_graph)
|
|
604
|
-
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
|
605
|
-
loss.backward()
|
|
606
|
-
self.optimizer.step()
|
|
607
|
-
|
|
608
|
-
epoch_training_loss_list.append(torch.sqrt(loss).item())
|
|
609
|
-
self.validate()
|
|
610
|
-
epoch_validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
|
611
|
-
gc.collect()
|
|
612
|
-
|
|
613
|
-
models.append(self.model)
|
|
614
|
-
weights.append(copy.deepcopy(self.model.state_dict()))
|
|
615
|
-
losses.append(torch.sqrt(self.validation_loss).item())
|
|
616
|
-
train_dataloaders.append(self.train_dataloader)
|
|
617
|
-
validate_dataloaders.append(self.validate_dataloader)
|
|
618
|
-
self.training_loss_list.append(epoch_training_loss_list)
|
|
619
|
-
self.validation_loss_list.append(epoch_validation_loss_list)
|
|
620
|
-
|
|
621
|
-
self.losses = losses
|
|
622
|
-
self.min_loss = min(losses)
|
|
623
|
-
ind = losses.index(self.min_loss)
|
|
624
|
-
self.model = models[ind]
|
|
625
|
-
self.model.load_state_dict(weights[ind])
|
|
626
|
-
self.model.eval()
|
|
627
|
-
self.training_loss_list = self.training_loss_list[ind]
|
|
628
|
-
self.validation_loss_list = self.validation_loss_list[ind]
|
|
629
247
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
batched_graph = batched_graph.to(self.device)
|
|
642
|
-
pred = self.model(batched_graph)
|
|
643
|
-
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
|
644
|
-
self.testing_loss = torch.sqrt(loss).item()
|
|
645
|
-
|
|
646
|
-
def save(self, path):
|
|
647
|
-
if path:
|
|
648
|
-
ext = path[-3:]
|
|
649
|
-
if ext.lower() != ".pt":
|
|
650
|
-
path = path + ".pt"
|
|
651
|
-
torch.save(self.model.state_dict(), path)
|
|
652
|
-
|
|
653
|
-
def load(self, path):
|
|
654
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
655
|
-
|
|
656
|
-
class _GraphClassifierKFold:
|
|
657
|
-
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
|
658
|
-
self.trainingDataset = trainingDataset
|
|
659
|
-
self.testingDataset = testingDataset
|
|
660
|
-
self.hparams = hparams
|
|
661
|
-
self.testing_accuracy = 0
|
|
662
|
-
self.accuracies = []
|
|
663
|
-
self.max_accuracy = 0
|
|
664
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
665
|
-
|
|
666
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
667
|
-
self.model = _SAGEConv(trainingDataset.num_node_features, hparams.hl_widths,
|
|
668
|
-
trainingDataset.num_classes, hparams.pooling).to(self.device)
|
|
669
|
-
else:
|
|
670
|
-
raise NotImplementedError
|
|
671
|
-
|
|
672
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
673
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
|
674
|
-
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
675
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
676
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
|
677
|
-
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
678
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
679
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
|
680
|
-
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
681
|
-
self.use_gpu = hparams.use_gpu
|
|
682
|
-
self.training_loss_list = []
|
|
683
|
-
self.validation_loss_list = []
|
|
684
|
-
self.training_accuracy_list = []
|
|
685
|
-
self.validation_accuracy_list = []
|
|
686
|
-
|
|
687
|
-
def reset_weights(self):
|
|
688
|
-
if self.hparams.conv_layer_type.lower() == 'sageconv':
|
|
689
|
-
self.model = _SAGEConv(self.trainingDataset.num_node_features, self.hparams.hl_widths,
|
|
690
|
-
self.trainingDataset.num_classes, self.hparams.pooling).to(self.device)
|
|
691
|
-
else:
|
|
692
|
-
raise NotImplementedError
|
|
693
|
-
|
|
694
|
-
if self.hparams.optimizer_str.lower() == "adadelta":
|
|
695
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=self.hparams.eps,
|
|
696
|
-
lr=self.hparams.lr, rho=self.hparams.rho, weight_decay=self.hparams.weight_decay)
|
|
697
|
-
elif self.hparams.optimizer_str.lower() == "adagrad":
|
|
698
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=self.hparams.eps,
|
|
699
|
-
lr=self.hparams.lr, lr_decay=self.hparams.lr_decay, weight_decay=self.hparams.weight_decay)
|
|
700
|
-
elif self.hparams.optimizer_str.lower() == "adam":
|
|
701
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=self.hparams.amsgrad, betas=self.hparams.betas, eps=self.hparams.eps,
|
|
702
|
-
lr=self.hparams.lr, maximize=self.hparams.maximize, weight_decay=self.hparams.weight_decay)
|
|
703
|
-
|
|
704
|
-
def train(self):
|
|
705
|
-
k_folds = self.hparams.k_folds
|
|
706
|
-
|
|
707
|
-
# Init the loss and accuracy reporting lists
|
|
708
|
-
self.training_accuracy_list = []
|
|
709
|
-
self.training_loss_list = []
|
|
710
|
-
self.validation_accuracy_list = []
|
|
711
|
-
self.validation_loss_list = []
|
|
712
|
-
|
|
713
|
-
# Set fixed random number seed
|
|
714
|
-
torch.manual_seed(42)
|
|
715
|
-
|
|
716
|
-
# Define the K-fold Cross Validator
|
|
717
|
-
kfold = KFold(n_splits=k_folds, shuffle=True)
|
|
718
|
-
|
|
719
|
-
models = []
|
|
720
|
-
weights = []
|
|
721
|
-
accuracies = []
|
|
722
|
-
train_dataloaders = []
|
|
723
|
-
validate_dataloaders = []
|
|
724
|
-
|
|
725
|
-
# K-fold Cross-validation model evaluation
|
|
726
|
-
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", initial=1, total=k_folds, leave=False):
|
|
727
|
-
epoch_training_loss_list = []
|
|
728
|
-
epoch_training_accuracy_list = []
|
|
729
|
-
epoch_validation_loss_list = []
|
|
730
|
-
epoch_validation_accuracy_list = []
|
|
731
|
-
# Sample elements randomly from a given list of ids, no replacement.
|
|
732
|
-
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
|
|
733
|
-
validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
|
|
734
|
-
|
|
735
|
-
# Define data loaders for training and testing data in this fold
|
|
736
|
-
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler,
|
|
737
|
-
batch_size=self.hparams.batch_size,
|
|
738
|
-
drop_last=False)
|
|
739
|
-
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler,
|
|
740
|
-
batch_size=self.hparams.batch_size,
|
|
741
|
-
drop_last=False)
|
|
742
|
-
# Init the neural network
|
|
743
|
-
self.reset_weights()
|
|
744
|
-
|
|
745
|
-
# Run the training loop for defined number of epochs
|
|
746
|
-
for _ in tqdm(range(0,self.hparams.epochs), desc='Epochs', initial=1, total=self.hparams.epochs, leave=False):
|
|
747
|
-
temp_loss_list = []
|
|
748
|
-
temp_acc_list = []
|
|
749
|
-
|
|
750
|
-
# Iterate over the DataLoader for training data
|
|
751
|
-
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
752
|
-
data = data.to(self.device)
|
|
753
|
-
# Make sure the model is in training mode
|
|
754
|
-
self.model.train()
|
|
755
|
-
|
|
756
|
-
# Zero the gradients
|
|
757
|
-
self.optimizer.zero_grad()
|
|
758
|
-
|
|
759
|
-
# Perform forward pass
|
|
760
|
-
pred = self.model(data)
|
|
761
|
-
|
|
762
|
-
# Compute loss
|
|
763
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
764
|
-
logp = F.log_softmax(pred, 1)
|
|
765
|
-
loss = F.nll_loss(logp, data.y)
|
|
766
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
767
|
-
loss = F.cross_entropy(pred, data.y)
|
|
768
|
-
|
|
769
|
-
# Save loss information for reporting
|
|
770
|
-
temp_loss_list.append(loss.item())
|
|
771
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
772
|
-
|
|
773
|
-
# Perform backward pass
|
|
774
|
-
loss.backward()
|
|
775
|
-
|
|
776
|
-
# Perform optimization
|
|
777
|
-
self.optimizer.step()
|
|
778
|
-
|
|
779
|
-
epoch_training_accuracy_list.append(np.mean(temp_acc_list).item())
|
|
780
|
-
epoch_training_loss_list.append(np.mean(temp_loss_list).item())
|
|
781
|
-
self.validate()
|
|
782
|
-
epoch_validation_accuracy_list.append(self.validation_accuracy)
|
|
783
|
-
epoch_validation_loss_list.append(self.validation_loss)
|
|
784
|
-
gc.collect()
|
|
785
|
-
models.append(self.model)
|
|
786
|
-
weights.append(copy.deepcopy(self.model.state_dict()))
|
|
787
|
-
accuracies.append(self.validation_accuracy)
|
|
788
|
-
train_dataloaders.append(self.train_dataloader)
|
|
789
|
-
validate_dataloaders.append(self.validate_dataloader)
|
|
790
|
-
self.training_accuracy_list.append(epoch_training_accuracy_list)
|
|
791
|
-
self.training_loss_list.append(epoch_training_loss_list)
|
|
792
|
-
self.validation_accuracy_list.append(epoch_validation_accuracy_list)
|
|
793
|
-
self.validation_loss_list.append(epoch_validation_loss_list)
|
|
794
|
-
self.accuracies = accuracies
|
|
795
|
-
max_accuracy = max(accuracies)
|
|
796
|
-
self.max_accuracy = max_accuracy
|
|
797
|
-
ind = accuracies.index(max_accuracy)
|
|
798
|
-
self.model = models[ind]
|
|
799
|
-
self.model.load_state_dict(weights[ind])
|
|
800
|
-
self.model.eval()
|
|
801
|
-
self.training_accuracy_list = self.training_accuracy_list[ind]
|
|
802
|
-
self.training_loss_list = self.training_loss_list[ind]
|
|
803
|
-
self.validation_accuracy_list = self.validation_accuracy_list[ind]
|
|
804
|
-
self.validation_loss_list = self.validation_loss_list[ind]
|
|
805
|
-
|
|
806
|
-
def validate(self):
|
|
807
|
-
temp_loss_list = []
|
|
808
|
-
temp_acc_list = []
|
|
809
|
-
self.model.eval()
|
|
810
|
-
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
811
|
-
data = data.to(self.device)
|
|
812
|
-
pred = self.model(data)
|
|
813
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
814
|
-
logp = F.log_softmax(pred, 1)
|
|
815
|
-
loss = F.nll_loss(logp, data.y)
|
|
816
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
817
|
-
loss = F.cross_entropy(pred, data.y)
|
|
818
|
-
temp_loss_list.append(loss.item())
|
|
819
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
820
|
-
self.validation_accuracy = np.mean(temp_acc_list).item()
|
|
821
|
-
self.validation_loss = np.mean(temp_loss_list).item()
|
|
822
|
-
|
|
823
|
-
def test(self):
|
|
824
|
-
if self.testingDataset:
|
|
825
|
-
self.test_dataloader = DataLoader(self.testingDataset,
|
|
826
|
-
batch_size=len(self.testingDataset),
|
|
827
|
-
drop_last=False)
|
|
828
|
-
temp_loss_list = []
|
|
829
|
-
temp_acc_list = []
|
|
830
|
-
self.model.eval()
|
|
831
|
-
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
832
|
-
data = data.to(self.device)
|
|
833
|
-
pred = self.model(data)
|
|
834
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
835
|
-
logp = F.log_softmax(pred, 1)
|
|
836
|
-
loss = F.nll_loss(logp, data.y)
|
|
837
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
838
|
-
loss = F.cross_entropy(pred, data.y)
|
|
839
|
-
temp_loss_list.append(loss.item())
|
|
840
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
841
|
-
self.testing_accuracy = np.mean(temp_acc_list).item()
|
|
842
|
-
self.testing_loss = np.mean(temp_loss_list).item()
|
|
843
|
-
|
|
844
|
-
def save(self, path):
|
|
845
|
-
if path:
|
|
846
|
-
# Make sure the file extension is .pt
|
|
847
|
-
ext = path[-3:]
|
|
848
|
-
if ext.lower() != ".pt":
|
|
849
|
-
path = path + ".pt"
|
|
850
|
-
torch.save(self.model.state_dict(), path)
|
|
851
|
-
|
|
852
|
-
def load(self, path):
|
|
853
|
-
#self.model.load_state_dict(torch.load(path))
|
|
854
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
855
|
-
|
|
856
|
-
class _GraphClassifierHoldout:
|
|
857
|
-
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
|
858
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
859
|
-
self.trainingDataset = trainingDataset
|
|
860
|
-
self.validationDataset = validationDataset
|
|
861
|
-
self.testingDataset = testingDataset
|
|
862
|
-
self.hparams = hparams
|
|
863
|
-
gclasses = trainingDataset.num_classes
|
|
864
|
-
nfeats = trainingDataset.num_node_features
|
|
865
|
-
|
|
866
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
867
|
-
self.model = _SAGEConv(nfeats, hparams.hl_widths, gclasses, hparams.pooling).to(self.device)
|
|
868
|
-
else:
|
|
869
|
-
raise NotImplementedError
|
|
870
|
-
|
|
871
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
872
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
|
873
|
-
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
874
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
875
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
|
876
|
-
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
877
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
878
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
|
879
|
-
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
880
|
-
self.use_gpu = hparams.use_gpu
|
|
881
|
-
self.training_loss_list = []
|
|
882
|
-
self.validation_loss_list = []
|
|
883
|
-
self.training_accuracy_list = []
|
|
884
|
-
self.validation_accuracy_list = []
|
|
885
|
-
self.node_attr_key = trainingDataset[0].x.shape[1]
|
|
886
|
-
|
|
887
|
-
# train, validate, test split
|
|
888
|
-
num_train = int(len(trainingDataset) * hparams.split[0])
|
|
889
|
-
num_validate = int(len(trainingDataset) * hparams.split[1])
|
|
890
|
-
num_test = len(trainingDataset) - num_train - num_validate
|
|
891
|
-
idx = torch.randperm(len(trainingDataset))
|
|
892
|
-
train_sampler = SubsetRandomSampler(idx[:num_train])
|
|
893
|
-
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
|
894
|
-
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
|
895
|
-
|
|
896
|
-
if validationDataset:
|
|
897
|
-
self.train_dataloader = DataLoader(trainingDataset, batch_size=hparams.batch_size, drop_last=False)
|
|
898
|
-
self.validate_dataloader = DataLoader(validationDataset, batch_size=hparams.batch_size, drop_last=False)
|
|
899
|
-
else:
|
|
900
|
-
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
901
|
-
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
902
|
-
|
|
903
|
-
if testingDataset:
|
|
904
|
-
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
|
905
|
-
else:
|
|
906
|
-
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
907
|
-
|
|
908
|
-
def train(self):
|
|
909
|
-
# Init the loss and accuracy reporting lists
|
|
910
|
-
self.training_accuracy_list = []
|
|
911
|
-
self.training_loss_list = []
|
|
912
|
-
self.validation_accuracy_list = []
|
|
913
|
-
self.validation_loss_list = []
|
|
914
|
-
|
|
915
|
-
# Run the training loop for defined number of epochs
|
|
916
|
-
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', initial=1, leave=False):
|
|
917
|
-
temp_loss_list = []
|
|
918
|
-
temp_acc_list = []
|
|
919
|
-
# Make sure the model is in training mode
|
|
920
|
-
self.model.train()
|
|
921
|
-
# Iterate over the DataLoader for training data
|
|
922
|
-
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
923
|
-
data = data.to(self.device)
|
|
924
|
-
|
|
925
|
-
# Zero the gradients
|
|
926
|
-
self.optimizer.zero_grad()
|
|
927
|
-
|
|
928
|
-
# Perform forward pass
|
|
929
|
-
pred = self.model(data)
|
|
930
|
-
|
|
931
|
-
# Compute loss
|
|
932
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
933
|
-
logp = F.log_softmax(pred, 1)
|
|
934
|
-
loss = F.nll_loss(logp, data.y)
|
|
935
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
936
|
-
loss = F.cross_entropy(pred, data.y)
|
|
937
|
-
|
|
938
|
-
# Save loss information for reporting
|
|
939
|
-
temp_loss_list.append(loss.item())
|
|
940
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
941
|
-
|
|
942
|
-
# Perform backward pass
|
|
943
|
-
loss.backward()
|
|
944
|
-
|
|
945
|
-
# Perform optimization
|
|
946
|
-
self.optimizer.step()
|
|
947
|
-
|
|
948
|
-
self.training_accuracy_list.append(np.mean(temp_acc_list).item())
|
|
949
|
-
self.training_loss_list.append(np.mean(temp_loss_list).item())
|
|
950
|
-
self.validate()
|
|
951
|
-
self.validation_accuracy_list.append(self.validation_accuracy)
|
|
952
|
-
self.validation_loss_list.append(self.validation_loss)
|
|
953
|
-
gc.collect()
|
|
954
|
-
|
|
955
|
-
def validate(self):
|
|
956
|
-
temp_loss_list = []
|
|
957
|
-
temp_acc_list = []
|
|
958
|
-
self.model.eval()
|
|
959
|
-
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
960
|
-
data = data.to(self.device)
|
|
961
|
-
pred = self.model(data)
|
|
962
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
963
|
-
logp = F.log_softmax(pred, 1)
|
|
964
|
-
loss = F.nll_loss(logp, data.y)
|
|
965
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
966
|
-
loss = F.cross_entropy(pred, data.y)
|
|
967
|
-
temp_loss_list.append(loss.item())
|
|
968
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
969
|
-
self.validation_accuracy = np.mean(temp_acc_list).item()
|
|
970
|
-
self.validation_loss = np.mean(temp_loss_list).item()
|
|
971
|
-
|
|
972
|
-
def test(self):
|
|
973
|
-
if self.test_dataloader:
|
|
974
|
-
temp_loss_list = []
|
|
975
|
-
temp_acc_list = []
|
|
976
|
-
self.model.eval()
|
|
977
|
-
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
978
|
-
data = data.to(self.device)
|
|
979
|
-
pred = self.model(data)
|
|
980
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
981
|
-
logp = F.log_softmax(pred, 1)
|
|
982
|
-
loss = F.nll_loss(logp, data.y)
|
|
983
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
984
|
-
loss = F.cross_entropy(pred, data.y)
|
|
985
|
-
temp_loss_list.append(loss.item())
|
|
986
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
987
|
-
self.testing_accuracy = np.mean(temp_acc_list).item()
|
|
988
|
-
self.testing_loss = np.mean(temp_loss_list).item()
|
|
989
|
-
|
|
990
|
-
def save(self, path):
|
|
991
|
-
if path:
|
|
992
|
-
# Make sure the file extension is .pt
|
|
993
|
-
ext = path[-3:]
|
|
994
|
-
if ext.lower() != ".pt":
|
|
995
|
-
path = path + ".pt"
|
|
996
|
-
torch.save(self.model.state_dict(), path)
|
|
997
|
-
|
|
998
|
-
def load(self, path):
|
|
999
|
-
#self.model.load_state_dict(torch.load(path))
|
|
1000
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
1001
|
-
|
|
1002
|
-
class _NodeClassifierHoldout:
|
|
1003
|
-
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
|
1004
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
1005
|
-
self.trainingDataset = trainingDataset
|
|
1006
|
-
self.validationDataset = validationDataset
|
|
1007
|
-
self.testingDataset = testingDataset
|
|
1008
|
-
self.hparams = hparams
|
|
1009
|
-
gclasses = trainingDataset.num_classes
|
|
1010
|
-
nfeats = trainingDataset.num_node_features
|
|
1011
|
-
|
|
1012
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
1013
|
-
# pooling is set None for Node classifier
|
|
1014
|
-
self.model = _SAGEConv(nfeats, hparams.hl_widths, gclasses, None).to(self.device)
|
|
1015
|
-
else:
|
|
1016
|
-
raise NotImplementedError
|
|
1017
|
-
|
|
1018
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
1019
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
|
1020
|
-
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
1021
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
1022
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
|
1023
|
-
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
1024
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
1025
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
|
1026
|
-
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
1027
|
-
self.use_gpu = hparams.use_gpu
|
|
1028
|
-
self.training_loss_list = []
|
|
1029
|
-
self.validation_loss_list = []
|
|
1030
|
-
self.training_accuracy_list = []
|
|
1031
|
-
self.validation_accuracy_list = []
|
|
1032
|
-
self.node_attr_key = trainingDataset[0].x.shape[1]
|
|
1033
|
-
|
|
1034
|
-
# train, validate, test split
|
|
1035
|
-
num_train = int(len(trainingDataset) * hparams.split[0])
|
|
1036
|
-
num_validate = int(len(trainingDataset) * hparams.split[1])
|
|
1037
|
-
num_test = len(trainingDataset) - num_train - num_validate
|
|
1038
|
-
idx = torch.randperm(len(trainingDataset))
|
|
1039
|
-
train_sampler = SubsetRandomSampler(idx[:num_train])
|
|
1040
|
-
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
|
1041
|
-
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
|
1042
|
-
|
|
1043
|
-
if validationDataset:
|
|
1044
|
-
self.train_dataloader = DataLoader(trainingDataset, batch_size=hparams.batch_size, drop_last=False)
|
|
1045
|
-
self.validate_dataloader = DataLoader(validationDataset, batch_size=hparams.batch_size, drop_last=False)
|
|
1046
|
-
else:
|
|
1047
|
-
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
1048
|
-
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
1049
|
-
|
|
1050
|
-
if testingDataset:
|
|
1051
|
-
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
|
1052
|
-
else:
|
|
1053
|
-
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
1054
|
-
|
|
1055
|
-
def train(self):
|
|
1056
|
-
# Init the loss and accuracy reporting lists
|
|
1057
|
-
self.training_accuracy_list = []
|
|
1058
|
-
self.training_loss_list = []
|
|
1059
|
-
self.validation_accuracy_list = []
|
|
1060
|
-
self.validation_loss_list = []
|
|
1061
|
-
|
|
1062
|
-
# Run the training loop for defined number of epochs
|
|
1063
|
-
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', initial=1, leave=False):
|
|
1064
|
-
temp_loss_list = []
|
|
1065
|
-
temp_acc_list = []
|
|
1066
|
-
# Iterate over the DataLoader for training data
|
|
1067
|
-
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
1068
|
-
data = data.to(self.device)
|
|
1069
|
-
# Make sure the model is in training mode
|
|
1070
|
-
self.model.train()
|
|
1071
|
-
|
|
1072
|
-
# Zero the gradients
|
|
1073
|
-
self.optimizer.zero_grad()
|
|
1074
|
-
|
|
1075
|
-
# Perform forward pass
|
|
1076
|
-
pred = self.model(data)
|
|
1077
|
-
|
|
1078
|
-
# Compute loss
|
|
1079
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
1080
|
-
logp = F.log_softmax(pred, 1)
|
|
1081
|
-
loss = F.nll_loss(logp, data.y)
|
|
1082
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
1083
|
-
loss = F.cross_entropy(pred, data.y)
|
|
1084
|
-
|
|
1085
|
-
# Save loss information for reporting
|
|
1086
|
-
temp_loss_list.append(loss.item())
|
|
1087
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
1088
|
-
|
|
1089
|
-
# Perform backward pass
|
|
1090
|
-
loss.backward()
|
|
1091
|
-
|
|
1092
|
-
# Perform optimization
|
|
1093
|
-
self.optimizer.step()
|
|
1094
|
-
|
|
1095
|
-
self.training_accuracy_list.append(np.mean(temp_acc_list).item())
|
|
1096
|
-
self.training_loss_list.append(np.mean(temp_loss_list).item())
|
|
1097
|
-
self.validate()
|
|
1098
|
-
self.validation_accuracy_list.append(self.validation_accuracy)
|
|
1099
|
-
self.validation_loss_list.append(self.validation_loss)
|
|
1100
|
-
gc.collect()
|
|
1101
|
-
|
|
1102
|
-
def validate(self):
|
|
1103
|
-
temp_loss_list = []
|
|
1104
|
-
temp_acc_list = []
|
|
1105
|
-
self.model.eval()
|
|
1106
|
-
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
1107
|
-
data = data.to(self.device)
|
|
1108
|
-
pred = self.model(data)
|
|
1109
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
1110
|
-
logp = F.log_softmax(pred, 1)
|
|
1111
|
-
loss = F.nll_loss(logp, data.y)
|
|
1112
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
1113
|
-
loss = F.cross_entropy(pred, data.y)
|
|
1114
|
-
temp_loss_list.append(loss.item())
|
|
1115
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
1116
|
-
self.validation_accuracy = np.mean(temp_acc_list).item()
|
|
1117
|
-
self.validation_loss = np.mean(temp_loss_list).item()
|
|
1118
|
-
|
|
1119
|
-
def test(self):
|
|
1120
|
-
if self.test_dataloader:
|
|
1121
|
-
temp_loss_list = []
|
|
1122
|
-
temp_acc_list = []
|
|
1123
|
-
self.model.eval()
|
|
1124
|
-
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
1125
|
-
data = data.to(self.device)
|
|
1126
|
-
pred = self.model(data)
|
|
1127
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
1128
|
-
logp = F.log_softmax(pred, 1)
|
|
1129
|
-
loss = F.nll_loss(logp, data.y)
|
|
1130
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
1131
|
-
loss = F.cross_entropy(pred, data.y)
|
|
1132
|
-
temp_loss_list.append(loss.item())
|
|
1133
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
1134
|
-
self.testing_accuracy = np.mean(temp_acc_list).item()
|
|
1135
|
-
self.testing_loss = np.mean(temp_loss_list).item()
|
|
1136
|
-
|
|
1137
|
-
def save(self, path):
|
|
1138
|
-
if path:
|
|
1139
|
-
# Make sure the file extension is .pt
|
|
1140
|
-
ext = path[-3:]
|
|
1141
|
-
if ext.lower() != ".pt":
|
|
1142
|
-
path = path + ".pt"
|
|
1143
|
-
torch.save(self.model.state_dict(), path)
|
|
1144
|
-
|
|
1145
|
-
def load(self, path):
|
|
1146
|
-
#self.model.load_state_dict(torch.load(path))
|
|
1147
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
1148
|
-
|
|
1149
|
-
class _NodeRegressorHoldout:
|
|
1150
|
-
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
|
1151
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
1152
|
-
self.trainingDataset = trainingDataset
|
|
1153
|
-
self.validationDataset = validationDataset
|
|
1154
|
-
self.testingDataset = testingDataset
|
|
1155
|
-
self.hparams = hparams
|
|
1156
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
1157
|
-
# pooling is set None for Node regressor
|
|
1158
|
-
self.model = _SAGEConv(trainingDataset[0].num_node_features, hparams.hl_widths, 1, None).to(self.device)
|
|
1159
|
-
else:
|
|
1160
|
-
raise NotImplementedError
|
|
1161
|
-
|
|
1162
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
1163
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
|
1164
|
-
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
1165
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
1166
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
|
1167
|
-
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
1168
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
1169
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
|
1170
|
-
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
1171
|
-
|
|
1172
|
-
self.use_gpu = hparams.use_gpu
|
|
1173
|
-
self.training_loss_list = []
|
|
1174
|
-
self.validation_loss_list = []
|
|
1175
|
-
self.node_attr_key = trainingDataset[0].x.shape[1]
|
|
1176
|
-
|
|
1177
|
-
# Train, validate, test split
|
|
1178
|
-
num_train = int(len(trainingDataset) * hparams.split[0])
|
|
1179
|
-
num_validate = int(len(trainingDataset) * hparams.split[1])
|
|
1180
|
-
num_test = len(trainingDataset) - num_train - num_validate
|
|
1181
|
-
idx = torch.randperm(len(trainingDataset))
|
|
1182
|
-
train_sampler = SubsetRandomSampler(idx[:num_train])
|
|
1183
|
-
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
|
1184
|
-
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:])
|
|
1185
|
-
|
|
1186
|
-
if validationDataset:
|
|
1187
|
-
self.train_dataloader = DataLoader(trainingDataset,
|
|
1188
|
-
batch_size=hparams.batch_size,
|
|
1189
|
-
drop_last=False)
|
|
1190
|
-
self.validate_dataloader = DataLoader(validationDataset,
|
|
1191
|
-
batch_size=hparams.batch_size,
|
|
1192
|
-
drop_last=False)
|
|
1193
|
-
else:
|
|
1194
|
-
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler,
|
|
1195
|
-
batch_size=hparams.batch_size,
|
|
1196
|
-
drop_last=False)
|
|
1197
|
-
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler,
|
|
1198
|
-
batch_size=hparams.batch_size,
|
|
1199
|
-
drop_last=False)
|
|
1200
|
-
|
|
1201
|
-
if testingDataset:
|
|
1202
|
-
self.test_dataloader = DataLoader(testingDataset,
|
|
1203
|
-
batch_size=len(testingDataset),
|
|
1204
|
-
drop_last=False)
|
|
1205
|
-
else:
|
|
1206
|
-
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler,
|
|
1207
|
-
batch_size=hparams.batch_size,
|
|
1208
|
-
drop_last=False)
|
|
1209
|
-
|
|
1210
|
-
def train(self):
|
|
1211
|
-
# Init the loss and accuracy reporting lists
|
|
1212
|
-
self.training_loss_list = []
|
|
1213
|
-
self.validation_loss_list = []
|
|
1214
|
-
|
|
1215
|
-
# Run the training loop for defined number of epochs
|
|
1216
|
-
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
|
1217
|
-
# Iterate over the DataLoader for training data
|
|
1218
|
-
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
1219
|
-
data = data.to(self.device)
|
|
1220
|
-
# Make sure the model is in training mode
|
|
1221
|
-
self.model.train()
|
|
1222
|
-
# Zero the gradients
|
|
1223
|
-
self.optimizer.zero_grad()
|
|
1224
|
-
|
|
1225
|
-
# Perform forward pass
|
|
1226
|
-
pred = self.model(data).to(self.device)
|
|
1227
|
-
# Compute loss
|
|
1228
|
-
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
|
1229
|
-
|
|
1230
|
-
# Perform backward pass
|
|
1231
|
-
loss.backward()
|
|
1232
|
-
|
|
1233
|
-
# Perform optimization
|
|
1234
|
-
self.optimizer.step()
|
|
1235
|
-
|
|
1236
|
-
self.training_loss_list.append(torch.sqrt(loss).item())
|
|
1237
|
-
self.validate()
|
|
1238
|
-
self.validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
|
1239
|
-
gc.collect()
|
|
1240
|
-
|
|
1241
|
-
def validate(self):
|
|
1242
|
-
self.model.eval()
|
|
1243
|
-
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
1244
|
-
data = data.to(self.device)
|
|
1245
|
-
pred = self.model(data).to(self.device)
|
|
1246
|
-
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
|
1247
|
-
self.validation_loss = loss
|
|
1248
|
-
|
|
1249
|
-
def test(self):
|
|
1250
|
-
self.model.eval()
|
|
1251
|
-
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
1252
|
-
data = data.to(self.device)
|
|
1253
|
-
pred = self.model(data).to(self.device)
|
|
1254
|
-
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
|
1255
|
-
self.testing_loss = torch.sqrt(loss).item()
|
|
1256
|
-
|
|
1257
|
-
def save(self, path):
|
|
1258
|
-
if path:
|
|
1259
|
-
# Make sure the file extension is .pt
|
|
1260
|
-
ext = path[-3:]
|
|
1261
|
-
if ext.lower() != ".pt":
|
|
1262
|
-
path = path + ".pt"
|
|
1263
|
-
torch.save(self.model.state_dict(), path)
|
|
1264
|
-
|
|
1265
|
-
def load(self, path):
|
|
1266
|
-
#self.model.load_state_dict(torch.load(path))
|
|
1267
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
1268
|
-
|
|
1269
|
-
class _NodeClassifierKFold:
|
|
1270
|
-
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
|
1271
|
-
self.trainingDataset = trainingDataset
|
|
1272
|
-
self.testingDataset = testingDataset
|
|
1273
|
-
self.hparams = hparams
|
|
1274
|
-
self.testing_accuracy = 0
|
|
1275
|
-
self.accuracies = []
|
|
1276
|
-
self.max_accuracy = 0
|
|
1277
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
1278
|
-
|
|
1279
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
1280
|
-
# pooling is set None for Node classifier
|
|
1281
|
-
self.model = _SAGEConv(trainingDataset.num_node_features, hparams.hl_widths,
|
|
1282
|
-
trainingDataset.num_classes, None).to(self.device)
|
|
1283
|
-
else:
|
|
1284
|
-
raise NotImplementedError
|
|
1285
|
-
|
|
1286
|
-
if hparams.optimizer_str.lower() == "adadelta":
|
|
1287
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
|
1288
|
-
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
|
1289
|
-
elif hparams.optimizer_str.lower() == "adagrad":
|
|
1290
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
|
1291
|
-
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
|
1292
|
-
elif hparams.optimizer_str.lower() == "adam":
|
|
1293
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
|
1294
|
-
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
|
1295
|
-
self.use_gpu = hparams.use_gpu
|
|
1296
|
-
self.training_loss_list = []
|
|
1297
|
-
self.validation_loss_list = []
|
|
1298
|
-
self.training_accuracy_list = []
|
|
1299
|
-
self.validation_accuracy_list = []
|
|
1300
|
-
|
|
1301
|
-
def reset_weights(self):
|
|
1302
|
-
if self.hparams.conv_layer_type.lower() == 'sageconv':
|
|
1303
|
-
# pooling is set None for Node classifier
|
|
1304
|
-
self.model = _SAGEConv(self.trainingDataset.num_node_features, self.hparams.hl_widths,
|
|
1305
|
-
self.trainingDataset.num_classes, None).to(self.device)
|
|
1306
|
-
else:
|
|
1307
|
-
raise NotImplementedError
|
|
1308
|
-
|
|
1309
|
-
if self.hparams.optimizer_str.lower() == "adadelta":
|
|
1310
|
-
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=self.hparams.eps,
|
|
1311
|
-
lr=self.hparams.lr, rho=self.hparams.rho, weight_decay=self.hparams.weight_decay)
|
|
1312
|
-
elif self.hparams.optimizer_str.lower() == "adagrad":
|
|
1313
|
-
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=self.hparams.eps,
|
|
1314
|
-
lr=self.hparams.lr, lr_decay=self.hparams.lr_decay, weight_decay=self.hparams.weight_decay)
|
|
1315
|
-
elif self.hparams.optimizer_str.lower() == "adam":
|
|
1316
|
-
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=self.hparams.amsgrad, betas=self.hparams.betas, eps=self.hparams.eps,
|
|
1317
|
-
lr=self.hparams.lr, maximize=self.hparams.maximize, weight_decay=self.hparams.weight_decay)
|
|
1318
|
-
|
|
1319
|
-
def train(self):
|
|
1320
|
-
k_folds = self.hparams.k_folds
|
|
1321
|
-
|
|
1322
|
-
# Init the loss and accuracy reporting lists
|
|
1323
|
-
self.training_accuracy_list = []
|
|
1324
|
-
self.training_loss_list = []
|
|
1325
|
-
self.validation_accuracy_list = []
|
|
1326
|
-
self.validation_loss_list = []
|
|
1327
|
-
|
|
1328
|
-
# Set fixed random number seed
|
|
1329
|
-
torch.manual_seed(42)
|
|
1330
|
-
|
|
1331
|
-
# Define the K-fold Cross Validator
|
|
1332
|
-
kfold = KFold(n_splits=k_folds, shuffle=True)
|
|
1333
|
-
|
|
1334
|
-
models = []
|
|
1335
|
-
weights = []
|
|
1336
|
-
accuracies = []
|
|
1337
|
-
train_dataloaders = []
|
|
1338
|
-
validate_dataloaders = []
|
|
1339
|
-
|
|
1340
|
-
# K-fold Cross-validation model evaluation
|
|
1341
|
-
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", initial=1, total=k_folds, leave=False):
|
|
1342
|
-
epoch_training_loss_list = []
|
|
1343
|
-
epoch_training_accuracy_list = []
|
|
1344
|
-
epoch_validation_loss_list = []
|
|
1345
|
-
epoch_validation_accuracy_list = []
|
|
1346
|
-
# Sample elements randomly from a given list of ids, no replacement.
|
|
1347
|
-
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
|
|
1348
|
-
validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
|
|
1349
|
-
|
|
1350
|
-
# Define data loaders for training and testing data in this fold
|
|
1351
|
-
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler,
|
|
1352
|
-
batch_size=self.hparams.batch_size,
|
|
1353
|
-
drop_last=False)
|
|
1354
|
-
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler,
|
|
1355
|
-
batch_size=self.hparams.batch_size,
|
|
1356
|
-
drop_last=False)
|
|
1357
|
-
# Init the neural network
|
|
1358
|
-
self.reset_weights()
|
|
1359
|
-
|
|
1360
|
-
# Run the training loop for defined number of epochs
|
|
1361
|
-
for _ in tqdm(range(0,self.hparams.epochs), desc='Epochs', initial=1, total=self.hparams.epochs, leave=False):
|
|
1362
|
-
temp_loss_list = []
|
|
1363
|
-
temp_acc_list = []
|
|
1364
|
-
|
|
1365
|
-
# Iterate over the DataLoader for training data
|
|
1366
|
-
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
|
1367
|
-
data = data.to(self.device)
|
|
1368
|
-
# Make sure the model is in training mode
|
|
1369
|
-
self.model.train()
|
|
1370
|
-
|
|
1371
|
-
# Zero the gradients
|
|
1372
|
-
self.optimizer.zero_grad()
|
|
1373
|
-
|
|
1374
|
-
# Perform forward pass
|
|
1375
|
-
pred = self.model(data)
|
|
1376
|
-
|
|
1377
|
-
# Compute loss
|
|
1378
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
1379
|
-
logp = F.log_softmax(pred, 1)
|
|
1380
|
-
loss = F.nll_loss(logp, data.y)
|
|
1381
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
1382
|
-
loss = F.cross_entropy(pred, data.y)
|
|
1383
|
-
|
|
1384
|
-
# Save loss information for reporting
|
|
1385
|
-
temp_loss_list.append(loss.item())
|
|
1386
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
1387
|
-
|
|
1388
|
-
# Perform backward pass
|
|
1389
|
-
loss.backward()
|
|
1390
|
-
|
|
1391
|
-
# Perform optimization
|
|
1392
|
-
self.optimizer.step()
|
|
1393
|
-
|
|
1394
|
-
epoch_training_accuracy_list.append(np.mean(temp_acc_list).item())
|
|
1395
|
-
epoch_training_loss_list.append(np.mean(temp_loss_list).item())
|
|
1396
|
-
self.validate()
|
|
1397
|
-
epoch_validation_accuracy_list.append(self.validation_accuracy)
|
|
1398
|
-
epoch_validation_loss_list.append(self.validation_loss)
|
|
1399
|
-
gc.collect()
|
|
1400
|
-
models.append(self.model)
|
|
1401
|
-
weights.append(copy.deepcopy(self.model.state_dict()))
|
|
1402
|
-
accuracies.append(self.validation_accuracy)
|
|
1403
|
-
train_dataloaders.append(self.train_dataloader)
|
|
1404
|
-
validate_dataloaders.append(self.validate_dataloader)
|
|
1405
|
-
self.training_accuracy_list.append(epoch_training_accuracy_list)
|
|
1406
|
-
self.training_loss_list.append(epoch_training_loss_list)
|
|
1407
|
-
self.validation_accuracy_list.append(epoch_validation_accuracy_list)
|
|
1408
|
-
self.validation_loss_list.append(epoch_validation_loss_list)
|
|
1409
|
-
self.accuracies = accuracies
|
|
1410
|
-
max_accuracy = max(accuracies)
|
|
1411
|
-
self.max_accuracy = max_accuracy
|
|
1412
|
-
ind = accuracies.index(max_accuracy)
|
|
1413
|
-
self.model = models[ind]
|
|
1414
|
-
self.model.load_state_dict(weights[ind])
|
|
1415
|
-
self.model.eval()
|
|
1416
|
-
self.training_accuracy_list = self.training_accuracy_list[ind]
|
|
1417
|
-
self.training_loss_list = self.training_loss_list[ind]
|
|
1418
|
-
self.validation_accuracy_list = self.validation_accuracy_list[ind]
|
|
1419
|
-
self.validation_loss_list = self.validation_loss_list[ind]
|
|
1420
|
-
|
|
1421
|
-
def validate(self):
|
|
1422
|
-
temp_loss_list = []
|
|
1423
|
-
temp_acc_list = []
|
|
1424
|
-
self.model.eval()
|
|
1425
|
-
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
1426
|
-
data = data.to(self.device)
|
|
1427
|
-
pred = self.model(data)
|
|
1428
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
1429
|
-
logp = F.log_softmax(pred, 1)
|
|
1430
|
-
loss = F.nll_loss(logp, data.y)
|
|
1431
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
1432
|
-
loss = F.cross_entropy(pred, data.y)
|
|
1433
|
-
temp_loss_list.append(loss.item())
|
|
1434
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
1435
|
-
self.validation_accuracy = np.mean(temp_acc_list).item()
|
|
1436
|
-
self.validation_loss = np.mean(temp_loss_list).item()
|
|
1437
|
-
|
|
1438
|
-
def test(self):
|
|
1439
|
-
if self.testingDataset:
|
|
1440
|
-
self.test_dataloader = DataLoader(self.testingDataset,
|
|
1441
|
-
batch_size=len(self.testingDataset),
|
|
1442
|
-
drop_last=False)
|
|
1443
|
-
temp_loss_list = []
|
|
1444
|
-
temp_acc_list = []
|
|
1445
|
-
self.model.eval()
|
|
1446
|
-
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
1447
|
-
data = data.to(self.device)
|
|
1448
|
-
pred = self.model(data)
|
|
1449
|
-
if self.hparams.loss_function.lower() == "negative log likelihood":
|
|
1450
|
-
logp = F.log_softmax(pred, 1)
|
|
1451
|
-
loss = F.nll_loss(logp, data.y)
|
|
1452
|
-
elif self.hparams.loss_function.lower() == "cross entropy":
|
|
1453
|
-
loss = F.cross_entropy(pred, data.y)
|
|
1454
|
-
temp_loss_list.append(loss.item())
|
|
1455
|
-
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
|
1456
|
-
self.testing_accuracy = np.mean(temp_acc_list).item()
|
|
1457
|
-
self.testing_loss = np.mean(temp_loss_list).item()
|
|
1458
|
-
|
|
1459
|
-
def save(self, path):
|
|
1460
|
-
if path:
|
|
1461
|
-
# Make sure the file extension is .pt
|
|
1462
|
-
ext = path[-3:]
|
|
1463
|
-
if ext.lower() != ".pt":
|
|
1464
|
-
path = path + ".pt"
|
|
1465
|
-
torch.save(self.model.state_dict(), path)
|
|
1466
|
-
|
|
1467
|
-
def load(self, path):
|
|
1468
|
-
#self.model.load_state_dict(torch.load(path))
|
|
1469
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
1470
|
-
|
|
1471
|
-
class _NodeRegressorKFold:
|
|
1472
|
-
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
|
1473
|
-
self.trainingDataset = trainingDataset
|
|
1474
|
-
self.testingDataset = testingDataset
|
|
1475
|
-
self.hparams = hparams
|
|
1476
|
-
self.losses = []
|
|
1477
|
-
self.min_loss = 0
|
|
1478
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
1479
|
-
|
|
1480
|
-
self.model = self._initialize_model(hparams, trainingDataset)
|
|
1481
|
-
self.optimizer = self._initialize_optimizer(hparams)
|
|
1482
|
-
|
|
1483
|
-
self.use_gpu = hparams.use_gpu
|
|
1484
|
-
self.training_loss_list = []
|
|
1485
|
-
self.validation_loss_list = []
|
|
1486
|
-
self.node_attr_key = trainingDataset.node_attr_key
|
|
1487
|
-
|
|
1488
|
-
# Train, validate, test split
|
|
1489
|
-
num_train = int(len(trainingDataset) * hparams.split[0])
|
|
1490
|
-
num_validate = int(len(trainingDataset) * hparams.split[1])
|
|
1491
|
-
num_test = len(trainingDataset) - num_train - num_validate
|
|
1492
|
-
idx = torch.randperm(len(trainingDataset))
|
|
1493
|
-
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
|
1494
|
-
|
|
1495
|
-
if testingDataset:
|
|
1496
|
-
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
|
1497
|
-
else:
|
|
1498
|
-
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
|
1499
|
-
|
|
1500
|
-
def _initialize_model(self, hparams, dataset):
|
|
1501
|
-
if hparams.conv_layer_type.lower() == 'sageconv':
|
|
1502
|
-
# pooling is set None for Node
|
|
1503
|
-
return _SAGEConv(dataset.num_node_features, hparams.hl_widths, 1, None).to(self.device)
|
|
248
|
+
class _GraphHead(nn.Module):
|
|
249
|
+
def __init__(self, in_dim: int, out_dim: int, pooling: PoolingKind = "mean", dropout: float = 0.1):
|
|
250
|
+
super().__init__()
|
|
251
|
+
self.dropout = float(dropout)
|
|
252
|
+
|
|
253
|
+
if pooling == "mean":
|
|
254
|
+
self.pool = global_mean_pool
|
|
255
|
+
elif pooling == "max":
|
|
256
|
+
self.pool = global_max_pool
|
|
257
|
+
elif pooling == "sum":
|
|
258
|
+
self.pool = global_add_pool
|
|
1504
259
|
else:
|
|
1505
|
-
raise
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
self.
|
|
260
|
+
raise ValueError("GraphHead requires pooling in {'mean','max','sum'}.")
|
|
261
|
+
|
|
262
|
+
self.mlp = nn.Sequential(
|
|
263
|
+
nn.Linear(in_dim, in_dim),
|
|
264
|
+
nn.ReLU(),
|
|
265
|
+
nn.Dropout(self.dropout),
|
|
266
|
+
nn.Linear(in_dim, out_dim),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def forward(self, node_emb, batch):
|
|
270
|
+
g = self.pool(node_emb, batch)
|
|
271
|
+
return self.mlp(g)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class _NodeHead(nn.Module):
|
|
275
|
+
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
276
|
+
super().__init__()
|
|
277
|
+
self.mlp = nn.Sequential(
|
|
278
|
+
nn.Linear(in_dim, in_dim),
|
|
279
|
+
nn.ReLU(),
|
|
280
|
+
nn.Dropout(float(dropout)),
|
|
281
|
+
nn.Linear(in_dim, out_dim),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def forward(self, node_emb):
|
|
285
|
+
return self.mlp(node_emb)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class _EdgeHead(nn.Module):
|
|
289
|
+
"""
|
|
290
|
+
Edge prediction head using concatenation of endpoint embeddings.
|
|
291
|
+
"""
|
|
292
|
+
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
293
|
+
super().__init__()
|
|
294
|
+
self.mlp = nn.Sequential(
|
|
295
|
+
nn.Linear(in_dim * 2, in_dim),
|
|
296
|
+
nn.ReLU(),
|
|
297
|
+
nn.Dropout(float(dropout)),
|
|
298
|
+
nn.Linear(in_dim, out_dim),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def forward(self, node_emb, edge_index):
|
|
302
|
+
src, dst = edge_index[0], edge_index[1]
|
|
303
|
+
h = torch.cat([node_emb[src], node_emb[dst]], dim=-1)
|
|
304
|
+
return self.mlp(h)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class _LinkPredictor(nn.Module):
|
|
308
|
+
"""
|
|
309
|
+
Binary link predictor (edge exists or not).
|
|
310
|
+
"""
|
|
311
|
+
def __init__(self, in_dim: int, hidden: int = 64, dropout: float = 0.1):
|
|
312
|
+
super().__init__()
|
|
313
|
+
self.net = nn.Sequential(
|
|
314
|
+
nn.Linear(in_dim * 2, hidden),
|
|
315
|
+
nn.ReLU(),
|
|
316
|
+
nn.Dropout(float(dropout)),
|
|
317
|
+
nn.Linear(hidden, 1),
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
def forward(self, node_emb, edge_label_index):
|
|
321
|
+
src, dst = edge_label_index[0], edge_label_index[1]
|
|
322
|
+
h = torch.cat([node_emb[src], node_emb[dst]], dim=-1)
|
|
323
|
+
return self.net(h).squeeze(-1) # logits
|
|
1569
324
|
|
|
1570
|
-
def validate(self):
|
|
1571
|
-
self.model.eval()
|
|
1572
|
-
for batched_graph in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
|
1573
|
-
batched_graph = batched_graph.to(self.device)
|
|
1574
|
-
pred = self.model(batched_graph)
|
|
1575
|
-
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
|
1576
|
-
self.validation_loss = loss
|
|
1577
|
-
|
|
1578
|
-
def test(self):
|
|
1579
|
-
self.model.eval()
|
|
1580
|
-
for batched_graph in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
|
1581
|
-
batched_graph = batched_graph.to(self.device)
|
|
1582
|
-
pred = self.model(batched_graph)
|
|
1583
|
-
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
|
1584
|
-
self.testing_loss = torch.sqrt(loss).item()
|
|
1585
|
-
|
|
1586
|
-
def save(self, path):
|
|
1587
|
-
if path:
|
|
1588
|
-
ext = path[-3:]
|
|
1589
|
-
if ext.lower() != ".pt":
|
|
1590
|
-
path = path + ".pt"
|
|
1591
|
-
torch.save(self.model.state_dict(), path)
|
|
1592
|
-
|
|
1593
|
-
def load(self, path):
|
|
1594
|
-
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
|
1595
325
|
|
|
1596
326
|
class PyG:
|
|
327
|
+
"""
|
|
328
|
+
A clean PyTorch Geometric interface for TopologicPy-exported CSV datasets.
|
|
329
|
+
|
|
330
|
+
You can control medium-level hyperparameters by passing keyword arguments to ByCSVPath,
|
|
331
|
+
for example:
|
|
332
|
+
|
|
333
|
+
pyg = PyG.ByCSVPath(
|
|
334
|
+
path="C:/dataset",
|
|
335
|
+
level="graph",
|
|
336
|
+
task="classification",
|
|
337
|
+
graphLabelType="categorical",
|
|
338
|
+
cv="kfold",
|
|
339
|
+
k_folds=5,
|
|
340
|
+
conv="gatv2",
|
|
341
|
+
hidden_dims=(128, 128, 64),
|
|
342
|
+
activation="gelu",
|
|
343
|
+
batch_norm=True,
|
|
344
|
+
residual=True,
|
|
345
|
+
dropout=0.2,
|
|
346
|
+
lr=1e-3,
|
|
347
|
+
optimizer="adamw",
|
|
348
|
+
early_stopping=True,
|
|
349
|
+
early_stopping_patience=10,
|
|
350
|
+
gradient_clip_norm=1.0
|
|
351
|
+
)
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
# ---------
|
|
355
|
+
# Creation
|
|
356
|
+
# ---------
|
|
1597
357
|
@staticmethod
|
|
1598
|
-
def
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
358
|
+
def ByCSVPath(path: str,
|
|
359
|
+
level: Level = "graph",
|
|
360
|
+
task: TaskKind = "classification",
|
|
361
|
+
graphLabelType: LabelType = "categorical",
|
|
362
|
+
nodeLabelType: LabelType = "categorical",
|
|
363
|
+
edgeLabelType: LabelType = "categorical",
|
|
364
|
+
**kwargs) -> "PyG":
|
|
365
|
+
cfg = _RunConfig(level=level, task=task,
|
|
366
|
+
graph_label_type=graphLabelType,
|
|
367
|
+
node_label_type=nodeLabelType,
|
|
368
|
+
edge_label_type=edgeLabelType)
|
|
1604
369
|
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
370
|
+
# allow override of any config field via kwargs
|
|
371
|
+
for k, v in kwargs.items():
|
|
372
|
+
if hasattr(cfg, k):
|
|
373
|
+
setattr(cfg, k, v)
|
|
1609
374
|
|
|
1610
|
-
|
|
1611
|
-
-------
|
|
1612
|
-
PyG Dataset
|
|
1613
|
-
The PyG dataset
|
|
1614
|
-
"""
|
|
1615
|
-
if not isinstance(path, str):
|
|
1616
|
-
print("PyG.DatasetByCSVPath - Error: The input path parameter is not a valid string. Returning None.")
|
|
1617
|
-
return None
|
|
1618
|
-
if not os.path.exists(path):
|
|
1619
|
-
print("PyG.DatasetByCSVPath - Error: The input path parameter does not exist. Returning None.")
|
|
1620
|
-
return None
|
|
1621
|
-
|
|
1622
|
-
return CustomGraphDataset(root=path, node_level=node_level, graph_level=graph_level, node_attr_key=nodeATTRKey, edge_attr_key=edgeATTRKey)
|
|
1623
|
-
|
|
1624
|
-
@staticmethod
|
|
1625
|
-
def DatasetGraphLabels(dataset, graphLabelHeader="label"):
|
|
1626
|
-
"""
|
|
1627
|
-
Returns the labels of the graphs in the input dataset
|
|
1628
|
-
|
|
1629
|
-
Parameters
|
|
1630
|
-
----------
|
|
1631
|
-
dataset : CustomDataset
|
|
1632
|
-
The input dataset
|
|
1633
|
-
graphLabelHeader: str , optional
|
|
1634
|
-
The key string under which the graph labels are stored. Default is "label".
|
|
1635
|
-
|
|
1636
|
-
Returns
|
|
1637
|
-
-------
|
|
1638
|
-
list
|
|
1639
|
-
The list of graph labels.
|
|
1640
|
-
"""
|
|
1641
|
-
import torch
|
|
375
|
+
return PyG(path=path, config=cfg)
|
|
1642
376
|
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
label = g.y
|
|
1647
|
-
graph_labels.append(label.item())
|
|
1648
|
-
return graph_labels
|
|
377
|
+
def __init__(self, path: str, config: _RunConfig):
|
|
378
|
+
self.path = path
|
|
379
|
+
self.config = config
|
|
1649
380
|
|
|
1650
|
-
|
|
1651
|
-
def DatasetSplit(dataset, split=[0.8,0.1,0.1], shuffle=True, randomState=42):
|
|
1652
|
-
"""
|
|
1653
|
-
Splits the dataset into three subsets.
|
|
381
|
+
self.device = torch.device("cuda:0" if (config.use_gpu and torch.cuda.is_available()) else "cpu")
|
|
1654
382
|
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
The input dataset
|
|
1659
|
-
split: list , optional
|
|
1660
|
-
The list of ratios. This list must be made out of three numbers adding to 1.
|
|
1661
|
-
shuffle: boolean , optional
|
|
1662
|
-
If set to True, the subsets are created from random indices. Otherwise, they are split sequentially. Default is True.
|
|
1663
|
-
randomState : int , optional
|
|
1664
|
-
The random seed to use for reproducibility. Default is 42.
|
|
1665
|
-
|
|
1666
|
-
Returns
|
|
1667
|
-
-------
|
|
1668
|
-
list
|
|
1669
|
-
The list of three subset datasets.
|
|
1670
|
-
"""
|
|
383
|
+
self.graph_df: Optional[pd.DataFrame] = None
|
|
384
|
+
self.nodes_df: Optional[pd.DataFrame] = None
|
|
385
|
+
self.edges_df: Optional[pd.DataFrame] = None
|
|
1671
386
|
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
# Calculate the number of samples for each split
|
|
1678
|
-
dataset_len = len(dataset)
|
|
1679
|
-
train_len = int(train_ratio * dataset_len)
|
|
1680
|
-
val_len = int(val_ratio * dataset_len)
|
|
1681
|
-
test_len = dataset_len - train_len - val_len # Ensure it adds up correctly
|
|
1682
|
-
|
|
1683
|
-
## Generate indices for the split
|
|
1684
|
-
indices = list(range(dataset_len))
|
|
1685
|
-
if shuffle:
|
|
1686
|
-
torch.manual_seed(randomState) # For reproducibility
|
|
1687
|
-
indices = torch.randperm(dataset_len).tolist() # Shuffled indices
|
|
1688
|
-
|
|
1689
|
-
# Create splits
|
|
1690
|
-
train_indices = indices[:train_len]
|
|
1691
|
-
val_indices = indices[train_len:train_len + val_len]
|
|
1692
|
-
test_indices = indices[train_len + val_len:train_len + val_len + test_len]
|
|
1693
|
-
|
|
1694
|
-
# Create new instances of CustomGraphDataset using the indices
|
|
1695
|
-
train_dataset = CustomGraphDataset(data_list=dataset.data_list, indices=train_indices)
|
|
1696
|
-
val_dataset = CustomGraphDataset(data_list=dataset.data_list, indices=val_indices)
|
|
1697
|
-
test_dataset = CustomGraphDataset(data_list=dataset.data_list, indices=test_indices)
|
|
1698
|
-
|
|
1699
|
-
return train_dataset, val_dataset, test_dataset
|
|
387
|
+
self.data_list: List[Data] = []
|
|
388
|
+
self.train_set: Optional[List[Data]] = None
|
|
389
|
+
self.val_set: Optional[List[Data]] = None
|
|
390
|
+
self.test_set: Optional[List[Data]] = None
|
|
1700
391
|
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
Returns the parameters of the optimizer
|
|
392
|
+
self.model: Optional[nn.Module] = None
|
|
393
|
+
self.history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}
|
|
394
|
+
self.cv_report: Optional[Dict[str, Union[float, List[Dict[str, float]]]]] = None
|
|
1705
395
|
|
|
1706
|
-
|
|
1707
|
-
----------
|
|
1708
|
-
amsgrad : bool , optional.
|
|
1709
|
-
amsgrad is an extension to the Adam version of gradient descent that attempts to improve the convergence properties of the algorithm, avoiding large abrupt changes in the learning rate for each input variable. Default is True.
|
|
1710
|
-
betas : tuple , optional
|
|
1711
|
-
Betas are used as for smoothing the path to the convergence also providing some momentum to cross a local minima or saddle point. Default is (0.9, 0.999).
|
|
1712
|
-
eps : float . optional.
|
|
1713
|
-
eps is a term added to the denominator to improve numerical stability. Default is 0.000001.
|
|
1714
|
-
lr : float
|
|
1715
|
-
The learning rate (lr) defines the adjustment in the weights of our network with respect to the loss gradient descent. Default is 0.001.
|
|
1716
|
-
maximize : float , optional
|
|
1717
|
-
maximize the params based on the objective, instead of minimizing. Default is False.
|
|
1718
|
-
weightDecay : float , optional
|
|
1719
|
-
weightDecay (L2 penalty) is a regularization technique applied to the weights of a neural network. Default is 0.0.
|
|
396
|
+
self._num_outputs: int = 1
|
|
1720
397
|
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
The dictionary of the optimizer parameters. The dictionary contains the following keys and values:
|
|
1725
|
-
- "name" (str): The name of the optimizer
|
|
1726
|
-
- "amsgrad" (bool):
|
|
1727
|
-
- "betas" (tuple):
|
|
1728
|
-
- "eps" (float):
|
|
1729
|
-
- "lr" (float):
|
|
1730
|
-
- "maximize" (bool):
|
|
1731
|
-
- weightDecay (float):
|
|
398
|
+
self._load_csv()
|
|
399
|
+
self._build_data_list()
|
|
400
|
+
self._split_holdout()
|
|
1732
401
|
|
|
1733
|
-
|
|
1734
|
-
return {"name":name, "amsgrad":amsgrad, "betas":betas, "eps":eps, "lr": lr, "maximize":maximize, "weight_decay":weightDecay, "rho":rho, "lr_decay":lr_decay}
|
|
1735
|
-
|
|
1736
|
-
@staticmethod
|
|
1737
|
-
def Hyperparameters(optimizer, model_type="classifier", cv_type="Holdout", split=[0.8,0.1,0.1], k_folds=5,
|
|
1738
|
-
hl_widths=[32], conv_layer_type="SAGEConv", pooling="AvgPooling",
|
|
1739
|
-
batch_size=1, epochs=1, use_gpu=False, loss_function="Cross Entropy",
|
|
1740
|
-
input_type="graph"):
|
|
1741
|
-
"""
|
|
1742
|
-
Creates a hyperparameters object based on the input settings.
|
|
402
|
+
self._build_model()
|
|
1743
403
|
|
|
1744
|
-
Parameters
|
|
1745
|
-
----------
|
|
1746
|
-
model_type : str , optional
|
|
1747
|
-
The desired type of model. The options are:
|
|
1748
|
-
- "Classifier"
|
|
1749
|
-
- "Regressor"
|
|
1750
|
-
The option is case insensitive. Default is "classifierholdout"
|
|
1751
|
-
optimizer : Optimizer
|
|
1752
|
-
The desired optimizer.
|
|
1753
|
-
cv_type : str , optional
|
|
1754
|
-
The desired cross-validation method. This can be "Holdout" or "K-Fold". It is case-insensitive. Default is "Holdout".
|
|
1755
|
-
split : list , optional
|
|
1756
|
-
The desired split between training validation, and testing. [0.8, 0.1, 0.1] means that 80% of the data is used for training 10% of the data is used for validation, and 10% is used for testing. Default is [0.8, 0.1, 0.1].
|
|
1757
|
-
k_folds : int , optional
|
|
1758
|
-
The desired number of k-folds. Default is 5.
|
|
1759
|
-
hl_widths : list , optional
|
|
1760
|
-
The list of hidden layer widths. A list of [16, 32, 16] means that the model will have 3 hidden layers with number of neurons in each being 16, 32, 16 respectively from input to output. Default is [32].
|
|
1761
|
-
conv_layer_type : str , optional
|
|
1762
|
-
The desired type of the convolution layer. The options are "Classic", "GraphConv", "GINConv", "SAGEConv", "TAGConv", "DGN". It is case insensitive. Default is "SAGEConv".
|
|
1763
|
-
pooling : str , optional
|
|
1764
|
-
The desired type of pooling. The options are "AvgPooling", "MaxPooling", or "SumPooling". It is case insensitive. Default is "AvgPooling".
|
|
1765
|
-
batch_size : int , optional
|
|
1766
|
-
The desired batch size. Default is 1.
|
|
1767
|
-
epochs : int , optional
|
|
1768
|
-
The desired number of epochs. Default is 1.
|
|
1769
|
-
use_gpu : bool , optional
|
|
1770
|
-
If set to True, the model will attempt to use the GPU. Default is False.
|
|
1771
|
-
loss_function : str , optional
|
|
1772
|
-
The desired loss function. The options are "Cross-Entropy" or "Negative Log Likelihood". It is case insensitive. Default is "Cross-Entropy".
|
|
1773
|
-
input_type : str
|
|
1774
|
-
selects the input_type of model such as graph, node or edge
|
|
1775
|
-
Returns
|
|
1776
|
-
-------
|
|
1777
|
-
Hyperparameters
|
|
1778
|
-
The created hyperparameters object.
|
|
1779
404
|
|
|
405
|
+
# ----------------------------
|
|
406
|
+
# Convenience: hyperparameters
|
|
407
|
+
# ----------------------------
|
|
408
|
+
def SetHyperparameters(self, **kwargs) -> Dict[str, Union[str, int, float, bool, Tuple]]:
|
|
1780
409
|
"""
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
split,
|
|
1800
|
-
k_folds,
|
|
1801
|
-
hl_widths,
|
|
1802
|
-
conv_layer_type,
|
|
1803
|
-
pooling,
|
|
1804
|
-
batch_size,
|
|
1805
|
-
epochs,
|
|
1806
|
-
use_gpu,
|
|
1807
|
-
loss_function,
|
|
1808
|
-
input_type)
|
|
1809
|
-
|
|
1810
|
-
@staticmethod
|
|
1811
|
-
def Model(hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
|
410
|
+
Set one or more configuration values (hyperparameters) in a safe, readable way.
|
|
411
|
+
|
|
412
|
+
Examples
|
|
413
|
+
--------
|
|
414
|
+
pyg.SetHyperparameters(
|
|
415
|
+
cv="kfold", k_folds=5, k_stratify=True,
|
|
416
|
+
conv="gatv2", hidden_dims=(128, 128, 64),
|
|
417
|
+
activation="gelu", batch_norm=True, residual=True,
|
|
418
|
+
lr=1e-3, optimizer="adamw",
|
|
419
|
+
early_stopping=True, early_stopping_patience=10,
|
|
420
|
+
gradient_clip_norm=1.0
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
Notes
|
|
424
|
+
-----
|
|
425
|
+
- Unknown keys are ignored (with a warning if verbose=True).
|
|
426
|
+
- Some values are validated (e.g., split sums, hidden_dims).
|
|
427
|
+
- If you change model-related settings (conv/hidden_dims/etc.) the model is rebuilt automatically.
|
|
1812
428
|
"""
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
|
|
429
|
+
cfg = self.config
|
|
430
|
+
changed_model = False
|
|
431
|
+
|
|
432
|
+
for k, v in kwargs.items():
|
|
433
|
+
if not hasattr(cfg, k):
|
|
434
|
+
if cfg.verbose:
|
|
435
|
+
print(f"PyG.SetHyperparameters - Warning: Unknown parameter '{k}' ignored.")
|
|
436
|
+
continue
|
|
437
|
+
|
|
438
|
+
# Basic validation / normalisation
|
|
439
|
+
if k == "split":
|
|
440
|
+
if (not isinstance(v, (tuple, list))) or len(v) != 3:
|
|
441
|
+
raise ValueError("split must be a 3-tuple, e.g. (0.8, 0.1, 0.1).")
|
|
442
|
+
s = float(v[0]) + float(v[1]) + float(v[2])
|
|
443
|
+
if abs(s - 1.0) > 1e-3:
|
|
444
|
+
raise ValueError("split ratios must sum to 1.")
|
|
445
|
+
v = (float(v[0]), float(v[1]), float(v[2]))
|
|
446
|
+
|
|
447
|
+
if k == "hidden_dims":
|
|
448
|
+
if isinstance(v, list):
|
|
449
|
+
v = tuple(v)
|
|
450
|
+
if (not isinstance(v, tuple)) or len(v) == 0:
|
|
451
|
+
raise ValueError("hidden_dims must be a non-empty tuple, e.g. (64, 64).")
|
|
452
|
+
v = tuple(int(x) for x in v)
|
|
453
|
+
changed_model = True
|
|
454
|
+
|
|
455
|
+
if k in ["conv", "activation", "dropout", "batch_norm", "residual", "pooling"]:
|
|
456
|
+
changed_model = True
|
|
457
|
+
|
|
458
|
+
setattr(cfg, k, v)
|
|
459
|
+
|
|
460
|
+
# rebuild model if needed
|
|
461
|
+
if changed_model:
|
|
462
|
+
self._build_model()
|
|
463
|
+
|
|
464
|
+
return self.Summary()
|
|
465
|
+
|
|
466
|
+
def Summary(self) -> Dict[str, Union[str, int, float, bool, Tuple]]:
|
|
1831
467
|
"""
|
|
1832
|
-
|
|
1833
|
-
model = None
|
|
1834
|
-
if hparams.model_type.lower() == "classifier":
|
|
1835
|
-
if hparams.input_type == 'graph':
|
|
1836
|
-
if hparams.cv_type.lower() == "holdout":
|
|
1837
|
-
model = _GraphClassifierHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
|
1838
|
-
elif "k" in hparams.cv_type.lower():
|
|
1839
|
-
model = _GraphClassifierKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
|
1840
|
-
elif hparams.input_type == 'node':
|
|
1841
|
-
if hparams.cv_type.lower() == "holdout":
|
|
1842
|
-
model = _NodeClassifierHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
|
1843
|
-
elif "k" in hparams.cv_type.lower():
|
|
1844
|
-
model = _NodeClassifierKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
|
1845
|
-
elif hparams.model_type.lower() == "regressor":
|
|
1846
|
-
if hparams.input_type == 'graph':
|
|
1847
|
-
if hparams.cv_type.lower() == "holdout":
|
|
1848
|
-
model = _GraphRegressorHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
|
1849
|
-
elif "k" in hparams.cv_type.lower():
|
|
1850
|
-
model = _GraphRegressorKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
|
1851
|
-
elif hparams.input_type == 'node':
|
|
1852
|
-
if hparams.cv_type.lower() == "holdout":
|
|
1853
|
-
model = _NodeRegressorHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
|
1854
|
-
elif "k" in hparams.cv_type.lower():
|
|
1855
|
-
model = _NodeRegressorKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
|
1856
|
-
else:
|
|
1857
|
-
raise NotImplementedError
|
|
1858
|
-
return model
|
|
1859
|
-
|
|
1860
|
-
@staticmethod
|
|
1861
|
-
def ModelTrain(model):
|
|
1862
|
-
"""
|
|
1863
|
-
Trains the neural network model.
|
|
1864
|
-
|
|
1865
|
-
Parameters
|
|
1866
|
-
----------
|
|
1867
|
-
model : Model
|
|
1868
|
-
The input model.
|
|
1869
|
-
|
|
1870
|
-
Returns
|
|
1871
|
-
-------
|
|
1872
|
-
Model
|
|
1873
|
-
The trained model
|
|
1874
|
-
|
|
468
|
+
Return a compact dictionary of the most relevant configuration choices.
|
|
1875
469
|
"""
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
470
|
+
cfg = self.config
|
|
471
|
+
return {
|
|
472
|
+
"level": cfg.level,
|
|
473
|
+
"task": cfg.task,
|
|
474
|
+
"graph_label_type": cfg.graph_label_type,
|
|
475
|
+
"node_label_type": cfg.node_label_type,
|
|
476
|
+
"edge_label_type": cfg.edge_label_type,
|
|
477
|
+
"cv": cfg.cv,
|
|
478
|
+
"split": cfg.split,
|
|
479
|
+
"k_folds": cfg.k_folds,
|
|
480
|
+
"conv": cfg.conv,
|
|
481
|
+
"hidden_dims": cfg.hidden_dims,
|
|
482
|
+
"activation": cfg.activation,
|
|
483
|
+
"dropout": cfg.dropout,
|
|
484
|
+
"batch_norm": cfg.batch_norm,
|
|
485
|
+
"residual": cfg.residual,
|
|
486
|
+
"pooling": cfg.pooling,
|
|
487
|
+
"epochs": cfg.epochs,
|
|
488
|
+
"batch_size": cfg.batch_size,
|
|
489
|
+
"lr": cfg.lr,
|
|
490
|
+
"weight_decay": cfg.weight_decay,
|
|
491
|
+
"optimizer": cfg.optimizer,
|
|
492
|
+
"gradient_clip_norm": cfg.gradient_clip_norm,
|
|
493
|
+
"early_stopping": cfg.early_stopping,
|
|
494
|
+
"early_stopping_patience": cfg.early_stopping_patience,
|
|
495
|
+
"device": str(self.device),
|
|
496
|
+
"num_graphs": len(self.data_list),
|
|
497
|
+
"num_outputs": int(self._num_outputs),
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
# ----------------------------
|
|
501
|
+
# Convenience: CV visualisation
|
|
502
|
+
# ----------------------------
|
|
503
|
+
def PlotCrossValidationSummary(self,
|
|
504
|
+
cv_report: Optional[Dict[str, Union[float, List[Dict[str, float]]]]] = None,
|
|
505
|
+
metrics: Optional[List[str]] = None,
|
|
506
|
+
show_mean_std: bool = True):
|
|
1883
507
|
"""
|
|
1884
|
-
|
|
508
|
+
Plot a cross-validation summary as grouped bars per fold (Plotly).
|
|
1885
509
|
|
|
1886
510
|
Parameters
|
|
1887
511
|
----------
|
|
1888
|
-
|
|
1889
|
-
|
|
512
|
+
cv_report : dict, optional
|
|
513
|
+
Output from CrossValidate(). If None, uses self.cv_report.
|
|
514
|
+
metrics : list[str], optional
|
|
515
|
+
Metrics to display. If None, chooses a sensible default based on task:
|
|
516
|
+
- classification: ["accuracy", "f1", "precision", "recall"]
|
|
517
|
+
- regression : ["mae", "rmse", "r2"]
|
|
518
|
+
show_mean_std : bool, optional
|
|
519
|
+
If True, includes mean and ±std reference lines.
|
|
1890
520
|
|
|
1891
521
|
Returns
|
|
1892
522
|
-------
|
|
1893
|
-
|
|
1894
|
-
The tested model
|
|
1895
|
-
|
|
523
|
+
plotly.graph_objects.Figure
|
|
1896
524
|
"""
|
|
1897
|
-
if
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
525
|
+
if cv_report is None:
|
|
526
|
+
cv_report = self.cv_report
|
|
527
|
+
if cv_report is None:
|
|
528
|
+
raise ValueError("No cross-validation report found. Run CrossValidate() first or pass cv_report.")
|
|
529
|
+
|
|
530
|
+
fold_metrics = cv_report.get("fold_metrics", [])
|
|
531
|
+
if not fold_metrics:
|
|
532
|
+
raise ValueError("cv_report has no fold_metrics.")
|
|
533
|
+
|
|
534
|
+
# default metrics
|
|
535
|
+
if metrics is None:
|
|
536
|
+
if self.config.task == "regression":
|
|
537
|
+
metrics = ["mae", "rmse", "r2"]
|
|
538
|
+
else:
|
|
539
|
+
metrics = ["accuracy", "f1", "precision", "recall"]
|
|
540
|
+
|
|
541
|
+
folds = [int(fm.get("fold", i)) for i, fm in enumerate(fold_metrics)]
|
|
542
|
+
|
|
543
|
+
fig = go.Figure()
|
|
544
|
+
for met in metrics:
|
|
545
|
+
vals = [float(fm.get(met, 0.0)) for fm in fold_metrics]
|
|
546
|
+
fig.add_trace(go.Bar(name=met, x=folds, y=vals))
|
|
547
|
+
|
|
548
|
+
if show_mean_std:
|
|
549
|
+
mean_k = f"mean_{met}"
|
|
550
|
+
std_k = f"std_{met}"
|
|
551
|
+
if mean_k in cv_report and std_k in cv_report:
|
|
552
|
+
mu = float(cv_report[mean_k])
|
|
553
|
+
sd = float(cv_report[std_k])
|
|
554
|
+
# mean line
|
|
555
|
+
fig.add_trace(go.Scatter(
|
|
556
|
+
x=[min(folds), max(folds)], y=[mu, mu],
|
|
557
|
+
mode="lines", name=f"{met} mean", line=dict(dash="dash")
|
|
558
|
+
))
|
|
559
|
+
# +/- std (as band using two lines)
|
|
560
|
+
fig.add_trace(go.Scatter(
|
|
561
|
+
x=[min(folds), max(folds)], y=[mu + sd, mu + sd],
|
|
562
|
+
mode="lines", name=f"{met} +std", line=dict(dash="dot")
|
|
563
|
+
))
|
|
564
|
+
fig.add_trace(go.Scatter(
|
|
565
|
+
x=[min(folds), max(folds)], y=[mu - sd, mu - sd],
|
|
566
|
+
mode="lines", name=f"{met} -std", line=dict(dash="dot")
|
|
567
|
+
))
|
|
568
|
+
|
|
569
|
+
fig.update_layout(
|
|
570
|
+
barmode="group",
|
|
571
|
+
title="Cross-Validation Summary",
|
|
572
|
+
xaxis_title="Fold",
|
|
573
|
+
yaxis_title="Metric Value"
|
|
574
|
+
)
|
|
575
|
+
return fig
|
|
576
|
+
|
|
577
|
+
# ----------------
|
|
578
|
+
# Dataset loading
|
|
579
|
+
# ----------------
|
|
580
|
+
def _load_csv(self):
|
|
581
|
+
if not isinstance(self.path, str) or (not os.path.exists(self.path)):
|
|
582
|
+
raise ValueError("PyG - Error: path does not exist.")
|
|
583
|
+
|
|
584
|
+
gpath = os.path.join(self.path, "graphs.csv")
|
|
585
|
+
npath = os.path.join(self.path, "nodes.csv")
|
|
586
|
+
epath = os.path.join(self.path, "edges.csv")
|
|
587
|
+
|
|
588
|
+
if not os.path.exists(gpath) or not os.path.exists(npath) or not os.path.exists(epath):
|
|
589
|
+
raise ValueError("PyG - Error: graphs.csv, nodes.csv, edges.csv must exist in the folder.")
|
|
590
|
+
|
|
591
|
+
self.graph_df = pd.read_csv(gpath)
|
|
592
|
+
self.nodes_df = pd.read_csv(npath)
|
|
593
|
+
self.edges_df = pd.read_csv(epath)
|
|
594
|
+
|
|
595
|
+
def _feature_columns(self, df: pd.DataFrame, prefix: str) -> List[str]:
|
|
596
|
+
cols = [c for c in df.columns if c.startswith(prefix + "_")]
|
|
597
|
+
def _key(c):
|
|
598
|
+
parts = c.rsplit("_", 1)
|
|
599
|
+
if len(parts) == 2 and parts[1].isdigit():
|
|
600
|
+
return int(parts[1])
|
|
601
|
+
return 10**9
|
|
602
|
+
return sorted(cols, key=_key)
|
|
1920
603
|
|
|
1921
|
-
"""
|
|
1922
|
-
import os
|
|
1923
|
-
|
|
1924
|
-
if model == None:
|
|
1925
|
-
print("PyG.ModelSave - Error: The input model parameter is invalid. Returning None.")
|
|
1926
|
-
return None
|
|
1927
|
-
if path == None:
|
|
1928
|
-
print("PyG.ModelSave - Error: The input path parameter is invalid. Returning None.")
|
|
1929
|
-
return None
|
|
1930
|
-
if not overwrite and os.path.exists(path):
|
|
1931
|
-
print("PyG.ModelSave - Error: a file already exists at the specified path and overwrite is set to False. Returning None.")
|
|
1932
|
-
return None
|
|
1933
|
-
if overwrite and os.path.exists(path):
|
|
1934
|
-
os.remove(path)
|
|
1935
|
-
# Make sure the file extension is .pt
|
|
1936
|
-
ext = path[len(path)-3:len(path)]
|
|
1937
|
-
if ext.lower() != ".pt":
|
|
1938
|
-
path = path+".pt"
|
|
1939
|
-
model.save(path)
|
|
1940
|
-
return True
|
|
1941
|
-
|
|
1942
604
|
@staticmethod
|
|
1943
|
-
def
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
605
|
+
def _infer_num_classes(values: np.ndarray) -> int:
|
|
606
|
+
uniq = np.unique(values[~pd.isna(values)])
|
|
607
|
+
return int(len(uniq))
|
|
608
|
+
|
|
609
|
+
def _build_data_list(self):
|
|
610
|
+
assert self.graph_df is not None and self.nodes_df is not None and self.edges_df is not None
|
|
611
|
+
|
|
612
|
+
cfg = self.config
|
|
613
|
+
gdf = self.graph_df
|
|
614
|
+
ndf = self.nodes_df
|
|
615
|
+
edf = self.edges_df
|
|
616
|
+
|
|
617
|
+
graph_feat_cols = self._feature_columns(gdf, cfg.graph_features_header)
|
|
618
|
+
node_feat_cols = self._feature_columns(ndf, cfg.node_features_header)
|
|
619
|
+
edge_feat_cols = self._feature_columns(edf, cfg.edge_features_header)
|
|
620
|
+
|
|
621
|
+
if len(node_feat_cols) == 0:
|
|
622
|
+
raise ValueError(
|
|
623
|
+
f"PyG - Error: No node feature columns found. "
|
|
624
|
+
f"Expected columns starting with '{cfg.node_features_header}_'."
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
for gid in gdf[cfg.graph_id_header].unique():
|
|
628
|
+
g_row = gdf[gdf[cfg.graph_id_header] == gid]
|
|
629
|
+
g_nodes = ndf[ndf[cfg.graph_id_header] == gid].sort_values(cfg.node_id_header)
|
|
630
|
+
g_edges = edf[edf[cfg.graph_id_header] == gid]
|
|
631
|
+
|
|
632
|
+
x = torch.tensor(g_nodes[node_feat_cols].values, dtype=torch.float32)
|
|
633
|
+
edge_index = torch.tensor(
|
|
634
|
+
g_edges[[cfg.edge_src_header, cfg.edge_dst_header]].values.T,
|
|
635
|
+
dtype=torch.long
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
data = Data(x=x, edge_index=edge_index)
|
|
639
|
+
|
|
640
|
+
if len(edge_feat_cols) > 0:
|
|
641
|
+
data.edge_attr = torch.tensor(g_edges[edge_feat_cols].values, dtype=torch.float32)
|
|
642
|
+
|
|
643
|
+
# graph-level
|
|
644
|
+
if cfg.level == "graph":
|
|
645
|
+
y_val = g_row[cfg.graph_label_header].values[0]
|
|
646
|
+
if cfg.graph_label_type == "categorical":
|
|
647
|
+
data.y = torch.tensor([int(y_val)], dtype=torch.long)
|
|
648
|
+
else:
|
|
649
|
+
data.y = torch.tensor([float(y_val)], dtype=torch.float32)
|
|
1977
650
|
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
data = {'Model Type': [model.hparams.model_type],
|
|
1982
|
-
'Optimizer': [model.hparams.optimizer_str],
|
|
1983
|
-
'CV Type': [model.hparams.cv_type],
|
|
1984
|
-
'Split': model.hparams.split,
|
|
1985
|
-
'K-Folds': [model.hparams.k_folds],
|
|
1986
|
-
'HL Widths': model.hparams.hl_widths,
|
|
1987
|
-
'Conv Layer Type': [model.hparams.conv_layer_type],
|
|
1988
|
-
'Pooling': [model.hparams.pooling],
|
|
1989
|
-
'Learning Rate': [model.hparams.lr],
|
|
1990
|
-
'Batch Size': [model.hparams.batch_size],
|
|
1991
|
-
'Epochs': [model.hparams.epochs]
|
|
1992
|
-
}
|
|
1993
|
-
|
|
1994
|
-
if model.hparams.model_type.lower() == "classifier":
|
|
1995
|
-
testing_accuracy_list = [model.testing_accuracy] * model.hparams.epochs
|
|
1996
|
-
try:
|
|
1997
|
-
testing_loss_list = [model.testing_loss] * model.hparams.epochs
|
|
1998
|
-
except:
|
|
1999
|
-
testing_loss_list = [0.] * model.hparams.epochs
|
|
2000
|
-
metrics_data = {
|
|
2001
|
-
'Training Accuracy': [model.training_accuracy_list],
|
|
2002
|
-
'Validation Accuracy': [model.validation_accuracy_list],
|
|
2003
|
-
'Testing Accuracy' : [testing_accuracy_list],
|
|
2004
|
-
'Training Loss': [model.training_loss_list],
|
|
2005
|
-
'Validation Loss': [model.validation_loss_list],
|
|
2006
|
-
'Testing Loss' : [testing_loss_list]
|
|
2007
|
-
}
|
|
2008
|
-
if model.hparams.cv_type.lower() == "k-fold":
|
|
2009
|
-
accuracy_data = {
|
|
2010
|
-
'Accuracies' : [model.accuracies],
|
|
2011
|
-
'Max Accuracy' : [model.max_accuracy]
|
|
2012
|
-
}
|
|
2013
|
-
metrics_data.update(accuracy_data)
|
|
2014
|
-
data.update(metrics_data)
|
|
2015
|
-
|
|
2016
|
-
elif model.hparams.model_type.lower() == "regressor":
|
|
2017
|
-
testing_loss_list = [model.testing_loss] * model.hparams.epochs
|
|
2018
|
-
metrics_data = {
|
|
2019
|
-
'Training Loss': [model.training_loss_list],
|
|
2020
|
-
'Validation Loss': [model.validation_loss_list],
|
|
2021
|
-
'Testing Loss' : [testing_loss_list]
|
|
2022
|
-
}
|
|
2023
|
-
if model.hparams.cv_type.lower() == "k-fold":
|
|
2024
|
-
loss_data = {
|
|
2025
|
-
'Losses' : [model.losses],
|
|
2026
|
-
'Min Loss' : [model.min_loss]
|
|
2027
|
-
}
|
|
2028
|
-
metrics_data.update(loss_data)
|
|
2029
|
-
data.update(metrics_data)
|
|
2030
|
-
|
|
2031
|
-
return data
|
|
2032
|
-
|
|
2033
|
-
@staticmethod
|
|
2034
|
-
def Show(data,
|
|
2035
|
-
labels,
|
|
2036
|
-
title="Training/Validation",
|
|
2037
|
-
xTitle="Epochs",
|
|
2038
|
-
xSpacing=1,
|
|
2039
|
-
yTitle="Accuracy and Loss",
|
|
2040
|
-
ySpacing=0.1,
|
|
2041
|
-
useMarkers=False,
|
|
2042
|
-
chartType="Line",
|
|
2043
|
-
width=950,
|
|
2044
|
-
height=500,
|
|
2045
|
-
backgroundColor='rgba(0,0,0,0)',
|
|
2046
|
-
gridColor='lightgray',
|
|
2047
|
-
marginLeft=0,
|
|
2048
|
-
marginRight=0,
|
|
2049
|
-
marginTop=40,
|
|
2050
|
-
marginBottom=0,
|
|
2051
|
-
renderer = "notebook"):
|
|
2052
|
-
"""
|
|
2053
|
-
Shows the data in a plolty graph.
|
|
651
|
+
if len(graph_feat_cols) > 0:
|
|
652
|
+
data.u = torch.tensor(g_row[graph_feat_cols].values[0], dtype=torch.float32)
|
|
2054
653
|
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
|
|
2066
|
-
|
|
2067
|
-
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2092
|
-
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2096
|
-
|
|
2097
|
-
|
|
2098
|
-
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
|
|
2103
|
-
|
|
2104
|
-
|
|
654
|
+
# node-level
|
|
655
|
+
if cfg.level == "node":
|
|
656
|
+
y_vals = g_nodes[cfg.node_label_header].values
|
|
657
|
+
if cfg.node_label_type == "categorical":
|
|
658
|
+
data.y = torch.tensor(y_vals.astype(int), dtype=torch.long)
|
|
659
|
+
else:
|
|
660
|
+
data.y = torch.tensor(y_vals.astype(float), dtype=torch.float32)
|
|
661
|
+
data.train_mask, data.val_mask, data.test_mask = self._get_or_make_node_masks(g_nodes)
|
|
662
|
+
|
|
663
|
+
# edge-level
|
|
664
|
+
if cfg.level == "edge":
|
|
665
|
+
y_vals = g_edges[cfg.edge_label_header].values
|
|
666
|
+
if cfg.edge_label_type == "categorical":
|
|
667
|
+
data.edge_y = torch.tensor(y_vals.astype(int), dtype=torch.long)
|
|
668
|
+
else:
|
|
669
|
+
data.edge_y = torch.tensor(y_vals.astype(float), dtype=torch.float32)
|
|
670
|
+
data.edge_train_mask, data.edge_val_mask, data.edge_test_mask = self._get_or_make_edge_masks(g_edges)
|
|
671
|
+
|
|
672
|
+
self.data_list.append(data)
|
|
673
|
+
|
|
674
|
+
# output dimensionality
|
|
675
|
+
if cfg.level == "graph":
|
|
676
|
+
self._num_outputs = self._infer_num_classes(gdf[cfg.graph_label_header].values) if cfg.graph_label_type == "categorical" else 1
|
|
677
|
+
elif cfg.level == "node":
|
|
678
|
+
self._num_outputs = self._infer_num_classes(ndf[cfg.node_label_header].values) if cfg.node_label_type == "categorical" else 1
|
|
679
|
+
elif cfg.level == "edge":
|
|
680
|
+
self._num_outputs = self._infer_num_classes(edf[cfg.edge_label_header].values) if cfg.edge_label_type == "categorical" else 1
|
|
681
|
+
elif cfg.level == "link":
|
|
682
|
+
self._num_outputs = 1
|
|
683
|
+
else:
|
|
684
|
+
raise ValueError("Unsupported level.")
|
|
685
|
+
|
|
686
|
+
def _get_or_make_node_masks(self, g_nodes: pd.DataFrame):
|
|
687
|
+
cfg = self.config
|
|
688
|
+
cols = g_nodes.columns
|
|
689
|
+
|
|
690
|
+
if (cfg.node_train_mask_header in cols) and (cfg.node_val_mask_header in cols) and (cfg.node_test_mask_header in cols):
|
|
691
|
+
train_mask = torch.tensor(g_nodes[cfg.node_train_mask_header].astype(bool).values, dtype=torch.bool)
|
|
692
|
+
val_mask = torch.tensor(g_nodes[cfg.node_val_mask_header].astype(bool).values, dtype=torch.bool)
|
|
693
|
+
test_mask = torch.tensor(g_nodes[cfg.node_test_mask_header].astype(bool).values, dtype=torch.bool)
|
|
694
|
+
return train_mask, val_mask, test_mask
|
|
695
|
+
|
|
696
|
+
n = len(g_nodes)
|
|
697
|
+
idx = list(range(n))
|
|
698
|
+
if cfg.shuffle:
|
|
699
|
+
random.Random(cfg.random_state).shuffle(idx)
|
|
700
|
+
|
|
701
|
+
n_train = max(1, int(cfg.split[0] * n))
|
|
702
|
+
n_val = max(1, int(cfg.split[1] * n))
|
|
703
|
+
n_test = max(0, n - n_train - n_val)
|
|
704
|
+
|
|
705
|
+
train_idx = set(idx[:n_train])
|
|
706
|
+
val_idx = set(idx[n_train:n_train + n_val])
|
|
707
|
+
test_idx = set(idx[n_train + n_val:n_train + n_val + n_test])
|
|
708
|
+
|
|
709
|
+
train_mask = torch.tensor([i in train_idx for i in range(n)], dtype=torch.bool)
|
|
710
|
+
val_mask = torch.tensor([i in val_idx for i in range(n)], dtype=torch.bool)
|
|
711
|
+
test_mask = torch.tensor([i in test_idx for i in range(n)], dtype=torch.bool)
|
|
712
|
+
return train_mask, val_mask, test_mask
|
|
713
|
+
|
|
714
|
+
def _get_or_make_edge_masks(self, g_edges: pd.DataFrame):
|
|
715
|
+
cfg = self.config
|
|
716
|
+
cols = g_edges.columns
|
|
717
|
+
|
|
718
|
+
if (cfg.edge_train_mask_header in cols) and (cfg.edge_val_mask_header in cols) and (cfg.edge_test_mask_header in cols):
|
|
719
|
+
train_mask = torch.tensor(g_edges[cfg.edge_train_mask_header].astype(bool).values, dtype=torch.bool)
|
|
720
|
+
val_mask = torch.tensor(g_edges[cfg.edge_val_mask_header].astype(bool).values, dtype=torch.bool)
|
|
721
|
+
test_mask = torch.tensor(g_edges[cfg.edge_test_mask_header].astype(bool).values, dtype=torch.bool)
|
|
722
|
+
return train_mask, val_mask, test_mask
|
|
723
|
+
|
|
724
|
+
n = len(g_edges)
|
|
725
|
+
idx = list(range(n))
|
|
726
|
+
if cfg.shuffle:
|
|
727
|
+
random.Random(cfg.random_state).shuffle(idx)
|
|
728
|
+
|
|
729
|
+
n_train = max(1, int(cfg.split[0] * n))
|
|
730
|
+
n_val = max(1, int(cfg.split[1] * n))
|
|
731
|
+
n_test = max(0, n - n_train - n_val)
|
|
732
|
+
|
|
733
|
+
train_idx = set(idx[:n_train])
|
|
734
|
+
val_idx = set(idx[n_train:n_train + n_val])
|
|
735
|
+
test_idx = set(idx[n_train + n_val:n_train + n_val + n_test])
|
|
736
|
+
|
|
737
|
+
train_mask = torch.tensor([i in train_idx for i in range(n)], dtype=torch.bool)
|
|
738
|
+
val_mask = torch.tensor([i in val_idx for i in range(n)], dtype=torch.bool)
|
|
739
|
+
test_mask = torch.tensor([i in test_idx for i in range(n)], dtype=torch.bool)
|
|
740
|
+
return train_mask, val_mask, test_mask
|
|
741
|
+
|
|
742
|
+
# ----------------------------
|
|
743
|
+
# Holdout split (graph-level)
|
|
744
|
+
# ----------------------------
|
|
745
|
+
def _split_holdout(self):
|
|
746
|
+
cfg = self.config
|
|
747
|
+
if cfg.level in ["node", "edge", "link"]:
|
|
748
|
+
self.train_set = self.data_list
|
|
749
|
+
self.val_set = self.data_list
|
|
750
|
+
self.test_set = self.data_list
|
|
751
|
+
return
|
|
752
|
+
|
|
753
|
+
n = len(self.data_list)
|
|
754
|
+
idx = list(range(n))
|
|
755
|
+
if cfg.shuffle:
|
|
756
|
+
random.Random(cfg.random_state).shuffle(idx)
|
|
757
|
+
|
|
758
|
+
n_train = max(1, int(cfg.split[0] * n))
|
|
759
|
+
n_val = max(1, int(cfg.split[1] * n))
|
|
760
|
+
n_test = max(0, n - n_train - n_val)
|
|
761
|
+
|
|
762
|
+
train_idx = idx[:n_train]
|
|
763
|
+
val_idx = idx[n_train:n_train + n_val]
|
|
764
|
+
test_idx = idx[n_train + n_val:n_train + n_val + n_test]
|
|
765
|
+
|
|
766
|
+
self.train_set = [self.data_list[i] for i in train_idx]
|
|
767
|
+
self.val_set = [self.data_list[i] for i in val_idx]
|
|
768
|
+
self.test_set = [self.data_list[i] for i in test_idx]
|
|
769
|
+
|
|
770
|
+
# --------------
|
|
771
|
+
# Model building
|
|
772
|
+
# --------------
|
|
773
|
+
def _build_model(self):
|
|
774
|
+
cfg = self.config
|
|
775
|
+
in_dim = int(self.data_list[0].x.shape[1])
|
|
776
|
+
|
|
777
|
+
encoder = _GNNBackbone(
|
|
778
|
+
in_dim=in_dim,
|
|
779
|
+
hidden_dims=cfg.hidden_dims,
|
|
780
|
+
conv=cfg.conv,
|
|
781
|
+
activation=cfg.activation,
|
|
782
|
+
dropout=cfg.dropout,
|
|
783
|
+
batch_norm=cfg.batch_norm,
|
|
784
|
+
residual=cfg.residual
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if cfg.level == "graph":
|
|
788
|
+
head = _GraphHead(encoder.out_dim, self._num_outputs, pooling=cfg.pooling, dropout=cfg.dropout)
|
|
789
|
+
self.model = nn.ModuleDict({"encoder": encoder, "head": head}).to(self.device)
|
|
790
|
+
elif cfg.level == "node":
|
|
791
|
+
head = _NodeHead(encoder.out_dim, self._num_outputs, dropout=cfg.dropout)
|
|
792
|
+
self.model = nn.ModuleDict({"encoder": encoder, "head": head}).to(self.device)
|
|
793
|
+
elif cfg.level == "edge":
|
|
794
|
+
head = _EdgeHead(encoder.out_dim, self._num_outputs, dropout=cfg.dropout)
|
|
795
|
+
self.model = nn.ModuleDict({"encoder": encoder, "head": head}).to(self.device)
|
|
796
|
+
elif cfg.level == "link":
|
|
797
|
+
predictor = _LinkPredictor(encoder.out_dim, hidden=max(32, encoder.out_dim), dropout=cfg.dropout)
|
|
798
|
+
self.model = nn.ModuleDict({"encoder": encoder, "predictor": predictor}).to(self.device)
|
|
799
|
+
else:
|
|
800
|
+
raise ValueError("Unsupported level.")
|
|
2105
801
|
|
|
2106
|
-
|
|
2107
|
-
|
|
2108
|
-
|
|
802
|
+
if cfg.optimizer == "adamw":
|
|
803
|
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
|
804
|
+
else:
|
|
805
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
|
2109
806
|
|
|
807
|
+
if cfg.level == "link":
|
|
808
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
809
|
+
else:
|
|
810
|
+
if cfg.task == "regression":
|
|
811
|
+
self.criterion = nn.MSELoss()
|
|
812
|
+
else:
|
|
813
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
814
|
+
|
|
815
|
+
def _apply_gradients(self):
|
|
816
|
+
cfg = self.config
|
|
817
|
+
if cfg.gradient_clip_norm is not None:
|
|
818
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), float(cfg.gradient_clip_norm))
|
|
819
|
+
self.optimizer.step()
|
|
820
|
+
|
|
821
|
+
# -----------------------
|
|
822
|
+
# Training / evaluation
|
|
823
|
+
# -----------------------
|
|
824
|
+
def Train(self, epochs: Optional[int] = None, batch_size: Optional[int] = None) -> Dict[str, List[float]]:
|
|
2110
825
|
"""
|
|
2111
|
-
|
|
2112
|
-
|
|
2113
|
-
dataFrame = Plotly.DataByDGL(data, labels)
|
|
2114
|
-
fig = Plotly.FigureByDataFrame(dataFrame,
|
|
2115
|
-
labels=labels,
|
|
2116
|
-
title=title,
|
|
2117
|
-
xTitle=xTitle,
|
|
2118
|
-
xSpacing=xSpacing,
|
|
2119
|
-
yTitle=yTitle,
|
|
2120
|
-
ySpacing=ySpacing,
|
|
2121
|
-
useMarkers=useMarkers,
|
|
2122
|
-
chartType=chartType,
|
|
2123
|
-
width=width,
|
|
2124
|
-
height=height,
|
|
2125
|
-
backgroundColor=backgroundColor,
|
|
2126
|
-
gridColor=gridColor,
|
|
2127
|
-
marginRight=marginRight,
|
|
2128
|
-
marginLeft=marginLeft,
|
|
2129
|
-
marginTop=marginTop,
|
|
2130
|
-
marginBottom=marginBottom
|
|
2131
|
-
)
|
|
2132
|
-
Plotly.Show(fig, renderer=renderer)
|
|
2133
|
-
|
|
2134
|
-
@staticmethod
|
|
2135
|
-
def ModelLoad(path, model):
|
|
2136
|
-
"""
|
|
2137
|
-
Returns the model found at the input file path.
|
|
2138
|
-
|
|
2139
|
-
Parameters
|
|
2140
|
-
----------
|
|
2141
|
-
path : str
|
|
2142
|
-
File path for the saved classifier.
|
|
2143
|
-
model : torch.nn.module
|
|
2144
|
-
Initialized instance of model
|
|
826
|
+
Train a model using holdout splitting (or in-graph masks for node/edge tasks).
|
|
2145
827
|
|
|
2146
|
-
|
|
2147
|
-
-------
|
|
2148
|
-
PyG Classifier
|
|
2149
|
-
The classifier.
|
|
2150
|
-
|
|
2151
|
-
"""
|
|
2152
|
-
if not path:
|
|
2153
|
-
return None
|
|
2154
|
-
|
|
2155
|
-
model.load(path)
|
|
2156
|
-
return model
|
|
2157
|
-
|
|
2158
|
-
@staticmethod
|
|
2159
|
-
def ConfusionMatrix(actual, predicted, normalize=False):
|
|
828
|
+
If you want k-fold cross-validation for graph-level tasks, call CrossValidate().
|
|
2160
829
|
"""
|
|
2161
|
-
|
|
830
|
+
cfg = self.config
|
|
831
|
+
if epochs is not None:
|
|
832
|
+
cfg.epochs = int(epochs)
|
|
833
|
+
if batch_size is not None:
|
|
834
|
+
cfg.batch_size = int(batch_size)
|
|
835
|
+
|
|
836
|
+
self.history = {"train_loss": [], "val_loss": []}
|
|
837
|
+
|
|
838
|
+
if cfg.level == "graph":
|
|
839
|
+
train_loader = DataLoader(self.train_set, batch_size=cfg.batch_size, shuffle=True)
|
|
840
|
+
val_loader = DataLoader(self.val_set, batch_size=cfg.batch_size, shuffle=False)
|
|
841
|
+
|
|
842
|
+
best_val = float("inf")
|
|
843
|
+
patience = 0
|
|
844
|
+
|
|
845
|
+
for _ in range(cfg.epochs):
|
|
846
|
+
tr = self._train_epoch_graph(train_loader)
|
|
847
|
+
va = self._eval_epoch_graph(val_loader)
|
|
848
|
+
self.history["train_loss"].append(tr)
|
|
849
|
+
self.history["val_loss"].append(va)
|
|
850
|
+
|
|
851
|
+
if cfg.early_stopping:
|
|
852
|
+
if va < best_val - 1e-9:
|
|
853
|
+
best_val = va
|
|
854
|
+
patience = 0
|
|
855
|
+
else:
|
|
856
|
+
patience += 1
|
|
857
|
+
if patience >= int(cfg.early_stopping_patience):
|
|
858
|
+
break
|
|
859
|
+
|
|
860
|
+
elif cfg.level == "node":
|
|
861
|
+
train_loader = DataLoader(self.data_list, batch_size=1, shuffle=True)
|
|
862
|
+
val_loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
863
|
+
for _ in range(cfg.epochs):
|
|
864
|
+
tr = self._train_epoch_node(train_loader)
|
|
865
|
+
va = self._eval_epoch_node(val_loader)
|
|
866
|
+
self.history["train_loss"].append(tr)
|
|
867
|
+
self.history["val_loss"].append(va)
|
|
868
|
+
|
|
869
|
+
elif cfg.level == "edge":
|
|
870
|
+
train_loader = DataLoader(self.data_list, batch_size=1, shuffle=True)
|
|
871
|
+
val_loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
872
|
+
for _ in range(cfg.epochs):
|
|
873
|
+
tr = self._train_epoch_edge(train_loader)
|
|
874
|
+
va = self._eval_epoch_edge(val_loader)
|
|
875
|
+
self.history["train_loss"].append(tr)
|
|
876
|
+
self.history["val_loss"].append(va)
|
|
877
|
+
|
|
878
|
+
elif cfg.level == "link":
|
|
879
|
+
train_loader = DataLoader(self.data_list, batch_size=1, shuffle=True)
|
|
880
|
+
val_loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
881
|
+
for _ in range(cfg.epochs):
|
|
882
|
+
tr = self._train_epoch_link(train_loader)
|
|
883
|
+
va = self._eval_epoch_link(val_loader)
|
|
884
|
+
self.history["train_loss"].append(tr)
|
|
885
|
+
self.history["val_loss"].append(va)
|
|
2162
886
|
|
|
2163
|
-
Parameters
|
|
2164
|
-
----------
|
|
2165
|
-
actual : list
|
|
2166
|
-
The input list of actual labels.
|
|
2167
|
-
predicted : list
|
|
2168
|
-
The input list of predicts labels.
|
|
2169
|
-
normalized : bool , optional
|
|
2170
|
-
If set to True, the returned data will be normalized (proportion of 1). Otherwise, actual numbers are returned. Default is False.
|
|
2171
|
-
|
|
2172
|
-
Returns
|
|
2173
|
-
-------
|
|
2174
|
-
list
|
|
2175
|
-
The created confusion matrix.
|
|
2176
|
-
|
|
2177
|
-
"""
|
|
2178
|
-
|
|
2179
|
-
try:
|
|
2180
|
-
from sklearn import metrics
|
|
2181
|
-
from sklearn.metrics import accuracy_score
|
|
2182
|
-
except:
|
|
2183
|
-
print("PyG - Installing required scikit-learn (sklearn) library.")
|
|
2184
|
-
try:
|
|
2185
|
-
os.system("pip install scikit-learn")
|
|
2186
|
-
except:
|
|
2187
|
-
os.system("pip install scikit-learn --user")
|
|
2188
|
-
try:
|
|
2189
|
-
from sklearn import metrics
|
|
2190
|
-
from sklearn.metrics import accuracy_score
|
|
2191
|
-
print("PyG - scikit-learn (sklearn) library installed correctly.")
|
|
2192
|
-
except:
|
|
2193
|
-
warnings.warn("PyG - Error: Could not import scikit-learn (sklearn). Please try to install scikit-learn manually. Returning None.")
|
|
2194
|
-
return None
|
|
2195
|
-
|
|
2196
|
-
if not isinstance(actual, list):
|
|
2197
|
-
print("PyG.ConfusionMatrix - ERROR: The actual input is not a list. Returning None")
|
|
2198
|
-
return None
|
|
2199
|
-
if not isinstance(predicted, list):
|
|
2200
|
-
print("PyG.ConfusionMatrix - ERROR: The predicted input is not a list. Returning None")
|
|
2201
|
-
return None
|
|
2202
|
-
if len(actual) != len(predicted):
|
|
2203
|
-
print("PyG.ConfusionMatrix - ERROR: The two input lists do not have the same length. Returning None")
|
|
2204
|
-
return None
|
|
2205
|
-
if normalize:
|
|
2206
|
-
cm = np.transpose(metrics.confusion_matrix(y_true=actual, y_pred=predicted, normalize="true"))
|
|
2207
887
|
else:
|
|
2208
|
-
|
|
2209
|
-
return cm
|
|
888
|
+
raise ValueError("Unsupported level.")
|
|
2210
889
|
|
|
2211
|
-
|
|
2212
|
-
def ModelPredict(model, dataset, nodeATTRKey="feat"):
|
|
2213
|
-
"""
|
|
2214
|
-
Predicts the value of the input dataset.
|
|
2215
|
-
|
|
2216
|
-
Parameters
|
|
2217
|
-
----------
|
|
2218
|
-
dataset : PyGDataset
|
|
2219
|
-
The input PyG dataset.
|
|
2220
|
-
model : Model
|
|
2221
|
-
The input trained model.
|
|
2222
|
-
nodeATTRKey : str , optional
|
|
2223
|
-
The key used for node attributes. Default is "feat".
|
|
2224
|
-
|
|
2225
|
-
Returns
|
|
2226
|
-
-------
|
|
2227
|
-
list
|
|
2228
|
-
The list of predictions
|
|
2229
|
-
"""
|
|
2230
|
-
try:
|
|
2231
|
-
model = model.model #The inoput model might be our wrapper model. In that case, get its model attribute to do the prediciton.
|
|
2232
|
-
except:
|
|
2233
|
-
pass
|
|
2234
|
-
values = []
|
|
2235
|
-
dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
|
|
2236
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
2237
|
-
model.eval()
|
|
2238
|
-
for data in tqdm(dataloader, desc='Predicting', leave=False):
|
|
2239
|
-
data = data.to(device)
|
|
2240
|
-
pred = model(data)
|
|
2241
|
-
values.extend(list(np.round(pred.detach().cpu().numpy().flatten(), 3)))
|
|
2242
|
-
return values
|
|
890
|
+
return self.history
|
|
2243
891
|
|
|
2244
|
-
|
|
2245
|
-
|
|
892
|
+
def CrossValidate(self,
|
|
893
|
+
k_folds: Optional[int] = None,
|
|
894
|
+
epochs: Optional[int] = None,
|
|
895
|
+
batch_size: Optional[int] = None) -> Dict[str, Union[float, List[Dict[str, float]]]]:
|
|
2246
896
|
"""
|
|
2247
|
-
|
|
897
|
+
Perform K-Fold cross-validation (graph-level only).
|
|
2248
898
|
|
|
2249
899
|
Parameters
|
|
2250
900
|
----------
|
|
2251
|
-
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
901
|
+
k_folds : int, optional
|
|
902
|
+
Number of folds. Defaults to config.k_folds.
|
|
903
|
+
epochs : int, optional
|
|
904
|
+
Training epochs per fold. Defaults to config.epochs.
|
|
905
|
+
batch_size : int, optional
|
|
906
|
+
Batch size. Defaults to config.batch_size.
|
|
2257
907
|
|
|
2258
908
|
Returns
|
|
2259
909
|
-------
|
|
2260
910
|
dict
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
911
|
+
{
|
|
912
|
+
"fold_metrics": [ {metric: value, ...}, ... ],
|
|
913
|
+
"mean_<metric>": value,
|
|
914
|
+
"std_<metric>": value
|
|
915
|
+
}
|
|
2264
916
|
|
|
917
|
+
Notes
|
|
918
|
+
-----
|
|
919
|
+
- Supported only for graph-level tasks (node/edge tasks typically use per-graph masks).
|
|
920
|
+
- If labels are categorical, stratified splits can be enabled (config.k_stratify).
|
|
2265
921
|
"""
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2301
|
-
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
|
|
922
|
+
cfg = self.config
|
|
923
|
+
if cfg.level != "graph":
|
|
924
|
+
raise ValueError("CrossValidate is supported for graph-level tasks only.")
|
|
925
|
+
|
|
926
|
+
if k_folds is None:
|
|
927
|
+
k_folds = int(cfg.k_folds)
|
|
928
|
+
if k_folds < 2:
|
|
929
|
+
raise ValueError("k_folds must be >= 2.")
|
|
930
|
+
if epochs is not None:
|
|
931
|
+
cfg.epochs = int(epochs)
|
|
932
|
+
if batch_size is not None:
|
|
933
|
+
cfg.batch_size = int(batch_size)
|
|
934
|
+
|
|
935
|
+
n = len(self.data_list)
|
|
936
|
+
indices = np.arange(n)
|
|
937
|
+
|
|
938
|
+
# Stratification labels (optional)
|
|
939
|
+
y = None
|
|
940
|
+
if cfg.k_stratify and cfg.task == "classification" and cfg.graph_label_type == "categorical":
|
|
941
|
+
y = np.array([int(d.y.item()) for d in self.data_list], dtype=int)
|
|
942
|
+
|
|
943
|
+
rng = np.random.RandomState(cfg.random_state)
|
|
944
|
+
if cfg.k_shuffle:
|
|
945
|
+
rng.shuffle(indices)
|
|
946
|
+
|
|
947
|
+
# Build folds
|
|
948
|
+
if y is None:
|
|
949
|
+
folds = np.array_split(indices, k_folds)
|
|
950
|
+
else:
|
|
951
|
+
folds = [np.array([], dtype=int) for _ in range(k_folds)]
|
|
952
|
+
classes = np.unique(y)
|
|
953
|
+
for c in classes:
|
|
954
|
+
cls_idx = indices[y[indices] == c]
|
|
955
|
+
cls_chunks = np.array_split(cls_idx, k_folds)
|
|
956
|
+
for fi in range(k_folds):
|
|
957
|
+
folds[fi] = np.concatenate([folds[fi], cls_chunks[fi]])
|
|
958
|
+
folds = [rng.permutation(f) for f in folds]
|
|
959
|
+
|
|
960
|
+
fold_metrics: List[Dict[str, float]] = []
|
|
961
|
+
base_config = copy.deepcopy(cfg)
|
|
962
|
+
|
|
963
|
+
for fi in range(k_folds):
|
|
964
|
+
test_idx = folds[fi]
|
|
965
|
+
train_idx = np.concatenate([folds[j] for j in range(k_folds) if j != fi])
|
|
966
|
+
|
|
967
|
+
train_set = [self.data_list[i] for i in train_idx.tolist()]
|
|
968
|
+
test_set = [self.data_list[i] for i in test_idx.tolist()]
|
|
969
|
+
|
|
970
|
+
# Fresh model per fold
|
|
971
|
+
self.config = copy.deepcopy(base_config)
|
|
972
|
+
self._build_model()
|
|
973
|
+
self.history = {"train_loss": [], "val_loss": []}
|
|
974
|
+
|
|
975
|
+
train_loader = DataLoader(train_set, batch_size=self.config.batch_size, shuffle=True)
|
|
976
|
+
test_loader = DataLoader(test_set, batch_size=self.config.batch_size, shuffle=False)
|
|
977
|
+
|
|
978
|
+
best_loss = float("inf")
|
|
979
|
+
patience = 0
|
|
980
|
+
|
|
981
|
+
for _ in range(self.config.epochs):
|
|
982
|
+
tr_loss = self._train_epoch_graph(train_loader)
|
|
983
|
+
te_loss = self._eval_epoch_graph(test_loader)
|
|
984
|
+
self.history["train_loss"].append(tr_loss)
|
|
985
|
+
self.history["val_loss"].append(te_loss)
|
|
986
|
+
|
|
987
|
+
if self.config.early_stopping:
|
|
988
|
+
if te_loss < best_loss - 1e-9:
|
|
989
|
+
best_loss = te_loss
|
|
990
|
+
patience = 0
|
|
991
|
+
else:
|
|
992
|
+
patience += 1
|
|
993
|
+
if patience >= int(self.config.early_stopping_patience):
|
|
994
|
+
break
|
|
995
|
+
|
|
996
|
+
# Metrics (unprefixed) for the fold
|
|
997
|
+
metrics = self._metrics_graph(test_loader, prefix="")
|
|
998
|
+
metrics["fold"] = float(fi)
|
|
999
|
+
fold_metrics.append(metrics)
|
|
1000
|
+
|
|
1001
|
+
# Restore original config and rebuild model
|
|
1002
|
+
self.config = copy.deepcopy(base_config)
|
|
1003
|
+
self._build_model()
|
|
1004
|
+
|
|
1005
|
+
# Aggregate
|
|
1006
|
+
summary: Dict[str, Union[float, List[Dict[str, float]]]] = {"fold_metrics": fold_metrics}
|
|
1007
|
+
metric_keys = [k for k in fold_metrics[0].keys()] if fold_metrics else []
|
|
1008
|
+
metric_keys = [k for k in metric_keys if k != "fold"]
|
|
1009
|
+
|
|
1010
|
+
for k in metric_keys:
|
|
1011
|
+
vals = np.array([fm[k] for fm in fold_metrics], dtype=float)
|
|
1012
|
+
summary[f"mean_{k}"] = float(np.mean(vals))
|
|
1013
|
+
summary[f"std_{k}"] = float(np.std(vals))
|
|
1014
|
+
|
|
1015
|
+
self.cv_report = summary
|
|
1016
|
+
return summary
|
|
1017
|
+
|
|
1018
|
+
def Validate(self) -> Dict[str, float]:
|
|
1019
|
+
cfg = self.config
|
|
1020
|
+
if cfg.level == "graph":
|
|
1021
|
+
loader = DataLoader(self.val_set, batch_size=cfg.batch_size, shuffle=False)
|
|
1022
|
+
return self._metrics_graph(loader, prefix="val_")
|
|
1023
|
+
if cfg.level == "node":
|
|
1024
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1025
|
+
return self._metrics_node(loader, split="val")
|
|
1026
|
+
if cfg.level == "edge":
|
|
1027
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1028
|
+
return self._metrics_edge(loader, split="val")
|
|
1029
|
+
if cfg.level == "link":
|
|
1030
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1031
|
+
return self._metrics_link(loader, split="val")
|
|
1032
|
+
raise ValueError("Unsupported level.")
|
|
1033
|
+
|
|
1034
|
+
def Test(self) -> Dict[str, float]:
|
|
1035
|
+
cfg = self.config
|
|
1036
|
+
if cfg.level == "graph":
|
|
1037
|
+
loader = DataLoader(self.test_set, batch_size=cfg.batch_size, shuffle=False)
|
|
1038
|
+
return self._metrics_graph(loader, prefix="test_")
|
|
1039
|
+
if cfg.level == "node":
|
|
1040
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1041
|
+
return self._metrics_node(loader, split="test")
|
|
1042
|
+
if cfg.level == "edge":
|
|
1043
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1044
|
+
return self._metrics_edge(loader, split="test")
|
|
1045
|
+
if cfg.level == "link":
|
|
1046
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1047
|
+
return self._metrics_link(loader, split="test")
|
|
1048
|
+
raise ValueError("Unsupported level.")
|
|
1049
|
+
|
|
1050
|
+
# --------
|
|
1051
|
+
# Epochs
|
|
1052
|
+
# --------
|
|
1053
|
+
def _loss_from_logits(self, logits, y, task: TaskKind):
|
|
1054
|
+
if task == "regression":
|
|
1055
|
+
pred = logits.squeeze(-1)
|
|
1056
|
+
return self.criterion(pred.float(), y.float())
|
|
1057
|
+
return self.criterion(logits, y.long())
|
|
1058
|
+
|
|
1059
|
+
def _train_epoch_graph(self, loader):
|
|
1060
|
+
self.model.train()
|
|
1061
|
+
losses = []
|
|
1062
|
+
for batch in loader:
|
|
1063
|
+
batch = batch.to(self.device)
|
|
1064
|
+
self.optimizer.zero_grad()
|
|
1065
|
+
node_emb = self.model["encoder"](batch.x, batch.edge_index)
|
|
1066
|
+
logits = self.model["head"](node_emb, batch.batch)
|
|
1067
|
+
loss = self._loss_from_logits(logits, batch.y, self.config.task)
|
|
1068
|
+
loss.backward()
|
|
1069
|
+
self._apply_gradients()
|
|
1070
|
+
losses.append(float(loss.detach().cpu()))
|
|
1071
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1072
|
+
|
|
1073
|
+
@torch.no_grad()
|
|
1074
|
+
def _eval_epoch_graph(self, loader):
|
|
1075
|
+
self.model.eval()
|
|
1076
|
+
losses = []
|
|
1077
|
+
for batch in loader:
|
|
1078
|
+
batch = batch.to(self.device)
|
|
1079
|
+
node_emb = self.model["encoder"](batch.x, batch.edge_index)
|
|
1080
|
+
logits = self.model["head"](node_emb, batch.batch)
|
|
1081
|
+
loss = self._loss_from_logits(logits, batch.y, self.config.task)
|
|
1082
|
+
losses.append(float(loss.detach().cpu()))
|
|
1083
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1084
|
+
|
|
1085
|
+
def _train_epoch_node(self, loader):
|
|
1086
|
+
self.model.train()
|
|
1087
|
+
losses = []
|
|
1088
|
+
for data in loader:
|
|
1089
|
+
data = data.to(self.device)
|
|
1090
|
+
self.optimizer.zero_grad()
|
|
1091
|
+
node_emb = self.model["encoder"](data.x, data.edge_index)
|
|
1092
|
+
logits = self.model["head"](node_emb)
|
|
1093
|
+
mask = data.train_mask
|
|
1094
|
+
loss = self._loss_from_logits(logits[mask], data.y[mask], self.config.task)
|
|
1095
|
+
loss.backward()
|
|
1096
|
+
self._apply_gradients()
|
|
1097
|
+
losses.append(float(loss.detach().cpu()))
|
|
1098
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1099
|
+
|
|
1100
|
+
@torch.no_grad()
|
|
1101
|
+
def _eval_epoch_node(self, loader):
|
|
1102
|
+
self.model.eval()
|
|
1103
|
+
losses = []
|
|
1104
|
+
for data in loader:
|
|
1105
|
+
data = data.to(self.device)
|
|
1106
|
+
node_emb = self.model["encoder"](data.x, data.edge_index)
|
|
1107
|
+
logits = self.model["head"](node_emb)
|
|
1108
|
+
mask = data.val_mask
|
|
1109
|
+
loss = self._loss_from_logits(logits[mask], data.y[mask], self.config.task)
|
|
1110
|
+
losses.append(float(loss.detach().cpu()))
|
|
1111
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1112
|
+
|
|
1113
|
+
def _train_epoch_edge(self, loader):
|
|
1114
|
+
self.model.train()
|
|
1115
|
+
losses = []
|
|
1116
|
+
for data in loader:
|
|
1117
|
+
data = data.to(self.device)
|
|
1118
|
+
self.optimizer.zero_grad()
|
|
1119
|
+
node_emb = self.model["encoder"](data.x, data.edge_index)
|
|
1120
|
+
logits = self.model["head"](node_emb, data.edge_index)
|
|
1121
|
+
mask = data.edge_train_mask
|
|
1122
|
+
loss = self._loss_from_logits(logits[mask], data.edge_y[mask], self.config.task)
|
|
1123
|
+
loss.backward()
|
|
1124
|
+
self._apply_gradients()
|
|
1125
|
+
losses.append(float(loss.detach().cpu()))
|
|
1126
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1127
|
+
|
|
1128
|
+
@torch.no_grad()
|
|
1129
|
+
def _eval_epoch_edge(self, loader):
|
|
1130
|
+
self.model.eval()
|
|
1131
|
+
losses = []
|
|
1132
|
+
for data in loader:
|
|
1133
|
+
data = data.to(self.device)
|
|
1134
|
+
node_emb = self.model["encoder"](data.x, data.edge_index)
|
|
1135
|
+
logits = self.model["head"](node_emb, data.edge_index)
|
|
1136
|
+
mask = data.edge_val_mask
|
|
1137
|
+
loss = self._loss_from_logits(logits[mask], data.edge_y[mask], self.config.task)
|
|
1138
|
+
losses.append(float(loss.detach().cpu()))
|
|
1139
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1140
|
+
|
|
1141
|
+
def _train_epoch_link(self, loader):
|
|
1142
|
+
self.model.train()
|
|
1143
|
+
losses = []
|
|
1144
|
+
split = RandomLinkSplit(
|
|
1145
|
+
num_val=self.config.link_val_ratio,
|
|
1146
|
+
num_test=self.config.link_test_ratio,
|
|
1147
|
+
is_undirected=self.config.link_is_undirected,
|
|
1148
|
+
add_negative_train_samples=True,
|
|
1149
|
+
neg_sampling_ratio=1.0
|
|
1150
|
+
)
|
|
1151
|
+
for data in loader:
|
|
1152
|
+
train_data, _, _ = split(data)
|
|
1153
|
+
train_data = train_data.to(self.device)
|
|
1154
|
+
self.optimizer.zero_grad()
|
|
1155
|
+
node_emb = self.model["encoder"](train_data.x, train_data.edge_index)
|
|
1156
|
+
logits = self.model["predictor"](node_emb, train_data.edge_label_index)
|
|
1157
|
+
loss = self.criterion(logits, train_data.edge_label.float())
|
|
1158
|
+
loss.backward()
|
|
1159
|
+
self._apply_gradients()
|
|
1160
|
+
losses.append(float(loss.detach().cpu()))
|
|
1161
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1162
|
+
|
|
1163
|
+
@torch.no_grad()
|
|
1164
|
+
def _eval_epoch_link(self, loader):
|
|
1165
|
+
self.model.eval()
|
|
1166
|
+
losses = []
|
|
1167
|
+
split = RandomLinkSplit(
|
|
1168
|
+
num_val=self.config.link_val_ratio,
|
|
1169
|
+
num_test=self.config.link_test_ratio,
|
|
1170
|
+
is_undirected=self.config.link_is_undirected,
|
|
1171
|
+
add_negative_train_samples=True,
|
|
1172
|
+
neg_sampling_ratio=1.0
|
|
1173
|
+
)
|
|
1174
|
+
for data in loader:
|
|
1175
|
+
_, val_data, _ = split(data)
|
|
1176
|
+
val_data = val_data.to(self.device)
|
|
1177
|
+
node_emb = self.model["encoder"](val_data.x, val_data.edge_index)
|
|
1178
|
+
logits = self.model["predictor"](node_emb, val_data.edge_label_index)
|
|
1179
|
+
loss = self.criterion(logits, val_data.edge_label.float())
|
|
1180
|
+
losses.append(float(loss.detach().cpu()))
|
|
1181
|
+
return float(np.mean(losses)) if losses else 0.0
|
|
1182
|
+
|
|
1183
|
+
# --------------
|
|
1184
|
+
# Metrics helpers
|
|
1185
|
+
# --------------
|
|
1186
|
+
@torch.no_grad()
|
|
1187
|
+
def _predict_graph(self, loader):
|
|
1188
|
+
self.model.eval()
|
|
1189
|
+
y_true, y_pred = [], []
|
|
1190
|
+
for batch in loader:
|
|
1191
|
+
batch = batch.to(self.device)
|
|
1192
|
+
node_emb = self.model["encoder"](batch.x, batch.edge_index)
|
|
1193
|
+
out = self.model["head"](node_emb, batch.batch)
|
|
1194
|
+
if self.config.task == "regression":
|
|
1195
|
+
y_true.extend(batch.y.squeeze(-1).detach().cpu().numpy().tolist())
|
|
1196
|
+
y_pred.extend(out.squeeze(-1).detach().cpu().numpy().tolist())
|
|
1197
|
+
else:
|
|
1198
|
+
probs = F.softmax(out, dim=-1)
|
|
1199
|
+
y_true.extend(batch.y.detach().cpu().numpy().tolist())
|
|
1200
|
+
y_pred.extend(probs.argmax(dim=-1).detach().cpu().numpy().tolist())
|
|
1201
|
+
return np.array(y_true), np.array(y_pred)
|
|
2309
1202
|
|
|
2310
|
-
|
|
2311
|
-
|
|
2312
|
-
|
|
2313
|
-
|
|
2314
|
-
|
|
2315
|
-
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
|
|
1203
|
+
@torch.no_grad()
|
|
1204
|
+
def _predict_node(self, loader, mask_name: str):
|
|
1205
|
+
self.model.eval()
|
|
1206
|
+
y_true, y_pred = [], []
|
|
1207
|
+
for data in loader:
|
|
1208
|
+
data = data.to(self.device)
|
|
1209
|
+
node_emb = self.model["encoder"](data.x, data.edge_index)
|
|
1210
|
+
out = self.model["head"](node_emb)
|
|
1211
|
+
mask = getattr(data, mask_name)
|
|
1212
|
+
if self.config.task == "regression":
|
|
1213
|
+
y_true.extend(data.y[mask].detach().cpu().numpy().tolist())
|
|
1214
|
+
y_pred.extend(out.squeeze(-1)[mask].detach().cpu().numpy().tolist())
|
|
2319
1215
|
else:
|
|
2320
|
-
|
|
2321
|
-
|
|
2322
|
-
|
|
2323
|
-
|
|
2324
|
-
return {"accuracy":accuracy, "correct":correct, "mask":mask, "size":size, "wrong":wrong}
|
|
2325
|
-
|
|
2326
|
-
@staticmethod
|
|
2327
|
-
def MSE(actual, predicted, mantissa: int = 6):
|
|
2328
|
-
"""
|
|
2329
|
-
Computes the Mean Squared Error (MSE) of the input predictions based on the input labels. This is to be used with regression models.
|
|
1216
|
+
probs = F.softmax(out, dim=-1)
|
|
1217
|
+
y_true.extend(data.y[mask].detach().cpu().numpy().tolist())
|
|
1218
|
+
y_pred.extend(probs.argmax(dim=-1)[mask].detach().cpu().numpy().tolist())
|
|
1219
|
+
return np.array(y_true), np.array(y_pred)
|
|
2330
1220
|
|
|
2331
|
-
|
|
2332
|
-
|
|
2333
|
-
|
|
2334
|
-
|
|
2335
|
-
|
|
2336
|
-
|
|
2337
|
-
|
|
2338
|
-
|
|
1221
|
+
@torch.no_grad()
|
|
1222
|
+
def _predict_edge(self, loader, mask_name: str):
|
|
1223
|
+
self.model.eval()
|
|
1224
|
+
y_true, y_pred = [], []
|
|
1225
|
+
for data in loader:
|
|
1226
|
+
data = data.to(self.device)
|
|
1227
|
+
node_emb = self.model["encoder"](data.x, data.edge_index)
|
|
1228
|
+
out = self.model["head"](node_emb, data.edge_index)
|
|
1229
|
+
mask = getattr(data, mask_name)
|
|
1230
|
+
if self.config.task == "regression":
|
|
1231
|
+
y_true.extend(data.edge_y[mask].detach().cpu().numpy().tolist())
|
|
1232
|
+
y_pred.extend(out.squeeze(-1)[mask].detach().cpu().numpy().tolist())
|
|
1233
|
+
else:
|
|
1234
|
+
probs = F.softmax(out, dim=-1)
|
|
1235
|
+
y_true.extend(data.edge_y[mask].detach().cpu().numpy().tolist())
|
|
1236
|
+
y_pred.extend(probs.argmax(dim=-1)[mask].detach().cpu().numpy().tolist())
|
|
1237
|
+
return np.array(y_true), np.array(y_pred)
|
|
2339
1238
|
|
|
2340
|
-
|
|
2341
|
-
|
|
2342
|
-
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
|
|
2346
|
-
|
|
2347
|
-
|
|
2348
|
-
|
|
2349
|
-
|
|
2350
|
-
|
|
2351
|
-
|
|
2352
|
-
|
|
1239
|
+
@torch.no_grad()
|
|
1240
|
+
def _predict_link(self, loader, split_name: str):
|
|
1241
|
+
self.model.eval()
|
|
1242
|
+
split = RandomLinkSplit(
|
|
1243
|
+
num_val=self.config.link_val_ratio,
|
|
1244
|
+
num_test=self.config.link_test_ratio,
|
|
1245
|
+
is_undirected=self.config.link_is_undirected,
|
|
1246
|
+
add_negative_train_samples=True,
|
|
1247
|
+
neg_sampling_ratio=1.0
|
|
1248
|
+
)
|
|
1249
|
+
y_true, y_score = [], []
|
|
1250
|
+
for data in loader:
|
|
1251
|
+
tr, va, te = split(data)
|
|
1252
|
+
use = {"train": tr, "val": va, "test": te}[split_name]
|
|
1253
|
+
use = use.to(self.device)
|
|
1254
|
+
node_emb = self.model["encoder"](use.x, use.edge_index)
|
|
1255
|
+
logits = self.model["predictor"](node_emb, use.edge_label_index)
|
|
1256
|
+
probs = torch.sigmoid(logits).detach().cpu().numpy()
|
|
1257
|
+
y = use.edge_label.detach().cpu().numpy()
|
|
1258
|
+
y_true.extend(y.tolist())
|
|
1259
|
+
y_score.extend(probs.tolist())
|
|
1260
|
+
return np.array(y_true), np.array(y_score)
|
|
1261
|
+
|
|
1262
|
+
# ----------
|
|
1263
|
+
# Public metrics API
|
|
1264
|
+
# ----------
|
|
1265
|
+
def _metrics_graph(self, loader, prefix: str):
|
|
1266
|
+
y_true, y_pred = self._predict_graph(loader)
|
|
1267
|
+
if self.config.task == "regression":
|
|
1268
|
+
return self._regression_metrics(y_true, y_pred, prefix=prefix)
|
|
1269
|
+
return self._classification_metrics(y_true, y_pred, prefix=prefix)
|
|
1270
|
+
|
|
1271
|
+
def _metrics_node(self, loader, split: Literal["train", "val", "test"]):
|
|
1272
|
+
mask = "train_mask" if split == "train" else ("val_mask" if split == "val" else "test_mask")
|
|
1273
|
+
y_true, y_pred = self._predict_node(loader, mask)
|
|
1274
|
+
if self.config.task == "regression":
|
|
1275
|
+
return self._regression_metrics(y_true, y_pred, prefix=f"{split}_")
|
|
1276
|
+
return self._classification_metrics(y_true, y_pred, prefix=f"{split}_")
|
|
1277
|
+
|
|
1278
|
+
def _metrics_edge(self, loader, split: Literal["train", "val", "test"]):
|
|
1279
|
+
mask = "edge_train_mask" if split == "train" else ("edge_val_mask" if split == "val" else "edge_test_mask")
|
|
1280
|
+
y_true, y_pred = self._predict_edge(loader, mask)
|
|
1281
|
+
if self.config.task == "regression":
|
|
1282
|
+
return self._regression_metrics(y_true, y_pred, prefix=f"{split}_")
|
|
1283
|
+
return self._classification_metrics(y_true, y_pred, prefix=f"{split}_")
|
|
1284
|
+
|
|
1285
|
+
def _metrics_link(self, loader, split: Literal["train", "val", "test"]):
|
|
1286
|
+
y_true, y_score = self._predict_link(loader, split)
|
|
1287
|
+
y_pred = (y_score >= 0.5).astype(int)
|
|
1288
|
+
return self._classification_metrics(y_true, y_pred, prefix=f"{split}_")
|
|
2353
1289
|
|
|
2354
|
-
|
|
1290
|
+
@staticmethod
|
|
1291
|
+
def _classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str = "") -> Dict[str, float]:
|
|
1292
|
+
acc = float(accuracy_score(y_true, y_pred)) if len(y_true) else 0.0
|
|
1293
|
+
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)
|
|
1294
|
+
return {
|
|
1295
|
+
f"{prefix}accuracy": float(acc),
|
|
1296
|
+
f"{prefix}precision": float(prec),
|
|
1297
|
+
f"{prefix}recall": float(rec),
|
|
1298
|
+
f"{prefix}f1": float(f1),
|
|
1299
|
+
}
|
|
2355
1300
|
|
|
2356
1301
|
@staticmethod
|
|
2357
|
-
def
|
|
2358
|
-
|
|
2359
|
-
|
|
1302
|
+
def _regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str = "") -> Dict[str, float]:
|
|
1303
|
+
if len(y_true) == 0:
|
|
1304
|
+
return {f"{prefix}mae": 0.0, f"{prefix}rmse": 0.0, f"{prefix}r2": 0.0}
|
|
1305
|
+
mae = float(mean_absolute_error(y_true, y_pred))
|
|
1306
|
+
rmse = float(math.sqrt(mean_squared_error(y_true, y_pred)))
|
|
1307
|
+
r2 = float(r2_score(y_true, y_pred))
|
|
1308
|
+
return {f"{prefix}mae": mae, f"{prefix}rmse": rmse, f"{prefix}r2": r2}
|
|
1309
|
+
|
|
1310
|
+
# -----------------
|
|
1311
|
+
# Plotly reporting
|
|
1312
|
+
# -----------------
|
|
1313
|
+
def PlotHistory(self):
|
|
1314
|
+
fig = go.Figure()
|
|
1315
|
+
fig.add_trace(go.Scatter(y=self.history["train_loss"], mode="lines+markers", name="Train Loss"))
|
|
1316
|
+
fig.add_trace(go.Scatter(y=self.history["val_loss"], mode="lines+markers", name="Val Loss"))
|
|
1317
|
+
fig.update_layout(title="Training History", xaxis_title="Epoch", yaxis_title="Loss")
|
|
1318
|
+
return fig
|
|
1319
|
+
|
|
1320
|
+
def PlotConfusionMatrix(self, split: Literal["train", "val", "test"] = "test"):
|
|
1321
|
+
if self.config.task != "classification" or self.config.level == "link":
|
|
1322
|
+
raise ValueError("Confusion matrix is only available for classification (graph/node/edge).")
|
|
1323
|
+
|
|
1324
|
+
if self.config.level == "graph":
|
|
1325
|
+
if split == "train":
|
|
1326
|
+
loader = DataLoader(self.train_set, batch_size=self.config.batch_size, shuffle=False)
|
|
1327
|
+
elif split == "val":
|
|
1328
|
+
loader = DataLoader(self.val_set, batch_size=self.config.batch_size, shuffle=False)
|
|
1329
|
+
else:
|
|
1330
|
+
loader = DataLoader(self.test_set, batch_size=self.config.batch_size, shuffle=False)
|
|
1331
|
+
y_true, y_pred = self._predict_graph(loader)
|
|
1332
|
+
|
|
1333
|
+
elif self.config.level == "node":
|
|
1334
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1335
|
+
mask = "train_mask" if split == "train" else ("val_mask" if split == "val" else "test_mask")
|
|
1336
|
+
y_true, y_pred = self._predict_node(loader, mask)
|
|
1337
|
+
|
|
1338
|
+
else: # edge
|
|
1339
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1340
|
+
mask = "edge_train_mask" if split == "train" else ("edge_val_mask" if split == "val" else "edge_test_mask")
|
|
1341
|
+
y_true, y_pred = self._predict_edge(loader, mask)
|
|
1342
|
+
|
|
1343
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
1344
|
+
fig = px.imshow(cm, text_auto=True, title=f"Confusion Matrix ({split})")
|
|
1345
|
+
fig.update_layout(xaxis_title="Predicted", yaxis_title="True")
|
|
1346
|
+
return fig
|
|
1347
|
+
|
|
1348
|
+
def PlotParity(self, split: Literal["train", "val", "test"] = "test"):
|
|
1349
|
+
if self.config.task != "regression":
|
|
1350
|
+
raise ValueError("Parity plot is only available for regression tasks.")
|
|
1351
|
+
|
|
1352
|
+
if self.config.level == "graph":
|
|
1353
|
+
if split == "train":
|
|
1354
|
+
loader = DataLoader(self.train_set, batch_size=self.config.batch_size, shuffle=False)
|
|
1355
|
+
elif split == "val":
|
|
1356
|
+
loader = DataLoader(self.val_set, batch_size=self.config.batch_size, shuffle=False)
|
|
1357
|
+
else:
|
|
1358
|
+
loader = DataLoader(self.test_set, batch_size=self.config.batch_size, shuffle=False)
|
|
1359
|
+
y_true, y_pred = self._predict_graph(loader)
|
|
1360
|
+
|
|
1361
|
+
elif self.config.level == "node":
|
|
1362
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1363
|
+
mask = "train_mask" if split == "train" else ("val_mask" if split == "val" else "test_mask")
|
|
1364
|
+
y_true, y_pred = self._predict_node(loader, mask)
|
|
2360
1365
|
|
|
2361
|
-
Parameters
|
|
2362
|
-
----------
|
|
2363
|
-
actual : list
|
|
2364
|
-
The input list of actual values.
|
|
2365
|
-
predicted : list
|
|
2366
|
-
The input list of predicted values.
|
|
2367
|
-
mantissa : int , optional
|
|
2368
|
-
The number of decimal places to round the result to. Default is 6.
|
|
2369
|
-
|
|
2370
|
-
Returns
|
|
2371
|
-
-------
|
|
2372
|
-
dict
|
|
2373
|
-
The dictionary containing the performance measures. The keys in the dictionary are: 'mae', 'mape', 'mse', 'r', 'r2', 'rmse'.
|
|
2374
|
-
"""
|
|
2375
|
-
|
|
2376
|
-
if not isinstance(actual, list):
|
|
2377
|
-
print("PyG.Performance - ERROR: The actual input is not a list. Returning None")
|
|
2378
|
-
return None
|
|
2379
|
-
if not isinstance(predicted, list):
|
|
2380
|
-
print("PyG.Performance - ERROR: The predicted input is not a list. Returning None")
|
|
2381
|
-
return None
|
|
2382
|
-
if not (len(actual) == len(predicted)):
|
|
2383
|
-
print("PyG.Performance - ERROR: The actual and predicted input lists have different lengths. Returning None")
|
|
2384
|
-
return None
|
|
2385
|
-
|
|
2386
|
-
predicted = np.array(predicted)
|
|
2387
|
-
actual = np.array(actual)
|
|
2388
|
-
|
|
2389
|
-
mae = np.mean(np.abs(predicted - actual))
|
|
2390
|
-
mape = np.mean(np.abs((actual - predicted) / actual))*100
|
|
2391
|
-
mse = np.mean((predicted - actual)**2)
|
|
2392
|
-
correlation_matrix = np.corrcoef(predicted, actual)
|
|
2393
|
-
r = correlation_matrix[0, 1]
|
|
2394
|
-
r2 = r**2
|
|
2395
|
-
absolute_errors = np.abs(predicted - actual)
|
|
2396
|
-
mean_actual = np.mean(actual)
|
|
2397
|
-
if mean_actual == 0:
|
|
2398
|
-
rae = None
|
|
2399
1366
|
else:
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
|
|
2403
|
-
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
1367
|
+
loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
|
|
1368
|
+
mask = "edge_train_mask" if split == "train" else ("edge_val_mask" if split == "val" else "edge_test_mask")
|
|
1369
|
+
y_true, y_pred = self._predict_edge(loader, mask)
|
|
1370
|
+
|
|
1371
|
+
fig = go.Figure()
|
|
1372
|
+
fig.add_trace(go.Scatter(x=y_true, y=y_pred, mode="markers", name="Predictions"))
|
|
1373
|
+
mn = float(min(np.min(y_true), np.min(y_pred))) if len(y_true) else 0.0
|
|
1374
|
+
mx = float(max(np.max(y_true), np.max(y_pred))) if len(y_true) else 1.0
|
|
1375
|
+
fig.add_trace(go.Scatter(x=[mn, mx], y=[mn, mx], mode="lines", name="Ideal"))
|
|
1376
|
+
fig.update_layout(title=f"Parity Plot ({split})", xaxis_title="True", yaxis_title="Predicted")
|
|
1377
|
+
return fig
|
|
1378
|
+
|
|
1379
|
+
def SaveModel(self, path: str):
|
|
1380
|
+
if not path.lower().endswith(".pt"):
|
|
1381
|
+
path = path + ".pt"
|
|
1382
|
+
torch.save(self.model.state_dict(), path)
|
|
1383
|
+
|
|
1384
|
+
def LoadModel(self, path: str):
|
|
1385
|
+
state = torch.load(path, map_location=self.device)
|
|
1386
|
+
self.model.load_state_dict(state)
|
|
1387
|
+
self.model.to(self.device)
|
|
1388
|
+
self.model.eval()
|