glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,76 @@
|
|
1
|
+
import os
|
2
|
+
from glam4cm.downstream_tasks.common_args import (
|
3
|
+
get_common_args_parser,
|
4
|
+
get_gpt_args_parser
|
5
|
+
)
|
6
|
+
|
7
|
+
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
8
|
+
from glam4cm.models.cmgpt import CMGPT, CMGPTClassifier
|
9
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
10
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
11
|
+
from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
|
12
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
13
|
+
|
14
|
+
|
15
|
+
def get_parser():
|
16
|
+
common_parser = get_common_args_parser()
|
17
|
+
bert_parser = get_gpt_args_parser()
|
18
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
19
|
+
|
20
|
+
parser.add_argument('--cls_label', type=str)
|
21
|
+
parser.add_argument('--pretr', type=str, default=None)
|
22
|
+
return parser
|
23
|
+
|
24
|
+
|
25
|
+
def run(args):
|
26
|
+
set_seed(args.seed)
|
27
|
+
|
28
|
+
tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
|
29
|
+
|
30
|
+
models_dataset_params = dict(
|
31
|
+
language='en',
|
32
|
+
)
|
33
|
+
|
34
|
+
graph_params = dict(
|
35
|
+
use_special_tokens=args.use_special_tokens,
|
36
|
+
distance=args.distance,
|
37
|
+
reload = args.reload
|
38
|
+
)
|
39
|
+
|
40
|
+
models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
|
41
|
+
graph_dataset = GraphNodeDataset(models_dataset, **graph_params)
|
42
|
+
|
43
|
+
assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node labels for {args.node_cls_label}"
|
44
|
+
|
45
|
+
node_label_dataset = graph_dataset.get_node_classification_lm_data(
|
46
|
+
args.node_cls_label,
|
47
|
+
tokenizer=tokenizer,
|
48
|
+
distance=1,
|
49
|
+
)
|
50
|
+
|
51
|
+
if args.pretr and os.path.exists(args.pretr):
|
52
|
+
print(f"Loading pretrained model from {args.pretr}")
|
53
|
+
cmgpt = CMGPT.from_pretrained(f"{args.pretr}")
|
54
|
+
else:
|
55
|
+
print("Creating new model")
|
56
|
+
cmgpt = CMGPT(
|
57
|
+
vocab_size=len(tokenizer),
|
58
|
+
embed_dim=args.embed_dim,
|
59
|
+
block_size=args.block_size,
|
60
|
+
n_layer=args.n_layer,
|
61
|
+
n_head=args.n_head,
|
62
|
+
)
|
63
|
+
|
64
|
+
cmgpt_classifier = CMGPTClassifier(cmgpt, num_classes=getattr(graph_dataset, f"num_nodes_{args.node_cls_label}"))
|
65
|
+
|
66
|
+
trainer = CMGPTTrainer(
|
67
|
+
cmgpt_classifier,
|
68
|
+
train_dataset=node_label_dataset['train'],
|
69
|
+
test_dataset=node_label_dataset['test'],
|
70
|
+
batch_size=args.batch_size,
|
71
|
+
num_epochs=args.num_epochs
|
72
|
+
)
|
73
|
+
|
74
|
+
trainer.train()
|
75
|
+
|
76
|
+
trainer.save_model()
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from sklearn.model_selection import train_test_split
|
2
|
+
from glam4cm.downstream_tasks.common_args import (
|
3
|
+
get_common_args_parser,
|
4
|
+
get_gpt_args_parser
|
5
|
+
)
|
6
|
+
|
7
|
+
from glam4cm.data_loading.graph_dataset import get_models_gpt_dataset
|
8
|
+
from glam4cm.models.cmgpt import CMGPT
|
9
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
10
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
11
|
+
from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
|
12
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
13
|
+
|
14
|
+
|
15
|
+
def get_parser():
|
16
|
+
common_parser = get_common_args_parser()
|
17
|
+
bert_parser = get_gpt_args_parser()
|
18
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
19
|
+
return parser
|
20
|
+
|
21
|
+
|
22
|
+
def run(args):
|
23
|
+
set_seed(args.seed)
|
24
|
+
|
25
|
+
tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
|
26
|
+
|
27
|
+
models_dataset_params = dict(
|
28
|
+
language='en',
|
29
|
+
)
|
30
|
+
|
31
|
+
graph_params = dict(
|
32
|
+
use_special_tokens=True,
|
33
|
+
distance=1,
|
34
|
+
)
|
35
|
+
|
36
|
+
models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
|
37
|
+
graph_dataset = get_models_gpt_dataset(
|
38
|
+
models_dataset,
|
39
|
+
tokenizer,
|
40
|
+
**graph_params
|
41
|
+
)
|
42
|
+
|
43
|
+
train_dataset, test_dataset = train_test_split(graph_dataset, test_size=0.05)
|
44
|
+
|
45
|
+
cmgpt = CMGPT(
|
46
|
+
vocab_size=len(tokenizer),
|
47
|
+
embed_dim=args.embed_dim,
|
48
|
+
block_size=args.block_size,
|
49
|
+
n_layer=args.n_layer,
|
50
|
+
n_head=args.n_head,
|
51
|
+
)
|
52
|
+
|
53
|
+
trainer = CMGPTTrainer(
|
54
|
+
cmgpt,
|
55
|
+
train_dataset=train_dataset,
|
56
|
+
test_dataset=test_dataset,
|
57
|
+
test_ratio=0.05,
|
58
|
+
batch_size=args.batch_size,
|
59
|
+
num_epochs=args.num_epochs
|
60
|
+
)
|
61
|
+
|
62
|
+
trainer.train()
|
63
|
+
|
64
|
+
trainer.save_model()
|
@@ -0,0 +1,160 @@
|
|
1
|
+
from argparse import ArgumentParser
|
2
|
+
from glam4cm.settings import (
|
3
|
+
BERT_MODEL,
|
4
|
+
WORD2VEC_MODEL,
|
5
|
+
TFIDF_MODEL
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
def get_config_params(args):
|
10
|
+
common_params = dict(
|
11
|
+
distance=args.distance,
|
12
|
+
reload=args.reload,
|
13
|
+
test_ratio=args.test_ratio,
|
14
|
+
|
15
|
+
use_attributes=args.use_attributes,
|
16
|
+
use_node_types=args.use_node_types,
|
17
|
+
use_edge_types=args.use_edge_types,
|
18
|
+
use_edge_label=args.use_edge_label,
|
19
|
+
no_labels=args.no_labels,
|
20
|
+
|
21
|
+
use_special_tokens=args.use_special_tokens,
|
22
|
+
|
23
|
+
use_embeddings=args.use_embeddings,
|
24
|
+
embed_model_name=args.embed_model_name,
|
25
|
+
ckpt=args.ckpt,
|
26
|
+
|
27
|
+
no_shuffle=args.no_shuffle,
|
28
|
+
randomize_ne=args.randomize_ne,
|
29
|
+
randomize_ee=args.randomize_ee,
|
30
|
+
random_embed_dim=args.random_embed_dim,
|
31
|
+
|
32
|
+
limit = args.limit,
|
33
|
+
|
34
|
+
node_cls_label=args.node_cls_label,
|
35
|
+
edge_cls_label=args.edge_cls_label,
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
return common_params
|
40
|
+
|
41
|
+
|
42
|
+
def get_common_args_parser():
|
43
|
+
parser = ArgumentParser()
|
44
|
+
|
45
|
+
parser.add_argument('--seed', type=int, default=42)
|
46
|
+
|
47
|
+
### Models Dataset Creation
|
48
|
+
parser.add_argument(
|
49
|
+
'--dataset',
|
50
|
+
type=str,
|
51
|
+
default='ecore_555',
|
52
|
+
choices=[
|
53
|
+
'modelset',
|
54
|
+
'ecore_555',
|
55
|
+
'mar-ecore-github',
|
56
|
+
'eamodelset',
|
57
|
+
'ontouml'
|
58
|
+
]
|
59
|
+
)
|
60
|
+
parser.add_argument('--remove_duplicates', action='store_true')
|
61
|
+
parser.add_argument('--reload', action='store_true')
|
62
|
+
parser.add_argument('--min_enr', type=float, default=-1.0)
|
63
|
+
parser.add_argument('--min_edges', type=int, default=-1)
|
64
|
+
parser.add_argument('--language', type=str, default='en')
|
65
|
+
|
66
|
+
parser.add_argument('--use_attributes', action='store_true')
|
67
|
+
parser.add_argument('--use_edge_label', action='store_true')
|
68
|
+
parser.add_argument('--use_edge_types', action='store_true')
|
69
|
+
parser.add_argument('--use_node_types', action='store_true')
|
70
|
+
parser.add_argument('--use_special_tokens', action='store_true')
|
71
|
+
parser.add_argument('--no_labels', action='store_true')
|
72
|
+
|
73
|
+
parser.add_argument('--node_cls_label', type=str, default=None)
|
74
|
+
parser.add_argument('--edge_cls_label', type=str, default=None)
|
75
|
+
|
76
|
+
|
77
|
+
parser.add_argument('--limit', type=int, default=-1)
|
78
|
+
|
79
|
+
|
80
|
+
### Model Dataset Loading
|
81
|
+
parser.add_argument('--distance', type=int, default=0)
|
82
|
+
parser.add_argument('--use_embeddings', action='store_true')
|
83
|
+
parser.add_argument('--regen_embeddings', action='store_true')
|
84
|
+
parser.add_argument(
|
85
|
+
'--embed_model_name',
|
86
|
+
type=str,
|
87
|
+
default='bert-base-uncased',
|
88
|
+
choices=[BERT_MODEL, WORD2VEC_MODEL, TFIDF_MODEL]
|
89
|
+
)
|
90
|
+
parser.add_argument('--max_length', type=int, default=512)
|
91
|
+
parser.add_argument('--ckpt', type=str, default=None)
|
92
|
+
|
93
|
+
|
94
|
+
parser.add_argument('--no_shuffle', action='store_true')
|
95
|
+
parser.add_argument('--randomize_ne', action='store_true')
|
96
|
+
parser.add_argument('--randomize_ee', action='store_true')
|
97
|
+
parser.add_argument('--random_embed_dim', type=int, default=128)
|
98
|
+
|
99
|
+
|
100
|
+
parser.add_argument('--test_ratio', type=float, default=0.2)
|
101
|
+
parser.add_argument('--add_negative_train_samples', action='store_true')
|
102
|
+
parser.add_argument('--neg_sampling_ratio', type=int, default=1)
|
103
|
+
|
104
|
+
parser.add_argument('--num_epochs', type=int, default=100)
|
105
|
+
parser.add_argument('--batch_size', type=int, default=32)
|
106
|
+
|
107
|
+
return parser
|
108
|
+
|
109
|
+
|
110
|
+
def get_gnn_args_parser():
|
111
|
+
parser = ArgumentParser()
|
112
|
+
parser.add_argument('--num_conv_layers', type=int, default=3)
|
113
|
+
parser.add_argument('--num_mlp_layers', type=int, default=3)
|
114
|
+
parser.add_argument('--num_heads', type=int, default=None)
|
115
|
+
|
116
|
+
parser.add_argument('--input_dim', type=int, default=768)
|
117
|
+
parser.add_argument('--hidden_dim', type=int, default=128)
|
118
|
+
parser.add_argument('--output_dim', type=int, default=128)
|
119
|
+
|
120
|
+
parser.add_argument('--residual', action='store_true')
|
121
|
+
parser.add_argument('--bias', action='store_true')
|
122
|
+
parser.add_argument('--l_norm', action='store_true')
|
123
|
+
parser.add_argument('--aggregation', type=str, default='sum')
|
124
|
+
parser.add_argument('--dropout', type=float, default=0.3)
|
125
|
+
parser.add_argument('--lr', type=float, default=1e-3)
|
126
|
+
parser.add_argument('--gnn_conv_model', type=str, default='SAGEConv')
|
127
|
+
|
128
|
+
parser.add_argument('--use_edge_attrs', action='store_true')
|
129
|
+
|
130
|
+
return parser
|
131
|
+
|
132
|
+
|
133
|
+
def get_bert_args_parser():
|
134
|
+
parser = ArgumentParser()
|
135
|
+
|
136
|
+
parser.add_argument('--freeze_pretrained_weights', action='store_true')
|
137
|
+
parser.add_argument('--model_name', type=str, default='bert-base-uncased')
|
138
|
+
|
139
|
+
parser.add_argument('--warmup_steps', type=int, default=200)
|
140
|
+
parser.add_argument('--num_log_steps', type=int, default=200)
|
141
|
+
parser.add_argument('--num_eval_steps', type=int, default=200)
|
142
|
+
parser.add_argument('--num_save_steps', type=int, default=200)
|
143
|
+
parser.add_argument('--train_batch_size', type=int, default=2)
|
144
|
+
parser.add_argument('--eval_batch_size', type=int, default=128)
|
145
|
+
return parser
|
146
|
+
|
147
|
+
|
148
|
+
def get_gpt_args_parser():
|
149
|
+
parser = ArgumentParser()
|
150
|
+
parser.add_argument('--model_name', type=str, default='gpt2')
|
151
|
+
parser.add_argument('--use_special_tokens', action='store_true')
|
152
|
+
|
153
|
+
parser.add_argument('--warmup_steps', type=int, default=200)
|
154
|
+
parser.add_argument('--blocks', type=int, default=6)
|
155
|
+
parser.add_argument('--block_size', type=int, default=8)
|
156
|
+
parser.add_argument('--n_head', type=int, default=8)
|
157
|
+
parser.add_argument('--embed_dim', type=int, default=768)
|
158
|
+
parser.add_argument('--n_layer', type=int, default=6)
|
159
|
+
parser.add_argument('--lr', type=float, default=1e-5)
|
160
|
+
return parser
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import random
|
2
|
+
import torch
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from glam4cm.data_loading.models_dataset import ArchiMateDataset, EcoreDataset
|
6
|
+
from glam4cm.data_loading.graph_dataset import (
|
7
|
+
GraphNodeDataset,
|
8
|
+
GraphEdgeDataset
|
9
|
+
)
|
10
|
+
from glam4cm.downstream_tasks.common_args import get_common_args_parser
|
11
|
+
|
12
|
+
|
13
|
+
def get_parser():
|
14
|
+
parser = get_common_args_parser()
|
15
|
+
return parser.parse_args()
|
16
|
+
|
17
|
+
|
18
|
+
def run(args):
|
19
|
+
|
20
|
+
seed = 42
|
21
|
+
random.seed(seed)
|
22
|
+
np.random.seed(seed)
|
23
|
+
torch.manual_seed(seed)
|
24
|
+
torch.cuda.manual_seed(seed)
|
25
|
+
|
26
|
+
|
27
|
+
config_params = dict(
|
28
|
+
min_enr = args.min_enr,
|
29
|
+
min_edges = args.min_edges,
|
30
|
+
)
|
31
|
+
ecore = EcoreDataset('ecore_555', reload=args.reload, **config_params)
|
32
|
+
modelset = EcoreDataset('modelset', reload=args.reload, remove_duplicates=True, **config_params)
|
33
|
+
mar = EcoreDataset('mar-ecore-github', reload=args.reload, **config_params)
|
34
|
+
eamodelset = ArchiMateDataset('eamodelset', reload=args.reload, **config_params)
|
35
|
+
|
36
|
+
graph_data_params = dict(
|
37
|
+
distance=args.distance,
|
38
|
+
add_negative_train_samples=args.add_neg_samples,
|
39
|
+
neg_sampling_ratio=1,
|
40
|
+
)
|
41
|
+
|
42
|
+
GraphEdgeDataset(ecore, reload=False, **graph_data_params)
|
43
|
+
GraphEdgeDataset(modelset, reload=True, **graph_data_params)
|
44
|
+
GraphEdgeDataset(mar, reload=True, **graph_data_params)
|
45
|
+
GraphEdgeDataset(eamodelset, reload=True, **graph_data_params)
|
46
|
+
|
47
|
+
|
48
|
+
GraphNodeDataset(ecore, reload=False, **graph_data_params)
|
49
|
+
GraphNodeDataset(modelset, reload=True, **graph_data_params)
|
50
|
+
GraphNodeDataset(mar, reload=True, **graph_data_params)
|
51
|
+
GraphNodeDataset(eamodelset, reload=True, **graph_data_params)
|
@@ -0,0 +1,106 @@
|
|
1
|
+
import os
|
2
|
+
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
3
|
+
from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
|
4
|
+
from glam4cm.settings import LP_TASK_EDGE_CLS
|
5
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
6
|
+
from glam4cm.tokenization.special_tokens import *
|
7
|
+
from glam4cm.trainers.gnn_edge_classifier import GNNEdgeClassificationTrainer as Trainer
|
8
|
+
from glam4cm.utils import set_seed, merge_argument_parsers
|
9
|
+
from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
|
10
|
+
|
11
|
+
|
12
|
+
def get_parser():
|
13
|
+
common_parser = get_common_args_parser()
|
14
|
+
gnn_parser = get_gnn_args_parser()
|
15
|
+
parser = merge_argument_parsers(common_parser, gnn_parser)
|
16
|
+
return parser
|
17
|
+
|
18
|
+
|
19
|
+
def run(args):
|
20
|
+
|
21
|
+
set_seed(args.seed)
|
22
|
+
|
23
|
+
config_params = dict(
|
24
|
+
min_enr = args.min_enr,
|
25
|
+
min_edges = args.min_edges,
|
26
|
+
remove_duplicates = args.remove_duplicates,
|
27
|
+
reload = args.reload,
|
28
|
+
language = args.language
|
29
|
+
)
|
30
|
+
dataset_name = args.dataset
|
31
|
+
|
32
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
33
|
+
|
34
|
+
graph_data_params = get_config_params(args)
|
35
|
+
graph_data_params = {**graph_data_params, 'task_type': LP_TASK_EDGE_CLS}
|
36
|
+
|
37
|
+
print("Loading graph dataset")
|
38
|
+
graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
|
39
|
+
print("Loaded graph dataset")
|
40
|
+
|
41
|
+
graph_torch_data = graph_dataset.get_torch_dataset()
|
42
|
+
|
43
|
+
input_dim = graph_torch_data[0].x.shape[1]
|
44
|
+
|
45
|
+
model_name = args.gnn_conv_model
|
46
|
+
hidden_dim = args.hidden_dim
|
47
|
+
output_dim = args.output_dim
|
48
|
+
num_conv_layers = args.num_conv_layers
|
49
|
+
num_mlp_layers = args.num_mlp_layers
|
50
|
+
num_heads = args.num_heads
|
51
|
+
residual = False
|
52
|
+
l_norm = args.l_norm
|
53
|
+
dropout = args.dropout
|
54
|
+
aggregation = args.aggregation
|
55
|
+
|
56
|
+
num_edges_label = f"num_edges_{args.edge_cls_label}"
|
57
|
+
assert hasattr(graph_dataset, num_edges_label), f"Graph dataset does not have attribute {num_edges_label}"
|
58
|
+
num_classes = getattr(graph_dataset, num_edges_label)
|
59
|
+
|
60
|
+
edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.use_edge_attrs else None
|
61
|
+
|
62
|
+
logs_dir = os.path.join(
|
63
|
+
"logs",
|
64
|
+
dataset_name,
|
65
|
+
"gnn_edge_cls",
|
66
|
+
f"{graph_dataset.config_hash}",
|
67
|
+
)
|
68
|
+
|
69
|
+
gnn_conv_model = GNNConv(
|
70
|
+
model_name=model_name,
|
71
|
+
input_dim=input_dim,
|
72
|
+
hidden_dim=hidden_dim,
|
73
|
+
out_dim=output_dim,
|
74
|
+
num_layers=num_conv_layers,
|
75
|
+
num_heads=num_heads,
|
76
|
+
residual=residual,
|
77
|
+
l_norm=l_norm,
|
78
|
+
dropout=dropout,
|
79
|
+
aggregation=aggregation,
|
80
|
+
edge_dim=edge_dim
|
81
|
+
)
|
82
|
+
|
83
|
+
clf_input_dim = output_dim*num_heads if args.num_heads else output_dim
|
84
|
+
mlp_predictor = EdgeClassifer(
|
85
|
+
input_dim=clf_input_dim,
|
86
|
+
hidden_dim=hidden_dim,
|
87
|
+
num_layers=num_mlp_layers,
|
88
|
+
num_classes=num_classes,
|
89
|
+
edge_dim=edge_dim,
|
90
|
+
bias=args.bias,
|
91
|
+
)
|
92
|
+
|
93
|
+
trainer = Trainer(
|
94
|
+
gnn_conv_model,
|
95
|
+
mlp_predictor,
|
96
|
+
graph_torch_data,
|
97
|
+
cls_label=args.edge_cls_label,
|
98
|
+
lr=args.lr,
|
99
|
+
num_epochs=args.num_epochs,
|
100
|
+
batch_size=args.batch_size,
|
101
|
+
use_edge_attrs=args.use_edge_attrs,
|
102
|
+
logs_dir=logs_dir
|
103
|
+
)
|
104
|
+
|
105
|
+
print("Training GNN Edge Classification model")
|
106
|
+
trainer.run()
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import os
|
2
|
+
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
3
|
+
from glam4cm.models.gnn_layers import GNNConv, GraphClassifer
|
4
|
+
from glam4cm.trainers.gnn_graph_classifier import GNNGraphClassificationTrainer as Trainer
|
5
|
+
from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
|
6
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
7
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
8
|
+
|
9
|
+
|
10
|
+
def get_parser():
|
11
|
+
common_parser = get_common_args_parser()
|
12
|
+
gnn_parser = get_gnn_args_parser()
|
13
|
+
parser = merge_argument_parsers(common_parser, gnn_parser)
|
14
|
+
|
15
|
+
parser.add_argument('--cls_label', type=str, default='label')
|
16
|
+
parser.add_argument('--global_pool', type=str, default='mean')
|
17
|
+
return parser
|
18
|
+
|
19
|
+
|
20
|
+
def run(args):
|
21
|
+
set_seed(args.seed)
|
22
|
+
|
23
|
+
config_params = dict(
|
24
|
+
min_enr = args.min_enr,
|
25
|
+
min_edges = args.min_edges,
|
26
|
+
remove_duplicates = args.remove_duplicates,
|
27
|
+
reload = args.reload,
|
28
|
+
language = args.language
|
29
|
+
)
|
30
|
+
dataset_name = args.dataset
|
31
|
+
|
32
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
33
|
+
|
34
|
+
graph_data_params = get_config_params(args)
|
35
|
+
|
36
|
+
print("Loading graph dataset")
|
37
|
+
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
38
|
+
print("Loaded graph dataset")
|
39
|
+
|
40
|
+
cls_label = f"num_graph_{args.cls_label}"
|
41
|
+
assert hasattr(graph_dataset, cls_label), f"Dataset does not have attribute {cls_label}"
|
42
|
+
num_classes = getattr(graph_dataset, cls_label)
|
43
|
+
|
44
|
+
print(f"Number of classes: {num_classes}")
|
45
|
+
model_name = args.gnn_conv_model
|
46
|
+
hidden_dim = args.hidden_dim
|
47
|
+
output_dim = args.output_dim
|
48
|
+
num_conv_layers = args.num_conv_layers
|
49
|
+
num_heads = args.num_heads
|
50
|
+
residual = True
|
51
|
+
l_norm = False
|
52
|
+
dropout = args.dropout
|
53
|
+
aggregation = args.aggregation
|
54
|
+
|
55
|
+
input_dim = graph_dataset[0].data.x.shape[1]
|
56
|
+
|
57
|
+
logs_dir = os.path.join(
|
58
|
+
"logs",
|
59
|
+
dataset_name,
|
60
|
+
"gnn_graph_cls",
|
61
|
+
f"{graph_dataset.config_hash}",
|
62
|
+
)
|
63
|
+
|
64
|
+
for datasets in graph_dataset.get_kfold_gnn_graph_classification_data():
|
65
|
+
|
66
|
+
edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.num_heads else None
|
67
|
+
|
68
|
+
gnn_conv_model = GNNConv(
|
69
|
+
model_name=model_name,
|
70
|
+
input_dim=input_dim,
|
71
|
+
hidden_dim=hidden_dim,
|
72
|
+
out_dim=output_dim,
|
73
|
+
num_layers=num_conv_layers,
|
74
|
+
num_heads=num_heads,
|
75
|
+
residual=residual,
|
76
|
+
l_norm=l_norm,
|
77
|
+
dropout=dropout,
|
78
|
+
aggregation=aggregation,
|
79
|
+
edge_dim=edge_dim,
|
80
|
+
)
|
81
|
+
|
82
|
+
clf_input_dim = output_dim*num_heads if args.num_heads else output_dim
|
83
|
+
classifier = GraphClassifer(
|
84
|
+
input_dim=clf_input_dim,
|
85
|
+
num_classes=num_classes,
|
86
|
+
global_pool=args.global_pool,
|
87
|
+
)
|
88
|
+
|
89
|
+
trainer = Trainer(
|
90
|
+
gnn_conv_model,
|
91
|
+
classifier,
|
92
|
+
datasets,
|
93
|
+
lr=args.lr,
|
94
|
+
num_epochs=args.num_epochs,
|
95
|
+
batch_size=args.batch_size,
|
96
|
+
use_edge_attrs=args.use_edge_attrs,
|
97
|
+
logs_dir=logs_dir
|
98
|
+
)
|
99
|
+
|
100
|
+
trainer.run()
|
101
|
+
break
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import os
|
2
|
+
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
3
|
+
from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
|
4
|
+
from glam4cm.settings import LP_TASK_LINK_PRED
|
5
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
6
|
+
from glam4cm.tokenization.special_tokens import *
|
7
|
+
from glam4cm.trainers.gnn_link_predictor import GNNLinkPredictionTrainer as Trainer
|
8
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
9
|
+
from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
|
10
|
+
|
11
|
+
|
12
|
+
def get_parser():
|
13
|
+
common_parser = get_common_args_parser()
|
14
|
+
gnn_parser = get_gnn_args_parser()
|
15
|
+
parser = merge_argument_parsers(common_parser, gnn_parser)
|
16
|
+
return parser
|
17
|
+
|
18
|
+
|
19
|
+
def run(args):
|
20
|
+
|
21
|
+
set_seed(args.seed)
|
22
|
+
|
23
|
+
config_params = dict(
|
24
|
+
min_enr = args.min_enr,
|
25
|
+
min_edges = args.min_edges,
|
26
|
+
remove_duplicates = args.remove_duplicates,
|
27
|
+
reload = args.reload,
|
28
|
+
language = args.language
|
29
|
+
)
|
30
|
+
dataset_name = args.dataset
|
31
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
32
|
+
|
33
|
+
model_name = args.gnn_conv_model
|
34
|
+
hidden_dim = args.hidden_dim
|
35
|
+
output_dim = args.output_dim
|
36
|
+
num_conv_layers = args.num_conv_layers
|
37
|
+
num_mlp_layers = args.num_mlp_layers
|
38
|
+
num_heads = args.num_heads
|
39
|
+
residual = True
|
40
|
+
l_norm = args.l_norm
|
41
|
+
dropout = args.dropout
|
42
|
+
aggregation = args.aggregation
|
43
|
+
|
44
|
+
graph_data_params = get_config_params(args)
|
45
|
+
print("Loading graph dataset")
|
46
|
+
graph_dataset = GraphEdgeDataset(
|
47
|
+
dataset,
|
48
|
+
dict(
|
49
|
+
**graph_data_params,
|
50
|
+
add_negative_train_samples=args.add_negative_train_samples,
|
51
|
+
neg_sampling_ratio=args.neg_sampling_ratio,
|
52
|
+
task=LP_TASK_LINK_PRED
|
53
|
+
))
|
54
|
+
|
55
|
+
input_dim = graph_dataset[0].data.x.shape[1]
|
56
|
+
|
57
|
+
edge_dim = None
|
58
|
+
if args.use_edge_attrs:
|
59
|
+
if args.use_embeddings:
|
60
|
+
edge_dim = graph_dataset.embedder.embedding_dim
|
61
|
+
else:
|
62
|
+
edge_dim = graph_dataset[0].data.edge_attr.shape[1]
|
63
|
+
|
64
|
+
gnn_conv_model = GNNConv(
|
65
|
+
model_name=model_name,
|
66
|
+
input_dim=input_dim,
|
67
|
+
hidden_dim=hidden_dim,
|
68
|
+
out_dim=output_dim,
|
69
|
+
num_layers=num_conv_layers,
|
70
|
+
num_heads=num_heads,
|
71
|
+
residual=residual,
|
72
|
+
l_norm=l_norm,
|
73
|
+
dropout=dropout,
|
74
|
+
aggregation=aggregation,
|
75
|
+
edge_dim=edge_dim
|
76
|
+
)
|
77
|
+
|
78
|
+
logs_dir = os.path.join(
|
79
|
+
"logs",
|
80
|
+
dataset_name,
|
81
|
+
"gnn_lp",
|
82
|
+
f'{graph_dataset.config_hash}',
|
83
|
+
)
|
84
|
+
|
85
|
+
clf_input_dim = gnn_conv_model.out_dim*num_heads if args.num_heads else output_dim
|
86
|
+
mlp_predictor = EdgeClassifer(
|
87
|
+
input_dim=clf_input_dim,
|
88
|
+
hidden_dim=hidden_dim,
|
89
|
+
num_layers=num_mlp_layers,
|
90
|
+
num_classes=2,
|
91
|
+
edge_dim=edge_dim,
|
92
|
+
bias=False,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
trainer = Trainer(
|
97
|
+
gnn_conv_model,
|
98
|
+
mlp_predictor,
|
99
|
+
graph_dataset.get_torch_dataset(),
|
100
|
+
lr=args.lr,
|
101
|
+
num_epochs=args.num_epochs,
|
102
|
+
batch_size=args.batch_size,
|
103
|
+
use_edge_attrs=args.use_edge_attrs,
|
104
|
+
logs_dir=logs_dir
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
print("Training GNN Link Prediction model")
|
109
|
+
trainer.run()
|