pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -10,10 +10,16 @@ from packaging.requirements import Requirement
10
10
  from packaging.version import Version
11
11
 
12
12
  import torch_geometric
13
+ import torch_geometric.typing
13
14
  from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
14
15
  from torch_geometric.visualization.graph import has_graphviz
15
16
 
16
17
 
18
+ def is_rag_test() -> bool:
19
+ r"""Whether to run the RAG test suite."""
20
+ return os.getenv('RAG_TEST', '0') == '1'
21
+
22
+
17
23
  def is_full_test() -> bool:
18
24
  r"""Whether to run the full but time-consuming test suite."""
19
25
  return os.getenv('FULL_TEST', '0') == '1'
@@ -32,8 +38,8 @@ def onlyFullTest(func: Callable) -> Callable:
32
38
 
33
39
  def is_distributed_test() -> bool:
34
40
  r"""Whether to run the distributed test suite."""
35
- return ((is_full_test() or os.getenv('DIST_TEST', '0') == '1')
36
- and sys.platform == 'linux' and has_package('pyg_lib'))
41
+ return (os.getenv('DIST_TEST', '0') == '1' and sys.platform == 'linux'
42
+ and has_package('pyg_lib'))
37
43
 
38
44
 
39
45
  def onlyDistributedTest(func: Callable) -> Callable:
@@ -203,6 +209,18 @@ def withPackage(*args: str) -> Callable:
203
209
  return decorator
204
210
 
205
211
 
212
+ def onlyRAG(func: Callable) -> Callable:
213
+ r"""A decorator to specify that this function belongs to the RAG test
214
+ suite.
215
+ """
216
+ import pytest
217
+ func = pytest.mark.rag(func)
218
+ return pytest.mark.skipif(
219
+ not is_rag_test(),
220
+ reason="RAG tests are disabled",
221
+ )(func)
222
+
223
+
206
224
  def withCUDA(func: Callable) -> Callable:
207
225
  r"""A decorator to test both on CPU and CUDA (if available)."""
208
226
  import pytest
@@ -234,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
234
252
  if device:
235
253
  backend = os.getenv('TORCH_BACKEND')
236
254
  if backend is None:
237
- warnings.warn(f"Please specify the backend via 'TORCH_BACKEND' in"
238
- f"order to test against '{device}'")
255
+ warnings.warn(
256
+ f"Please specify the backend via 'TORCH_BACKEND' in"
257
+ f"order to test against '{device}'", stacklevel=2)
239
258
  else:
240
259
  import_module(backend)
241
260
  devices.append(pytest.param(torch.device(device), id=device))
@@ -250,7 +269,7 @@ def withMETIS(func: Callable) -> Callable:
250
269
  with_metis = WITH_METIS
251
270
 
252
271
  if with_metis:
253
- try: # Test that METIS can succesfully execute:
272
+ try: # Test that METIS can successfully execute:
254
273
  # TODO Using `pyg-lib` metis partitioning leads to some weird bugs
255
274
  # in the # CI. As such, we require `torch-sparse` for now.
256
275
  rowptr = torch.tensor([0, 2, 4, 6])
@@ -265,6 +284,17 @@ def withMETIS(func: Callable) -> Callable:
265
284
  )(func)
266
285
 
267
286
 
287
+ def withHashTensor(func: Callable) -> Callable:
288
+ r"""A decorator to only test in case :class:`HashTensor` is available."""
289
+ import pytest
290
+
291
+ return pytest.mark.skipif(
292
+ not torch_geometric.typing.WITH_CPU_HASH_MAP
293
+ and not has_package('pandas'),
294
+ reason="HashTensor dependencies not available",
295
+ )(func)
296
+
297
+
268
298
  def disableExtensions(func: Callable) -> Callable:
