rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
mpnn/model/mpnn.py
ADDED
|
@@ -0,0 +1,2632 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from atomworks.constants import UNKNOWN_AA
|
|
4
|
+
from mpnn.model.layers.graph_embeddings import (
|
|
5
|
+
ProteinFeatures,
|
|
6
|
+
ProteinFeaturesLigand,
|
|
7
|
+
ProteinFeaturesMembrane,
|
|
8
|
+
ProteinFeaturesPSSM,
|
|
9
|
+
)
|
|
10
|
+
from mpnn.model.layers.message_passing import (
|
|
11
|
+
DecLayer,
|
|
12
|
+
EncLayer,
|
|
13
|
+
cat_neighbors_nodes,
|
|
14
|
+
gather_nodes,
|
|
15
|
+
)
|
|
16
|
+
from mpnn.utils.probability import sample_bernoulli_rv
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProteinMPNN(nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
Class for default ProteinMPNN.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
HAS_NODE_FEATURES = False
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def init_weights(module):
|
|
28
|
+
"""
|
|
29
|
+
Initialize the weights of the module.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
module (nn.Module): The module to initialize.
|
|
33
|
+
Side Effects:
|
|
34
|
+
Initializes the weights of the module using Xavier uniform
|
|
35
|
+
initialization for parameters with a dimension greater than 1.
|
|
36
|
+
"""
|
|
37
|
+
# Initialize the weights of the module using Xavier uniform, skipping
|
|
38
|
+
# any parameters with a dimension of 1 or less (for example, biases).
|
|
39
|
+
for parameter in module.parameters():
|
|
40
|
+
if parameter.dim() > 1:
|
|
41
|
+
nn.init.xavier_uniform_(parameter)
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
num_node_features=128,
|
|
46
|
+
num_edge_features=128,
|
|
47
|
+
hidden_dim=128,
|
|
48
|
+
num_encoder_layers=3,
|
|
49
|
+
num_decoder_layers=3,
|
|
50
|
+
num_neighbors=48,
|
|
51
|
+
dropout_rate=0.1,
|
|
52
|
+
num_positional_embeddings=16,
|
|
53
|
+
min_rbf_mean=2.0,
|
|
54
|
+
max_rbf_mean=22.0,
|
|
55
|
+
num_rbf=16,
|
|
56
|
+
graph_featurization_module=None,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Setup the ProteinMPNN model.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
num_node_features (int): Number of node features.
|
|
63
|
+
num_edge_features (int): Number of edge features.
|
|
64
|
+
hidden_dim (int): Hidden dimension size.
|
|
65
|
+
num_encoder_layers (int): Number of encoder layers.
|
|
66
|
+
num_decoder_layers (int): Number of decoder layers.
|
|
67
|
+
num_neighbors (int): Number of neighbors for each polymer residue.
|
|
68
|
+
dropout_rate (float): Dropout rate.
|
|
69
|
+
num_positional_embeddings (int): Number of positional embeddings.
|
|
70
|
+
min_rbf_mean (float): Minimum radial basis function mean.
|
|
71
|
+
max_rbf_mean (float): Maximum radial basis function mean.
|
|
72
|
+
num_rbf (int): Number of radial basis functions.
|
|
73
|
+
graph_featurization_module (nn.Module, optional): Custom graph
|
|
74
|
+
featurization module. If None, the default ProteinFeatures
|
|
75
|
+
module is used.
|
|
76
|
+
"""
|
|
77
|
+
super(ProteinMPNN, self).__init__()
|
|
78
|
+
|
|
79
|
+
# Internal dimensions
|
|
80
|
+
self.num_node_features = num_node_features
|
|
81
|
+
self.num_edge_features = num_edge_features
|
|
82
|
+
self.hidden_dim = hidden_dim
|
|
83
|
+
|
|
84
|
+
# Number of layers in the encoder and decoder
|
|
85
|
+
self.num_encoder_layers = num_encoder_layers
|
|
86
|
+
self.num_decoder_layers = num_decoder_layers
|
|
87
|
+
|
|
88
|
+
# Dropout rate
|
|
89
|
+
self.dropout_rate = dropout_rate
|
|
90
|
+
|
|
91
|
+
# Module for featurizing the graph.
|
|
92
|
+
if graph_featurization_module is not None:
|
|
93
|
+
self.graph_featurization_module = graph_featurization_module
|
|
94
|
+
else:
|
|
95
|
+
self.graph_featurization_module = ProteinFeatures(
|
|
96
|
+
num_edge_output_features=num_edge_features,
|
|
97
|
+
num_node_output_features=num_node_features,
|
|
98
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
99
|
+
min_rbf_mean=min_rbf_mean,
|
|
100
|
+
max_rbf_mean=max_rbf_mean,
|
|
101
|
+
num_rbf=num_rbf,
|
|
102
|
+
num_neighbors=num_neighbors,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Provide a shorter reference to the graph featurization token-to-idx
|
|
106
|
+
# mapping.
|
|
107
|
+
self.token_to_idx = self.graph_featurization_module.TOKEN_ENCODING.token_to_idx
|
|
108
|
+
|
|
109
|
+
# Size of the vocabulary
|
|
110
|
+
self.vocab_size = self.graph_featurization_module.TOKEN_ENCODING.n_tokens
|
|
111
|
+
|
|
112
|
+
# Unknown residue token indices, from the TOKEN_ENCODING.
|
|
113
|
+
self.unknown_token_indices = list(
|
|
114
|
+
map(
|
|
115
|
+
lambda token: self.token_to_idx[token],
|
|
116
|
+
self.graph_featurization_module.TOKEN_ENCODING.unknown_tokens,
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Linear layer for the edge features.
|
|
121
|
+
self.W_e = nn.Linear(num_edge_features, hidden_dim, bias=True)
|
|
122
|
+
|
|
123
|
+
# Linear layer for the sequence features.
|
|
124
|
+
self.W_s = nn.Embedding(self.vocab_size, hidden_dim)
|
|
125
|
+
|
|
126
|
+
if self.HAS_NODE_FEATURES:
|
|
127
|
+
# Linear layer for the node features.
|
|
128
|
+
self.W_v = nn.Linear(num_node_features, hidden_dim, bias=True)
|
|
129
|
+
|
|
130
|
+
# Dropout layer
|
|
131
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
132
|
+
|
|
133
|
+
# Encoder layers
|
|
134
|
+
self.encoder_layers = nn.ModuleList(
|
|
135
|
+
[
|
|
136
|
+
EncLayer(hidden_dim, hidden_dim * 3, dropout=dropout_rate)
|
|
137
|
+
for _ in range(num_encoder_layers)
|
|
138
|
+
]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Decoder layers
|
|
142
|
+
self.decoder_layers = nn.ModuleList(
|
|
143
|
+
[
|
|
144
|
+
DecLayer(hidden_dim, hidden_dim * 4, dropout=dropout_rate)
|
|
145
|
+
for _ in range(num_decoder_layers)
|
|
146
|
+
]
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Linear layer for the output
|
|
150
|
+
self.W_out = nn.Linear(hidden_dim, self.vocab_size, bias=True)
|
|
151
|
+
|
|
152
|
+
def construct_known_residue_mask(self, S):
|
|
153
|
+
"""
|
|
154
|
+
Construct a mask for the known residues based on the sequence S.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
S (torch.Tensor): [B, L] - the sequence of residues.
|
|
158
|
+
Returns:
|
|
159
|
+
known_residue_mask (torch.Tensor): [B, L] - mask for known residues,
|
|
160
|
+
where True is a residue with one of the canonical residue types,
|
|
161
|
+
and False is a residue with an unknown residue type.
|
|
162
|
+
"""
|
|
163
|
+
# Create a mask for known residues.
|
|
164
|
+
known_residue_mask = torch.isin(
|
|
165
|
+
S,
|
|
166
|
+
torch.tensor(self.unknown_token_indices, device=S.device, dtype=S.dtype),
|
|
167
|
+
invert=True,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return known_residue_mask
|
|
171
|
+
|
|
172
|
+
def sample_and_construct_masks(self, input_features):
|
|
173
|
+
"""
|
|
174
|
+
Sample and construct masks for the input features.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
input_features (dict): Input features containing the residue mask.
|
|
178
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
179
|
+
- S (torch.Tensor): [B, L] - sequence of residues.
|
|
180
|
+
- designed_residue_mask (torch.Tensor): [B, L] - mask for the
|
|
181
|
+
designed residues.
|
|
182
|
+
Side Effects:
|
|
183
|
+
input_features["residue_mask"] (torch.Tensor): [B, L] - mask for the
|
|
184
|
+
residues, where True is a residue that is valid and False is a
|
|
185
|
+
residue that is invalid.
|
|
186
|
+
input_features["known_residue_mask"] (torch.Tensor): [B, L] - mask
|
|
187
|
+
for known residues, where True is a residue with one of the
|
|
188
|
+
canonical residue types, and False is a residue with an unknown
|
|
189
|
+
residue type.
|
|
190
|
+
input_features["designed_residue_mask"] (torch.Tensor): [B, L] -
|
|
191
|
+
mask for designed residues, where True is a residue that is
|
|
192
|
+
designed, and False is a residue that is not designed.
|
|
193
|
+
input_features["mask_for_loss"] (torch.Tensor): [B, L] - mask for
|
|
194
|
+
loss, where True is a residue that is included in the loss
|
|
195
|
+
calculation, and False is a residue that is not included in the
|
|
196
|
+
loss calculation.
|
|
197
|
+
"""
|
|
198
|
+
# Check that the input features contain the necessary keys.
|
|
199
|
+
if "residue_mask" not in input_features:
|
|
200
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
201
|
+
if "S" not in input_features:
|
|
202
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
203
|
+
if "designed_residue_mask" not in input_features:
|
|
204
|
+
raise ValueError("Input features must contain 'designed_residue_mask' key.")
|
|
205
|
+
|
|
206
|
+
# Ensure that the residue_mask is a boolean.
|
|
207
|
+
input_features["residue_mask"] = input_features["residue_mask"].bool()
|
|
208
|
+
|
|
209
|
+
# Mask is true for canonical residues, false for unknown residues.
|
|
210
|
+
input_features["known_residue_mask"] = self.construct_known_residue_mask(
|
|
211
|
+
input_features["S"]
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Mask for residues that are designed. If the designed_residue_mask
|
|
215
|
+
# is None, then we assume that all valid residues are designed.
|
|
216
|
+
if input_features["designed_residue_mask"] is None:
|
|
217
|
+
input_features["designed_residue_mask"] = input_features[
|
|
218
|
+
"residue_mask"
|
|
219
|
+
].clone()
|
|
220
|
+
else:
|
|
221
|
+
input_features["designed_residue_mask"] = input_features[
|
|
222
|
+
"designed_residue_mask"
|
|
223
|
+
].bool()
|
|
224
|
+
|
|
225
|
+
# Chech that the designed_residue_mask is a subset of valid residues.
|
|
226
|
+
if not torch.all(
|
|
227
|
+
input_features["residue_mask"][input_features["designed_residue_mask"]]
|
|
228
|
+
):
|
|
229
|
+
raise ValueError("Designed residues must all be valid residues.")
|
|
230
|
+
|
|
231
|
+
# Mask for loss.
|
|
232
|
+
input_features["mask_for_loss"] = (
|
|
233
|
+
input_features["residue_mask"]
|
|
234
|
+
& input_features["known_residue_mask"]
|
|
235
|
+
& input_features["designed_residue_mask"]
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def graph_featurization(self, input_features):
|
|
239
|
+
"""
|
|
240
|
+
Apply the graph featurization to the input features.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
input_features (dict): Input features to be featurized.
|
|
244
|
+
Returns:
|
|
245
|
+
graph_features (dict): Featurized graph features (contains both node
|
|
246
|
+
and edge features).
|
|
247
|
+
"""
|
|
248
|
+
graph_features = self.graph_featurization_module(input_features)
|
|
249
|
+
|
|
250
|
+
return graph_features
|
|
251
|
+
|
|
252
|
+
def encode(self, input_features, graph_features):
|
|
253
|
+
"""
|
|
254
|
+
Encode the protein features with message passing.
|
|
255
|
+
|
|
256
|
+
# NOTE: K = self.num_neighbors
|
|
257
|
+
Args:
|
|
258
|
+
input_features (dict): Input features containing the residue mask.
|
|
259
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
260
|
+
graph_features (dict): Graph features containing the featurized
|
|
261
|
+
node and edge inputs.
|
|
262
|
+
- E (torch.Tensor): [B, L, K, self.num_edge_features] - edge
|
|
263
|
+
features.
|
|
264
|
+
- E_idx (torch.Tensor): [B, L, K] - edge indices.
|
|
265
|
+
- V (torch.Tensor, optional): [B, L, self.num_node_features] -
|
|
266
|
+
node features (if HAS_NODE_FEATURES is True).
|
|
267
|
+
Returns:
|
|
268
|
+
encoder_features (dict): Encoded features containing the encoded
|
|
269
|
+
protein node and protein edge features.
|
|
270
|
+
- h_V (torch.Tensor): [B, L, self.hidden_dim] - the protein node
|
|
271
|
+
features after encoding message passing.
|
|
272
|
+
- h_E (torch.Tensor): [B, L, K, self.hidden_dim] - the protein
|
|
273
|
+
edge features after encoding message passing.
|
|
274
|
+
"""
|
|
275
|
+
# Check that the input features contains the necessary keys.
|
|
276
|
+
if "residue_mask" not in input_features:
|
|
277
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
278
|
+
|
|
279
|
+
# Check that the graph features contains the necessary keys.
|
|
280
|
+
if "E" not in graph_features:
|
|
281
|
+
raise ValueError("Graph features must contain 'E' key.")
|
|
282
|
+
if "E_idx" not in graph_features:
|
|
283
|
+
raise ValueError("Graph features must contain 'E_idx' key.")
|
|
284
|
+
|
|
285
|
+
B, L, _, _ = graph_features["E"].shape
|
|
286
|
+
|
|
287
|
+
# Embed the node features.
|
|
288
|
+
# h_V [B, L, self.num_node_features] - the embedding of the node
|
|
289
|
+
# features.
|
|
290
|
+
if self.HAS_NODE_FEATURES:
|
|
291
|
+
if "V" not in graph_features:
|
|
292
|
+
raise ValueError("Graph features must contain 'V' key.")
|
|
293
|
+
h_V = self.W_v(graph_features["V"])
|
|
294
|
+
else:
|
|
295
|
+
h_V = torch.zeros(
|
|
296
|
+
(B, L, self.num_node_features), device=graph_features["E"].device
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Embed the edge features.
|
|
300
|
+
# h_E [B, L, K, self.edge_features] - the embedding of the edge
|
|
301
|
+
# features.
|
|
302
|
+
h_E = self.W_e(graph_features["E"])
|
|
303
|
+
|
|
304
|
+
# Gather the per-residue mask of the nearest neighbors.
|
|
305
|
+
# mask_E [B, L, K] - the mask for the edges, gathered at the
|
|
306
|
+
# neighbor indices.
|
|
307
|
+
mask_E = gather_nodes(
|
|
308
|
+
input_features["residue_mask"].unsqueeze(-1), graph_features["E_idx"]
|
|
309
|
+
).squeeze(-1)
|
|
310
|
+
mask_E = input_features["residue_mask"].unsqueeze(-1) * mask_E
|
|
311
|
+
|
|
312
|
+
# Perform the message passing in the encoder.
|
|
313
|
+
for layer in self.encoder_layers:
|
|
314
|
+
# h_V [B, L, self.hidden_dim] - the updated node features.
|
|
315
|
+
# h_E [B, L, K, self.hidden_dim] - the updated edge features.
|
|
316
|
+
h_V, h_E = torch.utils.checkpoint.checkpoint(
|
|
317
|
+
layer,
|
|
318
|
+
h_V,
|
|
319
|
+
h_E,
|
|
320
|
+
graph_features["E_idx"],
|
|
321
|
+
mask_V=input_features["residue_mask"],
|
|
322
|
+
mask_E=mask_E,
|
|
323
|
+
use_reentrant=False,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Create the encoder features dictionary.
|
|
327
|
+
encoder_features = {
|
|
328
|
+
"h_V": h_V,
|
|
329
|
+
"h_E": h_E,
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
return encoder_features
|
|
333
|
+
|
|
334
|
+
def setup_causality_masks(self, input_features, graph_features, decoding_eps=1e-4):
|
|
335
|
+
"""
|
|
336
|
+
Setup the causality masks for the decoder. This can involve sampling
|
|
337
|
+
the decoding order.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
input_features (dict): Input features containing the residue mask.
|
|
341
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
342
|
+
- designed_residue_mask (torch.Tensor): [B, L] - mask for the
|
|
343
|
+
designed residues.
|
|
344
|
+
- symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
|
|
345
|
+
an integer for every residue, indicating the symmetry group
|
|
346
|
+
that it belongs to. If None, the residues are not grouped by
|
|
347
|
+
symmetry. For example, if residue i and j should be decoded
|
|
348
|
+
symmetrically, then symmetry_equivalence_group[i] ==
|
|
349
|
+
symmetry_equivalence_group[j]. Must be torch.int64 to allow
|
|
350
|
+
for use as an index. These values should range from 0 to
|
|
351
|
+
the maximum number of symmetry groups - 1 for each example.
|
|
352
|
+
- causality_pattern (str): The pattern of causality to use for
|
|
353
|
+
the decoder.
|
|
354
|
+
graph_features (dict): Graph features containing the featurized
|
|
355
|
+
node and edge inputs.
|
|
356
|
+
- E_idx (torch.Tensor): [B, L, K] - edge indices.
|
|
357
|
+
decoding_eps (float): Small epsilon value added to the
|
|
358
|
+
decode_last_mask to prevent the case where every randomly
|
|
359
|
+
sampled number is multiplied by zero, which would result
|
|
360
|
+
in an incorrect decoding order.
|
|
361
|
+
Returns:
|
|
362
|
+
decoder_features (dict): Decoding features containing the decoding
|
|
363
|
+
order and masks for the decoder.
|
|
364
|
+
- causal_mask (torch.Tensor): [B, L, K, 1] - the causal mask for
|
|
365
|
+
the decoder.
|
|
366
|
+
- anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
|
|
367
|
+
anti-causal mask for the decoder.
|
|
368
|
+
- decoding_order (torch.Tensor): [B, L] - the order in which the
|
|
369
|
+
residues should be decoded.
|
|
370
|
+
- decode_last_mask (torch.Tensor): [B, L] - mask for residues
|
|
371
|
+
that should be decoded last, where False is a residue that
|
|
372
|
+
should be decoded first (invalid or fixed), and True is a
|
|
373
|
+
residue that should not be decoded first (designed
|
|
374
|
+
residues).
|
|
375
|
+
"""
|
|
376
|
+
# Check that the input features contains the necessary keys.
|
|
377
|
+
if "residue_mask" not in input_features:
|
|
378
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
379
|
+
if "designed_residue_mask" not in input_features:
|
|
380
|
+
raise ValueError("Input features must contain 'designed_residue_mask' key.")
|
|
381
|
+
if "symmetry_equivalence_group" not in input_features:
|
|
382
|
+
raise ValueError(
|
|
383
|
+
"Input features must contain 'symmetry_equivalence_group' key."
|
|
384
|
+
)
|
|
385
|
+
if "causality_pattern" not in input_features:
|
|
386
|
+
raise ValueError("Input features must contain 'causality_pattern' key.")
|
|
387
|
+
|
|
388
|
+
# Check that the encoder features contains the necessary keys.
|
|
389
|
+
if "E_idx" not in graph_features:
|
|
390
|
+
raise ValueError("Graph features must contain 'E_idx' key.")
|
|
391
|
+
|
|
392
|
+
B, L = input_features["residue_mask"].shape
|
|
393
|
+
|
|
394
|
+
# decode_last_mask [B, L] - mask for residues that should be
|
|
395
|
+
# decoded last, where False is a residue that should be decoded first
|
|
396
|
+
# (invalid or fixed), and True is a residue that should not be decoded
|
|
397
|
+
# first (designed residues).
|
|
398
|
+
decode_last_mask = (
|
|
399
|
+
input_features["residue_mask"] & input_features["designed_residue_mask"]
|
|
400
|
+
).bool()
|
|
401
|
+
|
|
402
|
+
# Compute the noise for the decoding order.
|
|
403
|
+
if input_features["symmetry_equivalence_group"] is None:
|
|
404
|
+
# noise [B, L] - the noise for each residue, sampled from a normal
|
|
405
|
+
# distribution. This is used to randomly sample the decoding order.
|
|
406
|
+
noise = torch.randn((B, L), device=input_features["residue_mask"].device)
|
|
407
|
+
else:
|
|
408
|
+
# Assume that all symmetry groups are non-negative.
|
|
409
|
+
assert input_features["symmetry_equivalence_group"].min() >= 0
|
|
410
|
+
|
|
411
|
+
# Compute the maximum number of symmetry groups.
|
|
412
|
+
G = int(input_features["symmetry_equivalence_group"].max().item()) + 1
|
|
413
|
+
|
|
414
|
+
# noise_per_group [B, G] - the noise for each
|
|
415
|
+
# symmetry group, sampled from a normal distribution.
|
|
416
|
+
noise_per_group = torch.randn(
|
|
417
|
+
(B, G), device=input_features["residue_mask"].device
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# batch_idx [B, 1] - the batch indices for each example.
|
|
421
|
+
batch_idx = torch.arange(
|
|
422
|
+
B, device=input_features["residue_mask"].device
|
|
423
|
+
).unsqueeze(-1)
|
|
424
|
+
|
|
425
|
+
# noise [B, L] - the noise for each residue, sampled from a normal
|
|
426
|
+
# distribution, where the noise is the same for residues in the same
|
|
427
|
+
# symmetry group.
|
|
428
|
+
noise = noise_per_group[
|
|
429
|
+
batch_idx, input_features["symmetry_equivalence_group"]
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
# decoding_order [B, L] - the order in which the residues should be
|
|
433
|
+
# decoded. Specifically, decoding_order[b, i] = j specifies that the
|
|
434
|
+
# jth residue should be decoded ith. Sampled for every example.
|
|
435
|
+
# Numbers will be smaller where decode_last_mask is False (0), and
|
|
436
|
+
# larger where decode_last_mask is True (1), leading to the appropriate
|
|
437
|
+
# index ordering after the argsort.
|
|
438
|
+
decoding_order = torch.argsort(
|
|
439
|
+
(decode_last_mask.float() + decoding_eps) * torch.abs(noise), dim=-1
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# permutation_matrix_reverse [B, L, L] - the reverse permutation
|
|
443
|
+
# matrix (the transpose/inverse of the permutation matrix) computed from
|
|
444
|
+
# the decoding order; such that permutation_matrix_reverse[i, j] = 1 if
|
|
445
|
+
# the ith entry in the original will be sent to the jth position (the
|
|
446
|
+
# ith row/column in the original all by all causal mask will be sent
|
|
447
|
+
# to the jth row/column in the permuted all by all causal mask).
|
|
448
|
+
permutation_matrix_reverse = torch.nn.functional.one_hot(
|
|
449
|
+
decoding_order, num_classes=L
|
|
450
|
+
).float()
|
|
451
|
+
|
|
452
|
+
# Create the all by all causal mask for the decoder.
|
|
453
|
+
# causal_mask_all_by_all [L, L] - the all by all causal mask for the
|
|
454
|
+
# decoder, constructed based on the specified causality pattern.
|
|
455
|
+
if input_features["causality_pattern"] == "auto_regressive":
|
|
456
|
+
# left_to_right_causal_mask [L, L] - the causal mask for the
|
|
457
|
+
# left-to-right attention (lower triangular with zeros on the
|
|
458
|
+
# diagonal). Residue at position i can "see" residues at positions
|
|
459
|
+
# j < i, but not at positions j >= i.
|
|
460
|
+
left_to_right_causal_mask = 1 - torch.triu(
|
|
461
|
+
torch.ones(L, L, device=input_features["residue_mask"].device)
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
causal_mask_all_by_all = left_to_right_causal_mask
|
|
465
|
+
elif input_features["causality_pattern"] == "unconditional":
|
|
466
|
+
# zeros_causal_mask [L, L] - the causal mask for the decoder,
|
|
467
|
+
# where all entries are zeros. Residue at position i cannot see
|
|
468
|
+
# any other residues, including itself.
|
|
469
|
+
zeros_causal_mask = torch.zeros(
|
|
470
|
+
(L, L), device=input_features["residue_mask"].device
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
causal_mask_all_by_all = zeros_causal_mask
|
|
474
|
+
elif input_features["causality_pattern"] == "conditional":
|
|
475
|
+
# ones_causal_mask [L, L] - the causal mask for the decoder,
|
|
476
|
+
# where all entries are ones. Residue at position i can see all
|
|
477
|
+
# other residues, including itself.
|
|
478
|
+
ones_causal_mask = torch.ones(
|
|
479
|
+
(L, L), device=input_features["residue_mask"].device
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
causal_mask_all_by_all = ones_causal_mask
|
|
483
|
+
elif input_features["causality_pattern"] == "conditional_minus_self":
|
|
484
|
+
# I [L, L] - the identity matrix, repeated along the batch.
|
|
485
|
+
I = torch.eye(L, device=input_features["residue_mask"].device)
|
|
486
|
+
|
|
487
|
+
# ones_minus_self_causal_mask [L, L] - the causal mask for the
|
|
488
|
+
# decoder, where all entries are ones except for the diagonal
|
|
489
|
+
# entries, which are zeros. Residue at position i can see all other
|
|
490
|
+
# residues, but not itself.
|
|
491
|
+
ones_minus_self_causal_mask = (
|
|
492
|
+
torch.ones((L, L), device=input_features["residue_mask"].device) - I
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
causal_mask_all_by_all = ones_minus_self_causal_mask
|
|
496
|
+
else:
|
|
497
|
+
raise ValueError(
|
|
498
|
+
"Unknown causality pattern: " + f"{input_features['causality_pattern']}"
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# permuted_causal_mask_all_by_all [B, L, L] - the causal mask for the
|
|
502
|
+
# decoder, permuted according to the decoding order.
|
|
503
|
+
permuted_causal_mask_all_by_all = torch.einsum(
|
|
504
|
+
"ij, biq, bjp->bqp",
|
|
505
|
+
causal_mask_all_by_all,
|
|
506
|
+
permutation_matrix_reverse,
|
|
507
|
+
permutation_matrix_reverse,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# If the symmetry equivalence group is not None, then we need to
|
|
511
|
+
# mask out residues that belong to the same symmetry group.
|
|
512
|
+
if input_features["symmetry_equivalence_group"] is not None:
|
|
513
|
+
# same_symmetry_group [B, L, L] - a mask for the residues that
|
|
514
|
+
# belong to the same symmetry group, where True is a residue pair
|
|
515
|
+
# that belongs to the same symmetry group, and False is a residue
|
|
516
|
+
# pair that does not belong to the same symmetry group.
|
|
517
|
+
same_symmetry_group = (
|
|
518
|
+
input_features["symmetry_equivalence_group"][:, :, None]
|
|
519
|
+
== input_features["symmetry_equivalence_group"][:, None, :]
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
permuted_causal_mask_all_by_all[same_symmetry_group] = 0.0
|
|
523
|
+
|
|
524
|
+
# causal_mask_nearest_neighbors [B, L, K, 1] - the causal mask for
|
|
525
|
+
# the decoder, gathered at the neighbor indices. This limits the
|
|
526
|
+
# attention to the nearest neighbors.
|
|
527
|
+
causal_mask_nearest_neighbors = torch.gather(
|
|
528
|
+
permuted_causal_mask_all_by_all, 2, graph_features["E_idx"]
|
|
529
|
+
).unsqueeze(-1)
|
|
530
|
+
|
|
531
|
+
# causal_mask [B, L, K, 1] - the final causal mask for the decoder;
|
|
532
|
+
# masked version of causal_mask_nearest_neighbors.
|
|
533
|
+
causal_mask = (
|
|
534
|
+
causal_mask_nearest_neighbors
|
|
535
|
+
* input_features["residue_mask"].view([B, L, 1, 1]).float()
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# anti_causal_mask [B, L, K, 1] - the anti-causal mask for the decoder.
|
|
539
|
+
anti_causal_mask = (1.0 - causal_mask_nearest_neighbors) * input_features[
|
|
540
|
+
"residue_mask"
|
|
541
|
+
].view([B, L, 1, 1]).float()
|
|
542
|
+
|
|
543
|
+
# Add the masks to the decoder features.
|
|
544
|
+
decoder_features = {
|
|
545
|
+
"causal_mask": causal_mask,
|
|
546
|
+
"anti_causal_mask": anti_causal_mask,
|
|
547
|
+
"decoding_order": decoding_order,
|
|
548
|
+
"decode_last_mask": decode_last_mask,
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
return decoder_features
|
|
552
|
+
|
|
553
|
+
def repeat_along_batch(self, input_features, graph_features, encoder_features):
|
|
554
|
+
"""
|
|
555
|
+
Given the input features, graph features, and encoder features,
|
|
556
|
+
repeat the samples along the batch dimension. This is useful during
|
|
557
|
+
inference, to prevent re-running the encoder for every sample (since
|
|
558
|
+
the encoder is deterministic and sequence-agnostic).
|
|
559
|
+
|
|
560
|
+
NOTE: if `repeat_sample_num` is not None and greater than 1, then
|
|
561
|
+
B must be 1, since repeating samples along the batch dimension is not
|
|
562
|
+
supported when more than one sample is provided in the batch.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
input_features (dict): Input features containing the residue mask
|
|
566
|
+
and sequence.
|
|
567
|
+
- S (torch.Tensor): [B, L] - sequence of residues.
|
|
568
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
569
|
+
- temperature (torch.Tensor, optional): [B, L] - the per-residue
|
|
570
|
+
temperature to use for sampling. If None, the code will
|
|
571
|
+
implicitly use a temperature of 1.0.
|
|
572
|
+
- bias (torch.Tensor, optional): [B, L, 21] - the per-residue
|
|
573
|
+
bias to use for sampling. If None, the code will implicitly
|
|
574
|
+
use a bias of 0.0 for all residues.
|
|
575
|
+
- pair_bias (torch.Tensor, optional): [B, L, 21, L, 21] - the
|
|
576
|
+
per-residue pair bias to use for sampling. If None, the code
|
|
577
|
+
will implicitly use a pair bias of 0.0 for all residue
|
|
578
|
+
pairs.
|
|
579
|
+
- symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
|
|
580
|
+
an integer for every residue, indicating the symmetry group
|
|
581
|
+
that it belongs to. If None, the residues are not grouped by
|
|
582
|
+
symmetry. For example, if residue i and j should be decoded
|
|
583
|
+
symmetrically, then symmetry_equivalence_group[i] ==
|
|
584
|
+
symmetry_equivalence_group[j]. Must be torch.int64 to allow
|
|
585
|
+
for use as an index. These values should range from 0 to
|
|
586
|
+
the maximum number of symmetry groups - 1 for each example.
|
|
587
|
+
- symmetry_weight (torch.Tensor, optional): [B, L] - the weight
|
|
588
|
+
for the symmetry equivalence group. If None, the code will
|
|
589
|
+
implicitly use a weight of 1.0 for all residues.
|
|
590
|
+
- repeat_sample_num (int, optional): Number of times to repeat
|
|
591
|
+
the samples along the batch dimension. If None, no
|
|
592
|
+
repetition is performed. If greater than 1, the samples
|
|
593
|
+
are repeated along the batch dimension. If greater than 1,
|
|
594
|
+
B must be 1, since repeating samples along the batch
|
|
595
|
+
dimension is not supported when more than one sample is
|
|
596
|
+
provided in the batch.
|
|
597
|
+
graph_features (dict): Graph features containing the featurized
|
|
598
|
+
node and edge inputs.
|
|
599
|
+
- E_idx (torch.Tensor): [B, L, K] - edge indices.
|
|
600
|
+
encoder_features (dict): Encoder features containing the encoded
|
|
601
|
+
protein node and protein edge features.
|
|
602
|
+
- h_V (torch.Tensor): [B, L, H] - the protein node features
|
|
603
|
+
after encoding message passing.
|
|
604
|
+
- h_E (torch.Tensor): [B, L, K, H] - the protein edge features
|
|
605
|
+
after encoding message passing.
|
|
606
|
+
Side Effects:
|
|
607
|
+
input_features["S"] (torch.Tensor): [repeat_sample_num, L] - the
|
|
608
|
+
sequence of residues, repeated along the batch dimension.
|
|
609
|
+
input_features["residue_mask"] (torch.Tensor):
|
|
610
|
+
[repeat_sample_num, L] - the mask for the residues, repeated
|
|
611
|
+
along the batch dimension.
|
|
612
|
+
input_features["mask_for_loss"] (torch.Tensor):
|
|
613
|
+
[repeat_sample_num, L] - the mask for the loss, repeated
|
|
614
|
+
along the batch dimension.
|
|
615
|
+
input_features["designed_residue_mask"] (torch.Tensor):
|
|
616
|
+
[repeat_sample_num, L] - the mask for designed residues,
|
|
617
|
+
repeated along the batch dimension.
|
|
618
|
+
input_features["temperature"] (torch.Tensor, optional):
|
|
619
|
+
[repeat_sample_num, L] - the per-residue temperature to use for
|
|
620
|
+
sampling, repeated along the batch dimension. If None, the code
|
|
621
|
+
will implicitly use a temperature of 1.0.
|
|
622
|
+
input_features["bias"] (torch.Tensor, optional):
|
|
623
|
+
[repeat_sample_num, L, 21] - the per-residue bias to use for
|
|
624
|
+
sampling, repeated along the batch dimension. If None, the code
|
|
625
|
+
will implicitly use a bias of 0.0 for all residues.
|
|
626
|
+
input_features["pair_bias"] (torch.Tensor, optional):
|
|
627
|
+
[repeat_sample_num, L, 21, L, 21] - the per-residue pair bias
|
|
628
|
+
to use for sampling, repeated along the batch dimension. If
|
|
629
|
+
None, the code will implicitly use a pair bias of 0.0 for all
|
|
630
|
+
residue pairs.
|
|
631
|
+
input_features["symmetry_equivalence_group"] (torch.Tensor,
|
|
632
|
+
optional): [repeat_sample_num, L] - the symmetry equivalence
|
|
633
|
+
group for each residue, repeated along the batch dimension.
|
|
634
|
+
input_features["symmetry_weight"] (torch.Tensor, optional):
|
|
635
|
+
[repeat_sample_num, L] - the symmetry weight for each residue,
|
|
636
|
+
repeated along the batch dimension. If None, the code will
|
|
637
|
+
implicitly use a weight of 1.0 for all residues.
|
|
638
|
+
encoder_features["h_V"] (torch.Tensor): [repeat_sample_num, L, H] -
|
|
639
|
+
the protein node features after encoding message passing,
|
|
640
|
+
repeated along the batch dimension.
|
|
641
|
+
encoder_features["h_E"] (torch.Tensor):
|
|
642
|
+
[repeat_sample_num, L, K, H] - the protein edge features
|
|
643
|
+
after encoding message passing, repeated along the batch
|
|
644
|
+
dimension.
|
|
645
|
+
graph_features["E_idx"] (torch.Tensor): [repeat_sample_num, L, K] -
|
|
646
|
+
the edge indices, repeated along the batch dimension.
|
|
647
|
+
"""
|
|
648
|
+
# Check that the input features contains the necessary keys.
|
|
649
|
+
if "S" not in input_features:
|
|
650
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
651
|
+
if "residue_mask" not in input_features:
|
|
652
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
653
|
+
if "mask_for_loss" not in input_features:
|
|
654
|
+
raise ValueError("Input features must contain 'mask_for_loss' key.")
|
|
655
|
+
if "temperature" not in input_features:
|
|
656
|
+
raise ValueError("Input features must contain 'temperature' key.")
|
|
657
|
+
if "bias" not in input_features:
|
|
658
|
+
raise ValueError("Input features must contain 'bias' key.")
|
|
659
|
+
if "pair_bias" not in input_features:
|
|
660
|
+
raise ValueError("Input features must contain 'pair_bias' key.")
|
|
661
|
+
if "symmetry_equivalence_group" not in input_features:
|
|
662
|
+
raise ValueError(
|
|
663
|
+
"Input features must contain 'symmetry_equivalence_group' key."
|
|
664
|
+
)
|
|
665
|
+
if "symmetry_weight" not in input_features:
|
|
666
|
+
raise ValueError("Input features must contain 'symmetry_weight' key.")
|
|
667
|
+
if "repeat_sample_num" not in input_features:
|
|
668
|
+
raise ValueError("Input features must contain 'repeat_sample_num' key.")
|
|
669
|
+
|
|
670
|
+
# Check that the graph features contains the necessary keys.
|
|
671
|
+
if "E_idx" not in graph_features:
|
|
672
|
+
raise ValueError("Graph features must contain 'E_idx' key.")
|
|
673
|
+
|
|
674
|
+
# Check that the encoder features contains the necessary keys.
|
|
675
|
+
if "h_V" not in encoder_features:
|
|
676
|
+
raise ValueError("Encoder features must contain 'h_V' key.")
|
|
677
|
+
if "h_E" not in encoder_features:
|
|
678
|
+
raise ValueError("Encoder features must contain 'h_E' key.")
|
|
679
|
+
|
|
680
|
+
# Repeating a sample along the batch dimension is not supported
|
|
681
|
+
# when more than one sample is provided in the batch.
|
|
682
|
+
if (
|
|
683
|
+
input_features["repeat_sample_num"] is not None
|
|
684
|
+
and input_features["repeat_sample_num"] > 1
|
|
685
|
+
and input_features["S"].shape[0] > 1
|
|
686
|
+
):
|
|
687
|
+
raise ValueError(
|
|
688
|
+
"Cannot repeat samples when more than one sample "
|
|
689
|
+
+ "is provided in the batch."
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# Repeat the samples along the batch dimension if necessary.
|
|
693
|
+
if (
|
|
694
|
+
input_features["repeat_sample_num"] is not None
|
|
695
|
+
and input_features["repeat_sample_num"] > 1
|
|
696
|
+
):
|
|
697
|
+
# S [repeat_sample_num, L]
|
|
698
|
+
input_features["S"] = input_features["S"][0].repeat(
|
|
699
|
+
input_features["repeat_sample_num"], 1
|
|
700
|
+
)
|
|
701
|
+
# residue_mask [repeat_sample_num, L]
|
|
702
|
+
input_features["residue_mask"] = input_features["residue_mask"][0].repeat(
|
|
703
|
+
input_features["repeat_sample_num"], 1
|
|
704
|
+
)
|
|
705
|
+
# mask_for_loss [repeat_sample_num, L]
|
|
706
|
+
input_features["mask_for_loss"] = input_features["mask_for_loss"][0].repeat(
|
|
707
|
+
input_features["repeat_sample_num"], 1
|
|
708
|
+
)
|
|
709
|
+
# designed_residue_mask [repeat_sample_num, L]
|
|
710
|
+
input_features["designed_residue_mask"] = input_features[
|
|
711
|
+
"designed_residue_mask"
|
|
712
|
+
][0].repeat(input_features["repeat_sample_num"], 1)
|
|
713
|
+
if input_features["temperature"] is not None:
|
|
714
|
+
# temperature [repeat_sample_num, L]
|
|
715
|
+
input_features["temperature"] = input_features["temperature"][0].repeat(
|
|
716
|
+
input_features["repeat_sample_num"], 1
|
|
717
|
+
)
|
|
718
|
+
if input_features["bias"] is not None:
|
|
719
|
+
# bias [repeat_sample_num, L, 21]
|
|
720
|
+
input_features["bias"] = input_features["bias"][0].repeat(
|
|
721
|
+
input_features["repeat_sample_num"], 1, 1
|
|
722
|
+
)
|
|
723
|
+
if input_features["pair_bias"] is not None:
|
|
724
|
+
# pair_bias [repeat_sample_num, L, 21, L, 21]
|
|
725
|
+
input_features["pair_bias"] = input_features["pair_bias"][0].repeat(
|
|
726
|
+
input_features["repeat_sample_num"], 1, 1, 1, 1
|
|
727
|
+
)
|
|
728
|
+
if input_features["symmetry_equivalence_group"] is not None:
|
|
729
|
+
# symmetry_equivalence_group [repeat_sample_num, L]
|
|
730
|
+
input_features["symmetry_equivalence_group"] = input_features[
|
|
731
|
+
"symmetry_equivalence_group"
|
|
732
|
+
][0].repeat(input_features["repeat_sample_num"], 1)
|
|
733
|
+
if input_features["symmetry_weight"] is not None:
|
|
734
|
+
# symmetry_weight [repeat_sample_num, L]
|
|
735
|
+
input_features["symmetry_weight"] = input_features["symmetry_weight"][
|
|
736
|
+
0
|
|
737
|
+
].repeat(input_features["repeat_sample_num"], 1)
|
|
738
|
+
|
|
739
|
+
# h_V [repeat_sample_num, L, H]
|
|
740
|
+
encoder_features["h_V"] = encoder_features["h_V"][0].repeat(
|
|
741
|
+
input_features["repeat_sample_num"], 1, 1
|
|
742
|
+
)
|
|
743
|
+
# h_E [repeat_sample_num, L, K, H]
|
|
744
|
+
encoder_features["h_E"] = encoder_features["h_E"][0].repeat(
|
|
745
|
+
input_features["repeat_sample_num"], 1, 1, 1
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# E_idx [repeat_sample_num, L, K]
|
|
749
|
+
graph_features["E_idx"] = graph_features["E_idx"][0].repeat(
|
|
750
|
+
input_features["repeat_sample_num"], 1, 1
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
def decode_setup(
|
|
754
|
+
self, input_features, graph_features, encoder_features, decoder_features
|
|
755
|
+
):
|
|
756
|
+
"""
|
|
757
|
+
Given the input features, graph features, encoder features, and initial
|
|
758
|
+
decoder features, set up the decoder for the autoregressive decoding.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
input_features (dict): Input features containing the residue mask
|
|
762
|
+
and sequence.
|
|
763
|
+
- S (torch.Tensor): [B, L] - sequence of residues.
|
|
764
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
765
|
+
- initialize_sequence_embedding_with_ground_truth (bool):
|
|
766
|
+
If True, initialize the sequence embedding with the ground
|
|
767
|
+
truth sequence S. Else, initialize the sequence
|
|
768
|
+
embedding with zeros.
|
|
769
|
+
graph_features (dict): Graph features containing the featurized
|
|
770
|
+
node and edge inputs.
|
|
771
|
+
- E_idx (torch.Tensor): [B, L, K] - edge indices.
|
|
772
|
+
encoder_features (dict): Encoder features containing the encoded
|
|
773
|
+
protein node and protein edge features.
|
|
774
|
+
- h_V (torch.Tensor): [B, L, H] - the protein node features
|
|
775
|
+
after encoding message passing.
|
|
776
|
+
- h_E (torch.Tensor): [B, L, K, H] - the protein edge features
|
|
777
|
+
after encoding message passing.
|
|
778
|
+
decoder_features (dict): Initial decoder features containing the
|
|
779
|
+
anti-causal mask for the decoder.
|
|
780
|
+
- anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
|
|
781
|
+
anti-causal mask for the decoder.
|
|
782
|
+
Returns:
|
|
783
|
+
h_EXV_encoder_anti_causal (torch.Tensor): [B, L, K, 3H] - the
|
|
784
|
+
encoder embeddings masked with the anti-causal mask.
|
|
785
|
+
mask_E (torch.Tensor): [B, L, K] - the mask for the edges, gathered
|
|
786
|
+
at the neighbor indices.
|
|
787
|
+
h_S (torch.Tensor): [B, L, H] - the sequence embeddings for the
|
|
788
|
+
decoder.
|
|
789
|
+
"""
|
|
790
|
+
# Check that the input features contains the necessary keys.
|
|
791
|
+
if "S" not in input_features:
|
|
792
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
793
|
+
if "residue_mask" not in input_features:
|
|
794
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
795
|
+
if "initialize_sequence_embedding_with_ground_truth" not in input_features:
|
|
796
|
+
raise ValueError(
|
|
797
|
+
"Input features must contain"
|
|
798
|
+
+ "'initialize_sequence_embedding_with_ground_truth' key."
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
# Check that the graph features contains the necessary keys.
|
|
802
|
+
if "E_idx" not in graph_features:
|
|
803
|
+
raise ValueError("Graph features must contain 'E_idx' key.")
|
|
804
|
+
|
|
805
|
+
# Check that the encoder features contains the necessary keys.
|
|
806
|
+
if "h_V" not in encoder_features:
|
|
807
|
+
raise ValueError("Encoder features must contain 'h_V' key.")
|
|
808
|
+
if "h_E" not in encoder_features:
|
|
809
|
+
raise ValueError("Encoder features must contain 'h_E' key.")
|
|
810
|
+
|
|
811
|
+
# Check that the decoder features contains the necessary keys.
|
|
812
|
+
if "anti_causal_mask" not in decoder_features:
|
|
813
|
+
raise ValueError("Decoder features must contain 'anti_causal_mask' key.")
|
|
814
|
+
|
|
815
|
+
# Build encoder embeddings.
|
|
816
|
+
# h_EX_encoder [B, L, K, 2H] - h_E_ij cat (0 vector); the edge features
|
|
817
|
+
# concatenated with the zero vector, since there is no sequence
|
|
818
|
+
# information in the encoder.
|
|
819
|
+
h_EX_encoder = cat_neighbors_nodes(
|
|
820
|
+
torch.zeros_like(encoder_features["h_V"]),
|
|
821
|
+
encoder_features["h_E"],
|
|
822
|
+
graph_features["E_idx"],
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
# h_EXV_encoder [B, L, K, 3H] - h_E_ij cat (0 vector) cat h_V_j; the
|
|
826
|
+
# edge features concatenated with the zero vector and the destination
|
|
827
|
+
# node features from the encoder.
|
|
828
|
+
h_EXV_encoder = cat_neighbors_nodes(
|
|
829
|
+
encoder_features["h_V"], h_EX_encoder, graph_features["E_idx"]
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
# h_EXV_encoder_anti_causal [B, L, K, 3H] - the encoder embeddings,
|
|
833
|
+
# masked with the anti-causal mask.
|
|
834
|
+
h_EXV_encoder_anti_causal = h_EXV_encoder * decoder_features["anti_causal_mask"]
|
|
835
|
+
|
|
836
|
+
# Gather the per-residue mask of the nearest neighbors.
|
|
837
|
+
# mask_E [B, L, K] - the mask for the edges, gathered at the
|
|
838
|
+
# neighbor indices.
|
|
839
|
+
mask_E = gather_nodes(
|
|
840
|
+
input_features["residue_mask"].unsqueeze(-1), graph_features["E_idx"]
|
|
841
|
+
).squeeze(-1)
|
|
842
|
+
mask_E = input_features["residue_mask"].unsqueeze(-1) * mask_E
|
|
843
|
+
|
|
844
|
+
# Build sequence embedding for the decoder.
|
|
845
|
+
# h_S [B, L, H] - the sequence embeddings for the decoder, obtained by
|
|
846
|
+
# embedding the ground truth sequence S.
|
|
847
|
+
if input_features["initialize_sequence_embedding_with_ground_truth"]:
|
|
848
|
+
h_S = self.W_s(input_features["S"])
|
|
849
|
+
else:
|
|
850
|
+
h_S = torch.zeros_like(encoder_features["h_V"])
|
|
851
|
+
|
|
852
|
+
return h_EXV_encoder_anti_causal, mask_E, h_S
|
|
853
|
+
|
|
854
|
+
def logits_to_sample(self, logits, bias, pair_bias, S_for_pair_bias, temperature):
|
|
855
|
+
"""
|
|
856
|
+
Convert the logits to log probabilities, probabilities, sampled
|
|
857
|
+
probabilities, predicted sequence, and argmax sequence.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
logits (torch.Tensor): [B, L, self.vocab_size] - the logits for the
|
|
861
|
+
sequence.
|
|
862
|
+
bias (torch.Tensor, optional): [B, L, self.vocab_size] - the
|
|
863
|
+
bias for the sequence. If None, the code will implicitly use a
|
|
864
|
+
bias of 0.0 for all residues.
|
|
865
|
+
pair_bias (torch.Tensor, optional): [B, L, self.vocab_size, L',
|
|
866
|
+
self.vocab_size] - the pair bias for the sequence. Note,
|
|
867
|
+
L is the length for the logits, and L' is the length for the
|
|
868
|
+
S_for_pair_bias. In some cases, L' may be different from L (
|
|
869
|
+
for example, when the logits are only computed for a subset of
|
|
870
|
+
residues). If None, the code will implicitly use a pair bias
|
|
871
|
+
of 0.0 for all residue pairs.
|
|
872
|
+
S_for_pair_bias (torch.Tensor, optional): [B, L'] - the sequence for
|
|
873
|
+
the pair bias. This is used to compute the total pair bias for
|
|
874
|
+
every position. Allowed to be None if pair_bias is None.
|
|
875
|
+
temperature (torch.Tensor, optional): [B, L] - the per-residue
|
|
876
|
+
temperature to use for sampling. If None, the code will
|
|
877
|
+
implicitly use a temperature of 1.0 for all residues.
|
|
878
|
+
Returns:
|
|
879
|
+
sample_dict (dict): A dictionary containing the following keys:
|
|
880
|
+
- log_probs (torch.Tensor): [B, L, self.vocab_size] - the log
|
|
881
|
+
probabilities for the sequence.
|
|
882
|
+
- probs (torch.Tensor): [B, L, self.vocab_size] - the
|
|
883
|
+
probabilities for the sequence.
|
|
884
|
+
- probs_sample (torch.Tensor): [B, L, self.vocab_size] -
|
|
885
|
+
the probabilities for the sequence, with the unknown
|
|
886
|
+
residues zeroed out and the other residues normalized.
|
|
887
|
+
- S_sampled (torch.Tensor): [B, L] - the predicted sequence,
|
|
888
|
+
sampled from the probabilities (unknown residues are not
|
|
889
|
+
sampled).
|
|
890
|
+
- S_argmax (torch.Tensor): [B, L] - the predicted sequence,
|
|
891
|
+
obtained by taking the argmax of the probabilities (unknown
|
|
892
|
+
residues are not selected).
|
|
893
|
+
"""
|
|
894
|
+
B, L, vocab_size = logits.shape
|
|
895
|
+
|
|
896
|
+
if pair_bias is not None:
|
|
897
|
+
# pair_bias_total [B, L, self.vocab_size] - the total pair bias to
|
|
898
|
+
# add to the sequence logits, computed for every residue by
|
|
899
|
+
# indexing the pair bias with the sequence (S_for_pair_bias) and
|
|
900
|
+
# summing over the second sequence dimension (L').
|
|
901
|
+
pair_bias_total = torch.gather(
|
|
902
|
+
pair_bias,
|
|
903
|
+
-1,
|
|
904
|
+
S_for_pair_bias[:, None, None, :, None].expand(
|
|
905
|
+
-1, -1, self.vocab_size, -1, -1
|
|
906
|
+
),
|
|
907
|
+
).sum(dim=(-2, -1))
|
|
908
|
+
else:
|
|
909
|
+
pair_bias_total = None
|
|
910
|
+
|
|
911
|
+
# modified_logits [B, L, self.vocab_size] - the logits for the
|
|
912
|
+
# sequence, modified by temperature, bias, and total pair bias.
|
|
913
|
+
modified_logits = (
|
|
914
|
+
logits
|
|
915
|
+
+ (0.0 if bias is None else bias)
|
|
916
|
+
+ (0.0 if pair_bias_total is None else pair_bias_total)
|
|
917
|
+
)
|
|
918
|
+
modified_logits = modified_logits / (
|
|
919
|
+
1.0 if temperature is None else temperature.unsqueeze(-1)
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
# log_probs [B, L, self.vocab_size] - the log probabilities for the
|
|
923
|
+
# sequence.
|
|
924
|
+
log_probs = torch.nn.functional.log_softmax(modified_logits, dim=-1)
|
|
925
|
+
|
|
926
|
+
# probs [B, L, self.vocab_size] - the probabilities for the sequence.
|
|
927
|
+
probs = torch.nn.functional.softmax(modified_logits, dim=-1)
|
|
928
|
+
|
|
929
|
+
# probs_sample [B, L, self.vocab_size] - the probabilities for the
|
|
930
|
+
# sequence, with the unknown residues zeroed out and the other residues
|
|
931
|
+
# normalized.
|
|
932
|
+
probs_sample = probs.clone()
|
|
933
|
+
probs_sample[:, :, self.unknown_token_indices] = 0.0
|
|
934
|
+
probs_sample = probs_sample / torch.sum(probs_sample, dim=-1, keepdim=True)
|
|
935
|
+
|
|
936
|
+
# probs_sample_flat [B * L, self.vocab_size] - the flattened
|
|
937
|
+
# probabilities for the sequence.
|
|
938
|
+
probs_sample_flat = probs_sample.view(B * L, vocab_size)
|
|
939
|
+
|
|
940
|
+
# S_sampled [B, L] - the predicted sequence, sampled from the
|
|
941
|
+
# probabilities (unknown residues are not sampled).
|
|
942
|
+
S_sampled = torch.multinomial(probs_sample_flat, 1).squeeze(-1).view(B, L)
|
|
943
|
+
|
|
944
|
+
# S_argmax [B, L] - the predicted sequence, obtained by taking the
|
|
945
|
+
# argmax of the probabilities (unknown residues are not selected).
|
|
946
|
+
S_argmax = torch.argmax(probs_sample, dim=-1)
|
|
947
|
+
|
|
948
|
+
sample_dict = {
|
|
949
|
+
"log_probs": log_probs,
|
|
950
|
+
"probs": probs,
|
|
951
|
+
"probs_sample": probs_sample,
|
|
952
|
+
"S_sampled": S_sampled,
|
|
953
|
+
"S_argmax": S_argmax,
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
return sample_dict
|
|
957
|
+
|
|
958
|
+
def decode_teacher_forcing(
|
|
959
|
+
self, input_features, graph_features, encoder_features, decoder_features
|
|
960
|
+
):
|
|
961
|
+
"""
|
|
962
|
+
Given the input features, graph features, encoder features, and
|
|
963
|
+
decoder features, perform the decoding with teacher forcing.
|
|
964
|
+
|
|
965
|
+
Although h_S is computed from the ground truth sequence S, the causal
|
|
966
|
+
mask will ensure that the decoder only attends to the sequence of
|
|
967
|
+
previously decoded residues. Using the ground truth for all previous
|
|
968
|
+
residues is called teacher forcing, and it is a common technique in
|
|
969
|
+
language modeling tasks.
|
|
970
|
+
|
|
971
|
+
Args:
|
|
972
|
+
input_features (dict): Input features containing the residue mask
|
|
973
|
+
and sequence.
|
|
974
|
+
- S (torch.Tensor): [B, L] - sequence of residues.
|
|
975
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
976
|
+
- bias (torch.Tensor, optional): [B, L, self.vocab_size] - the
|
|
977
|
+
per-residue bias to use for sampling. If None, the code will
|
|
978
|
+
implicitly use a bias of 0.0 for all residues.
|
|
979
|
+
- pair_bias (torch.Tensor, optional): [B, L, self.vocab_size
|
|
980
|
+
, L, self.vocab_size] - the per-residue pair bias to use
|
|
981
|
+
for sampling. If None, the code will implicitly use a pair
|
|
982
|
+
bias of 0.0 for all residue pairs.
|
|
983
|
+
- temperature (torch.Tensor, optional): [B, L] - the per-residue
|
|
984
|
+
temperature to use for sampling. If None, the code will
|
|
985
|
+
implicitly use a temperature of 1.0.
|
|
986
|
+
- initialize_sequence_embedding_with_ground_truth (bool):
|
|
987
|
+
If True, initialize the sequence embedding with the ground
|
|
988
|
+
truth sequence S. Else, initialize the sequence
|
|
989
|
+
embedding with zeros.
|
|
990
|
+
- symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
|
|
991
|
+
an integer for every residue, indicating the symmetry group
|
|
992
|
+
that it belongs to. If None, the residues are not grouped by
|
|
993
|
+
symmetry. For example, if residue i and j should be decoded
|
|
994
|
+
symmetrically, then symmetry_equivalence_group[i] ==
|
|
995
|
+
symmetry_equivalence_group[j]. Must be torch.int64 to allow
|
|
996
|
+
for use as an index. These values should range from 0 to
|
|
997
|
+
the maximum number of symmetry groups - 1 for each example.
|
|
998
|
+
-symmetry_weight (torch.Tensor, optional): [B, L] - the weights
|
|
999
|
+
for each residue, to be used when aggregating across its
|
|
1000
|
+
respective symmetry group. If None, the weights are
|
|
1001
|
+
assumed to be 1.0 for all residues.
|
|
1002
|
+
graph_features (dict): Graph features containing the featurized
|
|
1003
|
+
node and edge inputs.
|
|
1004
|
+
- E_idx (torch.Tensor): [B, L, K] - edge indices.
|
|
1005
|
+
encoder_features (dict): Encoder features containing the encoded
|
|
1006
|
+
protein node and protein edge features.
|
|
1007
|
+
- h_V (torch.Tensor): [B, L, H] - the protein node
|
|
1008
|
+
features after encoding message passing.
|
|
1009
|
+
- h_E (torch.Tensor): [B, L, K, H] - the
|
|
1010
|
+
protein edge features after encoding message passing.
|
|
1011
|
+
decoder_features (dict): Initial decoder features containing the
|
|
1012
|
+
causal mask for the decoder.
|
|
1013
|
+
- causal_mask (torch.Tensor): [B, L, K, 1] - the
|
|
1014
|
+
causal mask for the decoder.
|
|
1015
|
+
- anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
|
|
1016
|
+
anti-causal mask for the decoder.
|
|
1017
|
+
Side Effects:
|
|
1018
|
+
decoder_features["h_V"] (torch.Tensor): [B, L, H] - the updated
|
|
1019
|
+
node features for the decoder.
|
|
1020
|
+
decoder_features["logits"] (torch.Tensor): [B, L, self.vocab_size] -
|
|
1021
|
+
the sequence logits for the decoder.
|
|
1022
|
+
decoder_features["log_probs"] (torch.Tensor): [B, L,
|
|
1023
|
+
self.vocab_size] - the log probabilities for the sequence.
|
|
1024
|
+
decoder_features["probs"] (torch.Tensor): [B, L, self.vocab_size] -
|
|
1025
|
+
the probabilities for the sequence.
|
|
1026
|
+
decoder_features["probs_sample"] (torch.Tensor): [B, L,
|
|
1027
|
+
self.vocab_size] - the probabilities for the sequence, with the
|
|
1028
|
+
unknown residues zeroed out and the other residues normalized.
|
|
1029
|
+
decoder_features["S_sampled"] (torch.Tensor): [B, L] - the
|
|
1030
|
+
predicted sequence, sampled from the probabilities (unknown
|
|
1031
|
+
residues are not sampled).
|
|
1032
|
+
decoder_features["S_argmax"] (torch.Tensor): [B, L] - the predicted
|
|
1033
|
+
sequence, obtained by taking the argmax of the probabilities
|
|
1034
|
+
(unknown residues are not selected).
|
|
1035
|
+
"""
|
|
1036
|
+
# Check that the input features contains the necessary keys.
|
|
1037
|
+
if "residue_mask" not in input_features:
|
|
1038
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
1039
|
+
if "S" not in input_features:
|
|
1040
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
1041
|
+
if "bias" not in input_features:
|
|
1042
|
+
raise ValueError("Input features must contain 'bias' key.")
|
|
1043
|
+
if "pair_bias" not in input_features:
|
|
1044
|
+
raise ValueError("Input features must contain 'pair_bias' key.")
|
|
1045
|
+
if "temperature" not in input_features:
|
|
1046
|
+
raise ValueError("Input features must contain 'temperature' key.")
|
|
1047
|
+
if "symmetry_equivalence_group" not in input_features:
|
|
1048
|
+
raise ValueError(
|
|
1049
|
+
"Input features must contain 'symmetry_equivalence_group' key."
|
|
1050
|
+
)
|
|
1051
|
+
if "symmetry_weight" not in input_features:
|
|
1052
|
+
raise ValueError("Input features must contain 'symmetry_weight' key.")
|
|
1053
|
+
|
|
1054
|
+
# Check that the graph features contains the necessary keys.
|
|
1055
|
+
if "E_idx" not in graph_features:
|
|
1056
|
+
raise ValueError("Graph features must contain 'E_idx' key.")
|
|
1057
|
+
|
|
1058
|
+
# Check that the encoder features contains the necessary keys.
|
|
1059
|
+
if "h_V" not in encoder_features:
|
|
1060
|
+
raise ValueError("Encoder features must contain 'h_V' key.")
|
|
1061
|
+
if "h_E" not in encoder_features:
|
|
1062
|
+
raise ValueError("Encoder features must contain 'h_E' key.")
|
|
1063
|
+
|
|
1064
|
+
# Check that the decoder features contains the necessary keys.
|
|
1065
|
+
if "causal_mask" not in decoder_features:
|
|
1066
|
+
raise ValueError("Decoder features must contain 'causal_mask' key.")
|
|
1067
|
+
|
|
1068
|
+
# Do the setup for the decoder.
|
|
1069
|
+
# h_EXV_encoder_anti_causal [B, L, K, 3H] - the encoder embeddings,
|
|
1070
|
+
# masked with the anti-causal mask.
|
|
1071
|
+
# mask_E [B, L, K] - the mask for the edges, gathered at the
|
|
1072
|
+
# neighbor indices.
|
|
1073
|
+
# h_S [B, L, H] - the sequence embeddings for the decoder.
|
|
1074
|
+
h_EXV_encoder_anti_causal, mask_E, h_S = self.decode_setup(
|
|
1075
|
+
input_features, graph_features, encoder_features, decoder_features
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
# h_ES [B, L, K, 2H] - h_E_ij cat h_S_j; the edge features concatenated
|
|
1079
|
+
# with the sequence embeddings for the destination nodes.
|
|
1080
|
+
h_ES = cat_neighbors_nodes(
|
|
1081
|
+
h_S, encoder_features["h_E"], graph_features["E_idx"]
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
# Run the decoder layers.
|
|
1085
|
+
h_V_decoder = encoder_features["h_V"]
|
|
1086
|
+
for layer in self.decoder_layers:
|
|
1087
|
+
# h_ESV_decoder [B, L, K, 3H] - h_E_ij cat h_S_j cat h_V_decoder_j;
|
|
1088
|
+
# for the decoder embeddings, the edge features are concatenated
|
|
1089
|
+
# with the destination node sequence embeddings and node features.
|
|
1090
|
+
h_ESV_decoder = cat_neighbors_nodes(
|
|
1091
|
+
h_V_decoder, h_ES, graph_features["E_idx"]
|
|
1092
|
+
)
|
|
1093
|
+
|
|
1094
|
+
# h_ESV [B, L, K, 3H] - the encoder and decoder embeddings,
|
|
1095
|
+
# combined according to the causal and anti-causal masks.
|
|
1096
|
+
# Combine the encoder embeddings with the decoder embeddings,
|
|
1097
|
+
# using the causal and anti-causal masks. When decoding the residue
|
|
1098
|
+
# at position i:
|
|
1099
|
+
# - for residue j, decoded before i:
|
|
1100
|
+
# - h_ESV_ij = h_E_ij cat h_S_j cat h_V_decoder_j
|
|
1101
|
+
# - encoder edge embedding, decoder destination node
|
|
1102
|
+
# sequence embedding, and decoder destination node
|
|
1103
|
+
# embedding.
|
|
1104
|
+
# - for residue j, decoded after i (including i):
|
|
1105
|
+
# - h_ESV_ij = h_E_ij cat (0 vector) cat h_V_j
|
|
1106
|
+
# - encoder edge embedding, zero vector (no sequence
|
|
1107
|
+
# information), and encoder destination node embedding.
|
|
1108
|
+
# This prevents leakage of sequence information.
|
|
1109
|
+
# - NOTE: h_V_j comes from the encoder.
|
|
1110
|
+
# - NOTE: h_E is not updated in the decoder, h_E_ij comes from
|
|
1111
|
+
# the encoder.
|
|
1112
|
+
# - NOTE: within the decoder layer itself, h_V_decoder_i will
|
|
1113
|
+
# be concatenated to h_ESV_ij.
|
|
1114
|
+
h_ESV = (
|
|
1115
|
+
decoder_features["causal_mask"] * h_ESV_decoder
|
|
1116
|
+
+ h_EXV_encoder_anti_causal
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
# h_V_decoder [B, L, H] - the updated node features for the
|
|
1120
|
+
# decoder.
|
|
1121
|
+
h_V_decoder = torch.utils.checkpoint.checkpoint(
|
|
1122
|
+
layer,
|
|
1123
|
+
h_V_decoder,
|
|
1124
|
+
h_ESV,
|
|
1125
|
+
mask_V=input_features["residue_mask"],
|
|
1126
|
+
mask_E=mask_E,
|
|
1127
|
+
use_reentrant=False,
|
|
1128
|
+
)
|
|
1129
|
+
|
|
1130
|
+
# logits [B, L, self.vocab_size] - project the final node features to
|
|
1131
|
+
# get the sequence logits.
|
|
1132
|
+
logits = self.W_out(h_V_decoder)
|
|
1133
|
+
|
|
1134
|
+
# Handle symmetry equivalence groups if they are provided, performing
|
|
1135
|
+
# a (possibly weighted) sum of the logits across residues that
|
|
1136
|
+
# belong to the same symmetry group.
|
|
1137
|
+
if input_features["symmetry_equivalence_group"] is not None:
|
|
1138
|
+
# Assume that all symmetry groups are non-negative.
|
|
1139
|
+
assert input_features["symmetry_equivalence_group"].min() >= 0
|
|
1140
|
+
|
|
1141
|
+
B, L, _ = logits.shape
|
|
1142
|
+
|
|
1143
|
+
# The maximum number of symmetry groups in the batch.
|
|
1144
|
+
G = (input_features["symmetry_equivalence_group"].max().item()) + 1
|
|
1145
|
+
|
|
1146
|
+
# symmetry_equivalence_group_one_hot [B, L, G] - one-hot encoding
|
|
1147
|
+
# of the symmetry equivalence group for each residue.
|
|
1148
|
+
symmetry_equivalence_group_one_hot = torch.nn.functional.one_hot(
|
|
1149
|
+
input_features["symmetry_equivalence_group"], num_classes=G
|
|
1150
|
+
).float()
|
|
1151
|
+
|
|
1152
|
+
# scaled_symmetry_equivalence_group_one_hot [B, L, G] - the one-hot
|
|
1153
|
+
# encoding of the symmetry equivalence group, scaled by the
|
|
1154
|
+
# symmetry weights, if they are provided. If not provided, the
|
|
1155
|
+
# symmetry weights are implicitly assumed to be 1.0 for all
|
|
1156
|
+
# residues.
|
|
1157
|
+
scaled_symmetry_equivalence_group_one_hot = (
|
|
1158
|
+
symmetry_equivalence_group_one_hot
|
|
1159
|
+
* (
|
|
1160
|
+
1.0
|
|
1161
|
+
if input_features["symmetry_weight"] is None
|
|
1162
|
+
else input_features["symmetry_weight"].unsqueeze(-1)
|
|
1163
|
+
)
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
# weighted_sum_logits [B, G, self.vocab_size] - the logits for the
|
|
1167
|
+
# sequence, summed across the symmetry groups, weighted by the
|
|
1168
|
+
# symmetry weights.
|
|
1169
|
+
weighted_sum_logits = torch.einsum(
|
|
1170
|
+
"blg,blv->bgv", scaled_symmetry_equivalence_group_one_hot, logits
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
# logits [B, L, self.vocab_size] - overwrite the original logits
|
|
1174
|
+
# with the weighted and summed logits for the residues that belong
|
|
1175
|
+
# to the same symmetry group.
|
|
1176
|
+
logits = torch.einsum(
|
|
1177
|
+
"blg,bgv->blv", symmetry_equivalence_group_one_hot, weighted_sum_logits
|
|
1178
|
+
)
|
|
1179
|
+
|
|
1180
|
+
# Compute the log probabilities, probabilities, sampled probabilities,
|
|
1181
|
+
# predicted sequence, and argmax sequence.
|
|
1182
|
+
sample_dict = self.logits_to_sample(
|
|
1183
|
+
logits,
|
|
1184
|
+
input_features["bias"],
|
|
1185
|
+
input_features["pair_bias"],
|
|
1186
|
+
input_features["S"],
|
|
1187
|
+
input_features["temperature"],
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# All outputs from logits_to_sample should be the same across
|
|
1191
|
+
# symmetry equivalence groups, except for S_sampled (due to the
|
|
1192
|
+
# sampling being stochastic). Correct for this by overwriting S_sampled
|
|
1193
|
+
# with the first sampled residue in each group.
|
|
1194
|
+
if input_features["symmetry_equivalence_group"] is not None:
|
|
1195
|
+
S_sampled = sample_dict["S_sampled"]
|
|
1196
|
+
|
|
1197
|
+
B, L = S_sampled.shape
|
|
1198
|
+
|
|
1199
|
+
# Compute the maximum number of symmetry groups in the batch.
|
|
1200
|
+
G = (input_features["symmetry_equivalence_group"].max().item()) + 1
|
|
1201
|
+
|
|
1202
|
+
for b in range(B):
|
|
1203
|
+
for g in range(G):
|
|
1204
|
+
# group_mask [L] - mask for the residues that belong to
|
|
1205
|
+
# the symmetry equivalence group g for the batch example b.
|
|
1206
|
+
group_mask = input_features["symmetry_equivalence_group"][b] == g
|
|
1207
|
+
|
|
1208
|
+
# If there are residues in the group, set every S_sampled
|
|
1209
|
+
# in the group to the first S_sampled in the group.
|
|
1210
|
+
if group_mask.any():
|
|
1211
|
+
first = torch.where(group_mask)[0][0]
|
|
1212
|
+
S_sampled[b, group_mask] = S_sampled[b, first]
|
|
1213
|
+
|
|
1214
|
+
sample_dict["S_sampled"] = S_sampled
|
|
1215
|
+
|
|
1216
|
+
# Update the decoder features with the final node features, the computed
|
|
1217
|
+
# logits, log probabilities, probabilities, sampled probabilities,
|
|
1218
|
+
# predicted sequence, and argmax sequence.
|
|
1219
|
+
decoder_features["h_V"] = h_V_decoder
|
|
1220
|
+
decoder_features["logits"] = logits
|
|
1221
|
+
decoder_features["log_probs"] = sample_dict["log_probs"]
|
|
1222
|
+
decoder_features["probs"] = sample_dict["probs"]
|
|
1223
|
+
decoder_features["probs_sample"] = sample_dict["probs_sample"]
|
|
1224
|
+
decoder_features["S_sampled"] = sample_dict["S_sampled"]
|
|
1225
|
+
decoder_features["S_argmax"] = sample_dict["S_argmax"]
|
|
1226
|
+
|
|
1227
|
+
def decode_auto_regressive(
|
|
1228
|
+
self, input_features, graph_features, encoder_features, decoder_features
|
|
1229
|
+
):
|
|
1230
|
+
"""
|
|
1231
|
+
Given the input features, graph features, encoder features, and
|
|
1232
|
+
decoder features, perform the autoregressive decoding.
|
|
1233
|
+
|
|
1234
|
+
Args:
|
|
1235
|
+
input_features (dict): Input features containing the residue mask
|
|
1236
|
+
and sequence.
|
|
1237
|
+
- S (torch.Tensor): [B, L] - sequence of residues.
|
|
1238
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
1239
|
+
- bias (torch.Tensor, optional): [B, L, self.vocab_size]
|
|
1240
|
+
- the per-residue bias to use for sampling. If None, the
|
|
1241
|
+
code will implicitly use a bias of 0.0 for all residues.
|
|
1242
|
+
- pair_bias (torch.Tensor, optional): [B, L, self.vocab_size
|
|
1243
|
+
, L, self.vocab_size] - the per-residue pair bias to use
|
|
1244
|
+
for sampling. If None, the code will implicitly use a pair
|
|
1245
|
+
bias of 0.0 for all residue pairs.
|
|
1246
|
+
- temperature (torch.Tensor, optional): [B, L] - the per-residue
|
|
1247
|
+
temperature to use for sampling. If None, the code will
|
|
1248
|
+
implicitly use a temperature of 1.0.
|
|
1249
|
+
- initialize_sequence_embedding_with_ground_truth (bool):
|
|
1250
|
+
If True, initialize the sequence embedding with the ground
|
|
1251
|
+
truth sequence S. Else, initialize the sequence
|
|
1252
|
+
embedding with zeros. Also, if True, initialize S_sampled
|
|
1253
|
+
with the ground truth sequence S, which should only affect
|
|
1254
|
+
the application of pair bias (which relies on the predicted
|
|
1255
|
+
sequence). This is useful if we want to perform
|
|
1256
|
+
auto-regressive redesign.
|
|
1257
|
+
- symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
|
|
1258
|
+
an integer for every residue, indicating the symmetry group
|
|
1259
|
+
that it belongs to. If None, the residues are not grouped by
|
|
1260
|
+
symmetry. For example, if residue i and j should be decoded
|
|
1261
|
+
symmetrically, then symmetry_equivalence_group[i] ==
|
|
1262
|
+
symmetry_equivalence_group[j]. Must be torch.int64 to allow
|
|
1263
|
+
for use as an index. These values should range from 0 to
|
|
1264
|
+
the maximum number of symmetry groups - 1 for each example.
|
|
1265
|
+
NOTE: bias, pair_bias, and temperature should be the same
|
|
1266
|
+
for all residues in the symmetry equivalence group;
|
|
1267
|
+
otherwise, the intended behavior may not be achieved. The
|
|
1268
|
+
residues within a symmetry group should all have the same
|
|
1269
|
+
validity and design/fixed status.
|
|
1270
|
+
-symmetry_weight (torch.Tensor, optional): [B, L] - the weights
|
|
1271
|
+
for each residue, to be used when aggregating across its
|
|
1272
|
+
respective symmetry group. If None, the weights are
|
|
1273
|
+
assumed to be 1.0 for all residues.
|
|
1274
|
+
graph_features (dict): Graph features containing the featurized
|
|
1275
|
+
node and edge inputs.
|
|
1276
|
+
- E_idx (torch.Tensor): [B, L, K] - edge indices.
|
|
1277
|
+
encoder_features (dict): Encoder features containing the encoded
|
|
1278
|
+
protein node and protein edge features.
|
|
1279
|
+
- h_V (torch.Tensor): [B, L, H] - the protein node features
|
|
1280
|
+
after encoding message passing.
|
|
1281
|
+
- h_E (torch.Tensor): [B, L, K, H] - the protein edge features
|
|
1282
|
+
after encoding message passing.
|
|
1283
|
+
decoder_features (dict): Initial decoder features containing the
|
|
1284
|
+
causal mask for the decoder.
|
|
1285
|
+
- decoding_order (torch.Tensor): [B, L] - the order in which
|
|
1286
|
+
the residues should be decoded.
|
|
1287
|
+
- decode_last_mask (torch.Tensor): [B, L] - the mask for which
|
|
1288
|
+
residues should be decoded last, where False is a residue
|
|
1289
|
+
that should be decoded first (invalid or fixed), and True
|
|
1290
|
+
is a residue that should not be decoded first (designed
|
|
1291
|
+
residues).
|
|
1292
|
+
- causal_mask (torch.Tensor): [B, L, K, 1] - the causal mask
|
|
1293
|
+
for the decoder.
|
|
1294
|
+
- anti_causal_mask (torch.Tensor): [B, L, K, 1] - the anti-
|
|
1295
|
+
causal mask for the decoder.
|
|
1296
|
+
Side Effects:
|
|
1297
|
+
decoder_features["h_V"] (torch.Tensor): [B, L, H] - the updated
|
|
1298
|
+
node features for the decoder.
|
|
1299
|
+
decoder_features["logits"] (torch.Tensor): [B, L, self.vocab_size] -
|
|
1300
|
+
the sequence logits for the decoder.
|
|
1301
|
+
decoder_features["log_probs"] (torch.Tensor): [B, L,
|
|
1302
|
+
self.vocab_size] - the log probabilities for the sequence.
|
|
1303
|
+
decoder_features["probs"] (torch.Tensor): [B, L, self.vocab_size] -
|
|
1304
|
+
the probabilities for the sequence.
|
|
1305
|
+
decoder_features["probs_sample"] (torch.Tensor): [B, L,
|
|
1306
|
+
self.vocab_size] - the probabilities for the sequence, with the
|
|
1307
|
+
unknown residues zeroed out and the other residues normalized.
|
|
1308
|
+
decoder_features["S_sampled"] (torch.Tensor): [B, L] - the
|
|
1309
|
+
predicted sequence, sampled from the probabilities (unknown
|
|
1310
|
+
residues are not sampled).
|
|
1311
|
+
decoder_features["S_argmax"] (torch.Tensor): [B, L] - the predicted
|
|
1312
|
+
sequence, obtained by taking the argmax of the probabilities
|
|
1313
|
+
(unknown residues are not selected).
|
|
1314
|
+
"""
|
|
1315
|
+
# Check that the input features contains the necessary keys.
|
|
1316
|
+
if "residue_mask" not in input_features:
|
|
1317
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
1318
|
+
if "S" not in input_features:
|
|
1319
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
1320
|
+
if "temperature" not in input_features:
|
|
1321
|
+
raise ValueError("Input features must contain 'temperature' key.")
|
|
1322
|
+
if "bias" not in input_features:
|
|
1323
|
+
raise ValueError("Input features must contain 'bias' key.")
|
|
1324
|
+
if "pair_bias" not in input_features:
|
|
1325
|
+
raise ValueError("Input features must contain 'pair_bias' key.")
|
|
1326
|
+
if "symmetry_equivalence_group" not in input_features:
|
|
1327
|
+
raise ValueError(
|
|
1328
|
+
"Input features must contain 'symmetry_equivalence_group' key."
|
|
1329
|
+
)
|
|
1330
|
+
if "symmetry_weight" not in input_features:
|
|
1331
|
+
raise ValueError("Input features must contain 'symmetry_weight' key.")
|
|
1332
|
+
|
|
1333
|
+
# Check that the graph features contains the necessary keys.
|
|
1334
|
+
if "E_idx" not in graph_features:
|
|
1335
|
+
raise ValueError("Graph features must contain 'E_idx' key.")
|
|
1336
|
+
|
|
1337
|
+
# Check that the encoder features contains the necessary keys.
|
|
1338
|
+
if "h_V" not in encoder_features:
|
|
1339
|
+
raise ValueError("Encoder features must contain 'h_V' key.")
|
|
1340
|
+
if "h_E" not in encoder_features:
|
|
1341
|
+
raise ValueError("Encoder features must contain 'h_E' key.")
|
|
1342
|
+
|
|
1343
|
+
# Check that the decoder features contains the necessary keys.
|
|
1344
|
+
if "decoding_order" not in decoder_features:
|
|
1345
|
+
raise ValueError("Decoder features must contain 'decoding_order' key.")
|
|
1346
|
+
if "decode_last_mask" not in decoder_features:
|
|
1347
|
+
raise ValueError("Decoder features must contain 'decode_last_mask' key.")
|
|
1348
|
+
if "causal_mask" not in decoder_features:
|
|
1349
|
+
raise ValueError("Decoder features must contain 'causal_mask' key.")
|
|
1350
|
+
|
|
1351
|
+
B, L = input_features["residue_mask"].shape
|
|
1352
|
+
|
|
1353
|
+
# Do the setup for the decoder.
|
|
1354
|
+
# h_EXV_encoder_anti_causal [B, L, K, 3H] - the encoder embeddings,
|
|
1355
|
+
# masked with the anti-causal mask.
|
|
1356
|
+
# mask_E [B, L, K] - the mask for the edges, gathered at the
|
|
1357
|
+
# neighbor indices.
|
|
1358
|
+
# h_S [B, L, H] - the sequence embeddings for the decoder.
|
|
1359
|
+
h_EXV_encoder_anti_causal, mask_E, h_S = self.decode_setup(
|
|
1360
|
+
input_features, graph_features, encoder_features, decoder_features
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
# We can precompute the output dtype depending on automatic mixed
|
|
1364
|
+
# precision settings. This works because the W_out layer is a linear
|
|
1365
|
+
# layer, which has predictable dtype behavior with AMP.
|
|
1366
|
+
device = input_features["residue_mask"].device
|
|
1367
|
+
if device.type in ("cuda", "cpu") and torch.is_autocast_enabled(
|
|
1368
|
+
device_type=device.type
|
|
1369
|
+
):
|
|
1370
|
+
output_dtype = torch.get_autocast_dtype(device_type=device.type)
|
|
1371
|
+
else:
|
|
1372
|
+
output_dtype = torch.float32
|
|
1373
|
+
|
|
1374
|
+
# logits [B, L, self.vocab_size] - the logits for every residue
|
|
1375
|
+
# position and residue type.
|
|
1376
|
+
logits = torch.zeros(
|
|
1377
|
+
(B, L, self.vocab_size),
|
|
1378
|
+
device=input_features["residue_mask"].device,
|
|
1379
|
+
dtype=output_dtype,
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1382
|
+
# logits_i [B, 1, self.vocab_size] - the logits for the
|
|
1383
|
+
# residue at the current decoding index, computed from the
|
|
1384
|
+
# decoded node features. Declared here for accumulation use when
|
|
1385
|
+
# performing symmetry decoding.
|
|
1386
|
+
logits_i = torch.zeros(
|
|
1387
|
+
(B, 1, self.vocab_size),
|
|
1388
|
+
device=input_features["residue_mask"].device,
|
|
1389
|
+
dtype=output_dtype,
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
# log_probs [B, L, self.vocab_size] - the log probabilities for every
|
|
1393
|
+
# residue position and residue type.
|
|
1394
|
+
log_probs = torch.zeros(
|
|
1395
|
+
(B, L, self.vocab_size),
|
|
1396
|
+
device=input_features["residue_mask"].device,
|
|
1397
|
+
dtype=torch.float32,
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
# probs [B, L, self.vocab_size] - the probabilities for every residue
|
|
1401
|
+
# position and residue type.
|
|
1402
|
+
probs = torch.zeros(
|
|
1403
|
+
(B, L, self.vocab_size),
|
|
1404
|
+
device=input_features["residue_mask"].device,
|
|
1405
|
+
dtype=torch.float32,
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
# probs_sample [B, L, self.vocab_size] - the probabilities for every
|
|
1409
|
+
# residue position and residue type, with the unknown residues zeroed
|
|
1410
|
+
# out and the other residues normalized.
|
|
1411
|
+
probs_sample = torch.zeros(
|
|
1412
|
+
(B, L, self.vocab_size),
|
|
1413
|
+
device=input_features["residue_mask"].device,
|
|
1414
|
+
dtype=torch.float32,
|
|
1415
|
+
)
|
|
1416
|
+
|
|
1417
|
+
# S_sampled [B, L] - the predicted/sampled sequence of residues. If
|
|
1418
|
+
# initialize_sequence_embedding_with_ground_truth is True, it is
|
|
1419
|
+
# initialized with the ground truth sequence S; this should only
|
|
1420
|
+
# affect the application of pair bias (which relies on the predicted
|
|
1421
|
+
# sequence), and is useful if we want to perform auto-regressive
|
|
1422
|
+
# redesign. Otherwise, this should have no effect, as we overwrite
|
|
1423
|
+
# S_sampled with the sampled sequence at every decoding step.
|
|
1424
|
+
if input_features["initialize_sequence_embedding_with_ground_truth"]:
|
|
1425
|
+
S_sampled = input_features["S"].clone()
|
|
1426
|
+
else:
|
|
1427
|
+
S_sampled = torch.full(
|
|
1428
|
+
(B, L),
|
|
1429
|
+
fill_value=self.token_to_idx[UNKNOWN_AA],
|
|
1430
|
+
device=input_features["S"].device,
|
|
1431
|
+
dtype=input_features["S"].dtype,
|
|
1432
|
+
)
|
|
1433
|
+
|
|
1434
|
+
# S_argmax [B, L] - the argmax sequence of residues, initialized with
|
|
1435
|
+
# the unknown residue type.
|
|
1436
|
+
S_argmax = torch.full(
|
|
1437
|
+
(B, L),
|
|
1438
|
+
fill_value=self.token_to_idx[UNKNOWN_AA],
|
|
1439
|
+
device=input_features["S"].device,
|
|
1440
|
+
dtype=input_features["S"].dtype,
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
# h_V_decoder_stack - list containing the hidden node embeddings from
|
|
1444
|
+
# each decoder layer; populated iteratively during the decoding.
|
|
1445
|
+
# h_V_decoder_stack[i] [B, L, H] - the hidden node embeddings,
|
|
1446
|
+
# the 0th entry is the initial node embeddings from the encoder, and
|
|
1447
|
+
# the i-th entry is the hidden node embeddings after the i-th decoder
|
|
1448
|
+
# layer.
|
|
1449
|
+
# NOTE: it is necessary to keep the embeddings from all decoder layers,
|
|
1450
|
+
# since later decoding positions rely on the intermediate decoder layer
|
|
1451
|
+
# embeddings of previously decoded residues.
|
|
1452
|
+
h_V_decoder_stack = [encoder_features["h_V"]] + [
|
|
1453
|
+
torch.zeros_like(
|
|
1454
|
+
encoder_features["h_V"], device=input_features["residue_mask"].device
|
|
1455
|
+
)
|
|
1456
|
+
for _ in range(len(self.decoder_layers))
|
|
1457
|
+
]
|
|
1458
|
+
|
|
1459
|
+
# batch_idx [B, 1] - the batch indices for the decoder.
|
|
1460
|
+
batch_idx = torch.arange(
|
|
1461
|
+
B, device=input_features["residue_mask"].device
|
|
1462
|
+
).unsqueeze(-1)
|
|
1463
|
+
|
|
1464
|
+
# Iteratively decode, updating the hidden sequence embeddings.
|
|
1465
|
+
for decoding_idx in range(L):
|
|
1466
|
+
# i [B, 1] - the indices of the residues to decode in the
|
|
1467
|
+
# current iteration, based on the decoding order.
|
|
1468
|
+
i = decoder_features["decoding_order"][:, decoding_idx].unsqueeze(-1)
|
|
1469
|
+
|
|
1470
|
+
# decode_last_mask_i [B, 1] - the mask for residues that should be
|
|
1471
|
+
# decoded last, where False is a residue that should be decoded
|
|
1472
|
+
# first (invalid or fixed), and True is a residue that should not be
|
|
1473
|
+
# decoded first (designed residues); at the current decoding
|
|
1474
|
+
# index.
|
|
1475
|
+
decode_last_mask_i = decoder_features["decode_last_mask"][batch_idx, i]
|
|
1476
|
+
|
|
1477
|
+
# residue_mask_i [B, 1] - the mask for the residue at the current
|
|
1478
|
+
# decoding index.
|
|
1479
|
+
residue_mask_i = input_features["residue_mask"][batch_idx, i]
|
|
1480
|
+
|
|
1481
|
+
# mask_E_i [B, 1, K] - the mask for the edges at the current
|
|
1482
|
+
# decoding index, gathered at the neighbor indices.
|
|
1483
|
+
mask_E_i = mask_E[batch_idx, i]
|
|
1484
|
+
|
|
1485
|
+
# S_i [B, 1] - the ground truth sequence for the residue at the
|
|
1486
|
+
# current decoding index (for designed positions, undefined).
|
|
1487
|
+
S_i = input_features["S"][batch_idx, i]
|
|
1488
|
+
|
|
1489
|
+
# Setup the temperature, bias, and pair bias for the current
|
|
1490
|
+
# decoding index.
|
|
1491
|
+
if input_features["temperature"] is not None:
|
|
1492
|
+
# temperature_i [B, 1] - the temperature for the residue at the
|
|
1493
|
+
# current decoding index.
|
|
1494
|
+
temperature_i = input_features["temperature"][batch_idx, i]
|
|
1495
|
+
else:
|
|
1496
|
+
temperature_i = None
|
|
1497
|
+
|
|
1498
|
+
if input_features["bias"] is not None:
|
|
1499
|
+
# bias_i [B, 1, self.vocab_size] - the bias for the residue at
|
|
1500
|
+
# the current decoding index.
|
|
1501
|
+
bias_i = input_features["bias"][batch_idx, i]
|
|
1502
|
+
else:
|
|
1503
|
+
bias_i = None
|
|
1504
|
+
|
|
1505
|
+
if input_features["pair_bias"] is not None:
|
|
1506
|
+
# pair_bias_i [B, 1, self.vocab_size, L, self.vocab_size] - the
|
|
1507
|
+
# pair bias for the residue at the current decoding index.
|
|
1508
|
+
pair_bias_i = input_features["pair_bias"][batch_idx, i]
|
|
1509
|
+
else:
|
|
1510
|
+
pair_bias_i = None
|
|
1511
|
+
|
|
1512
|
+
if input_features["symmetry_equivalence_group"] is not None:
|
|
1513
|
+
# symmetry_equivalence_group_i [B, 1] - the symmetry
|
|
1514
|
+
# equivalence group for the residue at the current decoding
|
|
1515
|
+
# index.
|
|
1516
|
+
symmetry_equivalence_group_i = input_features[
|
|
1517
|
+
"symmetry_equivalence_group"
|
|
1518
|
+
][batch_idx, i]
|
|
1519
|
+
else:
|
|
1520
|
+
symmetry_equivalence_group_i = None
|
|
1521
|
+
|
|
1522
|
+
if input_features["symmetry_weight"] is not None:
|
|
1523
|
+
# symmetry_weight_i [B, 1] - the symmetry weights for the
|
|
1524
|
+
# residue at the current decoding index.
|
|
1525
|
+
symmetry_weight_i = input_features["symmetry_weight"][batch_idx, i]
|
|
1526
|
+
else:
|
|
1527
|
+
symmetry_weight_i = None
|
|
1528
|
+
|
|
1529
|
+
# Gather the graph, encoder, and sequence features for the
|
|
1530
|
+
# current decoding index.
|
|
1531
|
+
# E_idx_i [B, 1, K] - the edge indices for the residue at the
|
|
1532
|
+
# current decoding index.
|
|
1533
|
+
E_idx_i = graph_features["E_idx"][batch_idx, i]
|
|
1534
|
+
|
|
1535
|
+
# h_E_i [B, 1, K, H] - the post-encoder edge features for the
|
|
1536
|
+
# residue at the current decoding index.
|
|
1537
|
+
h_E_i = encoder_features["h_E"][batch_idx, i]
|
|
1538
|
+
|
|
1539
|
+
# h_ES_i [B, 1, K, 2H] - the edge features concatenated with the
|
|
1540
|
+
# sequence embeddings for the destination nodes, for the residue at
|
|
1541
|
+
# the current decoding index.
|
|
1542
|
+
h_ES_i = cat_neighbors_nodes(h_S, h_E_i, E_idx_i)
|
|
1543
|
+
|
|
1544
|
+
# h_EXV_encoder_anti_causal_i [B, 1, K, 3H] - the encoder
|
|
1545
|
+
# embeddings, masked with the anti-causal mask, for the residue at
|
|
1546
|
+
# the current decoding index.
|
|
1547
|
+
h_EXV_encoder_anti_causal_i = h_EXV_encoder_anti_causal[batch_idx, i]
|
|
1548
|
+
|
|
1549
|
+
# causal_mask_i [B, 1, K, 1] - the causal mask for the residue at
|
|
1550
|
+
# the current decoding index.
|
|
1551
|
+
causal_mask_i = decoder_features["causal_mask"][batch_idx, i]
|
|
1552
|
+
|
|
1553
|
+
# Apply the decoder layers, updating the hidden node embeddings
|
|
1554
|
+
# for the current decoding index.
|
|
1555
|
+
for layer_idx, layer in enumerate(self.decoder_layers):
|
|
1556
|
+
# h_ESV_decoder_i [B, 1, K, 3H] - h_E_ij cat h_S_j cat
|
|
1557
|
+
# h_V_decoder_j; for the decoder embeddings, the edge features
|
|
1558
|
+
# are concatenated with the destination node sequence embeddings
|
|
1559
|
+
# and node features, for the residue at the current decoding
|
|
1560
|
+
# index.
|
|
1561
|
+
h_ESV_decoder_i = cat_neighbors_nodes(
|
|
1562
|
+
h_V_decoder_stack[layer_idx], h_ES_i, E_idx_i
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
# h_ESV_i [B, 1, K, 3H] - the encoder and decoder embeddings,
|
|
1566
|
+
# combined according to the causal and anti-causal masks.
|
|
1567
|
+
# Combine the encoder embeddings with the decoder embeddings,
|
|
1568
|
+
# using the causal and anti-causal masks. When decoding the
|
|
1569
|
+
# residue at position i:
|
|
1570
|
+
# - for residue j, decoded before i:
|
|
1571
|
+
# - h_ESV_ij = h_E_ij cat h_S_j cat h_V_decoder_j
|
|
1572
|
+
# - encoder edge embedding, decoder destination node
|
|
1573
|
+
# sequence embedding, and decoder destination node
|
|
1574
|
+
# embedding.
|
|
1575
|
+
# - for residue j, decoded after i (including i):
|
|
1576
|
+
# - h_ESV_ij = h_E_ij cat (0 vector) cat h_V_j
|
|
1577
|
+
# - encoder edge embedding, zero vector (no sequence
|
|
1578
|
+
# information), and encoder destination node
|
|
1579
|
+
# embedding. This prevents leakage of sequence
|
|
1580
|
+
# information.
|
|
1581
|
+
# - NOTE: h_V_j comes from the encoder.
|
|
1582
|
+
# - NOTE: h_E is not updated in the decoder, h_E_ij comes
|
|
1583
|
+
# from the encoder.
|
|
1584
|
+
# - NOTE: within the decoder layer itself, h_V_decoder_i will
|
|
1585
|
+
# be concatenated to h_ESV_ij.
|
|
1586
|
+
h_ESV_i = causal_mask_i * h_ESV_decoder_i + h_EXV_encoder_anti_causal_i
|
|
1587
|
+
|
|
1588
|
+
# h_V_decoder_i [B, 1, H] - the updated node features for the
|
|
1589
|
+
# decoder, after applying the layer at the current decoding
|
|
1590
|
+
# index.
|
|
1591
|
+
h_V_decoder_i = torch.utils.checkpoint.checkpoint(
|
|
1592
|
+
layer,
|
|
1593
|
+
h_V_decoder_stack[layer_idx][batch_idx, i],
|
|
1594
|
+
h_ESV_i,
|
|
1595
|
+
mask_V=residue_mask_i,
|
|
1596
|
+
mask_E=mask_E_i,
|
|
1597
|
+
use_reentrant=False,
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
# h_V_decoder_stack[layer_idx + 1][batch_idx, i] [B, 1, H] -
|
|
1601
|
+
# the updated node features for the decoder, after applying the
|
|
1602
|
+
# layer at the current decoding index.
|
|
1603
|
+
if not torch.is_grad_enabled():
|
|
1604
|
+
h_V_decoder_stack[layer_idx + 1][batch_idx, i] = h_V_decoder_i
|
|
1605
|
+
else:
|
|
1606
|
+
# For gradient tracking, we can't use in-place operations.
|
|
1607
|
+
h_V_decoder_stack[layer_idx + 1] = h_V_decoder_stack[
|
|
1608
|
+
layer_idx + 1
|
|
1609
|
+
].scatter(
|
|
1610
|
+
1,
|
|
1611
|
+
i.unsqueeze(-1).expand(-1, -1, self.hidden_dim),
|
|
1612
|
+
h_V_decoder_i,
|
|
1613
|
+
)
|
|
1614
|
+
|
|
1615
|
+
if input_features["symmetry_equivalence_group"] is None:
|
|
1616
|
+
# logits_i [B, 1, self.vocab_size] - the logits for the
|
|
1617
|
+
# residue at the current decoding index, computed from the
|
|
1618
|
+
# decoded node features.
|
|
1619
|
+
logits_i = self.W_out(h_V_decoder_stack[-1][batch_idx, i])
|
|
1620
|
+
else:
|
|
1621
|
+
# logits_i [B, 1, self.vocab_size] - the logits for the
|
|
1622
|
+
# residue at the current decoding index, computed from the
|
|
1623
|
+
# decoded node features, aggregated across symmetry groups,
|
|
1624
|
+
# weighted by the symmetry weights.
|
|
1625
|
+
logits_i += self.W_out(h_V_decoder_stack[-1][batch_idx, i]) * (
|
|
1626
|
+
1.0
|
|
1627
|
+
if symmetry_weight_i is None
|
|
1628
|
+
else symmetry_weight_i.unsqueeze(-1)
|
|
1629
|
+
)
|
|
1630
|
+
|
|
1631
|
+
# Compute the log probabilities, probabilities, sampled
|
|
1632
|
+
# probabilities, predicted sequence, and argmax sequence for the
|
|
1633
|
+
# current decoding index.
|
|
1634
|
+
sample_dict = self.logits_to_sample(
|
|
1635
|
+
logits_i, bias_i, pair_bias_i, S_sampled, temperature_i
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
log_probs_i = sample_dict["log_probs"]
|
|
1639
|
+
probs_i = sample_dict["probs"]
|
|
1640
|
+
probs_sample_i = sample_dict["probs_sample"]
|
|
1641
|
+
S_sampled_i = sample_dict["S_sampled"]
|
|
1642
|
+
S_argmax_i = sample_dict["S_argmax"]
|
|
1643
|
+
|
|
1644
|
+
if input_features["symmetry_equivalence_group"] is None:
|
|
1645
|
+
# Save the logits, probabilities, probabilities, sampled
|
|
1646
|
+
# probabilities, and log probabilities for the current decoding
|
|
1647
|
+
# index. These are saved but not sampled for invalid/fixed
|
|
1648
|
+
# residues.
|
|
1649
|
+
logits[batch_idx, i] = logits_i
|
|
1650
|
+
probs[batch_idx, i] = probs_i
|
|
1651
|
+
probs_sample[batch_idx, i] = probs_sample_i
|
|
1652
|
+
if not torch.is_grad_enabled():
|
|
1653
|
+
log_probs[batch_idx, i] = log_probs_i
|
|
1654
|
+
else:
|
|
1655
|
+
# For gradient tracking, we can't use in-place operations.
|
|
1656
|
+
log_probs = log_probs.scatter(
|
|
1657
|
+
1, i.unsqueeze(-1).expand(-1, -1, self.vocab_size), log_probs_i
|
|
1658
|
+
)
|
|
1659
|
+
|
|
1660
|
+
# Update the predicted sequence and argmax sequence for the
|
|
1661
|
+
# current decoding index.
|
|
1662
|
+
S_sampled[batch_idx, i] = (S_sampled_i * decode_last_mask_i) + (
|
|
1663
|
+
S_i * (~decode_last_mask_i)
|
|
1664
|
+
)
|
|
1665
|
+
S_argmax[batch_idx, i] = (S_argmax_i * decode_last_mask_i) + (
|
|
1666
|
+
S_i * (~decode_last_mask_i)
|
|
1667
|
+
)
|
|
1668
|
+
|
|
1669
|
+
# h_S_i [B, 1, self.hidden_dim] - the sequence embeddings of the
|
|
1670
|
+
# sampled/fixed residue at the current decoding index.
|
|
1671
|
+
h_S_i = self.W_s(S_sampled[batch_idx, i])
|
|
1672
|
+
|
|
1673
|
+
# Update the decoder sequence embeddings with the predicted
|
|
1674
|
+
# sequence for the current decoding index.
|
|
1675
|
+
if not torch.is_grad_enabled():
|
|
1676
|
+
h_S[batch_idx, i] = h_S_i
|
|
1677
|
+
else:
|
|
1678
|
+
# For gradient tracking, we can't use in-place operations.
|
|
1679
|
+
h_S = h_S.scatter(
|
|
1680
|
+
1, i.unsqueeze(-1).expand(-1, -1, self.hidden_dim), h_S_i
|
|
1681
|
+
)
|
|
1682
|
+
else:
|
|
1683
|
+
# symm_group_end_mask [B, 1] - mask for the residues that are
|
|
1684
|
+
# at the end of a symmetry group, where True is a residue that
|
|
1685
|
+
# is at the end of a symmetry group, and False is a residue that
|
|
1686
|
+
# is not at the end of a symmetry group. When we are at the last
|
|
1687
|
+
# decoding index, we know that all residues are at the end of a
|
|
1688
|
+
# symmetry group.
|
|
1689
|
+
if decoding_idx == (L - 1):
|
|
1690
|
+
symm_group_end_mask = torch.ones(
|
|
1691
|
+
(B, 1),
|
|
1692
|
+
device=input_features["residue_mask"].device,
|
|
1693
|
+
dtype=torch.bool,
|
|
1694
|
+
)
|
|
1695
|
+
else:
|
|
1696
|
+
# next_i [B, 1] - the indices of the next residues to decode
|
|
1697
|
+
# in the current iteration, based on the decoding order.
|
|
1698
|
+
next_i = decoder_features["decoding_order"][
|
|
1699
|
+
:, decoding_idx + 1
|
|
1700
|
+
].unsqueeze(-1)
|
|
1701
|
+
|
|
1702
|
+
# symmetry_equivalence_group_next_i [B, 1] - the symmetry
|
|
1703
|
+
# equivalence group for the residue at the next decoding
|
|
1704
|
+
# index.
|
|
1705
|
+
symmetry_equivalence_group_next_i = input_features[
|
|
1706
|
+
"symmetry_equivalence_group"
|
|
1707
|
+
][batch_idx, next_i]
|
|
1708
|
+
|
|
1709
|
+
symm_group_end_mask = (
|
|
1710
|
+
symmetry_equivalence_group_i
|
|
1711
|
+
!= symmetry_equivalence_group_next_i
|
|
1712
|
+
)
|
|
1713
|
+
|
|
1714
|
+
# same_symm_group_mask [B, L] - mask for the residues that
|
|
1715
|
+
# belong to the same symmetry group as the current residue,
|
|
1716
|
+
# where True is a residue that belongs to the same symmetry
|
|
1717
|
+
# group, and False is a residue that does not belong to the same
|
|
1718
|
+
# symmetry group.
|
|
1719
|
+
same_symm_group_mask = (
|
|
1720
|
+
input_features["symmetry_equivalence_group"]
|
|
1721
|
+
== symmetry_equivalence_group_i
|
|
1722
|
+
)
|
|
1723
|
+
|
|
1724
|
+
# symm_end_and_same_mask [B, L] - mask that combines the
|
|
1725
|
+
# symm_group_end_mask and same_symm_group_mask, where all
|
|
1726
|
+
# residues with the same symmetry group as the current residue
|
|
1727
|
+
# (if the current residue is at the end of a symmetry group)
|
|
1728
|
+
# are True, and all other residues are False.
|
|
1729
|
+
symm_end_and_same_mask = symm_group_end_mask & same_symm_group_mask
|
|
1730
|
+
|
|
1731
|
+
# symm_end_and_same_mask_vocab [B, L, self.vocab_size] -
|
|
1732
|
+
# symm_end_and_same_mask projected to the vocabulary size.
|
|
1733
|
+
symm_end_and_same_mask_vocab = symm_end_and_same_mask.unsqueeze(
|
|
1734
|
+
-1
|
|
1735
|
+
).expand(-1, -1, self.vocab_size)
|
|
1736
|
+
|
|
1737
|
+
# symm_end_and_same_mask_hidden [B, L, self.hidden_dim] -
|
|
1738
|
+
# symm_end_and_same_mask projected to the hidden dimension.
|
|
1739
|
+
symm_end_and_same_mask_for_hidden = symm_end_and_same_mask.unsqueeze(
|
|
1740
|
+
-1
|
|
1741
|
+
).expand(-1, -1, self.hidden_dim)
|
|
1742
|
+
|
|
1743
|
+
# Save the logits, probabilities, sampled probabilities, and log
|
|
1744
|
+
# probabilities for the current decoding index, if the residue
|
|
1745
|
+
# is at the end of a symmetry group.
|
|
1746
|
+
logits[symm_end_and_same_mask_vocab] = logits_i.expand(-1, L, -1)[
|
|
1747
|
+
symm_end_and_same_mask_vocab
|
|
1748
|
+
]
|
|
1749
|
+
probs[symm_end_and_same_mask_vocab] = probs_i.expand(-1, L, -1)[
|
|
1750
|
+
symm_end_and_same_mask_vocab
|
|
1751
|
+
]
|
|
1752
|
+
probs_sample[symm_end_and_same_mask_vocab] = probs_sample_i.expand(
|
|
1753
|
+
-1, L, -1
|
|
1754
|
+
)[symm_end_and_same_mask_vocab]
|
|
1755
|
+
if not torch.is_grad_enabled():
|
|
1756
|
+
log_probs[symm_end_and_same_mask_vocab] = log_probs_i.expand(
|
|
1757
|
+
-1, L, -1
|
|
1758
|
+
)[symm_end_and_same_mask_vocab]
|
|
1759
|
+
else:
|
|
1760
|
+
# For gradient tracking, we can't use in-place operations.
|
|
1761
|
+
log_probs = torch.where(
|
|
1762
|
+
symm_end_and_same_mask_vocab,
|
|
1763
|
+
log_probs_i.expand(-1, L, -1),
|
|
1764
|
+
log_probs,
|
|
1765
|
+
)
|
|
1766
|
+
|
|
1767
|
+
# Update the predicted sequence and argmax sequence for the
|
|
1768
|
+
# current decoding index, if the residue is at the end of a
|
|
1769
|
+
# symmetry group.
|
|
1770
|
+
S_sampled[symm_end_and_same_mask] = (
|
|
1771
|
+
(S_sampled_i * decode_last_mask_i) + (S_i * (~decode_last_mask_i))
|
|
1772
|
+
).expand(-1, L)[symm_end_and_same_mask]
|
|
1773
|
+
S_argmax[symm_end_and_same_mask] = (
|
|
1774
|
+
(S_argmax_i * decode_last_mask_i) + (S_i * (~decode_last_mask_i))
|
|
1775
|
+
).expand(-1, L)[symm_end_and_same_mask]
|
|
1776
|
+
|
|
1777
|
+
# h_S_i [B, 1, self.hidden_dim] - the sequence embeddings of the
|
|
1778
|
+
# sampled/fixed residue at the current decoding index.
|
|
1779
|
+
h_S_i = self.W_s(S_sampled[batch_idx, i])
|
|
1780
|
+
|
|
1781
|
+
# Update the decoder sequence embeddings with the predicted
|
|
1782
|
+
# sequence for the current decoding index, if the residue is at
|
|
1783
|
+
# the end of a symmetry group.
|
|
1784
|
+
if not torch.is_grad_enabled():
|
|
1785
|
+
h_S[symm_end_and_same_mask_for_hidden] = h_S_i.expand(-1, L, -1)[
|
|
1786
|
+
symm_end_and_same_mask_for_hidden
|
|
1787
|
+
]
|
|
1788
|
+
else:
|
|
1789
|
+
# For gradient tracking, we can't use in-place operations.
|
|
1790
|
+
h_S = torch.where(
|
|
1791
|
+
symm_end_and_same_mask_for_hidden, h_S_i.expand(-1, L, -1), h_S
|
|
1792
|
+
)
|
|
1793
|
+
|
|
1794
|
+
# Zero out the current position logits (used for accumulation)
|
|
1795
|
+
# if a batch example's current residue is at the end of a
|
|
1796
|
+
# symmetry group.
|
|
1797
|
+
logits_i[symm_group_end_mask] = 0.0
|
|
1798
|
+
|
|
1799
|
+
# Update the decoder features with the final node features, the computed
|
|
1800
|
+
# logits, log probabilities, probabilities, sampled probabilities,
|
|
1801
|
+
# predicted sequence, and argmax sequence.
|
|
1802
|
+
decoder_features["h_V"] = h_V_decoder_stack[-1]
|
|
1803
|
+
decoder_features["logits"] = logits
|
|
1804
|
+
decoder_features["log_probs"] = log_probs
|
|
1805
|
+
decoder_features["probs"] = probs
|
|
1806
|
+
decoder_features["probs_sample"] = probs_sample
|
|
1807
|
+
decoder_features["S_sampled"] = S_sampled
|
|
1808
|
+
decoder_features["S_argmax"] = S_argmax
|
|
1809
|
+
|
|
1810
|
+
def construct_output_dictionary(
|
|
1811
|
+
self, input_features, graph_features, encoder_features, decoder_features
|
|
1812
|
+
):
|
|
1813
|
+
"""
|
|
1814
|
+
Constructs the output dictionary based on the requested features.
|
|
1815
|
+
|
|
1816
|
+
Args:
|
|
1817
|
+
input_features (dict): Input features containing the requested
|
|
1818
|
+
features to return.
|
|
1819
|
+
- features_to_return (dict, optional): dictionary determining
|
|
1820
|
+
which features to return from the model. If None, return all
|
|
1821
|
+
features (including modified input features, graph features,
|
|
1822
|
+
encoder features, and decoder features). Otherwise,
|
|
1823
|
+
expects a dictionary with the following key, value pairs:
|
|
1824
|
+
- "input_features": list - the input features to return.
|
|
1825
|
+
- "graph_features": list - the graph features to return.
|
|
1826
|
+
- "encoder_features": list - the encoder features to return.
|
|
1827
|
+
- "decoder_features": list - the decoder features to return.
|
|
1828
|
+
graph_features (dict): Graph features containing the featurized
|
|
1829
|
+
node and edge inputs.
|
|
1830
|
+
encoder_features (dict): Encoder features containing the encoded
|
|
1831
|
+
protein node and protein edge features.
|
|
1832
|
+
decoder_features (dict): Decoder features containing the post-
|
|
1833
|
+
decoder features (including causal masks, logits, probabilities,
|
|
1834
|
+
predicted sequence, etc.).
|
|
1835
|
+
Returns:
|
|
1836
|
+
output_dict (dict): Output dictionary containing the requested
|
|
1837
|
+
features based on the input features' "features_to_return" key.
|
|
1838
|
+
If "features_to_return" is None, returns all features.
|
|
1839
|
+
"""
|
|
1840
|
+
# Check that the input features contains the necessary keys.
|
|
1841
|
+
if "features_to_return" not in input_features:
|
|
1842
|
+
raise ValueError("Input features must contain 'features_to_return' key.")
|
|
1843
|
+
|
|
1844
|
+
# Create the output dictionary based on the requested features.
|
|
1845
|
+
if input_features["features_to_return"] is None:
|
|
1846
|
+
output_dict = {
|
|
1847
|
+
"input_features": input_features,
|
|
1848
|
+
"graph_features": graph_features,
|
|
1849
|
+
"encoder_features": encoder_features,
|
|
1850
|
+
"decoder_features": decoder_features,
|
|
1851
|
+
}
|
|
1852
|
+
else:
|
|
1853
|
+
# Filter the output dictionary based on the requested features.
|
|
1854
|
+
output_dict = dict()
|
|
1855
|
+
output_dict["input_features"] = {
|
|
1856
|
+
key: input_features[key]
|
|
1857
|
+
for key in input_features["features_to_return"].get(
|
|
1858
|
+
"input_features", []
|
|
1859
|
+
)
|
|
1860
|
+
}
|
|
1861
|
+
output_dict["graph_features"] = {
|
|
1862
|
+
key: graph_features[key]
|
|
1863
|
+
for key in input_features["features_to_return"].get(
|
|
1864
|
+
"graph_features", []
|
|
1865
|
+
)
|
|
1866
|
+
}
|
|
1867
|
+
output_dict["encoder_features"] = {
|
|
1868
|
+
key: encoder_features[key]
|
|
1869
|
+
for key in input_features["features_to_return"].get(
|
|
1870
|
+
"encoder_features", []
|
|
1871
|
+
)
|
|
1872
|
+
}
|
|
1873
|
+
output_dict["decoder_features"] = {
|
|
1874
|
+
key: decoder_features[key]
|
|
1875
|
+
for key in input_features["features_to_return"].get(
|
|
1876
|
+
"decoder_features", []
|
|
1877
|
+
)
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
return output_dict
|
|
1881
|
+
|
|
1882
|
+
def forward(self, network_input):
|
|
1883
|
+
"""
|
|
1884
|
+
Forward pass of the ProteinMPNN model.
|
|
1885
|
+
|
|
1886
|
+
A NOTE on shapes:
|
|
1887
|
+
- B = batch dimension size
|
|
1888
|
+
- L = sequence length (number of residues)
|
|
1889
|
+
- K = number of neighbors per residue
|
|
1890
|
+
- H = hidden dimension size
|
|
1891
|
+
- vocab_size = self.vocab_size
|
|
1892
|
+
- num_atoms =
|
|
1893
|
+
self.graph_featurization_module.TOKEN_ENCODING.n_atoms_per_token
|
|
1894
|
+
- num_backbone_atoms = len(
|
|
1895
|
+
self.graph_featurization_module.BACKBONE_ATOM_NAMES
|
|
1896
|
+
)
|
|
1897
|
+
- num_virtual_atoms = len(
|
|
1898
|
+
self.graph_featurization_module.DATA_TO_CALCULATE_VIRTUAL_ATOMS
|
|
1899
|
+
)
|
|
1900
|
+
- num_rep_atoms = len(
|
|
1901
|
+
self.graph_featurization_module.REPRESENTATIVE_ATOM_NAMES
|
|
1902
|
+
)
|
|
1903
|
+
- num_edge_output_features =
|
|
1904
|
+
self.graph_featurization_module.num_edge_output_features
|
|
1905
|
+
- num_node_output_features =
|
|
1906
|
+
self.graph_featurization_module.num_node_output_features
|
|
1907
|
+
|
|
1908
|
+
Args:
|
|
1909
|
+
network_input (dict): Dictionary containing the input to the
|
|
1910
|
+
network.
|
|
1911
|
+
- input_features (dict): dictionary containing input features
|
|
1912
|
+
and all necessary information for the model to run.
|
|
1913
|
+
- X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
|
|
1914
|
+
polymer atoms.
|
|
1915
|
+
- X_m (torch.Tensor): [B, L, num_atoms] - Mask indicating
|
|
1916
|
+
which polymer atoms are valid.
|
|
1917
|
+
- S (torch.Tensor): [B, L] - Sequence of the polymer
|
|
1918
|
+
residues.
|
|
1919
|
+
- R_idx (torch.Tensor): [B, L] - indices of the residues.
|
|
1920
|
+
- chain_labels (torch.Tensor): [B, L] - chain labels for
|
|
1921
|
+
each residue.
|
|
1922
|
+
- residue_mask (torch.Tensor): [B, L] - Mask indicating
|
|
1923
|
+
which residues are valid.
|
|
1924
|
+
- designed_residue_mask (torch.Tensor): [B, L] - mask for
|
|
1925
|
+
the designed residues.
|
|
1926
|
+
- symmetry_equivalence_group (torch.Tensor, optional):
|
|
1927
|
+
[B, L] - an integer for every residue, indicating the
|
|
1928
|
+
symmetry group that it belongs to. If None, the
|
|
1929
|
+
residues are not grouped by symmetry. For example, if
|
|
1930
|
+
residue i and j should be decoded symmetrically, then
|
|
1931
|
+
symmetry_equivalence_group[i] ==
|
|
1932
|
+
symmetry_equivalence_group[j]. Must be torch.int64 to
|
|
1933
|
+
allow for use as an index. These values should range
|
|
1934
|
+
from 0 to the maximum number of symmetry groups - 1 for
|
|
1935
|
+
each example. NOTE: bias, pair_bias, and temperature
|
|
1936
|
+
should be the same for all residues in the symmetry
|
|
1937
|
+
equivalence group; otherwise, the intended behavior may
|
|
1938
|
+
not be achieved. The residues within a symmetry group
|
|
1939
|
+
should all have the same validity and design/fixed
|
|
1940
|
+
status.
|
|
1941
|
+
- symmetry_weight (torch.Tensor, optional): [B, L] - the
|
|
1942
|
+
weights for each residue, to be used when aggregating
|
|
1943
|
+
across its respective symmetry group. If None, the
|
|
1944
|
+
weights are assumed to be 1.0 for all residues.
|
|
1945
|
+
- bias (torch.Tensor, optional): [B, L, 21] - the
|
|
1946
|
+
per-residue bias to use for sampling. If None, the code
|
|
1947
|
+
will implicitly use a bias of 0.0 for all residues.
|
|
1948
|
+
- pair_bias (torch.Tensor, optional): [B, L, 21, L, 21] -
|
|
1949
|
+
the per-residue pair bias to use for sampling. If None,
|
|
1950
|
+
the code will implicitly use a pair bias of 0.0 for all
|
|
1951
|
+
residue pairs.
|
|
1952
|
+
- temperature (torch.Tensor, optional): [B, L] - the
|
|
1953
|
+
per-residue temperature to use for sampling. If None,
|
|
1954
|
+
the code will implicitly use a temperature of 1.0.
|
|
1955
|
+
- structure_noise (float): Standard deviation of the
|
|
1956
|
+
Gaussian noise to add to the input coordinates, in
|
|
1957
|
+
Angstroms.
|
|
1958
|
+
- decode_type (str): the type of decoding to use.
|
|
1959
|
+
- "teacher_forcing": Use teacher forcing for the
|
|
1960
|
+
decoder, where the decoder attends to the ground
|
|
1961
|
+
truth sequence S for all previously decoded
|
|
1962
|
+
residues.
|
|
1963
|
+
- "auto_regressive": Use auto-regressive decoding,
|
|
1964
|
+
where the decoder attends to the sequence and
|
|
1965
|
+
decoder representation of residues that have
|
|
1966
|
+
already been decoded (using the predicted sequence).
|
|
1967
|
+
- causality_pattern (str): The pattern of causality to use
|
|
1968
|
+
for the decoder. For all causality patterns, the
|
|
1969
|
+
decoding order is randomized.
|
|
1970
|
+
- "auto_regressive": Use an auto-regressive causality
|
|
1971
|
+
pattern, where residues can attend to the sequence
|
|
1972
|
+
and decoder representation of residues that have
|
|
1973
|
+
already been decoded (NOTE: as mentioned above,
|
|
1974
|
+
this will be randomized).
|
|
1975
|
+
- "unconditional": Residues cannot attend to the
|
|
1976
|
+
sequence or decoder representation of any other
|
|
1977
|
+
residues.
|
|
1978
|
+
- "conditional": Residues can attend to the sequence
|
|
1979
|
+
and decoder representation of all other residues.
|
|
1980
|
+
- "conditional_minus_self": Residues can attend to the
|
|
1981
|
+
sequence and decoder representation of all other
|
|
1982
|
+
residues, except for themselves (as destination
|
|
1983
|
+
nodes).
|
|
1984
|
+
- initialize_sequence_embedding_with_ground_truth (bool):
|
|
1985
|
+
- True: Initialize the sequence embedding with the
|
|
1986
|
+
ground truth sequence S.
|
|
1987
|
+
- If doing auto-regressive decoding, also
|
|
1988
|
+
initialize S_sampled with the ground truth
|
|
1989
|
+
sequence S, which should only affect the
|
|
1990
|
+
application of pair bias.
|
|
1991
|
+
- False: Initialize the sequence embedding with zeros.
|
|
1992
|
+
- If doing auto-regressive decoding, initialize
|
|
1993
|
+
S_sampled with unknown residues.
|
|
1994
|
+
- features_to_return (dict, optional): dictionary
|
|
1995
|
+
determining which features to return from the model. If
|
|
1996
|
+
None, return all features (including modified input
|
|
1997
|
+
features, graph features, encoder features, and decoder
|
|
1998
|
+
features). Otherwise, expects a dictionary with the
|
|
1999
|
+
following key, value pairs:
|
|
2000
|
+
- "input_features": list - the input features to return.
|
|
2001
|
+
- "graph_features": list - the graph features to return.
|
|
2002
|
+
- "encoder_features": list - the encoder features to
|
|
2003
|
+
return.
|
|
2004
|
+
- "decoder_features": list - the decoder features to
|
|
2005
|
+
return.
|
|
2006
|
+
- repeat_sample_num (int, optional): Number of times to
|
|
2007
|
+
repeat the samples along the batch dimension. If None,
|
|
2008
|
+
no repetition is performed. If greater than 1, the
|
|
2009
|
+
samples are repeated along the batch dimension. If
|
|
2010
|
+
greater than 1, B must be 1, since repeating samples
|
|
2011
|
+
along the batch dimension is not supported when more
|
|
2012
|
+
than one sample is provided in the batch.
|
|
2013
|
+
Side Effects:
|
|
2014
|
+
Any changes denoted below to input_features are also mutated on the
|
|
2015
|
+
original input features.
|
|
2016
|
+
Returns:
|
|
2017
|
+
network_output (dict): Output dictionary containing the requested
|
|
2018
|
+
features based on the input features' "features_to_return" key.
|
|
2019
|
+
- input_features (dict): The input features from above, with
|
|
2020
|
+
the following keys added or modified:
|
|
2021
|
+
- mask_for_loss (torch.Tensor): [B, L] - mask for loss,
|
|
2022
|
+
where True is a residue that is included in the loss
|
|
2023
|
+
calculation, and False is a residue that is not
|
|
2024
|
+
included in the loss calculation.
|
|
2025
|
+
- X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
|
|
2026
|
+
polymer atoms with added Gaussian noise.
|
|
2027
|
+
- X_pre_noise (torch.Tensor): [B, L, num_atoms, 3] -
|
|
2028
|
+
3D coordinates of polymer atoms before adding Gaussian
|
|
2029
|
+
noise ('X' before noise).
|
|
2030
|
+
- X_backbone (torch.Tensor): [B, L, num_backbone_atoms, 3] -
|
|
2031
|
+
3D coordinates of the backbone atoms for each residue,
|
|
2032
|
+
built from the noisy 'X' coordinates.
|
|
2033
|
+
- X_m_backbone (torch.Tensor): [B, L, num_backbone_atoms] -
|
|
2034
|
+
mask indicating which backbone atoms are valid.
|
|
2035
|
+
- X_virtual_atoms (torch.Tensor):
|
|
2036
|
+
[B, L, num_virtual_atoms, 3] - 3D coordinates of the
|
|
2037
|
+
virtual atoms for each residue, built from the noisy
|
|
2038
|
+
'X' coordinates.
|
|
2039
|
+
- X_m_virtual_atoms (torch.Tensor):
|
|
2040
|
+
[B, L, num_virtual_atoms] - mask indicating which
|
|
2041
|
+
virtual atoms are valid.
|
|
2042
|
+
- X_rep_atoms (torch.Tensor): [B, L, num_rep_atoms, 3] - 3D
|
|
2043
|
+
coordinates of the representative atoms for each
|
|
2044
|
+
residue, built from the noisy 'X' coordinates.
|
|
2045
|
+
- X_m_rep_atoms (torch.Tensor): [B, L, num_rep_atoms] -
|
|
2046
|
+
mask indicating which representative atoms are valid.
|
|
2047
|
+
- graph_features (dict): The graph features.
|
|
2048
|
+
- E_idx (torch.Tensor): [B, L, K] - indices of the top K
|
|
2049
|
+
nearest neighbors for each residue.
|
|
2050
|
+
- E (torch.Tensor): [B, L, K, num_edge_output_features] -
|
|
2051
|
+
Edge features for each pair of neighbors.
|
|
2052
|
+
- encoder_features (dict): The encoder features.
|
|
2053
|
+
- h_V (torch.Tensor): [B, L, H] - the protein node features
|
|
2054
|
+
after encoding message passing.
|
|
2055
|
+
- h_E (torch.Tensor): [B, L, K, H] - the protein edge
|
|
2056
|
+
features after encoding message passing.
|
|
2057
|
+
- decoder_features (dict): The decoder features.
|
|
2058
|
+
- causal_mask (torch.Tensor): [B, L, K, 1] - the causal
|
|
2059
|
+
mask for the decoder.
|
|
2060
|
+
- anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
|
|
2061
|
+
anti-causal mask for the decoder.
|
|
2062
|
+
- decoding_order (torch.Tensor): [B, L] - the order in
|
|
2063
|
+
which the residues should be decoded.
|
|
2064
|
+
- decode_last_mask (torch.Tensor): [B, L] - mask for
|
|
2065
|
+
residues that should be decoded last, where False is a
|
|
2066
|
+
residue that should be decoded first (invalid or
|
|
2067
|
+
fixed), and True is a residue that should not be
|
|
2068
|
+
decoded first (designed residues).
|
|
2069
|
+
- h_V (torch.Tensor): [B, L, H] - the updated node features
|
|
2070
|
+
for the decoder.
|
|
2071
|
+
- logits (torch.Tensor): [B, L, vocab_size] - the logits
|
|
2072
|
+
for the sequence.
|
|
2073
|
+
- log_probs (torch.Tensor): [B, L, vocab_size] - the log
|
|
2074
|
+
probabilities for the sequence.
|
|
2075
|
+
- probs (torch.Tensor): [B, L, vocab_size] - the
|
|
2076
|
+
probabilities for the sequence.
|
|
2077
|
+
- probs_sample (torch.Tensor): [B, L, vocab_size] -
|
|
2078
|
+
the probabilities for the sequence, with the unknown
|
|
2079
|
+
residues zeroed out and the other residues normalized.
|
|
2080
|
+
- S_sampled (torch.Tensor): [B, L] - the predicted
|
|
2081
|
+
sequence, sampled from the probabilities (unknown
|
|
2082
|
+
residues are not sampled).
|
|
2083
|
+
- S_argmax (torch.Tensor): [B, L] - the predicted sequence,
|
|
2084
|
+
obtained by taking the argmax of the probabilities
|
|
2085
|
+
(unknown residues are not selected).
|
|
2086
|
+
"""
|
|
2087
|
+
input_features = network_input["input_features"]
|
|
2088
|
+
|
|
2089
|
+
# Check that the input features contains the necessary keys.
|
|
2090
|
+
if "decode_type" not in input_features:
|
|
2091
|
+
raise ValueError("Input features must contain 'decode_type' key.")
|
|
2092
|
+
|
|
2093
|
+
# Setup masks (added to the input features).
|
|
2094
|
+
self.sample_and_construct_masks(input_features)
|
|
2095
|
+
|
|
2096
|
+
# Graph featurization (also modifies/adds to input_features).
|
|
2097
|
+
graph_features = self.graph_featurization(input_features)
|
|
2098
|
+
|
|
2099
|
+
# Run the encoder.
|
|
2100
|
+
encoder_features = self.encode(input_features, graph_features)
|
|
2101
|
+
|
|
2102
|
+
# Setup for decoder (repeat features along the batch dimension, modifies
|
|
2103
|
+
# input_features, graph_features, and encoder_features).
|
|
2104
|
+
self.repeat_along_batch(
|
|
2105
|
+
input_features,
|
|
2106
|
+
graph_features,
|
|
2107
|
+
encoder_features,
|
|
2108
|
+
)
|
|
2109
|
+
|
|
2110
|
+
# Set up the causality masks.
|
|
2111
|
+
decoder_features = self.setup_causality_masks(input_features, graph_features)
|
|
2112
|
+
|
|
2113
|
+
# Decoder, either teacher forcing or auto-regressive.
|
|
2114
|
+
if input_features["decode_type"] == "teacher_forcing":
|
|
2115
|
+
self.decode_teacher_forcing(
|
|
2116
|
+
input_features, graph_features, encoder_features, decoder_features
|
|
2117
|
+
)
|
|
2118
|
+
elif input_features["decode_type"] == "auto_regressive":
|
|
2119
|
+
self.decode_auto_regressive(
|
|
2120
|
+
input_features, graph_features, encoder_features, decoder_features
|
|
2121
|
+
)
|
|
2122
|
+
else:
|
|
2123
|
+
raise ValueError(f"Unknown decode_type: {input_features['decode_type']}.")
|
|
2124
|
+
|
|
2125
|
+
# Create the output dictionary based on the requested features.
|
|
2126
|
+
network_output = self.construct_output_dictionary(
|
|
2127
|
+
input_features, graph_features, encoder_features, decoder_features
|
|
2128
|
+
)
|
|
2129
|
+
|
|
2130
|
+
return network_output
|
|
2131
|
+
|
|
2132
|
+
|
|
2133
|
+
class SolubleMPNN(ProteinMPNN):
|
|
2134
|
+
"""
|
|
2135
|
+
Same as ProteinMPNN, but with different training set and different weights.
|
|
2136
|
+
"""
|
|
2137
|
+
|
|
2138
|
+
pass
|
|
2139
|
+
|
|
2140
|
+
|
|
2141
|
+
class AntibodyMPNN(ProteinMPNN):
|
|
2142
|
+
"""
|
|
2143
|
+
Same as ProteinMPNN, but with different training set and different weights.
|
|
2144
|
+
"""
|
|
2145
|
+
|
|
2146
|
+
pass
|
|
2147
|
+
|
|
2148
|
+
|
|
2149
|
+
class MembraneMPNN(ProteinMPNN):
|
|
2150
|
+
"""
|
|
2151
|
+
Class for per-residue and global membrane label version of MPNN.
|
|
2152
|
+
"""
|
|
2153
|
+
|
|
2154
|
+
HAS_NODE_FEATURES = True
|
|
2155
|
+
|
|
2156
|
+
def __init__(
|
|
2157
|
+
self,
|
|
2158
|
+
num_node_features=128,
|
|
2159
|
+
num_edge_features=128,
|
|
2160
|
+
hidden_dim=128,
|
|
2161
|
+
num_encoder_layers=3,
|
|
2162
|
+
num_decoder_layers=3,
|
|
2163
|
+
num_neighbors=48,
|
|
2164
|
+
dropout_rate=0.1,
|
|
2165
|
+
num_positional_embeddings=16,
|
|
2166
|
+
min_rbf_mean=2.0,
|
|
2167
|
+
max_rbf_mean=22.0,
|
|
2168
|
+
num_rbf=16,
|
|
2169
|
+
num_membrane_classes=3,
|
|
2170
|
+
):
|
|
2171
|
+
"""
|
|
2172
|
+
Setup the MembraneMPNN model.
|
|
2173
|
+
|
|
2174
|
+
All args are the same as the parents class, except for the following:
|
|
2175
|
+
Args:
|
|
2176
|
+
num_membrane_classes (int): Number of membrane classes.
|
|
2177
|
+
"""
|
|
2178
|
+
# The only change necessary here is the graph featurization module.
|
|
2179
|
+
graph_featurization_module = ProteinFeaturesMembrane(
|
|
2180
|
+
num_edge_output_features=num_edge_features,
|
|
2181
|
+
num_node_output_features=num_node_features,
|
|
2182
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
2183
|
+
min_rbf_mean=min_rbf_mean,
|
|
2184
|
+
max_rbf_mean=max_rbf_mean,
|
|
2185
|
+
num_rbf=num_rbf,
|
|
2186
|
+
num_neighbors=num_neighbors,
|
|
2187
|
+
num_membrane_classes=num_membrane_classes,
|
|
2188
|
+
)
|
|
2189
|
+
|
|
2190
|
+
super(MembraneMPNN, self).__init__(
|
|
2191
|
+
num_node_features=num_node_features,
|
|
2192
|
+
num_edge_features=num_edge_features,
|
|
2193
|
+
hidden_dim=hidden_dim,
|
|
2194
|
+
num_encoder_layers=num_encoder_layers,
|
|
2195
|
+
num_decoder_layers=num_decoder_layers,
|
|
2196
|
+
num_neighbors=num_neighbors,
|
|
2197
|
+
dropout_rate=dropout_rate,
|
|
2198
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
2199
|
+
min_rbf_mean=min_rbf_mean,
|
|
2200
|
+
max_rbf_mean=max_rbf_mean,
|
|
2201
|
+
num_rbf=num_rbf,
|
|
2202
|
+
graph_featurization_module=graph_featurization_module,
|
|
2203
|
+
)
|
|
2204
|
+
|
|
2205
|
+
|
|
2206
|
+
class PSSMMPNN(ProteinMPNN):
|
|
2207
|
+
"""
|
|
2208
|
+
Class for pssm-aware version of MPNN.
|
|
2209
|
+
"""
|
|
2210
|
+
|
|
2211
|
+
HAS_NODE_FEATURES = True
|
|
2212
|
+
|
|
2213
|
+
def __init__(
|
|
2214
|
+
self,
|
|
2215
|
+
num_node_features=128,
|
|
2216
|
+
num_edge_features=128,
|
|
2217
|
+
hidden_dim=128,
|
|
2218
|
+
num_encoder_layers=3,
|
|
2219
|
+
num_decoder_layers=3,
|
|
2220
|
+
num_neighbors=48,
|
|
2221
|
+
dropout_rate=0.1,
|
|
2222
|
+
num_positional_embeddings=16,
|
|
2223
|
+
min_rbf_mean=2.0,
|
|
2224
|
+
max_rbf_mean=22.0,
|
|
2225
|
+
num_rbf=16,
|
|
2226
|
+
num_pssm_features=20,
|
|
2227
|
+
):
|
|
2228
|
+
"""
|
|
2229
|
+
Setup the PSSMMPNN model.
|
|
2230
|
+
|
|
2231
|
+
All args are the same as the parents class, except for the following:
|
|
2232
|
+
Args:
|
|
2233
|
+
num_pssm_features (int): Number of PSSM features.
|
|
2234
|
+
"""
|
|
2235
|
+
# The only change necessary here is the graph featurization module.
|
|
2236
|
+
graph_featurization_module = ProteinFeaturesPSSM(
|
|
2237
|
+
num_edge_output_features=num_edge_features,
|
|
2238
|
+
num_node_output_features=num_node_features,
|
|
2239
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
2240
|
+
min_rbf_mean=min_rbf_mean,
|
|
2241
|
+
max_rbf_mean=max_rbf_mean,
|
|
2242
|
+
num_rbf=num_rbf,
|
|
2243
|
+
num_neighbors=num_neighbors,
|
|
2244
|
+
num_pssm_features=num_pssm_features,
|
|
2245
|
+
)
|
|
2246
|
+
|
|
2247
|
+
super(PSSMMPNN, self).__init__(
|
|
2248
|
+
num_node_features=num_node_features,
|
|
2249
|
+
num_edge_features=num_edge_features,
|
|
2250
|
+
hidden_dim=hidden_dim,
|
|
2251
|
+
num_encoder_layers=num_encoder_layers,
|
|
2252
|
+
num_decoder_layers=num_decoder_layers,
|
|
2253
|
+
num_neighbors=num_neighbors,
|
|
2254
|
+
dropout_rate=dropout_rate,
|
|
2255
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
2256
|
+
min_rbf_mean=min_rbf_mean,
|
|
2257
|
+
max_rbf_mean=max_rbf_mean,
|
|
2258
|
+
num_rbf=num_rbf,
|
|
2259
|
+
graph_featurization_module=graph_featurization_module,
|
|
2260
|
+
)
|
|
2261
|
+
|
|
2262
|
+
|
|
2263
|
+
class LigandMPNN(ProteinMPNN):
|
|
2264
|
+
"""
|
|
2265
|
+
Class for ligand-aware version of MPNN.
|
|
2266
|
+
"""
|
|
2267
|
+
|
|
2268
|
+
# Although there are node-like features, they are actually the protein-
|
|
2269
|
+
# ligand subgraph features, so we set this to False. Note, this is because
|
|
2270
|
+
# none of these features are embedded prior to protein encoding.
|
|
2271
|
+
HAS_NODE_FEATURES = False
|
|
2272
|
+
|
|
2273
|
+
def __init__(
|
|
2274
|
+
self,
|
|
2275
|
+
num_node_features=128,
|
|
2276
|
+
num_edge_features=128,
|
|
2277
|
+
hidden_dim=128,
|
|
2278
|
+
num_encoder_layers=3,
|
|
2279
|
+
num_decoder_layers=3,
|
|
2280
|
+
num_neighbors=32,
|
|
2281
|
+
dropout_rate=0.1,
|
|
2282
|
+
num_positional_embeddings=16,
|
|
2283
|
+
min_rbf_mean=2.0,
|
|
2284
|
+
max_rbf_mean=22.0,
|
|
2285
|
+
num_rbf=16,
|
|
2286
|
+
num_context_atoms=25,
|
|
2287
|
+
num_context_encoding_layers=2,
|
|
2288
|
+
overall_atomize_side_chain_probability=0.5,
|
|
2289
|
+
per_residue_atomize_side_chain_probability=0.02,
|
|
2290
|
+
):
|
|
2291
|
+
# Pass the num_context_atoms to the graph featurization module.
|
|
2292
|
+
graph_featurization_module = ProteinFeaturesLigand(
|
|
2293
|
+
num_edge_output_features=num_edge_features,
|
|
2294
|
+
num_node_output_features=num_node_features,
|
|
2295
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
2296
|
+
min_rbf_mean=min_rbf_mean,
|
|
2297
|
+
max_rbf_mean=max_rbf_mean,
|
|
2298
|
+
num_rbf=num_rbf,
|
|
2299
|
+
num_neighbors=num_neighbors,
|
|
2300
|
+
num_context_atoms=num_context_atoms,
|
|
2301
|
+
)
|
|
2302
|
+
|
|
2303
|
+
super(LigandMPNN, self).__init__(
|
|
2304
|
+
num_node_features=num_node_features,
|
|
2305
|
+
num_edge_features=num_edge_features,
|
|
2306
|
+
hidden_dim=hidden_dim,
|
|
2307
|
+
num_encoder_layers=num_encoder_layers,
|
|
2308
|
+
num_decoder_layers=num_decoder_layers,
|
|
2309
|
+
num_neighbors=num_neighbors,
|
|
2310
|
+
dropout_rate=dropout_rate,
|
|
2311
|
+
num_positional_embeddings=num_positional_embeddings,
|
|
2312
|
+
min_rbf_mean=min_rbf_mean,
|
|
2313
|
+
max_rbf_mean=max_rbf_mean,
|
|
2314
|
+
num_rbf=num_rbf,
|
|
2315
|
+
graph_featurization_module=graph_featurization_module,
|
|
2316
|
+
)
|
|
2317
|
+
self.overall_atomize_side_chain_probability = (
|
|
2318
|
+
overall_atomize_side_chain_probability
|
|
2319
|
+
)
|
|
2320
|
+
self.per_residue_atomize_side_chain_probability = (
|
|
2321
|
+
per_residue_atomize_side_chain_probability
|
|
2322
|
+
)
|
|
2323
|
+
|
|
2324
|
+
# Linear layer for embedding the protein-ligand edge features.
|
|
2325
|
+
self.W_protein_to_ligand_edges_embed = nn.Linear(
|
|
2326
|
+
num_node_features, hidden_dim, bias=True
|
|
2327
|
+
)
|
|
2328
|
+
|
|
2329
|
+
# Linear layers for embedding the output of the protein encoder.
|
|
2330
|
+
self.W_protein_encoding_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
2331
|
+
|
|
2332
|
+
# Linear layers for embedding the ligand nodes and edges.
|
|
2333
|
+
self.W_ligand_nodes_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
2334
|
+
self.W_ligand_edges_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
2335
|
+
|
|
2336
|
+
# Linear layer for the final context embedding.
|
|
2337
|
+
self.W_final_context_embed = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
|
2338
|
+
|
|
2339
|
+
# Layer norm for the final context embedding.
|
|
2340
|
+
self.final_context_norm = nn.LayerNorm(hidden_dim)
|
|
2341
|
+
|
|
2342
|
+
# Save the number of context encoding layers.
|
|
2343
|
+
self.num_context_encoding_layers = num_context_encoding_layers
|
|
2344
|
+
|
|
2345
|
+
self.protein_ligand_context_encoder_layers = nn.ModuleList(
|
|
2346
|
+
[
|
|
2347
|
+
DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout_rate)
|
|
2348
|
+
for _ in range(num_context_encoding_layers)
|
|
2349
|
+
]
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
self.ligand_context_encoder_layers = nn.ModuleList(
|
|
2353
|
+
[
|
|
2354
|
+
DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout_rate)
|
|
2355
|
+
for _ in range(num_context_encoding_layers)
|
|
2356
|
+
]
|
|
2357
|
+
)
|
|
2358
|
+
|
|
2359
|
+
def sample_and_construct_masks(self, input_features):
|
|
2360
|
+
"""
|
|
2361
|
+
Sample and construct masks for the input features.
|
|
2362
|
+
|
|
2363
|
+
Args:
|
|
2364
|
+
input_features (dict): Input features containing the residue mask.
|
|
2365
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
2366
|
+
- S (torch.Tensor): [B, L] - sequence of residues.
|
|
2367
|
+
- designed_residue_mask (torch.Tensor): [B, L] - mask for the
|
|
2368
|
+
designed residues.
|
|
2369
|
+
- atomize_side_chains (bool): Whether to atomize side chains.
|
|
2370
|
+
Side Effects:
|
|
2371
|
+
input_features["residue_mask"] (torch.Tensor): [B, L] - mask for the
|
|
2372
|
+
residues, where True is a residue that is valid and False is a
|
|
2373
|
+
residue that is invalid.
|
|
2374
|
+
input_features["known_residue_mask"] (torch.Tensor): [B, L] - mask
|
|
2375
|
+
for known residues, where True is a residue with one of the
|
|
2376
|
+
canonical residue types, and False is a residue with an unknown
|
|
2377
|
+
residue type.
|
|
2378
|
+
input_features["designed_residue_mask"] (torch.Tensor): [B, L] -
|
|
2379
|
+
mask for designed residues, where True is a residue that is
|
|
2380
|
+
designed, and False is a residue that is not designed.
|
|
2381
|
+
input_features["hide_side_chain_mask"] (torch.Tensor): [B, L] - mask
|
|
2382
|
+
for hiding side chains, where True is a residue with hidden side
|
|
2383
|
+
chains, and False is a residue with revealed side chains.
|
|
2384
|
+
input_features["mask_for_loss"] (torch.Tensor): [B, L] - mask for
|
|
2385
|
+
loss, where True is a residue that is included in the loss
|
|
2386
|
+
calculation, and False is a residue that is not included in the
|
|
2387
|
+
loss calculation.
|
|
2388
|
+
"""
|
|
2389
|
+
# Create the masks for ProteinMPNN.
|
|
2390
|
+
super().sample_and_construct_masks(input_features)
|
|
2391
|
+
|
|
2392
|
+
# Check that the input features contain the necessary keys.
|
|
2393
|
+
if "atomize_side_chains" not in input_features:
|
|
2394
|
+
raise ValueError("Input features must contain 'atomize_side_chains' key.")
|
|
2395
|
+
|
|
2396
|
+
# Create the mask for hiding or revealing side chains.
|
|
2397
|
+
# With no side chain atomization, the side chain mask is all ones.
|
|
2398
|
+
if input_features["atomize_side_chains"]:
|
|
2399
|
+
# If we are training, randomly reveal side chains.
|
|
2400
|
+
if self.training:
|
|
2401
|
+
# With a probability specified as the overall atomization
|
|
2402
|
+
# side chain probability, we reveal some side chains (otherwise,
|
|
2403
|
+
# we hide all side chains).
|
|
2404
|
+
if (
|
|
2405
|
+
sample_bernoulli_rv(self.overall_atomize_side_chain_probability)
|
|
2406
|
+
== 1
|
|
2407
|
+
):
|
|
2408
|
+
reveal_side_chain_mask = (
|
|
2409
|
+
torch.rand(
|
|
2410
|
+
input_features["S"].shape, device=input_features["S"].device
|
|
2411
|
+
)
|
|
2412
|
+
< self.per_residue_atomize_side_chain_probability
|
|
2413
|
+
)
|
|
2414
|
+
hide_side_chain_mask = ~reveal_side_chain_mask
|
|
2415
|
+
else:
|
|
2416
|
+
hide_side_chain_mask = torch.ones(
|
|
2417
|
+
input_features["S"].shape, device=input_features["S"].device
|
|
2418
|
+
).bool()
|
|
2419
|
+
# If we are not training, only the side chains of fixed residues
|
|
2420
|
+
# are revealed.
|
|
2421
|
+
else:
|
|
2422
|
+
hide_side_chain_mask = input_features["designed_residue_mask"].clone()
|
|
2423
|
+
else:
|
|
2424
|
+
hide_side_chain_mask = torch.ones(
|
|
2425
|
+
input_features["S"].shape, device=input_features["S"].device
|
|
2426
|
+
).bool()
|
|
2427
|
+
|
|
2428
|
+
# Save the hide side chain mask in the input features.
|
|
2429
|
+
input_features["hide_side_chain_mask"] = hide_side_chain_mask
|
|
2430
|
+
|
|
2431
|
+
# Update the mask for the loss to include the hide side chain mask.
|
|
2432
|
+
input_features["mask_for_loss"] = (
|
|
2433
|
+
input_features["mask_for_loss"] & input_features["hide_side_chain_mask"]
|
|
2434
|
+
)
|
|
2435
|
+
|
|
2436
|
+
def encode(self, input_features, graph_features):
|
|
2437
|
+
"""
|
|
2438
|
+
Encode the protein features with ligand context.
|
|
2439
|
+
|
|
2440
|
+
NOTE: M = self.graph_featurization_module.num_context_atoms, the number
|
|
2441
|
+
of ligand atoms in each residue subgraph.
|
|
2442
|
+
|
|
2443
|
+
Args:
|
|
2444
|
+
input_features (dict): Input features containing the residue mask.
|
|
2445
|
+
- residue_mask (torch.Tensor): [B, L] - mask for the residues.
|
|
2446
|
+
- ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask for the
|
|
2447
|
+
ligand subgraph nodes.
|
|
2448
|
+
graph_features (dict): Graph features containing the featurized
|
|
2449
|
+
node and edge inputs.
|
|
2450
|
+
- E_protein_to_ligand (torch.Tensor): [B, L, M,
|
|
2451
|
+
self.num_edge_features] - protein-ligand edge features.
|
|
2452
|
+
- ligand_subgraph_nodes (torch.Tensor): [B, L, M,
|
|
2453
|
+
self.num_node_features] - ligand subgraph node features.
|
|
2454
|
+
- ligand_subgraph_edges (torch.Tensor): [B, L, M, M,
|
|
2455
|
+
self.num_edge_features] - ligand subgraph edge features.
|
|
2456
|
+
Returns:
|
|
2457
|
+
encoder_features (dict): Encoded features containing the encoded
|
|
2458
|
+
protein node and protein edge features.
|
|
2459
|
+
- h_V (torch.Tensor): [B, L, self.hidden_dim] - the protein node
|
|
2460
|
+
features after protein encoding and ligand context
|
|
2461
|
+
encoding.
|
|
2462
|
+
- h_E (torch.Tensor): [B, L, K, self.hidden_dim] - the protein
|
|
2463
|
+
edge features after protein encoding message passing.
|
|
2464
|
+
"""
|
|
2465
|
+
# Use the parent encode method to get the initial protein encoding.
|
|
2466
|
+
encoder_features = super().encode(input_features, graph_features)
|
|
2467
|
+
|
|
2468
|
+
# Check the encoder features.
|
|
2469
|
+
if "h_V" not in encoder_features:
|
|
2470
|
+
raise ValueError("Encoder features must contain 'h_V' key.")
|
|
2471
|
+
|
|
2472
|
+
# Check that the input features contain the necessary keys.
|
|
2473
|
+
if "residue_mask" not in input_features:
|
|
2474
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
2475
|
+
if "ligand_subgraph_Y_m" not in input_features:
|
|
2476
|
+
raise ValueError("Input features must contain 'ligand_subgraph_Y_m' key.")
|
|
2477
|
+
|
|
2478
|
+
# Check that the graph features contain the necessary keys.
|
|
2479
|
+
if "E_protein_to_ligand" not in graph_features:
|
|
2480
|
+
raise ValueError("Graph features must contain 'E_protein_to_ligand' key.")
|
|
2481
|
+
if "ligand_subgraph_nodes" not in graph_features:
|
|
2482
|
+
raise ValueError("Graph features must contain 'ligand_subgraph_nodes' key.")
|
|
2483
|
+
if "ligand_subgraph_edges" not in graph_features:
|
|
2484
|
+
raise ValueError("Graph features must contain 'ligand_subgraph_edges' key.")
|
|
2485
|
+
|
|
2486
|
+
# Compute the protein-ligand edge feature encoding.
|
|
2487
|
+
# h_E_protein_to_ligand [B, L, M, self.hidden_dim] - the embedding of
|
|
2488
|
+
# the protein-ligand edge features.
|
|
2489
|
+
h_E_protein_to_ligand = self.W_protein_to_ligand_edges_embed(
|
|
2490
|
+
graph_features["E_protein_to_ligand"]
|
|
2491
|
+
)
|
|
2492
|
+
|
|
2493
|
+
# Construct the starting context features, to aggregate the ligand
|
|
2494
|
+
# context; will be updated in the context encoder.
|
|
2495
|
+
# h_V_context [B, L, self.hidden_dim] - the embedding of context.
|
|
2496
|
+
h_V_context = self.W_protein_encoding_embed(encoder_features["h_V"])
|
|
2497
|
+
|
|
2498
|
+
# Construct the ligand subgraph edge mask.
|
|
2499
|
+
# ligand_subgraph_Y_m_edges [B, L, M, M] - the mask for the
|
|
2500
|
+
# ligand-ligand subgraph edges.
|
|
2501
|
+
ligand_subgraph_Y_m_edges = (
|
|
2502
|
+
input_features["ligand_subgraph_Y_m"][:, :, :, None]
|
|
2503
|
+
* input_features["ligand_subgraph_Y_m"][:, :, None, :]
|
|
2504
|
+
)
|
|
2505
|
+
|
|
2506
|
+
# Embed the ligand nodes.
|
|
2507
|
+
# ligand_subgraph_nodes [B, L, M, self.hidden_dim] - the embedding of
|
|
2508
|
+
# the ligand nodes in the subgraph.
|
|
2509
|
+
h_ligand_subgraph_nodes = self.W_ligand_nodes_embed(
|
|
2510
|
+
graph_features["ligand_subgraph_nodes"]
|
|
2511
|
+
)
|
|
2512
|
+
|
|
2513
|
+
# Embed the ligand edges.
|
|
2514
|
+
# ligand_subgraph_edges [B, L, M, M, self.hidden_dim] - the embedding
|
|
2515
|
+
# of the ligand edges in the subgraph.
|
|
2516
|
+
h_ligand_subgraph_edges = self.W_ligand_edges_embed(
|
|
2517
|
+
graph_features["ligand_subgraph_edges"]
|
|
2518
|
+
)
|
|
2519
|
+
|
|
2520
|
+
# Run the context encoder layers for the protein-ligand context.
|
|
2521
|
+
for i in range(self.num_context_encoding_layers):
|
|
2522
|
+
# Message passing in the ligand subgraph.
|
|
2523
|
+
# BUG: to replicate the original LigandMPNN,destination nodes are
|
|
2524
|
+
# not concatenated into ligand_subgraph_edges; This breaks message
|
|
2525
|
+
# passing in the small molecule graph (no message passing in the
|
|
2526
|
+
# ligand subgraph).
|
|
2527
|
+
h_ligand_subgraph_nodes = torch.utils.checkpoint.checkpoint(
|
|
2528
|
+
self.ligand_context_encoder_layers[i],
|
|
2529
|
+
h_ligand_subgraph_nodes,
|
|
2530
|
+
h_ligand_subgraph_edges,
|
|
2531
|
+
input_features["ligand_subgraph_Y_m"],
|
|
2532
|
+
ligand_subgraph_Y_m_edges,
|
|
2533
|
+
use_reentrant=False,
|
|
2534
|
+
)
|
|
2535
|
+
|
|
2536
|
+
# Concatenate the protein-ligand edge features with the ligand
|
|
2537
|
+
# hidden note features (effectively treating the ligand subgraph
|
|
2538
|
+
# node features as protein-ligand edge features).
|
|
2539
|
+
# h_E_protein_to_ligand_cat [B, L, M, 2 * self.hidden_dim] - the
|
|
2540
|
+
# concatenated protein-ligand edge features.
|
|
2541
|
+
h_E_protein_to_ligand_cat = torch.cat(
|
|
2542
|
+
[h_E_protein_to_ligand, h_ligand_subgraph_nodes], -1
|
|
2543
|
+
)
|
|
2544
|
+
|
|
2545
|
+
# h_V_context [B, L, self.hidden_dim] - the updated context node
|
|
2546
|
+
# features. Message passing from ligand subgraph to the protein.
|
|
2547
|
+
h_V_context = torch.utils.checkpoint.checkpoint(
|
|
2548
|
+
self.protein_ligand_context_encoder_layers[i],
|
|
2549
|
+
h_V_context,
|
|
2550
|
+
h_E_protein_to_ligand_cat,
|
|
2551
|
+
input_features["residue_mask"],
|
|
2552
|
+
input_features["ligand_subgraph_Y_m"],
|
|
2553
|
+
use_reentrant=False,
|
|
2554
|
+
)
|
|
2555
|
+
|
|
2556
|
+
# Final context embedding.
|
|
2557
|
+
h_V_context = self.W_final_context_embed(h_V_context)
|
|
2558
|
+
|
|
2559
|
+
# Update the protein node features with the context with a residual
|
|
2560
|
+
# connection (after apply dropout and layer norm to the context).
|
|
2561
|
+
encoder_features["h_V"] = encoder_features["h_V"] + self.final_context_norm(
|
|
2562
|
+
self.dropout(h_V_context)
|
|
2563
|
+
)
|
|
2564
|
+
|
|
2565
|
+
return encoder_features
|
|
2566
|
+
|
|
2567
|
+
def forward(self, network_input):
|
|
2568
|
+
"""
|
|
2569
|
+
Forward pass for the LigandMPNN model, which uses the same forward
|
|
2570
|
+
function and is repeated here for documentation purposes.
|
|
2571
|
+
|
|
2572
|
+
A NOTE on shapes (in addition to ProteinMPNN):
|
|
2573
|
+
- N = number of ligand atoms
|
|
2574
|
+
- M = self.num_context_atoms (number of ligand atom neighbors)
|
|
2575
|
+
|
|
2576
|
+
Args:
|
|
2577
|
+
network_input (dict): Dictionary containing the input to the
|
|
2578
|
+
network.
|
|
2579
|
+
- input_features (dict): dictionary containing input features
|
|
2580
|
+
and all necessary information for the model to run, in
|
|
2581
|
+
addition to the input features for the ProteinMPNN model.
|
|
2582
|
+
- Y (torch.Tensor): [B, N, 3] - 3D coordinates of the
|
|
2583
|
+
ligand atoms.
|
|
2584
|
+
- Y_m (torch.Tensor): [B, N] - mask indicating which ligand
|
|
2585
|
+
atoms are valid.
|
|
2586
|
+
- Y_t (torch.Tensor): [B, N] - element types of the ligand
|
|
2587
|
+
atoms.
|
|
2588
|
+
- atomize_side_chains (bool): Whether to atomize side
|
|
2589
|
+
chains of fixed residues.
|
|
2590
|
+
Side Effects:
|
|
2591
|
+
Any changes denoted below to input_features are also mutated on the
|
|
2592
|
+
original input features.
|
|
2593
|
+
Returns:
|
|
2594
|
+
network_output (dict): Output dictionary containing the requested
|
|
2595
|
+
features based on the input features' "features_to_return" key,
|
|
2596
|
+
in addition to the output features from the ProteinMPNN model.
|
|
2597
|
+
- input_features (dict): The input features from above, with
|
|
2598
|
+
the following keys added or modified:
|
|
2599
|
+
- hide_side_chain_mask (torch.Tensor): [B, L] - mask for
|
|
2600
|
+
hiding side chains, where True is a residue with hidden
|
|
2601
|
+
side chains, and False is a residue with revealed side
|
|
2602
|
+
chains.
|
|
2603
|
+
- Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand
|
|
2604
|
+
atoms with added Gaussian noise.
|
|
2605
|
+
- Y_pre_noise (torch.Tensor): [B, N, 3] -
|
|
2606
|
+
3D coordinates of the ligand atoms before adding
|
|
2607
|
+
Gaussian noise ('Y' before noise).
|
|
2608
|
+
- ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D
|
|
2609
|
+
coordinates of nearest ligand/atomized side chain
|
|
2610
|
+
atoms to the virtual atoms for each residue.
|
|
2611
|
+
- ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask
|
|
2612
|
+
indicating which nearest ligand/atomized side chain
|
|
2613
|
+
atoms to the virtual atoms are valid.
|
|
2614
|
+
- ligand_subgraph_Y_t (torch.Tensor): [B, L, M] -
|
|
2615
|
+
element types of the nearest ligand/atomized side chain
|
|
2616
|
+
atoms to the virtual atoms for each residue.
|
|
2617
|
+
- graph_features (dict): The graph features.
|
|
2618
|
+
- E_protein_to_ligand (torch.Tensor):
|
|
2619
|
+
[B, L, M, num_node_output_features] - protein to
|
|
2620
|
+
ligand subgraph edges; can also be considered node
|
|
2621
|
+
features of the protein residues (although they are not
|
|
2622
|
+
used as such).
|
|
2623
|
+
- ligand_subgraph_nodes (torch.Tensor):
|
|
2624
|
+
[B, L, M, num_node_output_features] - ligand atom type
|
|
2625
|
+
information, embedded as node features.
|
|
2626
|
+
- ligand_subgraph_edges (torch.Tensor):
|
|
2627
|
+
[B, L, M, M, num_edge_output_features] - embedded and
|
|
2628
|
+
normalized radial basis function embedding of the
|
|
2629
|
+
distances between the ligand atoms in each residue
|
|
2630
|
+
subgraph.
|
|
2631
|
+
"""
|
|
2632
|
+
return super().forward(network_input)
|