topologicpy 0.8.98__py3-none-any.whl → 0.8.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
topologicpy/PyG.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2025
1
+ # Copyright (C) 2026
2
2
  # Wassim Jabi <wassim.jabi@gmail.com>
3
3
  #
4
4
  # This program is free software: you can redistribute it and/or modify it under
@@ -14,2396 +14,1375 @@
14
14
  # You should have received a copy of the GNU Affero General Public License along with
15
15
  # this program. If not, see <https://www.gnu.org/licenses/>.
16
16
 
17
- import os
18
- import copy
19
- import warnings
20
- import gc
21
-
22
- try:
23
- import numpy as np
24
- except:
25
- print("PyG - Installing required numpy library.")
26
- try:
27
- os.system("pip install numpy")
28
- except:
29
- os.system("pip install numpy --user")
30
- try:
31
- import numpy as np
32
- print("PyG - numpy library installed successfully.")
33
- except:
34
- warnings.warn("PyG - Error: Could not import numpy.")
35
-
36
- try:
37
- import pandas as pd
38
- except:
39
- print("PyG - Installing required pandas library.")
40
- try:
41
- os.system("pip install pandas")
42
- except:
43
- os.system("pip install pandas --user")
44
- try:
45
- import numpy as np
46
- print("PyG - pandas library installed successfully.")
47
- except:
48
- warnings.warn("PyG - Error: Could not import pandas.")
49
-
50
- try:
51
- from tqdm.auto import tqdm
52
- except:
53
- print("PyG - Installing required tqdm library.")
54
- try:
55
- os.system("pip install tqdm")
56
- except:
57
- os.system("pip install tqdm --user")
58
- try:
59
- from tqdm.auto import tqdm
60
- print("PyG - tqdm library installed correctly.")
61
- except:
62
- raise Exception("PyG - Error: Could not import tqdm.")
63
-
64
- try:
65
- import torch
66
- import torch.nn as nn
67
- import torch.nn.functional as F
68
- from torch.utils.data.sampler import SubsetRandomSampler
69
- except:
70
- print("PyG - Installing required torch library.")
71
- try:
72
- os.system("pip install torch")
73
- except:
74
- os.system("pip install torch --user")
75
- try:
76
- import torch
77
- import torch.nn as nn
78
- import torch.nn.functional as F
79
- from torch.utils.data.sampler import SubsetRandomSampler
80
- print("PyG - torch library installed correctly.")
81
- except:
82
- warnings.warn("PyG - Error: Could not import torch.")
83
-
84
- try:
85
- from torch_geometric.data import Data, Dataset
86
- from torch_geometric.loader import DataLoader
87
- from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool, global_add_pool
88
- except:
89
- print("PyG - Installing required torch_geometric library.")
90
- try:
91
- os.system("pip install torch_geometric")
92
- except:
93
- os.system("pip install torch_geometric --user")
94
- try:
95
- from torch_geometric.data import Data, Dataset
96
- from torch_geometric.loader import DataLoader
97
- from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool, global_add_pool
98
- print("PyG - torch_geometric library installed correctly.")
99
- except:
100
- warnings.warn("PyG - Error: Could not import torch.")
101
-
102
- try:
103
- from sklearn.model_selection import KFold
104
- from sklearn.metrics import accuracy_score
105
- except:
106
- print("PyG - Installing required scikit-learn library.")
107
- try:
108
- os.system("pip install -U scikit-learn")
109
- except:
110
- os.system("pip install -U scikit-learn --user")
111
- try:
112
- from sklearn.model_selection import KFold
113
- from sklearn.metrics import accuracy_score
114
- print("PyG - scikit-learn library installed correctly.")
115
- except:
116
- warnings.warn("PyG - Error: Could not import scikit. Please install it manually.")
117
-
118
- class CustomGraphDataset(Dataset):
119
- def __init__(self, root=None, data_list=None, indices=None, node_level=False, graph_level=True,
120
- node_attr_key='feat', edge_attr_key='feat'):
121
- """
122
- Initializes the CustomGraphDataset.
123
-
124
- Parameters:
125
- - root: Root directory of the dataset (used only if data_list is None)
126
- - data_list: List of preprocessed data objects (used if provided)
127
- - indices: List of indices to select a subset of the data
128
- - node_level: Boolean flag indicating if the dataset is node-level
129
- - graph_level: Boolean flag indicating if the dataset is graph-level
130
- - node_attr_key: Key for node attributes
131
- - edge_attr_key: Key for edge attributes
132
- """
133
- assert not (node_level and graph_level), "Both node_level and graph_level cannot be True at the same time"
134
- assert node_level or graph_level, "Both node_level and graph_level cannot be False at the same time"
135
-
136
- self.node_level = node_level
137
- self.graph_level = graph_level
138
- self.node_attr_key = node_attr_key
139
- self.edge_attr_key = edge_attr_key
140
-
141
- if data_list is not None:
142
- self.data_list = data_list # Use the provided data list
143
- elif root is not None:
144
- # Load and process data from root directory if data_list is not provided
145
- self.graph_df = pd.read_csv(os.path.join(root, 'graphs.csv'))
146
- self.nodes_df = pd.read_csv(os.path.join(root, 'nodes.csv'))
147
- self.edges_df = pd.read_csv(os.path.join(root, 'edges.csv'))
148
- self.data_list = self.process_all()
149
- else:
150
- raise ValueError("Either a root directory or a data_list must be provided.")
151
-
152
- # Filter data_list based on indices if provided
153
- if indices is not None:
154
- self.data_list = [self.data_list[i] for i in indices]
155
-
156
- def process_all(self):
157
- data_list = []
158
- for graph_id in self.graph_df['graph_id'].unique():
159
- graph_nodes = self.nodes_df[self.nodes_df['graph_id'] == graph_id]
160
- graph_edges = self.edges_df[self.edges_df['graph_id'] == graph_id]
161
-
162
- if self.node_attr_key in graph_nodes.columns and not graph_nodes[self.node_attr_key].isnull().all():
163
- x = torch.tensor(graph_nodes[self.node_attr_key].values.tolist(), dtype=torch.float)
164
- if x.ndim == 1:
165
- x = x.unsqueeze(1) # Ensure x has shape [num_nodes, *]
166
- else:
167
- x = None
168
-
169
- edge_index = torch.tensor(graph_edges[['src_id', 'dst_id']].values.T, dtype=torch.long)
170
-
171
- if self.edge_attr_key in graph_edges.columns and not graph_edges[self.edge_attr_key].isnull().all():
172
- edge_attr = torch.tensor(graph_edges[self.edge_attr_key].values.tolist(), dtype=torch.float)
173
- else:
174
- edge_attr = None
175
-
176
- if self.graph_level:
177
- label_value = self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]
178
-
179
- if isinstance(label_value, np.int64):
180
- label_value = int(label_value)
181
- if isinstance(label_value, np.float64):
182
- label_value = float(label_value)
183
-
184
- if isinstance(label_value, int) or isinstance(label_value, np.int64):
185
- y = torch.tensor([label_value], dtype=torch.long)
186
- elif isinstance(label_value, float):
187
- y = torch.tensor([label_value], dtype=torch.float)
188
- else:
189
- raise ValueError(f"Unexpected label type: {type(label_value)}. Expected int or float.")
190
-
191
- elif self.node_level:
192
- label_values = graph_nodes['label'].values
193
-
194
- if issubclass(label_values.dtype.type, int):
195
- y = torch.tensor(label_values, dtype=torch.long)
196
- elif issubclass(label_values.dtype.type, float):
197
- y = torch.tensor(label_values, dtype=torch.float)
198
- else:
199
- raise ValueError(f"Unexpected label types: {label_values.dtype}. Expected int or float.")
200
-
201
- data = Data(x=x, edge_index=edge_index, y=y)
202
- if edge_attr is not None:
203
- data.edge_attr = edge_attr
204
-
205
- data_list.append(data)
206
-
207
- return data_list
208
-
209
- def __len__(self):
210
- return len(self.data_list)
211
-
212
- def __getitem__(self, idx):
213
- return self.data_list[idx]
17
+ """
18
+ TopologicPy: PyTorch Geometric (PyG) helper class
19
+ =================================================
214
20
 
21
+ This module provides a clean, beginner-friendly interface to:
215
22
 
23
+ 1) Load TopologicPy-exported CSV datasets (graphs.csv, nodes.csv, edges.csv)
24
+ 2) Train / validate / test models for:
25
+ - Graph-level prediction (classification or regression)
26
+ - Node-level prediction (classification or regression)
27
+ - Edge-level prediction (classification or regression)
28
+ - Link prediction (binary edge existence)
216
29
 
30
+ 3) Report performance metrics and interactive Plotly visualisations.
217
31
 
218
- # class CustomGraphDataset(Dataset):
219
- # def __init__(self, root, node_level=False, graph_level=True, node_attr_key='feat',
220
- # edge_attr_key='feat', transform=None, pre_transform=None):
221
- # super(CustomGraphDataset, self).__init__(root, transform, pre_transform)
222
- # assert not (node_level and graph_level), "Both node_level and graph_level cannot be True at the same time"
223
- # assert node_level or graph_level, "Both node_level and graph_level cannot be False at the same time"
32
+ User-controlled hyperparameters (medium-level)
33
+ ----------------------------------------------
34
+ - Cross-validation: holdout or k-fold (graph-level)
35
+ - Network topology: number of hidden layers and neurons per layer (hidden_dims)
36
+ - GNN backbone: conv type (sage/gcn/gatv2), activation, dropout, batch_norm, residual
37
+ - Training: epochs, batch_size, lr, weight_decay, optimizer (adam/adamw),
38
+ gradient clipping, early stopping
224
39
 
225
- # self.node_level = node_level
226
- # self.graph_level = graph_level
227
- # self.node_attr_key = node_attr_key
228
- # self.edge_attr_key = edge_attr_key
40
+ CSV assumptions
41
+ ---------------
42
+ - graphs.csv contains at least: graph_id, label, and optional graph feature columns:
43
+ feat_0, feat_1, ...
229
44
 
230
- # self.graph_df = pd.read_csv(os.path.join(root, 'graphs.csv'))
231
- # self.nodes_df = pd.read_csv(os.path.join(root, 'nodes.csv'))
232
- # self.edges_df = pd.read_csv(os.path.join(root, 'edges.csv'))
45
+ - nodes.csv contains at least: graph_id, node_id, label, optional masks, and feature columns:
46
+ feat_0, feat_1, ...
233
47
 
234
- # self.data_list = self.process_all()
48
+ - edges.csv contains at least: graph_id, src_id, dst_id, label, optional masks, and feature columns:
49
+ feat_0, feat_1, ...
235
50
 
236
- # @property
237
- # def raw_file_names(self):
238
- # return ['graphs.csv', 'nodes.csv', 'edges.csv']
51
+ Notes
52
+ -----
53
+ - This module intentionally avoids auto-install behaviour.
54
+ - It aims to be easy to read and modify by non-ML experts.
239
55
 
240
- # def process_all(self):
241
- # data_list = []
242
- # for graph_id in self.graph_df['graph_id'].unique():
243
- # graph_nodes = self.nodes_df[self.nodes_df['graph_id'] == graph_id]
244
- # graph_edges = self.edges_df[self.edges_df['graph_id'] == graph_id]
56
+ """
245
57
 
246
- # if self.node_attr_key in graph_nodes.columns and not graph_nodes[self.node_attr_key].isnull().all():
247
- # x = torch.tensor(graph_nodes[self.node_attr_key].values.tolist(), dtype=torch.float)
248
- # if x.ndim == 1:
249
- # x = x.unsqueeze(1) # Ensure x has shape [num_nodes, *]
250
- # else:
251
- # x = None
58
+ from __future__ import annotations
252
59
 
253
- # edge_index = torch.tensor(graph_edges[['src_id', 'dst_id']].values.T, dtype=torch.long)
60
+ from dataclasses import dataclass
61
+ from typing import Dict, List, Optional, Tuple, Union, Literal
254
62
 
255
- # if self.edge_attr_key in graph_edges.columns and not graph_edges[self.edge_attr_key].isnull().all():
256
- # edge_attr = torch.tensor(graph_edges[self.edge_attr_key].values.tolist(), dtype=torch.float)
257
- # else:
258
- # edge_attr = None
259
-
260
-
261
-
262
- # if self.graph_level:
263
- # label_value = self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]
264
-
265
- # # Check if the label is an integer or a float and cast accordingly
266
- # if isinstance(label_value, int):
267
- # y = torch.tensor([label_value], dtype=torch.long)
268
- # elif isinstance(label_value, float):
269
- # y = torch.tensor([label_value], dtype=torch.float)
270
- # else:
271
- # raise ValueError(f"Unexpected label type: {type(label_value)}. Expected int or float.")
272
-
273
- # elif self.node_level:
274
- # label_values = graph_nodes['label'].values
275
-
276
- # # Check if the labels are integers or floats and cast accordingly
277
- # if issubclass(label_values.dtype.type, int):
278
- # y = torch.tensor(label_values, dtype=torch.long)
279
- # elif issubclass(label_values.dtype.type, float):
280
- # y = torch.tensor(label_values, dtype=torch.float)
281
- # else:
282
- # raise ValueError(f"Unexpected label types: {label_values.dtype}. Expected int or float.")
283
-
284
-
285
- # # if self.graph_level:
286
- # # y = torch.tensor([self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]], dtype=torch.long)
287
- # # elif self.node_level:
288
- # # y = torch.tensor(graph_nodes['label'].values, dtype=torch.long)
63
+ import os
64
+ import math
65
+ import random
66
+ import copy
289
67
 
290
- # data = Data(x=x, edge_index=edge_index, y=y)
291
- # if edge_attr is not None:
292
- # data.edge_attr = edge_attr
68
+ import numpy as np
69
+ import pandas as pd
70
+
71
+ import torch
72
+ import torch.nn as nn
73
+ import torch.nn.functional as F
74
+
75
+ from torch_geometric.data import Data
76
+ from torch_geometric.loader import DataLoader
77
+ from torch_geometric.nn import (
78
+ SAGEConv, GCNConv, GATv2Conv,
79
+ global_mean_pool, global_max_pool, global_add_pool
80
+ )
81
+ from torch_geometric.transforms import RandomLinkSplit
82
+
83
+ from sklearn.metrics import (
84
+ accuracy_score, precision_recall_fscore_support, confusion_matrix,
85
+ mean_absolute_error, mean_squared_error, r2_score
86
+ )
87
+
88
+ import plotly.graph_objects as go
89
+ import plotly.express as px
90
+
91
+
92
+ LabelType = Literal["categorical", "continuous"]
93
+ Level = Literal["graph", "node", "edge", "link"]
94
+ TaskKind = Literal["classification", "regression", "link_prediction"]
95
+ ConvKind = Literal["sage", "gcn", "gatv2"]
96
+ PoolingKind = Literal["mean", "max", "sum"]
97
+
98
+
99
+ @dataclass
100
+ class _RunConfig:
101
+ # ----------------------------
102
+ # Task selection
103
+ # ----------------------------
104
+ level: Level = "graph" # "graph" | "node" | "edge" | "link"
105
+ task: TaskKind = "classification" # "classification" | "regression" | "link_prediction"
106
+
107
+ # label types (graph/node/edge)
108
+ graph_label_type: LabelType = "categorical"
109
+ node_label_type: LabelType = "categorical"
110
+ edge_label_type: LabelType = "categorical"
111
+
112
+ # ----------------------------
113
+ # CSV headers
114
+ # ----------------------------
115
+ graph_id_header: str = "graph_id"
116
+ graph_label_header: str = "label"
117
+ graph_features_header: str = "feat"
118
+
119
+ node_id_header: str = "node_id"
120
+ node_label_header: str = "label"
121
+ node_features_header: str = "feat"
122
+
123
+ edge_src_header: str = "src_id"
124
+ edge_dst_header: str = "dst_id"
125
+ edge_label_header: str = "label"
126
+ edge_features_header: str = "feat"
127
+
128
+ # masks (optional)
129
+ node_train_mask_header: str = "train_mask"
130
+ node_val_mask_header: str = "val_mask"
131
+ node_test_mask_header: str = "test_mask"
132
+
133
+ edge_train_mask_header: str = "train_mask"
134
+ edge_val_mask_header: str = "val_mask"
135
+ edge_test_mask_header: str = "test_mask"
136
+
137
+ # ----------------------------
138
+ # Cross-validation / splitting
139
+ # ----------------------------
140
+ cv: Literal["holdout", "kfold"] = "holdout"
141
+ split: Tuple[float, float, float] = (0.8, 0.1, 0.1) # used for holdout
142
+ k_folds: int = 5 # used for kfold (graph-level only)
143
+ k_shuffle: bool = True
144
+ k_stratify: bool = True # only if categorical labels exist
145
+ random_state: int = 42
146
+ shuffle: bool = True # affects holdout + in-graph mask fallback
147
+
148
+ # link prediction split (within each graph)
149
+ link_val_ratio: float = 0.1
150
+ link_test_ratio: float = 0.1
151
+ link_is_undirected: bool = False
152
+
153
+ # ----------------------------
154
+ # Training hyperparameters
155
+ # ----------------------------
156
+ epochs: int = 50
157
+ batch_size: int = 32
158
+ lr: float = 1e-3
159
+ weight_decay: float = 0.0
160
+ optimizer: Literal["adam", "adamw"] = "adam"
161
+ gradient_clip_norm: Optional[float] = None
162
+ early_stopping: bool = False
163
+ early_stopping_patience: int = 10
164
+ use_gpu: bool = True
165
+
166
+ # ----------------------------
167
+ # Network topology / model hyperparameters
168
+ # ----------------------------
169
+ conv: ConvKind = "sage"
170
+ hidden_dims: Tuple[int, ...] = (64, 64) # explicit per-layer widths (controls depth)
171
+ activation: Literal["relu", "gelu", "elu"] = "relu"
172
+ dropout: float = 0.1
173
+ batch_norm: bool = False
174
+ residual: bool = False
175
+ pooling: PoolingKind = "mean" # only for graph-level
176
+
177
+
178
+ class _GNNBackbone(nn.Module):
179
+ """
180
+ Shared GNN encoder that produces node embeddings.
181
+ """
182
+
183
+ def __init__(self,
184
+ in_dim: int,
185
+ hidden_dims: Tuple[int, ...],
186
+ conv: ConvKind = "sage",
187
+ activation: str = "relu",
188
+ dropout: float = 0.1,
189
+ batch_norm: bool = False,
190
+ residual: bool = False):
191
+ super().__init__()
192
+ if in_dim <= 0:
193
+ raise ValueError("in_dim must be > 0. Your dataset has no node features columns.")
194
+ if hidden_dims is None or len(hidden_dims) == 0:
195
+ raise ValueError("hidden_dims must contain at least one layer width, e.g. (64, 64).")
196
+
197
+ self.dropout = float(dropout)
198
+ self.use_bn = bool(batch_norm)
199
+ self.use_residual = bool(residual)
200
+
201
+ if activation == "relu":
202
+ self.act = F.relu
203
+ elif activation == "gelu":
204
+ self.act = F.gelu
205
+ elif activation == "elu":
206
+ self.act = F.elu
207
+ else:
208
+ raise ValueError("Unsupported activation. Use 'relu', 'gelu', or 'elu'.")
293
209
 
294
- # data_list.append(data)
210
+ dims = [int(in_dim)] + [int(d) for d in hidden_dims]
295
211
 
