pyg-nightly 2.7.0.dev20250228__py3-none-any.whl → 2.7.0.dev20250301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250228.dist-info → pyg_nightly-2.7.0.dev20250301.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250228.dist-info → pyg_nightly-2.7.0.dev20250301.dist-info}/RECORD +7 -7
- torch_geometric/__init__.py +1 -1
- torch_geometric/nn/attention/sgformer.py +41 -19
- torch_geometric/nn/models/sgformer.py +30 -5
- {pyg_nightly-2.7.0.dev20250228.dist-info → pyg_nightly-2.7.0.dev20250301.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250228.dist-info → pyg_nightly-2.7.0.dev20250301.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250228.dist-info → pyg_nightly-2.7.0.dev20250301.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250301
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=dOxOnSZ6b_-SJ2LAthheeG9ubXVjh8-19erKucOrbf4,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -334,7 +334,7 @@ torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nq
|
|
334
334
|
torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
|
335
335
|
torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
|
336
336
|
torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
|
337
|
-
torch_geometric/nn/attention/sgformer.py,sha256=
|
337
|
+
torch_geometric/nn/attention/sgformer.py,sha256=U4R_tGF1IWyKDlV0t4jFNatW2-payc-hPGM3sFdiYBE,3812
|
338
338
|
torch_geometric/nn/conv/__init__.py,sha256=37zTdt0gfSAUPMtwXjZg5mWx_itojJVFNODYR1h1ch0,3515
|
339
339
|
torch_geometric/nn/conv/agnn_conv.py,sha256=5nEPLx_BBHcDaO6HWzLuHfXc0Yd_reKynAOH0Iq09lU,3077
|
340
340
|
torch_geometric/nn/conv/antisymmetric_conv.py,sha256=dhA6sCETy1jlXReYJZBSyToOcL_mZ1wL10fMIb8Ppuw,4387
|
@@ -458,7 +458,7 @@ torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3W
|
|
458
458
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
459
459
|
torch_geometric/nn/models/rev_gnn.py,sha256=1b6wU-6YTuLsWn5p8c5LXQm2KugEAVcEYJKZbWTDvgQ,11796
|
460
460
|
torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O3hDOZM,16656
|
461
|
-
torch_geometric/nn/models/sgformer.py,sha256=
|
461
|
+
torch_geometric/nn/models/sgformer.py,sha256=V-F3J-yEdjuhkSsCkQePG_ByNHAu0BcKpknVxjCI3KY,6761
|
462
462
|
torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6lASrKwUEAI,9841
|
463
463
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
464
464
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
@@ -633,7 +633,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
633
633
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
634
634
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
635
635
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
636
|
-
pyg_nightly-2.7.0.
|
637
|
-
pyg_nightly-2.7.0.
|
638
|
-
pyg_nightly-2.7.0.
|
639
|
-
pyg_nightly-2.7.0.
|
636
|
+
pyg_nightly-2.7.0.dev20250301.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
637
|
+
pyg_nightly-2.7.0.dev20250301.dist-info/WHEEL,sha256=_2ozNFCLWc93bK4WKHCO-eDUENDlo-dgc9cU3qokYO4,82
|
638
|
+
pyg_nightly-2.7.0.dev20250301.dist-info/METADATA,sha256=MyWSa_ap4oMCz_T8PQTHTMvOnGJZtUF8T4U4XJo5Q38,63021
|
639
|
+
pyg_nightly-2.7.0.dev20250301.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250301'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
1
3
|
import torch
|
2
4
|
from torch import Tensor
|
3
5
|
|
@@ -38,34 +40,54 @@ class SGFormerAttention(torch.nn.Module):
|
|
38
40
|
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
39
41
|
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
40
42
|
|
41
|
-
def forward(self, x: Tensor) -> Tensor:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
43
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
44
|
+
r"""Forward pass.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
x (torch.Tensor): Node feature tensor
|
48
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
49
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
50
|
+
each graph, and feature dimension :math:`F`.
|
51
|
+
mask (torch.Tensor, optional): Mask matrix
|
52
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
53
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
54
|
+
"""
|
55
|
+
B, N, *_ = x.shape
|
56
|
+
qs, ks, vs = self.q(x), self.k(x), self.v(x)
|
57
|
+
# reshape and permute q, k and v to proper shape
|
58
|
+
# (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels)
|
59
|
+
qs, ks, vs = map(
|
60
|
+
lambda t: t.reshape(B, N, self.heads, self.head_channels),
|
61
|
+
(qs, ks, vs))
|
46
62
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
63
|
+
if mask is not None:
|
64
|
+
mask = mask[:, :, None, None]
|
65
|
+
vs.masked_fill_(~mask, 0.)
|
66
|
+
# replace 0's with epsilon
|
67
|
+
epsilon = 1e-6
|
68
|
+
qs[qs == 0] = epsilon
|
69
|
+
ks[ks == 0] = epsilon
|
70
|
+
# normalize input, shape not changed
|
71
|
+
qs, ks = map(
|
72
|
+
lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True),
|
73
|
+
(qs, ks))
|
51
74
|
|
52
75
|
# numerator
|
53
|
-
kvs = torch.einsum("
|
54
|
-
attention_num = torch.einsum("
|
76
|
+
kvs = torch.einsum("blhm,blhd->bhmd", ks, vs)
|
77
|
+
attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs)
|
55
78
|
attention_num += N * vs
|
56
79
|
|
57
80
|
# denominator
|
58
|
-
all_ones = torch.ones([
|
59
|
-
ks_sum = torch.einsum("
|
60
|
-
attention_normalizer = torch.einsum("
|
61
|
-
|
81
|
+
all_ones = torch.ones([B, N]).to(ks.device)
|
82
|
+
ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones)
|
83
|
+
attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum)
|
62
84
|
# attentive aggregated results
|
63
|
-
attention_normalizer = torch.unsqueeze(
|
64
|
-
|
85
|
+
attention_normalizer = torch.unsqueeze(attention_normalizer,
|
86
|
+
len(attention_normalizer.shape))
|
65
87
|
attention_normalizer += torch.ones_like(attention_normalizer) * N
|
66
|
-
attn_output = attention_num / attention_normalizer
|
88
|
+
attn_output = attention_num / attention_normalizer
|
67
89
|
|
68
|
-
return attn_output.mean(dim=
|
90
|
+
return attn_output.mean(dim=2)
|
69
91
|
|
70
92
|
def reset_parameters(self):
|
71
93
|
self.q.reset_parameters()
|
@@ -1,8 +1,12 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
1
3
|
import torch
|
2
4
|
import torch.nn.functional as F
|
5
|
+
from torch import Tensor
|
3
6
|
|
4
7
|
from torch_geometric.nn.attention import SGFormerAttention
|
5
8
|
from torch_geometric.nn.conv import GCNConv
|
9
|
+
from torch_geometric.utils import to_dense_batch
|
6
10
|
|
7
11
|
|
8
12
|
class GraphModule(torch.nn.Module):
|
@@ -84,7 +88,11 @@ class SGModule(torch.nn.Module):
|
|
84
88
|
for fc in self.fcs:
|
85
89
|
fc.reset_parameters()
|
86
90
|
|
87
|
-
def forward(self, x):
|
91
|
+
def forward(self, x: Tensor, batch: Tensor):
|
92
|
+
# to dense batch expects sorted batch
|
93
|
+
batch, indices = batch.sort(stable=True)
|
94
|
+
x = x[indices]
|
95
|
+
x, mask = to_dense_batch(x, batch)
|
88
96
|
layer_ = []
|
89
97
|
|
90
98
|
# input MLP layer
|
@@ -97,14 +105,17 @@ class SGModule(torch.nn.Module):
|
|
97
105
|
layer_.append(x)
|
98
106
|
|
99
107
|
for i, attn in enumerate(self.attns):
|
100
|
-
x = attn(x)
|
108
|
+
x = attn(x, mask)
|
101
109
|
x = (x + layer_[i]) / 2.
|
102
110
|
x = self.bns[i + 1](x)
|
103
111
|
x = self.activation(x)
|
104
112
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
105
113
|
layer_.append(x)
|
106
114
|
|
107
|
-
|
115
|
+
x_mask = x[mask]
|
116
|
+
# reverse the sorting
|
117
|
+
unsorted_x_mask = x_mask[indices.argsort()]
|
118
|
+
return unsorted_x_mask
|
108
119
|
|
109
120
|
|
110
121
|
class SGFormer(torch.nn.Module):
|
@@ -179,8 +190,22 @@ class SGFormer(torch.nn.Module):
|
|
179
190
|
self.graph_conv.reset_parameters()
|
180
191
|
self.fc.reset_parameters()
|
181
192
|
|
182
|
-
def forward(
|
183
|
-
|
193
|
+
def forward(
|
194
|
+
self,
|
195
|
+
x: Tensor,
|
196
|
+
edge_index: Tensor,
|
197
|
+
batch: Optional[Tensor],
|
198
|
+
) -> Tensor:
|
199
|
+
r"""Forward pass.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
x (torch.Tensor): The input node features.
|
203
|
+
edge_index (torch.Tensor or SparseTensor): The edge indices.
|
204
|
+
batch (torch.Tensor, optional): The batch vector
|
205
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
206
|
+
each element to a specific example.
|
207
|
+
"""
|
208
|
+
x1 = self.trans_conv(x, batch)
|
184
209
|
x2 = self.graph_conv(x, edge_index)
|
185
210
|
if self.aggregate == 'add':
|
186
211
|
x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
|
File without changes
|
{pyg_nightly-2.7.0.dev20250228.dist-info → pyg_nightly-2.7.0.dev20250301.dist-info}/licenses/LICENSE
RENAMED
File without changes
|