glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
glam4cm/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ import warnings
2
+
3
+ warnings.filterwarnings(
4
+ "ignore",
5
+ message="Pydantic serializer warnings:",
6
+ category=UserWarning,
7
+ module="pydantic.main",
8
+ )
9
+ __version__ = "0.1.0"
File without changes
@@ -0,0 +1,631 @@
1
+ import pickle
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ import torch
5
+ import networkx as nx
6
+ import os
7
+ from glam4cm.data_loading.metadata import (
8
+ ArchimateMetaData,
9
+ EcoreMetaData
10
+ )
11
+ from glam4cm.embeddings.common import Embedder
12
+ from glam4cm.lang2graph.archimate import ArchiMateNxG
13
+ from glam4cm.lang2graph.ecore import EcoreNxG
14
+ from glam4cm.lang2graph.common import (
15
+ create_graph_from_edge_index,
16
+ get_node_texts,
17
+ get_edge_texts
18
+ )
19
+
20
+ from scipy.sparse import csr_matrix
21
+
22
+ from glam4cm.settings import LP_TASK_EDGE_CLS, LP_TASK_LINK_PRED
23
+ from glam4cm.tokenization.special_tokens import *
24
+ from torch_geometric.transforms import RandomLinkSplit
25
+ import torch
26
+ from torch_geometric.data import Data, Dataset
27
+ from typing import List, Optional, Sequence, Union
28
+ from glam4cm.tokenization.utils import doc_tokenizer
29
+
30
+
31
+
32
+ def edge_index_to_idx(graph, edge_index):
33
+ return torch.tensor(
34
+ [
35
+ graph.edge_to_idx[(u, v)]
36
+ for u, v in edge_index.t().tolist()
37
+ ],
38
+ dtype=torch.long
39
+ )
40
+
41
+
42
+ class GraphData(Data):
43
+ def __inc__(self, key, value, *args, **kwargs):
44
+ if 'index' in key or 'node_mask' in key:
45
+ return self.num_nodes
46
+ elif 'edge_mask' in key:
47
+ return self.num_edges
48
+ else:
49
+ return 0
50
+
51
+ def __cat_dim__(self, key, value, *args, **kwargs):
52
+ if 'index' in key:
53
+ return 1
54
+ else:
55
+ return 0
56
+
57
+
58
+ class NumpyData:
59
+ def __init__(self, data: dict = {}):
60
+ self.set_data(data)
61
+
62
+ def __getitem__(self, key):
63
+ return getattr(self, key)
64
+
65
+ def set_data(self, data: dict):
66
+ for k, v in data.items():
67
+ if isinstance(v, torch.Tensor):
68
+ v = v.numpy()
69
+ setattr(self, k, v)
70
+
71
+ def __repr__(self):
72
+ response = "NumpyData(" + ", ".join([
73
+ f"{k}={list(v.shape)}" if isinstance(v, np.ndarray)
74
+ else f"{k}={v}"
75
+ for k, v in self.__dict__.items()
76
+ ]) + ")"
77
+ return response
78
+
79
+ def to_graph_data(self):
80
+ data = GraphData()
81
+ for k, v in self.__dict__.items():
82
+ if isinstance(v, np.ndarray):
83
+ v = torch.from_numpy(v)
84
+ elif isinstance(v, csr_matrix):
85
+ v = torch.from_numpy(v.toarray())
86
+ elif isinstance(v, int):
87
+ v = torch.tensor(v, dtype=torch.long)
88
+
89
+ if v.dtype == torch.float64:
90
+ v = v.float()
91
+
92
+ setattr(data, k, v)
93
+ return data
94
+
95
+
96
+ class TorchGraph:
97
+ def __init__(
98
+ self,
99
+ graph: Union[EcoreNxG, ArchiMateNxG],
100
+ metadata: Union[EcoreMetaData, ArchimateMetaData],
101
+ distance = 0,
102
+ test_ratio=0.2,
103
+ use_edge_types=False,
104
+ use_node_types=False,
105
+ use_attributes=False,
106
+ use_edge_label=False,
107
+ use_special_tokens=False,
108
+ no_labels=False,
109
+ node_cls_label=None,
110
+ edge_cls_label='type',
111
+ fp='test_graph.pkl'
112
+ ):
113
+
114
+ self.fp = fp
115
+ self.graph = graph
116
+ self.metadata = metadata
117
+
118
+ self.raw_data = graph.xmi if hasattr(graph, 'xmi') else graph.json_obj
119
+ self.use_edge_types = use_edge_types
120
+ self.use_node_types = use_node_types
121
+ self.use_attributes = use_attributes
122
+ self.use_edge_label = use_edge_label
123
+ self.use_special_tokens = use_special_tokens
124
+ self.no_labels = no_labels
125
+
126
+ self.node_cls_label = node_cls_label
127
+ self.edge_cls_label = edge_cls_label
128
+
129
+ self.distance = distance
130
+ self.test_ratio = test_ratio
131
+ self.data = NumpyData()
132
+
133
+
134
+ def load(pkl_path):
135
+ with open(pkl_path, 'rb') as f:
136
+ return pickle.load(f)
137
+
138
+ def save(self):
139
+ os.makedirs(os.path.dirname(self.fp), exist_ok=True)
140
+ with open(self.fp, 'wb') as f:
141
+ pickle.dump(self, f)
142
+
143
+
144
+ def get_node_edge_strings(self, edge_index):
145
+ node_texts = self.get_graph_node_strs(
146
+ edge_index=edge_index,
147
+ distance=self.distance
148
+ )
149
+
150
+ edge_texts = self.get_graph_edge_strs()
151
+ return node_texts, edge_texts
152
+
153
+
154
+ def embed(
155
+ self,
156
+ embedder: Union[Embedder, None],
157
+ reload=False,
158
+ randomize_ne=False,
159
+ randomize_ee=False,
160
+ random_embed_dim=128
161
+ ):
162
+
163
+ def generate_embeddings():
164
+ if randomize_ne or embedder is None:
165
+ print("Randomizing node embeddings")
166
+ self.data.x = np.random.randn(self.graph.number_of_nodes(), random_embed_dim)
167
+ else:
168
+ self.data.x = embedder.embed(list(self.node_texts.values()))
169
+
170
+ if randomize_ee or embedder is None:
171
+ print("Randomizing edge embeddings")
172
+ self.data.edge_attr = np.random.randn(self.graph.number_of_edges(), random_embed_dim)
173
+ else:
174
+ self.data.edge_attr = embedder.embed(list(self.edge_texts.values()))
175
+
176
+ if os.path.exists(f"{self.fp}") and not reload:
177
+ with open(f"{self.fp}", 'rb') as f:
178
+ obj: Union[TorchEdgeGraph, TorchNodeGraph] = pickle.load(f)
179
+ if not hasattr(obj.data, 'x') or not hasattr(obj.data, 'edge_attr'):
180
+ generate_embeddings()
181
+ self.save()
182
+ else:
183
+ if embedder is not None:
184
+ generate_embeddings()
185
+ else:
186
+ self.data.x = np.random.randn(self.graph.number_of_nodes(), random_embed_dim)
187
+ self.data.edge_attr = np.random.randn(self.graph.number_of_edges(), random_embed_dim)
188
+
189
+ self.save()
190
+
191
+
192
+ def get_graph_node_strs(
193
+ self,
194
+ edge_index: np.ndarray,
195
+ distance = None,
196
+ preprocessor: callable = doc_tokenizer
197
+ ):
198
+ if distance is None:
199
+ distance = self.distance
200
+
201
+ subgraph = create_graph_from_edge_index(self.graph, edge_index)
202
+ return get_node_texts(
203
+ subgraph,
204
+ distance,
205
+ metadata=self.metadata,
206
+ use_node_attributes=self.use_attributes,
207
+ use_node_types=self.use_node_types,
208
+ use_edge_types=self.use_edge_types,
209
+ use_edge_label=self.use_edge_label,
210
+ node_cls_label=self.node_cls_label,
211
+ edge_cls_label=self.edge_cls_label,
212
+ use_special_tokens=self.use_special_tokens,
213
+ no_labels=self.no_labels,
214
+ preprocessor=preprocessor
215
+ )
216
+
217
+
218
+ def get_graph_edge_strs(
219
+ self,
220
+ edge_index: np.ndarray = None,
221
+ neg_samples=False,
222
+ preprocessor: callable = doc_tokenizer
223
+ ):
224
+ if edge_index is None:
225
+ edge_index = self.graph.edge_index
226
+
227
+ edge_strs = dict()
228
+ for u, v in edge_index.T:
229
+ edge_str = get_edge_texts(
230
+ self.graph.numbered_graph,
231
+ (u, v),
232
+ d=self.distance,
233
+ metadata=self.metadata,
234
+ use_node_attributes=self.use_attributes,
235
+ use_node_types=self.use_node_types,
236
+ use_edge_types=self.use_edge_types,
237
+ use_edge_label=self.use_edge_label,
238
+ use_special_tokens=self.use_special_tokens,
239
+ no_labels=self.no_labels,
240
+ preprocessor=preprocessor,
241
+ neg_samples=neg_samples
242
+ )
243
+
244
+ edge_strs[(u, v)] = edge_str
245
+
246
+ return edge_strs
247
+
248
+
249
+ def validate_data(self):
250
+ assert self.data.num_nodes == self.graph.number_of_nodes()
251
+
252
+ def set_graph_label(self):
253
+ if self.metadata.graph_label is not None and not hasattr(self.graph, self.metadata.graph_label): #Graph has a label
254
+ text = doc_tokenizer("\n".join(list(self.node_texts.values())))
255
+ # print("Text:", text)
256
+ # print("-" * 100)
257
+ setattr(self.graph, self.metadata.graph_label, text)
258
+
259
+
260
+ @property
261
+ def name(self):
262
+ return '.'.join(self.graph.graph_id.replace('/', '_').split('.')[:-1])
263
+
264
+
265
+
266
+ class TorchEdgeGraph(TorchGraph):
267
+ def __init__(
268
+ self,
269
+ graph: Union[EcoreNxG, ArchiMateNxG],
270
+ metadata: Union[EcoreMetaData, ArchimateMetaData],
271
+ distance = 1,
272
+ test_ratio=0.2,
273
+ add_negative_train_samples=False,
274
+ neg_samples_ratio=1,
275
+ use_edge_types=False,
276
+ use_node_types=False,
277
+ use_edge_label=False,
278
+ use_attributes=False,
279
+ use_special_tokens=False,
280
+ node_cls_label=None,
281
+ edge_cls_label='type',
282
+ no_labels=False,
283
+ fp: str = 'test_graph.pkl'
284
+ ):
285
+
286
+ super().__init__(
287
+ graph=graph,
288
+ metadata=metadata,
289
+ distance=distance,
290
+ test_ratio=test_ratio,
291
+ use_node_types=use_node_types,
292
+ use_edge_types=use_edge_types,
293
+ use_attributes=use_attributes,
294
+ use_edge_label=use_edge_label,
295
+ use_special_tokens=use_special_tokens,
296
+ no_labels=no_labels,
297
+ node_cls_label=node_cls_label,
298
+ edge_cls_label=edge_cls_label,
299
+ fp=fp
300
+ )
301
+ self.add_negative_train_samples = add_negative_train_samples
302
+ self.neg_sampling_ratio = neg_samples_ratio
303
+ self.data, self.node_texts, self.edge_texts = self.get_pyg_data()
304
+ self.validate_data()
305
+ self.set_graph_label()
306
+
307
+
308
+
309
+
310
+ def get_pyg_data(self):
311
+
312
+ d = GraphData()
313
+
314
+ transform = RandomLinkSplit(
315
+ num_val=0,
316
+ num_test=self.test_ratio,
317
+ add_negative_train_samples=self.add_negative_train_samples,
318
+ neg_sampling_ratio=self.neg_sampling_ratio,
319
+ split_labels=True
320
+ )
321
+
322
+ try:
323
+ train_data, _, test_data = transform(GraphData(
324
+ edge_index=torch.tensor(self.graph.edge_index),
325
+ num_nodes=self.graph.number_of_nodes()
326
+ ))
327
+ except IndexError as e:
328
+ print(self.graph.edge_index)
329
+ raise e
330
+
331
+ train_idx = edge_index_to_idx(self.graph, train_data.edge_index)
332
+ test_idx = edge_index_to_idx(self.graph, test_data.pos_edge_label_index)
333
+
334
+ setattr(d, 'train_edge_mask', train_idx)
335
+ setattr(d, 'test_edge_mask', test_idx)
336
+
337
+
338
+ assert all([self.graph.numbered_graph.has_edge(*edge) for edge in train_data.edge_index.t().tolist()])
339
+ assert all([self.graph.numbered_graph.has_edge(*edge) for edge in test_data.pos_edge_label_index.t().tolist()])
340
+
341
+ setattr(d, 'train_pos_edge_label_index', train_data.pos_edge_label_index)
342
+ setattr(d, 'train_pos_edge_label', train_data.pos_edge_label)
343
+ setattr(d, 'test_pos_edge_label_index', test_data.pos_edge_label_index)
344
+ setattr(d, 'test_pos_edge_label', test_data.pos_edge_label)
345
+
346
+
347
+ if hasattr(train_data, 'neg_edge_label_index'):
348
+ assert not any([self.graph.numbered_graph.has_edge(*edge) for edge in train_data.neg_edge_label_index.t().tolist()])
349
+ assert not any([self.graph.numbered_graph.has_edge(*edge) for edge in test_data.neg_edge_label_index.t().tolist()])
350
+ setattr(d, 'train_neg_edge_label_index', train_data.neg_edge_label_index)
351
+ setattr(d, 'train_neg_edge_label', train_data.neg_edge_label)
352
+ setattr(d, 'test_neg_edge_label_index', test_data.neg_edge_label_index)
353
+ setattr(d, 'test_neg_edge_label', test_data.neg_edge_label)
354
+
355
+
356
+ nx.set_edge_attributes(
357
+ self.graph.numbered_graph,
358
+ {tuple(edge): False for edge in train_data.pos_edge_label_index.T.tolist()},
359
+ 'masked'
360
+ )
361
+ nx.set_edge_attributes(
362
+ self.graph.numbered_graph,
363
+ {tuple(edge): True for edge in test_data.pos_edge_label_index.T.tolist()},
364
+ 'masked'
365
+ )
366
+
367
+ edge_index = train_data.edge_index
368
+ # import code; code.interact(local=locals())
369
+ setattr(d, 'overall_edge_index', self.graph.edge_index)
370
+ setattr(d, 'edge_index', edge_index)
371
+
372
+ node_texts, edge_texts = self.get_node_edge_strings(
373
+ edge_index=edge_index.numpy(),
374
+ )
375
+
376
+ setattr(d, 'num_nodes', self.graph.number_of_nodes())
377
+ setattr(d, 'num_edges', self.graph.number_of_edges())
378
+ d = NumpyData(d)
379
+ return d, node_texts, edge_texts
380
+
381
+
382
+ def get_link_prediction_texts(self, label, task_type, only_texts=False):
383
+ data = dict()
384
+ train_pos_edge_index = self.data.edge_index
385
+ test_pos_edge_index = self.data.test_pos_edge_label_index
386
+
387
+ if task_type == LP_TASK_LINK_PRED:
388
+ train_neg_edge_index = self.data.train_neg_edge_label_index
389
+ test_neg_edge_index = self.data.test_neg_edge_label_index
390
+ else:
391
+ train_neg_edge_index = None
392
+ test_neg_edge_index = None
393
+
394
+ validate_edges(self)
395
+
396
+ # print(train_neg_edge_index.shape)
397
+
398
+ edge_indices = {
399
+ 'train_pos': train_pos_edge_index,
400
+ 'train_neg': train_neg_edge_index,
401
+ 'test_pos': test_pos_edge_index,
402
+ 'test_neg': test_neg_edge_index
403
+ }
404
+
405
+ for edge_index_label, edge_index in edge_indices.items():
406
+ if edge_index is None:
407
+ continue
408
+ edge_strs = self.get_graph_edge_strs(
409
+ edge_index=edge_index,
410
+ neg_samples="neg" in edge_index_label,
411
+ )
412
+
413
+ edge_strs = list(edge_strs.values())
414
+ data[f'{edge_index_label}_edges'] = edge_strs
415
+
416
+
417
+ if task_type == LP_TASK_EDGE_CLS and not only_texts:
418
+ train_mask = self.data.train_edge_mask
419
+ test_mask = self.data.test_edge_mask
420
+ train_classes, test_classes = getattr(self.data, f'edge_{label}')[train_mask], getattr(self.data, f'edge_{label}')[test_mask]
421
+ data['train_edge_classes'] = train_classes.tolist()
422
+ data['test_edge_classes'] = test_classes.tolist()
423
+
424
+ return data
425
+
426
+
427
+
428
+
429
+ class TorchNodeGraph(TorchGraph):
430
+ def __init__(
431
+ self,
432
+ graph: Union[EcoreNxG, ArchiMateNxG],
433
+ metadata: dict,
434
+ distance = 1,
435
+ test_ratio=0.2,
436
+ use_node_types=False,
437
+ use_edge_types=False,
438
+ use_edge_label=False,
439
+ use_attributes=False,
440
+ use_special_tokens=False,
441
+ no_labels=False,
442
+ node_cls_label=None,
443
+ edge_cls_label='type',
444
+ fp='test_graph.pkl'
445
+ ):
446
+
447
+ super().__init__(
448
+ graph,
449
+ metadata=metadata,
450
+ distance=distance,
451
+ test_ratio=test_ratio,
452
+ use_node_types=use_node_types,
453
+ use_edge_types=use_edge_types,
454
+ use_edge_label=use_edge_label,
455
+ use_attributes=use_attributes,
456
+ use_special_tokens=use_special_tokens,
457
+ no_labels=no_labels,
458
+ node_cls_label=node_cls_label,
459
+ edge_cls_label=edge_cls_label,
460
+ fp=fp
461
+ )
462
+
463
+ self.data, self.node_texts, self.edge_texts = self.get_pyg_data()
464
+ self.validate_data()
465
+ self.set_graph_label()
466
+
467
+
468
+
469
+ def get_pyg_data(self):
470
+ d = GraphData()
471
+ train_nodes, test_nodes = train_test_split(
472
+ list(self.graph.numbered_graph.nodes),
473
+ test_size=self.test_ratio,
474
+ shuffle=True,
475
+ random_state=42
476
+ )
477
+
478
+ nx.set_node_attributes(self.graph.numbered_graph, {node: False for node in train_nodes}, 'masked')
479
+ nx.set_node_attributes(self.graph.numbered_graph, {node: True for node in test_nodes}, 'masked')
480
+
481
+ train_idx = torch.tensor(train_nodes, dtype=torch.long)
482
+ test_idx = torch.tensor(test_nodes, dtype=torch.long)
483
+
484
+ setattr(d, 'train_node_mask', train_idx)
485
+ setattr(d, 'test_node_mask', test_idx)
486
+
487
+
488
+ assert all([self.graph.numbered_graph.has_node(n) for n in train_nodes])
489
+ assert all([self.graph.numbered_graph.has_node(n) for n in test_nodes])
490
+
491
+
492
+
493
+ edge_index = self.graph.edge_index
494
+ setattr(d, 'edge_index', edge_index)
495
+
496
+ node_texts, edge_texts = self.get_node_edge_strings(
497
+ edge_index=edge_index,
498
+ )
499
+
500
+ setattr(d, 'num_nodes', self.graph.number_of_nodes())
501
+ d = NumpyData(d)
502
+ return d, node_texts, edge_texts
503
+
504
+
505
+ @property
506
+ def name(self):
507
+ return '.'.join(self.graph.graph_id.replace('/', '_').split('.')[:-1])
508
+
509
+
510
+
511
+ def validate_edges(graph: Union[TorchEdgeGraph, TorchNodeGraph]):
512
+
513
+ train_pos_edge_index = graph.data.edge_index
514
+ test_pos_edge_index = graph.data.test_pos_edge_label_index
515
+ train_neg_edge_index = graph.data.train_neg_edge_label_index if hasattr(graph.data, 'train_neg_edge_label_index') else None
516
+ test_neg_edge_index = graph.data.test_neg_edge_label_index if hasattr(graph.data, 'test_neg_edge_label_index') else None
517
+
518
+ assert set((a, b) for a, b in train_pos_edge_index.T.tolist()).issubset(set(graph.graph.numbered_graph.edges()))
519
+ assert set((a, b) for a, b in test_pos_edge_index.T.tolist()).issubset(set(graph.graph.numbered_graph.edges()))
520
+ assert len(set((a, b) for a, b in train_pos_edge_index.T.tolist()).intersection(set((a, b) for a, b in test_pos_edge_index.T.tolist()))) == 0
521
+
522
+ if train_neg_edge_index is not None:
523
+ assert len(set(graph.graph.numbered_graph.edges()).intersection(set((a, b) for a, b in train_neg_edge_index.T.tolist()))) == 0
524
+
525
+ if test_neg_edge_index is not None:
526
+ assert len(set(graph.graph.numbered_graph.edges()).intersection(set((a, b) for a, b in test_neg_edge_index.T.tolist()))) == 0
527
+
528
+ if train_neg_edge_index is not None and test_neg_edge_index is not None:
529
+ assert len(set((a, b) for a, b in train_neg_edge_index.T.tolist()).intersection(set((a, b) for a, b in test_neg_edge_index.T.tolist()))) == 0
530
+
531
+
532
+
533
+
534
+ class LinkPredictionCollater:
535
+ def __init__(
536
+ self,
537
+ follow_batch: Optional[List[str]] = None,
538
+ exclude_keys: Optional[List[str]] = None
539
+ ):
540
+ self.follow_batch = follow_batch
541
+ self.exclude_keys = exclude_keys
542
+
543
+ def __call__(self, batch: List[Data]):
544
+ # Initialize lists to collect batched properties
545
+ x = []
546
+ edge_index = []
547
+ edge_attr = []
548
+ y = []
549
+ overall_edge_index = []
550
+ edge_classes = []
551
+ train_edge_mask = []
552
+ test_edge_mask = []
553
+ train_pos_edge_label_index = []
554
+ train_pos_edge_label = []
555
+ train_neg_edge_label_index = []
556
+ train_neg_edge_label = []
557
+ test_pos_edge_label_index = []
558
+ test_pos_edge_label = []
559
+ test_neg_edge_label_index = []
560
+ test_neg_edge_label = []
561
+
562
+ # Offsets for edge indices
563
+ node_offset = 0
564
+ edge_offset = 0
565
+
566
+ for data in batch:
567
+ x.append(data.x)
568
+ edge_index.append(data.edge_index + node_offset)
569
+ edge_attr.append(data.edge_attr)
570
+ y.append(data.y)
571
+ overall_edge_index.append(data.overall_edge_index + edge_offset)
572
+ edge_classes.append(data.edge_classes)
573
+
574
+ train_edge_mask.append(data.train_edge_mask)
575
+ test_edge_mask.append(data.test_edge_mask)
576
+
577
+ train_pos_edge_label_index.append(data.train_pos_edge_label_index + node_offset)
578
+ train_pos_edge_label.append(data.train_pos_edge_label)
579
+ train_neg_edge_label_index.append(data.train_neg_edge_label_index + node_offset)
580
+ train_neg_edge_label.append(data.train_neg_edge_label)
581
+
582
+ test_pos_edge_label_index.append(data.test_pos_edge_label_index + node_offset)
583
+ test_pos_edge_label.append(data.test_pos_edge_label)
584
+ test_neg_edge_label_index.append(data.test_neg_edge_label_index + node_offset)
585
+ test_neg_edge_label.append(data.test_neg_edge_label)
586
+
587
+ node_offset += data.num_nodes
588
+ edge_offset += data.edge_attr.size(0)
589
+
590
+ return GraphData(
591
+ x=torch.cat(x, dim=0),
592
+ edge_index=torch.cat(edge_index, dim=1),
593
+ edge_attr=torch.cat(edge_attr, dim=0),
594
+ y=torch.tensor(y),
595
+ overall_edge_index=torch.cat(overall_edge_index, dim=1),
596
+ edge_classes=torch.cat(edge_classes),
597
+ train_edge_mask=torch.cat(train_edge_mask),
598
+ test_edge_mask=torch.cat(test_edge_mask),
599
+ train_pos_edge_label_index=torch.cat(train_pos_edge_label_index, dim=1),
600
+ train_pos_edge_label=torch.cat(train_pos_edge_label),
601
+ train_neg_edge_label_index=torch.cat(train_neg_edge_label_index, dim=1),
602
+ train_neg_edge_label=torch.cat(train_neg_edge_label),
603
+ test_pos_edge_label_index=torch.cat(test_pos_edge_label_index, dim=1),
604
+ test_pos_edge_label=torch.cat(test_pos_edge_label),
605
+ test_neg_edge_label_index=torch.cat(test_neg_edge_label_index, dim=1),
606
+ test_neg_edge_label=torch.cat(test_neg_edge_label),
607
+ num_nodes=node_offset
608
+ )
609
+
610
+
611
+ class LinkPredictionDataLoader(torch.utils.data.DataLoader):
612
+ def __init__(
613
+ self,
614
+ dataset: Union[Dataset, Sequence[Data]],
615
+ batch_size: int = 1,
616
+ shuffle: bool = False,
617
+ collate_fn=None,
618
+ follow_batch: Optional[List[str]] = None,
619
+ exclude_keys: Optional[List[str]] = None,
620
+ **kwargs,
621
+ ):
622
+ if collate_fn is None:
623
+ collate_fn = LinkPredictionCollater(follow_batch, exclude_keys)
624
+
625
+ super().__init__(
626
+ dataset,
627
+ batch_size,
628
+ shuffle,
629
+ collate_fn=collate_fn,
630
+ **kwargs,
631
+ )