pyg-nightly 2.7.0.dev20250407__py3-none-any.whl → 2.7.0.dev20250408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250407
3
+ Version: 2.7.0.dev20250408
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -413,18 +413,6 @@ These approaches have been implemented in PyG, and can benefit from the above GN
413
413
 
414
414
  PyG is available for Python 3.9 to Python 3.12.
415
415
 
416
- ### Anaconda
417
-
418
- You can now install PyG via [Anaconda](https://anaconda.org/pyg/pyg) for all major OS/PyTorch/CUDA combinations 🤗
419
- If you have not yet installed PyTorch, install it via `conda` as described in the [official PyTorch documentation](https://pytorch.org/get-started/locally/).
420
- Given that you have PyTorch installed (`>=1.8.0`), simply run
421
-
422
- ```
423
- conda install pyg -c pyg
424
- ```
425
-
426
- ### PyPi
427
-
428
416
  From **PyG 2.3** onwards, you can install and use PyG **without any external library** required except for PyTorch.
429
417
  For this, simply run
430
418
 
@@ -448,28 +436,28 @@ We recommend to start with a minimal installation, and install additional depend
448
436
 
449
437
  For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
450
438
 
451
- #### PyTorch 2.5
439
+ #### PyTorch 2.6
452
440
 
453
- To install the binaries for PyTorch 2.5.0, simply run
441
+ To install the binaries for PyTorch 2.6.0, simply run
454
442
 
455
443
  ```
456
- pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
444
+ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+${CUDA}.html
457
445
  ```
458
446
 
459
- where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
447
+ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu124`, or `cu126` depending on your PyTorch installation.
460
448
 
461
- | | `cpu` | `cu118` | `cu121` | `cu124` |
449
+ | | `cpu` | `cu118` | `cu124` | `cu126` |
462
450
  | ----------- | ----- | ------- | ------- | ------- |
463
451
  | **Linux** | ✅ | ✅ | ✅ | ✅ |
464
452
  | **Windows** | ✅ | ✅ | ✅ | ✅ |
465
453
  | **macOS** | ✅ | | | |
466
454
 
467
- #### PyTorch 2.4
455
+ #### PyTorch 2.5
468
456
 
469
- To install the binaries for PyTorch 2.4.0, simply run
457
+ To install the binaries for PyTorch 2.5.0/2.5.1, simply run
470
458
 
471
459
  ```
472
- pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html
460
+ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
473
461
  ```
474
462
 
475
463
  where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
@@ -480,7 +468,7 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
480
468
  | **Windows** | ✅ | ✅ | ✅ | ✅ |
481
469
  | **macOS** | ✅ | | | |
482
470
 
483
- **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure).
471
+ **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, and PyTorch 2.4.0/2.4.1 (following the same procedure).
484
472
  **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
485
473
  You can look up the latest supported version number [here](https://data.pyg.org/whl).
486
474
 
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=peFG3sVQB1R7kf4erqOCMV6_UO_k1PvDHdQg1HBaqog,1978
1
+ torch_geometric/__init__.py,sha256=rucRaUNeMS0eQLE0b6LFcLimKD5JwNS4_nYsmtC3BMU,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -188,7 +188,7 @@ torch_geometric/distributed/event_loop.py,sha256=wr3iwMYEWOGkBlvC5huD2k5YxisaGE9
188
188
  torch_geometric/distributed/local_feature_store.py,sha256=CLW37RN0ouDufEs2tY9d2nLLvpxubiT6zgW3LIHAU8k,19058
189
189
  torch_geometric/distributed/local_graph_store.py,sha256=wNHXSS824Kk2HynbtWFXx-W4pl97UUBv6qFHAVqO3W4,8445
190
190
  torch_geometric/distributed/partition.py,sha256=X0BleuY0ROlUtVXKvvz814pJwglZQ2_6OiMi1K0Hhvo,14731
191
- torch_geometric/distributed/rpc.py,sha256=t0Ts4tzUE0LQyBr71i2iBjQDLe9NWkmVRf7C4xOllJc,5753
191
+ torch_geometric/distributed/rpc.py,sha256=rJqiVR6Vbb2mpyVSC0Y5tPApqP-b1ck1Uq3IQpCsNSw,5737
192
192
  torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jdSIUjBeM,6567
193
193
  torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
194
194
  torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1BFlM,7834
@@ -290,7 +290,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
290
290
  torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
291
291
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
292
292
  torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
293
- torch_geometric/metrics/link_pred.py,sha256=t2YHbEYc8Jbj_4Sb-Wdk5T5uzsSErpjBpUiSqOSf-NM,30729
293
+ torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
294
294
  torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
295
295
  torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
296
296
  torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
@@ -441,7 +441,7 @@ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0x
441
441
  torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
442
442
  torch_geometric/nn/models/glem.py,sha256=sT0XM4klVlci9wduvUoXupATUw9p25uXtaJBrmv3yvs,16431
443
443
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
444
- torch_geometric/nn/models/gpse.py,sha256=my-KIw_Ov8o0pXSCyh43NZRBAW95TFfmBgxzSimx8-A,42680
444
+ torch_geometric/nn/models/gpse.py,sha256=Fwldw9N3axV--BcSnI4im1sy1r87a5mAXZAXHu_6k2Y,41932
445
445
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
446
446
  torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
447
447
  torch_geometric/nn/models/jumping_knowledge.py,sha256=9JR2EoViXKjcDSLb8tjJm-UHfv1mQCJvZAAEsYa0Ocw,5496
@@ -571,7 +571,7 @@ torch_geometric/transforms/remove_duplicated_edges.py,sha256=EMX6E1R_gXFoXwBqMqt
571
571
  torch_geometric/transforms/remove_isolated_nodes.py,sha256=Q89b73es1tPsAmTdS7tWTIM7JcPUpL37v3EZTAd25Fc,2449
572
572
  torch_geometric/transforms/remove_self_loops.py,sha256=JfoooSnTO2KPXXuC3KWGhLS0tqr4yiXOl3A0sVv2riM,1221
573
573
  torch_geometric/transforms/remove_training_classes.py,sha256=GMCZwI_LYo2ZF29DABZXeuM0Sn2i3twx_V3KBUGu2As,932
574
- torch_geometric/transforms/rooted_subgraph.py,sha256=exQ-6bRePiH44o7f_VoBnkyj79PfGjH_hyAoolTIih8,6509
574
+ torch_geometric/transforms/rooted_subgraph.py,sha256=2P_2NSRcDNzfeR0zYScCaqZDkAvC0t3Ivr_lINO9abc,6504
575
575
  torch_geometric/transforms/sample_points.py,sha256=UfD44528J7SKH0I2_4ELM1A4hKKHCIDeMV6UzBbQAVU,2280
576
576
  torch_geometric/transforms/sign.py,sha256=bZUvUm9fMGXcYkI1GwPOW5ZC1QFu84vPBKpFnYxz2nA,2329
577
577
  torch_geometric/transforms/spherical.py,sha256=nU7h4IFw69JqUwRqaweVEBegHZWPOHDoTYYVRMzIZ7U,2320
@@ -592,7 +592,7 @@ torch_geometric/utils/_grid.py,sha256=1coutST2TMV9TSQcmpXze0GIK9odzZ9wBtbKs6u26D
592
592
  torch_geometric/utils/_homophily.py,sha256=1nXxGUATFPB3icEGpvEWUiuYbjU9gDGtlWpuLbtWhJk,5090
593
593
  torch_geometric/utils/_index_sort.py,sha256=FTJacmOsqgsyof7MJFHlVVdXhHOjR0j7siTb0UZ-YT0,1283
594
594
  torch_geometric/utils/_lexsort.py,sha256=chMEJJRXqfE6-K4vrVszdr3c338EhMZyi0Q9IEJD3p0,1403
595
- torch_geometric/utils/_negative_sampling.py,sha256=G4O572zAQgQQlVMz6ihhE13HFKEekLLYVXcYp4ZSdcQ,15521
595
+ torch_geometric/utils/_negative_sampling.py,sha256=jxsmpryeoTT8qQrvIH11MgyhgoWzvqPGRAcVyU85VCU,15494
596
596
  torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
597
597
  torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
598
598
  torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
636
636
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
637
637
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
638
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
639
- pyg_nightly-2.7.0.dev20250407.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
- pyg_nightly-2.7.0.dev20250407.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
- pyg_nightly-2.7.0.dev20250407.dist-info/METADATA,sha256=uPWrrVVadg-GQ1_t4bwTD4LTyNy8r0m9_g5BaKW1AVs,63021
642
- pyg_nightly-2.7.0.dev20250407.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250408.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250408.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250408.dist-info/METADATA,sha256=OX81Kse1GucsZCReO0E00CJl4rrtT9HhHEkNEbJU6KY,62652
642
+ pyg_nightly-2.7.0.dev20250408.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250407'
34
+ __version__ = '2.7.0.dev20250408'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
144
144
  @rpc_require_initialized
145
145
  def rpc_register(call: RPCCallBase) -> int:
146
146
  r"""Registers a call for RPC requests."""
147
- global _rpc_call_id, _rpc_call_pool
147
+ global _rpc_call_id
148
148
 
149
149
  with _rpc_call_lock:
150
150
  call_id = _rpc_call_id
@@ -670,6 +670,9 @@ class LinkPredDiversity(_LinkPredMetric):
670
670
  def __init__(self, k: int, category: Tensor) -> None:
671
671
  super().__init__(k)
672
672
 
673
+ self.accum: Tensor
674
+ self.total: Tensor
675
+
673
676
  if WITH_TORCHMETRICS:
674
677
  self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
675
678
  self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
@@ -736,11 +739,14 @@ class LinkPredPersonalization(_LinkPredMetric):
736
739
  self.max_src_nodes = max_src_nodes
737
740
  self.batch_size = batch_size
738
741
 
742
+ self.preds: List[Tensor]
743
+ self.total: Tensor
744
+
739
745
  if WITH_TORCHMETRICS:
740
746
  self.add_state('preds', default=[], dist_reduce_fx='cat')
741
747
  self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
742
748
  else:
743
- self.preds: List[Tensor] = []
749
+ self.preds = []
744
750
  self.register_buffer('total', torch.tensor(0), persistent=False)
745
751
 
746
752
  def update(
@@ -826,6 +832,9 @@ class LinkPredAveragePopularity(_LinkPredMetric):
826
832
  def __init__(self, k: int, popularity: Tensor) -> None:
827
833
  super().__init__(k)
828
834
 
835
+ self.accum: Tensor
836
+ self.total: Tensor
837
+
829
838
  if WITH_TORCHMETRICS:
830
839
  self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
831
840
  self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
@@ -379,7 +379,7 @@ class GPSE(torch.nn.Module):
379
379
 
380
380
  .. code-block:: python
381
381
 
382
- from torch_geometric.nn import GPSE, GPSENodeEncoder,
382
+ from torch_geometric.nn import GPSE, GPSENodeEncoder
383
383
  from torch_geometric.transforms import AddGPSE
384
384
  from torch_geometric.nn.models.gpse import precompute_GPSE
385
385
 
@@ -417,13 +417,11 @@ class GPSE(torch.nn.Module):
417
417
 
418
418
  encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
419
419
  expand_x=False)
420
- gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
420
+ gnn = GNN(...)
421
421
 
422
422
  for batch in loader:
423
- batch = encoder(batch)
424
- batch = gnn(batch)
425
- # Do something with the batch, which now includes 128-dimensional
426
- # node representations
423
+ x = encoder(batch.x, batch.pestat_GPSE)
424
+ out = gnn(x, batch.edge_index)
427
425
 
428
426
 
429
427
  Args:
@@ -571,8 +569,7 @@ class GPSE(torch.nn.Module):
571
569
  self.reset_parameters()
572
570
 
573
571
  def reset_parameters(self):
574
- from torch_geometric.graphgym.init import init_weights
575
- self.apply(init_weights)
572
+ pass
576
573
 
577
574
  @classmethod
578
575
  def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
@@ -608,6 +605,7 @@ class GPSE(torch.nn.Module):
608
605
  return model
609
606
 
610
607
  def forward(self, batch):
608
+ batch = batch.clone()
611
609
  for module in self.children():
612
610
  batch = module(batch)
613
611
  return batch
@@ -635,13 +633,11 @@ class GPSENodeEncoder(torch.nn.Module):
635
633
 
636
634
  encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
637
635
  expand_x=False)
638
- gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
636
+ gnn = GNN(...)
639
637
 
640
638
  for batch in loader:
641
- batch = encoder(batch)
642
- batch = gnn(batch)
643
- # Do something with the batch, which now includes 128-dimensional
644
- # node representations
639
+ x = encoder(batch.x, batch.pestat_GPSE)
640
+ batch = gnn(x, batch.edge_index)
645
641
 
646
642
  Args:
647
643
  dim_emb (int): Size of final node embedding.
@@ -705,28 +701,17 @@ class GPSENodeEncoder(torch.nn.Module):
705
701
  raise ValueError(f"{self.__class__.__name__}: Does not support "
706
702
  f"'{model_type}' encoder model.")
707
703
 
708
- def forward(self, batch):
709
- if not hasattr(batch, 'pestat_GPSE'):
710
- raise ValueError('Precomputed "pestat_GPSE" variable is required '
711
- 'for GNNNodeEncoder; either run '
712
- '`precompute_GPSE(gpse_model, dataset)` on your '
713
- 'dataset or add `AddGPSE(gpse_model)` as a (pre) '
714
- 'transform.')
715
-
716
- pos_enc = batch.pestat_GPSE
717
-
704
+ def forward(self, x, pos_enc):
718
705
  pos_enc = self.dropout_be(pos_enc)
719
706
  pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
720
707
  pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
721
708
  pos_enc = self.dropout_ae(pos_enc)
722
709
 
723
710
  # Expand node features if needed
724
- h = self.linear_x(batch.x) if self.expand_x else batch.x
711
+ h = self.linear_x(x) if self.expand_x else x
725
712
 
726
713
  # Concatenate final PEs to input embedding
727
- batch.x = torch.cat((h, pos_enc), 1)
728
-
729
- return batch
714
+ return torch.cat((h, pos_enc), 1)
730
715
 
731
716
 
732
717
  @torch.no_grad()
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
94
94
  arange = torch.arange(n_id.size(0), device=data.edge_index.device)
95
95
  node_map = data.edge_index.new_ones(num_nodes, num_nodes)
96
96
  node_map[n_sub_batch, n_id] = arange
97
- sub_edge_index += (arange * data.num_nodes)[e_sub_batch]
97
+ sub_edge_index += (arange * num_nodes)[e_sub_batch]
98
98
  sub_edge_index = node_map.view(-1)[sub_edge_index]
99
99
 
100
100
  return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
@@ -109,10 +109,9 @@ def negative_sampling(
109
109
  idx = idx.to('cpu')
110
110
  for _ in range(3): # Number of tries to sample negative indices.
111
111
  rnd = sample(population, sample_size, device='cpu')
112
- mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore
112
+ mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
113
113
  if neg_idx is not None:
114
- mask |= np.isin(rnd, neg_idx.to('cpu'))
115
- mask = torch.from_numpy(mask).to(torch.bool)
114
+ mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
116
115
  rnd = rnd[~mask].to(edge_index.device)
117
116
  neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
118
117
  if neg_idx.numel() >= num_neg_samples: