glam4cm 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +2 -1
- glam4cm/data_loading/data.py +90 -146
- glam4cm/data_loading/encoding.py +17 -6
- glam4cm/data_loading/graph_dataset.py +192 -57
- glam4cm/data_loading/metadata.py +1 -1
- glam4cm/data_loading/models_dataset.py +42 -18
- glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
- glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
- glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
- glam4cm/downstream_tasks/bert_node_classification.py +127 -89
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
- glam4cm/downstream_tasks/common_args.py +32 -4
- glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
- glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
- glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
- glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
- glam4cm/downstream_tasks/utils.py +16 -2
- glam4cm/embeddings/bert.py +1 -1
- glam4cm/embeddings/common.py +7 -4
- glam4cm/encoding/encoders.py +1 -1
- glam4cm/lang2graph/archimate.py +0 -5
- glam4cm/lang2graph/common.py +99 -41
- glam4cm/lang2graph/ecore.py +1 -2
- glam4cm/lang2graph/ontouml.py +8 -7
- glam4cm/models/gnn_layers.py +20 -6
- glam4cm/models/hf.py +2 -2
- glam4cm/run.py +12 -7
- glam4cm/run_conf_v2.py +405 -0
- glam4cm/run_configs.py +70 -106
- glam4cm/run_confs.py +41 -0
- glam4cm/settings.py +15 -2
- glam4cm/tokenization/special_tokens.py +23 -1
- glam4cm/tokenization/utils.py +23 -4
- glam4cm/trainers/cm_gpt_trainer.py +1 -1
- glam4cm/trainers/gnn_edge_classifier.py +12 -1
- glam4cm/trainers/gnn_graph_classifier.py +12 -5
- glam4cm/trainers/gnn_link_predictor.py +18 -3
- glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
- glam4cm/trainers/gnn_trainer.py +8 -0
- glam4cm/trainers/metrics.py +1 -1
- glam4cm/utils.py +265 -2
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
- glam4cm-1.0.0.dist-info/RECORD +75 -0
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
- glam4cm-0.1.1.dist-info/RECORD +0 -72
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -11,18 +11,29 @@ import numpy as np
|
|
11
11
|
from scipy.sparse import csr_matrix
|
12
12
|
from transformers import AutoTokenizer
|
13
13
|
from glam4cm.data_loading.data import TorchEdgeGraph, TorchGraph, TorchNodeGraph
|
14
|
-
from glam4cm.data_loading.models_dataset import
|
14
|
+
from glam4cm.data_loading.models_dataset import (
|
15
|
+
ArchiMateDataset,
|
16
|
+
EcoreDataset,
|
17
|
+
OntoUMLDataset
|
18
|
+
)
|
15
19
|
from glam4cm.data_loading.encoding import EncodingDataset, GPTTextDataset
|
16
20
|
from tqdm.auto import tqdm
|
17
21
|
from glam4cm.embeddings.w2v import Word2VecEmbedder
|
18
22
|
from glam4cm.embeddings.tfidf import TfidfEmbedder
|
19
23
|
from glam4cm.embeddings.common import get_embedding_model
|
20
24
|
from glam4cm.lang2graph.common import LangGraph, get_node_data, get_edge_data
|
21
|
-
from glam4cm.data_loading.metadata import
|
25
|
+
from glam4cm.data_loading.metadata import (
|
26
|
+
ArchimateMetaData,
|
27
|
+
EcoreMetaData,
|
28
|
+
OntoUMLMetaData
|
29
|
+
)
|
22
30
|
from glam4cm.settings import seed
|
23
31
|
from glam4cm.settings import (
|
24
|
-
|
25
|
-
|
32
|
+
EDGE_CLS_TASK,
|
33
|
+
LINK_PRED_TASK,
|
34
|
+
NODE_CLS_TASK,
|
35
|
+
GRAPH_CLS_TASK,
|
36
|
+
DUMMY_GRAPH_CLS_TASK
|
26
37
|
)
|
27
38
|
import glam4cm.utils as utils
|
28
39
|
|
@@ -94,7 +105,7 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
94
105
|
def __init__(
|
95
106
|
self,
|
96
107
|
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
97
|
-
|
108
|
+
task_type: str,
|
98
109
|
distance=1,
|
99
110
|
add_negative_train_samples=False,
|
100
111
|
neg_sampling_ratio=1,
|
@@ -103,6 +114,8 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
103
114
|
use_node_types=False,
|
104
115
|
use_edge_label=False,
|
105
116
|
no_labels=False,
|
117
|
+
|
118
|
+
node_topk=-1,
|
106
119
|
|
107
120
|
node_cls_label=None,
|
108
121
|
edge_cls_label=None,
|
@@ -120,8 +133,11 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
120
133
|
randomize_ee=False,
|
121
134
|
random_embed_dim=128,
|
122
135
|
|
123
|
-
exclude_labels:
|
136
|
+
exclude_labels: List = None,
|
137
|
+
save_dir='datasets/graph_data',
|
124
138
|
):
|
139
|
+
self.task_type = task_type
|
140
|
+
|
125
141
|
if isinstance(models_dataset, EcoreDataset):
|
126
142
|
self.metadata = EcoreMetaData()
|
127
143
|
elif isinstance(models_dataset, ArchiMateDataset):
|
@@ -147,12 +163,34 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
147
163
|
|
148
164
|
self.test_ratio = test_ratio
|
149
165
|
|
150
|
-
self.no_shuffle = no_shuffle
|
151
|
-
self.exclude_labels = exclude_labels
|
152
|
-
|
153
166
|
self.use_special_tokens = use_special_tokens
|
154
167
|
self.node_cls_label = node_cls_label
|
155
168
|
self.edge_cls_label = edge_cls_label
|
169
|
+
|
170
|
+
self.node_topk = node_topk
|
171
|
+
|
172
|
+
|
173
|
+
self.no_shuffle = no_shuffle
|
174
|
+
|
175
|
+
all_labels = [
|
176
|
+
i[0] for i in sorted(
|
177
|
+
dict(Counter(
|
178
|
+
sum(
|
179
|
+
[
|
180
|
+
[
|
181
|
+
v for v in list(dict(model.numbered_graph.nodes(data=node_cls_label)).values()) if v
|
182
|
+
]
|
183
|
+
for model in models_dataset],
|
184
|
+
[]
|
185
|
+
)
|
186
|
+
)).items(),
|
187
|
+
key=lambda x: x[1],
|
188
|
+
reverse=True
|
189
|
+
)
|
190
|
+
]
|
191
|
+
|
192
|
+
self.node_topk = all_labels[:node_topk] if node_topk > 0 else all_labels
|
193
|
+
self.exclude_labels = (all_labels[node_topk+1:] if node_topk > 0 else []) + [None, '']
|
156
194
|
|
157
195
|
self.randomize_ne = randomize_ne
|
158
196
|
self.randomize_ee = randomize_ee
|
@@ -161,6 +199,8 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
161
199
|
self.graphs: List[Union[TorchNodeGraph, TorchEdgeGraph]] = []
|
162
200
|
self.config = dict(
|
163
201
|
name=models_dataset.name,
|
202
|
+
task_type=task_type,
|
203
|
+
node_topk=node_topk,
|
164
204
|
distance=distance,
|
165
205
|
add_negative_train_samples=add_negative_train_samples,
|
166
206
|
neg_sampling_ratio=neg_sampling_ratio,
|
@@ -181,7 +221,6 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
181
221
|
randomize_ee=randomize_ee,
|
182
222
|
random_embed_dim=random_embed_dim
|
183
223
|
)
|
184
|
-
|
185
224
|
self.save_dir = os.path.join(save_dir, models_dataset.name)
|
186
225
|
os.makedirs(self.save_dir, exist_ok=True)
|
187
226
|
|
@@ -210,22 +249,21 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
210
249
|
self.config_hash = self.get_config_hash()
|
211
250
|
os.makedirs(os.path.join(self.save_dir, self.config_hash), exist_ok=True)
|
212
251
|
self.file_paths = {
|
213
|
-
graph.hash: os.path.join(
|
252
|
+
graph.hash: os.path.join(
|
253
|
+
self.save_dir,
|
254
|
+
self.config_hash,
|
255
|
+
f'{graph.hash}_{self.get_string_gen_params_hash()}',
|
256
|
+
'data.pkl'
|
257
|
+
)
|
214
258
|
for graph in models_dataset
|
215
259
|
}
|
216
260
|
|
217
261
|
print("Number of duplicate graphs: ", len(models_dataset) - len(self.file_paths))
|
218
262
|
|
219
|
-
|
220
|
-
def set_torch_graphs(
|
221
|
-
self,
|
222
|
-
type: str,
|
223
|
-
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
224
|
-
limit: int =-1
|
225
|
-
):
|
226
|
-
|
263
|
+
def get_common_params(self):
|
227
264
|
common_params = dict(
|
228
265
|
metadata=self.metadata,
|
266
|
+
task_type=self.task_type,
|
229
267
|
distance=self.distance,
|
230
268
|
test_ratio=self.test_ratio,
|
231
269
|
use_attributes=self.use_attributes,
|
@@ -235,8 +273,34 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
235
273
|
use_special_tokens=self.use_special_tokens,
|
236
274
|
no_labels=self.no_labels,
|
237
275
|
node_cls_label=self.node_cls_label,
|
238
|
-
edge_cls_label=self.edge_cls_label
|
276
|
+
edge_cls_label=self.edge_cls_label,
|
277
|
+
node_topk=self.node_topk,
|
239
278
|
)
|
279
|
+
return common_params
|
280
|
+
|
281
|
+
def get_string_gen_params_hash(self):
|
282
|
+
string_gen_params = f"""
|
283
|
+
distance={self.distance},
|
284
|
+
use_attributes={self.use_attributes},
|
285
|
+
use_node_types={self.use_node_types},
|
286
|
+
use_edge_types={self.use_edge_types},
|
287
|
+
use_edge_label={self.use_edge_label},
|
288
|
+
use_special_tokens={self.use_special_tokens},
|
289
|
+
no_labels={self.no_labels},
|
290
|
+
node_cls_label={self.node_cls_label},
|
291
|
+
edge_cls_label={self.edge_cls_label},
|
292
|
+
node_topk={self.node_topk},
|
293
|
+
ckpt={self.ckpt},
|
294
|
+
"""
|
295
|
+
return utils.md5_hash(string_gen_params)
|
296
|
+
|
297
|
+
def set_torch_graphs(
|
298
|
+
self,
|
299
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
300
|
+
limit: int =-1
|
301
|
+
):
|
302
|
+
|
303
|
+
common_params = self.get_common_params()
|
240
304
|
def create_node_graph(graph: LangGraph, fp: str) -> TorchNodeGraph:
|
241
305
|
node_params = {
|
242
306
|
**common_params,
|
@@ -256,24 +320,29 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
256
320
|
return torch_graph
|
257
321
|
|
258
322
|
|
259
|
-
|
260
323
|
models_size = len(models_dataset) \
|
261
324
|
if (limit == -1 or limit > len(models_dataset)) else limit
|
262
325
|
|
263
326
|
self.set_file_hashes(models_dataset[:models_size])
|
264
327
|
|
265
|
-
for graph in tqdm(models_dataset[:models_size], desc=f'Creating {
|
328
|
+
for graph in tqdm(models_dataset[:models_size], desc=f'Creating {self.task_type} graphs'):
|
266
329
|
fp = self.file_paths[graph.hash]
|
267
330
|
if not os.path.exists(fp) or self.reload:
|
268
|
-
if
|
331
|
+
if self.task_type in [NODE_CLS_TASK, GRAPH_CLS_TASK, DUMMY_GRAPH_CLS_TASK]:
|
269
332
|
torch_graph: TorchNodeGraph = create_node_graph(graph, fp)
|
270
|
-
elif
|
333
|
+
elif self.task_type in [EDGE_CLS_TASK, LINK_PRED_TASK]:
|
271
334
|
torch_graph: TorchEdgeGraph = create_edge_graph(graph, fp)
|
272
|
-
|
335
|
+
else:
|
336
|
+
raise ValueError(f"Invalid task type: {self.task_type}")
|
273
337
|
torch_graph.save()
|
274
338
|
|
275
|
-
|
276
|
-
|
339
|
+
|
340
|
+
def embed(self, models_dataset, limit):
|
341
|
+
models_size = len(models_dataset) \
|
342
|
+
if (limit == -1 or limit > len(models_dataset)) else limit
|
343
|
+
print("Limit: ", limit)
|
344
|
+
for graph in tqdm(models_dataset[:models_size], desc=f'Creating {self.task_type} graphs'):
|
345
|
+
fp = self.file_paths[graph.hash]
|
277
346
|
torch_graph = TorchGraph.load(fp)
|
278
347
|
torch_graph.embed(
|
279
348
|
self.embedder,
|
@@ -306,7 +375,8 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
306
375
|
return np.concatenate([a, b], axis=1)
|
307
376
|
|
308
377
|
prefix_cls = getattr(self, f"{prefix}_cls_label")
|
309
|
-
num_classes = getattr(self, f"
|
378
|
+
num_classes = getattr(self, f"{prefix}_label_map_{prefix_cls}").classes_.shape[0]
|
379
|
+
|
310
380
|
# print(f"Number of {prefix} types: {num_classes}")
|
311
381
|
for g in self.graphs:
|
312
382
|
types = np.eye(num_classes)[getattr(g.data, f"{prefix}_{prefix_cls}")]
|
@@ -325,10 +395,10 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
325
395
|
assert all(g.data.edge_attr.shape[1] == edge_dim for g in self.graphs), "Edge types not added correctly"
|
326
396
|
|
327
397
|
|
328
|
-
if self.use_node_types and self.node_cls_label:
|
398
|
+
if self.use_node_types and self.node_cls_label and self.task_type not in [NODE_CLS_TASK]:
|
329
399
|
set_types('node')
|
330
400
|
|
331
|
-
if self.use_edge_types and self.edge_cls_label:
|
401
|
+
if self.use_edge_types and self.edge_cls_label and self.task_type not in [EDGE_CLS_TASK, LINK_PRED_TASK]:
|
332
402
|
set_types('edge')
|
333
403
|
|
334
404
|
def __len__(self):
|
@@ -375,8 +445,9 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
375
445
|
assert torch_graph.data.overall_edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
|
376
446
|
f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
|
377
447
|
else:
|
378
|
-
|
379
|
-
|
448
|
+
if len(torch_graph.data.edge_index.shape) > 1:
|
449
|
+
assert torch_graph.data.edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
|
450
|
+
f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
|
380
451
|
|
381
452
|
|
382
453
|
|
@@ -407,7 +478,6 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
407
478
|
node_label_map = LabelEncoder()
|
408
479
|
node_label_map.fit_transform([j for i in label_values for j in i])
|
409
480
|
label_values = [node_label_map.transform(i) for i in label_values]
|
410
|
-
print(node_label_map.classes_)
|
411
481
|
|
412
482
|
for torch_graph, node_classes in zip(self.graphs, label_values):
|
413
483
|
setattr(torch_graph.data, f"node_{cls_label}", np.array(node_classes))
|
@@ -532,7 +602,7 @@ class GraphDataset(torch.utils.data.Dataset):
|
|
532
602
|
y = [getattr(self.graphs[i].data, f'graph_{graph_label_name}')[0].item() for i in indices]
|
533
603
|
|
534
604
|
dataset = EncodingDataset(tokenizer, X, y, remove_duplicates=remove_duplicates)
|
535
|
-
print("\n".join([f"Label: {self.graph_label_map_label.inverse_transform([l])[0]}, Text: {i}" for i, l in zip(X, y)]))
|
605
|
+
# print("\n".join([f"Label: {self.graph_label_map_label.inverse_transform([l])[0]}, Text: {i}" for i, l in zip(X, y)]))
|
536
606
|
|
537
607
|
return dataset
|
538
608
|
|
@@ -566,7 +636,7 @@ class GraphEdgeDataset(GraphDataset):
|
|
566
636
|
def __init__(
|
567
637
|
self,
|
568
638
|
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
569
|
-
|
639
|
+
task_type: str,
|
570
640
|
distance=0,
|
571
641
|
reload=False,
|
572
642
|
test_ratio=0.2,
|
@@ -580,6 +650,8 @@ class GraphEdgeDataset(GraphDataset):
|
|
580
650
|
use_node_types=False,
|
581
651
|
no_labels=False,
|
582
652
|
|
653
|
+
node_topk = -1,
|
654
|
+
|
583
655
|
use_embeddings=False,
|
584
656
|
embed_model_name='bert-base-uncased',
|
585
657
|
ckpt=None,
|
@@ -595,13 +667,13 @@ class GraphEdgeDataset(GraphDataset):
|
|
595
667
|
|
596
668
|
node_cls_label: str = None,
|
597
669
|
edge_cls_label: str = None,
|
598
|
-
|
599
|
-
task_type=LP_TASK_EDGE_CLS
|
670
|
+
save_dir='datasets/graph_data'
|
600
671
|
):
|
601
|
-
|
672
|
+
assert task_type in [EDGE_CLS_TASK, GRAPH_CLS_TASK], f"Invalid task type: Must be one of {[EDGE_CLS_TASK, GRAPH_CLS_TASK]}."
|
602
673
|
super().__init__(
|
603
674
|
models_dataset=models_dataset,
|
604
|
-
|
675
|
+
task_type=task_type,
|
676
|
+
|
605
677
|
distance=distance,
|
606
678
|
test_ratio=test_ratio,
|
607
679
|
|
@@ -609,8 +681,9 @@ class GraphEdgeDataset(GraphDataset):
|
|
609
681
|
use_edge_types=use_edge_types,
|
610
682
|
use_edge_label=use_edge_label,
|
611
683
|
use_attributes=use_attributes,
|
612
|
-
no_labels=no_labels,
|
613
|
-
|
684
|
+
no_labels=no_labels,
|
685
|
+
node_topk = node_topk,
|
686
|
+
|
614
687
|
add_negative_train_samples=add_negative_train_samples,
|
615
688
|
neg_sampling_ratio=neg_sampling_ratio,
|
616
689
|
|
@@ -631,11 +704,10 @@ class GraphEdgeDataset(GraphDataset):
|
|
631
704
|
randomize_ne=randomize_ne,
|
632
705
|
randomize_ee=randomize_ee,
|
633
706
|
random_embed_dim=random_embed_dim,
|
707
|
+
save_dir=save_dir
|
634
708
|
)
|
635
709
|
|
636
|
-
self.
|
637
|
-
|
638
|
-
self.set_torch_graphs('edge', models_dataset, limit)
|
710
|
+
self.set_torch_graphs(models_dataset, limit)
|
639
711
|
|
640
712
|
if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
|
641
713
|
texts = self.get_link_prediction_texts(only_texts=True)
|
@@ -644,7 +716,7 @@ class GraphEdgeDataset(GraphDataset):
|
|
644
716
|
self.embedder.train(texts)
|
645
717
|
print(f"Trained {self.embedder.name} Embedder")
|
646
718
|
|
647
|
-
self.embed()
|
719
|
+
self.embed(models_dataset, limit)
|
648
720
|
|
649
721
|
train_count, test_count = dict(), dict()
|
650
722
|
for g in self.graphs:
|
@@ -675,12 +747,13 @@ class GraphEdgeDataset(GraphDataset):
|
|
675
747
|
assert label is not None, "No edge label found in data. Please define edge label in metadata"
|
676
748
|
|
677
749
|
data = defaultdict(list)
|
678
|
-
for torch_graph in tqdm(self.graphs, desc='Getting
|
750
|
+
for torch_graph in tqdm(self.graphs, desc=f'Getting {self.task_type} Texts'):
|
679
751
|
# torch_graph: TorchEdgeGraph = TorchGraph.load(fp)
|
680
752
|
graph_data = torch_graph.get_link_prediction_texts(label, self.task_type, only_texts)
|
681
753
|
for k, v in graph_data.items():
|
682
754
|
data[k] += v
|
683
755
|
|
756
|
+
|
684
757
|
print("Train Texts: ", data[f'train_pos_edges'][:20])
|
685
758
|
print("Test Texts: ", data[f'test_pos_edges'][:20])
|
686
759
|
|
@@ -701,7 +774,7 @@ class GraphEdgeDataset(GraphDataset):
|
|
701
774
|
|
702
775
|
|
703
776
|
print("Tokenizing data")
|
704
|
-
if self.task_type ==
|
777
|
+
if self.task_type == EDGE_CLS_TASK:
|
705
778
|
datasets = {
|
706
779
|
'train': EncodingDataset(
|
707
780
|
tokenizer,
|
@@ -714,7 +787,7 @@ class GraphEdgeDataset(GraphDataset):
|
|
714
787
|
data['test_edge_classes']
|
715
788
|
)
|
716
789
|
}
|
717
|
-
elif self.task_type ==
|
790
|
+
elif self.task_type == LINK_PRED_TASK:
|
718
791
|
datasets = {
|
719
792
|
'train': EncodingDataset(
|
720
793
|
tokenizer,
|
@@ -738,8 +811,9 @@ class GraphEdgeDataset(GraphDataset):
|
|
738
811
|
class GraphNodeDataset(GraphDataset):
|
739
812
|
def __init__(
|
740
813
|
self,
|
741
|
-
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
742
|
-
|
814
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset, OntoUMLDataset],
|
815
|
+
task_type: str,
|
816
|
+
|
743
817
|
distance=0,
|
744
818
|
test_ratio=0.2,
|
745
819
|
reload=False,
|
@@ -749,6 +823,7 @@ class GraphNodeDataset(GraphDataset):
|
|
749
823
|
use_node_types=False,
|
750
824
|
use_edge_label=False,
|
751
825
|
use_special_tokens=False,
|
826
|
+
node_topk=-1,
|
752
827
|
|
753
828
|
use_embeddings=False,
|
754
829
|
embed_model_name='bert-base-uncased',
|
@@ -762,11 +837,69 @@ class GraphNodeDataset(GraphDataset):
|
|
762
837
|
limit: int = -1,
|
763
838
|
no_labels=False,
|
764
839
|
node_cls_label: str = None,
|
765
|
-
edge_cls_label: str = None
|
840
|
+
edge_cls_label: str = None,
|
841
|
+
|
842
|
+
save_dir='datasets/graph_data',
|
766
843
|
):
|
844
|
+
"""
|
845
|
+
Parameters
|
846
|
+
----------
|
847
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset, OntoUMLDataset]
|
848
|
+
The dataset of models to convert to a graph dataset.
|
849
|
+
task_type: str
|
850
|
+
The type of task to perform on the graph dataset. Must be one of 'node', 'edge', or 'graph'.
|
851
|
+
distance: int
|
852
|
+
The distance to consider when creating the graph. If 0, only the node itself is considered.
|
853
|
+
test_ratio: float
|
854
|
+
The proportion of the dataset to split into the test set.
|
855
|
+
reload: bool
|
856
|
+
Whether to reload the dataset from disk if it already exists.
|
857
|
+
use_attributes: bool
|
858
|
+
Whether to include attributes of the nodes and edges in the graph.
|
859
|
+
use_edge_types: bool
|
860
|
+
Whether to include the types of the edges in the graph.
|
861
|
+
use_node_types: bool
|
862
|
+
Whether to include the types of the nodes in the graph.
|
863
|
+
use_edge_label: bool
|
864
|
+
Whether to include the labels of the edges in the graph.
|
865
|
+
use_special_tokens: bool
|
866
|
+
Whether to include special tokens for the start and end of a node or edge sequence.
|
867
|
+
node_topk: int
|
868
|
+
The number of top nodes to include in the graph. If -1, all nodes are included.
|
869
|
+
use_embeddings: bool
|
870
|
+
Whether to use embeddings for the node and edge attributes.
|
871
|
+
embed_model_name: str
|
872
|
+
The name of the embedding model to use.
|
873
|
+
ckpt: str
|
874
|
+
The path to the checkpoint file of the embedding model.
|
875
|
+
no_shuffle: bool
|
876
|
+
Whether to shuffle the dataset before splitting it into train and test sets.
|
877
|
+
randomize_ne: bool
|
878
|
+
Whether to randomize the node embeddings.
|
879
|
+
randomize_ee: bool
|
880
|
+
Whether to randomize the edge embeddings.
|
881
|
+
random_embed_dim: int
|
882
|
+
The dimension of the random embeddings.
|
883
|
+
limit: int
|
884
|
+
The maximum number of models to include in the dataset.
|
885
|
+
no_labels: bool
|
886
|
+
Whether to include labels for the nodes and edges in the graph.
|
887
|
+
node_cls_label: str
|
888
|
+
The label to use for the node classification task.
|
889
|
+
edge_cls_label: str
|
890
|
+
The label to use for the edge classification task.
|
891
|
+
save_dir: str
|
892
|
+
The directory in which to save the dataset.
|
893
|
+
|
894
|
+
Returns
|
895
|
+
-------
|
896
|
+
A GraphNodeDataset object.
|
897
|
+
"""
|
898
|
+
assert task_type in [NODE_CLS_TASK, GRAPH_CLS_TASK], f"Invalid task type: Must be one of {[NODE_CLS_TASK, GRAPH_CLS_TASK]}."
|
767
899
|
super().__init__(
|
768
900
|
models_dataset=models_dataset,
|
769
|
-
|
901
|
+
task_type=task_type,
|
902
|
+
|
770
903
|
distance=distance,
|
771
904
|
test_ratio=test_ratio,
|
772
905
|
|
@@ -778,6 +911,8 @@ class GraphNodeDataset(GraphDataset):
|
|
778
911
|
|
779
912
|
node_cls_label=node_cls_label,
|
780
913
|
edge_cls_label=edge_cls_label,
|
914
|
+
|
915
|
+
node_topk=node_topk,
|
781
916
|
|
782
917
|
use_embeddings=use_embeddings,
|
783
918
|
embed_model_name=embed_model_name,
|
@@ -791,9 +926,10 @@ class GraphNodeDataset(GraphDataset):
|
|
791
926
|
randomize_ne=randomize_ne,
|
792
927
|
randomize_ee=randomize_ee,
|
793
928
|
random_embed_dim=random_embed_dim,
|
929
|
+
save_dir=save_dir,
|
794
930
|
)
|
795
931
|
|
796
|
-
self.set_torch_graphs(
|
932
|
+
self.set_torch_graphs(models_dataset, limit)
|
797
933
|
|
798
934
|
if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
|
799
935
|
texts = self.get_node_classification_texts()
|
@@ -802,7 +938,7 @@ class GraphNodeDataset(GraphDataset):
|
|
802
938
|
self.embedder.train(texts)
|
803
939
|
print(f"Trained {self.embedder.name} Embedder")
|
804
940
|
|
805
|
-
self.embed()
|
941
|
+
self.embed(models_dataset, limit)
|
806
942
|
|
807
943
|
node_labels = self.metadata.node_cls
|
808
944
|
if isinstance(node_labels, str):
|
@@ -860,9 +996,8 @@ class GraphNodeDataset(GraphDataset):
|
|
860
996
|
data['test_node_classes'] += test_node_classes
|
861
997
|
|
862
998
|
|
863
|
-
print("
|
864
|
-
print(data['
|
865
|
-
print(data['test_nodes'][:10])
|
999
|
+
# print("\n".join(data['train_nodes']))
|
1000
|
+
# print("\n".join(data['test_nodes']))
|
866
1001
|
if hasattr(self, "node_label_map_type"):
|
867
1002
|
node_label_map.inverse_transform([i.item() for i in train_node_classes]) == train_node_strs
|
868
1003
|
node_label_map.inverse_transform([i.item() for i in test_node_classes]) == test_node_strs
|
glam4cm/data_loading/metadata.py
CHANGED
@@ -30,7 +30,8 @@ class ModelDataset:
|
|
30
30
|
min_edges: int = -1,
|
31
31
|
min_enr: float = -1,
|
32
32
|
timeout=-1,
|
33
|
-
preprocess_graph_text: callable = None
|
33
|
+
preprocess_graph_text: callable = None,
|
34
|
+
include_dummies=False
|
34
35
|
):
|
35
36
|
self.name = dataset_name
|
36
37
|
self.dataset_dir = dataset_dir
|
@@ -41,6 +42,7 @@ class ModelDataset:
|
|
41
42
|
self.min_enr = min_enr
|
42
43
|
self.timeout = timeout
|
43
44
|
self.preprocess_graph_text = preprocess_graph_text
|
45
|
+
self.include_dummies = include_dummies
|
44
46
|
|
45
47
|
self.graphs: List[LangGraph] = []
|
46
48
|
|
@@ -114,12 +116,14 @@ class ModelDataset:
|
|
114
116
|
|
115
117
|
def save(self):
|
116
118
|
print(f'Saving {self.name} to pickle')
|
117
|
-
|
119
|
+
pkl_file = f'{self.name}{"_with_dummies" if self.include_dummies else ''}.pkl'
|
120
|
+
with open(os.path.join(self.save_dir, pkl_file), 'wb') as f:
|
118
121
|
pickle.dump(self.graphs, f)
|
119
122
|
print(f'Saved {self.name} to pickle')
|
120
123
|
|
121
124
|
|
122
125
|
def filter_graphs(self):
|
126
|
+
# print("Filtering graphs with min edges and min enr: ", self.min_edges, self.min_enr)
|
123
127
|
graphs = list()
|
124
128
|
for graph in self.graphs:
|
125
129
|
addable = True
|
@@ -129,6 +133,7 @@ class ModelDataset:
|
|
129
133
|
addable = False
|
130
134
|
|
131
135
|
if addable:
|
136
|
+
# print("Addable because min edges and min enr: ", graph.number_of_edges())
|
132
137
|
graphs.append(graph)
|
133
138
|
|
134
139
|
self.graphs = graphs
|
@@ -137,7 +142,8 @@ class ModelDataset:
|
|
137
142
|
|
138
143
|
def load(self):
|
139
144
|
print(f'Loading {self.name} from pickle')
|
140
|
-
|
145
|
+
pkl_file = f'{self.name}{"_with_dummies" if self.include_dummies else ''}.pkl'
|
146
|
+
with open(os.path.join(self.save_dir, pkl_file), 'rb') as f:
|
141
147
|
self.graphs = pickle.load(f)
|
142
148
|
|
143
149
|
self.filter_graphs()
|
@@ -172,7 +178,8 @@ class EcoreDataset(ModelDataset):
|
|
172
178
|
remove_duplicates=False,
|
173
179
|
min_edges: int = -1,
|
174
180
|
min_enr: float = -1,
|
175
|
-
preprocess_graph_text: callable = None
|
181
|
+
preprocess_graph_text: callable = None,
|
182
|
+
include_dummies=False
|
176
183
|
):
|
177
184
|
super().__init__(
|
178
185
|
dataset_name,
|
@@ -180,22 +187,34 @@ class EcoreDataset(ModelDataset):
|
|
180
187
|
save_dir=save_dir,
|
181
188
|
min_edges=min_edges,
|
182
189
|
min_enr=min_enr,
|
183
|
-
preprocess_graph_text=preprocess_graph_text
|
190
|
+
preprocess_graph_text=preprocess_graph_text,
|
191
|
+
include_dummies=include_dummies
|
184
192
|
)
|
185
193
|
os.makedirs(save_dir, exist_ok=True)
|
186
194
|
|
187
195
|
dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
|
188
196
|
if reload or not dataset_exists:
|
197
|
+
|
198
|
+
|
189
199
|
self.graphs: List[EcoreNxG] = []
|
190
200
|
data_path = os.path.join(dataset_dir, dataset_name)
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
201
|
+
file_name = os.path.join(data_path, 'ecore.jsonl') if not include_dummies\
|
202
|
+
else os.path.join(data_path, 'ecore-with-dummy.jsonl')
|
203
|
+
|
204
|
+
# for file in os.listdir(data_path):
|
205
|
+
# if file.endswith('.jsonl') and file.startswith("ecore"):
|
206
|
+
json_objects = json.load(open(file_name))
|
207
|
+
for g in tqdm(json_objects, desc=f'Loading {dataset_name.title()}'):
|
208
|
+
|
209
|
+
if remove_duplicates and g['is_duplicated']:
|
210
|
+
continue
|
211
|
+
|
212
|
+
if not include_dummies and g['labels'] == 'dummy':
|
213
|
+
print(f"Skipping dummy graph {g['ids']}")
|
214
|
+
continue
|
215
|
+
|
216
|
+
nxg = EcoreNxG(g)
|
217
|
+
self.graphs.append(nxg)
|
199
218
|
|
200
219
|
print(f'Loaded Total {self.name} with {len(self.graphs)} graphs')
|
201
220
|
print("Filtering...")
|
@@ -233,7 +252,8 @@ class ArchiMateDataset(ModelDataset):
|
|
233
252
|
min_enr: float = -1,
|
234
253
|
timeout=-1,
|
235
254
|
language=None,
|
236
|
-
preprocess_graph_text: callable = None
|
255
|
+
preprocess_graph_text: callable = None,
|
256
|
+
include_dummies=False
|
237
257
|
):
|
238
258
|
super().__init__(
|
239
259
|
dataset_name,
|
@@ -242,7 +262,8 @@ class ArchiMateDataset(ModelDataset):
|
|
242
262
|
min_edges=min_edges,
|
243
263
|
min_enr=min_enr,
|
244
264
|
timeout=timeout,
|
245
|
-
preprocess_graph_text=preprocess_graph_text
|
265
|
+
preprocess_graph_text=preprocess_graph_text,
|
266
|
+
include_dummies=include_dummies
|
246
267
|
)
|
247
268
|
os.makedirs(save_dir, exist_ok=True)
|
248
269
|
|
@@ -274,7 +295,7 @@ class ArchiMateDataset(ModelDataset):
|
|
274
295
|
|
275
296
|
except Exception as e:
|
276
297
|
raise e
|
277
|
-
|
298
|
+
print("Total graphs:", len(self.graphs))
|
278
299
|
self.filter_graphs()
|
279
300
|
self.save()
|
280
301
|
else:
|
@@ -283,6 +304,7 @@ class ArchiMateDataset(ModelDataset):
|
|
283
304
|
if remove_duplicates:
|
284
305
|
self.dedup()
|
285
306
|
|
307
|
+
assert all([g.number_of_edges() >= min_edges for g in self.graphs]), f"Filtered out graphs with less than {min_edges} edges"
|
286
308
|
print(f'Loaded {self.name} with {len(self.graphs)} graphs')
|
287
309
|
print(f'Graphs: {len(self.graphs)}')
|
288
310
|
|
@@ -305,7 +327,8 @@ class OntoUMLDataset(ModelDataset):
|
|
305
327
|
min_edges: int = -1,
|
306
328
|
min_enr: float = -1,
|
307
329
|
timeout=-1,
|
308
|
-
preprocess_graph_text: callable = None
|
330
|
+
preprocess_graph_text: callable = None,
|
331
|
+
include_dummies=False
|
309
332
|
):
|
310
333
|
super().__init__(
|
311
334
|
dataset_name,
|
@@ -314,7 +337,8 @@ class OntoUMLDataset(ModelDataset):
|
|
314
337
|
min_edges=min_edges,
|
315
338
|
min_enr=min_enr,
|
316
339
|
timeout=timeout,
|
317
|
-
preprocess_graph_text=preprocess_graph_text
|
340
|
+
preprocess_graph_text=preprocess_graph_text,
|
341
|
+
include_dummies=include_dummies
|
318
342
|
)
|
319
343
|
os.makedirs(save_dir, exist_ok=True)
|
320
344
|
|