rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
mpnn/model/mpnn.py ADDED
@@ -0,0 +1,2632 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from atomworks.constants import UNKNOWN_AA
4
+ from mpnn.model.layers.graph_embeddings import (
5
+ ProteinFeatures,
6
+ ProteinFeaturesLigand,
7
+ ProteinFeaturesMembrane,
8
+ ProteinFeaturesPSSM,
9
+ )
10
+ from mpnn.model.layers.message_passing import (
11
+ DecLayer,
12
+ EncLayer,
13
+ cat_neighbors_nodes,
14
+ gather_nodes,
15
+ )
16
+ from mpnn.utils.probability import sample_bernoulli_rv
17
+
18
+
19
+ class ProteinMPNN(nn.Module):
20
+ """
21
+ Class for default ProteinMPNN.
22
+ """
23
+
24
+ HAS_NODE_FEATURES = False
25
+
26
+ @staticmethod
27
+ def init_weights(module):
28
+ """
29
+ Initialize the weights of the module.
30
+
31
+ Args:
32
+ module (nn.Module): The module to initialize.
33
+ Side Effects:
34
+ Initializes the weights of the module using Xavier uniform
35
+ initialization for parameters with a dimension greater than 1.
36
+ """
37
+ # Initialize the weights of the module using Xavier uniform, skipping
38
+ # any parameters with a dimension of 1 or less (for example, biases).
39
+ for parameter in module.parameters():
40
+ if parameter.dim() > 1:
41
+ nn.init.xavier_uniform_(parameter)
42
+
43
+ def __init__(
44
+ self,
45
+ num_node_features=128,
46
+ num_edge_features=128,
47
+ hidden_dim=128,
48
+ num_encoder_layers=3,
49
+ num_decoder_layers=3,
50
+ num_neighbors=48,
51
+ dropout_rate=0.1,
52
+ num_positional_embeddings=16,
53
+ min_rbf_mean=2.0,
54
+ max_rbf_mean=22.0,
55
+ num_rbf=16,
56
+ graph_featurization_module=None,
57
+ ):
58
+ """
59
+ Setup the ProteinMPNN model.
60
+
61
+ Args:
62
+ num_node_features (int): Number of node features.
63
+ num_edge_features (int): Number of edge features.
64
+ hidden_dim (int): Hidden dimension size.
65
+ num_encoder_layers (int): Number of encoder layers.
66
+ num_decoder_layers (int): Number of decoder layers.
67
+ num_neighbors (int): Number of neighbors for each polymer residue.
68
+ dropout_rate (float): Dropout rate.
69
+ num_positional_embeddings (int): Number of positional embeddings.
70
+ min_rbf_mean (float): Minimum radial basis function mean.
71
+ max_rbf_mean (float): Maximum radial basis function mean.
72
+ num_rbf (int): Number of radial basis functions.
73
+ graph_featurization_module (nn.Module, optional): Custom graph
74
+ featurization module. If None, the default ProteinFeatures
75
+ module is used.
76
+ """
77
+ super(ProteinMPNN, self).__init__()
78
+
79
+ # Internal dimensions
80
+ self.num_node_features = num_node_features
81
+ self.num_edge_features = num_edge_features
82
+ self.hidden_dim = hidden_dim
83
+
84
+ # Number of layers in the encoder and decoder
85
+ self.num_encoder_layers = num_encoder_layers
86
+ self.num_decoder_layers = num_decoder_layers
87
+
88
+ # Dropout rate
89
+ self.dropout_rate = dropout_rate
90
+
91
+ # Module for featurizing the graph.
92
+ if graph_featurization_module is not None:
93
+ self.graph_featurization_module = graph_featurization_module
94
+ else:
95
+ self.graph_featurization_module = ProteinFeatures(
96
+ num_edge_output_features=num_edge_features,
97
+ num_node_output_features=num_node_features,
98
+ num_positional_embeddings=num_positional_embeddings,
99
+ min_rbf_mean=min_rbf_mean,
100
+ max_rbf_mean=max_rbf_mean,
101
+ num_rbf=num_rbf,
102
+ num_neighbors=num_neighbors,
103
+ )
104
+
105
+ # Provide a shorter reference to the graph featurization token-to-idx
106
+ # mapping.
107
+ self.token_to_idx = self.graph_featurization_module.TOKEN_ENCODING.token_to_idx
108
+
109
+ # Size of the vocabulary
110
+ self.vocab_size = self.graph_featurization_module.TOKEN_ENCODING.n_tokens
111
+
112
+ # Unknown residue token indices, from the TOKEN_ENCODING.
113
+ self.unknown_token_indices = list(
114
+ map(
115
+ lambda token: self.token_to_idx[token],
116
+ self.graph_featurization_module.TOKEN_ENCODING.unknown_tokens,
117
+ )
118
+ )
119
+
120
+ # Linear layer for the edge features.
121
+ self.W_e = nn.Linear(num_edge_features, hidden_dim, bias=True)
122
+
123
+ # Linear layer for the sequence features.
124
+ self.W_s = nn.Embedding(self.vocab_size, hidden_dim)
125
+
126
+ if self.HAS_NODE_FEATURES:
127
+ # Linear layer for the node features.
128
+ self.W_v = nn.Linear(num_node_features, hidden_dim, bias=True)
129
+
130
+ # Dropout layer
131
+ self.dropout = nn.Dropout(dropout_rate)
132
+
133
+ # Encoder layers
134
+ self.encoder_layers = nn.ModuleList(
135
+ [
136
+ EncLayer(hidden_dim, hidden_dim * 3, dropout=dropout_rate)
137
+ for _ in range(num_encoder_layers)
138
+ ]
139
+ )
140
+
141
+ # Decoder layers
142
+ self.decoder_layers = nn.ModuleList(
143
+ [
144
+ DecLayer(hidden_dim, hidden_dim * 4, dropout=dropout_rate)
145
+ for _ in range(num_decoder_layers)
146
+ ]
147
+ )
148
+
149
+ # Linear layer for the output
150
+ self.W_out = nn.Linear(hidden_dim, self.vocab_size, bias=True)
151
+
152
+ def construct_known_residue_mask(self, S):
153
+ """
154
+ Construct a mask for the known residues based on the sequence S.
155
+
156
+ Args:
157
+ S (torch.Tensor): [B, L] - the sequence of residues.
158
+ Returns:
159
+ known_residue_mask (torch.Tensor): [B, L] - mask for known residues,
160
+ where True is a residue with one of the canonical residue types,
161
+ and False is a residue with an unknown residue type.
162
+ """
163
+ # Create a mask for known residues.
164
+ known_residue_mask = torch.isin(
165
+ S,
166
+ torch.tensor(self.unknown_token_indices, device=S.device, dtype=S.dtype),
167
+ invert=True,
168
+ )
169
+
170
+ return known_residue_mask
171
+
172
+ def sample_and_construct_masks(self, input_features):
173
+ """
174
+ Sample and construct masks for the input features.
175
+
176
+ Args:
177
+ input_features (dict): Input features containing the residue mask.
178
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
179
+ - S (torch.Tensor): [B, L] - sequence of residues.
180
+ - designed_residue_mask (torch.Tensor): [B, L] - mask for the
181
+ designed residues.
182
+ Side Effects:
183
+ input_features["residue_mask"] (torch.Tensor): [B, L] - mask for the
184
+ residues, where True is a residue that is valid and False is a
185
+ residue that is invalid.
186
+ input_features["known_residue_mask"] (torch.Tensor): [B, L] - mask
187
+ for known residues, where True is a residue with one of the
188
+ canonical residue types, and False is a residue with an unknown
189
+ residue type.
190
+ input_features["designed_residue_mask"] (torch.Tensor): [B, L] -
191
+ mask for designed residues, where True is a residue that is
192
+ designed, and False is a residue that is not designed.
193
+ input_features["mask_for_loss"] (torch.Tensor): [B, L] - mask for
194
+ loss, where True is a residue that is included in the loss
195
+ calculation, and False is a residue that is not included in the
196
+ loss calculation.
197
+ """
198
+ # Check that the input features contain the necessary keys.
199
+ if "residue_mask" not in input_features:
200
+ raise ValueError("Input features must contain 'residue_mask' key.")
201
+ if "S" not in input_features:
202
+ raise ValueError("Input features must contain 'S' key.")
203
+ if "designed_residue_mask" not in input_features:
204
+ raise ValueError("Input features must contain 'designed_residue_mask' key.")
205
+
206
+ # Ensure that the residue_mask is a boolean.
207
+ input_features["residue_mask"] = input_features["residue_mask"].bool()
208
+
209
+ # Mask is true for canonical residues, false for unknown residues.
210
+ input_features["known_residue_mask"] = self.construct_known_residue_mask(
211
+ input_features["S"]
212
+ )
213
+
214
+ # Mask for residues that are designed. If the designed_residue_mask
215
+ # is None, then we assume that all valid residues are designed.
216
+ if input_features["designed_residue_mask"] is None:
217
+ input_features["designed_residue_mask"] = input_features[
218
+ "residue_mask"
219
+ ].clone()
220
+ else:
221
+ input_features["designed_residue_mask"] = input_features[
222
+ "designed_residue_mask"
223
+ ].bool()
224
+
225
+ # Chech that the designed_residue_mask is a subset of valid residues.
226
+ if not torch.all(
227
+ input_features["residue_mask"][input_features["designed_residue_mask"]]
228
+ ):
229
+ raise ValueError("Designed residues must all be valid residues.")
230
+
231
+ # Mask for loss.
232
+ input_features["mask_for_loss"] = (
233
+ input_features["residue_mask"]
234
+ & input_features["known_residue_mask"]
235
+ & input_features["designed_residue_mask"]
236
+ )
237
+
238
+ def graph_featurization(self, input_features):
239
+ """
240
+ Apply the graph featurization to the input features.
241
+
242
+ Args:
243
+ input_features (dict): Input features to be featurized.
244
+ Returns:
245
+ graph_features (dict): Featurized graph features (contains both node
246
+ and edge features).
247
+ """
248
+ graph_features = self.graph_featurization_module(input_features)
249
+
250
+ return graph_features
251
+
252
+ def encode(self, input_features, graph_features):
253
+ """
254
+ Encode the protein features with message passing.
255
+
256
+ # NOTE: K = self.num_neighbors
257
+ Args:
258
+ input_features (dict): Input features containing the residue mask.
259
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
260
+ graph_features (dict): Graph features containing the featurized
261
+ node and edge inputs.
262
+ - E (torch.Tensor): [B, L, K, self.num_edge_features] - edge
263
+ features.
264
+ - E_idx (torch.Tensor): [B, L, K] - edge indices.
265
+ - V (torch.Tensor, optional): [B, L, self.num_node_features] -
266
+ node features (if HAS_NODE_FEATURES is True).
267
+ Returns:
268
+ encoder_features (dict): Encoded features containing the encoded
269
+ protein node and protein edge features.
270
+ - h_V (torch.Tensor): [B, L, self.hidden_dim] - the protein node
271
+ features after encoding message passing.
272
+ - h_E (torch.Tensor): [B, L, K, self.hidden_dim] - the protein
273
+ edge features after encoding message passing.
274
+ """
275
+ # Check that the input features contains the necessary keys.
276
+ if "residue_mask" not in input_features:
277
+ raise ValueError("Input features must contain 'residue_mask' key.")
278
+
279
+ # Check that the graph features contains the necessary keys.
280
+ if "E" not in graph_features:
281
+ raise ValueError("Graph features must contain 'E' key.")
282
+ if "E_idx" not in graph_features:
283
+ raise ValueError("Graph features must contain 'E_idx' key.")
284
+
285
+ B, L, _, _ = graph_features["E"].shape
286
+
287
+ # Embed the node features.
288
+ # h_V [B, L, self.num_node_features] - the embedding of the node
289
+ # features.
290
+ if self.HAS_NODE_FEATURES:
291
+ if "V" not in graph_features:
292
+ raise ValueError("Graph features must contain 'V' key.")
293
+ h_V = self.W_v(graph_features["V"])
294
+ else:
295
+ h_V = torch.zeros(
296
+ (B, L, self.num_node_features), device=graph_features["E"].device
297
+ )
298
+
299
+ # Embed the edge features.
300
+ # h_E [B, L, K, self.edge_features] - the embedding of the edge
301
+ # features.
302
+ h_E = self.W_e(graph_features["E"])
303
+
304
+ # Gather the per-residue mask of the nearest neighbors.
305
+ # mask_E [B, L, K] - the mask for the edges, gathered at the
306
+ # neighbor indices.
307
+ mask_E = gather_nodes(
308
+ input_features["residue_mask"].unsqueeze(-1), graph_features["E_idx"]
309
+ ).squeeze(-1)
310
+ mask_E = input_features["residue_mask"].unsqueeze(-1) * mask_E
311
+
312
+ # Perform the message passing in the encoder.
313
+ for layer in self.encoder_layers:
314
+ # h_V [B, L, self.hidden_dim] - the updated node features.
315
+ # h_E [B, L, K, self.hidden_dim] - the updated edge features.
316
+ h_V, h_E = torch.utils.checkpoint.checkpoint(
317
+ layer,
318
+ h_V,
319
+ h_E,
320
+ graph_features["E_idx"],
321
+ mask_V=input_features["residue_mask"],
322
+ mask_E=mask_E,
323
+ use_reentrant=False,
324
+ )
325
+
326
+ # Create the encoder features dictionary.
327
+ encoder_features = {
328
+ "h_V": h_V,
329
+ "h_E": h_E,
330
+ }
331
+
332
+ return encoder_features
333
+
334
+ def setup_causality_masks(self, input_features, graph_features, decoding_eps=1e-4):
335
+ """
336
+ Setup the causality masks for the decoder. This can involve sampling
337
+ the decoding order.
338
+
339
+ Args:
340
+ input_features (dict): Input features containing the residue mask.
341
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
342
+ - designed_residue_mask (torch.Tensor): [B, L] - mask for the
343
+ designed residues.
344
+ - symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
345
+ an integer for every residue, indicating the symmetry group
346
+ that it belongs to. If None, the residues are not grouped by
347
+ symmetry. For example, if residue i and j should be decoded
348
+ symmetrically, then symmetry_equivalence_group[i] ==
349
+ symmetry_equivalence_group[j]. Must be torch.int64 to allow
350
+ for use as an index. These values should range from 0 to
351
+ the maximum number of symmetry groups - 1 for each example.
352
+ - causality_pattern (str): The pattern of causality to use for
353
+ the decoder.
354
+ graph_features (dict): Graph features containing the featurized
355
+ node and edge inputs.
356
+ - E_idx (torch.Tensor): [B, L, K] - edge indices.
357
+ decoding_eps (float): Small epsilon value added to the
358
+ decode_last_mask to prevent the case where every randomly
359
+ sampled number is multiplied by zero, which would result
360
+ in an incorrect decoding order.
361
+ Returns:
362
+ decoder_features (dict): Decoding features containing the decoding
363
+ order and masks for the decoder.
364
+ - causal_mask (torch.Tensor): [B, L, K, 1] - the causal mask for
365
+ the decoder.
366
+ - anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
367
+ anti-causal mask for the decoder.
368
+ - decoding_order (torch.Tensor): [B, L] - the order in which the
369
+ residues should be decoded.
370
+ - decode_last_mask (torch.Tensor): [B, L] - mask for residues
371
+ that should be decoded last, where False is a residue that
372
+ should be decoded first (invalid or fixed), and True is a
373
+ residue that should not be decoded first (designed
374
+ residues).
375
+ """
376
+ # Check that the input features contains the necessary keys.
377
+ if "residue_mask" not in input_features:
378
+ raise ValueError("Input features must contain 'residue_mask' key.")
379
+ if "designed_residue_mask" not in input_features:
380
+ raise ValueError("Input features must contain 'designed_residue_mask' key.")
381
+ if "symmetry_equivalence_group" not in input_features:
382
+ raise ValueError(
383
+ "Input features must contain 'symmetry_equivalence_group' key."
384
+ )
385
+ if "causality_pattern" not in input_features:
386
+ raise ValueError("Input features must contain 'causality_pattern' key.")
387
+
388
+ # Check that the encoder features contains the necessary keys.
389
+ if "E_idx" not in graph_features:
390
+ raise ValueError("Graph features must contain 'E_idx' key.")
391
+
392
+ B, L = input_features["residue_mask"].shape
393
+
394
+ # decode_last_mask [B, L] - mask for residues that should be
395
+ # decoded last, where False is a residue that should be decoded first
396
+ # (invalid or fixed), and True is a residue that should not be decoded
397
+ # first (designed residues).
398
+ decode_last_mask = (
399
+ input_features["residue_mask"] & input_features["designed_residue_mask"]
400
+ ).bool()
401
+
402
+ # Compute the noise for the decoding order.
403
+ if input_features["symmetry_equivalence_group"] is None:
404
+ # noise [B, L] - the noise for each residue, sampled from a normal
405
+ # distribution. This is used to randomly sample the decoding order.
406
+ noise = torch.randn((B, L), device=input_features["residue_mask"].device)
407
+ else:
408
+ # Assume that all symmetry groups are non-negative.
409
+ assert input_features["symmetry_equivalence_group"].min() >= 0
410
+
411
+ # Compute the maximum number of symmetry groups.
412
+ G = int(input_features["symmetry_equivalence_group"].max().item()) + 1
413
+
414
+ # noise_per_group [B, G] - the noise for each
415
+ # symmetry group, sampled from a normal distribution.
416
+ noise_per_group = torch.randn(
417
+ (B, G), device=input_features["residue_mask"].device
418
+ )
419
+
420
+ # batch_idx [B, 1] - the batch indices for each example.
421
+ batch_idx = torch.arange(
422
+ B, device=input_features["residue_mask"].device
423
+ ).unsqueeze(-1)
424
+
425
+ # noise [B, L] - the noise for each residue, sampled from a normal
426
+ # distribution, where the noise is the same for residues in the same
427
+ # symmetry group.
428
+ noise = noise_per_group[
429
+ batch_idx, input_features["symmetry_equivalence_group"]
430
+ ]
431
+
432
+ # decoding_order [B, L] - the order in which the residues should be
433
+ # decoded. Specifically, decoding_order[b, i] = j specifies that the
434
+ # jth residue should be decoded ith. Sampled for every example.
435
+ # Numbers will be smaller where decode_last_mask is False (0), and
436
+ # larger where decode_last_mask is True (1), leading to the appropriate
437
+ # index ordering after the argsort.
438
+ decoding_order = torch.argsort(
439
+ (decode_last_mask.float() + decoding_eps) * torch.abs(noise), dim=-1
440
+ )
441
+
442
+ # permutation_matrix_reverse [B, L, L] - the reverse permutation
443
+ # matrix (the transpose/inverse of the permutation matrix) computed from
444
+ # the decoding order; such that permutation_matrix_reverse[i, j] = 1 if
445
+ # the ith entry in the original will be sent to the jth position (the
446
+ # ith row/column in the original all by all causal mask will be sent
447
+ # to the jth row/column in the permuted all by all causal mask).
448
+ permutation_matrix_reverse = torch.nn.functional.one_hot(
449
+ decoding_order, num_classes=L
450
+ ).float()
451
+
452
+ # Create the all by all causal mask for the decoder.
453
+ # causal_mask_all_by_all [L, L] - the all by all causal mask for the
454
+ # decoder, constructed based on the specified causality pattern.
455
+ if input_features["causality_pattern"] == "auto_regressive":
456
+ # left_to_right_causal_mask [L, L] - the causal mask for the
457
+ # left-to-right attention (lower triangular with zeros on the
458
+ # diagonal). Residue at position i can "see" residues at positions
459
+ # j < i, but not at positions j >= i.
460
+ left_to_right_causal_mask = 1 - torch.triu(
461
+ torch.ones(L, L, device=input_features["residue_mask"].device)
462
+ )
463
+
464
+ causal_mask_all_by_all = left_to_right_causal_mask
465
+ elif input_features["causality_pattern"] == "unconditional":
466
+ # zeros_causal_mask [L, L] - the causal mask for the decoder,
467
+ # where all entries are zeros. Residue at position i cannot see
468
+ # any other residues, including itself.
469
+ zeros_causal_mask = torch.zeros(
470
+ (L, L), device=input_features["residue_mask"].device
471
+ )
472
+
473
+ causal_mask_all_by_all = zeros_causal_mask
474
+ elif input_features["causality_pattern"] == "conditional":
475
+ # ones_causal_mask [L, L] - the causal mask for the decoder,
476
+ # where all entries are ones. Residue at position i can see all
477
+ # other residues, including itself.
478
+ ones_causal_mask = torch.ones(
479
+ (L, L), device=input_features["residue_mask"].device
480
+ )
481
+
482
+ causal_mask_all_by_all = ones_causal_mask
483
+ elif input_features["causality_pattern"] == "conditional_minus_self":
484
+ # I [L, L] - the identity matrix, repeated along the batch.
485
+ I = torch.eye(L, device=input_features["residue_mask"].device)
486
+
487
+ # ones_minus_self_causal_mask [L, L] - the causal mask for the
488
+ # decoder, where all entries are ones except for the diagonal
489
+ # entries, which are zeros. Residue at position i can see all other
490
+ # residues, but not itself.
491
+ ones_minus_self_causal_mask = (
492
+ torch.ones((L, L), device=input_features["residue_mask"].device) - I
493
+ )
494
+
495
+ causal_mask_all_by_all = ones_minus_self_causal_mask
496
+ else:
497
+ raise ValueError(
498
+ "Unknown causality pattern: " + f"{input_features['causality_pattern']}"
499
+ )
500
+
501
+ # permuted_causal_mask_all_by_all [B, L, L] - the causal mask for the
502
+ # decoder, permuted according to the decoding order.
503
+ permuted_causal_mask_all_by_all = torch.einsum(
504
+ "ij, biq, bjp->bqp",
505
+ causal_mask_all_by_all,
506
+ permutation_matrix_reverse,
507
+ permutation_matrix_reverse,
508
+ )
509
+
510
+ # If the symmetry equivalence group is not None, then we need to
511
+ # mask out residues that belong to the same symmetry group.
512
+ if input_features["symmetry_equivalence_group"] is not None:
513
+ # same_symmetry_group [B, L, L] - a mask for the residues that
514
+ # belong to the same symmetry group, where True is a residue pair
515
+ # that belongs to the same symmetry group, and False is a residue
516
+ # pair that does not belong to the same symmetry group.
517
+ same_symmetry_group = (
518
+ input_features["symmetry_equivalence_group"][:, :, None]
519
+ == input_features["symmetry_equivalence_group"][:, None, :]
520
+ )
521
+
522
+ permuted_causal_mask_all_by_all[same_symmetry_group] = 0.0
523
+
524
+ # causal_mask_nearest_neighbors [B, L, K, 1] - the causal mask for
525
+ # the decoder, gathered at the neighbor indices. This limits the
526
+ # attention to the nearest neighbors.
527
+ causal_mask_nearest_neighbors = torch.gather(
528
+ permuted_causal_mask_all_by_all, 2, graph_features["E_idx"]
529
+ ).unsqueeze(-1)
530
+
531
+ # causal_mask [B, L, K, 1] - the final causal mask for the decoder;
532
+ # masked version of causal_mask_nearest_neighbors.
533
+ causal_mask = (
534
+ causal_mask_nearest_neighbors
535
+ * input_features["residue_mask"].view([B, L, 1, 1]).float()
536
+ )
537
+
538
+ # anti_causal_mask [B, L, K, 1] - the anti-causal mask for the decoder.
539
+ anti_causal_mask = (1.0 - causal_mask_nearest_neighbors) * input_features[
540
+ "residue_mask"
541
+ ].view([B, L, 1, 1]).float()
542
+
543
+ # Add the masks to the decoder features.
544
+ decoder_features = {
545
+ "causal_mask": causal_mask,
546
+ "anti_causal_mask": anti_causal_mask,
547
+ "decoding_order": decoding_order,
548
+ "decode_last_mask": decode_last_mask,
549
+ }
550
+
551
+ return decoder_features
552
+
553
+ def repeat_along_batch(self, input_features, graph_features, encoder_features):
554
+ """
555
+ Given the input features, graph features, and encoder features,
556
+ repeat the samples along the batch dimension. This is useful during
557
+ inference, to prevent re-running the encoder for every sample (since
558
+ the encoder is deterministic and sequence-agnostic).
559
+
560
+ NOTE: if `repeat_sample_num` is not None and greater than 1, then
561
+ B must be 1, since repeating samples along the batch dimension is not
562
+ supported when more than one sample is provided in the batch.
563
+
564
+ Args:
565
+ input_features (dict): Input features containing the residue mask
566
+ and sequence.
567
+ - S (torch.Tensor): [B, L] - sequence of residues.
568
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
569
+ - temperature (torch.Tensor, optional): [B, L] - the per-residue
570
+ temperature to use for sampling. If None, the code will
571
+ implicitly use a temperature of 1.0.
572
+ - bias (torch.Tensor, optional): [B, L, 21] - the per-residue
573
+ bias to use for sampling. If None, the code will implicitly
574
+ use a bias of 0.0 for all residues.
575
+ - pair_bias (torch.Tensor, optional): [B, L, 21, L, 21] - the
576
+ per-residue pair bias to use for sampling. If None, the code
577
+ will implicitly use a pair bias of 0.0 for all residue
578
+ pairs.
579
+ - symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
580
+ an integer for every residue, indicating the symmetry group
581
+ that it belongs to. If None, the residues are not grouped by
582
+ symmetry. For example, if residue i and j should be decoded
583
+ symmetrically, then symmetry_equivalence_group[i] ==
584
+ symmetry_equivalence_group[j]. Must be torch.int64 to allow
585
+ for use as an index. These values should range from 0 to
586
+ the maximum number of symmetry groups - 1 for each example.
587
+ - symmetry_weight (torch.Tensor, optional): [B, L] - the weight
588
+ for the symmetry equivalence group. If None, the code will
589
+ implicitly use a weight of 1.0 for all residues.
590
+ - repeat_sample_num (int, optional): Number of times to repeat
591
+ the samples along the batch dimension. If None, no
592
+ repetition is performed. If greater than 1, the samples
593
+ are repeated along the batch dimension. If greater than 1,
594
+ B must be 1, since repeating samples along the batch
595
+ dimension is not supported when more than one sample is
596
+ provided in the batch.
597
+ graph_features (dict): Graph features containing the featurized
598
+ node and edge inputs.
599
+ - E_idx (torch.Tensor): [B, L, K] - edge indices.
600
+ encoder_features (dict): Encoder features containing the encoded
601
+ protein node and protein edge features.
602
+ - h_V (torch.Tensor): [B, L, H] - the protein node features
603
+ after encoding message passing.
604
+ - h_E (torch.Tensor): [B, L, K, H] - the protein edge features
605
+ after encoding message passing.
606
+ Side Effects:
607
+ input_features["S"] (torch.Tensor): [repeat_sample_num, L] - the
608
+ sequence of residues, repeated along the batch dimension.
609
+ input_features["residue_mask"] (torch.Tensor):
610
+ [repeat_sample_num, L] - the mask for the residues, repeated
611
+ along the batch dimension.
612
+ input_features["mask_for_loss"] (torch.Tensor):
613
+ [repeat_sample_num, L] - the mask for the loss, repeated
614
+ along the batch dimension.
615
+ input_features["designed_residue_mask"] (torch.Tensor):
616
+ [repeat_sample_num, L] - the mask for designed residues,
617
+ repeated along the batch dimension.
618
+ input_features["temperature"] (torch.Tensor, optional):
619
+ [repeat_sample_num, L] - the per-residue temperature to use for
620
+ sampling, repeated along the batch dimension. If None, the code
621
+ will implicitly use a temperature of 1.0.
622
+ input_features["bias"] (torch.Tensor, optional):
623
+ [repeat_sample_num, L, 21] - the per-residue bias to use for
624
+ sampling, repeated along the batch dimension. If None, the code
625
+ will implicitly use a bias of 0.0 for all residues.
626
+ input_features["pair_bias"] (torch.Tensor, optional):
627
+ [repeat_sample_num, L, 21, L, 21] - the per-residue pair bias
628
+ to use for sampling, repeated along the batch dimension. If
629
+ None, the code will implicitly use a pair bias of 0.0 for all
630
+ residue pairs.
631
+ input_features["symmetry_equivalence_group"] (torch.Tensor,
632
+ optional): [repeat_sample_num, L] - the symmetry equivalence
633
+ group for each residue, repeated along the batch dimension.
634
+ input_features["symmetry_weight"] (torch.Tensor, optional):
635
+ [repeat_sample_num, L] - the symmetry weight for each residue,
636
+ repeated along the batch dimension. If None, the code will
637
+ implicitly use a weight of 1.0 for all residues.
638
+ encoder_features["h_V"] (torch.Tensor): [repeat_sample_num, L, H] -
639
+ the protein node features after encoding message passing,
640
+ repeated along the batch dimension.
641
+ encoder_features["h_E"] (torch.Tensor):
642
+ [repeat_sample_num, L, K, H] - the protein edge features
643
+ after encoding message passing, repeated along the batch
644
+ dimension.
645
+ graph_features["E_idx"] (torch.Tensor): [repeat_sample_num, L, K] -
646
+ the edge indices, repeated along the batch dimension.
647
+ """
648
+ # Check that the input features contains the necessary keys.
649
+ if "S" not in input_features:
650
+ raise ValueError("Input features must contain 'S' key.")
651
+ if "residue_mask" not in input_features:
652
+ raise ValueError("Input features must contain 'residue_mask' key.")
653
+ if "mask_for_loss" not in input_features:
654
+ raise ValueError("Input features must contain 'mask_for_loss' key.")
655
+ if "temperature" not in input_features:
656
+ raise ValueError("Input features must contain 'temperature' key.")
657
+ if "bias" not in input_features:
658
+ raise ValueError("Input features must contain 'bias' key.")
659
+ if "pair_bias" not in input_features:
660
+ raise ValueError("Input features must contain 'pair_bias' key.")
661
+ if "symmetry_equivalence_group" not in input_features:
662
+ raise ValueError(
663
+ "Input features must contain 'symmetry_equivalence_group' key."
664
+ )
665
+ if "symmetry_weight" not in input_features:
666
+ raise ValueError("Input features must contain 'symmetry_weight' key.")
667
+ if "repeat_sample_num" not in input_features:
668
+ raise ValueError("Input features must contain 'repeat_sample_num' key.")
669
+
670
+ # Check that the graph features contains the necessary keys.
671
+ if "E_idx" not in graph_features:
672
+ raise ValueError("Graph features must contain 'E_idx' key.")
673
+
674
+ # Check that the encoder features contains the necessary keys.
675
+ if "h_V" not in encoder_features:
676
+ raise ValueError("Encoder features must contain 'h_V' key.")
677
+ if "h_E" not in encoder_features:
678
+ raise ValueError("Encoder features must contain 'h_E' key.")
679
+
680
+ # Repeating a sample along the batch dimension is not supported
681
+ # when more than one sample is provided in the batch.
682
+ if (
683
+ input_features["repeat_sample_num"] is not None
684
+ and input_features["repeat_sample_num"] > 1
685
+ and input_features["S"].shape[0] > 1
686
+ ):
687
+ raise ValueError(
688
+ "Cannot repeat samples when more than one sample "
689
+ + "is provided in the batch."
690
+ )
691
+
692
+ # Repeat the samples along the batch dimension if necessary.
693
+ if (
694
+ input_features["repeat_sample_num"] is not None
695
+ and input_features["repeat_sample_num"] > 1
696
+ ):
697
+ # S [repeat_sample_num, L]
698
+ input_features["S"] = input_features["S"][0].repeat(
699
+ input_features["repeat_sample_num"], 1
700
+ )
701
+ # residue_mask [repeat_sample_num, L]
702
+ input_features["residue_mask"] = input_features["residue_mask"][0].repeat(
703
+ input_features["repeat_sample_num"], 1
704
+ )
705
+ # mask_for_loss [repeat_sample_num, L]
706
+ input_features["mask_for_loss"] = input_features["mask_for_loss"][0].repeat(
707
+ input_features["repeat_sample_num"], 1
708
+ )
709
+ # designed_residue_mask [repeat_sample_num, L]
710
+ input_features["designed_residue_mask"] = input_features[
711
+ "designed_residue_mask"
712
+ ][0].repeat(input_features["repeat_sample_num"], 1)
713
+ if input_features["temperature"] is not None:
714
+ # temperature [repeat_sample_num, L]
715
+ input_features["temperature"] = input_features["temperature"][0].repeat(
716
+ input_features["repeat_sample_num"], 1
717
+ )
718
+ if input_features["bias"] is not None:
719
+ # bias [repeat_sample_num, L, 21]
720
+ input_features["bias"] = input_features["bias"][0].repeat(
721
+ input_features["repeat_sample_num"], 1, 1
722
+ )
723
+ if input_features["pair_bias"] is not None:
724
+ # pair_bias [repeat_sample_num, L, 21, L, 21]
725
+ input_features["pair_bias"] = input_features["pair_bias"][0].repeat(
726
+ input_features["repeat_sample_num"], 1, 1, 1, 1
727
+ )
728
+ if input_features["symmetry_equivalence_group"] is not None:
729
+ # symmetry_equivalence_group [repeat_sample_num, L]
730
+ input_features["symmetry_equivalence_group"] = input_features[
731
+ "symmetry_equivalence_group"
732
+ ][0].repeat(input_features["repeat_sample_num"], 1)
733
+ if input_features["symmetry_weight"] is not None:
734
+ # symmetry_weight [repeat_sample_num, L]
735
+ input_features["symmetry_weight"] = input_features["symmetry_weight"][
736
+ 0
737
+ ].repeat(input_features["repeat_sample_num"], 1)
738
+
739
+ # h_V [repeat_sample_num, L, H]
740
+ encoder_features["h_V"] = encoder_features["h_V"][0].repeat(
741
+ input_features["repeat_sample_num"], 1, 1
742
+ )
743
+ # h_E [repeat_sample_num, L, K, H]
744
+ encoder_features["h_E"] = encoder_features["h_E"][0].repeat(
745
+ input_features["repeat_sample_num"], 1, 1, 1
746
+ )
747
+
748
+ # E_idx [repeat_sample_num, L, K]
749
+ graph_features["E_idx"] = graph_features["E_idx"][0].repeat(
750
+ input_features["repeat_sample_num"], 1, 1
751
+ )
752
+
753
+ def decode_setup(
754
+ self, input_features, graph_features, encoder_features, decoder_features
755
+ ):
756
+ """
757
+ Given the input features, graph features, encoder features, and initial
758
+ decoder features, set up the decoder for the autoregressive decoding.
759
+
760
+ Args:
761
+ input_features (dict): Input features containing the residue mask
762
+ and sequence.
763
+ - S (torch.Tensor): [B, L] - sequence of residues.
764
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
765
+ - initialize_sequence_embedding_with_ground_truth (bool):
766
+ If True, initialize the sequence embedding with the ground
767
+ truth sequence S. Else, initialize the sequence
768
+ embedding with zeros.
769
+ graph_features (dict): Graph features containing the featurized
770
+ node and edge inputs.
771
+ - E_idx (torch.Tensor): [B, L, K] - edge indices.
772
+ encoder_features (dict): Encoder features containing the encoded
773
+ protein node and protein edge features.
774
+ - h_V (torch.Tensor): [B, L, H] - the protein node features
775
+ after encoding message passing.
776
+ - h_E (torch.Tensor): [B, L, K, H] - the protein edge features
777
+ after encoding message passing.
778
+ decoder_features (dict): Initial decoder features containing the
779
+ anti-causal mask for the decoder.
780
+ - anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
781
+ anti-causal mask for the decoder.
782
+ Returns:
783
+ h_EXV_encoder_anti_causal (torch.Tensor): [B, L, K, 3H] - the
784
+ encoder embeddings masked with the anti-causal mask.
785
+ mask_E (torch.Tensor): [B, L, K] - the mask for the edges, gathered
786
+ at the neighbor indices.
787
+ h_S (torch.Tensor): [B, L, H] - the sequence embeddings for the
788
+ decoder.
789
+ """
790
+ # Check that the input features contains the necessary keys.
791
+ if "S" not in input_features:
792
+ raise ValueError("Input features must contain 'S' key.")
793
+ if "residue_mask" not in input_features:
794
+ raise ValueError("Input features must contain 'residue_mask' key.")
795
+ if "initialize_sequence_embedding_with_ground_truth" not in input_features:
796
+ raise ValueError(
797
+ "Input features must contain"
798
+ + "'initialize_sequence_embedding_with_ground_truth' key."
799
+ )
800
+
801
+ # Check that the graph features contains the necessary keys.
802
+ if "E_idx" not in graph_features:
803
+ raise ValueError("Graph features must contain 'E_idx' key.")
804
+
805
+ # Check that the encoder features contains the necessary keys.
806
+ if "h_V" not in encoder_features:
807
+ raise ValueError("Encoder features must contain 'h_V' key.")
808
+ if "h_E" not in encoder_features:
809
+ raise ValueError("Encoder features must contain 'h_E' key.")
810
+
811
+ # Check that the decoder features contains the necessary keys.
812
+ if "anti_causal_mask" not in decoder_features:
813
+ raise ValueError("Decoder features must contain 'anti_causal_mask' key.")
814
+
815
+ # Build encoder embeddings.
816
+ # h_EX_encoder [B, L, K, 2H] - h_E_ij cat (0 vector); the edge features
817
+ # concatenated with the zero vector, since there is no sequence
818
+ # information in the encoder.
819
+ h_EX_encoder = cat_neighbors_nodes(
820
+ torch.zeros_like(encoder_features["h_V"]),
821
+ encoder_features["h_E"],
822
+ graph_features["E_idx"],
823
+ )
824
+
825
+ # h_EXV_encoder [B, L, K, 3H] - h_E_ij cat (0 vector) cat h_V_j; the
826
+ # edge features concatenated with the zero vector and the destination
827
+ # node features from the encoder.
828
+ h_EXV_encoder = cat_neighbors_nodes(
829
+ encoder_features["h_V"], h_EX_encoder, graph_features["E_idx"]
830
+ )
831
+
832
+ # h_EXV_encoder_anti_causal [B, L, K, 3H] - the encoder embeddings,
833
+ # masked with the anti-causal mask.
834
+ h_EXV_encoder_anti_causal = h_EXV_encoder * decoder_features["anti_causal_mask"]
835
+
836
+ # Gather the per-residue mask of the nearest neighbors.
837
+ # mask_E [B, L, K] - the mask for the edges, gathered at the
838
+ # neighbor indices.
839
+ mask_E = gather_nodes(
840
+ input_features["residue_mask"].unsqueeze(-1), graph_features["E_idx"]
841
+ ).squeeze(-1)
842
+ mask_E = input_features["residue_mask"].unsqueeze(-1) * mask_E
843
+
844
+ # Build sequence embedding for the decoder.
845
+ # h_S [B, L, H] - the sequence embeddings for the decoder, obtained by
846
+ # embedding the ground truth sequence S.
847
+ if input_features["initialize_sequence_embedding_with_ground_truth"]:
848
+ h_S = self.W_s(input_features["S"])
849
+ else:
850
+ h_S = torch.zeros_like(encoder_features["h_V"])
851
+
852
+ return h_EXV_encoder_anti_causal, mask_E, h_S
853
+
854
+ def logits_to_sample(self, logits, bias, pair_bias, S_for_pair_bias, temperature):
855
+ """
856
+ Convert the logits to log probabilities, probabilities, sampled
857
+ probabilities, predicted sequence, and argmax sequence.
858
+
859
+ Args:
860
+ logits (torch.Tensor): [B, L, self.vocab_size] - the logits for the
861
+ sequence.
862
+ bias (torch.Tensor, optional): [B, L, self.vocab_size] - the
863
+ bias for the sequence. If None, the code will implicitly use a
864
+ bias of 0.0 for all residues.
865
+ pair_bias (torch.Tensor, optional): [B, L, self.vocab_size, L',
866
+ self.vocab_size] - the pair bias for the sequence. Note,
867
+ L is the length for the logits, and L' is the length for the
868
+ S_for_pair_bias. In some cases, L' may be different from L (
869
+ for example, when the logits are only computed for a subset of
870
+ residues). If None, the code will implicitly use a pair bias
871
+ of 0.0 for all residue pairs.
872
+ S_for_pair_bias (torch.Tensor, optional): [B, L'] - the sequence for
873
+ the pair bias. This is used to compute the total pair bias for
874
+ every position. Allowed to be None if pair_bias is None.
875
+ temperature (torch.Tensor, optional): [B, L] - the per-residue
876
+ temperature to use for sampling. If None, the code will
877
+ implicitly use a temperature of 1.0 for all residues.
878
+ Returns:
879
+ sample_dict (dict): A dictionary containing the following keys:
880
+ - log_probs (torch.Tensor): [B, L, self.vocab_size] - the log
881
+ probabilities for the sequence.
882
+ - probs (torch.Tensor): [B, L, self.vocab_size] - the
883
+ probabilities for the sequence.
884
+ - probs_sample (torch.Tensor): [B, L, self.vocab_size] -
885
+ the probabilities for the sequence, with the unknown
886
+ residues zeroed out and the other residues normalized.
887
+ - S_sampled (torch.Tensor): [B, L] - the predicted sequence,
888
+ sampled from the probabilities (unknown residues are not
889
+ sampled).
890
+ - S_argmax (torch.Tensor): [B, L] - the predicted sequence,
891
+ obtained by taking the argmax of the probabilities (unknown
892
+ residues are not selected).
893
+ """
894
+ B, L, vocab_size = logits.shape
895
+
896
+ if pair_bias is not None:
897
+ # pair_bias_total [B, L, self.vocab_size] - the total pair bias to
898
+ # add to the sequence logits, computed for every residue by
899
+ # indexing the pair bias with the sequence (S_for_pair_bias) and
900
+ # summing over the second sequence dimension (L').
901
+ pair_bias_total = torch.gather(
902
+ pair_bias,
903
+ -1,
904
+ S_for_pair_bias[:, None, None, :, None].expand(
905
+ -1, -1, self.vocab_size, -1, -1
906
+ ),
907
+ ).sum(dim=(-2, -1))
908
+ else:
909
+ pair_bias_total = None
910
+
911
+ # modified_logits [B, L, self.vocab_size] - the logits for the
912
+ # sequence, modified by temperature, bias, and total pair bias.
913
+ modified_logits = (
914
+ logits
915
+ + (0.0 if bias is None else bias)
916
+ + (0.0 if pair_bias_total is None else pair_bias_total)
917
+ )
918
+ modified_logits = modified_logits / (
919
+ 1.0 if temperature is None else temperature.unsqueeze(-1)
920
+ )
921
+
922
+ # log_probs [B, L, self.vocab_size] - the log probabilities for the
923
+ # sequence.
924
+ log_probs = torch.nn.functional.log_softmax(modified_logits, dim=-1)
925
+
926
+ # probs [B, L, self.vocab_size] - the probabilities for the sequence.
927
+ probs = torch.nn.functional.softmax(modified_logits, dim=-1)
928
+
929
+ # probs_sample [B, L, self.vocab_size] - the probabilities for the
930
+ # sequence, with the unknown residues zeroed out and the other residues
931
+ # normalized.
932
+ probs_sample = probs.clone()
933
+ probs_sample[:, :, self.unknown_token_indices] = 0.0
934
+ probs_sample = probs_sample / torch.sum(probs_sample, dim=-1, keepdim=True)
935
+
936
+ # probs_sample_flat [B * L, self.vocab_size] - the flattened
937
+ # probabilities for the sequence.
938
+ probs_sample_flat = probs_sample.view(B * L, vocab_size)
939
+
940
+ # S_sampled [B, L] - the predicted sequence, sampled from the
941
+ # probabilities (unknown residues are not sampled).
942
+ S_sampled = torch.multinomial(probs_sample_flat, 1).squeeze(-1).view(B, L)
943
+
944
+ # S_argmax [B, L] - the predicted sequence, obtained by taking the
945
+ # argmax of the probabilities (unknown residues are not selected).
946
+ S_argmax = torch.argmax(probs_sample, dim=-1)
947
+
948
+ sample_dict = {
949
+ "log_probs": log_probs,
950
+ "probs": probs,
951
+ "probs_sample": probs_sample,
952
+ "S_sampled": S_sampled,
953
+ "S_argmax": S_argmax,
954
+ }
955
+
956
+ return sample_dict
957
+
958
+ def decode_teacher_forcing(
959
+ self, input_features, graph_features, encoder_features, decoder_features
960
+ ):
961
+ """
962
+ Given the input features, graph features, encoder features, and
963
+ decoder features, perform the decoding with teacher forcing.
964
+
965
+ Although h_S is computed from the ground truth sequence S, the causal
966
+ mask will ensure that the decoder only attends to the sequence of
967
+ previously decoded residues. Using the ground truth for all previous
968
+ residues is called teacher forcing, and it is a common technique in
969
+ language modeling tasks.
970
+
971
+ Args:
972
+ input_features (dict): Input features containing the residue mask
973
+ and sequence.
974
+ - S (torch.Tensor): [B, L] - sequence of residues.
975
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
976
+ - bias (torch.Tensor, optional): [B, L, self.vocab_size] - the
977
+ per-residue bias to use for sampling. If None, the code will
978
+ implicitly use a bias of 0.0 for all residues.
979
+ - pair_bias (torch.Tensor, optional): [B, L, self.vocab_size
980
+ , L, self.vocab_size] - the per-residue pair bias to use
981
+ for sampling. If None, the code will implicitly use a pair
982
+ bias of 0.0 for all residue pairs.
983
+ - temperature (torch.Tensor, optional): [B, L] - the per-residue
984
+ temperature to use for sampling. If None, the code will
985
+ implicitly use a temperature of 1.0.
986
+ - initialize_sequence_embedding_with_ground_truth (bool):
987
+ If True, initialize the sequence embedding with the ground
988
+ truth sequence S. Else, initialize the sequence
989
+ embedding with zeros.
990
+ - symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
991
+ an integer for every residue, indicating the symmetry group
992
+ that it belongs to. If None, the residues are not grouped by
993
+ symmetry. For example, if residue i and j should be decoded
994
+ symmetrically, then symmetry_equivalence_group[i] ==
995
+ symmetry_equivalence_group[j]. Must be torch.int64 to allow
996
+ for use as an index. These values should range from 0 to
997
+ the maximum number of symmetry groups - 1 for each example.
998
+ -symmetry_weight (torch.Tensor, optional): [B, L] - the weights
999
+ for each residue, to be used when aggregating across its
1000
+ respective symmetry group. If None, the weights are
1001
+ assumed to be 1.0 for all residues.
1002
+ graph_features (dict): Graph features containing the featurized
1003
+ node and edge inputs.
1004
+ - E_idx (torch.Tensor): [B, L, K] - edge indices.
1005
+ encoder_features (dict): Encoder features containing the encoded
1006
+ protein node and protein edge features.
1007
+ - h_V (torch.Tensor): [B, L, H] - the protein node
1008
+ features after encoding message passing.
1009
+ - h_E (torch.Tensor): [B, L, K, H] - the
1010
+ protein edge features after encoding message passing.
1011
+ decoder_features (dict): Initial decoder features containing the
1012
+ causal mask for the decoder.
1013
+ - causal_mask (torch.Tensor): [B, L, K, 1] - the
1014
+ causal mask for the decoder.
1015
+ - anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
1016
+ anti-causal mask for the decoder.
1017
+ Side Effects:
1018
+ decoder_features["h_V"] (torch.Tensor): [B, L, H] - the updated
1019
+ node features for the decoder.
1020
+ decoder_features["logits"] (torch.Tensor): [B, L, self.vocab_size] -
1021
+ the sequence logits for the decoder.
1022
+ decoder_features["log_probs"] (torch.Tensor): [B, L,
1023
+ self.vocab_size] - the log probabilities for the sequence.
1024
+ decoder_features["probs"] (torch.Tensor): [B, L, self.vocab_size] -
1025
+ the probabilities for the sequence.
1026
+ decoder_features["probs_sample"] (torch.Tensor): [B, L,
1027
+ self.vocab_size] - the probabilities for the sequence, with the
1028
+ unknown residues zeroed out and the other residues normalized.
1029
+ decoder_features["S_sampled"] (torch.Tensor): [B, L] - the
1030
+ predicted sequence, sampled from the probabilities (unknown
1031
+ residues are not sampled).
1032
+ decoder_features["S_argmax"] (torch.Tensor): [B, L] - the predicted
1033
+ sequence, obtained by taking the argmax of the probabilities
1034
+ (unknown residues are not selected).
1035
+ """
1036
+ # Check that the input features contains the necessary keys.
1037
+ if "residue_mask" not in input_features:
1038
+ raise ValueError("Input features must contain 'residue_mask' key.")
1039
+ if "S" not in input_features:
1040
+ raise ValueError("Input features must contain 'S' key.")
1041
+ if "bias" not in input_features:
1042
+ raise ValueError("Input features must contain 'bias' key.")
1043
+ if "pair_bias" not in input_features:
1044
+ raise ValueError("Input features must contain 'pair_bias' key.")
1045
+ if "temperature" not in input_features:
1046
+ raise ValueError("Input features must contain 'temperature' key.")
1047
+ if "symmetry_equivalence_group" not in input_features:
1048
+ raise ValueError(
1049
+ "Input features must contain 'symmetry_equivalence_group' key."
1050
+ )
1051
+ if "symmetry_weight" not in input_features:
1052
+ raise ValueError("Input features must contain 'symmetry_weight' key.")
1053
+
1054
+ # Check that the graph features contains the necessary keys.
1055
+ if "E_idx" not in graph_features:
1056
+ raise ValueError("Graph features must contain 'E_idx' key.")
1057
+
1058
+ # Check that the encoder features contains the necessary keys.
1059
+ if "h_V" not in encoder_features:
1060
+ raise ValueError("Encoder features must contain 'h_V' key.")
1061
+ if "h_E" not in encoder_features:
1062
+ raise ValueError("Encoder features must contain 'h_E' key.")
1063
+
1064
+ # Check that the decoder features contains the necessary keys.
1065
+ if "causal_mask" not in decoder_features:
1066
+ raise ValueError("Decoder features must contain 'causal_mask' key.")
1067
+
1068
+ # Do the setup for the decoder.
1069
+ # h_EXV_encoder_anti_causal [B, L, K, 3H] - the encoder embeddings,
1070
+ # masked with the anti-causal mask.
1071
+ # mask_E [B, L, K] - the mask for the edges, gathered at the
1072
+ # neighbor indices.
1073
+ # h_S [B, L, H] - the sequence embeddings for the decoder.
1074
+ h_EXV_encoder_anti_causal, mask_E, h_S = self.decode_setup(
1075
+ input_features, graph_features, encoder_features, decoder_features
1076
+ )
1077
+
1078
+ # h_ES [B, L, K, 2H] - h_E_ij cat h_S_j; the edge features concatenated
1079
+ # with the sequence embeddings for the destination nodes.
1080
+ h_ES = cat_neighbors_nodes(
1081
+ h_S, encoder_features["h_E"], graph_features["E_idx"]
1082
+ )
1083
+
1084
+ # Run the decoder layers.
1085
+ h_V_decoder = encoder_features["h_V"]
1086
+ for layer in self.decoder_layers:
1087
+ # h_ESV_decoder [B, L, K, 3H] - h_E_ij cat h_S_j cat h_V_decoder_j;
1088
+ # for the decoder embeddings, the edge features are concatenated
1089
+ # with the destination node sequence embeddings and node features.
1090
+ h_ESV_decoder = cat_neighbors_nodes(
1091
+ h_V_decoder, h_ES, graph_features["E_idx"]
1092
+ )
1093
+
1094
+ # h_ESV [B, L, K, 3H] - the encoder and decoder embeddings,
1095
+ # combined according to the causal and anti-causal masks.
1096
+ # Combine the encoder embeddings with the decoder embeddings,
1097
+ # using the causal and anti-causal masks. When decoding the residue
1098
+ # at position i:
1099
+ # - for residue j, decoded before i:
1100
+ # - h_ESV_ij = h_E_ij cat h_S_j cat h_V_decoder_j
1101
+ # - encoder edge embedding, decoder destination node
1102
+ # sequence embedding, and decoder destination node
1103
+ # embedding.
1104
+ # - for residue j, decoded after i (including i):
1105
+ # - h_ESV_ij = h_E_ij cat (0 vector) cat h_V_j
1106
+ # - encoder edge embedding, zero vector (no sequence
1107
+ # information), and encoder destination node embedding.
1108
+ # This prevents leakage of sequence information.
1109
+ # - NOTE: h_V_j comes from the encoder.
1110
+ # - NOTE: h_E is not updated in the decoder, h_E_ij comes from
1111
+ # the encoder.
1112
+ # - NOTE: within the decoder layer itself, h_V_decoder_i will
1113
+ # be concatenated to h_ESV_ij.
1114
+ h_ESV = (
1115
+ decoder_features["causal_mask"] * h_ESV_decoder
1116
+ + h_EXV_encoder_anti_causal
1117
+ )
1118
+
1119
+ # h_V_decoder [B, L, H] - the updated node features for the
1120
+ # decoder.
1121
+ h_V_decoder = torch.utils.checkpoint.checkpoint(
1122
+ layer,
1123
+ h_V_decoder,
1124
+ h_ESV,
1125
+ mask_V=input_features["residue_mask"],
1126
+ mask_E=mask_E,
1127
+ use_reentrant=False,
1128
+ )
1129
+
1130
+ # logits [B, L, self.vocab_size] - project the final node features to
1131
+ # get the sequence logits.
1132
+ logits = self.W_out(h_V_decoder)
1133
+
1134
+ # Handle symmetry equivalence groups if they are provided, performing
1135
+ # a (possibly weighted) sum of the logits across residues that
1136
+ # belong to the same symmetry group.
1137
+ if input_features["symmetry_equivalence_group"] is not None:
1138
+ # Assume that all symmetry groups are non-negative.
1139
+ assert input_features["symmetry_equivalence_group"].min() >= 0
1140
+
1141
+ B, L, _ = logits.shape
1142
+
1143
+ # The maximum number of symmetry groups in the batch.
1144
+ G = (input_features["symmetry_equivalence_group"].max().item()) + 1
1145
+
1146
+ # symmetry_equivalence_group_one_hot [B, L, G] - one-hot encoding
1147
+ # of the symmetry equivalence group for each residue.
1148
+ symmetry_equivalence_group_one_hot = torch.nn.functional.one_hot(
1149
+ input_features["symmetry_equivalence_group"], num_classes=G
1150
+ ).float()
1151
+
1152
+ # scaled_symmetry_equivalence_group_one_hot [B, L, G] - the one-hot
1153
+ # encoding of the symmetry equivalence group, scaled by the
1154
+ # symmetry weights, if they are provided. If not provided, the
1155
+ # symmetry weights are implicitly assumed to be 1.0 for all
1156
+ # residues.
1157
+ scaled_symmetry_equivalence_group_one_hot = (
1158
+ symmetry_equivalence_group_one_hot
1159
+ * (
1160
+ 1.0
1161
+ if input_features["symmetry_weight"] is None
1162
+ else input_features["symmetry_weight"].unsqueeze(-1)
1163
+ )
1164
+ )
1165
+
1166
+ # weighted_sum_logits [B, G, self.vocab_size] - the logits for the
1167
+ # sequence, summed across the symmetry groups, weighted by the
1168
+ # symmetry weights.
1169
+ weighted_sum_logits = torch.einsum(
1170
+ "blg,blv->bgv", scaled_symmetry_equivalence_group_one_hot, logits
1171
+ )
1172
+
1173
+ # logits [B, L, self.vocab_size] - overwrite the original logits
1174
+ # with the weighted and summed logits for the residues that belong
1175
+ # to the same symmetry group.
1176
+ logits = torch.einsum(
1177
+ "blg,bgv->blv", symmetry_equivalence_group_one_hot, weighted_sum_logits
1178
+ )
1179
+
1180
+ # Compute the log probabilities, probabilities, sampled probabilities,
1181
+ # predicted sequence, and argmax sequence.
1182
+ sample_dict = self.logits_to_sample(
1183
+ logits,
1184
+ input_features["bias"],
1185
+ input_features["pair_bias"],
1186
+ input_features["S"],
1187
+ input_features["temperature"],
1188
+ )
1189
+
1190
+ # All outputs from logits_to_sample should be the same across
1191
+ # symmetry equivalence groups, except for S_sampled (due to the
1192
+ # sampling being stochastic). Correct for this by overwriting S_sampled
1193
+ # with the first sampled residue in each group.
1194
+ if input_features["symmetry_equivalence_group"] is not None:
1195
+ S_sampled = sample_dict["S_sampled"]
1196
+
1197
+ B, L = S_sampled.shape
1198
+
1199
+ # Compute the maximum number of symmetry groups in the batch.
1200
+ G = (input_features["symmetry_equivalence_group"].max().item()) + 1
1201
+
1202
+ for b in range(B):
1203
+ for g in range(G):
1204
+ # group_mask [L] - mask for the residues that belong to
1205
+ # the symmetry equivalence group g for the batch example b.
1206
+ group_mask = input_features["symmetry_equivalence_group"][b] == g
1207
+
1208
+ # If there are residues in the group, set every S_sampled
1209
+ # in the group to the first S_sampled in the group.
1210
+ if group_mask.any():
1211
+ first = torch.where(group_mask)[0][0]
1212
+ S_sampled[b, group_mask] = S_sampled[b, first]
1213
+
1214
+ sample_dict["S_sampled"] = S_sampled
1215
+
1216
+ # Update the decoder features with the final node features, the computed
1217
+ # logits, log probabilities, probabilities, sampled probabilities,
1218
+ # predicted sequence, and argmax sequence.
1219
+ decoder_features["h_V"] = h_V_decoder
1220
+ decoder_features["logits"] = logits
1221
+ decoder_features["log_probs"] = sample_dict["log_probs"]
1222
+ decoder_features["probs"] = sample_dict["probs"]
1223
+ decoder_features["probs_sample"] = sample_dict["probs_sample"]
1224
+ decoder_features["S_sampled"] = sample_dict["S_sampled"]
1225
+ decoder_features["S_argmax"] = sample_dict["S_argmax"]
1226
+
1227
+ def decode_auto_regressive(
1228
+ self, input_features, graph_features, encoder_features, decoder_features
1229
+ ):
1230
+ """
1231
+ Given the input features, graph features, encoder features, and
1232
+ decoder features, perform the autoregressive decoding.
1233
+
1234
+ Args:
1235
+ input_features (dict): Input features containing the residue mask
1236
+ and sequence.
1237
+ - S (torch.Tensor): [B, L] - sequence of residues.
1238
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
1239
+ - bias (torch.Tensor, optional): [B, L, self.vocab_size]
1240
+ - the per-residue bias to use for sampling. If None, the
1241
+ code will implicitly use a bias of 0.0 for all residues.
1242
+ - pair_bias (torch.Tensor, optional): [B, L, self.vocab_size
1243
+ , L, self.vocab_size] - the per-residue pair bias to use
1244
+ for sampling. If None, the code will implicitly use a pair
1245
+ bias of 0.0 for all residue pairs.
1246
+ - temperature (torch.Tensor, optional): [B, L] - the per-residue
1247
+ temperature to use for sampling. If None, the code will
1248
+ implicitly use a temperature of 1.0.
1249
+ - initialize_sequence_embedding_with_ground_truth (bool):
1250
+ If True, initialize the sequence embedding with the ground
1251
+ truth sequence S. Else, initialize the sequence
1252
+ embedding with zeros. Also, if True, initialize S_sampled
1253
+ with the ground truth sequence S, which should only affect
1254
+ the application of pair bias (which relies on the predicted
1255
+ sequence). This is useful if we want to perform
1256
+ auto-regressive redesign.
1257
+ - symmetry_equivalence_group (torch.Tensor, optional): [B, L] -
1258
+ an integer for every residue, indicating the symmetry group
1259
+ that it belongs to. If None, the residues are not grouped by
1260
+ symmetry. For example, if residue i and j should be decoded
1261
+ symmetrically, then symmetry_equivalence_group[i] ==
1262
+ symmetry_equivalence_group[j]. Must be torch.int64 to allow
1263
+ for use as an index. These values should range from 0 to
1264
+ the maximum number of symmetry groups - 1 for each example.
1265
+ NOTE: bias, pair_bias, and temperature should be the same
1266
+ for all residues in the symmetry equivalence group;
1267
+ otherwise, the intended behavior may not be achieved. The
1268
+ residues within a symmetry group should all have the same
1269
+ validity and design/fixed status.
1270
+ -symmetry_weight (torch.Tensor, optional): [B, L] - the weights
1271
+ for each residue, to be used when aggregating across its
1272
+ respective symmetry group. If None, the weights are
1273
+ assumed to be 1.0 for all residues.
1274
+ graph_features (dict): Graph features containing the featurized
1275
+ node and edge inputs.
1276
+ - E_idx (torch.Tensor): [B, L, K] - edge indices.
1277
+ encoder_features (dict): Encoder features containing the encoded
1278
+ protein node and protein edge features.
1279
+ - h_V (torch.Tensor): [B, L, H] - the protein node features
1280
+ after encoding message passing.
1281
+ - h_E (torch.Tensor): [B, L, K, H] - the protein edge features
1282
+ after encoding message passing.
1283
+ decoder_features (dict): Initial decoder features containing the
1284
+ causal mask for the decoder.
1285
+ - decoding_order (torch.Tensor): [B, L] - the order in which
1286
+ the residues should be decoded.
1287
+ - decode_last_mask (torch.Tensor): [B, L] - the mask for which
1288
+ residues should be decoded last, where False is a residue
1289
+ that should be decoded first (invalid or fixed), and True
1290
+ is a residue that should not be decoded first (designed
1291
+ residues).
1292
+ - causal_mask (torch.Tensor): [B, L, K, 1] - the causal mask
1293
+ for the decoder.
1294
+ - anti_causal_mask (torch.Tensor): [B, L, K, 1] - the anti-
1295
+ causal mask for the decoder.
1296
+ Side Effects:
1297
+ decoder_features["h_V"] (torch.Tensor): [B, L, H] - the updated
1298
+ node features for the decoder.
1299
+ decoder_features["logits"] (torch.Tensor): [B, L, self.vocab_size] -
1300
+ the sequence logits for the decoder.
1301
+ decoder_features["log_probs"] (torch.Tensor): [B, L,
1302
+ self.vocab_size] - the log probabilities for the sequence.
1303
+ decoder_features["probs"] (torch.Tensor): [B, L, self.vocab_size] -
1304
+ the probabilities for the sequence.
1305
+ decoder_features["probs_sample"] (torch.Tensor): [B, L,
1306
+ self.vocab_size] - the probabilities for the sequence, with the
1307
+ unknown residues zeroed out and the other residues normalized.
1308
+ decoder_features["S_sampled"] (torch.Tensor): [B, L] - the
1309
+ predicted sequence, sampled from the probabilities (unknown
1310
+ residues are not sampled).
1311
+ decoder_features["S_argmax"] (torch.Tensor): [B, L] - the predicted
1312
+ sequence, obtained by taking the argmax of the probabilities
1313
+ (unknown residues are not selected).
1314
+ """
1315
+ # Check that the input features contains the necessary keys.
1316
+ if "residue_mask" not in input_features:
1317
+ raise ValueError("Input features must contain 'residue_mask' key.")
1318
+ if "S" not in input_features:
1319
+ raise ValueError("Input features must contain 'S' key.")
1320
+ if "temperature" not in input_features:
1321
+ raise ValueError("Input features must contain 'temperature' key.")
1322
+ if "bias" not in input_features:
1323
+ raise ValueError("Input features must contain 'bias' key.")
1324
+ if "pair_bias" not in input_features:
1325
+ raise ValueError("Input features must contain 'pair_bias' key.")
1326
+ if "symmetry_equivalence_group" not in input_features:
1327
+ raise ValueError(
1328
+ "Input features must contain 'symmetry_equivalence_group' key."
1329
+ )
1330
+ if "symmetry_weight" not in input_features:
1331
+ raise ValueError("Input features must contain 'symmetry_weight' key.")
1332
+
1333
+ # Check that the graph features contains the necessary keys.
1334
+ if "E_idx" not in graph_features:
1335
+ raise ValueError("Graph features must contain 'E_idx' key.")
1336
+
1337
+ # Check that the encoder features contains the necessary keys.
1338
+ if "h_V" not in encoder_features:
1339
+ raise ValueError("Encoder features must contain 'h_V' key.")
1340
+ if "h_E" not in encoder_features:
1341
+ raise ValueError("Encoder features must contain 'h_E' key.")
1342
+
1343
+ # Check that the decoder features contains the necessary keys.
1344
+ if "decoding_order" not in decoder_features:
1345
+ raise ValueError("Decoder features must contain 'decoding_order' key.")
1346
+ if "decode_last_mask" not in decoder_features:
1347
+ raise ValueError("Decoder features must contain 'decode_last_mask' key.")
1348
+ if "causal_mask" not in decoder_features:
1349
+ raise ValueError("Decoder features must contain 'causal_mask' key.")
1350
+
1351
+ B, L = input_features["residue_mask"].shape
1352
+
1353
+ # Do the setup for the decoder.
1354
+ # h_EXV_encoder_anti_causal [B, L, K, 3H] - the encoder embeddings,
1355
+ # masked with the anti-causal mask.
1356
+ # mask_E [B, L, K] - the mask for the edges, gathered at the
1357
+ # neighbor indices.
1358
+ # h_S [B, L, H] - the sequence embeddings for the decoder.
1359
+ h_EXV_encoder_anti_causal, mask_E, h_S = self.decode_setup(
1360
+ input_features, graph_features, encoder_features, decoder_features
1361
+ )
1362
+
1363
+ # We can precompute the output dtype depending on automatic mixed
1364
+ # precision settings. This works because the W_out layer is a linear
1365
+ # layer, which has predictable dtype behavior with AMP.
1366
+ device = input_features["residue_mask"].device
1367
+ if device.type in ("cuda", "cpu") and torch.is_autocast_enabled(
1368
+ device_type=device.type
1369
+ ):
1370
+ output_dtype = torch.get_autocast_dtype(device_type=device.type)
1371
+ else:
1372
+ output_dtype = torch.float32
1373
+
1374
+ # logits [B, L, self.vocab_size] - the logits for every residue
1375
+ # position and residue type.
1376
+ logits = torch.zeros(
1377
+ (B, L, self.vocab_size),
1378
+ device=input_features["residue_mask"].device,
1379
+ dtype=output_dtype,
1380
+ )
1381
+
1382
+ # logits_i [B, 1, self.vocab_size] - the logits for the
1383
+ # residue at the current decoding index, computed from the
1384
+ # decoded node features. Declared here for accumulation use when
1385
+ # performing symmetry decoding.
1386
+ logits_i = torch.zeros(
1387
+ (B, 1, self.vocab_size),
1388
+ device=input_features["residue_mask"].device,
1389
+ dtype=output_dtype,
1390
+ )
1391
+
1392
+ # log_probs [B, L, self.vocab_size] - the log probabilities for every
1393
+ # residue position and residue type.
1394
+ log_probs = torch.zeros(
1395
+ (B, L, self.vocab_size),
1396
+ device=input_features["residue_mask"].device,
1397
+ dtype=torch.float32,
1398
+ )
1399
+
1400
+ # probs [B, L, self.vocab_size] - the probabilities for every residue
1401
+ # position and residue type.
1402
+ probs = torch.zeros(
1403
+ (B, L, self.vocab_size),
1404
+ device=input_features["residue_mask"].device,
1405
+ dtype=torch.float32,
1406
+ )
1407
+
1408
+ # probs_sample [B, L, self.vocab_size] - the probabilities for every
1409
+ # residue position and residue type, with the unknown residues zeroed
1410
+ # out and the other residues normalized.
1411
+ probs_sample = torch.zeros(
1412
+ (B, L, self.vocab_size),
1413
+ device=input_features["residue_mask"].device,
1414
+ dtype=torch.float32,
1415
+ )
1416
+
1417
+ # S_sampled [B, L] - the predicted/sampled sequence of residues. If
1418
+ # initialize_sequence_embedding_with_ground_truth is True, it is
1419
+ # initialized with the ground truth sequence S; this should only
1420
+ # affect the application of pair bias (which relies on the predicted
1421
+ # sequence), and is useful if we want to perform auto-regressive
1422
+ # redesign. Otherwise, this should have no effect, as we overwrite
1423
+ # S_sampled with the sampled sequence at every decoding step.
1424
+ if input_features["initialize_sequence_embedding_with_ground_truth"]:
1425
+ S_sampled = input_features["S"].clone()
1426
+ else:
1427
+ S_sampled = torch.full(
1428
+ (B, L),
1429
+ fill_value=self.token_to_idx[UNKNOWN_AA],
1430
+ device=input_features["S"].device,
1431
+ dtype=input_features["S"].dtype,
1432
+ )
1433
+
1434
+ # S_argmax [B, L] - the argmax sequence of residues, initialized with
1435
+ # the unknown residue type.
1436
+ S_argmax = torch.full(
1437
+ (B, L),
1438
+ fill_value=self.token_to_idx[UNKNOWN_AA],
1439
+ device=input_features["S"].device,
1440
+ dtype=input_features["S"].dtype,
1441
+ )
1442
+
1443
+ # h_V_decoder_stack - list containing the hidden node embeddings from
1444
+ # each decoder layer; populated iteratively during the decoding.
1445
+ # h_V_decoder_stack[i] [B, L, H] - the hidden node embeddings,
1446
+ # the 0th entry is the initial node embeddings from the encoder, and
1447
+ # the i-th entry is the hidden node embeddings after the i-th decoder
1448
+ # layer.
1449
+ # NOTE: it is necessary to keep the embeddings from all decoder layers,
1450
+ # since later decoding positions rely on the intermediate decoder layer
1451
+ # embeddings of previously decoded residues.
1452
+ h_V_decoder_stack = [encoder_features["h_V"]] + [
1453
+ torch.zeros_like(
1454
+ encoder_features["h_V"], device=input_features["residue_mask"].device
1455
+ )
1456
+ for _ in range(len(self.decoder_layers))
1457
+ ]
1458
+
1459
+ # batch_idx [B, 1] - the batch indices for the decoder.
1460
+ batch_idx = torch.arange(
1461
+ B, device=input_features["residue_mask"].device
1462
+ ).unsqueeze(-1)
1463
+
1464
+ # Iteratively decode, updating the hidden sequence embeddings.
1465
+ for decoding_idx in range(L):
1466
+ # i [B, 1] - the indices of the residues to decode in the
1467
+ # current iteration, based on the decoding order.
1468
+ i = decoder_features["decoding_order"][:, decoding_idx].unsqueeze(-1)
1469
+
1470
+ # decode_last_mask_i [B, 1] - the mask for residues that should be
1471
+ # decoded last, where False is a residue that should be decoded
1472
+ # first (invalid or fixed), and True is a residue that should not be
1473
+ # decoded first (designed residues); at the current decoding
1474
+ # index.
1475
+ decode_last_mask_i = decoder_features["decode_last_mask"][batch_idx, i]
1476
+
1477
+ # residue_mask_i [B, 1] - the mask for the residue at the current
1478
+ # decoding index.
1479
+ residue_mask_i = input_features["residue_mask"][batch_idx, i]
1480
+
1481
+ # mask_E_i [B, 1, K] - the mask for the edges at the current
1482
+ # decoding index, gathered at the neighbor indices.
1483
+ mask_E_i = mask_E[batch_idx, i]
1484
+
1485
+ # S_i [B, 1] - the ground truth sequence for the residue at the
1486
+ # current decoding index (for designed positions, undefined).
1487
+ S_i = input_features["S"][batch_idx, i]
1488
+
1489
+ # Setup the temperature, bias, and pair bias for the current
1490
+ # decoding index.
1491
+ if input_features["temperature"] is not None:
1492
+ # temperature_i [B, 1] - the temperature for the residue at the
1493
+ # current decoding index.
1494
+ temperature_i = input_features["temperature"][batch_idx, i]
1495
+ else:
1496
+ temperature_i = None
1497
+
1498
+ if input_features["bias"] is not None:
1499
+ # bias_i [B, 1, self.vocab_size] - the bias for the residue at
1500
+ # the current decoding index.
1501
+ bias_i = input_features["bias"][batch_idx, i]
1502
+ else:
1503
+ bias_i = None
1504
+
1505
+ if input_features["pair_bias"] is not None:
1506
+ # pair_bias_i [B, 1, self.vocab_size, L, self.vocab_size] - the
1507
+ # pair bias for the residue at the current decoding index.
1508
+ pair_bias_i = input_features["pair_bias"][batch_idx, i]
1509
+ else:
1510
+ pair_bias_i = None
1511
+
1512
+ if input_features["symmetry_equivalence_group"] is not None:
1513
+ # symmetry_equivalence_group_i [B, 1] - the symmetry
1514
+ # equivalence group for the residue at the current decoding
1515
+ # index.
1516
+ symmetry_equivalence_group_i = input_features[
1517
+ "symmetry_equivalence_group"
1518
+ ][batch_idx, i]
1519
+ else:
1520
+ symmetry_equivalence_group_i = None
1521
+
1522
+ if input_features["symmetry_weight"] is not None:
1523
+ # symmetry_weight_i [B, 1] - the symmetry weights for the
1524
+ # residue at the current decoding index.
1525
+ symmetry_weight_i = input_features["symmetry_weight"][batch_idx, i]
1526
+ else:
1527
+ symmetry_weight_i = None
1528
+
1529
+ # Gather the graph, encoder, and sequence features for the
1530
+ # current decoding index.
1531
+ # E_idx_i [B, 1, K] - the edge indices for the residue at the
1532
+ # current decoding index.
1533
+ E_idx_i = graph_features["E_idx"][batch_idx, i]
1534
+
1535
+ # h_E_i [B, 1, K, H] - the post-encoder edge features for the
1536
+ # residue at the current decoding index.
1537
+ h_E_i = encoder_features["h_E"][batch_idx, i]
1538
+
1539
+ # h_ES_i [B, 1, K, 2H] - the edge features concatenated with the
1540
+ # sequence embeddings for the destination nodes, for the residue at
1541
+ # the current decoding index.
1542
+ h_ES_i = cat_neighbors_nodes(h_S, h_E_i, E_idx_i)
1543
+
1544
+ # h_EXV_encoder_anti_causal_i [B, 1, K, 3H] - the encoder
1545
+ # embeddings, masked with the anti-causal mask, for the residue at
1546
+ # the current decoding index.
1547
+ h_EXV_encoder_anti_causal_i = h_EXV_encoder_anti_causal[batch_idx, i]
1548
+
1549
+ # causal_mask_i [B, 1, K, 1] - the causal mask for the residue at
1550
+ # the current decoding index.
1551
+ causal_mask_i = decoder_features["causal_mask"][batch_idx, i]
1552
+
1553
+ # Apply the decoder layers, updating the hidden node embeddings
1554
+ # for the current decoding index.
1555
+ for layer_idx, layer in enumerate(self.decoder_layers):
1556
+ # h_ESV_decoder_i [B, 1, K, 3H] - h_E_ij cat h_S_j cat
1557
+ # h_V_decoder_j; for the decoder embeddings, the edge features
1558
+ # are concatenated with the destination node sequence embeddings
1559
+ # and node features, for the residue at the current decoding
1560
+ # index.
1561
+ h_ESV_decoder_i = cat_neighbors_nodes(
1562
+ h_V_decoder_stack[layer_idx], h_ES_i, E_idx_i
1563
+ )
1564
+
1565
+ # h_ESV_i [B, 1, K, 3H] - the encoder and decoder embeddings,
1566
+ # combined according to the causal and anti-causal masks.
1567
+ # Combine the encoder embeddings with the decoder embeddings,
1568
+ # using the causal and anti-causal masks. When decoding the
1569
+ # residue at position i:
1570
+ # - for residue j, decoded before i:
1571
+ # - h_ESV_ij = h_E_ij cat h_S_j cat h_V_decoder_j
1572
+ # - encoder edge embedding, decoder destination node
1573
+ # sequence embedding, and decoder destination node
1574
+ # embedding.
1575
+ # - for residue j, decoded after i (including i):
1576
+ # - h_ESV_ij = h_E_ij cat (0 vector) cat h_V_j
1577
+ # - encoder edge embedding, zero vector (no sequence
1578
+ # information), and encoder destination node
1579
+ # embedding. This prevents leakage of sequence
1580
+ # information.
1581
+ # - NOTE: h_V_j comes from the encoder.
1582
+ # - NOTE: h_E is not updated in the decoder, h_E_ij comes
1583
+ # from the encoder.
1584
+ # - NOTE: within the decoder layer itself, h_V_decoder_i will
1585
+ # be concatenated to h_ESV_ij.
1586
+ h_ESV_i = causal_mask_i * h_ESV_decoder_i + h_EXV_encoder_anti_causal_i
1587
+
1588
+ # h_V_decoder_i [B, 1, H] - the updated node features for the
1589
+ # decoder, after applying the layer at the current decoding
1590
+ # index.
1591
+ h_V_decoder_i = torch.utils.checkpoint.checkpoint(
1592
+ layer,
1593
+ h_V_decoder_stack[layer_idx][batch_idx, i],
1594
+ h_ESV_i,
1595
+ mask_V=residue_mask_i,
1596
+ mask_E=mask_E_i,
1597
+ use_reentrant=False,
1598
+ )
1599
+
1600
+ # h_V_decoder_stack[layer_idx + 1][batch_idx, i] [B, 1, H] -
1601
+ # the updated node features for the decoder, after applying the
1602
+ # layer at the current decoding index.
1603
+ if not torch.is_grad_enabled():
1604
+ h_V_decoder_stack[layer_idx + 1][batch_idx, i] = h_V_decoder_i
1605
+ else:
1606
+ # For gradient tracking, we can't use in-place operations.
1607
+ h_V_decoder_stack[layer_idx + 1] = h_V_decoder_stack[
1608
+ layer_idx + 1
1609
+ ].scatter(
1610
+ 1,
1611
+ i.unsqueeze(-1).expand(-1, -1, self.hidden_dim),
1612
+ h_V_decoder_i,
1613
+ )
1614
+
1615
+ if input_features["symmetry_equivalence_group"] is None:
1616
+ # logits_i [B, 1, self.vocab_size] - the logits for the
1617
+ # residue at the current decoding index, computed from the
1618
+ # decoded node features.
1619
+ logits_i = self.W_out(h_V_decoder_stack[-1][batch_idx, i])
1620
+ else:
1621
+ # logits_i [B, 1, self.vocab_size] - the logits for the
1622
+ # residue at the current decoding index, computed from the
1623
+ # decoded node features, aggregated across symmetry groups,
1624
+ # weighted by the symmetry weights.
1625
+ logits_i += self.W_out(h_V_decoder_stack[-1][batch_idx, i]) * (
1626
+ 1.0
1627
+ if symmetry_weight_i is None
1628
+ else symmetry_weight_i.unsqueeze(-1)
1629
+ )
1630
+
1631
+ # Compute the log probabilities, probabilities, sampled
1632
+ # probabilities, predicted sequence, and argmax sequence for the
1633
+ # current decoding index.
1634
+ sample_dict = self.logits_to_sample(
1635
+ logits_i, bias_i, pair_bias_i, S_sampled, temperature_i
1636
+ )
1637
+
1638
+ log_probs_i = sample_dict["log_probs"]
1639
+ probs_i = sample_dict["probs"]
1640
+ probs_sample_i = sample_dict["probs_sample"]
1641
+ S_sampled_i = sample_dict["S_sampled"]
1642
+ S_argmax_i = sample_dict["S_argmax"]
1643
+
1644
+ if input_features["symmetry_equivalence_group"] is None:
1645
+ # Save the logits, probabilities, probabilities, sampled
1646
+ # probabilities, and log probabilities for the current decoding
1647
+ # index. These are saved but not sampled for invalid/fixed
1648
+ # residues.
1649
+ logits[batch_idx, i] = logits_i
1650
+ probs[batch_idx, i] = probs_i
1651
+ probs_sample[batch_idx, i] = probs_sample_i
1652
+ if not torch.is_grad_enabled():
1653
+ log_probs[batch_idx, i] = log_probs_i
1654
+ else:
1655
+ # For gradient tracking, we can't use in-place operations.
1656
+ log_probs = log_probs.scatter(
1657
+ 1, i.unsqueeze(-1).expand(-1, -1, self.vocab_size), log_probs_i
1658
+ )
1659
+
1660
+ # Update the predicted sequence and argmax sequence for the
1661
+ # current decoding index.
1662
+ S_sampled[batch_idx, i] = (S_sampled_i * decode_last_mask_i) + (
1663
+ S_i * (~decode_last_mask_i)
1664
+ )
1665
+ S_argmax[batch_idx, i] = (S_argmax_i * decode_last_mask_i) + (
1666
+ S_i * (~decode_last_mask_i)
1667
+ )
1668
+
1669
+ # h_S_i [B, 1, self.hidden_dim] - the sequence embeddings of the
1670
+ # sampled/fixed residue at the current decoding index.
1671
+ h_S_i = self.W_s(S_sampled[batch_idx, i])
1672
+
1673
+ # Update the decoder sequence embeddings with the predicted
1674
+ # sequence for the current decoding index.
1675
+ if not torch.is_grad_enabled():
1676
+ h_S[batch_idx, i] = h_S_i
1677
+ else:
1678
+ # For gradient tracking, we can't use in-place operations.
1679
+ h_S = h_S.scatter(
1680
+ 1, i.unsqueeze(-1).expand(-1, -1, self.hidden_dim), h_S_i
1681
+ )
1682
+ else:
1683
+ # symm_group_end_mask [B, 1] - mask for the residues that are
1684
+ # at the end of a symmetry group, where True is a residue that
1685
+ # is at the end of a symmetry group, and False is a residue that
1686
+ # is not at the end of a symmetry group. When we are at the last
1687
+ # decoding index, we know that all residues are at the end of a
1688
+ # symmetry group.
1689
+ if decoding_idx == (L - 1):
1690
+ symm_group_end_mask = torch.ones(
1691
+ (B, 1),
1692
+ device=input_features["residue_mask"].device,
1693
+ dtype=torch.bool,
1694
+ )
1695
+ else:
1696
+ # next_i [B, 1] - the indices of the next residues to decode
1697
+ # in the current iteration, based on the decoding order.
1698
+ next_i = decoder_features["decoding_order"][
1699
+ :, decoding_idx + 1
1700
+ ].unsqueeze(-1)
1701
+
1702
+ # symmetry_equivalence_group_next_i [B, 1] - the symmetry
1703
+ # equivalence group for the residue at the next decoding
1704
+ # index.
1705
+ symmetry_equivalence_group_next_i = input_features[
1706
+ "symmetry_equivalence_group"
1707
+ ][batch_idx, next_i]
1708
+
1709
+ symm_group_end_mask = (
1710
+ symmetry_equivalence_group_i
1711
+ != symmetry_equivalence_group_next_i
1712
+ )
1713
+
1714
+ # same_symm_group_mask [B, L] - mask for the residues that
1715
+ # belong to the same symmetry group as the current residue,
1716
+ # where True is a residue that belongs to the same symmetry
1717
+ # group, and False is a residue that does not belong to the same
1718
+ # symmetry group.
1719
+ same_symm_group_mask = (
1720
+ input_features["symmetry_equivalence_group"]
1721
+ == symmetry_equivalence_group_i
1722
+ )
1723
+
1724
+ # symm_end_and_same_mask [B, L] - mask that combines the
1725
+ # symm_group_end_mask and same_symm_group_mask, where all
1726
+ # residues with the same symmetry group as the current residue
1727
+ # (if the current residue is at the end of a symmetry group)
1728
+ # are True, and all other residues are False.
1729
+ symm_end_and_same_mask = symm_group_end_mask & same_symm_group_mask
1730
+
1731
+ # symm_end_and_same_mask_vocab [B, L, self.vocab_size] -
1732
+ # symm_end_and_same_mask projected to the vocabulary size.
1733
+ symm_end_and_same_mask_vocab = symm_end_and_same_mask.unsqueeze(
1734
+ -1
1735
+ ).expand(-1, -1, self.vocab_size)
1736
+
1737
+ # symm_end_and_same_mask_hidden [B, L, self.hidden_dim] -
1738
+ # symm_end_and_same_mask projected to the hidden dimension.
1739
+ symm_end_and_same_mask_for_hidden = symm_end_and_same_mask.unsqueeze(
1740
+ -1
1741
+ ).expand(-1, -1, self.hidden_dim)
1742
+
1743
+ # Save the logits, probabilities, sampled probabilities, and log
1744
+ # probabilities for the current decoding index, if the residue
1745
+ # is at the end of a symmetry group.
1746
+ logits[symm_end_and_same_mask_vocab] = logits_i.expand(-1, L, -1)[
1747
+ symm_end_and_same_mask_vocab
1748
+ ]
1749
+ probs[symm_end_and_same_mask_vocab] = probs_i.expand(-1, L, -1)[
1750
+ symm_end_and_same_mask_vocab
1751
+ ]
1752
+ probs_sample[symm_end_and_same_mask_vocab] = probs_sample_i.expand(
1753
+ -1, L, -1
1754
+ )[symm_end_and_same_mask_vocab]
1755
+ if not torch.is_grad_enabled():
1756
+ log_probs[symm_end_and_same_mask_vocab] = log_probs_i.expand(
1757
+ -1, L, -1
1758
+ )[symm_end_and_same_mask_vocab]
1759
+ else:
1760
+ # For gradient tracking, we can't use in-place operations.
1761
+ log_probs = torch.where(
1762
+ symm_end_and_same_mask_vocab,
1763
+ log_probs_i.expand(-1, L, -1),
1764
+ log_probs,
1765
+ )
1766
+
1767
+ # Update the predicted sequence and argmax sequence for the
1768
+ # current decoding index, if the residue is at the end of a
1769
+ # symmetry group.
1770
+ S_sampled[symm_end_and_same_mask] = (
1771
+ (S_sampled_i * decode_last_mask_i) + (S_i * (~decode_last_mask_i))
1772
+ ).expand(-1, L)[symm_end_and_same_mask]
1773
+ S_argmax[symm_end_and_same_mask] = (
1774
+ (S_argmax_i * decode_last_mask_i) + (S_i * (~decode_last_mask_i))
1775
+ ).expand(-1, L)[symm_end_and_same_mask]
1776
+
1777
+ # h_S_i [B, 1, self.hidden_dim] - the sequence embeddings of the
1778
+ # sampled/fixed residue at the current decoding index.
1779
+ h_S_i = self.W_s(S_sampled[batch_idx, i])
1780
+
1781
+ # Update the decoder sequence embeddings with the predicted
1782
+ # sequence for the current decoding index, if the residue is at
1783
+ # the end of a symmetry group.
1784
+ if not torch.is_grad_enabled():
1785
+ h_S[symm_end_and_same_mask_for_hidden] = h_S_i.expand(-1, L, -1)[
1786
+ symm_end_and_same_mask_for_hidden
1787
+ ]
1788
+ else:
1789
+ # For gradient tracking, we can't use in-place operations.
1790
+ h_S = torch.where(
1791
+ symm_end_and_same_mask_for_hidden, h_S_i.expand(-1, L, -1), h_S
1792
+ )
1793
+
1794
+ # Zero out the current position logits (used for accumulation)
1795
+ # if a batch example's current residue is at the end of a
1796
+ # symmetry group.
1797
+ logits_i[symm_group_end_mask] = 0.0
1798
+
1799
+ # Update the decoder features with the final node features, the computed
1800
+ # logits, log probabilities, probabilities, sampled probabilities,
1801
+ # predicted sequence, and argmax sequence.
1802
+ decoder_features["h_V"] = h_V_decoder_stack[-1]
1803
+ decoder_features["logits"] = logits
1804
+ decoder_features["log_probs"] = log_probs
1805
+ decoder_features["probs"] = probs
1806
+ decoder_features["probs_sample"] = probs_sample
1807
+ decoder_features["S_sampled"] = S_sampled
1808
+ decoder_features["S_argmax"] = S_argmax
1809
+
1810
+ def construct_output_dictionary(
1811
+ self, input_features, graph_features, encoder_features, decoder_features
1812
+ ):
1813
+ """
1814
+ Constructs the output dictionary based on the requested features.
1815
+
1816
+ Args:
1817
+ input_features (dict): Input features containing the requested
1818
+ features to return.
1819
+ - features_to_return (dict, optional): dictionary determining
1820
+ which features to return from the model. If None, return all
1821
+ features (including modified input features, graph features,
1822
+ encoder features, and decoder features). Otherwise,
1823
+ expects a dictionary with the following key, value pairs:
1824
+ - "input_features": list - the input features to return.
1825
+ - "graph_features": list - the graph features to return.
1826
+ - "encoder_features": list - the encoder features to return.
1827
+ - "decoder_features": list - the decoder features to return.
1828
+ graph_features (dict): Graph features containing the featurized
1829
+ node and edge inputs.
1830
+ encoder_features (dict): Encoder features containing the encoded
1831
+ protein node and protein edge features.
1832
+ decoder_features (dict): Decoder features containing the post-
1833
+ decoder features (including causal masks, logits, probabilities,
1834
+ predicted sequence, etc.).
1835
+ Returns:
1836
+ output_dict (dict): Output dictionary containing the requested
1837
+ features based on the input features' "features_to_return" key.
1838
+ If "features_to_return" is None, returns all features.
1839
+ """
1840
+ # Check that the input features contains the necessary keys.
1841
+ if "features_to_return" not in input_features:
1842
+ raise ValueError("Input features must contain 'features_to_return' key.")
1843
+
1844
+ # Create the output dictionary based on the requested features.
1845
+ if input_features["features_to_return"] is None:
1846
+ output_dict = {
1847
+ "input_features": input_features,
1848
+ "graph_features": graph_features,
1849
+ "encoder_features": encoder_features,
1850
+ "decoder_features": decoder_features,
1851
+ }
1852
+ else:
1853
+ # Filter the output dictionary based on the requested features.
1854
+ output_dict = dict()
1855
+ output_dict["input_features"] = {
1856
+ key: input_features[key]
1857
+ for key in input_features["features_to_return"].get(
1858
+ "input_features", []
1859
+ )
1860
+ }
1861
+ output_dict["graph_features"] = {
1862
+ key: graph_features[key]
1863
+ for key in input_features["features_to_return"].get(
1864
+ "graph_features", []
1865
+ )
1866
+ }
1867
+ output_dict["encoder_features"] = {
1868
+ key: encoder_features[key]
1869
+ for key in input_features["features_to_return"].get(
1870
+ "encoder_features", []
1871
+ )
1872
+ }
1873
+ output_dict["decoder_features"] = {
1874
+ key: decoder_features[key]
1875
+ for key in input_features["features_to_return"].get(
1876
+ "decoder_features", []
1877
+ )
1878
+ }
1879
+
1880
+ return output_dict
1881
+
1882
+ def forward(self, network_input):
1883
+ """
1884
+ Forward pass of the ProteinMPNN model.
1885
+
1886
+ A NOTE on shapes:
1887
+ - B = batch dimension size
1888
+ - L = sequence length (number of residues)
1889
+ - K = number of neighbors per residue
1890
+ - H = hidden dimension size
1891
+ - vocab_size = self.vocab_size
1892
+ - num_atoms =
1893
+ self.graph_featurization_module.TOKEN_ENCODING.n_atoms_per_token
1894
+ - num_backbone_atoms = len(
1895
+ self.graph_featurization_module.BACKBONE_ATOM_NAMES
1896
+ )
1897
+ - num_virtual_atoms = len(
1898
+ self.graph_featurization_module.DATA_TO_CALCULATE_VIRTUAL_ATOMS
1899
+ )
1900
+ - num_rep_atoms = len(
1901
+ self.graph_featurization_module.REPRESENTATIVE_ATOM_NAMES
1902
+ )
1903
+ - num_edge_output_features =
1904
+ self.graph_featurization_module.num_edge_output_features
1905
+ - num_node_output_features =
1906
+ self.graph_featurization_module.num_node_output_features
1907
+
1908
+ Args:
1909
+ network_input (dict): Dictionary containing the input to the
1910
+ network.
1911
+ - input_features (dict): dictionary containing input features
1912
+ and all necessary information for the model to run.
1913
+ - X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
1914
+ polymer atoms.
1915
+ - X_m (torch.Tensor): [B, L, num_atoms] - Mask indicating
1916
+ which polymer atoms are valid.
1917
+ - S (torch.Tensor): [B, L] - Sequence of the polymer
1918
+ residues.
1919
+ - R_idx (torch.Tensor): [B, L] - indices of the residues.
1920
+ - chain_labels (torch.Tensor): [B, L] - chain labels for
1921
+ each residue.
1922
+ - residue_mask (torch.Tensor): [B, L] - Mask indicating
1923
+ which residues are valid.
1924
+ - designed_residue_mask (torch.Tensor): [B, L] - mask for
1925
+ the designed residues.
1926
+ - symmetry_equivalence_group (torch.Tensor, optional):
1927
+ [B, L] - an integer for every residue, indicating the
1928
+ symmetry group that it belongs to. If None, the
1929
+ residues are not grouped by symmetry. For example, if
1930
+ residue i and j should be decoded symmetrically, then
1931
+ symmetry_equivalence_group[i] ==
1932
+ symmetry_equivalence_group[j]. Must be torch.int64 to
1933
+ allow for use as an index. These values should range
1934
+ from 0 to the maximum number of symmetry groups - 1 for
1935
+ each example. NOTE: bias, pair_bias, and temperature
1936
+ should be the same for all residues in the symmetry
1937
+ equivalence group; otherwise, the intended behavior may
1938
+ not be achieved. The residues within a symmetry group
1939
+ should all have the same validity and design/fixed
1940
+ status.
1941
+ - symmetry_weight (torch.Tensor, optional): [B, L] - the
1942
+ weights for each residue, to be used when aggregating
1943
+ across its respective symmetry group. If None, the
1944
+ weights are assumed to be 1.0 for all residues.
1945
+ - bias (torch.Tensor, optional): [B, L, 21] - the
1946
+ per-residue bias to use for sampling. If None, the code
1947
+ will implicitly use a bias of 0.0 for all residues.
1948
+ - pair_bias (torch.Tensor, optional): [B, L, 21, L, 21] -
1949
+ the per-residue pair bias to use for sampling. If None,
1950
+ the code will implicitly use a pair bias of 0.0 for all
1951
+ residue pairs.
1952
+ - temperature (torch.Tensor, optional): [B, L] - the
1953
+ per-residue temperature to use for sampling. If None,
1954
+ the code will implicitly use a temperature of 1.0.
1955
+ - structure_noise (float): Standard deviation of the
1956
+ Gaussian noise to add to the input coordinates, in
1957
+ Angstroms.
1958
+ - decode_type (str): the type of decoding to use.
1959
+ - "teacher_forcing": Use teacher forcing for the
1960
+ decoder, where the decoder attends to the ground
1961
+ truth sequence S for all previously decoded
1962
+ residues.
1963
+ - "auto_regressive": Use auto-regressive decoding,
1964
+ where the decoder attends to the sequence and
1965
+ decoder representation of residues that have
1966
+ already been decoded (using the predicted sequence).
1967
+ - causality_pattern (str): The pattern of causality to use
1968
+ for the decoder. For all causality patterns, the
1969
+ decoding order is randomized.
1970
+ - "auto_regressive": Use an auto-regressive causality
1971
+ pattern, where residues can attend to the sequence
1972
+ and decoder representation of residues that have
1973
+ already been decoded (NOTE: as mentioned above,
1974
+ this will be randomized).
1975
+ - "unconditional": Residues cannot attend to the
1976
+ sequence or decoder representation of any other
1977
+ residues.
1978
+ - "conditional": Residues can attend to the sequence
1979
+ and decoder representation of all other residues.
1980
+ - "conditional_minus_self": Residues can attend to the
1981
+ sequence and decoder representation of all other
1982
+ residues, except for themselves (as destination
1983
+ nodes).
1984
+ - initialize_sequence_embedding_with_ground_truth (bool):
1985
+ - True: Initialize the sequence embedding with the
1986
+ ground truth sequence S.
1987
+ - If doing auto-regressive decoding, also
1988
+ initialize S_sampled with the ground truth
1989
+ sequence S, which should only affect the
1990
+ application of pair bias.
1991
+ - False: Initialize the sequence embedding with zeros.
1992
+ - If doing auto-regressive decoding, initialize
1993
+ S_sampled with unknown residues.
1994
+ - features_to_return (dict, optional): dictionary
1995
+ determining which features to return from the model. If
1996
+ None, return all features (including modified input
1997
+ features, graph features, encoder features, and decoder
1998
+ features). Otherwise, expects a dictionary with the
1999
+ following key, value pairs:
2000
+ - "input_features": list - the input features to return.
2001
+ - "graph_features": list - the graph features to return.
2002
+ - "encoder_features": list - the encoder features to
2003
+ return.
2004
+ - "decoder_features": list - the decoder features to
2005
+ return.
2006
+ - repeat_sample_num (int, optional): Number of times to
2007
+ repeat the samples along the batch dimension. If None,
2008
+ no repetition is performed. If greater than 1, the
2009
+ samples are repeated along the batch dimension. If
2010
+ greater than 1, B must be 1, since repeating samples
2011
+ along the batch dimension is not supported when more
2012
+ than one sample is provided in the batch.
2013
+ Side Effects:
2014
+ Any changes denoted below to input_features are also mutated on the
2015
+ original input features.
2016
+ Returns:
2017
+ network_output (dict): Output dictionary containing the requested
2018
+ features based on the input features' "features_to_return" key.
2019
+ - input_features (dict): The input features from above, with
2020
+ the following keys added or modified:
2021
+ - mask_for_loss (torch.Tensor): [B, L] - mask for loss,
2022
+ where True is a residue that is included in the loss
2023
+ calculation, and False is a residue that is not
2024
+ included in the loss calculation.
2025
+ - X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
2026
+ polymer atoms with added Gaussian noise.
2027
+ - X_pre_noise (torch.Tensor): [B, L, num_atoms, 3] -
2028
+ 3D coordinates of polymer atoms before adding Gaussian
2029
+ noise ('X' before noise).
2030
+ - X_backbone (torch.Tensor): [B, L, num_backbone_atoms, 3] -
2031
+ 3D coordinates of the backbone atoms for each residue,
2032
+ built from the noisy 'X' coordinates.
2033
+ - X_m_backbone (torch.Tensor): [B, L, num_backbone_atoms] -
2034
+ mask indicating which backbone atoms are valid.
2035
+ - X_virtual_atoms (torch.Tensor):
2036
+ [B, L, num_virtual_atoms, 3] - 3D coordinates of the
2037
+ virtual atoms for each residue, built from the noisy
2038
+ 'X' coordinates.
2039
+ - X_m_virtual_atoms (torch.Tensor):
2040
+ [B, L, num_virtual_atoms] - mask indicating which
2041
+ virtual atoms are valid.
2042
+ - X_rep_atoms (torch.Tensor): [B, L, num_rep_atoms, 3] - 3D
2043
+ coordinates of the representative atoms for each
2044
+ residue, built from the noisy 'X' coordinates.
2045
+ - X_m_rep_atoms (torch.Tensor): [B, L, num_rep_atoms] -
2046
+ mask indicating which representative atoms are valid.
2047
+ - graph_features (dict): The graph features.
2048
+ - E_idx (torch.Tensor): [B, L, K] - indices of the top K
2049
+ nearest neighbors for each residue.
2050
+ - E (torch.Tensor): [B, L, K, num_edge_output_features] -
2051
+ Edge features for each pair of neighbors.
2052
+ - encoder_features (dict): The encoder features.
2053
+ - h_V (torch.Tensor): [B, L, H] - the protein node features
2054
+ after encoding message passing.
2055
+ - h_E (torch.Tensor): [B, L, K, H] - the protein edge
2056
+ features after encoding message passing.
2057
+ - decoder_features (dict): The decoder features.
2058
+ - causal_mask (torch.Tensor): [B, L, K, 1] - the causal
2059
+ mask for the decoder.
2060
+ - anti_causal_mask (torch.Tensor): [B, L, K, 1] - the
2061
+ anti-causal mask for the decoder.
2062
+ - decoding_order (torch.Tensor): [B, L] - the order in
2063
+ which the residues should be decoded.
2064
+ - decode_last_mask (torch.Tensor): [B, L] - mask for
2065
+ residues that should be decoded last, where False is a
2066
+ residue that should be decoded first (invalid or
2067
+ fixed), and True is a residue that should not be
2068
+ decoded first (designed residues).
2069
+ - h_V (torch.Tensor): [B, L, H] - the updated node features
2070
+ for the decoder.
2071
+ - logits (torch.Tensor): [B, L, vocab_size] - the logits
2072
+ for the sequence.
2073
+ - log_probs (torch.Tensor): [B, L, vocab_size] - the log
2074
+ probabilities for the sequence.
2075
+ - probs (torch.Tensor): [B, L, vocab_size] - the
2076
+ probabilities for the sequence.
2077
+ - probs_sample (torch.Tensor): [B, L, vocab_size] -
2078
+ the probabilities for the sequence, with the unknown
2079
+ residues zeroed out and the other residues normalized.
2080
+ - S_sampled (torch.Tensor): [B, L] - the predicted
2081
+ sequence, sampled from the probabilities (unknown
2082
+ residues are not sampled).
2083
+ - S_argmax (torch.Tensor): [B, L] - the predicted sequence,
2084
+ obtained by taking the argmax of the probabilities
2085
+ (unknown residues are not selected).
2086
+ """
2087
+ input_features = network_input["input_features"]
2088
+
2089
+ # Check that the input features contains the necessary keys.
2090
+ if "decode_type" not in input_features:
2091
+ raise ValueError("Input features must contain 'decode_type' key.")
2092
+
2093
+ # Setup masks (added to the input features).
2094
+ self.sample_and_construct_masks(input_features)
2095
+
2096
+ # Graph featurization (also modifies/adds to input_features).
2097
+ graph_features = self.graph_featurization(input_features)
2098
+
2099
+ # Run the encoder.
2100
+ encoder_features = self.encode(input_features, graph_features)
2101
+
2102
+ # Setup for decoder (repeat features along the batch dimension, modifies
2103
+ # input_features, graph_features, and encoder_features).
2104
+ self.repeat_along_batch(
2105
+ input_features,
2106
+ graph_features,
2107
+ encoder_features,
2108
+ )
2109
+
2110
+ # Set up the causality masks.
2111
+ decoder_features = self.setup_causality_masks(input_features, graph_features)
2112
+
2113
+ # Decoder, either teacher forcing or auto-regressive.
2114
+ if input_features["decode_type"] == "teacher_forcing":
2115
+ self.decode_teacher_forcing(
2116
+ input_features, graph_features, encoder_features, decoder_features
2117
+ )
2118
+ elif input_features["decode_type"] == "auto_regressive":
2119
+ self.decode_auto_regressive(
2120
+ input_features, graph_features, encoder_features, decoder_features
2121
+ )
2122
+ else:
2123
+ raise ValueError(f"Unknown decode_type: {input_features['decode_type']}.")
2124
+
2125
+ # Create the output dictionary based on the requested features.
2126
+ network_output = self.construct_output_dictionary(
2127
+ input_features, graph_features, encoder_features, decoder_features
2128
+ )
2129
+
2130
+ return network_output
2131
+
2132
+
2133
+ class SolubleMPNN(ProteinMPNN):
2134
+ """
2135
+ Same as ProteinMPNN, but with different training set and different weights.
2136
+ """
2137
+
2138
+ pass
2139
+
2140
+
2141
+ class AntibodyMPNN(ProteinMPNN):
2142
+ """
2143
+ Same as ProteinMPNN, but with different training set and different weights.
2144
+ """
2145
+
2146
+ pass
2147
+
2148
+
2149
+ class MembraneMPNN(ProteinMPNN):
2150
+ """
2151
+ Class for per-residue and global membrane label version of MPNN.
2152
+ """
2153
+
2154
+ HAS_NODE_FEATURES = True
2155
+
2156
+ def __init__(
2157
+ self,
2158
+ num_node_features=128,
2159
+ num_edge_features=128,
2160
+ hidden_dim=128,
2161
+ num_encoder_layers=3,
2162
+ num_decoder_layers=3,
2163
+ num_neighbors=48,
2164
+ dropout_rate=0.1,
2165
+ num_positional_embeddings=16,
2166
+ min_rbf_mean=2.0,
2167
+ max_rbf_mean=22.0,
2168
+ num_rbf=16,
2169
+ num_membrane_classes=3,
2170
+ ):
2171
+ """
2172
+ Setup the MembraneMPNN model.
2173
+
2174
+ All args are the same as the parents class, except for the following:
2175
+ Args:
2176
+ num_membrane_classes (int): Number of membrane classes.
2177
+ """
2178
+ # The only change necessary here is the graph featurization module.
2179
+ graph_featurization_module = ProteinFeaturesMembrane(
2180
+ num_edge_output_features=num_edge_features,
2181
+ num_node_output_features=num_node_features,
2182
+ num_positional_embeddings=num_positional_embeddings,
2183
+ min_rbf_mean=min_rbf_mean,
2184
+ max_rbf_mean=max_rbf_mean,
2185
+ num_rbf=num_rbf,
2186
+ num_neighbors=num_neighbors,
2187
+ num_membrane_classes=num_membrane_classes,
2188
+ )
2189
+
2190
+ super(MembraneMPNN, self).__init__(
2191
+ num_node_features=num_node_features,
2192
+ num_edge_features=num_edge_features,
2193
+ hidden_dim=hidden_dim,
2194
+ num_encoder_layers=num_encoder_layers,
2195
+ num_decoder_layers=num_decoder_layers,
2196
+ num_neighbors=num_neighbors,
2197
+ dropout_rate=dropout_rate,
2198
+ num_positional_embeddings=num_positional_embeddings,
2199
+ min_rbf_mean=min_rbf_mean,
2200
+ max_rbf_mean=max_rbf_mean,
2201
+ num_rbf=num_rbf,
2202
+ graph_featurization_module=graph_featurization_module,
2203
+ )
2204
+
2205
+
2206
+ class PSSMMPNN(ProteinMPNN):
2207
+ """
2208
+ Class for pssm-aware version of MPNN.
2209
+ """
2210
+
2211
+ HAS_NODE_FEATURES = True
2212
+
2213
+ def __init__(
2214
+ self,
2215
+ num_node_features=128,
2216
+ num_edge_features=128,
2217
+ hidden_dim=128,
2218
+ num_encoder_layers=3,
2219
+ num_decoder_layers=3,
2220
+ num_neighbors=48,
2221
+ dropout_rate=0.1,
2222
+ num_positional_embeddings=16,
2223
+ min_rbf_mean=2.0,
2224
+ max_rbf_mean=22.0,
2225
+ num_rbf=16,
2226
+ num_pssm_features=20,
2227
+ ):
2228
+ """
2229
+ Setup the PSSMMPNN model.
2230
+
2231
+ All args are the same as the parents class, except for the following:
2232
+ Args:
2233
+ num_pssm_features (int): Number of PSSM features.
2234
+ """
2235
+ # The only change necessary here is the graph featurization module.
2236
+ graph_featurization_module = ProteinFeaturesPSSM(
2237
+ num_edge_output_features=num_edge_features,
2238
+ num_node_output_features=num_node_features,
2239
+ num_positional_embeddings=num_positional_embeddings,
2240
+ min_rbf_mean=min_rbf_mean,
2241
+ max_rbf_mean=max_rbf_mean,
2242
+ num_rbf=num_rbf,
2243
+ num_neighbors=num_neighbors,
2244
+ num_pssm_features=num_pssm_features,
2245
+ )
2246
+
2247
+ super(PSSMMPNN, self).__init__(
2248
+ num_node_features=num_node_features,
2249
+ num_edge_features=num_edge_features,
2250
+ hidden_dim=hidden_dim,
2251
+ num_encoder_layers=num_encoder_layers,
2252
+ num_decoder_layers=num_decoder_layers,
2253
+ num_neighbors=num_neighbors,
2254
+ dropout_rate=dropout_rate,
2255
+ num_positional_embeddings=num_positional_embeddings,
2256
+ min_rbf_mean=min_rbf_mean,
2257
+ max_rbf_mean=max_rbf_mean,
2258
+ num_rbf=num_rbf,
2259
+ graph_featurization_module=graph_featurization_module,
2260
+ )
2261
+
2262
+
2263
+ class LigandMPNN(ProteinMPNN):
2264
+ """
2265
+ Class for ligand-aware version of MPNN.
2266
+ """
2267
+
2268
+ # Although there are node-like features, they are actually the protein-
2269
+ # ligand subgraph features, so we set this to False. Note, this is because
2270
+ # none of these features are embedded prior to protein encoding.
2271
+ HAS_NODE_FEATURES = False
2272
+
2273
+ def __init__(
2274
+ self,
2275
+ num_node_features=128,
2276
+ num_edge_features=128,
2277
+ hidden_dim=128,
2278
+ num_encoder_layers=3,
2279
+ num_decoder_layers=3,
2280
+ num_neighbors=32,
2281
+ dropout_rate=0.1,
2282
+ num_positional_embeddings=16,
2283
+ min_rbf_mean=2.0,
2284
+ max_rbf_mean=22.0,
2285
+ num_rbf=16,
2286
+ num_context_atoms=25,
2287
+ num_context_encoding_layers=2,
2288
+ overall_atomize_side_chain_probability=0.5,
2289
+ per_residue_atomize_side_chain_probability=0.02,
2290
+ ):
2291
+ # Pass the num_context_atoms to the graph featurization module.
2292
+ graph_featurization_module = ProteinFeaturesLigand(
2293
+ num_edge_output_features=num_edge_features,
2294
+ num_node_output_features=num_node_features,
2295
+ num_positional_embeddings=num_positional_embeddings,
2296
+ min_rbf_mean=min_rbf_mean,
2297
+ max_rbf_mean=max_rbf_mean,
2298
+ num_rbf=num_rbf,
2299
+ num_neighbors=num_neighbors,
2300
+ num_context_atoms=num_context_atoms,
2301
+ )
2302
+
2303
+ super(LigandMPNN, self).__init__(
2304
+ num_node_features=num_node_features,
2305
+ num_edge_features=num_edge_features,
2306
+ hidden_dim=hidden_dim,
2307
+ num_encoder_layers=num_encoder_layers,
2308
+ num_decoder_layers=num_decoder_layers,
2309
+ num_neighbors=num_neighbors,
2310
+ dropout_rate=dropout_rate,
2311
+ num_positional_embeddings=num_positional_embeddings,
2312
+ min_rbf_mean=min_rbf_mean,
2313
+ max_rbf_mean=max_rbf_mean,
2314
+ num_rbf=num_rbf,
2315
+ graph_featurization_module=graph_featurization_module,
2316
+ )
2317
+ self.overall_atomize_side_chain_probability = (
2318
+ overall_atomize_side_chain_probability
2319
+ )
2320
+ self.per_residue_atomize_side_chain_probability = (
2321
+ per_residue_atomize_side_chain_probability
2322
+ )
2323
+
2324
+ # Linear layer for embedding the protein-ligand edge features.
2325
+ self.W_protein_to_ligand_edges_embed = nn.Linear(
2326
+ num_node_features, hidden_dim, bias=True
2327
+ )
2328
+
2329
+ # Linear layers for embedding the output of the protein encoder.
2330
+ self.W_protein_encoding_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
2331
+
2332
+ # Linear layers for embedding the ligand nodes and edges.
2333
+ self.W_ligand_nodes_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
2334
+ self.W_ligand_edges_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
2335
+
2336
+ # Linear layer for the final context embedding.
2337
+ self.W_final_context_embed = nn.Linear(hidden_dim, hidden_dim, bias=False)
2338
+
2339
+ # Layer norm for the final context embedding.
2340
+ self.final_context_norm = nn.LayerNorm(hidden_dim)
2341
+
2342
+ # Save the number of context encoding layers.
2343
+ self.num_context_encoding_layers = num_context_encoding_layers
2344
+
2345
+ self.protein_ligand_context_encoder_layers = nn.ModuleList(
2346
+ [
2347
+ DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout_rate)
2348
+ for _ in range(num_context_encoding_layers)
2349
+ ]
2350
+ )
2351
+
2352
+ self.ligand_context_encoder_layers = nn.ModuleList(
2353
+ [
2354
+ DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout_rate)
2355
+ for _ in range(num_context_encoding_layers)
2356
+ ]
2357
+ )
2358
+
2359
+ def sample_and_construct_masks(self, input_features):
2360
+ """
2361
+ Sample and construct masks for the input features.
2362
+
2363
+ Args:
2364
+ input_features (dict): Input features containing the residue mask.
2365
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
2366
+ - S (torch.Tensor): [B, L] - sequence of residues.
2367
+ - designed_residue_mask (torch.Tensor): [B, L] - mask for the
2368
+ designed residues.
2369
+ - atomize_side_chains (bool): Whether to atomize side chains.
2370
+ Side Effects:
2371
+ input_features["residue_mask"] (torch.Tensor): [B, L] - mask for the
2372
+ residues, where True is a residue that is valid and False is a
2373
+ residue that is invalid.
2374
+ input_features["known_residue_mask"] (torch.Tensor): [B, L] - mask
2375
+ for known residues, where True is a residue with one of the
2376
+ canonical residue types, and False is a residue with an unknown
2377
+ residue type.
2378
+ input_features["designed_residue_mask"] (torch.Tensor): [B, L] -
2379
+ mask for designed residues, where True is a residue that is
2380
+ designed, and False is a residue that is not designed.
2381
+ input_features["hide_side_chain_mask"] (torch.Tensor): [B, L] - mask
2382
+ for hiding side chains, where True is a residue with hidden side
2383
+ chains, and False is a residue with revealed side chains.
2384
+ input_features["mask_for_loss"] (torch.Tensor): [B, L] - mask for
2385
+ loss, where True is a residue that is included in the loss
2386
+ calculation, and False is a residue that is not included in the
2387
+ loss calculation.
2388
+ """
2389
+ # Create the masks for ProteinMPNN.
2390
+ super().sample_and_construct_masks(input_features)
2391
+
2392
+ # Check that the input features contain the necessary keys.
2393
+ if "atomize_side_chains" not in input_features:
2394
+ raise ValueError("Input features must contain 'atomize_side_chains' key.")
2395
+
2396
+ # Create the mask for hiding or revealing side chains.
2397
+ # With no side chain atomization, the side chain mask is all ones.
2398
+ if input_features["atomize_side_chains"]:
2399
+ # If we are training, randomly reveal side chains.
2400
+ if self.training:
2401
+ # With a probability specified as the overall atomization
2402
+ # side chain probability, we reveal some side chains (otherwise,
2403
+ # we hide all side chains).
2404
+ if (
2405
+ sample_bernoulli_rv(self.overall_atomize_side_chain_probability)
2406
+ == 1
2407
+ ):
2408
+ reveal_side_chain_mask = (
2409
+ torch.rand(
2410
+ input_features["S"].shape, device=input_features["S"].device
2411
+ )
2412
+ < self.per_residue_atomize_side_chain_probability
2413
+ )
2414
+ hide_side_chain_mask = ~reveal_side_chain_mask
2415
+ else:
2416
+ hide_side_chain_mask = torch.ones(
2417
+ input_features["S"].shape, device=input_features["S"].device
2418
+ ).bool()
2419
+ # If we are not training, only the side chains of fixed residues
2420
+ # are revealed.
2421
+ else:
2422
+ hide_side_chain_mask = input_features["designed_residue_mask"].clone()
2423
+ else:
2424
+ hide_side_chain_mask = torch.ones(
2425
+ input_features["S"].shape, device=input_features["S"].device
2426
+ ).bool()
2427
+
2428
+ # Save the hide side chain mask in the input features.
2429
+ input_features["hide_side_chain_mask"] = hide_side_chain_mask
2430
+
2431
+ # Update the mask for the loss to include the hide side chain mask.
2432
+ input_features["mask_for_loss"] = (
2433
+ input_features["mask_for_loss"] & input_features["hide_side_chain_mask"]
2434
+ )
2435
+
2436
+ def encode(self, input_features, graph_features):
2437
+ """
2438
+ Encode the protein features with ligand context.
2439
+
2440
+ NOTE: M = self.graph_featurization_module.num_context_atoms, the number
2441
+ of ligand atoms in each residue subgraph.
2442
+
2443
+ Args:
2444
+ input_features (dict): Input features containing the residue mask.
2445
+ - residue_mask (torch.Tensor): [B, L] - mask for the residues.
2446
+ - ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask for the
2447
+ ligand subgraph nodes.
2448
+ graph_features (dict): Graph features containing the featurized
2449
+ node and edge inputs.
2450
+ - E_protein_to_ligand (torch.Tensor): [B, L, M,
2451
+ self.num_edge_features] - protein-ligand edge features.
2452
+ - ligand_subgraph_nodes (torch.Tensor): [B, L, M,
2453
+ self.num_node_features] - ligand subgraph node features.
2454
+ - ligand_subgraph_edges (torch.Tensor): [B, L, M, M,
2455
+ self.num_edge_features] - ligand subgraph edge features.
2456
+ Returns:
2457
+ encoder_features (dict): Encoded features containing the encoded
2458
+ protein node and protein edge features.
2459
+ - h_V (torch.Tensor): [B, L, self.hidden_dim] - the protein node
2460
+ features after protein encoding and ligand context
2461
+ encoding.
2462
+ - h_E (torch.Tensor): [B, L, K, self.hidden_dim] - the protein
2463
+ edge features after protein encoding message passing.
2464
+ """
2465
+ # Use the parent encode method to get the initial protein encoding.
2466
+ encoder_features = super().encode(input_features, graph_features)
2467
+
2468
+ # Check the encoder features.
2469
+ if "h_V" not in encoder_features:
2470
+ raise ValueError("Encoder features must contain 'h_V' key.")
2471
+
2472
+ # Check that the input features contain the necessary keys.
2473
+ if "residue_mask" not in input_features:
2474
+ raise ValueError("Input features must contain 'residue_mask' key.")
2475
+ if "ligand_subgraph_Y_m" not in input_features:
2476
+ raise ValueError("Input features must contain 'ligand_subgraph_Y_m' key.")
2477
+
2478
+ # Check that the graph features contain the necessary keys.
2479
+ if "E_protein_to_ligand" not in graph_features:
2480
+ raise ValueError("Graph features must contain 'E_protein_to_ligand' key.")
2481
+ if "ligand_subgraph_nodes" not in graph_features:
2482
+ raise ValueError("Graph features must contain 'ligand_subgraph_nodes' key.")
2483
+ if "ligand_subgraph_edges" not in graph_features:
2484
+ raise ValueError("Graph features must contain 'ligand_subgraph_edges' key.")
2485
+
2486
+ # Compute the protein-ligand edge feature encoding.
2487
+ # h_E_protein_to_ligand [B, L, M, self.hidden_dim] - the embedding of
2488
+ # the protein-ligand edge features.
2489
+ h_E_protein_to_ligand = self.W_protein_to_ligand_edges_embed(
2490
+ graph_features["E_protein_to_ligand"]
2491
+ )
2492
+
2493
+ # Construct the starting context features, to aggregate the ligand
2494
+ # context; will be updated in the context encoder.
2495
+ # h_V_context [B, L, self.hidden_dim] - the embedding of context.
2496
+ h_V_context = self.W_protein_encoding_embed(encoder_features["h_V"])
2497
+
2498
+ # Construct the ligand subgraph edge mask.
2499
+ # ligand_subgraph_Y_m_edges [B, L, M, M] - the mask for the
2500
+ # ligand-ligand subgraph edges.
2501
+ ligand_subgraph_Y_m_edges = (
2502
+ input_features["ligand_subgraph_Y_m"][:, :, :, None]
2503
+ * input_features["ligand_subgraph_Y_m"][:, :, None, :]
2504
+ )
2505
+
2506
+ # Embed the ligand nodes.
2507
+ # ligand_subgraph_nodes [B, L, M, self.hidden_dim] - the embedding of
2508
+ # the ligand nodes in the subgraph.
2509
+ h_ligand_subgraph_nodes = self.W_ligand_nodes_embed(
2510
+ graph_features["ligand_subgraph_nodes"]
2511
+ )
2512
+
2513
+ # Embed the ligand edges.
2514
+ # ligand_subgraph_edges [B, L, M, M, self.hidden_dim] - the embedding
2515
+ # of the ligand edges in the subgraph.
2516
+ h_ligand_subgraph_edges = self.W_ligand_edges_embed(
2517
+ graph_features["ligand_subgraph_edges"]
2518
+ )
2519
+
2520
+ # Run the context encoder layers for the protein-ligand context.
2521
+ for i in range(self.num_context_encoding_layers):
2522
+ # Message passing in the ligand subgraph.
2523
+ # BUG: to replicate the original LigandMPNN,destination nodes are
2524
+ # not concatenated into ligand_subgraph_edges; This breaks message
2525
+ # passing in the small molecule graph (no message passing in the
2526
+ # ligand subgraph).
2527
+ h_ligand_subgraph_nodes = torch.utils.checkpoint.checkpoint(
2528
+ self.ligand_context_encoder_layers[i],
2529
+ h_ligand_subgraph_nodes,
2530
+ h_ligand_subgraph_edges,
2531
+ input_features["ligand_subgraph_Y_m"],
2532
+ ligand_subgraph_Y_m_edges,
2533
+ use_reentrant=False,
2534
+ )
2535
+
2536
+ # Concatenate the protein-ligand edge features with the ligand
2537
+ # hidden note features (effectively treating the ligand subgraph
2538
+ # node features as protein-ligand edge features).
2539
+ # h_E_protein_to_ligand_cat [B, L, M, 2 * self.hidden_dim] - the
2540
+ # concatenated protein-ligand edge features.
2541
+ h_E_protein_to_ligand_cat = torch.cat(
2542
+ [h_E_protein_to_ligand, h_ligand_subgraph_nodes], -1
2543
+ )
2544
+
2545
+ # h_V_context [B, L, self.hidden_dim] - the updated context node
2546
+ # features. Message passing from ligand subgraph to the protein.
2547
+ h_V_context = torch.utils.checkpoint.checkpoint(
2548
+ self.protein_ligand_context_encoder_layers[i],
2549
+ h_V_context,
2550
+ h_E_protein_to_ligand_cat,
2551
+ input_features["residue_mask"],
2552
+ input_features["ligand_subgraph_Y_m"],
2553
+ use_reentrant=False,
2554
+ )
2555
+
2556
+ # Final context embedding.
2557
+ h_V_context = self.W_final_context_embed(h_V_context)
2558
+
2559
+ # Update the protein node features with the context with a residual
2560
+ # connection (after apply dropout and layer norm to the context).
2561
+ encoder_features["h_V"] = encoder_features["h_V"] + self.final_context_norm(
2562
+ self.dropout(h_V_context)
2563
+ )
2564
+
2565
+ return encoder_features
2566
+
2567
+ def forward(self, network_input):
2568
+ """
2569
+ Forward pass for the LigandMPNN model, which uses the same forward
2570
+ function and is repeated here for documentation purposes.
2571
+
2572
+ A NOTE on shapes (in addition to ProteinMPNN):
2573
+ - N = number of ligand atoms
2574
+ - M = self.num_context_atoms (number of ligand atom neighbors)
2575
+
2576
+ Args:
2577
+ network_input (dict): Dictionary containing the input to the
2578
+ network.
2579
+ - input_features (dict): dictionary containing input features
2580
+ and all necessary information for the model to run, in
2581
+ addition to the input features for the ProteinMPNN model.
2582
+ - Y (torch.Tensor): [B, N, 3] - 3D coordinates of the
2583
+ ligand atoms.
2584
+ - Y_m (torch.Tensor): [B, N] - mask indicating which ligand
2585
+ atoms are valid.
2586
+ - Y_t (torch.Tensor): [B, N] - element types of the ligand
2587
+ atoms.
2588
+ - atomize_side_chains (bool): Whether to atomize side
2589
+ chains of fixed residues.
2590
+ Side Effects:
2591
+ Any changes denoted below to input_features are also mutated on the
2592
+ original input features.
2593
+ Returns:
2594
+ network_output (dict): Output dictionary containing the requested
2595
+ features based on the input features' "features_to_return" key,
2596
+ in addition to the output features from the ProteinMPNN model.
2597
+ - input_features (dict): The input features from above, with
2598
+ the following keys added or modified:
2599
+ - hide_side_chain_mask (torch.Tensor): [B, L] - mask for
2600
+ hiding side chains, where True is a residue with hidden
2601
+ side chains, and False is a residue with revealed side
2602
+ chains.
2603
+ - Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand
2604
+ atoms with added Gaussian noise.
2605
+ - Y_pre_noise (torch.Tensor): [B, N, 3] -
2606
+ 3D coordinates of the ligand atoms before adding
2607
+ Gaussian noise ('Y' before noise).
2608
+ - ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D
2609
+ coordinates of nearest ligand/atomized side chain
2610
+ atoms to the virtual atoms for each residue.
2611
+ - ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask
2612
+ indicating which nearest ligand/atomized side chain
2613
+ atoms to the virtual atoms are valid.
2614
+ - ligand_subgraph_Y_t (torch.Tensor): [B, L, M] -
2615
+ element types of the nearest ligand/atomized side chain
2616
+ atoms to the virtual atoms for each residue.
2617
+ - graph_features (dict): The graph features.
2618
+ - E_protein_to_ligand (torch.Tensor):
2619
+ [B, L, M, num_node_output_features] - protein to
2620
+ ligand subgraph edges; can also be considered node
2621
+ features of the protein residues (although they are not
2622
+ used as such).
2623
+ - ligand_subgraph_nodes (torch.Tensor):
2624
+ [B, L, M, num_node_output_features] - ligand atom type
2625
+ information, embedded as node features.
2626
+ - ligand_subgraph_edges (torch.Tensor):
2627
+ [B, L, M, M, num_edge_output_features] - embedded and
2628
+ normalized radial basis function embedding of the
2629
+ distances between the ligand atoms in each residue
2630
+ subgraph.
2631
+ """
2632
+ return super().forward(network_input)