rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,1123 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import time
6
+ import warnings
7
+ from contextlib import contextmanager
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ from atomworks.constants import STANDARD_AA
12
+ from atomworks.io.parser import parse_atom_array
13
+
14
+ # from atomworks.ml.datasets.datasets import BaseDataset
15
+ from atomworks.ml.transforms.base import TransformedDict
16
+ from atomworks.ml.utils.token import (
17
+ get_token_starts,
18
+ )
19
+ from biotite import structure as struc
20
+ from biotite.structure import AtomArray, BondList, get_residue_starts
21
+ from pydantic import (
22
+ BaseModel,
23
+ ConfigDict,
24
+ Field,
25
+ model_validator,
26
+ )
27
+ from rfd3.constants import (
28
+ INFERENCE_ANNOTATIONS,
29
+ REQUIRED_CONDITIONING_ANNOTATION_VALUES,
30
+ REQUIRED_INFERENCE_ANNOTATIONS,
31
+ )
32
+ from rfd3.inference.legacy_input_parsing import (
33
+ create_atom_array_from_design_specification_legacy,
34
+ )
35
+ from rfd3.inference.parsing import InputSelection
36
+ from rfd3.inference.symmetry.symmetry_utils import (
37
+ SymmetryConfig,
38
+ center_symmetric_src_atom_array,
39
+ make_symmetric_atom_array,
40
+ )
41
+ from rfd3.transforms.conditioning_base import (
42
+ check_has_required_conditioning_annotations,
43
+ convert_existing_annotations_to_bool,
44
+ get_motif_features,
45
+ set_default_conditioning_annotations,
46
+ )
47
+ from rfd3.transforms.util_transforms import assign_types_
48
+ from rfd3.utils.inference import (
49
+ _restore_bonds_for_nonstandard_residues,
50
+ extract_ligand_array,
51
+ inference_load_,
52
+ set_com,
53
+ set_common_annotations,
54
+ set_indices,
55
+ )
56
+
57
+ from foundry.common import exists
58
+ from foundry.utils.components import (
59
+ get_design_pattern_with_constraints,
60
+ get_motif_components_and_breaks,
61
+ )
62
+ from foundry.utils.ddp import RankedLogger
63
+
64
+ logging.basicConfig(level=logging.DEBUG)
65
+
66
+ logger = RankedLogger(__name__, rank_zero_only=True)
67
+
68
+
69
+ #################################################################################
70
+ # Custom infer_ori functions
71
+ #################################################################################
72
+
73
+
74
+ class LegacySpecification(BaseModel):
75
+ """Legacy specification for compatibility with legacy input parsing."""
76
+
77
+ model_config = ConfigDict(
78
+ arbitrary_types_allowed=True,
79
+ extra="allow",
80
+ )
81
+
82
+ def build(self, *args, **kwargs):
83
+ """Build atom array using legacy input parsing."""
84
+ atom_array = create_atom_array_from_design_specification_legacy(
85
+ **self.model_dump(),
86
+ )
87
+ return atom_array, self.model_dump()
88
+
89
+ def to_pipeline_input(self, example_id):
90
+ atom_array, spec_dict = self.build(return_metadata=True)
91
+
92
+ # ... Forward into
93
+ data = prepare_pipeline_input_from_atom_array(atom_array)
94
+ data["example_id"] = example_id
95
+
96
+ # ... Wrap up with additional features
97
+ if "extra" not in spec_dict:
98
+ spec_dict["extra"] = {}
99
+ spec_dict["extra"]["example_id"] = example_id
100
+ data["specification"] = spec_dict
101
+ return data
102
+
103
+
104
+ # ========================================================================
105
+ # Input specification
106
+ # ========================================================================
107
+
108
+
109
+ class DesignInputSpecification(BaseModel):
110
+ """Validated and parsed input specification before resolution."""
111
+
112
+ model_config = ConfigDict(
113
+ hide_input_in_errors=False,
114
+ arbitrary_types_allowed=True,
115
+ validate_assignment=False,
116
+ str_strip_whitespace=True,
117
+ str_min_length=1,
118
+ extra="forbid",
119
+ )
120
+ # fmt: off
121
+ # ========================================================================
122
+ # Data inputs, motif generation & selection
123
+ # ========================================================================
124
+ # Data inputs
125
+ atom_array_input: Optional[AtomArray] = Field(None, description="Loaded atom array", exclude=True)
126
+ input: Optional[str] = Field(None, description="Path to input PDB/CIF file")
127
+ # Motif selection from input file
128
+ contig: Optional[InputSelection] = Field(None, description="Contig specification string (e.g. 'A1-10,B1-5')")
129
+ unindex: Optional[InputSelection] = Field(None,
130
+ description="Unindexed components string (components must not overlap with contig). "\
131
+ "E.g. 'A15-20,B6-10' or dict. We recommend specifying")
132
+ # Extra args:
133
+ length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
134
+ ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
135
+ cif_parser_args: Optional[Dict[str, Any]] = Field(None, description="CIF parser arguments")
136
+ extra: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra metadata to include in output (useful for logging additional info in metadata)")
137
+ dialect: int = Field(2, description="RFdiffusion3 input dialect. 1: legacy, 2: release.")
138
+
139
+ # ========================================================================
140
+ # Conditioning
141
+ # ========================================================================
142
+ # Sequence and coordinate conditioning
143
+ select_fixed_atoms: Optional[InputSelection] = Field(None,
144
+ description='''Atoms to fix coordinates for. Examples:
145
+ - True (default when inputs provided): All atoms pulled from the input are fixed in 3d space
146
+ - False: All atoms pulled from the input are unfixed in 3d space
147
+ - ContigStr: Components to fix in 3d space, e.g. "A1-10,B1-3" fixes residues 1-10 in chain A and residues 1-3 in chain B.
148
+ - {"A1": "N,CA,C,O,CB,CG", "A2-10": "BKBN"} fixes backbone and CB for residues 1 and 2, and all atoms for residues 3-10 in chain A.
149
+ '''.replace('\t\t', '\t')
150
+ )
151
+ select_unfixed_sequence: Optional[InputSelection] = Field(None, description='''Components to unfix sequence for.
152
+ - True (default when inputs provided): All atoms from the input have fixed sequences by default.
153
+ - False: All atoms pulled from the input have diffused sequences by default.
154
+ - ContigStr: Components to unfix sequence for, e.g. "A5-10,B1-3" unfixes sequence for residues 5-10 in chain A and residues 1-3 in chain B.
155
+ - Dictionary: Allowed but not recommended.
156
+ NOTE: Excludes ligands (ligands / DNA always has fixed sequence).
157
+ '''.replace('\t\t', '\t')
158
+ )
159
+ # Assignments of conditioning annotations
160
+ # RASA accessibilty
161
+ select_buried: Optional[InputSelection] = Field(None, description="Selection of RASA buried conditioning")
162
+ select_partially_buried: Optional[InputSelection] = Field(None, description="Selection of RASA partially buried conditioning")
163
+ select_exposed: Optional[InputSelection] = Field(None, description="Selection of RASA exposed conditioning")
164
+ # Hotspots & Hbonds
165
+ select_hbond_acceptor: Optional[InputSelection] = Field(None, description="Atom-wise hydrogen bond acceptor")
166
+ select_hbond_donor: Optional[InputSelection] = Field(None, description="Atom-wise hydrogen bond donor")
167
+ select_hotspots: Optional[InputSelection] = Field(None, description="Atom-level or token-level hotspots for PPI")
168
+ redesign_motif_sidechains: Union[bool, str] = Field(False,
169
+ description="Perform fixed-backbone sequence design on when 'contig' is provided. Changes the default behaviour when not using `select_fixed_atoms`."
170
+ )
171
+
172
+ # ========================================================================
173
+ # Global conditioning & symmetry
174
+ # ========================================================================
175
+ # Symmetry
176
+ symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md")
177
+ # Centering & COM guidance
178
+ ori_token: Optional[list[float]] = Field(None, description="Origin coordinates")
179
+ infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`")
180
+ # Additional global conditioning
181
+ plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement")
182
+ is_non_loopy: Optional[bool] = Field(None, description="Non-loopy conditioning")
183
+ # Partial diffusion
184
+ partial_t: Optional[float] = Field(None, ge=0.0, description="Angstroms of noise to add for partial diffusion (None turns off partial diffusion), t <= 15 recommended.")
185
+ # fmt: on
186
+
187
+ # ========================================================================
188
+ # Properties
189
+ # ========================================================================
190
+
191
+ @property
192
+ def is_partial_diffusion(self) -> bool:
193
+ """Whether partial diffusion is enabled."""
194
+ return exists(self.partial_t)
195
+
196
+ # ========================================================================
197
+ # Loading / saving
198
+ # ========================================================================
199
+
200
+ @classmethod
201
+ def from_json(cls, path):
202
+ with open(path, "r") as f:
203
+ data = json.load(f)
204
+ return cls(**data)
205
+
206
+ @classmethod
207
+ def from_rfd3_out(cls, path: str):
208
+ """Load from path to rfd3 outputs, either .cif, .cif.gz, .json or denoised / noisy trajectory files"""
209
+ path = path.replace(".cif.gz", ".cif").replace(".cif", ".json")
210
+ if not os.path.exists(path):
211
+ raise FileNotFoundError(f"Output file not found at {path}")
212
+ with open(path, "r") as f:
213
+ data = json.load(f)
214
+ if "input_specification" in data:
215
+ spec_args = data["input_specification"]
216
+ return cls(**spec_args)
217
+ else:
218
+ raise ValueError(f"No input specification found in json output: {path}")
219
+
220
+ def get_dict_to_save(self, exclude_extra: bool = False) -> dict:
221
+ # Returns dictionary for saving (reproducible) outputs to json
222
+ return self.model_dump(
223
+ exclude_defaults=True,
224
+ exclude={"atom_array_input"} | set({"extra"} if exclude_extra else {}),
225
+ )
226
+
227
+ # ========================================================================
228
+ # Pre-Validation / canonicalization
229
+ # ========================================================================
230
+
231
+ @model_validator(mode="before")
232
+ @classmethod
233
+ def validate_input_schema(cls, data: dict) -> dict:
234
+ if not (
235
+ exists(data.get("input"))
236
+ or exists(data.get("contig"))
237
+ or exists(data.get("length"))
238
+ ):
239
+ raise ValueError("Either 'input' or 'contig' / 'length' must be provided.")
240
+
241
+ # unused input check
242
+ if exists(data.get("input")) and not (
243
+ (
244
+ exists(data.get("contig"))
245
+ or exists(data.get("unindex"))
246
+ or exists(data.get("ligand"))
247
+ )
248
+ or exists(data.get("partial_t"))
249
+ ):
250
+ raise ValueError("Input provided but unused in composition specification.")
251
+
252
+ if not exists(data.get("partial_t")):
253
+ # non-partial diffusion checks
254
+ if exists(data.get("unindex")) and not (
255
+ exists(data.get("contig")) or exists(data.get("length"))
256
+ ):
257
+ raise ValueError(
258
+ "Unindex provided but neither a length nor contig was specified."
259
+ )
260
+ else:
261
+ # partial diffusion checks
262
+ if exists(data.get("length")):
263
+ raise ValueError(
264
+ "Length argument must not be provided during partial diffusion."
265
+ )
266
+ if not (exists(data.get("input")) or exists(data.get("atom_array_input"))):
267
+ raise ValueError(
268
+ "Partial diffusion requires input file or input atom array."
269
+ )
270
+
271
+ return data
272
+
273
+ @model_validator(mode="before")
274
+ @classmethod
275
+ def canonicalize(cls, data: dict) -> dict:
276
+ # Canonicalize length argument
277
+ data["length"] = str(data["length"]) if exists(data.get("length")) else None
278
+
279
+ # Normalize input to str
280
+ data["input"] = str(data["input"]) if exists(data.get("input")) else None
281
+ return data
282
+
283
+ @model_validator(mode="before")
284
+ @classmethod
285
+ def load_input(cls, data: dict) -> dict:
286
+ with validator_context("load_input"):
287
+ # ... Find provided selections
288
+ selections = [
289
+ # Motif
290
+ "contig",
291
+ "unindex",
292
+ # Aux
293
+ "select_fixed_atoms",
294
+ "select_unfixed_sequence",
295
+ # Conditioning
296
+ "select_buried",
297
+ "select_partially_buried",
298
+ "select_exposed",
299
+ "select_hbond_acceptor",
300
+ "select_hbond_donor",
301
+ "select_hotspots",
302
+ ]
303
+ selections = [s for s in selections if s in data]
304
+
305
+ # ... Early return if no input file provided / atom array input
306
+ if not exists(data.get("input")) and not exists(
307
+ data.get("atom_array_input")
308
+ ):
309
+ if selections:
310
+ raise ValueError(
311
+ "Atom array input must be provided before parsing selections: {}".format(
312
+ selections
313
+ )
314
+ )
315
+ return data
316
+
317
+ # ... Load atom array from input file if provided
318
+ if exists(data["input"]):
319
+ if exists(data.get("atom_array_input")):
320
+ raise ValueError(
321
+ "Both 'input' and 'atom_array_input' provided; please provide only one."
322
+ )
323
+ atom_array = inference_load_(
324
+ data["input"], cif_parser_args=data.get("cif_parser_args")
325
+ )["atom_array"]
326
+
327
+ # Center for symmetric design
328
+ if exists(data.get("symmetry")) and data["symmetry"].get("id"):
329
+ atom_array = center_symmetric_src_atom_array(atom_array)
330
+
331
+ if "atom_id" in atom_array.get_annotation_categories():
332
+ atom_array.del_annotation("atom_id")
333
+
334
+ data["atom_array_input"] = atom_array
335
+
336
+ atom_array = data["atom_array_input"]
337
+
338
+ # ... Set defaults if not provided
339
+ if not exists(data.get("select_fixed_atoms")):
340
+ data["select_fixed_atoms"] = InputSelection.from_any(
341
+ True, atom_array=atom_array
342
+ )
343
+ if not exists(data.get("select_unfixed_sequence")):
344
+ data["select_unfixed_sequence"] = InputSelection.from_any(
345
+ False, atom_array=atom_array
346
+ )
347
+
348
+ # Coerce selections
349
+ for sele in selections:
350
+ if sele in ["contig", "unindexed_breaks"]:
351
+ if exists(data[sele]) and not isinstance(data[sele], str):
352
+ raise ValueError(
353
+ f"{sele} selection must be a string or None, got {type(data[sele])} instead."
354
+ )
355
+ if not isinstance(data.get(sele), InputSelection):
356
+ data[sele] = InputSelection.from_any(
357
+ data[sele], atom_array=atom_array
358
+ )
359
+ return data
360
+
361
+ # ========================================================================
362
+ # Post-Validation
363
+ # ========================================================================
364
+
365
+ @model_validator(mode="after")
366
+ def assert_exclusivity(self):
367
+ with validator_context("assert_exclusivity"):
368
+ # ... Assert and indexed do not overlap
369
+ if exists(self.contig) and exists(self.unindex):
370
+ indexed_set = set(self.contig.keys())
371
+ unindexed_set = set(self.unindex.keys())
372
+ overlap = indexed_set & unindexed_set
373
+ if overlap:
374
+ raise ValueError(
375
+ f"Indexed and unindexed components must not overlap, got: {overlap}"
376
+ )
377
+
378
+ # ... Assert mutual exclusivity of rasa binning
379
+ exclusive_sets = [
380
+ ("Motifs", ("contig", "unindex")),
381
+ (
382
+ "RASA",
383
+ ("select_buried", "select_partially_buried", "select_exposed"),
384
+ ),
385
+ ]
386
+
387
+ for name, excl_set in exclusive_sets:
388
+ masks = [getattr(self, field, None) for field in excl_set]
389
+ masks = [m.get_mask() for m in masks if m is not None]
390
+ if not masks:
391
+ continue
392
+ mask_sum = np.zeros_like(masks[0], dtype=int)
393
+ for m in masks:
394
+ if m is not None:
395
+ mask_sum += m.astype(int)
396
+ if np.any(mask_sum > 1):
397
+ raise ValueError(
398
+ f"Selections for `{name}` must be mutually exclusive, got overlapping selections: {excl_set}. Mask sum: {mask_sum}"
399
+ )
400
+
401
+ return self
402
+
403
+ @model_validator(mode="after")
404
+ def attempt_expansion(self):
405
+ if self.is_partial_diffusion and exists(self.contig):
406
+ contig = self.contig
407
+ length = self.length
408
+ try:
409
+ get_design_pattern_with_constraints(contig.raw, length=length)
410
+ except Exception as e:
411
+ raise ValueError(f"Failed to expand contig ({contig.raw}): {e}")
412
+ return self
413
+
414
+ @model_validator(mode="after")
415
+ def _assign_types_to_input(self):
416
+ """Assign conditioning annotations to the input atom array"""
417
+ aa = self.atom_array_input
418
+ if not exists(aa):
419
+ return self
420
+
421
+ # ... Selections and their annotation values
422
+ selection_fields = {
423
+ # field name: (annotation name, assigned value, non-selected value)
424
+ "select_fixed_atoms": ("is_motif_atom_with_fixed_coord", True, False),
425
+ "select_unfixed_sequence": ("is_motif_atom_with_fixed_seq", False, True),
426
+ "unindex": ("is_motif_atom_unindexed", True, False),
427
+ "select_hotspots": ("is_atom_level_hotspot", True, False),
428
+ "select_hbond_acceptor": ("active_acceptor", True, False),
429
+ "select_hbond_donor": ("active_donor", True, False),
430
+ "select_buried": ("rasa_bin", 0, 3),
431
+ "select_partially_buried": ("rasa_bin", 1, 3),
432
+ "select_exposed": ("rasa_bin", 2, 3),
433
+ }
434
+ selection_fields = {
435
+ k: v for k, v in selection_fields.items() if exists(getattr(self, k, None))
436
+ }
437
+
438
+ # ... Init global
439
+ [
440
+ aa.set_annotation(name, np.full(aa.array_length(), val, dtype=int))
441
+ for name, val in REQUIRED_CONDITIONING_ANNOTATION_VALUES.items()
442
+ ]
443
+
444
+ # Application of selections to each token fn;
445
+ def apply_selections(start, end):
446
+ chain_id = aa.chain_id[start]
447
+ res_id = aa.res_id[start]
448
+
449
+ # Assign all select fields to atom array annotations.
450
+ for selection_name, (
451
+ annotation_name,
452
+ set_value,
453
+ default_value,
454
+ ) in selection_fields.items():
455
+ # ... Get input values
456
+ selection = getattr(self, selection_name)
457
+
458
+ # Important line: selects from data dictionary based on src chain & res_id (Not name!)
459
+ atom_names_sele = selection.get(f"{chain_id}{res_id}")
460
+
461
+ if atom_names_sele is None:
462
+ continue
463
+ mask = np.isin(aa.atom_name[start:end], atom_names_sele)
464
+ if annotation_name in aa.get_annotation_categories():
465
+ # ... Set only mask overridden features if exists in atom array
466
+ aa.get_annotation(annotation_name)[start:end] = np.where(
467
+ mask, set_value, default_value
468
+ ).astype(np.int_)
469
+ # ).astype(int)
470
+ else:
471
+ # ... Otherwise, set the entire annotation and use defaults for unselected
472
+ mask_aa = np.zeros(aa.array_length(), dtype=bool)
473
+ mask_aa[start:end] = mask
474
+ annotation_values = np.where(
475
+ mask_aa,
476
+ set_value,
477
+ default_value,
478
+ ).astype(np.int_)
479
+ aa.set_annotation(annotation_name, annotation_values)
480
+
481
+ # ... Set default assignments per-token based on whether redesigning
482
+ starts = get_residue_starts(aa, add_exclusive_stop=True)
483
+ for start, end in zip(starts[:-1], starts[1:]):
484
+ # ... Relax sequence and sidechains
485
+ if aa.res_name[start] in STANDARD_AA and self.redesign_motif_sidechains:
486
+ is_bkbn = np.isin(aa.atom_name[start:end], ["N", "CA", "C", "O"])
487
+ aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
488
+ aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
489
+ is_bkbn, False, dtype=int
490
+ )
491
+
492
+ # ... Apply selections on top
493
+ apply_selections(start, end)
494
+
495
+ return self
496
+
497
+ # ========================================================================
498
+ # Building
499
+ # ========================================================================
500
+
501
+ def build(self, return_metadata=False):
502
+ """Main build pipeline."""
503
+ atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
504
+ atom_array = self._build_init(atom_array_input_annotated)
505
+
506
+ # Apply post-processing
507
+ atom_array = self._append_ligand(atom_array, atom_array_input_annotated)
508
+ atom_array = self._apply_symmetry(atom_array, atom_array_input_annotated)
509
+
510
+ # Apply globals to all tokens (including diffused)
511
+ atom_array = self._set_origin(atom_array)
512
+ atom_array = self._apply_globals(atom_array)
513
+
514
+ # Final validation and cleanup
515
+ check_has_required_conditioning_annotations(
516
+ atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
517
+ )
518
+ convert_existing_annotations_to_bool(atom_array)
519
+
520
+ # ... Route return type
521
+ if not return_metadata:
522
+ return copy.deepcopy(atom_array)
523
+ else:
524
+ metadata = self.get_dict_to_save()
525
+ metadata["extra"] = metadata.get("extra", {}) | {
526
+ "num_tokens_in": len(get_token_starts(atom_array)),
527
+ "num_residues_in": len(get_residue_starts(atom_array)),
528
+ "num_chains": len(np.unique(atom_array.chain_id)),
529
+ "num_atoms": len(atom_array),
530
+ "num_residues": len(
531
+ np.unique(list(zip(atom_array.chain_id, atom_array.res_id)))
532
+ ),
533
+ }
534
+ return copy.deepcopy(atom_array), metadata
535
+
536
+ # ============================================================================
537
+ # Building functions
538
+ # ============================================================================
539
+
540
+ def _build_init(self, atom_array_input_annotated):
541
+ # ... Fetch tokens
542
+ indexed_tokens = (
543
+ self.contig.get_tokens(atom_array_input_annotated)
544
+ if exists(self.contig)
545
+ else {}
546
+ )
547
+ unindexed_tokens = (
548
+ self.unindex.get_tokens(atom_array_input_annotated)
549
+ if exists(self.unindex)
550
+ else {}
551
+ )
552
+ # Subset to only fixed coordindate atoms
553
+ unindexed_tokens = {
554
+ k: tok[tok.is_motif_atom_with_fixed_coord.astype(bool)]
555
+ for k, tok in unindexed_tokens.items()
556
+ }
557
+ unindexed_components, unindexed_breaks = self.break_unindexed(self.unindex)
558
+
559
+ if not self.is_partial_diffusion:
560
+ # ... Sample the contig string
561
+ components_to_accumulate = get_design_pattern_with_constraints(
562
+ self.contig.raw if exists(self.contig) else self.length,
563
+ length=self.length,
564
+ )
565
+ self.extra["sampled_contig"] = ",".join(
566
+ [str(x) for x in components_to_accumulate]
567
+ )
568
+
569
+ # ... Include unindexed components in accumulation
570
+ unindexed_breaks = [None] * len(components_to_accumulate) + unindexed_breaks
571
+ components_to_accumulate += unindexed_components
572
+
573
+ # ... Accumulate from scratch
574
+ atom_array = accumulate_components(
575
+ components_to_accumulate,
576
+ indexed_tokens=indexed_tokens,
577
+ unindexed_tokens=unindexed_tokens,
578
+ atom_array_accum=[],
579
+ unindexed_breaks=unindexed_breaks,
580
+ start_chain="A",
581
+ start_resid=1,
582
+ )
583
+ else:
584
+ # ... Set common annotations
585
+ atom_array_in = assign_types_(copy.deepcopy(atom_array_input_annotated))
586
+ atom_array_in = set_common_annotations(
587
+ atom_array_in, set_src_component_to_res_name=False
588
+ )
589
+
590
+ # ... Override motif annotations from pipeline
591
+ zeros = np.zeros(atom_array_in.array_length(), dtype=int)
592
+ atom_array_in.is_motif_atom_unindexed = (
593
+ zeros # reset unindexed annotation since those are copied already.
594
+ )
595
+ atom_array_in.is_motif_atom_with_fixed_coord = (
596
+ self.select_fixed_atoms.get_mask().astype(int)
597
+ if exists(self.select_fixed_atoms)
598
+ else zeros
599
+ )
600
+ atom_array_in.is_motif_atom_with_fixed_seq = (
601
+ ~self.select_unfixed_sequence.get_mask()
602
+ if exists(self.select_unfixed_sequence)
603
+ else zeros
604
+ ).astype(int)
605
+
606
+ # ... Subset to residues only
607
+ atom_array_in = atom_array_in[atom_array_in.is_protein]
608
+
609
+ # ... Set chain ID for unindexed residues as whatever the input has
610
+ start_resid = np.max(atom_array_in.res_id) + 1
611
+ start_chain = atom_array_in.chain_id[0]
612
+
613
+ # ... Accumulate from input
614
+ components_to_accumulate = unindexed_components
615
+ atom_array = accumulate_components(
616
+ # No accumulation of components
617
+ components_to_accumulate=components_to_accumulate,
618
+ indexed_tokens={},
619
+ # Append all inputs to unindexed tokens
620
+ unindexed_tokens=unindexed_tokens,
621
+ atom_array_accum=[atom_array_in],
622
+ start_chain=start_chain,
623
+ start_resid=start_resid,
624
+ unindexed_breaks=unindexed_breaks,
625
+ )
626
+
627
+ return atom_array
628
+
629
+ # ============================================================================
630
+ # Auxiliary functions
631
+ # ============================================================================
632
+
633
+ @staticmethod
634
+ def break_unindexed(unindex: InputSelection):
635
+ if not exists(unindex):
636
+ return [], []
637
+
638
+ # ... If original type was string, use that
639
+ if isinstance(unindex.raw, str):
640
+ unindexed_string = unindex.raw
641
+ elif isinstance(unindex.raw, dict):
642
+ unindexed_string = ",".join(unindex.raw.keys())
643
+ else:
644
+ logger.info(
645
+ "`Unindex` provided as non-string, separate keys in dictionary will be considered separate contiguous components"
646
+ )
647
+ unindexed_string = ",".join(unindex.keys())
648
+
649
+ # ... Break expected unindexed contig string
650
+ unindexed_components, breaks = get_motif_components_and_breaks(unindexed_string)
651
+
652
+ return unindexed_components, breaks
653
+
654
+ # ============================================================================
655
+ # Setter functions
656
+ # ============================================================================
657
+
658
+ def _append_ligand(self, atom_array, atom_array_input_annotated):
659
+ """Append ligand if specified."""
660
+ if exists(self.ligand):
661
+ ligand_array = extract_ligand_array(
662
+ atom_array_input_annotated,
663
+ self.ligand,
664
+ fixed_atoms={},
665
+ set_defaults=False,
666
+ additional_annotations=set(
667
+ list(atom_array.get_annotation_categories())
668
+ + list(atom_array_input_annotated.get_annotation_categories())
669
+ ),
670
+ )
671
+ # Offset ligand residue ids based on the original input to avoid clashes
672
+ # with any newly created residues (matches legacy behaviour).
673
+ ligand_array.res_id = (
674
+ ligand_array.res_id
675
+ - np.min(ligand_array.res_id)
676
+ + np.max(atom_array.res_id)
677
+ + 1
678
+ )
679
+ atom_array = atom_array + ligand_array
680
+ return atom_array
681
+
682
+ def _apply_symmetry(self, atom_array, atom_array_input_annotated):
683
+ """Apply symmetry transformation if specified."""
684
+ if exists(self.symmetry) and self.symmetry.id:
685
+ atom_array = make_symmetric_atom_array(
686
+ atom_array,
687
+ self.symmetry,
688
+ sm=self.ligand,
689
+ src_atom_array=atom_array_input_annotated,
690
+ )
691
+ return atom_array
692
+
693
+ def _set_origin(self, atom_array):
694
+ """Set origin token and initialize coordinates."""
695
+ if self.is_partial_diffusion:
696
+ # Partial diffusion: use COM, keep all coordinates
697
+ if exists(self.symmetry) and self.symmetry.id:
698
+ # For symmetric structures, avoid COM centering that would collapse chains
699
+ ranked_logger.info(
700
+ "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
701
+ )
702
+ else:
703
+ atom_array = set_com(
704
+ atom_array, ori_token=None, infer_ori_strategy="com"
705
+ )
706
+ else:
707
+ # Standard: set ori token, zero out diffused atoms
708
+ atom_array = set_com(
709
+ atom_array,
710
+ ori_token=self.ori_token,
711
+ infer_ori_strategy=self.infer_ori_strategy,
712
+ )
713
+ # Diffused atoms are always initialized at origin during regular diffusion (all information removed)
714
+ atom_array.coord[
715
+ ~atom_array.is_motif_atom_with_fixed_coord.astype(bool)
716
+ ] = 0.0
717
+ return atom_array
718
+
719
+ def _apply_globals(self, atom_array):
720
+ # Temperature conditioning
721
+ if exists(self.is_non_loopy):
722
+ is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
723
+ is_motif_token = get_motif_features(atom_array)["is_motif_token"]
724
+ diffused_region_mask = ~(is_motif_token.astype(bool))
725
+ if exists(self.is_non_loopy):
726
+ is_non_loopy_annot[diffused_region_mask] = (
727
+ 1 if self.is_non_loopy else -1
728
+ )
729
+ atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
730
+ atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
731
+ else:
732
+ zeros = np.zeros(atom_array.array_length(), dtype=int)
733
+ atom_array.set_annotation("is_non_loopy", zeros)
734
+ atom_array.set_annotation("is_non_loopy_atom_level", zeros)
735
+
736
+ if self.plddt_enhanced:
737
+ atom_array.set_annotation(
738
+ "ref_plddt", np.full((atom_array.array_length(),), True, dtype=int)
739
+ )
740
+
741
+ # Partial diffusion time annotation
742
+ if self.is_partial_diffusion:
743
+ atom_array.set_annotation(
744
+ "partial_t", np.full(atom_array.shape[0], self.partial_t, dtype=float)
745
+ )
746
+ return atom_array
747
+
748
+ @classmethod
749
+ def safe_init(cls, **spec_kwargs):
750
+ if spec_kwargs.get("dialect", 2) < 2:
751
+ warn = (
752
+ "Using dialect==1, which is deprecated and will be removed in future releases. "
753
+ "Please update your input specification to dialect=2 and use the new schema if possible"
754
+ )
755
+ warnings.warn(warn, DeprecationWarning)
756
+ logger.warning(warn)
757
+ return LegacySpecification(**spec_kwargs)
758
+ else:
759
+ return cls(**spec_kwargs)
760
+
761
+ def to_pipeline_input(self, example_id):
762
+ atom_array, spec_dict = self.build(return_metadata=True)
763
+
764
+ # ... Forward into
765
+ data = prepare_pipeline_input_from_atom_array(atom_array)
766
+ data["example_id"] = example_id
767
+
768
+ # ... Wrap up with additional features
769
+ if "extra" not in spec_dict:
770
+ spec_dict["extra"] = {}
771
+ spec_dict["extra"]["example_id"] = example_id
772
+ data["specification"] = spec_dict
773
+ return data
774
+
775
+
776
+ # ============================================================================
777
+ # APIs and utils
778
+ # ============================================================================
779
+
780
+
781
+ def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
782
+ atom_array_orig,
783
+ ) -> dict:
784
+ """
785
+ Load or create an example from a metadata dictionary.
786
+ If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
787
+ Args:
788
+ atom_array_orig: Atom array instantiated with conditioning annotations
789
+
790
+ Returns:
791
+ dict: A dictionary containing the parsed row data and additional loaded CIF data.
792
+ """
793
+ _start_parse_time = time.time()
794
+ # HACK: Set empty bond graph:
795
+ if atom_array_orig.bonds is None:
796
+ atom_array_orig.bonds = BondList(atom_array_orig.array_length())
797
+
798
+ # Temporary spoof of chain IDs to ensure duplicates aren't dropped:
799
+ result_dict = parse_atom_array(
800
+ atom_array_orig,
801
+ remove_ccds=[],
802
+ fix_arginines=False,
803
+ add_missing_atoms=False,
804
+ extra_fields=INFERENCE_ANNOTATIONS,
805
+ build_assembly=None,
806
+ hydrogen_policy="remove",
807
+ )
808
+ atom_array = result_dict["asym_unit"][0]
809
+
810
+ # HACK: Set iid information manually
811
+ # We currently do not preserve this information from the input,
812
+ # if you want these we'd need to remove the spoofing here
813
+ check_has_required_conditioning_annotations(
814
+ atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
815
+ )
816
+ atom_array = convert_existing_annotations_to_bool(atom_array)
817
+ atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
818
+ atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
819
+
820
+ # Ensure motif annotations are removed
821
+ atom_array.del_annotation(
822
+ "is_motif_token"
823
+ ) if "is_motif_token" in atom_array.get_annotation_categories() else None
824
+ atom_array.del_annotation(
825
+ "is_motif_atom"
826
+ ) if "is_motif_atom" in atom_array.get_annotation_categories() else None
827
+
828
+ data = {
829
+ "atom_array": atom_array, # First model
830
+ "chain_info": result_dict["chain_info"],
831
+ "ligand_info": result_dict["ligand_info"],
832
+ "metadata": result_dict["metadata"],
833
+ }
834
+ _stop_parse_time = time.time()
835
+ data = TransformedDict(data)
836
+ return data
837
+
838
+
839
+ def create_atom_array_from_design_specification(
840
+ **spec_kwargs,
841
+ ) -> tuple[AtomArray, dict]:
842
+ if int(spec_kwargs.get("dialect", 2)) < 2:
843
+ warn = (
844
+ "Using dialect==1, which is deprecated and will be removed in future releases. "
845
+ "Please update your input specification to dialect=2 and use the new schema if possible"
846
+ )
847
+ warnings.warn(warn, DeprecationWarning)
848
+ logger.warning(warn)
849
+ atom_array = create_atom_array_from_design_specification_legacy(**spec_kwargs)
850
+ return atom_array, {}
851
+
852
+ # Create input specfication and build
853
+ spec = DesignInputSpecification(**spec_kwargs)
854
+ atom_array, metadata = spec.build(return_metadata=True)
855
+ return atom_array, metadata
856
+
857
+
858
+ @contextmanager
859
+ def validator_context(validator_name: str, data: dict = None):
860
+ """Context manager for validator execution with logging."""
861
+ logger.debug(f"Starting validator: {validator_name}")
862
+ try:
863
+ yield
864
+ logger.debug(f"✓ Completed validator: {validator_name}")
865
+ except Exception as e:
866
+ logger.error(
867
+ f"✗ Failed in validator: {validator_name}\n"
868
+ f" Error: {str(e)}\n"
869
+ f" Error type: {type(e).__name__}"
870
+ )
871
+ raise e
872
+
873
+
874
+ def create_diffused_residues(n, additional_annotations=None):
875
+ if n <= 0:
876
+ raise ValueError(f"Negative/null residue count ({n}) not allowed.")
877
+
878
+ atoms = []
879
+ [
880
+ atoms.extend(
881
+ [
882
+ struc.Atom(
883
+ np.array([0.0, 0.0, 0.0], dtype=np.float32),
884
+ res_name="ALA",
885
+ res_id=idx,
886
+ )
887
+ for _ in range(5)
888
+ ]
889
+ )
890
+ for idx in range(1, n + 1)
891
+ ]
892
+ array = struc.array(atoms)
893
+ array.set_annotation(
894
+ "element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
895
+ )
896
+ array.set_annotation(
897
+ "atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
898
+ )
899
+ array = set_default_conditioning_annotations(
900
+ array, motif=False, additional=additional_annotations
901
+ )
902
+ array = set_common_annotations(array)
903
+ return array
904
+
905
+
906
+ def create_motif_residue(
907
+ token,
908
+ strip_sidechains_by_default: bool,
909
+ ):
910
+ if strip_sidechains_by_default and token.res_name in STANDARD_AA:
911
+ n_atoms = token.shape[0]
912
+ diffuse_oxygen = False
913
+ if n_atoms < 3:
914
+ raise ValueError(
915
+ f"Not enough data for {src_chain}{src_resid} in input atom array."
916
+ )
917
+ if n_atoms == 3:
918
+ # Handle cases with N, CA, C only;
919
+ token = token + create_o_atoms(token.copy())
920
+ diffuse_oxygen = True # flag oxygen for generation
921
+
922
+ # Subset to the first 4 atoms (N, CA, C, O) only
923
+ token = token[np.isin(token.atom_name, ["N", "CA", "C", "O"])]
924
+
925
+ # exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
926
+ # Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
927
+ # are placed on the CB so as to not leak the identity of the residue.
928
+ token = token + create_cb_atoms(token.copy())
929
+
930
+ # Sequence name must be set to ALA such that the central atom is correctly CB
931
+ token.res_name = np.full_like(token.res_name, "ALA", dtype=token.res_name.dtype)
932
+ token.set_annotation(
933
+ "is_motif_atom_with_fixed_coord",
934
+ np.where(
935
+ np.arange(token.shape[0], dtype=int) < (4 - int(diffuse_oxygen)),
936
+ token.is_motif_atom_with_fixed_coord,
937
+ 0,
938
+ ),
939
+ )
940
+
941
+ check_has_required_conditioning_annotations(token)
942
+ token = set_common_annotations(token)
943
+ token.set_annotation("res_id", np.full(token.shape[0], 1)) # Reset to 1
944
+
945
+ return token
946
+
947
+
948
+ def accumulate_components(
949
+ components_to_accumulate: List[Union[str, int]],
950
+ *,
951
+ # Tokens from input
952
+ indexed_tokens: Dict[str, AtomArray],
953
+ unindexed_tokens: Dict[str, AtomArray],
954
+ # Additional parameters
955
+ atom_array_accum=[],
956
+ start_chain: str = "A",
957
+ start_resid: int = 1,
958
+ unindexed_breaks: Optional[List[bool]] = [],
959
+ src_atom_array: Optional[AtomArray] = None,
960
+ strip_sidechains_by_default: bool = False,
961
+ **kwargs,
962
+ ) -> AtomArray:
963
+ # ... Create list of components
964
+ assert (
965
+ x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
966
+ ).issubset(
967
+ (y := set(components_to_accumulate))
968
+ ), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
969
+ x, y
970
+ )
971
+ all_tokens = indexed_tokens | unindexed_tokens
972
+ all_annots = []
973
+ [
974
+ all_annots.extend(list(tok.get_annotation_categories()))
975
+ for tok in all_tokens.values()
976
+ ]
977
+ all_annots = set(all_annots)
978
+ atom_array_accum = [] if atom_array_accum is None else atom_array_accum
979
+ unindexed_breaks = (
980
+ [None] * len(components_to_accumulate)
981
+ if unindexed_breaks is None
982
+ else unindexed_breaks
983
+ )
984
+
985
+ # ... For-loop accum variables
986
+ unindexed_components_started = (
987
+ False # once one unindexed component is added, stop adding diffused residues
988
+ )
989
+ chain = start_chain
990
+ res_id = start_resid
991
+ molecule_id = 0
992
+ source_to_accum_idx: Dict[int, int] = {}
993
+ current_accum_idx = sum(len(arr) for arr in atom_array_accum)
994
+
995
+ # ... Insert contig information one- by one-
996
+ assert len(components_to_accumulate) == len(
997
+ unindexed_breaks
998
+ ), "Mismatch in number of components to accumulate and breaks"
999
+ for component, is_break in zip(components_to_accumulate, unindexed_breaks):
1000
+ src_indices = None
1001
+ if exists(is_break) and is_break:
1002
+ if not unindexed_components_started:
1003
+ chain = start_chain
1004
+ res_id = start_resid
1005
+ unindexed_components_started = True
1006
+
1007
+ if component == "/0":
1008
+ # Reset iterators on next chain
1009
+ chain = chr(ord(chain) + 1)
1010
+ molecule_id += 1
1011
+ res_id = 1
1012
+ continue
1013
+
1014
+ # ... Create array to insert
1015
+ if str(component)[0].isalpha(): # motif (e.g. "A22")
1016
+ n = 1
1017
+
1018
+ # ... Fetch the motif residue
1019
+ token = all_tokens[component]
1020
+ if src_atom_array is not None:
1021
+ src_mask = fetch_mask_from_idx(component, atom_array=src_atom_array)
1022
+ src_indices = np.where(src_mask)[0]
1023
+ # try:
1024
+ # except ComponentValidationError as e:
1025
+ # src_indices = None
1026
+ # print(e)
1027
+
1028
+ # ... Ensure motif residues are set properly
1029
+ token = create_motif_residue(
1030
+ token, strip_sidechains_by_default=strip_sidechains_by_default
1031
+ )
1032
+
1033
+ # ... Insert breakpoint when break clause is met
1034
+ if exists(is_break) and is_break:
1035
+ token.set_annotation(
1036
+ "is_motif_atom_unindexed_motif_breakpoint",
1037
+ np.ones(token.shape[0], dtype=int),
1038
+ )
1039
+ else:
1040
+ token.set_annotation(
1041
+ "is_motif_atom_unindexed_motif_breakpoint",
1042
+ np.zeros(token.shape[0], dtype=int),
1043
+ )
1044
+ else:
1045
+ n = int(component)
1046
+ # ... Skip if none or unindexed
1047
+ if n == 0 or unindexed_components_started:
1048
+ res_id += n
1049
+ continue
1050
+
1051
+ # ... Create diffused residues
1052
+ token = create_diffused_residues(n, all_annots)
1053
+
1054
+ # ... Set index of insertion
1055
+ token = set_indices(
1056
+ array=token,
1057
+ chain=chain,
1058
+ res_id_start=res_id,
1059
+ molecule_id=molecule_id,
1060
+ component=component,
1061
+ )
1062
+
1063
+ assert (
1064
+ len(get_token_starts(token)) == n
1065
+ ), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
1066
+
1067
+ if (
1068
+ src_atom_array is not None
1069
+ and str(component)[0].isalpha()
1070
+ and src_indices is not None
1071
+ and len(src_indices) == len(token)
1072
+ ):
1073
+ for i, src_idx in enumerate(src_indices):
1074
+ source_to_accum_idx[int(src_idx)] = current_accum_idx + i
1075
+
1076
+ # ... Insert & Increment residue ID
1077
+ atom_array_accum.append(token)
1078
+ res_id += n
1079
+ current_accum_idx += len(token)
1080
+
1081
+ # ... Concatenate all components
1082
+ atom_array_accum = struc.concatenate(atom_array_accum)
1083
+ atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
1084
+
1085
+ should_restore_bonds = (
1086
+ src_atom_array is not None
1087
+ and bool(source_to_accum_idx)
1088
+ and _check_has_backbone_connections_to_nonstandard_residues(
1089
+ atom_array_accum, src_atom_array
1090
+ )
1091
+ )
1092
+ if should_restore_bonds:
1093
+ assert not unindexed_tokens, (
1094
+ "PTM backbone bond restoration is not compatible with unindexed components. "
1095
+ "PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
1096
+ f"Found unindexed components: {list(unindexed_tokens.keys())}"
1097
+ )
1098
+ atom_array_accum = _restore_bonds_for_nonstandard_residues(
1099
+ atom_array_accum, src_atom_array, source_to_accum_idx
1100
+ )
1101
+
1102
+ # Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
1103
+ if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
1104
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
1105
+ ):
1106
+ max_id = np.max(
1107
+ atom_array_accum[
1108
+ ~atom_array_accum.is_motif_atom_unindexed.astype(bool)
1109
+ ].res_id
1110
+ )
1111
+ min_id_udx = np.min(
1112
+ atom_array_accum[
1113
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
1114
+ ].res_id
1115
+ )
1116
+ atom_array_accum.res_id[
1117
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
1118
+ ] += max_id - min_id_udx + 1
1119
+
1120
+ # ... Bonds
1121
+ if atom_array_accum.bonds is None:
1122
+ atom_array_accum.bonds = BondList(atom_array_accum.array_length())
1123
+ return atom_array_accum