rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/cli.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import typer
|
|
4
|
+
from hydra import compose, initialize_config_dir
|
|
5
|
+
|
|
6
|
+
app = typer.Typer(pretty_exceptions_enable=False)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@app.command(
|
|
10
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
|
11
|
+
)
|
|
12
|
+
def fold(
|
|
13
|
+
ctx: typer.Context,
|
|
14
|
+
verbose: bool = typer.Option(
|
|
15
|
+
False, "--verbose", "-v", help="Show detailed logging output"
|
|
16
|
+
),
|
|
17
|
+
):
|
|
18
|
+
"""Run structure prediction using hydra config overrides or simple input file."""
|
|
19
|
+
# Configure logging BEFORE any heavy imports
|
|
20
|
+
if not verbose:
|
|
21
|
+
from foundry.utils.logging import configure_minimal_inference_logging
|
|
22
|
+
|
|
23
|
+
configure_minimal_inference_logging()
|
|
24
|
+
|
|
25
|
+
# Find the RF3 configs directory relative to this file
|
|
26
|
+
# This file is at: models/rf3/src/rf3/cli.py
|
|
27
|
+
# Configs are at: models/rf3/configs/
|
|
28
|
+
rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
|
|
29
|
+
config_path = str(rf3_package_dir / "configs")
|
|
30
|
+
|
|
31
|
+
# Get all arguments
|
|
32
|
+
args = ctx.params.get("args", []) + ctx.args
|
|
33
|
+
|
|
34
|
+
# Parse arguments
|
|
35
|
+
hydra_overrides = []
|
|
36
|
+
|
|
37
|
+
if len(args) == 1 and "=" not in args[0]:
|
|
38
|
+
# Old style: single positional argument assumed to be inputs
|
|
39
|
+
hydra_overrides.append(f"inputs={args[0]}")
|
|
40
|
+
else:
|
|
41
|
+
# New style: all arguments are hydra overrides
|
|
42
|
+
hydra_overrides.extend(args)
|
|
43
|
+
|
|
44
|
+
# Ensure we have at least a default inference_engine if not specified
|
|
45
|
+
has_inference_engine = any(
|
|
46
|
+
arg.startswith("inference_engine=") for arg in hydra_overrides
|
|
47
|
+
)
|
|
48
|
+
if not has_inference_engine:
|
|
49
|
+
hydra_overrides.append("inference_engine=rf3")
|
|
50
|
+
|
|
51
|
+
# Handle verbose flag
|
|
52
|
+
if verbose:
|
|
53
|
+
hydra_overrides.append("verbose=true")
|
|
54
|
+
|
|
55
|
+
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
|
|
56
|
+
cfg = compose(config_name="inference", overrides=hydra_overrides)
|
|
57
|
+
# Lazy import to avoid loading heavy dependencies at CLI startup
|
|
58
|
+
from rf3.inference import run_inference
|
|
59
|
+
|
|
60
|
+
run_inference(cfg)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@app.command(
|
|
64
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
|
65
|
+
)
|
|
66
|
+
def predict(
|
|
67
|
+
ctx: typer.Context,
|
|
68
|
+
verbose: bool = typer.Option(
|
|
69
|
+
False, "--verbose", "-v", help="Show detailed logging output"
|
|
70
|
+
),
|
|
71
|
+
):
|
|
72
|
+
"""Alias for fold command."""
|
|
73
|
+
fold(ctx, verbose=verbose)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == "__main__":
|
|
77
|
+
app()
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from atomworks.ml.transforms.base import Transform
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger("atomworks.ml")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AddCyclicBonds(Transform):
|
|
10
|
+
"""
|
|
11
|
+
Transform that detects and adds cyclic (head-to-tail) peptide bonds in protein chains.
|
|
12
|
+
This transform analyzes the atom-level structure of each chain in the input data to identify
|
|
13
|
+
cyclic bonds between the N-terminal nitrogen of the first residue and the C-terminal carbon
|
|
14
|
+
of the last residue, based on spatial proximity. If such a bond is detected (0.5 Å < distance < 1.5 Å),
|
|
15
|
+
it updates the token-level bond features to reflect the presence of the cyclic bond and flags that
|
|
16
|
+
chain as being cyclic.
|
|
17
|
+
|
|
18
|
+
Requirements:
|
|
19
|
+
- Must be applied after "AddAF3TokenBondFeatures" and "EncodeAF3TokenLevelFeatures" transforms.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# need to do it after this transform, because we want to include poly-poly bonds, which AF3TokenBondFeatures does not.
|
|
23
|
+
requires_previous_transforms = [
|
|
24
|
+
"AddAF3TokenBondFeatures",
|
|
25
|
+
"EncodeAF3TokenLevelFeatures",
|
|
26
|
+
"AddGlobalTokenIdAnnotation",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
def forward(self, data: dict) -> dict:
|
|
30
|
+
atom_array = data["atom_array"]
|
|
31
|
+
token_bonds = data["feats"]["token_bonds"]
|
|
32
|
+
asym_ids = data["feats"]["asym_id"]
|
|
33
|
+
|
|
34
|
+
cyclic_token_bonds = np.zeros_like(token_bonds, dtype=bool)
|
|
35
|
+
cyclic_asym_ids = set()
|
|
36
|
+
|
|
37
|
+
# check for any cyclic bonds
|
|
38
|
+
for chain in np.unique(atom_array.chain_id):
|
|
39
|
+
chain_mask = atom_array.chain_id == chain
|
|
40
|
+
if not np.any(chain_mask):
|
|
41
|
+
continue
|
|
42
|
+
chain_array = atom_array[chain_mask]
|
|
43
|
+
residue_ids = np.unique(chain_array.res_id)
|
|
44
|
+
if len(residue_ids) < 2:
|
|
45
|
+
continue
|
|
46
|
+
first_residue = residue_ids[0]
|
|
47
|
+
last_residue = residue_ids[-1]
|
|
48
|
+
first_nitrogen_mask = (chain_array.res_id == first_residue) & (
|
|
49
|
+
chain_array.atom_name == "N"
|
|
50
|
+
)
|
|
51
|
+
last_carbon_mask = (chain_array.res_id == last_residue) & (
|
|
52
|
+
chain_array.atom_name == "C"
|
|
53
|
+
)
|
|
54
|
+
if first_nitrogen_mask.sum() == 1 and last_carbon_mask.sum() == 1:
|
|
55
|
+
first_nitrogen = chain_array[first_nitrogen_mask]
|
|
56
|
+
last_carbon = chain_array[last_carbon_mask]
|
|
57
|
+
distance = np.linalg.norm(
|
|
58
|
+
first_nitrogen.coord[0] - last_carbon.coord[0]
|
|
59
|
+
)
|
|
60
|
+
if distance < 1.5 and distance > 0.5: # peptide-bond length-ish
|
|
61
|
+
cyclic_token_bonds[
|
|
62
|
+
first_nitrogen.token_id[0], last_carbon.token_id[0]
|
|
63
|
+
] = True
|
|
64
|
+
cyclic_token_bonds[
|
|
65
|
+
last_carbon.token_id[0], first_nitrogen.token_id[0]
|
|
66
|
+
] = True
|
|
67
|
+
logger.warning(
|
|
68
|
+
f"Detected cyclic bond in chain {chain} of {data['example_id']} between residues {first_residue} and {last_residue} with distance {distance:.2f} Å"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
cyclic_asym_ids.update(
|
|
72
|
+
asym_ids[np.where(np.any(cyclic_token_bonds, axis=0))[0]].tolist()
|
|
73
|
+
)
|
|
74
|
+
token_bonds |= cyclic_token_bonds
|
|
75
|
+
data["feats"]["token_bonds"] = token_bonds
|
|
76
|
+
data["feats"]["cyclic_asym_ids"] = list(cyclic_asym_ids)
|
|
77
|
+
|
|
78
|
+
return data
|
rf3/data/extra_xforms.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from atomworks.ml.transforms._checks import (
|
|
3
|
+
check_contains_keys,
|
|
4
|
+
)
|
|
5
|
+
from atomworks.ml.transforms.base import Transform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CheckForNaNsInInputs(Transform):
|
|
9
|
+
"""
|
|
10
|
+
This component marks atoms as occ=0 based on bfactor values
|
|
11
|
+
|
|
12
|
+
It takes as input 'brange', a list specifying the Mminimum and maximum B factors to
|
|
13
|
+
keep.
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
brange = [-1.0,70.0] will mark with occ=0 any atom with b>70 or b<-1
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def check_input(self, data: dict):
|
|
20
|
+
check_contains_keys(data, ["coord_atom_lvl_to_be_noised"])
|
|
21
|
+
check_contains_keys(data, ["noise"])
|
|
22
|
+
|
|
23
|
+
def forward(self, data: dict) -> dict:
|
|
24
|
+
# During inference, replace coordinates with true noise
|
|
25
|
+
# TODO: Move elsewhere in pipeline; placing it here is a short-term hack
|
|
26
|
+
if data.get("is_inference", False):
|
|
27
|
+
data["coord_atom_lvl_to_be_noised"] = torch.randn_like(
|
|
28
|
+
data["coord_atom_lvl_to_be_noised"]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
assert not torch.isnan(
|
|
32
|
+
data["coord_atom_lvl_to_be_noised"]
|
|
33
|
+
).any(), "NaN found in network input"
|
|
34
|
+
assert not torch.isnan(data["noise"]).any(), "NaN found in network noise"
|
|
35
|
+
|
|
36
|
+
return data
|
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from atomworks.enums import ChainType
|
|
7
|
+
from atomworks.ml.transforms._checks import (
|
|
8
|
+
check_atom_array_annotation,
|
|
9
|
+
check_contains_keys,
|
|
10
|
+
check_is_instance,
|
|
11
|
+
)
|
|
12
|
+
from atomworks.ml.transforms.atomize import AtomizeByCCDName
|
|
13
|
+
from atomworks.ml.transforms.base import Transform
|
|
14
|
+
from atomworks.ml.utils.token import (
|
|
15
|
+
get_af3_token_center_masks,
|
|
16
|
+
get_token_starts,
|
|
17
|
+
)
|
|
18
|
+
from beartype.typing import Any, Callable, Final, Sequence
|
|
19
|
+
from biotite.structure import AtomArray
|
|
20
|
+
from jaxtyping import Bool, Float, Shaped
|
|
21
|
+
from torch import Tensor
|
|
22
|
+
|
|
23
|
+
from foundry.utils.torch import assert_no_nans
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
MaskingFunction = Callable[[AtomArray], Bool[Shaped, "n"]]
|
|
28
|
+
"""A function that takes in an AtomArray and returns a boolean mask."""
|
|
29
|
+
|
|
30
|
+
NoiseScaleSampler = Callable[[Sequence[int]], Float[Tensor, "..."] | float]
|
|
31
|
+
"""
|
|
32
|
+
A noise scale sampler that, when given a shape-tuple, returns a sample of
|
|
33
|
+
noise scales of the appropriate shape.
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
- partial(torch.normal, mean=0.0, std=1.0)
|
|
37
|
+
- af3_noise_scale_distribution
|
|
38
|
+
- af3_noise_scale_distribution_wrapped
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class TokenGroupNoiseScaleSampler:
|
|
44
|
+
mask_and_sampling_fns: tuple[tuple[MaskingFunction, NoiseScaleSampler], ...]
|
|
45
|
+
|
|
46
|
+
def __call__(self, atom_array: AtomArray) -> Tensor:
|
|
47
|
+
# ... determine token centers
|
|
48
|
+
token_center_mask = get_af3_token_center_masks(atom_array) # [n_token] (bool)
|
|
49
|
+
token_array = atom_array[token_center_mask] # [n_token] (AtomArray)
|
|
50
|
+
|
|
51
|
+
# ... sample a noise scale for each token group
|
|
52
|
+
noise_scales = torch.full(
|
|
53
|
+
size=(len(token_array),),
|
|
54
|
+
fill_value=float("nan"),
|
|
55
|
+
dtype=torch.float32,
|
|
56
|
+
)
|
|
57
|
+
for mask_fn, sampling_fn in self.mask_and_sampling_fns:
|
|
58
|
+
mask = mask_fn(token_array)
|
|
59
|
+
n_tokens_to_sample = mask.sum()
|
|
60
|
+
if n_tokens_to_sample > 0:
|
|
61
|
+
# ... all tokens in that group receive the same noise scale
|
|
62
|
+
noise_scales[mask] = sampling_fn((1,))
|
|
63
|
+
|
|
64
|
+
return noise_scales
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
DEFAULT_DISTOGRAM_BINS: Final[Float[Tensor, "63"]] = torch.concat(
|
|
68
|
+
(
|
|
69
|
+
torch.arange(1.0, 4.0, 0.1, device="cpu"),
|
|
70
|
+
torch.arange(4.0, 20.5, 0.5, device="cpu"),
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
"""
|
|
74
|
+
Default bins for discretizing distances in the distogram (in Angstrom).
|
|
75
|
+
- 0.1A resolution from 1.0 - 4.0 A
|
|
76
|
+
- 0.5A resolution from 4.0 - 20.0 A
|
|
77
|
+
Total number of bins: 64 (i.e. 63 bin boundaries above)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def wrap_probability_distribution(
|
|
82
|
+
samples: Float[Tensor, "..."],
|
|
83
|
+
lower: float = float("-inf"),
|
|
84
|
+
upper: float = float("inf"),
|
|
85
|
+
) -> Float[Tensor, "..."]:
|
|
86
|
+
"""
|
|
87
|
+
Wrap a probability distribution around lower and upper bounds to create
|
|
88
|
+
samples from the corresponding wrapped probability distribution.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
- samples: Input tensor of samples to wrap
|
|
92
|
+
- lower: Lower bound for wrapping (inclusive, unless infinite)
|
|
93
|
+
- upper: Upper bound for wrapping (inclusive, unless infinite)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
- samples: Samples wrapped around the lower and upper bounds within
|
|
97
|
+
the interval ]lower, upper[.
|
|
98
|
+
Reference:
|
|
99
|
+
- https://en.wikipedia.org/wiki/Wrapped_distribution
|
|
100
|
+
"""
|
|
101
|
+
if lower > float("-inf") and upper < float("inf"):
|
|
102
|
+
return ((samples - lower) % (upper - lower)) + lower
|
|
103
|
+
elif lower > float("-inf"):
|
|
104
|
+
return lower + (samples - lower).abs()
|
|
105
|
+
elif upper < float("inf"):
|
|
106
|
+
return upper - (samples - upper).abs()
|
|
107
|
+
return samples
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def wrapped_normal(
|
|
111
|
+
mean: float,
|
|
112
|
+
std: float,
|
|
113
|
+
size: Sequence[int],
|
|
114
|
+
*,
|
|
115
|
+
lower: float = float("-inf"),
|
|
116
|
+
upper: float = float("inf"),
|
|
117
|
+
**normal_kwargs,
|
|
118
|
+
) -> Float[Tensor, "..."]:
|
|
119
|
+
"""Sample from a wrapped normal distribution."""
|
|
120
|
+
samples = torch.normal(mean=mean, std=std, size=size, **normal_kwargs)
|
|
121
|
+
return wrap_probability_distribution(samples, lower, upper)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def af3_noise_scale_to_noise_level(
|
|
125
|
+
noise_scale: Tensor | float, eps: int = 1e-8
|
|
126
|
+
) -> Tensor:
|
|
127
|
+
"""Converts AlphaFold3 noise scale (t^) in Angstroms to noise level (t).
|
|
128
|
+
|
|
129
|
+
This function converts from a noise scale in Angstroms (t^) to the
|
|
130
|
+
corresponding standard normal noise level (t) using the formula:
|
|
131
|
+
t = (log(t^/16.0) + 1.2) / 1.5
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
- noise_scale (Tensor): The noise scale (t^) in Angstroms,
|
|
135
|
+
representing the standard deviation of positional noise.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
- noise_level (Tensor): The corresponding noise level (t) as a
|
|
139
|
+
standard normal random variable.
|
|
140
|
+
|
|
141
|
+
Notes:
|
|
142
|
+
- We use the term 'noise-level' to refer to the standard normal random
|
|
143
|
+
variable `t` in the AF3 paper and 'noise-scale' to refer to the variable
|
|
144
|
+
`t^` which denotes the noise scale in Angstrom. This is the inverse
|
|
145
|
+
operation of af3_noise_level_to_noise_scale().
|
|
146
|
+
- To avoid taking the log of zero, we add a small constant to the
|
|
147
|
+
denominator (16.0) in the formula.
|
|
148
|
+
"""
|
|
149
|
+
noise_scale_tensor = torch.as_tensor(noise_scale)
|
|
150
|
+
return (torch.log(torch.clamp(noise_scale_tensor, min=eps) / 16.0) + 1.2) / 1.5
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def af3_noise_level_to_noise_scale(noise_level: Tensor | float) -> Tensor:
|
|
154
|
+
"""Convert AlphaFold3 noise level (t) to noise scale (t^) in Angstroms.
|
|
155
|
+
|
|
156
|
+
This function converts from a standard normal noise level (t) to the
|
|
157
|
+
corresponding noise scale in Angstroms (t^) using the formula:
|
|
158
|
+
t^ = 16.0 * exp(1.5t - 1.2) (log-N(log(0.04), 1.5^2))
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
- noise_level (Tensor): The noise level (t) as a standard normal random
|
|
162
|
+
variable, sampled from N(0,1).
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
- noise_scale (Tensor): The corresponding noise scale (t^) in Angstroms,
|
|
166
|
+
representing the standard deviation of positional noise to apply.
|
|
167
|
+
|
|
168
|
+
Note:
|
|
169
|
+
This is the inverse operation of af3_noise_scale_to_noise_level(). The
|
|
170
|
+
transformation is designed to convert between a normal distribution and
|
|
171
|
+
a log-normal distribution with specific parameters chosen by AlphaFold3.
|
|
172
|
+
"""
|
|
173
|
+
return 16.0 * torch.exp((torch.as_tensor(noise_level) * 1.5) - 1.2)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def af3_noise_scale_distribution(size: Sequence[int], **kwargs) -> Tensor:
|
|
177
|
+
"""
|
|
178
|
+
The log-normal noise-scale distribution used in AF3 (in Angstrom).
|
|
179
|
+
|
|
180
|
+
t^ = 16.0 * exp(1.5t - 1.2),
|
|
181
|
+
where:
|
|
182
|
+
- t = noise-level ~ N(0,1)
|
|
183
|
+
- t^ = noise-scale ~ log-N(log(0.04), 1.5^2)
|
|
184
|
+
"""
|
|
185
|
+
noise_level = torch.normal(mean=0.0, std=1.0, size=size, **kwargs)
|
|
186
|
+
return af3_noise_level_to_noise_scale(noise_level)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def af3_noise_scale_distribution_wrapped(
|
|
190
|
+
size: Sequence[int],
|
|
191
|
+
*,
|
|
192
|
+
lower_noise_level: float = float("-inf"),
|
|
193
|
+
upper_noise_level: float = float("inf"),
|
|
194
|
+
**kwargs,
|
|
195
|
+
) -> Tensor:
|
|
196
|
+
"""
|
|
197
|
+
The noise-scale distribution used in AF3 (in Angstrom), wrapped around the lower
|
|
198
|
+
and upper bounds.
|
|
199
|
+
|
|
200
|
+
WARNING: The lower/upper here correspond to the noise-level (t) (not noise-scale (t^)),
|
|
201
|
+
wrapping happens in the noise-level space before converting to the corresponding
|
|
202
|
+
log-normal noise-scale distribution (t^) in Angstroms.
|
|
203
|
+
"""
|
|
204
|
+
noise_level = wrapped_normal(
|
|
205
|
+
mean=0.0,
|
|
206
|
+
std=1.0,
|
|
207
|
+
size=size,
|
|
208
|
+
lower=lower_noise_level,
|
|
209
|
+
upper=upper_noise_level,
|
|
210
|
+
**kwargs,
|
|
211
|
+
)
|
|
212
|
+
return af3_noise_level_to_noise_scale(noise_level)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def featurize_noised_ground_truth_as_template_distogram(
|
|
216
|
+
atom_array: AtomArray,
|
|
217
|
+
*,
|
|
218
|
+
noise_scale: Float[Tensor, "n_token"] | float,
|
|
219
|
+
distogram_bins: Float[Tensor, "n_bin_edges"],
|
|
220
|
+
allowed_chain_types: list[ChainType],
|
|
221
|
+
is_unconditional: bool = True,
|
|
222
|
+
p_condition_per_token: float = 0.0,
|
|
223
|
+
p_provide_inter_molecule_distances: float = 0.0,
|
|
224
|
+
existing_annotation_to_check: str = "is_input_file_templated",
|
|
225
|
+
) -> dict[str, Tensor]:
|
|
226
|
+
"""Featurize noised ground truth as a template distogram for conditioning.
|
|
227
|
+
|
|
228
|
+
Used to leak ground-truth information into the model.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
atom_array (AtomArray): The input atom array. Must have 'chain_type', 'occupancy', and 'molecule_id' annotations.
|
|
232
|
+
noise_scale (Tensor | float): Standard deviation of the noise to add to the ground truth.
|
|
233
|
+
Different tokens may have different noise scales (e.g. one noise scale for
|
|
234
|
+
side-chains, one for ligand atoms and one for backbone atoms).
|
|
235
|
+
Units are in Angstrom. If given as tensor, must be of shape [n_token] (float).
|
|
236
|
+
allowed_chain_types (list): List of allowed chain types. Only token pairs where BOTH
|
|
237
|
+
tokens have a chain type in this list will have a distogram condition.
|
|
238
|
+
distogram_bins (Tensor): Bins for discretizing distances in the distogram (in Angstrom).
|
|
239
|
+
Shape: [n_bin].
|
|
240
|
+
is_unconditional (bool): Whether we are sampling unconditionally.
|
|
241
|
+
See Classifier-Free Diffusion Guidance (Ho et al., 2022) for details.
|
|
242
|
+
Default: True (no conditioning).
|
|
243
|
+
p_condition_per_token (float, optional):
|
|
244
|
+
Probability of conditioning each eligible token. Default: 0.0 (no conditioning)
|
|
245
|
+
p_provide_inter_molecule_distances (float, optional):
|
|
246
|
+
Probability of providing inter-molecule (inter-chain) distances. Default: 0.0 (mask all inter-molecule pairs).
|
|
247
|
+
existing_annotation_to_check (str):
|
|
248
|
+
If this annotation exists in the AtomArray, we ALWAYS template where it is True.
|
|
249
|
+
Useful for inference.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
dict[str, Tensor]:
|
|
253
|
+
Dictionary with the following keys:
|
|
254
|
+
- 'distogram_condition_noise_scale': Float[Tensor, "n_token"]. Noise scale for each token (0 for unconditioned tokens).
|
|
255
|
+
- 'has_distogram_condition': Bool[Tensor, "n_token n_token"]. Mask indicating which token pairs are conditioned.
|
|
256
|
+
- 'distogram_condition': Float[Tensor, "n_token n_token n_bins"]. One-hot encoded distogram for each token pair.
|
|
257
|
+
|
|
258
|
+
NOTE:
|
|
259
|
+
- We use the center atom for each token (CA for proteins, C1' for nucleic acids) in the token-level conditioning.
|
|
260
|
+
- If a token is not conditioned, its noise scale is set to 0 and its pairwise distances are masked.
|
|
261
|
+
"""
|
|
262
|
+
MASK_VALUE = float("nan")
|
|
263
|
+
|
|
264
|
+
# Get full atom array token starts (useful for going from atom-level -> token-level annotations)
|
|
265
|
+
_a_token_starts = get_token_starts(atom_array) # [n_token] (int)
|
|
266
|
+
_n_token = len(_a_token_starts)
|
|
267
|
+
|
|
268
|
+
# Create one blank template (ground truth), initialized to mask tokens (we will only use the distogram, and ignore the other features)
|
|
269
|
+
template_distogram = torch.full((_n_token, _n_token), fill_value=MASK_VALUE)
|
|
270
|
+
|
|
271
|
+
# Sample Gaussian noise according to the noise scale for each token
|
|
272
|
+
# NOTE: If a scalar noise scale is provided, it will be broadcasted to all tokens
|
|
273
|
+
# NOTE: We sample noise independently for each token; no two tokens will have the exact same noise
|
|
274
|
+
noise = torch.normal(mean=0.0, std=1.0, size=(_n_token, 3)) * noise_scale.unsqueeze(
|
|
275
|
+
-1
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Get the center coordinates of the tokens (CA for proteins, C1' for nucleic acids), and add noise
|
|
279
|
+
center_token_mask = get_af3_token_center_masks(atom_array) # [n_atom] (bool)
|
|
280
|
+
noisy_center_coords = (
|
|
281
|
+
torch.from_numpy(atom_array.coord[center_token_mask]) + noise
|
|
282
|
+
) # [n_token, 3] (float)
|
|
283
|
+
|
|
284
|
+
# Create a mask of supported chain types...
|
|
285
|
+
tokens_with_supported_chain_types_mask = np.isin(
|
|
286
|
+
atom_array.chain_type[center_token_mask], allowed_chain_types
|
|
287
|
+
) # [n_token] (bool)
|
|
288
|
+
|
|
289
|
+
# ... and mask of tokens with resolved center atoms
|
|
290
|
+
resolved_tokens_mask = (
|
|
291
|
+
atom_array.occupancy[center_token_mask] > 0
|
|
292
|
+
) # [n_token] (bool)
|
|
293
|
+
|
|
294
|
+
# The tokens to fill are those with supported chain types, resolved center atoms, and non-NaN noise
|
|
295
|
+
token_to_fill_mask = (
|
|
296
|
+
tokens_with_supported_chain_types_mask
|
|
297
|
+
& resolved_tokens_mask
|
|
298
|
+
& torch.isfinite(noise).all(dim=-1).numpy()
|
|
299
|
+
) # [n_token] (bool)
|
|
300
|
+
|
|
301
|
+
# Check if existing annotation exists and force templating where it's True
|
|
302
|
+
if (
|
|
303
|
+
existing_annotation_to_check
|
|
304
|
+
and existing_annotation_to_check in atom_array.get_annotation_categories()
|
|
305
|
+
):
|
|
306
|
+
existing_annotation_values = atom_array.get_annotation(
|
|
307
|
+
existing_annotation_to_check
|
|
308
|
+
)[center_token_mask]
|
|
309
|
+
forced_template_mask = np.asarray(existing_annotation_values, dtype=bool)
|
|
310
|
+
else:
|
|
311
|
+
forced_template_mask = np.full(_n_token, False, dtype=bool)
|
|
312
|
+
|
|
313
|
+
# If unconditional, discard all conditioning...
|
|
314
|
+
if is_unconditional:
|
|
315
|
+
token_to_fill_mask = np.full_like(token_to_fill_mask, False)
|
|
316
|
+
else:
|
|
317
|
+
# Probability of masking each token
|
|
318
|
+
_should_apply_condition = np.random.rand(_n_token) < p_condition_per_token
|
|
319
|
+
token_to_fill_mask = (
|
|
320
|
+
token_to_fill_mask & _should_apply_condition
|
|
321
|
+
) | forced_template_mask
|
|
322
|
+
|
|
323
|
+
token_idxs_to_fill = np.where(token_to_fill_mask)[0] # [n_token_to_fill] (int)
|
|
324
|
+
|
|
325
|
+
# ... fill the template_distogram
|
|
326
|
+
ix1, ix2 = np.ix_(token_idxs_to_fill, token_idxs_to_fill)
|
|
327
|
+
template_distogram[ix1.astype(int), ix2.astype(int)] = torch.cdist(
|
|
328
|
+
noisy_center_coords[token_to_fill_mask],
|
|
329
|
+
noisy_center_coords[token_to_fill_mask],
|
|
330
|
+
compute_mode="donot_use_mm_for_euclid_dist", # Important for numerical stability
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# (Create n_token x n_token mask, where True indicates a condition); e.g., True for all non-mask tokens
|
|
334
|
+
token_to_fill_mask_II = token_to_fill_mask[:, None] & token_to_fill_mask[None, :]
|
|
335
|
+
|
|
336
|
+
# ... mask inter-molecule distances, if required
|
|
337
|
+
if np.random.rand() > p_provide_inter_molecule_distances:
|
|
338
|
+
# Create a mask of tokens that belong to different molecules
|
|
339
|
+
is_inter_molecule = (
|
|
340
|
+
atom_array.molecule_id[center_token_mask][:, None]
|
|
341
|
+
!= atom_array.molecule_id[center_token_mask]
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# ... mask inter-molecule distances
|
|
345
|
+
token_to_fill_mask_II[is_inter_molecule] = False
|
|
346
|
+
template_distogram[is_inter_molecule] = MASK_VALUE
|
|
347
|
+
|
|
348
|
+
# Discretize distances into bins (NaNs go to last bin)
|
|
349
|
+
template_distogram_binned: Tensor = torch.bucketize(
|
|
350
|
+
template_distogram, boundaries=distogram_bins
|
|
351
|
+
) # (n_token, n_token)
|
|
352
|
+
n_bins: int = len(distogram_bins) + 1
|
|
353
|
+
template_distogram_onehot: Float[Tensor, "n_token n_token n_bins"] = (
|
|
354
|
+
torch.nn.functional.one_hot(
|
|
355
|
+
template_distogram_binned, num_classes=n_bins
|
|
356
|
+
).to(torch.float32)
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Expand noise_scale to (n_token,) if needed
|
|
360
|
+
expanded_noise_scale: Float[Tensor, "n_token"] = (
|
|
361
|
+
noise_scale.expand(_n_token)
|
|
362
|
+
if isinstance(noise_scale, Tensor)
|
|
363
|
+
else torch.full_like(noise, fill_value=noise_scale)
|
|
364
|
+
)
|
|
365
|
+
# Set noise scale to 0 for unconditioned tokens
|
|
366
|
+
expanded_noise_scale[~token_to_fill_mask] = 0.0
|
|
367
|
+
|
|
368
|
+
out: dict[str, Tensor] = {
|
|
369
|
+
"distogram_condition_noise_scale": expanded_noise_scale, # (n_token,)
|
|
370
|
+
"has_distogram_condition": torch.as_tensor(
|
|
371
|
+
token_to_fill_mask_II, dtype=torch.bool
|
|
372
|
+
), # (n_token, n_token)
|
|
373
|
+
"distogram_condition": template_distogram_onehot, # (n_token, n_token, n_bins)
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
assert_no_nans(out, msg="Conditioning features contain NaNs!")
|
|
377
|
+
|
|
378
|
+
return out
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class FeaturizeNoisedGroundTruthAsTemplateDistogram(Transform):
|
|
382
|
+
"""Add noised ground truth as a template distogram.
|
|
383
|
+
|
|
384
|
+
Creates template features by adding Gaussian noise to the ground truth
|
|
385
|
+
coordinates and converting the resulting distances into a discretized
|
|
386
|
+
distogram.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
noise_scale_distribution (Callable): Function that returns the standard
|
|
390
|
+
deviation of noise to add to the ground truth coordinates. Should take
|
|
391
|
+
a sequence of dimensions and return a tensor or float. Default is
|
|
392
|
+
af3_noise_scale_distribution.
|
|
393
|
+
distogram_bins (Tensor): Bin boundaries for discretizing distances in
|
|
394
|
+
the distogram. Shape [n_bins-1].
|
|
395
|
+
allowed_chain_types (list): List of allowed chain types. Default is all chain types.
|
|
396
|
+
p_condition_per_token (float): Per-token probability of conditioning, for those tokens that satisfy all other conditions.
|
|
397
|
+
Default is 0.0 (no conditioning).
|
|
398
|
+
p_provide_inter_molecule_distances (float): Probability of providing inter-molecule (inter-chain) distances.
|
|
399
|
+
Default is 0.0 (no inter-molecule distances provided).
|
|
400
|
+
existing_annotation_to_check (str): Name of an annotation in the AtomArray that,
|
|
401
|
+
if present and True for a token, will force that token to be templated regardless of other conditions.
|
|
402
|
+
Default is "is_input_file_templated".
|
|
403
|
+
|
|
404
|
+
Adds the following features to the `feats` dict:
|
|
405
|
+
- "distogram_condition_noise_scale": Noise scale for each
|
|
406
|
+
token [n_token] (float)
|
|
407
|
+
- "has_distogram_condition": Mask indicating which token pairs have a distogram
|
|
408
|
+
condition [n_token, n_token] (bool)
|
|
409
|
+
- "distogram_condition": One-hot encoded distogram
|
|
410
|
+
[n_token, n_token, n_bins] (float)
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
requires_previous_transforms = [AtomizeByCCDName]
|
|
414
|
+
|
|
415
|
+
def __init__(
|
|
416
|
+
self,
|
|
417
|
+
noise_scale_distribution: NoiseScaleSampler
|
|
418
|
+
| TokenGroupNoiseScaleSampler = af3_noise_scale_distribution,
|
|
419
|
+
distogram_bins: torch.Tensor = DEFAULT_DISTOGRAM_BINS,
|
|
420
|
+
allowed_chain_types: list[ChainType] = ChainType.get_all_types(),
|
|
421
|
+
p_condition_per_token: float = 0.0,
|
|
422
|
+
p_provide_inter_molecule_distances: float = 0.0,
|
|
423
|
+
existing_annotation_to_check: str = "is_input_file_templated",
|
|
424
|
+
):
|
|
425
|
+
self.distogram_bins = distogram_bins
|
|
426
|
+
self.noise_scale_distribution = noise_scale_distribution
|
|
427
|
+
self.p_provide_inter_molecule_distances = p_provide_inter_molecule_distances
|
|
428
|
+
self.allowed_chain_types = allowed_chain_types
|
|
429
|
+
self.p_condition_per_token = p_condition_per_token
|
|
430
|
+
self.existing_annotation_to_check = existing_annotation_to_check
|
|
431
|
+
|
|
432
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
433
|
+
check_contains_keys(data, ["atom_array"])
|
|
434
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
435
|
+
check_atom_array_annotation(data, ["chain_type", "occupancy"])
|
|
436
|
+
|
|
437
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
438
|
+
atom_array = data["atom_array"]
|
|
439
|
+
|
|
440
|
+
if isinstance(self.noise_scale_distribution, TokenGroupNoiseScaleSampler):
|
|
441
|
+
# ... different noise scale for each token group
|
|
442
|
+
noise_scale = self.noise_scale_distribution(atom_array) # [n_token] (float)
|
|
443
|
+
else:
|
|
444
|
+
# ... same noise scale for all tokens
|
|
445
|
+
noise_scale = self.noise_scale_distribution(size=(1,)) # [1] (float)
|
|
446
|
+
|
|
447
|
+
template_features = featurize_noised_ground_truth_as_template_distogram(
|
|
448
|
+
atom_array=atom_array,
|
|
449
|
+
noise_scale=noise_scale,
|
|
450
|
+
allowed_chain_types=self.allowed_chain_types,
|
|
451
|
+
distogram_bins=self.distogram_bins,
|
|
452
|
+
p_provide_inter_molecule_distances=self.p_provide_inter_molecule_distances,
|
|
453
|
+
is_unconditional=data.get("is_unconditional", False),
|
|
454
|
+
p_condition_per_token=self.p_condition_per_token,
|
|
455
|
+
existing_annotation_to_check=self.existing_annotation_to_check,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Add the template features to the `feats` dict
|
|
459
|
+
if "feats" not in data:
|
|
460
|
+
data["feats"] = {}
|
|
461
|
+
data["feats"].update(template_features)
|
|
462
|
+
|
|
463
|
+
return data
|