pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -4,174 +4,138 @@ import torch
4
4
  from torch import Tensor
5
5
 
6
6
  import torch_geometric.typing
7
- from torch_geometric import is_compiling, warnings
7
+ from torch_geometric import is_compiling, is_in_onnx_export, warnings
8
8
  from torch_geometric.typing import torch_scatter
9
9
  from torch_geometric.utils.functions import cumsum
10
10
 
11
- if torch_geometric.typing.WITH_PT112: # pragma: no cover
12
-
13
- warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
14
-
15
- def scatter(
16
- src: Tensor,
17
- index: Tensor,
18
- dim: int = 0,
19
- dim_size: Optional[int] = None,
20
- reduce: str = 'sum',
21
- ) -> Tensor:
22
- r"""Reduces all values from the :obj:`src` tensor at the indices
23
- specified in the :obj:`index` tensor along a given dimension
24
- :obj:`dim`. See the `documentation
25
- <https://pytorch-scatter.readthedocs.io/en/latest/functions/
26
- scatter.html>`__ of the :obj:`torch_scatter` package for more
27
- information.
28
-
29
- Args:
30
- src (torch.Tensor): The source tensor.
31
- index (torch.Tensor): The index tensor.
32
- dim (int, optional): The dimension along which to index.
33
- (default: :obj:`0`)
34
- dim_size (int, optional): The size of the output tensor at
35
- dimension :obj:`dim`. If set to :obj:`None`, will create a
36
- minimal-sized output tensor according to
37
- :obj:`index.max() + 1`. (default: :obj:`None`)
38
- reduce (str, optional): The reduce operation (:obj:`"sum"`,
39
- :obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
40
- :obj:`"any"`). (default: :obj:`"sum"`)
41
- """
42
- if isinstance(index, Tensor) and index.dim() != 1:
43
- raise ValueError(f"The `index` argument must be one-dimensional "
44
- f"(got {index.dim()} dimensions)")
45
-
46
- dim = src.dim() + dim if dim < 0 else dim
47
-
48
- if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
49
- raise ValueError(f"The `dim` argument must lay between 0 and "
50
- f"{src.dim() - 1} (got {dim})")
51
-
52
- if dim_size is None:
53
- dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
54
-
55
- # For now, we maintain various different code paths, based on whether
56
- # the input requires gradients and whether it lays on the CPU/GPU.
57
- # For example, `torch_scatter` is usually faster than
58
- # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
59
- # on CPU.
60
- # `torch.scatter_reduce` has a faster forward implementation for
61
- # "min"/"max" reductions since it does not compute additional arg
62
- # indices, but is therefore way slower in its backward implementation.
63
- # More insights can be found in `test/utils/test_scatter.py`.
64
-
65
- size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
66
-
67
- # For "any" reduction, we use regular `scatter_`:
68
- if reduce == 'any':
69
- index = broadcast(index, src, dim)
70
- return src.new_zeros(size).scatter_(dim, index, src)
11
+ warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
71
12
 
72
- # For "sum" and "mean" reduction, we make use of `scatter_add_`:
73
- if reduce == 'sum' or reduce == 'add':
74
- index = broadcast(index, src, dim)
75
- return src.new_zeros(size).scatter_add_(dim, index, src)
76
13
 
77
- if reduce == 'mean':
78
- count = src.new_zeros(dim_size)
79
- count.scatter_add_(0, index, src.new_ones(src.size(dim)))
80
- count = count.clamp(min=1)
14
+ def scatter(
15
+ src: Tensor,
16
+ index: Tensor,
17
+ dim: int = 0,
18
+ dim_size: Optional[int] = None,
19
+ reduce: str = 'sum',
20
+ ) -> Tensor:
21
+ r"""Reduces all values from the :obj:`src` tensor at the indices specified
22
+ in the :obj:`index` tensor along a given dimension ``dim``. See the
23
+ `documentation <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`__ # noqa: E501
24
+ of the ``torch_scatter`` package for more information.
81
25
 
82
- index = broadcast(index, src, dim)
83
- out = src.new_zeros(size).scatter_add_(dim, index, src)
26
+ Args:
27
+ src (torch.Tensor): The source tensor.
28
+ index (torch.Tensor): The index tensor.
29
+ dim (int, optional): The dimension along which to index.
30
+ (default: ``0``)
31
+ dim_size (int, optional): The size of the output tensor at dimension
32
+ ``dim``. If set to :obj:`None`, will create a minimal-sized output
33
+ tensor according to ``index.max() + 1``. (default: :obj:`None`)
34
+ reduce (str, optional): The reduce operation (``"sum"``, ``"mean"``,
35
+ ``"mul"``, ``"min"``, ``"max"`` or ``"any"``). (default: ``"sum"``)
36
+ """
37
+ if isinstance(index, Tensor) and index.dim() != 1:
38
+ raise ValueError(f"The `index` argument must be one-dimensional "
39
+ f"(got {index.dim()} dimensions)")
40
+
41
+ dim = src.dim() + dim if dim < 0 else dim
84
42
 
85
- return out / broadcast(count, out, dim)
43
+ if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
44
+ raise ValueError(f"The `dim` argument must lay between 0 and "
45
+ f"{src.dim() - 1} (got {dim})")
86
46
 
87
- # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
88
- # in case the input does not require gradients:
89
- if reduce in ['min', 'max', 'amin', 'amax']:
90
- if (not torch_geometric.typing.WITH_TORCH_SCATTER
91
- or is_compiling() or not src.is_cuda
92
- or not src.requires_grad):
47
+ if dim_size is None:
48
+ dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
93
49
 
94
- if src.is_cuda and src.requires_grad and not is_compiling():
95
- warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
96
- f"can be accelerated via the 'torch-scatter'"
97
- f" package, but it was not found")
50
+ # For now, we maintain various different code paths, based on whether
51
+ # the input requires gradients and whether it lays on the CPU/GPU.
52
+ # For example, `torch_scatter` is usually faster than
53
+ # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
54
+ # on CPU.
55
+ # `torch.scatter_reduce` has a faster forward implementation for
56
+ # "min"/"max" reductions since it does not compute additional arg
57
+ # indices, but is therefore way slower in its backward implementation.
58
+ # More insights can be found in `test/utils/test_scatter.py`.
59
+
60
+ size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
61
+
62
+ # For "any" reduction, we use regular `scatter_`:
63
+ if reduce == 'any':
64
+ index = broadcast(index, src, dim)
65
+ return src.new_zeros(size).scatter_(dim, index, src)
66
+
67
+ # For "sum" and "mean" reduction, we make use of `scatter_add_`:
68
+ if reduce == 'sum' or reduce == 'add':
69
+ index = broadcast(index, src, dim)
70
+ return src.new_zeros(size).scatter_add_(dim, index, src)
71
+
72
+ if reduce == 'mean':
73
+ count = src.new_zeros(dim_size)
74
+ count.scatter_add_(0, index, src.new_ones(src.size(dim)))
75
+ count = count.clamp(min=1)
76
+
77
+ index = broadcast(index, src, dim)
78
+ out = src.new_zeros(size).scatter_add_(dim, index, src)
79
+
80
+ return out / broadcast(count, out, dim)
81
+
82
+ # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
83
+ # in case the input does not require gradients:
84
+ if reduce in ['min', 'max', 'amin', 'amax']:
85
+ if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
86
+ or is_in_onnx_export() or not src.is_cuda
87
+ or not src.requires_grad):
88
+
89
+ if (src.is_cuda and src.requires_grad and not is_compiling()
90
+ and not is_in_onnx_export()):
91
+ warnings.warn(
92
+ f"The usage of `scatter(reduce='{reduce}')` "
93
+ f"can be accelerated via the 'torch-scatter'"
94
+ f" package, but it was not found", stacklevel=2)
98
95
 
