pyg-nightly 2.7.0.dev20250904__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
  2. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +34 -27
  3. torch_geometric/__init__.py +1 -1
  4. torch_geometric/data/__init__.py +0 -5
  5. torch_geometric/data/lightning/datamodule.py +2 -2
  6. torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
  7. torch_geometric/datasets/web_qsp_dataset.py +262 -210
  8. torch_geometric/graphgym/imports.py +2 -2
  9. torch_geometric/llm/__init__.py +9 -0
  10. torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
  11. torch_geometric/llm/models/__init__.py +23 -0
  12. torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
  13. torch_geometric/{nn → llm}/models/git_mol.py +1 -1
  14. torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
  15. torch_geometric/llm/models/llm_judge.py +158 -0
  16. torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
  17. torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
  18. torch_geometric/llm/models/txt2kg.py +353 -0
  19. torch_geometric/llm/rag_loader.py +154 -0
  20. torch_geometric/llm/utils/backend_utils.py +442 -0
  21. torch_geometric/llm/utils/feature_store.py +169 -0
  22. torch_geometric/llm/utils/graph_store.py +199 -0
  23. torch_geometric/llm/utils/vectorrag.py +124 -0
  24. torch_geometric/loader/__init__.py +0 -4
  25. torch_geometric/metrics/link_pred.py +13 -2
  26. torch_geometric/nn/__init__.py +0 -1
  27. torch_geometric/nn/models/__init__.py +0 -10
  28. torch_geometric/nn/models/sgformer.py +2 -0
  29. torch_geometric/utils/cross_entropy.py +34 -13
  30. torch_geometric/loader/rag_loader.py +0 -107
  31. torch_geometric/nn/nlp/__init__.py +0 -9
  32. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
  33. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
  34. /torch_geometric/{nn → llm}/models/glem.py +0 -0
  35. /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
  36. /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -0,0 +1,353 @@
