pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -1,4 +1,4 @@
1
- from typing import Optional, Tuple, Union
1
+ from typing import Dict, List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  from torch import Tensor
@@ -43,34 +43,15 @@ class LinkPredMetric(BaseMetric):
43
43
  self.register_buffer('accum', torch.tensor(0.))
44
44
  self.register_buffer('total', torch.tensor(0))
45
45
 
46
- def update(
47
- self,
46
+ @staticmethod
47
+ def _prepare(
48
48
  pred_index_mat: Tensor,
49
49
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
50
- ) -> None:
51
- r"""Updates the state variables based on the current mini-batch
52
- prediction.
53
-
54
- :meth:`update` can be repeated multiple times to accumulate the results
55
- of successive predictions, *e.g.*, inside a mini-batch training or
56
- evaluation loop.
57
-
58
- Args:
59
- pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
60
- every example in the mini-batch with shape
61
- :obj:`[batch_size, k]`.
62
- edge_label_index (torch.Tensor): The ground-truth indices for every
63
- example in the mini-batch, given in COO format of shape
64
- :obj:`[2, num_ground_truth_indices]`.
65
- """
66
- if pred_index_mat.size(1) != self.k:
67
- raise ValueError(f"Expected 'pred_index_mat' to hold {self.k} "
68
- f"many indices for every entry "
69
- f"(got {pred_index_mat.size(1)})")
70
-
71
- # Compute a boolean matrix indicating if the k-th prediction is part of
72
- # the ground-truth. We do this by flattening both prediction and
73
- # target indices, and then determining overlaps via `torch.isin`.
50
+ ) -> Tuple[Tensor, Tensor]:
51
+ # Compute a boolean matrix indicating if the `k`-th prediction is part
52
+ # of the ground-truth, as well as the number of ground-truths for every
53
+ # example. We do this by flattening both prediction and ground-truth
54
+ # indices, and then determining overlaps via `torch.isin`.
74
55
  max_index = max( # type: ignore
75
56
  pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
76
57
  edge_label_index[1].max()
@@ -78,8 +59,8 @@ class LinkPredMetric(BaseMetric):
78
59
  ) + 1
79
60
  arange = torch.arange(
80
61
  start=0,
81
- end=max_index * pred_index_mat.size(0),
82
- step=max_index,
62
+ end=max_index * pred_index_mat.size(0), # type: ignore
63
+ step=max_index, # type: ignore
83
64
  device=pred_index_mat.device,
84
65
  ).view(-1, 1)
85
66
  flat_pred_index = (pred_index_mat + arange).view(-1)
@@ -88,7 +69,7 @@ class LinkPredMetric(BaseMetric):
88
69
  pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
89
70
  pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())
90
71
 
91
- # Compute the number of targets per example:
72
+ # Compute the number of ground-truths per example:
92
73
  y_count = scatter(
93
74
  torch.ones_like(edge_label_index[0]),
94
75
  edge_label_index[0],
@@ -97,11 +78,41 @@ class LinkPredMetric(BaseMetric):
97
78
  reduce='sum',
98
79
  )
99
80
 
100
- metric = self._compute(pred_isin_mat, y_count)
81
+ return pred_isin_mat, y_count
101
82
 
83
+ def _update_from_prepared(
84
+ self,
85
+ pred_isin_mat: Tensor,
86
+ y_count: Tensor,
87
+ ) -> None:
88
+ metric = self._compute(pred_isin_mat[:, :self.k], y_count)
102
89
  self.accum += metric.sum()
103
90
  self.total += (y_count > 0).sum()
104
91
 
