pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,169 @@
1
+ import gc
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from torch_geometric.data import Data, HeteroData
9
+ from torch_geometric.distributed import LocalFeatureStore
10
+ from torch_geometric.llm.utils.backend_utils import batch_knn
11
+ from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
12
+ from torch_geometric.typing import InputNodes
13
+
14
+
15
+ # NOTE: Only compatible with Homogeneous graphs for now
16
+ class KNNRAGFeatureStore(LocalFeatureStore):
17
+ """A feature store that uses a KNN-based retrieval."""
18
+ def __init__(self) -> None:
19
+ """Initializes the feature store."""
20
+ # to be set by the config
21
+ self.encoder_model = None
22
+ self.k_nodes = None
23
+ self._config: Dict[str, Any] = {}
24
+ super().__init__()
25
+
26
+ @property
27
+ def config(self) -> Dict[str, Any]:
28
+ """Get the config for the feature store."""
29
+ return self._config
30
+
31
+ def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
32
+ """Set an attribute from the config.
33
+
34
+ Args:
35
+ config (Dict[str, Any]): Config dictionary
36
+ attr_name (str): Name of attribute to set
37
+
38
+ Raises:
39
+ ValueError: If required attribute not found in config
40
+ """
41
+ if attr_name not in config:
42
+ raise ValueError(
43
+ f"Required config parameter '{attr_name}' not found")
44
+ setattr(self, attr_name, config[attr_name])
45
+
46
+ @config.setter # type: ignore
47
+ def config(self, config: Dict[str, Any]) -> None:
48
+ """Set the config for the feature store.
49
+
50
+ Args:
51
+ config (Dict[str, Any]):
52
+ Config dictionary containing required parameters
53
+
54
+ Raises:
55
+ ValueError: If required parameters missing from config
56
+ """
57
+ self._set_from_config(config, "k_nodes")
58
+ self._set_from_config(config, "encoder_model")
59
+ assert self.encoder_model is not None, \
60
+ "Need to define encoder model from config"
61
+ self.encoder_model.eval()
62
+
63
+ self._config = config
64
+
65
+ @property
66
+ def x(self) -> Tensor:
67
+ """Returns the node features."""
68
+ return Tensor(self.get_tensor(group_name=None, attr_name='x'))
69
+
70
+ @property
71
+ def edge_attr(self) -> Tensor:
72
+ """Returns the edge attributes."""
73
+ return Tensor(
74
+ self.get_tensor(group_name=(None, None), attr_name='edge_attr'))
75
+
76
+ def retrieve_seed_nodes( # noqa: D417
77
+ self, query: Union[str, List[str],
78
+ Tuple[str]]) -> Tuple[InputNodes, Tensor]:
79
+ """Retrieves the k_nodes most similar nodes to the given query.
80
+
81
+ Args:
82
+ query (Union[str, List[str], Tuple[str]]): The query
83
+ or list of queries to search for.
84
+
85
+ Returns:
86
+ The indices of the most similar nodes and the encoded query
87
+ """
88
+ if not isinstance(query, (list, tuple)):
89
+ query = [query]
90
+ assert self.k_nodes is not None, "please set k_nodes via config"
91
+ if len(query) == 1:
92
+ result, query_enc = next(
93
+ self._retrieve_seed_nodes_batch(query, self.k_nodes))
94
+ gc.collect()
95
+ torch.cuda.empty_cache()
96
+ return result, query_enc
97
+ else:
98
+ out_dict = {}
99
+ for i, out in enumerate(
100
+ self._retrieve_seed_nodes_batch(query, self.k_nodes)):
101
+ out_dict[query[i]] = out
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+ return out_dict
105
+
106
+ def _retrieve_seed_nodes_batch( # noqa: D417
107
+ self, query: Iterable[Any],
108
+ k_nodes: int) -> Iterator[Tuple[InputNodes, Tensor]]:
109
+ """Retrieves the k_nodes most similar nodes to each query in the batch.
110
+
111
+ Args:
112
+ - query (Iterable[Any]: The batch of queries to search for.
113
+ - k_nodes (int): The number of nodes to retrieve.
114
+
115
+ Yields:
116
+ - The indices of the most similar nodes for each query.
117
+ """
118
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
119
+ raise NotImplementedError
120
+ assert self.encoder_model is not None, \
121
+ "Need to define encoder model from config"
122
+ query_enc = self.encoder_model.encode(query)
123
+ return batch_knn(query_enc, self.x, k_nodes)
124
+
125
+ def load_subgraph( # noqa
126
+ self,
127
+ sample: Union[SamplerOutput, HeteroSamplerOutput],
128
+ induced: bool = True,
129
+ ) -> Union[Data, HeteroData]:
130
+ """Loads a subgraph from the given sample.
131
+
132
+ Args:
133
+ sample: The sample to load the subgraph from.
134
+ induced: Whether to return the induced subgraph.
135
+ Resets node and edge ids.
136
+
137
+ Returns:
138
+ The loaded subgraph.
139
+ """
140
+ if isinstance(sample, HeteroSamplerOutput):
141
+ raise NotImplementedError
142
+ """
143
+ NOTE: torch_geometric.loader.utils.filter_custom_store
144
+ can be used here if it supported edge features.
145
+ """
146
+ edge_id = sample.edge
147
+ x = self.x[sample.node]
148
+ edge_attr = self.edge_attr[edge_id]
149
+
150
+ edge_idx = torch.stack(
151
+ [sample.row, sample.col], dim=0) if induced else torch.stack(
152
+ [sample.global_row, sample.global_col], dim=0)
153
+ result = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
154
+
155
+ # useful for tracking what subset of the graph was sampled
156
+ result.node_idx = sample.node
157
+ result.edge_idx = edge_id
158
+
159
+ return result
160
+
161
+
162
+ """
163
+ TODO: make class CuVSKNNRAGFeatureStore(KNNRAGFeatureStore)
164
+ include a approximate knn flag for the CuVS.
165
+ Connect this with a CuGraphGraphStore
166
+ for enabling a accelerated boolean flag for RAGQueryLoader.
167
+ On by default if CuGraph+CuVS avail.
168
+ If not raise note mentioning its speedup.
169
+ """
@@ -0,0 +1,199 @@
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.data import FeatureStore
7
+ from torch_geometric.distributed import LocalGraphStore
8
+ from torch_geometric.sampler import (
9
+ BidirectionalNeighborSampler,
10
+ NodeSamplerInput,
11
+ SamplerOutput,
12
+ )
13
+ from torch_geometric.utils import index_sort
14
+
15
+ # A representation of an edge index, following the possible formats:
16
+ # * default: Tensor, size = [2, num_edges]
17
+ # * Tensor[0, :] == row, Tensor[1, :] == col
18
+ # * COO: (row, col)
19
+ # * CSC: (row, colptr)
20
+ # * CSR: (rowptr, col)
21
+ _EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
22
+
23
+
24
+ class NeighborSamplingRAGGraphStore(LocalGraphStore):
25
+ """Neighbor sampling based graph-store to store & retrieve graph data."""
26
+ def __init__( # type: ignore[no-untyped-def]
27
+ self,
28
+ feature_store: Optional[FeatureStore] = None,
29
+ **kwargs,
30
+ ):
31
+ """Initializes the graph store.
32
+ Optional feature store and neighbor sampling settings.
33
+
34
+ Args:
35
+ feature_store (optional): The feature store to use.
36
+ None if not yet registered.
37
+ **kwargs (optional):
38
+ Additional keyword arguments for neighbor sampling.
39
+ """
40
+ self.feature_store = feature_store
41
+ self.sample_kwargs = kwargs
42
+ self._sampler_is_initialized = False
43
+ self._config: Dict[str, Any] = {}
44
+
45
+ # to be set by the config
46
+ self.num_neighbors = None
47
+ super().__init__()
48
+
49
+ @property
50
+ def config(self) -> Dict[str, Any]:
51
+ """Get the config for the feature store."""
52
+ return self._config
53
+
54
+ def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
55
+ """Set an attribute from the config.
56
+
57
+ Args:
58
+ config (Dict[str, Any]): Config dictionary
59
+ attr_name (str): Name of attribute to set
60
+
61
+ Raises:
62
+ ValueError: If required attribute not found in config
63
+ """
64
+ if attr_name not in config:
65
+ raise ValueError(
66
+ f"Required config parameter '{attr_name}' not found")
67
+ setattr(self, attr_name, config[attr_name])
68
+
69
+ @config.setter # type: ignore
70
+ def config(self, config: Dict[str, Any]) -> None:
71
+ """Set the config for the feature store.
72
+
73
+ Args:
74
+ config (Dict[str, Any]):
75
+ Config dictionary containing required parameters
76
+
77
+ Raises:
78
+ ValueError: If required parameters missing from config
79
+ """
80
+ self._set_from_config(config, "num_neighbors")
81
+ if hasattr(self, 'sampler'):
82
+ self.sampler.num_neighbors = ( # type: ignore[has-type]
83
+ self.num_neighbors)
84
+
85
+ self._config = config
86
+
87
+ def _init_sampler(self) -> None:
88
+ """Initializes neighbor sampler with the registered feature store."""
89
+ if self.feature_store is None:
90
+ raise AttributeError("Feature store not registered yet.")
91
+ assert self.num_neighbors is not None, \
92
+ "Please set num_neighbors through config"
93
+ self.sampler = BidirectionalNeighborSampler(
94
+ data=(self.feature_store, self), num_neighbors=self.num_neighbors,
95
+ **self.sample_kwargs)
96
+ self._sampler_is_initialized = True
97
+
98
+ def register_feature_store(self, feature_store: FeatureStore) -> None:
99
+ """Registers a feature store with the graph store.
100
+
101
+ :param feature_store: The feature store to register.
102
+ """
103
+ self.feature_store = feature_store
104
+ self._sampler_is_initialized = False
105
+
106
+ def put_edge_id( # type: ignore[no-untyped-def]
107
+ self, edge_id: Tensor, *args, **kwargs) -> bool:
108
+ """Stores an edge ID in the graph store.
109
+
110
+ :param edge_id: The edge ID to store.
111
+ :return: Whether the operation was successful.
112
+ """
113
+ ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
114
+ self._sampler_is_initialized = False
115
+ return ret
116
+
117
+ @property
118
+ def edge_index(self) -> _EdgeTensorType:
119
+ """Gets the edge index of the graph.
120
+
121
+ :return: The edge index as a tensor.
122
+ """
123
+ return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
124
+
125
+ def put_edge_index( # type: ignore[no-untyped-def]
126
+ self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
127
+ """Stores an edge index in the graph store.
128
+
129
+ :param edge_index: The edge index to store.
130
+ :return: Whether the operation was successful.
131
+ """
132
+ ret = super().put_edge_index(edge_index, *args, **kwargs)
133
+ # HACK
134
+ self.edge_idx_args = args
135
+ self.edge_idx_kwargs = kwargs
136
+ self._sampler_is_initialized = False
137
+ return ret
138
+
139
+ # HACKY
140
+ @edge_index.setter # type: ignore
141
+ def edge_index(self, edge_index: _EdgeTensorType) -> None:
142
+ """Sets the edge index of the graph.
143
+
144
+ :param edge_index: The edge index to set.
145
+ """
146
+ # correct since we make node list from triples
147
+ if isinstance(edge_index, Tensor):
148
+ num_nodes = int(edge_index.max()) + 1
149
+ else:
150
+ assert isinstance(edge_index, tuple) \
151
+ and isinstance(edge_index[0], Tensor) \
152
+ and isinstance(edge_index[1], Tensor), \
153
+ "edge_index must be a Tensor of [2, num_edges] \
154
+ or a tuple of Tensors, (row, col)."
155
+
156
+ num_nodes = int(edge_index[0].max()) + 1
157
+ attr = dict(
158
+ edge_type=None,
159
+ layout='coo',
160
+ size=(num_nodes, num_nodes),
161
+ is_sorted=False,
162
+ )
163
+ # edge index needs to be sorted here and the perm saved for later
164
+ col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
165
+ stable=True)
166
+ row_sorted = edge_index[0][self.perm]
167
+ edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
168
+ self.put_edge_index(edge_index_sorted, **attr)
169
+
170
+ def sample_subgraph(
171
+ self,
172
+ seed_nodes: Tensor,
173
+ ) -> SamplerOutput:
174
+ """Sample the graph starting from the given nodes using the
175
+ in-built NeighborSampler.
176
+
177
+ Args:
178
+ seed_nodes (InputNodes): Seed nodes to start sampling from.
179
+ num_neighbors (Optional[NumNeighborsType], optional): Parameters
180
+ to determine how many hops and number of neighbors per hop.
181
+ Defaults to None.
182
+
183
+ Returns:
184
+ Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
185
+ for the input.
186
+ """
187
+ # TODO add support for Hetero
188
+ if not self._sampler_is_initialized:
189
+ self._init_sampler()
190
+
191
+ seed_nodes = seed_nodes.unique().contiguous()
192
+ node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
193
+ out = self.sampler.sample_from_nodes( # type: ignore[has-type]
194
+ node_sample_input)
195
+
196
+ # edge ids need to be remapped to the original indices
197
+ out.edge = self.perm[out.edge]
198
+
199
+ return out
@@ -0,0 +1,125 @@
1
+ # mypy: ignore-errors
2
+ import os
3
+ from abc import abstractmethod
4
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Union
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from torch_geometric.data import Data
10
+ from torch_geometric.llm.models import SentenceTransformer
11
+ from torch_geometric.llm.utils.backend_utils import batch_knn
12
+
13
+
14
+ class VectorRetriever(Protocol):
15
+ """Protocol for VectorRAG."""
16
+ @abstractmethod
17
+ def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
18
+ """Retrieve a context for a given query."""
19
+ ...
20
+
21
+
22
+ class DocumentRetriever(VectorRetriever):
23
+ """Retrieve documents from a vector database."""
24
+ def __init__(self, raw_docs: List[str],
25
+ embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,
26
+ model: Optional[Union[SentenceTransformer, torch.nn.Module,
27
+ Callable]] = None,
28
+ model_kwargs: Optional[Dict[str, Any]] = None):
29
+ """Retrieve documents from a vector database.
30
+
31
+ Args:
32
+ raw_docs: List[str]: List of raw documents.
33
+ embedded_docs: Optional[Tensor]: Embedded documents.
34
+ k_for_docs: int: Number of documents to retrieve.
35
+ model: Optional[Union[SentenceTransformer, torch.nn.Module]]:
36
+ Model to use for encoding.
37
+ model_kwargs: Optional[Dict[str, Any]]:
38
+ Keyword arguments to pass to the model.
39
+ """
40
+ self.raw_docs = raw_docs
41
+ self.embedded_docs = embedded_docs
42
+ self.k_for_docs = k_for_docs
43
+ self.model = model
44
+
45
+ if self.model is not None:
46
+ self.encoder = self.model
47
+ self.model_kwargs = model_kwargs
48
+
49
+ if self.embedded_docs is None:
50
+ assert self.model is not None, \
51
+ "Model must be provided if embedded_docs is not provided"
52
+ self.model_kwargs = model_kwargs or {}
53
+ self.embedded_docs = self.encoder(self.raw_docs,
54
+ **self.model_kwargs)
55
+ # we don't want to print the verbose output in `query`
56
+ self.model_kwargs.pop("verbose", None)
57
+
58
+ def query(self, query: Union[str, Tensor]) -> List[str]:
59
+ """Retrieve documents from the vector database.
60
+
61
+ Args:
62
+ query: Union[str, Tensor]: Query to retrieve documents for.
63
+
64
+ Returns:
65
+ List[str]: Documents retrieved from the vector database.
66
+ """
67
+ if isinstance(query, str):
68
+ with torch.no_grad():
69
+ query_enc = self.encoder(query, **self.model_kwargs)
70
+ else:
71
+ query_enc = query
72
+
73
+ selected_doc_idxs, _ = next(
74
+ batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
75
+ return [self.raw_docs[i] for i in selected_doc_idxs]
76
+
77
+ def save(self, path: str) -> None:
78
+ """Save the DocumentRetriever instance to disk.
79
+
80
+ Args:
81
+ path: str: Path where to save the retriever.
82
+ """
83
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
84
+
85
+ # Prepare data to save
86
+ save_dict = {
87
+ 'raw_docs': self.raw_docs,
88
+ 'embedded_docs': self.embedded_docs,
89
+ 'k_for_docs': self.k_for_docs,
90
+ }
91
+
92
+ # We do not serialize the model
93
+ torch.save(save_dict, path)
94
+
95
+ @classmethod
96
+ def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,
97
+ Callable],
98
+ model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:
99
+ """Load a DocumentRetriever instance from disk.
100
+
101
+ Args:
102
+ path: str: Path to the saved retriever.
103
+ model: Union[SentenceTransformer, torch.nn.Module, Callable]:
104
+ Model to use for encoding.
105
+ If None, the saved model will be used if available.
106
+ model_kwargs: Optional[Dict[str, Any]]
107
+ Key word args to be passed to model
108
+
109
+ Returns:
110
+ DocumentRetriever: The loaded retriever.
111
+ """
112
+ if not os.path.exists(path):
113
+ raise FileNotFoundError(
114
+ f"No saved document retriever found at {path}")
115
+
116
+ save_dict = torch.load(path, weights_only=False)
117
+ if save_dict['embedded_docs'] is not None \
118
+ and isinstance(save_dict['embedded_docs'], Tensor)\
119
+ and model_kwargs is not None:
120
+ model_kwargs.pop("verbose", None)
121
+ # Create a new DocumentRetriever with the loaded data
122
+ return cls(raw_docs=save_dict['raw_docs'],
123
+ embedded_docs=save_dict['embedded_docs'],
124
+ k_for_docs=save_dict['k_for_docs'], model=model,
125
+ model_kwargs=model_kwargs)
@@ -12,6 +12,7 @@ from torch import Tensor
12
12
  import torch_geometric.typing
