pyg-nightly 2.7.0.dev20241125__py3-none-any.whl → 2.7.0.dev20241126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20241125
3
+ Version: 2.7.0.dev20241126
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=hPMlzqznHr3x2xBZhYBnmC-i7KVOX-tIpw1gy43En6g,1904
1
+ torch_geometric/__init__.py,sha256=byc0Xe43_b_bDlW1ufoO-jvXP0Uu6SzYXuRtGKXmqyw,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -29,7 +29,7 @@ torch_geometric/contrib/nn/conv/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2e
29
29
  torch_geometric/contrib/nn/models/__init__.py,sha256=3ia5cX-TPhouLl6jn_HA-Rd2LaaQvFgy5CjRk0ovKRU,113
30
30
  torch_geometric/contrib/nn/models/rbcd_attack.py,sha256=qcyxBxAbx8LKzpp3RoJQ0cxl9aB2onsWT4oY1fsM7us,33280
31
31
  torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
32
- torch_geometric/data/__init__.py,sha256=OLkV82AGm6xMSynT_DHfRE6_INfPxLx4BQnY0-WVn54,4323
32
+ torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
33
33
  torch_geometric/data/batch.py,sha256=C9cT7-rcWPgnG68Eb_uAcn90HS3OvOG6n4fY3ihpFhI,8764
34
34
  torch_geometric/data/collate.py,sha256=RRiUMBLxDAitaHx7zF0qiMR2nW1NY_0uaNdxlUo5-bo,12756
35
35
  torch_geometric/data/data.py,sha256=l_gHy18g9WtiSCm1mDinR4vGrZOLetogrw5wJEcn23E,43807
@@ -43,6 +43,7 @@ torch_geometric/data/graph_store.py,sha256=oFrLDNP5hKf3HWWsFsjcamx5vLIEk8JnLjuGp
43
43
  torch_geometric/data/hetero_data.py,sha256=q0L3bENyEvo_BGLPwZPVzh730Aak6sQ7yXoawPgM72E,47982
44
44
  torch_geometric/data/hypergraph_data.py,sha256=33hsXW25Yz4Ju8mKajYinZOrkqrUi1SqThG7MlOOYNM,8294
45
45
  torch_geometric/data/in_memory_dataset.py,sha256=F35hU9Dw3qiJUL5E1CCAfq-1xrlUMstXBmQVEQdtJ1I,13403
46
+ torch_geometric/data/large_graph_indexer.py,sha256=JqozKbn5C-jLq2uydeImWqihvBRg8nl5Al55V5s53aw,25433
46
47
  torch_geometric/data/makedirs.py,sha256=6uOv4y34i947cm4rv7Aj2_YZBq-EOsyPKnlGA188YSw,463
47
48
  torch_geometric/data/on_disk_dataset.py,sha256=77om-e6kzcpBb77kf7um1xY8-yHmQaao_6R7I-3NwHk,6629
48
49
  torch_geometric/data/remote_backend_utils.py,sha256=Rzpq1PczXuHhUscrFtIAL6dua6pMehSJlXG7yEsrrrg,4503
@@ -261,7 +262,7 @@ torch_geometric/io/ply.py,sha256=NdeTtr79vJ1HS37ZV2N61EUmA5NGJd2I6cUj1Pg7Ypg,489
261
262
  torch_geometric/io/sdf.py,sha256=H2PC6dSW9Kncc1ulb0UN0JnTRT93NY2fY8lf6K4hb50,1165
262
263
  torch_geometric/io/tu.py,sha256=-v5Ago7DfmGTRBtB5RZFvmv4XpLnKKnk-NOnxlHtB_c,4881
263
264
  torch_geometric/io/txt_array.py,sha256=LDeX2qtlNKW-kVe-wpnskMwAdXQp1jVCGQnrJce7Smg,910
264
- torch_geometric/loader/__init__.py,sha256=w9LSTbyrLRkyrLXi_10d80csWgfKOKDRQDJXRdcfD0M,1835
265
+ torch_geometric/loader/__init__.py,sha256=o0wC0Gvv4rewpZU_YeVaJZCCZZJQG2v8MfZhjocvKp8,1896
265
266
  torch_geometric/loader/base.py,sha256=ataIwNEYL0px3CN3LJEgXIVTRylDHB6-yBFXXuX2JN0,1615
266
267
  torch_geometric/loader/cache.py,sha256=S65heO3YTyUPbttqizCNtKPHIoAw5iHRpbvw6KlXmok,2106
267
268
  torch_geometric/loader/cluster.py,sha256=eMNxVkvZt5oQ_gJRgmWm1NBX7zU2tZI_BPaXeB0wuyk,13465
@@ -280,6 +281,7 @@ torch_geometric/loader/neighbor_loader.py,sha256=vnLn_RhBKTux5h8pi0vzj0d7JPoOpLA
280
281
  torch_geometric/loader/neighbor_sampler.py,sha256=mraVFXIIGctYot4Xr2VOAhCKAOQyW2gP9KROf7g6tcc,8497
281
282
  torch_geometric/loader/node_loader.py,sha256=g_kV5N0tO6eMSFPc5fdbzfHr4COAeKVJi7FEq52f4zc,11848
282
283
  torch_geometric/loader/prefetch.py,sha256=p1mr54TL4nx3Ea0fBy0JulGYJ8Hq4_9rsiNioZsIW-4,3211
284
+ torch_geometric/loader/rag_loader.py,sha256=nwswemzYL4wCKljXqsxMDg07x6PkLU_kgAkNFj5TwUY,4555
283
285
  torch_geometric/loader/random_node_loader.py,sha256=rCmRXYv70SPxBo-Oh049eFEWEZDV7FmlRPzmjcoirXQ,2196
