pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,4 +1,5 @@
1
- from typing import Optional, Tuple, Union
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union
2
3
 
3
4
  import torch
4
5
  from torch import Tensor
@@ -14,7 +15,143 @@ except Exception:
14
15
  BaseMetric = torch.nn.Module # type: ignore
15
16
 
16
17
 
17
- class LinkPredMetric(BaseMetric):
18
+ @dataclass(repr=False)
19
+ class LinkPredMetricData:
20
+ pred_index_mat: Tensor
21
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
22
+ edge_label_weight: Optional[Tensor] = None
23
+
24
+ def __post_init__(self) -> None:
25
+ # Filter all negative weights - they should not be used as ground-truth
26
+ if self.edge_label_weight is not None:
27
+ pos_mask = self.edge_label_weight > 0
28
+ self.edge_label_weight = self.edge_label_weight[pos_mask]
29
+ if isinstance(self.edge_label_index, Tensor):
30
+ self.edge_label_index = self.edge_label_index[:, pos_mask]
31
+ else:
32
+ self.edge_label_index = (
33
+ self.edge_label_index[0][pos_mask],
34
+ self.edge_label_index[1][pos_mask],
35
+ )
36
+
37
+ @property
38
+ def pred_rel_mat(self) -> Tensor:
39
+ r"""Returns a matrix indicating the relevance of the `k`-th prediction.
40
+ If :obj:`edge_label_weight` is not given, relevance will be denoted as
41
+ binary.
42
+ """
43
+ if hasattr(self, '_pred_rel_mat'):
44
+ return self._pred_rel_mat # type: ignore
45
+
46
+ if self.edge_label_index[1].numel() == 0:
47
+ self._pred_rel_mat = torch.zeros_like(
48
+ self.pred_index_mat,
49
+ dtype=torch.bool if self.edge_label_weight is None else
50
+ torch.get_default_dtype(),
51
+ )
52
+ return self._pred_rel_mat
53
+
54
+ # Flatten both prediction and ground-truth indices, and determine
55
+ # overlaps afterwards via `torch.searchsorted`.
56
+ max_index = max(
57
+ self.pred_index_mat.max()
58
+ if self.pred_index_mat.numel() > 0 else 0,
59
+ self.edge_label_index[1].max()
60
+ if self.edge_label_index[1].numel() > 0 else 0,
61
+ ) + 1
62
+ arange = torch.arange(
63
+ start=0,
64
+ end=max_index * self.pred_index_mat.size(0), # type: ignore
65
+ step=max_index, # type: ignore
66
+ device=self.pred_index_mat.device,
67
+ ).view(-1, 1)
68
+ flat_pred_index = (self.pred_index_mat + arange).view(-1)
69
+ flat_label_index = max_index * self.edge_label_index[0]
70
+ flat_label_index = flat_label_index + self.edge_label_index[1]
71
+ flat_label_index, perm = flat_label_index.sort()
72
+ edge_label_weight = self.edge_label_weight
73
+ if edge_label_weight is not None:
74
+ assert edge_label_weight.size() == self.edge_label_index[0].size()
75
+ edge_label_weight = edge_label_weight[perm]
76
+
77
+ pos = torch.searchsorted(flat_label_index, flat_pred_index)
78
+ pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
79
+
80
+ pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
81
+ if edge_label_weight is not None:
82
+ pred_rel_mat = edge_label_weight[pos].where(
83
+ pred_rel_mat,
84
+ pred_rel_mat.new_zeros(1),
85
+ )
86
+ pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
87
+
88
+ self._pred_rel_mat = pred_rel_mat
89
+ return pred_rel_mat
90
+
91
+ @property
92
+ def label_count(self) -> Tensor:
93
+ r"""The number of ground-truth labels for every example."""
94
+ if hasattr(self, '_label_count'):
95
+ return self._label_count # type: ignore
96
+
97
+ label_count = scatter(
98
+ torch.ones_like(self.edge_label_index[0]),
99
+ self.edge_label_index[0],
100
+ dim=0,
101
+ dim_size=self.pred_index_mat.size(0),
102
+ reduce='sum',
103
+ )
104
+
105
+ self._label_count = label_count
106
+ return label_count
107
+
108
+ @property
109
+ def label_weight_sum(self) -> Tensor:
110
+ r"""The sum of edge label weights for every example."""
111
+ if self.edge_label_weight is None:
112
+ return self.label_count
113
+
114
+ if hasattr(self, '_label_weight_sum'):
115
+ return self._label_weight_sum # type: ignore
116
+
117
+ label_weight_sum = scatter(
118
+ self.edge_label_weight,
119
+ self.edge_label_index[0],
120
+ dim=0,
121
+ dim_size=self.pred_index_mat.size(0),
122
+ reduce='sum',
123
+ )
124
+
125
+ self._label_weight_sum = label_weight_sum
126
+ return label_weight_sum
127
+
128
+ @property
129
+ def edge_label_weight_pos(self) -> Optional[Tensor]:
130
+ r"""Returns the position of edge label weights in descending order
131
+ within example-wise buckets.
132
+ """
133
+ if self.edge_label_weight is None:
134
+ return None
135
+
136
+ if hasattr(self, '_edge_label_weight_pos'):
137
+ return self._edge_label_weight_pos # type: ignore
138
+
139
+ # Get the permutation via two sorts: One globally on the weights,
140
+ # followed by a (stable) sort on the example indices.
141
+ perm1 = self.edge_label_weight.argsort(descending=True)
142
+ perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
143
+ perm = perm1[perm2]
144
+ # Invert the permutation to get the final position:
145
+ pos = torch.empty_like(perm)
146
+ pos[perm] = torch.arange(perm.size(0), device=perm.device)
147
+ # Normalize position to zero within all buckets:
148
+ pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]
149
+
150
+ self._edge_label_weight_pos = pos
151
+ return pos
152
+
153
+
154
+ class _LinkPredMetric(BaseMetric):
18
155
  r"""An abstract class for computing link prediction retrieval metrics.
19
156
 
20
157
  Args:
@@ -33,20 +170,11 @@ class LinkPredMetric(BaseMetric):
33
170
 
34
171
  self.k = k
35
172
 
36
- self.accum: Tensor
37
- self.total: Tensor
38
-
39
- if WITH_TORCHMETRICS:
40
- self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
41
- self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
42
- else:
43
- self.register_buffer('accum', torch.tensor(0.))
44
- self.register_buffer('total', torch.tensor(0))
45
-
46
173
  def update(
47
174
  self,
48
175
  pred_index_mat: Tensor,
49
176
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
177
+ edge_label_weight: Optional[Tensor] = None,
50
178
  ) -> None:
51
179
  r"""Updates the state variables based on the current mini-batch
