pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -53,10 +53,11 @@ from ._negative_sampling import (negative_sampling, batched_negative_sampling,
53
53
  structured_negative_sampling_feasible)
54
54
  from .augmentation import shuffle_node, mask_feature, add_random_edge
55
55
  from ._tree_decomposition import tree_decomposition
56
- from .embedding import get_embeddings
56
+ from .embedding import get_embeddings, get_embeddings_hetero
57
57
  from ._trim_to_layer import trim_to_layer
58
58
  from .ppr import get_ppr
59
59
  from ._train_test_split_edges import train_test_split_edges
60
+ from .influence import total_influence
60
61
 
61
62
  __all__ = [
62
63
  'scatter',
@@ -145,9 +146,11 @@ __all__ = [
145
146
  'add_random_edge',
146
147
  'tree_decomposition',
147
148
  'get_embeddings',
149
+ 'get_embeddings_hetero',
148
150
  'trim_to_layer',
149
151
  'get_ppr',
150
152
  'train_test_split_edges',
153
+ 'total_influence',
151
154
  ]
152
155
 
153
156
  # `structured_negative_sampling_feasible` is a long name and thus destroys the
@@ -1,11 +1,7 @@
1
1
  from typing import List
2
2
 
3
- import numpy as np
4
- import torch
5
3
  from torch import Tensor
6
4
 
7
- import torch_geometric.typing
8
-
9
5
 
10
6
  def lexsort(
11
7
  keys: List[Tensor],
@@ -28,11 +24,6 @@ def lexsort(
28
24
  """
29
25
  assert len(keys) >= 1
30
26
 
31
- if not torch_geometric.typing.WITH_PT113:
32
- keys = [k.neg() for k in keys] if descending else keys
33
- out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
34
- return torch.from_numpy(out).to(keys[0].device)
35
-
36
27
  out = keys[0].argsort(dim=dim, descending=descending, stable=True)
37
28
  for k in keys[1:]:
38
29
  index = k.gather(dim, out)
@@ -12,7 +12,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
12
12
  def negative_sampling(
13
13
  edge_index: Tensor,
14
14
  num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
15
- num_neg_samples: Optional[int] = None,
15
+ num_neg_samples: Optional[Union[int, float]] = None,
16
16
  method: str = "sparse",
17
17
  force_undirected: bool = False,
18
18
  ) -> Tensor:
@@ -25,10 +25,12 @@ def negative_sampling(
25
25
  If given as a tuple, then :obj:`edge_index` is interpreted as a
26
26
  bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.
27
27
  (default: :obj:`None`)
28
- num_neg_samples (int, optional): The (approximate) number of negative
29
- samples to return.
30
- If set to :obj:`None`, will try to return a negative edge for every
31
- positive edge. (default: :obj:`None`)
28
+ num_neg_samples (int or float, optional): The (approximate) number of
29
+ negative samples to return. If set to a floating-point value, it
30
+ represents the ratio of negative samples to generate based on the
31
+ number of positive edges. If set to :obj:`None`, will try to
32
+ return a negative edge for every positive edge.
33
+ (default: :obj:`None`)
32
34
  method (str, optional): The method to use for negative sampling,
33
35
  *i.e.* :obj:`"sparse"` or :obj:`"dense"`.
34
36
  This is a memory/runtime trade-off.
@@ -48,6 +50,11 @@ def negative_sampling(
48
50
  tensor([[3, 0, 0, 3],
49
51
  [2, 3, 2, 1]])
50
52
 
53
+ >>> negative_sampling(edge_index, num_nodes=(3, 4),
54
+ ... num_neg_samples=0.5) # 50% of positive edges
55
+ tensor([[0, 3],
56
+ [3, 0]])
57
+
51
58
  >>> # For bipartite graph
52
59
  >>> negative_sampling(edge_index, num_nodes=(3, 4))
53
60
  tensor([[0, 2, 2, 1],
@@ -74,6 +81,8 @@ def negative_sampling(
74
81
 
75
82
  if num_neg_samples is None:
76
83
  num_neg_samples = edge_index.size(1)
84
+ elif isinstance(num_neg_samples, float):
85
+ num_neg_samples = int(num_neg_samples * edge_index.size(1))
77
86
  if force_undirected:
78
87
  num_neg_samples = num_neg_samples // 2
79
88
 
@@ -100,10 +109,9 @@ def negative_sampling(
100
109
  idx = idx.to('cpu')
101
110
  for _ in range(3): # Number of tries to sample negative indices.
102
111
  rnd = sample(population, sample_size, device='cpu')
103
- mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore
112
+ mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
104
113
  if neg_idx is not None:
105
- mask |= np.isin(rnd, neg_idx.to('cpu'))
106
- mask = torch.from_numpy(mask).to(torch.bool)
114
+ mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
107
115
  rnd = rnd[~mask].to(edge_index.device)
108
116
  neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
109
117
  if neg_idx.numel() >= num_neg_samples:
@@ -117,7 +125,7 @@ def negative_sampling(
117
125
  def batched_negative_sampling(
118
126
  edge_index: Tensor,
119
127
  batch: Union[Tensor, Tuple[Tensor, Tensor]],
120
- num_neg_samples: Optional[int] = None,
128
+ num_neg_samples: Optional[Union[int, float]] = None,
121
129
  method: str = "sparse",
122
130
  force_undirected: bool = False,
123
131
  ) -> Tensor:
@@ -131,9 +139,11 @@ def batched_negative_sampling(
131
139
  node to a specific example.
132
140
  If given as a tuple, then :obj:`edge_index` is interpreted as a
133
141
  bipartite graph connecting two different node types.
134
- num_neg_samples (int, optional): The number of negative samples to
135
- return. If set to :obj:`None`, will try to return a negative edge
136
- for every positive edge. (default: :obj:`None`)
142
+ num_neg_samples (int or float, optional): The number of negative
143
+ samples to return. If set to :obj:`None`, will try to return a
144
+ negative edge for every positive edge. If float, it will generate
145
+ :obj:`num_neg_samples * num_edges` negative samples.
146
+ (default: :obj:`None`)
137
147
  method (str, optional): The method to use for negative sampling,
138
148
  *i.e.* :obj:`"sparse"` or :obj:`"dense"`.
139
149
  This is a memory/runtime trade-off.
@@ -157,6 +167,11 @@ def batched_negative_sampling(
157
167
  tensor([[3, 1, 3, 2, 7, 7, 6, 5],
158
168
  [2, 0, 1, 1, 5, 6, 4, 4]])
159
169
 
170
+ >>> # Using float multiplier for negative samples
171
+ >>> batched_negative_sampling(edge_index, batch, num_neg_samples=1.5)
172
+ tensor([[3, 1, 3, 2, 7, 7, 6, 5, 2, 0, 1, 1],
173
+ [2, 0, 1, 1, 5, 6, 4, 4, 3, 2, 3, 0]])
174
+
160
175
  >>> # For bipartite graph
161
176
  >>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])
162
177
  >>> edge_index2 = edge_index1 + torch.tensor([[2], [4]])
@@ -8,185 +8,134 @@ from torch_geometric import is_compiling, is_in_onnx_export, warnings
8
8
  from torch_geometric.typing import torch_scatter
9
9
  from torch_geometric.utils.functions import cumsum
10
10
 
11
- if torch_geometric.typing.WITH_PT112: # pragma: no cover
12
-
13
- warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
14
-
15
- def scatter(
16
- src: Tensor,
17
- index: Tensor,
18
- dim: int = 0,
19
- dim_size: Optional[int] = None,
20
- reduce: str = 'sum',
21
- ) -> Tensor:
22
- r"""Reduces all values from the :obj:`src` tensor at the indices
23
- specified in the :obj:`index` tensor along a given dimension
24
- :obj:`dim`. See the `documentation
25
- <https://pytorch-scatter.readthedocs.io/en/latest/functions/
26
- scatter.html>`__ of the :obj:`torch_scatter` package for more
27
- information.
28
-
29
- Args:
30
- src (torch.Tensor): The source tensor.
31
- index (torch.Tensor): The index tensor.
32
- dim (int, optional): The dimension along which to index.
33
- (default: :obj:`0`)
34
- dim_size (int, optional): The size of the output tensor at
35
- dimension :obj:`dim`. If set to :obj:`None`, will create a
36
- minimal-sized output tensor according to
37
- :obj:`index.max() + 1`. (default: :obj:`None`)
38
- reduce (str, optional): The reduce operation (:obj:`"sum"`,
39
- :obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
40
- :obj:`"any"`). (default: :obj:`"sum"`)
41
- """
42
- if isinstance(index, Tensor) and index.dim() != 1:
43
- raise ValueError(f"The `index` argument must be one-dimensional "
44
- f"(got {index.dim()} dimensions)")
45
-
46
- dim = src.dim() + dim if dim < 0 else dim
47
-
48
- if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
49
- raise ValueError(f"The `dim` argument must lay between 0 and "
50
- f"{src.dim() - 1} (got {dim})")
51
-
52
- if dim_size is None:
53
- dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
54
-
55
- # For now, we maintain various different code paths, based on whether
56
- # the input requires gradients and whether it lays on the CPU/GPU.
57
- # For example, `torch_scatter` is usually faster than
58
- # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
59
- # on CPU.
60
- # `torch.scatter_reduce` has a faster forward implementation for
61
- # "min"/"max" reductions since it does not compute additional arg
62
- # indices, but is therefore way slower in its backward implementation.
63
- # More insights can be found in `test/utils/test_scatter.py`.
64
-
65
- size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
66
-
67
- # For "any" reduction, we use regular `scatter_`:
68
- if reduce == 'any':
69
- index = broadcast(index, src, dim)
70
- return src.new_zeros(size).scatter_(dim, index, src)
11
+ warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
71
12
 
72
- # For "sum" and "mean" reduction, we make use of `scatter_add_`:
73
- if reduce == 'sum' or reduce == 'add':
74
- index = broadcast(index, src, dim)
75
- return src.new_zeros(size).scatter_add_(dim, index, src)
76
13
 
77
- if reduce == 'mean':
78
- count = src.new_zeros(dim_size)
79
- count.scatter_add_(0, index, src.new_ones(src.size(dim)))
80
- count = count.clamp(min=1)
14
+ def scatter(
15
+ src: Tensor,
16
+ index: Tensor,
17
+ dim: int = 0,
18
+ dim_size: Optional[int] = None,
19
+ reduce: str = 'sum',
20
+ ) -> Tensor:
21
+ r"""Reduces all values from the :obj:`src` tensor at the indices specified
22
+ in the :obj:`index` tensor along a given dimension ``dim``. See the
23
+ `documentation <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`__ # noqa: E501
24
+ of the ``torch_scatter`` package for more information.
81
25
 
82
- index = broadcast(index, src, dim)
83
- out = src.new_zeros(size).scatter_add_(dim, index, src)
84
-
85
- return out / broadcast(count, out, dim)
86
-
87
- # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
88
- # in case the input does not require gradients:
89
- if reduce in ['min', 'max', 'amin', 'amax']:
90
- if (not torch_geometric.typing.WITH_TORCH_SCATTER
91
- or is_compiling() or is_in_onnx_export() or not src.is_cuda
92
- or not src.requires_grad):
93
-
94
- if (src.is_cuda and src.requires_grad and not is_compiling()
95
- and not is_in_onnx_export()):
96
- warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
97
- f"can be accelerated via the 'torch-scatter'"
98
- f" package, but it was not found")
99
-
100
- index = broadcast(index, src, dim)
101
- if not is_in_onnx_export():
102
- return src.new_zeros(size).scatter_reduce_(
103
- dim, index, src, reduce=f'a{reduce[-3:]}',
104
- include_self=False)
105
-
106
- fill = torch.full( # type: ignore
107
- size=(1, ),
108
- fill_value=src.min() if 'max' in reduce else src.max(),
109
- dtype=src.dtype,
110
- device=src.device,
111
- ).expand_as(src)
112
- out = src.new_zeros(size).scatter_reduce_(
113
- dim, index, fill, reduce=f'a{reduce[-3:]}',
114
- include_self=True)
115
- return out.scatter_reduce_(dim, index, src,
116
- reduce=f'a{reduce[-3:]}',
117
- include_self=True)
118
-
119
- return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
120
- reduce=reduce[-3:])
121
-
122
- # For "mul" reduction, we prefer `scatter_reduce_` on CPU:
123
- if reduce == 'mul':
124
- if (not torch_geometric.typing.WITH_TORCH_SCATTER
125
- or is_compiling() or not src.is_cuda):
126
-
127
- if src.is_cuda and not is_compiling():
128
- warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
129
- f"can be accelerated via the 'torch-scatter'"
130
- f" package, but it was not found")
131
-
132
- index = broadcast(index, src, dim)
133
- # We initialize with `one` here to match `scatter_mul` output:
134
- return src.new_ones(size).scatter_reduce_(
135
- dim, index, src, reduce='prod', include_self=True)
136
-
137
- return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
138
- reduce='mul')
139
-
140
- raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
141
-
142
- else: # pragma: no cover
143
-
144
- def scatter(
145
- src: Tensor,
146
- index: Tensor,
147
- dim: int = 0,
148
- dim_size: Optional[int] = None,
149
- reduce: str = 'sum',
150
- ) -> Tensor:
151
- r"""Reduces all values from the :obj:`src` tensor at the indices
152
- specified in the :obj:`index` tensor along a given dimension
153
- :obj:`dim`. See the `documentation
154
- <https://pytorch-scatter.readthedocs.io/en/latest/functions/
155
- scatter.html>`_ of the :obj:`torch_scatter` package for more
156
- information.
157
-
158
- Args:
159
- src (torch.Tensor): The source tensor.
160
- index (torch.Tensor): The index tensor.
161
- dim (int, optional): The dimension along which to index.
162
- (default: :obj:`0`)
163
- dim_size (int, optional): The size of the output tensor at
164
- dimension :obj:`dim`. If set to :obj:`None`, will create a
165
- minimal-sized output tensor according to
166
- :obj:`index.max() + 1`. (default: :obj:`None`)
167
- reduce (str, optional): The reduce operation (:obj:`"sum"`,
168
- :obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
169
- :obj:`"any"`). (default: :obj:`"sum"`)
170
- """
171
- if reduce == 'any':
172
- dim = src.dim() + dim if dim < 0 else dim
173
-
174
- if dim_size is None:
175
- dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
176
-
177
- size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
26
+ Args:
27
+ src (torch.Tensor): The source tensor.
28
+ index (torch.Tensor): The index tensor.
29
+ dim (int, optional): The dimension along which to index.
30
+ (default: ``0``)
31
+ dim_size (int, optional): The size of the output tensor at dimension
32
+ ``dim``. If set to :obj:`None`, will create a minimal-sized output
33
+ tensor according to ``index.max() + 1``. (default: :obj:`None`)
34
+ reduce (str, optional): The reduce operation (``"sum"``, ``"mean"``,
35
+ ``"mul"``, ``"min"``, ``"max"`` or ``"any"``). (default: ``"sum"``)
36
+ """
37
+ if isinstance(index, Tensor) and index.dim() != 1:
38
+ raise ValueError(f"The `index` argument must be one-dimensional "
39
+ f"(got {index.dim()} dimensions)")
40
+
41
+ dim = src.dim() + dim if dim < 0 else dim
42
+
43
+ if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
44
+ raise ValueError(f"The `dim` argument must lay between 0 and "
45
+ f"{src.dim() - 1} (got {dim})")
46
+
47
+ if dim_size is None:
48
+ dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
49
+
50
+ # For now, we maintain various different code paths, based on whether
51
+ # the input requires gradients and whether it lays on the CPU/GPU.
52
+ # For example, `torch_scatter` is usually faster than
53
+ # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
54
+ # on CPU.
55
+ # `torch.scatter_reduce` has a faster forward implementation for
56
+ # "min"/"max" reductions since it does not compute additional arg
57
+ # indices, but is therefore way slower in its backward implementation.
58
+ # More insights can be found in `test/utils/test_scatter.py`.
59
+
60
+ size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
61
+
62
+ # For "any" reduction, we use regular `scatter_`:
63
+ if reduce == 'any':
64
+ index = broadcast(index, src, dim)
65
+ return src.new_zeros(size).scatter_(dim, index, src)
66
+
67
+ # For "sum" and "mean" reduction, we make use of `scatter_add_`:
68
+ if reduce == 'sum' or reduce == 'add':
69
+ index = broadcast(index, src, dim)
70
+ return src.new_zeros(size).scatter_add_(dim, index, src)
71
+
72
+ if reduce == 'mean':
73
+ count = src.new_zeros(dim_size)
74
+ count.scatter_add_(0, index, src.new_ones(src.size(dim)))
75
+ count = count.clamp(min=1)
76
+
77
+ index = broadcast(index, src, dim)
78
+ out = src.new_zeros(size).scatter_add_(dim, index, src)
79
+
80
+ return out / broadcast(count, out, dim)
81
+
82
+ # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
83
+ # in case the input does not require gradients:
84
+ if reduce in ['min', 'max', 'amin', 'amax']:
85
+ if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
86
+ or is_in_onnx_export() or not src.is_cuda
87
+ or not src.requires_grad):
88
+
89
+ if (src.is_cuda and src.requires_grad and not is_compiling()
90
+ and not is_in_onnx_export()):
91
+ warnings.warn(
92
+ f"The usage of `scatter(reduce='{reduce}')` "
93
+ f"can be accelerated via the 'torch-scatter'"
94
+ f" package, but it was not found", stacklevel=2)
178
95
 
179
96
  index = broadcast(index, src, dim)
180
- return src.new_zeros(size).scatter_(dim, index, src)
97
+ if not is_in_onnx_export():
98
+ return src.new_zeros(size).scatter_reduce_(
99
+ dim, index, src, reduce=f'a{reduce[-3:]}',
100
+ include_self=False)
101
+
102
+ fill = torch.full( # type: ignore
103
+ size=(1, ),
104
+ fill_value=src.min() if 'max' in reduce else src.max(),
105
+ dtype=src.dtype,
106
+ device=src.device,
107
+ ).expand_as(src)
108
+ out = src.new_zeros(size).scatter_reduce_(dim, index, fill,
109
+ reduce=f'a{reduce[-3:]}',
110
+ include_self=True)
111
+ return out.scatter_reduce_(dim, index, src,
112
+ reduce=f'a{reduce[-3:]}',
113
+ include_self=True)
181
114
 
182
- if not torch_geometric.typing.WITH_TORCH_SCATTER:
183
- raise ImportError("'scatter' requires the 'torch-scatter' package")
115
+ return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
116
+ reduce=reduce[-3:])
117
+
118
+ # For "mul" reduction, we prefer `scatter_reduce_` on CPU:
119
+ if reduce == 'mul':
120
+ if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
121
+ or not src.is_cuda):
122
+
123
+ if src.is_cuda and not is_compiling():
124
+ warnings.warn(
125
+ f"The usage of `scatter(reduce='{reduce}')` "
126
+ f"can be accelerated via the 'torch-scatter'"
127
+ f" package, but it was not found", stacklevel=2)
184
128
 
185
- if reduce == 'amin' or reduce == 'amax':
186
- reduce = reduce[-3:]
129
+ index = broadcast(index, src, dim)
130
+ # We initialize with `one` here to match `scatter_mul` output:
131
+ return src.new_ones(size).scatter_reduce_(dim, index, src,
132
+ reduce='prod',
133
+ include_self=True)
187
134
 
188
135
  return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
189
- reduce=reduce)
136
+ reduce='mul')
137
+
138
+ raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
190
139
 
191
140
 
192
141
  def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
@@ -215,24 +164,18 @@ def scatter_argmax(
215
164
  if dim_size is None:
216
165
  dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
217
166
 
218
- if torch_geometric.typing.WITH_PT112:
219
- if not is_in_onnx_export():
220
- res = src.new_empty(dim_size)
221
- res.scatter_reduce_(0, index, src.detach(), reduce='amax',
222
- include_self=False)
223
- else:
224
- # `include_self=False` is currently not supported by ONNX:
225
- res = src.new_full(
226
- size=(dim_size, ),
227
- fill_value=src.min(), # type: ignore
228
- )
229
- res.scatter_reduce_(0, index, src.detach(), reduce="amax",
230
- include_self=True)
231
- elif torch_geometric.typing.WITH_PT111:
232
- res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
233
- output_size=dim_size) # type: ignore
167
+ if not is_in_onnx_export():
168
+ res = src.new_empty(dim_size)
169
+ res.scatter_reduce_(0, index, src.detach(), reduce='amax',
170
+ include_self=False)
234
171
  else:
235
- raise ValueError("'scatter_argmax' requires PyTorch >= 1.11")
172
+ # `include_self=False` is currently not supported by ONNX:
173
+ res = src.new_full(
174
+ size=(dim_size, ),
175
+ fill_value=src.min(), # type: ignore
176
+ )
177
+ res.scatter_reduce_(0, index, src.detach(), reduce="amax",
178
+ include_self=True)
236
179
 
237
180
  out = index.new_full((dim_size, ), fill_value=dim_size - 1)
238
181
  nonzero = (src == res[index]).nonzero().view(-1)
@@ -290,13 +233,7 @@ def group_argsort(
290
233
 
291
234
  # Compute `grouped_argsort`:
292
235
  src = src - 2 * index if descending else src + 2 * index
293
- if torch_geometric.typing.WITH_PT113:
294
- perm = src.argsort(descending=descending, stable=stable)
295
- else:
296
- perm = src.argsort(descending=descending)
297
- if stable:
298
- warnings.warn("Ignoring option `stable=True` in 'group_argsort' "
299
- "since it requires PyTorch >= 1.13.0")
236
+ perm = src.argsort(descending=descending, stable=stable)
300
237
  out = torch.empty_like(index)
301
238
  out[perm] = torch.arange(index.numel(), device=index.device)
302
239
 
@@ -351,5 +288,5 @@ def group_cat(
351
288
  """
