glam4cm 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. glam4cm/__init__.py +2 -1
  2. glam4cm/data_loading/data.py +90 -146
  3. glam4cm/data_loading/encoding.py +17 -6
  4. glam4cm/data_loading/graph_dataset.py +192 -57
  5. glam4cm/data_loading/metadata.py +1 -1
  6. glam4cm/data_loading/models_dataset.py +42 -18
  7. glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
  8. glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
  9. glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
  10. glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
  11. glam4cm/downstream_tasks/bert_node_classification.py +127 -89
  12. glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
  13. glam4cm/downstream_tasks/common_args.py +32 -4
  14. glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
  15. glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
  16. glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
  17. glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
  18. glam4cm/downstream_tasks/utils.py +16 -2
  19. glam4cm/embeddings/bert.py +1 -1
  20. glam4cm/embeddings/common.py +7 -4
  21. glam4cm/encoding/encoders.py +1 -1
  22. glam4cm/lang2graph/archimate.py +0 -5
  23. glam4cm/lang2graph/common.py +99 -41
  24. glam4cm/lang2graph/ecore.py +1 -2
  25. glam4cm/lang2graph/ontouml.py +8 -7
  26. glam4cm/models/gnn_layers.py +20 -6
  27. glam4cm/models/hf.py +2 -2
  28. glam4cm/run.py +12 -7
  29. glam4cm/run_conf_v2.py +405 -0
  30. glam4cm/run_configs.py +70 -106
  31. glam4cm/run_confs.py +41 -0
  32. glam4cm/settings.py +15 -2
  33. glam4cm/tokenization/special_tokens.py +23 -1
  34. glam4cm/tokenization/utils.py +23 -4
  35. glam4cm/trainers/cm_gpt_trainer.py +1 -1
  36. glam4cm/trainers/gnn_edge_classifier.py +12 -1
  37. glam4cm/trainers/gnn_graph_classifier.py +12 -5
  38. glam4cm/trainers/gnn_link_predictor.py +18 -3
  39. glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
  40. glam4cm/trainers/gnn_trainer.py +8 -0
  41. glam4cm/trainers/metrics.py +1 -1
  42. glam4cm/utils.py +265 -2
  43. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
  44. glam4cm-1.0.0.dist-info/RECORD +75 -0
  45. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
  46. glam4cm-0.1.1.dist-info/RECORD +0 -72
  47. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
  48. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
  49. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,18 @@
1
1
  import os
2
2
  from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
3
3
  from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
4
- from glam4cm.settings import LP_TASK_LINK_PRED
4
+ from glam4cm.settings import LINK_PRED_TASK, results_dir
5
5
  from glam4cm.downstream_tasks.utils import get_models_dataset
6
6
  from glam4cm.tokenization.special_tokens import *
7
7
  from glam4cm.trainers.gnn_link_predictor import GNNLinkPredictionTrainer as Trainer
8
- from glam4cm.utils import merge_argument_parsers, set_seed
9
- from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
10
-
8
+ from glam4cm.utils import merge_argument_parsers, set_seed, set_torch_encoding_labels
9
+ from glam4cm.downstream_tasks.common_args import (
10
+ get_common_args_parser,
11
+ get_config_params,
12
+ get_gnn_args_parser
13
+ )
11
14
 
15
+
12
16
  def get_parser():
13
17
  common_parser = get_common_args_parser()
14
18
  gnn_parser = get_gnn_args_parser()
@@ -21,6 +25,7 @@ def run(args):
21
25
  set_seed(args.seed)
22
26
 
