pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250907__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/RECORD +32 -25
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +0 -5
- torch_geometric/data/lightning/datamodule.py +2 -2
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/web_qsp_dataset.py +262 -210
- torch_geometric/graphgym/imports.py +2 -2
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
- torch_geometric/{nn → llm}/models/git_mol.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/backend_utils.py +442 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +124 -0
- torch_geometric/loader/__init__.py +0 -4
- torch_geometric/nn/__init__.py +0 -1
- torch_geometric/nn/models/__init__.py +0 -10
- torch_geometric/nn/models/sgformer.py +2 -0
- torch_geometric/loader/rag_loader.py +0 -107
- torch_geometric/nn/nlp/__init__.py +0 -9
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/licenses/LICENSE +0 -0
- /torch_geometric/{nn → llm}/models/glem.py +0 -0
- /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
- /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -0,0 +1,442 @@
|
|
1
|
+
import os
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from enum import Enum, auto
|
4
|
+
from typing import (
|
5
|
+
Any,
|
6
|
+
Callable,
|
7
|
+
Dict,
|
8
|
+
Iterable,
|
9
|
+
Iterator,
|
10
|
+
List,
|
11
|
+
Optional,
|
12
|
+
Protocol,
|
13
|
+
Tuple,
|
14
|
+
Type,
|
15
|
+
Union,
|
16
|
+
no_type_check,
|
17
|
+
runtime_checkable,
|
18
|
+
)
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
from torch import Tensor
|
23
|
+
from torch.nn import Module
|
24
|
+
|
25
|
+
from torch_geometric.data import Data, FeatureStore, GraphStore
|
26
|
+
from torch_geometric.distributed import (
|
27
|
+
LocalFeatureStore,
|
28
|
+
LocalGraphStore,
|
29
|
+
Partitioner,
|
30
|
+
)
|
31
|
+
from torch_geometric.llm.large_graph_indexer import (
|
32
|
+
EDGE_RELATION,
|
33
|
+
LargeGraphIndexer,
|
34
|
+
TripletLike,
|
35
|
+
)
|
36
|
+
from torch_geometric.llm.models import SentenceTransformer
|
37
|
+
from torch_geometric.typing import EdgeType, NodeType
|
38
|
+
|
39
|
+
try:
|
40
|
+
from pandas import DataFrame
|
41
|
+
except ImportError:
|
42
|
+
DataFrame = None
|
43
|
+
RemoteGraphBackend = Tuple[FeatureStore, GraphStore]
|
44
|
+
|
45
|
+
# TODO: Make everything compatible with Hetero graphs aswell
|
46
|
+
|
47
|
+
|
48
|
+
def preprocess_triplet(triplet: TripletLike) -> TripletLike:
|
49
|
+
h, r, t = triplet
|
50
|
+
return str(h).lower(), str(r).lower(), str(t).lower()
|
51
|
+
|
52
|
+
|
53
|
+
@no_type_check
|
54
|
+
def retrieval_via_pcst(
|
55
|
+
data: Data,
|
56
|
+
q_emb: Tensor,
|
57
|
+
textual_nodes: Any,
|
58
|
+
textual_edges: Any,
|
59
|
+
topk: int = 3,
|
60
|
+
topk_e: int = 5,
|
61
|
+
cost_e: float = 0.5,
|
62
|
+
num_clusters: int = 1,
|
63
|
+
) -> Tuple[Data, str]:
|
64
|
+
|
65
|
+
# skip PCST for bad graphs
|
66
|
+
booly = data.edge_attr is None or data.edge_attr.numel() == 0
|
67
|
+
booly = booly or data.x is None or data.x.numel() == 0
|
68
|
+
booly = booly or data.edge_index is None or data.edge_index.numel() == 0
|
69
|
+
if not booly:
|
70
|
+
c = 0.01
|
71
|
+
|
72
|
+
from pcst_fast import pcst_fast
|
73
|
+
|
74
|
+
root = -1
|
75
|
+
pruning = 'gw'
|
76
|
+
verbosity_level = 0
|
77
|
+
if topk > 0:
|
78
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
|
79
|
+
topk = min(topk, data.num_nodes)
|
80
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
81
|
+
|
82
|
+
n_prizes = torch.zeros_like(n_prizes)
|
83
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
84
|
+
else:
|
85
|
+
n_prizes = torch.zeros(data.num_nodes)
|
86
|
+
|
87
|
+
if topk_e > 0:
|
88
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
|
89
|
+
topk_e = min(topk_e, e_prizes.unique().size(0))
|
90
|
+
|
91
|
+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e,
|
92
|
+
largest=True)
|
93
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
94
|
+
last_topk_e_value = topk_e
|
95
|
+
for k in range(topk_e):
|
96
|
+
indices = e_prizes == topk_e_values[k]
|
97
|
+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
|
98
|
+
e_prizes[indices] = value
|
99
|
+
last_topk_e_value = value * (1 - c)
|
100
|
+
# reduce the cost of the edges so that at least one edge is chosen
|
101
|
+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
|
102
|
+
else:
|
103
|
+
e_prizes = torch.zeros(data.num_edges)
|
104
|
+
|
105
|
+
costs = []
|
106
|
+
edges = []
|
107
|
+
virtual_n_prizes = []
|
108
|
+
virtual_edges = []
|
109
|
+
virtual_costs = []
|
110
|
+
mapping_n = {}
|
111
|
+
mapping_e = {}
|
112
|
+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
|
113
|
+
prize_e = e_prizes[i]
|
114
|
+
if prize_e <= cost_e:
|
115
|
+
mapping_e[len(edges)] = i
|
116
|
+
edges.append((src, dst))
|
117
|
+
costs.append(cost_e - prize_e)
|
118
|
+
else:
|
119
|
+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
|
120
|
+
mapping_n[virtual_node_id] = i
|
121
|
+
virtual_edges.append((src, virtual_node_id))
|
122
|
+
virtual_edges.append((virtual_node_id, dst))
|
123
|
+
virtual_costs.append(0)
|
124
|
+
virtual_costs.append(0)
|
125
|
+
virtual_n_prizes.append(prize_e - cost_e)
|
126
|
+
|
127
|
+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
|
128
|
+
num_edges = len(edges)
|
129
|
+
if len(virtual_costs) > 0:
|
130
|
+
costs = np.array(costs + virtual_costs)
|
131
|
+
edges = np.array(edges + virtual_edges)
|
132
|
+
|
133
|
+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
|
134
|
+
pruning, verbosity_level)
|
135
|
+
|
136
|
+
selected_nodes = vertices[vertices < data.num_nodes]
|
137
|
+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
138
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
139
|
+
if len(virtual_vertices) > 0:
|
140
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
141
|
+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
142
|
+
selected_edges = np.array(selected_edges + virtual_edges)
|
143
|
+
|
144
|
+
edge_index = data.edge_index[:, selected_edges]
|
145
|
+
selected_nodes = np.unique(
|
146
|
+
np.concatenate(
|
147
|
+
[selected_nodes, edge_index[0].numpy(),
|
148
|
+
edge_index[1].numpy()]))
|
149
|
+
|
150
|
+
n = textual_nodes.iloc[selected_nodes]
|
151
|
+
e = textual_edges.iloc[selected_edges]
|
152
|
+
else:
|
153
|
+
n = textual_nodes
|
154
|
+
e = textual_edges
|
155
|
+
|
156
|
+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
|
157
|
+
index=False, columns=['src', 'edge_attr', 'dst'])
|
158
|
+
|
159
|
+
if booly:
|
160
|
+
return data, desc
|
161
|
+
|
162
|
+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
163
|
+
src = [mapping[i] for i in edge_index[0].tolist()]
|
164
|
+
dst = [mapping[i] for i in edge_index[1].tolist()]
|
165
|
+
|
166
|
+
# HACK Added so that the subset of nodes and edges selected can be tracked
|
167
|
+
node_idx = np.array(data.node_idx)[selected_nodes]
|
168
|
+
edge_idx = np.array(data.edge_idx)[selected_edges]
|
169
|
+
|
170
|
+
data = Data(
|
171
|
+
x=data.x[selected_nodes],
|
172
|
+
edge_index=torch.tensor([src, dst]).to(torch.long),
|
173
|
+
edge_attr=data.edge_attr[selected_edges],
|
174
|
+
# HACK: track subset of selected nodes/edges
|
175
|
+
node_idx=node_idx,
|
176
|
+
edge_idx=edge_idx,
|
177
|
+
)
|
178
|
+
|
179
|
+
return data, desc
|
180
|
+
|
181
|
+
|
182
|
+
def batch_knn(query_enc: Tensor, embeds: Tensor,
|
183
|
+
k: int) -> Iterator[Tuple[Tensor, Tensor]]:
|
184
|
+
from torchmetrics.functional import pairwise_cosine_similarity
|
185
|
+
prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device))
|
186
|
+
topk = min(k, len(embeds))
|
187
|
+
for i, q in enumerate(prizes):
|
188
|
+
_, indices = torch.topk(q, topk, largest=True)
|
189
|
+
yield indices, query_enc[i].unsqueeze(0)
|
190
|
+
|
191
|
+
|
192
|
+
# Adapted from LocalGraphStore
|
193
|
+
@runtime_checkable
|
194
|
+
class ConvertableGraphStore(Protocol):
|
195
|
+
@classmethod
|
196
|
+
def from_data(
|
197
|
+
cls,
|
198
|
+
edge_id: Tensor,
|
199
|
+
edge_index: Tensor,
|
200
|
+
num_nodes: int,
|
201
|
+
is_sorted: bool = False,
|
202
|
+
) -> GraphStore:
|
203
|
+
...
|
204
|
+
|
205
|
+
@classmethod
|
206
|
+
def from_hetero_data(
|
207
|
+
cls,
|
208
|
+
edge_id_dict: Dict[EdgeType, Tensor],
|
209
|
+
edge_index_dict: Dict[EdgeType, Tensor],
|
210
|
+
num_nodes_dict: Dict[NodeType, int],
|
211
|
+
is_sorted: bool = False,
|
212
|
+
) -> GraphStore:
|
213
|
+
...
|
214
|
+
|
215
|
+
@classmethod
|
216
|
+
def from_partition(cls, root: str, pid: int) -> GraphStore:
|
217
|
+
...
|
218
|
+
|
219
|
+
|
220
|
+
# Adapted from LocalFeatureStore
|
221
|
+
@runtime_checkable
|
222
|
+
class ConvertableFeatureStore(Protocol):
|
223
|
+
@classmethod
|
224
|
+
def from_data(
|
225
|
+
cls,
|
226
|
+
node_id: Tensor,
|
227
|
+
x: Optional[Tensor] = None,
|
228
|
+
y: Optional[Tensor] = None,
|
229
|
+
edge_id: Optional[Tensor] = None,
|
230
|
+
edge_attr: Optional[Tensor] = None,
|
231
|
+
) -> FeatureStore:
|
232
|
+
...
|
233
|
+
|
234
|
+
@classmethod
|
235
|
+
def from_hetero_data(
|
236
|
+
cls,
|
237
|
+
node_id_dict: Dict[NodeType, Tensor],
|
238
|
+
x_dict: Optional[Dict[NodeType, Tensor]] = None,
|
239
|
+
y_dict: Optional[Dict[NodeType, Tensor]] = None,
|
240
|
+
edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,
|
241
|
+
edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
|
242
|
+
) -> FeatureStore:
|
243
|
+
...
|
244
|
+
|
245
|
+
@classmethod
|
246
|
+
def from_partition(cls, root: str, pid: int) -> FeatureStore:
|
247
|
+
...
|
248
|
+
|
249
|
+
|
250
|
+
class RemoteDataType(Enum):
|
251
|
+
DATA = auto()
|
252
|
+
PARTITION = auto()
|
253
|
+
|
254
|
+
|
255
|
+
@dataclass
|
256
|
+
class RemoteGraphBackendLoader:
|
257
|
+
"""Utility class to load triplets into a RAG Backend."""
|
258
|
+
path: str
|
259
|
+
datatype: RemoteDataType
|
260
|
+
graph_store_type: Type[ConvertableGraphStore]
|
261
|
+
feature_store_type: Type[ConvertableFeatureStore]
|
262
|
+
|
263
|
+
def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
|
264
|
+
if self.datatype == RemoteDataType.DATA:
|
265
|
+
data_obj = torch.load(self.path, weights_only=False)
|
266
|
+
# is_sorted=true since assume nodes come sorted from indexer
|
267
|
+
graph_store = self.graph_store_type.from_data(
|
268
|
+
edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
|
269
|
+
num_nodes=data_obj.num_nodes, is_sorted=True)
|
270
|
+
feature_store = self.feature_store_type.from_data(
|
271
|
+
node_id=data_obj['node_id'], x=data_obj.x,
|
272
|
+
edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)
|
273
|
+
elif self.datatype == RemoteDataType.PARTITION:
|
274
|
+
if pid is None:
|
275
|
+
assert pid is not None, \
|
276
|
+
"Partition ID must be defined for loading from a " \
|
277
|
+
+ "partitioned store."
|
278
|
+
graph_store = self.graph_store_type.from_partition(self.path, pid)
|
279
|
+
feature_store = self.feature_store_type.from_partition(
|
280
|
+
self.path, pid)
|
281
|
+
else:
|
282
|
+
raise NotImplementedError
|
283
|
+
return (feature_store, graph_store)
|
284
|
+
|
285
|
+
def __del__(self) -> None:
|
286
|
+
if os.path.exists(self.path):
|
287
|
+
os.remove(self.path)
|
288
|
+
|
289
|
+
|
290
|
+
def create_graph_from_triples(
|
291
|
+
triples: Iterable[TripletLike],
|
292
|
+
embedding_model: Union[Module, Callable],
|
293
|
+
embedding_method_kwargs: Optional[Dict[str, Any]] = None,
|
294
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
295
|
+
) -> Data:
|
296
|
+
"""Utility function that can be used to create a graph from triples."""
|
297
|
+
# Resolve callable methods
|
298
|
+
embedding_method_kwargs = embedding_method_kwargs \
|
299
|
+
if embedding_method_kwargs is not None else dict()
|
300
|
+
|
301
|
+
indexer = LargeGraphIndexer.from_triplets(triples,
|
302
|
+
pre_transform=pre_transform)
|
303
|
+
node_feats = embedding_model(indexer.get_unique_node_features(),
|
304
|
+
**embedding_method_kwargs)
|
305
|
+
indexer.add_node_feature('x', node_feats)
|
306
|
+
|
307
|
+
edge_feats = embedding_model(
|
308
|
+
indexer.get_unique_edge_features(feature_name=EDGE_RELATION),
|
309
|
+
**embedding_method_kwargs)
|
310
|
+
indexer.add_edge_feature(new_feature_name="edge_attr",
|
311
|
+
new_feature_vals=edge_feats,
|
312
|
+
map_from_feature=EDGE_RELATION)
|
313
|
+
|
314
|
+
data = indexer.to_data(node_feature_name='x',
|
315
|
+
edge_feature_name='edge_attr')
|
316
|
+
data = data.to("cpu")
|
317
|
+
return data
|
318
|
+
|
319
|
+
|
320
|
+
def create_remote_backend_from_graph_data(
|
321
|
+
graph_data: Data,
|
322
|
+
graph_db: Type[ConvertableGraphStore] = LocalGraphStore,
|
323
|
+
feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,
|
324
|
+
path: str = '',
|
325
|
+
n_parts: int = 1,
|
326
|
+
) -> RemoteGraphBackendLoader:
|
327
|
+
"""Utility function that can be used to create a RAG Backend from triples.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
graph_data (Data): Graph data to load into the RAG Backend.
|
331
|
+
graph_db (Type[ConvertableGraphStore], optional): GraphStore class to
|
332
|
+
use. Defaults to LocalGraphStore.
|
333
|
+
feature_db (Type[ConvertableFeatureStore], optional): FeatureStore
|
334
|
+
class to use. Defaults to LocalFeatureStore.
|
335
|
+
path (str, optional): path to save resulting stores. Defaults to ''.
|
336
|
+
n_parts (int, optional): Number of partitons to store in.
|
337
|
+
Defaults to 1.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
RemoteGraphBackendLoader: Loader to load RAG backend from disk or
|
341
|
+
memory.
|
342
|
+
"""
|
343
|
+
# Will return attribute errors for missing attributes
|
344
|
+
if not issubclass(graph_db, ConvertableGraphStore):
|
345
|
+
_ = graph_db.from_data
|
346
|
+
_ = graph_db.from_hetero_data
|
347
|
+
_ = graph_db.from_partition
|
348
|
+
elif not issubclass(feature_db, ConvertableFeatureStore):
|
349
|
+
_ = feature_db.from_data
|
350
|
+
_ = feature_db.from_hetero_data
|
351
|
+
_ = feature_db.from_partition
|
352
|
+
|
353
|
+
if n_parts == 1:
|
354
|
+
torch.save(graph_data, path)
|
355
|
+
return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
|
356
|
+
feature_db)
|
357
|
+
else:
|
358
|
+
partitioner = Partitioner(data=graph_data, num_parts=n_parts,
|
359
|
+
root=path)
|
360
|
+
partitioner.generate_partition()
|
361
|
+
return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,
|
362
|
+
graph_db, feature_db)
|
363
|
+
|
364
|
+
|
365
|
+
def make_pcst_filter(triples: List[Tuple[str, str,
|
366
|
+
str]], model: SentenceTransformer,
|
367
|
+
topk: int = 5, topk_e: int = 5, cost_e: float = 0.5,
|
368
|
+
num_clusters: int = 1) -> Callable[[Data, str], Data]:
|
369
|
+
"""Creates a PCST (Prize Collecting Tree) filter.
|
370
|
+
|
371
|
+
:param triples: List of triples (head, relation, tail) representing KG data
|
372
|
+
:param model: SentenceTransformer model for embedding text
|
373
|
+
:param topk: Number of top-K results to return (default: 5)
|
374
|
+
:param topk_e: Number of top-K entity results to return (default: 5)
|
375
|
+
:param cost_e: Cost of edges (default: 0.5)
|
376
|
+
:param num_clusters: Number of connected components in the PCST output.
|
377
|
+
:return: PCST Filter function
|
378
|
+
"""
|
379
|
+
if DataFrame is None:
|
380
|
+
raise Exception("PCST requires `pip install pandas`"
|
381
|
+
) # Check if pandas is installed
|
382
|
+
|
383
|
+
# Remove duplicate triples to ensure unique set
|
384
|
+
triples = list(dict.fromkeys(triples))
|
385
|
+
|
386
|
+
# Initialize empty list to store nodes (entities) from triples
|
387
|
+
nodes = []
|
388
|
+
|
389
|
+
# Iterate over triples to extract unique nodes (entities)
|
390
|
+
for h, _, t in triples:
|
391
|
+
for node in (h, t): # Extract head and tail entities from each triple
|
392
|
+
nodes.append(node)
|
393
|
+
|
394
|
+
# Remove duplicates and create final list of unique nodes
|
395
|
+
nodes = list(dict.fromkeys(nodes))
|
396
|
+
|
397
|
+
# Create full list of textual nodes (entities) for filtering
|
398
|
+
full_textual_nodes = nodes
|
399
|
+
|
400
|
+
def apply_retrieval_via_pcst(
|
401
|
+
graph: Data, # Input graph data
|
402
|
+
query: str, # Search query
|
403
|
+
) -> Data:
|
404
|
+
"""Applies PCST filtering for retrieval.
|
405
|
+
|
406
|
+
:param graph: Input graph data
|
407
|
+
:param query: Search query
|
408
|
+
:return: Retrieved graph/query data
|
409
|
+
"""
|
410
|
+
# PCST relies on numpy and pcst_fast pypi libs, hence to("cpu")
|
411
|
+
q_emb = model.encode([query]).to("cpu")
|
412
|
+
textual_nodes = [(int(i), full_textual_nodes[i])
|
413
|
+
for i in graph["node_idx"]]
|
414
|
+
textual_nodes = DataFrame(textual_nodes,
|
415
|
+
columns=["node_id", "node_attr"])
|
416
|
+
textual_edges = [triples[i] for i in graph["edge_idx"]]
|
417
|
+
textual_edges = DataFrame(textual_edges,
|
418
|
+
columns=["src", "edge_attr", "dst"])
|
419
|
+
out_graph, desc = retrieval_via_pcst(graph.to(q_emb.device), q_emb,
|
420
|
+
textual_nodes, textual_edges,
|
421
|
+
topk=topk, topk_e=topk_e,
|
422
|
+
cost_e=cost_e,
|
423
|
+
num_clusters=num_clusters)
|
424
|
+
out_graph["desc"] = desc
|
425
|
+
where_trips_start = desc.find("src,edge_attr,dst")
|
426
|
+
parsed_trips = []
|
427
|
+
for trip in desc[where_trips_start + 18:-1].split("\n"):
|
428
|
+
parsed_trips.append(tuple(trip.split(",")))
|
429
|
+
|
430
|
+
# Handle case where PCST returns an isolated node
|
431
|
+
"""
|
432
|
+
TODO find a better solution since these failed subgraphs
|
433
|
+
severely hurt accuracy.
|
434
|
+
"""
|
435
|
+
if str(parsed_trips) == "[('',)]" or out_graph.edge_index.numel() == 0:
|
436
|
+
out_graph["triples"] = []
|
437
|
+
else:
|
438
|
+
out_graph["triples"] = parsed_trips
|
439
|
+
out_graph["question"] = query
|
440
|
+
return out_graph
|
441
|
+
|
442
|
+
return apply_retrieval_via_pcst
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import gc
|
2
|
+
from collections.abc import Iterable, Iterator
|
3
|
+
from typing import Any, Dict, List, Tuple, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import Tensor
|
7
|
+
|
8
|
+
from torch_geometric.data import Data, HeteroData
|
9
|
+
from torch_geometric.distributed import LocalFeatureStore
|
10
|
+
from torch_geometric.llm.utils.backend_utils import batch_knn
|
11
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
12
|
+
from torch_geometric.typing import InputNodes
|
13
|
+
|
14
|
+
|
15
|
+
# NOTE: Only compatible with Homogeneous graphs for now
|
16
|
+
class KNNRAGFeatureStore(LocalFeatureStore):
|
17
|
+
"""A feature store that uses a KNN-based retrieval."""
|
18
|
+
def __init__(self) -> None:
|
19
|
+
"""Initializes the feature store."""
|
20
|
+
# to be set by the config
|
21
|
+
self.encoder_model = None
|
22
|
+
self.k_nodes = None
|
23
|
+
self._config: Dict[str, Any] = {}
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
@property
|
27
|
+
def config(self) -> Dict[str, Any]:
|
28
|
+
"""Get the config for the feature store."""
|
29
|
+
return self._config
|
30
|
+
|
31
|
+
def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
|
32
|
+
"""Set an attribute from the config.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
config (Dict[str, Any]): Config dictionary
|
36
|
+
attr_name (str): Name of attribute to set
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If required attribute not found in config
|
40
|
+
"""
|
41
|
+
if attr_name not in config:
|
42
|
+
raise ValueError(
|
43
|
+
f"Required config parameter '{attr_name}' not found")
|
44
|
+
setattr(self, attr_name, config[attr_name])
|
45
|
+
|
46
|
+
@config.setter # type: ignore
|
47
|
+
def config(self, config: Dict[str, Any]) -> None:
|
48
|
+
"""Set the config for the feature store.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
config (Dict[str, Any]):
|
52
|
+
Config dictionary containing required parameters
|
53
|
+
|
54
|
+
Raises:
|
55
|
+
ValueError: If required parameters missing from config
|
56
|
+
"""
|
57
|
+
self._set_from_config(config, "k_nodes")
|
58
|
+
self._set_from_config(config, "encoder_model")
|
59
|
+
assert self.encoder_model is not None, \
|
60
|
+
"Need to define encoder model from config"
|
61
|
+
self.encoder_model.eval()
|
62
|
+
|
63
|
+
self._config = config
|
64
|
+
|
65
|
+
@property
|
66
|
+
def x(self) -> Tensor:
|
67
|
+
"""Returns the node features."""
|
68
|
+
return Tensor(self.get_tensor(group_name=None, attr_name='x'))
|
69
|
+
|
70
|
+
@property
|
71
|
+
def edge_attr(self) -> Tensor:
|
72
|
+
"""Returns the edge attributes."""
|
73
|
+
return Tensor(
|
74
|
+
self.get_tensor(group_name=(None, None), attr_name='edge_attr'))
|
75
|
+
|
76
|
+
def retrieve_seed_nodes( # noqa: D417
|
77
|
+
self, query: Union[str, List[str],
|
78
|
+
Tuple[str]]) -> Tuple[InputNodes, Tensor]:
|
79
|
+
"""Retrieves the k_nodes most similar nodes to the given query.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
- query (Union[str, List[str], Tuple[str]]):
|
83
|
+
The query or list of queries to search for.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
- The indices of the most similar nodes and the encoded query
|
87
|
+
"""
|
88
|
+
if not isinstance(query, (list, tuple)):
|
89
|
+
query = [query]
|
90
|
+
assert self.k_nodes is not None, "please set k_nodes via config"
|
91
|
+
if len(query) == 1:
|
92
|
+
result, query_enc = next(
|
93
|
+
self._retrieve_seed_nodes_batch(query, self.k_nodes))
|
94
|
+
gc.collect()
|
95
|
+
torch.cuda.empty_cache()
|
96
|
+
return result, query_enc
|
97
|
+
else:
|
98
|
+
out_dict = {}
|
99
|
+
for i, out in enumerate(
|
100
|
+
self._retrieve_seed_nodes_batch(query, self.k_nodes)):
|
101
|
+
out_dict[query[i]] = out
|
102
|
+
gc.collect()
|
103
|
+
torch.cuda.empty_cache()
|
104
|
+
return out_dict
|
105
|
+
|
106
|
+
def _retrieve_seed_nodes_batch( # noqa: D417
|
107
|
+
self, query: Iterable[Any],
|
108
|
+
k_nodes: int) -> Iterator[Tuple[InputNodes, Tensor]]:
|
109
|
+
"""Retrieves the k_nodes most similar nodes to each query in the batch.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
- query (Iterable[Any]: The batch of queries to search for.
|
113
|
+
- k_nodes (int): The number of nodes to retrieve.
|
114
|
+
|
115
|
+
Yields:
|
116
|
+
- The indices of the most similar nodes for each query.
|
117
|
+
"""
|
118
|
+
if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
|
119
|
+
raise NotImplementedError
|
120
|
+
assert self.encoder_model is not None, \
|
121
|
+
"Need to define encoder model from config"
|
122
|
+
query_enc = self.encoder_model.encode(query)
|
123
|
+
return batch_knn(query_enc, self.x, k_nodes)
|
124
|
+
|
125
|
+
def load_subgraph( # noqa
|
126
|
+
self,
|
127
|
+
sample: Union[SamplerOutput, HeteroSamplerOutput],
|
128
|
+
induced: bool = True,
|
129
|
+
) -> Union[Data, HeteroData]:
|
130
|
+
"""Loads a subgraph from the given sample.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
- sample: The sample to load the subgraph from.
|
134
|
+
- induced: Whether to return the induced subgraph.
|
135
|
+
Resets node and edge ids.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
- The loaded subgraph.
|
139
|
+
"""
|
140
|
+
if isinstance(sample, HeteroSamplerOutput):
|
141
|
+
raise NotImplementedError
|
142
|
+
"""
|
143
|
+
NOTE: torch_geometric.loader.utils.filter_custom_store
|
144
|
+
can be used here if it supported edge features.
|
145
|
+
"""
|
146
|
+
edge_id = sample.edge
|
147
|
+
x = self.x[sample.node]
|
148
|
+
edge_attr = self.edge_attr[edge_id]
|
149
|
+
|
150
|
+
edge_idx = torch.stack(
|
151
|
+
[sample.row, sample.col], dim=0) if induced else torch.stack(
|
152
|
+
[sample.global_row, sample.global_col], dim=0)
|
153
|
+
result = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
|
154
|
+
|
155
|
+
# useful for tracking what subset of the graph was sampled
|
156
|
+
result.node_idx = sample.node
|
157
|
+
result.edge_idx = edge_id
|
158
|
+
|
159
|
+
return result
|
160
|
+
|
161
|
+
|
162
|
+
"""
|
163
|
+
TODO: make class CuVSKNNRAGFeatureStore(KNNRAGFeatureStore)
|
164
|
+
include a approximate knn flag for the CuVS.
|
165
|
+
Connect this with a CuGraphGraphStore
|
166
|
+
for enabling a accelerated boolean flag for RAGQueryLoader.
|
167
|
+
On by default if CuGraph+CuVS avail.
|
168
|
+
If not raise note mentioning its speedup.
|
169
|
+
"""
|