rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from rfd3.model.layers.layer_utils import (
|
|
3
|
+
MultiDimLinear,
|
|
4
|
+
RMSNorm,
|
|
5
|
+
Transition,
|
|
6
|
+
linearNoBias,
|
|
7
|
+
)
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
11
|
+
from foundry.utils.torch import device_of
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AttentionPairBiasPairformerDeepspeed(nn.Module):
|
|
15
|
+
def __init__(self, c_a, c_s, c_pair, n_head, kq_norm=False):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.n_head = n_head
|
|
18
|
+
self.c_a = c_a
|
|
19
|
+
self.c_pair = c_pair
|
|
20
|
+
self.c = c_a // n_head
|
|
21
|
+
|
|
22
|
+
self.to_q = MultiDimLinear(c_a, (n_head, self.c))
|
|
23
|
+
self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
|
|
24
|
+
self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
|
|
25
|
+
self.to_b = linearNoBias(c_pair, n_head)
|
|
26
|
+
self.to_g = nn.Sequential(
|
|
27
|
+
MultiDimLinear(c_a, (n_head, self.c), bias=False),
|
|
28
|
+
nn.Sigmoid(),
|
|
29
|
+
)
|
|
30
|
+
self.to_a = linearNoBias(c_a, c_a)
|
|
31
|
+
# self.linear_output_project = nn.Sequential(
|
|
32
|
+
# LinearBiasInit(c_s, c_a, biasinit=-2.),
|
|
33
|
+
# nn.Sigmoid(),
|
|
34
|
+
# )
|
|
35
|
+
self.ln_0 = RMSNorm((c_pair,))
|
|
36
|
+
# self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
|
|
37
|
+
self.ln_1 = RMSNorm((c_a,))
|
|
38
|
+
self.use_deepspeed_evo = False
|
|
39
|
+
self.force_bfloat16 = True
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
A_I, # [I, C_a]
|
|
44
|
+
S_I, # [I, C_a] | None
|
|
45
|
+
Z_II, # [I, I, C_z]
|
|
46
|
+
Beta_II=None, # [I, I]
|
|
47
|
+
):
|
|
48
|
+
# Input projections
|
|
49
|
+
assert S_I is None
|
|
50
|
+
A_I = self.ln_1(A_I)
|
|
51
|
+
|
|
52
|
+
if self.use_deepspeed_evo or self.force_bfloat16:
|
|
53
|
+
A_I = A_I.to(torch.bfloat16)
|
|
54
|
+
|
|
55
|
+
Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
|
|
56
|
+
K_IH = self.to_k(A_I)
|
|
57
|
+
V_IH = self.to_v(A_I)
|
|
58
|
+
B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
|
|
59
|
+
G_IH = self.to_g(A_I)
|
|
60
|
+
|
|
61
|
+
B, L = B_IIH.shape[:2]
|
|
62
|
+
|
|
63
|
+
if not self.use_deepspeed_evo or L <= 24:
|
|
64
|
+
Q_IH = Q_IH / torch.sqrt(
|
|
65
|
+
torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
|
|
66
|
+
)
|
|
67
|
+
# Attention
|
|
68
|
+
A_IIH = torch.softmax(
|
|
69
|
+
torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
|
|
70
|
+
) # softmax over j
|
|
71
|
+
## G_IH: [I, H, C]
|
|
72
|
+
## A_IIH: [I, I, H]
|
|
73
|
+
## V_IH: [I, H, C]
|
|
74
|
+
A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
|
|
75
|
+
A_I = G_IH * A_I # [B, I, H, C]
|
|
76
|
+
A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
|
|
77
|
+
else:
|
|
78
|
+
raise NotImplementedError
|
|
79
|
+
|
|
80
|
+
A_I = self.to_a(A_I)
|
|
81
|
+
|
|
82
|
+
return A_I
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class PairformerBlock(nn.Module):
|
|
86
|
+
"""
|
|
87
|
+
Attempt to replicate AF3 architecture from scratch.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
c_s,
|
|
93
|
+
c_z,
|
|
94
|
+
attention_pair_bias,
|
|
95
|
+
p_drop=0.1,
|
|
96
|
+
triangle_multiplication=None,
|
|
97
|
+
triangle_attention=None,
|
|
98
|
+
n_transition=4,
|
|
99
|
+
use_deepspeed_evo=True,
|
|
100
|
+
use_triangle_mult=False,
|
|
101
|
+
use_triangle_attn=False,
|
|
102
|
+
):
|
|
103
|
+
super().__init__()
|
|
104
|
+
|
|
105
|
+
# self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
|
|
106
|
+
# self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
|
|
107
|
+
|
|
108
|
+
self.z_transition = Transition(c=c_z, n=n_transition)
|
|
109
|
+
|
|
110
|
+
if c_s > 0:
|
|
111
|
+
self.s_transition = Transition(c=c_s, n=n_transition)
|
|
112
|
+
|
|
113
|
+
self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
|
|
114
|
+
c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@activation_checkpointing
|
|
118
|
+
def forward(self, S_I, Z_II):
|
|
119
|
+
with torch.amp.autocast(
|
|
120
|
+
device_type=device_of(self).type, enabled=True, dtype=torch.bfloat16
|
|
121
|
+
):
|
|
122
|
+
Z_II = Z_II + self.z_transition(Z_II)
|
|
123
|
+
if S_I is not None:
|
|
124
|
+
S_I = S_I + self.attention_pair_bias(
|
|
125
|
+
S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
|
|
126
|
+
)
|
|
127
|
+
S_I = S_I + self.s_transition(S_I)
|
|
128
|
+
return S_I, Z_II
|
rfd3/run_inference.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import hydra
|
|
6
|
+
import rootutils
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
|
|
10
|
+
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
|
|
11
|
+
|
|
12
|
+
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
13
|
+
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
14
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
15
|
+
|
|
16
|
+
load_dotenv(override=True)
|
|
17
|
+
|
|
18
|
+
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
|
|
19
|
+
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@hydra.main(
|
|
23
|
+
config_path=_config_path,
|
|
24
|
+
config_name="inference",
|
|
25
|
+
version_base="1.3",
|
|
26
|
+
)
|
|
27
|
+
def run_inference(cfg: DictConfig) -> None:
|
|
28
|
+
"""Execute the specified inference pipeline"""
|
|
29
|
+
|
|
30
|
+
run_params_set = {"inputs", "n_batches", "out_dir"}
|
|
31
|
+
run_params = {k: v for k, v in cfg.items() if k in run_params_set}
|
|
32
|
+
|
|
33
|
+
# Create __init__ args by filtering for all configs not in run_params
|
|
34
|
+
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
|
35
|
+
init_cfg_dict = {
|
|
36
|
+
k: v for k, v in cfg_dict.items() if k not in run_params_set | {"_target_"}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Run
|
|
40
|
+
engine = RFD3InferenceEngine(**RFD3InferenceConfig(**init_cfg_dict))
|
|
41
|
+
engine.run(**run_params)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
run_inference()
|
rfd3/testing/debug.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../scripts/shebang/modelhub_exec.sh" "$0" "$@"'
|
|
2
|
+
# JBs debugging file, please create your own and go crazy!
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
import hydra
|
|
9
|
+
import ipdb # noqa: F401
|
|
10
|
+
import numpy as np
|
|
11
|
+
import rootutils
|
|
12
|
+
import torch
|
|
13
|
+
import tree
|
|
14
|
+
from atomworks.ml.utils.token import (
|
|
15
|
+
get_token_starts,
|
|
16
|
+
)
|
|
17
|
+
from rfd3.testing.testing_utils import (
|
|
18
|
+
TEST_CFG_TRAIN,
|
|
19
|
+
TEST_JSON_DATA,
|
|
20
|
+
build_pipelines,
|
|
21
|
+
instantiate_example,
|
|
22
|
+
load_train_or_val_cfg,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from foundry.utils.ddp import set_accelerator_based_on_availability
|
|
26
|
+
|
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# Same as train.py
|
|
31
|
+
rootutils.setup_root(__file__ + "/../..", indicator=".project-root", pythonpath=True)
|
|
32
|
+
_config_path = os.path.join(
|
|
33
|
+
os.environ.get("PROJECT_PATH", os.environ.get("PROJECT_ROOT", "../..")), "configs"
|
|
34
|
+
)
|
|
35
|
+
print(f"Config path: {_config_path}")
|
|
36
|
+
print(f"Project root: {os.environ.get('PROJECT_ROOT', '../..')}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
is_inference = True
|
|
40
|
+
args = TEST_JSON_DATA["1qys-1-refactored"]
|
|
41
|
+
input = instantiate_example(args, is_inference=is_inference)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
TEST_CFG_TRAIN = (
|
|
45
|
+
load_train_or_val_cfg(name=sys.argv[1].split("=")[-1])
|
|
46
|
+
if len(sys.argv) > 1
|
|
47
|
+
else TEST_CFG_TRAIN
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def forward(example, trainer, model, is_inference=is_inference):
|
|
52
|
+
network_input = trainer._assemble_network_inputs(example)
|
|
53
|
+
|
|
54
|
+
# Forward pass
|
|
55
|
+
device = "cuda:0"
|
|
56
|
+
|
|
57
|
+
def _inmap(path, x):
|
|
58
|
+
if hasattr(x, "cpu") and path != ("f", "msa_stack"):
|
|
59
|
+
return x.to(device)
|
|
60
|
+
else:
|
|
61
|
+
return x
|
|
62
|
+
|
|
63
|
+
network_input = tree.map_structure_with_path(_inmap, network_input)
|
|
64
|
+
model.eval() if is_inference else model.train()
|
|
65
|
+
if not is_inference:
|
|
66
|
+
network_output = model.forward(
|
|
67
|
+
input=network_input,
|
|
68
|
+
n_cycle=1,
|
|
69
|
+
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"].to(
|
|
70
|
+
device
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
with torch.no_grad():
|
|
75
|
+
network_output = model.forward(
|
|
76
|
+
input=network_input,
|
|
77
|
+
n_cycle=1,
|
|
78
|
+
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"].to(
|
|
79
|
+
device
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
return network_output
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def prep_forward(cfg):
|
|
86
|
+
trainer = hydra.utils.instantiate(
|
|
87
|
+
cfg.trainer,
|
|
88
|
+
loggers=None,
|
|
89
|
+
callbacks=None,
|
|
90
|
+
_convert_="partial",
|
|
91
|
+
_recursive_=False,
|
|
92
|
+
)
|
|
93
|
+
set_accelerator_based_on_availability(cfg)
|
|
94
|
+
trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
|
|
95
|
+
cfg.trainer.devices_per_node = 1
|
|
96
|
+
cfg.trainer.num_nodes = 1
|
|
97
|
+
try:
|
|
98
|
+
trainer.fabric.launch()
|
|
99
|
+
except Exception as e:
|
|
100
|
+
print(f"Error: {e}")
|
|
101
|
+
print("Switching port")
|
|
102
|
+
os.environ["MASTER_PORT"] = str(1024 + np.random.randint(64512))
|
|
103
|
+
trainer.fabric.launch()
|
|
104
|
+
trainer.construct_model()
|
|
105
|
+
model = trainer.state["model"]
|
|
106
|
+
return model, trainer
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_conditional_forward():
|
|
110
|
+
unindexed_cfg = load_train_or_val_cfg("test-unindexed")
|
|
111
|
+
unindexed_cfg.datasets.global_transform_args.train_conditions.island.frequency = (
|
|
112
|
+
1e10
|
|
113
|
+
)
|
|
114
|
+
unindexed_cfg.datasets.global_transform_args.train_conditions.island.p_unindex_motif_tokens = 1.0
|
|
115
|
+
unindexed_pipes = build_pipelines(composed_config=unindexed_cfg)
|
|
116
|
+
|
|
117
|
+
t0 = time.time()
|
|
118
|
+
example = unindexed_pipes[is_inference](input)
|
|
119
|
+
example["example_id"] = "debug_example"
|
|
120
|
+
print(f"Time taken to process example: {time.time() - t0}")
|
|
121
|
+
|
|
122
|
+
aa = example["atom_array"]
|
|
123
|
+
t_aa = aa[get_token_starts(aa)] # noqa: F841
|
|
124
|
+
|
|
125
|
+
from rfd3.testing.debug_utils import pipe_out_to_file
|
|
126
|
+
|
|
127
|
+
pipe_out_to_file(example, save=True)
|
|
128
|
+
|
|
129
|
+
print("Preparing model")
|
|
130
|
+
model, trainer = prep_forward(TEST_CFG_TRAIN)
|
|
131
|
+
if is_inference:
|
|
132
|
+
model.eval()
|
|
133
|
+
trainer.state["model"].eval()
|
|
134
|
+
network_output = forward(example, trainer, model, is_inference=is_inference) # noqa: F841
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
if __name__ == "__main__":
|
|
138
|
+
test_conditional_forward()
|
|
139
|
+
print("Finished main")
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from atomworks.common import sum_string_arrays
|
|
3
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
4
|
+
from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
|
|
5
|
+
from biotite.structure import AtomArrayStack
|
|
6
|
+
from rfd3.trainer.rfd3 import _reassign_unindexed_token_chains
|
|
7
|
+
from rfd3.transforms.design_transforms import (
|
|
8
|
+
MotifCenterRandomAugmentation,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def pipe_out_to_file(output, save=True):
|
|
13
|
+
atom_array = output["atom_array"]
|
|
14
|
+
|
|
15
|
+
xyz = output["coord_atom_lvl_to_be_noised"]
|
|
16
|
+
idxs = np.argsort(output["t"].numpy())
|
|
17
|
+
eps = output["noise"].numpy()[idxs]
|
|
18
|
+
eps[0] = eps[0] * 0
|
|
19
|
+
x = AtomArrayStack(xyz.shape[0], xyz.shape[1])
|
|
20
|
+
x.coord = xyz[idxs].numpy() + eps
|
|
21
|
+
|
|
22
|
+
x.set_annotation("chain_id", ["A"] * xyz.shape[1])
|
|
23
|
+
x.set_annotation("atom_name", [f"C{i}" for i in range(x.shape[-1])])
|
|
24
|
+
x.set_annotation("res_id", output["feats"]["atom_to_token_map"])
|
|
25
|
+
x.set_annotation("element", ["C"] * x.shape[-1])
|
|
26
|
+
x.set_annotation("res_name", [atom_array.res_name[i] for i in range(x.shape[-1])])
|
|
27
|
+
|
|
28
|
+
if save:
|
|
29
|
+
f = f"{output.get('example_id', 'example')}_debug_out.cif"
|
|
30
|
+
to_cif_file(
|
|
31
|
+
x,
|
|
32
|
+
f,
|
|
33
|
+
id="x",
|
|
34
|
+
)
|
|
35
|
+
print("Saved cif file to:", f)
|
|
36
|
+
else:
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def save_pipe_out(atom_array):
|
|
41
|
+
atom_array = _reassign_unindexed_token_chains(atom_array)
|
|
42
|
+
|
|
43
|
+
f = "debug_out.cif"
|
|
44
|
+
to_cif_file(
|
|
45
|
+
atom_array,
|
|
46
|
+
f,
|
|
47
|
+
id="x",
|
|
48
|
+
)
|
|
49
|
+
print("Saved cif file to:", f)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_debug_pipe(pipe):
|
|
53
|
+
pipe.transforms = [
|
|
54
|
+
t
|
|
55
|
+
for t in pipe.transforms
|
|
56
|
+
if not isinstance(t, (CenterRandomAugmentation, MotifCenterRandomAugmentation))
|
|
57
|
+
]
|
|
58
|
+
return pipe
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Allows to use atom-array whenever debugging by removing friction in atoms having the same identifiers
|
|
62
|
+
def save_debug_cif(atom_array, filepath, name="debug_out.cif"):
|
|
63
|
+
dummy_array = atom_array.copy()
|
|
64
|
+
dummy_array.chain_id = sum_string_arrays(
|
|
65
|
+
dummy_array.chain_id, "-", dummy_array.transformation_id.astype(str)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
f = filepath + name
|
|
69
|
+
to_cif_file(
|
|
70
|
+
dummy_array,
|
|
71
|
+
f,
|
|
72
|
+
)
|
|
73
|
+
print("Saved cif file to:", f)
|