52
180
  prediction.
@@ -62,99 +190,293 @@ class LinkPredMetric(BaseMetric):
62
190
  edge_label_index (torch.Tensor): The ground-truth indices for every
63
191
  example in the mini-batch, given in COO format of shape
64
192
  :obj:`[2, num_ground_truth_indices]`.
193
+ edge_label_weight (torch.Tensor, optional): The weight of the
194
+ ground-truth indices for every example in the mini-batch of
195
+ shape :obj:`[num_ground_truth_indices]`. If given, needs to be
196
+ a vector of positive values. Required for weighted metrics,
197
+ ignored otherwise. (default: :obj:`None`)
65
198
  """
66
- if pred_index_mat.size(1) != self.k:
67
- raise ValueError(f"Expected 'pred_index_mat' to hold {self.k} "
68
- f"many indices for every entry "
69
- f"(got {pred_index_mat.size(1)})")
70
-
71
- # Compute a boolean matrix indicating if the k-th prediction is part of
72
- # the ground-truth. We do this by flattening both prediction and
73
- # target indices, and then determining overlaps via `torch.isin`.
74
- max_index = max( # type: ignore
75
- pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
76
- edge_label_index[1].max()
77
- if edge_label_index[1].numel() > 0 else 0,
78
- ) + 1
79
- arange = torch.arange(
80
- start=0,
81
- end=max_index * pred_index_mat.size(0),
82
- step=max_index,
83
- device=pred_index_mat.device,
84
- ).view(-1, 1)
85
- flat_pred_index = (pred_index_mat + arange).view(-1)
86
- flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]
199
+ raise NotImplementedError
87
200
 
88
- pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
89
- pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())
201
+ def compute(self) -> Tensor:
202
+ r"""Computes the final metric value."""
203
+ raise NotImplementedError
90
204
 
91
- # Compute the number of targets per example:
92
- y_count = scatter(
93
- torch.ones_like(edge_label_index[0]),
94
- edge_label_index[0],
95
- dim=0,
96
- dim_size=pred_index_mat.size(0),
97
- reduce='sum',
205
+ def reset(self) -> None:
206
+ r"""Resets metric state variables to their default value."""
207
+ if WITH_TORCHMETRICS:
208
+ super().reset()
209
+ else:
210
+ self._reset()
211
+
212
+ def _reset(self) -> None:
213
+ raise NotImplementedError
214
+
215
+ def __repr__(self) -> str:
216
+ return f'{self.__class__.__name__}(k={self.k})'
217
+
218
+
219
+ class LinkPredMetric(_LinkPredMetric):
220
+ r"""An abstract class for computing link prediction retrieval metrics.
221
+
222
+ Args:
223
+ k (int): The number of top-:math:`k` predictions to evaluate against.
224
+ """
225
+ weighted: bool
226
+
227
+ def __init__(self, k: int) -> None:
228
+ super().__init__(k)
229
+
230
+ self.accum: Tensor
231
+ self.total: Tensor
232
+
233
+ if WITH_TORCHMETRICS:
234
+ self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
235
+ self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
236
+ else:
237
+ self.register_buffer('accum', torch.tensor(0.), persistent=False)
238
+ self.register_buffer('total', torch.tensor(0), persistent=False)
239
+
240
+ def update(
241
+ self,
242
+ pred_index_mat: Tensor,
243
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
244
+ edge_label_weight: Optional[Tensor] = None,
245
+ ) -> None:
246
+ if self.weighted and edge_label_weight is None:
247
+ raise ValueError(f"'edge_label_weight' is a required argument for "
248
+ f"weighted '{self.__class__.__name__}' metrics")
249
+ if not self.weighted:
250
+ edge_label_weight = None
251
+
252
+ data = LinkPredMetricData(
253
+ pred_index_mat=pred_index_mat,
254
+ edge_label_index=edge_label_index,
255
+ edge_label_weight=edge_label_weight,
98
256
  )
257
+ self._update(data)
99
258
 
100
- metric = self._compute(pred_isin_mat, y_count)
259
+ def _update(self, data: LinkPredMetricData) -> None:
260
+ metric = self._compute(data)
101
261
 
102
262
  self.accum += metric.sum()
103
- self.total += (y_count > 0).sum()
263
+ self.total += (data.label_count > 0).sum()
104
264
 
105
265
  def compute(self) -> Tensor:
106
- r"""Computes the final metric value."""
107
266
  if self.total == 0:
108
267
  return torch.zeros_like(self.accum)
109
268
  return self.accum / self.total
110
269
 
111
- def reset(self) -> None:
112
- r"""Reset metric state variables to their default value."""
113
- if WITH_TORCHMETRICS:
114
- super().reset()
115
- else:
116
- self.accum.zero_()
117
- self.total.zero_()
118
-
119
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
120
- r"""Compute the specific metric.
270
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
271
+ r"""Computes the specific metric.
121
272
  To be implemented separately for each metric class.
122
273
 
123
274
  Args:
124
- pred_isin_mat (torch.Tensor): A boolean matrix whose :obj:`(i,k)`
125
- element indicates if the :obj:`k`-th prediction for the
126
- :obj:`i`-th example is correct or not.
127
- y_count (torch.Tensor): A vector indicating the number of
128
- ground-truth labels for each example.
275
+ data (LinkPredMetricData): The mini-batch data for computing a link
276
+ prediction metric per example.
129
277
  """
