glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,940 @@
|
|
1
|
+
import json
|
2
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
3
|
+
from sklearn.preprocessing import LabelEncoder
|
4
|
+
from collections import Counter, defaultdict
|
5
|
+
import os
|
6
|
+
from random import shuffle
|
7
|
+
from typing import List, Union
|
8
|
+
from sklearn.model_selection import StratifiedKFold
|
9
|
+
import torch
|
10
|
+
import numpy as np
|
11
|
+
from scipy.sparse import csr_matrix
|
12
|
+
from transformers import AutoTokenizer
|
13
|
+
from glam4cm.data_loading.data import TorchEdgeGraph, TorchGraph, TorchNodeGraph
|
14
|
+
from glam4cm.data_loading.models_dataset import ArchiMateDataset, EcoreDataset, OntoUMLDataset
|
15
|
+
from glam4cm.data_loading.encoding import EncodingDataset, GPTTextDataset
|
16
|
+
from tqdm.auto import tqdm
|
17
|
+
from glam4cm.embeddings.w2v import Word2VecEmbedder
|
18
|
+
from glam4cm.embeddings.tfidf import TfidfEmbedder
|
19
|
+
from glam4cm.embeddings.common import get_embedding_model
|
20
|
+
from glam4cm.lang2graph.common import LangGraph, get_node_data, get_edge_data
|
21
|
+
from glam4cm.data_loading.metadata import ArchimateMetaData, EcoreMetaData, OntoUMLMetaData
|
22
|
+
from glam4cm.settings import seed
|
23
|
+
from glam4cm.settings import (
|
24
|
+
LP_TASK_EDGE_CLS,
|
25
|
+
LP_TASK_LINK_PRED,
|
26
|
+
)
|
27
|
+
import glam4cm.utils as utils
|
28
|
+
|
29
|
+
|
30
|
+
def exclude_labels_from_data(X, labels, exclude_labels):
|
31
|
+
X = [X[i] for i in range(len(X)) if labels[i] not in exclude_labels]
|
32
|
+
labels = [labels[i] for i in range(len(labels)) if labels[i] not in exclude_labels]
|
33
|
+
return X, labels
|
34
|
+
|
35
|
+
|
36
|
+
def validate_classes(torch_graphs: List[TorchGraph], label, exclude_labels, element):
|
37
|
+
train_labels, test_labels = list(), list()
|
38
|
+
train_idx_label = f"train_{element}_mask"
|
39
|
+
test_idx_label = f"test_{element}_mask"
|
40
|
+
|
41
|
+
for torch_graph in tqdm(torch_graphs, desc=f"Validating {element} classes"):
|
42
|
+
labels = getattr(torch_graph.data, f"{element}_{label}")
|
43
|
+
train_idx = getattr(torch_graph.data, train_idx_label)
|
44
|
+
test_idx = getattr(torch_graph.data, test_idx_label)
|
45
|
+
indices = np.nonzero(np.isin(labels, exclude_labels))[0]
|
46
|
+
|
47
|
+
if len(indices) > 0:
|
48
|
+
train_idx = train_idx[~np.isin(train_idx, indices)]
|
49
|
+
test_idx = test_idx[~np.isin(test_idx, indices)]
|
50
|
+
|
51
|
+
train_labels.append(labels[train_idx])
|
52
|
+
test_labels.append(labels[test_idx])
|
53
|
+
|
54
|
+
|
55
|
+
edge_classes = [t for _, _, t in torch_graph.graph.edges(data=label)]
|
56
|
+
node_classes = [t for _, t in torch_graph.graph.nodes(data=label)]
|
57
|
+
|
58
|
+
if element == 'edge':
|
59
|
+
for idx in train_idx.tolist() + test_idx.tolist():
|
60
|
+
t_c = edge_classes[idx]
|
61
|
+
edge = torch_graph.graph.idx_to_edge[idx]
|
62
|
+
assert torch_graph.graph.numbered_graph.edges[edge][label] == t_c, f"Edge {label} mismatch for {edge}"
|
63
|
+
|
64
|
+
elif element == 'node':
|
65
|
+
for idx in train_idx.tolist() + test_idx.tolist():
|
66
|
+
t_c = node_classes[idx]
|
67
|
+
node_label = torch_graph.graph.id_to_node_label[idx]
|
68
|
+
try:
|
69
|
+
assert torch_graph.graph.nodes[node_label][label] == t_c, f"Node {label} mismatch for {idx}"
|
70
|
+
except AssertionError as e:
|
71
|
+
raise e
|
72
|
+
else:
|
73
|
+
raise ValueError(f"Invalid element: {element}")
|
74
|
+
|
75
|
+
|
76
|
+
train_classes = set(sum([
|
77
|
+
getattr(torch_graph.data, f"{element}_{label}")[getattr(torch_graph.data, train_idx_label)].tolist()
|
78
|
+
for torch_graph in torch_graphs], []
|
79
|
+
))
|
80
|
+
test_classes = set(sum([
|
81
|
+
getattr(torch_graph.data, f"{element}_{label}")[getattr(torch_graph.data, test_idx_label)].tolist()
|
82
|
+
for torch_graph in torch_graphs], []
|
83
|
+
))
|
84
|
+
num_train_classes = len(train_classes)
|
85
|
+
num_test_classes = len(test_classes)
|
86
|
+
print("Train classes:", train_classes)
|
87
|
+
print("Test classes:", test_classes)
|
88
|
+
print(f"Number of classes in training set: {num_train_classes}")
|
89
|
+
print(f"Number of classes in test set: {num_test_classes}")
|
90
|
+
|
91
|
+
|
92
|
+
|
93
|
+
class GraphDataset(torch.utils.data.Dataset):
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
97
|
+
save_dir='datasets/graph_data',
|
98
|
+
distance=1,
|
99
|
+
add_negative_train_samples=False,
|
100
|
+
neg_sampling_ratio=1,
|
101
|
+
use_attributes=False,
|
102
|
+
use_edge_types=False,
|
103
|
+
use_node_types=False,
|
104
|
+
use_edge_label=False,
|
105
|
+
no_labels=False,
|
106
|
+
|
107
|
+
node_cls_label=None,
|
108
|
+
edge_cls_label=None,
|
109
|
+
|
110
|
+
test_ratio=0.2,
|
111
|
+
|
112
|
+
use_embeddings=False,
|
113
|
+
use_special_tokens=False,
|
114
|
+
embed_model_name='bert-base-uncased',
|
115
|
+
ckpt=None,
|
116
|
+
reload=False,
|
117
|
+
no_shuffle=False,
|
118
|
+
|
119
|
+
randomize_ne=False,
|
120
|
+
randomize_ee=False,
|
121
|
+
random_embed_dim=128,
|
122
|
+
|
123
|
+
exclude_labels: list = [None, ''],
|
124
|
+
):
|
125
|
+
if isinstance(models_dataset, EcoreDataset):
|
126
|
+
self.metadata = EcoreMetaData()
|
127
|
+
elif isinstance(models_dataset, ArchiMateDataset):
|
128
|
+
self.metadata = ArchimateMetaData()
|
129
|
+
elif isinstance(models_dataset, OntoUMLDataset):
|
130
|
+
self.metadata = OntoUMLMetaData()
|
131
|
+
|
132
|
+
self.distance = distance
|
133
|
+
self.use_embeddings = use_embeddings
|
134
|
+
self.ckpt = ckpt
|
135
|
+
self.embedder = get_embedding_model(embed_model_name, ckpt) if use_embeddings else None
|
136
|
+
|
137
|
+
self.reload = reload
|
138
|
+
|
139
|
+
self.use_edge_types = use_edge_types
|
140
|
+
self.use_node_types = use_node_types
|
141
|
+
self.use_attributes = use_attributes
|
142
|
+
self.use_edge_label = use_edge_label
|
143
|
+
self.no_labels = no_labels
|
144
|
+
|
145
|
+
self.add_negative_train_samples = add_negative_train_samples
|
146
|
+
self.neg_sampling_ratio = neg_sampling_ratio
|
147
|
+
|
148
|
+
self.test_ratio = test_ratio
|
149
|
+
|
150
|
+
self.no_shuffle = no_shuffle
|
151
|
+
self.exclude_labels = exclude_labels
|
152
|
+
|
153
|
+
self.use_special_tokens = use_special_tokens
|
154
|
+
self.node_cls_label = node_cls_label
|
155
|
+
self.edge_cls_label = edge_cls_label
|
156
|
+
|
157
|
+
self.randomize_ne = randomize_ne
|
158
|
+
self.randomize_ee = randomize_ee
|
159
|
+
self.random_embed_dim = random_embed_dim
|
160
|
+
|
161
|
+
self.graphs: List[Union[TorchNodeGraph, TorchEdgeGraph]] = []
|
162
|
+
self.config = dict(
|
163
|
+
name=models_dataset.name,
|
164
|
+
distance=distance,
|
165
|
+
add_negative_train_samples=add_negative_train_samples,
|
166
|
+
neg_sampling_ratio=neg_sampling_ratio,
|
167
|
+
use_attributes=use_attributes,
|
168
|
+
use_edge_types=use_edge_types,
|
169
|
+
use_node_types=use_node_types,
|
170
|
+
use_edge_label=use_edge_label,
|
171
|
+
no_labels=no_labels,
|
172
|
+
use_special_tokens=use_special_tokens,
|
173
|
+
use_embeddings=use_embeddings,
|
174
|
+
embed_model_name=embed_model_name if use_embeddings else None,
|
175
|
+
ckpt=ckpt if use_embeddings else None,
|
176
|
+
exclude_labels=exclude_labels,
|
177
|
+
node_cls_label=node_cls_label,
|
178
|
+
edge_cls_label=edge_cls_label,
|
179
|
+
test_ratio=test_ratio,
|
180
|
+
randomize_ne=randomize_ne,
|
181
|
+
randomize_ee=randomize_ee,
|
182
|
+
random_embed_dim=random_embed_dim
|
183
|
+
)
|
184
|
+
|
185
|
+
self.save_dir = os.path.join(save_dir, models_dataset.name)
|
186
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
187
|
+
|
188
|
+
|
189
|
+
@property
|
190
|
+
def node_dim(self):
|
191
|
+
pass
|
192
|
+
|
193
|
+
def get_config_hash(self):
|
194
|
+
if os.path.exists(os.path.join(self.save_dir, 'configs.json')):
|
195
|
+
with open(os.path.join(self.save_dir, 'configs.json'), 'r') as f:
|
196
|
+
configs = json.load(f)
|
197
|
+
else:
|
198
|
+
configs = dict()
|
199
|
+
|
200
|
+
config_hash = utils.md5_hash(str(self.config))
|
201
|
+
if config_hash not in configs:
|
202
|
+
configs[config_hash] = self.config
|
203
|
+
with open(os.path.join(self.save_dir, 'configs.json'), 'w') as f:
|
204
|
+
json.dump(configs, f)
|
205
|
+
|
206
|
+
return config_hash
|
207
|
+
|
208
|
+
|
209
|
+
def set_file_hashes(self, models_dataset: Union[EcoreDataset, ArchiMateDataset]):
|
210
|
+
self.config_hash = self.get_config_hash()
|
211
|
+
os.makedirs(os.path.join(self.save_dir, self.config_hash), exist_ok=True)
|
212
|
+
self.file_paths = {
|
213
|
+
graph.hash: os.path.join(self.save_dir, self.config_hash, f'{graph.hash}', 'data.pkl')
|
214
|
+
for graph in models_dataset
|
215
|
+
}
|
216
|
+
|
217
|
+
print("Number of duplicate graphs: ", len(models_dataset) - len(self.file_paths))
|
218
|
+
|
219
|
+
|
220
|
+
def set_torch_graphs(
|
221
|
+
self,
|
222
|
+
type: str,
|
223
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
224
|
+
limit: int =-1
|
225
|
+
):
|
226
|
+
|
227
|
+
common_params = dict(
|
228
|
+
metadata=self.metadata,
|
229
|
+
distance=self.distance,
|
230
|
+
test_ratio=self.test_ratio,
|
231
|
+
use_attributes=self.use_attributes,
|
232
|
+
use_node_types=self.use_node_types,
|
233
|
+
use_edge_types=self.use_edge_types,
|
234
|
+
use_edge_label=self.use_edge_label,
|
235
|
+
use_special_tokens=self.use_special_tokens,
|
236
|
+
no_labels=self.no_labels,
|
237
|
+
node_cls_label=self.node_cls_label,
|
238
|
+
edge_cls_label=self.edge_cls_label
|
239
|
+
)
|
240
|
+
def create_node_graph(graph: LangGraph, fp: str) -> TorchNodeGraph:
|
241
|
+
node_params = {
|
242
|
+
**common_params,
|
243
|
+
'fp': fp,
|
244
|
+
}
|
245
|
+
torch_graph = TorchNodeGraph(graph, **node_params)
|
246
|
+
return torch_graph
|
247
|
+
|
248
|
+
def create_edge_graph(graph: LangGraph, fp: str) -> TorchEdgeGraph:
|
249
|
+
edge_params = {
|
250
|
+
**common_params,
|
251
|
+
'add_negative_train_samples': self.add_negative_train_samples,
|
252
|
+
'neg_samples_ratio': self.neg_sampling_ratio,
|
253
|
+
'fp': fp,
|
254
|
+
}
|
255
|
+
torch_graph = TorchEdgeGraph(graph, **edge_params)
|
256
|
+
return torch_graph
|
257
|
+
|
258
|
+
|
259
|
+
|
260
|
+
models_size = len(models_dataset) \
|
261
|
+
if (limit == -1 or limit > len(models_dataset)) else limit
|
262
|
+
|
263
|
+
self.set_file_hashes(models_dataset[:models_size])
|
264
|
+
|
265
|
+
for graph in tqdm(models_dataset[:models_size], desc=f'Creating {type} graphs'):
|
266
|
+
fp = self.file_paths[graph.hash]
|
267
|
+
if not os.path.exists(fp) or self.reload:
|
268
|
+
if type == 'node':
|
269
|
+
torch_graph: TorchNodeGraph = create_node_graph(graph, fp)
|
270
|
+
elif type == 'edge':
|
271
|
+
torch_graph: TorchEdgeGraph = create_edge_graph(graph, fp)
|
272
|
+
|
273
|
+
torch_graph.save()
|
274
|
+
|
275
|
+
def embed(self):
|
276
|
+
for fp in tqdm(self.file_paths.values(), desc='Embedding graphs'):
|
277
|
+
torch_graph = TorchGraph.load(fp)
|
278
|
+
torch_graph.embed(
|
279
|
+
self.embedder,
|
280
|
+
reload=self.reload,
|
281
|
+
randomize_ne=self.randomize_ne,
|
282
|
+
randomize_ee=self.randomize_ee,
|
283
|
+
random_embed_dim=self.random_embed_dim
|
284
|
+
)
|
285
|
+
|
286
|
+
for fp in tqdm(self.file_paths.values(), desc='Re-Loading graphs'):
|
287
|
+
torch_graph = TorchGraph.load(fp)
|
288
|
+
self.graphs.append(torch_graph)
|
289
|
+
|
290
|
+
if not self.no_shuffle:
|
291
|
+
shuffle(self.graphs)
|
292
|
+
|
293
|
+
self.post_process_graphs()
|
294
|
+
self.validate_graphs()
|
295
|
+
# self.save()
|
296
|
+
print("Graphs saved")
|
297
|
+
|
298
|
+
|
299
|
+
def post_process_graphs(self):
|
300
|
+
self.add_cls_labels()
|
301
|
+
|
302
|
+
def set_types(prefix):
|
303
|
+
def concatenate(a, b):
|
304
|
+
if isinstance(a, csr_matrix):
|
305
|
+
return csr_matrix(np.concatenate([a.toarray(), b], axis=1))
|
306
|
+
return np.concatenate([a, b], axis=1)
|
307
|
+
|
308
|
+
prefix_cls = getattr(self, f"{prefix}_cls_label")
|
309
|
+
num_classes = getattr(self, f"num_{prefix}s_{prefix_cls}") + 1
|
310
|
+
# print(f"Number of {prefix} types: {num_classes}")
|
311
|
+
for g in self.graphs:
|
312
|
+
types = np.eye(num_classes)[getattr(g.data, f"{prefix}_{prefix_cls}")]
|
313
|
+
|
314
|
+
if prefix == 'node':
|
315
|
+
g.data.x = concatenate(g.data.x, types)
|
316
|
+
|
317
|
+
elif prefix == 'edge':
|
318
|
+
g.data.edge_attr = concatenate(g.data.edge_attr, types)
|
319
|
+
|
320
|
+
# g.save()
|
321
|
+
|
322
|
+
node_dim = self.graphs[0].data.x.shape[1]
|
323
|
+
assert all(g.data.x.shape[1] == node_dim for g in self.graphs), "Node types not added correctly"
|
324
|
+
edge_dim = self.graphs[0].data.edge_attr.shape[1]
|
325
|
+
assert all(g.data.edge_attr.shape[1] == edge_dim for g in self.graphs), "Edge types not added correctly"
|
326
|
+
|
327
|
+
|
328
|
+
if self.use_node_types and self.node_cls_label:
|
329
|
+
set_types('node')
|
330
|
+
|
331
|
+
if self.use_edge_types and self.edge_cls_label:
|
332
|
+
set_types('edge')
|
333
|
+
|
334
|
+
def __len__(self):
|
335
|
+
return len(self.graphs)
|
336
|
+
|
337
|
+
|
338
|
+
def __getitem__(self, index: int):
|
339
|
+
return self.graphs[index]
|
340
|
+
|
341
|
+
|
342
|
+
def get_torch_dataset(self):
|
343
|
+
return [g.data.to_graph_data() for g in self.graphs]
|
344
|
+
|
345
|
+
|
346
|
+
def save(self):
|
347
|
+
for torch_graph in self.graphs:
|
348
|
+
torch_graph.save()
|
349
|
+
|
350
|
+
|
351
|
+
def get_train_test_split(self):
|
352
|
+
n = len(self.graphs)
|
353
|
+
train_size = int(n * (1 - self.test_ratio))
|
354
|
+
idx = list(range(n))
|
355
|
+
shuffle(idx)
|
356
|
+
train_idx = idx[:train_size]
|
357
|
+
test_idx = idx[train_size:]
|
358
|
+
return train_idx, test_idx
|
359
|
+
|
360
|
+
|
361
|
+
def k_fold_split(self):
|
362
|
+
k = int(1 / self.test_ratio)
|
363
|
+
kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
|
364
|
+
n = len(self.graphs)
|
365
|
+
for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
|
366
|
+
yield train_idx, test_idx
|
367
|
+
|
368
|
+
|
369
|
+
def validate_graphs(self):
|
370
|
+
for torch_graph in self.graphs:
|
371
|
+
assert torch_graph.data.x.shape[0] == torch_graph.graph.number_of_nodes(), \
|
372
|
+
f"Number of nodes mismatch, {torch_graph.data.x.shape[0]} != {torch_graph.graph.number_of_nodes()}"
|
373
|
+
|
374
|
+
if isinstance(torch_graph, TorchEdgeGraph):
|
375
|
+
assert torch_graph.data.overall_edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
|
376
|
+
f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
|
377
|
+
else:
|
378
|
+
assert torch_graph.data.edge_index.shape[1] == torch_graph.graph.number_of_edges(), \
|
379
|
+
f"Number of edges mismatch, {torch_graph.data.edge_index.shape[1]} != {torch_graph.graph.number_of_edges()}"
|
380
|
+
|
381
|
+
|
382
|
+
|
383
|
+
|
384
|
+
def add_cls_labels(self):
|
385
|
+
self.add_node_labels()
|
386
|
+
self.add_edge_labels()
|
387
|
+
self.add_graph_labels()
|
388
|
+
self.add_graph_text()
|
389
|
+
|
390
|
+
|
391
|
+
def add_node_labels(self):
|
392
|
+
model_type = self.metadata.type
|
393
|
+
cls_labels = self.metadata.node_cls
|
394
|
+
if isinstance(cls_labels, str):
|
395
|
+
cls_labels = [cls_labels]
|
396
|
+
|
397
|
+
for cls_label in cls_labels:
|
398
|
+
label_values = list()
|
399
|
+
for torch_graph in self.graphs:
|
400
|
+
values = list()
|
401
|
+
for _, node in torch_graph.graph.nodes(data=True):
|
402
|
+
node_data = get_node_data(node, cls_label, model_type)
|
403
|
+
values.append(node_data)
|
404
|
+
|
405
|
+
label_values.append(values)
|
406
|
+
|
407
|
+
node_label_map = LabelEncoder()
|
408
|
+
node_label_map.fit_transform([j for i in label_values for j in i])
|
409
|
+
label_values = [node_label_map.transform(i) for i in label_values]
|
410
|
+
print(node_label_map.classes_)
|
411
|
+
|
412
|
+
for torch_graph, node_classes in zip(self.graphs, label_values):
|
413
|
+
setattr(torch_graph.data, f"node_{cls_label}", np.array(node_classes))
|
414
|
+
|
415
|
+
setattr(self, f"node_label_map_{cls_label}", node_label_map)
|
416
|
+
|
417
|
+
exclude_labels = [
|
418
|
+
node_label_map.transform([e])[0]
|
419
|
+
for e in self.exclude_labels
|
420
|
+
if e in node_label_map.classes_
|
421
|
+
]
|
422
|
+
setattr(self, f"node_exclude_{cls_label}", exclude_labels)
|
423
|
+
|
424
|
+
num_labels = len(node_label_map.classes_) - len(exclude_labels)
|
425
|
+
|
426
|
+
print("Setting num_nodes_", cls_label, num_labels)
|
427
|
+
setattr(self, f"num_nodes_{cls_label}", num_labels)
|
428
|
+
|
429
|
+
if hasattr(self.graphs[0].data, 'train_node_mask'):
|
430
|
+
validate_classes(self.graphs, cls_label, exclude_labels, 'node')
|
431
|
+
|
432
|
+
|
433
|
+
def add_edge_labels(self):
|
434
|
+
model_type = self.metadata.type
|
435
|
+
cls_labels = self.metadata.edge_cls
|
436
|
+
if isinstance(cls_labels, str):
|
437
|
+
cls_labels = [cls_labels]
|
438
|
+
|
439
|
+
for cls_label in cls_labels:
|
440
|
+
label_values = list()
|
441
|
+
for torch_graph in self.graphs:
|
442
|
+
values = list()
|
443
|
+
for _, _, edge_data in torch_graph.graph.edges(data=True):
|
444
|
+
values.append(get_edge_data(edge_data, cls_label, model_type))
|
445
|
+
label_values.append(values)
|
446
|
+
|
447
|
+
edge_label_map = LabelEncoder()
|
448
|
+
edge_label_map.fit_transform([j for i in label_values for j in i])
|
449
|
+
label_values = [edge_label_map.transform(i) for i in label_values]
|
450
|
+
print("Edge Classes: ", edge_label_map.classes_)
|
451
|
+
|
452
|
+
for torch_graph, edge_classes in zip(self.graphs, label_values):
|
453
|
+
setattr(torch_graph.data, f"edge_{cls_label}", np.array(edge_classes))
|
454
|
+
|
455
|
+
setattr(self, f"edge_label_map_{cls_label}", edge_label_map)
|
456
|
+
|
457
|
+
exclude_labels = [
|
458
|
+
edge_label_map.transform([e])[0]
|
459
|
+
for e in self.exclude_labels
|
460
|
+
if e in edge_label_map.classes_
|
461
|
+
]
|
462
|
+
setattr(self, f"edge_exclude_{cls_label}", edge_label_map.inverse_transform(exclude_labels))
|
463
|
+
|
464
|
+
num_labels = len(edge_label_map.classes_) - len(exclude_labels)
|
465
|
+
setattr(self, f"num_edges_{cls_label}", num_labels)
|
466
|
+
|
467
|
+
if hasattr(self.graphs[0].data, 'train_edge_mask'):
|
468
|
+
validate_classes(self.graphs, cls_label, exclude_labels, 'edge')
|
469
|
+
|
470
|
+
|
471
|
+
def add_graph_labels(self):
|
472
|
+
if hasattr(self.metadata, 'graph'):
|
473
|
+
cls_label = self.metadata.graph_cls
|
474
|
+
if cls_label and hasattr(self.graphs[0].graph, cls_label):
|
475
|
+
graph_labels = list()
|
476
|
+
for torch_graph in self.graphs:
|
477
|
+
graph_labels.append(getattr(torch_graph.graph, cls_label))
|
478
|
+
|
479
|
+
graph_label_map = LabelEncoder()
|
480
|
+
graph_labels = graph_label_map.fit_transform(graph_labels)
|
481
|
+
|
482
|
+
print("Graph Classes: ", graph_label_map.classes_)
|
483
|
+
|
484
|
+
for torch_graph, graph_label in zip(self.graphs, graph_labels):
|
485
|
+
setattr(torch_graph.data, f"graph_{cls_label}", np.array([graph_label]))
|
486
|
+
|
487
|
+
exclude_labels = [
|
488
|
+
graph_label_map.transform([e])[0]
|
489
|
+
for e in self.exclude_labels
|
490
|
+
if e in graph_label_map.classes_
|
491
|
+
]
|
492
|
+
setattr(self, f"graph_exclude_{cls_label}", exclude_labels)
|
493
|
+
num_labels = len(graph_label_map.classes_) - len(exclude_labels)
|
494
|
+
|
495
|
+
print("Setting num_graph_", cls_label, num_labels)
|
496
|
+
setattr(self, f"num_graph_{cls_label}", num_labels)
|
497
|
+
setattr(self, f"graph_label_map_{cls_label}", graph_label_map)
|
498
|
+
|
499
|
+
|
500
|
+
|
501
|
+
def add_graph_text(self):
|
502
|
+
label = self.metadata.graph_label
|
503
|
+
if label:
|
504
|
+
for torch_graph in self.graphs:
|
505
|
+
setattr(torch_graph, label, getattr(torch_graph.graph, label))
|
506
|
+
|
507
|
+
|
508
|
+
def get_gnn_graph_classification_data(self):
|
509
|
+
train_idx, test_idx = self.get_train_test_split()
|
510
|
+
train_data = [self.graphs[i].data for i in train_idx]
|
511
|
+
test_data = [self.graphs[i].data for i in test_idx]
|
512
|
+
return {
|
513
|
+
'train': train_data,
|
514
|
+
'test': test_data
|
515
|
+
}
|
516
|
+
|
517
|
+
|
518
|
+
def get_kfold_gnn_graph_classification_data(self):
|
519
|
+
for train_idx, test_idx in self.k_fold_split():
|
520
|
+
train_data = [self.graphs[i].data.to_graph_data() for i in train_idx]
|
521
|
+
test_data = [self.graphs[i].data.to_graph_data() for i in test_idx]
|
522
|
+
yield {
|
523
|
+
'train': train_data,
|
524
|
+
'test': test_data,
|
525
|
+
}
|
526
|
+
|
527
|
+
|
528
|
+
def __get_lm_data(self, indices, tokenizer, remove_duplicates=False):
|
529
|
+
graph_label_name = self.metadata.graph_cls
|
530
|
+
assert graph_label_name is not None, "No Graph Label found in data. Please define graph label in metadata"
|
531
|
+
X = [getattr(self.graphs[i], self.metadata.graph_label) for i in indices]
|
532
|
+
y = [getattr(self.graphs[i].data, f'graph_{graph_label_name}')[0].item() for i in indices]
|
533
|
+
|
534
|
+
dataset = EncodingDataset(tokenizer, X, y, remove_duplicates=remove_duplicates)
|
535
|
+
print("\n".join([f"Label: {self.graph_label_map_label.inverse_transform([l])[0]}, Text: {i}" for i, l in zip(X, y)]))
|
536
|
+
|
537
|
+
return dataset
|
538
|
+
|
539
|
+
|
540
|
+
def get_lm_graph_classification_data(self, tokenizer):
|
541
|
+
assert self.metadata.graph_cls, "No Graph Label found in data. Please define graph label in metadata"
|
542
|
+
train_idx, test_idx = self.get_train_test_split()
|
543
|
+
train_dataset = self.__get_lm_data(train_idx, tokenizer)
|
544
|
+
test_dataset = self.__get_lm_data(test_idx, tokenizer)
|
545
|
+
|
546
|
+
return {
|
547
|
+
'train': train_dataset,
|
548
|
+
'test': test_dataset,
|
549
|
+
'num_classes': getattr(self, f'num_graph_{self.metadata.graph_cls}')
|
550
|
+
}
|
551
|
+
|
552
|
+
|
553
|
+
def get_kfold_lm_graph_classification_data(self, tokenizer, remove_duplicates=True):
|
554
|
+
assert self.metadata.graph_cls, "No Graph Label found in data. Please define graph label in metadata"
|
555
|
+
for train_idx, test_idx in self.k_fold_split():
|
556
|
+
train_dataset = self.__get_lm_data(train_idx, tokenizer, remove_duplicates)
|
557
|
+
test_dataset = self.__get_lm_data(test_idx, tokenizer, remove_duplicates)
|
558
|
+
yield {
|
559
|
+
'train': train_dataset,
|
560
|
+
'test': test_dataset,
|
561
|
+
'num_classes': getattr(self, f'num_graph_{self.metadata.graph_cls}')
|
562
|
+
}
|
563
|
+
|
564
|
+
|
565
|
+
class GraphEdgeDataset(GraphDataset):
|
566
|
+
def __init__(
|
567
|
+
self,
|
568
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
569
|
+
save_dir='datasets/graph_data',
|
570
|
+
distance=0,
|
571
|
+
reload=False,
|
572
|
+
test_ratio=0.2,
|
573
|
+
|
574
|
+
add_negative_train_samples=False,
|
575
|
+
neg_sampling_ratio=1,
|
576
|
+
|
577
|
+
use_attributes=False,
|
578
|
+
use_edge_types=False,
|
579
|
+
use_edge_label=False,
|
580
|
+
use_node_types=False,
|
581
|
+
no_labels=False,
|
582
|
+
|
583
|
+
use_embeddings=False,
|
584
|
+
embed_model_name='bert-base-uncased',
|
585
|
+
ckpt=None,
|
586
|
+
|
587
|
+
no_shuffle=False,
|
588
|
+
randomize_ne = False,
|
589
|
+
randomize_ee = False,
|
590
|
+
random_embed_dim=128,
|
591
|
+
|
592
|
+
use_special_tokens=False,
|
593
|
+
|
594
|
+
limit: int = -1,
|
595
|
+
|
596
|
+
node_cls_label: str = None,
|
597
|
+
edge_cls_label: str = None,
|
598
|
+
|
599
|
+
task_type=LP_TASK_EDGE_CLS
|
600
|
+
):
|
601
|
+
|
602
|
+
super().__init__(
|
603
|
+
models_dataset=models_dataset,
|
604
|
+
save_dir=save_dir,
|
605
|
+
distance=distance,
|
606
|
+
test_ratio=test_ratio,
|
607
|
+
|
608
|
+
use_node_types=use_node_types,
|
609
|
+
use_edge_types=use_edge_types,
|
610
|
+
use_edge_label=use_edge_label,
|
611
|
+
use_attributes=use_attributes,
|
612
|
+
no_labels=no_labels,
|
613
|
+
|
614
|
+
add_negative_train_samples=add_negative_train_samples,
|
615
|
+
neg_sampling_ratio=neg_sampling_ratio,
|
616
|
+
|
617
|
+
use_special_tokens=use_special_tokens,
|
618
|
+
|
619
|
+
|
620
|
+
node_cls_label=node_cls_label,
|
621
|
+
edge_cls_label=edge_cls_label,
|
622
|
+
|
623
|
+
use_embeddings=use_embeddings,
|
624
|
+
embed_model_name=embed_model_name,
|
625
|
+
ckpt=ckpt,
|
626
|
+
|
627
|
+
reload=reload,
|
628
|
+
no_shuffle=no_shuffle,
|
629
|
+
|
630
|
+
|
631
|
+
randomize_ne=randomize_ne,
|
632
|
+
randomize_ee=randomize_ee,
|
633
|
+
random_embed_dim=random_embed_dim,
|
634
|
+
)
|
635
|
+
|
636
|
+
self.task_type = task_type
|
637
|
+
|
638
|
+
self.set_torch_graphs('edge', models_dataset, limit)
|
639
|
+
|
640
|
+
if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
|
641
|
+
texts = self.get_link_prediction_texts(only_texts=True)
|
642
|
+
texts = sum([v for k, v in texts.items() if not k.endswith("classes")], [])
|
643
|
+
print(f"Training {self.embedder.name} Embedder")
|
644
|
+
self.embedder.train(texts)
|
645
|
+
print(f"Trained {self.embedder.name} Embedder")
|
646
|
+
|
647
|
+
self.embed()
|
648
|
+
|
649
|
+
train_count, test_count = dict(), dict()
|
650
|
+
for g in self.graphs:
|
651
|
+
train_mask = g.data.train_edge_mask
|
652
|
+
test_mask = g.data.test_edge_mask
|
653
|
+
train_labels = getattr(g.data, f'edge_{self.metadata.edge_cls}')[train_mask]
|
654
|
+
test_labels = getattr(g.data, f'edge_{self.metadata.edge_cls}')[test_mask]
|
655
|
+
t1 = dict(Counter(train_labels.tolist()))
|
656
|
+
t2 = dict(Counter(test_labels.tolist()))
|
657
|
+
for k in t1:
|
658
|
+
train_count[k] = train_count.get(k, 0) + t1[k]
|
659
|
+
|
660
|
+
for k in t2:
|
661
|
+
test_count[k] = test_count.get(k, 0) + t2[k]
|
662
|
+
|
663
|
+
print(f"Train edge classes: {train_count}")
|
664
|
+
print(f"Test edge classes: {test_count}")
|
665
|
+
|
666
|
+
|
667
|
+
def get_link_prediction_texts(
|
668
|
+
self,
|
669
|
+
label: str = None,
|
670
|
+
only_texts: bool = False
|
671
|
+
):
|
672
|
+
if label is None:
|
673
|
+
label = self.edge_cls_label
|
674
|
+
|
675
|
+
assert label is not None, "No edge label found in data. Please define edge label in metadata"
|
676
|
+
|
677
|
+
data = defaultdict(list)
|
678
|
+
for torch_graph in tqdm(self.graphs, desc='Getting Graph Texts'):
|
679
|
+
# torch_graph: TorchEdgeGraph = TorchGraph.load(fp)
|
680
|
+
graph_data = torch_graph.get_link_prediction_texts(label, self.task_type, only_texts)
|
681
|
+
for k, v in graph_data.items():
|
682
|
+
data[k] += v
|
683
|
+
|
684
|
+
print("Train Texts: ", data[f'train_pos_edges'][:20])
|
685
|
+
print("Test Texts: ", data[f'test_pos_edges'][:20])
|
686
|
+
|
687
|
+
# print("Train Classes", edge_label_map.inverse_transform([i.item() for i in data[f'train_edge_classes'][:20]]))
|
688
|
+
# print("Test Classes", edge_label_map.inverse_transform([i.item() for i in data[f'test_edge_classes'][:20]]))
|
689
|
+
return data
|
690
|
+
|
691
|
+
|
692
|
+
def get_link_prediction_lm_data(
|
693
|
+
self,
|
694
|
+
tokenizer: AutoTokenizer,
|
695
|
+
label: str = None,
|
696
|
+
):
|
697
|
+
if label is None:
|
698
|
+
label = self.edge_cls_label
|
699
|
+
|
700
|
+
data = self.get_link_prediction_texts(label)
|
701
|
+
|
702
|
+
|
703
|
+
print("Tokenizing data")
|
704
|
+
if self.task_type == LP_TASK_EDGE_CLS:
|
705
|
+
datasets = {
|
706
|
+
'train': EncodingDataset(
|
707
|
+
tokenizer,
|
708
|
+
data['train_pos_edges'],
|
709
|
+
data['train_edge_classes']
|
710
|
+
),
|
711
|
+
'test': EncodingDataset(
|
712
|
+
tokenizer,
|
713
|
+
data['test_pos_edges'],
|
714
|
+
data['test_edge_classes']
|
715
|
+
)
|
716
|
+
}
|
717
|
+
elif self.task_type == LP_TASK_LINK_PRED:
|
718
|
+
datasets = {
|
719
|
+
'train': EncodingDataset(
|
720
|
+
tokenizer,
|
721
|
+
data['train_pos_edges'] + data['train_neg_edges'],
|
722
|
+
[1] * len(data['train_pos_edges']) + [0] * len(data['train_neg_edges'])
|
723
|
+
),
|
724
|
+
'test': EncodingDataset(
|
725
|
+
tokenizer,
|
726
|
+
data['test_pos_edges'] + data['test_neg_edges'],
|
727
|
+
[1] * len(data['test_pos_edges']) + [0] * len(data['test_neg_edges'])
|
728
|
+
)
|
729
|
+
}
|
730
|
+
else:
|
731
|
+
raise ValueError(f"Invalid task type: {self.task_type}")
|
732
|
+
|
733
|
+
print("Tokenized data")
|
734
|
+
|
735
|
+
return datasets
|
736
|
+
|
737
|
+
|
738
|
+
class GraphNodeDataset(GraphDataset):
|
739
|
+
def __init__(
|
740
|
+
self,
|
741
|
+
models_dataset: Union[EcoreDataset, ArchiMateDataset],
|
742
|
+
save_dir='datasets/graph_data',
|
743
|
+
distance=0,
|
744
|
+
test_ratio=0.2,
|
745
|
+
reload=False,
|
746
|
+
|
747
|
+
use_attributes=False,
|
748
|
+
use_edge_types=False,
|
749
|
+
use_node_types=False,
|
750
|
+
use_edge_label=False,
|
751
|
+
use_special_tokens=False,
|
752
|
+
|
753
|
+
use_embeddings=False,
|
754
|
+
embed_model_name='bert-base-uncased',
|
755
|
+
ckpt=None,
|
756
|
+
|
757
|
+
no_shuffle=False,
|
758
|
+
randomize_ne=False,
|
759
|
+
randomize_ee=False,
|
760
|
+
random_embed_dim=128,
|
761
|
+
|
762
|
+
limit: int = -1,
|
763
|
+
no_labels=False,
|
764
|
+
node_cls_label: str = None,
|
765
|
+
edge_cls_label: str = None
|
766
|
+
):
|
767
|
+
super().__init__(
|
768
|
+
models_dataset=models_dataset,
|
769
|
+
save_dir=save_dir,
|
770
|
+
distance=distance,
|
771
|
+
test_ratio=test_ratio,
|
772
|
+
|
773
|
+
use_node_types=use_node_types,
|
774
|
+
use_edge_types=use_edge_types,
|
775
|
+
use_edge_label=use_edge_label,
|
776
|
+
use_attributes=use_attributes,
|
777
|
+
no_labels=no_labels,
|
778
|
+
|
779
|
+
node_cls_label=node_cls_label,
|
780
|
+
edge_cls_label=edge_cls_label,
|
781
|
+
|
782
|
+
use_embeddings=use_embeddings,
|
783
|
+
embed_model_name=embed_model_name,
|
784
|
+
ckpt=ckpt,
|
785
|
+
|
786
|
+
reload=reload,
|
787
|
+
no_shuffle=no_shuffle,
|
788
|
+
|
789
|
+
use_special_tokens=use_special_tokens,
|
790
|
+
|
791
|
+
randomize_ne=randomize_ne,
|
792
|
+
randomize_ee=randomize_ee,
|
793
|
+
random_embed_dim=random_embed_dim,
|
794
|
+
)
|
795
|
+
|
796
|
+
self.set_torch_graphs('node', models_dataset, limit)
|
797
|
+
|
798
|
+
if self.use_embeddings and (isinstance(self.embedder, Word2VecEmbedder) or isinstance(self.embedder, TfidfEmbedder)):
|
799
|
+
texts = self.get_node_classification_texts()
|
800
|
+
texts = sum([v for _, v in texts.items() if not v.endswith("classes")], [])
|
801
|
+
print(f"Training {self.embedder.name} Embedder")
|
802
|
+
self.embedder.train(texts)
|
803
|
+
print(f"Trained {self.embedder.name} Embedder")
|
804
|
+
|
805
|
+
self.embed()
|
806
|
+
|
807
|
+
node_labels = self.metadata.node_cls
|
808
|
+
if isinstance(node_labels, str):
|
809
|
+
node_labels = [node_labels]
|
810
|
+
|
811
|
+
for node_label in node_labels:
|
812
|
+
print(f"Node label: {node_label}")
|
813
|
+
train_count, test_count = dict(), dict()
|
814
|
+
for g in self.graphs:
|
815
|
+
train_idx = g.data.train_node_mask
|
816
|
+
test_idx = g.data.test_node_mask
|
817
|
+
train_labels = getattr(g.data, f'node_{node_label}')[train_idx]
|
818
|
+
test_labels = getattr(g.data, f'node_{node_label}')[test_idx]
|
819
|
+
t1 = dict(Counter(train_labels.tolist()))
|
820
|
+
t2 = dict(Counter(test_labels.tolist()))
|
821
|
+
for k in t1:
|
822
|
+
train_count[k] = train_count.get(k, 0) + t1[k]
|
823
|
+
|
824
|
+
for k in t2:
|
825
|
+
test_count[k] = test_count.get(k, 0) + t2[k]
|
826
|
+
|
827
|
+
print(f"Train Node classes: {train_count}")
|
828
|
+
print(f"Test Node classes: {test_count}")
|
829
|
+
|
830
|
+
|
831
|
+
def get_node_classification_texts(self, distance=None, label=None):
|
832
|
+
if distance is None:
|
833
|
+
distance = self.distance
|
834
|
+
|
835
|
+
label = self.metadata.node_cls if label is None else label
|
836
|
+
|
837
|
+
if isinstance(label, list):
|
838
|
+
label = label[0]
|
839
|
+
|
840
|
+
node_label_map = getattr(self, f"node_label_map_{label}")
|
841
|
+
|
842
|
+
data = {'train_nodes': [], 'train_node_classes': [], 'test_nodes': [], 'test_node_classes': []}
|
843
|
+
for torch_graph in tqdm(self.graphs, desc='Getting node classification data'):
|
844
|
+
# graph: TorchNodeGraph = TorchGraph.load(fp)
|
845
|
+
node_strs = list(torch_graph.get_graph_node_strs(torch_graph.data.edge_index, distance).values())
|
846
|
+
|
847
|
+
train_node_strs = [node_strs[i.item()] for i in torch_graph.data.train_node_mask]
|
848
|
+
test_node_strs = [node_strs[i.item()] for i in torch_graph.data.test_node_mask]
|
849
|
+
|
850
|
+
train_node_classes = getattr(torch_graph.data, f'node_{label}')[torch_graph.data.train_node_mask]
|
851
|
+
test_node_classes = getattr(torch_graph.data, f'node_{label}')[torch_graph.data.test_node_mask]
|
852
|
+
|
853
|
+
exclude_labels = getattr(self, f'node_exclude_{label}')
|
854
|
+
train_node_strs, train_node_classes = exclude_labels_from_data(train_node_strs, train_node_classes, exclude_labels)
|
855
|
+
test_node_strs, test_node_classes = exclude_labels_from_data(test_node_strs, test_node_classes, exclude_labels)
|
856
|
+
|
857
|
+
data['train_nodes'] += train_node_strs
|
858
|
+
data['train_node_classes'] += train_node_classes
|
859
|
+
data['test_nodes'] += test_node_strs
|
860
|
+
data['test_node_classes'] += test_node_classes
|
861
|
+
|
862
|
+
|
863
|
+
print("Tokenizing data")
|
864
|
+
print(data['train_nodes'][:10])
|
865
|
+
print(data['test_nodes'][:10])
|
866
|
+
if hasattr(self, "node_label_map_type"):
|
867
|
+
node_label_map.inverse_transform([i.item() for i in train_node_classes]) == train_node_strs
|
868
|
+
node_label_map.inverse_transform([i.item() for i in test_node_classes]) == test_node_strs
|
869
|
+
|
870
|
+
print(len(data['train_nodes']))
|
871
|
+
print(len(data['train_node_classes']))
|
872
|
+
print(len(data['test_nodes']))
|
873
|
+
print(len(data['test_node_classes']))
|
874
|
+
return data
|
875
|
+
|
876
|
+
|
877
|
+
def get_node_classification_lm_data(
|
878
|
+
self,
|
879
|
+
label: str,
|
880
|
+
tokenizer: AutoTokenizer,
|
881
|
+
distance: int = 0,
|
882
|
+
):
|
883
|
+
data = self.get_node_classification_texts(distance, label)
|
884
|
+
dataset = {
|
885
|
+
'train': EncodingDataset(
|
886
|
+
tokenizer,
|
887
|
+
data['train_nodes'],
|
888
|
+
data['train_node_classes']
|
889
|
+
),
|
890
|
+
'test': EncodingDataset(
|
891
|
+
tokenizer,
|
892
|
+
data['test_nodes'],
|
893
|
+
data['test_node_classes']
|
894
|
+
)
|
895
|
+
}
|
896
|
+
|
897
|
+
print("Tokenized data")
|
898
|
+
|
899
|
+
return dataset
|
900
|
+
|
901
|
+
|
902
|
+
def get_models_gpt_dataset(
|
903
|
+
models_dataset: Union[ArchiMateDataset, EcoreDataset],
|
904
|
+
tokenizer: AutoTokenizer,
|
905
|
+
chunk_size: int = 100,
|
906
|
+
chunk_overlap: int = 20,
|
907
|
+
max_length: int = 128,
|
908
|
+
**config_params
|
909
|
+
):
|
910
|
+
|
911
|
+
def split_texts_into_chunks(
|
912
|
+
texts: List[str],
|
913
|
+
size: int = 100,
|
914
|
+
overlap: int = 20,
|
915
|
+
):
|
916
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
917
|
+
# Set a really small chunk size, just to show.
|
918
|
+
chunk_size=size,
|
919
|
+
chunk_overlap=overlap,
|
920
|
+
length_function=len,
|
921
|
+
is_separator_regex=False,
|
922
|
+
)
|
923
|
+
return [t.page_content for t in text_splitter.create_documents(texts)]
|
924
|
+
|
925
|
+
graph_dataset = GraphEdgeDataset(models_dataset, **config_params)
|
926
|
+
texts_data = graph_dataset.get_link_prediction_texts()
|
927
|
+
texts = texts_data['train_pos_edges'] + texts_data['test_pos_edges']
|
928
|
+
|
929
|
+
print("Total texts", len(texts))
|
930
|
+
splitted_texts = split_texts_into_chunks(
|
931
|
+
texts,
|
932
|
+
size=chunk_size,
|
933
|
+
overlap=chunk_overlap
|
934
|
+
)
|
935
|
+
print(len(splitted_texts))
|
936
|
+
print("Tokenizing...")
|
937
|
+
dataset = GPTTextDataset(texts, tokenizer, max_length=max_length)
|
938
|
+
print("Tokenized")
|
939
|
+
print("Max length", dataset[:]['input_ids'].shape)
|
940
|
+
return dataset
|