pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/typing.py CHANGED
@@ -3,7 +3,7 @@ import os
3
3
  import sys
4
4
  import typing
5
5
  import warnings
6
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -14,8 +14,10 @@ WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
14
14
  WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
15
15
  WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
16
16
  WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
17
- WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
18
- WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
17
+ WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
18
+ WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
19
+ WITH_PT27 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 7
20
+ WITH_PT28 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 8
19
21
  WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
20
22
 
21
23
  WITH_WINDOWS = os.name == 'nt'
@@ -62,10 +64,21 @@ try:
62
64
  pyg_lib.sampler.neighbor_sample).parameters)
63
65
  WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
64
66
  pyg_lib.sampler.neighbor_sample).parameters)
67
+ try:
68
+ torch.classes.pyg.CPUHashMap # noqa: B018
69
+ WITH_CPU_HASH_MAP = True
70
+ except Exception:
71
+ WITH_CPU_HASH_MAP = False
72
+ try:
73
+ torch.classes.pyg.CUDAHashMap # noqa: B018
74
+ WITH_CUDA_HASH_MAP = True
75
+ except Exception:
76
+ WITH_CUDA_HASH_MAP = False
65
77
  except Exception as e:
66
78
  if not isinstance(e, ImportError): # pragma: no cover
67
- warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
68
- f"Disabling its usage. Stacktrace: {e}")
79
+ warnings.warn(
80
+ f"An issue occurred while importing 'pyg-lib'. "
81
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
69
82
  pyg_lib = object
70
83
  WITH_PYG_LIB = False
71
84
  WITH_GMM = False
@@ -76,14 +89,41 @@ except Exception as e:
76
89
  WITH_METIS = False
77
90
  WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
78
91
  WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
92
+ WITH_CPU_HASH_MAP = False
93
+ WITH_CUDA_HASH_MAP = False
94
+
95
+ if WITH_CPU_HASH_MAP:
96
+ CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap # type: ignore[name-defined] # noqa: E501
97
+ else:
98
+
99
+ class CPUHashMap: # type: ignore
100
+ def __init__(self, key: Tensor) -> None:
101
+ raise ImportError("'CPUHashMap' requires 'pyg-lib'")
102
+
103
+ def get(self, query: Tensor) -> Tensor:
104
+ raise ImportError("'CPUHashMap' requires 'pyg-lib'")
105
+
106
+
107
+ if WITH_CUDA_HASH_MAP:
108
+ CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap # type: ignore[name-defined] # noqa: E501
109
+ else:
110
+
111
+ class CUDAHashMap: # type: ignore
112
+ def __init__(self, key: Tensor) -> None:
113
+ raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
114
+
115
+ def get(self, query: Tensor) -> Tensor:
116
+ raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
117
+
79
118
 
80
119
  try:
81
120
  import torch_scatter # noqa
82
121
  WITH_TORCH_SCATTER = True
83
122
  except Exception as e:
84
123
  if not isinstance(e, ImportError): # pragma: no cover
85
- warnings.warn(f"An issue occurred while importing 'torch-scatter'. "
86
- f"Disabling its usage. Stacktrace: {e}")
124
+ warnings.warn(
125
+ f"An issue occurred while importing 'torch-scatter'. "
126
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
87
127
  torch_scatter = object
88
128
  WITH_TORCH_SCATTER = False
89
129
 
@@ -93,8 +133,9 @@ try:
93
133
  WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
94
134
  except Exception as e:
95
135
  if not isinstance(e, ImportError): # pragma: no cover
96
- warnings.warn(f"An issue occurred while importing 'torch-cluster'. "
97
- f"Disabling its usage. Stacktrace: {e}")
136
+ warnings.warn(
137
+ f"An issue occurred while importing 'torch-cluster'. "
138
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
98
139
  WITH_TORCH_CLUSTER = False
99
140
  WITH_TORCH_CLUSTER_BATCH_SIZE = False
100
141
 
@@ -111,7 +152,7 @@ except Exception as e:
111
152
  if not isinstance(e, ImportError): # pragma: no cover
112
153
  warnings.warn(
113
154
  f"An issue occurred while importing 'torch-spline-conv'. "
114
- f"Disabling its usage. Stacktrace: {e}")
155
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
115
156
  WITH_TORCH_SPLINE_CONV = False
116
157
 
117
158
  try:
@@ -120,8 +161,9 @@ try:
120
161
  WITH_TORCH_SPARSE = True
121
162
  except Exception as e:
122
163
  if not isinstance(e, ImportError): # pragma: no cover
123
- warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
124
- f"Disabling its usage. Stacktrace: {e}")
164
+ warnings.warn(
165
+ f"An issue occurred while importing 'torch-sparse'. "
166
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
125
167
  WITH_TORCH_SPARSE = False
126
168
 
127
169
  class SparseStorage: # type: ignore
@@ -305,6 +347,8 @@ class EdgeTypeStr(str):
305
347
  r"""A helper class to construct serializable edge types by merging an edge
306
348
  type tuple into a single string.
307
349
  """
350
+ edge_type: tuple[str, str, str]
351
+
308
352
  def __new__(cls, *args: Any) -> 'EdgeTypeStr':
309
353
  if isinstance(args[0], (list, tuple)):
310
354
  # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
@@ -312,27 +356,37 @@ class EdgeTypeStr(str):
312
356
 
313
357
  if len(args) == 1 and isinstance(args[0], str):
314
358
  arg = args[0] # An edge type string was passed.
359
+ edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
360
+ if len(edge_type) != 3:
361
+ raise ValueError(f"Cannot convert the edge type '{arg}' to a "
362
+ f"tuple since it holds invalid characters")
315
363
 
316
364
  elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
317
365
  # A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
318
- arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1]))
366
+ edge_type = (args[0], DEFAULT_REL, args[1])
367
+ arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
319
368
 
