pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +13 -7
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +317 -65
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +3 -5
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +329 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +56 -22
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -1,4 +1,5 @@
1
- from typing import Optional, Tuple, Union
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union
2
3
 
3
4
  import torch
4
5
  from torch import Tensor
@@ -14,6 +15,76 @@ except Exception:
14
15
  BaseMetric = torch.nn.Module # type: ignore
15
16
 
16
17
 
18
+ @dataclass(repr=False)
19
+ class LinkPredMetricData:
20
+ pred_index_mat: Tensor
21
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
22
+ edge_label_weight: Optional[Tensor] = None
23
+
24
+ @property
25
+ def pred_rel_mat(self) -> Tensor:
26
+ r"""Returns a matrix indicating the relevance of the `k`-th prediction.
27
+ If :obj:`edge_label_weight` is not given, relevance will be denoted as
28
+ binary.
29
+ """
30
+ if hasattr(self, '_pred_rel_mat'):
31
+ return self._pred_rel_mat # type: ignore
32
+
33
+ # Flatten both prediction and ground-truth indices, and determine
34
+ # overlaps afterwards via `torch.searchsorted`.
35
+ max_index = max( # type: ignore
36
+ self.pred_index_mat.max()
37
+ if self.pred_index_mat.numel() > 0 else 0,
38
+ self.edge_label_index[1].max()
39
+ if self.edge_label_index[1].numel() > 0 else 0,
40
+ ) + 1
41
+ arange = torch.arange(
42
+ start=0,
43
+ end=max_index * self.pred_index_mat.size(0), # type: ignore
44
+ step=max_index, # type: ignore
45
+ device=self.pred_index_mat.device,
46
+ ).view(-1, 1)
47
+ flat_pred_index = (self.pred_index_mat + arange).view(-1)
48
+ flat_label_index = max_index * self.edge_label_index[0]
49
+ flat_label_index = flat_label_index + self.edge_label_index[1]
50
+ flat_label_index, perm = flat_label_index.sort()
51
+ edge_label_weight = self.edge_label_weight
52
+ if edge_label_weight is not None:
53
+ assert edge_label_weight.size() == self.edge_label_index[0].size()
54
+ edge_label_weight = edge_label_weight[perm]
55
+
56
+ pos = torch.searchsorted(flat_label_index, flat_pred_index)
57
+ pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
58
+
59
+ pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
60
+ if edge_label_weight is not None:
61
+ pred_rel_mat = edge_label_weight[pos].where(
62
+ pred_rel_mat,
63
+ pred_rel_mat.new_zeros(1),
64
+ )
65
+ pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
66
+
67
+ self._pred_rel_mat = pred_rel_mat
68
+ return pred_rel_mat
69
+
70
+ @property
71
+ def label_count(self) -> Tensor:
72
+ r"""The number of ground-truth labels for every example."""
73
+ if hasattr(self, '_label_count'):
74
+ return self._label_count # type: ignore
75
+
76
+ label_count = scatter(
77
+ torch.ones_like(self.edge_label_index[0]),
78
+ self.edge_label_index[0],
79
+ dim=0,
80
+ dim_size=self.pred_index_mat.size(0),
81
+ reduce='sum',
82
+ )
83
+
84
+ self._label_count = label_count
85
+ return label_count
86
+
87
+
17
88
  class LinkPredMetric(BaseMetric):
18
89
  r"""An abstract class for computing link prediction retrieval metrics.
19
90
 
@@ -23,6 +94,7 @@ class LinkPredMetric(BaseMetric):
23
94
  is_differentiable: bool = False
24
95
  full_state_update: bool = False
25
96
  higher_is_better: Optional[bool] = None
97
+ weighted: bool = False
26
98
 
27
99
  def __init__(self, k: int) -> None:
28
100
  super().__init__()
@@ -47,6 +119,7 @@ class LinkPredMetric(BaseMetric):
47
119
  self,
48
120
  pred_index_mat: Tensor,
49
121
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
122
+ edge_label_weight: Optional[Tensor] = None,
50
123
  ) -> None:
51
124
  r"""Updates the state variables based on the current mini-batch
52
125
  prediction.
@@ -62,45 +135,30 @@ class LinkPredMetric(BaseMetric):
62
135
  edge_label_index (torch.Tensor): The ground-truth indices for every
63
136
  example in the mini-batch, given in COO format of shape
64
137
  :obj:`[2, num_ground_truth_indices]`.