99
- index = broadcast(index, src, dim)
96
+ index = broadcast(index, src, dim)
97
+ if not is_in_onnx_export():
100
98
  return src.new_zeros(size).scatter_reduce_(
101
99
  dim, index, src, reduce=f'a{reduce[-3:]}',
102
100
  include_self=False)
103
101
 
104
- return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
105
- reduce=reduce[-3:])
106
-
107
- # For "mul" reduction, we prefer `scatter_reduce_` on CPU:
108
- if reduce == 'mul':
109
- if (not torch_geometric.typing.WITH_TORCH_SCATTER
110
- or is_compiling() or not src.is_cuda):
111
-
112
- if src.is_cuda and not is_compiling():
113
- warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
114
- f"can be accelerated via the 'torch-scatter'"
115
- f" package, but it was not found")
116
-
117
- index = broadcast(index, src, dim)
118
- # We initialize with `one` here to match `scatter_mul` output:
119
- return src.new_ones(size).scatter_reduce_(
120
- dim, index, src, reduce='prod', include_self=True)
121
-
122
- return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
123
- reduce='mul')
124
-
125
- raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
126
-
127
- else: # pragma: no cover
128
-
129
- def scatter(
130
- src: Tensor,
131
- index: Tensor,
132
- dim: int = 0,
133
- dim_size: Optional[int] = None,
134
- reduce: str = 'sum',
135
- ) -> Tensor:
136
- r"""Reduces all values from the :obj:`src` tensor at the indices
137
- specified in the :obj:`index` tensor along a given dimension
138
- :obj:`dim`. See the `documentation
139
- <https://pytorch-scatter.readthedocs.io/en/latest/functions/
140
- scatter.html>`_ of the :obj:`torch_scatter` package for more
141
- information.
142
-
143
- Args:
144
- src (torch.Tensor): The source tensor.
145
- index (torch.Tensor): The index tensor.
146
- dim (int, optional): The dimension along which to index.
147
- (default: :obj:`0`)
148
- dim_size (int, optional): The size of the output tensor at
149
- dimension :obj:`dim`. If set to :obj:`None`, will create a
150
- minimal-sized output tensor according to
151
- :obj:`index.max() + 1`. (default: :obj:`None`)
152
- reduce (str, optional): The reduce operation (:obj:`"sum"`,
153
- :obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
154
- :obj:`"any"`). (default: :obj:`"sum"`)
155
- """
156
- if reduce == 'any':
157
- dim = src.dim() + dim if dim < 0 else dim
158
-
159
- if dim_size is None:
160
- dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
161
-
162
- size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
102
+ fill = torch.full( # type: ignore
103
+ size=(1, ),
104
+ fill_value=src.min() if 'max' in reduce else src.max(),
105
+ dtype=src.dtype,
106
+ device=src.device,
107
+ ).expand_as(src)
108
+ out = src.new_zeros(size).scatter_reduce_(dim, index, fill,
109
+ reduce=f'a{reduce[-3:]}',
110
+ include_self=True)
111
+ return out.scatter_reduce_(dim, index, src,
112
+ reduce=f'a{reduce[-3:]}',
113
+ include_self=True)
163
114
 
164
- index = broadcast(index, src, dim)
165
- return src.new_zeros(size).scatter_(dim, index, src)
115
+ return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
116
+ reduce=reduce[-3:])
117
+
118
+ # For "mul" reduction, we prefer `scatter_reduce_` on CPU:
119
+ if reduce == 'mul':
120
+ if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
121
+ or not src.is_cuda):
166
122
 
167
- if not torch_geometric.typing.WITH_TORCH_SCATTER:
168
- raise ImportError("'scatter' requires the 'torch-scatter' package")
123
+ if src.is_cuda and not is_compiling():
124
+ warnings.warn(
125
+ f"The usage of `scatter(reduce='{reduce}')` "
126
+ f"can be accelerated via the 'torch-scatter'"
127
+ f" package, but it was not found", stacklevel=2)
169
128
 
