glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,940 @@
1
+ import json
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from sklearn.preprocessing import LabelEncoder
4
+ from collections import Counter, defaultdict
5
+ import os
6
+ from random import shuffle
7
+ from typing import List, Union
8
+ from sklearn.model_selection import StratifiedKFold
9
+ import torch
10
+ import numpy as np
11
+ from scipy.sparse import csr_matrix
12
+ from transformers import AutoTokenizer
13
+ from glam4cm.data_loading.data import TorchEdgeGraph, TorchGraph, TorchNodeGraph
14
+ from glam4cm.data_loading.models_dataset import ArchiMateDataset, EcoreDataset, OntoUMLDataset
15
+ from glam4cm.data_loading.encoding import EncodingDataset, GPTTextDataset
16
+ from tqdm.auto import tqdm
17
+ from glam4cm.embeddings.w2v import Word2VecEmbedder
18
+ from glam4cm.embeddings.tfidf import TfidfEmbedder
19
+ from glam4cm.embeddings.common import get_embedding_model
20
+ from glam4cm.lang2graph.common import LangGraph, get_node_data, get_edge_data
21
+ from glam4cm.data_loading.metadata import ArchimateMetaData, EcoreMetaData, OntoUMLMetaData
22
+ from glam4cm.settings import seed
23
+ from glam4cm.settings import (
24
+ LP_TASK_EDGE_CLS,
25
+ LP_TASK_LINK_PRED,
26
+ )
27
+ import glam4cm.utils as utils
28
+
29
+
30
+ def exclude_labels_from_data(X, labels, exclude_labels):
31
+ X = [X[i] for i in range(len(X)) if labels[i] not in exclude_labels]
32
+ labels = [labels[i] for i in range(len(labels)) if labels[i] not in exclude_labels]
33
+ return X, labels
34
+
35
+
36
+ def validate_classes(torch_graphs: List[TorchGraph], label, exclude_labels, element):
37
+ train_labels, test_labels = list(), list()
38
+ train_idx_label = f"train_{element}_mask"
39
+ test_idx_label = f"test_{element}_mask"
40
+
41
+ for torch_graph in tqdm(torch_graphs, desc=f"Validating {element} classes"):
42
+ labels = getattr(torch_graph.data, f"{element}_{label}")
43
+ train_idx = getattr(torch_graph.data, train_idx_label)
44
+ test_idx = getattr(torch_graph.data, test_idx_label)
45
+ indices = np.nonzero(np.isin(labels, exclude_labels))[0]
46
+
47
+ if len(indices) > 0:
48
+ train_idx = train_idx[~np.isin(train_idx, indices)]
49
+ test_idx = test_idx[~np.isin(test_idx, indices)]
50
+
51
+ train_labels.append(labels[train_idx])
52
+ test_labels.append(labels[test_idx])
53
+
54
+
55
+ edge_classes = [t for _, _, t in torch_graph.graph.edges(data=label)]
56
+ node_classes = [t for _, t in torch_graph.graph.nodes(data=label)]
57
+
58
+ if element == 'edge':
59
+ for idx in train_idx.tolist() + test_idx.tolist():
60
+ t_c = edge_classes[idx]
61
+ edge = torch_graph.graph.idx_to_edge[idx]
62
+ assert torch_graph.graph.numbered_graph.edges[edge][label] == t_c, f"Edge {label} mismatch for {edge}"
63
+
64
+ elif element == 'node':
65
+ for idx in train_idx.tolist() + test_idx.tolist():
66
+ t_c = node_classes[idx]
67
+ node_label = torch_graph.graph.id_to_node_label[idx]
68
+ try:
69
+ assert torch_graph.graph.nodes[node_label][label] == t_c, f"Node {label} mismatch for {idx}"
70
+ except AssertionError as e:
71
+ raise e
72
+ else:
73
+ raise ValueError(f"Invalid element: {element}")
74
+
75
+
76
+ train_classes = set(sum([
77
+ getattr(torch_graph.data, f"{element}_{label}")[getattr(torch_graph.data, train_idx_label)].tolist()
78
+ for torch_graph in torch_graphs], []
79
+ ))
80
+ test_classes = set(sum([
81
+ getattr(torch_graph.data, f"{element}_{label}")[getattr(torch_graph.data, test_idx_label)].tolist()
82
+ for torch_graph in torch_graphs], []
83
+ ))
84
+ num_train_classes = len(train_classes)
85
+ num_test_classes = len(test_classes)
86
+ print("Train classes:", train_classes)
87
+ print("Test classes:", test_classes)
88
+ print(f"Number of classes in training set: {num_train_classes}")
89
+ print(f"Number of classes in test set: {num_test_classes}")
90
+
91
+
92
+
93
+ class GraphDataset(torch.utils.data.Dataset):
94
+ def __init__(
95
+ self,
96
+ models_dataset: Union[EcoreDataset, ArchiMateDataset],
97
+ save_dir='datasets/graph_data',
98
+ distance=1,
99
+ add_negative_train_samples=False,
100
+ neg_sampling_ratio=1,
101
+ use_attributes=False,
102
+ use_edge_types=False,
103
+ use_node_types=False,
104
+ use_edge_label=False,
105
+ no_labels=False,
106
+
107
+ node_cls_label=None,
108
+ edge_cls_label=None,
109
+
110
+ test_ratio=0.2,
111
+
112
+ use_embeddings=False,
113
+ use_special_tokens=False,
114
+ embed_model_name='bert-base-uncased',
115
+ ckpt=None,
116
+ reload=False,
117
+ no_shuffle=False,
118
+
119
+ randomize_ne=False,
120
+ randomize_ee=False,
121
+ random_embed_dim=128,
122
+
123
+ exclude_labels: list = [None, ''],
124
+ ):
125
+ if isinstance(models_dataset, EcoreDataset):
126
+ self.metadata = EcoreMetaData()
127
+ elif isinstance(models_dataset, ArchiMateDataset):
128
+ self.metadata = ArchimateMetaData()
129
+ elif isinstance(models_dataset, OntoUMLDataset):
130
+ self.metadata = OntoUMLMetaData()
131
+
132
+ self.distance = distance
133
+ self.use_embeddings = use_embeddings
134
+ self.ckpt = ckpt
135
+ self.embedder = get_embedding_model(embed_model_name, ckpt) if use_embeddings else None
136
+
137
+ self.reload = reload
138
+
139
+ self.use_edge_types = use_edge_types
140
+ self.use_node_types = use_node_types
141
+ self.use_attributes = use_attributes
142
+ self.use_edge_label = use_edge_label
143
+ self.no_labels = no_labels
144
+
145
+ self.add_negative_train_samples = add_negative_train_samples
146
+ self.neg_sampling_ratio = neg_sampling_ratio
147
+
148
+ self.test_ratio = test_ratio
149
+
150
+ self.no_shuffle = no_shuffle
151
+ self.exclude_labels = exclude_labels
152
+
153
+ self.use_special_tokens = use_special_tokens
154
+ self.node_cls_label = node_cls_label
155
+ self.edge_cls_label = edge_cls_label
156
+
157
+ self.randomize_ne = randomize_ne
158
+ self.randomize_ee = randomize_ee
159
+ self.random_embed_dim = random_embed_dim
160
+
161
+ self.graphs: List[Union[TorchNodeGraph, TorchEdgeGraph]] = []
162
+ self.config = dict(
163
+ name=models_dataset.name,
164
+ distance=distance,
165
+ add_negative_train_samples=add_negative_train_samples,
166
+ neg_sampling_ratio=neg_sampling_ratio,
167
+ use_attributes=use_attributes,
168
+ use_edge_types=use_edge_types,
169
+ use_node_types=use_node_types,
170
+ use_edge_label=use_edge_label,
171
+ no_labels=no_labels,
172
+ use_special_tokens=use_special_tokens,
173
+ use_embeddings=use_embeddings,
174
+ embed_model_name=embed_model_name if use_embeddings else None,
175
+ ckpt=ckpt if use_embeddings else None,
176
+ exclude_labels=exclude_labels,
177
+ node_cls_label=node_cls_label,
178
+ edge_cls_label=edge_cls_label,
179
+ test_ratio=test_ratio,
180
+ randomize_ne=randomize_ne,
181
+ randomize_ee=randomize_ee,
182
+ random_embed_dim=random_embed_dim
183
+ )
184
+
185
+ self.save_dir = os.path.join(save_dir, models_dataset.name)
186
+ os.makedirs(self.save_dir, exist_ok=True)
187
+
188
+
189
+ @property
190
+ def node_dim(self):
191
+ pass
192
+
193
+ def get_config_hash(self):
194
+ if os.path.exists(os.path.join(self.save_dir, 'configs.json')):
195
+ with open(os.path.join(self.save_dir, 'configs.json'), 'r') as f:
196
+ configs = json.load(f)
197
+ else:
198
+ configs = dict()
199
+
200
+ config_hash = utils.md5_hash(str(self.config))
201
+ if config_hash not in configs:
202
+ configs[config_hash] = self.config
203
+ with open(os.path.join(self.save_dir, 'configs.json'), 'w') as f:
204
+ json.dump(configs, f)
205
+
206
+ return config_hash
207
+
208
+
209
+ def set_file_hashes(self, models_dataset: Union[EcoreDataset, ArchiMateDataset]):
210
+ self.config_hash = self.get_config_hash()
211
+ os.makedirs(os.path.join(self.save_dir, self.config_hash), exist_ok=True)
212
+ self.file_paths = {
213
+ graph.hash: os.path.join(self.save_dir, self.config_hash, f'{graph.hash}', 'data.pkl')
214
+ for graph in models_dataset
215
+ }
216
+
217
+ print("Number of duplicate graphs: ", len(models_dataset) - len(self.file_paths))
218
+
219
+
220
+ def set_torch_graphs(
221
+ self,
222
+ type: str,
223
+ models_dataset: Union[EcoreDataset, ArchiMateDataset],
224
+ limit: int =-1
225
+ ):
226
+
227
+ common_params = dict(
228
+ metadata=self.metadata,
229
+ distance=self.distance,
230
+ test_ratio=self.test_ratio,
231
+ use_attributes=self.use_attributes,
232
+ use_node_types=self.use_node_types,
233
+ use_edge_types=self.use_edge_types,
234
+ use_edge_label=self.use_edge_label,
235
+ use_special_tokens=self.use_special_tokens,
236
+ no_labels=self.no_labels,
237
+ node_cls_label=self.node_cls_label,
238
+ edge_cls_label=self.edge_cls_label
239
+ )
240
+ def create_node_graph(graph: LangGraph, fp: str) -> TorchNodeGraph:
241
+ node_params = {
242
+ **common_params,
243
+ 'fp': fp,
244
+ }
245
+ torch_graph = TorchNodeGraph(graph, **node_params)
246
+ return torch_graph
247
+
248
+ def create_edge_graph(graph: LangGraph, fp: str) -> TorchEdgeGraph:
249
+ edge_params = {
250
+ **common_params,
251
+ 'add_negative_train_samples': self.add_negative_train_samples,
252
+ 'neg_samples_ratio': self.neg_sampling_ratio,
253
+ 'fp': fp,
254
+ }
255
+ torch_graph = TorchEdgeGraph(graph, **edge_params)
256
+ return torch_graph
257
+
258
+
259
+
260
+ models_size = len(models_dataset) \
261
+ if (limit == -1 or limit > len(models_dataset)) else limit
262
+
263
+ self.set_file_hashes(models_dataset[:models_size])
264
+
265
+ for graph in tqdm(models_dataset[:models_size], desc=f'Creating {type} graphs'):
266
+ fp = self.file_paths[graph.hash]
267
+ if not os.path.exists(fp) or self.reload:
268
+ if type == 'node':
269
+ torch_graph: TorchNodeGraph = create_node_graph(graph, fp)
270
+ elif type == 'edge':
271
+ torch_graph: TorchEdgeGraph = create_edge_graph(graph, fp)
272
+
273
+ torch_graph.save()
274
+
275
+ def embed(self):
276
+ for fp in tqdm(self.file_paths.values(), desc='Embedding graphs'):
277
+ torch_graph = TorchGraph.load(fp)
278
+ torch_graph.embed(
279
+ self.embedder,
280
+ reload=self.reload,
281
+ randomize_ne=self.randomize_ne,
282
+ randomize_ee=self.randomize_ee,
283
+ random_embed_dim=self.random_embed_dim
284
+ )
285
+
286
+ for fp in tqdm(self.file_paths.values(), desc='Re-Loading graphs'):
287
+ torch_graph = TorchGraph.load(fp)
288
+ self.graphs.append(torch_graph)
289
+
290
+ if not self.no_shuffle:
291
+ shuffle(self.graphs)
292
+
293
+ self.post_process_graphs()
294
+ self.validate_graphs()
295
+ # self.save()
296
+ print("Graphs saved")
297
+
298
+
299
+ def post_process_graphs(self):
300
+ self.add_cls_labels()
301
+
302
+ def set_types(prefix):
303
+ def concatenate(a, b):
304
+ if isinstance(a, csr_matrix):
305
+ return csr_matrix(np.concatenate([a.toarray(), b], axis=1))
306
+ return np.concatenate([a, b], axis=1)
307
+
308
+ prefix_cls = getattr(self, f"{prefix}_cls_label")
309
+ num_classes = getattr(self, f"num_{prefix}s_{prefix_cls}") + 1
310
+ # print(f"Number of {prefix} types: {num_classes}")
311
+ for g in self.graphs:
312
+ types = np.eye(num_classes)[getattr(g.data, f"{prefix}_{prefix_cls}")]
313
+
314
+ if prefix == 'node':
315
+ g.data.x = concatenate(g.data.x, types)
316
+
317
+ elif prefix == 'edge':
318
+ g.data.edge_attr = concatenate(g.data.edge_attr, types)
319
+
320
+ # g.save()
321
+
322
+ node_dim = self.graphs[0].data.x.shape[1]
323
+ assert all(g.data.x.shape[1] == node_dim for g in self.graphs), "Node types not added correctly"
324
+ edge_dim = self.graphs[0].data.edge_attr.shape[1]
325
+ assert all(g.data.edge_attr.shape[1] == edge_dim for g in self.graphs), "Edge types not added correctly"
326
+
327
+
328
+ if self.use_node_types and self.node_cls_label:
329
+ set_types('node')
330
+
331
+ if self.use_edge_types and self.edge_cls_label:
332
+ set_types('edge')
333
+
334
+ def __len__(self):
335
+ return len(self.graphs)
336
+
337
+
338
+ def __getitem__(self, index: int):
339
+ return self.graphs[index]
340
+
341
+
342
+ def get_torch_dataset(self):
343
+ return [g.data.to_graph_data() for g in self.graphs]
344
+
345
+
346
+ def save(self):
347
+ for torch_graph in self.graphs:
348
+ torch_graph.save()
349
+
350
+
351
+ def get_train_test_split(self):
352
+ n = len(self.graphs)
353
+ train_size = int(n * (1 - self.test_ratio))
354
+ idx = list(range(n))
355
+ shuffle(idx)
356
+ train_idx = idx[:train_size]
357
+ test_idx = idx[train_size:]
358
+ return train_idx, test_idx
359
+
360
+
361
+ def k_fold_split(self):
362
+ k = int(1 / self.test_ratio)
363
+ kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
364
+ n = len(self.graphs)
365
+ for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
366
+ yield train_idx, test_idx
367
+
368
+
369
+ def validate_graphs(self):
370
+ for torch_graph in self.graphs:
371
+ assert torch_graph.data.x.shape[0] == torch_graph.graph.number_of_nodes(), \
372
+ f"Number of nodes mismatch, {torch_graph.data.x.shape[0]} != {torch_graph.graph.number_of_nodes()}"
373
+
374
+ if isinstance(torch_graph, TorchEdgeGraph):
375
+ assert torch_graph.data.overall_edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
376
+ f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
377
+ else:
378
+ assert torch_graph.data.edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
379
+ f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
380
+
381
+
382
+
383
+
384
+ def add_cls_labels(self):
385
+ self.add_node_labels()
386
+ self.add_edge_labels()
387
+ self.add_graph_labels()
388
+ self.add_graph_text()
389
+
390
+
391
+ def add_node_labels(self):
392
+ model_type = self.metadata.type
393
+ cls_labels = self.metadata.node_cls
394
+ if isinstance(cls_labels, str):
395
+ cls_labels = [cls_labels]
396
+
397
+ for cls_label in cls_labels:
398
+ label_values = list()
399
+ for torch_graph in self.graphs:
400
+ values = list()
401
+ for _, node in torch_graph.graph.nodes(data=True):
402
+ node_data = get_node_data(node, cls_label, model_type)
403
+ values.append(node_data)
404
+
405
+ label_values.append(values)
406
+
407
+ node_label_map = LabelEncoder()
408
+ node_label_map.fit_transform([j for i in label_values for j in i])
409
+ label_values = [node_label_map.transform(i) for i in label_values]
410
+ print(node_label_map.classes_)
411
+
412
+ for torch_graph, node_classes in zip(self.graphs, label_values):
413
+ setattr(torch_graph.data, f"node_{cls_label}", np.array(node_classes))
414
+
415
+ setattr(self, f"node_label_map_{cls_label}", node_label_map)
416
+
417
+ exclude_labels = [
418
+ node_label_map.transform([e])[0]
419
+ for e in self.exclude_labels
420
+ if e in node_label_map.classes_
421
+ ]
422
+ setattr(self, f"node_exclude_{cls_label}", exclude_labels)
423
+
424
+ num_labels = len(node_label_map.classes_) - len(exclude_labels)
425
+
426
+ print("Setting num_nodes_", cls_label, num_labels)
427
+ setattr(self, f"num_nodes_{cls_label}", num_labels)
428
+
429
+ if hasattr(self.graphs[0].data, 'train_node_mask'):
430
+ validate_classes(self.graphs, cls_label, exclude_labels, 'node')
431
+
432
+
433
+ def add_edge_labels(self):
434
+ model_type = self.metadata.type
435
+ cls_labels = self.metadata.edge_cls
436
+ if isinstance(cls_labels, str):
437
+ cls_labels = [cls_labels]
438
+
439
+ for cls_label in cls_labels:
440
+ label_values = list()
441
+ for torch_graph in self.graphs:
442
+ values = list()
443
+ for _, _, edge_data in torch_graph.graph.edges(data=True):
444
+ values.append(get_edge_data(edge_data, cls_label, model_type))
445
+ label_values.append(values)
446
+
447
+ edge_label_map = LabelEncoder()
448
+ edge_label_map.fit_transform([j for i in label_values for j in i])
449
+ label_values = [edge_label_map.transform(i) for i in label_values]
450
+ print("Edge Classes: ", edge_label_map.classes_)
451
+
452
+ for torch_graph, edge_classes in zip(self.graphs, label_values):
453
+ setattr(torch_graph.data, f"edge_{cls_label}", np.array(edge_classes))
454
+
455
+ setattr(self, f"edge_label_map_{cls_label}", edge_label_map)
456
+
457
+ exclude_labels = [
458
+ edge_label_map.transform([e])[0]
459
+ for e in self.exclude_labels
460
+ if e in edge_label_map.classes_
461
+ ]
462
+ setattr(self, f"edge_exclude_{cls_label}", edge_label_map.inverse_transform(exclude_labels))
463
+
464
+ num_labels = len(edge_label_map.classes_) - len(exclude_labels)
465
+ setattr(self, f"num_edges_{cls_label}", num_labels)
466
+
467
+ if hasattr(self.graphs[0].data, 'train_edge_mask'):
468
+ validate_classes(self.graphs, cls_label, exclude_labels, 'edge')
469
+
470
+
471
+ def add_graph_labels(self):
472
+ if hasattr(self.metadata, 'graph'):
473
+ cls_label = self.metadata.graph_cls
474
+ if cls_label and hasattr(self.graphs[0].graph, cls_label):
475
+ graph_labels = list()
476
+ for torch_graph in self.graphs:
477
+ graph_labels.append(getattr(torch_graph.graph, cls_label))
478
+
479
+ graph_label_map = LabelEncoder()
480
+ graph_labels = graph_label_map.fit_transform(graph_labels)
481
+
482
+ print("Graph Classes: ", graph_label_map.classes_)
483
+
484
+ for torch_graph, graph_label in zip(self.graphs, graph_labels):
485
+ setattr(torch_graph.data, f"graph_{cls_label}", np.array([graph_label]))
486
+
487
+ exclude_labels = [
488
+ graph_label_map.transform([e])[0]
489
+ for e in self.exclude_labels
490
+ if e in graph_label_map.classes_
491
+ ]
492
+ setattr(self, f"graph_exclude_{cls_label}", exclude_labels)
493
+ num_labels = len(graph_label_map.classes_) - len(exclude_labels)
494
+
495
+ print("Setting num_graph_", cls_label, num_labels)
496
+ setattr(self, f"num_graph_{cls_label}", num_labels)
497
+ setattr(self, f"graph_label_map_{cls_label}", graph_label_map)
498
+
499
+
500
+
501
+ def add_graph_text(self):
502
+ label = self.metadata.graph_label
503
+ if label:
504
+ for torch_graph in self.graphs:
505
+ setattr(torch_graph, label, getattr(torch_graph.graph, label))
506
+
507
+
508
+ def get_gnn_graph_classification_data(self):
509
+ train_idx, test_idx = self.get_train_test_split()
510
+ train_data = [self.graphs[i].data for i in train_idx]
511
+ test_data = [self.graphs[i].data for i in test_idx]
512
+ return {
513
+ 'train': train_data,
514
+ 'test': test_data
515
+ }
516
+
517
+
518
+ def get_kfold_gnn_graph_classification_data(self):
519
+ for train_idx, test_idx in self.k_fold_split():
520
+ train_data = [self.graphs[i].data.to_graph_data() for i in train_idx]
521
+ test_data = [self.graphs[i].data.to_graph_data() for i in test_idx]
522
+ yield {
523
+ 'train': train_data,
524
+ 'test': test_data,
525
+ }
526
+
527
+
528
+ def __get_lm_data(self, indices, tokenizer, remove_duplicates=False):
529
+ graph_label_name = self.metadata.graph_cls
530
+ assert graph_label_name is not None, "No Graph Label found in data. Please define graph label in metadata"
531
+ X = [getattr(self.graphs[i], self.metadata.graph_label) for i in indices]
532
+ y = [getattr(self.graphs[i].data, f'graph_{graph_label_name}')[0].item() for i in indices]
533
+
534
+ dataset = EncodingDataset(tokenizer, X, y, remove_duplicates=remove_duplicates)
535
+ print("\n".join([f"Label: {self.graph_label_map_label.inverse_transform([l])[0]}, Text: {i}" for i, l in zip(X, y)]))
536
+
537
+ return dataset
538
+
539
+
540
+ def get_lm_graph_classification_data(self, tokenizer):
541
+ assert self.metadata.graph_cls, "No Graph Label found in data. Please define graph label in metadata"
542
+ train_idx, test_idx = self.get_train_test_split()
543
+ train_dataset = self.__get_lm_data(train_idx, tokenizer)
544
+ test_dataset = self.__get_lm_data(test_idx, tokenizer)
545
+
546
+ return {
547
+ 'train': train_dataset,
548
+ 'test': test_dataset,
549
+ 'num_classes': getattr(self, f'num_graph_{self.metadata.graph_cls}')
550
+ }
551
+
552
+
553
+ def get_kfold_lm_graph_classification_data(self, tokenizer, remove_duplicates=True):
554
+ assert self.metadata.graph_cls, "No Graph Label found in data. Please define graph label in metadata"
555
+ for train_idx, test_idx in self.k_fold_split():
556
+ train_dataset = self.__get_lm_data(train_idx, tokenizer, remove_duplicates)
557
+ test_dataset = self.__get_lm_data(test_idx, tokenizer, remove_duplicates)
558
+ yield {
559
+ 'train': train_dataset,
560
+ 'test': test_dataset,
561
+ 'num_classes': getattr(self, f'num_graph_{self.metadata.graph_cls}')
562
+ }
563
+
564
+
565
+ class GraphEdgeDataset(GraphDataset):
566
+ def __init__(
567
+ self,
568
+ models_dataset: Union[EcoreDataset, ArchiMateDataset],
569
+ save_dir='datasets/graph_data',
570
+ distance=0,
571
+ reload=False,
572
+ test_ratio=0.2,
573
+
574
+ add_negative_train_samples=False,
575
+ neg_sampling_ratio=1,
576
+
577
+ use_attributes=False,
578
+ use_edge_types=False,
579
+ use_edge_label=False,
580
+ use_node_types=False,
581
+ no_labels=False,
582
+
583
+ use_embeddings=False,
584
+ embed_model_name='bert-base-uncased',
585
+ ckpt=None,
586
+
587
+ no_shuffle=False,
588
+ randomize_ne = False,
589
+ randomize_ee = False,
590
+ random_embed_dim=128,
591
+
592
+ use_special_tokens=False,
593
+
594
+ limit: int = -1,
595
+
596
+ node_cls_label: str = None,
597
+ edge_cls_label: str = None,
598
+
599
+ task_type=LP_TASK_EDGE_CLS
600
+ ):
601
+
602
+ super().__init__(
603
+ models_dataset=models_dataset,
604
+ save_dir=save_dir,
605
+ distance=distance,
606
+ test_ratio=test_ratio,
607
+
608
+ use_node_types=use_node_types,
609
+ use_edge_types=use_edge_types,
610
+ use_edge_label=use_edge_label,
611
+ use_attributes=use_attributes,
612
+ no_labels=no_labels,
613
+
614
+ add_negative_train_samples=add_negative_train_samples,
615
+ neg_sampling_ratio=neg_sampling_ratio,
616
+
617
+ use_special_tokens=use_special_tokens,
618
+
619
+
620
+ node_cls_label=node_cls_label,
621
+ edge_cls_label=edge_cls_label,
622
+
623
+ use_embeddings=use_embeddings,
624
+ embed_model_name=embed_model_name,
625
+ ckpt=ckpt,
626
+
627
+ reload=reload,
628
+ no_shuffle=no_shuffle,
629
+
630
+
631
+ randomize_ne=randomize_ne,
632
+ randomize_ee=randomize_ee,
633
+ random_embed_dim=random_embed_dim,
634
+ )
635
+
636
+ self.task_type = task_type
637
+
638
+ self.set_torch_graphs('edge', models_dataset, limit)
639
+
640
+ if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
641
+ texts = self.get_link_prediction_texts(only_texts=True)
642
+ texts = sum([v for k, v in texts.items() if not k.endswith("classes")], [])
643
+ print(f"Training {self.embedder.name} Embedder")
644
+ self.embedder.train(texts)
645
+ print(f"Trained {self.embedder.name} Embedder")
646
+
647
+ self.embed()
648
+
649
+ train_count, test_count = dict(), dict()
650
+ for g in self.graphs:
651
+ train_mask = g.data.train_edge_mask
652
+ test_mask = g.data.test_edge_mask
653
+ train_labels = getattr(g.data, f'edge_{self.metadata.edge_cls}')[train_mask]
654
+ test_labels = getattr(g.data, f'edge_{self.metadata.edge_cls}')[test_mask]
655
+ t1 = dict(Counter(train_labels.tolist()))
656
+ t2 = dict(Counter(test_labels.tolist()))
657
+ for k in t1:
658
+ train_count[k] = train_count.get(k, 0) + t1[k]
659
+
660
+ for k in t2:
661
+ test_count[k] = test_count.get(k, 0) + t2[k]
662
+
663
+ print(f"Train edge classes: {train_count}")
664
+ print(f"Test edge classes: {test_count}")
665
+
666
+
667
+ def get_link_prediction_texts(
668
+ self,
669
+ label: str = None,
670
+ only_texts: bool = False
671
+ ):
672
+ if label is None:
673
+ label = self.edge_cls_label
674
+
675
+ assert label is not None, "No edge label found in data. Please define edge label in metadata"
676
+
677
+ data = defaultdict(list)
678
+ for torch_graph in tqdm(self.graphs, desc='Getting Graph Texts'):
679
+ # torch_graph: TorchEdgeGraph = TorchGraph.load(fp)
680
+ graph_data = torch_graph.get_link_prediction_texts(label, self.task_type, only_texts)
681
+ for k, v in graph_data.items():
682
+ data[k] += v
683
+
684
+ print("Train Texts: ", data[f'train_pos_edges'][:20])
685
+ print("Test Texts: ", data[f'test_pos_edges'][:20])
686
+
687
+ # print("Train Classes", edge_label_map.inverse_transform([i.item() for i in data[f'train_edge_classes'][:20]]))
688
+ # print("Test Classes", edge_label_map.inverse_transform([i.item() for i in data[f'test_edge_classes'][:20]]))
689
+ return data
690
+
691
+
692
+ def get_link_prediction_lm_data(
693
+ self,
694
+ tokenizer: AutoTokenizer,
695
+ label: str = None,
696
+ ):
697
+ if label is None:
698
+ label = self.edge_cls_label
699
+
700
+ data = self.get_link_prediction_texts(label)
701
+
702
+
703
+ print("Tokenizing data")
704
+ if self.task_type == LP_TASK_EDGE_CLS:
705
+ datasets = {
706
+ 'train': EncodingDataset(
707
+ tokenizer,
708
+ data['train_pos_edges'],
709
+ data['train_edge_classes']
710
+ ),
711
+ 'test': EncodingDataset(
712
+ tokenizer,
713
+ data['test_pos_edges'],
714
+ data['test_edge_classes']
715
+ )
716
+ }
717
+ elif self.task_type == LP_TASK_LINK_PRED:
718
+ datasets = {
719
+ 'train': EncodingDataset(
720
+ tokenizer,
721
+ data['train_pos_edges'] + data['train_neg_edges'],
722
+ [1] * len(data['train_pos_edges']) + [0] * len(data['train_neg_edges'])
723
+ ),
724
+ 'test': EncodingDataset(
725
+ tokenizer,
726
+ data['test_pos_edges'] + data['test_neg_edges'],
727
+ [1] * len(data['test_pos_edges']) + [0] * len(data['test_neg_edges'])
728
+ )
729
+ }
730
+ else:
731
+ raise ValueError(f"Invalid task type: {self.task_type}")
732
+
733
+ print("Tokenized data")
734
+
735
+ return datasets
736
+
737
+
738
+ class GraphNodeDataset(GraphDataset):
739
+ def __init__(
740
+ self,
741
+ models_dataset: Union[EcoreDataset, ArchiMateDataset],
742
+ save_dir='datasets/graph_data',
743
+ distance=0,
744
+ test_ratio=0.2,
745
+ reload=False,
746
+
747
+ use_attributes=False,
748
+ use_edge_types=False,
749
+ use_node_types=False,
750
+ use_edge_label=False,
751
+ use_special_tokens=False,
752
+
753
+ use_embeddings=False,
754
+ embed_model_name='bert-base-uncased',
755
+ ckpt=None,
756
+
757
+ no_shuffle=False,
758
+ randomize_ne=False,
759
+ randomize_ee=False,
760
+ random_embed_dim=128,
761
+
762
+ limit: int = -1,
763
+ no_labels=False,
764
+ node_cls_label: str = None,
765
+ edge_cls_label: str = None
766
+ ):
767
+ super().__init__(
768
+ models_dataset=models_dataset,
769
+ save_dir=save_dir,
770
+ distance=distance,
771
+ test_ratio=test_ratio,
772
+
773
+ use_node_types=use_node_types,
774
+ use_edge_types=use_edge_types,
775
+ use_edge_label=use_edge_label,
776
+ use_attributes=use_attributes,
777
+ no_labels=no_labels,
778
+
779
+ node_cls_label=node_cls_label,
780
+ edge_cls_label=edge_cls_label,
781
+
782
+ use_embeddings=use_embeddings,
783
+ embed_model_name=embed_model_name,
784
+ ckpt=ckpt,
785
+
786
+ reload=reload,
787
+ no_shuffle=no_shuffle,
788
+
789
+ use_special_tokens=use_special_tokens,
790
+
791
+ randomize_ne=randomize_ne,
792
+ randomize_ee=randomize_ee,
793
+ random_embed_dim=random_embed_dim,
794
+ )
795
+
796
+ self.set_torch_graphs('node', models_dataset, limit)
797
+
798
+ if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
799
+ texts = self.get_node_classification_texts()
800
+ texts = sum([v for _, v in texts.items() if not v.endswith("classes")], [])
801
+ print(f"Training {self.embedder.name} Embedder")
802
+ self.embedder.train(texts)
803
+ print(f"Trained {self.embedder.name} Embedder")
804
+
805
+ self.embed()
806
+
807
+ node_labels = self.metadata.node_cls
808
+ if isinstance(node_labels, str):
809
+ node_labels = [node_labels]
810
+
811
+ for node_label in node_labels:
812
+ print(f"Node label: {node_label}")
813
+ train_count, test_count = dict(), dict()
814
+ for g in self.graphs:
815
+ train_idx = g.data.train_node_mask
816
+ test_idx = g.data.test_node_mask
817
+ train_labels = getattr(g.data, f'node_{node_label}')[train_idx]
818
+ test_labels = getattr(g.data, f'node_{node_label}')[test_idx]
819
+ t1 = dict(Counter(train_labels.tolist()))
820
+ t2 = dict(Counter(test_labels.tolist()))
821
+ for k in t1:
822
+ train_count[k] = train_count.get(k, 0) + t1[k]
823
+
824
+ for k in t2:
825
+ test_count[k] = test_count.get(k, 0) + t2[k]
826
+
827
+ print(f"Train Node classes: {train_count}")
828
+ print(f"Test Node classes: {test_count}")
829
+
830
+
831
+ def get_node_classification_texts(self, distance=None, label=None):
832
+ if distance is None:
833
+ distance = self.distance
834
+
835
+ label = self.metadata.node_cls if label is None else label
836
+
837
+ if isinstance(label, list):
838
+ label = label[0]
839
+
840
+ node_label_map = getattr(self, f"node_label_map_{label}")
841
+
842
+ data = {'train_nodes': [], 'train_node_classes': [], 'test_nodes': [], 'test_node_classes': []}
843
+ for torch_graph in tqdm(self.graphs, desc='Getting node classification data'):
844
+ # graph: TorchNodeGraph = TorchGraph.load(fp)
845
+ node_strs = list(torch_graph.get_graph_node_strs(torch_graph.data.edge_index, distance).values())
846
+
847
+ train_node_strs = [node_strs[i.item()] for i in torch_graph.data.train_node_mask]
848
+ test_node_strs = [node_strs[i.item()] for i in torch_graph.data.test_node_mask]
849
+
850
+ train_node_classes = getattr(torch_graph.data, f'node_{label}')[torch_graph.data.train_node_mask]
851
+ test_node_classes = getattr(torch_graph.data, f'node_{label}')[torch_graph.data.test_node_mask]
852
+
853
+ exclude_labels = getattr(self, f'node_exclude_{label}')
854
+ train_node_strs, train_node_classes = exclude_labels_from_data(train_node_strs, train_node_classes, exclude_labels)
855
+ test_node_strs, test_node_classes = exclude_labels_from_data(test_node_strs, test_node_classes, exclude_labels)
856
+
857
+ data['train_nodes'] += train_node_strs
858
+ data['train_node_classes'] += train_node_classes
859
+ data['test_nodes'] += test_node_strs
860
+ data['test_node_classes'] += test_node_classes
861
+
862
+
863
+ print("Tokenizing data")
864
+ print(data['train_nodes'][:10])
865
+ print(data['test_nodes'][:10])
866
+ if hasattr(self, "node_label_map_type"):
867
+ node_label_map.inverse_transform([i.item() for i in train_node_classes]) == train_node_strs
868
+ node_label_map.inverse_transform([i.item() for i in test_node_classes]) == test_node_strs
869
+
870
+ print(len(data['train_nodes']))
871
+ print(len(data['train_node_classes']))
872
+ print(len(data['test_nodes']))
873
+ print(len(data['test_node_classes']))
874
+ return data
875
+
876
+
877
+ def get_node_classification_lm_data(
878
+ self,
879
+ label: str,
880
+ tokenizer: AutoTokenizer,
881
+ distance: int = 0,
882
+ ):
883
+ data = self.get_node_classification_texts(distance, label)
884
+ dataset = {
885
+ 'train': EncodingDataset(
886
+ tokenizer,
887
+ data['train_nodes'],
888
+ data['train_node_classes']
889
+ ),
890
+ 'test': EncodingDataset(
891
+ tokenizer,
892
+ data['test_nodes'],
893
+ data['test_node_classes']
894
+ )
895
+ }
896
+
897
+ print("Tokenized data")
898
+
899
+ return dataset
900
+
901
+
902
+ def get_models_gpt_dataset(
903
+ models_dataset: Union[ArchiMateDataset, EcoreDataset],
904
+ tokenizer: AutoTokenizer,
905
+ chunk_size: int = 100,
906
+ chunk_overlap: int = 20,
907
+ max_length: int = 128,
908
+ **config_params
909
+ ):
910
+
911
+ def split_texts_into_chunks(
912
+ texts: List[str],
913
+ size: int = 100,
914
+ overlap: int = 20,
915
+ ):
916
+ text_splitter = RecursiveCharacterTextSplitter(
917
+ # Set a really small chunk size, just to show.
918
+ chunk_size=size,
919
+ chunk_overlap=overlap,
920
+ length_function=len,
921
+ is_separator_regex=False,
922
+ )
923
+ return [t.page_content for t in text_splitter.create_documents(texts)]
924
+
925
+ graph_dataset = GraphEdgeDataset(models_dataset, **config_params)
926
+ texts_data = graph_dataset.get_link_prediction_texts()
927
+ texts = texts_data['train_pos_edges'] + texts_data['test_pos_edges']
928
+
929
+ print("Total texts", len(texts))
930
+ splitted_texts = split_texts_into_chunks(
931
+ texts,
932
+ size=chunk_size,
933
+ overlap=chunk_overlap
934
+ )
935
+ print(len(splitted_texts))
936
+ print("Tokenizing...")
937
+ dataset = GPTTextDataset(texts, tokenizer, max_length=max_length)
938
+ print("Tokenized")
939
+ print("Max length", dataset[:]['input_ids'].shape)
940
+ return dataset