pyg-nightly 2.7.0.dev20250115__py3-none-any.whl → 2.7.0.dev20250116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250115.dist-info → pyg_nightly-2.7.0.dev20250116.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250115.dist-info → pyg_nightly-2.7.0.dev20250116.dist-info}/RECORD +5 -5
- torch_geometric/__init__.py +1 -1
- torch_geometric/metrics/link_pred.py +63 -18
- {pyg_nightly-2.7.0.dev20250115.dist-info → pyg_nightly-2.7.0.dev20250116.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20250115.dist-info → pyg_nightly-2.7.0.dev20250116.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250116
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=rg2lsApb7EvEBE_oOKODzbZlkhjIg4lEI57BNr3rtCQ,1904
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -288,7 +288,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
|
|
288
288
|
torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
|
289
289
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
290
290
|
torch_geometric/metrics/__init__.py,sha256=xHDTWEG4kdv9xb5pGPlRfQjC5P-ZGbhJ0xDe3YNq3ss,393
|
291
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
291
|
+
torch_geometric/metrics/link_pred.py,sha256=8H74eS28AcGtOUB0g8_xUUp8IX-zmlIFuBURp7Dx0No,18269
|
292
292
|
torch_geometric/nn/__init__.py,sha256=RrWRzEoqtR3lsO2lAzYXboLPb3uYEX2z3tLxiBIVWjc,847
|
293
293
|
torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
|
294
294
|
torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
|
@@ -629,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
629
629
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
630
630
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
631
631
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
632
|
-
pyg_nightly-2.7.0.
|
633
|
-
pyg_nightly-2.7.0.
|
634
|
-
pyg_nightly-2.7.0.
|
632
|
+
pyg_nightly-2.7.0.dev20250116.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
633
|
+
pyg_nightly-2.7.0.dev20250116.dist-info/METADATA,sha256=ciNy_xzkhPNoMBEJp2PEyuRFgZrdUsLz6gjAQ9nYQxk,62977
|
634
|
+
pyg_nightly-2.7.0.dev20250116.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
30
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
32
|
|
33
|
-
__version__ = '2.7.0.
|
33
|
+
__version__ = '2.7.0.dev20250116'
|
34
34
|
|
35
35
|
__all__ = [
|
36
36
|
'Index',
|
@@ -84,6 +84,51 @@ class LinkPredMetricData:
|
|
84
84
|
self._label_count = label_count
|
85
85
|
return label_count
|
86
86
|
|
87
|
+
@property
|
88
|
+
def label_weight_sum(self) -> Tensor:
|
89
|
+
r"""The sum of edge label weights for every example."""
|
90
|
+
if self.edge_label_weight is None:
|
91
|
+
return self.label_count
|
92
|
+
|
93
|
+
if hasattr(self, '_label_weight_sum'):
|
94
|
+
return self._label_weight_sum # type: ignore
|
95
|
+
|
96
|
+
label_weight_sum = scatter(
|
97
|
+
self.edge_label_weight,
|
98
|
+
self.edge_label_index[0],
|
99
|
+
dim=0,
|
100
|
+
dim_size=self.pred_index_mat.size(0),
|
101
|
+
reduce='sum',
|
102
|
+
)
|
103
|
+
|
104
|
+
self._label_weight_sum = label_weight_sum
|
105
|
+
return label_weight_sum
|
106
|
+
|
107
|
+
@property
|
108
|
+
def edge_label_weight_pos(self) -> Optional[Tensor]:
|
109
|
+
r"""Returns the position of edge label weights in descending order
|
110
|
+
within example-wise buckets.
|
111
|
+
"""
|
112
|
+
if self.edge_label_weight is None:
|
113
|
+
return None
|
114
|
+
|
115
|
+
if hasattr(self, '_edge_label_weight_pos'):
|
116
|
+
return self._edge_label_weight_pos # type: ignore
|
117
|
+
|
118
|
+
# Get the permutation via two sorts: One globally on the weights,
|
119
|
+
# followed by a (stable) sort on the example indices.
|
120
|
+
perm1 = self.edge_label_weight.argsort(descending=True)
|
121
|
+
perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
|
122
|
+
perm = perm1[perm2]
|
123
|
+
# Invert the permutation to get the final position:
|
124
|
+
pos = torch.empty_like(perm)
|
125
|
+
pos[perm] = torch.arange(perm.size(0), device=perm.device)
|
126
|
+
# Normalize position to zero within all buckets:
|
127
|
+
pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]
|
128
|
+
|
129
|
+
self._edge_label_weight_pos = pos
|
130
|
+
return pos
|
131
|
+
|
87
132
|
|
88
133
|
class LinkPredMetric(BaseMetric):
|
89
134
|
r"""An abstract class for computing link prediction retrieval metrics.
|
@@ -231,7 +276,9 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
|
|
231
276
|
|
232
277
|
if isinstance(metrics, (list, tuple)):
|
233
278
|
metrics = {
|
234
|
-
f'{metric.
|
279
|
+
(f'{"Weighted" if metric.weighted else ""}'
|
280
|
+
f'{metric.__class__.__name__}@{metric.k}'):
|
281
|
+
metric
|
235
282
|
for metric in metrics
|
236
283
|
}
|
237
284
|
assert len(metrics) > 0
|
@@ -301,6 +348,10 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
|
|
301
348
|
data.edge_label_weight = None
|
302
349
|
if hasattr(data, '_pred_rel_mat'):
|
303
350
|
data._pred_rel_mat = data._pred_rel_mat != 0.0
|
351
|
+
if hasattr(data, '_label_weight_sum'):
|
352
|
+
del data._label_weight_sum
|
353
|
+
if hasattr(data, '_edge_label_weight_pos'):
|
354
|
+
del data._edge_label_weight_pos
|
304
355
|
|
305
356
|
for metric in self.values():
|
306
357
|
if not metric.weighted:
|
@@ -343,11 +394,14 @@ class LinkPredRecall(LinkPredMetric):
|
|
343
394
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
344
395
|
"""
|
345
396
|
higher_is_better: bool = True
|
346
|
-
|
397
|
+
|
398
|
+
def __init__(self, k: int, weighted: bool = False):
|
399
|
+
super().__init__(k=k)
|
400
|
+
self.weighted = weighted
|
347
401
|
|
348
402
|
def _compute(self, data: LinkPredMetricData) -> Tensor:
|
349
403
|
pred_rel_mat = data.pred_rel_mat[:, :self.k]
|
350
|
-
return pred_rel_mat.sum(dim=-1) / data.
|
404
|
+
return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)
|
351
405
|
|
352
406
|
|
353
407
|
class LinkPredF1(LinkPredMetric):
|
@@ -397,7 +451,6 @@ class LinkPredNDCG(LinkPredMetric):
|
|
397
451
|
:obj:`edge_label_weight`. (default: :obj:`False`)
|
398
452
|
"""
|
399
453
|
higher_is_better: bool = True
|
400
|
-
weighted: bool = False
|
401
454
|
|
402
455
|
def __init__(self, k: int, weighted: bool = False):
|
403
456
|
super().__init__(k=k)
|
@@ -424,26 +477,18 @@ class LinkPredNDCG(LinkPredMetric):
|
|
424
477
|
idcg = self.idcg[data.label_count.clamp(max=self.k)]
|
425
478
|
else:
|
426
479
|
assert data.edge_label_weight is not None
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
perm1 = weight.argsort(descending=True)
|
431
|
-
perm2 = batch[perm1].argsort(stable=True)
|
432
|
-
global_index = torch.empty_like(perm1)
|
433
|
-
global_index[perm1[perm2]] = torch.arange(
|
434
|
-
global_index.size(0), device=global_index.device)
|
435
|
-
local_index = global_index - cumsum(data.label_count)[batch]
|
436
|
-
|
437
|
-
# Get the discount per local index:
|
480
|
+
pos = data.edge_label_weight_pos
|
481
|
+
assert pos is not None
|
482
|
+
|
438
483
|
discount = torch.cat([
|
439
484
|
self.discount,
|
440
485
|
self.discount.new_full((1, ), fill_value=float('inf')),
|
441
486
|
])
|
442
|
-
discount = discount[
|
487
|
+
discount = discount[pos.clamp(max=self.k + 1)]
|
443
488
|
|
444
489
|
idcg = scatter( # Apply discount and aggregate:
|
445
|
-
|
446
|
-
|
490
|
+
data.edge_label_weight / discount,
|
491
|
+
data.edge_label_index[0],
|
447
492
|
dim_size=data.pred_index_mat.size(0),
|
448
493
|
reduce='sum',
|
449
494
|
)
|
File without changes
|