pyg-nightly 2.7.0.dev20250115__py3-none-any.whl → 2.7.0.dev20250116__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250115
3
+ Version: 2.7.0.dev20250116
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=QkrqTHL4gTBMa218nYocVEBizEOwmVBdqBkFozuzk4w,1904
1
+ torch_geometric/__init__.py,sha256=rg2lsApb7EvEBE_oOKODzbZlkhjIg4lEI57BNr3rtCQ,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -288,7 +288,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
288
288
  torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
289
289
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
290
290
  torch_geometric/metrics/__init__.py,sha256=xHDTWEG4kdv9xb5pGPlRfQjC5P-ZGbhJ0xDe3YNq3ss,393
291
- torch_geometric/metrics/link_pred.py,sha256=6nd929rWVmSWpFaRJ1u9OSL0VXndr7Pggce4Ynz5UG8,16799
291
+ torch_geometric/metrics/link_pred.py,sha256=8H74eS28AcGtOUB0g8_xUUp8IX-zmlIFuBURp7Dx0No,18269
292
292
  torch_geometric/nn/__init__.py,sha256=RrWRzEoqtR3lsO2lAzYXboLPb3uYEX2z3tLxiBIVWjc,847
293
293
  torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
294
294
  torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
@@ -629,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
629
629
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
630
630
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
631
631
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
632
- pyg_nightly-2.7.0.dev20250115.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
- pyg_nightly-2.7.0.dev20250115.dist-info/METADATA,sha256=en3D5pZ3YXy64UdH3FG_SWtiKWcICTUL5Vl7Dhr-VC8,62977
634
- pyg_nightly-2.7.0.dev20250115.dist-info/RECORD,,
632
+ pyg_nightly-2.7.0.dev20250116.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
+ pyg_nightly-2.7.0.dev20250116.dist-info/METADATA,sha256=ciNy_xzkhPNoMBEJp2PEyuRFgZrdUsLz6gjAQ9nYQxk,62977
634
+ pyg_nightly-2.7.0.dev20250116.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20250115'
33
+ __version__ = '2.7.0.dev20250116'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -84,6 +84,51 @@ class LinkPredMetricData:
84
84
  self._label_count = label_count
85
85
  return label_count
86
86
 
87
+ @property
88
+ def label_weight_sum(self) -> Tensor:
89
+ r"""The sum of edge label weights for every example."""
90
+ if self.edge_label_weight is None:
91
+ return self.label_count
92
+
93
+ if hasattr(self, '_label_weight_sum'):
94
+ return self._label_weight_sum # type: ignore
95
+
96
+ label_weight_sum = scatter(
97
+ self.edge_label_weight,
98
+ self.edge_label_index[0],
99
+ dim=0,
100
+ dim_size=self.pred_index_mat.size(0),
101
+ reduce='sum',
102
+ )
103
+
104
+ self._label_weight_sum = label_weight_sum
105
+ return label_weight_sum
106
+
107
+ @property
108
+ def edge_label_weight_pos(self) -> Optional[Tensor]:
109
+ r"""Returns the position of edge label weights in descending order
110
+ within example-wise buckets.
111
+ """
112
+ if self.edge_label_weight is None:
113
+ return None
114
+
115
+ if hasattr(self, '_edge_label_weight_pos'):
116
+ return self._edge_label_weight_pos # type: ignore
117
+
118
+ # Get the permutation via two sorts: One globally on the weights,
119
+ # followed by a (stable) sort on the example indices.
120
+ perm1 = self.edge_label_weight.argsort(descending=True)
121
+ perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
122
+ perm = perm1[perm2]
123
+ # Invert the permutation to get the final position:
124
+ pos = torch.empty_like(perm)
125
+ pos[perm] = torch.arange(perm.size(0), device=perm.device)
126
+ # Normalize position to zero within all buckets:
127
+ pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]
128
+
129
+ self._edge_label_weight_pos = pos
130
+ return pos
131
+
87
132
 
88
133
  class LinkPredMetric(BaseMetric):
89
134
  r"""An abstract class for computing link prediction retrieval metrics.
@@ -231,7 +276,9 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
231
276
 
232
277
  if isinstance(metrics, (list, tuple)):
233
278
  metrics = {
234
- f'{metric.__class__.__name__}@{metric.k}': metric
279
+ (f'{"Weighted" if metric.weighted else ""}'
280
+ f'{metric.__class__.__name__}@{metric.k}'):
281
+ metric
235
282
  for metric in metrics
236
283
  }
237
284
  assert len(metrics) > 0
@@ -301,6 +348,10 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
301
348
  data.edge_label_weight = None
302
349
  if hasattr(data, '_pred_rel_mat'):
303
350
  data._pred_rel_mat = data._pred_rel_mat != 0.0
351
+ if hasattr(data, '_label_weight_sum'):
352
+ del data._label_weight_sum
353
+ if hasattr(data, '_edge_label_weight_pos'):
354
+ del data._edge_label_weight_pos
304
355
 
305
356
  for metric in self.values():
306
357
  if not metric.weighted:
@@ -343,11 +394,14 @@ class LinkPredRecall(LinkPredMetric):
343
394
  k (int): The number of top-:math:`k` predictions to evaluate against.
344
395
  """
345
396
  higher_is_better: bool = True
346
- weighted: bool = False
397
+
398
+ def __init__(self, k: int, weighted: bool = False):
399
+ super().__init__(k=k)
400
+ self.weighted = weighted
347
401
 
348
402
  def _compute(self, data: LinkPredMetricData) -> Tensor:
349
403
  pred_rel_mat = data.pred_rel_mat[:, :self.k]
350
- return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
404
+ return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)
351
405
 
352
406
 
353
407
  class LinkPredF1(LinkPredMetric):
@@ -397,7 +451,6 @@ class LinkPredNDCG(LinkPredMetric):
397
451
  :obj:`edge_label_weight`. (default: :obj:`False`)
398
452
  """
399
453
  higher_is_better: bool = True
400
- weighted: bool = False
401
454
 
402
455
  def __init__(self, k: int, weighted: bool = False):
403
456
  super().__init__(k=k)
@@ -424,26 +477,18 @@ class LinkPredNDCG(LinkPredMetric):
424
477
  idcg = self.idcg[data.label_count.clamp(max=self.k)]
425
478
  else:
426
479
  assert data.edge_label_weight is not None
427
- # Sort weights within example-wise buckets via two sorts to get the
428
- # local index order within buckets:
429
- weight, batch = data.edge_label_weight, data.edge_label_index[0]
430
- perm1 = weight.argsort(descending=True)
431
- perm2 = batch[perm1].argsort(stable=True)
432
- global_index = torch.empty_like(perm1)
433
- global_index[perm1[perm2]] = torch.arange(
434
- global_index.size(0), device=global_index.device)
435
- local_index = global_index - cumsum(data.label_count)[batch]
436
-
437
- # Get the discount per local index:
480
+ pos = data.edge_label_weight_pos
481
+ assert pos is not None
482
+
438
483
  discount = torch.cat([
439
484
  self.discount,
440
485
  self.discount.new_full((1, ), fill_value=float('inf')),
441
486
  ])
442
- discount = discount[local_index.clamp(max=self.k + 1)]
487
+ discount = discount[pos.clamp(max=self.k + 1)]
443
488
 
444
489
  idcg = scatter( # Apply discount and aggregate:
445
- weight / discount,
446
- batch,
490
+ data.edge_label_weight / discount,
491
+ data.edge_label_index[0],
447
492
  dim_size=data.pred_index_mat.size(0),
448
493
  reduce='sum',
449
494
  )