130
278
  raise NotImplementedError
131
279
 
280
+ def _reset(self) -> None:
281
+ self.accum.zero_()
282
+ self.total.zero_()
283
+
132
284
  def __repr__(self) -> str:
133
- return f'{self.__class__.__name__}(k={self.k})'
285
+ weighted_repr = ', weighted=True' if self.weighted else ''
286
+ return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
287
+
288
+
289
+ class LinkPredMetricCollection(torch.nn.ModuleDict):
290
+ r"""A collection of metrics to reduce and speed-up computation of link
291
+ prediction metrics.
292
+
293
+ .. code-block:: python
294
+
295
+ from torch_geometric.metrics import (
296
+ LinkPredMAP,
297
+ LinkPredMetricCollection,
298
+ LinkPredPrecision,
299
+ LinkPredRecall,
300
+ )
301
+
302
+ metrics = LinkPredMetricCollection([
303
+ LinkPredMAP(k=10),
304
+ LinkPredPrecision(k=100),
305
+ LinkPredRecall(k=50),
306
+ ])
307
+
308
+ metrics.update(pred_index_mat, edge_label_index)
309
+ out = metrics.compute()
310
+ metrics.reset()
311
+
312
+ print(out)
313
+ >>> {'LinkPredMAP@10': tensor(0.375),
314
+ ... 'LinkPredPrecision@100': tensor(0.127),
315
+ ... 'LinkPredRecall@50': tensor(0.483)}
316
+
317
+ Args:
318
+ metrics: The link prediction metrics.
319
+ """
320
+ def __init__(
321
+ self,
322
+ metrics: Union[
323
+ List[LinkPredMetric],
324
+ Dict[str, LinkPredMetric],
325
+ ],
326
+ ) -> None:
327
+ super().__init__()
328
+
329
+ if isinstance(metrics, (list, tuple)):
330
+ metrics = {
331
+ (f'{"Weighted" if getattr(metric, "weighted", False) else ""}'
332
+ f'{metric.__class__.__name__}@{metric.k}'):
333
+ metric
334
+ for metric in metrics
335
+ }
336
+ assert len(metrics) > 0
337
+ assert isinstance(metrics, dict)
338
+
339
+ for name, metric in metrics.items():
340
+ assert isinstance(metric, _LinkPredMetric)
341
+ self[name] = metric
342
+
343
+ @property
344
+ def max_k(self) -> int:
345
+ r"""The maximum number of top-:math:`k` predictions to evaluate
346
+ against.
347
+ """
348
+ return max([
349
+ metric.k # type: ignore[return-value]
350
+ for metric in self.values()
351
+ ]) # type: ignore[type-var]
352
+
353
+ @property
354
+ def weighted(self) -> bool:
355
+ r"""Returns :obj:`True` in case the collection holds at least one
356
+ weighted link prediction metric.
357
+ """
358
+ return any(
359
+ [getattr(metric, 'weighted', False) for metric in self.values()])
360
+
361
+ def update( # type: ignore
362
+ self,
363
+ pred_index_mat: Tensor,
364
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
365
+ edge_label_weight: Optional[Tensor] = None,
366
+ ) -> None:
367
+ r"""Updates the state variables based on the current mini-batch
368
+ prediction.
369
+
370
+ :meth:`update` can be repeated multiple times to accumulate the results
371
+ of successive predictions, *e.g.*, inside a mini-batch training or
372
+ evaluation loop.
373
+
374
+ Args:
375
+ pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
376
+ every example in the mini-batch with shape
377
+ :obj:`[batch_size, k]`.
378
+ edge_label_index (torch.Tensor): The ground-truth indices for every
379
+ example in the mini-batch, given in COO format of shape
380
+ :obj:`[2, num_ground_truth_indices]`.
381
+ edge_label_weight (torch.Tensor, optional): The weight of the
382
+ ground-truth indices for every example in the mini-batch of
383
+ shape :obj:`[num_ground_truth_indices]`. If given, needs to be
384
+ a vector of positive values. Required for weighted metrics,
385
+ ignored otherwise. (default: :obj:`None`)
386
+ """
387
+ if self.weighted and edge_label_weight is None:
388
+ raise ValueError(f"'edge_label_weight' is a required argument for "
389
+ f"weighted '{self.__class__.__name__}' metrics")
390
+
391
+ data = LinkPredMetricData( # Share metric data across metrics.
392
+ pred_index_mat=pred_index_mat,
393
+ edge_label_index=edge_label_index,
394
+ edge_label_weight=edge_label_weight,
395
+ )
396
+
397
+ for metric in self.values():
398
+ if isinstance(metric, LinkPredMetric) and metric.weighted:
399
+ metric._update(data)
400
+ if WITH_TORCHMETRICS:
401
+ metric._update_count += 1
402
+
403
+ data.edge_label_weight = None
404
+ if hasattr(data, '_pred_rel_mat'):
405
+ data._pred_rel_mat = data._pred_rel_mat != 0.0
406
+ if hasattr(data, '_label_weight_sum'):
407
+ del data._label_weight_sum
408
+ if hasattr(data, '_edge_label_weight_pos'):
409
+ del data._edge_label_weight_pos
410
+
411
+ for metric in self.values():
412
+ if isinstance(metric, LinkPredMetric) and not metric.weighted:
413
+ metric._update(data)
414
+ if WITH_TORCHMETRICS:
415
+ metric._update_count += 1
416
+
417
+ for metric in self.values():
418
+ if not isinstance(metric, LinkPredMetric):
419
+ metric.update( # type: ignore[operator]
420
+ pred_index_mat,
421
+ edge_label_index,
422
+ edge_label_weight,
423
+ )
424
+
425
+ def compute(self) -> Dict[str, Tensor]:
426
+ r"""Computes the final metric values."""
427
+ return {
428
+ name: metric.compute() # type: ignore[operator]
429
+ for name, metric in self.items()
430
+ }
431
+
432
+ def reset(self) -> None:
433
+ r"""Reset metric state variables to their default value."""
434
+ for metric in self.values():
435
+ metric.reset() # type: ignore[operator]
436
+
437
+ def __repr__(self) -> str:
438
+ names = [f' {name}: {metric},\n' for name, metric in self.items()]
439
+ return f'{self.__class__.__name__}([\n{"".join(names)}])'
134
440
 
