pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -19,7 +19,11 @@ def get_attrs_with_suffix(
19
19
  return [key for key in store.keys() if key.endswith(suffix)]
20
20
 
21
21
 
22
- def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
22
+ def get_mask_size(
23
+ attr: str,
24
+ store: BaseStorage,
25
+ size: Optional[int],
26
+ ) -> Optional[int]:
23
27
  if size is not None:
24
28
  return size
25
29
  return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
@@ -44,7 +44,6 @@ class NodePropertySplit(BaseTransform):
44
44
  of the node property, so that nodes with greater values of the
45
45
  property are considered to be OOD (default: :obj:`True`)
46
46
 
47
- Example:
48
47
  .. code-block:: python
49
48
 
50
49
  from torch_geometric.transforms import NodePropertySplit
@@ -54,7 +53,7 @@ class NodePropertySplit(BaseTransform):
54
53
 
55
54
  property_name = 'popularity'
56
55
  ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
57
- tranaform = NodePropertySplit(property_name, ratios)
56
+ transform = NodePropertySplit(property_name, ratios)
58
57
 
59
58
  data = transform(data)
60
59
  """
@@ -262,15 +262,14 @@ class Pad(BaseTransform):
262
262
  All the attributes of node types other than :obj:`v0` and :obj:`v1` are
263
263
  padded using a value of :obj:`1.0`.
264
264
  All the attributes of the :obj:`('v0', 'e0', 'v1')` edge type are padded
265
- usin a value of :obj:`3.5`.
265
+ using a value of :obj:`3.5`.
266
266
  The :obj:`edge_attr` attributes of the :obj:`('v1', 'e0', 'v0')` edge type
267
267
  are padded using a value of :obj:`-1.5`, and any other attributes of this
268
268
  edge type are padded using a value of :obj:`5.5`.
269
269
  All the attributes of edge types other than these two are padded using a
270
270
  value of :obj:`1.5`.
271
271
 
272
- Example:
273
- .. code-block::
272
+ .. code-block:: python
274
273
 
275
274
  num_nodes = {'v0': 10, 'v1': 20, 'v2':30}
276
275
  num_edges = {('v0', 'e0', 'v1'): 80}
@@ -467,9 +466,11 @@ class Pad(BaseTransform):
467
466
  edge_type: Optional[EdgeType] = None,
468
467
  ) -> None:
469
468
 
470
- attrs_to_pad = set(
471
- attr for attr in store.keys()
472
- if store.is_edge_attr(attr) and self.__should_pad_edge_attr(attr))
469
+ attrs_to_pad = {
470
+ attr
471
+ for attr in store.keys()
472
+ if store.is_edge_attr(attr) and self.__should_pad_edge_attr(attr)
473
+ }
473
474
  if not attrs_to_pad:
474
475
  return
475
476
  num_target_edges = self.max_num_edges.get_value(edge_type)
@@ -23,7 +23,7 @@ class RandomLinkSplit(BaseTransform):
23
23
  in validation and test splits; and the validation split does not include
24
24
  edges in the test split.
25
25
 
26
- .. code-block::
26
+ .. code-block:: python
27
27
 
28
28
  from torch_geometric.transforms import RandomLinkSplit
29
29
 
@@ -0,0 +1,36 @@
1
+ from typing import Union
2
+
3
+ from torch_geometric.data import Data, HeteroData
4
+ from torch_geometric.data.datapipes import functional_transform
5
+ from torch_geometric.transforms import BaseTransform
6
+ from torch_geometric.utils import remove_self_loops
7
+
8
+
9
+ @functional_transform('remove_self_loops')
10
+ class RemoveSelfLoops(BaseTransform):
11
+ r"""Removes all self-loops in the given homogeneous or heterogeneous
12
+ graph (functional name: :obj:`remove_self_loops`).
13
+
14
+ Args:
15
+ attr (str, optional): The name of the attribute of edge weights
16
+ or multi-dimensional edge features to pass to
17
+ :meth:`torch_geometric.utils.remove_self_loops`.
18
+ (default: :obj:`"edge_weight"`)
19
+ """
20
+ def __init__(self, attr: str = 'edge_weight') -> None:
21
+ self.attr = attr
22
+
23
+ def forward(
24
+ self,
25
+ data: Union[Data, HeteroData],
26
+ ) -> Union[Data, HeteroData]:
27
+ for store in data.edge_stores:
28
+ if store.is_bipartite() or 'edge_index' not in store:
29
+ continue
30
+
31
+ store.edge_index, store[self.attr] = remove_self_loops(
32
+ store.edge_index,
33
+ edge_attr=store.get(self.attr, None),
34
+ )
35
+
36
+ return data
@@ -11,7 +11,7 @@ class SVDFeatureReduction(BaseTransform):
11
11
  Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`).
12
12
 
13
13
  Args:
14
- out_channels (int): The dimensionlity of node features after
14
+ out_channels (int): The dimensionality of node features after
15
15
  reduction.
16
16
  """
17
17
  def __init__(self, out_channels: int):
@@ -79,7 +79,7 @@ class ToSparseTensor(BaseTransform):
79
79
 
80
80
  keys, values = [], []
81
81
  for key, value in store.items():
82
- if key == 'edge_index':
82
+ if key in {'edge_index', 'edge_label', 'edge_label_index'}:
83
83
  continue
84
84
 
85
85
  if store.is_edge_attr(key):
@@ -19,7 +19,7 @@ class TwoHop(BaseTransform):
19
19
 
20
20
  edge_index = EdgeIndex(edge_index, sparse_size=(N, N))
21
21
  edge_index = edge_index.sort_by('row')[0]
22
- edge_index2, _ = edge_index @ edge_index
22
+ edge_index2 = edge_index.matmul(edge_index)[0].as_tensor()
23
23
  edge_index2, _ = remove_self_loops(edge_index2)
24
24
  edge_index = torch.cat([edge_index, edge_index2], dim=1)
25
25
 
@@ -37,7 +37,8 @@ class VirtualNode(BaseTransform):
37
37
  col = torch.cat([col, full, arange], dim=0)
38
38
  edge_index = torch.stack([row, col], dim=0)
39
39
 
40
- new_type = edge_type.new_full((num_nodes, ), int(edge_type.max()) + 1)
40
+ num_edge_types = int(edge_type.max()) if edge_type.numel() > 0 else 0
41
+ new_type = edge_type.new_full((num_nodes, ), num_edge_types + 1)
41
42
  edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0)
