pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -0,0 +1,199 @@
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.data import FeatureStore
7
+ from torch_geometric.distributed import LocalGraphStore
8
+ from torch_geometric.sampler import (
9
+ BidirectionalNeighborSampler,
10
+ NodeSamplerInput,
11
+ SamplerOutput,
12
+ )
13
+ from torch_geometric.utils import index_sort
14
+
15
+ # A representation of an edge index, following the possible formats:
16
+ # * default: Tensor, size = [2, num_edges]
17
+ # * Tensor[0, :] == row, Tensor[1, :] == col
18
+ # * COO: (row, col)
19
+ # * CSC: (row, colptr)
20
+ # * CSR: (rowptr, col)
21
+ _EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
22
+
23
+
24
+ class NeighborSamplingRAGGraphStore(LocalGraphStore):
25
+ """Neighbor sampling based graph-store to store & retrieve graph data."""
26
+ def __init__( # type: ignore[no-untyped-def]
27
+ self,
28
+ feature_store: Optional[FeatureStore] = None,
29
+ **kwargs,
30
+ ):
31
+ """Initializes the graph store.
32
+ Optional feature store and neighbor sampling settings.
33
+
34
+ Args:
35
+ feature_store (optional): The feature store to use.
36
+ None if not yet registered.
37
+ **kwargs (optional):
38
+ Additional keyword arguments for neighbor sampling.
39
+ """
40
+ self.feature_store = feature_store
41
+ self.sample_kwargs = kwargs
42
+ self._sampler_is_initialized = False
43
+ self._config: Dict[str, Any] = {}
44
+
45
+ # to be set by the config
46
+ self.num_neighbors = None
47
+ super().__init__()
48
+
49
+ @property
50
+ def config(self) -> Dict[str, Any]:
51
+ """Get the config for the feature store."""
52
+ return self._config
53
+
54
+ def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
55
+ """Set an attribute from the config.
56
+
57
+ Args:
58
+ config (Dict[str, Any]): Config dictionary
59
+ attr_name (str): Name of attribute to set
60
+
61
+ Raises:
62
+ ValueError: If required attribute not found in config
63
+ """
64
+ if attr_name not in config:
65
+ raise ValueError(
66
+ f"Required config parameter '{attr_name}' not found")
67
+ setattr(self, attr_name, config[attr_name])
68
+
69
+ @config.setter # type: ignore
70
+ def config(self, config: Dict[str, Any]) -> None:
71
+ """Set the config for the feature store.
72
+
73
+ Args:
74
+ config (Dict[str, Any]):
75
+ Config dictionary containing required parameters
76
+
77
+ Raises:
78
+ ValueError: If required parameters missing from config
79
+ """
80
+ self._set_from_config(config, "num_neighbors")
81
+ if hasattr(self, 'sampler'):
82
+ self.sampler.num_neighbors = ( # type: ignore[has-type]
83
+ self.num_neighbors)
84
+
85
+ self._config = config
86
+
87
+ def _init_sampler(self) -> None:
88
+ """Initializes neighbor sampler with the registered feature store."""
89
+ if self.feature_store is None:
90
+ raise AttributeError("Feature store not registered yet.")
91
+ assert self.num_neighbors is not None, \
92
+ "Please set num_neighbors through config"
93
+ self.sampler = BidirectionalNeighborSampler(
94
+ data=(self.feature_store, self), num_neighbors=self.num_neighbors,
95
+ **self.sample_kwargs)
96
+ self._sampler_is_initialized = True
97
+
98
+ def register_feature_store(self, feature_store: FeatureStore) -> None:
99
+ """Registers a feature store with the graph store.
100
+
101
+ :param feature_store: The feature store to register.
102
+ """
103
+ self.feature_store = feature_store
104
+ self._sampler_is_initialized = False
105
+
106
+ def put_edge_id( # type: ignore[no-untyped-def]
107
+ self, edge_id: Tensor, *args, **kwargs) -> bool:
108
+ """Stores an edge ID in the graph store.
109
+
110
+ :param edge_id: The edge ID to store.
111
+ :return: Whether the operation was successful.
112
+ """
113
+ ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
114
+ self._sampler_is_initialized = False
115
+ return ret
116
+
117
+ @property
118
+ def edge_index(self) -> _EdgeTensorType:
119
+ """Gets the edge index of the graph.
120
+
121
+ :return: The edge index as a tensor.
122
+ """
123
+ return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
124
+
125
+ def put_edge_index( # type: ignore[no-untyped-def]
126
+ self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
127
+ """Stores an edge index in the graph store.
128
+
129
+ :param edge_index: The edge index to store.
130
+ :return: Whether the operation was successful.
131
+ """
132
+ ret = super().put_edge_index(edge_index, *args, **kwargs)
133
+ # HACK
134
+ self.edge_idx_args = args
135
+ self.edge_idx_kwargs = kwargs
136
+ self._sampler_is_initialized = False
137
+ return ret
138
+
139
+ # HACKY
140
+ @edge_index.setter # type: ignore
141
+ def edge_index(self, edge_index: _EdgeTensorType) -> None:
142
+ """Sets the edge index of the graph.
143
+
144
+ :param edge_index: The edge index to set.
145
+ """
146
+ # correct since we make node list from triples
147
+ if isinstance(edge_index, Tensor):
148
+ num_nodes = int(edge_index.max()) + 1
149
+ else:
150
+ assert isinstance(edge_index, tuple) \
151
+ and isinstance(edge_index[0], Tensor) \
152
+ and isinstance(edge_index[1], Tensor), \
153
+ "edge_index must be a Tensor of [2, num_edges] \
154
+ or a tuple of Tensors, (row, col)."
155
+
156
+ num_nodes = int(edge_index[0].max()) + 1
157
+ attr = dict(
158
+ edge_type=None,
159
+ layout='coo',
160
+ size=(num_nodes, num_nodes),
161
+ is_sorted=False,
162
+ )
163
+ # edge index needs to be sorted here and the perm saved for later
164
+ col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
165
+ stable=True)
166
+ row_sorted = edge_index[0][self.perm]
167
+ edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
168
+ self.put_edge_index(edge_index_sorted, **attr)
169
+
170
+ def sample_subgraph(
171
+ self,
172
+ seed_nodes: Tensor,
173
+ ) -> SamplerOutput:
174
+ """Sample the graph starting from the given nodes using the
175
+ in-built NeighborSampler.
176
+
177
+ Args:
178
+ seed_nodes (InputNodes): Seed nodes to start sampling from.
179
+ num_neighbors (Optional[NumNeighborsType], optional): Parameters
180
+ to determine how many hops and number of neighbors per hop.
181
+ Defaults to None.
182
+
183
+ Returns:
184
+ Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
185
+ for the input.
186
+ """
187
+ # TODO add support for Hetero
188
+ if not self._sampler_is_initialized:
189
+ self._init_sampler()
190
+
191
+ seed_nodes = seed_nodes.unique().contiguous()
192
+ node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
193
+ out = self.sampler.sample_from_nodes( # type: ignore[has-type]
194
+ node_sample_input)
195
+
196
+ # edge ids need to be remapped to the original indices
197
+ out.edge = self.perm[out.edge]
198
+
199
+ return out
@@ -0,0 +1,125 @@
1
+ # mypy: ignore-errors
2
+ import os
3
+ from abc import abstractmethod
4
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Union
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from torch_geometric.data import Data
10
+ from torch_geometric.llm.models import SentenceTransformer
11
+ from torch_geometric.llm.utils.backend_utils import batch_knn
12
+
13
+
14
+ class VectorRetriever(Protocol):
15
+ """Protocol for VectorRAG."""
16
+ @abstractmethod
17
+ def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
18
+ """Retrieve a context for a given query."""
19
+ ...
20
+
21
+
22
+ class DocumentRetriever(VectorRetriever):
23
+ """Retrieve documents from a vector database."""
24
+ def __init__(self, raw_docs: List[str],
25
+ embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,
26
+ model: Optional[Union[SentenceTransformer, torch.nn.Module,
27
+ Callable]] = None,
28
+ model_kwargs: Optional[Dict[str, Any]] = None):
29
+ """Retrieve documents from a vector database.
30
+
31
+ Args:
32
+ raw_docs: List[str]: List of raw documents.
33
+ embedded_docs: Optional[Tensor]: Embedded documents.
34
+ k_for_docs: int: Number of documents to retrieve.
35
+ model: Optional[Union[SentenceTransformer, torch.nn.Module]]:
36
+ Model to use for encoding.
37
+ model_kwargs: Optional[Dict[str, Any]]:
38
+ Keyword arguments to pass to the model.
39
+ """
40
+ self.raw_docs = raw_docs
41
+ self.embedded_docs = embedded_docs
42
+ self.k_for_docs = k_for_docs
43
+ self.model = model
44
+
45
+ if self.model is not None:
46
+ self.encoder = self.model
47
+ self.model_kwargs = model_kwargs
48
+
49
+ if self.embedded_docs is None:
50
+ assert self.model is not None, \
51
+ "Model must be provided if embedded_docs is not provided"
52
+ self.model_kwargs = model_kwargs or {}
53
+ self.embedded_docs = self.encoder(self.raw_docs,
54
+ **self.model_kwargs)
55
+ # we don't want to print the verbose output in `query`
56
+ self.model_kwargs.pop("verbose", None)
57
+
58
+ def query(self, query: Union[str, Tensor]) -> List[str]:
59
+ """Retrieve documents from the vector database.
60
+
61
+ Args:
62
+ query: Union[str, Tensor]: Query to retrieve documents for.
63
+
64
+ Returns:
65
+ List[str]: Documents retrieved from the vector database.
66
+ """
67
+ if isinstance(query, str):
68
+ with torch.no_grad():
69
+ query_enc = self.encoder(query, **self.model_kwargs)
70
+ else:
71
+ query_enc = query
72
+
73
+ selected_doc_idxs, _ = next(
74
+ batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
75
+ return [self.raw_docs[i] for i in selected_doc_idxs]
76
+
77
+ def save(self, path: str) -> None:
78
+ """Save the DocumentRetriever instance to disk.
79
+
80
+ Args:
81
+ path: str: Path where to save the retriever.
82
+ """
83
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
84
+
85
+ # Prepare data to save
86
+ save_dict = {
87
+ 'raw_docs': self.raw_docs,
88
+ 'embedded_docs': self.embedded_docs,
89
+ 'k_for_docs': self.k_for_docs,
90
+ }
91
+
92
+ # We do not serialize the model
93
+ torch.save(save_dict, path)
94
+
95
+ @classmethod
96
+ def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,
97
+ Callable],
98
+ model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:
99
+ """Load a DocumentRetriever instance from disk.
100
+
101
+ Args:
102
+ path: str: Path to the saved retriever.
103
+ model: Union[SentenceTransformer, torch.nn.Module, Callable]:
104
+ Model to use for encoding.
105
+ If None, the saved model will be used if available.
106
+ model_kwargs: Optional[Dict[str, Any]]
107
+ Key word args to be passed to model
108
+
109
+ Returns:
110
+ DocumentRetriever: The loaded retriever.
111
+ """
112
+ if not os.path.exists(path):
113
+ raise FileNotFoundError(
114
+ f"No saved document retriever found at {path}")
115
+
116
+ save_dict = torch.load(path, weights_only=False)
117
+ if save_dict['embedded_docs'] is not None \
118
+ and isinstance(save_dict['embedded_docs'], Tensor)\
119
+ and model_kwargs is not None:
120
+ model_kwargs.pop("verbose", None)
121
+ # Create a new DocumentRetriever with the loaded data
122
+ return cls(raw_docs=save_dict['raw_docs'],
123
+ embedded_docs=save_dict['embedded_docs'],
124
+ k_for_docs=save_dict['k_for_docs'], model=model,
125
+ model_kwargs=model_kwargs)
@@ -235,9 +235,9 @@ class ClusterData(torch.utils.data.Dataset):
235
235
  class ClusterLoader(torch.utils.data.DataLoader):
