rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The Atom14 data pipeline for training and inference
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import warnings
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from atomworks.constants import (
|
|
11
|
+
AF3_EXCLUDED_LIGANDS,
|
|
12
|
+
GAP,
|
|
13
|
+
STANDARD_AA,
|
|
14
|
+
STANDARD_DNA,
|
|
15
|
+
STANDARD_RNA,
|
|
16
|
+
)
|
|
17
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
18
|
+
from atomworks.ml.transforms.atom_array import (
|
|
19
|
+
AddGlobalAtomIdAnnotation,
|
|
20
|
+
AddGlobalTokenIdAnnotation,
|
|
21
|
+
AddProteinTerminiAnnotation,
|
|
22
|
+
AddWithinChainInstanceResIdx,
|
|
23
|
+
AddWithinPolyResIdxAnnotation,
|
|
24
|
+
ComputeAtomToTokenMap,
|
|
25
|
+
CopyAnnotation,
|
|
26
|
+
)
|
|
27
|
+
from atomworks.ml.transforms.atomize import (
|
|
28
|
+
AtomizeByCCDName,
|
|
29
|
+
FlagNonPolymersForAtomization,
|
|
30
|
+
)
|
|
31
|
+
from atomworks.ml.transforms.base import (
|
|
32
|
+
AddData,
|
|
33
|
+
Compose,
|
|
34
|
+
ConditionalRoute,
|
|
35
|
+
ConvertToTorch,
|
|
36
|
+
Identity,
|
|
37
|
+
RandomRoute,
|
|
38
|
+
SubsetToKeys,
|
|
39
|
+
)
|
|
40
|
+
from atomworks.ml.transforms.bfactor_conditioned_transforms import SetOccToZeroOnBfactor
|
|
41
|
+
from atomworks.ml.transforms.bonds import AddAF3TokenBondFeatures
|
|
42
|
+
from atomworks.ml.transforms.cached_residue_data import LoadCachedResidueLevelData
|
|
43
|
+
from atomworks.ml.transforms.covalent_modifications import (
|
|
44
|
+
FlagAndReassignCovalentModifications,
|
|
45
|
+
)
|
|
46
|
+
from atomworks.ml.transforms.crop import CropContiguousLikeAF3, CropSpatialLikeAF3
|
|
47
|
+
from atomworks.ml.transforms.diffusion.batch_structures import (
|
|
48
|
+
BatchStructuresForDiffusionNoising,
|
|
49
|
+
)
|
|
50
|
+
from atomworks.ml.transforms.diffusion.edm import SampleEDMNoise
|
|
51
|
+
from atomworks.ml.transforms.featurize_unresolved_residues import (
|
|
52
|
+
MaskPolymerResiduesWithUnresolvedFrameAtoms,
|
|
53
|
+
PlaceUnresolvedTokenAtomsOnRepresentativeAtom,
|
|
54
|
+
PlaceUnresolvedTokenOnClosestResolvedTokenInSequence,
|
|
55
|
+
)
|
|
56
|
+
from atomworks.ml.transforms.filters import (
|
|
57
|
+
FilterToSpecifiedPNUnits,
|
|
58
|
+
HandleUndesiredResTokens,
|
|
59
|
+
RemoveHydrogens,
|
|
60
|
+
RemoveNucleicAcidTerminalOxygen,
|
|
61
|
+
RemovePolymersWithTooFewResolvedResidues,
|
|
62
|
+
RemoveTerminalOxygen,
|
|
63
|
+
RemoveUnresolvedLigandAtomsIfTooMany,
|
|
64
|
+
RemoveUnresolvedPNUnits,
|
|
65
|
+
)
|
|
66
|
+
from atomworks.ml.utils.token import get_token_count
|
|
67
|
+
from rfd3.transforms.conditioning_base import (
|
|
68
|
+
SampleConditioningFlags,
|
|
69
|
+
SampleConditioningType,
|
|
70
|
+
StrtoBoolforIsXFeatures,
|
|
71
|
+
UnindexFlaggedTokens,
|
|
72
|
+
)
|
|
73
|
+
from rfd3.transforms.design_transforms import (
|
|
74
|
+
AddAdditional1dFeaturesToFeats,
|
|
75
|
+
AddGroundTruthSequence,
|
|
76
|
+
AddIsXFeats,
|
|
77
|
+
AssignTypes,
|
|
78
|
+
AugmentNoise,
|
|
79
|
+
CreateDesignReferenceFeatures,
|
|
80
|
+
FeaturizeAtoms,
|
|
81
|
+
FeaturizepLDDT,
|
|
82
|
+
MotifCenterRandomAugmentation,
|
|
83
|
+
SubsampleToTypes,
|
|
84
|
+
)
|
|
85
|
+
from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
|
|
86
|
+
from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
|
|
87
|
+
from rfd3.transforms.ppi_transforms import (
|
|
88
|
+
Add1DSSFeature,
|
|
89
|
+
AddGlobalIsNonLoopyFeature,
|
|
90
|
+
AddPPIHotspotFeature,
|
|
91
|
+
PPIFullBinderCropSpatial,
|
|
92
|
+
)
|
|
93
|
+
from rfd3.transforms.rasa import (
|
|
94
|
+
CalculateRASA,
|
|
95
|
+
SetZeroOccOnDeltaRASA,
|
|
96
|
+
)
|
|
97
|
+
from rfd3.transforms.symmetry import AddSymmetryFeats
|
|
98
|
+
from rfd3.transforms.util_transforms import (
|
|
99
|
+
IPDB,
|
|
100
|
+
AggregateFeaturesLikeAF3WithoutMSA,
|
|
101
|
+
EncodeAF3TokenLevelFeatures,
|
|
102
|
+
RemoveTokensWithoutCorrespondingCentralAtom,
|
|
103
|
+
)
|
|
104
|
+
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
|
|
105
|
+
|
|
106
|
+
from foundry.common import exists
|
|
107
|
+
|
|
108
|
+
######################################################################################
|
|
109
|
+
# Common transforms
|
|
110
|
+
######################################################################################
|
|
111
|
+
af3_sequence_encoding = AF3SequenceEncoding()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
IPDB # noqa
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def TrainingRoute(transform):
|
|
118
|
+
return ConditionalRoute(
|
|
119
|
+
condition_func=lambda data: data["is_inference"],
|
|
120
|
+
transform_map={True: Identity(), False: transform},
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def InferenceRoute(transform):
|
|
125
|
+
return ConditionalRoute(
|
|
126
|
+
condition_func=lambda data: data["is_inference"],
|
|
127
|
+
transform_map={False: Identity(), True: transform},
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def TrainingConditionRoute(condition, transform):
|
|
132
|
+
transform = TrainingRoute(
|
|
133
|
+
ConditionalRoute(
|
|
134
|
+
condition_func=lambda data: data["conditions"][condition],
|
|
135
|
+
transform_map={
|
|
136
|
+
True: transform,
|
|
137
|
+
False: Identity(),
|
|
138
|
+
},
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
return transform
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def get_pre_crop_transforms(
|
|
145
|
+
central_atom: str,
|
|
146
|
+
b_factor_min: float | None,
|
|
147
|
+
):
|
|
148
|
+
return [
|
|
149
|
+
InferenceRoute(StrtoBoolforIsXFeatures()),
|
|
150
|
+
RemoveHydrogens(),
|
|
151
|
+
FilterToSpecifiedPNUnits(
|
|
152
|
+
extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
|
|
153
|
+
), # Filter to non-clashing PN units
|
|
154
|
+
RemoveTerminalOxygen(),
|
|
155
|
+
# ... Remove PN units that are unresolved early (and also after cropping)
|
|
156
|
+
TrainingRoute(SetOccToZeroOnBfactor(b_factor_min, None)),
|
|
157
|
+
RemoveUnresolvedPNUnits(),
|
|
158
|
+
# ... Remove polymers with too few resolved residues
|
|
159
|
+
TrainingRoute(RemovePolymersWithTooFewResolvedResidues(min_residues=4)),
|
|
160
|
+
MaskPolymerResiduesWithUnresolvedFrameAtoms(),
|
|
161
|
+
# Only filter out undesired res names during training, since it's intentional if they're in the input during inference.
|
|
162
|
+
TrainingRoute(HandleUndesiredResTokens(AF3_EXCLUDED_LIGANDS)),
|
|
163
|
+
# ... Bulk removal of unresolved atoms
|
|
164
|
+
TrainingRoute(
|
|
165
|
+
RemoveUnresolvedLigandAtomsIfTooMany(unresolved_ligand_atom_limit=5)
|
|
166
|
+
),
|
|
167
|
+
# Filter out tokens without a central atom during training, Padding during inference ensures each residue has a central atom
|
|
168
|
+
TrainingRoute(
|
|
169
|
+
RemoveTokensWithoutCorrespondingCentralAtom(central_atom=central_atom),
|
|
170
|
+
),
|
|
171
|
+
FlagAndReassignCovalentModifications(),
|
|
172
|
+
FlagNonPolymersForAtomization(),
|
|
173
|
+
AddGlobalAtomIdAnnotation(),
|
|
174
|
+
AtomizeByCCDName(
|
|
175
|
+
atomize_by_default=True,
|
|
176
|
+
res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA,
|
|
177
|
+
move_atomized_part_to_end=False,
|
|
178
|
+
validate_atomize=False,
|
|
179
|
+
),
|
|
180
|
+
RemoveNucleicAcidTerminalOxygen(),
|
|
181
|
+
AddWithinChainInstanceResIdx(),
|
|
182
|
+
AddWithinPolyResIdxAnnotation(),
|
|
183
|
+
AddProteinTerminiAnnotation(),
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def get_crop_transform(
|
|
188
|
+
crop_size: int,
|
|
189
|
+
crop_center_cutoff_distance: float,
|
|
190
|
+
crop_contiguous_probability: float,
|
|
191
|
+
crop_spatial_probability: float,
|
|
192
|
+
dna_contact_crop_probability: float,
|
|
193
|
+
keep_full_binder_in_spatial_crop: bool,
|
|
194
|
+
max_binder_length: int,
|
|
195
|
+
max_atoms_in_crop: int | None,
|
|
196
|
+
allowed_types: List[str],
|
|
197
|
+
):
|
|
198
|
+
if (
|
|
199
|
+
crop_contiguous_probability > 0
|
|
200
|
+
or crop_spatial_probability > 0
|
|
201
|
+
or dna_contact_crop_probability > 0
|
|
202
|
+
):
|
|
203
|
+
assert np.isclose(
|
|
204
|
+
crop_contiguous_probability
|
|
205
|
+
+ crop_spatial_probability
|
|
206
|
+
+ dna_contact_crop_probability,
|
|
207
|
+
1.0,
|
|
208
|
+
atol=1e-6,
|
|
209
|
+
), "Crop probabilities must sum to 1.0"
|
|
210
|
+
assert crop_size > 0, "Crop size must be greater than 0"
|
|
211
|
+
assert (
|
|
212
|
+
crop_center_cutoff_distance > 0
|
|
213
|
+
), "Crop center cutoff distance must be greater than 0"
|
|
214
|
+
|
|
215
|
+
pre_crop_transforms = [
|
|
216
|
+
SubsampleToTypes(allowed_types=allowed_types),
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
cropping_transform = RandomRoute(
|
|
220
|
+
transforms=[
|
|
221
|
+
CropContiguousLikeAF3(
|
|
222
|
+
crop_size=crop_size,
|
|
223
|
+
keep_uncropped_atom_array=True,
|
|
224
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
225
|
+
),
|
|
226
|
+
ConditionalRoute(
|
|
227
|
+
condition_func=lambda data: (
|
|
228
|
+
keep_full_binder_in_spatial_crop
|
|
229
|
+
and data["sampled_condition_name"] == "ppi"
|
|
230
|
+
and get_token_count(
|
|
231
|
+
data["atom_array"][data["atom_array"].is_binder_pn_unit]
|
|
232
|
+
)
|
|
233
|
+
< max_binder_length
|
|
234
|
+
and data["conditions"]["full_binder_crop"]
|
|
235
|
+
),
|
|
236
|
+
transform_map={
|
|
237
|
+
True: PPIFullBinderCropSpatial(
|
|
238
|
+
crop_size=crop_size,
|
|
239
|
+
crop_center_cutoff_distance=crop_center_cutoff_distance,
|
|
240
|
+
keep_uncropped_atom_array=True,
|
|
241
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
242
|
+
),
|
|
243
|
+
False: CropSpatialLikeAF3(
|
|
244
|
+
crop_size=crop_size,
|
|
245
|
+
crop_center_cutoff_distance=crop_center_cutoff_distance,
|
|
246
|
+
keep_uncropped_atom_array=True,
|
|
247
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
248
|
+
),
|
|
249
|
+
},
|
|
250
|
+
),
|
|
251
|
+
ProteinDNAContactContiguousCrop(
|
|
252
|
+
protein_contact_type="all",
|
|
253
|
+
dna_contact_type="base",
|
|
254
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
255
|
+
),
|
|
256
|
+
],
|
|
257
|
+
probs=[
|
|
258
|
+
crop_contiguous_probability,
|
|
259
|
+
crop_spatial_probability,
|
|
260
|
+
dna_contact_crop_probability,
|
|
261
|
+
],
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
post_crop_transforms = [
|
|
265
|
+
# ... Handling of remaining unresolved residues (NOTE: usually best done after inputs are processed.)
|
|
266
|
+
TrainingRoute(
|
|
267
|
+
PlaceUnresolvedTokenAtomsOnRepresentativeAtom(annotation_to_update="coord")
|
|
268
|
+
),
|
|
269
|
+
TrainingRoute(
|
|
270
|
+
PlaceUnresolvedTokenOnClosestResolvedTokenInSequence(
|
|
271
|
+
annotation_to_update="coord",
|
|
272
|
+
annotation_to_copy="coord",
|
|
273
|
+
)
|
|
274
|
+
),
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
transform = (
|
|
278
|
+
pre_crop_transforms
|
|
279
|
+
+ [
|
|
280
|
+
TrainingRoute(cropping_transform),
|
|
281
|
+
]
|
|
282
|
+
+ post_crop_transforms
|
|
283
|
+
)
|
|
284
|
+
return transform
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_diffusion_transforms(
|
|
288
|
+
*,
|
|
289
|
+
sigma_data: float,
|
|
290
|
+
diffusion_batch_size: int,
|
|
291
|
+
):
|
|
292
|
+
return [
|
|
293
|
+
ComputeAtomToTokenMap(),
|
|
294
|
+
ConvertToTorch(keys=["encoded", "feats"]),
|
|
295
|
+
# Prepare coordinates for noising (without modifying the ground truth)
|
|
296
|
+
# ...add placeholder coordinates for noising
|
|
297
|
+
CopyAnnotation(annotation_to_copy="coord", new_annotation="coord_to_be_noised"),
|
|
298
|
+
# Feature aggregation
|
|
299
|
+
AggregateFeaturesLikeAF3WithoutMSA(),
|
|
300
|
+
# ...batching and noise sampling for diffusion
|
|
301
|
+
BatchStructuresForDiffusionNoising(batch_size=diffusion_batch_size),
|
|
302
|
+
SampleEDMNoise(
|
|
303
|
+
sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size
|
|
304
|
+
),
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
######################################################################################
|
|
309
|
+
# Pipelines
|
|
310
|
+
######################################################################################
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def build_atom14_base_pipeline_(
|
|
314
|
+
*,
|
|
315
|
+
# Training or inference (required)
|
|
316
|
+
is_inference: bool, # If True, we skip cropping, etc.
|
|
317
|
+
return_atom_array: bool,
|
|
318
|
+
# Crop params
|
|
319
|
+
allowed_types: List[str],
|
|
320
|
+
crop_size: int,
|
|
321
|
+
crop_center_cutoff_distance: float,
|
|
322
|
+
crop_contiguous_probability: float,
|
|
323
|
+
crop_spatial_probability: float,
|
|
324
|
+
dna_contact_crop_probability: float,
|
|
325
|
+
keep_full_binder_in_spatial_crop: bool,
|
|
326
|
+
max_binder_length: int, # Only relevant when keep_full_binder_in_spatial_crop is True
|
|
327
|
+
max_atoms_in_crop: int | None,
|
|
328
|
+
b_factor_min: float | None,
|
|
329
|
+
zero_occ_on_exposure_after_cropping: bool,
|
|
330
|
+
# Training Hypers
|
|
331
|
+
sigma_data: float,
|
|
332
|
+
diffusion_batch_size: int,
|
|
333
|
+
# Reference conformer policy
|
|
334
|
+
generate_conformers: bool,
|
|
335
|
+
generate_conformers_for_non_protein_only: bool,
|
|
336
|
+
provide_reference_conformer_when_unmasked: bool,
|
|
337
|
+
ground_truth_conformer_policy: str,
|
|
338
|
+
provide_elements_for_unindexed_components: bool,
|
|
339
|
+
use_element_for_atom_names_of_atomized_tokens: bool,
|
|
340
|
+
residue_cache_dir: bool,
|
|
341
|
+
# Conditioning
|
|
342
|
+
train_conditions: dict,
|
|
343
|
+
meta_conditioning_probabilities: dict,
|
|
344
|
+
# Atom14/Model
|
|
345
|
+
n_atoms_per_token: int,
|
|
346
|
+
central_atom: str,
|
|
347
|
+
sigma_perturb: float,
|
|
348
|
+
sigma_perturb_com: float,
|
|
349
|
+
association_scheme: str | None,
|
|
350
|
+
center_option: str,
|
|
351
|
+
atom_1d_features: dict | None,
|
|
352
|
+
token_1d_features: dict | None,
|
|
353
|
+
# PPI features
|
|
354
|
+
max_ppi_hotspots_frac_to_provide: float,
|
|
355
|
+
ppi_hotspot_max_distance: float,
|
|
356
|
+
# Secondary structure features
|
|
357
|
+
max_ss_frac_to_provide: float,
|
|
358
|
+
min_ss_island_len: int,
|
|
359
|
+
max_ss_island_len: int,
|
|
360
|
+
**_, # dump additional kwargs (e.g. msa stuff)
|
|
361
|
+
):
|
|
362
|
+
"""
|
|
363
|
+
All-Atom design pipeline
|
|
364
|
+
"""
|
|
365
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
366
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
367
|
+
|
|
368
|
+
# Add any data necessary for downstream transforms
|
|
369
|
+
transforms = [
|
|
370
|
+
AddData(
|
|
371
|
+
{
|
|
372
|
+
"is_inference": is_inference,
|
|
373
|
+
"sampled_condition_name": None,
|
|
374
|
+
"conditions": {},
|
|
375
|
+
}
|
|
376
|
+
),
|
|
377
|
+
AssignTypes(),
|
|
378
|
+
]
|
|
379
|
+
# During training, sample condition | adds 'condition': TrainingCondition to data dict
|
|
380
|
+
transforms += [
|
|
381
|
+
TrainingRoute(
|
|
382
|
+
SampleConditioningType(
|
|
383
|
+
train_conditions=train_conditions,
|
|
384
|
+
meta_conditioning_probabilities=meta_conditioning_probabilities,
|
|
385
|
+
sequence_encoding=af3_sequence_encoding,
|
|
386
|
+
),
|
|
387
|
+
),
|
|
388
|
+
]
|
|
389
|
+
|
|
390
|
+
# Pre-crop transforms
|
|
391
|
+
transforms += get_pre_crop_transforms(
|
|
392
|
+
central_atom=central_atom,
|
|
393
|
+
b_factor_min=b_factor_min,
|
|
394
|
+
)
|
|
395
|
+
if zero_occ_on_exposure_after_cropping:
|
|
396
|
+
transforms.append(TrainingRoute(CalculateRASA(requires_ligand=False)))
|
|
397
|
+
|
|
398
|
+
transforms += get_crop_transform(
|
|
399
|
+
crop_size=crop_size,
|
|
400
|
+
crop_center_cutoff_distance=crop_center_cutoff_distance,
|
|
401
|
+
crop_contiguous_probability=crop_contiguous_probability,
|
|
402
|
+
crop_spatial_probability=crop_spatial_probability,
|
|
403
|
+
dna_contact_crop_probability=dna_contact_crop_probability,
|
|
404
|
+
keep_full_binder_in_spatial_crop=keep_full_binder_in_spatial_crop,
|
|
405
|
+
max_binder_length=max_binder_length,
|
|
406
|
+
max_atoms_in_crop=max_atoms_in_crop,
|
|
407
|
+
allowed_types=allowed_types,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
if zero_occ_on_exposure_after_cropping:
|
|
411
|
+
# Optional: Zero out sidechain occupancy for atoms that have become exposed
|
|
412
|
+
transforms.append(TrainingRoute(SetZeroOccOnDeltaRASA()))
|
|
413
|
+
else:
|
|
414
|
+
# RASA calculated after cropping
|
|
415
|
+
transforms.append(
|
|
416
|
+
TrainingConditionRoute(
|
|
417
|
+
"calculate_rasa", CalculateRASA(requires_ligand=True)
|
|
418
|
+
)
|
|
419
|
+
)
|
|
420
|
+
# Need condition flags to add is motif atom annotations before hbond in order to enable using full motif for hbonds
|
|
421
|
+
|
|
422
|
+
# ... Add global token features (since number of tokens is fixed after cropping)
|
|
423
|
+
transforms.append(AddGlobalTokenIdAnnotation())
|
|
424
|
+
# ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
|
|
425
|
+
transforms.append(TrainingRoute(SampleConditioningFlags()))
|
|
426
|
+
|
|
427
|
+
# Post-crop transforms
|
|
428
|
+
transforms.append(
|
|
429
|
+
TrainingConditionRoute(
|
|
430
|
+
"calculate_hbonds",
|
|
431
|
+
CalculateHbondsPlus(
|
|
432
|
+
cutoff_HA_dist=3,
|
|
433
|
+
cutoff_DA_distance=3.5,
|
|
434
|
+
),
|
|
435
|
+
)
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Design Transforms
|
|
439
|
+
transforms += [
|
|
440
|
+
LoadCachedResidueLevelData(
|
|
441
|
+
dir=Path(residue_cache_dir) if exists(residue_cache_dir) else None,
|
|
442
|
+
sharding_depth=1,
|
|
443
|
+
),
|
|
444
|
+
# ... Fuse inference and training conditioning assignments
|
|
445
|
+
UnindexFlaggedTokens(central_atom=central_atom),
|
|
446
|
+
# ... Virtual atom padding (NOTE: Last transform which modulates atom count)
|
|
447
|
+
PadTokensWithVirtualAtoms(
|
|
448
|
+
n_atoms_per_token=n_atoms_per_token,
|
|
449
|
+
atom_to_pad_from=central_atom,
|
|
450
|
+
association_scheme=association_scheme,
|
|
451
|
+
), # 0.1 s
|
|
452
|
+
# Possibly add hotspots
|
|
453
|
+
TrainingRoute(
|
|
454
|
+
ConditionalRoute(
|
|
455
|
+
condition_func=lambda data: data["sampled_condition_name"] == "ppi"
|
|
456
|
+
and data["conditions"]["add_ppi_hotspots"],
|
|
457
|
+
transform_map={
|
|
458
|
+
True: AddPPIHotspotFeature(
|
|
459
|
+
max_hotspots_frac_to_provide=max_ppi_hotspots_frac_to_provide,
|
|
460
|
+
hotspot_max_distance=ppi_hotspot_max_distance,
|
|
461
|
+
),
|
|
462
|
+
False: Identity(),
|
|
463
|
+
},
|
|
464
|
+
)
|
|
465
|
+
),
|
|
466
|
+
TrainingRoute(
|
|
467
|
+
Add1DSSFeature(
|
|
468
|
+
max_secondary_structure_frac_to_provide=max_ss_frac_to_provide,
|
|
469
|
+
min_ss_island_len=min_ss_island_len,
|
|
470
|
+
max_ss_island_len=max_ss_island_len,
|
|
471
|
+
),
|
|
472
|
+
),
|
|
473
|
+
TrainingRoute(
|
|
474
|
+
ConditionalRoute(
|
|
475
|
+
condition_func=lambda data: data["conditions"][
|
|
476
|
+
"add_global_is_non_loopy_feature"
|
|
477
|
+
],
|
|
478
|
+
transform_map={
|
|
479
|
+
True: AddGlobalIsNonLoopyFeature(),
|
|
480
|
+
False: Identity(),
|
|
481
|
+
},
|
|
482
|
+
)
|
|
483
|
+
),
|
|
484
|
+
# ... AF3 token level encoding with sequence masking
|
|
485
|
+
EncodeAF3TokenLevelFeatures(
|
|
486
|
+
sequence_encoding=af3_sequence_encoding, encode_residues_to=GAP
|
|
487
|
+
),
|
|
488
|
+
# ... Atom-level reference features
|
|
489
|
+
CreateDesignReferenceFeatures(
|
|
490
|
+
generate_conformers=generate_conformers,
|
|
491
|
+
generate_conformers_for_non_protein_only=generate_conformers_for_non_protein_only,
|
|
492
|
+
provide_reference_conformer_when_unmasked=provide_reference_conformer_when_unmasked,
|
|
493
|
+
ground_truth_conformer_policy=ground_truth_conformer_policy,
|
|
494
|
+
provide_elements_for_unindexed_components=provide_elements_for_unindexed_components,
|
|
495
|
+
use_element_for_atom_names_of_atomized_tokens=use_element_for_atom_names_of_atomized_tokens,
|
|
496
|
+
),
|
|
497
|
+
# ... Add useful features for losses / metrics
|
|
498
|
+
AddIsXFeats(
|
|
499
|
+
X=[
|
|
500
|
+
# Basic
|
|
501
|
+
"is_backbone",
|
|
502
|
+
"is_sidechain",
|
|
503
|
+
# Virtual atom
|
|
504
|
+
"is_virtual",
|
|
505
|
+
"is_central",
|
|
506
|
+
"is_ca",
|
|
507
|
+
# Conditioning
|
|
508
|
+
"is_motif_atom_with_fixed_coord",
|
|
509
|
+
"is_motif_atom_unindexed",
|
|
510
|
+
"is_motif_atom_with_fixed_seq",
|
|
511
|
+
"is_motif_token_with_fully_fixed_coord",
|
|
512
|
+
],
|
|
513
|
+
central_atom=central_atom,
|
|
514
|
+
),
|
|
515
|
+
FeaturizeAtoms(),
|
|
516
|
+
FeaturizepLDDT(skip=b_factor_min is not None),
|
|
517
|
+
AddAdditional1dFeaturesToFeats(
|
|
518
|
+
autofill_zeros_if_not_present_in_atomarray=True,
|
|
519
|
+
token_1d_features=token_1d_features,
|
|
520
|
+
atom_1d_features=atom_1d_features,
|
|
521
|
+
),
|
|
522
|
+
AddAF3TokenBondFeatures(),
|
|
523
|
+
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
|
|
524
|
+
ConditionalRoute(
|
|
525
|
+
condition_func=lambda data: "symmetry_id"
|
|
526
|
+
in data["atom_array"].get_annotation_categories(),
|
|
527
|
+
transform_map={
|
|
528
|
+
True: AddSymmetryFeats(),
|
|
529
|
+
False: Identity(),
|
|
530
|
+
},
|
|
531
|
+
),
|
|
532
|
+
]
|
|
533
|
+
|
|
534
|
+
# EDM-style wrap-up (no additional features added at this point)
|
|
535
|
+
transforms += get_diffusion_transforms(
|
|
536
|
+
sigma_data=sigma_data,
|
|
537
|
+
diffusion_batch_size=diffusion_batch_size,
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# ... Random augmentation accounting for motif
|
|
541
|
+
transforms += [
|
|
542
|
+
MotifCenterRandomAugmentation(
|
|
543
|
+
batch_size=diffusion_batch_size,
|
|
544
|
+
sigma_perturb=sigma_perturb,
|
|
545
|
+
center_option=center_option,
|
|
546
|
+
),
|
|
547
|
+
AugmentNoise(
|
|
548
|
+
sigma_perturb_com=sigma_perturb_com,
|
|
549
|
+
batch_size=diffusion_batch_size,
|
|
550
|
+
),
|
|
551
|
+
]
|
|
552
|
+
|
|
553
|
+
# Subset to necessary keys only
|
|
554
|
+
keys_to_keep = [
|
|
555
|
+
"example_id",
|
|
556
|
+
"feats",
|
|
557
|
+
"t",
|
|
558
|
+
"noise",
|
|
559
|
+
"ground_truth",
|
|
560
|
+
"coord_atom_lvl_to_be_noised",
|
|
561
|
+
"extra_info",
|
|
562
|
+
"sampled_condition_name",
|
|
563
|
+
"log_dict",
|
|
564
|
+
]
|
|
565
|
+
if return_atom_array:
|
|
566
|
+
keys_to_keep.extend(
|
|
567
|
+
[
|
|
568
|
+
"atom_array",
|
|
569
|
+
"specification",
|
|
570
|
+
]
|
|
571
|
+
)
|
|
572
|
+
# For debugging & tests:
|
|
573
|
+
if not is_inference:
|
|
574
|
+
keys_to_keep.append("conditions")
|
|
575
|
+
transforms.append(SubsetToKeys(keys_to_keep))
|
|
576
|
+
|
|
577
|
+
pipeline = Compose(transforms)
|
|
578
|
+
return pipeline
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def build_atom14_base_pipeline(
|
|
582
|
+
is_inference: bool,
|
|
583
|
+
*,
|
|
584
|
+
# Dumped args:
|
|
585
|
+
protein_msa_dirs=None,
|
|
586
|
+
rna_msa_dirs=None,
|
|
587
|
+
n_recycles=None,
|
|
588
|
+
n_msa=None,
|
|
589
|
+
# Catch all other arguments:
|
|
590
|
+
**kwargs,
|
|
591
|
+
):
|
|
592
|
+
"""
|
|
593
|
+
Wrapper around pipeline construction to handle empty training args
|
|
594
|
+
Sets default behaviour for inference to keep backward compatibility
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
if is_inference:
|
|
598
|
+
# Provide explicit defaults for training-only args
|
|
599
|
+
kwargs.setdefault("crop_size", 512)
|
|
600
|
+
kwargs.setdefault("crop_center_cutoff_distance", 10.0)
|
|
601
|
+
kwargs.setdefault("crop_contiguous_probability", 1.0)
|
|
602
|
+
kwargs.setdefault("crop_spatial_probability", 0.0)
|
|
603
|
+
kwargs.setdefault("dna_contact_crop_probability", 0.0)
|
|
604
|
+
kwargs.setdefault("max_atoms_in_crop", None)
|
|
605
|
+
kwargs.setdefault("keep_full_binder_in_spatial_crop", True)
|
|
606
|
+
kwargs.setdefault("max_ppi_hotspots_frac_to_provide", 0)
|
|
607
|
+
kwargs.setdefault("ppi_hotspot_max_distance", 15)
|
|
608
|
+
kwargs.setdefault("max_ss_frac_to_provide", 0.0)
|
|
609
|
+
kwargs.setdefault("min_ss_island_len", 0)
|
|
610
|
+
kwargs.setdefault("max_ss_island_len", 999)
|
|
611
|
+
kwargs.setdefault("max_binder_length", 999)
|
|
612
|
+
|
|
613
|
+
kwargs.setdefault("b_factor_min", None)
|
|
614
|
+
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
|
|
615
|
+
kwargs.setdefault("meta_conditioning_probabilities", {})
|
|
616
|
+
kwargs.setdefault("association_scheme", "dense")
|
|
617
|
+
kwargs.setdefault("sigma_perturb", 0.0)
|
|
618
|
+
kwargs.setdefault("sigma_perturb_com", 0.0)
|
|
619
|
+
kwargs.setdefault("allowed_types", "ALL")
|
|
620
|
+
kwargs.setdefault("train_conditions", {})
|
|
621
|
+
kwargs.setdefault("residue_cache_dir", None)
|
|
622
|
+
|
|
623
|
+
# TODO: Delete these once all checkpoints are updated with the latest defaults
|
|
624
|
+
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
|
|
625
|
+
kwargs.setdefault("return_atom_array", True)
|
|
626
|
+
kwargs.setdefault("provide_elements_for_unindexed_components", False)
|
|
627
|
+
kwargs.setdefault("center_option", "all")
|
|
628
|
+
|
|
629
|
+
return build_atom14_base_pipeline_(
|
|
630
|
+
is_inference=is_inference,
|
|
631
|
+
**kwargs,
|
|
632
|
+
)
|