rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,2372 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER
4
+ from mpnn.model.layers.message_passing import gather_edges, gather_nodes
5
+ from mpnn.model.layers.positional_encoding import PositionalEncodings
6
+ from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
7
+
8
+
9
+ class ProteinFeatures(nn.Module):
10
+ TOKEN_ENCODING = MPNN_TOKEN_ENCODING
11
+ BACKBONE_ATOM_NAMES = ["N", "CA", "C", "O"]
12
+
13
+ REPRESENTATIVE_ATOM_NAMES = ["CA"]
14
+
15
+ DATA_TO_CALCULATE_VIRTUAL_ATOMS = [
16
+ (
17
+ {"center_atom": "CA", "atom_1": "N", "atom_2": "C"},
18
+ {
19
+ "weight_normal": 0.58273431,
20
+ "weight_bond_1": -0.56802827,
21
+ "weight_bond_2": -0.54067466,
22
+ },
23
+ )
24
+ ]
25
+
26
+ def __init__(
27
+ self,
28
+ num_edge_output_features=128,
29
+ num_node_output_features=128,
30
+ num_positional_embeddings=16,
31
+ min_rbf_mean=2.0,
32
+ max_rbf_mean=22.0,
33
+ num_rbf=16,
34
+ num_neighbors=48,
35
+ ):
36
+ """
37
+ Given a protein structure, extract the features for the graph
38
+ representation of the protein.
39
+
40
+ Args:
41
+ num_edge_output_features (int): Number of output features for the
42
+ edges.
43
+ num_node_output_features (int): Number of output features for the
44
+ nodes.
45
+ num_positional_embeddings (int): Number of positional embeddings.
46
+ min_rbf_mean (float): Minimum mean for the radial basis functions.
47
+ max_rbf_mean (float): Maximum mean for the radial basis functions.
48
+ num_rbf (int): Number of radial basis functions.
49
+ num_neighbors (int): Number of neighbors to consider for each
50
+ residue.
51
+ """
52
+ super(ProteinFeatures, self).__init__()
53
+
54
+ self.num_edge_output_features = num_edge_output_features
55
+ self.num_node_output_features = num_node_output_features
56
+
57
+ self.num_neighbors = num_neighbors
58
+
59
+ self.min_rbf_mean = min_rbf_mean
60
+ self.max_rbf_mean = max_rbf_mean
61
+ self.num_rbf = num_rbf
62
+
63
+ self.num_positional_embeddings = num_positional_embeddings
64
+
65
+ self.num_backbone_atoms = len(self.BACKBONE_ATOM_NAMES)
66
+ self.num_virtual_atoms = len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)
67
+
68
+ self.num_edge_input_features = num_positional_embeddings + num_rbf * (
69
+ (self.num_backbone_atoms + self.num_virtual_atoms) ** 2
70
+ )
71
+
72
+ # Layers
73
+ self.positional_embedding = PositionalEncodings(self.num_positional_embeddings)
74
+ self.edge_embedding = nn.Linear(
75
+ self.num_edge_input_features, self.num_edge_output_features, bias=False
76
+ )
77
+ self.edge_norm = nn.LayerNorm(self.num_edge_output_features)
78
+
79
+ def construct_X_atoms(self, X, X_m, S, atom_names):
80
+ """
81
+ Given an array of 3D coordinates and the corresponding atom mask, use
82
+ the sequence and the atom names to construct a subset of X and X_m that
83
+ contains only the requested atoms.
84
+
85
+ Args:
86
+ X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
87
+ 3D coordinates of polymer atoms.
88
+ X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
89
+ Mask indicating which polymer atoms are valid.
90
+ S (torch.Tensor): [B, L] - the sequence of residues.
91
+ atom_names (list): List of atom names to extract from the
92
+ coordinates.
93
+ Returns:
94
+ X_atoms (torch.Tensor): [B, L, len(atom_names), 3] - 3D coordinates
95
+ of the requested atoms for each residue.
96
+ X_m_atoms (torch.Tensor): [B, L, len(atom_names)] - mask indicating
97
+ which requested atoms are valid for each residue.
98
+ """
99
+ B, L, _, _ = X.shape
100
+
101
+ # token_and_atom_name_to_atom_idx [self.TOKEN_ENCODING.n_tokens,
102
+ # len(atom_names)] - a tensor that maps each token/atom name pair to the
103
+ # corresponding atom index.
104
+ token_and_atom_name_to_atom_idx = torch.zeros(
105
+ (self.TOKEN_ENCODING.n_tokens, len(atom_names)),
106
+ device=X.device,
107
+ dtype=torch.int64,
108
+ )
109
+ for token_name, token_idx in self.TOKEN_ENCODING.token_to_idx.items():
110
+ for i, atom_name in enumerate(atom_names):
111
+ token_and_atom_name_to_atom_idx[token_idx, i] = (
112
+ self.TOKEN_ENCODING.atom_to_idx[(token_name, atom_name)]
113
+ )
114
+
115
+ # batch_idx [B, 1, 1] - a tensor that contains the batch index.
116
+ batch_idx = torch.arange(B, dtype=torch.int64, device=X.device).view(B, 1, 1)
117
+
118
+ # position_idx [1, L, 1] - a tensor that contains the position index.
119
+ position_idx = torch.arange(L, dtype=torch.int64, device=X.device).view(1, L, 1)
120
+
121
+ # atom_indices [B, L, len(atom_names)] - a tensor that contains the
122
+ # atom index for each residue for each atom name.
123
+ atom_indices = token_and_atom_name_to_atom_idx[S]
124
+
125
+ # X_atoms [B, L, len(atom_names), 3] - 3D coordinates of the atoms for
126
+ # each residue.
127
+ X_atoms = X[batch_idx, position_idx, atom_indices]
128
+
129
+ # X_m_atoms [B, L, len(atom_names)] - mask indicating which atoms are
130
+ # valid for each residue.
131
+ X_m_atoms = X_m[batch_idx, position_idx, atom_indices]
132
+
133
+ return X_atoms, X_m_atoms
134
+
135
+ def construct_X_rep_atoms(self, X, X_m, S):
136
+ """
137
+ Given an array of 3D coordinates, construct a subset of X that
138
+ contains only the representative atom for each residue.
139
+
140
+ Args:
141
+ X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
142
+ 3D coordinates of polymer atoms.
143
+ X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
144
+ Mask indicating which polymer atoms are valid.
145
+ S (torch.Tensor): [B, L] - the sequence of residues.
146
+ Returns:
147
+ X_rep_atoms (torch.Tensor):
148
+ [B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D coordinates
149
+ of the representative atoms for each residue.
150
+ X_m_rep_atoms (torch.Tensor):
151
+ [B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask indicating
152
+ which representative atoms are valid.
153
+ """
154
+ X_rep_atoms, X_m_rep_atoms = self.construct_X_atoms(
155
+ X, X_m, S, self.REPRESENTATIVE_ATOM_NAMES
156
+ )
157
+
158
+ # Check that the representative atoms are disjoint (only one per
159
+ # residue).
160
+ if torch.any(torch.sum(X_m_rep_atoms, dim=-1) > 1):
161
+ raise ValueError("Each residue should have only one representative atom.")
162
+
163
+ return X_rep_atoms, X_m_rep_atoms
164
+
165
+ def construct_X_backbone(self, X, X_m, S):
166
+ """
167
+ Given an array of 3D coordinates, construct a subset of X that
168
+ contains only the backbone atoms for each residue.
169
+
170
+ Args:
171
+ X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
172
+ 3D coordinates of polymer atoms.
173
+ X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
174
+ Mask indicating which polymer atoms are valid.
175
+ S (torch.Tensor): [B, L] - the sequence of residues.
176
+ Returns:
177
+ X_backbone (torch.Tensor): [B, L, len(self.BACKBONE_ATOM_NAMES), 3]
178
+ - 3D coordinates of the backbone atoms for each residue.
179
+ X_m_backbone (torch.Tensor): [B, L, len(self.BACKBONE_ATOM_NAMES)] -
180
+ Mask indicating which backbone atoms are valid.
181
+ """
182
+ X_backbone, X_m_backbone = self.construct_X_atoms(
183
+ X, X_m, S, self.BACKBONE_ATOM_NAMES
184
+ )
185
+
186
+ return X_backbone, X_m_backbone
187
+
188
+ def construct_X_virtual_atom(
189
+ self,
190
+ X_center_atom,
191
+ X_atom_1,
192
+ X_atom_2,
193
+ weight_normal,
194
+ weight_bond_1,
195
+ weight_bond_2,
196
+ ):
197
+ """
198
+ Predict the virtual atom coordinates based on the coordinates of the
199
+ center atom and the two other atoms.
200
+
201
+ Args:
202
+ X_center_atom (torch.Tensor): [B, L, 3] - 3D coordinates of the
203
+ center atom.
204
+ X_atom_1 (torch.Tensor): [B, L, 3] - 3D coordinates of the first
205
+ atom.
206
+ X_atom_2 (torch.Tensor): [B, L, 3] - 3D coordinates of the second
207
+ atom.
208
+ weight_normal (float): Weight for the normal vector.
209
+ weight_bond_1 (float): Weight for the first bond vector.
210
+ weight_bond_2 (float): Weight for the second bond vector.
211
+ """
212
+ # Calculate the bond vectors.
213
+ # bond_1 [B, L, 3] - vector from the center atom to the first atom.
214
+ bond_1 = X_atom_1 - X_center_atom
215
+
216
+ # bond_2 [B, L, 3] - vector from the center atom to the second atom.
217
+ bond_2 = X_atom_2 - X_center_atom
218
+
219
+ # normal [B, L, 3] - normal vector to the plane defined by the two
220
+ # bond vectors.
221
+ normal = torch.cross(bond_1, bond_2, dim=-1)
222
+
223
+ # X_virtual_atom [B, L, 3] - the coordinates of the virtual atom.
224
+ X_virtual_atom = (
225
+ weight_normal * normal
226
+ + weight_bond_1 * bond_1
227
+ + weight_bond_2 * bond_2
228
+ + X_center_atom
229
+ )
230
+
231
+ return X_virtual_atom
232
+
233
+ def construct_X_virtual_atoms(self, X, X_m, S):
234
+ """
235
+ Given an array of 3D coordinates, construct a the virtual atoms.
236
+
237
+ Args:
238
+ X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
239
+ 3D coordinates of polymer atoms.
240
+ X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
241
+ Mask indicating which polymer atoms are valid.
242
+ S (torch.Tensor): [B, L] - the sequence of residues.
243
+ Returns:
244
+ X_virtual_atoms (torch.Tensor):
245
+ [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] - 3D
246
+ coordinates of the virtual atoms for each residue.
247
+ X_m_virtual_atoms (torch.Tensor):
248
+ [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] - Mask
249
+ indicating which virtual atoms are valid.
250
+ """
251
+ X_virtual_atoms = []
252
+ X_m_virtual_atoms = []
253
+ for virtual_atom_data in self.DATA_TO_CALCULATE_VIRTUAL_ATOMS:
254
+ virtual_atom_info, weights = virtual_atom_data
255
+ center_atom = virtual_atom_info["center_atom"]
256
+ atom_1 = virtual_atom_info["atom_1"]
257
+ atom_2 = virtual_atom_info["atom_2"]
258
+
259
+ # Get the coordinates and masks for the center atom and two
260
+ # other atoms. Stack the coordinates and masks of the three atoms
261
+ # to reduce the number of calls to construct_X_atoms.
262
+ atom_names = [center_atom, atom_1, atom_2]
263
+ X_atoms, X_m_atoms = self.construct_X_atoms(X, X_m, S, atom_names)
264
+
265
+ # X_center_atom, X_atom_1, X_atom_2 [B, L, 3] - 3D coordinates of
266
+ # the center atom and the two other atoms.
267
+ X_center_atom = X_atoms[:, :, atom_names.index(center_atom), :]
268
+ X_atom_1 = X_atoms[:, :, atom_names.index(atom_1), :]
269
+ X_atom_2 = X_atoms[:, :, atom_names.index(atom_2), :]
270
+
271
+ # X_m_center_atom, X_m_atom_1, X_m_atom_2 [B, L] - mask indicating
272
+ # if the center atom and the two other atoms are valid for each
273
+ # residue.
274
+ X_m_center_atom = X_m_atoms[:, :, atom_names.index(center_atom)]
275
+ X_m_atom_1 = X_m_atoms[:, :, atom_names.index(atom_1)]
276
+ X_m_atom_2 = X_m_atoms[:, :, atom_names.index(atom_2)]
277
+
278
+ # X_virtual_atom [B, L, 3] - 3D coordinates of the virtual atom
279
+ # constructed from the center atom and the two other atoms.
280
+ X_virtual_atom = self.construct_X_virtual_atom(
281
+ X_center_atom,
282
+ X_atom_1,
283
+ X_atom_2,
284
+ weight_normal=weights["weight_normal"],
285
+ weight_bond_1=weights["weight_bond_1"],
286
+ weight_bond_2=weights["weight_bond_2"],
287
+ )
288
+ X_virtual_atoms.append(X_virtual_atom)
289
+
290
+ # X_m_virtual_atom [B, L] - mask indicating if the virtual atom
291
+ # is valid for each residue.
292
+ X_m_virtual_atom = X_m_center_atom * X_m_atom_1 * X_m_atom_2
293
+ X_m_virtual_atoms.append(X_m_virtual_atom)
294
+
295
+ # X_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] -
296
+ # coordinates of the virtual atoms for each residue.
297
+ X_virtual_atoms = torch.stack(X_virtual_atoms, dim=2)
298
+
299
+ # X_m_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] -
300
+ # mask indicating which virtual atoms are valid for each residue.
301
+ X_m_virtual_atoms = torch.stack(X_m_virtual_atoms, dim=2)
302
+
303
+ # Check that the virtual atoms are disjoint (only one per residue).
304
+ if torch.any(torch.sum(X_m_virtual_atoms, dim=-1) > 1):
305
+ raise ValueError("Each residue should have only one virtual atom.")
306
+
307
+ return X_virtual_atoms, X_m_virtual_atoms
308
+
309
+ def compute_representative_atom_pairwise_distances(
310
+ self, X_rep_atoms, X_m_rep_atoms, residue_mask, eps=1e-6
311
+ ):
312
+ """
313
+ Given an array of 3D coordinates, compute the pairwise distances
314
+ between all pairs of atoms. The masked distances are set to the
315
+ maximum distance.
316
+
317
+ Args:
318
+ X_rep_atoms (torch.Tensor):
319
+ [B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D coordinates
320
+ of the representative atoms for each residue.
321
+ X_m_rep_atoms (torch.Tensor):
322
+ [B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask indicating
323
+ which representative atoms are valid.
324
+ residue_mask (torch.Tensor): [B, L] - mask indicating which residues
325
+ are valid.
326
+ eps (float): Small value used to distances that are
327
+ numerically zero.
328
+ Returns:
329
+ D_rep_neighbors (torch.Tensor): [B, L, K] - Pairwise distances
330
+ between each residue's representative atom, masked by the
331
+ 2D mask, for the top K neighbors.
332
+ E_idx (torch.Tensor): [B, L, K] - Indices of the top K neighbors
333
+ for each residue's representative atom.
334
+ """
335
+ _, L, _, _ = X_rep_atoms.shape
336
+
337
+ # mask_2D [B, L, L] - 2D mask indicating which pairs of residues
338
+ # are valid.
339
+ mask_2D = (
340
+ torch.unsqueeze(residue_mask, 1) * torch.unsqueeze(residue_mask, 2)
341
+ ).bool()
342
+
343
+ # X_rep_atoms_collapsed [B, L, 3] - collapse the representative atom
344
+ # dimension.
345
+ # NOTE: collapsing along this dimension is okay because the
346
+ # self.construct_X_rep_atoms function ensures that there is only one
347
+ # representative atom per residue.
348
+ X_rep_atoms_collapsed = torch.sum(
349
+ X_rep_atoms * X_m_rep_atoms[:, :, :, None], dim=2
350
+ )
351
+
352
+ # dX [B, L, L, 3] - pairwise per-coordinate differences between each
353
+ # residue's representative atom.
354
+ dX_rep = torch.unsqueeze(X_rep_atoms_collapsed, 1) - torch.unsqueeze(
355
+ X_rep_atoms_collapsed, 2
356
+ )
357
+
358
+ # D_rep [B, L, L] - pairwise distances between each residue's
359
+ # representative atom, masked by the 2D mask.
360
+ D_rep = mask_2D * torch.sqrt(torch.sum(dX_rep**2, 3) + eps)
361
+
362
+ # D_rep_max [B, L, L] - a constant value that is the maximum distance
363
+ # between any two representative atoms in each batch entry.
364
+ D_rep_max, _ = torch.max(D_rep, -1, keepdim=True)
365
+
366
+ # D_rep_adjust [B, L, L] - the pairwise distances between each residue's
367
+ # representative atom, with masked distances set to the maximum
368
+ # distance.
369
+ D_rep_adjust = D_rep + (~mask_2D) * D_rep_max
370
+
371
+ # D_rep_neighbors [B, L, K] - the top K pairwise distances between
372
+ # each residue's representative atom, masked by the 2D mask.
373
+ # E_idx [B, L, K] - the indices of the top K pairwise distances
374
+ # between each residue's representative atom.
375
+ D_rep_neighbors, E_idx = torch.topk(
376
+ D_rep_adjust, min(self.num_neighbors, L), dim=-1, largest=False
377
+ )
378
+
379
+ return D_rep_neighbors, E_idx
380
+
381
+ def compute_rbf_embedding_from_distances(self, D):
382
+ """
383
+ Given a tensor of pairwise distances, compute the radial basis
384
+ embedding of the distances.
385
+
386
+ Args:
387
+ D (torch.Tensor): [B, L, K] - Pairwise distances between each
388
+ residue's representative atom, masked by the 2D mask.
389
+ Returns:
390
+ rbf_embedding (torch.Tensor): [B, L, K, num_rbf] - Radial basis
391
+ function embedding of the pairwise distances.
392
+ """
393
+ # Linear space the means of the radial basis functions.
394
+ # rbf_mus: [1, 1, 1, num_rbf]
395
+ rbf_mus = torch.linspace(
396
+ self.min_rbf_mean, self.max_rbf_mean, self.num_rbf, device=D.device
397
+ )
398
+ rbf_mus = rbf_mus[None, None, None, :]
399
+
400
+ # The standard deviation of the radial basis functions.
401
+ rbf_sigma = (self.max_rbf_mean - self.min_rbf_mean) / self.num_rbf
402
+
403
+ # Expand the dimensions of D to match the shape of rbf_mus.
404
+ # D_expand: [B, L, K, 1]
405
+ D_expand = torch.unsqueeze(D, -1)
406
+
407
+ # Compute the radial basis function embedding.
408
+ # RBF: [B, L, K, num_rbf]
409
+ rbf_embedding = torch.exp(-(((D_expand - rbf_mus) / rbf_sigma) ** 2))
410
+
411
+ return rbf_embedding
412
+
413
+ def compute_pairwise_residue_rbf_encoding(self, X, E_idx, X_m, eps=1e-6):
414
+ """
415
+ Given an array of 3D coordinates, compute the atom by atom pairwise
416
+ distances between each pair of neighbors. Mask the RBF features using
417
+ the atom mask.
418
+
419
+ NOTE: num_atoms = self.num_backbone_atoms + self.num_virtual_atoms
420
+
421
+ Args:
422
+ X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
423
+ polymer atoms.
424
+ E_idx (torch.Tensor): [B, L, K] - Indices of the top K neighbors.
425
+ X_m (torch.Tensor): [B, L, num_atoms] - mask indicating which
426
+ polymer atoms are valid.
427
+ eps (float): Small value added to distances that are zero.
428
+ Returns:
429
+ RBF_all (torch.Tensor): [B, L, K, num_atoms * num_atoms * num_rbf] -
430
+ Radial basis function embedding of the pairwise atomic
431
+ distances for each pair of residue neighbors.
432
+ """
433
+ B = X.shape[0]
434
+ L = X.shape[1]
435
+ K = E_idx.shape[2]
436
+ num_atoms = X.shape[2]
437
+
438
+ # X_flat [B, L, num_atoms * 3] - flatten the last two dimensions.
439
+ X_flat = X.reshape(B, L, -1)
440
+
441
+ # X_flat_g [B, L, K, num_atoms * 3] - gather the top K neighbors.
442
+ X_flat_g = gather_nodes(X_flat, E_idx)
443
+
444
+ # X_g [B, L, K, num_atoms, 3] - reshape the gathered tensor.
445
+ X_g = X_flat_g.reshape(B, L, K, num_atoms, 3)
446
+
447
+ # D [B, L, K, num_atoms, num_atoms] - pairwise distances between
448
+ # each residue's atoms.
449
+ D = torch.sqrt(
450
+ torch.sum(
451
+ (X[:, :, None, :, None, :] - X_g[:, :, :, None, :, :]) ** 2, dim=-1
452
+ )
453
+ + eps
454
+ )
455
+
456
+ # RBF_all [B, L, K, num_atoms, num_atoms, num_rbf] - radial basis
457
+ # function embedding of the pairwise distances.
458
+ RBF_all = self.compute_rbf_embedding_from_distances(D)
459
+
460
+ # If X_m is not all ones, mask the radial basis function embedding
461
+ # with the atom mask.
462
+ if not torch.all(X_m == 1):
463
+ # X_m_gathered [B, L, K, num_atoms] - gather the atom mask of the
464
+ # top K neighbors.
465
+ X_m_gathered = gather_nodes(X_m, E_idx)
466
+
467
+ # RBF_all [B, L, K, num_atoms, num_atoms, num_rbf] - mask the
468
+ # radial basis function embedding with the atom mask.
469
+ RBF_all = (
470
+ RBF_all
471
+ * X_m[:, :, None, :, None, None]
472
+ * X_m_gathered[:, :, :, None, :, None]
473
+ )
474
+
475
+ # RBF_all [B, L, K, num_atoms * num_atoms * num_rbf] - flatten the
476
+ # last dimensions.
477
+ RBF_all = RBF_all.view(B, L, K, -1)
478
+
479
+ return RBF_all
480
+
481
+ def compute_pairwise_positional_encoding(self, R_idx, E_idx, chain_labels):
482
+ """
483
+ Given the indices of the residues and the indices of the top K
484
+ neighbors, compute the positional encoding of the top K neighbors
485
+ for each residue.
486
+
487
+ Args:
488
+ R_idx (torch.Tensor): [B, L] - indices of the residues.
489
+ E_idx (torch.Tensor): [B, L, K] - indices of the top K neighbors.
490
+ chain_labels (torch.Tensor): [B, L] - chain labels for each residue.
491
+ Returns:
492
+ positional_encoding (torch.Tensor):
493
+ [B, L, K, num_positional_embeddings] - the positional encoding
494
+ of the top K neighbors.
495
+ """
496
+ # positional_offset [B, L, L] - pairwise differences between the
497
+ # indices of the residues.
498
+ positional_offset = (R_idx[:, :, None] - R_idx[:, None, :]).long()
499
+
500
+ # positional_offset_g [B, L, K] - gather the positional offset
501
+ # of the top K neighbors.
502
+ positional_offset_g = gather_edges(positional_offset[:, :, :, None], E_idx)[
503
+ :, :, :, 0
504
+ ]
505
+
506
+ # same_chain_mask [B, L, L] - mask indicating which residues are in the
507
+ # same chain.
508
+ same_chain_mask = (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
509
+
510
+ # same_chain_mask_g [B, L, K] - gather the same chain mask of the
511
+ # top K neighbors.
512
+ same_chain_mask_g = gather_edges(same_chain_mask[:, :, :, None], E_idx)[
513
+ :, :, :, 0
514
+ ]
515
+
516
+ # positional_encoding [B, L, K, num_positional_embeddings] - the
517
+ # positional encoding of the top K neighbors.
518
+ positional_encoding = self.positional_embedding(
519
+ positional_offset_g, same_chain_mask_g
520
+ )
521
+
522
+ return positional_encoding
523
+
524
+ def featurize_edges(self, input_features):
525
+ """
526
+ Given input features, construct the edge features for the protein.
527
+
528
+ Args:
529
+ input_features (dict): Dictionary containing the input features.
530
+ - residue_mask (torch.Tensor): [B, L] - Mask indicating which
531
+ residues are valid.
532
+ - R_idx (torch.Tensor): [B, L] - Indices of the residues.
533
+ - chain_labels (torch.Tensor): [B, L] - Chain labels for each
534
+ residue.
535
+ - X_backbone (torch.Tensor):
536
+ [B, L, len(self.BACKBONE_ATOM_NAMES), 3] - 3D
537
+ coordinates of the backbone atoms for each residue.
538
+ - X_m_backbone (torch.Tensor):
539
+ [B, L, len(self.BACKBONE_ATOM_NAMES)] - mask
540
+ indicating which backbone atoms are valid.
541
+ - X_virtual_atoms (torch.Tensor):
542
+ [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] -
543
+ 3D coordinates of the virtual atoms for each residue.
544
+ - X_m_virtual_atoms (torch.Tensor):
545
+ [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] -
546
+ mask indicating which virtual atoms are valid.
547
+ - X_rep_atoms (torch.Tensor):
548
+ [B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D
549
+ coordinates of the representative atoms for each residue.
550
+ - X_m_rep_atoms (torch.Tensor):
551
+ [B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask
552
+ indicating which representative atoms are valid.
553
+ Returns:
554
+ edge_features (dict): Dictionary containing the edge features.
555
+ - E_idx (torch.Tensor): [B, L, K] - Indices of the top K
556
+ neighbors.
557
+ - E (torch.Tensor): [B, L, K, num_edge_output_features] -
558
+ Edge features for each pair of neighbors.
559
+ """
560
+ # The following features should come from data loading.
561
+ if "residue_mask" not in input_features:
562
+ raise ValueError("Input features must contain 'residue_mask' key.")
563
+ if "R_idx" not in input_features:
564
+ raise ValueError("Input features must contain 'R_idx' key.")
565
+ if "chain_labels" not in input_features:
566
+ raise ValueError("Input features must contain 'chain_labels' key.")
567
+
568
+ # The following features should be computed by the forward function.
569
+ if "X_backbone" not in input_features:
570
+ raise ValueError("Input features must contain 'X_backbone' key.")
571
+ if "X_m_backbone" not in input_features:
572
+ raise ValueError("Input features must contain 'X_m_backbone' key.")
573
+ if "X_virtual_atoms" not in input_features:
574
+ raise ValueError("Input features must contain 'X_virtual_atoms' key.")
575
+ if "X_m_virtual_atoms" not in input_features:
576
+ raise ValueError("Input features must contain 'X_m_virtual_atoms' key.")
577
+ if "X_rep_atoms" not in input_features:
578
+ raise ValueError("Input features must contain 'X_rep_atoms' key.")
579
+ if "X_m_rep_atoms" not in input_features:
580
+ raise ValueError("Input features must contain 'X_m_rep_atoms' key.")
581
+
582
+ # Compute the pairwise distances between the representative atoms.
583
+ # D_rep_neighbors [B, L, K] - pairwise distances between each residue's
584
+ # representative atom, masked by the 2D mask.
585
+ # E_idx [B, L, K] - indices of the top K neighbors.
586
+ D_rep_neighbors, E_idx = self.compute_representative_atom_pairwise_distances(
587
+ input_features["X_rep_atoms"],
588
+ input_features["X_m_rep_atoms"],
589
+ input_features["residue_mask"],
590
+ )
591
+
592
+ # Concatenate the backbone and virtual atom coordinates.
593
+ # X_backbone_with_virtual_atoms [B, L, num_atoms + num_virtual_atoms, 3]
594
+ # - 3D coordinates of the backbone and virtual atoms for each residue.
595
+ # X_m_backbone_with_virtual_atoms [B, L, num_atoms + num_virtual_atoms]
596
+ # - mask indicating which backbone and virtual atoms are valid.
597
+ X_backbone_with_virtual_atoms = torch.cat(
598
+ (input_features["X_backbone"], input_features["X_virtual_atoms"]), dim=-2
599
+ )
600
+ X_m_backbone_with_virtual_atoms = torch.cat(
601
+ (input_features["X_m_backbone"], input_features["X_m_virtual_atoms"]),
602
+ dim=-1,
603
+ )
604
+
605
+ # Compute the RBF features for the atomwise distances for each pair of
606
+ # neighbors.
607
+ # RBF_all [B, L, K, num_atoms * num_atoms * num_rbf] - radial basis
608
+ # function embedding of the pairwise atomic distances for each pair of
609
+ # residue neighbors.
610
+ RBF_all = self.compute_pairwise_residue_rbf_encoding(
611
+ X_backbone_with_virtual_atoms, E_idx, X_m_backbone_with_virtual_atoms
612
+ )
613
+
614
+ # Compute the positional encoding for the top K neighbors.
615
+ # positional_encoding [B, L, K, num_positional_embeddings] - the
616
+ # positional encoding of the top K neighbors.
617
+ positional_encoding = self.compute_pairwise_positional_encoding(
618
+ input_features["R_idx"], E_idx, input_features["chain_labels"]
619
+ )
620
+
621
+ # Concatenate the positional encoding and the RBF features.
622
+ # E [B, L, K, num_positional_embeddings + num_atoms * num_atoms *
623
+ # num_rbf] - the edge features for each pair of neighbors.
624
+ E_raw = torch.cat((positional_encoding, RBF_all), dim=-1)
625
+
626
+ # Embed and normalize the edge features.
627
+ # E [B, L, K, num_edge_output_features] - the edge features for each
628
+ # pair of neighbors.
629
+ E = self.edge_embedding(E_raw)
630
+ E = self.edge_norm(E)
631
+
632
+ edge_features = {
633
+ "E_idx": E_idx,
634
+ "E": E,
635
+ }
636
+
637
+ return edge_features
638
+
639
+ def featurize_nodes(self, input_features, edge_features):
640
+ """
641
+ The default ProteinMPNN does not have any node features.
642
+
643
+ Args:
644
+ input_features (dict): Dictionary containing the input features.
645
+ edge_features (dict): Dictionary containing the edge features.
646
+ Returns:
647
+ node_features (dict): Dictionary containing the node features.
648
+ """
649
+ node_features = {}
650
+ return node_features
651
+
652
+ def noise_structure(self, input_features):
653
+ """
654
+ Given input features containing 3D coordinates of atoms, add Gaussian
655
+ noise to the coordinates.
656
+
657
+ Args:
658
+ input_features (dict): Dictionary containing the input features.
659
+ - X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
660
+ polymer atoms.
661
+ - structure_noise (float): Standard deviation of the
662
+ Gaussian noise to add to the input coordinates, in
663
+ Angstroms.
664
+ Side Effects:
665
+ input_features["X_pre_noise"] (torch.Tensor): [B, L, num_atoms, 3] -
666
+ 3D coordinates of polymer atoms before adding noise.
667
+ input_features["X"] (torch.Tensor): [B, L, num_atoms, 3] - 3D
668
+ coordinates of polymer atoms with added Gaussian noise.
669
+ """
670
+ if "X" not in input_features:
671
+ raise ValueError("Input features must contain 'X' key.")
672
+ if "structure_noise" not in input_features:
673
+ raise ValueError("Input features must contain 'structure_noise' key.")
674
+
675
+ structure_noise = input_features["structure_noise"]
676
+
677
+ # If the noise is non-zero, add Gaussian noise to the input
678
+ # coordinates.
679
+ if structure_noise > 0:
680
+ # Copy the original coordinates before adding noise.
681
+ input_features["X_pre_noise"] = input_features["X"].clone()
682
+
683
+ # Add Gaussian noise to the input coordinates.
684
+ input_features["X"] = input_features[
685
+ "X"
686
+ ] + structure_noise * torch.randn_like(input_features["X"])
687
+ else:
688
+ input_features["X_pre_noise"] = input_features["X"].clone()
689
+
690
+ def forward(self, input_features):
691
+ """
692
+ Given input features, construct the graph features for the protein.
693
+
694
+ Args:
695
+ input_features (dict): Dictionary containing the input features.
696
+ - X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
697
+ polymer atoms.
698
+ - X_m (torch.Tensor): [B, L, num_atoms] - Mask indicating
699
+ which polymer atoms are valid.
700
+ Returns:
701
+ graph_features (dict): Dictionary containing the graph features.
702
+ Union of edge and node features (see the repsective featurize
703
+ functions).
704
+ """
705
+ if "X" not in input_features:
706
+ raise ValueError("Input features must contain 'X' key.")
707
+ if "X_m" not in input_features:
708
+ raise ValueError("Input features must contain 'X_m' key.")
709
+ if "S" not in input_features:
710
+ raise ValueError("Input features must contain 'S' key.")
711
+
712
+ # Add Gaussian noise to the input coordinates.
713
+ self.noise_structure(input_features)
714
+
715
+ # Get the backbone atoms and mask.
716
+ # X_backbone [B, L, len(self.BACKBONE_ATOM_NAMES), 3] - 3D coordinates
717
+ # of the backbone atoms for each residue.
718
+ # X_m_backbone [B, L, len(self.BACKBONE_ATOM_NAMES)] - mask indicating
719
+ # which backbone atoms are valid.
720
+ X_backbone, X_m_backbone = self.construct_X_backbone(
721
+ input_features["X"], input_features["X_m"], input_features["S"]
722
+ )
723
+ input_features["X_backbone"] = X_backbone
724
+ input_features["X_m_backbone"] = X_m_backbone
725
+
726
+ # Get the virtual atoms and mask.
727
+ # X_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] -
728
+ # 3D coordinates of the virtual atoms for each residue.
729
+ # X_m_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] -
730
+ # mask indicating which virtual atoms are valid.
731
+ X_virtual_atoms, X_m_virtual_atoms = self.construct_X_virtual_atoms(
732
+ input_features["X"], input_features["X_m"], input_features["S"]
733
+ )
734
+ input_features["X_virtual_atoms"] = X_virtual_atoms
735
+ input_features["X_m_virtual_atoms"] = X_m_virtual_atoms
736
+
737
+ # Get the representative atoms.
738
+ # X_rep_atoms [B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D
739
+ # coordinates of the representative atoms for each residue.
740
+ # X_m_rep_atoms [B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask
741
+ # indicating which representative atoms are valid.
742
+ X_rep_atoms, X_m_rep_atoms = self.construct_X_rep_atoms(
743
+ input_features["X"], input_features["X_m"], input_features["S"]
744
+ )
745
+ input_features["X_rep_atoms"] = X_rep_atoms
746
+ input_features["X_m_rep_atoms"] = X_m_rep_atoms
747
+
748
+ # Featurize the edges.
749
+ edge_features = self.featurize_edges(input_features)
750
+
751
+ # Featurize the nodes.
752
+ # Edge features are sometimes needed for node feature calculation;
753
+ # for instance, for gathering nearest neighbor side chain atoms for
754
+ # per-residue ligand subgraphs in LigandMPNN.
755
+ node_features = self.featurize_nodes(input_features, edge_features)
756
+
757
+ # Construct the graph features.
758
+ graph_features = {**edge_features, **node_features}
759
+
760
+ return graph_features
761
+
762
+
763
+ class ProteinFeaturesMembrane(ProteinFeatures):
764
+ def __init__(self, num_membrane_classes=3, **kwargs):
765
+ """
766
+ Given a protein structure, extract the features for the graph
767
+ representation of the protein. This class is aware of membrane labels.
768
+
769
+ All args are the same as the parents class, except for the following:
770
+ Args:
771
+ num_membrane_classes (int): Number of membrane classes.
772
+ """
773
+ super(ProteinFeaturesMembrane, self).__init__(**kwargs)
774
+ self.num_membrane_classes = num_membrane_classes
775
+
776
+ self.node_embedding = nn.Linear(
777
+ self.num_classes, self.num_node_output_features, bias=False
778
+ )
779
+ self.node_norm = nn.LayerNorm(self.num_node_output_features)
780
+
781
+ def featurize_nodes(self, input_features, edge_features):
782
+ """
783
+ Given input features, construct the node features for the protein.
784
+
785
+ Args:
786
+ input_features (dict): Dictionary containing the input features.
787
+ - membrane_per_residue_labels (torch.Tensor): [B, L] - Class
788
+ labels for each residue.
789
+ edge_features (dict): Dictionary containing the edge features.
790
+ Returns:
791
+ node_features (dict): Dictionary containing the node features.
792
+ - V (torch.Tensor): [B, L, self.num_node_output_features] - Node
793
+ features for each residue.
794
+ """
795
+ if "membrane_per_residue_labels" not in input_features:
796
+ raise ValueError(
797
+ "Input features must contain 'membrane_per_residue_labels' key."
798
+ )
799
+
800
+ # Turn the class labels into one-hot vectors.
801
+ # class_one_hot [B, L, self.num_membrane_classes] - one-hot encoding
802
+ # of the class
803
+ class_one_hot = torch.nn.functional.one_hot(
804
+ input_features["membrane_per_residue_labels"],
805
+ num_classes=self.num_membrane_classes,
806
+ ).float()
807
+
808
+ # Embed and normalize the node features.
809
+ # V [B, L, self.num_node_output_features] - the node features for each
810
+ # residue.
811
+ V = self.node_embedding(class_one_hot)
812
+ V = self.node_norm(V)
813
+
814
+ node_features = {"V": V}
815
+
816
+ return node_features
817
+
818
+
819
+ class ProteinFeaturesPSSM(ProteinFeatures):
820
+ def __init__(self, num_pssm_features=20, **kwargs):
821
+ """
822
+ Given a protein structure, extract the features for the graph
823
+ representation of the protein. This class is aware of PSSM features.
824
+
825
+ All args are the same as the parents class, except for the following:
826
+ Args:
827
+ num_pssm_features (int): Number of PSSM features.
828
+ """
829
+ super(ProteinFeaturesPSSM, self).__init__(**kwargs)
830
+ self.num_pssm_features = num_pssm_features
831
+
832
+ self.node_embedding = nn.Linear(
833
+ self.num_pssm_features, self.num_node_output_features, bias=False
834
+ )
835
+ self.node_norm = nn.LayerNorm(self.num_node_output_features)
836
+
837
+ def featurize_nodes(self, input_features, edge_features):
838
+ """
839
+ Given input features, construct the node features for the protein.
840
+
841
+ Args:
842
+ input_features (dict): Dictionary containing the input features.
843
+ - pssm (torch.Tensor): [B, L, self.num_pssm_features] - PSSM
844
+ features for each residue.
845
+ edge_features (dict): Dictionary containing the edge features.
846
+ Returns:
847
+ node_features (dict): Dictionary containing the node features.
848
+ - V (torch.Tensor): [B, L, self.num_node_output_features] - Node
849
+ features for each residue.
850
+ """
851
+ if "pssm" not in input_features:
852
+ raise ValueError("Input features must contain 'pssm' key.")
853
+
854
+ # Embed and normalize the node features.
855
+ # V [B, L, self.num_node_output_features] - the node features for each
856
+ # residue.
857
+ V = self.node_embedding(input_features["pssm"])
858
+ V = self.node_norm(V)
859
+
860
+ node_features = {"V": V}
861
+
862
+ return node_features
863
+
864
+
865
+ class ProteinFeaturesLigand(ProteinFeatures):
866
+ # Note, CB is excluded from the side chain atoms due to the use of the
867
+ # virtual CB.
868
+ SIDE_CHAIN_ATOM_NAMES = [
869
+ "CG",
870
+ "CG1",
871
+ "CG2",
872
+ "OG",
873
+ "OG1",
874
+ "SG",
875
+ "CD",
876
+ "CD1",
877
+ "CD2",
878
+ "ND1",
879
+ "ND2",
880
+ "OD1",
881
+ "OD2",
882
+ "SD",
883
+ "CE",
884
+ "CE1",
885
+ "CE2",
886
+ "CE3",
887
+ "NE",
888
+ "NE1",
889
+ "NE2",
890
+ "OE1",
891
+ "OE2",
892
+ "CH2",
893
+ "NH1",
894
+ "NH2",
895
+ "OH",
896
+ "CZ",
897
+ "CZ2",
898
+ "CZ3",
899
+ "NZ",
900
+ "OXT",
901
+ ]
902
+
903
+ # Mapping of side chain atom name to element name.
904
+ SIDE_CHAIN_ATOM_NAME_TO_ELEMENT_NAME = {
905
+ "CG": "C",
906
+ "CG1": "C",
907
+ "CG2": "C",
908
+ "OG": "O",
909
+ "OG1": "O",
910
+ "SG": "S",
911
+ "CD": "C",
912
+ "CD1": "C",
913
+ "CD2": "C",
914
+ "ND1": "N",
915
+ "ND2": "N",
916
+ "OD1": "O",
917
+ "OD2": "O",
918
+ "SD": "S",
919
+ "CE": "C",
920
+ "CE1": "C",
921
+ "CE2": "C",
922
+ "CE3": "C",
923
+ "NE": "N",
924
+ "NE1": "N",
925
+ "NE2": "N",
926
+ "OE1": "O",
927
+ "OE2": "O",
928
+ "CH2": "C",
929
+ "NH1": "N",
930
+ "NH2": "N",
931
+ "OH": "O",
932
+ "CZ": "C",
933
+ "CZ2": "C",
934
+ "CZ3": "C",
935
+ "NZ": "N",
936
+ "OXT": "O",
937
+ }
938
+
939
+ def __init__(self, num_neighbors=32, num_context_atoms=25, **kwargs):
940
+ """
941
+ Given a protein structure and ligand structure, extract the features for
942
+ the graph representation of the protein and ligand. This class is aware
943
+ of ligand features.
944
+
945
+ All args are the same as the parents class, except for the following:
946
+ Args:
947
+ num_neighbors (int): Number of neighbors to consider, default
948
+ changed from 48 to 32.
949
+ num_context_atoms (int): Number of ligand plus side chain atoms to
950
+ consider for each polymer residue.
951
+ """
952
+ super(ProteinFeaturesLigand, self).__init__(
953
+ num_neighbors=num_neighbors, **kwargs
954
+ )
955
+ self.num_context_atoms = num_context_atoms
956
+
957
+ # Number of side chain atoms.
958
+ self.num_side_chain_atoms = len(self.SIDE_CHAIN_ATOM_NAMES)
959
+
960
+ # Features for atom type (periodic table features):
961
+ # There is a null group, period, and atomic number.
962
+ self.num_periodic_table_groups = 1 + 18
963
+ self.num_periodic_table_periods = 1 + 7
964
+ self.num_atomic_numbers = 1 + 118
965
+ self.num_atom_type_input_features = (
966
+ self.num_periodic_table_groups
967
+ + self.num_periodic_table_periods
968
+ + self.num_atomic_numbers
969
+ )
970
+
971
+ # Number of nearest neighbors residue to consider for finding atomized
972
+ # side chain atoms.
973
+ self.num_neighbors_for_atomized_side_chain = 16
974
+
975
+ # Max distance for finding nearest ligand atom neighbors.
976
+ self.max_distance_for_ligand_atoms = 10000.0
977
+
978
+ # Projection of the atom type features to the embedding space.
979
+ self.num_atom_type_output_features = 64
980
+
981
+ # Number of angle features.
982
+ self.num_angle_features = 4
983
+
984
+ # Node features (protein-ligand subgraph edge features).
985
+ # 1. RBF features for ligand atom to each backbone atom and virtual atom
986
+ # 2. Atom type features for the ligand atom
987
+ # 3. Angle features for the ligand atom and the backbone atoms
988
+ self.num_node_input_features = (
989
+ (self.num_backbone_atoms + self.num_virtual_atoms) * self.num_rbf
990
+ + self.num_atom_type_output_features
991
+ + self.num_angle_features
992
+ )
993
+
994
+ # Layers for the protein-ligand subgraph edge features.
995
+ self.embed_atom_type_features = nn.Linear(
996
+ self.num_atom_type_input_features,
997
+ self.num_atom_type_output_features,
998
+ bias=True,
999
+ )
1000
+ self.node_embedding = nn.Linear(
1001
+ self.num_node_input_features, self.num_node_output_features, bias=True
1002
+ )
1003
+ self.node_norm = nn.LayerNorm(self.num_node_output_features)
1004
+
1005
+ # Layers for the ligand subgraphs.
1006
+ self.ligand_subgraph_node_embedding = nn.Linear(
1007
+ self.num_atom_type_input_features, self.num_node_output_features, bias=False
1008
+ )
1009
+ self.ligand_subgraph_node_norm = nn.LayerNorm(self.num_node_output_features)
1010
+ self.ligand_subgraph_edge_embedding = nn.Linear(
1011
+ self.num_rbf, self.num_node_output_features, bias=False
1012
+ )
1013
+ self.ligand_subgraph_edge_norm = nn.LayerNorm(self.num_node_output_features)
1014
+
1015
+ # Numeric encoding of the atom type (atomic number for the last 32 atoms
1016
+ # in the 37 atom representation).
1017
+ self.register_buffer(
1018
+ "side_chain_atom_types",
1019
+ torch.tensor(
1020
+ [
1021
+ ELEMENT_NAME_TO_ATOMIC_NUMBER[
1022
+ self.SIDE_CHAIN_ATOM_NAME_TO_ELEMENT_NAME[atom_name]
1023
+ ]
1024
+ for atom_name in self.SIDE_CHAIN_ATOM_NAMES
1025
+ ]
1026
+ ),
1027
+ )
1028
+
1029
+ # Atomic number, period, and group for the periodic table.
1030
+ self.register_buffer(
1031
+ "periodic_table_groups",
1032
+ torch.tensor(
1033
+ [
1034
+ 0,
1035
+ 1,
1036
+ 18,
1037
+ 1,
1038
+ 2,
1039
+ 13,
1040
+ 14,
1041
+ 15,
1042
+ 16,
1043
+ 17,
1044
+ 18,
1045
+ 1,
1046
+ 2,
1047
+ 13,
1048
+ 14,
1049
+ 15,
1050
+ 16,
1051
+ 17,
1052
+ 18,
1053
+ 1,
1054
+ 2,
1055
+ 3,
1056
+ 4,
1057
+ 5,
1058
+ 6,
1059
+ 7,
1060
+ 8,
1061
+ 9,
1062
+ 10,
1063
+ 11,
1064
+ 12,
1065
+ 13,
1066
+ 14,
1067
+ 15,
1068
+ 16,
1069
+ 17,
1070
+ 18,
1071
+ 1,
1072
+ 2,
1073
+ 3,
1074
+ 4,
1075
+ 5,
1076
+ 6,
1077
+ 7,
1078
+ 8,
1079
+ 9,
1080
+ 10,
1081
+ 11,
1082
+ 12,
1083
+ 13,
1084
+ 14,
1085
+ 15,
1086
+ 16,
1087
+ 17,
1088
+ 18,
1089
+ 1,
1090
+ 2,
1091
+ 3,
1092
+ 3,
1093
+ 3,
1094
+ 3,
1095
+ 3,
1096
+ 3,
1097
+ 3,
1098
+ 3,
1099
+ 3,
1100
+ 3,
1101
+ 3,
1102
+ 3,
1103
+ 3,
1104
+ 3,
1105
+ 3,
1106
+ 4,
1107
+ 5,
1108
+ 6,
1109
+ 7,
1110
+ 8,
1111
+ 9,
1112
+ 10,
1113
+ 11,
1114
+ 12,
1115
+ 13,
1116
+ 14,
1117
+ 15,
1118
+ 16,
1119
+ 17,
1120
+ 18,
1121
+ 1,
1122
+ 2,
1123
+ 3,
1124
+ 3,
1125
+ 3,
1126
+ 3,
1127
+ 3,
1128
+ 3,
1129
+ 3,
1130
+ 3,
1131
+ 3,
1132
+ 3,
1133
+ 3,
1134
+ 3,
1135
+ 3,
1136
+ 3,
1137
+ 3,
1138
+ 4,
1139
+ 5,
1140
+ 6,
1141
+ 7,
1142
+ 8,
1143
+ 9,
1144
+ 10,
1145
+ 11,
1146
+ 12,
1147
+ 13,
1148
+ 14,
1149
+ 15,
1150
+ 16,
1151
+ 17,
1152
+ 18,
1153
+ ],
1154
+ dtype=torch.long,
1155
+ ),
1156
+ )
1157
+ self.register_buffer(
1158
+ "periodic_table_periods",
1159
+ torch.tensor(
1160
+ [
1161
+ 0,
1162
+ 1,
1163
+ 1,
1164
+ 2,
1165
+ 2,
1166
+ 2,
1167
+ 2,
1168
+ 2,
1169
+ 2,
1170
+ 2,
1171
+ 2,
1172
+ 3,
1173
+ 3,
1174
+ 3,
1175
+ 3,
1176
+ 3,
1177
+ 3,
1178
+ 3,
1179
+ 3,
1180
+ 4,
1181
+ 4,
1182
+ 4,
1183
+ 4,
1184
+ 4,
1185
+ 4,
1186
+ 4,
1187
+ 4,
1188
+ 4,
1189
+ 4,
1190
+ 4,
1191
+ 4,
1192
+ 4,
1193
+ 4,
1194
+ 4,
1195
+ 4,
1196
+ 4,
1197
+ 4,
1198
+ 5,
1199
+ 5,
1200
+ 5,
1201
+ 5,
1202
+ 5,
1203
+ 5,
1204
+ 5,
1205
+ 5,
1206
+ 5,
1207
+ 5,
1208
+ 5,
1209
+ 5,
1210
+ 5,
1211
+ 5,
1212
+ 5,
1213
+ 5,
1214
+ 5,
1215
+ 5,
1216
+ 6,
1217
+ 6,
1218
+ 6,
1219
+ 6,
1220
+ 6,
1221
+ 6,
1222
+ 6,
1223
+ 6,
1224
+ 6,
1225
+ 6,
1226
+ 6,
1227
+ 6,
1228
+ 6,
1229
+ 6,
1230
+ 6,
1231
+ 6,
1232
+ 6,
1233
+ 6,
1234
+ 6,
1235
+ 6,
1236
+ 6,
1237
+ 6,
1238
+ 6,
1239
+ 6,
1240
+ 6,
1241
+ 6,
1242
+ 6,
1243
+ 6,
1244
+ 6,
1245
+ 6,
1246
+ 6,
1247
+ 6,
1248
+ 7,
1249
+ 7,
1250
+ 7,
1251
+ 7,
1252
+ 7,
1253
+ 7,
1254
+ 7,
1255
+ 7,
1256
+ 7,
1257
+ 7,
1258
+ 7,
1259
+ 7,
1260
+ 7,
1261
+ 7,
1262
+ 7,
1263
+ 7,
1264
+ 7,
1265
+ 7,
1266
+ 7,
1267
+ 7,
1268
+ 7,
1269
+ 7,
1270
+ 7,
1271
+ 7,
1272
+ 7,
1273
+ 7,
1274
+ 7,
1275
+ 7,
1276
+ 7,
1277
+ 7,
1278
+ 7,
1279
+ 7,
1280
+ ],
1281
+ dtype=torch.long,
1282
+ ),
1283
+ )
1284
+
1285
+ def construct_X_side_chain(self, X, X_m, S):
1286
+ """
1287
+ Given the 3D coordinates of the atoms and the mask, construct the
1288
+ side chain atoms and their mask.
1289
+
1290
+ Args:
1291
+ X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
1292
+ 3D coordinates of polymer atoms.
1293
+ X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
1294
+ Mask indicating which polymer atoms are valid.
1295
+ S (torch.Tensor): [B, L] - Sequence of the polymer residues.
1296
+ Returns:
1297
+ X_side_chain (torch.Tensor):
1298
+ [B, L, len(self.SIDE_CHAIN_ATOM_NAMES), 3] -
1299
+ 3D coordinates of the side chain atoms for each residue.
1300
+ X_m_side_chain (torch.Tensor):
1301
+ [B, L, len(self.SIDE_CHAIN_ATOM_NAMES)] -
1302
+ Mask indicating which side chain atoms are valid.
1303
+ """
1304
+ X_side_chain, X_m_side_chain = self.construct_X_atoms(
1305
+ X, X_m, S, self.SIDE_CHAIN_ATOM_NAMES
1306
+ )
1307
+
1308
+ return X_side_chain, X_m_side_chain
1309
+
1310
+ def construct_angle_features(
1311
+ self, center_atom, atom_1, atom_2, ligand_subgraph_Y, eps=1e-8
1312
+ ):
1313
+ """
1314
+ Given the 3D coordinates of the center atom, the first atom, the second
1315
+ atom, and the ligand atoms, compute the angle features for the ligand
1316
+ atom with respect to the center atom and the two atoms.
1317
+
1318
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
1319
+ residue subgraph.
1320
+
1321
+ Args:
1322
+ center_atom (torch.Tensor): [B, L, 3] - 3D coordinates of the center
1323
+ atom.
1324
+ atom_1 (torch.Tensor): [B, L, 3] - 3D coordinates of the first atom.
1325
+ atom_2 (torch.Tensor): [B, L, 3] - 3D coordinates of the second
1326
+ atom.
1327
+ ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
1328
+ the M closest ligand atoms to each residue.
1329
+ eps (float): Small value added to distances that are zero.
1330
+ Returns:
1331
+ angle_features (torch.Tensor):
1332
+ [B, L, M, self.num_angle_features] - Angle features for the
1333
+ ligand atom with respect to the center atom and the two atoms.
1334
+ cos_azimuthal_xy_angle [B, L, M] -
1335
+ Cosine of the azimuthal angle in the local x-y plane.
1336
+ sin_azimuthal_xy_angle [B, L, M] -
1337
+ Sine of the azimuthal angle in the local x-y plane.
1338
+ cos_inclination_angle [B, L, M] -
1339
+ Cosine of the inclination angle (polar angle).
1340
+ sin_inclination_angle [B, L, M] -
1341
+ Sine of the inclination angle (polar angle).
1342
+ """
1343
+ # Compute the bond vectors.
1344
+ # bond_1 [B, L, 3] - vector from the center atom to the first atom.
1345
+ # bond_2 [B, L, 3] - vector from the center atom to the second atom.
1346
+ bond_1 = atom_1 - center_atom
1347
+ bond_2 = atom_2 - center_atom
1348
+
1349
+ # Construct an orthonormal basis from the bond vectors.
1350
+ # The first vector in the basis, the normalized bond_1 vector.
1351
+ # basis_vector_1 [B, L, 3] - normalized bond_1 vector.
1352
+ basis_vector_1 = torch.nn.functional.normalize(bond_1, dim=-1)
1353
+
1354
+ # Project bond_2 onto the first vector in the basis.
1355
+ # length_bond_2_proj [B, L, 1] - length of the projection of bond_2 onto
1356
+ # basis_vector_1.
1357
+ length_bond_2_proj = torch.einsum("bli, bli -> bl", basis_vector_1, bond_2)[
1358
+ ..., None
1359
+ ]
1360
+
1361
+ # bond_2_orthogonal_component [B, L, 3] - component of bond_2 vector
1362
+ # orthogonal to basis_vector_1.
1363
+ bond_2_orthogonal_component = bond_2 - basis_vector_1 * length_bond_2_proj
1364
+
1365
+ # basis_vector_2 [B, L, 3] - normalized bond_2 orthogonal component.
1366
+ basis_vector_2 = torch.nn.functional.normalize(
1367
+ bond_2_orthogonal_component, dim=-1
1368
+ )
1369
+
1370
+ # basis_vector_3 [B, L, 3] - cross product of the first two basis
1371
+ # vectors.
1372
+ basis_vector_3 = torch.cross(basis_vector_1, basis_vector_2, dim=-1)
1373
+
1374
+ # By construction, basis_vector_1, basis_vector_2, and basis_vector_3
1375
+ # form an orthonormal basis. We stack them together to form a
1376
+ # rotation matrix. This rotation matrix can be used to transform from
1377
+ # the local coordinate system defined by the bond vectors to the global
1378
+ # coordinate system:
1379
+ # v_global = R_residue @ v_local + center_atom
1380
+ # R_residue [B, L, 3, 3] - rotation matrix.
1381
+ R_residue = torch.cat(
1382
+ (
1383
+ basis_vector_1[:, :, :, None],
1384
+ basis_vector_2[:, :, :, None],
1385
+ basis_vector_3[:, :, :, None],
1386
+ ),
1387
+ dim=-1,
1388
+ )
1389
+
1390
+ # Compute the local coordinates of the ligand atoms with respect to
1391
+ # the center atom.
1392
+ # ligand_subgraph_Y_local [B, L, M, 3] - local coordinates of the ligand
1393
+ # atoms with respect to the center atom.
1394
+ ligand_subgraph_Y_local = torch.einsum(
1395
+ "blqp, blyq -> blyp",
1396
+ R_residue,
1397
+ ligand_subgraph_Y - center_atom[:, :, None, :],
1398
+ )
1399
+
1400
+ # Compute the length of the local vectors projected onto the local x-y
1401
+ # plane.
1402
+ # ligand_subgraph_Y_proj_xy_local_length [B, L, M] - length of the local
1403
+ # vectors projected on the local x-y plane.
1404
+ ligand_subgraph_Y_proj_xy_local_length = torch.sqrt(
1405
+ ligand_subgraph_Y_local[..., 0] ** 2
1406
+ + ligand_subgraph_Y_local[..., 1] ** 2
1407
+ + eps
1408
+ )
1409
+
1410
+ # Compute the cosine and sine of the azimuthal angle.
1411
+ # cos_azimuthal_xy_angle [B, L, M] - cosine of the azimuthal angle in
1412
+ # the local x-y plane.
1413
+ # sin_azimuthal_xy_angle [B, L, M] - sine of the azimuthal angle in the
1414
+ # local x-y plane.
1415
+ cos_azimuthal_xy_angle = (
1416
+ ligand_subgraph_Y_local[..., 0] / ligand_subgraph_Y_proj_xy_local_length
1417
+ )
1418
+ sin_azimuthal_xy_angle = (
1419
+ ligand_subgraph_Y_local[..., 1] / ligand_subgraph_Y_proj_xy_local_length
1420
+ )
1421
+
1422
+ # Compute the length of the local vectors.
1423
+ # ligand_subgraph_Y_local_length [B, L, M] - length of the local
1424
+ # vectors.
1425
+ ligand_subgraph_Y_local_length = (
1426
+ torch.norm(ligand_subgraph_Y_local, dim=-1) + eps
1427
+ )
1428
+
1429
+ # Compute the cosine and sine of the inclination angle (polar angle).
1430
+ # cos_inclination_angle [B, L, M] - cosine of the inclination angle
1431
+ # (polar angle).
1432
+ # sin_inclination_angle [B, L, M] - sine of the inclination angle
1433
+ # (polar angle).
1434
+ cos_inclination_angle = (
1435
+ ligand_subgraph_Y_proj_xy_local_length / ligand_subgraph_Y_local_length
1436
+ )
1437
+ sin_inclination_angle = (
1438
+ ligand_subgraph_Y_local[..., 2] / ligand_subgraph_Y_local_length
1439
+ )
1440
+
1441
+ # Concatenate the angle features.
1442
+ angle_features = torch.cat(
1443
+ (
1444
+ cos_azimuthal_xy_angle[..., None],
1445
+ sin_azimuthal_xy_angle[..., None],
1446
+ cos_inclination_angle[..., None],
1447
+ sin_inclination_angle[..., None],
1448
+ ),
1449
+ dim=-1,
1450
+ )
1451
+
1452
+ return angle_features
1453
+
1454
+ def gather_nearest_per_residue_atoms(
1455
+ self,
1456
+ per_residue_ligand_coords,
1457
+ per_residue_ligand_mask,
1458
+ per_residue_ligand_types,
1459
+ X_virtual_atoms,
1460
+ X_m_virtual_atoms,
1461
+ residue_mask,
1462
+ ):
1463
+ """
1464
+ Given the 3D coordinates of the ligand atoms, their mask, and the
1465
+ virtual atoms, gather the nearest ligand atoms to the virtual atoms for
1466
+ each residue.
1467
+
1468
+ NOTE:
1469
+ num_ligand_atoms = N
1470
+ when called in self.gather_nearest_ligand_atoms.
1471
+ num_ligand_atoms = M
1472
+ when called in self.combine_ligand_and_side_chain_atoms.
1473
+
1474
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
1475
+ residue subgraph.
1476
+
1477
+ Args:
1478
+ per_residue_ligand_coords (torch.Tensor): [B, L, num_ligand_atoms,
1479
+ 3] - per residue 3D coordinates of the ligand atoms.
1480
+ per_residue_ligand_mask (torch.Tensor): [B, L, num_ligand_atoms] -
1481
+ per residue mask indicating which ligand atoms are valid.
1482
+ per_residue_ligand_types (torch.Tensor): [B, L, num_ligand_atoms] -
1483
+ per residue element types of the ligand atoms (atomic numbers).
1484
+ X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] - 3D
1485
+ coordinates of the virtual atoms for each residue.
1486
+ X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] -
1487
+ mask indicating which virtual atoms are valid.
1488
+ residue_mask (torch.Tensor): [B, L] - mask indicating which residues
1489
+ are valid.
1490
+ Returns:
1491
+ ligand_subgraph_Y (torch.Tensor):
1492
+ [B, L, M, 3] - 3D coordinates of the nearest ligand atoms to the
1493
+ virtual atoms for each residue.
1494
+ ligand_subgraph_Y_m (torch.Tensor):
1495
+ [B, L, M] - mask indicating which nearest ligand atoms to the
1496
+ virtual atoms are valid.
1497
+ ligand_subgraph_Y_t (torch.Tensor):
1498
+ [B, L, M] - element types of the nearest ligand atoms to the
1499
+ virtual atoms for each residue.
1500
+ """
1501
+ B, L, num_ligand_atoms, _ = per_residue_ligand_coords.shape
1502
+
1503
+ # X_virtual_atoms_collapsed [B, L, 3] - collapse the virtual atom
1504
+ # dimension.
1505
+ # NOTE: collapsing along this dimension is okay because the
1506
+ # self.construct_X_virtual_atoms function ensures that there is only one
1507
+ # virtual atom per residue.
1508
+ X_virtual_atoms_collapsed = torch.sum(
1509
+ X_virtual_atoms * X_m_virtual_atoms[:, :, :, None], dim=2
1510
+ )
1511
+
1512
+ # ligand_to_virtual_atom_distances [B, L, num_ligand_atoms] -
1513
+ # distance between the ligand atoms and the virtual atoms.
1514
+ ligand_to_virtual_atom_distances = torch.sqrt(
1515
+ torch.sum(
1516
+ (X_virtual_atoms_collapsed[:, :, None, :] - per_residue_ligand_coords)
1517
+ ** 2,
1518
+ dim=-1,
1519
+ )
1520
+ )
1521
+
1522
+ # residue_and_ligand_mask [B, L, num_ligand_atoms] - mask indicating
1523
+ # which residue-ligand atom pairs are valid.
1524
+ residue_and_ligand_mask = (
1525
+ residue_mask[:, :, None] * per_residue_ligand_mask
1526
+ ).bool()
1527
+
1528
+ # ligand_to_virtual_atom_distances_adjusted [B, L, num_ligand_atoms] -
1529
+ # distances between the virtual atoms and the ligand atoms, with
1530
+ # invalid residue-ligand atom pairs adjusted to a maximum distance.
1531
+ ligand_to_virtual_atom_distances_adjusted = (
1532
+ ligand_to_virtual_atom_distances * residue_and_ligand_mask
1533
+ + (~residue_and_ligand_mask) * self.max_distance_for_ligand_atoms
1534
+ )
1535
+
1536
+ # E_idx_ligand_subgraph [B, L, M] - indices of the closest ligand atoms
1537
+ # to the virtual atoms.
1538
+ _, E_idx_ligand_subgraph = torch.topk(
1539
+ ligand_to_virtual_atom_distances_adjusted,
1540
+ min(self.num_context_atoms, num_ligand_atoms),
1541
+ dim=-1,
1542
+ largest=False,
1543
+ )
1544
+
1545
+ # Gather the ligand atom coordinates, mask, and types based on the
1546
+ # indices of the closest ligand atoms to the virtual atoms.
1547
+ # ligand_subgraph_Y [B, L, M, 3] - 3D coordinates of the nearest ligand
1548
+ # atoms to the virtual atoms for each residue.
1549
+ ligand_subgraph_Y = torch.gather(
1550
+ per_residue_ligand_coords,
1551
+ dim=2,
1552
+ index=E_idx_ligand_subgraph[:, :, :, None].expand(-1, -1, -1, 3),
1553
+ )
1554
+
1555
+ # ligand_subgraph_Y_m [B, L, M] - mask indicating which nearest ligand
1556
+ # atoms to the virtual atoms are valid.
1557
+ ligand_subgraph_Y_m = torch.gather(
1558
+ per_residue_ligand_mask, dim=2, index=E_idx_ligand_subgraph
1559
+ )
1560
+
1561
+ # ligand_subgraph_Y_t [B, L, M] - element types of the nearest ligand
1562
+ # atoms to the virtual atoms for each residue.
1563
+ ligand_subgraph_Y_t = torch.gather(
1564
+ per_residue_ligand_types, dim=2, index=E_idx_ligand_subgraph
1565
+ )
1566
+
1567
+ return ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t
1568
+
1569
+ def gather_nearest_ligand_atoms(
1570
+ self, Y, Y_m, Y_t, X_virtual_atoms, X_m_virtual_atoms, residue_mask
1571
+ ):
1572
+ """
1573
+ Given the 3D coordinates of the ligand atoms, their mask, and the
1574
+ virtual atoms, gather the nearest ligand atoms to the virtual atoms for
1575
+ each residue.
1576
+
1577
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
1578
+ residue subgraph.
1579
+
1580
+ Args:
1581
+ Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand atoms.
1582
+ Y_m (torch.Tensor): [B, N] - Mask indicating which ligand atoms
1583
+ are valid.
1584
+ Y_t (torch.Tensor): [B, N] - Element types of the ligand atoms
1585
+ (atomic numbers).
1586
+ X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] - 3D
1587
+ coordinates of the virtual atoms for each residue.
1588
+ X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] -
1589
+ Mask indicating which virtual atoms are valid.
1590
+ residue_mask (torch.Tensor): [B, L] - Mask indicating which residues
1591
+ are valid.
1592
+ Returns:
1593
+ ligand_subgraph_Y (torch.Tensor):
1594
+ [B, L, M, 3] - 3D coordinates of the nearest ligand atoms to the
1595
+ virtual atoms for each residue.
1596
+ ligand_subgraph_Y_m (torch.Tensor):
1597
+ [B, L, M] - Mask indicating which nearest ligand atoms to the
1598
+ virtual atoms are valid.
1599
+ ligand_subgraph_Y_t (torch.Tensor):
1600
+ [B, L, M] - Element types of the nearest ligand atoms to the
1601
+ virtual atoms for each residue.
1602
+ """
1603
+ B, L, _, _ = X_virtual_atoms.shape
1604
+
1605
+ # Gather the nearest ligand atoms to the virtual atoms for each
1606
+ # residue.
1607
+ # ligand_subgraph_Y [B, L, M, 3] - 3D coordinates of the nearest ligand
1608
+ # atoms to the virtual atoms for each residue.
1609
+ # ligand_subgraph_Y_m [B, L, M] - mask indicating which nearest ligand
1610
+ # atoms to the virtual atoms are valid.
1611
+ # ligand_subgraph_Y_t [B, L, M] - element types of the nearest ligand
1612
+ # atoms to the virtual atoms for each residue.
1613
+ ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t = (
1614
+ self.gather_nearest_per_residue_atoms(
1615
+ Y[:, None, :, :].expand(-1, L, -1, -1),
1616
+ Y_m[:, None, :].expand(-1, L, -1),
1617
+ Y_t[:, None, :].expand(-1, L, -1),
1618
+ X_virtual_atoms,
1619
+ X_m_virtual_atoms,
1620
+ residue_mask,
1621
+ )
1622
+ )
1623
+
1624
+ return ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t
1625
+
1626
+ def gather_nearest_atomized_side_chain_atoms(
1627
+ self, X, X_m, S, E_idx, hide_side_chain_mask
1628
+ ):
1629
+ """
1630
+ Given the 3D coordinates of the polymer atoms, their mask, the indices
1631
+ of the top K nearest neighbors for each residue, and a mask indicating
1632
+ which side chains are hidden, gather the nearest neighbors side chain
1633
+ atoms for each residue. This is used to construct the atomized side
1634
+ chain atoms for each residue.
1635
+
1636
+ Args:
1637
+ X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of the
1638
+ polymer atoms.
1639
+ X_m (torch.Tensor): [B, L, num_atoms] - Mask indicating which
1640
+ polymer atoms are valid.
1641
+ S (torch.Tensor): [B, L] - Sequence of the polymer residues.
1642
+ E_idx (torch.Tensor): [B, L, K] - Indices of the top K nearest
1643
+ neighbors for each residue.
1644
+ hide_side_chain_mask (torch.Tensor): [B, L] - Mask indicating which
1645
+ residue side chains are hidden and which are revealed. True
1646
+ indicates that the side chain is hidden and False indicates
1647
+ that the side chain is revealed.
1648
+ Returns:
1649
+ ligand_subgraph_R (torch.Tensor): [B, L,
1650
+ num_neighbors_for_atomized_side_chain * num_side_chain_atoms, 3]
1651
+ - 3D coordinates of the nearest neighbors side chain atoms for
1652
+ each residue.
1653
+ ligand_subgraph_R_m (torch.Tensor): [B, L,
1654
+ num_neighbors_for_atomized_side_chain * num_side_chain_atoms] -
1655
+ mask indicating which nearest neighbors side chain atoms are
1656
+ valid.
1657
+ ligand_subgraph_R_t (torch.Tensor): [B, L,
1658
+ num_neighbors_for_atomized_side_chain * num_side_chain_atoms] -
1659
+ Element types of the nearest neighbors side chain atoms for each
1660
+ residue.
1661
+ """
1662
+ B, L, _, _ = X.shape
1663
+
1664
+ # X_side_chain [B, L, len(self.SIDE_CHAIN_ATOM_NAMES), 3] - 3D
1665
+ # coordinates of the side chain atoms for each residue.
1666
+ # X_m_side_chain [B, L, len(self.SIDE_CHAIN_ATOM_NAMES)] - mask
1667
+ # indicating which side chain atoms are valid.
1668
+ # NOTE: the side chain atoms exclude the CB atom, since in other
1669
+ # places, we use the virtual CB atom.
1670
+ X_side_chain, X_m_side_chain = self.construct_X_side_chain(X, X_m, S)
1671
+
1672
+ # E_idx_sub [B, L, self.num_neighbors_for_atomized_side_chain] -
1673
+ # Indices of the nearest neighbors to consider for atomized side chain
1674
+ # atoms.
1675
+ E_idx_sub = E_idx[:, :, : self.num_neighbors_for_atomized_side_chain]
1676
+
1677
+ # ligand_subgraph_R [B, L, self.num_neighbors_for_atomized_side_chain *
1678
+ # self.num_side_chain_atoms, 3] - 3D coordinates of the nearest
1679
+ # neighbors side chain atoms for each residue.
1680
+ ligand_subgraph_R = gather_nodes(
1681
+ X_side_chain.view(B, L, self.num_side_chain_atoms * 3), E_idx_sub
1682
+ ).view(
1683
+ B,
1684
+ L,
1685
+ self.num_neighbors_for_atomized_side_chain * self.num_side_chain_atoms,
1686
+ 3,
1687
+ )
1688
+
1689
+ # ligand_subgraph_R_m [B, L, self.num_neighbors_for_atomized_side_chain
1690
+ # * self.num_side_chain_atoms] - mask indicating which nearest
1691
+ # neighbors side chain atoms are valid.
1692
+ ligand_subgraph_R_m = gather_nodes(
1693
+ X_m_side_chain & (~(hide_side_chain_mask[:, :, None].bool())), E_idx_sub
1694
+ ).view(
1695
+ B, L, self.num_neighbors_for_atomized_side_chain * self.num_side_chain_atoms
1696
+ )
1697
+
1698
+ # ligand_subgraph_R_t [B, L, self.num_neighbors_for_atomized_side_chain
1699
+ # * self.num_side_chain_atoms] - element types of the nearest
1700
+ # neighbors side chain atoms for each residue.
1701
+ ligand_subgraph_R_t = (
1702
+ self.side_chain_atom_types[None, None, None, :]
1703
+ .expand(B, L, self.num_neighbors_for_atomized_side_chain, -1)
1704
+ .reshape(
1705
+ B,
1706
+ L,
1707
+ self.num_neighbors_for_atomized_side_chain * self.num_side_chain_atoms,
1708
+ )
1709
+ )
1710
+
1711
+ return ligand_subgraph_R, ligand_subgraph_R_m, ligand_subgraph_R_t
1712
+
1713
+ def combine_ligand_and_atomized_side_chain_atoms(
1714
+ self,
1715
+ ligand_subgraph_Y,
1716
+ ligand_subgraph_Y_m,
1717
+ ligand_subgraph_Y_t,
1718
+ ligand_subgraph_R,
1719
+ ligand_subgraph_R_m,
1720
+ ligand_subgraph_R_t,
1721
+ X_virtual_atoms,
1722
+ X_m_virtual_atoms,
1723
+ residue_mask,
1724
+ ):
1725
+ """
1726
+ Given the 3D coordinates of the nearest ligand atoms to the virtual
1727
+ atoms, their mask, the element types of the nearest ligand atoms, the
1728
+ 3D coordinates of the nearest neighbors side chain atoms, their mask,
1729
+ and the element types of the nearest neighbors side chain atoms,
1730
+ combine the ligand and side chain atoms into a single tensor.
1731
+
1732
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
1733
+ residue subgraph.
1734
+
1735
+ Args:
1736
+ ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
1737
+ the nearest ligand atoms to the virtual atoms for each residue.
1738
+ ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask indicating
1739
+ which nearest ligand atoms to the virtual atoms are valid.
1740
+ ligand_subgraph_Y_t (torch.Tensor): [B, L, M] - element types of the
1741
+ nearest ligand atoms to the virtual atoms for each residue.
1742
+ ligand_subgraph_R (torch.Tensor): [B, L,
1743
+ self.num_neighbors_for_atomized_side_chain *
1744
+ self.num_side_chain_atoms, 3] - 3D coordinates of the nearest
1745
+ neighbors side chain atoms for each residue.
1746
+ ligand_subgraph_R_m (torch.Tensor): [B, L,
1747
+ self.num_neighbors_for_atomized_side_chain *
1748
+ self.num_side_chain_atoms] - mask indicating which nearest
1749
+ neighbors side chain atoms are valid.
1750
+ ligand_subgraph_R_t (torch.Tensor): [B, L,
1751
+ self.num_neighbors_for_atomized_side_chain *
1752
+ self.num_side_chain_atoms] - element types of the nearest
1753
+ neighbors side chain atoms for each residue.
1754
+ X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] - 3D
1755
+ coordinates of the virtual atoms for each residue.
1756
+ X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] - mask
1757
+ indicating which virtual atoms are valid.
1758
+ residue_mask (torch.Tensor): [B, L] - mask indicating which
1759
+ residues are valid.
1760
+ Returns:
1761
+ ligand_subgraph_Y_and_R (torch.Tensor): [B, L, M, 3] - 3D
1762
+ coordinates of the nearest ligand or side chain atoms to the
1763
+ virtual atoms for each residue.
1764
+ ligand_subgraph_Y_m_and_R_m (torch.Tensor): [B, L, M] - mask
1765
+ indicating which nearest ligand or side chain atoms to the
1766
+ virtual atoms are valid.
1767
+ ligand_subgraph_Y_t_and_R_t (torch.Tensor): [B, L, M] - element
1768
+ types of the nearest ligand or side chain atoms to the virtual
1769
+ atoms for each residue.
1770
+ """
1771
+ # Concatenate the ligand and side chain atom coordinates, masks, and
1772
+ # types.
1773
+ # ligand_subgraph_Y_cat_R [B, L, M +
1774
+ # self.num_neighbors_for_atomized_side_chain *
1775
+ # self.num_side_chain_atoms, 3] - 3D coordinates of the nearest ligand
1776
+ # atoms and side chain atoms for each residue.
1777
+ ligand_subgraph_Y_cat_R = torch.cat(
1778
+ (ligand_subgraph_Y, ligand_subgraph_R), dim=2
1779
+ )
1780
+
1781
+ # ligand_subgraph_Y_m_cat_R_m [B, L, M +
1782
+ # self.num_neighbors_for_atomized_side_chain *
1783
+ # self.num_side_chain_atoms] - mask indicating which nearest ligand
1784
+ # atoms and side chain atoms are valid.
1785
+ ligand_subgraph_Y_m_cat_R_m = torch.cat(
1786
+ (ligand_subgraph_Y_m, ligand_subgraph_R_m), dim=2
1787
+ )
1788
+
1789
+ # ligand_subgraph_Y_t_cat_R_t [B, L, M +
1790
+ # self.num_neighbors_for_atomized_side_chain *
1791
+ # self.num_side_chain_atoms] - element types of the nearest ligand
1792
+ # atoms and side chain atoms for each residue.
1793
+ ligand_subgraph_Y_t_cat_R_t = torch.cat(
1794
+ (ligand_subgraph_Y_t, ligand_subgraph_R_t), dim=2
1795
+ )
1796
+
1797
+ # Gather the nearest atoms to the virtual atoms from the combined
1798
+ # ligand and side chain atoms.
1799
+ # ligand_subgraph_Y_and_R [B, L, M, 3] - 3D coordinates of the nearest
1800
+ # ligand or side chain atoms to the virtual atoms for each residue.
1801
+ # ligand_subgraph_Y_m_and_R_m [B, L, M] - mask indicating which nearest
1802
+ # ligand or side chain atoms to the virtual atoms are valid.
1803
+ # ligand_subgraph_Y_t_and_R_t [B, L, M] - element types of the nearest
1804
+ # ligand or side chain atoms to the virtual atoms for each residue.
1805
+ (
1806
+ ligand_subgraph_Y_and_R,
1807
+ ligand_subgraph_Y_m_and_R_m,
1808
+ ligand_subgraph_Y_t_and_R_t,
1809
+ ) = self.gather_nearest_per_residue_atoms(
1810
+ ligand_subgraph_Y_cat_R,
1811
+ ligand_subgraph_Y_m_cat_R_m,
1812
+ ligand_subgraph_Y_t_cat_R_t,
1813
+ X_virtual_atoms,
1814
+ X_m_virtual_atoms,
1815
+ residue_mask,
1816
+ )
1817
+
1818
+ return (
1819
+ ligand_subgraph_Y_and_R,
1820
+ ligand_subgraph_Y_m_and_R_m,
1821
+ ligand_subgraph_Y_t_and_R_t,
1822
+ )
1823
+
1824
+ def featurize_ligand_atom_type_information(self, ligand_subgraph_Y_t):
1825
+ """
1826
+ Given the element types of the ligand atoms, compute the periodic table
1827
+ group, period, and atomic number for each ligand atom.
1828
+
1829
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
1830
+ residue subgraph.
1831
+
1832
+ Args:
1833
+ ligand_subgraph_Y_t (torch.Tensor): [B, L, M] - element types of the
1834
+ ligand atoms.
1835
+ Returns:
1836
+ ligand_subgraph_Y_t_concat_one_hot (torch.Tensor):
1837
+ [B, L, M, self.num_atomic_numbers +
1838
+ self.num_periodic_table_groups +
1839
+ self.num_periodic_table_periods] - atomic number,
1840
+ periodic group, and periodic period of the ligand atoms, as
1841
+ concatenated one-hot encodings.
1842
+ """
1843
+ # Get the periodic table group and period for the ligand atoms.
1844
+ ligand_subgraph_Y_t = ligand_subgraph_Y_t.long()
1845
+
1846
+ # ligand_subgraph_Y_t_g [B, L, M] - periodic group of the ligand atoms.
1847
+ # 18 groups and 1 null group.
1848
+ ligand_subgraph_Y_t_g = self.periodic_table_groups[ligand_subgraph_Y_t]
1849
+
1850
+ # ligand_subgraph_Y_t_p [B, L, M] - periodic period of the ligand atoms.
1851
+ # 7 periods and 1 null period.
1852
+ ligand_subgraph_Y_t_p = self.periodic_table_periods[ligand_subgraph_Y_t]
1853
+
1854
+ # Turn the ligand atom types into one-hot encodings.
1855
+ # ligand_subgraph_Y_t_g_one_hot [B, L, M,
1856
+ # self.num_periodic_table_groups] - periodic group of the ligand atoms.
1857
+ ligand_subgraph_Y_t_g_one_hot = torch.nn.functional.one_hot(
1858
+ ligand_subgraph_Y_t_g, self.num_periodic_table_groups
1859
+ )
1860
+
1861
+ # ligand_subgraph_Y_t_p_one_hot [B, L, M,
1862
+ # self.num_periodic_table_periods] - periodic period of the ligand
1863
+ # atoms.
1864
+ ligand_subgraph_Y_t_p_one_hot = torch.nn.functional.one_hot(
1865
+ ligand_subgraph_Y_t_p, self.num_periodic_table_periods
1866
+ )
1867
+
1868
+ # ligand_subgraph_Y_t_one_hot [B, L, M, self.num_atomic_numbers] -
1869
+ # atomic number of the ligand atoms.
1870
+ ligand_subgraph_Y_t_one_hot = torch.nn.functional.one_hot(
1871
+ ligand_subgraph_Y_t, self.num_atomic_numbers
1872
+ )
1873
+
1874
+ # Concatenate the one-hot encodings.
1875
+ # ligand_subgraph_Y_t_concat_one_hot [B, L, M, self.num_atomic_numbers +
1876
+ # self.num_periodic_table_groups + self.num_periodic_table_periods]
1877
+ # - atomic number, periodic group, and periodic period of the ligand
1878
+ # atoms.
1879
+ ligand_subgraph_Y_t_concat_one_hot = torch.cat(
1880
+ (
1881
+ ligand_subgraph_Y_t_one_hot,
1882
+ ligand_subgraph_Y_t_g_one_hot,
1883
+ ligand_subgraph_Y_t_p_one_hot,
1884
+ ),
1885
+ dim=-1,
1886
+ )
1887
+
1888
+ return ligand_subgraph_Y_t_concat_one_hot
1889
+
1890
+ def featurize_protein_to_ligand_subgraph_edges(
1891
+ self,
1892
+ ligand_subgraph_Y_t_concat_one_hot,
1893
+ X_backbone,
1894
+ X_m_backbone,
1895
+ X_virtual_atoms,
1896
+ X_m_virtual_atoms,
1897
+ ligand_subgraph_Y,
1898
+ ligand_subgraph_Y_m,
1899
+ eps=1e-6,
1900
+ ):
1901
+ """
1902
+ Given the 3D coordinates of the backbone atoms, the virtual atoms,
1903
+ the ligand atoms, and the mask indicating which atoms are valid,
1904
+ compute the protein to ligand subgraph edges.
1905
+
1906
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
1907
+ residue subgraph.
1908
+
1909
+ Args:
1910
+ ligand_subgraph_Y_t_concat_one_hot (torch.Tensor):
1911
+ [B, L, M, self.num_atomic_numbers +
1912
+ self.num_periodic_table_groups +
1913
+ self.num_periodic_table_periods] - atomic number,
1914
+ periodic group, and periodic period of the ligand atoms.
1915
+ X_backbone (torch.Tensor): [B, L, self.num_backbone_atoms, 3] - 3D
1916
+ coordinates of the backbone atoms for each residue.
1917
+ X_m_backbone (torch.Tensor): [B, L, self.num_backbone_atoms] - mask
1918
+ indicating which backbone atoms are valid.
1919
+ X_virtual_atoms (torch.Tensor): [B, L, self.num_virtual_atoms, 3] -
1920
+ 3D coordinates of the virtual atoms for each residue.
1921
+ X_m_virtual_atoms (torch.Tensor): [B, L, self.num_virtual_atoms] -
1922
+ mask indicating which virtual atoms are valid.
1923
+ ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
1924
+ the ligand atoms.
1925
+ ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask indicating
1926
+ which ligand atoms are valid.
1927
+ eps (float): Small value added to distances that are zero.
1928
+ Returns:
1929
+ E_protein_to_ligand (torch.Tensor):
1930
+ [B, L, M, self.num_node_output_features] -
1931
+ protein to ligand subgraph edges; can also be considered node
1932
+ features of the protein residues (although they are not used as
1933
+ such).
1934
+ """
1935
+ B, L, M, _ = ligand_subgraph_Y_t_concat_one_hot.shape
1936
+
1937
+ # Embed the ligand atom type information.
1938
+ # ligand_subgraph_Y_t_concat_one_hot_embed
1939
+ # [B, L, M, self.num_atom_type_output_features] - embedded atomic
1940
+ # number, periodic group, and periodic period of the ligand atoms.
1941
+ ligand_subgraph_Y_t_concat_one_hot_embed = self.embed_atom_type_features(
1942
+ ligand_subgraph_Y_t_concat_one_hot.float()
1943
+ )
1944
+
1945
+ # Concatenate the backbone and virtual atom coordinates and masks.
1946
+ # X_backbone_and_virtual_atoms [B, L, num_backbone_atoms +
1947
+ # num_virtual_atoms, 3] - 3D coordinates of the backbone and virtual
1948
+ # atoms for each residue.
1949
+ # X_m_backbone_and_virtual_atoms [B, L, num_backbone_atoms +
1950
+ # num_virtual_atoms] - mask indicating which backbone and virtual
1951
+ # atoms are valid.
1952
+ X_backbone_and_virtual_atoms = torch.cat((X_backbone, X_virtual_atoms), dim=2)
1953
+ X_m_backbone_and_virtual_atoms = torch.cat(
1954
+ (X_m_backbone, X_m_virtual_atoms), dim=2
1955
+ )
1956
+
1957
+ # Compute the distance of each ligand atom in each residue subgraph to
1958
+ # to each of the backbone and virtual atoms.
1959
+ # D_ligand_to_backbone_or_virtual [B, L, M, self.num_backbone_atoms +
1960
+ # self.num_virtual_atoms] - distances of each ligand atom in each
1961
+ # residue subgraph to each of the backbone and virtual atoms.
1962
+ D_ligand_to_backbone_or_virtual = torch.sqrt(
1963
+ torch.sum(
1964
+ (
1965
+ ligand_subgraph_Y[:, :, :, None, :]
1966
+ - X_backbone_and_virtual_atoms[:, :, None, :, :]
1967
+ )
1968
+ ** 2,
1969
+ dim=-1,
1970
+ )
1971
+ + eps
1972
+ )
1973
+
1974
+ # RBF_ligand_to_backbone_or_virtual [B, L, M, self.num_backbone_atoms +
1975
+ # self.num_virtual_atoms, num_rbf] - radial basis function embedding
1976
+ # of the distances of each ligand atom in each residue subgraph to
1977
+ # each of the backbone and virtual atoms.
1978
+ RBF_ligand_to_backbone_or_virtual = self.compute_rbf_embedding_from_distances(
1979
+ D_ligand_to_backbone_or_virtual
1980
+ )
1981
+
1982
+ # Mask the radial basis function embedding with the ligand atom mask and
1983
+ # the backbone and virtual atom mask.
1984
+ RBF_ligand_to_backbone_or_virtual = (
1985
+ RBF_ligand_to_backbone_or_virtual
1986
+ * ligand_subgraph_Y_m[:, :, :, None, None]
1987
+ * X_m_backbone_and_virtual_atoms[:, :, None, :, None]
1988
+ )
1989
+
1990
+ # Reshape the radial basis function embedding.
1991
+ # RBF_ligand_to_backbone_or_virtual [B, L, M,
1992
+ # (self.num_backbone_atoms + self.num_virtual_atoms) * num_rbf] -
1993
+ RBF_ligand_to_backbone_or_virtual = RBF_ligand_to_backbone_or_virtual.view(
1994
+ B, L, M, (self.num_backbone_atoms + self.num_virtual_atoms) * self.num_rbf
1995
+ )
1996
+
1997
+ # Construct the angle features for the ligand atoms with respect to the
1998
+ # backbone atoms.
1999
+ # angle_features [B, L, M, 4] - angle features for the ligand atoms
2000
+ # with respect to the backbone atoms.
2001
+ angle_features = self.construct_angle_features(
2002
+ X_backbone[:, :, self.BACKBONE_ATOM_NAMES.index("CA"), :],
2003
+ X_backbone[:, :, self.BACKBONE_ATOM_NAMES.index("N"), :],
2004
+ X_backbone[:, :, self.BACKBONE_ATOM_NAMES.index("C"), :],
2005
+ ligand_subgraph_Y,
2006
+ )
2007
+
2008
+ # E_protein_to_ligand [B, L, M,
2009
+ # (self.num_backbone_atoms + self.num_virtual_atoms) * num_rbf +
2010
+ # self.num_atom_type_output_features + 4] - concatenated
2011
+ # radial basis function embedding of the distances of each ligand atom
2012
+ # in each residue subgraph to each of the backbone and virtual atoms,
2013
+ # periodic group, periodic period, and atomic number of the ligand
2014
+ # atoms, and the angle features for the ligand atoms with respect to
2015
+ # the backbone atoms.
2016
+ E_protein_to_ligand = torch.cat(
2017
+ (
2018
+ RBF_ligand_to_backbone_or_virtual,
2019
+ ligand_subgraph_Y_t_concat_one_hot_embed,
2020
+ angle_features,
2021
+ ),
2022
+ dim=-1,
2023
+ )
2024
+
2025
+ # While these are protein-ligand subgraph edges, they can also be
2026
+ # considered node features of the protein residues.
2027
+ # E_protein_to_ligand [B, L, M, self.num_node_output_features] - protein
2028
+ # to ligand subgraph edges.
2029
+ E_protein_to_ligand = self.node_embedding(E_protein_to_ligand)
2030
+ E_protein_to_ligand = self.node_norm(E_protein_to_ligand)
2031
+
2032
+ return E_protein_to_ligand
2033
+
2034
+ def featurize_ligand_subgraph_nodes(self, ligand_subgraph_Y_t_concat_one_hot):
2035
+ """
2036
+ Given the atomic number, periodic group, and periodic period of the
2037
+ ligand atoms, compute the ligand subgraph node features.
2038
+
2039
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
2040
+ residue subgraph.
2041
+
2042
+ Args:
2043
+ ligand_subgraph_Y_t_concat_one_hot (torch.Tensor):
2044
+ [B, L, M, self.num_atomic_numbers +
2045
+ self.num_periodic_table_groups +
2046
+ self.num_periodic_table_periods] - atomic number,
2047
+ periodic group, and periodic period of the ligand atoms.
2048
+ Returns:
2049
+ ligand_subgraph_nodes (torch.Tensor): [B, L, M,
2050
+ self.num_node_output_features] - ligand atom type information,
2051
+ embedded as node features.
2052
+ """
2053
+ # Embed and normalize the ligand atom type information.
2054
+ # ligand_subgraph_nodes [B, L, M, self.num_atom_type_output_features] -
2055
+ # embedded atomic number, periodic group, and periodic period of the
2056
+ # ligand atoms.
2057
+ ligand_subgraph_nodes = self.ligand_subgraph_node_embedding(
2058
+ ligand_subgraph_Y_t_concat_one_hot.float()
2059
+ )
2060
+ ligand_subgraph_nodes = self.ligand_subgraph_node_norm(ligand_subgraph_nodes)
2061
+
2062
+ return ligand_subgraph_nodes
2063
+
2064
+ def featurize_ligand_subgraph_edges(
2065
+ self, ligand_subgraph_Y, ligand_subgraph_Y_m, eps=1e-6
2066
+ ):
2067
+ """
2068
+ Given the 3D coordinates of the ligand atoms and the mask indicating
2069
+ which atoms are valid, compute the ligand subgraph edges.
2070
+
2071
+ NOTE: M = self.num_context_atoms, the number of ligand atoms in each
2072
+ residue subgraph.
2073
+
2074
+ Args:
2075
+ ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
2076
+ the ligand atoms.
2077
+ ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask indicating
2078
+ which ligand atoms are valid.
2079
+ eps (float): Small value added to distances that are zero.
2080
+ Returns:
2081
+ ligand_subgraph_edges (torch.Tensor):
2082
+ [B, L, M, M, self.num_edge_output_features] - embedded and
2083
+ normalized radial basis function embedding of the distances
2084
+ between the ligand atoms in each residue subgraph.
2085
+ """
2086
+ # D_ligand_to_ligand [B, L, M, M] - distances between the ligand atoms
2087
+ # in each residue subgraph.
2088
+ D_ligand_to_ligand = torch.sqrt(
2089
+ torch.sum(
2090
+ (
2091
+ ligand_subgraph_Y[:, :, :, None, :]
2092
+ - ligand_subgraph_Y[:, :, None, :, :]
2093
+ )
2094
+ ** 2,
2095
+ dim=-1,
2096
+ )
2097
+ + eps
2098
+ )
2099
+
2100
+ # RBF_ligand_to_ligand [B, L, M, M, num_rbf] - radial basis function
2101
+ # embedding of the distances between the ligand atoms in each residue
2102
+ # subgraph.
2103
+ RBF_ligand_to_ligand = self.compute_rbf_embedding_from_distances(
2104
+ D_ligand_to_ligand
2105
+ )
2106
+
2107
+ # Mask the radial basis function embedding with the ligand atom mask.
2108
+ RBF_ligand_to_ligand = (
2109
+ RBF_ligand_to_ligand
2110
+ * ligand_subgraph_Y_m[:, :, :, None, None]
2111
+ * ligand_subgraph_Y_m[:, :, None, :, None]
2112
+ )
2113
+
2114
+ # ligand_subgraph_edges [B, L, M, M, self.num_edge_output_features] -
2115
+ # embedded and normalized radial basis function embedding of the
2116
+ # distances between the ligand atoms in each residue subgraph.
2117
+ ligand_subgraph_edges = self.ligand_subgraph_edge_embedding(
2118
+ RBF_ligand_to_ligand
2119
+ )
2120
+ ligand_subgraph_edges = self.ligand_subgraph_edge_norm(ligand_subgraph_edges)
2121
+
2122
+ return ligand_subgraph_edges
2123
+
2124
+ def featurize_nodes(self, input_features, edge_features):
2125
+ """
2126
+ Given the input features and edge features, compute the node features
2127
+ for the ligand atoms and the protein to ligand subgraph edges.
2128
+
2129
+ NOTE: N = the total number of ligand atoms.
2130
+ M = self.num_context_atoms, the number of ligand atoms in each
2131
+ residue subgraph.
2132
+
2133
+ Args:
2134
+ input_features (dict): Dictionary containing the input features.
2135
+ - X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of the
2136
+ polymer atoms.
2137
+ - X_m (torch.Tensor): [B, L, num_atoms] - mask indicating which
2138
+ polymer atoms are valid.
2139
+ - hide_side_chain_mask (torch.Tensor): [B, L] - mask
2140
+ indicating which residue side chains are hidden and which
2141
+ are revealed. True indicates that the side chain is hidden
2142
+ and False indicates that the side chain is revealed.
2143
+ - Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand
2144
+ atoms.
2145
+ - Y_m (torch.Tensor): [B, N] - mask indicating which ligand
2146
+ atoms are valid.
2147
+ - Y_t (torch.Tensor): [B, N] - element types of the ligand
2148
+ atoms.
2149
+ - X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] -
2150
+ 3D coordinates of the virtual atoms for each residue.
2151
+ - X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] -
2152
+ mask indicating which virtual atoms are valid.
2153
+ - residue_mask (torch.Tensor): [B, L] - mask indicating which
2154
+ residues are valid.
2155
+ - X_backbone (torch.Tensor): [B, L, num_backbone_atoms, 3] -
2156
+ 3D coordinates of the backbone atoms for each residue.
2157
+ - X_m_backbone (torch.Tensor): [B, L, num_backbone_atoms] -
2158
+ mask indicating which backbone atoms are valid.
2159
+ - atomize_side_chains (bool): Whether to atomize the side chains
2160
+ of the residues. If True, the side chains of the residues
2161
+ not specified in the hide side chain mask will be
2162
+ atomized and added as ligand atoms.
2163
+ edge_features (dict): Dictionary containing the edge features.
2164
+ - E_idx (torch.Tensor): [B, L, K] - indices of the top K
2165
+ nearest neighbors for each residue.
2166
+ Returns:
2167
+ node_features (dict): Dictionary containing the node features.
2168
+ - E_protein_to_ligand (torch.Tensor):
2169
+ [B, L, M, self.num_node_output_features] - protein to
2170
+ ligand subgraph edges; can also be considered node features
2171
+ of the protein residues (although they are not used as
2172
+ such).
2173
+ - ligand_subgraph_nodes (torch.Tensor):
2174
+ [B, L, M, self.num_node_output_features] - ligand atom type
2175
+ information, embedded as node features.
2176
+ - ligand_subgraph_edges (torch.Tensor):
2177
+ [B, L, M, M, self.num_edge_output_features] - embedded and
2178
+ normalized radial basis function embedding of the distances
2179
+ between the ligand atoms in each residue subgraph.
2180
+ """
2181
+ # Check that the needed input features are present.
2182
+ if "X" not in input_features:
2183
+ raise ValueError("Input features must contain 'X' key.")
2184
+ if "X_m" not in input_features:
2185
+ raise ValueError("Input features must contain 'X_m' key.")
2186
+ if "hide_side_chain_mask" not in input_features:
2187
+ raise ValueError("Input features must contain 'hide_side_chain_mask' key.")
2188
+ if "Y" not in input_features:
2189
+ raise ValueError("Input features must contain 'Y' key.")
2190
+ if "Y_m" not in input_features:
2191
+ raise ValueError("Input features must contain 'Y_m' key.")
2192
+ if "Y_t" not in input_features:
2193
+ raise ValueError("Input features must contain 'Y_t' key.")
2194
+ if "X_virtual_atoms" not in input_features:
2195
+ raise ValueError("Input features must contain 'X_virtual_atoms' key.")
2196
+ if "X_m_virtual_atoms" not in input_features:
2197
+ raise ValueError("Input features must contain 'X_m_virtual_atoms' key.")
2198
+ if "residue_mask" not in input_features:
2199
+ raise ValueError("Input features must contain 'residue_mask' key.")
2200
+ if "X_backbone" not in input_features:
2201
+ raise ValueError("Input features must contain 'X_backbone' key.")
2202
+ if "X_m_backbone" not in input_features:
2203
+ raise ValueError("Input features must contain 'X_m_backbone' key.")
2204
+ if "atomize_side_chains" not in input_features:
2205
+ raise ValueError("Input features must contain 'atomize_side_chains' key.")
2206
+
2207
+ # Check that the needed edge features are present.
2208
+ if "E_idx" not in edge_features:
2209
+ raise ValueError("Edge features must contain 'E_idx' key.")
2210
+
2211
+ atomize_side_chains = input_features["atomize_side_chains"]
2212
+
2213
+ # Gather the coordinates, mask and types of the ligand atoms closest
2214
+ # to the virtual atoms.
2215
+ # ligand_subgraph_Y [B, L, M, 3] - 3D coordinates of the ligand atoms
2216
+ # closest to the virtual atoms for each residue.
2217
+ # ligand_subgraph_Y_m [B, L, M] - mask indicating which ligand
2218
+ # atoms closest to the virtual atoms are valid.
2219
+ # ligand_subgraph_Y_t [B, L, M] - element types of the
2220
+ # ligand atoms closest to the virtual atoms for each residue.
2221
+ ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t = (
2222
+ self.gather_nearest_ligand_atoms(
2223
+ input_features["Y"],
2224
+ input_features["Y_m"],
2225
+ input_features["Y_t"],
2226
+ input_features["X_virtual_atoms"],
2227
+ input_features["X_m_virtual_atoms"],
2228
+ input_features["residue_mask"],
2229
+ )
2230
+ )
2231
+
2232
+ # Add atomized side chain atoms as ligand atoms.
2233
+ if atomize_side_chains:
2234
+ # Gather the atomized side chain atoms coordinates, mask, and types.
2235
+ # ligand_subgraph_R [B, L, num_neighbors_for_atomized_side_chain *
2236
+ # num_side_chain_atoms, 3] - 3D coordinates of the nearest neighbors
2237
+ # side chain atoms for each residue.
2238
+ # ligand_subgraph_R_m [B, L, num_neighbors_for_atomized_side_chain *
2239
+ # num_side_chain_atoms] - mask indicating which nearest neighbors
2240
+ # side chain atoms are valid.
2241
+ # ligand_subgraph_R_t [B, L, num_neighbors_for_atomized_side_chain
2242
+ # * num_side_chain_atoms] - element types of the nearest neighbors
2243
+ # side chain atoms for each residue.
2244
+ ligand_subgraph_R, ligand_subgraph_R_m, ligand_subgraph_R_t = (
2245
+ self.gather_nearest_atomized_side_chain_atoms(
2246
+ input_features["X"],
2247
+ input_features["X_m"],
2248
+ input_features["S"],
2249
+ edge_features["E_idx"],
2250
+ input_features["hide_side_chain_mask"],
2251
+ )
2252
+ )
2253
+
2254
+ # Get the self.num_context_atoms closest ligand or atomized side
2255
+ # chain atoms to the virtual atoms; overwriting the original
2256
+ # ligand_subgraph_Y, ligand_subgraph_Y_m, and ligand_subgraph_Y_t.
2257
+ ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t = (
2258
+ self.combine_ligand_and_atomized_side_chain_atoms(
2259
+ ligand_subgraph_Y,
2260
+ ligand_subgraph_Y_m,
2261
+ ligand_subgraph_Y_t,
2262
+ ligand_subgraph_R,
2263
+ ligand_subgraph_R_m,
2264
+ ligand_subgraph_R_t,
2265
+ input_features["X_virtual_atoms"],
2266
+ input_features["X_m_virtual_atoms"],
2267
+ input_features["residue_mask"],
2268
+ )
2269
+ )
2270
+
2271
+ # Save the ligand subgraph coordinates, mask, and types in the input
2272
+ # features.
2273
+ input_features["ligand_subgraph_Y"] = ligand_subgraph_Y
2274
+ input_features["ligand_subgraph_Y_m"] = ligand_subgraph_Y_m
2275
+ input_features["ligand_subgraph_Y_t"] = ligand_subgraph_Y_t
2276
+
2277
+ # Get the concatenated one hot type information for the ligand atoms.
2278
+ # ligand_subgraph_Y_t_concat_one_hot [B, L, M, self.num_atomic_numbers +
2279
+ # self.num_periodic_table_groups + self.num_periodic_table_periods] -
2280
+ # atomic number, periodic group, and periodic period of the ligand
2281
+ # atoms.
2282
+ ligand_subgraph_Y_t_concat_one_hot = (
2283
+ self.featurize_ligand_atom_type_information(
2284
+ input_features["ligand_subgraph_Y_t"]
2285
+ )
2286
+ )
2287
+
2288
+ # Get the protein to ligand subgraph edges.
2289
+ # E_protein_to_ligand [B, L, M, self.num_node_output_features] - protein
2290
+ # to ligand subgraph edges.
2291
+ E_protein_to_ligand = self.featurize_protein_to_ligand_subgraph_edges(
2292
+ ligand_subgraph_Y_t_concat_one_hot,
2293
+ input_features["X_backbone"],
2294
+ input_features["X_m_backbone"],
2295
+ input_features["X_virtual_atoms"],
2296
+ input_features["X_m_virtual_atoms"],
2297
+ input_features["ligand_subgraph_Y"],
2298
+ input_features["ligand_subgraph_Y_m"],
2299
+ )
2300
+
2301
+ # ligand_subgraph_nodes [B, L, M, self.num_node_output_features] -
2302
+ # ligand atom type information, embedded as node features.
2303
+ ligand_subgraph_nodes = self.featurize_ligand_subgraph_nodes(
2304
+ ligand_subgraph_Y_t_concat_one_hot
2305
+ )
2306
+
2307
+ # ligand_subgraph_edges [B, L, M, M, self.num_edge_output_features] -
2308
+ # embedded and normalized radial basis function embedding of the
2309
+ # distances between the ligand atoms in each residue subgraph.
2310
+ ligand_subgraph_edges = self.featurize_ligand_subgraph_edges(
2311
+ input_features["ligand_subgraph_Y"], input_features["ligand_subgraph_Y_m"]
2312
+ )
2313
+
2314
+ # Gather the node features.
2315
+ node_features = {
2316
+ "E_protein_to_ligand": E_protein_to_ligand,
2317
+ "ligand_subgraph_nodes": ligand_subgraph_nodes,
2318
+ "ligand_subgraph_edges": ligand_subgraph_edges,
2319
+ }
2320
+
2321
+ return node_features
2322
+
2323
+ def noise_structure(self, input_features):
2324
+ """
2325
+ Given input features containing 3D coordinates of atoms, add Gaussian
2326
+ noise to the coordinates.
2327
+
2328
+ Args:
2329
+ input_features (dict): Dictionary containing the input features.
2330
+ - X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
2331
+ polymer atoms.
2332
+ - Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand
2333
+ atoms.
2334
+ - structure_noise (float): Standard deviation of the
2335
+ Gaussian noise to add to the input coordinates, in
2336
+ Angstroms.
2337
+ Side Effects:
2338
+ input_features["X"] (torch.Tensor): [B, L, num_atoms, 3] - 3D
2339
+ coordinates of atoms with added Gaussian noise.
2340
+ input_features["Y"] (torch.Tensor): [B, N, 3] - 3D coordinates
2341
+ of the ligand atoms with added Gaussian noise.
2342
+ input_features["X_pre_noise"] (torch.Tensor): [B, L, num_atoms, 3] -
2343
+ 3D coordinates of polymer atoms before adding noise.
2344
+ input_features["Y_pre_noise"] (torch.Tensor): [B, N, 3] -
2345
+ 3D coordinates of the ligand atoms before adding noise.
2346
+ """
2347
+ if "X" not in input_features:
2348
+ raise ValueError("Input features must contain 'X' key.")
2349
+ if "Y" not in input_features:
2350
+ raise ValueError("Input features must contain 'Y' key.")
2351
+ if "structure_noise" not in input_features:
2352
+ raise ValueError("Input features must contain 'structure_noise' key.")
2353
+
2354
+ structure_noise = input_features["structure_noise"]
2355
+
2356
+ # If the noise is non-zero, add Gaussian noise to the input
2357
+ # coordinates.
2358
+ if structure_noise > 0:
2359
+ # Copy the original coordinates before adding noise.
2360
+ input_features["X_pre_noise"] = input_features["X"].clone()
2361
+ input_features["Y_pre_noise"] = input_features["Y"].clone()
2362
+
2363
+ # Add Gaussian noise to the input coordinates.
2364
+ input_features["X"] = input_features[
2365
+ "X"
2366
+ ] + structure_noise * torch.randn_like(input_features["X"])
2367
+ input_features["Y"] = input_features[
2368
+ "Y"
2369
+ ] + structure_noise * torch.randn_like(input_features["Y"])
2370
+ else:
2371
+ input_features["X_pre_noise"] = input_features["X"].clone()
2372
+ input_features["Y_pre_noise"] = input_features["Y"].clone()