pyg-nightly 2.7.0.dev20250701__py3-none-any.whl → 2.7.0.dev20250703__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/RECORD +25 -22
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/__init__.py +4 -0
- torch_geometric/datasets/git_mol_dataset.py +1 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/conv/meshcnn_conv.py +9 -15
- torch_geometric/nn/encoding.py +12 -3
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/glem.py +7 -3
- torch_geometric/nn/models/protein_mpnn.py +304 -0
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +12 -4
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/utils/convert.py +15 -8
- torch_geometric/utils/smiles.py +1 -1
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from itertools import product
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from torch_geometric.nn import knn_graph
|
|
8
|
+
from torch_geometric.nn.conv import MessagePassing
|
|
9
|
+
from torch_geometric.utils import to_dense_adj, to_dense_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PositionWiseFeedForward(torch.nn.Module):
|
|
13
|
+
def __init__(self, in_channels: int, hidden_channels: int) -> None:
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.out = torch.nn.Sequential(
|
|
16
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
17
|
+
torch.nn.GELU(),
|
|
18
|
+
torch.nn.Linear(hidden_channels, in_channels),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return self.out(x)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PositionalEncoding(torch.nn.Module):
|
|
26
|
+
def __init__(self, hidden_channels: int,
|
|
27
|
+
max_relative_feature: int = 32) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.max_relative_feature = max_relative_feature
|
|
30
|
+
self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,
|
|
31
|
+
hidden_channels)
|
|
32
|
+
|
|
33
|
+
def forward(self, offset, mask) -> torch.Tensor:
|
|
34
|
+
d = torch.clip(offset + self.max_relative_feature, 0,
|
|
35
|
+
2 * self.max_relative_feature) * mask + (1 - mask) * (
|
|
36
|
+
2 * self.max_relative_feature + 1) # noqa: E501
|
|
37
|
+
return self.emb(d.long())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Encoder(MessagePassing):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
in_channels: int,
|
|
44
|
+
hidden_channels: int,
|
|
45
|
+
dropout: float = 0.1,
|
|
46
|
+
scale: float = 30,
|
|
47
|
+
) -> None:
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.out_v = torch.nn.Sequential(
|
|
50
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
51
|
+
torch.nn.GELU(),
|
|
52
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
53
|
+
torch.nn.GELU(),
|
|
54
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
55
|
+
)
|
|
56
|
+
self.out_e = torch.nn.Sequential(
|
|
57
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
58
|
+
torch.nn.GELU(),
|
|
59
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
60
|
+
torch.nn.GELU(),
|
|
61
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
62
|
+
)
|
|
63
|
+
self.dropout1 = torch.nn.Dropout(dropout)
|
|
64
|
+
self.dropout2 = torch.nn.Dropout(dropout)
|
|
65
|
+
self.dropout3 = torch.nn.Dropout(dropout)
|
|
66
|
+
self.norm1 = torch.nn.LayerNorm(hidden_channels)
|
|
67
|
+
self.norm2 = torch.nn.LayerNorm(hidden_channels)
|
|
68
|
+
self.norm3 = torch.nn.LayerNorm(hidden_channels)
|
|
69
|
+
self.scale = scale
|
|
70
|
+
self.dense = PositionWiseFeedForward(hidden_channels,
|
|
71
|
+
hidden_channels * 4)
|
|
72
|
+
|
|
73
|
+
def forward(
|
|
74
|
+
self,
|
|
75
|
+
x: torch.Tensor,
|
|
76
|
+
edge_index: torch.Tensor,
|
|
77
|
+
edge_attr: torch.Tensor,
|
|
78
|
+
) -> torch.Tensor:
|
|
79
|
+
# x: [N, d_v]
|
|
80
|
+
# edge_index: [2, E]
|
|
81
|
+
# edge_attr: [E, d_e]
|
|
82
|
+
# update node features
|
|
83
|
+
h_message = self.propagate(x=x, edge_index=edge_index,
|
|
84
|
+
edge_attr=edge_attr)
|
|
85
|
+
dh = h_message / self.scale
|
|
86
|
+
x = self.norm1(x + self.dropout1(dh))
|
|
87
|
+
dh = self.dense(x)
|
|
88
|
+
x = self.norm2(x + self.dropout2(dh))
|
|
89
|
+
# update edge features
|
|
90
|
+
row, col = edge_index
|
|
91
|
+
x_i, x_j = x[row], x[col]
|
|
92
|
+
h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)
|
|
93
|
+
h_e = self.out_e(h_e)
|
|
94
|
+
edge_attr = self.norm3(edge_attr + self.dropout3(h_e))
|
|
95
|
+
return x, edge_attr
|
|
96
|
+
|
|
97
|
+
def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
|
|
98
|
+
edge_attr: torch.Tensor) -> torch.Tensor:
|
|
99
|
+
h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e]
|
|
100
|
+
h = self.out_e(h) # [E, d_e]
|
|
101
|
+
return h
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Decoder(MessagePassing):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
in_channels: int,
|
|
108
|
+
hidden_channels: int,
|
|
109
|
+
dropout: float = 0.1,
|
|
110
|
+
scale: float = 30,
|
|
111
|
+
) -> None:
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.out_v = torch.nn.Sequential(
|
|
114
|
+
torch.nn.Linear(in_channels, hidden_channels),
|
|
115
|
+
torch.nn.GELU(),
|
|
116
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
117
|
+
torch.nn.GELU(),
|
|
118
|
+
torch.nn.Linear(hidden_channels, hidden_channels),
|
|
119
|
+
)
|
|
120
|
+
self.dropout1 = torch.nn.Dropout(dropout)
|
|
121
|
+
self.dropout2 = torch.nn.Dropout(dropout)
|
|
122
|
+
self.norm1 = torch.nn.LayerNorm(hidden_channels)
|
|
123
|
+
self.norm2 = torch.nn.LayerNorm(hidden_channels)
|
|
124
|
+
self.scale = scale
|
|
125
|
+
self.dense = PositionWiseFeedForward(hidden_channels,
|
|
126
|
+
hidden_channels * 4)
|
|
127
|
+
|
|
128
|
+
def forward(
|
|
129
|
+
self,
|
|
130
|
+
x: torch.Tensor,
|
|
131
|
+
edge_index: torch.Tensor,
|
|
132
|
+
edge_attr: torch.Tensor,
|
|
133
|
+
x_label: torch.Tensor,
|
|
134
|
+
mask: torch.Tensor,
|
|
135
|
+
) -> torch.Tensor:
|
|
136
|
+
# x: [N, d_v]
|
|
137
|
+
# edge_index: [2, E]
|
|
138
|
+
# edge_attr: [E, d_e]
|
|
139
|
+
h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,
|
|
140
|
+
edge_attr=edge_attr, mask=mask)
|
|
141
|
+
dh = h_message / self.scale
|
|
142
|
+
x = self.norm1(x + self.dropout1(dh))
|
|
143
|
+
dh = self.dense(x)
|
|
144
|
+
x = self.norm2(x + self.dropout2(dh))
|
|
145
|
+
return x
|
|
146
|
+
|
|
147
|
+
def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
|
|
148
|
+
x_label_j: torch.Tensor, edge_attr: torch.Tensor,
|
|
149
|
+
mask: torch.Tensor) -> torch.Tensor:
|
|
150
|
+
h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)
|
|
151
|
+
h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)
|
|
152
|
+
h = h_1 * mask + h_0 * (1 - mask)
|
|
153
|
+
h = torch.concat([x_i, h], dim=-1)
|
|
154
|
+
h = self.out_v(h)
|
|
155
|
+
return h
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ProteinMPNN(torch.nn.Module):
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
hidden_dim: int = 128,
|
|
162
|
+
num_encoder_layers: int = 3,
|
|
163
|
+
num_decoder_layers: int = 3,
|
|
164
|
+
num_neighbors: int = 30,
|
|
165
|
+
num_rbf: int = 16,
|
|
166
|
+
dropout: float = 0.1,
|
|
167
|
+
augment_eps: float = 0.2,
|
|
168
|
+
num_positional_embedding: int = 16,
|
|
169
|
+
vocab_size: int = 21,
|
|
170
|
+
) -> None:
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.augment_eps = augment_eps
|
|
173
|
+
self.hidden_dim = hidden_dim
|
|
174
|
+
self.num_neighbors = num_neighbors
|
|
175
|
+
self.num_rbf = num_rbf
|
|
176
|
+
self.embedding = PositionalEncoding(num_positional_embedding)
|
|
177
|
+
self.edge_mlp = torch.nn.Sequential(
|
|
178
|
+
torch.nn.Linear(num_positional_embedding + 400, hidden_dim),
|
|
179
|
+
torch.nn.LayerNorm(hidden_dim),
|
|
180
|
+
torch.nn.Linear(hidden_dim, hidden_dim),
|
|
181
|
+
)
|
|
182
|
+
self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)
|
|
183
|
+
self.encoder_layers = torch.nn.ModuleList([
|
|
184
|
+
Encoder(hidden_dim * 3, hidden_dim, dropout)
|
|
185
|
+
for _ in range(num_encoder_layers)
|
|
186
|
+
])
|
|
187
|
+
|
|
188
|
+
self.decoder_layers = torch.nn.ModuleList([
|
|
189
|
+
Decoder(hidden_dim * 4, hidden_dim, dropout)
|
|
190
|
+
for _ in range(num_decoder_layers)
|
|
191
|
+
])
|
|
192
|
+
self.output = torch.nn.Linear(hidden_dim, vocab_size)
|
|
193
|
+
|
|
194
|
+
self.reset_parameters()
|
|
195
|
+
|
|
196
|
+
def reset_parameters(self):
|
|
197
|
+
for p in self.parameters():
|
|
198
|
+
if p.dim() > 1:
|
|
199
|
+
torch.nn.init.xavier_uniform_(p)
|
|
200
|
+
|
|
201
|
+
def _featurize(
|
|
202
|
+
self,
|
|
203
|
+
x: torch.Tensor,
|
|
204
|
+
mask: torch.Tensor,
|
|
205
|
+
batch: torch.Tensor,
|
|
206
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
207
|
+
N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741
|
|
208
|
+
b = Ca - N
|
|
209
|
+
c = C - Ca
|
|
210
|
+
a = torch.cross(b, c, dim=-1)
|
|
211
|
+
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
|
|
212
|
+
|
|
213
|
+
valid_mask = mask.bool()
|
|
214
|
+
valid_Ca = Ca[valid_mask]
|
|
215
|
+
valid_batch = batch[valid_mask]
|
|
216
|
+
|
|
217
|
+
edge_index = knn_graph(valid_Ca, k=self.num_neighbors,
|
|
218
|
+
batch=valid_batch, loop=True)
|
|
219
|
+
|
|
220
|
+
row, col = edge_index
|
|
221
|
+
original_indices = torch.arange(Ca.size(0),
|
|
222
|
+
device=x.device)[valid_mask]
|
|
223
|
+
edge_index_original = torch.stack(
|
|
224
|
+
[original_indices[row], original_indices[col]], dim=0)
|
|
225
|
+
row, col = edge_index_original
|
|
226
|
+
|
|
227
|
+
rbf_all = []
|
|
228
|
+
for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):
|
|
229
|
+
distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)
|
|
230
|
+
rbf = self._rbf(distances)
|
|
231
|
+
rbf_all.append(rbf)
|
|
232
|
+
|
|
233
|
+
return edge_index_original, torch.cat(rbf_all, dim=-1)
|
|
234
|
+
|
|
235
|
+
def _rbf(self, D: torch.Tensor) -> torch.Tensor:
|
|
236
|
+
D_min, D_max, D_count = 2., 22., self.num_rbf
|
|
237
|
+
D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
|
|
238
|
+
D_mu = D_mu.view([1, -1])
|
|
239
|
+
D_sigma = (D_max - D_min) / D_count
|
|
240
|
+
D_expand = torch.unsqueeze(D, -1)
|
|
241
|
+
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
|
242
|
+
return RBF
|
|
243
|
+
|
|
244
|
+
def forward(
|
|
245
|
+
self,
|
|
246
|
+
x: torch.Tensor,
|
|
247
|
+
chain_seq_label: torch.Tensor,
|
|
248
|
+
mask: torch.Tensor,
|
|
249
|
+
chain_mask_all: torch.Tensor,
|
|
250
|
+
residue_idx: torch.Tensor,
|
|
251
|
+
chain_encoding_all: torch.Tensor,
|
|
252
|
+
batch: torch.Tensor,
|
|
253
|
+
) -> torch.Tensor:
|
|
254
|
+
device = x.device
|
|
255
|
+
if self.training and self.augment_eps > 0:
|
|
256
|
+
x = x + self.augment_eps * torch.randn_like(x)
|
|
257
|
+
|
|
258
|
+
edge_index, edge_attr = self._featurize(x, mask, batch)
|
|
259
|
+
|
|
260
|
+
row, col = edge_index
|
|
261
|
+
offset = residue_idx[row] - residue_idx[col]
|
|
262
|
+
# find self vs non-self interaction
|
|
263
|
+
e_chains = ((chain_encoding_all[row] -
|
|
264
|
+
chain_encoding_all[col]) == 0).long()
|
|
265
|
+
e_pos = self.embedding(offset, e_chains)
|
|
266
|
+
h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))
|
|
267
|
+
h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
|
|
268
|
+
|
|
269
|
+
# encoder
|
|
270
|
+
for encoder in self.encoder_layers:
|
|
271
|
+
h_v, h_e = encoder(h_v, edge_index, h_e)
|
|
272
|
+
|
|
273
|
+
# mask
|
|
274
|
+
h_label = self.label_embedding(chain_seq_label)
|
|
275
|
+
batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,
|
|
276
|
+
batch) # [B, N]
|
|
277
|
+
# 0 - visible - encoder, 1 - masked - decoder
|
|
278
|
+
decoding_order = torch.argsort(
|
|
279
|
+
(batch_chain_mask_all + 1e-4) * (torch.abs(
|
|
280
|
+
torch.randn(batch_chain_mask_all.shape, device=device))))
|
|
281
|
+
mask_size = batch_chain_mask_all.size(1)
|
|
282
|
+
permutation_matrix_reverse = F.one_hot(decoding_order,
|
|
283
|
+
num_classes=mask_size).float()
|
|
284
|
+
order_mask_backward = torch.einsum(
|
|
285
|
+
'ij, biq, bjp->bqp',
|
|
286
|
+
1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),
|
|
287
|
+
permutation_matrix_reverse,
|
|
288
|
+
permutation_matrix_reverse,
|
|
289
|
+
)
|
|
290
|
+
adj = to_dense_adj(edge_index, batch)
|
|
291
|
+
mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)
|
|
292
|
+
|
|
293
|
+
# decoder
|
|
294
|
+
for decoder in self.decoder_layers:
|
|
295
|
+
h_v = decoder(
|
|
296
|
+
h_v,
|
|
297
|
+
edge_index,
|
|
298
|
+
h_e,
|
|
299
|
+
h_label,
|
|
300
|
+
mask_attend,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
logits = self.output(h_v)
|
|
304
|
+
return F.log_softmax(logits, dim=-1)
|
|
@@ -39,6 +39,8 @@ class BatchNorm(torch.nn.Module):
|
|
|
39
39
|
with only a single element will work as during in evaluation.
|
|
40
40
|
That is the running mean and variance will be used.
|
|
41
41
|
Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
|
|
42
|
+
device (torch.device, optional): The device to use for the module.
|
|
43
|
+
(default: :obj:`None`)
|
|
42
44
|
"""
|
|
43
45
|
def __init__(
|
|
44
46
|
self,
|
|
@@ -48,6 +50,7 @@ class BatchNorm(torch.nn.Module):
|
|
|
48
50
|
affine: bool = True,
|
|
49
51
|
track_running_stats: bool = True,
|
|
50
52
|
allow_single_element: bool = False,
|
|
53
|
+
device: Optional[torch.device] = None,
|
|
51
54
|
):
|
|
52
55
|
super().__init__()
|
|
53
56
|
|
|
@@ -56,7 +59,7 @@ class BatchNorm(torch.nn.Module):
|
|
|
56
59
|
"'track_running_stats' to be set to `True`")
|
|
57
60
|
|
|
58
61
|
self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
|
|
59
|
-
track_running_stats)
|
|
62
|
+
track_running_stats, device=device)
|
|
60
63
|
self.in_channels = in_channels
|
|
61
64
|
self.allow_single_element = allow_single_element
|
|
62
65
|
|
|
@@ -114,6 +117,8 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
114
117
|
:obj:`False`, this module does not track such statistics and always
|
|
115
118
|
uses batch statistics in both training and eval modes.
|
|
116
119
|
(default: :obj:`True`)
|
|
120
|
+
device (torch.device, optional): The device to use for the module.
|
|
121
|
+
(default: :obj:`None`)
|
|
117
122
|
"""
|
|
118
123
|
def __init__(
|
|
119
124
|
self,
|
|
@@ -123,6 +128,7 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
123
128
|
momentum: Optional[float] = 0.1,
|
|
124
129
|
affine: bool = True,
|
|
125
130
|
track_running_stats: bool = True,
|
|
131
|
+
device: Optional[torch.device] = None,
|
|
126
132
|
):
|
|
127
133
|
super().__init__()
|
|
128
134
|
|
|
@@ -134,17 +140,21 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
134
140
|
self.track_running_stats = track_running_stats
|
|
135
141
|
|
|
136
142
|
if self.affine:
|
|
137
|
-
self.weight = Parameter(
|
|
138
|
-
|
|
143
|
+
self.weight = Parameter(
|
|
144
|
+
torch.empty(num_types, in_channels, device=device))
|
|
145
|
+
self.bias = Parameter(
|
|
146
|
+
torch.empty(num_types, in_channels, device=device))
|
|
139
147
|
else:
|
|
140
148
|
self.register_parameter('weight', None)
|
|
141
149
|
self.register_parameter('bias', None)
|
|
142
150
|
|
|
143
151
|
if self.track_running_stats:
|
|
144
|
-
self.register_buffer(
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
152
|
+
self.register_buffer(
|
|
153
|
+
'running_mean',
|
|
154
|
+
torch.empty(num_types, in_channels, device=device))
|
|
155
|
+
self.register_buffer(
|
|
156
|
+
'running_var',
|
|
157
|
+
torch.empty(num_types, in_channels, device=device))
|
|
148
158
|
self.register_buffer('num_batches_tracked', torch.tensor(0))
|
|
149
159
|
else:
|
|
150
160
|
self.register_buffer('running_mean', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import Tensor
|
|
3
5
|
from torch.nn import BatchNorm1d, Linear
|
|
@@ -39,6 +41,8 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
39
41
|
:obj:`False`, this module does not track such statistics and always
|
|
40
42
|
uses batch statistics in both training and eval modes.
|
|
41
43
|
(default: :obj:`True`)
|
|
44
|
+
device (torch.device, optional): The device to use for the module.
|
|
45
|
+
(default: :obj:`None`)
|
|
42
46
|
"""
|
|
43
47
|
def __init__(
|
|
44
48
|
self,
|
|
@@ -49,6 +53,7 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
49
53
|
momentum: float = 0.1,
|
|
50
54
|
affine: bool = True,
|
|
51
55
|
track_running_stats: bool = True,
|
|
56
|
+
device: Optional[torch.device] = None,
|
|
52
57
|
):
|
|
53
58
|
super().__init__()
|
|
54
59
|
|
|
@@ -56,9 +61,9 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
56
61
|
self.groups = groups
|
|
57
62
|
self.lamda = lamda
|
|
58
63
|
|
|
59
|
-
self.lin = Linear(in_channels, groups, bias=False)
|
|
64
|
+
self.lin = Linear(in_channels, groups, bias=False, device=device)
|
|
60
65
|
self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,
|
|
61
|
-
track_running_stats)
|
|
66
|
+
track_running_stats, device=device)
|
|
62
67
|
|
|
63
68
|
self.reset_parameters()
|
|
64
69
|
|
|
@@ -26,16 +26,21 @@ class GraphNorm(torch.nn.Module):
|
|
|
26
26
|
in_channels (int): Size of each input sample.
|
|
27
27
|
eps (float, optional): A value added to the denominator for numerical
|
|
28
28
|
stability. (default: :obj:`1e-5`)
|
|
29
|
+
device (torch.device, optional): The device to use for the module.
|
|
30
|
+
(default: :obj:`None`)
|
|
29
31
|
"""
|
|
30
|
-
def __init__(self, in_channels: int, eps: float = 1e-5
|
|
32
|
+
def __init__(self, in_channels: int, eps: float = 1e-5,
|
|
33
|
+
device: Optional[torch.device] = None):
|
|
31
34
|
super().__init__()
|
|
32
35
|
|
|
33
36
|
self.in_channels = in_channels
|
|
34
37
|
self.eps = eps
|
|
35
38
|
|
|
36
|
-
self.weight = torch.nn.Parameter(
|
|
37
|
-
|
|
38
|
-
self.
|
|
39
|
+
self.weight = torch.nn.Parameter(
|
|
40
|
+
torch.empty(in_channels, device=device))
|
|
41
|
+
self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))
|
|
42
|
+
self.mean_scale = torch.nn.Parameter(
|
|
43
|
+
torch.empty(in_channels, device=device))
|
|
39
44
|
|
|
40
45
|
self.reset_parameters()
|
|
41
46
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
import torch.nn.functional as F
|
|
4
5
|
from torch import Tensor
|
|
5
6
|
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
@@ -36,6 +37,8 @@ class InstanceNorm(_InstanceNorm):
|
|
|
36
37
|
:obj:`False`, this module does not track such statistics and always
|
|
37
38
|
uses instance statistics in both training and eval modes.
|
|
38
39
|
(default: :obj:`False`)
|
|
40
|
+
device (torch.device, optional): The device to use for the module.
|
|
41
|
+
(default: :obj:`None`)
|
|
39
42
|
"""
|
|
40
43
|
def __init__(
|
|
41
44
|
self,
|
|
@@ -44,9 +47,10 @@ class InstanceNorm(_InstanceNorm):
|
|
|
44
47
|
momentum: float = 0.1,
|
|
45
48
|
affine: bool = False,
|
|
46
49
|
track_running_stats: bool = False,
|
|
50
|
+
device: Optional[torch.device] = None,
|
|
47
51
|
):
|
|
48
52
|
super().__init__(in_channels, eps, momentum, affine,
|
|
49
|
-
track_running_stats)
|
|
53
|
+
track_running_stats, device=device)
|
|
50
54
|
|
|
51
55
|
def reset_parameters(self):
|
|
52
56
|
r"""Resets all learnable parameters of the module."""
|
|
@@ -35,6 +35,8 @@ class LayerNorm(torch.nn.Module):
|
|
|
35
35
|
is used, each graph will be considered as an element to be
|
|
36
36
|
normalized. If `"node"` is used, each node will be considered as
|
|
37
37
|
an element to be normalized. (default: :obj:`"graph"`)
|
|
38
|
+
device (torch.device, optional): The device to use for the module.
|
|
39
|
+
(default: :obj:`None`)
|
|
38
40
|
"""
|
|
39
41
|
def __init__(
|
|
40
42
|
self,
|
|
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
42
44
|
eps: float = 1e-5,
|
|
43
45
|
affine: bool = True,
|
|
44
46
|
mode: str = 'graph',
|
|
47
|
+
device: Optional[torch.device] = None,
|
|
45
48
|
):
|
|
46
49
|
super().__init__()
|
|
47
50
|
|
|
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
|
|
|
51
54
|
self.mode = mode
|
|
52
55
|
|
|
53
56
|
if affine:
|
|
54
|
-
self.weight = Parameter(torch.empty(in_channels))
|
|
55
|
-
self.bias = Parameter(torch.empty(in_channels))
|
|
57
|
+
self.weight = Parameter(torch.empty(in_channels, device=device))
|
|
58
|
+
self.bias = Parameter(torch.empty(in_channels, device=device))
|
|
56
59
|
else:
|
|
57
60
|
self.register_parameter('weight', None)
|
|
58
61
|
self.register_parameter('bias', None)
|
|
@@ -134,6 +137,8 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
134
137
|
normalization (:obj:`"node"`). If `"node"` is used, each node will
|
|
135
138
|
be considered as an element to be normalized.
|
|
136
139
|
(default: :obj:`"node"`)
|
|
140
|
+
device (torch.device, optional): The device to use for the module.
|
|
141
|
+
(default: :obj:`None`)
|
|
137
142
|
"""
|
|
138
143
|
def __init__(
|
|
139
144
|
self,
|
|
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
142
147
|
eps: float = 1e-5,
|
|
143
148
|
affine: bool = True,
|
|
144
149
|
mode: str = 'node',
|
|
150
|
+
device: Optional[torch.device] = None,
|
|
145
151
|
):
|
|
146
152
|
super().__init__()
|
|
147
153
|
assert mode == 'node'
|
|
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
152
158
|
self.affine = affine
|
|
153
159
|
|
|
154
160
|
if affine:
|
|
155
|
-
self.weight = Parameter(
|
|
156
|
-
|
|
161
|
+
self.weight = Parameter(
|
|
162
|
+
torch.empty(num_types, in_channels, device=device))
|
|
163
|
+
self.bias = Parameter(
|
|
164
|
+
torch.empty(num_types, in_channels, device=device))
|
|
157
165
|
else:
|
|
158
166
|
self.register_parameter('weight', None)
|
|
159
167
|
self.register_parameter('bias', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn.functional as F
|
|
3
5
|
from torch import Tensor
|
|
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
|
|
|
19
21
|
learn_scale (bool, optional): If set to :obj:`True`, will learn the
|
|
20
22
|
scaling factor :math:`s` of message normalization.
|
|
21
23
|
(default: :obj:`False`)
|
|
24
|
+
device (torch.device, optional): The device to use for the module.
|
|
25
|
+
(default: :obj:`None`)
|
|
22
26
|
"""
|
|
23
|
-
def __init__(self, learn_scale: bool = False
|
|
27
|
+
def __init__(self, learn_scale: bool = False,
|
|
28
|
+
device: Optional[torch.device] = None):
|
|
24
29
|
super().__init__()
|
|
25
|
-
self.scale = Parameter(torch.empty(1
|
|
30
|
+
self.scale = Parameter(torch.empty(1, device=device),
|
|
31
|
+
requires_grad=learn_scale)
|
|
26
32
|
self.reset_parameters()
|
|
27
33
|
|
|
28
34
|
def reset_parameters(self):
|
torch_geometric/utils/convert.py
CHANGED
|
@@ -452,15 +452,22 @@ def to_cugraph(
|
|
|
452
452
|
g = cugraph.Graph(directed=directed)
|
|
453
453
|
df = cudf.from_dlpack(to_dlpack(edge_index.t()))
|
|
454
454
|
|
|
455
|
+
df = cudf.DataFrame({
|
|
456
|
+
'source':
|
|
457
|
+
cudf.from_dlpack(to_dlpack(edge_index[0])),
|
|
458
|
+
'destination':
|
|
459
|
+
cudf.from_dlpack(to_dlpack(edge_index[1])),
|
|
460
|
+
})
|
|
461
|
+
|
|
455
462
|
if edge_weight is not None:
|
|
456
463
|
assert edge_weight.dim() == 1
|
|
457
|
-
df['
|
|
464
|
+
df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
|
|
458
465
|
|
|
459
466
|
g.from_cudf_edgelist(
|
|
460
467
|
df,
|
|
461
|
-
source=
|
|
462
|
-
destination=
|
|
463
|
-
edge_attr='
|
|
468
|
+
source='source',
|
|
469
|
+
destination='destination',
|
|
470
|
+
edge_attr='weight' if edge_weight is not None else None,
|
|
464
471
|
renumber=relabel_nodes,
|
|
465
472
|
)
|
|
466
473
|
|
|
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
476
483
|
"""
|
|
477
484
|
df = g.view_edge_list()
|
|
478
485
|
|
|
479
|
-
src = from_dlpack(df[
|
|
480
|
-
dst = from_dlpack(df[
|
|
486
|
+
src = from_dlpack(df[g.source_columns].to_dlpack()).long()
|
|
487
|
+
dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
|
|
481
488
|
edge_index = torch.stack([src, dst], dim=0)
|
|
482
489
|
|
|
483
490
|
edge_weight = None
|
|
484
|
-
if
|
|
485
|
-
edge_weight = from_dlpack(df[
|
|
491
|
+
if g.weight_column is not None:
|
|
492
|
+
edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
|
|
486
493
|
|
|
487
494
|
return edge_index, edge_weight
|
|
488
495
|
|
torch_geometric/utils/smiles.py
CHANGED
|
File without changes
|
{pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250703.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|