296
- # return data_list
212
+ self.convs = nn.ModuleList()
213
+ self.bns = nn.ModuleList()
297
214
 
298
- # def len(self):
299
- # return len(self.data_list)
215
+ for i in range(1, len(dims)):
216
+ in_ch, out_ch = dims[i - 1], dims[i]
300
217
 
301
- # def get(self, idx):
302
- # return self.data_list[idx]
218
+ if conv == "sage":
219
+ self.convs.append(SAGEConv(in_ch, out_ch))
220
+ elif conv == "gcn":
221
+ self.convs.append(GCNConv(in_ch, out_ch))
222
+ elif conv == "gatv2":
223
+ self.convs.append(GATv2Conv(in_ch, out_ch, heads=1, concat=False))
224
+ else:
225
+ raise ValueError(f"Unsupported conv='{conv}'.")
303
226
 
304
- # def __getitem__(self, idx):
305
- # return self.get(idx)
227
+ if self.use_bn:
228
+ self.bns.append(nn.BatchNorm1d(out_ch))
306
229
 
307
- class _Hparams:
308
- def __init__(self, model_type="ClassifierHoldout", optimizer_str="Adam", amsgrad=False, betas=(0.9, 0.999), eps=1e-6, lr=0.001, lr_decay= 0, maximize=False, rho=0.9, weight_decay=0, cv_type="Holdout", split=[0.8,0.1, 0.1], k_folds=5, hl_widths=[32], conv_layer_type='SAGEConv', pooling="AvgPooling", batch_size=32, epochs=1,
309
- use_gpu=False, loss_function="Cross Entropy", input_type="graph"):
310
- """
311
- Parameters
312
- ----------
313
- cv : str
314
- A string to define the method of cross-validation
315
- "Holdout": Holdout
316
- "K-Fold": K-Fold cross validation
317
- k_folds : int
318
- An int value in the range of 2 to X to define the number of k-folds for cross-validation. Default is 5.
319
- split : list
320
- A list of three item in the range of 0 to 1 to define the split of train,
321
- validate, and test data. A default value of [0.8,0.1,0.1] means 80% of data will be
322
- used for training, 10% will be used for validation, and the remaining 10% will be used for training
323
- hl_widths : list
324
- List of hidden neurons for each layer such as [32] will mean
325
- that there is one hidden layers in the network with 32 neurons
326
- optimizer : torch.optim object
327
- This will be the selected optimizer from torch.optim package. By
328
- default, torch.optim.Adam is selected
329
- learning_rate : float
330
- a step value to be used to apply the gradients by optimizer
331
- batch_size : int
332
- to define a set of samples to be used for training and testing in
333
- each step of an epoch
334
- epochs : int
335
- An epoch means training the neural network with all the training data for one cycle. In an epoch, we use all of the data exactly once. A forward pass and a backward pass together are counted as one pass
336
- use_GPU : use the GPU. Otherwise, use the CPU
337
- input_type : str
338
- selects the input_type of model such as graph, node or edge
230
+ self.out_dim = dims[-1]
339
231
 
340
- Returns
341
- -------
342
- None
232
+ def forward(self, x, edge_index):
233
+ h = x
234
+ for i, conv in enumerate(self.convs):
235
+ h_in = h
236
+ h = conv(h, edge_index)
237
+ if self.use_bn:
238
+ h = self.bns[i](h)
239
+ h = self.act(h)
343
240
 
344
- """
345
-
346
- self.model_type = model_type
347
- self.optimizer_str = optimizer_str
348
- self.amsgrad = amsgrad
349
- self.betas = betas
350
- self.eps = eps
351
- self.lr = lr
352
- self.lr_decay = lr_decay
353
- self.maximize = maximize
354
- self.rho = rho
355
- self.weight_decay = weight_decay
356
- self.cv_type = cv_type
357
- self.split = split
358
- self.k_folds = k_folds
359
- self.hl_widths = hl_widths
360
- self.conv_layer_type = conv_layer_type
361
- self.pooling = pooling
362
- self.batch_size = batch_size
363
- self.epochs = epochs
364
- self.use_gpu = use_gpu
365
- self.loss_function = loss_function
366
- self.input_type = input_type
367
-
368
- class _SAGEConv(nn.Module):
369
- def __init__(self, in_feats, h_feats, num_classes, pooling=None):
370
- super(_SAGEConv, self).__init__()
371
- assert isinstance(h_feats, list), "h_feats must be a list"
372
- h_feats = [x for x in h_feats if x is not None]
373
- assert len(h_feats) != 0, "h_feats is empty. unable to add hidden layers"
374
- self.list_of_layers = nn.ModuleList()
375
- dim = [in_feats] + h_feats
376
-
377
- # Convolution (Hidden) Layers
378
- for i in range(1, len(dim)):
379
- self.list_of_layers.append(SAGEConv(dim[i-1], dim[i]))
380
-
381
- # Final Layer
382
- self.final = nn.Linear(dim[-1], num_classes)
383
-
384
- # Pooling layer
385
- if pooling is None:
386
- self.pooling_layer = None
387
- else:
388
- if "av" in pooling.lower():
389
- self.pooling_layer = global_mean_pool
390
- elif "max" in pooling.lower():
391
- self.pooling_layer = global_max_pool
392
- elif "sum" in pooling.lower():
393
- self.pooling_layer = global_add_pool
394
- else:
395
- raise NotImplementedError
241
+ if self.use_residual and h_in.shape == h.shape:
242
+ h = h + h_in
396
243
 
397
- def forward(self, data):
398
- x, edge_index, batch = data.x, data.edge_index, data.batch
399
- h = x
400
- # Generate node features
401
- for layer in self.list_of_layers:
402
- h = layer(h, edge_index)
403
- h = F.relu(h)
404
- # h will now be a matrix of dimension [num_nodes, h_feats[-1]]
405
- h = self.final(h)
406
- # Go from node-level features to graph-level features by pooling
407
- if self.pooling_layer:
408
- h = self.pooling_layer(h, batch)
409
- # h will now be a vector of dimension [num_classes]
244
+ h = F.dropout(h, p=self.dropout, training=self.training)
410
245
  return h
411
246
 
