rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from mpnn.model.layers.position_wise_feed_forward import PositionWiseFeedForward
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# Gather functions borrowed from ProteinMPNN;
|
|
7
|
+
# originally from:
|
|
8
|
+
# https://github.com/jingraham/neurips19-graph-protein-design/tree/master
|
|
9
|
+
def gather_edges(edge_features, neighbor_idx):
|
|
10
|
+
"""
|
|
11
|
+
Gather edge features for the neighbors of each node.
|
|
12
|
+
Args:
|
|
13
|
+
edge_features: [B,L,L,H] - edge features
|
|
14
|
+
neighbor_idx: [B,L,K] - neighbor indices
|
|
15
|
+
Returns:
|
|
16
|
+
edge_features_at_neighbors: [B,L,K,H] - neighbor edge features, gathered
|
|
17
|
+
at the neighbor indices.
|
|
18
|
+
"""
|
|
19
|
+
_, _, _, H = edge_features.shape
|
|
20
|
+
|
|
21
|
+
# neighbor_idx_expand [B,L,K,H] - expand the neighbor indices along the
|
|
22
|
+
# feature dimension.
|
|
23
|
+
neighbor_idx_expand = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, H)
|
|
24
|
+
|
|
25
|
+
# edge_features_at_neighbors [B,L,K,H] - gather the edge features at the
|
|
26
|
+
# neighbor indices.
|
|
27
|
+
edge_features_at_neighbors = torch.gather(edge_features, 2, neighbor_idx_expand)
|
|
28
|
+
|
|
29
|
+
return edge_features_at_neighbors
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def gather_nodes(node_features, neighbor_idx):
|
|
33
|
+
"""
|
|
34
|
+
Gather node features for the neighbors of each node.
|
|
35
|
+
|
|
36
|
+
NOTE: in most cases, L1 == L2. This is the straightforward case where the
|
|
37
|
+
node features are gathered at the neighbor indices for every node. However,
|
|
38
|
+
L2 can differ from L1, which allows for gathering node features of less
|
|
39
|
+
nodes (useful for gathering features during decoding, where we decode
|
|
40
|
+
one node at a time).
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
node_features: [B,L1,H] - node features
|
|
44
|
+
neighbor_idx: [B,L2,K] - neighbor indices
|
|
45
|
+
Returns:
|
|
46
|
+
node_features_at_neighbors: [B,L2,K,H] - neighbor node features,
|
|
47
|
+
gathered at the neighbor indices.
|
|
48
|
+
"""
|
|
49
|
+
B, L2, K = neighbor_idx.shape
|
|
50
|
+
_, _, H = node_features.shape
|
|
51
|
+
|
|
52
|
+
# neighbor_idx_flat [B,L2 * K] - flatten the residue index and neighbor
|
|
53
|
+
# index dimensions; this is done to allow for gathering.
|
|
54
|
+
neighbor_idx_flat = neighbor_idx.reshape((B, -1))
|
|
55
|
+
|
|
56
|
+
# neighbor_idx_flat_expand [B,L2 * K,H] - expand the neighbor indices along
|
|
57
|
+
# the feature dimension.
|
|
58
|
+
neighbor_idx_flat_expand = neighbor_idx_flat.unsqueeze(-1).expand(-1, -1, H)
|
|
59
|
+
|
|
60
|
+
# node_features_at_neighbors_flat [B,L2 * K,H] - gather the node features
|
|
61
|
+
# at the flattened neighbor indices.
|
|
62
|
+
node_features_at_neighbors_flat = torch.gather(
|
|
63
|
+
node_features, 1, neighbor_idx_flat_expand
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# node_features_at_neighbors [B,L2,K,H] - reshape the gathered node
|
|
67
|
+
# features to the original shape.
|
|
68
|
+
node_features_at_neighbors = node_features_at_neighbors_flat.view(B, L2, K, H)
|
|
69
|
+
|
|
70
|
+
return node_features_at_neighbors
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def cat_neighbors_nodes(node_features, edge_features_at_neighbors, neighbor_idx):
|
|
74
|
+
"""
|
|
75
|
+
Gather node features for the neighbors of each node and concatenate them
|
|
76
|
+
with the edge features.
|
|
77
|
+
|
|
78
|
+
NOTE: in most cases, L1 == L2. This is the straightforward case where the
|
|
79
|
+
node features are gathered at the neighbor indices for every node, then
|
|
80
|
+
concatenated with the edge features. However, L2 can differ from L1, which
|
|
81
|
+
allows for gathering node features and concatenating to edge features for
|
|
82
|
+
less nodes (useful for gathering features during decoding, where we decode
|
|
83
|
+
one node at a time).
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
node_features [B,L1,H1]: node features
|
|
87
|
+
edge_features_at_neighbors [B,L2,K,H2]: edge hidden states
|
|
88
|
+
neighbor_idx [B,L2,K]: neighbor indices
|
|
89
|
+
Returns:
|
|
90
|
+
edge_and_node_features_at_neighbors [B,L2,K,H2+H1]: concatenated node
|
|
91
|
+
and edge features, with the edge features first.
|
|
92
|
+
"""
|
|
93
|
+
# node_features_at_neighbors [B,L2,K,H1] - gather the node features at the
|
|
94
|
+
# neighbor indices.
|
|
95
|
+
node_features_at_neighbors = gather_nodes(node_features, neighbor_idx)
|
|
96
|
+
|
|
97
|
+
# edge_and_node_features_at_neighbors [B,L2,K,H2+H1] - concatenate the
|
|
98
|
+
# gathered node features with the edge features.
|
|
99
|
+
edge_and_node_features_at_neighbors = torch.cat(
|
|
100
|
+
[edge_features_at_neighbors, node_features_at_neighbors], -1
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return edge_and_node_features_at_neighbors
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class EncLayer(nn.Module):
|
|
107
|
+
def __init__(self, num_hidden, num_in, dropout=0.1, scale=30):
|
|
108
|
+
super(EncLayer, self).__init__()
|
|
109
|
+
|
|
110
|
+
self.num_hidden = num_hidden
|
|
111
|
+
self.num_in = num_in
|
|
112
|
+
self.scale = scale
|
|
113
|
+
|
|
114
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
115
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
116
|
+
self.dropout3 = nn.Dropout(dropout)
|
|
117
|
+
|
|
118
|
+
self.norm1 = nn.LayerNorm(num_hidden)
|
|
119
|
+
self.norm2 = nn.LayerNorm(num_hidden)
|
|
120
|
+
self.norm3 = nn.LayerNorm(num_hidden)
|
|
121
|
+
|
|
122
|
+
self.W1 = nn.Linear(num_in, num_hidden, bias=True)
|
|
123
|
+
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
|
|
124
|
+
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
|
|
125
|
+
|
|
126
|
+
self.W11 = nn.Linear(num_in, num_hidden, bias=True)
|
|
127
|
+
self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
|
|
128
|
+
self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
|
|
129
|
+
|
|
130
|
+
self.act = torch.nn.GELU()
|
|
131
|
+
|
|
132
|
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
|
133
|
+
|
|
134
|
+
def forward(self, h_V, h_E, E_idx, mask_V=None, mask_E=None):
|
|
135
|
+
"""
|
|
136
|
+
Encoder message passing step; updates both the node and edge hidden
|
|
137
|
+
states.
|
|
138
|
+
|
|
139
|
+
NOTE: num_in = 3H
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
h_V [B, L, H] - node hidden states
|
|
143
|
+
h_E [B, L, K, H] - edge hidden states
|
|
144
|
+
E_idx [B, L, K] - edge indices
|
|
145
|
+
mask_V [B, L] - node mask
|
|
146
|
+
mask_E [B, L, K] - edge mask
|
|
147
|
+
Returns:
|
|
148
|
+
h_V [B, L, H] - updated node hidden states
|
|
149
|
+
h_E [B, L, K, H] - updated edge hidden states
|
|
150
|
+
"""
|
|
151
|
+
# Concatenate h_V_j to h_E_ij
|
|
152
|
+
# (result h_E_ij cat h_V_j)
|
|
153
|
+
# Shape: [B, L, K, H] + [B, L, H] => [B, L, K, 2H]
|
|
154
|
+
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
|
|
155
|
+
|
|
156
|
+
# Concatenate h_V_i to h_E_ij cat h_V_j
|
|
157
|
+
# (result h_E_ij cat h_V_j cat h_V_i)
|
|
158
|
+
# Shape (h_EV): [B, L, K, 2H] + [B, L, H] => [B, L, K, 3H]
|
|
159
|
+
h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
|
|
160
|
+
h_EV = torch.cat([h_V_expand, h_EV], -1)
|
|
161
|
+
|
|
162
|
+
# Compute the message.
|
|
163
|
+
# Shape: [B, L, K, 3H] => [B, L, K, H]
|
|
164
|
+
h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
|
|
165
|
+
|
|
166
|
+
# Apply the edge mask to the message.
|
|
167
|
+
if mask_E is not None:
|
|
168
|
+
h_message = mask_E.unsqueeze(-1) * h_message
|
|
169
|
+
|
|
170
|
+
# Scaled sum aggregation.
|
|
171
|
+
# Shape: [B, L, K, H] => [B, L, H]
|
|
172
|
+
dh = torch.sum(h_message, -2) / self.scale
|
|
173
|
+
|
|
174
|
+
# Dropout + residual + norm.
|
|
175
|
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
|
176
|
+
|
|
177
|
+
# Position-wise feedforward.
|
|
178
|
+
dh = self.dense(h_V)
|
|
179
|
+
|
|
180
|
+
# Dropout + residual + norm.
|
|
181
|
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
|
182
|
+
|
|
183
|
+
# Apply the node mask to the node hidden states.
|
|
184
|
+
if mask_V is not None:
|
|
185
|
+
mask_V = mask_V.unsqueeze(-1)
|
|
186
|
+
h_V = mask_V * h_V
|
|
187
|
+
|
|
188
|
+
# Concatenate h_V_j to h_E_ij (using the updated node state).
|
|
189
|
+
# result h_E_ij cat h_V_j
|
|
190
|
+
# Shape: [B, L, K, H] + [B, L, H] => [B, L, K, 2H]
|
|
191
|
+
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
|
|
192
|
+
|
|
193
|
+
# Concatenate h_V_i to h_E_ij cat h_V_j (using the updated node state).
|
|
194
|
+
# result h_E_ij cat h_V_j cat h_V_i
|
|
195
|
+
# Shape: [B, L, K, 2H] + [B, L, H] => [B, L, K, 3H]
|
|
196
|
+
h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
|
|
197
|
+
h_EV = torch.cat([h_V_expand, h_EV], -1)
|
|
198
|
+
|
|
199
|
+
# Compute an edge update.
|
|
200
|
+
# Shape: [B, L, K, 3H] => [B, L, K, H]
|
|
201
|
+
h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
|
|
202
|
+
|
|
203
|
+
# Add the edge update to the edge hidden states.
|
|
204
|
+
# Dropout + residual + norm.
|
|
205
|
+
h_E = self.norm3(h_E + self.dropout3(h_message))
|
|
206
|
+
|
|
207
|
+
return h_V, h_E
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class DecLayer(nn.Module):
|
|
211
|
+
def __init__(self, num_hidden, num_in, dropout=0.1, scale=30):
|
|
212
|
+
super(DecLayer, self).__init__()
|
|
213
|
+
|
|
214
|
+
self.num_hidden = num_hidden
|
|
215
|
+
self.num_in = num_in
|
|
216
|
+
self.scale = scale
|
|
217
|
+
|
|
218
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
219
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
220
|
+
|
|
221
|
+
self.norm1 = nn.LayerNorm(num_hidden)
|
|
222
|
+
self.norm2 = nn.LayerNorm(num_hidden)
|
|
223
|
+
|
|
224
|
+
self.W1 = nn.Linear(num_in, num_hidden, bias=True)
|
|
225
|
+
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
|
|
226
|
+
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
|
|
227
|
+
|
|
228
|
+
self.act = torch.nn.GELU()
|
|
229
|
+
|
|
230
|
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
|
231
|
+
|
|
232
|
+
def forward(self, h_V, h_E, mask_V=None, mask_E=None):
|
|
233
|
+
"""
|
|
234
|
+
Decoder message passing step; updates only the node hidden states.
|
|
235
|
+
NOTE: this function is used for both the protein decoder and the ligand
|
|
236
|
+
context encoder. As such, this function operates on the "deepest"
|
|
237
|
+
graph in the tensor.
|
|
238
|
+
|
|
239
|
+
Below, the shapes for the protein decoder application will be as
|
|
240
|
+
follows:
|
|
241
|
+
... = empty
|
|
242
|
+
node_num = L
|
|
243
|
+
neighbor_num = K
|
|
244
|
+
num_in = 4H
|
|
245
|
+
For the ligand subgraph encoder, the shapes will be as follows:
|
|
246
|
+
... = L
|
|
247
|
+
node_num = M
|
|
248
|
+
neighbor_num = M
|
|
249
|
+
num_in = 2H
|
|
250
|
+
BUG: this should be 3H, but the original LigandMPNN does not
|
|
251
|
+
pre-concatenate the destination node features to the edge
|
|
252
|
+
features, which breaks the message passing in the ligand
|
|
253
|
+
subgraphs.
|
|
254
|
+
For the protein-ligand graph encoder, the shapes will be as follows:
|
|
255
|
+
... = empty
|
|
256
|
+
node_num = L
|
|
257
|
+
neighbor_num = M
|
|
258
|
+
num_in = 3H
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
h_V [B, ..., node_num, H] - node hidden states
|
|
262
|
+
h_E [B, ..., node_num, neighbor_num, num_in - H] - edge hidden
|
|
263
|
+
states;
|
|
264
|
+
NOTE: for message passing to behave in the decoder, the
|
|
265
|
+
destination node features (and sequence if applicable) MUST
|
|
266
|
+
be pre-concatenated to the edge features.
|
|
267
|
+
So, h_E is actually:
|
|
268
|
+
- protein decoder: h_E_ij cat h_S_j cat h_V_j
|
|
269
|
+
- ligand subgraph encoder: h_ligand_subgraph_edges_ij;
|
|
270
|
+
BUG: this should be h_ligand_subgraph_edges_ij cat
|
|
271
|
+
h_ligand_subgraph_nodes_j; in its current form
|
|
272
|
+
(replicating the original LigandMPNN), the
|
|
273
|
+
destination node features are not concatenated to
|
|
274
|
+
the edge features, which breaks the message passing
|
|
275
|
+
in the ligand subgraph.
|
|
276
|
+
- protein-ligand graph encoder: h_E_protein_to_ligand_ij
|
|
277
|
+
cat h_ligand_subgraph_nodes_j
|
|
278
|
+
mask_V [B, ..., node_num] - node mask
|
|
279
|
+
mask_E [B, ..., node_num, neighbor_num] - edge mask
|
|
280
|
+
Returns:
|
|
281
|
+
h_V [B, ..., node_num, H] - updated node hidden states
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
# Concatenate source node features to edge features, which include the
|
|
285
|
+
# destination node features.
|
|
286
|
+
# - protein decoder: concatenate h_V_i to h_E_ij cat h_S_j cat h_V_j
|
|
287
|
+
# result: h_E_ij cat h_S_j cat h_V_j cat h_V_i
|
|
288
|
+
# - ligand subgraph encoder: concatenate h_ligand_subgraph_nodes_i
|
|
289
|
+
# to h_ligand_subgraph_edges_ij
|
|
290
|
+
# result: h_ligand_subgraph_edges_ij cat
|
|
291
|
+
# h_ligand_subgraph_nodes_i
|
|
292
|
+
# - protein-ligand graph encoder: concatenate h_V_i to
|
|
293
|
+
# h_E_protein_to_ligand_ij cat h_ligand_subgraph_nodes_j
|
|
294
|
+
# result: h_E_protein_to_ligand_ij cat h_ligand_subgraph_nodes_j
|
|
295
|
+
# cat h_V_i
|
|
296
|
+
# Shape (h_EV): [B, ..., node_num, neighbor_num, num_in - H] +
|
|
297
|
+
# [B, ..., node_num, H] => [B, ..., node_num, neighbor_num, num_in]
|
|
298
|
+
h_V_expand = h_V.unsqueeze(-2).expand(
|
|
299
|
+
*h_V.shape[:-1], # B, ..., node_num
|
|
300
|
+
h_E.size(-2), # neighbor_num
|
|
301
|
+
h_V.shape[-1], # H
|
|
302
|
+
)
|
|
303
|
+
h_EV = torch.cat([h_V_expand, h_E], -1)
|
|
304
|
+
|
|
305
|
+
# Compute the message.
|
|
306
|
+
# Shape: [B, ..., node_num, neighbor_num, num_in] =>
|
|
307
|
+
# [B, ..., node_num, neighbor_num, H]
|
|
308
|
+
h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
|
|
309
|
+
|
|
310
|
+
# Apply the edge mask to the message.
|
|
311
|
+
if mask_E is not None:
|
|
312
|
+
h_message = mask_E.unsqueeze(-1) * h_message
|
|
313
|
+
|
|
314
|
+
# Scaled sum aggregation.
|
|
315
|
+
# Shape: [B, ..., node_num, neighbor_num, H] => [B, ..., node_num, H]
|
|
316
|
+
dh = torch.sum(h_message, -2) / self.scale
|
|
317
|
+
|
|
318
|
+
# Dropout + residual + norm.
|
|
319
|
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
|
320
|
+
|
|
321
|
+
# Position-wise feedforward
|
|
322
|
+
dh = self.dense(h_V)
|
|
323
|
+
|
|
324
|
+
# Dropout + residual + norm.
|
|
325
|
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
|
326
|
+
|
|
327
|
+
# Apply the node mask to the node hidden states.
|
|
328
|
+
if mask_V is not None:
|
|
329
|
+
mask_V = mask_V.unsqueeze(-1)
|
|
330
|
+
h_V = mask_V * h_V
|
|
331
|
+
|
|
332
|
+
return h_V
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PositionWiseFeedForward(nn.Module):
|
|
6
|
+
def __init__(self, num_hidden, num_ff):
|
|
7
|
+
"""
|
|
8
|
+
Position-wise feed-forward layer.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
num_hidden (int): The hidden dimension size of the input and output.
|
|
12
|
+
num_ff (int): The hidden dimension size of the feed-forward layer.
|
|
13
|
+
"""
|
|
14
|
+
super(PositionWiseFeedForward, self).__init__()
|
|
15
|
+
|
|
16
|
+
# Initialize the linear layers for the position-wise feed-forward
|
|
17
|
+
# layer with bias.
|
|
18
|
+
self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
|
|
19
|
+
self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
|
|
20
|
+
|
|
21
|
+
self.act = torch.nn.GELU()
|
|
22
|
+
|
|
23
|
+
def forward(self, h_V):
|
|
24
|
+
"""
|
|
25
|
+
Forward pass of the position-wise feed-forward layer.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
h_V (torch.Tensor): [B, L, num_hidden] - the hidden embedding
|
|
29
|
+
of the node features.
|
|
30
|
+
Returns:
|
|
31
|
+
feed_forward_output (torch.Tensor): [B, L, num_hidden] - the output
|
|
32
|
+
of the position-wise feed-forward layer.
|
|
33
|
+
"""
|
|
34
|
+
# feed_forward_latent [B, L, num_ff] - the input, projected with a
|
|
35
|
+
# linear layer to the feed-forward dimension, and then passed through
|
|
36
|
+
# the activation function.
|
|
37
|
+
feed_forward_latent = self.act(self.W_in(h_V))
|
|
38
|
+
|
|
39
|
+
# feed_forward_output [B, L, num_hidden] - the output of the
|
|
40
|
+
# position-wise feed-forward layer, projected back to the hidden
|
|
41
|
+
# dimension with a linear layer.
|
|
42
|
+
feed_forward_output = self.W_out(feed_forward_latent)
|
|
43
|
+
|
|
44
|
+
return feed_forward_output
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PositionalEncodings(nn.Module):
|
|
6
|
+
def __init__(self, num_positional_embeddings, max_relative_feature=32):
|
|
7
|
+
"""
|
|
8
|
+
Positional encodings for the MPNN model.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
num_positional_embeddings (int): The dimension of the embeddings for
|
|
12
|
+
the positional encodings.
|
|
13
|
+
max_relative_feature (int): The maximum relative feature offset.
|
|
14
|
+
Default is 32, which means the positional encodings will handle
|
|
15
|
+
offsets in the range [-32, 32]. This is used to determine the
|
|
16
|
+
size of the one-hot encoding for the positional offsets.
|
|
17
|
+
"""
|
|
18
|
+
super(PositionalEncodings, self).__init__()
|
|
19
|
+
|
|
20
|
+
# Store the number of embeddings and the maximum relative feature.
|
|
21
|
+
self.num_positional_embeddings = num_positional_embeddings
|
|
22
|
+
self.max_relative_feature = max_relative_feature
|
|
23
|
+
|
|
24
|
+
# We reserve enough space for the -max_relative_feature,...,0,...,
|
|
25
|
+
# max_relative_feature, plus an additional input for residue pairs not
|
|
26
|
+
self.num_positional_features = 2 * max_relative_feature + 1 + 1
|
|
27
|
+
|
|
28
|
+
# Initialize the linear layer that will map the one-hot encoding of the
|
|
29
|
+
# positional offsets to the embeddings.
|
|
30
|
+
self.embed_positional_features = nn.Linear(
|
|
31
|
+
self.num_positional_features, num_positional_embeddings
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def forward(self, positional_offset, same_chain_mask):
|
|
35
|
+
"""
|
|
36
|
+
Forward pass of the positional encodings.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
positional_offset (torch.Tensor): [B, L, K] - pairwise differences
|
|
40
|
+
between the indices of residues, gathered for the K nearest
|
|
41
|
+
neighbors.
|
|
42
|
+
same_chain_mask (torch.Tensor): [B, L, K] - a mask indicating
|
|
43
|
+
whether the residues are on the same chain (1) or not (0).
|
|
44
|
+
Returns:
|
|
45
|
+
positional_offset_embeddings (torch.Tensor): [B, L, K,
|
|
46
|
+
self.num_positional_embeddings] - the embeddings for the
|
|
47
|
+
positional offsets, where each offset shifted and clipped to
|
|
48
|
+
the range [0, 2 * max_relative_feature], with a special value
|
|
49
|
+
of (2 * max_relative_feature + 1) for residues not on the same
|
|
50
|
+
chain. The embeddings are obtained by passing the one-hot
|
|
51
|
+
encoding of the chain-aware clipped positional offsets through
|
|
52
|
+
a linear layer.
|
|
53
|
+
"""
|
|
54
|
+
# Check that the same chain mask has a boolean dtype.
|
|
55
|
+
if same_chain_mask.dtype != torch.bool:
|
|
56
|
+
raise ValueError("The same_chain_mask must be of boolean dtype.")
|
|
57
|
+
|
|
58
|
+
# shifted_positional_offset [B, L, K] - the positional offset shifted
|
|
59
|
+
# by the maximum relative feature.
|
|
60
|
+
shifted_positional_offset = positional_offset + self.max_relative_feature
|
|
61
|
+
|
|
62
|
+
# clipped_positional_offset [B, L, K] - the shifted positional offset
|
|
63
|
+
# clipped to the range [0, 2 * max_relative_feature]. Combining the
|
|
64
|
+
# shifting and clipping, this captures original positional offsets in
|
|
65
|
+
# the range [-max_relative_feature, max_relative_feature], shifting
|
|
66
|
+
# them to the range [0, 2 * max_relative_feature], clipping any values
|
|
67
|
+
# outside this range to the nearest valid value. The shifting to non-
|
|
68
|
+
# negative values is necessary for the one-hot encoding.
|
|
69
|
+
clipped_positional_offset = torch.clip(
|
|
70
|
+
shifted_positional_offset, 0, 2 * self.max_relative_feature
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# clipped_positional_offset_chain_aware [B, L, K] - the clipped
|
|
74
|
+
# positional offset, where the values for residues on the same chain
|
|
75
|
+
# are preserved, and the values for residues not on the same chain are
|
|
76
|
+
# set to a special value (2 * max_relative_feature + 1). This is
|
|
77
|
+
# done to ensure that the positional embeddings are chain-aware.
|
|
78
|
+
clipped_positional_offset_chain_aware = (
|
|
79
|
+
clipped_positional_offset * same_chain_mask
|
|
80
|
+
+ (~same_chain_mask) * (2 * self.max_relative_feature + 1)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# clipped_positional_offset_chain_aware_onehot [B, L, K,
|
|
84
|
+
# self.num_positional_features] - the chained-aware clipped positional
|
|
85
|
+
# offset converted to a one-hot encoding.
|
|
86
|
+
clipped_positional_offset_chain_aware_onehot = torch.nn.functional.one_hot(
|
|
87
|
+
clipped_positional_offset_chain_aware,
|
|
88
|
+
num_classes=self.num_positional_features,
|
|
89
|
+
).float()
|
|
90
|
+
|
|
91
|
+
# positional_offset_embeddings [B, L, K, self.num_positional_embeddings]
|
|
92
|
+
# - the embeddings for the positional offsets, obtained by passing the
|
|
93
|
+
# one-hot encoding through a linear layer.
|
|
94
|
+
positional_offset_embeddings = self.embed_positional_features(
|
|
95
|
+
clipped_positional_offset_chain_aware_onehot
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return positional_offset_embeddings
|