rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from foundry.utils.components import fetch_mask_from_idx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def expand_contig_to_resid_from_string(contig_string):
|
|
7
|
+
"""
|
|
8
|
+
Expand a contig string to a list of residue indices.
|
|
9
|
+
Arguments:
|
|
10
|
+
contig_string: string of the form "X1-5", e.g.
|
|
11
|
+
Returns:
|
|
12
|
+
list of residue indices
|
|
13
|
+
"""
|
|
14
|
+
chain = contig_string[0]
|
|
15
|
+
res_range = contig_string[1:].split("-")
|
|
16
|
+
res_start = int(res_range[0])
|
|
17
|
+
res_end = int(res_range[1])
|
|
18
|
+
return [f"{chain}{i}" for i in range(res_start, res_end + 1)]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def expand_contig_unsym_motif(unsym_motif_names):
|
|
22
|
+
"""
|
|
23
|
+
Expand a list of unsym motif names to a list of residue indices.
|
|
24
|
+
Arguments:
|
|
25
|
+
unsym_motif_names: list of unsym motif names
|
|
26
|
+
Returns:
|
|
27
|
+
list of residue indices
|
|
28
|
+
"""
|
|
29
|
+
expanded_contigs = [
|
|
30
|
+
expand_contig_to_resid_from_string(n) for n in unsym_motif_names if "-" in n
|
|
31
|
+
]
|
|
32
|
+
# now remove any unexpanded contigs
|
|
33
|
+
unsym_motif_names = [n for n in unsym_motif_names if "-" not in n]
|
|
34
|
+
if len(expanded_contigs) != 0:
|
|
35
|
+
for c in expanded_contigs:
|
|
36
|
+
unsym_motif_names.extend(c)
|
|
37
|
+
return unsym_motif_names
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_unsym_motif_mask(atom_array, unsym_motif_names):
|
|
41
|
+
"""
|
|
42
|
+
Get a mask of the unsym motif atoms.
|
|
43
|
+
Arguments:
|
|
44
|
+
atom_array: atom array
|
|
45
|
+
unsym_motif_names: list of unsym motif names
|
|
46
|
+
Returns:
|
|
47
|
+
mask of the unsym motif atoms
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
is_unsym_motif = np.zeros(len(atom_array), dtype=bool)
|
|
51
|
+
for n in unsym_motif_names:
|
|
52
|
+
is_unsym_motif = np.logical_or(is_unsym_motif, atom_array.res_name == n)
|
|
53
|
+
if (
|
|
54
|
+
"src_component" in atom_array.get_annotation_categories()
|
|
55
|
+
and n in atom_array.src_component
|
|
56
|
+
):
|
|
57
|
+
is_unsym_motif = np.logical_or(
|
|
58
|
+
is_unsym_motif, atom_array.src_component == n
|
|
59
|
+
)
|
|
60
|
+
elif n[0].isalpha() and n[1:].isdigit():
|
|
61
|
+
residue_mask = fetch_mask_from_idx(n, atom_array=atom_array)
|
|
62
|
+
is_unsym_motif = np.logical_or(is_unsym_motif, residue_mask)
|
|
63
|
+
return is_unsym_motif
|
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
6
|
+
"""
|
|
7
|
+
Get symmetry frames from a symmetry id.
|
|
8
|
+
Arguments:
|
|
9
|
+
symmetry_id: string of the symmetry id
|
|
10
|
+
Returns:
|
|
11
|
+
frames: list of rotation matrices
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Get frames from symmetry id
|
|
15
|
+
sym_conf = {}
|
|
16
|
+
if isinstance(symmetry_id, dict):
|
|
17
|
+
sym_conf = symmetry_id
|
|
18
|
+
symmetry_id = symmetry_id.get("id")
|
|
19
|
+
|
|
20
|
+
if symmetry_id.lower().startswith("c"):
|
|
21
|
+
order = int(symmetry_id[1:])
|
|
22
|
+
frames = get_cyclic_frames(order)
|
|
23
|
+
elif symmetry_id.lower().startswith("d"):
|
|
24
|
+
order = int(symmetry_id[1:])
|
|
25
|
+
frames = get_dihedral_frames(order)
|
|
26
|
+
elif symmetry_id.lower() == "input_defined":
|
|
27
|
+
assert (
|
|
28
|
+
"symmetry_file" in sym_conf
|
|
29
|
+
), "symmetry_file is required for input_defined symmetry"
|
|
30
|
+
frames = get_frames_from_file(sym_conf.get("symmetry_file"))
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError(f"Symmetry id {symmetry_id} not supported")
|
|
33
|
+
|
|
34
|
+
# Check that the frames are valid rotation matrices
|
|
35
|
+
for R, _ in frames:
|
|
36
|
+
assert is_valid_rotation_matrix(R), f"Frame {R} is not a valid rotation matrix"
|
|
37
|
+
|
|
38
|
+
return frames
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
|
|
42
|
+
"""
|
|
43
|
+
Get symmetry frames from an atom array. Adapted from code from FD
|
|
44
|
+
Arguments:
|
|
45
|
+
src_atom_array: atom array with coordinates and chain/residue information
|
|
46
|
+
input_frames: list of (rotation_matrix, translation_vector) tuples
|
|
47
|
+
Returns:
|
|
48
|
+
computed_frames: list of (rotation_matrix, translation_vector) tuples (updated)
|
|
49
|
+
"""
|
|
50
|
+
# import within the function to avoid circular import
|
|
51
|
+
from rfd3.inference.symmetry.checks import (
|
|
52
|
+
check_input_frames_match_symmetry_frames,
|
|
53
|
+
check_max_rmsds,
|
|
54
|
+
check_max_transforms,
|
|
55
|
+
check_min_atoms_to_align,
|
|
56
|
+
check_valid_multiplicity,
|
|
57
|
+
check_valid_subunit_size,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# remove non-protein elements
|
|
61
|
+
src_atom_array = src_atom_array[src_atom_array.chain_type == 6]
|
|
62
|
+
|
|
63
|
+
# get entities and ids from the src atom array
|
|
64
|
+
pn_unit_ent = src_atom_array.get_annotation("pn_unit_entity")
|
|
65
|
+
pn_unit_id = src_atom_array.get_annotation("pn_unit_iid")
|
|
66
|
+
unique_entities = np.unique(pn_unit_ent)
|
|
67
|
+
nids_by_entity = {
|
|
68
|
+
i: np.unique(pn_unit_id[pn_unit_ent == i]) for i in unique_entities
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# get coordinates
|
|
72
|
+
coords = src_atom_array.coord
|
|
73
|
+
|
|
74
|
+
# get/check multiplicities of subunits
|
|
75
|
+
check_valid_multiplicity(nids_by_entity)
|
|
76
|
+
|
|
77
|
+
multiplicity = min([len(i) for i in nids_by_entity.values()])
|
|
78
|
+
n_per_asu = {i: len(j) // multiplicity for i, j in nids_by_entity.items()}
|
|
79
|
+
|
|
80
|
+
# check that the subunits in the input are of the same size
|
|
81
|
+
check_valid_subunit_size(nids_by_entity, pn_unit_id)
|
|
82
|
+
|
|
83
|
+
# align the largest set of entities
|
|
84
|
+
natm_per_unique = {
|
|
85
|
+
i: (pn_unit_id == nids_by_entity[i][0]).sum()
|
|
86
|
+
for i in unique_entities
|
|
87
|
+
if n_per_asu[i] == 1
|
|
88
|
+
}
|
|
89
|
+
reference_entity = max(natm_per_unique, key=natm_per_unique.get)
|
|
90
|
+
|
|
91
|
+
# check that we have enough atoms to align
|
|
92
|
+
check_min_atoms_to_align(natm_per_unique, reference_entity)
|
|
93
|
+
|
|
94
|
+
# chains for the alignment (will generate complete set of frames)
|
|
95
|
+
chains_to_consider = nids_by_entity[reference_entity]
|
|
96
|
+
reference_molecule = nids_by_entity[reference_entity][0]
|
|
97
|
+
|
|
98
|
+
# check that we are not exceeding the max number of transforms
|
|
99
|
+
check_max_transforms(chains_to_consider)
|
|
100
|
+
|
|
101
|
+
# align reference molecule to all others
|
|
102
|
+
xforms = {
|
|
103
|
+
i: _align(coords[pn_unit_id == i], coords[pn_unit_id == reference_molecule])
|
|
104
|
+
for i in chains_to_consider
|
|
105
|
+
}
|
|
106
|
+
rmsds = {
|
|
107
|
+
i: _rms(coords[pn_unit_id == i], coords[pn_unit_id == reference_molecule], *j)
|
|
108
|
+
for i, j in xforms.items()
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
# check that there is not too big of a RMSD difference between subunits
|
|
112
|
+
check_max_rmsds(rmsds)
|
|
113
|
+
|
|
114
|
+
# check that the frames are valid rotation matrices
|
|
115
|
+
Rs = [R for _, R, _ in xforms.values()]
|
|
116
|
+
for R in Rs:
|
|
117
|
+
assert is_valid_rotation_matrix(
|
|
118
|
+
R
|
|
119
|
+
), f"Computed frame {R} is not a valid rotation matrix"
|
|
120
|
+
computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
|
|
121
|
+
|
|
122
|
+
# check that the computed frames match the input frames
|
|
123
|
+
check_input_frames_match_symmetry_frames(computed_frames, input_frames)
|
|
124
|
+
|
|
125
|
+
return computed_frames
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _align(X_fixed, X_moving):
|
|
129
|
+
"""
|
|
130
|
+
Align two sets of coordinates using Kabsch algorithm.
|
|
131
|
+
Arguments:
|
|
132
|
+
X_fixed: fixed coordinates
|
|
133
|
+
X_moving: moving coordinates
|
|
134
|
+
Returns:
|
|
135
|
+
u_X_moving: mean of the moving coordinates
|
|
136
|
+
R: rotation matrix
|
|
137
|
+
u_X_fixed: mean of the fixed coordinates
|
|
138
|
+
"""
|
|
139
|
+
is_torch = isinstance(X_fixed, torch.Tensor)
|
|
140
|
+
|
|
141
|
+
def _mean_along_dim(X, dim):
|
|
142
|
+
if is_torch:
|
|
143
|
+
return X.mean(dim=dim)
|
|
144
|
+
else:
|
|
145
|
+
return X.mean(axis=dim)
|
|
146
|
+
|
|
147
|
+
assert X_fixed.shape == X_moving.shape
|
|
148
|
+
|
|
149
|
+
if X_fixed.ndim == 2:
|
|
150
|
+
X_fixed = X_fixed[None, ...]
|
|
151
|
+
X_moving = X_moving[None, ...]
|
|
152
|
+
B = X_fixed.shape[0]
|
|
153
|
+
|
|
154
|
+
if is_torch:
|
|
155
|
+
mask = (~torch.isnan(X_fixed) & ~torch.isnan(X_moving)).all(dim=-1).all(dim=0)
|
|
156
|
+
else:
|
|
157
|
+
mask = (~np.isnan(X_fixed) & ~np.isnan(X_moving)).all(axis=-1).all(axis=0)
|
|
158
|
+
|
|
159
|
+
X_fixed = X_fixed[:, mask]
|
|
160
|
+
X_moving = X_moving[:, mask]
|
|
161
|
+
|
|
162
|
+
u_X_fixed = _mean_along_dim(X_fixed, dim=-2)
|
|
163
|
+
u_X_moving = _mean_along_dim(X_moving, dim=-2)
|
|
164
|
+
|
|
165
|
+
X_fixed_centered = X_fixed - u_X_fixed[..., None, :]
|
|
166
|
+
X_moving_centered = X_moving - u_X_moving[..., None, :]
|
|
167
|
+
|
|
168
|
+
if is_torch:
|
|
169
|
+
C = torch.einsum("...ji,...jk->...ik", X_fixed_centered, X_moving_centered)
|
|
170
|
+
U, S, V = torch.linalg.svd(C, full_matrices=False)
|
|
171
|
+
else:
|
|
172
|
+
C = np.einsum("...ji,...jk->...ik", X_fixed_centered, X_moving_centered)
|
|
173
|
+
U, S, V = np.linalg.svd(C, full_matrices=False)
|
|
174
|
+
|
|
175
|
+
R = U @ V
|
|
176
|
+
if is_torch:
|
|
177
|
+
F = torch.eye(3, 3, device=R.device).expand(B, 3, 3).clone()
|
|
178
|
+
F[..., -1, -1] = torch.sign(torch.linalg.det(R))
|
|
179
|
+
else:
|
|
180
|
+
F = np.broadcast_to(np.eye(3, 3), (B, 3, 3)).copy()
|
|
181
|
+
F[..., -1, -1] = np.sign(np.linalg.det(R))
|
|
182
|
+
R = U @ F @ V
|
|
183
|
+
|
|
184
|
+
if R.shape[0] == 1:
|
|
185
|
+
return u_X_moving[0], R[0], u_X_fixed[0]
|
|
186
|
+
|
|
187
|
+
return u_X_moving, R, u_X_fixed
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _rms(X_fixed, X_moving, t_pre, R, t_post):
|
|
191
|
+
"""
|
|
192
|
+
Calculate the RMSD between two sets of coordinates.
|
|
193
|
+
Arguments:
|
|
194
|
+
X_fixed: fixed coordinates
|
|
195
|
+
X_moving: moving coordinates
|
|
196
|
+
t_pre: pre-rotation translation
|
|
197
|
+
R: rotation matrix
|
|
198
|
+
t_post: post-rotation translation
|
|
199
|
+
Returns:
|
|
200
|
+
rms: RMSD
|
|
201
|
+
"""
|
|
202
|
+
mask = (~np.isnan(X_fixed) & ~np.isnan(X_moving)).all(axis=-1)
|
|
203
|
+
X_fixed = X_fixed[mask]
|
|
204
|
+
X_moving = X_moving[mask]
|
|
205
|
+
|
|
206
|
+
X_moving_aln = np.einsum("ij,bj->bi", R, (X_moving - t_pre[None])) + t_post[None]
|
|
207
|
+
rms = np.sqrt(np.sum(np.square(X_moving_aln - X_fixed)) / X_moving_aln.shape[-2])
|
|
208
|
+
return rms
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def is_valid_rotation_matrix(R):
|
|
212
|
+
"""
|
|
213
|
+
check if a matrix is a valid rotation matrix.
|
|
214
|
+
Arguments:
|
|
215
|
+
R: rotation matrix
|
|
216
|
+
Returns:
|
|
217
|
+
bool: True if R is a valid rotation matrix, False otherwise
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
return np.allclose(R @ R.T, np.eye(3), atol=1e-6)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def get_cyclic_frames(order):
|
|
224
|
+
"""
|
|
225
|
+
Get cyclic frames from a number of subunits.
|
|
226
|
+
Arguments:
|
|
227
|
+
order: number of subunits
|
|
228
|
+
Returns:
|
|
229
|
+
frames: list of rotation matrices
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
frames = []
|
|
233
|
+
for i in range(order):
|
|
234
|
+
angle = 2 * np.pi * i / order
|
|
235
|
+
R = np.array(
|
|
236
|
+
[
|
|
237
|
+
[np.cos(angle), -np.sin(angle), 0],
|
|
238
|
+
[np.sin(angle), np.cos(angle), 0],
|
|
239
|
+
[0, 0, 1],
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
243
|
+
|
|
244
|
+
return frames
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_dihedral_frames(order):
|
|
248
|
+
"""
|
|
249
|
+
Get dihedral frames from a number of subunits.
|
|
250
|
+
Arguments:
|
|
251
|
+
order: number of subunits // 2 (since each dihedral has two frames)
|
|
252
|
+
Returns:
|
|
253
|
+
frames: list of rotation matrices
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
frames = []
|
|
257
|
+
|
|
258
|
+
for i in range(order):
|
|
259
|
+
angle = 2 * np.pi * i / order
|
|
260
|
+
R = np.array(
|
|
261
|
+
[
|
|
262
|
+
[np.cos(angle), -np.sin(angle), 0],
|
|
263
|
+
[np.sin(angle), np.cos(angle), 0],
|
|
264
|
+
[0, 0, 1],
|
|
265
|
+
]
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# 180 degree rotation in the xy-plane
|
|
269
|
+
phi = angle + np.pi / order
|
|
270
|
+
u = np.array([np.cos(phi), np.sin(phi), 0])
|
|
271
|
+
flip = -np.eye(3) + 2 * np.outer(u, u)
|
|
272
|
+
|
|
273
|
+
# add both frames for the dihedral
|
|
274
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
275
|
+
frames.append((R @ flip, np.array([0, 0, 0])))
|
|
276
|
+
|
|
277
|
+
return frames
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def get_frames_from_file(file_path):
|
|
281
|
+
raise NotImplementedError("Input defined symmetry not implemented")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
###################################
|
|
285
|
+
# Kinematics
|
|
286
|
+
###################################
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
# fd - two routines that convert between:
|
|
290
|
+
# a) a "virtual frame" consisting of three atoms; and
|
|
291
|
+
# b) a translation and rotation
|
|
292
|
+
# uses Gram-Schmidt orthogonalziation, handles stacked/unstacked
|
|
293
|
+
# support np and torch inputs
|
|
294
|
+
def RTs_to_framecoords(Rs, ts, sig=1.0):
|
|
295
|
+
if isinstance(Rs, np.ndarray):
|
|
296
|
+
Rs = torch.from_numpy(Rs)
|
|
297
|
+
ts = torch.from_numpy(ts)
|
|
298
|
+
Ori = ts
|
|
299
|
+
X = Ori + sig * Rs[..., 0, :] / (
|
|
300
|
+
torch.norm(Rs[..., 0, :], dim=-1, keepdim=True) + 1e-6
|
|
301
|
+
)
|
|
302
|
+
Y = Ori + sig * Rs[..., 1, :] / (
|
|
303
|
+
torch.norm(Rs[..., 1, :], dim=-1, keepdim=True) + 1e-6
|
|
304
|
+
)
|
|
305
|
+
return Ori, X, Y
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# RTs_to_framecoords is used in loss and expects torch inputs
|
|
309
|
+
# (and must support backwards)
|
|
310
|
+
def framecoords_to_RTs(Ori, X, Y, eps=1e-6):
|
|
311
|
+
R1 = X - Ori
|
|
312
|
+
R1 = (R1 + torch.tensor([eps, 0, 0], device=R1.device)) / (
|
|
313
|
+
torch.linalg.norm(R1, axis=-1, keepdims=True) + eps
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
Y_rel = Y - Ori
|
|
317
|
+
proj = torch.sum(Y_rel * R1, axis=-1, keepdims=True) * R1
|
|
318
|
+
R2 = Y_rel - proj
|
|
319
|
+
R2 = (R2 + torch.tensor([0, eps, 0], device=R1.device)) / (
|
|
320
|
+
torch.linalg.norm(R2, axis=-1, keepdims=True) + eps
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
R3 = torch.cross(R1, R2, dim=-1)
|
|
324
|
+
|
|
325
|
+
# Stack into rotation matrix
|
|
326
|
+
R = torch.stack([R1, R2, R3], axis=-2) # shape (..., 3, 3)
|
|
327
|
+
T = Ori
|
|
328
|
+
|
|
329
|
+
return R, T
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def pack_vector(v: np.ndarray) -> np.ndarray:
|
|
333
|
+
"""
|
|
334
|
+
v: 1-D array of shape (3,) and arbitrary dtype
|
|
335
|
+
returns: 1-element of shape 1
|
|
336
|
+
"""
|
|
337
|
+
dt = np.dtype([("x", v.dtype, (3,))])
|
|
338
|
+
a = np.zeros(1, dtype=dt)
|
|
339
|
+
a["x"][0] = v
|
|
340
|
+
return a
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def unpack_vector(a: np.ndarray) -> np.ndarray:
|
|
344
|
+
"""
|
|
345
|
+
a: stuctured array of shape (1,)
|
|
346
|
+
returns: original vector
|
|
347
|
+
"""
|
|
348
|
+
return a["x"]
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def decompose_symmetry_frame(frame):
|
|
352
|
+
R, T = frame
|
|
353
|
+
Ori, X, Y = RTs_to_framecoords(R, T)
|
|
354
|
+
Ori, X, Y = pack_vector(Ori.numpy()), pack_vector(X.numpy()), pack_vector(Y.numpy())
|
|
355
|
+
return Ori, X, Y
|