rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,673 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import einops
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
import tree
|
|
9
|
+
from beartype.typing import Any
|
|
10
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from rf3.chemical import NHEAVY
|
|
13
|
+
from rf3.metrics.metric_utils import (
|
|
14
|
+
compute_mean_over_subsampled_pairs,
|
|
15
|
+
compute_min_over_subsampled_pairs,
|
|
16
|
+
create_chainwise_masks_1d,
|
|
17
|
+
create_chainwise_masks_2d,
|
|
18
|
+
create_interface_masks_2d,
|
|
19
|
+
spread_batch_into_dictionary,
|
|
20
|
+
unbin_logits,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_mean_atomwise_plddt(
|
|
25
|
+
plddt_logits: torch.Tensor,
|
|
26
|
+
is_real_atom: torch.Tensor,
|
|
27
|
+
max_value: float,
|
|
28
|
+
) -> torch.Tensor:
|
|
29
|
+
"""Aggregate plddts.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
plddt_logits: Tensor of shape [B, n_token, max_atoms_in_a_token * n_bin] with logits
|
|
33
|
+
is_real_atom: Boolean mask of shape [B, n_token, max_atoms_in_a_token] indicating which atoms are real (i.e., not padding)
|
|
34
|
+
max_value: Maximum value for pLDDT (assigned to the last bin)
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch
|
|
38
|
+
"""
|
|
39
|
+
assert (
|
|
40
|
+
plddt_logits.ndim == 3
|
|
41
|
+
), "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
|
|
42
|
+
|
|
43
|
+
# TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens
|
|
44
|
+
max_atoms_in_a_token = NHEAVY
|
|
45
|
+
|
|
46
|
+
# Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly
|
|
47
|
+
assert (
|
|
48
|
+
plddt_logits.shape[-1] % max_atoms_in_a_token == 0
|
|
49
|
+
), "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
|
|
50
|
+
n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token
|
|
51
|
+
|
|
52
|
+
# ... reshape to match what unbin_logits expects
|
|
53
|
+
reshaped_plddt_logits = einops.rearrange(
|
|
54
|
+
plddt_logits,
|
|
55
|
+
"... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token",
|
|
56
|
+
max_atoms_in_a_token=max_atoms_in_a_token,
|
|
57
|
+
n_bins=n_bins,
|
|
58
|
+
).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token]
|
|
59
|
+
|
|
60
|
+
plddt = unbin_logits(
|
|
61
|
+
reshaped_plddt_logits,
|
|
62
|
+
max_value,
|
|
63
|
+
n_bins,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
is_real_atom = is_real_atom.to(device=plddt.device)
|
|
67
|
+
|
|
68
|
+
# ... create mask indicating which atoms are "real" (i.e., not padding) and calculate the mean
|
|
69
|
+
mask = is_real_atom[:, :max_atoms_in_a_token].unsqueeze(0)
|
|
70
|
+
atomwise_plddt_mean = (plddt * mask).sum(dim=(1, 2)) / mask.sum(dim=(1, 2))
|
|
71
|
+
|
|
72
|
+
return atomwise_plddt_mean
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def compile_af3_confidence_outputs(
|
|
76
|
+
plddt_logits: torch.Tensor,
|
|
77
|
+
pae_logits: torch.Tensor,
|
|
78
|
+
pde_logits: torch.Tensor,
|
|
79
|
+
chain_iid_token_lvl: torch.Tensor,
|
|
80
|
+
is_real_atom: torch.Tensor,
|
|
81
|
+
example_id: str,
|
|
82
|
+
confidence_loss_cfg: DictConfig | dict,
|
|
83
|
+
) -> dict[str, Any]:
|
|
84
|
+
# TODO: Refactor to accept an AtomArray
|
|
85
|
+
# TODO: Taking the confidence_loss_cfg does not align with functional programming best-practices; we should instead take the max_value and n_bins as arguments
|
|
86
|
+
|
|
87
|
+
"""Given the confidence logits, computes the confidence metrics for the model's predictions.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
dict[str, Any]: A dictionary containing the following:
|
|
91
|
+
- confidence_df: A DataFrame containing the aggregate confidence metrics at the chain- and interface-level
|
|
92
|
+
- plddt: The pLDDT logits
|
|
93
|
+
- pae: The pAE logits
|
|
94
|
+
- pde: The pDE logits
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
# Reorder the input tensors to be in (B, n_bins, ...) format for unbinning
|
|
98
|
+
plddt = unbin_logits(
|
|
99
|
+
plddt_logits.reshape(
|
|
100
|
+
-1,
|
|
101
|
+
plddt_logits.shape[1],
|
|
102
|
+
NHEAVY,
|
|
103
|
+
confidence_loss_cfg.plddt.n_bins,
|
|
104
|
+
)
|
|
105
|
+
.permute(0, 3, 1, 2)
|
|
106
|
+
.float(),
|
|
107
|
+
confidence_loss_cfg.plddt.max_value,
|
|
108
|
+
confidence_loss_cfg.plddt.n_bins,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Unbin the pae and pde logits
|
|
112
|
+
pae = unbin_logits(
|
|
113
|
+
pae_logits.permute(0, 3, 1, 2).float(),
|
|
114
|
+
confidence_loss_cfg.pae.max_value,
|
|
115
|
+
confidence_loss_cfg.pae.n_bins,
|
|
116
|
+
)
|
|
117
|
+
pde = unbin_logits(
|
|
118
|
+
pde_logits.permute(0, 3, 1, 2).float(),
|
|
119
|
+
confidence_loss_cfg.pde.max_value,
|
|
120
|
+
confidence_loss_cfg.pde.n_bins,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Calculate interface metrics
|
|
124
|
+
interface_masks = create_interface_masks_2d(chain_iid_token_lvl, device=pae.device)
|
|
125
|
+
pae_interface = {
|
|
126
|
+
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v))
|
|
127
|
+
for k, v in interface_masks.items()
|
|
128
|
+
}
|
|
129
|
+
pde_interface = {
|
|
130
|
+
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v))
|
|
131
|
+
for k, v in interface_masks.items()
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
pae_interface_min = {
|
|
135
|
+
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v))
|
|
136
|
+
for k, v in interface_masks.items()
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
pde_interface_min = {
|
|
140
|
+
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v))
|
|
141
|
+
for k, v in interface_masks.items()
|
|
142
|
+
}
|
|
143
|
+
# Calculate chainwise metrics
|
|
144
|
+
chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
|
|
145
|
+
pae_chainwise = {
|
|
146
|
+
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v))
|
|
147
|
+
for k, v in chain_masks_2d.items()
|
|
148
|
+
}
|
|
149
|
+
pde_chainwise = {
|
|
150
|
+
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v))
|
|
151
|
+
for k, v in chain_masks_2d.items()
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
chain_masks_1d = create_chainwise_masks_1d(
|
|
155
|
+
chain_iid_token_lvl, device=is_real_atom.device
|
|
156
|
+
)
|
|
157
|
+
plddt_chainwise = {
|
|
158
|
+
k: spread_batch_into_dictionary(
|
|
159
|
+
compute_mean_over_subsampled_pairs(
|
|
160
|
+
plddt, is_real_atom[..., :NHEAVY] * v[:, None]
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
for k, v in chain_masks_1d.items()
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# Aggregate confidence data
|
|
167
|
+
confidence_data = {
|
|
168
|
+
"example_id": example_id,
|
|
169
|
+
"mean_plddt": spread_batch_into_dictionary(
|
|
170
|
+
compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., :NHEAVY])
|
|
171
|
+
),
|
|
172
|
+
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
|
|
173
|
+
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
|
|
174
|
+
"chain_wise_mean_plddt": plddt_chainwise,
|
|
175
|
+
"chain_wise_mean_pae": pae_chainwise,
|
|
176
|
+
"chain_wise_mean_pde": pde_chainwise,
|
|
177
|
+
"interface_wise_mean_pae": pae_interface,
|
|
178
|
+
"interface_wise_mean_pde": pde_interface,
|
|
179
|
+
"interface_wise_min_pae": pae_interface_min,
|
|
180
|
+
"interface_wise_min_pde": pde_interface_min,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
# Generate DataFrame rows
|
|
184
|
+
num_batches = plddt.shape[0]
|
|
185
|
+
chains = np.unique(chain_iid_token_lvl)
|
|
186
|
+
chain_pairs = list(itertools.combinations(chains, 2))
|
|
187
|
+
|
|
188
|
+
# For every batch, chain, and interface (chain pair), generate a dataframe row
|
|
189
|
+
chain_rows = [
|
|
190
|
+
{
|
|
191
|
+
"example_id": example_id,
|
|
192
|
+
"chain_chainwise": chain,
|
|
193
|
+
"chainwise_plddt": confidence_data["chain_wise_mean_plddt"][chain][
|
|
194
|
+
batch_idx
|
|
195
|
+
],
|
|
196
|
+
"chainwise_pde": confidence_data["chain_wise_mean_pde"][chain][batch_idx],
|
|
197
|
+
"chainwise_pae": confidence_data["chain_wise_mean_pae"][chain][batch_idx],
|
|
198
|
+
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
|
199
|
+
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
|
200
|
+
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
|
201
|
+
"batch_idx": batch_idx,
|
|
202
|
+
}
|
|
203
|
+
for batch_idx in range(num_batches)
|
|
204
|
+
for chain in chains
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
interface_rows = [
|
|
208
|
+
{
|
|
209
|
+
"example_id": example_id,
|
|
210
|
+
"chain_i_interface": chain_i,
|
|
211
|
+
"chain_j_interface": chain_j,
|
|
212
|
+
"pae_interface": confidence_data["interface_wise_mean_pae"][
|
|
213
|
+
(chain_i, chain_j)
|
|
214
|
+
][batch_idx],
|
|
215
|
+
"pde_interface": confidence_data["interface_wise_mean_pde"][
|
|
216
|
+
(chain_i, chain_j)
|
|
217
|
+
][batch_idx],
|
|
218
|
+
"min_pae_interface": confidence_data["interface_wise_min_pae"][
|
|
219
|
+
(chain_i, chain_j)
|
|
220
|
+
][batch_idx],
|
|
221
|
+
"min_pde_interface": confidence_data["interface_wise_min_pde"][
|
|
222
|
+
(chain_i, chain_j)
|
|
223
|
+
][batch_idx],
|
|
224
|
+
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
|
225
|
+
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
|
226
|
+
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
|
227
|
+
"batch_idx": batch_idx,
|
|
228
|
+
}
|
|
229
|
+
for batch_idx in range(num_batches)
|
|
230
|
+
for (chain_i, chain_j) in chain_pairs
|
|
231
|
+
]
|
|
232
|
+
|
|
233
|
+
return {
|
|
234
|
+
"confidence_df": pd.DataFrame(itertools.chain([*chain_rows, *interface_rows])),
|
|
235
|
+
"plddt": plddt,
|
|
236
|
+
"pae": pae,
|
|
237
|
+
"pde": pde,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def compile_af3_style_confidence_outputs(
|
|
242
|
+
plddt_logits: torch.Tensor,
|
|
243
|
+
pae_logits: torch.Tensor,
|
|
244
|
+
pde_logits: torch.Tensor,
|
|
245
|
+
chain_iid_token_lvl: torch.Tensor | np.ndarray,
|
|
246
|
+
is_real_atom: torch.Tensor,
|
|
247
|
+
atom_array: AtomArray,
|
|
248
|
+
confidence_loss_cfg: DictConfig | dict,
|
|
249
|
+
batch_idx: int = 0,
|
|
250
|
+
) -> dict[str, Any]:
|
|
251
|
+
"""Compile confidence outputs in AlphaFold3-compatible format.
|
|
252
|
+
|
|
253
|
+
Returns a dict with:
|
|
254
|
+
- summary_confidences: Dict for {name}_summary_confidences.json
|
|
255
|
+
- confidences: Dict for {name}_confidences.json (per-atom data)
|
|
256
|
+
- plddt, pae, pde: Raw tensors for further processing
|
|
257
|
+
"""
|
|
258
|
+
# Unbin logits
|
|
259
|
+
plddt = unbin_logits(
|
|
260
|
+
plddt_logits.reshape(
|
|
261
|
+
-1,
|
|
262
|
+
plddt_logits.shape[1],
|
|
263
|
+
NHEAVY,
|
|
264
|
+
confidence_loss_cfg.plddt.n_bins,
|
|
265
|
+
)
|
|
266
|
+
.permute(0, 3, 1, 2)
|
|
267
|
+
.float(),
|
|
268
|
+
confidence_loss_cfg.plddt.max_value,
|
|
269
|
+
confidence_loss_cfg.plddt.n_bins,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
pae = unbin_logits(
|
|
273
|
+
pae_logits.permute(0, 3, 1, 2).float(),
|
|
274
|
+
confidence_loss_cfg.pae.max_value,
|
|
275
|
+
confidence_loss_cfg.pae.n_bins,
|
|
276
|
+
)
|
|
277
|
+
pde = unbin_logits(
|
|
278
|
+
pde_logits.permute(0, 3, 1, 2).float(),
|
|
279
|
+
confidence_loss_cfg.pde.max_value,
|
|
280
|
+
confidence_loss_cfg.pde.n_bins,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Get chain information
|
|
284
|
+
if isinstance(chain_iid_token_lvl, torch.Tensor):
|
|
285
|
+
chain_iid_token_lvl = chain_iid_token_lvl.cpu().numpy()
|
|
286
|
+
chains = list(np.unique(chain_iid_token_lvl))
|
|
287
|
+
n_chains = len(chains)
|
|
288
|
+
|
|
289
|
+
# Calculate chainwise metrics
|
|
290
|
+
chain_masks_1d = create_chainwise_masks_1d(
|
|
291
|
+
chain_iid_token_lvl, device=is_real_atom.device
|
|
292
|
+
)
|
|
293
|
+
chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
|
|
294
|
+
|
|
295
|
+
# Chain-level pLDDT
|
|
296
|
+
chain_plddt = {}
|
|
297
|
+
for chain, mask in chain_masks_1d.items():
|
|
298
|
+
chain_plddt[chain] = compute_mean_over_subsampled_pairs(
|
|
299
|
+
plddt, is_real_atom[..., :NHEAVY] * mask[:, None]
|
|
300
|
+
)[batch_idx].item()
|
|
301
|
+
|
|
302
|
+
# Chain-level PAE (intra-chain)
|
|
303
|
+
chain_pae = {}
|
|
304
|
+
for chain, mask in chain_masks_2d.items():
|
|
305
|
+
chain_pae[chain] = compute_mean_over_subsampled_pairs(pae, mask)[
|
|
306
|
+
batch_idx
|
|
307
|
+
].item()
|
|
308
|
+
|
|
309
|
+
# Chain-pair PAE/PDE (inter-chain, for iptm-like metric)
|
|
310
|
+
interface_masks = create_interface_masks_2d(chain_iid_token_lvl, device=pae.device)
|
|
311
|
+
chain_pair_pae = {}
|
|
312
|
+
chain_pair_pae_min = {}
|
|
313
|
+
chain_pair_pde = {}
|
|
314
|
+
chain_pair_pde_min = {}
|
|
315
|
+
for (chain_i, chain_j), mask in interface_masks.items():
|
|
316
|
+
chain_pair_pae[(chain_i, chain_j)] = compute_mean_over_subsampled_pairs(
|
|
317
|
+
pae, mask
|
|
318
|
+
)[batch_idx].item()
|
|
319
|
+
chain_pair_pae_min[(chain_i, chain_j)] = compute_min_over_subsampled_pairs(
|
|
320
|
+
pae, mask
|
|
321
|
+
)[batch_idx].item()
|
|
322
|
+
chain_pair_pde[(chain_i, chain_j)] = compute_mean_over_subsampled_pairs(
|
|
323
|
+
pde, mask
|
|
324
|
+
)[batch_idx].item()
|
|
325
|
+
chain_pair_pde_min[(chain_i, chain_j)] = compute_min_over_subsampled_pairs(
|
|
326
|
+
pde, mask
|
|
327
|
+
)[batch_idx].item()
|
|
328
|
+
|
|
329
|
+
# Overall metrics for this batch
|
|
330
|
+
overall_plddt = compute_mean_over_subsampled_pairs(
|
|
331
|
+
plddt, is_real_atom[..., :NHEAVY]
|
|
332
|
+
)[batch_idx].item()
|
|
333
|
+
overall_pae = pae[batch_idx].mean().item()
|
|
334
|
+
overall_pde = pde[batch_idx].mean().item()
|
|
335
|
+
|
|
336
|
+
# Build chain_pair matrices (NxN)
|
|
337
|
+
chain_pair_pae_matrix = [[None] * n_chains for _ in range(n_chains)]
|
|
338
|
+
chain_pair_pae_min_matrix = [[None] * n_chains for _ in range(n_chains)]
|
|
339
|
+
chain_pair_pde_matrix = [[None] * n_chains for _ in range(n_chains)]
|
|
340
|
+
chain_pair_pde_min_matrix = [[None] * n_chains for _ in range(n_chains)]
|
|
341
|
+
for i, chain_i in enumerate(chains):
|
|
342
|
+
for j, chain_j in enumerate(chains):
|
|
343
|
+
if i != j and (chain_i, chain_j) in chain_pair_pae:
|
|
344
|
+
chain_pair_pae_matrix[i][j] = round(
|
|
345
|
+
chain_pair_pae[(chain_i, chain_j)], 2
|
|
346
|
+
)
|
|
347
|
+
chain_pair_pae_min_matrix[i][j] = round(
|
|
348
|
+
chain_pair_pae_min[(chain_i, chain_j)], 2
|
|
349
|
+
)
|
|
350
|
+
chain_pair_pde_matrix[i][j] = round(
|
|
351
|
+
chain_pair_pde[(chain_i, chain_j)], 2
|
|
352
|
+
)
|
|
353
|
+
chain_pair_pde_min_matrix[i][j] = round(
|
|
354
|
+
chain_pair_pde_min[(chain_i, chain_j)], 2
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Extract per-atom pLDDT values
|
|
358
|
+
atom_plddts = plddt[batch_idx][is_real_atom[..., :NHEAVY]].cpu().tolist()
|
|
359
|
+
|
|
360
|
+
# Extract atom/token chain and residue info from atom_array
|
|
361
|
+
atom_chain_ids = atom_array.chain_id.tolist()
|
|
362
|
+
token_chain_ids = list(chain_iid_token_lvl)
|
|
363
|
+
token_res_ids = list(
|
|
364
|
+
range(len(chain_iid_token_lvl))
|
|
365
|
+
) # Simplified; could map to actual res_id
|
|
366
|
+
|
|
367
|
+
# PAE matrix for this batch
|
|
368
|
+
pae_matrix = pae[batch_idx].cpu().tolist()
|
|
369
|
+
|
|
370
|
+
# Build summary_confidences (AlphaFold3-style + RF3 extensions)
|
|
371
|
+
summary_confidences = {
|
|
372
|
+
"chain_ptm": [round(chain_plddt.get(c, 0.0), 2) for c in chains],
|
|
373
|
+
"chain_pair_pae_min": chain_pair_pae_min_matrix,
|
|
374
|
+
"chain_pair_pde_min": chain_pair_pde_min_matrix,
|
|
375
|
+
"chain_pair_pae": chain_pair_pae_matrix,
|
|
376
|
+
"chain_pair_pde": chain_pair_pde_matrix,
|
|
377
|
+
"overall_plddt": round(overall_plddt, 4),
|
|
378
|
+
"overall_pde": round(overall_pde, 4),
|
|
379
|
+
"overall_pae": round(overall_pae, 4),
|
|
380
|
+
# Note: ptm, iptm, has_clash should be populated from metrics_output
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
# Build full confidences (per-atom data)
|
|
384
|
+
confidences = {
|
|
385
|
+
"atom_chain_ids": atom_chain_ids,
|
|
386
|
+
"atom_plddts": [round(p, 2) for p in atom_plddts],
|
|
387
|
+
"pae": [[round(v, 2) for v in row] for row in pae_matrix],
|
|
388
|
+
"token_chain_ids": token_chain_ids,
|
|
389
|
+
"token_res_ids": token_res_ids,
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
return {
|
|
393
|
+
"summary_confidences": summary_confidences,
|
|
394
|
+
"confidences": confidences,
|
|
395
|
+
"plddt": plddt,
|
|
396
|
+
"pae": pae,
|
|
397
|
+
"pde": pde,
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def compute_batch_indices_with_lowest_predicted_error(
|
|
402
|
+
plddt: torch.Tensor,
|
|
403
|
+
is_real_atom: torch.Tensor,
|
|
404
|
+
pae: torch.Tensor,
|
|
405
|
+
confidence_loss_cfg: dict | DictConfig,
|
|
406
|
+
chain_iid_token_lvl: torch.Tensor,
|
|
407
|
+
is_ligand: torch.Tensor,
|
|
408
|
+
interfaces_to_score: list[tuple],
|
|
409
|
+
pn_units_to_score: list[tuple],
|
|
410
|
+
) -> dict[str, Any]:
|
|
411
|
+
"""Given the confidence logits, computes the index within the diffusion batch of the best predicted structure.
|
|
412
|
+
|
|
413
|
+
Metrics include pAE, pLDDT, and pDE, among others.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
dict[str, Any]: A dictionary containing the following keys:
|
|
417
|
+
- pae_idx: The index within the diffusion batch of the structure with the best overall pAE (Predicted Aligned Error)
|
|
418
|
+
- pde_idx: The index within the diffusion batch of the structure with the best overall pDE (Predicted Distance Error)
|
|
419
|
+
- plddt_idx: The index within the diffusion batch of the structure with the best overall pLDDT (Predicted Local Distance
|
|
420
|
+
Difference Test)
|
|
421
|
+
- best_chain_to_all_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any
|
|
422
|
+
pair (i,j) where i == chain or j == chain
|
|
423
|
+
- best_chain_to_self_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any
|
|
424
|
+
pair (i,j) where i == chain and j == chain
|
|
425
|
+
- best_interface_idx: For each interface between two scored PN Units, the index within the diffusion batch of the
|
|
426
|
+
structure with the best mean pAE for all (i,j) where i == interface_chain or j == interface_chain and i != j
|
|
427
|
+
- best_lig_ipae_idx: The index within the diffusion batch for the best pAE subsambled over any pair (i,j)
|
|
428
|
+
where i == chain or j == chain and i != j and i or j is a ligand
|
|
429
|
+
"""
|
|
430
|
+
# TODO: Have this function take an `AtomArray` as input so we quickly build masks with much less code
|
|
431
|
+
# TODO: Explore how we can write this function more concisely
|
|
432
|
+
return_dict = {}
|
|
433
|
+
|
|
434
|
+
# AF3's ranking metrics work like this, but using ptm instead of ipae:
|
|
435
|
+
scored_chains, interfaces, interface_chains = _select_scored_units(
|
|
436
|
+
interfaces_to_score, pn_units_to_score
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
chain_to_all_masks = _create_chain_to_all_masks(chain_iid_token_lvl, scored_chains)
|
|
440
|
+
chain_to_self_masks = _create_chain_to_self_masks(
|
|
441
|
+
chain_iid_token_lvl, scored_chains
|
|
442
|
+
)
|
|
443
|
+
interface_masks, lig_chains = _create_interface_masks(
|
|
444
|
+
chain_iid_token_lvl, interfaces, is_ligand
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# map everything to gpu
|
|
448
|
+
gpu = plddt.device
|
|
449
|
+
chain_to_all_masks = tree.map_structure(
|
|
450
|
+
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_all_masks
|
|
451
|
+
)
|
|
452
|
+
chain_to_self_masks = tree.map_structure(
|
|
453
|
+
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_self_masks
|
|
454
|
+
)
|
|
455
|
+
interface_masks = tree.map_structure(
|
|
456
|
+
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, interface_masks
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Reshape logits to B, K, L, NHEAVY
|
|
460
|
+
plddt = (
|
|
461
|
+
plddt.reshape(
|
|
462
|
+
-1,
|
|
463
|
+
plddt.shape[1],
|
|
464
|
+
NHEAVY,
|
|
465
|
+
confidence_loss_cfg.plddt.n_bins,
|
|
466
|
+
)
|
|
467
|
+
.permute(0, 3, 1, 2)
|
|
468
|
+
.float()
|
|
469
|
+
)
|
|
470
|
+
# Reshape the pae and pde logits to B, K, L, L
|
|
471
|
+
pae_logits = pae.permute(0, 3, 1, 2).float()
|
|
472
|
+
pde_logits = pae.permute(0, 3, 1, 2).float()
|
|
473
|
+
|
|
474
|
+
pae_logits_unbinned = unbin_logits(
|
|
475
|
+
pae_logits, confidence_loss_cfg.pae.max_value, confidence_loss_cfg.pae.n_bins
|
|
476
|
+
)
|
|
477
|
+
plddt_logits_unbinned = unbin_logits(
|
|
478
|
+
plddt, confidence_loss_cfg.plddt.max_value, confidence_loss_cfg.plddt.n_bins
|
|
479
|
+
)
|
|
480
|
+
pde_logits_unbinned = unbin_logits(
|
|
481
|
+
pde_logits, confidence_loss_cfg.pde.max_value, confidence_loss_cfg.pde.n_bins
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
|
|
485
|
+
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
|
|
486
|
+
complex_plddt = (plddt_logits_unbinned * is_real_atom[..., :NHEAVY]).sum(
|
|
487
|
+
dim=(1, 2)
|
|
488
|
+
) / is_real_atom[..., :NHEAVY].sum()
|
|
489
|
+
|
|
490
|
+
return_dict["pae_idx"] = torch.argmin(complex_pae)
|
|
491
|
+
return_dict["pde_idx"] = torch.argmin(complex_pde)
|
|
492
|
+
return_dict["plddt_idx"] = torch.argmax(complex_plddt)
|
|
493
|
+
|
|
494
|
+
chain_to_self_paes = _get_masked_error_per_chain(
|
|
495
|
+
scored_chains, chain_to_self_masks, pae_logits_unbinned
|
|
496
|
+
)
|
|
497
|
+
chain_to_all_paes = _get_masked_error_per_chain(
|
|
498
|
+
scored_chains, chain_to_all_masks, pae_logits_unbinned
|
|
499
|
+
)
|
|
500
|
+
interface_chain_paes = _get_masked_error_per_chain(
|
|
501
|
+
interface_chains, interface_masks, pae_logits_unbinned
|
|
502
|
+
)
|
|
503
|
+
# average over both interfaces
|
|
504
|
+
average_interface_paes = _get_average_error_per_interface(
|
|
505
|
+
interfaces, lig_chains, interface_chain_paes
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
return_dict["best_chain_to_all_idx"] = _get_lowest_error_indices(chain_to_all_paes)
|
|
509
|
+
return_dict["best_chain_to_self_idx"] = _get_lowest_error_indices(
|
|
510
|
+
chain_to_self_paes
|
|
511
|
+
)
|
|
512
|
+
return_dict["best_interface_idx"] = _get_lowest_error_indices(
|
|
513
|
+
average_interface_paes
|
|
514
|
+
)
|
|
515
|
+
# for ligands, we don't average the error
|
|
516
|
+
return_dict["best_lig_ipae_idx"] = _get_lowest_error_ligand_indices(
|
|
517
|
+
interface_chain_paes, interfaces, lig_chains
|
|
518
|
+
)
|
|
519
|
+
return return_dict
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def annotate_atom_array_b_factor_with_plddt(
|
|
523
|
+
atom_array: AtomArray | AtomArrayStack,
|
|
524
|
+
plddt: torch.Tensor,
|
|
525
|
+
is_real_atom: torch.Tensor,
|
|
526
|
+
) -> List[AtomArray]:
|
|
527
|
+
"""Annotates the b_factor of an AtomArray with the pLDDT values in the occupancy field.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
atom_array: The AtomArray or AtomArrayStack to annotate
|
|
531
|
+
plddt: The pLDDT tensor of shape (B, I, NHEAVY)
|
|
532
|
+
is_real_atom: A mask indicating which atoms are in the structure of shape (I, NHEAVY)
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
list[AtomArray]: The annotated list of AtomArrays. We must return a list of AtomArrays
|
|
536
|
+
because the AtomArray class does not support setting different values as annotations
|
|
537
|
+
other than the coordinate feature.
|
|
538
|
+
"""
|
|
539
|
+
atom_wise_plddt = plddt[:, is_real_atom[..., :NHEAVY]]
|
|
540
|
+
assert atom_wise_plddt.shape[1] == atom_array.array_length()
|
|
541
|
+
atom_array_list = []
|
|
542
|
+
# bitotite's AtomArray does not support setting different values as annotations other than
|
|
543
|
+
# the coordinate feature, so we convert atom_array to a list of AtomArrays
|
|
544
|
+
if isinstance(atom_array, AtomArrayStack):
|
|
545
|
+
for i, aa in enumerate(atom_array):
|
|
546
|
+
aa.set_annotation("b_factor", atom_wise_plddt[i].cpu().numpy())
|
|
547
|
+
atom_array_list.append(aa)
|
|
548
|
+
else:
|
|
549
|
+
assert atom_wise_plddt.shape[0] == 1
|
|
550
|
+
atom_array.set_annotation("b_factor", atom_wise_plddt[0].cpu().numpy())
|
|
551
|
+
atom_array_list.append(atom_array)
|
|
552
|
+
|
|
553
|
+
for aa in atom_array_list:
|
|
554
|
+
assert np.isnan(aa.b_factor).sum() == 0
|
|
555
|
+
|
|
556
|
+
return atom_array_list
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _select_scored_units(
|
|
560
|
+
interfaces_to_score: list[tuple], pn_units_to_score: list[tuple]
|
|
561
|
+
):
|
|
562
|
+
scored_chains = []
|
|
563
|
+
interfaces = []
|
|
564
|
+
interface_chains = []
|
|
565
|
+
for k in interfaces_to_score:
|
|
566
|
+
interfaces.append(f"{k[0]}-{k[1]}")
|
|
567
|
+
interface_chains.append(k[0])
|
|
568
|
+
interface_chains.append(k[1])
|
|
569
|
+
for k in pn_units_to_score:
|
|
570
|
+
scored_chains.append(k[0])
|
|
571
|
+
|
|
572
|
+
return scored_chains, interfaces, interface_chains
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def _create_chain_to_all_masks(ch_label, chains_to_score):
|
|
576
|
+
unique_chains = np.unique(ch_label)
|
|
577
|
+
I = len(ch_label)
|
|
578
|
+
chain_to_all_masks = {}
|
|
579
|
+
for chain in unique_chains:
|
|
580
|
+
if chain in chains_to_score:
|
|
581
|
+
indices = torch.from_numpy((ch_label == chain))
|
|
582
|
+
mask = indices.unsqueeze(0) | indices.unsqueeze(1)
|
|
583
|
+
# set the diagonal to false
|
|
584
|
+
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
|
|
585
|
+
chain_to_all_masks[chain] = mask
|
|
586
|
+
return chain_to_all_masks
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _create_chain_to_self_masks(ch_label, chains_to_score):
|
|
590
|
+
unique_chains = np.unique(ch_label)
|
|
591
|
+
I = len(ch_label)
|
|
592
|
+
chain_to_self_masks = {}
|
|
593
|
+
for chain in unique_chains:
|
|
594
|
+
if chain in chains_to_score:
|
|
595
|
+
indices = torch.from_numpy((ch_label == chain))
|
|
596
|
+
mask = indices.unsqueeze(0) & indices.unsqueeze(1)
|
|
597
|
+
# set the diagonal to false
|
|
598
|
+
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
|
|
599
|
+
chain_to_self_masks[chain] = mask
|
|
600
|
+
return chain_to_self_masks
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def _create_interface_masks(ch_label, interfaces, is_ligand):
|
|
604
|
+
interface_masks = {}
|
|
605
|
+
interface_chains = []
|
|
606
|
+
ligand_chains = []
|
|
607
|
+
for interface in interfaces:
|
|
608
|
+
interface_chains.append(interface.split("-")[0])
|
|
609
|
+
interface_chains.append(interface.split("-")[1])
|
|
610
|
+
interface_chains = set(interface_chains)
|
|
611
|
+
for chain in interface_chains:
|
|
612
|
+
chain_indices = torch.from_numpy((ch_label == chain))
|
|
613
|
+
|
|
614
|
+
to_self = chain_indices.unsqueeze(0) & chain_indices.unsqueeze(1)
|
|
615
|
+
to_all = chain_indices.unsqueeze(0) | chain_indices.unsqueeze(1)
|
|
616
|
+
interface_mask = to_all & ~to_self
|
|
617
|
+
interface_masks[chain] = interface_mask
|
|
618
|
+
|
|
619
|
+
if torch.all(is_ligand[chain_indices]):
|
|
620
|
+
ligand_chains.append(chain)
|
|
621
|
+
|
|
622
|
+
return interface_masks, ligand_chains
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def _get_masked_error_per_chain(chains, masks, unbinned_logits):
|
|
626
|
+
error = {}
|
|
627
|
+
for chain in chains:
|
|
628
|
+
mask = masks[chain]
|
|
629
|
+
chain_error = compute_mean_over_subsampled_pairs(unbinned_logits, mask)
|
|
630
|
+
error[chain] = chain_error
|
|
631
|
+
|
|
632
|
+
return error
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def _get_average_error_per_interface(interfaces, lig_chains, interface_errors):
|
|
636
|
+
average_error = {}
|
|
637
|
+
for interface in interfaces:
|
|
638
|
+
chain_a = interface.split("-")[0]
|
|
639
|
+
chain_b = interface.split("-")[1]
|
|
640
|
+
average_error[interface] = (
|
|
641
|
+
interface_errors[chain_a] + interface_errors[chain_b]
|
|
642
|
+
) / 2
|
|
643
|
+
|
|
644
|
+
return average_error
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def _get_lowest_error_indices(errors):
|
|
648
|
+
lowest_error_indices = {}
|
|
649
|
+
for k, v in errors.items():
|
|
650
|
+
lowest_error_indices[k] = torch.argmin(v)
|
|
651
|
+
|
|
652
|
+
return lowest_error_indices
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def _get_lowest_error_ligand_indices(errors, interfaces, lig_chains):
|
|
656
|
+
# ligands are a special case in AF3, where they only consider the ligand chain's error and not the average for the interface
|
|
657
|
+
lowest_error_indices = {}
|
|
658
|
+
for interface in interfaces:
|
|
659
|
+
chain_a = interface.split("-")[0]
|
|
660
|
+
chain_b = interface.split("-")[1]
|
|
661
|
+
if chain_a in lig_chains or chain_b in lig_chains:
|
|
662
|
+
if chain_a in lig_chains:
|
|
663
|
+
lig_chain = chain_a
|
|
664
|
+
elif chain_b in lig_chains:
|
|
665
|
+
lig_chain = chain_b
|
|
666
|
+
|
|
667
|
+
lowest_error_indices[interface] = torch.argmin(errors[lig_chain])
|
|
668
|
+
else:
|
|
669
|
+
# assign a random value to avoid key errors downstream; sorting ligand interfaces
|
|
670
|
+
# from other types is handles in analysis
|
|
671
|
+
lowest_error_indices[interface] = 0
|
|
672
|
+
|
|
673
|
+
return lowest_error_indices
|