269
299
  r"""A decorator to temporarily disable the usage of the
270
300
  :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
@@ -73,7 +73,7 @@ def assert_run_mproc(
73
73
  ]
74
74
  results = []
75
75
 
76
- for p, q in zip(procs, queues):
76
+ for p, _ in zip(procs, queues):
77
77
  p.start()
78
78
 
79
79
  for p, q in zip(procs, queues):
@@ -20,6 +20,7 @@ from .target_indegree import TargetIndegree
20
20
  from .local_degree_profile import LocalDegreeProfile
21
21
  from .add_self_loops import AddSelfLoops
22
22
  from .add_remaining_self_loops import AddRemainingSelfLoops
23
+ from .remove_self_loops import RemoveSelfLoops
23
24
  from .remove_isolated_nodes import RemoveIsolatedNodes
24
25
  from .remove_duplicated_edges import RemoveDuplicatedEdges
25
26
  from .knn_graph import KNNGraph
@@ -36,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
36
37
  from .largest_connected_components import LargestConnectedComponents
37
38
  from .virtual_node import VirtualNode
38
39
  from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
40
+ from .add_gpse import AddGPSE
39
41
  from .feature_propagation import FeaturePropagation
40
42
  from .half_hop import HalfHop
41
43
 
@@ -87,6 +89,7 @@ graph_transforms = [
87
89
  'LocalDegreeProfile',
88
90
  'AddSelfLoops',
89
91
  'AddRemainingSelfLoops',
92
+ 'RemoveSelfLoops',
90
93
  'RemoveIsolatedNodes',
91
94
  'RemoveDuplicatedEdges',
92
95
  'KNNGraph',
@@ -106,6 +109,7 @@ graph_transforms = [
106
109
  'VirtualNode',
107
110
  'AddLaplacianEigenvectorPE',
108
111
  'AddRandomWalkPE',
112
+ 'AddGPSE',
109
113
  'FeaturePropagation',
110
114
  'HalfHop',
111
115
  ]
@@ -0,0 +1,49 @@
1
+ from typing import Any
2
+
3
+ from torch.nn import Module
4
+
5
+ from torch_geometric.data import Data
6
+ from torch_geometric.data.datapipes import functional_transform
7
+ from torch_geometric.transforms import BaseTransform, VirtualNode
8
+
9
+
10
+ @functional_transform('add_gpse')
11
+ class AddGPSE(BaseTransform):
12
+ r"""Adds the GPSE encoding from the `"Graph Positional and Structural
13
+ Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
14
+ (functional name: :obj:`add_gpse`).
15
+ To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
16
+ the actual encodings.
17
+
18
+ Args:
19
+ model (Module): The pre-trained GPSE model.
20
+ use_vn (bool, optional): Whether to use virtual nodes.
21
+ (default: :obj:`True`)
22
+ rand_type (str, optional): Type of random features to use. Options are
23
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
24
+ (default: :obj:`NormalSE`)
25
+
26
+ """
27
+ def __init__(
28
+ self,
29
+ model: Module,
30
+ use_vn: bool = True,
31
+ rand_type: str = 'NormalSE',
32
+ ):
33
+ self.model = model
34
+ self.use_vn = use_vn
35
+ self.vn = VirtualNode()
36
+ self.rand_type = rand_type
37
+
38
+ def forward(self, data: Data) -> Any:
39
+ pass
40
+
41
+ def __call__(self, data: Data) -> Data:
42
+ from torch_geometric.nn.models.gpse import gpse_process
43
+
44
+ data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
45
+ batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
46
+ batch_out = batch_out.to('cpu', non_blocking=True)
47
+ data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
48
+
49
+ return data
@@ -37,7 +37,7 @@ class AddMetaPaths(BaseTransform):
37
37
  :class:`~torch_geometric.data.HeteroData` object as edge type
38
38
  :obj:`(src_node_type, "metapath_*", dst_node_type)`, where
39
39
  :obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\mathcal{V}_1`
40
- and :math:`\mathcal{V}_{\ell}`, repectively.
40
+ and :math:`\mathcal{V}_{\ell}`, respectively.
41
41
 
42
42
  In addition, a :obj:`metapath_dict` object is added to the
43
43
  :class:`~torch_geometric.data.HeteroData` object which maps the
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
108
108
  **kwargs: bool,
109
109
  ) -> None:
110
110
  if 'drop_orig_edges' in kwargs:
111
- warnings.warn("'drop_orig_edges' is dprecated. Use "
112
- "'drop_orig_edge_types' instead")
111
+ warnings.warn(
112
+ "'drop_orig_edges' is deprecated. Use "
113
+ "'drop_orig_edge_types' instead", stacklevel=2)
113
114
  drop_orig_edge_types = kwargs['drop_orig_edges']
114
115
 
