pyg-nightly 2.7.0.dev20250703__py3-none-any.whl → 2.7.0.dev20250705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250703.dist-info → pyg_nightly-2.7.0.dev20250705.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250703.dist-info → pyg_nightly-2.7.0.dev20250705.dist-info}/RECORD +10 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/nn/attention/__init__.py +2 -0
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/polynormer.py +206 -0
- {pyg_nightly-2.7.0.dev20250703.dist-info → pyg_nightly-2.7.0.dev20250705.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250703.dist-info → pyg_nightly-2.7.0.dev20250705.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250703.dist-info → pyg_nightly-2.7.0.dev20250705.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyg-nightly
|
|
3
|
-
Version: 2.7.0.
|
|
3
|
+
Version: 2.7.0.dev20250705
|
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
|
1
|
+
torch_geometric/__init__.py,sha256=1bw1VDVg35KJk3xGL1TIsA4fTzak3tUjkH7pRIkfoyg,2250
|
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
|
@@ -119,7 +119,7 @@ torch_geometric/datasets/medshapenet.py,sha256=eCBCXKpueweCwDSf_Q4_MwVA3IbJd04FS
|
|
|
119
119
|
torch_geometric/datasets/mixhop_synthetic_dataset.py,sha256=4NNvTHUvvV6pcqQCyVDS5XhppXUeF2H9GTfFoc49eyU,3951
|
|
120
120
|
torch_geometric/datasets/mnist_superpixels.py,sha256=o2ArbZ0_OE0u8VCaHmWwvngESlOFr9oM9dSEP_tjAS4,3340
|
|
121
121
|
torch_geometric/datasets/modelnet.py,sha256=-qmLjlQiKVWmtHefAIIE97dQxEcaBfetMJnvgYZuwkg,5347
|
|
122
|
-
torch_geometric/datasets/molecule_gpt_dataset.py,sha256=
|
|
122
|
+
torch_geometric/datasets/molecule_gpt_dataset.py,sha256=TFBduE3_3xxTFSHL3tirV-OAlBjSi6iHPOHJGQ_-tug,18785
|
|
123
123
|
torch_geometric/datasets/molecule_net.py,sha256=pMzaJzd-LbBncZ0VoC87HfA8d1F4NwCWTb5YKvLM890,7404
|
|
124
124
|
torch_geometric/datasets/movie_lens.py,sha256=M4Bu0Xus8IkW8GYzjxPxSdPXNbcCCx9cu6cncxBvLx8,4033
|
|
125
125
|
torch_geometric/datasets/movie_lens_100k.py,sha256=eTpBAteM3jqTEtiwLxmhVj4r8JvftvPx8Hvs-3ZIHlU,6057
|
|
@@ -335,8 +335,9 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
|
|
|
335
335
|
torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
|
|
336
336
|
torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
|
|
337
337
|
torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
|
|
338
|
-
torch_geometric/nn/attention/__init__.py,sha256=
|
|
338
|
+
torch_geometric/nn/attention/__init__.py,sha256=w-jDQFpVqARJKjttTgKkD9kkAqRJl4MpASCfiNYIfr0,263
|
|
339
339
|
torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
|
|
340
|
+
torch_geometric/nn/attention/polynormer.py,sha256=uBxGs0nldp6oGlByqbxgEk23VeXLEd6B3myS5BOKDRs,3998
|
|
340
341
|
torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
|
|
341
342
|
torch_geometric/nn/attention/sgformer.py,sha256=OBC5HQxbY289bPDtwN8UbPH46To2GRTeVN-najogD-o,3747
|
|
342
343
|
torch_geometric/nn/conv/__init__.py,sha256=8CK-DFG2PEo2ZaFyg-IUlQH8ecQoDDi556uv3ugeQyc,3572
|
|
@@ -431,7 +432,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
|
431
432
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
|
432
433
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
|
433
434
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
|
434
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
|
435
|
+
torch_geometric/nn/models/__init__.py,sha256=71Hqc-ZMfCKn9lelFYDjpHXapbEa0wqVAd2OXCb1y5o,2448
|
|
435
436
|
torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
|
|
436
437
|
torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
|
|
437
438
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
|
@@ -461,6 +462,7 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
|
|
|
461
462
|
torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
|
|
462
463
|
torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
|
|
463
464
|
torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
|
|
465
|
+
torch_geometric/nn/models/polynormer.py,sha256=mayWdzdolT5PCt_Oo7UGG-JUripMHWB2lUWF1bh6goU,7640
|
|
464
466
|
torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
|
|
465
467
|
torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
|
|
466
468
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
|
@@ -643,7 +645,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
|
643
645
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
|
644
646
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
|
645
647
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
|
646
|
-
pyg_nightly-2.7.0.
|
|
647
|
-
pyg_nightly-2.7.0.
|
|
648
|
-
pyg_nightly-2.7.0.
|
|
649
|
-
pyg_nightly-2.7.0.
|
|
648
|
+
pyg_nightly-2.7.0.dev20250705.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
|
649
|
+
pyg_nightly-2.7.0.dev20250705.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
650
|
+
pyg_nightly-2.7.0.dev20250705.dist-info/METADATA,sha256=YTG5Idy8eBGYKfe-87eyt6OjQ5_Me0nc1RWKkeiKDjE,63005
|
|
651
|
+
pyg_nightly-2.7.0.dev20250705.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
33
33
|
|
|
34
|
-
__version__ = '2.7.0.
|
|
34
|
+
__version__ = '2.7.0.dev20250705'
|
|
35
35
|
|
|
36
36
|
__all__ = [
|
|
37
37
|
'Index',
|
|
@@ -438,7 +438,7 @@ class MoleculeGPTDataset(InMemoryDataset):
|
|
|
438
438
|
for mol in tqdm(suppl):
|
|
439
439
|
if mol.HasProp('PUBCHEM_COMPOUND_CID'):
|
|
440
440
|
CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
|
|
441
|
-
CAN_SMILES = mol.GetProp("
|
|
441
|
+
CAN_SMILES = mol.GetProp("PUBCHEM_SMILES")
|
|
442
442
|
|
|
443
443
|
m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
|
|
444
444
|
if m is None:
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from .performer import PerformerAttention
|
|
2
2
|
from .qformer import QFormer
|
|
3
3
|
from .sgformer import SGFormerAttention
|
|
4
|
+
from .polynormer import PolynormerAttention
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
'PerformerAttention',
|
|
7
8
|
'QFormer',
|
|
8
9
|
'SGFormerAttention',
|
|
10
|
+
'PolynormerAttention',
|
|
9
11
|
]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PolynormerAttention(torch.nn.Module):
|
|
9
|
+
r"""The polynomial-expressive attention mechanism from the
|
|
10
|
+
`"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time"
|
|
11
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
channels (int): Size of each input sample.
|
|
15
|
+
heads (int, optional): Number of parallel attention heads.
|
|
16
|
+
head_channels (int, optional): Size of each attention head.
|
|
17
|
+
(default: :obj:`64.`)
|
|
18
|
+
beta (float, optional): Polynormer beta initialization.
|
|
19
|
+
(default: :obj:`0.9`)
|
|
20
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
|
21
|
+
and value in the self attention. (default: :obj:`False`)
|
|
22
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
23
|
+
(default: :obj:`True`)
|
|
24
|
+
dropout (float, optional): Dropout probability of the final
|
|
25
|
+
attention output. (default: :obj:`0.0`)
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
channels: int,
|
|
30
|
+
heads: int,
|
|
31
|
+
head_channels: int = 64,
|
|
32
|
+
beta: float = 0.9,
|
|
33
|
+
qkv_bias: bool = False,
|
|
34
|
+
qk_shared: bool = True,
|
|
35
|
+
dropout: float = 0.0,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self.head_channels = head_channels
|
|
40
|
+
self.heads = heads
|
|
41
|
+
self.beta = beta
|
|
42
|
+
self.qk_shared = qk_shared
|
|
43
|
+
|
|
44
|
+
inner_channels = heads * head_channels
|
|
45
|
+
self.h_lins = torch.nn.Linear(channels, inner_channels)
|
|
46
|
+
if not self.qk_shared:
|
|
47
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
48
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
49
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
50
|
+
self.lns = torch.nn.LayerNorm(inner_channels)
|
|
51
|
+
self.lin_out = torch.nn.Linear(inner_channels, inner_channels)
|
|
52
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
55
|
+
r"""Forward pass.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): Node feature tensor
|
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
60
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
61
|
+
each graph, and feature dimension :math:`F`.
|
|
62
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
63
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
64
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
65
|
+
"""
|
|
66
|
+
B, N, *_ = x.shape
|
|
67
|
+
h = self.h_lins(x)
|
|
68
|
+
k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)
|
|
69
|
+
if self.qk_shared:
|
|
70
|
+
q = k
|
|
71
|
+
else:
|
|
72
|
+
q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)
|
|
73
|
+
v = self.v(x).view(B, N, self.head_channels, self.heads)
|
|
74
|
+
|
|
75
|
+
if mask is not None:
|
|
76
|
+
mask = mask[:, :, None, None]
|
|
77
|
+
v.masked_fill_(~mask, 0.)
|
|
78
|
+
|
|
79
|
+
# numerator
|
|
80
|
+
kv = torch.einsum('bndh, bnmh -> bdmh', k, v)
|
|
81
|
+
num = torch.einsum('bndh, bdmh -> bnmh', q, kv)
|
|
82
|
+
|
|
83
|
+
# denominator
|
|
84
|
+
k_sum = torch.einsum('bndh -> bdh', k)
|
|
85
|
+
den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)
|
|
86
|
+
|
|
87
|
+
# linear global attention based on kernel trick
|
|
88
|
+
x = (num / (den + 1e-6)).reshape(B, N, -1)
|
|
89
|
+
x = self.lns(x) * (h + self.beta)
|
|
90
|
+
x = F.relu(self.lin_out(x))
|
|
91
|
+
x = self.dropout(x)
|
|
92
|
+
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
def reset_parameters(self) -> None:
|
|
96
|
+
self.h_lins.reset_parameters()
|
|
97
|
+
if not self.qk_shared:
|
|
98
|
+
self.q.reset_parameters()
|
|
99
|
+
self.k.reset_parameters()
|
|
100
|
+
self.v.reset_parameters()
|
|
101
|
+
self.lns.reset_parameters()
|
|
102
|
+
self.lin_out.reset_parameters()
|
|
103
|
+
|
|
104
|
+
def __repr__(self) -> str:
|
|
105
|
+
return (f'{self.__class__.__name__}('
|
|
106
|
+
f'heads={self.heads}, '
|
|
107
|
+
f'head_channels={self.head_channels})')
|
|
@@ -35,6 +35,7 @@ from .molecule_gpt import MoleculeGPT
|
|
|
35
35
|
from .protein_mpnn import ProteinMPNN
|
|
36
36
|
from .glem import GLEM
|
|
37
37
|
from .sgformer import SGFormer
|
|
38
|
+
from .polynormer import Polynormer
|
|
38
39
|
# Deprecated:
|
|
39
40
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
|
40
41
|
captum_output_to_dicts)
|
|
@@ -90,5 +91,6 @@ __all__ = classes = [
|
|
|
90
91
|
'ProteinMPNN',
|
|
91
92
|
'GLEM',
|
|
92
93
|
'SGFormer',
|
|
94
|
+
'Polynormer',
|
|
93
95
|
'ARLinkPredictor',
|
|
94
96
|
]
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torch_geometric.nn import GATConv, GCNConv
|
|
8
|
+
from torch_geometric.nn.attention import PolynormerAttention
|
|
9
|
+
from torch_geometric.utils import to_dense_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Polynormer(torch.nn.Module):
|
|
13
|
+
r"""The polynormer module from the
|
|
14
|
+
`"Polynormer: polynomial-expressive graph
|
|
15
|
+
transformer in linear time"
|
|
16
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
in_channels (int): Input channels.
|
|
20
|
+
hidden_channels (int): Hidden channels.
|
|
21
|
+
out_channels (int): Output channels.
|
|
22
|
+
local_layers (int): The number of local attention layers.
|
|
23
|
+
(default: :obj:`7`)
|
|
24
|
+
global_layers (int): The number of global attention layers.
|
|
25
|
+
(default: :obj:`2`)
|
|
26
|
+
in_dropout (float): Input dropout rate.
|
|
27
|
+
(default: :obj:`0.15`)
|
|
28
|
+
dropout (float): Dropout rate.
|
|
29
|
+
(default: :obj:`0.5`)
|
|
30
|
+
global_dropout (float): Global dropout rate.
|
|
31
|
+
(default: :obj:`0.5`)
|
|
32
|
+
heads (int): The number of heads.
|
|
33
|
+
(default: :obj:`1`)
|
|
34
|
+
beta (float): Aggregate type.
|
|
35
|
+
(default: :obj:`0.9`)
|
|
36
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
37
|
+
(default: :obj:`True`)
|
|
38
|
+
pre_ln (bool): Pre layer normalization.
|
|
39
|
+
(default: :obj:`False`)
|
|
40
|
+
post_bn (bool): Post batch normlization.
|
|
41
|
+
(default: :obj:`True`)
|
|
42
|
+
local_attn (bool): Whether use local attention.
|
|
43
|
+
(default: :obj:`False`)
|
|
44
|
+
"""
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
in_channels: int,
|
|
48
|
+
hidden_channels: int,
|
|
49
|
+
out_channels: int,
|
|
50
|
+
local_layers: int = 7,
|
|
51
|
+
global_layers: int = 2,
|
|
52
|
+
in_dropout: float = 0.15,
|
|
53
|
+
dropout: float = 0.5,
|
|
54
|
+
global_dropout: float = 0.5,
|
|
55
|
+
heads: int = 1,
|
|
56
|
+
beta: float = 0.9,
|
|
57
|
+
qk_shared: bool = False,
|
|
58
|
+
pre_ln: bool = False,
|
|
59
|
+
post_bn: bool = True,
|
|
60
|
+
local_attn: bool = False,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self._global = False
|
|
64
|
+
self.in_drop = in_dropout
|
|
65
|
+
self.dropout = dropout
|
|
66
|
+
self.pre_ln = pre_ln
|
|
67
|
+
self.post_bn = post_bn
|
|
68
|
+
|
|
69
|
+
self.beta = beta
|
|
70
|
+
|
|
71
|
+
self.h_lins = torch.nn.ModuleList()
|
|
72
|
+
self.local_convs = torch.nn.ModuleList()
|
|
73
|
+
self.lins = torch.nn.ModuleList()
|
|
74
|
+
self.lns = torch.nn.ModuleList()
|
|
75
|
+
if self.pre_ln:
|
|
76
|
+
self.pre_lns = torch.nn.ModuleList()
|
|
77
|
+
if self.post_bn:
|
|
78
|
+
self.post_bns = torch.nn.ModuleList()
|
|
79
|
+
|
|
80
|
+
# first layer
|
|
81
|
+
inner_channels = heads * hidden_channels
|
|
82
|
+
self.h_lins.append(torch.nn.Linear(in_channels, inner_channels))
|
|
83
|
+
if local_attn:
|
|
84
|
+
self.local_convs.append(
|
|
85
|
+
GATConv(in_channels, hidden_channels, heads=heads, concat=True,
|
|
86
|
+
add_self_loops=False, bias=False))
|
|
87
|
+
else:
|
|
88
|
+
self.local_convs.append(
|
|
89
|
+
GCNConv(in_channels, inner_channels, cached=False,
|
|
90
|
+
normalize=True))
|
|
91
|
+
|
|
92
|
+
self.lins.append(torch.nn.Linear(in_channels, inner_channels))
|
|
93
|
+
self.lns.append(torch.nn.LayerNorm(inner_channels))
|
|
94
|
+
if self.pre_ln:
|
|
95
|
+
self.pre_lns.append(torch.nn.LayerNorm(in_channels))
|
|
96
|
+
if self.post_bn:
|
|
97
|
+
self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
|
|
98
|
+
|
|
99
|
+
# following layers
|
|
100
|
+
for _ in range(local_layers - 1):
|
|
101
|
+
self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels))
|
|
102
|
+
if local_attn:
|
|
103
|
+
self.local_convs.append(
|
|
104
|
+
GATConv(inner_channels, hidden_channels, heads=heads,
|
|
105
|
+
concat=True, add_self_loops=False, bias=False))
|
|
106
|
+
else:
|
|
107
|
+
self.local_convs.append(
|
|
108
|
+
GCNConv(inner_channels, inner_channels, cached=False,
|
|
109
|
+
normalize=True))
|
|
110
|
+
|
|
111
|
+
self.lins.append(torch.nn.Linear(inner_channels, inner_channels))
|
|
112
|
+
self.lns.append(torch.nn.LayerNorm(inner_channels))
|
|
113
|
+
if self.pre_ln:
|
|
114
|
+
self.pre_lns.append(torch.nn.LayerNorm(heads *
|
|
115
|
+
hidden_channels))
|
|
116
|
+
if self.post_bn:
|
|
117
|
+
self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
|
|
118
|
+
|
|
119
|
+
self.lin_in = torch.nn.Linear(in_channels, inner_channels)
|
|
120
|
+
self.ln = torch.nn.LayerNorm(inner_channels)
|
|
121
|
+
|
|
122
|
+
self.global_attn = torch.nn.ModuleList()
|
|
123
|
+
for _ in range(global_layers):
|
|
124
|
+
self.global_attn.append(
|
|
125
|
+
PolynormerAttention(
|
|
126
|
+
channels=hidden_channels,
|
|
127
|
+
heads=heads,
|
|
128
|
+
head_channels=hidden_channels,
|
|
129
|
+
beta=beta,
|
|
130
|
+
dropout=global_dropout,
|
|
131
|
+
qk_shared=qk_shared,
|
|
132
|
+
))
|
|
133
|
+
self.pred_local = torch.nn.Linear(inner_channels, out_channels)
|
|
134
|
+
self.pred_global = torch.nn.Linear(inner_channels, out_channels)
|
|
135
|
+
self.reset_parameters()
|
|
136
|
+
|
|
137
|
+
def reset_parameters(self) -> None:
|
|
138
|
+
for local_conv in self.local_convs:
|
|
139
|
+
local_conv.reset_parameters()
|
|
140
|
+
for attn in self.global_attn:
|
|
141
|
+
attn.reset_parameters()
|
|
142
|
+
for lin in self.lins:
|
|
143
|
+
lin.reset_parameters()
|
|
144
|
+
for h_lin in self.h_lins:
|
|
145
|
+
h_lin.reset_parameters()
|
|
146
|
+
for ln in self.lns:
|
|
147
|
+
ln.reset_parameters()
|
|
148
|
+
if self.pre_ln:
|
|
149
|
+
for p_ln in self.pre_lns:
|
|
150
|
+
p_ln.reset_parameters()
|
|
151
|
+
if self.post_bn:
|
|
152
|
+
for p_bn in self.post_bns:
|
|
153
|
+
p_bn.reset_parameters()
|
|
154
|
+
self.lin_in.reset_parameters()
|
|
155
|
+
self.ln.reset_parameters()
|
|
156
|
+
self.pred_local.reset_parameters()
|
|
157
|
+
self.pred_global.reset_parameters()
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
x: Tensor,
|
|
162
|
+
edge_index: Tensor,
|
|
163
|
+
batch: Optional[Tensor],
|
|
164
|
+
) -> Tensor:
|
|
165
|
+
r"""Forward pass.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
x (torch.Tensor): The input node features.
|
|
169
|
+
edge_index (torch.Tensor or SparseTensor): The edge indices.
|
|
170
|
+
batch (torch.Tensor, optional): The batch vector
|
|
171
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
|
172
|
+
each element to a specific example.
|
|
173
|
+
"""
|
|
174
|
+
x = F.dropout(x, p=self.in_drop, training=self.training)
|
|
175
|
+
|
|
176
|
+
# equivariant local attention
|
|
177
|
+
x_local = 0
|
|
178
|
+
for i, local_conv in enumerate(self.local_convs):
|
|
179
|
+
if self.pre_ln:
|
|
180
|
+
x = self.pre_lns[i](x)
|
|
181
|
+
h = self.h_lins[i](x)
|
|
182
|
+
h = F.relu(h)
|
|
183
|
+
x = local_conv(x, edge_index) + self.lins[i](x)
|
|
184
|
+
if self.post_bn:
|
|
185
|
+
x = self.post_bns[i](x)
|
|
186
|
+
x = F.relu(x)
|
|
187
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
188
|
+
x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x
|
|
189
|
+
x_local = x_local + x
|
|
190
|
+
|
|
191
|
+
# equivariant global attention
|
|
192
|
+
if self._global:
|
|
193
|
+
batch, indices = batch.sort()
|
|
194
|
+
rev_perm = torch.empty_like(indices)
|
|
195
|
+
rev_perm[indices] = torch.arange(len(indices),
|
|
196
|
+
device=indices.device)
|
|
197
|
+
x_local = self.ln(x_local[indices])
|
|
198
|
+
x_global, mask = to_dense_batch(x_local, batch)
|
|
199
|
+
for attn in self.global_attn:
|
|
200
|
+
x_global = attn(x_global, mask)
|
|
201
|
+
x = x_global[mask][rev_perm]
|
|
202
|
+
x = self.pred_global(x)
|
|
203
|
+
else:
|
|
204
|
+
x = self.pred_local(x_local)
|
|
205
|
+
|
|
206
|
+
return F.log_softmax(x, dim=-1)
|
|
File without changes
|
{pyg_nightly-2.7.0.dev20250703.dist-info → pyg_nightly-2.7.0.dev20250705.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|