rc-foundry 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/checkpoint_registry.py +58 -11
- foundry/utils/alignment.py +10 -2
- foundry/utils/ddp.py +1 -1
- foundry/utils/logging.py +1 -1
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +66 -66
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/METADATA +30 -21
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/RECORD +31 -31
- rf3/cli.py +13 -4
- rf3/inference.py +3 -1
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +1 -1
- rfd3/configs/inference_engine/rfdiffusion3.yaml +2 -2
- rfd3/configs/model/samplers/symmetry.yaml +1 -1
- rfd3/engine.py +28 -12
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +32 -1
- rfd3/inference/legacy_input_parsing.py +17 -1
- rfd3/inference/parsing.py +1 -0
- rfd3/inference/symmetry/atom_array.py +78 -13
- rfd3/inference/symmetry/checks.py +62 -29
- rfd3/inference/symmetry/frames.py +256 -5
- rfd3/inference/symmetry/symmetry_utils.py +39 -61
- rfd3/model/inference_sampler.py +11 -1
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/run_inference.py +3 -1
- rfd3/transforms/symmetry.py +16 -7
- rfd3/utils/inference.py +21 -22
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/licenses/LICENSE.md +0 -0
rf3/cli.py
CHANGED
|
@@ -23,10 +23,19 @@ def fold(
|
|
|
23
23
|
configure_minimal_inference_logging()
|
|
24
24
|
|
|
25
25
|
# Find the RF3 configs directory relative to this file
|
|
26
|
-
#
|
|
27
|
-
#
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
# In development: models/rf3/src/rf3/cli.py -> models/rf3/configs/
|
|
27
|
+
# When installed: site-packages/rf3/cli.py -> site-packages/rf3/configs/
|
|
28
|
+
rf3_file_dir = Path(__file__).parent
|
|
29
|
+
|
|
30
|
+
# Check if we're in installed mode (configs are sibling to this file)
|
|
31
|
+
# or development mode (configs are ../../../configs)
|
|
32
|
+
if (rf3_file_dir / "configs").exists():
|
|
33
|
+
# Installed mode
|
|
34
|
+
config_path = str(rf3_file_dir / "configs")
|
|
35
|
+
else:
|
|
36
|
+
# Development mode
|
|
37
|
+
rf3_package_dir = rf3_file_dir.parent.parent # Go up to models/rf3/
|
|
38
|
+
config_path = str(rf3_package_dir / "configs")
|
|
30
39
|
|
|
31
40
|
# Get all arguments
|
|
32
41
|
args = ctx.params.get("args", []) + ctx.args
|
rf3/inference.py
CHANGED
|
@@ -16,7 +16,9 @@ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
16
16
|
|
|
17
17
|
load_dotenv(override=True)
|
|
18
18
|
|
|
19
|
-
_config_path = os.path.join(
|
|
19
|
+
_config_path = os.path.join(
|
|
20
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
|
|
21
|
+
)
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@hydra.main(
|
|
@@ -7,7 +7,7 @@ dataset:
|
|
|
7
7
|
base_dir: ${paths.data.pdb_data_dir}
|
|
8
8
|
dataset:
|
|
9
9
|
name: interface
|
|
10
|
-
data: ${paths.data.pdb_parquet_dir}/
|
|
10
|
+
data: ${paths.data.pdb_parquet_dir}/interfaces_df.parquet
|
|
11
11
|
filters:
|
|
12
12
|
# filters common across all PDB datasets
|
|
13
13
|
- "deposition_date < '2021-09-30'"
|
|
@@ -7,7 +7,7 @@ _target_: rfd3.engine.RFD3InferenceEngine
|
|
|
7
7
|
|
|
8
8
|
out_dir: ???
|
|
9
9
|
inputs: ??? # null, json, pdb or
|
|
10
|
-
ckpt_path:
|
|
10
|
+
ckpt_path: rfd3
|
|
11
11
|
json_keys_subset: null
|
|
12
12
|
skip_existing: True
|
|
13
13
|
|
|
@@ -61,5 +61,5 @@ global_prefix: null
|
|
|
61
61
|
dump_prediction_metadata_json: True
|
|
62
62
|
dump_trajectories: False
|
|
63
63
|
align_trajectory_structures: False
|
|
64
|
-
prevalidate_inputs:
|
|
64
|
+
prevalidate_inputs: False
|
|
65
65
|
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
|
rfd3/engine.py
CHANGED
|
@@ -21,9 +21,14 @@ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
|
|
21
21
|
from rfd3.inference.datasets import (
|
|
22
22
|
assemble_distributed_inference_loader_from_json,
|
|
23
23
|
)
|
|
24
|
-
from rfd3.inference.input_parsing import
|
|
24
|
+
from rfd3.inference.input_parsing import (
|
|
25
|
+
DesignInputSpecification,
|
|
26
|
+
ensure_input_is_abspath,
|
|
27
|
+
)
|
|
25
28
|
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
|
26
|
-
from rfd3.utils.inference import
|
|
29
|
+
from rfd3.utils.inference import (
|
|
30
|
+
ensure_inference_sampler_matches_design_spec,
|
|
31
|
+
)
|
|
27
32
|
from rfd3.utils.io import (
|
|
28
33
|
CIF_LIKE_EXTENSIONS,
|
|
29
34
|
build_stack_from_atom_array_and_batched_coords,
|
|
@@ -171,6 +176,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
171
176
|
)
|
|
172
177
|
# save
|
|
173
178
|
self.specification_overrides = dict(specification or {})
|
|
179
|
+
self.inference_sampler_overrides = dict(inference_sampler or {})
|
|
174
180
|
|
|
175
181
|
# Setup output directories and args
|
|
176
182
|
self.global_prefix = global_prefix
|
|
@@ -210,6 +216,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
210
216
|
inputs=inputs,
|
|
211
217
|
n_batches=n_batches,
|
|
212
218
|
)
|
|
219
|
+
ensure_inference_sampler_matches_design_spec(
|
|
220
|
+
design_specifications, self.inference_sampler_overrides
|
|
221
|
+
)
|
|
213
222
|
# init before
|
|
214
223
|
self.initialize()
|
|
215
224
|
outputs = self._run_multi(design_specifications)
|
|
@@ -383,6 +392,15 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
383
392
|
# Based on inputs, construct the specifications to loop through
|
|
384
393
|
design_specifications = {}
|
|
385
394
|
for prefix, example_spec in inputs.items():
|
|
395
|
+
# Record task name in the specification
|
|
396
|
+
if isinstance(example_spec, DesignInputSpecification):
|
|
397
|
+
example_spec.extra = example_spec.extra or {}
|
|
398
|
+
example_spec.extra["task_name"] = prefix
|
|
399
|
+
else:
|
|
400
|
+
if "extra" not in example_spec:
|
|
401
|
+
example_spec["extra"] = {}
|
|
402
|
+
example_spec["extra"]["task_name"] = prefix
|
|
403
|
+
|
|
386
404
|
# ... Create n_batches for example
|
|
387
405
|
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
|
388
406
|
# ... Example ID
|
|
@@ -524,21 +542,19 @@ def process_input(
|
|
|
524
542
|
|
|
525
543
|
|
|
526
544
|
def _reshape_trajectory(traj, align_structures: bool):
|
|
527
|
-
traj = [traj[i] for i in range(len(traj))]
|
|
528
|
-
n_steps = len(traj)
|
|
545
|
+
traj = [traj[i] for i in range(len(traj))] # make list of arrays
|
|
529
546
|
max_frames = 100
|
|
530
|
-
|
|
547
|
+
if len(traj) > max_frames:
|
|
548
|
+
selected_indices = torch.linspace(0, len(traj) - 1, max_frames).long().tolist()
|
|
549
|
+
traj = [traj[i] for i in selected_indices]
|
|
531
550
|
if align_structures:
|
|
532
551
|
# ... align the trajectories on the last prediction
|
|
533
|
-
for step in range(
|
|
552
|
+
for step in range(len(traj) - 1):
|
|
534
553
|
traj[step] = weighted_rigid_align(
|
|
535
|
-
X_L=traj[-1],
|
|
536
|
-
X_gt_L=traj[step],
|
|
537
|
-
)
|
|
554
|
+
X_L=traj[-1][None],
|
|
555
|
+
X_gt_L=traj[step][None],
|
|
556
|
+
).squeeze(0)
|
|
538
557
|
traj = traj[::-1] # reverse to go from noised -> denoised
|
|
539
|
-
if n_steps > max_frames:
|
|
540
|
-
selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
|
|
541
|
-
traj = [traj[i] for i in selected_indices]
|
|
542
558
|
|
|
543
559
|
traj = torch.stack(traj).cpu().numpy()
|
|
544
560
|
return traj
|
rfd3/inference/datasets.py
CHANGED
|
@@ -14,8 +14,8 @@ from atomworks.ml.transforms.base import Compose, Transform
|
|
|
14
14
|
from omegaconf import DictConfig, OmegaConf
|
|
15
15
|
from rfd3.inference.input_parsing import (
|
|
16
16
|
DesignInputSpecification,
|
|
17
|
+
ensure_input_is_abspath,
|
|
17
18
|
)
|
|
18
|
-
from rfd3.utils.inference import ensure_input_is_abspath
|
|
19
19
|
from torch.utils.data import (
|
|
20
20
|
DataLoader,
|
|
21
21
|
SequentialSampler,
|
rfd3/inference/input_parsing.py
CHANGED
|
@@ -5,6 +5,7 @@ import os
|
|
|
5
5
|
import time
|
|
6
6
|
import warnings
|
|
7
7
|
from contextlib import contextmanager
|
|
8
|
+
from os import PathLike
|
|
8
9
|
from typing import Any, Dict, List, Optional, Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
@@ -696,7 +697,7 @@ class DesignInputSpecification(BaseModel):
|
|
|
696
697
|
# Partial diffusion: use COM, keep all coordinates
|
|
697
698
|
if exists(self.symmetry) and self.symmetry.id:
|
|
698
699
|
# For symmetric structures, avoid COM centering that would collapse chains
|
|
699
|
-
|
|
700
|
+
logger.info(
|
|
700
701
|
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
|
|
701
702
|
)
|
|
702
703
|
else:
|
|
@@ -1121,3 +1122,33 @@ def accumulate_components(
|
|
|
1121
1122
|
if atom_array_accum.bonds is None:
|
|
1122
1123
|
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
|
|
1123
1124
|
return atom_array_accum
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
def ensure_input_is_abspath(args: Dict[str, Any], path: PathLike | None):
|
|
1128
|
+
"""
|
|
1129
|
+
Ensures the input source is an absolute path if exists, if not it will convert
|
|
1130
|
+
|
|
1131
|
+
args:
|
|
1132
|
+
args: Inference specification for atom array
|
|
1133
|
+
path: None or file to which the input is relative to.
|
|
1134
|
+
"""
|
|
1135
|
+
if isinstance(args, str):
|
|
1136
|
+
raise ValueError(
|
|
1137
|
+
"Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
|
|
1138
|
+
args
|
|
1139
|
+
)
|
|
1140
|
+
)
|
|
1141
|
+
if "input" not in args or not exists(args["input"]):
|
|
1142
|
+
return args
|
|
1143
|
+
input = str(args["input"])
|
|
1144
|
+
if not os.path.isabs(input):
|
|
1145
|
+
if path is None:
|
|
1146
|
+
raise ValueError(
|
|
1147
|
+
"Input path is relative, but no base path was provided to resolve it against."
|
|
1148
|
+
)
|
|
1149
|
+
input = os.path.abspath(os.path.join(os.path.dirname(str(path)), input))
|
|
1150
|
+
logger.info(
|
|
1151
|
+
f"Input source path is relative, converted to absolute path: {input}"
|
|
1152
|
+
)
|
|
1153
|
+
args["input"] = input
|
|
1154
|
+
return args
|
|
@@ -139,13 +139,18 @@ def fetch_motif_residue_(
|
|
|
139
139
|
subarray, motif=True, unindexed=False, dtype=int
|
|
140
140
|
) # all values init to True (fix all)
|
|
141
141
|
|
|
142
|
+
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
143
|
+
to_index = f"{src_chain}{src_resid}" in components
|
|
144
|
+
|
|
142
145
|
# Assign is motif atom and sequence
|
|
143
146
|
if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
|
|
147
|
+
# If specified, we set fixed atoms in the residue to be motif atoms
|
|
144
148
|
atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
|
|
145
149
|
subarray.set_annotation("is_motif_atom", atom_mask)
|
|
146
150
|
# subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
|
|
147
151
|
|
|
148
152
|
elif redesign_motif_sidechains and res_name in STANDARD_AA:
|
|
153
|
+
# If redesign_motif_sidechains is True, we only make the backbone atoms to be motif atoms
|
|
149
154
|
n_atoms = subarray.shape[0]
|
|
150
155
|
diffuse_oxygen = False
|
|
151
156
|
if n_atoms < 3:
|
|
@@ -178,6 +183,18 @@ def fetch_motif_residue_(
|
|
|
178
183
|
subarray.set_annotation(
|
|
179
184
|
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
|
180
185
|
)
|
|
186
|
+
elif to_index or to_unindex:
|
|
187
|
+
# If the residue is in the contig or unindexed components,
|
|
188
|
+
# we set all atoms in the residue to be motif atoms
|
|
189
|
+
subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int))
|
|
190
|
+
else:
|
|
191
|
+
if to_unindex and not (
|
|
192
|
+
unfix_all or f"{src_chain}{src_resid}" in unfix_residues
|
|
193
|
+
):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
|
|
196
|
+
"Please check your input and contig specification."
|
|
197
|
+
)
|
|
181
198
|
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
|
|
182
199
|
subarray.set_annotation(
|
|
183
200
|
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
|
|
@@ -197,7 +214,6 @@ def fetch_motif_residue_(
|
|
|
197
214
|
subarray.set_annotation(
|
|
198
215
|
"is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
|
|
199
216
|
)
|
|
200
|
-
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
201
217
|
if to_unindex:
|
|
202
218
|
subarray.set_annotation(
|
|
203
219
|
"is_motif_atom_unindexed", subarray.is_motif_atom.copy()
|
rfd3/inference/parsing.py
CHANGED
|
@@ -117,6 +117,7 @@ def from_any_(v: Any, atom_array: AtomArray):
|
|
|
117
117
|
|
|
118
118
|
# Split to atom names
|
|
119
119
|
data_split[idx] = token.atom_name[comp_mask_subset].tolist()
|
|
120
|
+
# TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented
|
|
120
121
|
|
|
121
122
|
# Update mask & token dictionary
|
|
122
123
|
mask[comp_mask] = comp_mask_subset
|
|
@@ -1,14 +1,74 @@
|
|
|
1
|
+
import string
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from rfd3.inference.symmetry.frames import (
|
|
3
5
|
decompose_symmetry_frame,
|
|
4
6
|
get_symmetry_frames_from_symmetry_id,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
|
-
from foundry.utils.ddp import RankedLogger
|
|
8
|
-
|
|
9
9
|
FIXED_TRANSFORM_ID = -1
|
|
10
10
|
FIXED_ENTITY_ID = -1
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
# Alphabet for chain ID generation (uppercase letters only, per wwPDB convention)
|
|
13
|
+
_CHAIN_ALPHABET = string.ascii_uppercase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def index_to_chain_id(index: int) -> str:
|
|
17
|
+
"""
|
|
18
|
+
Convert a zero-based index to a chain ID following wwPDB convention.
|
|
19
|
+
|
|
20
|
+
The naming follows the wwPDB-assigned chain ID system:
|
|
21
|
+
- 0-25: A-Z (single letter)
|
|
22
|
+
- 26-701: AA-ZZ (double letter)
|
|
23
|
+
- 702-18277: AAA-ZZZ (triple letter)
|
|
24
|
+
- And so on...
|
|
25
|
+
|
|
26
|
+
This is similar to Excel column naming (A, B, ..., Z, AA, AB, ...).
|
|
27
|
+
|
|
28
|
+
Arguments:
|
|
29
|
+
index: zero-based index (0 -> 'A', 25 -> 'Z', 26 -> 'AA', etc.)
|
|
30
|
+
Returns:
|
|
31
|
+
chain_id: string chain identifier
|
|
32
|
+
"""
|
|
33
|
+
if index < 0:
|
|
34
|
+
raise ValueError(f"Chain index must be non-negative, got {index}")
|
|
35
|
+
|
|
36
|
+
result = ""
|
|
37
|
+
remaining = index
|
|
38
|
+
|
|
39
|
+
# Convert to bijective base-26 (like Excel columns)
|
|
40
|
+
while True:
|
|
41
|
+
result = _CHAIN_ALPHABET[remaining % 26] + result
|
|
42
|
+
remaining = remaining // 26 - 1
|
|
43
|
+
if remaining < 0:
|
|
44
|
+
break
|
|
45
|
+
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def chain_id_to_index(chain_id: str) -> int:
|
|
50
|
+
"""
|
|
51
|
+
Convert a chain ID back to a zero-based index.
|
|
52
|
+
|
|
53
|
+
Inverse of index_to_chain_id.
|
|
54
|
+
|
|
55
|
+
Arguments:
|
|
56
|
+
chain_id: string chain identifier (e.g., 'A', 'Z', 'AA', 'AB')
|
|
57
|
+
Returns:
|
|
58
|
+
index: zero-based index
|
|
59
|
+
"""
|
|
60
|
+
if not chain_id or not all(c in _CHAIN_ALPHABET for c in chain_id):
|
|
61
|
+
raise ValueError(f"Invalid chain ID: {chain_id}")
|
|
62
|
+
|
|
63
|
+
# Offset for all shorter chain IDs (26 + 26^2 + ... + 26^(len-1))
|
|
64
|
+
offset = sum(26**k for k in range(1, len(chain_id)))
|
|
65
|
+
|
|
66
|
+
# Value within the current length group (standard base-26)
|
|
67
|
+
value = 0
|
|
68
|
+
for char in chain_id:
|
|
69
|
+
value = value * 26 + _CHAIN_ALPHABET.index(char)
|
|
70
|
+
|
|
71
|
+
return offset + value
|
|
12
72
|
|
|
13
73
|
|
|
14
74
|
########################################################
|
|
@@ -28,7 +88,7 @@ def add_sym_annotations(atom_array, sym_conf):
|
|
|
28
88
|
is_asu = np.full(n, True, dtype=np.bool_)
|
|
29
89
|
atom_array.set_annotation("is_sym_asu", is_asu)
|
|
30
90
|
# symmetry_id
|
|
31
|
-
symmetry_ids = np.full(n, sym_conf.
|
|
91
|
+
symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
|
|
32
92
|
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
33
93
|
return atom_array
|
|
34
94
|
|
|
@@ -251,11 +311,13 @@ def reset_chain_ids(atom_array, start_id):
|
|
|
251
311
|
Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
|
|
252
312
|
Arguments:
|
|
253
313
|
atom_array: atom array with chain_ids and pn_unit_iids annotated
|
|
314
|
+
start_id: starting chain ID (e.g., 'A')
|
|
254
315
|
"""
|
|
255
316
|
chain_ids = np.unique(atom_array.chain_id)
|
|
256
|
-
|
|
257
|
-
for
|
|
258
|
-
|
|
317
|
+
start_index = chain_id_to_index(start_id)
|
|
318
|
+
for i, old_id in enumerate(chain_ids):
|
|
319
|
+
new_id = index_to_chain_id(start_index + i)
|
|
320
|
+
atom_array.chain_id[atom_array.chain_id == old_id] = new_id
|
|
259
321
|
atom_array.pn_unit_iid = atom_array.chain_id
|
|
260
322
|
return atom_array
|
|
261
323
|
|
|
@@ -263,15 +325,18 @@ def reset_chain_ids(atom_array, start_id):
|
|
|
263
325
|
def reannotate_chain_ids(atom_array, offset, multiplier=0):
|
|
264
326
|
"""
|
|
265
327
|
Reannotate the chain ids and pn_unit_iids of an atom array.
|
|
328
|
+
|
|
329
|
+
Uses wwPDB-style chain IDs (A-Z, AA-ZZ, AAA-ZZZ, ...) to support
|
|
330
|
+
any number of chains.
|
|
331
|
+
|
|
266
332
|
Arguments:
|
|
267
333
|
atom_array: protein atom array with chain_ids and pn_unit_iids annotated
|
|
268
|
-
offset: offset to add to the chain ids
|
|
269
|
-
multiplier: multiplier
|
|
334
|
+
offset: offset to add to the chain ids (typically num_chains in ASU)
|
|
335
|
+
multiplier: multiplier for the offset (typically transform index)
|
|
270
336
|
"""
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
)
|
|
274
|
-
chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
|
|
337
|
+
chain_ids_indices = np.array([chain_id_to_index(c) for c in atom_array.chain_id])
|
|
338
|
+
new_indices = chain_ids_indices + offset * multiplier
|
|
339
|
+
chain_ids = np.array([index_to_chain_id(idx) for idx in new_indices], dtype="U4")
|
|
275
340
|
atom_array.chain_id = chain_ids
|
|
276
341
|
atom_array.pn_unit_iid = chain_ids
|
|
277
342
|
return atom_array
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from rfd3.inference.symmetry.contigs import
|
|
2
|
+
from rfd3.inference.symmetry.contigs import (
|
|
3
|
+
expand_contig_unsym_motif,
|
|
4
|
+
get_unsym_motif_mask,
|
|
5
|
+
)
|
|
3
6
|
from rfd3.transforms.conditioning_base import get_motif_features
|
|
4
7
|
|
|
5
8
|
from foundry.utils.ddp import RankedLogger
|
|
6
9
|
|
|
7
|
-
MIN_ATOMS_ALIGN =
|
|
10
|
+
MIN_ATOMS_ALIGN = 30
|
|
8
11
|
MAX_TRANSFORMS = 10
|
|
9
12
|
RMSD_CUT = 1.0 # Angstroms
|
|
10
13
|
|
|
@@ -18,32 +21,44 @@ def check_symmetry_config(
|
|
|
18
21
|
Check if the symmetry configuration is valid. Add all basic checks here.
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
|
-
assert sym_conf.
|
|
24
|
+
assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
|
|
22
25
|
# if unsym motif is provided, check that each motif name is in the atom array
|
|
23
|
-
|
|
26
|
+
|
|
27
|
+
is_motif_atom = get_motif_features(atom_array)["is_motif_atom"]
|
|
28
|
+
is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
|
|
29
|
+
|
|
30
|
+
if not is_motif_atom.any():
|
|
31
|
+
sym_conf.is_symmetric_motif = None
|
|
32
|
+
ranked_logger.warning(
|
|
33
|
+
"No motifs found in atom array. Setting is_symmetric_motif to None."
|
|
34
|
+
)
|
|
35
|
+
return sym_conf
|
|
36
|
+
|
|
37
|
+
if sym_conf.is_unsym_motif:
|
|
24
38
|
assert (
|
|
25
39
|
src_atom_array is not None
|
|
26
40
|
), "Source atom array must be provided for symmetric motifs"
|
|
27
|
-
unsym_motif_names = sym_conf
|
|
41
|
+
unsym_motif_names = sym_conf.is_unsym_motif.split(",")
|
|
28
42
|
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
43
|
+
is_unsym_motif = get_unsym_motif_mask(atom_array, unsym_motif_names)
|
|
29
44
|
for n in unsym_motif_names:
|
|
30
45
|
if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
|
|
31
46
|
raise ValueError(f"Unsym motif {n} not found in atom_array")
|
|
47
|
+
|
|
32
48
|
if (
|
|
33
|
-
|
|
34
|
-
and not sym_conf.
|
|
49
|
+
is_motif_atom[~is_unsym_motif].any()
|
|
50
|
+
and not sym_conf.is_symmetric_motif
|
|
35
51
|
and not has_dist_cond
|
|
36
52
|
):
|
|
37
53
|
raise ValueError(
|
|
38
|
-
"Asymmetric motif inputs
|
|
39
|
-
"Use atomwise_fixed_dist to constrain the distance between the motif atoms."
|
|
54
|
+
"Asymmetric motif inputs are not supported yet. Please provide a symmetric motif."
|
|
40
55
|
)
|
|
41
|
-
# else: if unconditional symmetry, no need to have symmetric input motif
|
|
42
56
|
|
|
43
|
-
if partial and not sym_conf.
|
|
57
|
+
if partial and not sym_conf.is_symmetric_motif:
|
|
44
58
|
raise ValueError(
|
|
45
59
|
"Partial diffusion with symmetry is only supported for symmetric inputs."
|
|
46
60
|
)
|
|
61
|
+
return sym_conf
|
|
47
62
|
|
|
48
63
|
|
|
49
64
|
def check_atom_array_is_symmetric(atom_array):
|
|
@@ -54,9 +69,6 @@ def check_atom_array_is_symmetric(atom_array):
|
|
|
54
69
|
Returns:
|
|
55
70
|
bool: True if the atom array is symmetric, False otherwise
|
|
56
71
|
"""
|
|
57
|
-
# TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
|
|
58
|
-
# and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
|
|
59
|
-
|
|
60
72
|
import biotite.structure as struc
|
|
61
73
|
from rfd3.inference.symmetry.atom_array import (
|
|
62
74
|
apply_symmetry_to_atomarray_coord,
|
|
@@ -68,8 +80,10 @@ def check_atom_array_is_symmetric(atom_array):
|
|
|
68
80
|
# remove hetero atoms
|
|
69
81
|
atom_array = atom_array[~atom_array.hetero]
|
|
70
82
|
if len(atom_array) == 0:
|
|
71
|
-
ranked_logger.
|
|
72
|
-
|
|
83
|
+
ranked_logger.warning(
|
|
84
|
+
"Atom array has no protein chains. Please check your input."
|
|
85
|
+
)
|
|
86
|
+
return True
|
|
73
87
|
|
|
74
88
|
chains = np.unique(atom_array.chain_id)
|
|
75
89
|
asu_mask = atom_array.chain_id == chains[0]
|
|
@@ -162,16 +176,22 @@ def find_optimal_rotation(coords1, coords2, max_points=1000):
|
|
|
162
176
|
return None
|
|
163
177
|
|
|
164
178
|
|
|
165
|
-
def check_input_frames_match_symmetry_frames(
|
|
179
|
+
def check_input_frames_match_symmetry_frames(
|
|
180
|
+
computed_frames, original_frames, nids_by_entity
|
|
181
|
+
) -> None:
|
|
166
182
|
"""
|
|
167
183
|
Check if the atom array matches the symmetry_id.
|
|
168
184
|
Arguments:
|
|
169
185
|
computed_frames: list of computed frames
|
|
170
186
|
original_frames: list of original frames
|
|
171
187
|
"""
|
|
172
|
-
assert len(computed_frames) == len(
|
|
173
|
-
|
|
174
|
-
|
|
188
|
+
assert len(computed_frames) == len(original_frames), (
|
|
189
|
+
"Number of computed frames does not match number of original frames.\n"
|
|
190
|
+
f"Computed Frames: {len(computed_frames)}. Original Frames: {len(original_frames)}.\n"
|
|
191
|
+
"If the computed frames are not as expected, please check if you have one-to-one mapping "
|
|
192
|
+
"(size, sequence, folding) of an entity across all chains.\n"
|
|
193
|
+
f"Computed Entity Mapping: {nids_by_entity}."
|
|
194
|
+
)
|
|
175
195
|
|
|
176
196
|
|
|
177
197
|
def check_valid_multiplicity(nids_by_entity) -> None:
|
|
@@ -184,25 +204,35 @@ def check_valid_multiplicity(nids_by_entity) -> None:
|
|
|
184
204
|
multiplicity = min([len(i) for i in nids_by_entity.values()])
|
|
185
205
|
if multiplicity == 1: # no possible symmetry
|
|
186
206
|
raise ValueError(
|
|
187
|
-
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead
|
|
207
|
+
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead.\n"
|
|
208
|
+
"Multiplicity: 1"
|
|
188
209
|
)
|
|
189
210
|
|
|
190
211
|
# Check that the input is not asymmetric
|
|
191
212
|
multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
|
|
192
213
|
if not all(multiplicity_good):
|
|
193
|
-
raise ValueError(
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"Expected multiplicity does not match for some entities.\n"
|
|
216
|
+
"Please modify your input to have one-to-one mapping (size, sequence, folding) of an entity across all chains.\n"
|
|
217
|
+
f"Expected Multiplicity: {multiplicity}.\n"
|
|
218
|
+
f"Computed Entity Mapping: {nids_by_entity}."
|
|
219
|
+
)
|
|
194
220
|
|
|
195
221
|
|
|
196
222
|
def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
|
|
197
223
|
"""
|
|
198
224
|
Check that the subunits in the input are of the same size.
|
|
199
225
|
Arguments:
|
|
200
|
-
nids_by_entity: dict mapping entity to ids
|
|
226
|
+
nids_by_entity: dict mapping entity to ids. e.g. {0: (['A_1', 'B_1', 'C_1']), 1: (['A_2', 'B_2', 'C_2'])}
|
|
227
|
+
pn_unit_id: array of ids. e.g. ['A_1', 'B_1', 'C_1', 'A_2', 'B_2', 'C_2']
|
|
201
228
|
"""
|
|
202
|
-
for
|
|
203
|
-
for
|
|
204
|
-
if (pn_unit_id == js[0]).sum() != (pn_unit_id ==
|
|
205
|
-
raise ValueError(
|
|
229
|
+
for js in nids_by_entity.values():
|
|
230
|
+
for js_i in js[1:]:
|
|
231
|
+
if (pn_unit_id == js[0]).sum() != (pn_unit_id == js_i).sum():
|
|
232
|
+
raise ValueError(
|
|
233
|
+
f"Size mismatch between chain {js[0]} ({(pn_unit_id == js[0]).sum()} atoms) "
|
|
234
|
+
f"and chain {js_i} ({(pn_unit_id == js_i).sum()} atoms). Please check your input file."
|
|
235
|
+
)
|
|
206
236
|
|
|
207
237
|
|
|
208
238
|
def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
@@ -212,7 +242,10 @@ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
|
212
242
|
nids_by_entity: dict mapping entity to ids
|
|
213
243
|
"""
|
|
214
244
|
if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
|
|
215
|
-
raise ValueError(
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"Not enough atoms to align < {MIN_ATOMS_ALIGN} atoms."
|
|
247
|
+
f"Please provide a input with at least {MIN_ATOMS_ALIGN} atoms."
|
|
248
|
+
)
|
|
216
249
|
|
|
217
250
|
|
|
218
251
|
def check_max_transforms(chains_to_consider) -> None:
|
|
@@ -224,7 +257,7 @@ def check_max_transforms(chains_to_consider) -> None:
|
|
|
224
257
|
"""
|
|
225
258
|
if len(chains_to_consider) > MAX_TRANSFORMS:
|
|
226
259
|
raise ValueError(
|
|
227
|
-
"Number of transforms exceeds the max number of transforms (
|
|
260
|
+
f"Number of transforms exceeds the max number of transforms ({MAX_TRANSFORMS})."
|
|
228
261
|
)
|
|
229
262
|
|
|
230
263
|
|