pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,677 @@
1
+ import os
2
+ import pickle as pkl
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from itertools import chain
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Hashable,
11
+ Iterable,
12
+ Iterator,
13
+ List,
14
+ Optional,
15
+ Sequence,
16
+ Set,
17
+ Tuple,
18
+ Union,
19
+ )
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from tqdm import tqdm
24
+
25
+ from torch_geometric.data import Data
26
+ from torch_geometric.typing import WITH_PT24
27
+
28
+ TripletLike = Tuple[Hashable, Hashable, Hashable]
29
+
30
+ KnowledgeGraphLike = Iterable[TripletLike]
31
+
32
+
33
+ def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
34
+ return list(dict.fromkeys(values))
35
+
36
+
37
+ # TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
38
+
39
+ NODE_PID = "pid"
40
+
41
+ NODE_KEYS = {NODE_PID}
42
+
43
+ EDGE_PID = "e_pid"
44
+ EDGE_HEAD = "h"
45
+ EDGE_RELATION = "r"
46
+ EDGE_TAIL = "t"
47
+ EDGE_INDEX = "edge_idx"
48
+
49
+ EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
50
+
51
+ FeatureValueType = Union[Sequence[Any], Tensor]
52
+
53
+
54
+ @dataclass
55
+ class MappedFeature:
56
+ name: str
57
+ values: FeatureValueType
58
+
59
+ def __eq__(self, value: "MappedFeature") -> bool:
60
+ eq = self.name == value.name
61
+ if isinstance(self.values, torch.Tensor):
62
+ eq &= torch.equal(self.values, value.values)
63
+ else:
64
+ eq &= self.values == value.values
65
+ return eq
66
+
67
+
68
+ if WITH_PT24:
69
+ torch.serialization.add_safe_globals([MappedFeature])
70
+
71
+
72
+ class LargeGraphIndexer:
73
+ """For a dataset that consists of mulitiple subgraphs that are assumed to
74
+ be part of a much larger graph, collate the values into a large graph store
75
+ to save resources.
76
+ """
77
+ def __init__(
78
+ self,
79
+ nodes: Iterable[Hashable],
80
+ edges: KnowledgeGraphLike,
81
+ node_attr: Optional[Dict[str, List[Any]]] = None,
82
+ edge_attr: Optional[Dict[str, List[Any]]] = None,
83
+ ) -> None:
84
+ r"""Constructs a new index that uniquely catalogs each node and edge
85
+ by id. Not meant to be used directly.
86
+
87
+ Args:
88
+ nodes (Iterable[Hashable]): Node ids in the graph.
89
+ edges (KnowledgeGraphLike): Edge ids in the graph.
90
+ node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
91
+ attribute name and list of their values in order of unique node
92
+ ids. Defaults to None.
93
+ edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
94
+ attribute name and list of their values in order of unique edge
95
+ ids. Defaults to None.
96
+ """
97
+ self._nodes: Dict[Hashable, int] = dict()
98
+ self._edges: Dict[TripletLike, int] = dict()
99
+
100
+ self._mapped_node_features: Set[str] = set()
101
+ self._mapped_edge_features: Set[str] = set()
102
+
103
+ if len(nodes) != len(set(nodes)):
104
+ raise AttributeError("Nodes need to be unique")
105
+ if len(edges) != len(set(edges)):
106
+ raise AttributeError("Edges need to be unique")
107
+
108
+ if node_attr is not None:
109
+ # TODO: Validity checks btw nodes and node_attr
110
+ self.node_attr = node_attr
111
+ if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
112
+ raise AttributeError(
113
+ "Invalid node_attr object. Missing " +
114
+ f"{NODE_KEYS - set(self.node_attr.keys())}")
115
+ elif self.node_attr[NODE_PID] != nodes:
116
+ raise AttributeError(
117
+ "Nodes provided do not match those in node_attr")
118
+ else:
119
+ self.node_attr = dict()
120
+ self.node_attr[NODE_PID] = nodes
121
+
122
+ for i, node in enumerate(self.node_attr[NODE_PID]):
123
+ self._nodes[node] = i
124
+
125
+ if edge_attr is not None:
126
+ # TODO: Validity checks btw edges and edge_attr
127
+ self.edge_attr = edge_attr
128
+
129
+ if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
130
+ raise AttributeError(
131
+ "Invalid edge_attr object. Missing " +
132
+ f"{EDGE_KEYS - set(self.edge_attr.keys())}")
133
+ elif self.node_attr[EDGE_PID] != edges:
134
+ raise AttributeError(
135
+ "Edges provided do not match those in edge_attr")
136
+
137
+ else:
138
+ self.edge_attr = dict()
139
+ for default_key in EDGE_KEYS:
140
+ self.edge_attr[default_key] = list()
141
+ self.edge_attr[EDGE_PID] = edges
142
+
143
+ for i, tup in enumerate(edges):
144
+ h, r, t = tup
145
+ self.edge_attr[EDGE_HEAD].append(h)
146
+ self.edge_attr[EDGE_RELATION].append(r)
147
+ self.edge_attr[EDGE_TAIL].append(t)
148
+ self.edge_attr[EDGE_INDEX].append(
149
+ (self._nodes[h], self._nodes[t]))
150
+
151
+ for i, tup in enumerate(edges):
152
+ self._edges[tup] = i
153
+
154
+ @classmethod
155
+ def from_triplets(
156
+ cls,
157
+ triplets: KnowledgeGraphLike,
158
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
159
+ ) -> "LargeGraphIndexer":
160
+ r"""Generate a new index from a series of triplets that represent edge
161
+ relations between nodes.
162
+ Formatted like (source_node, edge, dest_node).
163
+
164
+ Args:
165
+ triplets (KnowledgeGraphLike): Series of triplets representing
166
+ knowledge graph relations.
167
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
168
+ Optional preprocessing function to apply to triplets.
169
+ Defaults to None.
170
+
171
+ Returns:
172
+ LargeGraphIndexer: Index of unique nodes and edges.
173
+ """
174
+ # NOTE: Right now assumes that all trips can be loaded into memory
175
+ nodes = set()
176
+ edges = set()
177
+
178
+ if pre_transform is not None:
179
+
180
+ def apply_transform(
181
+ trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
182
+ for trip in trips:
183
+ yield pre_transform(trip)
184
+
185
+ triplets = apply_transform(triplets)
186
+
187
+ for h, r, t in triplets:
188
+
189
+ for node in (h, t):
190
+ nodes.add(node)
191
+
192
+ edge_idx = (h, r, t)
193
+ edges.add(edge_idx)
194
+
195
+ return cls(list(nodes), list(edges))
196
+
197
+ @classmethod
198
+ def collate(cls,
199
+ graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
200
+ r"""Combines a series of large graph indexes into a single large graph
201
+ index.
202
+
203
+ Args:
204
+ graphs (Iterable["LargeGraphIndexer"]): Indices to be
205
+ combined.
206
+
207
+ Returns:
208
+ LargeGraphIndexer: Singular unique index for all nodes and edges
209
+ in input indices.
210
+ """
211
+ # FIXME Needs to merge node attrs and edge attrs?
212
+ trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
213
+ return cls.from_triplets(trips)
214
+
215
+ def get_unique_node_features(
216
+ self, feature_name: str = NODE_PID) -> List[Hashable]:
217
+ r"""Get all the unique values for a specific node attribute.
218
+
219
+ Args:
220
+ feature_name (str, optional): Name of feature to get.
221
+ Defaults to NODE_PID.
222
+
223
+ Returns:
224
+ List[Hashable]: List of unique values for the specified feature.
225
+ """
226
+ try:
227
+ if feature_name in self._mapped_node_features:
228
+ raise IndexError(
229
+ "Only non-mapped features can be retrieved uniquely.")
230
+ return ordered_set(self.get_node_features(feature_name))
231
+
232
+ except KeyError:
233
+ raise AttributeError(
234
+ f"Nodes do not have a feature called {feature_name}")
235
+
236
+ def add_node_feature(
237
+ self,
238
+ new_feature_name: str,
239
+ new_feature_vals: FeatureValueType,
240
+ map_from_feature: str = NODE_PID,
241
+ ) -> None:
242
+ r"""Adds a new feature that corresponds to each unique node in
243
+ the graph.
244
+
245
+ Args:
246
+ new_feature_name (str): Name to call the new feature.
247
+ new_feature_vals (FeatureValueType): Values to map for that
248
+ new feature.
249
+ map_from_feature (str, optional): Key of feature to map from.
250
+ Size must match the number of feature values.
251
+ Defaults to NODE_PID.
252
+ """
253
+ if new_feature_name in self.node_attr:
254
+ raise AttributeError("Features cannot be overridden once created")
255
+ if map_from_feature in self._mapped_node_features:
256
+ raise AttributeError(
257
+ f"{map_from_feature} is already a feature mapping.")
258
+
259
+ feature_keys = self.get_unique_node_features(map_from_feature)
260
+ if len(feature_keys) != len(new_feature_vals):
261
+ raise AttributeError(
262
+ "Expected encodings for {len(feature_keys)} unique features," +
263
+ f" but got {len(new_feature_vals)} encodings.")
264
+
265
+ if map_from_feature == NODE_PID:
266
+ self.node_attr[new_feature_name] = new_feature_vals
267
+ else:
268
+ self.node_attr[new_feature_name] = MappedFeature(
269
+ name=map_from_feature, values=new_feature_vals)
270
+ self._mapped_node_features.add(new_feature_name)
271
+
272
+ def get_node_features(
273
+ self,
274
+ feature_name: str = NODE_PID,
275
+ pids: Optional[Iterable[Hashable]] = None,
276
+ ) -> List[Any]:
277
+ r"""Get node feature values for a given set of unique node ids.
278
+ Returned values are not necessarily unique.
279
+
280
+ Args:
281
+ feature_name (str, optional): Name of feature to fetch. Defaults
282
+ to NODE_PID.
283
+ pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
284
+ for. Defaults to None, which fetches all nodes.
285
+
286
+ Returns:
287
+ List[Any]: Node features corresponding to the specified ids.
288
+ """
289
+ if feature_name in self._mapped_node_features:
290
+ values = self.node_attr[feature_name].values
291
+ else:
292
+ values = self.node_attr[feature_name]
293
+
294
+ # TODO: torch_geometric.utils.select
295
+ if isinstance(values, torch.Tensor):
296
+ idxs = list(
297
+ self.get_node_features_iter(feature_name, pids,
298
+ index_only=True))
299
+ return values[idxs]
300
+ return list(self.get_node_features_iter(feature_name, pids))
301
+
302
+ def get_node_features_iter(
303
+ self,
304
+ feature_name: str = NODE_PID,
305
+ pids: Optional[Iterable[Hashable]] = None,
306
+ index_only: bool = False,
307
+ ) -> Iterator[Any]:
308
+ """Iterator version of get_node_features. If index_only is True,
309
+ yields indices instead of values.
310
+ """
311
+ if pids is None:
312
+ pids = self.node_attr[NODE_PID]
313
+
314
+ if feature_name in self._mapped_node_features:
315
+ feature_map_info = self.node_attr[feature_name]
316
+ from_feature_name, to_feature_vals = (
317
+ feature_map_info.name,
318
+ feature_map_info.values,
319
+ )
320
+ from_feature_vals = self.get_unique_node_features(
321
+ from_feature_name)
322
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
323
+
324
+ for pid in pids:
325
+ idx = self._nodes[pid]
326
+ from_feature_val = self.node_attr[from_feature_name][idx]
327
+ to_feature_idx = feature_mapping[from_feature_val]
328
+ if index_only:
329
+ yield to_feature_idx
330
+ else:
331
+ yield to_feature_vals[to_feature_idx]
332
+ else:
333
+ for pid in pids:
334
+ idx = self._nodes[pid]
335
+ if index_only:
336
+ yield idx
337
+ else:
338
+ yield self.node_attr[feature_name][idx]
339
+
340
+ def get_unique_edge_features(
341
+ self, feature_name: str = EDGE_PID) -> List[Hashable]:
342
+ r"""Get all the unique values for a specific edge attribute.
343
+
344
+ Args:
345
+ feature_name (str, optional): Name of feature to get.
346
+ Defaults to EDGE_PID.
347
+
348
+ Returns:
349
+ List[Hashable]: List of unique values for the specified feature.
350
+ """
351
+ try:
352
+ if feature_name in self._mapped_edge_features:
353
+ raise IndexError(
354
+ "Only non-mapped features can be retrieved uniquely.")
355
+ return ordered_set(self.get_edge_features(feature_name))
356
+ except KeyError:
357
+ raise AttributeError(
358
+ f"Edges do not have a feature called {feature_name}")
359
+
360
+ def add_edge_feature(
361
+ self,
362
+ new_feature_name: str,
363
+ new_feature_vals: FeatureValueType,
364
+ map_from_feature: str = EDGE_PID,
365
+ ) -> None:
366
+ r"""Adds a new feature that corresponds to each unique edge in
367
+ the graph.
368
+
369
+ Args:
370
+ new_feature_name (str): Name to call the new feature.
371
+ new_feature_vals (FeatureValueType): Values to map for that new
372
+ feature.
373
+ map_from_feature (str, optional): Key of feature to map from.
374
+ Size must match the number of feature values.
375
+ Defaults to EDGE_PID.
376
+ """
377
+ if new_feature_name in self.edge_attr:
378
+ raise AttributeError("Features cannot be overridden once created")
379
+ if map_from_feature in self._mapped_edge_features:
380
+ raise AttributeError(
381
+ f"{map_from_feature} is already a feature mapping.")
382
+
383
+ feature_keys = self.get_unique_edge_features(map_from_feature)
384
+ if len(feature_keys) != len(new_feature_vals):
385
+ raise AttributeError(
386
+ f"Expected encodings for {len(feature_keys)} unique features, "
387
+ + f"but got {len(new_feature_vals)} encodings.")
388
+
389
+ if map_from_feature == EDGE_PID:
390
+ self.edge_attr[new_feature_name] = new_feature_vals
391
+ else:
392
+ self.edge_attr[new_feature_name] = MappedFeature(
393
+ name=map_from_feature, values=new_feature_vals)
394
+ self._mapped_edge_features.add(new_feature_name)
395
+
396
+ def get_edge_features(
397
+ self,
398
+ feature_name: str = EDGE_PID,
399
+ pids: Optional[Iterable[Hashable]] = None,
400
+ ) -> List[Any]:
401
+ r"""Get edge feature values for a given set of unique edge ids.
402
+ Returned values are not necessarily unique.
403
+
404
+ Args:
405
+ feature_name (str, optional): Name of feature to fetch.
406
+ Defaults to EDGE_PID.
407
+ pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
408
+ for. Defaults to None, which fetches all edges.
409
+
410
+ Returns:
411
+ List[Any]: Node features corresponding to the specified ids.
412
+ """
413
+ if feature_name in self._mapped_edge_features:
414
+ values = self.edge_attr[feature_name].values
415
+ else:
416
+ values = self.edge_attr[feature_name]
417
+
418
+ # TODO: torch_geometric.utils.select
419
+ if isinstance(values, torch.Tensor):
420
+ idxs = list(
421
+ self.get_edge_features_iter(feature_name, pids,
422
+ index_only=True))
423
+ return values[idxs]
424
+ return list(self.get_edge_features_iter(feature_name, pids))
425
+
426
+ def get_edge_features_iter(
427
+ self,
428
+ feature_name: str = EDGE_PID,
429
+ pids: Optional[KnowledgeGraphLike] = None,
430
+ index_only: bool = False,
431
+ ) -> Iterator[Any]:
432
+ """Iterator version of get_edge_features. If index_only is True,
433
+ yields indices instead of values.
434
+ """
435
+ if pids is None:
436
+ pids = self.edge_attr[EDGE_PID]
437
+
438
+ if feature_name in self._mapped_edge_features:
439
+ feature_map_info = self.edge_attr[feature_name]
440
+ from_feature_name, to_feature_vals = (
441
+ feature_map_info.name,
442
+ feature_map_info.values,
443
+ )
444
+ from_feature_vals = self.get_unique_edge_features(
445
+ from_feature_name)
446
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
447
+
448
+ for pid in pids:
449
+ idx = self._edges[pid]
450
+ from_feature_val = self.edge_attr[from_feature_name][idx]
451
+ to_feature_idx = feature_mapping[from_feature_val]
452
+ if index_only:
453
+ yield to_feature_idx
454
+ else:
455
+ yield to_feature_vals[to_feature_idx]
456
+ else:
457
+ for pid in pids:
458
+ idx = self._edges[pid]
459
+ if index_only:
460
+ yield idx
461
+ else:
462
+ yield self.edge_attr[feature_name][idx]
463
+
464
+ def to_triplets(self) -> Iterator[TripletLike]:
465
+ return iter(self.edge_attr[EDGE_PID])
466
+
467
+ def save(self, path: str) -> None:
468
+ if os.path.exists(path):
469
+ shutil.rmtree(path)
470
+ os.makedirs(path, exist_ok=True)
471
+ with open(path + "/edges", "wb") as f:
472
+ pkl.dump(self._edges, f)
473
+ with open(path + "/nodes", "wb") as f:
474
+ pkl.dump(self._nodes, f)
475
+
476
+ with open(path + "/mapped_edges", "wb") as f:
477
+ pkl.dump(self._mapped_edge_features, f)
478
+ with open(path + "/mapped_nodes", "wb") as f:
479
+ pkl.dump(self._mapped_node_features, f)
480
+
481
+ node_attr_path = path + "/node_attr"
482
+ os.makedirs(node_attr_path, exist_ok=True)
483
+ for attr_name, vals in self.node_attr.items():
484
+ torch.save(vals, node_attr_path + f"/{attr_name}.pt")
485
+
486
+ edge_attr_path = path + "/edge_attr"
487
+ os.makedirs(edge_attr_path, exist_ok=True)
488
+ for attr_name, vals in self.edge_attr.items():
489
+ torch.save(vals, edge_attr_path + f"/{attr_name}.pt")
490
+
491
+ @classmethod
492
+ def from_disk(cls, path: str) -> "LargeGraphIndexer":
493
+ indexer = cls(list(), list())
494
+ with open(path + "/edges", "rb") as f:
495
+ indexer._edges = pkl.load(f)
496
+ with open(path + "/nodes", "rb") as f:
497
+ indexer._nodes = pkl.load(f)
498
+
499
+ with open(path + "/mapped_edges", "rb") as f:
500
+ indexer._mapped_edge_features = pkl.load(f)
501
+ with open(path + "/mapped_nodes", "rb") as f:
502
+ indexer._mapped_node_features = pkl.load(f)
503
+
504
+ node_attr_path = path + "/node_attr"
505
+ for fname in os.listdir(node_attr_path):
506
+ full_fname = f"{node_attr_path}/{fname}"
507
+ key = fname.split(".")[0]
508
+ indexer.node_attr[key] = torch.load(full_fname)
509
+
510
+ edge_attr_path = path + "/edge_attr"
511
+ for fname in os.listdir(edge_attr_path):
512
+ full_fname = f"{edge_attr_path}/{fname}"
513
+ key = fname.split(".")[0]
514
+ indexer.edge_attr[key] = torch.load(full_fname)
515
+
516
+ return indexer
517
+
518
+ def to_data(self, node_feature_name: str,
519
+ edge_feature_name: Optional[str] = None) -> Data:
520
+ """Return a Data object containing all the specified node and
521
+ edge features and the graph.
522
+
523
+ Args:
524
+ node_feature_name (str): Feature to use for nodes
525
+ edge_feature_name (Optional[str], optional): Feature to use for
526
+ edges. Defaults to None.
527
+
528
+ Returns:
529
+ Data: Data object containing the specified node and
530
+ edge features and the graph.
531
+ """
532
+ x = torch.Tensor(self.get_node_features(node_feature_name))
533
+ node_id = torch.LongTensor(range(len(x)))
534
+
535
+ edge_index = torch.t(
536
+ torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
537
+
538
+ edge_attr = (self.get_edge_features(edge_feature_name)
539
+ if edge_feature_name is not None else None)
540
+ edge_id = torch.LongTensor(range(len(edge_attr)))
541
+
542
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
543
+ edge_id=edge_id, node_id=node_id)
544
+
545
+ def __eq__(self, value: "LargeGraphIndexer") -> bool:
546
+ eq = True
547
+ eq &= self._nodes == value._nodes
548
+ eq &= self._edges == value._edges
549
+ eq &= self.node_attr.keys() == value.node_attr.keys()
550
+ eq &= self.edge_attr.keys() == value.edge_attr.keys()
551
+ eq &= self._mapped_node_features == value._mapped_node_features
552
+ eq &= self._mapped_edge_features == value._mapped_edge_features
553
+
554
+ for k in self.node_attr:
555
+ eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
556
+ if isinstance(self.node_attr[k], torch.Tensor):
557
+ eq &= torch.equal(self.node_attr[k], value.node_attr[k])
558
+ else:
559
+ eq &= self.node_attr[k] == value.node_attr[k]
560
+ for k in self.edge_attr:
561
+ eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
562
+ if isinstance(self.edge_attr[k], torch.Tensor):
563
+ eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
564
+ else:
565
+ eq &= self.edge_attr[k] == value.edge_attr[k]
566
+ return eq
567
+
568
+
569
+ def get_features_for_triplets_groups(
570
+ indexer: LargeGraphIndexer,
571
+ triplet_groups: Iterable[KnowledgeGraphLike],
572
+ node_feature_name: str = "x",
573
+ edge_feature_name: str = "edge_attr",
574
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
575
+ verbose: bool = False,
576
+ ) -> Iterator[Data]:
577
+ """Given an indexer and a series of triplet groups (like a dataset),
578
+ retrieve the specified node and edge features for each triplet from the
579
+ index.
580
+
581
+ Args:
582
+ indexer (LargeGraphIndexer): Indexer containing desired features
583
+ triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
584
+ triplets to fetch features for
585
+ node_feature_name (str, optional): Node feature to fetch.
586
+ Defaults to "x".
587
+ edge_feature_name (str, optional): edge feature to fetch.
588
+ Defaults to "edge_attr".
589
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
590
+ Optional preprocessing to perform on triplets.
591
+ Defaults to None.
592
+ verbose (bool, optional): Whether to print progress. Defaults to False.
593
+
594
+ Yields:
595
+ Iterator[Data]: For each triplet group, yield a data object containing
596
+ the unique graph and features from the index.
597
+ """
598
+ if pre_transform is not None:
599
+
600
+ def apply_transform(trips):
601
+ for trip in trips:
602
+ yield pre_transform(tuple(trip))
603
+
604
+ # TODO: Make this safe for large amounts of triplets?
605
+ triplet_groups = (list(apply_transform(triplets))
606
+ for triplets in triplet_groups)
607
+
608
+ node_keys = []
609
+ edge_keys = []
610
+ edge_index = []
611
+
612
+ for triplets in tqdm(triplet_groups, disable=not verbose):
613
+ small_graph_indexer = LargeGraphIndexer.from_triplets(
614
+ triplets, pre_transform=pre_transform)
615
+
616
+ node_keys.append(small_graph_indexer.get_node_features())
617
+ edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
618
+ edge_index.append(
619
+ small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
620
+
621
+ node_feats = indexer.get_node_features(feature_name=node_feature_name,
622
+ pids=chain.from_iterable(node_keys))
623
+ edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
624
+ pids=chain.from_iterable(edge_keys))
625
+
626
+ last_node_idx, last_edge_idx = 0, 0
627
+ for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
628
+ nlen, elen = len(nkeys), len(ekeys)
629
+ x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
630
+ last_node_idx += len(nkeys)
631
+
632
+ edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
633
+ elen])
634
+ last_edge_idx += len(ekeys)
635
+
636
+ edge_idx = torch.LongTensor(eidx).T
637
+
638
+ data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
639
+ data_obj[NODE_PID] = node_keys
640
+ data_obj[EDGE_PID] = edge_keys
641
+ data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
642
+ data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
643
+
644
+ yield data_obj
645
+
646
+
647
+ def get_features_for_triplets(
648
+ indexer: LargeGraphIndexer,
649
+ triplets: KnowledgeGraphLike,
650
+ node_feature_name: str = "x",
651
+ edge_feature_name: str = "edge_attr",
652
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
653
+ verbose: bool = False,
654
+ ) -> Data:
655
+ """For a given set of triplets retrieve a Data object containing the
656
+ unique graph and features from the index.
657
+
658
+ Args:
659
+ indexer (LargeGraphIndexer): Indexer containing desired features
660
+ triplets (KnowledgeGraphLike): Triplets to fetch features for
661
+ node_feature_name (str, optional): Feature to use for node features.
662
+ Defaults to "x".
663
+ edge_feature_name (str, optional): Feature to use for edge features.
664
+ Defaults to "edge_attr".
665
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
666
+ Optional preprocessing function for triplets. Defaults to None.
667
+ verbose (bool, optional): Whether to print progress. Defaults to False.
668
+
669
+ Returns:
670
+ Data: Data object containing the unique graph and features from the
671
+ index for the given triplets.
672
+ """
673
+ gen = get_features_for_triplets_groups(indexer, [triplets],
674
+ node_feature_name,
675
+ edge_feature_name, pre_transform,
676
+ verbose)
677
+ return next(gen)
@@ -77,6 +77,7 @@ from .myket import MyketDataset
77
77
  from .brca_tgca import BrcaTcga
78
78
  from .neurograph import NeuroGraphDataset
79
79
  from .web_qsp_dataset import WebQSPDataset
80
+ from .git_mol_dataset import GitMolDataset
80
81
  from .molecule_gpt_dataset import MoleculeGPTDataset
81
82
  from .tag_dataset import TAGDataset
82
83
 
@@ -192,6 +193,7 @@ homo_datasets = [
192
193
  'BrcaTcga',
193
194
  'NeuroGraphDataset',
194
195
  'WebQSPDataset',
196
+ 'GitMolDataset',
195
197
  'MoleculeGPTDataset',
196
198
  'TAGDataset',
197
199
  ]