284
286
  torch_geometric/loader/shadow.py,sha256=_hCspYf9SlJYX0lqEjxFec9e9t1iMScNThOoWR1wQGM,4173
285
287
  torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_yndivXt7Z6_w62zI,2248
@@ -431,7 +433,7 @@ torch_geometric/nn/models/deep_graph_infomax.py,sha256=u6j-5-iHBASDCZ776dyfCI1N8
431
433
  torch_geometric/nn/models/deepgcn.py,sha256=tIgT03cj8MghYlxEozpoGvGG_CwpJrGDxv1Z0CVIUts,4339
432
434
  torch_geometric/nn/models/dimenet.py,sha256=Kc5p-rB5q-0e8lY22l-OdQTscTxJh2lTEpeRFMdL4RY,36186
433
435
  torch_geometric/nn/models/dimenet_utils.py,sha256=Eyn_EiJqwKvuYj6BtRpSxrzMG3v4Gk98X9MxZ7uvwm4,5069
434
- torch_geometric/nn/models/g_retriever.py,sha256=VueRImNJlh1WvRWcsSXliSw8RlxlzWlu2WSFs_VQaJc,7749
436
+ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0xM5sx5MEW8Ok,8258
435
437
  torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
436
438
  torch_geometric/nn/models/glem.py,sha256=gqQF4jlU7U_u5-zGeJZuHiEqhSXa-wLU5TghN4u5fYY,16389
437
439
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
@@ -496,8 +498,9 @@ torch_geometric/nn/pool/select/base.py,sha256=On7xaGjMf_ISRvDmpBpJ0krYof0a78XEzS
496
498
  torch_geometric/nn/pool/select/topk.py,sha256=R1LTjOvanJqlrcDe0qinqz286qOJpmjC1tPeiQdPGcU,5305
497
499
  torch_geometric/nn/unpool/__init__.py,sha256=J6I3abNR1MRxisXzbX3sBRH-hlMpmUe7FVc3UziZ67s,129
498
500
  torch_geometric/nn/unpool/knn_interpolate.py,sha256=8GlKoB-wzZz6ETJP7SsKHbzwenr4JiPg6sK3uh9I6R8,2586
499
- torch_geometric/profile/__init__.py,sha256=G-GJ-sIFctmEGepDrbr5-ETWSxZIjsHLS7XzxiOQJ1E,863
501
+ torch_geometric/profile/__init__.py,sha256=R5dQw0vA5Ukf6FrJgptNMCOb__kcgwnxFThA8qaJF8k,902
500
502
  torch_geometric/profile/benchmark.py,sha256=EuD12qJiiPCSwkg5w8arELXiRT_QY_3Wz_rqs7LpDKE,5256
503
+ torch_geometric/profile/nvtx.py,sha256=AKBr-rqlHDnls_UM02Dfq5BZmyFTHS5Li5gaeKmsAJI,2032
501
504
  torch_geometric/profile/profile.py,sha256=cHCY4U0XtyqyKC5u380q6TspsOZ5tGHNXaZsKuzYi1A,11793
502
505
  torch_geometric/profile/profiler.py,sha256=rfNciRzWDka_BgO6aPFi3cy8mcT4lSgFWy-WfPgI2SI,16891
503
506
  torch_geometric/profile/utils.py,sha256=7h6vzTzW8vv-ZqMOz2DV8HHNgC9ViOrN7IR9d3BPDZ8,5497
@@ -626,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
626
629
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
627
630
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
628
631
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
629
- pyg_nightly-2.7.0.dev20241125.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
630
- pyg_nightly-2.7.0.dev20241125.dist-info/METADATA,sha256=bDgjxvVn0QZLKMZH40NUhX3W96-XohGqDUXoYJ8Ly3A,62979
631
- pyg_nightly-2.7.0.dev20241125.dist-info/RECORD,,
632
+ pyg_nightly-2.7.0.dev20241126.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
+ pyg_nightly-2.7.0.dev20241126.dist-info/METADATA,sha256=kYXIEsYg6p5-3FGWoEJAsjfBZTNbsotYFGiCqV6r4JA,62979
634
+ pyg_nightly-2.7.0.dev20241126.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20241125'
33
+ __version__ = '2.7.0.dev20241126'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -16,6 +16,7 @@ from .on_disk_dataset import OnDiskDataset
16
16
  from .makedirs import makedirs
17
17
  from .download import download_url, download_google_url
18
18
  from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
19
+ from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
19
20
 
20
21
  from torch_geometric.lazy_loader import LazyLoader
21
22
 
@@ -27,6 +28,8 @@ data_classes = [
27
28
  'Dataset',
28
29
  'InMemoryDataset',
29
30
  'OnDiskDataset',
31
+ 'LargeGraphIndexer',
32
+ 'TripletLike',
30
33
  ]
31
34
 
