rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,648 @@
1
+ """
2
+ Utilities for inference input preparation
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from os import PathLike
8
+ from typing import Dict
9
+
10
+ import biotite.structure as struc
11
+ import numpy as np
12
+ from atomworks import parse
13
+ from atomworks.constants import STANDARD_AA, STANDARD_DNA
14
+ from atomworks.io.parser import (
15
+ STANDARD_PARSER_ARGS,
16
+ )
17
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
18
+ from atomworks.ml.preprocessing.utils.structure_utils import (
19
+ get_atom_mask_from_cell_list,
20
+ )
21
+ from atomworks.ml.utils.token import (
22
+ get_token_starts,
23
+ spread_token_wise,
24
+ )
25
+ from rfd3.constants import (
26
+ REQUIRED_CONDITIONING_ANNOTATIONS,
27
+ )
28
+ from rfd3.transforms.conditioning_base import (
29
+ convert_existing_annotations_to_bool,
30
+ set_default_conditioning_annotations,
31
+ )
32
+ from rfd3.transforms.conditioning_utils import sample_island_tokens
33
+
34
+ from foundry.common import exists
35
+ from foundry.utils.components import (
36
+ fetch_mask_from_component,
37
+ get_name_mask,
38
+ unravel_components,
39
+ )
40
+ from foundry.utils.ddp import RankedLogger
41
+
42
+ logging.basicConfig(level=logging.INFO)
43
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
44
+
45
+ sequence_encoding = AF3SequenceEncoding()
46
+ _aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
47
+
48
+
49
+ #################################################################################
50
+ # Setter functions for annotations
51
+ #################################################################################
52
+
53
+
54
+ def set_common_annotations(array, set_src_component_to_res_name=True):
55
+ annots = array.get_annotation_categories()
56
+ if "occupancy" not in annots:
57
+ array.set_annotation("occupancy", np.ones(array.shape[0], dtype=float))
58
+ if "b_factor" not in annots:
59
+ array.set_annotation("b_factor", np.zeros(array.shape[0], dtype=float))
60
+ if "charge" not in annots:
61
+ array.set_annotation("charge", np.zeros(array.shape[0], dtype=float))
62
+ if "src_component" not in annots:
63
+ if set_src_component_to_res_name:
64
+ array.set_annotation(
65
+ "src_component",
66
+ np.full(
67
+ array.shape[0], array.res_name.copy(), dtype=array.res_name.dtype
68
+ ),
69
+ )
70
+ else:
71
+ array.set_annotation(
72
+ "src_component", np.full(array.shape[0], "", dtype=array.res_name.dtype)
73
+ )
74
+ return array
75
+
76
+
77
+ def set_indices(array, chain, res_id_start, molecule_id, component):
78
+ n = array.shape[0]
79
+ array.chain_id = np.full(n, chain, dtype=array.chain_id.dtype)
80
+ array.res_id = np.full(n, res_id_start + array.res_id - 1, dtype=array.res_id.dtype)
81
+ array.molecule_id = np.full(n, molecule_id, dtype=np.int32)
82
+ array.set_annotation(
83
+ "src_component", np.full(n, component, dtype=array.chain_id.dtype)
84
+ )
85
+ return array
86
+
87
+
88
+ #################################################################################
89
+ # Getters
90
+ #################################################################################
91
+
92
+
93
+ def extract_ligand_array(
94
+ atom_array_input,
95
+ ligand,
96
+ fixed_atoms={},
97
+ set_defaults=True,
98
+ additional_annotations=None,
99
+ ):
100
+ if not exists(atom_array_input):
101
+ raise ValueError(
102
+ "No input file/atom array provided. Cannot add requested ligand."
103
+ )
104
+
105
+ ligand_arrays = []
106
+ for lig in ligand.split(","):
107
+ for name in unravel_components(
108
+ lig, atom_array=atom_array_input, allow_multiple_matches=True
109
+ ): # additional nesting to allow multiple indices per ligand
110
+ mask = fetch_mask_from_component(name, atom_array=atom_array_input)
111
+ ligand_array = atom_array_input[mask].copy()
112
+
113
+ # ... Set as fully fixed motif
114
+ if set_defaults:
115
+ ligand_array = set_default_conditioning_annotations(
116
+ ligand_array, motif=True, additional=additional_annotations
117
+ ) # should be pre-set!
118
+ ligand_array = set_common_annotations(ligand_array)
119
+
120
+ # ... Unfix all names not specified if specified in motif_atoms
121
+ if lig in fixed_atoms or name in fixed_atoms:
122
+ if (lig in fixed_atoms and name in fixed_atoms) and name != lig:
123
+ raise ValueError(
124
+ f"Got both ligand name and its pdb indices in fixed_atoms dictionary: {lig} and {name}. Please only provide one."
125
+ )
126
+ fixed = fixed_atoms.get(lig, fixed_atoms.get(name, None))
127
+ if fixed:
128
+ fixed_mask = get_name_mask(ligand_array.atom_name, fixed)
129
+ ligand_array.is_motif_atom_with_fixed_coord[~fixed_mask] = np.zeros(
130
+ np.sum(~fixed_mask), dtype=int
131
+ )
132
+ else:
133
+ ligand_array.is_motif_atom_with_fixed_coord = np.zeros(
134
+ ligand_array.shape[0], dtype=int
135
+ )
136
+ ligand_arrays.append(ligand_array)
137
+
138
+ ligand_arrays = struc.concatenate(ligand_arrays)
139
+ return ligand_arrays
140
+
141
+
142
+ def extract_na_array(atom_array_input):
143
+ # TODO : do it more nicely, take into account modifications to NA reses e.g. 5IU
144
+ if (na_mask := np.isin(atom_array_input.res_name, list(STANDARD_DNA))).any():
145
+ na_array = atom_array_input[na_mask]
146
+ # ...replace chain_id A with literally anything else available
147
+ Achain_mask = na_array.chain_id == "A"
148
+
149
+ all_nonAchains = np.unique((atom_array_input + na_array).chain_id).tolist()
150
+ all_nonAchains.remove("A")
151
+
152
+ if len(all_nonAchains) > 1:
153
+ new_chain = "".join(all_nonAchains) # join_them_all !! so definitely unique
154
+ elif len(all_nonAchains) == 1:
155
+ new_chain = all_nonAchains[0] + all_nonAchains[0]
156
+ else:
157
+ new_chain = "B"
158
+
159
+ na_array.chain_id[Achain_mask] = new_chain
160
+ na_array = set_default_conditioning_annotations(na_array, motif=True)
161
+ return na_array
162
+ else:
163
+ raise ValueError(
164
+ "Could not find any NA tokens in input file, but requested to add all NA"
165
+ )
166
+
167
+
168
+ def _restore_bonds_for_nonstandard_residues(
169
+ atom_array_accum: struc.AtomArray,
170
+ src_atom_array: struc.AtomArray | None,
171
+ source_to_accum_idx: Dict[int, int],
172
+ ) -> struc.AtomArray:
173
+ """
174
+ Restores and creates bonds for non-standard residues (PTMs, modified AAs, etc.)
175
+ from source structure and between consecutive residues.
176
+ This function:
177
+ 1. Preserves inter-residue bonds from the source structure (if available)
178
+ 2. Adds backbone C-N bonds between consecutive residues where at least one is non-standard
179
+ Args:
180
+ atom_array_accum: The accumulated atom array to add bonds to
181
+ src_atom_array: The source atom array containing original bond information
182
+ source_to_accum_idx: Mapping from source atom indices to accumulated array indices
183
+ Returns:
184
+ atom_array_accum with bonds added
185
+ """
186
+ # Initialize bonds if needed
187
+ if atom_array_accum.bonds is None:
188
+ atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())
189
+
190
+ # Step 1: Restore inter-residue bonds from the source atom array (only for non-standard residues)
191
+ if (
192
+ src_atom_array is not None
193
+ and hasattr(src_atom_array, "bonds")
194
+ and src_atom_array.bonds is not None
195
+ ):
196
+ original_bonds = src_atom_array.bonds.as_array()
197
+ if len(original_bonds) > 0:
198
+ # Extract bonds where both atoms are in the accumulated array
199
+ bonds_to_add = []
200
+ for bond in original_bonds:
201
+ atom_i, atom_j, bond_type = bond
202
+ # Check if both atoms are in our mapping
203
+ if (
204
+ int(atom_i) in source_to_accum_idx
205
+ and int(atom_j) in source_to_accum_idx
206
+ ):
207
+ # Check if at least one atom is from a non-standard residue
208
+ src_res_i = src_atom_array[int(atom_i)].res_name
209
+ src_res_j = src_atom_array[int(atom_j)].res_name
210
+
211
+ # Only preserve if at least one residue is non-standard
212
+ if src_res_i not in STANDARD_AA or src_res_j not in STANDARD_AA:
213
+ new_i = source_to_accum_idx[int(atom_i)]
214
+ new_j = source_to_accum_idx[int(atom_j)]
215
+ bonds_to_add.append([new_i, new_j, int(bond_type)])
216
+
217
+ if bonds_to_add:
218
+ # Add the preserved bonds
219
+ new_bonds = struc.BondList(
220
+ atom_array_accum.array_length(),
221
+ np.array(bonds_to_add, dtype=np.int64),
222
+ )
223
+ atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
224
+ logger.info(
225
+ f"Preserved {len(bonds_to_add)} inter-residue bonds involving non-standard residues from source structure"
226
+ )
227
+
228
+ # Step 2: Add backbone bonds between consecutive residues where at least one is non-standard
229
+ # This handles: PTM-to-diffused, diffused-to-PTM, PTM-to-PTM, ligand-to-protein
230
+ bonds_to_add = []
231
+
232
+ # Group by residue
233
+ token_starts = get_token_starts(atom_array_accum, add_exclusive_stop=True)
234
+
235
+ for i in range(
236
+ len(token_starts) - 2
237
+ ): # -2 because we need pairs and token_starts has exclusive stop
238
+ curr_start, curr_end = token_starts[i], token_starts[i + 1]
239
+ next_start, next_end = token_starts[i + 1], token_starts[i + 2]
240
+
241
+ curr_residue = atom_array_accum[curr_start:curr_end]
242
+ next_residue = atom_array_accum[next_start:next_end]
243
+
244
+ # Check if at least one residue is non-standard (PTMs, modified AAs, etc.)
245
+ curr_is_nonstandard = curr_residue.res_name[0] not in STANDARD_AA
246
+ next_is_nonstandard = next_residue.res_name[0] not in STANDARD_AA
247
+
248
+ # Only add bonds if at least one residue is non-standard
249
+ if not (curr_is_nonstandard or next_is_nonstandard):
250
+ continue
251
+
252
+ # Check if consecutive in same chain
253
+ if curr_residue.chain_id[0] != next_residue.chain_id[0]:
254
+ continue
255
+ if next_residue.res_id[0] - curr_residue.res_id[0] != 1:
256
+ continue
257
+
258
+ # Find C atom in current residue (C-terminus connection point)
259
+ c_mask = curr_residue.atom_name == "C"
260
+ if not np.any(c_mask):
261
+ # If a non-standard residue doesn't have a C atom, it can't connect to next residue
262
+ # This is expected for some atomized residues or ligands at chain termini
263
+ if curr_is_nonstandard and next_is_nonstandard:
264
+ # Both are non-standard but no C in current - might be an atomized region without proper termini
265
+ logger.debug(
266
+ f"Non-standard residue {curr_residue.res_name[0]} (res_id {curr_residue.res_id[0]}) "
267
+ f"has no C atom - cannot form backbone bond to next residue"
268
+ )
269
+ continue
270
+ c_idx = curr_start + np.where(c_mask)[0][0]
271
+
272
+ # Find N atom in next residue (N-terminus connection point)
273
+ n_mask = next_residue.atom_name == "N"
274
+ if not np.any(n_mask):
275
+ # If a non-standard residue doesn't have an N atom, it can't connect to previous residue
276
+ # This is expected for some atomized residues or ligands at chain termini
277
+ if curr_is_nonstandard and next_is_nonstandard:
278
+ # Both are non-standard but no N in next - might be an atomized region without proper termini
279
+ logger.debug(
280
+ f"Non-standard residue {next_residue.res_name[0]} (res_id {next_residue.res_id[0]}) "
281
+ f"has no N atom - cannot form backbone bond from previous residue"
282
+ )
283
+ continue
284
+ n_idx = next_start + np.where(n_mask)[0][0]
285
+
286
+ # Check if this bond already exists (from source preservation)
287
+ existing_bonds = atom_array_accum.bonds.as_array()
288
+ bond_exists = False
289
+ if len(existing_bonds) > 0:
290
+ for existing_bond in existing_bonds:
291
+ if (existing_bond[0] == c_idx and existing_bond[1] == n_idx) or (
292
+ existing_bond[0] == n_idx and existing_bond[1] == c_idx
293
+ ):
294
+ bond_exists = True
295
+ break
296
+
297
+ if not bond_exists:
298
+ bonds_to_add.append([c_idx, n_idx, struc.BondType.SINGLE])
299
+
300
+ if bonds_to_add:
301
+ new_bonds = struc.BondList(
302
+ atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64)
303
+ )
304
+ atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
305
+ logger.info(
306
+ f"Added {len(bonds_to_add)} backbone bonds involving non-standard residues"
307
+ )
308
+
309
+ return atom_array_accum
310
+
311
+
312
+ #################################################################################
313
+ # File IO utilities
314
+ #################################################################################
315
+
316
+
317
+ def inference_load_(
318
+ file: PathLike, *, assembly_id: str = "1", cif_parser_args: dict | None = None
319
+ ):
320
+ # Default cif_parser_args to an empty dictionary if not provided
321
+ if cif_parser_args is None:
322
+ cif_parser_args = {}
323
+
324
+ # Convenience utilities to default to loading from and saving to cache if a cache_dir is provided, unless explicitly overridden
325
+ if "cache_dir" in cif_parser_args and cif_parser_args["cache_dir"]:
326
+ cif_parser_args.setdefault("load_from_cache", True)
327
+ cif_parser_args.setdefault("save_to_cache", True)
328
+
329
+ merged_cif_parser_args = {
330
+ **STANDARD_PARSER_ARGS,
331
+ **{
332
+ "fix_arginines": False,
333
+ "add_missing_atoms": False,
334
+ "remove_ccds": [],
335
+ },
336
+ **cif_parser_args,
337
+ }
338
+ merged_cif_parser_args["hydrogen_policy"] = "remove"
339
+
340
+ # Ensure the required annotations can be loaded
341
+ merged_cif_parser_args["extra_fields"] = list(
342
+ set(
343
+ merged_cif_parser_args.get("extra_fields", [])
344
+ + REQUIRED_CONDITIONING_ANNOTATIONS
345
+ )
346
+ )
347
+
348
+ # Use the parse function with the merged CIF parser arguments
349
+ result_dict = parse(
350
+ filename=file,
351
+ build_assembly=(assembly_id,), # Convert list to tuple (make hashable)
352
+ **merged_cif_parser_args,
353
+ )
354
+
355
+ atom_array = result_dict["assemblies"][assembly_id][0]
356
+ atom_array = convert_existing_annotations_to_bool(atom_array)
357
+
358
+ data = {
359
+ "atom_array": atom_array, # First model
360
+ "chain_info": result_dict["chain_info"],
361
+ "ligand_info": result_dict["ligand_info"],
362
+ "metadata": result_dict["metadata"],
363
+ }
364
+
365
+ return data
366
+
367
+
368
+ def ensure_input_is_abspath(args: dict, path: PathLike | None):
369
+ """
370
+ Ensures the input source is an absolute path if exists, if not it will convert
371
+
372
+ args:
373
+ spec: Inference specification for atom array
374
+ path: None or file to which the input is relative to.
375
+ """
376
+ if isinstance(args, str):
377
+ raise ValueError(
378
+ "Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
379
+ args
380
+ )
381
+ )
382
+ if "input" not in args or not exists(args["input"]):
383
+ return args
384
+ input = args["input"]
385
+ if not os.path.isabs(input):
386
+ input = os.path.abspath(os.path.join(os.path.dirname(path), input))
387
+ ranked_logger.info(
388
+ f"Input source path is relative, converted to absolute path: {input}"
389
+ )
390
+ args["input"] = input
391
+ return args
392
+
393
+
394
+ #################################################################################
395
+ # Custom infer_ori functions
396
+ #################################################################################
397
+
398
+
399
+ def infer_ori_from_hotspots(atom_array: struc.AtomArray):
400
+ assert (
401
+ "is_atom_level_hotspot" in atom_array.get_annotation_categories()
402
+ ), "Atom array must contain 'is_atom_level_hotspot' annotation to infer ori from hotspots."
403
+ hotspot_atom_array = atom_array[atom_array.is_atom_level_hotspot.astype(bool)]
404
+ hotspot_com = hotspot_atom_array.coord.mean(axis=0)
405
+
406
+ # We can only perform distance computations on atoms with non-NaN coordinates
407
+ nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1)
408
+ non_nan_atom_array = atom_array[~nan_coords_mask]
409
+
410
+ # Perform the distance computation
411
+ # RFD2 used 10 Angstroms instead of 12, but was for residue-level hotspots
412
+ DISTANCE_CUTOFF = 12.0
413
+ cell_list = struc.CellList(non_nan_atom_array, cell_size=DISTANCE_CUTOFF)
414
+ nearby_atoms_mask = get_atom_mask_from_cell_list(
415
+ hotspot_atom_array.coord,
416
+ cell_list,
417
+ len(non_nan_atom_array),
418
+ cutoff=DISTANCE_CUTOFF,
419
+ ) # (n_query, n_cell_list)
420
+
421
+ nearby_atoms_mask = np.any(nearby_atoms_mask, axis=0) # (n_cell_list,)
422
+ nearby_atoms_com = non_nan_atom_array.coord[nearby_atoms_mask].mean(axis=0)
423
+
424
+ vector_from_core_to_hotspot = hotspot_com - nearby_atoms_com
425
+ vector_from_core_to_hotspot = vector_from_core_to_hotspot / np.linalg.norm(
426
+ vector_from_core_to_hotspot
427
+ )
428
+
429
+ # This is following RFD2. Both this and the distance cutoff should definitely be configs with defaults
430
+ DISTANCE_ABOVE_HOTSPOTS = 10.0
431
+ ori_token = hotspot_com + DISTANCE_ABOVE_HOTSPOTS * vector_from_core_to_hotspot
432
+
433
+ return ori_token
434
+
435
+
436
+ def infer_ori_from_com(atom_array):
437
+ xyz = atom_array.coord
438
+ mask = np.isfinite(xyz).all(axis=-1) # Ensure no NaN coordinates
439
+ com = np.mean(xyz[..., mask, :], axis=0)
440
+ return com
441
+
442
+
443
+ # This can't go in constants.py because that leads to a circular dependency
444
+ INFER_ORI_STRATEGIES_TO_FUNCTIONS = {
445
+ "hotspots": infer_ori_from_hotspots,
446
+ "com": infer_ori_from_com,
447
+ }
448
+ """
449
+ Constant mapping from infer_ori_strategy keys to the corresponding functions. These functions should take an AtomArray
450
+ as input and return a three-element list or numpy array of floats.
451
+ """
452
+
453
+
454
+ def set_com(
455
+ atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None
456
+ ):
457
+ if exists(ori_token):
458
+ center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
459
+ atom_array.coord = atom_array.coord - center
460
+ ranked_logger.info(f"Received ori_token argument. Setting origin as {center}.")
461
+ if infer_ori_strategy is not None:
462
+ ranked_logger.warning(
463
+ f"Specified infer_ori_strategy of '{infer_ori_strategy}' will be ignored because an ori_token was provided."
464
+ )
465
+ elif "ORI" in atom_array.res_name:
466
+ center = atom_array[atom_array.res_name == "ORI"].coord
467
+ if center.shape[0] != 1:
468
+ center = np.random.choice(center, size=1, replace=False)
469
+ ranked_logger.info(f"Found multiple ORI tokens in input. Sampled: {center}")
470
+ center = np.nan_to_num(center.squeeze())
471
+ atom_array.coord = atom_array.coord - center
472
+ ranked_logger.info(
473
+ f"Found ORI token in input. Setting origin as token value ({center})."
474
+ )
475
+ if infer_ori_strategy is not None:
476
+ ranked_logger.warning(
477
+ f"Specified infer_ori_strategy of '{infer_ori_strategy}' will be ignored because an ori_token was provided."
478
+ )
479
+ elif infer_ori_strategy is not None:
480
+ if infer_ori_strategy in INFER_ORI_STRATEGIES_TO_FUNCTIONS:
481
+ center = INFER_ORI_STRATEGIES_TO_FUNCTIONS[infer_ori_strategy](atom_array)
482
+ atom_array.coord = atom_array.coord - center
483
+ ranked_logger.info(
484
+ f"Inferred origin using strategy '{infer_ori_strategy}'. Setting origin as {center}."
485
+ )
486
+ else:
487
+ # No offset
488
+ if np.any(atom_array.is_motif_atom_with_fixed_coord.astype(bool)):
489
+ center = np.nan_to_num(
490
+ np.mean(
491
+ atom_array.coord[
492
+ atom_array.is_motif_atom_with_fixed_coord.astype(bool)
493
+ ],
494
+ axis=0,
495
+ )
496
+ )
497
+ ranked_logger.info(
498
+ f"No ori_token or infer_ori_strategy provided. Setting origin as COM of fixed motif ({center})."
499
+ )
500
+ atom_array.coord -= center
501
+ else:
502
+ ranked_logger.warning(
503
+ "No ori_token, infer_ori_strategy, or motif provided. Setting [0,0,0] as origin."
504
+ )
505
+ atom_array.coord = np.zeros_like(
506
+ atom_array.coord, dtype=atom_array.coord.dtype
507
+ )
508
+ return atom_array
509
+
510
+
511
+ #################################################################################
512
+ # Custom conditioning functions
513
+ #################################################################################
514
+
515
+
516
+ def spoof_helical_bundle_ss_conditioning_fn(atom_array: struc.AtomArray):
517
+ # NOTE: This assumes that all diffused residues are protein residues -- should be updated if that changes!
518
+ # Compute islands within the subset that is diffused and has secondary structure types.
519
+ token_level_array = atom_array[get_token_starts(atom_array)]
520
+ is_diffused_atom_token_level = ~(
521
+ token_level_array.is_motif_atom_with_fixed_coord.astype(bool)
522
+ )
523
+
524
+ # My reason for sampling from 3-7 is that I don't want to restrict the model too heavily since this is
525
+ # indexed to specific residues, and it will likely extend helices to reasonable lengths once it has started them.
526
+ where_to_show_helices = sample_island_tokens(
527
+ is_diffused_atom_token_level.sum(),
528
+ island_len_min=3,
529
+ island_len_max=7,
530
+ n_islands_min=1,
531
+ n_islands_max=3,
532
+ max_length=None,
533
+ )
534
+
535
+ # Convert this to a mask over the entire token-level atom array
536
+ token_level_helix_mask = np.zeros(token_level_array.array_length(), dtype=bool)
537
+ token_level_helix_mask[is_diffused_atom_token_level] = where_to_show_helices
538
+
539
+ # I don't want to sample very near the tails, as this gets too restrictive for the model
540
+ for chain_id in np.unique(token_level_array.chain_id):
541
+ chain_mask = token_level_array.chain_id == chain_id
542
+ chain_indices = np.where(chain_mask)[0]
543
+ chain_start, chain_end = chain_indices[0], chain_indices[-1] + 1
544
+ chain_length = chain_mask.sum()
545
+
546
+ buffer_length = chain_length // 8
547
+ buffer_mask = chain_mask.copy()
548
+ buffer_mask[chain_start + buffer_length : chain_end - buffer_length] = False
549
+
550
+ token_level_helix_mask[buffer_mask] = False
551
+
552
+ helix_conditioning = np.zeros(atom_array.array_length())
553
+ helix_condition_mask = spread_token_wise(atom_array, token_level_helix_mask)
554
+
555
+ helix_conditioning[helix_condition_mask] = 1
556
+ return helix_conditioning
557
+
558
+
559
+ #################################################################################
560
+ # Patching of bad inputs
561
+ #################################################################################
562
+
563
+
564
+ def generate_idealized_cb_position(N: np.array, Ca: np.array, C: np.array) -> np.array:
565
+ """
566
+ Generate Cb coordiantes given (N, CA, C) as if the given coordinates were from an idealized Alanine.
567
+
568
+ Args:
569
+ - N (np.array): coordinates of (pseudo) N atoms [..., L, 3]
570
+ - Ca (np.array): coordinates of (pseudo) Ca atoms [..., L, 3]
571
+ - C (np.array): coordinates of (pseudo) C atoms [..., L, 3]
572
+
573
+ Returns:
574
+ Cb (torch.Tensor): coordinates of (pseudo) Cb atoms [..., L, 3]
575
+ These will be placed at the idealized Cb distance (based on ALA) from Ca, assuming a frame of the following form:
576
+ - x-axis: along the Ca-C bond
577
+ - z-axis: perpendicular to the Ca-N-C plane, right-handed wrt to (Ca-C) & (Ca-N) vectors.
578
+ - y-axis: in the plane of the Ca-N-C bonds, such that the overall frame is right-handed.
579
+ Reference:
580
+ - https://github.com/google-deepmind/alphafold/blob/d95a92aae161240b645fc10e9d030443011d913e/alphafold/common/residue_constants.py#L126-L335
581
+ ALA:
582
+ ['N', 0, (-0.525, 1.363, 0.000)], # ca-n bond dist: 1.4606142543
583
+ ['CA', 0, ( 0.000, 0.000, 0.000)],
584
+ ['C', 0, ( 1.526, 0.000, 0.000)], # ca-c bond dist: 1.526
585
+ ['CB', 0, (-0.529, -0.774, -1.205)], # cb-ca bond dist: 1.5267422834
586
+ """
587
+ if np.linalg.norm(N) == 0 and np.linalg.norm(C) == 0 and np.linalg.norm(Ca) == 0:
588
+ return np.zeros_like(N)
589
+
590
+ def _safe_normalize(vec: np.ndarray) -> np.ndarray:
591
+ vec = np.asarray(vec, dtype=float)
592
+ norms = np.linalg.norm(vec, axis=-1, keepdims=True)
593
+ norms = np.where(norms == 0, 1.0, norms)
594
+ return vec / norms
595
+
596
+ normalize = _safe_normalize
597
+
598
+ # ... get local frame x-axis
599
+ to_C = C - Ca
600
+ frame_x = normalize(to_C)
601
+
602
+ # ... get local frame z-axis
603
+ to_N = N - Ca
604
+ to_out_of_plane = np.cross(frame_x, normalize(to_N), axis=-1)
605
+ frame_z = normalize(to_out_of_plane)
606
+
607
+ # ... get local frame y-axis
608
+ frame_y = normalize(np.cross(frame_z, frame_x, axis=-1))
609
+
610
+ # ... place virtual Cb at the desired location
611
+ Cb = Ca + (-0.529 * frame_x - 0.774 * frame_y - 1.205 * frame_z)
612
+ return Cb
613
+
614
+
615
+ def create_cb_atoms(array):
616
+ # array of length 4 with N, CA, C, O
617
+ # Returns array with CB placed ideally
618
+ if array.atom_name.tolist() != ["N", "CA", "C", "O"]:
619
+ raise ValueError(
620
+ "Input array must contain exactly 4 atoms: N, CA, C, O. Got : {}".format(
621
+ array.atom_name.tolist()
622
+ )
623
+ )
624
+ cb_atoms = array[array.atom_name == "CA"].copy()
625
+ cb_atoms.atom_name = np.array(["CB"], dtype=cb_atoms.atom_name.dtype)
626
+ cb_pos = generate_idealized_cb_position(
627
+ array.coord[array.atom_name == "N"].squeeze(),
628
+ array.coord[array.atom_name == "CA"].squeeze(),
629
+ array.coord[array.atom_name == "C"].squeeze(),
630
+ )
631
+ cb_atoms.coord = cb_pos[None]
632
+ return cb_atoms
633
+
634
+
635
+ def create_o_atoms(array):
636
+ if array.atom_name.tolist() != ["N", "CA", "C"]:
637
+ raise ValueError(
638
+ "Input array must contain exactly 4 atoms: N, CA, C, O. Got : {}".format(
639
+ array.atom_name.tolist()
640
+ )
641
+ )
642
+
643
+ ca_atoms = array[array.atom_name == "CA"].copy()
644
+ ca_atoms.atom_name = np.array(["O"], dtype=ca_atoms.atom_name.dtype)
645
+ ca_atoms.element = np.array(["O"], dtype=ca_atoms.element.dtype)
646
+ ca_atoms.coord = array.coord[array.atom_name == "C"].squeeze()[None]
647
+
648
+ return ca_atoms