13
13
  from torch_geometric.data import Data
14
14
  from torch_geometric.index import index2ptr, ptr2index
15
+ from torch_geometric.io import fs
15
16
  from torch_geometric.typing import pyg_lib
16
17
  from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
17
18
  from torch_geometric.utils.map import map_index
@@ -77,7 +78,7 @@ class ClusterData(torch.utils.data.Dataset):
77
78
  path = osp.join(root_dir, filename or 'metis.pt')
78
79
 
79
80
  if save_dir is not None and osp.exists(path):
80
- self.partition = torch.load(path)
81
+ self.partition = fs.torch_load(path)
81
82
  else:
82
83
  if log: # pragma: no cover
83
84
  print('Computing METIS partitioning...', file=sys.stderr)
@@ -234,9 +235,9 @@ class ClusterData(torch.utils.data.Dataset):
234
235
  class ClusterLoader(torch.utils.data.DataLoader):
235
236
  r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
236
237
  for Training Deep and Large Graph Convolutional Networks"
237
- <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
238
- and their between-cluster links from a large-scale graph data object to
239
- form a mini-batch.
238
+ <https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned
239
+ subgraphs and their between-cluster links from a large-scale graph data
240
+ object to form a mini-batch.
240
241
 
241
242
  .. note::
242
243
 
