pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -53,10 +53,11 @@ from ._negative_sampling import (negative_sampling, batched_negative_sampling,
|
|
|
53
53
|
structured_negative_sampling_feasible)
|
|
54
54
|
from .augmentation import shuffle_node, mask_feature, add_random_edge
|
|
55
55
|
from ._tree_decomposition import tree_decomposition
|
|
56
|
-
from .embedding import get_embeddings
|
|
56
|
+
from .embedding import get_embeddings, get_embeddings_hetero
|
|
57
57
|
from ._trim_to_layer import trim_to_layer
|
|
58
58
|
from .ppr import get_ppr
|
|
59
59
|
from ._train_test_split_edges import train_test_split_edges
|
|
60
|
+
from .influence import total_influence
|
|
60
61
|
|
|
61
62
|
__all__ = [
|
|
62
63
|
'scatter',
|
|
@@ -145,9 +146,11 @@ __all__ = [
|
|
|
145
146
|
'add_random_edge',
|
|
146
147
|
'tree_decomposition',
|
|
147
148
|
'get_embeddings',
|
|
149
|
+
'get_embeddings_hetero',
|
|
148
150
|
'trim_to_layer',
|
|
149
151
|
'get_ppr',
|
|
150
152
|
'train_test_split_edges',
|
|
153
|
+
'total_influence',
|
|
151
154
|
]
|
|
152
155
|
|
|
153
156
|
# `structured_negative_sampling_feasible` is a long name and thus destroys the
|
|
@@ -1,11 +1,7 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
3
|
from torch import Tensor
|
|
6
4
|
|
|
7
|
-
import torch_geometric.typing
|
|
8
|
-
|
|
9
5
|
|
|
10
6
|
def lexsort(
|
|
11
7
|
keys: List[Tensor],
|
|
@@ -28,11 +24,6 @@ def lexsort(
|
|
|
28
24
|
"""
|
|
29
25
|
assert len(keys) >= 1
|
|
30
26
|
|
|
31
|
-
if not torch_geometric.typing.WITH_PT113:
|
|
32
|
-
keys = [k.neg() for k in keys] if descending else keys
|
|
33
|
-
out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
|
|
34
|
-
return torch.from_numpy(out).to(keys[0].device)
|
|
35
|
-
|
|
36
27
|
out = keys[0].argsort(dim=dim, descending=descending, stable=True)
|
|
37
28
|
for k in keys[1:]:
|
|
38
29
|
index = k.gather(dim, out)
|
|
@@ -12,7 +12,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|
|
12
12
|
def negative_sampling(
|
|
13
13
|
edge_index: Tensor,
|
|
14
14
|
num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
|
|
15
|
-
num_neg_samples: Optional[int] = None,
|
|
15
|
+
num_neg_samples: Optional[Union[int, float]] = None,
|
|
16
16
|
method: str = "sparse",
|
|
17
17
|
force_undirected: bool = False,
|
|
18
18
|
) -> Tensor:
|
|
@@ -25,10 +25,12 @@ def negative_sampling(
|
|
|
25
25
|
If given as a tuple, then :obj:`edge_index` is interpreted as a
|
|
26
26
|
bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.
|
|
27
27
|
(default: :obj:`None`)
|
|
28
|
-
num_neg_samples (int, optional): The (approximate) number of
|
|
29
|
-
samples to return.
|
|
30
|
-
|
|
31
|
-
positive
|
|
28
|
+
num_neg_samples (int or float, optional): The (approximate) number of
|
|
29
|
+
negative samples to return. If set to a floating-point value, it
|
|
30
|
+
represents the ratio of negative samples to generate based on the
|
|
31
|
+
number of positive edges. If set to :obj:`None`, will try to
|
|
32
|
+
return a negative edge for every positive edge.
|
|
33
|
+
(default: :obj:`None`)
|
|
32
34
|
method (str, optional): The method to use for negative sampling,
|
|
33
35
|
*i.e.* :obj:`"sparse"` or :obj:`"dense"`.
|
|
34
36
|
This is a memory/runtime trade-off.
|
|
@@ -48,6 +50,11 @@ def negative_sampling(
|
|
|
48
50
|
tensor([[3, 0, 0, 3],
|
|
49
51
|
[2, 3, 2, 1]])
|
|
50
52
|
|
|
53
|
+
>>> negative_sampling(edge_index, num_nodes=(3, 4),
|
|
54
|
+
... num_neg_samples=0.5) # 50% of positive edges
|
|
55
|
+
tensor([[0, 3],
|
|
56
|
+
[3, 0]])
|
|
57
|
+
|
|
51
58
|
>>> # For bipartite graph
|
|
52
59
|
>>> negative_sampling(edge_index, num_nodes=(3, 4))
|
|
53
60
|
tensor([[0, 2, 2, 1],
|
|
@@ -74,6 +81,8 @@ def negative_sampling(
|
|
|
74
81
|
|
|
75
82
|
if num_neg_samples is None:
|
|
76
83
|
num_neg_samples = edge_index.size(1)
|
|
84
|
+
elif isinstance(num_neg_samples, float):
|
|
85
|
+
num_neg_samples = int(num_neg_samples * edge_index.size(1))
|
|
77
86
|
if force_undirected:
|
|
78
87
|
num_neg_samples = num_neg_samples // 2
|
|
79
88
|
|
|
@@ -100,10 +109,9 @@ def negative_sampling(
|
|
|
100
109
|
idx = idx.to('cpu')
|
|
101
110
|
for _ in range(3): # Number of tries to sample negative indices.
|
|
102
111
|
rnd = sample(population, sample_size, device='cpu')
|
|
103
|
-
mask = np.isin(rnd.numpy(), idx.numpy())
|
|
112
|
+
mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
|
|
104
113
|
if neg_idx is not None:
|
|
105
|
-
mask |= np.isin(rnd, neg_idx.
|
|
106
|
-
mask = torch.from_numpy(mask).to(torch.bool)
|
|
114
|
+
mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
|
|
107
115
|
rnd = rnd[~mask].to(edge_index.device)
|
|
108
116
|
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
|
|
109
117
|
if neg_idx.numel() >= num_neg_samples:
|
|
@@ -117,7 +125,7 @@ def negative_sampling(
|
|
|
117
125
|
def batched_negative_sampling(
|
|
118
126
|
edge_index: Tensor,
|
|
119
127
|
batch: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
120
|
-
num_neg_samples: Optional[int] = None,
|
|
128
|
+
num_neg_samples: Optional[Union[int, float]] = None,
|
|
121
129
|
method: str = "sparse",
|
|
122
130
|
force_undirected: bool = False,
|
|
123
131
|
) -> Tensor:
|
|
@@ -131,9 +139,11 @@ def batched_negative_sampling(
|
|
|
131
139
|
node to a specific example.
|
|
132
140
|
If given as a tuple, then :obj:`edge_index` is interpreted as a
|
|
133
141
|
bipartite graph connecting two different node types.
|
|
134
|
-
num_neg_samples (int, optional): The number of negative
|
|
135
|
-
return. If set to :obj:`None`, will try to return a
|
|
136
|
-
for every positive edge.
|
|
142
|
+
num_neg_samples (int or float, optional): The number of negative
|
|
143
|
+
samples to return. If set to :obj:`None`, will try to return a
|
|
144
|
+
negative edge for every positive edge. If float, it will generate
|
|
145
|
+
:obj:`num_neg_samples * num_edges` negative samples.
|
|
146
|
+
(default: :obj:`None`)
|
|
137
147
|
method (str, optional): The method to use for negative sampling,
|
|
138
148
|
*i.e.* :obj:`"sparse"` or :obj:`"dense"`.
|
|
139
149
|
This is a memory/runtime trade-off.
|
|
@@ -157,6 +167,11 @@ def batched_negative_sampling(
|
|
|
157
167
|
tensor([[3, 1, 3, 2, 7, 7, 6, 5],
|
|
158
168
|
[2, 0, 1, 1, 5, 6, 4, 4]])
|
|
159
169
|
|
|
170
|
+
>>> # Using float multiplier for negative samples
|
|
171
|
+
>>> batched_negative_sampling(edge_index, batch, num_neg_samples=1.5)
|
|
172
|
+
tensor([[3, 1, 3, 2, 7, 7, 6, 5, 2, 0, 1, 1],
|
|
173
|
+
[2, 0, 1, 1, 5, 6, 4, 4, 3, 2, 3, 0]])
|
|
174
|
+
|
|
160
175
|
>>> # For bipartite graph
|
|
161
176
|
>>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])
|
|
162
177
|
>>> edge_index2 = edge_index1 + torch.tensor([[2], [4]])
|
|
@@ -8,185 +8,134 @@ from torch_geometric import is_compiling, is_in_onnx_export, warnings
|
|
|
8
8
|
from torch_geometric.typing import torch_scatter
|
|
9
9
|
from torch_geometric.utils.functions import cumsum
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
|
|
14
|
-
|
|
15
|
-
def scatter(
|
|
16
|
-
src: Tensor,
|
|
17
|
-
index: Tensor,
|
|
18
|
-
dim: int = 0,
|
|
19
|
-
dim_size: Optional[int] = None,
|
|
20
|
-
reduce: str = 'sum',
|
|
21
|
-
) -> Tensor:
|
|
22
|
-
r"""Reduces all values from the :obj:`src` tensor at the indices
|
|
23
|
-
specified in the :obj:`index` tensor along a given dimension
|
|
24
|
-
:obj:`dim`. See the `documentation
|
|
25
|
-
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
|
|
26
|
-
scatter.html>`__ of the :obj:`torch_scatter` package for more
|
|
27
|
-
information.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
src (torch.Tensor): The source tensor.
|
|
31
|
-
index (torch.Tensor): The index tensor.
|
|
32
|
-
dim (int, optional): The dimension along which to index.
|
|
33
|
-
(default: :obj:`0`)
|
|
34
|
-
dim_size (int, optional): The size of the output tensor at
|
|
35
|
-
dimension :obj:`dim`. If set to :obj:`None`, will create a
|
|
36
|
-
minimal-sized output tensor according to
|
|
37
|
-
:obj:`index.max() + 1`. (default: :obj:`None`)
|
|
38
|
-
reduce (str, optional): The reduce operation (:obj:`"sum"`,
|
|
39
|
-
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
|
|
40
|
-
:obj:`"any"`). (default: :obj:`"sum"`)
|
|
41
|
-
"""
|
|
42
|
-
if isinstance(index, Tensor) and index.dim() != 1:
|
|
43
|
-
raise ValueError(f"The `index` argument must be one-dimensional "
|
|
44
|
-
f"(got {index.dim()} dimensions)")
|
|
45
|
-
|
|
46
|
-
dim = src.dim() + dim if dim < 0 else dim
|
|
47
|
-
|
|
48
|
-
if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
|
|
49
|
-
raise ValueError(f"The `dim` argument must lay between 0 and "
|
|
50
|
-
f"{src.dim() - 1} (got {dim})")
|
|
51
|
-
|
|
52
|
-
if dim_size is None:
|
|
53
|
-
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
54
|
-
|
|
55
|
-
# For now, we maintain various different code paths, based on whether
|
|
56
|
-
# the input requires gradients and whether it lays on the CPU/GPU.
|
|
57
|
-
# For example, `torch_scatter` is usually faster than
|
|
58
|
-
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
|
|
59
|
-
# on CPU.
|
|
60
|
-
# `torch.scatter_reduce` has a faster forward implementation for
|
|
61
|
-
# "min"/"max" reductions since it does not compute additional arg
|
|
62
|
-
# indices, but is therefore way slower in its backward implementation.
|
|
63
|
-
# More insights can be found in `test/utils/test_scatter.py`.
|
|
64
|
-
|
|
65
|
-
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
|
66
|
-
|
|
67
|
-
# For "any" reduction, we use regular `scatter_`:
|
|
68
|
-
if reduce == 'any':
|
|
69
|
-
index = broadcast(index, src, dim)
|
|
70
|
-
return src.new_zeros(size).scatter_(dim, index, src)
|
|
11
|
+
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
|
|
71
12
|
|
|
72
|
-
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
|
|
73
|
-
if reduce == 'sum' or reduce == 'add':
|
|
74
|
-
index = broadcast(index, src, dim)
|
|
75
|
-
return src.new_zeros(size).scatter_add_(dim, index, src)
|
|
76
13
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
14
|
+
def scatter(
|
|
15
|
+
src: Tensor,
|
|
16
|
+
index: Tensor,
|
|
17
|
+
dim: int = 0,
|
|
18
|
+
dim_size: Optional[int] = None,
|
|
19
|
+
reduce: str = 'sum',
|
|
20
|
+
) -> Tensor:
|
|
21
|
+
r"""Reduces all values from the :obj:`src` tensor at the indices specified
|
|
22
|
+
in the :obj:`index` tensor along a given dimension ``dim``. See the
|
|
23
|
+
`documentation <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`__ # noqa: E501
|
|
24
|
+
of the ``torch_scatter`` package for more information.
|
|
81
25
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
r"""Reduces all values from the :obj:`src` tensor at the indices
|
|
152
|
-
specified in the :obj:`index` tensor along a given dimension
|
|
153
|
-
:obj:`dim`. See the `documentation
|
|
154
|
-
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
|
|
155
|
-
scatter.html>`_ of the :obj:`torch_scatter` package for more
|
|
156
|
-
information.
|
|
157
|
-
|
|
158
|
-
Args:
|
|
159
|
-
src (torch.Tensor): The source tensor.
|
|
160
|
-
index (torch.Tensor): The index tensor.
|
|
161
|
-
dim (int, optional): The dimension along which to index.
|
|
162
|
-
(default: :obj:`0`)
|
|
163
|
-
dim_size (int, optional): The size of the output tensor at
|
|
164
|
-
dimension :obj:`dim`. If set to :obj:`None`, will create a
|
|
165
|
-
minimal-sized output tensor according to
|
|
166
|
-
:obj:`index.max() + 1`. (default: :obj:`None`)
|
|
167
|
-
reduce (str, optional): The reduce operation (:obj:`"sum"`,
|
|
168
|
-
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
|
|
169
|
-
:obj:`"any"`). (default: :obj:`"sum"`)
|
|
170
|
-
"""
|
|
171
|
-
if reduce == 'any':
|
|
172
|
-
dim = src.dim() + dim if dim < 0 else dim
|
|
173
|
-
|
|
174
|
-
if dim_size is None:
|
|
175
|
-
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
176
|
-
|
|
177
|
-
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
|
26
|
+
Args:
|
|
27
|
+
src (torch.Tensor): The source tensor.
|
|
28
|
+
index (torch.Tensor): The index tensor.
|
|
29
|
+
dim (int, optional): The dimension along which to index.
|
|
30
|
+
(default: ``0``)
|
|
31
|
+
dim_size (int, optional): The size of the output tensor at dimension
|
|
32
|
+
``dim``. If set to :obj:`None`, will create a minimal-sized output
|
|
33
|
+
tensor according to ``index.max() + 1``. (default: :obj:`None`)
|
|
34
|
+
reduce (str, optional): The reduce operation (``"sum"``, ``"mean"``,
|
|
35
|
+
``"mul"``, ``"min"``, ``"max"`` or ``"any"``). (default: ``"sum"``)
|
|
36
|
+
"""
|
|
37
|
+
if isinstance(index, Tensor) and index.dim() != 1:
|
|
38
|
+
raise ValueError(f"The `index` argument must be one-dimensional "
|
|
39
|
+
f"(got {index.dim()} dimensions)")
|
|
40
|
+
|
|
41
|
+
dim = src.dim() + dim if dim < 0 else dim
|
|
42
|
+
|
|
43
|
+
if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
|
|
44
|
+
raise ValueError(f"The `dim` argument must lay between 0 and "
|
|
45
|
+
f"{src.dim() - 1} (got {dim})")
|
|
46
|
+
|
|
47
|
+
if dim_size is None:
|
|
48
|
+
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
49
|
+
|
|
50
|
+
# For now, we maintain various different code paths, based on whether
|
|
51
|
+
# the input requires gradients and whether it lays on the CPU/GPU.
|
|
52
|
+
# For example, `torch_scatter` is usually faster than
|
|
53
|
+
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
|
|
54
|
+
# on CPU.
|
|
55
|
+
# `torch.scatter_reduce` has a faster forward implementation for
|
|
56
|
+
# "min"/"max" reductions since it does not compute additional arg
|
|
57
|
+
# indices, but is therefore way slower in its backward implementation.
|
|
58
|
+
# More insights can be found in `test/utils/test_scatter.py`.
|
|
59
|
+
|
|
60
|
+
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
|
61
|
+
|
|
62
|
+
# For "any" reduction, we use regular `scatter_`:
|
|
63
|
+
if reduce == 'any':
|
|
64
|
+
index = broadcast(index, src, dim)
|
|
65
|
+
return src.new_zeros(size).scatter_(dim, index, src)
|
|
66
|
+
|
|
67
|
+
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
|
|
68
|
+
if reduce == 'sum' or reduce == 'add':
|
|
69
|
+
index = broadcast(index, src, dim)
|
|
70
|
+
return src.new_zeros(size).scatter_add_(dim, index, src)
|
|
71
|
+
|
|
72
|
+
if reduce == 'mean':
|
|
73
|
+
count = src.new_zeros(dim_size)
|
|
74
|
+
count.scatter_add_(0, index, src.new_ones(src.size(dim)))
|
|
75
|
+
count = count.clamp(min=1)
|
|
76
|
+
|
|
77
|
+
index = broadcast(index, src, dim)
|
|
78
|
+
out = src.new_zeros(size).scatter_add_(dim, index, src)
|
|
79
|
+
|
|
80
|
+
return out / broadcast(count, out, dim)
|
|
81
|
+
|
|
82
|
+
# For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
|
|
83
|
+
# in case the input does not require gradients:
|
|
84
|
+
if reduce in ['min', 'max', 'amin', 'amax']:
|
|
85
|
+
if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
|
|
86
|
+
or is_in_onnx_export() or not src.is_cuda
|
|
87
|
+
or not src.requires_grad):
|
|
88
|
+
|
|
89
|
+
if (src.is_cuda and src.requires_grad and not is_compiling()
|
|
90
|
+
and not is_in_onnx_export()):
|
|
91
|
+
warnings.warn(
|
|
92
|
+
f"The usage of `scatter(reduce='{reduce}')` "
|
|
93
|
+
f"can be accelerated via the 'torch-scatter'"
|
|
94
|
+
f" package, but it was not found", stacklevel=2)
|
|
178
95
|
|
|
179
96
|
index = broadcast(index, src, dim)
|
|
180
|
-
|
|
97
|
+
if not is_in_onnx_export():
|
|
98
|
+
return src.new_zeros(size).scatter_reduce_(
|
|
99
|
+
dim, index, src, reduce=f'a{reduce[-3:]}',
|
|
100
|
+
include_self=False)
|
|
101
|
+
|
|
102
|
+
fill = torch.full( # type: ignore
|
|
103
|
+
size=(1, ),
|
|
104
|
+
fill_value=src.min() if 'max' in reduce else src.max(),
|
|
105
|
+
dtype=src.dtype,
|
|
106
|
+
device=src.device,
|
|
107
|
+
).expand_as(src)
|
|
108
|
+
out = src.new_zeros(size).scatter_reduce_(dim, index, fill,
|
|
109
|
+
reduce=f'a{reduce[-3:]}',
|
|
110
|
+
include_self=True)
|
|
111
|
+
return out.scatter_reduce_(dim, index, src,
|
|
112
|
+
reduce=f'a{reduce[-3:]}',
|
|
113
|
+
include_self=True)
|
|
181
114
|
|
|
182
|
-
|
|
183
|
-
|
|
115
|
+
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
|
116
|
+
reduce=reduce[-3:])
|
|
117
|
+
|
|
118
|
+
# For "mul" reduction, we prefer `scatter_reduce_` on CPU:
|
|
119
|
+
if reduce == 'mul':
|
|
120
|
+
if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
|
|
121
|
+
or not src.is_cuda):
|
|
122
|
+
|
|
123
|
+
if src.is_cuda and not is_compiling():
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"The usage of `scatter(reduce='{reduce}')` "
|
|
126
|
+
f"can be accelerated via the 'torch-scatter'"
|
|
127
|
+
f" package, but it was not found", stacklevel=2)
|
|
184
128
|
|
|
185
|
-
|
|
186
|
-
|
|
129
|
+
index = broadcast(index, src, dim)
|
|
130
|
+
# We initialize with `one` here to match `scatter_mul` output:
|
|
131
|
+
return src.new_ones(size).scatter_reduce_(dim, index, src,
|
|
132
|
+
reduce='prod',
|
|
133
|
+
include_self=True)
|
|
187
134
|
|
|
188
135
|
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
|
189
|
-
reduce=
|
|
136
|
+
reduce='mul')
|
|
137
|
+
|
|
138
|
+
raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
|
|
190
139
|
|
|
191
140
|
|
|
192
141
|
def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
|
|
@@ -215,24 +164,18 @@ def scatter_argmax(
|
|
|
215
164
|
if dim_size is None:
|
|
216
165
|
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
217
166
|
|
|
218
|
-
if
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
include_self=False)
|
|
223
|
-
else:
|
|
224
|
-
# `include_self=False` is currently not supported by ONNX:
|
|
225
|
-
res = src.new_full(
|
|
226
|
-
size=(dim_size, ),
|
|
227
|
-
fill_value=src.min(), # type: ignore
|
|
228
|
-
)
|
|
229
|
-
res.scatter_reduce_(0, index, src.detach(), reduce="amax",
|
|
230
|
-
include_self=True)
|
|
231
|
-
elif torch_geometric.typing.WITH_PT111:
|
|
232
|
-
res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
|
|
233
|
-
output_size=dim_size) # type: ignore
|
|
167
|
+
if not is_in_onnx_export():
|
|
168
|
+
res = src.new_empty(dim_size)
|
|
169
|
+
res.scatter_reduce_(0, index, src.detach(), reduce='amax',
|
|
170
|
+
include_self=False)
|
|
234
171
|
else:
|
|
235
|
-
|
|
172
|
+
# `include_self=False` is currently not supported by ONNX:
|
|
173
|
+
res = src.new_full(
|
|
174
|
+
size=(dim_size, ),
|
|
175
|
+
fill_value=src.min(), # type: ignore
|
|
176
|
+
)
|
|
177
|
+
res.scatter_reduce_(0, index, src.detach(), reduce="amax",
|
|
178
|
+
include_self=True)
|
|
236
179
|
|
|
237
180
|
out = index.new_full((dim_size, ), fill_value=dim_size - 1)
|
|
238
181
|
nonzero = (src == res[index]).nonzero().view(-1)
|
|
@@ -290,13 +233,7 @@ def group_argsort(
|
|
|
290
233
|
|
|
291
234
|
# Compute `grouped_argsort`:
|
|
292
235
|
src = src - 2 * index if descending else src + 2 * index
|
|
293
|
-
|
|
294
|
-
perm = src.argsort(descending=descending, stable=stable)
|
|
295
|
-
else:
|
|
296
|
-
perm = src.argsort(descending=descending)
|
|
297
|
-
if stable:
|
|
298
|
-
warnings.warn("Ignoring option `stable=True` in 'group_argsort' "
|
|
299
|
-
"since it requires PyTorch >= 1.13.0")
|
|
236
|
+
perm = src.argsort(descending=descending, stable=stable)
|
|
300
237
|
out = torch.empty_like(index)
|
|
301
238
|
out[perm] = torch.arange(index.numel(), device=index.device)
|
|
302
239
|
|
|
@@ -351,5 +288,5 @@ def group_cat(
|
|
|
351
288
|
"""
|
|
352
289
|
assert len(tensors) == len(indices)
|
|
353
290
|
index, perm = torch.cat(indices).sort(stable=True)
|
|
354
|
-
out = torch.cat(tensors, dim=
|
|
291
|
+
out = torch.cat(tensors, dim=dim).index_select(dim, perm)
|
|
355
292
|
return (out, index) if return_index else out
|
|
@@ -107,8 +107,6 @@ def sort_edge_index( # noqa: F811
|
|
|
107
107
|
num_nodes = maybe_num_nodes(edge_index, num_nodes)
|
|
108
108
|
|
|
109
109
|
if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:
|
|
110
|
-
if not torch_geometric.typing.WITH_PT113:
|
|
111
|
-
raise ValueError("'sort_edge_index' will result in an overflow")
|
|
112
110
|
perm = lexsort(keys=[
|
|
113
111
|
edge_index[int(sort_by_row)],
|
|
114
112
|
edge_index[1 - int(sort_by_row)],
|
torch_geometric/utils/_spmm.py
CHANGED
|
@@ -63,18 +63,20 @@ def spmm(
|
|
|
63
63
|
|
|
64
64
|
# Always convert COO to CSR for more efficient processing:
|
|
65
65
|
if src.layout == torch.sparse_coo:
|
|
66
|
-
warnings.warn(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
66
|
+
warnings.warn(
|
|
67
|
+
f"Converting sparse tensor to CSR format for more "
|
|
68
|
+
f"efficient processing. Consider converting your "
|
|
69
|
+
f"sparse tensor to CSR format beforehand to avoid "
|
|
70
|
+
f"repeated conversion (got '{src.layout}')", stacklevel=2)
|
|
70
71
|
src = src.to_sparse_csr()
|
|
71
72
|
|
|
72
73
|
# Warn in case of CSC format without gradient computation:
|
|
73
74
|
if src.layout == torch.sparse_csc and not other.requires_grad:
|
|
74
|
-
warnings.warn(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
75
|
+
warnings.warn(
|
|
76
|
+
f"Converting sparse tensor to CSR format for more "
|
|
77
|
+
f"efficient processing. Consider converting your "
|
|
78
|
+
f"sparse tensor to CSR format beforehand to avoid "
|
|
79
|
+
f"repeated conversion (got '{src.layout}')", stacklevel=2)
|
|
78
80
|
|
|
79
81
|
# Use the default code path for `sum` reduction (works on CPU/GPU):
|
|
80
82
|
if reduce == 'sum':
|
|
@@ -99,10 +101,11 @@ def spmm(
|
|
|
99
101
|
# TODO The `torch.sparse.mm` code path with the `reduce` argument does
|
|
100
102
|
# not yet support CSC :(
|
|
101
103
|
if src.layout == torch.sparse_csc:
|
|
102
|
-
warnings.warn(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
warnings.warn(
|
|
105
|
+
f"Converting sparse tensor to CSR format for more "
|
|
106
|
+
f"efficient processing. Consider converting your "
|
|
107
|
+
f"sparse tensor to CSR format beforehand to avoid "
|
|
108
|
+
f"repeated conversion (got '{src.layout}')", stacklevel=2)
|
|
106
109
|
src = src.to_sparse_csr()
|
|
107
110
|
|
|
108
111
|
return torch.sparse.mm(src, other, reduce)
|
|
@@ -115,8 +118,7 @@ def spmm(
|
|
|
115
118
|
if src.layout == torch.sparse_csr:
|
|
116
119
|
ptr = src.crow_indices()
|
|
117
120
|
deg = ptr[1:] - ptr[:-1]
|
|
118
|
-
elif
|
|
119
|
-
and src.layout == torch.sparse_csc):
|
|
121
|
+
elif src.layout == torch.sparse_csc:
|
|
120
122
|
assert src.layout == torch.sparse_csc
|
|
121
123
|
ones = torch.ones_like(src.values())
|
|
122
124
|
index = src.row_indices()
|
|
@@ -346,10 +346,12 @@ def k_hop_subgraph(
|
|
|
346
346
|
|
|
347
347
|
subsets = [node_idx]
|
|
348
348
|
|
|
349
|
+
preserved_edge_mask = torch.zeros_like(edge_mask)
|
|
349
350
|
for _ in range(num_hops):
|
|
350
351
|
node_mask.fill_(False)
|
|
351
352
|
node_mask[subsets[-1]] = True
|
|
352
353
|
torch.index_select(node_mask, 0, row, out=edge_mask)
|
|
354
|
+
preserved_edge_mask |= edge_mask
|
|
353
355
|
subsets.append(col[edge_mask])
|
|
354
356
|
|
|
355
357
|
subset, inv = torch.cat(subsets).unique(return_inverse=True)
|
|
@@ -360,6 +362,8 @@ def k_hop_subgraph(
|
|
|
360
362
|
|
|
361
363
|
if not directed:
|
|
362
364
|
edge_mask = node_mask[row] & node_mask[col]
|
|
365
|
+
else:
|
|
366
|
+
edge_mask = preserved_edge_mask
|
|
363
367
|
|
|
364
368
|
edge_index = edge_index[:, edge_mask]
|
|
365
369
|
|
|
@@ -234,10 +234,10 @@ def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int],
|
|
|
234
234
|
rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone()
|
|
235
235
|
rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes]
|
|
236
236
|
|
|
237
|
-
col = torch.narrow(col, 0, 0, rowptr[-1])
|
|
237
|
+
col = torch.narrow(col, 0, 0, rowptr[-1]) # type: ignore
|
|
238
238
|
|
|
239
239
|
if value is not None:
|
|
240
|
-
value = torch.narrow(value, 0, 0, rowptr[-1])
|
|
240
|
+
value = torch.narrow(value, 0, 0, rowptr[-1]) # type: ignore
|
|
241
241
|
|
|
242
242
|
csr2csc = src.storage._csr2csc
|
|
243
243
|
if csr2csc is not None:
|