glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,137 @@
|
|
1
|
+
import os
|
2
|
+
from sklearn.metrics import (
|
3
|
+
accuracy_score,
|
4
|
+
balanced_accuracy_score,
|
5
|
+
f1_score,
|
6
|
+
recall_score
|
7
|
+
)
|
8
|
+
from transformers import (
|
9
|
+
Trainer,
|
10
|
+
TrainingArguments
|
11
|
+
)
|
12
|
+
|
13
|
+
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
14
|
+
from glam4cm.models.hf import get_model
|
15
|
+
from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
|
16
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
17
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
18
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
19
|
+
|
20
|
+
|
21
|
+
|
22
|
+
def compute_metrics(pred):
|
23
|
+
labels = pred.label_ids
|
24
|
+
preds = pred.predictions.argmax(-1)
|
25
|
+
acc = (preds == labels).mean()
|
26
|
+
f1_macro = f1_score(labels, preds, average='macro')
|
27
|
+
accuracy = accuracy_score(labels, preds)
|
28
|
+
recall = recall_score(labels, preds, average='macro')
|
29
|
+
balanced_acc = balanced_accuracy_score(labels, preds)
|
30
|
+
|
31
|
+
return {
|
32
|
+
'balanced_accuracy': balanced_acc,
|
33
|
+
'accuracy': acc,
|
34
|
+
'f1_macro': f1_macro,
|
35
|
+
'precision': accuracy,
|
36
|
+
'recall': recall
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
def get_parser():
|
41
|
+
common_parser = get_common_args_parser()
|
42
|
+
bert_parser = get_bert_args_parser()
|
43
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
44
|
+
|
45
|
+
parser.add_argument('--cls_label', type=str, default='label')
|
46
|
+
parser.add_argument('--remove_duplicate_graphs', action='store_true')
|
47
|
+
return parser
|
48
|
+
|
49
|
+
|
50
|
+
def run(args):
|
51
|
+
set_seed(args.seed)
|
52
|
+
|
53
|
+
config_params = dict(
|
54
|
+
min_enr = args.min_enr,
|
55
|
+
min_edges = args.min_edges,
|
56
|
+
remove_duplicates = args.remove_duplicates,
|
57
|
+
reload = args.reload,
|
58
|
+
language = args.language
|
59
|
+
)
|
60
|
+
dataset_name = args.dataset
|
61
|
+
|
62
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
63
|
+
graph_data_params = get_config_params(args)
|
64
|
+
print("Loading graph dataset")
|
65
|
+
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
66
|
+
print("Loaded graph dataset")
|
67
|
+
|
68
|
+
model_name = args.model_name
|
69
|
+
tokenizer = get_tokenizer(model_name, args.use_special_tokens)
|
70
|
+
|
71
|
+
fold_id = 0
|
72
|
+
for classification_dataset in graph_dataset.get_kfold_lm_graph_classification_data(
|
73
|
+
tokenizer,
|
74
|
+
remove_duplicates=args.remove_duplicate_graphs
|
75
|
+
):
|
76
|
+
train_dataset = classification_dataset['train']
|
77
|
+
test_dataset = classification_dataset['test']
|
78
|
+
num_labels = classification_dataset['num_classes']
|
79
|
+
|
80
|
+
print(len(train_dataset), len(test_dataset), num_labels)
|
81
|
+
|
82
|
+
print("Training model")
|
83
|
+
output_dir = os.path.join(
|
84
|
+
'results',
|
85
|
+
dataset_name,
|
86
|
+
f'graph_cls_',
|
87
|
+
f"{graph_dataset.config_hash}",
|
88
|
+
)
|
89
|
+
|
90
|
+
logs_dir = os.path.join(
|
91
|
+
'logs',
|
92
|
+
dataset_name,
|
93
|
+
f'graph_cls_',
|
94
|
+
f"{graph_dataset.config_hash}"
|
95
|
+
)
|
96
|
+
|
97
|
+
model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
|
98
|
+
|
99
|
+
if args.freeze_pretrained_weights:
|
100
|
+
for param in model.base_model.parameters():
|
101
|
+
param.requires_grad = False
|
102
|
+
|
103
|
+
# Training arguments
|
104
|
+
training_args = TrainingArguments(
|
105
|
+
output_dir=output_dir,
|
106
|
+
num_train_epochs=args.num_epochs,
|
107
|
+
eval_strategy="steps",
|
108
|
+
per_device_train_batch_size=args.train_batch_size,
|
109
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
110
|
+
warmup_steps=200,
|
111
|
+
weight_decay=0.01,
|
112
|
+
learning_rate=5e-5,
|
113
|
+
logging_dir=logs_dir,
|
114
|
+
logging_steps=args.num_log_steps,
|
115
|
+
eval_steps=args.num_eval_steps,
|
116
|
+
save_steps=args.num_save_steps,
|
117
|
+
save_total_limit=2,
|
118
|
+
load_best_model_at_end=True,
|
119
|
+
fp16=True
|
120
|
+
)
|
121
|
+
|
122
|
+
# Trainer
|
123
|
+
trainer = Trainer(
|
124
|
+
model=model,
|
125
|
+
args=training_args,
|
126
|
+
train_dataset=train_dataset,
|
127
|
+
eval_dataset=test_dataset,
|
128
|
+
compute_metrics=compute_metrics
|
129
|
+
)
|
130
|
+
|
131
|
+
# Train the model
|
132
|
+
trainer.train()
|
133
|
+
results = trainer.evaluate()
|
134
|
+
print(results)
|
135
|
+
|
136
|
+
fold_id += 1
|
137
|
+
break
|
@@ -0,0 +1,156 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
from argparse import ArgumentParser
|
4
|
+
from random import shuffle
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from sklearn.model_selection import StratifiedKFold
|
8
|
+
|
9
|
+
from sklearn.metrics import (
|
10
|
+
accuracy_score,
|
11
|
+
balanced_accuracy_score,
|
12
|
+
f1_score,
|
13
|
+
recall_score
|
14
|
+
)
|
15
|
+
from transformers import (
|
16
|
+
AutoTokenizer,
|
17
|
+
Trainer,
|
18
|
+
TrainingArguments
|
19
|
+
)
|
20
|
+
|
21
|
+
from glam4cm.data_loading.encoding import EncodingDataset
|
22
|
+
from glam4cm.models.hf import get_model
|
23
|
+
|
24
|
+
|
25
|
+
|
26
|
+
def compute_metrics(pred):
|
27
|
+
labels = pred.label_ids
|
28
|
+
preds = np.argmax(pred.predictions, axis=1)
|
29
|
+
acc = (preds == labels).mean()
|
30
|
+
f1_macro = f1_score(labels, preds, average='macro')
|
31
|
+
accuracy = accuracy_score(labels, preds)
|
32
|
+
recall = recall_score(labels, preds, average='macro')
|
33
|
+
balanced_acc = balanced_accuracy_score(labels, preds)
|
34
|
+
|
35
|
+
return {
|
36
|
+
'balanced_accuracy': balanced_acc,
|
37
|
+
'accuracy': acc,
|
38
|
+
'f1_macro': f1_macro,
|
39
|
+
'precision': accuracy,
|
40
|
+
'recall': recall
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
def get_parser():
|
45
|
+
parser = ArgumentParser()
|
46
|
+
parser.add_argument('--seed', type=int, default=42)
|
47
|
+
parser.add_argument('--dataset_name', type=str, default='ecore_555')
|
48
|
+
parser.add_argument('--model_name', type=str, default='bert-base-uncased')
|
49
|
+
parser.add_argument('--ckpt', type=str, default=None)
|
50
|
+
parser.add_argument('--max_length', type=int, default=512)
|
51
|
+
parser.add_argument('--k', type=int, default=10)
|
52
|
+
|
53
|
+
parser.add_argument('--num_epochs', type=int, default=10)
|
54
|
+
|
55
|
+
parser.add_argument('--warmup_steps', type=int, default=500)
|
56
|
+
parser.add_argument('--num_log_steps', type=int, default=50)
|
57
|
+
parser.add_argument('--num_eval_steps', type=int, default=50)
|
58
|
+
parser.add_argument('--num_save_steps', type=int, default=50)
|
59
|
+
parser.add_argument('--train_batch_size', type=int, default=2)
|
60
|
+
parser.add_argument('--eval_batch_size', type=int, default=128)
|
61
|
+
parser.add_argument('--lr', type=float, default=1e-5)
|
62
|
+
|
63
|
+
return parser
|
64
|
+
|
65
|
+
|
66
|
+
def run(args):
|
67
|
+
dataset_name = args.dataset_name
|
68
|
+
model_name = args.model_name
|
69
|
+
|
70
|
+
|
71
|
+
texts = [
|
72
|
+
(g['txt'], g['labels'])
|
73
|
+
for file_name in os.listdir(f'datasets/{dataset_name}')
|
74
|
+
for g in json.load(open(f'datasets/{dataset_name}/{file_name}'))
|
75
|
+
if 'ecore' in file_name and file_name.endswith('.jsonl')
|
76
|
+
]
|
77
|
+
shuffle(texts)
|
78
|
+
labels = [y for _, y in texts]
|
79
|
+
y_map = {label: i for i, label in enumerate(set(y for y in labels))}
|
80
|
+
y = [y_map[y] for y in labels]
|
81
|
+
n = len(texts)
|
82
|
+
|
83
|
+
texts = [text for text, _ in texts]
|
84
|
+
|
85
|
+
num_labels = len(y_map)
|
86
|
+
|
87
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
88
|
+
k = args.k
|
89
|
+
kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=args.seed)
|
90
|
+
|
91
|
+
i = 0
|
92
|
+
for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
|
93
|
+
|
94
|
+
print(f'Fold {i+1}/{k}')
|
95
|
+
|
96
|
+
train_texts = [texts[i] for i in train_idx]
|
97
|
+
test_texts = [texts[i] for i in test_idx]
|
98
|
+
train_y = [y[i] for i in train_idx]
|
99
|
+
test_y = [y[i] for i in test_idx]
|
100
|
+
|
101
|
+
|
102
|
+
print(f'Train: {len(train_texts)}, Test: {len(test_texts)}', num_labels)
|
103
|
+
|
104
|
+
train_dataset = EncodingDataset(tokenizer, train_texts, train_y)
|
105
|
+
test_dataset = EncodingDataset(tokenizer, test_texts, test_y)
|
106
|
+
|
107
|
+
model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
|
108
|
+
|
109
|
+
print("Training model")
|
110
|
+
output_dir = os.path.join(
|
111
|
+
'results',
|
112
|
+
dataset_name,
|
113
|
+
'graph_cls_comp',
|
114
|
+
)
|
115
|
+
|
116
|
+
logs_dir = os.path.join(
|
117
|
+
'logs',
|
118
|
+
dataset_name,
|
119
|
+
'graph_cls_comp',
|
120
|
+
)
|
121
|
+
|
122
|
+
print("Running epochs: ", args.num_epochs)
|
123
|
+
|
124
|
+
# Training arguments
|
125
|
+
training_args = TrainingArguments(
|
126
|
+
output_dir=output_dir,
|
127
|
+
num_train_epochs=args.num_epochs,
|
128
|
+
eval_strategy="steps",
|
129
|
+
per_device_train_batch_size=args.train_batch_size,
|
130
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
131
|
+
warmup_steps=500,
|
132
|
+
weight_decay=0.01,
|
133
|
+
logging_dir=logs_dir,
|
134
|
+
logging_steps=10,
|
135
|
+
eval_steps=10,
|
136
|
+
save_total_limit=2,
|
137
|
+
load_best_model_at_end=True,
|
138
|
+
fp16=True
|
139
|
+
)
|
140
|
+
|
141
|
+
# Trainer
|
142
|
+
trainer = Trainer(
|
143
|
+
model=model,
|
144
|
+
args=training_args,
|
145
|
+
train_dataset=train_dataset,
|
146
|
+
eval_dataset=test_dataset,
|
147
|
+
compute_metrics=compute_metrics
|
148
|
+
)
|
149
|
+
|
150
|
+
# Train the model
|
151
|
+
trainer.train()
|
152
|
+
results = trainer.evaluate()
|
153
|
+
print(results)
|
154
|
+
|
155
|
+
i += 1
|
156
|
+
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from collections import Counter
|
2
|
+
import os
|
3
|
+
from transformers import TrainingArguments, Trainer
|
4
|
+
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
5
|
+
from glam4cm.models.hf import get_model
|
6
|
+
from glam4cm.settings import LP_TASK_LINK_PRED
|
7
|
+
from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
|
8
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
9
|
+
from glam4cm.tokenization.special_tokens import *
|
10
|
+
|
11
|
+
|
12
|
+
from sklearn.metrics import (
|
13
|
+
f1_score,
|
14
|
+
precision_score,
|
15
|
+
balanced_accuracy_score,
|
16
|
+
recall_score
|
17
|
+
)
|
18
|
+
|
19
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
20
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
21
|
+
|
22
|
+
|
23
|
+
def compute_metrics(pred):
|
24
|
+
labels = pred.label_ids
|
25
|
+
preds = pred.predictions.argmax(-1)
|
26
|
+
print(Counter(preds), Counter(labels))
|
27
|
+
acc = (preds == labels).mean()
|
28
|
+
# roc = roc_auc_score(labels, preds)
|
29
|
+
f1_macro = f1_score(labels, preds)
|
30
|
+
f1_micro = f1_score(labels, preds, )
|
31
|
+
precision = precision_score(labels, preds)
|
32
|
+
balanced_accuracy = balanced_accuracy_score(labels, preds)
|
33
|
+
recall = recall_score(labels, preds)
|
34
|
+
|
35
|
+
return {
|
36
|
+
'accuracy': acc,
|
37
|
+
'f1_macro': f1_macro,
|
38
|
+
'f1_micro': f1_micro,
|
39
|
+
'precision': precision,
|
40
|
+
'recall': recall,
|
41
|
+
'balanced_accuracy': balanced_accuracy
|
42
|
+
}
|
43
|
+
|
44
|
+
|
45
|
+
|
46
|
+
def get_parser():
|
47
|
+
|
48
|
+
common_parser = get_common_args_parser()
|
49
|
+
bert_parser = get_bert_args_parser()
|
50
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
51
|
+
|
52
|
+
return parser
|
53
|
+
|
54
|
+
|
55
|
+
def run(args):
|
56
|
+
set_seed(args.seed)
|
57
|
+
|
58
|
+
config_params = dict(
|
59
|
+
min_enr = args.min_enr,
|
60
|
+
min_edges = args.min_edges,
|
61
|
+
remove_duplicates = args.remove_duplicates,
|
62
|
+
reload = args.reload,
|
63
|
+
language = args.language
|
64
|
+
)
|
65
|
+
dataset_name = args.dataset
|
66
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
67
|
+
|
68
|
+
print("Loaded dataset")
|
69
|
+
|
70
|
+
|
71
|
+
graph_data_params = get_config_params(args)
|
72
|
+
graph_data_params = {**graph_data_params, 'task': LP_TASK_LINK_PRED}
|
73
|
+
|
74
|
+
print("Loading graph dataset")
|
75
|
+
graph_dataset = GraphEdgeDataset(
|
76
|
+
dataset,
|
77
|
+
dict(
|
78
|
+
**graph_data_params,
|
79
|
+
add_negative_train_samples=args.add_negative_train_samples,
|
80
|
+
neg_sampling_ratio=args.neg_sampling_ratio,
|
81
|
+
task=LP_TASK_LINK_PRED
|
82
|
+
))
|
83
|
+
print("Loaded graph dataset")
|
84
|
+
|
85
|
+
|
86
|
+
|
87
|
+
model_name = args.model_name
|
88
|
+
tokenizer = get_tokenizer(model_name, args.use_special_tokens)
|
89
|
+
|
90
|
+
|
91
|
+
print("Getting link prediction data")
|
92
|
+
bert_dataset = graph_dataset.get_link_prediction_lm_data(
|
93
|
+
tokenizer=tokenizer,
|
94
|
+
task_type=LP_TASK_LINK_PRED
|
95
|
+
)
|
96
|
+
|
97
|
+
print("Training model")
|
98
|
+
model = get_model(args.ckpt if args.ckpt else model_name, num_labels=2, len_tokenizer=len(tokenizer))
|
99
|
+
|
100
|
+
if args.freeze_pretrained_weights:
|
101
|
+
for param in model.base_model.parameters():
|
102
|
+
param.requires_grad = False
|
103
|
+
|
104
|
+
|
105
|
+
output_dir = os.path.join(
|
106
|
+
'results',
|
107
|
+
dataset_name,
|
108
|
+
'lp',
|
109
|
+
f"{graph_dataset.config_hash}",
|
110
|
+
)
|
111
|
+
|
112
|
+
logs_dir = os.path.join(
|
113
|
+
'logs',
|
114
|
+
dataset_name,
|
115
|
+
'lp',
|
116
|
+
f"{graph_dataset.config_hash}",
|
117
|
+
)
|
118
|
+
|
119
|
+
training_args = TrainingArguments(
|
120
|
+
output_dir=output_dir,
|
121
|
+
num_train_epochs=args.num_epochs,
|
122
|
+
per_device_train_batch_size=args.train_batch_size,
|
123
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
124
|
+
weight_decay=0.01,
|
125
|
+
logging_dir=logs_dir,
|
126
|
+
logging_steps=200,
|
127
|
+
eval_strategy='steps',
|
128
|
+
eval_steps=200,
|
129
|
+
save_steps=200,
|
130
|
+
save_total_limit=2,
|
131
|
+
load_best_model_at_end=True,
|
132
|
+
fp16=True,
|
133
|
+
)
|
134
|
+
|
135
|
+
trainer = Trainer(
|
136
|
+
model=model,
|
137
|
+
args=training_args,
|
138
|
+
train_dataset=bert_dataset['train'],
|
139
|
+
eval_dataset=bert_dataset['test'],
|
140
|
+
compute_metrics=compute_metrics
|
141
|
+
)
|
142
|
+
|
143
|
+
trainer.train()
|
144
|
+
print(trainer.evaluate())
|
145
|
+
trainer.save_model()
|
@@ -0,0 +1,164 @@
|
|
1
|
+
from glam4cm.models.hf import get_model
|
2
|
+
from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
|
3
|
+
import os
|
4
|
+
from transformers import TrainingArguments, Trainer
|
5
|
+
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
6
|
+
from glam4cm.data_loading.utils import oversample_dataset
|
7
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
8
|
+
from glam4cm.tokenization.special_tokens import *
|
9
|
+
|
10
|
+
|
11
|
+
from sklearn.metrics import (
|
12
|
+
accuracy_score,
|
13
|
+
f1_score,
|
14
|
+
recall_score,
|
15
|
+
balanced_accuracy_score
|
16
|
+
)
|
17
|
+
|
18
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
19
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
20
|
+
|
21
|
+
|
22
|
+
|
23
|
+
def compute_metrics(pred):
|
24
|
+
labels = pred.label_ids
|
25
|
+
preds = pred.predictions.argmax(-1)
|
26
|
+
acc = (preds == labels).mean()
|
27
|
+
f1_macro = f1_score(labels, preds, average='macro')
|
28
|
+
accuracy = accuracy_score(labels, preds)
|
29
|
+
recall = recall_score(labels, preds, average='macro')
|
30
|
+
balanced_acc = balanced_accuracy_score(labels, preds)
|
31
|
+
|
32
|
+
return {
|
33
|
+
'balanced_accuracy': balanced_acc,
|
34
|
+
'accuracy': acc,
|
35
|
+
'f1_macro': f1_macro,
|
36
|
+
'precision': accuracy,
|
37
|
+
'recall': recall
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
def get_num_labels(dataset):
|
42
|
+
train_labels = dataset['train'][:]['labels'].unique().tolist()
|
43
|
+
test_labels = dataset['test'][:]['labels'].unique().tolist()
|
44
|
+
return len(set(train_labels + test_labels))
|
45
|
+
|
46
|
+
|
47
|
+
def get_parser():
|
48
|
+
common_parser = get_common_args_parser()
|
49
|
+
bert_parser = get_bert_args_parser()
|
50
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
51
|
+
|
52
|
+
parser.add_argument('--oversampling_ratio', type=float, default=-1)
|
53
|
+
|
54
|
+
return parser
|
55
|
+
|
56
|
+
|
57
|
+
|
58
|
+
def run(args):
|
59
|
+
set_seed(args.seed)
|
60
|
+
|
61
|
+
config_params = dict(
|
62
|
+
min_enr = args.min_enr,
|
63
|
+
min_edges = args.min_edges,
|
64
|
+
remove_duplicates = args.remove_duplicates,
|
65
|
+
reload=args.reload,
|
66
|
+
language = args.language
|
67
|
+
)
|
68
|
+
dataset_name = args.dataset
|
69
|
+
distance = args.distance
|
70
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
71
|
+
|
72
|
+
print("Loaded dataset")
|
73
|
+
|
74
|
+
graph_data_params = get_config_params(args)
|
75
|
+
print("Loading graph dataset")
|
76
|
+
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
77
|
+
print("Loaded graph dataset")
|
78
|
+
|
79
|
+
assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node_{args.node_cls_label} attribute"
|
80
|
+
num_labels = getattr(graph_dataset, f"num_nodes_{args.node_cls_label}")
|
81
|
+
|
82
|
+
model_name = args.model_name
|
83
|
+
tokenizer = get_tokenizer(model_name, use_special_tokens=args.use_special_tokens)
|
84
|
+
|
85
|
+
print("Getting node classification data")
|
86
|
+
bert_dataset = graph_dataset.get_node_classification_lm_data(
|
87
|
+
label=args.node_cls_label,
|
88
|
+
tokenizer=tokenizer,
|
89
|
+
distance=distance,
|
90
|
+
)
|
91
|
+
|
92
|
+
# exit(0)
|
93
|
+
|
94
|
+
if args.oversampling_ratio != -1:
|
95
|
+
ind_w_oversamples = oversample_dataset(bert_dataset['train'])
|
96
|
+
bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
|
97
|
+
|
98
|
+
model = get_model(
|
99
|
+
args.ckpt if args.ckpt else model_name,
|
100
|
+
num_labels=2,
|
101
|
+
len_tokenizer=len(tokenizer)
|
102
|
+
)
|
103
|
+
|
104
|
+
if args.freeze_pretrained_weights:
|
105
|
+
for param in model.base_model.parameters():
|
106
|
+
param.requires_grad = False
|
107
|
+
|
108
|
+
|
109
|
+
print("Training model")
|
110
|
+
output_dir = os.path.join(
|
111
|
+
'results',
|
112
|
+
dataset_name,
|
113
|
+
'node_cls',
|
114
|
+
f'{args.node_cls_label}',
|
115
|
+
f"{graph_dataset.config_hash}",
|
116
|
+
)
|
117
|
+
|
118
|
+
logs_dir = os.path.join(
|
119
|
+
'logs',
|
120
|
+
dataset_name,
|
121
|
+
'node_cls',
|
122
|
+
f'{args.node_cls_label}',
|
123
|
+
f"{graph_dataset.config_hash}",
|
124
|
+
)
|
125
|
+
|
126
|
+
print("Output Dir: ", output_dir)
|
127
|
+
print("Logs Dir: ", logs_dir)
|
128
|
+
print("Len Train Dataset: ", len(bert_dataset['train']))
|
129
|
+
print("Len Test Dataset: ", len(bert_dataset['test']))
|
130
|
+
|
131
|
+
training_args = TrainingArguments(
|
132
|
+
output_dir=output_dir,
|
133
|
+
num_train_epochs=args.num_epochs,
|
134
|
+
per_device_train_batch_size=args.train_batch_size,
|
135
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
136
|
+
weight_decay=0.01,
|
137
|
+
logging_dir=logs_dir,
|
138
|
+
logging_steps=args.num_log_steps,
|
139
|
+
eval_strategy='steps',
|
140
|
+
eval_steps=args.num_eval_steps,
|
141
|
+
save_steps=args.num_save_steps,
|
142
|
+
save_total_limit=2,
|
143
|
+
load_best_model_at_end=True,
|
144
|
+
fp16=True,
|
145
|
+
)
|
146
|
+
|
147
|
+
trainer = Trainer(
|
148
|
+
model=model,
|
149
|
+
args=training_args,
|
150
|
+
train_dataset=bert_dataset['train'],
|
151
|
+
eval_dataset=bert_dataset['test'],
|
152
|
+
compute_metrics=compute_metrics
|
153
|
+
)
|
154
|
+
|
155
|
+
trainer.train()
|
156
|
+
results = trainer.evaluate()
|
157
|
+
print(results)
|
158
|
+
|
159
|
+
trainer.save_model()
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == '__main__':
|
163
|
+
args = get_parser()
|
164
|
+
run(args)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import os
|
2
|
+
from glam4cm.downstream_tasks.common_args import (
|
3
|
+
get_common_args_parser,
|
4
|
+
get_gpt_args_parser
|
5
|
+
)
|
6
|
+
|
7
|
+
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
8
|
+
from glam4cm.models.cmgpt import CMGPT, CMGPTClassifier
|
9
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
10
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
11
|
+
from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
|
12
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
13
|
+
|
14
|
+
|
15
|
+
def get_parser():
|
16
|
+
common_parser = get_common_args_parser()
|
17
|
+
bert_parser = get_gpt_args_parser()
|
18
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
19
|
+
|
20
|
+
parser.add_argument('--cls_label', type=str, default='type')
|
21
|
+
parser.add_argument('--pretr', type=str, default=None)
|
22
|
+
return parser
|
23
|
+
|
24
|
+
|
25
|
+
def run(args):
|
26
|
+
set_seed(args.seed)
|
27
|
+
|
28
|
+
tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
|
29
|
+
|
30
|
+
models_dataset_params = dict(
|
31
|
+
language='en',
|
32
|
+
)
|
33
|
+
|
34
|
+
graph_params = dict(
|
35
|
+
use_special_tokens=args.use_special_tokens,
|
36
|
+
distance=args.distance,
|
37
|
+
reload = args.reload
|
38
|
+
)
|
39
|
+
|
40
|
+
models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
|
41
|
+
graph_dataset = GraphEdgeDataset(models_dataset, **graph_params)
|
42
|
+
|
43
|
+
assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node labels for {args.node_cls_label}"
|
44
|
+
|
45
|
+
node_label_dataset = graph_dataset.get_link_prediction_lm_data(
|
46
|
+
tokenizer=tokenizer,
|
47
|
+
label=args.node_cls_label
|
48
|
+
)
|
49
|
+
|
50
|
+
if args.pretr and os.path.exists(args.pretr):
|
51
|
+
cmgpt = CMGPT.from_pretrained(f"{args.pretr}")
|
52
|
+
else:
|
53
|
+
cmgpt = CMGPT(
|
54
|
+
vocab_size=len(tokenizer),
|
55
|
+
embed_dim=args.embed_dim,
|
56
|
+
block_size=args.block_size,
|
57
|
+
n_layer=args.n_layer,
|
58
|
+
n_head=args.n_head,
|
59
|
+
)
|
60
|
+
|
61
|
+
cmgpt_classifier = CMGPTClassifier(cmgpt, num_classes=getattr(graph_dataset, f"num_nodes_{args.node_cls_label}"))
|
62
|
+
|
63
|
+
trainer = CMGPTTrainer(
|
64
|
+
cmgpt_classifier,
|
65
|
+
train_dataset=node_label_dataset['train'],
|
66
|
+
test_dataset=node_label_dataset['test'],
|
67
|
+
batch_size=args.batch_size,
|
68
|
+
num_epochs=args.num_epochs
|
69
|
+
)
|
70
|
+
|
71
|
+
trainer.train()
|
72
|
+
|
73
|
+
trainer.save_model()
|