rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,1123 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import warnings
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from atomworks.constants import STANDARD_AA
|
|
12
|
+
from atomworks.io.parser import parse_atom_array
|
|
13
|
+
|
|
14
|
+
# from atomworks.ml.datasets.datasets import BaseDataset
|
|
15
|
+
from atomworks.ml.transforms.base import TransformedDict
|
|
16
|
+
from atomworks.ml.utils.token import (
|
|
17
|
+
get_token_starts,
|
|
18
|
+
)
|
|
19
|
+
from biotite import structure as struc
|
|
20
|
+
from biotite.structure import AtomArray, BondList, get_residue_starts
|
|
21
|
+
from pydantic import (
|
|
22
|
+
BaseModel,
|
|
23
|
+
ConfigDict,
|
|
24
|
+
Field,
|
|
25
|
+
model_validator,
|
|
26
|
+
)
|
|
27
|
+
from rfd3.constants import (
|
|
28
|
+
INFERENCE_ANNOTATIONS,
|
|
29
|
+
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
|
|
30
|
+
REQUIRED_INFERENCE_ANNOTATIONS,
|
|
31
|
+
)
|
|
32
|
+
from rfd3.inference.legacy_input_parsing import (
|
|
33
|
+
create_atom_array_from_design_specification_legacy,
|
|
34
|
+
)
|
|
35
|
+
from rfd3.inference.parsing import InputSelection
|
|
36
|
+
from rfd3.inference.symmetry.symmetry_utils import (
|
|
37
|
+
SymmetryConfig,
|
|
38
|
+
center_symmetric_src_atom_array,
|
|
39
|
+
make_symmetric_atom_array,
|
|
40
|
+
)
|
|
41
|
+
from rfd3.transforms.conditioning_base import (
|
|
42
|
+
check_has_required_conditioning_annotations,
|
|
43
|
+
convert_existing_annotations_to_bool,
|
|
44
|
+
get_motif_features,
|
|
45
|
+
set_default_conditioning_annotations,
|
|
46
|
+
)
|
|
47
|
+
from rfd3.transforms.util_transforms import assign_types_
|
|
48
|
+
from rfd3.utils.inference import (
|
|
49
|
+
_restore_bonds_for_nonstandard_residues,
|
|
50
|
+
extract_ligand_array,
|
|
51
|
+
inference_load_,
|
|
52
|
+
set_com,
|
|
53
|
+
set_common_annotations,
|
|
54
|
+
set_indices,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
from foundry.common import exists
|
|
58
|
+
from foundry.utils.components import (
|
|
59
|
+
get_design_pattern_with_constraints,
|
|
60
|
+
get_motif_components_and_breaks,
|
|
61
|
+
)
|
|
62
|
+
from foundry.utils.ddp import RankedLogger
|
|
63
|
+
|
|
64
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
65
|
+
|
|
66
|
+
logger = RankedLogger(__name__, rank_zero_only=True)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
#################################################################################
|
|
70
|
+
# Custom infer_ori functions
|
|
71
|
+
#################################################################################
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class LegacySpecification(BaseModel):
|
|
75
|
+
"""Legacy specification for compatibility with legacy input parsing."""
|
|
76
|
+
|
|
77
|
+
model_config = ConfigDict(
|
|
78
|
+
arbitrary_types_allowed=True,
|
|
79
|
+
extra="allow",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def build(self, *args, **kwargs):
|
|
83
|
+
"""Build atom array using legacy input parsing."""
|
|
84
|
+
atom_array = create_atom_array_from_design_specification_legacy(
|
|
85
|
+
**self.model_dump(),
|
|
86
|
+
)
|
|
87
|
+
return atom_array, self.model_dump()
|
|
88
|
+
|
|
89
|
+
def to_pipeline_input(self, example_id):
|
|
90
|
+
atom_array, spec_dict = self.build(return_metadata=True)
|
|
91
|
+
|
|
92
|
+
# ... Forward into
|
|
93
|
+
data = prepare_pipeline_input_from_atom_array(atom_array)
|
|
94
|
+
data["example_id"] = example_id
|
|
95
|
+
|
|
96
|
+
# ... Wrap up with additional features
|
|
97
|
+
if "extra" not in spec_dict:
|
|
98
|
+
spec_dict["extra"] = {}
|
|
99
|
+
spec_dict["extra"]["example_id"] = example_id
|
|
100
|
+
data["specification"] = spec_dict
|
|
101
|
+
return data
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ========================================================================
|
|
105
|
+
# Input specification
|
|
106
|
+
# ========================================================================
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class DesignInputSpecification(BaseModel):
|
|
110
|
+
"""Validated and parsed input specification before resolution."""
|
|
111
|
+
|
|
112
|
+
model_config = ConfigDict(
|
|
113
|
+
hide_input_in_errors=False,
|
|
114
|
+
arbitrary_types_allowed=True,
|
|
115
|
+
validate_assignment=False,
|
|
116
|
+
str_strip_whitespace=True,
|
|
117
|
+
str_min_length=1,
|
|
118
|
+
extra="forbid",
|
|
119
|
+
)
|
|
120
|
+
# fmt: off
|
|
121
|
+
# ========================================================================
|
|
122
|
+
# Data inputs, motif generation & selection
|
|
123
|
+
# ========================================================================
|
|
124
|
+
# Data inputs
|
|
125
|
+
atom_array_input: Optional[AtomArray] = Field(None, description="Loaded atom array", exclude=True)
|
|
126
|
+
input: Optional[str] = Field(None, description="Path to input PDB/CIF file")
|
|
127
|
+
# Motif selection from input file
|
|
128
|
+
contig: Optional[InputSelection] = Field(None, description="Contig specification string (e.g. 'A1-10,B1-5')")
|
|
129
|
+
unindex: Optional[InputSelection] = Field(None,
|
|
130
|
+
description="Unindexed components string (components must not overlap with contig). "\
|
|
131
|
+
"E.g. 'A15-20,B6-10' or dict. We recommend specifying")
|
|
132
|
+
# Extra args:
|
|
133
|
+
length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
|
|
134
|
+
ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
|
|
135
|
+
cif_parser_args: Optional[Dict[str, Any]] = Field(None, description="CIF parser arguments")
|
|
136
|
+
extra: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra metadata to include in output (useful for logging additional info in metadata)")
|
|
137
|
+
dialect: int = Field(2, description="RFdiffusion3 input dialect. 1: legacy, 2: release.")
|
|
138
|
+
|
|
139
|
+
# ========================================================================
|
|
140
|
+
# Conditioning
|
|
141
|
+
# ========================================================================
|
|
142
|
+
# Sequence and coordinate conditioning
|
|
143
|
+
select_fixed_atoms: Optional[InputSelection] = Field(None,
|
|
144
|
+
description='''Atoms to fix coordinates for. Examples:
|
|
145
|
+
- True (default when inputs provided): All atoms pulled from the input are fixed in 3d space
|
|
146
|
+
- False: All atoms pulled from the input are unfixed in 3d space
|
|
147
|
+
- ContigStr: Components to fix in 3d space, e.g. "A1-10,B1-3" fixes residues 1-10 in chain A and residues 1-3 in chain B.
|
|
148
|
+
- {"A1": "N,CA,C,O,CB,CG", "A2-10": "BKBN"} fixes backbone and CB for residues 1 and 2, and all atoms for residues 3-10 in chain A.
|
|
149
|
+
'''.replace('\t\t', '\t')
|
|
150
|
+
)
|
|
151
|
+
select_unfixed_sequence: Optional[InputSelection] = Field(None, description='''Components to unfix sequence for.
|
|
152
|
+
- True (default when inputs provided): All atoms from the input have fixed sequences by default.
|
|
153
|
+
- False: All atoms pulled from the input have diffused sequences by default.
|
|
154
|
+
- ContigStr: Components to unfix sequence for, e.g. "A5-10,B1-3" unfixes sequence for residues 5-10 in chain A and residues 1-3 in chain B.
|
|
155
|
+
- Dictionary: Allowed but not recommended.
|
|
156
|
+
NOTE: Excludes ligands (ligands / DNA always has fixed sequence).
|
|
157
|
+
'''.replace('\t\t', '\t')
|
|
158
|
+
)
|
|
159
|
+
# Assignments of conditioning annotations
|
|
160
|
+
# RASA accessibilty
|
|
161
|
+
select_buried: Optional[InputSelection] = Field(None, description="Selection of RASA buried conditioning")
|
|
162
|
+
select_partially_buried: Optional[InputSelection] = Field(None, description="Selection of RASA partially buried conditioning")
|
|
163
|
+
select_exposed: Optional[InputSelection] = Field(None, description="Selection of RASA exposed conditioning")
|
|
164
|
+
# Hotspots & Hbonds
|
|
165
|
+
select_hbond_acceptor: Optional[InputSelection] = Field(None, description="Atom-wise hydrogen bond acceptor")
|
|
166
|
+
select_hbond_donor: Optional[InputSelection] = Field(None, description="Atom-wise hydrogen bond donor")
|
|
167
|
+
select_hotspots: Optional[InputSelection] = Field(None, description="Atom-level or token-level hotspots for PPI")
|
|
168
|
+
redesign_motif_sidechains: Union[bool, str] = Field(False,
|
|
169
|
+
description="Perform fixed-backbone sequence design on when 'contig' is provided. Changes the default behaviour when not using `select_fixed_atoms`."
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# ========================================================================
|
|
173
|
+
# Global conditioning & symmetry
|
|
174
|
+
# ========================================================================
|
|
175
|
+
# Symmetry
|
|
176
|
+
symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md")
|
|
177
|
+
# Centering & COM guidance
|
|
178
|
+
ori_token: Optional[list[float]] = Field(None, description="Origin coordinates")
|
|
179
|
+
infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`")
|
|
180
|
+
# Additional global conditioning
|
|
181
|
+
plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement")
|
|
182
|
+
is_non_loopy: Optional[bool] = Field(None, description="Non-loopy conditioning")
|
|
183
|
+
# Partial diffusion
|
|
184
|
+
partial_t: Optional[float] = Field(None, ge=0.0, description="Angstroms of noise to add for partial diffusion (None turns off partial diffusion), t <= 15 recommended.")
|
|
185
|
+
# fmt: on
|
|
186
|
+
|
|
187
|
+
# ========================================================================
|
|
188
|
+
# Properties
|
|
189
|
+
# ========================================================================
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def is_partial_diffusion(self) -> bool:
|
|
193
|
+
"""Whether partial diffusion is enabled."""
|
|
194
|
+
return exists(self.partial_t)
|
|
195
|
+
|
|
196
|
+
# ========================================================================
|
|
197
|
+
# Loading / saving
|
|
198
|
+
# ========================================================================
|
|
199
|
+
|
|
200
|
+
@classmethod
|
|
201
|
+
def from_json(cls, path):
|
|
202
|
+
with open(path, "r") as f:
|
|
203
|
+
data = json.load(f)
|
|
204
|
+
return cls(**data)
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def from_rfd3_out(cls, path: str):
|
|
208
|
+
"""Load from path to rfd3 outputs, either .cif, .cif.gz, .json or denoised / noisy trajectory files"""
|
|
209
|
+
path = path.replace(".cif.gz", ".cif").replace(".cif", ".json")
|
|
210
|
+
if not os.path.exists(path):
|
|
211
|
+
raise FileNotFoundError(f"Output file not found at {path}")
|
|
212
|
+
with open(path, "r") as f:
|
|
213
|
+
data = json.load(f)
|
|
214
|
+
if "input_specification" in data:
|
|
215
|
+
spec_args = data["input_specification"]
|
|
216
|
+
return cls(**spec_args)
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError(f"No input specification found in json output: {path}")
|
|
219
|
+
|
|
220
|
+
def get_dict_to_save(self, exclude_extra: bool = False) -> dict:
|
|
221
|
+
# Returns dictionary for saving (reproducible) outputs to json
|
|
222
|
+
return self.model_dump(
|
|
223
|
+
exclude_defaults=True,
|
|
224
|
+
exclude={"atom_array_input"} | set({"extra"} if exclude_extra else {}),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# ========================================================================
|
|
228
|
+
# Pre-Validation / canonicalization
|
|
229
|
+
# ========================================================================
|
|
230
|
+
|
|
231
|
+
@model_validator(mode="before")
|
|
232
|
+
@classmethod
|
|
233
|
+
def validate_input_schema(cls, data: dict) -> dict:
|
|
234
|
+
if not (
|
|
235
|
+
exists(data.get("input"))
|
|
236
|
+
or exists(data.get("contig"))
|
|
237
|
+
or exists(data.get("length"))
|
|
238
|
+
):
|
|
239
|
+
raise ValueError("Either 'input' or 'contig' / 'length' must be provided.")
|
|
240
|
+
|
|
241
|
+
# unused input check
|
|
242
|
+
if exists(data.get("input")) and not (
|
|
243
|
+
(
|
|
244
|
+
exists(data.get("contig"))
|
|
245
|
+
or exists(data.get("unindex"))
|
|
246
|
+
or exists(data.get("ligand"))
|
|
247
|
+
)
|
|
248
|
+
or exists(data.get("partial_t"))
|
|
249
|
+
):
|
|
250
|
+
raise ValueError("Input provided but unused in composition specification.")
|
|
251
|
+
|
|
252
|
+
if not exists(data.get("partial_t")):
|
|
253
|
+
# non-partial diffusion checks
|
|
254
|
+
if exists(data.get("unindex")) and not (
|
|
255
|
+
exists(data.get("contig")) or exists(data.get("length"))
|
|
256
|
+
):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"Unindex provided but neither a length nor contig was specified."
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
# partial diffusion checks
|
|
262
|
+
if exists(data.get("length")):
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"Length argument must not be provided during partial diffusion."
|
|
265
|
+
)
|
|
266
|
+
if not (exists(data.get("input")) or exists(data.get("atom_array_input"))):
|
|
267
|
+
raise ValueError(
|
|
268
|
+
"Partial diffusion requires input file or input atom array."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return data
|
|
272
|
+
|
|
273
|
+
@model_validator(mode="before")
|
|
274
|
+
@classmethod
|
|
275
|
+
def canonicalize(cls, data: dict) -> dict:
|
|
276
|
+
# Canonicalize length argument
|
|
277
|
+
data["length"] = str(data["length"]) if exists(data.get("length")) else None
|
|
278
|
+
|
|
279
|
+
# Normalize input to str
|
|
280
|
+
data["input"] = str(data["input"]) if exists(data.get("input")) else None
|
|
281
|
+
return data
|
|
282
|
+
|
|
283
|
+
@model_validator(mode="before")
|
|
284
|
+
@classmethod
|
|
285
|
+
def load_input(cls, data: dict) -> dict:
|
|
286
|
+
with validator_context("load_input"):
|
|
287
|
+
# ... Find provided selections
|
|
288
|
+
selections = [
|
|
289
|
+
# Motif
|
|
290
|
+
"contig",
|
|
291
|
+
"unindex",
|
|
292
|
+
# Aux
|
|
293
|
+
"select_fixed_atoms",
|
|
294
|
+
"select_unfixed_sequence",
|
|
295
|
+
# Conditioning
|
|
296
|
+
"select_buried",
|
|
297
|
+
"select_partially_buried",
|
|
298
|
+
"select_exposed",
|
|
299
|
+
"select_hbond_acceptor",
|
|
300
|
+
"select_hbond_donor",
|
|
301
|
+
"select_hotspots",
|
|
302
|
+
]
|
|
303
|
+
selections = [s for s in selections if s in data]
|
|
304
|
+
|
|
305
|
+
# ... Early return if no input file provided / atom array input
|
|
306
|
+
if not exists(data.get("input")) and not exists(
|
|
307
|
+
data.get("atom_array_input")
|
|
308
|
+
):
|
|
309
|
+
if selections:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"Atom array input must be provided before parsing selections: {}".format(
|
|
312
|
+
selections
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
return data
|
|
316
|
+
|
|
317
|
+
# ... Load atom array from input file if provided
|
|
318
|
+
if exists(data["input"]):
|
|
319
|
+
if exists(data.get("atom_array_input")):
|
|
320
|
+
raise ValueError(
|
|
321
|
+
"Both 'input' and 'atom_array_input' provided; please provide only one."
|
|
322
|
+
)
|
|
323
|
+
atom_array = inference_load_(
|
|
324
|
+
data["input"], cif_parser_args=data.get("cif_parser_args")
|
|
325
|
+
)["atom_array"]
|
|
326
|
+
|
|
327
|
+
# Center for symmetric design
|
|
328
|
+
if exists(data.get("symmetry")) and data["symmetry"].get("id"):
|
|
329
|
+
atom_array = center_symmetric_src_atom_array(atom_array)
|
|
330
|
+
|
|
331
|
+
if "atom_id" in atom_array.get_annotation_categories():
|
|
332
|
+
atom_array.del_annotation("atom_id")
|
|
333
|
+
|
|
334
|
+
data["atom_array_input"] = atom_array
|
|
335
|
+
|
|
336
|
+
atom_array = data["atom_array_input"]
|
|
337
|
+
|
|
338
|
+
# ... Set defaults if not provided
|
|
339
|
+
if not exists(data.get("select_fixed_atoms")):
|
|
340
|
+
data["select_fixed_atoms"] = InputSelection.from_any(
|
|
341
|
+
True, atom_array=atom_array
|
|
342
|
+
)
|
|
343
|
+
if not exists(data.get("select_unfixed_sequence")):
|
|
344
|
+
data["select_unfixed_sequence"] = InputSelection.from_any(
|
|
345
|
+
False, atom_array=atom_array
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Coerce selections
|
|
349
|
+
for sele in selections:
|
|
350
|
+
if sele in ["contig", "unindexed_breaks"]:
|
|
351
|
+
if exists(data[sele]) and not isinstance(data[sele], str):
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"{sele} selection must be a string or None, got {type(data[sele])} instead."
|
|
354
|
+
)
|
|
355
|
+
if not isinstance(data.get(sele), InputSelection):
|
|
356
|
+
data[sele] = InputSelection.from_any(
|
|
357
|
+
data[sele], atom_array=atom_array
|
|
358
|
+
)
|
|
359
|
+
return data
|
|
360
|
+
|
|
361
|
+
# ========================================================================
|
|
362
|
+
# Post-Validation
|
|
363
|
+
# ========================================================================
|
|
364
|
+
|
|
365
|
+
@model_validator(mode="after")
|
|
366
|
+
def assert_exclusivity(self):
|
|
367
|
+
with validator_context("assert_exclusivity"):
|
|
368
|
+
# ... Assert and indexed do not overlap
|
|
369
|
+
if exists(self.contig) and exists(self.unindex):
|
|
370
|
+
indexed_set = set(self.contig.keys())
|
|
371
|
+
unindexed_set = set(self.unindex.keys())
|
|
372
|
+
overlap = indexed_set & unindexed_set
|
|
373
|
+
if overlap:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Indexed and unindexed components must not overlap, got: {overlap}"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# ... Assert mutual exclusivity of rasa binning
|
|
379
|
+
exclusive_sets = [
|
|
380
|
+
("Motifs", ("contig", "unindex")),
|
|
381
|
+
(
|
|
382
|
+
"RASA",
|
|
383
|
+
("select_buried", "select_partially_buried", "select_exposed"),
|
|
384
|
+
),
|
|
385
|
+
]
|
|
386
|
+
|
|
387
|
+
for name, excl_set in exclusive_sets:
|
|
388
|
+
masks = [getattr(self, field, None) for field in excl_set]
|
|
389
|
+
masks = [m.get_mask() for m in masks if m is not None]
|
|
390
|
+
if not masks:
|
|
391
|
+
continue
|
|
392
|
+
mask_sum = np.zeros_like(masks[0], dtype=int)
|
|
393
|
+
for m in masks:
|
|
394
|
+
if m is not None:
|
|
395
|
+
mask_sum += m.astype(int)
|
|
396
|
+
if np.any(mask_sum > 1):
|
|
397
|
+
raise ValueError(
|
|
398
|
+
f"Selections for `{name}` must be mutually exclusive, got overlapping selections: {excl_set}. Mask sum: {mask_sum}"
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
return self
|
|
402
|
+
|
|
403
|
+
@model_validator(mode="after")
|
|
404
|
+
def attempt_expansion(self):
|
|
405
|
+
if self.is_partial_diffusion and exists(self.contig):
|
|
406
|
+
contig = self.contig
|
|
407
|
+
length = self.length
|
|
408
|
+
try:
|
|
409
|
+
get_design_pattern_with_constraints(contig.raw, length=length)
|
|
410
|
+
except Exception as e:
|
|
411
|
+
raise ValueError(f"Failed to expand contig ({contig.raw}): {e}")
|
|
412
|
+
return self
|
|
413
|
+
|
|
414
|
+
@model_validator(mode="after")
|
|
415
|
+
def _assign_types_to_input(self):
|
|
416
|
+
"""Assign conditioning annotations to the input atom array"""
|
|
417
|
+
aa = self.atom_array_input
|
|
418
|
+
if not exists(aa):
|
|
419
|
+
return self
|
|
420
|
+
|
|
421
|
+
# ... Selections and their annotation values
|
|
422
|
+
selection_fields = {
|
|
423
|
+
# field name: (annotation name, assigned value, non-selected value)
|
|
424
|
+
"select_fixed_atoms": ("is_motif_atom_with_fixed_coord", True, False),
|
|
425
|
+
"select_unfixed_sequence": ("is_motif_atom_with_fixed_seq", False, True),
|
|
426
|
+
"unindex": ("is_motif_atom_unindexed", True, False),
|
|
427
|
+
"select_hotspots": ("is_atom_level_hotspot", True, False),
|
|
428
|
+
"select_hbond_acceptor": ("active_acceptor", True, False),
|
|
429
|
+
"select_hbond_donor": ("active_donor", True, False),
|
|
430
|
+
"select_buried": ("rasa_bin", 0, 3),
|
|
431
|
+
"select_partially_buried": ("rasa_bin", 1, 3),
|
|
432
|
+
"select_exposed": ("rasa_bin", 2, 3),
|
|
433
|
+
}
|
|
434
|
+
selection_fields = {
|
|
435
|
+
k: v for k, v in selection_fields.items() if exists(getattr(self, k, None))
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
# ... Init global
|
|
439
|
+
[
|
|
440
|
+
aa.set_annotation(name, np.full(aa.array_length(), val, dtype=int))
|
|
441
|
+
for name, val in REQUIRED_CONDITIONING_ANNOTATION_VALUES.items()
|
|
442
|
+
]
|
|
443
|
+
|
|
444
|
+
# Application of selections to each token fn;
|
|
445
|
+
def apply_selections(start, end):
|
|
446
|
+
chain_id = aa.chain_id[start]
|
|
447
|
+
res_id = aa.res_id[start]
|
|
448
|
+
|
|
449
|
+
# Assign all select fields to atom array annotations.
|
|
450
|
+
for selection_name, (
|
|
451
|
+
annotation_name,
|
|
452
|
+
set_value,
|
|
453
|
+
default_value,
|
|
454
|
+
) in selection_fields.items():
|
|
455
|
+
# ... Get input values
|
|
456
|
+
selection = getattr(self, selection_name)
|
|
457
|
+
|
|
458
|
+
# Important line: selects from data dictionary based on src chain & res_id (Not name!)
|
|
459
|
+
atom_names_sele = selection.get(f"{chain_id}{res_id}")
|
|
460
|
+
|
|
461
|
+
if atom_names_sele is None:
|
|
462
|
+
continue
|
|
463
|
+
mask = np.isin(aa.atom_name[start:end], atom_names_sele)
|
|
464
|
+
if annotation_name in aa.get_annotation_categories():
|
|
465
|
+
# ... Set only mask overridden features if exists in atom array
|
|
466
|
+
aa.get_annotation(annotation_name)[start:end] = np.where(
|
|
467
|
+
mask, set_value, default_value
|
|
468
|
+
).astype(np.int_)
|
|
469
|
+
# ).astype(int)
|
|
470
|
+
else:
|
|
471
|
+
# ... Otherwise, set the entire annotation and use defaults for unselected
|
|
472
|
+
mask_aa = np.zeros(aa.array_length(), dtype=bool)
|
|
473
|
+
mask_aa[start:end] = mask
|
|
474
|
+
annotation_values = np.where(
|
|
475
|
+
mask_aa,
|
|
476
|
+
set_value,
|
|
477
|
+
default_value,
|
|
478
|
+
).astype(np.int_)
|
|
479
|
+
aa.set_annotation(annotation_name, annotation_values)
|
|
480
|
+
|
|
481
|
+
# ... Set default assignments per-token based on whether redesigning
|
|
482
|
+
starts = get_residue_starts(aa, add_exclusive_stop=True)
|
|
483
|
+
for start, end in zip(starts[:-1], starts[1:]):
|
|
484
|
+
# ... Relax sequence and sidechains
|
|
485
|
+
if aa.res_name[start] in STANDARD_AA and self.redesign_motif_sidechains:
|
|
486
|
+
is_bkbn = np.isin(aa.atom_name[start:end], ["N", "CA", "C", "O"])
|
|
487
|
+
aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
|
|
488
|
+
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
|
|
489
|
+
is_bkbn, False, dtype=int
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# ... Apply selections on top
|
|
493
|
+
apply_selections(start, end)
|
|
494
|
+
|
|
495
|
+
return self
|
|
496
|
+
|
|
497
|
+
# ========================================================================
|
|
498
|
+
# Building
|
|
499
|
+
# ========================================================================
|
|
500
|
+
|
|
501
|
+
def build(self, return_metadata=False):
|
|
502
|
+
"""Main build pipeline."""
|
|
503
|
+
atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
|
|
504
|
+
atom_array = self._build_init(atom_array_input_annotated)
|
|
505
|
+
|
|
506
|
+
# Apply post-processing
|
|
507
|
+
atom_array = self._append_ligand(atom_array, atom_array_input_annotated)
|
|
508
|
+
atom_array = self._apply_symmetry(atom_array, atom_array_input_annotated)
|
|
509
|
+
|
|
510
|
+
# Apply globals to all tokens (including diffused)
|
|
511
|
+
atom_array = self._set_origin(atom_array)
|
|
512
|
+
atom_array = self._apply_globals(atom_array)
|
|
513
|
+
|
|
514
|
+
# Final validation and cleanup
|
|
515
|
+
check_has_required_conditioning_annotations(
|
|
516
|
+
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
|
|
517
|
+
)
|
|
518
|
+
convert_existing_annotations_to_bool(atom_array)
|
|
519
|
+
|
|
520
|
+
# ... Route return type
|
|
521
|
+
if not return_metadata:
|
|
522
|
+
return copy.deepcopy(atom_array)
|
|
523
|
+
else:
|
|
524
|
+
metadata = self.get_dict_to_save()
|
|
525
|
+
metadata["extra"] = metadata.get("extra", {}) | {
|
|
526
|
+
"num_tokens_in": len(get_token_starts(atom_array)),
|
|
527
|
+
"num_residues_in": len(get_residue_starts(atom_array)),
|
|
528
|
+
"num_chains": len(np.unique(atom_array.chain_id)),
|
|
529
|
+
"num_atoms": len(atom_array),
|
|
530
|
+
"num_residues": len(
|
|
531
|
+
np.unique(list(zip(atom_array.chain_id, atom_array.res_id)))
|
|
532
|
+
),
|
|
533
|
+
}
|
|
534
|
+
return copy.deepcopy(atom_array), metadata
|
|
535
|
+
|
|
536
|
+
# ============================================================================
|
|
537
|
+
# Building functions
|
|
538
|
+
# ============================================================================
|
|
539
|
+
|
|
540
|
+
def _build_init(self, atom_array_input_annotated):
|
|
541
|
+
# ... Fetch tokens
|
|
542
|
+
indexed_tokens = (
|
|
543
|
+
self.contig.get_tokens(atom_array_input_annotated)
|
|
544
|
+
if exists(self.contig)
|
|
545
|
+
else {}
|
|
546
|
+
)
|
|
547
|
+
unindexed_tokens = (
|
|
548
|
+
self.unindex.get_tokens(atom_array_input_annotated)
|
|
549
|
+
if exists(self.unindex)
|
|
550
|
+
else {}
|
|
551
|
+
)
|
|
552
|
+
# Subset to only fixed coordindate atoms
|
|
553
|
+
unindexed_tokens = {
|
|
554
|
+
k: tok[tok.is_motif_atom_with_fixed_coord.astype(bool)]
|
|
555
|
+
for k, tok in unindexed_tokens.items()
|
|
556
|
+
}
|
|
557
|
+
unindexed_components, unindexed_breaks = self.break_unindexed(self.unindex)
|
|
558
|
+
|
|
559
|
+
if not self.is_partial_diffusion:
|
|
560
|
+
# ... Sample the contig string
|
|
561
|
+
components_to_accumulate = get_design_pattern_with_constraints(
|
|
562
|
+
self.contig.raw if exists(self.contig) else self.length,
|
|
563
|
+
length=self.length,
|
|
564
|
+
)
|
|
565
|
+
self.extra["sampled_contig"] = ",".join(
|
|
566
|
+
[str(x) for x in components_to_accumulate]
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# ... Include unindexed components in accumulation
|
|
570
|
+
unindexed_breaks = [None] * len(components_to_accumulate) + unindexed_breaks
|
|
571
|
+
components_to_accumulate += unindexed_components
|
|
572
|
+
|
|
573
|
+
# ... Accumulate from scratch
|
|
574
|
+
atom_array = accumulate_components(
|
|
575
|
+
components_to_accumulate,
|
|
576
|
+
indexed_tokens=indexed_tokens,
|
|
577
|
+
unindexed_tokens=unindexed_tokens,
|
|
578
|
+
atom_array_accum=[],
|
|
579
|
+
unindexed_breaks=unindexed_breaks,
|
|
580
|
+
start_chain="A",
|
|
581
|
+
start_resid=1,
|
|
582
|
+
)
|
|
583
|
+
else:
|
|
584
|
+
# ... Set common annotations
|
|
585
|
+
atom_array_in = assign_types_(copy.deepcopy(atom_array_input_annotated))
|
|
586
|
+
atom_array_in = set_common_annotations(
|
|
587
|
+
atom_array_in, set_src_component_to_res_name=False
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# ... Override motif annotations from pipeline
|
|
591
|
+
zeros = np.zeros(atom_array_in.array_length(), dtype=int)
|
|
592
|
+
atom_array_in.is_motif_atom_unindexed = (
|
|
593
|
+
zeros # reset unindexed annotation since those are copied already.
|
|
594
|
+
)
|
|
595
|
+
atom_array_in.is_motif_atom_with_fixed_coord = (
|
|
596
|
+
self.select_fixed_atoms.get_mask().astype(int)
|
|
597
|
+
if exists(self.select_fixed_atoms)
|
|
598
|
+
else zeros
|
|
599
|
+
)
|
|
600
|
+
atom_array_in.is_motif_atom_with_fixed_seq = (
|
|
601
|
+
~self.select_unfixed_sequence.get_mask()
|
|
602
|
+
if exists(self.select_unfixed_sequence)
|
|
603
|
+
else zeros
|
|
604
|
+
).astype(int)
|
|
605
|
+
|
|
606
|
+
# ... Subset to residues only
|
|
607
|
+
atom_array_in = atom_array_in[atom_array_in.is_protein]
|
|
608
|
+
|
|
609
|
+
# ... Set chain ID for unindexed residues as whatever the input has
|
|
610
|
+
start_resid = np.max(atom_array_in.res_id) + 1
|
|
611
|
+
start_chain = atom_array_in.chain_id[0]
|
|
612
|
+
|
|
613
|
+
# ... Accumulate from input
|
|
614
|
+
components_to_accumulate = unindexed_components
|
|
615
|
+
atom_array = accumulate_components(
|
|
616
|
+
# No accumulation of components
|
|
617
|
+
components_to_accumulate=components_to_accumulate,
|
|
618
|
+
indexed_tokens={},
|
|
619
|
+
# Append all inputs to unindexed tokens
|
|
620
|
+
unindexed_tokens=unindexed_tokens,
|
|
621
|
+
atom_array_accum=[atom_array_in],
|
|
622
|
+
start_chain=start_chain,
|
|
623
|
+
start_resid=start_resid,
|
|
624
|
+
unindexed_breaks=unindexed_breaks,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
return atom_array
|
|
628
|
+
|
|
629
|
+
# ============================================================================
|
|
630
|
+
# Auxiliary functions
|
|
631
|
+
# ============================================================================
|
|
632
|
+
|
|
633
|
+
@staticmethod
|
|
634
|
+
def break_unindexed(unindex: InputSelection):
|
|
635
|
+
if not exists(unindex):
|
|
636
|
+
return [], []
|
|
637
|
+
|
|
638
|
+
# ... If original type was string, use that
|
|
639
|
+
if isinstance(unindex.raw, str):
|
|
640
|
+
unindexed_string = unindex.raw
|
|
641
|
+
elif isinstance(unindex.raw, dict):
|
|
642
|
+
unindexed_string = ",".join(unindex.raw.keys())
|
|
643
|
+
else:
|
|
644
|
+
logger.info(
|
|
645
|
+
"`Unindex` provided as non-string, separate keys in dictionary will be considered separate contiguous components"
|
|
646
|
+
)
|
|
647
|
+
unindexed_string = ",".join(unindex.keys())
|
|
648
|
+
|
|
649
|
+
# ... Break expected unindexed contig string
|
|
650
|
+
unindexed_components, breaks = get_motif_components_and_breaks(unindexed_string)
|
|
651
|
+
|
|
652
|
+
return unindexed_components, breaks
|
|
653
|
+
|
|
654
|
+
# ============================================================================
|
|
655
|
+
# Setter functions
|
|
656
|
+
# ============================================================================
|
|
657
|
+
|
|
658
|
+
def _append_ligand(self, atom_array, atom_array_input_annotated):
|
|
659
|
+
"""Append ligand if specified."""
|
|
660
|
+
if exists(self.ligand):
|
|
661
|
+
ligand_array = extract_ligand_array(
|
|
662
|
+
atom_array_input_annotated,
|
|
663
|
+
self.ligand,
|
|
664
|
+
fixed_atoms={},
|
|
665
|
+
set_defaults=False,
|
|
666
|
+
additional_annotations=set(
|
|
667
|
+
list(atom_array.get_annotation_categories())
|
|
668
|
+
+ list(atom_array_input_annotated.get_annotation_categories())
|
|
669
|
+
),
|
|
670
|
+
)
|
|
671
|
+
# Offset ligand residue ids based on the original input to avoid clashes
|
|
672
|
+
# with any newly created residues (matches legacy behaviour).
|
|
673
|
+
ligand_array.res_id = (
|
|
674
|
+
ligand_array.res_id
|
|
675
|
+
- np.min(ligand_array.res_id)
|
|
676
|
+
+ np.max(atom_array.res_id)
|
|
677
|
+
+ 1
|
|
678
|
+
)
|
|
679
|
+
atom_array = atom_array + ligand_array
|
|
680
|
+
return atom_array
|
|
681
|
+
|
|
682
|
+
def _apply_symmetry(self, atom_array, atom_array_input_annotated):
|
|
683
|
+
"""Apply symmetry transformation if specified."""
|
|
684
|
+
if exists(self.symmetry) and self.symmetry.id:
|
|
685
|
+
atom_array = make_symmetric_atom_array(
|
|
686
|
+
atom_array,
|
|
687
|
+
self.symmetry,
|
|
688
|
+
sm=self.ligand,
|
|
689
|
+
src_atom_array=atom_array_input_annotated,
|
|
690
|
+
)
|
|
691
|
+
return atom_array
|
|
692
|
+
|
|
693
|
+
def _set_origin(self, atom_array):
|
|
694
|
+
"""Set origin token and initialize coordinates."""
|
|
695
|
+
if self.is_partial_diffusion:
|
|
696
|
+
# Partial diffusion: use COM, keep all coordinates
|
|
697
|
+
if exists(self.symmetry) and self.symmetry.id:
|
|
698
|
+
# For symmetric structures, avoid COM centering that would collapse chains
|
|
699
|
+
ranked_logger.info(
|
|
700
|
+
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
|
|
701
|
+
)
|
|
702
|
+
else:
|
|
703
|
+
atom_array = set_com(
|
|
704
|
+
atom_array, ori_token=None, infer_ori_strategy="com"
|
|
705
|
+
)
|
|
706
|
+
else:
|
|
707
|
+
# Standard: set ori token, zero out diffused atoms
|
|
708
|
+
atom_array = set_com(
|
|
709
|
+
atom_array,
|
|
710
|
+
ori_token=self.ori_token,
|
|
711
|
+
infer_ori_strategy=self.infer_ori_strategy,
|
|
712
|
+
)
|
|
713
|
+
# Diffused atoms are always initialized at origin during regular diffusion (all information removed)
|
|
714
|
+
atom_array.coord[
|
|
715
|
+
~atom_array.is_motif_atom_with_fixed_coord.astype(bool)
|
|
716
|
+
] = 0.0
|
|
717
|
+
return atom_array
|
|
718
|
+
|
|
719
|
+
def _apply_globals(self, atom_array):
|
|
720
|
+
# Temperature conditioning
|
|
721
|
+
if exists(self.is_non_loopy):
|
|
722
|
+
is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
|
|
723
|
+
is_motif_token = get_motif_features(atom_array)["is_motif_token"]
|
|
724
|
+
diffused_region_mask = ~(is_motif_token.astype(bool))
|
|
725
|
+
if exists(self.is_non_loopy):
|
|
726
|
+
is_non_loopy_annot[diffused_region_mask] = (
|
|
727
|
+
1 if self.is_non_loopy else -1
|
|
728
|
+
)
|
|
729
|
+
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
|
|
730
|
+
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
|
|
731
|
+
else:
|
|
732
|
+
zeros = np.zeros(atom_array.array_length(), dtype=int)
|
|
733
|
+
atom_array.set_annotation("is_non_loopy", zeros)
|
|
734
|
+
atom_array.set_annotation("is_non_loopy_atom_level", zeros)
|
|
735
|
+
|
|
736
|
+
if self.plddt_enhanced:
|
|
737
|
+
atom_array.set_annotation(
|
|
738
|
+
"ref_plddt", np.full((atom_array.array_length(),), True, dtype=int)
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
# Partial diffusion time annotation
|
|
742
|
+
if self.is_partial_diffusion:
|
|
743
|
+
atom_array.set_annotation(
|
|
744
|
+
"partial_t", np.full(atom_array.shape[0], self.partial_t, dtype=float)
|
|
745
|
+
)
|
|
746
|
+
return atom_array
|
|
747
|
+
|
|
748
|
+
@classmethod
|
|
749
|
+
def safe_init(cls, **spec_kwargs):
|
|
750
|
+
if spec_kwargs.get("dialect", 2) < 2:
|
|
751
|
+
warn = (
|
|
752
|
+
"Using dialect==1, which is deprecated and will be removed in future releases. "
|
|
753
|
+
"Please update your input specification to dialect=2 and use the new schema if possible"
|
|
754
|
+
)
|
|
755
|
+
warnings.warn(warn, DeprecationWarning)
|
|
756
|
+
logger.warning(warn)
|
|
757
|
+
return LegacySpecification(**spec_kwargs)
|
|
758
|
+
else:
|
|
759
|
+
return cls(**spec_kwargs)
|
|
760
|
+
|
|
761
|
+
def to_pipeline_input(self, example_id):
|
|
762
|
+
atom_array, spec_dict = self.build(return_metadata=True)
|
|
763
|
+
|
|
764
|
+
# ... Forward into
|
|
765
|
+
data = prepare_pipeline_input_from_atom_array(atom_array)
|
|
766
|
+
data["example_id"] = example_id
|
|
767
|
+
|
|
768
|
+
# ... Wrap up with additional features
|
|
769
|
+
if "extra" not in spec_dict:
|
|
770
|
+
spec_dict["extra"] = {}
|
|
771
|
+
spec_dict["extra"]["example_id"] = example_id
|
|
772
|
+
data["specification"] = spec_dict
|
|
773
|
+
return data
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
# ============================================================================
|
|
777
|
+
# APIs and utils
|
|
778
|
+
# ============================================================================
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
|
|
782
|
+
atom_array_orig,
|
|
783
|
+
) -> dict:
|
|
784
|
+
"""
|
|
785
|
+
Load or create an example from a metadata dictionary.
|
|
786
|
+
If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
|
|
787
|
+
Args:
|
|
788
|
+
atom_array_orig: Atom array instantiated with conditioning annotations
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
dict: A dictionary containing the parsed row data and additional loaded CIF data.
|
|
792
|
+
"""
|
|
793
|
+
_start_parse_time = time.time()
|
|
794
|
+
# HACK: Set empty bond graph:
|
|
795
|
+
if atom_array_orig.bonds is None:
|
|
796
|
+
atom_array_orig.bonds = BondList(atom_array_orig.array_length())
|
|
797
|
+
|
|
798
|
+
# Temporary spoof of chain IDs to ensure duplicates aren't dropped:
|
|
799
|
+
result_dict = parse_atom_array(
|
|
800
|
+
atom_array_orig,
|
|
801
|
+
remove_ccds=[],
|
|
802
|
+
fix_arginines=False,
|
|
803
|
+
add_missing_atoms=False,
|
|
804
|
+
extra_fields=INFERENCE_ANNOTATIONS,
|
|
805
|
+
build_assembly=None,
|
|
806
|
+
hydrogen_policy="remove",
|
|
807
|
+
)
|
|
808
|
+
atom_array = result_dict["asym_unit"][0]
|
|
809
|
+
|
|
810
|
+
# HACK: Set iid information manually
|
|
811
|
+
# We currently do not preserve this information from the input,
|
|
812
|
+
# if you want these we'd need to remove the spoofing here
|
|
813
|
+
check_has_required_conditioning_annotations(
|
|
814
|
+
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
|
|
815
|
+
)
|
|
816
|
+
atom_array = convert_existing_annotations_to_bool(atom_array)
|
|
817
|
+
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
|
|
818
|
+
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
|
|
819
|
+
|
|
820
|
+
# Ensure motif annotations are removed
|
|
821
|
+
atom_array.del_annotation(
|
|
822
|
+
"is_motif_token"
|
|
823
|
+
) if "is_motif_token" in atom_array.get_annotation_categories() else None
|
|
824
|
+
atom_array.del_annotation(
|
|
825
|
+
"is_motif_atom"
|
|
826
|
+
) if "is_motif_atom" in atom_array.get_annotation_categories() else None
|
|
827
|
+
|
|
828
|
+
data = {
|
|
829
|
+
"atom_array": atom_array, # First model
|
|
830
|
+
"chain_info": result_dict["chain_info"],
|
|
831
|
+
"ligand_info": result_dict["ligand_info"],
|
|
832
|
+
"metadata": result_dict["metadata"],
|
|
833
|
+
}
|
|
834
|
+
_stop_parse_time = time.time()
|
|
835
|
+
data = TransformedDict(data)
|
|
836
|
+
return data
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def create_atom_array_from_design_specification(
|
|
840
|
+
**spec_kwargs,
|
|
841
|
+
) -> tuple[AtomArray, dict]:
|
|
842
|
+
if int(spec_kwargs.get("dialect", 2)) < 2:
|
|
843
|
+
warn = (
|
|
844
|
+
"Using dialect==1, which is deprecated and will be removed in future releases. "
|
|
845
|
+
"Please update your input specification to dialect=2 and use the new schema if possible"
|
|
846
|
+
)
|
|
847
|
+
warnings.warn(warn, DeprecationWarning)
|
|
848
|
+
logger.warning(warn)
|
|
849
|
+
atom_array = create_atom_array_from_design_specification_legacy(**spec_kwargs)
|
|
850
|
+
return atom_array, {}
|
|
851
|
+
|
|
852
|
+
# Create input specfication and build
|
|
853
|
+
spec = DesignInputSpecification(**spec_kwargs)
|
|
854
|
+
atom_array, metadata = spec.build(return_metadata=True)
|
|
855
|
+
return atom_array, metadata
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
@contextmanager
|
|
859
|
+
def validator_context(validator_name: str, data: dict = None):
|
|
860
|
+
"""Context manager for validator execution with logging."""
|
|
861
|
+
logger.debug(f"Starting validator: {validator_name}")
|
|
862
|
+
try:
|
|
863
|
+
yield
|
|
864
|
+
logger.debug(f"✓ Completed validator: {validator_name}")
|
|
865
|
+
except Exception as e:
|
|
866
|
+
logger.error(
|
|
867
|
+
f"✗ Failed in validator: {validator_name}\n"
|
|
868
|
+
f" Error: {str(e)}\n"
|
|
869
|
+
f" Error type: {type(e).__name__}"
|
|
870
|
+
)
|
|
871
|
+
raise e
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def create_diffused_residues(n, additional_annotations=None):
|
|
875
|
+
if n <= 0:
|
|
876
|
+
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
|
877
|
+
|
|
878
|
+
atoms = []
|
|
879
|
+
[
|
|
880
|
+
atoms.extend(
|
|
881
|
+
[
|
|
882
|
+
struc.Atom(
|
|
883
|
+
np.array([0.0, 0.0, 0.0], dtype=np.float32),
|
|
884
|
+
res_name="ALA",
|
|
885
|
+
res_id=idx,
|
|
886
|
+
)
|
|
887
|
+
for _ in range(5)
|
|
888
|
+
]
|
|
889
|
+
)
|
|
890
|
+
for idx in range(1, n + 1)
|
|
891
|
+
]
|
|
892
|
+
array = struc.array(atoms)
|
|
893
|
+
array.set_annotation(
|
|
894
|
+
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
|
|
895
|
+
)
|
|
896
|
+
array.set_annotation(
|
|
897
|
+
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
|
|
898
|
+
)
|
|
899
|
+
array = set_default_conditioning_annotations(
|
|
900
|
+
array, motif=False, additional=additional_annotations
|
|
901
|
+
)
|
|
902
|
+
array = set_common_annotations(array)
|
|
903
|
+
return array
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def create_motif_residue(
|
|
907
|
+
token,
|
|
908
|
+
strip_sidechains_by_default: bool,
|
|
909
|
+
):
|
|
910
|
+
if strip_sidechains_by_default and token.res_name in STANDARD_AA:
|
|
911
|
+
n_atoms = token.shape[0]
|
|
912
|
+
diffuse_oxygen = False
|
|
913
|
+
if n_atoms < 3:
|
|
914
|
+
raise ValueError(
|
|
915
|
+
f"Not enough data for {src_chain}{src_resid} in input atom array."
|
|
916
|
+
)
|
|
917
|
+
if n_atoms == 3:
|
|
918
|
+
# Handle cases with N, CA, C only;
|
|
919
|
+
token = token + create_o_atoms(token.copy())
|
|
920
|
+
diffuse_oxygen = True # flag oxygen for generation
|
|
921
|
+
|
|
922
|
+
# Subset to the first 4 atoms (N, CA, C, O) only
|
|
923
|
+
token = token[np.isin(token.atom_name, ["N", "CA", "C", "O"])]
|
|
924
|
+
|
|
925
|
+
# exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
|
|
926
|
+
# Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
|
|
927
|
+
# are placed on the CB so as to not leak the identity of the residue.
|
|
928
|
+
token = token + create_cb_atoms(token.copy())
|
|
929
|
+
|
|
930
|
+
# Sequence name must be set to ALA such that the central atom is correctly CB
|
|
931
|
+
token.res_name = np.full_like(token.res_name, "ALA", dtype=token.res_name.dtype)
|
|
932
|
+
token.set_annotation(
|
|
933
|
+
"is_motif_atom_with_fixed_coord",
|
|
934
|
+
np.where(
|
|
935
|
+
np.arange(token.shape[0], dtype=int) < (4 - int(diffuse_oxygen)),
|
|
936
|
+
token.is_motif_atom_with_fixed_coord,
|
|
937
|
+
0,
|
|
938
|
+
),
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
check_has_required_conditioning_annotations(token)
|
|
942
|
+
token = set_common_annotations(token)
|
|
943
|
+
token.set_annotation("res_id", np.full(token.shape[0], 1)) # Reset to 1
|
|
944
|
+
|
|
945
|
+
return token
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
def accumulate_components(
|
|
949
|
+
components_to_accumulate: List[Union[str, int]],
|
|
950
|
+
*,
|
|
951
|
+
# Tokens from input
|
|
952
|
+
indexed_tokens: Dict[str, AtomArray],
|
|
953
|
+
unindexed_tokens: Dict[str, AtomArray],
|
|
954
|
+
# Additional parameters
|
|
955
|
+
atom_array_accum=[],
|
|
956
|
+
start_chain: str = "A",
|
|
957
|
+
start_resid: int = 1,
|
|
958
|
+
unindexed_breaks: Optional[List[bool]] = [],
|
|
959
|
+
src_atom_array: Optional[AtomArray] = None,
|
|
960
|
+
strip_sidechains_by_default: bool = False,
|
|
961
|
+
**kwargs,
|
|
962
|
+
) -> AtomArray:
|
|
963
|
+
# ... Create list of components
|
|
964
|
+
assert (
|
|
965
|
+
x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
|
|
966
|
+
).issubset(
|
|
967
|
+
(y := set(components_to_accumulate))
|
|
968
|
+
), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
|
|
969
|
+
x, y
|
|
970
|
+
)
|
|
971
|
+
all_tokens = indexed_tokens | unindexed_tokens
|
|
972
|
+
all_annots = []
|
|
973
|
+
[
|
|
974
|
+
all_annots.extend(list(tok.get_annotation_categories()))
|
|
975
|
+
for tok in all_tokens.values()
|
|
976
|
+
]
|
|
977
|
+
all_annots = set(all_annots)
|
|
978
|
+
atom_array_accum = [] if atom_array_accum is None else atom_array_accum
|
|
979
|
+
unindexed_breaks = (
|
|
980
|
+
[None] * len(components_to_accumulate)
|
|
981
|
+
if unindexed_breaks is None
|
|
982
|
+
else unindexed_breaks
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
# ... For-loop accum variables
|
|
986
|
+
unindexed_components_started = (
|
|
987
|
+
False # once one unindexed component is added, stop adding diffused residues
|
|
988
|
+
)
|
|
989
|
+
chain = start_chain
|
|
990
|
+
res_id = start_resid
|
|
991
|
+
molecule_id = 0
|
|
992
|
+
source_to_accum_idx: Dict[int, int] = {}
|
|
993
|
+
current_accum_idx = sum(len(arr) for arr in atom_array_accum)
|
|
994
|
+
|
|
995
|
+
# ... Insert contig information one- by one-
|
|
996
|
+
assert len(components_to_accumulate) == len(
|
|
997
|
+
unindexed_breaks
|
|
998
|
+
), "Mismatch in number of components to accumulate and breaks"
|
|
999
|
+
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
|
|
1000
|
+
src_indices = None
|
|
1001
|
+
if exists(is_break) and is_break:
|
|
1002
|
+
if not unindexed_components_started:
|
|
1003
|
+
chain = start_chain
|
|
1004
|
+
res_id = start_resid
|
|
1005
|
+
unindexed_components_started = True
|
|
1006
|
+
|
|
1007
|
+
if component == "/0":
|
|
1008
|
+
# Reset iterators on next chain
|
|
1009
|
+
chain = chr(ord(chain) + 1)
|
|
1010
|
+
molecule_id += 1
|
|
1011
|
+
res_id = 1
|
|
1012
|
+
continue
|
|
1013
|
+
|
|
1014
|
+
# ... Create array to insert
|
|
1015
|
+
if str(component)[0].isalpha(): # motif (e.g. "A22")
|
|
1016
|
+
n = 1
|
|
1017
|
+
|
|
1018
|
+
# ... Fetch the motif residue
|
|
1019
|
+
token = all_tokens[component]
|
|
1020
|
+
if src_atom_array is not None:
|
|
1021
|
+
src_mask = fetch_mask_from_idx(component, atom_array=src_atom_array)
|
|
1022
|
+
src_indices = np.where(src_mask)[0]
|
|
1023
|
+
# try:
|
|
1024
|
+
# except ComponentValidationError as e:
|
|
1025
|
+
# src_indices = None
|
|
1026
|
+
# print(e)
|
|
1027
|
+
|
|
1028
|
+
# ... Ensure motif residues are set properly
|
|
1029
|
+
token = create_motif_residue(
|
|
1030
|
+
token, strip_sidechains_by_default=strip_sidechains_by_default
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
# ... Insert breakpoint when break clause is met
|
|
1034
|
+
if exists(is_break) and is_break:
|
|
1035
|
+
token.set_annotation(
|
|
1036
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
1037
|
+
np.ones(token.shape[0], dtype=int),
|
|
1038
|
+
)
|
|
1039
|
+
else:
|
|
1040
|
+
token.set_annotation(
|
|
1041
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
1042
|
+
np.zeros(token.shape[0], dtype=int),
|
|
1043
|
+
)
|
|
1044
|
+
else:
|
|
1045
|
+
n = int(component)
|
|
1046
|
+
# ... Skip if none or unindexed
|
|
1047
|
+
if n == 0 or unindexed_components_started:
|
|
1048
|
+
res_id += n
|
|
1049
|
+
continue
|
|
1050
|
+
|
|
1051
|
+
# ... Create diffused residues
|
|
1052
|
+
token = create_diffused_residues(n, all_annots)
|
|
1053
|
+
|
|
1054
|
+
# ... Set index of insertion
|
|
1055
|
+
token = set_indices(
|
|
1056
|
+
array=token,
|
|
1057
|
+
chain=chain,
|
|
1058
|
+
res_id_start=res_id,
|
|
1059
|
+
molecule_id=molecule_id,
|
|
1060
|
+
component=component,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
assert (
|
|
1064
|
+
len(get_token_starts(token)) == n
|
|
1065
|
+
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
|
|
1066
|
+
|
|
1067
|
+
if (
|
|
1068
|
+
src_atom_array is not None
|
|
1069
|
+
and str(component)[0].isalpha()
|
|
1070
|
+
and src_indices is not None
|
|
1071
|
+
and len(src_indices) == len(token)
|
|
1072
|
+
):
|
|
1073
|
+
for i, src_idx in enumerate(src_indices):
|
|
1074
|
+
source_to_accum_idx[int(src_idx)] = current_accum_idx + i
|
|
1075
|
+
|
|
1076
|
+
# ... Insert & Increment residue ID
|
|
1077
|
+
atom_array_accum.append(token)
|
|
1078
|
+
res_id += n
|
|
1079
|
+
current_accum_idx += len(token)
|
|
1080
|
+
|
|
1081
|
+
# ... Concatenate all components
|
|
1082
|
+
atom_array_accum = struc.concatenate(atom_array_accum)
|
|
1083
|
+
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
|
|
1084
|
+
|
|
1085
|
+
should_restore_bonds = (
|
|
1086
|
+
src_atom_array is not None
|
|
1087
|
+
and bool(source_to_accum_idx)
|
|
1088
|
+
and _check_has_backbone_connections_to_nonstandard_residues(
|
|
1089
|
+
atom_array_accum, src_atom_array
|
|
1090
|
+
)
|
|
1091
|
+
)
|
|
1092
|
+
if should_restore_bonds:
|
|
1093
|
+
assert not unindexed_tokens, (
|
|
1094
|
+
"PTM backbone bond restoration is not compatible with unindexed components. "
|
|
1095
|
+
"PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
|
|
1096
|
+
f"Found unindexed components: {list(unindexed_tokens.keys())}"
|
|
1097
|
+
)
|
|
1098
|
+
atom_array_accum = _restore_bonds_for_nonstandard_residues(
|
|
1099
|
+
atom_array_accum, src_atom_array, source_to_accum_idx
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
|
|
1103
|
+
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
|
|
1104
|
+
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
1105
|
+
):
|
|
1106
|
+
max_id = np.max(
|
|
1107
|
+
atom_array_accum[
|
|
1108
|
+
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
1109
|
+
].res_id
|
|
1110
|
+
)
|
|
1111
|
+
min_id_udx = np.min(
|
|
1112
|
+
atom_array_accum[
|
|
1113
|
+
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
1114
|
+
].res_id
|
|
1115
|
+
)
|
|
1116
|
+
atom_array_accum.res_id[
|
|
1117
|
+
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
|
1118
|
+
] += max_id - min_id_udx + 1
|
|
1119
|
+
|
|
1120
|
+
# ... Bonds
|
|
1121
|
+
if atom_array_accum.bonds is None:
|
|
1122
|
+
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
|
|
1123
|
+
return atom_array_accum
|