138
+ edge_label_weight (torch.Tensor, optional): The weight of the
139
+ ground-truth indices for every example in the mini-batch of
140
+ shape :obj:`[num_ground_truth_indices]`. If given, needs to be
141
+ a vector of positive values. Required for weighted metrics,
142
+ ignored otherwise. (default: :obj:`None`)
65
143
  """
66
- if pred_index_mat.size(1) != self.k:
67
- raise ValueError(f"Expected 'pred_index_mat' to hold {self.k} "
68
- f"many indices for every entry "
69
- f"(got {pred_index_mat.size(1)})")
70
-
71
- # Compute a boolean matrix indicating if the k-th prediction is part of
72
- # the ground-truth. We do this by flattening both prediction and
73
- # target indices, and then determining overlaps via `torch.isin`.
74
- max_index = max( # type: ignore
75
- pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
76
- edge_label_index[1].max()
77
- if edge_label_index[1].numel() > 0 else 0,
78
- ) + 1
79
- arange = torch.arange(
80
- start=0,
81
- end=max_index * pred_index_mat.size(0),
82
- step=max_index,
83
- device=pred_index_mat.device,
84
- ).view(-1, 1)
85
- flat_pred_index = (pred_index_mat + arange).view(-1)
86
- flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]
87
-
88
- pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
89
- pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())
90
-
91
- # Compute the number of targets per example:
92
- y_count = scatter(
93
- torch.ones_like(edge_label_index[0]),
94
- edge_label_index[0],
95
- dim=0,
96
- dim_size=pred_index_mat.size(0),
97
- reduce='sum',
144
+ if self.weighted and edge_label_weight is None:
145
+ raise ValueError(f"'edge_label_weight' is a required argument for "
146
+ f"weighted '{self.__class__.__name__}' metrics")
147
+ if not self.weighted:
148
+ edge_label_weight = None
149
+
150
+ data = LinkPredMetricData(
151
+ pred_index_mat=pred_index_mat,
152
+ edge_label_index=edge_label_index,
153
+ edge_label_weight=edge_label_weight,
98
154
  )
155
+ self._update(data)
99
156
 
100
- metric = self._compute(pred_isin_mat, y_count)
157
+ def _update(self, data: LinkPredMetricData) -> None:
158
+ metric = self._compute(data)
101
159
 
102
160
  self.accum += metric.sum()
103
- self.total += (y_count > 0).sum()
161
+ self.total += (data.label_count > 0).sum()
104
162
 
105
163
  def compute(self) -> Tensor:
106
164
  r"""Computes the final metric value."""
@@ -109,28 +167,159 @@ class LinkPredMetric(BaseMetric):
109
167
  return self.accum / self.total
110
168
 
111
169
  def reset(self) -> None:
112
- r"""Reset metric state variables to their default value."""
170
+ r"""Resets metric state variables to their default value."""
113
171
  if WITH_TORCHMETRICS:
114
172
  super().reset()
115
173
  else:
116
174
  self.accum.zero_()
117
175
  self.total.zero_()
118
176
 
119
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
120
- r"""Compute the specific metric.
177
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
178
+ r"""Computes the specific metric.
121
179
  To be implemented separately for each metric class.
122
180
 
123
181
  Args:
124
- pred_isin_mat (torch.Tensor): A boolean matrix whose :obj:`(i,k)`
125
- element indicates if the :obj:`k`-th prediction for the
126
- :obj:`i`-th example is correct or not.
127
- y_count (torch.Tensor): A vector indicating the number of
128
- ground-truth labels for each example.
182
+ data (LinkPredMetricData): The mini-batch data for computing a link
183
+ prediction metric per example.
129
184
  """
130
185
  raise NotImplementedError
131
186
 
132
187
  def __repr__(self) -> str:
