pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,154 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
3
+
4
+ from torch_geometric.data import Data, FeatureStore, HeteroData
5
+ from torch_geometric.llm.utils.vectorrag import VectorRetriever
6
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
7
+ from torch_geometric.typing import InputEdges, InputNodes
8
+
9
+
10
+ class RAGFeatureStore(Protocol):
11
+ """Feature store template for remote GNN RAG backend."""
12
+ @abstractmethod
13
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
14
+ """Makes a comparison between the query and all the nodes to get all
15
+ the closest nodes. Return the indices of the nodes that are to be seeds
16
+ for the RAG Sampler.
17
+ """
18
+ ...
19
+
20
+ @property
21
+ @abstractmethod
22
+ def config(self) -> Dict[str, Any]:
23
+ """Get the config for the RAGFeatureStore."""
24
+ ...
25
+
26
+ @config.setter
27
+ @abstractmethod
28
+ def config(self, config: Dict[str, Any]):
29
+ """Set the config for the RAGFeatureStore."""
30
+ ...
31
+
32
+ @abstractmethod
33
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
34
+ """Makes a comparison between the query and all the edges to get all
35
+ the closest nodes. Returns the edge indices that are to be the seeds
36
+ for the RAG Sampler.
37
+ """
38
+ ...
39
+
40
+ @abstractmethod
41
+ def load_subgraph(
42
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
43
+ ) -> Union[Data, HeteroData]:
44
+ """Combines sampled subgraph output with features in a Data object."""
45
+ ...
46
+
47
+
48
+ class RAGGraphStore(Protocol):
49
+ """Graph store template for remote GNN RAG backend."""
50
+ @abstractmethod
51
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
52
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
53
+ """Sample a subgraph using the seeded nodes and edges."""
54
+ ...
55
+
56
+ @property
57
+ @abstractmethod
58
+ def config(self) -> Dict[str, Any]:
59
+ """Get the config for the RAGGraphStore."""
60
+ ...
61
+
62
+ @config.setter
63
+ @abstractmethod
64
+ def config(self, config: Dict[str, Any]):
65
+ """Set the config for the RAGGraphStore."""
66
+ ...
67
+
68
+ @abstractmethod
69
+ def register_feature_store(self, feature_store: FeatureStore):
70
+ """Register a feature store to be used with the sampler. Samplers need
71
+ info from the feature store in order to work properly on HeteroGraphs.
72
+ """
73
+ ...
74
+
75
+
76
+ # TODO: Make compatible with Heterographs
77
+
78
+
79
+ class RAGQueryLoader:
80
+ """Loader meant for making RAG queries from a remote backend."""
81
+ def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore],
82
+ subgraph_filter: Optional[Callable[[Data, Any], Data]] = None,
83
+ augment_query: bool = False,
84
+ vector_retriever: Optional[VectorRetriever] = None,
85
+ config: Optional[Dict[str, Any]] = None):
86
+ """Loader meant for making queries from a remote backend.
87
+
88
+ Args:
89
+ graph_data (Tuple[RAGFeatureStore, RAGGraphStore]):
90
+ Remote FeatureStore and GraphStore to load from.
91
+ Assumed to conform to the protocols listed above.
92
+ subgraph_filter (Optional[Callable[[Data, Any], Data]], optional):
93
+ Optional local transform to apply to data after retrieval.
94
+ Defaults to None.
95
+ augment_query (bool, optional): Whether to augment the query with
96
+ retrieved documents. Defaults to False.
97
+ vector_retriever (Optional[VectorRetriever], optional):
98
+ VectorRetriever to use for retrieving documents.
99
+ Defaults to None.
100
+ config (Optional[Dict[str, Any]], optional): Config to pass into
101
+ the RAGQueryLoader. Defaults to None.
102
+ """
103
+ fstore, gstore = graph_data
104
+ self.vector_retriever = vector_retriever
105
+ self.augment_query = augment_query
106
+ self.feature_store = fstore
107
+ self.graph_store = gstore
108
+ self.graph_store.edge_index = self.graph_store.edge_index.contiguous()
109
+ self.graph_store.register_feature_store(self.feature_store)
110
+ self.subgraph_filter = subgraph_filter
111
+ self.config = config
112
+
113
+ def _propagate_config(self, config: Dict[str, Any]):
114
+ """Propagate the config the relevant components."""
115
+ self.feature_store.config = config
116
+ self.graph_store.config = config
117
+
118
+ @property
119
+ def config(self):
120
+ """Get the config for the RAGQueryLoader."""
121
+ return self._config
122
+
123
+ @config.setter
124
+ def config(self, config: Dict[str, Any]):
125
+ """Set the config for the RAGQueryLoader.
126
+
127
+ Args:
128
+ config (Dict[str, Any]): The config to set.
129
+ """
130
+ self._propagate_config(config)
131
+ self._config = config
132
+
133
+ def query(self, query: Any) -> Data:
134
+ """Retrieve a subgraph associated with the query with all its feature
135
+ attributes.
136
+ """
137
+ if self.vector_retriever:
138
+ retrieved_docs = self.vector_retriever.query(query)
139
+
140
+ if self.augment_query:
141
+ query = [query] + retrieved_docs
142
+
143
+ seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query)
144
+
145
+ subgraph_sample = self.graph_store.sample_subgraph(seed_nodes)
146
+
147
+ data = self.feature_store.load_subgraph(sample=subgraph_sample)
148
+
149
+ # apply local filter
150
+ if self.subgraph_filter:
151
+ data = self.subgraph_filter(data, query)
152
+ if self.vector_retriever:
153
+ data.text_context = retrieved_docs
154
+ return data
@@ -0,0 +1,10 @@
1
+ from .backend_utils import * # noqa
2
+ from .feature_store import KNNRAGFeatureStore
3
+ from .graph_store import NeighborSamplingRAGGraphStore
4
+ from .vectorrag import DocumentRetriever
5
+
6
+ __all__ = classes = [
7
+ 'KNNRAGFeatureStore',
8
+ 'NeighborSamplingRAGGraphStore',
9
+ 'DocumentRetriever',
10
+ ]
@@ -0,0 +1,443 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from enum import Enum, auto
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Iterable,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Protocol,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ no_type_check,
17
+ runtime_checkable,
18
+ )
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.nn import Module
24
+
25
+ from torch_geometric.data import Data, FeatureStore, GraphStore
26
+ from torch_geometric.distributed import (
27
+ LocalFeatureStore,
28
+ LocalGraphStore,
29
+ Partitioner,
30
+ )
31
+ from torch_geometric.llm.large_graph_indexer import (
32
+ EDGE_RELATION,
33
+ LargeGraphIndexer,
34
+ TripletLike,
35
+ )
36
+ from torch_geometric.llm.models import SentenceTransformer
37
+ from torch_geometric.typing import EdgeType, NodeType
38
+
39
+ try:
40
+ from pandas import DataFrame
41
+ except ImportError:
42
+ DataFrame = None
43
+ RemoteGraphBackend = Tuple[FeatureStore, GraphStore]
44
+
45
+ # TODO: Make everything compatible with Hetero graphs aswell
46
+
47
+
48
+ def preprocess_triplet(triplet: TripletLike) -> TripletLike:
49
+ h, r, t = triplet
50
+ return str(h).lower(), str(r).lower(), str(t).lower()
51
+
52
+
53
+ @no_type_check
54
+ def retrieval_via_pcst(
55
+ data: Data,
56
+ q_emb: Tensor,
57
+ textual_nodes: Any,
58
+ textual_edges: Any,
59
+ topk: int = 3,
60
+ topk_e: int = 5,
61
+ cost_e: float = 0.5,
62
+ num_clusters: int = 1,
63
+ ) -> Tuple[Data, str]:
64
+
65
+ # skip PCST for bad graphs
66
+ booly = data.edge_attr is None or data.edge_attr.numel() == 0
67
+ booly = booly or data.x is None or data.x.numel() == 0
68
+ booly = booly or data.edge_index is None or data.edge_index.numel() == 0
69
+ if not booly:
70
+ c = 0.01
71
+
72
+ from pcst_fast import pcst_fast
73
+
74
+ root = -1
75
+ pruning = 'gw'
76
+ verbosity_level = 0
77
+ if topk > 0:
78
+ n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
79
+ topk = min(topk, data.num_nodes)
80
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
81
+
82
+ n_prizes = torch.zeros_like(n_prizes)
83
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
84
+ else:
85
+ n_prizes = torch.zeros(data.num_nodes)
86
+
87
+ if topk_e > 0:
88
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
89
+ topk_e = min(topk_e, e_prizes.unique().size(0))
90
+
91
+ topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e,
92
+ largest=True)
93
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
94
+ last_topk_e_value = topk_e
95
+ for k in range(topk_e):
96
+ indices = e_prizes == topk_e_values[k]
97
+ value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
98
+ e_prizes[indices] = value
99
+ last_topk_e_value = value * (1 - c)
100
+ # reduce the cost of the edges so that at least one edge is chosen
101
+ cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
102
+ else:
103
+ e_prizes = torch.zeros(data.num_edges)
104
+
105
+ costs = []
106
+ edges = []
107
+ virtual_n_prizes = []
108
+ virtual_edges = []
109
+ virtual_costs = []
110
+ mapping_n = {}
111
+ mapping_e = {}
112
+ for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
113
+ prize_e = e_prizes[i]
114
+ if prize_e <= cost_e:
115
+ mapping_e[len(edges)] = i
116
+ edges.append((src, dst))
117
+ costs.append(cost_e - prize_e)
118
+ else:
119
+ virtual_node_id = data.num_nodes + len(virtual_n_prizes)
120
+ mapping_n[virtual_node_id] = i
121
+ virtual_edges.append((src, virtual_node_id))
122
+ virtual_edges.append((virtual_node_id, dst))
123
+ virtual_costs.append(0)
124
+ virtual_costs.append(0)
125
+ virtual_n_prizes.append(prize_e - cost_e)
126
+
127
+ prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
128
+ num_edges = len(edges)
129
+ if len(virtual_costs) > 0:
130
+ costs = np.array(costs + virtual_costs)
131
+ edges = np.array(edges + virtual_edges)
132
+
133
+ vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
134
+ pruning, verbosity_level)
135
+
136
+ selected_nodes = vertices[vertices < data.num_nodes]
137
+ selected_edges = [mapping_e[e] for e in edges if e < num_edges]
138
+ virtual_vertices = vertices[vertices >= data.num_nodes]
139
+ if len(virtual_vertices) > 0:
140
+ virtual_vertices = vertices[vertices >= data.num_nodes]
141
+ virtual_edges = [mapping_n[i] for i in virtual_vertices]
142
+ selected_edges = np.array(selected_edges + virtual_edges)
143
+
144
+ edge_index = data.edge_index[:, selected_edges]
145
+ selected_nodes = np.unique(
146
+ np.concatenate(
147
+ [selected_nodes, edge_index[0].numpy(),
148
+ edge_index[1].numpy()]))
149
+
150
+ n = textual_nodes.iloc[selected_nodes]
151
+ e = textual_edges.iloc[selected_edges]
152
+ else:
153
+ n = textual_nodes
154
+ e = textual_edges
155
+
156
+ desc = n.to_csv(index=False) + '\n' + e.to_csv(
157
+ index=False, columns=['src', 'edge_attr', 'dst'])
158
+
159
+ if booly:
160
+ return data, desc
161
+
162
+ mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
163
+ src = [mapping[i] for i in edge_index[0].tolist()]
164
+ dst = [mapping[i] for i in edge_index[1].tolist()]
165
+
166
+ # HACK Added so that the subset of nodes and edges selected can be tracked
167
+ node_idx = np.array(data.node_idx)[selected_nodes]
168
+ edge_idx = np.array(data.edge_idx)[selected_edges]
169
+
170
+ data = Data(
171
+ x=data.x[selected_nodes],
172
+ edge_index=torch.tensor([src, dst]).to(torch.long),
173
+ edge_attr=data.edge_attr[selected_edges],
174
+ # HACK: track subset of selected nodes/edges
175
+ node_idx=node_idx,
176
+ edge_idx=edge_idx,
177
+ )
178
+
179
+ return data, desc
180
+
181
+
182
+ def batch_knn(query_enc: Tensor, embeds: Tensor,
183
+ k: int) -> Iterator[Tuple[Tensor, Tensor]]:
184
+ from torchmetrics.functional import pairwise_cosine_similarity
185
+ prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device))
186
+ topk = min(k, len(embeds))
187
+ for i, q in enumerate(prizes):
188
+ _, indices = torch.topk(q, topk, largest=True)
189
+ yield indices, query_enc[i].unsqueeze(0)
190
+
191
+
192
+ # Adapted from LocalGraphStore
193
+ @runtime_checkable
194
+ class ConvertableGraphStore(Protocol):
195
+ @classmethod
196
+ def from_data(
197
+ cls,
198
+ edge_id: Tensor,
199
+ edge_index: Tensor,
200
+ num_nodes: int,
201
+ is_sorted: bool = False,
202
+ ) -> GraphStore:
203
+ ...
204
+
205
+ @classmethod
206
+ def from_hetero_data(
207
+ cls,
208
+ edge_id_dict: Dict[EdgeType, Tensor],
209
+ edge_index_dict: Dict[EdgeType, Tensor],
210
+ num_nodes_dict: Dict[NodeType, int],
211
+ is_sorted: bool = False,
212
+ ) -> GraphStore:
213
+ ...
214
+
215
+ @classmethod
216
+ def from_partition(cls, root: str, pid: int) -> GraphStore:
217
+ ...
218
+
219
+
220
+ # Adapted from LocalFeatureStore
221
+ @runtime_checkable
222
+ class ConvertableFeatureStore(Protocol):
223
+ @classmethod
224
+ def from_data(
225
+ cls,
226
+ node_id: Tensor,
227
+ x: Optional[Tensor] = None,
228
+ y: Optional[Tensor] = None,
229
+ edge_id: Optional[Tensor] = None,
230
+ edge_attr: Optional[Tensor] = None,
231
+ ) -> FeatureStore:
232
+ ...
233
+
234
+ @classmethod
235
+ def from_hetero_data(
236
+ cls,
237
+ node_id_dict: Dict[NodeType, Tensor],
238
+ x_dict: Optional[Dict[NodeType, Tensor]] = None,
239
+ y_dict: Optional[Dict[NodeType, Tensor]] = None,
240
+ edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,
241
+ edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
242
+ ) -> FeatureStore:
243
+ ...
244
+
245
+ @classmethod
246
+ def from_partition(cls, root: str, pid: int) -> FeatureStore:
247
+ ...
248
+
249
+
250
+ class RemoteDataType(Enum):
251
+ DATA = auto()
252
+ PARTITION = auto()
253
+
254
+
255
+ @dataclass
256
+ class RemoteGraphBackendLoader:
257
+ """Utility class to load triplets into a RAG Backend."""
258
+ path: str
259
+ datatype: RemoteDataType
260
+ graph_store_type: Type[ConvertableGraphStore]
261
+ feature_store_type: Type[ConvertableFeatureStore]
262
+
263
+ def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
264
+ if self.datatype == RemoteDataType.DATA:
265
+ data_obj = torch.load(self.path, weights_only=False)
266
+ # is_sorted=true since assume nodes come sorted from indexer
267
+ graph_store = self.graph_store_type.from_data(
268
+ edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
269
+ num_nodes=data_obj.num_nodes, is_sorted=True)
270
+ feature_store = self.feature_store_type.from_data(
271
+ node_id=data_obj['node_id'], x=data_obj.x,
272
+ edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)
273
+ elif self.datatype == RemoteDataType.PARTITION:
274
+ if pid is None:
275
+ assert pid is not None, \
276
+ "Partition ID must be defined for loading from a " \
277
+ + "partitioned store."
278
+ graph_store = self.graph_store_type.from_partition(self.path, pid)
279
+ feature_store = self.feature_store_type.from_partition(
280
+ self.path, pid)
281
+ else:
282
+ raise NotImplementedError
283
+ return (feature_store, graph_store)
284
+
285
+ def __del__(self) -> None:
286
+ if os.path.exists(self.path):
287
+ os.remove(self.path)
288
+
289
+
290
+ def create_graph_from_triples(
291
+ triples: Iterable[TripletLike],
292
+ embedding_model: Union[Module, Callable],
293
+ embedding_method_kwargs: Optional[Dict[str, Any]] = None,
294
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
295
+ ) -> Data:
296
+ """Utility function that can be used to create a graph from triples."""
297
+ # Resolve callable methods
298
+ embedding_method_kwargs = embedding_method_kwargs \
299
+ if embedding_method_kwargs is not None else dict()
300
+
301
+ indexer = LargeGraphIndexer.from_triplets(triples,
302
+ pre_transform=pre_transform)
303
+ node_feats = embedding_model(indexer.get_unique_node_features(),
304
+ **embedding_method_kwargs)
305
+ indexer.add_node_feature('x', node_feats)
306
+
307
+ edge_feats = embedding_model(
308
+ indexer.get_unique_edge_features(feature_name=EDGE_RELATION),
309
+ **embedding_method_kwargs)
310
+ indexer.add_edge_feature(new_feature_name="edge_attr",
311
+ new_feature_vals=edge_feats,
312
+ map_from_feature=EDGE_RELATION)
313
+
314
+ data = indexer.to_data(node_feature_name='x',
315
+ edge_feature_name='edge_attr')
316
+ data = data.to("cpu")
317
+ return data
318
+
319
+
320
+ def create_remote_backend_from_graph_data(
321
+ graph_data: Data,
322
+ graph_db: Type[ConvertableGraphStore] = LocalGraphStore,
323
+ feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,
324
+ path: str = '',
325
+ n_parts: int = 1,
326
+ ) -> RemoteGraphBackendLoader:
327
+ """Utility function that can be used to create a RAG Backend from triples.
328
+
329
+ Args:
330
+ graph_data (Data): Graph data to load into the RAG Backend.
331
+ graph_db (Type[ConvertableGraphStore], optional): GraphStore class to
332
+ use. Defaults to LocalGraphStore.
333
+ feature_db (Type[ConvertableFeatureStore], optional): FeatureStore
334
+ class to use. Defaults to LocalFeatureStore.
335
+ path (str, optional): path to save resulting stores. Defaults to ''.
336
+ n_parts (int, optional): Number of partitons to store in.
337
+ Defaults to 1.
338
+
339
+ Returns:
340
+ RemoteGraphBackendLoader: Loader to load RAG backend from disk or
341
+ memory.
342
+ """
343
+ # Will return attribute errors for missing attributes
344
+ if not issubclass(graph_db, ConvertableGraphStore):
345
+ _ = graph_db.from_data
346
+ _ = graph_db.from_hetero_data
347
+ _ = graph_db.from_partition
348
+ elif not issubclass(feature_db, ConvertableFeatureStore):
349
+ _ = feature_db.from_data
350
+ _ = feature_db.from_hetero_data
351
+ _ = feature_db.from_partition
352
+
353
+ if n_parts == 1:
354
+ torch.save(graph_data, path)
355
+ return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
356
+ feature_db)
357
+ else:
358
+ partitioner = Partitioner(data=graph_data, num_parts=n_parts,
359
+ root=path)
360
+ partitioner.generate_partition()
361
+ return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,
362
+ graph_db, feature_db)
363
+
364
+
365
+ def make_pcst_filter(triples: List[Tuple[str, str,
366
+ str]], model: SentenceTransformer,
367
+ topk: int = 5, topk_e: int = 5, cost_e: float = 0.5,
368
+ num_clusters: int = 1) -> Callable[[Data, str], Data]:
369
+ """Creates a PCST (Prize Collecting Tree) filter.
370
+
371
+ :param triples: List of triples (head, relation, tail) representing KG data
372
+ :param model: SentenceTransformer model for embedding text
373
+ :param topk: Number of top-K results to return (default: 5)
374
+ :param topk_e: Number of top-K entity results to return (default: 5)
375
+ :param cost_e: Cost of edges (default: 0.5)
376
+ :param num_clusters: Number of connected components in the PCST output.
377
+ :return: PCST Filter function
378
+ """
379
+ if DataFrame is None:
380
+ raise Exception("PCST requires `pip install pandas`"
381
+ ) # Check if pandas is installed
382
+
383
+ # Remove duplicate triples to ensure unique set
384
+ triples = list(dict.fromkeys(triples))
385
+
386
+ # Initialize empty list to store nodes (entities) from triples
387
+ nodes = []
388
+
389
+ # Iterate over triples to extract unique nodes (entities)
390
+ for h, _, t in triples:
391
+ for node in (h, t): # Extract head and tail entities from each triple
392
+ nodes.append(node)
393
+
394
+ # Remove duplicates and create final list of unique nodes
395
+ nodes = list(dict.fromkeys(nodes))
396
+
397
+ # Create full list of textual nodes (entities) for filtering
398
+ full_textual_nodes = nodes
399
+
400
+ def apply_retrieval_via_pcst(
401
+ graph: Data, # Input graph data
402
+ query: str, # Search query
403
+ ) -> Data:
404
+ """Applies PCST filtering for retrieval.
405
+
406
+ :param graph: Input graph data
407
+ :param query: Search query
408
+ :return: Retrieved graph/query data
409
+ """
410
+ # PCST relies on numpy and pcst_fast pypi libs, hence to("cpu")
411
+ with torch.no_grad():
412
+ q_emb = model.encode([query]).to("cpu")
413
+ textual_nodes = [(int(i), full_textual_nodes[i])
414
+ for i in graph["node_idx"]]
415
+ textual_nodes = DataFrame(textual_nodes,
416
+ columns=["node_id", "node_attr"])
417
+ textual_edges = [triples[i] for i in graph["edge_idx"]]
418
+ textual_edges = DataFrame(textual_edges,
419
+ columns=["src", "edge_attr", "dst"])
420
+ out_graph, desc = retrieval_via_pcst(graph.to(q_emb.device), q_emb,
421
+ textual_nodes, textual_edges,
422
+ topk=topk, topk_e=topk_e,
423
+ cost_e=cost_e,
424
+ num_clusters=num_clusters)
425
+ out_graph["desc"] = desc
426
+ where_trips_start = desc.find("src,edge_attr,dst")
427
+ parsed_trips = []
428
+ for trip in desc[where_trips_start + 18:-1].split("\n"):
429
+ parsed_trips.append(tuple(trip.split(",")))
430
+
431
+ # Handle case where PCST returns an isolated node
432
+ """
433
+ TODO find a better solution since these failed subgraphs
434
+ severely hurt accuracy.
435
+ """
436
+ if str(parsed_trips) == "[('',)]" or out_graph.edge_index.numel() == 0:
437
+ out_graph["triples"] = []
438
+ else:
439
+ out_graph["triples"] = parsed_trips
440
+ out_graph["question"] = query
441
+ return out_graph
442
+
443
+ return apply_retrieval_via_pcst