rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import biotite.structure as struct
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from atomworks.ml.preprocessing.constants import ChainType
|
|
5
|
+
from atomworks.ml.transforms._checks import check_contains_keys
|
|
6
|
+
from atomworks.ml.transforms.base import Transform
|
|
7
|
+
from rfd3.transforms.conditioning_base import (
|
|
8
|
+
convert_existing_annotations_to_bool,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
MIRROR_IMAGE_MAPPING = {
|
|
12
|
+
"ALA": "DAL",
|
|
13
|
+
"SER": "DSN",
|
|
14
|
+
"CYS": "DCY",
|
|
15
|
+
"PRO": "DPR",
|
|
16
|
+
"VAL": "DVA",
|
|
17
|
+
"THR": "DTH",
|
|
18
|
+
"LEU": "DLE",
|
|
19
|
+
"ILE": "DIL",
|
|
20
|
+
"ASN": "DSG",
|
|
21
|
+
"ASP": "DAS",
|
|
22
|
+
"MET": "MED",
|
|
23
|
+
"GLN": "DGN",
|
|
24
|
+
"GLU": "DGL",
|
|
25
|
+
"LYS": "DLY",
|
|
26
|
+
"HIS": "DHI",
|
|
27
|
+
"PHE": "DPN",
|
|
28
|
+
"ARG": "DAR",
|
|
29
|
+
"TYR": "DTY",
|
|
30
|
+
"TRP": "DTR",
|
|
31
|
+
"GLY": "GLY",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
D_TO_L_MAPPING = {v: k for k, v in MIRROR_IMAGE_MAPPING.items() if k != "GLY"}
|
|
35
|
+
|
|
36
|
+
TWO_WAY_MIRROR_IMAGE_MAPPING = {**MIRROR_IMAGE_MAPPING, **D_TO_L_MAPPING}
|
|
37
|
+
|
|
38
|
+
D_AA = [aa for aa in MIRROR_IMAGE_MAPPING.values() if aa != "GLY"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RandomlyMirrorInputs(Transform):
|
|
42
|
+
"""
|
|
43
|
+
This component reflects inputs with a user-provided probability.
|
|
44
|
+
|
|
45
|
+
Only protein and ligand comonents are reflected, nucleic acids are not.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def forward(self, data: dict) -> dict:
|
|
49
|
+
assert not data.get("is_inference", False)
|
|
50
|
+
mirror_input = data["conditions"].get("mirror_input", False)
|
|
51
|
+
atom_array = data["atom_array"]
|
|
52
|
+
|
|
53
|
+
if (
|
|
54
|
+
(atom_array.chain_type == ChainType.DNA).any()
|
|
55
|
+
or (atom_array.chain_type == ChainType.RNA).any()
|
|
56
|
+
or (atom_array.chain_type == ChainType.DNA_RNA_HYBRID).any()
|
|
57
|
+
):
|
|
58
|
+
return data
|
|
59
|
+
|
|
60
|
+
if not mirror_input:
|
|
61
|
+
return data
|
|
62
|
+
|
|
63
|
+
renamed_map = {}
|
|
64
|
+
res_starts = struct.get_residue_starts(atom_array)
|
|
65
|
+
for i, r_i in enumerate(res_starts):
|
|
66
|
+
if i == len(res_starts) - 1:
|
|
67
|
+
r_j = len(atom_array)
|
|
68
|
+
else:
|
|
69
|
+
r_j = res_starts[i + 1]
|
|
70
|
+
|
|
71
|
+
# case 1: standard AA
|
|
72
|
+
resname = atom_array.res_name[r_i]
|
|
73
|
+
if resname in TWO_WAY_MIRROR_IMAGE_MAPPING:
|
|
74
|
+
atom_array.res_name[r_i:r_j] = TWO_WAY_MIRROR_IMAGE_MAPPING[resname]
|
|
75
|
+
# case 2: non-standard AA or ligand with >=4 atoms
|
|
76
|
+
elif r_j - r_i >= 3:
|
|
77
|
+
if resname in renamed_map:
|
|
78
|
+
newname = renamed_map[resname]
|
|
79
|
+
else:
|
|
80
|
+
newname = "L:" + str(len(renamed_map))
|
|
81
|
+
renamed_map[resname] = newname
|
|
82
|
+
atom_array.res_name[r_i:r_j] = newname
|
|
83
|
+
|
|
84
|
+
# flip coords about Z
|
|
85
|
+
atom_array.coord = atom_array.coord * np.array([1, 1, -1.0])
|
|
86
|
+
|
|
87
|
+
xyz = data.get("coord_atom_lvl_to_be_noised", None)
|
|
88
|
+
if xyz is not None:
|
|
89
|
+
# flip coords about Z
|
|
90
|
+
data["coord_atom_lvl_to_be_noised"] = xyz * torch.tensor(
|
|
91
|
+
[1, 1, -1], dtype=xyz.dtype, device=xyz.device
|
|
92
|
+
)
|
|
93
|
+
ground_truth_coord = (
|
|
94
|
+
data["ground_truth"].get("coord_atom_lvl", None)
|
|
95
|
+
if "ground_truth" in data
|
|
96
|
+
else None
|
|
97
|
+
)
|
|
98
|
+
if ground_truth_coord is not None:
|
|
99
|
+
# flip coords about Z
|
|
100
|
+
data["ground_truth"]["coord_atom_lvl"] = ground_truth_coord * torch.tensor(
|
|
101
|
+
[1, 1, -1],
|
|
102
|
+
dtype=ground_truth_coord.dtype,
|
|
103
|
+
device=ground_truth_coord.device,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return data
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AddIsDAminoAcidFeat(Transform):
|
|
110
|
+
"""
|
|
111
|
+
Adds an annotation to the atom array indicating whether each residue is a D-amino acid.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def check_input(self, data) -> None:
|
|
115
|
+
check_contains_keys(data, ["atom_array", "feats"])
|
|
116
|
+
|
|
117
|
+
def forward(self, data: dict) -> dict:
|
|
118
|
+
atom_array = data["atom_array"]
|
|
119
|
+
# Check if there is already an annotation for D-amino acids
|
|
120
|
+
if "is_d_amino_acid" not in atom_array.get_annotation_categories():
|
|
121
|
+
# Check if the res_name is in the D-amino acid set
|
|
122
|
+
is_d_aa = np.isin(atom_array.res_name, D_AA)
|
|
123
|
+
# Create a new annotation for D-amino acids
|
|
124
|
+
|
|
125
|
+
glycines = atom_array.res_name == "GLY"
|
|
126
|
+
# half the time, we will set glycine to be D-glycine
|
|
127
|
+
is_d_aa = np.logical_or(
|
|
128
|
+
is_d_aa, np.logical_and(glycines, np.random.rand(len(glycines)) < 0.5)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
atom_array.set_annotation(
|
|
132
|
+
"is_d_amino_acid",
|
|
133
|
+
is_d_aa,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Add feature for is_d_amino_acid
|
|
137
|
+
if "is_d_amino_acid" not in data["feats"]:
|
|
138
|
+
is_d_amino_acid = atom_array.get_annotation("is_d_amino_acid")
|
|
139
|
+
data["feats"]["is_d_amino_acid"] = is_d_amino_acid
|
|
140
|
+
|
|
141
|
+
data["atom_array"] = atom_array
|
|
142
|
+
|
|
143
|
+
return data
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class StrtoBoolforIsDAminoAcidFeature(Transform):
|
|
147
|
+
def forward(self, data):
|
|
148
|
+
atom_array = data["atom_array"]
|
|
149
|
+
convert_existing_annotations_to_bool(
|
|
150
|
+
atom_array, annotations=["is_d_amino_acid"]
|
|
151
|
+
)
|
|
152
|
+
data["atom_array"] = atom_array
|
|
153
|
+
return data
|