rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rfd3/utils/io.py ADDED
@@ -0,0 +1,245 @@
1
+ import json
2
+ import re
3
+ from os import PathLike
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+ import torch
9
+ from atomworks.io.utils.io_utils import to_cif_file
10
+ from biotite.structure import AtomArray, AtomArrayStack, stack
11
+
12
+ from foundry.utils.alignment import weighted_rigid_align
13
+
14
+ DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
15
+ CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
16
+
17
+
18
+ def dump_structures(
19
+ atom_arrays: AtomArrayStack | list[AtomArray] | AtomArray,
20
+ base_path: PathLike,
21
+ one_model_per_file: bool,
22
+ extra_fields: list[str] | Literal["all"] = [],
23
+ **kwargs,
24
+ ) -> None:
25
+ """Dump structures to CIF files, given the coordinates and input AtomArray.
26
+
27
+ Args:
28
+ atom_arrays (AtomArrayStack | list[AtomArray] | AtomArray): Either an AtomArrayStack, a list of AtomArray objects,
29
+ or a single AtomArray object to be dumped to CIF file(s)
30
+ base_path (PathLike): Base path where the output files will be saved.
31
+ one_model_per_file (bool): Flag to determine if each model should be dumped into a separate file. Has no effect if
32
+ `atom_arrays` is a list of AtomArrays.
33
+ extra_fields (list[str] | Literal["all"]): List of extra fields to include in the CIF file.
34
+ """
35
+ base_path = Path(base_path)
36
+ if one_model_per_file:
37
+ assert (
38
+ isinstance(atom_arrays, AtomArrayStack) or isinstance(atom_arrays, list)
39
+ ), "AtomArrayStack or list of AtomArray required when one_model_per_file is True"
40
+ # One model per file —> loop over the diffusion batch
41
+ for i in range(len(atom_arrays)):
42
+ path = f"{base_path}_model_{i}"
43
+ to_cif_file(
44
+ atom_arrays[i],
45
+ path,
46
+ file_type="cif.gz",
47
+ include_entity_poly=False,
48
+ extra_fields=extra_fields,
49
+ **kwargs,
50
+ )
51
+ else:
52
+ # Include all models in a single CIF file
53
+ to_cif_file(
54
+ atom_arrays,
55
+ base_path,
56
+ file_type="cif.gz",
57
+ include_entity_poly=False,
58
+ extra_fields=extra_fields,
59
+ **kwargs,
60
+ )
61
+
62
+
63
+ def dump_metadata(
64
+ prediction_metadata: dict,
65
+ base_path: PathLike,
66
+ one_model_per_file: bool,
67
+ ):
68
+ """
69
+ Dump JSONs of prediction metadata to disk.
70
+
71
+ Args:
72
+ prediction_metadata (dict): Dictionary containing metadata for the predictions.
73
+ Instantiated after the models' predictions in trainer. Keys are model indices (from 0)
74
+ base_path (PathLike): Base path where the output files will be saved.
75
+ one_model_per_file (bool): If True, save each model's metadata in a separate file.
76
+ """
77
+ if one_model_per_file:
78
+ # One model per file —> loop over the diffusion batch
79
+ for i in prediction_metadata:
80
+ path = f"{base_path}_model_{i}"
81
+ with open(f"{path}.json", "w") as f:
82
+ json.dump(prediction_metadata[i], f, indent=4)
83
+ else:
84
+ # Include all models in a single JSON file
85
+ with open(f"{base_path}.json", "w") as f:
86
+ json.dump(prediction_metadata, f, indent=4)
87
+
88
+
89
+ def dump_trajectories(
90
+ trajectory_list: list[torch.Tensor | np.ndarray],
91
+ atom_array: AtomArray,
92
+ base_path: Path,
93
+ align_structures: bool = True,
94
+ coord_atom_lvl_to_be_noised: torch.Tensor | None = None,
95
+ is_motif_atom_with_fixed_pos: torch.Tensor | None = None,
96
+ max_frames: int = 100,
97
+ ) -> None:
98
+ """Write denoising trajectories to CIF files.
99
+
100
+ Args:
101
+ trajectory_list (List[torch.Tensor]): List of tensors of length n_steps representing the diffusion trajectory at each step.
102
+ Each tensor has shape [D, L, 3], where D is the diffusion batch size and L is the number of atoms.
103
+ atom_array (np.ndarray): Atom array corresponding to the coordinates.
104
+ base_path (Path): Base path where the output files will be saved.
105
+ align_structures (bool): Flag to determine if the structures should be aligned on the final prediction.
106
+ If False, each step may have a different alignment.
107
+ """
108
+ n_steps = len(trajectory_list)
109
+
110
+ if align_structures:
111
+ # ... align the trajectories on the last prediction
112
+ w_L = torch.ones(*trajectory_list[0].shape[:2]).to(trajectory_list[0].device)
113
+ X_exists_L = torch.ones(trajectory_list[0].shape[1], dtype=torch.bool).to(
114
+ trajectory_list[0].device
115
+ )
116
+ for step in range(n_steps - 1):
117
+ trajectory_list[step] = weighted_rigid_align(
118
+ X_L=trajectory_list[-1],
119
+ X_gt_L=trajectory_list[step],
120
+ X_exists_L=X_exists_L,
121
+ w_L=w_L,
122
+ )
123
+
124
+ # ... invert the list, to make the trajectory compatible with PyMol (which builds the bond graph from the first frame)
125
+ trajectory_list = trajectory_list[::-1]
126
+
127
+ # ... Select subset of frames if necessary
128
+ if n_steps > max_frames:
129
+ selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
130
+ trajectory_list = [trajectory_list[i] for i in selected_indices]
131
+ n_steps = len(trajectory_list)
132
+
133
+ # ... iterate over the range of D (diffusion batch size; e.g., 5 during validation)
134
+ # (We want to convert `aligned_trajectory_list` to a list of length D where each item is a tensor of shape [n_steps, L, 3])
135
+ trajectories_split_by_model = []
136
+ for d in range(trajectory_list[0].shape[0]):
137
+ trajectory_for_single_model = torch.stack(
138
+ [trajectory_list[step][d] for step in range(n_steps)], dim=0
139
+ )
140
+ trajectories_split_by_model.append(trajectory_for_single_model)
141
+
142
+ # ... write the trajectories to CIF files, named by epoch, dataset, example_id, and model index (within the diffusion batch)
143
+ for i, trajectory in enumerate(trajectories_split_by_model):
144
+ if isinstance(trajectory, torch.Tensor):
145
+ trajectory = trajectory.cpu().numpy()
146
+ atom_array_stack = build_stack_from_atom_array_and_batched_coords(
147
+ trajectory, atom_array
148
+ )
149
+
150
+ path = f"{base_path}_model_{i}"
151
+ to_cif_file(
152
+ atom_array_stack, path, file_type="cif.gz", include_entity_poly=False
153
+ )
154
+
155
+
156
+ def build_stack_from_atom_array_and_batched_coords(
157
+ coords: np.ndarray | torch.Tensor,
158
+ atom_array: AtomArray,
159
+ ) -> AtomArrayStack:
160
+ """Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
161
+
162
+ Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
163
+
164
+ Args:
165
+ coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
166
+ atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
167
+ """
168
+ if isinstance(coords, torch.Tensor):
169
+ coords = coords.cpu().numpy()
170
+
171
+ assert (
172
+ coords.shape[-2] == atom_array.array_length()
173
+ ), f"N batched coordinates {coords.shape} != {atom_array.array_length()}"
174
+
175
+ # (Diffusion batch size will become the number of models)
176
+ n_batch = coords.shape[0]
177
+
178
+ # Build the stack and assign the coordinates
179
+ atom_array_stack = stack([atom_array for _ in range(n_batch)])
180
+ atom_array_stack.coord = coords
181
+
182
+ # Adjust chain_id if there are multiple transformations
183
+ # (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
184
+ if (
185
+ "transformation_id" in atom_array.get_annotation_categories()
186
+ and len(np.unique(atom_array_stack.transformation_id)) > 1
187
+ ):
188
+ new_chain_ids = np.char.add(
189
+ atom_array_stack.chain_id, atom_array_stack.transformation_id
190
+ )
191
+ atom_array_stack.set_annotation("chain_id", new_chain_ids)
192
+
193
+ return atom_array_stack
194
+
195
+
196
+ def find_files_with_extension(path: PathLike, supported_file_types: list) -> list[Path]:
197
+ """Recursively find all files with the given extensions in the specified path.
198
+
199
+ Args:
200
+ path (PathLike): Path to the directory containing the files.
201
+ supported_file_types (list): List of supported file extensions.
202
+
203
+ Returns:
204
+ list[Path]: List of files with the given extensions.
205
+ """
206
+ files_with_supported_types = []
207
+ path = Path(path)
208
+
209
+ # Check if the path is a directory
210
+ if path.is_dir():
211
+ # Search for files with each supported extension
212
+ for file_type in supported_file_types:
213
+ files_with_supported_types.extend(path.glob(f"*{file_type}"))
214
+ elif path.is_file() and path.suffix in supported_file_types:
215
+ # If it's a file and has a supported extension, add to the list
216
+ files_with_supported_types.append(path)
217
+
218
+ return files_with_supported_types
219
+
220
+
221
+ def create_example_id_extractor(extensions: set | list = CIF_LIKE_EXTENSIONS) -> str:
222
+ """Create a function with closure that extracts example_ids from file paths with specified extensions.
223
+
224
+ Example:
225
+ >>> extractor = create_example_id_extractor({".cif", ".cif.gz"})
226
+ >>> extractor("example.path.example_id.cif.gz")
227
+ 'example_id'
228
+ """
229
+ pattern = re.compile(
230
+ "(" + "|".join(re.escape(ext) + "$" for ext in extensions) + ")"
231
+ )
232
+
233
+ def extract_id(file_path: PathLike) -> str:
234
+ """Extract example_id from file path."""
235
+ # Remove extension and get last part after splitting by dots
236
+ without_ext = pattern.sub("", Path(file_path).name)
237
+ return without_ext.split(".")[-1]
238
+
239
+ return extract_id
240
+
241
+
242
+ def extract_example_id_from_path(file_path: PathLike, extensions: set | list) -> str:
243
+ """Extract example_id from file path with specified extensions."""
244
+ extractor = create_example_id_extractor(extensions)
245
+ return extractor(file_path)
@@ -0,0 +1,276 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../scripts/shebang/modelhub_exec.sh" "$0" "$@"'
2
+ """
3
+ If you add the `/path/to/visualize.py` to your .bashrc/.zshrc like this:
4
+
5
+ ```bash
6
+ viz () {
7
+ /path/to/visualize.py "$@"
8
+ }
9
+ ```
10
+
11
+ Then you can run `viz /path/to/file.cif` to visualize the structures via pymol-remote.
12
+ """
13
+
14
+ import pathlib
15
+ import sys
16
+
17
+ # NOTE: This is needed here to enable `viz` to be used as script
18
+ if (project_dir := str(pathlib.Path(__file__).parents[3])) not in sys.path:
19
+ sys.path.append(project_dir)
20
+
21
+ import logging
22
+
23
+ import biotite.structure as struc
24
+ import numpy as np
25
+ from atomworks.io.utils.visualize import get_pymol_session, view_pymol
26
+ from atomworks.ml.conditions import C_DIS, Condition
27
+ from atomworks.ml.conditions.base import Level
28
+
29
+ logger = logging.getLogger(__name__)
30
+ logger.setLevel(logging.WARNING)
31
+
32
+
33
+ def get_atom_array_style_cmd(
34
+ atom_array: struc.AtomArray | struc.AtomArrayStack,
35
+ obj: str,
36
+ label: bool = False,
37
+ max_distances: int = 100,
38
+ grid_slot: int | None = None,
39
+ ) -> str:
40
+ """Generate PyMOL commands to style an atom array visualization.
41
+
42
+ Creates a series of PyMOL commands that style different parts of a molecular structure, including:
43
+ - Applying a color spectrum to polymer chains
44
+ - Showing backbone atoms as sticks and CA atoms as spheres
45
+ - Styling non-polymer atoms with different colors and representation
46
+ - Highlighting specially annotated atoms with different colors and visualizations
47
+
48
+ Args:
49
+ atom_array: The biotite AtomArray or AtomArrayStack to be styled
50
+ obj: PyMOL object name to apply the styling to
51
+ label: Whether to label all atoms with their 0-indexed atom_id
52
+ max_num_dist_lines: Maximum number of distance lines to show (pymol hangs when there's too many distance objects)
53
+
54
+ Returns:
55
+ str: A PyMOL command string that styles the atom array
56
+ """
57
+ grid_slot = grid_slot or np.random.randint(0, 10_000)
58
+ commands = [f"hide everything, {obj}"]
59
+ annotations = atom_array.get_annotation_categories()
60
+
61
+ offset = 1 # ... default offset since pymol 1-indexes atom id's
62
+ atom_ids = np.arange(1, atom_array.array_length() + 1)
63
+ if "atom_id" in annotations:
64
+ offset = 0
65
+ atom_ids = atom_array.get_annotation("atom_id")
66
+
67
+ # Convert MAS residues to ALA for compatibility
68
+ atom_array = atom_array.copy()
69
+ # atom_array.res_name[atom_array.res_name == "MAS"] = "ALA"
70
+
71
+ # Style the backbone for each polymer chain with a color spectrum
72
+ for chain_id in struc.get_chains(atom_array):
73
+ if (~atom_array.hetero[atom_array.chain_id == chain_id]).any():
74
+ commands.append(
75
+ f"spectrum resi, RFd_darkblue RFd_blue RFd_lightblue RFd_purple RFd_pink RFd_melon RFd_navaho, "
76
+ f"{obj} and chain {chain_id} and elem C"
77
+ )
78
+
79
+ # Add basic styling commands for protein backbone and non-polymer components
80
+ commands.extend(
81
+ [
82
+ f"show sticks, model {obj} and name n+c+ca+cb",
83
+ f"show spheres, model {obj} and name ca",
84
+ f"set sphere_scale, 0.23, model {obj} and name ca",
85
+ f"set sphere_transparency, 0, model {obj} and name ca",
86
+ f"color grey60, model {obj} and not polymer and elem C",
87
+ f"show nb_spheres, model {obj} and not polymer",
88
+ f"show sticks, model {obj} and not polymer",
89
+ ]
90
+ )
91
+ if label:
92
+ # label 0-indexed for correspondence with biotite
93
+ commands.append("label all, ID") if offset == 0 else commands.append(
94
+ f"label all, ID-{offset}"
95
+ )
96
+
97
+ # Style atoms marked for "atomize" if present
98
+ if "atomize" in annotations and atom_array.atomize.any():
99
+ atomize_ids = np.where(atom_array.atomize)[0]
100
+ atomize_ids = atom_ids[atomize_ids]
101
+ commands.extend(
102
+ [
103
+ f"select {obj}_atomize, model {obj} and id {'+'.join(str(id) for id in atomize_ids)}",
104
+ f"show sticks, {obj}_atomize and byres {obj}_atomize",
105
+ ]
106
+ )
107
+
108
+ # Style constraints, if present:
109
+ # 2-body:
110
+ # ... add distance lines between atoms if specified in annotations
111
+ if hasattr(atom_array, "_annot_2d"):
112
+ distance_commands = []
113
+ if C_DIS.full_name in atom_array._annot_2d:
114
+ constraint_data = C_DIS.annotation(atom_array).as_array()
115
+ if len(constraint_data) > 0:
116
+ _atom_idxs = np.unique(constraint_data[:, :2].flatten()).astype(int)
117
+ _atom_ids = atom_ids[_atom_idxs]
118
+ _selection = f"{obj} and id {'+'.join(str(id) for id in _atom_ids)}"
119
+ commands.extend(
120
+ [
121
+ f"delete m2d_{obj}",
122
+ f"select m2d_{obj}, {_selection}",
123
+ f"show spheres, m2d_{obj}",
124
+ f"set sphere_color, lime, m2d_{obj}",
125
+ f"set sphere_scale, 0.25, m2d_{obj}",
126
+ f"set sphere_transparency, 0.5, m2d_{obj}",
127
+ f"show sticks, byres m2d_{obj}",
128
+ ]
129
+ )
130
+
131
+ if len(constraint_data) > max_distances:
132
+ logger.warning(
133
+ f"Too many distance conditions ({len(constraint_data)}), sampling {max_distances}."
134
+ )
135
+ constraint_idxs = np.arange(len(constraint_data))
136
+ constraint_idxs = np.random.choice(
137
+ constraint_idxs, max_distances, replace=False
138
+ )
139
+ constraint_data = constraint_data[constraint_idxs]
140
+
141
+ for row in constraint_data:
142
+ idx_i, idx_j, value = row
143
+ if (idx_i > idx_j) or (value == 0):
144
+ continue
145
+
146
+ i, j = atom_ids[idx_i], atom_ids[idx_j]
147
+ # ... if we have a stack, we grab the last frame for the distance computation
148
+ if isinstance(atom_array, struc.AtomArrayStack):
149
+ distance = struc.distance(
150
+ atom_array[0, idx_i], atom_array[0, idx_j]
151
+ )
152
+ else:
153
+ distance = struc.distance(atom_array[idx_i], atom_array[idx_j])
154
+
155
+ distance_name = f"d{idx_i}-{idx_j}_{value:.2f}_{distance:.2f}"
156
+ distance_commands.extend(
157
+ [
158
+ f"distance {distance_name}, model {obj} and id {i}, model {obj} and id {j}",
159
+ f"set grid_slot, {grid_slot}, {distance_name}",
160
+ ]
161
+ )
162
+
163
+ commands.extend(distance_commands)
164
+
165
+ # Handle 1-D conditions (only display masks)
166
+ for cond in Condition:
167
+ if cond.n_body == 1 and cond.mask(atom_array, default="generate").any():
168
+ _atom_ids = atom_ids[np.where(cond.mask(atom_array, default="generate"))[0]]
169
+ if cond.level == Level.ATOM:
170
+ _selection = (
171
+ f"model {obj} and id {'+'.join(str(id) for id in _atom_ids)}"
172
+ )
173
+ commands.extend([f"select {cond.mask_name}_{obj}, {_selection}"])
174
+ elif cond.level == Level.RESIDUE or cond.level == Level.TOKEN:
175
+ _selection = f"model {obj} and byres (id {'+'.join(str(id) for id in _atom_ids)})"
176
+ commands.extend([f"select {cond.mask_name}_{obj}, {_selection}"])
177
+
178
+ return "\n".join(commands)
179
+
180
+
181
+ def viz(
182
+ atom_array: struc.AtomArray | struc.AtomArrayStack,
183
+ id: str = "obj",
184
+ clear: bool = True,
185
+ label: bool = True,
186
+ max_distances: int = 100,
187
+ view_ori_token: bool = False,
188
+ ) -> None:
189
+ """Quickly visualize a molecular structure in PyMOL with predefined styling.
190
+
191
+ This function creates a PyMOL session, loads the atom array structure, and applies
192
+ a set of styling commands to make the visualization informative and aesthetically pleasing.
193
+ The styling highlights different structural components and annotated features.
194
+
195
+ Args:
196
+ atom_array: The biotite AtomArray or AtomArrayStack to visualize
197
+ id: PyMOL object identifier (default: "obj")
198
+ clear: Whether to clear existing PyMOL objects before visualization (default: True)
199
+ label: Whether to label all atoms with their 0-indexed atom_id (default: True)
200
+ view_ori_token: Whether to view the ori token (default: False)
201
+
202
+ Example:
203
+ >>> from biotite.structure import AtomArray
204
+ >>> # Create or load an atom array
205
+ >>> structure = AtomArray(...)
206
+ >>> # Visualize the structure in PyMOL
207
+ >>> viz(structure)
208
+ """
209
+ atom_array = atom_array.copy()
210
+ # pymol only considers chain_id annotation, which can make weird looking artifacts if we have two different chains with the same chain_id
211
+ # We always disambiguate by using the chain_iid annotation, so we need to have pymol use that to do the same
212
+ if "chain_iid" in atom_array.get_annotation_categories():
213
+ atom_array.chain_id = atom_array.get_annotation("chain_iid")
214
+ atom_array.chain_id = atom_array.chain_id.astype(str)
215
+
216
+ pymol = get_pymol_session()
217
+ pymol.do("set valence, 1; set connect_mode, 2;")
218
+ if clear:
219
+ pymol.do("delete d*")
220
+ pymol.delete("all")
221
+ slot = np.random.randint(0, 10_000)
222
+ obj_name = view_pymol(atom_array, id=id, grid_slot=slot)
223
+ cmd = get_atom_array_style_cmd(
224
+ atom_array, obj_name, label=label, grid_slot=slot, max_distances=max_distances
225
+ )
226
+ pymol.do(cmd)
227
+
228
+ if view_ori_token:
229
+ pymol.do(f"pseudoatom ori_{obj_name}, pos=[0,0,0]")
230
+ pymol.do(
231
+ [
232
+ f"set grid_slot, {slot}, ori_{obj_name}",
233
+ f"show spheres, ori_{obj_name}",
234
+ f"set sphere_color, white, ori_{obj_name}",
235
+ f"set sphere_scale, 0.5, ori_{obj_name}",
236
+ f"set sphere_transparency, 0.5, ori_{obj_name}",
237
+ ]
238
+ )
239
+
240
+
241
+ def _viz_from_file(
242
+ file_path: str,
243
+ id: str = "obj",
244
+ clear: bool = True,
245
+ label: bool = True,
246
+ max_distances: int = 100,
247
+ ):
248
+ if file_path.endswith(".pkl.gz"):
249
+ import gzip
250
+ import pickle
251
+
252
+ with gzip.open(file_path, "rb") as f:
253
+ atom_array = pickle.load(f)
254
+ elif file_path.endswith(".pkl"):
255
+ import pickle
256
+
257
+ with open(file_path, "rb") as f:
258
+ atom_array = pickle.load(f)
259
+ elif file_path.endswith((".cif", ".cif.gz", ".bcif", ".bcif.gz")):
260
+ from atomworks.io.utils.io_utils import get_structure, read_any
261
+ from rfd3.utils.inference import (
262
+ _add_design_annotations_from_cif_block_metadata,
263
+ )
264
+
265
+ cif_file = read_any(file_path)
266
+ atom_array = get_structure(cif_file, include_bonds=True, extra_fields="all")
267
+ atom_array = _add_design_annotations_from_cif_block_metadata(
268
+ atom_array, cif_file.block
269
+ )
270
+ viz(atom_array, id=id, clear=clear, label=label, max_distances=max_distances)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ import fire
275
+
276
+ fire.Fire(_viz_from_file)