32
35
  remote_backend_classes = [
@@ -50,6 +53,8 @@ helper_functions = [
50
53
  'extract_zip',
51
54
  'extract_bz2',
52
55
  'extract_gz',
56
+ 'get_features_for_triplets',
57
+ "get_features_for_triplets_groups",
53
58
  ]
54
59
 
55
60
  __all__ = data_classes + remote_backend_classes + helper_functions
@@ -0,0 +1,677 @@
1
+ import os
2
+ import pickle as pkl
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from itertools import chain
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Hashable,
11
+ Iterable,
12
+ Iterator,
13
+ List,
14
+ Optional,
15
+ Sequence,
16
+ Set,
17
+ Tuple,
18
+ Union,
19
+ )
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from tqdm import tqdm
24
+
25
+ from torch_geometric.data import Data
26
+ from torch_geometric.typing import WITH_PT24
27
+
28
+ TripletLike = Tuple[Hashable, Hashable, Hashable]
29
+
30
+ KnowledgeGraphLike = Iterable[TripletLike]
31
+
32
+
33
+ def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
34
+ return list(dict.fromkeys(values))
35
+
36
+
37
+ # TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
38
+
39
+ NODE_PID = "pid"
40
+
41
+ NODE_KEYS = {NODE_PID}
42
+
43
+ EDGE_PID = "e_pid"
44
+ EDGE_HEAD = "h"
45
+ EDGE_RELATION = "r"
46
+ EDGE_TAIL = "t"
47
+ EDGE_INDEX = "edge_idx"
48
+
49
+ EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
50
+
51
+ FeatureValueType = Union[Sequence[Any], Tensor]
52
+
53
+
54
+ @dataclass
55
+ class MappedFeature:
56
+ name: str
57
+ values: FeatureValueType
58
+
59
+ def __eq__(self, value: "MappedFeature") -> bool:
60
+ eq = self.name == value.name
61
+ if isinstance(self.values, torch.Tensor):
62
+ eq &= torch.equal(self.values, value.values)
63
+ else:
64
+ eq &= self.values == value.values
65
+ return eq
66
+
67
+
68
+ if WITH_PT24:
69
+ torch.serialization.add_safe_globals([MappedFeature])
70
+
71
+
72
+ class LargeGraphIndexer:
73
+ """For a dataset that consists of mulitiple subgraphs that are assumed to
74
+ be part of a much larger graph, collate the values into a large graph store
75
+ to save resources.
76
+ """
77
+ def __init__(
78
+ self,
79
+ nodes: Iterable[Hashable],
80
+ edges: KnowledgeGraphLike,
81
+ node_attr: Optional[Dict[str, List[Any]]] = None,
82
+ edge_attr: Optional[Dict[str, List[Any]]] = None,
83
+ ) -> None:
84
+ r"""Constructs a new index that uniquely catalogs each node and edge
85
+ by id. Not meant to be used directly.
86
+
87
+ Args:
88
+ nodes (Iterable[Hashable]): Node ids in the graph.
89
+ edges (KnowledgeGraphLike): Edge ids in the graph.
90
+ node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
91
+ attribute name and list of their values in order of unique node
92
+ ids. Defaults to None.
93
+ edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
94
+ attribute name and list of their values in order of unique edge
95
+ ids. Defaults to None.
96
+ """
97
+ self._nodes: Dict[Hashable, int] = dict()
98
+ self._edges: Dict[TripletLike, int] = dict()
99
+
100
+ self._mapped_node_features: Set[str] = set()
101
+ self._mapped_edge_features: Set[str] = set()
102
+
103
+ if len(nodes) != len(set(nodes)):
104
+ raise AttributeError("Nodes need to be unique")
105
+ if len(edges) != len(set(edges)):
106
+ raise AttributeError("Edges need to be unique")
107
+
108
+ if node_attr is not None:
109
+ # TODO: Validity checks btw nodes and node_attr
110
+ self.node_attr = node_attr
111
+ if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
112
+ raise AttributeError(
113
+ "Invalid node_attr object. Missing " +
114
+ f"{NODE_KEYS - set(self.node_attr.keys())}")
115
+ elif self.node_attr[NODE_PID] != nodes:
116
+ raise AttributeError(
117
+ "Nodes provided do not match those in node_attr")
118
+ else:
119
+ self.node_attr = dict()
120
+ self.node_attr[NODE_PID] = nodes
121
+
122
+ for i, node in enumerate(self.node_attr[NODE_PID]):
123
+ self._nodes[node] = i
124
+
125
+ if edge_attr is not None:
126
+ # TODO: Validity checks btw edges and edge_attr
127
+ self.edge_attr = edge_attr
128
+
129
+ if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
130
+ raise AttributeError(
131
+ "Invalid edge_attr object. Missing " +
132
+ f"{EDGE_KEYS - set(self.edge_attr.keys())}")
133
+ elif self.node_attr[EDGE_PID] != edges:
134
+ raise AttributeError(
135
+ "Edges provided do not match those in edge_attr")
136
+
137
+ else:
138
+ self.edge_attr = dict()
139
+ for default_key in EDGE_KEYS:
140
+ self.edge_attr[default_key] = list()
141
+ self.edge_attr[EDGE_PID] = edges
142
+
143
+ for i, tup in enumerate(edges):
144
+ h, r, t = tup
145
+ self.edge_attr[EDGE_HEAD].append(h)
146
+ self.edge_attr[EDGE_RELATION].append(r)
147
+ self.edge_attr[EDGE_TAIL].append(t)
148
+ self.edge_attr[EDGE_INDEX].append(
149
+ (self._nodes[h], self._nodes[t]))
150
+
151
+ for i, tup in enumerate(edges):
152
+ self._edges[tup] = i
153
+
154
+ @classmethod
155
+ def from_triplets(
156
+ cls,
157
+ triplets: KnowledgeGraphLike,
158
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
159
+ ) -> "LargeGraphIndexer":
160
+ r"""Generate a new index from a series of triplets that represent edge
161
+ relations between nodes.
162
+ Formatted like (source_node, edge, dest_node).
163
+
164
+ Args:
165
+ triplets (KnowledgeGraphLike): Series of triplets representing
166
+ knowledge graph relations.
167
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
168
+ Optional preprocessing function to apply to triplets.
169
+ Defaults to None.
170
+
171
+ Returns:
172
+ LargeGraphIndexer: Index of unique nodes and edges.
173
+ """
174
+ # NOTE: Right now assumes that all trips can be loaded into memory
175
+ nodes = set()
176
+ edges = set()
177
+
178
+ if pre_transform is not None:
179
+
180
+ def apply_transform(
181
+ trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
182
+ for trip in trips:
183
+ yield pre_transform(trip)
184
+
185
+ triplets = apply_transform(triplets)
186
+
187
+ for h, r, t in triplets:
188
+
189
+ for node in (h, t):
190
+ nodes.add(node)
191
+
192
+ edge_idx = (h, r, t)
193
+ edges.add(edge_idx)
194
+
195
+ return cls(list(nodes), list(edges))
196
+
197
+ @classmethod
198
+ def collate(cls,
199
+ graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
200
+ r"""Combines a series of large graph indexes into a single large graph
201
+ index.
202
+
203
+ Args:
204
+ graphs (Iterable[&quot;LargeGraphIndexer&quot;]): Indices to be
205
+ combined.
206
+
207
+ Returns:
208
+ LargeGraphIndexer: Singular unique index for all nodes and edges
209
+ in input indices.
210
+ """
211
+ # FIXME Needs to merge node attrs and edge attrs?
212
+ trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
213
+ return cls.from_triplets(trips)
214
+
215
+ def get_unique_node_features(
216
+ self, feature_name: str = NODE_PID) -> List[Hashable]:
217
+ r"""Get all the unique values for a specific node attribute.
218
+
219
+ Args:
220
+ feature_name (str, optional): Name of feature to get.
221
+ Defaults to NODE_PID.
222
+
223
+ Returns:
224
+ List[Hashable]: List of unique values for the specified feature.
225
+ """
226
+ try:
227
+ if feature_name in self._mapped_node_features:
228
+ raise IndexError(
229
+ "Only non-mapped features can be retrieved uniquely.")
230
+ return ordered_set(self.get_node_features(feature_name))
231
+
232
+ except KeyError:
233
+ raise AttributeError(
234
+ f"Nodes do not have a feature called {feature_name}")
235
+
236
+ def add_node_feature(
237
+ self,
238
+ new_feature_name: str,
239
+ new_feature_vals: FeatureValueType,
240
+ map_from_feature: str = NODE_PID,
241
+ ) -> None:
242
+ r"""Adds a new feature that corresponds to each unique node in
243
+ the graph.
244
+
245
+ Args:
246
+ new_feature_name (str): Name to call the new feature.
247
+ new_feature_vals (FeatureValueType): Values to map for that
248
+ new feature.
249
+ map_from_feature (str, optional): Key of feature to map from.
250
+ Size must match the number of feature values.
251
+ Defaults to NODE_PID.
252
+ """
253
+ if new_feature_name in self.node_attr:
254
+ raise AttributeError("Features cannot be overridden once created")
255
+ if map_from_feature in self._mapped_node_features:
256
+ raise AttributeError(
257
+ f"{map_from_feature} is already a feature mapping.")
258
+
259
+ feature_keys = self.get_unique_node_features(map_from_feature)
260
+ if len(feature_keys) != len(new_feature_vals):
261
+ raise AttributeError(
262
+ "Expected encodings for {len(feature_keys)} unique features," +
263
+ f" but got {len(new_feature_vals)} encodings.")
264
+
265
+ if map_from_feature == NODE_PID:
266
+ self.node_attr[new_feature_name] = new_feature_vals
267
+ else:
268
+ self.node_attr[new_feature_name] = MappedFeature(
269
+ name=map_from_feature, values=new_feature_vals)
270
+ self._mapped_node_features.add(new_feature_name)
271
+
272
+ def get_node_features(
273
+ self,
274
+ feature_name: str = NODE_PID,
275
+ pids: Optional[Iterable[Hashable]] = None,
276
+ ) -> List[Any]:
277
+ r"""Get node feature values for a given set of unique node ids.
278
+ Returned values are not necessarily unique.
279
+
280
+ Args:
281
+ feature_name (str, optional): Name of feature to fetch. Defaults
282
+ to NODE_PID.
283
+ pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
284
+ for. Defaults to None, which fetches all nodes.
285
+
286
+ Returns:
287
+ List[Any]: Node features corresponding to the specified ids.
288
+ """
289
+ if feature_name in self._mapped_node_features:
290
+ values = self.node_attr[feature_name].values
291
+ else:
292
+ values = self.node_attr[feature_name]
293
+
294
+ # TODO: torch_geometric.utils.select
295
+ if isinstance(values, torch.Tensor):
296
+ idxs = list(
297
+ self.get_node_features_iter(feature_name, pids,
298
+ index_only=True))
299
+ return values[idxs]
300
+ return list(self.get_node_features_iter(feature_name, pids))
301
+
302
+ def get_node_features_iter(
303
+ self,
304
+ feature_name: str = NODE_PID,
305
+ pids: Optional[Iterable[Hashable]] = None,
306
+ index_only: bool = False,
307
+ ) -> Iterator[Any]:
308
+ """Iterator version of get_node_features. If index_only is True,
309
+ yields indices instead of values.
310
+ """
311
+ if pids is None:
312
+ pids = self.node_attr[NODE_PID]
313
+
314
+ if feature_name in self._mapped_node_features:
315
+ feature_map_info = self.node_attr[feature_name]
316
+ from_feature_name, to_feature_vals = (
317
+ feature_map_info.name,
318
+ feature_map_info.values,
319
+ )
320
+ from_feature_vals = self.get_unique_node_features(
321
+ from_feature_name)
322
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
323
+
324
+ for pid in pids:
325
+ idx = self._nodes[pid]
326
+ from_feature_val = self.node_attr[from_feature_name][idx]
327
+ to_feature_idx = feature_mapping[from_feature_val]
328
+ if index_only:
329
+ yield to_feature_idx
330
+ else:
331
+ yield to_feature_vals[to_feature_idx]
332
+ else:
333
+ for pid in pids:
334
+ idx = self._nodes[pid]
335
+ if index_only:
336
+ yield idx
337
+ else:
338
+ yield self.node_attr[feature_name][idx]
339
+
340
+ def get_unique_edge_features(
341
+ self, feature_name: str = EDGE_PID) -> List[Hashable]:
342
+ r"""Get all the unique values for a specific edge attribute.
343
+
344
+ Args:
345
+ feature_name (str, optional): Name of feature to get.
346
+ Defaults to EDGE_PID.
347
+
348
+ Returns:
349
+ List[Hashable]: List of unique values for the specified feature.
350
+ """
351
+ try:
352
+ if feature_name in self._mapped_edge_features:
353
+ raise IndexError(
354
+ "Only non-mapped features can be retrieved uniquely.")
355
+ return ordered_set(self.get_edge_features(feature_name))
356
+ except KeyError:
357
+ raise AttributeError(
358
+ f"Edges do not have a feature called {feature_name}")
359
+
360
+ def add_edge_feature(
361
+ self,
362
+ new_feature_name: str,
363
+ new_feature_vals: FeatureValueType,
364
+ map_from_feature: str = EDGE_PID,
365
+ ) -> None:
366
+ r"""Adds a new feature that corresponds to each unique edge in
367
+ the graph.
368
+
369
+ Args:
370
+ new_feature_name (str): Name to call the new feature.
371
+ new_feature_vals (FeatureValueType): Values to map for that new
372
+ feature.
373
+ map_from_feature (str, optional): Key of feature to map from.
374
+ Size must match the number of feature values.
375
+ Defaults to EDGE_PID.
376
+ """
377
+ if new_feature_name in self.edge_attr:
378
+ raise AttributeError("Features cannot be overridden once created")
379
+ if map_from_feature in self._mapped_edge_features:
380
+ raise AttributeError(
381
+ f"{map_from_feature} is already a feature mapping.")
382
+
383
+ feature_keys = self.get_unique_edge_features(map_from_feature)
384
+ if len(feature_keys) != len(new_feature_vals):
385
+ raise AttributeError(
386
+ f"Expected encodings for {len(feature_keys)} unique features, "
387
+ + f"but got {len(new_feature_vals)} encodings.")
388
+
389
+ if map_from_feature == EDGE_PID:
390
+ self.edge_attr[new_feature_name] = new_feature_vals
391
+ else:
392
+ self.edge_attr[new_feature_name] = MappedFeature(
393
+ name=map_from_feature, values=new_feature_vals)
394
+ self._mapped_edge_features.add(new_feature_name)
395
+
396
+ def get_edge_features(
397
+ self,
398
+ feature_name: str = EDGE_PID,
399
+ pids: Optional[Iterable[Hashable]] = None,
400
+ ) -> List[Any]:
401
+ r"""Get edge feature values for a given set of unique edge ids.
402
+ Returned values are not necessarily unique.
403
+
404
+ Args:
405
+ feature_name (str, optional): Name of feature to fetch.
406
+ Defaults to EDGE_PID.
407
+ pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
408
+ for. Defaults to None, which fetches all edges.
409
+
410
+ Returns:
411
+ List[Any]: Node features corresponding to the specified ids.
412
+ """
413
+ if feature_name in self._mapped_edge_features:
414
+ values = self.edge_attr[feature_name].values
415
+ else:
416
+ values = self.edge_attr[feature_name]
417
+
418
+ # TODO: torch_geometric.utils.select
419
+ if isinstance(values, torch.Tensor):
420
+ idxs = list(
421
+ self.get_edge_features_iter(feature_name, pids,
422
+ index_only=True))
423
+ return values[idxs]
424
+ return list(self.get_edge_features_iter(feature_name, pids))
425
+
426
+ def get_edge_features_iter(
427
+ self,
428
+ feature_name: str = EDGE_PID,
429
+ pids: Optional[KnowledgeGraphLike] = None,
430
+ index_only: bool = False,
431
+ ) -> Iterator[Any]:
432
+ """Iterator version of get_edge_features. If index_only is True,
433
+ yields indices instead of values.
434
+ """
435
+ if pids is None:
436
+ pids = self.edge_attr[EDGE_PID]
437
+
438
+ if feature_name in self._mapped_edge_features:
439
+ feature_map_info = self.edge_attr[feature_name]
440
+ from_feature_name, to_feature_vals = (
441
+ feature_map_info.name,
442
+ feature_map_info.values,
443
+ )
444
+ from_feature_vals = self.get_unique_edge_features(
445
+ from_feature_name)
446
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
447
+
448
+ for pid in pids:
449
+ idx = self._edges[pid]
450
+ from_feature_val = self.edge_attr[from_feature_name][idx]
451
+ to_feature_idx = feature_mapping[from_feature_val]
452
+ if index_only:
453
+ yield to_feature_idx
454
+ else:
455
+ yield to_feature_vals[to_feature_idx]
456
+ else:
457
+ for pid in pids:
458
+ idx = self._edges[pid]
459
+ if index_only:
460
+ yield idx
461
+ else:
462
+ yield self.edge_attr[feature_name][idx]
463
+
464
+ def to_triplets(self) -> Iterator[TripletLike]:
465
+ return iter(self.edge_attr[EDGE_PID])
466
+
467
+ def save(self, path: str) -> None:
468
+ if os.path.exists(path):
469
+ shutil.rmtree(path)
470
+ os.makedirs(path, exist_ok=True)
471
+ with open(path + "/edges", "wb") as f:
472
+ pkl.dump(self._edges, f)
473
+ with open(path + "/nodes", "wb") as f:
474
+ pkl.dump(self._nodes, f)
475
+
476
+ with open(path + "/mapped_edges", "wb") as f:
477
+ pkl.dump(self._mapped_edge_features, f)
478
+ with open(path + "/mapped_nodes", "wb") as f:
479
+ pkl.dump(self._mapped_node_features, f)
480
+
481
+ node_attr_path = path + "/node_attr"
482
+ os.makedirs(node_attr_path, exist_ok=True)
483
+ for attr_name, vals in self.node_attr.items():
484
+ torch.save(vals, node_attr_path + f"/{attr_name}.pt")
485
+
486
+ edge_attr_path = path + "/edge_attr"
487
+ os.makedirs(edge_attr_path, exist_ok=True)
488
+ for attr_name, vals in self.edge_attr.items():
489
+ torch.save(vals, edge_attr_path + f"/{attr_name}.pt")
490
+
491
+ @classmethod
492
+ def from_disk(cls, path: str) -> "LargeGraphIndexer":
493
+ indexer = cls(list(), list())
494
+ with open(path + "/edges", "rb") as f:
495
+ indexer._edges = pkl.load(f)
496
+ with open(path + "/nodes", "rb") as f:
497
+ indexer._nodes = pkl.load(f)
498
+
499
+ with open(path + "/mapped_edges", "rb") as f:
500
+ indexer._mapped_edge_features = pkl.load(f)
501
+ with open(path + "/mapped_nodes", "rb") as f:
502
+ indexer._mapped_node_features = pkl.load(f)
503
+
504
+ node_attr_path = path + "/node_attr"
505
+ for fname in os.listdir(node_attr_path):
506
+ full_fname = f"{node_attr_path}/{fname}"
507
+ key = fname.split(".")[0]
508
+ indexer.node_attr[key] = torch.load(full_fname)
509
+
510
+ edge_attr_path = path + "/edge_attr"
511
+ for fname in os.listdir(edge_attr_path):
512
+ full_fname = f"{edge_attr_path}/{fname}"
513
+ key = fname.split(".")[0]
514
+ indexer.edge_attr[key] = torch.load(full_fname)
515
+
516
+ return indexer
517
+
518
+ def to_data(self, node_feature_name: str,
519
+ edge_feature_name: Optional[str] = None) -> Data:
520
+ """Return a Data object containing all the specified node and
521
+ edge features and the graph.
522
+
523
+ Args:
524
+ node_feature_name (str): Feature to use for nodes
525
+ edge_feature_name (Optional[str], optional): Feature to use for
526
+ edges. Defaults to None.
527
+
528
+ Returns:
529
+ Data: Data object containing the specified node and
530
+ edge features and the graph.
531
+ """
532
+ x = torch.Tensor(self.get_node_features(node_feature_name))
533
+ node_id = torch.LongTensor(range(len(x)))
534
+
535
+ edge_index = torch.t(
536
+ torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
537
+
538
+ edge_attr = (self.get_edge_features(edge_feature_name)
539
+ if edge_feature_name is not None else None)
540
+ edge_id = torch.LongTensor(range(len(edge_attr)))
541
+
542
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
543
+ edge_id=edge_id, node_id=node_id)
544
+
545
+ def __eq__(self, value: "LargeGraphIndexer") -> bool:
546
+ eq = True
547
+ eq &= self._nodes == value._nodes
548
+ eq &= self._edges == value._edges
549
+ eq &= self.node_attr.keys() == value.node_attr.keys()
550
+ eq &= self.edge_attr.keys() == value.edge_attr.keys()
551
+ eq &= self._mapped_node_features == value._mapped_node_features
552
+ eq &= self._mapped_edge_features == value._mapped_edge_features
553
+
554
+ for k in self.node_attr:
555
+ eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
556
+ if isinstance(self.node_attr[k], torch.Tensor):
557
+ eq &= torch.equal(self.node_attr[k], value.node_attr[k])
558
+ else:
559
+ eq &= self.node_attr[k] == value.node_attr[k]
560
+ for k in self.edge_attr:
561
+ eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
562
+ if isinstance(self.edge_attr[k], torch.Tensor):
563
+ eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
564
+ else:
565
+ eq &= self.edge_attr[k] == value.edge_attr[k]
566
+ return eq
567
+
568
+
569
+ def get_features_for_triplets_groups(
570
+ indexer: LargeGraphIndexer,
571
+ triplet_groups: Iterable[KnowledgeGraphLike],
572
+ node_feature_name: str = "x",
573
+ edge_feature_name: str = "edge_attr",
574
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
575
+ verbose: bool = False,
576
+ ) -> Iterator[Data]:
577
+ """Given an indexer and a series of triplet groups (like a dataset),
578
+ retrieve the specified node and edge features for each triplet from the
579
+ index.
580
+
581
+ Args:
582
+ indexer (LargeGraphIndexer): Indexer containing desired features
583
+ triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
584
+ triplets to fetch features for
585
+ node_feature_name (str, optional): Node feature to fetch.
586
+ Defaults to "x".
587
+ edge_feature_name (str, optional): edge feature to fetch.
588
+ Defaults to "edge_attr".
589
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
590
+ Optional preprocessing to perform on triplets.
591
+ Defaults to None.
592
+ verbose (bool, optional): Whether to print progress. Defaults to False.
593
+
594
+ Yields:
595
+ Iterator[Data]: For each triplet group, yield a data object containing
596
+ the unique graph and features from the index.
597
+ """
598
+ if pre_transform is not None:
599
+
600
+ def apply_transform(trips):
601
+ for trip in trips:
602
+ yield pre_transform(tuple(trip))
603
+
604
+ # TODO: Make this safe for large amounts of triplets?
605
+ triplet_groups = (list(apply_transform(triplets))
606
+ for triplets in triplet_groups)
607
+
608
+ node_keys = []
609
+ edge_keys = []
610
+ edge_index = []
611
+
612
+ for triplets in tqdm(triplet_groups, disable=not verbose):
613
+ small_graph_indexer = LargeGraphIndexer.from_triplets(
614
+ triplets, pre_transform=pre_transform)
615
+
616
+ node_keys.append(small_graph_indexer.get_node_features())
617
+ edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
618
+ edge_index.append(
619
+ small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
620
+
621
+ node_feats = indexer.get_node_features(feature_name=node_feature_name,
622
+ pids=chain.from_iterable(node_keys))
623
+ edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
624
+ pids=chain.from_iterable(edge_keys))
625
+
626
+ last_node_idx, last_edge_idx = 0, 0
627
+ for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
628
+ nlen, elen = len(nkeys), len(ekeys)
629
+ x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
630
+ last_node_idx += len(nkeys)
631
+
632
+ edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
633
+ elen])
634
+ last_edge_idx += len(ekeys)
635
+
636
+ edge_idx = torch.LongTensor(eidx).T
637
+
638
+ data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
639
+ data_obj[NODE_PID] = node_keys
640
+ data_obj[EDGE_PID] = edge_keys
641
+ data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
642
+ data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
643
+
644
+ yield data_obj
645
+
646
+
647
+ def get_features_for_triplets(
648
+ indexer: LargeGraphIndexer,
649
+ triplets: KnowledgeGraphLike,
650
+ node_feature_name: str = "x",
651
+ edge_feature_name: str = "edge_attr",
652
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
653
+ verbose: bool = False,
654
+ ) -> Data:
655
+ """For a given set of triplets retrieve a Data object containing the
656
+ unique graph and features from the index.
657
+
658
+ Args:
659
+ indexer (LargeGraphIndexer): Indexer containing desired features
660
+ triplets (KnowledgeGraphLike): Triplets to fetch features for
661
+ node_feature_name (str, optional): Feature to use for node features.
662
+ Defaults to "x".
663
+ edge_feature_name (str, optional): Feature to use for edge features.
664
+ Defaults to "edge_attr".
665
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
666
+ Optional preprocessing function for triplets. Defaults to None.
667
+ verbose (bool, optional): Whether to print progress. Defaults to False.
668
+
669
+ Returns:
670
+ Data: Data object containing the unique graph and features from the
671
+ index for the given triplets.
672
+ """
673
+ gen = get_features_for_triplets_groups(indexer, [triplets],
674
+ node_feature_name,
675
+ edge_feature_name, pre_transform,
676
+ verbose)
677
+ return next(gen)
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
22
22
  from .prefetch import PrefetchLoader