412
- class _GraphRegressorHoldout:
413
- def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
414
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
415
- self.trainingDataset = trainingDataset
416
- self.validationDataset = validationDataset
417
- self.testingDataset = testingDataset
418
- self.hparams = hparams
419
- if hparams.conv_layer_type.lower() == 'sageconv':
420
- self.model = _SAGEConv(trainingDataset[0].num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
421
- else:
422
- raise NotImplementedError
423
-
424
- if hparams.optimizer_str.lower() == "adadelta":
425
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
426
- lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
427
- elif hparams.optimizer_str.lower() == "adagrad":
428
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
429
- lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
430
- elif hparams.optimizer_str.lower() == "adam":
431
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
432
- lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
433
-
434
- self.use_gpu = hparams.use_gpu
435
- self.training_loss_list = []
436
- self.validation_loss_list = []
437
- self.node_attr_key = trainingDataset[0].x.shape[1]
438
-
439
- # Train, validate, test split
440
- num_train = int(len(trainingDataset) * hparams.split[0])
441
- num_validate = int(len(trainingDataset) * hparams.split[1])
442
- num_test = len(trainingDataset) - num_train - num_validate
443
- idx = torch.randperm(len(trainingDataset))
444
- train_sampler = SubsetRandomSampler(idx[:num_train])
445
- validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
446
- test_sampler = SubsetRandomSampler(idx[num_train+num_validate:])
447
-
448
- if validationDataset:
449
- self.train_dataloader = DataLoader(trainingDataset,
450
- batch_size=hparams.batch_size,
451
- drop_last=False)
452
- self.validate_dataloader = DataLoader(validationDataset,
453
- batch_size=hparams.batch_size,
454
- drop_last=False)
455
- else:
456
- self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler,
457
- batch_size=hparams.batch_size,
458
- drop_last=False)
459
- self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler,
460
- batch_size=hparams.batch_size,
461
- drop_last=False)
462
-
463
- if testingDataset:
464
- self.test_dataloader = DataLoader(testingDataset,
465
- batch_size=len(testingDataset),
466
- drop_last=False)
467
- else:
468
- self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler,
469
- batch_size=hparams.batch_size,
470
- drop_last=False)
471
-
472
- def train(self):
473
- # Init the loss and accuracy reporting lists
474
- self.training_loss_list = []
475
- self.validation_loss_list = []
476
-
477
- # Run the training loop for defined number of epochs
478
- for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
479
- # Iterate over the DataLoader for training data
480
- for data in tqdm(self.train_dataloader, desc='Training', leave=False):
481
- data = data.to(self.device)
482
- # Make sure the model is in training mode
483
- self.model.train()
484
- # Zero the gradients
485
- self.optimizer.zero_grad()
486
-
487
- # Perform forward pass
488
- pred = self.model(data).to(self.device)
489
- # Compute loss
490
- loss = F.mse_loss(torch.flatten(pred), data.y.float())
491
-
492
- # Perform backward pass
493
- loss.backward()
494
-
495
- # Perform optimization
496
- self.optimizer.step()
497
-
498
- self.training_loss_list.append(torch.sqrt(loss).item())
499
- self.validate()
500
- self.validation_loss_list.append(torch.sqrt(self.validation_loss).item())
501
- gc.collect()
502
-
503
- def validate(self):
504
- self.model.eval()
505
- for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
506
- data = data.to(self.device)
507
- pred = self.model(data).to(self.device)
508
- loss = F.mse_loss(torch.flatten(pred), data.y.float())
509
- self.validation_loss = loss
510
-
511
- def test(self):
512
- self.model.eval()
513
- for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
514
- data = data.to(self.device)
515
- pred = self.model(data).to(self.device)
516
- loss = F.mse_loss(torch.flatten(pred), data.y.float())
517
- self.testing_loss = torch.sqrt(loss).item()
518
-
519
- def save(self, path):
520
- if path:
521
- # Make sure the file extension is .pt
522
- ext = path[-3:]
523
- if ext.lower() != ".pt":
524
- path = path + ".pt"
525
- torch.save(self.model.state_dict(), path)
526
-
527
- def load(self, path):
528
- #self.model.load_state_dict(torch.load(path))
529
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
530
-
531
- class _GraphRegressorKFold:
532
- def __init__(self, hparams, trainingDataset, testingDataset=None):
533
- self.trainingDataset = trainingDataset
534
- self.testingDataset = testingDataset
535
- self.hparams = hparams
536
- self.losses = []
537
- self.min_loss = 0
538
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
539
-
540
- self.model = self._initialize_model(hparams, trainingDataset)
541
- self.optimizer = self._initialize_optimizer(hparams)
542
-
543
- self.use_gpu = hparams.use_gpu
544
- self.training_loss_list = []
545
- self.validation_loss_list = []
546
- self.node_attr_key = trainingDataset.node_attr_key
547
-
548
- # Train, validate, test split
549
- num_train = int(len(trainingDataset) * hparams.split[0])
550
- num_validate = int(len(trainingDataset) * hparams.split[1])
551
- num_test = len(trainingDataset) - num_train - num_validate
552
- idx = torch.randperm(len(trainingDataset))
553
- test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
554
-
555
- if testingDataset:
556
- self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
557
- else:
558
- self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
559
-
560
- def _initialize_model(self, hparams, dataset):
561
- if hparams.conv_layer_type.lower() == 'sageconv':
562
- return _SAGEConv(dataset[0].num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
563
- #return _SAGEConv(dataset.num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
564
- else:
565
- raise NotImplementedError
566
-
567
- def _initialize_optimizer(self, hparams):
568
- if hparams.optimizer_str.lower() == "adadelta":
569
- return torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
570
- elif hparams.optimizer_str.lower() == "adagrad":
571
- return torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
572
- elif hparams.optimizer_str.lower() == "adam":
573
- return torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps, lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
574
-
575
- def reset_weights(self):
576
- self.model = self._initialize_model(self.hparams, self.trainingDataset)
577
- self.optimizer = self._initialize_optimizer(self.hparams)
578
-
579
- def train(self):
580
- k_folds = self.hparams.k_folds
581
- torch.manual_seed(42)
582
-
583
- kfold = KFold(n_splits=k_folds, shuffle=True)
584
- models, weights, losses, train_dataloaders, validate_dataloaders = [], [], [], [], []
585
-
586
- for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", total=k_folds, leave=False):
587
- epoch_training_loss_list, epoch_validation_loss_list = [], []
588
- train_subsampler = SubsetRandomSampler(train_ids)
589
- validate_subsampler = SubsetRandomSampler(validate_ids)
590
-
591
- self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
592
- self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
593
-
594
- self.reset_weights()
595
- best_rmse = np.inf
596
-
597
- for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
598
- for batched_graph in tqdm(self.train_dataloader, desc='Training', leave=False):
599
- self.model.train()
600
- self.optimizer.zero_grad()
601
-
602
- batched_graph = batched_graph.to(self.device)
603
- pred = self.model(batched_graph)
604
- loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
605
- loss.backward()
606
- self.optimizer.step()
607
-
608
- epoch_training_loss_list.append(torch.sqrt(loss).item())
609
- self.validate()
610
- epoch_validation_loss_list.append(torch.sqrt(self.validation_loss).item())
611
- gc.collect()
612
-
613
- models.append(self.model)
614
- weights.append(copy.deepcopy(self.model.state_dict()))
615
- losses.append(torch.sqrt(self.validation_loss).item())
616
- train_dataloaders.append(self.train_dataloader)
617
- validate_dataloaders.append(self.validate_dataloader)
618
- self.training_loss_list.append(epoch_training_loss_list)
619
- self.validation_loss_list.append(epoch_validation_loss_list)
620
-
621
- self.losses = losses
622
- self.min_loss = min(losses)
623
- ind = losses.index(self.min_loss)
624
- self.model = models[ind]
625
- self.model.load_state_dict(weights[ind])
626
- self.model.eval()
627
- self.training_loss_list = self.training_loss_list[ind]
628
- self.validation_loss_list = self.validation_loss_list[ind]
629
247
 
630
- def validate(self):
631
- self.model.eval()
632
- for batched_graph in tqdm(self.validate_dataloader, desc='Validating', leave=False):
633
- batched_graph = batched_graph.to(self.device)
634
- pred = self.model(batched_graph)
635
- loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
636
- self.validation_loss = loss
637
-
638
- def test(self):
639
- self.model.eval()
640
- for batched_graph in tqdm(self.test_dataloader, desc='Testing', leave=False):
641
- batched_graph = batched_graph.to(self.device)
642
- pred = self.model(batched_graph)
643
- loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
644
- self.testing_loss = torch.sqrt(loss).item()
645
-
646
- def save(self, path):
647
- if path:
648
- ext = path[-3:]
649
- if ext.lower() != ".pt":
650
- path = path + ".pt"
651
- torch.save(self.model.state_dict(), path)
652
-
653
- def load(self, path):
654
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
655
-
656
- class _GraphClassifierKFold:
657
- def __init__(self, hparams, trainingDataset, testingDataset=None):
658
- self.trainingDataset = trainingDataset
659
- self.testingDataset = testingDataset
660
- self.hparams = hparams
661
- self.testing_accuracy = 0
662
- self.accuracies = []
663
- self.max_accuracy = 0
664
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
665
-
666
- if hparams.conv_layer_type.lower() == 'sageconv':
667
- self.model = _SAGEConv(trainingDataset.num_node_features, hparams.hl_widths,
668
- trainingDataset.num_classes, hparams.pooling).to(self.device)
669
- else:
670
- raise NotImplementedError
671
-
672
- if hparams.optimizer_str.lower() == "adadelta":
673
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
674
- lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
675
- elif hparams.optimizer_str.lower() == "adagrad":
676
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
677
- lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
678
- elif hparams.optimizer_str.lower() == "adam":
679
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
680
- lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
681
- self.use_gpu = hparams.use_gpu
682
- self.training_loss_list = []
683
- self.validation_loss_list = []
684
- self.training_accuracy_list = []
685
- self.validation_accuracy_list = []
686
-
687
- def reset_weights(self):
688
- if self.hparams.conv_layer_type.lower() == 'sageconv':
689
- self.model = _SAGEConv(self.trainingDataset.num_node_features, self.hparams.hl_widths,
690
- self.trainingDataset.num_classes, self.hparams.pooling).to(self.device)
691
- else:
692
- raise NotImplementedError
693
-
694
- if self.hparams.optimizer_str.lower() == "adadelta":
695
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=self.hparams.eps,
696
- lr=self.hparams.lr, rho=self.hparams.rho, weight_decay=self.hparams.weight_decay)
697
- elif self.hparams.optimizer_str.lower() == "adagrad":
698
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=self.hparams.eps,
699
- lr=self.hparams.lr, lr_decay=self.hparams.lr_decay, weight_decay=self.hparams.weight_decay)
700
- elif self.hparams.optimizer_str.lower() == "adam":
701
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=self.hparams.amsgrad, betas=self.hparams.betas, eps=self.hparams.eps,
702
- lr=self.hparams.lr, maximize=self.hparams.maximize, weight_decay=self.hparams.weight_decay)
703
-
704
- def train(self):
705
- k_folds = self.hparams.k_folds
706
-
707
- # Init the loss and accuracy reporting lists
708
- self.training_accuracy_list = []
709
- self.training_loss_list = []
710
- self.validation_accuracy_list = []
711
- self.validation_loss_list = []
712
-
713
- # Set fixed random number seed
714
- torch.manual_seed(42)
715
-
716
- # Define the K-fold Cross Validator
717
- kfold = KFold(n_splits=k_folds, shuffle=True)
718
-
719
- models = []
720
- weights = []
721
- accuracies = []
722
- train_dataloaders = []
723
- validate_dataloaders = []
724
-
725
- # K-fold Cross-validation model evaluation
726
- for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", initial=1, total=k_folds, leave=False):
727
- epoch_training_loss_list = []
728
- epoch_training_accuracy_list = []
729
- epoch_validation_loss_list = []
730
- epoch_validation_accuracy_list = []
731
- # Sample elements randomly from a given list of ids, no replacement.
732
- train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
733
- validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
734
-
735
- # Define data loaders for training and testing data in this fold
736
- self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler,
737
- batch_size=self.hparams.batch_size,
738
- drop_last=False)
739
- self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler,
740
- batch_size=self.hparams.batch_size,
741
- drop_last=False)
742
- # Init the neural network
743
- self.reset_weights()
744
-
745
- # Run the training loop for defined number of epochs
746
- for _ in tqdm(range(0,self.hparams.epochs), desc='Epochs', initial=1, total=self.hparams.epochs, leave=False):
747
- temp_loss_list = []
748
- temp_acc_list = []
749
-
750
- # Iterate over the DataLoader for training data
751
- for data in tqdm(self.train_dataloader, desc='Training', leave=False):
752
- data = data.to(self.device)
753
- # Make sure the model is in training mode
754
- self.model.train()
755
-
756
- # Zero the gradients
757
- self.optimizer.zero_grad()
758
-
759
- # Perform forward pass
760
- pred = self.model(data)
761
-
762
- # Compute loss
763
- if self.hparams.loss_function.lower() == "negative log likelihood":
764
- logp = F.log_softmax(pred, 1)
765
- loss = F.nll_loss(logp, data.y)
766
- elif self.hparams.loss_function.lower() == "cross entropy":
767
- loss = F.cross_entropy(pred, data.y)
768
-
769
- # Save loss information for reporting
770
- temp_loss_list.append(loss.item())
771
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
772
-
773
- # Perform backward pass
774
- loss.backward()
775
-
776
- # Perform optimization
777
- self.optimizer.step()
778
-
779
- epoch_training_accuracy_list.append(np.mean(temp_acc_list).item())
780
- epoch_training_loss_list.append(np.mean(temp_loss_list).item())
781
- self.validate()
782
- epoch_validation_accuracy_list.append(self.validation_accuracy)
783
- epoch_validation_loss_list.append(self.validation_loss)
784
- gc.collect()
785
- models.append(self.model)
786
- weights.append(copy.deepcopy(self.model.state_dict()))
787
- accuracies.append(self.validation_accuracy)
788
- train_dataloaders.append(self.train_dataloader)
789
- validate_dataloaders.append(self.validate_dataloader)
790
- self.training_accuracy_list.append(epoch_training_accuracy_list)
791
- self.training_loss_list.append(epoch_training_loss_list)
792
- self.validation_accuracy_list.append(epoch_validation_accuracy_list)
793
- self.validation_loss_list.append(epoch_validation_loss_list)
794
- self.accuracies = accuracies
795
- max_accuracy = max(accuracies)
796
- self.max_accuracy = max_accuracy
797
- ind = accuracies.index(max_accuracy)
798
- self.model = models[ind]
799
- self.model.load_state_dict(weights[ind])
800
- self.model.eval()
801
- self.training_accuracy_list = self.training_accuracy_list[ind]
802
- self.training_loss_list = self.training_loss_list[ind]
803
- self.validation_accuracy_list = self.validation_accuracy_list[ind]
804
- self.validation_loss_list = self.validation_loss_list[ind]
805
-
806
- def validate(self):
807
- temp_loss_list = []
808
- temp_acc_list = []
809
- self.model.eval()
810
- for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
811
- data = data.to(self.device)
812
- pred = self.model(data)
813
- if self.hparams.loss_function.lower() == "negative log likelihood":
814
- logp = F.log_softmax(pred, 1)
815
- loss = F.nll_loss(logp, data.y)
816
- elif self.hparams.loss_function.lower() == "cross entropy":
817
- loss = F.cross_entropy(pred, data.y)
818
- temp_loss_list.append(loss.item())
819
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
820
- self.validation_accuracy = np.mean(temp_acc_list).item()
821
- self.validation_loss = np.mean(temp_loss_list).item()
822
-
823
- def test(self):
824
- if self.testingDataset:
825
- self.test_dataloader = DataLoader(self.testingDataset,
826
- batch_size=len(self.testingDataset),
827
- drop_last=False)
828
- temp_loss_list = []
829
- temp_acc_list = []
830
- self.model.eval()
831
- for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
832
- data = data.to(self.device)
833
- pred = self.model(data)
834
- if self.hparams.loss_function.lower() == "negative log likelihood":
835
- logp = F.log_softmax(pred, 1)
836
- loss = F.nll_loss(logp, data.y)
837
- elif self.hparams.loss_function.lower() == "cross entropy":
838
- loss = F.cross_entropy(pred, data.y)
839
- temp_loss_list.append(loss.item())
840
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
841
- self.testing_accuracy = np.mean(temp_acc_list).item()
842
- self.testing_loss = np.mean(temp_loss_list).item()
843
-
844
- def save(self, path):
845
- if path:
846
- # Make sure the file extension is .pt
847
- ext = path[-3:]
848
- if ext.lower() != ".pt":
849
- path = path + ".pt"
850
- torch.save(self.model.state_dict(), path)
851
-
852
- def load(self, path):
853
- #self.model.load_state_dict(torch.load(path))
854
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
855
-
856
- class _GraphClassifierHoldout:
857
- def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
858
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
859
- self.trainingDataset = trainingDataset
860
- self.validationDataset = validationDataset
861
- self.testingDataset = testingDataset
862
- self.hparams = hparams
863
- gclasses = trainingDataset.num_classes
864
- nfeats = trainingDataset.num_node_features
865
-
866
- if hparams.conv_layer_type.lower() == 'sageconv':
867
- self.model = _SAGEConv(nfeats, hparams.hl_widths, gclasses, hparams.pooling).to(self.device)
868
- else:
869
- raise NotImplementedError
870
-
871
- if hparams.optimizer_str.lower() == "adadelta":
872
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
873
- lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
874
- elif hparams.optimizer_str.lower() == "adagrad":
875
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
876
- lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
877
- elif hparams.optimizer_str.lower() == "adam":
878
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
879
- lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
880
- self.use_gpu = hparams.use_gpu
881
- self.training_loss_list = []
882
- self.validation_loss_list = []
883
- self.training_accuracy_list = []
884
- self.validation_accuracy_list = []
885
- self.node_attr_key = trainingDataset[0].x.shape[1]
886
-
887
- # train, validate, test split
888
- num_train = int(len(trainingDataset) * hparams.split[0])
889
- num_validate = int(len(trainingDataset) * hparams.split[1])
890
- num_test = len(trainingDataset) - num_train - num_validate
891
- idx = torch.randperm(len(trainingDataset))
892
- train_sampler = SubsetRandomSampler(idx[:num_train])
893
- validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
894
- test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
895
-
896
- if validationDataset:
897
- self.train_dataloader = DataLoader(trainingDataset, batch_size=hparams.batch_size, drop_last=False)
898
- self.validate_dataloader = DataLoader(validationDataset, batch_size=hparams.batch_size, drop_last=False)
899
- else:
900
- self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler, batch_size=hparams.batch_size, drop_last=False)
901
- self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler, batch_size=hparams.batch_size, drop_last=False)
902
-
903
- if testingDataset:
904
- self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
905
- else:
906
- self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
907
-
908
- def train(self):
909
- # Init the loss and accuracy reporting lists
910
- self.training_accuracy_list = []
911
- self.training_loss_list = []
912
- self.validation_accuracy_list = []
913
- self.validation_loss_list = []
914
-
915
- # Run the training loop for defined number of epochs
916
- for _ in tqdm(range(self.hparams.epochs), desc='Epochs', initial=1, leave=False):
917
- temp_loss_list = []
918
- temp_acc_list = []
919
- # Make sure the model is in training mode
920
- self.model.train()
921
- # Iterate over the DataLoader for training data
922
- for data in tqdm(self.train_dataloader, desc='Training', leave=False):
923
- data = data.to(self.device)
924
-
925
- # Zero the gradients
926
- self.optimizer.zero_grad()
927
-
928
- # Perform forward pass
929
- pred = self.model(data)
930
-
931
- # Compute loss
932
- if self.hparams.loss_function.lower() == "negative log likelihood":
933
- logp = F.log_softmax(pred, 1)
934
- loss = F.nll_loss(logp, data.y)
935
- elif self.hparams.loss_function.lower() == "cross entropy":
936
- loss = F.cross_entropy(pred, data.y)
937
-
938
- # Save loss information for reporting
939
- temp_loss_list.append(loss.item())
940
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
941
-
942
- # Perform backward pass
943
- loss.backward()
944
-
945
- # Perform optimization
946
- self.optimizer.step()
947
-
948
- self.training_accuracy_list.append(np.mean(temp_acc_list).item())
949
- self.training_loss_list.append(np.mean(temp_loss_list).item())
950
- self.validate()
951
- self.validation_accuracy_list.append(self.validation_accuracy)
952
- self.validation_loss_list.append(self.validation_loss)
953
- gc.collect()
954
-
955
- def validate(self):
956
- temp_loss_list = []
957
- temp_acc_list = []
958
- self.model.eval()
959
- for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
960
- data = data.to(self.device)
961
- pred = self.model(data)
962
- if self.hparams.loss_function.lower() == "negative log likelihood":
963
- logp = F.log_softmax(pred, 1)
964
- loss = F.nll_loss(logp, data.y)
965
- elif self.hparams.loss_function.lower() == "cross entropy":
966
- loss = F.cross_entropy(pred, data.y)
967
- temp_loss_list.append(loss.item())
968
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
969
- self.validation_accuracy = np.mean(temp_acc_list).item()
970
- self.validation_loss = np.mean(temp_loss_list).item()
971
-
972
- def test(self):
973
- if self.test_dataloader:
974
- temp_loss_list = []
975
- temp_acc_list = []
976
- self.model.eval()
977
- for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
978
- data = data.to(self.device)
979
- pred = self.model(data)
980
- if self.hparams.loss_function.lower() == "negative log likelihood":
981
- logp = F.log_softmax(pred, 1)
982
- loss = F.nll_loss(logp, data.y)
983
- elif self.hparams.loss_function.lower() == "cross entropy":
984
- loss = F.cross_entropy(pred, data.y)
985
- temp_loss_list.append(loss.item())
986
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
987
- self.testing_accuracy = np.mean(temp_acc_list).item()
988
- self.testing_loss = np.mean(temp_loss_list).item()
989
-
990
- def save(self, path):
991
- if path:
992
- # Make sure the file extension is .pt
993
- ext = path[-3:]
994
- if ext.lower() != ".pt":
995
- path = path + ".pt"
996
- torch.save(self.model.state_dict(), path)
997
-
998
- def load(self, path):
999
- #self.model.load_state_dict(torch.load(path))
1000
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
1001
-
1002
- class _NodeClassifierHoldout:
1003
- def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
1004
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1005
- self.trainingDataset = trainingDataset
1006
- self.validationDataset = validationDataset
1007
- self.testingDataset = testingDataset
1008
- self.hparams = hparams
1009
- gclasses = trainingDataset.num_classes
1010
- nfeats = trainingDataset.num_node_features
1011
-
1012
- if hparams.conv_layer_type.lower() == 'sageconv':
1013
- # pooling is set None for Node classifier
1014
- self.model = _SAGEConv(nfeats, hparams.hl_widths, gclasses, None).to(self.device)
1015
- else:
1016
- raise NotImplementedError
1017
-
1018
- if hparams.optimizer_str.lower() == "adadelta":
1019
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
1020
- lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
1021
- elif hparams.optimizer_str.lower() == "adagrad":
1022
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
1023
- lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
1024
- elif hparams.optimizer_str.lower() == "adam":
1025
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
1026
- lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
1027
- self.use_gpu = hparams.use_gpu
1028
- self.training_loss_list = []
1029
- self.validation_loss_list = []
1030
- self.training_accuracy_list = []
1031
- self.validation_accuracy_list = []
1032
- self.node_attr_key = trainingDataset[0].x.shape[1]
1033
-
1034
- # train, validate, test split
1035
- num_train = int(len(trainingDataset) * hparams.split[0])
1036
- num_validate = int(len(trainingDataset) * hparams.split[1])
1037
- num_test = len(trainingDataset) - num_train - num_validate
1038
- idx = torch.randperm(len(trainingDataset))
1039
- train_sampler = SubsetRandomSampler(idx[:num_train])
1040
- validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
1041
- test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
1042
-
1043
- if validationDataset:
1044
- self.train_dataloader = DataLoader(trainingDataset, batch_size=hparams.batch_size, drop_last=False)
1045
- self.validate_dataloader = DataLoader(validationDataset, batch_size=hparams.batch_size, drop_last=False)
1046
- else:
1047
- self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler, batch_size=hparams.batch_size, drop_last=False)
1048
- self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler, batch_size=hparams.batch_size, drop_last=False)
1049
-
1050
- if testingDataset:
1051
- self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
1052
- else:
1053
- self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
1054
-
1055
- def train(self):
1056
- # Init the loss and accuracy reporting lists
1057
- self.training_accuracy_list = []
1058
- self.training_loss_list = []
1059
- self.validation_accuracy_list = []
1060
- self.validation_loss_list = []
1061
-
1062
- # Run the training loop for defined number of epochs
1063
- for _ in tqdm(range(self.hparams.epochs), desc='Epochs', initial=1, leave=False):
1064
- temp_loss_list = []
1065
- temp_acc_list = []
1066
- # Iterate over the DataLoader for training data
1067
- for data in tqdm(self.train_dataloader, desc='Training', leave=False):
1068
- data = data.to(self.device)
1069
- # Make sure the model is in training mode
1070
- self.model.train()
1071
-
1072
- # Zero the gradients
1073
- self.optimizer.zero_grad()
1074
-
1075
- # Perform forward pass
1076
- pred = self.model(data)
1077
-
1078
- # Compute loss
1079
- if self.hparams.loss_function.lower() == "negative log likelihood":
1080
- logp = F.log_softmax(pred, 1)
1081
- loss = F.nll_loss(logp, data.y)
1082
- elif self.hparams.loss_function.lower() == "cross entropy":
1083
- loss = F.cross_entropy(pred, data.y)
1084
-
1085
- # Save loss information for reporting
1086
- temp_loss_list.append(loss.item())
1087
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
1088
-
1089
- # Perform backward pass
1090
- loss.backward()
1091
-
1092
- # Perform optimization
1093
- self.optimizer.step()
1094
-
1095
- self.training_accuracy_list.append(np.mean(temp_acc_list).item())
1096
- self.training_loss_list.append(np.mean(temp_loss_list).item())
1097
- self.validate()
1098
- self.validation_accuracy_list.append(self.validation_accuracy)
1099
- self.validation_loss_list.append(self.validation_loss)
1100
- gc.collect()
1101
-
1102
- def validate(self):
1103
- temp_loss_list = []
1104
- temp_acc_list = []
1105
- self.model.eval()
1106
- for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
1107
- data = data.to(self.device)
1108
- pred = self.model(data)
1109
- if self.hparams.loss_function.lower() == "negative log likelihood":
1110
- logp = F.log_softmax(pred, 1)
1111
- loss = F.nll_loss(logp, data.y)
1112
- elif self.hparams.loss_function.lower() == "cross entropy":
1113
- loss = F.cross_entropy(pred, data.y)
1114
- temp_loss_list.append(loss.item())
1115
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
1116
- self.validation_accuracy = np.mean(temp_acc_list).item()
1117
- self.validation_loss = np.mean(temp_loss_list).item()
1118
-
1119
- def test(self):
1120
- if self.test_dataloader:
1121
- temp_loss_list = []
1122
- temp_acc_list = []
1123
- self.model.eval()
1124
- for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
1125
- data = data.to(self.device)
1126
- pred = self.model(data)
1127
- if self.hparams.loss_function.lower() == "negative log likelihood":
1128
- logp = F.log_softmax(pred, 1)
1129
- loss = F.nll_loss(logp, data.y)
1130
- elif self.hparams.loss_function.lower() == "cross entropy":
1131
- loss = F.cross_entropy(pred, data.y)
1132
- temp_loss_list.append(loss.item())
1133
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
1134
- self.testing_accuracy = np.mean(temp_acc_list).item()
1135
- self.testing_loss = np.mean(temp_loss_list).item()
1136
-
1137
- def save(self, path):
1138
- if path:
1139
- # Make sure the file extension is .pt
1140
- ext = path[-3:]
1141
- if ext.lower() != ".pt":
1142
- path = path + ".pt"
1143
- torch.save(self.model.state_dict(), path)
1144
-
1145
- def load(self, path):
1146
- #self.model.load_state_dict(torch.load(path))
1147
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
1148
-
1149
- class _NodeRegressorHoldout:
1150
- def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
1151
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1152
- self.trainingDataset = trainingDataset
1153
- self.validationDataset = validationDataset
1154
- self.testingDataset = testingDataset
1155
- self.hparams = hparams
1156
- if hparams.conv_layer_type.lower() == 'sageconv':
1157
- # pooling is set None for Node regressor
1158
- self.model = _SAGEConv(trainingDataset[0].num_node_features, hparams.hl_widths, 1, None).to(self.device)
1159
- else:
1160
- raise NotImplementedError
1161
-
1162
- if hparams.optimizer_str.lower() == "adadelta":
1163
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
1164
- lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
1165
- elif hparams.optimizer_str.lower() == "adagrad":
1166
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
1167
- lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
1168
- elif hparams.optimizer_str.lower() == "adam":
1169
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
1170
- lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
1171
-
1172
- self.use_gpu = hparams.use_gpu
1173
- self.training_loss_list = []
1174
- self.validation_loss_list = []
1175
- self.node_attr_key = trainingDataset[0].x.shape[1]
1176
-
1177
- # Train, validate, test split
1178
- num_train = int(len(trainingDataset) * hparams.split[0])
1179
- num_validate = int(len(trainingDataset) * hparams.split[1])
1180
- num_test = len(trainingDataset) - num_train - num_validate
1181
- idx = torch.randperm(len(trainingDataset))
1182
- train_sampler = SubsetRandomSampler(idx[:num_train])
1183
- validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
1184
- test_sampler = SubsetRandomSampler(idx[num_train+num_validate:])
1185
-
1186
- if validationDataset:
1187
- self.train_dataloader = DataLoader(trainingDataset,
1188
- batch_size=hparams.batch_size,
1189
- drop_last=False)
1190
- self.validate_dataloader = DataLoader(validationDataset,
1191
- batch_size=hparams.batch_size,
1192
- drop_last=False)
1193
- else:
1194
- self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler,
1195
- batch_size=hparams.batch_size,
1196
- drop_last=False)
1197
- self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler,
1198
- batch_size=hparams.batch_size,
1199
- drop_last=False)
1200
-
1201
- if testingDataset:
1202
- self.test_dataloader = DataLoader(testingDataset,
1203
- batch_size=len(testingDataset),
1204
- drop_last=False)
1205
- else:
1206
- self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler,
1207
- batch_size=hparams.batch_size,
1208
- drop_last=False)
1209
-
1210
- def train(self):
1211
- # Init the loss and accuracy reporting lists
1212
- self.training_loss_list = []
1213
- self.validation_loss_list = []
1214
-
1215
- # Run the training loop for defined number of epochs
1216
- for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
1217
- # Iterate over the DataLoader for training data
1218
- for data in tqdm(self.train_dataloader, desc='Training', leave=False):
1219
- data = data.to(self.device)
1220
- # Make sure the model is in training mode
1221
- self.model.train()
1222
- # Zero the gradients
1223
- self.optimizer.zero_grad()
1224
-
1225
- # Perform forward pass
1226
- pred = self.model(data).to(self.device)
1227
- # Compute loss
1228
- loss = F.mse_loss(torch.flatten(pred), data.y.float())
1229
-
1230
- # Perform backward pass
1231
- loss.backward()
1232
-
1233
- # Perform optimization
1234
- self.optimizer.step()
1235
-
1236
- self.training_loss_list.append(torch.sqrt(loss).item())
1237
- self.validate()
1238
- self.validation_loss_list.append(torch.sqrt(self.validation_loss).item())
1239
- gc.collect()
1240
-
1241
- def validate(self):
1242
- self.model.eval()
1243
- for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
1244
- data = data.to(self.device)
1245
- pred = self.model(data).to(self.device)
1246
- loss = F.mse_loss(torch.flatten(pred), data.y.float())
1247
- self.validation_loss = loss
1248
-
1249
- def test(self):
1250
- self.model.eval()
1251
- for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
1252
- data = data.to(self.device)
1253
- pred = self.model(data).to(self.device)
1254
- loss = F.mse_loss(torch.flatten(pred), data.y.float())
1255
- self.testing_loss = torch.sqrt(loss).item()
1256
-
1257
- def save(self, path):
1258
- if path:
1259
- # Make sure the file extension is .pt
1260
- ext = path[-3:]
1261
- if ext.lower() != ".pt":
1262
- path = path + ".pt"
1263
- torch.save(self.model.state_dict(), path)
1264
-
1265
- def load(self, path):
1266
- #self.model.load_state_dict(torch.load(path))
1267
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
1268
-
1269
- class _NodeClassifierKFold:
1270
- def __init__(self, hparams, trainingDataset, testingDataset=None):
1271
- self.trainingDataset = trainingDataset
1272
- self.testingDataset = testingDataset
1273
- self.hparams = hparams
1274
- self.testing_accuracy = 0
1275
- self.accuracies = []
1276
- self.max_accuracy = 0
1277
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1278
-
1279
- if hparams.conv_layer_type.lower() == 'sageconv':
1280
- # pooling is set None for Node classifier
1281
- self.model = _SAGEConv(trainingDataset.num_node_features, hparams.hl_widths,
1282
- trainingDataset.num_classes, None).to(self.device)
1283
- else:
1284
- raise NotImplementedError
1285
-
1286
- if hparams.optimizer_str.lower() == "adadelta":
1287
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
1288
- lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
1289
- elif hparams.optimizer_str.lower() == "adagrad":
1290
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
1291
- lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
1292
- elif hparams.optimizer_str.lower() == "adam":
1293
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
1294
- lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
1295
- self.use_gpu = hparams.use_gpu
1296
- self.training_loss_list = []
1297
- self.validation_loss_list = []
1298
- self.training_accuracy_list = []
1299
- self.validation_accuracy_list = []
1300
-
1301
- def reset_weights(self):
1302
- if self.hparams.conv_layer_type.lower() == 'sageconv':
1303
- # pooling is set None for Node classifier
1304
- self.model = _SAGEConv(self.trainingDataset.num_node_features, self.hparams.hl_widths,
1305
- self.trainingDataset.num_classes, None).to(self.device)
1306
- else:
1307
- raise NotImplementedError
1308
-
1309
- if self.hparams.optimizer_str.lower() == "adadelta":
1310
- self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=self.hparams.eps,
1311
- lr=self.hparams.lr, rho=self.hparams.rho, weight_decay=self.hparams.weight_decay)
1312
- elif self.hparams.optimizer_str.lower() == "adagrad":
1313
- self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=self.hparams.eps,
1314
- lr=self.hparams.lr, lr_decay=self.hparams.lr_decay, weight_decay=self.hparams.weight_decay)
1315
- elif self.hparams.optimizer_str.lower() == "adam":
1316
- self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=self.hparams.amsgrad, betas=self.hparams.betas, eps=self.hparams.eps,
1317
- lr=self.hparams.lr, maximize=self.hparams.maximize, weight_decay=self.hparams.weight_decay)
1318
-
1319
- def train(self):
1320
- k_folds = self.hparams.k_folds
1321
-
1322
- # Init the loss and accuracy reporting lists
1323
- self.training_accuracy_list = []
1324
- self.training_loss_list = []
1325
- self.validation_accuracy_list = []
1326
- self.validation_loss_list = []
1327
-
1328
- # Set fixed random number seed
1329
- torch.manual_seed(42)
1330
-
1331
- # Define the K-fold Cross Validator
1332
- kfold = KFold(n_splits=k_folds, shuffle=True)
1333
-
1334
- models = []
1335
- weights = []
1336
- accuracies = []
1337
- train_dataloaders = []
1338
- validate_dataloaders = []
1339
-
1340
- # K-fold Cross-validation model evaluation
1341
- for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", initial=1, total=k_folds, leave=False):
1342
- epoch_training_loss_list = []
1343
- epoch_training_accuracy_list = []
1344
- epoch_validation_loss_list = []
1345
- epoch_validation_accuracy_list = []
1346
- # Sample elements randomly from a given list of ids, no replacement.
1347
- train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
1348
- validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
1349
-
1350
- # Define data loaders for training and testing data in this fold
1351
- self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler,
1352
- batch_size=self.hparams.batch_size,
1353
- drop_last=False)
1354
- self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler,
1355
- batch_size=self.hparams.batch_size,
1356
- drop_last=False)
1357
- # Init the neural network
1358
- self.reset_weights()
1359
-
1360
- # Run the training loop for defined number of epochs
1361
- for _ in tqdm(range(0,self.hparams.epochs), desc='Epochs', initial=1, total=self.hparams.epochs, leave=False):
1362
- temp_loss_list = []
1363
- temp_acc_list = []
1364
-
1365
- # Iterate over the DataLoader for training data
1366
- for data in tqdm(self.train_dataloader, desc='Training', leave=False):
1367
- data = data.to(self.device)
1368
- # Make sure the model is in training mode
1369
- self.model.train()
1370
-
1371
- # Zero the gradients
1372
- self.optimizer.zero_grad()
1373
-
1374
- # Perform forward pass
1375
- pred = self.model(data)
1376
-
1377
- # Compute loss
1378
- if self.hparams.loss_function.lower() == "negative log likelihood":
1379
- logp = F.log_softmax(pred, 1)
1380
- loss = F.nll_loss(logp, data.y)
1381
- elif self.hparams.loss_function.lower() == "cross entropy":
1382
- loss = F.cross_entropy(pred, data.y)
1383
-
1384
- # Save loss information for reporting
1385
- temp_loss_list.append(loss.item())
1386
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
1387
-
1388
- # Perform backward pass
1389
- loss.backward()
1390
-
1391
- # Perform optimization
1392
- self.optimizer.step()
1393
-
1394
- epoch_training_accuracy_list.append(np.mean(temp_acc_list).item())
1395
- epoch_training_loss_list.append(np.mean(temp_loss_list).item())
1396
- self.validate()
1397
- epoch_validation_accuracy_list.append(self.validation_accuracy)
1398
- epoch_validation_loss_list.append(self.validation_loss)
1399
- gc.collect()
1400
- models.append(self.model)
1401
- weights.append(copy.deepcopy(self.model.state_dict()))
1402
- accuracies.append(self.validation_accuracy)
1403
- train_dataloaders.append(self.train_dataloader)
1404
- validate_dataloaders.append(self.validate_dataloader)
1405
- self.training_accuracy_list.append(epoch_training_accuracy_list)
1406
- self.training_loss_list.append(epoch_training_loss_list)
1407
- self.validation_accuracy_list.append(epoch_validation_accuracy_list)
1408
- self.validation_loss_list.append(epoch_validation_loss_list)
1409
- self.accuracies = accuracies
1410
- max_accuracy = max(accuracies)
1411
- self.max_accuracy = max_accuracy
1412
- ind = accuracies.index(max_accuracy)
1413
- self.model = models[ind]
1414
- self.model.load_state_dict(weights[ind])
1415
- self.model.eval()
1416
- self.training_accuracy_list = self.training_accuracy_list[ind]
1417
- self.training_loss_list = self.training_loss_list[ind]
1418
- self.validation_accuracy_list = self.validation_accuracy_list[ind]
1419
- self.validation_loss_list = self.validation_loss_list[ind]
1420
-
1421
- def validate(self):
1422
- temp_loss_list = []
1423
- temp_acc_list = []
1424
- self.model.eval()
1425
- for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
1426
- data = data.to(self.device)
1427
- pred = self.model(data)
1428
- if self.hparams.loss_function.lower() == "negative log likelihood":
1429
- logp = F.log_softmax(pred, 1)
1430
- loss = F.nll_loss(logp, data.y)
1431
- elif self.hparams.loss_function.lower() == "cross entropy":
1432
- loss = F.cross_entropy(pred, data.y)
1433
- temp_loss_list.append(loss.item())
1434
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
1435
- self.validation_accuracy = np.mean(temp_acc_list).item()
1436
- self.validation_loss = np.mean(temp_loss_list).item()
1437
-
1438
- def test(self):
1439
- if self.testingDataset:
1440
- self.test_dataloader = DataLoader(self.testingDataset,
1441
- batch_size=len(self.testingDataset),
1442
- drop_last=False)
1443
- temp_loss_list = []
1444
- temp_acc_list = []
1445
- self.model.eval()
1446
- for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
1447
- data = data.to(self.device)
1448
- pred = self.model(data)
1449
- if self.hparams.loss_function.lower() == "negative log likelihood":
1450
- logp = F.log_softmax(pred, 1)
1451
- loss = F.nll_loss(logp, data.y)
1452
- elif self.hparams.loss_function.lower() == "cross entropy":
1453
- loss = F.cross_entropy(pred, data.y)
1454
- temp_loss_list.append(loss.item())
1455
- temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
1456
- self.testing_accuracy = np.mean(temp_acc_list).item()
1457
- self.testing_loss = np.mean(temp_loss_list).item()
1458
-
1459
- def save(self, path):
1460
- if path:
1461
- # Make sure the file extension is .pt
1462
- ext = path[-3:]
1463
- if ext.lower() != ".pt":
1464
- path = path + ".pt"
1465
- torch.save(self.model.state_dict(), path)
1466
-
1467
- def load(self, path):
1468
- #self.model.load_state_dict(torch.load(path))
1469
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
1470
-
1471
- class _NodeRegressorKFold:
1472
- def __init__(self, hparams, trainingDataset, testingDataset=None):
1473
- self.trainingDataset = trainingDataset
1474
- self.testingDataset = testingDataset
1475
- self.hparams = hparams
1476
- self.losses = []
1477
- self.min_loss = 0
1478
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1479
-
1480
- self.model = self._initialize_model(hparams, trainingDataset)
1481
- self.optimizer = self._initialize_optimizer(hparams)
1482
-
1483
- self.use_gpu = hparams.use_gpu
1484
- self.training_loss_list = []
1485
- self.validation_loss_list = []
1486
- self.node_attr_key = trainingDataset.node_attr_key
1487
-
1488
- # Train, validate, test split
1489
- num_train = int(len(trainingDataset) * hparams.split[0])
1490
- num_validate = int(len(trainingDataset) * hparams.split[1])
1491
- num_test = len(trainingDataset) - num_train - num_validate
1492
- idx = torch.randperm(len(trainingDataset))
1493
- test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
1494
-
1495
- if testingDataset:
1496
- self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
1497
- else:
1498
- self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
1499
-
1500
- def _initialize_model(self, hparams, dataset):
1501
- if hparams.conv_layer_type.lower() == 'sageconv':
1502
- # pooling is set None for Node
1503
- return _SAGEConv(dataset.num_node_features, hparams.hl_widths, 1, None).to(self.device)
248
+ class _GraphHead(nn.Module):
249
+ def __init__(self, in_dim: int, out_dim: int, pooling: PoolingKind = "mean", dropout: float = 0.1):
250
+ super().__init__()
251
+ self.dropout = float(dropout)
252
+
253
+ if pooling == "mean":
254
+ self.pool = global_mean_pool
255
+ elif pooling == "max":
256
+ self.pool = global_max_pool
257
+ elif pooling == "sum":
258
+ self.pool = global_add_pool
1504
259
  else:
