rc-foundry 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/checkpoint_registry.py +58 -11
- foundry/utils/alignment.py +10 -2
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +66 -66
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/METADATA +25 -20
- rc_foundry-0.1.7.dist-info/RECORD +311 -0
- rf3/configs/callbacks/default.yaml +5 -0
- rf3/configs/callbacks/dump_validation_structures.yaml +6 -0
- rf3/configs/callbacks/metrics_logging.yaml +10 -0
- rf3/configs/callbacks/train_logging.yaml +16 -0
- rf3/configs/dataloader/default.yaml +15 -0
- rf3/configs/datasets/base.yaml +31 -0
- rf3/configs/datasets/pdb_and_distillation.yaml +58 -0
- rf3/configs/datasets/pdb_only.yaml +17 -0
- rf3/configs/datasets/train/disorder_distillation.yaml +48 -0
- rf3/configs/datasets/train/domain_distillation.yaml +50 -0
- rf3/configs/datasets/train/monomer_distillation.yaml +49 -0
- rf3/configs/datasets/train/na_complex_distillation.yaml +50 -0
- rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml +8 -0
- rf3/configs/datasets/train/pdb/base.yaml +32 -0
- rf3/configs/datasets/train/pdb/plinder.yaml +54 -0
- rf3/configs/datasets/train/pdb/train_interface.yaml +51 -0
- rf3/configs/datasets/train/pdb/train_pn_unit.yaml +46 -0
- rf3/configs/datasets/train/rna_monomer_distillation.yaml +56 -0
- rf3/configs/datasets/val/af3_ab_set.yaml +11 -0
- rf3/configs/datasets/val/af3_validation.yaml +11 -0
- rf3/configs/datasets/val/base.yaml +32 -0
- rf3/configs/datasets/val/runs_and_poses.yaml +12 -0
- rf3/configs/debug/default.yaml +66 -0
- rf3/configs/debug/train_specific_examples.yaml +21 -0
- rf3/configs/experiment/pretrained/rf3.yaml +50 -0
- rf3/configs/experiment/pretrained/rf3_with_confidence.yaml +13 -0
- rf3/configs/experiment/quick-rf3-with-confidence.yaml +15 -0
- rf3/configs/experiment/quick-rf3.yaml +61 -0
- rf3/configs/hydra/default.yaml +18 -0
- rf3/configs/hydra/no_logging.yaml +7 -0
- rf3/configs/inference.yaml +7 -0
- rf3/configs/inference_engine/base.yaml +23 -0
- rf3/configs/inference_engine/rf3.yaml +33 -0
- rf3/configs/logger/csv.yaml +6 -0
- rf3/configs/logger/default.yaml +3 -0
- rf3/configs/logger/wandb.yaml +15 -0
- rf3/configs/model/components/ema.yaml +1 -0
- rf3/configs/model/components/rf3_net.yaml +177 -0
- rf3/configs/model/components/rf3_net_with_confidence_head.yaml +45 -0
- rf3/configs/model/optimizers/adam.yaml +5 -0
- rf3/configs/model/rf3.yaml +43 -0
- rf3/configs/model/rf3_with_confidence.yaml +7 -0
- rf3/configs/model/schedulers/af3.yaml +6 -0
- rf3/configs/paths/data/default.yaml +43 -0
- rf3/configs/paths/default.yaml +21 -0
- rf3/configs/train.yaml +42 -0
- rf3/configs/trainer/cpu.yaml +6 -0
- rf3/configs/trainer/ddp.yaml +5 -0
- rf3/configs/trainer/loss/losses/confidence_loss.yaml +29 -0
- rf3/configs/trainer/loss/losses/diffusion_loss.yaml +9 -0
- rf3/configs/trainer/loss/losses/distogram_loss.yaml +2 -0
- rf3/configs/trainer/loss/structure_prediction.yaml +4 -0
- rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml +2 -0
- rf3/configs/trainer/metrics/structure_prediction.yaml +14 -0
- rf3/configs/trainer/rf3.yaml +20 -0
- rf3/configs/trainer/rf3_with_confidence.yaml +13 -0
- rf3/configs/validate.yaml +45 -0
- rfd3/cli.py +10 -4
- rfd3/configs/__init__.py +0 -0
- rfd3/configs/callbacks/design_callbacks.yaml +10 -0
- rfd3/configs/callbacks/metrics_logging.yaml +20 -0
- rfd3/configs/callbacks/train_logging.yaml +24 -0
- rfd3/configs/dataloader/default.yaml +15 -0
- rfd3/configs/dataloader/fast.yaml +11 -0
- rfd3/configs/datasets/conditions/dna_condition.yaml +3 -0
- rfd3/configs/datasets/conditions/island.yaml +28 -0
- rfd3/configs/datasets/conditions/ppi.yaml +2 -0
- rfd3/configs/datasets/conditions/sequence_design.yaml +17 -0
- rfd3/configs/datasets/conditions/tipatom.yaml +28 -0
- rfd3/configs/datasets/conditions/unconditional.yaml +21 -0
- rfd3/configs/datasets/design_base.yaml +97 -0
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +46 -0
- rfd3/configs/datasets/train/pdb/af3_train_pn_unit.yaml +42 -0
- rfd3/configs/datasets/train/pdb/base.yaml +14 -0
- rfd3/configs/datasets/train/pdb/base_no_weights.yaml +19 -0
- rfd3/configs/datasets/train/pdb/base_transform_args.yaml +59 -0
- rfd3/configs/datasets/train/pdb/na_complex_distillation.yaml +20 -0
- rfd3/configs/datasets/train/pdb/pdb_base.yaml +11 -0
- rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml +22 -0
- rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml +23 -0
- rfd3/configs/datasets/train/rfd3_monomer_distillation.yaml +38 -0
- rfd3/configs/datasets/val/bcov_ppi_easy_medium.yaml +9 -0
- rfd3/configs/datasets/val/design_validation_base.yaml +40 -0
- rfd3/configs/datasets/val/dna_binder_design5.yaml +9 -0
- rfd3/configs/datasets/val/dna_binder_long.yaml +13 -0
- rfd3/configs/datasets/val/dna_binder_short.yaml +13 -0
- rfd3/configs/datasets/val/indexed.yaml +9 -0
- rfd3/configs/datasets/val/mcsa_41.yaml +9 -0
- rfd3/configs/datasets/val/mcsa_41_short_rigid.yaml +10 -0
- rfd3/configs/datasets/val/ppi_inference.yaml +7 -0
- rfd3/configs/datasets/val/sm_binder_hbonds.yaml +13 -0
- rfd3/configs/datasets/val/sm_binder_hbonds_short.yaml +15 -0
- rfd3/configs/datasets/val/unconditional.yaml +9 -0
- rfd3/configs/datasets/val/unconditional_deep.yaml +9 -0
- rfd3/configs/datasets/val/unindexed.yaml +8 -0
- rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori.yaml +151 -0
- rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_spoof_helical_bundle.yaml +7 -0
- rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_varying_lengths.yaml +28 -0
- rfd3/configs/datasets/val/val_examples/bpem_ori_hb.yaml +212 -0
- rfd3/configs/debug/default.yaml +64 -0
- rfd3/configs/debug/train_specific_examples.yaml +21 -0
- rfd3/configs/dev.yaml +9 -0
- rfd3/configs/experiment/debug.yaml +14 -0
- rfd3/configs/experiment/pretrain.yaml +31 -0
- rfd3/configs/experiment/test-uncond.yaml +10 -0
- rfd3/configs/experiment/test-unindexed.yaml +21 -0
- rfd3/configs/hydra/default.yaml +18 -0
- rfd3/configs/hydra/no_logging.yaml +7 -0
- rfd3/configs/inference.yaml +9 -0
- rfd3/configs/inference_engine/base.yaml +15 -0
- rfd3/configs/inference_engine/dev.yaml +20 -0
- rfd3/configs/inference_engine/rfdiffusion3.yaml +65 -0
- rfd3/configs/logger/csv.yaml +6 -0
- rfd3/configs/logger/default.yaml +2 -0
- rfd3/configs/logger/wandb.yaml +15 -0
- rfd3/configs/model/components/ema.yaml +1 -0
- rfd3/configs/model/components/rfd3_net.yaml +131 -0
- rfd3/configs/model/optimizers/adam.yaml +5 -0
- rfd3/configs/model/rfd3_base.yaml +8 -0
- rfd3/configs/model/samplers/edm.yaml +21 -0
- rfd3/configs/model/samplers/symmetry.yaml +10 -0
- rfd3/configs/model/schedulers/af3.yaml +6 -0
- rfd3/configs/paths/data/default.yaml +18 -0
- rfd3/configs/paths/default.yaml +22 -0
- rfd3/configs/train.yaml +28 -0
- rfd3/configs/trainer/cpu.yaml +6 -0
- rfd3/configs/trainer/ddp.yaml +5 -0
- rfd3/configs/trainer/loss/losses/diffusion_loss.yaml +12 -0
- rfd3/configs/trainer/loss/losses/sequence_loss.yaml +3 -0
- rfd3/configs/trainer/metrics/design_metrics.yaml +22 -0
- rfd3/configs/trainer/rfd3_base.yaml +35 -0
- rfd3/configs/validate.yaml +34 -0
- rfd3/engine.py +19 -11
- rfd3/inference/input_parsing.py +1 -1
- rfd3/inference/legacy_input_parsing.py +17 -1
- rfd3/inference/parsing.py +1 -0
- rfd3/inference/symmetry/atom_array.py +1 -5
- rfd3/inference/symmetry/checks.py +53 -28
- rfd3/inference/symmetry/frames.py +8 -5
- rfd3/inference/symmetry/symmetry_utils.py +38 -60
- rfd3/run_inference.py +3 -1
- rfd3/utils/inference.py +23 -0
- rc_foundry-0.1.5.dist-info/RECORD +0 -180
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -139,13 +139,18 @@ def fetch_motif_residue_(
|
|
|
139
139
|
subarray, motif=True, unindexed=False, dtype=int
|
|
140
140
|
) # all values init to True (fix all)
|
|
141
141
|
|
|
142
|
+
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
143
|
+
to_index = f"{src_chain}{src_resid}" in components
|
|
144
|
+
|
|
142
145
|
# Assign is motif atom and sequence
|
|
143
146
|
if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
|
|
147
|
+
# If specified, we set fixed atoms in the residue to be motif atoms
|
|
144
148
|
atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
|
|
145
149
|
subarray.set_annotation("is_motif_atom", atom_mask)
|
|
146
150
|
# subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
|
|
147
151
|
|
|
148
152
|
elif redesign_motif_sidechains and res_name in STANDARD_AA:
|
|
153
|
+
# If redesign_motif_sidechains is True, we only make the backbone atoms to be motif atoms
|
|
149
154
|
n_atoms = subarray.shape[0]
|
|
150
155
|
diffuse_oxygen = False
|
|
151
156
|
if n_atoms < 3:
|
|
@@ -178,6 +183,18 @@ def fetch_motif_residue_(
|
|
|
178
183
|
subarray.set_annotation(
|
|
179
184
|
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
|
180
185
|
)
|
|
186
|
+
elif to_index or to_unindex:
|
|
187
|
+
# If the residue is in the contig or unindexed components,
|
|
188
|
+
# we set all atoms in the residue to be motif atoms
|
|
189
|
+
subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int))
|
|
190
|
+
else:
|
|
191
|
+
if to_unindex and not (
|
|
192
|
+
unfix_all or f"{src_chain}{src_resid}" in unfix_residues
|
|
193
|
+
):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
|
|
196
|
+
"Please check your input and contig specification."
|
|
197
|
+
)
|
|
181
198
|
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
|
|
182
199
|
subarray.set_annotation(
|
|
183
200
|
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
|
|
@@ -197,7 +214,6 @@ def fetch_motif_residue_(
|
|
|
197
214
|
subarray.set_annotation(
|
|
198
215
|
"is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
|
|
199
216
|
)
|
|
200
|
-
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
201
217
|
if to_unindex:
|
|
202
218
|
subarray.set_annotation(
|
|
203
219
|
"is_motif_atom_unindexed", subarray.is_motif_atom.copy()
|
rfd3/inference/parsing.py
CHANGED
|
@@ -117,6 +117,7 @@ def from_any_(v: Any, atom_array: AtomArray):
|
|
|
117
117
|
|
|
118
118
|
# Split to atom names
|
|
119
119
|
data_split[idx] = token.atom_name[comp_mask_subset].tolist()
|
|
120
|
+
# TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented
|
|
120
121
|
|
|
121
122
|
# Update mask & token dictionary
|
|
122
123
|
mask[comp_mask] = comp_mask_subset
|
|
@@ -4,12 +4,8 @@ from rfd3.inference.symmetry.frames import (
|
|
|
4
4
|
get_symmetry_frames_from_symmetry_id,
|
|
5
5
|
)
|
|
6
6
|
|
|
7
|
-
from foundry.utils.ddp import RankedLogger
|
|
8
|
-
|
|
9
7
|
FIXED_TRANSFORM_ID = -1
|
|
10
8
|
FIXED_ENTITY_ID = -1
|
|
11
|
-
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
12
|
-
|
|
13
9
|
|
|
14
10
|
########################################################
|
|
15
11
|
# Symmetry annotations
|
|
@@ -28,7 +24,7 @@ def add_sym_annotations(atom_array, sym_conf):
|
|
|
28
24
|
is_asu = np.full(n, True, dtype=np.bool_)
|
|
29
25
|
atom_array.set_annotation("is_sym_asu", is_asu)
|
|
30
26
|
# symmetry_id
|
|
31
|
-
symmetry_ids = np.full(n, sym_conf.
|
|
27
|
+
symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
|
|
32
28
|
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
33
29
|
return atom_array
|
|
34
30
|
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from rfd3.inference.symmetry.contigs import
|
|
2
|
+
from rfd3.inference.symmetry.contigs import (
|
|
3
|
+
expand_contig_unsym_motif,
|
|
4
|
+
get_unsym_motif_mask,
|
|
5
|
+
)
|
|
3
6
|
from rfd3.transforms.conditioning_base import get_motif_features
|
|
4
7
|
|
|
5
8
|
from foundry.utils.ddp import RankedLogger
|
|
6
9
|
|
|
7
|
-
MIN_ATOMS_ALIGN =
|
|
10
|
+
MIN_ATOMS_ALIGN = 30
|
|
8
11
|
MAX_TRANSFORMS = 10
|
|
9
12
|
RMSD_CUT = 1.0 # Angstroms
|
|
10
13
|
|
|
@@ -18,29 +21,33 @@ def check_symmetry_config(
|
|
|
18
21
|
Check if the symmetry configuration is valid. Add all basic checks here.
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
|
-
assert sym_conf.
|
|
24
|
+
assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
|
|
22
25
|
# if unsym motif is provided, check that each motif name is in the atom array
|
|
23
|
-
|
|
26
|
+
|
|
27
|
+
is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
|
|
28
|
+
if sym_conf.is_unsym_motif:
|
|
24
29
|
assert (
|
|
25
30
|
src_atom_array is not None
|
|
26
31
|
), "Source atom array must be provided for symmetric motifs"
|
|
27
|
-
unsym_motif_names = sym_conf
|
|
32
|
+
unsym_motif_names = sym_conf.is_unsym_motif.split(",")
|
|
28
33
|
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
34
|
+
is_unsym_motif = get_unsym_motif_mask(atom_array, unsym_motif_names)
|
|
29
35
|
for n in unsym_motif_names:
|
|
30
36
|
if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
|
|
31
37
|
raise ValueError(f"Unsym motif {n} not found in atom_array")
|
|
38
|
+
|
|
39
|
+
is_motif_token = get_motif_features(atom_array)["is_motif_token"]
|
|
32
40
|
if (
|
|
33
|
-
|
|
34
|
-
and not sym_conf.
|
|
41
|
+
is_motif_token[~is_unsym_motif].any()
|
|
42
|
+
and not sym_conf.is_symmetric_motif
|
|
35
43
|
and not has_dist_cond
|
|
36
44
|
):
|
|
37
45
|
raise ValueError(
|
|
38
|
-
"Asymmetric motif inputs should be distance constrained.
|
|
46
|
+
"Asymmetric motif inputs should be distance constrained."
|
|
39
47
|
"Use atomwise_fixed_dist to constrain the distance between the motif atoms."
|
|
40
48
|
)
|
|
41
|
-
# else: if unconditional symmetry, no need to have symmetric input motif
|
|
42
49
|
|
|
43
|
-
if partial and not sym_conf.
|
|
50
|
+
if partial and not sym_conf.is_symmetric_motif:
|
|
44
51
|
raise ValueError(
|
|
45
52
|
"Partial diffusion with symmetry is only supported for symmetric inputs."
|
|
46
53
|
)
|
|
@@ -54,9 +61,6 @@ def check_atom_array_is_symmetric(atom_array):
|
|
|
54
61
|
Returns:
|
|
55
62
|
bool: True if the atom array is symmetric, False otherwise
|
|
56
63
|
"""
|
|
57
|
-
# TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
|
|
58
|
-
# and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
|
|
59
|
-
|
|
60
64
|
import biotite.structure as struc
|
|
61
65
|
from rfd3.inference.symmetry.atom_array import (
|
|
62
66
|
apply_symmetry_to_atomarray_coord,
|
|
@@ -68,8 +72,10 @@ def check_atom_array_is_symmetric(atom_array):
|
|
|
68
72
|
# remove hetero atoms
|
|
69
73
|
atom_array = atom_array[~atom_array.hetero]
|
|
70
74
|
if len(atom_array) == 0:
|
|
71
|
-
ranked_logger.
|
|
72
|
-
|
|
75
|
+
ranked_logger.warning(
|
|
76
|
+
"Atom array has no protein chains. Please check your input."
|
|
77
|
+
)
|
|
78
|
+
return True
|
|
73
79
|
|
|
74
80
|
chains = np.unique(atom_array.chain_id)
|
|
75
81
|
asu_mask = atom_array.chain_id == chains[0]
|
|
@@ -162,16 +168,22 @@ def find_optimal_rotation(coords1, coords2, max_points=1000):
|
|
|
162
168
|
return None
|
|
163
169
|
|
|
164
170
|
|
|
165
|
-
def check_input_frames_match_symmetry_frames(
|
|
171
|
+
def check_input_frames_match_symmetry_frames(
|
|
172
|
+
computed_frames, original_frames, nids_by_entity
|
|
173
|
+
) -> None:
|
|
166
174
|
"""
|
|
167
175
|
Check if the atom array matches the symmetry_id.
|
|
168
176
|
Arguments:
|
|
169
177
|
computed_frames: list of computed frames
|
|
170
178
|
original_frames: list of original frames
|
|
171
179
|
"""
|
|
172
|
-
assert len(computed_frames) == len(
|
|
173
|
-
|
|
174
|
-
|
|
180
|
+
assert len(computed_frames) == len(original_frames), (
|
|
181
|
+
"Number of computed frames does not match number of original frames.\n"
|
|
182
|
+
f"Computed Frames: {len(computed_frames)}. Original Frames: {len(original_frames)}.\n"
|
|
183
|
+
"If the computed frames are not as expected, please check if you have one-to-one mapping "
|
|
184
|
+
"(size, sequence, folding) of an entity across all chains.\n"
|
|
185
|
+
f"Computed Entity Mapping: {nids_by_entity}."
|
|
186
|
+
)
|
|
175
187
|
|
|
176
188
|
|
|
177
189
|
def check_valid_multiplicity(nids_by_entity) -> None:
|
|
@@ -184,25 +196,35 @@ def check_valid_multiplicity(nids_by_entity) -> None:
|
|
|
184
196
|
multiplicity = min([len(i) for i in nids_by_entity.values()])
|
|
185
197
|
if multiplicity == 1: # no possible symmetry
|
|
186
198
|
raise ValueError(
|
|
187
|
-
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead
|
|
199
|
+
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead.\n"
|
|
200
|
+
"Multiplicity: 1"
|
|
188
201
|
)
|
|
189
202
|
|
|
190
203
|
# Check that the input is not asymmetric
|
|
191
204
|
multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
|
|
192
205
|
if not all(multiplicity_good):
|
|
193
|
-
raise ValueError(
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"Expected multiplicity does not match for some entities.\n"
|
|
208
|
+
"Please modify your input to have one-to-one mapping (size, sequence, folding) of an entity across all chains.\n"
|
|
209
|
+
f"Expected Multiplicity: {multiplicity}.\n"
|
|
210
|
+
f"Computed Entity Mapping: {nids_by_entity}."
|
|
211
|
+
)
|
|
194
212
|
|
|
195
213
|
|
|
196
214
|
def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
|
|
197
215
|
"""
|
|
198
216
|
Check that the subunits in the input are of the same size.
|
|
199
217
|
Arguments:
|
|
200
|
-
nids_by_entity: dict mapping entity to ids
|
|
218
|
+
nids_by_entity: dict mapping entity to ids. e.g. {0: (['A_1', 'B_1', 'C_1']), 1: (['A_2', 'B_2', 'C_2'])}
|
|
219
|
+
pn_unit_id: array of ids. e.g. ['A_1', 'B_1', 'C_1', 'A_2', 'B_2', 'C_2']
|
|
201
220
|
"""
|
|
202
|
-
for
|
|
203
|
-
for
|
|
204
|
-
if (pn_unit_id == js[0]).sum() != (pn_unit_id ==
|
|
205
|
-
raise ValueError(
|
|
221
|
+
for js in nids_by_entity.values():
|
|
222
|
+
for js_i in js[1:]:
|
|
223
|
+
if (pn_unit_id == js[0]).sum() != (pn_unit_id == js_i).sum():
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Size mismatch between chain {js[0]} ({(pn_unit_id == js[0]).sum()} atoms) "
|
|
226
|
+
f"and chain {js_i} ({(pn_unit_id == js_i).sum()} atoms). Please check your input file."
|
|
227
|
+
)
|
|
206
228
|
|
|
207
229
|
|
|
208
230
|
def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
@@ -212,7 +234,10 @@ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
|
212
234
|
nids_by_entity: dict mapping entity to ids
|
|
213
235
|
"""
|
|
214
236
|
if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
|
|
215
|
-
raise ValueError(
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Not enough atoms to align < {MIN_ATOMS_ALIGN} atoms."
|
|
239
|
+
f"Please provide a input with at least {MIN_ATOMS_ALIGN} atoms."
|
|
240
|
+
)
|
|
216
241
|
|
|
217
242
|
|
|
218
243
|
def check_max_transforms(chains_to_consider) -> None:
|
|
@@ -224,7 +249,7 @@ def check_max_transforms(chains_to_consider) -> None:
|
|
|
224
249
|
"""
|
|
225
250
|
if len(chains_to_consider) > MAX_TRANSFORMS:
|
|
226
251
|
raise ValueError(
|
|
227
|
-
"Number of transforms exceeds the max number of transforms (
|
|
252
|
+
f"Number of transforms exceeds the max number of transforms ({MAX_TRANSFORMS})."
|
|
228
253
|
)
|
|
229
254
|
|
|
230
255
|
|
|
@@ -10,12 +10,13 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
10
10
|
Returns:
|
|
11
11
|
frames: list of rotation matrices
|
|
12
12
|
"""
|
|
13
|
+
from rfd3.inference.symmetry.symmetry_utils import SymmetryConfig
|
|
13
14
|
|
|
14
15
|
# Get frames from symmetry id
|
|
15
16
|
sym_conf = {}
|
|
16
|
-
if isinstance(symmetry_id,
|
|
17
|
+
if isinstance(symmetry_id, SymmetryConfig):
|
|
17
18
|
sym_conf = symmetry_id
|
|
18
|
-
symmetry_id = symmetry_id.
|
|
19
|
+
symmetry_id = symmetry_id.id
|
|
19
20
|
|
|
20
21
|
if symmetry_id.lower().startswith("c"):
|
|
21
22
|
order = int(symmetry_id[1:])
|
|
@@ -25,9 +26,9 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
25
26
|
frames = get_dihedral_frames(order)
|
|
26
27
|
elif symmetry_id.lower() == "input_defined":
|
|
27
28
|
assert (
|
|
28
|
-
|
|
29
|
+
sym_conf.symmetry_file is not None
|
|
29
30
|
), "symmetry_file is required for input_defined symmetry"
|
|
30
|
-
frames = get_frames_from_file(sym_conf.
|
|
31
|
+
frames = get_frames_from_file(sym_conf.symmetry_file)
|
|
31
32
|
else:
|
|
32
33
|
raise ValueError(f"Symmetry id {symmetry_id} not supported")
|
|
33
34
|
|
|
@@ -120,7 +121,9 @@ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
|
|
|
120
121
|
computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
|
|
121
122
|
|
|
122
123
|
# check that the computed frames match the input frames
|
|
123
|
-
check_input_frames_match_symmetry_frames(
|
|
124
|
+
check_input_frames_match_symmetry_frames(
|
|
125
|
+
computed_frames, input_frames, nids_by_entity
|
|
126
|
+
)
|
|
124
127
|
|
|
125
128
|
return computed_frames
|
|
126
129
|
|
|
@@ -39,18 +39,36 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class SymmetryConfig(BaseModel):
|
|
42
|
-
# AM / HE TODO: feel free to flesh this out and add validation as needed
|
|
43
42
|
model_config = ConfigDict(
|
|
44
43
|
arbitrary_types_allowed=True,
|
|
45
44
|
extra="allow",
|
|
46
45
|
)
|
|
47
|
-
id: Optional[str] = Field(
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
id: Optional[str] = Field(
|
|
47
|
+
None,
|
|
48
|
+
description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
|
|
49
|
+
)
|
|
50
|
+
is_unsym_motif: Optional[str] = Field(
|
|
51
|
+
None,
|
|
52
|
+
description="Comma separated list of contig/ligand names that should not be symmetrized such as DNA strands. \
|
|
53
|
+
e.g. 'HEM' or 'Y1-11,Z16-25'",
|
|
54
|
+
)
|
|
55
|
+
is_symmetric_motif: bool = Field(
|
|
56
|
+
True,
|
|
57
|
+
description="If True, the input motifs are expected to be already symmetric and won't be symmetrized. \
|
|
58
|
+
If False, the all input motifs are expected to be ASU and will be symmetrized.",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def convery_sym_conf_to_symmetry_config(sym_conf: dict):
|
|
63
|
+
return SymmetryConfig(**sym_conf)
|
|
50
64
|
|
|
51
65
|
|
|
52
66
|
def make_symmetric_atom_array(
|
|
53
|
-
asu_atom_array,
|
|
67
|
+
asu_atom_array,
|
|
68
|
+
sym_conf: SymmetryConfig | dict,
|
|
69
|
+
sm=None,
|
|
70
|
+
has_dist_cond=False,
|
|
71
|
+
src_atom_array=None,
|
|
54
72
|
):
|
|
55
73
|
"""
|
|
56
74
|
apply symmetry to an atom array.
|
|
@@ -58,39 +76,33 @@ def make_symmetric_atom_array(
|
|
|
58
76
|
asu_atom_array: atom array of the asymmetric unit
|
|
59
77
|
sym_conf: symmetry configuration (dict, "id" key is required)
|
|
60
78
|
sm: optional small molecule names (str, comma separated)
|
|
61
|
-
|
|
79
|
+
has_dist_cond: whether to add 2d entity annotations
|
|
62
80
|
Returns:
|
|
63
81
|
new_asu_atom_array: atom array with symmetry applied
|
|
64
82
|
"""
|
|
65
|
-
|
|
66
|
-
sym_conf
|
|
67
|
-
) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
|
|
68
|
-
ranked_logger.info(f"Symmetry Configs: {sym_conf}")
|
|
83
|
+
if not isinstance(sym_conf, SymmetryConfig):
|
|
84
|
+
sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
|
|
69
85
|
|
|
70
|
-
# Making sure that the symmetry config is valid
|
|
71
86
|
check_symmetry_config(
|
|
72
|
-
asu_atom_array,
|
|
73
|
-
sym_conf,
|
|
74
|
-
sm,
|
|
75
|
-
has_dist_cond=has_2d,
|
|
76
|
-
src_atom_array=src_atom_array,
|
|
87
|
+
asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
|
|
77
88
|
)
|
|
78
89
|
# Adding utility annotations to the asu atom array
|
|
79
90
|
asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
|
|
80
91
|
|
|
81
|
-
if
|
|
92
|
+
if has_dist_cond: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
|
|
82
93
|
asu_atom_array = add_2d_entity_annotations(asu_atom_array)
|
|
83
94
|
|
|
84
95
|
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
85
96
|
|
|
86
97
|
# If the motif is symmetric, we get the frames instead from the source atom array.
|
|
87
|
-
if sym_conf.
|
|
98
|
+
if sym_conf.is_symmetric_motif:
|
|
88
99
|
assert (
|
|
89
100
|
src_atom_array is not None
|
|
90
101
|
), "Source atom array must be provided for symmetric motifs"
|
|
91
|
-
# if symmetric motif is provided, get the frames from the src atom array
|
|
102
|
+
# if symmetric motif is provided, get the frames from the src atom array.
|
|
92
103
|
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
93
|
-
|
|
104
|
+
elif (asu_atom_array._is_motif[~asu_atom_array._is_unsym_motif]).any():
|
|
105
|
+
# if the motifs that's not unsym motifs are present.
|
|
94
106
|
raise NotImplementedError(
|
|
95
107
|
"Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
|
|
96
108
|
)
|
|
@@ -101,7 +113,7 @@ def make_symmetric_atom_array(
|
|
|
101
113
|
# Extracting all things at this moment that we will not want to symmetrize.
|
|
102
114
|
# This includes: 1) unsym motifs, 2) ligands
|
|
103
115
|
unsym_atom_arrays = []
|
|
104
|
-
if sym_conf.
|
|
116
|
+
if sym_conf.is_unsym_motif:
|
|
105
117
|
# unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
|
|
106
118
|
# Now remove the unsym motifs from the asu atom array
|
|
107
119
|
unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
|
|
@@ -128,7 +140,7 @@ def make_symmetric_atom_array(
|
|
|
128
140
|
symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
|
|
129
141
|
|
|
130
142
|
# add 2D conditioning annotations
|
|
131
|
-
if
|
|
143
|
+
if has_dist_cond:
|
|
132
144
|
symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
|
|
133
145
|
|
|
134
146
|
# set all motifs to not have any symmetrization applied to them
|
|
@@ -183,7 +195,7 @@ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
|
|
|
183
195
|
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
184
196
|
|
|
185
197
|
# Add symmetry ID
|
|
186
|
-
symmetry_ids = np.full(n, sym_conf.
|
|
198
|
+
symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
|
|
187
199
|
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
188
200
|
|
|
189
201
|
# Initialize transform annotations (use same format as original system)
|
|
@@ -244,7 +256,7 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
|
244
256
|
"""
|
|
245
257
|
n = asu_atom_array.shape[0]
|
|
246
258
|
is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
|
|
247
|
-
is_sm = np.zeros(
|
|
259
|
+
is_sm = np.zeros(n, dtype=bool)
|
|
248
260
|
is_asu = np.ones(n, dtype=bool)
|
|
249
261
|
is_unsym_motif = np.zeros(n, dtype=bool)
|
|
250
262
|
|
|
@@ -257,8 +269,8 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
|
257
269
|
)
|
|
258
270
|
|
|
259
271
|
# assign unsym motifs
|
|
260
|
-
if sym_conf.
|
|
261
|
-
unsym_motif_names = sym_conf
|
|
272
|
+
if sym_conf.is_unsym_motif:
|
|
273
|
+
unsym_motif_names = sym_conf.is_unsym_motif.split(",")
|
|
262
274
|
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
263
275
|
is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
|
|
264
276
|
|
|
@@ -361,38 +373,4 @@ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
|
|
|
361
373
|
"blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
|
|
362
374
|
) + sym_transforms[target_id][1].to(asu_xyz.dtype)
|
|
363
375
|
|
|
364
|
-
# Log inter-chain distances for debugging - use actual chain annotations
|
|
365
|
-
if sym_X_L.shape[1] > 100: # Only for large structures
|
|
366
|
-
# Use symmetry entity annotations to find different chains
|
|
367
|
-
sym_entity_id = sym_feats["sym_entity_id"]
|
|
368
|
-
unique_entities = torch.unique(sym_entity_id)
|
|
369
|
-
|
|
370
|
-
if len(unique_entities) >= 2:
|
|
371
|
-
# Get atoms from first two different entities
|
|
372
|
-
entity_0_mask = sym_entity_id == unique_entities[0]
|
|
373
|
-
entity_1_mask = sym_entity_id == unique_entities[1]
|
|
374
|
-
|
|
375
|
-
if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
|
|
376
|
-
entity_0_atoms = sym_X_L[0, entity_0_mask, :]
|
|
377
|
-
entity_1_atoms = sym_X_L[0, entity_1_mask, :]
|
|
378
|
-
|
|
379
|
-
# Sample subset to avoid memory issues
|
|
380
|
-
entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
|
|
381
|
-
entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
|
|
382
|
-
|
|
383
|
-
min_distance = (
|
|
384
|
-
torch.cdist(entity_0_sample, entity_1_sample).min().item()
|
|
385
|
-
)
|
|
386
|
-
ranked_logger.info(
|
|
387
|
-
f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
# Also log the centers of each entity
|
|
391
|
-
entity_0_center = entity_0_atoms.mean(dim=0)
|
|
392
|
-
entity_1_center = entity_1_atoms.mean(dim=0)
|
|
393
|
-
center_distance = torch.norm(entity_0_center - entity_1_center).item()
|
|
394
|
-
ranked_logger.info(
|
|
395
|
-
f"Distance between chain centers: {center_distance:.2f} Å"
|
|
396
|
-
)
|
|
397
|
-
|
|
398
376
|
return sym_X_L
|
rfd3/run_inference.py
CHANGED
|
@@ -12,7 +12,9 @@ load_dotenv(override=True)
|
|
|
12
12
|
|
|
13
13
|
# For pip-installed package, configs should be relative to this file
|
|
14
14
|
# Adjust this path based on where configs are bundled in the package
|
|
15
|
-
_config_path = os.path.join(
|
|
15
|
+
_config_path = os.path.join(
|
|
16
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
|
|
17
|
+
)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@hydra.main(
|
rfd3/utils/inference.py
CHANGED
|
@@ -391,6 +391,29 @@ def ensure_input_is_abspath(args: dict, path: PathLike | None):
|
|
|
391
391
|
return args
|
|
392
392
|
|
|
393
393
|
|
|
394
|
+
def ensure_inference_sampler_matches_design_spec(
|
|
395
|
+
design_spec: dict, inference_sampler: dict | None = None
|
|
396
|
+
):
|
|
397
|
+
"""
|
|
398
|
+
Ensure the inference sampler is set to the correct sampler for the design specification.
|
|
399
|
+
Args:
|
|
400
|
+
design_spec: Design specification dictionary
|
|
401
|
+
inference_sampler: Inference sampler dictionary
|
|
402
|
+
"""
|
|
403
|
+
has_symmetry_specification = [
|
|
404
|
+
True if "symmetry" in item.keys() else False for item in design_spec.values()
|
|
405
|
+
]
|
|
406
|
+
if any(has_symmetry_specification):
|
|
407
|
+
if (
|
|
408
|
+
inference_sampler is None
|
|
409
|
+
or inference_sampler.get("kind", "default") != "symmetry"
|
|
410
|
+
):
|
|
411
|
+
raise ValueError(
|
|
412
|
+
"You requested for symmetric designs, but inference sampler is not set to symmetry. "
|
|
413
|
+
"Please add inference_sampler.kind='symmetry' to your command."
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
394
417
|
#################################################################################
|
|
395
418
|
# Custom infer_ori functions
|
|
396
419
|
#################################################################################
|