pyg-nightly 2.7.0.dev20250114__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250114
3
+ Version: 2.7.0.dev20250115
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=OePuhKBlW6WPSGttNbtvnSX0xm9ofblbpOm1oj7VB8E,1904
1
+ torch_geometric/__init__.py,sha256=QkrqTHL4gTBMa218nYocVEBizEOwmVBdqBkFozuzk4w,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -288,7 +288,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
288
288
  torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
289
289
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
290
290
  torch_geometric/metrics/__init__.py,sha256=xHDTWEG4kdv9xb5pGPlRfQjC5P-ZGbhJ0xDe3YNq3ss,393
291
- torch_geometric/metrics/link_pred.py,sha256=UTFxnRJw6bu2AMjgM_nN14g33W2cHBaoMFcT-Tglj6c,11653
291
+ torch_geometric/metrics/link_pred.py,sha256=6nd929rWVmSWpFaRJ1u9OSL0VXndr7Pggce4Ynz5UG8,16799
292
292
  torch_geometric/nn/__init__.py,sha256=RrWRzEoqtR3lsO2lAzYXboLPb3uYEX2z3tLxiBIVWjc,847
293
293
  torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
294
294
  torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
@@ -459,7 +459,7 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
459
459
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
460
460
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
461
461
  torch_geometric/nn/nlp/__init__.py,sha256=q6CPUiJHcc9bXw90lyj-ID4F3kfW8uPM-SOxW9uCMHs,213
462
- torch_geometric/nn/nlp/llm.py,sha256=M15Qn0yHyA6HL2rHCH2p4H6hKjUvLfnzlxdfEFvRxSA,11732
462
+ torch_geometric/nn/nlp/llm.py,sha256=vcFvqW-veEfVZDLSHKFKXY-1k0TbiOzmf3LZIwIA0zM,12146
463
463
  torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
464
464
  torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
465
465
  torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
@@ -629,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
629
629
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
630
630
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
631
631
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
632
- pyg_nightly-2.7.0.dev20250114.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
- pyg_nightly-2.7.0.dev20250114.dist-info/METADATA,sha256=XZA1HPdycLk4Y21bu651u5e0tIDnkLgRfRn_FQgbIXc,62977
634
- pyg_nightly-2.7.0.dev20250114.dist-info/RECORD,,
632
+ pyg_nightly-2.7.0.dev20250115.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
+ pyg_nightly-2.7.0.dev20250115.dist-info/METADATA,sha256=en3D5pZ3YXy64UdH3FG_SWtiKWcICTUL5Vl7Dhr-VC8,62977
634
+ pyg_nightly-2.7.0.dev20250115.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20250114'
33
+ __version__ = '2.7.0.dev20250115'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Dict, List, Optional, Tuple, Union
2
3
 
3
4
  import torch
@@ -14,6 +15,76 @@ except Exception:
14
15
  BaseMetric = torch.nn.Module # type: ignore
15
16
 
16
17
 
18
+ @dataclass(repr=False)
19
+ class LinkPredMetricData:
20
+ pred_index_mat: Tensor
21
+ edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
22
+ edge_label_weight: Optional[Tensor] = None
23
+
24
+ @property
25
+ def pred_rel_mat(self) -> Tensor:
26
+ r"""Returns a matrix indicating the relevance of the `k`-th prediction.
27
+ If :obj:`edge_label_weight` is not given, relevance will be denoted as
28
+ binary.
29
+ """
30
+ if hasattr(self, '_pred_rel_mat'):
31
+ return self._pred_rel_mat # type: ignore
32
+
33
+ # Flatten both prediction and ground-truth indices, and determine
34
+ # overlaps afterwards via `torch.searchsorted`.
35
+ max_index = max( # type: ignore
36
+ self.pred_index_mat.max()
37
+ if self.pred_index_mat.numel() > 0 else 0,
38
+ self.edge_label_index[1].max()
39
+ if self.edge_label_index[1].numel() > 0 else 0,
40
+ ) + 1
41
+ arange = torch.arange(
42
+ start=0,
43
+ end=max_index * self.pred_index_mat.size(0), # type: ignore
44
+ step=max_index, # type: ignore
45
+ device=self.pred_index_mat.device,
46
+ ).view(-1, 1)
47
+ flat_pred_index = (self.pred_index_mat + arange).view(-1)
48
+ flat_label_index = max_index * self.edge_label_index[0]
49
+ flat_label_index = flat_label_index + self.edge_label_index[1]
50
+ flat_label_index, perm = flat_label_index.sort()
51
+ edge_label_weight = self.edge_label_weight
52
+ if edge_label_weight is not None:
53
+ assert edge_label_weight.size() == self.edge_label_index[0].size()
54
+ edge_label_weight = edge_label_weight[perm]
55
+
56
+ pos = torch.searchsorted(flat_label_index, flat_pred_index)
57
+ pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
58
+
59
+ pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
60
+ if edge_label_weight is not None:
61
+ pred_rel_mat = edge_label_weight[pos].where(
62
+ pred_rel_mat,
63
+ pred_rel_mat.new_zeros(1),
64
+ )
65
+ pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
66
+
67
+ self._pred_rel_mat = pred_rel_mat
68
+ return pred_rel_mat
69
+
70
+ @property
71
+ def label_count(self) -> Tensor:
72
+ r"""The number of ground-truth labels for every example."""
73
+ if hasattr(self, '_label_count'):
74
+ return self._label_count # type: ignore
75
+
76
+ label_count = scatter(
77
+ torch.ones_like(self.edge_label_index[0]),
78
+ self.edge_label_index[0],
79
+ dim=0,
80
+ dim_size=self.pred_index_mat.size(0),
81
+ reduce='sum',
82
+ )
83
+
84
+ self._label_count = label_count
85
+ return label_count
86
+
87
+
17
88
  class LinkPredMetric(BaseMetric):
