glam4cm 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +2 -1
- glam4cm/data_loading/data.py +90 -146
- glam4cm/data_loading/encoding.py +17 -6
- glam4cm/data_loading/graph_dataset.py +192 -57
- glam4cm/data_loading/metadata.py +1 -1
- glam4cm/data_loading/models_dataset.py +42 -18
- glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
- glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
- glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
- glam4cm/downstream_tasks/bert_node_classification.py +127 -89
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
- glam4cm/downstream_tasks/common_args.py +32 -4
- glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
- glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
- glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
- glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
- glam4cm/downstream_tasks/utils.py +16 -2
- glam4cm/embeddings/bert.py +1 -1
- glam4cm/embeddings/common.py +7 -4
- glam4cm/encoding/encoders.py +1 -1
- glam4cm/lang2graph/archimate.py +0 -5
- glam4cm/lang2graph/common.py +99 -41
- glam4cm/lang2graph/ecore.py +1 -2
- glam4cm/lang2graph/ontouml.py +8 -7
- glam4cm/models/gnn_layers.py +20 -6
- glam4cm/models/hf.py +2 -2
- glam4cm/run.py +13 -9
- glam4cm/run_conf_v2.py +405 -0
- glam4cm/run_configs.py +70 -106
- glam4cm/run_confs.py +41 -0
- glam4cm/settings.py +15 -2
- glam4cm/tokenization/special_tokens.py +23 -1
- glam4cm/tokenization/utils.py +23 -4
- glam4cm/trainers/cm_gpt_trainer.py +1 -1
- glam4cm/trainers/gnn_edge_classifier.py +12 -1
- glam4cm/trainers/gnn_graph_classifier.py +12 -5
- glam4cm/trainers/gnn_link_predictor.py +18 -3
- glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
- glam4cm/trainers/gnn_trainer.py +8 -0
- glam4cm/trainers/metrics.py +1 -1
- glam4cm/utils.py +265 -2
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
- glam4cm-1.0.0.dist-info/RECORD +75 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
- glam4cm-0.1.0.dist-info/RECORD +0 -72
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,19 @@
|
|
1
|
+
import numpy as np
|
1
2
|
from glam4cm.models.hf import get_model
|
2
|
-
from glam4cm.downstream_tasks.common_args import
|
3
|
+
from glam4cm.downstream_tasks.common_args import (
|
4
|
+
get_bert_args_parser,
|
5
|
+
get_common_args_parser,
|
6
|
+
get_config_params,
|
7
|
+
get_config_str
|
8
|
+
)
|
3
9
|
import os
|
4
10
|
from transformers import TrainingArguments, Trainer
|
5
11
|
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
6
12
|
from glam4cm.data_loading.utils import oversample_dataset
|
7
|
-
from glam4cm.downstream_tasks.utils import get_models_dataset
|
13
|
+
from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
|
14
|
+
from glam4cm.settings import NODE_CLS_TASK, results_dir
|
8
15
|
from glam4cm.tokenization.special_tokens import *
|
9
|
-
|
16
|
+
from sklearn.model_selection import StratifiedKFold
|
10
17
|
|
11
18
|
from sklearn.metrics import (
|
12
19
|
accuracy_score,
|
@@ -16,7 +23,7 @@ from sklearn.metrics import (
|
|
16
23
|
)
|
17
24
|
|
18
25
|
from glam4cm.tokenization.utils import get_tokenizer
|
19
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
26
|
+
from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
|
20
27
|
|
21
28
|
|
22
29
|
|
@@ -54,11 +61,24 @@ def get_parser():
|
|
54
61
|
return parser
|
55
62
|
|
56
63
|
|
57
|
-
|
58
64
|
def run(args):
|
59
65
|
set_seed(args.seed)
|
66
|
+
dataset_name = args.dataset
|
67
|
+
print("Training model")
|
68
|
+
output_dir = os.path.join(
|
69
|
+
results_dir,
|
70
|
+
dataset_name,
|
71
|
+
f'LM_{NODE_CLS_TASK}',
|
72
|
+
f'{args.node_cls_label}',
|
73
|
+
get_config_str(args)
|
74
|
+
)
|
75
|
+
|
76
|
+
# if os.path.exists(output_dir):
|
77
|
+
# print(f"Output directory {output_dir} already exists. Exiting.")
|
78
|
+
# exit(0)
|
60
79
|
|
61
80
|
config_params = dict(
|
81
|
+
include_dummies = args.include_dummies,
|
62
82
|
min_enr = args.min_enr,
|
63
83
|
min_edges = args.min_edges,
|
64
84
|
remove_duplicates = args.remove_duplicates,
|
@@ -71,92 +91,110 @@ def run(args):
|
|
71
91
|
|
72
92
|
print("Loaded dataset")
|
73
93
|
|
74
|
-
graph_data_params = get_config_params(args)
|
94
|
+
graph_data_params = {**get_config_params(args), 'task_type': NODE_CLS_TASK}
|
75
95
|
print("Loading graph dataset")
|
76
|
-
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
77
|
-
print("Loaded graph dataset")
|
78
|
-
|
79
|
-
assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node_{args.node_cls_label} attribute"
|
80
|
-
num_labels = getattr(graph_dataset, f"num_nodes_{args.node_cls_label}")
|
81
|
-
|
82
|
-
model_name = args.model_name
|
83
|
-
tokenizer = get_tokenizer(model_name, use_special_tokens=args.use_special_tokens)
|
84
|
-
|
85
|
-
print("Getting node classification data")
|
86
|
-
bert_dataset = graph_dataset.get_node_classification_lm_data(
|
87
|
-
label=args.node_cls_label,
|
88
|
-
tokenizer=tokenizer,
|
89
|
-
distance=distance,
|
90
|
-
)
|
91
|
-
|
92
|
-
# exit(0)
|
93
|
-
|
94
|
-
if args.oversampling_ratio != -1:
|
95
|
-
ind_w_oversamples = oversample_dataset(bert_dataset['train'])
|
96
|
-
bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
|
97
|
-
|
98
|
-
model = get_model(
|
99
|
-
args.ckpt if args.ckpt else model_name,
|
100
|
-
num_labels=2,
|
101
|
-
len_tokenizer=len(tokenizer)
|
102
|
-
)
|
103
|
-
|
104
|
-
if args.freeze_pretrained_weights:
|
105
|
-
for param in model.base_model.parameters():
|
106
|
-
param.requires_grad = False
|
107
|
-
|
108
96
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
97
|
+
k = int(1 / args.test_ratio)
|
98
|
+
|
99
|
+
for i in range(k):
|
100
|
+
set_seed(np.random.randint(0, 1000))
|
101
|
+
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
102
|
+
print("Loaded graph dataset")
|
103
|
+
|
104
|
+
assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node_{args.node_cls_label} attribute"
|
105
|
+
num_labels = getattr(graph_dataset, f"num_nodes_{args.node_cls_label}")
|
106
|
+
|
107
|
+
model_name = args.model_name
|
108
|
+
tokenizer = get_tokenizer(model_name, use_special_tokens=args.use_special_tokens)
|
109
|
+
|
110
|
+
print("Getting node classification data")
|
111
|
+
bert_dataset = graph_dataset.get_node_classification_lm_data(
|
112
|
+
label=args.node_cls_label,
|
113
|
+
tokenizer=tokenizer,
|
114
|
+
distance=distance,
|
115
|
+
)
|
116
|
+
|
117
|
+
# exit(0)
|
118
|
+
|
119
|
+
if args.oversampling_ratio != -1:
|
120
|
+
ind_w_oversamples = oversample_dataset(bert_dataset['train'])
|
121
|
+
bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
|
122
|
+
|
123
|
+
|
124
|
+
model = get_model(
|
125
|
+
args.ckpt if args.ckpt else model_name,
|
126
|
+
num_labels=num_labels,
|
127
|
+
len_tokenizer=len(tokenizer),
|
128
|
+
trust_remote_code=args.trust_remote_code
|
129
|
+
)
|
130
|
+
|
131
|
+
if args.freeze_pretrained_weights:
|
132
|
+
for param in model.base_model.parameters():
|
133
|
+
param.requires_grad = False
|
134
|
+
|
135
|
+
|
136
|
+
logs_dir = os.path.join(
|
137
|
+
'logs',
|
138
|
+
dataset_name,
|
139
|
+
f'BERT_{NODE_CLS_TASK}',
|
140
|
+
f'{args.node_cls_label}',
|
141
|
+
f"{graph_dataset.config_hash}_{i}",
|
142
|
+
)
|
143
|
+
|
144
|
+
print("Output Dir: ", output_dir)
|
145
|
+
print("Logs Dir: ", logs_dir)
|
146
|
+
print("Len Train Dataset: ", len(bert_dataset['train']))
|
147
|
+
print("Len Test Dataset: ", len(bert_dataset['test']))
|
148
|
+
|
149
|
+
train_dataset = bert_dataset['train']
|
150
|
+
test_dataset = bert_dataset['test']
|
151
|
+
set_encoded_labels(train_dataset, test_dataset)
|
152
|
+
|
153
|
+
|
154
|
+
print("Num epochs: ", args.num_epochs)
|
155
|
+
|
156
|
+
logging_steps = get_logging_steps(
|
157
|
+
len(train_dataset),
|
158
|
+
args.num_epochs,
|
159
|
+
args.train_batch_size
|
160
|
+
)
|
161
|
+
|
162
|
+
training_args = TrainingArguments(
|
163
|
+
output_dir=output_dir,
|
164
|
+
num_train_epochs=args.num_epochs,
|
165
|
+
per_device_train_batch_size=args.train_batch_size,
|
166
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
167
|
+
weight_decay=0.01,
|
168
|
+
logging_dir=logs_dir,
|
169
|
+
logging_steps=logging_steps,
|
170
|
+
eval_strategy='steps',
|
171
|
+
eval_steps=logging_steps,
|
172
|
+
# save_steps=args.num_save_steps,
|
173
|
+
# save_total_limit=2,
|
174
|
+
# load_best_model_at_end=True,
|
175
|
+
fp16=True,
|
176
|
+
save_strategy="no"
|
177
|
+
)
|
178
|
+
|
179
|
+
trainer = Trainer(
|
180
|
+
model=model,
|
181
|
+
args=training_args,
|
182
|
+
train_dataset=bert_dataset['train'],
|
183
|
+
eval_dataset=bert_dataset['test'],
|
184
|
+
compute_metrics=compute_metrics
|
185
|
+
)
|
186
|
+
|
187
|
+
trainer.train()
|
188
|
+
results = trainer.evaluate()
|
189
|
+
|
190
|
+
# with open(os.path.join(output_dir, 'results.txt'), 'a') as f:
|
191
|
+
# f.write(str(results))
|
192
|
+
# f.write('\n')
|
193
|
+
|
194
|
+
print(results)
|
195
|
+
|
196
|
+
trainer.save_model()
|
197
|
+
break
|
160
198
|
|
161
199
|
|
162
200
|
if __name__ == '__main__':
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import os
|
2
2
|
from glam4cm.downstream_tasks.common_args import (
|
3
|
-
get_common_args_parser,
|
3
|
+
get_common_args_parser,
|
4
|
+
get_config_params,
|
5
|
+
get_config_str,
|
4
6
|
get_gpt_args_parser
|
5
7
|
)
|
6
8
|
|
@@ -9,7 +11,8 @@ from glam4cm.models.cmgpt import CMGPT, CMGPTClassifier
|
|
9
11
|
from glam4cm.downstream_tasks.utils import get_models_dataset
|
10
12
|
from glam4cm.tokenization.utils import get_tokenizer
|
11
13
|
from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
|
12
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
14
|
+
from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
|
15
|
+
from glam4cm.settings import NODE_CLS_TASK, results_dir
|
13
16
|
|
14
17
|
|
15
18
|
def get_parser():
|
@@ -25,29 +28,68 @@ def get_parser():
|
|
25
28
|
def run(args):
|
26
29
|
set_seed(args.seed)
|
27
30
|
|
28
|
-
tokenizer = get_tokenizer('bert-base-cased',
|
31
|
+
tokenizer = get_tokenizer('bert-base-cased', use_special_tokens=args.use_special_tokens)
|
29
32
|
|
30
|
-
|
31
|
-
|
33
|
+
set_seed(args.seed)
|
34
|
+
dataset_name = args.dataset
|
35
|
+
print("Training model")
|
36
|
+
output_dir = os.path.join(
|
37
|
+
results_dir,
|
38
|
+
dataset_name,
|
39
|
+
f'LM_{NODE_CLS_TASK}',
|
40
|
+
f'{args.node_cls_label}',
|
41
|
+
get_config_str(args)
|
32
42
|
)
|
33
43
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
44
|
+
# if os.path.exists(output_dir):
|
45
|
+
# print(f"Output directory {output_dir} already exists. Exiting.")
|
46
|
+
# exit(0)
|
47
|
+
|
48
|
+
config_params = dict(
|
49
|
+
include_dummies = args.include_dummies,
|
50
|
+
min_enr = args.min_enr,
|
51
|
+
min_edges = args.min_edges,
|
52
|
+
remove_duplicates = args.remove_duplicates,
|
53
|
+
reload=args.reload,
|
54
|
+
language = args.language
|
38
55
|
)
|
56
|
+
dataset_name = args.dataset
|
57
|
+
distance = args.distance
|
58
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
59
|
+
|
60
|
+
print("Loaded dataset")
|
39
61
|
|
40
|
-
|
41
|
-
|
62
|
+
graph_data_params = {**get_config_params(args), 'task_type': NODE_CLS_TASK}
|
63
|
+
print("Loading graph dataset")
|
64
|
+
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
42
65
|
|
43
66
|
assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node labels for {args.node_cls_label}"
|
44
67
|
|
45
68
|
node_label_dataset = graph_dataset.get_node_classification_lm_data(
|
46
69
|
args.node_cls_label,
|
47
70
|
tokenizer=tokenizer,
|
48
|
-
distance=
|
71
|
+
distance=args.distance,
|
72
|
+
)
|
73
|
+
|
74
|
+
set_encoded_labels(node_label_dataset['train'], node_label_dataset['test'])
|
75
|
+
|
76
|
+
print("Training model")
|
77
|
+
output_dir = os.path.join(
|
78
|
+
results_dir,
|
79
|
+
args.dataset,
|
80
|
+
f'LM_{NODE_CLS_TASK}',
|
81
|
+
f'{args.node_cls_label}',
|
82
|
+
)
|
83
|
+
|
84
|
+
logs_dir = os.path.join(
|
85
|
+
'logs',
|
86
|
+
args.dataset,
|
87
|
+
f'CMGPT_{NODE_CLS_TASK}',
|
88
|
+
f'{args.node_cls_label}',
|
89
|
+
f"{graph_dataset.config_hash}",
|
49
90
|
)
|
50
91
|
|
92
|
+
|
51
93
|
if args.pretr and os.path.exists(args.pretr):
|
52
94
|
print(f"Loading pretrained model from {args.pretr}")
|
53
95
|
cmgpt = CMGPT.from_pretrained(f"{args.pretr}")
|
@@ -60,7 +102,8 @@ def run(args):
|
|
60
102
|
n_layer=args.n_layer,
|
61
103
|
n_head=args.n_head,
|
62
104
|
)
|
63
|
-
|
105
|
+
print(f"Train dataset size: {len(node_label_dataset['train'])}")
|
106
|
+
print(f"Test dataset size: {len(node_label_dataset['test'])}")
|
64
107
|
cmgpt_classifier = CMGPTClassifier(cmgpt, num_classes=getattr(graph_dataset, f"num_nodes_{args.node_cls_label}"))
|
65
108
|
|
66
109
|
trainer = CMGPTTrainer(
|
@@ -68,9 +111,12 @@ def run(args):
|
|
68
111
|
train_dataset=node_label_dataset['train'],
|
69
112
|
test_dataset=node_label_dataset['test'],
|
70
113
|
batch_size=args.batch_size,
|
71
|
-
num_epochs=args.num_epochs
|
114
|
+
num_epochs=args.num_epochs,
|
115
|
+
log_dir=logs_dir,
|
116
|
+
results_dir=output_dir,
|
72
117
|
)
|
73
118
|
|
74
119
|
trainer.train()
|
75
120
|
|
76
|
-
trainer.save_model()
|
121
|
+
trainer.save_model()
|
122
|
+
|
@@ -1,13 +1,35 @@
|
|
1
1
|
from argparse import ArgumentParser
|
2
2
|
from glam4cm.settings import (
|
3
|
+
MODERN_BERT,
|
3
4
|
BERT_MODEL,
|
4
5
|
WORD2VEC_MODEL,
|
5
6
|
TFIDF_MODEL
|
6
7
|
)
|
7
8
|
|
9
|
+
def get_config_str(args):
|
10
|
+
config_str = ""
|
11
|
+
if args.use_attributes:
|
12
|
+
config_str += "_attrs"
|
13
|
+
if args.use_edge_label:
|
14
|
+
config_str += "_el"
|
15
|
+
if args.use_edge_types:
|
16
|
+
config_str += "_et"
|
17
|
+
if args.use_node_types:
|
18
|
+
config_str += "_nt"
|
19
|
+
if args.use_special_tokens:
|
20
|
+
config_str += "_st"
|
21
|
+
if args.no_labels:
|
22
|
+
config_str += "_nolb"
|
23
|
+
config_str += f"_{args.node_cls_label}" if args.node_cls_label else ""
|
24
|
+
config_str += f"_{args.edge_cls_label}" if args.edge_cls_label else ""
|
25
|
+
config_str += f"_{args.distance}" if args.distance else ""
|
26
|
+
|
27
|
+
return config_str
|
28
|
+
|
8
29
|
|
9
30
|
def get_config_params(args):
|
10
31
|
common_params = dict(
|
32
|
+
|
11
33
|
distance=args.distance,
|
12
34
|
reload=args.reload,
|
13
35
|
test_ratio=args.test_ratio,
|
@@ -17,6 +39,8 @@ def get_config_params(args):
|
|
17
39
|
use_edge_types=args.use_edge_types,
|
18
40
|
use_edge_label=args.use_edge_label,
|
19
41
|
no_labels=args.no_labels,
|
42
|
+
|
43
|
+
node_topk=args.node_topk,
|
20
44
|
|
21
45
|
use_special_tokens=args.use_special_tokens,
|
22
46
|
|
@@ -58,6 +82,7 @@ def get_common_args_parser():
|
|
58
82
|
]
|
59
83
|
)
|
60
84
|
parser.add_argument('--remove_duplicates', action='store_true')
|
85
|
+
parser.add_argument('--include_dummies', action='store_true')
|
61
86
|
parser.add_argument('--reload', action='store_true')
|
62
87
|
parser.add_argument('--min_enr', type=float, default=-1.0)
|
63
88
|
parser.add_argument('--min_edges', type=int, default=-1)
|
@@ -72,6 +97,8 @@ def get_common_args_parser():
|
|
72
97
|
|
73
98
|
parser.add_argument('--node_cls_label', type=str, default=None)
|
74
99
|
parser.add_argument('--edge_cls_label', type=str, default=None)
|
100
|
+
|
101
|
+
parser.add_argument('--node_topk', type=int, default=-1)
|
75
102
|
|
76
103
|
|
77
104
|
parser.add_argument('--limit', type=int, default=-1)
|
@@ -84,9 +111,11 @@ def get_common_args_parser():
|
|
84
111
|
parser.add_argument(
|
85
112
|
'--embed_model_name',
|
86
113
|
type=str,
|
87
|
-
default=
|
88
|
-
choices=[BERT_MODEL, WORD2VEC_MODEL, TFIDF_MODEL]
|
114
|
+
default=MODERN_BERT,
|
115
|
+
choices=[MODERN_BERT, BERT_MODEL, WORD2VEC_MODEL, TFIDF_MODEL]
|
89
116
|
)
|
117
|
+
|
118
|
+
parser.add_argument('--trust_remote_code', action='store_true')
|
90
119
|
parser.add_argument('--max_length', type=int, default=512)
|
91
120
|
parser.add_argument('--ckpt', type=str, default=None)
|
92
121
|
|
@@ -134,7 +163,7 @@ def get_bert_args_parser():
|
|
134
163
|
parser = ArgumentParser()
|
135
164
|
|
136
165
|
parser.add_argument('--freeze_pretrained_weights', action='store_true')
|
137
|
-
parser.add_argument('--model_name', type=str, default='
|
166
|
+
parser.add_argument('--model_name', type=str, default='answerdotai/ModernBERT-base')
|
138
167
|
|
139
168
|
parser.add_argument('--warmup_steps', type=int, default=200)
|
140
169
|
parser.add_argument('--num_log_steps', type=int, default=200)
|
@@ -148,7 +177,6 @@ def get_bert_args_parser():
|
|
148
177
|
def get_gpt_args_parser():
|
149
178
|
parser = ArgumentParser()
|
150
179
|
parser.add_argument('--model_name', type=str, default='gpt2')
|
151
|
-
parser.add_argument('--use_special_tokens', action='store_true')
|
152
180
|
|
153
181
|
parser.add_argument('--warmup_steps', type=int, default=200)
|
154
182
|
parser.add_argument('--blocks', type=int, default=6)
|
@@ -1,12 +1,16 @@
|
|
1
1
|
import os
|
2
2
|
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
3
3
|
from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
|
4
|
-
from glam4cm.settings import
|
4
|
+
from glam4cm.settings import EDGE_CLS_TASK, results_dir
|
5
5
|
from glam4cm.downstream_tasks.utils import get_models_dataset
|
6
6
|
from glam4cm.tokenization.special_tokens import *
|
7
7
|
from glam4cm.trainers.gnn_edge_classifier import GNNEdgeClassificationTrainer as Trainer
|
8
|
-
from glam4cm.utils import set_seed, merge_argument_parsers
|
9
|
-
from glam4cm.downstream_tasks.common_args import
|
8
|
+
from glam4cm.utils import set_seed, merge_argument_parsers, set_torch_encoding_labels
|
9
|
+
from glam4cm.downstream_tasks.common_args import (
|
10
|
+
get_common_args_parser,
|
11
|
+
get_config_params,
|
12
|
+
get_gnn_args_parser
|
13
|
+
)
|
10
14
|
|
11
15
|
|
12
16
|
def get_parser():
|
@@ -21,6 +25,7 @@ def run(args):
|
|
21
25
|
set_seed(args.seed)
|
22
26
|
|
23
27
|
config_params = dict(
|
28
|
+
include_dummies = args.include_dummies,
|
24
29
|
min_enr = args.min_enr,
|
25
30
|
min_edges = args.min_edges,
|
26
31
|
remove_duplicates = args.remove_duplicates,
|
@@ -32,13 +37,21 @@ def run(args):
|
|
32
37
|
dataset = get_models_dataset(dataset_name, **config_params)
|
33
38
|
|
34
39
|
graph_data_params = get_config_params(args)
|
35
|
-
graph_data_params = {**graph_data_params, 'task_type':
|
40
|
+
graph_data_params = {**graph_data_params, 'task_type': EDGE_CLS_TASK}
|
41
|
+
print("Using model: ", graph_data_params['embed_model_name'])
|
42
|
+
if args.ckpt:
|
43
|
+
print("Using checkpoint: ", args.ckpt)
|
44
|
+
|
45
|
+
# if args.use_embeddings:
|
46
|
+
# graph_data_params['embed_model_name'] = os.path.join(results_dir, dataset_name, f'{args.edge_cls_label}')
|
36
47
|
|
37
48
|
print("Loading graph dataset")
|
38
49
|
graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
|
39
50
|
print("Loaded graph dataset")
|
40
51
|
|
41
52
|
graph_torch_data = graph_dataset.get_torch_dataset()
|
53
|
+
exclude_labels = getattr(graph_dataset, f"edge_exclude_{args.edge_cls_label}")
|
54
|
+
set_torch_encoding_labels(graph_torch_data, f"edge_{args.edge_cls_label}", exclude_labels)
|
42
55
|
|
43
56
|
input_dim = graph_torch_data[0].x.shape[1]
|
44
57
|
|
@@ -59,12 +72,16 @@ def run(args):
|
|
59
72
|
|
60
73
|
edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.use_edge_attrs else None
|
61
74
|
|
75
|
+
ue = "" if not args.use_edge_attrs else "_ue"
|
76
|
+
|
62
77
|
logs_dir = os.path.join(
|
63
78
|
"logs",
|
64
79
|
dataset_name,
|
65
|
-
"
|
80
|
+
f"GNN_{EDGE_CLS_TASK}",
|
81
|
+
f"{args.edge_cls_label}{ue}",
|
66
82
|
f"{graph_dataset.config_hash}",
|
67
83
|
)
|
84
|
+
|
68
85
|
|
69
86
|
gnn_conv_model = GNNConv(
|
70
87
|
model_name=model_name,
|
@@ -99,8 +116,8 @@ def run(args):
|
|
99
116
|
num_epochs=args.num_epochs,
|
100
117
|
batch_size=args.batch_size,
|
101
118
|
use_edge_attrs=args.use_edge_attrs,
|
102
|
-
logs_dir=logs_dir
|
119
|
+
logs_dir=logs_dir,
|
103
120
|
)
|
104
121
|
|
105
122
|
print("Training GNN Edge Classification model")
|
106
|
-
trainer.run()
|
123
|
+
trainer.run()
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
3
3
|
from glam4cm.models.gnn_layers import GNNConv, GraphClassifer
|
4
|
+
from glam4cm.settings import GRAPH_CLS_TASK, DUMMY_GRAPH_CLS_TASK, results_dir
|
4
5
|
from glam4cm.trainers.gnn_graph_classifier import GNNGraphClassificationTrainer as Trainer
|
5
6
|
from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
|
6
7
|
from glam4cm.utils import merge_argument_parsers, set_seed
|
@@ -21,6 +22,7 @@ def run(args):
|
|
21
22
|
set_seed(args.seed)
|
22
23
|
|
23
24
|
config_params = dict(
|
25
|
+
include_dummies = args.include_dummies,
|
24
26
|
min_enr = args.min_enr,
|
25
27
|
min_edges = args.min_edges,
|
26
28
|
remove_duplicates = args.remove_duplicates,
|
@@ -30,8 +32,10 @@ def run(args):
|
|
30
32
|
dataset_name = args.dataset
|
31
33
|
|
32
34
|
dataset = get_models_dataset(dataset_name, **config_params)
|
33
|
-
|
34
|
-
graph_data_params = get_config_params(args)
|
35
|
+
|
36
|
+
graph_data_params = {**get_config_params(args), 'task_type': GRAPH_CLS_TASK if not args.include_dummies else DUMMY_GRAPH_CLS_TASK}
|
37
|
+
# if args.use_embeddings:
|
38
|
+
# graph_data_params['ckpt'] = os.path.join(results_dir, dataset_name, f'{args.cls_label}')
|
35
39
|
|
36
40
|
print("Loading graph dataset")
|
37
41
|
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
@@ -40,8 +44,16 @@ def run(args):
|
|
40
44
|
cls_label = f"num_graph_{args.cls_label}"
|
41
45
|
assert hasattr(graph_dataset, cls_label), f"Dataset does not have attribute {cls_label}"
|
42
46
|
num_classes = getattr(graph_dataset, cls_label)
|
43
|
-
|
44
47
|
print(f"Number of classes: {num_classes}")
|
48
|
+
|
49
|
+
if args.include_dummies:
|
50
|
+
import numpy as np
|
51
|
+
dummy_class = int(graph_dataset.graph_label_map_label.transform(['dummy'])[0])
|
52
|
+
for g, l in zip(graph_dataset, [int(g.data.graph_label[0]) == dummy_class for g in graph_dataset]):
|
53
|
+
setattr(g.data, f"graph_{args.cls_label}", np.array([int(l)]))
|
54
|
+
num_classes = 2
|
55
|
+
|
56
|
+
|
45
57
|
model_name = args.gnn_conv_model
|
46
58
|
hidden_dim = args.hidden_dim
|
47
59
|
output_dim = args.output_dim
|
@@ -53,14 +65,15 @@ def run(args):
|
|
53
65
|
aggregation = args.aggregation
|
54
66
|
|
55
67
|
input_dim = graph_dataset[0].data.x.shape[1]
|
56
|
-
|
68
|
+
ue = "" if not args.use_edge_attrs else "_ue"
|
57
69
|
logs_dir = os.path.join(
|
58
70
|
"logs",
|
59
71
|
dataset_name,
|
60
|
-
"
|
72
|
+
f"GNN_{GRAPH_CLS_TASK}{ue}",
|
61
73
|
f"{graph_dataset.config_hash}",
|
62
74
|
)
|
63
75
|
|
76
|
+
fold_id = 0
|
64
77
|
for datasets in graph_dataset.get_kfold_gnn_graph_classification_data():
|
65
78
|
|
66
79
|
edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.num_heads else None
|
@@ -94,7 +107,7 @@ def run(args):
|
|
94
107
|
num_epochs=args.num_epochs,
|
95
108
|
batch_size=args.batch_size,
|
96
109
|
use_edge_attrs=args.use_edge_attrs,
|
97
|
-
logs_dir=logs_dir
|
110
|
+
logs_dir=logs_dir + f"_{fold_id}",
|
98
111
|
)
|
99
112
|
|
100
113
|
trainer.run()
|