glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,126 @@
1
+ from typing import List
2
+ import torch
3
+ from collections import defaultdict
4
+ from torch_geometric.loader import DataLoader
5
+ from glam4cm.data_loading.data import GraphData
6
+ from glam4cm.models.gnn_layers import (
7
+ GNNConv,
8
+ EdgeClassifer
9
+ )
10
+
11
+ from glam4cm.trainers.gnn_trainer import Trainer
12
+ from glam4cm.settings import device
13
+
14
+
15
+ class GNNEdgeClassificationTrainer(Trainer):
16
+ """
17
+ Trainer class for GNN Link Prediction
18
+ This class is used to train the GNN model for the link prediction task
19
+ The model is trained to predict the link between two nodes
20
+ """
21
+ def __init__(
22
+ self,
23
+ model: GNNConv,
24
+ predictor: EdgeClassifer,
25
+ dataset: List[GraphData],
26
+ cls_label='type',
27
+ lr=1e-3,
28
+ num_epochs=100,
29
+ batch_size=32,
30
+ use_edge_attrs=False,
31
+ logs_dir='./logs'
32
+ ) -> None:
33
+
34
+ super().__init__(
35
+ model=model,
36
+ predictor=predictor,
37
+ cls_label=cls_label,
38
+ lr=lr,
39
+ num_epochs=num_epochs,
40
+ use_edge_attrs=use_edge_attrs,
41
+ logs_dir=logs_dir
42
+ )
43
+
44
+ self.dataloader = DataLoader(
45
+ dataset,
46
+ batch_size=batch_size,
47
+ shuffle=True
48
+ )
49
+
50
+ print("GNN Trainer initialized.")
51
+
52
+
53
+ def train(self):
54
+ self.model.train()
55
+ self.predictor.train()
56
+
57
+ all_preds, all_labels = list(), list()
58
+ epoch_loss = 0
59
+ epoch_metrics = defaultdict(float)
60
+ for data in self.dataloader:
61
+ self.optimizer.zero_grad()
62
+ self.model.zero_grad()
63
+ self.predictor.zero_grad()
64
+ x = data.x
65
+ edge_index = data.train_pos_edge_label_index
66
+ train_mask = data.train_edge_mask
67
+ edge_attr = data.edge_attr[train_mask] if self.use_edge_attrs else None
68
+
69
+ h = self.get_logits(x, edge_index, edge_attr)
70
+ scores = self.get_prediction_score(h, edge_index, edge_attr)
71
+ labels = getattr(data, f"edge_{self.cls_label}")[train_mask]
72
+ loss = self.criterion(scores, labels.to(device))
73
+ all_preds.append(scores.detach().cpu())
74
+ all_labels.append(labels)
75
+
76
+ loss.backward()
77
+ self.optimizer.step()
78
+ self.scheduler.step()
79
+ epoch_loss += loss.item()
80
+
81
+
82
+ all_preds = torch.cat(all_preds, dim=0)
83
+ all_labels = torch.cat(all_labels, dim=0)
84
+ epoch_metrics = self.compute_metrics(all_preds, all_labels)
85
+ epoch_metrics['loss'] = epoch_loss
86
+ epoch_metrics['phase'] = 'train'
87
+
88
+ # print(f"Train Metrics: {epoch_metrics}")
89
+
90
+ return epoch_metrics
91
+
92
+
93
+ def test(self):
94
+ self.model.eval()
95
+ self.predictor.eval()
96
+ all_preds, all_labels = list(), list()
97
+ with torch.no_grad():
98
+ epoch_loss = 0
99
+ epoch_metrics = defaultdict(float)
100
+ for data in self.dataloader:
101
+ x = data.x
102
+ edge_index = data.test_pos_edge_label_index
103
+ test_mask = data.test_edge_mask
104
+ edge_attr = data.edge_attr[test_mask] if self.use_edge_attrs else None
105
+
106
+ h = self.get_logits(x, edge_index, edge_attr)
107
+ scores = self.get_prediction_score(h, edge_index, edge_attr)
108
+ labels = getattr(data, f"edge_{self.cls_label}")[test_mask]
109
+ all_preds.append(scores.detach().cpu())
110
+ all_labels.append(labels)
111
+ loss = self.criterion(scores, labels.to(device))
112
+
113
+ epoch_loss += loss.item()
114
+
115
+ all_preds = torch.cat(all_preds, dim=0)
116
+ all_labels = torch.cat(all_labels, dim=0)
117
+ epoch_metrics = self.compute_metrics(all_preds, all_labels)
118
+
119
+ epoch_metrics['loss'] = epoch_loss
120
+ epoch_metrics['phase'] = 'test'
121
+ # print(f"Epoch Test Loss: {epoch_loss}\nTest Accuracy: {epoch_acc}\nTest F1: {epoch_f1}")
122
+ self.results.append(epoch_metrics)
123
+
124
+ print(f"Epoch: {len(self.results)}\n{epoch_metrics}")
125
+
126
+ return epoch_metrics
@@ -0,0 +1,123 @@
1
+ from typing import List, Tuple
2
+ import torch
3
+ from collections import defaultdict
4
+ from torch_geometric.loader import DataLoader
5
+
6
+ from torch_geometric.data import Data
7
+ from glam4cm.models.gnn_layers import (
8
+ GNNConv,
9
+ GraphClassifer
10
+ )
11
+ from glam4cm.trainers.gnn_trainer import Trainer
12
+ from glam4cm.settings import device
13
+
14
+
15
+ class GNNGraphClassificationTrainer(Trainer):
16
+ """
17
+ Trainer class for GNN Graph Classfication
18
+ This class is used to train the GNN model for the link prediction task
19
+ The model is trained to predict the link between two nodes
20
+ """
21
+ def __init__(
22
+ self,
23
+ model: GNNConv,
24
+ predictor: GraphClassifer,
25
+ dataset: List[Tuple[Data, Data]],
26
+ cls_label='label',
27
+ lr=1e-4,
28
+ num_epochs=100,
29
+ batch_size=32,
30
+ use_edge_attrs=False,
31
+ logs_dir='./logs'
32
+ ) -> None:
33
+
34
+ super().__init__(
35
+ model=model,
36
+ predictor=predictor,
37
+ cls_label='type',
38
+ lr=lr,
39
+ num_epochs=num_epochs,
40
+ use_edge_attrs=use_edge_attrs,
41
+ logs_dir=logs_dir
42
+ )
43
+
44
+ self.cls_label = cls_label
45
+ self.dataloaders = dict()
46
+ self.dataloaders['train'] = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
47
+ self.dataloaders['test'] = DataLoader(dataset['test'], batch_size=batch_size, shuffle=False)
48
+
49
+ self.results = list()
50
+
51
+ print("GNN Trainer initialized.")
52
+
53
+
54
+ def train(self):
55
+ self.model.train()
56
+ self.predictor.train()
57
+
58
+ epoch_loss = 0
59
+ epoch_metrics = defaultdict(float)
60
+ preds, all_labels = list(), list()
61
+ for data in self.dataloaders['train']:
62
+ self.optimizer.zero_grad()
63
+ self.model.train()
64
+ self.predictor.train()
65
+
66
+ h = self.model(data.x.to(device), data.edge_index.to(device))
67
+ g_pred = self.predictor(h, data.batch.to(device))
68
+
69
+
70
+ labels = getattr(data, f"graph_{self.cls_label}")
71
+ loss = self.criterion(g_pred, labels.to(device))
72
+
73
+ preds.append(g_pred.detach().cpu())
74
+ all_labels.append(labels)
75
+
76
+ loss.backward()
77
+ self.optimizer.step()
78
+ self.scheduler.step()
79
+ epoch_loss += loss.item()
80
+
81
+
82
+ preds = torch.cat(preds, dim=0)
83
+ labels = torch.cat(all_labels, dim=0)
84
+
85
+ epoch_metrics = self.compute_metrics(preds, labels)
86
+ epoch_metrics['loss'] = epoch_loss
87
+ epoch_metrics['phase'] = 'train'
88
+
89
+ self.results.append(epoch_metrics)
90
+
91
+ return epoch_metrics
92
+
93
+
94
+ def test(self):
95
+ self.model.eval()
96
+ self.predictor.eval()
97
+ with torch.no_grad():
98
+ epoch_loss = 0
99
+ preds, all_labels = list(), list()
100
+ for data in self.dataloaders['test']:
101
+ h = self.model(data.x.to(device), data.edge_index.to(device))
102
+ g_pred = self.predictor(h, data.batch.to(device))
103
+ labels = getattr(data, f"graph_{self.cls_label}")
104
+
105
+ loss = self.criterion(g_pred, labels.to(device))
106
+ epoch_loss += loss.item()
107
+
108
+ preds.append(g_pred.cpu().detach())
109
+ all_labels.append(labels.cpu())
110
+
111
+
112
+ preds = torch.cat(preds, dim=0)
113
+ labels = torch.cat(all_labels, dim=0)
114
+
115
+ epoch_metrics = self.compute_metrics(preds, labels)
116
+ epoch_metrics['loss'] = epoch_loss
117
+ epoch_metrics['phase'] = 'test'
118
+ self.results.append(epoch_metrics)
119
+
120
+ s2t = lambda x: x.replace("_", " ").title()
121
+ print(f"Epoch: {len(self.results)//2} {' | '.join([f'{s2t(k)}: {v:.4f}' for k, v in epoch_metrics.items() if k != 'phase'])}")
122
+
123
+ return epoch_metrics
@@ -0,0 +1,144 @@
1
+ from torch_geometric.loader import DataLoader
2
+ from torch_geometric.data import Data
3
+ import torch
4
+ from collections import defaultdict
5
+ from typing import List
6
+
7
+ from glam4cm.models.gnn_layers import (
8
+ GNNConv,
9
+ EdgeClassifer
10
+ )
11
+
12
+ from glam4cm.trainers.gnn_trainer import Trainer
13
+ from tqdm.auto import tqdm
14
+ from glam4cm.settings import device
15
+
16
+
17
+ class GNNLinkPredictionTrainer(Trainer):
18
+ """
19
+ Trainer class for GNN Link Prediction
20
+ This class is used to train the GNN model for the link prediction task
21
+ The model is trained to predict the link between two nodes
22
+ """
23
+ def __init__(
24
+ self,
25
+ model: GNNConv,
26
+ predictor: EdgeClassifer,
27
+ dataset: List[Data],
28
+ cls_label='type',
29
+ lr=1e-3,
30
+ num_epochs=100,
31
+ batch_size=32,
32
+ use_edge_attrs=False,
33
+ logs_dir='./logs'
34
+ ) -> None:
35
+
36
+ super().__init__(
37
+ model=model,
38
+ predictor=predictor,
39
+ lr=lr,
40
+ cls_label=cls_label,
41
+ num_epochs=num_epochs,
42
+ use_edge_attrs=use_edge_attrs,
43
+ logs_dir=logs_dir
44
+ )
45
+ self.dataloader = DataLoader(
46
+ dataset, batch_size=batch_size, shuffle=True
47
+ )
48
+ self.results = list()
49
+
50
+ print("GNN Trainer initialized.")
51
+
52
+
53
+
54
+ def train(self):
55
+ self.model.train()
56
+ self.predictor.train()
57
+
58
+ all_preds, all_labels = list(), list()
59
+ epoch_loss = 0
60
+ epoch_metrics = defaultdict(float)
61
+ for data in tqdm(self.dataloader, desc='Training Batches'):
62
+ self.optimizer.zero_grad()
63
+ self.model.zero_grad()
64
+ self.predictor.zero_grad()
65
+
66
+ x = data.x
67
+ pos_edge_index = data.train_pos_edge_label_index
68
+ neg_edge_index = data.train_neg_edge_label_index
69
+ train_mask = data.train_edge_mask
70
+ edge_attr = data.edge_attr[train_mask] if self.use_edge_attrs else None
71
+
72
+ h = self.get_logits(x, pos_edge_index, edge_attr)
73
+ # h = x
74
+
75
+ pos_scores = self.get_prediction_score(h, pos_edge_index, edge_attr)
76
+ neg_scores = self.get_prediction_score(h, neg_edge_index, edge_attr)
77
+ loss = self.compute_loss(pos_scores, neg_scores)
78
+ all_labels.append(torch.cat([torch.ones(pos_scores.size(0)), torch.zeros(neg_scores.size(0))]))
79
+ all_preds.append(torch.cat([pos_scores.detach().cpu(), neg_scores.detach().cpu()]))
80
+
81
+ loss.backward()
82
+ self.optimizer.step()
83
+ self.scheduler.step()
84
+ epoch_loss += loss.item()
85
+
86
+
87
+ all_preds = torch.cat(all_preds, dim=0)
88
+ all_labels = torch.cat(all_labels, dim=0)
89
+ epoch_metrics = self.compute_metrics(all_preds, all_labels)
90
+ epoch_metrics['loss'] = epoch_loss
91
+ epoch_metrics['phase'] = 'train'
92
+
93
+ return epoch_metrics
94
+
95
+
96
+ def test(self):
97
+ self.model.eval()
98
+ self.predictor.eval()
99
+ all_preds, all_labels = list(), list()
100
+ with torch.no_grad():
101
+ epoch_loss = 0
102
+ epoch_metrics = defaultdict(float)
103
+ for data in tqdm(self.dataloader, desc='Testing Batches'):
104
+
105
+ x = data.x
106
+ pos_edge_index = data.test_pos_edge_label_index
107
+ neg_edge_index = data.test_neg_edge_label_index
108
+ test_mask = data.test_edge_mask
109
+ edge_attr = data.edge_attr[test_mask] if self.use_edge_attrs else None
110
+
111
+
112
+ h = self.get_logits(x, pos_edge_index, edge_attr)
113
+ # h = x
114
+ pos_score = self.get_prediction_score(h, pos_edge_index, edge_attr)
115
+ neg_score = self.get_prediction_score(h, neg_edge_index, edge_attr)
116
+
117
+ loss = self.compute_loss(pos_score, neg_score)
118
+ all_labels.append(torch.cat([torch.ones(pos_score.size(0)), torch.zeros(neg_score.size(0))]))
119
+ all_preds.append(torch.cat([pos_score.detach().cpu(), neg_score.detach().cpu()]))
120
+
121
+ epoch_loss += loss.item()
122
+
123
+ all_preds = torch.cat(all_preds, dim=0)
124
+ all_labels = torch.cat(all_labels, dim=0)
125
+ epoch_metrics = self.compute_metrics(all_preds, all_labels)
126
+
127
+ epoch_metrics['loss'] = epoch_loss
128
+ epoch_metrics['phase'] = 'test'
129
+ # print(f"Epoch Test Loss: {epoch_loss}\nTest Accuracy: {epoch_acc}\nTest F1: {epoch_f1}")
130
+ self.results.append(epoch_metrics)
131
+
132
+ print(f"Test Epoch: {len(self.results)}\n{epoch_metrics}")
133
+
134
+ return epoch_metrics
135
+
136
+ def compute_loss(self, pos_score, neg_score):
137
+ pos_label = torch.ones(pos_score.size(0), dtype=torch.long).to(device)
138
+ neg_label = torch.zeros(neg_score.size(0), dtype=torch.long).to(device)
139
+
140
+ scores = torch.cat([pos_score, neg_score], dim=0)
141
+ labels = torch.cat([pos_label, neg_label], dim=0)
142
+
143
+ loss = self.criterion(scores, labels)
144
+ return loss
@@ -0,0 +1,135 @@
1
+ from typing import List
2
+ from torch_geometric.loader import DataLoader
3
+ import torch
4
+ from collections import defaultdict
5
+ from torch_geometric.data import Data
6
+ from glam4cm.models.gnn_layers import (
7
+ GNNConv,
8
+ NodeClassifier
9
+ )
10
+ from glam4cm.trainers.gnn_trainer import Trainer
11
+ from glam4cm.settings import device
12
+
13
+
14
+ class GNNNodeClassificationTrainer(Trainer):
15
+ """
16
+ Trainer class for GNN Link Prediction
17
+ This class is used to train the GNN model for the link prediction task
18
+ The model is trained to predict the link between two nodes
19
+ """
20
+ def __init__(
21
+ self,
22
+ model: GNNConv,
23
+ predictor: NodeClassifier,
24
+ dataset: List[Data],
25
+ cls_label,
26
+ exclude_labels=None,
27
+ lr=1e-3,
28
+ num_epochs=100,
29
+ batch_size=32,
30
+ use_edge_attrs=False,
31
+ logs_dir='./logs'
32
+ ) -> None:
33
+
34
+ super().__init__(
35
+ model=model,
36
+ predictor=predictor,
37
+ cls_label=cls_label,
38
+ lr=lr,
39
+ num_epochs=num_epochs,
40
+ use_edge_attrs=use_edge_attrs,
41
+ logs_dir=logs_dir
42
+ )
43
+
44
+ self.exclude_labels = torch.tensor(exclude_labels, dtype=torch.long)
45
+ self.results = list()
46
+ self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
47
+ print("GNN Trainer initialized.")
48
+
49
+
50
+
51
+ def train(self):
52
+ self.model.train()
53
+ self.predictor.train()
54
+
55
+ all_preds, all_labels = list(), list()
56
+ epoch_loss = 0
57
+ epoch_metrics = defaultdict(float)
58
+ # for i, data in tqdm(enumerate(self.dataloader), desc=f"Training batches", total=len(self.dataloader)):
59
+ for data in self.dataloader:
60
+ self.optimizer.zero_grad()
61
+ self.model.zero_grad()
62
+ self.predictor.zero_grad()
63
+
64
+ h = self.get_logits(
65
+ data.x,
66
+ data.edge_index,
67
+ data.edge_attr if self.use_edge_attrs else None
68
+ )
69
+ scores = self.get_prediction_score(h)[data.train_node_mask]
70
+ labels = getattr(data, f"node_{self.cls_label}")[data.train_node_mask]
71
+
72
+ mask = ~torch.isin(labels, self.exclude_labels)
73
+ labels = labels[mask]
74
+ scores = scores[mask]
75
+
76
+ loss = self.criterion(scores, labels.to(device))
77
+
78
+ all_preds.append(scores.detach().cpu())
79
+ all_labels.append(labels)
80
+
81
+ loss.backward()
82
+ self.optimizer.step()
83
+ self.scheduler.step()
84
+ epoch_loss += loss.item()
85
+
86
+
87
+ all_preds = torch.cat(all_preds, dim=0)
88
+ all_labels = torch.cat(all_labels, dim=0)
89
+ epoch_metrics = self.compute_metrics(all_preds, all_labels)
90
+ epoch_metrics['loss'] = epoch_loss
91
+ epoch_metrics['phase'] = 'train'
92
+
93
+ return epoch_metrics
94
+
95
+
96
+ def test(self):
97
+ self.model.eval()
98
+ self.predictor.eval()
99
+ all_preds, all_labels = list(), list()
100
+ with torch.no_grad():
101
+ epoch_loss = 0
102
+ epoch_metrics = defaultdict(float)
103
+ # for _, data in tqdm(enumerate(self.dataloader), desc=f"Evaluating batches", total=len(self.dataloader)):
104
+ for data in self.dataloader:
105
+ h = self.get_logits(
106
+ data.x,
107
+ data.edge_index,
108
+ data.edge_attr if self.use_edge_attrs else None
109
+ )
110
+
111
+ scores = self.get_prediction_score(h)[data.test_node_mask]
112
+ labels = getattr(data, f"node_{self.cls_label}")[data.test_node_mask]
113
+
114
+ mask = ~torch.isin(labels, self.exclude_labels)
115
+ labels = labels[mask]
116
+ scores = scores[mask]
117
+ loss = self.criterion(scores, labels.to(device))
118
+ epoch_loss += loss.item()
119
+
120
+
121
+ all_preds.append(scores.detach().cpu())
122
+ all_labels.append(labels)
123
+
124
+
125
+ all_preds = torch.cat(all_preds, dim=0)
126
+ all_labels = torch.cat(all_labels, dim=0)
127
+ epoch_metrics = self.compute_metrics(all_preds, all_labels)
128
+
129
+ epoch_metrics['loss'] = epoch_loss
130
+ epoch_metrics['phase'] = 'test'
131
+ self.results.append(epoch_metrics)
132
+
133
+ print(f"Epoch: {len(self.results)}\n{epoch_metrics}")
134
+
135
+ return epoch_metrics
@@ -0,0 +1,129 @@
1
+ from abc import abstractmethod
2
+ import torch
3
+ from typing import Union
4
+ import pandas as pd
5
+
6
+ from glam4cm.models.gnn_layers import (
7
+ GNNConv,
8
+ EdgeClassifer,
9
+ NodeClassifier
10
+ )
11
+ from glam4cm.settings import device
12
+ from itertools import chain
13
+ from tqdm.auto import tqdm
14
+ import torch.nn as nn
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+ from torch.optim import Adam
17
+ from torch_geometric.loader import DataLoader
18
+
19
+ from tensorboardX import SummaryWriter
20
+ from glam4cm.trainers.metrics import compute_classification_metrics
21
+
22
+
23
+ class Trainer:
24
+ """
25
+ Trainer class for GNN Link Prediction
26
+ This class is used to train the GNN model for the link prediction task
27
+ The model is trained to predict the link between two nodes
28
+ """
29
+ def __init__(
30
+ self,
31
+ model: GNNConv,
32
+ predictor: Union[EdgeClassifer, NodeClassifier],
33
+ cls_label,
34
+ lr=1e-3,
35
+ num_epochs=100,
36
+ use_edge_attrs=False,
37
+
38
+ logs_dir='./logs'
39
+ ) -> None:
40
+ self.model = model
41
+ self.predictor = predictor
42
+ self.model.to(device)
43
+ self.predictor.to(device)
44
+
45
+ self.cls_label = cls_label
46
+ self.num_epochs = num_epochs
47
+
48
+ self.optimizer = Adam(chain(model.parameters(), predictor.parameters()), lr=lr)
49
+ self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
50
+
51
+ self.edge2index = lambda g: torch.stack(list(g.edges())).contiguous()
52
+ self.results = list()
53
+ self.criterion = nn.CrossEntropyLoss()
54
+
55
+ self.use_edge_attrs = use_edge_attrs
56
+
57
+ self.logs_dir = logs_dir
58
+
59
+ self.writer = SummaryWriter(log_dir=self.logs_dir)
60
+
61
+ print("GNN Trainer initialized.")
62
+
63
+
64
+ def set_dataloader(self, dataset, batch_size):
65
+ self.dataloader = DataLoader(
66
+ dataset, batch_size=batch_size, shuffle=True
67
+ )
68
+
69
+ @abstractmethod
70
+ def train(self):
71
+ pass
72
+
73
+ @abstractmethod
74
+ def test(self):
75
+ pass
76
+
77
+
78
+ def get_logits(self, x, edge_index, edge_attr=None):
79
+ edge_index = edge_index.to(device)
80
+ x = x.to(device)
81
+
82
+ if edge_attr is not None:
83
+ edge_attr = edge_attr.to(device)
84
+ h = self.model(x, edge_index, edge_attr)
85
+ else:
86
+ h = self.model(x, edge_index)
87
+ return h
88
+
89
+
90
+ def get_prediction_score(self, h, edge_index=None, edge_attr=None):
91
+ h = h.to(device)
92
+ if edge_attr is not None:
93
+ edge_attr = edge_attr.to(device)
94
+ edge_index = edge_index.to(device)
95
+ prediction_score = self.predictor(h, edge_index, edge_attr)
96
+ elif edge_index is not None:
97
+ edge_index = edge_index.to(device)
98
+ prediction_score = self.predictor(h, edge_index)
99
+ else:
100
+ prediction_score = self.predictor(h)
101
+ return prediction_score
102
+
103
+
104
+ def plot_metrics(self):
105
+ results = pd.DataFrame(self.results)
106
+ df = pd.DataFrame(results, index=range(1, len(results)+1))
107
+ df['epoch'] = df.index
108
+
109
+ columns = [c for c in df.columns if c not in ['epoch', 'phase']]
110
+ df.loc[df['phase'] == 'test'].plot(x='epoch', y=columns, kind='line')
111
+
112
+
113
+ def run(self):
114
+ for epoch in tqdm(range(self.num_epochs), desc="Running Epochs"):
115
+ train_metrics = self.train()
116
+ test_metrics = self.test()
117
+
118
+ for k, v in train_metrics.items():
119
+ if k != 'phase':
120
+ self.writer.add_scalar(f"train/{k}", v, epoch)
121
+
122
+ for k, v in test_metrics.items():
123
+ if k != 'phase':
124
+ self.writer.add_scalar(f"test/{k}", v, epoch)
125
+
126
+ self.writer.close()
127
+
128
+ def compute_metrics(self, all_preds, all_labels):
129
+ return compute_classification_metrics(all_preds, all_labels)