23
27
  config_params = dict(
28
+ include_dummies = args.include_dummies,
24
29
  min_enr = args.min_enr,
25
30
  min_edges = args.min_edges,
26
31
  remove_duplicates = args.remove_duplicates,
@@ -42,14 +47,18 @@ def run(args):
42
47
  aggregation = args.aggregation
43
48
 
44
49
  graph_data_params = get_config_params(args)
50
+
51
+ if args.use_embeddings:
52
+ graph_data_params['embed_model_name'] = os.path.join(results_dir, dataset_name, f"LM_{LINK_PRED_TASK}")
53
+
45
54
  print("Loading graph dataset")
46
55
  graph_dataset = GraphEdgeDataset(
47
- dataset,
48
- dict(
56
+ dataset,
57
+ task_type=LINK_PRED_TASK,
58
+ **dict(
49
59
  **graph_data_params,
50
- add_negative_train_samples=args.add_negative_train_samples,
60
+ add_negative_train_samples=True,
51
61
  neg_sampling_ratio=args.neg_sampling_ratio,
52
- task=LP_TASK_LINK_PRED
53
62
  ))
54
63
 
55
64
  input_dim = graph_dataset[0].data.x.shape[1]
@@ -78,7 +87,7 @@ def run(args):
78
87
  logs_dir = os.path.join(
79
88
  "logs",
80
89
  dataset_name,
81
- "gnn_lp",
90
+ f"GNN_{LINK_PRED_TASK}",
82
91
  f'{graph_dataset.config_hash}',
83
92
  )
84
93
 
@@ -92,11 +101,14 @@ def run(args):
92
101
  bias=False,
93
102
  )
94
103
 
104
+ graph_torch_data = graph_dataset.get_torch_dataset()
105
+ # exclude_labels = getattr(graph_dataset, f"node_exclude_{args.node_cls_label}")
106
+ # set_torch_encoding_labels(graph_torch_data, f"node_{args.node_cls_label}", exclude_labels)
95
107
 
96
108
  trainer = Trainer(
97
- gnn_conv_model,
98
- mlp_predictor,
99
- graph_dataset.get_torch_dataset(),
109
+ model=gnn_conv_model,
110
+ predictor=mlp_predictor,
111
+ dataset=graph_torch_data,
100
112
  lr=args.lr,
101
113
  num_epochs=args.num_epochs,
102
114
  batch_size=args.batch_size,
@@ -106,4 +118,4 @@ def run(args):
106
118
 
107
119
 
108
120
  print("Training GNN Link Prediction model")
109
- trainer.run()
121
+ trainer.run()
@@ -2,10 +2,15 @@ import os
2
2
  from glam4cm.data_loading.graph_dataset import GraphNodeDataset
3
3
  from glam4cm.models.gnn_layers import GNNConv, NodeClassifier
4
4
  from glam4cm.downstream_tasks.utils import get_models_dataset
5
+ from glam4cm.settings import NODE_CLS_TASK, results_dir
5
6
  from glam4cm.tokenization.special_tokens import *
6
7
  from glam4cm.trainers.gnn_node_classifier import GNNNodeClassificationTrainer as Trainer
7
- from glam4cm.utils import merge_argument_parsers, set_seed
8
- from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
8
+ from glam4cm.utils import merge_argument_parsers, set_seed, set_torch_encoding_labels
9
+ from glam4cm.downstream_tasks.common_args import (
10
+ get_common_args_parser,
11
+ get_config_params,
12
+ get_gnn_args_parser
13
+ )
9
14
 
10
15
 
11
16
  def get_parser():
@@ -20,6 +25,7 @@ def run(args):
20
25
  set_seed(args.seed)
21
26
 
22
27
  config_params = dict(
28
+ include_dummies = args.include_dummies,
23
29
  min_enr = args.min_enr,
24
30
  min_edges = args.min_edges,
25
31
  remove_duplicates = args.remove_duplicates,
@@ -29,13 +35,19 @@ def run(args):
29
35
  dataset_name = args.dataset
30
36
 
31
37
  dataset = get_models_dataset(dataset_name, **config_params)
32
- graph_data_params = get_config_params(args)
38
+ graph_data_params = {**get_config_params(args), 'task_type': NODE_CLS_TASK}
39
+
40
+ if args.use_embeddings:
41
+ graph_data_params['embed_model_name'] = os.path.join(results_dir, dataset_name, f'{args.node_cls_label}')
33
42
 
34
43
  print("Loading graph dataset")
35
44
  graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
36
45
  print("Loaded graph dataset")
37
-
46
+
47
+
38
48
  graph_torch_data = graph_dataset.get_torch_dataset()
49
+ exclude_labels = getattr(graph_dataset, f"node_exclude_{args.node_cls_label}")
50
+ set_torch_encoding_labels(graph_torch_data, f"node_{args.node_cls_label}", exclude_labels)
39
51
 
40
52
  num_nodes_label = f"num_nodes_{args.node_cls_label}"
41
53
  assert hasattr(graph_dataset, num_nodes_label), f"Graph dataset does not have attribute {num_nodes_label}"
@@ -83,7 +95,7 @@ def run(args):
83
95
  logs_dir = os.path.join(
84
96
  "logs",
85
97
  dataset_name,
86
- "gnn_node_cls",
98
+ f"GNN_{NODE_CLS_TASK}",
87
99
  f"{graph_dataset.config_hash}",
88
100
  )
89
101
 
@@ -92,7 +104,7 @@ def run(args):
92
104
  mlp_predictor,
93
105
  graph_torch_data,
94
106
  cls_label=args.node_cls_label,
95
- exclude_labels=getattr(graph_dataset, f"node_exclude_{args.node_cls_label}"),
107
+ exclude_labels=[-1],
96
108
  lr=args.lr,
97
109
  num_epochs=args.num_epochs,
98
110
  use_edge_attrs=args.use_edge_attrs,
@@ -100,4 +112,4 @@ def run(args):
100
112
  )
101
113
 
102
114
  print("Training GNN Node Classification model")
103
- trainer.run()
115
+ trainer.run()
@@ -1,6 +1,7 @@
1
1
  from glam4cm.data_loading.models_dataset import (
2
2
  ArchiMateDataset,
3
- EcoreDataset
3
+ EcoreDataset,
4
+ OntoUMLDataset
4
5
  )
5
6
 
6
7
 
@@ -8,7 +9,8 @@ dataset_to_metamodel = {
8
9
  'modelset': 'ecore',
9
10
  'ecore_555': 'ecore',
10
11
  'mar-ecore-github': 'ecore',
11
- 'eamodelset': 'ea'
12
+ 'eamodelset': 'ea',
13
+ 'ontouml': 'ontouml',
12
14
  }
13
15
 
14
16
 
@@ -22,6 +24,8 @@ def get_model_dataset_class(dataset_name):
22
24
  dataset_class = ArchiMateDataset
23
25
  elif dataset_type == 'ecore':
24
26
  dataset_class = EcoreDataset
27
+ elif dataset_type == 'ontouml':
28
+ dataset_class = OntoUMLDataset
25
29
  else:
26
30
  raise ValueError(f"Unknown dataset type: {dataset_type}")
27
31
  return dataset_class
@@ -33,3 +37,13 @@ def get_models_dataset(dataset_name, **config_params):
33
37
  del config_params['language']
34
38
  dataset_class = get_model_dataset_class(dataset_name)
35
39
  return dataset_class(dataset_name, **config_params)
40
+
41
+
42
+ def get_logging_steps(dataset_size, num_epochs, batch_size):
43
+ """
44
+ Calculate the logging steps based on the dataset size, number of epochs, and batch size.
45
+ """
46
+ num_steps = dataset_size // batch_size
47
+ logging_steps = num_steps * num_epochs // 20
48
+ print(f"Logging steps: {logging_steps}")
49
+ return logging_steps
@@ -36,7 +36,7 @@ class BertEmbedder(Embedder):
36
36
  print("Number of Texts: ", len(text))
37
37
 
38
38
  dataset = EncodingDataset(self.tokenizer, texts=text, remove_duplicates=False)
39
- loader = DataLoader(dataset, batch_size=256)
39
+ loader = DataLoader(dataset, batch_size=64)
40
40
 
41
41
  embeddings = list()
42
42
  with torch.no_grad():
@@ -5,7 +5,9 @@ from typing import List, Union
5
5
  import torch
6
6
  from glam4cm.settings import (
7
7
  WORD2VEC_MODEL,
8
- TFIDF_MODEL
8
+ TFIDF_MODEL,
9
+ MODERN_BERT,
10
+ BERT_MODEL
9
11
  )
10
12
 
11
13
 
@@ -27,10 +29,11 @@ def get_embedding_model(
27
29
  model_name: str,
28
30
  ckpt: str = None
29
31
  ) -> Embedder:
30
- if ckpt:
31
- model_name = json.load(open(os.path.join(ckpt, 'config.json')))['_name_or_path']
32
+ # if ckpt:
33
+ # model_name = json.load(open(os.path.join(ckpt, 'config.json')))['_name_or_path']
34
+ # print("Model name:", model_name)
32
35
 
33
- if 'bert' in model_name:
36
+ if model_name in [MODERN_BERT, BERT_MODEL]:
34
37
  from glam4cm.embeddings.bert import BertEmbedder
35
38
  return BertEmbedder(model_name, ckpt)
36
39
  elif WORD2VEC_MODEL in model_name:
@@ -4,7 +4,7 @@ from sklearn.preprocessing import LabelEncoder
4
4
  import fasttext
5
5
  from scipy.sparse import csr_matrix
6
6
  import numpy as np
7
- from encoding.common import (
7
+ from glam4cm.encoding.common import (
8
8
  doc_tokenizer,
9
9
  SEP
10
10
  )
@@ -18,11 +18,6 @@ class ArchiMateNxG(LangGraph):
18
18
  self.__create_graph()
19
19
  self.set_numbered_labels()
20
20
 
21
- # self.text = " ".join([
22
- # self.nodes[node]['name'] if 'name' in self.nodes[node] else ''
23
- # for node in self.nodes
24
- # ])
25
-
26
21
 
27
22
  def __create_graph(self):
28
23
  for node in self.json_obj['elements']:
@@ -8,11 +8,15 @@ from glam4cm.data_loading.metadata import GraphMetadata
8
8
  from glam4cm.tokenization.special_tokens import *
9
9
  from glam4cm.tokenization.utils import doc_tokenizer
10
10
  import glam4cm.utils as utils
11
+ from glam4cm.settings import (
12
+ SUPERTYPE,
13
+ REFERENCE,
14
+ CONTAINMENT,
15
+
16
+ EDGE_CLS_TASK,
17
+ LINK_PRED_TASK,
18
+ )
11
19
 
12
- SEP = ' '
13
- REFERENCE = 'reference'
14
- SUPERTYPE = 'supertype'
15
- CONTAINMENT = 'containment'
16
20
 
17
21
 
18
22
  class LangGraph(nx.DiGraph):
@@ -112,13 +116,14 @@ def create_graph_from_edge_index(graph, edge_index: np.ndarray):
112
116
  subgraph.id_to_node_label = graph.id_to_node_label
113
117
  subgraph.edge_label_to_id = graph.edge_label_to_id
114
118
  subgraph.id_to_edge_label = graph.id_to_edge_label
115
- try:
116
- assert subgraph.number_of_edges() == edge_index.shape[1]
117
- except AssertionError as e:
118
- print(f"Number of edges mismatch {subgraph.number_of_edges()} != {edge_index.size(1)}")
119
- import pickle
120
- pickle.dump([graph, edge_index], open("subgraph.pkl", "wb"))
121
- raise e
119
+ if len(edge_index) > 0:
120
+ try:
121
+ assert subgraph.number_of_edges() == edge_index.shape[1]
122
+ except AssertionError as e:
123
+ print(f"Number of edges mismatch {subgraph.number_of_edges()} != {edge_index.size(1)}")
124
+ import pickle
125
+ pickle.dump([graph, edge_index], open("subgraph.pkl", "wb"))
126
+ raise e
122
127
 
123
128
  return subgraph
124
129
 
@@ -140,15 +145,24 @@ def format_path(
140
145
  ):
141
146
  """Format a path into a string representation."""
142
147
  def get_node_label(node):
148
+
143
149
  masked = graph.nodes[node].get('masked')
144
- node_type = f"{graph.nodes[node].get(f'{node_cls_label}', '')}" if use_node_types and not masked and node_cls_label else ''
145
- node_type = f"{node_cls_label}: {node_type}" if node_type else ''
150
+ node_type = f"{graph.nodes[node].get(f'{node_cls_label}', '')}" \
151
+ if use_node_types and not masked and node_cls_label else ''
152
+
153
+ if node_type != '':
154
+ if isinstance(graph.nodes[node].get(f'{node_cls_label}'), bool):
155
+ node_type = node_cls_label.title() if graph.nodes[node].get(f'{node_cls_label}') else ''
156
+
157
+
146
158
  node_label = get_node_name(
147
159
  graph.nodes[node],
148
160
  metadata.node_label,
149
161
  use_node_attributes,
150
162
  metadata.node_attributes
151
163
  ) if not no_labels else ''
164
+
165
+
152
166
  if preprocessor:
153
167
  node_label = preprocessor(node_label)
154
168
 
@@ -174,8 +188,9 @@ def format_path(
174
188
 
175
189
  return edge_label.strip()
176
190
 
191
+ # import code; code.interact(local=locals())
177
192
  assert len(path) > 0, "Path must contain at least one node."
178
- formatted = [get_node_label(path[0])]
193
+ formatted = []
179
194
  for i in range(1, len(path)):
180
195
  n1 = path[i - 1]
181
196
  n2 = path[i]
@@ -184,12 +199,18 @@ def format_path(
184
199
  formatted.append(get_edge_label(n1, n2))
185
200
  formatted.append(get_node_label(n2))
186
201
 
187
- return " ".join(formatted).strip()
202
+ node_str = get_node_label(path[0])
203
+ if len(formatted) > 0:
204
+ node_str += " | " + " ".join(formatted).strip()
205
+
206
+ return node_str
207
+
188
208
 
189
209
  def get_edge_texts(
190
210
  graph: LangGraph,
191
211
  edge: tuple,
192
212
  d: int,
213
+ task_type: str,
193
214
  metadata: GraphMetadata,
194
215
  use_node_attributes=False,
195
216
  use_node_types=False,
@@ -206,7 +227,8 @@ def get_edge_texts(
206
227
  if not neg_samples:
207
228
  masked = graph.edges[n1, n2].get('masked')
208
229
  graph.edges[n1, n2]['masked'] = True
209
-
230
+
231
+
210
232
  n1_text = get_node_text(
211
233
  graph=graph,
212
234
  node=n1,
@@ -239,10 +261,26 @@ def get_edge_texts(
239
261
  preprocessor=preprocessor,
240
262
  exclude_edges=[edge]
241
263
  )
264
+
265
+
266
+ edge_text = ""
267
+
242
268
  if not neg_samples:
243
269
  graph.edges[n1, n2]['masked'] = masked or False
270
+
271
+ edge_data = graph.get_edge_data(n1, n2)
272
+ edge_type = get_edge_data(edge_data, edge_cls_label, metadata.type)
273
+ edge_label = edge_data.get(metadata.edge_label, '') if use_edge_label and not no_labels else ''
274
+
275
+ if task_type not in [EDGE_CLS_TASK, LINK_PRED_TASK]:
276
+ if use_edge_types :
277
+ edge_text += f" {edge_cls_label}: {edge_type} " if not no_labels else ''
278
+
279
+ if use_edge_label:
280
+ edge_text += f" {edge_label} " if not no_labels else ''
244
281
 
245
- return n1_text + EDGE_START + EDGE_END + n2_text
282
+
283
+ return n1_text + EDGE_START + f"{edge_text}" + EDGE_END + n2_text
246
284
 
247
285
 
248
286
  def get_node_text(
@@ -263,28 +301,39 @@ def get_node_text(
263
301
  ):
264
302
  masked = graph.nodes[node].get('masked')
265
303
  graph.nodes[node]['masked'] = True
266
- raw_paths = utils.bfs(graph=graph, start_node=node, d=d, exclude_edges=exclude_edges)
267
- unique_paths = utils.remove_subsets(list_of_lists=raw_paths)
268
- text = "\n".join([
269
- format_path(
270
- graph=graph,
271
- path=path,
272
- metadata=metadata,
273
- use_node_attributes=use_node_attributes,
274
- use_node_types=use_node_types,
275
- use_edge_types=use_edge_types,
276
- use_edge_label=use_edge_label,
277
- node_cls_label=node_cls_label,
278
- edge_cls_label=edge_cls_label,
279
- use_special_tokens=use_special_tokens,
280
- no_labels=no_labels,
281
- preprocessor=preprocessor,
282
- neg_sample=False
283
- )
284
- for path in unique_paths
285
- ])
304
+ # raw_paths = utils.bfs(graph=graph, start_node=node, d=d, exclude_edges=exclude_edges)
305
+ # unique_paths = utils.remove_subsets(list_of_lists=raw_paths)
306
+ node_neighbour_texts = list()
307
+ node_neighbours = utils.get_node_neighbours(graph, node, d, exclude_edges=exclude_edges)
308
+ for neighbour in node_neighbours:
309
+ unique_paths = [p for p in nx.all_simple_paths(graph, node, neighbour, cutoff=d)]
310
+
311
+ node_neighbour_texts.extend([
312
+ format_path(
313
+ graph=graph,
314
+ path=path,
315
+ metadata=metadata,
316
+ use_node_attributes=use_node_attributes,
317
+ use_node_types=use_node_types,
318
+ use_edge_types=use_edge_types,
319
+ use_edge_label=use_edge_label,
320
+ node_cls_label=node_cls_label,
321
+ edge_cls_label=edge_cls_label,
322
+ use_special_tokens=use_special_tokens,
323
+ no_labels=no_labels,
324
+ preprocessor=preprocessor,
325
+ neg_sample=False
326
+ )
327
+ for path in unique_paths
328
+ ])
329
+
286
330
  graph.nodes[node]['masked'] = masked or False
287
- return text
331
+ node_str = "\n".join(node_neighbour_texts).strip() if node_neighbour_texts else ''
332
+
333
+ if node_cls_label == 'stereotype':
334
+ node_str = graph.nodes[node]['type'].title() + " " + node_str
335
+
336
+ return node_str.strip()
288
337
 
289
338
 
290
339
  def get_node_texts(
@@ -326,6 +375,8 @@ def get_attribute_labels(node_data, attribute_labels):
326
375
  if isinstance(node_data[attribute_labels], list):
327
376
  if not node_data[attribute_labels]:
328
377
  return ''
378
+ if isinstance(node_data[attribute_labels][0], str):
379
+ return ", ".join(node_data[attribute_labels])
329
380
  if isinstance(node_data[attribute_labels][0], tuple):
330
381
  return ", ".join([f"{k}: {v}" for k, v in node_data[attribute_labels]])
331
382
  elif isinstance(node_data[attribute_labels][0], dict):
@@ -346,8 +397,12 @@ def get_node_name(
346
397
  attributes_str = "(" + get_attribute_labels(node_data, attribute_labels) + ")"
347
398
  else:
348
399
  attributes_str = ''
349
- node_label = node_data.get(label, '')
350
- node_label = '' if node_label.lower() == 'null' else node_label
400
+
401
+ node_label = node_data.get(label, '') if node_data.get(label, '') else ''
402
+ node_label = '' if node_label and node_label.lower() in ['null', 'none'] else node_label
403
+ # if attributes_str:
404
+ # print(f"Node label: {node_label} | Attributes: {attributes_str}")
405
+
351
406
  return f"{node_label}{attributes_str}".strip()
352
407
 
353
408
 
@@ -405,7 +460,10 @@ def get_uml_edge_data(edge_data: dict, edge_label: str):
405
460
  raise ValueError(f"Unknown edge label: {edge_label}")
406
461
 
407
462
  def get_ontouml_edge_data(edge_data: dict, edge_label: str):
408
- return edge_data.get(edge_label)
463
+ try:
464
+ return {'rel': "relates", "gen": "generalizes"}[edge_data.get(edge_label)]
465
+ except KeyError:
466
+ raise ValueError(f"Unknown edge label: {edge_label}")
409
467
 
410
468
  def get_uml_edge_type(edge_data):
411
469
  edge_type = edge_data.get('type')
@@ -58,8 +58,7 @@ class EcoreNxG(LangGraph):
58
58
  for f in structural_features:
59
59
  if f['type'] == 'ecore:EAttribute':
60
60
  name = f['name']
61
- attr_type = f['ref'] if f['ref'] else ''
62
- attributes.append((name, attr_type))
61
+ attributes.append(name)
63
62
 
64
63
  self.add_node(
65
64
  classifier_name,
@@ -48,13 +48,14 @@ extra_properties = [
48
48
  class OntoUMLNxG(LangGraph):
49
49
  def __init__(self, json_obj: dict, rel_as_node=True):
50
50
  super().__init__()
51
+ self.graph_id = json_obj['id']
51
52
  self.json_obj = json_obj
52
53
  self.rel_as_node = rel_as_node
53
54
  self.__create_graph()
54
55
  self.set_numbered_labels()
55
56
 
56
57
  self.text = " ".join([
57
- self.nodes[node]['name'] if 'name' in self.nodes[node] else ''
58
+ self.nodes[node]['name'] if 'name' in self.nodes[node] and self.nodes[node]['name'] else ''
58
59
  for node in self.nodes
59
60
  ])
60
61
 
@@ -76,6 +77,7 @@ class OntoUMLNxG(LangGraph):
76
77
  ontouml_id2obj(item)
77
78
 
78
79
  def create_nxg():
80
+
79
81
  for k, v in id2obj_map.items():
80
82
  node_name = v.get('name', '')
81
83
 
@@ -85,7 +87,8 @@ class OntoUMLNxG(LangGraph):
85
87
  self.nodes[k][prop] = v[prop] if prop in v else False
86
88
 
87
89
  logger.info(f"Node: {node_name} type: {v[ONTOUML_ELEMENT_TYPE]}")
88
-
90
+ # else:
91
+ # continue
89
92
 
90
93
  logger.info(f"Node: {node_name} type: {v[ONTOUML_ELEMENT_TYPE]}")
91
94
  if ONTOUML_STEREOTYPE in v and v[ONTOUML_STEREOTYPE] is not None:
@@ -108,10 +111,8 @@ class OntoUMLNxG(LangGraph):
108
111
 
109
112
  elif ONTOUML_PROPERTIES in v and v[ONTOUML_PROPERTIES] is not None:
110
113
  properties = v[ONTOUML_PROPERTIES] if isinstance(v[ONTOUML_PROPERTIES], list) else [v[ONTOUML_PROPERTIES]]
111
- properties_str = ", ".join([property[ONTOUML_ELEMENT_NAME] for property in properties])
112
- self.nodes[k][ONTOUML_PROPERTIES] = properties_str
113
- logger.info(f"Properties: {properties_str}")
114
-
114
+ self.nodes[k][ONTOUML_PROPERTIES] = [property[ONTOUML_ELEMENT_NAME] for property in properties]
115
+
115
116
 
116
117
  elif v[ONTOUML_ELEMENT_TYPE] == ONTOUML_RELATION:
117
118
  properties = v[ONTOUML_PROPERTIES] if isinstance(v[ONTOUML_PROPERTIES], list) else [v[ONTOUML_PROPERTIES]]
@@ -144,7 +145,7 @@ class OntoUMLNxG(LangGraph):
144
145
 
145
146
  def create_nxg_rel_as_edge():
146
147
  # TODO: To be implemented
147
- pass
148
+ raise NotImplementedError
148
149
 
149
150
 
150
151
  id2obj_map = dict()
@@ -123,19 +123,33 @@ class GNNConv(torch.nn.Module):
123
123
  h = self.dropout(h)
124
124
  return h
125
125
 
126
+ edge_attr_val = isinstance(edge_attr, torch.Tensor) and self.is_headed_model()
126
127
  h = in_feat
127
- h = self.conv_layers[0](h, edge_index, edge_attr) if isinstance(edge_attr, torch.Tensor) else self.conv_layers[0](h, edge_index)
128
- activate(h)
128
+ h = self.conv_layers[0](h, edge_index, edge_attr) \
129
+ if edge_attr_val else self.conv_layers[0](h, edge_index)
130
+ h = activate(h)
129
131
 
130
132
  for conv in self.conv_layers[1:-1]:
131
- nh = conv(h, edge_index, edge_attr) if isinstance(edge_attr, torch.Tensor) else conv(h, edge_index)
133
+ nh = conv(h, edge_index, edge_attr) if edge_attr_val else conv(h, edge_index)
132
134
  h = nh if not self.residual else nh + h
133
- activate(h)
135
+ h = activate(h)
134
136
 
135
137
  h = self.conv_layers[-1](h, edge_index)
136
- activate(h)
138
+ h = activate(h)
137
139
  return h
138
-
140
+
141
+ def is_headed_model(self):
142
+ """"
143
+ Returns True if the model is a headed model
144
+ Checks if the model name is in the supported_conv_models dictionary
145
+ and if the model requires num_heads
146
+ """
147
+ headed = self.num_heads is not None
148
+ model_name = self.conv_layers[0].__class__.__name__
149
+ if model_name in supported_conv_models:
150
+ return supported_conv_models[model_name] and headed
151
+ return False
152
+
139
153
 
140
154
  class EdgeClassifer(nn.Module):
141
155
 
glam4cm/models/hf.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from transformers import AutoModelForSequenceClassification
2
2
 
3
- def get_model(model_name, num_labels, len_tokenizer=None) -> AutoModelForSequenceClassification:
4
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
3
+ def get_model(model_name, num_labels, len_tokenizer=None, trust_remote_code=False) -> AutoModelForSequenceClassification:
4
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, trust_remote_code=trust_remote_code)
5
5
  if len_tokenizer:
6
6
  model.resize_token_embeddings(len_tokenizer)
7
7
  assert model.config.vocab_size == len_tokenizer,\
glam4cm/run.py CHANGED
@@ -43,8 +43,8 @@ tasks = {
43
43
 
44
44
  6: 'GNN Graph Classification',
45
45
  7: 'GNN Node Classification',
46
- 8: 'GNN Edge Classification',
47
- 9: 'GNN Link Prediction',
46
+ 8: 'GNN Link Prediction',
47
+ 9: 'GNN Edge Classification',
48
48
  10: 'CM-GPT Causal Modeling',
49
49
  11: 'CM-GPT Node Classification',
50
50
  12: 'CM-GPT Edge Classification'
@@ -60,8 +60,8 @@ tasks_handler_map = {
60
60
  5: (bert_edge_classification.run, bert_ec_parse_args),
61
61
  6: (gnn_graph_cls.run, gnn_parse_args),
62
62
  7: (gnn_node_classification.run, gnn_nc_parse_args),
63
- 8: (gnn_edge_classification.run, gnn_ec_parse_args),
64
- 9: (gnn_link_prediction.run, gnn_lp_parse_args),
63
+ 8: (gnn_link_prediction.run, gnn_lp_parse_args),
64
+ 9: (gnn_edge_classification.run, gnn_ec_parse_args),
65
65
  10: (cm_gpt_pretraining.run, cm_gpt_parse_args),
66
66
  11: (cm_gpt_node_classification.run, cm_gpt_nc_parse_args),
67
67
  12: (cm_gpt_edge_classification.run, cm_gpt_ec_parse_args)
@@ -84,7 +84,7 @@ def main():
84
84
  ### If args has -h or --help, print help
85
85
  if any(x in remaining_args for x in ['-th', '--task_help']):
86
86
  task_id = args.task_id
87
- hander, task_parser = tasks_handler_map[task_id]
87
+ task_handler, task_parser = tasks_handler_map[task_id]
88
88
  print("Help for task:", tasks[task_id])
89
89
  task_parser().print_help()
90
90
  exit(0)
@@ -93,6 +93,11 @@ def main():
93
93
 
94
94
 
95
95
  task_id = args.task_id
96
- hander, task_parser = tasks_handler_map[task_id]
96
+ task_handler, task_parser = tasks_handler_map[task_id]
97
97
  task_args = task_parser().parse_args(remaining_args)
98
- hander(task_args)
98
+ task_handler(task_args)
99
+
100
+
101
+ if __name__ == '__main__':
102
+ main()
103
+