18
89
  r"""An abstract class for computing link prediction retrieval metrics.
19
90
 
@@ -23,6 +94,7 @@ class LinkPredMetric(BaseMetric):
23
94
  is_differentiable: bool = False
24
95
  full_state_update: bool = False
25
96
  higher_is_better: Optional[bool] = None
97
+ weighted: bool = False
26
98
 
27
99
  def __init__(self, k: int) -> None:
28
100
  super().__init__()
@@ -43,56 +115,11 @@ class LinkPredMetric(BaseMetric):
43
115
  self.register_buffer('accum', torch.tensor(0.))
44
116
  self.register_buffer('total', torch.tensor(0))
45
117
 
46
- @staticmethod
47
- def _prepare(
48
- pred_index_mat: Tensor,
49
- edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
50
- ) -> Tuple[Tensor, Tensor]:
51
- # Compute a boolean matrix indicating if the `k`-th prediction is part
52
- # of the ground-truth, as well as the number of ground-truths for every
53
- # example. We do this by flattening both prediction and ground-truth
54
- # indices, and then determining overlaps via `torch.isin`.
55
- max_index = max( # type: ignore
56
- pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
57
- edge_label_index[1].max()
58
- if edge_label_index[1].numel() > 0 else 0,
59
- ) + 1
60
- arange = torch.arange(
61
- start=0,
62
- end=max_index * pred_index_mat.size(0), # type: ignore
63
- step=max_index, # type: ignore
64
- device=pred_index_mat.device,
65
- ).view(-1, 1)
66
- flat_pred_index = (pred_index_mat + arange).view(-1)
67
- flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]
68
-
69
- pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
70
- pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())
71
-
72
- # Compute the number of ground-truths per example:
73
- y_count = scatter(
74
- torch.ones_like(edge_label_index[0]),
75
- edge_label_index[0],
76
- dim=0,
77
- dim_size=pred_index_mat.size(0),
78
- reduce='sum',
79
- )
80
-
81
- return pred_isin_mat, y_count
82
-
83
- def _update_from_prepared(
84
- self,
85
- pred_isin_mat: Tensor,
86
- y_count: Tensor,
87
- ) -> None:
88
- metric = self._compute(pred_isin_mat[:, :self.k], y_count)
89
- self.accum += metric.sum()
90
- self.total += (y_count > 0).sum()
91
-
92
118
  def update(
93
119
  self,
94
120
  pred_index_mat: Tensor,
95
121
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
122
+ edge_label_weight: Optional[Tensor] = None,
96
123
  ) -> None:
97
124
  r"""Updates the state variables based on the current mini-batch
98
125
  prediction.
@@ -108,10 +135,30 @@ class LinkPredMetric(BaseMetric):
108
135
  edge_label_index (torch.Tensor): The ground-truth indices for every
109
136
  example in the mini-batch, given in COO format of shape
110
137
  :obj:`[2, num_ground_truth_indices]`.