115
116
  if 'drop_unconnected_nodes' in kwargs:
116
- warnings.warn("'drop_unconnected_nodes' is dprecated. Use "
117
- "'drop_unconnected_node_types' instead")
117
+ warnings.warn(
118
+ "'drop_unconnected_nodes' is deprecated. Use "
119
+ "'drop_unconnected_node_types' instead", stacklevel=2)
118
120
  drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
119
121
 
120
122
  for path in metapaths:
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
144
146
  if self.max_sample is not None:
145
147
  edge_index, edge_weight = self._sample(edge_index, edge_weight)
146
148
 
147
- for i, edge_type in enumerate(metapath[1:]):
149
+ for edge_type in metapath[1:]:
148
150
  edge_index2, edge_weight2 = self._edge_index(data, edge_type)
149
151
 
150
152
  edge_index, edge_weight = edge_index.matmul(
@@ -231,7 +233,7 @@ class AddRandomMetaPaths(BaseTransform):
231
233
  will drop node types not connected by any edge type.
232
234
  (default: :obj:`False`)
233
235
  walks_per_node (int, List[int], optional): The number of random walks
234
- for each starting node in a metapth. (default: :obj:`1`)
236
+ for each starting node in a metapath. (default: :obj:`1`)
235
237
  sample_ratio (float, optional): The ratio of source nodes to start
236
238
  random walks from. (default: :obj:`1.0`)
237
239
  """
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
276
278
  row = start = torch.randperm(num_nodes)[:num_starts].repeat(
277
279
  self.walks_per_node[j])
278
280
 
279
- for i, edge_type in enumerate(metapath):
281
+ for edge_type in metapath:
280
282
  edge_index = EdgeIndex(
281
283
  data[edge_type].edge_index,
282
284
  sparse_size=data[edge_type].size(),
@@ -92,12 +92,12 @@ class AddLaplacianEigenvectorPE(BaseTransform):
92
92
  from numpy.linalg import eig, eigh
93
93
  eig_fn = eig if not self.is_undirected else eigh
94
94
 
95
- eig_vals, eig_vecs = eig_fn(L.todense()) # type: ignore
95
+ eig_vals, eig_vecs = eig_fn(L.todense())
96
96
  else:
97
97
  from scipy.sparse.linalg import eigs, eigsh
98
98
  eig_fn = eigs if not self.is_undirected else eigsh
99
99
 
100
- eig_vals, eig_vecs = eig_fn( # type: ignore
100
+ eig_vals, eig_vecs = eig_fn(
101
101
  L,
102
102
  k=self.k + 1,
103
103
  which='SR' if not self.is_undirected else 'SA',
@@ -1,5 +1,5 @@
1
1
  import copy
2
- from abc import ABC
2
+ from abc import ABC, abstractmethod
3
3
  from typing import Any
4
4
 
5
5
 
@@ -31,6 +31,7 @@ class BaseTransform(ABC):
31
31
  # Shallow-copy the data so that we prevent in-place data modification.
32
32
  return self.forward(copy.copy(data))
33
33
 
34
+ @abstractmethod
34
35
  def forward(self, data: Any) -> Any:
35
36
  pass
36
37
 
@@ -1,3 +1,5 @@
1
+ from typing import List
2
+
1
3
  import torch
2
4
 
3
5
  from torch_geometric.data import Data
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
5
7
  from torch_geometric.transforms import BaseTransform
6
8
 
7
9
 
10
+ class _QhullTransform(BaseTransform):
11
+ r"""Q-hull implementation of delaunay triangulation."""
12
+ def forward(self, data: Data) -> Data:
13
+ assert data.pos is not None
14
+ import scipy.spatial
15
+
16
+ pos = data.pos.cpu().numpy()
17
+ tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
18
+ face = torch.from_numpy(tri.simplices)
19
+
20
+ data.face = face.t().contiguous().to(data.pos.device, torch.long)
21
+ return data
22
+
23
+
24
+ class _ShullTransform(BaseTransform):
25
+ r"""Sweep-hull implementation of delaunay triangulation."""
26
+ def forward(self, data: Data) -> Data:
27
+ assert data.pos is not None
28
+ from torch_delaunay.functional import shull2d
29
+
30
+ face = shull2d(data.pos.cpu())
31
+ data.face = face.t().contiguous().to(data.pos.device)
32
+ return data
33
+
34
+
35
+ class _SequentialTransform(BaseTransform):
36
+ r"""Runs the first successful transformation.
37
+
38
+ All intermediate exceptions are suppressed except the last.
39
+ """
40
+ def __init__(self, transforms: List[BaseTransform]) -> None:
41
+ assert len(transforms) > 0
42
+ self.transforms = transforms
43
+
44
+ def forward(self, data: Data) -> Data:
45
+ for i, transform in enumerate(self.transforms):
46
+ try:
47
+ return transform.forward(data)
48
+ except ImportError as e:
49
+ if i == len(self.transforms) - 1:
50
+ raise e
51
+ return data
52
+
53
+
8
54
  @functional_transform('delaunay')
9
55
  class Delaunay(BaseTransform):
10
56
  r"""Computes the delaunay triangulation of a set of points
11
57
  (functional name: :obj:`delaunay`).
58
+
59
+ .. hint::
60
+ Consider installing the
61
+ `torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
62
+ to speed up computation.
12
63
  """
13
- def forward(self, data: Data) -> Data:
14
- import scipy.spatial
64
+ def __init__(self) -> None:
65
+ self._transform = _SequentialTransform([
66
+ _ShullTransform(),
67
+ _QhullTransform(),
68
+ ])
15
69
 
70
+ def forward(self, data: Data) -> Data:
16
71
  assert data.pos is not None
72
+ device = data.pos.device
17
73
 
18
74
  if data.pos.size(0) < 2:
19
- data.edge_index = torch.tensor([], dtype=torch.long,
20
- device=data.pos.device).view(2, 0)
21
- if data.pos.size(0) == 2:
22
- data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
23
- device=data.pos.device)
75
+ data.edge_index = torch.empty(2, 0, dtype=torch.long,
76
+ device=device)
77
+ elif data.pos.size(0) == 2:
78
+ data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
24
79
  elif data.pos.size(0) == 3:
25
- data.face = torch.tensor([[0], [1], [2]], dtype=torch.long,
26
- device=data.pos.device)
27
- if data.pos.size(0) > 3:
28
- pos = data.pos.cpu().numpy()
29
- tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
30
- face = torch.from_numpy(tri.simplices)
31
-
32
- data.face = face.t().contiguous().to(data.pos.device, torch.long)
80
+ data.face = torch.tensor([[0], [1], [2]], device=device)
81
+ else:
82
+ data = self._transform.forward(data)
33
83
 
34
84
  return data
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
8
8
 
9
9
  @functional_transform('face_to_edge')
10
10
  class FaceToEdge(BaseTransform):
11
- r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices
12
- :obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`).
11
+ r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
12
+ :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
13
+ (functional name: :obj:`face_to_edge`).
14
+
15
+ This transform supports both 2D triangular faces, represented by a
16
+ tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
17
+ represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
18
+ these faces into edge indices, where each edge is defined by the indices
19
+ of its two endpoints.
13
20
 
