glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
glam4cm/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,631 @@
|
|
1
|
+
import pickle
|
2
|
+
import numpy as np
|
3
|
+
from sklearn.model_selection import train_test_split
|
4
|
+
import torch
|
5
|
+
import networkx as nx
|
6
|
+
import os
|
7
|
+
from glam4cm.data_loading.metadata import (
|
8
|
+
ArchimateMetaData,
|
9
|
+
EcoreMetaData
|
10
|
+
)
|
11
|
+
from glam4cm.embeddings.common import Embedder
|
12
|
+
from glam4cm.lang2graph.archimate import ArchiMateNxG
|
13
|
+
from glam4cm.lang2graph.ecore import EcoreNxG
|
14
|
+
from glam4cm.lang2graph.common import (
|
15
|
+
create_graph_from_edge_index,
|
16
|
+
get_node_texts,
|
17
|
+
get_edge_texts
|
18
|
+
)
|
19
|
+
|
20
|
+
from scipy.sparse import csr_matrix
|
21
|
+
|
22
|
+
from glam4cm.settings import LP_TASK_EDGE_CLS, LP_TASK_LINK_PRED
|
23
|
+
from glam4cm.tokenization.special_tokens import *
|
24
|
+
from torch_geometric.transforms import RandomLinkSplit
|
25
|
+
import torch
|
26
|
+
from torch_geometric.data import Data, Dataset
|
27
|
+
from typing import List, Optional, Sequence, Union
|
28
|
+
from glam4cm.tokenization.utils import doc_tokenizer
|
29
|
+
|
30
|
+
|
31
|
+
|
32
|
+
def edge_index_to_idx(graph, edge_index):
|
33
|
+
return torch.tensor(
|
34
|
+
[
|
35
|
+
graph.edge_to_idx[(u, v)]
|
36
|
+
for u, v in edge_index.t().tolist()
|
37
|
+
],
|
38
|
+
dtype=torch.long
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
class GraphData(Data):
|
43
|
+
def __inc__(self, key, value, *args, **kwargs):
|
44
|
+
if 'index' in key or 'node_mask' in key:
|
45
|
+
return self.num_nodes
|
46
|
+
elif 'edge_mask' in key:
|
47
|
+
return self.num_edges
|
48
|
+
else:
|
49
|
+
return 0
|
50
|
+
|
51
|
+
def __cat_dim__(self, key, value, *args, **kwargs):
|
52
|
+
if 'index' in key:
|
53
|
+
return 1
|
54
|
+
else:
|
55
|
+
return 0
|
56
|
+
|
57
|
+
|
58
|
+
class NumpyData:
|
59
|
+
def __init__(self, data: dict = {}):
|
60
|
+
self.set_data(data)
|
61
|
+
|
62
|
+
def __getitem__(self, key):
|
63
|
+
return getattr(self, key)
|
64
|
+
|
65
|
+
def set_data(self, data: dict):
|
66
|
+
for k, v in data.items():
|
67
|
+
if isinstance(v, torch.Tensor):
|
68
|
+
v = v.numpy()
|
69
|
+
setattr(self, k, v)
|
70
|
+
|
71
|
+
def __repr__(self):
|
72
|
+
response = "NumpyData(" + ", ".join([
|
73
|
+
f"{k}={list(v.shape)}" if isinstance(v, np.ndarray)
|
74
|
+
else f"{k}={v}"
|
75
|
+
for k, v in self.__dict__.items()
|
76
|
+
]) + ")"
|
77
|
+
return response
|
78
|
+
|
79
|
+
def to_graph_data(self):
|
80
|
+
data = GraphData()
|
81
|
+
for k, v in self.__dict__.items():
|
82
|
+
if isinstance(v, np.ndarray):
|
83
|
+
v = torch.from_numpy(v)
|
84
|
+
elif isinstance(v, csr_matrix):
|
85
|
+
v = torch.from_numpy(v.toarray())
|
86
|
+
elif isinstance(v, int):
|
87
|
+
v = torch.tensor(v, dtype=torch.long)
|
88
|
+
|
89
|
+
if v.dtype == torch.float64:
|
90
|
+
v = v.float()
|
91
|
+
|
92
|
+
setattr(data, k, v)
|
93
|
+
return data
|
94
|
+
|
95
|
+
|
96
|
+
class TorchGraph:
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
graph: Union[EcoreNxG, ArchiMateNxG],
|
100
|
+
metadata: Union[EcoreMetaData, ArchimateMetaData],
|
101
|
+
distance = 0,
|
102
|
+
test_ratio=0.2,
|
103
|
+
use_edge_types=False,
|
104
|
+
use_node_types=False,
|
105
|
+
use_attributes=False,
|
106
|
+
use_edge_label=False,
|
107
|
+
use_special_tokens=False,
|
108
|
+
no_labels=False,
|
109
|
+
node_cls_label=None,
|
110
|
+
edge_cls_label='type',
|
111
|
+
fp='test_graph.pkl'
|
112
|
+
):
|
113
|
+
|
114
|
+
self.fp = fp
|
115
|
+
self.graph = graph
|
116
|
+
self.metadata = metadata
|
117
|
+
|
118
|
+
self.raw_data = graph.xmi if hasattr(graph, 'xmi') else graph.json_obj
|
119
|
+
self.use_edge_types = use_edge_types
|
120
|
+
self.use_node_types = use_node_types
|
121
|
+
self.use_attributes = use_attributes
|
122
|
+
self.use_edge_label = use_edge_label
|
123
|
+
self.use_special_tokens = use_special_tokens
|
124
|
+
self.no_labels = no_labels
|
125
|
+
|
126
|
+
self.node_cls_label = node_cls_label
|
127
|
+
self.edge_cls_label = edge_cls_label
|
128
|
+
|
129
|
+
self.distance = distance
|
130
|
+
self.test_ratio = test_ratio
|
131
|
+
self.data = NumpyData()
|
132
|
+
|
133
|
+
|
134
|
+
def load(pkl_path):
|
135
|
+
with open(pkl_path, 'rb') as f:
|
136
|
+
return pickle.load(f)
|
137
|
+
|
138
|
+
def save(self):
|
139
|
+
os.makedirs(os.path.dirname(self.fp), exist_ok=True)
|
140
|
+
with open(self.fp, 'wb') as f:
|
141
|
+
pickle.dump(self, f)
|
142
|
+
|
143
|
+
|
144
|
+
def get_node_edge_strings(self, edge_index):
|
145
|
+
node_texts = self.get_graph_node_strs(
|
146
|
+
edge_index=edge_index,
|
147
|
+
distance=self.distance
|
148
|
+
)
|
149
|
+
|
150
|
+
edge_texts = self.get_graph_edge_strs()
|
151
|
+
return node_texts, edge_texts
|
152
|
+
|
153
|
+
|
154
|
+
def embed(
|
155
|
+
self,
|
156
|
+
embedder: Union[Embedder, None],
|
157
|
+
reload=False,
|
158
|
+
randomize_ne=False,
|
159
|
+
randomize_ee=False,
|
160
|
+
random_embed_dim=128
|
161
|
+
):
|
162
|
+
|
163
|
+
def generate_embeddings():
|
164
|
+
if randomize_ne or embedder is None:
|
165
|
+
print("Randomizing node embeddings")
|
166
|
+
self.data.x = np.random.randn(self.graph.number_of_nodes(), random_embed_dim)
|
167
|
+
else:
|
168
|
+
self.data.x = embedder.embed(list(self.node_texts.values()))
|
169
|
+
|
170
|
+
if randomize_ee or embedder is None:
|
171
|
+
print("Randomizing edge embeddings")
|
172
|
+
self.data.edge_attr = np.random.randn(self.graph.number_of_edges(), random_embed_dim)
|
173
|
+
else:
|
174
|
+
self.data.edge_attr = embedder.embed(list(self.edge_texts.values()))
|
175
|
+
|
176
|
+
if os.path.exists(f"{self.fp}") and not reload:
|
177
|
+
with open(f"{self.fp}", 'rb') as f:
|
178
|
+
obj: Union[TorchEdgeGraph, TorchNodeGraph] = pickle.load(f)
|
179
|
+
if not hasattr(obj.data, 'x') or not hasattr(obj.data, 'edge_attr'):
|
180
|
+
generate_embeddings()
|
181
|
+
self.save()
|
182
|
+
else:
|
183
|
+
if embedder is not None:
|
184
|
+
generate_embeddings()
|
185
|
+
else:
|
186
|
+
self.data.x = np.random.randn(self.graph.number_of_nodes(), random_embed_dim)
|
187
|
+
self.data.edge_attr = np.random.randn(self.graph.number_of_edges(), random_embed_dim)
|
188
|
+
|
189
|
+
self.save()
|
190
|
+
|
191
|
+
|
192
|
+
def get_graph_node_strs(
|
193
|
+
self,
|
194
|
+
edge_index: np.ndarray,
|
195
|
+
distance = None,
|
196
|
+
preprocessor: callable = doc_tokenizer
|
197
|
+
):
|
198
|
+
if distance is None:
|
199
|
+
distance = self.distance
|
200
|
+
|
201
|
+
subgraph = create_graph_from_edge_index(self.graph, edge_index)
|
202
|
+
return get_node_texts(
|
203
|
+
subgraph,
|
204
|
+
distance,
|
205
|
+
metadata=self.metadata,
|
206
|
+
use_node_attributes=self.use_attributes,
|
207
|
+
use_node_types=self.use_node_types,
|
208
|
+
use_edge_types=self.use_edge_types,
|
209
|
+
use_edge_label=self.use_edge_label,
|
210
|
+
node_cls_label=self.node_cls_label,
|
211
|
+
edge_cls_label=self.edge_cls_label,
|
212
|
+
use_special_tokens=self.use_special_tokens,
|
213
|
+
no_labels=self.no_labels,
|
214
|
+
preprocessor=preprocessor
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
def get_graph_edge_strs(
|
219
|
+
self,
|
220
|
+
edge_index: np.ndarray = None,
|
221
|
+
neg_samples=False,
|
222
|
+
preprocessor: callable = doc_tokenizer
|
223
|
+
):
|
224
|
+
if edge_index is None:
|
225
|
+
edge_index = self.graph.edge_index
|
226
|
+
|
227
|
+
edge_strs = dict()
|
228
|
+
for u, v in edge_index.T:
|
229
|
+
edge_str = get_edge_texts(
|
230
|
+
self.graph.numbered_graph,
|
231
|
+
(u, v),
|
232
|
+
d=self.distance,
|
233
|
+
metadata=self.metadata,
|
234
|
+
use_node_attributes=self.use_attributes,
|
235
|
+
use_node_types=self.use_node_types,
|
236
|
+
use_edge_types=self.use_edge_types,
|
237
|
+
use_edge_label=self.use_edge_label,
|
238
|
+
use_special_tokens=self.use_special_tokens,
|
239
|
+
no_labels=self.no_labels,
|
240
|
+
preprocessor=preprocessor,
|
241
|
+
neg_samples=neg_samples
|
242
|
+
)
|
243
|
+
|
244
|
+
edge_strs[(u, v)] = edge_str
|
245
|
+
|
246
|
+
return edge_strs
|
247
|
+
|
248
|
+
|
249
|
+
def validate_data(self):
|
250
|
+
assert self.data.num_nodes == self.graph.number_of_nodes()
|
251
|
+
|
252
|
+
def set_graph_label(self):
|
253
|
+
if self.metadata.graph_label is not None and not hasattr(self.graph, self.metadata.graph_label): #Graph has a label
|
254
|
+
text = doc_tokenizer("\n".join(list(self.node_texts.values())))
|
255
|
+
# print("Text:", text)
|
256
|
+
# print("-" * 100)
|
257
|
+
setattr(self.graph, self.metadata.graph_label, text)
|
258
|
+
|
259
|
+
|
260
|
+
@property
|
261
|
+
def name(self):
|
262
|
+
return '.'.join(self.graph.graph_id.replace('/', '_').split('.')[:-1])
|
263
|
+
|
264
|
+
|
265
|
+
|
266
|
+
class TorchEdgeGraph(TorchGraph):
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
graph: Union[EcoreNxG, ArchiMateNxG],
|
270
|
+
metadata: Union[EcoreMetaData, ArchimateMetaData],
|
271
|
+
distance = 1,
|
272
|
+
test_ratio=0.2,
|
273
|
+
add_negative_train_samples=False,
|
274
|
+
neg_samples_ratio=1,
|
275
|
+
use_edge_types=False,
|
276
|
+
use_node_types=False,
|
277
|
+
use_edge_label=False,
|
278
|
+
use_attributes=False,
|
279
|
+
use_special_tokens=False,
|
280
|
+
node_cls_label=None,
|
281
|
+
edge_cls_label='type',
|
282
|
+
no_labels=False,
|
283
|
+
fp: str = 'test_graph.pkl'
|
284
|
+
):
|
285
|
+
|
286
|
+
super().__init__(
|
287
|
+
graph=graph,
|
288
|
+
metadata=metadata,
|
289
|
+
distance=distance,
|
290
|
+
test_ratio=test_ratio,
|
291
|
+
use_node_types=use_node_types,
|
292
|
+
use_edge_types=use_edge_types,
|
293
|
+
use_attributes=use_attributes,
|
294
|
+
use_edge_label=use_edge_label,
|
295
|
+
use_special_tokens=use_special_tokens,
|
296
|
+
no_labels=no_labels,
|
297
|
+
node_cls_label=node_cls_label,
|
298
|
+
edge_cls_label=edge_cls_label,
|
299
|
+
fp=fp
|
300
|
+
)
|
301
|
+
self.add_negative_train_samples = add_negative_train_samples
|
302
|
+
self.neg_sampling_ratio = neg_samples_ratio
|
303
|
+
self.data, self.node_texts, self.edge_texts = self.get_pyg_data()
|
304
|
+
self.validate_data()
|
305
|
+
self.set_graph_label()
|
306
|
+
|
307
|
+
|
308
|
+
|
309
|
+
|
310
|
+
def get_pyg_data(self):
|
311
|
+
|
312
|
+
d = GraphData()
|
313
|
+
|
314
|
+
transform = RandomLinkSplit(
|
315
|
+
num_val=0,
|
316
|
+
num_test=self.test_ratio,
|
317
|
+
add_negative_train_samples=self.add_negative_train_samples,
|
318
|
+
neg_sampling_ratio=self.neg_sampling_ratio,
|
319
|
+
split_labels=True
|
320
|
+
)
|
321
|
+
|
322
|
+
try:
|
323
|
+
train_data, _, test_data = transform(GraphData(
|
324
|
+
edge_index=torch.tensor(self.graph.edge_index),
|
325
|
+
num_nodes=self.graph.number_of_nodes()
|
326
|
+
))
|
327
|
+
except IndexError as e:
|
328
|
+
print(self.graph.edge_index)
|
329
|
+
raise e
|
330
|
+
|
331
|
+
train_idx = edge_index_to_idx(self.graph, train_data.edge_index)
|
332
|
+
test_idx = edge_index_to_idx(self.graph, test_data.pos_edge_label_index)
|
333
|
+
|
334
|
+
setattr(d, 'train_edge_mask', train_idx)
|
335
|
+
setattr(d, 'test_edge_mask', test_idx)
|
336
|
+
|
337
|
+
|
338
|
+
assert all([self.graph.numbered_graph.has_edge(*edge) for edge in train_data.edge_index.t().tolist()])
|
339
|
+
assert all([self.graph.numbered_graph.has_edge(*edge) for edge in test_data.pos_edge_label_index.t().tolist()])
|
340
|
+
|
341
|
+
setattr(d, 'train_pos_edge_label_index', train_data.pos_edge_label_index)
|
342
|
+
setattr(d, 'train_pos_edge_label', train_data.pos_edge_label)
|
343
|
+
setattr(d, 'test_pos_edge_label_index', test_data.pos_edge_label_index)
|
344
|
+
setattr(d, 'test_pos_edge_label', test_data.pos_edge_label)
|
345
|
+
|
346
|
+
|
347
|
+
if hasattr(train_data, 'neg_edge_label_index'):
|
348
|
+
assert not any([self.graph.numbered_graph.has_edge(*edge) for edge in train_data.neg_edge_label_index.t().tolist()])
|
349
|
+
assert not any([self.graph.numbered_graph.has_edge(*edge) for edge in test_data.neg_edge_label_index.t().tolist()])
|
350
|
+
setattr(d, 'train_neg_edge_label_index', train_data.neg_edge_label_index)
|
351
|
+
setattr(d, 'train_neg_edge_label', train_data.neg_edge_label)
|
352
|
+
setattr(d, 'test_neg_edge_label_index', test_data.neg_edge_label_index)
|
353
|
+
setattr(d, 'test_neg_edge_label', test_data.neg_edge_label)
|
354
|
+
|
355
|
+
|
356
|
+
nx.set_edge_attributes(
|
357
|
+
self.graph.numbered_graph,
|
358
|
+
{tuple(edge): False for edge in train_data.pos_edge_label_index.T.tolist()},
|
359
|
+
'masked'
|
360
|
+
)
|
361
|
+
nx.set_edge_attributes(
|
362
|
+
self.graph.numbered_graph,
|
363
|
+
{tuple(edge): True for edge in test_data.pos_edge_label_index.T.tolist()},
|
364
|
+
'masked'
|
365
|
+
)
|
366
|
+
|
367
|
+
edge_index = train_data.edge_index
|
368
|
+
# import code; code.interact(local=locals())
|
369
|
+
setattr(d, 'overall_edge_index', self.graph.edge_index)
|
370
|
+
setattr(d, 'edge_index', edge_index)
|
371
|
+
|
372
|
+
node_texts, edge_texts = self.get_node_edge_strings(
|
373
|
+
edge_index=edge_index.numpy(),
|
374
|
+
)
|
375
|
+
|
376
|
+
setattr(d, 'num_nodes', self.graph.number_of_nodes())
|
377
|
+
setattr(d, 'num_edges', self.graph.number_of_edges())
|
378
|
+
d = NumpyData(d)
|
379
|
+
return d, node_texts, edge_texts
|
380
|
+
|
381
|
+
|
382
|
+
def get_link_prediction_texts(self, label, task_type, only_texts=False):
|
383
|
+
data = dict()
|
384
|
+
train_pos_edge_index = self.data.edge_index
|
385
|
+
test_pos_edge_index = self.data.test_pos_edge_label_index
|
386
|
+
|
387
|
+
if task_type == LP_TASK_LINK_PRED:
|
388
|
+
train_neg_edge_index = self.data.train_neg_edge_label_index
|
389
|
+
test_neg_edge_index = self.data.test_neg_edge_label_index
|
390
|
+
else:
|
391
|
+
train_neg_edge_index = None
|
392
|
+
test_neg_edge_index = None
|
393
|
+
|
394
|
+
validate_edges(self)
|
395
|
+
|
396
|
+
# print(train_neg_edge_index.shape)
|
397
|
+
|
398
|
+
edge_indices = {
|
399
|
+
'train_pos': train_pos_edge_index,
|
400
|
+
'train_neg': train_neg_edge_index,
|
401
|
+
'test_pos': test_pos_edge_index,
|
402
|
+
'test_neg': test_neg_edge_index
|
403
|
+
}
|
404
|
+
|
405
|
+
for edge_index_label, edge_index in edge_indices.items():
|
406
|
+
if edge_index is None:
|
407
|
+
continue
|
408
|
+
edge_strs = self.get_graph_edge_strs(
|
409
|
+
edge_index=edge_index,
|
410
|
+
neg_samples="neg" in edge_index_label,
|
411
|
+
)
|
412
|
+
|
413
|
+
edge_strs = list(edge_strs.values())
|
414
|
+
data[f'{edge_index_label}_edges'] = edge_strs
|
415
|
+
|
416
|
+
|
417
|
+
if task_type == LP_TASK_EDGE_CLS and not only_texts:
|
418
|
+
train_mask = self.data.train_edge_mask
|
419
|
+
test_mask = self.data.test_edge_mask
|
420
|
+
train_classes, test_classes = getattr(self.data, f'edge_{label}')[train_mask], getattr(self.data, f'edge_{label}')[test_mask]
|
421
|
+
data['train_edge_classes'] = train_classes.tolist()
|
422
|
+
data['test_edge_classes'] = test_classes.tolist()
|
423
|
+
|
424
|
+
return data
|
425
|
+
|
426
|
+
|
427
|
+
|
428
|
+
|
429
|
+
class TorchNodeGraph(TorchGraph):
|
430
|
+
def __init__(
|
431
|
+
self,
|
432
|
+
graph: Union[EcoreNxG, ArchiMateNxG],
|
433
|
+
metadata: dict,
|
434
|
+
distance = 1,
|
435
|
+
test_ratio=0.2,
|
436
|
+
use_node_types=False,
|
437
|
+
use_edge_types=False,
|
438
|
+
use_edge_label=False,
|
439
|
+
use_attributes=False,
|
440
|
+
use_special_tokens=False,
|
441
|
+
no_labels=False,
|
442
|
+
node_cls_label=None,
|
443
|
+
edge_cls_label='type',
|
444
|
+
fp='test_graph.pkl'
|
445
|
+
):
|
446
|
+
|
447
|
+
super().__init__(
|
448
|
+
graph,
|
449
|
+
metadata=metadata,
|
450
|
+
distance=distance,
|
451
|
+
test_ratio=test_ratio,
|
452
|
+
use_node_types=use_node_types,
|
453
|
+
use_edge_types=use_edge_types,
|
454
|
+
use_edge_label=use_edge_label,
|
455
|
+
use_attributes=use_attributes,
|
456
|
+
use_special_tokens=use_special_tokens,
|
457
|
+
no_labels=no_labels,
|
458
|
+
node_cls_label=node_cls_label,
|
459
|
+
edge_cls_label=edge_cls_label,
|
460
|
+
fp=fp
|
461
|
+
)
|
462
|
+
|
463
|
+
self.data, self.node_texts, self.edge_texts = self.get_pyg_data()
|
464
|
+
self.validate_data()
|
465
|
+
self.set_graph_label()
|
466
|
+
|
467
|
+
|
468
|
+
|
469
|
+
def get_pyg_data(self):
|
470
|
+
d = GraphData()
|
471
|
+
train_nodes, test_nodes = train_test_split(
|
472
|
+
list(self.graph.numbered_graph.nodes),
|
473
|
+
test_size=self.test_ratio,
|
474
|
+
shuffle=True,
|
475
|
+
random_state=42
|
476
|
+
)
|
477
|
+
|
478
|
+
nx.set_node_attributes(self.graph.numbered_graph, {node: False for node in train_nodes}, 'masked')
|
479
|
+
nx.set_node_attributes(self.graph.numbered_graph, {node: True for node in test_nodes}, 'masked')
|
480
|
+
|
481
|
+
train_idx = torch.tensor(train_nodes, dtype=torch.long)
|
482
|
+
test_idx = torch.tensor(test_nodes, dtype=torch.long)
|
483
|
+
|
484
|
+
setattr(d, 'train_node_mask', train_idx)
|
485
|
+
setattr(d, 'test_node_mask', test_idx)
|
486
|
+
|
487
|
+
|
488
|
+
assert all([self.graph.numbered_graph.has_node(n) for n in train_nodes])
|
489
|
+
assert all([self.graph.numbered_graph.has_node(n) for n in test_nodes])
|
490
|
+
|
491
|
+
|
492
|
+
|
493
|
+
edge_index = self.graph.edge_index
|
494
|
+
setattr(d, 'edge_index', edge_index)
|
495
|
+
|
496
|
+
node_texts, edge_texts = self.get_node_edge_strings(
|
497
|
+
edge_index=edge_index,
|
498
|
+
)
|
499
|
+
|
500
|
+
setattr(d, 'num_nodes', self.graph.number_of_nodes())
|
501
|
+
d = NumpyData(d)
|
502
|
+
return d, node_texts, edge_texts
|
503
|
+
|
504
|
+
|
505
|
+
@property
|
506
|
+
def name(self):
|
507
|
+
return '.'.join(self.graph.graph_id.replace('/', '_').split('.')[:-1])
|
508
|
+
|
509
|
+
|
510
|
+
|
511
|
+
def validate_edges(graph: Union[TorchEdgeGraph, TorchNodeGraph]):
|
512
|
+
|
513
|
+
train_pos_edge_index = graph.data.edge_index
|
514
|
+
test_pos_edge_index = graph.data.test_pos_edge_label_index
|
515
|
+
train_neg_edge_index = graph.data.train_neg_edge_label_index if hasattr(graph.data, 'train_neg_edge_label_index') else None
|
516
|
+
test_neg_edge_index = graph.data.test_neg_edge_label_index if hasattr(graph.data, 'test_neg_edge_label_index') else None
|
517
|
+
|
518
|
+
assert set((a, b) for a, b in train_pos_edge_index.T.tolist()).issubset(set(graph.graph.numbered_graph.edges()))
|
519
|
+
assert set((a, b) for a, b in test_pos_edge_index.T.tolist()).issubset(set(graph.graph.numbered_graph.edges()))
|
520
|
+
assert len(set((a, b) for a, b in train_pos_edge_index.T.tolist()).intersection(set((a, b) for a, b in test_pos_edge_index.T.tolist()))) == 0
|
521
|
+
|
522
|
+
if train_neg_edge_index is not None:
|
523
|
+
assert len(set(graph.graph.numbered_graph.edges()).intersection(set((a, b) for a, b in train_neg_edge_index.T.tolist()))) == 0
|
524
|
+
|
525
|
+
if test_neg_edge_index is not None:
|
526
|
+
assert len(set(graph.graph.numbered_graph.edges()).intersection(set((a, b) for a, b in test_neg_edge_index.T.tolist()))) == 0
|
527
|
+
|
528
|
+
if train_neg_edge_index is not None and test_neg_edge_index is not None:
|
529
|
+
assert len(set((a, b) for a, b in train_neg_edge_index.T.tolist()).intersection(set((a, b) for a, b in test_neg_edge_index.T.tolist()))) == 0
|
530
|
+
|
531
|
+
|
532
|
+
|
533
|
+
|
534
|
+
class LinkPredictionCollater:
|
535
|
+
def __init__(
|
536
|
+
self,
|
537
|
+
follow_batch: Optional[List[str]] = None,
|
538
|
+
exclude_keys: Optional[List[str]] = None
|
539
|
+
):
|
540
|
+
self.follow_batch = follow_batch
|
541
|
+
self.exclude_keys = exclude_keys
|
542
|
+
|
543
|
+
def __call__(self, batch: List[Data]):
|
544
|
+
# Initialize lists to collect batched properties
|
545
|
+
x = []
|
546
|
+
edge_index = []
|
547
|
+
edge_attr = []
|
548
|
+
y = []
|
549
|
+
overall_edge_index = []
|
550
|
+
edge_classes = []
|
551
|
+
train_edge_mask = []
|
552
|
+
test_edge_mask = []
|
553
|
+
train_pos_edge_label_index = []
|
554
|
+
train_pos_edge_label = []
|
555
|
+
train_neg_edge_label_index = []
|
556
|
+
train_neg_edge_label = []
|
557
|
+
test_pos_edge_label_index = []
|
558
|
+
test_pos_edge_label = []
|
559
|
+
test_neg_edge_label_index = []
|
560
|
+
test_neg_edge_label = []
|
561
|
+
|
562
|
+
# Offsets for edge indices
|
563
|
+
node_offset = 0
|
564
|
+
edge_offset = 0
|
565
|
+
|
566
|
+
for data in batch:
|
567
|
+
x.append(data.x)
|
568
|
+
edge_index.append(data.edge_index + node_offset)
|
569
|
+
edge_attr.append(data.edge_attr)
|
570
|
+
y.append(data.y)
|
571
|
+
overall_edge_index.append(data.overall_edge_index + edge_offset)
|
572
|
+
edge_classes.append(data.edge_classes)
|
573
|
+
|
574
|
+
train_edge_mask.append(data.train_edge_mask)
|
575
|
+
test_edge_mask.append(data.test_edge_mask)
|
576
|
+
|
577
|
+
train_pos_edge_label_index.append(data.train_pos_edge_label_index + node_offset)
|
578
|
+
train_pos_edge_label.append(data.train_pos_edge_label)
|
579
|
+
train_neg_edge_label_index.append(data.train_neg_edge_label_index + node_offset)
|
580
|
+
train_neg_edge_label.append(data.train_neg_edge_label)
|
581
|
+
|
582
|
+
test_pos_edge_label_index.append(data.test_pos_edge_label_index + node_offset)
|
583
|
+
test_pos_edge_label.append(data.test_pos_edge_label)
|
584
|
+
test_neg_edge_label_index.append(data.test_neg_edge_label_index + node_offset)
|
585
|
+
test_neg_edge_label.append(data.test_neg_edge_label)
|
586
|
+
|
587
|
+
node_offset += data.num_nodes
|
588
|
+
edge_offset += data.edge_attr.size(0)
|
589
|
+
|
590
|
+
return GraphData(
|
591
|
+
x=torch.cat(x, dim=0),
|
592
|
+
edge_index=torch.cat(edge_index, dim=1),
|
593
|
+
edge_attr=torch.cat(edge_attr, dim=0),
|
594
|
+
y=torch.tensor(y),
|
595
|
+
overall_edge_index=torch.cat(overall_edge_index, dim=1),
|
596
|
+
edge_classes=torch.cat(edge_classes),
|
597
|
+
train_edge_mask=torch.cat(train_edge_mask),
|
598
|
+
test_edge_mask=torch.cat(test_edge_mask),
|
599
|
+
train_pos_edge_label_index=torch.cat(train_pos_edge_label_index, dim=1),
|
600
|
+
train_pos_edge_label=torch.cat(train_pos_edge_label),
|
601
|
+
train_neg_edge_label_index=torch.cat(train_neg_edge_label_index, dim=1),
|
602
|
+
train_neg_edge_label=torch.cat(train_neg_edge_label),
|
603
|
+
test_pos_edge_label_index=torch.cat(test_pos_edge_label_index, dim=1),
|
604
|
+
test_pos_edge_label=torch.cat(test_pos_edge_label),
|
605
|
+
test_neg_edge_label_index=torch.cat(test_neg_edge_label_index, dim=1),
|
606
|
+
test_neg_edge_label=torch.cat(test_neg_edge_label),
|
607
|
+
num_nodes=node_offset
|
608
|
+
)
|
609
|
+
|
610
|
+
|
611
|
+
class LinkPredictionDataLoader(torch.utils.data.DataLoader):
|
612
|
+
def __init__(
|
613
|
+
self,
|
614
|
+
dataset: Union[Dataset, Sequence[Data]],
|
615
|
+
batch_size: int = 1,
|
616
|
+
shuffle: bool = False,
|
617
|
+
collate_fn=None,
|
618
|
+
follow_batch: Optional[List[str]] = None,
|
619
|
+
exclude_keys: Optional[List[str]] = None,
|
620
|
+
**kwargs,
|
621
|
+
):
|
622
|
+
if collate_fn is None:
|
623
|
+
collate_fn = LinkPredictionCollater(follow_batch, exclude_keys)
|
624
|
+
|
625
|
+
super().__init__(
|
626
|
+
dataset,
|
627
|
+
batch_size,
|
628
|
+
shuffle,
|
629
|
+
collate_fn=collate_fn,
|
630
|
+
**kwargs,
|
631
|
+
)
|