pyg-nightly 2.7.0.dev20250406__py3-none-any.whl → 2.7.0.dev20250408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250406.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/METADATA +10 -22
- {pyg_nightly-2.7.0.dev20250406.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/RECORD +11 -11
- torch_geometric/__init__.py +1 -1
- torch_geometric/distributed/rpc.py +1 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +403 -67
- torch_geometric/metrics/link_pred.py +10 -1
- torch_geometric/nn/models/gpse.py +12 -27
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/utils/_negative_sampling.py +2 -3
- {pyg_nightly-2.7.0.dev20250406.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250406.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250406.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250408
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -413,18 +413,6 @@ These approaches have been implemented in PyG, and can benefit from the above GN
|
|
413
413
|
|
414
414
|
PyG is available for Python 3.9 to Python 3.12.
|
415
415
|
|
416
|
-
### Anaconda
|
417
|
-
|
418
|
-
You can now install PyG via [Anaconda](https://anaconda.org/pyg/pyg) for all major OS/PyTorch/CUDA combinations 🤗
|
419
|
-
If you have not yet installed PyTorch, install it via `conda` as described in the [official PyTorch documentation](https://pytorch.org/get-started/locally/).
|
420
|
-
Given that you have PyTorch installed (`>=1.8.0`), simply run
|
421
|
-
|
422
|
-
```
|
423
|
-
conda install pyg -c pyg
|
424
|
-
```
|
425
|
-
|
426
|
-
### PyPi
|
427
|
-
|
428
416
|
From **PyG 2.3** onwards, you can install and use PyG **without any external library** required except for PyTorch.
|
429
417
|
For this, simply run
|
430
418
|
|
@@ -448,28 +436,28 @@ We recommend to start with a minimal installation, and install additional depend
|
|
448
436
|
|
449
437
|
For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
|
450
438
|
|
451
|
-
#### PyTorch 2.
|
439
|
+
#### PyTorch 2.6
|
452
440
|
|
453
|
-
To install the binaries for PyTorch 2.
|
441
|
+
To install the binaries for PyTorch 2.6.0, simply run
|
454
442
|
|
455
443
|
```
|
456
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
444
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+${CUDA}.html
|
457
445
|
```
|
458
446
|
|
459
|
-
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `
|
447
|
+
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu124`, or `cu126` depending on your PyTorch installation.
|
460
448
|
|
461
|
-
| | `cpu` | `cu118` | `
|
449
|
+
| | `cpu` | `cu118` | `cu124` | `cu126` |
|
462
450
|
| ----------- | ----- | ------- | ------- | ------- |
|
463
451
|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
|
464
452
|
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
465
453
|
| **macOS** | ✅ | | | |
|
466
454
|
|
467
|
-
#### PyTorch 2.
|
455
|
+
#### PyTorch 2.5
|
468
456
|
|
469
|
-
To install the binaries for PyTorch 2.
|
457
|
+
To install the binaries for PyTorch 2.5.0/2.5.1, simply run
|
470
458
|
|
471
459
|
```
|
472
|
-
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.
|
460
|
+
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
|
473
461
|
```
|
474
462
|
|
475
463
|
where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` depending on your PyTorch installation.
|
@@ -480,7 +468,7 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`
|
|
480
468
|
| **Windows** | ✅ | ✅ | ✅ | ✅ |
|
481
469
|
| **macOS** | ✅ | | | |
|
482
470
|
|
483
|
-
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2,
|
471
|
+
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, and PyTorch 2.4.0/2.4.1 (following the same procedure).
|
484
472
|
**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
|
485
473
|
You can look up the latest supported version number [here](https://data.pyg.org/whl).
|
486
474
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=rucRaUNeMS0eQLE0b6LFcLimKD5JwNS4_nYsmtC3BMU,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -188,7 +188,7 @@ torch_geometric/distributed/event_loop.py,sha256=wr3iwMYEWOGkBlvC5huD2k5YxisaGE9
|
|
188
188
|
torch_geometric/distributed/local_feature_store.py,sha256=CLW37RN0ouDufEs2tY9d2nLLvpxubiT6zgW3LIHAU8k,19058
|
189
189
|
torch_geometric/distributed/local_graph_store.py,sha256=wNHXSS824Kk2HynbtWFXx-W4pl97UUBv6qFHAVqO3W4,8445
|
190
190
|
torch_geometric/distributed/partition.py,sha256=X0BleuY0ROlUtVXKvvz814pJwglZQ2_6OiMi1K0Hhvo,14731
|
191
|
-
torch_geometric/distributed/rpc.py,sha256=
|
191
|
+
torch_geometric/distributed/rpc.py,sha256=rJqiVR6Vbb2mpyVSC0Y5tPApqP-b1ck1Uq3IQpCsNSw,5737
|
192
192
|
torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jdSIUjBeM,6567
|
193
193
|
torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
|
194
194
|
torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1BFlM,7834
|
@@ -200,7 +200,7 @@ torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS
|
|
200
200
|
torch_geometric/explain/algorithm/captum.py,sha256=k6hNgC5Kn9lVirOYVJzej8-hRuf5C2mPFUXFLd2wWsY,12857
|
201
201
|
torch_geometric/explain/algorithm/captum_explainer.py,sha256=oz-c40hvdzii4_chEQPHzQo_dFjHr9HLuJhDLsqRIVU,7346
|
202
202
|
torch_geometric/explain/algorithm/dummy_explainer.py,sha256=jvcVQmfngmUWgoKa5p7CXzju2HM5D5DfieJhZW3gbLc,2872
|
203
|
-
torch_geometric/explain/algorithm/gnn_explainer.py,sha256=
|
203
|
+
torch_geometric/explain/algorithm/gnn_explainer.py,sha256=iu45fGWdd4c6wNczWEAT-29HCAz7ncuoaS6cpx-xDJM,24660
|
204
204
|
torch_geometric/explain/algorithm/graphmask_explainer.py,sha256=T2B081dK-JSpaQmutnkQd5xF3JF49_dPZCOgwqIKJDk,21367
|
205
205
|
torch_geometric/explain/algorithm/pg_explainer.py,sha256=zPsl0tT9ISSWK1xP1KKpe1ZjUarhSB736WTtqwcmDIo,10372
|
206
206
|
torch_geometric/explain/algorithm/utils.py,sha256=eh0ARPG41V7piVw5jdMYpV0p7WjTlpehnY-bWqPV_zg,2564
|
@@ -290,7 +290,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
|
|
290
290
|
torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
|
291
291
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
292
292
|
torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
|
293
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
293
|
+
torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
|
294
294
|
torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
|
295
295
|
torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
|
296
296
|
torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
|
@@ -441,7 +441,7 @@ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0x
|
|
441
441
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
442
442
|
torch_geometric/nn/models/glem.py,sha256=sT0XM4klVlci9wduvUoXupATUw9p25uXtaJBrmv3yvs,16431
|
443
443
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
444
|
-
torch_geometric/nn/models/gpse.py,sha256=
|
444
|
+
torch_geometric/nn/models/gpse.py,sha256=Fwldw9N3axV--BcSnI4im1sy1r87a5mAXZAXHu_6k2Y,41932
|
445
445
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
446
446
|
torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
|
447
447
|
torch_geometric/nn/models/jumping_knowledge.py,sha256=9JR2EoViXKjcDSLb8tjJm-UHfv1mQCJvZAAEsYa0Ocw,5496
|
@@ -571,7 +571,7 @@ torch_geometric/transforms/remove_duplicated_edges.py,sha256=EMX6E1R_gXFoXwBqMqt
|
|
571
571
|
torch_geometric/transforms/remove_isolated_nodes.py,sha256=Q89b73es1tPsAmTdS7tWTIM7JcPUpL37v3EZTAd25Fc,2449
|
572
572
|
torch_geometric/transforms/remove_self_loops.py,sha256=JfoooSnTO2KPXXuC3KWGhLS0tqr4yiXOl3A0sVv2riM,1221
|
573
573
|
torch_geometric/transforms/remove_training_classes.py,sha256=GMCZwI_LYo2ZF29DABZXeuM0Sn2i3twx_V3KBUGu2As,932
|
574
|
-
torch_geometric/transforms/rooted_subgraph.py,sha256=
|
574
|
+
torch_geometric/transforms/rooted_subgraph.py,sha256=2P_2NSRcDNzfeR0zYScCaqZDkAvC0t3Ivr_lINO9abc,6504
|
575
575
|
torch_geometric/transforms/sample_points.py,sha256=UfD44528J7SKH0I2_4ELM1A4hKKHCIDeMV6UzBbQAVU,2280
|
576
576
|
torch_geometric/transforms/sign.py,sha256=bZUvUm9fMGXcYkI1GwPOW5ZC1QFu84vPBKpFnYxz2nA,2329
|
577
577
|
torch_geometric/transforms/spherical.py,sha256=nU7h4IFw69JqUwRqaweVEBegHZWPOHDoTYYVRMzIZ7U,2320
|
@@ -592,7 +592,7 @@ torch_geometric/utils/_grid.py,sha256=1coutST2TMV9TSQcmpXze0GIK9odzZ9wBtbKs6u26D
|
|
592
592
|
torch_geometric/utils/_homophily.py,sha256=1nXxGUATFPB3icEGpvEWUiuYbjU9gDGtlWpuLbtWhJk,5090
|
593
593
|
torch_geometric/utils/_index_sort.py,sha256=FTJacmOsqgsyof7MJFHlVVdXhHOjR0j7siTb0UZ-YT0,1283
|
594
594
|
torch_geometric/utils/_lexsort.py,sha256=chMEJJRXqfE6-K4vrVszdr3c338EhMZyi0Q9IEJD3p0,1403
|
595
|
-
torch_geometric/utils/_negative_sampling.py,sha256=
|
595
|
+
torch_geometric/utils/_negative_sampling.py,sha256=jxsmpryeoTT8qQrvIH11MgyhgoWzvqPGRAcVyU85VCU,15494
|
596
596
|
torch_geometric/utils/_normalize_edge_index.py,sha256=H6DY-Dzi1Psr3igG_nb0U3ZPNZz-BBDntO2iuA8FtzA,1682
|
597
597
|
torch_geometric/utils/_normalized_cut.py,sha256=uwVJkl-Q0tpY-w0nvcHajcQYcqFh1oDOf55XELdjJBU,1167
|
598
598
|
torch_geometric/utils/_one_hot.py,sha256=vXC7l7zudYRZIwWv6mT-Biuk2zKELyqteJXLynPocPM,1404
|
@@ -636,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
636
636
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
637
637
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
638
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/METADATA,sha256=OX81Kse1GucsZCReO0E00CJl4rrtT9HhHEkNEbJU6KY,62652
|
642
|
+
pyg_nightly-2.7.0.dev20250408.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250408'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
|
|
144
144
|
@rpc_require_initialized
|
145
145
|
def rpc_register(call: RPCCallBase) -> int:
|
146
146
|
r"""Registers a call for RPC requests."""
|
147
|
-
global _rpc_call_id
|
147
|
+
global _rpc_call_id
|
148
148
|
|
149
149
|
with _rpc_call_lock:
|
150
150
|
call_id = _rpc_call_id
|
@@ -1,14 +1,24 @@
|
|
1
1
|
from math import sqrt
|
2
|
-
from typing import Optional, Tuple, Union
|
2
|
+
from typing import Dict, Optional, Tuple, Union, overload
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
from torch.nn.parameter import Parameter
|
7
7
|
|
8
|
-
from torch_geometric.explain import
|
8
|
+
from torch_geometric.explain import (
|
9
|
+
ExplainerConfig,
|
10
|
+
Explanation,
|
11
|
+
HeteroExplanation,
|
12
|
+
ModelConfig,
|
13
|
+
)
|
9
14
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
10
|
-
from torch_geometric.explain.algorithm.utils import
|
15
|
+
from torch_geometric.explain.algorithm.utils import (
|
16
|
+
clear_masks,
|
17
|
+
set_hetero_masks,
|
18
|
+
set_masks,
|
19
|
+
)
|
11
20
|
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
|
21
|
+
from torch_geometric.typing import EdgeType, NodeType
|
12
22
|
|
13
23
|
|
14
24
|
class GNNExplainer(ExplainerAlgorithm):
|
@@ -69,7 +79,9 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
69
79
|
|
70
80
|
self.node_mask = self.hard_node_mask = None
|
71
81
|
self.edge_mask = self.hard_edge_mask = None
|
82
|
+
self.is_hetero = False
|
72
83
|
|
84
|
+
@overload
|
73
85
|
def forward(
|
74
86
|
self,
|
75
87
|
model: torch.nn.Module,
|
@@ -80,30 +92,87 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
80
92
|
index: Optional[Union[int, Tensor]] = None,
|
81
93
|
**kwargs,
|
82
94
|
) -> Explanation:
|
83
|
-
|
84
|
-
raise ValueError(f"Heterogeneous graphs not yet supported in "
|
85
|
-
f"'{self.__class__.__name__}'")
|
95
|
+
...
|
86
96
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
97
|
+
@overload
|
98
|
+
def forward(
|
99
|
+
self,
|
100
|
+
model: torch.nn.Module,
|
101
|
+
x: Dict[NodeType, Tensor],
|
102
|
+
edge_index: Dict[EdgeType, Tensor],
|
103
|
+
*,
|
104
|
+
target: Tensor,
|
105
|
+
index: Optional[Union[int, Tensor]] = None,
|
106
|
+
**kwargs,
|
107
|
+
) -> HeteroExplanation:
|
108
|
+
...
|
99
109
|
|
110
|
+
def forward(
|
111
|
+
self,
|
112
|
+
model: torch.nn.Module,
|
113
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
114
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
115
|
+
*,
|
116
|
+
target: Tensor,
|
117
|
+
index: Optional[Union[int, Tensor]] = None,
|
118
|
+
**kwargs,
|
119
|
+
) -> Union[Explanation, HeteroExplanation]:
|
120
|
+
self.is_hetero = isinstance(x, dict)
|
121
|
+
self._train(model, x, edge_index, target=target, index=index, **kwargs)
|
122
|
+
explanation = self._create_explanation()
|
100
123
|
self._clean_model(model)
|
124
|
+
return explanation
|
125
|
+
|
126
|
+
def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
|
127
|
+
"""Create an explanation object from the current masks."""
|
128
|
+
if self.is_hetero:
|
129
|
+
# For heterogeneous graphs, process each type separately
|
130
|
+
node_mask_dict = {}
|
131
|
+
edge_mask_dict = {}
|
132
|
+
|
133
|
+
for node_type, mask in self.node_mask.items():
|
134
|
+
if mask is not None:
|
135
|
+
node_mask_dict[node_type] = self._post_process_mask(
|
136
|
+
mask,
|
137
|
+
self.hard_node_mask[node_type],
|
138
|
+
apply_sigmoid=True,
|
139
|
+
)
|
140
|
+
|
141
|
+
for edge_type, mask in self.edge_mask.items():
|
142
|
+
if mask is not None:
|
143
|
+
edge_mask_dict[edge_type] = self._post_process_mask(
|
144
|
+
mask,
|
145
|
+
self.hard_edge_mask[edge_type],
|
146
|
+
apply_sigmoid=True,
|
147
|
+
)
|
148
|
+
|
149
|
+
# Create heterogeneous explanation
|
150
|
+
explanation = HeteroExplanation()
|
151
|
+
explanation.set_value_dict('node_mask', node_mask_dict)
|
152
|
+
explanation.set_value_dict('edge_mask', edge_mask_dict)
|
101
153
|
|
102
|
-
|
154
|
+
else:
|
155
|
+
# For homogeneous graphs, process single masks
|
156
|
+
node_mask = self._post_process_mask(
|
157
|
+
self.node_mask,
|
158
|
+
self.hard_node_mask,
|
159
|
+
apply_sigmoid=True,
|
160
|
+
)
|
161
|
+
edge_mask = self._post_process_mask(
|
162
|
+
self.edge_mask,
|
163
|
+
self.hard_edge_mask,
|
164
|
+
apply_sigmoid=True,
|
165
|
+
)
|
166
|
+
|
167
|
+
# Create homogeneous explanation
|
168
|
+
explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
|
169
|
+
|
170
|
+
return explanation
|
103
171
|
|
104
172
|
def supports(self) -> bool:
|
105
173
|
return True
|
106
174
|
|
175
|
+
@overload
|
107
176
|
def _train(
|
108
177
|
self,
|
109
178
|
model: torch.nn.Module,
|
@@ -113,57 +182,222 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
113
182
|
target: Tensor,
|
114
183
|
index: Optional[Union[int, Tensor]] = None,
|
115
184
|
**kwargs,
|
116
|
-
):
|
185
|
+
) -> None:
|
186
|
+
...
|
187
|
+
|
188
|
+
@overload
|
189
|
+
def _train(
|
190
|
+
self,
|
191
|
+
model: torch.nn.Module,
|
192
|
+
x: Dict[NodeType, Tensor],
|
193
|
+
edge_index: Dict[EdgeType, Tensor],
|
194
|
+
*,
|
195
|
+
target: Tensor,
|
196
|
+
index: Optional[Union[int, Tensor]] = None,
|
197
|
+
**kwargs,
|
198
|
+
) -> None:
|
199
|
+
...
|
200
|
+
|
201
|
+
def _train(
|
202
|
+
self,
|
203
|
+
model: torch.nn.Module,
|
204
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
205
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
206
|
+
*,
|
207
|
+
target: Tensor,
|
208
|
+
index: Optional[Union[int, Tensor]] = None,
|
209
|
+
**kwargs,
|
210
|
+
) -> None:
|
211
|
+
# Initialize masks based on input type
|
117
212
|
self._initialize_masks(x, edge_index)
|
118
213
|
|
119
|
-
parameters
|
120
|
-
|
121
|
-
parameters.append(self.node_mask)
|
122
|
-
if self.edge_mask is not None:
|
123
|
-
set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
|
124
|
-
parameters.append(self.edge_mask)
|
214
|
+
# Collect parameters for optimization
|
215
|
+
parameters = self._collect_parameters(model, edge_index)
|
125
216
|
|
217
|
+
# Create optimizer
|
126
218
|
optimizer = torch.optim.Adam(parameters, lr=self.lr)
|
127
219
|
|
220
|
+
# Training loop
|
128
221
|
for i in range(self.epochs):
|
129
222
|
optimizer.zero_grad()
|
130
223
|
|
131
|
-
|
132
|
-
y_hat
|
224
|
+
# Forward pass with masked inputs
|
225
|
+
y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
|
226
|
+
y = target
|
133
227
|
|
228
|
+
# Handle index if provided
|
134
229
|
if index is not None:
|
135
230
|
y_hat, y = y_hat[index], y[index]
|
136
231
|
|
232
|
+
# Calculate loss
|
137
233
|
loss = self._loss(y_hat, y)
|
138
234
|
|
235
|
+
# Backward pass
|
139
236
|
loss.backward()
|
140
237
|
optimizer.step()
|
141
238
|
|
142
|
-
# In the first iteration,
|
143
|
-
#
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
239
|
+
# In the first iteration, collect gradients to identify important
|
240
|
+
# nodes/edges
|
241
|
+
if i == 0:
|
242
|
+
self._collect_gradients()
|
243
|
+
|
244
|
+
def _collect_parameters(self, model, edge_index):
|
245
|
+
"""Collect parameters for optimization."""
|
246
|
+
parameters = []
|
247
|
+
|
248
|
+
if self.is_hetero:
|
249
|
+
# For heterogeneous graphs, collect parameters from all types
|
250
|
+
for mask in self.node_mask.values():
|
251
|
+
if mask is not None:
|
252
|
+
parameters.append(mask)
|
253
|
+
if any(v is not None for v in self.edge_mask.values()):
|
254
|
+
set_hetero_masks(model, self.edge_mask, edge_index)
|
255
|
+
for mask in self.edge_mask.values():
|
256
|
+
if mask is not None:
|
257
|
+
parameters.append(mask)
|
258
|
+
else:
|
259
|
+
# For homogeneous graphs, collect single parameters
|
260
|
+
if self.node_mask is not None:
|
261
|
+
parameters.append(self.node_mask)
|
262
|
+
if self.edge_mask is not None:
|
263
|
+
set_masks(model, self.edge_mask, edge_index,
|
264
|
+
apply_sigmoid=True)
|
265
|
+
parameters.append(self.edge_mask)
|
266
|
+
|
267
|
+
return parameters
|
268
|
+
|
269
|
+
@overload
|
270
|
+
def _forward_with_masks(
|
271
|
+
self,
|
272
|
+
model: torch.nn.Module,
|
273
|
+
x: Tensor,
|
274
|
+
edge_index: Tensor,
|
275
|
+
**kwargs,
|
276
|
+
) -> Tensor:
|
277
|
+
...
|
278
|
+
|
279
|
+
@overload
|
280
|
+
def _forward_with_masks(
|
281
|
+
self,
|
282
|
+
model: torch.nn.Module,
|
283
|
+
x: Dict[NodeType, Tensor],
|
284
|
+
edge_index: Dict[EdgeType, Tensor],
|
285
|
+
**kwargs,
|
286
|
+
) -> Tensor:
|
287
|
+
...
|
288
|
+
|
289
|
+
def _forward_with_masks(
|
290
|
+
self,
|
291
|
+
model: torch.nn.Module,
|
292
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
293
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
294
|
+
**kwargs,
|
295
|
+
) -> Tensor:
|
296
|
+
"""Forward pass with masked inputs."""
|
297
|
+
if self.is_hetero:
|
298
|
+
# Apply masks to heterogeneous inputs
|
299
|
+
h_dict = {}
|
300
|
+
for node_type, features in x.items():
|
301
|
+
if node_type in self.node_mask and self.node_mask[
|
302
|
+
node_type] is not None:
|
303
|
+
h_dict[node_type] = features * self.node_mask[
|
304
|
+
node_type].sigmoid()
|
305
|
+
else:
|
306
|
+
h_dict[node_type] = features
|
307
|
+
|
308
|
+
# Forward pass with masked features
|
309
|
+
return model(h_dict, edge_index, **kwargs)
|
310
|
+
else:
|
311
|
+
# Apply mask to homogeneous input
|
312
|
+
h = x if self.node_mask is None else x * self.node_mask.sigmoid()
|
313
|
+
|
314
|
+
# Forward pass with masked features
|
315
|
+
return model(h, edge_index, **kwargs)
|
316
|
+
|
317
|
+
def _initialize_masks(
|
318
|
+
self,
|
319
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
320
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
321
|
+
) -> None:
|
161
322
|
node_mask_type = self.explainer_config.node_mask_type
|
162
323
|
edge_mask_type = self.explainer_config.edge_mask_type
|
163
324
|
|
164
|
-
|
165
|
-
|
325
|
+
if self.is_hetero:
|
326
|
+
# Initialize dictionaries for heterogeneous masks
|
327
|
+
self.node_mask = {}
|
328
|
+
self.hard_node_mask = {}
|
329
|
+
self.edge_mask = {}
|
330
|
+
self.hard_edge_mask = {}
|
331
|
+
|
332
|
+
# Initialize node masks for each node type
|
333
|
+
for node_type, features in x.items():
|
334
|
+
device = features.device
|
335
|
+
N, F = features.size()
|
336
|
+
self._initialize_node_mask(node_mask_type, node_type, N, F,
|
337
|
+
device)
|
338
|
+
|
339
|
+
# Initialize edge masks for each edge type
|
340
|
+
for edge_type, indices in edge_index.items():
|
341
|
+
device = indices.device
|
342
|
+
E = indices.size(1)
|
343
|
+
N = max(indices.max().item() + 1,
|
344
|
+
max(feat.size(0) for feat in x.values()))
|
345
|
+
self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
|
346
|
+
device)
|
347
|
+
else:
|
348
|
+
# Initialize masks for homogeneous graph
|
349
|
+
device = x.device
|
350
|
+
(N, F), E = x.size(), edge_index.size(1)
|
351
|
+
|
352
|
+
# Initialize homogeneous node and edge masks
|
353
|
+
self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
|
354
|
+
N, F, E, device)
|
355
|
+
|
356
|
+
def _initialize_node_mask(
|
357
|
+
self,
|
358
|
+
node_mask_type,
|
359
|
+
node_type,
|
360
|
+
N,
|
361
|
+
F,
|
362
|
+
device,
|
363
|
+
) -> None:
|
364
|
+
"""Initialize node mask for a specific node type."""
|
365
|
+
std = 0.1
|
366
|
+
if node_mask_type is None:
|
367
|
+
self.node_mask[node_type] = None
|
368
|
+
self.hard_node_mask[node_type] = None
|
369
|
+
elif node_mask_type == MaskType.object:
|
370
|
+
self.node_mask[node_type] = Parameter(
|
371
|
+
torch.randn(N, 1, device=device) * std)
|
372
|
+
self.hard_node_mask[node_type] = None
|
373
|
+
elif node_mask_type == MaskType.attributes:
|
374
|
+
self.node_mask[node_type] = Parameter(
|
375
|
+
torch.randn(N, F, device=device) * std)
|
376
|
+
self.hard_node_mask[node_type] = None
|
377
|
+
elif node_mask_type == MaskType.common_attributes:
|
378
|
+
self.node_mask[node_type] = Parameter(
|
379
|
+
torch.randn(1, F, device=device) * std)
|
380
|
+
self.hard_node_mask[node_type] = None
|
381
|
+
else:
|
382
|
+
raise ValueError(f"Invalid node mask type: {node_mask_type}")
|
383
|
+
|
384
|
+
def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
|
385
|
+
"""Initialize edge mask for a specific edge type."""
|
386
|
+
if edge_mask_type is None:
|
387
|
+
self.edge_mask[edge_type] = None
|
388
|
+
self.hard_edge_mask[edge_type] = None
|
389
|
+
elif edge_mask_type == MaskType.object:
|
390
|
+
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
|
391
|
+
self.edge_mask[edge_type] = Parameter(
|
392
|
+
torch.randn(E, device=device) * std)
|
393
|
+
self.hard_edge_mask[edge_type] = None
|
394
|
+
else:
|
395
|
+
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
|
166
396
|
|
397
|
+
def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
|
398
|
+
F, E, device):
|
399
|
+
"""Initialize masks for homogeneous graph."""
|
400
|
+
# Initialize node mask
|
167
401
|
std = 0.1
|
168
402
|
if node_mask_type is None:
|
169
403
|
self.node_mask = None
|
@@ -174,43 +408,145 @@ class GNNExplainer(ExplainerAlgorithm):
|
|
174
408
|
elif node_mask_type == MaskType.common_attributes:
|
175
409
|
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
|
176
410
|
else:
|
177
|
-
|
411
|
+
raise ValueError(f"Invalid node mask type: {node_mask_type}")
|
178
412
|
|
413
|
+
# Initialize edge mask
|
179
414
|
if edge_mask_type is None:
|
180
415
|
self.edge_mask = None
|
181
416
|
elif edge_mask_type == MaskType.object:
|
182
417
|
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
|
183
418
|
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
|
184
419
|
else:
|
185
|
-
|
420
|
+
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
|
421
|
+
|
422
|
+
def _collect_gradients(self) -> None:
|
423
|
+
if self.is_hetero:
|
424
|
+
self._collect_hetero_gradients()
|
425
|
+
else:
|
426
|
+
self._collect_homo_gradients()
|
427
|
+
|
428
|
+
def _collect_hetero_gradients(self):
|
429
|
+
"""Collect gradients for heterogeneous graph."""
|
430
|
+
for node_type, mask in self.node_mask.items():
|
431
|
+
if mask is not None:
|
432
|
+
if mask.grad is None:
|
433
|
+
raise ValueError(
|
434
|
+
f"Could not compute gradients for node masks of type "
|
435
|
+
f"'{node_type}'. Please make sure that node masks are "
|
436
|
+
f"used inside the model or disable it via "
|
437
|
+
f"`node_mask_type=None`.")
|
438
|
+
|
439
|
+
self.hard_node_mask[node_type] = mask.grad != 0.0
|
440
|
+
|
441
|
+
for edge_type, mask in self.edge_mask.items():
|
442
|
+
if mask is not None:
|
443
|
+
if mask.grad is None:
|
444
|
+
raise ValueError(
|
445
|
+
f"Could not compute gradients for edge masks of type "
|
446
|
+
f"'{edge_type}'. Please make sure that edge masks are "
|
447
|
+
f"used inside the model or disable it via "
|
448
|
+
f"`edge_mask_type=None`.")
|
449
|
+
self.hard_edge_mask[edge_type] = mask.grad != 0.0
|
450
|
+
|
451
|
+
def _collect_homo_gradients(self):
|
452
|
+
"""Collect gradients for homogeneous graph."""
|
453
|
+
if self.node_mask is not None:
|
454
|
+
if self.node_mask.grad is None:
|
455
|
+
raise ValueError("Could not compute gradients for node "
|
456
|
+
"features. Please make sure that node "
|
457
|
+
"features are used inside the model or "
|
458
|
+
"disable it via `node_mask_type=None`.")
|
459
|
+
self.hard_node_mask = self.node_mask.grad != 0.0
|
460
|
+
|
461
|
+
if self.edge_mask is not None:
|
462
|
+
if self.edge_mask.grad is None:
|
463
|
+
raise ValueError("Could not compute gradients for edges. "
|
464
|
+
"Please make sure that edges are used "
|
465
|
+
"via message passing inside the model or "
|
466
|
+
"disable it via `edge_mask_type=None`.")
|
467
|
+
self.hard_edge_mask = self.edge_mask.grad != 0.0
|
186
468
|
|
187
469
|
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
470
|
+
# Calculate base loss based on model configuration
|
471
|
+
loss = self._calculate_base_loss(y_hat, y)
|
472
|
+
|
473
|
+
# Apply regularization based on graph type
|
474
|
+
if self.is_hetero:
|
475
|
+
# Apply regularization for heterogeneous graph
|
476
|
+
loss = self._apply_hetero_regularization(loss)
|
477
|
+
else:
|
478
|
+
# Apply regularization for homogeneous graph
|
479
|
+
loss = self._apply_homo_regularization(loss)
|
480
|
+
|
481
|
+
return loss
|
482
|
+
|
483
|
+
def _calculate_base_loss(self, y_hat, y):
|
484
|
+
"""Calculate base loss based on model configuration."""
|
188
485
|
if self.model_config.mode == ModelMode.binary_classification:
|
189
|
-
|
486
|
+
return self._loss_binary_classification(y_hat, y)
|
190
487
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
191
|
-
|
488
|
+
return self._loss_multiclass_classification(y_hat, y)
|
192
489
|
elif self.model_config.mode == ModelMode.regression:
|
193
|
-
|
490
|
+
return self._loss_regression(y_hat, y)
|
194
491
|
else:
|
195
|
-
|
492
|
+
raise ValueError(f"Invalid model mode: {self.model_config.mode}")
|
493
|
+
|
494
|
+
def _apply_hetero_regularization(self, loss):
|
495
|
+
"""Apply regularization for heterogeneous graph."""
|
496
|
+
# Apply regularization for each edge type
|
497
|
+
for edge_type, mask in self.edge_mask.items():
|
498
|
+
if (mask is not None
|
499
|
+
and self.hard_edge_mask[edge_type] is not None):
|
500
|
+
loss = self._add_mask_regularization(
|
501
|
+
loss, mask, self.hard_edge_mask[edge_type],
|
502
|
+
self.coeffs['edge_size'], self.coeffs['edge_reduction'],
|
503
|
+
self.coeffs['edge_ent'])
|
504
|
+
|
505
|
+
# Apply regularization for each node type
|
506
|
+
for node_type, mask in self.node_mask.items():
|
507
|
+
if (mask is not None
|
508
|
+
and self.hard_node_mask[node_type] is not None):
|
509
|
+
loss = self._add_mask_regularization(
|
510
|
+
loss, mask, self.hard_node_mask[node_type],
|
511
|
+
self.coeffs['node_feat_size'],
|
512
|
+
self.coeffs['node_feat_reduction'],
|
513
|
+
self.coeffs['node_feat_ent'])
|
196
514
|
|
515
|
+
return loss
|
516
|
+
|
517
|
+
def _apply_homo_regularization(self, loss):
|
518
|
+
"""Apply regularization for homogeneous graph."""
|
519
|
+
# Apply regularization for edge mask
|
197
520
|
if self.hard_edge_mask is not None:
|
198
521
|
assert self.edge_mask is not None
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
loss = loss + self.coeffs['edge_ent'] * ent.mean()
|
522
|
+
loss = self._add_mask_regularization(loss, self.edge_mask,
|
523
|
+
self.hard_edge_mask,
|
524
|
+
self.coeffs['edge_size'],
|
525
|
+
self.coeffs['edge_reduction'],
|
526
|
+
self.coeffs['edge_ent'])
|
205
527
|
|
528
|
+
# Apply regularization for node mask
|
206
529
|
if self.hard_node_mask is not None:
|
207
530
|
assert self.node_mask is not None
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
531
|
+
loss = self._add_mask_regularization(
|
532
|
+
loss, self.node_mask, self.hard_node_mask,
|
533
|
+
self.coeffs['node_feat_size'],
|
534
|
+
self.coeffs['node_feat_reduction'],
|
535
|
+
self.coeffs['node_feat_ent'])
|
536
|
+
|
537
|
+
return loss
|
538
|
+
|
539
|
+
def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
|
540
|
+
reduction_name, ent_coeff):
|
541
|
+
"""Add size and entropy regularization for a mask."""
|
542
|
+
m = mask[hard_mask].sigmoid()
|
543
|
+
reduce_fn = getattr(torch, reduction_name)
|
544
|
+
# Add size regularization
|
545
|
+
loss = loss + size_coeff * reduce_fn(m)
|
546
|
+
# Add entropy regularization
|
547
|
+
ent = -m * torch.log(m + self.coeffs['EPS']) - (
|
548
|
+
1 - m) * torch.log(1 - m + self.coeffs['EPS'])
|
549
|
+
loss = loss + ent_coeff * ent.mean()
|
214
550
|
|
215
551
|
return loss
|
216
552
|
|
@@ -670,6 +670,9 @@ class LinkPredDiversity(_LinkPredMetric):
|
|
670
670
|
def __init__(self, k: int, category: Tensor) -> None:
|
671
671
|
super().__init__(k)
|
672
672
|
|
673
|
+
self.accum: Tensor
|
674
|
+
self.total: Tensor
|
675
|
+
|
673
676
|
if WITH_TORCHMETRICS:
|
674
677
|
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
675
678
|
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
@@ -736,11 +739,14 @@ class LinkPredPersonalization(_LinkPredMetric):
|
|
736
739
|
self.max_src_nodes = max_src_nodes
|
737
740
|
self.batch_size = batch_size
|
738
741
|
|
742
|
+
self.preds: List[Tensor]
|
743
|
+
self.total: Tensor
|
744
|
+
|
739
745
|
if WITH_TORCHMETRICS:
|
740
746
|
self.add_state('preds', default=[], dist_reduce_fx='cat')
|
741
747
|
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
742
748
|
else:
|
743
|
-
self.preds
|
749
|
+
self.preds = []
|
744
750
|
self.register_buffer('total', torch.tensor(0), persistent=False)
|
745
751
|
|
746
752
|
def update(
|
@@ -826,6 +832,9 @@ class LinkPredAveragePopularity(_LinkPredMetric):
|
|
826
832
|
def __init__(self, k: int, popularity: Tensor) -> None:
|
827
833
|
super().__init__(k)
|
828
834
|
|
835
|
+
self.accum: Tensor
|
836
|
+
self.total: Tensor
|
837
|
+
|
829
838
|
if WITH_TORCHMETRICS:
|
830
839
|
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
|
831
840
|
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
@@ -379,7 +379,7 @@ class GPSE(torch.nn.Module):
|
|
379
379
|
|
380
380
|
.. code-block:: python
|
381
381
|
|
382
|
-
from torch_geometric.nn import GPSE, GPSENodeEncoder
|
382
|
+
from torch_geometric.nn import GPSE, GPSENodeEncoder
|
383
383
|
from torch_geometric.transforms import AddGPSE
|
384
384
|
from torch_geometric.nn.models.gpse import precompute_GPSE
|
385
385
|
|
@@ -417,13 +417,11 @@ class GPSE(torch.nn.Module):
|
|
417
417
|
|
418
418
|
encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
|
419
419
|
expand_x=False)
|
420
|
-
gnn = GNN(
|
420
|
+
gnn = GNN(...)
|
421
421
|
|
422
422
|
for batch in loader:
|
423
|
-
|
424
|
-
|
425
|
-
# Do something with the batch, which now includes 128-dimensional
|
426
|
-
# node representations
|
423
|
+
x = encoder(batch.x, batch.pestat_GPSE)
|
424
|
+
out = gnn(x, batch.edge_index)
|
427
425
|
|
428
426
|
|
429
427
|
Args:
|
@@ -571,8 +569,7 @@ class GPSE(torch.nn.Module):
|
|
571
569
|
self.reset_parameters()
|
572
570
|
|
573
571
|
def reset_parameters(self):
|
574
|
-
|
575
|
-
self.apply(init_weights)
|
572
|
+
pass
|
576
573
|
|
577
574
|
@classmethod
|
578
575
|
def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
|
@@ -608,6 +605,7 @@ class GPSE(torch.nn.Module):
|
|
608
605
|
return model
|
609
606
|
|
610
607
|
def forward(self, batch):
|
608
|
+
batch = batch.clone()
|
611
609
|
for module in self.children():
|
612
610
|
batch = module(batch)
|
613
611
|
return batch
|
@@ -635,13 +633,11 @@ class GPSENodeEncoder(torch.nn.Module):
|
|
635
633
|
|
636
634
|
encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
|
637
635
|
expand_x=False)
|
638
|
-
gnn = GNN(
|
636
|
+
gnn = GNN(...)
|
639
637
|
|
640
638
|
for batch in loader:
|
641
|
-
|
642
|
-
batch = gnn(batch)
|
643
|
-
# Do something with the batch, which now includes 128-dimensional
|
644
|
-
# node representations
|
639
|
+
x = encoder(batch.x, batch.pestat_GPSE)
|
640
|
+
batch = gnn(x, batch.edge_index)
|
645
641
|
|
646
642
|
Args:
|
647
643
|
dim_emb (int): Size of final node embedding.
|
@@ -705,28 +701,17 @@ class GPSENodeEncoder(torch.nn.Module):
|
|
705
701
|
raise ValueError(f"{self.__class__.__name__}: Does not support "
|
706
702
|
f"'{model_type}' encoder model.")
|
707
703
|
|
708
|
-
def forward(self,
|
709
|
-
if not hasattr(batch, 'pestat_GPSE'):
|
710
|
-
raise ValueError('Precomputed "pestat_GPSE" variable is required '
|
711
|
-
'for GNNNodeEncoder; either run '
|
712
|
-
'`precompute_GPSE(gpse_model, dataset)` on your '
|
713
|
-
'dataset or add `AddGPSE(gpse_model)` as a (pre) '
|
714
|
-
'transform.')
|
715
|
-
|
716
|
-
pos_enc = batch.pestat_GPSE
|
717
|
-
|
704
|
+
def forward(self, x, pos_enc):
|
718
705
|
pos_enc = self.dropout_be(pos_enc)
|
719
706
|
pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
|
720
707
|
pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
|
721
708
|
pos_enc = self.dropout_ae(pos_enc)
|
722
709
|
|
723
710
|
# Expand node features if needed
|
724
|
-
h = self.linear_x(
|
711
|
+
h = self.linear_x(x) if self.expand_x else x
|
725
712
|
|
726
713
|
# Concatenate final PEs to input embedding
|
727
|
-
|
728
|
-
|
729
|
-
return batch
|
714
|
+
return torch.cat((h, pos_enc), 1)
|
730
715
|
|
731
716
|
|
732
717
|
@torch.no_grad()
|
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
|
|
94
94
|
arange = torch.arange(n_id.size(0), device=data.edge_index.device)
|
95
95
|
node_map = data.edge_index.new_ones(num_nodes, num_nodes)
|
96
96
|
node_map[n_sub_batch, n_id] = arange
|
97
|
-
sub_edge_index += (arange *
|
97
|
+
sub_edge_index += (arange * num_nodes)[e_sub_batch]
|
98
98
|
sub_edge_index = node_map.view(-1)[sub_edge_index]
|
99
99
|
|
100
100
|
return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
|
@@ -109,10 +109,9 @@ def negative_sampling(
|
|
109
109
|
idx = idx.to('cpu')
|
110
110
|
for _ in range(3): # Number of tries to sample negative indices.
|
111
111
|
rnd = sample(population, sample_size, device='cpu')
|
112
|
-
mask = np.isin(rnd.numpy(), idx.numpy())
|
112
|
+
mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
|
113
113
|
if neg_idx is not None:
|
114
|
-
mask |= np.isin(rnd, neg_idx.
|
115
|
-
mask = torch.from_numpy(mask).to(torch.bool)
|
114
|
+
mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
|
116
115
|
rnd = rnd[~mask].to(edge_index.device)
|
117
116
|
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
|
118
117
|
if neg_idx.numel() >= num_neg_samples:
|
File without changes
|
{pyg_nightly-2.7.0.dev20250406.dist-info → pyg_nightly-2.7.0.dev20250408.dist-info}/licenses/LICENSE
RENAMED
File without changes
|