1
+ import os
2
+ import time
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.multiprocessing as mp
7
+
8
+ CLIENT_INITD = False
9
+
10
+ CLIENT = None
11
+ GLOBAL_NIM_KEY = ""
12
+ SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Separate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph." # noqa
13
+
14
+
15
+ class TXT2KG():
16
+ """A class to convert text data into a Knowledge Graph (KG) format.
17
+ Uses NVIDIA NIMs + Prompt engineering by default.
18
+ Default model `nvidia/llama-3.1-nemotron-70b-instruct`
19
+ is on par or better than GPT4o in benchmarks.
20
+ We need a high quality model to ensure high quality KG.
21
+ Otherwise we have garbage in garbage out for the rest of the
22
+ GNN+LLM RAG pipeline.
23
+
24
+ Use local_lm flag for local debugging/dev. You still need to be able to
25
+ inference a 14B param LLM, 'VAGOsolutions/SauerkrautLM-v2-14b-DPO'.
26
+ Smaller LLMs did not work at all in testing.
27
+ Note this 14B model requires a considerable amount of GPU memory.
28
+ See examples/llm/txt2kg_rag.py for an example.
29
+
30
+ Args:
31
+ NVIDIA_NIM_MODEL : str, optional
32
+ The name of the NVIDIA NIM model to use.
33
+ (default: "nvidia/llama-3.1-nemotron-70b-instruct").
34
+ NVIDIA_API_KEY : str, optional
35
+ The API key for accessing NVIDIA's NIM models (default: "").
36
+ ENDPOINT_URL : str, optional
37
+ The URL hosting your model, in case you are not using
38
+ the public NIM.
39
+ (default: "https://integrate.api.nvidia.com/v1").
40
+ local_LM : bool, optional
41
+ A flag indicating whether a local Language Model (LM)
42
+ should be used. This uses HuggingFace and will be slower
43
+ than deploying your own private NIM endpoint. This flag
44
+ is mainly recommended for dev/debug.
45
+ (default: False).
46
+ chunk_size : int, optional
47
+ The size of the chunks in which the text data is processed
48
+ (default: 512).
49
+ """
50
+ def __init__(
51
+ self,
52
+ NVIDIA_NIM_MODEL: Optional[
53
+ str] = "nvidia/llama-3.1-nemotron-70b-instruct",
54
+ NVIDIA_API_KEY: Optional[str] = "",
55
+ ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
56
+ local_LM: bool = False,
57
+ chunk_size: int = 512,
58
+ ) -> None:
59
+ self.local_LM = local_LM
60
+ # Initialize the local LM flag and the NIM model info accordingly
61
+ if self.local_LM:
62
+ # If using a local LM, set the initd_LM flag to False
63
+ self.initd_LM = False
64
+ else:
65
+ # If not using a local LM, store the provided NIM model info
66
+ self.NVIDIA_API_KEY = NVIDIA_API_KEY
67
+ self.NIM_MODEL = NVIDIA_NIM_MODEL
68
+ self.ENDPOINT_URL = ENDPOINT_URL
69
+
70
+ # Set the chunk size for processing text data
71
+ self.chunk_size = chunk_size
72
+
73
+ # Initialize counters and storage for parsing results
74
+ self.doc_id_counter = 0
75
+ self.relevant_triples = {}
76
+ self.total_chars_parsed = 0
77
+ self.time_to_parse = 0.0
78
+
79
+ def save_kg(self, path: str) -> None:
80
+ """Saves the relevant triples in the knowledge graph (KG) to a file.
81
+
82
+ Args:
83
+ path (str): The file path where the KG will be saved.
84
+
85
+ Returns:
86
+ None
87
+ """
88
+ torch.save(self.relevant_triples, path)
89
+
90
+ def _chunk_to_triples_str_local(self, txt: str) -> str:
91
+ # call LLM on text
92
+ chunk_start_time = time.time()
93
+ if not self.initd_LM:
94
+ from torch_geometric.nn.nlp import LLM
95
+ LM_name = "VAGOsolutions/SauerkrautLM-v2-14b-DPO"
96
+ self.model = LLM(LM_name).eval()
97
+ self.initd_LM = True
98
+ out_str = self.model.inference(question=[txt + '\n' + SYSTEM_PROMPT],
99
+ max_tokens=self.chunk_size)[0]
100
+ # for debug
101
+ self.total_chars_parsed += len(txt)
102
+ self.time_to_parse += round(time.time() - chunk_start_time, 2)
103
+ self.avg_chars_parsed_per_sec = self.total_chars_parsed / self.time_to_parse # noqa
104
+ return out_str
105
+
106
+ def add_doc_2_KG(
107
+ self,
108
+ txt: str,
109
+ QA_pair: Optional[Tuple[str, str]] = None,
110
+ ) -> None:
111
+ """Add a document to the Knowledge Graph (KG).
112
+
113
+ Args:
114
+ txt (str): The text to extract triples from.
115
+ QA_pair (Tuple[str, str]], optional):
116
+ A QA pair to associate with the extracted triples.
117
+ Useful for downstream evaluation.
118
+
119
+ Returns:
120
+ - None
121
+ """
122
+ if not self.local_LM:
123
+ # Ensure NVIDIA_API_KEY is set before proceeding
124
+ assert self.NVIDIA_API_KEY != '', \
125
+ "Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True"
126
+ if QA_pair:
127
+ # QA_pairs should be unique keys, check if already exists in KG
128
+ if QA_pair in self.relevant_triples.keys():
129
+ print("Warning: QA_Pair was already added to the set")
130
+ print("Q=", QA_pair[0])
131
+ print("A=", QA_pair[1])
132
+ print("Previously parsed triples=",
133
+ self.relevant_triples[QA_pair])
134
+ print("Skipping...")
135
+ key = QA_pair
136
+ else:
137
+ # If no QA_pair, use the current doc_id_counter as the key
138
+ key = self.doc_id_counter
139
+
140
+ # Handle empty text (context-less QA pairs)
141
+ if txt == "":
142
+ self.relevant_triples[key] = []
143
+ else:
144
+ # Chunk the text into smaller pieces for processing
145
+ chunks = _chunk_text(txt, chunk_size=self.chunk_size)
146
+
147
+ if self.local_LM:
148
+ # For debugging purposes...
149
+ # process chunks sequentially on the local LM
150
+ self.relevant_triples[key] = _llm_then_python_parse(
151
+ chunks, _parse_n_check_triples,
152
+ self._chunk_to_triples_str_local)
153
+ else:
154
+ # Process chunks in parallel using multiple processes
155
+ num_procs = min(len(chunks), _get_num_procs())
156
+ meta_chunk_size = int(len(chunks) / num_procs)
157
+ in_chunks_per_proc = {
158
+ j:
159
+ chunks[j *
160
+ meta_chunk_size:min((j + 1) *
161
+ meta_chunk_size, len(chunks))]
162
+ for j in range(num_procs)
163
+ }
164
+ for _retry_j in range(5):
165
+ try:
166
+ for _retry_i in range(200):
167
+ try:
168
+ # Spawn multiple processes
169
+ # process chunks in parallel
170
+ mp.spawn(
171
+ _multiproc_helper,
172
+ args=(in_chunks_per_proc,
173
+ _parse_n_check_triples,
174
+ _chunk_to_triples_str_cloud,
175
+ self.NVIDIA_API_KEY, self.NIM_MODEL,
176
+ self.ENDPOINT_URL), nprocs=num_procs)
177
+ break
178
+ except: # noqa
179
+ # keep retrying...
180
+ # txt2kg is costly -> stoppage is costly
181
+ pass
182
+
183
+ # Collect the results from each process
184
+ self.relevant_triples[key] = []
185
+ for rank in range(num_procs):
186
+ self.relevant_triples[key] += torch.load(
187
+ "/tmp/outs_for_proc_" + str(rank))
188
+ os.remove("/tmp/outs_for_proc_" + str(rank))
189
+ break
190
+ except: # noqa
191
+ pass
192
+ # Increment the doc_id_counter for the next document
193
+ self.doc_id_counter += 1
194
+
195
+
196
+ known_reasoners = [
197
+ "llama-3.1-nemotron-ultra-253b-v1",
198
+ "kimi-k2-instruct",
199
+ "nemotron-super-49b-v1_5",
200
+ "gpt-oss",
201
+ ]
202
+
203
+
204
+ def _chunk_to_triples_str_cloud(
205
+ txt: str, GLOBAL_NIM_KEY='',
206
+ NIM_MODEL="nvidia/llama-3.1-nemotron-ultra-253b-v1",
207
+ ENDPOINT_URL="https://integrate.api.nvidia.com/v1",
208
+ post_text=SYSTEM_PROMPT) -> str:
209
+ global CLIENT_INITD
210
+ if not CLIENT_INITD:
211
+ # We use NIMs since most PyG users may not be able to run a 70B+ model
212
+ try:
213
+ from openai import OpenAI
214
+ except ImportError:
215
+ quit(
216
+ "Failed to import `openai` package, please install it and rerun the script" # noqa
217
+ )
218
+ global CLIENT
219
+ CLIENT = OpenAI(base_url=ENDPOINT_URL, api_key=GLOBAL_NIM_KEY)
220
+ CLIENT_INITD = True
221
+ txt_input = txt
222
+ if post_text != "":
223
+ txt_input += '\n' + post_text
224
+ messages = []
225
+ if any([model_name_str in NIM_MODEL
226
+ for model_name_str in known_reasoners]):
227
+ messages.append({"role": "system", "content": "detailed thinking on"})
228
+ messages.append({"role": "user", "content": txt_input})
229
+ completion = CLIENT.chat.completions.create(model=NIM_MODEL,
230
+ messages=messages,
231
+ temperature=0, top_p=1,
232
+ max_tokens=1024, stream=True)
233
+ out_str = ""
234
+ for chunk in completion:
235
+ if chunk.choices[0].delta.content is not None:
236
+ out_str += chunk.choices[0].delta.content
237
+ return out_str
238
+
239
+
240
+ def _parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]:
241
+ # use pythonic checks for triples
242
+ processed = []
243
+ split_by_newline = triples_str.split("\n")
244
+ # sometimes LLM fails to obey the prompt
245
+ if len(split_by_newline) > 1:
246
+ split_triples = split_by_newline
247
+ llm_obeyed = True
248
+ else:
249
+ # handles form "(e, r, e) (e, r, e) ... (e, r, e)""
250
+ split_triples = triples_str[1:-1].split(") (")
251
+ llm_obeyed = False
252
+ for triple_str in split_triples:
253
+ try:
254
+ if llm_obeyed:
255
+ # remove parenthesis and single quotes for parsing
256
+ triple_str = triple_str.replace("(", "").replace(")",
257
+ "").replace(
258
+ "'", "")
259
+ split_trip = triple_str.split(',')
260
+ # remove blank space at beginning or end
261
+ split_trip = [(i[1:] if i[0] == " " else i) for i in split_trip]
262
+ split_trip = [(i[:-1].lower() if i[-1] == " " else i)
263
+ for i in split_trip]
264
+ potential_trip = tuple(split_trip)
265
+ except: # noqa
266
+ continue
267
+ if 'tuple' in str(type(potential_trip)) and len(
268
+ potential_trip
269
+ ) == 3 and "note:" not in potential_trip[0].lower():
270
+ # additional check for empty node/edge attrs
271
+ if potential_trip[0] != '' and potential_trip[
272
+ 1] != '' and potential_trip[2] != '':
273
+ processed.append(potential_trip)
274
+ return processed
275
+
276
+
277
+ def _llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs):
278
+ relevant_triples = []
279
+ for chunk in chunks:
280
+ relevant_triples += py_fn(llm_fn(chunk, **kwargs))
281
+ return relevant_triples
282
+
283
+
284
+ def _multiproc_helper(rank, in_chunks_per_proc, py_fn, llm_fn, NIM_KEY,
285
+ NIM_MODEL, ENDPOINT_URL):
286
+ out = _llm_then_python_parse(in_chunks_per_proc[rank], py_fn, llm_fn,
287
+ GLOBAL_NIM_KEY=NIM_KEY, NIM_MODEL=NIM_MODEL,
288
+ ENDPOINT_URL=ENDPOINT_URL)
289
+ torch.save(out, "/tmp/outs_for_proc_" + str(rank))
290
+
291
+
292
+ def _get_num_procs():
293
+ if hasattr(os, "sched_getaffinity"):
294
+ try:
295
+ num_proc = len(os.sched_getaffinity(0)) / (2)
296
+ except Exception:
297
+ pass
298
+ if num_proc is None:
299
+ num_proc = os.cpu_count() / (2)
300
+ return int(num_proc)
301
+
302
+
303
+ def _chunk_text(text: str, chunk_size: int = 512) -> list[str]:
304
+ """Function to chunk text into sentence-based segments.
305
+ Co-authored with Claude AI.
306
+ """
307
+ # If the input text is empty or None, return an empty list
308
+ if not text:
309
+ return []
310
+
311
+ # List of punctuation marks that typically end sentences
312
+ sentence_endings = '.!?'
313
+
314
+ # List to store the resulting chunks
315
+ chunks = []
316
+
317
+ # Continue processing the entire text
318
+ while text:
319
+ # If the remaining text is shorter than chunk_size, add it and break
320
+ if len(text) <= chunk_size:
321
+ chunks.append(text.strip())
322
+ break
323
+
324
+ # Start with the maximum possible chunk
325
+ chunk = text[:chunk_size]
326
+
327
+ # Try to find the last sentence ending within the chunk
328
+ best_split = chunk_size
329
+ for ending in sentence_endings:
330
+ # Find the last occurrence of the ending punctuation
331
+ last_ending = chunk.rfind(ending)
332
+ if last_ending != -1:
333
+ # Ensure we include the punctuation and any following space
334
+ best_split = min(
335
+ best_split, last_ending + 1 +
336
+ (1 if last_ending + 1 < len(chunk)
337
+ and chunk[last_ending + 1].isspace() else 0))
338
+
339
+ # Adjust to ensure we don't break words
340
+ # If the next character is a letter, find the last space
341
+ if best_split < len(text) and text[best_split].isalpha():
342
+ # Find the last space before the current split point
343
+ space_split = text[:best_split].rfind(' ')
344
+ if space_split != -1:
345
+ best_split = space_split
346
+
347
+ # Append the chunk, ensuring it's stripped
348
+ chunks.append(text[:best_split].strip())
349
+
350
+ # Remove the processed part from the text
351
+ text = text[best_split:].lstrip()
352
+
353
+ return chunks
@@ -0,0 +1,154 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
3
+
4
+ from torch_geometric.data import Data, FeatureStore, HeteroData
5
+ from torch_geometric.llm.utils.vectorrag import VectorRetriever
6
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
7
+ from torch_geometric.typing import InputEdges, InputNodes
8
+
9
+
10
+ class RAGFeatureStore(Protocol):
11
+ """Feature store template for remote GNN RAG backend."""
12
+ @abstractmethod
13
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
14
+ """Makes a comparison between the query and all the nodes to get all
15
+ the closest nodes. Return the indices of the nodes that are to be seeds
16
+ for the RAG Sampler.
17
+ """
18
+ ...
19
+
20
+ @property
21
+ @abstractmethod
22
+ def config(self) -> Dict[str, Any]:
23
+ """Get the config for the RAGFeatureStore."""
24
+ ...
25
+
26
+ @config.setter
27
+ @abstractmethod
28
+ def config(self, config: Dict[str, Any]):
29
+ """Set the config for the RAGFeatureStore."""
30
+ ...
31
+
32
+ @abstractmethod
33
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
34
+ """Makes a comparison between the query and all the edges to get all
35
+ the closest nodes. Returns the edge indices that are to be the seeds
36
+ for the RAG Sampler.
37
+ """
38
+ ...
39
+
40
+ @abstractmethod
41
+ def load_subgraph(
42
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
43
+ ) -> Union[Data, HeteroData]:
44
+ """Combines sampled subgraph output with features in a Data object."""
45
+ ...
46
+
47
+
48
+ class RAGGraphStore(Protocol):
49
+ """Graph store template for remote GNN RAG backend."""
50
+ @abstractmethod
51
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
52
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
53
+ """Sample a subgraph using the seeded nodes and edges."""
54
+ ...
55
+
56
+ @property
57
+ @abstractmethod
58
+ def config(self) -> Dict[str, Any]:
59
+ """Get the config for the RAGGraphStore."""
60
+ ...
61
+
62
+ @config.setter
63
+ @abstractmethod
64
+ def config(self, config: Dict[str, Any]):
65
+ """Set the config for the RAGGraphStore."""
66
+ ...
67
+
68
+ @abstractmethod
69
+ def register_feature_store(self, feature_store: FeatureStore):
70
+ """Register a feature store to be used with the sampler. Samplers need
71
+ info from the feature store in order to work properly on HeteroGraphs.
72
+ """
73
+ ...
74
+
75
+
76
+ # TODO: Make compatible with Heterographs
77
+
78
+
79
+ class RAGQueryLoader:
80
+ """Loader meant for making RAG queries from a remote backend."""
81
+ def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore],
82
+ subgraph_filter: Optional[Callable[[Data, Any], Data]] = None,
83
+ augment_query: bool = False,
84
+ vector_retriever: Optional[VectorRetriever] = None,
85
+ config: Optional[Dict[str, Any]] = None):
86
+ """Loader meant for making queries from a remote backend.
87
+
88
+ Args:
89
+ graph_data (Tuple[RAGFeatureStore, RAGGraphStore]):
90
+ Remote FeatureStore and GraphStore to load from.
91
+ Assumed to conform to the protocols listed above.
92
+ subgraph_filter (Optional[Callable[[Data, Any], Data]], optional):
93
+ Optional local transform to apply to data after retrieval.
94
+ Defaults to None.
95
+ augment_query (bool, optional): Whether to augment the query with
96
+ retrieved documents. Defaults to False.
97
+ vector_retriever (Optional[VectorRetriever], optional):
98
+ VectorRetriever to use for retrieving documents.
99
+ Defaults to None.
100
+ config (Optional[Dict[str, Any]], optional): Config to pass into
101
+ the RAGQueryLoader. Defaults to None.
102
+ """
103
+ fstore, gstore = graph_data
104
+ self.vector_retriever = vector_retriever
105
+ self.augment_query = augment_query
106
+ self.feature_store = fstore
107
+ self.graph_store = gstore
108
+ self.graph_store.edge_index = self.graph_store.edge_index.contiguous()
109
+ self.graph_store.register_feature_store(self.feature_store)
110
+ self.subgraph_filter = subgraph_filter
111
+ self.config = config
112
+
113
+ def _propagate_config(self, config: Dict[str, Any]):
114
+ """Propagate the config the relevant components."""
115
+ self.feature_store.config = config
116
+ self.graph_store.config = config
117
+
118
+ @property
119
+ def config(self):
120
+ """Get the config for the RAGQueryLoader."""
121
+ return self._config
122
+
123
+ @config.setter
124
+ def config(self, config: Dict[str, Any]):
125
+ """Set the config for the RAGQueryLoader.
126
+
127
+ Args:
128
+ config (Dict[str, Any]): The config to set.
129
+ """
130
+ self._propagate_config(config)
131
+ self._config = config
132
+
133
+ def query(self, query: Any) -> Data:
134
+ """Retrieve a subgraph associated with the query with all its feature
135
+ attributes.
136
+ """
137
+ if self.vector_retriever:
138
+ retrieved_docs = self.vector_retriever.query(query)
139
+
140
+ if self.augment_query:
141
+ query = [query] + retrieved_docs
142
+
143
+ seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query)
144
+
145
+ subgraph_sample = self.graph_store.sample_subgraph(seed_nodes)
146
+
147
+ data = self.feature_store.load_subgraph(sample=subgraph_sample)
148
+
149
+ # apply local filter
150
+ if self.subgraph_filter:
151
+ data = self.subgraph_filter(data, query)
152
+ if self.vector_retriever:
153
+ data.text_context = retrieved_docs
154
+ return data