rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/data/pipelines.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
from os import PathLike
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from atomworks.common import exists
|
|
6
|
+
from atomworks.constants import (
|
|
7
|
+
AF3_EXCLUDED_LIGANDS,
|
|
8
|
+
STANDARD_AA,
|
|
9
|
+
STANDARD_DNA,
|
|
10
|
+
STANDARD_RNA,
|
|
11
|
+
)
|
|
12
|
+
from atomworks.enums import ChainType
|
|
13
|
+
from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING, AF3SequenceEncoding
|
|
14
|
+
from atomworks.ml.transforms.af3_reference_molecule import (
|
|
15
|
+
GetAF3ReferenceMoleculeFeatures,
|
|
16
|
+
GroundTruthConformerPolicy,
|
|
17
|
+
RandomApplyGroundTruthConformerByChainType,
|
|
18
|
+
)
|
|
19
|
+
from atomworks.ml.transforms.atom_array import (
|
|
20
|
+
AddGlobalAtomIdAnnotation,
|
|
21
|
+
AddGlobalResIdAnnotation,
|
|
22
|
+
AddGlobalTokenIdAnnotation,
|
|
23
|
+
AddWithinChainInstanceResIdx,
|
|
24
|
+
AddWithinPolyResIdxAnnotation,
|
|
25
|
+
ComputeAtomToTokenMap,
|
|
26
|
+
CopyAnnotation,
|
|
27
|
+
)
|
|
28
|
+
from atomworks.ml.transforms.atom_frames import (
|
|
29
|
+
AddAtomFrames,
|
|
30
|
+
AddIsRealAtom,
|
|
31
|
+
AddPolymerFrameIndices,
|
|
32
|
+
)
|
|
33
|
+
from atomworks.ml.transforms.atom_level_embeddings import FeaturizeAtomLevelEmbeddings
|
|
34
|
+
from atomworks.ml.transforms.atomize import (
|
|
35
|
+
AtomizeByCCDName,
|
|
36
|
+
FlagNonPolymersForAtomization,
|
|
37
|
+
)
|
|
38
|
+
from atomworks.ml.transforms.base import (
|
|
39
|
+
AddData,
|
|
40
|
+
ApplyFunction,
|
|
41
|
+
Compose,
|
|
42
|
+
ConditionalRoute,
|
|
43
|
+
ConvertToTorch,
|
|
44
|
+
Identity,
|
|
45
|
+
RandomRoute,
|
|
46
|
+
SubsetToKeys,
|
|
47
|
+
)
|
|
48
|
+
from atomworks.ml.transforms.bfactor_conditioned_transforms import SetOccToZeroOnBfactor
|
|
49
|
+
from atomworks.ml.transforms.bonds import (
|
|
50
|
+
AddAF3TokenBondFeatures,
|
|
51
|
+
)
|
|
52
|
+
from atomworks.ml.transforms.cached_residue_data import (
|
|
53
|
+
LoadCachedResidueLevelData,
|
|
54
|
+
RandomSubsampleCachedConformers,
|
|
55
|
+
)
|
|
56
|
+
from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
|
|
57
|
+
from atomworks.ml.transforms.chirals import AddAF3ChiralFeatures
|
|
58
|
+
from atomworks.ml.transforms.covalent_modifications import (
|
|
59
|
+
FlagAndReassignCovalentModifications,
|
|
60
|
+
)
|
|
61
|
+
from atomworks.ml.transforms.crop import CropContiguousLikeAF3, CropSpatialLikeAF3
|
|
62
|
+
from atomworks.ml.transforms.diffusion.batch_structures import (
|
|
63
|
+
BatchStructuresForDiffusionNoising,
|
|
64
|
+
)
|
|
65
|
+
from atomworks.ml.transforms.diffusion.edm import SampleEDMNoise
|
|
66
|
+
from atomworks.ml.transforms.encoding import (
|
|
67
|
+
EncodeAF3TokenLevelFeatures,
|
|
68
|
+
EncodeAtomArray,
|
|
69
|
+
)
|
|
70
|
+
from atomworks.ml.transforms.feature_aggregation.af3 import AggregateFeaturesLikeAF3
|
|
71
|
+
from atomworks.ml.transforms.feature_aggregation.confidence import (
|
|
72
|
+
PackageConfidenceFeats,
|
|
73
|
+
)
|
|
74
|
+
from atomworks.ml.transforms.featurize_unresolved_residues import (
|
|
75
|
+
MaskPolymerResiduesWithUnresolvedFrameAtoms,
|
|
76
|
+
PlaceUnresolvedTokenAtomsOnRepresentativeAtom,
|
|
77
|
+
PlaceUnresolvedTokenOnClosestResolvedTokenInSequence,
|
|
78
|
+
)
|
|
79
|
+
from atomworks.ml.transforms.filters import (
|
|
80
|
+
FilterToSpecifiedPNUnits,
|
|
81
|
+
HandleUndesiredResTokens,
|
|
82
|
+
RandomlyRemoveLigands,
|
|
83
|
+
RemoveHydrogens,
|
|
84
|
+
RemoveNucleicAcidTerminalOxygen,
|
|
85
|
+
RemovePolymersWithTooFewResolvedResidues,
|
|
86
|
+
RemoveTerminalOxygen,
|
|
87
|
+
RemoveUnresolvedPNUnits,
|
|
88
|
+
)
|
|
89
|
+
from atomworks.ml.transforms.mirror_transform import RandomlyMirrorInputs
|
|
90
|
+
from atomworks.ml.transforms.msa.msa import (
|
|
91
|
+
EncodeMSA,
|
|
92
|
+
FeaturizeMSALikeAF3,
|
|
93
|
+
FillFullMSAFromEncoded,
|
|
94
|
+
LoadPolymerMSAs,
|
|
95
|
+
PairAndMergePolymerMSAs,
|
|
96
|
+
)
|
|
97
|
+
from atomworks.ml.transforms.random_atomize_residues import RandomAtomizeResidues
|
|
98
|
+
from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
|
|
99
|
+
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
|
|
100
|
+
from omegaconf import DictConfig
|
|
101
|
+
from rf3.data.cyclic_transform import AddCyclicBonds
|
|
102
|
+
from rf3.data.extra_xforms import CheckForNaNsInInputs
|
|
103
|
+
from rf3.data.pipeline_utils import (
|
|
104
|
+
annotate_post_crop_hash,
|
|
105
|
+
annotate_pre_crop_hash,
|
|
106
|
+
build_ground_truth_distogram_transform,
|
|
107
|
+
set_to_occupancy_0_where_crop_hashes_differ,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def TrainingRoute(transform):
|
|
112
|
+
return ConditionalRoute(
|
|
113
|
+
condition_func=lambda data: data["is_inference"],
|
|
114
|
+
transform_map={True: Identity(), False: transform},
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def InferenceRoute(transform):
|
|
119
|
+
return ConditionalRoute(
|
|
120
|
+
condition_func=lambda data: data["is_inference"],
|
|
121
|
+
transform_map={False: Identity(), True: transform},
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def build_af3_transform_pipeline(
|
|
126
|
+
*,
|
|
127
|
+
# Training or inference (required)
|
|
128
|
+
is_inference: bool, # If True, we skip cropping, etc.
|
|
129
|
+
# MSA dirs
|
|
130
|
+
protein_msa_dirs: list[dict],
|
|
131
|
+
rna_msa_dirs: list[dict],
|
|
132
|
+
# Recycles
|
|
133
|
+
n_recycles: int = 5,
|
|
134
|
+
# Crop params
|
|
135
|
+
crop_size: int = 384,
|
|
136
|
+
crop_center_cutoff_distance: float = 15.0,
|
|
137
|
+
crop_contiguous_probability: float = 0.5,
|
|
138
|
+
crop_spatial_probability: float = 0.5,
|
|
139
|
+
max_atoms_in_crop: int | None = None,
|
|
140
|
+
# Undesired res names
|
|
141
|
+
undesired_res_names: list[str] = AF3_EXCLUDED_LIGANDS,
|
|
142
|
+
# Conformer generation params
|
|
143
|
+
conformer_generation_timeout: float = 5.0, # seconds
|
|
144
|
+
use_element_for_atom_names_of_atomized_tokens: bool = False,
|
|
145
|
+
# MSA parameters
|
|
146
|
+
max_msa_sequences: int = 10_000, # Paper: 16,000, but we only have 10K stored on disk
|
|
147
|
+
n_msa: int = 10_000, # Paper: ?? I think ~12K?
|
|
148
|
+
dense_msa: bool = True, # True for AF3
|
|
149
|
+
add_residue_is_paired_feature: bool = False,
|
|
150
|
+
# Cache paths
|
|
151
|
+
msa_cache_dir: PathLike | str | None = None,
|
|
152
|
+
residue_cache_dir: PathLike
|
|
153
|
+
| str
|
|
154
|
+
| None = "/net/tukwila/lschaaf/datahub/MACE-OMOL-Jul2025/mace_embeddings",
|
|
155
|
+
# Diffusion parameters
|
|
156
|
+
sigma_data: float = 16.0,
|
|
157
|
+
diffusion_batch_size: int = 48,
|
|
158
|
+
# Whether to include features for confidence head
|
|
159
|
+
run_confidence_head: bool = False,
|
|
160
|
+
return_atom_array: bool = True,
|
|
161
|
+
# DNA
|
|
162
|
+
pad_dna_p_skip: float = 0.0,
|
|
163
|
+
b_factor_min: float | None = None,
|
|
164
|
+
b_factor_max: float | None = None,
|
|
165
|
+
# ------ Atom-level conditioning ------ #
|
|
166
|
+
p_unconditional: float = 1.0, # Show no conditioning, anywhere (i.e., unconditional)
|
|
167
|
+
template_noise_scales: dict | DictConfig = {
|
|
168
|
+
"atomized": 1e-5, # No noise (for atomized tokens)
|
|
169
|
+
"not_atomized": 0.2, # Up to 0.2A of noise (for non-atomized tokens)
|
|
170
|
+
},
|
|
171
|
+
allowed_chain_types_for_conditioning: list[int | str | ChainType]
|
|
172
|
+
| None = ChainType.get_all_types(), # All chain types (None = no conditioning)
|
|
173
|
+
p_condition_per_token: float = 0.0, # When sampling with conditions, X% of tokens are conditioned (e.g., X^2% of pairs have conditions)
|
|
174
|
+
p_provide_inter_molecule_distances: float = 0.0, # When sampling with conditions, X% of the time, show any inter-molecule distances
|
|
175
|
+
# (Reference Conformer)
|
|
176
|
+
p_give_non_polymer_ref_conf: float = 0.0, # When sampling with conditions, X% of non-polymer chains get a ground-truth reference conformer
|
|
177
|
+
p_give_polymer_ref_conf: float = 0.0, # When sampling with conditions, X% of polymer chains get a ground-truth reference conformer
|
|
178
|
+
# -------------------------------------- #
|
|
179
|
+
take_first_chiral_subordering: bool = False,
|
|
180
|
+
mirror_prob: float = 0.0,
|
|
181
|
+
input_contains_explicit_msa: bool = False,
|
|
182
|
+
atomization_prob: float = 0.0,
|
|
183
|
+
ligand_dropout_prob: float = 0.0,
|
|
184
|
+
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
|
|
185
|
+
mask_crop_edges: bool = False,
|
|
186
|
+
p_dropout_atom_level_embeddings: float = 0.0,
|
|
187
|
+
embedding_dim: int = 384,
|
|
188
|
+
n_conformers: int = 8,
|
|
189
|
+
add_cyclic_bonds: bool = True,
|
|
190
|
+
metrics_tags: list[str] | set[str] | None = None,
|
|
191
|
+
p_dropout_ref_conf: float = 0.0, # Unused
|
|
192
|
+
):
|
|
193
|
+
"""Build the AF3 pipeline with specified parameters.
|
|
194
|
+
|
|
195
|
+
This function constructs a pipeline of transforms for processing protein structures
|
|
196
|
+
in a manner similar to AlphaFold 3. The pipeline includes steps for removing hydrogens,
|
|
197
|
+
adding annotations, atomizing residues, cropping, adding templates, encoding features,
|
|
198
|
+
and generating reference molecule features.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
crop_size (int, optional): The size of the crop. Defaults to 384.
|
|
202
|
+
crop_center_cutoff_distance (float, optional): The cutoff distance for spatial cropping.
|
|
203
|
+
Defaults to 15.0.
|
|
204
|
+
crop_contiguous_probability (float, optional): The probability of using contiguous cropping.
|
|
205
|
+
Defaults to 0.5.
|
|
206
|
+
crop_spatial_probability (float, optional): The probability of using spatial cropping.
|
|
207
|
+
Defaults to 0.5.
|
|
208
|
+
conformer_generation_timeout (float, optional): The timeout for conformer generation in seconds.
|
|
209
|
+
Defaults to 10.0.
|
|
210
|
+
metrics_tags (list[str] | set[str] | None, optional): Tags to use for determining which Metrics apply.
|
|
211
|
+
Defaults to None (tags not added).
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Transform: A composed pipeline of transforms.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
AssertionError: If the crop probabilities do not sum to 1.0, if the crop size is not positive,
|
|
218
|
+
or if the crop center cutoff distance is not positive.
|
|
219
|
+
|
|
220
|
+
Note:
|
|
221
|
+
The cropping method is chosen randomly based on the provided probabilities.
|
|
222
|
+
The pipeline includes steps for processing the structure, adding annotations,
|
|
223
|
+
and generating features required for AF3-like predictions.
|
|
224
|
+
|
|
225
|
+
References:
|
|
226
|
+
- AlphaFold 3 Supplementary Information.
|
|
227
|
+
https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
if (
|
|
231
|
+
crop_contiguous_probability > 0 or crop_spatial_probability > 0
|
|
232
|
+
) and not is_inference:
|
|
233
|
+
assert np.isclose(
|
|
234
|
+
crop_contiguous_probability + crop_spatial_probability, 1.0, atol=1e-6
|
|
235
|
+
), "Crop probabilities must sum to 1.0"
|
|
236
|
+
assert crop_size > 0, "Crop size must be greater than 0"
|
|
237
|
+
assert (
|
|
238
|
+
crop_center_cutoff_distance > 0
|
|
239
|
+
), "Crop center cutoff distance must be greater than 0"
|
|
240
|
+
|
|
241
|
+
af3_sequence_encoding = AF3SequenceEncoding()
|
|
242
|
+
rf2aa_sequence_encoding = RF2AA_ATOM36_ENCODING
|
|
243
|
+
|
|
244
|
+
transforms = [
|
|
245
|
+
AddData(
|
|
246
|
+
{"is_inference": is_inference, "run_confidence_head": run_confidence_head}
|
|
247
|
+
),
|
|
248
|
+
# ... unconditional vs. conditional
|
|
249
|
+
TrainingRoute(
|
|
250
|
+
RandomRoute(
|
|
251
|
+
transforms=[
|
|
252
|
+
AddData({"is_unconditional": True}),
|
|
253
|
+
AddData({"is_unconditional": False}),
|
|
254
|
+
],
|
|
255
|
+
probs=[p_unconditional, 1 - p_unconditional],
|
|
256
|
+
),
|
|
257
|
+
),
|
|
258
|
+
RemoveHydrogens(),
|
|
259
|
+
TrainingRoute(
|
|
260
|
+
FilterToSpecifiedPNUnits(
|
|
261
|
+
extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
|
|
262
|
+
),
|
|
263
|
+
),
|
|
264
|
+
]
|
|
265
|
+
|
|
266
|
+
if exists(metrics_tags):
|
|
267
|
+
transforms.append(AddData({"metrics_tags": metrics_tags}))
|
|
268
|
+
|
|
269
|
+
transforms.append(
|
|
270
|
+
ConditionalRoute(
|
|
271
|
+
condition_func=lambda data: data.get("is_inference", False),
|
|
272
|
+
transform_map={
|
|
273
|
+
True: Identity(),
|
|
274
|
+
False: RandomlyMirrorInputs(mirror_prob),
|
|
275
|
+
},
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
transforms += [
|
|
280
|
+
RemoveTerminalOxygen(),
|
|
281
|
+
TrainingRoute(
|
|
282
|
+
SetOccToZeroOnBfactor(b_factor_min, b_factor_max),
|
|
283
|
+
),
|
|
284
|
+
TrainingRoute(RemoveUnresolvedPNUnits()),
|
|
285
|
+
RemovePolymersWithTooFewResolvedResidues(min_residues=4),
|
|
286
|
+
MaskPolymerResiduesWithUnresolvedFrameAtoms(),
|
|
287
|
+
ConditionalRoute(
|
|
288
|
+
condition_func=lambda data: data.get("is_inference", False),
|
|
289
|
+
transform_map={
|
|
290
|
+
# UNX causes RDKit to crash (element is "X"), so we exclude even at inference
|
|
291
|
+
True: HandleUndesiredResTokens(undesired_res_tokens=["UNX"]),
|
|
292
|
+
False: HandleUndesiredResTokens(
|
|
293
|
+
undesired_res_tokens=undesired_res_names
|
|
294
|
+
),
|
|
295
|
+
},
|
|
296
|
+
),
|
|
297
|
+
# NOTE: this is used in training to pad DNA sequences, but we don't use it in inference
|
|
298
|
+
# TrainingRoute(
|
|
299
|
+
# PadDNA(p_skip=pad_dna_p_skip),
|
|
300
|
+
# ),
|
|
301
|
+
FlagAndReassignCovalentModifications(),
|
|
302
|
+
FlagNonPolymersForAtomization(),
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
transforms.append(
|
|
306
|
+
ConditionalRoute(
|
|
307
|
+
condition_func=lambda data: data.get("is_inference", False),
|
|
308
|
+
transform_map={
|
|
309
|
+
True: Identity(),
|
|
310
|
+
False: RandomAtomizeResidues(atomization_prob),
|
|
311
|
+
},
|
|
312
|
+
)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
transforms.append(
|
|
316
|
+
ConditionalRoute(
|
|
317
|
+
condition_func=lambda data: data.get("is_inference", False),
|
|
318
|
+
transform_map={
|
|
319
|
+
True: Identity(),
|
|
320
|
+
False: RandomlyRemoveLigands(ligand_dropout_prob),
|
|
321
|
+
},
|
|
322
|
+
)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
transforms += [
|
|
326
|
+
AddGlobalAtomIdAnnotation(),
|
|
327
|
+
AtomizeByCCDName(
|
|
328
|
+
atomize_by_default=True,
|
|
329
|
+
res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA,
|
|
330
|
+
move_atomized_part_to_end=False,
|
|
331
|
+
validate_atomize=False,
|
|
332
|
+
),
|
|
333
|
+
RemoveNucleicAcidTerminalOxygen(),
|
|
334
|
+
AddWithinChainInstanceResIdx(),
|
|
335
|
+
AddWithinPolyResIdxAnnotation(),
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
# Crop
|
|
339
|
+
|
|
340
|
+
# ... crop around our query pn_unit(s) early, since we don't need the full structure moving forward
|
|
341
|
+
cropping_transform = Identity()
|
|
342
|
+
if crop_size is not None:
|
|
343
|
+
cropping_transform = RandomRoute(
|
|
344
|
+
transforms=[
|
|
345
|
+
CropContiguousLikeAF3(
|
|
346
|
+
crop_size=crop_size,
|
|
347
|
+
keep_uncropped_atom_array=True,
|
|
348
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
349
|
+
),
|
|
350
|
+
CropSpatialLikeAF3(
|
|
351
|
+
crop_size=crop_size,
|
|
352
|
+
crop_center_cutoff_distance=crop_center_cutoff_distance,
|
|
353
|
+
keep_uncropped_atom_array=True,
|
|
354
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
355
|
+
),
|
|
356
|
+
],
|
|
357
|
+
probs=[crop_contiguous_probability, crop_spatial_probability],
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
transforms += [
|
|
361
|
+
TrainingRoute(ApplyFunction(annotate_pre_crop_hash)),
|
|
362
|
+
ConditionalRoute(
|
|
363
|
+
condition_func=lambda data: data.get("is_inference", False),
|
|
364
|
+
transform_map={
|
|
365
|
+
True: Identity(),
|
|
366
|
+
False: cropping_transform,
|
|
367
|
+
# Default to Identity during inference (`is_inference == True`)
|
|
368
|
+
},
|
|
369
|
+
),
|
|
370
|
+
TrainingRoute(ApplyFunction(annotate_post_crop_hash)),
|
|
371
|
+
]
|
|
372
|
+
|
|
373
|
+
if mask_crop_edges:
|
|
374
|
+
transforms += [
|
|
375
|
+
TrainingRoute(ApplyFunction(set_to_occupancy_0_where_crop_hashes_differ)),
|
|
376
|
+
]
|
|
377
|
+
|
|
378
|
+
# +-----------------------------------------------------------+
|
|
379
|
+
# +------------------ GROUND TRUTH TEMPLATE ------------------+
|
|
380
|
+
# +-----------------------------------------------------------+
|
|
381
|
+
|
|
382
|
+
# Ground truth template noising (for training)
|
|
383
|
+
transforms.append(
|
|
384
|
+
build_ground_truth_distogram_transform(
|
|
385
|
+
template_noise_scales=template_noise_scales,
|
|
386
|
+
allowed_chain_types_for_conditioning=allowed_chain_types_for_conditioning,
|
|
387
|
+
p_condition_per_token=p_condition_per_token,
|
|
388
|
+
p_provide_inter_molecule_distances=p_provide_inter_molecule_distances,
|
|
389
|
+
is_inference=is_inference,
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# +----------------------------------------------------------------------+
|
|
394
|
+
# +------------------ GROUND TRUTH REFERENCE CONFORMER ------------------+
|
|
395
|
+
# +----------------------------------------------------------------------+
|
|
396
|
+
|
|
397
|
+
transforms.append(
|
|
398
|
+
RandomApplyGroundTruthConformerByChainType(
|
|
399
|
+
chain_type_probabilities={
|
|
400
|
+
tuple(ChainType.get_polymers()): p_give_polymer_ref_conf,
|
|
401
|
+
tuple(ChainType.get_non_polymers()): p_give_non_polymer_ref_conf,
|
|
402
|
+
},
|
|
403
|
+
policy=GroundTruthConformerPolicy.ADD,
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
transforms += [
|
|
408
|
+
AddGlobalTokenIdAnnotation(), # required for reference molecule features and TokenToAtomMap
|
|
409
|
+
AddGlobalResIdAnnotation(),
|
|
410
|
+
LoadCachedResidueLevelData(
|
|
411
|
+
dir=Path(residue_cache_dir) if exists(residue_cache_dir) else None,
|
|
412
|
+
sharding_depth=1,
|
|
413
|
+
),
|
|
414
|
+
RandomSubsampleCachedConformers(n_conformers=n_conformers),
|
|
415
|
+
EncodeAF3TokenLevelFeatures(sequence_encoding=af3_sequence_encoding),
|
|
416
|
+
GetAF3ReferenceMoleculeFeatures(
|
|
417
|
+
conformer_generation_timeout=conformer_generation_timeout,
|
|
418
|
+
use_element_for_atom_names_of_atomized_tokens=use_element_for_atom_names_of_atomized_tokens,
|
|
419
|
+
),
|
|
420
|
+
FeaturizeAtomLevelEmbeddings(
|
|
421
|
+
mask_rdkit_conformers=False,
|
|
422
|
+
p_dropout_atom_level_embeddings=p_dropout_atom_level_embeddings,
|
|
423
|
+
embedding_dim=embedding_dim,
|
|
424
|
+
n_conformers=n_conformers,
|
|
425
|
+
),
|
|
426
|
+
FindAutomorphismsWithNetworkX(), # Adds the "automorphisms" key to the data dictionary
|
|
427
|
+
ComputeAtomToTokenMap(),
|
|
428
|
+
GetRDKitChiralCenters(),
|
|
429
|
+
AddAF3ChiralFeatures(
|
|
430
|
+
take_first_chiral_subordering=take_first_chiral_subordering
|
|
431
|
+
),
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
transforms += [
|
|
435
|
+
# ... load and pair MSAs
|
|
436
|
+
LoadPolymerMSAs(
|
|
437
|
+
protein_msa_dirs=protein_msa_dirs,
|
|
438
|
+
rna_msa_dirs=rna_msa_dirs,
|
|
439
|
+
max_msa_sequences=max_msa_sequences, # maximum number of sequences to load (we later subsample further)
|
|
440
|
+
msa_cache_dir=Path(msa_cache_dir) if exists(msa_cache_dir) else None,
|
|
441
|
+
use_paths_in_chain_info=True, # if there are paths specified in the `chain_info` for a given chain, use them
|
|
442
|
+
raise_if_missing_msa_for_protein_of_length_n=raise_if_missing_msa_for_protein_of_length_n,
|
|
443
|
+
),
|
|
444
|
+
PairAndMergePolymerMSAs(
|
|
445
|
+
dense=dense_msa, add_residue_is_paired_feature=add_residue_is_paired_feature
|
|
446
|
+
),
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
transforms += [
|
|
450
|
+
# ... encode MSA to AF-3 format
|
|
451
|
+
EncodeMSA(
|
|
452
|
+
encoding=af3_sequence_encoding,
|
|
453
|
+
token_to_use_for_gap=af3_sequence_encoding.token_to_idx["<G>"],
|
|
454
|
+
),
|
|
455
|
+
# ... fill MSA, indexing into only the portions of the polymers that are present in the cropped structure
|
|
456
|
+
FillFullMSAFromEncoded(
|
|
457
|
+
pad_token=af3_sequence_encoding.token_to_idx["<G>"],
|
|
458
|
+
add_residue_is_paired_feature=add_residue_is_paired_feature,
|
|
459
|
+
),
|
|
460
|
+
ConditionalRoute(
|
|
461
|
+
condition_func=lambda data: data.get("is_inference", False),
|
|
462
|
+
transform_map={
|
|
463
|
+
True: AddAF3TokenBondFeatures(np.inf),
|
|
464
|
+
False: AddAF3TokenBondFeatures(),
|
|
465
|
+
},
|
|
466
|
+
),
|
|
467
|
+
]
|
|
468
|
+
|
|
469
|
+
if add_cyclic_bonds:
|
|
470
|
+
transforms += [
|
|
471
|
+
AddCyclicBonds(),
|
|
472
|
+
]
|
|
473
|
+
|
|
474
|
+
transforms += [
|
|
475
|
+
# ... featurize MSA
|
|
476
|
+
ConvertToTorch(
|
|
477
|
+
keys=[
|
|
478
|
+
"encoded",
|
|
479
|
+
"feats",
|
|
480
|
+
"full_msa_details",
|
|
481
|
+
]
|
|
482
|
+
),
|
|
483
|
+
FeaturizeMSALikeAF3(
|
|
484
|
+
encoding=af3_sequence_encoding,
|
|
485
|
+
n_recycles=n_recycles,
|
|
486
|
+
n_msa=n_msa,
|
|
487
|
+
),
|
|
488
|
+
# Prepare coordinates for noising (without modifying the ground truth)
|
|
489
|
+
# ... add placeholder coordinates for noising
|
|
490
|
+
CopyAnnotation(annotation_to_copy="coord", new_annotation="coord_to_be_noised"),
|
|
491
|
+
# ... handling of unresolved residues (note that these Transforms create the "atom_array_to_noise" dictionary, if not already present)
|
|
492
|
+
PlaceUnresolvedTokenAtomsOnRepresentativeAtom(
|
|
493
|
+
annotation_to_update="coord_to_be_noised"
|
|
494
|
+
),
|
|
495
|
+
PlaceUnresolvedTokenOnClosestResolvedTokenInSequence(
|
|
496
|
+
annotation_to_update="coord_to_be_noised",
|
|
497
|
+
annotation_to_copy="coord_to_be_noised",
|
|
498
|
+
),
|
|
499
|
+
# Feature aggregation
|
|
500
|
+
AggregateFeaturesLikeAF3(),
|
|
501
|
+
# ... batching and noise sampling for diffusion
|
|
502
|
+
BatchStructuresForDiffusionNoising(batch_size=diffusion_batch_size),
|
|
503
|
+
CenterRandomAugmentation(batch_size=diffusion_batch_size),
|
|
504
|
+
SampleEDMNoise(
|
|
505
|
+
sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size
|
|
506
|
+
),
|
|
507
|
+
CheckForNaNsInInputs(),
|
|
508
|
+
]
|
|
509
|
+
|
|
510
|
+
confidence_transforms = Compose(
|
|
511
|
+
[
|
|
512
|
+
# Additions required for confidence calculation
|
|
513
|
+
EncodeAtomArray(rf2aa_sequence_encoding),
|
|
514
|
+
AddAtomFrames(),
|
|
515
|
+
AddIsRealAtom(rf2aa_sequence_encoding),
|
|
516
|
+
AddPolymerFrameIndices(),
|
|
517
|
+
# wrap it all together
|
|
518
|
+
PackageConfidenceFeats(),
|
|
519
|
+
]
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
transforms.append(
|
|
523
|
+
ConditionalRoute(
|
|
524
|
+
condition_func=lambda data: data.get("run_confidence_head", False),
|
|
525
|
+
transform_map={
|
|
526
|
+
True: confidence_transforms,
|
|
527
|
+
False: Identity(),
|
|
528
|
+
},
|
|
529
|
+
)
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
keys_to_keep = [
|
|
533
|
+
"example_id",
|
|
534
|
+
"feats",
|
|
535
|
+
"t",
|
|
536
|
+
"noise",
|
|
537
|
+
"ground_truth",
|
|
538
|
+
"coord_atom_lvl_to_be_noised",
|
|
539
|
+
"automorphisms",
|
|
540
|
+
"symmetry_resolution",
|
|
541
|
+
"extra_info",
|
|
542
|
+
]
|
|
543
|
+
|
|
544
|
+
if run_confidence_head:
|
|
545
|
+
keys_to_keep.append("confidence_feats")
|
|
546
|
+
|
|
547
|
+
if return_atom_array: # and is_inference:
|
|
548
|
+
keys_to_keep.append("atom_array")
|
|
549
|
+
|
|
550
|
+
transforms += [
|
|
551
|
+
# Subset to only keys necessary
|
|
552
|
+
SubsetToKeys(keys_to_keep)
|
|
553
|
+
]
|
|
554
|
+
|
|
555
|
+
# ... compose final pipeline
|
|
556
|
+
pipeline = Compose(transforms)
|
|
557
|
+
|
|
558
|
+
return pipeline
|