92
+ def update(
93
+ self,
94
+ pred_index_mat: Tensor,
95
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
96
+ ) -> None:
97
+ r"""Updates the state variables based on the current mini-batch
98
+ prediction.
99
+
100
+ :meth:`update` can be repeated multiple times to accumulate the results
101
+ of successive predictions, *e.g.*, inside a mini-batch training or
102
+ evaluation loop.
103
+
104
+ Args:
105
+ pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
106
+ every example in the mini-batch with shape
107
+ :obj:`[batch_size, k]`.
108
+ edge_label_index (torch.Tensor): The ground-truth indices for every
109
+ example in the mini-batch, given in COO format of shape
110
+ :obj:`[2, num_ground_truth_indices]`.
111
+ """
112
+ pred_isin_mat, y_count = self._prepare(pred_index_mat,
113
+ edge_label_index)
114
+ self._update_from_prepared(pred_isin_mat, y_count)
115
+
105
116
  def compute(self) -> Tensor:
106
117
  r"""Computes the final metric value."""
107
118
  if self.total == 0:
@@ -133,6 +144,103 @@ class LinkPredMetric(BaseMetric):
133
144
  return f'{self.__class__.__name__}(k={self.k})'
134
145
 
135
146
 
147
+ class LinkPredMetricCollection(torch.nn.ModuleDict):
148
+ r"""A collection of metrics to reduce and speed-up computation of link
149
+ prediction metrics.
150
+
151
+ .. code-block:: python
152
+
153
+ from torch_geometric.metrics import (
154
+ LinkPredMAP,
155
+ LinkPredMetricCollection,
156
+ LinkPredPrecision,
157
+ LinkPredRecall,
158
+ )
159
+
160
+ metrics = LinkPredMetricCollection([
161
+ LinkPredMAP(k=10),
162
+ LinkPredPrecision(k=100),
163
+ LinkPredRecall(k=50),
164
+ ])
165
+
166
+ metrics.update(pred_index_mat, edge_label_index)
167
+ out = metrics.compute()
168
+ metrics.reset()
169
+
170
+ print(out)
171
+ >>> {'LinkPredMAP@10': tensor(0.375),
172
+ ... 'LinkPredPrecision@100': tensor(0.127),
173
+ ... 'LinkPredRecall@50': tensor(0.483)}
174
+
175
+ Args:
176
+ metrics: The link prediction metrics.
177
+ """
178
+ def __init__(
179
+ self,
180
+ metrics: Union[
181
+ List[LinkPredMetric],
182
+ Dict[str, LinkPredMetric],
183
+ ],
184
+ ) -> None:
185
+ super().__init__()
186
+
187
+ if isinstance(metrics, (list, tuple)):
188
+ metrics = {
189
+ f'{metric.__class__.__name__}@{metric.k}': metric
190
+ for metric in metrics
191
+ }
192
+ assert len(metrics) > 0
193
+ assert isinstance(metrics, dict)
194
+
195
+ for name, metric in metrics.items():
196
+ self[name] = metric
197
+
198
+ @property
199
+ def max_k(self) -> int:
200
+ r"""The maximum number of top-:math:`k` predictions to evaluate
201
+ against.
202
+ """
203
+ return max([metric.k for metric in self.values()])
204
+
205
+ def update( # type: ignore
206
+ self,
207
+ pred_index_mat: Tensor,
208
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
209
+ ) -> None:
210
+ r"""Updates the state variables based on the current mini-batch
211
+ prediction.
212
+
213
+ :meth:`update` can be repeated multiple times to accumulate the results
214
+ of successive predictions, *e.g.*, inside a mini-batch training or
215
+ evaluation loop.
216
+
217
+ Args:
218
+ pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
219
+ every example in the mini-batch with shape
220
+ :obj:`[batch_size, k]`.
221
+ edge_label_index (torch.Tensor): The ground-truth indices for every
222
+ example in the mini-batch, given in COO format of shape
223
+ :obj:`[2, num_ground_truth_indices]`.
224
+ """
225
+ pred_isin_mat, y_count = LinkPredMetric._prepare(
226
+ pred_index_mat, edge_label_index)
227
+ for metric in self.values():
228
+ metric._update_from_prepared(pred_isin_mat, y_count)
229
+
230
+ def compute(self) -> Dict[str, Tensor]:
231
+ r"""Computes the final metric values."""
232
+ return {name: metric.compute() for name, metric in self.items()}
233
+
234
+ def reset(self) -> None:
235
+ r"""Reset metric state variables to their default value."""
236
+ for metric in self.values():
237
+ metric.reset()
238
+
239
+ def __repr__(self) -> str:
240
+ names = [f' {name}: {metric},\n' for name, metric in self.items()]
241
+ return f'{self.__class__.__name__}([\n{"".join(names)}])'
242
+
243
+
136
244
  class LinkPredPrecision(LinkPredMetric):
