glam4cm 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +2 -1
- glam4cm/data_loading/data.py +90 -146
- glam4cm/data_loading/encoding.py +17 -6
- glam4cm/data_loading/graph_dataset.py +192 -57
- glam4cm/data_loading/metadata.py +1 -1
- glam4cm/data_loading/models_dataset.py +42 -18
- glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
- glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
- glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
- glam4cm/downstream_tasks/bert_node_classification.py +127 -89
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
- glam4cm/downstream_tasks/common_args.py +32 -4
- glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
- glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
- glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
- glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
- glam4cm/downstream_tasks/utils.py +16 -2
- glam4cm/embeddings/bert.py +1 -1
- glam4cm/embeddings/common.py +7 -4
- glam4cm/encoding/encoders.py +1 -1
- glam4cm/lang2graph/archimate.py +0 -5
- glam4cm/lang2graph/common.py +99 -41
- glam4cm/lang2graph/ecore.py +1 -2
- glam4cm/lang2graph/ontouml.py +8 -7
- glam4cm/models/gnn_layers.py +20 -6
- glam4cm/models/hf.py +2 -2
- glam4cm/run.py +13 -9
- glam4cm/run_conf_v2.py +405 -0
- glam4cm/run_configs.py +70 -106
- glam4cm/run_confs.py +41 -0
- glam4cm/settings.py +15 -2
- glam4cm/tokenization/special_tokens.py +23 -1
- glam4cm/tokenization/utils.py +23 -4
- glam4cm/trainers/cm_gpt_trainer.py +1 -1
- glam4cm/trainers/gnn_edge_classifier.py +12 -1
- glam4cm/trainers/gnn_graph_classifier.py +12 -5
- glam4cm/trainers/gnn_link_predictor.py +18 -3
- glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
- glam4cm/trainers/gnn_trainer.py +8 -0
- glam4cm/trainers/metrics.py +1 -1
- glam4cm/utils.py +265 -2
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
- glam4cm-1.0.0.dist-info/RECORD +75 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
- glam4cm-0.1.0.dist-info/RECORD +0 -72
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,18 @@
|
|
1
1
|
import os
|
2
2
|
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
3
3
|
from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
|
4
|
-
from glam4cm.settings import
|
4
|
+
from glam4cm.settings import LINK_PRED_TASK, results_dir
|
5
5
|
from glam4cm.downstream_tasks.utils import get_models_dataset
|
6
6
|
from glam4cm.tokenization.special_tokens import *
|
7
7
|
from glam4cm.trainers.gnn_link_predictor import GNNLinkPredictionTrainer as Trainer
|
8
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
9
|
-
from glam4cm.downstream_tasks.common_args import
|
10
|
-
|
8
|
+
from glam4cm.utils import merge_argument_parsers, set_seed, set_torch_encoding_labels
|
9
|
+
from glam4cm.downstream_tasks.common_args import (
|
10
|
+
get_common_args_parser,
|
11
|
+
get_config_params,
|
12
|
+
get_gnn_args_parser
|
13
|
+
)
|
11
14
|
|
15
|
+
|
12
16
|
def get_parser():
|
13
17
|
common_parser = get_common_args_parser()
|
14
18
|
gnn_parser = get_gnn_args_parser()
|
@@ -21,6 +25,7 @@ def run(args):
|
|
21
25
|
set_seed(args.seed)
|
22
26
|
|
23
27
|
config_params = dict(
|
28
|
+
include_dummies = args.include_dummies,
|
24
29
|
min_enr = args.min_enr,
|
25
30
|
min_edges = args.min_edges,
|
26
31
|
remove_duplicates = args.remove_duplicates,
|
@@ -42,14 +47,18 @@ def run(args):
|
|
42
47
|
aggregation = args.aggregation
|
43
48
|
|
44
49
|
graph_data_params = get_config_params(args)
|
50
|
+
|
51
|
+
if args.use_embeddings:
|
52
|
+
graph_data_params['embed_model_name'] = os.path.join(results_dir, dataset_name, f"LM_{LINK_PRED_TASK}")
|
53
|
+
|
45
54
|
print("Loading graph dataset")
|
46
55
|
graph_dataset = GraphEdgeDataset(
|
47
|
-
dataset,
|
48
|
-
|
56
|
+
dataset,
|
57
|
+
task_type=LINK_PRED_TASK,
|
58
|
+
**dict(
|
49
59
|
**graph_data_params,
|
50
|
-
add_negative_train_samples=
|
60
|
+
add_negative_train_samples=True,
|
51
61
|
neg_sampling_ratio=args.neg_sampling_ratio,
|
52
|
-
task=LP_TASK_LINK_PRED
|
53
62
|
))
|
54
63
|
|
55
64
|
input_dim = graph_dataset[0].data.x.shape[1]
|
@@ -78,7 +87,7 @@ def run(args):
|
|
78
87
|
logs_dir = os.path.join(
|
79
88
|
"logs",
|
80
89
|
dataset_name,
|
81
|
-
"
|
90
|
+
f"GNN_{LINK_PRED_TASK}",
|
82
91
|
f'{graph_dataset.config_hash}',
|
83
92
|
)
|
84
93
|
|
@@ -92,11 +101,14 @@ def run(args):
|
|
92
101
|
bias=False,
|
93
102
|
)
|
94
103
|
|
104
|
+
graph_torch_data = graph_dataset.get_torch_dataset()
|
105
|
+
# exclude_labels = getattr(graph_dataset, f"node_exclude_{args.node_cls_label}")
|
106
|
+
# set_torch_encoding_labels(graph_torch_data, f"node_{args.node_cls_label}", exclude_labels)
|
95
107
|
|
96
108
|
trainer = Trainer(
|
97
|
-
gnn_conv_model,
|
98
|
-
mlp_predictor,
|
99
|
-
|
109
|
+
model=gnn_conv_model,
|
110
|
+
predictor=mlp_predictor,
|
111
|
+
dataset=graph_torch_data,
|
100
112
|
lr=args.lr,
|
101
113
|
num_epochs=args.num_epochs,
|
102
114
|
batch_size=args.batch_size,
|
@@ -106,4 +118,4 @@ def run(args):
|
|
106
118
|
|
107
119
|
|
108
120
|
print("Training GNN Link Prediction model")
|
109
|
-
trainer.run()
|
121
|
+
trainer.run()
|
@@ -2,10 +2,15 @@ import os
|
|
2
2
|
from glam4cm.data_loading.graph_dataset import GraphNodeDataset
|
3
3
|
from glam4cm.models.gnn_layers import GNNConv, NodeClassifier
|
4
4
|
from glam4cm.downstream_tasks.utils import get_models_dataset
|
5
|
+
from glam4cm.settings import NODE_CLS_TASK, results_dir
|
5
6
|
from glam4cm.tokenization.special_tokens import *
|
6
7
|
from glam4cm.trainers.gnn_node_classifier import GNNNodeClassificationTrainer as Trainer
|
7
|
-
from glam4cm.utils import merge_argument_parsers, set_seed
|
8
|
-
from glam4cm.downstream_tasks.common_args import
|
8
|
+
from glam4cm.utils import merge_argument_parsers, set_seed, set_torch_encoding_labels
|
9
|
+
from glam4cm.downstream_tasks.common_args import (
|
10
|
+
get_common_args_parser,
|
11
|
+
get_config_params,
|
12
|
+
get_gnn_args_parser
|
13
|
+
)
|
9
14
|
|
10
15
|
|
11
16
|
def get_parser():
|
@@ -20,6 +25,7 @@ def run(args):
|
|
20
25
|
set_seed(args.seed)
|
21
26
|
|
22
27
|
config_params = dict(
|
28
|
+
include_dummies = args.include_dummies,
|
23
29
|
min_enr = args.min_enr,
|
24
30
|
min_edges = args.min_edges,
|
25
31
|
remove_duplicates = args.remove_duplicates,
|
@@ -29,13 +35,19 @@ def run(args):
|
|
29
35
|
dataset_name = args.dataset
|
30
36
|
|
31
37
|
dataset = get_models_dataset(dataset_name, **config_params)
|
32
|
-
graph_data_params = get_config_params(args)
|
38
|
+
graph_data_params = {**get_config_params(args), 'task_type': NODE_CLS_TASK}
|
39
|
+
|
40
|
+
if args.use_embeddings:
|
41
|
+
graph_data_params['embed_model_name'] = os.path.join(results_dir, dataset_name, f'{args.node_cls_label}')
|
33
42
|
|
34
43
|
print("Loading graph dataset")
|
35
44
|
graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
|
36
45
|
print("Loaded graph dataset")
|
37
|
-
|
46
|
+
|
47
|
+
|
38
48
|
graph_torch_data = graph_dataset.get_torch_dataset()
|
49
|
+
exclude_labels = getattr(graph_dataset, f"node_exclude_{args.node_cls_label}")
|
50
|
+
set_torch_encoding_labels(graph_torch_data, f"node_{args.node_cls_label}", exclude_labels)
|
39
51
|
|
40
52
|
num_nodes_label = f"num_nodes_{args.node_cls_label}"
|
41
53
|
assert hasattr(graph_dataset, num_nodes_label), f"Graph dataset does not have attribute {num_nodes_label}"
|
@@ -83,7 +95,7 @@ def run(args):
|
|
83
95
|
logs_dir = os.path.join(
|
84
96
|
"logs",
|
85
97
|
dataset_name,
|
86
|
-
"
|
98
|
+
f"GNN_{NODE_CLS_TASK}",
|
87
99
|
f"{graph_dataset.config_hash}",
|
88
100
|
)
|
89
101
|
|
@@ -92,7 +104,7 @@ def run(args):
|
|
92
104
|
mlp_predictor,
|
93
105
|
graph_torch_data,
|
94
106
|
cls_label=args.node_cls_label,
|
95
|
-
exclude_labels=
|
107
|
+
exclude_labels=[-1],
|
96
108
|
lr=args.lr,
|
97
109
|
num_epochs=args.num_epochs,
|
98
110
|
use_edge_attrs=args.use_edge_attrs,
|
@@ -100,4 +112,4 @@ def run(args):
|
|
100
112
|
)
|
101
113
|
|
102
114
|
print("Training GNN Node Classification model")
|
103
|
-
trainer.run()
|
115
|
+
trainer.run()
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from glam4cm.data_loading.models_dataset import (
|
2
2
|
ArchiMateDataset,
|
3
|
-
EcoreDataset
|
3
|
+
EcoreDataset,
|
4
|
+
OntoUMLDataset
|
4
5
|
)
|
5
6
|
|
6
7
|
|
@@ -8,7 +9,8 @@ dataset_to_metamodel = {
|
|
8
9
|
'modelset': 'ecore',
|
9
10
|
'ecore_555': 'ecore',
|
10
11
|
'mar-ecore-github': 'ecore',
|
11
|
-
'eamodelset': 'ea'
|
12
|
+
'eamodelset': 'ea',
|
13
|
+
'ontouml': 'ontouml',
|
12
14
|
}
|
13
15
|
|
14
16
|
|
@@ -22,6 +24,8 @@ def get_model_dataset_class(dataset_name):
|
|
22
24
|
dataset_class = ArchiMateDataset
|
23
25
|
elif dataset_type == 'ecore':
|
24
26
|
dataset_class = EcoreDataset
|
27
|
+
elif dataset_type == 'ontouml':
|
28
|
+
dataset_class = OntoUMLDataset
|
25
29
|
else:
|
26
30
|
raise ValueError(f"Unknown dataset type: {dataset_type}")
|
27
31
|
return dataset_class
|
@@ -33,3 +37,13 @@ def get_models_dataset(dataset_name, **config_params):
|
|
33
37
|
del config_params['language']
|
34
38
|
dataset_class = get_model_dataset_class(dataset_name)
|
35
39
|
return dataset_class(dataset_name, **config_params)
|
40
|
+
|
41
|
+
|
42
|
+
def get_logging_steps(dataset_size, num_epochs, batch_size):
|
43
|
+
"""
|
44
|
+
Calculate the logging steps based on the dataset size, number of epochs, and batch size.
|
45
|
+
"""
|
46
|
+
num_steps = dataset_size // batch_size
|
47
|
+
logging_steps = num_steps * num_epochs // 20
|
48
|
+
print(f"Logging steps: {logging_steps}")
|
49
|
+
return logging_steps
|
glam4cm/embeddings/bert.py
CHANGED
@@ -36,7 +36,7 @@ class BertEmbedder(Embedder):
|
|
36
36
|
print("Number of Texts: ", len(text))
|
37
37
|
|
38
38
|
dataset = EncodingDataset(self.tokenizer, texts=text, remove_duplicates=False)
|
39
|
-
loader = DataLoader(dataset, batch_size=
|
39
|
+
loader = DataLoader(dataset, batch_size=64)
|
40
40
|
|
41
41
|
embeddings = list()
|
42
42
|
with torch.no_grad():
|
glam4cm/embeddings/common.py
CHANGED
@@ -5,7 +5,9 @@ from typing import List, Union
|
|
5
5
|
import torch
|
6
6
|
from glam4cm.settings import (
|
7
7
|
WORD2VEC_MODEL,
|
8
|
-
TFIDF_MODEL
|
8
|
+
TFIDF_MODEL,
|
9
|
+
MODERN_BERT,
|
10
|
+
BERT_MODEL
|
9
11
|
)
|
10
12
|
|
11
13
|
|
@@ -27,10 +29,11 @@ def get_embedding_model(
|
|
27
29
|
model_name: str,
|
28
30
|
ckpt: str = None
|
29
31
|
) -> Embedder:
|
30
|
-
if ckpt:
|
31
|
-
|
32
|
+
# if ckpt:
|
33
|
+
# model_name = json.load(open(os.path.join(ckpt, 'config.json')))['_name_or_path']
|
34
|
+
# print("Model name:", model_name)
|
32
35
|
|
33
|
-
if
|
36
|
+
if model_name in [MODERN_BERT, BERT_MODEL]:
|
34
37
|
from glam4cm.embeddings.bert import BertEmbedder
|
35
38
|
return BertEmbedder(model_name, ckpt)
|
36
39
|
elif WORD2VEC_MODEL in model_name:
|
glam4cm/encoding/encoders.py
CHANGED
glam4cm/lang2graph/archimate.py
CHANGED
@@ -18,11 +18,6 @@ class ArchiMateNxG(LangGraph):
|
|
18
18
|
self.__create_graph()
|
19
19
|
self.set_numbered_labels()
|
20
20
|
|
21
|
-
# self.text = " ".join([
|
22
|
-
# self.nodes[node]['name'] if 'name' in self.nodes[node] else ''
|
23
|
-
# for node in self.nodes
|
24
|
-
# ])
|
25
|
-
|
26
21
|
|
27
22
|
def __create_graph(self):
|
28
23
|
for node in self.json_obj['elements']:
|
glam4cm/lang2graph/common.py
CHANGED
@@ -8,11 +8,15 @@ from glam4cm.data_loading.metadata import GraphMetadata
|
|
8
8
|
from glam4cm.tokenization.special_tokens import *
|
9
9
|
from glam4cm.tokenization.utils import doc_tokenizer
|
10
10
|
import glam4cm.utils as utils
|
11
|
+
from glam4cm.settings import (
|
12
|
+
SUPERTYPE,
|
13
|
+
REFERENCE,
|
14
|
+
CONTAINMENT,
|
15
|
+
|
16
|
+
EDGE_CLS_TASK,
|
17
|
+
LINK_PRED_TASK,
|
18
|
+
)
|
11
19
|
|
12
|
-
SEP = ' '
|
13
|
-
REFERENCE = 'reference'
|
14
|
-
SUPERTYPE = 'supertype'
|
15
|
-
CONTAINMENT = 'containment'
|
16
20
|
|
17
21
|
|
18
22
|
class LangGraph(nx.DiGraph):
|
@@ -112,13 +116,14 @@ def create_graph_from_edge_index(graph, edge_index: np.ndarray):
|
|
112
116
|
subgraph.id_to_node_label = graph.id_to_node_label
|
113
117
|
subgraph.edge_label_to_id = graph.edge_label_to_id
|
114
118
|
subgraph.id_to_edge_label = graph.id_to_edge_label
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
119
|
+
if len(edge_index) > 0:
|
120
|
+
try:
|
121
|
+
assert subgraph.number_of_edges() == edge_index.shape[1]
|
122
|
+
except AssertionError as e:
|
123
|
+
print(f"Number of edges mismatch {subgraph.number_of_edges()} != {edge_index.size(1)}")
|
124
|
+
import pickle
|
125
|
+
pickle.dump([graph, edge_index], open("subgraph.pkl", "wb"))
|
126
|
+
raise e
|
122
127
|
|
123
128
|
return subgraph
|
124
129
|
|
@@ -140,15 +145,24 @@ def format_path(
|
|
140
145
|
):
|
141
146
|
"""Format a path into a string representation."""
|
142
147
|
def get_node_label(node):
|
148
|
+
|
143
149
|
masked = graph.nodes[node].get('masked')
|
144
|
-
node_type = f"{graph.nodes[node].get(f'{node_cls_label}', '')}"
|
145
|
-
|
150
|
+
node_type = f"{graph.nodes[node].get(f'{node_cls_label}', '')}" \
|
151
|
+
if use_node_types and not masked and node_cls_label else ''
|
152
|
+
|
153
|
+
if node_type != '':
|
154
|
+
if isinstance(graph.nodes[node].get(f'{node_cls_label}'), bool):
|
155
|
+
node_type = node_cls_label.title() if graph.nodes[node].get(f'{node_cls_label}') else ''
|
156
|
+
|
157
|
+
|
146
158
|
node_label = get_node_name(
|
147
159
|
graph.nodes[node],
|
148
160
|
metadata.node_label,
|
149
161
|
use_node_attributes,
|
150
162
|
metadata.node_attributes
|
151
163
|
) if not no_labels else ''
|
164
|
+
|
165
|
+
|
152
166
|
if preprocessor:
|
153
167
|
node_label = preprocessor(node_label)
|
154
168
|
|
@@ -174,8 +188,9 @@ def format_path(
|
|
174
188
|
|
175
189
|
return edge_label.strip()
|
176
190
|
|
191
|
+
# import code; code.interact(local=locals())
|
177
192
|
assert len(path) > 0, "Path must contain at least one node."
|
178
|
-
formatted = [
|
193
|
+
formatted = []
|
179
194
|
for i in range(1, len(path)):
|
180
195
|
n1 = path[i - 1]
|
181
196
|
n2 = path[i]
|
@@ -184,12 +199,18 @@ def format_path(
|
|
184
199
|
formatted.append(get_edge_label(n1, n2))
|
185
200
|
formatted.append(get_node_label(n2))
|
186
201
|
|
187
|
-
|
202
|
+
node_str = get_node_label(path[0])
|
203
|
+
if len(formatted) > 0:
|
204
|
+
node_str += " | " + " ".join(formatted).strip()
|
205
|
+
|
206
|
+
return node_str
|
207
|
+
|
188
208
|
|
189
209
|
def get_edge_texts(
|
190
210
|
graph: LangGraph,
|
191
211
|
edge: tuple,
|
192
212
|
d: int,
|
213
|
+
task_type: str,
|
193
214
|
metadata: GraphMetadata,
|
194
215
|
use_node_attributes=False,
|
195
216
|
use_node_types=False,
|
@@ -206,7 +227,8 @@ def get_edge_texts(
|
|
206
227
|
if not neg_samples:
|
207
228
|
masked = graph.edges[n1, n2].get('masked')
|
208
229
|
graph.edges[n1, n2]['masked'] = True
|
209
|
-
|
230
|
+
|
231
|
+
|
210
232
|
n1_text = get_node_text(
|
211
233
|
graph=graph,
|
212
234
|
node=n1,
|
@@ -239,10 +261,26 @@ def get_edge_texts(
|
|
239
261
|
preprocessor=preprocessor,
|
240
262
|
exclude_edges=[edge]
|
241
263
|
)
|
264
|
+
|
265
|
+
|
266
|
+
edge_text = ""
|
267
|
+
|
242
268
|
if not neg_samples:
|
243
269
|
graph.edges[n1, n2]['masked'] = masked or False
|
270
|
+
|
271
|
+
edge_data = graph.get_edge_data(n1, n2)
|
272
|
+
edge_type = get_edge_data(edge_data, edge_cls_label, metadata.type)
|
273
|
+
edge_label = edge_data.get(metadata.edge_label, '') if use_edge_label and not no_labels else ''
|
274
|
+
|
275
|
+
if task_type not in [EDGE_CLS_TASK, LINK_PRED_TASK]:
|
276
|
+
if use_edge_types :
|
277
|
+
edge_text += f" {edge_cls_label}: {edge_type} " if not no_labels else ''
|
278
|
+
|
279
|
+
if use_edge_label:
|
280
|
+
edge_text += f" {edge_label} " if not no_labels else ''
|
244
281
|
|
245
|
-
|
282
|
+
|
283
|
+
return n1_text + EDGE_START + f"{edge_text}" + EDGE_END + n2_text
|
246
284
|
|
247
285
|
|
248
286
|
def get_node_text(
|
@@ -263,28 +301,39 @@ def get_node_text(
|
|
263
301
|
):
|
264
302
|
masked = graph.nodes[node].get('masked')
|
265
303
|
graph.nodes[node]['masked'] = True
|
266
|
-
raw_paths = utils.bfs(graph=graph, start_node=node, d=d, exclude_edges=exclude_edges)
|
267
|
-
unique_paths = utils.remove_subsets(list_of_lists=raw_paths)
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
304
|
+
# raw_paths = utils.bfs(graph=graph, start_node=node, d=d, exclude_edges=exclude_edges)
|
305
|
+
# unique_paths = utils.remove_subsets(list_of_lists=raw_paths)
|
306
|
+
node_neighbour_texts = list()
|
307
|
+
node_neighbours = utils.get_node_neighbours(graph, node, d, exclude_edges=exclude_edges)
|
308
|
+
for neighbour in node_neighbours:
|
309
|
+
unique_paths = [p for p in nx.all_simple_paths(graph, node, neighbour, cutoff=d)]
|
310
|
+
|
311
|
+
node_neighbour_texts.extend([
|
312
|
+
format_path(
|
313
|
+
graph=graph,
|
314
|
+
path=path,
|
315
|
+
metadata=metadata,
|
316
|
+
use_node_attributes=use_node_attributes,
|
317
|
+
use_node_types=use_node_types,
|
318
|
+
use_edge_types=use_edge_types,
|
319
|
+
use_edge_label=use_edge_label,
|
320
|
+
node_cls_label=node_cls_label,
|
321
|
+
edge_cls_label=edge_cls_label,
|
322
|
+
use_special_tokens=use_special_tokens,
|
323
|
+
no_labels=no_labels,
|
324
|
+
preprocessor=preprocessor,
|
325
|
+
neg_sample=False
|
326
|
+
)
|
327
|
+
for path in unique_paths
|
328
|
+
])
|
329
|
+
|
286
330
|
graph.nodes[node]['masked'] = masked or False
|
287
|
-
|
331
|
+
node_str = "\n".join(node_neighbour_texts).strip() if node_neighbour_texts else ''
|
332
|
+
|
333
|
+
if node_cls_label == 'stereotype':
|
334
|
+
node_str = graph.nodes[node]['type'].title() + " " + node_str
|
335
|
+
|
336
|
+
return node_str.strip()
|
288
337
|
|
289
338
|
|
290
339
|
def get_node_texts(
|
@@ -326,6 +375,8 @@ def get_attribute_labels(node_data, attribute_labels):
|
|
326
375
|
if isinstance(node_data[attribute_labels], list):
|
327
376
|
if not node_data[attribute_labels]:
|
328
377
|
return ''
|
378
|
+
if isinstance(node_data[attribute_labels][0], str):
|
379
|
+
return ", ".join(node_data[attribute_labels])
|
329
380
|
if isinstance(node_data[attribute_labels][0], tuple):
|
330
381
|
return ", ".join([f"{k}: {v}" for k, v in node_data[attribute_labels]])
|
331
382
|
elif isinstance(node_data[attribute_labels][0], dict):
|
@@ -346,8 +397,12 @@ def get_node_name(
|
|
346
397
|
attributes_str = "(" + get_attribute_labels(node_data, attribute_labels) + ")"
|
347
398
|
else:
|
348
399
|
attributes_str = ''
|
349
|
-
|
350
|
-
node_label = '' if
|
400
|
+
|
401
|
+
node_label = node_data.get(label, '') if node_data.get(label, '') else ''
|
402
|
+
node_label = '' if node_label and node_label.lower() in ['null', 'none'] else node_label
|
403
|
+
# if attributes_str:
|
404
|
+
# print(f"Node label: {node_label} | Attributes: {attributes_str}")
|
405
|
+
|
351
406
|
return f"{node_label}{attributes_str}".strip()
|
352
407
|
|
353
408
|
|
@@ -405,7 +460,10 @@ def get_uml_edge_data(edge_data: dict, edge_label: str):
|
|
405
460
|
raise ValueError(f"Unknown edge label: {edge_label}")
|
406
461
|
|
407
462
|
def get_ontouml_edge_data(edge_data: dict, edge_label: str):
|
408
|
-
|
463
|
+
try:
|
464
|
+
return {'rel': "relates", "gen": "generalizes"}[edge_data.get(edge_label)]
|
465
|
+
except KeyError:
|
466
|
+
raise ValueError(f"Unknown edge label: {edge_label}")
|
409
467
|
|
410
468
|
def get_uml_edge_type(edge_data):
|
411
469
|
edge_type = edge_data.get('type')
|
glam4cm/lang2graph/ecore.py
CHANGED
@@ -58,8 +58,7 @@ class EcoreNxG(LangGraph):
|
|
58
58
|
for f in structural_features:
|
59
59
|
if f['type'] == 'ecore:EAttribute':
|
60
60
|
name = f['name']
|
61
|
-
|
62
|
-
attributes.append((name, attr_type))
|
61
|
+
attributes.append(name)
|
63
62
|
|
64
63
|
self.add_node(
|
65
64
|
classifier_name,
|
glam4cm/lang2graph/ontouml.py
CHANGED
@@ -48,13 +48,14 @@ extra_properties = [
|
|
48
48
|
class OntoUMLNxG(LangGraph):
|
49
49
|
def __init__(self, json_obj: dict, rel_as_node=True):
|
50
50
|
super().__init__()
|
51
|
+
self.graph_id = json_obj['id']
|
51
52
|
self.json_obj = json_obj
|
52
53
|
self.rel_as_node = rel_as_node
|
53
54
|
self.__create_graph()
|
54
55
|
self.set_numbered_labels()
|
55
56
|
|
56
57
|
self.text = " ".join([
|
57
|
-
self.nodes[node]['name'] if 'name' in self.nodes[node] else ''
|
58
|
+
self.nodes[node]['name'] if 'name' in self.nodes[node] and self.nodes[node]['name'] else ''
|
58
59
|
for node in self.nodes
|
59
60
|
])
|
60
61
|
|
@@ -76,6 +77,7 @@ class OntoUMLNxG(LangGraph):
|
|
76
77
|
ontouml_id2obj(item)
|
77
78
|
|
78
79
|
def create_nxg():
|
80
|
+
|
79
81
|
for k, v in id2obj_map.items():
|
80
82
|
node_name = v.get('name', '')
|
81
83
|
|
@@ -85,7 +87,8 @@ class OntoUMLNxG(LangGraph):
|
|
85
87
|
self.nodes[k][prop] = v[prop] if prop in v else False
|
86
88
|
|
87
89
|
logger.info(f"Node: {node_name} type: {v[ONTOUML_ELEMENT_TYPE]}")
|
88
|
-
|
90
|
+
# else:
|
91
|
+
# continue
|
89
92
|
|
90
93
|
logger.info(f"Node: {node_name} type: {v[ONTOUML_ELEMENT_TYPE]}")
|
91
94
|
if ONTOUML_STEREOTYPE in v and v[ONTOUML_STEREOTYPE] is not None:
|
@@ -108,10 +111,8 @@ class OntoUMLNxG(LangGraph):
|
|
108
111
|
|
109
112
|
elif ONTOUML_PROPERTIES in v and v[ONTOUML_PROPERTIES] is not None:
|
110
113
|
properties = v[ONTOUML_PROPERTIES] if isinstance(v[ONTOUML_PROPERTIES], list) else [v[ONTOUML_PROPERTIES]]
|
111
|
-
|
112
|
-
|
113
|
-
logger.info(f"Properties: {properties_str}")
|
114
|
-
|
114
|
+
self.nodes[k][ONTOUML_PROPERTIES] = [property[ONTOUML_ELEMENT_NAME] for property in properties]
|
115
|
+
|
115
116
|
|
116
117
|
elif v[ONTOUML_ELEMENT_TYPE] == ONTOUML_RELATION:
|
117
118
|
properties = v[ONTOUML_PROPERTIES] if isinstance(v[ONTOUML_PROPERTIES], list) else [v[ONTOUML_PROPERTIES]]
|
@@ -144,7 +145,7 @@ class OntoUMLNxG(LangGraph):
|
|
144
145
|
|
145
146
|
def create_nxg_rel_as_edge():
|
146
147
|
# TODO: To be implemented
|
147
|
-
|
148
|
+
raise NotImplementedError
|
148
149
|
|
149
150
|
|
150
151
|
id2obj_map = dict()
|
glam4cm/models/gnn_layers.py
CHANGED
@@ -123,19 +123,33 @@ class GNNConv(torch.nn.Module):
|
|
123
123
|
h = self.dropout(h)
|
124
124
|
return h
|
125
125
|
|
126
|
+
edge_attr_val = isinstance(edge_attr, torch.Tensor) and self.is_headed_model()
|
126
127
|
h = in_feat
|
127
|
-
h = self.conv_layers[0](h, edge_index, edge_attr)
|
128
|
-
|
128
|
+
h = self.conv_layers[0](h, edge_index, edge_attr) \
|
129
|
+
if edge_attr_val else self.conv_layers[0](h, edge_index)
|
130
|
+
h = activate(h)
|
129
131
|
|
130
132
|
for conv in self.conv_layers[1:-1]:
|
131
|
-
nh = conv(h, edge_index, edge_attr) if
|
133
|
+
nh = conv(h, edge_index, edge_attr) if edge_attr_val else conv(h, edge_index)
|
132
134
|
h = nh if not self.residual else nh + h
|
133
|
-
activate(h)
|
135
|
+
h = activate(h)
|
134
136
|
|
135
137
|
h = self.conv_layers[-1](h, edge_index)
|
136
|
-
activate(h)
|
138
|
+
h = activate(h)
|
137
139
|
return h
|
138
|
-
|
140
|
+
|
141
|
+
def is_headed_model(self):
|
142
|
+
""""
|
143
|
+
Returns True if the model is a headed model
|
144
|
+
Checks if the model name is in the supported_conv_models dictionary
|
145
|
+
and if the model requires num_heads
|
146
|
+
"""
|
147
|
+
headed = self.num_heads is not None
|
148
|
+
model_name = self.conv_layers[0].__class__.__name__
|
149
|
+
if model_name in supported_conv_models:
|
150
|
+
return supported_conv_models[model_name] and headed
|
151
|
+
return False
|
152
|
+
|
139
153
|
|
140
154
|
class EdgeClassifer(nn.Module):
|
141
155
|
|
glam4cm/models/hf.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from transformers import AutoModelForSequenceClassification
|
2
2
|
|
3
|
-
def get_model(model_name, num_labels, len_tokenizer=None) -> AutoModelForSequenceClassification:
|
4
|
-
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
|
3
|
+
def get_model(model_name, num_labels, len_tokenizer=None, trust_remote_code=False) -> AutoModelForSequenceClassification:
|
4
|
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, trust_remote_code=trust_remote_code)
|
5
5
|
if len_tokenizer:
|
6
6
|
model.resize_token_embeddings(len_tokenizer)
|
7
7
|
assert model.config.vocab_size == len_tokenizer,\
|
glam4cm/run.py
CHANGED
@@ -43,8 +43,8 @@ tasks = {
|
|
43
43
|
|
44
44
|
6: 'GNN Graph Classification',
|
45
45
|
7: 'GNN Node Classification',
|
46
|
-
8: 'GNN
|
47
|
-
9: 'GNN
|
46
|
+
8: 'GNN Link Prediction',
|
47
|
+
9: 'GNN Edge Classification',
|
48
48
|
10: 'CM-GPT Causal Modeling',
|
49
49
|
11: 'CM-GPT Node Classification',
|
50
50
|
12: 'CM-GPT Edge Classification'
|
@@ -60,16 +60,15 @@ tasks_handler_map = {
|
|
60
60
|
5: (bert_edge_classification.run, bert_ec_parse_args),
|
61
61
|
6: (gnn_graph_cls.run, gnn_parse_args),
|
62
62
|
7: (gnn_node_classification.run, gnn_nc_parse_args),
|
63
|
-
8: (
|
64
|
-
9: (
|
63
|
+
8: (gnn_link_prediction.run, gnn_lp_parse_args),
|
64
|
+
9: (gnn_edge_classification.run, gnn_ec_parse_args),
|
65
65
|
10: (cm_gpt_pretraining.run, cm_gpt_parse_args),
|
66
66
|
11: (cm_gpt_node_classification.run, cm_gpt_nc_parse_args),
|
67
67
|
12: (cm_gpt_edge_classification.run, cm_gpt_ec_parse_args)
|
68
68
|
}
|
69
69
|
|
70
70
|
|
71
|
-
|
72
|
-
|
71
|
+
def main():
|
73
72
|
main_parser = argparse.ArgumentParser(description="Train ML models on conceptual models")
|
74
73
|
main_parser.add_argument('--task_id', type=int, required=True, help=f'ID of the task to run. Options are: {"\n".join(f"{k}: {v}" for k, v in tasks.items())}', choices=list(tasks.keys()), default=0)
|
75
74
|
main_parser.add_argument('--th', '--task_help', action="store_true", help="Help for the task specified by --task_id")
|
@@ -85,7 +84,7 @@ if __name__ == '__main__':
|
|
85
84
|
### If args has -h or --help, print help
|
86
85
|
if any(x in remaining_args for x in ['-th', '--task_help']):
|
87
86
|
task_id = args.task_id
|
88
|
-
|
87
|
+
task_handler, task_parser = tasks_handler_map[task_id]
|
89
88
|
print("Help for task:", tasks[task_id])
|
90
89
|
task_parser().print_help()
|
91
90
|
exit(0)
|
@@ -94,6 +93,11 @@ if __name__ == '__main__':
|
|
94
93
|
|
95
94
|
|
96
95
|
task_id = args.task_id
|
97
|
-
|
96
|
+
task_handler, task_parser = tasks_handler_map[task_id]
|
98
97
|
task_args = task_parser().parse_args(remaining_args)
|
99
|
-
|
98
|
+
task_handler(task_args)
|
99
|
+
|
100
|
+
|
101
|
+
if __name__ == '__main__':
|
102
|
+
main()
|
103
|
+
|