135
441
 
136
442
  class LinkPredPrecision(LinkPredMetric):
137
- r"""A link prediction metric to compute Precision @ :math:`k`.
443
+ r"""A link prediction metric to compute Precision @ :math:`k`, *i.e.* the
444
+ proportion of recommendations within the top-:math:`k` that are actually
445
+ relevant.
446
+
447
+ A higher precision indicates the model's ability to surface relevant items
448
+ early in the ranking.
138
449
 
139
450
  Args:
140
451
  k (int): The number of top-:math:`k` predictions to evaluate against.
141
452
  """
142
453
  higher_is_better: bool = True
454
+ weighted: bool = False
143
455
 
144
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
145
- return pred_isin_mat.sum(dim=-1) / self.k
456
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
457
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
458
+ return pred_rel_mat.sum(dim=-1) / self.k
146
459
 
147
460
 
148
461
  class LinkPredRecall(LinkPredMetric):
149
- r"""A link prediction metric to compute Recall @ :math:`k`.
462
+ r"""A link prediction metric to compute Recall @ :math:`k`, *i.e.* the
463
+ proportion of relevant items that appear within the top-:math:`k`.
464
+
465
+ A higher recall indicates the model's ability to retrieve a larger
466
+ proportion of relevant items.
150
467
 
151
468
  Args:
152
469
  k (int): The number of top-:math:`k` predictions to evaluate against.
153
470
  """
