glam4cm 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. glam4cm/__init__.py +2 -1
  2. glam4cm/data_loading/data.py +90 -146
  3. glam4cm/data_loading/encoding.py +17 -6
  4. glam4cm/data_loading/graph_dataset.py +192 -57
  5. glam4cm/data_loading/metadata.py +1 -1
  6. glam4cm/data_loading/models_dataset.py +42 -18
  7. glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
  8. glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
  9. glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
  10. glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
  11. glam4cm/downstream_tasks/bert_node_classification.py +127 -89
  12. glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
  13. glam4cm/downstream_tasks/common_args.py +32 -4
  14. glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
  15. glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
  16. glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
  17. glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
  18. glam4cm/downstream_tasks/utils.py +16 -2
  19. glam4cm/embeddings/bert.py +1 -1
  20. glam4cm/embeddings/common.py +7 -4
  21. glam4cm/encoding/encoders.py +1 -1
  22. glam4cm/lang2graph/archimate.py +0 -5
  23. glam4cm/lang2graph/common.py +99 -41
  24. glam4cm/lang2graph/ecore.py +1 -2
  25. glam4cm/lang2graph/ontouml.py +8 -7
  26. glam4cm/models/gnn_layers.py +20 -6
  27. glam4cm/models/hf.py +2 -2
  28. glam4cm/run.py +12 -7
  29. glam4cm/run_conf_v2.py +405 -0
  30. glam4cm/run_configs.py +70 -106
  31. glam4cm/run_confs.py +41 -0
  32. glam4cm/settings.py +15 -2
  33. glam4cm/tokenization/special_tokens.py +23 -1
  34. glam4cm/tokenization/utils.py +23 -4
  35. glam4cm/trainers/cm_gpt_trainer.py +1 -1
  36. glam4cm/trainers/gnn_edge_classifier.py +12 -1
  37. glam4cm/trainers/gnn_graph_classifier.py +12 -5
  38. glam4cm/trainers/gnn_link_predictor.py +18 -3
  39. glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
  40. glam4cm/trainers/gnn_trainer.py +8 -0
  41. glam4cm/trainers/metrics.py +1 -1
  42. glam4cm/utils.py +265 -2
  43. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
  44. glam4cm-1.0.0.dist-info/RECORD +75 -0
  45. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
  46. glam4cm-0.1.1.dist-info/RECORD +0 -72
  47. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
  48. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
  49. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -11,18 +11,29 @@ import numpy as np
11
11
  from scipy.sparse import csr_matrix
12
12
  from transformers import AutoTokenizer
13
13
  from glam4cm.data_loading.data import TorchEdgeGraph, TorchGraph, TorchNodeGraph
14
- from glam4cm.data_loading.models_dataset import ArchiMateDataset, EcoreDataset, OntoUMLDataset
14
+ from glam4cm.data_loading.models_dataset import (
15
+ ArchiMateDataset,
16
+ EcoreDataset,
17
+ OntoUMLDataset
18
+ )
15
19
  from glam4cm.data_loading.encoding import EncodingDataset, GPTTextDataset
16
20
  from tqdm.auto import tqdm
17
21
  from glam4cm.embeddings.w2v import Word2VecEmbedder
18
22
  from glam4cm.embeddings.tfidf import TfidfEmbedder
19
23
  from glam4cm.embeddings.common import get_embedding_model
20
24
  from glam4cm.lang2graph.common import LangGraph, get_node_data, get_edge_data
21
- from glam4cm.data_loading.metadata import ArchimateMetaData, EcoreMetaData, OntoUMLMetaData
25
+ from glam4cm.data_loading.metadata import (
26
+ ArchimateMetaData,
27
+ EcoreMetaData,
28
+ OntoUMLMetaData
29
+ )
22
30
  from glam4cm.settings import seed
23
31
  from glam4cm.settings import (
24
- LP_TASK_EDGE_CLS,
25
- LP_TASK_LINK_PRED,
32
+ EDGE_CLS_TASK,
33
+ LINK_PRED_TASK,
34
+ NODE_CLS_TASK,
35
+ GRAPH_CLS_TASK,
36
+ DUMMY_GRAPH_CLS_TASK
26
37
  )
27
38
  import glam4cm.utils as utils
28
39
 
