pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -10,10 +10,16 @@ from packaging.requirements import Requirement
10
10
  from packaging.version import Version
11
11
 
12
12
  import torch_geometric
13
+ import torch_geometric.typing
13
14
  from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
14
15
  from torch_geometric.visualization.graph import has_graphviz
15
16
 
16
17
 
18
+ def is_rag_test() -> bool:
19
+ r"""Whether to run the RAG test suite."""
20
+ return os.getenv('RAG_TEST', '0') == '1'
21
+
22
+
17
23
  def is_full_test() -> bool:
18
24
  r"""Whether to run the full but time-consuming test suite."""
19
25
  return os.getenv('FULL_TEST', '0') == '1'
@@ -32,8 +38,8 @@ def onlyFullTest(func: Callable) -> Callable:
32
38
 
33
39
  def is_distributed_test() -> bool:
34
40
  r"""Whether to run the distributed test suite."""
35
- return ((is_full_test() or os.getenv('DIST_TEST', '0') == '1')
36
- and sys.platform == 'linux' and has_package('pyg_lib'))
41
+ return (os.getenv('DIST_TEST', '0') == '1' and sys.platform == 'linux'
42
+ and has_package('pyg_lib'))
37
43
 
38
44
 
39
45
  def onlyDistributedTest(func: Callable) -> Callable:
@@ -203,6 +209,18 @@ def withPackage(*args: str) -> Callable:
203
209
  return decorator
204
210
 
205
211
 
212
+ def onlyRAG(func: Callable) -> Callable:
213
+ r"""A decorator to specify that this function belongs to the RAG test
214
+ suite.
215
+ """
216
+ import pytest
217
+ func = pytest.mark.rag(func)
218
+ return pytest.mark.skipif(
219
+ not is_rag_test(),
220
+ reason="RAG tests are disabled",
221
+ )(func)
222
+
223
+
206
224
  def withCUDA(func: Callable) -> Callable:
207
225
  r"""A decorator to test both on CPU and CUDA (if available)."""
208
226
  import pytest
@@ -234,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
234
252
  if device:
235
253
  backend = os.getenv('TORCH_BACKEND')
236
254
  if backend is None:
237
- warnings.warn(f"Please specify the backend via 'TORCH_BACKEND' in"
238
- f"order to test against '{device}'")
255
+ warnings.warn(
256
+ f"Please specify the backend via 'TORCH_BACKEND' in"
257
+ f"order to test against '{device}'", stacklevel=2)
239
258
  else:
240
259
  import_module(backend)
241
260
  devices.append(pytest.param(torch.device(device), id=device))
@@ -250,7 +269,7 @@ def withMETIS(func: Callable) -> Callable:
250
269
  with_metis = WITH_METIS
251
270
 
252
271
  if with_metis:
253
- try: # Test that METIS can succesfully execute:
272
+ try: # Test that METIS can successfully execute:
254
273
  # TODO Using `pyg-lib` metis partitioning leads to some weird bugs
255
274
  # in the # CI. As such, we require `torch-sparse` for now.
256
275
  rowptr = torch.tensor([0, 2, 4, 6])
@@ -265,6 +284,17 @@ def withMETIS(func: Callable) -> Callable:
265
284
  )(func)
266
285
 
267
286
 
287
+ def withHashTensor(func: Callable) -> Callable:
288
+ r"""A decorator to only test in case :class:`HashTensor` is available."""
289
+ import pytest
290
+
291
+ return pytest.mark.skipif(
292
+ not torch_geometric.typing.WITH_CPU_HASH_MAP
293
+ and not has_package('pandas'),
294
+ reason="HashTensor dependencies not available",
295
+ )(func)
296
+
297
+
268
298
  def disableExtensions(func: Callable) -> Callable:
269
299
  r"""A decorator to temporarily disable the usage of the
270
300
  :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
@@ -73,7 +73,7 @@ def assert_run_mproc(
73
73
  ]
74
74
  results = []
75
75
 
76
- for p, q in zip(procs, queues):
76
+ for p, _ in zip(procs, queues):
77
77
  p.start()
78
78
 
79
79
  for p, q in zip(procs, queues):
@@ -37,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
37
37
  from .largest_connected_components import LargestConnectedComponents
38
38
  from .virtual_node import VirtualNode
39
39
  from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
40
+ from .add_gpse import AddGPSE
40
41
  from .feature_propagation import FeaturePropagation
41
42
  from .half_hop import HalfHop
42
43
 
@@ -108,6 +109,7 @@ graph_transforms = [
108
109
  'VirtualNode',
109
110
  'AddLaplacianEigenvectorPE',
110
111
  'AddRandomWalkPE',
112
+ 'AddGPSE',
111
113
  'FeaturePropagation',
112
114
  'HalfHop',
113
115
  ]
@@ -0,0 +1,49 @@
1
+ from typing import Any
2
+
3
+ from torch.nn import Module
4
+
5
+ from torch_geometric.data import Data
6
+ from torch_geometric.data.datapipes import functional_transform
7
+ from torch_geometric.transforms import BaseTransform, VirtualNode
8
+
9
+
10
+ @functional_transform('add_gpse')
11
+ class AddGPSE(BaseTransform):
12
+ r"""Adds the GPSE encoding from the `"Graph Positional and Structural
13
+ Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
14
+ (functional name: :obj:`add_gpse`).
15
+ To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
16
+ the actual encodings.
17
+
18
+ Args:
19
+ model (Module): The pre-trained GPSE model.
20
+ use_vn (bool, optional): Whether to use virtual nodes.
21
+ (default: :obj:`True`)
22
+ rand_type (str, optional): Type of random features to use. Options are
23
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
24
+ (default: :obj:`NormalSE`)
25
+
26
+ """
27
+ def __init__(
28
+ self,
29
+ model: Module,
30
+ use_vn: bool = True,
31
+ rand_type: str = 'NormalSE',
32
+ ):
33
+ self.model = model
34
+ self.use_vn = use_vn
35
+ self.vn = VirtualNode()
36
+ self.rand_type = rand_type
37
+
38
+ def forward(self, data: Data) -> Any:
39
+ pass
40
+
41
+ def __call__(self, data: Data) -> Data:
42
+ from torch_geometric.nn.models.gpse import gpse_process
43
+
44
+ data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
45
+ batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
46
+ batch_out = batch_out.to('cpu', non_blocking=True)
47
+ data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
48
+
49
+ return data
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
108
108
  **kwargs: bool,
109
109
  ) -> None:
110
110
  if 'drop_orig_edges' in kwargs:
111
- warnings.warn("'drop_orig_edges' is deprecated. Use "
112
- "'drop_orig_edge_types' instead")
111
+ warnings.warn(
112
+ "'drop_orig_edges' is deprecated. Use "
113
+ "'drop_orig_edge_types' instead", stacklevel=2)
113
114
  drop_orig_edge_types = kwargs['drop_orig_edges']
114
115
 
115
116
  if 'drop_unconnected_nodes' in kwargs:
116
- warnings.warn("'drop_unconnected_nodes' is deprecated. Use "
117
- "'drop_unconnected_node_types' instead")
117
+ warnings.warn(
118
+ "'drop_unconnected_nodes' is deprecated. Use "
119
+ "'drop_unconnected_node_types' instead", stacklevel=2)
118
120
  drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
119
121
 
120
122
  for path in metapaths:
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
144
146
  if self.max_sample is not None:
145
147
  edge_index, edge_weight = self._sample(edge_index, edge_weight)
146
148
 
147
- for i, edge_type in enumerate(metapath[1:]):
149
+ for edge_type in metapath[1:]:
148
150
  edge_index2, edge_weight2 = self._edge_index(data, edge_type)
149
151
 
150
152
  edge_index, edge_weight = edge_index.matmul(
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
276
278
  row = start = torch.randperm(num_nodes)[:num_starts].repeat(
277
279
  self.walks_per_node[j])
278
280
 
279
- for i, edge_type in enumerate(metapath):
281
+ for edge_type in metapath:
280
282
  edge_index = EdgeIndex(
281
283
  data[edge_type].edge_index,
282
284
  sparse_size=data[edge_type].size(),
@@ -92,12 +92,12 @@ class AddLaplacianEigenvectorPE(BaseTransform):
92
92
  from numpy.linalg import eig, eigh
93
93
  eig_fn = eig if not self.is_undirected else eigh
94
94
 
95
- eig_vals, eig_vecs = eig_fn(L.todense()) # type: ignore
95
+ eig_vals, eig_vecs = eig_fn(L.todense())
96
96
  else:
97
97
  from scipy.sparse.linalg import eigs, eigsh
98
98
  eig_fn = eigs if not self.is_undirected else eigsh
99
99
 
100
- eig_vals, eig_vecs = eig_fn( # type: ignore
100
+ eig_vals, eig_vecs = eig_fn(
101
101
  L,
102
102
  k=self.k + 1,
103
103
  which='SR' if not self.is_undirected else 'SA',
@@ -1,5 +1,5 @@
1
1
  import copy
2
- from abc import ABC
2
+ from abc import ABC, abstractmethod
3
3
  from typing import Any
4
4
 
5
5
 
@@ -31,6 +31,7 @@ class BaseTransform(ABC):
31
31
  # Shallow-copy the data so that we prevent in-place data modification.
32
32
  return self.forward(copy.copy(data))
33
33
 
34
+ @abstractmethod
34
35
  def forward(self, data: Any) -> Any:
35
36
  pass
36
37
 
@@ -1,3 +1,5 @@
1
+ from typing import List
2
+
1
3
  import torch
2
4
 
3
5
  from torch_geometric.data import Data
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
5
7
  from torch_geometric.transforms import BaseTransform
6
8
 
7
9
 
10
+ class _QhullTransform(BaseTransform):
11
+ r"""Q-hull implementation of delaunay triangulation."""
12
+ def forward(self, data: Data) -> Data:
13
+ assert data.pos is not None
14
+ import scipy.spatial
15
+
16
+ pos = data.pos.cpu().numpy()
17
+ tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
18
+ face = torch.from_numpy(tri.simplices)
19
+
20
+ data.face = face.t().contiguous().to(data.pos.device, torch.long)
21
+ return data
22
+
23
+
24
+ class _ShullTransform(BaseTransform):
25
+ r"""Sweep-hull implementation of delaunay triangulation."""
26
+ def forward(self, data: Data) -> Data:
27
+ assert data.pos is not None
28
+ from torch_delaunay.functional import shull2d
29
+
30
+ face = shull2d(data.pos.cpu())
31
+ data.face = face.t().contiguous().to(data.pos.device)
32
+ return data
33
+
34
+
35
+ class _SequentialTransform(BaseTransform):
36
+ r"""Runs the first successful transformation.
37
+
38
+ All intermediate exceptions are suppressed except the last.
39
+ """
40
+ def __init__(self, transforms: List[BaseTransform]) -> None:
41
+ assert len(transforms) > 0
42
+ self.transforms = transforms
43
+
44
+ def forward(self, data: Data) -> Data:
45
+ for i, transform in enumerate(self.transforms):
46
+ try:
47
+ return transform.forward(data)
48
+ except ImportError as e:
49
+ if i == len(self.transforms) - 1:
50
+ raise e
51
+ return data
52
+
53
+
8
54
  @functional_transform('delaunay')
9
55
  class Delaunay(BaseTransform):
10
56
  r"""Computes the delaunay triangulation of a set of points
11
57
  (functional name: :obj:`delaunay`).
58
+
59
+ .. hint::
60
+ Consider installing the
61
+ `torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
62
+ to speed up computation.
12
63
  """
13
- def forward(self, data: Data) -> Data:
14
- import scipy.spatial
64
+ def __init__(self) -> None:
65
+ self._transform = _SequentialTransform([
66
+ _ShullTransform(),
67
+ _QhullTransform(),
68
+ ])
15
69
 
70
+ def forward(self, data: Data) -> Data:
16
71
  assert data.pos is not None
72
+ device = data.pos.device
17
73
 
18
74
  if data.pos.size(0) < 2:
19
- data.edge_index = torch.tensor([], dtype=torch.long,
20
- device=data.pos.device).view(2, 0)
21
- if data.pos.size(0) == 2:
22
- data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
23
- device=data.pos.device)
75
+ data.edge_index = torch.empty(2, 0, dtype=torch.long,
76
+ device=device)
77
+ elif data.pos.size(0) == 2:
78
+ data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
24
79
  elif data.pos.size(0) == 3:
25
- data.face = torch.tensor([[0], [1], [2]], dtype=torch.long,
26
- device=data.pos.device)
27
- if data.pos.size(0) > 3:
28
- pos = data.pos.cpu().numpy()
29
- tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
30
- face = torch.from_numpy(tri.simplices)
31
-
32
- data.face = face.t().contiguous().to(data.pos.device, torch.long)
80
+ data.face = torch.tensor([[0], [1], [2]], device=device)
81
+ else:
82
+ data = self._transform.forward(data)
33
83
 
34
84
  return data
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
8
8
 
9
9
  @functional_transform('face_to_edge')
10
10
  class FaceToEdge(BaseTransform):
11
- r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices
12
- :obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`).
11
+ r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
12
+ :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
13
+ (functional name: :obj:`face_to_edge`).
14
+
15
+ This transform supports both 2D triangular faces, represented by a
16
+ tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
17
+ represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
18
+ these faces into edge indices, where each edge is defined by the indices
19
+ of its two endpoints.
13
20
 
14
21
  Args:
15
22
  remove_faces (bool, optional): If set to :obj:`False`, the face tensor
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
22
29
  if hasattr(data, 'face'):
23
30
  assert data.face is not None
24
31
  face = data.face
25
- edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
32
+
33
+ if face.size(0) not in [3, 4]:
34
+ raise RuntimeError(f"Expected 'face' tensor with shape "
35
+ f"[3, num_faces] or [4, num_faces] "
36
+ f"(got {list(face.size())})")
37
+
38
+ if face.size()[0] == 3:
39
+ edge_index = torch.cat([
40
+ face[:2],
41
+ face[1:],
42
+ face[::2],
43
+ ], dim=1)
44
+ else:
45
+ assert face.size()[0] == 4
46
+ edge_index = torch.cat([
47
+ face[:2],
48
+ face[1:3],
49
+ face[2:4],
50
+ face[::2],
51
+ face[1::2],
52
+ face[::3],
53
+ ], dim=1)
54
+
26
55
  edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
27
56
 
28
57
  data.edge_index = edge_index
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Tuple
1
+ from typing import Any, Dict, Optional, Tuple
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -78,18 +78,17 @@ class GDC(BaseTransform):
78
78
  self_loop_weight: float = 1.,
79
79
  normalization_in: str = 'sym',
80
80
  normalization_out: str = 'col',
81
- diffusion_kwargs: Dict[str, Any] = dict(method='ppr', alpha=0.15),
82
- sparsification_kwargs: Dict[str, Any] = dict(
83
- method='threshold',
84
- avg_degree=64,
85
- ),
81
+ diffusion_kwargs: Optional[Dict[str, Any]] = None,
82
+ sparsification_kwargs: Optional[Dict[str, Any]] = None,
86
83
  exact: bool = True,
87
84
  ) -> None:
88
85
  self.self_loop_weight = self_loop_weight
89
86
  self.normalization_in = normalization_in
90
87
  self.normalization_out = normalization_out
91
- self.diffusion_kwargs = diffusion_kwargs
92
- self.sparsification_kwargs = sparsification_kwargs
88
+ self.diffusion_kwargs = diffusion_kwargs or dict(
89
+ method='ppr', alpha=0.15)
90
+ self.sparsification_kwargs = sparsification_kwargs or dict(
91
+ method='threshold', avg_degree=64)
93
92
  self.exact = exact
94
93
 
95
94
  if self_loop_weight:
@@ -47,7 +47,7 @@ class LargestConnectedComponents(BaseTransform):
47
47
  return data
48
48
 
49
49
  _, count = np.unique(component, return_counts=True)
50
- subset_np = np.in1d(component, count.argsort()[-self.num_components:])
50
+ subset_np = np.isin(component, count.argsort()[-self.num_components:])
51
51
  subset = torch.from_numpy(subset_np)
52
52
  subset = subset.to(data.edge_index.device, torch.bool)
53
53
 
@@ -19,7 +19,11 @@ def get_attrs_with_suffix(
19
19
  return [key for key in store.keys() if key.endswith(suffix)]
20
20
 
21
21
 
22
- def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
22
+ def get_mask_size(
23
+ attr: str,
24
+ store: BaseStorage,
25
+ size: Optional[int],
26
+ ) -> Optional[int]:
23
27
  if size is not None:
24
28
  return size
25
29
  return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import List, Optional, Union
2
2
 
3
3
  from torch_geometric.data import Data, HeteroData
4
4
  from torch_geometric.data.datapipes import functional_transform
@@ -14,8 +14,8 @@ class NormalizeFeatures(BaseTransform):
14
14
  attrs (List[str]): The names of attributes to normalize.
15
15
  (default: :obj:`["x"]`)
16
16
  """
17
- def __init__(self, attrs: List[str] = ["x"]):
18
- self.attrs = attrs
17
+ def __init__(self, attrs: Optional[List[str]] = None) -> None:
18
+ self.attrs = attrs or ["x"]
19
19
 
20
20
  def forward(
21
21
  self,
@@ -245,7 +245,7 @@ class RandomLinkSplit(BaseTransform):
245
245
  warnings.warn(
246
246
  f"There are not enough negative edges to satisfy "
247
247
  "the provided sampling ratio. The ratio will be "
248
- f"adjusted to {ratio:.2f}.")
248
+ f"adjusted to {ratio:.2f}.", stacklevel=2)
249
249
  num_neg_train = int((num_neg_train / num_neg) * num_neg_found)
250
250
  num_neg_val = int((num_neg_val / num_neg) * num_neg_found)
251
251
  num_neg_test = num_neg_found - num_neg_train - num_neg_val
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import List, Optional, Union
2
2
 
3
3
  from torch_geometric.data import Data, HeteroData
4
4
  from torch_geometric.data.datapipes import functional_transform
@@ -22,9 +22,11 @@ class RemoveDuplicatedEdges(BaseTransform):
22
22
  """
23
23
  def __init__(
24
24
  self,
25
- key: Union[str, List[str]] = ['edge_attr', 'edge_weight'],
25
+ key: Optional[Union[str, List[str]]] = None,
26
26
  reduce: str = "add",
27
27
  ) -> None:
28
+ key = key or ['edge_attr', 'edge_weight']
29
+
28
30
  if isinstance(key, str):
29
31
  key = [key]
30
32
 
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
94
94
  arange = torch.arange(n_id.size(0), device=data.edge_index.device)
95
95
  node_map = data.edge_index.new_ones(num_nodes, num_nodes)
96
96
  node_map[n_sub_batch, n_id] = arange
97
- sub_edge_index += (arange * data.num_nodes)[e_sub_batch]
97
+ sub_edge_index += (arange * num_nodes)[e_sub_batch]
98
98
  sub_edge_index = node_map.view(-1)[sub_edge_index]
99
99
 
100
100
  return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
torch_geometric/typing.py CHANGED
@@ -3,7 +3,7 @@ import os
3
3
  import sys
4
4
  import typing
5
5
  import warnings
6
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -15,8 +15,9 @@ WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
15
15
  WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
16
16
  WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
17
17
  WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
18
- WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
19
- WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
18
+ WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
19
+ WITH_PT27 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 7
20
+ WITH_PT28 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 8
20
21
  WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
21
22
 
22
23
  WITH_WINDOWS = os.name == 'nt'
@@ -63,10 +64,21 @@ try:
63
64
  pyg_lib.sampler.neighbor_sample).parameters)
64
65
  WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
65
66
  pyg_lib.sampler.neighbor_sample).parameters)
67
+ try:
68
+ torch.classes.pyg.CPUHashMap # noqa: B018
69
+ WITH_CPU_HASH_MAP = True
70
+ except Exception:
71
+ WITH_CPU_HASH_MAP = False
72
+ try:
73
+ torch.classes.pyg.CUDAHashMap # noqa: B018
74
+ WITH_CUDA_HASH_MAP = True
75
+ except Exception:
76
+ WITH_CUDA_HASH_MAP = False
66
77
  except Exception as e:
67
78
  if not isinstance(e, ImportError): # pragma: no cover
68
- warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
69
- f"Disabling its usage. Stacktrace: {e}")
79
+ warnings.warn(
80
+ f"An issue occurred while importing 'pyg-lib'. "
81
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
70
82
  pyg_lib = object
71
83
  WITH_PYG_LIB = False
72
84
  WITH_GMM = False
@@ -77,14 +89,41 @@ except Exception as e:
77
89
  WITH_METIS = False
78
90
  WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
79
91
  WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
92
+ WITH_CPU_HASH_MAP = False
93
+ WITH_CUDA_HASH_MAP = False
94
+
95
+ if WITH_CPU_HASH_MAP:
96
+ CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap # type: ignore[name-defined] # noqa: E501
97
+ else:
98
+
99
+ class CPUHashMap: # type: ignore
100
+ def __init__(self, key: Tensor) -> None:
101
+ raise ImportError("'CPUHashMap' requires 'pyg-lib'")
102
+
103
+ def get(self, query: Tensor) -> Tensor:
104
+ raise ImportError("'CPUHashMap' requires 'pyg-lib'")
105
+
106
+
107
+ if WITH_CUDA_HASH_MAP:
108
+ CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap # type: ignore[name-defined] # noqa: E501
109
+ else:
110
+
111
+ class CUDAHashMap: # type: ignore
112
+ def __init__(self, key: Tensor) -> None:
113
+ raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
114
+
115
+ def get(self, query: Tensor) -> Tensor:
116
+ raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
117
+
80
118
 
81
119
  try:
82
120
  import torch_scatter # noqa
83
121
  WITH_TORCH_SCATTER = True
84
122
  except Exception as e:
85
123
  if not isinstance(e, ImportError): # pragma: no cover
86
- warnings.warn(f"An issue occurred while importing 'torch-scatter'. "
87
- f"Disabling its usage. Stacktrace: {e}")
124
+ warnings.warn(
125
+ f"An issue occurred while importing 'torch-scatter'. "
126
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
88
127
  torch_scatter = object
89
128
  WITH_TORCH_SCATTER = False
90
129
 
@@ -94,8 +133,9 @@ try:
94
133
  WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
95
134
  except Exception as e:
96
135
  if not isinstance(e, ImportError): # pragma: no cover
97
- warnings.warn(f"An issue occurred while importing 'torch-cluster'. "
98
- f"Disabling its usage. Stacktrace: {e}")
136
+ warnings.warn(
137
+ f"An issue occurred while importing 'torch-cluster'. "
138
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
99
139
  WITH_TORCH_CLUSTER = False
100
140
  WITH_TORCH_CLUSTER_BATCH_SIZE = False
101
141
 
@@ -112,7 +152,7 @@ except Exception as e:
112
152
  if not isinstance(e, ImportError): # pragma: no cover
113
153
  warnings.warn(
114
154
  f"An issue occurred while importing 'torch-spline-conv'. "
115
- f"Disabling its usage. Stacktrace: {e}")
155
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
116
156
  WITH_TORCH_SPLINE_CONV = False
117
157
 
118
158
  try:
@@ -121,8 +161,9 @@ try:
121
161
  WITH_TORCH_SPARSE = True
122
162
  except Exception as e:
123
163
  if not isinstance(e, ImportError): # pragma: no cover
124
- warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
125
- f"Disabling its usage. Stacktrace: {e}")
164
+ warnings.warn(
165
+ f"An issue occurred while importing 'torch-sparse'. "
166
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
126
167
  WITH_TORCH_SPARSE = False
127
168
 
128
169
  class SparseStorage: # type: ignore
@@ -306,6 +347,8 @@ class EdgeTypeStr(str):
306
347
  r"""A helper class to construct serializable edge types by merging an edge