137
245
  r"""A link prediction metric to compute Precision @ :math:`k`.
138
246
 
@@ -182,8 +290,9 @@ class LinkPredMAP(LinkPredMetric):
182
290
  higher_is_better: bool = True
183
291
 
184
292
  def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
185
- cum_precision = (torch.cumsum(pred_isin_mat, dim=1) /
186
- torch.arange(1, self.k + 1, device=y_count.device))
293
+ device = pred_isin_mat.device
294
+ arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device)
295
+ cum_precision = pred_isin_mat.cumsum(dim=1) / arange
187
296
  return ((cum_precision * pred_isin_mat).sum(dim=-1) /
188
297
  y_count.clamp(min=1e-7, max=self.k))
189
298
 
@@ -210,9 +319,25 @@ class LinkPredNDCG(LinkPredMetric):
210
319
  self.register_buffer('idcg', cumsum(multiplier))
211
320
 
212
321
  def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
213
- dcg = (pred_isin_mat * self.multiplier.view(1, -1)).sum(dim=-1)
322
+ multiplier = self.multiplier[:pred_isin_mat.size(1)].view(1, -1)
323
+ dcg = (pred_isin_mat * multiplier).sum(dim=-1)
214
324
  idcg = self.idcg[y_count.clamp(max=self.k)]
215
325
 
216
326
  out = dcg / idcg
217
327
  out[out.isnan() | out.isinf()] = 0.0
218
328
  return out
329
+
330
+
331
+ class LinkPredMRR(LinkPredMetric):
332
+ r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
333
+ Reciprocal Rank).
334
+
335
+ Args:
336
+ k (int): The number of top-:math:`k` predictions to evaluate against.
337
+ """
338
+ higher_is_better: bool = True
339
+
340
+ def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
341
+ device = pred_isin_mat.device
342
+ arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device)
343
+ return (pred_isin_mat / arange).max(dim=-1)[0]
@@ -24,6 +24,8 @@ from .mlp import MLPAggregation
24
24
  from .deep_sets import DeepSetsAggregation
25
25
  from .set_transformer import SetTransformerAggregation
26
26
  from .lcm import LCMAggregation
27
+ from .variance_preserving import VariancePreservingAggregation
28
+ from .patch_transformer import PatchTransformerAggregation
27
29
 
28
30
  __all__ = classes = [
29
31
  'Aggregation',
@@ -51,4 +53,6 @@ __all__ = classes = [
51
53
  'DeepSetsAggregation',
52
54
  'SetTransformerAggregation',
53
55
  'LCMAggregation',
56
+ 'VariancePreservingAggregation',
57
+ 'PatchTransformerAggregation',
54
58
  ]
@@ -65,8 +65,6 @@ class AttentionalAggregation(Aggregation):
65
65
  ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
66
66
  dim: int = -2) -> Tensor:
67
67
 
68
- self.assert_two_dimensional_input(x, dim)
69
-
70
68
  if self.gate_mlp is not None:
71
69
  gate = self.gate_mlp(x, batch=index, batch_size=dim_size)
72
70
  else:
@@ -25,7 +25,7 @@ class Aggregation(torch.nn.Module):
25
25
  Notably, :obj:`index` does not have to be sorted (for most aggregation
26
26
  operators):