170
- if reduce == 'amin' or reduce == 'amax':
171
- reduce = reduce[-3:]
129
+ index = broadcast(index, src, dim)
130
+ # We initialize with `one` here to match `scatter_mul` output:
131
+ return src.new_ones(size).scatter_reduce_(dim, index, src,
132
+ reduce='prod',
133
+ include_self=True)
172
134
 
173
135
  return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
174
- reduce=reduce)
136
+ reduce='mul')
137
+
138
+ raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
175
139
 
176
140
 
177
141
  def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
@@ -187,7 +151,8 @@ def scatter_argmax(
187
151
  dim_size: Optional[int] = None,
188
152
  ) -> Tensor:
189
153
 
190
- if torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling():
154
+ if (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()
155
+ and not is_in_onnx_export()):
191
156
  out = torch_scatter.scatter_max(src, index, dim=dim, dim_size=dim_size)
192
157
  return out[1]
193
158
 
@@ -199,15 +164,18 @@ def scatter_argmax(
199
164
  if dim_size is None:
200
165
  dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
201
166
 
202
- if torch_geometric.typing.WITH_PT112:
167
+ if not is_in_onnx_export():
203
168
  res = src.new_empty(dim_size)
204
169
  res.scatter_reduce_(0, index, src.detach(), reduce='amax',
205
170
  include_self=False)
206
- elif torch_geometric.typing.WITH_PT111:
207
- res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
208
- output_size=dim_size) # type: ignore
209
171
  else:
210
- raise ValueError("'scatter_argmax' requires PyTorch >= 1.11")
172
+ # `include_self=False` is currently not supported by ONNX:
173
+ res = src.new_full(
174
+ size=(dim_size, ),
175
+ fill_value=src.min(), # type: ignore
176
+ )
177
+ res.scatter_reduce_(0, index, src.detach(), reduce="amax",
178
+ include_self=True)
211
179
 
212
180
  out = index.new_full((dim_size, ), fill_value=dim_size - 1)
213
181
  nonzero = (src == res[index]).nonzero().view(-1)