154
471
  higher_is_better: bool = True
155
472
 
156
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
157
- return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7)
473
+ def __init__(self, k: int, weighted: bool = False):
474
+ super().__init__(k=k)
475
+ self.weighted = weighted
476
+
477
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
478
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
479
+ return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)
158
480
 
159
481
 
160
482
  class LinkPredF1(LinkPredMetric):
@@ -164,54 +486,96 @@ class LinkPredF1(LinkPredMetric):
164
486
  k (int): The number of top-:math:`k` predictions to evaluate against.
165
487
  """
166
488
  higher_is_better: bool = True
489
+ weighted: bool = False
167
490
 
168
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
169
- isin_count = pred_isin_mat.sum(dim=-1)
491
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
492
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
493
+ isin_count = pred_rel_mat.sum(dim=-1)
170
494
  precision = isin_count / self.k
171
- recall = isin_count = isin_count / y_count.clamp(min=1e-7)
495
+ recall = isin_count / data.label_count.clamp(min=1e-7)
172
496
  return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
173
497
 
174
498
 
175
499
  class LinkPredMAP(LinkPredMetric):
176
500
  r"""A link prediction metric to compute MAP @ :math:`k` (Mean Average
177
- Precision).
501
+ Precision), considering the order of relevant items within the
502
+ top-:math:`k`.
503
+
504
+ MAP @ :math:`k` can provide a more comprehensive view of ranking quality
505
+ than precision alone.
178
506
 
179
507
  Args:
180
508
  k (int): The number of top-:math:`k` predictions to evaluate against.
181
509
  """
182
510
  higher_is_better: bool = True
511
+ weighted: bool = False
183
512
 
184
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
185
- cum_precision = (torch.cumsum(pred_isin_mat, dim=1) /
186
- torch.arange(1, self.k + 1, device=y_count.device))
187
- return ((cum_precision * pred_isin_mat).sum(dim=-1) /
188
- y_count.clamp(min=1e-7, max=self.k))
513
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
514
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
515
+ device = pred_rel_mat.device
516
+ arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
517
+ cum_precision = pred_rel_mat.cumsum(dim=1) / arange
518
+ return ((cum_precision * pred_rel_mat).sum(dim=-1) /
519
+ data.label_count.clamp(min=1e-7, max=self.k))
189
520
 
190
521
 
191
522
  class LinkPredNDCG(LinkPredMetric):
192
523
  r"""A link prediction metric to compute the NDCG @ :math:`k` (Normalized
193
524
  Discounted Cumulative Gain).
194
525
 
526
+ In particular, can account for the position of relevant items by
527
+ considering relevance scores, giving higher weight to more relevant items
528
+ appearing at the top.
529
+
195
530
  Args:
196
531
  k (int): The number of top-:math:`k` predictions to evaluate against.