14
21
  Args:
15
22
  remove_faces (bool, optional): If set to :obj:`False`, the face tensor
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
22
29
  if hasattr(data, 'face'):
23
30
  assert data.face is not None
24
31
  face = data.face
25
- edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
32
+
33
+ if face.size(0) not in [3, 4]:
34
+ raise RuntimeError(f"Expected 'face' tensor with shape "
35
+ f"[3, num_faces] or [4, num_faces] "
36
+ f"(got {list(face.size())})")
37
+
38
+ if face.size()[0] == 3:
39
+ edge_index = torch.cat([
40
+ face[:2],
41
+ face[1:],
42
+ face[::2],
43
+ ], dim=1)
44
+ else:
45
+ assert face.size()[0] == 4
46
+ edge_index = torch.cat([
47
+ face[:2],
48
+ face[1:3],
49
+ face[2:4],
50
+ face[::2],
51
+ face[1::2],
52
+ face[::3],
53
+ ], dim=1)
54
+
26
55
  edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
27
56
 
28
57
  data.edge_index = edge_index
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Tuple
1
+ from typing import Any, Dict, Optional, Tuple
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -21,7 +21,7 @@ from torch_geometric.utils import (
21
21
  @functional_transform('gdc')
22
22
  class GDC(BaseTransform):
23
23
  r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
24
- `"Diffusion Improves Graph Learning" <https://www.kdd.in.tum.de/gdc>`_
24
+ `"Diffusion Improves Graph Learning" <https://arxiv.org/abs/1911.05485>`_
25
25
  paper (functional name: :obj:`gdc`).
26
26
 
27
27
  .. note::
@@ -78,18 +78,17 @@ class GDC(BaseTransform):
78
78
  self_loop_weight: float = 1.,
79
79
  normalization_in: str = 'sym',
80
80
  normalization_out: str = 'col',
81
- diffusion_kwargs: Dict[str, Any] = dict(method='ppr', alpha=0.15),
82
- sparsification_kwargs: Dict[str, Any] = dict(
83
- method='threshold',
84
- avg_degree=64,
85
- ),
81
+ diffusion_kwargs: Optional[Dict[str, Any]] = None,
82
+ sparsification_kwargs: Optional[Dict[str, Any]] = None,
86
83
  exact: bool = True,
87
84
  ) -> None:
