pyg-nightly 2.7.0.dev20250703__py3-none-any.whl → 2.7.0.dev20250704__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250703
3
+ Version: 2.7.0.dev20250704
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=8_AcpgPpDfVr7gwK3sYPD-pu9JYUmziOQzs9SPRvqcE,2250
1
+ torch_geometric/__init__.py,sha256=GOuL0XBOcsFqK-Q-c_STDpzZAG-vsctiDiU_Tg9W3t8,2250
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -119,7 +119,7 @@ torch_geometric/datasets/medshapenet.py,sha256=eCBCXKpueweCwDSf_Q4_MwVA3IbJd04FS
119
119
  torch_geometric/datasets/mixhop_synthetic_dataset.py,sha256=4NNvTHUvvV6pcqQCyVDS5XhppXUeF2H9GTfFoc49eyU,3951
120
120
  torch_geometric/datasets/mnist_superpixels.py,sha256=o2ArbZ0_OE0u8VCaHmWwvngESlOFr9oM9dSEP_tjAS4,3340
121
121
  torch_geometric/datasets/modelnet.py,sha256=-qmLjlQiKVWmtHefAIIE97dQxEcaBfetMJnvgYZuwkg,5347
122
- torch_geometric/datasets/molecule_gpt_dataset.py,sha256=gVZv14PuZCanE4oxxHlqRNrvzGv6_KN318q5yFA3lS0,18797
122
+ torch_geometric/datasets/molecule_gpt_dataset.py,sha256=TFBduE3_3xxTFSHL3tirV-OAlBjSi6iHPOHJGQ_-tug,18785
123
123
  torch_geometric/datasets/molecule_net.py,sha256=pMzaJzd-LbBncZ0VoC87HfA8d1F4NwCWTb5YKvLM890,7404
124
124
  torch_geometric/datasets/movie_lens.py,sha256=M4Bu0Xus8IkW8GYzjxPxSdPXNbcCCx9cu6cncxBvLx8,4033
125
125
  torch_geometric/datasets/movie_lens_100k.py,sha256=eTpBAteM3jqTEtiwLxmhVj4r8JvftvPx8Hvs-3ZIHlU,6057
@@ -335,8 +335,9 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
335
335
  torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
336
336
  torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
337
337
  torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
338
- torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
338
+ torch_geometric/nn/attention/__init__.py,sha256=w-jDQFpVqARJKjttTgKkD9kkAqRJl4MpASCfiNYIfr0,263
339
339
  torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
340
+ torch_geometric/nn/attention/polynormer.py,sha256=uBxGs0nldp6oGlByqbxgEk23VeXLEd6B3myS5BOKDRs,3998
340
341
  torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
341
342
  torch_geometric/nn/attention/sgformer.py,sha256=OBC5HQxbY289bPDtwN8UbPH46To2GRTeVN-najogD-o,3747
342
343
  torch_geometric/nn/conv/__init__.py,sha256=8CK-DFG2PEo2ZaFyg-IUlQH8ecQoDDi556uv3ugeQyc,3572
@@ -431,7 +432,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
431
432
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
432
433
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
433
434
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
434
- torch_geometric/nn/models/__init__.py,sha256=fbHQauZw9Snvl2PuN5cjZoAW8SwUl6E-p2IOmwUKB3A,2395
435
+ torch_geometric/nn/models/__init__.py,sha256=71Hqc-ZMfCKn9lelFYDjpHXapbEa0wqVAd2OXCb1y5o,2448
435
436
  torch_geometric/nn/models/attentive_fp.py,sha256=1z3iTV2O5W9tqHFAdno8FeBFeXmuG-TDZk4lwwVh3Ac,6634
436
437
  torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
437
438
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
@@ -461,6 +462,7 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
461
462
  torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
462
463
  torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
463
464
  torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
465
+ torch_geometric/nn/models/polynormer.py,sha256=mayWdzdolT5PCt_Oo7UGG-JUripMHWB2lUWF1bh6goU,7640
464
466
  torch_geometric/nn/models/protein_mpnn.py,sha256=QXHfltiJPmakpzgJKw_1vwCGBlszv9nfY4r4F38Sg9k,11031
465
467
  torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
466
468
  torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
