rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,717 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from os import PathLike
|
|
5
|
+
|
|
6
|
+
import biotite.structure as struc
|
|
7
|
+
import numpy as np
|
|
8
|
+
from atomworks.constants import STANDARD_AA
|
|
9
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
10
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
11
|
+
from atomworks.ml.utils.token import (
|
|
12
|
+
get_token_starts,
|
|
13
|
+
)
|
|
14
|
+
from rfd3.constants import (
|
|
15
|
+
INFERENCE_ANNOTATIONS,
|
|
16
|
+
OPTIONAL_CONDITIONING_VALUES,
|
|
17
|
+
REQUIRED_INFERENCE_ANNOTATIONS,
|
|
18
|
+
)
|
|
19
|
+
from rfd3.inference.symmetry.symmetry_utils import (
|
|
20
|
+
center_symmetric_src_atom_array,
|
|
21
|
+
make_symmetric_atom_array,
|
|
22
|
+
)
|
|
23
|
+
from rfd3.transforms.conditioning_base import (
|
|
24
|
+
check_has_required_conditioning_annotations,
|
|
25
|
+
convert_existing_annotations_to_bool,
|
|
26
|
+
get_motif_features,
|
|
27
|
+
set_default_conditioning_annotations,
|
|
28
|
+
)
|
|
29
|
+
from rfd3.transforms.util_transforms import assign_types_
|
|
30
|
+
from rfd3.utils.inference import (
|
|
31
|
+
create_cb_atoms,
|
|
32
|
+
create_o_atoms,
|
|
33
|
+
extract_ligand_array,
|
|
34
|
+
inference_load_,
|
|
35
|
+
set_com,
|
|
36
|
+
set_common_annotations,
|
|
37
|
+
set_indices,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
from foundry.common import exists
|
|
41
|
+
from foundry.utils.components import (
|
|
42
|
+
fetch_mask_from_component,
|
|
43
|
+
fetch_mask_from_idx,
|
|
44
|
+
get_design_pattern_with_constraints,
|
|
45
|
+
get_motif_components_and_breaks,
|
|
46
|
+
get_name_mask,
|
|
47
|
+
split_contig,
|
|
48
|
+
)
|
|
49
|
+
from foundry.utils.ddp import RankedLogger
|
|
50
|
+
|
|
51
|
+
logging.basicConfig(level=logging.INFO)
|
|
52
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
53
|
+
|
|
54
|
+
sequence_encoding = AF3SequenceEncoding()
|
|
55
|
+
_aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def assert_non_intersecting_contigs(indexed_components, unindexed_components):
|
|
59
|
+
assert not any(
|
|
60
|
+
[
|
|
61
|
+
(
|
|
62
|
+
(unindexed_component in indexed_components)
|
|
63
|
+
and unindexed_component[0].isalpha()
|
|
64
|
+
)
|
|
65
|
+
for unindexed_component in unindexed_components
|
|
66
|
+
]
|
|
67
|
+
), "Unindexed residues must not be part of the indexing contig. got: {} and {}".format(
|
|
68
|
+
unindexed_components, indexed_components
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def set_atom_level_argument(atom_array, args, name: str):
|
|
73
|
+
default_value = OPTIONAL_CONDITIONING_VALUES.get(name, np.nan)
|
|
74
|
+
atom_values = np.full(atom_array.array_length(), default_value)
|
|
75
|
+
atom_idxs = np.arange(atom_array.array_length())
|
|
76
|
+
|
|
77
|
+
if args is not None:
|
|
78
|
+
for component_name, d in args.items():
|
|
79
|
+
component_mask = fetch_mask_from_component(
|
|
80
|
+
component_name, atom_array=atom_array
|
|
81
|
+
)
|
|
82
|
+
for names, value in d.items():
|
|
83
|
+
mask = component_mask & np.isin(
|
|
84
|
+
atom_array.atom_name, np.array(names.split(","))
|
|
85
|
+
)
|
|
86
|
+
assert mask.sum() == len(
|
|
87
|
+
names.split(",")
|
|
88
|
+
), f"Not all atoms in {names} found in {atom_array.atom_name}"
|
|
89
|
+
|
|
90
|
+
atom_idxs_selected = atom_idxs[mask]
|
|
91
|
+
atom_values[atom_idxs_selected] = value
|
|
92
|
+
|
|
93
|
+
atom_array.set_annotation(name, atom_values)
|
|
94
|
+
return atom_array
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def fetch_motif_residue_(
|
|
98
|
+
src_chain,
|
|
99
|
+
src_resid,
|
|
100
|
+
*,
|
|
101
|
+
components,
|
|
102
|
+
src_atom_array,
|
|
103
|
+
redesign_motif_sidechains,
|
|
104
|
+
unindexed_components,
|
|
105
|
+
unfixed_sequence_components,
|
|
106
|
+
fixed_atoms,
|
|
107
|
+
unfix_all,
|
|
108
|
+
flexible_backbone,
|
|
109
|
+
unfix_residues,
|
|
110
|
+
):
|
|
111
|
+
"""
|
|
112
|
+
Given source chain and resid, returns the residue if present in the source atom array
|
|
113
|
+
|
|
114
|
+
NB: For glycines, we extend the array with a CB position so as to not leak whether
|
|
115
|
+
the original residue is a glycine if sequence is masked during inference.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
assert (
|
|
119
|
+
src_atom_array is not None
|
|
120
|
+
), "Motif provided in contigs, but no input provided. input={} contig={}".format(
|
|
121
|
+
input, components
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# ... Fetch residue in the input atom array
|
|
125
|
+
mask = fetch_mask_from_idx(f"{src_chain}{src_resid}", atom_array=src_atom_array)
|
|
126
|
+
subarray = src_atom_array[mask]
|
|
127
|
+
res_name = subarray.res_name[0]
|
|
128
|
+
|
|
129
|
+
# Check if we have a redesign_motif_sidechains contig
|
|
130
|
+
if isinstance(redesign_motif_sidechains, list):
|
|
131
|
+
# If we have a list, check if the residue is in the list
|
|
132
|
+
if f"{src_chain}{src_resid}" in redesign_motif_sidechains:
|
|
133
|
+
redesign_motif_sidechains = True
|
|
134
|
+
else:
|
|
135
|
+
redesign_motif_sidechains = False
|
|
136
|
+
|
|
137
|
+
# Assign base properties
|
|
138
|
+
subarray = set_default_conditioning_annotations(
|
|
139
|
+
subarray, motif=True, unindexed=False, dtype=int
|
|
140
|
+
) # all values init to True (fix all)
|
|
141
|
+
|
|
142
|
+
# Assign is motif atom and sequence
|
|
143
|
+
if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
|
|
144
|
+
atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
|
|
145
|
+
subarray.set_annotation("is_motif_atom", atom_mask)
|
|
146
|
+
# subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
|
|
147
|
+
|
|
148
|
+
elif redesign_motif_sidechains and res_name in STANDARD_AA:
|
|
149
|
+
n_atoms = subarray.shape[0]
|
|
150
|
+
diffuse_oxygen = False
|
|
151
|
+
if n_atoms < 3:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Not enough data for {src_chain}{src_resid} in input atom array."
|
|
154
|
+
)
|
|
155
|
+
if n_atoms == 3:
|
|
156
|
+
# Handle cases with N, CA, C only;
|
|
157
|
+
subarray = subarray + create_o_atoms(subarray.copy())
|
|
158
|
+
diffuse_oxygen = True # flag oxygen for generation
|
|
159
|
+
|
|
160
|
+
# Subset to the first 4 atoms (N, CA, C, O) only
|
|
161
|
+
subarray = subarray[np.isin(subarray.atom_name, ["N", "CA", "C", "O"])]
|
|
162
|
+
|
|
163
|
+
# exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
|
|
164
|
+
# Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
|
|
165
|
+
# are placed on the CB so as to not leak the identity of the residue.
|
|
166
|
+
subarray = subarray + create_cb_atoms(subarray.copy())
|
|
167
|
+
|
|
168
|
+
# Sequence name must be set to ALA such that the central atom is correctly CB
|
|
169
|
+
subarray.res_name = np.full_like(
|
|
170
|
+
subarray.res_name, "ALA", dtype=subarray.res_name.dtype
|
|
171
|
+
)
|
|
172
|
+
subarray.set_annotation(
|
|
173
|
+
"is_motif_atom",
|
|
174
|
+
(
|
|
175
|
+
np.arange(subarray.shape[0], dtype=int) < (4 - int(diffuse_oxygen))
|
|
176
|
+
).astype(int),
|
|
177
|
+
)
|
|
178
|
+
subarray.set_annotation(
|
|
179
|
+
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
|
180
|
+
)
|
|
181
|
+
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
|
|
182
|
+
subarray.set_annotation(
|
|
183
|
+
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
subarray.set_annotation(
|
|
187
|
+
"is_motif_atom_with_fixed_coord", subarray.is_motif_atom.copy()
|
|
188
|
+
)
|
|
189
|
+
if flexible_backbone:
|
|
190
|
+
backbone_atoms = ["N", "CA", "C", "O"]
|
|
191
|
+
is_flexible_motif_atom = np.isin(subarray.atom_name, backbone_atoms)
|
|
192
|
+
subarray.set_annotation(
|
|
193
|
+
"is_flexible_motif_atom",
|
|
194
|
+
is_flexible_motif_atom,
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
subarray.set_annotation(
|
|
198
|
+
"is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
|
|
199
|
+
)
|
|
200
|
+
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
201
|
+
if to_unindex:
|
|
202
|
+
subarray.set_annotation(
|
|
203
|
+
"is_motif_atom_unindexed", subarray.is_motif_atom.copy()
|
|
204
|
+
)
|
|
205
|
+
# Subset to desired motif atoms
|
|
206
|
+
subarray = subarray[subarray.is_motif_atom.astype(bool)]
|
|
207
|
+
|
|
208
|
+
# ... Relax sequence constraint if provided
|
|
209
|
+
if (
|
|
210
|
+
exists(unfixed_sequence_components)
|
|
211
|
+
and f"{src_chain}{src_resid}" in unfixed_sequence_components
|
|
212
|
+
):
|
|
213
|
+
ranked_logger.info(
|
|
214
|
+
"Unfixing sequence for motif {}{}".format(src_chain, src_resid)
|
|
215
|
+
)
|
|
216
|
+
subarray.set_annotation(
|
|
217
|
+
"is_motif_atom_with_fixed_seq",
|
|
218
|
+
np.zeros(subarray.shape[0], dtype=int),
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# ... Double check that required annotations are set within the scope of this function only
|
|
222
|
+
check_has_required_conditioning_annotations(subarray)
|
|
223
|
+
subarray = set_common_annotations(subarray)
|
|
224
|
+
subarray.set_annotation("res_id", np.full(subarray.shape[0], 1)) # Reset to 1
|
|
225
|
+
return subarray
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def create_diffused_residues_(n):
|
|
229
|
+
if n <= 0:
|
|
230
|
+
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
|
231
|
+
|
|
232
|
+
atoms = []
|
|
233
|
+
[
|
|
234
|
+
atoms.extend(
|
|
235
|
+
[
|
|
236
|
+
struc.Atom(
|
|
237
|
+
np.array([0.0, 0.0, 0.0], dtype=np.float32),
|
|
238
|
+
res_name="ALA",
|
|
239
|
+
res_id=idx,
|
|
240
|
+
)
|
|
241
|
+
for _ in range(5)
|
|
242
|
+
]
|
|
243
|
+
)
|
|
244
|
+
for idx in range(1, n + 1)
|
|
245
|
+
]
|
|
246
|
+
array = struc.array(atoms)
|
|
247
|
+
array.set_annotation(
|
|
248
|
+
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
|
|
249
|
+
)
|
|
250
|
+
array.set_annotation(
|
|
251
|
+
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
|
|
252
|
+
)
|
|
253
|
+
array = set_default_conditioning_annotations(array, motif=False)
|
|
254
|
+
array = set_common_annotations(array)
|
|
255
|
+
return array
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def accumulate_components(
|
|
259
|
+
components,
|
|
260
|
+
src_atom_array,
|
|
261
|
+
redesign_motif_sidechains,
|
|
262
|
+
unindexed_components: list[str],
|
|
263
|
+
unfixed_sequence_components: list[str],
|
|
264
|
+
breaks: list,
|
|
265
|
+
fixed_atoms: dict,
|
|
266
|
+
unfix_all: bool,
|
|
267
|
+
optional_conditions: list[str],
|
|
268
|
+
flexible_backbone: bool,
|
|
269
|
+
*,
|
|
270
|
+
start_chain="A",
|
|
271
|
+
unfix_residues: list[str],
|
|
272
|
+
start_resid=1,
|
|
273
|
+
):
|
|
274
|
+
"""
|
|
275
|
+
Subcomponents have three types, specifying either the end of a chain ("/0),
|
|
276
|
+
a motif (e.g. "A20" or "A21"), or a number indicating the number of diffused residues to create.
|
|
277
|
+
This function accumulates these components into a single atom array.
|
|
278
|
+
|
|
279
|
+
Arguments:
|
|
280
|
+
- components: list of components, where each component is either a string
|
|
281
|
+
e.g. [2, A20, A21, 2, A25, 3, A30, /0, 3]
|
|
282
|
+
- src_atom_array: the source atom array to fetch motifs from, or None if no input is provided.
|
|
283
|
+
- unindexed_components: list of components to unindex e.g. [A20, A21]
|
|
284
|
+
- redesign_motif_sidechains: whether to diffuse the sidechains of the input motifs (indexed components)
|
|
285
|
+
- fixed_atoms: dictionary of fixed atoms for each component (previously called `contig_atoms`)
|
|
286
|
+
- unfix_all: whether to fully unfix the motif coordinates
|
|
287
|
+
- unfix_residues: list of residues to unfix. Overrides `unfix_all` for specific residues.
|
|
288
|
+
- flexible_backbone: whether to allow flexible backbone for motifs
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
- Accumulated atom array with components, and is_motif labels
|
|
292
|
+
"""
|
|
293
|
+
# ... Create component assignment functions
|
|
294
|
+
fetch_motif_residue = functools.partial(
|
|
295
|
+
fetch_motif_residue_,
|
|
296
|
+
components=components,
|
|
297
|
+
src_atom_array=src_atom_array,
|
|
298
|
+
redesign_motif_sidechains=redesign_motif_sidechains,
|
|
299
|
+
unindexed_components=unindexed_components,
|
|
300
|
+
unfixed_sequence_components=unfixed_sequence_components,
|
|
301
|
+
fixed_atoms=fixed_atoms,
|
|
302
|
+
unfix_all=unfix_all,
|
|
303
|
+
flexible_backbone=flexible_backbone,
|
|
304
|
+
unfix_residues=unfix_residues,
|
|
305
|
+
)
|
|
306
|
+
create_diffused_residues = create_diffused_residues_
|
|
307
|
+
|
|
308
|
+
# ... For loop accum variables
|
|
309
|
+
breaks = [None] * len(components) if breaks is None else breaks
|
|
310
|
+
unindexed_components_started = (
|
|
311
|
+
False # once one unindexed component is added, stop adding diffused residues
|
|
312
|
+
)
|
|
313
|
+
atom_array_accum = []
|
|
314
|
+
chain = start_chain
|
|
315
|
+
res_id = start_resid
|
|
316
|
+
molecule_id = 0
|
|
317
|
+
# 2) Insert contig information one- by one-
|
|
318
|
+
for component, is_break in zip(components, breaks):
|
|
319
|
+
if component == "/0":
|
|
320
|
+
# reset iterators on next chain
|
|
321
|
+
chain = chr(ord(chain) + 1)
|
|
322
|
+
molecule_id += 1
|
|
323
|
+
res_id = 1
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
# Create array to insert
|
|
327
|
+
if str(component)[0].isalpha(): # motif (e.g. "A22")
|
|
328
|
+
atom_array_insert = fetch_motif_residue(*split_contig(component))
|
|
329
|
+
n = 1
|
|
330
|
+
if exists(is_break) and is_break:
|
|
331
|
+
if not unindexed_components_started:
|
|
332
|
+
chain = start_chain
|
|
333
|
+
unindexed_components_started = True
|
|
334
|
+
atom_array_insert.set_annotation(
|
|
335
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
336
|
+
np.ones(atom_array_insert.shape[0], dtype=int),
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
n = int(component)
|
|
340
|
+
if n == 0 or unindexed_components_started:
|
|
341
|
+
res_id += n
|
|
342
|
+
continue
|
|
343
|
+
atom_array_insert = create_diffused_residues(n)
|
|
344
|
+
for key in optional_conditions:
|
|
345
|
+
atom_array_insert.set_annotation(
|
|
346
|
+
key,
|
|
347
|
+
np.full(
|
|
348
|
+
atom_array_insert.array_length(),
|
|
349
|
+
OPTIONAL_CONDITIONING_VALUES[key],
|
|
350
|
+
dtype=int,
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# ... Set index of insertion
|
|
355
|
+
atom_array_insert = set_indices(
|
|
356
|
+
array=atom_array_insert,
|
|
357
|
+
chain=chain,
|
|
358
|
+
res_id_start=res_id,
|
|
359
|
+
molecule_id=molecule_id,
|
|
360
|
+
component=component,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
assert (
|
|
364
|
+
len(get_token_starts(atom_array_insert)) == n
|
|
365
|
+
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(atom_array_insert))} in \n{atom_array_insert}"
|
|
366
|
+
|
|
367
|
+
# ... Insert & Increment residue ID
|
|
368
|
+
atom_array_accum.append(atom_array_insert)
|
|
369
|
+
res_id += n
|
|
370
|
+
|
|
371
|
+
atom_array_accum = struc.concatenate(atom_array_accum)
|
|
372
|
+
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
|
|
373
|
+
|
|
374
|
+
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
|
|
375
|
+
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
|
|
376
|
+
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
377
|
+
):
|
|
378
|
+
max_id = np.max(
|
|
379
|
+
atom_array_accum[
|
|
380
|
+
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
381
|
+
].res_id
|
|
382
|
+
)
|
|
383
|
+
min_id_udx = np.min(
|
|
384
|
+
atom_array_accum[
|
|
385
|
+
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
386
|
+
].res_id
|
|
387
|
+
)
|
|
388
|
+
atom_array_accum.res_id[
|
|
389
|
+
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
390
|
+
] += max_id - min_id_udx + 1
|
|
391
|
+
|
|
392
|
+
return atom_array_accum
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
#################################################################################
|
|
396
|
+
# Custom conditioning functions
|
|
397
|
+
#################################################################################
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def create_atom_array_from_design_specification_legacy(
|
|
401
|
+
*,
|
|
402
|
+
# Specification args:
|
|
403
|
+
input: PathLike = None,
|
|
404
|
+
length: str = "100-300",
|
|
405
|
+
contig: str = None,
|
|
406
|
+
fixed_atoms: dict = None,
|
|
407
|
+
unindex: str = None,
|
|
408
|
+
unfix_sequence: str = None,
|
|
409
|
+
redesign_motif_sidechains: bool = False,
|
|
410
|
+
unfix_all=False,
|
|
411
|
+
unfix_specific: str = None,
|
|
412
|
+
flexible_backbone: bool = False,
|
|
413
|
+
# Args for biomolecular design (Enzymes, DNA/PNA):
|
|
414
|
+
ligand: str = None,
|
|
415
|
+
ori_token: list[float] = None,
|
|
416
|
+
infer_ori_strategy: str | None = None,
|
|
417
|
+
atomwise_rasa: dict = None,
|
|
418
|
+
atomwise_hbond: dict = None,
|
|
419
|
+
# Additional args:
|
|
420
|
+
out_path=None,
|
|
421
|
+
cif_parser_args=None,
|
|
422
|
+
# PPI Kwargs
|
|
423
|
+
atom_level_hotspots: dict | None = None,
|
|
424
|
+
# SS conditioning kwargs
|
|
425
|
+
is_helix: dict | None = None,
|
|
426
|
+
is_sheet: dict | None = None,
|
|
427
|
+
is_loop: dict | None = None,
|
|
428
|
+
spoof_helical_bundle_ss_conditioning: bool = False,
|
|
429
|
+
symmetry: dict = None,
|
|
430
|
+
# Low-temperature global conditioning args
|
|
431
|
+
plddt_enhanced: bool = True,
|
|
432
|
+
is_non_loopy: bool | None = None,
|
|
433
|
+
# Partial diff args:
|
|
434
|
+
partial_t: float | None = None, # Optional noise scale for partial diffusion
|
|
435
|
+
**_, # dump additional args
|
|
436
|
+
):
|
|
437
|
+
"""
|
|
438
|
+
Create pre-pipeline CIF file.
|
|
439
|
+
|
|
440
|
+
Arguments:
|
|
441
|
+
- input: path to input pdb containing coordinate data
|
|
442
|
+
- contig: your typical contig string '10-10,A20-21,5-5,A25-25,5-5,A30-30,10-10'.
|
|
443
|
+
- unindex: string of residue indices to unindex, e.g. "A20,A21" or "A20-21". Note the latter will be treated as two contiguous
|
|
444
|
+
residues whereas the former will end up as two uncorrelated residues.
|
|
445
|
+
- unfix_sequence: contig string of components to unfix sequence for.
|
|
446
|
+
- unfix_specific: comma separated residues to unfix coordinates for. "ALL" to unfix every motif.
|
|
447
|
+
- length: required total length (optional)
|
|
448
|
+
- ligand: name of ligand to keep from input pdb, or path to a cif file containing the ligand
|
|
449
|
+
- ori_token: coordinates for origin relative to coordinates in input file.
|
|
450
|
+
- infer_ori_strategy: string argument controlling how the ori token is inferred if not otherwise specified.
|
|
451
|
+
If None, the ori token will be set to the COM of the motif, or to [0,0,0] for unconditional generation.
|
|
452
|
+
Currently supported strategies:
|
|
453
|
+
- "hotspots": move 10A along an outward normal vector from the COM of the hotspots.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
- atom_array with all required conditioning annotations set appropriately.
|
|
457
|
+
"""
|
|
458
|
+
###########################################################################################################################
|
|
459
|
+
|
|
460
|
+
# 1) Load input data if provided
|
|
461
|
+
if exists(input):
|
|
462
|
+
atom_array_input = inference_load_(input, cif_parser_args=cif_parser_args)[
|
|
463
|
+
"atom_array"
|
|
464
|
+
]
|
|
465
|
+
# If we are doing symmetric design, we need to center the full input atom array at the origin (for getting symmetry frames)
|
|
466
|
+
if exists(symmetry) and symmetry.get("id"):
|
|
467
|
+
atom_array_input = center_symmetric_src_atom_array(atom_array_input)
|
|
468
|
+
elif exists(contig) or exists(length):
|
|
469
|
+
atom_array_input = None
|
|
470
|
+
else:
|
|
471
|
+
raise ValueError("Either 'input' or 'contig' / 'length' must be provided.")
|
|
472
|
+
if isinstance(length, int):
|
|
473
|
+
length = f"{length}-{length}"
|
|
474
|
+
if exists(length) and not exists(contig):
|
|
475
|
+
# Handle cases where contigs aren't specified
|
|
476
|
+
if not exists(unindex) and not exists(flexible_backbone):
|
|
477
|
+
if exists(fixed_atoms):
|
|
478
|
+
# ensure that fixed atoms are in the input, else raise error
|
|
479
|
+
_ = [
|
|
480
|
+
fetch_mask_from_component(key, atom_array=atom_array_input)
|
|
481
|
+
for key in fixed_atoms.keys()
|
|
482
|
+
]
|
|
483
|
+
ranked_logger.warning(
|
|
484
|
+
"No input contig specified and no motif, running unconditional generation"
|
|
485
|
+
)
|
|
486
|
+
indexed_components_provided = False
|
|
487
|
+
contig = length
|
|
488
|
+
else:
|
|
489
|
+
indexed_components_provided = True
|
|
490
|
+
if not exists(fixed_atoms):
|
|
491
|
+
fixed_atoms = {}
|
|
492
|
+
|
|
493
|
+
optional_conditions = []
|
|
494
|
+
if exists(atomwise_rasa):
|
|
495
|
+
set_atom_level_argument(atom_array_input, atomwise_rasa, "rasa_bin")
|
|
496
|
+
optional_conditions.append("rasa_bin")
|
|
497
|
+
if exists(atomwise_hbond):
|
|
498
|
+
for key, value in atomwise_hbond.items():
|
|
499
|
+
set_atom_level_argument(atom_array_input, value, key)
|
|
500
|
+
optional_conditions.append(key)
|
|
501
|
+
if exists(atom_level_hotspots):
|
|
502
|
+
set_atom_level_argument(
|
|
503
|
+
atom_array_input, atom_level_hotspots, "is_atom_level_hotspot"
|
|
504
|
+
)
|
|
505
|
+
optional_conditions.append("is_atom_level_hotspot")
|
|
506
|
+
|
|
507
|
+
# 2) Parse contigs into components
|
|
508
|
+
indexed_components = get_design_pattern_with_constraints(
|
|
509
|
+
contig, length
|
|
510
|
+
) # e.g. [2, A20, A21, 2, A25, 3, A30, /0, 3]
|
|
511
|
+
|
|
512
|
+
# Parse redesign_motif_sidechains if necessary
|
|
513
|
+
if isinstance(redesign_motif_sidechains, str):
|
|
514
|
+
redesign_motif_sidechains = get_design_pattern_with_constraints(
|
|
515
|
+
redesign_motif_sidechains
|
|
516
|
+
)
|
|
517
|
+
###########################################################################################################################
|
|
518
|
+
|
|
519
|
+
# ... Add unindexed components
|
|
520
|
+
unindexed_components, unindexed_breaks = (
|
|
521
|
+
get_motif_components_and_breaks(unindex) if exists(unindex) else ([], [])
|
|
522
|
+
)
|
|
523
|
+
breaks = [None] * len(indexed_components) + unindexed_breaks
|
|
524
|
+
assert_non_intersecting_contigs(indexed_components, unindexed_components)
|
|
525
|
+
|
|
526
|
+
# ... Expand unfix_sequence into components
|
|
527
|
+
unfixed_sequence_components = (
|
|
528
|
+
get_design_pattern_with_constraints(unfix_sequence) if unfix_sequence else []
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Determine which residues to unfix
|
|
532
|
+
unfix_residues = []
|
|
533
|
+
if isinstance(unfix_specific, list):
|
|
534
|
+
unfix_residues = [str(u) for u in unfix_specific]
|
|
535
|
+
elif isinstance(unfix_specific, str):
|
|
536
|
+
if unfix_specific.upper() == "ALL":
|
|
537
|
+
unfix_all = True
|
|
538
|
+
elif unfix_specific:
|
|
539
|
+
unfix_residues, _ = get_motif_components_and_breaks(
|
|
540
|
+
unfix_specific, index_all=True
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# 3) Create atom array from components
|
|
544
|
+
if exists(partial_t):
|
|
545
|
+
ranked_logger.info("Using partial diffusion with t=%s", partial_t)
|
|
546
|
+
atom_array = assign_types_(copy.deepcopy(atom_array_input))
|
|
547
|
+
atom_array = atom_array[atom_array.is_protein]
|
|
548
|
+
|
|
549
|
+
# Set the whole thing without constraints
|
|
550
|
+
atom_array = set_default_conditioning_annotations(
|
|
551
|
+
atom_array, motif=False, unindexed=False
|
|
552
|
+
)
|
|
553
|
+
atom_array = set_common_annotations(
|
|
554
|
+
atom_array, set_src_component_to_res_name=False
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# Fix parts in the atom array as fixed components
|
|
558
|
+
set_default_conditioning_annotations(atom_array, motif=False, unindexed=False)
|
|
559
|
+
if indexed_components and indexed_components_provided:
|
|
560
|
+
for component in indexed_components:
|
|
561
|
+
if str(component)[0].isalpha():
|
|
562
|
+
mask = fetch_mask_from_component(component, atom_array=atom_array)
|
|
563
|
+
|
|
564
|
+
# Set the component as a motif token
|
|
565
|
+
set_default_conditioning_annotations(
|
|
566
|
+
atom_array, motif=True, unindexed=False, mask=mask
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# Set the fixed atoms of the component
|
|
570
|
+
if mask.any():
|
|
571
|
+
# Also handle fixed atoms
|
|
572
|
+
if component in fixed_atoms:
|
|
573
|
+
atom_mask = get_name_mask(
|
|
574
|
+
atom_array.atom_name[mask],
|
|
575
|
+
fixed_atoms[component],
|
|
576
|
+
atom_array.res_name[mask][0],
|
|
577
|
+
)
|
|
578
|
+
# If specific fixed atoms are selected, set fixed coordinates to those specified
|
|
579
|
+
atom_array.is_motif_atom_with_fixed_coord[mask] = atom_mask
|
|
580
|
+
else:
|
|
581
|
+
# Otherwise fix the entire token.
|
|
582
|
+
atom_array.is_motif_atom_with_fixed_coord[mask] = 1
|
|
583
|
+
|
|
584
|
+
# Append unindexed components from input specifcation
|
|
585
|
+
if unindexed_components:
|
|
586
|
+
start_resid = np.max(atom_array.res_id) + 1
|
|
587
|
+
start_chain = atom_array.chain_id[
|
|
588
|
+
0
|
|
589
|
+
] # HACK: set chain ID for unindexed residues as whatever the input has
|
|
590
|
+
atom_array_append = accumulate_components(
|
|
591
|
+
# Normal stuff:
|
|
592
|
+
components=unindexed_components,
|
|
593
|
+
breaks=unindexed_breaks,
|
|
594
|
+
# Regular other stuff
|
|
595
|
+
src_atom_array=atom_array_input,
|
|
596
|
+
redesign_motif_sidechains=redesign_motif_sidechains,
|
|
597
|
+
unindexed_components=unindexed_components,
|
|
598
|
+
unfixed_sequence_components=unfixed_sequence_components,
|
|
599
|
+
fixed_atoms=fixed_atoms,
|
|
600
|
+
unfix_all=unfix_all,
|
|
601
|
+
optional_conditions=optional_conditions,
|
|
602
|
+
flexible_backbone=flexible_backbone,
|
|
603
|
+
unfix_residues=unfix_residues,
|
|
604
|
+
start_chain=start_chain,
|
|
605
|
+
start_resid=start_resid,
|
|
606
|
+
)
|
|
607
|
+
atom_array = atom_array + atom_array_append
|
|
608
|
+
else:
|
|
609
|
+
atom_array = accumulate_components(
|
|
610
|
+
components=indexed_components + unindexed_components,
|
|
611
|
+
src_atom_array=atom_array_input,
|
|
612
|
+
redesign_motif_sidechains=redesign_motif_sidechains,
|
|
613
|
+
unindexed_components=unindexed_components,
|
|
614
|
+
unfixed_sequence_components=unfixed_sequence_components,
|
|
615
|
+
breaks=breaks,
|
|
616
|
+
fixed_atoms=fixed_atoms,
|
|
617
|
+
unfix_all=unfix_all,
|
|
618
|
+
optional_conditions=optional_conditions,
|
|
619
|
+
flexible_backbone=flexible_backbone,
|
|
620
|
+
unfix_residues=unfix_residues,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Spoof assignments for is_motif_token
|
|
624
|
+
f = get_motif_features(atom_array)
|
|
625
|
+
is_motif_token = f["is_motif_token"]
|
|
626
|
+
atom_array.set_annotation("is_motif_token", is_motif_token.astype(int))
|
|
627
|
+
is_motif_atom = f["is_motif_atom"]
|
|
628
|
+
atom_array.set_annotation("is_motif_atom", is_motif_atom.astype(int))
|
|
629
|
+
|
|
630
|
+
# ... If ligand, post-pend it
|
|
631
|
+
if exists(ligand):
|
|
632
|
+
ligand_array = extract_ligand_array(
|
|
633
|
+
atom_array_input,
|
|
634
|
+
ligand,
|
|
635
|
+
fixed_atoms,
|
|
636
|
+
additional_annotations=set(
|
|
637
|
+
list(atom_array.get_annotation_categories())
|
|
638
|
+
+ list(atom_array_input.get_annotation_categories())
|
|
639
|
+
+ ["is_motif_atom", "is_motif_token"]
|
|
640
|
+
),
|
|
641
|
+
)
|
|
642
|
+
ligand_array.res_id = (
|
|
643
|
+
ligand_array.res_id
|
|
644
|
+
- np.min(ligand_array.res_id)
|
|
645
|
+
+ np.max(atom_array.res_id)
|
|
646
|
+
+ 1
|
|
647
|
+
)
|
|
648
|
+
atom_array = atom_array + ligand_array
|
|
649
|
+
|
|
650
|
+
# ... Apply symmetry if it exists ahead of any other processing
|
|
651
|
+
if exists(symmetry) and symmetry.get("id"):
|
|
652
|
+
atom_array = make_symmetric_atom_array(
|
|
653
|
+
atom_array, symmetry, sm=ligand, src_atom_array=atom_array_input
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# ... Input frame and ORI token handling
|
|
657
|
+
if exists(partial_t):
|
|
658
|
+
# For symmetric structures, avoid COM centering that would collapse chains
|
|
659
|
+
if exists(symmetry) and symmetry.get("id"):
|
|
660
|
+
ranked_logger.info(
|
|
661
|
+
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
atom_array = set_com(atom_array, ori_token=None, infer_ori_strategy="com")
|
|
665
|
+
atom_array.set_annotation(
|
|
666
|
+
"partial_t", np.full(atom_array.shape[0], partial_t, dtype=float)
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
atom_array = set_com(
|
|
670
|
+
atom_array, ori_token=ori_token, infer_ori_strategy=infer_ori_strategy
|
|
671
|
+
)
|
|
672
|
+
# diffused atoms initialized at origin
|
|
673
|
+
atom_array.coord[~atom_array.is_motif_atom_with_fixed_coord.astype(bool), :] = (
|
|
674
|
+
0.0
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# This is an annotation on the diffused regions, so must be added after accumulate_components
|
|
678
|
+
if spoof_helical_bundle_ss_conditioning:
|
|
679
|
+
is_helix = spoof_helical_bundle_ss_conditioning_fn(atom_array)
|
|
680
|
+
is_sheet = None
|
|
681
|
+
is_loop = None
|
|
682
|
+
if exists(is_helix):
|
|
683
|
+
set_atom_level_argument(atom_array, is_helix, "is_helix")
|
|
684
|
+
if exists(is_sheet):
|
|
685
|
+
set_atom_level_argument(atom_array, is_sheet, "is_sheet")
|
|
686
|
+
optional_conditions.append("is_sheet")
|
|
687
|
+
if exists(is_loop):
|
|
688
|
+
set_atom_level_argument(atom_array, is_loop, "is_loop")
|
|
689
|
+
optional_conditions.append("is_loop")
|
|
690
|
+
|
|
691
|
+
is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
|
|
692
|
+
diffused_region_mask = ~(atom_array.is_motif_token.astype(bool))
|
|
693
|
+
if exists(is_non_loopy):
|
|
694
|
+
is_non_loopy_annot[diffused_region_mask] = 1 if is_non_loopy else -1
|
|
695
|
+
|
|
696
|
+
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
|
|
697
|
+
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
|
|
698
|
+
|
|
699
|
+
if plddt_enhanced:
|
|
700
|
+
atom_array.set_annotation(
|
|
701
|
+
"ref_plddt", np.ones((atom_array.array_length(),), dtype=int)
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Ensure correct annotations before saving
|
|
705
|
+
check_has_required_conditioning_annotations(
|
|
706
|
+
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
|
|
707
|
+
)
|
|
708
|
+
convert_existing_annotations_to_bool(atom_array)
|
|
709
|
+
|
|
710
|
+
if "atom_id" in atom_array.get_annotation_categories():
|
|
711
|
+
ranked_logger.info("Removing atom_id annotation...")
|
|
712
|
+
atom_array.del_annotation("atom_id")
|
|
713
|
+
|
|
714
|
+
if out_path is not None:
|
|
715
|
+
to_cif_file(atom_array, out_path, extra_fields=INFERENCE_ANNOTATIONS)
|
|
716
|
+
|
|
717
|
+
return atom_array
|