23
23
  from .cache import CachedLoader
24
24
  from .mixin import AffinityMixin
25
+ from .rag_loader import RAGQueryLoader
25
26
 
26
27
  __all__ = classes = [
27
28
  'DataLoader',
@@ -50,6 +51,7 @@ __all__ = classes = [
50
51
  'PrefetchLoader',
51
52
  'CachedLoader',
52
53
  'AffinityMixin',
54
+ 'RAGQueryLoader',
53
55
  ]
54
56
 
55
57
  RandomNodeSampler = deprecated(
@@ -0,0 +1,106 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
3
+
4
+ from torch_geometric.data import Data, FeatureStore, HeteroData
5
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
6
+ from torch_geometric.typing import InputEdges, InputNodes
7
+
8
+
9
+ class RAGFeatureStore(Protocol):
10
+ """Feature store for remote GNN RAG backend."""
11
+ @abstractmethod
12
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
13
+ """Makes a comparison between the query and all the nodes to get all
14
+ the closest nodes. Return the indices of the nodes that are to be seeds
15
+ for the RAG Sampler.
16
+ """
17
+ ...
18
+
19
+ @abstractmethod
20
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
21
+ """Makes a comparison between the query and all the edges to get all
22
+ the closest nodes. Returns the edge indices that are to be the seeds
23
+ for the RAG Sampler.
24
+ """
25
+ ...
26
+
27
+ @abstractmethod
28
+ def load_subgraph(
29
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
30
+ ) -> Union[Data, HeteroData]:
31
+ """Combines sampled subgraph output with features in a Data object."""
32
+ ...
33
+
34
+
35
+ class RAGGraphStore(Protocol):
36
+ """Graph store for remote GNN RAG backend."""
37
+ @abstractmethod
38
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
39
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
40
+ """Sample a subgraph using the seeded nodes and edges."""
41
+ ...
42
+
43
+ @abstractmethod
44
+ def register_feature_store(self, feature_store: FeatureStore):
45
+ """Register a feature store to be used with the sampler. Samplers need
46
+ info from the feature store in order to work properly on HeteroGraphs.
47
+ """
48
+ ...
49
+
50
+
51
+ # TODO: Make compatible with Heterographs
52
+
53
+
54
+ class RAGQueryLoader:
55
+ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
56
+ local_filter: Optional[Callable[[Data, Any], Data]] = None,
57
+ seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
58
+ seed_edges_kwargs: Optional[Dict[str, Any]] = None,
59
+ sampler_kwargs: Optional[Dict[str, Any]] = None,
60
+ loader_kwargs: Optional[Dict[str, Any]] = None):
61
+ """Loader meant for making queries from a remote backend.
62
+
63
+ Args:
64
+ data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
65
+ and GraphStore to load from. Assumed to conform to the
66
+ protocols listed above.
67
+ local_filter (Optional[Callable[[Data, Any], Data]], optional):
68
+ Optional local transform to apply to data after retrieval.
69
+ Defaults to None.
70
+ seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
71
+ to pass into process for fetching seed nodes. Defaults to None.
72
+ seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
73
+ to pass into process for fetching seed edges. Defaults to None.
74
+ sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
75
+ pass into process for sampling graph. Defaults to None.
76
+ loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
77
+ pass into process for loading graph features. Defaults to None.
78
+ """
79
+ fstore, gstore = data
80
+ self.feature_store = fstore
81
+ self.graph_store = gstore
82
+ self.graph_store.register_feature_store(self.feature_store)
83
+ self.local_filter = local_filter
84
+ self.seed_nodes_kwargs = seed_nodes_kwargs or {}
85
+ self.seed_edges_kwargs = seed_edges_kwargs or {}
86
+ self.sampler_kwargs = sampler_kwargs or {}
87
+ self.loader_kwargs = loader_kwargs or {}
88
+
89
+ def query(self, query: Any) -> Data:
90
+ """Retrieve a subgraph associated with the query with all its feature
91
+ attributes.
92
+ """
93
+ seed_nodes = self.feature_store.retrieve_seed_nodes(
94
+ query, **self.seed_nodes_kwargs)
95
+ seed_edges = self.feature_store.retrieve_seed_edges(
96
+ query, **self.seed_edges_kwargs)
97
+
98
+ subgraph_sample = self.graph_store.sample_subgraph(
99
+ seed_nodes, seed_edges, **self.sampler_kwargs)
100
+
101
+ data = self.feature_store.load_subgraph(sample=subgraph_sample,
102
+ **self.loader_kwargs)
103
+
104
+ if self.local_filter:
105
+ data = self.local_filter(data, query)
106
+ return data
@@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module):
21
21
  (default: :obj:`False`)