@@ -643,7 +645,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
643
645
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
644
646
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
645
647
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
646
- pyg_nightly-2.7.0.dev20250703.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
647
- pyg_nightly-2.7.0.dev20250703.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
648
- pyg_nightly-2.7.0.dev20250703.dist-info/METADATA,sha256=lm-xVNswGkfSaczkhV_ANegM0oGwBEfhCv75gwj0X5Q,63005
649
- pyg_nightly-2.7.0.dev20250703.dist-info/RECORD,,
648
+ pyg_nightly-2.7.0.dev20250704.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
649
+ pyg_nightly-2.7.0.dev20250704.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
650
+ pyg_nightly-2.7.0.dev20250704.dist-info/METADATA,sha256=Nau44bIMI13OXEqYNOlI1hYfDz8FXpUoPydv-JxRW2Q,63005
651
+ pyg_nightly-2.7.0.dev20250704.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250703'
34
+ __version__ = '2.7.0.dev20250704'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -438,7 +438,7 @@ class MoleculeGPTDataset(InMemoryDataset):
438
438
  for mol in tqdm(suppl):
439
439
  if mol.HasProp('PUBCHEM_COMPOUND_CID'):
440
440
  CID = mol.GetProp("PUBCHEM_COMPOUND_CID")
441
- CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES")
441
+ CAN_SMILES = mol.GetProp("PUBCHEM_SMILES")
442
442
 
443
443
  m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)
444
444
  if m is None:
@@ -1,9 +1,11 @@
1
1
  from .performer import PerformerAttention
2
2
  from .qformer import QFormer
3
3
  from .sgformer import SGFormerAttention
4
+ from .polynormer import PolynormerAttention
4
5
 
5
6
  __all__ = [
6
7
  'PerformerAttention',
7
8
  'QFormer',
8
9
  'SGFormerAttention',
10
+ 'PolynormerAttention',
9
11
  ]
@@ -0,0 +1,107 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+
8
+ class PolynormerAttention(torch.nn.Module):
9
+ r"""The polynomial-expressive attention mechanism from the
10
+ `"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time"
11
+ <https://arxiv.org/abs/2403.01232>`_ paper.
12
+
13
+ Args:
14
+ channels (int): Size of each input sample.
15
+ heads (int, optional): Number of parallel attention heads.
16
+ head_channels (int, optional): Size of each attention head.
17
+ (default: :obj:`64.`)
18
+ beta (float, optional): Polynormer beta initialization.
19
+ (default: :obj:`0.9`)
20
+ qkv_bias (bool, optional): If specified, add bias to query, key
21
+ and value in the self attention. (default: :obj:`False`)
22
+ qk_shared (bool optional): Whether weight of query and key are shared.
23
+ (default: :obj:`True`)
24
+ dropout (float, optional): Dropout probability of the final
25
+ attention output. (default: :obj:`0.0`)
26
+ """
27
+ def __init__(
28
+ self,
29
+ channels: int,
30
+ heads: int,
31
+ head_channels: int = 64,
32
+ beta: float = 0.9,
33
+ qkv_bias: bool = False,
34
+ qk_shared: bool = True,
35
+ dropout: float = 0.0,
36
+ ) -> None:
37
+ super().__init__()
38
+
39
+ self.head_channels = head_channels
40
+ self.heads = heads
41
+ self.beta = beta
42
+ self.qk_shared = qk_shared
43
+
44
+ inner_channels = heads * head_channels
45
+ self.h_lins = torch.nn.Linear(channels, inner_channels)
46
+ if not self.qk_shared:
47
+ self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
48
+ self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
49
+ self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
50
+ self.lns = torch.nn.LayerNorm(inner_channels)
51
+ self.lin_out = torch.nn.Linear(inner_channels, inner_channels)
52
+ self.dropout = torch.nn.Dropout(dropout)
53
+
54
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
55
+ r"""Forward pass.
56
+
57
+ Args:
58
+ x (torch.Tensor): Node feature tensor
59
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
60
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
61
+ each graph, and feature dimension :math:`F`.
62
+ mask (torch.Tensor, optional): Mask matrix
63
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
64
+ the valid nodes for each graph. (default: :obj:`None`)
65
+ """
66
+ B, N, *_ = x.shape
67
+ h = self.h_lins(x)
68
+ k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)
69
+ if self.qk_shared:
70
+ q = k
71
+ else:
72
+ q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)
73
+ v = self.v(x).view(B, N, self.head_channels, self.heads)
74
+
75
+ if mask is not None:
76
+ mask = mask[:, :, None, None]
77
+ v.masked_fill_(~mask, 0.)
78
+
79
+ # numerator
80
+ kv = torch.einsum('bndh, bnmh -> bdmh', k, v)
81
+ num = torch.einsum('bndh, bdmh -> bnmh', q, kv)
82
+
83
+ # denominator
84
+ k_sum = torch.einsum('bndh -> bdh', k)
85
+ den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)
86
+
87
+ # linear global attention based on kernel trick
88
+ x = (num / (den + 1e-6)).reshape(B, N, -1)
89
+ x = self.lns(x) * (h + self.beta)
90
+ x = F.relu(self.lin_out(x))
91
+ x = self.dropout(x)
92
+
93
+ return x
94
+
95
+ def reset_parameters(self) -> None:
96
+ self.h_lins.reset_parameters()
97
+ if not self.qk_shared:
98
+ self.q.reset_parameters()
99
+ self.k.reset_parameters()
100
+ self.v.reset_parameters()
101
+ self.lns.reset_parameters()
102
+ self.lin_out.reset_parameters()
103
+
104
+ def __repr__(self) -> str:
105
+ return (f'{self.__class__.__name__}('
106
+ f'heads={self.heads}, '
107
+ f'head_channels={self.head_channels})')
@@ -35,6 +35,7 @@ from .molecule_gpt import MoleculeGPT
35
35
  from .protein_mpnn import ProteinMPNN
