pyg-nightly 2.7.0.dev20250115__py3-none-any.whl → 2.7.0.dev20250116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250115
3
+ Version: 2.7.0.dev20250116
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=QkrqTHL4gTBMa218nYocVEBizEOwmVBdqBkFozuzk4w,1904
1
+ torch_geometric/__init__.py,sha256=rg2lsApb7EvEBE_oOKODzbZlkhjIg4lEI57BNr3rtCQ,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -288,7 +288,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
288
288
  torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
289
289
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
290
290
  torch_geometric/metrics/__init__.py,sha256=xHDTWEG4kdv9xb5pGPlRfQjC5P-ZGbhJ0xDe3YNq3ss,393
291
- torch_geometric/metrics/link_pred.py,sha256=6nd929rWVmSWpFaRJ1u9OSL0VXndr7Pggce4Ynz5UG8,16799
291
+ torch_geometric/metrics/link_pred.py,sha256=8H74eS28AcGtOUB0g8_xUUp8IX-zmlIFuBURp7Dx0No,18269
292
292
  torch_geometric/nn/__init__.py,sha256=RrWRzEoqtR3lsO2lAzYXboLPb3uYEX2z3tLxiBIVWjc,847
293
293
  torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
294
294
  torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
@@ -629,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
629
629
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
630
630
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
631
631
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
632
- pyg_nightly-2.7.0.dev20250115.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
- pyg_nightly-2.7.0.dev20250115.dist-info/METADATA,sha256=en3D5pZ3YXy64UdH3FG_SWtiKWcICTUL5Vl7Dhr-VC8,62977
634
- pyg_nightly-2.7.0.dev20250115.dist-info/RECORD,,
632
+ pyg_nightly-2.7.0.dev20250116.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
+ pyg_nightly-2.7.0.dev20250116.dist-info/METADATA,sha256=ciNy_xzkhPNoMBEJp2PEyuRFgZrdUsLz6gjAQ9nYQxk,62977
634
+ pyg_nightly-2.7.0.dev20250116.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20250115'
33
+ __version__ = '2.7.0.dev20250116'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -84,6 +84,51 @@ class LinkPredMetricData:
84
84
  self._label_count = label_count
85
85
  return label_count
86
86
 
87
+ @property
88
+ def label_weight_sum(self) -> Tensor:
89
+ r"""The sum of edge label weights for every example."""
90
+ if self.edge_label_weight is None:
91
+ return self.label_count
92
+
93
+ if hasattr(self, '_label_weight_sum'):
94
+ return self._label_weight_sum # type: ignore
95
+
96
+ label_weight_sum = scatter(
97
+ self.edge_label_weight,
98
+ self.edge_label_index[0],
99
+ dim=0,
100
+ dim_size=self.pred_index_mat.size(0),
101
+ reduce='sum',
102
+ )
103
+
104
+ self._label_weight_sum = label_weight_sum
105
+ return label_weight_sum
106
+
107
+ @property
108
+ def edge_label_weight_pos(self) -> Optional[Tensor]:
109
+ r"""Returns the position of edge label weights in descending order
110
+ within example-wise buckets.
111
+ """
112
+ if self.edge_label_weight is None:
113
+ return None
114
+
115
+ if hasattr(self, '_edge_label_weight_pos'):
116
+ return self._edge_label_weight_pos # type: ignore
117
+
118
+ # Get the permutation via two sorts: One globally on the weights,
119
+ # followed by a (stable) sort on the example indices.
120
+ perm1 = self.edge_label_weight.argsort(descending=True)
121
+ perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
122
+ perm = perm1[perm2]
123
+ # Invert the permutation to get the final position:
124
+ pos = torch.empty_like(perm)
125
+ pos[perm] = torch.arange(perm.size(0), device=perm.device)
126
+ # Normalize position to zero within all buckets:
127
+ pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]
128
+
129
+ self._edge_label_weight_pos = pos
130
+ return pos
131
+
87
132
 
88
133
  class LinkPredMetric(BaseMetric):
89
134
  r"""An abstract class for computing link prediction retrieval metrics.
@@ -231,7 +276,9 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
231
276
 
232
277
  if isinstance(metrics, (list, tuple)):
233
278
  metrics = {
234
- f'{metric.__class__.__name__}@{metric.k}': metric
279
+ (f'{"Weighted" if metric.weighted else ""}'
280
+ f'{metric.__class__.__name__}@{metric.k}'):
281
+ metric
235
282
  for metric in metrics
236
283
  }
237
284
  assert len(metrics) > 0
@@ -301,6 +348,10 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
301
348
  data.edge_label_weight = None
302
349
  if hasattr(data, '_pred_rel_mat'):
303
350
  data._pred_rel_mat = data._pred_rel_mat != 0.0
351
+ if hasattr(data, '_label_weight_sum'):
352
+ del data._label_weight_sum
353
+ if hasattr(data, '_edge_label_weight_pos'):
354
+ del data._edge_label_weight_pos
304
355
 
305
356
  for metric in self.values():
306
357
  if not metric.weighted:
@@ -343,11 +394,14 @@ class LinkPredRecall(LinkPredMetric):
343
394
  k (int): The number of top-:math:`k` predictions to evaluate against.
344
395
  """
345
396
  higher_is_better: bool = True
346
- weighted: bool = False
397
+
398
+ def __init__(self, k: int, weighted: bool = False):
399
+ super().__init__(k=k)
400
+ self.weighted = weighted
347
401
 
348
402
  def _compute(self, data: LinkPredMetricData) -> Tensor:
349
403
  pred_rel_mat = data.pred_rel_mat[:, :self.k]
350
- return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
404
+ return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)
351
405
 
352
406
 
353
407
  class LinkPredF1(LinkPredMetric):
@@ -397,7 +451,6 @@ class LinkPredNDCG(LinkPredMetric):
397
451
  :obj:`edge_label_weight`. (default: :obj:`False`)
398
452
  """
399
453
  higher_is_better: bool = True
400
- weighted: bool = False
401
454
 
402
455
  def __init__(self, k: int, weighted: bool = False):
403
456
  super().__init__(k=k)
@@ -424,26 +477,18 @@ class LinkPredNDCG(LinkPredMetric):
424
477
  idcg = self.idcg[data.label_count.clamp(max=self.k)]
425
478
  else:
426
479
  assert data.edge_label_weight is not None
427
- # Sort weights within example-wise buckets via two sorts to get the
428
- # local index order within buckets:
429
- weight, batch = data.edge_label_weight, data.edge_label_index[0]
430
- perm1 = weight.argsort(descending=True)
431
- perm2 = batch[perm1].argsort(stable=True)
432
- global_index = torch.empty_like(perm1)
433
- global_index[perm1[perm2]] = torch.arange(
434
- global_index.size(0), device=global_index.device)
435
- local_index = global_index - cumsum(data.label_count)[batch]
436
-
437
- # Get the discount per local index:
480
+ pos = data.edge_label_weight_pos
481
+ assert pos is not None
482
+
438
483
  discount = torch.cat([
439
484
  self.discount,
440
485
  self.discount.new_full((1, ), fill_value=float('inf')),
441
486
  ])
442
- discount = discount[local_index.clamp(max=self.k + 1)]
487
+ discount = discount[pos.clamp(max=self.k + 1)]
443
488
 
444
489
  idcg = scatter( # Apply discount and aggregate:
445
- weight / discount,
446
- batch,
490
+ data.edge_label_weight / discount,
491
+ data.edge_label_index[0],
447
492
  dim_size=data.pred_index_mat.size(0),
448
493
  reduce='sum',
449
494
  )