1505
- raise NotImplementedError
1506
-
1507
- def _initialize_optimizer(self, hparams):
1508
- if hparams.optimizer_str.lower() == "adadelta":
1509
- return torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
1510
- elif hparams.optimizer_str.lower() == "adagrad":
1511
- return torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
1512
- elif hparams.optimizer_str.lower() == "adam":
1513
- return torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps, lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
1514
-
1515
- def reset_weights(self):
1516
- self.model = self._initialize_model(self.hparams, self.trainingDataset)
1517
- self.optimizer = self._initialize_optimizer(self.hparams)
1518
-
1519
- def train(self):
1520
- k_folds = self.hparams.k_folds
1521
- torch.manual_seed(42)
1522
-
1523
- kfold = KFold(n_splits=k_folds, shuffle=True)
1524
- models, weights, losses, train_dataloaders, validate_dataloaders = [], [], [], [], []
1525
-
1526
- for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", total=k_folds, leave=False):
1527
- epoch_training_loss_list, epoch_validation_loss_list = [], []
1528
- train_subsampler = SubsetRandomSampler(train_ids)
1529
- validate_subsampler = SubsetRandomSampler(validate_ids)
1530
-
1531
- self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
1532
- self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
1533
-
1534
- self.reset_weights()
1535
- best_rmse = np.inf
1536
-
1537
- for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
1538
- for batched_graph in tqdm(self.train_dataloader, desc='Training', leave=False):
1539
- self.model.train()
1540
- self.optimizer.zero_grad()
1541
-
1542
- batched_graph = batched_graph.to(self.device)
1543
- pred = self.model(batched_graph)
1544
- loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
1545
- loss.backward()
1546
- self.optimizer.step()
1547
-
1548
- epoch_training_loss_list.append(torch.sqrt(loss).item())
1549
- self.validate()
1550
- epoch_validation_loss_list.append(torch.sqrt(self.validation_loss).item())
1551
- gc.collect()
1552
-
1553
- models.append(self.model)
1554
- weights.append(copy.deepcopy(self.model.state_dict()))
1555
- losses.append(torch.sqrt(self.validation_loss).item())
1556
- train_dataloaders.append(self.train_dataloader)
1557
- validate_dataloaders.append(self.validate_dataloader)
1558
- self.training_loss_list.append(epoch_training_loss_list)
1559
- self.validation_loss_list.append(epoch_validation_loss_list)
1560
-
1561
- self.losses = losses
1562
- self.min_loss = min(losses)
1563
- ind = losses.index(self.min_loss)
1564
- self.model = models[ind]
1565
- self.model.load_state_dict(weights[ind])
1566
- self.model.eval()
1567
- self.training_loss_list = self.training_loss_list[ind]
1568
- self.validation_loss_list = self.validation_loss_list[ind]
260
+ raise ValueError("GraphHead requires pooling in {'mean','max','sum'}.")
261
+
262
+ self.mlp = nn.Sequential(
263
+ nn.Linear(in_dim, in_dim),
264
+ nn.ReLU(),
265
+ nn.Dropout(self.dropout),
266
+ nn.Linear(in_dim, out_dim),
267
+ )
268
+
269
+ def forward(self, node_emb, batch):
270
+ g = self.pool(node_emb, batch)
271
+ return self.mlp(g)
272
+
273
+
274
+ class _NodeHead(nn.Module):
275
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
276
+ super().__init__()
277
+ self.mlp = nn.Sequential(
278
+ nn.Linear(in_dim, in_dim),
279
+ nn.ReLU(),
280
+ nn.Dropout(float(dropout)),
281
+ nn.Linear(in_dim, out_dim),
282
+ )
283
+
284
+ def forward(self, node_emb):
285
+ return self.mlp(node_emb)
286
+
287
+
288
+ class _EdgeHead(nn.Module):
289
+ """
290
+ Edge prediction head using concatenation of endpoint embeddings.
291
+ """
292
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
293
+ super().__init__()
294
+ self.mlp = nn.Sequential(
295
+ nn.Linear(in_dim * 2, in_dim),
296
+ nn.ReLU(),
297
+ nn.Dropout(float(dropout)),
298
+ nn.Linear(in_dim, out_dim),
299
+ )
300
+
301
+ def forward(self, node_emb, edge_index):
302
+ src, dst = edge_index[0], edge_index[1]
303
+ h = torch.cat([node_emb[src], node_emb[dst]], dim=-1)
304
+ return self.mlp(h)
305
+
306
+
307
+ class _LinkPredictor(nn.Module):
308
+ """
309
+ Binary link predictor (edge exists or not).
310
+ """
311
+ def __init__(self, in_dim: int, hidden: int = 64, dropout: float = 0.1):
312
+ super().__init__()
313
+ self.net = nn.Sequential(
314
+ nn.Linear(in_dim * 2, hidden),
315
+ nn.ReLU(),
316
+ nn.Dropout(float(dropout)),
317
+ nn.Linear(hidden, 1),
318
+ )
319
+
320
+ def forward(self, node_emb, edge_label_index):
321
+ src, dst = edge_label_index[0], edge_label_index[1]
322
+ h = torch.cat([node_emb[src], node_emb[dst]], dim=-1)
323
+ return self.net(h).squeeze(-1) # logits
1569
324
 
1570
- def validate(self):
1571
- self.model.eval()
1572
- for batched_graph in tqdm(self.validate_dataloader, desc='Validating', leave=False):
1573
- batched_graph = batched_graph.to(self.device)
1574
- pred = self.model(batched_graph)
1575
- loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
1576
- self.validation_loss = loss
1577
-
1578
- def test(self):
1579
- self.model.eval()
1580
- for batched_graph in tqdm(self.test_dataloader, desc='Testing', leave=False):
1581
- batched_graph = batched_graph.to(self.device)
1582
- pred = self.model(batched_graph)
1583
- loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
1584
- self.testing_loss = torch.sqrt(loss).item()
1585
-
1586
- def save(self, path):
1587
- if path:
1588
- ext = path[-3:]
1589
- if ext.lower() != ".pt":
1590
- path = path + ".pt"
1591
- torch.save(self.model.state_dict(), path)
1592
-
1593
- def load(self, path):
1594
- self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
1595
325
 
1596
326
  class PyG:
327
+ """
328
+ A clean PyTorch Geometric interface for TopologicPy-exported CSV datasets.
329
+
330
+ You can control medium-level hyperparameters by passing keyword arguments to ByCSVPath,
331
+ for example:
332
+
333
+ pyg = PyG.ByCSVPath(
334
+ path="C:/dataset",
335
+ level="graph",
336
+ task="classification",
337
+ graphLabelType="categorical",
338
+ cv="kfold",
339
+ k_folds=5,
340
+ conv="gatv2",
341
+ hidden_dims=(128, 128, 64),
342
+ activation="gelu",
343
+ batch_norm=True,
344
+ residual=True,
345
+ dropout=0.2,
346
+ lr=1e-3,
347
+ optimizer="adamw",
348
+ early_stopping=True,
349
+ early_stopping_patience=10,
350
+ gradient_clip_norm=1.0
351
+ )
352
+ """
353
+
354
+ # ---------
355
+ # Creation
356
+ # ---------
1597
357
  @staticmethod
1598
- def DatasetByCSVPath(path, numberOfGraphClasses=0, nodeATTRKey='feat', edgeATTRKey='feat', nodeOneHotEncode=False,
1599
- nodeFeaturesCategories=[], edgeOneHotEncode=False, edgeFeaturesCategories=[], addSelfLoop=False,
1600
- node_level=False, graph_level=True):
1601
- """
1602
- Returns PyTorch Geometric dataset according to the input CSV folder path. The folder must contain "graphs.csv",
1603
- "edges.csv", "nodes.csv", and "meta.yml" files according to conventions.
358
+ def ByCSVPath(path: str,
359
+ level: Level = "graph",
360
+ task: TaskKind = "classification",
361
+ graphLabelType: LabelType = "categorical",
362
+ nodeLabelType: LabelType = "categorical",
363
+ edgeLabelType: LabelType = "categorical",
364
+ **kwargs) -> "PyG":
365
+ cfg = _RunConfig(level=level, task=task,
366
+ graph_label_type=graphLabelType,
367
+ node_label_type=nodeLabelType,
368
+ edge_label_type=edgeLabelType)
1604
369
 
1605
- Parameters
1606
- ----------
1607
- path : str
1608
- The path to the folder containing the necessary CSV and YML files.
370
+ # allow override of any config field via kwargs
371
+ for k, v in kwargs.items():
372
+ if hasattr(cfg, k):
373
+ setattr(cfg, k, v)
1609
374
 
1610
- Returns
1611
- -------
1612
- PyG Dataset
1613
- The PyG dataset
1614
- """
1615
- if not isinstance(path, str):
1616
- print("PyG.DatasetByCSVPath - Error: The input path parameter is not a valid string. Returning None.")
1617
- return None
1618
- if not os.path.exists(path):
1619
- print("PyG.DatasetByCSVPath - Error: The input path parameter does not exist. Returning None.")
1620
- return None
1621
-
1622
- return CustomGraphDataset(root=path, node_level=node_level, graph_level=graph_level, node_attr_key=nodeATTRKey, edge_attr_key=edgeATTRKey)
1623
-
1624
- @staticmethod
1625
- def DatasetGraphLabels(dataset, graphLabelHeader="label"):
1626
- """
1627
- Returns the labels of the graphs in the input dataset
1628
-
1629
- Parameters
1630
- ----------
1631
- dataset : CustomDataset
1632
- The input dataset
1633
- graphLabelHeader: str , optional
1634
- The key string under which the graph labels are stored. Default is "label".
1635
-
1636
- Returns
1637
- -------
1638
- list
1639
- The list of graph labels.
1640
- """
1641
- import torch
375
+ return PyG(path=path, config=cfg)
1642
376
 