42
43
 
43
44
  old_data = copy.copy(data)
torch_geometric/typing.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import inspect
2
2
  import os
3
3
  import sys
4
+ import typing
4
5
  import warnings
5
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
7
 
7
8
  import numpy as np
8
9
  import torch
@@ -12,6 +13,9 @@ WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
12
13
  WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
13
14
  WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
14
15
  WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
16
+ WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
17
+ WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
18
+ WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
15
19
  WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
16
20
  WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
17
21
  WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
@@ -21,6 +25,16 @@ NO_MKL = 'USE_MKL=OFF' in torch.__config__.show() or WITH_WINDOWS
21
25
 
22
26
  MAX_INT64 = torch.iinfo(torch.int64).max
23
27
 
28
+ if WITH_PT20:
29
+ INDEX_DTYPES: Set[torch.dtype] = {
30
+ torch.int32,
31
+ torch.int64,
32
+ }
33
+ elif not typing.TYPE_CHECKING: # pragma: no cover
34
+ INDEX_DTYPES: Set[torch.dtype] = {
35
+ torch.int64,
36
+ }
37
+
24
38
  if not hasattr(torch, 'sparse_csc'):
25
39
  torch.sparse_csc = torch.sparse_coo
26
40
 
@@ -293,6 +307,8 @@ class EdgeTypeStr(str):
293
307
  r"""A helper class to construct serializable edge types by merging an edge
294
308
  type tuple into a single string.
295
309
  """
310
+ edge_type: tuple[str, str, str]
311
+
296
312
  def __new__(cls, *args: Any) -> 'EdgeTypeStr':
297
313
  if isinstance(args[0], (list, tuple)):
298
314
  # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