@@ -94,7 +105,7 @@ class GraphDataset(torch.utils.data.Dataset):
94
105
  def __init__(
95
106
  self,
96
107
  models_dataset: Union[EcoreDataset, ArchiMateDataset],
97
- save_dir='datasets/graph_data',
108
+ task_type: str,
98
109
  distance=1,
99
110
  add_negative_train_samples=False,
100
111
  neg_sampling_ratio=1,
@@ -103,6 +114,8 @@ class GraphDataset(torch.utils.data.Dataset):
103
114
  use_node_types=False,
104
115
  use_edge_label=False,
105
116
  no_labels=False,
117
+
118
+ node_topk=-1,
106
119
 
107
120
  node_cls_label=None,
108
121
  edge_cls_label=None,
@@ -120,8 +133,11 @@ class GraphDataset(torch.utils.data.Dataset):
120
133
  randomize_ee=False,
121
134
  random_embed_dim=128,
122
135
 
123
- exclude_labels: list = [None, ''],
136
+ exclude_labels: List = None,
137
+ save_dir='datasets/graph_data',
124
138
  ):
139
+ self.task_type = task_type
140
+
125
141
  if isinstance(models_dataset, EcoreDataset):
126
142
  self.metadata = EcoreMetaData()
127
143
  elif isinstance(models_dataset, ArchiMateDataset):
@@ -147,12 +163,34 @@ class GraphDataset(torch.utils.data.Dataset):
147
163
 
148
164
  self.test_ratio = test_ratio
149
165
 
150
- self.no_shuffle = no_shuffle
151
- self.exclude_labels = exclude_labels
152
-
153
166
  self.use_special_tokens = use_special_tokens
154
167
  self.node_cls_label = node_cls_label
155
168
  self.edge_cls_label = edge_cls_label
169
+
170
+ self.node_topk = node_topk
171
+
172
+
173
+ self.no_shuffle = no_shuffle
174
+
175
+ all_labels = [
176
+ i[0] for i in sorted(
177
+ dict(Counter(
178
+ sum(
179
+ [
180
+ [
181
+ v for v in list(dict(model.numbered_graph.nodes(data=node_cls_label)).values()) if v
182
+ ]
183
+ for model in models_dataset],
184
+ []
185
+ )
186
+ )).items(),
187
+ key=lambda x: x[1],
188
+ reverse=True
189
+ )
190
+ ]
191
+
192
+ self.node_topk = all_labels[:node_topk] if node_topk > 0 else all_labels
193
+ self.exclude_labels = (all_labels[node_topk+1:] if node_topk > 0 else []) + [None, '']
156
194
 
157
195
  self.randomize_ne = randomize_ne
158
196
  self.randomize_ee = randomize_ee
@@ -161,6 +199,8 @@ class GraphDataset(torch.utils.data.Dataset):
161
199
  self.graphs: List[Union[TorchNodeGraph, TorchEdgeGraph]] = []
162
200
  self.config = dict(
163
201
  name=models_dataset.name,
202
+ task_type=task_type,
203
+ node_topk=node_topk,
164
204
  distance=distance,
165
205
  add_negative_train_samples=add_negative_train_samples,
166
206
  neg_sampling_ratio=neg_sampling_ratio,
@@ -181,7 +221,6 @@ class GraphDataset(torch.utils.data.Dataset):
181
221
  randomize_ee=randomize_ee,
182
222
  random_embed_dim=random_embed_dim
183
223
  )
184
-
185
224
  self.save_dir = os.path.join(save_dir, models_dataset.name)
186
225
  os.makedirs(self.save_dir, exist_ok=True)
187
226
 
@@ -210,22 +249,21 @@ class GraphDataset(torch.utils.data.Dataset):
210
249
  self.config_hash = self.get_config_hash()
211
250
  os.makedirs(os.path.join(self.save_dir, self.config_hash), exist_ok=True)
212
251
  self.file_paths = {
213
- graph.hash: os.path.join(self.save_dir, self.config_hash, f'{graph.hash}', 'data.pkl')
252
+ graph.hash: os.path.join(
253
+ self.save_dir,
254
+ self.config_hash,
255
+ f'{graph.hash}_{self.get_string_gen_params_hash()}',
256
+ 'data.pkl'
257
+ )
214
258
  for graph in models_dataset
215
259
  }