138
+ edge_label_weight (torch.Tensor, optional): The weight of the
139
+ ground-truth indices for every example in the mini-batch of
140
+ shape :obj:`[num_ground_truth_indices]`. If given, needs to be
141
+ a vector of positive values. Required for weighted metrics,
142
+ ignored otherwise. (default: :obj:`None`)
111
143
  """
112
- pred_isin_mat, y_count = self._prepare(pred_index_mat,
113
- edge_label_index)
114
- self._update_from_prepared(pred_isin_mat, y_count)
144
+ if self.weighted and edge_label_weight is None:
145
+ raise ValueError(f"'edge_label_weight' is a required argument for "
146
+ f"weighted '{self.__class__.__name__}' metrics")
147
+ if not self.weighted:
148
+ edge_label_weight = None
149
+
150
+ data = LinkPredMetricData(
151
+ pred_index_mat=pred_index_mat,
152
+ edge_label_index=edge_label_index,
153
+ edge_label_weight=edge_label_weight,
154
+ )
155
+ self._update(data)
156
+
157
+ def _update(self, data: LinkPredMetricData) -> None:
158
+ metric = self._compute(data)
159
+
160
+ self.accum += metric.sum()
161
+ self.total += (data.label_count > 0).sum()
115
162
 
116
163
  def compute(self) -> Tensor:
117
164
  r"""Computes the final metric value."""
@@ -120,28 +167,26 @@ class LinkPredMetric(BaseMetric):
120
167
  return self.accum / self.total
121
168
 
122
169
  def reset(self) -> None:
123
- r"""Reset metric state variables to their default value."""
170
+ r"""Resets metric state variables to their default value."""
124
171
  if WITH_TORCHMETRICS:
125
172
  super().reset()
126
173
  else:
127
174
  self.accum.zero_()
128
175
  self.total.zero_()
129
176
 
130
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
131
- r"""Compute the specific metric.
177
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
178
+ r"""Computes the specific metric.
132
179
  To be implemented separately for each metric class.
133
180
 
134
181
  Args:
135
- pred_isin_mat (torch.Tensor): A boolean matrix whose :obj:`(i,k)`
136
- element indicates if the :obj:`k`-th prediction for the
137
- :obj:`i`-th example is correct or not.
138
- y_count (torch.Tensor): A vector indicating the number of
139
- ground-truth labels for each example.
182
+ data (LinkPredMetricData): The mini-batch data for computing a link
183
+ prediction metric per example.
140
184
  """
141
185
  raise NotImplementedError
142
186
 
143
187
  def __repr__(self) -> str:
144
- return f'{self.__class__.__name__}(k={self.k})'
188
+ weighted_repr = ', weighted=True' if self.weighted else ''
189
+ return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
145
190
 
146
191
 
147
192
  class LinkPredMetricCollection(torch.nn.ModuleDict):
@@ -202,10 +247,18 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
202
247
  """
203
248
  return max([metric.k for metric in self.values()])
204
249
 
250
+ @property
251
+ def weighted(self) -> bool:
252
+ r"""Returns :obj:`True` in case the collection holds at least one
253
+ weighted link prediction metric.
254
+ """
255
+ return any([metric.weighted for metric in self.values()])
256
+
205
257
  def update( # type: ignore
206
258
  self,
207
259
  pred_index_mat: Tensor,
208
260
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
261
+ edge_label_weight: Optional[Tensor] = None,
209
262
  ) -> None:
210
263
  r"""Updates the state variables based on the current mini-batch
211
264
  prediction.
@@ -221,11 +274,39 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
221
274
  edge_label_index (torch.Tensor): The ground-truth indices for every
222
275
  example in the mini-batch, given in COO format of shape
223
276
  :obj:`[2, num_ground_truth_indices]`.