133
- return f'{self.__class__.__name__}(k={self.k})'
188
+ weighted_repr = ', weighted=True' if self.weighted else ''
189
+ return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
190
+
191
+
192
+ class LinkPredMetricCollection(torch.nn.ModuleDict):
193
+ r"""A collection of metrics to reduce and speed-up computation of link
194
+ prediction metrics.
195
+
196
+ .. code-block:: python
197
+
198
+ from torch_geometric.metrics import (
199
+ LinkPredMAP,
200
+ LinkPredMetricCollection,
201
+ LinkPredPrecision,
202
+ LinkPredRecall,
203
+ )
204
+
205
+ metrics = LinkPredMetricCollection([
206
+ LinkPredMAP(k=10),
207
+ LinkPredPrecision(k=100),
208
+ LinkPredRecall(k=50),
209
+ ])
210
+
211
+ metrics.update(pred_index_mat, edge_label_index)
212
+ out = metrics.compute()
213
+ metrics.reset()
214
+
215
+ print(out)
216
+ >>> {'LinkPredMAP@10': tensor(0.375),
217
+ ... 'LinkPredPrecision@100': tensor(0.127),
218
+ ... 'LinkPredRecall@50': tensor(0.483)}
219
+
220
+ Args:
221
+ metrics: The link prediction metrics.
222
+ """
223
+ def __init__(
224
+ self,
225
+ metrics: Union[
226
+ List[LinkPredMetric],
227
+ Dict[str, LinkPredMetric],
228
+ ],
229
+ ) -> None:
230
+ super().__init__()
231
+
232
+ if isinstance(metrics, (list, tuple)):
233
+ metrics = {
234
+ f'{metric.__class__.__name__}@{metric.k}': metric
235
+ for metric in metrics
236
+ }
237
+ assert len(metrics) > 0
238
+ assert isinstance(metrics, dict)
239
+
240
+ for name, metric in metrics.items():
241
+ self[name] = metric
242
+
243
+ @property
244
+ def max_k(self) -> int:
245
+ r"""The maximum number of top-:math:`k` predictions to evaluate
246
+ against.
247
+ """
248
+ return max([metric.k for metric in self.values()])
249
+
250
+ @property
251
+ def weighted(self) -> bool:
252
+ r"""Returns :obj:`True` in case the collection holds at least one
253
+ weighted link prediction metric.
254
+ """
255
+ return any([metric.weighted for metric in self.values()])
256
+
257
+ def update( # type: ignore
258
+ self,
259
+ pred_index_mat: Tensor,
260
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
261
+ edge_label_weight: Optional[Tensor] = None,
262
+ ) -> None:
263
+ r"""Updates the state variables based on the current mini-batch
264
+ prediction.
265
+
266
+ :meth:`update` can be repeated multiple times to accumulate the results
267
+ of successive predictions, *e.g.*, inside a mini-batch training or
268
+ evaluation loop.
269
+
270
+ Args:
271
+ pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
272
+ every example in the mini-batch with shape
273
+ :obj:`[batch_size, k]`.
274
+ edge_label_index (torch.Tensor): The ground-truth indices for every
275
+ example in the mini-batch, given in COO format of shape
276
+ :obj:`[2, num_ground_truth_indices]`.
277
+ edge_label_weight (torch.Tensor, optional): The weight of the
278
+ ground-truth indices for every example in the mini-batch of
279
+ shape :obj:`[num_ground_truth_indices]`. If given, needs to be
280
+ a vector of positive values. Required for weighted metrics,
281
+ ignored otherwise. (default: :obj:`None`)
282
+ """
283
+ if self.weighted and edge_label_weight is None:
284
+ raise ValueError(f"'edge_label_weight' is a required argument for "
285
+ f"weighted '{self.__class__.__name__}' metrics")
286
+ if not self.weighted:
287
+ edge_label_weight = None
288
+
289
+ data = LinkPredMetricData( # Share metric data across metrics.
290
+ pred_index_mat=pred_index_mat,
291
+ edge_label_index=edge_label_index,
292
+ edge_label_weight=edge_label_weight,
293
+ )
294
+
295
+ for metric in self.values():
296
+ if metric.weighted:
297
+ metric._update(data)
298
+ if WITH_TORCHMETRICS:
299
+ metric._update_count += 1
300
+
301
+ data.edge_label_weight = None
302
+ if hasattr(data, '_pred_rel_mat'):
303
+ data._pred_rel_mat = data._pred_rel_mat != 0.0
304
+
305
+ for metric in self.values():
306
+ if not metric.weighted:
307
+ metric._update(data)
308
+ if WITH_TORCHMETRICS:
309
+ metric._update_count += 1
310
+
311
+ def compute(self) -> Dict[str, Tensor]:
312
+ r"""Computes the final metric values."""
313
+ return {name: metric.compute() for name, metric in self.items()}
314
+
315
+ def reset(self) -> None:
316
+ r"""Reset metric state variables to their default value."""
317
+ for metric in self.values():
318
+ metric.reset()
319
+
320
+ def __repr__(self) -> str:
321
+ names = [f' {name}: {metric},\n' for name, metric in self.items()]
322
+ return f'{self.__class__.__name__}([\n{"".join(names)}])'
134
323
 
135
324
 
136
325
  class LinkPredPrecision(LinkPredMetric):