216
260
 
217
261
  print("Number of duplicate graphs: ", len(models_dataset) - len(self.file_paths))
218
262
 
219
-
220
- def set_torch_graphs(
221
- self,
222
- type: str,
223
- models_dataset: Union[EcoreDataset, ArchiMateDataset],
224
- limit: int =-1
225
- ):
226
-
263
+ def get_common_params(self):
227
264
  common_params = dict(
228
265
  metadata=self.metadata,
266
+ task_type=self.task_type,
229
267
  distance=self.distance,
230
268
  test_ratio=self.test_ratio,
231
269
  use_attributes=self.use_attributes,
@@ -235,8 +273,34 @@ class GraphDataset(torch.utils.data.Dataset):
235
273
  use_special_tokens=self.use_special_tokens,
236
274
  no_labels=self.no_labels,
237
275
  node_cls_label=self.node_cls_label,
238
- edge_cls_label=self.edge_cls_label
276
+ edge_cls_label=self.edge_cls_label,
277
+ node_topk=self.node_topk,
239
278
  )
279
+ return common_params
280
+
281
+ def get_string_gen_params_hash(self):
282
+ string_gen_params = f"""
283
+ distance={self.distance},
284
+ use_attributes={self.use_attributes},
285
+ use_node_types={self.use_node_types},
286
+ use_edge_types={self.use_edge_types},
287
+ use_edge_label={self.use_edge_label},
288
+ use_special_tokens={self.use_special_tokens},
289
+ no_labels={self.no_labels},
290
+ node_cls_label={self.node_cls_label},
291
+ edge_cls_label={self.edge_cls_label},
292
+ node_topk={self.node_topk},
293
+ ckpt={self.ckpt},
294
+ """
295
+ return utils.md5_hash(string_gen_params)
296
+
297
+ def set_torch_graphs(
298
+ self,
299
+ models_dataset: Union[EcoreDataset, ArchiMateDataset],
300
+ limit: int =-1
301
+ ):
302
+
303
+ common_params = self.get_common_params()
240
304
  def create_node_graph(graph: LangGraph, fp: str) -> TorchNodeGraph:
241
305
  node_params = {
242
306
  **common_params,
@@ -256,24 +320,29 @@ class GraphDataset(torch.utils.data.Dataset):
256
320
  return torch_graph
257
321
 
258
322
 
259
-
260
323
  models_size = len(models_dataset) \
261
324
  if (limit == -1 or limit > len(models_dataset)) else limit
262
325
 
263
326
  self.set_file_hashes(models_dataset[:models_size])
264
327
 
265
- for graph in tqdm(models_dataset[:models_size], desc=f'Creating {type} graphs'):
328
+ for graph in tqdm(models_dataset[:models_size], desc=f'Creating {self.task_type} graphs'):
266
329
  fp = self.file_paths[graph.hash]
267
330
  if not os.path.exists(fp) or self.reload:
268
- if type == 'node':
331
+ if self.task_type in [NODE_CLS_TASK, GRAPH_CLS_TASK, DUMMY_GRAPH_CLS_TASK]:
269
332
  torch_graph: TorchNodeGraph = create_node_graph(graph, fp)
270
- elif type == 'edge':
333
+ elif self.task_type in [EDGE_CLS_TASK, LINK_PRED_TASK]:
271
334
  torch_graph: TorchEdgeGraph = create_edge_graph(graph, fp)
272
-
335
+ else:
336
+ raise ValueError(f"Invalid task type: {self.task_type}")
273
337
  torch_graph.save()
274
338
 
275
- def embed(self):
276
- for fp in tqdm(self.file_paths.values(), desc='Embedding graphs'):
339
+
340
+ def embed(self, models_dataset, limit):
341
+ models_size = len(models_dataset) \
342
+ if (limit == -1 or limit > len(models_dataset)) else limit
343
+ print("Limit: ", limit)
344
+ for graph in tqdm(models_dataset[:models_size], desc=f'Creating {self.task_type} graphs'):
345
+ fp = self.file_paths[graph.hash]
277
346
  torch_graph = TorchGraph.load(fp)
278
347
  torch_graph.embed(
279
348
  self.embedder,
@@ -306,7 +375,8 @@ class GraphDataset(torch.utils.data.Dataset):
306
375
  return np.concatenate([a, b], axis=1)
307
376
 
308
377
  prefix_cls = getattr(self, f"{prefix}_cls_label")
309
- num_classes = getattr(self, f"num_{prefix}s_{prefix_cls}") + 1
378
+ num_classes = getattr(self, f"{prefix}_label_map_{prefix_cls}").classes_.shape[0]
379
+
310
380
  # print(f"Number of {prefix} types: {num_classes}")
311
381
  for g in self.graphs:
312
382
  types = np.eye(num_classes)[getattr(g.data, f"{prefix}_{prefix_cls}")]
@@ -325,10 +395,10 @@ class GraphDataset(torch.utils.data.Dataset):
325
395
  assert all(g.data.edge_attr.shape[1] == edge_dim for g in self.graphs), "Edge types not added correctly"
326
396
 
327
397
 
328
- if self.use_node_types and self.node_cls_label:
398
+ if self.use_node_types and self.node_cls_label and self.task_type not in [NODE_CLS_TASK]:
329
399
  set_types('node')
330
400
 
331
- if self.use_edge_types and self.edge_cls_label:
401
+ if self.use_edge_types and self.edge_cls_label and self.task_type not in [EDGE_CLS_TASK, LINK_PRED_TASK]:
332
402
  set_types('edge')
333
403
 
334
404
  def __len__(self):
@@ -375,8 +445,9 @@ class GraphDataset(torch.utils.data.Dataset):
375
445
  assert torch_graph.data.overall_edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
376
446
  f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
377
447
  else:
378
- assert torch_graph.data.edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
379
- f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
448
+ if len(torch_graph.data.edge_index.shape) > 1:
449
+ assert torch_graph.data.edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
450
+ f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
380
451
 
381
452
 
382
453
 
@@ -407,7 +478,6 @@ class GraphDataset(torch.utils.data.Dataset):
407
478
  node_label_map = LabelEncoder()
408
479
  node_label_map.fit_transform([j for i in label_values for j in i])
409
480
  label_values = [node_label_map.transform(i) for i in label_values]
410
- print(node_label_map.classes_)
411
481
 
412
482
  for torch_graph, node_classes in zip(self.graphs, label_values):
413
483
  setattr(torch_graph.data, f"node_{cls_label}", np.array(node_classes))
@@ -532,7 +602,7 @@ class GraphDataset(torch.utils.data.Dataset):
532
602
  y = [getattr(self.graphs[i].data, f'graph_{graph_label_name}')[0].item() for i in indices]
533
603
 
534
604
  dataset = EncodingDataset(tokenizer, X, y, remove_duplicates=remove_duplicates)
535
- print("\n".join([f"Label: {self.graph_label_map_label.inverse_transform([l])[0]}, Text: {i}" for i, l in zip(X, y)]))
605
+ # print("\n".join([f"Label: {self.graph_label_map_label.inverse_transform([l])[0]}, Text: {i}" for i, l in zip(X, y)]))
536
606
 
537
607
  return dataset
538
608
 
@@ -566,7 +636,7 @@ class GraphEdgeDataset(GraphDataset):
566
636
  def __init__(
567
637
  self,
568
638
  models_dataset: Union[EcoreDataset, ArchiMateDataset],
569
- save_dir='datasets/graph_data',
639
+ task_type: str,
570
640
  distance=0,
571
641
  reload=False,
572
642
  test_ratio=0.2,
@@ -580,6 +650,8 @@ class GraphEdgeDataset(GraphDataset):
580
650
  use_node_types=False,
581
651
  no_labels=False,
582
652
 
653
+ node_topk = -1,
654
+
583
655
  use_embeddings=False,
584
656
  embed_model_name='bert-base-uncased',
585
657
  ckpt=None,
@@ -595,13 +667,13 @@ class GraphEdgeDataset(GraphDataset):
595
667
 
596
668
  node_cls_label: str = None,
597
669
  edge_cls_label: str = None,
598
-
599
- task_type=LP_TASK_EDGE_CLS
670
+ save_dir='datasets/graph_data'
600
671
  ):
601
-
672
+ assert task_type in [EDGE_CLS_TASK, GRAPH_CLS_TASK], f"Invalid task type: Must be one of {[EDGE_CLS_TASK, GRAPH_CLS_TASK]}."
602
673
  super().__init__(
603
674
  models_dataset=models_dataset,
604
- save_dir=save_dir,
675
+ task_type=task_type,
676
+
605
677
  distance=distance,
606
678
  test_ratio=test_ratio,
607
679
 
@@ -609,8 +681,9 @@ class GraphEdgeDataset(GraphDataset):
609
681
  use_edge_types=use_edge_types,
610
682
  use_edge_label=use_edge_label,
611
683
  use_attributes=use_attributes,
612
- no_labels=no_labels,
613
-
684
+ no_labels=no_labels,
685
+ node_topk = node_topk,
686
+
614
687
  add_negative_train_samples=add_negative_train_samples,
615
688
  neg_sampling_ratio=neg_sampling_ratio,
616
689
 
@@ -631,11 +704,10 @@ class GraphEdgeDataset(GraphDataset):
631
704
  randomize_ne=randomize_ne,
632
705
  randomize_ee=randomize_ee,
633
706
  random_embed_dim=random_embed_dim,
707
+ save_dir=save_dir
634
708
  )
635
709
 
636
- self.task_type = task_type
637
-
638
- self.set_torch_graphs('edge', models_dataset, limit)
710
+ self.set_torch_graphs(models_dataset, limit)
639
711
 
640
712
  if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
641
713
  texts = self.get_link_prediction_texts(only_texts=True)
@@ -644,7 +716,7 @@ class GraphEdgeDataset(GraphDataset):
644
716
  self.embedder.train(texts)
645
717
  print(f"Trained {self.embedder.name} Embedder")
646
718
 
647
- self.embed()
719
+ self.embed(models_dataset, limit)
648
720
 
649
721
  train_count, test_count = dict(), dict()
650
722
  for g in self.graphs:
@@ -675,12 +747,13 @@ class GraphEdgeDataset(GraphDataset):
675
747
  assert label is not None, "No edge label found in data. Please define edge label in metadata"
676
748
 
677
749
  data = defaultdict(list)
678
- for torch_graph in tqdm(self.graphs, desc='Getting Graph Texts'):
750
+ for torch_graph in tqdm(self.graphs, desc=f'Getting {self.task_type} Texts'):
679
751
  # torch_graph: TorchEdgeGraph = TorchGraph.load(fp)
680
752
  graph_data = torch_graph.get_link_prediction_texts(label, self.task_type, only_texts)
681
753
  for k, v in graph_data.items():
682
754
  data[k] += v
683
755
 
756
+
684
757
  print("Train Texts: ", data[f'train_pos_edges'][:20])
685
758
  print("Test Texts: ", data[f'test_pos_edges'][:20])
686
759
 
@@ -701,7 +774,7 @@ class GraphEdgeDataset(GraphDataset):
701
774
 
702
775
 
703
776
  print("Tokenizing data")
704
- if self.task_type == LP_TASK_EDGE_CLS:
777
+ if self.task_type == EDGE_CLS_TASK:
705
778
  datasets = {
706
779
  'train': EncodingDataset(
707
780
  tokenizer,
@@ -714,7 +787,7 @@ class GraphEdgeDataset(GraphDataset):
714
787
  data['test_edge_classes']
715
788
  )
716
789
  }
717
- elif self.task_type == LP_TASK_LINK_PRED:
790
+ elif self.task_type == LINK_PRED_TASK:
718
791
  datasets = {
719
792
  'train': EncodingDataset(
720
793
  tokenizer,
@@ -738,8 +811,9 @@ class GraphEdgeDataset(GraphDataset):
738
811
  class GraphNodeDataset(GraphDataset):
739
812
  def __init__(
740
813
  self,
741
- models_dataset: Union[EcoreDataset, ArchiMateDataset],
742
- save_dir='datasets/graph_data',
814
+ models_dataset: Union[EcoreDataset, ArchiMateDataset, OntoUMLDataset],
815
+ task_type: str,
816
+
743
817
  distance=0,
744
818
  test_ratio=0.2,
745
819
  reload=False,
@@ -749,6 +823,7 @@ class GraphNodeDataset(GraphDataset):
749
823
  use_node_types=False,
750
824
  use_edge_label=False,
751
825
  use_special_tokens=False,
826
+ node_topk=-1,
752
827
 
753
828
  use_embeddings=False,
754
829
  embed_model_name='bert-base-uncased',
@@ -762,11 +837,69 @@ class GraphNodeDataset(GraphDataset):
762
837
  limit: int = -1,
763
838
  no_labels=False,
764
839
  node_cls_label: str = None,
765
- edge_cls_label: str = None
840
+ edge_cls_label: str = None,
841
+
842
+ save_dir='datasets/graph_data',
766
843
  ):
844
+ """
845
+ Parameters
846
+ ----------
847
+ models_dataset: Union[EcoreDataset, ArchiMateDataset, OntoUMLDataset]
848
+ The dataset of models to convert to a graph dataset.
849
+ task_type: str
850
+ The type of task to perform on the graph dataset. Must be one of 'node', 'edge', or 'graph'.
851
+ distance: int
852
+ The distance to consider when creating the graph. If 0, only the node itself is considered.
853
+ test_ratio: float
854
+ The proportion of the dataset to split into the test set.
855
+ reload: bool
856
+ Whether to reload the dataset from disk if it already exists.
857
+ use_attributes: bool
858
+ Whether to include attributes of the nodes and edges in the graph.
859
+ use_edge_types: bool
860
+ Whether to include the types of the edges in the graph.
861
+ use_node_types: bool
862
+ Whether to include the types of the nodes in the graph.
863
+ use_edge_label: bool
864
+ Whether to include the labels of the edges in the graph.
865
+ use_special_tokens: bool
866
+ Whether to include special tokens for the start and end of a node or edge sequence.
867
+ node_topk: int
868
+ The number of top nodes to include in the graph. If -1, all nodes are included.
869
+ use_embeddings: bool
870
+ Whether to use embeddings for the node and edge attributes.
871
+ embed_model_name: str
872
+ The name of the embedding model to use.
873
+ ckpt: str
874
+ The path to the checkpoint file of the embedding model.
875
+ no_shuffle: bool
876
+ Whether to shuffle the dataset before splitting it into train and test sets.
877
+ randomize_ne: bool
878
+ Whether to randomize the node embeddings.
879
+ randomize_ee: bool
880
+ Whether to randomize the edge embeddings.
881
+ random_embed_dim: int
882
+ The dimension of the random embeddings.
883
+ limit: int
884
+ The maximum number of models to include in the dataset.
885
+ no_labels: bool
886
+ Whether to include labels for the nodes and edges in the graph.
887
+ node_cls_label: str
888
+ The label to use for the node classification task.
889
+ edge_cls_label: str
890
+ The label to use for the edge classification task.
891
+ save_dir: str
892
+ The directory in which to save the dataset.
893
+
894
+ Returns
895
+ -------
896
+ A GraphNodeDataset object.
897
+ """
898
+ assert task_type in [NODE_CLS_TASK, GRAPH_CLS_TASK], f"Invalid task type: Must be one of {[NODE_CLS_TASK, GRAPH_CLS_TASK]}."
767
899
  super().__init__(
768
900
  models_dataset=models_dataset,
769
- save_dir=save_dir,
901
+ task_type=task_type,
902
+
770
903
  distance=distance,
771
904
  test_ratio=test_ratio,
772
905
 
@@ -778,6 +911,8 @@ class GraphNodeDataset(GraphDataset):
778
911
 
779
912
  node_cls_label=node_cls_label,
780
913
  edge_cls_label=edge_cls_label,
914
+
915
+ node_topk=node_topk,
781
916
 
782
917
  use_embeddings=use_embeddings,
783
918
  embed_model_name=embed_model_name,
@@ -791,9 +926,10 @@ class GraphNodeDataset(GraphDataset):
791
926
  randomize_ne=randomize_ne,
792
927
  randomize_ee=randomize_ee,
793
928
  random_embed_dim=random_embed_dim,
929
+ save_dir=save_dir,
794
930
  )
795
931
 
796
- self.set_torch_graphs('node', models_dataset, limit)
932
+ self.set_torch_graphs(models_dataset, limit)
797
933
 
798
934
  if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
799
935
  texts = self.get_node_classification_texts()
@@ -802,7 +938,7 @@ class GraphNodeDataset(GraphDataset):
802
938
  self.embedder.train(texts)
803
939
  print(f"Trained {self.embedder.name} Embedder")
804
940
 
805
- self.embed()
941
+ self.embed(models_dataset, limit)
806
942
 
807
943
  node_labels = self.metadata.node_cls
808
944
  if isinstance(node_labels, str):
@@ -860,9 +996,8 @@ class GraphNodeDataset(GraphDataset):
860
996
  data['test_node_classes'] += test_node_classes
861
997
 
862
998
 
863
- print("Tokenizing data")
864
- print(data['train_nodes'][:10])
865
- print(data['test_nodes'][:10])
999
+ # print("\n".join(data['train_nodes']))
1000
+ # print("\n".join(data['test_nodes']))
866
1001
  if hasattr(self, "node_label_map_type"):
867
1002
  node_label_map.inverse_transform([i.item() for i in train_node_classes]) == train_node_strs
868
1003
  node_label_map.inverse_transform([i.item() for i in test_node_classes]) == test_node_strs
@@ -76,7 +76,7 @@ class OntoUMLMetaData(GraphMetadata):
76
76
  "attributes": "properties"
77
77
  }
78
78
  self.edge = {
79
- "cls": "stereotype"
79
+ "cls": 'type'
80
80
  }
81
81
 
82
82
  self.graph = {
@@ -30,7 +30,8 @@ class ModelDataset:
30
30
  min_edges: int = -1,
31
31
  min_enr: float = -1,
32
32
  timeout=-1,
33
- preprocess_graph_text: callable = None
33
+ preprocess_graph_text: callable = None,
34
+ include_dummies=False
34
35
  ):
35
36
  self.name = dataset_name
36
37
  self.dataset_dir = dataset_dir
@@ -41,6 +42,7 @@ class ModelDataset:
41
42
  self.min_enr = min_enr
42
43
  self.timeout = timeout
43
44
  self.preprocess_graph_text = preprocess_graph_text
45
+ self.include_dummies = include_dummies
44
46
 
45
47
  self.graphs: List[LangGraph] = []
46
48
 
@@ -114,12 +116,14 @@ class ModelDataset:
114
116
 
115
117
  def save(self):
116
118
  print(f'Saving {self.name} to pickle')
117
- with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'wb') as f:
119
+ pkl_file = f'{self.name}{"_with_dummies" if self.include_dummies else ''}.pkl'
120
+ with open(os.path.join(self.save_dir, pkl_file), 'wb') as f:
118
121
  pickle.dump(self.graphs, f)
119
122
  print(f'Saved {self.name} to pickle')
120
123
 
121
124
 
122
125
  def filter_graphs(self):
126
+ # print("Filtering graphs with min edges and min enr: ", self.min_edges, self.min_enr)
123
127
  graphs = list()
124
128
  for graph in self.graphs:
125
129
  addable = True
@@ -129,6 +133,7 @@ class ModelDataset:
129
133
  addable = False
130
134
 
131
135
  if addable:
136
+ # print("Addable because min edges and min enr: ", graph.number_of_edges())
132
137
  graphs.append(graph)
133
138
 
134
139
  self.graphs = graphs
@@ -137,7 +142,8 @@ class ModelDataset:
137
142
 
138
143
  def load(self):
139
144
  print(f'Loading {self.name} from pickle')
140
- with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'rb') as f:
145
+ pkl_file = f'{self.name}{"_with_dummies" if self.include_dummies else ''}.pkl'
146
+ with open(os.path.join(self.save_dir, pkl_file), 'rb') as f:
141
147
  self.graphs = pickle.load(f)
142
148
 
143
149
  self.filter_graphs()
@@ -172,7 +178,8 @@ class EcoreDataset(ModelDataset):
172
178
  remove_duplicates=False,
173
179
  min_edges: int = -1,
174
180
  min_enr: float = -1,
175
- preprocess_graph_text: callable = None
181
+ preprocess_graph_text: callable = None,
182
+ include_dummies=False
176
183
  ):
