pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +32 -25
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +0 -5
- torch_geometric/data/lightning/datamodule.py +2 -2
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/web_qsp_dataset.py +262 -210
- torch_geometric/graphgym/imports.py +2 -2
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
- torch_geometric/{nn → llm}/models/git_mol.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/backend_utils.py +442 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +124 -0
- torch_geometric/loader/__init__.py +0 -4
- torch_geometric/nn/__init__.py +0 -1
- torch_geometric/nn/models/__init__.py +0 -10
- torch_geometric/nn/models/sgformer.py +2 -0
- torch_geometric/loader/rag_loader.py +0 -107
- torch_geometric/nn/nlp/__init__.py +0 -9
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
- /torch_geometric/{nn → llm}/models/glem.py +0 -0
- /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
- /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -0,0 +1,353 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from typing import List, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.multiprocessing as mp
|
7
|
+
|
8
|
+
CLIENT_INITD = False
|
9
|
+
|
10
|
+
CLIENT = None
|
11
|
+
GLOBAL_NIM_KEY = ""
|
12
|
+
SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Separate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph." # noqa
|
13
|
+
|
14
|
+
|
15
|
+
class TXT2KG():
|
16
|
+
"""A class to convert text data into a Knowledge Graph (KG) format.
|
17
|
+
Uses NVIDIA NIMs + Prompt engineering by default.
|
18
|
+
Default model `nvidia/llama-3.1-nemotron-70b-instruct`
|
19
|
+
is on par or better than GPT4o in benchmarks.
|
20
|
+
We need a high quality model to ensure high quality KG.
|
21
|
+
Otherwise we have garbage in garbage out for the rest of the
|
22
|
+
GNN+LLM RAG pipeline.
|
23
|
+
|
24
|
+
Use local_lm flag for local debugging/dev. You still need to be able to
|
25
|
+
inference a 14B param LLM, 'VAGOsolutions/SauerkrautLM-v2-14b-DPO'.
|
26
|
+
Smaller LLMs did not work at all in testing.
|
27
|
+
Note this 14B model requires a considerable amount of GPU memory.
|
28
|
+
See examples/llm/txt2kg_rag.py for an example.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
NVIDIA_NIM_MODEL : str, optional
|
32
|
+
The name of the NVIDIA NIM model to use.
|
33
|
+
(default: "nvidia/llama-3.1-nemotron-70b-instruct").
|
34
|
+
NVIDIA_API_KEY : str, optional
|
35
|
+
The API key for accessing NVIDIA's NIM models (default: "").
|
36
|
+
ENDPOINT_URL : str, optional
|
37
|
+
The URL hosting your model, in case you are not using
|
38
|
+
the public NIM.
|
39
|
+
(default: "https://integrate.api.nvidia.com/v1").
|
40
|
+
local_LM : bool, optional
|
41
|
+
A flag indicating whether a local Language Model (LM)
|
42
|
+
should be used. This uses HuggingFace and will be slower
|
43
|
+
than deploying your own private NIM endpoint. This flag
|
44
|
+
is mainly recommended for dev/debug.
|
45
|
+
(default: False).
|
46
|
+
chunk_size : int, optional
|
47
|
+
The size of the chunks in which the text data is processed
|
48
|
+
(default: 512).
|
49
|
+
"""
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
NVIDIA_NIM_MODEL: Optional[
|
53
|
+
str] = "nvidia/llama-3.1-nemotron-70b-instruct",
|
54
|
+
NVIDIA_API_KEY: Optional[str] = "",
|
55
|
+
ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
|
56
|
+
local_LM: bool = False,
|
57
|
+
chunk_size: int = 512,
|
58
|
+
) -> None:
|
59
|
+
self.local_LM = local_LM
|
60
|
+
# Initialize the local LM flag and the NIM model info accordingly
|
61
|
+
if self.local_LM:
|
62
|
+
# If using a local LM, set the initd_LM flag to False
|
63
|
+
self.initd_LM = False
|
64
|
+
else:
|
65
|
+
# If not using a local LM, store the provided NIM model info
|
66
|
+
self.NVIDIA_API_KEY = NVIDIA_API_KEY
|
67
|
+
self.NIM_MODEL = NVIDIA_NIM_MODEL
|
68
|
+
self.ENDPOINT_URL = ENDPOINT_URL
|
69
|
+
|
70
|
+
# Set the chunk size for processing text data
|
71
|
+
self.chunk_size = chunk_size
|
72
|
+
|
73
|
+
# Initialize counters and storage for parsing results
|
74
|
+
self.doc_id_counter = 0
|
75
|
+
self.relevant_triples = {}
|
76
|
+
self.total_chars_parsed = 0
|
77
|
+
self.time_to_parse = 0.0
|
78
|
+
|
79
|
+
def save_kg(self, path: str) -> None:
|
80
|
+
"""Saves the relevant triples in the knowledge graph (KG) to a file.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
path (str): The file path where the KG will be saved.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
None
|
87
|
+
"""
|
88
|
+
torch.save(self.relevant_triples, path)
|
89
|
+
|
90
|
+
def _chunk_to_triples_str_local(self, txt: str) -> str:
|
91
|
+
# call LLM on text
|
92
|
+
chunk_start_time = time.time()
|
93
|
+
if not self.initd_LM:
|
94
|
+
from torch_geometric.nn.nlp import LLM
|
95
|
+
LM_name = "VAGOsolutions/SauerkrautLM-v2-14b-DPO"
|
96
|
+
self.model = LLM(LM_name).eval()
|
97
|
+
self.initd_LM = True
|
98
|
+
out_str = self.model.inference(question=[txt + '\n' + SYSTEM_PROMPT],
|
99
|
+
max_tokens=self.chunk_size)[0]
|
100
|
+
# for debug
|
101
|
+
self.total_chars_parsed += len(txt)
|
102
|
+
self.time_to_parse += round(time.time() - chunk_start_time, 2)
|
103
|
+
self.avg_chars_parsed_per_sec = self.total_chars_parsed / self.time_to_parse # noqa
|
104
|
+
return out_str
|
105
|
+
|
106
|
+
def add_doc_2_KG(
|
107
|
+
self,
|
108
|
+
txt: str,
|
109
|
+
QA_pair: Optional[Tuple[str, str]] = None,
|
110
|
+
) -> None:
|
111
|
+
"""Add a document to the Knowledge Graph (KG).
|
112
|
+
|
113
|
+
Args:
|
114
|
+
txt (str): The text to extract triples from.
|
115
|
+
QA_pair (Tuple[str, str]], optional):
|
116
|
+
A QA pair to associate with the extracted triples.
|
117
|
+
Useful for downstream evaluation.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
- None
|
121
|
+
"""
|
122
|
+
if not self.local_LM:
|
123
|
+
# Ensure NVIDIA_API_KEY is set before proceeding
|
124
|
+
assert self.NVIDIA_API_KEY != '', \
|
125
|
+
"Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True"
|
126
|
+
if QA_pair:
|
127
|
+
# QA_pairs should be unique keys, check if already exists in KG
|
128
|
+
if QA_pair in self.relevant_triples.keys():
|
129
|
+
print("Warning: QA_Pair was already added to the set")
|
130
|
+
print("Q=", QA_pair[0])
|
131
|
+
print("A=", QA_pair[1])
|
132
|
+
print("Previously parsed triples=",
|
133
|
+
self.relevant_triples[QA_pair])
|
134
|
+
print("Skipping...")
|
135
|
+
key = QA_pair
|
136
|
+
else:
|
137
|
+
# If no QA_pair, use the current doc_id_counter as the key
|
138
|
+
key = self.doc_id_counter
|
139
|
+
|
140
|
+
# Handle empty text (context-less QA pairs)
|
141
|
+
if txt == "":
|
142
|
+
self.relevant_triples[key] = []
|
143
|
+
else:
|
144
|
+
# Chunk the text into smaller pieces for processing
|
145
|
+
chunks = _chunk_text(txt, chunk_size=self.chunk_size)
|
146
|
+
|
147
|
+
if self.local_LM:
|
148
|
+
# For debugging purposes...
|
149
|
+
# process chunks sequentially on the local LM
|
150
|
+
self.relevant_triples[key] = _llm_then_python_parse(
|
151
|
+
chunks, _parse_n_check_triples,
|
152
|
+
self._chunk_to_triples_str_local)
|
153
|
+
else:
|
154
|
+
# Process chunks in parallel using multiple processes
|
155
|
+
num_procs = min(len(chunks), _get_num_procs())
|
156
|
+
meta_chunk_size = int(len(chunks) / num_procs)
|
157
|
+
in_chunks_per_proc = {
|
158
|
+
j:
|
159
|
+
chunks[j *
|
160
|
+
meta_chunk_size:min((j + 1) *
|
161
|
+
meta_chunk_size, len(chunks))]
|
162
|
+
for j in range(num_procs)
|
163
|
+
}
|
164
|
+
for _retry_j in range(5):
|
165
|
+
try:
|
166
|
+
for _retry_i in range(200):
|
167
|
+
try:
|
168
|
+
# Spawn multiple processes
|
169
|
+
# process chunks in parallel
|
170
|
+
mp.spawn(
|
171
|
+
_multiproc_helper,
|
172
|
+
args=(in_chunks_per_proc,
|
173
|
+
_parse_n_check_triples,
|
174
|
+
_chunk_to_triples_str_cloud,
|
175
|
+
self.NVIDIA_API_KEY, self.NIM_MODEL,
|
176
|
+
self.ENDPOINT_URL), nprocs=num_procs)
|
177
|
+
break
|
178
|
+
except: # noqa
|
179
|
+
# keep retrying...
|
180
|
+
# txt2kg is costly -> stoppage is costly
|
181
|
+
pass
|
182
|
+
|
183
|
+
# Collect the results from each process
|
184
|
+
self.relevant_triples[key] = []
|
185
|
+
for rank in range(num_procs):
|
186
|
+
self.relevant_triples[key] += torch.load(
|
187
|
+
"/tmp/outs_for_proc_" + str(rank))
|
188
|
+
os.remove("/tmp/outs_for_proc_" + str(rank))
|
189
|
+
break
|
190
|
+
except: # noqa
|
191
|
+
pass
|
192
|
+
# Increment the doc_id_counter for the next document
|
193
|
+
self.doc_id_counter += 1
|
194
|
+
|
195
|
+
|
196
|
+
known_reasoners = [
|
197
|
+
"llama-3.1-nemotron-ultra-253b-v1",
|
198
|
+
"kimi-k2-instruct",
|
199
|
+
"nemotron-super-49b-v1_5",
|
200
|
+
"gpt-oss",
|
201
|
+
]
|
202
|
+
|
203
|
+
|
204
|
+
def _chunk_to_triples_str_cloud(
|
205
|
+
txt: str, GLOBAL_NIM_KEY='',
|
206
|
+
NIM_MODEL="nvidia/llama-3.1-nemotron-ultra-253b-v1",
|
207
|
+
ENDPOINT_URL="https://integrate.api.nvidia.com/v1",
|
208
|
+
post_text=SYSTEM_PROMPT) -> str:
|
209
|
+
global CLIENT_INITD
|
210
|
+
if not CLIENT_INITD:
|
211
|
+
# We use NIMs since most PyG users may not be able to run a 70B+ model
|
212
|
+
try:
|
213
|
+
from openai import OpenAI
|
214
|
+
except ImportError:
|
215
|
+
quit(
|
216
|
+
"Failed to import `openai` package, please install it and rerun the script" # noqa
|
217
|
+
)
|
218
|
+
global CLIENT
|
219
|
+
CLIENT = OpenAI(base_url=ENDPOINT_URL, api_key=GLOBAL_NIM_KEY)
|
220
|
+
CLIENT_INITD = True
|
221
|
+
txt_input = txt
|
222
|
+
if post_text != "":
|
223
|
+
txt_input += '\n' + post_text
|
224
|
+
messages = []
|
225
|
+
if any([model_name_str in NIM_MODEL
|
226
|
+
for model_name_str in known_reasoners]):
|
227
|
+
messages.append({"role": "system", "content": "detailed thinking on"})
|
228
|
+
messages.append({"role": "user", "content": txt_input})
|
229
|
+
completion = CLIENT.chat.completions.create(model=NIM_MODEL,
|
230
|
+
messages=messages,
|
231
|
+
temperature=0, top_p=1,
|
232
|
+
max_tokens=1024, stream=True)
|
233
|
+
out_str = ""
|
234
|
+
for chunk in completion:
|
235
|
+
if chunk.choices[0].delta.content is not None:
|
236
|
+
out_str += chunk.choices[0].delta.content
|
237
|
+
return out_str
|
238
|
+
|
239
|
+
|
240
|
+
def _parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]:
|
241
|
+
# use pythonic checks for triples
|
242
|
+
processed = []
|
243
|
+
split_by_newline = triples_str.split("\n")
|
244
|
+
# sometimes LLM fails to obey the prompt
|
245
|
+
if len(split_by_newline) > 1:
|
246
|
+
split_triples = split_by_newline
|
247
|
+
llm_obeyed = True
|
248
|
+
else:
|
249
|
+
# handles form "(e, r, e) (e, r, e) ... (e, r, e)""
|
250
|
+
split_triples = triples_str[1:-1].split(") (")
|
251
|
+
llm_obeyed = False
|
252
|
+
for triple_str in split_triples:
|
253
|
+
try:
|
254
|
+
if llm_obeyed:
|
255
|
+
# remove parenthesis and single quotes for parsing
|
256
|
+
triple_str = triple_str.replace("(", "").replace(")",
|
257
|
+
"").replace(
|
258
|
+
"'", "")
|
259
|
+
split_trip = triple_str.split(',')
|
260
|
+
# remove blank space at beginning or end
|
261
|
+
split_trip = [(i[1:] if i[0] == " " else i) for i in split_trip]
|
262
|
+
split_trip = [(i[:-1].lower() if i[-1] == " " else i)
|
263
|
+
for i in split_trip]
|
264
|
+
potential_trip = tuple(split_trip)
|
265
|
+
except: # noqa
|
266
|
+
continue
|
267
|
+
if 'tuple' in str(type(potential_trip)) and len(
|
268
|
+
potential_trip
|
269
|
+
) == 3 and "note:" not in potential_trip[0].lower():
|
270
|
+
# additional check for empty node/edge attrs
|
271
|
+
if potential_trip[0] != '' and potential_trip[
|
272
|
+
1] != '' and potential_trip[2] != '':
|
273
|
+
processed.append(potential_trip)
|
274
|
+
return processed
|
275
|
+
|
276
|
+
|
277
|
+
def _llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs):
|
278
|
+
relevant_triples = []
|
279
|
+
for chunk in chunks:
|
280
|
+
relevant_triples += py_fn(llm_fn(chunk, **kwargs))
|
281
|
+
return relevant_triples
|
282
|
+
|
283
|
+
|
284
|
+
def _multiproc_helper(rank, in_chunks_per_proc, py_fn, llm_fn, NIM_KEY,
|
285
|
+
NIM_MODEL, ENDPOINT_URL):
|
286
|
+
out = _llm_then_python_parse(in_chunks_per_proc[rank], py_fn, llm_fn,
|
287
|
+
GLOBAL_NIM_KEY=NIM_KEY, NIM_MODEL=NIM_MODEL,
|
288
|
+
ENDPOINT_URL=ENDPOINT_URL)
|
289
|
+
torch.save(out, "/tmp/outs_for_proc_" + str(rank))
|
290
|
+
|
291
|
+
|
292
|
+
def _get_num_procs():
|
293
|
+
if hasattr(os, "sched_getaffinity"):
|
294
|
+
try:
|
295
|
+
num_proc = len(os.sched_getaffinity(0)) / (2)
|
296
|
+
except Exception:
|
297
|
+
pass
|
298
|
+
if num_proc is None:
|
299
|
+
num_proc = os.cpu_count() / (2)
|
300
|
+
return int(num_proc)
|
301
|
+
|
302
|
+
|
303
|
+
def _chunk_text(text: str, chunk_size: int = 512) -> list[str]:
|
304
|
+
"""Function to chunk text into sentence-based segments.
|
305
|
+
Co-authored with Claude AI.
|
306
|
+
"""
|
307
|
+
# If the input text is empty or None, return an empty list
|
308
|
+
if not text:
|
309
|
+
return []
|
310
|
+
|
311
|
+
# List of punctuation marks that typically end sentences
|
312
|
+
sentence_endings = '.!?'
|
313
|
+
|
314
|
+
# List to store the resulting chunks
|
315
|
+
chunks = []
|
316
|
+
|
317
|
+
# Continue processing the entire text
|
318
|
+
while text:
|
319
|
+
# If the remaining text is shorter than chunk_size, add it and break
|
320
|
+
if len(text) <= chunk_size:
|
321
|
+
chunks.append(text.strip())
|
322
|
+
break
|
323
|
+
|
324
|
+
# Start with the maximum possible chunk
|
325
|
+
chunk = text[:chunk_size]
|
326
|
+
|
327
|
+
# Try to find the last sentence ending within the chunk
|
328
|
+
best_split = chunk_size
|
329
|
+
for ending in sentence_endings:
|
330
|
+
# Find the last occurrence of the ending punctuation
|
331
|
+
last_ending = chunk.rfind(ending)
|
332
|
+
if last_ending != -1:
|
333
|
+
# Ensure we include the punctuation and any following space
|
334
|
+
best_split = min(
|
335
|
+
best_split, last_ending + 1 +
|
336
|
+
(1 if last_ending + 1 < len(chunk)
|
337
|
+
and chunk[last_ending + 1].isspace() else 0))
|
338
|
+
|
339
|
+
# Adjust to ensure we don't break words
|
340
|
+
# If the next character is a letter, find the last space
|
341
|
+
if best_split < len(text) and text[best_split].isalpha():
|
342
|
+
# Find the last space before the current split point
|
343
|
+
space_split = text[:best_split].rfind(' ')
|
344
|
+
if space_split != -1:
|
345
|
+
best_split = space_split
|
346
|
+
|
347
|
+
# Append the chunk, ensuring it's stripped
|
348
|
+
chunks.append(text[:best_split].strip())
|
349
|
+
|
350
|
+
# Remove the processed part from the text
|
351
|
+
text = text[best_split:].lstrip()
|
352
|
+
|
353
|
+
return chunks
|
@@ -0,0 +1,154 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
3
|
+
|
4
|
+
from torch_geometric.data import Data, FeatureStore, HeteroData
|
5
|
+
from torch_geometric.llm.utils.vectorrag import VectorRetriever
|
6
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
7
|
+
from torch_geometric.typing import InputEdges, InputNodes
|
8
|
+
|
9
|
+
|
10
|
+
class RAGFeatureStore(Protocol):
|
11
|
+
"""Feature store template for remote GNN RAG backend."""
|
12
|
+
@abstractmethod
|
13
|
+
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
|
14
|
+
"""Makes a comparison between the query and all the nodes to get all
|
15
|
+
the closest nodes. Return the indices of the nodes that are to be seeds
|
16
|
+
for the RAG Sampler.
|
17
|
+
"""
|
18
|
+
...
|
19
|
+
|
20
|
+
@property
|
21
|
+
@abstractmethod
|
22
|
+
def config(self) -> Dict[str, Any]:
|
23
|
+
"""Get the config for the RAGFeatureStore."""
|
24
|
+
...
|
25
|
+
|
26
|
+
@config.setter
|
27
|
+
@abstractmethod
|
28
|
+
def config(self, config: Dict[str, Any]):
|
29
|
+
"""Set the config for the RAGFeatureStore."""
|
30
|
+
...
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
|
34
|
+
"""Makes a comparison between the query and all the edges to get all
|
35
|
+
the closest nodes. Returns the edge indices that are to be the seeds
|
36
|
+
for the RAG Sampler.
|
37
|
+
"""
|
38
|
+
...
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def load_subgraph(
|
42
|
+
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
|
43
|
+
) -> Union[Data, HeteroData]:
|
44
|
+
"""Combines sampled subgraph output with features in a Data object."""
|
45
|
+
...
|
46
|
+
|
47
|
+
|
48
|
+
class RAGGraphStore(Protocol):
|
49
|
+
"""Graph store template for remote GNN RAG backend."""
|
50
|
+
@abstractmethod
|
51
|
+
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
|
52
|
+
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
53
|
+
"""Sample a subgraph using the seeded nodes and edges."""
|
54
|
+
...
|
55
|
+
|
56
|
+
@property
|
57
|
+
@abstractmethod
|
58
|
+
def config(self) -> Dict[str, Any]:
|
59
|
+
"""Get the config for the RAGGraphStore."""
|
60
|
+
...
|
61
|
+
|
62
|
+
@config.setter
|
63
|
+
@abstractmethod
|
64
|
+
def config(self, config: Dict[str, Any]):
|
65
|
+
"""Set the config for the RAGGraphStore."""
|
66
|
+
...
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def register_feature_store(self, feature_store: FeatureStore):
|
70
|
+
"""Register a feature store to be used with the sampler. Samplers need
|
71
|
+
info from the feature store in order to work properly on HeteroGraphs.
|
72
|
+
"""
|
73
|
+
...
|
74
|
+
|
75
|
+
|
76
|
+
# TODO: Make compatible with Heterographs
|
77
|
+
|
78
|
+
|
79
|
+
class RAGQueryLoader:
|
80
|
+
"""Loader meant for making RAG queries from a remote backend."""
|
81
|
+
def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore],
|
82
|
+
subgraph_filter: Optional[Callable[[Data, Any], Data]] = None,
|
83
|
+
augment_query: bool = False,
|
84
|
+
vector_retriever: Optional[VectorRetriever] = None,
|
85
|
+
config: Optional[Dict[str, Any]] = None):
|
86
|
+
"""Loader meant for making queries from a remote backend.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
graph_data (Tuple[RAGFeatureStore, RAGGraphStore]):
|
90
|
+
Remote FeatureStore and GraphStore to load from.
|
91
|
+
Assumed to conform to the protocols listed above.
|
92
|
+
subgraph_filter (Optional[Callable[[Data, Any], Data]], optional):
|
93
|
+
Optional local transform to apply to data after retrieval.
|
94
|
+
Defaults to None.
|
95
|
+
augment_query (bool, optional): Whether to augment the query with
|
96
|
+
retrieved documents. Defaults to False.
|
97
|
+
vector_retriever (Optional[VectorRetriever], optional):
|
98
|
+
VectorRetriever to use for retrieving documents.
|
99
|
+
Defaults to None.
|
100
|
+
config (Optional[Dict[str, Any]], optional): Config to pass into
|
101
|
+
the RAGQueryLoader. Defaults to None.
|
102
|
+
"""
|
103
|
+
fstore, gstore = graph_data
|
104
|
+
self.vector_retriever = vector_retriever
|
105
|
+
self.augment_query = augment_query
|
106
|
+
self.feature_store = fstore
|
107
|
+
self.graph_store = gstore
|
108
|
+
self.graph_store.edge_index = self.graph_store.edge_index.contiguous()
|
109
|
+
self.graph_store.register_feature_store(self.feature_store)
|
110
|
+
self.subgraph_filter = subgraph_filter
|
111
|
+
self.config = config
|
112
|
+
|
113
|
+
def _propagate_config(self, config: Dict[str, Any]):
|
114
|
+
"""Propagate the config the relevant components."""
|
115
|
+
self.feature_store.config = config
|
116
|
+
self.graph_store.config = config
|
117
|
+
|
118
|
+
@property
|
119
|
+
def config(self):
|
120
|
+
"""Get the config for the RAGQueryLoader."""
|
121
|
+
return self._config
|
122
|
+
|
123
|
+
@config.setter
|
124
|
+
def config(self, config: Dict[str, Any]):
|
125
|
+
"""Set the config for the RAGQueryLoader.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
config (Dict[str, Any]): The config to set.
|
129
|
+
"""
|
130
|
+
self._propagate_config(config)
|
131
|
+
self._config = config
|
132
|
+
|
133
|
+
def query(self, query: Any) -> Data:
|
134
|
+
"""Retrieve a subgraph associated with the query with all its feature
|
135
|
+
attributes.
|
136
|
+
"""
|
137
|
+
if self.vector_retriever:
|
138
|
+
retrieved_docs = self.vector_retriever.query(query)
|
139
|
+
|
140
|
+
if self.augment_query:
|
141
|
+
query = [query] + retrieved_docs
|
142
|
+
|
143
|
+
seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query)
|
144
|
+
|
145
|
+
subgraph_sample = self.graph_store.sample_subgraph(seed_nodes)
|
146
|
+
|
147
|
+
data = self.feature_store.load_subgraph(sample=subgraph_sample)
|
148
|
+
|
149
|
+
# apply local filter
|
150
|
+
if self.subgraph_filter:
|
151
|
+
data = self.subgraph_filter(data, query)
|
152
|
+
if self.vector_retriever:
|
153
|
+
data.text_context = retrieved_docs
|
154
|
+
return data
|