rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,717 @@
1
+ import copy
2
+ import functools
3
+ import logging
4
+ from os import PathLike
5
+
6
+ import biotite.structure as struc
7
+ import numpy as np
8
+ from atomworks.constants import STANDARD_AA
9
+ from atomworks.io.utils.io_utils import to_cif_file
10
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
11
+ from atomworks.ml.utils.token import (
12
+ get_token_starts,
13
+ )
14
+ from rfd3.constants import (
15
+ INFERENCE_ANNOTATIONS,
16
+ OPTIONAL_CONDITIONING_VALUES,
17
+ REQUIRED_INFERENCE_ANNOTATIONS,
18
+ )
19
+ from rfd3.inference.symmetry.symmetry_utils import (
20
+ center_symmetric_src_atom_array,
21
+ make_symmetric_atom_array,
22
+ )
23
+ from rfd3.transforms.conditioning_base import (
24
+ check_has_required_conditioning_annotations,
25
+ convert_existing_annotations_to_bool,
26
+ get_motif_features,
27
+ set_default_conditioning_annotations,
28
+ )
29
+ from rfd3.transforms.util_transforms import assign_types_
30
+ from rfd3.utils.inference import (
31
+ create_cb_atoms,
32
+ create_o_atoms,
33
+ extract_ligand_array,
34
+ inference_load_,
35
+ set_com,
36
+ set_common_annotations,
37
+ set_indices,
38
+ )
39
+
40
+ from foundry.common import exists
41
+ from foundry.utils.components import (
42
+ fetch_mask_from_component,
43
+ fetch_mask_from_idx,
44
+ get_design_pattern_with_constraints,
45
+ get_motif_components_and_breaks,
46
+ get_name_mask,
47
+ split_contig,
48
+ )
49
+ from foundry.utils.ddp import RankedLogger
50
+
51
+ logging.basicConfig(level=logging.INFO)
52
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
53
+
54
+ sequence_encoding = AF3SequenceEncoding()
55
+ _aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
56
+
57
+
58
+ def assert_non_intersecting_contigs(indexed_components, unindexed_components):
59
+ assert not any(
60
+ [
61
+ (
62
+ (unindexed_component in indexed_components)
63
+ and unindexed_component[0].isalpha()
64
+ )
65
+ for unindexed_component in unindexed_components
66
+ ]
67
+ ), "Unindexed residues must not be part of the indexing contig. got: {} and {}".format(
68
+ unindexed_components, indexed_components
69
+ )
70
+
71
+
72
+ def set_atom_level_argument(atom_array, args, name: str):
73
+ default_value = OPTIONAL_CONDITIONING_VALUES.get(name, np.nan)
74
+ atom_values = np.full(atom_array.array_length(), default_value)
75
+ atom_idxs = np.arange(atom_array.array_length())
76
+
77
+ if args is not None:
78
+ for component_name, d in args.items():
79
+ component_mask = fetch_mask_from_component(
80
+ component_name, atom_array=atom_array
81
+ )
82
+ for names, value in d.items():
83
+ mask = component_mask & np.isin(
84
+ atom_array.atom_name, np.array(names.split(","))
85
+ )
86
+ assert mask.sum() == len(
87
+ names.split(",")
88
+ ), f"Not all atoms in {names} found in {atom_array.atom_name}"
89
+
90
+ atom_idxs_selected = atom_idxs[mask]
91
+ atom_values[atom_idxs_selected] = value
92
+
93
+ atom_array.set_annotation(name, atom_values)
94
+ return atom_array
95
+
96
+
97
+ def fetch_motif_residue_(
98
+ src_chain,
99
+ src_resid,
100
+ *,
101
+ components,
102
+ src_atom_array,
103
+ redesign_motif_sidechains,
104
+ unindexed_components,
105
+ unfixed_sequence_components,
106
+ fixed_atoms,
107
+ unfix_all,
108
+ flexible_backbone,
109
+ unfix_residues,
110
+ ):
111
+ """
112
+ Given source chain and resid, returns the residue if present in the source atom array
113
+
114
+ NB: For glycines, we extend the array with a CB position so as to not leak whether
115
+ the original residue is a glycine if sequence is masked during inference.
116
+ """
117
+
118
+ assert (
119
+ src_atom_array is not None
120
+ ), "Motif provided in contigs, but no input provided. input={} contig={}".format(
121
+ input, components
122
+ )
123
+
124
+ # ... Fetch residue in the input atom array
125
+ mask = fetch_mask_from_idx(f"{src_chain}{src_resid}", atom_array=src_atom_array)
126
+ subarray = src_atom_array[mask]
127
+ res_name = subarray.res_name[0]
128
+
129
+ # Check if we have a redesign_motif_sidechains contig
130
+ if isinstance(redesign_motif_sidechains, list):
131
+ # If we have a list, check if the residue is in the list
132
+ if f"{src_chain}{src_resid}" in redesign_motif_sidechains:
133
+ redesign_motif_sidechains = True
134
+ else:
135
+ redesign_motif_sidechains = False
136
+
137
+ # Assign base properties
138
+ subarray = set_default_conditioning_annotations(
139
+ subarray, motif=True, unindexed=False, dtype=int
140
+ ) # all values init to True (fix all)
141
+
142
+ # Assign is motif atom and sequence
143
+ if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
144
+ atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
145
+ subarray.set_annotation("is_motif_atom", atom_mask)
146
+ # subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
147
+
148
+ elif redesign_motif_sidechains and res_name in STANDARD_AA:
149
+ n_atoms = subarray.shape[0]
150
+ diffuse_oxygen = False
151
+ if n_atoms < 3:
152
+ raise ValueError(
153
+ f"Not enough data for {src_chain}{src_resid} in input atom array."
154
+ )
155
+ if n_atoms == 3:
156
+ # Handle cases with N, CA, C only;
157
+ subarray = subarray + create_o_atoms(subarray.copy())
158
+ diffuse_oxygen = True # flag oxygen for generation
159
+
160
+ # Subset to the first 4 atoms (N, CA, C, O) only
161
+ subarray = subarray[np.isin(subarray.atom_name, ["N", "CA", "C", "O"])]
162
+
163
+ # exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
164
+ # Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
165
+ # are placed on the CB so as to not leak the identity of the residue.
166
+ subarray = subarray + create_cb_atoms(subarray.copy())
167
+
168
+ # Sequence name must be set to ALA such that the central atom is correctly CB
169
+ subarray.res_name = np.full_like(
170
+ subarray.res_name, "ALA", dtype=subarray.res_name.dtype
171
+ )
172
+ subarray.set_annotation(
173
+ "is_motif_atom",
174
+ (
175
+ np.arange(subarray.shape[0], dtype=int) < (4 - int(diffuse_oxygen))
176
+ ).astype(int),
177
+ )
178
+ subarray.set_annotation(
179
+ "is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
180
+ )
181
+ if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
182
+ subarray.set_annotation(
183
+ "is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
184
+ )
185
+ else:
186
+ subarray.set_annotation(
187
+ "is_motif_atom_with_fixed_coord", subarray.is_motif_atom.copy()
188
+ )
189
+ if flexible_backbone:
190
+ backbone_atoms = ["N", "CA", "C", "O"]
191
+ is_flexible_motif_atom = np.isin(subarray.atom_name, backbone_atoms)
192
+ subarray.set_annotation(
193
+ "is_flexible_motif_atom",
194
+ is_flexible_motif_atom,
195
+ )
196
+ else:
197
+ subarray.set_annotation(
198
+ "is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
199
+ )
200
+ to_unindex = f"{src_chain}{src_resid}" in unindexed_components
201
+ if to_unindex:
202
+ subarray.set_annotation(
203
+ "is_motif_atom_unindexed", subarray.is_motif_atom.copy()
204
+ )
205
+ # Subset to desired motif atoms
206
+ subarray = subarray[subarray.is_motif_atom.astype(bool)]
207
+
208
+ # ... Relax sequence constraint if provided
209
+ if (
210
+ exists(unfixed_sequence_components)
211
+ and f"{src_chain}{src_resid}" in unfixed_sequence_components
212
+ ):
213
+ ranked_logger.info(
214
+ "Unfixing sequence for motif {}{}".format(src_chain, src_resid)
215
+ )
216
+ subarray.set_annotation(
217
+ "is_motif_atom_with_fixed_seq",
218
+ np.zeros(subarray.shape[0], dtype=int),
219
+ )
220
+
221
+ # ... Double check that required annotations are set within the scope of this function only
222
+ check_has_required_conditioning_annotations(subarray)
223
+ subarray = set_common_annotations(subarray)
224
+ subarray.set_annotation("res_id", np.full(subarray.shape[0], 1)) # Reset to 1
225
+ return subarray
226
+
227
+
228
+ def create_diffused_residues_(n):
229
+ if n <= 0:
230
+ raise ValueError(f"Negative/null residue count ({n}) not allowed.")
231
+
232
+ atoms = []
233
+ [
234
+ atoms.extend(
235
+ [
236
+ struc.Atom(
237
+ np.array([0.0, 0.0, 0.0], dtype=np.float32),
238
+ res_name="ALA",
239
+ res_id=idx,
240
+ )
241
+ for _ in range(5)
242
+ ]
243
+ )
244
+ for idx in range(1, n + 1)
245
+ ]
246
+ array = struc.array(atoms)
247
+ array.set_annotation(
248
+ "element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
249
+ )
250
+ array.set_annotation(
251
+ "atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
252
+ )
253
+ array = set_default_conditioning_annotations(array, motif=False)
254
+ array = set_common_annotations(array)
255
+ return array
256
+
257
+
258
+ def accumulate_components(
259
+ components,
260
+ src_atom_array,
261
+ redesign_motif_sidechains,
262
+ unindexed_components: list[str],
263
+ unfixed_sequence_components: list[str],
264
+ breaks: list,
265
+ fixed_atoms: dict,
266
+ unfix_all: bool,
267
+ optional_conditions: list[str],
268
+ flexible_backbone: bool,
269
+ *,
270
+ start_chain="A",
271
+ unfix_residues: list[str],
272
+ start_resid=1,
273
+ ):
274
+ """
275
+ Subcomponents have three types, specifying either the end of a chain ("/0),
276
+ a motif (e.g. "A20" or "A21"), or a number indicating the number of diffused residues to create.
277
+ This function accumulates these components into a single atom array.
278
+
279
+ Arguments:
280
+ - components: list of components, where each component is either a string
281
+ e.g. [2, A20, A21, 2, A25, 3, A30, /0, 3]
282
+ - src_atom_array: the source atom array to fetch motifs from, or None if no input is provided.
283
+ - unindexed_components: list of components to unindex e.g. [A20, A21]
284
+ - redesign_motif_sidechains: whether to diffuse the sidechains of the input motifs (indexed components)
285
+ - fixed_atoms: dictionary of fixed atoms for each component (previously called `contig_atoms`)
286
+ - unfix_all: whether to fully unfix the motif coordinates
287
+ - unfix_residues: list of residues to unfix. Overrides `unfix_all` for specific residues.
288
+ - flexible_backbone: whether to allow flexible backbone for motifs
289
+
290
+ Returns:
291
+ - Accumulated atom array with components, and is_motif labels
292
+ """
293
+ # ... Create component assignment functions
294
+ fetch_motif_residue = functools.partial(
295
+ fetch_motif_residue_,
296
+ components=components,
297
+ src_atom_array=src_atom_array,
298
+ redesign_motif_sidechains=redesign_motif_sidechains,
299
+ unindexed_components=unindexed_components,
300
+ unfixed_sequence_components=unfixed_sequence_components,
301
+ fixed_atoms=fixed_atoms,
302
+ unfix_all=unfix_all,
303
+ flexible_backbone=flexible_backbone,
304
+ unfix_residues=unfix_residues,
305
+ )
306
+ create_diffused_residues = create_diffused_residues_
307
+
308
+ # ... For loop accum variables
309
+ breaks = [None] * len(components) if breaks is None else breaks
310
+ unindexed_components_started = (
311
+ False # once one unindexed component is added, stop adding diffused residues
312
+ )
313
+ atom_array_accum = []
314
+ chain = start_chain
315
+ res_id = start_resid
316
+ molecule_id = 0
317
+ # 2) Insert contig information one- by one-
318
+ for component, is_break in zip(components, breaks):
319
+ if component == "/0":
320
+ # reset iterators on next chain
321
+ chain = chr(ord(chain) + 1)
322
+ molecule_id += 1
323
+ res_id = 1
324
+ continue
325
+
326
+ # Create array to insert
327
+ if str(component)[0].isalpha(): # motif (e.g. "A22")
328
+ atom_array_insert = fetch_motif_residue(*split_contig(component))
329
+ n = 1
330
+ if exists(is_break) and is_break:
331
+ if not unindexed_components_started:
332
+ chain = start_chain
333
+ unindexed_components_started = True
334
+ atom_array_insert.set_annotation(
335
+ "is_motif_atom_unindexed_motif_breakpoint",
336
+ np.ones(atom_array_insert.shape[0], dtype=int),
337
+ )
338
+ else:
339
+ n = int(component)
340
+ if n == 0 or unindexed_components_started:
341
+ res_id += n
342
+ continue
343
+ atom_array_insert = create_diffused_residues(n)
344
+ for key in optional_conditions:
345
+ atom_array_insert.set_annotation(
346
+ key,
347
+ np.full(
348
+ atom_array_insert.array_length(),
349
+ OPTIONAL_CONDITIONING_VALUES[key],
350
+ dtype=int,
351
+ ),
352
+ )
353
+
354
+ # ... Set index of insertion
355
+ atom_array_insert = set_indices(
356
+ array=atom_array_insert,
357
+ chain=chain,
358
+ res_id_start=res_id,
359
+ molecule_id=molecule_id,
360
+ component=component,
361
+ )
362
+
363
+ assert (
364
+ len(get_token_starts(atom_array_insert)) == n
365
+ ), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(atom_array_insert))} in \n{atom_array_insert}"
366
+
367
+ # ... Insert & Increment residue ID
368
+ atom_array_accum.append(atom_array_insert)
369
+ res_id += n
370
+
371
+ atom_array_accum = struc.concatenate(atom_array_accum)
372
+ atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
373
+
374
+ # Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
375
+ if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
376
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
377
+ ):
378
+ max_id = np.max(
379
+ atom_array_accum[
380
+ ~atom_array_accum.is_motif_atom_unindexed.astype(bool)
381
+ ].res_id
382
+ )
383
+ min_id_udx = np.min(
384
+ atom_array_accum[
385
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
386
+ ].res_id
387
+ )
388
+ atom_array_accum.res_id[
389
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
390
+ ] += max_id - min_id_udx + 1
391
+
392
+ return atom_array_accum
393
+
394
+
395
+ #################################################################################
396
+ # Custom conditioning functions
397
+ #################################################################################
398
+
399
+
400
+ def create_atom_array_from_design_specification_legacy(
401
+ *,
402
+ # Specification args:
403
+ input: PathLike = None,
404
+ length: str = "100-300",
405
+ contig: str = None,
406
+ fixed_atoms: dict = None,
407
+ unindex: str = None,
408
+ unfix_sequence: str = None,
409
+ redesign_motif_sidechains: bool = False,
410
+ unfix_all=False,
411
+ unfix_specific: str = None,
412
+ flexible_backbone: bool = False,
413
+ # Args for biomolecular design (Enzymes, DNA/PNA):
414
+ ligand: str = None,
415
+ ori_token: list[float] = None,
416
+ infer_ori_strategy: str | None = None,
417
+ atomwise_rasa: dict = None,
418
+ atomwise_hbond: dict = None,
419
+ # Additional args:
420
+ out_path=None,
421
+ cif_parser_args=None,
422
+ # PPI Kwargs
423
+ atom_level_hotspots: dict | None = None,
424
+ # SS conditioning kwargs
425
+ is_helix: dict | None = None,
426
+ is_sheet: dict | None = None,
427
+ is_loop: dict | None = None,
428
+ spoof_helical_bundle_ss_conditioning: bool = False,
429
+ symmetry: dict = None,
430
+ # Low-temperature global conditioning args
431
+ plddt_enhanced: bool = True,
432
+ is_non_loopy: bool | None = None,
433
+ # Partial diff args:
434
+ partial_t: float | None = None, # Optional noise scale for partial diffusion
435
+ **_, # dump additional args
436
+ ):
437
+ """
438
+ Create pre-pipeline CIF file.
439
+
440
+ Arguments:
441
+ - input: path to input pdb containing coordinate data
442
+ - contig: your typical contig string '10-10,A20-21,5-5,A25-25,5-5,A30-30,10-10'.
443
+ - unindex: string of residue indices to unindex, e.g. "A20,A21" or "A20-21". Note the latter will be treated as two contiguous
444
+ residues whereas the former will end up as two uncorrelated residues.
445
+ - unfix_sequence: contig string of components to unfix sequence for.
446
+ - unfix_specific: comma separated residues to unfix coordinates for. "ALL" to unfix every motif.
447
+ - length: required total length (optional)
448
+ - ligand: name of ligand to keep from input pdb, or path to a cif file containing the ligand
449
+ - ori_token: coordinates for origin relative to coordinates in input file.
450
+ - infer_ori_strategy: string argument controlling how the ori token is inferred if not otherwise specified.
451
+ If None, the ori token will be set to the COM of the motif, or to [0,0,0] for unconditional generation.
452
+ Currently supported strategies:
453
+ - "hotspots": move 10A along an outward normal vector from the COM of the hotspots.
454
+
455
+ Returns:
456
+ - atom_array with all required conditioning annotations set appropriately.
457
+ """
458
+ ###########################################################################################################################
459
+
460
+ # 1) Load input data if provided
461
+ if exists(input):
462
+ atom_array_input = inference_load_(input, cif_parser_args=cif_parser_args)[
463
+ "atom_array"
464
+ ]
465
+ # If we are doing symmetric design, we need to center the full input atom array at the origin (for getting symmetry frames)
466
+ if exists(symmetry) and symmetry.get("id"):
467
+ atom_array_input = center_symmetric_src_atom_array(atom_array_input)
468
+ elif exists(contig) or exists(length):
469
+ atom_array_input = None
470
+ else:
471
+ raise ValueError("Either 'input' or 'contig' / 'length' must be provided.")
472
+ if isinstance(length, int):
473
+ length = f"{length}-{length}"
474
+ if exists(length) and not exists(contig):
475
+ # Handle cases where contigs aren't specified
476
+ if not exists(unindex) and not exists(flexible_backbone):
477
+ if exists(fixed_atoms):
478
+ # ensure that fixed atoms are in the input, else raise error
479
+ _ = [
480
+ fetch_mask_from_component(key, atom_array=atom_array_input)
481
+ for key in fixed_atoms.keys()
482
+ ]
483
+ ranked_logger.warning(
484
+ "No input contig specified and no motif, running unconditional generation"
485
+ )
486
+ indexed_components_provided = False
487
+ contig = length
488
+ else:
489
+ indexed_components_provided = True
490
+ if not exists(fixed_atoms):
491
+ fixed_atoms = {}
492
+
493
+ optional_conditions = []
494
+ if exists(atomwise_rasa):
495
+ set_atom_level_argument(atom_array_input, atomwise_rasa, "rasa_bin")
496
+ optional_conditions.append("rasa_bin")
497
+ if exists(atomwise_hbond):
498
+ for key, value in atomwise_hbond.items():
499
+ set_atom_level_argument(atom_array_input, value, key)
500
+ optional_conditions.append(key)
501
+ if exists(atom_level_hotspots):
502
+ set_atom_level_argument(
503
+ atom_array_input, atom_level_hotspots, "is_atom_level_hotspot"
504
+ )
505
+ optional_conditions.append("is_atom_level_hotspot")
506
+
507
+ # 2) Parse contigs into components
508
+ indexed_components = get_design_pattern_with_constraints(
509
+ contig, length
510
+ ) # e.g. [2, A20, A21, 2, A25, 3, A30, /0, 3]
511
+
512
+ # Parse redesign_motif_sidechains if necessary
513
+ if isinstance(redesign_motif_sidechains, str):
514
+ redesign_motif_sidechains = get_design_pattern_with_constraints(
515
+ redesign_motif_sidechains
516
+ )
517
+ ###########################################################################################################################
518
+
519
+ # ... Add unindexed components
520
+ unindexed_components, unindexed_breaks = (
521
+ get_motif_components_and_breaks(unindex) if exists(unindex) else ([], [])
522
+ )
523
+ breaks = [None] * len(indexed_components) + unindexed_breaks
524
+ assert_non_intersecting_contigs(indexed_components, unindexed_components)
525
+
526
+ # ... Expand unfix_sequence into components
527
+ unfixed_sequence_components = (
528
+ get_design_pattern_with_constraints(unfix_sequence) if unfix_sequence else []
529
+ )
530
+
531
+ # Determine which residues to unfix
532
+ unfix_residues = []
533
+ if isinstance(unfix_specific, list):
534
+ unfix_residues = [str(u) for u in unfix_specific]
535
+ elif isinstance(unfix_specific, str):
536
+ if unfix_specific.upper() == "ALL":
537
+ unfix_all = True
538
+ elif unfix_specific:
539
+ unfix_residues, _ = get_motif_components_and_breaks(
540
+ unfix_specific, index_all=True
541
+ )
542
+
543
+ # 3) Create atom array from components
544
+ if exists(partial_t):
545
+ ranked_logger.info("Using partial diffusion with t=%s", partial_t)
546
+ atom_array = assign_types_(copy.deepcopy(atom_array_input))
547
+ atom_array = atom_array[atom_array.is_protein]
548
+
549
+ # Set the whole thing without constraints
550
+ atom_array = set_default_conditioning_annotations(
551
+ atom_array, motif=False, unindexed=False
552
+ )
553
+ atom_array = set_common_annotations(
554
+ atom_array, set_src_component_to_res_name=False
555
+ )
556
+
557
+ # Fix parts in the atom array as fixed components
558
+ set_default_conditioning_annotations(atom_array, motif=False, unindexed=False)
559
+ if indexed_components and indexed_components_provided:
560
+ for component in indexed_components:
561
+ if str(component)[0].isalpha():
562
+ mask = fetch_mask_from_component(component, atom_array=atom_array)
563
+
564
+ # Set the component as a motif token
565
+ set_default_conditioning_annotations(
566
+ atom_array, motif=True, unindexed=False, mask=mask
567
+ )
568
+
569
+ # Set the fixed atoms of the component
570
+ if mask.any():
571
+ # Also handle fixed atoms
572
+ if component in fixed_atoms:
573
+ atom_mask = get_name_mask(
574
+ atom_array.atom_name[mask],
575
+ fixed_atoms[component],
576
+ atom_array.res_name[mask][0],
577
+ )
578
+ # If specific fixed atoms are selected, set fixed coordinates to those specified
579
+ atom_array.is_motif_atom_with_fixed_coord[mask] = atom_mask
580
+ else:
581
+ # Otherwise fix the entire token.
582
+ atom_array.is_motif_atom_with_fixed_coord[mask] = 1
583
+
584
+ # Append unindexed components from input specifcation
585
+ if unindexed_components:
586
+ start_resid = np.max(atom_array.res_id) + 1
587
+ start_chain = atom_array.chain_id[
588
+ 0
589
+ ] # HACK: set chain ID for unindexed residues as whatever the input has
590
+ atom_array_append = accumulate_components(
591
+ # Normal stuff:
592
+ components=unindexed_components,
593
+ breaks=unindexed_breaks,
594
+ # Regular other stuff
595
+ src_atom_array=atom_array_input,
596
+ redesign_motif_sidechains=redesign_motif_sidechains,
597
+ unindexed_components=unindexed_components,
598
+ unfixed_sequence_components=unfixed_sequence_components,
599
+ fixed_atoms=fixed_atoms,
600
+ unfix_all=unfix_all,
601
+ optional_conditions=optional_conditions,
602
+ flexible_backbone=flexible_backbone,
603
+ unfix_residues=unfix_residues,
604
+ start_chain=start_chain,
605
+ start_resid=start_resid,
606
+ )
607
+ atom_array = atom_array + atom_array_append
608
+ else:
609
+ atom_array = accumulate_components(
610
+ components=indexed_components + unindexed_components,
611
+ src_atom_array=atom_array_input,
612
+ redesign_motif_sidechains=redesign_motif_sidechains,
613
+ unindexed_components=unindexed_components,
614
+ unfixed_sequence_components=unfixed_sequence_components,
615
+ breaks=breaks,
616
+ fixed_atoms=fixed_atoms,
617
+ unfix_all=unfix_all,
618
+ optional_conditions=optional_conditions,
619
+ flexible_backbone=flexible_backbone,
620
+ unfix_residues=unfix_residues,
621
+ )
622
+
623
+ # Spoof assignments for is_motif_token
624
+ f = get_motif_features(atom_array)
625
+ is_motif_token = f["is_motif_token"]
626
+ atom_array.set_annotation("is_motif_token", is_motif_token.astype(int))
627
+ is_motif_atom = f["is_motif_atom"]
628
+ atom_array.set_annotation("is_motif_atom", is_motif_atom.astype(int))
629
+
630
+ # ... If ligand, post-pend it
631
+ if exists(ligand):
632
+ ligand_array = extract_ligand_array(
633
+ atom_array_input,
634
+ ligand,
635
+ fixed_atoms,
636
+ additional_annotations=set(
637
+ list(atom_array.get_annotation_categories())
638
+ + list(atom_array_input.get_annotation_categories())
639
+ + ["is_motif_atom", "is_motif_token"]
640
+ ),
641
+ )
642
+ ligand_array.res_id = (
643
+ ligand_array.res_id
644
+ - np.min(ligand_array.res_id)
645
+ + np.max(atom_array.res_id)
646
+ + 1
647
+ )
648
+ atom_array = atom_array + ligand_array
649
+
650
+ # ... Apply symmetry if it exists ahead of any other processing
651
+ if exists(symmetry) and symmetry.get("id"):
652
+ atom_array = make_symmetric_atom_array(
653
+ atom_array, symmetry, sm=ligand, src_atom_array=atom_array_input
654
+ )
655
+
656
+ # ... Input frame and ORI token handling
657
+ if exists(partial_t):
658
+ # For symmetric structures, avoid COM centering that would collapse chains
659
+ if exists(symmetry) and symmetry.get("id"):
660
+ ranked_logger.info(
661
+ "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
662
+ )
663
+ else:
664
+ atom_array = set_com(atom_array, ori_token=None, infer_ori_strategy="com")
665
+ atom_array.set_annotation(
666
+ "partial_t", np.full(atom_array.shape[0], partial_t, dtype=float)
667
+ )
668
+ else:
669
+ atom_array = set_com(
670
+ atom_array, ori_token=ori_token, infer_ori_strategy=infer_ori_strategy
671
+ )
672
+ # diffused atoms initialized at origin
673
+ atom_array.coord[~atom_array.is_motif_atom_with_fixed_coord.astype(bool), :] = (
674
+ 0.0
675
+ )
676
+
677
+ # This is an annotation on the diffused regions, so must be added after accumulate_components
678
+ if spoof_helical_bundle_ss_conditioning:
679
+ is_helix = spoof_helical_bundle_ss_conditioning_fn(atom_array)
680
+ is_sheet = None
681
+ is_loop = None
682
+ if exists(is_helix):
683
+ set_atom_level_argument(atom_array, is_helix, "is_helix")
684
+ if exists(is_sheet):
685
+ set_atom_level_argument(atom_array, is_sheet, "is_sheet")
686
+ optional_conditions.append("is_sheet")
687
+ if exists(is_loop):
688
+ set_atom_level_argument(atom_array, is_loop, "is_loop")
689
+ optional_conditions.append("is_loop")
690
+
691
+ is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
692
+ diffused_region_mask = ~(atom_array.is_motif_token.astype(bool))
693
+ if exists(is_non_loopy):
694
+ is_non_loopy_annot[diffused_region_mask] = 1 if is_non_loopy else -1
695
+
696
+ atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
697
+ atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
698
+
699
+ if plddt_enhanced:
700
+ atom_array.set_annotation(
701
+ "ref_plddt", np.ones((atom_array.array_length(),), dtype=int)
702
+ )
703
+
704
+ # Ensure correct annotations before saving
705
+ check_has_required_conditioning_annotations(
706
+ atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
707
+ )
708
+ convert_existing_annotations_to_bool(atom_array)
709
+
710
+ if "atom_id" in atom_array.get_annotation_categories():
711
+ ranked_logger.info("Removing atom_id annotation...")
712
+ atom_array.del_annotation("atom_id")
713
+
714
+ if out_path is not None:
715
+ to_cif_file(atom_array, out_path, extra_fields=INFERENCE_ANNOTATIONS)
716
+
717
+ return atom_array