pyg-nightly 2.7.0.dev20250406__py3-none-any.whl → 2.7.0.dev20250408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250406
3
+ Version: 2.7.0.dev20250408
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -413,18 +413,6 @@ These approaches have been implemented in PyG, and can benefit from the above GN
413
413
 
414
414
  PyG is available for Python 3.9 to Python 3.12.
415
415
 
416
- ### Anaconda
417
-
418
- You can now install PyG via [Anaconda](https://anaconda.org/pyg/pyg) for all major OS/PyTorch/CUDA combinations 🤗
419
- If you have not yet installed PyTorch, install it via `conda` as described in the [official PyTorch documentation](https://pytorch.org/get-started/locally/).
420
- Given that you have PyTorch installed (`>=1.8.0`), simply run
421
-
422
- ```
423
- conda install pyg -c pyg
424
- ```
425
-
426
- ### PyPi
427
-
428
416
  From **PyG 2.3** onwards, you can install and use PyG **without any external library** required except for PyTorch.
429
417
  For this, simply run
430
418
 
@@ -448,28 +436,28 @@ We recommend to start with a minimal installation, and install additional depend
448
436
 
449
437
  For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
450
438
 
451
- #### PyTorch 2.5
439
+ #### PyTorch 2.6
452
440
 
453
- To install the binaries for PyTorch 2.5.0, simply run
441
+ To install the binaries for PyTorch 2.6.0, simply run
454
442
 
455
443
  ```
456
- pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
444
+ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+${CUDA}.html
457
445
  ```
458
446
 
459
- where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
447
+ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu124`, or `cu126` depending on your PyTorch installation.
460
448
 
461
- | | `cpu` | `cu118` | `cu121` | `cu124` |
449
+ | | `cpu` | `cu118` | `cu124` | `cu126` |
462
450
  | ----------- | ----- | ------- | ------- | ------- |
463
451
  | **Linux** | ✅ | ✅ | ✅ | ✅ |
464
452
  | **Windows** | ✅ | ✅ | ✅ | ✅ |
465
453
  | **macOS** | ✅ | | | |
466
454
 
467
- #### PyTorch 2.4
455
+ #### PyTorch 2.5
468
456
 
469
- To install the binaries for PyTorch 2.4.0, simply run
457
+ To install the binaries for PyTorch 2.5.0/2.5.1, simply run
470
458
 
471
459
  ```
472
- pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html
460
+ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
473
461
  ```
474
462
 
475
463
  where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
@@ -480,7 +468,7 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
480
468
  | **Windows** | ✅ | ✅ | ✅ | ✅ |
481
469
  | **macOS** | ✅ | | | |
482
470
 
483
- **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure).
471
+ **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, and PyTorch 2.4.0/2.4.1 (following the same procedure).
484
472
  **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
485
473
  You can look up the latest supported version number [here](https://data.pyg.org/whl).
486
474
 
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=EFUlgJy_cHoHOgqO8KCynWIfRJFW8DFqG7O5v9DFOzI,1978
1
+ torch_geometric/__init__.py,sha256=rucRaUNeMS0eQLE0b6LFcLimKD5JwNS4_nYsmtC3BMU,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -188,7 +188,7 @@ torch_geometric/distributed/event_loop.py,sha256=wr3iwMYEWOGkBlvC5huD2k5YxisaGE9
188
188
  torch_geometric/distributed/local_feature_store.py,sha256=CLW37RN0ouDufEs2tY9d2nLLvpxubiT6zgW3LIHAU8k,19058
189
189
  torch_geometric/distributed/local_graph_store.py,sha256=wNHXSS824Kk2HynbtWFXx-W4pl97UUBv6qFHAVqO3W4,8445
190
190
  torch_geometric/distributed/partition.py,sha256=X0BleuY0ROlUtVXKvvz814pJwglZQ2_6OiMi1K0Hhvo,14731
191
- torch_geometric/distributed/rpc.py,sha256=t0Ts4tzUE0LQyBr71i2iBjQDLe9NWkmVRf7C4xOllJc,5753
191
+ torch_geometric/distributed/rpc.py,sha256=rJqiVR6Vbb2mpyVSC0Y5tPApqP-b1ck1Uq3IQpCsNSw,5737
192
192
  torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jdSIUjBeM,6567
193
193
  torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
194
194
  torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1BFlM,7834
@@ -200,7 +200,7 @@ torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS
200
200
  torch_geometric/explain/algorithm/captum.py,sha256=k6hNgC5Kn9lVirOYVJzej8-hRuf5C2mPFUXFLd2wWsY,12857
201
201
  torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQPHzQo_dFjHr9HLuJhDLsqRIVU,7346
202
202
  torch_geometric/explain/algorithm/dummy_explainer.py,sha256=jvcVQmfngmUWgoKa5p7CXzju2HM5D5DfieJhZW3gbLc,2872
203
- torch_geometric/explain/algorithm/gnn_explainer.py,sha256=TRGwaKYn9nLn3fp0rSSzeGs9uHj2rZzfomMseDfq8O8,12454
203
+ torch_geometric/explain/algorithm/gnn_explainer.py,sha256=iu45fGWdd4c6wNczWEAT-29HCAz7ncuoaS6cpx-xDJM,24660
204
204
  torch_geometric/explain/algorithm/graphmask_explainer.py,sha256=T2B081dK-JSpaQmutnkQd5xF3JF49_dPZCOgwqIKJDk,21367
205
205
  torch_geometric/explain/algorithm/pg_explainer.py,sha256=zPsl0tT9ISSWK1xP1KKpe1ZjUarhSB736WTtqwcmDIo,10372
206
206
  torch_geometric/explain/algorithm/utils.py,sha256=eh0ARPG41V7piVw5jdMYpV0p7WjTlpehnY-bWqPV_zg,2564
@@ -290,7 +290,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
290
290
  torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
291
291
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
292
292
  torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
293
- torch_geometric/metrics/link_pred.py,sha256=t2YHbEYc8Jbj_4Sb-Wdk5T5uzsSErpjBpUiSqOSf-NM,30729
293
+ torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
294
294
  torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
295
295
  torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
296
296
  torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
@@ -441,7 +441,7 @@ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0x
441
441
  torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
442
442
  torch_geometric/nn/models/glem.py,sha256=sT0XM4klVlci9wduvUoXupATUw9p25uXtaJBrmv3yvs,16431
443
443
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
444
- torch_geometric/nn/models/gpse.py,sha256=my-KIw_Ov8o0pXSCyh43NZRBAW95TFfmBgxzSimx8-A,42680
444
+ torch_geometric/nn/models/gpse.py,sha256=Fwldw9N3axV--BcSnI4im1sy1r87a5mAXZAXHu_6k2Y,41932
445
445
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
446
446
  torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
447
447
  torch_geometric/nn/models/jumping_knowledge.py,sha256=9JR2EoViXKjcDSLb8tjJm-UHfv1mQCJvZAAEsYa0Ocw,5496
@@ -571,7 +571,7 @@ torch_geometric/transforms/remove_duplicated_edges.py,sha256=EMX6E1R_gXFoXwBqMqt
571
571
  torch_geometric/transforms/remove_isolated_nodes.py,sha256=Q89b73es1tPsAmTdS7tWTIM7JcPUpL37v3EZTAd25Fc,2449
572
572
  torch_geometric/transforms/remove_self_loops.py,sha256=JfoooSnTO2KPXXuC3KWGhLS0tqr4yiXOl3A0sVv2riM,1221
573
573
  torch_geometric/transforms/remove_training_classes.py,sha256=GMCZwI_LYo2ZF29DABZXeuM0Sn2i3twx_V3KBUGu2As,932
574
- torch_geometric/transforms/rooted_subgraph.py,sha256=exQ-6bRePiH44o7f_VoBnkyj79PfGjH_hyAoolTIih8,6509
574
+ torch_geometric/transforms/rooted_subgraph.py,sha256=2P_2NSRcDNzfeR0zYScCaqZDkAvC0t3Ivr_lINO9abc,6504
575
575
  torch_geometric/transforms/sample_points.py,sha256=UfD44528J7SKH0I2_4ELM1A4hKKHCIDeMV6UzBbQAVU,2280
576
576
  torch_geometric/transforms/sign.py,sha256=bZUvUm9fMGXcYkI1GwPOW5ZC1QFu84vPBKpFnYxz2nA,2329
577
577
  torch_geometric/transforms/spherical.py,sha256=nU7h4IFw69JqUwRqaweVEBegHZWPOHDoTYYVRMzIZ7U,2320
@@ -592,7 +592,7 @@ torch_geometric/utils/_grid.py,sha256=1coutST2TMV9TSQcmpXze0GIK9odzZ9wBtbKs6u26D
592
592
  torch_geometric/utils/_homophily.py,sha256=1nXxGUATFPB3icEGpvEWUiuYbjU9gDGtlWpuLbtWhJk,5090
593
593
  torch_geometric/utils/_index_sort.py,sha256=FTJacmOsqgsyof7MJFHlVVdXhHOjR0j7siTb0UZ-YT0,1283
594
594
  torch_geometric/utils/_lexsort.py,sha256=chMEJJRXqfE6-K4vrVszdr3c338EhMZyi0Q9IEJD3p0,1403
595
- torch_geometric/utils/_negative_sampling.py,sha256=G4O572zAQgQQlVMz6ihhE13HFKEekLLYVXcYp4ZSdcQ,15521
595
+ torch_geometric/utils/_negative_sampling.py,sha256=jxsmpryeoTT8qQrvIH11MgyhgoWzvqPGRAcVyU85VCU,15494
596
596
  torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
597
597
  torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
598
598
  torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
636
636
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
637
637
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
638
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
639
- pyg_nightly-2.7.0.dev20250406.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
- pyg_nightly-2.7.0.dev20250406.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
- pyg_nightly-2.7.0.dev20250406.dist-info/METADATA,sha256=csAfUo5zCWohsFtqQsynpzuSflZRg7_4f9DB3JWSVWE,63021
642
- pyg_nightly-2.7.0.dev20250406.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250408.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250408.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250408.dist-info/METADATA,sha256=OX81Kse1GucsZCReO0E00CJl4rrtT9HhHEkNEbJU6KY,62652
642
+ pyg_nightly-2.7.0.dev20250408.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250406'
34
+ __version__ = '2.7.0.dev20250408'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
144
144
  @rpc_require_initialized
145
145
  def rpc_register(call: RPCCallBase) -> int:
146
146
  r"""Registers a call for RPC requests."""
147
- global _rpc_call_id, _rpc_call_pool
147
+ global _rpc_call_id
148
148
 
149
149
  with _rpc_call_lock:
150
150
  call_id = _rpc_call_id
@@ -1,14 +1,24 @@
1
1
  from math import sqrt
2
- from typing import Optional, Tuple, Union
2
+ from typing import Dict, Optional, Tuple, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
  from torch.nn.parameter import Parameter
7
7
 
8
- from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
8
+ from torch_geometric.explain import (
9
+ ExplainerConfig,
10
+ Explanation,
11
+ HeteroExplanation,
12
+ ModelConfig,
13
+ )
9
14
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
10
- from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
15
+ from torch_geometric.explain.algorithm.utils import (
16
+ clear_masks,
17
+ set_hetero_masks,
18
+ set_masks,
19
+ )
11
20
  from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
21
+ from torch_geometric.typing import EdgeType, NodeType
12
22
 
13
23
 
14
24
  class GNNExplainer(ExplainerAlgorithm):
@@ -69,7 +79,9 @@ class GNNExplainer(ExplainerAlgorithm):
69
79
 
70
80
  self.node_mask = self.hard_node_mask = None
71
81
  self.edge_mask = self.hard_edge_mask = None
82
+ self.is_hetero = False
72
83
 
84
+ @overload
73
85
  def forward(
74
86
  self,
75
87
  model: torch.nn.Module,
@@ -80,30 +92,87 @@ class GNNExplainer(ExplainerAlgorithm):
80
92
  index: Optional[Union[int, Tensor]] = None,
81
93
  **kwargs,
82
94
  ) -> Explanation:
83
- if isinstance(x, dict) or isinstance(edge_index, dict):
84
- raise ValueError(f"Heterogeneous graphs not yet supported in "
85
- f"'{self.__class__.__name__}'")
95
+ ...
86
96
 
87
- self._train(model, x, edge_index, target=target, index=index, **kwargs)
88
-
89
- node_mask = self._post_process_mask(
90
- self.node_mask,
91
- self.hard_node_mask,
92
- apply_sigmoid=True,
93
- )
94
- edge_mask = self._post_process_mask(
95
- self.edge_mask,
96
- self.hard_edge_mask,
97
- apply_sigmoid=True,
98
- )
97
+ @overload
98
+ def forward(
99
+ self,
100
+ model: torch.nn.Module,
101
+ x: Dict[NodeType, Tensor],
102
+ edge_index: Dict[EdgeType, Tensor],
103
+ *,
104
+ target: Tensor,
105
+ index: Optional[Union[int, Tensor]] = None,
106
+ **kwargs,
107
+ ) -> HeteroExplanation:
108
+ ...
99
109
 
110
+ def forward(
111
+ self,
112
+ model: torch.nn.Module,
113
+ x: Union[Tensor, Dict[NodeType, Tensor]],
114
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
115
+ *,
116
+ target: Tensor,
117
+ index: Optional[Union[int, Tensor]] = None,
118
+ **kwargs,
119
+ ) -> Union[Explanation, HeteroExplanation]:
120
+ self.is_hetero = isinstance(x, dict)
121
+ self._train(model, x, edge_index, target=target, index=index, **kwargs)
122
+ explanation = self._create_explanation()
100
123
  self._clean_model(model)
124
+ return explanation
125
+
126
+ def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
127
+ """Create an explanation object from the current masks."""
128
+ if self.is_hetero:
129
+ # For heterogeneous graphs, process each type separately
130
+ node_mask_dict = {}
131
+ edge_mask_dict = {}
132
+
133
+ for node_type, mask in self.node_mask.items():
134
+ if mask is not None:
135
+ node_mask_dict[node_type] = self._post_process_mask(
136
+ mask,
137
+ self.hard_node_mask[node_type],
138
+ apply_sigmoid=True,
139
+ )
140
+
141
+ for edge_type, mask in self.edge_mask.items():
142
+ if mask is not None:
143
+ edge_mask_dict[edge_type] = self._post_process_mask(
144
+ mask,
145
+ self.hard_edge_mask[edge_type],
146
+ apply_sigmoid=True,
147
+ )
148
+
149
+ # Create heterogeneous explanation
150
+ explanation = HeteroExplanation()
151
+ explanation.set_value_dict('node_mask', node_mask_dict)
152
+ explanation.set_value_dict('edge_mask', edge_mask_dict)
101
153
 
102
- return Explanation(node_mask=node_mask, edge_mask=edge_mask)
154
+ else:
155
+ # For homogeneous graphs, process single masks
156
+ node_mask = self._post_process_mask(
157
+ self.node_mask,
158
+ self.hard_node_mask,
159
+ apply_sigmoid=True,
160
+ )
161
+ edge_mask = self._post_process_mask(
162
+ self.edge_mask,
163
+ self.hard_edge_mask,
164
+ apply_sigmoid=True,
165
+ )
166
+
167
+ # Create homogeneous explanation
168
+ explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
169
+
170
+ return explanation
103
171
 
104
172
  def supports(self) -> bool:
105
173
  return True
106
174
 
175
+ @overload
107
176
  def _train(
108
177
  self,
109
178
  model: torch.nn.Module,
@@ -113,57 +182,222 @@ class GNNExplainer(ExplainerAlgorithm):
113
182
  target: Tensor,
114
183
  index: Optional[Union[int, Tensor]] = None,
115
184
  **kwargs,
116
- ):
185
+ ) -> None:
186
+ ...
187
+
188
+ @overload
189
+ def _train(
190
+ self,
191
+ model: torch.nn.Module,
192
+ x: Dict[NodeType, Tensor],
193
+ edge_index: Dict[EdgeType, Tensor],
194
+ *,
195
+ target: Tensor,
196
+ index: Optional[Union[int, Tensor]] = None,
197
+ **kwargs,
198
+ ) -> None:
199
+ ...
200
+
201
+ def _train(
202
+ self,
203
+ model: torch.nn.Module,
204
+ x: Union[Tensor, Dict[NodeType, Tensor]],
205
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
206
+ *,
207
+ target: Tensor,
208
+ index: Optional[Union[int, Tensor]] = None,
209
+ **kwargs,
210
+ ) -> None:
211
+ # Initialize masks based on input type
117
212
  self._initialize_masks(x, edge_index)
118
213
 
119
- parameters = []
120
- if self.node_mask is not None:
121
- parameters.append(self.node_mask)
122
- if self.edge_mask is not None:
123
- set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
124
- parameters.append(self.edge_mask)
214
+ # Collect parameters for optimization
215
+ parameters = self._collect_parameters(model, edge_index)
125
216
 
217
+ # Create optimizer
126
218
  optimizer = torch.optim.Adam(parameters, lr=self.lr)
127
219
 
220
+ # Training loop
128
221
  for i in range(self.epochs):
129
222
  optimizer.zero_grad()
130
223
 
131
- h = x if self.node_mask is None else x * self.node_mask.sigmoid()
132
- y_hat, y = model(h, edge_index, **kwargs), target
224
+ # Forward pass with masked inputs
225
+ y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
226
+ y = target
133
227
 
228
+ # Handle index if provided
134
229
  if index is not None:
135
230
  y_hat, y = y_hat[index], y[index]
136
231
 
232
+ # Calculate loss
137
233
  loss = self._loss(y_hat, y)
138
234
 
235
+ # Backward pass
139
236
  loss.backward()
140
237
  optimizer.step()
141
238
 
142
- # In the first iteration, we collect the nodes and edges that are
143
- # involved into making the prediction. These are all the nodes and
144
- # edges with gradient != 0 (without regularization applied).
145
- if i == 0 and self.node_mask is not None:
146
- if self.node_mask.grad is None:
147
- raise ValueError("Could not compute gradients for node "
148
- "features. Please make sure that node "
149
- "features are used inside the model or "
150
- "disable it via `node_mask_type=None`.")
151
- self.hard_node_mask = self.node_mask.grad != 0.0
152
- if i == 0 and self.edge_mask is not None:
153
- if self.edge_mask.grad is None:
154
- raise ValueError("Could not compute gradients for edges. "
155
- "Please make sure that edges are used "
156
- "via message passing inside the model or "
157
- "disable it via `edge_mask_type=None`.")
158
- self.hard_edge_mask = self.edge_mask.grad != 0.0
159
-
160
- def _initialize_masks(self, x: Tensor, edge_index: Tensor):
239
+ # In the first iteration, collect gradients to identify important
240
+ # nodes/edges
241
+ if i == 0:
242
+ self._collect_gradients()
243
+
244
+ def _collect_parameters(self, model, edge_index):
245
+ """Collect parameters for optimization."""
246
+ parameters = []
247
+
248
+ if self.is_hetero:
249
+ # For heterogeneous graphs, collect parameters from all types
250
+ for mask in self.node_mask.values():
251
+ if mask is not None:
252
+ parameters.append(mask)
253
+ if any(v is not None for v in self.edge_mask.values()):
254
+ set_hetero_masks(model, self.edge_mask, edge_index)
255
+ for mask in self.edge_mask.values():
256
+ if mask is not None:
257
+ parameters.append(mask)
258
+ else:
259
+ # For homogeneous graphs, collect single parameters
260
+ if self.node_mask is not None:
261
+ parameters.append(self.node_mask)
262
+ if self.edge_mask is not None:
263
+ set_masks(model, self.edge_mask, edge_index,
264
+ apply_sigmoid=True)
265
+ parameters.append(self.edge_mask)
266
+
267
+ return parameters
268
+
269
+ @overload
270
+ def _forward_with_masks(
271
+ self,
272
+ model: torch.nn.Module,
273
+ x: Tensor,
274
+ edge_index: Tensor,
275
+ **kwargs,
276
+ ) -> Tensor:
277
+ ...
278
+
279
+ @overload
280
+ def _forward_with_masks(
281
+ self,
282
+ model: torch.nn.Module,
283
+ x: Dict[NodeType, Tensor],
284
+ edge_index: Dict[EdgeType, Tensor],
285
+ **kwargs,
286
+ ) -> Tensor:
287
+ ...
288
+
289
+ def _forward_with_masks(
290
+ self,
291
+ model: torch.nn.Module,
292
+ x: Union[Tensor, Dict[NodeType, Tensor]],
293
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
294
+ **kwargs,
295
+ ) -> Tensor:
296
+ """Forward pass with masked inputs."""
297
+ if self.is_hetero:
298
+ # Apply masks to heterogeneous inputs
299
+ h_dict = {}
300
+ for node_type, features in x.items():
301
+ if node_type in self.node_mask and self.node_mask[
302
+ node_type] is not None:
303
+ h_dict[node_type] = features * self.node_mask[
304
+ node_type].sigmoid()
305
+ else:
306
+ h_dict[node_type] = features
307
+
308
+ # Forward pass with masked features
309
+ return model(h_dict, edge_index, **kwargs)
310
+ else:
311
+ # Apply mask to homogeneous input
312
+ h = x if self.node_mask is None else x * self.node_mask.sigmoid()
313
+
314
+ # Forward pass with masked features
315
+ return model(h, edge_index, **kwargs)
316
+
317
+ def _initialize_masks(
318
+ self,
319
+ x: Union[Tensor, Dict[NodeType, Tensor]],
320
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
321
+ ) -> None:
161
322
  node_mask_type = self.explainer_config.node_mask_type
162
323
  edge_mask_type = self.explainer_config.edge_mask_type
163
324
 
164
- device = x.device
165
- (N, F), E = x.size(), edge_index.size(1)
325
+ if self.is_hetero:
326
+ # Initialize dictionaries for heterogeneous masks
327
+ self.node_mask = {}
328
+ self.hard_node_mask = {}
329
+ self.edge_mask = {}
330
+ self.hard_edge_mask = {}
331
+
332
+ # Initialize node masks for each node type
333
+ for node_type, features in x.items():
334
+ device = features.device
335
+ N, F = features.size()
336
+ self._initialize_node_mask(node_mask_type, node_type, N, F,
337
+ device)
338
+
339
+ # Initialize edge masks for each edge type
340
+ for edge_type, indices in edge_index.items():
341
+ device = indices.device
342
+ E = indices.size(1)
343
+ N = max(indices.max().item() + 1,
344
+ max(feat.size(0) for feat in x.values()))
345
+ self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
346
+ device)
347
+ else:
348
+ # Initialize masks for homogeneous graph
349
+ device = x.device
350
+ (N, F), E = x.size(), edge_index.size(1)
351
+
352
+ # Initialize homogeneous node and edge masks
353
+ self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
354
+ N, F, E, device)
355
+
356
+ def _initialize_node_mask(
357
+ self,
358
+ node_mask_type,
359
+ node_type,
360
+ N,
361
+ F,
362
+ device,
363
+ ) -> None:
364
+ """Initialize node mask for a specific node type."""
365
+ std = 0.1
366
+ if node_mask_type is None:
367
+ self.node_mask[node_type] = None
368
+ self.hard_node_mask[node_type] = None
369
+ elif node_mask_type == MaskType.object:
370
+ self.node_mask[node_type] = Parameter(
371
+ torch.randn(N, 1, device=device) * std)
372
+ self.hard_node_mask[node_type] = None
373
+ elif node_mask_type == MaskType.attributes:
374
+ self.node_mask[node_type] = Parameter(
375
+ torch.randn(N, F, device=device) * std)
376
+ self.hard_node_mask[node_type] = None
377
+ elif node_mask_type == MaskType.common_attributes:
378
+ self.node_mask[node_type] = Parameter(
379
+ torch.randn(1, F, device=device) * std)
380
+ self.hard_node_mask[node_type] = None
381
+ else:
382
+ raise ValueError(f"Invalid node mask type: {node_mask_type}")
383
+
384
+ def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
385
+ """Initialize edge mask for a specific edge type."""
386
+ if edge_mask_type is None:
387
+ self.edge_mask[edge_type] = None
388
+ self.hard_edge_mask[edge_type] = None
389
+ elif edge_mask_type == MaskType.object:
390
+ std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
391
+ self.edge_mask[edge_type] = Parameter(
392
+ torch.randn(E, device=device) * std)
393
+ self.hard_edge_mask[edge_type] = None
394
+ else:
395
+ raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
166
396
 
397
+ def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
398
+ F, E, device):
399
+ """Initialize masks for homogeneous graph."""
400
+ # Initialize node mask
167
401
  std = 0.1
168
402
  if node_mask_type is None:
169
403
  self.node_mask = None
@@ -174,43 +408,145 @@ class GNNExplainer(ExplainerAlgorithm):
174
408
  elif node_mask_type == MaskType.common_attributes:
175
409
  self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
176
410
  else:
177
- assert False
411
+ raise ValueError(f"Invalid node mask type: {node_mask_type}")
178
412
 
413
+ # Initialize edge mask
179
414
  if edge_mask_type is None:
180
415
  self.edge_mask = None
181
416
  elif edge_mask_type == MaskType.object:
182
417
  std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
183
418
  self.edge_mask = Parameter(torch.randn(E, device=device) * std)
184
419
  else:
185
- assert False
420
+ raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
421
+
422
+ def _collect_gradients(self) -> None:
423
+ if self.is_hetero:
424
+ self._collect_hetero_gradients()
425
+ else:
426
+ self._collect_homo_gradients()
427
+
428
+ def _collect_hetero_gradients(self):
429
+ """Collect gradients for heterogeneous graph."""
430
+ for node_type, mask in self.node_mask.items():
431
+ if mask is not None:
432
+ if mask.grad is None:
433
+ raise ValueError(
434
+ f"Could not compute gradients for node masks of type "
435
+ f"'{node_type}'. Please make sure that node masks are "
436
+ f"used inside the model or disable it via "
437
+ f"`node_mask_type=None`.")
438
+
439
+ self.hard_node_mask[node_type] = mask.grad != 0.0
440
+
441
+ for edge_type, mask in self.edge_mask.items():
442
+ if mask is not None:
443
+ if mask.grad is None:
444
+ raise ValueError(
445
+ f"Could not compute gradients for edge masks of type "
446
+ f"'{edge_type}'. Please make sure that edge masks are "
447
+ f"used inside the model or disable it via "
448
+ f"`edge_mask_type=None`.")
449
+ self.hard_edge_mask[edge_type] = mask.grad != 0.0
450
+
451
+ def _collect_homo_gradients(self):
452
+ """Collect gradients for homogeneous graph."""
453
+ if self.node_mask is not None:
454
+ if self.node_mask.grad is None:
455
+ raise ValueError("Could not compute gradients for node "
456
+ "features. Please make sure that node "
457
+ "features are used inside the model or "
458
+ "disable it via `node_mask_type=None`.")
459
+ self.hard_node_mask = self.node_mask.grad != 0.0
460
+
461
+ if self.edge_mask is not None:
462
+ if self.edge_mask.grad is None:
463
+ raise ValueError("Could not compute gradients for edges. "
464
+ "Please make sure that edges are used "
465
+ "via message passing inside the model or "
466
+ "disable it via `edge_mask_type=None`.")
467
+ self.hard_edge_mask = self.edge_mask.grad != 0.0
186
468
 
187
469
  def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
470
+ # Calculate base loss based on model configuration
471
+ loss = self._calculate_base_loss(y_hat, y)
472
+
473
+ # Apply regularization based on graph type
474
+ if self.is_hetero:
475
+ # Apply regularization for heterogeneous graph
476
+ loss = self._apply_hetero_regularization(loss)
477
+ else:
478
+ # Apply regularization for homogeneous graph
479
+ loss = self._apply_homo_regularization(loss)
480
+
481
+ return loss
482
+
483
+ def _calculate_base_loss(self, y_hat, y):
484
+ """Calculate base loss based on model configuration."""
188
485
  if self.model_config.mode == ModelMode.binary_classification:
189
- loss = self._loss_binary_classification(y_hat, y)
486
+ return self._loss_binary_classification(y_hat, y)
190
487
  elif self.model_config.mode == ModelMode.multiclass_classification:
191
- loss = self._loss_multiclass_classification(y_hat, y)
488
+ return self._loss_multiclass_classification(y_hat, y)
192
489
  elif self.model_config.mode == ModelMode.regression:
193
- loss = self._loss_regression(y_hat, y)
490
+ return self._loss_regression(y_hat, y)
194
491
  else:
195
- assert False
492
+ raise ValueError(f"Invalid model mode: {self.model_config.mode}")
493
+
494
+ def _apply_hetero_regularization(self, loss):
495
+ """Apply regularization for heterogeneous graph."""
496
+ # Apply regularization for each edge type
497
+ for edge_type, mask in self.edge_mask.items():
498
+ if (mask is not None
499
+ and self.hard_edge_mask[edge_type] is not None):
500
+ loss = self._add_mask_regularization(
501
+ loss, mask, self.hard_edge_mask[edge_type],
502
+ self.coeffs['edge_size'], self.coeffs['edge_reduction'],
503
+ self.coeffs['edge_ent'])
504
+
505
+ # Apply regularization for each node type
506
+ for node_type, mask in self.node_mask.items():
507
+ if (mask is not None
508
+ and self.hard_node_mask[node_type] is not None):
509
+ loss = self._add_mask_regularization(
510
+ loss, mask, self.hard_node_mask[node_type],
511
+ self.coeffs['node_feat_size'],
512
+ self.coeffs['node_feat_reduction'],
513
+ self.coeffs['node_feat_ent'])
196
514
 
515
+ return loss
516
+
517
+ def _apply_homo_regularization(self, loss):
518
+ """Apply regularization for homogeneous graph."""
519
+ # Apply regularization for edge mask
197
520
  if self.hard_edge_mask is not None:
198
521
  assert self.edge_mask is not None
199
- m = self.edge_mask[self.hard_edge_mask].sigmoid()
200
- edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
201
- loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
202
- ent = -m * torch.log(m + self.coeffs['EPS']) - (
203
- 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
204
- loss = loss + self.coeffs['edge_ent'] * ent.mean()
522
+ loss = self._add_mask_regularization(loss, self.edge_mask,
523
+ self.hard_edge_mask,
524
+ self.coeffs['edge_size'],
525
+ self.coeffs['edge_reduction'],
526
+ self.coeffs['edge_ent'])
205
527
 
528
+ # Apply regularization for node mask
206
529
  if self.hard_node_mask is not None:
207
530
  assert self.node_mask is not None
208
- m = self.node_mask[self.hard_node_mask].sigmoid()
209
- node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
210
- loss = loss + self.coeffs['node_feat_size'] * node_reduce(m)
211
- ent = -m * torch.log(m + self.coeffs['EPS']) - (
212
- 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
213
- loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
531
+ loss = self._add_mask_regularization(
532
+ loss, self.node_mask, self.hard_node_mask,
533
+ self.coeffs['node_feat_size'],
534
+ self.coeffs['node_feat_reduction'],
535
+ self.coeffs['node_feat_ent'])
536
+
537
+ return loss
538
+
539
+ def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
540
+ reduction_name, ent_coeff):
541
+ """Add size and entropy regularization for a mask."""
542
+ m = mask[hard_mask].sigmoid()
543
+ reduce_fn = getattr(torch, reduction_name)
544
+ # Add size regularization
545
+ loss = loss + size_coeff * reduce_fn(m)
546
+ # Add entropy regularization
547
+ ent = -m * torch.log(m + self.coeffs['EPS']) - (
548
+ 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
549
+ loss = loss + ent_coeff * ent.mean()
214
550
 
215
551
  return loss
216
552
 
@@ -670,6 +670,9 @@ class LinkPredDiversity(_LinkPredMetric):
670
670
  def __init__(self, k: int, category: Tensor) -> None:
671
671
  super().__init__(k)
672
672
 
673
+ self.accum: Tensor
674
+ self.total: Tensor
675
+
673
676
  if WITH_TORCHMETRICS:
674
677
  self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
675
678
  self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
@@ -736,11 +739,14 @@ class LinkPredPersonalization(_LinkPredMetric):
736
739
  self.max_src_nodes = max_src_nodes
737
740
  self.batch_size = batch_size
738
741
 
742
+ self.preds: List[Tensor]
743
+ self.total: Tensor
744
+
739
745
  if WITH_TORCHMETRICS:
740
746
  self.add_state('preds', default=[], dist_reduce_fx='cat')
741
747
  self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
742
748
  else:
743
- self.preds: List[Tensor] = []
749
+ self.preds = []
744
750
  self.register_buffer('total', torch.tensor(0), persistent=False)
745
751
 
746
752
  def update(
@@ -826,6 +832,9 @@ class LinkPredAveragePopularity(_LinkPredMetric):
826
832
  def __init__(self, k: int, popularity: Tensor) -> None:
827
833
  super().__init__(k)
828
834
 
835
+ self.accum: Tensor
836
+ self.total: Tensor
837
+
829
838
  if WITH_TORCHMETRICS:
830
839
  self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
831
840
  self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
@@ -379,7 +379,7 @@ class GPSE(torch.nn.Module):
379
379
 
380
380
  .. code-block:: python
381
381
 
382
- from torch_geometric.nn import GPSE, GPSENodeEncoder,
382
+ from torch_geometric.nn import GPSE, GPSENodeEncoder
383
383
  from torch_geometric.transforms import AddGPSE
384
384
  from torch_geometric.nn.models.gpse import precompute_GPSE
385
385
 
@@ -417,13 +417,11 @@ class GPSE(torch.nn.Module):
417
417
 
418
418
  encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
419
419
  expand_x=False)
420
- gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
420
+ gnn = GNN(...)
421
421
 
422
422
  for batch in loader:
423
- batch = encoder(batch)
424
- batch = gnn(batch)
425
- # Do something with the batch, which now includes 128-dimensional
426
- # node representations
423
+ x = encoder(batch.x, batch.pestat_GPSE)
424
+ out = gnn(x, batch.edge_index)
427
425
 
428
426
 
429
427
  Args:
@@ -571,8 +569,7 @@ class GPSE(torch.nn.Module):
571
569
  self.reset_parameters()
572
570
 
573
571
  def reset_parameters(self):
574
- from torch_geometric.graphgym.init import init_weights
575
- self.apply(init_weights)
572
+ pass
576
573
 
577
574
  @classmethod
578
575
  def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
@@ -608,6 +605,7 @@ class GPSE(torch.nn.Module):
608
605
  return model
609
606
 
610
607
  def forward(self, batch):
608
+ batch = batch.clone()
611
609
  for module in self.children():
612
610
  batch = module(batch)
613
611
  return batch
@@ -635,13 +633,11 @@ class GPSENodeEncoder(torch.nn.Module):
635
633
 
636
634
  encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
637
635
  expand_x=False)
638
- gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
636
+ gnn = GNN(...)
639
637
 
640
638
  for batch in loader:
641
- batch = encoder(batch)
642
- batch = gnn(batch)
643
- # Do something with the batch, which now includes 128-dimensional
644
- # node representations
639
+ x = encoder(batch.x, batch.pestat_GPSE)
640
+ batch = gnn(x, batch.edge_index)
645
641
 
646
642
  Args:
647
643
  dim_emb (int): Size of final node embedding.
@@ -705,28 +701,17 @@ class GPSENodeEncoder(torch.nn.Module):
705
701
  raise ValueError(f"{self.__class__.__name__}: Does not support "
706
702
  f"'{model_type}' encoder model.")
707
703
 
708
- def forward(self, batch):
709
- if not hasattr(batch, 'pestat_GPSE'):
710
- raise ValueError('Precomputed "pestat_GPSE" variable is required '
711
- 'for GNNNodeEncoder; either run '
712
- '`precompute_GPSE(gpse_model, dataset)` on your '
713
- 'dataset or add `AddGPSE(gpse_model)` as a (pre) '
714
- 'transform.')
715
-
716
- pos_enc = batch.pestat_GPSE
717
-
704
+ def forward(self, x, pos_enc):
718
705
  pos_enc = self.dropout_be(pos_enc)
719
706
  pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
720
707
  pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
721
708
  pos_enc = self.dropout_ae(pos_enc)
722
709
 
723
710
  # Expand node features if needed
724
- h = self.linear_x(batch.x) if self.expand_x else batch.x
711
+ h = self.linear_x(x) if self.expand_x else x
725
712
 
726
713
  # Concatenate final PEs to input embedding
727
- batch.x = torch.cat((h, pos_enc), 1)
728
-
729
- return batch
714
+ return torch.cat((h, pos_enc), 1)
730
715
 
731
716
 
732
717
  @torch.no_grad()
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
94
94
  arange = torch.arange(n_id.size(0), device=data.edge_index.device)
95
95
  node_map = data.edge_index.new_ones(num_nodes, num_nodes)
96
96
  node_map[n_sub_batch, n_id] = arange
97
- sub_edge_index += (arange * data.num_nodes)[e_sub_batch]
97
+ sub_edge_index += (arange * num_nodes)[e_sub_batch]
98
98
  sub_edge_index = node_map.view(-1)[sub_edge_index]
99
99
 
100
100
  return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
@@ -109,10 +109,9 @@ def negative_sampling(
109
109
  idx = idx.to('cpu')
110
110
  for _ in range(3): # Number of tries to sample negative indices.
111
111
  rnd = sample(population, sample_size, device='cpu')
112
- mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore
112
+ mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
113
113
  if neg_idx is not None:
114
- mask |= np.isin(rnd, neg_idx.to('cpu'))
115
- mask = torch.from_numpy(mask).to(torch.bool)
114
+ mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
116
115
  rnd = rnd[~mask].to(edge_index.device)
117
116
  neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
118
117
  if neg_idx.numel() >= num_neg_samples: