pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -1,4 +1,5 @@
|
|
1
|
-
from
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
2
3
|
|
3
4
|
import torch
|
4
5
|
from torch import Tensor
|
@@ -14,6 +15,76 @@ except Exception:
|
|
14
15
|
BaseMetric = torch.nn.Module # type: ignore
|
15
16
|
|
16
17
|
|
18
|
+
@dataclass(repr=False)
|
19
|
+
class LinkPredMetricData:
|
20
|
+
pred_index_mat: Tensor
|
21
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
|
22
|
+
edge_label_weight: Optional[Tensor] = None
|
23
|
+
|
24
|
+
@property
|
25
|
+
def pred_rel_mat(self) -> Tensor:
|
26
|
+
r"""Returns a matrix indicating the relevance of the `k`-th prediction.
|
27
|
+
If :obj:`edge_label_weight` is not given, relevance will be denoted as
|
28
|
+
binary.
|
29
|
+
"""
|
30
|
+
if hasattr(self, '_pred_rel_mat'):
|
31
|
+
return self._pred_rel_mat # type: ignore
|
32
|
+
|
33
|
+
# Flatten both prediction and ground-truth indices, and determine
|
34
|
+
# overlaps afterwards via `torch.searchsorted`.
|
35
|
+
max_index = max( # type: ignore
|
36
|
+
self.pred_index_mat.max()
|
37
|
+
if self.pred_index_mat.numel() > 0 else 0,
|
38
|
+
self.edge_label_index[1].max()
|
39
|
+
if self.edge_label_index[1].numel() > 0 else 0,
|
40
|
+
) + 1
|
41
|
+
arange = torch.arange(
|
42
|
+
start=0,
|
43
|
+
end=max_index * self.pred_index_mat.size(0), # type: ignore
|
44
|
+
step=max_index, # type: ignore
|
45
|
+
device=self.pred_index_mat.device,
|
46
|
+
).view(-1, 1)
|
47
|
+
flat_pred_index = (self.pred_index_mat + arange).view(-1)
|
48
|
+
flat_label_index = max_index * self.edge_label_index[0]
|
49
|
+
flat_label_index = flat_label_index + self.edge_label_index[1]
|
50
|
+
flat_label_index, perm = flat_label_index.sort()
|
51
|
+
edge_label_weight = self.edge_label_weight
|
52
|
+
if edge_label_weight is not None:
|
53
|
+
assert edge_label_weight.size() == self.edge_label_index[0].size()
|
54
|
+
edge_label_weight = edge_label_weight[perm]
|
55
|
+
|
56
|
+
pos = torch.searchsorted(flat_label_index, flat_pred_index)
|
57
|
+
pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
|
58
|
+
|
59
|
+
pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
|
60
|
+
if edge_label_weight is not None:
|
61
|
+
pred_rel_mat = edge_label_weight[pos].where(
|
62
|
+
pred_rel_mat,
|
63
|
+
pred_rel_mat.new_zeros(1),
|
64
|
+
)
|
65
|
+
pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
|
66
|
+
|
67
|
+
self._pred_rel_mat = pred_rel_mat
|
68
|
+
return pred_rel_mat
|
69
|
+
|
70
|
+
@property
|
71
|
+
def label_count(self) -> Tensor:
|
72
|
+
r"""The number of ground-truth labels for every example."""
|
73
|
+
if hasattr(self, '_label_count'):
|
74
|
+
return self._label_count # type: ignore
|
75
|
+
|
76
|
+
label_count = scatter(
|
77
|
+
torch.ones_like(self.edge_label_index[0]),
|
78
|
+
self.edge_label_index[0],
|
79
|
+
dim=0,
|
80
|
+
dim_size=self.pred_index_mat.size(0),
|
81
|
+
reduce='sum',
|
82
|
+
)
|
83
|
+
|
84
|
+
self._label_count = label_count
|
85
|
+
return label_count
|
86
|
+
|
87
|
+
|
17
88
|
class LinkPredMetric(BaseMetric):
|
18
89
|
r"""An abstract class for computing link prediction retrieval metrics.
|
19
90
|
|
@@ -23,6 +94,7 @@ class LinkPredMetric(BaseMetric):
|
|
23
94
|
is_differentiable: bool = False
|
24
95
|
full_state_update: bool = False
|
25
96
|
higher_is_better: Optional[bool] = None
|
97
|
+
weighted: bool = False
|
26
98
|
|
27
99
|
def __init__(self, k: int) -> None:
|
28
100
|
super().__init__()
|
@@ -47,6 +119,7 @@ class LinkPredMetric(BaseMetric):
|
|
47
119
|
self,
|
48
120
|
pred_index_mat: Tensor,
|
49
121
|
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
122
|
+
edge_label_weight: Optional[Tensor] = None,
|
50
123
|
) -> None:
|
51
124
|
r"""Updates the state variables based on the current mini-batch
|
52
125
|
prediction.
|
@@ -62,45 +135,30 @@ class LinkPredMetric(BaseMetric):
|
|
62
135
|
edge_label_index (torch.Tensor): The ground-truth indices for every
|
63
136
|
example in the mini-batch, given in COO format of shape
|
64
137
|
:obj:`[2, num_ground_truth_indices]`.
|
138
|
+
edge_label_weight (torch.Tensor, optional): The weight of the
|
139
|
+
ground-truth indices for every example in the mini-batch of
|
140
|
+
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
|
141
|
+
a vector of positive values. Required for weighted metrics,
|
142
|
+
ignored otherwise. (default: :obj:`None`)
|
65
143
|
"""
|
66
|
-
if
|
67
|
-
raise ValueError(f"
|
68
|
-
f"
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
edge_label_index[1].max()
|
77
|
-
if edge_label_index[1].numel() > 0 else 0,
|
78
|
-
) + 1
|
79
|
-
arange = torch.arange(
|
80
|
-
start=0,
|
81
|
-
end=max_index * pred_index_mat.size(0),
|
82
|
-
step=max_index,
|
83
|
-
device=pred_index_mat.device,
|
84
|
-
).view(-1, 1)
|
85
|
-
flat_pred_index = (pred_index_mat + arange).view(-1)
|
86
|
-
flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]
|
87
|
-
|
88
|
-
pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
|
89
|
-
pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())
|
90
|
-
|
91
|
-
# Compute the number of targets per example:
|
92
|
-
y_count = scatter(
|
93
|
-
torch.ones_like(edge_label_index[0]),
|
94
|
-
edge_label_index[0],
|
95
|
-
dim=0,
|
96
|
-
dim_size=pred_index_mat.size(0),
|
97
|
-
reduce='sum',
|
144
|
+
if self.weighted and edge_label_weight is None:
|
145
|
+
raise ValueError(f"'edge_label_weight' is a required argument for "
|
146
|
+
f"weighted '{self.__class__.__name__}' metrics")
|
147
|
+
if not self.weighted:
|
148
|
+
edge_label_weight = None
|
149
|
+
|
150
|
+
data = LinkPredMetricData(
|
151
|
+
pred_index_mat=pred_index_mat,
|
152
|
+
edge_label_index=edge_label_index,
|
153
|
+
edge_label_weight=edge_label_weight,
|
98
154
|
)
|
155
|
+
self._update(data)
|
99
156
|
|
100
|
-
|
157
|
+
def _update(self, data: LinkPredMetricData) -> None:
|
158
|
+
metric = self._compute(data)
|
101
159
|
|
102
160
|
self.accum += metric.sum()
|
103
|
-
self.total += (
|
161
|
+
self.total += (data.label_count > 0).sum()
|
104
162
|
|
105
163
|
def compute(self) -> Tensor:
|
106
164
|
r"""Computes the final metric value."""
|
@@ -109,28 +167,159 @@ class LinkPredMetric(BaseMetric):
|
|
109
167
|
return self.accum / self.total
|
110
168
|
|
111
169
|
def reset(self) -> None:
|
112
|
-
r"""
|
170
|
+
r"""Resets metric state variables to their default value."""
|
113
171
|
if WITH_TORCHMETRICS:
|
114
172
|
super().reset()
|
115
173
|
else:
|
116
174
|
self.accum.zero_()
|
117
175
|
self.total.zero_()
|
118
176
|
|
119
|
-
def _compute(self,
|
120
|
-
r"""
|
177
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
178
|
+
r"""Computes the specific metric.
|
121
179
|
To be implemented separately for each metric class.
|
122
180
|
|
123
181
|
Args:
|
124
|
-
|
125
|
-
|
126
|
-
:obj:`i`-th example is correct or not.
|
127
|
-
y_count (torch.Tensor): A vector indicating the number of
|
128
|
-
ground-truth labels for each example.
|
182
|
+
data (LinkPredMetricData): The mini-batch data for computing a link
|
183
|
+
prediction metric per example.
|
129
184
|
"""
|
130
185
|
raise NotImplementedError
|
131
186
|
|
132
187
|
def __repr__(self) -> str:
|
133
|
-
|
188
|
+
weighted_repr = ', weighted=True' if self.weighted else ''
|
189
|
+
return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
|
190
|
+
|
191
|
+
|
192
|
+
class LinkPredMetricCollection(torch.nn.ModuleDict):
|
193
|
+
r"""A collection of metrics to reduce and speed-up computation of link
|
194
|
+
prediction metrics.
|
195
|
+
|
196
|
+
.. code-block:: python
|
197
|
+
|
198
|
+
from torch_geometric.metrics import (
|
199
|
+
LinkPredMAP,
|
200
|
+
LinkPredMetricCollection,
|
201
|
+
LinkPredPrecision,
|
202
|
+
LinkPredRecall,
|
203
|
+
)
|
204
|
+
|
205
|
+
metrics = LinkPredMetricCollection([
|
206
|
+
LinkPredMAP(k=10),
|
207
|
+
LinkPredPrecision(k=100),
|
208
|
+
LinkPredRecall(k=50),
|
209
|
+
])
|
210
|
+
|
211
|
+
metrics.update(pred_index_mat, edge_label_index)
|
212
|
+
out = metrics.compute()
|
213
|
+
metrics.reset()
|
214
|
+
|
215
|
+
print(out)
|
216
|
+
>>> {'LinkPredMAP@10': tensor(0.375),
|
217
|
+
... 'LinkPredPrecision@100': tensor(0.127),
|
218
|
+
... 'LinkPredRecall@50': tensor(0.483)}
|
219
|
+
|
220
|
+
Args:
|
221
|
+
metrics: The link prediction metrics.
|
222
|
+
"""
|
223
|
+
def __init__(
|
224
|
+
self,
|
225
|
+
metrics: Union[
|
226
|
+
List[LinkPredMetric],
|
227
|
+
Dict[str, LinkPredMetric],
|
228
|
+
],
|
229
|
+
) -> None:
|
230
|
+
super().__init__()
|
231
|
+
|
232
|
+
if isinstance(metrics, (list, tuple)):
|
233
|
+
metrics = {
|
234
|
+
f'{metric.__class__.__name__}@{metric.k}': metric
|
235
|
+
for metric in metrics
|
236
|
+
}
|
237
|
+
assert len(metrics) > 0
|
238
|
+
assert isinstance(metrics, dict)
|
239
|
+
|
240
|
+
for name, metric in metrics.items():
|
241
|
+
self[name] = metric
|
242
|
+
|
243
|
+
@property
|
244
|
+
def max_k(self) -> int:
|
245
|
+
r"""The maximum number of top-:math:`k` predictions to evaluate
|
246
|
+
against.
|
247
|
+
"""
|
248
|
+
return max([metric.k for metric in self.values()])
|
249
|
+
|
250
|
+
@property
|
251
|
+
def weighted(self) -> bool:
|
252
|
+
r"""Returns :obj:`True` in case the collection holds at least one
|
253
|
+
weighted link prediction metric.
|
254
|
+
"""
|
255
|
+
return any([metric.weighted for metric in self.values()])
|
256
|
+
|
257
|
+
def update( # type: ignore
|
258
|
+
self,
|
259
|
+
pred_index_mat: Tensor,
|
260
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
261
|
+
edge_label_weight: Optional[Tensor] = None,
|
262
|
+
) -> None:
|
263
|
+
r"""Updates the state variables based on the current mini-batch
|
264
|
+
prediction.
|
265
|
+
|
266
|
+
:meth:`update` can be repeated multiple times to accumulate the results
|
267
|
+
of successive predictions, *e.g.*, inside a mini-batch training or
|
268
|
+
evaluation loop.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
|
272
|
+
every example in the mini-batch with shape
|
273
|
+
:obj:`[batch_size, k]`.
|
274
|
+
edge_label_index (torch.Tensor): The ground-truth indices for every
|
275
|
+
example in the mini-batch, given in COO format of shape
|
276
|
+
:obj:`[2, num_ground_truth_indices]`.
|
277
|
+
edge_label_weight (torch.Tensor, optional): The weight of the
|
278
|
+
ground-truth indices for every example in the mini-batch of
|
279
|
+
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
|
280
|
+
a vector of positive values. Required for weighted metrics,
|
281
|
+
ignored otherwise. (default: :obj:`None`)
|
282
|
+
"""
|
283
|
+
if self.weighted and edge_label_weight is None:
|
284
|
+
raise ValueError(f"'edge_label_weight' is a required argument for "
|
285
|
+
f"weighted '{self.__class__.__name__}' metrics")
|
286
|
+
if not self.weighted:
|
287
|
+
edge_label_weight = None
|
288
|
+
|
289
|
+
data = LinkPredMetricData( # Share metric data across metrics.
|
290
|
+
pred_index_mat=pred_index_mat,
|
291
|
+
edge_label_index=edge_label_index,
|
292
|
+
edge_label_weight=edge_label_weight,
|
293
|
+
)
|
294
|
+
|
295
|
+
for metric in self.values():
|
296
|
+
if metric.weighted:
|
297
|
+
metric._update(data)
|
298
|
+
if WITH_TORCHMETRICS:
|
299
|
+
metric._update_count += 1
|
300
|
+
|
301
|
+
data.edge_label_weight = None
|
302
|
+
if hasattr(data, '_pred_rel_mat'):
|
303
|
+
data._pred_rel_mat = data._pred_rel_mat != 0.0
|
304
|
+
|
305
|
+
for metric in self.values():
|
306
|
+
if not metric.weighted:
|
307
|
+
metric._update(data)
|
308
|
+
if WITH_TORCHMETRICS:
|
309
|
+
metric._update_count += 1
|
310
|
+
|
311
|
+
def compute(self) -> Dict[str, Tensor]:
|
312
|
+
r"""Computes the final metric values."""
|
313
|
+
return {name: metric.compute() for name, metric in self.items()}
|
314
|
+
|
315
|
+
def reset(self) -> None:
|
316
|
+
r"""Reset metric state variables to their default value."""
|
317
|
+
for metric in self.values():
|
318
|
+
metric.reset()
|
319
|
+
|
320
|
+
def __repr__(self) -> str:
|
321
|
+
names = [f' {name}: {metric},\n' for name, metric in self.items()]
|
322
|
+
return f'{self.__class__.__name__}([\n{"".join(names)}])'
|
134
323
|
|
135
324
|
|
136
325
|
class LinkPredPrecision(LinkPredMetric):
|
@@ -140,9 +329,11 @@ class LinkPredPrecision(LinkPredMetric):
|
|
140
329
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
141
330
|
"""
|
142
331
|
higher_is_better: bool = True
|
332
|
+
weighted: bool = False
|
143
333
|
|
144
|
-
def _compute(self,
|
145
|
-
|
334
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
335
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
336
|
+
return pred_rel_mat.sum(dim=-1) / self.k
|
146
337
|
|
147
338
|
|
148
339
|
class LinkPredRecall(LinkPredMetric):
|
@@ -152,9 +343,11 @@ class LinkPredRecall(LinkPredMetric):
|
|
152
343
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
153
344
|
"""
|
154
345
|
higher_is_better: bool = True
|
346
|
+
weighted: bool = False
|
155
347
|
|
156
|
-
def _compute(self,
|
157
|
-
|
348
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
349
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
350
|
+
return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
|
158
351
|
|
159
352
|
|
160
353
|
class LinkPredF1(LinkPredMetric):
|
@@ -164,11 +357,13 @@ class LinkPredF1(LinkPredMetric):
|
|
164
357
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
165
358
|
"""
|
166
359
|
higher_is_better: bool = True
|
360
|
+
weighted: bool = False
|
167
361
|
|
168
|
-
def _compute(self,
|
169
|
-
|
362
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
363
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
364
|
+
isin_count = pred_rel_mat.sum(dim=-1)
|
170
365
|
precision = isin_count / self.k
|
171
|
-
recall = isin_count
|
366
|
+
recall = isin_count / data.label_count.clamp(min=1e-7)
|
172
367
|
return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
|
173
368
|
|
174
369
|
|
@@ -180,12 +375,15 @@ class LinkPredMAP(LinkPredMetric):
|
|
180
375
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
181
376
|
"""
|
182
377
|
higher_is_better: bool = True
|
378
|
+
weighted: bool = False
|
183
379
|
|
184
|
-
def _compute(self,
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
380
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
381
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
382
|
+
device = pred_rel_mat.device
|
383
|
+
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
|
384
|
+
cum_precision = pred_rel_mat.cumsum(dim=1) / arange
|
385
|
+
return ((cum_precision * pred_rel_mat).sum(dim=-1) /
|
386
|
+
data.label_count.clamp(min=1e-7, max=self.k))
|
189
387
|
|
190
388
|
|
191
389
|
class LinkPredNDCG(LinkPredMetric):
|
@@ -194,25 +392,79 @@ class LinkPredNDCG(LinkPredMetric):
|
|
194
392
|
|
195
393
|
Args:
|
196
394
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
395
|
+
weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
|
396
|
+
of ground-truth items according to a relevance score as given by
|
397
|
+
:obj:`edge_label_weight`. (default: :obj:`False`)
|
197
398
|
"""
|
198
399
|
higher_is_better: bool = True
|
400
|
+
weighted: bool = False
|
199
401
|
|
200
|
-
def __init__(self, k: int):
|
402
|
+
def __init__(self, k: int, weighted: bool = False):
|
201
403
|
super().__init__(k=k)
|
404
|
+
self.weighted = weighted
|
202
405
|
|
203
406
|
dtype = torch.get_default_dtype()
|
204
|
-
|
407
|
+
discount = torch.arange(2, k + 2, dtype=dtype).log2()
|
408
|
+
|
409
|
+
self.discount: Tensor
|
410
|
+
self.register_buffer('discount', discount)
|
205
411
|
|
206
|
-
|
207
|
-
|
412
|
+
if not weighted:
|
413
|
+
self.register_buffer('idcg', cumsum(1.0 / discount))
|
414
|
+
else:
|
415
|
+
self.idcg = None
|
208
416
|
|
209
|
-
|
210
|
-
self.
|
417
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
418
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
419
|
+
discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
|
420
|
+
dcg = (pred_rel_mat / discount).sum(dim=-1)
|
211
421
|
|
212
|
-
|
213
|
-
|
214
|
-
|
422
|
+
if not self.weighted:
|
423
|
+
assert self.idcg is not None
|
424
|
+
idcg = self.idcg[data.label_count.clamp(max=self.k)]
|
425
|
+
else:
|
426
|
+
assert data.edge_label_weight is not None
|
427
|
+
# Sort weights within example-wise buckets via two sorts to get the
|
428
|
+
# local index order within buckets:
|
429
|
+
weight, batch = data.edge_label_weight, data.edge_label_index[0]
|
430
|
+
perm1 = weight.argsort(descending=True)
|
431
|
+
perm2 = batch[perm1].argsort(stable=True)
|
432
|
+
global_index = torch.empty_like(perm1)
|
433
|
+
global_index[perm1[perm2]] = torch.arange(
|
434
|
+
global_index.size(0), device=global_index.device)
|
435
|
+
local_index = global_index - cumsum(data.label_count)[batch]
|
436
|
+
|
437
|
+
# Get the discount per local index:
|
438
|
+
discount = torch.cat([
|
439
|
+
self.discount,
|
440
|
+
self.discount.new_full((1, ), fill_value=float('inf')),
|
441
|
+
])
|
442
|
+
discount = discount[local_index.clamp(max=self.k + 1)]
|
443
|
+
|
444
|
+
idcg = scatter( # Apply discount and aggregate:
|
445
|
+
weight / discount,
|
446
|
+
batch,
|
447
|
+
dim_size=data.pred_index_mat.size(0),
|
448
|
+
reduce='sum',
|
449
|
+
)
|
215
450
|
|
216
451
|
out = dcg / idcg
|
217
452
|
out[out.isnan() | out.isinf()] = 0.0
|
218
453
|
return out
|
454
|
+
|
455
|
+
|
456
|
+
class LinkPredMRR(LinkPredMetric):
|
457
|
+
r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
|
458
|
+
Reciprocal Rank).
|
459
|
+
|
460
|
+
Args:
|
461
|
+
k (int): The number of top-:math:`k` predictions to evaluate against.
|
462
|
+
"""
|
463
|
+
higher_is_better: bool = True
|
464
|
+
weighted: bool = False
|
465
|
+
|
466
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
467
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
468
|
+
device = pred_rel_mat.device
|
469
|
+
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
|
470
|
+
return (pred_rel_mat / arange).max(dim=-1)[0]
|
@@ -24,6 +24,8 @@ from .mlp import MLPAggregation
|
|
24
24
|
from .deep_sets import DeepSetsAggregation
|
25
25
|
from .set_transformer import SetTransformerAggregation
|
26
26
|
from .lcm import LCMAggregation
|
27
|
+
from .variance_preserving import VariancePreservingAggregation
|
28
|
+
from .patch_transformer import PatchTransformerAggregation
|
27
29
|
|
28
30
|
__all__ = classes = [
|
29
31
|
'Aggregation',
|
@@ -51,4 +53,6 @@ __all__ = classes = [
|
|
51
53
|
'DeepSetsAggregation',
|
52
54
|
'SetTransformerAggregation',
|
53
55
|
'LCMAggregation',
|
56
|
+
'VariancePreservingAggregation',
|
57
|
+
'PatchTransformerAggregation',
|
54
58
|
]
|
@@ -65,8 +65,6 @@ class AttentionalAggregation(Aggregation):
|
|
65
65
|
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
|
66
66
|
dim: int = -2) -> Tensor:
|
67
67
|
|
68
|
-
self.assert_two_dimensional_input(x, dim)
|
69
|
-
|
70
68
|
if self.gate_mlp is not None:
|
71
69
|
gate = self.gate_mlp(x, batch=index, batch_size=dim_size)
|
72
70
|
else:
|
torch_geometric/nn/aggr/base.py
CHANGED
@@ -25,7 +25,7 @@ class Aggregation(torch.nn.Module):
|
|
25
25
|
Notably, :obj:`index` does not have to be sorted (for most aggregation
|
26
26
|
operators):
|
27
27
|
|
28
|
-
.. code-block::
|
28
|
+
.. code-block:: python
|
29
29
|
|
30
30
|
# Feature matrix holding 10 elements with 64 features each:
|
31
31
|
x = torch.randn(10, 64)
|
@@ -39,7 +39,7 @@ class Aggregation(torch.nn.Module):
|
|
39
39
|
called :obj:`ptr`. Here, elements within the same set need to be grouped
|
40
40
|
together in the input, and :obj:`ptr` defines their boundaries:
|
41
41
|
|
42
|
-
.. code-block::
|
42
|
+
.. code-block:: python
|
43
43
|
|
44
44
|
# Feature matrix holding 10 elements with 64 features each:
|
45
45
|
x = torch.randn(10, 64)
|
@@ -47,7 +47,7 @@ class Aggregation(torch.nn.Module):
|
|
47
47
|
# Define the boundary indices for three sets:
|
48
48
|
ptr = torch.tensor([0, 4, 7, 10])
|
49
49
|
|
50
|
-
output = aggr(x, ptr=ptr) # Output shape: [
|
50
|
+
output = aggr(x, ptr=ptr) # Output shape: [3, 64]
|
51
51
|
|
52
52
|
Note that at least one of :obj:`index` or :obj:`ptr` must be defined.
|
53
53
|
|
@@ -94,11 +94,9 @@ class Aggregation(torch.nn.Module):
|
|
94
94
|
max_num_elements: (int, optional): The maximum number of elements
|
95
95
|
within a single aggregation group. (default: :obj:`None`)
|
96
96
|
"""
|
97
|
-
pass
|
98
97
|
|
99
98
|
def reset_parameters(self):
|
100
99
|
r"""Resets all learnable parameters of the module."""
|
101
|
-
pass
|
102
100
|
|
103
101
|
@disable_dynamic_shapes(required_args=['dim_size'])
|
104
102
|
def __call__(
|