352
289
  assert len(tensors) == len(indices)
353
290
  index, perm = torch.cat(indices).sort(stable=True)
354
- out = torch.cat(tensors, dim=0)[perm]
291
+ out = torch.cat(tensors, dim=dim).index_select(dim, perm)
355
292
  return (out, index) if return_index else out
@@ -107,8 +107,6 @@ def sort_edge_index( # noqa: F811
107
107
  num_nodes = maybe_num_nodes(edge_index, num_nodes)
108
108
 
109
109
  if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:
110
- if not torch_geometric.typing.WITH_PT113:
111
- raise ValueError("'sort_edge_index' will result in an overflow")
112
110
  perm = lexsort(keys=[
113
111
  edge_index[int(sort_by_row)],
114
112
  edge_index[1 - int(sort_by_row)],
@@ -63,18 +63,20 @@ def spmm(
63
63
 
64
64
  # Always convert COO to CSR for more efficient processing:
65
65
  if src.layout == torch.sparse_coo:
66
- warnings.warn(f"Converting sparse tensor to CSR format for more "
67
- f"efficient processing. Consider converting your "
68
- f"sparse tensor to CSR format beforehand to avoid "
69
- f"repeated conversion (got '{src.layout}')")
66
+ warnings.warn(
67
+ f"Converting sparse tensor to CSR format for more "
68
+ f"efficient processing. Consider converting your "
69
+ f"sparse tensor to CSR format beforehand to avoid "
70
+ f"repeated conversion (got '{src.layout}')", stacklevel=2)
70
71
  src = src.to_sparse_csr()
71
72
 
72
73
  # Warn in case of CSC format without gradient computation:
73
74
  if src.layout == torch.sparse_csc and not other.requires_grad:
74
- warnings.warn(f"Converting sparse tensor to CSR format for more "
75
- f"efficient processing. Consider converting your "
76
- f"sparse tensor to CSR format beforehand to avoid "
77
- f"repeated conversion (got '{src.layout}')")
75
+ warnings.warn(
76
+ f"Converting sparse tensor to CSR format for more "
77
+ f"efficient processing. Consider converting your "
78
+ f"sparse tensor to CSR format beforehand to avoid "
79
+ f"repeated conversion (got '{src.layout}')", stacklevel=2)
78
80
 
79
81
  # Use the default code path for `sum` reduction (works on CPU/GPU):
80
82
  if reduce == 'sum':
@@ -99,10 +101,11 @@ def spmm(
99
101
  # TODO The `torch.sparse.mm` code path with the `reduce` argument does
100
102
  # not yet support CSC :(
101
103
  if src.layout == torch.sparse_csc:
102
- warnings.warn(f"Converting sparse tensor to CSR format for more "
103
- f"efficient processing. Consider converting your "
104
- f"sparse tensor to CSR format beforehand to avoid "
105
- f"repeated conversion (got '{src.layout}')")
104
+ warnings.warn(
105
+ f"Converting sparse tensor to CSR format for more "
106
+ f"efficient processing. Consider converting your "
107
+ f"sparse tensor to CSR format beforehand to avoid "
108
+ f"repeated conversion (got '{src.layout}')", stacklevel=2)
106
109
  src = src.to_sparse_csr()
107
110
 
108
111
  return torch.sparse.mm(src, other, reduce)
@@ -115,8 +118,7 @@ def spmm(
115
118
  if src.layout == torch.sparse_csr:
116
119
  ptr = src.crow_indices()
117
120
  deg = ptr[1:] - ptr[:-1]
118
- elif (torch_geometric.typing.WITH_PT112
119
- and src.layout == torch.sparse_csc):
121
+ elif src.layout == torch.sparse_csc:
120
122
  assert src.layout == torch.sparse_csc
121
123
  ones = torch.ones_like(src.values())
122
124
  index = src.row_indices()
@@ -346,10 +346,12 @@ def k_hop_subgraph(
346
346
 
347
347
  subsets = [node_idx]
348
348
 
349
+ preserved_edge_mask = torch.zeros_like(edge_mask)
349
350
  for _ in range(num_hops):
350
351
  node_mask.fill_(False)
351
352
  node_mask[subsets[-1]] = True
352
353
  torch.index_select(node_mask, 0, row, out=edge_mask)
354
+ preserved_edge_mask |= edge_mask
353
355
  subsets.append(col[edge_mask])
354
356
 
355
357
  subset, inv = torch.cat(subsets).unique(return_inverse=True)
@@ -360,6 +362,8 @@ def k_hop_subgraph(
360
362
 
361
363
  if not directed:
362
364
  edge_mask = node_mask[row] & node_mask[col]
365
+ else:
366
+ edge_mask = preserved_edge_mask
363
367
 
364
368
  edge_index = edge_index[:, edge_mask]
365
369
 
@@ -123,8 +123,8 @@ def to_dense_batch(
123
123
  x, idx = x[mask], idx[mask]
124
124
 
125
125
  size = [batch_size * max_num_nodes] + list(x.size())[1:]
126
- out = torch.as_tensor(fill_value, device=x.device)
127
- out = out.to(x.dtype).repeat(size)
126
+ out = torch.as_tensor(fill_value, device=x.device, dtype=x.dtype)
127
+ out = out.repeat(size)
128
128
  out[idx] = x
129
129
  out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])
130
130
 
@@ -234,10 +234,10 @@ def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int],
234
234
  rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone()
235
235
  rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes]
236
236
 
237
- col = torch.narrow(col, 0, 0, rowptr[-1])
237
+ col = torch.narrow(col, 0, 0, rowptr[-1]) # type: ignore
238
238
 
239
239
  if value is not None:
240
- value = torch.narrow(value, 0, 0, rowptr[-1])
240
+ value = torch.narrow(value, 0, 0, rowptr[-1]) # type: ignore
241
241
 
242
242
  csr2csc = src.storage._csr2csc
243
243
  if csr2csc is not None: