pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor
|
|
@@ -14,7 +15,143 @@ except Exception:
|
|
|
14
15
|
BaseMetric = torch.nn.Module # type: ignore
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
@dataclass(repr=False)
|
|
19
|
+
class LinkPredMetricData:
|
|
20
|
+
pred_index_mat: Tensor
|
|
21
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
|
|
22
|
+
edge_label_weight: Optional[Tensor] = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
# Filter all negative weights - they should not be used as ground-truth
|
|
26
|
+
if self.edge_label_weight is not None:
|
|
27
|
+
pos_mask = self.edge_label_weight > 0
|
|
28
|
+
self.edge_label_weight = self.edge_label_weight[pos_mask]
|
|
29
|
+
if isinstance(self.edge_label_index, Tensor):
|
|
30
|
+
self.edge_label_index = self.edge_label_index[:, pos_mask]
|
|
31
|
+
else:
|
|
32
|
+
self.edge_label_index = (
|
|
33
|
+
self.edge_label_index[0][pos_mask],
|
|
34
|
+
self.edge_label_index[1][pos_mask],
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def pred_rel_mat(self) -> Tensor:
|
|
39
|
+
r"""Returns a matrix indicating the relevance of the `k`-th prediction.
|
|
40
|
+
If :obj:`edge_label_weight` is not given, relevance will be denoted as
|
|
41
|
+
binary.
|
|
42
|
+
"""
|
|
43
|
+
if hasattr(self, '_pred_rel_mat'):
|
|
44
|
+
return self._pred_rel_mat # type: ignore
|
|
45
|
+
|
|
46
|
+
if self.edge_label_index[1].numel() == 0:
|
|
47
|
+
self._pred_rel_mat = torch.zeros_like(
|
|
48
|
+
self.pred_index_mat,
|
|
49
|
+
dtype=torch.bool if self.edge_label_weight is None else
|
|
50
|
+
torch.get_default_dtype(),
|
|
51
|
+
)
|
|
52
|
+
return self._pred_rel_mat
|
|
53
|
+
|
|
54
|
+
# Flatten both prediction and ground-truth indices, and determine
|
|
55
|
+
# overlaps afterwards via `torch.searchsorted`.
|
|
56
|
+
max_index = max(
|
|
57
|
+
self.pred_index_mat.max()
|
|
58
|
+
if self.pred_index_mat.numel() > 0 else 0,
|
|
59
|
+
self.edge_label_index[1].max()
|
|
60
|
+
if self.edge_label_index[1].numel() > 0 else 0,
|
|
61
|
+
) + 1
|
|
62
|
+
arange = torch.arange(
|
|
63
|
+
start=0,
|
|
64
|
+
end=max_index * self.pred_index_mat.size(0), # type: ignore
|
|
65
|
+
step=max_index, # type: ignore
|
|
66
|
+
device=self.pred_index_mat.device,
|
|
67
|
+
).view(-1, 1)
|
|
68
|
+
flat_pred_index = (self.pred_index_mat + arange).view(-1)
|
|
69
|
+
flat_label_index = max_index * self.edge_label_index[0]
|
|
70
|
+
flat_label_index = flat_label_index + self.edge_label_index[1]
|
|
71
|
+
flat_label_index, perm = flat_label_index.sort()
|
|
72
|
+
edge_label_weight = self.edge_label_weight
|
|
73
|
+
if edge_label_weight is not None:
|
|
74
|
+
assert edge_label_weight.size() == self.edge_label_index[0].size()
|
|
75
|
+
edge_label_weight = edge_label_weight[perm]
|
|
76
|
+
|
|
77
|
+
pos = torch.searchsorted(flat_label_index, flat_pred_index)
|
|
78
|
+
pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
|
|
79
|
+
|
|
80
|
+
pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
|
|
81
|
+
if edge_label_weight is not None:
|
|
82
|
+
pred_rel_mat = edge_label_weight[pos].where(
|
|
83
|
+
pred_rel_mat,
|
|
84
|
+
pred_rel_mat.new_zeros(1),
|
|
85
|
+
)
|
|
86
|
+
pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
|
|
87
|
+
|
|
88
|
+
self._pred_rel_mat = pred_rel_mat
|
|
89
|
+
return pred_rel_mat
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def label_count(self) -> Tensor:
|
|
93
|
+
r"""The number of ground-truth labels for every example."""
|
|
94
|
+
if hasattr(self, '_label_count'):
|
|
95
|
+
return self._label_count # type: ignore
|
|
96
|
+
|
|
97
|
+
label_count = scatter(
|
|
98
|
+
torch.ones_like(self.edge_label_index[0]),
|
|
99
|
+
self.edge_label_index[0],
|
|
100
|
+
dim=0,
|
|
101
|
+
dim_size=self.pred_index_mat.size(0),
|
|
102
|
+
reduce='sum',
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self._label_count = label_count
|
|
106
|
+
return label_count
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def label_weight_sum(self) -> Tensor:
|
|
110
|
+
r"""The sum of edge label weights for every example."""
|
|
111
|
+
if self.edge_label_weight is None:
|
|
112
|
+
return self.label_count
|
|
113
|
+
|
|
114
|
+
if hasattr(self, '_label_weight_sum'):
|
|
115
|
+
return self._label_weight_sum # type: ignore
|
|
116
|
+
|
|
117
|
+
label_weight_sum = scatter(
|
|
118
|
+
self.edge_label_weight,
|
|
119
|
+
self.edge_label_index[0],
|
|
120
|
+
dim=0,
|
|
121
|
+
dim_size=self.pred_index_mat.size(0),
|
|
122
|
+
reduce='sum',
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self._label_weight_sum = label_weight_sum
|
|
126
|
+
return label_weight_sum
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def edge_label_weight_pos(self) -> Optional[Tensor]:
|
|
130
|
+
r"""Returns the position of edge label weights in descending order
|
|
131
|
+
within example-wise buckets.
|
|
132
|
+
"""
|
|
133
|
+
if self.edge_label_weight is None:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
if hasattr(self, '_edge_label_weight_pos'):
|
|
137
|
+
return self._edge_label_weight_pos # type: ignore
|
|
138
|
+
|
|
139
|
+
# Get the permutation via two sorts: One globally on the weights,
|
|
140
|
+
# followed by a (stable) sort on the example indices.
|
|
141
|
+
perm1 = self.edge_label_weight.argsort(descending=True)
|
|
142
|
+
perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
|
|
143
|
+
perm = perm1[perm2]
|
|
144
|
+
# Invert the permutation to get the final position:
|
|
145
|
+
pos = torch.empty_like(perm)
|
|
146
|
+
pos[perm] = torch.arange(perm.size(0), device=perm.device)
|
|
147
|
+
# Normalize position to zero within all buckets:
|
|
148
|
+
pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]
|
|
149
|
+
|
|
150
|
+
self._edge_label_weight_pos = pos
|
|
151
|
+
return pos
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class _LinkPredMetric(BaseMetric):
|
|
18
155
|
r"""An abstract class for computing link prediction retrieval metrics.
|
|
19
156
|
|
|
20
157
|
Args:
|
|
@@ -33,20 +170,11 @@ class LinkPredMetric(BaseMetric):
|
|
|
33
170
|
|
|
34
171
|
self.k = k
|
|
35
172
|
|
|
36
|
-
self.accum: Tensor
|
|
37
|
-
self.total: Tensor
|
|
38
|
-
|
|
39
|
-
if WITH_TORCHMETRICS:
|
|
40
|
-
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
|
41
|
-
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
|
42
|
-
else:
|
|
43
|
-
self.register_buffer('accum', torch.tensor(0.))
|
|
44
|
-
self.register_buffer('total', torch.tensor(0))
|
|
45
|
-
|
|
46
173
|
def update(
|
|
47
174
|
self,
|
|
48
175
|
pred_index_mat: Tensor,
|
|
49
176
|
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
177
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
50
178
|
) -> None:
|
|
51
179
|
r"""Updates the state variables based on the current mini-batch
|
|
52
180
|
prediction.
|
|
@@ -62,99 +190,293 @@ class LinkPredMetric(BaseMetric):
|
|
|
62
190
|
edge_label_index (torch.Tensor): The ground-truth indices for every
|
|
63
191
|
example in the mini-batch, given in COO format of shape
|
|
64
192
|
:obj:`[2, num_ground_truth_indices]`.
|
|
193
|
+
edge_label_weight (torch.Tensor, optional): The weight of the
|
|
194
|
+
ground-truth indices for every example in the mini-batch of
|
|
195
|
+
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
|
|
196
|
+
a vector of positive values. Required for weighted metrics,
|
|
197
|
+
ignored otherwise. (default: :obj:`None`)
|
|
65
198
|
"""
|
|
66
|
-
|
|
67
|
-
raise ValueError(f"Expected 'pred_index_mat' to hold {self.k} "
|
|
68
|
-
f"many indices for every entry "
|
|
69
|
-
f"(got {pred_index_mat.size(1)})")
|
|
70
|
-
|
|
71
|
-
# Compute a boolean matrix indicating if the k-th prediction is part of
|
|
72
|
-
# the ground-truth. We do this by flattening both prediction and
|
|
73
|
-
# target indices, and then determining overlaps via `torch.isin`.
|
|
74
|
-
max_index = max( # type: ignore
|
|
75
|
-
pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
|
|
76
|
-
edge_label_index[1].max()
|
|
77
|
-
if edge_label_index[1].numel() > 0 else 0,
|
|
78
|
-
) + 1
|
|
79
|
-
arange = torch.arange(
|
|
80
|
-
start=0,
|
|
81
|
-
end=max_index * pred_index_mat.size(0),
|
|
82
|
-
step=max_index,
|
|
83
|
-
device=pred_index_mat.device,
|
|
84
|
-
).view(-1, 1)
|
|
85
|
-
flat_pred_index = (pred_index_mat + arange).view(-1)
|
|
86
|
-
flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]
|
|
199
|
+
raise NotImplementedError
|
|
87
200
|
|
|
88
|
-
|
|
89
|
-
|
|
201
|
+
def compute(self) -> Tensor:
|
|
202
|
+
r"""Computes the final metric value."""
|
|
203
|
+
raise NotImplementedError
|
|
90
204
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
205
|
+
def reset(self) -> None:
|
|
206
|
+
r"""Resets metric state variables to their default value."""
|
|
207
|
+
if WITH_TORCHMETRICS:
|
|
208
|
+
super().reset()
|
|
209
|
+
else:
|
|
210
|
+
self._reset()
|
|
211
|
+
|
|
212
|
+
def _reset(self) -> None:
|
|
213
|
+
raise NotImplementedError
|
|
214
|
+
|
|
215
|
+
def __repr__(self) -> str:
|
|
216
|
+
return f'{self.__class__.__name__}(k={self.k})'
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class LinkPredMetric(_LinkPredMetric):
|
|
220
|
+
r"""An abstract class for computing link prediction retrieval metrics.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
224
|
+
"""
|
|
225
|
+
weighted: bool
|
|
226
|
+
|
|
227
|
+
def __init__(self, k: int) -> None:
|
|
228
|
+
super().__init__(k)
|
|
229
|
+
|
|
230
|
+
self.accum: Tensor
|
|
231
|
+
self.total: Tensor
|
|
232
|
+
|
|
233
|
+
if WITH_TORCHMETRICS:
|
|
234
|
+
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
|
235
|
+
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
|
236
|
+
else:
|
|
237
|
+
self.register_buffer('accum', torch.tensor(0.), persistent=False)
|
|
238
|
+
self.register_buffer('total', torch.tensor(0), persistent=False)
|
|
239
|
+
|
|
240
|
+
def update(
|
|
241
|
+
self,
|
|
242
|
+
pred_index_mat: Tensor,
|
|
243
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
244
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
245
|
+
) -> None:
|
|
246
|
+
if self.weighted and edge_label_weight is None:
|
|
247
|
+
raise ValueError(f"'edge_label_weight' is a required argument for "
|
|
248
|
+
f"weighted '{self.__class__.__name__}' metrics")
|
|
249
|
+
if not self.weighted:
|
|
250
|
+
edge_label_weight = None
|
|
251
|
+
|
|
252
|
+
data = LinkPredMetricData(
|
|
253
|
+
pred_index_mat=pred_index_mat,
|
|
254
|
+
edge_label_index=edge_label_index,
|
|
255
|
+
edge_label_weight=edge_label_weight,
|
|
98
256
|
)
|
|
257
|
+
self._update(data)
|
|
99
258
|
|
|
100
|
-
|
|
259
|
+
def _update(self, data: LinkPredMetricData) -> None:
|
|
260
|
+
metric = self._compute(data)
|
|
101
261
|
|
|
102
262
|
self.accum += metric.sum()
|
|
103
|
-
self.total += (
|
|
263
|
+
self.total += (data.label_count > 0).sum()
|
|
104
264
|
|
|
105
265
|
def compute(self) -> Tensor:
|
|
106
|
-
r"""Computes the final metric value."""
|
|
107
266
|
if self.total == 0:
|
|
108
267
|
return torch.zeros_like(self.accum)
|
|
109
268
|
return self.accum / self.total
|
|
110
269
|
|
|
111
|
-
def
|
|
112
|
-
r"""
|
|
113
|
-
if WITH_TORCHMETRICS:
|
|
114
|
-
super().reset()
|
|
115
|
-
else:
|
|
116
|
-
self.accum.zero_()
|
|
117
|
-
self.total.zero_()
|
|
118
|
-
|
|
119
|
-
def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
|
|
120
|
-
r"""Compute the specific metric.
|
|
270
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
271
|
+
r"""Computes the specific metric.
|
|
121
272
|
To be implemented separately for each metric class.
|
|
122
273
|
|
|
123
274
|
Args:
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
:obj:`i`-th example is correct or not.
|
|
127
|
-
y_count (torch.Tensor): A vector indicating the number of
|
|
128
|
-
ground-truth labels for each example.
|
|
275
|
+
data (LinkPredMetricData): The mini-batch data for computing a link
|
|
276
|
+
prediction metric per example.
|
|
129
277
|
"""
|
|
130
278
|
raise NotImplementedError
|
|
131
279
|
|
|
280
|
+
def _reset(self) -> None:
|
|
281
|
+
self.accum.zero_()
|
|
282
|
+
self.total.zero_()
|
|
283
|
+
|
|
132
284
|
def __repr__(self) -> str:
|
|
133
|
-
|
|
285
|
+
weighted_repr = ', weighted=True' if self.weighted else ''
|
|
286
|
+
return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class LinkPredMetricCollection(torch.nn.ModuleDict):
|
|
290
|
+
r"""A collection of metrics to reduce and speed-up computation of link
|
|
291
|
+
prediction metrics.
|
|
292
|
+
|
|
293
|
+
.. code-block:: python
|
|
294
|
+
|
|
295
|
+
from torch_geometric.metrics import (
|
|
296
|
+
LinkPredMAP,
|
|
297
|
+
LinkPredMetricCollection,
|
|
298
|
+
LinkPredPrecision,
|
|
299
|
+
LinkPredRecall,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
metrics = LinkPredMetricCollection([
|
|
303
|
+
LinkPredMAP(k=10),
|
|
304
|
+
LinkPredPrecision(k=100),
|
|
305
|
+
LinkPredRecall(k=50),
|
|
306
|
+
])
|
|
307
|
+
|
|
308
|
+
metrics.update(pred_index_mat, edge_label_index)
|
|
309
|
+
out = metrics.compute()
|
|
310
|
+
metrics.reset()
|
|
311
|
+
|
|
312
|
+
print(out)
|
|
313
|
+
>>> {'LinkPredMAP@10': tensor(0.375),
|
|
314
|
+
... 'LinkPredPrecision@100': tensor(0.127),
|
|
315
|
+
... 'LinkPredRecall@50': tensor(0.483)}
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
metrics: The link prediction metrics.
|
|
319
|
+
"""
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
metrics: Union[
|
|
323
|
+
List[LinkPredMetric],
|
|
324
|
+
Dict[str, LinkPredMetric],
|
|
325
|
+
],
|
|
326
|
+
) -> None:
|
|
327
|
+
super().__init__()
|
|
328
|
+
|
|
329
|
+
if isinstance(metrics, (list, tuple)):
|
|
330
|
+
metrics = {
|
|
331
|
+
(f'{"Weighted" if getattr(metric, "weighted", False) else ""}'
|
|
332
|
+
f'{metric.__class__.__name__}@{metric.k}'):
|
|
333
|
+
metric
|
|
334
|
+
for metric in metrics
|
|
335
|
+
}
|
|
336
|
+
assert len(metrics) > 0
|
|
337
|
+
assert isinstance(metrics, dict)
|
|
338
|
+
|
|
339
|
+
for name, metric in metrics.items():
|
|
340
|
+
assert isinstance(metric, _LinkPredMetric)
|
|
341
|
+
self[name] = metric
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def max_k(self) -> int:
|
|
345
|
+
r"""The maximum number of top-:math:`k` predictions to evaluate
|
|
346
|
+
against.
|
|
347
|
+
"""
|
|
348
|
+
return max([
|
|
349
|
+
metric.k # type: ignore[return-value]
|
|
350
|
+
for metric in self.values()
|
|
351
|
+
]) # type: ignore[type-var]
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def weighted(self) -> bool:
|
|
355
|
+
r"""Returns :obj:`True` in case the collection holds at least one
|
|
356
|
+
weighted link prediction metric.
|
|
357
|
+
"""
|
|
358
|
+
return any(
|
|
359
|
+
[getattr(metric, 'weighted', False) for metric in self.values()])
|
|
360
|
+
|
|
361
|
+
def update( # type: ignore
|
|
362
|
+
self,
|
|
363
|
+
pred_index_mat: Tensor,
|
|
364
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
365
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
366
|
+
) -> None:
|
|
367
|
+
r"""Updates the state variables based on the current mini-batch
|
|
368
|
+
prediction.
|
|
369
|
+
|
|
370
|
+
:meth:`update` can be repeated multiple times to accumulate the results
|
|
371
|
+
of successive predictions, *e.g.*, inside a mini-batch training or
|
|
372
|
+
evaluation loop.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
|
|
376
|
+
every example in the mini-batch with shape
|
|
377
|
+
:obj:`[batch_size, k]`.
|
|
378
|
+
edge_label_index (torch.Tensor): The ground-truth indices for every
|
|
379
|
+
example in the mini-batch, given in COO format of shape
|
|
380
|
+
:obj:`[2, num_ground_truth_indices]`.
|
|
381
|
+
edge_label_weight (torch.Tensor, optional): The weight of the
|
|
382
|
+
ground-truth indices for every example in the mini-batch of
|
|
383
|
+
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
|
|
384
|
+
a vector of positive values. Required for weighted metrics,
|
|
385
|
+
ignored otherwise. (default: :obj:`None`)
|
|
386
|
+
"""
|
|
387
|
+
if self.weighted and edge_label_weight is None:
|
|
388
|
+
raise ValueError(f"'edge_label_weight' is a required argument for "
|
|
389
|
+
f"weighted '{self.__class__.__name__}' metrics")
|
|
390
|
+
|
|
391
|
+
data = LinkPredMetricData( # Share metric data across metrics.
|
|
392
|
+
pred_index_mat=pred_index_mat,
|
|
393
|
+
edge_label_index=edge_label_index,
|
|
394
|
+
edge_label_weight=edge_label_weight,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
for metric in self.values():
|
|
398
|
+
if isinstance(metric, LinkPredMetric) and metric.weighted:
|
|
399
|
+
metric._update(data)
|
|
400
|
+
if WITH_TORCHMETRICS:
|
|
401
|
+
metric._update_count += 1
|
|
402
|
+
|
|
403
|
+
data.edge_label_weight = None
|
|
404
|
+
if hasattr(data, '_pred_rel_mat'):
|
|
405
|
+
data._pred_rel_mat = data._pred_rel_mat != 0.0
|
|
406
|
+
if hasattr(data, '_label_weight_sum'):
|
|
407
|
+
del data._label_weight_sum
|
|
408
|
+
if hasattr(data, '_edge_label_weight_pos'):
|
|
409
|
+
del data._edge_label_weight_pos
|
|
410
|
+
|
|
411
|
+
for metric in self.values():
|
|
412
|
+
if isinstance(metric, LinkPredMetric) and not metric.weighted:
|
|
413
|
+
metric._update(data)
|
|
414
|
+
if WITH_TORCHMETRICS:
|
|
415
|
+
metric._update_count += 1
|
|
416
|
+
|
|
417
|
+
for metric in self.values():
|
|
418
|
+
if not isinstance(metric, LinkPredMetric):
|
|
419
|
+
metric.update( # type: ignore[operator]
|
|
420
|
+
pred_index_mat,
|
|
421
|
+
edge_label_index,
|
|
422
|
+
edge_label_weight,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
def compute(self) -> Dict[str, Tensor]:
|
|
426
|
+
r"""Computes the final metric values."""
|
|
427
|
+
return {
|
|
428
|
+
name: metric.compute() # type: ignore[operator]
|
|
429
|
+
for name, metric in self.items()
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
def reset(self) -> None:
|
|
433
|
+
r"""Reset metric state variables to their default value."""
|
|
434
|
+
for metric in self.values():
|
|
435
|
+
metric.reset() # type: ignore[operator]
|
|
436
|
+
|
|
437
|
+
def __repr__(self) -> str:
|
|
438
|
+
names = [f' {name}: {metric},\n' for name, metric in self.items()]
|
|
439
|
+
return f'{self.__class__.__name__}([\n{"".join(names)}])'
|
|
134
440
|
|
|
135
441
|
|
|
136
442
|
class LinkPredPrecision(LinkPredMetric):
|
|
137
|
-
r"""A link prediction metric to compute Precision @ :math:`k
|
|
443
|
+
r"""A link prediction metric to compute Precision @ :math:`k`, *i.e.* the
|
|
444
|
+
proportion of recommendations within the top-:math:`k` that are actually
|
|
445
|
+
relevant.
|
|
446
|
+
|
|
447
|
+
A higher precision indicates the model's ability to surface relevant items
|
|
448
|
+
early in the ranking.
|
|
138
449
|
|
|
139
450
|
Args:
|
|
140
451
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
141
452
|
"""
|
|
142
453
|
higher_is_better: bool = True
|
|
454
|
+
weighted: bool = False
|
|
143
455
|
|
|
144
|
-
def _compute(self,
|
|
145
|
-
|
|
456
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
457
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
458
|
+
return pred_rel_mat.sum(dim=-1) / self.k
|
|
146
459
|
|
|
147
460
|
|
|
148
461
|
class LinkPredRecall(LinkPredMetric):
|
|
149
|
-
r"""A link prediction metric to compute Recall @ :math:`k
|
|
462
|
+
r"""A link prediction metric to compute Recall @ :math:`k`, *i.e.* the
|
|
463
|
+
proportion of relevant items that appear within the top-:math:`k`.
|
|
464
|
+
|
|
465
|
+
A higher recall indicates the model's ability to retrieve a larger
|
|
466
|
+
proportion of relevant items.
|
|
150
467
|
|
|
151
468
|
Args:
|
|
152
469
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
153
470
|
"""
|
|
154
471
|
higher_is_better: bool = True
|
|
155
472
|
|
|
156
|
-
def
|
|
157
|
-
|
|
473
|
+
def __init__(self, k: int, weighted: bool = False):
|
|
474
|
+
super().__init__(k=k)
|
|
475
|
+
self.weighted = weighted
|
|
476
|
+
|
|
477
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
478
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
479
|
+
return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)
|
|
158
480
|
|
|
159
481
|
|
|
160
482
|
class LinkPredF1(LinkPredMetric):
|
|
@@ -164,55 +486,403 @@ class LinkPredF1(LinkPredMetric):
|
|
|
164
486
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
165
487
|
"""
|
|
166
488
|
higher_is_better: bool = True
|
|
489
|
+
weighted: bool = False
|
|
167
490
|
|
|
168
|
-
def _compute(self,
|
|
169
|
-
|
|
491
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
492
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
493
|
+
isin_count = pred_rel_mat.sum(dim=-1)
|
|
170
494
|
precision = isin_count / self.k
|
|
171
|
-
recall = isin_count
|
|
495
|
+
recall = isin_count / data.label_count.clamp(min=1e-7)
|
|
172
496
|
return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
|
|
173
497
|
|
|
174
498
|
|
|
175
499
|
class LinkPredMAP(LinkPredMetric):
|
|
176
500
|
r"""A link prediction metric to compute MAP @ :math:`k` (Mean Average
|
|
177
|
-
Precision)
|
|
501
|
+
Precision), considering the order of relevant items within the
|
|
502
|
+
top-:math:`k`.
|
|
503
|
+
|
|
504
|
+
MAP @ :math:`k` can provide a more comprehensive view of ranking quality
|
|
505
|
+
than precision alone.
|
|
178
506
|
|
|
179
507
|
Args:
|
|
180
508
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
181
509
|
"""
|
|
182
510
|
higher_is_better: bool = True
|
|
511
|
+
weighted: bool = False
|
|
183
512
|
|
|
184
|
-
def _compute(self,
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
513
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
514
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
515
|
+
device = pred_rel_mat.device
|
|
516
|
+
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
|
|
517
|
+
cum_precision = pred_rel_mat.cumsum(dim=1) / arange
|
|
518
|
+
return ((cum_precision * pred_rel_mat).sum(dim=-1) /
|
|
519
|
+
data.label_count.clamp(min=1e-7, max=self.k))
|
|
189
520
|
|
|
190
521
|
|
|
191
522
|
class LinkPredNDCG(LinkPredMetric):
|
|
192
523
|
r"""A link prediction metric to compute the NDCG @ :math:`k` (Normalized
|
|
193
524
|
Discounted Cumulative Gain).
|
|
194
525
|
|
|
526
|
+
In particular, can account for the position of relevant items by
|
|
527
|
+
considering relevance scores, giving higher weight to more relevant items
|
|
528
|
+
appearing at the top.
|
|
529
|
+
|
|
195
530
|
Args:
|
|
196
531
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
532
|
+
weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
|
|
533
|
+
of ground-truth items according to a relevance score as given by
|
|
534
|
+
:obj:`edge_label_weight`. (default: :obj:`False`)
|
|
197
535
|
"""
|
|
198
536
|
higher_is_better: bool = True
|
|
199
537
|
|
|
200
|
-
def __init__(self, k: int):
|
|
538
|
+
def __init__(self, k: int, weighted: bool = False):
|
|
201
539
|
super().__init__(k=k)
|
|
540
|
+
self.weighted = weighted
|
|
202
541
|
|
|
203
542
|
dtype = torch.get_default_dtype()
|
|
204
|
-
|
|
543
|
+
discount = torch.arange(2, k + 2, dtype=dtype).log2()
|
|
544
|
+
|
|
545
|
+
self.discount: Tensor
|
|
546
|
+
self.register_buffer('discount', discount, persistent=False)
|
|
205
547
|
|
|
206
|
-
|
|
207
|
-
|
|
548
|
+
if not weighted:
|
|
549
|
+
self.register_buffer('idcg', cumsum(1.0 / discount),
|
|
550
|
+
persistent=False)
|
|
551
|
+
else:
|
|
552
|
+
self.idcg = None
|
|
208
553
|
|
|
209
|
-
|
|
210
|
-
self.
|
|
554
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
555
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
556
|
+
discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
|
|
557
|
+
dcg = (pred_rel_mat / discount).sum(dim=-1)
|
|
211
558
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
559
|
+
if not self.weighted:
|
|
560
|
+
assert self.idcg is not None
|
|
561
|
+
idcg = self.idcg[data.label_count.clamp(max=self.k)]
|
|
562
|
+
else:
|
|
563
|
+
assert data.edge_label_weight is not None
|
|
564
|
+
pos = data.edge_label_weight_pos
|
|
565
|
+
assert pos is not None
|
|
566
|
+
|
|
567
|
+
discount = torch.cat([
|
|
568
|
+
self.discount,
|
|
569
|
+
self.discount.new_full((1, ), fill_value=float('inf')),
|
|
570
|
+
])
|
|
571
|
+
discount = discount[pos.clamp(max=self.k)]
|
|
572
|
+
|
|
573
|
+
idcg = scatter( # Apply discount and aggregate:
|
|
574
|
+
data.edge_label_weight / discount,
|
|
575
|
+
data.edge_label_index[0],
|
|
576
|
+
dim_size=data.pred_index_mat.size(0),
|
|
577
|
+
reduce='sum',
|
|
578
|
+
)
|
|
215
579
|
|
|
216
580
|
out = dcg / idcg
|
|
217
581
|
out[out.isnan() | out.isinf()] = 0.0
|
|
218
582
|
return out
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
class LinkPredMRR(LinkPredMetric):
|
|
586
|
+
r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
|
|
587
|
+
Reciprocal Rank), *i.e.* the mean reciprocal rank of the first correct
|
|
588
|
+
prediction (or zero otherwise).
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
592
|
+
"""
|
|
593
|
+
higher_is_better: bool = True
|
|
594
|
+
weighted: bool = False
|
|
595
|
+
|
|
596
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
597
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
598
|
+
device = pred_rel_mat.device
|
|
599
|
+
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
|
|
600
|
+
return (pred_rel_mat / arange).max(dim=-1)[0]
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class LinkPredHitRatio(LinkPredMetric):
|
|
604
|
+
r"""A link prediction metric to compute the hit ratio @ :math:`k`, *i.e.*
|
|
605
|
+
the percentage of users for whom at least one relevant item is present
|
|
606
|
+
within the top-:math:`k` recommendations.
|
|
607
|
+
|
|
608
|
+
A high ratio signifies the model's effectiveness in satisfying a broad
|
|
609
|
+
range of user preferences.
|
|
610
|
+
"""
|
|
611
|
+
higher_is_better: bool = True
|
|
612
|
+
weighted: bool = False
|
|
613
|
+
|
|
614
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
|
615
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
|
616
|
+
return pred_rel_mat.max(dim=-1)[0].to(torch.get_default_dtype())
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class LinkPredCoverage(_LinkPredMetric):
|
|
620
|
+
r"""A link prediction metric to compute the Coverage @ :math:`k` of
|
|
621
|
+
predictions, *i.e.* the percentage of unique items recommended across all
|
|
622
|
+
users within the top-:math:`k`.
|
|
623
|
+
|
|
624
|
+
Higher coverage indicates a wider exploration of the item catalog.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
628
|
+
num_dst_nodes (int): The total number of destination nodes.
|
|
629
|
+
"""
|
|
630
|
+
higher_is_better: bool = True
|
|
631
|
+
|
|
632
|
+
def __init__(self, k: int, num_dst_nodes: int) -> None:
|
|
633
|
+
super().__init__(k)
|
|
634
|
+
self.num_dst_nodes = num_dst_nodes
|
|
635
|
+
|
|
636
|
+
self.mask: Tensor
|
|
637
|
+
mask = torch.zeros(num_dst_nodes, dtype=torch.bool)
|
|
638
|
+
if WITH_TORCHMETRICS:
|
|
639
|
+
self.add_state('mask', mask, dist_reduce_fx='max')
|
|
640
|
+
else:
|
|
641
|
+
self.register_buffer('mask', mask, persistent=False)
|
|
642
|
+
|
|
643
|
+
def update(
|
|
644
|
+
self,
|
|
645
|
+
pred_index_mat: Tensor,
|
|
646
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
647
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
648
|
+
) -> None:
|
|
649
|
+
self.mask[pred_index_mat[:, :self.k].flatten()] = True
|
|
650
|
+
|
|
651
|
+
def compute(self) -> Tensor:
|
|
652
|
+
return self.mask.to(torch.get_default_dtype()).mean()
|
|
653
|
+
|
|
654
|
+
def _reset(self) -> None:
|
|
655
|
+
self.mask.zero_()
|
|
656
|
+
|
|
657
|
+
def __repr__(self) -> str:
|
|
658
|
+
return (f'{self.__class__.__name__}(k={self.k}, '
|
|
659
|
+
f'num_dst_nodes={self.num_dst_nodes})')
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
class LinkPredDiversity(_LinkPredMetric):
|
|
663
|
+
r"""A link prediction metric to compute the Diversity @ :math:`k` of
|
|
664
|
+
predictions according to item categories.
|
|
665
|
+
|
|
666
|
+
Diversity is computed as
|
|
667
|
+
|
|
668
|
+
.. math::
|
|
669
|
+
div_{u@k} = 1 - \left( \frac{1}{k \cdot (k-1)} \right) \sum_{i \neq j}
|
|
670
|
+
sim(i, j)
|
|
671
|
+
|
|
672
|
+
where
|
|
673
|
+
|
|
674
|
+
.. math::
|
|
675
|
+
sim(i,j) = \begin{cases}
|
|
676
|
+
1 & \quad \text{if } i,j \text{ share category,}\\
|
|
677
|
+
0 & \quad \text{otherwise.}
|
|
678
|
+
\end{cases}
|
|
679
|
+
|
|
680
|
+
which measures the pair-wise inequality of recommendations according to
|
|
681
|
+
item categories.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
685
|
+
category (torch.Tensor): A vector that assigns each destination node to
|
|
686
|
+
a specific category.
|
|
687
|
+
"""
|
|
688
|
+
higher_is_better: bool = True
|
|
689
|
+
|
|
690
|
+
def __init__(self, k: int, category: Tensor) -> None:
|
|
691
|
+
super().__init__(k)
|
|
692
|
+
|
|
693
|
+
self.accum: Tensor
|
|
694
|
+
self.total: Tensor
|
|
695
|
+
|
|
696
|
+
if WITH_TORCHMETRICS:
|
|
697
|
+
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
|
698
|
+
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
|
699
|
+
else:
|
|
700
|
+
self.register_buffer('accum', torch.tensor(0.), persistent=False)
|
|
701
|
+
self.register_buffer('total', torch.tensor(0), persistent=False)
|
|
702
|
+
|
|
703
|
+
self.category: Tensor
|
|
704
|
+
self.register_buffer('category', category, persistent=False)
|
|
705
|
+
|
|
706
|
+
def update(
|
|
707
|
+
self,
|
|
708
|
+
pred_index_mat: Tensor,
|
|
709
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
710
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
711
|
+
) -> None:
|
|
712
|
+
category = self.category[pred_index_mat[:, :self.k]]
|
|
713
|
+
|
|
714
|
+
sim = (category.unsqueeze(-2) == category.unsqueeze(-1)).sum(dim=-1)
|
|
715
|
+
div = 1 - 1 / (self.k * (self.k - 1)) * (sim - 1).sum(dim=-1)
|
|
716
|
+
|
|
717
|
+
self.accum += div.sum()
|
|
718
|
+
self.total += pred_index_mat.size(0)
|
|
719
|
+
|
|
720
|
+
def compute(self) -> Tensor:
|
|
721
|
+
if self.total == 0:
|
|
722
|
+
return torch.zeros_like(self.accum)
|
|
723
|
+
return self.accum / self.total
|
|
724
|
+
|
|
725
|
+
def _reset(self) -> None:
|
|
726
|
+
self.accum.zero_()
|
|
727
|
+
self.total.zero_()
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class LinkPredPersonalization(_LinkPredMetric):
|
|
731
|
+
r"""A link prediction metric to compute the Personalization @ :math:`k`,
|
|
732
|
+
*i.e.* the dissimilarity of recommendations across different users.
|
|
733
|
+
|
|
734
|
+
Higher personalization suggests that the model tailors recommendations to
|
|
735
|
+
individual user preferences rather than providing generic results.
|
|
736
|
+
|
|
737
|
+
Dissimilarity is defined by the average inverse cosine similarity between
|
|
738
|
+
users' lists of recommendations.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
742
|
+
max_src_nodes (int, optional): The maximum source nodes to consider to
|
|
743
|
+
compute pair-wise dissimilarity. If specified,
|
|
744
|
+
Personalization @ :math:`k` is approximated to avoid computation
|
|
745
|
+
blowup due to quadratic complexity. (default: :obj:`2**12`)
|
|
746
|
+
batch_size (int, optional): The batch size to determine how many pairs
|
|
747
|
+
of user recommendations should be processed at once.
|
|
748
|
+
(default: :obj:`2**16`)
|
|
749
|
+
"""
|
|
750
|
+
higher_is_better: bool = True
|
|
751
|
+
|
|
752
|
+
def __init__(
|
|
753
|
+
self,
|
|
754
|
+
k: int,
|
|
755
|
+
max_src_nodes: Optional[int] = 2**12,
|
|
756
|
+
batch_size: int = 2**16,
|
|
757
|
+
) -> None:
|
|
758
|
+
super().__init__(k)
|
|
759
|
+
self.max_src_nodes = max_src_nodes
|
|
760
|
+
self.batch_size = batch_size
|
|
761
|
+
|
|
762
|
+
self.preds: List[Tensor]
|
|
763
|
+
self.total: Tensor
|
|
764
|
+
|
|
765
|
+
if WITH_TORCHMETRICS:
|
|
766
|
+
self.add_state('preds', default=[], dist_reduce_fx='cat')
|
|
767
|
+
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
|
768
|
+
else:
|
|
769
|
+
self.preds = []
|
|
770
|
+
self.register_buffer('total', torch.tensor(0), persistent=False)
|
|
771
|
+
|
|
772
|
+
def update(
|
|
773
|
+
self,
|
|
774
|
+
pred_index_mat: Tensor,
|
|
775
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
776
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
777
|
+
) -> None:
|
|
778
|
+
|
|
779
|
+
# NOTE Move to CPU to avoid memory blowup.
|
|
780
|
+
pred_index_mat = pred_index_mat[:, :self.k].cpu()
|
|
781
|
+
|
|
782
|
+
if self.max_src_nodes is None:
|
|
783
|
+
self.preds.append(pred_index_mat)
|
|
784
|
+
self.total += pred_index_mat.size(0)
|
|
785
|
+
elif self.total < self.max_src_nodes:
|
|
786
|
+
remaining = int(self.max_src_nodes - self.total)
|
|
787
|
+
pred_index_mat = pred_index_mat[:remaining]
|
|
788
|
+
self.preds.append(pred_index_mat)
|
|
789
|
+
self.total += pred_index_mat.size(0)
|
|
790
|
+
|
|
791
|
+
def compute(self) -> Tensor:
|
|
792
|
+
device = self.total.device
|
|
793
|
+
score = torch.tensor(0.0, device=device)
|
|
794
|
+
total = torch.tensor(0, device=device)
|
|
795
|
+
|
|
796
|
+
if len(self.preds) == 0:
|
|
797
|
+
return score
|
|
798
|
+
|
|
799
|
+
pred = torch.cat(self.preds, dim=0)
|
|
800
|
+
|
|
801
|
+
if pred.size(0) == 0:
|
|
802
|
+
return score
|
|
803
|
+
|
|
804
|
+
# Calculate all pairs of nodes (e.g., triu_indices with offset=1).
|
|
805
|
+
# NOTE We do this in chunks to avoid memory blow-up, which leads to a
|
|
806
|
+
# more efficient but trickier implementation.
|
|
807
|
+
num_pairs = (pred.size(0) * (pred.size(0) - 1)) // 2
|
|
808
|
+
offset = torch.arange(pred.size(0) - 1, 0, -1, device=device)
|
|
809
|
+
rowptr = cumsum(offset)
|
|
810
|
+
for start in range(0, num_pairs, self.batch_size):
|
|
811
|
+
end = min(start + self.batch_size, num_pairs)
|
|
812
|
+
idx = torch.arange(start, end, device=device)
|
|
813
|
+
|
|
814
|
+
# Find the corresponding row:
|
|
815
|
+
row = torch.searchsorted(rowptr, idx, right=True) - 1
|
|
816
|
+
# Find the corresponding column:
|
|
817
|
+
col = idx - rowptr[row] + (pred.size(0) - offset[row])
|
|
818
|
+
|
|
819
|
+
left = pred[row.cpu()].to(device)
|
|
820
|
+
right = pred[col.cpu()].to(device)
|
|
821
|
+
|
|
822
|
+
# Use offset to work around applying `isin` along a specific dim:
|
|
823
|
+
i = max(int(left.max()), int(right.max())) + 1
|
|
824
|
+
idx = torch.arange(0, i * row.size(0), i, device=device)
|
|
825
|
+
idx = idx.view(-1, 1)
|
|
826
|
+
isin = torch.isin(left + idx, right + idx)
|
|
827
|
+
|
|
828
|
+
# Compute personalization via average inverse cosine similarity:
|
|
829
|
+
cos = isin.sum(dim=-1) / pred.size(1)
|
|
830
|
+
score += (1 - cos).sum()
|
|
831
|
+
total += cos.numel()
|
|
832
|
+
|
|
833
|
+
return score / total
|
|
834
|
+
|
|
835
|
+
def _reset(self) -> None:
|
|
836
|
+
self.preds = []
|
|
837
|
+
self.total.zero_()
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
class LinkPredAveragePopularity(_LinkPredMetric):
|
|
841
|
+
r"""A link prediction metric to compute the Average Recommendation
|
|
842
|
+
Popularity (ARP) @ :math:`k`, which provides insights into the model's
|
|
843
|
+
tendency to recommend popular items by averaging the popularity scores of
|
|
844
|
+
items within the top-:math:`k` recommendations.
|
|
845
|
+
|
|
846
|
+
Args:
|
|
847
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
|
848
|
+
popularity (torch.Tensor): The popularity of every item in the training
|
|
849
|
+
set, *e.g.*, the number of times an item has been rated.
|
|
850
|
+
"""
|
|
851
|
+
higher_is_better: bool = False
|
|
852
|
+
|
|
853
|
+
def __init__(self, k: int, popularity: Tensor) -> None:
|
|
854
|
+
super().__init__(k)
|
|
855
|
+
|
|
856
|
+
self.accum: Tensor
|
|
857
|
+
self.total: Tensor
|
|
858
|
+
|
|
859
|
+
if WITH_TORCHMETRICS:
|
|
860
|
+
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
|
861
|
+
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
|
862
|
+
else:
|
|
863
|
+
self.register_buffer('accum', torch.tensor(0.), persistent=False)
|
|
864
|
+
self.register_buffer('total', torch.tensor(0), persistent=False)
|
|
865
|
+
|
|
866
|
+
self.popularity: Tensor
|
|
867
|
+
self.register_buffer('popularity', popularity, persistent=False)
|
|
868
|
+
|
|
869
|
+
def update(
|
|
870
|
+
self,
|
|
871
|
+
pred_index_mat: Tensor,
|
|
872
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
873
|
+
edge_label_weight: Optional[Tensor] = None,
|
|
874
|
+
) -> None:
|
|
875
|
+
pred_index_mat = pred_index_mat[:, :self.k]
|
|
876
|
+
popularity = self.popularity[pred_index_mat]
|
|
877
|
+
popularity = popularity.to(self.accum.dtype).mean(dim=-1)
|
|
878
|
+
self.accum += popularity.sum()
|
|
879
|
+
self.total += popularity.numel()
|
|
880
|
+
|
|
881
|
+
def compute(self) -> Tensor:
|
|
882
|
+
if self.total == 0:
|
|
883
|
+
return torch.zeros_like(self.accum)
|
|
884
|
+
return self.accum / self.total
|
|
885
|
+
|
|
886
|
+
def _reset(self) -> None:
|
|
887
|
+
self.accum.zero_()
|
|
888
|
+
self.total.zero_()
|