177
184
  super().__init__(
178
185
  dataset_name,
@@ -180,22 +187,34 @@ class EcoreDataset(ModelDataset):
180
187
  save_dir=save_dir,
181
188
  min_edges=min_edges,
182
189
  min_enr=min_enr,
183
- preprocess_graph_text=preprocess_graph_text
190
+ preprocess_graph_text=preprocess_graph_text,
191
+ include_dummies=include_dummies
184
192
  )
185
193
  os.makedirs(save_dir, exist_ok=True)
186
194
 
187
195
  dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
188
196
  if reload or not dataset_exists:
197
+
198
+
189
199
  self.graphs: List[EcoreNxG] = []
190
200
  data_path = os.path.join(dataset_dir, dataset_name)
191
- for file in os.listdir(data_path):
192
- if file.endswith('.jsonl') and file.startswith('ecore'):
193
- json_objects = json.load(open(os.path.join(data_path, file)))
194
- for g in tqdm(json_objects, desc=f'Loading {dataset_name.title()}'):
195
- if remove_duplicates and g['is_duplicated']:
196
- continue
197
- nxg = EcoreNxG(g)
198
- self.graphs.append(nxg)
201
+ file_name = os.path.join(data_path, 'ecore.jsonl') if not include_dummies\
202
+ else os.path.join(data_path, 'ecore-with-dummy.jsonl')
203
+
204
+ # for file in os.listdir(data_path):
205
+ # if file.endswith('.jsonl') and file.startswith("ecore"):
206
+ json_objects = json.load(open(file_name))
207
+ for g in tqdm(json_objects, desc=f'Loading {dataset_name.title()}'):
208
+
209
+ if remove_duplicates and g['is_duplicated']:
210
+ continue
211
+
212
+ if not include_dummies and g['labels'] == 'dummy':
213
+ print(f"Skipping dummy graph {g['ids']}")
214
+ continue
215
+
216
+ nxg = EcoreNxG(g)
217
+ self.graphs.append(nxg)
199
218
 
