glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,76 @@
1
+ import os
2
+ from glam4cm.downstream_tasks.common_args import (
3
+ get_common_args_parser,
4
+ get_gpt_args_parser
5
+ )
6
+
7
+ from glam4cm.data_loading.graph_dataset import GraphNodeDataset
8
+ from glam4cm.models.cmgpt import CMGPT, CMGPTClassifier
9
+ from glam4cm.downstream_tasks.utils import get_models_dataset
10
+ from glam4cm.tokenization.utils import get_tokenizer
11
+ from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
12
+ from glam4cm.utils import merge_argument_parsers, set_seed
13
+
14
+
15
+ def get_parser():
16
+ common_parser = get_common_args_parser()
17
+ bert_parser = get_gpt_args_parser()
18
+ parser = merge_argument_parsers(common_parser, bert_parser)
19
+
20
+ parser.add_argument('--cls_label', type=str)
21
+ parser.add_argument('--pretr', type=str, default=None)
22
+ return parser
23
+
24
+
25
+ def run(args):
26
+ set_seed(args.seed)
27
+
28
+ tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
29
+
30
+ models_dataset_params = dict(
31
+ language='en',
32
+ )
33
+
34
+ graph_params = dict(
35
+ use_special_tokens=args.use_special_tokens,
36
+ distance=args.distance,
37
+ reload = args.reload
38
+ )
39
+
40
+ models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
41
+ graph_dataset = GraphNodeDataset(models_dataset, **graph_params)
42
+
43
+ assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node labels for {args.node_cls_label}"
44
+
45
+ node_label_dataset = graph_dataset.get_node_classification_lm_data(
46
+ args.node_cls_label,
47
+ tokenizer=tokenizer,
48
+ distance=1,
49
+ )
50
+
51
+ if args.pretr and os.path.exists(args.pretr):
52
+ print(f"Loading pretrained model from {args.pretr}")
53
+ cmgpt = CMGPT.from_pretrained(f"{args.pretr}")
54
+ else:
55
+ print("Creating new model")
56
+ cmgpt = CMGPT(
57
+ vocab_size=len(tokenizer),
58
+ embed_dim=args.embed_dim,
59
+ block_size=args.block_size,
60
+ n_layer=args.n_layer,
61
+ n_head=args.n_head,
62
+ )
63
+
64
+ cmgpt_classifier = CMGPTClassifier(cmgpt, num_classes=getattr(graph_dataset, f"num_nodes_{args.node_cls_label}"))
65
+
66
+ trainer = CMGPTTrainer(
67
+ cmgpt_classifier,
68
+ train_dataset=node_label_dataset['train'],
69
+ test_dataset=node_label_dataset['test'],
70
+ batch_size=args.batch_size,
71
+ num_epochs=args.num_epochs
72
+ )
73
+
74
+ trainer.train()
75
+
76
+ trainer.save_model()
@@ -0,0 +1,64 @@
1
+ from sklearn.model_selection import train_test_split
2
+ from glam4cm.downstream_tasks.common_args import (
3
+ get_common_args_parser,
4
+ get_gpt_args_parser
5
+ )
6
+
7
+ from glam4cm.data_loading.graph_dataset import get_models_gpt_dataset
8
+ from glam4cm.models.cmgpt import CMGPT
9
+ from glam4cm.downstream_tasks.utils import get_models_dataset
10
+ from glam4cm.tokenization.utils import get_tokenizer
11
+ from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
12
+ from glam4cm.utils import merge_argument_parsers, set_seed
13
+
14
+
15
+ def get_parser():
16
+ common_parser = get_common_args_parser()
17
+ bert_parser = get_gpt_args_parser()
18
+ parser = merge_argument_parsers(common_parser, bert_parser)
19
+ return parser
20
+
21
+
22
+ def run(args):
23
+ set_seed(args.seed)
24
+
25
+ tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
26
+
27
+ models_dataset_params = dict(
28
+ language='en',
29
+ )
30
+
31
+ graph_params = dict(
32
+ use_special_tokens=True,
33
+ distance=1,
34
+ )
35
+
36
+ models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
37
+ graph_dataset = get_models_gpt_dataset(
38
+ models_dataset,
39
+ tokenizer,
40
+ **graph_params
41
+ )
42
+
43
+ train_dataset, test_dataset = train_test_split(graph_dataset, test_size=0.05)
44
+
45
+ cmgpt = CMGPT(
46
+ vocab_size=len(tokenizer),
47
+ embed_dim=args.embed_dim,
48
+ block_size=args.block_size,
49
+ n_layer=args.n_layer,
50
+ n_head=args.n_head,
51
+ )
52
+
53
+ trainer = CMGPTTrainer(
54
+ cmgpt,
55
+ train_dataset=train_dataset,
56
+ test_dataset=test_dataset,
57
+ test_ratio=0.05,
58
+ batch_size=args.batch_size,
59
+ num_epochs=args.num_epochs
60
+ )
61
+
62
+ trainer.train()
63
+
64
+ trainer.save_model()
@@ -0,0 +1,160 @@
1
+ from argparse import ArgumentParser
2
+ from glam4cm.settings import (
3
+ BERT_MODEL,
4
+ WORD2VEC_MODEL,
5
+ TFIDF_MODEL
6
+ )
7
+
8
+
9
+ def get_config_params(args):
10
+ common_params = dict(
11
+ distance=args.distance,
12
+ reload=args.reload,
13
+ test_ratio=args.test_ratio,
14
+
15
+ use_attributes=args.use_attributes,
16
+ use_node_types=args.use_node_types,
17
+ use_edge_types=args.use_edge_types,
18
+ use_edge_label=args.use_edge_label,
19
+ no_labels=args.no_labels,
20
+
21
+ use_special_tokens=args.use_special_tokens,
22
+
23
+ use_embeddings=args.use_embeddings,
24
+ embed_model_name=args.embed_model_name,
25
+ ckpt=args.ckpt,
26
+
27
+ no_shuffle=args.no_shuffle,
28
+ randomize_ne=args.randomize_ne,
29
+ randomize_ee=args.randomize_ee,
30
+ random_embed_dim=args.random_embed_dim,
31
+
32
+ limit = args.limit,
33
+
34
+ node_cls_label=args.node_cls_label,
35
+ edge_cls_label=args.edge_cls_label,
36
+ )
37
+
38
+
39
+ return common_params
40
+
41
+
42
+ def get_common_args_parser():
43
+ parser = ArgumentParser()
44
+
45
+ parser.add_argument('--seed', type=int, default=42)
46
+
47
+ ### Models Dataset Creation
48
+ parser.add_argument(
49
+ '--dataset',
50
+ type=str,
51
+ default='ecore_555',
52
+ choices=[
53
+ 'modelset',
54
+ 'ecore_555',
55
+ 'mar-ecore-github',
56
+ 'eamodelset',
57
+ 'ontouml'
58
+ ]
59
+ )
60
+ parser.add_argument('--remove_duplicates', action='store_true')
61
+ parser.add_argument('--reload', action='store_true')
62
+ parser.add_argument('--min_enr', type=float, default=-1.0)
63
+ parser.add_argument('--min_edges', type=int, default=-1)
64
+ parser.add_argument('--language', type=str, default='en')
65
+
66
+ parser.add_argument('--use_attributes', action='store_true')
67
+ parser.add_argument('--use_edge_label', action='store_true')
68
+ parser.add_argument('--use_edge_types', action='store_true')
69
+ parser.add_argument('--use_node_types', action='store_true')
70
+ parser.add_argument('--use_special_tokens', action='store_true')
71
+ parser.add_argument('--no_labels', action='store_true')
72
+
73
+ parser.add_argument('--node_cls_label', type=str, default=None)
74
+ parser.add_argument('--edge_cls_label', type=str, default=None)
75
+
76
+
77
+ parser.add_argument('--limit', type=int, default=-1)
78
+
79
+
80
+ ### Model Dataset Loading
81
+ parser.add_argument('--distance', type=int, default=0)
82
+ parser.add_argument('--use_embeddings', action='store_true')
83
+ parser.add_argument('--regen_embeddings', action='store_true')
84
+ parser.add_argument(
85
+ '--embed_model_name',
86
+ type=str,
87
+ default='bert-base-uncased',
88
+ choices=[BERT_MODEL, WORD2VEC_MODEL, TFIDF_MODEL]
89
+ )
90
+ parser.add_argument('--max_length', type=int, default=512)
91
+ parser.add_argument('--ckpt', type=str, default=None)
92
+
93
+
94
+ parser.add_argument('--no_shuffle', action='store_true')
95
+ parser.add_argument('--randomize_ne', action='store_true')
96
+ parser.add_argument('--randomize_ee', action='store_true')
97
+ parser.add_argument('--random_embed_dim', type=int, default=128)
98
+
99
+
100
+ parser.add_argument('--test_ratio', type=float, default=0.2)
101
+ parser.add_argument('--add_negative_train_samples', action='store_true')
102
+ parser.add_argument('--neg_sampling_ratio', type=int, default=1)
103
+
104
+ parser.add_argument('--num_epochs', type=int, default=100)
105
+ parser.add_argument('--batch_size', type=int, default=32)
106
+
107
+ return parser
108
+
109
+
110
+ def get_gnn_args_parser():
111
+ parser = ArgumentParser()
112
+ parser.add_argument('--num_conv_layers', type=int, default=3)
113
+ parser.add_argument('--num_mlp_layers', type=int, default=3)
114
+ parser.add_argument('--num_heads', type=int, default=None)
115
+
116
+ parser.add_argument('--input_dim', type=int, default=768)
117
+ parser.add_argument('--hidden_dim', type=int, default=128)
118
+ parser.add_argument('--output_dim', type=int, default=128)
119
+
120
+ parser.add_argument('--residual', action='store_true')
121
+ parser.add_argument('--bias', action='store_true')
122
+ parser.add_argument('--l_norm', action='store_true')
123
+ parser.add_argument('--aggregation', type=str, default='sum')
124
+ parser.add_argument('--dropout', type=float, default=0.3)
125
+ parser.add_argument('--lr', type=float, default=1e-3)
126
+ parser.add_argument('--gnn_conv_model', type=str, default='SAGEConv')
127
+
128
+ parser.add_argument('--use_edge_attrs', action='store_true')
129
+
130
+ return parser
131
+
132
+
133
+ def get_bert_args_parser():
134
+ parser = ArgumentParser()
135
+
136
+ parser.add_argument('--freeze_pretrained_weights', action='store_true')
137
+ parser.add_argument('--model_name', type=str, default='bert-base-uncased')
138
+
139
+ parser.add_argument('--warmup_steps', type=int, default=200)
140
+ parser.add_argument('--num_log_steps', type=int, default=200)
141
+ parser.add_argument('--num_eval_steps', type=int, default=200)
142
+ parser.add_argument('--num_save_steps', type=int, default=200)
143
+ parser.add_argument('--train_batch_size', type=int, default=2)
144
+ parser.add_argument('--eval_batch_size', type=int, default=128)
145
+ return parser
146
+
147
+
148
+ def get_gpt_args_parser():
149
+ parser = ArgumentParser()
150
+ parser.add_argument('--model_name', type=str, default='gpt2')
151
+ parser.add_argument('--use_special_tokens', action='store_true')
152
+
153
+ parser.add_argument('--warmup_steps', type=int, default=200)
154
+ parser.add_argument('--blocks', type=int, default=6)
155
+ parser.add_argument('--block_size', type=int, default=8)
156
+ parser.add_argument('--n_head', type=int, default=8)
157
+ parser.add_argument('--embed_dim', type=int, default=768)
158
+ parser.add_argument('--n_layer', type=int, default=6)
159
+ parser.add_argument('--lr', type=float, default=1e-5)
160
+ return parser
@@ -0,0 +1,51 @@
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+
5
+ from glam4cm.data_loading.models_dataset import ArchiMateDataset, EcoreDataset
6
+ from glam4cm.data_loading.graph_dataset import (
7
+ GraphNodeDataset,
8
+ GraphEdgeDataset
9
+ )
10
+ from glam4cm.downstream_tasks.common_args import get_common_args_parser
11
+
12
+
13
+ def get_parser():
14
+ parser = get_common_args_parser()
15
+ return parser.parse_args()
16
+
17
+
18
+ def run(args):
19
+
20
+ seed = 42
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed(seed)
25
+
26
+
27
+ config_params = dict(
28
+ min_enr = args.min_enr,
29
+ min_edges = args.min_edges,
30
+ )
31
+ ecore = EcoreDataset('ecore_555', reload=args.reload, **config_params)
32
+ modelset = EcoreDataset('modelset', reload=args.reload, remove_duplicates=True, **config_params)
33
+ mar = EcoreDataset('mar-ecore-github', reload=args.reload, **config_params)
34
+ eamodelset = ArchiMateDataset('eamodelset', reload=args.reload, **config_params)
35
+
36
+ graph_data_params = dict(
37
+ distance=args.distance,
38
+ add_negative_train_samples=args.add_neg_samples,
39
+ neg_sampling_ratio=1,
40
+ )
41
+
42
+ GraphEdgeDataset(ecore, reload=False, **graph_data_params)
43
+ GraphEdgeDataset(modelset, reload=True, **graph_data_params)
44
+ GraphEdgeDataset(mar, reload=True, **graph_data_params)
45
+ GraphEdgeDataset(eamodelset, reload=True, **graph_data_params)
46
+
47
+
48
+ GraphNodeDataset(ecore, reload=False, **graph_data_params)
49
+ GraphNodeDataset(modelset, reload=True, **graph_data_params)
50
+ GraphNodeDataset(mar, reload=True, **graph_data_params)
51
+ GraphNodeDataset(eamodelset, reload=True, **graph_data_params)
@@ -0,0 +1,106 @@
1
+ import os
2
+ from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
3
+ from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
4
+ from glam4cm.settings import LP_TASK_EDGE_CLS
5
+ from glam4cm.downstream_tasks.utils import get_models_dataset
6
+ from glam4cm.tokenization.special_tokens import *
7
+ from glam4cm.trainers.gnn_edge_classifier import GNNEdgeClassificationTrainer as Trainer
8
+ from glam4cm.utils import set_seed, merge_argument_parsers
9
+ from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
10
+
11
+
12
+ def get_parser():
13
+ common_parser = get_common_args_parser()
14
+ gnn_parser = get_gnn_args_parser()
15
+ parser = merge_argument_parsers(common_parser, gnn_parser)
16
+ return parser
17
+
18
+
19
+ def run(args):
20
+
21
+ set_seed(args.seed)
22
+
23
+ config_params = dict(
24
+ min_enr = args.min_enr,
25
+ min_edges = args.min_edges,
26
+ remove_duplicates = args.remove_duplicates,
27
+ reload = args.reload,
28
+ language = args.language
29
+ )
30
+ dataset_name = args.dataset
31
+
32
+ dataset = get_models_dataset(dataset_name, **config_params)
33
+
34
+ graph_data_params = get_config_params(args)
35
+ graph_data_params = {**graph_data_params, 'task_type': LP_TASK_EDGE_CLS}
36
+
37
+ print("Loading graph dataset")
38
+ graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
39
+ print("Loaded graph dataset")
40
+
41
+ graph_torch_data = graph_dataset.get_torch_dataset()
42
+
43
+ input_dim = graph_torch_data[0].x.shape[1]
44
+
45
+ model_name = args.gnn_conv_model
46
+ hidden_dim = args.hidden_dim
47
+ output_dim = args.output_dim
48
+ num_conv_layers = args.num_conv_layers
49
+ num_mlp_layers = args.num_mlp_layers
50
+ num_heads = args.num_heads
51
+ residual = False
52
+ l_norm = args.l_norm
53
+ dropout = args.dropout
54
+ aggregation = args.aggregation
55
+
56
+ num_edges_label = f"num_edges_{args.edge_cls_label}"
57
+ assert hasattr(graph_dataset, num_edges_label), f"Graph dataset does not have attribute {num_edges_label}"
58
+ num_classes = getattr(graph_dataset, num_edges_label)
59
+
60
+ edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.use_edge_attrs else None
61
+
62
+ logs_dir = os.path.join(
63
+ "logs",
64
+ dataset_name,
65
+ "gnn_edge_cls",
66
+ f"{graph_dataset.config_hash}",
67
+ )
68
+
69
+ gnn_conv_model = GNNConv(
70
+ model_name=model_name,
71
+ input_dim=input_dim,
72
+ hidden_dim=hidden_dim,
73
+ out_dim=output_dim,
74
+ num_layers=num_conv_layers,
75
+ num_heads=num_heads,
76
+ residual=residual,
77
+ l_norm=l_norm,
78
+ dropout=dropout,
79
+ aggregation=aggregation,
80
+ edge_dim=edge_dim
81
+ )
82
+
83
+ clf_input_dim = output_dim*num_heads if args.num_heads else output_dim
84
+ mlp_predictor = EdgeClassifer(
85
+ input_dim=clf_input_dim,
86
+ hidden_dim=hidden_dim,
87
+ num_layers=num_mlp_layers,
88
+ num_classes=num_classes,
89
+ edge_dim=edge_dim,
90
+ bias=args.bias,
91
+ )
92
+
93
+ trainer = Trainer(
94
+ gnn_conv_model,
95
+ mlp_predictor,
96
+ graph_torch_data,
97
+ cls_label=args.edge_cls_label,
98
+ lr=args.lr,
99
+ num_epochs=args.num_epochs,
100
+ batch_size=args.batch_size,
101
+ use_edge_attrs=args.use_edge_attrs,
102
+ logs_dir=logs_dir
103
+ )
104
+
105
+ print("Training GNN Edge Classification model")
106
+ trainer.run()
@@ -0,0 +1,101 @@
1
+ import os
2
+ from glam4cm.data_loading.graph_dataset import GraphNodeDataset
3
+ from glam4cm.models.gnn_layers import GNNConv, GraphClassifer
4
+ from glam4cm.trainers.gnn_graph_classifier import GNNGraphClassificationTrainer as Trainer
5
+ from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
6
+ from glam4cm.utils import merge_argument_parsers, set_seed
7
+ from glam4cm.downstream_tasks.utils import get_models_dataset
8
+
9
+
10
+ def get_parser():
11
+ common_parser = get_common_args_parser()
12
+ gnn_parser = get_gnn_args_parser()
13
+ parser = merge_argument_parsers(common_parser, gnn_parser)
14
+
15
+ parser.add_argument('--cls_label', type=str, default='label')
16
+ parser.add_argument('--global_pool', type=str, default='mean')
17
+ return parser
18
+
19
+
20
+ def run(args):
21
+ set_seed(args.seed)
22
+
23
+ config_params = dict(
24
+ min_enr = args.min_enr,
25
+ min_edges = args.min_edges,
26
+ remove_duplicates = args.remove_duplicates,
27
+ reload = args.reload,
28
+ language = args.language
29
+ )
30
+ dataset_name = args.dataset
31
+
32
+ dataset = get_models_dataset(dataset_name, **config_params)
33
+
34
+ graph_data_params = get_config_params(args)
35
+
36
+ print("Loading graph dataset")
37
+ graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
38
+ print("Loaded graph dataset")
39
+
40
+ cls_label = f"num_graph_{args.cls_label}"
41
+ assert hasattr(graph_dataset, cls_label), f"Dataset does not have attribute {cls_label}"
42
+ num_classes = getattr(graph_dataset, cls_label)
43
+
44
+ print(f"Number of classes: {num_classes}")
45
+ model_name = args.gnn_conv_model
46
+ hidden_dim = args.hidden_dim
47
+ output_dim = args.output_dim
48
+ num_conv_layers = args.num_conv_layers
49
+ num_heads = args.num_heads
50
+ residual = True
51
+ l_norm = False
52
+ dropout = args.dropout
53
+ aggregation = args.aggregation
54
+
55
+ input_dim = graph_dataset[0].data.x.shape[1]
56
+
57
+ logs_dir = os.path.join(
58
+ "logs",
59
+ dataset_name,
60
+ "gnn_graph_cls",
61
+ f"{graph_dataset.config_hash}",
62
+ )
63
+
64
+ for datasets in graph_dataset.get_kfold_gnn_graph_classification_data():
65
+
66
+ edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.num_heads else None
67
+
68
+ gnn_conv_model = GNNConv(
69
+ model_name=model_name,
70
+ input_dim=input_dim,
71
+ hidden_dim=hidden_dim,
72
+ out_dim=output_dim,
73
+ num_layers=num_conv_layers,
74
+ num_heads=num_heads,
75
+ residual=residual,
76
+ l_norm=l_norm,
77
+ dropout=dropout,
78
+ aggregation=aggregation,
79
+ edge_dim=edge_dim,
80
+ )
81
+
82
+ clf_input_dim = output_dim*num_heads if args.num_heads else output_dim
83
+ classifier = GraphClassifer(
84
+ input_dim=clf_input_dim,
85
+ num_classes=num_classes,
86
+ global_pool=args.global_pool,
87
+ )
88
+
89
+ trainer = Trainer(
90
+ gnn_conv_model,
91
+ classifier,
92
+ datasets,
93
+ lr=args.lr,
94
+ num_epochs=args.num_epochs,
95
+ batch_size=args.batch_size,
96
+ use_edge_attrs=args.use_edge_attrs,
97
+ logs_dir=logs_dir
98
+ )
99
+
100
+ trainer.run()
101
+ break
@@ -0,0 +1,109 @@
1
+ import os
2
+ from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
3
+ from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
4
+ from glam4cm.settings import LP_TASK_LINK_PRED
5
+ from glam4cm.downstream_tasks.utils import get_models_dataset
6
+ from glam4cm.tokenization.special_tokens import *
7
+ from glam4cm.trainers.gnn_link_predictor import GNNLinkPredictionTrainer as Trainer
8
+ from glam4cm.utils import merge_argument_parsers, set_seed
9
+ from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
10
+
11
+
12
+ def get_parser():
13
+ common_parser = get_common_args_parser()
14
+ gnn_parser = get_gnn_args_parser()
15
+ parser = merge_argument_parsers(common_parser, gnn_parser)
16
+ return parser
17
+
18
+
19
+ def run(args):
20
+
21
+ set_seed(args.seed)
22
+
23
+ config_params = dict(
24
+ min_enr = args.min_enr,
25
+ min_edges = args.min_edges,
26
+ remove_duplicates = args.remove_duplicates,
27
+ reload = args.reload,
28
+ language = args.language
29
+ )
30
+ dataset_name = args.dataset
31
+ dataset = get_models_dataset(dataset_name, **config_params)
32
+
33
+ model_name = args.gnn_conv_model
34
+ hidden_dim = args.hidden_dim
35
+ output_dim = args.output_dim
36
+ num_conv_layers = args.num_conv_layers
37
+ num_mlp_layers = args.num_mlp_layers
38
+ num_heads = args.num_heads
39
+ residual = True
40
+ l_norm = args.l_norm
41
+ dropout = args.dropout
42
+ aggregation = args.aggregation
43
+
44
+ graph_data_params = get_config_params(args)
45
+ print("Loading graph dataset")
46
+ graph_dataset = GraphEdgeDataset(
47
+ dataset,
48
+ dict(
49
+ **graph_data_params,
50
+ add_negative_train_samples=args.add_negative_train_samples,
51
+ neg_sampling_ratio=args.neg_sampling_ratio,
52
+ task=LP_TASK_LINK_PRED
53
+ ))
54
+
55
+ input_dim = graph_dataset[0].data.x.shape[1]
56
+
57
+ edge_dim = None
58
+ if args.use_edge_attrs:
59
+ if args.use_embeddings:
60
+ edge_dim = graph_dataset.embedder.embedding_dim
61
+ else:
62
+ edge_dim = graph_dataset[0].data.edge_attr.shape[1]
63
+
64
+ gnn_conv_model = GNNConv(
65
+ model_name=model_name,
66
+ input_dim=input_dim,
67
+ hidden_dim=hidden_dim,
68
+ out_dim=output_dim,
69
+ num_layers=num_conv_layers,
70
+ num_heads=num_heads,
71
+ residual=residual,
72
+ l_norm=l_norm,
73
+ dropout=dropout,
74
+ aggregation=aggregation,
75
+ edge_dim=edge_dim
76
+ )
77
+
78
+ logs_dir = os.path.join(
79
+ "logs",
80
+ dataset_name,
81
+ "gnn_lp",
82
+ f'{graph_dataset.config_hash}',
83
+ )
84
+
85
+ clf_input_dim = gnn_conv_model.out_dim*num_heads if args.num_heads else output_dim
86
+ mlp_predictor = EdgeClassifer(
87
+ input_dim=clf_input_dim,
88
+ hidden_dim=hidden_dim,
89
+ num_layers=num_mlp_layers,
90
+ num_classes=2,
91
+ edge_dim=edge_dim,
92
+ bias=False,
93
+ )
94
+
95
+
96
+ trainer = Trainer(
97
+ gnn_conv_model,
98
+ mlp_predictor,
99
+ graph_dataset.get_torch_dataset(),
100
+ lr=args.lr,
101
+ num_epochs=args.num_epochs,
102
+ batch_size=args.batch_size,
103
+ use_edge_attrs=args.use_edge_attrs,
104
+ logs_dir=logs_dir
105
+ )
106
+
107
+
108
+ print("Training GNN Link Prediction model")
109
+ trainer.run()