277
+ edge_label_weight (torch.Tensor, optional): The weight of the
278
+ ground-truth indices for every example in the mini-batch of
279
+ shape :obj:`[num_ground_truth_indices]`. If given, needs to be
280
+ a vector of positive values. Required for weighted metrics,
281
+ ignored otherwise. (default: :obj:`None`)
224
282
  """
225
- pred_isin_mat, y_count = LinkPredMetric._prepare(
226
- pred_index_mat, edge_label_index)
283
+ if self.weighted and edge_label_weight is None:
284
+ raise ValueError(f"'edge_label_weight' is a required argument for "
285
+ f"weighted '{self.__class__.__name__}' metrics")
286
+ if not self.weighted:
287
+ edge_label_weight = None
288
+
289
+ data = LinkPredMetricData( # Share metric data across metrics.
290
+ pred_index_mat=pred_index_mat,
291
+ edge_label_index=edge_label_index,
292
+ edge_label_weight=edge_label_weight,
293
+ )
294
+
227
295
  for metric in self.values():
228
- metric._update_from_prepared(pred_isin_mat, y_count)
296
+ if metric.weighted:
297
+ metric._update(data)
298
+ if WITH_TORCHMETRICS:
299
+ metric._update_count += 1
300
+
301
+ data.edge_label_weight = None
302
+ if hasattr(data, '_pred_rel_mat'):
303
+ data._pred_rel_mat = data._pred_rel_mat != 0.0
304
+
305
+ for metric in self.values():
306
+ if not metric.weighted:
307
+ metric._update(data)
308
+ if WITH_TORCHMETRICS:
309
+ metric._update_count += 1
229
310
 
230
311
  def compute(self) -> Dict[str, Tensor]:
231
312
  r"""Computes the final metric values."""
@@ -248,9 +329,11 @@ class LinkPredPrecision(LinkPredMetric):
248
329
  k (int): The number of top-:math:`k` predictions to evaluate against.
249
330
  """
250
331
  higher_is_better: bool = True
332
+ weighted: bool = False
251
333
 
252
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
253
- return pred_isin_mat.sum(dim=-1) / self.k
334
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
335
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
336
+ return pred_rel_mat.sum(dim=-1) / self.k
254
337
 
255
338
 
256
339
  class LinkPredRecall(LinkPredMetric):
@@ -260,9 +343,11 @@ class LinkPredRecall(LinkPredMetric):
260
343
  k (int): The number of top-:math:`k` predictions to evaluate against.
261
344
  """
262
345
  higher_is_better: bool = True
346
+ weighted: bool = False
263
347
 
264
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
265
- return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7)
348
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
349
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
350
+ return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
266
351
 
267
352
 
268
353
  class LinkPredF1(LinkPredMetric):
@@ -272,11 +357,13 @@ class LinkPredF1(LinkPredMetric):
272
357
  k (int): The number of top-:math:`k` predictions to evaluate against.
273
358
  """
274
359
  higher_is_better: bool = True
360
+ weighted: bool = False
275
361
 
276
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
277
- isin_count = pred_isin_mat.sum(dim=-1)
362
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
363
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
364
+ isin_count = pred_rel_mat.sum(dim=-1)
278
365
  precision = isin_count / self.k
279
- recall = isin_count = isin_count / y_count.clamp(min=1e-7)
366
+ recall = isin_count / data.label_count.clamp(min=1e-7)
280
367
  return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
281
368
 
282
369
 
@@ -288,13 +375,15 @@ class LinkPredMAP(LinkPredMetric):
288
375
  k (int): The number of top-:math:`k` predictions to evaluate against.
289
376
  """
290
377
  higher_is_better: bool = True
378
+ weighted: bool = False
291
379
 
292
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
293
- device = pred_isin_mat.device
294
- arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device)
295
- cum_precision = pred_isin_mat.cumsum(dim=1) / arange
296
- return ((cum_precision * pred_isin_mat).sum(dim=-1) /
297
- y_count.clamp(min=1e-7, max=self.k))
380
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
381
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
382
+ device = pred_rel_mat.device
383
+ arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
384
+ cum_precision = pred_rel_mat.cumsum(dim=1) / arange
385
+ return ((cum_precision * pred_rel_mat).sum(dim=-1) /
386
+ data.label_count.clamp(min=1e-7, max=self.k))
298
387
 
299
388
 
300
389
  class LinkPredNDCG(LinkPredMetric):
@@ -303,25 +392,61 @@ class LinkPredNDCG(LinkPredMetric):
303
392
 
304
393
  Args:
305
394
  k (int): The number of top-:math:`k` predictions to evaluate against.
395
+ weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
396
+ of ground-truth items according to a relevance score as given by
397
+ :obj:`edge_label_weight`. (default: :obj:`False`)
306
398
  """
307
399
  higher_is_better: bool = True
400
+ weighted: bool = False
308
401
 
309
- def __init__(self, k: int):
402
+ def __init__(self, k: int, weighted: bool = False):
310
403
  super().__init__(k=k)
404
+ self.weighted = weighted
311
405
 
312
406
  dtype = torch.get_default_dtype()
313
- multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2()
407
+ discount = torch.arange(2, k + 2, dtype=dtype).log2()
314
408
 
315
- self.multiplier: Tensor
316
- self.register_buffer('multiplier', multiplier)
409
+ self.discount: Tensor
410
+ self.register_buffer('discount', discount)
317
411
 