@@ -140,9 +329,11 @@ class LinkPredPrecision(LinkPredMetric):
140
329
  k (int): The number of top-:math:`k` predictions to evaluate against.
141
330
  """
142
331
  higher_is_better: bool = True
332
+ weighted: bool = False
143
333
 
144
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
145
- return pred_isin_mat.sum(dim=-1) / self.k
334
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
335
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
336
+ return pred_rel_mat.sum(dim=-1) / self.k
146
337
 
147
338
 
148
339
  class LinkPredRecall(LinkPredMetric):
@@ -152,9 +343,11 @@ class LinkPredRecall(LinkPredMetric):
152
343
  k (int): The number of top-:math:`k` predictions to evaluate against.
153
344
  """
154
345
  higher_is_better: bool = True
346
+ weighted: bool = False
155
347
 
156
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
157
- return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7)
348
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
349
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
350
+ return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
158
351
 
159
352
 
160
353
  class LinkPredF1(LinkPredMetric):
@@ -164,11 +357,13 @@ class LinkPredF1(LinkPredMetric):
164
357
  k (int): The number of top-:math:`k` predictions to evaluate against.
165
358
  """
166
359
  higher_is_better: bool = True
360
+ weighted: bool = False
167
361
 
168
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
169
- isin_count = pred_isin_mat.sum(dim=-1)
362
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
363
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
364
+ isin_count = pred_rel_mat.sum(dim=-1)
170
365
  precision = isin_count / self.k
171
- recall = isin_count = isin_count / y_count.clamp(min=1e-7)
366
+ recall = isin_count / data.label_count.clamp(min=1e-7)
172
367
  return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
173
368
 
174
369
 
@@ -180,12 +375,15 @@ class LinkPredMAP(LinkPredMetric):
180
375
  k (int): The number of top-:math:`k` predictions to evaluate against.
181
376
  """
182
377
  higher_is_better: bool = True
378
+ weighted: bool = False
183
379
 
184
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
185
- cum_precision = (torch.cumsum(pred_isin_mat, dim=1) /
186
- torch.arange(1, self.k + 1, device=y_count.device))
187
- return ((cum_precision * pred_isin_mat).sum(dim=-1) /
188
- y_count.clamp(min=1e-7, max=self.k))
380
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
381
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
382
+ device = pred_rel_mat.device
383
+ arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
384
+ cum_precision = pred_rel_mat.cumsum(dim=1) / arange
385
+ return ((cum_precision * pred_rel_mat).sum(dim=-1) /
386
+ data.label_count.clamp(min=1e-7, max=self.k))
189
387
 
190
388
 
191
389
  class LinkPredNDCG(LinkPredMetric):
@@ -194,25 +392,79 @@ class LinkPredNDCG(LinkPredMetric):
194
392
 
195
393
  Args:
196
394
  k (int): The number of top-:math:`k` predictions to evaluate against.
395
+ weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
396
+ of ground-truth items according to a relevance score as given by
397
+ :obj:`edge_label_weight`. (default: :obj:`False`)
197
398
  """
198
399
  higher_is_better: bool = True
400
+ weighted: bool = False
199
401
 
200
- def __init__(self, k: int):
402
+ def __init__(self, k: int, weighted: bool = False):
201
403
  super().__init__(k=k)
404
+ self.weighted = weighted
202
405
 
203
406
  dtype = torch.get_default_dtype()
204
- multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2()
407
+ discount = torch.arange(2, k + 2, dtype=dtype).log2()
408
+
409
+ self.discount: Tensor
410
+ self.register_buffer('discount', discount)
205
411
 
206
- self.multiplier: Tensor
207
- self.register_buffer('multiplier', multiplier)
412
+ if not weighted:
413
+ self.register_buffer('idcg', cumsum(1.0 / discount))
414
+ else:
415
+ self.idcg = None
208
416
 
209
- self.idcg: Tensor
210
- self.register_buffer('idcg', cumsum(multiplier))
417
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
418
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
419
+ discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
420
+ dcg = (pred_rel_mat / discount).sum(dim=-1)
211
421
 
212
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
213
- dcg = (pred_isin_mat * self.multiplier.view(1, -1)).sum(dim=-1)
214
- idcg = self.idcg[y_count.clamp(max=self.k)]
422
+ if not self.weighted:
423
+ assert self.idcg is not None
424
+ idcg = self.idcg[data.label_count.clamp(max=self.k)]
425
+ else:
426
+ assert data.edge_label_weight is not None
427
+ # Sort weights within example-wise buckets via two sorts to get the
428
+ # local index order within buckets:
429
+ weight, batch = data.edge_label_weight, data.edge_label_index[0]
430
+ perm1 = weight.argsort(descending=True)
431
+ perm2 = batch[perm1].argsort(stable=True)
432
+ global_index = torch.empty_like(perm1)
433
+ global_index[perm1[perm2]] = torch.arange(
434
+ global_index.size(0), device=global_index.device)
435
+ local_index = global_index - cumsum(data.label_count)[batch]
436
+
437
+ # Get the discount per local index:
438
+ discount = torch.cat([
439
+ self.discount,
440
+ self.discount.new_full((1, ), fill_value=float('inf')),
441
+ ])
442
+ discount = discount[local_index.clamp(max=self.k + 1)]
443
+
444
+ idcg = scatter( # Apply discount and aggregate:
445
+ weight / discount,
446
+ batch,
447
+ dim_size=data.pred_index_mat.size(0),
448
+ reduce='sum',
449
+ )
215
450
 
216
451
  out = dcg / idcg
217
452
  out[out.isnan() | out.isinf()] = 0.0
218
453
  return out
454
+
455
+
456
+ class LinkPredMRR(LinkPredMetric):
457
+ r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
458
+ Reciprocal Rank).
459
+
460
+ Args:
461
+ k (int): The number of top-:math:`k` predictions to evaluate against.
462
+ """
463
+ higher_is_better: bool = True
464
+ weighted: bool = False
465
+
466
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
467
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
468
+ device = pred_rel_mat.device
469
+ arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
470
+ return (pred_rel_mat / arange).max(dim=-1)[0]
@@ -24,6 +24,8 @@ from .mlp import MLPAggregation
24
24
  from .deep_sets import DeepSetsAggregation