307
348
  type tuple into a single string.
308
349
  """
350
+ edge_type: tuple[str, str, str]
351
+
309
352
  def __new__(cls, *args: Any) -> 'EdgeTypeStr':
310
353
  if isinstance(args[0], (list, tuple)):
311
354
  # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
@@ -313,27 +356,37 @@ class EdgeTypeStr(str):
313
356
 
314
357
  if len(args) == 1 and isinstance(args[0], str):
315
358
  arg = args[0] # An edge type string was passed.
359
+ edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
360
+ if len(edge_type) != 3:
361
+ raise ValueError(f"Cannot convert the edge type '{arg}' to a "
362
+ f"tuple since it holds invalid characters")
316
363
 
317
364
  elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
318
365
  # A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
319
- arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1]))
366
+ edge_type = (args[0], DEFAULT_REL, args[1])
367
+ arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
320
368
 
321
369
  elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
322
370
  # A `(src, rel, dst)` edge type was passed:
371
+ edge_type = tuple(args)
323
372
  arg = EDGE_TYPE_STR_SPLIT.join(args)
324
373
 
325
374
  else:
326
375
  raise ValueError(f"Encountered invalid edge type '{args}'")
327
376
 
328
- return str.__new__(cls, arg)
377
+ out = str.__new__(cls, arg)
378
+ out.edge_type = edge_type # type: ignore
379
+ return out
329
380
 
330
381
  def to_tuple(self) -> EdgeType:
331
382
  r"""Returns the original edge type."""
332
- out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
333
- if len(out) != 3:
383
+ if len(self.edge_type) != 3:
334
384
  raise ValueError(f"Cannot convert the edge type '{self}' to a "
335
385
  f"tuple since it holds invalid characters")
336
- return out
386
+ return self.edge_type
387
+
388
+ def __reduce__(self) -> tuple[Any, Any]:
389
+ return (self.__class__, (self.edge_type, ))
337
390
 
338
391
 
339
392
  # There exist some short-cuts to query edge-types (given that the full triplet