236
236
  r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
237
237
  for Training Deep and Large Graph Convolutional Networks"
238
- <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
239
- and their between-cluster links from a large-scale graph data object to
240
- form a mini-batch.
238
+ <https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned
239
+ subgraphs and their between-cluster links from a large-scale graph data
240
+ object to form a mini-batch.
241
241
 
242
242
  .. note::
243
243
 
@@ -252,7 +252,7 @@ class ClusterLoader(torch.utils.data.DataLoader):
252
252
 
253
253
  Args:
254
254
  cluster_data (torch_geometric.loader.ClusterData): The already
255
- partioned data object.
255
+ partitioned data object.
256
256
  **kwargs (optional): Additional arguments of
257
257
  :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
258
258
  :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
@@ -148,7 +148,7 @@ def indices_complete_check(
148
148
  if isinstance(aux, Tensor):
149
149
  aux = aux.cpu().numpy()
150
150
 
151
- assert np.all(np.in1d(out,
151
+ assert np.all(np.isin(out,
152
152
  aux)), "Not all output nodes are in aux nodes!"
153
153
  outs.append(out)
154
154
 
@@ -236,7 +236,7 @@ def create_batchwise_out_aux_pairs(
236
236
  logits[tele_set, i] = 1. / len(tele_set)
237
237
 
238
238
  new_logits = logits.clone()
239
- for i in range(num_iter):
239
+ for _ in range(num_iter):
240
240
  new_logits = adj @ new_logits * (1 - alpha) + alpha * logits
241
241
 
242
242
  inds = new_logits.argsort(0)
@@ -498,7 +498,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
498
498
  assert adj is not None
499
499
 
500
500
  for out, aux in pbar:
501
- mask = torch.from_numpy(np.in1d(aux, out))
501
+ mask = torch.from_numpy(np.isin(aux, out))
502
502
  if isinstance(aux, np.ndarray):
503
503
  aux = torch.from_numpy(aux)
504
504
  subg = get_subgraph(aux, graph, return_edge_index_type, adj,
@@ -541,7 +541,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
541
541
  out, aux = zip(*data_list)
542
542
  out = np.concatenate(out)
543
543
  aux = np.unique(np.concatenate(aux))
544
- mask = torch.from_numpy(np.in1d(aux, out))
544
+ mask = torch.from_numpy(np.isin(aux, out))
545
545
  aux = torch.from_numpy(aux)
546
546
 
547
547
  subg = get_subgraph(aux, self.graph, self.return_edge_index_type,
@@ -70,7 +70,7 @@ class LinkLoader(
70
70
  :obj:`edge_label_index`. If set, temporal sampling will be
71
71
  used such that neighbors are guaranteed to fulfill temporal
72
72
  constraints, *i.e.*, neighbors have an earlier timestamp than
73
- the ouput edge. The :obj:`time_attr` needs to be set for this
73
+ the output edge. The :obj:`time_attr` needs to be set for this
74
74
  to work. (default: :obj:`None`)
75
75
  neg_sampling (NegativeSampling, optional): The negative sampling
76
76
  configuration.
@@ -117,7 +117,7 @@ class LinkNeighborLoader(LinkLoader):
117
117
  :obj:`edge_label_index`. If set, temporal sampling will be
118
118
  used such that neighbors are guaranteed to fulfill temporal
119
119
  constraints, *i.e.*, neighbors have an earlier timestamp than
120
- the ouput edge. The :obj:`time_attr` needs to be set for this
120
+ the output edge. The :obj:`time_attr` needs to be set for this
121
121
  to work. (default: :obj:`None`)
122
122
  replace (bool, optional): If set to :obj:`True`, will sample with
123
123
  replacement. (default: :obj:`False`)
@@ -170,6 +170,7 @@ class LinkNeighborLoader(LinkLoader):
170
170
  negative sampling mode.
171
171
  If set to :obj:`None`, no negative sampling strategy is applied.
172
172
  (default: :obj:`None`)
173
+ For example use obj:`neg_sampling=dict(mode= 'binary', amount=0.5)`
173
174
  neg_sampling_ratio (int or float, optional): The ratio of sampled
174
175
  negative edges to the number of positive edges.
175
176
  Deprecated in favor of the :obj:`neg_sampling` argument.
@@ -106,9 +106,9 @@ class MultithreadingMixin:
106
106
  def _mt_init_fn(self, worker_id: int) -> None:
107
107
  try:
108
108
  torch.set_num_threads(int(self._worker_threads))
109
- except IndexError:
109
+ except IndexError as e:
110
110
  raise ValueError(f"Cannot set {self.worker_threads} threads "
111
- f"in worker {worker_id}")
111
+ f"in worker {worker_id}") from e
112
112
 
113
113
  # Chain worker init functions:
114
114
  self._old_worker_init_fn(worker_id)
@@ -213,9 +213,9 @@ class AffinityMixin:
213
213
 
214
214
  psutil.Process().cpu_affinity(worker_cores)
215
215
 
216
- except IndexError:
216
+ except IndexError as e:
217
217
  raise ValueError(f"Cannot use CPU affinity for worker ID "
218
- f"{worker_id} on CPU {self.loader_cores}")
218
+ f"{worker_id} on CPU {self.loader_cores}") from e
219
219
 
220
220
  # Chain worker init functions:
221
221
  self._old_worker_init_fn(worker_id)
@@ -248,7 +248,8 @@ class AffinityMixin:
248
248
  warnings.warn(
249
249
  "Due to conflicting parallelization methods it is not advised "
250
250
  "to use affinitization with 'HeteroData' datasets. "
251
- "Use `enable_multithreading` for better performance.")
251
+ "Use `enable_multithreading` for better performance.",
252
+ stacklevel=2)
252
253
 
253
254
  self.loader_cores = loader_cores[:] if loader_cores else None
254
255
  if self.loader_cores is None:
@@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader):
14
14
  This loader allows for mini-batch training of GNNs on large-scale graphs
15
15
  where full-batch training is not feasible.
16
16
 
17
- More specifically, :obj:`num_neighbors` denotes how much neighbors are
17
+ More specifically, :obj:`num_neighbors` denotes how many neighbors are
18
18
  sampled for each node in each iteration.
19
19
  :class:`~torch_geometric.loader.NeighborLoader` takes in this list of
20
20
  :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
@@ -72,9 +72,9 @@ class NeighborSampler(torch.utils.data.DataLoader):
72
72
  `examples/reddit.py
73
73
  <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
74
74
  reddit.py>`_ or
75
- `examples/ogbn_products_sage.py
75
+ `examples/ogbn_train.py
76
76
  <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
77
- ogbn_products_sage.py>`_.
77
+ ogbn_train.py>`_.
78
78
 
79
79
  Args:
80
80
  edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
@@ -27,8 +27,9 @@ class DeviceHelper:
27
27
 
28
28
  if ((self.device.type == 'cuda' and not with_cuda)
29
29
  or (self.device.type == 'xpu' and not with_xpu)):
30
- warnings.warn(f"Requested device '{self.device.type}' is not "
31
- f"available, falling back to CPU")
30
+ warnings.warn(
31
+ f"Requested device '{self.device.type}' is not "
32
+ f"available, falling back to CPU", stacklevel=2)
32
33
  self.device = torch.device('cpu')
33
34
 
34
35
  self.stream = None
@@ -6,7 +6,7 @@ from torch_geometric.data import TemporalData
6
6
 
7
7
 
8
8
  class TemporalDataLoader(torch.utils.data.DataLoader):
9
- r"""A data loader which merges succesive events of a
9
+ r"""A data loader which merges successive events of a
10
10
  :class:`torch_geometric.data.TemporalData` to a mini-batch.
11
11
 
12
12
  Args:
@@ -15,7 +15,7 @@ class TemporalDataLoader(torch.utils.data.DataLoader):
15
15
  batch_size (int, optional): How many samples per batch to load.
16
16
  (default: :obj:`1`)
17
17
  neg_sampling_ratio (float, optional): The ratio of sampled negative
18
- destination nodes to the number of postive destination nodes.
18
+ destination nodes to the number of positive destination nodes.
19
19
  (default: :obj:`0.0`)
20
20
  **kwargs (optional): Additional arguments of
21
21
  :class:`torch.utils.data.DataLoader`.
@@ -178,7 +178,7 @@ def filter_hetero_data(
178
178
  out = copy.copy(data)
179
179
 
180
180
  for node_type in out.node_types:
181
- # Handle the case of disconneted graph sampling:
181
+ # Handle the case of disconnected graph sampling:
182
182
  if node_type not in node_dict:
183
183
  node_dict[node_type] = torch.empty(0, dtype=torch.long)
184
184
 
@@ -186,7 +186,7 @@ def filter_hetero_data(
186
186
  node_dict[node_type])
187
187
 
188
188
  for edge_type in out.edge_types:
189
- # Handle the case of disconneted graph sampling:
189
+ # Handle the case of disconnected graph sampling:
190
190
  if edge_type not in row_dict:
191
191
  row_dict[edge_type] = torch.empty(0, dtype=torch.long)
192
192
  if edge_type not in col_dict:
@@ -256,14 +256,6 @@ def filter_custom_hetero_store(
256
256
  # Construct a new `HeteroData` object:
257
257
  data = custom_cls() if custom_cls is not None else HeteroData()
258
258
 
259
- # Filter edge storage:
260
- # TODO support edge attributes
261
- for attr in graph_store.get_all_edge_attrs():
262
- key = attr.edge_type
263
- if key in row_dict and key in col_dict:
264
- edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
265
- data[attr.edge_type].edge_index = edge_index
266
-
267
259
  # Filter node storage:
268
260
  required_attrs = []
269
261
  for attr in feature_store.get_all_tensor_attrs():
@@ -280,6 +272,14 @@ def filter_custom_hetero_store(
280
272
  for i, attr in enumerate(required_attrs):
281
273
  data[attr.group_name][attr.attr_name] = tensors[i]
282
274
 
275
+ # Filter edge storage:
276
+ # TODO support edge attributes
277
+ for attr in graph_store.get_all_edge_attrs():
278
+ key = attr.edge_type
279
+ if key in row_dict and key in col_dict:
280
+ edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
281
+ data[attr.edge_type].edge_index = edge_index
282
+
283
283
  return data
284
284
 
285
285
 
@@ -1,21 +1,35 @@
1
1
  # flake8: noqa
2
2
 
3
3
  from .link_pred import (
4
+ LinkPredMetric,
5
+ LinkPredMetricCollection,
4
6
  LinkPredPrecision,
5
7
  LinkPredRecall,
6
8
  LinkPredF1,
7
9
  LinkPredMAP,
8
10
  LinkPredNDCG,
9
11
  LinkPredMRR,
12
+ LinkPredHitRatio,
13
+ LinkPredCoverage,
14
+ LinkPredDiversity,
15
+ LinkPredPersonalization,
16
+ LinkPredAveragePopularity,
10
17
  )
11
18
 
12
19
  link_pred_metrics = [
20
+ 'LinkPredMetric',
21
+ 'LinkPredMetricCollection',
13
22
  'LinkPredPrecision',
14
23
  'LinkPredRecall',
15
24
  'LinkPredF1',
16
25
  'LinkPredMAP',
17
26
  'LinkPredNDCG',
18
27
  'LinkPredMRR',
28
+ 'LinkPredHitRatio',
29
+ 'LinkPredCoverage',
30
+ 'LinkPredDiversity',
31
+ 'LinkPredPersonalization',
32
+ 'LinkPredAveragePopularity',
19
33
  ]
20
34
 
21
35
  __all__ = link_pred_metrics