88
85
  self.self_loop_weight = self_loop_weight
89
86
  self.normalization_in = normalization_in
90
87
  self.normalization_out = normalization_out
91
- self.diffusion_kwargs = diffusion_kwargs
92
- self.sparsification_kwargs = sparsification_kwargs
88
+ self.diffusion_kwargs = diffusion_kwargs or dict(
89
+ method='ppr', alpha=0.15)
90
+ self.sparsification_kwargs = sparsification_kwargs or dict(
91
+ method='threshold', avg_degree=64)
93
92
  self.exact = exact
94
93
 
95
94
  if self_loop_weight:
@@ -47,7 +47,7 @@ class LargestConnectedComponents(BaseTransform):
47
47
  return data
48
48
 
49
49
  _, count = np.unique(component, return_counts=True)
50
- subset_np = np.in1d(component, count.argsort()[-self.num_components:])
50
+ subset_np = np.isin(component, count.argsort()[-self.num_components:])
51
51
  subset = torch.from_numpy(subset_np)
52
52
  subset = subset.to(data.edge_index.device, torch.bool)
53
53
 
@@ -19,7 +19,11 @@ def get_attrs_with_suffix(
19
19
  return [key for key in store.keys() if key.endswith(suffix)]
20
20
 
21
21
 
22
- def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
22
+ def get_mask_size(
23
+ attr: str,
24
+ store: BaseStorage,
25
+ size: Optional[int],
26
+ ) -> Optional[int]:
23
27
  if size is not None:
24
28
  return size
25
29
  return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
@@ -53,7 +53,7 @@ class NodePropertySplit(BaseTransform):
53
53
 
54
54
  property_name = 'popularity'
55
55
  ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
56
- tranaform = NodePropertySplit(property_name, ratios)
56
+ transform = NodePropertySplit(property_name, ratios)
57
57
 
58
58
  data = transform(data)
59
59
  """
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import List, Optional, Union
2
2
 
3
3
  from torch_geometric.data import Data, HeteroData
4
4
  from torch_geometric.data.datapipes import functional_transform
@@ -14,8 +14,8 @@ class NormalizeFeatures(BaseTransform):
14
14
  attrs (List[str]): The names of attributes to normalize.
15
15
  (default: :obj:`["x"]`)
16
16
  """
17
- def __init__(self, attrs: List[str] = ["x"]):
18
- self.attrs = attrs
17
+ def __init__(self, attrs: Optional[List[str]] = None) -> None:
18
+ self.attrs = attrs or ["x"]
19
19
 
20
20
  def forward(
21
21
  self,
@@ -262,7 +262,7 @@ class Pad(BaseTransform):
262
262
  All the attributes of node types other than :obj:`v0` and :obj:`v1` are
263
263
  padded using a value of :obj:`1.0`.
264
264
  All the attributes of the :obj:`('v0', 'e0', 'v1')` edge type are padded
265
- usin a value of :obj:`3.5`.
265
+ using a value of :obj:`3.5`.
266
266
  The :obj:`edge_attr` attributes of the :obj:`('v1', 'e0', 'v0')` edge type
267
267
  are padded using a value of :obj:`-1.5`, and any other attributes of this
268
268
  edge type are padded using a value of :obj:`5.5`.
@@ -245,7 +245,7 @@ class RandomLinkSplit(BaseTransform):
245
245
  warnings.warn(
246
246
  f"There are not enough negative edges to satisfy "
247
247
  "the provided sampling ratio. The ratio will be "
248
- f"adjusted to {ratio:.2f}.")
248
+ f"adjusted to {ratio:.2f}.", stacklevel=2)
249
249
  num_neg_train = int((num_neg_train / num_neg) * num_neg_found)
250
250
  num_neg_val = int((num_neg_val / num_neg) * num_neg_found)
251
251
  num_neg_test = num_neg_found - num_neg_train - num_neg_val
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import List, Optional, Union
2
2
 
3
3
  from torch_geometric.data import Data, HeteroData
4
4
  from torch_geometric.data.datapipes import functional_transform
@@ -22,9 +22,11 @@ class RemoveDuplicatedEdges(BaseTransform):
22
22
  """
23
23
  def __init__(
24
24
  self,
25
- key: Union[str, List[str]] = ['edge_attr', 'edge_weight'],
25
+ key: Optional[Union[str, List[str]]] = None,
26
26
  reduce: str = "add",
27
27
  ) -> None:
28
+ key = key or ['edge_attr', 'edge_weight']
29
+
28
30
  if isinstance(key, str):
29
31
  key = [key]
30
32
 
@@ -0,0 +1,36 @@
1
+ from typing import Union
2
+
3
+ from torch_geometric.data import Data, HeteroData
4
+ from torch_geometric.data.datapipes import functional_transform
5
+ from torch_geometric.transforms import BaseTransform
6
+ from torch_geometric.utils import remove_self_loops
7
+
8
+
9
+ @functional_transform('remove_self_loops')
10
+ class RemoveSelfLoops(BaseTransform):
11
+ r"""Removes all self-loops in the given homogeneous or heterogeneous
12
+ graph (functional name: :obj:`remove_self_loops`).
13
+
14
+ Args:
15
+ attr (str, optional): The name of the attribute of edge weights
16
+ or multi-dimensional edge features to pass to
17
+ :meth:`torch_geometric.utils.remove_self_loops`.
18
+ (default: :obj:`"edge_weight"`)
19
+ """
20
+ def __init__(self, attr: str = 'edge_weight') -> None:
21
+ self.attr = attr
22
+
23
+ def forward(
24
+ self,
25
+ data: Union[Data, HeteroData],
26
+ ) -> Union[Data, HeteroData]:
27
+ for store in data.edge_stores:
28
+ if store.is_bipartite() or 'edge_index' not in store:
29
+ continue
30
+
31
+ store.edge_index, store[self.attr] = remove_self_loops(
32
+ store.edge_index,
33
+ edge_attr=store.get(self.attr, None),
34
+ )
35
+
36
+ return data
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
94
94
  arange = torch.arange(n_id.size(0), device=data.edge_index.device)
95
95
  node_map = data.edge_index.new_ones(num_nodes, num_nodes)
96
96
  node_map[n_sub_batch, n_id] = arange
97
- sub_edge_index += (arange * data.num_nodes)[e_sub_batch]
97
+ sub_edge_index += (arange * num_nodes)[e_sub_batch]
98
98
  sub_edge_index = node_map.view(-1)[sub_edge_index]
99
99
 
100
100
  return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
@@ -11,7 +11,7 @@ class SVDFeatureReduction(BaseTransform):
11
11
  Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`).
12
12
 
13
13
  Args:
14
- out_channels (int): The dimensionlity of node features after
14
+ out_channels (int): The dimensionality of node features after
15
15
  reduction.
16
16
  """
17
17
  def __init__(self, out_channels: int):
@@ -37,7 +37,8 @@ class VirtualNode(BaseTransform):
37
37
  col = torch.cat([col, full, arange], dim=0)
38
38
  edge_index = torch.stack([row, col], dim=0)
39
39
 
40
- new_type = edge_type.new_full((num_nodes, ), int(edge_type.max()) + 1)
40
+ num_edge_types = int(edge_type.max()) if edge_type.numel() > 0 else 0
41
+ new_type = edge_type.new_full((num_nodes, ), num_edge_types + 1)
41
42
  edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0)
42
43
 
43
44
  old_data = copy.copy(data)