320
369
  elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
321
370
  # A `(src, rel, dst)` edge type was passed:
371
+ edge_type = tuple(args)
322
372
  arg = EDGE_TYPE_STR_SPLIT.join(args)
323
373
 
324
374
  else:
325
375
  raise ValueError(f"Encountered invalid edge type '{args}'")
326
376
 
327
- return str.__new__(cls, arg)
377
+ out = str.__new__(cls, arg)
378
+ out.edge_type = edge_type # type: ignore
379
+ return out
328
380
 
329
381
  def to_tuple(self) -> EdgeType:
330
382
  r"""Returns the original edge type."""
331
- out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
332
- if len(out) != 3:
383
+ if len(self.edge_type) != 3:
333
384
  raise ValueError(f"Cannot convert the edge type '{self}' to a "
334
385
  f"tuple since it holds invalid characters")
335
- return out
386
+ return self.edge_type
387
+
388
+ def __reduce__(self) -> tuple[Any, Any]:
389
+ return (self.__class__, (self.edge_type, ))
336
390
 
337
391
 
338
392
  # There exist some short-cuts to query edge-types (given that the full triplet
@@ -370,3 +424,14 @@ MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]]
370
424
 
371
425
  InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
372
426
  InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]
427
+
428
+ # Serialization ###############################################################
429
+
430
+ if WITH_PT24:
431
+ torch.serialization.add_safe_globals([
432
+ SparseTensor,
433
+ SparseStorage,
434
+ TensorFrame,
435
+ MockTorchCSCTensor,
436
+ EdgeTypeStr,
437
+ ])
@@ -21,6 +21,7 @@ from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph,
21
21
  from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
22
22
  from ._homophily import homophily
23
23
  from ._assortativity import assortativity
24
+ from ._normalize_edge_index import normalize_edge_index
24
25
  from .laplacian import get_laplacian
25
26
  from .mesh_laplacian import get_mesh_laplacian
26
27
  from .mask import mask_select, index_to_mask, mask_to_index
@@ -52,10 +53,11 @@ from ._negative_sampling import (negative_sampling, batched_negative_sampling,
52
53
  structured_negative_sampling_feasible)
53
54
  from .augmentation import shuffle_node, mask_feature, add_random_edge
54
55
  from ._tree_decomposition import tree_decomposition
55
- from .embedding import get_embeddings
56
+ from .embedding import get_embeddings, get_embeddings_hetero
56
57
  from ._trim_to_layer import trim_to_layer
57
58
  from .ppr import get_ppr
58
59
  from ._train_test_split_edges import train_test_split_edges
60
+ from .influence import total_influence
59
61
 