27
27
 
28
- .. code-block::
28
+ .. code-block:: python
29
29
 
30
30
  # Feature matrix holding 10 elements with 64 features each:
31
31
  x = torch.randn(10, 64)
@@ -39,7 +39,7 @@ class Aggregation(torch.nn.Module):
39
39
  called :obj:`ptr`. Here, elements within the same set need to be grouped
40
40
  together in the input, and :obj:`ptr` defines their boundaries:
41
41
 
42
- .. code-block::
42
+ .. code-block:: python
43
43
 
44
44
  # Feature matrix holding 10 elements with 64 features each:
45
45
  x = torch.randn(10, 64)
@@ -94,11 +94,9 @@ class Aggregation(torch.nn.Module):
94
94
  max_num_elements: (int, optional): The maximum number of elements
95
95
  within a single aggregation group. (default: :obj:`None`)
96
96
  """
97
- pass
98
97
 
99
98
  def reset_parameters(self):
100
99
  r"""Resets all learnable parameters of the module."""
101
- pass
102
100
 
103
101
  @disable_dynamic_shapes(required_args=['dim_size'])
104
102
  def __call__(
@@ -0,0 +1,143 @@
1
+ import math
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.experimental import disable_dynamic_shapes
8
+ from torch_geometric.nn.aggr import Aggregation
9
+ from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
10
+ from torch_geometric.nn.encoding import PositionalEncoding
11
+ from torch_geometric.utils import scatter
12
+
13
+
14
+ class PatchTransformerAggregation(Aggregation):
15
+ r"""Performs patch transformer aggregation in which the elements to
16
+ aggregate are processed by multi-head attention blocks across patches, as
17
+ described in the `"Simplifying Temporal Heterogeneous Network for
18
+ Continuous-Time Link Prediction"
19
+ <https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.
20
+
21
+ Args:
22
+ in_channels (int): Size of each input sample.
23
+ out_channels (int): Size of each output sample.
24
+ patch_size (int): Number of elements in a patch.
25
+ hidden_channels (int): Intermediate size of each sample.
26
+ num_transformer_blocks (int, optional): Number of transformer blocks
27
+ (default: :obj:`1`).
28
+ heads (int, optional): Number of multi-head-attentions.
29
+ (default: :obj:`1`)
30
+ dropout (float, optional): Dropout probability of attention weights.
31
+ (default: :obj:`0.0`)
32
+ aggr (str or list[str], optional): The aggregation module, *e.g.*,
33
+ :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
34
+ :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
35
+ """
36
+ def __init__(
37
+ self,
38
+ in_channels: int,
39
+ out_channels: int,
40
+ patch_size: int,
41
+ hidden_channels: int,
42
+ num_transformer_blocks: int = 1,
43
+ heads: int = 1,
44
+ dropout: float = 0.0,
45
+ aggr: Union[str, List[str]] = 'mean',
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ self.in_channels = in_channels
50
+ self.out_channels = out_channels
51
+ self.patch_size = patch_size
52
+ self.aggrs = [aggr] if isinstance(aggr, str) else aggr
53
+
54
+ assert len(self.aggrs) > 0
55
+ for aggr in self.aggrs:
56
+ assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
57
+
58
+ self.lin = torch.nn.Linear(in_channels, hidden_channels)
59
+ self.pad_projector = torch.nn.Linear(
60
+ patch_size * hidden_channels,
61
+ hidden_channels,
62
+ )
63
+ self.pe = PositionalEncoding(hidden_channels)
64
+
65
+ self.blocks = torch.nn.ModuleList([
66
+ MultiheadAttentionBlock(
67
+ channels=hidden_channels,
68
+ heads=heads,
69
+ layer_norm=True,
70
+ dropout=dropout,
71
+ ) for _ in range(num_transformer_blocks)
72
+ ])
73
+
74
+ self.fc = torch.nn.Linear(
75
+ hidden_channels * len(self.aggrs),
76
+ out_channels,
77
+ )
78
+
79
+ def reset_parameters(self) -> None:
80
+ self.lin.reset_parameters()
81
+ self.pad_projector.reset_parameters()
82
+ self.pe.reset_parameters()
83
+ for block in self.blocks:
84
+ block.reset_parameters()
85
+ self.fc.reset_parameters()
86
+
87
+ @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
88
+ def forward(
89
+ self,
90
+ x: Tensor,
91
+ index: Tensor,
92
+ ptr: Optional[Tensor] = None,
93
+ dim_size: Optional[int] = None,
94
+ dim: int = -2,
95
+ max_num_elements: Optional[int] = None,
96
+ ) -> Tensor:
97
+
98
+ if max_num_elements is None:
99
+ if ptr is not None:
100
+ count = ptr.diff()
101
+ else:
102
+ count = scatter(torch.ones_like(index), index, dim=0,
103
+ dim_size=dim_size, reduce='sum')
104
+ max_num_elements = int(count.max()) + 1
105
+
106
+ # Set `max_num_elements` to a multiple of `patch_size`:
107
+ max_num_elements = (math.floor(max_num_elements / self.patch_size) *
108
+ self.patch_size)
109
+
110
+ x = self.lin(x)
111
+
112
+ # TODO If groups are heavily unbalanced, this will create a lot of
113
+ # "empty" patches. Try to figure out a way to fix this.
114
+ # [batch_size, num_patches * patch_size, hidden_channels]
115
+ x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
116
+ max_num_elements=max_num_elements)
117
+
118
+ # [batch_size, num_patches, patch_size * hidden_channels]
119
+ x = x.view(x.size(0), max_num_elements // self.patch_size,
120
+ self.patch_size * x.size(-1))
121
+
122
+ # [batch_size, num_patches, hidden_channels]
123
+ x = self.pad_projector(x)
124
+
125
+ x = x + self.pe(torch.arange(x.size(1), device=x.device))
126
+
127
+ # [batch_size, num_patches, hidden_channels]
128
+ for block in self.blocks:
129
+ x = block(x, x)
130
+
131
+ # [batch_size, hidden_channels]
132
+ outs: List[Tensor] = []
133
+ for aggr in self.aggrs:
134
+ out = getattr(torch, aggr)(x, dim=1)
135
+ outs.append(out[0] if isinstance(out, tuple) else out)
136
+ out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]
137
+
138
+ # [batch_size, out_channels]
139
+ return self.fc(out)
140
+
141
+ def __repr__(self) -> str:
142
+ return (f'{self.__class__.__name__}({self.in_channels}, '
143
+ f'{self.out_channels}, patch_size={self.patch_size})')
@@ -38,7 +38,7 @@ class SetTransformerAggregation(Aggregation):
38
38
  (default: :obj:`1`)
39
39
  concat (bool, optional): If set to :obj:`False`, the seed embeddings
40
40
  are averaged instead of concatenated. (default: :obj:`True`)
41
- norm (str, optional): If set to :obj:`True`, will apply layer
41
+ layer_norm (str, optional): If set to :obj:`True`, will apply layer
42
42
  normalization. (default: :obj:`False`)
43
43
  dropout (float, optional): Dropout probability of attention weights.
44
44
  (default: :obj:`0`)
@@ -0,0 +1,33 @@
1
+ from typing import Optional
2
+
3
+ from torch import Tensor
4
+
5
+ from torch_geometric.nn.aggr import Aggregation
6
+ from torch_geometric.utils import degree
7
+ from torch_geometric.utils._scatter import broadcast
8
+
9
+
10
+ class VariancePreservingAggregation(Aggregation):
11
+ r"""Performs the Variance Preserving Aggregation (VPA) from the `"GNN-VPA:
12
+ A Variance-Preserving Aggregation Strategy for Graph Neural Networks"
13
+ <https://arxiv.org/abs/2403.04747>`_ paper.
14
+
15
+ .. math::
16
+ \mathrm{vpa}(\mathcal{X}) = \frac{1}{\sqrt{|\mathcal{X}|}}
17
+ \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i
18
+ """
19
+ def forward(self, x: Tensor, index: Optional[Tensor] = None,
20
+ ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
21
+ dim: int = -2) -> Tensor:
22
+
23
+ out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
24
+
25
+ if ptr is not None:
26
+ count = ptr.diff().to(out.dtype)
27
+ else:
28
+ count = degree(index, dim_size, dtype=out.dtype)
29
+
30
+ count = count.sqrt().clamp(min=1.0)
31
+ count = broadcast(count, ref=out, dim=dim)
32
+
33
+ return out / count
@@ -1,3 +1,7 @@
1
1
  from .performer import PerformerAttention
2
+ from .qformer import QFormer
2
3
 
3
- __all__ = ['PerformerAttention']
4
+ __all__ = [
5
+ 'PerformerAttention',
6
+ 'QFormer',
7
+ ]
@@ -0,0 +1,71 @@
1
+ from typing import Callable
2
+
3
+ import torch
4
+
5
+
6
+ class QFormer(torch.nn.Module):
7
+ r"""The Querying Transformer (Q-Former) from
8
+ `"BLIP-2: Bootstrapping Language-Image Pre-training
9
+ with Frozen Image Encoders and Large Language Models"
10
+ <https://arxiv.org/pdf/2301.12597>`_ paper.
11
+
12
+ Args:
13
+ input_dim (int): The number of features in the input.
14
+ hidden_dim (int): The dimension of the fnn in the encoder layer.
15
+ output_dim (int): The final output dimension.
16
+ num_heads (int): The number of multi-attention-heads.
17
+ num_layers (int): The number of sub-encoder-layers in the encoder.
18
+ dropout (int): The dropout value in each encoder layer.
19
+
20
+
21
+ .. note::
22
+ This is a simplified version of the original Q-Former implementation.
23
+ """
24
+ def __init__(
25
+ self,
26
+ input_dim: int,
27
+ hidden_dim: int,
28
+ output_dim: int,
29
+ num_heads: int,
30
+ num_layers: int,
31
+ dropout: float = 0.0,
32
+ activation: Callable = torch.nn.ReLU(),
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.num_heads = num_heads
38
+
39
+ self.layer_norm = torch.nn.LayerNorm(input_dim)
40
+ self.encoder_layer = torch.nn.TransformerEncoderLayer(
41
+ d_model=input_dim,
42
+ nhead=num_heads,
43
+ dim_feedforward=hidden_dim,
44
+ dropout=dropout,
45
+ activation=activation,
46
+ batch_first=True,
47
+ )
48
+ self.encoder = torch.nn.TransformerEncoder(
49
+ self.encoder_layer,
50
+ num_layers=num_layers,
51
+ )
52
+ self.project = torch.nn.Linear(input_dim, output_dim)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ r"""Forward pass.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input sequence to the encoder layer.
59
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
60
+ batch-size :math:`B`, sequence length :math:`N`,
61
+ and feature dimension :math:`F`.
62
+ """
63
+ x = self.layer_norm(x)
64
+ x = self.encoder(x)
65
+ out = self.project(x)
66
+ return out
67
+
68
+ def __repr__(self) -> str:
69
+ return (f'{self.__class__.__name__}('
70
+ f'num_heads={self.num_heads}, '
71
+ f'num_layers={self.num_layers})')
@@ -4,8 +4,8 @@ import torch
4
4
  from torch import Tensor
5
5
 
6
6
  from torch_geometric import EdgeIndex
7
+ from torch_geometric.index import ptr2index
7
8
  from torch_geometric.utils import is_torch_sparse_tensor
8
- from torch_geometric.utils.sparse import ptr2index
9
9
  from torch_geometric.typing import SparseTensor
10
10
 
11
11
 
@@ -98,13 +98,16 @@ def {{collect_name}}(
98
98
 
99
99
  {%- if 'edge_weight' in collect_param_dict and
100
100
  collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
101
- assert edge_weight is not None
101
+ if torch.jit.is_scripting():
102
+ assert edge_weight is not None
102
103
  {%- elif 'edge_attr' in collect_param_dict and
103
104
  collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
104
- assert edge_attr is not None
105
+ if torch.jit.is_scripting():
106
+ assert edge_attr is not None
105
107
  {%- elif 'edge_type' in collect_param_dict and
106
108
  collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
107
- assert edge_type is not None
109
+ if torch.jit.is_scripting():
110
+ assert edge_type is not None
108
111
  {%- endif %}
109
112
 
110
113
  # Collect user-defined arguments:
@@ -7,12 +7,7 @@ from torch_geometric import EdgeIndex
7
7
 
8
8
  try: # pragma: no cover
9
9
  LEGACY_MODE = False
10
- from pylibcugraphops.pytorch import (
11
- SampledCSC,
12
- SampledHeteroCSC,
13
- StaticCSC,
14
- StaticHeteroCSC,
15
- )
10
+ from pylibcugraphops.pytorch import CSC, HeteroCSC
16
11
  HAS_PYLIBCUGRAPHOPS = True
17
12
  except ImportError:
18
13
  HAS_PYLIBCUGRAPHOPS = False
@@ -41,7 +36,6 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
41
36
 
42
37
  def reset_parameters(self):
43
38
  r"""Resets all learnable parameters of the module."""
44
- pass
45
39
 
46
40
  def get_cugraph(
47
41
  self,
@@ -79,12 +73,13 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
79
73
  return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,
80
74
  num_src_nodes)
81
75
 
82
- return SampledCSC(colptr, row, max_num_neighbors, num_src_nodes)
76
+ return CSC(colptr, row, num_src_nodes,
77
+ dst_max_in_degree=max_num_neighbors)
83
78
 
84
79
  if LEGACY_MODE:
85
80
  return make_fg_csr(colptr, row)
86
81
 
87
- return StaticCSC(colptr, row)
82
+ return CSC(colptr, row, num_src_nodes=num_src_nodes)
88
83
 
89
84
  def get_typed_cugraph(
90
85
  self,
@@ -135,15 +130,16 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
135
130
  out_node_types=None, in_node_types=None,
136
131
  edge_types=edge_type)
137
132
 
138
- return SampledHeteroCSC(colptr, row, edge_type, max_num_neighbors,
139
- num_src_nodes, num_edge_types)
133
+ return HeteroCSC(colptr, row, edge_type, num_src_nodes,
134
+ num_edge_types,
135
+ dst_max_in_degree=max_num_neighbors)
140
136
 
141
137
  if LEGACY_MODE:
142
138
  return make_fg_csr_hg(colptr, row, n_node_types=0,
143
139
  n_edge_types=num_edge_types, node_types=None,
144
140
  edge_types=edge_type)
145
141
 
146
- return StaticHeteroCSC(colptr, row, edge_type, num_edge_types)
142
+ return HeteroCSC(colptr, row, edge_type, num_src_nodes, num_edge_types)
147
143
 
148
144
  def forward(
149
145
  self,
@@ -3,13 +3,14 @@ from typing import Callable, Optional, Union
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
+ import torch_geometric.typing
6
7
  from torch_geometric.nn.conv import MessagePassing
7
8
  from torch_geometric.nn.inits import reset
8
9
  from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor
9
10
 
10
- try:
11
+ if torch_geometric.typing.WITH_TORCH_CLUSTER:
11
12
  from torch_cluster import knn
12
- except ImportError:
13
+ else:
13
14
  knn = None
14
15
 
15
16