532
+ weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
533
+ of ground-truth items according to a relevance score as given by
534
+ :obj:`edge_label_weight`. (default: :obj:`False`)
197
535
  """
198
536
  higher_is_better: bool = True
199
537
 
200
- def __init__(self, k: int):
538
+ def __init__(self, k: int, weighted: bool = False):
201
539
  super().__init__(k=k)
540
+ self.weighted = weighted
202
541
 
203
542
  dtype = torch.get_default_dtype()
204
- multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2()
543
+ discount = torch.arange(2, k + 2, dtype=dtype).log2()
544
+
545
+ self.discount: Tensor
546
+ self.register_buffer('discount', discount, persistent=False)
205
547
 
206
- self.multiplier: Tensor
207
- self.register_buffer('multiplier', multiplier)
548
+ if not weighted:
549
+ self.register_buffer('idcg', cumsum(1.0 / discount),
550
+ persistent=False)
551
+ else:
552
+ self.idcg = None
208
553
 
209
- self.idcg: Tensor
210
- self.register_buffer('idcg', cumsum(multiplier))
554
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
555
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
556
+ discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
557
+ dcg = (pred_rel_mat / discount).sum(dim=-1)
211
558
 
212
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
213
- dcg = (pred_isin_mat * self.multiplier.view(1, -1)).sum(dim=-1)
214
- idcg = self.idcg[y_count.clamp(max=self.k)]
559
+ if not self.weighted:
560
+ assert self.idcg is not None
561
+ idcg = self.idcg[data.label_count.clamp(max=self.k)]
562
+ else:
563
+ assert data.edge_label_weight is not None
564
+ pos = data.edge_label_weight_pos
565
+ assert pos is not None
566
+
567
+ discount = torch.cat([
568
+ self.discount,
569
+ self.discount.new_full((1, ), fill_value=float('inf')),
570
+ ])
571
+ discount = discount[pos.clamp(max=self.k)]
572
+
573
+ idcg = scatter( # Apply discount and aggregate:
574
+ data.edge_label_weight / discount,
575
+ data.edge_label_index[0],
576
+ dim_size=data.pred_index_mat.size(0),
577
+ reduce='sum',
578
+ )
215
579
 
216
580
  out = dcg / idcg
217
581
  out[out.isnan() | out.isinf()] = 0.0
@@ -220,16 +584,305 @@ class LinkPredNDCG(LinkPredMetric):
220
584
 
221
585
  class LinkPredMRR(LinkPredMetric):
222
586
  r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
223
- Reciprocal Rank).
587
+ Reciprocal Rank), *i.e.* the mean reciprocal rank of the first correct
588
+ prediction (or zero otherwise).
589
+
590
+ Args:
591
+ k (int): The number of top-:math:`k` predictions to evaluate against.
592
+ """
593
+ higher_is_better: bool = True
594
+ weighted: bool = False
595
+
596
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
597
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
598
+ device = pred_rel_mat.device
599
+ arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
600
+ return (pred_rel_mat / arange).max(dim=-1)[0]
601
+
602
+
603
+ class LinkPredHitRatio(LinkPredMetric):
604
+ r"""A link prediction metric to compute the hit ratio @ :math:`k`, *i.e.*
605
+ the percentage of users for whom at least one relevant item is present
606
+ within the top-:math:`k` recommendations.
607
+
608
+ A high ratio signifies the model's effectiveness in satisfying a broad
609
+ range of user preferences.
610
+ """
611
+ higher_is_better: bool = True
612
+ weighted: bool = False
613
+
614
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
615
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
616
+ return pred_rel_mat.max(dim=-1)[0].to(torch.get_default_dtype())
617
+
618
+
619
+ class LinkPredCoverage(_LinkPredMetric):
620
+ r"""A link prediction metric to compute the Coverage @ :math:`k` of
621
+ predictions, *i.e.* the percentage of unique items recommended across all
622
+ users within the top-:math:`k`.
623
+
624
+ Higher coverage indicates a wider exploration of the item catalog.
224
625
 
225
626
  Args:
226
627
  k (int): The number of top-:math:`k` predictions to evaluate against.
628
+ num_dst_nodes (int): The total number of destination nodes.
227
629
  """
228
630
  higher_is_better: bool = True
229
631
 
