rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import biotite.structure as struc
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from pydantic import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
ConfigDict,
|
|
9
|
+
Field,
|
|
10
|
+
)
|
|
11
|
+
from rfd3.inference.symmetry.atom_array import (
|
|
12
|
+
FIXED_ENTITY_ID,
|
|
13
|
+
FIXED_TRANSFORM_ID,
|
|
14
|
+
add_2d_entity_annotations,
|
|
15
|
+
add_src_sym_component_annotations,
|
|
16
|
+
add_sym_annotations,
|
|
17
|
+
annotate_unsym_atom_array,
|
|
18
|
+
fix_3D_sym_motif_annotations,
|
|
19
|
+
get_symmetry_unit,
|
|
20
|
+
reannotate_2d_conditions,
|
|
21
|
+
)
|
|
22
|
+
from rfd3.inference.symmetry.checks import (
|
|
23
|
+
check_symmetry_config,
|
|
24
|
+
)
|
|
25
|
+
from rfd3.inference.symmetry.contigs import (
|
|
26
|
+
expand_contig_unsym_motif,
|
|
27
|
+
get_unsym_motif_mask,
|
|
28
|
+
)
|
|
29
|
+
from rfd3.inference.symmetry.frames import (
|
|
30
|
+
get_symmetry_frames_from_atom_array,
|
|
31
|
+
get_symmetry_frames_from_symmetry_id,
|
|
32
|
+
)
|
|
33
|
+
from rfd3.transforms.conditioning_base import get_motif_features
|
|
34
|
+
|
|
35
|
+
from foundry.utils.components import fetch_mask_from_component
|
|
36
|
+
from foundry.utils.ddp import RankedLogger
|
|
37
|
+
|
|
38
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SymmetryConfig(BaseModel):
|
|
42
|
+
# AM / HE TODO: feel free to flesh this out and add validation as needed
|
|
43
|
+
model_config = ConfigDict(
|
|
44
|
+
arbitrary_types_allowed=True,
|
|
45
|
+
extra="allow",
|
|
46
|
+
)
|
|
47
|
+
id: Optional[str] = Field(None)
|
|
48
|
+
# is_unsym_motif: Optional[np.ndarray[bool]] = Field(...)
|
|
49
|
+
# is_symmetric_motif: bool = Field(...)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def make_symmetric_atom_array(
|
|
53
|
+
asu_atom_array, sym_conf: SymmetryConfig, sm=None, has_2d=False, src_atom_array=None
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
apply symmetry to an atom array.
|
|
57
|
+
Arguments:
|
|
58
|
+
asu_atom_array: atom array of the asymmetric unit
|
|
59
|
+
sym_conf: symmetry configuration (dict, "id" key is required)
|
|
60
|
+
sm: optional small molecule names (str, comma separated)
|
|
61
|
+
has_2d: whether to add 2d entity annotations
|
|
62
|
+
Returns:
|
|
63
|
+
new_asu_atom_array: atom array with symmetry applied
|
|
64
|
+
"""
|
|
65
|
+
sym_conf = (
|
|
66
|
+
sym_conf.model_dump()
|
|
67
|
+
) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
|
|
68
|
+
ranked_logger.info(f"Symmetry Configs: {sym_conf}")
|
|
69
|
+
|
|
70
|
+
# Making sure that the symmetry config is valid
|
|
71
|
+
check_symmetry_config(
|
|
72
|
+
asu_atom_array,
|
|
73
|
+
sym_conf,
|
|
74
|
+
sm,
|
|
75
|
+
has_dist_cond=has_2d,
|
|
76
|
+
src_atom_array=src_atom_array,
|
|
77
|
+
)
|
|
78
|
+
# Adding utility annotations to the asu atom array
|
|
79
|
+
asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
|
|
80
|
+
|
|
81
|
+
if has_2d: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
|
|
82
|
+
asu_atom_array = add_2d_entity_annotations(asu_atom_array)
|
|
83
|
+
|
|
84
|
+
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
85
|
+
|
|
86
|
+
# If the motif is symmetric, we get the frames instead from the source atom array.
|
|
87
|
+
if sym_conf.get("is_symmetric_motif"):
|
|
88
|
+
assert (
|
|
89
|
+
src_atom_array is not None
|
|
90
|
+
), "Source atom array must be provided for symmetric motifs"
|
|
91
|
+
# if symmetric motif is provided, get the frames from the src atom array
|
|
92
|
+
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
93
|
+
else:
|
|
94
|
+
raise NotImplementedError(
|
|
95
|
+
"Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Add symmetry annotations to the asu atom array
|
|
99
|
+
asu_atom_array = add_sym_annotations(asu_atom_array, sym_conf)
|
|
100
|
+
|
|
101
|
+
# Extracting all things at this moment that we will not want to symmetrize.
|
|
102
|
+
# This includes: 1) unsym motifs, 2) ligands
|
|
103
|
+
unsym_atom_arrays = []
|
|
104
|
+
if sym_conf.get("is_unsym_motif"):
|
|
105
|
+
# unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
|
|
106
|
+
# Now remove the unsym motifs from the asu atom array
|
|
107
|
+
unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
|
|
108
|
+
asu_atom_array = asu_atom_array[~asu_atom_array._is_unsym_motif]
|
|
109
|
+
if sm:
|
|
110
|
+
unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_sm])
|
|
111
|
+
asu_atom_array = asu_atom_array[~asu_atom_array._is_sm]
|
|
112
|
+
unsym_atom_array = (
|
|
113
|
+
struc.concatenate(unsym_atom_arrays) if len(unsym_atom_arrays) > 0 else None
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Annotate symmetric subunits
|
|
117
|
+
symmetry_unit_list = []
|
|
118
|
+
for transform_id, frame in enumerate(frames):
|
|
119
|
+
# this is to build the fully symmetrized atom array containing all the symmetry subunits
|
|
120
|
+
symmetry_unit = get_symmetry_unit(asu_atom_array, transform_id, frame)
|
|
121
|
+
symmetry_unit_list.append(symmetry_unit)
|
|
122
|
+
if unsym_atom_array: # only if exists
|
|
123
|
+
unsym_atom_array = annotate_unsym_atom_array(unsym_atom_array)
|
|
124
|
+
symmetry_unit_list.append(
|
|
125
|
+
unsym_atom_array
|
|
126
|
+
) # add the motifs to the end of the asu atom array list (motifs at end of atom array)
|
|
127
|
+
# build the full symmetrized atom array
|
|
128
|
+
symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
|
|
129
|
+
|
|
130
|
+
# add 2D conditioning annotations
|
|
131
|
+
if has_2d:
|
|
132
|
+
symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
|
|
133
|
+
|
|
134
|
+
# set all motifs to not have any symmetrization applied to them
|
|
135
|
+
# TODO: this needs to be adapted to work with 2D cond (in 2D cond, we WANT to apply symmetry to the motifs since they move in space)
|
|
136
|
+
symmetrized_atom_array = fix_3D_sym_motif_annotations(symmetrized_atom_array)
|
|
137
|
+
|
|
138
|
+
# This is needed to output correct motif residue mappings in the output json
|
|
139
|
+
symmetrized_atom_array = add_src_sym_component_annotations(symmetrized_atom_array)
|
|
140
|
+
# remove utility annotations
|
|
141
|
+
symmetrized_atom_array = _del_util_annotations(symmetrized_atom_array)
|
|
142
|
+
return symmetrized_atom_array
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
|
|
146
|
+
"""
|
|
147
|
+
Apply symmetry to an atom array with partial diffusion.
|
|
148
|
+
Arguments:
|
|
149
|
+
atom_array: atom array of the asymmetric unit
|
|
150
|
+
sym_conf: symmetry configuration (dict, "id" key is required)
|
|
151
|
+
Returns:
|
|
152
|
+
atom_array: atom array with symmetry applied
|
|
153
|
+
"""
|
|
154
|
+
# TODO: clean up this function
|
|
155
|
+
|
|
156
|
+
# For partial diffusion with symmetric inputs, preserve exact positioning
|
|
157
|
+
ranked_logger.info(
|
|
158
|
+
"Partial diffusion with symmetry - preserving exact input coordinates"
|
|
159
|
+
)
|
|
160
|
+
ranked_logger.info("SKIPPING symmetry reconstruction to preserve input structure")
|
|
161
|
+
# Add full symmetry annotations without changing coordinates
|
|
162
|
+
from rfd3.inference.symmetry.checks import (
|
|
163
|
+
check_atom_array_is_symmetric,
|
|
164
|
+
)
|
|
165
|
+
from rfd3.inference.symmetry.frames import (
|
|
166
|
+
decompose_symmetry_frame,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
check_symmetry_config(
|
|
170
|
+
atom_array,
|
|
171
|
+
sym_conf,
|
|
172
|
+
sm=None,
|
|
173
|
+
has_dist_cond=False,
|
|
174
|
+
src_atom_array=None,
|
|
175
|
+
partial=True,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
atom_array = add_sym_annotations(atom_array, sym_conf)
|
|
179
|
+
assert check_atom_array_is_symmetric(atom_array), "Atom array is not symmetric"
|
|
180
|
+
|
|
181
|
+
n = atom_array.shape[0]
|
|
182
|
+
chain_ids = np.unique(atom_array.chain_id)
|
|
183
|
+
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
184
|
+
|
|
185
|
+
# Add symmetry ID
|
|
186
|
+
symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
|
|
187
|
+
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
188
|
+
|
|
189
|
+
# Initialize transform annotations (use same format as original system)
|
|
190
|
+
symmetry_transform_id = np.zeros(n, dtype=np.int32)
|
|
191
|
+
symmetry_entity_id = np.zeros(n, dtype=np.int32)
|
|
192
|
+
is_asu = np.zeros(n, dtype=bool)
|
|
193
|
+
|
|
194
|
+
# Add transform annotations for each chain (same format as add_symmetry_transform_annotations)
|
|
195
|
+
for i, chain_id in enumerate(chain_ids):
|
|
196
|
+
chain_mask = atom_array.chain_id == chain_id
|
|
197
|
+
transform_id = i % len(frames) # Cycle through available frames
|
|
198
|
+
frame = frames[transform_id]
|
|
199
|
+
|
|
200
|
+
# Decompose frame to packed scalars
|
|
201
|
+
Ori, X, Y = decompose_symmetry_frame(frame)
|
|
202
|
+
|
|
203
|
+
# Set annotations for this chain (use np.full like original system)
|
|
204
|
+
if i == 0: # First chain - initialize arrays
|
|
205
|
+
sym_transform_Ori = np.full(n, Ori)
|
|
206
|
+
sym_transform_X = np.full(n, X)
|
|
207
|
+
sym_transform_Y = np.full(n, Y)
|
|
208
|
+
is_asu[chain_mask] = True
|
|
209
|
+
else: # Subsequent chains - update specific atoms
|
|
210
|
+
sym_transform_Ori[chain_mask] = Ori
|
|
211
|
+
sym_transform_X[chain_mask] = X
|
|
212
|
+
sym_transform_Y[chain_mask] = Y
|
|
213
|
+
|
|
214
|
+
symmetry_transform_id[chain_mask] = transform_id
|
|
215
|
+
symmetry_entity_id[chain_mask] = 0 # All chains same entity for C9
|
|
216
|
+
|
|
217
|
+
# Set all annotations
|
|
218
|
+
atom_array.set_annotation("sym_transform_Ori", sym_transform_Ori)
|
|
219
|
+
atom_array.set_annotation("sym_transform_X", sym_transform_X)
|
|
220
|
+
atom_array.set_annotation("sym_transform_Y", sym_transform_Y)
|
|
221
|
+
atom_array.set_annotation("sym_transform_id", symmetry_transform_id)
|
|
222
|
+
atom_array.set_annotation("sym_entity_id", symmetry_entity_id)
|
|
223
|
+
atom_array.set_annotation("is_sym_asu", is_asu)
|
|
224
|
+
|
|
225
|
+
ranked_logger.info(
|
|
226
|
+
f"Added full symmetry annotations to {len(chain_ids)} existing chains WITHOUT changing coordinates"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return atom_array
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
########################################################
|
|
233
|
+
# Private functions only used in make_symmetric_atom_array
|
|
234
|
+
########################################################
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
238
|
+
"""
|
|
239
|
+
Add symmetry-specific utility annotations to the asu atom array.
|
|
240
|
+
Arguments:
|
|
241
|
+
asu_atom_array: atom array of the asymmetric unit
|
|
242
|
+
sym_conf: symmetry configuration
|
|
243
|
+
sm: small molecule names (str, comma separated)
|
|
244
|
+
"""
|
|
245
|
+
n = asu_atom_array.shape[0]
|
|
246
|
+
is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
|
|
247
|
+
is_sm = np.zeros(asu_atom_array.shape[0], dtype=bool)
|
|
248
|
+
is_asu = np.ones(n, dtype=bool)
|
|
249
|
+
is_unsym_motif = np.zeros(n, dtype=bool)
|
|
250
|
+
|
|
251
|
+
if sm:
|
|
252
|
+
is_sm = np.logical_or.reduce(
|
|
253
|
+
[
|
|
254
|
+
fetch_mask_from_component(lig, atom_array=asu_atom_array)
|
|
255
|
+
for lig in sm.split(",")
|
|
256
|
+
]
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# assign unsym motifs
|
|
260
|
+
if sym_conf.get("is_unsym_motif"):
|
|
261
|
+
unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
|
|
262
|
+
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
263
|
+
is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
|
|
264
|
+
|
|
265
|
+
is_unindexed_motif = asu_atom_array.is_motif_atom_unindexed.astype(np.bool_)
|
|
266
|
+
is_indexed_motif = ~is_sm & ~is_unindexed_motif & is_motif
|
|
267
|
+
|
|
268
|
+
asu_atom_array.set_annotation(
|
|
269
|
+
"_is_asu", is_asu
|
|
270
|
+
) # Currently not used but will needed for 2D cond
|
|
271
|
+
asu_atom_array.set_annotation("_is_motif", is_motif)
|
|
272
|
+
asu_atom_array.set_annotation("_is_sm", is_sm)
|
|
273
|
+
asu_atom_array.set_annotation("_is_indexed_motif", is_indexed_motif)
|
|
274
|
+
asu_atom_array.set_annotation("_is_unindexed_motif", is_unindexed_motif)
|
|
275
|
+
asu_atom_array.set_annotation("_is_unsym_motif", is_unsym_motif)
|
|
276
|
+
return asu_atom_array
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _del_util_annotations(aary):
|
|
280
|
+
"""
|
|
281
|
+
Delete symmetry-specific utility annotations from the atom array.
|
|
282
|
+
Arguments:
|
|
283
|
+
aary: atom array
|
|
284
|
+
"""
|
|
285
|
+
aary.del_annotation("_is_asu") # Currently not used but will needed for 2D cond
|
|
286
|
+
aary.del_annotation("_is_motif")
|
|
287
|
+
aary.del_annotation("_is_sm")
|
|
288
|
+
aary.del_annotation("_is_indexed_motif")
|
|
289
|
+
aary.del_annotation("_is_unindexed_motif")
|
|
290
|
+
aary.del_annotation("_is_unsym_motif")
|
|
291
|
+
return aary
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
#########################
|
|
295
|
+
# Symmetrization functions
|
|
296
|
+
#########################
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def center_symmetric_src_atom_array(src_atom_array):
|
|
300
|
+
"""
|
|
301
|
+
Center the src atom array at the origin.
|
|
302
|
+
Arguments:
|
|
303
|
+
src_atom_array: atom array of the source
|
|
304
|
+
Returns:
|
|
305
|
+
src_atom_array: atom array of the source centered at the origin
|
|
306
|
+
"""
|
|
307
|
+
# Compute COM of the src atom array (protein only elements)
|
|
308
|
+
src_atom_array_com = np.mean(
|
|
309
|
+
src_atom_array[src_atom_array.chain_type == 6].coord, axis=0
|
|
310
|
+
)
|
|
311
|
+
# center the src atom array
|
|
312
|
+
src_atom_array.coord -= src_atom_array_com
|
|
313
|
+
return src_atom_array
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
|
|
317
|
+
"""
|
|
318
|
+
Apply symmetry to the xyz coordinates.
|
|
319
|
+
Arguments:
|
|
320
|
+
X_L: [B, L, 3] xyz coordinates
|
|
321
|
+
sym_feats: dictionary containing symmetry features (id, transform, entity_id, is_sym_asu)
|
|
322
|
+
Returns:
|
|
323
|
+
X_L: [B, L, 3] xyz coordinates with symmetry applied
|
|
324
|
+
"""
|
|
325
|
+
sym_entity_id = sym_feats["sym_entity_id"]
|
|
326
|
+
sym_transform_id = sym_feats["sym_transform_id"]
|
|
327
|
+
is_sym_asu = sym_feats["is_sym_asu"]
|
|
328
|
+
fixed_motif_mask = sym_entity_id == FIXED_ENTITY_ID
|
|
329
|
+
sym_transforms = {
|
|
330
|
+
int(k): v
|
|
331
|
+
for k, v in sym_feats["sym_transform"].items()
|
|
332
|
+
if int(k) != FIXED_TRANSFORM_ID
|
|
333
|
+
} # {str(id): tensor(3,3)} -> {int(id): tensor(3,3)}
|
|
334
|
+
# COM correction (in case there is drift)
|
|
335
|
+
if not partial_diffusion:
|
|
336
|
+
X_L[:, ~fixed_motif_mask, :] = X_L[:, ~fixed_motif_mask, :] - X_L[
|
|
337
|
+
:, ~fixed_motif_mask, :
|
|
338
|
+
].mean(dim=1, keepdim=True)
|
|
339
|
+
sym_X_L = X_L.clone()
|
|
340
|
+
|
|
341
|
+
# Loop through each symmetry entity id - making sure that we apply the matching symmetry transform to asu id
|
|
342
|
+
unique_entity_id = torch.unique(sym_entity_id)
|
|
343
|
+
unique_entity_id = unique_entity_id[unique_entity_id != FIXED_ENTITY_ID]
|
|
344
|
+
for entity_id in unique_entity_id:
|
|
345
|
+
# Mask for this entity id
|
|
346
|
+
entity_id_mask = sym_entity_id == entity_id # [L]
|
|
347
|
+
# ASU that corresponds to this transform only
|
|
348
|
+
entity_asu_mask = is_sym_asu & entity_id_mask
|
|
349
|
+
if entity_asu_mask.sum() == 0:
|
|
350
|
+
continue
|
|
351
|
+
asu_xyz = X_L[:, entity_asu_mask, :] # [B, Lasu, 3]
|
|
352
|
+
# Transforms
|
|
353
|
+
unique_transform_id = torch.unique(sym_transform_id[entity_id_mask]).tolist()
|
|
354
|
+
for (
|
|
355
|
+
target_id
|
|
356
|
+
) in unique_transform_id: # Open to suggestions for making this more efficient
|
|
357
|
+
# Get a mask that corresponds to this specific transform in the entire atom array
|
|
358
|
+
this_subunit = entity_id_mask & (sym_transform_id == target_id)
|
|
359
|
+
# Apply this subunit's symmetry transform to its corresponding ASU
|
|
360
|
+
sym_X_L[:, this_subunit, :] = torch.einsum(
|
|
361
|
+
"blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
|
|
362
|
+
) + sym_transforms[target_id][1].to(asu_xyz.dtype)
|
|
363
|
+
|
|
364
|
+
# Log inter-chain distances for debugging - use actual chain annotations
|
|
365
|
+
if sym_X_L.shape[1] > 100: # Only for large structures
|
|
366
|
+
# Use symmetry entity annotations to find different chains
|
|
367
|
+
sym_entity_id = sym_feats["sym_entity_id"]
|
|
368
|
+
unique_entities = torch.unique(sym_entity_id)
|
|
369
|
+
|
|
370
|
+
if len(unique_entities) >= 2:
|
|
371
|
+
# Get atoms from first two different entities
|
|
372
|
+
entity_0_mask = sym_entity_id == unique_entities[0]
|
|
373
|
+
entity_1_mask = sym_entity_id == unique_entities[1]
|
|
374
|
+
|
|
375
|
+
if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
|
|
376
|
+
entity_0_atoms = sym_X_L[0, entity_0_mask, :]
|
|
377
|
+
entity_1_atoms = sym_X_L[0, entity_1_mask, :]
|
|
378
|
+
|
|
379
|
+
# Sample subset to avoid memory issues
|
|
380
|
+
entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
|
|
381
|
+
entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
|
|
382
|
+
|
|
383
|
+
min_distance = (
|
|
384
|
+
torch.cdist(entity_0_sample, entity_1_sample).min().item()
|
|
385
|
+
)
|
|
386
|
+
ranked_logger.info(
|
|
387
|
+
f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Also log the centers of each entity
|
|
391
|
+
entity_0_center = entity_0_atoms.mean(dim=0)
|
|
392
|
+
entity_1_center = entity_1_atoms.mean(dim=0)
|
|
393
|
+
center_distance = torch.norm(entity_0_center - entity_1_center).item()
|
|
394
|
+
ranked_logger.info(
|
|
395
|
+
f"Distance between chain centers: {center_distance:.2f} Å"
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
return sym_X_L
|