1643
- graph_labels = []
1644
- for g in dataset:
1645
- # Get the label of the graph
1646
- label = g.y
1647
- graph_labels.append(label.item())
1648
- return graph_labels
377
+ def __init__(self, path: str, config: _RunConfig):
378
+ self.path = path
379
+ self.config = config
1649
380
 
1650
- @staticmethod
1651
- def DatasetSplit(dataset, split=[0.8,0.1,0.1], shuffle=True, randomState=42):
1652
- """
1653
- Splits the dataset into three subsets.
381
+ self.device = torch.device("cuda:0" if (config.use_gpu and torch.cuda.is_available()) else "cpu")
1654
382
 
1655
- Parameters
1656
- ----------
1657
- dataset : CustomDataset
1658
- The input dataset
1659
- split: list , optional
1660
- The list of ratios. This list must be made out of three numbers adding to 1.
1661
- shuffle: boolean , optional
1662
- If set to True, the subsets are created from random indices. Otherwise, they are split sequentially. Default is True.
1663
- randomState : int , optional
1664
- The random seed to use for reproducibility. Default is 42.
1665
-
1666
- Returns
1667
- -------
1668
- list
1669
- The list of three subset datasets.
1670
- """
383
+ self.graph_df: Optional[pd.DataFrame] = None
384
+ self.nodes_df: Optional[pd.DataFrame] = None
385
+ self.edges_df: Optional[pd.DataFrame] = None
1671
386
 
1672
- import torch
1673
- from torch.utils.data import random_split
1674
- train_ratio, val_ratio, test_ratio = split
1675
- assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must add up to 1."
1676
-
1677
- # Calculate the number of samples for each split
1678
- dataset_len = len(dataset)
1679
- train_len = int(train_ratio * dataset_len)
1680
- val_len = int(val_ratio * dataset_len)
1681
- test_len = dataset_len - train_len - val_len # Ensure it adds up correctly
1682
-
1683
- ## Generate indices for the split
1684
- indices = list(range(dataset_len))
1685
- if shuffle:
1686
- torch.manual_seed(randomState) # For reproducibility
1687
- indices = torch.randperm(dataset_len).tolist() # Shuffled indices
1688
-
1689
- # Create splits
1690
- train_indices = indices[:train_len]
1691
- val_indices = indices[train_len:train_len + val_len]
1692
- test_indices = indices[train_len + val_len:train_len + val_len + test_len]
1693
-
1694
- # Create new instances of CustomGraphDataset using the indices
1695
- train_dataset = CustomGraphDataset(data_list=dataset.data_list, indices=train_indices)
1696
- val_dataset = CustomGraphDataset(data_list=dataset.data_list, indices=val_indices)
1697
- test_dataset = CustomGraphDataset(data_list=dataset.data_list, indices=test_indices)
1698
-
1699
- return train_dataset, val_dataset, test_dataset
387
+ self.data_list: List[Data] = []
388
+ self.train_set: Optional[List[Data]] = None
389
+ self.val_set: Optional[List[Data]] = None
390
+ self.test_set: Optional[List[Data]] = None
1700
391
 
1701
- @staticmethod
1702
- def Optimizer(name="Adam", amsgrad=True, betas=(0.9,0.999), eps=0.000001, lr=0.001, maximize=False, weightDecay=0.0, rho=0.9, lr_decay=0.0):
1703
- """
1704
- Returns the parameters of the optimizer
392
+ self.model: Optional[nn.Module] = None
393
+ self.history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}
394
+ self.cv_report: Optional[Dict[str, Union[float, List[Dict[str, float]]]]] = None
1705
395
 
1706
- Parameters
1707
- ----------
1708
- amsgrad : bool , optional.
1709
- amsgrad is an extension to the Adam version of gradient descent that attempts to improve the convergence properties of the algorithm, avoiding large abrupt changes in the learning rate for each input variable. Default is True.
1710
- betas : tuple , optional
1711
- Betas are used as for smoothing the path to the convergence also providing some momentum to cross a local minima or saddle point. Default is (0.9, 0.999).
1712
- eps : float . optional.
1713
- eps is a term added to the denominator to improve numerical stability. Default is 0.000001.
1714
- lr : float
1715
- The learning rate (lr) defines the adjustment in the weights of our network with respect to the loss gradient descent. Default is 0.001.
1716
- maximize : float , optional
1717
- maximize the params based on the objective, instead of minimizing. Default is False.
1718
- weightDecay : float , optional
1719
- weightDecay (L2 penalty) is a regularization technique applied to the weights of a neural network. Default is 0.0.
396
+ self._num_outputs: int = 1
1720
397
 
1721
- Returns
1722
- -------
1723
- dict
1724
- The dictionary of the optimizer parameters. The dictionary contains the following keys and values:
1725
- - "name" (str): The name of the optimizer
1726
- - "amsgrad" (bool):
1727
- - "betas" (tuple):
1728
- - "eps" (float):
1729
- - "lr" (float):
1730
- - "maximize" (bool):
1731
- - weightDecay (float):
398
+ self._load_csv()
399
+ self._build_data_list()
400
+ self._split_holdout()
1732
401
 
1733
- """
1734
- return {"name":name, "amsgrad":amsgrad, "betas":betas, "eps":eps, "lr": lr, "maximize":maximize, "weight_decay":weightDecay, "rho":rho, "lr_decay":lr_decay}
1735
-
1736
- @staticmethod
1737
- def Hyperparameters(optimizer, model_type="classifier", cv_type="Holdout", split=[0.8,0.1,0.1], k_folds=5,
1738
- hl_widths=[32], conv_layer_type="SAGEConv", pooling="AvgPooling",
1739
- batch_size=1, epochs=1, use_gpu=False, loss_function="Cross Entropy",
1740
- input_type="graph"):
1741
- """
1742
- Creates a hyperparameters object based on the input settings.
402
+ self._build_model()
1743
403
 
1744
- Parameters
1745
- ----------
1746
- model_type : str , optional
1747
- The desired type of model. The options are:
1748
- - "Classifier"
1749
- - "Regressor"
1750
- The option is case insensitive. Default is "classifierholdout"
1751
- optimizer : Optimizer
1752
- The desired optimizer.
1753
- cv_type : str , optional
1754
- The desired cross-validation method. This can be "Holdout" or "K-Fold". It is case-insensitive. Default is "Holdout".
1755
- split : list , optional
1756
- The desired split between training validation, and testing. [0.8, 0.1, 0.1] means that 80% of the data is used for training 10% of the data is used for validation, and 10% is used for testing. Default is [0.8, 0.1, 0.1].
1757
- k_folds : int , optional
1758
- The desired number of k-folds. Default is 5.
1759
- hl_widths : list , optional
1760
- The list of hidden layer widths. A list of [16, 32, 16] means that the model will have 3 hidden layers with number of neurons in each being 16, 32, 16 respectively from input to output. Default is [32].
1761
- conv_layer_type : str , optional
1762
- The desired type of the convolution layer. The options are "Classic", "GraphConv", "GINConv", "SAGEConv", "TAGConv", "DGN". It is case insensitive. Default is "SAGEConv".
1763
- pooling : str , optional
1764
- The desired type of pooling. The options are "AvgPooling", "MaxPooling", or "SumPooling". It is case insensitive. Default is "AvgPooling".
1765
- batch_size : int , optional
1766
- The desired batch size. Default is 1.
1767
- epochs : int , optional
1768
- The desired number of epochs. Default is 1.
1769
- use_gpu : bool , optional
1770
- If set to True, the model will attempt to use the GPU. Default is False.
1771
- loss_function : str , optional
1772
- The desired loss function. The options are "Cross-Entropy" or "Negative Log Likelihood". It is case insensitive. Default is "Cross-Entropy".
1773
- input_type : str
1774
- selects the input_type of model such as graph, node or edge
1775
- Returns
1776
- -------
1777
- Hyperparameters
1778
- The created hyperparameters object.
1779
404
 
405
+ # ----------------------------
406
+ # Convenience: hyperparameters
407
+ # ----------------------------
408
+ def SetHyperparameters(self, **kwargs) -> Dict[str, Union[str, int, float, bool, Tuple]]:
1780
409
  """
1781
-
1782
- if optimizer['name'].lower() == "adadelta":
1783
- optimizer_str = "Adadelta"
1784
- elif optimizer['name'].lower() == "adagrad":
1785
- optimizer_str = "Adagrad"
1786
- elif optimizer['name'].lower() == "adam":
1787
- optimizer_str = "Adam"
1788
- return _Hparams(model_type,
1789
- optimizer_str,
1790
- optimizer['amsgrad'],
1791
- optimizer['betas'],
1792
- optimizer['eps'],
1793
- optimizer['lr'],
1794
- optimizer['lr_decay'],
1795
- optimizer['maximize'],
1796
- optimizer['rho'],
1797
- optimizer['weight_decay'],
1798
- cv_type,
1799
- split,
1800
- k_folds,
1801
- hl_widths,
1802
- conv_layer_type,
1803
- pooling,
1804
- batch_size,
1805
- epochs,
1806
- use_gpu,
1807
- loss_function,
1808
- input_type)
1809
-
1810
- @staticmethod
1811
- def Model(hparams, trainingDataset, validationDataset=None, testingDataset=None):
410
+ Set one or more configuration values (hyperparameters) in a safe, readable way.
411
+
412
+ Examples
413
+ --------
414
+ pyg.SetHyperparameters(
415
+ cv="kfold", k_folds=5, k_stratify=True,
416
+ conv="gatv2", hidden_dims=(128, 128, 64),
417
+ activation="gelu", batch_norm=True, residual=True,
418
+ lr=1e-3, optimizer="adamw",
419
+ early_stopping=True, early_stopping_patience=10,
420
+ gradient_clip_norm=1.0
421
+ )
422
+
423
+ Notes
424
+ -----
425
+ - Unknown keys are ignored (with a warning if verbose=True).
426
+ - Some values are validated (e.g., split sums, hidden_dims).
427
+ - If you change model-related settings (conv/hidden_dims/etc.) the model is rebuilt automatically.
1812
428
  """
1813
- Creates a neural network classifier.
1814
-
1815
- Parameters
1816
- ----------
1817
- hparams : HParams
1818
- The input hyperparameters
1819
- trainingDataset : CustomDataset
1820
- The input training dataset.
1821
- validationDataset : CustomDataset
1822
- The input validation dataset. If not specified, a portion of the trainingDataset will be used for validation according the to the split list as specified in the hyper-parameters.
1823
- testingDataset : CustomDataset
1824
- The input testing dataset. If not specified, a portion of the trainingDataset will be used for testing according the to the split list as specified in the hyper-parameters.
1825
-
1826
- Returns
1827
- -------
1828
- Classifier
1829
- The created classifier
1830
-
429
+ cfg = self.config
430
+ changed_model = False
431
+
432
+ for k, v in kwargs.items():
433
+ if not hasattr(cfg, k):
434
+ if cfg.verbose:
435
+ print(f"PyG.SetHyperparameters - Warning: Unknown parameter '{k}' ignored.")
436
+ continue
437
+
438
+ # Basic validation / normalisation
439
+ if k == "split":
440
+ if (not isinstance(v, (tuple, list))) or len(v) != 3:
441
+ raise ValueError("split must be a 3-tuple, e.g. (0.8, 0.1, 0.1).")
442
+ s = float(v[0]) + float(v[1]) + float(v[2])
443
+ if abs(s - 1.0) > 1e-3:
444
+ raise ValueError("split ratios must sum to 1.")
445
+ v = (float(v[0]), float(v[1]), float(v[2]))
446
+
447
+ if k == "hidden_dims":
448
+ if isinstance(v, list):
449
+ v = tuple(v)
450
+ if (not isinstance(v, tuple)) or len(v) == 0:
451
+ raise ValueError("hidden_dims must be a non-empty tuple, e.g. (64, 64).")
452
+ v = tuple(int(x) for x in v)
453
+ changed_model = True
454
+
455
+ if k in ["conv", "activation", "dropout", "batch_norm", "residual", "pooling"]:
456
+ changed_model = True
457
+
458
+ setattr(cfg, k, v)
459
+
460
+ # rebuild model if needed
461
+ if changed_model:
462
+ self._build_model()
463
+
464
+ return self.Summary()
465
+
466
+ def Summary(self) -> Dict[str, Union[str, int, float, bool, Tuple]]:
1831
467
  """
1832
-
1833
- model = None
1834
- if hparams.model_type.lower() == "classifier":
1835
- if hparams.input_type == 'graph':
1836
- if hparams.cv_type.lower() == "holdout":
1837
- model = _GraphClassifierHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
1838
- elif "k" in hparams.cv_type.lower():
1839
- model = _GraphClassifierKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
1840
- elif hparams.input_type == 'node':
1841
- if hparams.cv_type.lower() == "holdout":
1842
- model = _NodeClassifierHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
1843
- elif "k" in hparams.cv_type.lower():
1844
- model = _NodeClassifierKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
1845
- elif hparams.model_type.lower() == "regressor":
1846
- if hparams.input_type == 'graph':
1847
- if hparams.cv_type.lower() == "holdout":
1848
- model = _GraphRegressorHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
1849
- elif "k" in hparams.cv_type.lower():
1850
- model = _GraphRegressorKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
1851
- elif hparams.input_type == 'node':
1852
- if hparams.cv_type.lower() == "holdout":
1853
- model = _NodeRegressorHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
1854
- elif "k" in hparams.cv_type.lower():
1855
- model = _NodeRegressorKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
1856
- else:
1857
- raise NotImplementedError
1858
- return model
1859
-
1860
- @staticmethod
1861
- def ModelTrain(model):
1862
- """
1863
- Trains the neural network model.
1864
-
1865
- Parameters
1866
- ----------
1867
- model : Model
1868
- The input model.
1869
-
1870
- Returns
1871
- -------
1872
- Model
1873
- The trained model
1874
-
468
+ Return a compact dictionary of the most relevant configuration choices.
1875
469
  """
1876
- if not model:
1877
- return None
1878
- model.train()
1879
- return model
1880
-
1881
- @staticmethod
1882
- def ModelTest(model):
470
+ cfg = self.config
471
+ return {
472
+ "level": cfg.level,
473
+ "task": cfg.task,
474
+ "graph_label_type": cfg.graph_label_type,
475
+ "node_label_type": cfg.node_label_type,
476
+ "edge_label_type": cfg.edge_label_type,
477
+ "cv": cfg.cv,
478
+ "split": cfg.split,
479
+ "k_folds": cfg.k_folds,
480
+ "conv": cfg.conv,
481
+ "hidden_dims": cfg.hidden_dims,
482
+ "activation": cfg.activation,
483
+ "dropout": cfg.dropout,
484
+ "batch_norm": cfg.batch_norm,
485
+ "residual": cfg.residual,
486
+ "pooling": cfg.pooling,
487
+ "epochs": cfg.epochs,
488
+ "batch_size": cfg.batch_size,
489
+ "lr": cfg.lr,
490
+ "weight_decay": cfg.weight_decay,
491
+ "optimizer": cfg.optimizer,
492
+ "gradient_clip_norm": cfg.gradient_clip_norm,
493
+ "early_stopping": cfg.early_stopping,
494
+ "early_stopping_patience": cfg.early_stopping_patience,
495
+ "device": str(self.device),
496
+ "num_graphs": len(self.data_list),
497
+ "num_outputs": int(self._num_outputs),
498
+ }
499
+
500
+ # ----------------------------
501
+ # Convenience: CV visualisation
502
+ # ----------------------------
503
+ def PlotCrossValidationSummary(self,
504
+ cv_report: Optional[Dict[str, Union[float, List[Dict[str, float]]]]] = None,
505
+ metrics: Optional[List[str]] = None,
506
+ show_mean_std: bool = True):
1883
507
  """
1884
- Tests the neural network model.
508
+ Plot a cross-validation summary as grouped bars per fold (Plotly).
1885
509
 
1886
510
  Parameters
1887
511
  ----------
1888
- model : Model
1889
- The input model.
512
+ cv_report : dict, optional
513
+ Output from CrossValidate(). If None, uses self.cv_report.
514
+ metrics : list[str], optional
515
+ Metrics to display. If None, chooses a sensible default based on task:
516
+ - classification: ["accuracy", "f1", "precision", "recall"]
517
+ - regression : ["mae", "rmse", "r2"]
518
+ show_mean_std : bool, optional
519
+ If True, includes mean and ±std reference lines.
1890
520
 
1891
521
  Returns
1892
522
  -------
1893
- Model
1894
- The tested model
1895
-
523
+ plotly.graph_objects.Figure
1896
524
  """
