pyg-nightly 2.7.0.dev20241125__py3-none-any.whl → 2.7.0.dev20241126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241125.dist-info → pyg_nightly-2.7.0.dev20241126.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20241125.dist-info → pyg_nightly-2.7.0.dev20241126.dist-info}/RECORD +11 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +5 -0
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/loader/__init__.py +2 -0
- torch_geometric/loader/rag_loader.py +106 -0
- torch_geometric/nn/models/g_retriever.py +12 -1
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- {pyg_nightly-2.7.0.dev20241125.dist-info → pyg_nightly-2.7.0.dev20241126.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20241125.dist-info → pyg_nightly-2.7.0.dev20241126.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20241126
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=byc0Xe43_b_bDlW1ufoO-jvXP0Uu6SzYXuRtGKXmqyw,1904
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -29,7 +29,7 @@ torch_geometric/contrib/nn/conv/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2e
|
|
29
29
|
torch_geometric/contrib/nn/models/__init__.py,sha256=3ia5cX-TPhouLl6jn_HA-Rd2LaaQvFgy5CjRk0ovKRU,113
|
30
30
|
torch_geometric/contrib/nn/models/rbcd_attack.py,sha256=qcyxBxAbx8LKzpp3RoJQ0cxl9aB2onsWT4oY1fsM7us,33280
|
31
31
|
torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
|
32
|
-
torch_geometric/data/__init__.py,sha256=
|
32
|
+
torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
|
33
33
|
torch_geometric/data/batch.py,sha256=C9cT7-rcWPgnG68Eb_uAcn90HS3OvOG6n4fY3ihpFhI,8764
|
34
34
|
torch_geometric/data/collate.py,sha256=RRiUMBLxDAitaHx7zF0qiMR2nW1NY_0uaNdxlUo5-bo,12756
|
35
35
|
torch_geometric/data/data.py,sha256=l_gHy18g9WtiSCm1mDinR4vGrZOLetogrw5wJEcn23E,43807
|
@@ -43,6 +43,7 @@ torch_geometric/data/graph_store.py,sha256=oFrLDNP5hKf3HWWsFsjcamx5vLIEk8JnLjuGp
|
|
43
43
|
torch_geometric/data/hetero_data.py,sha256=q0L3bENyEvo_BGLPwZPVzh730Aak6sQ7yXoawPgM72E,47982
|
44
44
|
torch_geometric/data/hypergraph_data.py,sha256=33hsXW25Yz4Ju8mKajYinZOrkqrUi1SqThG7MlOOYNM,8294
|
45
45
|
torch_geometric/data/in_memory_dataset.py,sha256=F35hU9Dw3qiJUL5E1CCAfq-1xrlUMstXBmQVEQdtJ1I,13403
|
46
|
+
torch_geometric/data/large_graph_indexer.py,sha256=JqozKbn5C-jLq2uydeImWqihvBRg8nl5Al55V5s53aw,25433
|
46
47
|
torch_geometric/data/makedirs.py,sha256=6uOv4y34i947cm4rv7Aj2_YZBq-EOsyPKnlGA188YSw,463
|
47
48
|
torch_geometric/data/on_disk_dataset.py,sha256=77om-e6kzcpBb77kf7um1xY8-yHmQaao_6R7I-3NwHk,6629
|
48
49
|
torch_geometric/data/remote_backend_utils.py,sha256=Rzpq1PczXuHhUscrFtIAL6dua6pMehSJlXG7yEsrrrg,4503
|
@@ -261,7 +262,7 @@ torch_geometric/io/ply.py,sha256=NdeTtr79vJ1HS37ZV2N61EUmA5NGJd2I6cUj1Pg7Ypg,489
|
|
261
262
|
torch_geometric/io/sdf.py,sha256=H2PC6dSW9Kncc1ulb0UN0JnTRT93NY2fY8lf6K4hb50,1165
|
262
263
|
torch_geometric/io/tu.py,sha256=-v5Ago7DfmGTRBtB5RZFvmv4XpLnKKnk-NOnxlHtB_c,4881
|
263
264
|
torch_geometric/io/txt_array.py,sha256=LDeX2qtlNKW-kVe-wpnskMwAdXQp1jVCGQnrJce7Smg,910
|
264
|
-
torch_geometric/loader/__init__.py,sha256=
|
265
|
+
torch_geometric/loader/__init__.py,sha256=o0wC0Gvv4rewpZU_YeVaJZCCZZJQG2v8MfZhjocvKp8,1896
|
265
266
|
torch_geometric/loader/base.py,sha256=ataIwNEYL0px3CN3LJEgXIVTRylDHB6-yBFXXuX2JN0,1615
|
266
267
|
torch_geometric/loader/cache.py,sha256=S65heO3YTyUPbttqizCNtKPHIoAw5iHRpbvw6KlXmok,2106
|
267
268
|
torch_geometric/loader/cluster.py,sha256=eMNxVkvZt5oQ_gJRgmWm1NBX7zU2tZI_BPaXeB0wuyk,13465
|
@@ -280,6 +281,7 @@ torch_geometric/loader/neighbor_loader.py,sha256=vnLn_RhBKTux5h8pi0vzj0d7JPoOpLA
|
|
280
281
|
torch_geometric/loader/neighbor_sampler.py,sha256=mraVFXIIGctYot4Xr2VOAhCKAOQyW2gP9KROf7g6tcc,8497
|
281
282
|
torch_geometric/loader/node_loader.py,sha256=g_kV5N0tO6eMSFPc5fdbzfHr4COAeKVJi7FEq52f4zc,11848
|
282
283
|
torch_geometric/loader/prefetch.py,sha256=p1mr54TL4nx3Ea0fBy0JulGYJ8Hq4_9rsiNioZsIW-4,3211
|
284
|
+
torch_geometric/loader/rag_loader.py,sha256=nwswemzYL4wCKljXqsxMDg07x6PkLU_kgAkNFj5TwUY,4555
|
283
285
|
torch_geometric/loader/random_node_loader.py,sha256=rCmRXYv70SPxBo-Oh049eFEWEZDV7FmlRPzmjcoirXQ,2196
|
284
286
|
torch_geometric/loader/shadow.py,sha256=_hCspYf9SlJYX0lqEjxFec9e9t1iMScNThOoWR1wQGM,4173
|
285
287
|
torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_yndivXt7Z6_w62zI,2248
|
@@ -431,7 +433,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
|
|
431
433
|
torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
|
432
434
|
torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
|
433
435
|
torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
|
434
|
-
torch_geometric/nn/models/g_retriever.py,sha256=
|
436
|
+
torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0xM5sx5MEW8Ok,8258
|
435
437
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
436
438
|
torch_geometric/nn/models/glem.py,sha256=gqQF4jlU7U_u5-zGeJZuHiEqhSXa-wLU5TghN4u5fYY,16389
|
437
439
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
@@ -496,8 +498,9 @@ torch_geometric/nn/pool/select/base.py,sha256=On7xaGjMf_ISRvDmpBpJ0krYof0a78XEzS
|
|
496
498
|
torch_geometric/nn/pool/select/topk.py,sha256=R1LTjOvanJqlrcDe0qinqz286qOJpmjC1tPeiQdPGcU,5305
|
497
499
|
torch_geometric/nn/unpool/__init__.py,sha256=J6I3abNR1MRxisXzbX3sBRH-hlMpmUe7FVc3UziZ67s,129
|
498
500
|
torch_geometric/nn/unpool/knn_interpolate.py,sha256=8GlKoB-wzZz6ETJP7SsKHbzwenr4JiPg6sK3uh9I6R8,2586
|
499
|
-
torch_geometric/profile/__init__.py,sha256=
|
501
|
+
torch_geometric/profile/__init__.py,sha256=R5dQw0vA5Ukf6FrJgptNMCOb__kcgwnxFThA8qaJF8k,902
|
500
502
|
torch_geometric/profile/benchmark.py,sha256=EuD12qJiiPCSwkg5w8arELXiRT_QY_3Wz_rqs7LpDKE,5256
|
503
|
+
torch_geometric/profile/nvtx.py,sha256=AKBr-rqlHDnls_UM02Dfq5BZmyFTHS5Li5gaeKmsAJI,2032
|
501
504
|
torch_geometric/profile/profile.py,sha256=cHCY4U0XtyqyKC5u380q6TspsOZ5tGHNXaZsKuzYi1A,11793
|
502
505
|
torch_geometric/profile/profiler.py,sha256=rfNciRzWDka_BgO6aPFi3cy8mcT4lSgFWy-WfPgI2SI,16891
|
503
506
|
torch_geometric/profile/utils.py,sha256=7h6vzTzW8vv-ZqMOz2DV8HHNgC9ViOrN7IR9d3BPDZ8,5497
|
@@ -626,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
626
629
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
627
630
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
628
631
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
629
|
-
pyg_nightly-2.7.0.
|
630
|
-
pyg_nightly-2.7.0.
|
631
|
-
pyg_nightly-2.7.0.
|
632
|
+
pyg_nightly-2.7.0.dev20241126.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
633
|
+
pyg_nightly-2.7.0.dev20241126.dist-info/METADATA,sha256=kYXIEsYg6p5-3FGWoEJAsjfBZTNbsotYFGiCqV6r4JA,62979
|
634
|
+
pyg_nightly-2.7.0.dev20241126.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.7.0.
|
33
|
+
__version__ = '2.7.0.dev20241126'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
torch_geometric/data/__init__.py
CHANGED
@@ -16,6 +16,7 @@ from .on_disk_dataset import OnDiskDataset
|
|
16
16
|
from .makedirs import makedirs
|
17
17
|
from .download import download_url, download_google_url
|
18
18
|
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
|
19
|
+
from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
|
19
20
|
|
20
21
|
from torch_geometric.lazy_loader import LazyLoader
|
21
22
|
|
@@ -27,6 +28,8 @@ data_classes = [
|
|
27
28
|
'Dataset',
|
28
29
|
'InMemoryDataset',
|
29
30
|
'OnDiskDataset',
|
31
|
+
'LargeGraphIndexer',
|
32
|
+
'TripletLike',
|
30
33
|
]
|
31
34
|
|
32
35
|
remote_backend_classes = [
|
@@ -50,6 +53,8 @@ helper_functions = [
|
|
50
53
|
'extract_zip',
|
51
54
|
'extract_bz2',
|
52
55
|
'extract_gz',
|
56
|
+
'get_features_for_triplets',
|
57
|
+
"get_features_for_triplets_groups",
|
53
58
|
]
|
54
59
|
|
55
60
|
__all__ = data_classes + remote_backend_classes + helper_functions
|
@@ -0,0 +1,677 @@
|
|
1
|
+
import os
|
2
|
+
import pickle as pkl
|
3
|
+
import shutil
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from itertools import chain
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
Dict,
|
10
|
+
Hashable,
|
11
|
+
Iterable,
|
12
|
+
Iterator,
|
13
|
+
List,
|
14
|
+
Optional,
|
15
|
+
Sequence,
|
16
|
+
Set,
|
17
|
+
Tuple,
|
18
|
+
Union,
|
19
|
+
)
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch import Tensor
|
23
|
+
from tqdm import tqdm
|
24
|
+
|
25
|
+
from torch_geometric.data import Data
|
26
|
+
from torch_geometric.typing import WITH_PT24
|
27
|
+
|
28
|
+
TripletLike = Tuple[Hashable, Hashable, Hashable]
|
29
|
+
|
30
|
+
KnowledgeGraphLike = Iterable[TripletLike]
|
31
|
+
|
32
|
+
|
33
|
+
def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
|
34
|
+
return list(dict.fromkeys(values))
|
35
|
+
|
36
|
+
|
37
|
+
# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
|
38
|
+
|
39
|
+
NODE_PID = "pid"
|
40
|
+
|
41
|
+
NODE_KEYS = {NODE_PID}
|
42
|
+
|
43
|
+
EDGE_PID = "e_pid"
|
44
|
+
EDGE_HEAD = "h"
|
45
|
+
EDGE_RELATION = "r"
|
46
|
+
EDGE_TAIL = "t"
|
47
|
+
EDGE_INDEX = "edge_idx"
|
48
|
+
|
49
|
+
EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
|
50
|
+
|
51
|
+
FeatureValueType = Union[Sequence[Any], Tensor]
|
52
|
+
|
53
|
+
|
54
|
+
@dataclass
|
55
|
+
class MappedFeature:
|
56
|
+
name: str
|
57
|
+
values: FeatureValueType
|
58
|
+
|
59
|
+
def __eq__(self, value: "MappedFeature") -> bool:
|
60
|
+
eq = self.name == value.name
|
61
|
+
if isinstance(self.values, torch.Tensor):
|
62
|
+
eq &= torch.equal(self.values, value.values)
|
63
|
+
else:
|
64
|
+
eq &= self.values == value.values
|
65
|
+
return eq
|
66
|
+
|
67
|
+
|
68
|
+
if WITH_PT24:
|
69
|
+
torch.serialization.add_safe_globals([MappedFeature])
|
70
|
+
|
71
|
+
|
72
|
+
class LargeGraphIndexer:
|
73
|
+
"""For a dataset that consists of mulitiple subgraphs that are assumed to
|
74
|
+
be part of a much larger graph, collate the values into a large graph store
|
75
|
+
to save resources.
|
76
|
+
"""
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
nodes: Iterable[Hashable],
|
80
|
+
edges: KnowledgeGraphLike,
|
81
|
+
node_attr: Optional[Dict[str, List[Any]]] = None,
|
82
|
+
edge_attr: Optional[Dict[str, List[Any]]] = None,
|
83
|
+
) -> None:
|
84
|
+
r"""Constructs a new index that uniquely catalogs each node and edge
|
85
|
+
by id. Not meant to be used directly.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
nodes (Iterable[Hashable]): Node ids in the graph.
|
89
|
+
edges (KnowledgeGraphLike): Edge ids in the graph.
|
90
|
+
node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
|
91
|
+
attribute name and list of their values in order of unique node
|
92
|
+
ids. Defaults to None.
|
93
|
+
edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
|
94
|
+
attribute name and list of their values in order of unique edge
|
95
|
+
ids. Defaults to None.
|
96
|
+
"""
|
97
|
+
self._nodes: Dict[Hashable, int] = dict()
|
98
|
+
self._edges: Dict[TripletLike, int] = dict()
|
99
|
+
|
100
|
+
self._mapped_node_features: Set[str] = set()
|
101
|
+
self._mapped_edge_features: Set[str] = set()
|
102
|
+
|
103
|
+
if len(nodes) != len(set(nodes)):
|
104
|
+
raise AttributeError("Nodes need to be unique")
|
105
|
+
if len(edges) != len(set(edges)):
|
106
|
+
raise AttributeError("Edges need to be unique")
|
107
|
+
|
108
|
+
if node_attr is not None:
|
109
|
+
# TODO: Validity checks btw nodes and node_attr
|
110
|
+
self.node_attr = node_attr
|
111
|
+
if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
|
112
|
+
raise AttributeError(
|
113
|
+
"Invalid node_attr object. Missing " +
|
114
|
+
f"{NODE_KEYS - set(self.node_attr.keys())}")
|
115
|
+
elif self.node_attr[NODE_PID] != nodes:
|
116
|
+
raise AttributeError(
|
117
|
+
"Nodes provided do not match those in node_attr")
|
118
|
+
else:
|
119
|
+
self.node_attr = dict()
|
120
|
+
self.node_attr[NODE_PID] = nodes
|
121
|
+
|
122
|
+
for i, node in enumerate(self.node_attr[NODE_PID]):
|
123
|
+
self._nodes[node] = i
|
124
|
+
|
125
|
+
if edge_attr is not None:
|
126
|
+
# TODO: Validity checks btw edges and edge_attr
|
127
|
+
self.edge_attr = edge_attr
|
128
|
+
|
129
|
+
if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
|
130
|
+
raise AttributeError(
|
131
|
+
"Invalid edge_attr object. Missing " +
|
132
|
+
f"{EDGE_KEYS - set(self.edge_attr.keys())}")
|
133
|
+
elif self.node_attr[EDGE_PID] != edges:
|
134
|
+
raise AttributeError(
|
135
|
+
"Edges provided do not match those in edge_attr")
|
136
|
+
|
137
|
+
else:
|
138
|
+
self.edge_attr = dict()
|
139
|
+
for default_key in EDGE_KEYS:
|
140
|
+
self.edge_attr[default_key] = list()
|
141
|
+
self.edge_attr[EDGE_PID] = edges
|
142
|
+
|
143
|
+
for i, tup in enumerate(edges):
|
144
|
+
h, r, t = tup
|
145
|
+
self.edge_attr[EDGE_HEAD].append(h)
|
146
|
+
self.edge_attr[EDGE_RELATION].append(r)
|
147
|
+
self.edge_attr[EDGE_TAIL].append(t)
|
148
|
+
self.edge_attr[EDGE_INDEX].append(
|
149
|
+
(self._nodes[h], self._nodes[t]))
|
150
|
+
|
151
|
+
for i, tup in enumerate(edges):
|
152
|
+
self._edges[tup] = i
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def from_triplets(
|
156
|
+
cls,
|
157
|
+
triplets: KnowledgeGraphLike,
|
158
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
159
|
+
) -> "LargeGraphIndexer":
|
160
|
+
r"""Generate a new index from a series of triplets that represent edge
|
161
|
+
relations between nodes.
|
162
|
+
Formatted like (source_node, edge, dest_node).
|
163
|
+
|
164
|
+
Args:
|
165
|
+
triplets (KnowledgeGraphLike): Series of triplets representing
|
166
|
+
knowledge graph relations.
|
167
|
+
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
168
|
+
Optional preprocessing function to apply to triplets.
|
169
|
+
Defaults to None.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
LargeGraphIndexer: Index of unique nodes and edges.
|
173
|
+
"""
|
174
|
+
# NOTE: Right now assumes that all trips can be loaded into memory
|
175
|
+
nodes = set()
|
176
|
+
edges = set()
|
177
|
+
|
178
|
+
if pre_transform is not None:
|
179
|
+
|
180
|
+
def apply_transform(
|
181
|
+
trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
|
182
|
+
for trip in trips:
|
183
|
+
yield pre_transform(trip)
|
184
|
+
|
185
|
+
triplets = apply_transform(triplets)
|
186
|
+
|
187
|
+
for h, r, t in triplets:
|
188
|
+
|
189
|
+
for node in (h, t):
|
190
|
+
nodes.add(node)
|
191
|
+
|
192
|
+
edge_idx = (h, r, t)
|
193
|
+
edges.add(edge_idx)
|
194
|
+
|
195
|
+
return cls(list(nodes), list(edges))
|
196
|
+
|
197
|
+
@classmethod
|
198
|
+
def collate(cls,
|
199
|
+
graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
|
200
|
+
r"""Combines a series of large graph indexes into a single large graph
|
201
|
+
index.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
graphs (Iterable["LargeGraphIndexer"]): Indices to be
|
205
|
+
combined.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
LargeGraphIndexer: Singular unique index for all nodes and edges
|
209
|
+
in input indices.
|
210
|
+
"""
|
211
|
+
# FIXME Needs to merge node attrs and edge attrs?
|
212
|
+
trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
|
213
|
+
return cls.from_triplets(trips)
|
214
|
+
|
215
|
+
def get_unique_node_features(
|
216
|
+
self, feature_name: str = NODE_PID) -> List[Hashable]:
|
217
|
+
r"""Get all the unique values for a specific node attribute.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
feature_name (str, optional): Name of feature to get.
|
221
|
+
Defaults to NODE_PID.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
List[Hashable]: List of unique values for the specified feature.
|
225
|
+
"""
|
226
|
+
try:
|
227
|
+
if feature_name in self._mapped_node_features:
|
228
|
+
raise IndexError(
|
229
|
+
"Only non-mapped features can be retrieved uniquely.")
|
230
|
+
return ordered_set(self.get_node_features(feature_name))
|
231
|
+
|
232
|
+
except KeyError:
|
233
|
+
raise AttributeError(
|
234
|
+
f"Nodes do not have a feature called {feature_name}")
|
235
|
+
|
236
|
+
def add_node_feature(
|
237
|
+
self,
|
238
|
+
new_feature_name: str,
|
239
|
+
new_feature_vals: FeatureValueType,
|
240
|
+
map_from_feature: str = NODE_PID,
|
241
|
+
) -> None:
|
242
|
+
r"""Adds a new feature that corresponds to each unique node in
|
243
|
+
the graph.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
new_feature_name (str): Name to call the new feature.
|
247
|
+
new_feature_vals (FeatureValueType): Values to map for that
|
248
|
+
new feature.
|
249
|
+
map_from_feature (str, optional): Key of feature to map from.
|
250
|
+
Size must match the number of feature values.
|
251
|
+
Defaults to NODE_PID.
|
252
|
+
"""
|
253
|
+
if new_feature_name in self.node_attr:
|
254
|
+
raise AttributeError("Features cannot be overridden once created")
|
255
|
+
if map_from_feature in self._mapped_node_features:
|
256
|
+
raise AttributeError(
|
257
|
+
f"{map_from_feature} is already a feature mapping.")
|
258
|
+
|
259
|
+
feature_keys = self.get_unique_node_features(map_from_feature)
|
260
|
+
if len(feature_keys) != len(new_feature_vals):
|
261
|
+
raise AttributeError(
|
262
|
+
"Expected encodings for {len(feature_keys)} unique features," +
|
263
|
+
f" but got {len(new_feature_vals)} encodings.")
|
264
|
+
|
265
|
+
if map_from_feature == NODE_PID:
|
266
|
+
self.node_attr[new_feature_name] = new_feature_vals
|
267
|
+
else:
|
268
|
+
self.node_attr[new_feature_name] = MappedFeature(
|
269
|
+
name=map_from_feature, values=new_feature_vals)
|
270
|
+
self._mapped_node_features.add(new_feature_name)
|
271
|
+
|
272
|
+
def get_node_features(
|
273
|
+
self,
|
274
|
+
feature_name: str = NODE_PID,
|
275
|
+
pids: Optional[Iterable[Hashable]] = None,
|
276
|
+
) -> List[Any]:
|
277
|
+
r"""Get node feature values for a given set of unique node ids.
|
278
|
+
Returned values are not necessarily unique.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
feature_name (str, optional): Name of feature to fetch. Defaults
|
282
|
+
to NODE_PID.
|
283
|
+
pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
|
284
|
+
for. Defaults to None, which fetches all nodes.
|
285
|
+
|
286
|
+
Returns:
|
287
|
+
List[Any]: Node features corresponding to the specified ids.
|
288
|
+
"""
|
289
|
+
if feature_name in self._mapped_node_features:
|
290
|
+
values = self.node_attr[feature_name].values
|
291
|
+
else:
|
292
|
+
values = self.node_attr[feature_name]
|
293
|
+
|
294
|
+
# TODO: torch_geometric.utils.select
|
295
|
+
if isinstance(values, torch.Tensor):
|
296
|
+
idxs = list(
|
297
|
+
self.get_node_features_iter(feature_name, pids,
|
298
|
+
index_only=True))
|
299
|
+
return values[idxs]
|
300
|
+
return list(self.get_node_features_iter(feature_name, pids))
|
301
|
+
|
302
|
+
def get_node_features_iter(
|
303
|
+
self,
|
304
|
+
feature_name: str = NODE_PID,
|
305
|
+
pids: Optional[Iterable[Hashable]] = None,
|
306
|
+
index_only: bool = False,
|
307
|
+
) -> Iterator[Any]:
|
308
|
+
"""Iterator version of get_node_features. If index_only is True,
|
309
|
+
yields indices instead of values.
|
310
|
+
"""
|
311
|
+
if pids is None:
|
312
|
+
pids = self.node_attr[NODE_PID]
|
313
|
+
|
314
|
+
if feature_name in self._mapped_node_features:
|
315
|
+
feature_map_info = self.node_attr[feature_name]
|
316
|
+
from_feature_name, to_feature_vals = (
|
317
|
+
feature_map_info.name,
|
318
|
+
feature_map_info.values,
|
319
|
+
)
|
320
|
+
from_feature_vals = self.get_unique_node_features(
|
321
|
+
from_feature_name)
|
322
|
+
feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
|
323
|
+
|
324
|
+
for pid in pids:
|
325
|
+
idx = self._nodes[pid]
|
326
|
+
from_feature_val = self.node_attr[from_feature_name][idx]
|
327
|
+
to_feature_idx = feature_mapping[from_feature_val]
|
328
|
+
if index_only:
|
329
|
+
yield to_feature_idx
|
330
|
+
else:
|
331
|
+
yield to_feature_vals[to_feature_idx]
|
332
|
+
else:
|
333
|
+
for pid in pids:
|
334
|
+
idx = self._nodes[pid]
|
335
|
+
if index_only:
|
336
|
+
yield idx
|
337
|
+
else:
|
338
|
+
yield self.node_attr[feature_name][idx]
|
339
|
+
|
340
|
+
def get_unique_edge_features(
|
341
|
+
self, feature_name: str = EDGE_PID) -> List[Hashable]:
|
342
|
+
r"""Get all the unique values for a specific edge attribute.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
feature_name (str, optional): Name of feature to get.
|
346
|
+
Defaults to EDGE_PID.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
List[Hashable]: List of unique values for the specified feature.
|
350
|
+
"""
|
351
|
+
try:
|
352
|
+
if feature_name in self._mapped_edge_features:
|
353
|
+
raise IndexError(
|
354
|
+
"Only non-mapped features can be retrieved uniquely.")
|
355
|
+
return ordered_set(self.get_edge_features(feature_name))
|
356
|
+
except KeyError:
|
357
|
+
raise AttributeError(
|
358
|
+
f"Edges do not have a feature called {feature_name}")
|
359
|
+
|
360
|
+
def add_edge_feature(
|
361
|
+
self,
|
362
|
+
new_feature_name: str,
|
363
|
+
new_feature_vals: FeatureValueType,
|
364
|
+
map_from_feature: str = EDGE_PID,
|
365
|
+
) -> None:
|
366
|
+
r"""Adds a new feature that corresponds to each unique edge in
|
367
|
+
the graph.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
new_feature_name (str): Name to call the new feature.
|
371
|
+
new_feature_vals (FeatureValueType): Values to map for that new
|
372
|
+
feature.
|
373
|
+
map_from_feature (str, optional): Key of feature to map from.
|
374
|
+
Size must match the number of feature values.
|
375
|
+
Defaults to EDGE_PID.
|
376
|
+
"""
|
377
|
+
if new_feature_name in self.edge_attr:
|
378
|
+
raise AttributeError("Features cannot be overridden once created")
|
379
|
+
if map_from_feature in self._mapped_edge_features:
|
380
|
+
raise AttributeError(
|
381
|
+
f"{map_from_feature} is already a feature mapping.")
|
382
|
+
|
383
|
+
feature_keys = self.get_unique_edge_features(map_from_feature)
|
384
|
+
if len(feature_keys) != len(new_feature_vals):
|
385
|
+
raise AttributeError(
|
386
|
+
f"Expected encodings for {len(feature_keys)} unique features, "
|
387
|
+
+ f"but got {len(new_feature_vals)} encodings.")
|
388
|
+
|
389
|
+
if map_from_feature == EDGE_PID:
|
390
|
+
self.edge_attr[new_feature_name] = new_feature_vals
|
391
|
+
else:
|
392
|
+
self.edge_attr[new_feature_name] = MappedFeature(
|
393
|
+
name=map_from_feature, values=new_feature_vals)
|
394
|
+
self._mapped_edge_features.add(new_feature_name)
|
395
|
+
|
396
|
+
def get_edge_features(
|
397
|
+
self,
|
398
|
+
feature_name: str = EDGE_PID,
|
399
|
+
pids: Optional[Iterable[Hashable]] = None,
|
400
|
+
) -> List[Any]:
|
401
|
+
r"""Get edge feature values for a given set of unique edge ids.
|
402
|
+
Returned values are not necessarily unique.
|
403
|
+
|
404
|
+
Args:
|
405
|
+
feature_name (str, optional): Name of feature to fetch.
|
406
|
+
Defaults to EDGE_PID.
|
407
|
+
pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
|
408
|
+
for. Defaults to None, which fetches all edges.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
List[Any]: Node features corresponding to the specified ids.
|
412
|
+
"""
|
413
|
+
if feature_name in self._mapped_edge_features:
|
414
|
+
values = self.edge_attr[feature_name].values
|
415
|
+
else:
|
416
|
+
values = self.edge_attr[feature_name]
|
417
|
+
|
418
|
+
# TODO: torch_geometric.utils.select
|
419
|
+
if isinstance(values, torch.Tensor):
|
420
|
+
idxs = list(
|
421
|
+
self.get_edge_features_iter(feature_name, pids,
|
422
|
+
index_only=True))
|
423
|
+
return values[idxs]
|
424
|
+
return list(self.get_edge_features_iter(feature_name, pids))
|
425
|
+
|
426
|
+
def get_edge_features_iter(
|
427
|
+
self,
|
428
|
+
feature_name: str = EDGE_PID,
|
429
|
+
pids: Optional[KnowledgeGraphLike] = None,
|
430
|
+
index_only: bool = False,
|
431
|
+
) -> Iterator[Any]:
|
432
|
+
"""Iterator version of get_edge_features. If index_only is True,
|
433
|
+
yields indices instead of values.
|
434
|
+
"""
|
435
|
+
if pids is None:
|
436
|
+
pids = self.edge_attr[EDGE_PID]
|
437
|
+
|
438
|
+
if feature_name in self._mapped_edge_features:
|
439
|
+
feature_map_info = self.edge_attr[feature_name]
|
440
|
+
from_feature_name, to_feature_vals = (
|
441
|
+
feature_map_info.name,
|
442
|
+
feature_map_info.values,
|
443
|
+
)
|
444
|
+
from_feature_vals = self.get_unique_edge_features(
|
445
|
+
from_feature_name)
|
446
|
+
feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
|
447
|
+
|
448
|
+
for pid in pids:
|
449
|
+
idx = self._edges[pid]
|
450
|
+
from_feature_val = self.edge_attr[from_feature_name][idx]
|
451
|
+
to_feature_idx = feature_mapping[from_feature_val]
|
452
|
+
if index_only:
|
453
|
+
yield to_feature_idx
|
454
|
+
else:
|
455
|
+
yield to_feature_vals[to_feature_idx]
|
456
|
+
else:
|
457
|
+
for pid in pids:
|
458
|
+
idx = self._edges[pid]
|
459
|
+
if index_only:
|
460
|
+
yield idx
|
461
|
+
else:
|
462
|
+
yield self.edge_attr[feature_name][idx]
|
463
|
+
|
464
|
+
def to_triplets(self) -> Iterator[TripletLike]:
|
465
|
+
return iter(self.edge_attr[EDGE_PID])
|
466
|
+
|
467
|
+
def save(self, path: str) -> None:
|
468
|
+
if os.path.exists(path):
|
469
|
+
shutil.rmtree(path)
|
470
|
+
os.makedirs(path, exist_ok=True)
|
471
|
+
with open(path + "/edges", "wb") as f:
|
472
|
+
pkl.dump(self._edges, f)
|
473
|
+
with open(path + "/nodes", "wb") as f:
|
474
|
+
pkl.dump(self._nodes, f)
|
475
|
+
|
476
|
+
with open(path + "/mapped_edges", "wb") as f:
|
477
|
+
pkl.dump(self._mapped_edge_features, f)
|
478
|
+
with open(path + "/mapped_nodes", "wb") as f:
|
479
|
+
pkl.dump(self._mapped_node_features, f)
|
480
|
+
|
481
|
+
node_attr_path = path + "/node_attr"
|
482
|
+
os.makedirs(node_attr_path, exist_ok=True)
|
483
|
+
for attr_name, vals in self.node_attr.items():
|
484
|
+
torch.save(vals, node_attr_path + f"/{attr_name}.pt")
|
485
|
+
|
486
|
+
edge_attr_path = path + "/edge_attr"
|
487
|
+
os.makedirs(edge_attr_path, exist_ok=True)
|
488
|
+
for attr_name, vals in self.edge_attr.items():
|
489
|
+
torch.save(vals, edge_attr_path + f"/{attr_name}.pt")
|
490
|
+
|
491
|
+
@classmethod
|
492
|
+
def from_disk(cls, path: str) -> "LargeGraphIndexer":
|
493
|
+
indexer = cls(list(), list())
|
494
|
+
with open(path + "/edges", "rb") as f:
|
495
|
+
indexer._edges = pkl.load(f)
|
496
|
+
with open(path + "/nodes", "rb") as f:
|
497
|
+
indexer._nodes = pkl.load(f)
|
498
|
+
|
499
|
+
with open(path + "/mapped_edges", "rb") as f:
|
500
|
+
indexer._mapped_edge_features = pkl.load(f)
|
501
|
+
with open(path + "/mapped_nodes", "rb") as f:
|
502
|
+
indexer._mapped_node_features = pkl.load(f)
|
503
|
+
|
504
|
+
node_attr_path = path + "/node_attr"
|
505
|
+
for fname in os.listdir(node_attr_path):
|
506
|
+
full_fname = f"{node_attr_path}/{fname}"
|
507
|
+
key = fname.split(".")[0]
|
508
|
+
indexer.node_attr[key] = torch.load(full_fname)
|
509
|
+
|
510
|
+
edge_attr_path = path + "/edge_attr"
|
511
|
+
for fname in os.listdir(edge_attr_path):
|
512
|
+
full_fname = f"{edge_attr_path}/{fname}"
|
513
|
+
key = fname.split(".")[0]
|
514
|
+
indexer.edge_attr[key] = torch.load(full_fname)
|
515
|
+
|
516
|
+
return indexer
|
517
|
+
|
518
|
+
def to_data(self, node_feature_name: str,
|
519
|
+
edge_feature_name: Optional[str] = None) -> Data:
|
520
|
+
"""Return a Data object containing all the specified node and
|
521
|
+
edge features and the graph.
|
522
|
+
|
523
|
+
Args:
|
524
|
+
node_feature_name (str): Feature to use for nodes
|
525
|
+
edge_feature_name (Optional[str], optional): Feature to use for
|
526
|
+
edges. Defaults to None.
|
527
|
+
|
528
|
+
Returns:
|
529
|
+
Data: Data object containing the specified node and
|
530
|
+
edge features and the graph.
|
531
|
+
"""
|
532
|
+
x = torch.Tensor(self.get_node_features(node_feature_name))
|
533
|
+
node_id = torch.LongTensor(range(len(x)))
|
534
|
+
|
535
|
+
edge_index = torch.t(
|
536
|
+
torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
|
537
|
+
|
538
|
+
edge_attr = (self.get_edge_features(edge_feature_name)
|
539
|
+
if edge_feature_name is not None else None)
|
540
|
+
edge_id = torch.LongTensor(range(len(edge_attr)))
|
541
|
+
|
542
|
+
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
543
|
+
edge_id=edge_id, node_id=node_id)
|
544
|
+
|
545
|
+
def __eq__(self, value: "LargeGraphIndexer") -> bool:
|
546
|
+
eq = True
|
547
|
+
eq &= self._nodes == value._nodes
|
548
|
+
eq &= self._edges == value._edges
|
549
|
+
eq &= self.node_attr.keys() == value.node_attr.keys()
|
550
|
+
eq &= self.edge_attr.keys() == value.edge_attr.keys()
|
551
|
+
eq &= self._mapped_node_features == value._mapped_node_features
|
552
|
+
eq &= self._mapped_edge_features == value._mapped_edge_features
|
553
|
+
|
554
|
+
for k in self.node_attr:
|
555
|
+
eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
|
556
|
+
if isinstance(self.node_attr[k], torch.Tensor):
|
557
|
+
eq &= torch.equal(self.node_attr[k], value.node_attr[k])
|
558
|
+
else:
|
559
|
+
eq &= self.node_attr[k] == value.node_attr[k]
|
560
|
+
for k in self.edge_attr:
|
561
|
+
eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
|
562
|
+
if isinstance(self.edge_attr[k], torch.Tensor):
|
563
|
+
eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
|
564
|
+
else:
|
565
|
+
eq &= self.edge_attr[k] == value.edge_attr[k]
|
566
|
+
return eq
|
567
|
+
|
568
|
+
|
569
|
+
def get_features_for_triplets_groups(
|
570
|
+
indexer: LargeGraphIndexer,
|
571
|
+
triplet_groups: Iterable[KnowledgeGraphLike],
|
572
|
+
node_feature_name: str = "x",
|
573
|
+
edge_feature_name: str = "edge_attr",
|
574
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
575
|
+
verbose: bool = False,
|
576
|
+
) -> Iterator[Data]:
|
577
|
+
"""Given an indexer and a series of triplet groups (like a dataset),
|
578
|
+
retrieve the specified node and edge features for each triplet from the
|
579
|
+
index.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
indexer (LargeGraphIndexer): Indexer containing desired features
|
583
|
+
triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
|
584
|
+
triplets to fetch features for
|
585
|
+
node_feature_name (str, optional): Node feature to fetch.
|
586
|
+
Defaults to "x".
|
587
|
+
edge_feature_name (str, optional): edge feature to fetch.
|
588
|
+
Defaults to "edge_attr".
|
589
|
+
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
590
|
+
Optional preprocessing to perform on triplets.
|
591
|
+
Defaults to None.
|
592
|
+
verbose (bool, optional): Whether to print progress. Defaults to False.
|
593
|
+
|
594
|
+
Yields:
|
595
|
+
Iterator[Data]: For each triplet group, yield a data object containing
|
596
|
+
the unique graph and features from the index.
|
597
|
+
"""
|
598
|
+
if pre_transform is not None:
|
599
|
+
|
600
|
+
def apply_transform(trips):
|
601
|
+
for trip in trips:
|
602
|
+
yield pre_transform(tuple(trip))
|
603
|
+
|
604
|
+
# TODO: Make this safe for large amounts of triplets?
|
605
|
+
triplet_groups = (list(apply_transform(triplets))
|
606
|
+
for triplets in triplet_groups)
|
607
|
+
|
608
|
+
node_keys = []
|
609
|
+
edge_keys = []
|
610
|
+
edge_index = []
|
611
|
+
|
612
|
+
for triplets in tqdm(triplet_groups, disable=not verbose):
|
613
|
+
small_graph_indexer = LargeGraphIndexer.from_triplets(
|
614
|
+
triplets, pre_transform=pre_transform)
|
615
|
+
|
616
|
+
node_keys.append(small_graph_indexer.get_node_features())
|
617
|
+
edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
|
618
|
+
edge_index.append(
|
619
|
+
small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
|
620
|
+
|
621
|
+
node_feats = indexer.get_node_features(feature_name=node_feature_name,
|
622
|
+
pids=chain.from_iterable(node_keys))
|
623
|
+
edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
|
624
|
+
pids=chain.from_iterable(edge_keys))
|
625
|
+
|
626
|
+
last_node_idx, last_edge_idx = 0, 0
|
627
|
+
for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
|
628
|
+
nlen, elen = len(nkeys), len(ekeys)
|
629
|
+
x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
|
630
|
+
last_node_idx += len(nkeys)
|
631
|
+
|
632
|
+
edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
|
633
|
+
elen])
|
634
|
+
last_edge_idx += len(ekeys)
|
635
|
+
|
636
|
+
edge_idx = torch.LongTensor(eidx).T
|
637
|
+
|
638
|
+
data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
|
639
|
+
data_obj[NODE_PID] = node_keys
|
640
|
+
data_obj[EDGE_PID] = edge_keys
|
641
|
+
data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
|
642
|
+
data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
|
643
|
+
|
644
|
+
yield data_obj
|
645
|
+
|
646
|
+
|
647
|
+
def get_features_for_triplets(
|
648
|
+
indexer: LargeGraphIndexer,
|
649
|
+
triplets: KnowledgeGraphLike,
|
650
|
+
node_feature_name: str = "x",
|
651
|
+
edge_feature_name: str = "edge_attr",
|
652
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
653
|
+
verbose: bool = False,
|
654
|
+
) -> Data:
|
655
|
+
"""For a given set of triplets retrieve a Data object containing the
|
656
|
+
unique graph and features from the index.
|
657
|
+
|
658
|
+
Args:
|
659
|
+
indexer (LargeGraphIndexer): Indexer containing desired features
|
660
|
+
triplets (KnowledgeGraphLike): Triplets to fetch features for
|
661
|
+
node_feature_name (str, optional): Feature to use for node features.
|
662
|
+
Defaults to "x".
|
663
|
+
edge_feature_name (str, optional): Feature to use for edge features.
|
664
|
+
Defaults to "edge_attr".
|
665
|
+
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
666
|
+
Optional preprocessing function for triplets. Defaults to None.
|
667
|
+
verbose (bool, optional): Whether to print progress. Defaults to False.
|
668
|
+
|
669
|
+
Returns:
|
670
|
+
Data: Data object containing the unique graph and features from the
|
671
|
+
index for the given triplets.
|
672
|
+
"""
|
673
|
+
gen = get_features_for_triplets_groups(indexer, [triplets],
|
674
|
+
node_feature_name,
|
675
|
+
edge_feature_name, pre_transform,
|
676
|
+
verbose)
|
677
|
+
return next(gen)
|
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
|
|
22
22
|
from .prefetch import PrefetchLoader
|
23
23
|
from .cache import CachedLoader
|
24
24
|
from .mixin import AffinityMixin
|
25
|
+
from .rag_loader import RAGQueryLoader
|
25
26
|
|
26
27
|
__all__ = classes = [
|
27
28
|
'DataLoader',
|
@@ -50,6 +51,7 @@ __all__ = classes = [
|
|
50
51
|
'PrefetchLoader',
|
51
52
|
'CachedLoader',
|
52
53
|
'AffinityMixin',
|
54
|
+
'RAGQueryLoader',
|
53
55
|
]
|
54
56
|
|
55
57
|
RandomNodeSampler = deprecated(
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
3
|
+
|
4
|
+
from torch_geometric.data import Data, FeatureStore, HeteroData
|
5
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
6
|
+
from torch_geometric.typing import InputEdges, InputNodes
|
7
|
+
|
8
|
+
|
9
|
+
class RAGFeatureStore(Protocol):
|
10
|
+
"""Feature store for remote GNN RAG backend."""
|
11
|
+
@abstractmethod
|
12
|
+
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
|
13
|
+
"""Makes a comparison between the query and all the nodes to get all
|
14
|
+
the closest nodes. Return the indices of the nodes that are to be seeds
|
15
|
+
for the RAG Sampler.
|
16
|
+
"""
|
17
|
+
...
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
|
21
|
+
"""Makes a comparison between the query and all the edges to get all
|
22
|
+
the closest nodes. Returns the edge indices that are to be the seeds
|
23
|
+
for the RAG Sampler.
|
24
|
+
"""
|
25
|
+
...
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def load_subgraph(
|
29
|
+
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
|
30
|
+
) -> Union[Data, HeteroData]:
|
31
|
+
"""Combines sampled subgraph output with features in a Data object."""
|
32
|
+
...
|
33
|
+
|
34
|
+
|
35
|
+
class RAGGraphStore(Protocol):
|
36
|
+
"""Graph store for remote GNN RAG backend."""
|
37
|
+
@abstractmethod
|
38
|
+
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
|
39
|
+
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
40
|
+
"""Sample a subgraph using the seeded nodes and edges."""
|
41
|
+
...
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def register_feature_store(self, feature_store: FeatureStore):
|
45
|
+
"""Register a feature store to be used with the sampler. Samplers need
|
46
|
+
info from the feature store in order to work properly on HeteroGraphs.
|
47
|
+
"""
|
48
|
+
...
|
49
|
+
|
50
|
+
|
51
|
+
# TODO: Make compatible with Heterographs
|
52
|
+
|
53
|
+
|
54
|
+
class RAGQueryLoader:
|
55
|
+
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
|
56
|
+
local_filter: Optional[Callable[[Data, Any], Data]] = None,
|
57
|
+
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
|
58
|
+
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
|
59
|
+
sampler_kwargs: Optional[Dict[str, Any]] = None,
|
60
|
+
loader_kwargs: Optional[Dict[str, Any]] = None):
|
61
|
+
"""Loader meant for making queries from a remote backend.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
|
65
|
+
and GraphStore to load from. Assumed to conform to the
|
66
|
+
protocols listed above.
|
67
|
+
local_filter (Optional[Callable[[Data, Any], Data]], optional):
|
68
|
+
Optional local transform to apply to data after retrieval.
|
69
|
+
Defaults to None.
|
70
|
+
seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
|
71
|
+
to pass into process for fetching seed nodes. Defaults to None.
|
72
|
+
seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
|
73
|
+
to pass into process for fetching seed edges. Defaults to None.
|
74
|
+
sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
75
|
+
pass into process for sampling graph. Defaults to None.
|
76
|
+
loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
77
|
+
pass into process for loading graph features. Defaults to None.
|
78
|
+
"""
|
79
|
+
fstore, gstore = data
|
80
|
+
self.feature_store = fstore
|
81
|
+
self.graph_store = gstore
|
82
|
+
self.graph_store.register_feature_store(self.feature_store)
|
83
|
+
self.local_filter = local_filter
|
84
|
+
self.seed_nodes_kwargs = seed_nodes_kwargs or {}
|
85
|
+
self.seed_edges_kwargs = seed_edges_kwargs or {}
|
86
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
87
|
+
self.loader_kwargs = loader_kwargs or {}
|
88
|
+
|
89
|
+
def query(self, query: Any) -> Data:
|
90
|
+
"""Retrieve a subgraph associated with the query with all its feature
|
91
|
+
attributes.
|
92
|
+
"""
|
93
|
+
seed_nodes = self.feature_store.retrieve_seed_nodes(
|
94
|
+
query, **self.seed_nodes_kwargs)
|
95
|
+
seed_edges = self.feature_store.retrieve_seed_edges(
|
96
|
+
query, **self.seed_edges_kwargs)
|
97
|
+
|
98
|
+
subgraph_sample = self.graph_store.sample_subgraph(
|
99
|
+
seed_nodes, seed_edges, **self.sampler_kwargs)
|
100
|
+
|
101
|
+
data = self.feature_store.load_subgraph(sample=subgraph_sample,
|
102
|
+
**self.loader_kwargs)
|
103
|
+
|
104
|
+
if self.local_filter:
|
105
|
+
data = self.local_filter(data, query)
|
106
|
+
return data
|
@@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module):
|
|
21
21
|
(default: :obj:`False`)
|
22
22
|
mlp_out_channels (int, optional): The size of each graph embedding
|
23
23
|
after projection. (default: :obj:`4096`)
|
24
|
+
mlp_out_tokens (int, optional): Number of LLM prefix tokens to
|
25
|
+
reserve for GNN output. (default: :obj:`1`)
|
24
26
|
|
25
27
|
.. warning::
|
26
28
|
This module has been tested with the following HuggingFace models
|
@@ -43,6 +45,7 @@ class GRetriever(torch.nn.Module):
|
|
43
45
|
gnn: torch.nn.Module,
|
44
46
|
use_lora: bool = False,
|
45
47
|
mlp_out_channels: int = 4096,
|
48
|
+
mlp_out_tokens: int = 1,
|
46
49
|
) -> None:
|
47
50
|
super().__init__()
|
48
51
|
|
@@ -77,7 +80,9 @@ class GRetriever(torch.nn.Module):
|
|
77
80
|
self.projector = torch.nn.Sequential(
|
78
81
|
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
|
79
82
|
torch.nn.Sigmoid(),
|
80
|
-
torch.nn.Linear(mlp_hidden_channels,
|
83
|
+
torch.nn.Linear(mlp_hidden_channels,
|
84
|
+
mlp_out_channels * mlp_out_tokens),
|
85
|
+
torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
|
81
86
|
).to(self.llm.device)
|
82
87
|
|
83
88
|
def encode(
|
@@ -126,6 +131,9 @@ class GRetriever(torch.nn.Module):
|
|
126
131
|
x = self.projector(x)
|
127
132
|
xs = x.split(1, dim=0)
|
128
133
|
|
134
|
+
# Handle case where theres more than one embedding for each sample
|
135
|
+
xs = [x.squeeze(0) for x in xs]
|
136
|
+
|
129
137
|
# Handle questions without node features:
|
130
138
|
batch_unique = batch.unique()
|
131
139
|
batch_size = len(question)
|
@@ -182,6 +190,9 @@ class GRetriever(torch.nn.Module):
|
|
182
190
|
x = self.projector(x)
|
183
191
|
xs = x.split(1, dim=0)
|
184
192
|
|
193
|
+
# Handle case where theres more than one embedding for each sample
|
194
|
+
xs = [x.squeeze(0) for x in xs]
|
195
|
+
|
185
196
|
# Handle questions without node features:
|
186
197
|
batch_unique = batch.unique()
|
187
198
|
batch_size = len(question)
|
@@ -20,6 +20,7 @@ from .utils import (
|
|
20
20
|
get_gpu_memory_from_nvidia_smi,
|
21
21
|
get_model_size,
|
22
22
|
)
|
23
|
+
from .nvtx import nvtxit
|
23
24
|
|
24
25
|
__all__ = [
|
25
26
|
'profileit',
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
38
39
|
'get_gpu_memory_from_nvidia_smi',
|
39
40
|
'get_gpu_memory_from_ipex',
|
40
41
|
'benchmark',
|
42
|
+
'nvtxit',
|
41
43
|
]
|
42
44
|
|
43
45
|
classes = __all__
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from functools import wraps
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
CUDA_PROFILE_STARTED = False
|
7
|
+
|
8
|
+
|
9
|
+
def begin_cuda_profile():
|
10
|
+
global CUDA_PROFILE_STARTED
|
11
|
+
prev_state = CUDA_PROFILE_STARTED
|
12
|
+
if prev_state is False:
|
13
|
+
CUDA_PROFILE_STARTED = True
|
14
|
+
torch.cuda.cudart().cudaProfilerStart()
|
15
|
+
return prev_state
|
16
|
+
|
17
|
+
|
18
|
+
def end_cuda_profile(prev_state: bool):
|
19
|
+
global CUDA_PROFILE_STARTED
|
20
|
+
CUDA_PROFILE_STARTED = prev_state
|
21
|
+
if prev_state is False:
|
22
|
+
torch.cuda.cudart().cudaProfilerStop()
|
23
|
+
|
24
|
+
|
25
|
+
def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
|
26
|
+
n_iters: Optional[int] = None):
|
27
|
+
"""Enables NVTX profiling for a function.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
name (Optional[str], optional): Name to give the reference frame for
|
31
|
+
the function being wrapped. Defaults to the name of the
|
32
|
+
function in code.
|
33
|
+
n_warmups (int, optional): Number of iters to call that function
|
34
|
+
before starting. Defaults to 0.
|
35
|
+
n_iters (Optional[int], optional): Number of iters of that function to
|
36
|
+
record. Defaults to all of them.
|
37
|
+
"""
|
38
|
+
def nvtx(func):
|
39
|
+
|
40
|
+
nonlocal name
|
41
|
+
iters_so_far = 0
|
42
|
+
if name is None:
|
43
|
+
name = func.__name__
|
44
|
+
|
45
|
+
@wraps(func)
|
46
|
+
def wrapper(*args, **kwargs):
|
47
|
+
nonlocal iters_so_far
|
48
|
+
if not torch.cuda.is_available():
|
49
|
+
return func(*args, **kwargs)
|
50
|
+
elif iters_so_far < n_warmups:
|
51
|
+
iters_so_far += 1
|
52
|
+
return func(*args, **kwargs)
|
53
|
+
elif n_iters is None or iters_so_far < n_iters + n_warmups:
|
54
|
+
prev_state = begin_cuda_profile()
|
55
|
+
torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
|
56
|
+
result = func(*args, **kwargs)
|
57
|
+
torch.cuda.nvtx.range_pop()
|
58
|
+
end_cuda_profile(prev_state)
|
59
|
+
iters_so_far += 1
|
60
|
+
return result
|
61
|
+
else:
|
62
|
+
return func(*args, **kwargs)
|
63
|
+
|
64
|
+
return wrapper
|
65
|
+
|
66
|
+
return nvtx
|
File without changes
|