topologicpy 0.7.37__py3-none-any.whl → 0.7.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- topologicpy/ANN.py +2 -18
- topologicpy/Face.py +4 -8
- topologicpy/Graph.py +34 -13
- topologicpy/Grid.py +0 -2
- topologicpy/Helper.py +0 -34
- topologicpy/Neo4j.py +5 -3
- topologicpy/Polyskel.py +5 -1
- topologicpy/PyG.py +2067 -0
- topologicpy/Shell.py +2 -2
- topologicpy/Sun.py +26 -36
- topologicpy/Topology.py +0 -22
- topologicpy/Wire.py +3 -7
- topologicpy/version.py +1 -1
- {topologicpy-0.7.37.dist-info → topologicpy-0.7.39.dist-info}/METADATA +1 -1
- {topologicpy-0.7.37.dist-info → topologicpy-0.7.39.dist-info}/RECORD +18 -17
- {topologicpy-0.7.37.dist-info → topologicpy-0.7.39.dist-info}/WHEEL +1 -1
- {topologicpy-0.7.37.dist-info → topologicpy-0.7.39.dist-info}/LICENSE +0 -0
- {topologicpy-0.7.37.dist-info → topologicpy-0.7.39.dist-info}/top_level.txt +0 -0
topologicpy/PyG.py
ADDED
@@ -0,0 +1,2067 @@
|
|
1
|
+
# Copyright (C) 2024
|
2
|
+
# Wassim Jabi <wassim.jabi@gmail.com>
|
3
|
+
#
|
4
|
+
# This program is free software: you can redistribute it and/or modify it under
|
5
|
+
# the terms of the GNU Affero General Public License as published by the Free Software
|
6
|
+
# Foundation, either version 3 of the License, or (at your option) any later
|
7
|
+
# version.
|
8
|
+
#
|
9
|
+
# This program is distributed in the hope that it will be useful, but WITHOUT
|
10
|
+
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
11
|
+
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
|
12
|
+
# details.
|
13
|
+
#
|
14
|
+
# You should have received a copy of the GNU Affero General Public License along with
|
15
|
+
# this program. If not, see <https://www.gnu.org/licenses/>.
|
16
|
+
|
17
|
+
import os
|
18
|
+
import copy
|
19
|
+
import numpy as np
|
20
|
+
import pandas as pd
|
21
|
+
import torch
|
22
|
+
import torch.nn as nn
|
23
|
+
import torch.nn.functional as F
|
24
|
+
from torch_geometric.data import Data, Dataset
|
25
|
+
from torch_geometric.loader import DataLoader
|
26
|
+
from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool, global_add_pool
|
27
|
+
from torch.utils.data.sampler import SubsetRandomSampler
|
28
|
+
from sklearn.model_selection import KFold
|
29
|
+
from sklearn.metrics import accuracy_score
|
30
|
+
from tqdm.auto import tqdm
|
31
|
+
import gc
|
32
|
+
|
33
|
+
|
34
|
+
class CustomGraphDataset(Dataset):
|
35
|
+
def __init__(self, root, node_level=False, graph_level=True, node_attr_key='feat',
|
36
|
+
edge_attr_key='feat', transform=None, pre_transform=None):
|
37
|
+
super(CustomGraphDataset, self).__init__(root, transform, pre_transform)
|
38
|
+
assert not (node_level and graph_level), "Both node_level and graph_level cannot be True at the same time"
|
39
|
+
assert node_level or graph_level, "Both node_level and graph_level cannot be False at the same time"
|
40
|
+
|
41
|
+
self.node_level = node_level
|
42
|
+
self.graph_level = graph_level
|
43
|
+
self.node_attr_key = node_attr_key
|
44
|
+
self.edge_attr_key = edge_attr_key
|
45
|
+
|
46
|
+
self.graph_df = pd.read_csv(os.path.join(root, 'graphs.csv'))
|
47
|
+
self.nodes_df = pd.read_csv(os.path.join(root, 'nodes.csv'))
|
48
|
+
self.edges_df = pd.read_csv(os.path.join(root, 'edges.csv'))
|
49
|
+
|
50
|
+
self.data_list = self.process_all()
|
51
|
+
|
52
|
+
@property
|
53
|
+
def raw_file_names(self):
|
54
|
+
return ['graphs.csv', 'nodes.csv', 'edges.csv']
|
55
|
+
|
56
|
+
def process_all(self):
|
57
|
+
data_list = []
|
58
|
+
for graph_id in self.graph_df['graph_id'].unique():
|
59
|
+
graph_nodes = self.nodes_df[self.nodes_df['graph_id'] == graph_id]
|
60
|
+
graph_edges = self.edges_df[self.edges_df['graph_id'] == graph_id]
|
61
|
+
|
62
|
+
if self.node_attr_key in graph_nodes.columns and not graph_nodes[self.node_attr_key].isnull().all():
|
63
|
+
x = torch.tensor(graph_nodes[self.node_attr_key].values.tolist(), dtype=torch.float)
|
64
|
+
if x.ndim == 1:
|
65
|
+
x = x.unsqueeze(1) # Ensure x has shape [num_nodes, *]
|
66
|
+
else:
|
67
|
+
x = None
|
68
|
+
|
69
|
+
edge_index = torch.tensor(graph_edges[['src_id', 'dst_id']].values.T, dtype=torch.long)
|
70
|
+
|
71
|
+
if self.edge_attr_key in graph_edges.columns and not graph_edges[self.edge_attr_key].isnull().all():
|
72
|
+
edge_attr = torch.tensor(graph_edges[self.edge_attr_key].values.tolist(), dtype=torch.float)
|
73
|
+
else:
|
74
|
+
edge_attr = None
|
75
|
+
|
76
|
+
if self.graph_level:
|
77
|
+
y = torch.tensor([self.graph_df[self.graph_df['graph_id'] == graph_id]['label'].values[0]], dtype=torch.long)
|
78
|
+
elif self.node_level:
|
79
|
+
y = torch.tensor(graph_nodes['label'].values, dtype=torch.long)
|
80
|
+
|
81
|
+
data = Data(x=x, edge_index=edge_index, y=y)
|
82
|
+
if edge_attr is not None:
|
83
|
+
data.edge_attr = edge_attr
|
84
|
+
|
85
|
+
data_list.append(data)
|
86
|
+
|
87
|
+
return data_list
|
88
|
+
|
89
|
+
def len(self):
|
90
|
+
return len(self.data_list)
|
91
|
+
|
92
|
+
def get(self, idx):
|
93
|
+
return self.data_list[idx]
|
94
|
+
|
95
|
+
def __getitem__(self, idx):
|
96
|
+
return self.get(idx)
|
97
|
+
|
98
|
+
class _Hparams:
|
99
|
+
def __init__(self, model_type="ClassifierHoldout", optimizer_str="Adam", amsgrad=False, betas=(0.9, 0.999), eps=1e-6, lr=0.001, lr_decay= 0, maximize=False, rho=0.9, weight_decay=0, cv_type="Holdout", split=[0.8,0.1, 0.1], k_folds=5, hl_widths=[32], conv_layer_type='SAGEConv', pooling="AvgPooling", batch_size=32, epochs=1,
|
100
|
+
use_gpu=False, loss_function="Cross Entropy", input_type="graph"):
|
101
|
+
"""
|
102
|
+
Parameters
|
103
|
+
----------
|
104
|
+
cv : str
|
105
|
+
A string to define the method of cross-validation
|
106
|
+
"Holdout": Holdout
|
107
|
+
"K-Fold": K-Fold cross validation
|
108
|
+
k_folds : int
|
109
|
+
An int value in the range of 2 to X to define the number of k-folds for cross-validation. Default is 5.
|
110
|
+
split : list
|
111
|
+
A list of three item in the range of 0 to 1 to define the split of train,
|
112
|
+
validate, and test data. A default value of [0.8,0.1,0.1] means 80% of data will be
|
113
|
+
used for training, 10% will be used for validation, and the remaining 10% will be used for training
|
114
|
+
hl_widths : list
|
115
|
+
List of hidden neurons for each layer such as [32] will mean
|
116
|
+
that there is one hidden layers in the network with 32 neurons
|
117
|
+
optimizer : torch.optim object
|
118
|
+
This will be the selected optimizer from torch.optim package. By
|
119
|
+
default, torch.optim.Adam is selected
|
120
|
+
learning_rate : float
|
121
|
+
a step value to be used to apply the gradients by optimizer
|
122
|
+
batch_size : int
|
123
|
+
to define a set of samples to be used for training and testing in
|
124
|
+
each step of an epoch
|
125
|
+
epochs : int
|
126
|
+
An epoch means training the neural network with all the training data for one cycle. In an epoch, we use all of the data exactly once. A forward pass and a backward pass together are counted as one pass
|
127
|
+
use_GPU : use the GPU. Otherwise, use the CPU
|
128
|
+
input_type : str
|
129
|
+
selects the input_type of model such as graph, node or edge
|
130
|
+
|
131
|
+
Returns
|
132
|
+
-------
|
133
|
+
None
|
134
|
+
|
135
|
+
"""
|
136
|
+
|
137
|
+
self.model_type = model_type
|
138
|
+
self.optimizer_str = optimizer_str
|
139
|
+
self.amsgrad = amsgrad
|
140
|
+
self.betas = betas
|
141
|
+
self.eps = eps
|
142
|
+
self.lr = lr
|
143
|
+
self.lr_decay = lr_decay
|
144
|
+
self.maximize = maximize
|
145
|
+
self.rho = rho
|
146
|
+
self.weight_decay = weight_decay
|
147
|
+
self.cv_type = cv_type
|
148
|
+
self.split = split
|
149
|
+
self.k_folds = k_folds
|
150
|
+
self.hl_widths = hl_widths
|
151
|
+
self.conv_layer_type = conv_layer_type
|
152
|
+
self.pooling = pooling
|
153
|
+
self.batch_size = batch_size
|
154
|
+
self.epochs = epochs
|
155
|
+
self.use_gpu = use_gpu
|
156
|
+
self.loss_function = loss_function
|
157
|
+
self.input_type = input_type
|
158
|
+
|
159
|
+
class _SAGEConv(nn.Module):
|
160
|
+
def __init__(self, in_feats, h_feats, num_classes, pooling=None):
|
161
|
+
super(_SAGEConv, self).__init__()
|
162
|
+
assert isinstance(h_feats, list), "h_feats must be a list"
|
163
|
+
h_feats = [x for x in h_feats if x is not None]
|
164
|
+
assert len(h_feats) != 0, "h_feats is empty. unable to add hidden layers"
|
165
|
+
self.list_of_layers = nn.ModuleList()
|
166
|
+
dim = [in_feats] + h_feats
|
167
|
+
|
168
|
+
# Convolution (Hidden) Layers
|
169
|
+
for i in range(1, len(dim)):
|
170
|
+
self.list_of_layers.append(SAGEConv(dim[i-1], dim[i]))
|
171
|
+
|
172
|
+
# Final Layer
|
173
|
+
self.final = nn.Linear(dim[-1], num_classes)
|
174
|
+
|
175
|
+
# Pooling layer
|
176
|
+
if pooling is None:
|
177
|
+
self.pooling_layer = None
|
178
|
+
else:
|
179
|
+
if "av" in pooling.lower():
|
180
|
+
self.pooling_layer = global_mean_pool
|
181
|
+
elif "max" in pooling.lower():
|
182
|
+
self.pooling_layer = global_max_pool
|
183
|
+
elif "sum" in pooling.lower():
|
184
|
+
self.pooling_layer = global_add_pool
|
185
|
+
else:
|
186
|
+
raise NotImplementedError
|
187
|
+
|
188
|
+
def forward(self, data):
|
189
|
+
x, edge_index, batch = data.x, data.edge_index, data.batch
|
190
|
+
h = x
|
191
|
+
# Generate node features
|
192
|
+
for layer in self.list_of_layers:
|
193
|
+
h = layer(h, edge_index)
|
194
|
+
h = F.relu(h)
|
195
|
+
# h will now be a matrix of dimension [num_nodes, h_feats[-1]]
|
196
|
+
h = self.final(h)
|
197
|
+
# Go from node-level features to graph-level features by pooling
|
198
|
+
if self.pooling_layer:
|
199
|
+
h = self.pooling_layer(h, batch)
|
200
|
+
# h will now be a vector of dimension [num_classes]
|
201
|
+
return h
|
202
|
+
|
203
|
+
class _GraphRegressorHoldout:
|
204
|
+
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
205
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
206
|
+
self.trainingDataset = trainingDataset
|
207
|
+
self.validationDataset = validationDataset
|
208
|
+
self.testingDataset = testingDataset
|
209
|
+
self.hparams = hparams
|
210
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
211
|
+
self.model = _SAGEConv(trainingDataset[0].num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
|
212
|
+
else:
|
213
|
+
raise NotImplementedError
|
214
|
+
|
215
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
216
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
217
|
+
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
218
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
219
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
220
|
+
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
221
|
+
elif hparams.optimizer_str.lower() == "adam":
|
222
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
223
|
+
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
224
|
+
|
225
|
+
self.use_gpu = hparams.use_gpu
|
226
|
+
self.training_loss_list = []
|
227
|
+
self.validation_loss_list = []
|
228
|
+
self.node_attr_key = trainingDataset[0].x.shape[1]
|
229
|
+
|
230
|
+
# Train, validate, test split
|
231
|
+
num_train = int(len(trainingDataset) * hparams.split[0])
|
232
|
+
num_validate = int(len(trainingDataset) * hparams.split[1])
|
233
|
+
num_test = len(trainingDataset) - num_train - num_validate
|
234
|
+
idx = torch.randperm(len(trainingDataset))
|
235
|
+
train_sampler = SubsetRandomSampler(idx[:num_train])
|
236
|
+
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
237
|
+
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:])
|
238
|
+
|
239
|
+
if validationDataset:
|
240
|
+
self.train_dataloader = DataLoader(trainingDataset,
|
241
|
+
batch_size=hparams.batch_size,
|
242
|
+
drop_last=False)
|
243
|
+
self.validate_dataloader = DataLoader(validationDataset,
|
244
|
+
batch_size=hparams.batch_size,
|
245
|
+
drop_last=False)
|
246
|
+
else:
|
247
|
+
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler,
|
248
|
+
batch_size=hparams.batch_size,
|
249
|
+
drop_last=False)
|
250
|
+
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler,
|
251
|
+
batch_size=hparams.batch_size,
|
252
|
+
drop_last=False)
|
253
|
+
|
254
|
+
if testingDataset:
|
255
|
+
self.test_dataloader = DataLoader(testingDataset,
|
256
|
+
batch_size=len(testingDataset),
|
257
|
+
drop_last=False)
|
258
|
+
else:
|
259
|
+
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler,
|
260
|
+
batch_size=hparams.batch_size,
|
261
|
+
drop_last=False)
|
262
|
+
|
263
|
+
def train(self):
|
264
|
+
# Init the loss and accuracy reporting lists
|
265
|
+
self.training_loss_list = []
|
266
|
+
self.validation_loss_list = []
|
267
|
+
|
268
|
+
# Run the training loop for defined number of epochs
|
269
|
+
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
270
|
+
# Iterate over the DataLoader for training data
|
271
|
+
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
272
|
+
data = data.to(self.device)
|
273
|
+
# Make sure the model is in training mode
|
274
|
+
self.model.train()
|
275
|
+
# Zero the gradients
|
276
|
+
self.optimizer.zero_grad()
|
277
|
+
|
278
|
+
# Perform forward pass
|
279
|
+
pred = self.model(data).to(self.device)
|
280
|
+
# Compute loss
|
281
|
+
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
282
|
+
|
283
|
+
# Perform backward pass
|
284
|
+
loss.backward()
|
285
|
+
|
286
|
+
# Perform optimization
|
287
|
+
self.optimizer.step()
|
288
|
+
|
289
|
+
self.training_loss_list.append(torch.sqrt(loss).item())
|
290
|
+
self.validate()
|
291
|
+
self.validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
292
|
+
gc.collect()
|
293
|
+
|
294
|
+
def validate(self):
|
295
|
+
self.model.eval()
|
296
|
+
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
297
|
+
data = data.to(self.device)
|
298
|
+
pred = self.model(data).to(self.device)
|
299
|
+
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
300
|
+
self.validation_loss = loss
|
301
|
+
|
302
|
+
def test(self):
|
303
|
+
self.model.eval()
|
304
|
+
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
305
|
+
data = data.to(self.device)
|
306
|
+
pred = self.model(data).to(self.device)
|
307
|
+
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
308
|
+
self.testing_loss = torch.sqrt(loss).item()
|
309
|
+
|
310
|
+
def save(self, path):
|
311
|
+
if path:
|
312
|
+
# Make sure the file extension is .pt
|
313
|
+
ext = path[-3:]
|
314
|
+
if ext.lower() != ".pt":
|
315
|
+
path = path + ".pt"
|
316
|
+
torch.save(self.model.state_dict(), path)
|
317
|
+
|
318
|
+
def load(self, path):
|
319
|
+
#self.model.load_state_dict(torch.load(path))
|
320
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
321
|
+
|
322
|
+
class _GraphRegressorKFold:
|
323
|
+
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
324
|
+
self.trainingDataset = trainingDataset
|
325
|
+
self.testingDataset = testingDataset
|
326
|
+
self.hparams = hparams
|
327
|
+
self.losses = []
|
328
|
+
self.min_loss = 0
|
329
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
330
|
+
|
331
|
+
self.model = self._initialize_model(hparams, trainingDataset)
|
332
|
+
self.optimizer = self._initialize_optimizer(hparams)
|
333
|
+
|
334
|
+
self.use_gpu = hparams.use_gpu
|
335
|
+
self.training_loss_list = []
|
336
|
+
self.validation_loss_list = []
|
337
|
+
self.node_attr_key = trainingDataset.node_attr_key
|
338
|
+
|
339
|
+
# Train, validate, test split
|
340
|
+
num_train = int(len(trainingDataset) * hparams.split[0])
|
341
|
+
num_validate = int(len(trainingDataset) * hparams.split[1])
|
342
|
+
num_test = len(trainingDataset) - num_train - num_validate
|
343
|
+
idx = torch.randperm(len(trainingDataset))
|
344
|
+
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
345
|
+
|
346
|
+
if testingDataset:
|
347
|
+
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
348
|
+
else:
|
349
|
+
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
350
|
+
|
351
|
+
def _initialize_model(self, hparams, dataset):
|
352
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
353
|
+
return _SAGEConv(dataset.num_node_features, hparams.hl_widths, 1, hparams.pooling).to(self.device)
|
354
|
+
else:
|
355
|
+
raise NotImplementedError
|
356
|
+
|
357
|
+
def _initialize_optimizer(self, hparams):
|
358
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
359
|
+
return torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
360
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
361
|
+
return torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
362
|
+
elif hparams.optimizer_str.lower() == "adam":
|
363
|
+
return torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps, lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
364
|
+
|
365
|
+
def reset_weights(self):
|
366
|
+
self.model = self._initialize_model(self.hparams, self.trainingDataset)
|
367
|
+
self.optimizer = self._initialize_optimizer(self.hparams)
|
368
|
+
|
369
|
+
def train(self):
|
370
|
+
k_folds = self.hparams.k_folds
|
371
|
+
torch.manual_seed(42)
|
372
|
+
|
373
|
+
kfold = KFold(n_splits=k_folds, shuffle=True)
|
374
|
+
models, weights, losses, train_dataloaders, validate_dataloaders = [], [], [], [], []
|
375
|
+
|
376
|
+
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", total=k_folds, leave=False):
|
377
|
+
epoch_training_loss_list, epoch_validation_loss_list = [], []
|
378
|
+
train_subsampler = SubsetRandomSampler(train_ids)
|
379
|
+
validate_subsampler = SubsetRandomSampler(validate_ids)
|
380
|
+
|
381
|
+
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
|
382
|
+
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
|
383
|
+
|
384
|
+
self.reset_weights()
|
385
|
+
best_rmse = np.inf
|
386
|
+
|
387
|
+
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
388
|
+
for batched_graph in tqdm(self.train_dataloader, desc='Training', leave=False):
|
389
|
+
self.model.train()
|
390
|
+
self.optimizer.zero_grad()
|
391
|
+
|
392
|
+
batched_graph = batched_graph.to(self.device)
|
393
|
+
pred = self.model(batched_graph)
|
394
|
+
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
395
|
+
loss.backward()
|
396
|
+
self.optimizer.step()
|
397
|
+
|
398
|
+
epoch_training_loss_list.append(torch.sqrt(loss).item())
|
399
|
+
self.validate()
|
400
|
+
epoch_validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
401
|
+
gc.collect()
|
402
|
+
|
403
|
+
models.append(self.model)
|
404
|
+
weights.append(copy.deepcopy(self.model.state_dict()))
|
405
|
+
losses.append(torch.sqrt(self.validation_loss).item())
|
406
|
+
train_dataloaders.append(self.train_dataloader)
|
407
|
+
validate_dataloaders.append(self.validate_dataloader)
|
408
|
+
self.training_loss_list.append(epoch_training_loss_list)
|
409
|
+
self.validation_loss_list.append(epoch_validation_loss_list)
|
410
|
+
|
411
|
+
self.losses = losses
|
412
|
+
self.min_loss = min(losses)
|
413
|
+
ind = losses.index(self.min_loss)
|
414
|
+
self.model = models[ind]
|
415
|
+
self.model.load_state_dict(weights[ind])
|
416
|
+
self.model.eval()
|
417
|
+
self.training_loss_list = self.training_loss_list[ind]
|
418
|
+
self.validation_loss_list = self.validation_loss_list[ind]
|
419
|
+
|
420
|
+
def validate(self):
|
421
|
+
self.model.eval()
|
422
|
+
for batched_graph in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
423
|
+
batched_graph = batched_graph.to(self.device)
|
424
|
+
pred = self.model(batched_graph)
|
425
|
+
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
426
|
+
self.validation_loss = loss
|
427
|
+
|
428
|
+
def test(self):
|
429
|
+
self.model.eval()
|
430
|
+
for batched_graph in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
431
|
+
batched_graph = batched_graph.to(self.device)
|
432
|
+
pred = self.model(batched_graph)
|
433
|
+
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
434
|
+
self.testing_loss = torch.sqrt(loss).item()
|
435
|
+
|
436
|
+
def save(self, path):
|
437
|
+
if path:
|
438
|
+
ext = path[-3:]
|
439
|
+
if ext.lower() != ".pt":
|
440
|
+
path = path + ".pt"
|
441
|
+
torch.save(self.model.state_dict(), path)
|
442
|
+
|
443
|
+
def load(self, path):
|
444
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
445
|
+
|
446
|
+
class _GraphClassifierKFold:
|
447
|
+
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
448
|
+
self.trainingDataset = trainingDataset
|
449
|
+
self.testingDataset = testingDataset
|
450
|
+
self.hparams = hparams
|
451
|
+
self.testing_accuracy = 0
|
452
|
+
self.accuracies = []
|
453
|
+
self.max_accuracy = 0
|
454
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
455
|
+
|
456
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
457
|
+
self.model = _SAGEConv(trainingDataset.num_node_features, hparams.hl_widths,
|
458
|
+
trainingDataset.num_classes, hparams.pooling).to(self.device)
|
459
|
+
else:
|
460
|
+
raise NotImplementedError
|
461
|
+
|
462
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
463
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
464
|
+
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
465
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
466
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
467
|
+
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
468
|
+
elif hparams.optimizer_str.lower() == "adam":
|
469
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
470
|
+
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
471
|
+
self.use_gpu = hparams.use_gpu
|
472
|
+
self.training_loss_list = []
|
473
|
+
self.validation_loss_list = []
|
474
|
+
self.training_accuracy_list = []
|
475
|
+
self.validation_accuracy_list = []
|
476
|
+
|
477
|
+
def reset_weights(self):
|
478
|
+
if self.hparams.conv_layer_type.lower() == 'sageconv':
|
479
|
+
self.model = _SAGEConv(self.trainingDataset.num_node_features, self.hparams.hl_widths,
|
480
|
+
self.trainingDataset.num_classes, self.hparams.pooling).to(self.device)
|
481
|
+
else:
|
482
|
+
raise NotImplementedError
|
483
|
+
|
484
|
+
if self.hparams.optimizer_str.lower() == "adadelta":
|
485
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=self.hparams.eps,
|
486
|
+
lr=self.hparams.lr, rho=self.hparams.rho, weight_decay=self.hparams.weight_decay)
|
487
|
+
elif self.hparams.optimizer_str.lower() == "adagrad":
|
488
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=self.hparams.eps,
|
489
|
+
lr=self.hparams.lr, lr_decay=self.hparams.lr_decay, weight_decay=self.hparams.weight_decay)
|
490
|
+
elif self.hparams.optimizer_str.lower() == "adam":
|
491
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=self.hparams.amsgrad, betas=self.hparams.betas, eps=self.hparams.eps,
|
492
|
+
lr=self.hparams.lr, maximize=self.hparams.maximize, weight_decay=self.hparams.weight_decay)
|
493
|
+
|
494
|
+
def train(self):
|
495
|
+
k_folds = self.hparams.k_folds
|
496
|
+
|
497
|
+
# Init the loss and accuracy reporting lists
|
498
|
+
self.training_accuracy_list = []
|
499
|
+
self.training_loss_list = []
|
500
|
+
self.validation_accuracy_list = []
|
501
|
+
self.validation_loss_list = []
|
502
|
+
|
503
|
+
# Set fixed random number seed
|
504
|
+
torch.manual_seed(42)
|
505
|
+
|
506
|
+
# Define the K-fold Cross Validator
|
507
|
+
kfold = KFold(n_splits=k_folds, shuffle=True)
|
508
|
+
|
509
|
+
models = []
|
510
|
+
weights = []
|
511
|
+
accuracies = []
|
512
|
+
train_dataloaders = []
|
513
|
+
validate_dataloaders = []
|
514
|
+
|
515
|
+
# K-fold Cross-validation model evaluation
|
516
|
+
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", initial=1, total=k_folds, leave=False):
|
517
|
+
epoch_training_loss_list = []
|
518
|
+
epoch_training_accuracy_list = []
|
519
|
+
epoch_validation_loss_list = []
|
520
|
+
epoch_validation_accuracy_list = []
|
521
|
+
# Sample elements randomly from a given list of ids, no replacement.
|
522
|
+
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
|
523
|
+
validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
|
524
|
+
|
525
|
+
# Define data loaders for training and testing data in this fold
|
526
|
+
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler,
|
527
|
+
batch_size=self.hparams.batch_size,
|
528
|
+
drop_last=False)
|
529
|
+
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler,
|
530
|
+
batch_size=self.hparams.batch_size,
|
531
|
+
drop_last=False)
|
532
|
+
# Init the neural network
|
533
|
+
self.reset_weights()
|
534
|
+
|
535
|
+
# Run the training loop for defined number of epochs
|
536
|
+
for _ in tqdm(range(0,self.hparams.epochs), desc='Epochs', initial=1, total=self.hparams.epochs, leave=False):
|
537
|
+
temp_loss_list = []
|
538
|
+
temp_acc_list = []
|
539
|
+
|
540
|
+
# Iterate over the DataLoader for training data
|
541
|
+
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
542
|
+
data = data.to(self.device)
|
543
|
+
# Make sure the model is in training mode
|
544
|
+
self.model.train()
|
545
|
+
|
546
|
+
# Zero the gradients
|
547
|
+
self.optimizer.zero_grad()
|
548
|
+
|
549
|
+
# Perform forward pass
|
550
|
+
pred = self.model(data)
|
551
|
+
|
552
|
+
# Compute loss
|
553
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
554
|
+
logp = F.log_softmax(pred, 1)
|
555
|
+
loss = F.nll_loss(logp, data.y)
|
556
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
557
|
+
loss = F.cross_entropy(pred, data.y)
|
558
|
+
|
559
|
+
# Save loss information for reporting
|
560
|
+
temp_loss_list.append(loss.item())
|
561
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
562
|
+
|
563
|
+
# Perform backward pass
|
564
|
+
loss.backward()
|
565
|
+
|
566
|
+
# Perform optimization
|
567
|
+
self.optimizer.step()
|
568
|
+
|
569
|
+
epoch_training_accuracy_list.append(np.mean(temp_acc_list).item())
|
570
|
+
epoch_training_loss_list.append(np.mean(temp_loss_list).item())
|
571
|
+
self.validate()
|
572
|
+
epoch_validation_accuracy_list.append(self.validation_accuracy)
|
573
|
+
epoch_validation_loss_list.append(self.validation_loss)
|
574
|
+
gc.collect()
|
575
|
+
models.append(self.model)
|
576
|
+
weights.append(copy.deepcopy(self.model.state_dict()))
|
577
|
+
accuracies.append(self.validation_accuracy)
|
578
|
+
train_dataloaders.append(self.train_dataloader)
|
579
|
+
validate_dataloaders.append(self.validate_dataloader)
|
580
|
+
self.training_accuracy_list.append(epoch_training_accuracy_list)
|
581
|
+
self.training_loss_list.append(epoch_training_loss_list)
|
582
|
+
self.validation_accuracy_list.append(epoch_validation_accuracy_list)
|
583
|
+
self.validation_loss_list.append(epoch_validation_loss_list)
|
584
|
+
self.accuracies = accuracies
|
585
|
+
max_accuracy = max(accuracies)
|
586
|
+
self.max_accuracy = max_accuracy
|
587
|
+
ind = accuracies.index(max_accuracy)
|
588
|
+
self.model = models[ind]
|
589
|
+
self.model.load_state_dict(weights[ind])
|
590
|
+
self.model.eval()
|
591
|
+
self.training_accuracy_list = self.training_accuracy_list[ind]
|
592
|
+
self.training_loss_list = self.training_loss_list[ind]
|
593
|
+
self.validation_accuracy_list = self.validation_accuracy_list[ind]
|
594
|
+
self.validation_loss_list = self.validation_loss_list[ind]
|
595
|
+
|
596
|
+
def validate(self):
|
597
|
+
temp_loss_list = []
|
598
|
+
temp_acc_list = []
|
599
|
+
self.model.eval()
|
600
|
+
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
601
|
+
data = data.to(self.device)
|
602
|
+
pred = self.model(data)
|
603
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
604
|
+
logp = F.log_softmax(pred, 1)
|
605
|
+
loss = F.nll_loss(logp, data.y)
|
606
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
607
|
+
loss = F.cross_entropy(pred, data.y)
|
608
|
+
temp_loss_list.append(loss.item())
|
609
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
610
|
+
self.validation_accuracy = np.mean(temp_acc_list).item()
|
611
|
+
self.validation_loss = np.mean(temp_loss_list).item()
|
612
|
+
|
613
|
+
def test(self):
|
614
|
+
if self.testingDataset:
|
615
|
+
self.test_dataloader = DataLoader(self.testingDataset,
|
616
|
+
batch_size=len(self.testingDataset),
|
617
|
+
drop_last=False)
|
618
|
+
temp_loss_list = []
|
619
|
+
temp_acc_list = []
|
620
|
+
self.model.eval()
|
621
|
+
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
622
|
+
data = data.to(self.device)
|
623
|
+
pred = self.model(data)
|
624
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
625
|
+
logp = F.log_softmax(pred, 1)
|
626
|
+
loss = F.nll_loss(logp, data.y)
|
627
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
628
|
+
loss = F.cross_entropy(pred, data.y)
|
629
|
+
temp_loss_list.append(loss.item())
|
630
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
631
|
+
self.testing_accuracy = np.mean(temp_acc_list).item()
|
632
|
+
self.testing_loss = np.mean(temp_loss_list).item()
|
633
|
+
|
634
|
+
def save(self, path):
|
635
|
+
if path:
|
636
|
+
# Make sure the file extension is .pt
|
637
|
+
ext = path[-3:]
|
638
|
+
if ext.lower() != ".pt":
|
639
|
+
path = path + ".pt"
|
640
|
+
torch.save(self.model.state_dict(), path)
|
641
|
+
|
642
|
+
def load(self, path):
|
643
|
+
#self.model.load_state_dict(torch.load(path))
|
644
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
645
|
+
|
646
|
+
class _GraphClassifierHoldout:
|
647
|
+
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
648
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
649
|
+
self.trainingDataset = trainingDataset
|
650
|
+
self.validationDataset = validationDataset
|
651
|
+
self.testingDataset = testingDataset
|
652
|
+
self.hparams = hparams
|
653
|
+
gclasses = trainingDataset.num_classes
|
654
|
+
nfeats = trainingDataset.num_node_features
|
655
|
+
|
656
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
657
|
+
self.model = _SAGEConv(nfeats, hparams.hl_widths, gclasses, hparams.pooling).to(self.device)
|
658
|
+
else:
|
659
|
+
raise NotImplementedError
|
660
|
+
|
661
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
662
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
663
|
+
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
664
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
665
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
666
|
+
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
667
|
+
elif hparams.optimizer_str.lower() == "adam":
|
668
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
669
|
+
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
670
|
+
self.use_gpu = hparams.use_gpu
|
671
|
+
self.training_loss_list = []
|
672
|
+
self.validation_loss_list = []
|
673
|
+
self.training_accuracy_list = []
|
674
|
+
self.validation_accuracy_list = []
|
675
|
+
self.node_attr_key = trainingDataset[0].x.shape[1]
|
676
|
+
|
677
|
+
# train, validate, test split
|
678
|
+
num_train = int(len(trainingDataset) * hparams.split[0])
|
679
|
+
num_validate = int(len(trainingDataset) * hparams.split[1])
|
680
|
+
num_test = len(trainingDataset) - num_train - num_validate
|
681
|
+
idx = torch.randperm(len(trainingDataset))
|
682
|
+
train_sampler = SubsetRandomSampler(idx[:num_train])
|
683
|
+
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
684
|
+
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
685
|
+
|
686
|
+
if validationDataset:
|
687
|
+
self.train_dataloader = DataLoader(trainingDataset, batch_size=hparams.batch_size, drop_last=False)
|
688
|
+
self.validate_dataloader = DataLoader(validationDataset, batch_size=hparams.batch_size, drop_last=False)
|
689
|
+
else:
|
690
|
+
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler, batch_size=hparams.batch_size, drop_last=False)
|
691
|
+
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler, batch_size=hparams.batch_size, drop_last=False)
|
692
|
+
|
693
|
+
if testingDataset:
|
694
|
+
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
695
|
+
else:
|
696
|
+
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
697
|
+
|
698
|
+
def train(self):
|
699
|
+
# Init the loss and accuracy reporting lists
|
700
|
+
self.training_accuracy_list = []
|
701
|
+
self.training_loss_list = []
|
702
|
+
self.validation_accuracy_list = []
|
703
|
+
self.validation_loss_list = []
|
704
|
+
|
705
|
+
# Run the training loop for defined number of epochs
|
706
|
+
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', initial=1, leave=False):
|
707
|
+
temp_loss_list = []
|
708
|
+
temp_acc_list = []
|
709
|
+
# Make sure the model is in training mode
|
710
|
+
self.model.train()
|
711
|
+
# Iterate over the DataLoader for training data
|
712
|
+
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
713
|
+
data = data.to(self.device)
|
714
|
+
|
715
|
+
# Zero the gradients
|
716
|
+
self.optimizer.zero_grad()
|
717
|
+
|
718
|
+
# Perform forward pass
|
719
|
+
pred = self.model(data)
|
720
|
+
|
721
|
+
# Compute loss
|
722
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
723
|
+
logp = F.log_softmax(pred, 1)
|
724
|
+
loss = F.nll_loss(logp, data.y)
|
725
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
726
|
+
loss = F.cross_entropy(pred, data.y)
|
727
|
+
|
728
|
+
# Save loss information for reporting
|
729
|
+
temp_loss_list.append(loss.item())
|
730
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
731
|
+
|
732
|
+
# Perform backward pass
|
733
|
+
loss.backward()
|
734
|
+
|
735
|
+
# Perform optimization
|
736
|
+
self.optimizer.step()
|
737
|
+
|
738
|
+
self.training_accuracy_list.append(np.mean(temp_acc_list).item())
|
739
|
+
self.training_loss_list.append(np.mean(temp_loss_list).item())
|
740
|
+
self.validate()
|
741
|
+
self.validation_accuracy_list.append(self.validation_accuracy)
|
742
|
+
self.validation_loss_list.append(self.validation_loss)
|
743
|
+
gc.collect()
|
744
|
+
|
745
|
+
def validate(self):
|
746
|
+
temp_loss_list = []
|
747
|
+
temp_acc_list = []
|
748
|
+
self.model.eval()
|
749
|
+
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
750
|
+
data = data.to(self.device)
|
751
|
+
pred = self.model(data)
|
752
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
753
|
+
logp = F.log_softmax(pred, 1)
|
754
|
+
loss = F.nll_loss(logp, data.y)
|
755
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
756
|
+
loss = F.cross_entropy(pred, data.y)
|
757
|
+
temp_loss_list.append(loss.item())
|
758
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
759
|
+
self.validation_accuracy = np.mean(temp_acc_list).item()
|
760
|
+
self.validation_loss = np.mean(temp_loss_list).item()
|
761
|
+
|
762
|
+
def test(self):
|
763
|
+
if self.test_dataloader:
|
764
|
+
temp_loss_list = []
|
765
|
+
temp_acc_list = []
|
766
|
+
self.model.eval()
|
767
|
+
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
768
|
+
data = data.to(self.device)
|
769
|
+
pred = self.model(data)
|
770
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
771
|
+
logp = F.log_softmax(pred, 1)
|
772
|
+
loss = F.nll_loss(logp, data.y)
|
773
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
774
|
+
loss = F.cross_entropy(pred, data.y)
|
775
|
+
temp_loss_list.append(loss.item())
|
776
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
777
|
+
self.testing_accuracy = np.mean(temp_acc_list).item()
|
778
|
+
self.testing_loss = np.mean(temp_loss_list).item()
|
779
|
+
|
780
|
+
def save(self, path):
|
781
|
+
if path:
|
782
|
+
# Make sure the file extension is .pt
|
783
|
+
ext = path[-3:]
|
784
|
+
if ext.lower() != ".pt":
|
785
|
+
path = path + ".pt"
|
786
|
+
torch.save(self.model.state_dict(), path)
|
787
|
+
|
788
|
+
def load(self, path):
|
789
|
+
#self.model.load_state_dict(torch.load(path))
|
790
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
791
|
+
|
792
|
+
class _NodeClassifierHoldout:
|
793
|
+
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
794
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
795
|
+
self.trainingDataset = trainingDataset
|
796
|
+
self.validationDataset = validationDataset
|
797
|
+
self.testingDataset = testingDataset
|
798
|
+
self.hparams = hparams
|
799
|
+
gclasses = trainingDataset.num_classes
|
800
|
+
nfeats = trainingDataset.num_node_features
|
801
|
+
|
802
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
803
|
+
# pooling is set None for Node classifier
|
804
|
+
self.model = _SAGEConv(nfeats, hparams.hl_widths, gclasses, None).to(self.device)
|
805
|
+
else:
|
806
|
+
raise NotImplementedError
|
807
|
+
|
808
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
809
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
810
|
+
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
811
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
812
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
813
|
+
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
814
|
+
elif hparams.optimizer_str.lower() == "adam":
|
815
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
816
|
+
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
817
|
+
self.use_gpu = hparams.use_gpu
|
818
|
+
self.training_loss_list = []
|
819
|
+
self.validation_loss_list = []
|
820
|
+
self.training_accuracy_list = []
|
821
|
+
self.validation_accuracy_list = []
|
822
|
+
self.node_attr_key = trainingDataset[0].x.shape[1]
|
823
|
+
|
824
|
+
# train, validate, test split
|
825
|
+
num_train = int(len(trainingDataset) * hparams.split[0])
|
826
|
+
num_validate = int(len(trainingDataset) * hparams.split[1])
|
827
|
+
num_test = len(trainingDataset) - num_train - num_validate
|
828
|
+
idx = torch.randperm(len(trainingDataset))
|
829
|
+
train_sampler = SubsetRandomSampler(idx[:num_train])
|
830
|
+
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
831
|
+
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
832
|
+
|
833
|
+
if validationDataset:
|
834
|
+
self.train_dataloader = DataLoader(trainingDataset, batch_size=hparams.batch_size, drop_last=False)
|
835
|
+
self.validate_dataloader = DataLoader(validationDataset, batch_size=hparams.batch_size, drop_last=False)
|
836
|
+
else:
|
837
|
+
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler, batch_size=hparams.batch_size, drop_last=False)
|
838
|
+
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler, batch_size=hparams.batch_size, drop_last=False)
|
839
|
+
|
840
|
+
if testingDataset:
|
841
|
+
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
842
|
+
else:
|
843
|
+
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
844
|
+
|
845
|
+
def train(self):
|
846
|
+
# Init the loss and accuracy reporting lists
|
847
|
+
self.training_accuracy_list = []
|
848
|
+
self.training_loss_list = []
|
849
|
+
self.validation_accuracy_list = []
|
850
|
+
self.validation_loss_list = []
|
851
|
+
|
852
|
+
# Run the training loop for defined number of epochs
|
853
|
+
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', initial=1, leave=False):
|
854
|
+
temp_loss_list = []
|
855
|
+
temp_acc_list = []
|
856
|
+
# Iterate over the DataLoader for training data
|
857
|
+
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
858
|
+
data = data.to(self.device)
|
859
|
+
# Make sure the model is in training mode
|
860
|
+
self.model.train()
|
861
|
+
|
862
|
+
# Zero the gradients
|
863
|
+
self.optimizer.zero_grad()
|
864
|
+
|
865
|
+
# Perform forward pass
|
866
|
+
pred = self.model(data)
|
867
|
+
|
868
|
+
# Compute loss
|
869
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
870
|
+
logp = F.log_softmax(pred, 1)
|
871
|
+
loss = F.nll_loss(logp, data.y)
|
872
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
873
|
+
loss = F.cross_entropy(pred, data.y)
|
874
|
+
|
875
|
+
# Save loss information for reporting
|
876
|
+
temp_loss_list.append(loss.item())
|
877
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
878
|
+
|
879
|
+
# Perform backward pass
|
880
|
+
loss.backward()
|
881
|
+
|
882
|
+
# Perform optimization
|
883
|
+
self.optimizer.step()
|
884
|
+
|
885
|
+
self.training_accuracy_list.append(np.mean(temp_acc_list).item())
|
886
|
+
self.training_loss_list.append(np.mean(temp_loss_list).item())
|
887
|
+
self.validate()
|
888
|
+
self.validation_accuracy_list.append(self.validation_accuracy)
|
889
|
+
self.validation_loss_list.append(self.validation_loss)
|
890
|
+
gc.collect()
|
891
|
+
|
892
|
+
def validate(self):
|
893
|
+
temp_loss_list = []
|
894
|
+
temp_acc_list = []
|
895
|
+
self.model.eval()
|
896
|
+
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
897
|
+
data = data.to(self.device)
|
898
|
+
pred = self.model(data)
|
899
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
900
|
+
logp = F.log_softmax(pred, 1)
|
901
|
+
loss = F.nll_loss(logp, data.y)
|
902
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
903
|
+
loss = F.cross_entropy(pred, data.y)
|
904
|
+
temp_loss_list.append(loss.item())
|
905
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
906
|
+
self.validation_accuracy = np.mean(temp_acc_list).item()
|
907
|
+
self.validation_loss = np.mean(temp_loss_list).item()
|
908
|
+
|
909
|
+
def test(self):
|
910
|
+
if self.test_dataloader:
|
911
|
+
temp_loss_list = []
|
912
|
+
temp_acc_list = []
|
913
|
+
self.model.eval()
|
914
|
+
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
915
|
+
data = data.to(self.device)
|
916
|
+
pred = self.model(data)
|
917
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
918
|
+
logp = F.log_softmax(pred, 1)
|
919
|
+
loss = F.nll_loss(logp, data.y)
|
920
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
921
|
+
loss = F.cross_entropy(pred, data.y)
|
922
|
+
temp_loss_list.append(loss.item())
|
923
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
924
|
+
self.testing_accuracy = np.mean(temp_acc_list).item()
|
925
|
+
self.testing_loss = np.mean(temp_loss_list).item()
|
926
|
+
|
927
|
+
def save(self, path):
|
928
|
+
if path:
|
929
|
+
# Make sure the file extension is .pt
|
930
|
+
ext = path[-3:]
|
931
|
+
if ext.lower() != ".pt":
|
932
|
+
path = path + ".pt"
|
933
|
+
torch.save(self.model.state_dict(), path)
|
934
|
+
|
935
|
+
def load(self, path):
|
936
|
+
#self.model.load_state_dict(torch.load(path))
|
937
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
938
|
+
|
939
|
+
class _NodeRegressorHoldout:
|
940
|
+
def __init__(self, hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
941
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
942
|
+
self.trainingDataset = trainingDataset
|
943
|
+
self.validationDataset = validationDataset
|
944
|
+
self.testingDataset = testingDataset
|
945
|
+
self.hparams = hparams
|
946
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
947
|
+
# pooling is set None for Node regressor
|
948
|
+
self.model = _SAGEConv(trainingDataset[0].num_node_features, hparams.hl_widths, 1, None).to(self.device)
|
949
|
+
else:
|
950
|
+
raise NotImplementedError
|
951
|
+
|
952
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
953
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
954
|
+
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
955
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
956
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
957
|
+
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
958
|
+
elif hparams.optimizer_str.lower() == "adam":
|
959
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
960
|
+
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
961
|
+
|
962
|
+
self.use_gpu = hparams.use_gpu
|
963
|
+
self.training_loss_list = []
|
964
|
+
self.validation_loss_list = []
|
965
|
+
self.node_attr_key = trainingDataset[0].x.shape[1]
|
966
|
+
|
967
|
+
# Train, validate, test split
|
968
|
+
num_train = int(len(trainingDataset) * hparams.split[0])
|
969
|
+
num_validate = int(len(trainingDataset) * hparams.split[1])
|
970
|
+
num_test = len(trainingDataset) - num_train - num_validate
|
971
|
+
idx = torch.randperm(len(trainingDataset))
|
972
|
+
train_sampler = SubsetRandomSampler(idx[:num_train])
|
973
|
+
validate_sampler = SubsetRandomSampler(idx[num_train:num_train+num_validate])
|
974
|
+
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:])
|
975
|
+
|
976
|
+
if validationDataset:
|
977
|
+
self.train_dataloader = DataLoader(trainingDataset,
|
978
|
+
batch_size=hparams.batch_size,
|
979
|
+
drop_last=False)
|
980
|
+
self.validate_dataloader = DataLoader(validationDataset,
|
981
|
+
batch_size=hparams.batch_size,
|
982
|
+
drop_last=False)
|
983
|
+
else:
|
984
|
+
self.train_dataloader = DataLoader(trainingDataset, sampler=train_sampler,
|
985
|
+
batch_size=hparams.batch_size,
|
986
|
+
drop_last=False)
|
987
|
+
self.validate_dataloader = DataLoader(trainingDataset, sampler=validate_sampler,
|
988
|
+
batch_size=hparams.batch_size,
|
989
|
+
drop_last=False)
|
990
|
+
|
991
|
+
if testingDataset:
|
992
|
+
self.test_dataloader = DataLoader(testingDataset,
|
993
|
+
batch_size=len(testingDataset),
|
994
|
+
drop_last=False)
|
995
|
+
else:
|
996
|
+
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler,
|
997
|
+
batch_size=hparams.batch_size,
|
998
|
+
drop_last=False)
|
999
|
+
|
1000
|
+
def train(self):
|
1001
|
+
# Init the loss and accuracy reporting lists
|
1002
|
+
self.training_loss_list = []
|
1003
|
+
self.validation_loss_list = []
|
1004
|
+
|
1005
|
+
# Run the training loop for defined number of epochs
|
1006
|
+
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
1007
|
+
# Iterate over the DataLoader for training data
|
1008
|
+
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
1009
|
+
data = data.to(self.device)
|
1010
|
+
# Make sure the model is in training mode
|
1011
|
+
self.model.train()
|
1012
|
+
# Zero the gradients
|
1013
|
+
self.optimizer.zero_grad()
|
1014
|
+
|
1015
|
+
# Perform forward pass
|
1016
|
+
pred = self.model(data).to(self.device)
|
1017
|
+
# Compute loss
|
1018
|
+
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
1019
|
+
|
1020
|
+
# Perform backward pass
|
1021
|
+
loss.backward()
|
1022
|
+
|
1023
|
+
# Perform optimization
|
1024
|
+
self.optimizer.step()
|
1025
|
+
|
1026
|
+
self.training_loss_list.append(torch.sqrt(loss).item())
|
1027
|
+
self.validate()
|
1028
|
+
self.validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
1029
|
+
gc.collect()
|
1030
|
+
|
1031
|
+
def validate(self):
|
1032
|
+
self.model.eval()
|
1033
|
+
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
1034
|
+
data = data.to(self.device)
|
1035
|
+
pred = self.model(data).to(self.device)
|
1036
|
+
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
1037
|
+
self.validation_loss = loss
|
1038
|
+
|
1039
|
+
def test(self):
|
1040
|
+
self.model.eval()
|
1041
|
+
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
1042
|
+
data = data.to(self.device)
|
1043
|
+
pred = self.model(data).to(self.device)
|
1044
|
+
loss = F.mse_loss(torch.flatten(pred), data.y.float())
|
1045
|
+
self.testing_loss = torch.sqrt(loss).item()
|
1046
|
+
|
1047
|
+
def save(self, path):
|
1048
|
+
if path:
|
1049
|
+
# Make sure the file extension is .pt
|
1050
|
+
ext = path[-3:]
|
1051
|
+
if ext.lower() != ".pt":
|
1052
|
+
path = path + ".pt"
|
1053
|
+
torch.save(self.model.state_dict(), path)
|
1054
|
+
|
1055
|
+
def load(self, path):
|
1056
|
+
#self.model.load_state_dict(torch.load(path))
|
1057
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
1058
|
+
|
1059
|
+
class _NodeClassifierKFold:
|
1060
|
+
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
1061
|
+
self.trainingDataset = trainingDataset
|
1062
|
+
self.testingDataset = testingDataset
|
1063
|
+
self.hparams = hparams
|
1064
|
+
self.testing_accuracy = 0
|
1065
|
+
self.accuracies = []
|
1066
|
+
self.max_accuracy = 0
|
1067
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1068
|
+
|
1069
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
1070
|
+
# pooling is set None for Node classifier
|
1071
|
+
self.model = _SAGEConv(trainingDataset.num_node_features, hparams.hl_widths,
|
1072
|
+
trainingDataset.num_classes, None).to(self.device)
|
1073
|
+
else:
|
1074
|
+
raise NotImplementedError
|
1075
|
+
|
1076
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
1077
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps,
|
1078
|
+
lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
1079
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
1080
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps,
|
1081
|
+
lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
1082
|
+
elif hparams.optimizer_str.lower() == "adam":
|
1083
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps,
|
1084
|
+
lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
1085
|
+
self.use_gpu = hparams.use_gpu
|
1086
|
+
self.training_loss_list = []
|
1087
|
+
self.validation_loss_list = []
|
1088
|
+
self.training_accuracy_list = []
|
1089
|
+
self.validation_accuracy_list = []
|
1090
|
+
|
1091
|
+
def reset_weights(self):
|
1092
|
+
if self.hparams.conv_layer_type.lower() == 'sageconv':
|
1093
|
+
# pooling is set None for Node classifier
|
1094
|
+
self.model = _SAGEConv(self.trainingDataset.num_node_features, self.hparams.hl_widths,
|
1095
|
+
self.trainingDataset.num_classes, None).to(self.device)
|
1096
|
+
else:
|
1097
|
+
raise NotImplementedError
|
1098
|
+
|
1099
|
+
if self.hparams.optimizer_str.lower() == "adadelta":
|
1100
|
+
self.optimizer = torch.optim.Adadelta(self.model.parameters(), eps=self.hparams.eps,
|
1101
|
+
lr=self.hparams.lr, rho=self.hparams.rho, weight_decay=self.hparams.weight_decay)
|
1102
|
+
elif self.hparams.optimizer_str.lower() == "adagrad":
|
1103
|
+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), eps=self.hparams.eps,
|
1104
|
+
lr=self.hparams.lr, lr_decay=self.hparams.lr_decay, weight_decay=self.hparams.weight_decay)
|
1105
|
+
elif self.hparams.optimizer_str.lower() == "adam":
|
1106
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), amsgrad=self.hparams.amsgrad, betas=self.hparams.betas, eps=self.hparams.eps,
|
1107
|
+
lr=self.hparams.lr, maximize=self.hparams.maximize, weight_decay=self.hparams.weight_decay)
|
1108
|
+
|
1109
|
+
def train(self):
|
1110
|
+
k_folds = self.hparams.k_folds
|
1111
|
+
|
1112
|
+
# Init the loss and accuracy reporting lists
|
1113
|
+
self.training_accuracy_list = []
|
1114
|
+
self.training_loss_list = []
|
1115
|
+
self.validation_accuracy_list = []
|
1116
|
+
self.validation_loss_list = []
|
1117
|
+
|
1118
|
+
# Set fixed random number seed
|
1119
|
+
torch.manual_seed(42)
|
1120
|
+
|
1121
|
+
# Define the K-fold Cross Validator
|
1122
|
+
kfold = KFold(n_splits=k_folds, shuffle=True)
|
1123
|
+
|
1124
|
+
models = []
|
1125
|
+
weights = []
|
1126
|
+
accuracies = []
|
1127
|
+
train_dataloaders = []
|
1128
|
+
validate_dataloaders = []
|
1129
|
+
|
1130
|
+
# K-fold Cross-validation model evaluation
|
1131
|
+
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", initial=1, total=k_folds, leave=False):
|
1132
|
+
epoch_training_loss_list = []
|
1133
|
+
epoch_training_accuracy_list = []
|
1134
|
+
epoch_validation_loss_list = []
|
1135
|
+
epoch_validation_accuracy_list = []
|
1136
|
+
# Sample elements randomly from a given list of ids, no replacement.
|
1137
|
+
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
|
1138
|
+
validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
|
1139
|
+
|
1140
|
+
# Define data loaders for training and testing data in this fold
|
1141
|
+
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler,
|
1142
|
+
batch_size=self.hparams.batch_size,
|
1143
|
+
drop_last=False)
|
1144
|
+
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler,
|
1145
|
+
batch_size=self.hparams.batch_size,
|
1146
|
+
drop_last=False)
|
1147
|
+
# Init the neural network
|
1148
|
+
self.reset_weights()
|
1149
|
+
|
1150
|
+
# Run the training loop for defined number of epochs
|
1151
|
+
for _ in tqdm(range(0,self.hparams.epochs), desc='Epochs', initial=1, total=self.hparams.epochs, leave=False):
|
1152
|
+
temp_loss_list = []
|
1153
|
+
temp_acc_list = []
|
1154
|
+
|
1155
|
+
# Iterate over the DataLoader for training data
|
1156
|
+
for data in tqdm(self.train_dataloader, desc='Training', leave=False):
|
1157
|
+
data = data.to(self.device)
|
1158
|
+
# Make sure the model is in training mode
|
1159
|
+
self.model.train()
|
1160
|
+
|
1161
|
+
# Zero the gradients
|
1162
|
+
self.optimizer.zero_grad()
|
1163
|
+
|
1164
|
+
# Perform forward pass
|
1165
|
+
pred = self.model(data)
|
1166
|
+
|
1167
|
+
# Compute loss
|
1168
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
1169
|
+
logp = F.log_softmax(pred, 1)
|
1170
|
+
loss = F.nll_loss(logp, data.y)
|
1171
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
1172
|
+
loss = F.cross_entropy(pred, data.y)
|
1173
|
+
|
1174
|
+
# Save loss information for reporting
|
1175
|
+
temp_loss_list.append(loss.item())
|
1176
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
1177
|
+
|
1178
|
+
# Perform backward pass
|
1179
|
+
loss.backward()
|
1180
|
+
|
1181
|
+
# Perform optimization
|
1182
|
+
self.optimizer.step()
|
1183
|
+
|
1184
|
+
epoch_training_accuracy_list.append(np.mean(temp_acc_list).item())
|
1185
|
+
epoch_training_loss_list.append(np.mean(temp_loss_list).item())
|
1186
|
+
self.validate()
|
1187
|
+
epoch_validation_accuracy_list.append(self.validation_accuracy)
|
1188
|
+
epoch_validation_loss_list.append(self.validation_loss)
|
1189
|
+
gc.collect()
|
1190
|
+
models.append(self.model)
|
1191
|
+
weights.append(copy.deepcopy(self.model.state_dict()))
|
1192
|
+
accuracies.append(self.validation_accuracy)
|
1193
|
+
train_dataloaders.append(self.train_dataloader)
|
1194
|
+
validate_dataloaders.append(self.validate_dataloader)
|
1195
|
+
self.training_accuracy_list.append(epoch_training_accuracy_list)
|
1196
|
+
self.training_loss_list.append(epoch_training_loss_list)
|
1197
|
+
self.validation_accuracy_list.append(epoch_validation_accuracy_list)
|
1198
|
+
self.validation_loss_list.append(epoch_validation_loss_list)
|
1199
|
+
self.accuracies = accuracies
|
1200
|
+
max_accuracy = max(accuracies)
|
1201
|
+
self.max_accuracy = max_accuracy
|
1202
|
+
ind = accuracies.index(max_accuracy)
|
1203
|
+
self.model = models[ind]
|
1204
|
+
self.model.load_state_dict(weights[ind])
|
1205
|
+
self.model.eval()
|
1206
|
+
self.training_accuracy_list = self.training_accuracy_list[ind]
|
1207
|
+
self.training_loss_list = self.training_loss_list[ind]
|
1208
|
+
self.validation_accuracy_list = self.validation_accuracy_list[ind]
|
1209
|
+
self.validation_loss_list = self.validation_loss_list[ind]
|
1210
|
+
|
1211
|
+
def validate(self):
|
1212
|
+
temp_loss_list = []
|
1213
|
+
temp_acc_list = []
|
1214
|
+
self.model.eval()
|
1215
|
+
for data in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
1216
|
+
data = data.to(self.device)
|
1217
|
+
pred = self.model(data)
|
1218
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
1219
|
+
logp = F.log_softmax(pred, 1)
|
1220
|
+
loss = F.nll_loss(logp, data.y)
|
1221
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
1222
|
+
loss = F.cross_entropy(pred, data.y)
|
1223
|
+
temp_loss_list.append(loss.item())
|
1224
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
1225
|
+
self.validation_accuracy = np.mean(temp_acc_list).item()
|
1226
|
+
self.validation_loss = np.mean(temp_loss_list).item()
|
1227
|
+
|
1228
|
+
def test(self):
|
1229
|
+
if self.testingDataset:
|
1230
|
+
self.test_dataloader = DataLoader(self.testingDataset,
|
1231
|
+
batch_size=len(self.testingDataset),
|
1232
|
+
drop_last=False)
|
1233
|
+
temp_loss_list = []
|
1234
|
+
temp_acc_list = []
|
1235
|
+
self.model.eval()
|
1236
|
+
for data in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
1237
|
+
data = data.to(self.device)
|
1238
|
+
pred = self.model(data)
|
1239
|
+
if self.hparams.loss_function.lower() == "negative log likelihood":
|
1240
|
+
logp = F.log_softmax(pred, 1)
|
1241
|
+
loss = F.nll_loss(logp, data.y)
|
1242
|
+
elif self.hparams.loss_function.lower() == "cross entropy":
|
1243
|
+
loss = F.cross_entropy(pred, data.y)
|
1244
|
+
temp_loss_list.append(loss.item())
|
1245
|
+
temp_acc_list.append(accuracy_score(data.y.cpu(), pred.argmax(1).cpu()))
|
1246
|
+
self.testing_accuracy = np.mean(temp_acc_list).item()
|
1247
|
+
self.testing_loss = np.mean(temp_loss_list).item()
|
1248
|
+
|
1249
|
+
def save(self, path):
|
1250
|
+
if path:
|
1251
|
+
# Make sure the file extension is .pt
|
1252
|
+
ext = path[-3:]
|
1253
|
+
if ext.lower() != ".pt":
|
1254
|
+
path = path + ".pt"
|
1255
|
+
torch.save(self.model.state_dict(), path)
|
1256
|
+
|
1257
|
+
def load(self, path):
|
1258
|
+
#self.model.load_state_dict(torch.load(path))
|
1259
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
1260
|
+
|
1261
|
+
class _NodeRegressorKFold:
|
1262
|
+
def __init__(self, hparams, trainingDataset, testingDataset=None):
|
1263
|
+
self.trainingDataset = trainingDataset
|
1264
|
+
self.testingDataset = testingDataset
|
1265
|
+
self.hparams = hparams
|
1266
|
+
self.losses = []
|
1267
|
+
self.min_loss = 0
|
1268
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1269
|
+
|
1270
|
+
self.model = self._initialize_model(hparams, trainingDataset)
|
1271
|
+
self.optimizer = self._initialize_optimizer(hparams)
|
1272
|
+
|
1273
|
+
self.use_gpu = hparams.use_gpu
|
1274
|
+
self.training_loss_list = []
|
1275
|
+
self.validation_loss_list = []
|
1276
|
+
self.node_attr_key = trainingDataset.node_attr_key
|
1277
|
+
|
1278
|
+
# Train, validate, test split
|
1279
|
+
num_train = int(len(trainingDataset) * hparams.split[0])
|
1280
|
+
num_validate = int(len(trainingDataset) * hparams.split[1])
|
1281
|
+
num_test = len(trainingDataset) - num_train - num_validate
|
1282
|
+
idx = torch.randperm(len(trainingDataset))
|
1283
|
+
test_sampler = SubsetRandomSampler(idx[num_train+num_validate:num_train+num_validate+num_test])
|
1284
|
+
|
1285
|
+
if testingDataset:
|
1286
|
+
self.test_dataloader = DataLoader(testingDataset, batch_size=len(testingDataset), drop_last=False)
|
1287
|
+
else:
|
1288
|
+
self.test_dataloader = DataLoader(trainingDataset, sampler=test_sampler, batch_size=hparams.batch_size, drop_last=False)
|
1289
|
+
|
1290
|
+
def _initialize_model(self, hparams, dataset):
|
1291
|
+
if hparams.conv_layer_type.lower() == 'sageconv':
|
1292
|
+
# pooling is set None for Node
|
1293
|
+
return _SAGEConv(dataset.num_node_features, hparams.hl_widths, 1, None).to(self.device)
|
1294
|
+
else:
|
1295
|
+
raise NotImplementedError
|
1296
|
+
|
1297
|
+
def _initialize_optimizer(self, hparams):
|
1298
|
+
if hparams.optimizer_str.lower() == "adadelta":
|
1299
|
+
return torch.optim.Adadelta(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, rho=hparams.rho, weight_decay=hparams.weight_decay)
|
1300
|
+
elif hparams.optimizer_str.lower() == "adagrad":
|
1301
|
+
return torch.optim.Adagrad(self.model.parameters(), eps=hparams.eps, lr=hparams.lr, lr_decay=hparams.lr_decay, weight_decay=hparams.weight_decay)
|
1302
|
+
elif hparams.optimizer_str.lower() == "adam":
|
1303
|
+
return torch.optim.Adam(self.model.parameters(), amsgrad=hparams.amsgrad, betas=hparams.betas, eps=hparams.eps, lr=hparams.lr, maximize=hparams.maximize, weight_decay=hparams.weight_decay)
|
1304
|
+
|
1305
|
+
def reset_weights(self):
|
1306
|
+
self.model = self._initialize_model(self.hparams, self.trainingDataset)
|
1307
|
+
self.optimizer = self._initialize_optimizer(self.hparams)
|
1308
|
+
|
1309
|
+
def train(self):
|
1310
|
+
k_folds = self.hparams.k_folds
|
1311
|
+
torch.manual_seed(42)
|
1312
|
+
|
1313
|
+
kfold = KFold(n_splits=k_folds, shuffle=True)
|
1314
|
+
models, weights, losses, train_dataloaders, validate_dataloaders = [], [], [], [], []
|
1315
|
+
|
1316
|
+
for fold, (train_ids, validate_ids) in tqdm(enumerate(kfold.split(self.trainingDataset)), desc="Fold", total=k_folds, leave=False):
|
1317
|
+
epoch_training_loss_list, epoch_validation_loss_list = [], []
|
1318
|
+
train_subsampler = SubsetRandomSampler(train_ids)
|
1319
|
+
validate_subsampler = SubsetRandomSampler(validate_ids)
|
1320
|
+
|
1321
|
+
self.train_dataloader = DataLoader(self.trainingDataset, sampler=train_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
|
1322
|
+
self.validate_dataloader = DataLoader(self.trainingDataset, sampler=validate_subsampler, batch_size=self.hparams.batch_size, drop_last=False)
|
1323
|
+
|
1324
|
+
self.reset_weights()
|
1325
|
+
best_rmse = np.inf
|
1326
|
+
|
1327
|
+
for _ in tqdm(range(self.hparams.epochs), desc='Epochs', total=self.hparams.epochs, leave=False):
|
1328
|
+
for batched_graph in tqdm(self.train_dataloader, desc='Training', leave=False):
|
1329
|
+
self.model.train()
|
1330
|
+
self.optimizer.zero_grad()
|
1331
|
+
|
1332
|
+
batched_graph = batched_graph.to(self.device)
|
1333
|
+
pred = self.model(batched_graph)
|
1334
|
+
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
1335
|
+
loss.backward()
|
1336
|
+
self.optimizer.step()
|
1337
|
+
|
1338
|
+
epoch_training_loss_list.append(torch.sqrt(loss).item())
|
1339
|
+
self.validate()
|
1340
|
+
epoch_validation_loss_list.append(torch.sqrt(self.validation_loss).item())
|
1341
|
+
gc.collect()
|
1342
|
+
|
1343
|
+
models.append(self.model)
|
1344
|
+
weights.append(copy.deepcopy(self.model.state_dict()))
|
1345
|
+
losses.append(torch.sqrt(self.validation_loss).item())
|
1346
|
+
train_dataloaders.append(self.train_dataloader)
|
1347
|
+
validate_dataloaders.append(self.validate_dataloader)
|
1348
|
+
self.training_loss_list.append(epoch_training_loss_list)
|
1349
|
+
self.validation_loss_list.append(epoch_validation_loss_list)
|
1350
|
+
|
1351
|
+
self.losses = losses
|
1352
|
+
self.min_loss = min(losses)
|
1353
|
+
ind = losses.index(self.min_loss)
|
1354
|
+
self.model = models[ind]
|
1355
|
+
self.model.load_state_dict(weights[ind])
|
1356
|
+
self.model.eval()
|
1357
|
+
self.training_loss_list = self.training_loss_list[ind]
|
1358
|
+
self.validation_loss_list = self.validation_loss_list[ind]
|
1359
|
+
|
1360
|
+
def validate(self):
|
1361
|
+
self.model.eval()
|
1362
|
+
for batched_graph in tqdm(self.validate_dataloader, desc='Validating', leave=False):
|
1363
|
+
batched_graph = batched_graph.to(self.device)
|
1364
|
+
pred = self.model(batched_graph)
|
1365
|
+
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
1366
|
+
self.validation_loss = loss
|
1367
|
+
|
1368
|
+
def test(self):
|
1369
|
+
self.model.eval()
|
1370
|
+
for batched_graph in tqdm(self.test_dataloader, desc='Testing', leave=False):
|
1371
|
+
batched_graph = batched_graph.to(self.device)
|
1372
|
+
pred = self.model(batched_graph)
|
1373
|
+
loss = F.mse_loss(torch.flatten(pred), batched_graph.y.float())
|
1374
|
+
self.testing_loss = torch.sqrt(loss).item()
|
1375
|
+
|
1376
|
+
def save(self, path):
|
1377
|
+
if path:
|
1378
|
+
ext = path[-3:]
|
1379
|
+
if ext.lower() != ".pt":
|
1380
|
+
path = path + ".pt"
|
1381
|
+
torch.save(self.model.state_dict(), path)
|
1382
|
+
|
1383
|
+
def load(self, path):
|
1384
|
+
self.model.load_state_dict(torch.load(path, weights_only=True, map_location=self.device))
|
1385
|
+
|
1386
|
+
class PyG:
|
1387
|
+
@staticmethod
|
1388
|
+
def DatasetByCSVPath(path, numberOfGraphClasses=0, nodeATTRKey='feat', edgeATTRKey='feat', nodeOneHotEncode=False,
|
1389
|
+
nodeFeaturesCategories=[], edgeOneHotEncode=False, edgeFeaturesCategories=[], addSelfLoop=False,
|
1390
|
+
node_level=False, graph_level=True):
|
1391
|
+
"""
|
1392
|
+
Returns PyTorch Geometric dataset according to the input CSV folder path. The folder must contain "graphs.csv",
|
1393
|
+
"edges.csv", "nodes.csv", and "meta.yml" files according to conventions.
|
1394
|
+
|
1395
|
+
Parameters
|
1396
|
+
----------
|
1397
|
+
path : str
|
1398
|
+
The path to the folder containing the necessary CSV and YML files.
|
1399
|
+
|
1400
|
+
Returns
|
1401
|
+
-------
|
1402
|
+
PyG Dataset
|
1403
|
+
The PyG dataset
|
1404
|
+
"""
|
1405
|
+
if not isinstance(path, str):
|
1406
|
+
print("PyG.DatasetByCSVPath - Error: The input path parameter is not a valid string. Returning None.")
|
1407
|
+
return None
|
1408
|
+
if not os.path.exists(path):
|
1409
|
+
print("PyG.DatasetByCSVPath - Error: The input path parameter does not exist. Returning None.")
|
1410
|
+
return None
|
1411
|
+
|
1412
|
+
return CustomGraphDataset(root=path, node_level=node_level, graph_level=graph_level, node_attr_key=nodeATTRKey, edge_attr_key=edgeATTRKey)
|
1413
|
+
|
1414
|
+
@staticmethod
|
1415
|
+
def Optimizer(name="Adam", amsgrad=True, betas=(0.9,0.999), eps=0.000001, lr=0.001, maximize=False, weightDecay=0.0, rho=0.9, lr_decay=0.0):
|
1416
|
+
"""
|
1417
|
+
Returns the parameters of the optimizer
|
1418
|
+
|
1419
|
+
Parameters
|
1420
|
+
----------
|
1421
|
+
amsgrad : bool , optional.
|
1422
|
+
amsgrad is an extension to the Adam version of gradient descent that attempts to improve the convergence properties of the algorithm, avoiding large abrupt changes in the learning rate for each input variable. The default is True.
|
1423
|
+
betas : tuple , optional
|
1424
|
+
Betas are used as for smoothing the path to the convergence also providing some momentum to cross a local minima or saddle point. The default is (0.9, 0.999).
|
1425
|
+
eps : float . optional.
|
1426
|
+
eps is a term added to the denominator to improve numerical stability. The default is 0.000001.
|
1427
|
+
lr : float
|
1428
|
+
The learning rate (lr) defines the adjustment in the weights of our network with respect to the loss gradient descent. The default is 0.001.
|
1429
|
+
maximize : float , optional
|
1430
|
+
maximize the params based on the objective, instead of minimizing. The default is False.
|
1431
|
+
weightDecay : float , optional
|
1432
|
+
weightDecay (L2 penalty) is a regularization technique applied to the weights of a neural network. The default is 0.0.
|
1433
|
+
|
1434
|
+
Returns
|
1435
|
+
-------
|
1436
|
+
dict
|
1437
|
+
The dictionary of the optimizer parameters. The dictionary contains the following keys and values:
|
1438
|
+
- "name" (str): The name of the optimizer
|
1439
|
+
- "amsgrad" (bool):
|
1440
|
+
- "betas" (tuple):
|
1441
|
+
- "eps" (float):
|
1442
|
+
- "lr" (float):
|
1443
|
+
- "maximize" (bool):
|
1444
|
+
- weightDecay (float):
|
1445
|
+
|
1446
|
+
"""
|
1447
|
+
return {"name":name, "amsgrad":amsgrad, "betas":betas, "eps":eps, "lr": lr, "maximize":maximize, "weight_decay":weightDecay, "rho":rho, "lr_decay":lr_decay}
|
1448
|
+
|
1449
|
+
@staticmethod
|
1450
|
+
def Hyperparameters(optimizer, model_type="classifier", cv_type="Holdout", split=[0.8,0.1,0.1], k_folds=5,
|
1451
|
+
hl_widths=[32], conv_layer_type="SAGEConv", pooling="AvgPooling",
|
1452
|
+
batch_size=1, epochs=1, use_gpu=False, loss_function="Cross Entropy",
|
1453
|
+
input_type="graph"):
|
1454
|
+
"""
|
1455
|
+
Creates a hyperparameters object based on the input settings.
|
1456
|
+
|
1457
|
+
Parameters
|
1458
|
+
----------
|
1459
|
+
model_type : str , optional
|
1460
|
+
The desired type of model. The options are:
|
1461
|
+
- "Classifier"
|
1462
|
+
- "Regressor"
|
1463
|
+
The option is case insensitive. The default is "classifierholdout"
|
1464
|
+
optimizer : Optimizer
|
1465
|
+
The desired optimizer.
|
1466
|
+
cv_type : str , optional
|
1467
|
+
The desired cross-validation method. This can be "Holdout" or "K-Fold". It is case-insensitive. The default is "Holdout".
|
1468
|
+
split : list , optional
|
1469
|
+
The desired split between training validation, and testing. [0.8, 0.1, 0.1] means that 80% of the data is used for training 10% of the data is used for validation, and 10% is used for testing. The default is [0.8, 0.1, 0.1].
|
1470
|
+
k_folds : int , optional
|
1471
|
+
The desired number of k-folds. The default is 5.
|
1472
|
+
hl_widths : list , optional
|
1473
|
+
The list of hidden layer widths. A list of [16, 32, 16] means that the model will have 3 hidden layers with number of neurons in each being 16, 32, 16 respectively from input to output. The default is [32].
|
1474
|
+
conv_layer_type : str , optional
|
1475
|
+
The desired type of the convolution layer. The options are "Classic", "GraphConv", "GINConv", "SAGEConv", "TAGConv", "DGN". It is case insensitive. The default is "SAGEConv".
|
1476
|
+
pooling : str , optional
|
1477
|
+
The desired type of pooling. The options are "AvgPooling", "MaxPooling", or "SumPooling". It is case insensitive. The default is "AvgPooling".
|
1478
|
+
batch_size : int , optional
|
1479
|
+
The desired batch size. The default is 1.
|
1480
|
+
epochs : int , optional
|
1481
|
+
The desired number of epochs. The default is 1.
|
1482
|
+
use_gpu : bool , optional
|
1483
|
+
If set to True, the model will attempt to use the GPU. The default is False.
|
1484
|
+
loss_function : str , optional
|
1485
|
+
The desired loss function. The options are "Cross-Entropy" or "Negative Log Likelihood". It is case insensitive. The default is "Cross-Entropy".
|
1486
|
+
input_type : str
|
1487
|
+
selects the input_type of model such as graph, node or edge
|
1488
|
+
Returns
|
1489
|
+
-------
|
1490
|
+
Hyperparameters
|
1491
|
+
The created hyperparameters object.
|
1492
|
+
|
1493
|
+
"""
|
1494
|
+
|
1495
|
+
if optimizer['name'].lower() == "adadelta":
|
1496
|
+
optimizer_str = "Adadelta"
|
1497
|
+
elif optimizer['name'].lower() == "adagrad":
|
1498
|
+
optimizer_str = "Adagrad"
|
1499
|
+
elif optimizer['name'].lower() == "adam":
|
1500
|
+
optimizer_str = "Adam"
|
1501
|
+
return _Hparams(model_type,
|
1502
|
+
optimizer_str,
|
1503
|
+
optimizer['amsgrad'],
|
1504
|
+
optimizer['betas'],
|
1505
|
+
optimizer['eps'],
|
1506
|
+
optimizer['lr'],
|
1507
|
+
optimizer['lr_decay'],
|
1508
|
+
optimizer['maximize'],
|
1509
|
+
optimizer['rho'],
|
1510
|
+
optimizer['weight_decay'],
|
1511
|
+
cv_type,
|
1512
|
+
split,
|
1513
|
+
k_folds,
|
1514
|
+
hl_widths,
|
1515
|
+
conv_layer_type,
|
1516
|
+
pooling,
|
1517
|
+
batch_size,
|
1518
|
+
epochs,
|
1519
|
+
use_gpu,
|
1520
|
+
loss_function,
|
1521
|
+
input_type)
|
1522
|
+
|
1523
|
+
@staticmethod
|
1524
|
+
def Model(hparams, trainingDataset, validationDataset=None, testingDataset=None):
|
1525
|
+
"""
|
1526
|
+
Creates a neural network classifier.
|
1527
|
+
|
1528
|
+
Parameters
|
1529
|
+
----------
|
1530
|
+
hparams : HParams
|
1531
|
+
The input hyperparameters
|
1532
|
+
trainingDataset : DGLDataset
|
1533
|
+
The input training dataset.
|
1534
|
+
validationDataset : DGLDataset
|
1535
|
+
The input validation dataset. If not specified, a portion of the trainingDataset will be used for validation according the to the split list as specified in the hyper-parameters.
|
1536
|
+
testingDataset : DGLDataset
|
1537
|
+
The input testing dataset. If not specified, a portion of the trainingDataset will be used for testing according the to the split list as specified in the hyper-parameters.
|
1538
|
+
|
1539
|
+
Returns
|
1540
|
+
-------
|
1541
|
+
Classifier
|
1542
|
+
The created classifier
|
1543
|
+
|
1544
|
+
"""
|
1545
|
+
|
1546
|
+
model = None
|
1547
|
+
if hparams.model_type.lower() == "classifier":
|
1548
|
+
if hparams.input_type == 'graph':
|
1549
|
+
if hparams.cv_type.lower() == "holdout":
|
1550
|
+
model = _GraphClassifierHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
1551
|
+
elif "k" in hparams.cv_type.lower():
|
1552
|
+
model = _GraphClassifierKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
1553
|
+
elif hparams.input_type == 'node':
|
1554
|
+
if hparams.cv_type.lower() == "holdout":
|
1555
|
+
model = _NodeClassifierHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
1556
|
+
elif "k" in hparams.cv_type.lower():
|
1557
|
+
model = _NodeClassifierKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
1558
|
+
elif hparams.model_type.lower() == "regressor":
|
1559
|
+
if hparams.input_type == 'graph':
|
1560
|
+
if hparams.cv_type.lower() == "holdout":
|
1561
|
+
model = _GraphRegressorHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
1562
|
+
elif "k" in hparams.cv_type.lower():
|
1563
|
+
model = _GraphRegressorKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
1564
|
+
elif hparams.input_type == 'node':
|
1565
|
+
if hparams.cv_type.lower() == "holdout":
|
1566
|
+
model = _NodeRegressorHoldout(hparams=hparams, trainingDataset=trainingDataset, validationDataset=validationDataset, testingDataset=testingDataset)
|
1567
|
+
elif "k" in hparams.cv_type.lower():
|
1568
|
+
model = _NodeRegressorKFold(hparams=hparams, trainingDataset=trainingDataset, testingDataset=testingDataset)
|
1569
|
+
else:
|
1570
|
+
raise NotImplementedError
|
1571
|
+
return model
|
1572
|
+
|
1573
|
+
@staticmethod
|
1574
|
+
def ModelTrain(model):
|
1575
|
+
"""
|
1576
|
+
Trains the neural network model.
|
1577
|
+
|
1578
|
+
Parameters
|
1579
|
+
----------
|
1580
|
+
model : Model
|
1581
|
+
The input model.
|
1582
|
+
|
1583
|
+
Returns
|
1584
|
+
-------
|
1585
|
+
Model
|
1586
|
+
The trained model
|
1587
|
+
|
1588
|
+
"""
|
1589
|
+
if not model:
|
1590
|
+
return None
|
1591
|
+
model.train()
|
1592
|
+
return model
|
1593
|
+
|
1594
|
+
@staticmethod
|
1595
|
+
def ModelTest(model):
|
1596
|
+
"""
|
1597
|
+
Tests the neural network model.
|
1598
|
+
|
1599
|
+
Parameters
|
1600
|
+
----------
|
1601
|
+
model : Model
|
1602
|
+
The input model.
|
1603
|
+
|
1604
|
+
Returns
|
1605
|
+
-------
|
1606
|
+
Model
|
1607
|
+
The tested model
|
1608
|
+
|
1609
|
+
"""
|
1610
|
+
if not model:
|
1611
|
+
return None
|
1612
|
+
model.test()
|
1613
|
+
return model
|
1614
|
+
|
1615
|
+
@staticmethod
|
1616
|
+
def ModelSave(model, path, overwrite=False):
|
1617
|
+
"""
|
1618
|
+
Saves the model.
|
1619
|
+
|
1620
|
+
Parameters
|
1621
|
+
----------
|
1622
|
+
model : Model
|
1623
|
+
The input model.
|
1624
|
+
path : str
|
1625
|
+
The file path at which to save the model.
|
1626
|
+
overwrite : bool, optional
|
1627
|
+
If set to True, any existing file will be overwritten. Otherwise, it won't. The default is False.
|
1628
|
+
|
1629
|
+
Returns
|
1630
|
+
-------
|
1631
|
+
bool
|
1632
|
+
True if the model is saved correctly. False otherwise.
|
1633
|
+
|
1634
|
+
"""
|
1635
|
+
import os
|
1636
|
+
|
1637
|
+
if model == None:
|
1638
|
+
print("DGL.ModelSave - Error: The input model parameter is invalid. Returning None.")
|
1639
|
+
return None
|
1640
|
+
if path == None:
|
1641
|
+
print("DGL.ModelSave - Error: The input path parameter is invalid. Returning None.")
|
1642
|
+
return None
|
1643
|
+
if not overwrite and os.path.exists(path):
|
1644
|
+
print("DGL.ModelSave - Error: a file already exists at the specified path and overwrite is set to False. Returning None.")
|
1645
|
+
return None
|
1646
|
+
if overwrite and os.path.exists(path):
|
1647
|
+
os.remove(path)
|
1648
|
+
# Make sure the file extension is .pt
|
1649
|
+
ext = path[len(path)-3:len(path)]
|
1650
|
+
if ext.lower() != ".pt":
|
1651
|
+
path = path+".pt"
|
1652
|
+
model.save(path)
|
1653
|
+
return True
|
1654
|
+
|
1655
|
+
@staticmethod
|
1656
|
+
def ModelData(model):
|
1657
|
+
"""
|
1658
|
+
Returns the data of the model
|
1659
|
+
|
1660
|
+
Parameters
|
1661
|
+
----------
|
1662
|
+
model : Model
|
1663
|
+
The input model.
|
1664
|
+
|
1665
|
+
Returns
|
1666
|
+
-------
|
1667
|
+
dict
|
1668
|
+
A dictionary containing the model data. The keys in the dictionary are:
|
1669
|
+
'Model Type'
|
1670
|
+
'Optimizer'
|
1671
|
+
'CV Type'
|
1672
|
+
'Split'
|
1673
|
+
'K-Folds'
|
1674
|
+
'HL Widths'
|
1675
|
+
'Conv Layer Type'
|
1676
|
+
'Pooling'
|
1677
|
+
'Learning Rate'
|
1678
|
+
'Batch Size'
|
1679
|
+
'Epochs'
|
1680
|
+
'Training Accuracy'
|
1681
|
+
'Validation Accuracy'
|
1682
|
+
'Testing Accuracy'
|
1683
|
+
'Training Loss'
|
1684
|
+
'Validation Loss'
|
1685
|
+
'Testing Loss'
|
1686
|
+
'Accuracies' (Classifier and K-Fold only)
|
1687
|
+
'Max Accuracy' (Classifier and K-Fold only)
|
1688
|
+
'Losses' (Regressor and K-fold only)
|
1689
|
+
'min Loss' (Regressor and K-fold only)
|
1690
|
+
|
1691
|
+
"""
|
1692
|
+
from topologicpy.Helper import Helper
|
1693
|
+
|
1694
|
+
data = {'Model Type': [model.hparams.model_type],
|
1695
|
+
'Optimizer': [model.hparams.optimizer_str],
|
1696
|
+
'CV Type': [model.hparams.cv_type],
|
1697
|
+
'Split': model.hparams.split,
|
1698
|
+
'K-Folds': [model.hparams.k_folds],
|
1699
|
+
'HL Widths': model.hparams.hl_widths,
|
1700
|
+
'Conv Layer Type': [model.hparams.conv_layer_type],
|
1701
|
+
'Pooling': [model.hparams.pooling],
|
1702
|
+
'Learning Rate': [model.hparams.lr],
|
1703
|
+
'Batch Size': [model.hparams.batch_size],
|
1704
|
+
'Epochs': [model.hparams.epochs]
|
1705
|
+
}
|
1706
|
+
|
1707
|
+
if model.hparams.model_type.lower() == "classifier":
|
1708
|
+
testing_accuracy_list = [model.testing_accuracy] * model.hparams.epochs
|
1709
|
+
try:
|
1710
|
+
testing_loss_list = [model.testing_loss] * model.hparams.epochs
|
1711
|
+
except:
|
1712
|
+
testing_loss_list = [0.] * model.hparams.epochs
|
1713
|
+
metrics_data = {
|
1714
|
+
'Training Accuracy': [model.training_accuracy_list],
|
1715
|
+
'Validation Accuracy': [model.validation_accuracy_list],
|
1716
|
+
'Testing Accuracy' : [testing_accuracy_list],
|
1717
|
+
'Training Loss': [model.training_loss_list],
|
1718
|
+
'Validation Loss': [model.validation_loss_list],
|
1719
|
+
'Testing Loss' : [testing_loss_list]
|
1720
|
+
}
|
1721
|
+
if model.hparams.cv_type.lower() == "k-fold":
|
1722
|
+
accuracy_data = {
|
1723
|
+
'Accuracies' : [model.accuracies],
|
1724
|
+
'Max Accuracy' : [model.max_accuracy]
|
1725
|
+
}
|
1726
|
+
metrics_data.update(accuracy_data)
|
1727
|
+
data.update(metrics_data)
|
1728
|
+
|
1729
|
+
elif model.hparams.model_type.lower() == "regressor":
|
1730
|
+
testing_loss_list = [model.testing_loss] * model.hparams.epochs
|
1731
|
+
metrics_data = {
|
1732
|
+
'Training Loss': [model.training_loss_list],
|
1733
|
+
'Validation Loss': [model.validation_loss_list],
|
1734
|
+
'Testing Loss' : [testing_loss_list]
|
1735
|
+
}
|
1736
|
+
if model.hparams.cv_type.lower() == "k-fold":
|
1737
|
+
loss_data = {
|
1738
|
+
'Losses' : [model.losses],
|
1739
|
+
'Min Loss' : [model.min_loss]
|
1740
|
+
}
|
1741
|
+
metrics_data.update(loss_data)
|
1742
|
+
data.update(metrics_data)
|
1743
|
+
|
1744
|
+
return data
|
1745
|
+
|
1746
|
+
@staticmethod
|
1747
|
+
def Show(data,
|
1748
|
+
labels,
|
1749
|
+
title="Training/Validation",
|
1750
|
+
xTitle="Epochs",
|
1751
|
+
xSpacing=1,
|
1752
|
+
yTitle="Accuracy and Loss",
|
1753
|
+
ySpacing=0.1,
|
1754
|
+
useMarkers=False,
|
1755
|
+
chartType="Line",
|
1756
|
+
width=950,
|
1757
|
+
height=500,
|
1758
|
+
backgroundColor='rgba(0,0,0,0)',
|
1759
|
+
gridColor='lightgray',
|
1760
|
+
marginLeft=0,
|
1761
|
+
marginRight=0,
|
1762
|
+
marginTop=40,
|
1763
|
+
marginBottom=0,
|
1764
|
+
renderer = "notebook"):
|
1765
|
+
"""
|
1766
|
+
Shows the data in a plolty graph.
|
1767
|
+
|
1768
|
+
Parameters
|
1769
|
+
----------
|
1770
|
+
data : list
|
1771
|
+
The data to display.
|
1772
|
+
labels : list
|
1773
|
+
The labels to use for the data.
|
1774
|
+
width : int , optional
|
1775
|
+
The desired width of the figure. The default is 950.
|
1776
|
+
height : int , optional
|
1777
|
+
The desired height of the figure. The default is 500.
|
1778
|
+
title : str , optional
|
1779
|
+
The chart title. The default is "Training and Testing Results".
|
1780
|
+
xTitle : str , optional
|
1781
|
+
The X-axis title. The default is "Epochs".
|
1782
|
+
xSpacing : float , optional
|
1783
|
+
The X-axis spacing. The default is 1.0.
|
1784
|
+
yTitle : str , optional
|
1785
|
+
The Y-axis title. The default is "Accuracy and Loss".
|
1786
|
+
ySpacing : float , optional
|
1787
|
+
The Y-axis spacing. The default is 0.1.
|
1788
|
+
useMarkers : bool , optional
|
1789
|
+
If set to True, markers will be displayed. The default is False.
|
1790
|
+
chartType : str , optional
|
1791
|
+
The desired type of chart. The options are "Line", "Bar", or "Scatter". It is case insensitive. The default is "Line".
|
1792
|
+
backgroundColor : str , optional
|
1793
|
+
The desired background color. This can be any plotly color string and may be specified as:
|
1794
|
+
- A hex string (e.g. '#ff0000')
|
1795
|
+
- An rgb/rgba string (e.g. 'rgb(255,0,0)')
|
1796
|
+
- An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
|
1797
|
+
- An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
|
1798
|
+
- A named CSS color.
|
1799
|
+
The default is 'rgba(0,0,0,0)' (transparent).
|
1800
|
+
gridColor : str , optional
|
1801
|
+
The desired grid color. This can be any plotly color string and may be specified as:
|
1802
|
+
- A hex string (e.g. '#ff0000')
|
1803
|
+
- An rgb/rgba string (e.g. 'rgb(255,0,0)')
|
1804
|
+
- An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
|
1805
|
+
- An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
|
1806
|
+
- A named CSS color.
|
1807
|
+
The default is 'lightgray'.
|
1808
|
+
marginLeft : int , optional
|
1809
|
+
The desired left margin in pixels. The default is 0.
|
1810
|
+
marginRight : int , optional
|
1811
|
+
The desired right margin in pixels. The default is 0.
|
1812
|
+
marginTop : int , optional
|
1813
|
+
The desired top margin in pixels. The default is 40.
|
1814
|
+
marginBottom : int , optional
|
1815
|
+
The desired bottom margin in pixels. The default is 0.
|
1816
|
+
renderer : str , optional
|
1817
|
+
The desired plotly renderer. The default is "notebook".
|
1818
|
+
|
1819
|
+
Returns
|
1820
|
+
-------
|
1821
|
+
None.
|
1822
|
+
|
1823
|
+
"""
|
1824
|
+
from topologicpy.Plotly import Plotly
|
1825
|
+
|
1826
|
+
dataFrame = Plotly.DataByDGL(data, labels)
|
1827
|
+
fig = Plotly.FigureByDataFrame(dataFrame,
|
1828
|
+
labels=labels,
|
1829
|
+
title=title,
|
1830
|
+
xTitle=xTitle,
|
1831
|
+
xSpacing=xSpacing,
|
1832
|
+
yTitle=yTitle,
|
1833
|
+
ySpacing=ySpacing,
|
1834
|
+
useMarkers=useMarkers,
|
1835
|
+
chartType=chartType,
|
1836
|
+
width=width,
|
1837
|
+
height=height,
|
1838
|
+
backgroundColor=backgroundColor,
|
1839
|
+
gridColor=gridColor,
|
1840
|
+
marginRight=marginRight,
|
1841
|
+
marginLeft=marginLeft,
|
1842
|
+
marginTop=marginTop,
|
1843
|
+
marginBottom=marginBottom
|
1844
|
+
)
|
1845
|
+
Plotly.Show(fig, renderer=renderer)
|
1846
|
+
|
1847
|
+
@staticmethod
|
1848
|
+
def ModelLoad(path, model):
|
1849
|
+
"""
|
1850
|
+
Returns the model found at the input file path.
|
1851
|
+
|
1852
|
+
Parameters
|
1853
|
+
----------
|
1854
|
+
path : str
|
1855
|
+
File path for the saved classifier.
|
1856
|
+
model : torch.nn.module
|
1857
|
+
Initialized instance of model
|
1858
|
+
|
1859
|
+
Returns
|
1860
|
+
-------
|
1861
|
+
PyG Classifier
|
1862
|
+
The classifier.
|
1863
|
+
|
1864
|
+
"""
|
1865
|
+
if not path:
|
1866
|
+
return None
|
1867
|
+
|
1868
|
+
model.load(path)
|
1869
|
+
return model
|
1870
|
+
|
1871
|
+
@staticmethod
|
1872
|
+
def ConfusionMatrix(actual, predicted, normalize=False):
|
1873
|
+
"""
|
1874
|
+
Returns the confusion matrix for the input actual and predicted labels. This is to be used with classification tasks only not regression.
|
1875
|
+
|
1876
|
+
Parameters
|
1877
|
+
----------
|
1878
|
+
actual : list
|
1879
|
+
The input list of actual labels.
|
1880
|
+
predicted : list
|
1881
|
+
The input list of predicts labels.
|
1882
|
+
normalized : bool , optional
|
1883
|
+
If set to True, the returned data will be normalized (proportion of 1). Otherwise, actual numbers are returned. The default is False.
|
1884
|
+
|
1885
|
+
Returns
|
1886
|
+
-------
|
1887
|
+
list
|
1888
|
+
The created confusion matrix.
|
1889
|
+
|
1890
|
+
"""
|
1891
|
+
|
1892
|
+
try:
|
1893
|
+
from sklearn import metrics
|
1894
|
+
from sklearn.metrics import accuracy_score
|
1895
|
+
except:
|
1896
|
+
print("DGL - Installing required scikit-learn (sklearn) library.")
|
1897
|
+
try:
|
1898
|
+
os.system("pip install scikit-learn")
|
1899
|
+
except:
|
1900
|
+
os.system("pip install scikit-learn --user")
|
1901
|
+
try:
|
1902
|
+
from sklearn import metrics
|
1903
|
+
from sklearn.metrics import accuracy_score
|
1904
|
+
print("DGL - scikit-learn (sklearn) library installed correctly.")
|
1905
|
+
except:
|
1906
|
+
warnings.warn("DGL - Error: Could not import scikit-learn (sklearn). Please try to install scikit-learn manually. Returning None.")
|
1907
|
+
return None
|
1908
|
+
|
1909
|
+
if not isinstance(actual, list):
|
1910
|
+
print("DGL.ConfusionMatrix - ERROR: The actual input is not a list. Returning None")
|
1911
|
+
return None
|
1912
|
+
if not isinstance(predicted, list):
|
1913
|
+
print("DGL.ConfusionMatrix - ERROR: The predicted input is not a list. Returning None")
|
1914
|
+
return None
|
1915
|
+
if len(actual) != len(predicted):
|
1916
|
+
print("DGL.ConfusionMatrix - ERROR: The two input lists do not have the same length. Returning None")
|
1917
|
+
return None
|
1918
|
+
if normalize:
|
1919
|
+
cm = np.transpose(metrics.confusion_matrix(y_true=actual, y_pred=predicted, normalize="true"))
|
1920
|
+
else:
|
1921
|
+
cm = np.transpose(metrics.confusion_matrix(y_true=actual, y_pred=predicted))
|
1922
|
+
return cm
|
1923
|
+
|
1924
|
+
@staticmethod
|
1925
|
+
def ModelPredict(model, dataset, nodeATTRKey="feat"):
|
1926
|
+
"""
|
1927
|
+
Predicts the value of the input dataset.
|
1928
|
+
|
1929
|
+
Parameters
|
1930
|
+
----------
|
1931
|
+
dataset : PyGDataset
|
1932
|
+
The input PyG dataset.
|
1933
|
+
model : Model
|
1934
|
+
The input trained model.
|
1935
|
+
nodeATTRKey : str , optional
|
1936
|
+
The key used for node attributes. The default is "feat".
|
1937
|
+
|
1938
|
+
Returns
|
1939
|
+
-------
|
1940
|
+
list
|
1941
|
+
The list of predictions
|
1942
|
+
"""
|
1943
|
+
try:
|
1944
|
+
model = model.model #The inoput model might be our wrapper model. In that case, get its model attribute to do the prediciton.
|
1945
|
+
except:
|
1946
|
+
pass
|
1947
|
+
values = []
|
1948
|
+
dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
|
1949
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1950
|
+
model.eval()
|
1951
|
+
for data in tqdm(dataloader, desc='Predicting', leave=False):
|
1952
|
+
data = data.to(device)
|
1953
|
+
pred = model(data)
|
1954
|
+
values.extend(list(np.round(pred.detach().cpu().numpy().flatten(), 3)))
|
1955
|
+
return values
|
1956
|
+
|
1957
|
+
@staticmethod
|
1958
|
+
def ModelClassify(model, dataset, nodeATTRKey="feat"):
|
1959
|
+
"""
|
1960
|
+
Predicts the classification the labels of the input dataset.
|
1961
|
+
|
1962
|
+
Parameters
|
1963
|
+
----------
|
1964
|
+
dataset : PyGDataset
|
1965
|
+
The input PyG dataset.
|
1966
|
+
model : Model
|
1967
|
+
The input trained model.
|
1968
|
+
nodeATTRKey : str , optional
|
1969
|
+
The key used for node attributes. The default is "feat".
|
1970
|
+
|
1971
|
+
Returns
|
1972
|
+
-------
|
1973
|
+
dict
|
1974
|
+
Dictionary containing labels and probabilities. The included keys and values are:
|
1975
|
+
- "predictions" (list): the list of predicted labels
|
1976
|
+
- "probabilities" (list): the list of probabilities that the label is one of the categories.
|
1977
|
+
|
1978
|
+
"""
|
1979
|
+
try:
|
1980
|
+
model = model.model #The inoput model might be our wrapper model. In that case, get its model attribute to do the prediciton.
|
1981
|
+
except:
|
1982
|
+
pass
|
1983
|
+
labels = []
|
1984
|
+
probabilities = []
|
1985
|
+
dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
|
1986
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1987
|
+
for data in tqdm(dataloader, desc='Classifying', leave=False):
|
1988
|
+
data = data.to(device)
|
1989
|
+
pred = model(data)
|
1990
|
+
labels.extend(pred.argmax(1).tolist())
|
1991
|
+
probability = (torch.nn.functional.softmax(pred, dim=1).tolist())
|
1992
|
+
probability = probability[0]
|
1993
|
+
temp_probability = []
|
1994
|
+
for p in probability:
|
1995
|
+
temp_probability.append(round(p, 3))
|
1996
|
+
probabilities.extend(temp_probability)
|
1997
|
+
return {"predictions":labels, "probabilities":probabilities}
|
1998
|
+
|
1999
|
+
@staticmethod
|
2000
|
+
def Accuracy(actual, predicted, mantissa: int = 6):
|
2001
|
+
"""
|
2002
|
+
Computes the accuracy of the input predictions based on the input labels. This is to be used only with classification not with regression.
|
2003
|
+
|
2004
|
+
Parameters
|
2005
|
+
----------
|
2006
|
+
actual : list
|
2007
|
+
The input list of actual values.
|
2008
|
+
predicted : list
|
2009
|
+
The input list of predicted values.
|
2010
|
+
mantissa : int , optional
|
2011
|
+
The desired length of the mantissa. The default is 6.
|
2012
|
+
|
2013
|
+
Returns
|
2014
|
+
-------
|
2015
|
+
dict
|
2016
|
+
A dictionary returning the accuracy information. This contains the following keys and values:
|
2017
|
+
- "accuracy" (float): The number of correct predictions divided by the length of the list.
|
2018
|
+
- "correct" (int): The number of correct predictions
|
2019
|
+
- "mask" (list): A boolean mask for correct vs. wrong predictions which can be used to filter the list of predictions
|
2020
|
+
- "size" (int): The size of the predictions list
|
2021
|
+
- "wrong" (int): The number of wrong predictions
|
2022
|
+
|
2023
|
+
"""
|
2024
|
+
if len(predicted) < 1 or len(actual) < 1 or not len(predicted) == len(actual):
|
2025
|
+
return None
|
2026
|
+
correct = 0
|
2027
|
+
mask = []
|
2028
|
+
for i in range(len(predicted)):
|
2029
|
+
if predicted[i] == actual[i]:
|
2030
|
+
correct = correct + 1
|
2031
|
+
mask.append(True)
|
2032
|
+
else:
|
2033
|
+
mask.append(False)
|
2034
|
+
size = len(predicted)
|
2035
|
+
wrong = len(predicted)- correct
|
2036
|
+
accuracy = round(float(correct) / float(len(predicted)), mantissa)
|
2037
|
+
return {"accuracy":accuracy, "correct":correct, "mask":mask, "size":size, "wrong":wrong}
|
2038
|
+
|
2039
|
+
@staticmethod
|
2040
|
+
def MSE(actual, predicted, mantissa: int = 6):
|
2041
|
+
"""
|
2042
|
+
Computes the Mean Squared Error (MSE) of the input predictions based on the input labels. This is to be used with regression models.
|
2043
|
+
|
2044
|
+
Parameters
|
2045
|
+
----------
|
2046
|
+
actual : list
|
2047
|
+
The input list of actual values.
|
2048
|
+
predicted : list
|
2049
|
+
The input list of predicted values.
|
2050
|
+
mantissa : int , optional
|
2051
|
+
The desired length of the mantissa. The default is 6.
|
2052
|
+
|
2053
|
+
Returns
|
2054
|
+
-------
|
2055
|
+
dict
|
2056
|
+
A dictionary returning the MSE information. This contains the following keys and values:
|
2057
|
+
- "mse" (float): The mean squared error rounded to the specified mantissa.
|
2058
|
+
- "size" (int): The size of the predictions list.
|
2059
|
+
"""
|
2060
|
+
if len(predicted) < 1 or len(actual) < 1 or not len(predicted) == len(actual):
|
2061
|
+
return None
|
2062
|
+
|
2063
|
+
mse = np.mean((np.array(predicted) - np.array(actual)) ** 2)
|
2064
|
+
mse = round(mse, mantissa)
|
2065
|
+
size = len(predicted)
|
2066
|
+
|
2067
|
+
return {"mse": mse, "size": size}
|