1897
- if not model:
1898
- return None
1899
- model.test()
1900
- return model
1901
-
1902
- @staticmethod
1903
- def ModelSave(model, path, overwrite=False):
1904
- """
1905
- Saves the model.
1906
-
1907
- Parameters
1908
- ----------
1909
- model : Model
1910
- The input model.
1911
- path : str
1912
- The file path at which to save the model.
1913
- overwrite : bool, optional
1914
- If set to True, any existing file will be overwritten. Otherwise, it won't. Default is False.
1915
-
1916
- Returns
1917
- -------
1918
- bool
1919
- True if the model is saved correctly. False otherwise.
525
+ if cv_report is None:
526
+ cv_report = self.cv_report
527
+ if cv_report is None:
528
+ raise ValueError("No cross-validation report found. Run CrossValidate() first or pass cv_report.")
529
+
530
+ fold_metrics = cv_report.get("fold_metrics", [])
531
+ if not fold_metrics:
532
+ raise ValueError("cv_report has no fold_metrics.")
533
+
534
+ # default metrics
535
+ if metrics is None:
536
+ if self.config.task == "regression":
537
+ metrics = ["mae", "rmse", "r2"]
538
+ else:
539
+ metrics = ["accuracy", "f1", "precision", "recall"]
540
+
541
+ folds = [int(fm.get("fold", i)) for i, fm in enumerate(fold_metrics)]
542
+
543
+ fig = go.Figure()
544
+ for met in metrics:
545
+ vals = [float(fm.get(met, 0.0)) for fm in fold_metrics]
546
+ fig.add_trace(go.Bar(name=met, x=folds, y=vals))
547
+
548
+ if show_mean_std:
549
+ mean_k = f"mean_{met}"
550
+ std_k = f"std_{met}"
551
+ if mean_k in cv_report and std_k in cv_report:
552
+ mu = float(cv_report[mean_k])
553
+ sd = float(cv_report[std_k])
554
+ # mean line
555
+ fig.add_trace(go.Scatter(
556
+ x=[min(folds), max(folds)], y=[mu, mu],
557
+ mode="lines", name=f"{met} mean", line=dict(dash="dash")
558
+ ))
559
+ # +/- std (as band using two lines)
560
+ fig.add_trace(go.Scatter(
561
+ x=[min(folds), max(folds)], y=[mu + sd, mu + sd],
562
+ mode="lines", name=f"{met} +std", line=dict(dash="dot")
563
+ ))
564
+ fig.add_trace(go.Scatter(
565
+ x=[min(folds), max(folds)], y=[mu - sd, mu - sd],
566
+ mode="lines", name=f"{met} -std", line=dict(dash="dot")
567
+ ))
568
+
569
+ fig.update_layout(
570
+ barmode="group",
571
+ title="Cross-Validation Summary",
572
+ xaxis_title="Fold",
573
+ yaxis_title="Metric Value"
574
+ )
575
+ return fig
576
+
577
+ # ----------------
578
+ # Dataset loading
579
+ # ----------------
580
+ def _load_csv(self):
581
+ if not isinstance(self.path, str) or (not os.path.exists(self.path)):
582
+ raise ValueError("PyG - Error: path does not exist.")
583
+
584
+ gpath = os.path.join(self.path, "graphs.csv")
585
+ npath = os.path.join(self.path, "nodes.csv")
586
+ epath = os.path.join(self.path, "edges.csv")
587
+
588
+ if not os.path.exists(gpath) or not os.path.exists(npath) or not os.path.exists(epath):
589
+ raise ValueError("PyG - Error: graphs.csv, nodes.csv, edges.csv must exist in the folder.")
590
+
591
+ self.graph_df = pd.read_csv(gpath)
592
+ self.nodes_df = pd.read_csv(npath)
593
+ self.edges_df = pd.read_csv(epath)
594
+
595
+ def _feature_columns(self, df: pd.DataFrame, prefix: str) -> List[str]:
596
+ cols = [c for c in df.columns if c.startswith(prefix + "_")]
597
+ def _key(c):
598
+ parts = c.rsplit("_", 1)
599
+ if len(parts) == 2 and parts[1].isdigit():
600
+ return int(parts[1])
601
+ return 10**9
602
+ return sorted(cols, key=_key)
1920
603
 
1921
- """
1922
- import os
1923
-
1924
- if model == None:
1925
- print("PyG.ModelSave - Error: The input model parameter is invalid. Returning None.")
1926
- return None
1927
- if path == None:
1928
- print("PyG.ModelSave - Error: The input path parameter is invalid. Returning None.")
1929
- return None
1930
- if not overwrite and os.path.exists(path):
1931
- print("PyG.ModelSave - Error: a file already exists at the specified path and overwrite is set to False. Returning None.")
1932
- return None
1933
- if overwrite and os.path.exists(path):
1934
- os.remove(path)
1935
- # Make sure the file extension is .pt
1936
- ext = path[len(path)-3:len(path)]
1937
- if ext.lower() != ".pt":
1938
- path = path+".pt"
1939
- model.save(path)
1940
- return True
1941
-
1942
604
  @staticmethod
1943
- def ModelData(model):
1944
- """
1945
- Returns the data of the model
1946
-
1947
- Parameters
1948
- ----------
1949
- model : Model
1950
- The input model.
1951
-
1952
- Returns
1953
- -------
1954
- dict
1955
- A dictionary containing the model data. The keys in the dictionary are:
1956
- 'Model Type'
1957
- 'Optimizer'
1958
- 'CV Type'
1959
- 'Split'
1960
- 'K-Folds'
1961
- 'HL Widths'
1962
- 'Conv Layer Type'
1963
- 'Pooling'
1964
- 'Learning Rate'
1965
- 'Batch Size'
1966
- 'Epochs'
1967
- 'Training Accuracy'
1968
- 'Validation Accuracy'
1969
- 'Testing Accuracy'
1970
- 'Training Loss'
1971
- 'Validation Loss'
1972
- 'Testing Loss'
1973
- 'Accuracies' (Classifier and K-Fold only)
1974
- 'Max Accuracy' (Classifier and K-Fold only)
1975
- 'Losses' (Regressor and K-fold only)
1976
- 'min Loss' (Regressor and K-fold only)
605
+ def _infer_num_classes(values: np.ndarray) -> int:
606
+ uniq = np.unique(values[~pd.isna(values)])
607
+ return int(len(uniq))
608
+
609
+ def _build_data_list(self):
610
+ assert self.graph_df is not None and self.nodes_df is not None and self.edges_df is not None
611
+
612
+ cfg = self.config
613
+ gdf = self.graph_df
614
+ ndf = self.nodes_df
615
+ edf = self.edges_df
616
+
617
+ graph_feat_cols = self._feature_columns(gdf, cfg.graph_features_header)
618
+ node_feat_cols = self._feature_columns(ndf, cfg.node_features_header)
619
+ edge_feat_cols = self._feature_columns(edf, cfg.edge_features_header)
620
+
621
+ if len(node_feat_cols) == 0:
622
+ raise ValueError(
623
+ f"PyG - Error: No node feature columns found. "
624
+ f"Expected columns starting with '{cfg.node_features_header}_'."
625
+ )
626
+
627
+ for gid in gdf[cfg.graph_id_header].unique():
628
+ g_row = gdf[gdf[cfg.graph_id_header] == gid]
629
+ g_nodes = ndf[ndf[cfg.graph_id_header] == gid].sort_values(cfg.node_id_header)
630
+ g_edges = edf[edf[cfg.graph_id_header] == gid]
631
+
632
+ x = torch.tensor(g_nodes[node_feat_cols].values, dtype=torch.float32)
633
+ edge_index = torch.tensor(
634
+ g_edges[[cfg.edge_src_header, cfg.edge_dst_header]].values.T,
635
+ dtype=torch.long
636
+ )
637
+
638
+ data = Data(x=x, edge_index=edge_index)
639
+
640
+ if len(edge_feat_cols) > 0:
641
+ data.edge_attr = torch.tensor(g_edges[edge_feat_cols].values, dtype=torch.float32)
642
+
643
+ # graph-level
644
+ if cfg.level == "graph":
645
+ y_val = g_row[cfg.graph_label_header].values[0]
646
+ if cfg.graph_label_type == "categorical":
647
+ data.y = torch.tensor([int(y_val)], dtype=torch.long)
648
+ else:
649
+ data.y = torch.tensor([float(y_val)], dtype=torch.float32)
1977
650
 
1978
- """
1979
- from topologicpy.Helper import Helper
1980
-
1981
- data = {'Model Type': [model.hparams.model_type],
1982
- 'Optimizer': [model.hparams.optimizer_str],
1983
- 'CV Type': [model.hparams.cv_type],
1984
- 'Split': model.hparams.split,
1985
- 'K-Folds': [model.hparams.k_folds],
1986
- 'HL Widths': model.hparams.hl_widths,
1987
- 'Conv Layer Type': [model.hparams.conv_layer_type],
1988
- 'Pooling': [model.hparams.pooling],
1989
- 'Learning Rate': [model.hparams.lr],
1990
- 'Batch Size': [model.hparams.batch_size],
1991
- 'Epochs': [model.hparams.epochs]
1992
- }
1993
-
1994
- if model.hparams.model_type.lower() == "classifier":
1995
- testing_accuracy_list = [model.testing_accuracy] * model.hparams.epochs
1996
- try:
1997
- testing_loss_list = [model.testing_loss] * model.hparams.epochs
1998
- except:
1999
- testing_loss_list = [0.] * model.hparams.epochs
2000
- metrics_data = {
2001
- 'Training Accuracy': [model.training_accuracy_list],
2002
- 'Validation Accuracy': [model.validation_accuracy_list],
2003
- 'Testing Accuracy' : [testing_accuracy_list],
2004
- 'Training Loss': [model.training_loss_list],
2005
- 'Validation Loss': [model.validation_loss_list],
2006
- 'Testing Loss' : [testing_loss_list]
2007
- }
2008
- if model.hparams.cv_type.lower() == "k-fold":
2009
- accuracy_data = {
2010
- 'Accuracies' : [model.accuracies],
2011
- 'Max Accuracy' : [model.max_accuracy]
2012
- }
2013
- metrics_data.update(accuracy_data)
2014
- data.update(metrics_data)
2015
-
2016
- elif model.hparams.model_type.lower() == "regressor":
2017
- testing_loss_list = [model.testing_loss] * model.hparams.epochs
2018
- metrics_data = {
2019
- 'Training Loss': [model.training_loss_list],
2020
- 'Validation Loss': [model.validation_loss_list],
2021
- 'Testing Loss' : [testing_loss_list]
2022
- }
2023
- if model.hparams.cv_type.lower() == "k-fold":
2024
- loss_data = {
2025
- 'Losses' : [model.losses],
2026
- 'Min Loss' : [model.min_loss]
2027
- }
2028
- metrics_data.update(loss_data)
2029
- data.update(metrics_data)
2030
-
2031
- return data
2032
-
2033
- @staticmethod
2034
- def Show(data,
2035
- labels,
2036
- title="Training/Validation",
2037
- xTitle="Epochs",
2038
- xSpacing=1,
2039
- yTitle="Accuracy and Loss",
2040
- ySpacing=0.1,
2041
- useMarkers=False,
2042
- chartType="Line",
2043
- width=950,
2044
- height=500,
2045
- backgroundColor='rgba(0,0,0,0)',
2046
- gridColor='lightgray',
2047
- marginLeft=0,
2048
- marginRight=0,
2049
- marginTop=40,
2050
- marginBottom=0,
2051
- renderer = "notebook"):
2052
- """
2053
- Shows the data in a plolty graph.
651
+ if len(graph_feat_cols) > 0:
652
+ data.u = torch.tensor(g_row[graph_feat_cols].values[0], dtype=torch.float32)
2054
653
 
2055
- Parameters
2056
- ----------
2057
- data : list
2058
- The data to display.
2059
- labels : list
2060
- The labels to use for the data.
2061
- width : int , optional
2062
- The desired width of the figure. Default is 950.
2063
- height : int , optional
2064
- The desired height of the figure. Default is 500.
2065
- title : str , optional
2066
- The chart title. Default is "Training and Testing Results".
2067
- xTitle : str , optional
2068
- The X-axis title. Default is "Epochs".
2069
- xSpacing : float , optional
2070
- The X-axis spacing. Default is 1.0.
2071
- yTitle : str , optional
2072
- The Y-axis title. Default is "Accuracy and Loss".
2073
- ySpacing : float , optional
2074
- The Y-axis spacing. Default is 0.1.
2075
- useMarkers : bool , optional
2076
- If set to True, markers will be displayed. Default is False.
2077
- chartType : str , optional
2078
- The desired type of chart. The options are "Line", "Bar", or "Scatter". It is case insensitive. Default is "Line".
2079
- backgroundColor : str , optional
2080
- The desired background color. This can be any plotly color string and may be specified as:
2081
- - A hex string (e.g. '#ff0000')
2082
- - An rgb/rgba string (e.g. 'rgb(255,0,0)')
2083
- - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
2084
- - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
2085
- - A named CSS color.
2086
- The default is 'rgba(0,0,0,0)' (transparent).
2087
- gridColor : str , optional
2088
- The desired grid color. This can be any plotly color string and may be specified as:
2089
- - A hex string (e.g. '#ff0000')
2090
- - An rgb/rgba string (e.g. 'rgb(255,0,0)')
2091
- - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
2092
- - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
2093
- - A named CSS color.
2094
- The default is 'lightgray'.
2095
- marginLeft : int , optional
2096
- The desired left margin in pixels. Default is 0.
2097
- marginRight : int , optional
2098
- The desired right margin in pixels. Default is 0.
2099
- marginTop : int , optional
2100
- The desired top margin in pixels. Default is 40.
2101
- marginBottom : int , optional
2102
- The desired bottom margin in pixels. Default is 0.
2103
- renderer : str , optional
2104
- The desired plotly renderer. Default is "notebook".
654
+ # node-level
655
+ if cfg.level == "node":
656
+ y_vals = g_nodes[cfg.node_label_header].values
657
+ if cfg.node_label_type == "categorical":
658
+ data.y = torch.tensor(y_vals.astype(int), dtype=torch.long)
659
+ else:
660
+ data.y = torch.tensor(y_vals.astype(float), dtype=torch.float32)
661
+ data.train_mask, data.val_mask, data.test_mask = self._get_or_make_node_masks(g_nodes)
662
+
663
+ # edge-level
664
+ if cfg.level == "edge":
665
+ y_vals = g_edges[cfg.edge_label_header].values
666
+ if cfg.edge_label_type == "categorical":
667
+ data.edge_y = torch.tensor(y_vals.astype(int), dtype=torch.long)
668
+ else:
669
+ data.edge_y = torch.tensor(y_vals.astype(float), dtype=torch.float32)
670
+ data.edge_train_mask, data.edge_val_mask, data.edge_test_mask = self._get_or_make_edge_masks(g_edges)
671
+
672
+ self.data_list.append(data)
673
+
674
+ # output dimensionality
675
+ if cfg.level == "graph":
676
+ self._num_outputs = self._infer_num_classes(gdf[cfg.graph_label_header].values) if cfg.graph_label_type == "categorical" else 1
677
+ elif cfg.level == "node":
678
+ self._num_outputs = self._infer_num_classes(ndf[cfg.node_label_header].values) if cfg.node_label_type == "categorical" else 1
679
+ elif cfg.level == "edge":
680
+ self._num_outputs = self._infer_num_classes(edf[cfg.edge_label_header].values) if cfg.edge_label_type == "categorical" else 1
681
+ elif cfg.level == "link":
682
+ self._num_outputs = 1
683
+ else:
684
+ raise ValueError("Unsupported level.")
685
+
686
+ def _get_or_make_node_masks(self, g_nodes: pd.DataFrame):
687
+ cfg = self.config
688
+ cols = g_nodes.columns
689
+
690
+ if (cfg.node_train_mask_header in cols) and (cfg.node_val_mask_header in cols) and (cfg.node_test_mask_header in cols):
691
+ train_mask = torch.tensor(g_nodes[cfg.node_train_mask_header].astype(bool).values, dtype=torch.bool)
692
+ val_mask = torch.tensor(g_nodes[cfg.node_val_mask_header].astype(bool).values, dtype=torch.bool)
693
+ test_mask = torch.tensor(g_nodes[cfg.node_test_mask_header].astype(bool).values, dtype=torch.bool)
694
+ return train_mask, val_mask, test_mask
695
+
696
+ n = len(g_nodes)
697
+ idx = list(range(n))
698
+ if cfg.shuffle:
699
+ random.Random(cfg.random_state).shuffle(idx)
700
+
701
+ n_train = max(1, int(cfg.split[0] * n))
702
+ n_val = max(1, int(cfg.split[1] * n))
703
+ n_test = max(0, n - n_train - n_val)
704
+
705
+ train_idx = set(idx[:n_train])
706
+ val_idx = set(idx[n_train:n_train + n_val])
707
+ test_idx = set(idx[n_train + n_val:n_train + n_val + n_test])
708
+
709
+ train_mask = torch.tensor([i in train_idx for i in range(n)], dtype=torch.bool)
710
+ val_mask = torch.tensor([i in val_idx for i in range(n)], dtype=torch.bool)
711
+ test_mask = torch.tensor([i in test_idx for i in range(n)], dtype=torch.bool)
712
+ return train_mask, val_mask, test_mask
713
+
714
+ def _get_or_make_edge_masks(self, g_edges: pd.DataFrame):
715
+ cfg = self.config
716
+ cols = g_edges.columns
717
+
718
+ if (cfg.edge_train_mask_header in cols) and (cfg.edge_val_mask_header in cols) and (cfg.edge_test_mask_header in cols):
719
+ train_mask = torch.tensor(g_edges[cfg.edge_train_mask_header].astype(bool).values, dtype=torch.bool)
720
+ val_mask = torch.tensor(g_edges[cfg.edge_val_mask_header].astype(bool).values, dtype=torch.bool)
721
+ test_mask = torch.tensor(g_edges[cfg.edge_test_mask_header].astype(bool).values, dtype=torch.bool)
722
+ return train_mask, val_mask, test_mask
723
+
724
+ n = len(g_edges)
725
+ idx = list(range(n))
726
+ if cfg.shuffle:
727
+ random.Random(cfg.random_state).shuffle(idx)
728
+
729
+ n_train = max(1, int(cfg.split[0] * n))
730
+ n_val = max(1, int(cfg.split[1] * n))
731
+ n_test = max(0, n - n_train - n_val)
732
+
733
+ train_idx = set(idx[:n_train])
734
+ val_idx = set(idx[n_train:n_train + n_val])
735
+ test_idx = set(idx[n_train + n_val:n_train + n_val + n_test])
736
+
737
+ train_mask = torch.tensor([i in train_idx for i in range(n)], dtype=torch.bool)
738
+ val_mask = torch.tensor([i in val_idx for i in range(n)], dtype=torch.bool)
739
+ test_mask = torch.tensor([i in test_idx for i in range(n)], dtype=torch.bool)
740
+ return train_mask, val_mask, test_mask
741
+
742
+ # ----------------------------
743
+ # Holdout split (graph-level)
744
+ # ----------------------------
745
+ def _split_holdout(self):
746
+ cfg = self.config
747
+ if cfg.level in ["node", "edge", "link"]:
748
+ self.train_set = self.data_list
749
+ self.val_set = self.data_list
750
+ self.test_set = self.data_list
751
+ return
752
+
753
+ n = len(self.data_list)
754
+ idx = list(range(n))
755
+ if cfg.shuffle:
756
+ random.Random(cfg.random_state).shuffle(idx)
757
+
758
+ n_train = max(1, int(cfg.split[0] * n))
759
+ n_val = max(1, int(cfg.split[1] * n))
760
+ n_test = max(0, n - n_train - n_val)
761
+
762
+ train_idx = idx[:n_train]
763
+ val_idx = idx[n_train:n_train + n_val]
764
+ test_idx = idx[n_train + n_val:n_train + n_val + n_test]
765
+
766
+ self.train_set = [self.data_list[i] for i in train_idx]
767
+ self.val_set = [self.data_list[i] for i in val_idx]
768
+ self.test_set = [self.data_list[i] for i in test_idx]
769
+
770
+ # --------------
771
+ # Model building
772
+ # --------------
773
+ def _build_model(self):
774
+ cfg = self.config
775
+ in_dim = int(self.data_list[0].x.shape[1])
776
+
777
+ encoder = _GNNBackbone(
778
+ in_dim=in_dim,
779
+ hidden_dims=cfg.hidden_dims,
780
+ conv=cfg.conv,
781
+ activation=cfg.activation,
782
+ dropout=cfg.dropout,
783
+ batch_norm=cfg.batch_norm,
784
+ residual=cfg.residual
785
+ )
786
+
787
+ if cfg.level == "graph":
788
+ head = _GraphHead(encoder.out_dim, self._num_outputs, pooling=cfg.pooling, dropout=cfg.dropout)
789
+ self.model = nn.ModuleDict({"encoder": encoder, "head": head}).to(self.device)
790
+ elif cfg.level == "node":
791
+ head = _NodeHead(encoder.out_dim, self._num_outputs, dropout=cfg.dropout)
792
+ self.model = nn.ModuleDict({"encoder": encoder, "head": head}).to(self.device)
793
+ elif cfg.level == "edge":
794
+ head = _EdgeHead(encoder.out_dim, self._num_outputs, dropout=cfg.dropout)
795
+ self.model = nn.ModuleDict({"encoder": encoder, "head": head}).to(self.device)
796
+ elif cfg.level == "link":
797
+ predictor = _LinkPredictor(encoder.out_dim, hidden=max(32, encoder.out_dim), dropout=cfg.dropout)
798
+ self.model = nn.ModuleDict({"encoder": encoder, "predictor": predictor}).to(self.device)
799
+ else:
800
+ raise ValueError("Unsupported level.")
2105
801
 
2106
- Returns
2107
- -------
2108
- None.
802
+ if cfg.optimizer == "adamw":
803
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
804
+ else:
805
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
2109
806
 
807
+ if cfg.level == "link":
808
+ self.criterion = nn.BCEWithLogitsLoss()
809
+ else:
810
+ if cfg.task == "regression":
811
+ self.criterion = nn.MSELoss()
812
+ else:
813
+ self.criterion = nn.CrossEntropyLoss()
814
+
815
+ def _apply_gradients(self):
816
+ cfg = self.config
817
+ if cfg.gradient_clip_norm is not None:
818
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), float(cfg.gradient_clip_norm))
819
+ self.optimizer.step()
820
+
821
+ # -----------------------
822
+ # Training / evaluation
823
+ # -----------------------
824
+ def Train(self, epochs: Optional[int] = None, batch_size: Optional[int] = None) -> Dict[str, List[float]]:
2110
825
  """
