rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,632 @@
1
+ """
2
+ The Atom14 data pipeline for training and inference
3
+ """
4
+
5
+ import warnings
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ from atomworks.constants import (
11
+ AF3_EXCLUDED_LIGANDS,
12
+ GAP,
13
+ STANDARD_AA,
14
+ STANDARD_DNA,
15
+ STANDARD_RNA,
16
+ )
17
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
18
+ from atomworks.ml.transforms.atom_array import (
19
+ AddGlobalAtomIdAnnotation,
20
+ AddGlobalTokenIdAnnotation,
21
+ AddProteinTerminiAnnotation,
22
+ AddWithinChainInstanceResIdx,
23
+ AddWithinPolyResIdxAnnotation,
24
+ ComputeAtomToTokenMap,
25
+ CopyAnnotation,
26
+ )
27
+ from atomworks.ml.transforms.atomize import (
28
+ AtomizeByCCDName,
29
+ FlagNonPolymersForAtomization,
30
+ )
31
+ from atomworks.ml.transforms.base import (
32
+ AddData,
33
+ Compose,
34
+ ConditionalRoute,
35
+ ConvertToTorch,
36
+ Identity,
37
+ RandomRoute,
38
+ SubsetToKeys,
39
+ )
40
+ from atomworks.ml.transforms.bfactor_conditioned_transforms import SetOccToZeroOnBfactor
41
+ from atomworks.ml.transforms.bonds import AddAF3TokenBondFeatures
42
+ from atomworks.ml.transforms.cached_residue_data import LoadCachedResidueLevelData
43
+ from atomworks.ml.transforms.covalent_modifications import (
44
+ FlagAndReassignCovalentModifications,
45
+ )
46
+ from atomworks.ml.transforms.crop import CropContiguousLikeAF3, CropSpatialLikeAF3
47
+ from atomworks.ml.transforms.diffusion.batch_structures import (
48
+ BatchStructuresForDiffusionNoising,
49
+ )
50
+ from atomworks.ml.transforms.diffusion.edm import SampleEDMNoise
51
+ from atomworks.ml.transforms.featurize_unresolved_residues import (
52
+ MaskPolymerResiduesWithUnresolvedFrameAtoms,
53
+ PlaceUnresolvedTokenAtomsOnRepresentativeAtom,
54
+ PlaceUnresolvedTokenOnClosestResolvedTokenInSequence,
55
+ )
56
+ from atomworks.ml.transforms.filters import (
57
+ FilterToSpecifiedPNUnits,
58
+ HandleUndesiredResTokens,
59
+ RemoveHydrogens,
60
+ RemoveNucleicAcidTerminalOxygen,
61
+ RemovePolymersWithTooFewResolvedResidues,
62
+ RemoveTerminalOxygen,
63
+ RemoveUnresolvedLigandAtomsIfTooMany,
64
+ RemoveUnresolvedPNUnits,
65
+ )
66
+ from atomworks.ml.utils.token import get_token_count
67
+ from rfd3.transforms.conditioning_base import (
68
+ SampleConditioningFlags,
69
+ SampleConditioningType,
70
+ StrtoBoolforIsXFeatures,
71
+ UnindexFlaggedTokens,
72
+ )
73
+ from rfd3.transforms.design_transforms import (
74
+ AddAdditional1dFeaturesToFeats,
75
+ AddGroundTruthSequence,
76
+ AddIsXFeats,
77
+ AssignTypes,
78
+ AugmentNoise,
79
+ CreateDesignReferenceFeatures,
80
+ FeaturizeAtoms,
81
+ FeaturizepLDDT,
82
+ MotifCenterRandomAugmentation,
83
+ SubsampleToTypes,
84
+ )
85
+ from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
86
+ from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
87
+ from rfd3.transforms.ppi_transforms import (
88
+ Add1DSSFeature,
89
+ AddGlobalIsNonLoopyFeature,
90
+ AddPPIHotspotFeature,
91
+ PPIFullBinderCropSpatial,
92
+ )
93
+ from rfd3.transforms.rasa import (
94
+ CalculateRASA,
95
+ SetZeroOccOnDeltaRASA,
96
+ )
97
+ from rfd3.transforms.symmetry import AddSymmetryFeats
98
+ from rfd3.transforms.util_transforms import (
99
+ IPDB,
100
+ AggregateFeaturesLikeAF3WithoutMSA,
101
+ EncodeAF3TokenLevelFeatures,
102
+ RemoveTokensWithoutCorrespondingCentralAtom,
103
+ )
104
+ from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
105
+
106
+ from foundry.common import exists
107
+
108
+ ######################################################################################
109
+ # Common transforms
110
+ ######################################################################################
111
+ af3_sequence_encoding = AF3SequenceEncoding()
112
+
113
+
114
+ IPDB # noqa
115
+
116
+
117
+ def TrainingRoute(transform):
118
+ return ConditionalRoute(
119
+ condition_func=lambda data: data["is_inference"],
120
+ transform_map={True: Identity(), False: transform},
121
+ )
122
+
123
+
124
+ def InferenceRoute(transform):
125
+ return ConditionalRoute(
126
+ condition_func=lambda data: data["is_inference"],
127
+ transform_map={False: Identity(), True: transform},
128
+ )
129
+
130
+
131
+ def TrainingConditionRoute(condition, transform):
132
+ transform = TrainingRoute(
133
+ ConditionalRoute(
134
+ condition_func=lambda data: data["conditions"][condition],
135
+ transform_map={
136
+ True: transform,
137
+ False: Identity(),
138
+ },
139
+ )
140
+ )
141
+ return transform
142
+
143
+
144
+ def get_pre_crop_transforms(
145
+ central_atom: str,
146
+ b_factor_min: float | None,
147
+ ):
148
+ return [
149
+ InferenceRoute(StrtoBoolforIsXFeatures()),
150
+ RemoveHydrogens(),
151
+ FilterToSpecifiedPNUnits(
152
+ extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
153
+ ), # Filter to non-clashing PN units
154
+ RemoveTerminalOxygen(),
155
+ # ... Remove PN units that are unresolved early (and also after cropping)
156
+ TrainingRoute(SetOccToZeroOnBfactor(b_factor_min, None)),
157
+ RemoveUnresolvedPNUnits(),
158
+ # ... Remove polymers with too few resolved residues
159
+ TrainingRoute(RemovePolymersWithTooFewResolvedResidues(min_residues=4)),
160
+ MaskPolymerResiduesWithUnresolvedFrameAtoms(),
161
+ # Only filter out undesired res names during training, since it's intentional if they're in the input during inference.
162
+ TrainingRoute(HandleUndesiredResTokens(AF3_EXCLUDED_LIGANDS)),
163
+ # ... Bulk removal of unresolved atoms
164
+ TrainingRoute(
165
+ RemoveUnresolvedLigandAtomsIfTooMany(unresolved_ligand_atom_limit=5)
166
+ ),
167
+ # Filter out tokens without a central atom during training, Padding during inference ensures each residue has a central atom
168
+ TrainingRoute(
169
+ RemoveTokensWithoutCorrespondingCentralAtom(central_atom=central_atom),
170
+ ),
171
+ FlagAndReassignCovalentModifications(),
172
+ FlagNonPolymersForAtomization(),
173
+ AddGlobalAtomIdAnnotation(),
174
+ AtomizeByCCDName(
175
+ atomize_by_default=True,
176
+ res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA,
177
+ move_atomized_part_to_end=False,
178
+ validate_atomize=False,
179
+ ),
180
+ RemoveNucleicAcidTerminalOxygen(),
181
+ AddWithinChainInstanceResIdx(),
182
+ AddWithinPolyResIdxAnnotation(),
183
+ AddProteinTerminiAnnotation(),
184
+ ]
185
+
186
+
187
+ def get_crop_transform(
188
+ crop_size: int,
189
+ crop_center_cutoff_distance: float,
190
+ crop_contiguous_probability: float,
191
+ crop_spatial_probability: float,
192
+ dna_contact_crop_probability: float,
193
+ keep_full_binder_in_spatial_crop: bool,
194
+ max_binder_length: int,
195
+ max_atoms_in_crop: int | None,
196
+ allowed_types: List[str],
197
+ ):
198
+ if (
199
+ crop_contiguous_probability > 0
200
+ or crop_spatial_probability > 0
201
+ or dna_contact_crop_probability > 0
202
+ ):
203
+ assert np.isclose(
204
+ crop_contiguous_probability
205
+ + crop_spatial_probability
206
+ + dna_contact_crop_probability,
207
+ 1.0,
208
+ atol=1e-6,
209
+ ), "Crop probabilities must sum to 1.0"
210
+ assert crop_size > 0, "Crop size must be greater than 0"
211
+ assert (
212
+ crop_center_cutoff_distance > 0
213
+ ), "Crop center cutoff distance must be greater than 0"
214
+
215
+ pre_crop_transforms = [
216
+ SubsampleToTypes(allowed_types=allowed_types),
217
+ ]
218
+
219
+ cropping_transform = RandomRoute(
220
+ transforms=[
221
+ CropContiguousLikeAF3(
222
+ crop_size=crop_size,
223
+ keep_uncropped_atom_array=True,
224
+ max_atoms_in_crop=max_atoms_in_crop,
225
+ ),
226
+ ConditionalRoute(
227
+ condition_func=lambda data: (
228
+ keep_full_binder_in_spatial_crop
229
+ and data["sampled_condition_name"] == "ppi"
230
+ and get_token_count(
231
+ data["atom_array"][data["atom_array"].is_binder_pn_unit]
232
+ )
233
+ < max_binder_length
234
+ and data["conditions"]["full_binder_crop"]
235
+ ),
236
+ transform_map={
237
+ True: PPIFullBinderCropSpatial(
238
+ crop_size=crop_size,
239
+ crop_center_cutoff_distance=crop_center_cutoff_distance,
240
+ keep_uncropped_atom_array=True,
241
+ max_atoms_in_crop=max_atoms_in_crop,
242
+ ),
243
+ False: CropSpatialLikeAF3(
244
+ crop_size=crop_size,
245
+ crop_center_cutoff_distance=crop_center_cutoff_distance,
246
+ keep_uncropped_atom_array=True,
247
+ max_atoms_in_crop=max_atoms_in_crop,
248
+ ),
249
+ },
250
+ ),
251
+ ProteinDNAContactContiguousCrop(
252
+ protein_contact_type="all",
253
+ dna_contact_type="base",
254
+ max_atoms_in_crop=max_atoms_in_crop,
255
+ ),
256
+ ],
257
+ probs=[
258
+ crop_contiguous_probability,
259
+ crop_spatial_probability,
260
+ dna_contact_crop_probability,
261
+ ],
262
+ )
263
+
264
+ post_crop_transforms = [
265
+ # ... Handling of remaining unresolved residues (NOTE: usually best done after inputs are processed.)
266
+ TrainingRoute(
267
+ PlaceUnresolvedTokenAtomsOnRepresentativeAtom(annotation_to_update="coord")
268
+ ),
269
+ TrainingRoute(
270
+ PlaceUnresolvedTokenOnClosestResolvedTokenInSequence(
271
+ annotation_to_update="coord",
272
+ annotation_to_copy="coord",
273
+ )
274
+ ),
275
+ ]
276
+
277
+ transform = (
278
+ pre_crop_transforms
279
+ + [
280
+ TrainingRoute(cropping_transform),
281
+ ]
282
+ + post_crop_transforms
283
+ )
284
+ return transform
285
+
286
+
287
+ def get_diffusion_transforms(
288
+ *,
289
+ sigma_data: float,
290
+ diffusion_batch_size: int,
291
+ ):
292
+ return [
293
+ ComputeAtomToTokenMap(),
294
+ ConvertToTorch(keys=["encoded", "feats"]),
295
+ # Prepare coordinates for noising (without modifying the ground truth)
296
+ # ...add placeholder coordinates for noising
297
+ CopyAnnotation(annotation_to_copy="coord", new_annotation="coord_to_be_noised"),
298
+ # Feature aggregation
299
+ AggregateFeaturesLikeAF3WithoutMSA(),
300
+ # ...batching and noise sampling for diffusion
301
+ BatchStructuresForDiffusionNoising(batch_size=diffusion_batch_size),
302
+ SampleEDMNoise(
303
+ sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size
304
+ ),
305
+ ]
306
+
307
+
308
+ ######################################################################################
309
+ # Pipelines
310
+ ######################################################################################
311
+
312
+
313
+ def build_atom14_base_pipeline_(
314
+ *,
315
+ # Training or inference (required)
316
+ is_inference: bool, # If True, we skip cropping, etc.
317
+ return_atom_array: bool,
318
+ # Crop params
319
+ allowed_types: List[str],
320
+ crop_size: int,
321
+ crop_center_cutoff_distance: float,
322
+ crop_contiguous_probability: float,
323
+ crop_spatial_probability: float,
324
+ dna_contact_crop_probability: float,
325
+ keep_full_binder_in_spatial_crop: bool,
326
+ max_binder_length: int, # Only relevant when keep_full_binder_in_spatial_crop is True
327
+ max_atoms_in_crop: int | None,
328
+ b_factor_min: float | None,
329
+ zero_occ_on_exposure_after_cropping: bool,
330
+ # Training Hypers
331
+ sigma_data: float,
332
+ diffusion_batch_size: int,
333
+ # Reference conformer policy
334
+ generate_conformers: bool,
335
+ generate_conformers_for_non_protein_only: bool,
336
+ provide_reference_conformer_when_unmasked: bool,
337
+ ground_truth_conformer_policy: str,
338
+ provide_elements_for_unindexed_components: bool,
339
+ use_element_for_atom_names_of_atomized_tokens: bool,
340
+ residue_cache_dir: bool,
341
+ # Conditioning
342
+ train_conditions: dict,
343
+ meta_conditioning_probabilities: dict,
344
+ # Atom14/Model
345
+ n_atoms_per_token: int,
346
+ central_atom: str,
347
+ sigma_perturb: float,
348
+ sigma_perturb_com: float,
349
+ association_scheme: str | None,
350
+ center_option: str,
351
+ atom_1d_features: dict | None,
352
+ token_1d_features: dict | None,
353
+ # PPI features
354
+ max_ppi_hotspots_frac_to_provide: float,
355
+ ppi_hotspot_max_distance: float,
356
+ # Secondary structure features
357
+ max_ss_frac_to_provide: float,
358
+ min_ss_island_len: int,
359
+ max_ss_island_len: int,
360
+ **_, # dump additional kwargs (e.g. msa stuff)
361
+ ):
362
+ """
363
+ All-Atom design pipeline
364
+ """
365
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
366
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
367
+
368
+ # Add any data necessary for downstream transforms
369
+ transforms = [
370
+ AddData(
371
+ {
372
+ "is_inference": is_inference,
373
+ "sampled_condition_name": None,
374
+ "conditions": {},
375
+ }
376
+ ),
377
+ AssignTypes(),
378
+ ]
379
+ # During training, sample condition | adds 'condition': TrainingCondition to data dict
380
+ transforms += [
381
+ TrainingRoute(
382
+ SampleConditioningType(
383
+ train_conditions=train_conditions,
384
+ meta_conditioning_probabilities=meta_conditioning_probabilities,
385
+ sequence_encoding=af3_sequence_encoding,
386
+ ),
387
+ ),
388
+ ]
389
+
390
+ # Pre-crop transforms
391
+ transforms += get_pre_crop_transforms(
392
+ central_atom=central_atom,
393
+ b_factor_min=b_factor_min,
394
+ )
395
+ if zero_occ_on_exposure_after_cropping:
396
+ transforms.append(TrainingRoute(CalculateRASA(requires_ligand=False)))
397
+
398
+ transforms += get_crop_transform(
399
+ crop_size=crop_size,
400
+ crop_center_cutoff_distance=crop_center_cutoff_distance,
401
+ crop_contiguous_probability=crop_contiguous_probability,
402
+ crop_spatial_probability=crop_spatial_probability,
403
+ dna_contact_crop_probability=dna_contact_crop_probability,
404
+ keep_full_binder_in_spatial_crop=keep_full_binder_in_spatial_crop,
405
+ max_binder_length=max_binder_length,
406
+ max_atoms_in_crop=max_atoms_in_crop,
407
+ allowed_types=allowed_types,
408
+ )
409
+
410
+ if zero_occ_on_exposure_after_cropping:
411
+ # Optional: Zero out sidechain occupancy for atoms that have become exposed
412
+ transforms.append(TrainingRoute(SetZeroOccOnDeltaRASA()))
413
+ else:
414
+ # RASA calculated after cropping
415
+ transforms.append(
416
+ TrainingConditionRoute(
417
+ "calculate_rasa", CalculateRASA(requires_ligand=True)
418
+ )
419
+ )
420
+ # Need condition flags to add is motif atom annotations before hbond in order to enable using full motif for hbonds
421
+
422
+ # ... Add global token features (since number of tokens is fixed after cropping)
423
+ transforms.append(AddGlobalTokenIdAnnotation())
424
+ # ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
425
+ transforms.append(TrainingRoute(SampleConditioningFlags()))
426
+
427
+ # Post-crop transforms
428
+ transforms.append(
429
+ TrainingConditionRoute(
430
+ "calculate_hbonds",
431
+ CalculateHbondsPlus(
432
+ cutoff_HA_dist=3,
433
+ cutoff_DA_distance=3.5,
434
+ ),
435
+ )
436
+ )
437
+
438
+ # Design Transforms
439
+ transforms += [
440
+ LoadCachedResidueLevelData(
441
+ dir=Path(residue_cache_dir) if exists(residue_cache_dir) else None,
442
+ sharding_depth=1,
443
+ ),
444
+ # ... Fuse inference and training conditioning assignments
445
+ UnindexFlaggedTokens(central_atom=central_atom),
446
+ # ... Virtual atom padding (NOTE: Last transform which modulates atom count)
447
+ PadTokensWithVirtualAtoms(
448
+ n_atoms_per_token=n_atoms_per_token,
449
+ atom_to_pad_from=central_atom,
450
+ association_scheme=association_scheme,
451
+ ), # 0.1 s
452
+ # Possibly add hotspots
453
+ TrainingRoute(
454
+ ConditionalRoute(
455
+ condition_func=lambda data: data["sampled_condition_name"] == "ppi"
456
+ and data["conditions"]["add_ppi_hotspots"],
457
+ transform_map={
458
+ True: AddPPIHotspotFeature(
459
+ max_hotspots_frac_to_provide=max_ppi_hotspots_frac_to_provide,
460
+ hotspot_max_distance=ppi_hotspot_max_distance,
461
+ ),
462
+ False: Identity(),
463
+ },
464
+ )
465
+ ),
466
+ TrainingRoute(
467
+ Add1DSSFeature(
468
+ max_secondary_structure_frac_to_provide=max_ss_frac_to_provide,
469
+ min_ss_island_len=min_ss_island_len,
470
+ max_ss_island_len=max_ss_island_len,
471
+ ),
472
+ ),
473
+ TrainingRoute(
474
+ ConditionalRoute(
475
+ condition_func=lambda data: data["conditions"][
476
+ "add_global_is_non_loopy_feature"
477
+ ],
478
+ transform_map={
479
+ True: AddGlobalIsNonLoopyFeature(),
480
+ False: Identity(),
481
+ },
482
+ )
483
+ ),
484
+ # ... AF3 token level encoding with sequence masking
485
+ EncodeAF3TokenLevelFeatures(
486
+ sequence_encoding=af3_sequence_encoding, encode_residues_to=GAP
487
+ ),
488
+ # ... Atom-level reference features
489
+ CreateDesignReferenceFeatures(
490
+ generate_conformers=generate_conformers,
491
+ generate_conformers_for_non_protein_only=generate_conformers_for_non_protein_only,
492
+ provide_reference_conformer_when_unmasked=provide_reference_conformer_when_unmasked,
493
+ ground_truth_conformer_policy=ground_truth_conformer_policy,
494
+ provide_elements_for_unindexed_components=provide_elements_for_unindexed_components,
495
+ use_element_for_atom_names_of_atomized_tokens=use_element_for_atom_names_of_atomized_tokens,
496
+ ),
497
+ # ... Add useful features for losses / metrics
498
+ AddIsXFeats(
499
+ X=[
500
+ # Basic
501
+ "is_backbone",
502
+ "is_sidechain",
503
+ # Virtual atom
504
+ "is_virtual",
505
+ "is_central",
506
+ "is_ca",
507
+ # Conditioning
508
+ "is_motif_atom_with_fixed_coord",
509
+ "is_motif_atom_unindexed",
510
+ "is_motif_atom_with_fixed_seq",
511
+ "is_motif_token_with_fully_fixed_coord",
512
+ ],
513
+ central_atom=central_atom,
514
+ ),
515
+ FeaturizeAtoms(),
516
+ FeaturizepLDDT(skip=b_factor_min is not None),
517
+ AddAdditional1dFeaturesToFeats(
518
+ autofill_zeros_if_not_present_in_atomarray=True,
519
+ token_1d_features=token_1d_features,
520
+ atom_1d_features=atom_1d_features,
521
+ ),
522
+ AddAF3TokenBondFeatures(),
523
+ AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
524
+ ConditionalRoute(
525
+ condition_func=lambda data: "symmetry_id"
526
+ in data["atom_array"].get_annotation_categories(),
527
+ transform_map={
528
+ True: AddSymmetryFeats(),
529
+ False: Identity(),
530
+ },
531
+ ),
532
+ ]
533
+
534
+ # EDM-style wrap-up (no additional features added at this point)
535
+ transforms += get_diffusion_transforms(
536
+ sigma_data=sigma_data,
537
+ diffusion_batch_size=diffusion_batch_size,
538
+ )
539
+
540
+ # ... Random augmentation accounting for motif
541
+ transforms += [
542
+ MotifCenterRandomAugmentation(
543
+ batch_size=diffusion_batch_size,
544
+ sigma_perturb=sigma_perturb,
545
+ center_option=center_option,
546
+ ),
547
+ AugmentNoise(
548
+ sigma_perturb_com=sigma_perturb_com,
549
+ batch_size=diffusion_batch_size,
550
+ ),
551
+ ]
552
+
553
+ # Subset to necessary keys only
554
+ keys_to_keep = [
555
+ "example_id",
556
+ "feats",
557
+ "t",
558
+ "noise",
559
+ "ground_truth",
560
+ "coord_atom_lvl_to_be_noised",
561
+ "extra_info",
562
+ "sampled_condition_name",
563
+ "log_dict",
564
+ ]
565
+ if return_atom_array:
566
+ keys_to_keep.extend(
567
+ [
568
+ "atom_array",
569
+ "specification",
570
+ ]
571
+ )
572
+ # For debugging & tests:
573
+ if not is_inference:
574
+ keys_to_keep.append("conditions")
575
+ transforms.append(SubsetToKeys(keys_to_keep))
576
+
577
+ pipeline = Compose(transforms)
578
+ return pipeline
579
+
580
+
581
+ def build_atom14_base_pipeline(
582
+ is_inference: bool,
583
+ *,
584
+ # Dumped args:
585
+ protein_msa_dirs=None,
586
+ rna_msa_dirs=None,
587
+ n_recycles=None,
588
+ n_msa=None,
589
+ # Catch all other arguments:
590
+ **kwargs,
591
+ ):
592
+ """
593
+ Wrapper around pipeline construction to handle empty training args
594
+ Sets default behaviour for inference to keep backward compatibility
595
+ """
596
+
597
+ if is_inference:
598
+ # Provide explicit defaults for training-only args
599
+ kwargs.setdefault("crop_size", 512)
600
+ kwargs.setdefault("crop_center_cutoff_distance", 10.0)
601
+ kwargs.setdefault("crop_contiguous_probability", 1.0)
602
+ kwargs.setdefault("crop_spatial_probability", 0.0)
603
+ kwargs.setdefault("dna_contact_crop_probability", 0.0)
604
+ kwargs.setdefault("max_atoms_in_crop", None)
605
+ kwargs.setdefault("keep_full_binder_in_spatial_crop", True)
606
+ kwargs.setdefault("max_ppi_hotspots_frac_to_provide", 0)
607
+ kwargs.setdefault("ppi_hotspot_max_distance", 15)
608
+ kwargs.setdefault("max_ss_frac_to_provide", 0.0)
609
+ kwargs.setdefault("min_ss_island_len", 0)
610
+ kwargs.setdefault("max_ss_island_len", 999)
611
+ kwargs.setdefault("max_binder_length", 999)
612
+
613
+ kwargs.setdefault("b_factor_min", None)
614
+ kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
615
+ kwargs.setdefault("meta_conditioning_probabilities", {})
616
+ kwargs.setdefault("association_scheme", "dense")
617
+ kwargs.setdefault("sigma_perturb", 0.0)
618
+ kwargs.setdefault("sigma_perturb_com", 0.0)
619
+ kwargs.setdefault("allowed_types", "ALL")
620
+ kwargs.setdefault("train_conditions", {})
621
+ kwargs.setdefault("residue_cache_dir", None)
622
+
623
+ # TODO: Delete these once all checkpoints are updated with the latest defaults
624
+ kwargs.setdefault("generate_conformers_for_non_protein_only", True)
625
+ kwargs.setdefault("return_atom_array", True)
626
+ kwargs.setdefault("provide_elements_for_unindexed_components", False)
627
+ kwargs.setdefault("center_option", "all")
628
+
629
+ return build_atom14_base_pipeline_(
630
+ is_inference=is_inference,
631
+ **kwargs,
632
+ )