@@ -265,13 +233,7 @@ def group_argsort(
265
233
 
266
234
  # Compute `grouped_argsort`:
267
235
  src = src - 2 * index if descending else src + 2 * index
268
- if torch_geometric.typing.WITH_PT113:
269
- perm = src.argsort(descending=descending, stable=stable)
270
- else:
271
- perm = src.argsort(descending=descending)
272
- if stable:
273
- warnings.warn("Ignoring option `stable=True` in 'group_argsort' "
274
- "since it requires PyTorch >= 1.13.0")
236
+ perm = src.argsort(descending=descending, stable=stable)
275
237
  out = torch.empty_like(index)
276
238
  out[perm] = torch.arange(index.numel(), device=index.device)
277
239
 
@@ -295,7 +257,7 @@ def group_cat(
295
257
  r"""Concatenates the given sequence of tensors :obj:`tensors` in the given
296
258
  dimension :obj:`dim`.
297
259
  Different from :meth:`torch.cat`, values along the concatenating dimension
298
- are grouped according to the indicies defined in the :obj:`index` tensors.
260
+ are grouped according to the indices defined in the :obj:`index` tensors.
299
261
  All tensors must have the same shape (except in the concatenating
300
262
  dimension).
301
263
 
@@ -326,5 +288,5 @@ def group_cat(
326
288
  """
327
289
  assert len(tensors) == len(indices)
328
290
  index, perm = torch.cat(indices).sort(stable=True)
329
- out = torch.cat(tensors, dim=0)[perm]
291
+ out = torch.cat(tensors, dim=dim).index_select(dim, perm)
330
292
  return (out, index) if return_index else out
@@ -107,8 +107,6 @@ def sort_edge_index( # noqa: F811
107
107
  num_nodes = maybe_num_nodes(edge_index, num_nodes)
108
108
 
109
109
  if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:
110
- if not torch_geometric.typing.WITH_PT113:
111
- raise ValueError("'sort_edge_index' will result in an overflow")
112
110
  perm = lexsort(keys=[
113
111
  edge_index[int(sort_by_row)],
114
112
  edge_index[1 - int(sort_by_row)],
@@ -63,18 +63,20 @@ def spmm(
63
63
 
64
64
  # Always convert COO to CSR for more efficient processing:
65
65
  if src.layout == torch.sparse_coo:
66
- warnings.warn(f"Converting sparse tensor to CSR format for more "
67
- f"efficient processing. Consider converting your "
68
- f"sparse tensor to CSR format beforehand to avoid "
69
- f"repeated conversion (got '{src.layout}')")
66
+ warnings.warn(
67
+ f"Converting sparse tensor to CSR format for more "
68
+ f"efficient processing. Consider converting your "
69
+ f"sparse tensor to CSR format beforehand to avoid "
70
+ f"repeated conversion (got '{src.layout}')", stacklevel=2)
70
71
  src = src.to_sparse_csr()
71
72
 
72
73
  # Warn in case of CSC format without gradient computation:
73
74
  if src.layout == torch.sparse_csc and not other.requires_grad:
74
- warnings.warn(f"Converting sparse tensor to CSR format for more "
75
- f"efficient processing. Consider converting your "
76
- f"sparse tensor to CSR format beforehand to avoid "
77
- f"repeated conversion (got '{src.layout}')")
75
+ warnings.warn(
76
+ f"Converting sparse tensor to CSR format for more "
77
+ f"efficient processing. Consider converting your "
78
+ f"sparse tensor to CSR format beforehand to avoid "
79
+ f"repeated conversion (got '{src.layout}')", stacklevel=2)
78
80
 
79
81
  # Use the default code path for `sum` reduction (works on CPU/GPU):
80
82
  if reduce == 'sum':
@@ -99,10 +101,11 @@ def spmm(
99
101
  # TODO The `torch.sparse.mm` code path with the `reduce` argument does
100
102
  # not yet support CSC :(
101
103
  if src.layout == torch.sparse_csc:
102
- warnings.warn(f"Converting sparse tensor to CSR format for more "
103
- f"efficient processing. Consider converting your "
104
- f"sparse tensor to CSR format beforehand to avoid "
105
- f"repeated conversion (got '{src.layout}')")
104
+ warnings.warn(
105
+ f"Converting sparse tensor to CSR format for more "
106
+ f"efficient processing. Consider converting your "
107
+ f"sparse tensor to CSR format beforehand to avoid "
108
+ f"repeated conversion (got '{src.layout}')", stacklevel=2)
106
109
  src = src.to_sparse_csr()
107
110
 
108
111
  return torch.sparse.mm(src, other, reduce)
@@ -115,8 +118,7 @@ def spmm(
115
118
  if src.layout == torch.sparse_csr:
116
119
  ptr = src.crow_indices()
117
120
  deg = ptr[1:] - ptr[:-1]
118
- elif (torch_geometric.typing.WITH_PT112
119
- and src.layout == torch.sparse_csc):
121
+ elif src.layout == torch.sparse_csc:
120
122
  assert src.layout == torch.sparse_csc
121
123
  ones = torch.ones_like(src.values())
122
124
  index = src.row_indices()
@@ -346,10 +346,12 @@ def k_hop_subgraph(
346
346
 
347
347
  subsets = [node_idx]
348
348
 
349
+ preserved_edge_mask = torch.zeros_like(edge_mask)
349
350
  for _ in range(num_hops):
350
351
  node_mask.fill_(False)
351
352
  node_mask[subsets[-1]] = True
352
353
  torch.index_select(node_mask, 0, row, out=edge_mask)
354
+ preserved_edge_mask |= edge_mask
353
355
  subsets.append(col[edge_mask])
354
356
 
355
357
  subset, inv = torch.cat(subsets).unique(return_inverse=True)
@@ -360,6 +362,8 @@ def k_hop_subgraph(
360
362
 
361
363
  if not directed:
362
364
  edge_mask = node_mask[row] & node_mask[col]
365
+ else:
366
+ edge_mask = preserved_edge_mask
363
367
 
364
368
  edge_index = edge_index[:, edge_mask]
365
369
 
@@ -64,7 +64,7 @@ def tree_decomposition(
64
64
  xs.append(1)
65
65
 
66
66
  # Generate `atom2cliques` mappings.
67
- atom2cliques: List[List[int]] = [[] for i in range(mol.GetNumAtoms())]
67
+ atom2cliques: List[List[int]] = [[] for _ in range(mol.GetNumAtoms())]
68
68
  for c in range(len(cliques)):
69
69
  for atom in cliques[c]:
70
70
  atom2cliques[atom].append(c)
@@ -234,10 +234,10 @@ def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int],
234
234
  rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone()
235
235
  rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes]
236
236
 
237
- col = torch.narrow(col, 0, 0, rowptr[-1])
237
+ col = torch.narrow(col, 0, 0, rowptr[-1]) # type: ignore
238
238
 
239
239
  if value is not None:
240
- value = torch.narrow(value, 0, 0, rowptr[-1])
240
+ value = torch.narrow(value, 0, 0, rowptr[-1]) # type: ignore
241
241
 
242
242
  csr2csc = src.storage._csr2csc
243
243
  if csr2csc is not None:
@@ -12,7 +12,7 @@ def shuffle_node(
12
12
  training: bool = True,
13
13
  ) -> Tuple[Tensor, Tensor]:
14
14
  r"""Randomly shuffle the feature matrix :obj:`x` along the
15
- first dimmension.
15
+ first dimension.
16
16
 
17
17
  The method returns (1) the shuffled :obj:`x`, (2) the permutation
18
18
  indicating the orders of original nodes after shuffling.
@@ -251,13 +251,13 @@ def from_networkx(
251
251
  if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):
252
252
  group_edge_attrs = edge_attrs
253
253
 
254
- for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
254
+ for _, feat_dict in G.nodes(data=True):
255
255
  if set(feat_dict.keys()) != set(node_attrs):
256
256
  raise ValueError('Not all nodes contain the same attributes')
257
257
  for key, value in feat_dict.items():
258
258
  data_dict[str(key)].append(value)
259
259
 
260
- for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
260
+ for _, _, feat_dict in G.edges(data=True):
261
261
  if set(feat_dict.keys()) != set(edge_attrs):
262
262
  raise ValueError('Not all edges contain the same attributes')
263
263
  for key, value in feat_dict.items():
@@ -452,15 +452,22 @@ def to_cugraph(
452
452
  g = cugraph.Graph(directed=directed)
453
453
  df = cudf.from_dlpack(to_dlpack(edge_index.t()))
454
454
 
455
+ df = cudf.DataFrame({
456
+ 'source':
457
+ cudf.from_dlpack(to_dlpack(edge_index[0])),
458
+ 'destination':
459
+ cudf.from_dlpack(to_dlpack(edge_index[1])),
460
+ })
461
+
455
462
  if edge_weight is not None:
456
463
  assert edge_weight.dim() == 1
457
- df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))
464
+ df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
458
465
 
459
466
  g.from_cudf_edgelist(
460
467
  df,
461
- source=0,
462
- destination=1,
463
- edge_attr='2' if edge_weight is not None else None,
468
+ source='source',
469
+ destination='destination',
470
+ edge_attr='weight' if edge_weight is not None else None,
464
471
  renumber=relabel_nodes,
465
472
  )
466
473
 
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
476
483
  """
477
484
  df = g.view_edge_list()
478
485
 
479
- src = from_dlpack(df[0].to_dlpack()).long()
480
- dst = from_dlpack(df[1].to_dlpack()).long()
486
+ src = from_dlpack(df[g.source_columns].to_dlpack()).long()
487
+ dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
481
488
  edge_index = torch.stack([src, dst], dim=0)
482
489
 
483
490
  edge_weight = None
484
- if '2' in df:
485
- edge_weight = from_dlpack(df['2'].to_dlpack())
491
+ if g.weight_column is not None:
492
+ edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
486
493
 
487
494
  return edge_index, edge_weight
488
495
 
@@ -18,30 +18,51 @@ class SparseCrossEntropy(torch.autograd.Function):
18
18
  ) -> Tensor:
19
19
  assert inputs.dim() == 2
20
20
 
21
- logsumexp = inputs.logsumexp(dim=-1)
22
- ctx.save_for_backward(inputs, edge_label_index, edge_label_weight,
23
- logsumexp)
21
+ # Support for both positive and negative weights:
22
+ # Positive weights scale the logits *after* softmax.
23
+ # Negative weights scale the denominator *before* softmax:
24
+ pos_y = edge_label_index
25
+ neg_y = pos_weight = neg_weight = None
24
26
 
25
- out = inputs[edge_label_index[0], edge_label_index[1]]
26
- out.neg_().add_(logsumexp[edge_label_index[0]])
27
27
  if edge_label_weight is not None:
28
- out *= edge_label_weight
28
+ pos_mask = edge_label_weight >= 0
29
+ pos_y = edge_label_index[:, pos_mask]
30
+ pos_weight = edge_label_weight[pos_mask]
31
+
32
+ if pos_y.size(1) < edge_label_index.size(1):
33
+ neg_mask = ~pos_mask
34
+ neg_y = edge_label_index[:, neg_mask]
35
+ neg_weight = edge_label_weight[neg_mask]
36
+
37
+ if neg_y is not None and neg_weight is not None:
38
+ inputs = inputs.clone()
39
+ inputs[
40
+ neg_y[0],
41
+ neg_y[1],
42
+ ] += neg_weight.abs().log().clamp(min=1e-12)
43
+
44
+ logsumexp = inputs.logsumexp(dim=-1)
45
+ ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)
46
+
47
+ out = inputs[pos_y[0], pos_y[1]]
48
+ out.neg_().add_(logsumexp[pos_y[0]])
49
+ if pos_weight is not None:
50
+ out *= pos_weight
29
51
 
30
52
  return out.sum() / inputs.size(0)
31
53
 
32
54
  @staticmethod
33
55
  @torch.autograd.function.once_differentiable
34
56
  def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:
35
- inputs, edge_label_index, edge_label_weight, logsumexp = (
36
- ctx.saved_tensors)
57
+ inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors
37
58
 
38
59
  grad_out = grad_out / inputs.size(0)
39
- grad_out = grad_out.expand(edge_label_index.size(1))
60
+ grad_out = grad_out.expand(pos_y.size(1))
40
61
 
41
- if edge_label_weight is not None:
42
- grad_out = grad_out * edge_label_weight
62
+ if pos_weight is not None:
63
+ grad_out = grad_out * pos_weight
43
64
 
44
- grad_logsumexp = scatter(grad_out, edge_label_index[0], dim=0,
65
+ grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,
45
66
  dim_size=inputs.size(0), reduce='sum')
46
67
 
47
68
  # Gradient computation of `logsumexp`: `grad * (self - result).exp()`
@@ -49,7 +70,7 @@ class SparseCrossEntropy(torch.autograd.Function):
49
70
  grad_input.exp_()
50
71
  grad_input.mul_(grad_logsumexp.view(-1, 1))
51
72
 
52
- grad_input[edge_label_index[0], edge_label_index[1]] -= grad_out
73
+ grad_input[pos_y[0], pos_y[1]] -= grad_out
53
74
 
54
75
  return grad_input, None, None
55
76