@@ -300,27 +316,37 @@ class EdgeTypeStr(str):
300
316
 
301
317
  if len(args) == 1 and isinstance(args[0], str):
302
318
  arg = args[0] # An edge type string was passed.
319
+ edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
320
+ if len(edge_type) != 3:
321
+ raise ValueError(f"Cannot convert the edge type '{arg}' to a "
322
+ f"tuple since it holds invalid characters")
303
323
 
304
324
  elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
305
325
  # A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
306
- arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1]))
326
+ edge_type = (args[0], DEFAULT_REL, args[1])
327
+ arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
307
328
 
308
329
  elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
309
330
  # A `(src, rel, dst)` edge type was passed:
331
+ edge_type = tuple(args)
310
332
  arg = EDGE_TYPE_STR_SPLIT.join(args)
311
333
 
312
334
  else:
313
335
  raise ValueError(f"Encountered invalid edge type '{args}'")
314
336
 
315
- return str.__new__(cls, arg)
337
+ out = str.__new__(cls, arg)
338
+ out.edge_type = edge_type # type: ignore
339
+ return out
316
340
 
317
341
  def to_tuple(self) -> EdgeType:
318
342
  r"""Returns the original edge type."""
319
- out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
320
- if len(out) != 3:
343
+ if len(self.edge_type) != 3:
321
344
  raise ValueError(f"Cannot convert the edge type '{self}' to a "
322
345
  f"tuple since it holds invalid characters")
323
- return out
346
+ return self.edge_type
347
+
348
+ def __reduce__(self) -> tuple[Any, Any]:
349
+ return (self.__class__, (self.edge_type, ))
324
350
 
325
351
 
326
352
  # There exist some short-cuts to query edge-types (given that the full triplet
@@ -358,3 +384,14 @@ MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]]
358
384
 
359
385
  InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
360
386
  InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]
387
+
388
+ # Serialization ###############################################################
389
+
390
+ if WITH_PT24:
391
+ torch.serialization.add_safe_globals([
392
+ SparseTensor,
393
+ SparseStorage,
394
+ TensorFrame,
395
+ MockTorchCSCTensor,
396
+ EdgeTypeStr,
397
+ ])
@@ -21,6 +21,7 @@ from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph,
21
21
  from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
22
22
  from ._homophily import homophily
23
23
  from ._assortativity import assortativity
24
+ from ._normalize_edge_index import normalize_edge_index
24
25
  from .laplacian import get_laplacian
25
26
  from .mesh_laplacian import get_mesh_laplacian
26
27
  from .mask import mask_select, index_to_mask, mask_to_index
@@ -44,7 +45,7 @@ from .convert import to_networkit, from_networkit
44
45
  from .convert import to_trimesh, from_trimesh
45
46
  from .convert import to_cugraph, from_cugraph
46
47
  from .convert import to_dgl, from_dgl
47
- from .smiles import from_smiles, to_smiles
48
+ from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles
48
49
  from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
49
50
  barabasi_albert_graph)