25
25
  from .set_transformer import SetTransformerAggregation
26
26
  from .lcm import LCMAggregation
27
+ from .variance_preserving import VariancePreservingAggregation
28
+ from .patch_transformer import PatchTransformerAggregation
27
29
 
28
30
  __all__ = classes = [
29
31
  'Aggregation',
@@ -51,4 +53,6 @@ __all__ = classes = [
51
53
  'DeepSetsAggregation',
52
54
  'SetTransformerAggregation',
53
55
  'LCMAggregation',
56
+ 'VariancePreservingAggregation',
57
+ 'PatchTransformerAggregation',
54
58
  ]
@@ -65,8 +65,6 @@ class AttentionalAggregation(Aggregation):
65
65
  ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
66
66
  dim: int = -2) -> Tensor:
67
67
 
68
- self.assert_two_dimensional_input(x, dim)
69
-
70
68
  if self.gate_mlp is not None:
71
69
  gate = self.gate_mlp(x, batch=index, batch_size=dim_size)
72
70
  else:
@@ -25,7 +25,7 @@ class Aggregation(torch.nn.Module):
25
25
  Notably, :obj:`index` does not have to be sorted (for most aggregation
26
26
  operators):
27
27
 
28
- .. code-block::
28
+ .. code-block:: python
29
29
 
30
30
  # Feature matrix holding 10 elements with 64 features each:
31
31
  x = torch.randn(10, 64)
@@ -39,7 +39,7 @@ class Aggregation(torch.nn.Module):
39
39
  called :obj:`ptr`. Here, elements within the same set need to be grouped
40
40
  together in the input, and :obj:`ptr` defines their boundaries:
41
41
 
42
- .. code-block::
42
+ .. code-block:: python
43
43
 
44
44
  # Feature matrix holding 10 elements with 64 features each:
45
45
  x = torch.randn(10, 64)
@@ -47,7 +47,7 @@ class Aggregation(torch.nn.Module):
47
47
  # Define the boundary indices for three sets:
48
48
  ptr = torch.tensor([0, 4, 7, 10])
49
49
 
50
- output = aggr(x, ptr=ptr) # Output shape: [4, 64]
50
+ output = aggr(x, ptr=ptr) # Output shape: [3, 64]
51
51
 
52
52
  Note that at least one of :obj:`index` or :obj:`ptr` must be defined.
53
53
 
@@ -94,11 +94,9 @@ class Aggregation(torch.nn.Module):
94
94
  max_num_elements: (int, optional): The maximum number of elements
95
95
  within a single aggregation group. (default: :obj:`None`)
96
96
  """
97
- pass
98
97
 
99
98
  def reset_parameters(self):
100
99
  r"""Resets all learnable parameters of the module."""
101
- pass
102
100
 
103
101
  @disable_dynamic_shapes(required_args=['dim_size'])
104
102
  def __call__(