60
62
  __all__ = [
61
63
  'scatter',
@@ -89,6 +91,7 @@ __all__ = [
89
91
  'dropout_adj',
90
92
  'homophily',
91
93
  'assortativity',
94
+ 'normalize_edge_index',
92
95
  'get_laplacian',
93
96
  'get_mesh_laplacian',
94
97
  'mask_select',
@@ -143,9 +146,11 @@ __all__ = [
143
146
  'add_random_edge',
144
147
  'tree_decomposition',
145
148
  'get_embeddings',
149
+ 'get_embeddings_hetero',
146
150
  'trim_to_layer',
147
151
  'get_ppr',
148
152
  'train_test_split_edges',
153
+ 'total_influence',
149
154
  ]
150
155
 
151
156
  # `structured_negative_sampling_feasible` is a long name and thus destroys the
@@ -1,11 +1,7 @@
1
1
  from typing import List
2
2
 
3
- import numpy as np
4
- import torch
5
3
  from torch import Tensor
6
4
 
7
- import torch_geometric.typing
8
-
9
5
 
10
6
  def lexsort(
11
7
  keys: List[Tensor],
@@ -28,11 +24,6 @@ def lexsort(
28
24
  """
29
25
  assert len(keys) >= 1
30
26
 
31
- if not torch_geometric.typing.WITH_PT113:
32
- keys = [k.neg() for k in keys] if descending else keys
33
- out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
34
- return torch.from_numpy(out).to(keys[0].device)
35
-
36
27
  out = keys[0].argsort(dim=dim, descending=descending, stable=True)
37
28
  for k in keys[1:]:
38
29
  index = k.gather(dim, out)
@@ -12,7 +12,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
12
12
  def negative_sampling(
13
13
  edge_index: Tensor,
14
14
  num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
15
- num_neg_samples: Optional[int] = None,
15
+ num_neg_samples: Optional[Union[int, float]] = None,
16
16
  method: str = "sparse",
17
17
  force_undirected: bool = False,
18
18
  ) -> Tensor:
@@ -25,10 +25,12 @@ def negative_sampling(
25
25
  If given as a tuple, then :obj:`edge_index` is interpreted as a
26
26
  bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.
27
27
  (default: :obj:`None`)
28
- num_neg_samples (int, optional): The (approximate) number of negative
29
- samples to return.
30
- If set to :obj:`None`, will try to return a negative edge for every
31
- positive edge. (default: :obj:`None`)
28
+ num_neg_samples (int or float, optional): The (approximate) number of
29
+ negative samples to return. If set to a floating-point value, it
30
+ represents the ratio of negative samples to generate based on the
31
+ number of positive edges. If set to :obj:`None`, will try to
32
+ return a negative edge for every positive edge.
33
+ (default: :obj:`None`)
32
34
  method (str, optional): The method to use for negative sampling,
33
35
  *i.e.* :obj:`"sparse"` or :obj:`"dense"`.
34
36
  This is a memory/runtime trade-off.
@@ -48,6 +50,11 @@ def negative_sampling(
48
50
  tensor([[3, 0, 0, 3],
49
51
  [2, 3, 2, 1]])
50
52
 
53
+ >>> negative_sampling(edge_index, num_nodes=(3, 4),
54
+ ... num_neg_samples=0.5) # 50% of positive edges
55
+ tensor([[0, 3],
56
+ [3, 0]])
57
+
51
58
  >>> # For bipartite graph
52
59
  >>> negative_sampling(edge_index, num_nodes=(3, 4))
53
60
  tensor([[0, 2, 2, 1],
@@ -74,6 +81,8 @@ def negative_sampling(
74
81
 
75
82
  if num_neg_samples is None:
76
83
  num_neg_samples = edge_index.size(1)
84
+ elif isinstance(num_neg_samples, float):
85
+ num_neg_samples = int(num_neg_samples * edge_index.size(1))
77
86
  if force_undirected:
78
87
  num_neg_samples = num_neg_samples // 2
79
88
 
@@ -100,10 +109,9 @@ def negative_sampling(
100
109
  idx = idx.to('cpu')
101
110
  for _ in range(3): # Number of tries to sample negative indices.
102
111
  rnd = sample(population, sample_size, device='cpu')
103
- mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore
112
+ mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
104
113
  if neg_idx is not None:
105
- mask |= np.isin(rnd, neg_idx.to('cpu'))
106
- mask = torch.from_numpy(mask).to(torch.bool)
114
+ mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
107
115
  rnd = rnd[~mask].to(edge_index.device)
108
116
  neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
109
117
  if neg_idx.numel() >= num_neg_samples:
@@ -117,7 +125,7 @@ def negative_sampling(
117
125
  def batched_negative_sampling(
118
126
  edge_index: Tensor,
119
127
  batch: Union[Tensor, Tuple[Tensor, Tensor]],
120
- num_neg_samples: Optional[int] = None,
128
+ num_neg_samples: Optional[Union[int, float]] = None,
121
129
  method: str = "sparse",
122
130
  force_undirected: bool = False,
123
131
  ) -> Tensor:
@@ -131,9 +139,11 @@ def batched_negative_sampling(
131
139
  node to a specific example.
132
140
  If given as a tuple, then :obj:`edge_index` is interpreted as a
133
141
  bipartite graph connecting two different node types.
134
- num_neg_samples (int, optional): The number of negative samples to
135
- return. If set to :obj:`None`, will try to return a negative edge
136
- for every positive edge. (default: :obj:`None`)
142
+ num_neg_samples (int or float, optional): The number of negative
143
+ samples to return. If set to :obj:`None`, will try to return a
144
+ negative edge for every positive edge. If float, it will generate
145
+ :obj:`num_neg_samples * num_edges` negative samples.
146
+ (default: :obj:`None`)
137
147
  method (str, optional): The method to use for negative sampling,
138
148
  *i.e.* :obj:`"sparse"` or :obj:`"dense"`.
139
149
  This is a memory/runtime trade-off.
@@ -157,6 +167,11 @@ def batched_negative_sampling(
157
167
  tensor([[3, 1, 3, 2, 7, 7, 6, 5],
158
168
  [2, 0, 1, 1, 5, 6, 4, 4]])
159
169
 
170
+ >>> # Using float multiplier for negative samples
171
+ >>> batched_negative_sampling(edge_index, batch, num_neg_samples=1.5)
172
+ tensor([[3, 1, 3, 2, 7, 7, 6, 5, 2, 0, 1, 1],
173
+ [2, 0, 1, 1, 5, 6, 4, 4, 3, 2, 3, 0]])
174
+
160
175
  >>> # For bipartite graph
161
176
  >>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])
162
177
  >>> edge_index2 = edge_index1 + torch.tensor([[2], [4]])
@@ -265,7 +280,7 @@ def structured_negative_sampling_feasible(
265
280
  :meth:`~torch_geometric.utils.structured_negative_sampling` is feasible
266
281
  on the graph given by :obj:`edge_index`.
267
282
  :meth:`~torch_geometric.utils.structured_negative_sampling` is infeasible
268
- if atleast one node is connected to all other nodes.
283
+ if at least one node is connected to all other nodes.
269
284
 
270
285
  Args:
271
286
  edge_index (LongTensor): The edge indices.
@@ -0,0 +1,46 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.utils import add_self_loops as add_self_loops_fn
7
+ from torch_geometric.utils import degree
8
+
9
+
10
+ def normalize_edge_index(
11
+ edge_index: Tensor,
12
+ num_nodes: Optional[int] = None,
13
+ add_self_loops: bool = True,
14
+ symmetric: bool = True,
15
+ ) -> Tuple[Tensor, Tensor]:
16
+ """Applies normalization to the edges of a graph.
17
+
18
+ This function can add self-loops to the graph and apply either symmetric or
19
+ asymmetric normalization based on the node degrees.
20
+
21
+ Args:
22
+ edge_index (LongTensor): The edge indices.
23
+ num_nodes (int, int], optional): The number of nodes, *i.e.*
24
+ :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
25
+ add_self_loops (bool, optional): If set to :obj:`False`, will not add
26
+ self-loops to the input graph. (default: :obj:`True`)
27
+ symmetric (bool, optional): If set to :obj:`True`, symmetric
28
+ normalization (:math:`D^{-1/2} A D^{-1/2}`) is used, otherwise
29
+ asymmetric normalization (:math:`D^{-1} A`).
30
+ """
31
+ if add_self_loops:
32
+ edge_index, _ = add_self_loops_fn(edge_index, num_nodes=num_nodes)
33
+
34
+ row, col = edge_index[0], edge_index[1]
35
+ deg = degree(row, num_nodes, dtype=torch.get_default_dtype())
36
+
37
+ if symmetric: # D^-1/2 * A * D^-1/2
38
+ deg_inv_sqrt = deg.pow(-0.5)
39
+ deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
40
+ edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
41
+ else: # D^-1 * A
42
+ deg_inv = deg.pow(-1)
43
+ deg_inv[torch.isinf(deg_inv)] = 0
44
+ edge_weight = deg_inv[row]
45
+
46
+ return edge_index, edge_weight