rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,508 @@
1
+ """
2
+ Contains (a) global conditioning syntax and (b) transforms for pipeline
3
+
4
+ Conditioning pipeline:
5
+ inference --- create_atom_array_from_design_specification ---|
6
+ |---> CreateConditionedArray
7
+ training --- SampleConditioningFlags ---|
8
+ """
9
+
10
+ import ast
11
+ import copy
12
+ import logging
13
+
14
+ import biotite.structure as struc
15
+ import hydra
16
+ import networkx as nx
17
+ import numpy as np
18
+ from atomworks.ml.transforms._checks import (
19
+ check_atom_array_annotation,
20
+ check_contains_keys,
21
+ check_is_instance,
22
+ )
23
+ from atomworks.ml.transforms.atom_array import (
24
+ add_global_token_id_annotation,
25
+ add_protein_termini_annotation,
26
+ )
27
+ from atomworks.ml.transforms.base import Transform
28
+ from atomworks.ml.utils.token import (
29
+ apply_and_spread_token_wise,
30
+ get_token_count,
31
+ get_token_starts,
32
+ )
33
+ from biotite.structure import AtomArray
34
+ from rfd3.constants import (
35
+ OPTIONAL_CONDITIONING_VALUES,
36
+ REQUIRED_CONDITIONING_ANNOTATIONS,
37
+ )
38
+ from rfd3.transforms.conditioning_utils import random_condition
39
+ from rfd3.transforms.util_transforms import (
40
+ add_representative_atom,
41
+ )
42
+
43
+ from foundry.common import exists
44
+
45
+ nx.from_numpy_matrix = nx.from_numpy_array
46
+ logger = logging.getLogger(__name__)
47
+ NHEAVYPROT = 14
48
+
49
+
50
+ #################################################################################
51
+ # Base conditioning definititions
52
+ #################################################################################
53
+
54
+
55
+ def get_motif_features(atom_array):
56
+ is_fixed = atom_array.is_motif_atom_with_fixed_coord.astype(bool)
57
+ is_sequence_fixed = atom_array.is_motif_atom_with_fixed_seq.astype(bool)
58
+ is_unindexed = atom_array.is_motif_atom_unindexed.astype(bool)
59
+
60
+ # Motif atom if has any conditioning
61
+ is_motif_atom = is_fixed | is_sequence_fixed | is_unindexed
62
+ is_motif_token = apply_and_spread_token_wise(
63
+ atom_array, is_motif_atom, function=lambda x: np.any(x)
64
+ ) # Has any atoms with conditioning
65
+
66
+ return {"is_motif_atom": is_motif_atom, "is_motif_token": is_motif_token}
67
+
68
+
69
+ def set_default_conditioning_annotations(
70
+ atom_array,
71
+ motif=False,
72
+ unindexed=False,
73
+ mask=None,
74
+ dtype=bool,
75
+ additional: set | list = None,
76
+ ):
77
+ """
78
+ Adds default annotations to the atom array
79
+
80
+ Args:
81
+ motif: True if default for a fully fixed motif, False if default for a fully diffused motif
82
+ unindexed: True if the tokens in the atom array should be motif
83
+ mask: boolean mask for array of which atoms to apply the assignments to.
84
+ NB: In both cases, the defaults for unindexed are False
85
+ """
86
+
87
+ # All annotations set to true for motif
88
+ fill = True if motif else False
89
+ if mask is not None:
90
+ # TODO: support defaulting to nulls
91
+ check_has_required_conditioning_annotations(atom_array)
92
+ trues = np.full(mask.sum(), True, dtype=dtype)
93
+ falses = np.full(mask.sum(), False, dtype=dtype)
94
+
95
+ atom_array.is_motif_atom_unindexed[mask] = trues if unindexed else falses
96
+ atom_array.is_motif_atom_unindexed_motif_breakpoint[mask] = falses
97
+
98
+ # Others:
99
+ for annotation in REQUIRED_CONDITIONING_ANNOTATIONS:
100
+ if annotation in [
101
+ "is_motif_atom_unindexed",
102
+ "is_motif_atom_unindexed_motif_breakpoint",
103
+ ]:
104
+ continue
105
+
106
+ vals = copy.deepcopy(atom_array.get_annotation(annotation))
107
+ vals[mask] = trues if fill else falses
108
+ atom_array.set_annotation(annotation, vals)
109
+ else:
110
+ for annotation in REQUIRED_CONDITIONING_ANNOTATIONS:
111
+ if annotation in [
112
+ "is_motif_atom_unindexed",
113
+ ]:
114
+ atom_array.set_annotation(
115
+ annotation,
116
+ np.full(atom_array.array_length(), unindexed, dtype=dtype),
117
+ )
118
+ elif annotation in [
119
+ "is_motif_atom_unindexed_motif_breakpoint",
120
+ ]:
121
+ atom_array.set_annotation(
122
+ annotation, np.full(atom_array.array_length(), False, dtype=dtype)
123
+ )
124
+ else:
125
+ atom_array.set_annotation(
126
+ annotation, np.full(atom_array.array_length(), fill, dtype=dtype)
127
+ )
128
+
129
+ if additional is not None:
130
+ for annot, val in OPTIONAL_CONDITIONING_VALUES.items():
131
+ if (
132
+ annot in additional
133
+ and annot not in atom_array.get_annotation_categories()
134
+ ):
135
+ atom_array.set_annotation(
136
+ annot, np.full(atom_array.array_length(), val)
137
+ )
138
+
139
+ return atom_array
140
+
141
+
142
+ def check_has_required_conditioning_annotations(
143
+ atom_array, required=REQUIRED_CONDITIONING_ANNOTATIONS
144
+ ):
145
+ """
146
+ Checks if the atom array has the correct conditioning annotations
147
+ """
148
+ received = atom_array.get_annotation_categories()
149
+ for required_annotation in required:
150
+ if required_annotation not in received:
151
+ raise InvalidSampledConditionException(
152
+ f"Missing annotation category in atom_array: {required_annotation}"
153
+ )
154
+ return True
155
+
156
+
157
+ def convert_existing_annotations_to_bool(
158
+ atom_array, annotations=REQUIRED_CONDITIONING_ANNOTATIONS
159
+ ):
160
+ # When loading from cif, annotations are loaded as strings when they should be boolean
161
+ for annotation in annotations:
162
+ if annotation not in atom_array.get_annotation_categories():
163
+ continue
164
+ tmp = atom_array.get_annotation(annotation).copy()
165
+ atom_array.get_annotation(annotation).dtype = bool
166
+ if isinstance(tmp[0], (str, np.str_, np.dtypes.StrDType)):
167
+ tmp = np.array([ast.literal_eval(x) for x in tmp], dtype=bool)
168
+ else:
169
+ tmp = np.asarray(tmp, dtype=bool)
170
+ atom_array.set_annotation(annotation, tmp)
171
+ return atom_array
172
+
173
+
174
+ def convert_existing_annotations_to_int(
175
+ atom_array, annotations=REQUIRED_CONDITIONING_ANNOTATIONS
176
+ ):
177
+ # When loading from cif, annotations are loaded as strings when they should be boolean
178
+ for annotation in annotations:
179
+ if annotation not in atom_array.get_annotation_categories():
180
+ continue
181
+ tmp = atom_array.get_annotation(annotation).copy()
182
+ if isinstance(tmp[0], (str, np.str_, np.bool_, bool, np.dtypes.BoolDType)):
183
+ tmp = np.array([int(x) for x in tmp], dtype=int)
184
+ atom_array.set_annotation(annotation, tmp)
185
+ return atom_array
186
+
187
+
188
+ class StrtoBoolforIsXFeatures(Transform):
189
+ def check_input(self, *args, **kwargs):
190
+ pass
191
+
192
+ def __init__(self):
193
+ pass
194
+
195
+ def forward(self, data):
196
+ atom_array = data["atom_array"]
197
+ convert_existing_annotations_to_bool(atom_array)
198
+ data["atom_array"] = atom_array
199
+ return data
200
+
201
+
202
+ class InvalidSampledConditionException(Exception):
203
+ def __init__(self, message="Error during sampling of condition."):
204
+ self.message = message
205
+ super().__init__(self.message)
206
+
207
+
208
+ #################################################################################
209
+ # Transform for pipeline (training & inference)
210
+ #################################################################################
211
+
212
+
213
+ class SampleConditioningType(Transform):
214
+ """
215
+ Applies conditional assignments
216
+
217
+ Args:
218
+ train_conditions: List[RandomMask]
219
+ seed (int): random seed, for controling the masking results
220
+
221
+ Return:
222
+ atom_array with three more annotations:
223
+ - is_motif_token: tokens to be motif
224
+ - is_motif_atom: atoms to be motif
225
+ - is_motif_atom_with_fixed_seq: for which atom we know the true restype
226
+ """
227
+
228
+ requires_previous_transforms = [
229
+ "AssignTypes",
230
+ ]
231
+
232
+ def __init__(
233
+ self,
234
+ *,
235
+ train_conditions: dict,
236
+ meta_conditioning_probabilities: dict,
237
+ sequence_encoding,
238
+ ):
239
+ if exists(train_conditions):
240
+ train_conditions = hydra.utils.instantiate(
241
+ train_conditions, _recursive_=True
242
+ )
243
+ self.meta_conditioning_probabilities = meta_conditioning_probabilities
244
+ self.train_conditions = train_conditions
245
+ self.sequence_encoding = sequence_encoding
246
+
247
+ def check_input(self, data: dict):
248
+ assert not data["is_inference"], "This transform is only used during training!"
249
+ check_contains_keys(data, ["atom_array"])
250
+ check_is_instance(data, "atom_array", AtomArray)
251
+ check_atom_array_annotation(data, ["pn_unit_id", "pn_unit_iid"])
252
+ existing = [
253
+ cat in REQUIRED_CONDITIONING_ANNOTATIONS
254
+ for cat in data["atom_array"].get_annotation_categories()
255
+ ]
256
+ assert not any(
257
+ existing
258
+ ), "Conditioning annotations already set! found {}".format(existing)
259
+ assert "conditions" in data, "Conditioning dict not initialized"
260
+
261
+ def forward(self, data):
262
+ valid_conditions = [
263
+ cond
264
+ for cond in self.train_conditions.values()
265
+ if cond.frequency > 0 and cond.is_valid_for_example(data)
266
+ ]
267
+
268
+ if len(valid_conditions) == 0:
269
+ raise InvalidSampledConditionException("No valid condition was found.")
270
+
271
+ p_cond = np.array([cond.frequency for cond in valid_conditions])
272
+ if p_cond.sum() == 0:
273
+ raise InvalidSampledConditionException(
274
+ "No valid condition was found with non-zero frequency."
275
+ )
276
+ p_cond = p_cond.astype(np.float64)
277
+ p_cond /= p_cond.sum()
278
+ i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
279
+ cond = valid_conditions[i_cond]
280
+
281
+ data["sampled_condition"] = cond
282
+ data["sampled_condition_name"] = cond.name
283
+ data["sampled_condition_cls"] = cond.__class__
284
+
285
+ # Sample canonical conditioning flags for downstream processing
286
+ for k, p in self.meta_conditioning_probabilities.items():
287
+ data["conditions"][k] = random_condition(p)
288
+
289
+ return data
290
+
291
+
292
+ class SampleConditioningFlags(Transform):
293
+ requires_previous_transforms = [
294
+ "FlagAndReassignCovalentModifications",
295
+ "AssignTypes",
296
+ "SampleConditioningType",
297
+ ] # We use is_protein in the PPI training condition
298
+
299
+ def check_input(self, data):
300
+ assert not data[
301
+ "is_inference"
302
+ ], "This transform is only used during training! Validation using sampled conditions is not implemented yet"
303
+ assert "sampled_condition" in data
304
+
305
+ def forward(self, data: dict) -> dict:
306
+ cond = data["sampled_condition"]
307
+
308
+ # Sample canonical conditioning flags for atom array
309
+ atom_array = cond.sample(data)
310
+ data["atom_array"] = atom_array
311
+
312
+ return data
313
+
314
+
315
+ class UnindexFlaggedTokens(Transform):
316
+ """
317
+ Serves as the merge point between training / infernece conditioning pipelines
318
+ """
319
+
320
+ def __init__(self, central_atom):
321
+ """
322
+ Args:
323
+ central_atom: The atom to use as the central atom for unindexed motifs.
324
+ """
325
+ super().__init__()
326
+ self.central_atom = central_atom
327
+
328
+ def check_input(self, data: dict):
329
+ check_contains_keys(data, ["atom_array"])
330
+ check_is_instance(data, "atom_array", AtomArray)
331
+
332
+ def expand_unindexed_motifs(
333
+ self, atom_array: AtomArray, pop_orig_tokens: bool
334
+ ) -> AtomArray:
335
+ """
336
+ Takes atom array and motif indices and padds the atom array to include unindexed motif atoms.
337
+
338
+ is_motif_atom_unindexed - Whether an atom is flagged to be a guidepost
339
+ During training, the original coordinates are left behind for the model to learn to diffuse,
340
+ during inference, the original tokens are removed by default.
341
+ """
342
+ # back up original residue id for training metrics
343
+ atom_array.set_annotation("orig_res_id", atom_array.res_id.copy())
344
+ is_motif_atom_unindexed = atom_array.is_motif_atom_unindexed.copy()
345
+ if not np.any(is_motif_atom_unindexed):
346
+ return atom_array
347
+
348
+ # ... A token is to be unindexed if any atoms in the token are unindexed
349
+ max_resid = np.max(atom_array.res_id)
350
+ starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True)
351
+ token_to_unindex = struc.spread_residue_wise(
352
+ atom_array,
353
+ struc.apply_residue_wise(
354
+ atom_array,
355
+ is_motif_atom_unindexed,
356
+ function=lambda x: np.any(x),
357
+ ),
358
+ )
359
+ assert token_to_unindex.sum() > 0, "No tokens to unindex!"
360
+ idxs = np.arange(atom_array.array_length())
361
+ unindexed_tokens = []
362
+ for i, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
363
+ if not token_to_unindex[start]:
364
+ continue
365
+ subset_mask = np.isin(idxs, idxs[start:end])
366
+ token = copy.deepcopy(atom_array[subset_mask])
367
+ token = token[token.is_motif_atom_unindexed]
368
+ token.res_id = token.res_id + max_resid
369
+ token.is_C_terminus[:] = False
370
+ token.is_N_terminus[:] = False
371
+ assert token.is_protein.all(), f"Cannot unindex non-protein token: {token}"
372
+ token = add_representative_atom(token, central_atom=self.central_atom)
373
+ unindexed_tokens.append(token)
374
+
375
+ # ... Remove original tokens e.g. during inference
376
+ if pop_orig_tokens:
377
+ atom_array = atom_array[~token_to_unindex]
378
+ # Reassign Termini features
379
+ atom_array = add_protein_termini_annotation(atom_array)
380
+ else:
381
+ # Reset is_motif_atom and is_motif_atom_unindexed to contain no motif annotations where unindexed
382
+ # I.e model should view the original tokens the same as every other diffused token
383
+ atom_array.is_motif_atom[token_to_unindex] = False
384
+ atom_array.is_motif_atom_with_fixed_coord[token_to_unindex] = False
385
+ atom_array.is_motif_token[token_to_unindex] = False
386
+ atom_array.is_motif_atom_with_fixed_seq[token_to_unindex] = False
387
+ atom_array.is_motif_atom_unindexed[token_to_unindex] = False
388
+ atom_array.is_motif_atom_unindexed_motif_breakpoint[token_to_unindex] = (
389
+ False
390
+ )
391
+
392
+ # Concatenate unindexed parts to the end
393
+ atom_array_full = struc.concatenate([atom_array] + unindexed_tokens)
394
+ atom_array_to_concat = struc.concatenate(unindexed_tokens)
395
+ # Ensure tokens are recognised as seperate
396
+ n_unindexed_tokens = get_token_count(atom_array_to_concat)
397
+ assert n_unindexed_tokens == len(
398
+ unindexed_tokens
399
+ ), f"Expected {len(unindexed_tokens)} but got {n_unindexed_tokens}"
400
+ assert (
401
+ get_token_count(atom_array_full)
402
+ == get_token_count(atom_array) + n_unindexed_tokens
403
+ ), (
404
+ f"Failed to create uniquely recognised tokens after concatenation.\n"
405
+ f"Concatenated tokens: {get_token_count(atom_array_full)}, unindexed: {n_unindexed_tokens}"
406
+ )
407
+
408
+ return atom_array_full
409
+
410
+ def create_unindexed_masks(
411
+ self,
412
+ atom_array,
413
+ is_inference=False,
414
+ ):
415
+ """
416
+ Create L,L boolean matrix indicating the tokens which should absolutely
417
+ not know the relative positions of one another.
418
+
419
+ False when positional leakage is allowed
420
+ True when positional leakage is disallowed
421
+
422
+ Used as input to the models' relative position encoding.
423
+
424
+ breaks:
425
+ boolean atom-wise array indicating which token breaks the group ids up.
426
+ if all are false, all indices are leaked. If the first break of the unindexed tokens is
427
+ True, the cross-motif couplings are leaked but not the global index
428
+
429
+ atom_array: padded atom array
430
+ """
431
+ token_starts = get_token_starts(atom_array)
432
+ token_level_array = atom_array[token_starts]
433
+ is_motif_token_unindexed = token_level_array.is_motif_atom_unindexed
434
+
435
+ # ... Grab breaks from the token level array
436
+ unindexed_token_level_array = token_level_array[is_motif_token_unindexed]
437
+ breaks = unindexed_token_level_array.is_motif_atom_unindexed_motif_breakpoint
438
+
439
+ leak_all = not np.any(breaks)
440
+ if leak_all:
441
+ if is_inference and np.any(is_motif_token_unindexed):
442
+ logger.info("Indexing all unindexed components")
443
+ L = len(token_starts)
444
+ return np.zeros((L, L), dtype=bool), is_motif_token_unindexed
445
+
446
+ # ... First component of mask is that no unindexed atoms should talk to indexed ones.
447
+ mask = (
448
+ is_motif_token_unindexed[:, None] == ~is_motif_token_unindexed[None, :]
449
+ ) # [intra indexed + intra unindexed]
450
+
451
+ # ... Then, within unindexed tokens, seperate the islands based on where the token id breaks
452
+ unindexed_all_LL = (
453
+ is_motif_token_unindexed[:, None] & is_motif_token_unindexed[None, :]
454
+ ) # [intra unindexed]
455
+
456
+ ########################################################################################
457
+ # Determine intra-unindexed resid leakage
458
+ ########################################################################################
459
+ # ... Mask out intra-unindexed off-diagonals as necessary
460
+ group_ids = np.cumsum(breaks)
461
+ mask_unindexed_MM = group_ids[:, None] != group_ids[None, :]
462
+ mask[unindexed_all_LL] = mask_unindexed_MM.flatten()
463
+
464
+ return mask, is_motif_token_unindexed
465
+
466
+ def forward(self, data: dict):
467
+ atom_array = data["atom_array"]
468
+ if "feats" not in data:
469
+ data["feats"] = {}
470
+
471
+ # ... Ensure conditioning flags are set correctly
472
+ # NOTE: Join point for inference and training conditioning pipelines
473
+ check_has_required_conditioning_annotations(atom_array)
474
+
475
+ is_unindexed_token = apply_and_spread_token_wise(
476
+ atom_array,
477
+ atom_array.is_motif_atom_unindexed.copy(),
478
+ function=lambda x: np.any(x),
479
+ )
480
+
481
+ # Expand unindexed motifs if necessary
482
+ atom_array_expanded = self.expand_unindexed_motifs(
483
+ atom_array,
484
+ pop_orig_tokens=data["is_inference"],
485
+ )
486
+
487
+ # Provide the atom-wise mask for the regions which should be diffused into the guideposts
488
+ # the original token was unindexed if any of the atoms where unindexed
489
+ n_expanded_atoms = (
490
+ atom_array_expanded.array_length() - atom_array.array_length()
491
+ )
492
+ mask = np.concatenate([is_unindexed_token, np.zeros(n_expanded_atoms)])
493
+ if "ground_truth" not in data:
494
+ data["ground_truth"] = {}
495
+ data["ground_truth"]["is_original_unindexed_token"] = mask.astype(bool)
496
+
497
+ # Reset global token IDs after possible padding
498
+ atom_array_expanded = add_global_token_id_annotation(atom_array_expanded)
499
+
500
+ # For unindexed scaffolding, we must provide an unindexing pair mask to ensure original positions aren't leaked to:
501
+ # (I) RPE of the token initializer and (II) the atom attention base sequence mask
502
+ mask_II, mask_I = self.create_unindexed_masks(
503
+ atom_array_expanded, is_inference=data["is_inference"]
504
+ )
505
+ data["feats"]["unindexing_pair_mask"] = mask_II
506
+ data["feats"]["is_motif_token_unindexed"] = mask_I
507
+ data["atom_array"] = atom_array_expanded
508
+ return data