2111
- from topologicpy.Plotly import Plotly
2112
-
2113
- dataFrame = Plotly.DataByDGL(data, labels)
2114
- fig = Plotly.FigureByDataFrame(dataFrame,
2115
- labels=labels,
2116
- title=title,
2117
- xTitle=xTitle,
2118
- xSpacing=xSpacing,
2119
- yTitle=yTitle,
2120
- ySpacing=ySpacing,
2121
- useMarkers=useMarkers,
2122
- chartType=chartType,
2123
- width=width,
2124
- height=height,
2125
- backgroundColor=backgroundColor,
2126
- gridColor=gridColor,
2127
- marginRight=marginRight,
2128
- marginLeft=marginLeft,
2129
- marginTop=marginTop,
2130
- marginBottom=marginBottom
2131
- )
2132
- Plotly.Show(fig, renderer=renderer)
2133
-
2134
- @staticmethod
2135
- def ModelLoad(path, model):
2136
- """
2137
- Returns the model found at the input file path.
2138
-
2139
- Parameters
2140
- ----------
2141
- path : str
2142
- File path for the saved classifier.
2143
- model : torch.nn.module
2144
- Initialized instance of model
826
+ Train a model using holdout splitting (or in-graph masks for node/edge tasks).
2145
827
 
2146
- Returns
2147
- -------
2148
- PyG Classifier
2149
- The classifier.
2150
-
2151
- """
2152
- if not path:
2153
- return None
2154
-
2155
- model.load(path)
2156
- return model
2157
-
2158
- @staticmethod
2159
- def ConfusionMatrix(actual, predicted, normalize=False):
828
+ If you want k-fold cross-validation for graph-level tasks, call CrossValidate().
2160
829
  """
2161
- Returns the confusion matrix for the input actual and predicted labels. This is to be used with classification tasks only not regression.
830
+ cfg = self.config
831
+ if epochs is not None:
832
+ cfg.epochs = int(epochs)
833
+ if batch_size is not None:
834
+ cfg.batch_size = int(batch_size)
835
+
836
+ self.history = {"train_loss": [], "val_loss": []}
837
+
838
+ if cfg.level == "graph":
839
+ train_loader = DataLoader(self.train_set, batch_size=cfg.batch_size, shuffle=True)
840
+ val_loader = DataLoader(self.val_set, batch_size=cfg.batch_size, shuffle=False)
841
+
842
+ best_val = float("inf")
843
+ patience = 0
844
+
845
+ for _ in range(cfg.epochs):
846
+ tr = self._train_epoch_graph(train_loader)
847
+ va = self._eval_epoch_graph(val_loader)
848
+ self.history["train_loss"].append(tr)
849
+ self.history["val_loss"].append(va)
850
+
851
+ if cfg.early_stopping:
852
+ if va < best_val - 1e-9:
853
+ best_val = va
854
+ patience = 0
855
+ else:
856
+ patience += 1
857
+ if patience >= int(cfg.early_stopping_patience):
858
+ break
859
+
860
+ elif cfg.level == "node":
861
+ train_loader = DataLoader(self.data_list, batch_size=1, shuffle=True)
862
+ val_loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
863
+ for _ in range(cfg.epochs):
864
+ tr = self._train_epoch_node(train_loader)
865
+ va = self._eval_epoch_node(val_loader)
866
+ self.history["train_loss"].append(tr)
867
+ self.history["val_loss"].append(va)
868
+
869
+ elif cfg.level == "edge":
870
+ train_loader = DataLoader(self.data_list, batch_size=1, shuffle=True)
871
+ val_loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
872
+ for _ in range(cfg.epochs):
873
+ tr = self._train_epoch_edge(train_loader)
874
+ va = self._eval_epoch_edge(val_loader)
875
+ self.history["train_loss"].append(tr)
876
+ self.history["val_loss"].append(va)
877
+
878
+ elif cfg.level == "link":
879
+ train_loader = DataLoader(self.data_list, batch_size=1, shuffle=True)
880
+ val_loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
881
+ for _ in range(cfg.epochs):
882
+ tr = self._train_epoch_link(train_loader)
883
+ va = self._eval_epoch_link(val_loader)
884
+ self.history["train_loss"].append(tr)
885
+ self.history["val_loss"].append(va)
2162
886
 
2163
- Parameters
2164
- ----------
2165
- actual : list
2166
- The input list of actual labels.
2167
- predicted : list
2168
- The input list of predicts labels.
2169
- normalized : bool , optional
2170
- If set to True, the returned data will be normalized (proportion of 1). Otherwise, actual numbers are returned. Default is False.
2171
-
2172
- Returns
2173
- -------
2174
- list
2175
- The created confusion matrix.
2176
-
2177
- """
2178
-
2179
- try:
2180
- from sklearn import metrics
2181
- from sklearn.metrics import accuracy_score
2182
- except:
2183
- print("PyG - Installing required scikit-learn (sklearn) library.")
2184
- try:
2185
- os.system("pip install scikit-learn")
2186
- except:
2187
- os.system("pip install scikit-learn --user")
2188
- try:
2189
- from sklearn import metrics
2190
- from sklearn.metrics import accuracy_score
2191
- print("PyG - scikit-learn (sklearn) library installed correctly.")
2192
- except:
2193
- warnings.warn("PyG - Error: Could not import scikit-learn (sklearn). Please try to install scikit-learn manually. Returning None.")
2194
- return None
2195
-
2196
- if not isinstance(actual, list):
2197
- print("PyG.ConfusionMatrix - ERROR: The actual input is not a list. Returning None")
2198
- return None
2199
- if not isinstance(predicted, list):
2200
- print("PyG.ConfusionMatrix - ERROR: The predicted input is not a list. Returning None")
2201
- return None
2202
- if len(actual) != len(predicted):
2203
- print("PyG.ConfusionMatrix - ERROR: The two input lists do not have the same length. Returning None")
2204
- return None
2205
- if normalize:
2206
- cm = np.transpose(metrics.confusion_matrix(y_true=actual, y_pred=predicted, normalize="true"))
2207
887
  else:
2208
- cm = np.transpose(metrics.confusion_matrix(y_true=actual, y_pred=predicted))
2209
- return cm
888
+ raise ValueError("Unsupported level.")
2210
889
 
2211
- @staticmethod
2212
- def ModelPredict(model, dataset, nodeATTRKey="feat"):
2213
- """
2214
- Predicts the value of the input dataset.
2215
-
2216
- Parameters
2217
- ----------
2218
- dataset : PyGDataset
2219
- The input PyG dataset.
2220
- model : Model
2221
- The input trained model.
2222
- nodeATTRKey : str , optional
2223
- The key used for node attributes. Default is "feat".
2224
-
2225
- Returns
2226
- -------
2227
- list
2228
- The list of predictions
2229
- """
2230
- try:
2231
- model = model.model #The inoput model might be our wrapper model. In that case, get its model attribute to do the prediciton.
2232
- except:
2233
- pass
2234
- values = []
2235
- dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
2236
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2237
- model.eval()
2238
- for data in tqdm(dataloader, desc='Predicting', leave=False):
2239
- data = data.to(device)
2240
- pred = model(data)
2241
- values.extend(list(np.round(pred.detach().cpu().numpy().flatten(), 3)))
2242
- return values
890
+ return self.history
2243
891
 
2244
- @staticmethod
2245
- def ModelClassify(model, dataset, nodeATTRKey="feat"):
892
+ def CrossValidate(self,
893
+ k_folds: Optional[int] = None,
894
+ epochs: Optional[int] = None,
895
+ batch_size: Optional[int] = None) -> Dict[str, Union[float, List[Dict[str, float]]]]:
2246
896
  """
2247
- Predicts the classification the labels of the input dataset.
897
+ Perform K-Fold cross-validation (graph-level only).
2248
898
 
2249
899
  Parameters
2250
900
  ----------
2251
- dataset : PyGDataset
2252
- The input PyG dataset.
2253
- model : Model
2254
- The input trained model.
2255
- nodeATTRKey : str , optional
2256
- The key used for node attributes. Default is "feat".
901
+ k_folds : int, optional
902
+ Number of folds. Defaults to config.k_folds.
903
+ epochs : int, optional
904
+ Training epochs per fold. Defaults to config.epochs.
905
+ batch_size : int, optional
906
+ Batch size. Defaults to config.batch_size.
2257
907
 
2258
908
  Returns
2259
909
  -------
2260
910
  dict
2261
- Dictionary containing labels and probabilities. The included keys and values are:
2262
- - "predictions" (list): the list of predicted labels
2263
- - "probabilities" (list): the list of probabilities that the label is one of the categories.
911
+ {
912
+ "fold_metrics": [ {metric: value, ...}, ... ],
913
+ "mean_<metric>": value,
914
+ "std_<metric>": value
915
+ }
2264
916
 
917
+ Notes
918
+ -----
919
+ - Supported only for graph-level tasks (node/edge tasks typically use per-graph masks).
920
+ - If labels are categorical, stratified splits can be enabled (config.k_stratify).
2265
921
  """
2266
- try:
2267
- model = model.model #The inoput model might be our wrapper model. In that case, get its model attribute to do the prediciton.
2268
- except:
2269
- pass
2270
- labels = []
2271
- probabilities = []
2272
- dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
2273
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2274
- for data in tqdm(dataloader, desc='Classifying', leave=False):
2275
- data = data.to(device)
2276
- pred = model(data)
2277
- labels.extend(pred.argmax(1).tolist())
2278
- probability = (torch.nn.functional.softmax(pred, dim=1).tolist())
2279
- probability = probability[0]
2280
- temp_probability = []
2281
- for p in probability:
2282
- temp_probability.append(round(p, 3))
2283
- probabilities.extend(temp_probability)
2284
- return {"predictions":labels, "probabilities":probabilities}
2285
-
2286
- @staticmethod
2287
- def Accuracy(actual, predicted, mantissa: int = 6):
2288
- """
2289
- Computes the accuracy of the input predictions based on the input labels. This is to be used only with classification not with regression.
2290
-
2291
- Parameters
2292
- ----------
2293
- actual : list
2294
- The input list of actual values.
2295
- predicted : list
2296
- The input list of predicted values.
2297
- mantissa : int , optional
2298
- The number of decimal places to round the result to. Default is 6.
2299
-
2300
- Returns
2301
- -------
2302
- dict
2303
- A dictionary returning the accuracy information. This contains the following keys and values:
2304
- - "accuracy" (float): The number of correct predictions divided by the length of the list.
2305
- - "correct" (int): The number of correct predictions
2306
- - "mask" (list): A boolean mask for correct vs. wrong predictions which can be used to filter the list of predictions
2307
- - "size" (int): The size of the predictions list
2308
- - "wrong" (int): The number of wrong predictions
922
+ cfg = self.config
923
+ if cfg.level != "graph":
924
+ raise ValueError("CrossValidate is supported for graph-level tasks only.")
925
+
926
+ if k_folds is None:
927
+ k_folds = int(cfg.k_folds)
928
+ if k_folds < 2:
929
+ raise ValueError("k_folds must be >= 2.")
930
+ if epochs is not None:
931
+ cfg.epochs = int(epochs)
932
+ if batch_size is not None:
933
+ cfg.batch_size = int(batch_size)
934
+
935
+ n = len(self.data_list)
936
+ indices = np.arange(n)
937
+
938
+ # Stratification labels (optional)
939
+ y = None
940
+ if cfg.k_stratify and cfg.task == "classification" and cfg.graph_label_type == "categorical":
941
+ y = np.array([int(d.y.item()) for d in self.data_list], dtype=int)
942
+
943
+ rng = np.random.RandomState(cfg.random_state)
944
+ if cfg.k_shuffle:
945
+ rng.shuffle(indices)
946
+
947
+ # Build folds
948
+ if y is None:
949
+ folds = np.array_split(indices, k_folds)
950
+ else:
951
+ folds = [np.array([], dtype=int) for _ in range(k_folds)]
952
+ classes = np.unique(y)
953
+ for c in classes:
954
+ cls_idx = indices[y[indices] == c]
955
+ cls_chunks = np.array_split(cls_idx, k_folds)
956
+ for fi in range(k_folds):
957
+ folds[fi] = np.concatenate([folds[fi], cls_chunks[fi]])
958
+ folds = [rng.permutation(f) for f in folds]
959
+
960
+ fold_metrics: List[Dict[str, float]] = []
961
+ base_config = copy.deepcopy(cfg)
962
+
963
+ for fi in range(k_folds):
964
+ test_idx = folds[fi]
965
+ train_idx = np.concatenate([folds[j] for j in range(k_folds) if j != fi])
966
+
967
+ train_set = [self.data_list[i] for i in train_idx.tolist()]
968
+ test_set = [self.data_list[i] for i in test_idx.tolist()]
969
+
970
+ # Fresh model per fold
971
+ self.config = copy.deepcopy(base_config)
972
+ self._build_model()
973
+ self.history = {"train_loss": [], "val_loss": []}
974
+
975
+ train_loader = DataLoader(train_set, batch_size=self.config.batch_size, shuffle=True)
976
+ test_loader = DataLoader(test_set, batch_size=self.config.batch_size, shuffle=False)
977
+
978
+ best_loss = float("inf")
979
+ patience = 0
980
+
981
+ for _ in range(self.config.epochs):
982
+ tr_loss = self._train_epoch_graph(train_loader)
983
+ te_loss = self._eval_epoch_graph(test_loader)
984
+ self.history["train_loss"].append(tr_loss)
985
+ self.history["val_loss"].append(te_loss)
986
+
987
+ if self.config.early_stopping:
988
+ if te_loss < best_loss - 1e-9:
989
+ best_loss = te_loss
990
+ patience = 0
991
+ else:
992
+ patience += 1
993
+ if patience >= int(self.config.early_stopping_patience):
994
+ break
995
+
996
+ # Metrics (unprefixed) for the fold
997
+ metrics = self._metrics_graph(test_loader, prefix="")
998
+ metrics["fold"] = float(fi)
999
+ fold_metrics.append(metrics)
1000
+
1001
+ # Restore original config and rebuild model
1002
+ self.config = copy.deepcopy(base_config)
1003
+ self._build_model()
1004
+
1005
+ # Aggregate
1006
+ summary: Dict[str, Union[float, List[Dict[str, float]]]] = {"fold_metrics": fold_metrics}
1007
+ metric_keys = [k for k in fold_metrics[0].keys()] if fold_metrics else []
1008
+ metric_keys = [k for k in metric_keys if k != "fold"]
1009
+
1010
+ for k in metric_keys:
1011
+ vals = np.array([fm[k] for fm in fold_metrics], dtype=float)
1012
+ summary[f"mean_{k}"] = float(np.mean(vals))
1013
+ summary[f"std_{k}"] = float(np.std(vals))
1014
+
1015
+ self.cv_report = summary
1016
+ return summary
1017
+
1018
+ def Validate(self) -> Dict[str, float]:
1019
+ cfg = self.config
1020
+ if cfg.level == "graph":
1021
+ loader = DataLoader(self.val_set, batch_size=cfg.batch_size, shuffle=False)
1022
+ return self._metrics_graph(loader, prefix="val_")
1023
+ if cfg.level == "node":
1024
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1025
+ return self._metrics_node(loader, split="val")
1026
+ if cfg.level == "edge":
1027
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1028
+ return self._metrics_edge(loader, split="val")
1029
+ if cfg.level == "link":
1030
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1031
+ return self._metrics_link(loader, split="val")
1032
+ raise ValueError("Unsupported level.")
1033
+
1034
+ def Test(self) -> Dict[str, float]:
1035
+ cfg = self.config
1036
+ if cfg.level == "graph":
1037
+ loader = DataLoader(self.test_set, batch_size=cfg.batch_size, shuffle=False)
1038
+ return self._metrics_graph(loader, prefix="test_")
1039
+ if cfg.level == "node":
1040
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1041
+ return self._metrics_node(loader, split="test")
1042
+ if cfg.level == "edge":
1043
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1044
+ return self._metrics_edge(loader, split="test")
1045
+ if cfg.level == "link":
1046
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1047
+ return self._metrics_link(loader, split="test")
1048
+ raise ValueError("Unsupported level.")
1049
+
1050
+ # --------
1051
+ # Epochs
1052
+ # --------
1053
+ def _loss_from_logits(self, logits, y, task: TaskKind):
1054
+ if task == "regression":
1055
+ pred = logits.squeeze(-1)
1056
+ return self.criterion(pred.float(), y.float())
1057
+ return self.criterion(logits, y.long())
1058
+
1059
+ def _train_epoch_graph(self, loader):
1060
+ self.model.train()
1061
+ losses = []
1062
+ for batch in loader:
1063
+ batch = batch.to(self.device)
1064
+ self.optimizer.zero_grad()
1065
+ node_emb = self.model["encoder"](batch.x, batch.edge_index)
1066
+ logits = self.model["head"](node_emb, batch.batch)
1067
+ loss = self._loss_from_logits(logits, batch.y, self.config.task)
1068
+ loss.backward()
1069
+ self._apply_gradients()
1070
+ losses.append(float(loss.detach().cpu()))
1071
+ return float(np.mean(losses)) if losses else 0.0
1072
+
1073
+ @torch.no_grad()
1074
+ def _eval_epoch_graph(self, loader):
1075
+ self.model.eval()
1076
+ losses = []
1077
+ for batch in loader:
1078
+ batch = batch.to(self.device)
1079
+ node_emb = self.model["encoder"](batch.x, batch.edge_index)
1080
+ logits = self.model["head"](node_emb, batch.batch)
1081
+ loss = self._loss_from_logits(logits, batch.y, self.config.task)
1082
+ losses.append(float(loss.detach().cpu()))
1083
+ return float(np.mean(losses)) if losses else 0.0
1084
+
1085
+ def _train_epoch_node(self, loader):
1086
+ self.model.train()
1087
+ losses = []
1088
+ for data in loader:
1089
+ data = data.to(self.device)
1090
+ self.optimizer.zero_grad()
1091
+ node_emb = self.model["encoder"](data.x, data.edge_index)
1092
+ logits = self.model["head"](node_emb)
1093
+ mask = data.train_mask
1094
+ loss = self._loss_from_logits(logits[mask], data.y[mask], self.config.task)
1095
+ loss.backward()
1096
+ self._apply_gradients()
1097
+ losses.append(float(loss.detach().cpu()))
1098
+ return float(np.mean(losses)) if losses else 0.0
1099
+
1100
+ @torch.no_grad()
1101
+ def _eval_epoch_node(self, loader):
1102
+ self.model.eval()
1103
+ losses = []
1104
+ for data in loader:
1105
+ data = data.to(self.device)
1106
+ node_emb = self.model["encoder"](data.x, data.edge_index)
1107
+ logits = self.model["head"](node_emb)
1108
+ mask = data.val_mask
1109
+ loss = self._loss_from_logits(logits[mask], data.y[mask], self.config.task)
1110
+ losses.append(float(loss.detach().cpu()))
1111
+ return float(np.mean(losses)) if losses else 0.0
1112
+
1113
+ def _train_epoch_edge(self, loader):
1114
+ self.model.train()
1115
+ losses = []
1116
+ for data in loader:
1117
+ data = data.to(self.device)
1118
+ self.optimizer.zero_grad()
1119
+ node_emb = self.model["encoder"](data.x, data.edge_index)
1120
+ logits = self.model["head"](node_emb, data.edge_index)
1121
+ mask = data.edge_train_mask
1122
+ loss = self._loss_from_logits(logits[mask], data.edge_y[mask], self.config.task)
1123
+ loss.backward()
1124
+ self._apply_gradients()
1125
+ losses.append(float(loss.detach().cpu()))
1126
+ return float(np.mean(losses)) if losses else 0.0
1127
+
1128
+ @torch.no_grad()
1129
+ def _eval_epoch_edge(self, loader):
1130
+ self.model.eval()
1131
+ losses = []
1132
+ for data in loader:
1133
+ data = data.to(self.device)
1134
+ node_emb = self.model["encoder"](data.x, data.edge_index)
1135
+ logits = self.model["head"](node_emb, data.edge_index)
1136
+ mask = data.edge_val_mask
1137
+ loss = self._loss_from_logits(logits[mask], data.edge_y[mask], self.config.task)
1138
+ losses.append(float(loss.detach().cpu()))
1139
+ return float(np.mean(losses)) if losses else 0.0
1140
+
1141
+ def _train_epoch_link(self, loader):
1142
+ self.model.train()
1143
+ losses = []
1144
+ split = RandomLinkSplit(
1145
+ num_val=self.config.link_val_ratio,
1146
+ num_test=self.config.link_test_ratio,
1147
+ is_undirected=self.config.link_is_undirected,
1148
+ add_negative_train_samples=True,
1149
+ neg_sampling_ratio=1.0
1150
+ )
1151
+ for data in loader:
1152
+ train_data, _, _ = split(data)
1153
+ train_data = train_data.to(self.device)
1154
+ self.optimizer.zero_grad()
1155
+ node_emb = self.model["encoder"](train_data.x, train_data.edge_index)
1156
+ logits = self.model["predictor"](node_emb, train_data.edge_label_index)
1157
+ loss = self.criterion(logits, train_data.edge_label.float())
1158
+ loss.backward()
1159
+ self._apply_gradients()
1160
+ losses.append(float(loss.detach().cpu()))
1161
+ return float(np.mean(losses)) if losses else 0.0
1162
+
1163
+ @torch.no_grad()
1164
+ def _eval_epoch_link(self, loader):
1165
+ self.model.eval()
1166
+ losses = []
1167
+ split = RandomLinkSplit(
1168
+ num_val=self.config.link_val_ratio,
1169
+ num_test=self.config.link_test_ratio,
1170
+ is_undirected=self.config.link_is_undirected,
1171
+ add_negative_train_samples=True,
1172
+ neg_sampling_ratio=1.0
1173
+ )
1174
+ for data in loader:
1175
+ _, val_data, _ = split(data)
1176
+ val_data = val_data.to(self.device)
1177
+ node_emb = self.model["encoder"](val_data.x, val_data.edge_index)
1178
+ logits = self.model["predictor"](node_emb, val_data.edge_label_index)
1179
+ loss = self.criterion(logits, val_data.edge_label.float())
1180
+ losses.append(float(loss.detach().cpu()))
1181
+ return float(np.mean(losses)) if losses else 0.0
1182
+
1183
+ # --------------
1184
+ # Metrics helpers
1185
+ # --------------
1186
+ @torch.no_grad()
1187
+ def _predict_graph(self, loader):
1188
+ self.model.eval()
1189
+ y_true, y_pred = [], []
1190
+ for batch in loader:
1191
+ batch = batch.to(self.device)
1192
+ node_emb = self.model["encoder"](batch.x, batch.edge_index)
1193
+ out = self.model["head"](node_emb, batch.batch)
1194
+ if self.config.task == "regression":
1195
+ y_true.extend(batch.y.squeeze(-1).detach().cpu().numpy().tolist())
1196
+ y_pred.extend(out.squeeze(-1).detach().cpu().numpy().tolist())
1197
+ else:
1198
+ probs = F.softmax(out, dim=-1)
1199
+ y_true.extend(batch.y.detach().cpu().numpy().tolist())
1200
+ y_pred.extend(probs.argmax(dim=-1).detach().cpu().numpy().tolist())
1201
+ return np.array(y_true), np.array(y_pred)
2309
1202
 
2310
- """
2311
- if len(predicted) < 1 or len(actual) < 1 or not len(predicted) == len(actual):
2312
- return None
2313
- correct = 0
2314
- mask = []
2315
- for i in range(len(predicted)):
2316
- if predicted[i] == actual[i]:
2317
- correct = correct + 1
2318
- mask.append(True)
1203
+ @torch.no_grad()
1204
+ def _predict_node(self, loader, mask_name: str):
1205
+ self.model.eval()
1206
+ y_true, y_pred = [], []
1207
+ for data in loader:
1208
+ data = data.to(self.device)
1209
+ node_emb = self.model["encoder"](data.x, data.edge_index)
1210
+ out = self.model["head"](node_emb)
1211
+ mask = getattr(data, mask_name)
1212
+ if self.config.task == "regression":
1213
+ y_true.extend(data.y[mask].detach().cpu().numpy().tolist())
1214
+ y_pred.extend(out.squeeze(-1)[mask].detach().cpu().numpy().tolist())
2319
1215
  else:
2320
- mask.append(False)
2321
- size = len(predicted)
2322
- wrong = len(predicted)- correct
2323
- accuracy = round(float(correct) / float(len(predicted)), mantissa)
2324
- return {"accuracy":accuracy, "correct":correct, "mask":mask, "size":size, "wrong":wrong}
2325
-
2326
- @staticmethod
2327
- def MSE(actual, predicted, mantissa: int = 6):
2328
- """
2329
- Computes the Mean Squared Error (MSE) of the input predictions based on the input labels. This is to be used with regression models.
1216
+ probs = F.softmax(out, dim=-1)
1217
+ y_true.extend(data.y[mask].detach().cpu().numpy().tolist())
1218
+ y_pred.extend(probs.argmax(dim=-1)[mask].detach().cpu().numpy().tolist())
1219
+ return np.array(y_true), np.array(y_pred)
2330
1220
 
2331
- Parameters
2332
- ----------
2333
- actual : list
2334
- The input list of actual values.
2335
- predicted : list
2336
- The input list of predicted values.
2337
- mantissa : int , optional
2338
- The number of decimal places to round the result to. Default is 6.
1221
+ @torch.no_grad()
1222
+ def _predict_edge(self, loader, mask_name: str):
1223
+ self.model.eval()
1224
+ y_true, y_pred = [], []
1225
+ for data in loader:
1226
+ data = data.to(self.device)
1227
+ node_emb = self.model["encoder"](data.x, data.edge_index)
1228
+ out = self.model["head"](node_emb, data.edge_index)
1229
+ mask = getattr(data, mask_name)
1230
+ if self.config.task == "regression":
1231
+ y_true.extend(data.edge_y[mask].detach().cpu().numpy().tolist())
1232
+ y_pred.extend(out.squeeze(-1)[mask].detach().cpu().numpy().tolist())
1233
+ else:
1234
+ probs = F.softmax(out, dim=-1)
1235
+ y_true.extend(data.edge_y[mask].detach().cpu().numpy().tolist())
1236
+ y_pred.extend(probs.argmax(dim=-1)[mask].detach().cpu().numpy().tolist())
1237
+ return np.array(y_true), np.array(y_pred)
2339
1238
 
2340
- Returns
2341
- -------
2342
- dict
2343
- A dictionary returning the MSE information. This contains the following keys and values:
2344
- - "mse" (float): The mean squared error rounded to the specified mantissa.
2345
- - "size" (int): The size of the predictions list.
2346
- """
2347
- if len(predicted) < 1 or len(actual) < 1 or not len(predicted) == len(actual):
2348
- return None
2349
-
2350
- mse = np.mean((np.array(predicted) - np.array(actual)) ** 2)
2351
- mse = round(mse, mantissa)
2352
- size = len(predicted)
1239
+ @torch.no_grad()
1240
+ def _predict_link(self, loader, split_name: str):
1241
+ self.model.eval()
1242
+ split = RandomLinkSplit(
1243
+ num_val=self.config.link_val_ratio,
1244
+ num_test=self.config.link_test_ratio,
1245
+ is_undirected=self.config.link_is_undirected,
1246
+ add_negative_train_samples=True,
1247
+ neg_sampling_ratio=1.0
1248
+ )
1249
+ y_true, y_score = [], []
1250
+ for data in loader:
1251
+ tr, va, te = split(data)
1252
+ use = {"train": tr, "val": va, "test": te}[split_name]
1253
+ use = use.to(self.device)
1254
+ node_emb = self.model["encoder"](use.x, use.edge_index)
1255
+ logits = self.model["predictor"](node_emb, use.edge_label_index)
1256
+ probs = torch.sigmoid(logits).detach().cpu().numpy()
1257
+ y = use.edge_label.detach().cpu().numpy()
1258
+ y_true.extend(y.tolist())
1259
+ y_score.extend(probs.tolist())
1260
+ return np.array(y_true), np.array(y_score)
1261
+
1262
+ # ----------
1263
+ # Public metrics API
1264
+ # ----------
1265
+ def _metrics_graph(self, loader, prefix: str):
1266
+ y_true, y_pred = self._predict_graph(loader)
1267
+ if self.config.task == "regression":
1268
+ return self._regression_metrics(y_true, y_pred, prefix=prefix)
1269
+ return self._classification_metrics(y_true, y_pred, prefix=prefix)
1270
+
1271
+ def _metrics_node(self, loader, split: Literal["train", "val", "test"]):
1272
+ mask = "train_mask" if split == "train" else ("val_mask" if split == "val" else "test_mask")
1273
+ y_true, y_pred = self._predict_node(loader, mask)
1274
+ if self.config.task == "regression":
1275
+ return self._regression_metrics(y_true, y_pred, prefix=f"{split}_")
1276
+ return self._classification_metrics(y_true, y_pred, prefix=f"{split}_")
1277
+
1278
+ def _metrics_edge(self, loader, split: Literal["train", "val", "test"]):
1279
+ mask = "edge_train_mask" if split == "train" else ("edge_val_mask" if split == "val" else "edge_test_mask")
1280
+ y_true, y_pred = self._predict_edge(loader, mask)
1281
+ if self.config.task == "regression":
1282
+ return self._regression_metrics(y_true, y_pred, prefix=f"{split}_")
1283
+ return self._classification_metrics(y_true, y_pred, prefix=f"{split}_")
1284
+
1285
+ def _metrics_link(self, loader, split: Literal["train", "val", "test"]):
1286
+ y_true, y_score = self._predict_link(loader, split)
1287
+ y_pred = (y_score >= 0.5).astype(int)
1288
+ return self._classification_metrics(y_true, y_pred, prefix=f"{split}_")
2353
1289
 
2354
- return {"mse": mse, "size": size}
1290
+ @staticmethod
1291
+ def _classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str = "") -> Dict[str, float]:
1292
+ acc = float(accuracy_score(y_true, y_pred)) if len(y_true) else 0.0
1293
+ prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)
1294
+ return {
1295
+ f"{prefix}accuracy": float(acc),
1296
+ f"{prefix}precision": float(prec),
1297
+ f"{prefix}recall": float(rec),
1298
+ f"{prefix}f1": float(f1),
1299
+ }
2355
1300
 
2356
1301
  @staticmethod
2357
- def Performance(actual, predicted, mantissa: int = 6):
2358
- """
2359
- Computes regression model performance measures. This is to be used only with regression not with classification.
1302
+ def _regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str = "") -> Dict[str, float]:
1303
+ if len(y_true) == 0:
1304
+ return {f"{prefix}mae": 0.0, f"{prefix}rmse": 0.0, f"{prefix}r2": 0.0}
1305
+ mae = float(mean_absolute_error(y_true, y_pred))
1306
+ rmse = float(math.sqrt(mean_squared_error(y_true, y_pred)))
1307
+ r2 = float(r2_score(y_true, y_pred))
1308
+ return {f"{prefix}mae": mae, f"{prefix}rmse": rmse, f"{prefix}r2": r2}
1309
+
1310
+ # -----------------
1311
+ # Plotly reporting
1312
+ # -----------------
1313
+ def PlotHistory(self):
1314
+ fig = go.Figure()
1315
+ fig.add_trace(go.Scatter(y=self.history["train_loss"], mode="lines+markers", name="Train Loss"))
1316
+ fig.add_trace(go.Scatter(y=self.history["val_loss"], mode="lines+markers", name="Val Loss"))
1317
+ fig.update_layout(title="Training History", xaxis_title="Epoch", yaxis_title="Loss")
1318
+ return fig
1319
+
1320
+ def PlotConfusionMatrix(self, split: Literal["train", "val", "test"] = "test"):
1321
+ if self.config.task != "classification" or self.config.level == "link":
1322
+ raise ValueError("Confusion matrix is only available for classification (graph/node/edge).")
1323
+
1324
+ if self.config.level == "graph":
1325
+ if split == "train":
1326
+ loader = DataLoader(self.train_set, batch_size=self.config.batch_size, shuffle=False)
1327
+ elif split == "val":
1328
+ loader = DataLoader(self.val_set, batch_size=self.config.batch_size, shuffle=False)
1329
+ else:
1330
+ loader = DataLoader(self.test_set, batch_size=self.config.batch_size, shuffle=False)
1331
+ y_true, y_pred = self._predict_graph(loader)
1332
+
1333
+ elif self.config.level == "node":
1334
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1335
+ mask = "train_mask" if split == "train" else ("val_mask" if split == "val" else "test_mask")
1336
+ y_true, y_pred = self._predict_node(loader, mask)
1337
+
1338
+ else: # edge
1339
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1340
+ mask = "edge_train_mask" if split == "train" else ("edge_val_mask" if split == "val" else "edge_test_mask")
1341
+ y_true, y_pred = self._predict_edge(loader, mask)
1342
+
1343
+ cm = confusion_matrix(y_true, y_pred)
1344
+ fig = px.imshow(cm, text_auto=True, title=f"Confusion Matrix ({split})")
1345
+ fig.update_layout(xaxis_title="Predicted", yaxis_title="True")
1346
+ return fig
1347
+
1348
+ def PlotParity(self, split: Literal["train", "val", "test"] = "test"):
1349
+ if self.config.task != "regression":
1350
+ raise ValueError("Parity plot is only available for regression tasks.")
1351
+
1352
+ if self.config.level == "graph":
1353
+ if split == "train":
1354
+ loader = DataLoader(self.train_set, batch_size=self.config.batch_size, shuffle=False)
1355
+ elif split == "val":
1356
+ loader = DataLoader(self.val_set, batch_size=self.config.batch_size, shuffle=False)
1357
+ else:
1358
+ loader = DataLoader(self.test_set, batch_size=self.config.batch_size, shuffle=False)
1359
+ y_true, y_pred = self._predict_graph(loader)
1360
+
1361
+ elif self.config.level == "node":
1362
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1363
+ mask = "train_mask" if split == "train" else ("val_mask" if split == "val" else "test_mask")
1364
+ y_true, y_pred = self._predict_node(loader, mask)
2360
1365
 
2361
- Parameters
2362
- ----------
2363
- actual : list
2364
- The input list of actual values.
2365
- predicted : list
2366
- The input list of predicted values.
2367
- mantissa : int , optional
2368
- The number of decimal places to round the result to. Default is 6.
2369
-
2370
- Returns
2371
- -------
2372
- dict
2373
- The dictionary containing the performance measures. The keys in the dictionary are: 'mae', 'mape', 'mse', 'r', 'r2', 'rmse'.
2374
- """
2375
-
2376
- if not isinstance(actual, list):
2377
- print("PyG.Performance - ERROR: The actual input is not a list. Returning None")
2378
- return None
2379
- if not isinstance(predicted, list):
2380
- print("PyG.Performance - ERROR: The predicted input is not a list. Returning None")
2381
- return None
2382
- if not (len(actual) == len(predicted)):
2383
- print("PyG.Performance - ERROR: The actual and predicted input lists have different lengths. Returning None")
2384
- return None
2385
-
2386
- predicted = np.array(predicted)
2387
- actual = np.array(actual)
2388
-
2389
- mae = np.mean(np.abs(predicted - actual))
2390
- mape = np.mean(np.abs((actual - predicted) / actual))*100
2391
- mse = np.mean((predicted - actual)**2)
2392
- correlation_matrix = np.corrcoef(predicted, actual)
2393
- r = correlation_matrix[0, 1]
2394
- r2 = r**2
2395
- absolute_errors = np.abs(predicted - actual)
2396
- mean_actual = np.mean(actual)
2397
- if mean_actual == 0:
2398
- rae = None
2399
1366
  else:
2400
- rae = np.mean(absolute_errors) / mean_actual
2401
- rmse = np.sqrt(mse)
2402
- return {'mae': round(mae, mantissa),
2403
- 'mape': round(mape, mantissa),
2404
- 'mse': round(mse, mantissa),
2405
- 'r': round(r, mantissa),
2406
- 'r2': round(r2, mantissa),
2407
- 'rae': round(rae, mantissa),
2408
- 'rmse': round(rmse, mantissa)
2409
- }
1367
+ loader = DataLoader(self.data_list, batch_size=1, shuffle=False)
1368
+ mask = "edge_train_mask" if split == "train" else ("edge_val_mask" if split == "val" else "edge_test_mask")
1369
+ y_true, y_pred = self._predict_edge(loader, mask)
1370
+
1371
+ fig = go.Figure()
1372
+ fig.add_trace(go.Scatter(x=y_true, y=y_pred, mode="markers", name="Predictions"))
1373
+ mn = float(min(np.min(y_true), np.min(y_pred))) if len(y_true) else 0.0
1374
+ mx = float(max(np.max(y_true), np.max(y_pred))) if len(y_true) else 1.0
1375
+ fig.add_trace(go.Scatter(x=[mn, mx], y=[mn, mx], mode="lines", name="Ideal"))
1376
+ fig.update_layout(title=f"Parity Plot ({split})", xaxis_title="True", yaxis_title="Predicted")
1377
+ return fig
1378
+
1379
+ def SaveModel(self, path: str):
1380
+ if not path.lower().endswith(".pt"):
1381
+ path = path + ".pt"
1382
+ torch.save(self.model.state_dict(), path)
1383
+
1384
+ def LoadModel(self, path: str):
1385
+ state = torch.load(path, map_location=self.device)
1386
+ self.model.load_state_dict(state)
1387
+ self.model.to(self.device)
1388
+ self.model.eval()