rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/data/pipelines.py ADDED
@@ -0,0 +1,558 @@
1
+ from os import PathLike
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ from atomworks.common import exists
6
+ from atomworks.constants import (
7
+ AF3_EXCLUDED_LIGANDS,
8
+ STANDARD_AA,
9
+ STANDARD_DNA,
10
+ STANDARD_RNA,
11
+ )
12
+ from atomworks.enums import ChainType
13
+ from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING, AF3SequenceEncoding
14
+ from atomworks.ml.transforms.af3_reference_molecule import (
15
+ GetAF3ReferenceMoleculeFeatures,
16
+ GroundTruthConformerPolicy,
17
+ RandomApplyGroundTruthConformerByChainType,
18
+ )
19
+ from atomworks.ml.transforms.atom_array import (
20
+ AddGlobalAtomIdAnnotation,
21
+ AddGlobalResIdAnnotation,
22
+ AddGlobalTokenIdAnnotation,
23
+ AddWithinChainInstanceResIdx,
24
+ AddWithinPolyResIdxAnnotation,
25
+ ComputeAtomToTokenMap,
26
+ CopyAnnotation,
27
+ )
28
+ from atomworks.ml.transforms.atom_frames import (
29
+ AddAtomFrames,
30
+ AddIsRealAtom,
31
+ AddPolymerFrameIndices,
32
+ )
33
+ from atomworks.ml.transforms.atom_level_embeddings import FeaturizeAtomLevelEmbeddings
34
+ from atomworks.ml.transforms.atomize import (
35
+ AtomizeByCCDName,
36
+ FlagNonPolymersForAtomization,
37
+ )
38
+ from atomworks.ml.transforms.base import (
39
+ AddData,
40
+ ApplyFunction,
41
+ Compose,
42
+ ConditionalRoute,
43
+ ConvertToTorch,
44
+ Identity,
45
+ RandomRoute,
46
+ SubsetToKeys,
47
+ )
48
+ from atomworks.ml.transforms.bfactor_conditioned_transforms import SetOccToZeroOnBfactor
49
+ from atomworks.ml.transforms.bonds import (
50
+ AddAF3TokenBondFeatures,
51
+ )
52
+ from atomworks.ml.transforms.cached_residue_data import (
53
+ LoadCachedResidueLevelData,
54
+ RandomSubsampleCachedConformers,
55
+ )
56
+ from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
57
+ from atomworks.ml.transforms.chirals import AddAF3ChiralFeatures
58
+ from atomworks.ml.transforms.covalent_modifications import (
59
+ FlagAndReassignCovalentModifications,
60
+ )
61
+ from atomworks.ml.transforms.crop import CropContiguousLikeAF3, CropSpatialLikeAF3
62
+ from atomworks.ml.transforms.diffusion.batch_structures import (
63
+ BatchStructuresForDiffusionNoising,
64
+ )
65
+ from atomworks.ml.transforms.diffusion.edm import SampleEDMNoise
66
+ from atomworks.ml.transforms.encoding import (
67
+ EncodeAF3TokenLevelFeatures,
68
+ EncodeAtomArray,
69
+ )
70
+ from atomworks.ml.transforms.feature_aggregation.af3 import AggregateFeaturesLikeAF3
71
+ from atomworks.ml.transforms.feature_aggregation.confidence import (
72
+ PackageConfidenceFeats,
73
+ )
74
+ from atomworks.ml.transforms.featurize_unresolved_residues import (
75
+ MaskPolymerResiduesWithUnresolvedFrameAtoms,
76
+ PlaceUnresolvedTokenAtomsOnRepresentativeAtom,
77
+ PlaceUnresolvedTokenOnClosestResolvedTokenInSequence,
78
+ )
79
+ from atomworks.ml.transforms.filters import (
80
+ FilterToSpecifiedPNUnits,
81
+ HandleUndesiredResTokens,
82
+ RandomlyRemoveLigands,
83
+ RemoveHydrogens,
84
+ RemoveNucleicAcidTerminalOxygen,
85
+ RemovePolymersWithTooFewResolvedResidues,
86
+ RemoveTerminalOxygen,
87
+ RemoveUnresolvedPNUnits,
88
+ )
89
+ from atomworks.ml.transforms.mirror_transform import RandomlyMirrorInputs
90
+ from atomworks.ml.transforms.msa.msa import (
91
+ EncodeMSA,
92
+ FeaturizeMSALikeAF3,
93
+ FillFullMSAFromEncoded,
94
+ LoadPolymerMSAs,
95
+ PairAndMergePolymerMSAs,
96
+ )
97
+ from atomworks.ml.transforms.random_atomize_residues import RandomAtomizeResidues
98
+ from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
99
+ from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
100
+ from omegaconf import DictConfig
101
+ from rf3.data.cyclic_transform import AddCyclicBonds
102
+ from rf3.data.extra_xforms import CheckForNaNsInInputs
103
+ from rf3.data.pipeline_utils import (
104
+ annotate_post_crop_hash,
105
+ annotate_pre_crop_hash,
106
+ build_ground_truth_distogram_transform,
107
+ set_to_occupancy_0_where_crop_hashes_differ,
108
+ )
109
+
110
+
111
+ def TrainingRoute(transform):
112
+ return ConditionalRoute(
113
+ condition_func=lambda data: data["is_inference"],
114
+ transform_map={True: Identity(), False: transform},
115
+ )
116
+
117
+
118
+ def InferenceRoute(transform):
119
+ return ConditionalRoute(
120
+ condition_func=lambda data: data["is_inference"],
121
+ transform_map={False: Identity(), True: transform},
122
+ )
123
+
124
+
125
+ def build_af3_transform_pipeline(
126
+ *,
127
+ # Training or inference (required)
128
+ is_inference: bool, # If True, we skip cropping, etc.
129
+ # MSA dirs
130
+ protein_msa_dirs: list[dict],
131
+ rna_msa_dirs: list[dict],
132
+ # Recycles
133
+ n_recycles: int = 5,
134
+ # Crop params
135
+ crop_size: int = 384,
136
+ crop_center_cutoff_distance: float = 15.0,
137
+ crop_contiguous_probability: float = 0.5,
138
+ crop_spatial_probability: float = 0.5,
139
+ max_atoms_in_crop: int | None = None,
140
+ # Undesired res names
141
+ undesired_res_names: list[str] = AF3_EXCLUDED_LIGANDS,
142
+ # Conformer generation params
143
+ conformer_generation_timeout: float = 5.0, # seconds
144
+ use_element_for_atom_names_of_atomized_tokens: bool = False,
145
+ # MSA parameters
146
+ max_msa_sequences: int = 10_000, # Paper: 16,000, but we only have 10K stored on disk
147
+ n_msa: int = 10_000, # Paper: ?? I think ~12K?
148
+ dense_msa: bool = True, # True for AF3
149
+ add_residue_is_paired_feature: bool = False,
150
+ # Cache paths
151
+ msa_cache_dir: PathLike | str | None = None,
152
+ residue_cache_dir: PathLike
153
+ | str
154
+ | None = "/net/tukwila/lschaaf/datahub/MACE-OMOL-Jul2025/mace_embeddings",
155
+ # Diffusion parameters
156
+ sigma_data: float = 16.0,
157
+ diffusion_batch_size: int = 48,
158
+ # Whether to include features for confidence head
159
+ run_confidence_head: bool = False,
160
+ return_atom_array: bool = True,
161
+ # DNA
162
+ pad_dna_p_skip: float = 0.0,
163
+ b_factor_min: float | None = None,
164
+ b_factor_max: float | None = None,
165
+ # ------ Atom-level conditioning ------ #
166
+ p_unconditional: float = 1.0, # Show no conditioning, anywhere (i.e., unconditional)
167
+ template_noise_scales: dict | DictConfig = {
168
+ "atomized": 1e-5, # No noise (for atomized tokens)
169
+ "not_atomized": 0.2, # Up to 0.2A of noise (for non-atomized tokens)
170
+ },
171
+ allowed_chain_types_for_conditioning: list[int | str | ChainType]
172
+ | None = ChainType.get_all_types(), # All chain types (None = no conditioning)
173
+ p_condition_per_token: float = 0.0, # When sampling with conditions, X% of tokens are conditioned (e.g., X^2% of pairs have conditions)
174
+ p_provide_inter_molecule_distances: float = 0.0, # When sampling with conditions, X% of the time, show any inter-molecule distances
175
+ # (Reference Conformer)
176
+ p_give_non_polymer_ref_conf: float = 0.0, # When sampling with conditions, X% of non-polymer chains get a ground-truth reference conformer
177
+ p_give_polymer_ref_conf: float = 0.0, # When sampling with conditions, X% of polymer chains get a ground-truth reference conformer
178
+ # -------------------------------------- #
179
+ take_first_chiral_subordering: bool = False,
180
+ mirror_prob: float = 0.0,
181
+ input_contains_explicit_msa: bool = False,
182
+ atomization_prob: float = 0.0,
183
+ ligand_dropout_prob: float = 0.0,
184
+ raise_if_missing_msa_for_protein_of_length_n: int | None = None,
185
+ mask_crop_edges: bool = False,
186
+ p_dropout_atom_level_embeddings: float = 0.0,
187
+ embedding_dim: int = 384,
188
+ n_conformers: int = 8,
189
+ add_cyclic_bonds: bool = True,
190
+ metrics_tags: list[str] | set[str] | None = None,
191
+ p_dropout_ref_conf: float = 0.0, # Unused
192
+ ):
193
+ """Build the AF3 pipeline with specified parameters.
194
+
195
+ This function constructs a pipeline of transforms for processing protein structures
196
+ in a manner similar to AlphaFold 3. The pipeline includes steps for removing hydrogens,
197
+ adding annotations, atomizing residues, cropping, adding templates, encoding features,
198
+ and generating reference molecule features.
199
+
200
+ Args:
201
+ crop_size (int, optional): The size of the crop. Defaults to 384.
202
+ crop_center_cutoff_distance (float, optional): The cutoff distance for spatial cropping.
203
+ Defaults to 15.0.
204
+ crop_contiguous_probability (float, optional): The probability of using contiguous cropping.
205
+ Defaults to 0.5.
206
+ crop_spatial_probability (float, optional): The probability of using spatial cropping.
207
+ Defaults to 0.5.
208
+ conformer_generation_timeout (float, optional): The timeout for conformer generation in seconds.
209
+ Defaults to 10.0.
210
+ metrics_tags (list[str] | set[str] | None, optional): Tags to use for determining which Metrics apply.
211
+ Defaults to None (tags not added).
212
+
213
+ Returns:
214
+ Transform: A composed pipeline of transforms.
215
+
216
+ Raises:
217
+ AssertionError: If the crop probabilities do not sum to 1.0, if the crop size is not positive,
218
+ or if the crop center cutoff distance is not positive.
219
+
220
+ Note:
221
+ The cropping method is chosen randomly based on the provided probabilities.
222
+ The pipeline includes steps for processing the structure, adding annotations,
223
+ and generating features required for AF3-like predictions.
224
+
225
+ References:
226
+ - AlphaFold 3 Supplementary Information.
227
+ https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
228
+ """
229
+
230
+ if (
231
+ crop_contiguous_probability > 0 or crop_spatial_probability > 0
232
+ ) and not is_inference:
233
+ assert np.isclose(
234
+ crop_contiguous_probability + crop_spatial_probability, 1.0, atol=1e-6
235
+ ), "Crop probabilities must sum to 1.0"
236
+ assert crop_size > 0, "Crop size must be greater than 0"
237
+ assert (
238
+ crop_center_cutoff_distance > 0
239
+ ), "Crop center cutoff distance must be greater than 0"
240
+
241
+ af3_sequence_encoding = AF3SequenceEncoding()
242
+ rf2aa_sequence_encoding = RF2AA_ATOM36_ENCODING
243
+
244
+ transforms = [
245
+ AddData(
246
+ {"is_inference": is_inference, "run_confidence_head": run_confidence_head}
247
+ ),
248
+ # ... unconditional vs. conditional
249
+ TrainingRoute(
250
+ RandomRoute(
251
+ transforms=[
252
+ AddData({"is_unconditional": True}),
253
+ AddData({"is_unconditional": False}),
254
+ ],
255
+ probs=[p_unconditional, 1 - p_unconditional],
256
+ ),
257
+ ),
258
+ RemoveHydrogens(),
259
+ TrainingRoute(
260
+ FilterToSpecifiedPNUnits(
261
+ extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
262
+ ),
263
+ ),
264
+ ]
265
+
266
+ if exists(metrics_tags):
267
+ transforms.append(AddData({"metrics_tags": metrics_tags}))
268
+
269
+ transforms.append(
270
+ ConditionalRoute(
271
+ condition_func=lambda data: data.get("is_inference", False),
272
+ transform_map={
273
+ True: Identity(),
274
+ False: RandomlyMirrorInputs(mirror_prob),
275
+ },
276
+ )
277
+ )
278
+
279
+ transforms += [
280
+ RemoveTerminalOxygen(),
281
+ TrainingRoute(
282
+ SetOccToZeroOnBfactor(b_factor_min, b_factor_max),
283
+ ),
284
+ TrainingRoute(RemoveUnresolvedPNUnits()),
285
+ RemovePolymersWithTooFewResolvedResidues(min_residues=4),
286
+ MaskPolymerResiduesWithUnresolvedFrameAtoms(),
287
+ ConditionalRoute(
288
+ condition_func=lambda data: data.get("is_inference", False),
289
+ transform_map={
290
+ # UNX causes RDKit to crash (element is "X"), so we exclude even at inference
291
+ True: HandleUndesiredResTokens(undesired_res_tokens=["UNX"]),
292
+ False: HandleUndesiredResTokens(
293
+ undesired_res_tokens=undesired_res_names
294
+ ),
295
+ },
296
+ ),
297
+ # NOTE: this is used in training to pad DNA sequences, but we don't use it in inference
298
+ # TrainingRoute(
299
+ # PadDNA(p_skip=pad_dna_p_skip),
300
+ # ),
301
+ FlagAndReassignCovalentModifications(),
302
+ FlagNonPolymersForAtomization(),
303
+ ]
304
+
305
+ transforms.append(
306
+ ConditionalRoute(
307
+ condition_func=lambda data: data.get("is_inference", False),
308
+ transform_map={
309
+ True: Identity(),
310
+ False: RandomAtomizeResidues(atomization_prob),
311
+ },
312
+ )
313
+ )
314
+
315
+ transforms.append(
316
+ ConditionalRoute(
317
+ condition_func=lambda data: data.get("is_inference", False),
318
+ transform_map={
319
+ True: Identity(),
320
+ False: RandomlyRemoveLigands(ligand_dropout_prob),
321
+ },
322
+ )
323
+ )
324
+
325
+ transforms += [
326
+ AddGlobalAtomIdAnnotation(),
327
+ AtomizeByCCDName(
328
+ atomize_by_default=True,
329
+ res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA,
330
+ move_atomized_part_to_end=False,
331
+ validate_atomize=False,
332
+ ),
333
+ RemoveNucleicAcidTerminalOxygen(),
334
+ AddWithinChainInstanceResIdx(),
335
+ AddWithinPolyResIdxAnnotation(),
336
+ ]
337
+
338
+ # Crop
339
+
340
+ # ... crop around our query pn_unit(s) early, since we don't need the full structure moving forward
341
+ cropping_transform = Identity()
342
+ if crop_size is not None:
343
+ cropping_transform = RandomRoute(
344
+ transforms=[
345
+ CropContiguousLikeAF3(
346
+ crop_size=crop_size,
347
+ keep_uncropped_atom_array=True,
348
+ max_atoms_in_crop=max_atoms_in_crop,
349
+ ),
350
+ CropSpatialLikeAF3(
351
+ crop_size=crop_size,
352
+ crop_center_cutoff_distance=crop_center_cutoff_distance,
353
+ keep_uncropped_atom_array=True,
354
+ max_atoms_in_crop=max_atoms_in_crop,
355
+ ),
356
+ ],
357
+ probs=[crop_contiguous_probability, crop_spatial_probability],
358
+ )
359
+
360
+ transforms += [
361
+ TrainingRoute(ApplyFunction(annotate_pre_crop_hash)),
362
+ ConditionalRoute(
363
+ condition_func=lambda data: data.get("is_inference", False),
364
+ transform_map={
365
+ True: Identity(),
366
+ False: cropping_transform,
367
+ # Default to Identity during inference (`is_inference == True`)
368
+ },
369
+ ),
370
+ TrainingRoute(ApplyFunction(annotate_post_crop_hash)),
371
+ ]
372
+
373
+ if mask_crop_edges:
374
+ transforms += [
375
+ TrainingRoute(ApplyFunction(set_to_occupancy_0_where_crop_hashes_differ)),
376
+ ]
377
+
378
+ # +-----------------------------------------------------------+
379
+ # +------------------ GROUND TRUTH TEMPLATE ------------------+
380
+ # +-----------------------------------------------------------+
381
+
382
+ # Ground truth template noising (for training)
383
+ transforms.append(
384
+ build_ground_truth_distogram_transform(
385
+ template_noise_scales=template_noise_scales,
386
+ allowed_chain_types_for_conditioning=allowed_chain_types_for_conditioning,
387
+ p_condition_per_token=p_condition_per_token,
388
+ p_provide_inter_molecule_distances=p_provide_inter_molecule_distances,
389
+ is_inference=is_inference,
390
+ )
391
+ )
392
+
393
+ # +----------------------------------------------------------------------+
394
+ # +------------------ GROUND TRUTH REFERENCE CONFORMER ------------------+
395
+ # +----------------------------------------------------------------------+
396
+
397
+ transforms.append(
398
+ RandomApplyGroundTruthConformerByChainType(
399
+ chain_type_probabilities={
400
+ tuple(ChainType.get_polymers()): p_give_polymer_ref_conf,
401
+ tuple(ChainType.get_non_polymers()): p_give_non_polymer_ref_conf,
402
+ },
403
+ policy=GroundTruthConformerPolicy.ADD,
404
+ )
405
+ )
406
+
407
+ transforms += [
408
+ AddGlobalTokenIdAnnotation(), # required for reference molecule features and TokenToAtomMap
409
+ AddGlobalResIdAnnotation(),
410
+ LoadCachedResidueLevelData(
411
+ dir=Path(residue_cache_dir) if exists(residue_cache_dir) else None,
412
+ sharding_depth=1,
413
+ ),
414
+ RandomSubsampleCachedConformers(n_conformers=n_conformers),
415
+ EncodeAF3TokenLevelFeatures(sequence_encoding=af3_sequence_encoding),
416
+ GetAF3ReferenceMoleculeFeatures(
417
+ conformer_generation_timeout=conformer_generation_timeout,
418
+ use_element_for_atom_names_of_atomized_tokens=use_element_for_atom_names_of_atomized_tokens,
419
+ ),
420
+ FeaturizeAtomLevelEmbeddings(
421
+ mask_rdkit_conformers=False,
422
+ p_dropout_atom_level_embeddings=p_dropout_atom_level_embeddings,
423
+ embedding_dim=embedding_dim,
424
+ n_conformers=n_conformers,
425
+ ),
426
+ FindAutomorphismsWithNetworkX(), # Adds the "automorphisms" key to the data dictionary
427
+ ComputeAtomToTokenMap(),
428
+ GetRDKitChiralCenters(),
429
+ AddAF3ChiralFeatures(
430
+ take_first_chiral_subordering=take_first_chiral_subordering
431
+ ),
432
+ ]
433
+
434
+ transforms += [
435
+ # ... load and pair MSAs
436
+ LoadPolymerMSAs(
437
+ protein_msa_dirs=protein_msa_dirs,
438
+ rna_msa_dirs=rna_msa_dirs,
439
+ max_msa_sequences=max_msa_sequences, # maximum number of sequences to load (we later subsample further)
440
+ msa_cache_dir=Path(msa_cache_dir) if exists(msa_cache_dir) else None,
441
+ use_paths_in_chain_info=True, # if there are paths specified in the `chain_info` for a given chain, use them
442
+ raise_if_missing_msa_for_protein_of_length_n=raise_if_missing_msa_for_protein_of_length_n,
443
+ ),
444
+ PairAndMergePolymerMSAs(
445
+ dense=dense_msa, add_residue_is_paired_feature=add_residue_is_paired_feature
446
+ ),
447
+ ]
448
+
449
+ transforms += [
450
+ # ... encode MSA to AF-3 format
451
+ EncodeMSA(
452
+ encoding=af3_sequence_encoding,
453
+ token_to_use_for_gap=af3_sequence_encoding.token_to_idx["<G>"],
454
+ ),
455
+ # ... fill MSA, indexing into only the portions of the polymers that are present in the cropped structure
456
+ FillFullMSAFromEncoded(
457
+ pad_token=af3_sequence_encoding.token_to_idx["<G>"],
458
+ add_residue_is_paired_feature=add_residue_is_paired_feature,
459
+ ),
460
+ ConditionalRoute(
461
+ condition_func=lambda data: data.get("is_inference", False),
462
+ transform_map={
463
+ True: AddAF3TokenBondFeatures(np.inf),
464
+ False: AddAF3TokenBondFeatures(),
465
+ },
466
+ ),
467
+ ]
468
+
469
+ if add_cyclic_bonds:
470
+ transforms += [
471
+ AddCyclicBonds(),
472
+ ]
473
+
474
+ transforms += [
475
+ # ... featurize MSA
476
+ ConvertToTorch(
477
+ keys=[
478
+ "encoded",
479
+ "feats",
480
+ "full_msa_details",
481
+ ]
482
+ ),
483
+ FeaturizeMSALikeAF3(
484
+ encoding=af3_sequence_encoding,
485
+ n_recycles=n_recycles,
486
+ n_msa=n_msa,
487
+ ),
488
+ # Prepare coordinates for noising (without modifying the ground truth)
489
+ # ... add placeholder coordinates for noising
490
+ CopyAnnotation(annotation_to_copy="coord", new_annotation="coord_to_be_noised"),
491
+ # ... handling of unresolved residues (note that these Transforms create the "atom_array_to_noise" dictionary, if not already present)
492
+ PlaceUnresolvedTokenAtomsOnRepresentativeAtom(
493
+ annotation_to_update="coord_to_be_noised"
494
+ ),
495
+ PlaceUnresolvedTokenOnClosestResolvedTokenInSequence(
496
+ annotation_to_update="coord_to_be_noised",
497
+ annotation_to_copy="coord_to_be_noised",
498
+ ),
499
+ # Feature aggregation
500
+ AggregateFeaturesLikeAF3(),
501
+ # ... batching and noise sampling for diffusion
502
+ BatchStructuresForDiffusionNoising(batch_size=diffusion_batch_size),
503
+ CenterRandomAugmentation(batch_size=diffusion_batch_size),
504
+ SampleEDMNoise(
505
+ sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size
506
+ ),
507
+ CheckForNaNsInInputs(),
508
+ ]
509
+
510
+ confidence_transforms = Compose(
511
+ [
512
+ # Additions required for confidence calculation
513
+ EncodeAtomArray(rf2aa_sequence_encoding),
514
+ AddAtomFrames(),
515
+ AddIsRealAtom(rf2aa_sequence_encoding),
516
+ AddPolymerFrameIndices(),
517
+ # wrap it all together
518
+ PackageConfidenceFeats(),
519
+ ]
520
+ )
521
+
522
+ transforms.append(
523
+ ConditionalRoute(
524
+ condition_func=lambda data: data.get("run_confidence_head", False),
525
+ transform_map={
526
+ True: confidence_transforms,
527
+ False: Identity(),
528
+ },
529
+ )
530
+ )
531
+
532
+ keys_to_keep = [
533
+ "example_id",
534
+ "feats",
535
+ "t",
536
+ "noise",
537
+ "ground_truth",
538
+ "coord_atom_lvl_to_be_noised",
539
+ "automorphisms",
540
+ "symmetry_resolution",
541
+ "extra_info",
542
+ ]
543
+
544
+ if run_confidence_head:
545
+ keys_to_keep.append("confidence_feats")
546
+
547
+ if return_atom_array: # and is_inference:
548
+ keys_to_keep.append("atom_array")
549
+
550
+ transforms += [
551
+ # Subset to only keys necessary
552
+ SubsetToKeys(keys_to_keep)
553
+ ]
554
+
555
+ # ... compose final pipeline
556
+ pipeline = Compose(transforms)
557
+
558
+ return pipeline