rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rfd3/utils/io.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
10
|
+
from biotite.structure import AtomArray, AtomArrayStack, stack
|
|
11
|
+
|
|
12
|
+
from foundry.utils.alignment import weighted_rigid_align
|
|
13
|
+
|
|
14
|
+
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
|
|
15
|
+
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def dump_structures(
|
|
19
|
+
atom_arrays: AtomArrayStack | list[AtomArray] | AtomArray,
|
|
20
|
+
base_path: PathLike,
|
|
21
|
+
one_model_per_file: bool,
|
|
22
|
+
extra_fields: list[str] | Literal["all"] = [],
|
|
23
|
+
**kwargs,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Dump structures to CIF files, given the coordinates and input AtomArray.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
atom_arrays (AtomArrayStack | list[AtomArray] | AtomArray): Either an AtomArrayStack, a list of AtomArray objects,
|
|
29
|
+
or a single AtomArray object to be dumped to CIF file(s)
|
|
30
|
+
base_path (PathLike): Base path where the output files will be saved.
|
|
31
|
+
one_model_per_file (bool): Flag to determine if each model should be dumped into a separate file. Has no effect if
|
|
32
|
+
`atom_arrays` is a list of AtomArrays.
|
|
33
|
+
extra_fields (list[str] | Literal["all"]): List of extra fields to include in the CIF file.
|
|
34
|
+
"""
|
|
35
|
+
base_path = Path(base_path)
|
|
36
|
+
if one_model_per_file:
|
|
37
|
+
assert (
|
|
38
|
+
isinstance(atom_arrays, AtomArrayStack) or isinstance(atom_arrays, list)
|
|
39
|
+
), "AtomArrayStack or list of AtomArray required when one_model_per_file is True"
|
|
40
|
+
# One model per file —> loop over the diffusion batch
|
|
41
|
+
for i in range(len(atom_arrays)):
|
|
42
|
+
path = f"{base_path}_model_{i}"
|
|
43
|
+
to_cif_file(
|
|
44
|
+
atom_arrays[i],
|
|
45
|
+
path,
|
|
46
|
+
file_type="cif.gz",
|
|
47
|
+
include_entity_poly=False,
|
|
48
|
+
extra_fields=extra_fields,
|
|
49
|
+
**kwargs,
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
# Include all models in a single CIF file
|
|
53
|
+
to_cif_file(
|
|
54
|
+
atom_arrays,
|
|
55
|
+
base_path,
|
|
56
|
+
file_type="cif.gz",
|
|
57
|
+
include_entity_poly=False,
|
|
58
|
+
extra_fields=extra_fields,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def dump_metadata(
|
|
64
|
+
prediction_metadata: dict,
|
|
65
|
+
base_path: PathLike,
|
|
66
|
+
one_model_per_file: bool,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Dump JSONs of prediction metadata to disk.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
prediction_metadata (dict): Dictionary containing metadata for the predictions.
|
|
73
|
+
Instantiated after the models' predictions in trainer. Keys are model indices (from 0)
|
|
74
|
+
base_path (PathLike): Base path where the output files will be saved.
|
|
75
|
+
one_model_per_file (bool): If True, save each model's metadata in a separate file.
|
|
76
|
+
"""
|
|
77
|
+
if one_model_per_file:
|
|
78
|
+
# One model per file —> loop over the diffusion batch
|
|
79
|
+
for i in prediction_metadata:
|
|
80
|
+
path = f"{base_path}_model_{i}"
|
|
81
|
+
with open(f"{path}.json", "w") as f:
|
|
82
|
+
json.dump(prediction_metadata[i], f, indent=4)
|
|
83
|
+
else:
|
|
84
|
+
# Include all models in a single JSON file
|
|
85
|
+
with open(f"{base_path}.json", "w") as f:
|
|
86
|
+
json.dump(prediction_metadata, f, indent=4)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def dump_trajectories(
|
|
90
|
+
trajectory_list: list[torch.Tensor | np.ndarray],
|
|
91
|
+
atom_array: AtomArray,
|
|
92
|
+
base_path: Path,
|
|
93
|
+
align_structures: bool = True,
|
|
94
|
+
coord_atom_lvl_to_be_noised: torch.Tensor | None = None,
|
|
95
|
+
is_motif_atom_with_fixed_pos: torch.Tensor | None = None,
|
|
96
|
+
max_frames: int = 100,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Write denoising trajectories to CIF files.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
trajectory_list (List[torch.Tensor]): List of tensors of length n_steps representing the diffusion trajectory at each step.
|
|
102
|
+
Each tensor has shape [D, L, 3], where D is the diffusion batch size and L is the number of atoms.
|
|
103
|
+
atom_array (np.ndarray): Atom array corresponding to the coordinates.
|
|
104
|
+
base_path (Path): Base path where the output files will be saved.
|
|
105
|
+
align_structures (bool): Flag to determine if the structures should be aligned on the final prediction.
|
|
106
|
+
If False, each step may have a different alignment.
|
|
107
|
+
"""
|
|
108
|
+
n_steps = len(trajectory_list)
|
|
109
|
+
|
|
110
|
+
if align_structures:
|
|
111
|
+
# ... align the trajectories on the last prediction
|
|
112
|
+
w_L = torch.ones(*trajectory_list[0].shape[:2]).to(trajectory_list[0].device)
|
|
113
|
+
X_exists_L = torch.ones(trajectory_list[0].shape[1], dtype=torch.bool).to(
|
|
114
|
+
trajectory_list[0].device
|
|
115
|
+
)
|
|
116
|
+
for step in range(n_steps - 1):
|
|
117
|
+
trajectory_list[step] = weighted_rigid_align(
|
|
118
|
+
X_L=trajectory_list[-1],
|
|
119
|
+
X_gt_L=trajectory_list[step],
|
|
120
|
+
X_exists_L=X_exists_L,
|
|
121
|
+
w_L=w_L,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# ... invert the list, to make the trajectory compatible with PyMol (which builds the bond graph from the first frame)
|
|
125
|
+
trajectory_list = trajectory_list[::-1]
|
|
126
|
+
|
|
127
|
+
# ... Select subset of frames if necessary
|
|
128
|
+
if n_steps > max_frames:
|
|
129
|
+
selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
|
|
130
|
+
trajectory_list = [trajectory_list[i] for i in selected_indices]
|
|
131
|
+
n_steps = len(trajectory_list)
|
|
132
|
+
|
|
133
|
+
# ... iterate over the range of D (diffusion batch size; e.g., 5 during validation)
|
|
134
|
+
# (We want to convert `aligned_trajectory_list` to a list of length D where each item is a tensor of shape [n_steps, L, 3])
|
|
135
|
+
trajectories_split_by_model = []
|
|
136
|
+
for d in range(trajectory_list[0].shape[0]):
|
|
137
|
+
trajectory_for_single_model = torch.stack(
|
|
138
|
+
[trajectory_list[step][d] for step in range(n_steps)], dim=0
|
|
139
|
+
)
|
|
140
|
+
trajectories_split_by_model.append(trajectory_for_single_model)
|
|
141
|
+
|
|
142
|
+
# ... write the trajectories to CIF files, named by epoch, dataset, example_id, and model index (within the diffusion batch)
|
|
143
|
+
for i, trajectory in enumerate(trajectories_split_by_model):
|
|
144
|
+
if isinstance(trajectory, torch.Tensor):
|
|
145
|
+
trajectory = trajectory.cpu().numpy()
|
|
146
|
+
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
147
|
+
trajectory, atom_array
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
path = f"{base_path}_model_{i}"
|
|
151
|
+
to_cif_file(
|
|
152
|
+
atom_array_stack, path, file_type="cif.gz", include_entity_poly=False
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def build_stack_from_atom_array_and_batched_coords(
|
|
157
|
+
coords: np.ndarray | torch.Tensor,
|
|
158
|
+
atom_array: AtomArray,
|
|
159
|
+
) -> AtomArrayStack:
|
|
160
|
+
"""Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
|
|
161
|
+
|
|
162
|
+
Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
|
|
166
|
+
atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
|
|
167
|
+
"""
|
|
168
|
+
if isinstance(coords, torch.Tensor):
|
|
169
|
+
coords = coords.cpu().numpy()
|
|
170
|
+
|
|
171
|
+
assert (
|
|
172
|
+
coords.shape[-2] == atom_array.array_length()
|
|
173
|
+
), f"N batched coordinates {coords.shape} != {atom_array.array_length()}"
|
|
174
|
+
|
|
175
|
+
# (Diffusion batch size will become the number of models)
|
|
176
|
+
n_batch = coords.shape[0]
|
|
177
|
+
|
|
178
|
+
# Build the stack and assign the coordinates
|
|
179
|
+
atom_array_stack = stack([atom_array for _ in range(n_batch)])
|
|
180
|
+
atom_array_stack.coord = coords
|
|
181
|
+
|
|
182
|
+
# Adjust chain_id if there are multiple transformations
|
|
183
|
+
# (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
|
|
184
|
+
if (
|
|
185
|
+
"transformation_id" in atom_array.get_annotation_categories()
|
|
186
|
+
and len(np.unique(atom_array_stack.transformation_id)) > 1
|
|
187
|
+
):
|
|
188
|
+
new_chain_ids = np.char.add(
|
|
189
|
+
atom_array_stack.chain_id, atom_array_stack.transformation_id
|
|
190
|
+
)
|
|
191
|
+
atom_array_stack.set_annotation("chain_id", new_chain_ids)
|
|
192
|
+
|
|
193
|
+
return atom_array_stack
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def find_files_with_extension(path: PathLike, supported_file_types: list) -> list[Path]:
|
|
197
|
+
"""Recursively find all files with the given extensions in the specified path.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
path (PathLike): Path to the directory containing the files.
|
|
201
|
+
supported_file_types (list): List of supported file extensions.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
list[Path]: List of files with the given extensions.
|
|
205
|
+
"""
|
|
206
|
+
files_with_supported_types = []
|
|
207
|
+
path = Path(path)
|
|
208
|
+
|
|
209
|
+
# Check if the path is a directory
|
|
210
|
+
if path.is_dir():
|
|
211
|
+
# Search for files with each supported extension
|
|
212
|
+
for file_type in supported_file_types:
|
|
213
|
+
files_with_supported_types.extend(path.glob(f"*{file_type}"))
|
|
214
|
+
elif path.is_file() and path.suffix in supported_file_types:
|
|
215
|
+
# If it's a file and has a supported extension, add to the list
|
|
216
|
+
files_with_supported_types.append(path)
|
|
217
|
+
|
|
218
|
+
return files_with_supported_types
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def create_example_id_extractor(extensions: set | list = CIF_LIKE_EXTENSIONS) -> str:
|
|
222
|
+
"""Create a function with closure that extracts example_ids from file paths with specified extensions.
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
>>> extractor = create_example_id_extractor({".cif", ".cif.gz"})
|
|
226
|
+
>>> extractor("example.path.example_id.cif.gz")
|
|
227
|
+
'example_id'
|
|
228
|
+
"""
|
|
229
|
+
pattern = re.compile(
|
|
230
|
+
"(" + "|".join(re.escape(ext) + "$" for ext in extensions) + ")"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def extract_id(file_path: PathLike) -> str:
|
|
234
|
+
"""Extract example_id from file path."""
|
|
235
|
+
# Remove extension and get last part after splitting by dots
|
|
236
|
+
without_ext = pattern.sub("", Path(file_path).name)
|
|
237
|
+
return without_ext.split(".")[-1]
|
|
238
|
+
|
|
239
|
+
return extract_id
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def extract_example_id_from_path(file_path: PathLike, extensions: set | list) -> str:
|
|
243
|
+
"""Extract example_id from file path with specified extensions."""
|
|
244
|
+
extractor = create_example_id_extractor(extensions)
|
|
245
|
+
return extractor(file_path)
|
rfd3/utils/vizualize.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../scripts/shebang/modelhub_exec.sh" "$0" "$@"'
|
|
2
|
+
"""
|
|
3
|
+
If you add the `/path/to/visualize.py` to your .bashrc/.zshrc like this:
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
viz () {
|
|
7
|
+
/path/to/visualize.py "$@"
|
|
8
|
+
}
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Then you can run `viz /path/to/file.cif` to visualize the structures via pymol-remote.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import pathlib
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
# NOTE: This is needed here to enable `viz` to be used as script
|
|
18
|
+
if (project_dir := str(pathlib.Path(__file__).parents[3])) not in sys.path:
|
|
19
|
+
sys.path.append(project_dir)
|
|
20
|
+
|
|
21
|
+
import logging
|
|
22
|
+
|
|
23
|
+
import biotite.structure as struc
|
|
24
|
+
import numpy as np
|
|
25
|
+
from atomworks.io.utils.visualize import get_pymol_session, view_pymol
|
|
26
|
+
from atomworks.ml.conditions import C_DIS, Condition
|
|
27
|
+
from atomworks.ml.conditions.base import Level
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
logger.setLevel(logging.WARNING)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_atom_array_style_cmd(
|
|
34
|
+
atom_array: struc.AtomArray | struc.AtomArrayStack,
|
|
35
|
+
obj: str,
|
|
36
|
+
label: bool = False,
|
|
37
|
+
max_distances: int = 100,
|
|
38
|
+
grid_slot: int | None = None,
|
|
39
|
+
) -> str:
|
|
40
|
+
"""Generate PyMOL commands to style an atom array visualization.
|
|
41
|
+
|
|
42
|
+
Creates a series of PyMOL commands that style different parts of a molecular structure, including:
|
|
43
|
+
- Applying a color spectrum to polymer chains
|
|
44
|
+
- Showing backbone atoms as sticks and CA atoms as spheres
|
|
45
|
+
- Styling non-polymer atoms with different colors and representation
|
|
46
|
+
- Highlighting specially annotated atoms with different colors and visualizations
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
atom_array: The biotite AtomArray or AtomArrayStack to be styled
|
|
50
|
+
obj: PyMOL object name to apply the styling to
|
|
51
|
+
label: Whether to label all atoms with their 0-indexed atom_id
|
|
52
|
+
max_num_dist_lines: Maximum number of distance lines to show (pymol hangs when there's too many distance objects)
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
str: A PyMOL command string that styles the atom array
|
|
56
|
+
"""
|
|
57
|
+
grid_slot = grid_slot or np.random.randint(0, 10_000)
|
|
58
|
+
commands = [f"hide everything, {obj}"]
|
|
59
|
+
annotations = atom_array.get_annotation_categories()
|
|
60
|
+
|
|
61
|
+
offset = 1 # ... default offset since pymol 1-indexes atom id's
|
|
62
|
+
atom_ids = np.arange(1, atom_array.array_length() + 1)
|
|
63
|
+
if "atom_id" in annotations:
|
|
64
|
+
offset = 0
|
|
65
|
+
atom_ids = atom_array.get_annotation("atom_id")
|
|
66
|
+
|
|
67
|
+
# Convert MAS residues to ALA for compatibility
|
|
68
|
+
atom_array = atom_array.copy()
|
|
69
|
+
# atom_array.res_name[atom_array.res_name == "MAS"] = "ALA"
|
|
70
|
+
|
|
71
|
+
# Style the backbone for each polymer chain with a color spectrum
|
|
72
|
+
for chain_id in struc.get_chains(atom_array):
|
|
73
|
+
if (~atom_array.hetero[atom_array.chain_id == chain_id]).any():
|
|
74
|
+
commands.append(
|
|
75
|
+
f"spectrum resi, RFd_darkblue RFd_blue RFd_lightblue RFd_purple RFd_pink RFd_melon RFd_navaho, "
|
|
76
|
+
f"{obj} and chain {chain_id} and elem C"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Add basic styling commands for protein backbone and non-polymer components
|
|
80
|
+
commands.extend(
|
|
81
|
+
[
|
|
82
|
+
f"show sticks, model {obj} and name n+c+ca+cb",
|
|
83
|
+
f"show spheres, model {obj} and name ca",
|
|
84
|
+
f"set sphere_scale, 0.23, model {obj} and name ca",
|
|
85
|
+
f"set sphere_transparency, 0, model {obj} and name ca",
|
|
86
|
+
f"color grey60, model {obj} and not polymer and elem C",
|
|
87
|
+
f"show nb_spheres, model {obj} and not polymer",
|
|
88
|
+
f"show sticks, model {obj} and not polymer",
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
if label:
|
|
92
|
+
# label 0-indexed for correspondence with biotite
|
|
93
|
+
commands.append("label all, ID") if offset == 0 else commands.append(
|
|
94
|
+
f"label all, ID-{offset}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Style atoms marked for "atomize" if present
|
|
98
|
+
if "atomize" in annotations and atom_array.atomize.any():
|
|
99
|
+
atomize_ids = np.where(atom_array.atomize)[0]
|
|
100
|
+
atomize_ids = atom_ids[atomize_ids]
|
|
101
|
+
commands.extend(
|
|
102
|
+
[
|
|
103
|
+
f"select {obj}_atomize, model {obj} and id {'+'.join(str(id) for id in atomize_ids)}",
|
|
104
|
+
f"show sticks, {obj}_atomize and byres {obj}_atomize",
|
|
105
|
+
]
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Style constraints, if present:
|
|
109
|
+
# 2-body:
|
|
110
|
+
# ... add distance lines between atoms if specified in annotations
|
|
111
|
+
if hasattr(atom_array, "_annot_2d"):
|
|
112
|
+
distance_commands = []
|
|
113
|
+
if C_DIS.full_name in atom_array._annot_2d:
|
|
114
|
+
constraint_data = C_DIS.annotation(atom_array).as_array()
|
|
115
|
+
if len(constraint_data) > 0:
|
|
116
|
+
_atom_idxs = np.unique(constraint_data[:, :2].flatten()).astype(int)
|
|
117
|
+
_atom_ids = atom_ids[_atom_idxs]
|
|
118
|
+
_selection = f"{obj} and id {'+'.join(str(id) for id in _atom_ids)}"
|
|
119
|
+
commands.extend(
|
|
120
|
+
[
|
|
121
|
+
f"delete m2d_{obj}",
|
|
122
|
+
f"select m2d_{obj}, {_selection}",
|
|
123
|
+
f"show spheres, m2d_{obj}",
|
|
124
|
+
f"set sphere_color, lime, m2d_{obj}",
|
|
125
|
+
f"set sphere_scale, 0.25, m2d_{obj}",
|
|
126
|
+
f"set sphere_transparency, 0.5, m2d_{obj}",
|
|
127
|
+
f"show sticks, byres m2d_{obj}",
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if len(constraint_data) > max_distances:
|
|
132
|
+
logger.warning(
|
|
133
|
+
f"Too many distance conditions ({len(constraint_data)}), sampling {max_distances}."
|
|
134
|
+
)
|
|
135
|
+
constraint_idxs = np.arange(len(constraint_data))
|
|
136
|
+
constraint_idxs = np.random.choice(
|
|
137
|
+
constraint_idxs, max_distances, replace=False
|
|
138
|
+
)
|
|
139
|
+
constraint_data = constraint_data[constraint_idxs]
|
|
140
|
+
|
|
141
|
+
for row in constraint_data:
|
|
142
|
+
idx_i, idx_j, value = row
|
|
143
|
+
if (idx_i > idx_j) or (value == 0):
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
i, j = atom_ids[idx_i], atom_ids[idx_j]
|
|
147
|
+
# ... if we have a stack, we grab the last frame for the distance computation
|
|
148
|
+
if isinstance(atom_array, struc.AtomArrayStack):
|
|
149
|
+
distance = struc.distance(
|
|
150
|
+
atom_array[0, idx_i], atom_array[0, idx_j]
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
distance = struc.distance(atom_array[idx_i], atom_array[idx_j])
|
|
154
|
+
|
|
155
|
+
distance_name = f"d{idx_i}-{idx_j}_{value:.2f}_{distance:.2f}"
|
|
156
|
+
distance_commands.extend(
|
|
157
|
+
[
|
|
158
|
+
f"distance {distance_name}, model {obj} and id {i}, model {obj} and id {j}",
|
|
159
|
+
f"set grid_slot, {grid_slot}, {distance_name}",
|
|
160
|
+
]
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
commands.extend(distance_commands)
|
|
164
|
+
|
|
165
|
+
# Handle 1-D conditions (only display masks)
|
|
166
|
+
for cond in Condition:
|
|
167
|
+
if cond.n_body == 1 and cond.mask(atom_array, default="generate").any():
|
|
168
|
+
_atom_ids = atom_ids[np.where(cond.mask(atom_array, default="generate"))[0]]
|
|
169
|
+
if cond.level == Level.ATOM:
|
|
170
|
+
_selection = (
|
|
171
|
+
f"model {obj} and id {'+'.join(str(id) for id in _atom_ids)}"
|
|
172
|
+
)
|
|
173
|
+
commands.extend([f"select {cond.mask_name}_{obj}, {_selection}"])
|
|
174
|
+
elif cond.level == Level.RESIDUE or cond.level == Level.TOKEN:
|
|
175
|
+
_selection = f"model {obj} and byres (id {'+'.join(str(id) for id in _atom_ids)})"
|
|
176
|
+
commands.extend([f"select {cond.mask_name}_{obj}, {_selection}"])
|
|
177
|
+
|
|
178
|
+
return "\n".join(commands)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def viz(
|
|
182
|
+
atom_array: struc.AtomArray | struc.AtomArrayStack,
|
|
183
|
+
id: str = "obj",
|
|
184
|
+
clear: bool = True,
|
|
185
|
+
label: bool = True,
|
|
186
|
+
max_distances: int = 100,
|
|
187
|
+
view_ori_token: bool = False,
|
|
188
|
+
) -> None:
|
|
189
|
+
"""Quickly visualize a molecular structure in PyMOL with predefined styling.
|
|
190
|
+
|
|
191
|
+
This function creates a PyMOL session, loads the atom array structure, and applies
|
|
192
|
+
a set of styling commands to make the visualization informative and aesthetically pleasing.
|
|
193
|
+
The styling highlights different structural components and annotated features.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
atom_array: The biotite AtomArray or AtomArrayStack to visualize
|
|
197
|
+
id: PyMOL object identifier (default: "obj")
|
|
198
|
+
clear: Whether to clear existing PyMOL objects before visualization (default: True)
|
|
199
|
+
label: Whether to label all atoms with their 0-indexed atom_id (default: True)
|
|
200
|
+
view_ori_token: Whether to view the ori token (default: False)
|
|
201
|
+
|
|
202
|
+
Example:
|
|
203
|
+
>>> from biotite.structure import AtomArray
|
|
204
|
+
>>> # Create or load an atom array
|
|
205
|
+
>>> structure = AtomArray(...)
|
|
206
|
+
>>> # Visualize the structure in PyMOL
|
|
207
|
+
>>> viz(structure)
|
|
208
|
+
"""
|
|
209
|
+
atom_array = atom_array.copy()
|
|
210
|
+
# pymol only considers chain_id annotation, which can make weird looking artifacts if we have two different chains with the same chain_id
|
|
211
|
+
# We always disambiguate by using the chain_iid annotation, so we need to have pymol use that to do the same
|
|
212
|
+
if "chain_iid" in atom_array.get_annotation_categories():
|
|
213
|
+
atom_array.chain_id = atom_array.get_annotation("chain_iid")
|
|
214
|
+
atom_array.chain_id = atom_array.chain_id.astype(str)
|
|
215
|
+
|
|
216
|
+
pymol = get_pymol_session()
|
|
217
|
+
pymol.do("set valence, 1; set connect_mode, 2;")
|
|
218
|
+
if clear:
|
|
219
|
+
pymol.do("delete d*")
|
|
220
|
+
pymol.delete("all")
|
|
221
|
+
slot = np.random.randint(0, 10_000)
|
|
222
|
+
obj_name = view_pymol(atom_array, id=id, grid_slot=slot)
|
|
223
|
+
cmd = get_atom_array_style_cmd(
|
|
224
|
+
atom_array, obj_name, label=label, grid_slot=slot, max_distances=max_distances
|
|
225
|
+
)
|
|
226
|
+
pymol.do(cmd)
|
|
227
|
+
|
|
228
|
+
if view_ori_token:
|
|
229
|
+
pymol.do(f"pseudoatom ori_{obj_name}, pos=[0,0,0]")
|
|
230
|
+
pymol.do(
|
|
231
|
+
[
|
|
232
|
+
f"set grid_slot, {slot}, ori_{obj_name}",
|
|
233
|
+
f"show spheres, ori_{obj_name}",
|
|
234
|
+
f"set sphere_color, white, ori_{obj_name}",
|
|
235
|
+
f"set sphere_scale, 0.5, ori_{obj_name}",
|
|
236
|
+
f"set sphere_transparency, 0.5, ori_{obj_name}",
|
|
237
|
+
]
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _viz_from_file(
|
|
242
|
+
file_path: str,
|
|
243
|
+
id: str = "obj",
|
|
244
|
+
clear: bool = True,
|
|
245
|
+
label: bool = True,
|
|
246
|
+
max_distances: int = 100,
|
|
247
|
+
):
|
|
248
|
+
if file_path.endswith(".pkl.gz"):
|
|
249
|
+
import gzip
|
|
250
|
+
import pickle
|
|
251
|
+
|
|
252
|
+
with gzip.open(file_path, "rb") as f:
|
|
253
|
+
atom_array = pickle.load(f)
|
|
254
|
+
elif file_path.endswith(".pkl"):
|
|
255
|
+
import pickle
|
|
256
|
+
|
|
257
|
+
with open(file_path, "rb") as f:
|
|
258
|
+
atom_array = pickle.load(f)
|
|
259
|
+
elif file_path.endswith((".cif", ".cif.gz", ".bcif", ".bcif.gz")):
|
|
260
|
+
from atomworks.io.utils.io_utils import get_structure, read_any
|
|
261
|
+
from rfd3.utils.inference import (
|
|
262
|
+
_add_design_annotations_from_cif_block_metadata,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
cif_file = read_any(file_path)
|
|
266
|
+
atom_array = get_structure(cif_file, include_bonds=True, extra_fields="all")
|
|
267
|
+
atom_array = _add_design_annotations_from_cif_block_metadata(
|
|
268
|
+
atom_array, cif_file.block
|
|
269
|
+
)
|
|
270
|
+
viz(atom_array, id=id, clear=clear, label=label, max_distances=max_distances)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
if __name__ == "__main__":
|
|
274
|
+
import fire
|
|
275
|
+
|
|
276
|
+
fire.Fire(_viz_from_file)
|