230
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
231
- rank = pred_isin_mat.type(torch.uint8).argmax(dim=-1)
232
- is_correct = pred_isin_mat.gather(1, rank.view(-1, 1)).view(-1)
233
- reciprocals = 1.0 / (rank + 1)
234
- reciprocals[~is_correct] = 0.0
235
- return reciprocals
632
+ def __init__(self, k: int, num_dst_nodes: int) -> None:
633
+ super().__init__(k)
634
+ self.num_dst_nodes = num_dst_nodes
635
+
636
+ self.mask: Tensor
637
+ mask = torch.zeros(num_dst_nodes, dtype=torch.bool)
638
+ if WITH_TORCHMETRICS:
639
+ self.add_state('mask', mask, dist_reduce_fx='max')
640
+ else:
641
+ self.register_buffer('mask', mask, persistent=False)
642
+
643
+ def update(
644
+ self,
645
+ pred_index_mat: Tensor,
646
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
647
+ edge_label_weight: Optional[Tensor] = None,
648
+ ) -> None:
649
+ self.mask[pred_index_mat[:, :self.k].flatten()] = True
650
+
651
+ def compute(self) -> Tensor:
652
+ return self.mask.to(torch.get_default_dtype()).mean()
653
+
654
+ def _reset(self) -> None:
655
+ self.mask.zero_()
656
+
657
+ def __repr__(self) -> str:
658
+ return (f'{self.__class__.__name__}(k={self.k}, '
659
+ f'num_dst_nodes={self.num_dst_nodes})')
660
+
661
+
662
+ class LinkPredDiversity(_LinkPredMetric):
663
+ r"""A link prediction metric to compute the Diversity @ :math:`k` of
664
+ predictions according to item categories.
665
+
666
+ Diversity is computed as
667
+
668
+ .. math::
669
+ div_{u@k} = 1 - \left( \frac{1}{k \cdot (k-1)} \right) \sum_{i \neq j}
670
+ sim(i, j)
671
+
672
+ where
673
+
674
+ .. math::
675
+ sim(i,j) = \begin{cases}
676
+ 1 & \quad \text{if } i,j \text{ share category,}\\
677
+ 0 & \quad \text{otherwise.}
678
+ \end{cases}
679
+
680
+ which measures the pair-wise inequality of recommendations according to
681
+ item categories.
682
+
683
+ Args:
684
+ k (int): The number of top-:math:`k` predictions to evaluate against.
685
+ category (torch.Tensor): A vector that assigns each destination node to
686
+ a specific category.
687
+ """
688
+ higher_is_better: bool = True
689
+
690
+ def __init__(self, k: int, category: Tensor) -> None:
691
+ super().__init__(k)
692
+
693
+ self.accum: Tensor
694
+ self.total: Tensor
695
+
696
+ if WITH_TORCHMETRICS:
697
+ self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
698
+ self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
699
+ else:
700
+ self.register_buffer('accum', torch.tensor(0.), persistent=False)
701
+ self.register_buffer('total', torch.tensor(0), persistent=False)
702
+
703
+ self.category: Tensor
704
+ self.register_buffer('category', category, persistent=False)
705
+
706
+ def update(
707
+ self,
708
+ pred_index_mat: Tensor,
709
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
710
+ edge_label_weight: Optional[Tensor] = None,
711
+ ) -> None:
712
+ category = self.category[pred_index_mat[:, :self.k]]
713
+
714
+ sim = (category.unsqueeze(-2) == category.unsqueeze(-1)).sum(dim=-1)
715
+ div = 1 - 1 / (self.k * (self.k - 1)) * (sim - 1).sum(dim=-1)
716
+
717
+ self.accum += div.sum()
718
+ self.total += pred_index_mat.size(0)
719
+
720
+ def compute(self) -> Tensor:
721
+ if self.total == 0:
722
+ return torch.zeros_like(self.accum)
723
+ return self.accum / self.total
724
+
725
+ def _reset(self) -> None:
726
+ self.accum.zero_()
727
+ self.total.zero_()
728
+
729
+
730
+ class LinkPredPersonalization(_LinkPredMetric):
731
+ r"""A link prediction metric to compute the Personalization @ :math:`k`,
732
+ *i.e.* the dissimilarity of recommendations across different users.
733
+
734
+ Higher personalization suggests that the model tailors recommendations to
735
+ individual user preferences rather than providing generic results.
736
+
737
+ Dissimilarity is defined by the average inverse cosine similarity between
738
+ users' lists of recommendations.
739
+
740
+ Args:
741
+ k (int): The number of top-:math:`k` predictions to evaluate against.
742
+ max_src_nodes (int, optional): The maximum source nodes to consider to
743
+ compute pair-wise dissimilarity. If specified,
744
+ Personalization @ :math:`k` is approximated to avoid computation
745
+ blowup due to quadratic complexity. (default: :obj:`2**12`)
746
+ batch_size (int, optional): The batch size to determine how many pairs
747
+ of user recommendations should be processed at once.
748
+ (default: :obj:`2**16`)
749
+ """
750
+ higher_is_better: bool = True
751
+
752
+ def __init__(
753
+ self,
754
+ k: int,
755
+ max_src_nodes: Optional[int] = 2**12,
756
+ batch_size: int = 2**16,
757
+ ) -> None:
758
+ super().__init__(k)
759
+ self.max_src_nodes = max_src_nodes
760
+ self.batch_size = batch_size
761
+
762
+ self.preds: List[Tensor]
763
+ self.total: Tensor
764
+
765
+ if WITH_TORCHMETRICS:
766
+ self.add_state('preds', default=[], dist_reduce_fx='cat')
767
+ self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
768
+ else:
769
+ self.preds = []
770
+ self.register_buffer('total', torch.tensor(0), persistent=False)
771
+
772
+ def update(
773
+ self,
774
+ pred_index_mat: Tensor,
775
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
776
+ edge_label_weight: Optional[Tensor] = None,
777
+ ) -> None:
778
+
779
+ # NOTE Move to CPU to avoid memory blowup.
780
+ pred_index_mat = pred_index_mat[:, :self.k].cpu()
781
+
782
+ if self.max_src_nodes is None:
783
+ self.preds.append(pred_index_mat)
784
+ self.total += pred_index_mat.size(0)
785
+ elif self.total < self.max_src_nodes:
786
+ remaining = int(self.max_src_nodes - self.total)
787
+ pred_index_mat = pred_index_mat[:remaining]
788
+ self.preds.append(pred_index_mat)
789
+ self.total += pred_index_mat.size(0)
790
+
791
+ def compute(self) -> Tensor:
792
+ device = self.total.device
793
+ score = torch.tensor(0.0, device=device)
794
+ total = torch.tensor(0, device=device)
795
+
796
+ if len(self.preds) == 0:
797
+ return score
798
+
799
+ pred = torch.cat(self.preds, dim=0)
800
+
801
+ if pred.size(0) == 0:
802
+ return score
803
+
804
+ # Calculate all pairs of nodes (e.g., triu_indices with offset=1).
805
+ # NOTE We do this in chunks to avoid memory blow-up, which leads to a
806
+ # more efficient but trickier implementation.
807
+ num_pairs = (pred.size(0) * (pred.size(0) - 1)) // 2
808
+ offset = torch.arange(pred.size(0) - 1, 0, -1, device=device)
809
+ rowptr = cumsum(offset)
810
+ for start in range(0, num_pairs, self.batch_size):
811
+ end = min(start + self.batch_size, num_pairs)
812
+ idx = torch.arange(start, end, device=device)
813
+
814
+ # Find the corresponding row:
815
+ row = torch.searchsorted(rowptr, idx, right=True) - 1
816
+ # Find the corresponding column:
817
+ col = idx - rowptr[row] + (pred.size(0) - offset[row])
818
+
819
+ left = pred[row.cpu()].to(device)
820
+ right = pred[col.cpu()].to(device)
821
+
822
+ # Use offset to work around applying `isin` along a specific dim:
823
+ i = max(int(left.max()), int(right.max())) + 1
824
+ idx = torch.arange(0, i * row.size(0), i, device=device)
825
+ idx = idx.view(-1, 1)
826
+ isin = torch.isin(left + idx, right + idx)
827
+
828
+ # Compute personalization via average inverse cosine similarity:
829
+ cos = isin.sum(dim=-1) / pred.size(1)
830
+ score += (1 - cos).sum()
831
+ total += cos.numel()
832
+
833
+ return score / total
834
+
835
+ def _reset(self) -> None:
836
+ self.preds = []
837
+ self.total.zero_()
838
+
839
+
840
+ class LinkPredAveragePopularity(_LinkPredMetric):
841
+ r"""A link prediction metric to compute the Average Recommendation
842
+ Popularity (ARP) @ :math:`k`, which provides insights into the model's
843
+ tendency to recommend popular items by averaging the popularity scores of
844
+ items within the top-:math:`k` recommendations.
845
+
846
+ Args:
847
+ k (int): The number of top-:math:`k` predictions to evaluate against.
848
+ popularity (torch.Tensor): The popularity of every item in the training
849
+ set, *e.g.*, the number of times an item has been rated.
850
+ """
851
+ higher_is_better: bool = False
852
+
853
+ def __init__(self, k: int, popularity: Tensor) -> None:
854
+ super().__init__(k)
855
+
856
+ self.accum: Tensor
857
+ self.total: Tensor
858
+
859
+ if WITH_TORCHMETRICS:
860
+ self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
861
+ self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
862
+ else:
863
+ self.register_buffer('accum', torch.tensor(0.), persistent=False)
864
+ self.register_buffer('total', torch.tensor(0), persistent=False)
865
+
866
+ self.popularity: Tensor
867
+ self.register_buffer('popularity', popularity, persistent=False)
868
+
869
+ def update(
870
+ self,
871
+ pred_index_mat: Tensor,
872
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
873
+ edge_label_weight: Optional[Tensor] = None,
874
+ ) -> None:
875
+ pred_index_mat = pred_index_mat[:, :self.k]
876
+ popularity = self.popularity[pred_index_mat]
877
+ popularity = popularity.to(self.accum.dtype).mean(dim=-1)
878
+ self.accum += popularity.sum()
879
+ self.total += popularity.numel()
880
+
881
+ def compute(self) -> Tensor:
882
+ if self.total == 0:
883
+ return torch.zeros_like(self.accum)
884
+ return self.accum / self.total
885
+
886
+ def _reset(self) -> None:
887
+ self.accum.zero_()
888
+ self.total.zero_()