glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
from typing import List
|
2
|
+
import torch
|
3
|
+
from collections import defaultdict
|
4
|
+
from torch_geometric.loader import DataLoader
|
5
|
+
from glam4cm.data_loading.data import GraphData
|
6
|
+
from glam4cm.models.gnn_layers import (
|
7
|
+
GNNConv,
|
8
|
+
EdgeClassifer
|
9
|
+
)
|
10
|
+
|
11
|
+
from glam4cm.trainers.gnn_trainer import Trainer
|
12
|
+
from glam4cm.settings import device
|
13
|
+
|
14
|
+
|
15
|
+
class GNNEdgeClassificationTrainer(Trainer):
|
16
|
+
"""
|
17
|
+
Trainer class for GNN Link Prediction
|
18
|
+
This class is used to train the GNN model for the link prediction task
|
19
|
+
The model is trained to predict the link between two nodes
|
20
|
+
"""
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
model: GNNConv,
|
24
|
+
predictor: EdgeClassifer,
|
25
|
+
dataset: List[GraphData],
|
26
|
+
cls_label='type',
|
27
|
+
lr=1e-3,
|
28
|
+
num_epochs=100,
|
29
|
+
batch_size=32,
|
30
|
+
use_edge_attrs=False,
|
31
|
+
logs_dir='./logs'
|
32
|
+
) -> None:
|
33
|
+
|
34
|
+
super().__init__(
|
35
|
+
model=model,
|
36
|
+
predictor=predictor,
|
37
|
+
cls_label=cls_label,
|
38
|
+
lr=lr,
|
39
|
+
num_epochs=num_epochs,
|
40
|
+
use_edge_attrs=use_edge_attrs,
|
41
|
+
logs_dir=logs_dir
|
42
|
+
)
|
43
|
+
|
44
|
+
self.dataloader = DataLoader(
|
45
|
+
dataset,
|
46
|
+
batch_size=batch_size,
|
47
|
+
shuffle=True
|
48
|
+
)
|
49
|
+
|
50
|
+
print("GNN Trainer initialized.")
|
51
|
+
|
52
|
+
|
53
|
+
def train(self):
|
54
|
+
self.model.train()
|
55
|
+
self.predictor.train()
|
56
|
+
|
57
|
+
all_preds, all_labels = list(), list()
|
58
|
+
epoch_loss = 0
|
59
|
+
epoch_metrics = defaultdict(float)
|
60
|
+
for data in self.dataloader:
|
61
|
+
self.optimizer.zero_grad()
|
62
|
+
self.model.zero_grad()
|
63
|
+
self.predictor.zero_grad()
|
64
|
+
x = data.x
|
65
|
+
edge_index = data.train_pos_edge_label_index
|
66
|
+
train_mask = data.train_edge_mask
|
67
|
+
edge_attr = data.edge_attr[train_mask] if self.use_edge_attrs else None
|
68
|
+
|
69
|
+
h = self.get_logits(x, edge_index, edge_attr)
|
70
|
+
scores = self.get_prediction_score(h, edge_index, edge_attr)
|
71
|
+
labels = getattr(data, f"edge_{self.cls_label}")[train_mask]
|
72
|
+
loss = self.criterion(scores, labels.to(device))
|
73
|
+
all_preds.append(scores.detach().cpu())
|
74
|
+
all_labels.append(labels)
|
75
|
+
|
76
|
+
loss.backward()
|
77
|
+
self.optimizer.step()
|
78
|
+
self.scheduler.step()
|
79
|
+
epoch_loss += loss.item()
|
80
|
+
|
81
|
+
|
82
|
+
all_preds = torch.cat(all_preds, dim=0)
|
83
|
+
all_labels = torch.cat(all_labels, dim=0)
|
84
|
+
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
85
|
+
epoch_metrics['loss'] = epoch_loss
|
86
|
+
epoch_metrics['phase'] = 'train'
|
87
|
+
|
88
|
+
# print(f"Train Metrics: {epoch_metrics}")
|
89
|
+
|
90
|
+
return epoch_metrics
|
91
|
+
|
92
|
+
|
93
|
+
def test(self):
|
94
|
+
self.model.eval()
|
95
|
+
self.predictor.eval()
|
96
|
+
all_preds, all_labels = list(), list()
|
97
|
+
with torch.no_grad():
|
98
|
+
epoch_loss = 0
|
99
|
+
epoch_metrics = defaultdict(float)
|
100
|
+
for data in self.dataloader:
|
101
|
+
x = data.x
|
102
|
+
edge_index = data.test_pos_edge_label_index
|
103
|
+
test_mask = data.test_edge_mask
|
104
|
+
edge_attr = data.edge_attr[test_mask] if self.use_edge_attrs else None
|
105
|
+
|
106
|
+
h = self.get_logits(x, edge_index, edge_attr)
|
107
|
+
scores = self.get_prediction_score(h, edge_index, edge_attr)
|
108
|
+
labels = getattr(data, f"edge_{self.cls_label}")[test_mask]
|
109
|
+
all_preds.append(scores.detach().cpu())
|
110
|
+
all_labels.append(labels)
|
111
|
+
loss = self.criterion(scores, labels.to(device))
|
112
|
+
|
113
|
+
epoch_loss += loss.item()
|
114
|
+
|
115
|
+
all_preds = torch.cat(all_preds, dim=0)
|
116
|
+
all_labels = torch.cat(all_labels, dim=0)
|
117
|
+
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
118
|
+
|
119
|
+
epoch_metrics['loss'] = epoch_loss
|
120
|
+
epoch_metrics['phase'] = 'test'
|
121
|
+
# print(f"Epoch Test Loss: {epoch_loss}\nTest Accuracy: {epoch_acc}\nTest F1: {epoch_f1}")
|
122
|
+
self.results.append(epoch_metrics)
|
123
|
+
|
124
|
+
print(f"Epoch: {len(self.results)}\n{epoch_metrics}")
|
125
|
+
|
126
|
+
return epoch_metrics
|
@@ -0,0 +1,123 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
import torch
|
3
|
+
from collections import defaultdict
|
4
|
+
from torch_geometric.loader import DataLoader
|
5
|
+
|
6
|
+
from torch_geometric.data import Data
|
7
|
+
from glam4cm.models.gnn_layers import (
|
8
|
+
GNNConv,
|
9
|
+
GraphClassifer
|
10
|
+
)
|
11
|
+
from glam4cm.trainers.gnn_trainer import Trainer
|
12
|
+
from glam4cm.settings import device
|
13
|
+
|
14
|
+
|
15
|
+
class GNNGraphClassificationTrainer(Trainer):
|
16
|
+
"""
|
17
|
+
Trainer class for GNN Graph Classfication
|
18
|
+
This class is used to train the GNN model for the link prediction task
|
19
|
+
The model is trained to predict the link between two nodes
|
20
|
+
"""
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
model: GNNConv,
|
24
|
+
predictor: GraphClassifer,
|
25
|
+
dataset: List[Tuple[Data, Data]],
|
26
|
+
cls_label='label',
|
27
|
+
lr=1e-4,
|
28
|
+
num_epochs=100,
|
29
|
+
batch_size=32,
|
30
|
+
use_edge_attrs=False,
|
31
|
+
logs_dir='./logs'
|
32
|
+
) -> None:
|
33
|
+
|
34
|
+
super().__init__(
|
35
|
+
model=model,
|
36
|
+
predictor=predictor,
|
37
|
+
cls_label='type',
|
38
|
+
lr=lr,
|
39
|
+
num_epochs=num_epochs,
|
40
|
+
use_edge_attrs=use_edge_attrs,
|
41
|
+
logs_dir=logs_dir
|
42
|
+
)
|
43
|
+
|
44
|
+
self.cls_label = cls_label
|
45
|
+
self.dataloaders = dict()
|
46
|
+
self.dataloaders['train'] = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
|
47
|
+
self.dataloaders['test'] = DataLoader(dataset['test'], batch_size=batch_size, shuffle=False)
|
48
|
+
|
49
|
+
self.results = list()
|
50
|
+
|
51
|
+
print("GNN Trainer initialized.")
|
52
|
+
|
53
|
+
|
54
|
+
def train(self):
|
55
|
+
self.model.train()
|
56
|
+
self.predictor.train()
|
57
|
+
|
58
|
+
epoch_loss = 0
|
59
|
+
epoch_metrics = defaultdict(float)
|
60
|
+
preds, all_labels = list(), list()
|
61
|
+
for data in self.dataloaders['train']:
|
62
|
+
self.optimizer.zero_grad()
|
63
|
+
self.model.train()
|
64
|
+
self.predictor.train()
|
65
|
+
|
66
|
+
h = self.model(data.x.to(device), data.edge_index.to(device))
|
67
|
+
g_pred = self.predictor(h, data.batch.to(device))
|
68
|
+
|
69
|
+
|
70
|
+
labels = getattr(data, f"graph_{self.cls_label}")
|
71
|
+
loss = self.criterion(g_pred, labels.to(device))
|
72
|
+
|
73
|
+
preds.append(g_pred.detach().cpu())
|
74
|
+
all_labels.append(labels)
|
75
|
+
|
76
|
+
loss.backward()
|
77
|
+
self.optimizer.step()
|
78
|
+
self.scheduler.step()
|
79
|
+
epoch_loss += loss.item()
|
80
|
+
|
81
|
+
|
82
|
+
preds = torch.cat(preds, dim=0)
|
83
|
+
labels = torch.cat(all_labels, dim=0)
|
84
|
+
|
85
|
+
epoch_metrics = self.compute_metrics(preds, labels)
|
86
|
+
epoch_metrics['loss'] = epoch_loss
|
87
|
+
epoch_metrics['phase'] = 'train'
|
88
|
+
|
89
|
+
self.results.append(epoch_metrics)
|
90
|
+
|
91
|
+
return epoch_metrics
|
92
|
+
|
93
|
+
|
94
|
+
def test(self):
|
95
|
+
self.model.eval()
|
96
|
+
self.predictor.eval()
|
97
|
+
with torch.no_grad():
|
98
|
+
epoch_loss = 0
|
99
|
+
preds, all_labels = list(), list()
|
100
|
+
for data in self.dataloaders['test']:
|
101
|
+
h = self.model(data.x.to(device), data.edge_index.to(device))
|
102
|
+
g_pred = self.predictor(h, data.batch.to(device))
|
103
|
+
labels = getattr(data, f"graph_{self.cls_label}")
|
104
|
+
|
105
|
+
loss = self.criterion(g_pred, labels.to(device))
|
106
|
+
epoch_loss += loss.item()
|
107
|
+
|
108
|
+
preds.append(g_pred.cpu().detach())
|
109
|
+
all_labels.append(labels.cpu())
|
110
|
+
|
111
|
+
|
112
|
+
preds = torch.cat(preds, dim=0)
|
113
|
+
labels = torch.cat(all_labels, dim=0)
|
114
|
+
|
115
|
+
epoch_metrics = self.compute_metrics(preds, labels)
|
116
|
+
epoch_metrics['loss'] = epoch_loss
|
117
|
+
epoch_metrics['phase'] = 'test'
|
118
|
+
self.results.append(epoch_metrics)
|
119
|
+
|
120
|
+
s2t = lambda x: x.replace("_", " ").title()
|
121
|
+
print(f"Epoch: {len(self.results)//2} {' | '.join([f'{s2t(k)}: {v:.4f}' for k, v in epoch_metrics.items() if k != 'phase'])}")
|
122
|
+
|
123
|
+
return epoch_metrics
|
@@ -0,0 +1,144 @@
|
|
1
|
+
from torch_geometric.loader import DataLoader
|
2
|
+
from torch_geometric.data import Data
|
3
|
+
import torch
|
4
|
+
from collections import defaultdict
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
from glam4cm.models.gnn_layers import (
|
8
|
+
GNNConv,
|
9
|
+
EdgeClassifer
|
10
|
+
)
|
11
|
+
|
12
|
+
from glam4cm.trainers.gnn_trainer import Trainer
|
13
|
+
from tqdm.auto import tqdm
|
14
|
+
from glam4cm.settings import device
|
15
|
+
|
16
|
+
|
17
|
+
class GNNLinkPredictionTrainer(Trainer):
|
18
|
+
"""
|
19
|
+
Trainer class for GNN Link Prediction
|
20
|
+
This class is used to train the GNN model for the link prediction task
|
21
|
+
The model is trained to predict the link between two nodes
|
22
|
+
"""
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
model: GNNConv,
|
26
|
+
predictor: EdgeClassifer,
|
27
|
+
dataset: List[Data],
|
28
|
+
cls_label='type',
|
29
|
+
lr=1e-3,
|
30
|
+
num_epochs=100,
|
31
|
+
batch_size=32,
|
32
|
+
use_edge_attrs=False,
|
33
|
+
logs_dir='./logs'
|
34
|
+
) -> None:
|
35
|
+
|
36
|
+
super().__init__(
|
37
|
+
model=model,
|
38
|
+
predictor=predictor,
|
39
|
+
lr=lr,
|
40
|
+
cls_label=cls_label,
|
41
|
+
num_epochs=num_epochs,
|
42
|
+
use_edge_attrs=use_edge_attrs,
|
43
|
+
logs_dir=logs_dir
|
44
|
+
)
|
45
|
+
self.dataloader = DataLoader(
|
46
|
+
dataset, batch_size=batch_size, shuffle=True
|
47
|
+
)
|
48
|
+
self.results = list()
|
49
|
+
|
50
|
+
print("GNN Trainer initialized.")
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
def train(self):
|
55
|
+
self.model.train()
|
56
|
+
self.predictor.train()
|
57
|
+
|
58
|
+
all_preds, all_labels = list(), list()
|
59
|
+
epoch_loss = 0
|
60
|
+
epoch_metrics = defaultdict(float)
|
61
|
+
for data in tqdm(self.dataloader, desc='Training Batches'):
|
62
|
+
self.optimizer.zero_grad()
|
63
|
+
self.model.zero_grad()
|
64
|
+
self.predictor.zero_grad()
|
65
|
+
|
66
|
+
x = data.x
|
67
|
+
pos_edge_index = data.train_pos_edge_label_index
|
68
|
+
neg_edge_index = data.train_neg_edge_label_index
|
69
|
+
train_mask = data.train_edge_mask
|
70
|
+
edge_attr = data.edge_attr[train_mask] if self.use_edge_attrs else None
|
71
|
+
|
72
|
+
h = self.get_logits(x, pos_edge_index, edge_attr)
|
73
|
+
# h = x
|
74
|
+
|
75
|
+
pos_scores = self.get_prediction_score(h, pos_edge_index, edge_attr)
|
76
|
+
neg_scores = self.get_prediction_score(h, neg_edge_index, edge_attr)
|
77
|
+
loss = self.compute_loss(pos_scores, neg_scores)
|
78
|
+
all_labels.append(torch.cat([torch.ones(pos_scores.size(0)), torch.zeros(neg_scores.size(0))]))
|
79
|
+
all_preds.append(torch.cat([pos_scores.detach().cpu(), neg_scores.detach().cpu()]))
|
80
|
+
|
81
|
+
loss.backward()
|
82
|
+
self.optimizer.step()
|
83
|
+
self.scheduler.step()
|
84
|
+
epoch_loss += loss.item()
|
85
|
+
|
86
|
+
|
87
|
+
all_preds = torch.cat(all_preds, dim=0)
|
88
|
+
all_labels = torch.cat(all_labels, dim=0)
|
89
|
+
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
90
|
+
epoch_metrics['loss'] = epoch_loss
|
91
|
+
epoch_metrics['phase'] = 'train'
|
92
|
+
|
93
|
+
return epoch_metrics
|
94
|
+
|
95
|
+
|
96
|
+
def test(self):
|
97
|
+
self.model.eval()
|
98
|
+
self.predictor.eval()
|
99
|
+
all_preds, all_labels = list(), list()
|
100
|
+
with torch.no_grad():
|
101
|
+
epoch_loss = 0
|
102
|
+
epoch_metrics = defaultdict(float)
|
103
|
+
for data in tqdm(self.dataloader, desc='Testing Batches'):
|
104
|
+
|
105
|
+
x = data.x
|
106
|
+
pos_edge_index = data.test_pos_edge_label_index
|
107
|
+
neg_edge_index = data.test_neg_edge_label_index
|
108
|
+
test_mask = data.test_edge_mask
|
109
|
+
edge_attr = data.edge_attr[test_mask] if self.use_edge_attrs else None
|
110
|
+
|
111
|
+
|
112
|
+
h = self.get_logits(x, pos_edge_index, edge_attr)
|
113
|
+
# h = x
|
114
|
+
pos_score = self.get_prediction_score(h, pos_edge_index, edge_attr)
|
115
|
+
neg_score = self.get_prediction_score(h, neg_edge_index, edge_attr)
|
116
|
+
|
117
|
+
loss = self.compute_loss(pos_score, neg_score)
|
118
|
+
all_labels.append(torch.cat([torch.ones(pos_score.size(0)), torch.zeros(neg_score.size(0))]))
|
119
|
+
all_preds.append(torch.cat([pos_score.detach().cpu(), neg_score.detach().cpu()]))
|
120
|
+
|
121
|
+
epoch_loss += loss.item()
|
122
|
+
|
123
|
+
all_preds = torch.cat(all_preds, dim=0)
|
124
|
+
all_labels = torch.cat(all_labels, dim=0)
|
125
|
+
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
126
|
+
|
127
|
+
epoch_metrics['loss'] = epoch_loss
|
128
|
+
epoch_metrics['phase'] = 'test'
|
129
|
+
# print(f"Epoch Test Loss: {epoch_loss}\nTest Accuracy: {epoch_acc}\nTest F1: {epoch_f1}")
|
130
|
+
self.results.append(epoch_metrics)
|
131
|
+
|
132
|
+
print(f"Test Epoch: {len(self.results)}\n{epoch_metrics}")
|
133
|
+
|
134
|
+
return epoch_metrics
|
135
|
+
|
136
|
+
def compute_loss(self, pos_score, neg_score):
|
137
|
+
pos_label = torch.ones(pos_score.size(0), dtype=torch.long).to(device)
|
138
|
+
neg_label = torch.zeros(neg_score.size(0), dtype=torch.long).to(device)
|
139
|
+
|
140
|
+
scores = torch.cat([pos_score, neg_score], dim=0)
|
141
|
+
labels = torch.cat([pos_label, neg_label], dim=0)
|
142
|
+
|
143
|
+
loss = self.criterion(scores, labels)
|
144
|
+
return loss
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from typing import List
|
2
|
+
from torch_geometric.loader import DataLoader
|
3
|
+
import torch
|
4
|
+
from collections import defaultdict
|
5
|
+
from torch_geometric.data import Data
|
6
|
+
from glam4cm.models.gnn_layers import (
|
7
|
+
GNNConv,
|
8
|
+
NodeClassifier
|
9
|
+
)
|
10
|
+
from glam4cm.trainers.gnn_trainer import Trainer
|
11
|
+
from glam4cm.settings import device
|
12
|
+
|
13
|
+
|
14
|
+
class GNNNodeClassificationTrainer(Trainer):
|
15
|
+
"""
|
16
|
+
Trainer class for GNN Link Prediction
|
17
|
+
This class is used to train the GNN model for the link prediction task
|
18
|
+
The model is trained to predict the link between two nodes
|
19
|
+
"""
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
model: GNNConv,
|
23
|
+
predictor: NodeClassifier,
|
24
|
+
dataset: List[Data],
|
25
|
+
cls_label,
|
26
|
+
exclude_labels=None,
|
27
|
+
lr=1e-3,
|
28
|
+
num_epochs=100,
|
29
|
+
batch_size=32,
|
30
|
+
use_edge_attrs=False,
|
31
|
+
logs_dir='./logs'
|
32
|
+
) -> None:
|
33
|
+
|
34
|
+
super().__init__(
|
35
|
+
model=model,
|
36
|
+
predictor=predictor,
|
37
|
+
cls_label=cls_label,
|
38
|
+
lr=lr,
|
39
|
+
num_epochs=num_epochs,
|
40
|
+
use_edge_attrs=use_edge_attrs,
|
41
|
+
logs_dir=logs_dir
|
42
|
+
)
|
43
|
+
|
44
|
+
self.exclude_labels = torch.tensor(exclude_labels, dtype=torch.long)
|
45
|
+
self.results = list()
|
46
|
+
self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
47
|
+
print("GNN Trainer initialized.")
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
def train(self):
|
52
|
+
self.model.train()
|
53
|
+
self.predictor.train()
|
54
|
+
|
55
|
+
all_preds, all_labels = list(), list()
|
56
|
+
epoch_loss = 0
|
57
|
+
epoch_metrics = defaultdict(float)
|
58
|
+
# for i, data in tqdm(enumerate(self.dataloader), desc=f"Training batches", total=len(self.dataloader)):
|
59
|
+
for data in self.dataloader:
|
60
|
+
self.optimizer.zero_grad()
|
61
|
+
self.model.zero_grad()
|
62
|
+
self.predictor.zero_grad()
|
63
|
+
|
64
|
+
h = self.get_logits(
|
65
|
+
data.x,
|
66
|
+
data.edge_index,
|
67
|
+
data.edge_attr if self.use_edge_attrs else None
|
68
|
+
)
|
69
|
+
scores = self.get_prediction_score(h)[data.train_node_mask]
|
70
|
+
labels = getattr(data, f"node_{self.cls_label}")[data.train_node_mask]
|
71
|
+
|
72
|
+
mask = ~torch.isin(labels, self.exclude_labels)
|
73
|
+
labels = labels[mask]
|
74
|
+
scores = scores[mask]
|
75
|
+
|
76
|
+
loss = self.criterion(scores, labels.to(device))
|
77
|
+
|
78
|
+
all_preds.append(scores.detach().cpu())
|
79
|
+
all_labels.append(labels)
|
80
|
+
|
81
|
+
loss.backward()
|
82
|
+
self.optimizer.step()
|
83
|
+
self.scheduler.step()
|
84
|
+
epoch_loss += loss.item()
|
85
|
+
|
86
|
+
|
87
|
+
all_preds = torch.cat(all_preds, dim=0)
|
88
|
+
all_labels = torch.cat(all_labels, dim=0)
|
89
|
+
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
90
|
+
epoch_metrics['loss'] = epoch_loss
|
91
|
+
epoch_metrics['phase'] = 'train'
|
92
|
+
|
93
|
+
return epoch_metrics
|
94
|
+
|
95
|
+
|
96
|
+
def test(self):
|
97
|
+
self.model.eval()
|
98
|
+
self.predictor.eval()
|
99
|
+
all_preds, all_labels = list(), list()
|
100
|
+
with torch.no_grad():
|
101
|
+
epoch_loss = 0
|
102
|
+
epoch_metrics = defaultdict(float)
|
103
|
+
# for _, data in tqdm(enumerate(self.dataloader), desc=f"Evaluating batches", total=len(self.dataloader)):
|
104
|
+
for data in self.dataloader:
|
105
|
+
h = self.get_logits(
|
106
|
+
data.x,
|
107
|
+
data.edge_index,
|
108
|
+
data.edge_attr if self.use_edge_attrs else None
|
109
|
+
)
|
110
|
+
|
111
|
+
scores = self.get_prediction_score(h)[data.test_node_mask]
|
112
|
+
labels = getattr(data, f"node_{self.cls_label}")[data.test_node_mask]
|
113
|
+
|
114
|
+
mask = ~torch.isin(labels, self.exclude_labels)
|
115
|
+
labels = labels[mask]
|
116
|
+
scores = scores[mask]
|
117
|
+
loss = self.criterion(scores, labels.to(device))
|
118
|
+
epoch_loss += loss.item()
|
119
|
+
|
120
|
+
|
121
|
+
all_preds.append(scores.detach().cpu())
|
122
|
+
all_labels.append(labels)
|
123
|
+
|
124
|
+
|
125
|
+
all_preds = torch.cat(all_preds, dim=0)
|
126
|
+
all_labels = torch.cat(all_labels, dim=0)
|
127
|
+
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
128
|
+
|
129
|
+
epoch_metrics['loss'] = epoch_loss
|
130
|
+
epoch_metrics['phase'] = 'test'
|
131
|
+
self.results.append(epoch_metrics)
|
132
|
+
|
133
|
+
print(f"Epoch: {len(self.results)}\n{epoch_metrics}")
|
134
|
+
|
135
|
+
return epoch_metrics
|
@@ -0,0 +1,129 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
import torch
|
3
|
+
from typing import Union
|
4
|
+
import pandas as pd
|
5
|
+
|
6
|
+
from glam4cm.models.gnn_layers import (
|
7
|
+
GNNConv,
|
8
|
+
EdgeClassifer,
|
9
|
+
NodeClassifier
|
10
|
+
)
|
11
|
+
from glam4cm.settings import device
|
12
|
+
from itertools import chain
|
13
|
+
from tqdm.auto import tqdm
|
14
|
+
import torch.nn as nn
|
15
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
16
|
+
from torch.optim import Adam
|
17
|
+
from torch_geometric.loader import DataLoader
|
18
|
+
|
19
|
+
from tensorboardX import SummaryWriter
|
20
|
+
from glam4cm.trainers.metrics import compute_classification_metrics
|
21
|
+
|
22
|
+
|
23
|
+
class Trainer:
|
24
|
+
"""
|
25
|
+
Trainer class for GNN Link Prediction
|
26
|
+
This class is used to train the GNN model for the link prediction task
|
27
|
+
The model is trained to predict the link between two nodes
|
28
|
+
"""
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
model: GNNConv,
|
32
|
+
predictor: Union[EdgeClassifer, NodeClassifier],
|
33
|
+
cls_label,
|
34
|
+
lr=1e-3,
|
35
|
+
num_epochs=100,
|
36
|
+
use_edge_attrs=False,
|
37
|
+
|
38
|
+
logs_dir='./logs'
|
39
|
+
) -> None:
|
40
|
+
self.model = model
|
41
|
+
self.predictor = predictor
|
42
|
+
self.model.to(device)
|
43
|
+
self.predictor.to(device)
|
44
|
+
|
45
|
+
self.cls_label = cls_label
|
46
|
+
self.num_epochs = num_epochs
|
47
|
+
|
48
|
+
self.optimizer = Adam(chain(model.parameters(), predictor.parameters()), lr=lr)
|
49
|
+
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
|
50
|
+
|
51
|
+
self.edge2index = lambda g: torch.stack(list(g.edges())).contiguous()
|
52
|
+
self.results = list()
|
53
|
+
self.criterion = nn.CrossEntropyLoss()
|
54
|
+
|
55
|
+
self.use_edge_attrs = use_edge_attrs
|
56
|
+
|
57
|
+
self.logs_dir = logs_dir
|
58
|
+
|
59
|
+
self.writer = SummaryWriter(log_dir=self.logs_dir)
|
60
|
+
|
61
|
+
print("GNN Trainer initialized.")
|
62
|
+
|
63
|
+
|
64
|
+
def set_dataloader(self, dataset, batch_size):
|
65
|
+
self.dataloader = DataLoader(
|
66
|
+
dataset, batch_size=batch_size, shuffle=True
|
67
|
+
)
|
68
|
+
|
69
|
+
@abstractmethod
|
70
|
+
def train(self):
|
71
|
+
pass
|
72
|
+
|
73
|
+
@abstractmethod
|
74
|
+
def test(self):
|
75
|
+
pass
|
76
|
+
|
77
|
+
|
78
|
+
def get_logits(self, x, edge_index, edge_attr=None):
|
79
|
+
edge_index = edge_index.to(device)
|
80
|
+
x = x.to(device)
|
81
|
+
|
82
|
+
if edge_attr is not None:
|
83
|
+
edge_attr = edge_attr.to(device)
|
84
|
+
h = self.model(x, edge_index, edge_attr)
|
85
|
+
else:
|
86
|
+
h = self.model(x, edge_index)
|
87
|
+
return h
|
88
|
+
|
89
|
+
|
90
|
+
def get_prediction_score(self, h, edge_index=None, edge_attr=None):
|
91
|
+
h = h.to(device)
|
92
|
+
if edge_attr is not None:
|
93
|
+
edge_attr = edge_attr.to(device)
|
94
|
+
edge_index = edge_index.to(device)
|
95
|
+
prediction_score = self.predictor(h, edge_index, edge_attr)
|
96
|
+
elif edge_index is not None:
|
97
|
+
edge_index = edge_index.to(device)
|
98
|
+
prediction_score = self.predictor(h, edge_index)
|
99
|
+
else:
|
100
|
+
prediction_score = self.predictor(h)
|
101
|
+
return prediction_score
|
102
|
+
|
103
|
+
|
104
|
+
def plot_metrics(self):
|
105
|
+
results = pd.DataFrame(self.results)
|
106
|
+
df = pd.DataFrame(results, index=range(1, len(results)+1))
|
107
|
+
df['epoch'] = df.index
|
108
|
+
|
109
|
+
columns = [c for c in df.columns if c not in ['epoch', 'phase']]
|
110
|
+
df.loc[df['phase'] == 'test'].plot(x='epoch', y=columns, kind='line')
|
111
|
+
|
112
|
+
|
113
|
+
def run(self):
|
114
|
+
for epoch in tqdm(range(self.num_epochs), desc="Running Epochs"):
|
115
|
+
train_metrics = self.train()
|
116
|
+
test_metrics = self.test()
|
117
|
+
|
118
|
+
for k, v in train_metrics.items():
|
119
|
+
if k != 'phase':
|
120
|
+
self.writer.add_scalar(f"train/{k}", v, epoch)
|
121
|
+
|
122
|
+
for k, v in test_metrics.items():
|
123
|
+
if k != 'phase':
|
124
|
+
self.writer.add_scalar(f"test/{k}", v, epoch)
|
125
|
+
|
126
|
+
self.writer.close()
|
127
|
+
|
128
|
+
def compute_metrics(self, all_preds, all_labels):
|
129
|
+
return compute_classification_metrics(all_preds, all_labels)
|