rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,502 @@
1
+ from collections import Counter, OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
6
+ from atomworks.ml.utils.token import (
7
+ get_token_starts,
8
+ spread_token_wise,
9
+ )
10
+ from biotite.structure import concatenate, infer_elements
11
+ from jaxtyping import Float, Int
12
+ from rfd3.constants import (
13
+ ATOM14_ATOM_NAMES,
14
+ VIRTUAL_ATOM_ELEMENT_NAME,
15
+ association_schemes,
16
+ association_schemes_stripped,
17
+ )
18
+ from rfd3.utils.io import (
19
+ build_stack_from_atom_array_and_batched_coords,
20
+ )
21
+ from scipy.optimize import linear_sum_assignment
22
+
23
+ from foundry.common import exists
24
+ from foundry.utils.ddp import RankedLogger
25
+
26
+ global_logger = RankedLogger(__name__, rank_zero_only=False)
27
+
28
+ #######################################################################
29
+ # Pythonic Helper functions
30
+ #######################################################################
31
+
32
+
33
+ def _remap_outputs(
34
+ xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
35
+ ) -> Float[torch.Tensor, "D L 3"]:
36
+ """Helper function to remap outputs using a mapping tensor."""
37
+ for i in range(xyz.shape[0]):
38
+ xyz[i, mapping[i]] = xyz[i].clone()
39
+ return xyz
40
+
41
+
42
+ def _reorder_dict(d: dict) -> OrderedDict:
43
+ """
44
+ Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
45
+ """
46
+ ordered = OrderedDict()
47
+ first_keys = ["task", "diffused_index_map"]
48
+ last_keys = ["metrics", "specification", "inference_sampler"]
49
+ # First
50
+ for k in first_keys:
51
+ if k in d:
52
+ ordered[k] = d[k]
53
+ # Middle
54
+ for k in d:
55
+ if k not in last_keys and k not in first_keys:
56
+ ordered[k] = d[k]
57
+ # Last
58
+ for k in last_keys:
59
+ if k in d:
60
+ ordered[k] = d[k]
61
+ return ordered
62
+
63
+
64
+ #######################################################################
65
+ # Biotite-related helper functions
66
+ #######################################################################
67
+
68
+
69
+ def _build_atom_array_stack(
70
+ coords,
71
+ src_atom_array,
72
+ sequence_indices,
73
+ sequence_logits,
74
+ allow_sequence_outputs=True,
75
+ read_sequence_from_sequence_head=True,
76
+ association_scheme: str = "atom14",
77
+ ):
78
+ """
79
+ Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
80
+ """
81
+ atom_array_stack = build_stack_from_atom_array_and_batched_coords(
82
+ coords, src_atom_array.copy()
83
+ )
84
+
85
+ # ... Spoof empty sequences to alanines
86
+ atom_array_stack.res_name[
87
+ atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
88
+ ] = "ALA"
89
+
90
+ # ... Add sequence if available
91
+ if allow_sequence_outputs:
92
+ array_list = []
93
+ if read_sequence_from_sequence_head and exists(sequence_logits):
94
+ sequence_encoding = AF3SequenceEncoding()
95
+ for i, (atom_array, seq_indices, seq_logits) in enumerate(
96
+ zip(atom_array_stack, sequence_indices, sequence_logits)
97
+ ):
98
+ # Set residue names
99
+ diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
100
+ three_letter_sequence = sequence_encoding.decode(
101
+ seq_indices.cpu().numpy().astype(int)
102
+ ) # [I]
103
+
104
+ atom_array.res_name[diffused_mask] = three_letter_sequence[
105
+ atom_array.token_id
106
+ ][diffused_mask] # [L]
107
+
108
+ # Set bfactor column as entropy of sequence logits
109
+ p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
110
+ res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
111
+ atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
112
+ array_list.append(atom_array.copy())
113
+ else:
114
+ # This automatically deletes virtual atoms and assigns resname, atom name, and elements
115
+ for atom_array in atom_array_stack:
116
+ atom_array = _readout_seq_from_struc(
117
+ atom_array, association_scheme=association_scheme
118
+ )
119
+ array_list.append(atom_array)
120
+
121
+ # Return as list
122
+ atom_array_stack = array_list
123
+
124
+ return atom_array_stack
125
+
126
+
127
+ def _cleanup_virtual_atoms_and_assign_atom_name_elements(
128
+ atom_array, association_scheme: str = "atom14"
129
+ ):
130
+ ## remove virtual atoms based on predicted residue and assign correct atom name and elements
131
+ ret_mask = []
132
+ atom_names = []
133
+ # This is used to indicate which residue is unidentified, probably due to an invalid structure.
134
+ # This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
135
+ invalid_mask = []
136
+
137
+ # ... Iterate through each residue.
138
+ # Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
139
+ res_ids = atom_array.res_id
140
+ res_start_indices = np.concatenate(
141
+ [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
142
+ )
143
+ res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
144
+ warning_issued = False
145
+ for start, end in zip(res_start_indices, res_end_indices):
146
+ res_array = atom_array[start:end]
147
+
148
+ is_seq_known = all(
149
+ np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
150
+ ) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
151
+
152
+ # ... If sequence is known for the original atom array, just skip
153
+ if is_seq_known:
154
+ ret_mask += [True] * len(res_array)
155
+ invalid_mask += [False] * len(res_array)
156
+ res_name = res_array[0].res_name
157
+ atom_names += res_array.gt_atom_name.tolist()
158
+ continue
159
+
160
+ # ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
161
+ res_name = res_array[0].res_name
162
+ if res_name not in association_schemes[association_scheme]:
163
+ global_logger.warning(
164
+ "Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
165
+ )
166
+ warning_issued = True
167
+ ret_mask += [True] * len(res_array)
168
+ invalid_mask += [True] * len(res_array)
169
+ atom_names += res_array.atom_name.tolist()
170
+ continue
171
+
172
+ scheme = association_schemes[association_scheme][res_name]
173
+ ret_mask += [True if item is not None else False for item in scheme]
174
+ atom_names += [item.strip() if item is not None else "VX" for item in scheme]
175
+ invalid_mask += [False] * len(scheme)
176
+
177
+ if len(atom_names) != atom_array.array_length():
178
+ global_logger.warning(
179
+ f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
180
+ "\nCould not cleanup atom array!!!"
181
+ )
182
+ if not warning_issued:
183
+ raise ValueError("Atom names length does not match original array length. ")
184
+ return atom_array
185
+ atom_array.atom_name = atom_names
186
+ atom_array.element = np.where(
187
+ atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
188
+ infer_elements(atom_names),
189
+ atom_array.element,
190
+ )
191
+ atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
192
+ return atom_array[ret_mask]
193
+
194
+
195
+ def _readout_seq_from_struc(
196
+ atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
197
+ ):
198
+ cur_atom_array_list = []
199
+
200
+ # Iterate through each residue
201
+ res_ids = atom_array.res_id
202
+ res_start_indices = np.concatenate(
203
+ [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
204
+ )
205
+ res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
206
+
207
+ for start, end in zip(res_start_indices, res_end_indices):
208
+ # ... Check if the current residue is after padding (seq unknown):
209
+ cur_res_atom_array = atom_array[start:end]
210
+ is_seq_known = all(
211
+ np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
212
+ )
213
+
214
+ # Here it assumes that every non-protein part has its sequence shown (not padded)
215
+ if not is_seq_known:
216
+ # For Glycine: it doesn't have CB, so set the virtual atom as CA.
217
+ # The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
218
+ # There might be a better way to do this.
219
+ CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
220
+ CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
221
+ if np.linalg.norm(CA_coord - CB_coord) < threshold:
222
+ cur_central_atom = "CA"
223
+ else:
224
+ cur_central_atom = central_atom
225
+
226
+ central_mask = cur_res_atom_array.atom_name == cur_central_atom
227
+
228
+ # ... Calculate the distance to the central atom
229
+ central_coord = cur_res_atom_array.coord[central_mask][
230
+ 0
231
+ ] # Should only have one central atom anyway
232
+ dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
233
+
234
+ # ... Select virtual atom by the distance. Shouldn't count the central atom itself.
235
+ is_virtual = (dists < threshold) & ~central_mask
236
+
237
+ # ... Throw away virtual atoms
238
+ cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
239
+ cur_pred_res_atom_names = (
240
+ cur_res_atom_array_wo_virtual.atom_name
241
+ ) # e.g. [N, CA, C, O, CB, V6, V2]
242
+
243
+ # ... Iterate over the possible restypes and find the matched one if there is any
244
+ has_restype_assigned = False
245
+ for restype, atom_names in association_schemes_stripped[
246
+ association_scheme
247
+ ].items():
248
+ atom_names = np.array(atom_names)
249
+
250
+ # Shouldn't match these two
251
+ if restype in ["UNK", "MSK"]:
252
+ continue
253
+
254
+ # ... Find the index of virtual atom names in the standard atom14 names
255
+ atom_name_idx_in_atom14_scheme = np.array(
256
+ [
257
+ np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
258
+ for atom_name in cur_pred_res_atom_names
259
+ ]
260
+ ) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
261
+ atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
262
+ atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
263
+
264
+ # ... Find the matched restype by checking if all the non-None posititons and None positions match
265
+ # This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
266
+ if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
267
+ x is None for x in atom_names[~atom14_scheme_mask]
268
+ ):
269
+ cur_res_atom_array.res_name = np.array(
270
+ [restype] * len(cur_res_atom_array)
271
+ )
272
+ cur_atom_array_list.append(cur_res_atom_array)
273
+ has_restype_assigned = True
274
+ break
275
+ else:
276
+ cur_atom_array_list.append(cur_res_atom_array)
277
+ has_restype_assigned = True
278
+
279
+ # ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
280
+ if not has_restype_assigned:
281
+ cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
282
+ cur_atom_array_list.append(cur_res_atom_array)
283
+
284
+ cur_atom_array = concatenate(cur_atom_array_list)
285
+
286
+ return cur_atom_array
287
+
288
+
289
+ #######################################################################
290
+ # Unindexed output parsing
291
+ #######################################################################
292
+
293
+
294
+ def _reassign_unindexed_token_chains(atom_array):
295
+ if np.any((mask := atom_array.is_motif_atom_unindexed)):
296
+ # HACK: Since res_ids are the same, we should save them with a different chain index.
297
+ atom_array.chain_id[mask] = "X"
298
+ atom_array.res_id[mask] = atom_array.orig_res_id[mask]
299
+
300
+ # Parse to separate chains
301
+ starts = get_token_starts(atom_array)
302
+ unindexed_starts = starts[mask[starts]]
303
+ token_breaks = atom_array[
304
+ unindexed_starts
305
+ ].is_motif_atom_unindexed_motif_breakpoint
306
+ token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
307
+ token_chain_id = np.array([f"X{i}" for i in token_group_id])
308
+
309
+ chains = atom_array.chain_id[starts]
310
+ chains[mask[starts]] = token_chain_id
311
+ atom_array.chain_id = spread_token_wise(atom_array, chains)
312
+ return atom_array
313
+
314
+
315
+ def process_unindexed_outputs(
316
+ atom_array,
317
+ match_atom_names=True,
318
+ insert_guideposts=False,
319
+ verbose=False,
320
+ ):
321
+ """
322
+ Process design outputs containing unindexed tokens.
323
+ Returns metadata such as the assigned positional indices from the input indices
324
+ and the RMSD of the unindexed tokens.
325
+
326
+ Returns:
327
+ - Diffused atom array (without additional unindexed tokens)
328
+ - Metadata:
329
+ - diffused_indices: keys = original (contig) indices, values = diffused indices
330
+ - insertion_rmsd: overall RMSD of insertion
331
+ - insertion_rmsd_by_residue: RMSD of insertion for each token
332
+
333
+ TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
334
+ TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
335
+ """
336
+ # ... Find assignments based on greedy search
337
+ starts = get_token_starts(atom_array, add_exclusive_stop=True)
338
+
339
+ # [N_diffused,]
340
+ atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
341
+ global_idx = np.arange(atom_array.array_length())[
342
+ ~atom_array.is_motif_atom_unindexed
343
+ ]
344
+
345
+ metadata = {
346
+ "diffused_index_map": {},
347
+ "insertion_rmsd_by_token": {},
348
+ "join_point_rmsd_by_token": {},
349
+ "insertion_rmsd_by_restype": {},
350
+ }
351
+ token_maes = []
352
+ token_rmcds = []
353
+ n_conjoined_residues = 0
354
+
355
+ # Initialize an empty array
356
+ inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
357
+
358
+ for start, end in zip(starts[:-1], starts[1:]):
359
+ token = atom_array[start:end]
360
+ if not token.is_motif_atom_unindexed.all():
361
+ continue
362
+
363
+ if "src_component" in token.get_annotation_categories():
364
+ token_pdb_id = token.src_component[0]
365
+ else:
366
+ raise ValueError(
367
+ "Missing annotation 'src_component' in token. Is this inference?"
368
+ )
369
+
370
+ if "src_sym_component" in token.get_annotation_categories():
371
+ # if symmetry, token_pdb_id are updated to match the symmetrized component
372
+ token_pdb_id = token.src_sym_component[0]
373
+
374
+ res_name = token.res_name[0]
375
+
376
+ # ... Calculate [N_unindex, N_diffused] distance matrix
377
+ dists = np.linalg.norm(
378
+ token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
379
+ )
380
+
381
+ # ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
382
+ dists[:, inserted_mask.copy()] = np.inf
383
+ if match_atom_names:
384
+ matching_atom_name = (
385
+ token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
386
+ )
387
+ dists[~matching_atom_name] = np.inf
388
+
389
+ # ... Find the res_id's in the diffused regions belonging to the diffused indices
390
+ row_ind, col_ind = linear_sum_assignment(dists)
391
+ res_id, chain_id, is_conjoined = indices_to_components_(
392
+ atom_array_diffused, col_ind
393
+ )
394
+ n_conjoined_residues += int(is_conjoined)
395
+
396
+ # ... Recompute distance indices based on single residue pairings only
397
+ token_match = (atom_array_diffused.res_id == res_id) & (
398
+ atom_array_diffused.chain_id == chain_id
399
+ )
400
+ dists[:, ~token_match] = np.nan
401
+ BIG = 1e12
402
+ dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
403
+ row_ind, col_ind = linear_sum_assignment(dists)
404
+ res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind)
405
+
406
+ assert (res_id_ == res_id) & (chain_id_ == chain_id)
407
+ inserted_mask = np.logical_or(inserted_mask, token_match)
408
+
409
+ # ... Compute metrics based on the new distances
410
+ diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
411
+ token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
412
+ token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
413
+ token_mae = float((np.abs(diff)).sum(-1).mean())
414
+
415
+ metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
416
+ token_maes.append(token_mae)
417
+ token_rmcds.append(token_rmcd)
418
+
419
+ if res_name not in metadata["insertion_rmsd_by_restype"]:
420
+ metadata["insertion_rmsd_by_restype"][res_name] = []
421
+ metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
422
+ if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
423
+ if np.sum(token.atomize) == 1:
424
+ join_atom = np.where(token.atomize)[0][0]
425
+ elif "CB" in token.atom_name:
426
+ join_atom = np.where(token.atom_name == "CB")[0][0]
427
+ else:
428
+ join_atom = None
429
+
430
+ if join_atom is None:
431
+ global_logger.warning(
432
+ f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
433
+ )
434
+ else:
435
+ dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
436
+ metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
437
+
438
+ metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
439
+
440
+ # ... Decide whether to cleanup guideposts or not
441
+ if insert_guideposts:
442
+ atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
443
+ if token.is_motif_atom_with_fixed_seq[0]:
444
+ atom_array_diffused.res_name[token_match] = token.res_name[0]
445
+ # atom_array_diffused.is_motif_token[token_match] = True
446
+ # atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
447
+ atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
448
+ True
449
+ )
450
+
451
+ # ... Calculate global metrics
452
+ def safe_mean(x):
453
+ """Return nan-safe mean for empty or nan arrays."""
454
+ x = np.asarray(x, float)
455
+ if x.size == 0 or not np.isfinite(x).any():
456
+ return float("nan")
457
+ return float(np.nanmean(x))
458
+
459
+ metadata["insertion.mae"] = safe_mean(token_maes)
460
+ metadata["insertion.rmcd"] = safe_mean(token_rmcds)
461
+ metadata["insertion_rmsd"] = safe_mean(
462
+ list(metadata["insertion_rmsd_by_token"].values())
463
+ )
464
+ metadata["join_point_rmsd"] = safe_mean(
465
+ list(metadata["join_point_rmsd_by_token"].values())
466
+ )
467
+ metadata["insertion_rmsd_by_restype"] = {
468
+ a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
469
+ }
470
+ metadata["n_conjoined_residues"] = n_conjoined_residues
471
+
472
+ if not verbose:
473
+ metadata = {
474
+ k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
475
+ }
476
+
477
+ return atom_array_diffused, metadata
478
+
479
+
480
+ def indices_to_components_(atom_array, col_ind):
481
+ """
482
+ Fetch chain and resids in atom array given a set of raw indices
483
+ will return 'conjoined' if indices to not map to a unique residue
484
+ """
485
+ res_ids, chain_ids = (
486
+ atom_array.res_id[col_ind],
487
+ atom_array.chain_id[col_ind],
488
+ )
489
+ if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
490
+ global_logger.warning(
491
+ f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
492
+ )
493
+ # Handle by majority
494
+ pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
495
+ (chain_id, res_id), _ = pair_counts.most_common(1)[0]
496
+ conjoined = True
497
+ else:
498
+ res_id = res_ids[0]
499
+ chain_id = chain_ids[0]
500
+ conjoined = False
501
+
502
+ return res_id, chain_id, conjoined