pyg-nightly 2.7.0.dev20250228__py3-none-any.whl → 2.7.0.dev20250302__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250228
3
+ Version: 2.7.0.dev20250302
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=55I0x87kSkECyYdIRxH0kF2v6J9VGbG_99GlaLPMQXQ,1978
1
+ torch_geometric/__init__.py,sha256=-ySUkCuMvYC1iL9it4tdObi1jhTkJrpNJ86BUKiAxE0,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -334,7 +334,7 @@ torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nq
334
334
  torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
335
335
  torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
336
336
  torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
337
- torch_geometric/nn/attention/sgformer.py,sha256=QT_3bVOPsqPY0fWVdwyr8E0LL9u9TwC_sv0_2EggdpA,2893
337
+ torch_geometric/nn/attention/sgformer.py,sha256=U4R_tGF1IWyKDlV0t4jFNatW2-payc-hPGM3sFdiYBE,3812
338
338
  torch_geometric/nn/conv/__init__.py,sha256=37zTdt0gfSAUPMtwXjZg5mWx_itojJVFNODYR1h1ch0,3515
339
339
  torch_geometric/nn/conv/agnn_conv.py,sha256=5nEPLx_BBHcDaO6HWzLuHfXc0Yd_reKynAOH0Iq09lU,3077
340
340
  torch_geometric/nn/conv/antisymmetric_conv.py,sha256=dhA6sCETy1jlXReYJZBSyToOcL_mZ1wL10fMIb8Ppuw,4387
@@ -458,7 +458,7 @@ torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3W
458
458
  torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
459
459
  torch_geometric/nn/models/rev_gnn.py,sha256=1b6wU-6YTuLsWn5p8c5LXQm2KugEAVcEYJKZbWTDvgQ,11796
460
460
  torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O3hDOZM,16656
461
- torch_geometric/nn/models/sgformer.py,sha256=QacmFjlTSUgk6T5Dk4NfG9zmxHZqr0ggx1ElNrPMJNc,5878
461
+ torch_geometric/nn/models/sgformer.py,sha256=V-F3J-yEdjuhkSsCkQePG_ByNHAu0BcKpknVxjCI3KY,6761
462
462
  torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6lASrKwUEAI,9841
463
463
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
464
464
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
@@ -633,7 +633,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
633
633
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
634
634
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
635
635
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
636
- pyg_nightly-2.7.0.dev20250228.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
637
- pyg_nightly-2.7.0.dev20250228.dist-info/WHEEL,sha256=_2ozNFCLWc93bK4WKHCO-eDUENDlo-dgc9cU3qokYO4,82
638
- pyg_nightly-2.7.0.dev20250228.dist-info/METADATA,sha256=WqOwZeA_woFXbV3pQXk5CAIZeKr1txi5OpmJnrtXtww,63021
639
- pyg_nightly-2.7.0.dev20250228.dist-info/RECORD,,
636
+ pyg_nightly-2.7.0.dev20250302.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
637
+ pyg_nightly-2.7.0.dev20250302.dist-info/WHEEL,sha256=_2ozNFCLWc93bK4WKHCO-eDUENDlo-dgc9cU3qokYO4,82
638
+ pyg_nightly-2.7.0.dev20250302.dist-info/METADATA,sha256=sLyKGhdjP0OYneFdjiW5cyxTIb8wDONfzlPRcuwVrYw,63021
639
+ pyg_nightly-2.7.0.dev20250302.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250228'
34
+ __version__ = '2.7.0.dev20250302'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  from torch import Tensor
3
5
 
@@ -38,34 +40,54 @@ class SGFormerAttention(torch.nn.Module):
38
40
  self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
39
41
  self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
40
42
 
41
- def forward(self, x: Tensor) -> Tensor:
42
- # feature transformation
43
- qs = self.q(x).reshape(-1, self.heads, self.head_channels)
44
- ks = self.k(x).reshape(-1, self.heads, self.head_channels)
45
- vs = self.v(x).reshape(-1, self.heads, self.head_channels)
43
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
44
+ r"""Forward pass.
45
+
46
+ Args:
47
+ x (torch.Tensor): Node feature tensor
48
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
49
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
50
+ each graph, and feature dimension :math:`F`.
51
+ mask (torch.Tensor, optional): Mask matrix
52
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
53
+ the valid nodes for each graph. (default: :obj:`None`)
54
+ """
55
+ B, N, *_ = x.shape
56
+ qs, ks, vs = self.q(x), self.k(x), self.v(x)
57
+ # reshape and permute q, k and v to proper shape
58
+ # (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels)
59
+ qs, ks, vs = map(
60
+ lambda t: t.reshape(B, N, self.heads, self.head_channels),
61
+ (qs, ks, vs))
46
62
 