50
51
  from ._negative_sampling import (negative_sampling, batched_negative_sampling,
@@ -89,6 +90,7 @@ __all__ = [
89
90
  'dropout_adj',
90
91
  'homophily',
91
92
  'assortativity',
93
+ 'normalize_edge_index',
92
94
  'get_laplacian',
93
95
  'get_mesh_laplacian',
94
96
  'mask_select',
@@ -127,6 +129,8 @@ __all__ = [
127
129
  'from_cugraph',
128
130
  'to_dgl',
129
131
  'from_dgl',
132
+ 'from_rdmol',
133
+ 'to_rdmol',
130
134
  'from_smiles',
131
135
  'to_smiles',
132
136
  'erdos_renyi_graph',
@@ -265,7 +265,7 @@ def structured_negative_sampling_feasible(
265
265
  :meth:`~torch_geometric.utils.structured_negative_sampling` is feasible
266
266
  on the graph given by :obj:`edge_index`.
267
267
  :meth:`~torch_geometric.utils.structured_negative_sampling` is infeasible
268
- if atleast one node is connected to all other nodes.
268
+ if at least one node is connected to all other nodes.
269
269
 
270
270
  Args:
271
271
  edge_index (LongTensor): The edge indices.
@@ -0,0 +1,46 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.utils import add_self_loops as add_self_loops_fn
7
+ from torch_geometric.utils import degree
8
+
9
+
10
+ def normalize_edge_index(
11
+ edge_index: Tensor,
12
+ num_nodes: Optional[int] = None,
13
+ add_self_loops: bool = True,
14
+ symmetric: bool = True,
15
+ ) -> Tuple[Tensor, Tensor]:
16
+ """Applies normalization to the edges of a graph.
17
+
18
+ This function can add self-loops to the graph and apply either symmetric or
19
+ asymmetric normalization based on the node degrees.
20
+
21
+ Args:
22
+ edge_index (LongTensor): The edge indices.
23
+ num_nodes (int, int], optional): The number of nodes, *i.e.*
24
+ :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
25
+ add_self_loops (bool, optional): If set to :obj:`False`, will not add
26
+ self-loops to the input graph. (default: :obj:`True`)
27
+ symmetric (bool, optional): If set to :obj:`True`, symmetric
28
+ normalization (:math:`D^{-1/2} A D^{-1/2}`) is used, otherwise
29
+ asymmetric normalization (:math:`D^{-1} A`).
30
+ """
31
+ if add_self_loops:
32
+ edge_index, _ = add_self_loops_fn(edge_index, num_nodes=num_nodes)
33
+
34
+ row, col = edge_index[0], edge_index[1]
35
+ deg = degree(row, num_nodes, dtype=torch.get_default_dtype())
36
+
37
+ if symmetric: # D^-1/2 * A * D^-1/2
38
+ deg_inv_sqrt = deg.pow(-0.5)
39
+ deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
40
+ edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
41
+ else: # D^-1 * A
42
+ deg_inv = deg.pow(-1)
43
+ deg_inv[torch.isinf(deg_inv)] = 0
44
+ edge_weight = deg_inv[row]
45
+
46
+ return edge_index, edge_weight
@@ -4,7 +4,7 @@ import torch
4
4
  from torch import Tensor
5
5
 
6
6
  import torch_geometric.typing
7
- from torch_geometric import is_compiling, warnings
7
+ from torch_geometric import is_compiling, is_in_onnx_export, warnings
8
8
  from torch_geometric.typing import torch_scatter
9
9
  from torch_geometric.utils.functions import cumsum
10
10
 
@@ -88,18 +88,33 @@ if torch_geometric.typing.WITH_PT112: # pragma: no cover
88
88
  # in case the input does not require gradients:
89
89
  if reduce in ['min', 'max', 'amin', 'amax']:
90
90
  if (not torch_geometric.typing.WITH_TORCH_SCATTER
91
- or is_compiling() or not src.is_cuda
91
+ or is_compiling() or is_in_onnx_export() or not src.is_cuda
92
92
  or not src.requires_grad):
93
93
 
94
- if src.is_cuda and src.requires_grad and not is_compiling():
94
+ if (src.is_cuda and src.requires_grad and not is_compiling()
95
+ and not is_in_onnx_export()):
95
96
  warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
96
97
  f"can be accelerated via the 'torch-scatter'"
97
98
  f" package, but it was not found")
98
99
 
99
100
  index = broadcast(index, src, dim)
100
- return src.new_zeros(size).scatter_reduce_(
101
- dim, index, src, reduce=f'a{reduce[-3:]}',
102
- include_self=False)
101
+ if not is_in_onnx_export():
102
+ return src.new_zeros(size).scatter_reduce_(
103
+ dim, index, src, reduce=f'a{reduce[-3:]}',
104
+ include_self=False)
105
+
106
+ fill = torch.full( # type: ignore
107
+ size=(1, ),
108
+ fill_value=src.min() if 'max' in reduce else src.max(),
109
+ dtype=src.dtype,
110
+ device=src.device,
111
+ ).expand_as(src)
112
+ out = src.new_zeros(size).scatter_reduce_(
113
+ dim, index, fill, reduce=f'a{reduce[-3:]}',
114
+ include_self=True)
115
+ return out.scatter_reduce_(dim, index, src,
116
+ reduce=f'a{reduce[-3:]}',
117
+ include_self=True)
103
118
 
104
119
  return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
105
120
  reduce=reduce[-3:])
@@ -175,6 +190,7 @@ else: # pragma: no cover
175
190
 
176
191
 
177
192
  def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
193
+ dim = ref.dim() + dim if dim < 0 else dim
178
194
  size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1))
179
195
  return src.view(size).expand_as(ref)
180
196
 
@@ -186,7 +202,8 @@ def scatter_argmax(
186
202
  dim_size: Optional[int] = None,
187
203
  ) -> Tensor:
188
204
 
189
- if torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling():
205
+ if (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()
206
+ and not is_in_onnx_export()):
190
207
  out = torch_scatter.scatter_max(src, index, dim=dim, dim_size=dim_size)
191
208
  return out[1]
192
209
 
@@ -199,9 +216,18 @@ def scatter_argmax(
199
216
  dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
200
217
 
201
218
  if torch_geometric.typing.WITH_PT112:
202
- res = src.new_empty(dim_size)
203
- res.scatter_reduce_(0, index, src.detach(), reduce='amax',
204
- include_self=False)
219
+ if not is_in_onnx_export():
220
+ res = src.new_empty(dim_size)
221
+ res.scatter_reduce_(0, index, src.detach(), reduce='amax',
222
+ include_self=False)
223
+ else:
224
+ # `include_self=False` is currently not supported by ONNX:
225
+ res = src.new_full(
226
+ size=(dim_size, ),
227
+ fill_value=src.min(), # type: ignore
228
+ )
229
+ res.scatter_reduce_(0, index, src.detach(), reduce="amax",
230
+ include_self=True)
205
231
  elif torch_geometric.typing.WITH_PT111:
206
232
  res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
207
233
  output_size=dim_size) # type: ignore
