rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rfd3/utils/inference.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for inference input preparation
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from os import PathLike
|
|
8
|
+
from typing import Dict
|
|
9
|
+
|
|
10
|
+
import biotite.structure as struc
|
|
11
|
+
import numpy as np
|
|
12
|
+
from atomworks import parse
|
|
13
|
+
from atomworks.constants import STANDARD_AA, STANDARD_DNA
|
|
14
|
+
from atomworks.io.parser import (
|
|
15
|
+
STANDARD_PARSER_ARGS,
|
|
16
|
+
)
|
|
17
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
18
|
+
from atomworks.ml.preprocessing.utils.structure_utils import (
|
|
19
|
+
get_atom_mask_from_cell_list,
|
|
20
|
+
)
|
|
21
|
+
from atomworks.ml.utils.token import (
|
|
22
|
+
get_token_starts,
|
|
23
|
+
spread_token_wise,
|
|
24
|
+
)
|
|
25
|
+
from rfd3.constants import (
|
|
26
|
+
REQUIRED_CONDITIONING_ANNOTATIONS,
|
|
27
|
+
)
|
|
28
|
+
from rfd3.transforms.conditioning_base import (
|
|
29
|
+
convert_existing_annotations_to_bool,
|
|
30
|
+
set_default_conditioning_annotations,
|
|
31
|
+
)
|
|
32
|
+
from rfd3.transforms.conditioning_utils import sample_island_tokens
|
|
33
|
+
|
|
34
|
+
from foundry.common import exists
|
|
35
|
+
from foundry.utils.components import (
|
|
36
|
+
fetch_mask_from_component,
|
|
37
|
+
get_name_mask,
|
|
38
|
+
unravel_components,
|
|
39
|
+
)
|
|
40
|
+
from foundry.utils.ddp import RankedLogger
|
|
41
|
+
|
|
42
|
+
logging.basicConfig(level=logging.INFO)
|
|
43
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
44
|
+
|
|
45
|
+
sequence_encoding = AF3SequenceEncoding()
|
|
46
|
+
_aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
#################################################################################
|
|
50
|
+
# Setter functions for annotations
|
|
51
|
+
#################################################################################
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def set_common_annotations(array, set_src_component_to_res_name=True):
|
|
55
|
+
annots = array.get_annotation_categories()
|
|
56
|
+
if "occupancy" not in annots:
|
|
57
|
+
array.set_annotation("occupancy", np.ones(array.shape[0], dtype=float))
|
|
58
|
+
if "b_factor" not in annots:
|
|
59
|
+
array.set_annotation("b_factor", np.zeros(array.shape[0], dtype=float))
|
|
60
|
+
if "charge" not in annots:
|
|
61
|
+
array.set_annotation("charge", np.zeros(array.shape[0], dtype=float))
|
|
62
|
+
if "src_component" not in annots:
|
|
63
|
+
if set_src_component_to_res_name:
|
|
64
|
+
array.set_annotation(
|
|
65
|
+
"src_component",
|
|
66
|
+
np.full(
|
|
67
|
+
array.shape[0], array.res_name.copy(), dtype=array.res_name.dtype
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
array.set_annotation(
|
|
72
|
+
"src_component", np.full(array.shape[0], "", dtype=array.res_name.dtype)
|
|
73
|
+
)
|
|
74
|
+
return array
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def set_indices(array, chain, res_id_start, molecule_id, component):
|
|
78
|
+
n = array.shape[0]
|
|
79
|
+
array.chain_id = np.full(n, chain, dtype=array.chain_id.dtype)
|
|
80
|
+
array.res_id = np.full(n, res_id_start + array.res_id - 1, dtype=array.res_id.dtype)
|
|
81
|
+
array.molecule_id = np.full(n, molecule_id, dtype=np.int32)
|
|
82
|
+
array.set_annotation(
|
|
83
|
+
"src_component", np.full(n, component, dtype=array.chain_id.dtype)
|
|
84
|
+
)
|
|
85
|
+
return array
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
#################################################################################
|
|
89
|
+
# Getters
|
|
90
|
+
#################################################################################
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def extract_ligand_array(
|
|
94
|
+
atom_array_input,
|
|
95
|
+
ligand,
|
|
96
|
+
fixed_atoms={},
|
|
97
|
+
set_defaults=True,
|
|
98
|
+
additional_annotations=None,
|
|
99
|
+
):
|
|
100
|
+
if not exists(atom_array_input):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"No input file/atom array provided. Cannot add requested ligand."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
ligand_arrays = []
|
|
106
|
+
for lig in ligand.split(","):
|
|
107
|
+
for name in unravel_components(
|
|
108
|
+
lig, atom_array=atom_array_input, allow_multiple_matches=True
|
|
109
|
+
): # additional nesting to allow multiple indices per ligand
|
|
110
|
+
mask = fetch_mask_from_component(name, atom_array=atom_array_input)
|
|
111
|
+
ligand_array = atom_array_input[mask].copy()
|
|
112
|
+
|
|
113
|
+
# ... Set as fully fixed motif
|
|
114
|
+
if set_defaults:
|
|
115
|
+
ligand_array = set_default_conditioning_annotations(
|
|
116
|
+
ligand_array, motif=True, additional=additional_annotations
|
|
117
|
+
) # should be pre-set!
|
|
118
|
+
ligand_array = set_common_annotations(ligand_array)
|
|
119
|
+
|
|
120
|
+
# ... Unfix all names not specified if specified in motif_atoms
|
|
121
|
+
if lig in fixed_atoms or name in fixed_atoms:
|
|
122
|
+
if (lig in fixed_atoms and name in fixed_atoms) and name != lig:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Got both ligand name and its pdb indices in fixed_atoms dictionary: {lig} and {name}. Please only provide one."
|
|
125
|
+
)
|
|
126
|
+
fixed = fixed_atoms.get(lig, fixed_atoms.get(name, None))
|
|
127
|
+
if fixed:
|
|
128
|
+
fixed_mask = get_name_mask(ligand_array.atom_name, fixed)
|
|
129
|
+
ligand_array.is_motif_atom_with_fixed_coord[~fixed_mask] = np.zeros(
|
|
130
|
+
np.sum(~fixed_mask), dtype=int
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
ligand_array.is_motif_atom_with_fixed_coord = np.zeros(
|
|
134
|
+
ligand_array.shape[0], dtype=int
|
|
135
|
+
)
|
|
136
|
+
ligand_arrays.append(ligand_array)
|
|
137
|
+
|
|
138
|
+
ligand_arrays = struc.concatenate(ligand_arrays)
|
|
139
|
+
return ligand_arrays
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def extract_na_array(atom_array_input):
|
|
143
|
+
# TODO : do it more nicely, take into account modifications to NA reses e.g. 5IU
|
|
144
|
+
if (na_mask := np.isin(atom_array_input.res_name, list(STANDARD_DNA))).any():
|
|
145
|
+
na_array = atom_array_input[na_mask]
|
|
146
|
+
# ...replace chain_id A with literally anything else available
|
|
147
|
+
Achain_mask = na_array.chain_id == "A"
|
|
148
|
+
|
|
149
|
+
all_nonAchains = np.unique((atom_array_input + na_array).chain_id).tolist()
|
|
150
|
+
all_nonAchains.remove("A")
|
|
151
|
+
|
|
152
|
+
if len(all_nonAchains) > 1:
|
|
153
|
+
new_chain = "".join(all_nonAchains) # join_them_all !! so definitely unique
|
|
154
|
+
elif len(all_nonAchains) == 1:
|
|
155
|
+
new_chain = all_nonAchains[0] + all_nonAchains[0]
|
|
156
|
+
else:
|
|
157
|
+
new_chain = "B"
|
|
158
|
+
|
|
159
|
+
na_array.chain_id[Achain_mask] = new_chain
|
|
160
|
+
na_array = set_default_conditioning_annotations(na_array, motif=True)
|
|
161
|
+
return na_array
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"Could not find any NA tokens in input file, but requested to add all NA"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _restore_bonds_for_nonstandard_residues(
|
|
169
|
+
atom_array_accum: struc.AtomArray,
|
|
170
|
+
src_atom_array: struc.AtomArray | None,
|
|
171
|
+
source_to_accum_idx: Dict[int, int],
|
|
172
|
+
) -> struc.AtomArray:
|
|
173
|
+
"""
|
|
174
|
+
Restores and creates bonds for non-standard residues (PTMs, modified AAs, etc.)
|
|
175
|
+
from source structure and between consecutive residues.
|
|
176
|
+
This function:
|
|
177
|
+
1. Preserves inter-residue bonds from the source structure (if available)
|
|
178
|
+
2. Adds backbone C-N bonds between consecutive residues where at least one is non-standard
|
|
179
|
+
Args:
|
|
180
|
+
atom_array_accum: The accumulated atom array to add bonds to
|
|
181
|
+
src_atom_array: The source atom array containing original bond information
|
|
182
|
+
source_to_accum_idx: Mapping from source atom indices to accumulated array indices
|
|
183
|
+
Returns:
|
|
184
|
+
atom_array_accum with bonds added
|
|
185
|
+
"""
|
|
186
|
+
# Initialize bonds if needed
|
|
187
|
+
if atom_array_accum.bonds is None:
|
|
188
|
+
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())
|
|
189
|
+
|
|
190
|
+
# Step 1: Restore inter-residue bonds from the source atom array (only for non-standard residues)
|
|
191
|
+
if (
|
|
192
|
+
src_atom_array is not None
|
|
193
|
+
and hasattr(src_atom_array, "bonds")
|
|
194
|
+
and src_atom_array.bonds is not None
|
|
195
|
+
):
|
|
196
|
+
original_bonds = src_atom_array.bonds.as_array()
|
|
197
|
+
if len(original_bonds) > 0:
|
|
198
|
+
# Extract bonds where both atoms are in the accumulated array
|
|
199
|
+
bonds_to_add = []
|
|
200
|
+
for bond in original_bonds:
|
|
201
|
+
atom_i, atom_j, bond_type = bond
|
|
202
|
+
# Check if both atoms are in our mapping
|
|
203
|
+
if (
|
|
204
|
+
int(atom_i) in source_to_accum_idx
|
|
205
|
+
and int(atom_j) in source_to_accum_idx
|
|
206
|
+
):
|
|
207
|
+
# Check if at least one atom is from a non-standard residue
|
|
208
|
+
src_res_i = src_atom_array[int(atom_i)].res_name
|
|
209
|
+
src_res_j = src_atom_array[int(atom_j)].res_name
|
|
210
|
+
|
|
211
|
+
# Only preserve if at least one residue is non-standard
|
|
212
|
+
if src_res_i not in STANDARD_AA or src_res_j not in STANDARD_AA:
|
|
213
|
+
new_i = source_to_accum_idx[int(atom_i)]
|
|
214
|
+
new_j = source_to_accum_idx[int(atom_j)]
|
|
215
|
+
bonds_to_add.append([new_i, new_j, int(bond_type)])
|
|
216
|
+
|
|
217
|
+
if bonds_to_add:
|
|
218
|
+
# Add the preserved bonds
|
|
219
|
+
new_bonds = struc.BondList(
|
|
220
|
+
atom_array_accum.array_length(),
|
|
221
|
+
np.array(bonds_to_add, dtype=np.int64),
|
|
222
|
+
)
|
|
223
|
+
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
|
|
224
|
+
logger.info(
|
|
225
|
+
f"Preserved {len(bonds_to_add)} inter-residue bonds involving non-standard residues from source structure"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Step 2: Add backbone bonds between consecutive residues where at least one is non-standard
|
|
229
|
+
# This handles: PTM-to-diffused, diffused-to-PTM, PTM-to-PTM, ligand-to-protein
|
|
230
|
+
bonds_to_add = []
|
|
231
|
+
|
|
232
|
+
# Group by residue
|
|
233
|
+
token_starts = get_token_starts(atom_array_accum, add_exclusive_stop=True)
|
|
234
|
+
|
|
235
|
+
for i in range(
|
|
236
|
+
len(token_starts) - 2
|
|
237
|
+
): # -2 because we need pairs and token_starts has exclusive stop
|
|
238
|
+
curr_start, curr_end = token_starts[i], token_starts[i + 1]
|
|
239
|
+
next_start, next_end = token_starts[i + 1], token_starts[i + 2]
|
|
240
|
+
|
|
241
|
+
curr_residue = atom_array_accum[curr_start:curr_end]
|
|
242
|
+
next_residue = atom_array_accum[next_start:next_end]
|
|
243
|
+
|
|
244
|
+
# Check if at least one residue is non-standard (PTMs, modified AAs, etc.)
|
|
245
|
+
curr_is_nonstandard = curr_residue.res_name[0] not in STANDARD_AA
|
|
246
|
+
next_is_nonstandard = next_residue.res_name[0] not in STANDARD_AA
|
|
247
|
+
|
|
248
|
+
# Only add bonds if at least one residue is non-standard
|
|
249
|
+
if not (curr_is_nonstandard or next_is_nonstandard):
|
|
250
|
+
continue
|
|
251
|
+
|
|
252
|
+
# Check if consecutive in same chain
|
|
253
|
+
if curr_residue.chain_id[0] != next_residue.chain_id[0]:
|
|
254
|
+
continue
|
|
255
|
+
if next_residue.res_id[0] - curr_residue.res_id[0] != 1:
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
# Find C atom in current residue (C-terminus connection point)
|
|
259
|
+
c_mask = curr_residue.atom_name == "C"
|
|
260
|
+
if not np.any(c_mask):
|
|
261
|
+
# If a non-standard residue doesn't have a C atom, it can't connect to next residue
|
|
262
|
+
# This is expected for some atomized residues or ligands at chain termini
|
|
263
|
+
if curr_is_nonstandard and next_is_nonstandard:
|
|
264
|
+
# Both are non-standard but no C in current - might be an atomized region without proper termini
|
|
265
|
+
logger.debug(
|
|
266
|
+
f"Non-standard residue {curr_residue.res_name[0]} (res_id {curr_residue.res_id[0]}) "
|
|
267
|
+
f"has no C atom - cannot form backbone bond to next residue"
|
|
268
|
+
)
|
|
269
|
+
continue
|
|
270
|
+
c_idx = curr_start + np.where(c_mask)[0][0]
|
|
271
|
+
|
|
272
|
+
# Find N atom in next residue (N-terminus connection point)
|
|
273
|
+
n_mask = next_residue.atom_name == "N"
|
|
274
|
+
if not np.any(n_mask):
|
|
275
|
+
# If a non-standard residue doesn't have an N atom, it can't connect to previous residue
|
|
276
|
+
# This is expected for some atomized residues or ligands at chain termini
|
|
277
|
+
if curr_is_nonstandard and next_is_nonstandard:
|
|
278
|
+
# Both are non-standard but no N in next - might be an atomized region without proper termini
|
|
279
|
+
logger.debug(
|
|
280
|
+
f"Non-standard residue {next_residue.res_name[0]} (res_id {next_residue.res_id[0]}) "
|
|
281
|
+
f"has no N atom - cannot form backbone bond from previous residue"
|
|
282
|
+
)
|
|
283
|
+
continue
|
|
284
|
+
n_idx = next_start + np.where(n_mask)[0][0]
|
|
285
|
+
|
|
286
|
+
# Check if this bond already exists (from source preservation)
|
|
287
|
+
existing_bonds = atom_array_accum.bonds.as_array()
|
|
288
|
+
bond_exists = False
|
|
289
|
+
if len(existing_bonds) > 0:
|
|
290
|
+
for existing_bond in existing_bonds:
|
|
291
|
+
if (existing_bond[0] == c_idx and existing_bond[1] == n_idx) or (
|
|
292
|
+
existing_bond[0] == n_idx and existing_bond[1] == c_idx
|
|
293
|
+
):
|
|
294
|
+
bond_exists = True
|
|
295
|
+
break
|
|
296
|
+
|
|
297
|
+
if not bond_exists:
|
|
298
|
+
bonds_to_add.append([c_idx, n_idx, struc.BondType.SINGLE])
|
|
299
|
+
|
|
300
|
+
if bonds_to_add:
|
|
301
|
+
new_bonds = struc.BondList(
|
|
302
|
+
atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64)
|
|
303
|
+
)
|
|
304
|
+
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
|
|
305
|
+
logger.info(
|
|
306
|
+
f"Added {len(bonds_to_add)} backbone bonds involving non-standard residues"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return atom_array_accum
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
#################################################################################
|
|
313
|
+
# File IO utilities
|
|
314
|
+
#################################################################################
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def inference_load_(
|
|
318
|
+
file: PathLike, *, assembly_id: str = "1", cif_parser_args: dict | None = None
|
|
319
|
+
):
|
|
320
|
+
# Default cif_parser_args to an empty dictionary if not provided
|
|
321
|
+
if cif_parser_args is None:
|
|
322
|
+
cif_parser_args = {}
|
|
323
|
+
|
|
324
|
+
# Convenience utilities to default to loading from and saving to cache if a cache_dir is provided, unless explicitly overridden
|
|
325
|
+
if "cache_dir" in cif_parser_args and cif_parser_args["cache_dir"]:
|
|
326
|
+
cif_parser_args.setdefault("load_from_cache", True)
|
|
327
|
+
cif_parser_args.setdefault("save_to_cache", True)
|
|
328
|
+
|
|
329
|
+
merged_cif_parser_args = {
|
|
330
|
+
**STANDARD_PARSER_ARGS,
|
|
331
|
+
**{
|
|
332
|
+
"fix_arginines": False,
|
|
333
|
+
"add_missing_atoms": False,
|
|
334
|
+
"remove_ccds": [],
|
|
335
|
+
},
|
|
336
|
+
**cif_parser_args,
|
|
337
|
+
}
|
|
338
|
+
merged_cif_parser_args["hydrogen_policy"] = "remove"
|
|
339
|
+
|
|
340
|
+
# Ensure the required annotations can be loaded
|
|
341
|
+
merged_cif_parser_args["extra_fields"] = list(
|
|
342
|
+
set(
|
|
343
|
+
merged_cif_parser_args.get("extra_fields", [])
|
|
344
|
+
+ REQUIRED_CONDITIONING_ANNOTATIONS
|
|
345
|
+
)
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Use the parse function with the merged CIF parser arguments
|
|
349
|
+
result_dict = parse(
|
|
350
|
+
filename=file,
|
|
351
|
+
build_assembly=(assembly_id,), # Convert list to tuple (make hashable)
|
|
352
|
+
**merged_cif_parser_args,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
atom_array = result_dict["assemblies"][assembly_id][0]
|
|
356
|
+
atom_array = convert_existing_annotations_to_bool(atom_array)
|
|
357
|
+
|
|
358
|
+
data = {
|
|
359
|
+
"atom_array": atom_array, # First model
|
|
360
|
+
"chain_info": result_dict["chain_info"],
|
|
361
|
+
"ligand_info": result_dict["ligand_info"],
|
|
362
|
+
"metadata": result_dict["metadata"],
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
return data
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def ensure_input_is_abspath(args: dict, path: PathLike | None):
|
|
369
|
+
"""
|
|
370
|
+
Ensures the input source is an absolute path if exists, if not it will convert
|
|
371
|
+
|
|
372
|
+
args:
|
|
373
|
+
spec: Inference specification for atom array
|
|
374
|
+
path: None or file to which the input is relative to.
|
|
375
|
+
"""
|
|
376
|
+
if isinstance(args, str):
|
|
377
|
+
raise ValueError(
|
|
378
|
+
"Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
|
|
379
|
+
args
|
|
380
|
+
)
|
|
381
|
+
)
|
|
382
|
+
if "input" not in args or not exists(args["input"]):
|
|
383
|
+
return args
|
|
384
|
+
input = args["input"]
|
|
385
|
+
if not os.path.isabs(input):
|
|
386
|
+
input = os.path.abspath(os.path.join(os.path.dirname(path), input))
|
|
387
|
+
ranked_logger.info(
|
|
388
|
+
f"Input source path is relative, converted to absolute path: {input}"
|
|
389
|
+
)
|
|
390
|
+
args["input"] = input
|
|
391
|
+
return args
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
#################################################################################
|
|
395
|
+
# Custom infer_ori functions
|
|
396
|
+
#################################################################################
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def infer_ori_from_hotspots(atom_array: struc.AtomArray):
|
|
400
|
+
assert (
|
|
401
|
+
"is_atom_level_hotspot" in atom_array.get_annotation_categories()
|
|
402
|
+
), "Atom array must contain 'is_atom_level_hotspot' annotation to infer ori from hotspots."
|
|
403
|
+
hotspot_atom_array = atom_array[atom_array.is_atom_level_hotspot.astype(bool)]
|
|
404
|
+
hotspot_com = hotspot_atom_array.coord.mean(axis=0)
|
|
405
|
+
|
|
406
|
+
# We can only perform distance computations on atoms with non-NaN coordinates
|
|
407
|
+
nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1)
|
|
408
|
+
non_nan_atom_array = atom_array[~nan_coords_mask]
|
|
409
|
+
|
|
410
|
+
# Perform the distance computation
|
|
411
|
+
# RFD2 used 10 Angstroms instead of 12, but was for residue-level hotspots
|
|
412
|
+
DISTANCE_CUTOFF = 12.0
|
|
413
|
+
cell_list = struc.CellList(non_nan_atom_array, cell_size=DISTANCE_CUTOFF)
|
|
414
|
+
nearby_atoms_mask = get_atom_mask_from_cell_list(
|
|
415
|
+
hotspot_atom_array.coord,
|
|
416
|
+
cell_list,
|
|
417
|
+
len(non_nan_atom_array),
|
|
418
|
+
cutoff=DISTANCE_CUTOFF,
|
|
419
|
+
) # (n_query, n_cell_list)
|
|
420
|
+
|
|
421
|
+
nearby_atoms_mask = np.any(nearby_atoms_mask, axis=0) # (n_cell_list,)
|
|
422
|
+
nearby_atoms_com = non_nan_atom_array.coord[nearby_atoms_mask].mean(axis=0)
|
|
423
|
+
|
|
424
|
+
vector_from_core_to_hotspot = hotspot_com - nearby_atoms_com
|
|
425
|
+
vector_from_core_to_hotspot = vector_from_core_to_hotspot / np.linalg.norm(
|
|
426
|
+
vector_from_core_to_hotspot
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# This is following RFD2. Both this and the distance cutoff should definitely be configs with defaults
|
|
430
|
+
DISTANCE_ABOVE_HOTSPOTS = 10.0
|
|
431
|
+
ori_token = hotspot_com + DISTANCE_ABOVE_HOTSPOTS * vector_from_core_to_hotspot
|
|
432
|
+
|
|
433
|
+
return ori_token
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def infer_ori_from_com(atom_array):
|
|
437
|
+
xyz = atom_array.coord
|
|
438
|
+
mask = np.isfinite(xyz).all(axis=-1) # Ensure no NaN coordinates
|
|
439
|
+
com = np.mean(xyz[..., mask, :], axis=0)
|
|
440
|
+
return com
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# This can't go in constants.py because that leads to a circular dependency
|
|
444
|
+
INFER_ORI_STRATEGIES_TO_FUNCTIONS = {
|
|
445
|
+
"hotspots": infer_ori_from_hotspots,
|
|
446
|
+
"com": infer_ori_from_com,
|
|
447
|
+
}
|
|
448
|
+
"""
|
|
449
|
+
Constant mapping from infer_ori_strategy keys to the corresponding functions. These functions should take an AtomArray
|
|
450
|
+
as input and return a three-element list or numpy array of floats.
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def set_com(
|
|
455
|
+
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None
|
|
456
|
+
):
|
|
457
|
+
if exists(ori_token):
|
|
458
|
+
center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
|
|
459
|
+
atom_array.coord = atom_array.coord - center
|
|
460
|
+
ranked_logger.info(f"Received ori_token argument. Setting origin as {center}.")
|
|
461
|
+
if infer_ori_strategy is not None:
|
|
462
|
+
ranked_logger.warning(
|
|
463
|
+
f"Specified infer_ori_strategy of '{infer_ori_strategy}' will be ignored because an ori_token was provided."
|
|
464
|
+
)
|
|
465
|
+
elif "ORI" in atom_array.res_name:
|
|
466
|
+
center = atom_array[atom_array.res_name == "ORI"].coord
|
|
467
|
+
if center.shape[0] != 1:
|
|
468
|
+
center = np.random.choice(center, size=1, replace=False)
|
|
469
|
+
ranked_logger.info(f"Found multiple ORI tokens in input. Sampled: {center}")
|
|
470
|
+
center = np.nan_to_num(center.squeeze())
|
|
471
|
+
atom_array.coord = atom_array.coord - center
|
|
472
|
+
ranked_logger.info(
|
|
473
|
+
f"Found ORI token in input. Setting origin as token value ({center})."
|
|
474
|
+
)
|
|
475
|
+
if infer_ori_strategy is not None:
|
|
476
|
+
ranked_logger.warning(
|
|
477
|
+
f"Specified infer_ori_strategy of '{infer_ori_strategy}' will be ignored because an ori_token was provided."
|
|
478
|
+
)
|
|
479
|
+
elif infer_ori_strategy is not None:
|
|
480
|
+
if infer_ori_strategy in INFER_ORI_STRATEGIES_TO_FUNCTIONS:
|
|
481
|
+
center = INFER_ORI_STRATEGIES_TO_FUNCTIONS[infer_ori_strategy](atom_array)
|
|
482
|
+
atom_array.coord = atom_array.coord - center
|
|
483
|
+
ranked_logger.info(
|
|
484
|
+
f"Inferred origin using strategy '{infer_ori_strategy}'. Setting origin as {center}."
|
|
485
|
+
)
|
|
486
|
+
else:
|
|
487
|
+
# No offset
|
|
488
|
+
if np.any(atom_array.is_motif_atom_with_fixed_coord.astype(bool)):
|
|
489
|
+
center = np.nan_to_num(
|
|
490
|
+
np.mean(
|
|
491
|
+
atom_array.coord[
|
|
492
|
+
atom_array.is_motif_atom_with_fixed_coord.astype(bool)
|
|
493
|
+
],
|
|
494
|
+
axis=0,
|
|
495
|
+
)
|
|
496
|
+
)
|
|
497
|
+
ranked_logger.info(
|
|
498
|
+
f"No ori_token or infer_ori_strategy provided. Setting origin as COM of fixed motif ({center})."
|
|
499
|
+
)
|
|
500
|
+
atom_array.coord -= center
|
|
501
|
+
else:
|
|
502
|
+
ranked_logger.warning(
|
|
503
|
+
"No ori_token, infer_ori_strategy, or motif provided. Setting [0,0,0] as origin."
|
|
504
|
+
)
|
|
505
|
+
atom_array.coord = np.zeros_like(
|
|
506
|
+
atom_array.coord, dtype=atom_array.coord.dtype
|
|
507
|
+
)
|
|
508
|
+
return atom_array
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
#################################################################################
|
|
512
|
+
# Custom conditioning functions
|
|
513
|
+
#################################################################################
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def spoof_helical_bundle_ss_conditioning_fn(atom_array: struc.AtomArray):
|
|
517
|
+
# NOTE: This assumes that all diffused residues are protein residues -- should be updated if that changes!
|
|
518
|
+
# Compute islands within the subset that is diffused and has secondary structure types.
|
|
519
|
+
token_level_array = atom_array[get_token_starts(atom_array)]
|
|
520
|
+
is_diffused_atom_token_level = ~(
|
|
521
|
+
token_level_array.is_motif_atom_with_fixed_coord.astype(bool)
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# My reason for sampling from 3-7 is that I don't want to restrict the model too heavily since this is
|
|
525
|
+
# indexed to specific residues, and it will likely extend helices to reasonable lengths once it has started them.
|
|
526
|
+
where_to_show_helices = sample_island_tokens(
|
|
527
|
+
is_diffused_atom_token_level.sum(),
|
|
528
|
+
island_len_min=3,
|
|
529
|
+
island_len_max=7,
|
|
530
|
+
n_islands_min=1,
|
|
531
|
+
n_islands_max=3,
|
|
532
|
+
max_length=None,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Convert this to a mask over the entire token-level atom array
|
|
536
|
+
token_level_helix_mask = np.zeros(token_level_array.array_length(), dtype=bool)
|
|
537
|
+
token_level_helix_mask[is_diffused_atom_token_level] = where_to_show_helices
|
|
538
|
+
|
|
539
|
+
# I don't want to sample very near the tails, as this gets too restrictive for the model
|
|
540
|
+
for chain_id in np.unique(token_level_array.chain_id):
|
|
541
|
+
chain_mask = token_level_array.chain_id == chain_id
|
|
542
|
+
chain_indices = np.where(chain_mask)[0]
|
|
543
|
+
chain_start, chain_end = chain_indices[0], chain_indices[-1] + 1
|
|
544
|
+
chain_length = chain_mask.sum()
|
|
545
|
+
|
|
546
|
+
buffer_length = chain_length // 8
|
|
547
|
+
buffer_mask = chain_mask.copy()
|
|
548
|
+
buffer_mask[chain_start + buffer_length : chain_end - buffer_length] = False
|
|
549
|
+
|
|
550
|
+
token_level_helix_mask[buffer_mask] = False
|
|
551
|
+
|
|
552
|
+
helix_conditioning = np.zeros(atom_array.array_length())
|
|
553
|
+
helix_condition_mask = spread_token_wise(atom_array, token_level_helix_mask)
|
|
554
|
+
|
|
555
|
+
helix_conditioning[helix_condition_mask] = 1
|
|
556
|
+
return helix_conditioning
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
#################################################################################
|
|
560
|
+
# Patching of bad inputs
|
|
561
|
+
#################################################################################
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def generate_idealized_cb_position(N: np.array, Ca: np.array, C: np.array) -> np.array:
|
|
565
|
+
"""
|
|
566
|
+
Generate Cb coordiantes given (N, CA, C) as if the given coordinates were from an idealized Alanine.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
- N (np.array): coordinates of (pseudo) N atoms [..., L, 3]
|
|
570
|
+
- Ca (np.array): coordinates of (pseudo) Ca atoms [..., L, 3]
|
|
571
|
+
- C (np.array): coordinates of (pseudo) C atoms [..., L, 3]
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
Cb (torch.Tensor): coordinates of (pseudo) Cb atoms [..., L, 3]
|
|
575
|
+
These will be placed at the idealized Cb distance (based on ALA) from Ca, assuming a frame of the following form:
|
|
576
|
+
- x-axis: along the Ca-C bond
|
|
577
|
+
- z-axis: perpendicular to the Ca-N-C plane, right-handed wrt to (Ca-C) & (Ca-N) vectors.
|
|
578
|
+
- y-axis: in the plane of the Ca-N-C bonds, such that the overall frame is right-handed.
|
|
579
|
+
Reference:
|
|
580
|
+
- https://github.com/google-deepmind/alphafold/blob/d95a92aae161240b645fc10e9d030443011d913e/alphafold/common/residue_constants.py#L126-L335
|
|
581
|
+
ALA:
|
|
582
|
+
['N', 0, (-0.525, 1.363, 0.000)], # ca-n bond dist: 1.4606142543
|
|
583
|
+
['CA', 0, ( 0.000, 0.000, 0.000)],
|
|
584
|
+
['C', 0, ( 1.526, 0.000, 0.000)], # ca-c bond dist: 1.526
|
|
585
|
+
['CB', 0, (-0.529, -0.774, -1.205)], # cb-ca bond dist: 1.5267422834
|
|
586
|
+
"""
|
|
587
|
+
if np.linalg.norm(N) == 0 and np.linalg.norm(C) == 0 and np.linalg.norm(Ca) == 0:
|
|
588
|
+
return np.zeros_like(N)
|
|
589
|
+
|
|
590
|
+
def _safe_normalize(vec: np.ndarray) -> np.ndarray:
|
|
591
|
+
vec = np.asarray(vec, dtype=float)
|
|
592
|
+
norms = np.linalg.norm(vec, axis=-1, keepdims=True)
|
|
593
|
+
norms = np.where(norms == 0, 1.0, norms)
|
|
594
|
+
return vec / norms
|
|
595
|
+
|
|
596
|
+
normalize = _safe_normalize
|
|
597
|
+
|
|
598
|
+
# ... get local frame x-axis
|
|
599
|
+
to_C = C - Ca
|
|
600
|
+
frame_x = normalize(to_C)
|
|
601
|
+
|
|
602
|
+
# ... get local frame z-axis
|
|
603
|
+
to_N = N - Ca
|
|
604
|
+
to_out_of_plane = np.cross(frame_x, normalize(to_N), axis=-1)
|
|
605
|
+
frame_z = normalize(to_out_of_plane)
|
|
606
|
+
|
|
607
|
+
# ... get local frame y-axis
|
|
608
|
+
frame_y = normalize(np.cross(frame_z, frame_x, axis=-1))
|
|
609
|
+
|
|
610
|
+
# ... place virtual Cb at the desired location
|
|
611
|
+
Cb = Ca + (-0.529 * frame_x - 0.774 * frame_y - 1.205 * frame_z)
|
|
612
|
+
return Cb
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def create_cb_atoms(array):
|
|
616
|
+
# array of length 4 with N, CA, C, O
|
|
617
|
+
# Returns array with CB placed ideally
|
|
618
|
+
if array.atom_name.tolist() != ["N", "CA", "C", "O"]:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
"Input array must contain exactly 4 atoms: N, CA, C, O. Got : {}".format(
|
|
621
|
+
array.atom_name.tolist()
|
|
622
|
+
)
|
|
623
|
+
)
|
|
624
|
+
cb_atoms = array[array.atom_name == "CA"].copy()
|
|
625
|
+
cb_atoms.atom_name = np.array(["CB"], dtype=cb_atoms.atom_name.dtype)
|
|
626
|
+
cb_pos = generate_idealized_cb_position(
|
|
627
|
+
array.coord[array.atom_name == "N"].squeeze(),
|
|
628
|
+
array.coord[array.atom_name == "CA"].squeeze(),
|
|
629
|
+
array.coord[array.atom_name == "C"].squeeze(),
|
|
630
|
+
)
|
|
631
|
+
cb_atoms.coord = cb_pos[None]
|
|
632
|
+
return cb_atoms
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def create_o_atoms(array):
|
|
636
|
+
if array.atom_name.tolist() != ["N", "CA", "C"]:
|
|
637
|
+
raise ValueError(
|
|
638
|
+
"Input array must contain exactly 4 atoms: N, CA, C, O. Got : {}".format(
|
|
639
|
+
array.atom_name.tolist()
|
|
640
|
+
)
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
ca_atoms = array[array.atom_name == "CA"].copy()
|
|
644
|
+
ca_atoms.atom_name = np.array(["O"], dtype=ca_atoms.atom_name.dtype)
|
|
645
|
+
ca_atoms.element = np.array(["O"], dtype=ca_atoms.element.dtype)
|
|
646
|
+
ca_atoms.coord = array.coord[array.atom_name == "C"].squeeze()[None]
|
|
647
|
+
|
|
648
|
+
return ca_atoms
|