47
- # normalize input
48
- qs = qs / torch.norm(qs, p=2) # [N, H, M]
49
- ks = ks / torch.norm(ks, p=2) # [L, H, M]
50
- N = qs.shape[0]
63
+ if mask is not None:
64
+ mask = mask[:, :, None, None]
65
+ vs.masked_fill_(~mask, 0.)
66
+ # replace 0's with epsilon
67
+ epsilon = 1e-6
68
+ qs[qs == 0] = epsilon
69
+ ks[ks == 0] = epsilon
70
+ # normalize input, shape not changed
71
+ qs, ks = map(
72
+ lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True),
73
+ (qs, ks))
51
74
 
52
75
  # numerator
53
- kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
54
- attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
76
+ kvs = torch.einsum("blhm,blhd->bhmd", ks, vs)
77
+ attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs)
55
78
  attention_num += N * vs
56
79
 
57
80
  # denominator
58
- all_ones = torch.ones([ks.shape[0]]).to(ks.device)
59
- ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
60
- attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H]
61
-
81
+ all_ones = torch.ones([B, N]).to(ks.device)
82
+ ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones)
83
+ attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum)
62
84
  # attentive aggregated results
63
- attention_normalizer = torch.unsqueeze(
64
- attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1]
85
+ attention_normalizer = torch.unsqueeze(attention_normalizer,
86
+ len(attention_normalizer.shape))
65
87
  attention_normalizer += torch.ones_like(attention_normalizer) * N
66
- attn_output = attention_num / attention_normalizer # [N, H, D]
88
+ attn_output = attention_num / attention_normalizer
67
89
 
68
- return attn_output.mean(dim=1)
90
+ return attn_output.mean(dim=2)
69
91
 
70
92
  def reset_parameters(self):
71
93
  self.q.reset_parameters()
@@ -1,8 +1,12 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
5
+ from torch import Tensor
3
6
 
4
7
  from torch_geometric.nn.attention import SGFormerAttention
5
8
  from torch_geometric.nn.conv import GCNConv
9
+ from torch_geometric.utils import to_dense_batch
6
10
 
7
11
 
8
12
  class GraphModule(torch.nn.Module):
@@ -84,7 +88,11 @@ class SGModule(torch.nn.Module):
84
88
  for fc in self.fcs:
85
89
  fc.reset_parameters()
86
90
 
87
- def forward(self, x):
91
+ def forward(self, x: Tensor, batch: Tensor):
92
+ # to dense batch expects sorted batch
93
+ batch, indices = batch.sort(stable=True)
94
+ x = x[indices]
95
+ x, mask = to_dense_batch(x, batch)
88
96
  layer_ = []
89
97
 
90
98
  # input MLP layer
@@ -97,14 +105,17 @@ class SGModule(torch.nn.Module):
97
105
  layer_.append(x)
98
106
 
99
107
  for i, attn in enumerate(self.attns):
100
- x = attn(x)
108
+ x = attn(x, mask)
101
109
  x = (x + layer_[i]) / 2.
102
110
  x = self.bns[i + 1](x)
103
111
  x = self.activation(x)
104
112
  x = F.dropout(x, p=self.dropout, training=self.training)
105
113
  layer_.append(x)
106
114
 
107
- return x
115
+ x_mask = x[mask]
116
+ # reverse the sorting
117
+ unsorted_x_mask = x_mask[indices.argsort()]
118
+ return unsorted_x_mask
108
119
 
109
120
 
110
121
  class SGFormer(torch.nn.Module):
@@ -179,8 +190,22 @@ class SGFormer(torch.nn.Module):
179
190
  self.graph_conv.reset_parameters()
180
191
  self.fc.reset_parameters()
181
192
 
182
- def forward(self, x, edge_index):
183
- x1 = self.trans_conv(x)
193
+ def forward(
194
+ self,
195
+ x: Tensor,
196
+ edge_index: Tensor,
197
+ batch: Optional[Tensor],
198
+ ) -> Tensor:
199
+ r"""Forward pass.
200
+
201
+ Args:
202
+ x (torch.Tensor): The input node features.
203
+ edge_index (torch.Tensor or SparseTensor): The edge indices.
204
+ batch (torch.Tensor, optional): The batch vector
205
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
206
+ each element to a specific example.
207
+ """
208
+ x1 = self.trans_conv(x, batch)
184
209
  x2 = self.graph_conv(x, edge_index)
185
210
  if self.aggregate == 'add':
186
211
  x = self.graph_weight * x2 + (1 - self.graph_weight) * x1