318
- self.idcg: Tensor
319
- self.register_buffer('idcg', cumsum(multiplier))
412
+ if not weighted:
413
+ self.register_buffer('idcg', cumsum(1.0 / discount))
414
+ else:
415
+ self.idcg = None
416
+
417
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
418
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
419
+ discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
420
+ dcg = (pred_rel_mat / discount).sum(dim=-1)
320
421
 
321
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
322
- multiplier = self.multiplier[:pred_isin_mat.size(1)].view(1, -1)
323
- dcg = (pred_isin_mat * multiplier).sum(dim=-1)
324
- idcg = self.idcg[y_count.clamp(max=self.k)]
422
+ if not self.weighted:
423
+ assert self.idcg is not None
424
+ idcg = self.idcg[data.label_count.clamp(max=self.k)]
425
+ else:
426
+ assert data.edge_label_weight is not None
427
+ # Sort weights within example-wise buckets via two sorts to get the
428
+ # local index order within buckets:
429
+ weight, batch = data.edge_label_weight, data.edge_label_index[0]
430
+ perm1 = weight.argsort(descending=True)
431
+ perm2 = batch[perm1].argsort(stable=True)
432
+ global_index = torch.empty_like(perm1)
433
+ global_index[perm1[perm2]] = torch.arange(
434
+ global_index.size(0), device=global_index.device)
435
+ local_index = global_index - cumsum(data.label_count)[batch]
436
+
437
+ # Get the discount per local index:
438
+ discount = torch.cat([
439
+ self.discount,
440
+ self.discount.new_full((1, ), fill_value=float('inf')),
441
+ ])
442
+ discount = discount[local_index.clamp(max=self.k + 1)]
443
+
444
+ idcg = scatter( # Apply discount and aggregate:
445
+ weight / discount,
446
+ batch,
447
+ dim_size=data.pred_index_mat.size(0),
448
+ reduce='sum',
449
+ )
325
450
 
326
451
  out = dcg / idcg
327
452
  out[out.isnan() | out.isinf()] = 0.0
@@ -336,8 +461,10 @@ class LinkPredMRR(LinkPredMetric):
336
461
  k (int): The number of top-:math:`k` predictions to evaluate against.
337
462
  """
338
463
  higher_is_better: bool = True
464
+ weighted: bool = False
339
465
 
340
- def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
341
- device = pred_isin_mat.device
342
- arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device)
343
- return (pred_isin_mat / arange).max(dim=-1)[0]
466
+ def _compute(self, data: LinkPredMetricData) -> Tensor:
467
+ pred_rel_mat = data.pred_rel_mat[:, :self.k]
468
+ device = pred_rel_mat.device
469
+ arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
470
+ return (pred_rel_mat / arange).max(dim=-1)[0]
@@ -51,17 +51,18 @@ class LLM(torch.nn.Module):
51
51
 
52
52
  model_name (str): The HuggingFace model name, *e.g.*, :obj:`"llama2"` or
53
53
  :obj:`"gemma"`.
54
- num_params (int): An integer representing how many parameters the
54
+ num_params (int, optional): An integer representing how many parameters the
55
55
  HuggingFace model has, in billions. This is used to automatically
56
56
  allocate the correct number of GPUs needed, given the available GPU
57
- memory of your GPUs.
57
+ memory of your GPUs. If not specified, the number of parameters
58
+ is determined using the `huggingface_hub` module.
58
59
  dtype (torch.dtype, optional): The data type to use for the LLM.
59
60
  (default :obj: `torch.bfloat16`)
60
61
  """
61
62
  def __init__(
62
63
  self,
63
64
  model_name: str,
64
- num_params: int,
65
+ num_params: int = None,
65
66
  dtype=torch.bfloat16,
66
67
  ) -> None:
67
68
  super().__init__()
@@ -70,6 +71,12 @@ class LLM(torch.nn.Module):
70
71
 
71
72
  from transformers import AutoModelForCausalLM, AutoTokenizer
72
73
 
74
+ if num_params is None:
75
+ from huggingface_hub import get_safetensors_metadata
76
+ safetensors_metadata = get_safetensors_metadata(model_name)
77
+ param_count = safetensors_metadata.parameter_count
78
+ num_params = list(param_count.values())[0] // 10**9
79
+
73
80
  # A rough heuristic on GPU memory requirements, e.g., we found that
74
81
  # LLAMA2 (7B parameters) fits on a 85GB GPU.
75
82
  required_memory = 85 * num_params / 7