@@ -251,7 +252,7 @@ class ClusterLoader(torch.utils.data.DataLoader):
251
252
 
252
253
  Args:
253
254
  cluster_data (torch_geometric.loader.ClusterData): The already
254
- partioned data object.
255
+ partitioned data object.
255
256
  **kwargs (optional): Additional arguments of
256
257
  :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
257
258
  :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
@@ -4,6 +4,7 @@ from typing import Optional
4
4
  import torch
5
5
  from tqdm import tqdm
6
6
 
7
+ from torch_geometric.io import fs
7
8
  from torch_geometric.typing import SparseTensor
8
9
 
9
10
 
@@ -77,7 +78,7 @@ class GraphSAINTSampler(torch.utils.data.DataLoader):
77
78
  if self.sample_coverage > 0:
78
79
  path = osp.join(save_dir or '', self._filename)
79
80
  if save_dir is not None and osp.exists(path): # pragma: no cover
80
- self.node_norm, self.edge_norm = torch.load(path)
81
+ self.node_norm, self.edge_norm = fs.torch_load(path)
81
82
  else:
82
83
  self.node_norm, self.edge_norm = self._compute_norm()
83
84
  if save_dir is not None: # pragma: no cover
@@ -148,7 +148,7 @@ def indices_complete_check(
148
148
  if isinstance(aux, Tensor):
149
149
  aux = aux.cpu().numpy()
150
150
 
151
- assert np.all(np.in1d(out,
151
+ assert np.all(np.isin(out,
152
152
  aux)), "Not all output nodes are in aux nodes!"
153
153
  outs.append(out)
154
154
 
@@ -236,7 +236,7 @@ def create_batchwise_out_aux_pairs(
236
236
  logits[tele_set, i] = 1. / len(tele_set)
237
237
 
238
238
  new_logits = logits.clone()
239
- for i in range(num_iter):
239
+ for _ in range(num_iter):
240
240
  new_logits = adj @ new_logits * (1 - alpha) + alpha * logits
241
241
 
242
242
  inds = new_logits.argsort(0)
@@ -498,7 +498,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
498
498
  assert adj is not None
499
499
 
500
500
  for out, aux in pbar:
501
- mask = torch.from_numpy(np.in1d(aux, out))
501
+ mask = torch.from_numpy(np.isin(aux, out))
502
502
  if isinstance(aux, np.ndarray):
503
503
  aux = torch.from_numpy(aux)
504
504
  subg = get_subgraph(aux, graph, return_edge_index_type, adj,
@@ -541,7 +541,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
541
541
  out, aux = zip(*data_list)
542
542
  out = np.concatenate(out)
543
543
  aux = np.unique(np.concatenate(aux))
544
- mask = torch.from_numpy(np.in1d(aux, out))
544
+ mask = torch.from_numpy(np.isin(aux, out))
545
545
  aux = torch.from_numpy(aux)
546
546
 
547
547
  subg = get_subgraph(aux, self.graph, self.return_edge_index_type,
@@ -70,7 +70,7 @@ class LinkLoader(
70
70
  :obj:`edge_label_index`. If set, temporal sampling will be
71
71
  used such that neighbors are guaranteed to fulfill temporal
72
72
  constraints, *i.e.*, neighbors have an earlier timestamp than
73
- the ouput edge. The :obj:`time_attr` needs to be set for this
73
+ the output edge. The :obj:`time_attr` needs to be set for this
74
74
  to work. (default: :obj:`None`)
75
75
  neg_sampling (NegativeSampling, optional): The negative sampling
76
76
  configuration.
@@ -117,7 +117,7 @@ class LinkNeighborLoader(LinkLoader):
117
117
  :obj:`edge_label_index`. If set, temporal sampling will be
118
118
  used such that neighbors are guaranteed to fulfill temporal
119
119
  constraints, *i.e.*, neighbors have an earlier timestamp than
120
- the ouput edge. The :obj:`time_attr` needs to be set for this
120
+ the output edge. The :obj:`time_attr` needs to be set for this
121
121
  to work. (default: :obj:`None`)
122
122
  replace (bool, optional): If set to :obj:`True`, will sample with
123
123
  replacement. (default: :obj:`False`)
@@ -170,6 +170,7 @@ class LinkNeighborLoader(LinkLoader):
170
170
  negative sampling mode.
171
171
  If set to :obj:`None`, no negative sampling strategy is applied.
172
172
  (default: :obj:`None`)
173
+ For example use obj:`neg_sampling=dict(mode= 'binary', amount=0.5)`
173
174
  neg_sampling_ratio (int or float, optional): The ratio of sampled
174
175
  negative edges to the number of positive edges.
175
176
  Deprecated in favor of the :obj:`neg_sampling` argument.
@@ -106,9 +106,9 @@ class MultithreadingMixin:
106
106
  def _mt_init_fn(self, worker_id: int) -> None:
107
107
  try:
108
108
  torch.set_num_threads(int(self._worker_threads))
109
- except IndexError:
109
+ except IndexError as e:
110
110
  raise ValueError(f"Cannot set {self.worker_threads} threads "
111
- f"in worker {worker_id}")
111
+ f"in worker {worker_id}") from e
112
112
 
113
113
  # Chain worker init functions:
114
114
  self._old_worker_init_fn(worker_id)
@@ -213,9 +213,9 @@ class AffinityMixin:
213
213
 
214
214
  psutil.Process().cpu_affinity(worker_cores)
215
215
 
216
- except IndexError:
216
+ except IndexError as e:
217
217
  raise ValueError(f"Cannot use CPU affinity for worker ID "
218
- f"{worker_id} on CPU {self.loader_cores}")
218
+ f"{worker_id} on CPU {self.loader_cores}") from e
219
219
 
220
220
  # Chain worker init functions:
221
221
  self._old_worker_init_fn(worker_id)
@@ -248,7 +248,8 @@ class AffinityMixin:
248
248
  warnings.warn(
249
249
  "Due to conflicting parallelization methods it is not advised "
250
250
  "to use affinitization with 'HeteroData' datasets. "
251
- "Use `enable_multithreading` for better performance.")
251
+ "Use `enable_multithreading` for better performance.",
252
+ stacklevel=2)
252
253
 
253
254
  self.loader_cores = loader_cores[:] if loader_cores else None
254
255
  if self.loader_cores is None:
@@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader):
14
14
  This loader allows for mini-batch training of GNNs on large-scale graphs
15
15
  where full-batch training is not feasible.
16
16
 
17
- More specifically, :obj:`num_neighbors` denotes how much neighbors are
17
+ More specifically, :obj:`num_neighbors` denotes how many neighbors are
18
18
  sampled for each node in each iteration.
19
19
  :class:`~torch_geometric.loader.NeighborLoader` takes in this list of
20
20
  :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
@@ -72,9 +72,9 @@ class NeighborSampler(torch.utils.data.DataLoader):
72
72
  `examples/reddit.py
73
73
  <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
74
74
  reddit.py>`_ or
75
- `examples/ogbn_products_sage.py
75
+ `examples/ogbn_train.py
76
76
  <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
77
- ogbn_products_sage.py>`_.
77
+ ogbn_train.py>`_.
78
78
 
79
79
  Args:
80
80
  edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a