pyg-nightly 2.7.0.dev20250114__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.7.0.dev20250114.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250114.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +6 -6
- torch_geometric/__init__.py +1 -1
- torch_geometric/metrics/link_pred.py +215 -88
- torch_geometric/nn/nlp/llm.py +10 -3
- {pyg_nightly-2.7.0.dev20250114.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20250114.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250115
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=QkrqTHL4gTBMa218nYocVEBizEOwmVBdqBkFozuzk4w,1904
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -288,7 +288,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
|
|
288
288
|
torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
|
289
289
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
290
290
|
torch_geometric/metrics/__init__.py,sha256=xHDTWEG4kdv9xb5pGPlRfQjC5P-ZGbhJ0xDe3YNq3ss,393
|
291
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
291
|
+
torch_geometric/metrics/link_pred.py,sha256=6nd929rWVmSWpFaRJ1u9OSL0VXndr7Pggce4Ynz5UG8,16799
|
292
292
|
torch_geometric/nn/__init__.py,sha256=RrWRzEoqtR3lsO2lAzYXboLPb3uYEX2z3tLxiBIVWjc,847
|
293
293
|
torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
|
294
294
|
torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
|
@@ -459,7 +459,7 @@ torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6
|
|
459
459
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
460
460
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
461
461
|
torch_geometric/nn/nlp/__init__.py,sha256=q6CPUiJHcc9bXw90lyj-ID4F3kfW8uPM-SOxW9uCMHs,213
|
462
|
-
torch_geometric/nn/nlp/llm.py,sha256=
|
462
|
+
torch_geometric/nn/nlp/llm.py,sha256=vcFvqW-veEfVZDLSHKFKXY-1k0TbiOzmf3LZIwIA0zM,12146
|
463
463
|
torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
|
464
464
|
torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
|
465
465
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
@@ -629,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
629
629
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
630
630
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
631
631
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
632
|
-
pyg_nightly-2.7.0.
|
633
|
-
pyg_nightly-2.7.0.
|
634
|
-
pyg_nightly-2.7.0.
|
632
|
+
pyg_nightly-2.7.0.dev20250115.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
633
|
+
pyg_nightly-2.7.0.dev20250115.dist-info/METADATA,sha256=en3D5pZ3YXy64UdH3FG_SWtiKWcICTUL5Vl7Dhr-VC8,62977
|
634
|
+
pyg_nightly-2.7.0.dev20250115.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.7.0.
|
33
|
+
__version__ = '2.7.0.dev20250115'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from dataclasses import dataclass
|
1
2
|
from typing import Dict, List, Optional, Tuple, Union
|
2
3
|
|
3
4
|
import torch
|
@@ -14,6 +15,76 @@ except Exception:
|
|
14
15
|
BaseMetric = torch.nn.Module # type: ignore
|
15
16
|
|
16
17
|
|
18
|
+
@dataclass(repr=False)
|
19
|
+
class LinkPredMetricData:
|
20
|
+
pred_index_mat: Tensor
|
21
|
+
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
|
22
|
+
edge_label_weight: Optional[Tensor] = None
|
23
|
+
|
24
|
+
@property
|
25
|
+
def pred_rel_mat(self) -> Tensor:
|
26
|
+
r"""Returns a matrix indicating the relevance of the `k`-th prediction.
|
27
|
+
If :obj:`edge_label_weight` is not given, relevance will be denoted as
|
28
|
+
binary.
|
29
|
+
"""
|
30
|
+
if hasattr(self, '_pred_rel_mat'):
|
31
|
+
return self._pred_rel_mat # type: ignore
|
32
|
+
|
33
|
+
# Flatten both prediction and ground-truth indices, and determine
|
34
|
+
# overlaps afterwards via `torch.searchsorted`.
|
35
|
+
max_index = max( # type: ignore
|
36
|
+
self.pred_index_mat.max()
|
37
|
+
if self.pred_index_mat.numel() > 0 else 0,
|
38
|
+
self.edge_label_index[1].max()
|
39
|
+
if self.edge_label_index[1].numel() > 0 else 0,
|
40
|
+
) + 1
|
41
|
+
arange = torch.arange(
|
42
|
+
start=0,
|
43
|
+
end=max_index * self.pred_index_mat.size(0), # type: ignore
|
44
|
+
step=max_index, # type: ignore
|
45
|
+
device=self.pred_index_mat.device,
|
46
|
+
).view(-1, 1)
|
47
|
+
flat_pred_index = (self.pred_index_mat + arange).view(-1)
|
48
|
+
flat_label_index = max_index * self.edge_label_index[0]
|
49
|
+
flat_label_index = flat_label_index + self.edge_label_index[1]
|
50
|
+
flat_label_index, perm = flat_label_index.sort()
|
51
|
+
edge_label_weight = self.edge_label_weight
|
52
|
+
if edge_label_weight is not None:
|
53
|
+
assert edge_label_weight.size() == self.edge_label_index[0].size()
|
54
|
+
edge_label_weight = edge_label_weight[perm]
|
55
|
+
|
56
|
+
pos = torch.searchsorted(flat_label_index, flat_pred_index)
|
57
|
+
pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
|
58
|
+
|
59
|
+
pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
|
60
|
+
if edge_label_weight is not None:
|
61
|
+
pred_rel_mat = edge_label_weight[pos].where(
|
62
|
+
pred_rel_mat,
|
63
|
+
pred_rel_mat.new_zeros(1),
|
64
|
+
)
|
65
|
+
pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
|
66
|
+
|
67
|
+
self._pred_rel_mat = pred_rel_mat
|
68
|
+
return pred_rel_mat
|
69
|
+
|
70
|
+
@property
|
71
|
+
def label_count(self) -> Tensor:
|
72
|
+
r"""The number of ground-truth labels for every example."""
|
73
|
+
if hasattr(self, '_label_count'):
|
74
|
+
return self._label_count # type: ignore
|
75
|
+
|
76
|
+
label_count = scatter(
|
77
|
+
torch.ones_like(self.edge_label_index[0]),
|
78
|
+
self.edge_label_index[0],
|
79
|
+
dim=0,
|
80
|
+
dim_size=self.pred_index_mat.size(0),
|
81
|
+
reduce='sum',
|
82
|
+
)
|
83
|
+
|
84
|
+
self._label_count = label_count
|
85
|
+
return label_count
|
86
|
+
|
87
|
+
|
17
88
|
class LinkPredMetric(BaseMetric):
|
18
89
|
r"""An abstract class for computing link prediction retrieval metrics.
|
19
90
|
|
@@ -23,6 +94,7 @@ class LinkPredMetric(BaseMetric):
|
|
23
94
|
is_differentiable: bool = False
|
24
95
|
full_state_update: bool = False
|
25
96
|
higher_is_better: Optional[bool] = None
|
97
|
+
weighted: bool = False
|
26
98
|
|
27
99
|
def __init__(self, k: int) -> None:
|
28
100
|
super().__init__()
|
@@ -43,56 +115,11 @@ class LinkPredMetric(BaseMetric):
|
|
43
115
|
self.register_buffer('accum', torch.tensor(0.))
|
44
116
|
self.register_buffer('total', torch.tensor(0))
|
45
117
|
|
46
|
-
@staticmethod
|
47
|
-
def _prepare(
|
48
|
-
pred_index_mat: Tensor,
|
49
|
-
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
50
|
-
) -> Tuple[Tensor, Tensor]:
|
51
|
-
# Compute a boolean matrix indicating if the `k`-th prediction is part
|
52
|
-
# of the ground-truth, as well as the number of ground-truths for every
|
53
|
-
# example. We do this by flattening both prediction and ground-truth
|
54
|
-
# indices, and then determining overlaps via `torch.isin`.
|
55
|
-
max_index = max( # type: ignore
|
56
|
-
pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
|
57
|
-
edge_label_index[1].max()
|
58
|
-
if edge_label_index[1].numel() > 0 else 0,
|
59
|
-
) + 1
|
60
|
-
arange = torch.arange(
|
61
|
-
start=0,
|
62
|
-
end=max_index * pred_index_mat.size(0), # type: ignore
|
63
|
-
step=max_index, # type: ignore
|
64
|
-
device=pred_index_mat.device,
|
65
|
-
).view(-1, 1)
|
66
|
-
flat_pred_index = (pred_index_mat + arange).view(-1)
|
67
|
-
flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]
|
68
|
-
|
69
|
-
pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
|
70
|
-
pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())
|
71
|
-
|
72
|
-
# Compute the number of ground-truths per example:
|
73
|
-
y_count = scatter(
|
74
|
-
torch.ones_like(edge_label_index[0]),
|
75
|
-
edge_label_index[0],
|
76
|
-
dim=0,
|
77
|
-
dim_size=pred_index_mat.size(0),
|
78
|
-
reduce='sum',
|
79
|
-
)
|
80
|
-
|
81
|
-
return pred_isin_mat, y_count
|
82
|
-
|
83
|
-
def _update_from_prepared(
|
84
|
-
self,
|
85
|
-
pred_isin_mat: Tensor,
|
86
|
-
y_count: Tensor,
|
87
|
-
) -> None:
|
88
|
-
metric = self._compute(pred_isin_mat[:, :self.k], y_count)
|
89
|
-
self.accum += metric.sum()
|
90
|
-
self.total += (y_count > 0).sum()
|
91
|
-
|
92
118
|
def update(
|
93
119
|
self,
|
94
120
|
pred_index_mat: Tensor,
|
95
121
|
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
122
|
+
edge_label_weight: Optional[Tensor] = None,
|
96
123
|
) -> None:
|
97
124
|
r"""Updates the state variables based on the current mini-batch
|
98
125
|
prediction.
|
@@ -108,10 +135,30 @@ class LinkPredMetric(BaseMetric):
|
|
108
135
|
edge_label_index (torch.Tensor): The ground-truth indices for every
|
109
136
|
example in the mini-batch, given in COO format of shape
|
110
137
|
:obj:`[2, num_ground_truth_indices]`.
|
138
|
+
edge_label_weight (torch.Tensor, optional): The weight of the
|
139
|
+
ground-truth indices for every example in the mini-batch of
|
140
|
+
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
|
141
|
+
a vector of positive values. Required for weighted metrics,
|
142
|
+
ignored otherwise. (default: :obj:`None`)
|
111
143
|
"""
|
112
|
-
|
113
|
-
|
114
|
-
|
144
|
+
if self.weighted and edge_label_weight is None:
|
145
|
+
raise ValueError(f"'edge_label_weight' is a required argument for "
|
146
|
+
f"weighted '{self.__class__.__name__}' metrics")
|
147
|
+
if not self.weighted:
|
148
|
+
edge_label_weight = None
|
149
|
+
|
150
|
+
data = LinkPredMetricData(
|
151
|
+
pred_index_mat=pred_index_mat,
|
152
|
+
edge_label_index=edge_label_index,
|
153
|
+
edge_label_weight=edge_label_weight,
|
154
|
+
)
|
155
|
+
self._update(data)
|
156
|
+
|
157
|
+
def _update(self, data: LinkPredMetricData) -> None:
|
158
|
+
metric = self._compute(data)
|
159
|
+
|
160
|
+
self.accum += metric.sum()
|
161
|
+
self.total += (data.label_count > 0).sum()
|
115
162
|
|
116
163
|
def compute(self) -> Tensor:
|
117
164
|
r"""Computes the final metric value."""
|
@@ -120,28 +167,26 @@ class LinkPredMetric(BaseMetric):
|
|
120
167
|
return self.accum / self.total
|
121
168
|
|
122
169
|
def reset(self) -> None:
|
123
|
-
r"""
|
170
|
+
r"""Resets metric state variables to their default value."""
|
124
171
|
if WITH_TORCHMETRICS:
|
125
172
|
super().reset()
|
126
173
|
else:
|
127
174
|
self.accum.zero_()
|
128
175
|
self.total.zero_()
|
129
176
|
|
130
|
-
def _compute(self,
|
131
|
-
r"""
|
177
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
178
|
+
r"""Computes the specific metric.
|
132
179
|
To be implemented separately for each metric class.
|
133
180
|
|
134
181
|
Args:
|
135
|
-
|
136
|
-
|
137
|
-
:obj:`i`-th example is correct or not.
|
138
|
-
y_count (torch.Tensor): A vector indicating the number of
|
139
|
-
ground-truth labels for each example.
|
182
|
+
data (LinkPredMetricData): The mini-batch data for computing a link
|
183
|
+
prediction metric per example.
|
140
184
|
"""
|
141
185
|
raise NotImplementedError
|
142
186
|
|
143
187
|
def __repr__(self) -> str:
|
144
|
-
|
188
|
+
weighted_repr = ', weighted=True' if self.weighted else ''
|
189
|
+
return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
|
145
190
|
|
146
191
|
|
147
192
|
class LinkPredMetricCollection(torch.nn.ModuleDict):
|
@@ -202,10 +247,18 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
|
|
202
247
|
"""
|
203
248
|
return max([metric.k for metric in self.values()])
|
204
249
|
|
250
|
+
@property
|
251
|
+
def weighted(self) -> bool:
|
252
|
+
r"""Returns :obj:`True` in case the collection holds at least one
|
253
|
+
weighted link prediction metric.
|
254
|
+
"""
|
255
|
+
return any([metric.weighted for metric in self.values()])
|
256
|
+
|
205
257
|
def update( # type: ignore
|
206
258
|
self,
|
207
259
|
pred_index_mat: Tensor,
|
208
260
|
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
261
|
+
edge_label_weight: Optional[Tensor] = None,
|
209
262
|
) -> None:
|
210
263
|
r"""Updates the state variables based on the current mini-batch
|
211
264
|
prediction.
|
@@ -221,11 +274,39 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
|
|
221
274
|
edge_label_index (torch.Tensor): The ground-truth indices for every
|
222
275
|
example in the mini-batch, given in COO format of shape
|
223
276
|
:obj:`[2, num_ground_truth_indices]`.
|
277
|
+
edge_label_weight (torch.Tensor, optional): The weight of the
|
278
|
+
ground-truth indices for every example in the mini-batch of
|
279
|
+
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
|
280
|
+
a vector of positive values. Required for weighted metrics,
|
281
|
+
ignored otherwise. (default: :obj:`None`)
|
224
282
|
"""
|
225
|
-
|
226
|
-
|
283
|
+
if self.weighted and edge_label_weight is None:
|
284
|
+
raise ValueError(f"'edge_label_weight' is a required argument for "
|
285
|
+
f"weighted '{self.__class__.__name__}' metrics")
|
286
|
+
if not self.weighted:
|
287
|
+
edge_label_weight = None
|
288
|
+
|
289
|
+
data = LinkPredMetricData( # Share metric data across metrics.
|
290
|
+
pred_index_mat=pred_index_mat,
|
291
|
+
edge_label_index=edge_label_index,
|
292
|
+
edge_label_weight=edge_label_weight,
|
293
|
+
)
|
294
|
+
|
227
295
|
for metric in self.values():
|
228
|
-
metric.
|
296
|
+
if metric.weighted:
|
297
|
+
metric._update(data)
|
298
|
+
if WITH_TORCHMETRICS:
|
299
|
+
metric._update_count += 1
|
300
|
+
|
301
|
+
data.edge_label_weight = None
|
302
|
+
if hasattr(data, '_pred_rel_mat'):
|
303
|
+
data._pred_rel_mat = data._pred_rel_mat != 0.0
|
304
|
+
|
305
|
+
for metric in self.values():
|
306
|
+
if not metric.weighted:
|
307
|
+
metric._update(data)
|
308
|
+
if WITH_TORCHMETRICS:
|
309
|
+
metric._update_count += 1
|
229
310
|
|
230
311
|
def compute(self) -> Dict[str, Tensor]:
|
231
312
|
r"""Computes the final metric values."""
|
@@ -248,9 +329,11 @@ class LinkPredPrecision(LinkPredMetric):
|
|
248
329
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
249
330
|
"""
|
250
331
|
higher_is_better: bool = True
|
332
|
+
weighted: bool = False
|
251
333
|
|
252
|
-
def _compute(self,
|
253
|
-
|
334
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
335
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
336
|
+
return pred_rel_mat.sum(dim=-1) / self.k
|
254
337
|
|
255
338
|
|
256
339
|
class LinkPredRecall(LinkPredMetric):
|
@@ -260,9 +343,11 @@ class LinkPredRecall(LinkPredMetric):
|
|
260
343
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
261
344
|
"""
|
262
345
|
higher_is_better: bool = True
|
346
|
+
weighted: bool = False
|
263
347
|
|
264
|
-
def _compute(self,
|
265
|
-
|
348
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
349
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
350
|
+
return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
|
266
351
|
|
267
352
|
|
268
353
|
class LinkPredF1(LinkPredMetric):
|
@@ -272,11 +357,13 @@ class LinkPredF1(LinkPredMetric):
|
|
272
357
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
273
358
|
"""
|
274
359
|
higher_is_better: bool = True
|
360
|
+
weighted: bool = False
|
275
361
|
|
276
|
-
def _compute(self,
|
277
|
-
|
362
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
363
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
364
|
+
isin_count = pred_rel_mat.sum(dim=-1)
|
278
365
|
precision = isin_count / self.k
|
279
|
-
recall = isin_count
|
366
|
+
recall = isin_count / data.label_count.clamp(min=1e-7)
|
280
367
|
return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
|
281
368
|
|
282
369
|
|
@@ -288,13 +375,15 @@ class LinkPredMAP(LinkPredMetric):
|
|
288
375
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
289
376
|
"""
|
290
377
|
higher_is_better: bool = True
|
378
|
+
weighted: bool = False
|
291
379
|
|
292
|
-
def _compute(self,
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
380
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
381
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
382
|
+
device = pred_rel_mat.device
|
383
|
+
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
|
384
|
+
cum_precision = pred_rel_mat.cumsum(dim=1) / arange
|
385
|
+
return ((cum_precision * pred_rel_mat).sum(dim=-1) /
|
386
|
+
data.label_count.clamp(min=1e-7, max=self.k))
|
298
387
|
|
299
388
|
|
300
389
|
class LinkPredNDCG(LinkPredMetric):
|
@@ -303,25 +392,61 @@ class LinkPredNDCG(LinkPredMetric):
|
|
303
392
|
|
304
393
|
Args:
|
305
394
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
395
|
+
weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
|
396
|
+
of ground-truth items according to a relevance score as given by
|
397
|
+
:obj:`edge_label_weight`. (default: :obj:`False`)
|
306
398
|
"""
|
307
399
|
higher_is_better: bool = True
|
400
|
+
weighted: bool = False
|
308
401
|
|
309
|
-
def __init__(self, k: int):
|
402
|
+
def __init__(self, k: int, weighted: bool = False):
|
310
403
|
super().__init__(k=k)
|
404
|
+
self.weighted = weighted
|
311
405
|
|
312
406
|
dtype = torch.get_default_dtype()
|
313
|
-
|
407
|
+
discount = torch.arange(2, k + 2, dtype=dtype).log2()
|
314
408
|
|
315
|
-
self.
|
316
|
-
self.register_buffer('
|
409
|
+
self.discount: Tensor
|
410
|
+
self.register_buffer('discount', discount)
|
317
411
|
|
318
|
-
|
319
|
-
|
412
|
+
if not weighted:
|
413
|
+
self.register_buffer('idcg', cumsum(1.0 / discount))
|
414
|
+
else:
|
415
|
+
self.idcg = None
|
416
|
+
|
417
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
418
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
419
|
+
discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
|
420
|
+
dcg = (pred_rel_mat / discount).sum(dim=-1)
|
320
421
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
422
|
+
if not self.weighted:
|
423
|
+
assert self.idcg is not None
|
424
|
+
idcg = self.idcg[data.label_count.clamp(max=self.k)]
|
425
|
+
else:
|
426
|
+
assert data.edge_label_weight is not None
|
427
|
+
# Sort weights within example-wise buckets via two sorts to get the
|
428
|
+
# local index order within buckets:
|
429
|
+
weight, batch = data.edge_label_weight, data.edge_label_index[0]
|
430
|
+
perm1 = weight.argsort(descending=True)
|
431
|
+
perm2 = batch[perm1].argsort(stable=True)
|
432
|
+
global_index = torch.empty_like(perm1)
|
433
|
+
global_index[perm1[perm2]] = torch.arange(
|
434
|
+
global_index.size(0), device=global_index.device)
|
435
|
+
local_index = global_index - cumsum(data.label_count)[batch]
|
436
|
+
|
437
|
+
# Get the discount per local index:
|
438
|
+
discount = torch.cat([
|
439
|
+
self.discount,
|
440
|
+
self.discount.new_full((1, ), fill_value=float('inf')),
|
441
|
+
])
|
442
|
+
discount = discount[local_index.clamp(max=self.k + 1)]
|
443
|
+
|
444
|
+
idcg = scatter( # Apply discount and aggregate:
|
445
|
+
weight / discount,
|
446
|
+
batch,
|
447
|
+
dim_size=data.pred_index_mat.size(0),
|
448
|
+
reduce='sum',
|
449
|
+
)
|
325
450
|
|
326
451
|
out = dcg / idcg
|
327
452
|
out[out.isnan() | out.isinf()] = 0.0
|
@@ -336,8 +461,10 @@ class LinkPredMRR(LinkPredMetric):
|
|
336
461
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
337
462
|
"""
|
338
463
|
higher_is_better: bool = True
|
464
|
+
weighted: bool = False
|
339
465
|
|
340
|
-
def _compute(self,
|
341
|
-
|
342
|
-
|
343
|
-
|
466
|
+
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
467
|
+
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
468
|
+
device = pred_rel_mat.device
|
469
|
+
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
|
470
|
+
return (pred_rel_mat / arange).max(dim=-1)[0]
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -51,17 +51,18 @@ class LLM(torch.nn.Module):
|
|
51
51
|
|
52
52
|
model_name (str): The HuggingFace model name, *e.g.*, :obj:`"llama2"` or
|
53
53
|
:obj:`"gemma"`.
|
54
|
-
num_params (int): An integer representing how many parameters the
|
54
|
+
num_params (int, optional): An integer representing how many parameters the
|
55
55
|
HuggingFace model has, in billions. This is used to automatically
|
56
56
|
allocate the correct number of GPUs needed, given the available GPU
|
57
|
-
memory of your GPUs.
|
57
|
+
memory of your GPUs. If not specified, the number of parameters
|
58
|
+
is determined using the `huggingface_hub` module.
|
58
59
|
dtype (torch.dtype, optional): The data type to use for the LLM.
|
59
60
|
(default :obj: `torch.bfloat16`)
|
60
61
|
"""
|
61
62
|
def __init__(
|
62
63
|
self,
|
63
64
|
model_name: str,
|
64
|
-
num_params: int,
|
65
|
+
num_params: int = None,
|
65
66
|
dtype=torch.bfloat16,
|
66
67
|
) -> None:
|
67
68
|
super().__init__()
|
@@ -70,6 +71,12 @@ class LLM(torch.nn.Module):
|
|
70
71
|
|
71
72
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
72
73
|
|
74
|
+
if num_params is None:
|
75
|
+
from huggingface_hub import get_safetensors_metadata
|
76
|
+
safetensors_metadata = get_safetensors_metadata(model_name)
|
77
|
+
param_count = safetensors_metadata.parameter_count
|
78
|
+
num_params = list(param_count.values())[0] // 10**9
|
79
|
+
|
73
80
|
# A rough heuristic on GPU memory requirements, e.g., we found that
|
74
81
|
# LLAMA2 (7B parameters) fits on a 85GB GPU.
|
75
82
|
required_memory = 85 * num_params / 7
|
File without changes
|