36
36
  from .glem import GLEM
37
37
  from .sgformer import SGFormer
38
+ from .polynormer import Polynormer
38
39
  # Deprecated:
39
40
  from torch_geometric.explain.algorithm.captum import (to_captum_input,
40
41
  captum_output_to_dicts)
@@ -90,5 +91,6 @@ __all__ = classes = [
90
91
  'ProteinMPNN',
91
92
  'GLEM',
92
93
  'SGFormer',
94
+ 'Polynormer',
93
95
  'ARLinkPredictor',
94
96
  ]
@@ -0,0 +1,206 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.nn import GATConv, GCNConv
8
+ from torch_geometric.nn.attention import PolynormerAttention
9
+ from torch_geometric.utils import to_dense_batch
10
+
11
+
12
+ class Polynormer(torch.nn.Module):
13
+ r"""The polynormer module from the
14
+ `"Polynormer: polynomial-expressive graph
15
+ transformer in linear time"
16
+ <https://arxiv.org/abs/2403.01232>`_ paper.
17
+
18
+ Args:
19
+ in_channels (int): Input channels.
20
+ hidden_channels (int): Hidden channels.
21
+ out_channels (int): Output channels.
22
+ local_layers (int): The number of local attention layers.
23
+ (default: :obj:`7`)
24
+ global_layers (int): The number of global attention layers.
25
+ (default: :obj:`2`)
26
+ in_dropout (float): Input dropout rate.
27
+ (default: :obj:`0.15`)
28
+ dropout (float): Dropout rate.
29
+ (default: :obj:`0.5`)
30
+ global_dropout (float): Global dropout rate.
31
+ (default: :obj:`0.5`)
32
+ heads (int): The number of heads.
33
+ (default: :obj:`1`)
34
+ beta (float): Aggregate type.
35
+ (default: :obj:`0.9`)
36
+ qk_shared (bool optional): Whether weight of query and key are shared.
37
+ (default: :obj:`True`)
38
+ pre_ln (bool): Pre layer normalization.
39
+ (default: :obj:`False`)
40
+ post_bn (bool): Post batch normlization.
41
+ (default: :obj:`True`)
42
+ local_attn (bool): Whether use local attention.
43
+ (default: :obj:`False`)
44
+ """
45
+ def __init__(
46
+ self,
47
+ in_channels: int,
48
+ hidden_channels: int,
49
+ out_channels: int,
50
+ local_layers: int = 7,
51
+ global_layers: int = 2,
52
+ in_dropout: float = 0.15,
53
+ dropout: float = 0.5,
54
+ global_dropout: float = 0.5,
55
+ heads: int = 1,
56
+ beta: float = 0.9,
57
+ qk_shared: bool = False,
58
+ pre_ln: bool = False,
59
+ post_bn: bool = True,
60
+ local_attn: bool = False,
61
+ ) -> None:
62
+ super().__init__()
63
+ self._global = False
64
+ self.in_drop = in_dropout
65
+ self.dropout = dropout
66
+ self.pre_ln = pre_ln
67
+ self.post_bn = post_bn
68
+
69
+ self.beta = beta
70
+
71
+ self.h_lins = torch.nn.ModuleList()
72
+ self.local_convs = torch.nn.ModuleList()
73
+ self.lins = torch.nn.ModuleList()
74
+ self.lns = torch.nn.ModuleList()
75
+ if self.pre_ln:
76
+ self.pre_lns = torch.nn.ModuleList()
77
+ if self.post_bn:
78
+ self.post_bns = torch.nn.ModuleList()
79
+
80
+ # first layer
81
+ inner_channels = heads * hidden_channels
82
+ self.h_lins.append(torch.nn.Linear(in_channels, inner_channels))
83
+ if local_attn:
84
+ self.local_convs.append(
85
+ GATConv(in_channels, hidden_channels, heads=heads, concat=True,
86
+ add_self_loops=False, bias=False))
87
+ else:
88
+ self.local_convs.append(
89
+ GCNConv(in_channels, inner_channels, cached=False,
90
+ normalize=True))
91
+
92
+ self.lins.append(torch.nn.Linear(in_channels, inner_channels))
93
+ self.lns.append(torch.nn.LayerNorm(inner_channels))
94
+ if self.pre_ln:
95
+ self.pre_lns.append(torch.nn.LayerNorm(in_channels))
96
+ if self.post_bn:
97
+ self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
98
+
99
+ # following layers
100
+ for _ in range(local_layers - 1):
101
+ self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels))
102
+ if local_attn:
103
+ self.local_convs.append(
104
+ GATConv(inner_channels, hidden_channels, heads=heads,
105
+ concat=True, add_self_loops=False, bias=False))
106
+ else:
107
+ self.local_convs.append(
108
+ GCNConv(inner_channels, inner_channels, cached=False,
109
+ normalize=True))
110
+
111
+ self.lins.append(torch.nn.Linear(inner_channels, inner_channels))
112
+ self.lns.append(torch.nn.LayerNorm(inner_channels))
113
+ if self.pre_ln:
114
+ self.pre_lns.append(torch.nn.LayerNorm(heads *
115
+ hidden_channels))
116
+ if self.post_bn:
117
+ self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
118
+
119
+ self.lin_in = torch.nn.Linear(in_channels, inner_channels)
120
+ self.ln = torch.nn.LayerNorm(inner_channels)
121
+
122
+ self.global_attn = torch.nn.ModuleList()
123
+ for _ in range(global_layers):
124
+ self.global_attn.append(
125
+ PolynormerAttention(
126
+ channels=hidden_channels,
127
+ heads=heads,
128
+ head_channels=hidden_channels,
129
+ beta=beta,
130
+ dropout=global_dropout,
131
+ qk_shared=qk_shared,
132
+ ))
133
+ self.pred_local = torch.nn.Linear(inner_channels, out_channels)
134
+ self.pred_global = torch.nn.Linear(inner_channels, out_channels)
135
+ self.reset_parameters()
136
+
137
+ def reset_parameters(self) -> None:
138
+ for local_conv in self.local_convs:
139
+ local_conv.reset_parameters()
140
+ for attn in self.global_attn:
141
+ attn.reset_parameters()
142
+ for lin in self.lins:
143
+ lin.reset_parameters()
144
+ for h_lin in self.h_lins:
145
+ h_lin.reset_parameters()
146
+ for ln in self.lns:
147
+ ln.reset_parameters()
148
+ if self.pre_ln:
149
+ for p_ln in self.pre_lns:
150
+ p_ln.reset_parameters()
151
+ if self.post_bn:
152
+ for p_bn in self.post_bns:
153
+ p_bn.reset_parameters()
154
+ self.lin_in.reset_parameters()
155
+ self.ln.reset_parameters()
156
+ self.pred_local.reset_parameters()
157
+ self.pred_global.reset_parameters()
158
+
159
+ def forward(
160
+ self,
161
+ x: Tensor,
162
+ edge_index: Tensor,
163
+ batch: Optional[Tensor],
164
+ ) -> Tensor:
165
+ r"""Forward pass.
166
+
167
+ Args:
168
+ x (torch.Tensor): The input node features.
169
+ edge_index (torch.Tensor or SparseTensor): The edge indices.
170
+ batch (torch.Tensor, optional): The batch vector
171
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
172
+ each element to a specific example.
173
+ """
174
+ x = F.dropout(x, p=self.in_drop, training=self.training)
175
+
176
+ # equivariant local attention
177
+ x_local = 0
178
+ for i, local_conv in enumerate(self.local_convs):
179
+ if self.pre_ln:
180
+ x = self.pre_lns[i](x)
181
+ h = self.h_lins[i](x)
182
+ h = F.relu(h)
183
+ x = local_conv(x, edge_index) + self.lins[i](x)
184
+ if self.post_bn:
185
+ x = self.post_bns[i](x)
186
+ x = F.relu(x)
187
+ x = F.dropout(x, p=self.dropout, training=self.training)
188
+ x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x
189
+ x_local = x_local + x
190
+
191
+ # equivariant global attention
192
+ if self._global:
193
+ batch, indices = batch.sort()
194
+ rev_perm = torch.empty_like(indices)
195
+ rev_perm[indices] = torch.arange(len(indices),
196
+ device=indices.device)
197
+ x_local = self.ln(x_local[indices])
198
+ x_global, mask = to_dense_batch(x_local, batch)
199
+ for attn in self.global_attn:
200
+ x_global = attn(x_global, mask)
201
+ x = x_global[mask][rev_perm]
202
+ x = self.pred_global(x)
203
+ else:
204
+ x = self.pred_local(x_local)
205
+
206
+ return F.log_softmax(x, dim=-1)