200
219
  print(f'Loaded Total {self.name} with {len(self.graphs)} graphs')
201
220
  print("Filtering...")
@@ -233,7 +252,8 @@ class ArchiMateDataset(ModelDataset):
233
252
  min_enr: float = -1,
234
253
  timeout=-1,
235
254
  language=None,
236
- preprocess_graph_text: callable = None
255
+ preprocess_graph_text: callable = None,
256
+ include_dummies=False
237
257
  ):
238
258
  super().__init__(
239
259
  dataset_name,
@@ -242,7 +262,8 @@ class ArchiMateDataset(ModelDataset):
242
262
  min_edges=min_edges,
243
263
  min_enr=min_enr,
244
264
  timeout=timeout,
245
- preprocess_graph_text=preprocess_graph_text
265
+ preprocess_graph_text=preprocess_graph_text,
266
+ include_dummies=include_dummies
246
267
  )
247
268
  os.makedirs(save_dir, exist_ok=True)
248
269
 
@@ -274,7 +295,7 @@ class ArchiMateDataset(ModelDataset):
274
295
 
275
296
  except Exception as e:
276
297
  raise e
277
-
298
+ print("Total graphs:", len(self.graphs))
278
299
  self.filter_graphs()
279
300
  self.save()
280
301
  else:
@@ -283,6 +304,7 @@ class ArchiMateDataset(ModelDataset):
283
304
  if remove_duplicates:
284
305
  self.dedup()
285
306
 
307
+ assert all([g.number_of_edges() >= min_edges for g in self.graphs]), f"Filtered out graphs with less than {min_edges} edges"
286
308
  print(f'Loaded {self.name} with {len(self.graphs)} graphs')
287
309
  print(f'Graphs: {len(self.graphs)}')
288
310
 
@@ -305,7 +327,8 @@ class OntoUMLDataset(ModelDataset):
305
327
  min_edges: int = -1,
306
328
  min_enr: float = -1,
307
329
  timeout=-1,
308
- preprocess_graph_text: callable = None
330
+ preprocess_graph_text: callable = None,
331
+ include_dummies=False
309
332
  ):
310
333
  super().__init__(
311
334
  dataset_name,
@@ -314,7 +337,8 @@ class OntoUMLDataset(ModelDataset):
314
337
  min_edges=min_edges,
315
338
  min_enr=min_enr,
316
339
  timeout=timeout,
317
- preprocess_graph_text=preprocess_graph_text
340
+ preprocess_graph_text=preprocess_graph_text,
341
+ include_dummies=include_dummies
318
342
  )
319
343
  os.makedirs(save_dir, exist_ok=True)
320
344