glam4cm 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +2 -1
- glam4cm/data_loading/data.py +90 -146
- glam4cm/data_loading/encoding.py +17 -6
- glam4cm/data_loading/graph_dataset.py +192 -57
- glam4cm/data_loading/metadata.py +1 -1
- glam4cm/data_loading/models_dataset.py +42 -18
- glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
- glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
- glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
- glam4cm/downstream_tasks/bert_node_classification.py +127 -89
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
- glam4cm/downstream_tasks/common_args.py +32 -4
- glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
- glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
- glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
- glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
- glam4cm/downstream_tasks/utils.py +16 -2
- glam4cm/embeddings/bert.py +1 -1
- glam4cm/embeddings/common.py +7 -4
- glam4cm/encoding/encoders.py +1 -1
- glam4cm/lang2graph/archimate.py +0 -5
- glam4cm/lang2graph/common.py +99 -41
- glam4cm/lang2graph/ecore.py +1 -2
- glam4cm/lang2graph/ontouml.py +8 -7
- glam4cm/models/gnn_layers.py +20 -6
- glam4cm/models/hf.py +2 -2
- glam4cm/run.py +13 -9
- glam4cm/run_conf_v2.py +405 -0
- glam4cm/run_configs.py +70 -106
- glam4cm/run_confs.py +41 -0
- glam4cm/settings.py +15 -2
- glam4cm/tokenization/special_tokens.py +23 -1
- glam4cm/tokenization/utils.py +23 -4
- glam4cm/trainers/cm_gpt_trainer.py +1 -1
- glam4cm/trainers/gnn_edge_classifier.py +12 -1
- glam4cm/trainers/gnn_graph_classifier.py +12 -5
- glam4cm/trainers/gnn_link_predictor.py +18 -3
- glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
- glam4cm/trainers/gnn_trainer.py +8 -0
- glam4cm/trainers/metrics.py +1 -1
- glam4cm/utils.py +265 -2
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
- glam4cm-1.0.0.dist-info/RECORD +75 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
- glam4cm-0.1.0.dist-info/RECORD +0 -72
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
glam4cm/settings.py
CHANGED
@@ -9,6 +9,7 @@ logger.setLevel(logging.DEBUG)
|
|
9
9
|
|
10
10
|
|
11
11
|
BERT_MODEL = 'bert-base-uncased'
|
12
|
+
MODERN_BERT = 'answerdotai/ModernBERT-base'
|
12
13
|
WORD2VEC_MODEL = 'word2vec'
|
13
14
|
TFIDF_MODEL = 'tfidf'
|
14
15
|
FAST_TEXT_MODEL = 'uml-fasttext.bin'
|
@@ -35,12 +36,22 @@ modelsets_ecore_json_path = os.path.join(datasets_dir, 'modelset/ecore.jsonl')
|
|
35
36
|
|
36
37
|
|
37
38
|
graph_data_dir = 'datasets/graph_data'
|
39
|
+
results_dir = 'results'
|
38
40
|
|
39
41
|
# Path: settings.py
|
40
42
|
|
41
43
|
|
42
|
-
|
43
|
-
|
44
|
+
EDGE_CLS_TASK = 'edge_cls'
|
45
|
+
LINK_PRED_TASK = 'lp'
|
46
|
+
NODE_CLS_TASK = 'node_cls'
|
47
|
+
GRAPH_CLS_TASK = 'graph_cls'
|
48
|
+
DUMMY_GRAPH_CLS_TASK = 'dummy_graph_cls'
|
49
|
+
|
50
|
+
|
51
|
+
SEP = ' '
|
52
|
+
REFERENCE = 'reference'
|
53
|
+
SUPERTYPE = 'supertype'
|
54
|
+
CONTAINMENT = 'containment'
|
44
55
|
|
45
56
|
|
46
57
|
EPOCH = 'epoch'
|
@@ -52,3 +63,5 @@ TEST_ACC = 'test_acc'
|
|
52
63
|
TRAINING_PHASE = 'train'
|
53
64
|
VALIDATION_PHASE = 'val'
|
54
65
|
TESTING_PHASE = 'test'
|
66
|
+
|
67
|
+
|
@@ -1,4 +1,26 @@
|
|
1
1
|
EDGE_START = '<edge_begin>'
|
2
2
|
EDGE_END = '<edge_end>'
|
3
3
|
NODE_BEGIN = '<node_begin>'
|
4
|
-
NODE_END = '<node_end>'
|
4
|
+
NODE_END = '<node_end>'
|
5
|
+
|
6
|
+
escape_keywords = [
|
7
|
+
"EString",
|
8
|
+
"EInt",
|
9
|
+
"EBoolean",
|
10
|
+
"EFloat",
|
11
|
+
"EAttribute",
|
12
|
+
"EReference",
|
13
|
+
"EClass",
|
14
|
+
"EEnum",
|
15
|
+
"EEnumLiteral",
|
16
|
+
"EDataType",
|
17
|
+
"EOperation",
|
18
|
+
"EParameter",
|
19
|
+
"ETypeParameter",
|
20
|
+
"EAnnotation",
|
21
|
+
"stereotype",
|
22
|
+
EDGE_START,
|
23
|
+
EDGE_END,
|
24
|
+
NODE_BEGIN,
|
25
|
+
NODE_END
|
26
|
+
]
|
glam4cm/tokenization/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from re import finditer
|
2
|
+
from typing import List
|
2
3
|
from glam4cm.tokenization.special_tokens import (
|
3
|
-
EDGE_START, EDGE_END, NODE_BEGIN, NODE_END
|
4
|
+
EDGE_START, EDGE_END, NODE_BEGIN, NODE_END, escape_keywords
|
4
5
|
)
|
5
6
|
from transformers import AutoTokenizer
|
6
7
|
|
@@ -24,6 +25,8 @@ def get_tokenizer(model_name, use_special_tokens=False, max_length=512) -> AutoT
|
|
24
25
|
|
25
26
|
|
26
27
|
def camel_case_split(identifier) -> list:
|
28
|
+
if any(ek in identifier for ek in escape_keywords):
|
29
|
+
return [identifier]
|
27
30
|
matches = finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
|
28
31
|
return [m.group(0) for m in matches]
|
29
32
|
|
@@ -31,7 +34,23 @@ def camel_case_split(identifier) -> list:
|
|
31
34
|
def doc_tokenizer(doc, lower=False) -> str:
|
32
35
|
words = doc.split()
|
33
36
|
# split _
|
34
|
-
|
37
|
+
snake_words: List[str] = list()
|
38
|
+
for w1 in words:
|
39
|
+
if any(ek in w1 for ek in escape_keywords):
|
40
|
+
snake_words.append(w1)
|
41
|
+
else:
|
42
|
+
snake_words.extend([w2 for w2 in w1.split('_') if w2 != ''])
|
43
|
+
|
44
|
+
|
35
45
|
# camelcase
|
36
|
-
|
37
|
-
|
46
|
+
final_words: List[str] = list()
|
47
|
+
for word in snake_words:
|
48
|
+
if any(ek in word for ek in escape_keywords):
|
49
|
+
final_words.append(word)
|
50
|
+
else:
|
51
|
+
final_words.extend(camel_case_split(word))
|
52
|
+
|
53
|
+
if lower:
|
54
|
+
final_words = [w.lower() for w in final_words]
|
55
|
+
|
56
|
+
return " ".join(final_words)
|
@@ -52,7 +52,7 @@ class CMGPTTrainer:
|
|
52
52
|
self.compute_metrics = compute_metrics
|
53
53
|
|
54
54
|
print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)/ 1000000:.3f}M")
|
55
|
-
|
55
|
+
print(f"Logging to: {log_dir}")
|
56
56
|
|
57
57
|
def step(self, batch, idx=None):
|
58
58
|
# B, T = batch['input_ids'].shape
|
@@ -57,6 +57,9 @@ class GNNEdgeClassificationTrainer(Trainer):
|
|
57
57
|
all_preds, all_labels = list(), list()
|
58
58
|
epoch_loss = 0
|
59
59
|
epoch_metrics = defaultdict(float)
|
60
|
+
# print("Total dataloader size: ", len(self.dataloader))
|
61
|
+
# from tqdm.auto import tqdm
|
62
|
+
# for data in tqdm(self.dataloader):
|
60
63
|
for data in self.dataloader:
|
61
64
|
self.optimizer.zero_grad()
|
62
65
|
self.model.zero_grad()
|
@@ -81,7 +84,10 @@ class GNNEdgeClassificationTrainer(Trainer):
|
|
81
84
|
|
82
85
|
all_preds = torch.cat(all_preds, dim=0)
|
83
86
|
all_labels = torch.cat(all_labels, dim=0)
|
87
|
+
# import time
|
88
|
+
# t1 = time.time()
|
84
89
|
epoch_metrics = self.compute_metrics(all_preds, all_labels)
|
90
|
+
# print(f"Time taken: {time.time() - t1}")
|
85
91
|
epoch_metrics['loss'] = epoch_loss
|
86
92
|
epoch_metrics['phase'] = 'train'
|
87
93
|
|
@@ -99,11 +105,16 @@ class GNNEdgeClassificationTrainer(Trainer):
|
|
99
105
|
epoch_metrics = defaultdict(float)
|
100
106
|
for data in self.dataloader:
|
101
107
|
x = data.x
|
108
|
+
train_edge_index = data.train_pos_edge_label_index
|
109
|
+
train_mask = data.train_edge_mask
|
110
|
+
train_edge_attr = data.edge_attr[train_mask] if self.use_edge_attrs else None
|
111
|
+
|
102
112
|
edge_index = data.test_pos_edge_label_index
|
103
113
|
test_mask = data.test_edge_mask
|
104
114
|
edge_attr = data.edge_attr[test_mask] if self.use_edge_attrs else None
|
105
115
|
|
106
|
-
h = self.get_logits(x,
|
116
|
+
h = self.get_logits(x, train_edge_index, train_edge_attr)
|
117
|
+
|
107
118
|
scores = self.get_prediction_score(h, edge_index, edge_attr)
|
108
119
|
labels = getattr(data, f"edge_{self.cls_label}")[test_mask]
|
109
120
|
all_preds.append(scores.detach().cpu())
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import List, Tuple
|
1
|
+
from typing import Dict, List, Tuple
|
2
2
|
import torch
|
3
3
|
from collections import defaultdict
|
4
4
|
from torch_geometric.loader import DataLoader
|
@@ -22,7 +22,7 @@ class GNNGraphClassificationTrainer(Trainer):
|
|
22
22
|
self,
|
23
23
|
model: GNNConv,
|
24
24
|
predictor: GraphClassifer,
|
25
|
-
dataset:
|
25
|
+
dataset: Dict[str, List[Data]],
|
26
26
|
cls_label='label',
|
27
27
|
lr=1e-4,
|
28
28
|
num_epochs=100,
|
@@ -43,8 +43,14 @@ class GNNGraphClassificationTrainer(Trainer):
|
|
43
43
|
|
44
44
|
self.cls_label = cls_label
|
45
45
|
self.dataloaders = dict()
|
46
|
-
self.dataloaders['train'] = DataLoader(
|
47
|
-
|
46
|
+
self.dataloaders['train'] = DataLoader(
|
47
|
+
[g for g in dataset['train'] if len(g.edge_index) != 0],
|
48
|
+
batch_size=batch_size, shuffle=True
|
49
|
+
)
|
50
|
+
self.dataloaders['test'] = DataLoader(
|
51
|
+
[g for g in dataset['test'] if len(g.edge_index) != 0],
|
52
|
+
batch_size=batch_size, shuffle=False
|
53
|
+
)
|
48
54
|
|
49
55
|
self.results = list()
|
50
56
|
|
@@ -120,4 +126,5 @@ class GNNGraphClassificationTrainer(Trainer):
|
|
120
126
|
s2t = lambda x: x.replace("_", " ").title()
|
121
127
|
print(f"Epoch: {len(self.results)//2} {' | '.join([f'{s2t(k)}: {v:.4f}' for k, v in epoch_metrics.items() if k != 'phase'])}")
|
122
128
|
|
123
|
-
return epoch_metrics
|
129
|
+
return epoch_metrics
|
130
|
+
|
@@ -58,6 +58,12 @@ class GNNLinkPredictionTrainer(Trainer):
|
|
58
58
|
all_preds, all_labels = list(), list()
|
59
59
|
epoch_loss = 0
|
60
60
|
epoch_metrics = defaultdict(float)
|
61
|
+
|
62
|
+
total_pos_edges = sum([data.train_pos_edge_label_index.size(1) for data in self.dataloader.dataset])
|
63
|
+
total_neg_edges = sum([data.train_neg_edge_label_index.size(1) for data in self.dataloader.dataset])
|
64
|
+
print(f"Total positive edges: {total_pos_edges}")
|
65
|
+
print(f"Total negative edges: {total_neg_edges}")
|
66
|
+
|
61
67
|
for data in tqdm(self.dataloader, desc='Training Batches'):
|
62
68
|
self.optimizer.zero_grad()
|
63
69
|
self.model.zero_grad()
|
@@ -103,14 +109,23 @@ class GNNLinkPredictionTrainer(Trainer):
|
|
103
109
|
for data in tqdm(self.dataloader, desc='Testing Batches'):
|
104
110
|
|
105
111
|
x = data.x
|
112
|
+
|
113
|
+
train_edge_index = torch.cat([
|
114
|
+
data.train_pos_edge_label_index,
|
115
|
+
data.train_neg_edge_label_index
|
116
|
+
], dim=1)
|
117
|
+
train_edge_attr = (
|
118
|
+
data.edge_attr[data.train_edge_mask]
|
119
|
+
if self.use_edge_attrs else None
|
120
|
+
)
|
121
|
+
|
122
|
+
h = self.get_logits(x, train_edge_index, train_edge_attr)
|
123
|
+
|
106
124
|
pos_edge_index = data.test_pos_edge_label_index
|
107
125
|
neg_edge_index = data.test_neg_edge_label_index
|
108
126
|
test_mask = data.test_edge_mask
|
109
127
|
edge_attr = data.edge_attr[test_mask] if self.use_edge_attrs else None
|
110
128
|
|
111
|
-
|
112
|
-
h = self.get_logits(x, pos_edge_index, edge_attr)
|
113
|
-
# h = x
|
114
129
|
pos_score = self.get_prediction_score(h, pos_edge_index, edge_attr)
|
115
130
|
neg_score = self.get_prediction_score(h, neg_edge_index, edge_attr)
|
116
131
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
from sklearn.metrics import roc_auc_score, average_precision_score
|
2
|
+
from torch_geometric.loader import DataLoader
|
3
|
+
from torch_geometric.data import Data
|
4
|
+
from torch_geometric.nn import GATConv, VGAE
|
5
|
+
import torch.nn.functional as F
|
6
|
+
import torch
|
7
|
+
from typing import List
|
8
|
+
|
9
|
+
from glam4cm.models.gnn_layers import (
|
10
|
+
GNNConv,
|
11
|
+
EdgeClassifer
|
12
|
+
)
|
13
|
+
|
14
|
+
from tqdm.auto import tqdm
|
15
|
+
from glam4cm.settings import device
|
16
|
+
|
17
|
+
|
18
|
+
class GATVGAEEncoder(torch.nn.Module):
|
19
|
+
def __init__(self, in_channels, hid_channels, out_channels, heads=(4,2), dropout=0.3):
|
20
|
+
super().__init__()
|
21
|
+
self.conv1 = GATConv(in_channels, hid_channels, heads=heads[0], dropout=dropout)
|
22
|
+
# mu and log_std each map to latent dim
|
23
|
+
self.conv_mu = GATConv(hid_channels * heads[0], out_channels, heads=heads[1], concat=False)
|
24
|
+
self.conv_logstd = GATConv(hid_channels * heads[0], out_channels, heads=heads[1], concat=False)
|
25
|
+
|
26
|
+
def forward(self, x, edge_index):
|
27
|
+
x = F.elu(self.conv1(x, edge_index))
|
28
|
+
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
|
29
|
+
|
30
|
+
|
31
|
+
|
32
|
+
class GNNLinkPredictionTrainerV2:
|
33
|
+
"""
|
34
|
+
Trainer class for GNN Link Prediction
|
35
|
+
This class is used to train the GNN model for the link prediction task
|
36
|
+
The model is trained to predict the link between two nodes
|
37
|
+
"""
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
model: GNNConv,
|
41
|
+
predictor: EdgeClassifer,
|
42
|
+
dataset: List[Data],
|
43
|
+
cls_label='type',
|
44
|
+
lr=1e-3,
|
45
|
+
num_epochs=100,
|
46
|
+
batch_size=32,
|
47
|
+
use_edge_attrs=False,
|
48
|
+
logs_dir='./logs'
|
49
|
+
) -> None:
|
50
|
+
|
51
|
+
self.num_epochs = num_epochs
|
52
|
+
self.lr = lr
|
53
|
+
in_dim = dataset[0].data.x.shape[1]
|
54
|
+
hid_dim = 64
|
55
|
+
out_dim = 32
|
56
|
+
self.encoder = GATVGAEEncoder(in_dim, hid_dim, out_dim).to(device)
|
57
|
+
self.model = VGAE(self.encoder).to(device)
|
58
|
+
self.opt = torch.optim.Adam(model.parameters(), lr=lr)
|
59
|
+
|
60
|
+
self.dataloader = DataLoader(
|
61
|
+
dataset, batch_size=batch_size, shuffle=True
|
62
|
+
)
|
63
|
+
self.results = list()
|
64
|
+
|
65
|
+
print("GNN Trainer initialized.")
|
66
|
+
|
67
|
+
|
68
|
+
|
69
|
+
def train(self):
|
70
|
+
self.model.train()
|
71
|
+
total_loss = 0.0
|
72
|
+
|
73
|
+
for data in self.dataloader:
|
74
|
+
data = data.to(device)
|
75
|
+
# Encode over the **train positive** graph only:
|
76
|
+
z = self.model.encode(data.x, data.train_pos_edge_label_index)
|
77
|
+
|
78
|
+
# recon_loss only on positives:
|
79
|
+
loss = self.model.recon_loss(z, data.train_pos_edge_label_index)
|
80
|
+
# KL regularizer:
|
81
|
+
loss += (1. / data.num_nodes) * self.model.kl_loss()
|
82
|
+
|
83
|
+
self.opt.zero_grad()
|
84
|
+
loss.backward()
|
85
|
+
self.opt.step()
|
86
|
+
|
87
|
+
total_loss += loss.item()
|
88
|
+
|
89
|
+
return total_loss / len(self.dataloader)
|
90
|
+
|
91
|
+
|
92
|
+
@torch.no_grad()
|
93
|
+
def test(self):
|
94
|
+
self.model.eval()
|
95
|
+
all_auc, all_ap = [], []
|
96
|
+
|
97
|
+
for data in self.dataloader:
|
98
|
+
data = data.to(device)
|
99
|
+
z = self.model.encode(data.x, data.train_pos_edge_label_index)
|
100
|
+
|
101
|
+
# positive edges from your test split
|
102
|
+
pos_idx = data.test_pos_edge_label_index
|
103
|
+
# generate equal‐size negative sample
|
104
|
+
neg_idx = data.test_neg_edge_label_index
|
105
|
+
|
106
|
+
pos_scores = self.model.decoder(z, pos_idx).sigmoid()
|
107
|
+
neg_scores = self.model.decoder(z, neg_idx).sigmoid()
|
108
|
+
|
109
|
+
y_true = torch.cat([torch.ones(pos_scores.size(0)),
|
110
|
+
torch.zeros(neg_scores.size(0))]).cpu()
|
111
|
+
y_score = torch.cat([pos_scores, neg_scores]).cpu()
|
112
|
+
|
113
|
+
all_auc.append( roc_auc_score(y_true, y_score) )
|
114
|
+
all_ap.append( average_precision_score(y_true, y_score) )
|
115
|
+
|
116
|
+
return {
|
117
|
+
'AUC': sum(all_auc) / len(all_auc),
|
118
|
+
'AP': sum(all_ap) / len(all_ap),
|
119
|
+
}
|
120
|
+
|
121
|
+
|
122
|
+
def compute_loss(self, pos_score, neg_score):
|
123
|
+
pos_label = torch.ones(pos_score.size(0), dtype=torch.long).to(device)
|
124
|
+
neg_label = torch.zeros(neg_score.size(0), dtype=torch.long).to(device)
|
125
|
+
|
126
|
+
scores = torch.cat([pos_score, neg_score], dim=0)
|
127
|
+
labels = torch.cat([pos_label, neg_label], dim=0)
|
128
|
+
|
129
|
+
loss = self.criterion(scores, labels)
|
130
|
+
return loss
|
131
|
+
|
132
|
+
|
133
|
+
def run(self):
|
134
|
+
all_metrics = list()
|
135
|
+
for epoch in tqdm(range(self.num_epochs), desc="Running Epochs"):
|
136
|
+
self.train()
|
137
|
+
test_metrics = self.test()
|
138
|
+
all_metrics.append(test_metrics)
|
139
|
+
print(f"Epoch {epoch+1}/{self.num_epochs} | AUC: {test_metrics['AUC']:.4f} | AP: {test_metrics['AP']:.4f}")
|
140
|
+
|
141
|
+
print("Training complete.")
|
142
|
+
best_metrics = sorted(all_metrics, key=lambda x: x['AUC'], reverse=True)[0]
|
143
|
+
|
144
|
+
s2t = lambda x: x.replace("_", " ").title()
|
145
|
+
print(f"Best: {' | '.join([f'{s2t(k)}: {v:.4f}' for k, v in best_metrics.items()])}")
|
146
|
+
|
glam4cm/trainers/gnn_trainer.py
CHANGED
@@ -111,9 +111,11 @@ class Trainer:
|
|
111
111
|
|
112
112
|
|
113
113
|
def run(self):
|
114
|
+
all_metrics = list()
|
114
115
|
for epoch in tqdm(range(self.num_epochs), desc="Running Epochs"):
|
115
116
|
train_metrics = self.train()
|
116
117
|
test_metrics = self.test()
|
118
|
+
all_metrics.append(test_metrics)
|
117
119
|
|
118
120
|
for k, v in train_metrics.items():
|
119
121
|
if k != 'phase':
|
@@ -124,6 +126,12 @@ class Trainer:
|
|
124
126
|
self.writer.add_scalar(f"test/{k}", v, epoch)
|
125
127
|
|
126
128
|
self.writer.close()
|
129
|
+
print("Training complete.")
|
130
|
+
best_metrics = sorted(all_metrics, key=lambda x: x['balanced_accuracy'], reverse=True)[0]
|
131
|
+
|
132
|
+
s2t = lambda x: x.replace("_", " ").title()
|
133
|
+
print(f"Best: {' | '.join([f'{s2t(k)}: {v:.4f}' for k, v in best_metrics.items() if k != 'phase'])}")
|
134
|
+
|
127
135
|
|
128
136
|
def compute_metrics(self, all_preds, all_labels):
|
129
137
|
return compute_classification_metrics(all_preds, all_labels)
|
glam4cm/trainers/metrics.py
CHANGED
@@ -20,7 +20,7 @@ def compute_metrics(p):
|
|
20
20
|
}
|
21
21
|
|
22
22
|
|
23
|
-
def compute_classification_metrics(preds, labels):
|
23
|
+
def compute_classification_metrics(preds: torch.Tensor, labels: torch.Tensor) -> dict:
|
24
24
|
"""
|
25
25
|
Compute F1-score, balanced accuracy, precision, and recall for multi-class classification.
|
26
26
|
|