rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from biotite.structure import AtomArray, get_residue_starts
|
|
5
|
+
from pydantic import (
|
|
6
|
+
BaseModel,
|
|
7
|
+
ConfigDict,
|
|
8
|
+
Field,
|
|
9
|
+
model_serializer,
|
|
10
|
+
model_validator,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from foundry.utils.components import (
|
|
14
|
+
ComponentStr,
|
|
15
|
+
fetch_mask_from_idx,
|
|
16
|
+
get_name_mask,
|
|
17
|
+
split_contig,
|
|
18
|
+
unravel_components,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# ============================================================================
|
|
22
|
+
# Input Specification & Validation
|
|
23
|
+
# ============================================================================
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class InputSelection(BaseModel):
|
|
27
|
+
model_config = ConfigDict(
|
|
28
|
+
arbitrary_types_allowed=True,
|
|
29
|
+
str_strip_whitespace=True,
|
|
30
|
+
str_min_length=1,
|
|
31
|
+
)
|
|
32
|
+
data: Dict[ComponentStr | str, List[str]] = Field(
|
|
33
|
+
..., description="Validated selection dictionary", exclude=True
|
|
34
|
+
)
|
|
35
|
+
raw: Any = Field(..., description="Original input value")
|
|
36
|
+
mask: np.ndarray[np.bool_] = Field(
|
|
37
|
+
..., description="Boolean mask over atom array", exclude=True
|
|
38
|
+
)
|
|
39
|
+
tokens: Optional[Dict[ComponentStr | str, AtomArray]] = Field(
|
|
40
|
+
..., description="Selected atom arrays per component", exclude=True
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_any(
|
|
45
|
+
cls, v: Union[str, bool, dict, None], atom_array: AtomArray
|
|
46
|
+
) -> Optional["InputSelection"]:
|
|
47
|
+
"""Create InputSelection from various input types."""
|
|
48
|
+
if v is None:
|
|
49
|
+
return None
|
|
50
|
+
data, mask, _ = from_any_(v=v, atom_array=atom_array)
|
|
51
|
+
return cls(
|
|
52
|
+
raw=v,
|
|
53
|
+
data=data,
|
|
54
|
+
mask=mask,
|
|
55
|
+
tokens=None,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@model_validator(mode="after")
|
|
59
|
+
def check_keys(self):
|
|
60
|
+
# lightweight validation that all keys have contig format (are splittable indices)
|
|
61
|
+
assert all([split_contig(k) for k in self.data.keys()])
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
# Wrapper functionality as dict-like
|
|
65
|
+
def __getitem__(self, key: str) -> List[str]:
|
|
66
|
+
"""Allow dict-like access."""
|
|
67
|
+
return self.data[key]
|
|
68
|
+
|
|
69
|
+
def items(self):
|
|
70
|
+
return self.data.items()
|
|
71
|
+
|
|
72
|
+
def keys(self):
|
|
73
|
+
return self.data.keys()
|
|
74
|
+
|
|
75
|
+
def values(self):
|
|
76
|
+
return self.data.values()
|
|
77
|
+
|
|
78
|
+
def get(self, *args, **kwargs):
|
|
79
|
+
return self.data.get(*args, **kwargs)
|
|
80
|
+
|
|
81
|
+
# Serialization & repr
|
|
82
|
+
def __repr__(self) -> str:
|
|
83
|
+
num_atoms = self.mask.sum() if hasattr(self.mask, "sum") else 0
|
|
84
|
+
num_tokens = len(self.tokens) if self.tokens else 0
|
|
85
|
+
return (
|
|
86
|
+
f"InputSelection(raw={self.raw!r}, atoms={num_atoms}, tokens={num_tokens})"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@model_serializer
|
|
90
|
+
def serialize_raw(self) -> Any:
|
|
91
|
+
return self.raw
|
|
92
|
+
|
|
93
|
+
# Listed as separate methods for future changes to parsing.
|
|
94
|
+
def get_mask(self):
|
|
95
|
+
return self.mask
|
|
96
|
+
|
|
97
|
+
def get_tokens(self, aa):
|
|
98
|
+
_, _, tokens = from_any_(v=self.raw, atom_array=aa)
|
|
99
|
+
return tokens
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def from_any_(v: Any, atom_array: AtomArray):
|
|
103
|
+
data_norm = canonicalize_(v, atom_array)
|
|
104
|
+
|
|
105
|
+
# Canonicalize dictionaries to SelectionDict (I.e. convert "ALL" / "TIP" -> concrete atom names)
|
|
106
|
+
data_split = {}
|
|
107
|
+
mask = np.array([False] * len(atom_array))
|
|
108
|
+
tokens = {}
|
|
109
|
+
for idx, atm_names in data_norm.items():
|
|
110
|
+
# Find atom array subset
|
|
111
|
+
comp_mask = fetch_mask_from_idx(idx, atom_array=atom_array)
|
|
112
|
+
token = atom_array[comp_mask]
|
|
113
|
+
|
|
114
|
+
comp_mask_subset = get_name_mask(
|
|
115
|
+
token.atom_name, atm_names, token.res_name[0]
|
|
116
|
+
) # [N_atoms_in_token,]
|
|
117
|
+
|
|
118
|
+
# Split to atom names
|
|
119
|
+
data_split[idx] = token.atom_name[comp_mask_subset].tolist()
|
|
120
|
+
|
|
121
|
+
# Update mask & token dictionary
|
|
122
|
+
mask[comp_mask] = comp_mask_subset
|
|
123
|
+
tokens[idx] = token[comp_mask_subset]
|
|
124
|
+
|
|
125
|
+
return (data_split, mask, tokens)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def canonicalize_(v, atom_array: AtomArray):
|
|
129
|
+
# Canonicalize inputs to dictionaries of strings:
|
|
130
|
+
# "A11-12" -> {"A11": "N,CA,C,...", "A12": "N,CA,C,..."}
|
|
131
|
+
# True -> {"A1": "ALL", "A2": "ALL", ...}
|
|
132
|
+
# False -> {"A1": "", "A2": "", ...}
|
|
133
|
+
# "LIG" -> {"B1": "ALL", "C1": "ALL"} (for two ligands named LIG)
|
|
134
|
+
data = {}
|
|
135
|
+
if isinstance(v, str):
|
|
136
|
+
for component in unravel_components(
|
|
137
|
+
v, atom_array=atom_array, allow_multiple_matches=True
|
|
138
|
+
):
|
|
139
|
+
if (
|
|
140
|
+
isinstance(component, str) and component[0].isalpha()
|
|
141
|
+
): # filter on valid chain IDs
|
|
142
|
+
data[component] = "ALL"
|
|
143
|
+
|
|
144
|
+
elif isinstance(v, bool):
|
|
145
|
+
starts = get_residue_starts(atom_array, add_exclusive_stop=True)
|
|
146
|
+
for start, stop in zip(starts[:-1], starts[1:]):
|
|
147
|
+
token = atom_array[start:stop]
|
|
148
|
+
# All atoms selected for every token or None
|
|
149
|
+
data[f"{token.chain_id[0]}{token.res_id[0]}"] = "ALL" if v else ""
|
|
150
|
+
|
|
151
|
+
elif isinstance(v, dict):
|
|
152
|
+
# Ensure all values of dictionaries are strings
|
|
153
|
+
data = {}
|
|
154
|
+
for k, vv in v.items():
|
|
155
|
+
for component in unravel_components(
|
|
156
|
+
k, atom_array=atom_array, allow_multiple_matches=True
|
|
157
|
+
):
|
|
158
|
+
if isinstance(vv, list):
|
|
159
|
+
data[component] = ",".join(vv)
|
|
160
|
+
else:
|
|
161
|
+
data[component] = vv
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(f"Cannot convert {type(v)} to InputSelection")
|
|
164
|
+
|
|
165
|
+
return data
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from rfd3.inference.symmetry.frames import (
|
|
3
|
+
decompose_symmetry_frame,
|
|
4
|
+
get_symmetry_frames_from_symmetry_id,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from foundry.utils.ddp import RankedLogger
|
|
8
|
+
|
|
9
|
+
FIXED_TRANSFORM_ID = -1
|
|
10
|
+
FIXED_ENTITY_ID = -1
|
|
11
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
########################################################
|
|
15
|
+
# Symmetry annotations
|
|
16
|
+
########################################################
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def add_sym_annotations(atom_array, sym_conf):
|
|
20
|
+
"""
|
|
21
|
+
Add symmetry base annotations to an atom array.
|
|
22
|
+
Arguments:
|
|
23
|
+
atom_array: atom array of symmetry subunit
|
|
24
|
+
sym_conf: symmetry configuration (dict, "id" key is required)
|
|
25
|
+
"""
|
|
26
|
+
n = atom_array.shape[0]
|
|
27
|
+
# which is the asymmetric unit? At this point, we annotate everything as the asu
|
|
28
|
+
is_asu = np.full(n, True, dtype=np.bool_)
|
|
29
|
+
atom_array.set_annotation("is_sym_asu", is_asu)
|
|
30
|
+
# symmetry_id
|
|
31
|
+
symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
|
|
32
|
+
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
33
|
+
return atom_array
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def add_sym_annotations_to_fixed_motif(atom_array):
|
|
37
|
+
"""
|
|
38
|
+
Add symmetry annotations to a motif atom array.
|
|
39
|
+
Arguments:
|
|
40
|
+
atom_array: atom array of symmetry subunit
|
|
41
|
+
"""
|
|
42
|
+
n = atom_array.shape[0]
|
|
43
|
+
|
|
44
|
+
# setting the identity transform
|
|
45
|
+
Ori, X, Y = decompose_symmetry_frame((np.eye(3), np.zeros(3)))
|
|
46
|
+
Oris = np.full(n, Ori)
|
|
47
|
+
Xs = np.full(n, X)
|
|
48
|
+
Ys = np.full(n, Y)
|
|
49
|
+
atom_array.set_annotation("sym_transform_Ori", Oris)
|
|
50
|
+
atom_array.set_annotation("sym_transform_X", Xs)
|
|
51
|
+
atom_array.set_annotation("sym_transform_Y", Ys)
|
|
52
|
+
|
|
53
|
+
transform_ids = np.full(n, FIXED_TRANSFORM_ID, dtype=np.int32)
|
|
54
|
+
atom_array.set_annotation("sym_transform_id", transform_ids)
|
|
55
|
+
entity_ids = np.full(n, FIXED_ENTITY_ID, dtype=np.int32)
|
|
56
|
+
atom_array.set_annotation("sym_entity_id", entity_ids)
|
|
57
|
+
# make sure that the motif is not the asu
|
|
58
|
+
is_sym_asu = np.full(n, False, dtype=np.bool_)
|
|
59
|
+
atom_array.set_annotation("is_sym_asu", is_sym_asu)
|
|
60
|
+
return atom_array
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def add_src_sym_component_annotations(atom_array):
|
|
64
|
+
"""
|
|
65
|
+
Add src_sym_component annotations to an atom array.
|
|
66
|
+
This is used to correctly map the original motif id to diffused unindexed motifs.
|
|
67
|
+
Arguments:
|
|
68
|
+
atom_array: atom array with src_component annotated
|
|
69
|
+
"""
|
|
70
|
+
if "src_component" not in atom_array.get_annotation_categories():
|
|
71
|
+
return atom_array
|
|
72
|
+
|
|
73
|
+
src_sym_component = atom_array.src_component.copy()
|
|
74
|
+
src_tokens = np.unique(atom_array.src_component)
|
|
75
|
+
|
|
76
|
+
for src_token in src_tokens:
|
|
77
|
+
# Skip non-alphabetic tokens
|
|
78
|
+
if len(src_token) == 0:
|
|
79
|
+
continue
|
|
80
|
+
if not src_token[0].isalpha():
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
# Get block of atoms with this src token
|
|
84
|
+
src_block_mask = atom_array.src_component == src_token
|
|
85
|
+
src_block = atom_array[src_block_mask]
|
|
86
|
+
|
|
87
|
+
# Skip if not all unindexed motif atoms
|
|
88
|
+
if not src_block.is_motif_atom_unindexed.all():
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
# Update src component with chain ID prefix
|
|
92
|
+
for chain_id in np.unique(src_block.chain_id):
|
|
93
|
+
chain_mask = src_block.chain_id == chain_id
|
|
94
|
+
src_block.src_component[chain_mask] = chain_id + src_token[1:]
|
|
95
|
+
|
|
96
|
+
src_sym_component[src_block_mask] = src_block.src_component
|
|
97
|
+
|
|
98
|
+
atom_array.set_annotation("src_sym_component", src_sym_component)
|
|
99
|
+
return atom_array
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def fix_3D_sym_motif_annotations(atom_array):
|
|
103
|
+
"""
|
|
104
|
+
Add fixed motif annotations to the 3D NON-indexed motifs (only unindexed and ligands).
|
|
105
|
+
since indexed motifs are contiguously connected to generative residues,
|
|
106
|
+
it should NOT be fixed, instead get symmetrized at each step
|
|
107
|
+
Arguments:
|
|
108
|
+
atom_array: atom array
|
|
109
|
+
"""
|
|
110
|
+
# fixed_motif_mask = atom_array.is_motif_atom_with_fixed_coord == 1
|
|
111
|
+
fixed_motif_mask = atom_array._is_motif & ~atom_array._is_indexed_motif
|
|
112
|
+
fixed_motif_array = atom_array[fixed_motif_mask].copy()
|
|
113
|
+
fixed_motif_array = add_sym_annotations_to_fixed_motif(fixed_motif_array)
|
|
114
|
+
atom_array[fixed_motif_mask] = fixed_motif_array
|
|
115
|
+
return atom_array
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def add_sym_transform_annotations(atom_array, transform_id, frame, is_asu=False):
|
|
119
|
+
"""
|
|
120
|
+
Add symmetry annotations to an atom array.
|
|
121
|
+
Arguments:
|
|
122
|
+
atom_array: atom array of symmetry subunit
|
|
123
|
+
transform_id: index of the transform frame
|
|
124
|
+
is_asu: whether this is the asymmetric unit
|
|
125
|
+
Returns:
|
|
126
|
+
atom_array: atom array with symmetry annotations
|
|
127
|
+
"""
|
|
128
|
+
Ori, X, Y = decompose_symmetry_frame(frame)
|
|
129
|
+
n = atom_array.shape[0]
|
|
130
|
+
|
|
131
|
+
# symmetry transform (decomposed into Ori, X, Y)
|
|
132
|
+
Oris = np.full(n, Ori)
|
|
133
|
+
Xs = np.full(n, X)
|
|
134
|
+
Ys = np.full(n, Y)
|
|
135
|
+
atom_array.set_annotation("sym_transform_Ori", Oris)
|
|
136
|
+
atom_array.set_annotation("sym_transform_X", Xs)
|
|
137
|
+
atom_array.set_annotation("sym_transform_Y", Ys)
|
|
138
|
+
|
|
139
|
+
# symmetry transform id
|
|
140
|
+
transform_ids = np.full(n, transform_id, dtype=np.int32)
|
|
141
|
+
atom_array.set_annotation("sym_transform_id", transform_ids)
|
|
142
|
+
|
|
143
|
+
# entity ids - this will help keep track of different multiplicities
|
|
144
|
+
# if there are sm, they will have different entity ids from the prot atoms
|
|
145
|
+
unique_chain_ids = np.unique(atom_array.chain_id).tolist()
|
|
146
|
+
unique_chain_ids.sort()
|
|
147
|
+
entity_ids = np.array([unique_chain_ids.index(id) for id in atom_array.chain_id])
|
|
148
|
+
atom_array.set_annotation("sym_entity_id", entity_ids)
|
|
149
|
+
|
|
150
|
+
is_sym_asu = np.full(n, is_asu, dtype=np.bool_)
|
|
151
|
+
atom_array.set_annotation("is_sym_asu", is_sym_asu)
|
|
152
|
+
|
|
153
|
+
return atom_array
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def apply_symmetry_to_atomarray_coord(atom_array, frame):
|
|
157
|
+
"""
|
|
158
|
+
Apply symmetry to the atom array coordinates.
|
|
159
|
+
Arguments:
|
|
160
|
+
atom_array: atom array
|
|
161
|
+
frame: symmetry frame (R, T)
|
|
162
|
+
"""
|
|
163
|
+
R, T = frame
|
|
164
|
+
atom_array.coord = atom_array.coord @ R.T
|
|
165
|
+
atom_array.coord += T # T should be 0 for most symmetry cases
|
|
166
|
+
return atom_array
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
########################################################
|
|
170
|
+
# Motif functions
|
|
171
|
+
########################################################
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def annotate_unsym_atom_array(atom_array):
|
|
175
|
+
"""
|
|
176
|
+
Annotate the unsym motif and return it.
|
|
177
|
+
Arguments:
|
|
178
|
+
atom_array: atom array
|
|
179
|
+
unsym_motif_mask: mask of unsym motifs
|
|
180
|
+
"""
|
|
181
|
+
unsym_atom_array = atom_array.copy()
|
|
182
|
+
unsym_atom_array._is_asu = np.full(unsym_atom_array.shape[0], False, dtype=np.bool_)
|
|
183
|
+
unsym_atom_array.is_sym_asu = unsym_atom_array._is_asu
|
|
184
|
+
unsym_atom_array = reset_chain_ids(
|
|
185
|
+
unsym_atom_array, start_id="a"
|
|
186
|
+
) # give it a lowercase chain id to avoid confusion with symmetry units
|
|
187
|
+
unsym_atom_array = add_sym_annotations_to_fixed_motif(unsym_atom_array)
|
|
188
|
+
return unsym_atom_array
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
########################################################
|
|
192
|
+
# 2D conditioning functions
|
|
193
|
+
########################################################
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def add_2d_entity_annotations(atom_array):
|
|
197
|
+
entity_ids = np.zeros(atom_array.shape[0], dtype=np.int32)
|
|
198
|
+
categories = get_2d_annotation_categories(atom_array)
|
|
199
|
+
entity_id = 1
|
|
200
|
+
for i, anno in enumerate(categories):
|
|
201
|
+
entity_id = i + 1
|
|
202
|
+
entity_ids[atom_array.get_annotation(anno) == 1] = entity_id
|
|
203
|
+
atom_array.set_annotation("_2d_entity_id", entity_ids)
|
|
204
|
+
return atom_array
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def reannotate_2d_entity_ids(atom_array, transform_id):
|
|
208
|
+
if "_2d_entity_id" not in atom_array.get_annotation_categories():
|
|
209
|
+
return atom_array
|
|
210
|
+
_2d_annos = get_2d_annotation_categories(atom_array)
|
|
211
|
+
frames = get_symmetry_frames_from_symmetry_id(atom_array.symmetry_id[0])
|
|
212
|
+
# NOTE: assuming its either 2d cond is within a subunit was specified or all active sites were explicity specified
|
|
213
|
+
max_entity_id = max(len(_2d_annos), len(frames))
|
|
214
|
+
mask = atom_array.get_annotation("_2d_entity_id") != 0
|
|
215
|
+
atom_array._2d_entity_id[mask] = (
|
|
216
|
+
(atom_array._2d_entity_id[mask] + transform_id - 1) % max_entity_id
|
|
217
|
+
) + 1
|
|
218
|
+
return atom_array
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def get_2d_annotation_categories(atom_array):
|
|
222
|
+
categories = []
|
|
223
|
+
for anno in atom_array.get_annotation_categories():
|
|
224
|
+
if "2d_condition" in anno:
|
|
225
|
+
categories.append(anno)
|
|
226
|
+
categories.sort() # sort to make sure that categories are in ascending order
|
|
227
|
+
return categories
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def reannotate_2d_conditions(atom_array):
|
|
231
|
+
entity_ids_anno = atom_array.get_annotation("_2d_entity_id")
|
|
232
|
+
entity_ids = [d for d in np.unique(entity_ids_anno) if d != 0]
|
|
233
|
+
categories = get_2d_annotation_categories(atom_array)
|
|
234
|
+
diff = len(entity_ids) - len(categories)
|
|
235
|
+
if diff > 0:
|
|
236
|
+
for i in range(len(categories), len(categories) + diff):
|
|
237
|
+
categories.append(f"{categories[0]}_{i}")
|
|
238
|
+
for d, anno in zip(entity_ids, categories):
|
|
239
|
+
atom_array.set_annotation(anno, entity_ids_anno == d)
|
|
240
|
+
atom_array.del_annotation("_2d_entity_id")
|
|
241
|
+
return atom_array
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
########################################################
|
|
245
|
+
# Utility functions
|
|
246
|
+
########################################################
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def reset_chain_ids(atom_array, start_id):
|
|
250
|
+
"""
|
|
251
|
+
Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
|
|
252
|
+
Arguments:
|
|
253
|
+
atom_array: atom array with chain_ids and pn_unit_iids annotated
|
|
254
|
+
"""
|
|
255
|
+
chain_ids = np.unique(atom_array.chain_id)
|
|
256
|
+
new_chain_range = range(ord(start_id), ord(start_id) + len(chain_ids))
|
|
257
|
+
for new_id, old_id in zip(new_chain_range, chain_ids):
|
|
258
|
+
atom_array.chain_id[atom_array.chain_id == old_id] = chr(new_id)
|
|
259
|
+
atom_array.pn_unit_iid = atom_array.chain_id
|
|
260
|
+
return atom_array
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def reannotate_chain_ids(atom_array, offset, multiplier=0):
|
|
264
|
+
"""
|
|
265
|
+
Reannotate the chain ids and pn_unit_iids of an atom array.
|
|
266
|
+
Arguments:
|
|
267
|
+
atom_array: protein atom array with chain_ids and pn_unit_iids annotated
|
|
268
|
+
offset: offset to add to the chain ids
|
|
269
|
+
multiplier: multiplier to add to the chain ids
|
|
270
|
+
"""
|
|
271
|
+
chain_ids_int = (
|
|
272
|
+
np.array([ord(c) for c in atom_array.chain_id]) + offset * multiplier
|
|
273
|
+
)
|
|
274
|
+
chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
|
|
275
|
+
atom_array.chain_id = chain_ids
|
|
276
|
+
atom_array.pn_unit_iid = chain_ids
|
|
277
|
+
return atom_array
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def get_symmetry_unit(asu_atom_array, transform_id, frame):
|
|
281
|
+
"""
|
|
282
|
+
Annotate the ASU protein atom array and return it for each symmetry unit.
|
|
283
|
+
Arguments:
|
|
284
|
+
asu_atom_array: atom array of the asymmetric unit, annotated with symmetry_id
|
|
285
|
+
transform_id: index of the symmetry unit
|
|
286
|
+
frame: symmetry frame
|
|
287
|
+
"""
|
|
288
|
+
num_prot_chains = len(np.unique(asu_atom_array.chain_id))
|
|
289
|
+
symmetry_unit = asu_atom_array.copy()
|
|
290
|
+
symmetry_unit = reannotate_chain_ids(symmetry_unit, num_prot_chains, transform_id)
|
|
291
|
+
symmetry_unit = reannotate_2d_entity_ids(symmetry_unit, transform_id)
|
|
292
|
+
symmetry_unit = add_sym_transform_annotations(
|
|
293
|
+
symmetry_unit, transform_id, frame, is_asu=(transform_id == 0)
|
|
294
|
+
)
|
|
295
|
+
# apply symmetry to indexed motifs
|
|
296
|
+
# at this point, the diffused coordinates are at the origin/ have no xyz
|
|
297
|
+
symmetry_unit = apply_symmetry_to_atomarray_coord(symmetry_unit, frame)
|
|
298
|
+
return symmetry_unit
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from rfd3.inference.symmetry.contigs import expand_contig_unsym_motif
|
|
3
|
+
from rfd3.transforms.conditioning_base import get_motif_features
|
|
4
|
+
|
|
5
|
+
from foundry.utils.ddp import RankedLogger
|
|
6
|
+
|
|
7
|
+
MIN_ATOMS_ALIGN = 100
|
|
8
|
+
MAX_TRANSFORMS = 10
|
|
9
|
+
RMSD_CUT = 1.0 # Angstroms
|
|
10
|
+
|
|
11
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def check_symmetry_config(
|
|
15
|
+
atom_array, sym_conf, sm, has_dist_cond, src_atom_array=None, partial=False
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Check if the symmetry configuration is valid. Add all basic checks here.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
assert sym_conf.get("id"), "symmetry_id is required. e.g. {'id': 'C2'}"
|
|
22
|
+
# if unsym motif is provided, check that each motif name is in the atom array
|
|
23
|
+
if sym_conf.get("is_unsym_motif"):
|
|
24
|
+
assert (
|
|
25
|
+
src_atom_array is not None
|
|
26
|
+
), "Source atom array must be provided for symmetric motifs"
|
|
27
|
+
unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
|
|
28
|
+
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
29
|
+
for n in unsym_motif_names:
|
|
30
|
+
if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
|
|
31
|
+
raise ValueError(f"Unsym motif {n} not found in atom_array")
|
|
32
|
+
if (
|
|
33
|
+
get_motif_features(atom_array)["is_motif_token"].any()
|
|
34
|
+
and not sym_conf.get("is_symmetric_motif")
|
|
35
|
+
and not has_dist_cond
|
|
36
|
+
):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Asymmetric motif inputs should be distance constrained. "
|
|
39
|
+
"Use atomwise_fixed_dist to constrain the distance between the motif atoms."
|
|
40
|
+
)
|
|
41
|
+
# else: if unconditional symmetry, no need to have symmetric input motif
|
|
42
|
+
|
|
43
|
+
if partial and not sym_conf.get("is_symmetric_motif"):
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"Partial diffusion with symmetry is only supported for symmetric inputs."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def check_atom_array_is_symmetric(atom_array):
|
|
50
|
+
"""
|
|
51
|
+
Check if the atom array is symmetric. This is NOT to check that the atom array symmetry matches that of the symmetry_id.
|
|
52
|
+
Arguments:
|
|
53
|
+
atom_array: atom arrays to check
|
|
54
|
+
Returns:
|
|
55
|
+
bool: True if the atom array is symmetric, False otherwise
|
|
56
|
+
"""
|
|
57
|
+
# TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
|
|
58
|
+
# and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
|
|
59
|
+
|
|
60
|
+
import biotite.structure as struc
|
|
61
|
+
from rfd3.inference.symmetry.atom_array import (
|
|
62
|
+
apply_symmetry_to_atomarray_coord,
|
|
63
|
+
)
|
|
64
|
+
from rfd3.inference.symmetry.frames import (
|
|
65
|
+
get_symmetry_frames_from_symmetry_id,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# remove hetero atoms
|
|
69
|
+
atom_array = atom_array[~atom_array.hetero]
|
|
70
|
+
if len(atom_array) == 0:
|
|
71
|
+
ranked_logger.info("Atom array has no protein chains. Please check your input.")
|
|
72
|
+
return False
|
|
73
|
+
|
|
74
|
+
chains = np.unique(atom_array.chain_id)
|
|
75
|
+
asu_mask = atom_array.chain_id == chains[0]
|
|
76
|
+
asu_atoms = atom_array[asu_mask].copy()
|
|
77
|
+
|
|
78
|
+
# Check that all atom arrays have the same number of atoms
|
|
79
|
+
for chain in chains[1:]:
|
|
80
|
+
chain_mask = atom_array.chain_id == chain
|
|
81
|
+
if len(asu_atoms) != len(atom_array[chain_mask]):
|
|
82
|
+
ranked_logger.info(
|
|
83
|
+
f"Atom array has different number of atoms in chain {chain}. {len(asu_atoms)} != {len(atom_array[chain_mask])}"
|
|
84
|
+
)
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
# Check that all atom arrays have the same atoms
|
|
88
|
+
for chain in chains[1:]:
|
|
89
|
+
chain_mask = atom_array.chain_id == chain
|
|
90
|
+
for i in range(len(asu_atoms)):
|
|
91
|
+
if asu_atoms.atom_name[i] != atom_array[chain_mask].atom_name[i]:
|
|
92
|
+
ranked_logger.info(
|
|
93
|
+
f"Atom array has different atoms in chain {chain}. {asu_atoms.atom_name[i]} != {atom_array[chain_mask].atom_name[i]}"
|
|
94
|
+
)
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
# Check that the atom array aligns with the standard symmetry frames
|
|
98
|
+
standard_frames = get_symmetry_frames_from_symmetry_id(atom_array.symmetry_id[0])
|
|
99
|
+
standard_atom_array = []
|
|
100
|
+
for frame in standard_frames:
|
|
101
|
+
symmed_atoms = apply_symmetry_to_atomarray_coord(asu_atoms, frame)
|
|
102
|
+
standard_atom_array.append(symmed_atoms)
|
|
103
|
+
standard_atom_array = struc.concatenate(standard_atom_array)
|
|
104
|
+
|
|
105
|
+
R_standard_obtained = find_optimal_rotation(
|
|
106
|
+
standard_atom_array.coord, atom_array.coord
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if R_standard_obtained is None:
|
|
110
|
+
ranked_logger.info(
|
|
111
|
+
"Atom array does not align with the standard symmetry frames."
|
|
112
|
+
)
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def find_optimal_rotation(coords1, coords2, max_points=1000):
|
|
119
|
+
"""
|
|
120
|
+
Find optimal rotation matrix between two sets of coordinates using Kabsch algorithm.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
coords1: reference coordinates (N, 3)
|
|
124
|
+
coords2: target coordinates (N, 3)
|
|
125
|
+
max_points: maximum number of points to use for efficiency
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
rotation_matrix: 3x3 rotation matrix or None if failed
|
|
129
|
+
"""
|
|
130
|
+
if len(coords1) > max_points:
|
|
131
|
+
indices = np.random.choice(len(coords1), max_points, replace=False)
|
|
132
|
+
coords1 = coords1[indices]
|
|
133
|
+
coords2 = coords2[indices]
|
|
134
|
+
|
|
135
|
+
# Ensure same number of points
|
|
136
|
+
min_len = min(len(coords1), len(coords2))
|
|
137
|
+
coords1 = coords1[:min_len]
|
|
138
|
+
coords2 = coords2[:min_len]
|
|
139
|
+
if min_len < 3:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
# Kabsch algorithm
|
|
143
|
+
try:
|
|
144
|
+
centroid1 = np.mean(coords1, axis=0)
|
|
145
|
+
centroid2 = np.mean(coords2, axis=0)
|
|
146
|
+
coords1_centered = coords1 - centroid1
|
|
147
|
+
coords2_centered = coords2 - centroid2
|
|
148
|
+
|
|
149
|
+
# Compute covariance matrix
|
|
150
|
+
H = coords1_centered.T @ coords2_centered
|
|
151
|
+
|
|
152
|
+
U, S, Vt = np.linalg.svd(H)
|
|
153
|
+
R = Vt.T @ U.T
|
|
154
|
+
# Ensure proper rotation matrix
|
|
155
|
+
if np.linalg.det(R) < 0:
|
|
156
|
+
Vt[-1, :] *= -1
|
|
157
|
+
R = Vt.T @ U.T
|
|
158
|
+
return R
|
|
159
|
+
|
|
160
|
+
except Exception as e:
|
|
161
|
+
print(f"Error in rotation calculation: {e}")
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def check_input_frames_match_symmetry_frames(computed_frames, original_frames) -> None:
|
|
166
|
+
"""
|
|
167
|
+
Check if the atom array matches the symmetry_id.
|
|
168
|
+
Arguments:
|
|
169
|
+
computed_frames: list of computed frames
|
|
170
|
+
original_frames: list of original frames
|
|
171
|
+
"""
|
|
172
|
+
assert len(computed_frames) == len(
|
|
173
|
+
original_frames
|
|
174
|
+
), "Number of computed frames does not match number of original frames"
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def check_valid_multiplicity(nids_by_entity) -> None:
|
|
178
|
+
"""
|
|
179
|
+
Check if the multiplicity is valid.
|
|
180
|
+
Arguments:
|
|
181
|
+
nids_by_entity: dict mapping entity to ids
|
|
182
|
+
"""
|
|
183
|
+
# get multiplicities of subunits
|
|
184
|
+
multiplicity = min([len(i) for i in nids_by_entity.values()])
|
|
185
|
+
if multiplicity == 1: # no possible symmetry
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Check that the input is not asymmetric
|
|
191
|
+
multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
|
|
192
|
+
if not all(multiplicity_good):
|
|
193
|
+
raise ValueError("Invalid multiplicities of subunits. Please check your input.")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Check that the subunits in the input are of the same size.
|
|
199
|
+
Arguments:
|
|
200
|
+
nids_by_entity: dict mapping entity to ids
|
|
201
|
+
"""
|
|
202
|
+
for i, js in nids_by_entity.items():
|
|
203
|
+
for j in js[1:]:
|
|
204
|
+
if (pn_unit_id == js[0]).sum() != (pn_unit_id == j).sum():
|
|
205
|
+
raise ValueError("Size mismatch in the input. Please check your file.")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
209
|
+
"""
|
|
210
|
+
Check that we have enough atoms to align.
|
|
211
|
+
Arguments:
|
|
212
|
+
nids_by_entity: dict mapping entity to ids
|
|
213
|
+
"""
|
|
214
|
+
if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
|
|
215
|
+
raise ValueError("Not enough atoms to align. Please check your input.")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def check_max_transforms(chains_to_consider) -> None:
|
|
219
|
+
"""
|
|
220
|
+
Check that we are not exceeding the max number of transforms.
|
|
221
|
+
Arguments:
|
|
222
|
+
chains_to_consider: list of chains to consider
|
|
223
|
+
max_transforms: max number of transforms
|
|
224
|
+
"""
|
|
225
|
+
if len(chains_to_consider) > MAX_TRANSFORMS:
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"Number of transforms exceeds the max number of transforms (10)"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def check_max_rmsds(rmsds) -> None:
|
|
232
|
+
"""
|
|
233
|
+
Check that the RMSD between the reference molecule and the other molecules is not too big.
|
|
234
|
+
Arguments:
|
|
235
|
+
rmsds: dict mapping chain to RMSD
|
|
236
|
+
"""
|
|
237
|
+
if max(rmsds.values()) > RMSD_CUT:
|
|
238
|
+
ranked_logger.warning(
|
|
239
|
+
f"RMSD between the reference molecule and the other molecules is too big ({max(rmsds.values())} > {RMSD_CUT}). Please provide a symmetric input PDB file."
|
|
240
|
+
)
|
|
241
|
+
# raise ValueError(f"RMSD between the reference molecule and the other molecules is too big ({max(rmsds.values())} > {RMSD_CUT}). Please provide a symmetric input PDB file.")
|