pyg-nightly 2.7.0.dev20250702__py3-none-any.whl → 2.7.0.dev20250704__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -0,0 +1,304 @@
1
+ from itertools import product
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from torch_geometric.nn import knn_graph
8
+ from torch_geometric.nn.conv import MessagePassing
9
+ from torch_geometric.utils import to_dense_adj, to_dense_batch
10
+
11
+
12
+ class PositionWiseFeedForward(torch.nn.Module):
13
+ def __init__(self, in_channels: int, hidden_channels: int) -> None:
14
+ super().__init__()
15
+ self.out = torch.nn.Sequential(
16
+ torch.nn.Linear(in_channels, hidden_channels),
17
+ torch.nn.GELU(),
18
+ torch.nn.Linear(hidden_channels, in_channels),
19
+ )
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.out(x)
23
+
24
+
25
+ class PositionalEncoding(torch.nn.Module):
26
+ def __init__(self, hidden_channels: int,
27
+ max_relative_feature: int = 32) -> None:
28
+ super().__init__()
29
+ self.max_relative_feature = max_relative_feature
30
+ self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,
31
+ hidden_channels)
32
+
33
+ def forward(self, offset, mask) -> torch.Tensor:
34
+ d = torch.clip(offset + self.max_relative_feature, 0,
35
+ 2 * self.max_relative_feature) * mask + (1 - mask) * (
36
+ 2 * self.max_relative_feature + 1) # noqa: E501
37
+ return self.emb(d.long())
38
+
39
+
40
+ class Encoder(MessagePassing):
41
+ def __init__(
42
+ self,
43
+ in_channels: int,
44
+ hidden_channels: int,
45
+ dropout: float = 0.1,
46
+ scale: float = 30,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.out_v = torch.nn.Sequential(
50
+ torch.nn.Linear(in_channels, hidden_channels),
51
+ torch.nn.GELU(),
52
+ torch.nn.Linear(hidden_channels, hidden_channels),
53
+ torch.nn.GELU(),
54
+ torch.nn.Linear(hidden_channels, hidden_channels),
55
+ )
56
+ self.out_e = torch.nn.Sequential(
57
+ torch.nn.Linear(in_channels, hidden_channels),
58
+ torch.nn.GELU(),
59
+ torch.nn.Linear(hidden_channels, hidden_channels),
60
+ torch.nn.GELU(),
61
+ torch.nn.Linear(hidden_channels, hidden_channels),
62
+ )
63
+ self.dropout1 = torch.nn.Dropout(dropout)
64
+ self.dropout2 = torch.nn.Dropout(dropout)
65
+ self.dropout3 = torch.nn.Dropout(dropout)
66
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
67
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
68
+ self.norm3 = torch.nn.LayerNorm(hidden_channels)
69
+ self.scale = scale
70
+ self.dense = PositionWiseFeedForward(hidden_channels,
71
+ hidden_channels * 4)
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ edge_index: torch.Tensor,
77
+ edge_attr: torch.Tensor,
78
+ ) -> torch.Tensor:
79
+ # x: [N, d_v]
80
+ # edge_index: [2, E]
81
+ # edge_attr: [E, d_e]
82
+ # update node features
83
+ h_message = self.propagate(x=x, edge_index=edge_index,
84
+ edge_attr=edge_attr)
85
+ dh = h_message / self.scale
86
+ x = self.norm1(x + self.dropout1(dh))
87
+ dh = self.dense(x)
88
+ x = self.norm2(x + self.dropout2(dh))
89
+ # update edge features
90
+ row, col = edge_index
91
+ x_i, x_j = x[row], x[col]
92
+ h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)
93
+ h_e = self.out_e(h_e)
94
+ edge_attr = self.norm3(edge_attr + self.dropout3(h_e))
95
+ return x, edge_attr
96
+
97
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
98
+ edge_attr: torch.Tensor) -> torch.Tensor:
99
+ h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e]
100
+ h = self.out_e(h) # [E, d_e]
101
+ return h
102
+
103
+
104
+ class Decoder(MessagePassing):
105
+ def __init__(
106
+ self,
107
+ in_channels: int,
108
+ hidden_channels: int,
109
+ dropout: float = 0.1,
110
+ scale: float = 30,
111
+ ) -> None:
112
+ super().__init__()
113
+ self.out_v = torch.nn.Sequential(
114
+ torch.nn.Linear(in_channels, hidden_channels),
115
+ torch.nn.GELU(),
116
+ torch.nn.Linear(hidden_channels, hidden_channels),
117
+ torch.nn.GELU(),
118
+ torch.nn.Linear(hidden_channels, hidden_channels),
119
+ )
120
+ self.dropout1 = torch.nn.Dropout(dropout)
121
+ self.dropout2 = torch.nn.Dropout(dropout)
122
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
123
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
124
+ self.scale = scale
125
+ self.dense = PositionWiseFeedForward(hidden_channels,
126
+ hidden_channels * 4)
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ edge_index: torch.Tensor,
132
+ edge_attr: torch.Tensor,
133
+ x_label: torch.Tensor,
134
+ mask: torch.Tensor,
135
+ ) -> torch.Tensor:
136
+ # x: [N, d_v]
137
+ # edge_index: [2, E]
138
+ # edge_attr: [E, d_e]
139
+ h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,
140
+ edge_attr=edge_attr, mask=mask)
141
+ dh = h_message / self.scale
142
+ x = self.norm1(x + self.dropout1(dh))
143
+ dh = self.dense(x)
144
+ x = self.norm2(x + self.dropout2(dh))
145
+ return x
146
+
147
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
148
+ x_label_j: torch.Tensor, edge_attr: torch.Tensor,
149
+ mask: torch.Tensor) -> torch.Tensor:
150
+ h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)
151
+ h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)
152
+ h = h_1 * mask + h_0 * (1 - mask)
153
+ h = torch.concat([x_i, h], dim=-1)
154
+ h = self.out_v(h)
155
+ return h
156
+
157
+
158
+ class ProteinMPNN(torch.nn.Module):
159
+ def __init__(
160
+ self,
161
+ hidden_dim: int = 128,
162
+ num_encoder_layers: int = 3,
163
+ num_decoder_layers: int = 3,
164
+ num_neighbors: int = 30,
165
+ num_rbf: int = 16,
166
+ dropout: float = 0.1,
167
+ augment_eps: float = 0.2,
168
+ num_positional_embedding: int = 16,
169
+ vocab_size: int = 21,
170
+ ) -> None:
171
+ super().__init__()
172
+ self.augment_eps = augment_eps
173
+ self.hidden_dim = hidden_dim
174
+ self.num_neighbors = num_neighbors
175
+ self.num_rbf = num_rbf
176
+ self.embedding = PositionalEncoding(num_positional_embedding)
177
+ self.edge_mlp = torch.nn.Sequential(
178
+ torch.nn.Linear(num_positional_embedding + 400, hidden_dim),
179
+ torch.nn.LayerNorm(hidden_dim),
180
+ torch.nn.Linear(hidden_dim, hidden_dim),
181
+ )
182
+ self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)
183
+ self.encoder_layers = torch.nn.ModuleList([
184
+ Encoder(hidden_dim * 3, hidden_dim, dropout)
185
+ for _ in range(num_encoder_layers)
186
+ ])
187
+
188
+ self.decoder_layers = torch.nn.ModuleList([
189
+ Decoder(hidden_dim * 4, hidden_dim, dropout)
190
+ for _ in range(num_decoder_layers)
191
+ ])
192
+ self.output = torch.nn.Linear(hidden_dim, vocab_size)
193
+
194
+ self.reset_parameters()
195
+
196
+ def reset_parameters(self):
197
+ for p in self.parameters():
198
+ if p.dim() > 1:
199
+ torch.nn.init.xavier_uniform_(p)
200
+
201
+ def _featurize(
202
+ self,
203
+ x: torch.Tensor,
204
+ mask: torch.Tensor,
205
+ batch: torch.Tensor,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741
208
+ b = Ca - N
209
+ c = C - Ca
210
+ a = torch.cross(b, c, dim=-1)
211
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
212
+
213
+ valid_mask = mask.bool()
214
+ valid_Ca = Ca[valid_mask]
215
+ valid_batch = batch[valid_mask]
216
+
217
+ edge_index = knn_graph(valid_Ca, k=self.num_neighbors,
218
+ batch=valid_batch, loop=True)
219
+
220
+ row, col = edge_index
221
+ original_indices = torch.arange(Ca.size(0),
222
+ device=x.device)[valid_mask]
223
+ edge_index_original = torch.stack(
224
+ [original_indices[row], original_indices[col]], dim=0)
225
+ row, col = edge_index_original
226
+
227
+ rbf_all = []
228
+ for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):
229
+ distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)
230
+ rbf = self._rbf(distances)
231
+ rbf_all.append(rbf)
232
+
233
+ return edge_index_original, torch.cat(rbf_all, dim=-1)
234
+
235
+ def _rbf(self, D: torch.Tensor) -> torch.Tensor:
236
+ D_min, D_max, D_count = 2., 22., self.num_rbf
237
+ D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
238
+ D_mu = D_mu.view([1, -1])
239
+ D_sigma = (D_max - D_min) / D_count
240
+ D_expand = torch.unsqueeze(D, -1)
241
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
242
+ return RBF
243
+
244
+ def forward(
245
+ self,
246
+ x: torch.Tensor,
247
+ chain_seq_label: torch.Tensor,
248
+ mask: torch.Tensor,
249
+ chain_mask_all: torch.Tensor,
250
+ residue_idx: torch.Tensor,
251
+ chain_encoding_all: torch.Tensor,
252
+ batch: torch.Tensor,
253
+ ) -> torch.Tensor:
254
+ device = x.device
255
+ if self.training and self.augment_eps > 0:
256
+ x = x + self.augment_eps * torch.randn_like(x)
257
+
258
+ edge_index, edge_attr = self._featurize(x, mask, batch)
259
+
260
+ row, col = edge_index
261
+ offset = residue_idx[row] - residue_idx[col]
262
+ # find self vs non-self interaction
263
+ e_chains = ((chain_encoding_all[row] -
264
+ chain_encoding_all[col]) == 0).long()
265
+ e_pos = self.embedding(offset, e_chains)
266
+ h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))
267
+ h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
268
+
269
+ # encoder
270
+ for encoder in self.encoder_layers:
271
+ h_v, h_e = encoder(h_v, edge_index, h_e)
272
+
273
+ # mask
274
+ h_label = self.label_embedding(chain_seq_label)
275
+ batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,
276
+ batch) # [B, N]
277
+ # 0 - visible - encoder, 1 - masked - decoder
278
+ decoding_order = torch.argsort(
279
+ (batch_chain_mask_all + 1e-4) * (torch.abs(
280
+ torch.randn(batch_chain_mask_all.shape, device=device))))
281
+ mask_size = batch_chain_mask_all.size(1)
282
+ permutation_matrix_reverse = F.one_hot(decoding_order,
283
+ num_classes=mask_size).float()
284
+ order_mask_backward = torch.einsum(
285
+ 'ij, biq, bjp->bqp',
286
+ 1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),
287
+ permutation_matrix_reverse,
288
+ permutation_matrix_reverse,
289
+ )
290
+ adj = to_dense_adj(edge_index, batch)
291
+ mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)
292
+
293
+ # decoder
294
+ for decoder in self.decoder_layers:
295
+ h_v = decoder(
296
+ h_v,
297
+ edge_index,
298
+ h_e,
299
+ h_label,
300
+ mask_attend,
301
+ )
302
+
303
+ logits = self.output(h_v)
304
+ return F.log_softmax(logits, dim=-1)
@@ -452,15 +452,22 @@ def to_cugraph(
452
452
  g = cugraph.Graph(directed=directed)
453
453
  df = cudf.from_dlpack(to_dlpack(edge_index.t()))
454
454
 
455
+ df = cudf.DataFrame({
456
+ 'source':
457
+ cudf.from_dlpack(to_dlpack(edge_index[0])),
458
+ 'destination':
459
+ cudf.from_dlpack(to_dlpack(edge_index[1])),
460
+ })
461
+
455
462
  if edge_weight is not None:
456
463
  assert edge_weight.dim() == 1
457
- df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))
464
+ df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
458
465
 
