pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -4,174 +4,138 @@ import torch
|
|
|
4
4
|
from torch import Tensor
|
|
5
5
|
|
|
6
6
|
import torch_geometric.typing
|
|
7
|
-
from torch_geometric import is_compiling, warnings
|
|
7
|
+
from torch_geometric import is_compiling, is_in_onnx_export, warnings
|
|
8
8
|
from torch_geometric.typing import torch_scatter
|
|
9
9
|
from torch_geometric.utils.functions import cumsum
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
|
|
14
|
-
|
|
15
|
-
def scatter(
|
|
16
|
-
src: Tensor,
|
|
17
|
-
index: Tensor,
|
|
18
|
-
dim: int = 0,
|
|
19
|
-
dim_size: Optional[int] = None,
|
|
20
|
-
reduce: str = 'sum',
|
|
21
|
-
) -> Tensor:
|
|
22
|
-
r"""Reduces all values from the :obj:`src` tensor at the indices
|
|
23
|
-
specified in the :obj:`index` tensor along a given dimension
|
|
24
|
-
:obj:`dim`. See the `documentation
|
|
25
|
-
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
|
|
26
|
-
scatter.html>`__ of the :obj:`torch_scatter` package for more
|
|
27
|
-
information.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
src (torch.Tensor): The source tensor.
|
|
31
|
-
index (torch.Tensor): The index tensor.
|
|
32
|
-
dim (int, optional): The dimension along which to index.
|
|
33
|
-
(default: :obj:`0`)
|
|
34
|
-
dim_size (int, optional): The size of the output tensor at
|
|
35
|
-
dimension :obj:`dim`. If set to :obj:`None`, will create a
|
|
36
|
-
minimal-sized output tensor according to
|
|
37
|
-
:obj:`index.max() + 1`. (default: :obj:`None`)
|
|
38
|
-
reduce (str, optional): The reduce operation (:obj:`"sum"`,
|
|
39
|
-
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
|
|
40
|
-
:obj:`"any"`). (default: :obj:`"sum"`)
|
|
41
|
-
"""
|
|
42
|
-
if isinstance(index, Tensor) and index.dim() != 1:
|
|
43
|
-
raise ValueError(f"The `index` argument must be one-dimensional "
|
|
44
|
-
f"(got {index.dim()} dimensions)")
|
|
45
|
-
|
|
46
|
-
dim = src.dim() + dim if dim < 0 else dim
|
|
47
|
-
|
|
48
|
-
if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
|
|
49
|
-
raise ValueError(f"The `dim` argument must lay between 0 and "
|
|
50
|
-
f"{src.dim() - 1} (got {dim})")
|
|
51
|
-
|
|
52
|
-
if dim_size is None:
|
|
53
|
-
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
54
|
-
|
|
55
|
-
# For now, we maintain various different code paths, based on whether
|
|
56
|
-
# the input requires gradients and whether it lays on the CPU/GPU.
|
|
57
|
-
# For example, `torch_scatter` is usually faster than
|
|
58
|
-
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
|
|
59
|
-
# on CPU.
|
|
60
|
-
# `torch.scatter_reduce` has a faster forward implementation for
|
|
61
|
-
# "min"/"max" reductions since it does not compute additional arg
|
|
62
|
-
# indices, but is therefore way slower in its backward implementation.
|
|
63
|
-
# More insights can be found in `test/utils/test_scatter.py`.
|
|
64
|
-
|
|
65
|
-
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
|
66
|
-
|
|
67
|
-
# For "any" reduction, we use regular `scatter_`:
|
|
68
|
-
if reduce == 'any':
|
|
69
|
-
index = broadcast(index, src, dim)
|
|
70
|
-
return src.new_zeros(size).scatter_(dim, index, src)
|
|
11
|
+
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
|
|
71
12
|
|
|
72
|
-
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
|
|
73
|
-
if reduce == 'sum' or reduce == 'add':
|
|
74
|
-
index = broadcast(index, src, dim)
|
|
75
|
-
return src.new_zeros(size).scatter_add_(dim, index, src)
|
|
76
13
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
14
|
+
def scatter(
|
|
15
|
+
src: Tensor,
|
|
16
|
+
index: Tensor,
|
|
17
|
+
dim: int = 0,
|
|
18
|
+
dim_size: Optional[int] = None,
|
|
19
|
+
reduce: str = 'sum',
|
|
20
|
+
) -> Tensor:
|
|
21
|
+
r"""Reduces all values from the :obj:`src` tensor at the indices specified
|
|
22
|
+
in the :obj:`index` tensor along a given dimension ``dim``. See the
|
|
23
|
+
`documentation <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`__ # noqa: E501
|
|
24
|
+
of the ``torch_scatter`` package for more information.
|
|
81
25
|
|
|
82
|
-
|
|
83
|
-
|
|
26
|
+
Args:
|
|
27
|
+
src (torch.Tensor): The source tensor.
|
|
28
|
+
index (torch.Tensor): The index tensor.
|
|
29
|
+
dim (int, optional): The dimension along which to index.
|
|
30
|
+
(default: ``0``)
|
|
31
|
+
dim_size (int, optional): The size of the output tensor at dimension
|
|
32
|
+
``dim``. If set to :obj:`None`, will create a minimal-sized output
|
|
33
|
+
tensor according to ``index.max() + 1``. (default: :obj:`None`)
|
|
34
|
+
reduce (str, optional): The reduce operation (``"sum"``, ``"mean"``,
|
|
35
|
+
``"mul"``, ``"min"``, ``"max"`` or ``"any"``). (default: ``"sum"``)
|
|
36
|
+
"""
|
|
37
|
+
if isinstance(index, Tensor) and index.dim() != 1:
|
|
38
|
+
raise ValueError(f"The `index` argument must be one-dimensional "
|
|
39
|
+
f"(got {index.dim()} dimensions)")
|
|
40
|
+
|
|
41
|
+
dim = src.dim() + dim if dim < 0 else dim
|
|
84
42
|
|
|
85
|
-
|
|
43
|
+
if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
|
|
44
|
+
raise ValueError(f"The `dim` argument must lay between 0 and "
|
|
45
|
+
f"{src.dim() - 1} (got {dim})")
|
|
86
46
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
if reduce in ['min', 'max', 'amin', 'amax']:
|
|
90
|
-
if (not torch_geometric.typing.WITH_TORCH_SCATTER
|
|
91
|
-
or is_compiling() or not src.is_cuda
|
|
92
|
-
or not src.requires_grad):
|
|
47
|
+
if dim_size is None:
|
|
48
|
+
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
93
49
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
50
|
+
# For now, we maintain various different code paths, based on whether
|
|
51
|
+
# the input requires gradients and whether it lays on the CPU/GPU.
|
|
52
|
+
# For example, `torch_scatter` is usually faster than
|
|
53
|
+
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
|
|
54
|
+
# on CPU.
|
|
55
|
+
# `torch.scatter_reduce` has a faster forward implementation for
|
|
56
|
+
# "min"/"max" reductions since it does not compute additional arg
|
|
57
|
+
# indices, but is therefore way slower in its backward implementation.
|
|
58
|
+
# More insights can be found in `test/utils/test_scatter.py`.
|
|
59
|
+
|
|
60
|
+
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
|
61
|
+
|
|
62
|
+
# For "any" reduction, we use regular `scatter_`:
|
|
63
|
+
if reduce == 'any':
|
|
64
|
+
index = broadcast(index, src, dim)
|
|
65
|
+
return src.new_zeros(size).scatter_(dim, index, src)
|
|
66
|
+
|
|
67
|
+
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
|
|
68
|
+
if reduce == 'sum' or reduce == 'add':
|
|
69
|
+
index = broadcast(index, src, dim)
|
|
70
|
+
return src.new_zeros(size).scatter_add_(dim, index, src)
|
|
71
|
+
|
|
72
|
+
if reduce == 'mean':
|
|
73
|
+
count = src.new_zeros(dim_size)
|
|
74
|
+
count.scatter_add_(0, index, src.new_ones(src.size(dim)))
|
|
75
|
+
count = count.clamp(min=1)
|
|
76
|
+
|
|
77
|
+
index = broadcast(index, src, dim)
|
|
78
|
+
out = src.new_zeros(size).scatter_add_(dim, index, src)
|
|
79
|
+
|
|
80
|
+
return out / broadcast(count, out, dim)
|
|
81
|
+
|
|
82
|
+
# For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
|
|
83
|
+
# in case the input does not require gradients:
|
|
84
|
+
if reduce in ['min', 'max', 'amin', 'amax']:
|
|
85
|
+
if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
|
|
86
|
+
or is_in_onnx_export() or not src.is_cuda
|
|
87
|
+
or not src.requires_grad):
|
|
88
|
+
|
|
89
|
+
if (src.is_cuda and src.requires_grad and not is_compiling()
|
|
90
|
+
and not is_in_onnx_export()):
|
|
91
|
+
warnings.warn(
|
|
92
|
+
f"The usage of `scatter(reduce='{reduce}')` "
|
|
93
|
+
f"can be accelerated via the 'torch-scatter'"
|
|
94
|
+
f" package, but it was not found", stacklevel=2)
|
|
98
95
|
|
|
99
|
-
|
|
96
|
+
index = broadcast(index, src, dim)
|
|
97
|
+
if not is_in_onnx_export():
|
|
100
98
|
return src.new_zeros(size).scatter_reduce_(
|
|
101
99
|
dim, index, src, reduce=f'a{reduce[-3:]}',
|
|
102
100
|
include_self=False)
|
|
103
101
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
index = broadcast(index, src, dim)
|
|
118
|
-
# We initialize with `one` here to match `scatter_mul` output:
|
|
119
|
-
return src.new_ones(size).scatter_reduce_(
|
|
120
|
-
dim, index, src, reduce='prod', include_self=True)
|
|
121
|
-
|
|
122
|
-
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
|
123
|
-
reduce='mul')
|
|
124
|
-
|
|
125
|
-
raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
|
|
126
|
-
|
|
127
|
-
else: # pragma: no cover
|
|
128
|
-
|
|
129
|
-
def scatter(
|
|
130
|
-
src: Tensor,
|
|
131
|
-
index: Tensor,
|
|
132
|
-
dim: int = 0,
|
|
133
|
-
dim_size: Optional[int] = None,
|
|
134
|
-
reduce: str = 'sum',
|
|
135
|
-
) -> Tensor:
|
|
136
|
-
r"""Reduces all values from the :obj:`src` tensor at the indices
|
|
137
|
-
specified in the :obj:`index` tensor along a given dimension
|
|
138
|
-
:obj:`dim`. See the `documentation
|
|
139
|
-
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
|
|
140
|
-
scatter.html>`_ of the :obj:`torch_scatter` package for more
|
|
141
|
-
information.
|
|
142
|
-
|
|
143
|
-
Args:
|
|
144
|
-
src (torch.Tensor): The source tensor.
|
|
145
|
-
index (torch.Tensor): The index tensor.
|
|
146
|
-
dim (int, optional): The dimension along which to index.
|
|
147
|
-
(default: :obj:`0`)
|
|
148
|
-
dim_size (int, optional): The size of the output tensor at
|
|
149
|
-
dimension :obj:`dim`. If set to :obj:`None`, will create a
|
|
150
|
-
minimal-sized output tensor according to
|
|
151
|
-
:obj:`index.max() + 1`. (default: :obj:`None`)
|
|
152
|
-
reduce (str, optional): The reduce operation (:obj:`"sum"`,
|
|
153
|
-
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
|
|
154
|
-
:obj:`"any"`). (default: :obj:`"sum"`)
|
|
155
|
-
"""
|
|
156
|
-
if reduce == 'any':
|
|
157
|
-
dim = src.dim() + dim if dim < 0 else dim
|
|
158
|
-
|
|
159
|
-
if dim_size is None:
|
|
160
|
-
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
161
|
-
|
|
162
|
-
size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]
|
|
102
|
+
fill = torch.full( # type: ignore
|
|
103
|
+
size=(1, ),
|
|
104
|
+
fill_value=src.min() if 'max' in reduce else src.max(),
|
|
105
|
+
dtype=src.dtype,
|
|
106
|
+
device=src.device,
|
|
107
|
+
).expand_as(src)
|
|
108
|
+
out = src.new_zeros(size).scatter_reduce_(dim, index, fill,
|
|
109
|
+
reduce=f'a{reduce[-3:]}',
|
|
110
|
+
include_self=True)
|
|
111
|
+
return out.scatter_reduce_(dim, index, src,
|
|
112
|
+
reduce=f'a{reduce[-3:]}',
|
|
113
|
+
include_self=True)
|
|
163
114
|
|
|
164
|
-
|
|
165
|
-
|
|
115
|
+
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
|
116
|
+
reduce=reduce[-3:])
|
|
117
|
+
|
|
118
|
+
# For "mul" reduction, we prefer `scatter_reduce_` on CPU:
|
|
119
|
+
if reduce == 'mul':
|
|
120
|
+
if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()
|
|
121
|
+
or not src.is_cuda):
|
|
166
122
|
|
|
167
|
-
|
|
168
|
-
|
|
123
|
+
if src.is_cuda and not is_compiling():
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"The usage of `scatter(reduce='{reduce}')` "
|
|
126
|
+
f"can be accelerated via the 'torch-scatter'"
|
|
127
|
+
f" package, but it was not found", stacklevel=2)
|
|
169
128
|
|
|
170
|
-
|
|
171
|
-
|
|
129
|
+
index = broadcast(index, src, dim)
|
|
130
|
+
# We initialize with `one` here to match `scatter_mul` output:
|
|
131
|
+
return src.new_ones(size).scatter_reduce_(dim, index, src,
|
|
132
|
+
reduce='prod',
|
|
133
|
+
include_self=True)
|
|
172
134
|
|
|
173
135
|
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
|
174
|
-
reduce=
|
|
136
|
+
reduce='mul')
|
|
137
|
+
|
|
138
|
+
raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
|
|
175
139
|
|
|
176
140
|
|
|
177
141
|
def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
|
|
@@ -187,7 +151,8 @@ def scatter_argmax(
|
|
|
187
151
|
dim_size: Optional[int] = None,
|
|
188
152
|
) -> Tensor:
|
|
189
153
|
|
|
190
|
-
if torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()
|
|
154
|
+
if (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()
|
|
155
|
+
and not is_in_onnx_export()):
|
|
191
156
|
out = torch_scatter.scatter_max(src, index, dim=dim, dim_size=dim_size)
|
|
192
157
|
return out[1]
|
|
193
158
|
|
|
@@ -199,15 +164,18 @@ def scatter_argmax(
|
|
|
199
164
|
if dim_size is None:
|
|
200
165
|
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
|
201
166
|
|
|
202
|
-
if
|
|
167
|
+
if not is_in_onnx_export():
|
|
203
168
|
res = src.new_empty(dim_size)
|
|
204
169
|
res.scatter_reduce_(0, index, src.detach(), reduce='amax',
|
|
205
170
|
include_self=False)
|
|
206
|
-
elif torch_geometric.typing.WITH_PT111:
|
|
207
|
-
res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
|
|
208
|
-
output_size=dim_size) # type: ignore
|
|
209
171
|
else:
|
|
210
|
-
|
|
172
|
+
# `include_self=False` is currently not supported by ONNX:
|
|
173
|
+
res = src.new_full(
|
|
174
|
+
size=(dim_size, ),
|
|
175
|
+
fill_value=src.min(), # type: ignore
|
|
176
|
+
)
|
|
177
|
+
res.scatter_reduce_(0, index, src.detach(), reduce="amax",
|
|
178
|
+
include_self=True)
|
|
211
179
|
|
|
212
180
|
out = index.new_full((dim_size, ), fill_value=dim_size - 1)
|
|
213
181
|
nonzero = (src == res[index]).nonzero().view(-1)
|
|
@@ -265,13 +233,7 @@ def group_argsort(
|
|
|
265
233
|
|
|
266
234
|
# Compute `grouped_argsort`:
|
|
267
235
|
src = src - 2 * index if descending else src + 2 * index
|
|
268
|
-
|
|
269
|
-
perm = src.argsort(descending=descending, stable=stable)
|
|
270
|
-
else:
|
|
271
|
-
perm = src.argsort(descending=descending)
|
|
272
|
-
if stable:
|
|
273
|
-
warnings.warn("Ignoring option `stable=True` in 'group_argsort' "
|
|
274
|
-
"since it requires PyTorch >= 1.13.0")
|
|
236
|
+
perm = src.argsort(descending=descending, stable=stable)
|
|
275
237
|
out = torch.empty_like(index)
|
|
276
238
|
out[perm] = torch.arange(index.numel(), device=index.device)
|
|
277
239
|
|
|
@@ -295,7 +257,7 @@ def group_cat(
|
|
|
295
257
|
r"""Concatenates the given sequence of tensors :obj:`tensors` in the given
|
|
296
258
|
dimension :obj:`dim`.
|
|
297
259
|
Different from :meth:`torch.cat`, values along the concatenating dimension
|
|
298
|
-
are grouped according to the
|
|
260
|
+
are grouped according to the indices defined in the :obj:`index` tensors.
|
|
299
261
|
All tensors must have the same shape (except in the concatenating
|
|
300
262
|
dimension).
|
|
301
263
|
|
|
@@ -326,5 +288,5 @@ def group_cat(
|
|
|
326
288
|
"""
|
|
327
289
|
assert len(tensors) == len(indices)
|
|
328
290
|
index, perm = torch.cat(indices).sort(stable=True)
|
|
329
|
-
out = torch.cat(tensors, dim=
|
|
291
|
+
out = torch.cat(tensors, dim=dim).index_select(dim, perm)
|
|
330
292
|
return (out, index) if return_index else out
|
|
@@ -107,8 +107,6 @@ def sort_edge_index( # noqa: F811
|
|
|
107
107
|
num_nodes = maybe_num_nodes(edge_index, num_nodes)
|
|
108
108
|
|
|
109
109
|
if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:
|
|
110
|
-
if not torch_geometric.typing.WITH_PT113:
|
|
111
|
-
raise ValueError("'sort_edge_index' will result in an overflow")
|
|
112
110
|
perm = lexsort(keys=[
|
|
113
111
|
edge_index[int(sort_by_row)],
|
|
114
112
|
edge_index[1 - int(sort_by_row)],
|
torch_geometric/utils/_spmm.py
CHANGED
|
@@ -63,18 +63,20 @@ def spmm(
|
|
|
63
63
|
|
|
64
64
|
# Always convert COO to CSR for more efficient processing:
|
|
65
65
|
if src.layout == torch.sparse_coo:
|
|
66
|
-
warnings.warn(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
66
|
+
warnings.warn(
|
|
67
|
+
f"Converting sparse tensor to CSR format for more "
|
|
68
|
+
f"efficient processing. Consider converting your "
|
|
69
|
+
f"sparse tensor to CSR format beforehand to avoid "
|
|
70
|
+
f"repeated conversion (got '{src.layout}')", stacklevel=2)
|
|
70
71
|
src = src.to_sparse_csr()
|
|
71
72
|
|
|
72
73
|
# Warn in case of CSC format without gradient computation:
|
|
73
74
|
if src.layout == torch.sparse_csc and not other.requires_grad:
|
|
74
|
-
warnings.warn(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
75
|
+
warnings.warn(
|
|
76
|
+
f"Converting sparse tensor to CSR format for more "
|
|
77
|
+
f"efficient processing. Consider converting your "
|
|
78
|
+
f"sparse tensor to CSR format beforehand to avoid "
|
|
79
|
+
f"repeated conversion (got '{src.layout}')", stacklevel=2)
|
|
78
80
|
|
|
79
81
|
# Use the default code path for `sum` reduction (works on CPU/GPU):
|
|
80
82
|
if reduce == 'sum':
|
|
@@ -99,10 +101,11 @@ def spmm(
|
|
|
99
101
|
# TODO The `torch.sparse.mm` code path with the `reduce` argument does
|
|
100
102
|
# not yet support CSC :(
|
|
101
103
|
if src.layout == torch.sparse_csc:
|
|
102
|
-
warnings.warn(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
warnings.warn(
|
|
105
|
+
f"Converting sparse tensor to CSR format for more "
|
|
106
|
+
f"efficient processing. Consider converting your "
|
|
107
|
+
f"sparse tensor to CSR format beforehand to avoid "
|
|
108
|
+
f"repeated conversion (got '{src.layout}')", stacklevel=2)
|
|
106
109
|
src = src.to_sparse_csr()
|
|
107
110
|
|
|
108
111
|
return torch.sparse.mm(src, other, reduce)
|
|
@@ -115,8 +118,7 @@ def spmm(
|
|
|
115
118
|
if src.layout == torch.sparse_csr:
|
|
116
119
|
ptr = src.crow_indices()
|
|
117
120
|
deg = ptr[1:] - ptr[:-1]
|
|
118
|
-
elif
|
|
119
|
-
and src.layout == torch.sparse_csc):
|
|
121
|
+
elif src.layout == torch.sparse_csc:
|
|
120
122
|
assert src.layout == torch.sparse_csc
|
|
121
123
|
ones = torch.ones_like(src.values())
|
|
122
124
|
index = src.row_indices()
|
|
@@ -346,10 +346,12 @@ def k_hop_subgraph(
|
|
|
346
346
|
|
|
347
347
|
subsets = [node_idx]
|
|
348
348
|
|
|
349
|
+
preserved_edge_mask = torch.zeros_like(edge_mask)
|
|
349
350
|
for _ in range(num_hops):
|
|
350
351
|
node_mask.fill_(False)
|
|
351
352
|
node_mask[subsets[-1]] = True
|
|
352
353
|
torch.index_select(node_mask, 0, row, out=edge_mask)
|
|
354
|
+
preserved_edge_mask |= edge_mask
|
|
353
355
|
subsets.append(col[edge_mask])
|
|
354
356
|
|
|
355
357
|
subset, inv = torch.cat(subsets).unique(return_inverse=True)
|
|
@@ -360,6 +362,8 @@ def k_hop_subgraph(
|
|
|
360
362
|
|
|
361
363
|
if not directed:
|
|
362
364
|
edge_mask = node_mask[row] & node_mask[col]
|
|
365
|
+
else:
|
|
366
|
+
edge_mask = preserved_edge_mask
|
|
363
367
|
|
|
364
368
|
edge_index = edge_index[:, edge_mask]
|
|
365
369
|
|
|
@@ -64,7 +64,7 @@ def tree_decomposition(
|
|
|
64
64
|
xs.append(1)
|
|
65
65
|
|
|
66
66
|
# Generate `atom2cliques` mappings.
|
|
67
|
-
atom2cliques: List[List[int]] = [[] for
|
|
67
|
+
atom2cliques: List[List[int]] = [[] for _ in range(mol.GetNumAtoms())]
|
|
68
68
|
for c in range(len(cliques)):
|
|
69
69
|
for atom in cliques[c]:
|
|
70
70
|
atom2cliques[atom].append(c)
|
|
@@ -234,10 +234,10 @@ def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int],
|
|
|
234
234
|
rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone()
|
|
235
235
|
rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes]
|
|
236
236
|
|
|
237
|
-
col = torch.narrow(col, 0, 0, rowptr[-1])
|
|
237
|
+
col = torch.narrow(col, 0, 0, rowptr[-1]) # type: ignore
|
|
238
238
|
|
|
239
239
|
if value is not None:
|
|
240
|
-
value = torch.narrow(value, 0, 0, rowptr[-1])
|
|
240
|
+
value = torch.narrow(value, 0, 0, rowptr[-1]) # type: ignore
|
|
241
241
|
|
|
242
242
|
csr2csc = src.storage._csr2csc
|
|
243
243
|
if csr2csc is not None:
|
|
@@ -12,7 +12,7 @@ def shuffle_node(
|
|
|
12
12
|
training: bool = True,
|
|
13
13
|
) -> Tuple[Tensor, Tensor]:
|
|
14
14
|
r"""Randomly shuffle the feature matrix :obj:`x` along the
|
|
15
|
-
first
|
|
15
|
+
first dimension.
|
|
16
16
|
|
|
17
17
|
The method returns (1) the shuffled :obj:`x`, (2) the permutation
|
|
18
18
|
indicating the orders of original nodes after shuffling.
|
torch_geometric/utils/convert.py
CHANGED
|
@@ -251,13 +251,13 @@ def from_networkx(
|
|
|
251
251
|
if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):
|
|
252
252
|
group_edge_attrs = edge_attrs
|
|
253
253
|
|
|
254
|
-
for
|
|
254
|
+
for _, feat_dict in G.nodes(data=True):
|
|
255
255
|
if set(feat_dict.keys()) != set(node_attrs):
|
|
256
256
|
raise ValueError('Not all nodes contain the same attributes')
|
|
257
257
|
for key, value in feat_dict.items():
|
|
258
258
|
data_dict[str(key)].append(value)
|
|
259
259
|
|
|
260
|
-
for
|
|
260
|
+
for _, _, feat_dict in G.edges(data=True):
|
|
261
261
|
if set(feat_dict.keys()) != set(edge_attrs):
|
|
262
262
|
raise ValueError('Not all edges contain the same attributes')
|
|
263
263
|
for key, value in feat_dict.items():
|
|
@@ -452,15 +452,22 @@ def to_cugraph(
|
|
|
452
452
|
g = cugraph.Graph(directed=directed)
|
|
453
453
|
df = cudf.from_dlpack(to_dlpack(edge_index.t()))
|
|
454
454
|
|
|
455
|
+
df = cudf.DataFrame({
|
|
456
|
+
'source':
|
|
457
|
+
cudf.from_dlpack(to_dlpack(edge_index[0])),
|
|
458
|
+
'destination':
|
|
459
|
+
cudf.from_dlpack(to_dlpack(edge_index[1])),
|
|
460
|
+
})
|
|
461
|
+
|
|
455
462
|
if edge_weight is not None:
|
|
456
463
|
assert edge_weight.dim() == 1
|
|
457
|
-
df['
|
|
464
|
+
df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
|
|
458
465
|
|
|
459
466
|
g.from_cudf_edgelist(
|
|
460
467
|
df,
|
|
461
|
-
source=
|
|
462
|
-
destination=
|
|
463
|
-
edge_attr='
|
|
468
|
+
source='source',
|
|
469
|
+
destination='destination',
|
|
470
|
+
edge_attr='weight' if edge_weight is not None else None,
|
|
464
471
|
renumber=relabel_nodes,
|
|
465
472
|
)
|
|
466
473
|
|
|
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
476
483
|
"""
|
|
477
484
|
df = g.view_edge_list()
|
|
478
485
|
|
|
479
|
-
src = from_dlpack(df[
|
|
480
|
-
dst = from_dlpack(df[
|
|
486
|
+
src = from_dlpack(df[g.source_columns].to_dlpack()).long()
|
|
487
|
+
dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
|
|
481
488
|
edge_index = torch.stack([src, dst], dim=0)
|
|
482
489
|
|
|
483
490
|
edge_weight = None
|
|
484
|
-
if
|
|
485
|
-
edge_weight = from_dlpack(df[
|
|
491
|
+
if g.weight_column is not None:
|
|
492
|
+
edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
|
|
486
493
|
|
|
487
494
|
return edge_index, edge_weight
|
|
488
495
|
|
|
@@ -18,30 +18,51 @@ class SparseCrossEntropy(torch.autograd.Function):
|
|
|
18
18
|
) -> Tensor:
|
|
19
19
|
assert inputs.dim() == 2
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
# Support for both positive and negative weights:
|
|
22
|
+
# Positive weights scale the logits *after* softmax.
|
|
23
|
+
# Negative weights scale the denominator *before* softmax:
|
|
24
|
+
pos_y = edge_label_index
|
|
25
|
+
neg_y = pos_weight = neg_weight = None
|
|
24
26
|
|
|
25
|
-
out = inputs[edge_label_index[0], edge_label_index[1]]
|
|
26
|
-
out.neg_().add_(logsumexp[edge_label_index[0]])
|
|
27
27
|
if edge_label_weight is not None:
|
|
28
|
-
|
|
28
|
+
pos_mask = edge_label_weight >= 0
|
|
29
|
+
pos_y = edge_label_index[:, pos_mask]
|
|
30
|
+
pos_weight = edge_label_weight[pos_mask]
|
|
31
|
+
|
|
32
|
+
if pos_y.size(1) < edge_label_index.size(1):
|
|
33
|
+
neg_mask = ~pos_mask
|
|
34
|
+
neg_y = edge_label_index[:, neg_mask]
|
|
35
|
+
neg_weight = edge_label_weight[neg_mask]
|
|
36
|
+
|
|
37
|
+
if neg_y is not None and neg_weight is not None:
|
|
38
|
+
inputs = inputs.clone()
|
|
39
|
+
inputs[
|
|
40
|
+
neg_y[0],
|
|
41
|
+
neg_y[1],
|
|
42
|
+
] += neg_weight.abs().log().clamp(min=1e-12)
|
|
43
|
+
|
|
44
|
+
logsumexp = inputs.logsumexp(dim=-1)
|
|
45
|
+
ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)
|
|
46
|
+
|
|
47
|
+
out = inputs[pos_y[0], pos_y[1]]
|
|
48
|
+
out.neg_().add_(logsumexp[pos_y[0]])
|
|
49
|
+
if pos_weight is not None:
|
|
50
|
+
out *= pos_weight
|
|
29
51
|
|
|
30
52
|
return out.sum() / inputs.size(0)
|
|
31
53
|
|
|
32
54
|
@staticmethod
|
|
33
55
|
@torch.autograd.function.once_differentiable
|
|
34
56
|
def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:
|
|
35
|
-
inputs,
|
|
36
|
-
ctx.saved_tensors)
|
|
57
|
+
inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors
|
|
37
58
|
|
|
38
59
|
grad_out = grad_out / inputs.size(0)
|
|
39
|
-
grad_out = grad_out.expand(
|
|
60
|
+
grad_out = grad_out.expand(pos_y.size(1))
|
|
40
61
|
|
|
41
|
-
if
|
|
42
|
-
grad_out = grad_out *
|
|
62
|
+
if pos_weight is not None:
|
|
63
|
+
grad_out = grad_out * pos_weight
|
|
43
64
|
|
|
44
|
-
grad_logsumexp = scatter(grad_out,
|
|
65
|
+
grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,
|
|
45
66
|
dim_size=inputs.size(0), reduce='sum')
|
|
46
67
|
|
|
47
68
|
# Gradient computation of `logsumexp`: `grad * (self - result).exp()`
|
|
@@ -49,7 +70,7 @@ class SparseCrossEntropy(torch.autograd.Function):
|
|
|
49
70
|
grad_input.exp_()
|
|
50
71
|
grad_input.mul_(grad_logsumexp.view(-1, 1))
|
|
51
72
|
|
|
52
|
-
grad_input[
|
|
73
|
+
grad_input[pos_y[0], pos_y[1]] -= grad_out
|
|
53
74
|
|
|
54
75
|
return grad_input, None, None
|
|
55
76
|
|