pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
  2. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +32 -25
  3. torch_geometric/__init__.py +1 -1
  4. torch_geometric/data/__init__.py +0 -5
  5. torch_geometric/data/lightning/datamodule.py +2 -2
  6. torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
  7. torch_geometric/datasets/web_qsp_dataset.py +262 -210
  8. torch_geometric/graphgym/imports.py +2 -2
  9. torch_geometric/llm/__init__.py +9 -0
  10. torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
  11. torch_geometric/llm/models/__init__.py +23 -0
  12. torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
  13. torch_geometric/{nn → llm}/models/git_mol.py +1 -1
  14. torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
  15. torch_geometric/llm/models/llm_judge.py +158 -0
  16. torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
  17. torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
  18. torch_geometric/llm/models/txt2kg.py +353 -0
  19. torch_geometric/llm/rag_loader.py +154 -0
  20. torch_geometric/llm/utils/backend_utils.py +442 -0
  21. torch_geometric/llm/utils/feature_store.py +169 -0
  22. torch_geometric/llm/utils/graph_store.py +199 -0
  23. torch_geometric/llm/utils/vectorrag.py +124 -0
  24. torch_geometric/loader/__init__.py +0 -4
  25. torch_geometric/nn/__init__.py +0 -1
  26. torch_geometric/nn/models/__init__.py +0 -10
  27. torch_geometric/nn/models/sgformer.py +2 -0
  28. torch_geometric/loader/rag_loader.py +0 -107
  29. torch_geometric/nn/nlp/__init__.py +0 -9
  30. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
  31. {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
  32. /torch_geometric/{nn → llm}/models/glem.py +0 -0
  33. /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
  34. /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -0,0 +1,442 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from enum import Enum, auto
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Iterable,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Protocol,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ no_type_check,
17
+ runtime_checkable,
18
+ )
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.nn import Module
24
+
25
+ from torch_geometric.data import Data, FeatureStore, GraphStore
26
+ from torch_geometric.distributed import (
27
+ LocalFeatureStore,
28
+ LocalGraphStore,
29
+ Partitioner,
30
+ )
31
+ from torch_geometric.llm.large_graph_indexer import (
32
+ EDGE_RELATION,
33
+ LargeGraphIndexer,
34
+ TripletLike,
35
+ )
36
+ from torch_geometric.llm.models import SentenceTransformer
37
+ from torch_geometric.typing import EdgeType, NodeType
38
+
39
+ try:
40
+ from pandas import DataFrame
41
+ except ImportError:
42
+ DataFrame = None
43
+ RemoteGraphBackend = Tuple[FeatureStore, GraphStore]
44
+
45
+ # TODO: Make everything compatible with Hetero graphs aswell
46
+
47
+
48
+ def preprocess_triplet(triplet: TripletLike) -> TripletLike:
49
+ h, r, t = triplet
50
+ return str(h).lower(), str(r).lower(), str(t).lower()
51
+
52
+
53
+ @no_type_check
54
+ def retrieval_via_pcst(
55
+ data: Data,
56
+ q_emb: Tensor,
57
+ textual_nodes: Any,
58
+ textual_edges: Any,
59
+ topk: int = 3,
60
+ topk_e: int = 5,
61
+ cost_e: float = 0.5,
62
+ num_clusters: int = 1,
63
+ ) -> Tuple[Data, str]:
64
+
65
+ # skip PCST for bad graphs
66
+ booly = data.edge_attr is None or data.edge_attr.numel() == 0
67
+ booly = booly or data.x is None or data.x.numel() == 0
68
+ booly = booly or data.edge_index is None or data.edge_index.numel() == 0
69
+ if not booly:
70
+ c = 0.01
71
+
72
+ from pcst_fast import pcst_fast
73
+
74
+ root = -1
75
+ pruning = 'gw'
76
+ verbosity_level = 0
77
+ if topk > 0:
78
+ n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
79
+ topk = min(topk, data.num_nodes)
80
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
81
+
82
+ n_prizes = torch.zeros_like(n_prizes)
83
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
84
+ else:
85
+ n_prizes = torch.zeros(data.num_nodes)
86
+
87
+ if topk_e > 0:
88
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
89
+ topk_e = min(topk_e, e_prizes.unique().size(0))
90
+
91
+ topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e,
92
+ largest=True)
93
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
94
+ last_topk_e_value = topk_e
95
+ for k in range(topk_e):
96
+ indices = e_prizes == topk_e_values[k]
97
+ value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
98
+ e_prizes[indices] = value
99
+ last_topk_e_value = value * (1 - c)
100
+ # reduce the cost of the edges so that at least one edge is chosen
101
+ cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
102
+ else:
103
+ e_prizes = torch.zeros(data.num_edges)
104
+
105
+ costs = []
106
+ edges = []
107
+ virtual_n_prizes = []
108
+ virtual_edges = []
109
+ virtual_costs = []
110
+ mapping_n = {}
111
+ mapping_e = {}
112
+ for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
113
+ prize_e = e_prizes[i]
114
+ if prize_e <= cost_e:
115
+ mapping_e[len(edges)] = i
116
+ edges.append((src, dst))
117
+ costs.append(cost_e - prize_e)
118
+ else:
119
+ virtual_node_id = data.num_nodes + len(virtual_n_prizes)
120
+ mapping_n[virtual_node_id] = i
121
+ virtual_edges.append((src, virtual_node_id))
122
+ virtual_edges.append((virtual_node_id, dst))
123
+ virtual_costs.append(0)
124
+ virtual_costs.append(0)
125
+ virtual_n_prizes.append(prize_e - cost_e)
126
+
127
+ prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
128
+ num_edges = len(edges)
129
+ if len(virtual_costs) > 0:
130
+ costs = np.array(costs + virtual_costs)
131
+ edges = np.array(edges + virtual_edges)
132
+
133
+ vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
134
+ pruning, verbosity_level)
135
+
136
+ selected_nodes = vertices[vertices < data.num_nodes]
137
+ selected_edges = [mapping_e[e] for e in edges if e < num_edges]
138
+ virtual_vertices = vertices[vertices >= data.num_nodes]
139
+ if len(virtual_vertices) > 0:
140
+ virtual_vertices = vertices[vertices >= data.num_nodes]
141
+ virtual_edges = [mapping_n[i] for i in virtual_vertices]
142
+ selected_edges = np.array(selected_edges + virtual_edges)
143
+
144
+ edge_index = data.edge_index[:, selected_edges]
145
+ selected_nodes = np.unique(
146
+ np.concatenate(
147
+ [selected_nodes, edge_index[0].numpy(),
148
+ edge_index[1].numpy()]))
149
+
150
+ n = textual_nodes.iloc[selected_nodes]
151
+ e = textual_edges.iloc[selected_edges]
152
+ else:
153
+ n = textual_nodes
154
+ e = textual_edges
155
+
156
+ desc = n.to_csv(index=False) + '\n' + e.to_csv(
157
+ index=False, columns=['src', 'edge_attr', 'dst'])
158
+
159
+ if booly:
160
+ return data, desc
161
+
162
+ mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
163
+ src = [mapping[i] for i in edge_index[0].tolist()]
164
+ dst = [mapping[i] for i in edge_index[1].tolist()]
165
+
166
+ # HACK Added so that the subset of nodes and edges selected can be tracked
167
+ node_idx = np.array(data.node_idx)[selected_nodes]
168
+ edge_idx = np.array(data.edge_idx)[selected_edges]
169
+
170
+ data = Data(
171
+ x=data.x[selected_nodes],
172
+ edge_index=torch.tensor([src, dst]).to(torch.long),
173
+ edge_attr=data.edge_attr[selected_edges],
174
+ # HACK: track subset of selected nodes/edges
175
+ node_idx=node_idx,
176
+ edge_idx=edge_idx,
177
+ )
178
+
179
+ return data, desc
180
+
181
+
182
+ def batch_knn(query_enc: Tensor, embeds: Tensor,
183
+ k: int) -> Iterator[Tuple[Tensor, Tensor]]:
184
+ from torchmetrics.functional import pairwise_cosine_similarity
185
+ prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device))
186
+ topk = min(k, len(embeds))
187
+ for i, q in enumerate(prizes):
188
+ _, indices = torch.topk(q, topk, largest=True)
189
+ yield indices, query_enc[i].unsqueeze(0)
190
+
191
+
192
+ # Adapted from LocalGraphStore
193
+ @runtime_checkable
194
+ class ConvertableGraphStore(Protocol):
195
+ @classmethod
196
+ def from_data(
197
+ cls,
198
+ edge_id: Tensor,
199
+ edge_index: Tensor,
200
+ num_nodes: int,
201
+ is_sorted: bool = False,
202
+ ) -> GraphStore:
203
+ ...
204
+
205
+ @classmethod
206
+ def from_hetero_data(
207
+ cls,
208
+ edge_id_dict: Dict[EdgeType, Tensor],
209
+ edge_index_dict: Dict[EdgeType, Tensor],
210
+ num_nodes_dict: Dict[NodeType, int],
211
+ is_sorted: bool = False,
212
+ ) -> GraphStore:
213
+ ...
214
+
215
+ @classmethod
216
+ def from_partition(cls, root: str, pid: int) -> GraphStore:
217
+ ...
218
+
219
+
220
+ # Adapted from LocalFeatureStore
221
+ @runtime_checkable
222
+ class ConvertableFeatureStore(Protocol):
223
+ @classmethod
224
+ def from_data(
225
+ cls,
226
+ node_id: Tensor,
227
+ x: Optional[Tensor] = None,
228
+ y: Optional[Tensor] = None,
229
+ edge_id: Optional[Tensor] = None,
230
+ edge_attr: Optional[Tensor] = None,
231
+ ) -> FeatureStore:
232
+ ...
233
+
234
+ @classmethod
235
+ def from_hetero_data(
236
+ cls,
237
+ node_id_dict: Dict[NodeType, Tensor],
238
+ x_dict: Optional[Dict[NodeType, Tensor]] = None,
239
+ y_dict: Optional[Dict[NodeType, Tensor]] = None,
240
+ edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,
241
+ edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
242
+ ) -> FeatureStore:
243
+ ...
244
+
245
+ @classmethod
246
+ def from_partition(cls, root: str, pid: int) -> FeatureStore:
247
+ ...
248
+
249
+
250
+ class RemoteDataType(Enum):
251
+ DATA = auto()
252
+ PARTITION = auto()
253
+
254
+
255
+ @dataclass
256
+ class RemoteGraphBackendLoader:
257
+ """Utility class to load triplets into a RAG Backend."""
258
+ path: str
259
+ datatype: RemoteDataType
260
+ graph_store_type: Type[ConvertableGraphStore]
261
+ feature_store_type: Type[ConvertableFeatureStore]
262
+
263
+ def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
264
+ if self.datatype == RemoteDataType.DATA:
265
+ data_obj = torch.load(self.path, weights_only=False)
266
+ # is_sorted=true since assume nodes come sorted from indexer
267
+ graph_store = self.graph_store_type.from_data(
268
+ edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
269
+ num_nodes=data_obj.num_nodes, is_sorted=True)
270
+ feature_store = self.feature_store_type.from_data(
271
+ node_id=data_obj['node_id'], x=data_obj.x,
272
+ edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)
273
+ elif self.datatype == RemoteDataType.PARTITION:
274
+ if pid is None:
275
+ assert pid is not None, \
276
+ "Partition ID must be defined for loading from a " \
277
+ + "partitioned store."
278
+ graph_store = self.graph_store_type.from_partition(self.path, pid)
279
+ feature_store = self.feature_store_type.from_partition(
280
+ self.path, pid)
281
+ else:
282
+ raise NotImplementedError
283
+ return (feature_store, graph_store)
284
+
285
+ def __del__(self) -> None:
286
+ if os.path.exists(self.path):
287
+ os.remove(self.path)
288
+
289
+
290
+ def create_graph_from_triples(
291
+ triples: Iterable[TripletLike],
292
+ embedding_model: Union[Module, Callable],
293
+ embedding_method_kwargs: Optional[Dict[str, Any]] = None,
294
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
295
+ ) -> Data:
296
+ """Utility function that can be used to create a graph from triples."""
297
+ # Resolve callable methods
298
+ embedding_method_kwargs = embedding_method_kwargs \
299
+ if embedding_method_kwargs is not None else dict()
300
+
301
+ indexer = LargeGraphIndexer.from_triplets(triples,
302
+ pre_transform=pre_transform)
303
+ node_feats = embedding_model(indexer.get_unique_node_features(),
304
+ **embedding_method_kwargs)
305
+ indexer.add_node_feature('x', node_feats)
306
+
307
+ edge_feats = embedding_model(
308
+ indexer.get_unique_edge_features(feature_name=EDGE_RELATION),
309
+ **embedding_method_kwargs)
310
+ indexer.add_edge_feature(new_feature_name="edge_attr",
311
+ new_feature_vals=edge_feats,
312
+ map_from_feature=EDGE_RELATION)
313
+
314
+ data = indexer.to_data(node_feature_name='x',
315
+ edge_feature_name='edge_attr')
316
+ data = data.to("cpu")
317
+ return data
318
+
319
+
320
+ def create_remote_backend_from_graph_data(
321
+ graph_data: Data,
322
+ graph_db: Type[ConvertableGraphStore] = LocalGraphStore,
323
+ feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,
324
+ path: str = '',
325
+ n_parts: int = 1,
326
+ ) -> RemoteGraphBackendLoader:
327
+ """Utility function that can be used to create a RAG Backend from triples.
328
+
329
+ Args:
330
+ graph_data (Data): Graph data to load into the RAG Backend.
331
+ graph_db (Type[ConvertableGraphStore], optional): GraphStore class to
332
+ use. Defaults to LocalGraphStore.
333
+ feature_db (Type[ConvertableFeatureStore], optional): FeatureStore
334
+ class to use. Defaults to LocalFeatureStore.
335
+ path (str, optional): path to save resulting stores. Defaults to ''.
336
+ n_parts (int, optional): Number of partitons to store in.
337
+ Defaults to 1.
338
+
339
+ Returns:
340
+ RemoteGraphBackendLoader: Loader to load RAG backend from disk or
341
+ memory.
342
+ """
343
+ # Will return attribute errors for missing attributes
344
+ if not issubclass(graph_db, ConvertableGraphStore):
345
+ _ = graph_db.from_data
346
+ _ = graph_db.from_hetero_data
347
+ _ = graph_db.from_partition
348
+ elif not issubclass(feature_db, ConvertableFeatureStore):
349
+ _ = feature_db.from_data
350
+ _ = feature_db.from_hetero_data
351
+ _ = feature_db.from_partition
352
+
353
+ if n_parts == 1:
354
+ torch.save(graph_data, path)
355
+ return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
356
+ feature_db)
357
+ else:
358
+ partitioner = Partitioner(data=graph_data, num_parts=n_parts,
359
+ root=path)
360
+ partitioner.generate_partition()
361
+ return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,
362
+ graph_db, feature_db)
363
+
364
+
365
+ def make_pcst_filter(triples: List[Tuple[str, str,
366
+ str]], model: SentenceTransformer,
367
+ topk: int = 5, topk_e: int = 5, cost_e: float = 0.5,
368
+ num_clusters: int = 1) -> Callable[[Data, str], Data]:
369
+ """Creates a PCST (Prize Collecting Tree) filter.
370
+
371
+ :param triples: List of triples (head, relation, tail) representing KG data
372
+ :param model: SentenceTransformer model for embedding text
373
+ :param topk: Number of top-K results to return (default: 5)
374
+ :param topk_e: Number of top-K entity results to return (default: 5)
375
+ :param cost_e: Cost of edges (default: 0.5)
376
+ :param num_clusters: Number of connected components in the PCST output.
377
+ :return: PCST Filter function
378
+ """
379
+ if DataFrame is None:
380
+ raise Exception("PCST requires `pip install pandas`"
381
+ ) # Check if pandas is installed
382
+
383
+ # Remove duplicate triples to ensure unique set
384
+ triples = list(dict.fromkeys(triples))
385
+
386
+ # Initialize empty list to store nodes (entities) from triples
387
+ nodes = []
388
+
389
+ # Iterate over triples to extract unique nodes (entities)
390
+ for h, _, t in triples:
391
+ for node in (h, t): # Extract head and tail entities from each triple
392
+ nodes.append(node)
393
+
394
+ # Remove duplicates and create final list of unique nodes
395
+ nodes = list(dict.fromkeys(nodes))
396
+
397
+ # Create full list of textual nodes (entities) for filtering
398
+ full_textual_nodes = nodes
399
+
400
+ def apply_retrieval_via_pcst(
401
+ graph: Data, # Input graph data
402
+ query: str, # Search query
403
+ ) -> Data:
404
+ """Applies PCST filtering for retrieval.
405
+
406
+ :param graph: Input graph data
407
+ :param query: Search query
408
+ :return: Retrieved graph/query data
409
+ """
410
+ # PCST relies on numpy and pcst_fast pypi libs, hence to("cpu")
411
+ q_emb = model.encode([query]).to("cpu")
412
+ textual_nodes = [(int(i), full_textual_nodes[i])
413
+ for i in graph["node_idx"]]
414
+ textual_nodes = DataFrame(textual_nodes,
415
+ columns=["node_id", "node_attr"])
416
+ textual_edges = [triples[i] for i in graph["edge_idx"]]
417
+ textual_edges = DataFrame(textual_edges,
418
+ columns=["src", "edge_attr", "dst"])
419
+ out_graph, desc = retrieval_via_pcst(graph.to(q_emb.device), q_emb,
420
+ textual_nodes, textual_edges,
421
+ topk=topk, topk_e=topk_e,
422
+ cost_e=cost_e,
423
+ num_clusters=num_clusters)
424
+ out_graph["desc"] = desc
425
+ where_trips_start = desc.find("src,edge_attr,dst")
426
+ parsed_trips = []
427
+ for trip in desc[where_trips_start + 18:-1].split("\n"):
428
+ parsed_trips.append(tuple(trip.split(",")))
429
+
430
+ # Handle case where PCST returns an isolated node
431
+ """
432
+ TODO find a better solution since these failed subgraphs
433
+ severely hurt accuracy.
434
+ """
435
+ if str(parsed_trips) == "[('',)]" or out_graph.edge_index.numel() == 0:
436
+ out_graph["triples"] = []
437
+ else:
438
+ out_graph["triples"] = parsed_trips
439
+ out_graph["question"] = query
440
+ return out_graph
441
+
442
+ return apply_retrieval_via_pcst
@@ -0,0 +1,169 @@
1
+ import gc
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from torch_geometric.data import Data, HeteroData
9
+ from torch_geometric.distributed import LocalFeatureStore
10
+ from torch_geometric.llm.utils.backend_utils import batch_knn
11
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
12
+ from torch_geometric.typing import InputNodes
13
+
14
+
15
+ # NOTE: Only compatible with Homogeneous graphs for now
16
+ class KNNRAGFeatureStore(LocalFeatureStore):
17
+ """A feature store that uses a KNN-based retrieval."""
18
+ def __init__(self) -> None:
19
+ """Initializes the feature store."""
20
+ # to be set by the config
21
+ self.encoder_model = None
22
+ self.k_nodes = None
23
+ self._config: Dict[str, Any] = {}
24
+ super().__init__()
25
+
26
+ @property
27
+ def config(self) -> Dict[str, Any]:
28
+ """Get the config for the feature store."""
29
+ return self._config
30
+
31
+ def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
32
+ """Set an attribute from the config.
33
+
34
+ Args:
35
+ config (Dict[str, Any]): Config dictionary
36
+ attr_name (str): Name of attribute to set
37
+
38
+ Raises:
39
+ ValueError: If required attribute not found in config
40
+ """
41
+ if attr_name not in config:
42
+ raise ValueError(
43
+ f"Required config parameter '{attr_name}' not found")
44
+ setattr(self, attr_name, config[attr_name])
45
+
46
+ @config.setter # type: ignore
47
+ def config(self, config: Dict[str, Any]) -> None:
48
+ """Set the config for the feature store.
49
+
50
+ Args:
51
+ config (Dict[str, Any]):
52
+ Config dictionary containing required parameters
53
+
54
+ Raises:
55
+ ValueError: If required parameters missing from config
56
+ """
57
+ self._set_from_config(config, "k_nodes")
58
+ self._set_from_config(config, "encoder_model")
59
+ assert self.encoder_model is not None, \
60
+ "Need to define encoder model from config"
61
+ self.encoder_model.eval()
62
+
63
+ self._config = config
64
+
65
+ @property
66
+ def x(self) -> Tensor:
67
+ """Returns the node features."""
68
+ return Tensor(self.get_tensor(group_name=None, attr_name='x'))
69
+
70
+ @property
71
+ def edge_attr(self) -> Tensor:
72
+ """Returns the edge attributes."""
73
+ return Tensor(
74
+ self.get_tensor(group_name=(None, None), attr_name='edge_attr'))
75
+
76
+ def retrieve_seed_nodes( # noqa: D417
77
+ self, query: Union[str, List[str],
78
+ Tuple[str]]) -> Tuple[InputNodes, Tensor]:
79
+ """Retrieves the k_nodes most similar nodes to the given query.
80
+
81
+ Args:
82
+ - query (Union[str, List[str], Tuple[str]]):
83
+ The query or list of queries to search for.
84
+
85
+ Returns:
86
+ - The indices of the most similar nodes and the encoded query
87
+ """
88
+ if not isinstance(query, (list, tuple)):
89
+ query = [query]
90
+ assert self.k_nodes is not None, "please set k_nodes via config"
91
+ if len(query) == 1:
92
+ result, query_enc = next(
93
+ self._retrieve_seed_nodes_batch(query, self.k_nodes))
94
+ gc.collect()
95
+ torch.cuda.empty_cache()
96
+ return result, query_enc
97
+ else:
98
+ out_dict = {}
99
+ for i, out in enumerate(
100
+ self._retrieve_seed_nodes_batch(query, self.k_nodes)):
101
+ out_dict[query[i]] = out
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+ return out_dict
105
+
106
+ def _retrieve_seed_nodes_batch( # noqa: D417
107
+ self, query: Iterable[Any],
108
+ k_nodes: int) -> Iterator[Tuple[InputNodes, Tensor]]:
109
+ """Retrieves the k_nodes most similar nodes to each query in the batch.
110
+
111
+ Args:
112
+ - query (Iterable[Any]: The batch of queries to search for.
113
+ - k_nodes (int): The number of nodes to retrieve.
114
+
115
+ Yields:
116
+ - The indices of the most similar nodes for each query.
117
+ """
118
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
119
+ raise NotImplementedError
120
+ assert self.encoder_model is not None, \
121
+ "Need to define encoder model from config"
122
+ query_enc = self.encoder_model.encode(query)
123
+ return batch_knn(query_enc, self.x, k_nodes)
124
+
125
+ def load_subgraph( # noqa
126
+ self,
127
+ sample: Union[SamplerOutput, HeteroSamplerOutput],
128
+ induced: bool = True,
129
+ ) -> Union[Data, HeteroData]:
130
+ """Loads a subgraph from the given sample.
131
+
132
+ Args:
133
+ - sample: The sample to load the subgraph from.
134
+ - induced: Whether to return the induced subgraph.
135
+ Resets node and edge ids.
136
+
137
+ Returns:
138
+ - The loaded subgraph.
139
+ """
140
+ if isinstance(sample, HeteroSamplerOutput):
141
+ raise NotImplementedError
142
+ """
143
+ NOTE: torch_geometric.loader.utils.filter_custom_store
144
+ can be used here if it supported edge features.
145
+ """
146
+ edge_id = sample.edge
147
+ x = self.x[sample.node]
148
+ edge_attr = self.edge_attr[edge_id]
149
+
150
+ edge_idx = torch.stack(
151
+ [sample.row, sample.col], dim=0) if induced else torch.stack(
152
+ [sample.global_row, sample.global_col], dim=0)
153
+ result = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
154
+
155
+ # useful for tracking what subset of the graph was sampled
156
+ result.node_idx = sample.node
157
+ result.edge_idx = edge_id
158
+
159
+ return result
160
+
161
+
162
+ """
163
+ TODO: make class CuVSKNNRAGFeatureStore(KNNRAGFeatureStore)
164
+ include a approximate knn flag for the CuVS.
165
+ Connect this with a CuGraphGraphStore
166
+ for enabling a accelerated boolean flag for RAGQueryLoader.
167
+ On by default if CuGraph+CuVS avail.
168
+ If not raise note mentioning its speedup.
169
+ """