@@ -294,7 +320,7 @@ def group_cat(
294
320
  r"""Concatenates the given sequence of tensors :obj:`tensors` in the given
295
321
  dimension :obj:`dim`.
296
322
  Different from :meth:`torch.cat`, values along the concatenating dimension
297
- are grouped according to the indicies defined in the :obj:`index` tensors.
323
+ are grouped according to the indices defined in the :obj:`index` tensors.
298
324
  All tensors must have the same shape (except in the concatenating
299
325
  dimension).
300
326
 
@@ -325,5 +351,5 @@ def group_cat(
325
351
  """
326
352
  assert len(tensors) == len(indices)
327
353
  index, perm = torch.cat(indices).sort(stable=True)
328
- out = torch.cat(tensors, dim=0)[perm]
354
+ out = torch.cat(tensors, dim=dim).index_select(dim, perm)
329
355
  return (out, index) if return_index else out
@@ -346,10 +346,12 @@ def k_hop_subgraph(
346
346
 
347
347
  subsets = [node_idx]
348
348
 
349
+ preserved_edge_mask = torch.zeros_like(edge_mask)
349
350
  for _ in range(num_hops):
350
351
  node_mask.fill_(False)
351
352
  node_mask[subsets[-1]] = True
352
353
  torch.index_select(node_mask, 0, row, out=edge_mask)
354
+ preserved_edge_mask |= edge_mask
353
355
  subsets.append(col[edge_mask])
354
356
 
355
357
  subset, inv = torch.cat(subsets).unique(return_inverse=True)
@@ -360,6 +362,8 @@ def k_hop_subgraph(
360
362
 
361
363
  if not directed:
362
364
  edge_mask = node_mask[row] & node_mask[col]
365
+ else:
366
+ edge_mask = preserved_edge_mask
363
367
 
364
368
  edge_index = edge_index[:, edge_mask]
365
369
 
@@ -2,7 +2,6 @@ from itertools import chain
2
2
  from typing import Any, List, Literal, Tuple, Union, overload
3
3
 
4
4
  import torch
5
- from scipy.sparse.csgraph import minimum_spanning_tree
6
5
  from torch import Tensor
7
6
 
8
7
  from torch_geometric.utils import (
@@ -54,6 +53,7 @@ def tree_decomposition(
54
53
  :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)`
55
54
  """
56
55
  import rdkit.Chem as Chem
56
+ from scipy.sparse.csgraph import minimum_spanning_tree
57
57
 
58
58
  # Cliques = rings and bonds.
59
59
  cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)]
@@ -64,7 +64,7 @@ def tree_decomposition(
64
64
  xs.append(1)
65
65
 
66
66
  # Generate `atom2cliques` mappings.
67
- atom2cliques: List[List[int]] = [[] for i in range(mol.GetNumAtoms())]
67
+ atom2cliques: List[List[int]] = [[] for _ in range(mol.GetNumAtoms())]
68
68
  for c in range(len(cliques)):
69
69
  for atom in cliques[c]:
70
70
  atom2cliques[atom].append(c)
@@ -12,7 +12,7 @@ def shuffle_node(
12
12
  training: bool = True,
13
13
  ) -> Tuple[Tensor, Tensor]:
14
14
  r"""Randomly shuffle the feature matrix :obj:`x` along the
15
- first dimmension.
15
+ first dimension.
16
16
 
17
17
  The method returns (1) the shuffled :obj:`x`, (2) the permutation
18
18
  indicating the orders of original nodes after shuffling.
@@ -1,7 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
3
3
 
4
- import scipy.sparse
5
4
  import torch
6
5
  from torch import Tensor
7
6
  from torch.utils.dlpack import from_dlpack, to_dlpack
@@ -14,7 +13,7 @@ def to_scipy_sparse_matrix(
14
13
  edge_index: Tensor,
15
14
  edge_attr: Optional[Tensor] = None,
16
15
  num_nodes: Optional[int] = None,
17
- ) -> scipy.sparse.coo_matrix:
16
+ ) -> Any:
18
17
  r"""Converts a graph given by edge indices and edge attributes to a scipy
19
18
  sparse matrix.
20
19
 
@@ -34,22 +33,23 @@ def to_scipy_sparse_matrix(
34
33
  <4x4 sparse matrix of type '<class 'numpy.float32'>'
35
34
  with 6 stored elements in COOrdinate format>
36
35
  """
36
+ import scipy.sparse as sp
37
+
37
38
  row, col = edge_index.cpu()
38
39
 
39
40
  if edge_attr is None:
40
- edge_attr = torch.ones(row.size(0))
41
+ edge_attr = torch.ones(row.size(0), device="cpu")
41
42
  else:
42
43
  edge_attr = edge_attr.view(-1).cpu()
43
44
  assert edge_attr.size(0) == row.size(0)
44
45
 
45
46
  N = maybe_num_nodes(edge_index, num_nodes)
46
- out = scipy.sparse.coo_matrix(
47
+ out = sp.coo_matrix( #
47
48
  (edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N))
48
49
  return out
49
50
 
50
51
 
51
- def from_scipy_sparse_matrix(
52
- A: scipy.sparse.spmatrix) -> Tuple[Tensor, Tensor]:
52
+ def from_scipy_sparse_matrix(A: Any) -> Tuple[Tensor, Tensor]:
53
53
  r"""Converts a scipy sparse matrix to edge indices and edge attributes.
54
54
 
55
55
  Args:
@@ -527,10 +527,14 @@ def to_dgl(
527
527
  if isinstance(data, Data):
528
528
  if data.edge_index is not None:
529
529
  row, col = data.edge_index
530
- else:
530
+ elif 'adj' in data:
531
+ row, col, _ = data.adj.coo()
532
+ elif 'adj_t' in data:
531
533
  row, col, _ = data.adj_t.t().coo()
534
+ else:
535
+ row, col = [], []
532
536
 
533
- g = dgl.graph((row, col))
537
+ g = dgl.graph((row, col), num_nodes=data.num_nodes)
534
538
 
535
539
  for attr in data.node_attrs():
536
540
  g.ndata[attr] = data[attr]
@@ -2,6 +2,7 @@ import multiprocessing as mp
2
2
  import warnings
3
3
  from typing import Optional
4
4
 
5
+ import numpy as np
5
6
  import torch
6
7
  from torch import Tensor
7
8
 
@@ -82,54 +83,55 @@ def geodesic_distance( # noqa: D417
82
83
 
83
84
  dtype = pos.dtype
84
85
 
85
- pos = pos.detach().cpu().to(torch.double).numpy()
86
- face = face.detach().t().cpu().to(torch.int).numpy()
86
+ pos_np = pos.detach().cpu().to(torch.double).numpy()
87
+ face_np = face.detach().t().cpu().to(torch.int).numpy()
87
88
 
88
89
  if src is None and dst is None:
89
- out = gdist.local_gdist_matrix(pos, face,
90
- max_distance * scale).toarray() / scale
90
+ out = gdist.local_gdist_matrix(
91
+ pos_np,
92
+ face_np,
93
+ max_distance * scale,
94
+ ).toarray() / scale
91
95
  return torch.from_numpy(out).to(dtype)
92
96
 
93
97
  if src is None:
94
- src = torch.arange(pos.shape[0], dtype=torch.int).numpy()
98
+ src_np = torch.arange(pos.size(0), dtype=torch.int).numpy()
95
99
  else:
96
- src = src.detach().cpu().to(torch.int).numpy()
97
- assert src is not None
100
+ src_np = src.detach().cpu().to(torch.int).numpy()
98
101
 
99
- dst = None if dst is None else dst.detach().cpu().to(torch.int).numpy()
102
+ dst_np = None if dst is None else dst.detach().cpu().to(torch.int).numpy()
100
103
 
101
104
  def _parallel_loop(
102
- pos: Tensor,
103
- face: Tensor,
104
- src: Tensor,
105
- dst: Optional[Tensor],
105
+ pos_np: np.ndarray,
106
+ face_np: np.ndarray,
107
+ src_np: np.ndarray,
108
+ dst_np: Optional[np.ndarray],
106
109
  max_distance: float,
107
110
  scale: float,
108
111
  i: int,
109
112
  dtype: torch.dtype,
110
113
  ) -> Tensor:
111
- s = src[i:i + 1]
112
- d = None if dst is None else dst[i:i + 1]
113
- out = gdist.compute_gdist(pos, face, s, d, max_distance * scale)
114
+ s = src_np[i:i + 1]
115
+ d = None if dst_np is None else dst_np[i:i + 1]
116
+ out = gdist.compute_gdist(pos_np, face_np, s, d, max_distance * scale)
114
117
  out = out / scale
115
118
  return torch.from_numpy(out).to(dtype)
116
119
 
117
120
  num_workers = mp.cpu_count() if num_workers <= -1 else num_workers
118
121
  if num_workers > 0:
119
122
  with mp.Pool(num_workers) as pool:
120
- outs = pool.starmap(
121
- _parallel_loop,
122
- [(pos, face, src, dst, max_distance, scale, i, dtype)
123
- for i in range(len(src))])
123
+ data = [(pos_np, face_np, src_np, dst_np, max_distance, scale, i,
124
+ dtype) for i in range(len(src_np))]
125
+ outs = pool.starmap(_parallel_loop, data)
124
126
  else:
125
127
  outs = [
126
- _parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype)
127
- for i in range(len(src))
128
+ _parallel_loop(pos_np, face_np, src_np, dst_np, max_distance,
129
+ scale, i, dtype) for i in range(len(src_np))
128
130
  ]
129
131
 
130
132
  out = torch.cat(outs, dim=0)
131
133
 
132
134
  if dst is None:
133
- out = out.view(-1, pos.shape[0])
135
+ out = out.view(-1, pos.size(0))
134
136
 
135
137
  return out
@@ -63,7 +63,7 @@ def group_hetero_graph(
63
63
 
64
64
  def get_unused_node_types(node_types: List[NodeType],
65
65
  edge_types: List[EdgeType]) -> Set[NodeType]:
66
- dst_node_types = set(edge_type[-1] for edge_type in edge_types)
66
+ dst_node_types = {edge_type[-1] for edge_type in edge_types}
67
67
  return set(node_types) - set(dst_node_types)
68
68
 
69
69