459
466
  g.from_cudf_edgelist(
460
467
  df,
461
- source=0,
462
- destination=1,
463
- edge_attr='2' if edge_weight is not None else None,
468
+ source='source',
469
+ destination='destination',
470
+ edge_attr='weight' if edge_weight is not None else None,
464
471
  renumber=relabel_nodes,
465
472
  )
466
473
 
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
476
483
  """
477
484
  df = g.view_edge_list()
478
485
 
479
- src = from_dlpack(df[0].to_dlpack()).long()
480
- dst = from_dlpack(df[1].to_dlpack()).long()
486
+ src = from_dlpack(df[g.source_columns].to_dlpack()).long()
487
+ dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
481
488
  edge_index = torch.stack([src, dst], dim=0)
482
489
 
483
490
  edge_weight = None
484
- if '2' in df:
485
- edge_weight = from_dlpack(df['2'].to_dlpack())
491
+ if g.weight_column is not None:
492
+ edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
486
493
 
487
494
  return edge_index, edge_weight
488
495
 
@@ -148,7 +148,7 @@ def from_smiles(
148
148
  """
149
149
  from rdkit import Chem, RDLogger
150
150
 
151
- RDLogger.DisableLog('rdApp.*') # type: ignore
151
+ RDLogger.DisableLog('rdApp.*')
152
152
 
153
153
  mol = Chem.MolFromSmiles(smiles)
154
154