rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,807 @@
1
+ """
2
+ Design Transforms for the Atom14 pipeline
3
+ """
4
+
5
+ from typing import Any, Dict
6
+
7
+ import biotite.structure as struc
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from atomworks.constants import (
12
+ ELEMENT_NAME_TO_ATOMIC_NUMBER,
13
+ )
14
+ from atomworks.enums import GroundTruthConformerPolicy
15
+ from atomworks.io.utils.selection import get_residue_starts
16
+ from atomworks.ml.transforms._checks import (
17
+ check_contains_keys,
18
+ check_is_instance,
19
+ )
20
+ from atomworks.ml.transforms.af3_reference_molecule import (
21
+ _encode_atom_names_like_af3,
22
+ get_af3_reference_molecule_features,
23
+ )
24
+ from atomworks.ml.transforms.base import (
25
+ Transform,
26
+ )
27
+ from atomworks.ml.utils.geometry import (
28
+ masked_center,
29
+ random_rigid_augmentation,
30
+ )
31
+ from atomworks.ml.utils.token import (
32
+ apply_token_wise,
33
+ get_token_starts,
34
+ )
35
+ from biotite.structure import AtomArray
36
+ from rfd3.constants import VIRTUAL_ATOM_ELEMENT_NAME
37
+ from rfd3.transforms.conditioning_base import (
38
+ UnindexFlaggedTokens,
39
+ get_motif_features,
40
+ )
41
+ from rfd3.transforms.rasa import discretize_rasa
42
+ from rfd3.transforms.util_transforms import (
43
+ AssignTypes,
44
+ add_backbone_and_sidechain_annotations,
45
+ get_af3_token_representative_masks,
46
+ )
47
+ from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
48
+
49
+ from foundry.utils.ddp import RankedLogger # noqa
50
+
51
+ #####################################################################################################
52
+ # Other design transforms
53
+ #####################################################################################################
54
+
55
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
56
+
57
+
58
+ class SubsampleToTypes(Transform):
59
+ """
60
+ Remove all types not specified as allowed_types
61
+ Possible allowed types:
62
+ - is_protein
63
+ - is_ligand
64
+ - is_dna
65
+ - is_rna
66
+ """
67
+
68
+ requires_previous_transforms = [AssignTypes]
69
+
70
+ def __init__(
71
+ self,
72
+ allowed_types: list | str = ["is_protein"],
73
+ ):
74
+ self.allowed_types = allowed_types
75
+ if not self.allowed_types == "ALL":
76
+ for k in allowed_types:
77
+ if not k.startswith("is_"):
78
+ raise ValueError(f"Allowed types must start with 'is_', got {k}")
79
+
80
+ def check_input(self, data: dict):
81
+ check_contains_keys(data, ["atom_array"])
82
+
83
+ def forward(self, data):
84
+ atom_array = data["atom_array"]
85
+
86
+ # ... Subsampling
87
+ if not self.allowed_types == "ALL":
88
+ is_allowed = np.zeros_like(atom_array.is_protein, dtype=bool)
89
+ for allowed_type in self.allowed_types:
90
+ is_allowed = np.logical_or(
91
+ np.asarray(is_allowed, dtype=bool),
92
+ np.asarray(
93
+ atom_array.get_annotation(allowed_type), dtype=bool
94
+ ).copy(),
95
+ )
96
+ atom_array = atom_array[is_allowed]
97
+
98
+ # ... Assert any protein remains
99
+ if atom_array.array_length() == 0:
100
+ raise ValueError(
101
+ "No tokens found in the atom array! Example ID: {}".format(
102
+ data.get("example_id", "unknown")
103
+ )
104
+ )
105
+
106
+ if atom_array.is_protein.sum() == 0:
107
+ raise ValueError(
108
+ "No protein atoms found in the atom array. Example ID: {}".format(
109
+ data.get("example_id", "unknown")
110
+ )
111
+ )
112
+
113
+ data["atom_array"] = atom_array
114
+ return data
115
+
116
+
117
+ class CreateDesignReferenceFeatures(Transform):
118
+ """
119
+ Traditional AF3 will create a bunch of reference features based on the sequence and molecular identity.
120
+ For our design, we do not have access to sequence so these features are useless
121
+
122
+ However, this is a great place to add atom-level features as explicit conditioning or implicit
123
+ classifier free guidance.
124
+
125
+ Reduces time to process from ~0.5 to ~0.1 s on avg.
126
+ """
127
+
128
+ requires_previous_transforms = [UnindexFlaggedTokens, AssignTypes]
129
+
130
+ def __init__(
131
+ self,
132
+ generate_conformers,
133
+ generate_conformers_for_non_protein_only,
134
+ provide_reference_conformer_when_unmasked,
135
+ ground_truth_conformer_policy,
136
+ provide_elements_for_unindexed_components,
137
+ use_element_for_atom_names_of_atomized_tokens,
138
+ **kwargs,
139
+ ):
140
+ self.generate_conformers = generate_conformers
141
+ self.generate_conformers_for_non_protein_only = (
142
+ generate_conformers_for_non_protein_only
143
+ )
144
+ self.provide_reference_conformer_when_unmasked = (
145
+ provide_reference_conformer_when_unmasked
146
+ )
147
+ if provide_reference_conformer_when_unmasked:
148
+ self.ground_truth_conformer_policy = GroundTruthConformerPolicy[
149
+ ground_truth_conformer_policy
150
+ ]
151
+ else:
152
+ self.ground_truth_conformer_policy = GroundTruthConformerPolicy.IGNORE
153
+
154
+ self.provide_elements_for_unindexed_components = (
155
+ provide_elements_for_unindexed_components
156
+ )
157
+
158
+ self.conformer_generation_kwargs = {
159
+ "conformer_generation_timeout": 2.0,
160
+ "use_element_for_atom_names_of_atomized_tokens": use_element_for_atom_names_of_atomized_tokens,
161
+ } | kwargs
162
+
163
+ def check_input(self, data: dict):
164
+ check_contains_keys(data, ["atom_array"])
165
+
166
+ def forward(self, data: dict) -> dict:
167
+ atom_array = data["atom_array"]
168
+ I = atom_array.array_length()
169
+ token_starts = get_token_starts(atom_array)
170
+ token_level_array = atom_array[token_starts]
171
+ L = token_level_array.array_length()
172
+
173
+ # ... Set up default reference features
174
+ ref_pos = np.zeros_like(atom_array.coord, dtype=np.float32)
175
+ ref_pos[~atom_array.is_motif_atom_with_fixed_coord, :] = 0
176
+ ref_mask = np.zeros((I,), dtype=bool)
177
+ ref_charge = np.zeros((I,), dtype=np.int8)
178
+ ref_pos_is_ground_truth = np.zeros((I,), dtype=bool)
179
+
180
+ # ... For elements provide only the elements for unindexed components
181
+ ref_element = np.zeros((I,), dtype=np.int64)
182
+ # if self.provide_elements_for_unindexed_components:
183
+ # ref_element[atom_array.is_motif_atom_unindexed] = atom_array.atomic_number[
184
+ # atom_array.is_motif_atom_unindexed
185
+ # ]
186
+
187
+ # ... For atom names, provide all (spoofed) names in the atom array
188
+ ref_atom_name_chars = _encode_atom_names_like_af3(atom_array.atom_name)
189
+ _res_start_ends = get_residue_starts(atom_array, add_exclusive_stop=True)
190
+ _res_starts, _res_ends = _res_start_ends[:-1], _res_start_ends[1:]
191
+ ref_space_uid = struc.segments.spread_segment_wise(
192
+ _res_start_ends, np.arange(len(_res_starts), dtype=np.int64)
193
+ )
194
+
195
+ o = get_motif_features(atom_array)
196
+ is_motif_atom, is_motif_token = o["is_motif_atom"], o["is_motif_token"]
197
+ is_motif_token = is_motif_token[token_starts]
198
+
199
+ # ... Add a flag for atoms with zero occupancy
200
+ has_zero_occupancy = atom_array.occupancy == 0.0
201
+ if data["is_inference"] and has_zero_occupancy.any():
202
+ ranked_logger.warning(
203
+ "Found non-zero occupancy in input, setting occupancy to 1"
204
+ )
205
+ has_zero_occupancy = np.full_like(has_zero_occupancy, False)
206
+
207
+ # ... Token features for token type;
208
+ # [1, 0, 0]: non-motif
209
+ # [0, 1, 0]: indexed motif
210
+ # [0, 0, 1]: unindexed motif
211
+ is_motif_token_unindexed = atom_array.is_motif_atom_unindexed[token_starts]
212
+ motif_token_class = np.zeros((L,), dtype=np.int8)
213
+ motif_token_class[is_motif_token] = 1
214
+ motif_token_class[is_motif_token_unindexed] = 2
215
+ motif_token_type = np.eye(3, dtype=np.int8)[
216
+ motif_token_class
217
+ ] # one-hot, (L, 3)
218
+
219
+ # ... Provide GT as reference coordinates even when unfixed
220
+ motif_pos = np.nan_to_num(atom_array.coord.copy())
221
+ motif_pos = motif_pos * (is_motif_atom[..., None])
222
+
223
+ # ... Create reference features for unmasked subset (where we are allowed to use gt)
224
+ has_sequence = (
225
+ atom_array.is_motif_atom_with_fixed_seq
226
+ & ~atom_array.is_motif_atom_unindexed
227
+ ) # (n_atoms,)
228
+
229
+ if self.generate_conformers_for_non_protein_only:
230
+ has_sequence = has_sequence & ~atom_array.is_protein
231
+
232
+ if np.any(has_sequence):
233
+ # Subset atom level
234
+ atom_array_unmasked = atom_array[has_sequence]
235
+
236
+ # We always want to generate conformers if there are ligand atoms that are diffused
237
+ if (
238
+ self.generate_conformers
239
+ or np.logical_and(
240
+ atom_array_unmasked.is_ligand, ~atom_array_unmasked.is_motif_atom
241
+ ).any()
242
+ ):
243
+ atom_array_unmasked.set_annotation(
244
+ "ground_truth_conformer_policy",
245
+ np.full(
246
+ atom_array_unmasked.array_length(),
247
+ self.ground_truth_conformer_policy.value,
248
+ ),
249
+ )
250
+
251
+ # Compute the reference features
252
+ # ... Create a copy of atom_array_unmasked and replace the atom_names with gt_atom_names for reference conformer generation
253
+ atom_array_unmasked_with_gt_atom_name = atom_array_unmasked.copy()
254
+ atom_array_unmasked_with_gt_atom_name.atom_name = (
255
+ atom_array_unmasked_with_gt_atom_name.gt_atom_name
256
+ )
257
+ reference_features_unmasked = get_af3_reference_molecule_features(
258
+ atom_array_unmasked_with_gt_atom_name,
259
+ cached_residue_level_data=data["cached_residue_level_data"]
260
+ if "cached_residue_level_data" in data
261
+ else None,
262
+ **self.conformer_generation_kwargs,
263
+ )[0] ## returns tuple, need to index 0
264
+
265
+ ref_atom_name_chars[has_sequence] = reference_features_unmasked[
266
+ "ref_atom_name_chars"
267
+ ]
268
+ ref_mask[has_sequence] = reference_features_unmasked["ref_mask"]
269
+ ref_element[has_sequence] = reference_features_unmasked["ref_element"]
270
+ ref_charge[has_sequence] = reference_features_unmasked["ref_charge"]
271
+ ref_pos_is_ground_truth[has_sequence] = reference_features_unmasked[
272
+ "ref_pos_is_ground_truth"
273
+ ]
274
+
275
+ # If requested, include the reference conformers for unmasked atoms
276
+ if self.provide_reference_conformer_when_unmasked:
277
+ ref_pos[has_sequence] = reference_features_unmasked["ref_pos"]
278
+ else:
279
+ # Generate simple features
280
+ ref_charge[has_sequence] = atom_array_unmasked.charge
281
+ ref_element[has_sequence] = (
282
+ atom_array_unmasked.atomic_number
283
+ if "atomic_number" in atom_array.get_annotation_categories()
284
+ else np.vectorize(ELEMENT_NAME_TO_ATOMIC_NUMBER.get)(
285
+ atom_array.element
286
+ )
287
+ )
288
+
289
+ reference_features = {
290
+ "ref_atom_name_chars": ref_atom_name_chars, # (n_atoms, 4)
291
+ "ref_pos": ref_pos, # (n_atoms, 3)
292
+ "ref_mask": ref_mask, # (n_atoms)
293
+ "ref_element": ref_element, # (n_atoms)
294
+ "ref_charge": ref_charge, # (n_atoms)
295
+ "ref_space_uid": ref_space_uid, # (n_atoms)
296
+ "ref_pos_is_ground_truth": ref_pos_is_ground_truth, # (n_atoms)
297
+ "has_zero_occupancy": has_zero_occupancy, # (n_atoms)
298
+ # Conditional masks
299
+ # "ref_is_motif_atom": is_motif_atom, # (n_atoms, 2)
300
+ # "ref_is_motif_atom_mask": atom_array.is_motif_atom.copy(), # (n_atoms)
301
+ # "ref_is_motif_token": is_motif_token, # (n_tokens, 2)
302
+ # "ref_motif_atom_type": motif_atom_type, # (n_atoms, 3) # 3 types of atom conditions
303
+ "ref_is_motif_atom_with_fixed_coord": atom_array.is_motif_atom_with_fixed_coord.copy(), # (n_atoms)
304
+ "ref_is_motif_atom_unindexed": atom_array.is_motif_atom_unindexed.copy(),
305
+ "ref_motif_token_type": motif_token_type, # (n_tokens, 3) # 3 types of token
306
+ "motif_pos": motif_pos, # (n_atoms, 3) # GT pos for motif atoms
307
+ }
308
+
309
+ # TEMPORARY HACK TO CREATE MOTIF FEATURES AGAIN
310
+ # f = get_motif_features(atom_array)
311
+ # token_starts = get_token_starts(atom_array)
312
+ # # Annots
313
+ # atom_array = data["atom_array"]
314
+ # atom_array.set_annotation("is_motif_atom", f["is_motif_atom"])
315
+ # atom_array.set_annotation("is_motif_token", f["is_motif_token"])
316
+ # data["atom_array"] = atom_array
317
+ # # Ref feats
318
+ # motif_atom_class = np.zeros((I,), dtype=np.int8)
319
+ # motif_atom_class[atom_array.is_motif_atom] = 1
320
+ # motif_atom_class[atom_array.is_motif_atom_unindexed] = 2
321
+ # motif_atom_type = np.eye(3, dtype=np.int8)[motif_atom_class] # one-hot, (I, 3)
322
+ # is_motif_atom = torch.nn.functional.one_hot(
323
+ # torch.from_numpy(atom_array.is_motif_atom).long(), num_classes=2
324
+ # ).numpy()
325
+ # is_motif_token = torch.nn.functional.one_hot(
326
+ # torch.from_numpy(atom_array.is_motif_token[token_starts]).long(),
327
+ # num_classes=2,
328
+ # ).numpy()
329
+ # reference_features["ref_motif_atom_type"] = motif_atom_type
330
+ # reference_features["ref_is_motif_atom"] = is_motif_atom
331
+ # reference_features["ref_is_motif_atom_mask"] = atom_array.is_motif_atom.copy()
332
+ # reference_features["ref_is_motif_token"] = is_motif_token
333
+ # reference_features["is_motif_atom"] = atom_array.is_motif_atom.astype(
334
+ # bool
335
+ # ).copy()
336
+ # reference_features["is_motif_token"] = f["is_motif_token"][token_starts]
337
+ # # END TEMPORARY HACK
338
+
339
+ if "feats" not in data:
340
+ data["feats"] = {}
341
+ data["feats"].update(reference_features)
342
+
343
+ return data
344
+
345
+
346
+ class FeaturizeAtoms(Transform):
347
+ def __init__(self, n_bins=4):
348
+ self.n_bins = n_bins
349
+
350
+ def forward(self, data):
351
+ atom_array = data["atom_array"]
352
+
353
+ if "feats" not in data:
354
+ data["feats"] = {}
355
+
356
+ if (
357
+ data["is_inference"]
358
+ and "rasa_bin" in atom_array.get_annotation_categories()
359
+ ):
360
+ rasa_binned = atom_array.get_annotation("rasa_bin").copy()
361
+ elif "rasa" in atom_array.get_annotation_categories():
362
+ rasa_binned = discretize_rasa(
363
+ atom_array,
364
+ n_bins=self.n_bins - 1,
365
+ keep_protein_motif=data["conditions"]["keep_protein_motif_rasa"],
366
+ )
367
+ else:
368
+ rasa_binned = np.full(
369
+ atom_array.array_length(), self.n_bins - 1, dtype=np.int64
370
+ )
371
+ rasa_oh = F.one_hot(
372
+ torch.from_numpy(rasa_binned).long(), num_classes=self.n_bins
373
+ ).numpy()
374
+ data["feats"]["ref_atomwise_rasa"] = rasa_oh[
375
+ ..., :-1
376
+ ] # exclude last bin from being fed to the model
377
+
378
+ if "active_donor" in atom_array.get_annotation_categories():
379
+ data["feats"]["active_donor"] = torch.tensor(
380
+ np.float64(atom_array.active_donor)
381
+ ).long()
382
+ else:
383
+ data["feats"]["active_donor"] = torch.tensor(
384
+ np.zeros(len(atom_array))
385
+ ).long()
386
+
387
+ if "active_acceptor" in atom_array.get_annotation_categories():
388
+ data["feats"]["active_acceptor"] = torch.tensor(
389
+ np.float64(atom_array.active_acceptor)
390
+ ).long()
391
+ else:
392
+ data["feats"]["active_acceptor"] = torch.tensor(
393
+ np.zeros(len(atom_array))
394
+ ).long()
395
+
396
+ return data
397
+
398
+
399
+ class AddIsXFeats(Transform):
400
+ """
401
+ Adds boolean masks to the atom array based on the sequence type
402
+
403
+ Assigned types to atom array (X):
404
+ - is_backbone
405
+ - is_sidechain
406
+ - is_virtual
407
+ - is_central
408
+ - is_ca
409
+ Xs only returned as features (requires previous assignment):
410
+ - is_motif_atom_with_fixed_coord
411
+ - is_motif_atom_unindexed
412
+ - is_motif_atom_with_fixed_seq
413
+ """
414
+
415
+ requires_previous_transforms = [
416
+ AssignTypes,
417
+ PadTokensWithVirtualAtoms,
418
+ "UnindexFlaggedTokens",
419
+ ]
420
+
421
+ def __init__(
422
+ self,
423
+ X,
424
+ central_atom,
425
+ extra_atom_level_feats: list[str] = [],
426
+ extra_token_level_feats: list[str] = [],
427
+ ):
428
+ self.X = X
429
+ self.central_atom = central_atom
430
+ self.update_atom_array = False
431
+ self.extra_atom_level_feats = extra_atom_level_feats
432
+ self.extra_token_level_feats = extra_token_level_feats
433
+
434
+ def check_input(self, data):
435
+ check_contains_keys(data, ["atom_array", "feats"])
436
+
437
+ def forward(self, data: dict) -> dict:
438
+ atom_array = data["atom_array"]
439
+ atom_array = add_backbone_and_sidechain_annotations(atom_array)
440
+ token_level_array = atom_array[get_token_starts(atom_array)]
441
+ _token_rep_mask = get_af3_token_representative_masks(
442
+ atom_array, central_atom=self.central_atom
443
+ )
444
+ _token_rep_idxs = np.where(_token_rep_mask)[0]
445
+
446
+ # ... Basic features
447
+ if "is_backbone" in self.X:
448
+ is_backbone = data["atom_array"].get_annotation("is_backbone")
449
+ data["feats"]["is_backbone"] = torch.from_numpy(is_backbone).to(
450
+ dtype=torch.bool
451
+ )
452
+
453
+ if "is_sidechain" in self.X:
454
+ is_sidechain = data["atom_array"].get_annotation("is_sidechain")
455
+ data["feats"]["is_sidechain"] = torch.from_numpy(is_sidechain).to(
456
+ dtype=torch.bool
457
+ )
458
+
459
+ # Virtual atom feats
460
+ if "is_virtual" in self.X:
461
+ data["feats"]["is_virtual"] = (
462
+ atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME
463
+ )
464
+
465
+ for x in [
466
+ "is_motif_atom_with_fixed_coord",
467
+ "is_motif_atom_with_fixed_seq",
468
+ "is_motif_atom_unindexed",
469
+ ]:
470
+ if x not in self.X:
471
+ continue
472
+ if "atom" in x:
473
+ mask = atom_array.get_annotation(x).copy().astype(bool)
474
+ else:
475
+ mask = token_level_array.get_annotation(x).copy().astype(bool)
476
+ data["feats"][x] = mask
477
+
478
+ if "is_motif_token_with_fully_fixed_coord" in self.X:
479
+ mask = apply_token_wise(
480
+ atom_array,
481
+ atom_array.is_motif_atom_with_fixed_coord.astype(bool),
482
+ function=lambda x: np.all(x, axis=-1),
483
+ )
484
+ data["feats"]["is_motif_token_with_fully_fixed_coord"] = mask
485
+
486
+ # ... Central and CA
487
+ if "is_central" in self.X:
488
+ data["feats"]["is_central"] = _token_rep_mask
489
+
490
+ if "is_ca" in self.X:
491
+ # Split into components to handle separately
492
+ atom_array_indexed = atom_array[~atom_array.is_motif_atom_unindexed]
493
+ _token_rep_mask_indexed = get_af3_token_representative_masks(
494
+ atom_array_indexed, central_atom="CA"
495
+ )
496
+ if atom_array.is_motif_atom_unindexed.any():
497
+ atom_array_unindexed = atom_array[atom_array.is_motif_atom_unindexed]
498
+
499
+ # Ensure is_ca represents one and the first atom only for unindexed tokens
500
+ def first_nonzero(n):
501
+ assert n > 0
502
+ x = np.zeros(n, dtype=bool)
503
+ x[0] = 1
504
+ return x
505
+
506
+ starts = get_token_starts(atom_array_unindexed, add_exclusive_stop=True)
507
+ _token_rep_mask_unindexed = np.concatenate(
508
+ [
509
+ first_nonzero(end - start)
510
+ for start, end in zip(starts[:-1], starts[1:])
511
+ ]
512
+ )
513
+ _token_rep_mask = np.concatenate(
514
+ [
515
+ _token_rep_mask_indexed,
516
+ _token_rep_mask_unindexed,
517
+ ],
518
+ axis=0,
519
+ )
520
+ else:
521
+ _token_rep_mask = _token_rep_mask_indexed
522
+ data["feats"]["is_ca"] = _token_rep_mask
523
+
524
+ return data
525
+
526
+
527
+ PPI_PERTURB_SCALE = 2.0
528
+ PPI_PERTURB_COM_SCALE = 1.5
529
+
530
+
531
+ class MotifCenterRandomAugmentation(Transform):
532
+ requires_previous_transforms = ["BatchStructuresForDiffusionNoising"]
533
+
534
+ def __init__(
535
+ self,
536
+ batch_size,
537
+ sigma_perturb,
538
+ center_option,
539
+ ):
540
+ """
541
+ Randomly augments the coordinates of the motif center for diffusion training.
542
+ During inference, this behaviour is handled by the sampler at every step
543
+ """
544
+
545
+ self.batch_size = batch_size
546
+ self.sigma_perturb = sigma_perturb
547
+ self.center_option = center_option
548
+
549
+ def check_input(self, data: dict):
550
+ pass
551
+
552
+ def forward(self, data):
553
+ """
554
+ Applies CenterRandomAugmentation
555
+
556
+ And supplies the same rotated ground-truth coordinates as the input feature
557
+ """
558
+ if data["is_inference"]:
559
+ return data # ori token behaviour set when creating atom array & in sampler
560
+
561
+ xyz = data["coord_atom_lvl_to_be_noised"] # (batch_size, n_atoms, 3)
562
+ mask_atom_lvl = data["ground_truth"]["mask_atom_lvl"]
563
+ mask_atom_lvl = (
564
+ mask_atom_lvl & ~data["feats"]["is_motif_atom_unindexed"]
565
+ ) # Avoid double weighting
566
+
567
+ # Handle the diffferent centering options
568
+ is_motif_atom_with_fixed_coord = torch.tensor(
569
+ data["atom_array"].is_motif_atom_with_fixed_coord, dtype=torch.bool
570
+ )
571
+ if torch.any(is_motif_atom_with_fixed_coord):
572
+ if self.center_option == "motif":
573
+ center_mask = is_motif_atom_with_fixed_coord.clone()
574
+ elif self.center_option == "diffuse":
575
+ center_mask = (~is_motif_atom_with_fixed_coord).clone()
576
+ else:
577
+ center_mask = torch.ones(mask_atom_lvl.shape, dtype=torch.bool)
578
+ else:
579
+ center_mask = torch.ones(mask_atom_lvl.shape, dtype=torch.bool)
580
+
581
+ mask_atom_lvl = mask_atom_lvl & center_mask
582
+ mask_atom_lvl_expanded = mask_atom_lvl.expand(xyz.shape[0], -1)
583
+
584
+ # Masked center during training (nb not motif mask - just non-zero occupancy)
585
+ xyz = masked_center(xyz, mask_atom_lvl_expanded)
586
+
587
+ # Random offset
588
+ sigma_perturb = self.sigma_perturb
589
+ if data["sampled_condition_name"] == "ppi":
590
+ sigma_perturb = sigma_perturb * PPI_PERTURB_SCALE
591
+
592
+ xyz = (
593
+ xyz
594
+ + torch.randn(
595
+ (
596
+ self.batch_size,
597
+ 3,
598
+ ),
599
+ device=xyz.device,
600
+ )[:, None, :]
601
+ * self.sigma_perturb
602
+ )
603
+
604
+ # Apply random spin
605
+ xyz = random_rigid_augmentation(xyz, batch_size=self.batch_size, s=0)
606
+ data["coord_atom_lvl_to_be_noised"] = xyz
607
+
608
+ return data
609
+
610
+
611
+ class AugmentNoise(Transform):
612
+ requires_previous_transforms = ["SampleEDMNoise", "AddIsXFeats"]
613
+
614
+ def __init__(
615
+ self,
616
+ sigma_perturb_com,
617
+ batch_size,
618
+ ):
619
+ """
620
+ Scaled perturbation to the offset between motif and diffused region based on time
621
+ """
622
+ self.sigma_perturb_com = sigma_perturb_com
623
+ self.batch_size = batch_size
624
+
625
+ def check_input(self, data: dict):
626
+ check_contains_keys(data, ["noise", "coord_atom_lvl_to_be_noised"])
627
+ check_contains_keys(data, ["feats"])
628
+
629
+ def forward(self, data: dict) -> dict:
630
+ is_motif_atom_with_fixed_coord = data["feats"]["is_motif_atom_with_fixed_coord"]
631
+ device = data["coord_atom_lvl_to_be_noised"].device
632
+ data["noise"][..., is_motif_atom_with_fixed_coord, :] = 0.0
633
+
634
+ # Add perturbation to the centre-of-mass
635
+ if not data["is_inference"] or not is_motif_atom_with_fixed_coord.any():
636
+ sigma_perturb_com = self.sigma_perturb_com
637
+ if data["sampled_condition_name"] == "ppi":
638
+ sigma_perturb_com = sigma_perturb_com * PPI_PERTURB_COM_SCALE
639
+ eps = torch.randn(self.batch_size, 3, device=device) * sigma_perturb_com
640
+ maxt = 38
641
+ eps = eps * torch.clip((data["t"][:, None] / maxt) ** 3, min=0, max=1)
642
+ data["noise"][..., ~is_motif_atom_with_fixed_coord, :] += eps[:, None, :]
643
+
644
+ # ... Zero out noise going into motif
645
+ data["noise"][..., is_motif_atom_with_fixed_coord, :] = 0.0
646
+
647
+ assert data["coord_atom_lvl_to_be_noised"].shape == data["noise"].shape
648
+ return data
649
+
650
+
651
+ class AddGroundTruthSequence(Transform):
652
+ """
653
+ Adds token level sequence to the ground truth.
654
+
655
+ Adds:
656
+ ['ground_truth']['seq_token_lvl'] (torch.Tensor): The ground truth token level sequence [L,]
657
+ """
658
+
659
+ def __init__(self, sequence_encoding):
660
+ self.sequence_encoding = sequence_encoding
661
+
662
+ def check_input(self, data):
663
+ check_contains_keys(data, ["atom_array"])
664
+
665
+ def forward(self, data: dict) -> dict:
666
+ atom_array = data["atom_array"]
667
+ token_starts = get_token_starts(atom_array)
668
+ res_names = atom_array.res_name[token_starts]
669
+ restype = self.sequence_encoding.encode(res_names)
670
+
671
+ if "ground_truth" not in data:
672
+ data["ground_truth"] = {}
673
+
674
+ data["ground_truth"]["sequence_gt_I"] = torch.from_numpy(restype)
675
+ data["ground_truth"]["sequence_valid_mask"] = torch.from_numpy(
676
+ ~np.isin(res_names, ["UNK", "X", "DX", "<G>"])
677
+ )
678
+
679
+ return data
680
+
681
+
682
+ class AddAdditional1dFeaturesToFeats(Transform):
683
+ """
684
+ Adds any net.token_initializer.token_1d_features and net.diffusion_module.diffusion_atom_encoder.atom_1d_features present in the atomarray but not in data['feats'] to data['feats']
685
+ Args:
686
+ - autofill_zeros_if_not_present_in_atomarray: self explanatory
687
+ - token_1d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
688
+ net.token_initializer.token_1d_features
689
+ - atom_1d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
690
+ net.diffusion_module.diffusion_atom_encoder.atom_1d_features
691
+ """
692
+
693
+ incompatible_previous_transforms = ["AddAdditional1dFeaturesToFeats"]
694
+
695
+ def __init__(
696
+ self,
697
+ token_1d_features,
698
+ atom_1d_features,
699
+ autofill_zeros_if_not_present_in_atomarray=False,
700
+ ):
701
+ self.autofill = autofill_zeros_if_not_present_in_atomarray
702
+ self.token_1d_features = token_1d_features
703
+ self.atom_1d_features = atom_1d_features
704
+
705
+ def check_input(self, data) -> None:
706
+ check_contains_keys(data, ["atom_array"])
707
+ check_is_instance(data, "atom_array", AtomArray)
708
+
709
+ def generate_feature(self, feature_name, n_dims, data, feature_type):
710
+ if feature_name in data["feats"].keys():
711
+ return data
712
+ elif feature_name in data["atom_array"].get_annotation_categories():
713
+ feature_array = torch.tensor(
714
+ data["atom_array"].get_annotation(feature_name)
715
+ ).float()
716
+
717
+ # ensure that feature_array is a 2d array with second dim n_dims
718
+ if len(feature_array.shape) == 1 and n_dims == 1:
719
+ feature_array = feature_array.unsqueeze(1)
720
+ elif len(feature_array.shape) != 2:
721
+ raise ValueError(
722
+ f"{feature_type} 1d_feature `{feature_name}` must be a 2d array, got {len(feature_array.shape)}d."
723
+ )
724
+ if feature_array.shape[1] != n_dims:
725
+ raise ValueError(
726
+ f"{feature_type} 1d_feature `{feature_name}` dimensions in atomarray ({feature_array.shape[-1]}) does not match dimension declared in config, ({n_dims})"
727
+ )
728
+
729
+ elif self.autofill:
730
+ feature_array = torch.zeros((len(data["atom_array"]), n_dims)).float()
731
+
732
+ # not in data['feats'], not in atomarray, and autofill is off
733
+ else:
734
+ raise ValueError(
735
+ f"{feature_type} 1d_feature `{feature_name}` is not present in atomarray, and autofill is False"
736
+ )
737
+
738
+ if feature_type == "token":
739
+ feature_array = torch.tensor(
740
+ apply_token_wise(
741
+ data["atom_array"], feature_array.numpy(), np.mean, axis=0
742
+ )
743
+ ).float()
744
+
745
+ data["feats"][feature_name] = feature_array
746
+ return data
747
+
748
+ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
749
+ """
750
+ Checks if the 1d_features are present in data['feats']. If not present, adds them from the atomarray.
751
+ If annotation is not present in atomarray, either autofills the feature with 0s or throws an error
752
+ """
753
+ if "feats" not in data.keys():
754
+ data["feats"] = {}
755
+
756
+ for feature_name, n_dims in self.token_1d_features.items():
757
+ data = self.generate_feature(feature_name, n_dims, data, "token")
758
+
759
+ for feature_name, n_dims in self.atom_1d_features.items():
760
+ data = self.generate_feature(feature_name, n_dims, data, "atom")
761
+
762
+ return data
763
+
764
+
765
+ class FeaturizepLDDT(Transform):
766
+ """
767
+ Provides:
768
+ 0 for unknown pLDDT
769
+ +1 for high pLDDT
770
+ -1 for low pLDDT
771
+ """
772
+
773
+ def __init__(
774
+ self,
775
+ skip,
776
+ ):
777
+ self.skip = skip
778
+ self.bsplit = 80 # Threshold for splitting pLDDT into high and low
779
+
780
+ def forward(self, data: dict) -> dict:
781
+ atom_array = data["atom_array"]
782
+ token_starts = get_token_starts(atom_array)
783
+ I = len(token_starts)
784
+ zeros = np.zeros(I, dtype=int)
785
+ if data["is_inference"]:
786
+ if "ref_plddt" not in atom_array.get_annotation_categories():
787
+ ref_plddt = zeros
788
+ else:
789
+ ref_plddt = atom_array.get_annotation("ref_plddt")[token_starts]
790
+ elif (
791
+ self.skip
792
+ or "b_factor" not in atom_array.get_annotation_categories()
793
+ or not data["conditions"]["featurize_plddt"]
794
+ ):
795
+ ref_plddt = zeros
796
+ else:
797
+ plddt = atom_array.get_annotation("b_factor")
798
+ mean_plddt = np.mean(plddt)
799
+ ref_plddt = zeros + (1 if mean_plddt >= self.bsplit else -1)
800
+
801
+ # Provide only non-zero values for diffused tokens
802
+ ref_plddt = (
803
+ ~(get_motif_features(atom_array)["is_motif_token"][token_starts])
804
+ * ref_plddt
805
+ )
806
+ data["feats"]["ref_plddt"] = ref_plddt
807
+ return data