glam4cm 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +2 -1
- glam4cm/data_loading/data.py +90 -146
- glam4cm/data_loading/encoding.py +17 -6
- glam4cm/data_loading/graph_dataset.py +192 -57
- glam4cm/data_loading/metadata.py +1 -1
- glam4cm/data_loading/models_dataset.py +42 -18
- glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
- glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
- glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
- glam4cm/downstream_tasks/bert_node_classification.py +127 -89
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
- glam4cm/downstream_tasks/common_args.py +32 -4
- glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
- glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
- glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
- glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
- glam4cm/downstream_tasks/utils.py +16 -2
- glam4cm/embeddings/bert.py +1 -1
- glam4cm/embeddings/common.py +7 -4
- glam4cm/encoding/encoders.py +1 -1
- glam4cm/lang2graph/archimate.py +0 -5
- glam4cm/lang2graph/common.py +99 -41
- glam4cm/lang2graph/ecore.py +1 -2
- glam4cm/lang2graph/ontouml.py +8 -7
- glam4cm/models/gnn_layers.py +20 -6
- glam4cm/models/hf.py +2 -2
- glam4cm/run.py +13 -9
- glam4cm/run_conf_v2.py +405 -0
- glam4cm/run_configs.py +70 -106
- glam4cm/run_confs.py +41 -0
- glam4cm/settings.py +15 -2
- glam4cm/tokenization/special_tokens.py +23 -1
- glam4cm/tokenization/utils.py +23 -4
- glam4cm/trainers/cm_gpt_trainer.py +1 -1
- glam4cm/trainers/gnn_edge_classifier.py +12 -1
- glam4cm/trainers/gnn_graph_classifier.py +12 -5
- glam4cm/trainers/gnn_link_predictor.py +18 -3
- glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
- glam4cm/trainers/gnn_trainer.py +8 -0
- glam4cm/trainers/metrics.py +1 -1
- glam4cm/utils.py +265 -2
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
- glam4cm-1.0.0.dist-info/RECORD +75 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
- glam4cm-0.1.0.dist-info/RECORD +0 -72
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,15 @@ import os
|
|
2
2
|
from transformers import TrainingArguments, Trainer
|
3
3
|
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
4
4
|
from glam4cm.data_loading.utils import oversample_dataset
|
5
|
-
from glam4cm.settings import
|
6
|
-
from glam4cm.downstream_tasks.common_args import
|
5
|
+
from glam4cm.settings import EDGE_CLS_TASK, results_dir
|
6
|
+
from glam4cm.downstream_tasks.common_args import (
|
7
|
+
get_bert_args_parser,
|
8
|
+
get_common_args_parser,
|
9
|
+
get_config_params,
|
10
|
+
get_config_str
|
11
|
+
)
|
7
12
|
from glam4cm.models.hf import get_model
|
8
|
-
from glam4cm.downstream_tasks.utils import get_models_dataset
|
13
|
+
from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
|
9
14
|
|
10
15
|
|
11
16
|
from sklearn.metrics import (
|
@@ -16,7 +21,7 @@ from sklearn.metrics import (
|
|
16
21
|
)
|
17
22
|
|
18
23
|
from glam4cm.tokenization.utils import get_tokenizer
|
19
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
24
|
+
from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
|
20
25
|
|
21
26
|
|
22
27
|
def compute_metrics(pred):
|
@@ -49,27 +54,39 @@ def get_parser():
|
|
49
54
|
|
50
55
|
def run(args):
|
51
56
|
set_seed(args.seed)
|
57
|
+
dataset_name = args.dataset
|
58
|
+
output_dir = os.path.join(
|
59
|
+
results_dir,
|
60
|
+
dataset_name,
|
61
|
+
f"LM_{EDGE_CLS_TASK}",
|
62
|
+
f'{args.edge_cls_label}',
|
63
|
+
get_config_str(args)
|
64
|
+
)
|
65
|
+
if os.path.exists(output_dir):
|
66
|
+
print(f"Output directory {output_dir} already exists. Exiting.")
|
67
|
+
exit(0)
|
52
68
|
|
53
69
|
config_params = dict(
|
70
|
+
include_dummies = args.include_dummies,
|
54
71
|
min_enr = args.min_enr,
|
55
72
|
min_edges = args.min_edges,
|
56
73
|
remove_duplicates = args.remove_duplicates,
|
57
74
|
language = args.language,
|
58
75
|
reload=args.reload
|
59
76
|
)
|
60
|
-
|
77
|
+
|
61
78
|
|
62
79
|
print("Loaded dataset")
|
63
80
|
dataset = get_models_dataset(dataset_name, **config_params)
|
64
81
|
|
65
82
|
graph_data_params = get_config_params(args)
|
66
|
-
graph_data_params = {**graph_data_params, 'task_type':
|
83
|
+
graph_data_params = {**graph_data_params, 'task_type': EDGE_CLS_TASK}
|
67
84
|
|
68
85
|
print("Loading graph dataset")
|
69
86
|
graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
|
70
87
|
print("Loaded graph dataset")
|
71
88
|
|
72
|
-
assert hasattr(graph_dataset, f'num_edges_{args.edge_cls_label}'), f"Dataset does not have
|
89
|
+
assert hasattr(graph_dataset, f'num_edges_{args.edge_cls_label}'), f"Dataset does not have edge_{args.edge_cls_label} attribute"
|
73
90
|
num_labels = getattr(graph_dataset, f"num_edges_{args.edge_cls_label}")
|
74
91
|
|
75
92
|
|
@@ -78,6 +95,11 @@ def run(args):
|
|
78
95
|
|
79
96
|
print("Getting Edge Classification data")
|
80
97
|
bert_dataset = graph_dataset.get_link_prediction_lm_data(tokenizer=tokenizer)
|
98
|
+
|
99
|
+
train_dataset = bert_dataset['train']
|
100
|
+
test_dataset = bert_dataset['test']
|
101
|
+
set_encoded_labels(train_dataset, test_dataset)
|
102
|
+
|
81
103
|
|
82
104
|
# exit(0)
|
83
105
|
|
@@ -88,28 +110,32 @@ def run(args):
|
|
88
110
|
print("Training model")
|
89
111
|
print(f'Number of labels: {num_labels}')
|
90
112
|
|
91
|
-
model = get_model(
|
92
|
-
|
113
|
+
model = get_model(
|
114
|
+
args.ckpt if args.ckpt else model_name,
|
115
|
+
num_labels,
|
116
|
+
len(tokenizer),
|
117
|
+
trust_remote_code=args.trust_remote_code
|
118
|
+
)
|
119
|
+
|
93
120
|
if args.freeze_pretrained_weights:
|
94
121
|
for param in model.base_model.parameters():
|
95
122
|
param.requires_grad = False
|
96
123
|
|
97
|
-
output_dir = os.path.join(
|
98
|
-
'results',
|
99
|
-
dataset_name,
|
100
|
-
'edge_cls',
|
101
|
-
f'{args.edge_cls_label}',
|
102
|
-
f"{graph_dataset.config_hash}",
|
103
|
-
)
|
104
124
|
|
105
125
|
logs_dir = os.path.join(
|
106
126
|
'logs',
|
107
127
|
dataset_name,
|
108
|
-
|
128
|
+
f"LM_{EDGE_CLS_TASK}",
|
109
129
|
f'{args.edge_cls_label}',
|
110
130
|
f"{graph_dataset.config_hash}",
|
111
131
|
)
|
112
132
|
|
133
|
+
logging_steps = get_logging_steps(
|
134
|
+
len(train_dataset),
|
135
|
+
args.num_epochs,
|
136
|
+
args.train_batch_size
|
137
|
+
)
|
138
|
+
|
113
139
|
training_args = TrainingArguments(
|
114
140
|
output_dir=output_dir,
|
115
141
|
num_train_epochs=args.num_epochs,
|
@@ -117,13 +143,14 @@ def run(args):
|
|
117
143
|
per_device_eval_batch_size=args.eval_batch_size,
|
118
144
|
weight_decay=0.01,
|
119
145
|
logging_dir=logs_dir,
|
120
|
-
logging_steps=
|
146
|
+
logging_steps=logging_steps,
|
121
147
|
eval_strategy='steps',
|
122
|
-
eval_steps=
|
123
|
-
save_steps=args.num_save_steps,
|
124
|
-
save_total_limit=2,
|
125
|
-
load_best_model_at_end=True,
|
148
|
+
eval_steps=logging_steps,
|
149
|
+
# save_steps=args.num_save_steps,
|
150
|
+
# save_total_limit=2,
|
151
|
+
# load_best_model_at_end=True,
|
126
152
|
fp16=True,
|
153
|
+
save_strategy="no"
|
127
154
|
)
|
128
155
|
|
129
156
|
trainer = Trainer(
|
@@ -12,10 +12,16 @@ from transformers import (
|
|
12
12
|
|
13
13
|
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
14
14
|
from glam4cm.models.hf import get_model
|
15
|
-
from glam4cm.downstream_tasks.common_args import
|
16
|
-
|
15
|
+
from glam4cm.downstream_tasks.common_args import (
|
16
|
+
get_bert_args_parser,
|
17
|
+
get_common_args_parser,
|
18
|
+
get_config_params,
|
19
|
+
get_config_str
|
20
|
+
)
|
21
|
+
from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
|
22
|
+
from glam4cm.settings import GRAPH_CLS_TASK, results_dir
|
17
23
|
from glam4cm.tokenization.utils import get_tokenizer
|
18
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
24
|
+
from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
|
19
25
|
|
20
26
|
|
21
27
|
|
@@ -51,6 +57,7 @@ def run(args):
|
|
51
57
|
set_seed(args.seed)
|
52
58
|
|
53
59
|
config_params = dict(
|
60
|
+
include_dummies = args.include_dummies,
|
54
61
|
min_enr = args.min_enr,
|
55
62
|
min_edges = args.min_edges,
|
56
63
|
remove_duplicates = args.remove_duplicates,
|
@@ -60,7 +67,7 @@ def run(args):
|
|
60
67
|
dataset_name = args.dataset
|
61
68
|
|
62
69
|
dataset = get_models_dataset(dataset_name, **config_params)
|
63
|
-
graph_data_params = get_config_params(args)
|
70
|
+
graph_data_params = {**get_config_params(args), 'task_type': GRAPH_CLS_TASK}
|
64
71
|
print("Loading graph dataset")
|
65
72
|
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
66
73
|
print("Loaded graph dataset")
|
@@ -76,30 +83,50 @@ def run(args):
|
|
76
83
|
train_dataset = classification_dataset['train']
|
77
84
|
test_dataset = classification_dataset['test']
|
78
85
|
num_labels = classification_dataset['num_classes']
|
86
|
+
|
87
|
+
set_encoded_labels(train_dataset, test_dataset)
|
79
88
|
|
80
89
|
print(len(train_dataset), len(test_dataset), num_labels)
|
81
90
|
|
82
91
|
print("Training model")
|
83
92
|
output_dir = os.path.join(
|
84
|
-
|
93
|
+
results_dir,
|
85
94
|
dataset_name,
|
86
|
-
f
|
87
|
-
f
|
95
|
+
f"LM_{GRAPH_CLS_TASK}",
|
96
|
+
f'{args.cls_label}',
|
97
|
+
get_config_str(args)
|
88
98
|
)
|
99
|
+
# if os.path.exists(output_dir):
|
100
|
+
# print(f"Output directory {output_dir} already exists. Exiting.")
|
101
|
+
# exit(0)
|
89
102
|
|
90
103
|
logs_dir = os.path.join(
|
91
104
|
'logs',
|
92
105
|
dataset_name,
|
93
|
-
f
|
94
|
-
f
|
106
|
+
f"LM_{GRAPH_CLS_TASK}",
|
107
|
+
f'{args.cls_label}',
|
108
|
+
f"{graph_dataset.config_hash}_{fold_id}",
|
109
|
+
|
95
110
|
)
|
96
111
|
|
97
|
-
model = get_model(
|
112
|
+
model = get_model(
|
113
|
+
args.ckpt if args.ckpt else model_name,
|
114
|
+
num_labels,
|
115
|
+
len(tokenizer),
|
116
|
+
trust_remote_code=args.trust_remote_code
|
117
|
+
)
|
98
118
|
|
99
119
|
if args.freeze_pretrained_weights:
|
100
120
|
for param in model.base_model.parameters():
|
101
121
|
param.requires_grad = False
|
102
122
|
|
123
|
+
|
124
|
+
logging_steps = get_logging_steps(
|
125
|
+
len(train_dataset),
|
126
|
+
args.num_epochs,
|
127
|
+
args.train_batch_size
|
128
|
+
)
|
129
|
+
#
|
103
130
|
# Training arguments
|
104
131
|
training_args = TrainingArguments(
|
105
132
|
output_dir=output_dir,
|
@@ -111,12 +138,13 @@ def run(args):
|
|
111
138
|
weight_decay=0.01,
|
112
139
|
learning_rate=5e-5,
|
113
140
|
logging_dir=logs_dir,
|
114
|
-
logging_steps=
|
115
|
-
eval_steps=
|
116
|
-
save_steps=
|
141
|
+
logging_steps=logging_steps,
|
142
|
+
eval_steps=logging_steps,
|
143
|
+
save_steps=logging_steps,
|
117
144
|
save_total_limit=2,
|
118
145
|
load_best_model_at_end=True,
|
119
|
-
fp16=True
|
146
|
+
fp16=True,
|
147
|
+
save_strategy="steps"
|
120
148
|
)
|
121
149
|
|
122
150
|
# Trainer
|
@@ -133,5 +161,7 @@ def run(args):
|
|
133
161
|
results = trainer.evaluate()
|
134
162
|
print(results)
|
135
163
|
|
164
|
+
trainer.save_model()
|
165
|
+
|
136
166
|
fold_id += 1
|
137
167
|
break
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from collections import Counter
|
1
2
|
import os
|
2
3
|
import json
|
3
4
|
from argparse import ArgumentParser
|
@@ -20,8 +21,7 @@ from transformers import (
|
|
20
21
|
|
21
22
|
from glam4cm.data_loading.encoding import EncodingDataset
|
22
23
|
from glam4cm.models.hf import get_model
|
23
|
-
|
24
|
-
|
24
|
+
from glam4cm.settings import results_dir
|
25
25
|
|
26
26
|
def compute_metrics(pred):
|
27
27
|
labels = pred.label_ids
|
@@ -49,13 +49,17 @@ def get_parser():
|
|
49
49
|
parser.add_argument('--ckpt', type=str, default=None)
|
50
50
|
parser.add_argument('--max_length', type=int, default=512)
|
51
51
|
parser.add_argument('--k', type=int, default=10)
|
52
|
+
parser.add_argument('--limit', type=int, default=-1)
|
53
|
+
parser.add_argument('--trust_remote_code', action='store_true')
|
54
|
+
parser.add_argument('--include_dummies', action='store_true')
|
55
|
+
parser.add_argument('--task_type', type=str, default='graph_cls')
|
52
56
|
|
53
57
|
parser.add_argument('--num_epochs', type=int, default=10)
|
54
58
|
|
55
59
|
parser.add_argument('--warmup_steps', type=int, default=500)
|
56
|
-
parser.add_argument('--num_log_steps', type=int, default=
|
57
|
-
parser.add_argument('--num_eval_steps', type=int, default=
|
58
|
-
parser.add_argument('--num_save_steps', type=int, default=
|
60
|
+
parser.add_argument('--num_log_steps', type=int, default=500)
|
61
|
+
parser.add_argument('--num_eval_steps', type=int, default=500)
|
62
|
+
parser.add_argument('--num_save_steps', type=int, default=500)
|
59
63
|
parser.add_argument('--train_batch_size', type=int, default=2)
|
60
64
|
parser.add_argument('--eval_batch_size', type=int, default=128)
|
61
65
|
parser.add_argument('--lr', type=float, default=1e-5)
|
@@ -66,15 +70,27 @@ def get_parser():
|
|
66
70
|
def run(args):
|
67
71
|
dataset_name = args.dataset_name
|
68
72
|
model_name = args.model_name
|
69
|
-
|
70
|
-
|
73
|
+
include_dummies = args.include_dummies
|
74
|
+
|
75
|
+
file_name = 'ecore.jsonl' if include_dummies and dataset_name == 'modelset' else 'ecore-with-dummy.jsonl'
|
71
76
|
texts = [
|
72
77
|
(g['txt'], g['labels'])
|
73
|
-
for file_name in os.listdir(f'datasets/{dataset_name}')
|
74
78
|
for g in json.load(open(f'datasets/{dataset_name}/{file_name}'))
|
75
|
-
if '
|
79
|
+
if g['labels'] not in ['dummy', 'unknown']
|
80
|
+
]
|
81
|
+
allowed_labels = [label for label, _ in dict(Counter([t[1] for t in texts]).most_common(48)).items()]
|
82
|
+
texts = [
|
83
|
+
(t, l) for t, l in texts
|
84
|
+
if l in allowed_labels
|
76
85
|
]
|
86
|
+
if args.task_type == 'dd':
|
87
|
+
print("Task type: DD")
|
88
|
+
texts = [(t, 0 if l in 'dummy' else 1) for t, l in texts]
|
89
|
+
|
77
90
|
shuffle(texts)
|
91
|
+
limit = args.limit if args.limit > 0 else len(texts)
|
92
|
+
texts = texts[:limit]
|
93
|
+
|
78
94
|
labels = [y for _, y in texts]
|
79
95
|
y_map = {label: i for i, label in enumerate(set(y for y in labels))}
|
80
96
|
y = [y_map[y] for y in labels]
|
@@ -84,13 +100,14 @@ def run(args):
|
|
84
100
|
|
85
101
|
num_labels = len(y_map)
|
86
102
|
|
87
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
103
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=args.trust_remote_code)
|
88
104
|
k = args.k
|
89
105
|
kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=args.seed)
|
90
106
|
|
91
107
|
i = 0
|
92
108
|
for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
|
93
|
-
|
109
|
+
# if i == 0:
|
110
|
+
# continue
|
94
111
|
print(f'Fold {i+1}/{k}')
|
95
112
|
|
96
113
|
train_texts = [texts[i] for i in train_idx]
|
@@ -101,27 +118,34 @@ def run(args):
|
|
101
118
|
|
102
119
|
print(f'Train: {len(train_texts)}, Test: {len(test_texts)}', num_labels)
|
103
120
|
|
104
|
-
train_dataset = EncodingDataset(tokenizer, train_texts, train_y)
|
105
|
-
test_dataset = EncodingDataset(tokenizer, test_texts, test_y)
|
121
|
+
train_dataset = EncodingDataset(tokenizer, train_texts, train_y, max_length=args.max_length)
|
122
|
+
test_dataset = EncodingDataset(tokenizer, test_texts, test_y, max_length=args.max_length)
|
123
|
+
# import code; code.interact(local=locals())
|
106
124
|
|
107
|
-
model = get_model(
|
125
|
+
model = get_model(
|
126
|
+
args.ckpt if args.ckpt else model_name,
|
127
|
+
num_labels,
|
128
|
+
len(tokenizer),
|
129
|
+
trust_remote_code=args.trust_remote_code
|
130
|
+
)
|
108
131
|
|
109
132
|
print("Training model")
|
110
133
|
output_dir = os.path.join(
|
111
|
-
|
134
|
+
results_dir,
|
112
135
|
dataset_name,
|
113
|
-
'
|
136
|
+
f'graph_cls_comp_{"dummy" if include_dummies else ""}{i+1}',
|
114
137
|
)
|
115
138
|
|
116
139
|
logs_dir = os.path.join(
|
117
140
|
'logs',
|
118
|
-
dataset_name,
|
119
|
-
'
|
141
|
+
f"{dataset_name}_{args.model_name if args.ckpt is None else args.ckpt.split('/')[-1]}",
|
142
|
+
f'graph_cls_comp_{"dummy" if include_dummies else ""}{i+1}',
|
120
143
|
)
|
121
144
|
|
122
145
|
print("Running epochs: ", args.num_epochs)
|
123
146
|
|
124
147
|
# Training arguments
|
148
|
+
print("Batch size: ", args.train_batch_size)
|
125
149
|
training_args = TrainingArguments(
|
126
150
|
output_dir=output_dir,
|
127
151
|
num_train_epochs=args.num_epochs,
|
@@ -131,11 +155,10 @@ def run(args):
|
|
131
155
|
warmup_steps=500,
|
132
156
|
weight_decay=0.01,
|
133
157
|
logging_dir=logs_dir,
|
134
|
-
logging_steps=
|
135
|
-
eval_steps=
|
136
|
-
|
137
|
-
|
138
|
-
fp16=True
|
158
|
+
logging_steps=args.num_log_steps,
|
159
|
+
eval_steps=args.num_eval_steps,
|
160
|
+
fp16=True,
|
161
|
+
save_strategy="no"
|
139
162
|
)
|
140
163
|
|
141
164
|
# Trainer
|
@@ -153,4 +176,4 @@ def run(args):
|
|
153
176
|
print(results)
|
154
177
|
|
155
178
|
i += 1
|
156
|
-
|
179
|
+
# break
|
@@ -3,9 +3,14 @@ import os
|
|
3
3
|
from transformers import TrainingArguments, Trainer
|
4
4
|
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
5
5
|
from glam4cm.models.hf import get_model
|
6
|
-
from glam4cm.settings import
|
7
|
-
from glam4cm.downstream_tasks.common_args import
|
8
|
-
|
6
|
+
from glam4cm.settings import LINK_PRED_TASK, results_dir
|
7
|
+
from glam4cm.downstream_tasks.common_args import (
|
8
|
+
get_bert_args_parser,
|
9
|
+
get_common_args_parser,
|
10
|
+
get_config_params,
|
11
|
+
get_config_str
|
12
|
+
)
|
13
|
+
from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
|
9
14
|
from glam4cm.tokenization.special_tokens import *
|
10
15
|
|
11
16
|
|
@@ -17,7 +22,7 @@ from sklearn.metrics import (
|
|
17
22
|
)
|
18
23
|
|
19
24
|
from glam4cm.tokenization.utils import get_tokenizer
|
20
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
25
|
+
from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
|
21
26
|
|
22
27
|
|
23
28
|
def compute_metrics(pred):
|
@@ -56,6 +61,7 @@ def run(args):
|
|
56
61
|
set_seed(args.seed)
|
57
62
|
|
58
63
|
config_params = dict(
|
64
|
+
include_dummies = args.include_dummies,
|
59
65
|
min_enr = args.min_enr,
|
60
66
|
min_edges = args.min_edges,
|
61
67
|
remove_duplicates = args.remove_duplicates,
|
@@ -69,17 +75,19 @@ def run(args):
|
|
69
75
|
|
70
76
|
|
71
77
|
graph_data_params = get_config_params(args)
|
72
|
-
graph_data_params = {
|
78
|
+
graph_data_params = {
|
79
|
+
**graph_data_params,
|
80
|
+
'add_negative_train_samples': True,
|
81
|
+
'neg_sampling_ratio': args.neg_sampling_ratio,
|
82
|
+
'task_type': LINK_PRED_TASK
|
83
|
+
}
|
84
|
+
print(graph_data_params)
|
73
85
|
|
74
86
|
print("Loading graph dataset")
|
75
87
|
graph_dataset = GraphEdgeDataset(
|
76
88
|
dataset,
|
77
|
-
|
78
|
-
|
79
|
-
add_negative_train_samples=args.add_negative_train_samples,
|
80
|
-
neg_sampling_ratio=args.neg_sampling_ratio,
|
81
|
-
task=LP_TASK_LINK_PRED
|
82
|
-
))
|
89
|
+
**graph_data_params
|
90
|
+
)
|
83
91
|
print("Loaded graph dataset")
|
84
92
|
|
85
93
|
|
@@ -89,33 +97,44 @@ def run(args):
|
|
89
97
|
|
90
98
|
|
91
99
|
print("Getting link prediction data")
|
92
|
-
bert_dataset = graph_dataset.get_link_prediction_lm_data(
|
93
|
-
|
94
|
-
|
95
|
-
)
|
100
|
+
bert_dataset = graph_dataset.get_link_prediction_lm_data(tokenizer=tokenizer)
|
101
|
+
train_dataset = bert_dataset['train']
|
102
|
+
test_dataset = bert_dataset['test']
|
103
|
+
set_encoded_labels(train_dataset, test_dataset)
|
104
|
+
|
96
105
|
|
97
106
|
print("Training model")
|
98
|
-
model = get_model(args.ckpt if args.ckpt else model_name, num_labels=2, len_tokenizer=len(tokenizer))
|
107
|
+
model = get_model(args.ckpt if args.ckpt else model_name, num_labels=2, len_tokenizer=len(tokenizer), trust_remote_code=args.trust_remote_code)
|
99
108
|
|
100
109
|
if args.freeze_pretrained_weights:
|
101
110
|
for param in model.base_model.parameters():
|
102
111
|
param.requires_grad = False
|
103
112
|
|
104
|
-
|
113
|
+
|
105
114
|
output_dir = os.path.join(
|
106
|
-
|
115
|
+
results_dir,
|
107
116
|
dataset_name,
|
108
|
-
|
109
|
-
|
117
|
+
f"LM_{LINK_PRED_TASK}",
|
118
|
+
get_config_str(args)
|
110
119
|
)
|
120
|
+
if os.path.exists(output_dir):
|
121
|
+
print(f"Output directory {output_dir} already exists. Exiting.")
|
122
|
+
exit(0)
|
111
123
|
|
112
124
|
logs_dir = os.path.join(
|
113
125
|
'logs',
|
114
126
|
dataset_name,
|
115
|
-
|
127
|
+
f"LM_{LINK_PRED_TASK}",
|
116
128
|
f"{graph_dataset.config_hash}",
|
117
129
|
)
|
118
130
|
|
131
|
+
logging_steps = get_logging_steps(
|
132
|
+
len(train_dataset),
|
133
|
+
args.num_epochs,
|
134
|
+
args.train_batch_size
|
135
|
+
)
|
136
|
+
|
137
|
+
|
119
138
|
training_args = TrainingArguments(
|
120
139
|
output_dir=output_dir,
|
121
140
|
num_train_epochs=args.num_epochs,
|
@@ -123,13 +142,14 @@ def run(args):
|
|
123
142
|
per_device_eval_batch_size=args.eval_batch_size,
|
124
143
|
weight_decay=0.01,
|
125
144
|
logging_dir=logs_dir,
|
126
|
-
logging_steps=
|
145
|
+
logging_steps=logging_steps,
|
127
146
|
eval_strategy='steps',
|
128
|
-
eval_steps=
|
129
|
-
save_steps=200,
|
130
|
-
save_total_limit=2,
|
131
|
-
load_best_model_at_end=True,
|
147
|
+
eval_steps=logging_steps,
|
148
|
+
# save_steps=200,
|
149
|
+
# save_total_limit=2,
|
150
|
+
# load_best_model_at_end=True,
|
132
151
|
fp16=True,
|
152
|
+
save_strategy="no"
|
133
153
|
)
|
134
154
|
|
135
155
|
trainer = Trainer(
|