rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
from collections import Counter, OrderedDict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
6
|
+
from atomworks.ml.utils.token import (
|
|
7
|
+
get_token_starts,
|
|
8
|
+
spread_token_wise,
|
|
9
|
+
)
|
|
10
|
+
from biotite.structure import concatenate, infer_elements
|
|
11
|
+
from jaxtyping import Float, Int
|
|
12
|
+
from rfd3.constants import (
|
|
13
|
+
ATOM14_ATOM_NAMES,
|
|
14
|
+
VIRTUAL_ATOM_ELEMENT_NAME,
|
|
15
|
+
association_schemes,
|
|
16
|
+
association_schemes_stripped,
|
|
17
|
+
)
|
|
18
|
+
from rfd3.utils.io import (
|
|
19
|
+
build_stack_from_atom_array_and_batched_coords,
|
|
20
|
+
)
|
|
21
|
+
from scipy.optimize import linear_sum_assignment
|
|
22
|
+
|
|
23
|
+
from foundry.common import exists
|
|
24
|
+
from foundry.utils.ddp import RankedLogger
|
|
25
|
+
|
|
26
|
+
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
27
|
+
|
|
28
|
+
#######################################################################
|
|
29
|
+
# Pythonic Helper functions
|
|
30
|
+
#######################################################################
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _remap_outputs(
|
|
34
|
+
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
|
|
35
|
+
) -> Float[torch.Tensor, "D L 3"]:
|
|
36
|
+
"""Helper function to remap outputs using a mapping tensor."""
|
|
37
|
+
for i in range(xyz.shape[0]):
|
|
38
|
+
xyz[i, mapping[i]] = xyz[i].clone()
|
|
39
|
+
return xyz
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _reorder_dict(d: dict) -> OrderedDict:
|
|
43
|
+
"""
|
|
44
|
+
Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
|
|
45
|
+
"""
|
|
46
|
+
ordered = OrderedDict()
|
|
47
|
+
first_keys = ["task", "diffused_index_map"]
|
|
48
|
+
last_keys = ["metrics", "specification", "inference_sampler"]
|
|
49
|
+
# First
|
|
50
|
+
for k in first_keys:
|
|
51
|
+
if k in d:
|
|
52
|
+
ordered[k] = d[k]
|
|
53
|
+
# Middle
|
|
54
|
+
for k in d:
|
|
55
|
+
if k not in last_keys and k not in first_keys:
|
|
56
|
+
ordered[k] = d[k]
|
|
57
|
+
# Last
|
|
58
|
+
for k in last_keys:
|
|
59
|
+
if k in d:
|
|
60
|
+
ordered[k] = d[k]
|
|
61
|
+
return ordered
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
#######################################################################
|
|
65
|
+
# Biotite-related helper functions
|
|
66
|
+
#######################################################################
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _build_atom_array_stack(
|
|
70
|
+
coords,
|
|
71
|
+
src_atom_array,
|
|
72
|
+
sequence_indices,
|
|
73
|
+
sequence_logits,
|
|
74
|
+
allow_sequence_outputs=True,
|
|
75
|
+
read_sequence_from_sequence_head=True,
|
|
76
|
+
association_scheme: str = "atom14",
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
|
|
80
|
+
"""
|
|
81
|
+
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
82
|
+
coords, src_atom_array.copy()
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# ... Spoof empty sequences to alanines
|
|
86
|
+
atom_array_stack.res_name[
|
|
87
|
+
atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
|
|
88
|
+
] = "ALA"
|
|
89
|
+
|
|
90
|
+
# ... Add sequence if available
|
|
91
|
+
if allow_sequence_outputs:
|
|
92
|
+
array_list = []
|
|
93
|
+
if read_sequence_from_sequence_head and exists(sequence_logits):
|
|
94
|
+
sequence_encoding = AF3SequenceEncoding()
|
|
95
|
+
for i, (atom_array, seq_indices, seq_logits) in enumerate(
|
|
96
|
+
zip(atom_array_stack, sequence_indices, sequence_logits)
|
|
97
|
+
):
|
|
98
|
+
# Set residue names
|
|
99
|
+
diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
|
|
100
|
+
three_letter_sequence = sequence_encoding.decode(
|
|
101
|
+
seq_indices.cpu().numpy().astype(int)
|
|
102
|
+
) # [I]
|
|
103
|
+
|
|
104
|
+
atom_array.res_name[diffused_mask] = three_letter_sequence[
|
|
105
|
+
atom_array.token_id
|
|
106
|
+
][diffused_mask] # [L]
|
|
107
|
+
|
|
108
|
+
# Set bfactor column as entropy of sequence logits
|
|
109
|
+
p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
|
|
110
|
+
res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
|
|
111
|
+
atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
|
|
112
|
+
array_list.append(atom_array.copy())
|
|
113
|
+
else:
|
|
114
|
+
# This automatically deletes virtual atoms and assigns resname, atom name, and elements
|
|
115
|
+
for atom_array in atom_array_stack:
|
|
116
|
+
atom_array = _readout_seq_from_struc(
|
|
117
|
+
atom_array, association_scheme=association_scheme
|
|
118
|
+
)
|
|
119
|
+
array_list.append(atom_array)
|
|
120
|
+
|
|
121
|
+
# Return as list
|
|
122
|
+
atom_array_stack = array_list
|
|
123
|
+
|
|
124
|
+
return atom_array_stack
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
|
128
|
+
atom_array, association_scheme: str = "atom14"
|
|
129
|
+
):
|
|
130
|
+
## remove virtual atoms based on predicted residue and assign correct atom name and elements
|
|
131
|
+
ret_mask = []
|
|
132
|
+
atom_names = []
|
|
133
|
+
# This is used to indicate which residue is unidentified, probably due to an invalid structure.
|
|
134
|
+
# This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
|
|
135
|
+
invalid_mask = []
|
|
136
|
+
|
|
137
|
+
# ... Iterate through each residue.
|
|
138
|
+
# Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
|
|
139
|
+
res_ids = atom_array.res_id
|
|
140
|
+
res_start_indices = np.concatenate(
|
|
141
|
+
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
|
142
|
+
)
|
|
143
|
+
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
|
144
|
+
warning_issued = False
|
|
145
|
+
for start, end in zip(res_start_indices, res_end_indices):
|
|
146
|
+
res_array = atom_array[start:end]
|
|
147
|
+
|
|
148
|
+
is_seq_known = all(
|
|
149
|
+
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
|
150
|
+
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
|
|
151
|
+
|
|
152
|
+
# ... If sequence is known for the original atom array, just skip
|
|
153
|
+
if is_seq_known:
|
|
154
|
+
ret_mask += [True] * len(res_array)
|
|
155
|
+
invalid_mask += [False] * len(res_array)
|
|
156
|
+
res_name = res_array[0].res_name
|
|
157
|
+
atom_names += res_array.gt_atom_name.tolist()
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
# ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
|
|
161
|
+
res_name = res_array[0].res_name
|
|
162
|
+
if res_name not in association_schemes[association_scheme]:
|
|
163
|
+
global_logger.warning(
|
|
164
|
+
"Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
|
|
165
|
+
)
|
|
166
|
+
warning_issued = True
|
|
167
|
+
ret_mask += [True] * len(res_array)
|
|
168
|
+
invalid_mask += [True] * len(res_array)
|
|
169
|
+
atom_names += res_array.atom_name.tolist()
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
scheme = association_schemes[association_scheme][res_name]
|
|
173
|
+
ret_mask += [True if item is not None else False for item in scheme]
|
|
174
|
+
atom_names += [item.strip() if item is not None else "VX" for item in scheme]
|
|
175
|
+
invalid_mask += [False] * len(scheme)
|
|
176
|
+
|
|
177
|
+
if len(atom_names) != atom_array.array_length():
|
|
178
|
+
global_logger.warning(
|
|
179
|
+
f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
|
|
180
|
+
"\nCould not cleanup atom array!!!"
|
|
181
|
+
)
|
|
182
|
+
if not warning_issued:
|
|
183
|
+
raise ValueError("Atom names length does not match original array length. ")
|
|
184
|
+
return atom_array
|
|
185
|
+
atom_array.atom_name = atom_names
|
|
186
|
+
atom_array.element = np.where(
|
|
187
|
+
atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
|
|
188
|
+
infer_elements(atom_names),
|
|
189
|
+
atom_array.element,
|
|
190
|
+
)
|
|
191
|
+
atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
|
|
192
|
+
return atom_array[ret_mask]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _readout_seq_from_struc(
|
|
196
|
+
atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
|
|
197
|
+
):
|
|
198
|
+
cur_atom_array_list = []
|
|
199
|
+
|
|
200
|
+
# Iterate through each residue
|
|
201
|
+
res_ids = atom_array.res_id
|
|
202
|
+
res_start_indices = np.concatenate(
|
|
203
|
+
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
|
204
|
+
)
|
|
205
|
+
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
|
206
|
+
|
|
207
|
+
for start, end in zip(res_start_indices, res_end_indices):
|
|
208
|
+
# ... Check if the current residue is after padding (seq unknown):
|
|
209
|
+
cur_res_atom_array = atom_array[start:end]
|
|
210
|
+
is_seq_known = all(
|
|
211
|
+
np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Here it assumes that every non-protein part has its sequence shown (not padded)
|
|
215
|
+
if not is_seq_known:
|
|
216
|
+
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
|
|
217
|
+
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
|
|
218
|
+
# There might be a better way to do this.
|
|
219
|
+
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
|
|
220
|
+
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
|
|
221
|
+
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
|
222
|
+
cur_central_atom = "CA"
|
|
223
|
+
else:
|
|
224
|
+
cur_central_atom = central_atom
|
|
225
|
+
|
|
226
|
+
central_mask = cur_res_atom_array.atom_name == cur_central_atom
|
|
227
|
+
|
|
228
|
+
# ... Calculate the distance to the central atom
|
|
229
|
+
central_coord = cur_res_atom_array.coord[central_mask][
|
|
230
|
+
0
|
|
231
|
+
] # Should only have one central atom anyway
|
|
232
|
+
dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
|
|
233
|
+
|
|
234
|
+
# ... Select virtual atom by the distance. Shouldn't count the central atom itself.
|
|
235
|
+
is_virtual = (dists < threshold) & ~central_mask
|
|
236
|
+
|
|
237
|
+
# ... Throw away virtual atoms
|
|
238
|
+
cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
|
|
239
|
+
cur_pred_res_atom_names = (
|
|
240
|
+
cur_res_atom_array_wo_virtual.atom_name
|
|
241
|
+
) # e.g. [N, CA, C, O, CB, V6, V2]
|
|
242
|
+
|
|
243
|
+
# ... Iterate over the possible restypes and find the matched one if there is any
|
|
244
|
+
has_restype_assigned = False
|
|
245
|
+
for restype, atom_names in association_schemes_stripped[
|
|
246
|
+
association_scheme
|
|
247
|
+
].items():
|
|
248
|
+
atom_names = np.array(atom_names)
|
|
249
|
+
|
|
250
|
+
# Shouldn't match these two
|
|
251
|
+
if restype in ["UNK", "MSK"]:
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
# ... Find the index of virtual atom names in the standard atom14 names
|
|
255
|
+
atom_name_idx_in_atom14_scheme = np.array(
|
|
256
|
+
[
|
|
257
|
+
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
|
|
258
|
+
for atom_name in cur_pred_res_atom_names
|
|
259
|
+
]
|
|
260
|
+
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
|
|
261
|
+
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
|
|
262
|
+
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
|
|
263
|
+
|
|
264
|
+
# ... Find the matched restype by checking if all the non-None posititons and None positions match
|
|
265
|
+
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
|
|
266
|
+
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
|
|
267
|
+
x is None for x in atom_names[~atom14_scheme_mask]
|
|
268
|
+
):
|
|
269
|
+
cur_res_atom_array.res_name = np.array(
|
|
270
|
+
[restype] * len(cur_res_atom_array)
|
|
271
|
+
)
|
|
272
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
273
|
+
has_restype_assigned = True
|
|
274
|
+
break
|
|
275
|
+
else:
|
|
276
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
277
|
+
has_restype_assigned = True
|
|
278
|
+
|
|
279
|
+
# ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
|
|
280
|
+
if not has_restype_assigned:
|
|
281
|
+
cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
|
|
282
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
283
|
+
|
|
284
|
+
cur_atom_array = concatenate(cur_atom_array_list)
|
|
285
|
+
|
|
286
|
+
return cur_atom_array
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
#######################################################################
|
|
290
|
+
# Unindexed output parsing
|
|
291
|
+
#######################################################################
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _reassign_unindexed_token_chains(atom_array):
|
|
295
|
+
if np.any((mask := atom_array.is_motif_atom_unindexed)):
|
|
296
|
+
# HACK: Since res_ids are the same, we should save them with a different chain index.
|
|
297
|
+
atom_array.chain_id[mask] = "X"
|
|
298
|
+
atom_array.res_id[mask] = atom_array.orig_res_id[mask]
|
|
299
|
+
|
|
300
|
+
# Parse to separate chains
|
|
301
|
+
starts = get_token_starts(atom_array)
|
|
302
|
+
unindexed_starts = starts[mask[starts]]
|
|
303
|
+
token_breaks = atom_array[
|
|
304
|
+
unindexed_starts
|
|
305
|
+
].is_motif_atom_unindexed_motif_breakpoint
|
|
306
|
+
token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
|
|
307
|
+
token_chain_id = np.array([f"X{i}" for i in token_group_id])
|
|
308
|
+
|
|
309
|
+
chains = atom_array.chain_id[starts]
|
|
310
|
+
chains[mask[starts]] = token_chain_id
|
|
311
|
+
atom_array.chain_id = spread_token_wise(atom_array, chains)
|
|
312
|
+
return atom_array
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def process_unindexed_outputs(
|
|
316
|
+
atom_array,
|
|
317
|
+
match_atom_names=True,
|
|
318
|
+
insert_guideposts=False,
|
|
319
|
+
verbose=False,
|
|
320
|
+
):
|
|
321
|
+
"""
|
|
322
|
+
Process design outputs containing unindexed tokens.
|
|
323
|
+
Returns metadata such as the assigned positional indices from the input indices
|
|
324
|
+
and the RMSD of the unindexed tokens.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
- Diffused atom array (without additional unindexed tokens)
|
|
328
|
+
- Metadata:
|
|
329
|
+
- diffused_indices: keys = original (contig) indices, values = diffused indices
|
|
330
|
+
- insertion_rmsd: overall RMSD of insertion
|
|
331
|
+
- insertion_rmsd_by_residue: RMSD of insertion for each token
|
|
332
|
+
|
|
333
|
+
TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
|
|
334
|
+
TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
|
|
335
|
+
"""
|
|
336
|
+
# ... Find assignments based on greedy search
|
|
337
|
+
starts = get_token_starts(atom_array, add_exclusive_stop=True)
|
|
338
|
+
|
|
339
|
+
# [N_diffused,]
|
|
340
|
+
atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
|
|
341
|
+
global_idx = np.arange(atom_array.array_length())[
|
|
342
|
+
~atom_array.is_motif_atom_unindexed
|
|
343
|
+
]
|
|
344
|
+
|
|
345
|
+
metadata = {
|
|
346
|
+
"diffused_index_map": {},
|
|
347
|
+
"insertion_rmsd_by_token": {},
|
|
348
|
+
"join_point_rmsd_by_token": {},
|
|
349
|
+
"insertion_rmsd_by_restype": {},
|
|
350
|
+
}
|
|
351
|
+
token_maes = []
|
|
352
|
+
token_rmcds = []
|
|
353
|
+
n_conjoined_residues = 0
|
|
354
|
+
|
|
355
|
+
# Initialize an empty array
|
|
356
|
+
inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
|
|
357
|
+
|
|
358
|
+
for start, end in zip(starts[:-1], starts[1:]):
|
|
359
|
+
token = atom_array[start:end]
|
|
360
|
+
if not token.is_motif_atom_unindexed.all():
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
if "src_component" in token.get_annotation_categories():
|
|
364
|
+
token_pdb_id = token.src_component[0]
|
|
365
|
+
else:
|
|
366
|
+
raise ValueError(
|
|
367
|
+
"Missing annotation 'src_component' in token. Is this inference?"
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if "src_sym_component" in token.get_annotation_categories():
|
|
371
|
+
# if symmetry, token_pdb_id are updated to match the symmetrized component
|
|
372
|
+
token_pdb_id = token.src_sym_component[0]
|
|
373
|
+
|
|
374
|
+
res_name = token.res_name[0]
|
|
375
|
+
|
|
376
|
+
# ... Calculate [N_unindex, N_diffused] distance matrix
|
|
377
|
+
dists = np.linalg.norm(
|
|
378
|
+
token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
|
|
382
|
+
dists[:, inserted_mask.copy()] = np.inf
|
|
383
|
+
if match_atom_names:
|
|
384
|
+
matching_atom_name = (
|
|
385
|
+
token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
|
|
386
|
+
)
|
|
387
|
+
dists[~matching_atom_name] = np.inf
|
|
388
|
+
|
|
389
|
+
# ... Find the res_id's in the diffused regions belonging to the diffused indices
|
|
390
|
+
row_ind, col_ind = linear_sum_assignment(dists)
|
|
391
|
+
res_id, chain_id, is_conjoined = indices_to_components_(
|
|
392
|
+
atom_array_diffused, col_ind
|
|
393
|
+
)
|
|
394
|
+
n_conjoined_residues += int(is_conjoined)
|
|
395
|
+
|
|
396
|
+
# ... Recompute distance indices based on single residue pairings only
|
|
397
|
+
token_match = (atom_array_diffused.res_id == res_id) & (
|
|
398
|
+
atom_array_diffused.chain_id == chain_id
|
|
399
|
+
)
|
|
400
|
+
dists[:, ~token_match] = np.nan
|
|
401
|
+
BIG = 1e12
|
|
402
|
+
dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
|
|
403
|
+
row_ind, col_ind = linear_sum_assignment(dists)
|
|
404
|
+
res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind)
|
|
405
|
+
|
|
406
|
+
assert (res_id_ == res_id) & (chain_id_ == chain_id)
|
|
407
|
+
inserted_mask = np.logical_or(inserted_mask, token_match)
|
|
408
|
+
|
|
409
|
+
# ... Compute metrics based on the new distances
|
|
410
|
+
diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
|
|
411
|
+
token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
|
|
412
|
+
token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
|
|
413
|
+
token_mae = float((np.abs(diff)).sum(-1).mean())
|
|
414
|
+
|
|
415
|
+
metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
|
|
416
|
+
token_maes.append(token_mae)
|
|
417
|
+
token_rmcds.append(token_rmcd)
|
|
418
|
+
|
|
419
|
+
if res_name not in metadata["insertion_rmsd_by_restype"]:
|
|
420
|
+
metadata["insertion_rmsd_by_restype"][res_name] = []
|
|
421
|
+
metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
|
|
422
|
+
if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
|
|
423
|
+
if np.sum(token.atomize) == 1:
|
|
424
|
+
join_atom = np.where(token.atomize)[0][0]
|
|
425
|
+
elif "CB" in token.atom_name:
|
|
426
|
+
join_atom = np.where(token.atom_name == "CB")[0][0]
|
|
427
|
+
else:
|
|
428
|
+
join_atom = None
|
|
429
|
+
|
|
430
|
+
if join_atom is None:
|
|
431
|
+
global_logger.warning(
|
|
432
|
+
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
|
436
|
+
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
|
|
437
|
+
|
|
438
|
+
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
|
|
439
|
+
|
|
440
|
+
# ... Decide whether to cleanup guideposts or not
|
|
441
|
+
if insert_guideposts:
|
|
442
|
+
atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
|
|
443
|
+
if token.is_motif_atom_with_fixed_seq[0]:
|
|
444
|
+
atom_array_diffused.res_name[token_match] = token.res_name[0]
|
|
445
|
+
# atom_array_diffused.is_motif_token[token_match] = True
|
|
446
|
+
# atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
|
|
447
|
+
atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
|
|
448
|
+
True
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# ... Calculate global metrics
|
|
452
|
+
def safe_mean(x):
|
|
453
|
+
"""Return nan-safe mean for empty or nan arrays."""
|
|
454
|
+
x = np.asarray(x, float)
|
|
455
|
+
if x.size == 0 or not np.isfinite(x).any():
|
|
456
|
+
return float("nan")
|
|
457
|
+
return float(np.nanmean(x))
|
|
458
|
+
|
|
459
|
+
metadata["insertion.mae"] = safe_mean(token_maes)
|
|
460
|
+
metadata["insertion.rmcd"] = safe_mean(token_rmcds)
|
|
461
|
+
metadata["insertion_rmsd"] = safe_mean(
|
|
462
|
+
list(metadata["insertion_rmsd_by_token"].values())
|
|
463
|
+
)
|
|
464
|
+
metadata["join_point_rmsd"] = safe_mean(
|
|
465
|
+
list(metadata["join_point_rmsd_by_token"].values())
|
|
466
|
+
)
|
|
467
|
+
metadata["insertion_rmsd_by_restype"] = {
|
|
468
|
+
a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
|
|
469
|
+
}
|
|
470
|
+
metadata["n_conjoined_residues"] = n_conjoined_residues
|
|
471
|
+
|
|
472
|
+
if not verbose:
|
|
473
|
+
metadata = {
|
|
474
|
+
k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
return atom_array_diffused, metadata
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def indices_to_components_(atom_array, col_ind):
|
|
481
|
+
"""
|
|
482
|
+
Fetch chain and resids in atom array given a set of raw indices
|
|
483
|
+
will return 'conjoined' if indices to not map to a unique residue
|
|
484
|
+
"""
|
|
485
|
+
res_ids, chain_ids = (
|
|
486
|
+
atom_array.res_id[col_ind],
|
|
487
|
+
atom_array.chain_id[col_ind],
|
|
488
|
+
)
|
|
489
|
+
if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
|
|
490
|
+
global_logger.warning(
|
|
491
|
+
f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
|
|
492
|
+
)
|
|
493
|
+
# Handle by majority
|
|
494
|
+
pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
|
|
495
|
+
(chain_id, res_id), _ = pair_counts.most_common(1)[0]
|
|
496
|
+
conjoined = True
|
|
497
|
+
else:
|
|
498
|
+
res_id = res_ids[0]
|
|
499
|
+
chain_id = chain_ids[0]
|
|
500
|
+
conjoined = False
|
|
501
|
+
|
|
502
|
+
return res_id, chain_id, conjoined
|