rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,2372 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER
|
|
4
|
+
from mpnn.model.layers.message_passing import gather_edges, gather_nodes
|
|
5
|
+
from mpnn.model.layers.positional_encoding import PositionalEncodings
|
|
6
|
+
from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ProteinFeatures(nn.Module):
|
|
10
|
+
TOKEN_ENCODING = MPNN_TOKEN_ENCODING
|
|
11
|
+
BACKBONE_ATOM_NAMES = ["N", "CA", "C", "O"]
|
|
12
|
+
|
|
13
|
+
REPRESENTATIVE_ATOM_NAMES = ["CA"]
|
|
14
|
+
|
|
15
|
+
DATA_TO_CALCULATE_VIRTUAL_ATOMS = [
|
|
16
|
+
(
|
|
17
|
+
{"center_atom": "CA", "atom_1": "N", "atom_2": "C"},
|
|
18
|
+
{
|
|
19
|
+
"weight_normal": 0.58273431,
|
|
20
|
+
"weight_bond_1": -0.56802827,
|
|
21
|
+
"weight_bond_2": -0.54067466,
|
|
22
|
+
},
|
|
23
|
+
)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
num_edge_output_features=128,
|
|
29
|
+
num_node_output_features=128,
|
|
30
|
+
num_positional_embeddings=16,
|
|
31
|
+
min_rbf_mean=2.0,
|
|
32
|
+
max_rbf_mean=22.0,
|
|
33
|
+
num_rbf=16,
|
|
34
|
+
num_neighbors=48,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Given a protein structure, extract the features for the graph
|
|
38
|
+
representation of the protein.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
num_edge_output_features (int): Number of output features for the
|
|
42
|
+
edges.
|
|
43
|
+
num_node_output_features (int): Number of output features for the
|
|
44
|
+
nodes.
|
|
45
|
+
num_positional_embeddings (int): Number of positional embeddings.
|
|
46
|
+
min_rbf_mean (float): Minimum mean for the radial basis functions.
|
|
47
|
+
max_rbf_mean (float): Maximum mean for the radial basis functions.
|
|
48
|
+
num_rbf (int): Number of radial basis functions.
|
|
49
|
+
num_neighbors (int): Number of neighbors to consider for each
|
|
50
|
+
residue.
|
|
51
|
+
"""
|
|
52
|
+
super(ProteinFeatures, self).__init__()
|
|
53
|
+
|
|
54
|
+
self.num_edge_output_features = num_edge_output_features
|
|
55
|
+
self.num_node_output_features = num_node_output_features
|
|
56
|
+
|
|
57
|
+
self.num_neighbors = num_neighbors
|
|
58
|
+
|
|
59
|
+
self.min_rbf_mean = min_rbf_mean
|
|
60
|
+
self.max_rbf_mean = max_rbf_mean
|
|
61
|
+
self.num_rbf = num_rbf
|
|
62
|
+
|
|
63
|
+
self.num_positional_embeddings = num_positional_embeddings
|
|
64
|
+
|
|
65
|
+
self.num_backbone_atoms = len(self.BACKBONE_ATOM_NAMES)
|
|
66
|
+
self.num_virtual_atoms = len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)
|
|
67
|
+
|
|
68
|
+
self.num_edge_input_features = num_positional_embeddings + num_rbf * (
|
|
69
|
+
(self.num_backbone_atoms + self.num_virtual_atoms) ** 2
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Layers
|
|
73
|
+
self.positional_embedding = PositionalEncodings(self.num_positional_embeddings)
|
|
74
|
+
self.edge_embedding = nn.Linear(
|
|
75
|
+
self.num_edge_input_features, self.num_edge_output_features, bias=False
|
|
76
|
+
)
|
|
77
|
+
self.edge_norm = nn.LayerNorm(self.num_edge_output_features)
|
|
78
|
+
|
|
79
|
+
def construct_X_atoms(self, X, X_m, S, atom_names):
|
|
80
|
+
"""
|
|
81
|
+
Given an array of 3D coordinates and the corresponding atom mask, use
|
|
82
|
+
the sequence and the atom names to construct a subset of X and X_m that
|
|
83
|
+
contains only the requested atoms.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
|
|
87
|
+
3D coordinates of polymer atoms.
|
|
88
|
+
X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
|
|
89
|
+
Mask indicating which polymer atoms are valid.
|
|
90
|
+
S (torch.Tensor): [B, L] - the sequence of residues.
|
|
91
|
+
atom_names (list): List of atom names to extract from the
|
|
92
|
+
coordinates.
|
|
93
|
+
Returns:
|
|
94
|
+
X_atoms (torch.Tensor): [B, L, len(atom_names), 3] - 3D coordinates
|
|
95
|
+
of the requested atoms for each residue.
|
|
96
|
+
X_m_atoms (torch.Tensor): [B, L, len(atom_names)] - mask indicating
|
|
97
|
+
which requested atoms are valid for each residue.
|
|
98
|
+
"""
|
|
99
|
+
B, L, _, _ = X.shape
|
|
100
|
+
|
|
101
|
+
# token_and_atom_name_to_atom_idx [self.TOKEN_ENCODING.n_tokens,
|
|
102
|
+
# len(atom_names)] - a tensor that maps each token/atom name pair to the
|
|
103
|
+
# corresponding atom index.
|
|
104
|
+
token_and_atom_name_to_atom_idx = torch.zeros(
|
|
105
|
+
(self.TOKEN_ENCODING.n_tokens, len(atom_names)),
|
|
106
|
+
device=X.device,
|
|
107
|
+
dtype=torch.int64,
|
|
108
|
+
)
|
|
109
|
+
for token_name, token_idx in self.TOKEN_ENCODING.token_to_idx.items():
|
|
110
|
+
for i, atom_name in enumerate(atom_names):
|
|
111
|
+
token_and_atom_name_to_atom_idx[token_idx, i] = (
|
|
112
|
+
self.TOKEN_ENCODING.atom_to_idx[(token_name, atom_name)]
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# batch_idx [B, 1, 1] - a tensor that contains the batch index.
|
|
116
|
+
batch_idx = torch.arange(B, dtype=torch.int64, device=X.device).view(B, 1, 1)
|
|
117
|
+
|
|
118
|
+
# position_idx [1, L, 1] - a tensor that contains the position index.
|
|
119
|
+
position_idx = torch.arange(L, dtype=torch.int64, device=X.device).view(1, L, 1)
|
|
120
|
+
|
|
121
|
+
# atom_indices [B, L, len(atom_names)] - a tensor that contains the
|
|
122
|
+
# atom index for each residue for each atom name.
|
|
123
|
+
atom_indices = token_and_atom_name_to_atom_idx[S]
|
|
124
|
+
|
|
125
|
+
# X_atoms [B, L, len(atom_names), 3] - 3D coordinates of the atoms for
|
|
126
|
+
# each residue.
|
|
127
|
+
X_atoms = X[batch_idx, position_idx, atom_indices]
|
|
128
|
+
|
|
129
|
+
# X_m_atoms [B, L, len(atom_names)] - mask indicating which atoms are
|
|
130
|
+
# valid for each residue.
|
|
131
|
+
X_m_atoms = X_m[batch_idx, position_idx, atom_indices]
|
|
132
|
+
|
|
133
|
+
return X_atoms, X_m_atoms
|
|
134
|
+
|
|
135
|
+
def construct_X_rep_atoms(self, X, X_m, S):
|
|
136
|
+
"""
|
|
137
|
+
Given an array of 3D coordinates, construct a subset of X that
|
|
138
|
+
contains only the representative atom for each residue.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
|
|
142
|
+
3D coordinates of polymer atoms.
|
|
143
|
+
X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
|
|
144
|
+
Mask indicating which polymer atoms are valid.
|
|
145
|
+
S (torch.Tensor): [B, L] - the sequence of residues.
|
|
146
|
+
Returns:
|
|
147
|
+
X_rep_atoms (torch.Tensor):
|
|
148
|
+
[B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D coordinates
|
|
149
|
+
of the representative atoms for each residue.
|
|
150
|
+
X_m_rep_atoms (torch.Tensor):
|
|
151
|
+
[B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask indicating
|
|
152
|
+
which representative atoms are valid.
|
|
153
|
+
"""
|
|
154
|
+
X_rep_atoms, X_m_rep_atoms = self.construct_X_atoms(
|
|
155
|
+
X, X_m, S, self.REPRESENTATIVE_ATOM_NAMES
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Check that the representative atoms are disjoint (only one per
|
|
159
|
+
# residue).
|
|
160
|
+
if torch.any(torch.sum(X_m_rep_atoms, dim=-1) > 1):
|
|
161
|
+
raise ValueError("Each residue should have only one representative atom.")
|
|
162
|
+
|
|
163
|
+
return X_rep_atoms, X_m_rep_atoms
|
|
164
|
+
|
|
165
|
+
def construct_X_backbone(self, X, X_m, S):
|
|
166
|
+
"""
|
|
167
|
+
Given an array of 3D coordinates, construct a subset of X that
|
|
168
|
+
contains only the backbone atoms for each residue.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
|
|
172
|
+
3D coordinates of polymer atoms.
|
|
173
|
+
X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
|
|
174
|
+
Mask indicating which polymer atoms are valid.
|
|
175
|
+
S (torch.Tensor): [B, L] - the sequence of residues.
|
|
176
|
+
Returns:
|
|
177
|
+
X_backbone (torch.Tensor): [B, L, len(self.BACKBONE_ATOM_NAMES), 3]
|
|
178
|
+
- 3D coordinates of the backbone atoms for each residue.
|
|
179
|
+
X_m_backbone (torch.Tensor): [B, L, len(self.BACKBONE_ATOM_NAMES)] -
|
|
180
|
+
Mask indicating which backbone atoms are valid.
|
|
181
|
+
"""
|
|
182
|
+
X_backbone, X_m_backbone = self.construct_X_atoms(
|
|
183
|
+
X, X_m, S, self.BACKBONE_ATOM_NAMES
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return X_backbone, X_m_backbone
|
|
187
|
+
|
|
188
|
+
def construct_X_virtual_atom(
|
|
189
|
+
self,
|
|
190
|
+
X_center_atom,
|
|
191
|
+
X_atom_1,
|
|
192
|
+
X_atom_2,
|
|
193
|
+
weight_normal,
|
|
194
|
+
weight_bond_1,
|
|
195
|
+
weight_bond_2,
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
Predict the virtual atom coordinates based on the coordinates of the
|
|
199
|
+
center atom and the two other atoms.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
X_center_atom (torch.Tensor): [B, L, 3] - 3D coordinates of the
|
|
203
|
+
center atom.
|
|
204
|
+
X_atom_1 (torch.Tensor): [B, L, 3] - 3D coordinates of the first
|
|
205
|
+
atom.
|
|
206
|
+
X_atom_2 (torch.Tensor): [B, L, 3] - 3D coordinates of the second
|
|
207
|
+
atom.
|
|
208
|
+
weight_normal (float): Weight for the normal vector.
|
|
209
|
+
weight_bond_1 (float): Weight for the first bond vector.
|
|
210
|
+
weight_bond_2 (float): Weight for the second bond vector.
|
|
211
|
+
"""
|
|
212
|
+
# Calculate the bond vectors.
|
|
213
|
+
# bond_1 [B, L, 3] - vector from the center atom to the first atom.
|
|
214
|
+
bond_1 = X_atom_1 - X_center_atom
|
|
215
|
+
|
|
216
|
+
# bond_2 [B, L, 3] - vector from the center atom to the second atom.
|
|
217
|
+
bond_2 = X_atom_2 - X_center_atom
|
|
218
|
+
|
|
219
|
+
# normal [B, L, 3] - normal vector to the plane defined by the two
|
|
220
|
+
# bond vectors.
|
|
221
|
+
normal = torch.cross(bond_1, bond_2, dim=-1)
|
|
222
|
+
|
|
223
|
+
# X_virtual_atom [B, L, 3] - the coordinates of the virtual atom.
|
|
224
|
+
X_virtual_atom = (
|
|
225
|
+
weight_normal * normal
|
|
226
|
+
+ weight_bond_1 * bond_1
|
|
227
|
+
+ weight_bond_2 * bond_2
|
|
228
|
+
+ X_center_atom
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return X_virtual_atom
|
|
232
|
+
|
|
233
|
+
def construct_X_virtual_atoms(self, X, X_m, S):
|
|
234
|
+
"""
|
|
235
|
+
Given an array of 3D coordinates, construct a the virtual atoms.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
|
|
239
|
+
3D coordinates of polymer atoms.
|
|
240
|
+
X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
|
|
241
|
+
Mask indicating which polymer atoms are valid.
|
|
242
|
+
S (torch.Tensor): [B, L] - the sequence of residues.
|
|
243
|
+
Returns:
|
|
244
|
+
X_virtual_atoms (torch.Tensor):
|
|
245
|
+
[B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] - 3D
|
|
246
|
+
coordinates of the virtual atoms for each residue.
|
|
247
|
+
X_m_virtual_atoms (torch.Tensor):
|
|
248
|
+
[B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] - Mask
|
|
249
|
+
indicating which virtual atoms are valid.
|
|
250
|
+
"""
|
|
251
|
+
X_virtual_atoms = []
|
|
252
|
+
X_m_virtual_atoms = []
|
|
253
|
+
for virtual_atom_data in self.DATA_TO_CALCULATE_VIRTUAL_ATOMS:
|
|
254
|
+
virtual_atom_info, weights = virtual_atom_data
|
|
255
|
+
center_atom = virtual_atom_info["center_atom"]
|
|
256
|
+
atom_1 = virtual_atom_info["atom_1"]
|
|
257
|
+
atom_2 = virtual_atom_info["atom_2"]
|
|
258
|
+
|
|
259
|
+
# Get the coordinates and masks for the center atom and two
|
|
260
|
+
# other atoms. Stack the coordinates and masks of the three atoms
|
|
261
|
+
# to reduce the number of calls to construct_X_atoms.
|
|
262
|
+
atom_names = [center_atom, atom_1, atom_2]
|
|
263
|
+
X_atoms, X_m_atoms = self.construct_X_atoms(X, X_m, S, atom_names)
|
|
264
|
+
|
|
265
|
+
# X_center_atom, X_atom_1, X_atom_2 [B, L, 3] - 3D coordinates of
|
|
266
|
+
# the center atom and the two other atoms.
|
|
267
|
+
X_center_atom = X_atoms[:, :, atom_names.index(center_atom), :]
|
|
268
|
+
X_atom_1 = X_atoms[:, :, atom_names.index(atom_1), :]
|
|
269
|
+
X_atom_2 = X_atoms[:, :, atom_names.index(atom_2), :]
|
|
270
|
+
|
|
271
|
+
# X_m_center_atom, X_m_atom_1, X_m_atom_2 [B, L] - mask indicating
|
|
272
|
+
# if the center atom and the two other atoms are valid for each
|
|
273
|
+
# residue.
|
|
274
|
+
X_m_center_atom = X_m_atoms[:, :, atom_names.index(center_atom)]
|
|
275
|
+
X_m_atom_1 = X_m_atoms[:, :, atom_names.index(atom_1)]
|
|
276
|
+
X_m_atom_2 = X_m_atoms[:, :, atom_names.index(atom_2)]
|
|
277
|
+
|
|
278
|
+
# X_virtual_atom [B, L, 3] - 3D coordinates of the virtual atom
|
|
279
|
+
# constructed from the center atom and the two other atoms.
|
|
280
|
+
X_virtual_atom = self.construct_X_virtual_atom(
|
|
281
|
+
X_center_atom,
|
|
282
|
+
X_atom_1,
|
|
283
|
+
X_atom_2,
|
|
284
|
+
weight_normal=weights["weight_normal"],
|
|
285
|
+
weight_bond_1=weights["weight_bond_1"],
|
|
286
|
+
weight_bond_2=weights["weight_bond_2"],
|
|
287
|
+
)
|
|
288
|
+
X_virtual_atoms.append(X_virtual_atom)
|
|
289
|
+
|
|
290
|
+
# X_m_virtual_atom [B, L] - mask indicating if the virtual atom
|
|
291
|
+
# is valid for each residue.
|
|
292
|
+
X_m_virtual_atom = X_m_center_atom * X_m_atom_1 * X_m_atom_2
|
|
293
|
+
X_m_virtual_atoms.append(X_m_virtual_atom)
|
|
294
|
+
|
|
295
|
+
# X_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] -
|
|
296
|
+
# coordinates of the virtual atoms for each residue.
|
|
297
|
+
X_virtual_atoms = torch.stack(X_virtual_atoms, dim=2)
|
|
298
|
+
|
|
299
|
+
# X_m_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] -
|
|
300
|
+
# mask indicating which virtual atoms are valid for each residue.
|
|
301
|
+
X_m_virtual_atoms = torch.stack(X_m_virtual_atoms, dim=2)
|
|
302
|
+
|
|
303
|
+
# Check that the virtual atoms are disjoint (only one per residue).
|
|
304
|
+
if torch.any(torch.sum(X_m_virtual_atoms, dim=-1) > 1):
|
|
305
|
+
raise ValueError("Each residue should have only one virtual atom.")
|
|
306
|
+
|
|
307
|
+
return X_virtual_atoms, X_m_virtual_atoms
|
|
308
|
+
|
|
309
|
+
def compute_representative_atom_pairwise_distances(
|
|
310
|
+
self, X_rep_atoms, X_m_rep_atoms, residue_mask, eps=1e-6
|
|
311
|
+
):
|
|
312
|
+
"""
|
|
313
|
+
Given an array of 3D coordinates, compute the pairwise distances
|
|
314
|
+
between all pairs of atoms. The masked distances are set to the
|
|
315
|
+
maximum distance.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
X_rep_atoms (torch.Tensor):
|
|
319
|
+
[B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D coordinates
|
|
320
|
+
of the representative atoms for each residue.
|
|
321
|
+
X_m_rep_atoms (torch.Tensor):
|
|
322
|
+
[B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask indicating
|
|
323
|
+
which representative atoms are valid.
|
|
324
|
+
residue_mask (torch.Tensor): [B, L] - mask indicating which residues
|
|
325
|
+
are valid.
|
|
326
|
+
eps (float): Small value used to distances that are
|
|
327
|
+
numerically zero.
|
|
328
|
+
Returns:
|
|
329
|
+
D_rep_neighbors (torch.Tensor): [B, L, K] - Pairwise distances
|
|
330
|
+
between each residue's representative atom, masked by the
|
|
331
|
+
2D mask, for the top K neighbors.
|
|
332
|
+
E_idx (torch.Tensor): [B, L, K] - Indices of the top K neighbors
|
|
333
|
+
for each residue's representative atom.
|
|
334
|
+
"""
|
|
335
|
+
_, L, _, _ = X_rep_atoms.shape
|
|
336
|
+
|
|
337
|
+
# mask_2D [B, L, L] - 2D mask indicating which pairs of residues
|
|
338
|
+
# are valid.
|
|
339
|
+
mask_2D = (
|
|
340
|
+
torch.unsqueeze(residue_mask, 1) * torch.unsqueeze(residue_mask, 2)
|
|
341
|
+
).bool()
|
|
342
|
+
|
|
343
|
+
# X_rep_atoms_collapsed [B, L, 3] - collapse the representative atom
|
|
344
|
+
# dimension.
|
|
345
|
+
# NOTE: collapsing along this dimension is okay because the
|
|
346
|
+
# self.construct_X_rep_atoms function ensures that there is only one
|
|
347
|
+
# representative atom per residue.
|
|
348
|
+
X_rep_atoms_collapsed = torch.sum(
|
|
349
|
+
X_rep_atoms * X_m_rep_atoms[:, :, :, None], dim=2
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# dX [B, L, L, 3] - pairwise per-coordinate differences between each
|
|
353
|
+
# residue's representative atom.
|
|
354
|
+
dX_rep = torch.unsqueeze(X_rep_atoms_collapsed, 1) - torch.unsqueeze(
|
|
355
|
+
X_rep_atoms_collapsed, 2
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# D_rep [B, L, L] - pairwise distances between each residue's
|
|
359
|
+
# representative atom, masked by the 2D mask.
|
|
360
|
+
D_rep = mask_2D * torch.sqrt(torch.sum(dX_rep**2, 3) + eps)
|
|
361
|
+
|
|
362
|
+
# D_rep_max [B, L, L] - a constant value that is the maximum distance
|
|
363
|
+
# between any two representative atoms in each batch entry.
|
|
364
|
+
D_rep_max, _ = torch.max(D_rep, -1, keepdim=True)
|
|
365
|
+
|
|
366
|
+
# D_rep_adjust [B, L, L] - the pairwise distances between each residue's
|
|
367
|
+
# representative atom, with masked distances set to the maximum
|
|
368
|
+
# distance.
|
|
369
|
+
D_rep_adjust = D_rep + (~mask_2D) * D_rep_max
|
|
370
|
+
|
|
371
|
+
# D_rep_neighbors [B, L, K] - the top K pairwise distances between
|
|
372
|
+
# each residue's representative atom, masked by the 2D mask.
|
|
373
|
+
# E_idx [B, L, K] - the indices of the top K pairwise distances
|
|
374
|
+
# between each residue's representative atom.
|
|
375
|
+
D_rep_neighbors, E_idx = torch.topk(
|
|
376
|
+
D_rep_adjust, min(self.num_neighbors, L), dim=-1, largest=False
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
return D_rep_neighbors, E_idx
|
|
380
|
+
|
|
381
|
+
def compute_rbf_embedding_from_distances(self, D):
|
|
382
|
+
"""
|
|
383
|
+
Given a tensor of pairwise distances, compute the radial basis
|
|
384
|
+
embedding of the distances.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
D (torch.Tensor): [B, L, K] - Pairwise distances between each
|
|
388
|
+
residue's representative atom, masked by the 2D mask.
|
|
389
|
+
Returns:
|
|
390
|
+
rbf_embedding (torch.Tensor): [B, L, K, num_rbf] - Radial basis
|
|
391
|
+
function embedding of the pairwise distances.
|
|
392
|
+
"""
|
|
393
|
+
# Linear space the means of the radial basis functions.
|
|
394
|
+
# rbf_mus: [1, 1, 1, num_rbf]
|
|
395
|
+
rbf_mus = torch.linspace(
|
|
396
|
+
self.min_rbf_mean, self.max_rbf_mean, self.num_rbf, device=D.device
|
|
397
|
+
)
|
|
398
|
+
rbf_mus = rbf_mus[None, None, None, :]
|
|
399
|
+
|
|
400
|
+
# The standard deviation of the radial basis functions.
|
|
401
|
+
rbf_sigma = (self.max_rbf_mean - self.min_rbf_mean) / self.num_rbf
|
|
402
|
+
|
|
403
|
+
# Expand the dimensions of D to match the shape of rbf_mus.
|
|
404
|
+
# D_expand: [B, L, K, 1]
|
|
405
|
+
D_expand = torch.unsqueeze(D, -1)
|
|
406
|
+
|
|
407
|
+
# Compute the radial basis function embedding.
|
|
408
|
+
# RBF: [B, L, K, num_rbf]
|
|
409
|
+
rbf_embedding = torch.exp(-(((D_expand - rbf_mus) / rbf_sigma) ** 2))
|
|
410
|
+
|
|
411
|
+
return rbf_embedding
|
|
412
|
+
|
|
413
|
+
def compute_pairwise_residue_rbf_encoding(self, X, E_idx, X_m, eps=1e-6):
|
|
414
|
+
"""
|
|
415
|
+
Given an array of 3D coordinates, compute the atom by atom pairwise
|
|
416
|
+
distances between each pair of neighbors. Mask the RBF features using
|
|
417
|
+
the atom mask.
|
|
418
|
+
|
|
419
|
+
NOTE: num_atoms = self.num_backbone_atoms + self.num_virtual_atoms
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
|
|
423
|
+
polymer atoms.
|
|
424
|
+
E_idx (torch.Tensor): [B, L, K] - Indices of the top K neighbors.
|
|
425
|
+
X_m (torch.Tensor): [B, L, num_atoms] - mask indicating which
|
|
426
|
+
polymer atoms are valid.
|
|
427
|
+
eps (float): Small value added to distances that are zero.
|
|
428
|
+
Returns:
|
|
429
|
+
RBF_all (torch.Tensor): [B, L, K, num_atoms * num_atoms * num_rbf] -
|
|
430
|
+
Radial basis function embedding of the pairwise atomic
|
|
431
|
+
distances for each pair of residue neighbors.
|
|
432
|
+
"""
|
|
433
|
+
B = X.shape[0]
|
|
434
|
+
L = X.shape[1]
|
|
435
|
+
K = E_idx.shape[2]
|
|
436
|
+
num_atoms = X.shape[2]
|
|
437
|
+
|
|
438
|
+
# X_flat [B, L, num_atoms * 3] - flatten the last two dimensions.
|
|
439
|
+
X_flat = X.reshape(B, L, -1)
|
|
440
|
+
|
|
441
|
+
# X_flat_g [B, L, K, num_atoms * 3] - gather the top K neighbors.
|
|
442
|
+
X_flat_g = gather_nodes(X_flat, E_idx)
|
|
443
|
+
|
|
444
|
+
# X_g [B, L, K, num_atoms, 3] - reshape the gathered tensor.
|
|
445
|
+
X_g = X_flat_g.reshape(B, L, K, num_atoms, 3)
|
|
446
|
+
|
|
447
|
+
# D [B, L, K, num_atoms, num_atoms] - pairwise distances between
|
|
448
|
+
# each residue's atoms.
|
|
449
|
+
D = torch.sqrt(
|
|
450
|
+
torch.sum(
|
|
451
|
+
(X[:, :, None, :, None, :] - X_g[:, :, :, None, :, :]) ** 2, dim=-1
|
|
452
|
+
)
|
|
453
|
+
+ eps
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# RBF_all [B, L, K, num_atoms, num_atoms, num_rbf] - radial basis
|
|
457
|
+
# function embedding of the pairwise distances.
|
|
458
|
+
RBF_all = self.compute_rbf_embedding_from_distances(D)
|
|
459
|
+
|
|
460
|
+
# If X_m is not all ones, mask the radial basis function embedding
|
|
461
|
+
# with the atom mask.
|
|
462
|
+
if not torch.all(X_m == 1):
|
|
463
|
+
# X_m_gathered [B, L, K, num_atoms] - gather the atom mask of the
|
|
464
|
+
# top K neighbors.
|
|
465
|
+
X_m_gathered = gather_nodes(X_m, E_idx)
|
|
466
|
+
|
|
467
|
+
# RBF_all [B, L, K, num_atoms, num_atoms, num_rbf] - mask the
|
|
468
|
+
# radial basis function embedding with the atom mask.
|
|
469
|
+
RBF_all = (
|
|
470
|
+
RBF_all
|
|
471
|
+
* X_m[:, :, None, :, None, None]
|
|
472
|
+
* X_m_gathered[:, :, :, None, :, None]
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# RBF_all [B, L, K, num_atoms * num_atoms * num_rbf] - flatten the
|
|
476
|
+
# last dimensions.
|
|
477
|
+
RBF_all = RBF_all.view(B, L, K, -1)
|
|
478
|
+
|
|
479
|
+
return RBF_all
|
|
480
|
+
|
|
481
|
+
def compute_pairwise_positional_encoding(self, R_idx, E_idx, chain_labels):
|
|
482
|
+
"""
|
|
483
|
+
Given the indices of the residues and the indices of the top K
|
|
484
|
+
neighbors, compute the positional encoding of the top K neighbors
|
|
485
|
+
for each residue.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
R_idx (torch.Tensor): [B, L] - indices of the residues.
|
|
489
|
+
E_idx (torch.Tensor): [B, L, K] - indices of the top K neighbors.
|
|
490
|
+
chain_labels (torch.Tensor): [B, L] - chain labels for each residue.
|
|
491
|
+
Returns:
|
|
492
|
+
positional_encoding (torch.Tensor):
|
|
493
|
+
[B, L, K, num_positional_embeddings] - the positional encoding
|
|
494
|
+
of the top K neighbors.
|
|
495
|
+
"""
|
|
496
|
+
# positional_offset [B, L, L] - pairwise differences between the
|
|
497
|
+
# indices of the residues.
|
|
498
|
+
positional_offset = (R_idx[:, :, None] - R_idx[:, None, :]).long()
|
|
499
|
+
|
|
500
|
+
# positional_offset_g [B, L, K] - gather the positional offset
|
|
501
|
+
# of the top K neighbors.
|
|
502
|
+
positional_offset_g = gather_edges(positional_offset[:, :, :, None], E_idx)[
|
|
503
|
+
:, :, :, 0
|
|
504
|
+
]
|
|
505
|
+
|
|
506
|
+
# same_chain_mask [B, L, L] - mask indicating which residues are in the
|
|
507
|
+
# same chain.
|
|
508
|
+
same_chain_mask = (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
|
|
509
|
+
|
|
510
|
+
# same_chain_mask_g [B, L, K] - gather the same chain mask of the
|
|
511
|
+
# top K neighbors.
|
|
512
|
+
same_chain_mask_g = gather_edges(same_chain_mask[:, :, :, None], E_idx)[
|
|
513
|
+
:, :, :, 0
|
|
514
|
+
]
|
|
515
|
+
|
|
516
|
+
# positional_encoding [B, L, K, num_positional_embeddings] - the
|
|
517
|
+
# positional encoding of the top K neighbors.
|
|
518
|
+
positional_encoding = self.positional_embedding(
|
|
519
|
+
positional_offset_g, same_chain_mask_g
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
return positional_encoding
|
|
523
|
+
|
|
524
|
+
def featurize_edges(self, input_features):
|
|
525
|
+
"""
|
|
526
|
+
Given input features, construct the edge features for the protein.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
input_features (dict): Dictionary containing the input features.
|
|
530
|
+
- residue_mask (torch.Tensor): [B, L] - Mask indicating which
|
|
531
|
+
residues are valid.
|
|
532
|
+
- R_idx (torch.Tensor): [B, L] - Indices of the residues.
|
|
533
|
+
- chain_labels (torch.Tensor): [B, L] - Chain labels for each
|
|
534
|
+
residue.
|
|
535
|
+
- X_backbone (torch.Tensor):
|
|
536
|
+
[B, L, len(self.BACKBONE_ATOM_NAMES), 3] - 3D
|
|
537
|
+
coordinates of the backbone atoms for each residue.
|
|
538
|
+
- X_m_backbone (torch.Tensor):
|
|
539
|
+
[B, L, len(self.BACKBONE_ATOM_NAMES)] - mask
|
|
540
|
+
indicating which backbone atoms are valid.
|
|
541
|
+
- X_virtual_atoms (torch.Tensor):
|
|
542
|
+
[B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] -
|
|
543
|
+
3D coordinates of the virtual atoms for each residue.
|
|
544
|
+
- X_m_virtual_atoms (torch.Tensor):
|
|
545
|
+
[B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] -
|
|
546
|
+
mask indicating which virtual atoms are valid.
|
|
547
|
+
- X_rep_atoms (torch.Tensor):
|
|
548
|
+
[B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D
|
|
549
|
+
coordinates of the representative atoms for each residue.
|
|
550
|
+
- X_m_rep_atoms (torch.Tensor):
|
|
551
|
+
[B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask
|
|
552
|
+
indicating which representative atoms are valid.
|
|
553
|
+
Returns:
|
|
554
|
+
edge_features (dict): Dictionary containing the edge features.
|
|
555
|
+
- E_idx (torch.Tensor): [B, L, K] - Indices of the top K
|
|
556
|
+
neighbors.
|
|
557
|
+
- E (torch.Tensor): [B, L, K, num_edge_output_features] -
|
|
558
|
+
Edge features for each pair of neighbors.
|
|
559
|
+
"""
|
|
560
|
+
# The following features should come from data loading.
|
|
561
|
+
if "residue_mask" not in input_features:
|
|
562
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
563
|
+
if "R_idx" not in input_features:
|
|
564
|
+
raise ValueError("Input features must contain 'R_idx' key.")
|
|
565
|
+
if "chain_labels" not in input_features:
|
|
566
|
+
raise ValueError("Input features must contain 'chain_labels' key.")
|
|
567
|
+
|
|
568
|
+
# The following features should be computed by the forward function.
|
|
569
|
+
if "X_backbone" not in input_features:
|
|
570
|
+
raise ValueError("Input features must contain 'X_backbone' key.")
|
|
571
|
+
if "X_m_backbone" not in input_features:
|
|
572
|
+
raise ValueError("Input features must contain 'X_m_backbone' key.")
|
|
573
|
+
if "X_virtual_atoms" not in input_features:
|
|
574
|
+
raise ValueError("Input features must contain 'X_virtual_atoms' key.")
|
|
575
|
+
if "X_m_virtual_atoms" not in input_features:
|
|
576
|
+
raise ValueError("Input features must contain 'X_m_virtual_atoms' key.")
|
|
577
|
+
if "X_rep_atoms" not in input_features:
|
|
578
|
+
raise ValueError("Input features must contain 'X_rep_atoms' key.")
|
|
579
|
+
if "X_m_rep_atoms" not in input_features:
|
|
580
|
+
raise ValueError("Input features must contain 'X_m_rep_atoms' key.")
|
|
581
|
+
|
|
582
|
+
# Compute the pairwise distances between the representative atoms.
|
|
583
|
+
# D_rep_neighbors [B, L, K] - pairwise distances between each residue's
|
|
584
|
+
# representative atom, masked by the 2D mask.
|
|
585
|
+
# E_idx [B, L, K] - indices of the top K neighbors.
|
|
586
|
+
D_rep_neighbors, E_idx = self.compute_representative_atom_pairwise_distances(
|
|
587
|
+
input_features["X_rep_atoms"],
|
|
588
|
+
input_features["X_m_rep_atoms"],
|
|
589
|
+
input_features["residue_mask"],
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# Concatenate the backbone and virtual atom coordinates.
|
|
593
|
+
# X_backbone_with_virtual_atoms [B, L, num_atoms + num_virtual_atoms, 3]
|
|
594
|
+
# - 3D coordinates of the backbone and virtual atoms for each residue.
|
|
595
|
+
# X_m_backbone_with_virtual_atoms [B, L, num_atoms + num_virtual_atoms]
|
|
596
|
+
# - mask indicating which backbone and virtual atoms are valid.
|
|
597
|
+
X_backbone_with_virtual_atoms = torch.cat(
|
|
598
|
+
(input_features["X_backbone"], input_features["X_virtual_atoms"]), dim=-2
|
|
599
|
+
)
|
|
600
|
+
X_m_backbone_with_virtual_atoms = torch.cat(
|
|
601
|
+
(input_features["X_m_backbone"], input_features["X_m_virtual_atoms"]),
|
|
602
|
+
dim=-1,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Compute the RBF features for the atomwise distances for each pair of
|
|
606
|
+
# neighbors.
|
|
607
|
+
# RBF_all [B, L, K, num_atoms * num_atoms * num_rbf] - radial basis
|
|
608
|
+
# function embedding of the pairwise atomic distances for each pair of
|
|
609
|
+
# residue neighbors.
|
|
610
|
+
RBF_all = self.compute_pairwise_residue_rbf_encoding(
|
|
611
|
+
X_backbone_with_virtual_atoms, E_idx, X_m_backbone_with_virtual_atoms
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# Compute the positional encoding for the top K neighbors.
|
|
615
|
+
# positional_encoding [B, L, K, num_positional_embeddings] - the
|
|
616
|
+
# positional encoding of the top K neighbors.
|
|
617
|
+
positional_encoding = self.compute_pairwise_positional_encoding(
|
|
618
|
+
input_features["R_idx"], E_idx, input_features["chain_labels"]
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Concatenate the positional encoding and the RBF features.
|
|
622
|
+
# E [B, L, K, num_positional_embeddings + num_atoms * num_atoms *
|
|
623
|
+
# num_rbf] - the edge features for each pair of neighbors.
|
|
624
|
+
E_raw = torch.cat((positional_encoding, RBF_all), dim=-1)
|
|
625
|
+
|
|
626
|
+
# Embed and normalize the edge features.
|
|
627
|
+
# E [B, L, K, num_edge_output_features] - the edge features for each
|
|
628
|
+
# pair of neighbors.
|
|
629
|
+
E = self.edge_embedding(E_raw)
|
|
630
|
+
E = self.edge_norm(E)
|
|
631
|
+
|
|
632
|
+
edge_features = {
|
|
633
|
+
"E_idx": E_idx,
|
|
634
|
+
"E": E,
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
return edge_features
|
|
638
|
+
|
|
639
|
+
def featurize_nodes(self, input_features, edge_features):
|
|
640
|
+
"""
|
|
641
|
+
The default ProteinMPNN does not have any node features.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
input_features (dict): Dictionary containing the input features.
|
|
645
|
+
edge_features (dict): Dictionary containing the edge features.
|
|
646
|
+
Returns:
|
|
647
|
+
node_features (dict): Dictionary containing the node features.
|
|
648
|
+
"""
|
|
649
|
+
node_features = {}
|
|
650
|
+
return node_features
|
|
651
|
+
|
|
652
|
+
def noise_structure(self, input_features):
|
|
653
|
+
"""
|
|
654
|
+
Given input features containing 3D coordinates of atoms, add Gaussian
|
|
655
|
+
noise to the coordinates.
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
input_features (dict): Dictionary containing the input features.
|
|
659
|
+
- X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
|
|
660
|
+
polymer atoms.
|
|
661
|
+
- structure_noise (float): Standard deviation of the
|
|
662
|
+
Gaussian noise to add to the input coordinates, in
|
|
663
|
+
Angstroms.
|
|
664
|
+
Side Effects:
|
|
665
|
+
input_features["X_pre_noise"] (torch.Tensor): [B, L, num_atoms, 3] -
|
|
666
|
+
3D coordinates of polymer atoms before adding noise.
|
|
667
|
+
input_features["X"] (torch.Tensor): [B, L, num_atoms, 3] - 3D
|
|
668
|
+
coordinates of polymer atoms with added Gaussian noise.
|
|
669
|
+
"""
|
|
670
|
+
if "X" not in input_features:
|
|
671
|
+
raise ValueError("Input features must contain 'X' key.")
|
|
672
|
+
if "structure_noise" not in input_features:
|
|
673
|
+
raise ValueError("Input features must contain 'structure_noise' key.")
|
|
674
|
+
|
|
675
|
+
structure_noise = input_features["structure_noise"]
|
|
676
|
+
|
|
677
|
+
# If the noise is non-zero, add Gaussian noise to the input
|
|
678
|
+
# coordinates.
|
|
679
|
+
if structure_noise > 0:
|
|
680
|
+
# Copy the original coordinates before adding noise.
|
|
681
|
+
input_features["X_pre_noise"] = input_features["X"].clone()
|
|
682
|
+
|
|
683
|
+
# Add Gaussian noise to the input coordinates.
|
|
684
|
+
input_features["X"] = input_features[
|
|
685
|
+
"X"
|
|
686
|
+
] + structure_noise * torch.randn_like(input_features["X"])
|
|
687
|
+
else:
|
|
688
|
+
input_features["X_pre_noise"] = input_features["X"].clone()
|
|
689
|
+
|
|
690
|
+
def forward(self, input_features):
|
|
691
|
+
"""
|
|
692
|
+
Given input features, construct the graph features for the protein.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
input_features (dict): Dictionary containing the input features.
|
|
696
|
+
- X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
|
|
697
|
+
polymer atoms.
|
|
698
|
+
- X_m (torch.Tensor): [B, L, num_atoms] - Mask indicating
|
|
699
|
+
which polymer atoms are valid.
|
|
700
|
+
Returns:
|
|
701
|
+
graph_features (dict): Dictionary containing the graph features.
|
|
702
|
+
Union of edge and node features (see the repsective featurize
|
|
703
|
+
functions).
|
|
704
|
+
"""
|
|
705
|
+
if "X" not in input_features:
|
|
706
|
+
raise ValueError("Input features must contain 'X' key.")
|
|
707
|
+
if "X_m" not in input_features:
|
|
708
|
+
raise ValueError("Input features must contain 'X_m' key.")
|
|
709
|
+
if "S" not in input_features:
|
|
710
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
711
|
+
|
|
712
|
+
# Add Gaussian noise to the input coordinates.
|
|
713
|
+
self.noise_structure(input_features)
|
|
714
|
+
|
|
715
|
+
# Get the backbone atoms and mask.
|
|
716
|
+
# X_backbone [B, L, len(self.BACKBONE_ATOM_NAMES), 3] - 3D coordinates
|
|
717
|
+
# of the backbone atoms for each residue.
|
|
718
|
+
# X_m_backbone [B, L, len(self.BACKBONE_ATOM_NAMES)] - mask indicating
|
|
719
|
+
# which backbone atoms are valid.
|
|
720
|
+
X_backbone, X_m_backbone = self.construct_X_backbone(
|
|
721
|
+
input_features["X"], input_features["X_m"], input_features["S"]
|
|
722
|
+
)
|
|
723
|
+
input_features["X_backbone"] = X_backbone
|
|
724
|
+
input_features["X_m_backbone"] = X_m_backbone
|
|
725
|
+
|
|
726
|
+
# Get the virtual atoms and mask.
|
|
727
|
+
# X_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS), 3] -
|
|
728
|
+
# 3D coordinates of the virtual atoms for each residue.
|
|
729
|
+
# X_m_virtual_atoms [B, L, len(self.DATA_TO_CALCULATE_VIRTUAL_ATOMS)] -
|
|
730
|
+
# mask indicating which virtual atoms are valid.
|
|
731
|
+
X_virtual_atoms, X_m_virtual_atoms = self.construct_X_virtual_atoms(
|
|
732
|
+
input_features["X"], input_features["X_m"], input_features["S"]
|
|
733
|
+
)
|
|
734
|
+
input_features["X_virtual_atoms"] = X_virtual_atoms
|
|
735
|
+
input_features["X_m_virtual_atoms"] = X_m_virtual_atoms
|
|
736
|
+
|
|
737
|
+
# Get the representative atoms.
|
|
738
|
+
# X_rep_atoms [B, L, len(self.REPRESENTATIVE_ATOM_NAMES), 3] - 3D
|
|
739
|
+
# coordinates of the representative atoms for each residue.
|
|
740
|
+
# X_m_rep_atoms [B, L, len(self.REPRESENTATIVE_ATOM_NAMES)] - mask
|
|
741
|
+
# indicating which representative atoms are valid.
|
|
742
|
+
X_rep_atoms, X_m_rep_atoms = self.construct_X_rep_atoms(
|
|
743
|
+
input_features["X"], input_features["X_m"], input_features["S"]
|
|
744
|
+
)
|
|
745
|
+
input_features["X_rep_atoms"] = X_rep_atoms
|
|
746
|
+
input_features["X_m_rep_atoms"] = X_m_rep_atoms
|
|
747
|
+
|
|
748
|
+
# Featurize the edges.
|
|
749
|
+
edge_features = self.featurize_edges(input_features)
|
|
750
|
+
|
|
751
|
+
# Featurize the nodes.
|
|
752
|
+
# Edge features are sometimes needed for node feature calculation;
|
|
753
|
+
# for instance, for gathering nearest neighbor side chain atoms for
|
|
754
|
+
# per-residue ligand subgraphs in LigandMPNN.
|
|
755
|
+
node_features = self.featurize_nodes(input_features, edge_features)
|
|
756
|
+
|
|
757
|
+
# Construct the graph features.
|
|
758
|
+
graph_features = {**edge_features, **node_features}
|
|
759
|
+
|
|
760
|
+
return graph_features
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
class ProteinFeaturesMembrane(ProteinFeatures):
|
|
764
|
+
def __init__(self, num_membrane_classes=3, **kwargs):
|
|
765
|
+
"""
|
|
766
|
+
Given a protein structure, extract the features for the graph
|
|
767
|
+
representation of the protein. This class is aware of membrane labels.
|
|
768
|
+
|
|
769
|
+
All args are the same as the parents class, except for the following:
|
|
770
|
+
Args:
|
|
771
|
+
num_membrane_classes (int): Number of membrane classes.
|
|
772
|
+
"""
|
|
773
|
+
super(ProteinFeaturesMembrane, self).__init__(**kwargs)
|
|
774
|
+
self.num_membrane_classes = num_membrane_classes
|
|
775
|
+
|
|
776
|
+
self.node_embedding = nn.Linear(
|
|
777
|
+
self.num_classes, self.num_node_output_features, bias=False
|
|
778
|
+
)
|
|
779
|
+
self.node_norm = nn.LayerNorm(self.num_node_output_features)
|
|
780
|
+
|
|
781
|
+
def featurize_nodes(self, input_features, edge_features):
|
|
782
|
+
"""
|
|
783
|
+
Given input features, construct the node features for the protein.
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
input_features (dict): Dictionary containing the input features.
|
|
787
|
+
- membrane_per_residue_labels (torch.Tensor): [B, L] - Class
|
|
788
|
+
labels for each residue.
|
|
789
|
+
edge_features (dict): Dictionary containing the edge features.
|
|
790
|
+
Returns:
|
|
791
|
+
node_features (dict): Dictionary containing the node features.
|
|
792
|
+
- V (torch.Tensor): [B, L, self.num_node_output_features] - Node
|
|
793
|
+
features for each residue.
|
|
794
|
+
"""
|
|
795
|
+
if "membrane_per_residue_labels" not in input_features:
|
|
796
|
+
raise ValueError(
|
|
797
|
+
"Input features must contain 'membrane_per_residue_labels' key."
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
# Turn the class labels into one-hot vectors.
|
|
801
|
+
# class_one_hot [B, L, self.num_membrane_classes] - one-hot encoding
|
|
802
|
+
# of the class
|
|
803
|
+
class_one_hot = torch.nn.functional.one_hot(
|
|
804
|
+
input_features["membrane_per_residue_labels"],
|
|
805
|
+
num_classes=self.num_membrane_classes,
|
|
806
|
+
).float()
|
|
807
|
+
|
|
808
|
+
# Embed and normalize the node features.
|
|
809
|
+
# V [B, L, self.num_node_output_features] - the node features for each
|
|
810
|
+
# residue.
|
|
811
|
+
V = self.node_embedding(class_one_hot)
|
|
812
|
+
V = self.node_norm(V)
|
|
813
|
+
|
|
814
|
+
node_features = {"V": V}
|
|
815
|
+
|
|
816
|
+
return node_features
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
class ProteinFeaturesPSSM(ProteinFeatures):
|
|
820
|
+
def __init__(self, num_pssm_features=20, **kwargs):
|
|
821
|
+
"""
|
|
822
|
+
Given a protein structure, extract the features for the graph
|
|
823
|
+
representation of the protein. This class is aware of PSSM features.
|
|
824
|
+
|
|
825
|
+
All args are the same as the parents class, except for the following:
|
|
826
|
+
Args:
|
|
827
|
+
num_pssm_features (int): Number of PSSM features.
|
|
828
|
+
"""
|
|
829
|
+
super(ProteinFeaturesPSSM, self).__init__(**kwargs)
|
|
830
|
+
self.num_pssm_features = num_pssm_features
|
|
831
|
+
|
|
832
|
+
self.node_embedding = nn.Linear(
|
|
833
|
+
self.num_pssm_features, self.num_node_output_features, bias=False
|
|
834
|
+
)
|
|
835
|
+
self.node_norm = nn.LayerNorm(self.num_node_output_features)
|
|
836
|
+
|
|
837
|
+
def featurize_nodes(self, input_features, edge_features):
|
|
838
|
+
"""
|
|
839
|
+
Given input features, construct the node features for the protein.
|
|
840
|
+
|
|
841
|
+
Args:
|
|
842
|
+
input_features (dict): Dictionary containing the input features.
|
|
843
|
+
- pssm (torch.Tensor): [B, L, self.num_pssm_features] - PSSM
|
|
844
|
+
features for each residue.
|
|
845
|
+
edge_features (dict): Dictionary containing the edge features.
|
|
846
|
+
Returns:
|
|
847
|
+
node_features (dict): Dictionary containing the node features.
|
|
848
|
+
- V (torch.Tensor): [B, L, self.num_node_output_features] - Node
|
|
849
|
+
features for each residue.
|
|
850
|
+
"""
|
|
851
|
+
if "pssm" not in input_features:
|
|
852
|
+
raise ValueError("Input features must contain 'pssm' key.")
|
|
853
|
+
|
|
854
|
+
# Embed and normalize the node features.
|
|
855
|
+
# V [B, L, self.num_node_output_features] - the node features for each
|
|
856
|
+
# residue.
|
|
857
|
+
V = self.node_embedding(input_features["pssm"])
|
|
858
|
+
V = self.node_norm(V)
|
|
859
|
+
|
|
860
|
+
node_features = {"V": V}
|
|
861
|
+
|
|
862
|
+
return node_features
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
class ProteinFeaturesLigand(ProteinFeatures):
|
|
866
|
+
# Note, CB is excluded from the side chain atoms due to the use of the
|
|
867
|
+
# virtual CB.
|
|
868
|
+
SIDE_CHAIN_ATOM_NAMES = [
|
|
869
|
+
"CG",
|
|
870
|
+
"CG1",
|
|
871
|
+
"CG2",
|
|
872
|
+
"OG",
|
|
873
|
+
"OG1",
|
|
874
|
+
"SG",
|
|
875
|
+
"CD",
|
|
876
|
+
"CD1",
|
|
877
|
+
"CD2",
|
|
878
|
+
"ND1",
|
|
879
|
+
"ND2",
|
|
880
|
+
"OD1",
|
|
881
|
+
"OD2",
|
|
882
|
+
"SD",
|
|
883
|
+
"CE",
|
|
884
|
+
"CE1",
|
|
885
|
+
"CE2",
|
|
886
|
+
"CE3",
|
|
887
|
+
"NE",
|
|
888
|
+
"NE1",
|
|
889
|
+
"NE2",
|
|
890
|
+
"OE1",
|
|
891
|
+
"OE2",
|
|
892
|
+
"CH2",
|
|
893
|
+
"NH1",
|
|
894
|
+
"NH2",
|
|
895
|
+
"OH",
|
|
896
|
+
"CZ",
|
|
897
|
+
"CZ2",
|
|
898
|
+
"CZ3",
|
|
899
|
+
"NZ",
|
|
900
|
+
"OXT",
|
|
901
|
+
]
|
|
902
|
+
|
|
903
|
+
# Mapping of side chain atom name to element name.
|
|
904
|
+
SIDE_CHAIN_ATOM_NAME_TO_ELEMENT_NAME = {
|
|
905
|
+
"CG": "C",
|
|
906
|
+
"CG1": "C",
|
|
907
|
+
"CG2": "C",
|
|
908
|
+
"OG": "O",
|
|
909
|
+
"OG1": "O",
|
|
910
|
+
"SG": "S",
|
|
911
|
+
"CD": "C",
|
|
912
|
+
"CD1": "C",
|
|
913
|
+
"CD2": "C",
|
|
914
|
+
"ND1": "N",
|
|
915
|
+
"ND2": "N",
|
|
916
|
+
"OD1": "O",
|
|
917
|
+
"OD2": "O",
|
|
918
|
+
"SD": "S",
|
|
919
|
+
"CE": "C",
|
|
920
|
+
"CE1": "C",
|
|
921
|
+
"CE2": "C",
|
|
922
|
+
"CE3": "C",
|
|
923
|
+
"NE": "N",
|
|
924
|
+
"NE1": "N",
|
|
925
|
+
"NE2": "N",
|
|
926
|
+
"OE1": "O",
|
|
927
|
+
"OE2": "O",
|
|
928
|
+
"CH2": "C",
|
|
929
|
+
"NH1": "N",
|
|
930
|
+
"NH2": "N",
|
|
931
|
+
"OH": "O",
|
|
932
|
+
"CZ": "C",
|
|
933
|
+
"CZ2": "C",
|
|
934
|
+
"CZ3": "C",
|
|
935
|
+
"NZ": "N",
|
|
936
|
+
"OXT": "O",
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
def __init__(self, num_neighbors=32, num_context_atoms=25, **kwargs):
|
|
940
|
+
"""
|
|
941
|
+
Given a protein structure and ligand structure, extract the features for
|
|
942
|
+
the graph representation of the protein and ligand. This class is aware
|
|
943
|
+
of ligand features.
|
|
944
|
+
|
|
945
|
+
All args are the same as the parents class, except for the following:
|
|
946
|
+
Args:
|
|
947
|
+
num_neighbors (int): Number of neighbors to consider, default
|
|
948
|
+
changed from 48 to 32.
|
|
949
|
+
num_context_atoms (int): Number of ligand plus side chain atoms to
|
|
950
|
+
consider for each polymer residue.
|
|
951
|
+
"""
|
|
952
|
+
super(ProteinFeaturesLigand, self).__init__(
|
|
953
|
+
num_neighbors=num_neighbors, **kwargs
|
|
954
|
+
)
|
|
955
|
+
self.num_context_atoms = num_context_atoms
|
|
956
|
+
|
|
957
|
+
# Number of side chain atoms.
|
|
958
|
+
self.num_side_chain_atoms = len(self.SIDE_CHAIN_ATOM_NAMES)
|
|
959
|
+
|
|
960
|
+
# Features for atom type (periodic table features):
|
|
961
|
+
# There is a null group, period, and atomic number.
|
|
962
|
+
self.num_periodic_table_groups = 1 + 18
|
|
963
|
+
self.num_periodic_table_periods = 1 + 7
|
|
964
|
+
self.num_atomic_numbers = 1 + 118
|
|
965
|
+
self.num_atom_type_input_features = (
|
|
966
|
+
self.num_periodic_table_groups
|
|
967
|
+
+ self.num_periodic_table_periods
|
|
968
|
+
+ self.num_atomic_numbers
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
# Number of nearest neighbors residue to consider for finding atomized
|
|
972
|
+
# side chain atoms.
|
|
973
|
+
self.num_neighbors_for_atomized_side_chain = 16
|
|
974
|
+
|
|
975
|
+
# Max distance for finding nearest ligand atom neighbors.
|
|
976
|
+
self.max_distance_for_ligand_atoms = 10000.0
|
|
977
|
+
|
|
978
|
+
# Projection of the atom type features to the embedding space.
|
|
979
|
+
self.num_atom_type_output_features = 64
|
|
980
|
+
|
|
981
|
+
# Number of angle features.
|
|
982
|
+
self.num_angle_features = 4
|
|
983
|
+
|
|
984
|
+
# Node features (protein-ligand subgraph edge features).
|
|
985
|
+
# 1. RBF features for ligand atom to each backbone atom and virtual atom
|
|
986
|
+
# 2. Atom type features for the ligand atom
|
|
987
|
+
# 3. Angle features for the ligand atom and the backbone atoms
|
|
988
|
+
self.num_node_input_features = (
|
|
989
|
+
(self.num_backbone_atoms + self.num_virtual_atoms) * self.num_rbf
|
|
990
|
+
+ self.num_atom_type_output_features
|
|
991
|
+
+ self.num_angle_features
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
# Layers for the protein-ligand subgraph edge features.
|
|
995
|
+
self.embed_atom_type_features = nn.Linear(
|
|
996
|
+
self.num_atom_type_input_features,
|
|
997
|
+
self.num_atom_type_output_features,
|
|
998
|
+
bias=True,
|
|
999
|
+
)
|
|
1000
|
+
self.node_embedding = nn.Linear(
|
|
1001
|
+
self.num_node_input_features, self.num_node_output_features, bias=True
|
|
1002
|
+
)
|
|
1003
|
+
self.node_norm = nn.LayerNorm(self.num_node_output_features)
|
|
1004
|
+
|
|
1005
|
+
# Layers for the ligand subgraphs.
|
|
1006
|
+
self.ligand_subgraph_node_embedding = nn.Linear(
|
|
1007
|
+
self.num_atom_type_input_features, self.num_node_output_features, bias=False
|
|
1008
|
+
)
|
|
1009
|
+
self.ligand_subgraph_node_norm = nn.LayerNorm(self.num_node_output_features)
|
|
1010
|
+
self.ligand_subgraph_edge_embedding = nn.Linear(
|
|
1011
|
+
self.num_rbf, self.num_node_output_features, bias=False
|
|
1012
|
+
)
|
|
1013
|
+
self.ligand_subgraph_edge_norm = nn.LayerNorm(self.num_node_output_features)
|
|
1014
|
+
|
|
1015
|
+
# Numeric encoding of the atom type (atomic number for the last 32 atoms
|
|
1016
|
+
# in the 37 atom representation).
|
|
1017
|
+
self.register_buffer(
|
|
1018
|
+
"side_chain_atom_types",
|
|
1019
|
+
torch.tensor(
|
|
1020
|
+
[
|
|
1021
|
+
ELEMENT_NAME_TO_ATOMIC_NUMBER[
|
|
1022
|
+
self.SIDE_CHAIN_ATOM_NAME_TO_ELEMENT_NAME[atom_name]
|
|
1023
|
+
]
|
|
1024
|
+
for atom_name in self.SIDE_CHAIN_ATOM_NAMES
|
|
1025
|
+
]
|
|
1026
|
+
),
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
# Atomic number, period, and group for the periodic table.
|
|
1030
|
+
self.register_buffer(
|
|
1031
|
+
"periodic_table_groups",
|
|
1032
|
+
torch.tensor(
|
|
1033
|
+
[
|
|
1034
|
+
0,
|
|
1035
|
+
1,
|
|
1036
|
+
18,
|
|
1037
|
+
1,
|
|
1038
|
+
2,
|
|
1039
|
+
13,
|
|
1040
|
+
14,
|
|
1041
|
+
15,
|
|
1042
|
+
16,
|
|
1043
|
+
17,
|
|
1044
|
+
18,
|
|
1045
|
+
1,
|
|
1046
|
+
2,
|
|
1047
|
+
13,
|
|
1048
|
+
14,
|
|
1049
|
+
15,
|
|
1050
|
+
16,
|
|
1051
|
+
17,
|
|
1052
|
+
18,
|
|
1053
|
+
1,
|
|
1054
|
+
2,
|
|
1055
|
+
3,
|
|
1056
|
+
4,
|
|
1057
|
+
5,
|
|
1058
|
+
6,
|
|
1059
|
+
7,
|
|
1060
|
+
8,
|
|
1061
|
+
9,
|
|
1062
|
+
10,
|
|
1063
|
+
11,
|
|
1064
|
+
12,
|
|
1065
|
+
13,
|
|
1066
|
+
14,
|
|
1067
|
+
15,
|
|
1068
|
+
16,
|
|
1069
|
+
17,
|
|
1070
|
+
18,
|
|
1071
|
+
1,
|
|
1072
|
+
2,
|
|
1073
|
+
3,
|
|
1074
|
+
4,
|
|
1075
|
+
5,
|
|
1076
|
+
6,
|
|
1077
|
+
7,
|
|
1078
|
+
8,
|
|
1079
|
+
9,
|
|
1080
|
+
10,
|
|
1081
|
+
11,
|
|
1082
|
+
12,
|
|
1083
|
+
13,
|
|
1084
|
+
14,
|
|
1085
|
+
15,
|
|
1086
|
+
16,
|
|
1087
|
+
17,
|
|
1088
|
+
18,
|
|
1089
|
+
1,
|
|
1090
|
+
2,
|
|
1091
|
+
3,
|
|
1092
|
+
3,
|
|
1093
|
+
3,
|
|
1094
|
+
3,
|
|
1095
|
+
3,
|
|
1096
|
+
3,
|
|
1097
|
+
3,
|
|
1098
|
+
3,
|
|
1099
|
+
3,
|
|
1100
|
+
3,
|
|
1101
|
+
3,
|
|
1102
|
+
3,
|
|
1103
|
+
3,
|
|
1104
|
+
3,
|
|
1105
|
+
3,
|
|
1106
|
+
4,
|
|
1107
|
+
5,
|
|
1108
|
+
6,
|
|
1109
|
+
7,
|
|
1110
|
+
8,
|
|
1111
|
+
9,
|
|
1112
|
+
10,
|
|
1113
|
+
11,
|
|
1114
|
+
12,
|
|
1115
|
+
13,
|
|
1116
|
+
14,
|
|
1117
|
+
15,
|
|
1118
|
+
16,
|
|
1119
|
+
17,
|
|
1120
|
+
18,
|
|
1121
|
+
1,
|
|
1122
|
+
2,
|
|
1123
|
+
3,
|
|
1124
|
+
3,
|
|
1125
|
+
3,
|
|
1126
|
+
3,
|
|
1127
|
+
3,
|
|
1128
|
+
3,
|
|
1129
|
+
3,
|
|
1130
|
+
3,
|
|
1131
|
+
3,
|
|
1132
|
+
3,
|
|
1133
|
+
3,
|
|
1134
|
+
3,
|
|
1135
|
+
3,
|
|
1136
|
+
3,
|
|
1137
|
+
3,
|
|
1138
|
+
4,
|
|
1139
|
+
5,
|
|
1140
|
+
6,
|
|
1141
|
+
7,
|
|
1142
|
+
8,
|
|
1143
|
+
9,
|
|
1144
|
+
10,
|
|
1145
|
+
11,
|
|
1146
|
+
12,
|
|
1147
|
+
13,
|
|
1148
|
+
14,
|
|
1149
|
+
15,
|
|
1150
|
+
16,
|
|
1151
|
+
17,
|
|
1152
|
+
18,
|
|
1153
|
+
],
|
|
1154
|
+
dtype=torch.long,
|
|
1155
|
+
),
|
|
1156
|
+
)
|
|
1157
|
+
self.register_buffer(
|
|
1158
|
+
"periodic_table_periods",
|
|
1159
|
+
torch.tensor(
|
|
1160
|
+
[
|
|
1161
|
+
0,
|
|
1162
|
+
1,
|
|
1163
|
+
1,
|
|
1164
|
+
2,
|
|
1165
|
+
2,
|
|
1166
|
+
2,
|
|
1167
|
+
2,
|
|
1168
|
+
2,
|
|
1169
|
+
2,
|
|
1170
|
+
2,
|
|
1171
|
+
2,
|
|
1172
|
+
3,
|
|
1173
|
+
3,
|
|
1174
|
+
3,
|
|
1175
|
+
3,
|
|
1176
|
+
3,
|
|
1177
|
+
3,
|
|
1178
|
+
3,
|
|
1179
|
+
3,
|
|
1180
|
+
4,
|
|
1181
|
+
4,
|
|
1182
|
+
4,
|
|
1183
|
+
4,
|
|
1184
|
+
4,
|
|
1185
|
+
4,
|
|
1186
|
+
4,
|
|
1187
|
+
4,
|
|
1188
|
+
4,
|
|
1189
|
+
4,
|
|
1190
|
+
4,
|
|
1191
|
+
4,
|
|
1192
|
+
4,
|
|
1193
|
+
4,
|
|
1194
|
+
4,
|
|
1195
|
+
4,
|
|
1196
|
+
4,
|
|
1197
|
+
4,
|
|
1198
|
+
5,
|
|
1199
|
+
5,
|
|
1200
|
+
5,
|
|
1201
|
+
5,
|
|
1202
|
+
5,
|
|
1203
|
+
5,
|
|
1204
|
+
5,
|
|
1205
|
+
5,
|
|
1206
|
+
5,
|
|
1207
|
+
5,
|
|
1208
|
+
5,
|
|
1209
|
+
5,
|
|
1210
|
+
5,
|
|
1211
|
+
5,
|
|
1212
|
+
5,
|
|
1213
|
+
5,
|
|
1214
|
+
5,
|
|
1215
|
+
5,
|
|
1216
|
+
6,
|
|
1217
|
+
6,
|
|
1218
|
+
6,
|
|
1219
|
+
6,
|
|
1220
|
+
6,
|
|
1221
|
+
6,
|
|
1222
|
+
6,
|
|
1223
|
+
6,
|
|
1224
|
+
6,
|
|
1225
|
+
6,
|
|
1226
|
+
6,
|
|
1227
|
+
6,
|
|
1228
|
+
6,
|
|
1229
|
+
6,
|
|
1230
|
+
6,
|
|
1231
|
+
6,
|
|
1232
|
+
6,
|
|
1233
|
+
6,
|
|
1234
|
+
6,
|
|
1235
|
+
6,
|
|
1236
|
+
6,
|
|
1237
|
+
6,
|
|
1238
|
+
6,
|
|
1239
|
+
6,
|
|
1240
|
+
6,
|
|
1241
|
+
6,
|
|
1242
|
+
6,
|
|
1243
|
+
6,
|
|
1244
|
+
6,
|
|
1245
|
+
6,
|
|
1246
|
+
6,
|
|
1247
|
+
6,
|
|
1248
|
+
7,
|
|
1249
|
+
7,
|
|
1250
|
+
7,
|
|
1251
|
+
7,
|
|
1252
|
+
7,
|
|
1253
|
+
7,
|
|
1254
|
+
7,
|
|
1255
|
+
7,
|
|
1256
|
+
7,
|
|
1257
|
+
7,
|
|
1258
|
+
7,
|
|
1259
|
+
7,
|
|
1260
|
+
7,
|
|
1261
|
+
7,
|
|
1262
|
+
7,
|
|
1263
|
+
7,
|
|
1264
|
+
7,
|
|
1265
|
+
7,
|
|
1266
|
+
7,
|
|
1267
|
+
7,
|
|
1268
|
+
7,
|
|
1269
|
+
7,
|
|
1270
|
+
7,
|
|
1271
|
+
7,
|
|
1272
|
+
7,
|
|
1273
|
+
7,
|
|
1274
|
+
7,
|
|
1275
|
+
7,
|
|
1276
|
+
7,
|
|
1277
|
+
7,
|
|
1278
|
+
7,
|
|
1279
|
+
7,
|
|
1280
|
+
],
|
|
1281
|
+
dtype=torch.long,
|
|
1282
|
+
),
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
def construct_X_side_chain(self, X, X_m, S):
|
|
1286
|
+
"""
|
|
1287
|
+
Given the 3D coordinates of the atoms and the mask, construct the
|
|
1288
|
+
side chain atoms and their mask.
|
|
1289
|
+
|
|
1290
|
+
Args:
|
|
1291
|
+
X (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token, 3] -
|
|
1292
|
+
3D coordinates of polymer atoms.
|
|
1293
|
+
X_m (torch.Tensor): [B, L, self.TOKEN_ENCODING.n_atoms_per_token] -
|
|
1294
|
+
Mask indicating which polymer atoms are valid.
|
|
1295
|
+
S (torch.Tensor): [B, L] - Sequence of the polymer residues.
|
|
1296
|
+
Returns:
|
|
1297
|
+
X_side_chain (torch.Tensor):
|
|
1298
|
+
[B, L, len(self.SIDE_CHAIN_ATOM_NAMES), 3] -
|
|
1299
|
+
3D coordinates of the side chain atoms for each residue.
|
|
1300
|
+
X_m_side_chain (torch.Tensor):
|
|
1301
|
+
[B, L, len(self.SIDE_CHAIN_ATOM_NAMES)] -
|
|
1302
|
+
Mask indicating which side chain atoms are valid.
|
|
1303
|
+
"""
|
|
1304
|
+
X_side_chain, X_m_side_chain = self.construct_X_atoms(
|
|
1305
|
+
X, X_m, S, self.SIDE_CHAIN_ATOM_NAMES
|
|
1306
|
+
)
|
|
1307
|
+
|
|
1308
|
+
return X_side_chain, X_m_side_chain
|
|
1309
|
+
|
|
1310
|
+
def construct_angle_features(
|
|
1311
|
+
self, center_atom, atom_1, atom_2, ligand_subgraph_Y, eps=1e-8
|
|
1312
|
+
):
|
|
1313
|
+
"""
|
|
1314
|
+
Given the 3D coordinates of the center atom, the first atom, the second
|
|
1315
|
+
atom, and the ligand atoms, compute the angle features for the ligand
|
|
1316
|
+
atom with respect to the center atom and the two atoms.
|
|
1317
|
+
|
|
1318
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
1319
|
+
residue subgraph.
|
|
1320
|
+
|
|
1321
|
+
Args:
|
|
1322
|
+
center_atom (torch.Tensor): [B, L, 3] - 3D coordinates of the center
|
|
1323
|
+
atom.
|
|
1324
|
+
atom_1 (torch.Tensor): [B, L, 3] - 3D coordinates of the first atom.
|
|
1325
|
+
atom_2 (torch.Tensor): [B, L, 3] - 3D coordinates of the second
|
|
1326
|
+
atom.
|
|
1327
|
+
ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
|
|
1328
|
+
the M closest ligand atoms to each residue.
|
|
1329
|
+
eps (float): Small value added to distances that are zero.
|
|
1330
|
+
Returns:
|
|
1331
|
+
angle_features (torch.Tensor):
|
|
1332
|
+
[B, L, M, self.num_angle_features] - Angle features for the
|
|
1333
|
+
ligand atom with respect to the center atom and the two atoms.
|
|
1334
|
+
cos_azimuthal_xy_angle [B, L, M] -
|
|
1335
|
+
Cosine of the azimuthal angle in the local x-y plane.
|
|
1336
|
+
sin_azimuthal_xy_angle [B, L, M] -
|
|
1337
|
+
Sine of the azimuthal angle in the local x-y plane.
|
|
1338
|
+
cos_inclination_angle [B, L, M] -
|
|
1339
|
+
Cosine of the inclination angle (polar angle).
|
|
1340
|
+
sin_inclination_angle [B, L, M] -
|
|
1341
|
+
Sine of the inclination angle (polar angle).
|
|
1342
|
+
"""
|
|
1343
|
+
# Compute the bond vectors.
|
|
1344
|
+
# bond_1 [B, L, 3] - vector from the center atom to the first atom.
|
|
1345
|
+
# bond_2 [B, L, 3] - vector from the center atom to the second atom.
|
|
1346
|
+
bond_1 = atom_1 - center_atom
|
|
1347
|
+
bond_2 = atom_2 - center_atom
|
|
1348
|
+
|
|
1349
|
+
# Construct an orthonormal basis from the bond vectors.
|
|
1350
|
+
# The first vector in the basis, the normalized bond_1 vector.
|
|
1351
|
+
# basis_vector_1 [B, L, 3] - normalized bond_1 vector.
|
|
1352
|
+
basis_vector_1 = torch.nn.functional.normalize(bond_1, dim=-1)
|
|
1353
|
+
|
|
1354
|
+
# Project bond_2 onto the first vector in the basis.
|
|
1355
|
+
# length_bond_2_proj [B, L, 1] - length of the projection of bond_2 onto
|
|
1356
|
+
# basis_vector_1.
|
|
1357
|
+
length_bond_2_proj = torch.einsum("bli, bli -> bl", basis_vector_1, bond_2)[
|
|
1358
|
+
..., None
|
|
1359
|
+
]
|
|
1360
|
+
|
|
1361
|
+
# bond_2_orthogonal_component [B, L, 3] - component of bond_2 vector
|
|
1362
|
+
# orthogonal to basis_vector_1.
|
|
1363
|
+
bond_2_orthogonal_component = bond_2 - basis_vector_1 * length_bond_2_proj
|
|
1364
|
+
|
|
1365
|
+
# basis_vector_2 [B, L, 3] - normalized bond_2 orthogonal component.
|
|
1366
|
+
basis_vector_2 = torch.nn.functional.normalize(
|
|
1367
|
+
bond_2_orthogonal_component, dim=-1
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1370
|
+
# basis_vector_3 [B, L, 3] - cross product of the first two basis
|
|
1371
|
+
# vectors.
|
|
1372
|
+
basis_vector_3 = torch.cross(basis_vector_1, basis_vector_2, dim=-1)
|
|
1373
|
+
|
|
1374
|
+
# By construction, basis_vector_1, basis_vector_2, and basis_vector_3
|
|
1375
|
+
# form an orthonormal basis. We stack them together to form a
|
|
1376
|
+
# rotation matrix. This rotation matrix can be used to transform from
|
|
1377
|
+
# the local coordinate system defined by the bond vectors to the global
|
|
1378
|
+
# coordinate system:
|
|
1379
|
+
# v_global = R_residue @ v_local + center_atom
|
|
1380
|
+
# R_residue [B, L, 3, 3] - rotation matrix.
|
|
1381
|
+
R_residue = torch.cat(
|
|
1382
|
+
(
|
|
1383
|
+
basis_vector_1[:, :, :, None],
|
|
1384
|
+
basis_vector_2[:, :, :, None],
|
|
1385
|
+
basis_vector_3[:, :, :, None],
|
|
1386
|
+
),
|
|
1387
|
+
dim=-1,
|
|
1388
|
+
)
|
|
1389
|
+
|
|
1390
|
+
# Compute the local coordinates of the ligand atoms with respect to
|
|
1391
|
+
# the center atom.
|
|
1392
|
+
# ligand_subgraph_Y_local [B, L, M, 3] - local coordinates of the ligand
|
|
1393
|
+
# atoms with respect to the center atom.
|
|
1394
|
+
ligand_subgraph_Y_local = torch.einsum(
|
|
1395
|
+
"blqp, blyq -> blyp",
|
|
1396
|
+
R_residue,
|
|
1397
|
+
ligand_subgraph_Y - center_atom[:, :, None, :],
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
# Compute the length of the local vectors projected onto the local x-y
|
|
1401
|
+
# plane.
|
|
1402
|
+
# ligand_subgraph_Y_proj_xy_local_length [B, L, M] - length of the local
|
|
1403
|
+
# vectors projected on the local x-y plane.
|
|
1404
|
+
ligand_subgraph_Y_proj_xy_local_length = torch.sqrt(
|
|
1405
|
+
ligand_subgraph_Y_local[..., 0] ** 2
|
|
1406
|
+
+ ligand_subgraph_Y_local[..., 1] ** 2
|
|
1407
|
+
+ eps
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
# Compute the cosine and sine of the azimuthal angle.
|
|
1411
|
+
# cos_azimuthal_xy_angle [B, L, M] - cosine of the azimuthal angle in
|
|
1412
|
+
# the local x-y plane.
|
|
1413
|
+
# sin_azimuthal_xy_angle [B, L, M] - sine of the azimuthal angle in the
|
|
1414
|
+
# local x-y plane.
|
|
1415
|
+
cos_azimuthal_xy_angle = (
|
|
1416
|
+
ligand_subgraph_Y_local[..., 0] / ligand_subgraph_Y_proj_xy_local_length
|
|
1417
|
+
)
|
|
1418
|
+
sin_azimuthal_xy_angle = (
|
|
1419
|
+
ligand_subgraph_Y_local[..., 1] / ligand_subgraph_Y_proj_xy_local_length
|
|
1420
|
+
)
|
|
1421
|
+
|
|
1422
|
+
# Compute the length of the local vectors.
|
|
1423
|
+
# ligand_subgraph_Y_local_length [B, L, M] - length of the local
|
|
1424
|
+
# vectors.
|
|
1425
|
+
ligand_subgraph_Y_local_length = (
|
|
1426
|
+
torch.norm(ligand_subgraph_Y_local, dim=-1) + eps
|
|
1427
|
+
)
|
|
1428
|
+
|
|
1429
|
+
# Compute the cosine and sine of the inclination angle (polar angle).
|
|
1430
|
+
# cos_inclination_angle [B, L, M] - cosine of the inclination angle
|
|
1431
|
+
# (polar angle).
|
|
1432
|
+
# sin_inclination_angle [B, L, M] - sine of the inclination angle
|
|
1433
|
+
# (polar angle).
|
|
1434
|
+
cos_inclination_angle = (
|
|
1435
|
+
ligand_subgraph_Y_proj_xy_local_length / ligand_subgraph_Y_local_length
|
|
1436
|
+
)
|
|
1437
|
+
sin_inclination_angle = (
|
|
1438
|
+
ligand_subgraph_Y_local[..., 2] / ligand_subgraph_Y_local_length
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
# Concatenate the angle features.
|
|
1442
|
+
angle_features = torch.cat(
|
|
1443
|
+
(
|
|
1444
|
+
cos_azimuthal_xy_angle[..., None],
|
|
1445
|
+
sin_azimuthal_xy_angle[..., None],
|
|
1446
|
+
cos_inclination_angle[..., None],
|
|
1447
|
+
sin_inclination_angle[..., None],
|
|
1448
|
+
),
|
|
1449
|
+
dim=-1,
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
return angle_features
|
|
1453
|
+
|
|
1454
|
+
def gather_nearest_per_residue_atoms(
|
|
1455
|
+
self,
|
|
1456
|
+
per_residue_ligand_coords,
|
|
1457
|
+
per_residue_ligand_mask,
|
|
1458
|
+
per_residue_ligand_types,
|
|
1459
|
+
X_virtual_atoms,
|
|
1460
|
+
X_m_virtual_atoms,
|
|
1461
|
+
residue_mask,
|
|
1462
|
+
):
|
|
1463
|
+
"""
|
|
1464
|
+
Given the 3D coordinates of the ligand atoms, their mask, and the
|
|
1465
|
+
virtual atoms, gather the nearest ligand atoms to the virtual atoms for
|
|
1466
|
+
each residue.
|
|
1467
|
+
|
|
1468
|
+
NOTE:
|
|
1469
|
+
num_ligand_atoms = N
|
|
1470
|
+
when called in self.gather_nearest_ligand_atoms.
|
|
1471
|
+
num_ligand_atoms = M
|
|
1472
|
+
when called in self.combine_ligand_and_side_chain_atoms.
|
|
1473
|
+
|
|
1474
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
1475
|
+
residue subgraph.
|
|
1476
|
+
|
|
1477
|
+
Args:
|
|
1478
|
+
per_residue_ligand_coords (torch.Tensor): [B, L, num_ligand_atoms,
|
|
1479
|
+
3] - per residue 3D coordinates of the ligand atoms.
|
|
1480
|
+
per_residue_ligand_mask (torch.Tensor): [B, L, num_ligand_atoms] -
|
|
1481
|
+
per residue mask indicating which ligand atoms are valid.
|
|
1482
|
+
per_residue_ligand_types (torch.Tensor): [B, L, num_ligand_atoms] -
|
|
1483
|
+
per residue element types of the ligand atoms (atomic numbers).
|
|
1484
|
+
X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] - 3D
|
|
1485
|
+
coordinates of the virtual atoms for each residue.
|
|
1486
|
+
X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] -
|
|
1487
|
+
mask indicating which virtual atoms are valid.
|
|
1488
|
+
residue_mask (torch.Tensor): [B, L] - mask indicating which residues
|
|
1489
|
+
are valid.
|
|
1490
|
+
Returns:
|
|
1491
|
+
ligand_subgraph_Y (torch.Tensor):
|
|
1492
|
+
[B, L, M, 3] - 3D coordinates of the nearest ligand atoms to the
|
|
1493
|
+
virtual atoms for each residue.
|
|
1494
|
+
ligand_subgraph_Y_m (torch.Tensor):
|
|
1495
|
+
[B, L, M] - mask indicating which nearest ligand atoms to the
|
|
1496
|
+
virtual atoms are valid.
|
|
1497
|
+
ligand_subgraph_Y_t (torch.Tensor):
|
|
1498
|
+
[B, L, M] - element types of the nearest ligand atoms to the
|
|
1499
|
+
virtual atoms for each residue.
|
|
1500
|
+
"""
|
|
1501
|
+
B, L, num_ligand_atoms, _ = per_residue_ligand_coords.shape
|
|
1502
|
+
|
|
1503
|
+
# X_virtual_atoms_collapsed [B, L, 3] - collapse the virtual atom
|
|
1504
|
+
# dimension.
|
|
1505
|
+
# NOTE: collapsing along this dimension is okay because the
|
|
1506
|
+
# self.construct_X_virtual_atoms function ensures that there is only one
|
|
1507
|
+
# virtual atom per residue.
|
|
1508
|
+
X_virtual_atoms_collapsed = torch.sum(
|
|
1509
|
+
X_virtual_atoms * X_m_virtual_atoms[:, :, :, None], dim=2
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
# ligand_to_virtual_atom_distances [B, L, num_ligand_atoms] -
|
|
1513
|
+
# distance between the ligand atoms and the virtual atoms.
|
|
1514
|
+
ligand_to_virtual_atom_distances = torch.sqrt(
|
|
1515
|
+
torch.sum(
|
|
1516
|
+
(X_virtual_atoms_collapsed[:, :, None, :] - per_residue_ligand_coords)
|
|
1517
|
+
** 2,
|
|
1518
|
+
dim=-1,
|
|
1519
|
+
)
|
|
1520
|
+
)
|
|
1521
|
+
|
|
1522
|
+
# residue_and_ligand_mask [B, L, num_ligand_atoms] - mask indicating
|
|
1523
|
+
# which residue-ligand atom pairs are valid.
|
|
1524
|
+
residue_and_ligand_mask = (
|
|
1525
|
+
residue_mask[:, :, None] * per_residue_ligand_mask
|
|
1526
|
+
).bool()
|
|
1527
|
+
|
|
1528
|
+
# ligand_to_virtual_atom_distances_adjusted [B, L, num_ligand_atoms] -
|
|
1529
|
+
# distances between the virtual atoms and the ligand atoms, with
|
|
1530
|
+
# invalid residue-ligand atom pairs adjusted to a maximum distance.
|
|
1531
|
+
ligand_to_virtual_atom_distances_adjusted = (
|
|
1532
|
+
ligand_to_virtual_atom_distances * residue_and_ligand_mask
|
|
1533
|
+
+ (~residue_and_ligand_mask) * self.max_distance_for_ligand_atoms
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
# E_idx_ligand_subgraph [B, L, M] - indices of the closest ligand atoms
|
|
1537
|
+
# to the virtual atoms.
|
|
1538
|
+
_, E_idx_ligand_subgraph = torch.topk(
|
|
1539
|
+
ligand_to_virtual_atom_distances_adjusted,
|
|
1540
|
+
min(self.num_context_atoms, num_ligand_atoms),
|
|
1541
|
+
dim=-1,
|
|
1542
|
+
largest=False,
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
# Gather the ligand atom coordinates, mask, and types based on the
|
|
1546
|
+
# indices of the closest ligand atoms to the virtual atoms.
|
|
1547
|
+
# ligand_subgraph_Y [B, L, M, 3] - 3D coordinates of the nearest ligand
|
|
1548
|
+
# atoms to the virtual atoms for each residue.
|
|
1549
|
+
ligand_subgraph_Y = torch.gather(
|
|
1550
|
+
per_residue_ligand_coords,
|
|
1551
|
+
dim=2,
|
|
1552
|
+
index=E_idx_ligand_subgraph[:, :, :, None].expand(-1, -1, -1, 3),
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
# ligand_subgraph_Y_m [B, L, M] - mask indicating which nearest ligand
|
|
1556
|
+
# atoms to the virtual atoms are valid.
|
|
1557
|
+
ligand_subgraph_Y_m = torch.gather(
|
|
1558
|
+
per_residue_ligand_mask, dim=2, index=E_idx_ligand_subgraph
|
|
1559
|
+
)
|
|
1560
|
+
|
|
1561
|
+
# ligand_subgraph_Y_t [B, L, M] - element types of the nearest ligand
|
|
1562
|
+
# atoms to the virtual atoms for each residue.
|
|
1563
|
+
ligand_subgraph_Y_t = torch.gather(
|
|
1564
|
+
per_residue_ligand_types, dim=2, index=E_idx_ligand_subgraph
|
|
1565
|
+
)
|
|
1566
|
+
|
|
1567
|
+
return ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t
|
|
1568
|
+
|
|
1569
|
+
def gather_nearest_ligand_atoms(
|
|
1570
|
+
self, Y, Y_m, Y_t, X_virtual_atoms, X_m_virtual_atoms, residue_mask
|
|
1571
|
+
):
|
|
1572
|
+
"""
|
|
1573
|
+
Given the 3D coordinates of the ligand atoms, their mask, and the
|
|
1574
|
+
virtual atoms, gather the nearest ligand atoms to the virtual atoms for
|
|
1575
|
+
each residue.
|
|
1576
|
+
|
|
1577
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
1578
|
+
residue subgraph.
|
|
1579
|
+
|
|
1580
|
+
Args:
|
|
1581
|
+
Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand atoms.
|
|
1582
|
+
Y_m (torch.Tensor): [B, N] - Mask indicating which ligand atoms
|
|
1583
|
+
are valid.
|
|
1584
|
+
Y_t (torch.Tensor): [B, N] - Element types of the ligand atoms
|
|
1585
|
+
(atomic numbers).
|
|
1586
|
+
X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] - 3D
|
|
1587
|
+
coordinates of the virtual atoms for each residue.
|
|
1588
|
+
X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] -
|
|
1589
|
+
Mask indicating which virtual atoms are valid.
|
|
1590
|
+
residue_mask (torch.Tensor): [B, L] - Mask indicating which residues
|
|
1591
|
+
are valid.
|
|
1592
|
+
Returns:
|
|
1593
|
+
ligand_subgraph_Y (torch.Tensor):
|
|
1594
|
+
[B, L, M, 3] - 3D coordinates of the nearest ligand atoms to the
|
|
1595
|
+
virtual atoms for each residue.
|
|
1596
|
+
ligand_subgraph_Y_m (torch.Tensor):
|
|
1597
|
+
[B, L, M] - Mask indicating which nearest ligand atoms to the
|
|
1598
|
+
virtual atoms are valid.
|
|
1599
|
+
ligand_subgraph_Y_t (torch.Tensor):
|
|
1600
|
+
[B, L, M] - Element types of the nearest ligand atoms to the
|
|
1601
|
+
virtual atoms for each residue.
|
|
1602
|
+
"""
|
|
1603
|
+
B, L, _, _ = X_virtual_atoms.shape
|
|
1604
|
+
|
|
1605
|
+
# Gather the nearest ligand atoms to the virtual atoms for each
|
|
1606
|
+
# residue.
|
|
1607
|
+
# ligand_subgraph_Y [B, L, M, 3] - 3D coordinates of the nearest ligand
|
|
1608
|
+
# atoms to the virtual atoms for each residue.
|
|
1609
|
+
# ligand_subgraph_Y_m [B, L, M] - mask indicating which nearest ligand
|
|
1610
|
+
# atoms to the virtual atoms are valid.
|
|
1611
|
+
# ligand_subgraph_Y_t [B, L, M] - element types of the nearest ligand
|
|
1612
|
+
# atoms to the virtual atoms for each residue.
|
|
1613
|
+
ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t = (
|
|
1614
|
+
self.gather_nearest_per_residue_atoms(
|
|
1615
|
+
Y[:, None, :, :].expand(-1, L, -1, -1),
|
|
1616
|
+
Y_m[:, None, :].expand(-1, L, -1),
|
|
1617
|
+
Y_t[:, None, :].expand(-1, L, -1),
|
|
1618
|
+
X_virtual_atoms,
|
|
1619
|
+
X_m_virtual_atoms,
|
|
1620
|
+
residue_mask,
|
|
1621
|
+
)
|
|
1622
|
+
)
|
|
1623
|
+
|
|
1624
|
+
return ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t
|
|
1625
|
+
|
|
1626
|
+
def gather_nearest_atomized_side_chain_atoms(
|
|
1627
|
+
self, X, X_m, S, E_idx, hide_side_chain_mask
|
|
1628
|
+
):
|
|
1629
|
+
"""
|
|
1630
|
+
Given the 3D coordinates of the polymer atoms, their mask, the indices
|
|
1631
|
+
of the top K nearest neighbors for each residue, and a mask indicating
|
|
1632
|
+
which side chains are hidden, gather the nearest neighbors side chain
|
|
1633
|
+
atoms for each residue. This is used to construct the atomized side
|
|
1634
|
+
chain atoms for each residue.
|
|
1635
|
+
|
|
1636
|
+
Args:
|
|
1637
|
+
X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of the
|
|
1638
|
+
polymer atoms.
|
|
1639
|
+
X_m (torch.Tensor): [B, L, num_atoms] - Mask indicating which
|
|
1640
|
+
polymer atoms are valid.
|
|
1641
|
+
S (torch.Tensor): [B, L] - Sequence of the polymer residues.
|
|
1642
|
+
E_idx (torch.Tensor): [B, L, K] - Indices of the top K nearest
|
|
1643
|
+
neighbors for each residue.
|
|
1644
|
+
hide_side_chain_mask (torch.Tensor): [B, L] - Mask indicating which
|
|
1645
|
+
residue side chains are hidden and which are revealed. True
|
|
1646
|
+
indicates that the side chain is hidden and False indicates
|
|
1647
|
+
that the side chain is revealed.
|
|
1648
|
+
Returns:
|
|
1649
|
+
ligand_subgraph_R (torch.Tensor): [B, L,
|
|
1650
|
+
num_neighbors_for_atomized_side_chain * num_side_chain_atoms, 3]
|
|
1651
|
+
- 3D coordinates of the nearest neighbors side chain atoms for
|
|
1652
|
+
each residue.
|
|
1653
|
+
ligand_subgraph_R_m (torch.Tensor): [B, L,
|
|
1654
|
+
num_neighbors_for_atomized_side_chain * num_side_chain_atoms] -
|
|
1655
|
+
mask indicating which nearest neighbors side chain atoms are
|
|
1656
|
+
valid.
|
|
1657
|
+
ligand_subgraph_R_t (torch.Tensor): [B, L,
|
|
1658
|
+
num_neighbors_for_atomized_side_chain * num_side_chain_atoms] -
|
|
1659
|
+
Element types of the nearest neighbors side chain atoms for each
|
|
1660
|
+
residue.
|
|
1661
|
+
"""
|
|
1662
|
+
B, L, _, _ = X.shape
|
|
1663
|
+
|
|
1664
|
+
# X_side_chain [B, L, len(self.SIDE_CHAIN_ATOM_NAMES), 3] - 3D
|
|
1665
|
+
# coordinates of the side chain atoms for each residue.
|
|
1666
|
+
# X_m_side_chain [B, L, len(self.SIDE_CHAIN_ATOM_NAMES)] - mask
|
|
1667
|
+
# indicating which side chain atoms are valid.
|
|
1668
|
+
# NOTE: the side chain atoms exclude the CB atom, since in other
|
|
1669
|
+
# places, we use the virtual CB atom.
|
|
1670
|
+
X_side_chain, X_m_side_chain = self.construct_X_side_chain(X, X_m, S)
|
|
1671
|
+
|
|
1672
|
+
# E_idx_sub [B, L, self.num_neighbors_for_atomized_side_chain] -
|
|
1673
|
+
# Indices of the nearest neighbors to consider for atomized side chain
|
|
1674
|
+
# atoms.
|
|
1675
|
+
E_idx_sub = E_idx[:, :, : self.num_neighbors_for_atomized_side_chain]
|
|
1676
|
+
|
|
1677
|
+
# ligand_subgraph_R [B, L, self.num_neighbors_for_atomized_side_chain *
|
|
1678
|
+
# self.num_side_chain_atoms, 3] - 3D coordinates of the nearest
|
|
1679
|
+
# neighbors side chain atoms for each residue.
|
|
1680
|
+
ligand_subgraph_R = gather_nodes(
|
|
1681
|
+
X_side_chain.view(B, L, self.num_side_chain_atoms * 3), E_idx_sub
|
|
1682
|
+
).view(
|
|
1683
|
+
B,
|
|
1684
|
+
L,
|
|
1685
|
+
self.num_neighbors_for_atomized_side_chain * self.num_side_chain_atoms,
|
|
1686
|
+
3,
|
|
1687
|
+
)
|
|
1688
|
+
|
|
1689
|
+
# ligand_subgraph_R_m [B, L, self.num_neighbors_for_atomized_side_chain
|
|
1690
|
+
# * self.num_side_chain_atoms] - mask indicating which nearest
|
|
1691
|
+
# neighbors side chain atoms are valid.
|
|
1692
|
+
ligand_subgraph_R_m = gather_nodes(
|
|
1693
|
+
X_m_side_chain & (~(hide_side_chain_mask[:, :, None].bool())), E_idx_sub
|
|
1694
|
+
).view(
|
|
1695
|
+
B, L, self.num_neighbors_for_atomized_side_chain * self.num_side_chain_atoms
|
|
1696
|
+
)
|
|
1697
|
+
|
|
1698
|
+
# ligand_subgraph_R_t [B, L, self.num_neighbors_for_atomized_side_chain
|
|
1699
|
+
# * self.num_side_chain_atoms] - element types of the nearest
|
|
1700
|
+
# neighbors side chain atoms for each residue.
|
|
1701
|
+
ligand_subgraph_R_t = (
|
|
1702
|
+
self.side_chain_atom_types[None, None, None, :]
|
|
1703
|
+
.expand(B, L, self.num_neighbors_for_atomized_side_chain, -1)
|
|
1704
|
+
.reshape(
|
|
1705
|
+
B,
|
|
1706
|
+
L,
|
|
1707
|
+
self.num_neighbors_for_atomized_side_chain * self.num_side_chain_atoms,
|
|
1708
|
+
)
|
|
1709
|
+
)
|
|
1710
|
+
|
|
1711
|
+
return ligand_subgraph_R, ligand_subgraph_R_m, ligand_subgraph_R_t
|
|
1712
|
+
|
|
1713
|
+
def combine_ligand_and_atomized_side_chain_atoms(
|
|
1714
|
+
self,
|
|
1715
|
+
ligand_subgraph_Y,
|
|
1716
|
+
ligand_subgraph_Y_m,
|
|
1717
|
+
ligand_subgraph_Y_t,
|
|
1718
|
+
ligand_subgraph_R,
|
|
1719
|
+
ligand_subgraph_R_m,
|
|
1720
|
+
ligand_subgraph_R_t,
|
|
1721
|
+
X_virtual_atoms,
|
|
1722
|
+
X_m_virtual_atoms,
|
|
1723
|
+
residue_mask,
|
|
1724
|
+
):
|
|
1725
|
+
"""
|
|
1726
|
+
Given the 3D coordinates of the nearest ligand atoms to the virtual
|
|
1727
|
+
atoms, their mask, the element types of the nearest ligand atoms, the
|
|
1728
|
+
3D coordinates of the nearest neighbors side chain atoms, their mask,
|
|
1729
|
+
and the element types of the nearest neighbors side chain atoms,
|
|
1730
|
+
combine the ligand and side chain atoms into a single tensor.
|
|
1731
|
+
|
|
1732
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
1733
|
+
residue subgraph.
|
|
1734
|
+
|
|
1735
|
+
Args:
|
|
1736
|
+
ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
|
|
1737
|
+
the nearest ligand atoms to the virtual atoms for each residue.
|
|
1738
|
+
ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask indicating
|
|
1739
|
+
which nearest ligand atoms to the virtual atoms are valid.
|
|
1740
|
+
ligand_subgraph_Y_t (torch.Tensor): [B, L, M] - element types of the
|
|
1741
|
+
nearest ligand atoms to the virtual atoms for each residue.
|
|
1742
|
+
ligand_subgraph_R (torch.Tensor): [B, L,
|
|
1743
|
+
self.num_neighbors_for_atomized_side_chain *
|
|
1744
|
+
self.num_side_chain_atoms, 3] - 3D coordinates of the nearest
|
|
1745
|
+
neighbors side chain atoms for each residue.
|
|
1746
|
+
ligand_subgraph_R_m (torch.Tensor): [B, L,
|
|
1747
|
+
self.num_neighbors_for_atomized_side_chain *
|
|
1748
|
+
self.num_side_chain_atoms] - mask indicating which nearest
|
|
1749
|
+
neighbors side chain atoms are valid.
|
|
1750
|
+
ligand_subgraph_R_t (torch.Tensor): [B, L,
|
|
1751
|
+
self.num_neighbors_for_atomized_side_chain *
|
|
1752
|
+
self.num_side_chain_atoms] - element types of the nearest
|
|
1753
|
+
neighbors side chain atoms for each residue.
|
|
1754
|
+
X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] - 3D
|
|
1755
|
+
coordinates of the virtual atoms for each residue.
|
|
1756
|
+
X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] - mask
|
|
1757
|
+
indicating which virtual atoms are valid.
|
|
1758
|
+
residue_mask (torch.Tensor): [B, L] - mask indicating which
|
|
1759
|
+
residues are valid.
|
|
1760
|
+
Returns:
|
|
1761
|
+
ligand_subgraph_Y_and_R (torch.Tensor): [B, L, M, 3] - 3D
|
|
1762
|
+
coordinates of the nearest ligand or side chain atoms to the
|
|
1763
|
+
virtual atoms for each residue.
|
|
1764
|
+
ligand_subgraph_Y_m_and_R_m (torch.Tensor): [B, L, M] - mask
|
|
1765
|
+
indicating which nearest ligand or side chain atoms to the
|
|
1766
|
+
virtual atoms are valid.
|
|
1767
|
+
ligand_subgraph_Y_t_and_R_t (torch.Tensor): [B, L, M] - element
|
|
1768
|
+
types of the nearest ligand or side chain atoms to the virtual
|
|
1769
|
+
atoms for each residue.
|
|
1770
|
+
"""
|
|
1771
|
+
# Concatenate the ligand and side chain atom coordinates, masks, and
|
|
1772
|
+
# types.
|
|
1773
|
+
# ligand_subgraph_Y_cat_R [B, L, M +
|
|
1774
|
+
# self.num_neighbors_for_atomized_side_chain *
|
|
1775
|
+
# self.num_side_chain_atoms, 3] - 3D coordinates of the nearest ligand
|
|
1776
|
+
# atoms and side chain atoms for each residue.
|
|
1777
|
+
ligand_subgraph_Y_cat_R = torch.cat(
|
|
1778
|
+
(ligand_subgraph_Y, ligand_subgraph_R), dim=2
|
|
1779
|
+
)
|
|
1780
|
+
|
|
1781
|
+
# ligand_subgraph_Y_m_cat_R_m [B, L, M +
|
|
1782
|
+
# self.num_neighbors_for_atomized_side_chain *
|
|
1783
|
+
# self.num_side_chain_atoms] - mask indicating which nearest ligand
|
|
1784
|
+
# atoms and side chain atoms are valid.
|
|
1785
|
+
ligand_subgraph_Y_m_cat_R_m = torch.cat(
|
|
1786
|
+
(ligand_subgraph_Y_m, ligand_subgraph_R_m), dim=2
|
|
1787
|
+
)
|
|
1788
|
+
|
|
1789
|
+
# ligand_subgraph_Y_t_cat_R_t [B, L, M +
|
|
1790
|
+
# self.num_neighbors_for_atomized_side_chain *
|
|
1791
|
+
# self.num_side_chain_atoms] - element types of the nearest ligand
|
|
1792
|
+
# atoms and side chain atoms for each residue.
|
|
1793
|
+
ligand_subgraph_Y_t_cat_R_t = torch.cat(
|
|
1794
|
+
(ligand_subgraph_Y_t, ligand_subgraph_R_t), dim=2
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
# Gather the nearest atoms to the virtual atoms from the combined
|
|
1798
|
+
# ligand and side chain atoms.
|
|
1799
|
+
# ligand_subgraph_Y_and_R [B, L, M, 3] - 3D coordinates of the nearest
|
|
1800
|
+
# ligand or side chain atoms to the virtual atoms for each residue.
|
|
1801
|
+
# ligand_subgraph_Y_m_and_R_m [B, L, M] - mask indicating which nearest
|
|
1802
|
+
# ligand or side chain atoms to the virtual atoms are valid.
|
|
1803
|
+
# ligand_subgraph_Y_t_and_R_t [B, L, M] - element types of the nearest
|
|
1804
|
+
# ligand or side chain atoms to the virtual atoms for each residue.
|
|
1805
|
+
(
|
|
1806
|
+
ligand_subgraph_Y_and_R,
|
|
1807
|
+
ligand_subgraph_Y_m_and_R_m,
|
|
1808
|
+
ligand_subgraph_Y_t_and_R_t,
|
|
1809
|
+
) = self.gather_nearest_per_residue_atoms(
|
|
1810
|
+
ligand_subgraph_Y_cat_R,
|
|
1811
|
+
ligand_subgraph_Y_m_cat_R_m,
|
|
1812
|
+
ligand_subgraph_Y_t_cat_R_t,
|
|
1813
|
+
X_virtual_atoms,
|
|
1814
|
+
X_m_virtual_atoms,
|
|
1815
|
+
residue_mask,
|
|
1816
|
+
)
|
|
1817
|
+
|
|
1818
|
+
return (
|
|
1819
|
+
ligand_subgraph_Y_and_R,
|
|
1820
|
+
ligand_subgraph_Y_m_and_R_m,
|
|
1821
|
+
ligand_subgraph_Y_t_and_R_t,
|
|
1822
|
+
)
|
|
1823
|
+
|
|
1824
|
+
def featurize_ligand_atom_type_information(self, ligand_subgraph_Y_t):
|
|
1825
|
+
"""
|
|
1826
|
+
Given the element types of the ligand atoms, compute the periodic table
|
|
1827
|
+
group, period, and atomic number for each ligand atom.
|
|
1828
|
+
|
|
1829
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
1830
|
+
residue subgraph.
|
|
1831
|
+
|
|
1832
|
+
Args:
|
|
1833
|
+
ligand_subgraph_Y_t (torch.Tensor): [B, L, M] - element types of the
|
|
1834
|
+
ligand atoms.
|
|
1835
|
+
Returns:
|
|
1836
|
+
ligand_subgraph_Y_t_concat_one_hot (torch.Tensor):
|
|
1837
|
+
[B, L, M, self.num_atomic_numbers +
|
|
1838
|
+
self.num_periodic_table_groups +
|
|
1839
|
+
self.num_periodic_table_periods] - atomic number,
|
|
1840
|
+
periodic group, and periodic period of the ligand atoms, as
|
|
1841
|
+
concatenated one-hot encodings.
|
|
1842
|
+
"""
|
|
1843
|
+
# Get the periodic table group and period for the ligand atoms.
|
|
1844
|
+
ligand_subgraph_Y_t = ligand_subgraph_Y_t.long()
|
|
1845
|
+
|
|
1846
|
+
# ligand_subgraph_Y_t_g [B, L, M] - periodic group of the ligand atoms.
|
|
1847
|
+
# 18 groups and 1 null group.
|
|
1848
|
+
ligand_subgraph_Y_t_g = self.periodic_table_groups[ligand_subgraph_Y_t]
|
|
1849
|
+
|
|
1850
|
+
# ligand_subgraph_Y_t_p [B, L, M] - periodic period of the ligand atoms.
|
|
1851
|
+
# 7 periods and 1 null period.
|
|
1852
|
+
ligand_subgraph_Y_t_p = self.periodic_table_periods[ligand_subgraph_Y_t]
|
|
1853
|
+
|
|
1854
|
+
# Turn the ligand atom types into one-hot encodings.
|
|
1855
|
+
# ligand_subgraph_Y_t_g_one_hot [B, L, M,
|
|
1856
|
+
# self.num_periodic_table_groups] - periodic group of the ligand atoms.
|
|
1857
|
+
ligand_subgraph_Y_t_g_one_hot = torch.nn.functional.one_hot(
|
|
1858
|
+
ligand_subgraph_Y_t_g, self.num_periodic_table_groups
|
|
1859
|
+
)
|
|
1860
|
+
|
|
1861
|
+
# ligand_subgraph_Y_t_p_one_hot [B, L, M,
|
|
1862
|
+
# self.num_periodic_table_periods] - periodic period of the ligand
|
|
1863
|
+
# atoms.
|
|
1864
|
+
ligand_subgraph_Y_t_p_one_hot = torch.nn.functional.one_hot(
|
|
1865
|
+
ligand_subgraph_Y_t_p, self.num_periodic_table_periods
|
|
1866
|
+
)
|
|
1867
|
+
|
|
1868
|
+
# ligand_subgraph_Y_t_one_hot [B, L, M, self.num_atomic_numbers] -
|
|
1869
|
+
# atomic number of the ligand atoms.
|
|
1870
|
+
ligand_subgraph_Y_t_one_hot = torch.nn.functional.one_hot(
|
|
1871
|
+
ligand_subgraph_Y_t, self.num_atomic_numbers
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
# Concatenate the one-hot encodings.
|
|
1875
|
+
# ligand_subgraph_Y_t_concat_one_hot [B, L, M, self.num_atomic_numbers +
|
|
1876
|
+
# self.num_periodic_table_groups + self.num_periodic_table_periods]
|
|
1877
|
+
# - atomic number, periodic group, and periodic period of the ligand
|
|
1878
|
+
# atoms.
|
|
1879
|
+
ligand_subgraph_Y_t_concat_one_hot = torch.cat(
|
|
1880
|
+
(
|
|
1881
|
+
ligand_subgraph_Y_t_one_hot,
|
|
1882
|
+
ligand_subgraph_Y_t_g_one_hot,
|
|
1883
|
+
ligand_subgraph_Y_t_p_one_hot,
|
|
1884
|
+
),
|
|
1885
|
+
dim=-1,
|
|
1886
|
+
)
|
|
1887
|
+
|
|
1888
|
+
return ligand_subgraph_Y_t_concat_one_hot
|
|
1889
|
+
|
|
1890
|
+
def featurize_protein_to_ligand_subgraph_edges(
|
|
1891
|
+
self,
|
|
1892
|
+
ligand_subgraph_Y_t_concat_one_hot,
|
|
1893
|
+
X_backbone,
|
|
1894
|
+
X_m_backbone,
|
|
1895
|
+
X_virtual_atoms,
|
|
1896
|
+
X_m_virtual_atoms,
|
|
1897
|
+
ligand_subgraph_Y,
|
|
1898
|
+
ligand_subgraph_Y_m,
|
|
1899
|
+
eps=1e-6,
|
|
1900
|
+
):
|
|
1901
|
+
"""
|
|
1902
|
+
Given the 3D coordinates of the backbone atoms, the virtual atoms,
|
|
1903
|
+
the ligand atoms, and the mask indicating which atoms are valid,
|
|
1904
|
+
compute the protein to ligand subgraph edges.
|
|
1905
|
+
|
|
1906
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
1907
|
+
residue subgraph.
|
|
1908
|
+
|
|
1909
|
+
Args:
|
|
1910
|
+
ligand_subgraph_Y_t_concat_one_hot (torch.Tensor):
|
|
1911
|
+
[B, L, M, self.num_atomic_numbers +
|
|
1912
|
+
self.num_periodic_table_groups +
|
|
1913
|
+
self.num_periodic_table_periods] - atomic number,
|
|
1914
|
+
periodic group, and periodic period of the ligand atoms.
|
|
1915
|
+
X_backbone (torch.Tensor): [B, L, self.num_backbone_atoms, 3] - 3D
|
|
1916
|
+
coordinates of the backbone atoms for each residue.
|
|
1917
|
+
X_m_backbone (torch.Tensor): [B, L, self.num_backbone_atoms] - mask
|
|
1918
|
+
indicating which backbone atoms are valid.
|
|
1919
|
+
X_virtual_atoms (torch.Tensor): [B, L, self.num_virtual_atoms, 3] -
|
|
1920
|
+
3D coordinates of the virtual atoms for each residue.
|
|
1921
|
+
X_m_virtual_atoms (torch.Tensor): [B, L, self.num_virtual_atoms] -
|
|
1922
|
+
mask indicating which virtual atoms are valid.
|
|
1923
|
+
ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
|
|
1924
|
+
the ligand atoms.
|
|
1925
|
+
ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask indicating
|
|
1926
|
+
which ligand atoms are valid.
|
|
1927
|
+
eps (float): Small value added to distances that are zero.
|
|
1928
|
+
Returns:
|
|
1929
|
+
E_protein_to_ligand (torch.Tensor):
|
|
1930
|
+
[B, L, M, self.num_node_output_features] -
|
|
1931
|
+
protein to ligand subgraph edges; can also be considered node
|
|
1932
|
+
features of the protein residues (although they are not used as
|
|
1933
|
+
such).
|
|
1934
|
+
"""
|
|
1935
|
+
B, L, M, _ = ligand_subgraph_Y_t_concat_one_hot.shape
|
|
1936
|
+
|
|
1937
|
+
# Embed the ligand atom type information.
|
|
1938
|
+
# ligand_subgraph_Y_t_concat_one_hot_embed
|
|
1939
|
+
# [B, L, M, self.num_atom_type_output_features] - embedded atomic
|
|
1940
|
+
# number, periodic group, and periodic period of the ligand atoms.
|
|
1941
|
+
ligand_subgraph_Y_t_concat_one_hot_embed = self.embed_atom_type_features(
|
|
1942
|
+
ligand_subgraph_Y_t_concat_one_hot.float()
|
|
1943
|
+
)
|
|
1944
|
+
|
|
1945
|
+
# Concatenate the backbone and virtual atom coordinates and masks.
|
|
1946
|
+
# X_backbone_and_virtual_atoms [B, L, num_backbone_atoms +
|
|
1947
|
+
# num_virtual_atoms, 3] - 3D coordinates of the backbone and virtual
|
|
1948
|
+
# atoms for each residue.
|
|
1949
|
+
# X_m_backbone_and_virtual_atoms [B, L, num_backbone_atoms +
|
|
1950
|
+
# num_virtual_atoms] - mask indicating which backbone and virtual
|
|
1951
|
+
# atoms are valid.
|
|
1952
|
+
X_backbone_and_virtual_atoms = torch.cat((X_backbone, X_virtual_atoms), dim=2)
|
|
1953
|
+
X_m_backbone_and_virtual_atoms = torch.cat(
|
|
1954
|
+
(X_m_backbone, X_m_virtual_atoms), dim=2
|
|
1955
|
+
)
|
|
1956
|
+
|
|
1957
|
+
# Compute the distance of each ligand atom in each residue subgraph to
|
|
1958
|
+
# to each of the backbone and virtual atoms.
|
|
1959
|
+
# D_ligand_to_backbone_or_virtual [B, L, M, self.num_backbone_atoms +
|
|
1960
|
+
# self.num_virtual_atoms] - distances of each ligand atom in each
|
|
1961
|
+
# residue subgraph to each of the backbone and virtual atoms.
|
|
1962
|
+
D_ligand_to_backbone_or_virtual = torch.sqrt(
|
|
1963
|
+
torch.sum(
|
|
1964
|
+
(
|
|
1965
|
+
ligand_subgraph_Y[:, :, :, None, :]
|
|
1966
|
+
- X_backbone_and_virtual_atoms[:, :, None, :, :]
|
|
1967
|
+
)
|
|
1968
|
+
** 2,
|
|
1969
|
+
dim=-1,
|
|
1970
|
+
)
|
|
1971
|
+
+ eps
|
|
1972
|
+
)
|
|
1973
|
+
|
|
1974
|
+
# RBF_ligand_to_backbone_or_virtual [B, L, M, self.num_backbone_atoms +
|
|
1975
|
+
# self.num_virtual_atoms, num_rbf] - radial basis function embedding
|
|
1976
|
+
# of the distances of each ligand atom in each residue subgraph to
|
|
1977
|
+
# each of the backbone and virtual atoms.
|
|
1978
|
+
RBF_ligand_to_backbone_or_virtual = self.compute_rbf_embedding_from_distances(
|
|
1979
|
+
D_ligand_to_backbone_or_virtual
|
|
1980
|
+
)
|
|
1981
|
+
|
|
1982
|
+
# Mask the radial basis function embedding with the ligand atom mask and
|
|
1983
|
+
# the backbone and virtual atom mask.
|
|
1984
|
+
RBF_ligand_to_backbone_or_virtual = (
|
|
1985
|
+
RBF_ligand_to_backbone_or_virtual
|
|
1986
|
+
* ligand_subgraph_Y_m[:, :, :, None, None]
|
|
1987
|
+
* X_m_backbone_and_virtual_atoms[:, :, None, :, None]
|
|
1988
|
+
)
|
|
1989
|
+
|
|
1990
|
+
# Reshape the radial basis function embedding.
|
|
1991
|
+
# RBF_ligand_to_backbone_or_virtual [B, L, M,
|
|
1992
|
+
# (self.num_backbone_atoms + self.num_virtual_atoms) * num_rbf] -
|
|
1993
|
+
RBF_ligand_to_backbone_or_virtual = RBF_ligand_to_backbone_or_virtual.view(
|
|
1994
|
+
B, L, M, (self.num_backbone_atoms + self.num_virtual_atoms) * self.num_rbf
|
|
1995
|
+
)
|
|
1996
|
+
|
|
1997
|
+
# Construct the angle features for the ligand atoms with respect to the
|
|
1998
|
+
# backbone atoms.
|
|
1999
|
+
# angle_features [B, L, M, 4] - angle features for the ligand atoms
|
|
2000
|
+
# with respect to the backbone atoms.
|
|
2001
|
+
angle_features = self.construct_angle_features(
|
|
2002
|
+
X_backbone[:, :, self.BACKBONE_ATOM_NAMES.index("CA"), :],
|
|
2003
|
+
X_backbone[:, :, self.BACKBONE_ATOM_NAMES.index("N"), :],
|
|
2004
|
+
X_backbone[:, :, self.BACKBONE_ATOM_NAMES.index("C"), :],
|
|
2005
|
+
ligand_subgraph_Y,
|
|
2006
|
+
)
|
|
2007
|
+
|
|
2008
|
+
# E_protein_to_ligand [B, L, M,
|
|
2009
|
+
# (self.num_backbone_atoms + self.num_virtual_atoms) * num_rbf +
|
|
2010
|
+
# self.num_atom_type_output_features + 4] - concatenated
|
|
2011
|
+
# radial basis function embedding of the distances of each ligand atom
|
|
2012
|
+
# in each residue subgraph to each of the backbone and virtual atoms,
|
|
2013
|
+
# periodic group, periodic period, and atomic number of the ligand
|
|
2014
|
+
# atoms, and the angle features for the ligand atoms with respect to
|
|
2015
|
+
# the backbone atoms.
|
|
2016
|
+
E_protein_to_ligand = torch.cat(
|
|
2017
|
+
(
|
|
2018
|
+
RBF_ligand_to_backbone_or_virtual,
|
|
2019
|
+
ligand_subgraph_Y_t_concat_one_hot_embed,
|
|
2020
|
+
angle_features,
|
|
2021
|
+
),
|
|
2022
|
+
dim=-1,
|
|
2023
|
+
)
|
|
2024
|
+
|
|
2025
|
+
# While these are protein-ligand subgraph edges, they can also be
|
|
2026
|
+
# considered node features of the protein residues.
|
|
2027
|
+
# E_protein_to_ligand [B, L, M, self.num_node_output_features] - protein
|
|
2028
|
+
# to ligand subgraph edges.
|
|
2029
|
+
E_protein_to_ligand = self.node_embedding(E_protein_to_ligand)
|
|
2030
|
+
E_protein_to_ligand = self.node_norm(E_protein_to_ligand)
|
|
2031
|
+
|
|
2032
|
+
return E_protein_to_ligand
|
|
2033
|
+
|
|
2034
|
+
def featurize_ligand_subgraph_nodes(self, ligand_subgraph_Y_t_concat_one_hot):
|
|
2035
|
+
"""
|
|
2036
|
+
Given the atomic number, periodic group, and periodic period of the
|
|
2037
|
+
ligand atoms, compute the ligand subgraph node features.
|
|
2038
|
+
|
|
2039
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
2040
|
+
residue subgraph.
|
|
2041
|
+
|
|
2042
|
+
Args:
|
|
2043
|
+
ligand_subgraph_Y_t_concat_one_hot (torch.Tensor):
|
|
2044
|
+
[B, L, M, self.num_atomic_numbers +
|
|
2045
|
+
self.num_periodic_table_groups +
|
|
2046
|
+
self.num_periodic_table_periods] - atomic number,
|
|
2047
|
+
periodic group, and periodic period of the ligand atoms.
|
|
2048
|
+
Returns:
|
|
2049
|
+
ligand_subgraph_nodes (torch.Tensor): [B, L, M,
|
|
2050
|
+
self.num_node_output_features] - ligand atom type information,
|
|
2051
|
+
embedded as node features.
|
|
2052
|
+
"""
|
|
2053
|
+
# Embed and normalize the ligand atom type information.
|
|
2054
|
+
# ligand_subgraph_nodes [B, L, M, self.num_atom_type_output_features] -
|
|
2055
|
+
# embedded atomic number, periodic group, and periodic period of the
|
|
2056
|
+
# ligand atoms.
|
|
2057
|
+
ligand_subgraph_nodes = self.ligand_subgraph_node_embedding(
|
|
2058
|
+
ligand_subgraph_Y_t_concat_one_hot.float()
|
|
2059
|
+
)
|
|
2060
|
+
ligand_subgraph_nodes = self.ligand_subgraph_node_norm(ligand_subgraph_nodes)
|
|
2061
|
+
|
|
2062
|
+
return ligand_subgraph_nodes
|
|
2063
|
+
|
|
2064
|
+
def featurize_ligand_subgraph_edges(
|
|
2065
|
+
self, ligand_subgraph_Y, ligand_subgraph_Y_m, eps=1e-6
|
|
2066
|
+
):
|
|
2067
|
+
"""
|
|
2068
|
+
Given the 3D coordinates of the ligand atoms and the mask indicating
|
|
2069
|
+
which atoms are valid, compute the ligand subgraph edges.
|
|
2070
|
+
|
|
2071
|
+
NOTE: M = self.num_context_atoms, the number of ligand atoms in each
|
|
2072
|
+
residue subgraph.
|
|
2073
|
+
|
|
2074
|
+
Args:
|
|
2075
|
+
ligand_subgraph_Y (torch.Tensor): [B, L, M, 3] - 3D coordinates of
|
|
2076
|
+
the ligand atoms.
|
|
2077
|
+
ligand_subgraph_Y_m (torch.Tensor): [B, L, M] - mask indicating
|
|
2078
|
+
which ligand atoms are valid.
|
|
2079
|
+
eps (float): Small value added to distances that are zero.
|
|
2080
|
+
Returns:
|
|
2081
|
+
ligand_subgraph_edges (torch.Tensor):
|
|
2082
|
+
[B, L, M, M, self.num_edge_output_features] - embedded and
|
|
2083
|
+
normalized radial basis function embedding of the distances
|
|
2084
|
+
between the ligand atoms in each residue subgraph.
|
|
2085
|
+
"""
|
|
2086
|
+
# D_ligand_to_ligand [B, L, M, M] - distances between the ligand atoms
|
|
2087
|
+
# in each residue subgraph.
|
|
2088
|
+
D_ligand_to_ligand = torch.sqrt(
|
|
2089
|
+
torch.sum(
|
|
2090
|
+
(
|
|
2091
|
+
ligand_subgraph_Y[:, :, :, None, :]
|
|
2092
|
+
- ligand_subgraph_Y[:, :, None, :, :]
|
|
2093
|
+
)
|
|
2094
|
+
** 2,
|
|
2095
|
+
dim=-1,
|
|
2096
|
+
)
|
|
2097
|
+
+ eps
|
|
2098
|
+
)
|
|
2099
|
+
|
|
2100
|
+
# RBF_ligand_to_ligand [B, L, M, M, num_rbf] - radial basis function
|
|
2101
|
+
# embedding of the distances between the ligand atoms in each residue
|
|
2102
|
+
# subgraph.
|
|
2103
|
+
RBF_ligand_to_ligand = self.compute_rbf_embedding_from_distances(
|
|
2104
|
+
D_ligand_to_ligand
|
|
2105
|
+
)
|
|
2106
|
+
|
|
2107
|
+
# Mask the radial basis function embedding with the ligand atom mask.
|
|
2108
|
+
RBF_ligand_to_ligand = (
|
|
2109
|
+
RBF_ligand_to_ligand
|
|
2110
|
+
* ligand_subgraph_Y_m[:, :, :, None, None]
|
|
2111
|
+
* ligand_subgraph_Y_m[:, :, None, :, None]
|
|
2112
|
+
)
|
|
2113
|
+
|
|
2114
|
+
# ligand_subgraph_edges [B, L, M, M, self.num_edge_output_features] -
|
|
2115
|
+
# embedded and normalized radial basis function embedding of the
|
|
2116
|
+
# distances between the ligand atoms in each residue subgraph.
|
|
2117
|
+
ligand_subgraph_edges = self.ligand_subgraph_edge_embedding(
|
|
2118
|
+
RBF_ligand_to_ligand
|
|
2119
|
+
)
|
|
2120
|
+
ligand_subgraph_edges = self.ligand_subgraph_edge_norm(ligand_subgraph_edges)
|
|
2121
|
+
|
|
2122
|
+
return ligand_subgraph_edges
|
|
2123
|
+
|
|
2124
|
+
def featurize_nodes(self, input_features, edge_features):
|
|
2125
|
+
"""
|
|
2126
|
+
Given the input features and edge features, compute the node features
|
|
2127
|
+
for the ligand atoms and the protein to ligand subgraph edges.
|
|
2128
|
+
|
|
2129
|
+
NOTE: N = the total number of ligand atoms.
|
|
2130
|
+
M = self.num_context_atoms, the number of ligand atoms in each
|
|
2131
|
+
residue subgraph.
|
|
2132
|
+
|
|
2133
|
+
Args:
|
|
2134
|
+
input_features (dict): Dictionary containing the input features.
|
|
2135
|
+
- X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of the
|
|
2136
|
+
polymer atoms.
|
|
2137
|
+
- X_m (torch.Tensor): [B, L, num_atoms] - mask indicating which
|
|
2138
|
+
polymer atoms are valid.
|
|
2139
|
+
- hide_side_chain_mask (torch.Tensor): [B, L] - mask
|
|
2140
|
+
indicating which residue side chains are hidden and which
|
|
2141
|
+
are revealed. True indicates that the side chain is hidden
|
|
2142
|
+
and False indicates that the side chain is revealed.
|
|
2143
|
+
- Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand
|
|
2144
|
+
atoms.
|
|
2145
|
+
- Y_m (torch.Tensor): [B, N] - mask indicating which ligand
|
|
2146
|
+
atoms are valid.
|
|
2147
|
+
- Y_t (torch.Tensor): [B, N] - element types of the ligand
|
|
2148
|
+
atoms.
|
|
2149
|
+
- X_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms, 3] -
|
|
2150
|
+
3D coordinates of the virtual atoms for each residue.
|
|
2151
|
+
- X_m_virtual_atoms (torch.Tensor): [B, L, num_virtual_atoms] -
|
|
2152
|
+
mask indicating which virtual atoms are valid.
|
|
2153
|
+
- residue_mask (torch.Tensor): [B, L] - mask indicating which
|
|
2154
|
+
residues are valid.
|
|
2155
|
+
- X_backbone (torch.Tensor): [B, L, num_backbone_atoms, 3] -
|
|
2156
|
+
3D coordinates of the backbone atoms for each residue.
|
|
2157
|
+
- X_m_backbone (torch.Tensor): [B, L, num_backbone_atoms] -
|
|
2158
|
+
mask indicating which backbone atoms are valid.
|
|
2159
|
+
- atomize_side_chains (bool): Whether to atomize the side chains
|
|
2160
|
+
of the residues. If True, the side chains of the residues
|
|
2161
|
+
not specified in the hide side chain mask will be
|
|
2162
|
+
atomized and added as ligand atoms.
|
|
2163
|
+
edge_features (dict): Dictionary containing the edge features.
|
|
2164
|
+
- E_idx (torch.Tensor): [B, L, K] - indices of the top K
|
|
2165
|
+
nearest neighbors for each residue.
|
|
2166
|
+
Returns:
|
|
2167
|
+
node_features (dict): Dictionary containing the node features.
|
|
2168
|
+
- E_protein_to_ligand (torch.Tensor):
|
|
2169
|
+
[B, L, M, self.num_node_output_features] - protein to
|
|
2170
|
+
ligand subgraph edges; can also be considered node features
|
|
2171
|
+
of the protein residues (although they are not used as
|
|
2172
|
+
such).
|
|
2173
|
+
- ligand_subgraph_nodes (torch.Tensor):
|
|
2174
|
+
[B, L, M, self.num_node_output_features] - ligand atom type
|
|
2175
|
+
information, embedded as node features.
|
|
2176
|
+
- ligand_subgraph_edges (torch.Tensor):
|
|
2177
|
+
[B, L, M, M, self.num_edge_output_features] - embedded and
|
|
2178
|
+
normalized radial basis function embedding of the distances
|
|
2179
|
+
between the ligand atoms in each residue subgraph.
|
|
2180
|
+
"""
|
|
2181
|
+
# Check that the needed input features are present.
|
|
2182
|
+
if "X" not in input_features:
|
|
2183
|
+
raise ValueError("Input features must contain 'X' key.")
|
|
2184
|
+
if "X_m" not in input_features:
|
|
2185
|
+
raise ValueError("Input features must contain 'X_m' key.")
|
|
2186
|
+
if "hide_side_chain_mask" not in input_features:
|
|
2187
|
+
raise ValueError("Input features must contain 'hide_side_chain_mask' key.")
|
|
2188
|
+
if "Y" not in input_features:
|
|
2189
|
+
raise ValueError("Input features must contain 'Y' key.")
|
|
2190
|
+
if "Y_m" not in input_features:
|
|
2191
|
+
raise ValueError("Input features must contain 'Y_m' key.")
|
|
2192
|
+
if "Y_t" not in input_features:
|
|
2193
|
+
raise ValueError("Input features must contain 'Y_t' key.")
|
|
2194
|
+
if "X_virtual_atoms" not in input_features:
|
|
2195
|
+
raise ValueError("Input features must contain 'X_virtual_atoms' key.")
|
|
2196
|
+
if "X_m_virtual_atoms" not in input_features:
|
|
2197
|
+
raise ValueError("Input features must contain 'X_m_virtual_atoms' key.")
|
|
2198
|
+
if "residue_mask" not in input_features:
|
|
2199
|
+
raise ValueError("Input features must contain 'residue_mask' key.")
|
|
2200
|
+
if "X_backbone" not in input_features:
|
|
2201
|
+
raise ValueError("Input features must contain 'X_backbone' key.")
|
|
2202
|
+
if "X_m_backbone" not in input_features:
|
|
2203
|
+
raise ValueError("Input features must contain 'X_m_backbone' key.")
|
|
2204
|
+
if "atomize_side_chains" not in input_features:
|
|
2205
|
+
raise ValueError("Input features must contain 'atomize_side_chains' key.")
|
|
2206
|
+
|
|
2207
|
+
# Check that the needed edge features are present.
|
|
2208
|
+
if "E_idx" not in edge_features:
|
|
2209
|
+
raise ValueError("Edge features must contain 'E_idx' key.")
|
|
2210
|
+
|
|
2211
|
+
atomize_side_chains = input_features["atomize_side_chains"]
|
|
2212
|
+
|
|
2213
|
+
# Gather the coordinates, mask and types of the ligand atoms closest
|
|
2214
|
+
# to the virtual atoms.
|
|
2215
|
+
# ligand_subgraph_Y [B, L, M, 3] - 3D coordinates of the ligand atoms
|
|
2216
|
+
# closest to the virtual atoms for each residue.
|
|
2217
|
+
# ligand_subgraph_Y_m [B, L, M] - mask indicating which ligand
|
|
2218
|
+
# atoms closest to the virtual atoms are valid.
|
|
2219
|
+
# ligand_subgraph_Y_t [B, L, M] - element types of the
|
|
2220
|
+
# ligand atoms closest to the virtual atoms for each residue.
|
|
2221
|
+
ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t = (
|
|
2222
|
+
self.gather_nearest_ligand_atoms(
|
|
2223
|
+
input_features["Y"],
|
|
2224
|
+
input_features["Y_m"],
|
|
2225
|
+
input_features["Y_t"],
|
|
2226
|
+
input_features["X_virtual_atoms"],
|
|
2227
|
+
input_features["X_m_virtual_atoms"],
|
|
2228
|
+
input_features["residue_mask"],
|
|
2229
|
+
)
|
|
2230
|
+
)
|
|
2231
|
+
|
|
2232
|
+
# Add atomized side chain atoms as ligand atoms.
|
|
2233
|
+
if atomize_side_chains:
|
|
2234
|
+
# Gather the atomized side chain atoms coordinates, mask, and types.
|
|
2235
|
+
# ligand_subgraph_R [B, L, num_neighbors_for_atomized_side_chain *
|
|
2236
|
+
# num_side_chain_atoms, 3] - 3D coordinates of the nearest neighbors
|
|
2237
|
+
# side chain atoms for each residue.
|
|
2238
|
+
# ligand_subgraph_R_m [B, L, num_neighbors_for_atomized_side_chain *
|
|
2239
|
+
# num_side_chain_atoms] - mask indicating which nearest neighbors
|
|
2240
|
+
# side chain atoms are valid.
|
|
2241
|
+
# ligand_subgraph_R_t [B, L, num_neighbors_for_atomized_side_chain
|
|
2242
|
+
# * num_side_chain_atoms] - element types of the nearest neighbors
|
|
2243
|
+
# side chain atoms for each residue.
|
|
2244
|
+
ligand_subgraph_R, ligand_subgraph_R_m, ligand_subgraph_R_t = (
|
|
2245
|
+
self.gather_nearest_atomized_side_chain_atoms(
|
|
2246
|
+
input_features["X"],
|
|
2247
|
+
input_features["X_m"],
|
|
2248
|
+
input_features["S"],
|
|
2249
|
+
edge_features["E_idx"],
|
|
2250
|
+
input_features["hide_side_chain_mask"],
|
|
2251
|
+
)
|
|
2252
|
+
)
|
|
2253
|
+
|
|
2254
|
+
# Get the self.num_context_atoms closest ligand or atomized side
|
|
2255
|
+
# chain atoms to the virtual atoms; overwriting the original
|
|
2256
|
+
# ligand_subgraph_Y, ligand_subgraph_Y_m, and ligand_subgraph_Y_t.
|
|
2257
|
+
ligand_subgraph_Y, ligand_subgraph_Y_m, ligand_subgraph_Y_t = (
|
|
2258
|
+
self.combine_ligand_and_atomized_side_chain_atoms(
|
|
2259
|
+
ligand_subgraph_Y,
|
|
2260
|
+
ligand_subgraph_Y_m,
|
|
2261
|
+
ligand_subgraph_Y_t,
|
|
2262
|
+
ligand_subgraph_R,
|
|
2263
|
+
ligand_subgraph_R_m,
|
|
2264
|
+
ligand_subgraph_R_t,
|
|
2265
|
+
input_features["X_virtual_atoms"],
|
|
2266
|
+
input_features["X_m_virtual_atoms"],
|
|
2267
|
+
input_features["residue_mask"],
|
|
2268
|
+
)
|
|
2269
|
+
)
|
|
2270
|
+
|
|
2271
|
+
# Save the ligand subgraph coordinates, mask, and types in the input
|
|
2272
|
+
# features.
|
|
2273
|
+
input_features["ligand_subgraph_Y"] = ligand_subgraph_Y
|
|
2274
|
+
input_features["ligand_subgraph_Y_m"] = ligand_subgraph_Y_m
|
|
2275
|
+
input_features["ligand_subgraph_Y_t"] = ligand_subgraph_Y_t
|
|
2276
|
+
|
|
2277
|
+
# Get the concatenated one hot type information for the ligand atoms.
|
|
2278
|
+
# ligand_subgraph_Y_t_concat_one_hot [B, L, M, self.num_atomic_numbers +
|
|
2279
|
+
# self.num_periodic_table_groups + self.num_periodic_table_periods] -
|
|
2280
|
+
# atomic number, periodic group, and periodic period of the ligand
|
|
2281
|
+
# atoms.
|
|
2282
|
+
ligand_subgraph_Y_t_concat_one_hot = (
|
|
2283
|
+
self.featurize_ligand_atom_type_information(
|
|
2284
|
+
input_features["ligand_subgraph_Y_t"]
|
|
2285
|
+
)
|
|
2286
|
+
)
|
|
2287
|
+
|
|
2288
|
+
# Get the protein to ligand subgraph edges.
|
|
2289
|
+
# E_protein_to_ligand [B, L, M, self.num_node_output_features] - protein
|
|
2290
|
+
# to ligand subgraph edges.
|
|
2291
|
+
E_protein_to_ligand = self.featurize_protein_to_ligand_subgraph_edges(
|
|
2292
|
+
ligand_subgraph_Y_t_concat_one_hot,
|
|
2293
|
+
input_features["X_backbone"],
|
|
2294
|
+
input_features["X_m_backbone"],
|
|
2295
|
+
input_features["X_virtual_atoms"],
|
|
2296
|
+
input_features["X_m_virtual_atoms"],
|
|
2297
|
+
input_features["ligand_subgraph_Y"],
|
|
2298
|
+
input_features["ligand_subgraph_Y_m"],
|
|
2299
|
+
)
|
|
2300
|
+
|
|
2301
|
+
# ligand_subgraph_nodes [B, L, M, self.num_node_output_features] -
|
|
2302
|
+
# ligand atom type information, embedded as node features.
|
|
2303
|
+
ligand_subgraph_nodes = self.featurize_ligand_subgraph_nodes(
|
|
2304
|
+
ligand_subgraph_Y_t_concat_one_hot
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
# ligand_subgraph_edges [B, L, M, M, self.num_edge_output_features] -
|
|
2308
|
+
# embedded and normalized radial basis function embedding of the
|
|
2309
|
+
# distances between the ligand atoms in each residue subgraph.
|
|
2310
|
+
ligand_subgraph_edges = self.featurize_ligand_subgraph_edges(
|
|
2311
|
+
input_features["ligand_subgraph_Y"], input_features["ligand_subgraph_Y_m"]
|
|
2312
|
+
)
|
|
2313
|
+
|
|
2314
|
+
# Gather the node features.
|
|
2315
|
+
node_features = {
|
|
2316
|
+
"E_protein_to_ligand": E_protein_to_ligand,
|
|
2317
|
+
"ligand_subgraph_nodes": ligand_subgraph_nodes,
|
|
2318
|
+
"ligand_subgraph_edges": ligand_subgraph_edges,
|
|
2319
|
+
}
|
|
2320
|
+
|
|
2321
|
+
return node_features
|
|
2322
|
+
|
|
2323
|
+
def noise_structure(self, input_features):
|
|
2324
|
+
"""
|
|
2325
|
+
Given input features containing 3D coordinates of atoms, add Gaussian
|
|
2326
|
+
noise to the coordinates.
|
|
2327
|
+
|
|
2328
|
+
Args:
|
|
2329
|
+
input_features (dict): Dictionary containing the input features.
|
|
2330
|
+
- X (torch.Tensor): [B, L, num_atoms, 3] - 3D coordinates of
|
|
2331
|
+
polymer atoms.
|
|
2332
|
+
- Y (torch.Tensor): [B, N, 3] - 3D coordinates of the ligand
|
|
2333
|
+
atoms.
|
|
2334
|
+
- structure_noise (float): Standard deviation of the
|
|
2335
|
+
Gaussian noise to add to the input coordinates, in
|
|
2336
|
+
Angstroms.
|
|
2337
|
+
Side Effects:
|
|
2338
|
+
input_features["X"] (torch.Tensor): [B, L, num_atoms, 3] - 3D
|
|
2339
|
+
coordinates of atoms with added Gaussian noise.
|
|
2340
|
+
input_features["Y"] (torch.Tensor): [B, N, 3] - 3D coordinates
|
|
2341
|
+
of the ligand atoms with added Gaussian noise.
|
|
2342
|
+
input_features["X_pre_noise"] (torch.Tensor): [B, L, num_atoms, 3] -
|
|
2343
|
+
3D coordinates of polymer atoms before adding noise.
|
|
2344
|
+
input_features["Y_pre_noise"] (torch.Tensor): [B, N, 3] -
|
|
2345
|
+
3D coordinates of the ligand atoms before adding noise.
|
|
2346
|
+
"""
|
|
2347
|
+
if "X" not in input_features:
|
|
2348
|
+
raise ValueError("Input features must contain 'X' key.")
|
|
2349
|
+
if "Y" not in input_features:
|
|
2350
|
+
raise ValueError("Input features must contain 'Y' key.")
|
|
2351
|
+
if "structure_noise" not in input_features:
|
|
2352
|
+
raise ValueError("Input features must contain 'structure_noise' key.")
|
|
2353
|
+
|
|
2354
|
+
structure_noise = input_features["structure_noise"]
|
|
2355
|
+
|
|
2356
|
+
# If the noise is non-zero, add Gaussian noise to the input
|
|
2357
|
+
# coordinates.
|
|
2358
|
+
if structure_noise > 0:
|
|
2359
|
+
# Copy the original coordinates before adding noise.
|
|
2360
|
+
input_features["X_pre_noise"] = input_features["X"].clone()
|
|
2361
|
+
input_features["Y_pre_noise"] = input_features["Y"].clone()
|
|
2362
|
+
|
|
2363
|
+
# Add Gaussian noise to the input coordinates.
|
|
2364
|
+
input_features["X"] = input_features[
|
|
2365
|
+
"X"
|
|
2366
|
+
] + structure_noise * torch.randn_like(input_features["X"])
|
|
2367
|
+
input_features["Y"] = input_features[
|
|
2368
|
+
"Y"
|
|
2369
|
+
] + structure_noise * torch.randn_like(input_features["Y"])
|
|
2370
|
+
else:
|
|
2371
|
+
input_features["X_pre_noise"] = input_features["X"].clone()
|
|
2372
|
+
input_features["Y_pre_noise"] = input_features["Y"].clone()
|