22
22
  mlp_out_channels (int, optional): The size of each graph embedding
23
23
  after projection. (default: :obj:`4096`)
24
+ mlp_out_tokens (int, optional): Number of LLM prefix tokens to
25
+ reserve for GNN output. (default: :obj:`1`)
24
26
 
25
27
  .. warning::
26
28
  This module has been tested with the following HuggingFace models
@@ -43,6 +45,7 @@ class GRetriever(torch.nn.Module):
43
45
  gnn: torch.nn.Module,
44
46
  use_lora: bool = False,
45
47
  mlp_out_channels: int = 4096,
48
+ mlp_out_tokens: int = 1,
46
49
  ) -> None:
47
50
  super().__init__()
48
51
 
@@ -77,7 +80,9 @@ class GRetriever(torch.nn.Module):
77
80
  self.projector = torch.nn.Sequential(
78
81
  torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
79
82
  torch.nn.Sigmoid(),
80
- torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
83
+ torch.nn.Linear(mlp_hidden_channels,
84
+ mlp_out_channels * mlp_out_tokens),
85
+ torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
81
86
  ).to(self.llm.device)
82
87
 
83
88
  def encode(
@@ -126,6 +131,9 @@ class GRetriever(torch.nn.Module):
126
131
  x = self.projector(x)
127
132
  xs = x.split(1, dim=0)
128
133
 
134
+ # Handle case where theres more than one embedding for each sample
135
+ xs = [x.squeeze(0) for x in xs]
136
+
129
137
  # Handle questions without node features:
130
138
  batch_unique = batch.unique()
131
139
  batch_size = len(question)
@@ -182,6 +190,9 @@ class GRetriever(torch.nn.Module):
182
190
  x = self.projector(x)
183
191
  xs = x.split(1, dim=0)
184
192
 
193
+ # Handle case where theres more than one embedding for each sample
194
+ xs = [x.squeeze(0) for x in xs]
195
+
185
196
  # Handle questions without node features:
186
197
  batch_unique = batch.unique()
187
198
  batch_size = len(question)
@@ -20,6 +20,7 @@ from .utils import (
20
20
  get_gpu_memory_from_nvidia_smi,
21
21
  get_model_size,
22
22
  )
23
+ from .nvtx import nvtxit
23
24
 
24
25
  __all__ = [
25
26
  'profileit',
@@ -38,6 +39,7 @@ __all__ = [
38
39
  'get_gpu_memory_from_nvidia_smi',
39
40
  'get_gpu_memory_from_ipex',
40
41
  'benchmark',
42
+ 'nvtxit',
41
43
  ]
42
44
 
43
45
  classes = __all__
@@ -0,0 +1,66 @@
1
+ from functools import wraps
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ CUDA_PROFILE_STARTED = False
7
+
8
+
9
+ def begin_cuda_profile():
10
+ global CUDA_PROFILE_STARTED
11
+ prev_state = CUDA_PROFILE_STARTED
12
+ if prev_state is False:
13
+ CUDA_PROFILE_STARTED = True
14
+ torch.cuda.cudart().cudaProfilerStart()
15
+ return prev_state
16
+
17
+
18
+ def end_cuda_profile(prev_state: bool):
19
+ global CUDA_PROFILE_STARTED
20
+ CUDA_PROFILE_STARTED = prev_state
21
+ if prev_state is False:
22
+ torch.cuda.cudart().cudaProfilerStop()
23
+
24
+
25
+ def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
26
+ n_iters: Optional[int] = None):
27
+ """Enables NVTX profiling for a function.
28
+
29
+ Args:
30
+ name (Optional[str], optional): Name to give the reference frame for
31
+ the function being wrapped. Defaults to the name of the
32
+ function in code.
33
+ n_warmups (int, optional): Number of iters to call that function
34
+ before starting. Defaults to 0.
35
+ n_iters (Optional[int], optional): Number of iters of that function to
36
+ record. Defaults to all of them.
37
+ """
38
+ def nvtx(func):
39
+
40
+ nonlocal name
41
+ iters_so_far = 0
42
+ if name is None:
43
+ name = func.__name__
44
+
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ nonlocal iters_so_far
48
+ if not torch.cuda.is_available():
49
+ return func(*args, **kwargs)
50
+ elif iters_so_far < n_warmups:
51
+ iters_so_far += 1
52
+ return func(*args, **kwargs)
53
+ elif n_iters is None or iters_so_far < n_iters + n_warmups:
54
+ prev_state = begin_cuda_profile()
55
+ torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
56
+ result = func(*args, **kwargs)
57
+ torch.cuda.nvtx.range_pop()
58
+ end_cuda_profile(prev_state)
59
+ iters_so_far += 1
60
+ return result
61
+ else:
62
+ return func(*args, **kwargs)
63
+
64
+ return wrapper
65
+
66
+ return nvtx