rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/utils/recycling.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_recycle_schedule(
|
|
8
|
+
max_cycle: int,
|
|
9
|
+
n_epochs: int,
|
|
10
|
+
n_train: int,
|
|
11
|
+
world_size: int,
|
|
12
|
+
seed: int = 42,
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""Generate a schedule for recycling iterations over multiple epochs.
|
|
15
|
+
|
|
16
|
+
Used to ensure that each GPU has the same number of recycles within a given batch.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
max_cycle (int): Maximum number of recycling iterations (n_recycle).
|
|
20
|
+
n_epochs (int): Number of training epochs.
|
|
21
|
+
n_train (int): The total number of training examples per epoch (across all GPUs).
|
|
22
|
+
world_size (int): The number of distributed training processes.
|
|
23
|
+
seed (int, optional): The seed for random number generation. Defaults to 42.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
torch.Tensor: A tensor containing the recycling schedule for each epoch,
|
|
27
|
+
with dimensions `(n_epochs, n_train // world_size)`.
|
|
28
|
+
|
|
29
|
+
References:
|
|
30
|
+
AF-2 Supplement, Algorithm 31
|
|
31
|
+
"""
|
|
32
|
+
# We use a context manager to avoid modifying the global RNG state
|
|
33
|
+
with rng_state(create_rng_state_from_seeds(torch_seed=seed)):
|
|
34
|
+
# ...generate a recycling schedule for each epoch
|
|
35
|
+
recycle_schedule = []
|
|
36
|
+
for i in range(n_epochs):
|
|
37
|
+
schedule = torch.randint(
|
|
38
|
+
1, max_cycle + 1, (math.ceil(n_train / world_size),)
|
|
39
|
+
)
|
|
40
|
+
recycle_schedule.append(schedule)
|
|
41
|
+
|
|
42
|
+
return torch.stack(recycle_schedule, dim=0)
|
rf3/validate.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import hydra
|
|
7
|
+
import rootutils
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
|
|
11
|
+
from foundry.utils.logging import suppress_warnings
|
|
12
|
+
|
|
13
|
+
load_dotenv(override=True)
|
|
14
|
+
|
|
15
|
+
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
16
|
+
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
17
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
18
|
+
|
|
19
|
+
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
|
|
20
|
+
|
|
21
|
+
_spawning_process_logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@hydra.main(config_path=_config_path, config_name="validate", version_base="1.3")
|
|
25
|
+
def validate(cfg: DictConfig) -> None:
|
|
26
|
+
# ==============================================================================
|
|
27
|
+
# Import dependencies and resolve Hydra configuration
|
|
28
|
+
# ==============================================================================
|
|
29
|
+
|
|
30
|
+
_spawning_process_logger.info("Importing dependencies...")
|
|
31
|
+
|
|
32
|
+
# Lazy imports to make config generation fast
|
|
33
|
+
import torch
|
|
34
|
+
from lightning.fabric import seed_everything
|
|
35
|
+
from lightning.fabric.loggers import Logger
|
|
36
|
+
|
|
37
|
+
# If training on DIGS L40, set precision of matrix multiplication to balance speed and accuracy
|
|
38
|
+
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
|
|
39
|
+
torch.set_float32_matmul_precision("medium")
|
|
40
|
+
|
|
41
|
+
from foundry.callbacks.callback import BaseCallback # noqa
|
|
42
|
+
from foundry.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
|
43
|
+
from foundry.utils.logging import print_config_tree # noqa
|
|
44
|
+
from foundry.utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa
|
|
45
|
+
from foundry.utils.ddp import is_rank_zero # noqa
|
|
46
|
+
from foundry.utils.datasets import assemble_val_loader_dict # noqa
|
|
47
|
+
|
|
48
|
+
set_accelerator_based_on_availability(cfg)
|
|
49
|
+
|
|
50
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
51
|
+
_spawning_process_logger.info("Completed dependency imports ...")
|
|
52
|
+
|
|
53
|
+
# ... print the configuration tree (NOTE: Only prints for rank 0)
|
|
54
|
+
print_config_tree(cfg, resolve=True)
|
|
55
|
+
|
|
56
|
+
# ==============================================================================
|
|
57
|
+
# Logging and Callback instantiation
|
|
58
|
+
# ==============================================================================
|
|
59
|
+
|
|
60
|
+
# Reduce the logging level for all dataset and sampler loggers (unless rank 0)
|
|
61
|
+
# We will still see messages from Rank 0; they are identical, since all ranks load and sample from the same datasets
|
|
62
|
+
if not is_rank_zero():
|
|
63
|
+
dataset_logger = logging.getLogger("datasets")
|
|
64
|
+
sampler_logger = logging.getLogger("atomworks.ml.samplers")
|
|
65
|
+
dataset_logger.setLevel(logging.WARNING)
|
|
66
|
+
sampler_logger.setLevel(logging.ERROR)
|
|
67
|
+
|
|
68
|
+
# ... seed everything (NOTE: By setting `workers=True`, we ensure that the dataloaders are seeded as well)
|
|
69
|
+
# (`PL_GLOBAL_SEED` environment varaible will be passed to the spawned subprocessed; e.g., through `ddp_spawn` backend)
|
|
70
|
+
if cfg.get("seed"):
|
|
71
|
+
ranked_logger.info(f"Seeding everything with seed={cfg.seed}...")
|
|
72
|
+
seed_everything(cfg.seed, workers=True, verbose=True)
|
|
73
|
+
else:
|
|
74
|
+
ranked_logger.warning("No seed provided - Not seeding anything!")
|
|
75
|
+
|
|
76
|
+
ranked_logger.info("Instantiating loggers...")
|
|
77
|
+
loggers: list[Logger] = instantiate_loggers(cfg.get("logger"))
|
|
78
|
+
|
|
79
|
+
ranked_logger.info("Instantiating callbacks...")
|
|
80
|
+
callbacks: list[BaseCallback] = instantiate_callbacks(cfg.get("callbacks"))
|
|
81
|
+
|
|
82
|
+
# ==============================================================================
|
|
83
|
+
# Trainer and model instantiation
|
|
84
|
+
# ==============================================================================
|
|
85
|
+
|
|
86
|
+
# ... instantiate the trainer
|
|
87
|
+
trainer = hydra.utils.instantiate(
|
|
88
|
+
cfg.trainer,
|
|
89
|
+
loggers=loggers or None,
|
|
90
|
+
callbacks=callbacks or None,
|
|
91
|
+
_convert_="partial",
|
|
92
|
+
_recursive_=False,
|
|
93
|
+
)
|
|
94
|
+
# (Store the Hydra configuration in the trainer state)
|
|
95
|
+
trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
|
|
96
|
+
|
|
97
|
+
# ... spawn processes for distributed training
|
|
98
|
+
# (We spawn here, rather than within `fit`, so we can use Fabric's `init_module` to efficiently initialize the model on the appropriate device)
|
|
99
|
+
ranked_logger.info(
|
|
100
|
+
f"Spawning {trainer.fabric.world_size} processes from {trainer.fabric.global_rank}..."
|
|
101
|
+
)
|
|
102
|
+
trainer.fabric.launch()
|
|
103
|
+
|
|
104
|
+
# ... construct the model
|
|
105
|
+
trainer.construct_model()
|
|
106
|
+
|
|
107
|
+
# ==============================================================================
|
|
108
|
+
# Dataset instantiation
|
|
109
|
+
# ==============================================================================
|
|
110
|
+
|
|
111
|
+
# Compose the validation loader(s)
|
|
112
|
+
val_loaders = assemble_val_loader_dict(
|
|
113
|
+
cfg=cfg.datasets.val,
|
|
114
|
+
rank=trainer.fabric.global_rank,
|
|
115
|
+
world_size=trainer.fabric.world_size,
|
|
116
|
+
loader_cfg=cfg.dataloader["val"],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# ... load the checkpoint configuration, regardless of whether it's a path or a config
|
|
120
|
+
if "ckpt_path" in cfg and cfg.ckpt_path:
|
|
121
|
+
ckpt_path = cfg.ckpt_path
|
|
122
|
+
elif "ckpt_config" in cfg and cfg.ckpt_config:
|
|
123
|
+
assert (
|
|
124
|
+
"path" in cfg.ckpt_config
|
|
125
|
+
), "No checkpoint path provided in `ckpt_config`!"
|
|
126
|
+
ckpt_path = cfg.ckpt_config.path
|
|
127
|
+
|
|
128
|
+
# ... validate the model
|
|
129
|
+
ranked_logger.info("Validating model...")
|
|
130
|
+
with suppress_warnings():
|
|
131
|
+
trainer.validate(
|
|
132
|
+
val_loaders=val_loaders,
|
|
133
|
+
ckpt_path=ckpt_path,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
ranked_logger.info("Validation complete!")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
if __name__ == "__main__":
|
|
140
|
+
validate()
|
rfd3/.gitignore
ADDED
rfd3/Makefile
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
.PHONY: clean format
|
|
2
|
+
|
|
3
|
+
#################################################################################
|
|
4
|
+
# COMMANDS #
|
|
5
|
+
#################################################################################
|
|
6
|
+
|
|
7
|
+
## Delete all compiled Python files
|
|
8
|
+
clean:
|
|
9
|
+
find . -type f -name "*.py[co]" -delete
|
|
10
|
+
find . -type d -name "__pycache__" -delete
|
|
11
|
+
|
|
12
|
+
## Format src directory using ruff
|
|
13
|
+
format:
|
|
14
|
+
ruff format .
|
|
15
|
+
ruff check --fix .
|
|
16
|
+
|
|
17
|
+
#################################################################################
|
|
18
|
+
# Self Documenting Commands #
|
|
19
|
+
#################################################################################
|
|
20
|
+
|
|
21
|
+
.DEFAULT_GOAL := help
|
|
22
|
+
|
|
23
|
+
# Inspired by <http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html>
|
|
24
|
+
# sed script explained:
|
|
25
|
+
# /^##/:
|
|
26
|
+
# * save line in hold space
|
|
27
|
+
# * purge line
|
|
28
|
+
# * Loop:
|
|
29
|
+
# * append newline + line to hold space
|
|
30
|
+
# * go to next line
|
|
31
|
+
# * if line starts with doc comment, strip comment character off and loop
|
|
32
|
+
# * remove target prerequisites
|
|
33
|
+
# * append hold space (+ newline) to line
|
|
34
|
+
# * replace newline plus comments by `---`
|
|
35
|
+
# * print line
|
|
36
|
+
# Separate expressions are necessary because labels cannot be delimited by
|
|
37
|
+
# semicolon; see <http://stackoverflow.com/a/11799865/1968>
|
|
38
|
+
.PHONY: help
|
|
39
|
+
help:
|
|
40
|
+
@echo "$$(tput bold)Available rules:$$(tput sgr0)"
|
|
41
|
+
@echo
|
|
42
|
+
@sed -n -e "/^## / { \
|
|
43
|
+
h; \
|
|
44
|
+
s/.*//; \
|
|
45
|
+
:doc" \
|
|
46
|
+
-e "H; \
|
|
47
|
+
n; \
|
|
48
|
+
s/^## //; \
|
|
49
|
+
t doc" \
|
|
50
|
+
-e "s/:.*//; \
|
|
51
|
+
G; \
|
|
52
|
+
s/\\n## /---/; \
|
|
53
|
+
s/\\n/ /g; \
|
|
54
|
+
p; \
|
|
55
|
+
}" ${MAKEFILE_LIST} \
|
|
56
|
+
| LC_ALL='C' sort --ignore-case \
|
|
57
|
+
| awk -F '---' \
|
|
58
|
+
-v ncol=$$(tput cols) \
|
|
59
|
+
-v indent=19 \
|
|
60
|
+
-v col_on="$$(tput setaf 6)" \
|
|
61
|
+
-v col_off="$$(tput sgr0)" \
|
|
62
|
+
'{ \
|
|
63
|
+
printf "%s%*s%s ", col_on, -indent, $$1, col_off; \
|
|
64
|
+
n = split($$2, words, " "); \
|
|
65
|
+
line_length = ncol - indent; \
|
|
66
|
+
for (i = 1; i <= n; i++) { \
|
|
67
|
+
line_length -= length(words[i]) + 1; \
|
|
68
|
+
if (line_length <= 0) { \
|
|
69
|
+
line_length = ncol - indent - length(words[i]) - 1; \
|
|
70
|
+
printf "\n%*s ", -indent, " "; \
|
|
71
|
+
} \
|
|
72
|
+
printf "%s ", words[i]; \
|
|
73
|
+
} \
|
|
74
|
+
printf "\n"; \
|
|
75
|
+
}' \
|
|
76
|
+
| more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars')
|
rfd3/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""RFD3 - RosettaFold-diffusion model implementation."""
|
|
2
|
+
|
|
3
|
+
import pydantic
|
|
4
|
+
from packaging.version import Version
|
|
5
|
+
|
|
6
|
+
if Version(pydantic.__version__) < Version("2.0"):
|
|
7
|
+
raise RuntimeError(
|
|
8
|
+
f"Pydantic >=2.0 is required; found {pydantic.__version__}. "
|
|
9
|
+
"Pin pydantic>=2,<3 and upgrade dependent packages."
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__version__ = "0.1.0"
|
rfd3/callbacks.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from beartype.typing import Any
|
|
3
|
+
|
|
4
|
+
from foundry.callbacks.callback import BaseCallback
|
|
5
|
+
from foundry.utils.ddp import RankedLogger
|
|
6
|
+
from foundry.utils.logging import print_df_as_table
|
|
7
|
+
|
|
8
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LogDesignValidationMetricsCallback(BaseCallback):
|
|
12
|
+
def on_validation_epoch_end(self, trainer: Any):
|
|
13
|
+
# Only log metrics to disk if this is the global zero rank
|
|
14
|
+
if not trainer.fabric.is_global_zero:
|
|
15
|
+
return
|
|
16
|
+
|
|
17
|
+
assert hasattr(
|
|
18
|
+
trainer, "validation_results_path"
|
|
19
|
+
), "Results path not found! Ensure that StoreValidationMetricsInDFCallback is called first."
|
|
20
|
+
df = pd.read_csv(trainer.validation_results_path)
|
|
21
|
+
|
|
22
|
+
# ... filter to most recent epoch, drop epoch column
|
|
23
|
+
df = df[df["epoch"] == df["epoch"].max()]
|
|
24
|
+
df.drop(columns=["epoch"], inplace=True)
|
|
25
|
+
|
|
26
|
+
for dataset in df["dataset"].unique():
|
|
27
|
+
dataset_df = df[df["dataset"] == dataset].copy()
|
|
28
|
+
dataset_df.drop(columns=["dataset"], inplace=True)
|
|
29
|
+
|
|
30
|
+
print(f"\n+{' ' + dataset + ' ':-^150}+\n")
|
|
31
|
+
|
|
32
|
+
remaining_cols = [
|
|
33
|
+
col for col in dataset_df.columns if col not in ["example_id"]
|
|
34
|
+
]
|
|
35
|
+
remaining_df = dataset_df[remaining_cols].copy()
|
|
36
|
+
remaining_df = remaining_df.dropna(how="all")
|
|
37
|
+
numeric_cols = remaining_df.select_dtypes(include="number").columns
|
|
38
|
+
|
|
39
|
+
# Compute means and non-NaN counts for numeric columns
|
|
40
|
+
final_means = remaining_df[numeric_cols].mean()
|
|
41
|
+
non_nan_counts = remaining_df[numeric_cols].count()
|
|
42
|
+
|
|
43
|
+
# Convert the Series to a DataFrame and add the count as a new column
|
|
44
|
+
final_means_df = final_means.to_frame(name="mean")
|
|
45
|
+
final_means_df["Count"] = non_nan_counts
|
|
46
|
+
|
|
47
|
+
print_df_as_table(
|
|
48
|
+
final_means_df.reset_index(),
|
|
49
|
+
f"{dataset} — {trainer.state['current_epoch']} — Design Validation Metrics",
|
|
50
|
+
)
|
|
51
|
+
if trainer.fabric:
|
|
52
|
+
trainer.fabric.log_dict(
|
|
53
|
+
{f"val/{dataset}/{col}": final_means[col] for col in numeric_cols},
|
|
54
|
+
step=trainer.state["current_epoch"],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if len(dataset_df["example_id"].unique()) <= 25:
|
|
58
|
+
for eid, df_ in dataset_df.groupby("example_id"):
|
|
59
|
+
df_ = df_[numeric_cols].mean()
|
|
60
|
+
trainer.fabric.log_dict(
|
|
61
|
+
{
|
|
62
|
+
f"val/{dataset}/{col}/{eid}": df_[col]
|
|
63
|
+
for col in numeric_cols
|
|
64
|
+
},
|
|
65
|
+
step=trainer.state["current_epoch"],
|
|
66
|
+
)
|
rfd3/cli.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import typer
|
|
4
|
+
from hydra import compose, initialize_config_dir
|
|
5
|
+
|
|
6
|
+
app = typer.Typer()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@app.command(
|
|
10
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
|
11
|
+
)
|
|
12
|
+
def design(ctx: typer.Context):
|
|
13
|
+
"""Run design using hydra config overrides and input files."""
|
|
14
|
+
# Find the RFD3 configs directory relative to this file
|
|
15
|
+
# This file is at: models/rfd3/src/rfd3/cli.py
|
|
16
|
+
# Configs are at: models/rfd3/configs/
|
|
17
|
+
rfd3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rfd3/
|
|
18
|
+
config_path = str(rfd3_package_dir / "configs")
|
|
19
|
+
|
|
20
|
+
# Get all arguments
|
|
21
|
+
args = ctx.params.get("args", []) + ctx.args
|
|
22
|
+
args = [a for a in args if a not in ["design", "fold"]]
|
|
23
|
+
|
|
24
|
+
# Ensure we have at least a default inference_engine if not specified
|
|
25
|
+
has_inference_engine = any(arg.startswith("inference_engine=") for arg in args)
|
|
26
|
+
if not has_inference_engine:
|
|
27
|
+
args.append("inference_engine=rfdiffusion3")
|
|
28
|
+
|
|
29
|
+
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
|
|
30
|
+
cfg = compose(config_name="inference", overrides=args)
|
|
31
|
+
|
|
32
|
+
# Lazy import to avoid loading heavy dependencies at CLI startup
|
|
33
|
+
from foundry.utils.logging import suppress_warnings
|
|
34
|
+
from rfd3.run_inference import run_inference
|
|
35
|
+
|
|
36
|
+
with suppress_warnings(is_inference=True):
|
|
37
|
+
run_inference(cfg)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if __name__ == "__main__":
|
|
41
|
+
app()
|
rfd3/constants.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from foundry.constants import TIP_BY_RESTYPE
|
|
4
|
+
|
|
5
|
+
TIP_BY_RESTYPE
|
|
6
|
+
|
|
7
|
+
# Annot: default (diffused default)
|
|
8
|
+
REQUIRED_CONDITIONING_ANNOTATION_VALUES = {
|
|
9
|
+
"is_motif_atom_with_fixed_seq": True,
|
|
10
|
+
"is_motif_atom_with_fixed_coord": True,
|
|
11
|
+
"is_motif_atom_unindexed": False,
|
|
12
|
+
"is_motif_atom_unindexed_motif_breakpoint": False,
|
|
13
|
+
}
|
|
14
|
+
REQUIRED_CONDITIONING_ANNOTATIONS = list(REQUIRED_CONDITIONING_ANNOTATION_VALUES.keys())
|
|
15
|
+
REQUIRED_INFERENCE_ANNOTATIONS = REQUIRED_CONDITIONING_ANNOTATIONS + ["src_component"]
|
|
16
|
+
"""Annotations assigned to every valid atom array"""
|
|
17
|
+
|
|
18
|
+
OPTIONAL_CONDITIONING_VALUES = {
|
|
19
|
+
"is_atom_level_hotspot": 0,
|
|
20
|
+
"is_helix_conditioning": 0,
|
|
21
|
+
"is_sheet_conditioning": 0,
|
|
22
|
+
"is_loop_conditioning": 0,
|
|
23
|
+
"active_donor": 0,
|
|
24
|
+
"active_acceptor": 0,
|
|
25
|
+
"rasa_bin": 3,
|
|
26
|
+
"ref_plddt": 0,
|
|
27
|
+
"is_non_loopy": 0,
|
|
28
|
+
"partial_t": np.nan,
|
|
29
|
+
# kept for legacy reasons
|
|
30
|
+
"is_motif_token": 1,
|
|
31
|
+
"is_motif_atom": 1,
|
|
32
|
+
}
|
|
33
|
+
"""Optional conditioning annotations and their default values if not provided."""
|
|
34
|
+
|
|
35
|
+
CONDITIONING_VALUES = (
|
|
36
|
+
REQUIRED_CONDITIONING_ANNOTATION_VALUES | OPTIONAL_CONDITIONING_VALUES
|
|
37
|
+
)
|
|
38
|
+
"""Annotations that must be present in the AtomArray at inference time."""
|
|
39
|
+
|
|
40
|
+
INFERENCE_ANNOTATIONS = REQUIRED_INFERENCE_ANNOTATIONS + list(
|
|
41
|
+
OPTIONAL_CONDITIONING_VALUES.keys()
|
|
42
|
+
)
|
|
43
|
+
"""All annotations that might be desired at inference time. Determines what AtomArray annotations will be preserved."""
|
|
44
|
+
|
|
45
|
+
SAVED_CONDITIONING_ANNOTATIONS = [
|
|
46
|
+
# "is_motif_atom_with_fixed_coord",
|
|
47
|
+
"is_motif_atom_with_fixed_seq",
|
|
48
|
+
]
|
|
49
|
+
"""Annotations for conditioning to save in output files"""
|
|
50
|
+
|
|
51
|
+
# fmt: off
|
|
52
|
+
ccd_ordering_atomchar = {
|
|
53
|
+
'TRP': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"), # trp
|
|
54
|
+
'HIS': (" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
|
|
55
|
+
'TYR': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
|
|
56
|
+
'PHE': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
|
|
57
|
+
'ASN': (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
|
|
58
|
+
'ASP': (" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None), # asp
|
|
59
|
+
'GLN': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
|
|
60
|
+
'GLU': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None), # glu
|
|
61
|
+
'CYS': (" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
|
|
62
|
+
'SER': (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
|
|
63
|
+
'THR': (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
|
|
64
|
+
'LEU': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
|
|
65
|
+
'VAL': (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
|
|
66
|
+
'ILE': (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
|
|
67
|
+
'MET': (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
|
|
68
|
+
'LYS': (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
|
|
69
|
+
'ARG': (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
|
|
70
|
+
'PRO': (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
|
|
71
|
+
'ALA': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
|
|
72
|
+
'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
|
73
|
+
'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
|
|
74
|
+
'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
|
|
75
|
+
}
|
|
76
|
+
"""Canonical ordering of amino acid atom names in the CCD."""
|
|
77
|
+
|
|
78
|
+
symmetric_atomchar = {
|
|
79
|
+
"TYR": [[" CE1", " CE2"], [" CD1", " CD2"]],
|
|
80
|
+
"PHE": [[" CE1", " CE2"], [" CD1", " CD2"]],
|
|
81
|
+
"ASP": [[" OD1", " OD2"]],
|
|
82
|
+
"GLU": [[" OE1", " OE2"]],
|
|
83
|
+
"LEU": [[" CD1", " CD2"]],
|
|
84
|
+
"VAL": [[" CG1", " CG2"]],
|
|
85
|
+
}
|
|
86
|
+
"""Maps residues to their pairs of aton names corresponding to symmetric atoms."""
|
|
87
|
+
|
|
88
|
+
association_schemes = {
|
|
89
|
+
'atom14': {
|
|
90
|
+
# | Backbone atoms |sp2-L1|sp2-R1|sp2-L2|sp2-R2|sp2-CZ|O-/S-|beta-OH|sp3-CG|sp2-CG|
|
|
91
|
+
# 0 1 2 3 4 V0 V1 V2 V3 V4 V5 V6 V7 V8
|
|
92
|
+
# Aromatics
|
|
93
|
+
'TRP': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"," CG "), # trp
|
|
94
|
+
'HIS': (" N "," CA "," C "," O "," CB "," ND1"," CD2"," CE1"," NE2", None, None, None, None," CG "), # his
|
|
95
|
+
'TYR': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None," CG "), # tyr*
|
|
96
|
+
'PHE': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None," CG "), # phe*
|
|
97
|
+
|
|
98
|
+
# Carboxylates & amines
|
|
99
|
+
'ASN': (" N "," CA "," C "," O "," CB ", None, None, None, None," ND2"," OD1", None, None," CG "), # asn
|
|
100
|
+
'ASP': (" N "," CA "," C "," O "," CB ", None, None," OD1"," OD2", None, None, None, None," CG "), # asp*
|
|
101
|
+
'GLN': (" N "," CA "," C "," O "," CB ", None, None, None, None," NE2"," OE1", None," CD "," CG "), # gln
|
|
102
|
+
'GLU': (" N "," CA "," C "," O "," CB ", None, None," OE2"," OE1", None, None, None," CD "," CG "), # glu*
|
|
103
|
+
|
|
104
|
+
# CB-OH and CB-SG
|
|
105
|
+
'CYS': (" N "," CA "," C "," O "," CB ", None, None, None, None, None," SG ", None, None, None), # cys
|
|
106
|
+
'SER': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None," OG ", None, None), # ser
|
|
107
|
+
'THR': (" N "," CA "," C "," O "," CB "," CG2", None, None, None, None, None," OG1", None, None), # thr
|
|
108
|
+
|
|
109
|
+
# Ile/Leu/Val have a common C backbone but different placements of branching C
|
|
110
|
+
'LEU': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu*
|
|
111
|
+
'VAL': (" N "," CA "," C "," O "," CB "," CG1", None, None," CG2", None, None, None, None, None), # val*
|
|
112
|
+
'ILE': (" N "," CA "," C "," O "," CB "," CG1"," CD1", None," CG2", None, None, None, None, None), # ile
|
|
113
|
+
|
|
114
|
+
# MET / LYS have a common C backbone but heteroatoms inbetween
|
|
115
|
+
'MET': (" N "," CA "," C "," O "," CB "," CG ", None," CE ", None, None," SD ", None, None, None), # met
|
|
116
|
+
'LYS': (" N "," CA "," C "," O "," CB "," CG "," CD "," CE ", None," NZ ", None, None, None, None), # lys
|
|
117
|
+
|
|
118
|
+
# Weird ones
|
|
119
|
+
'ARG': (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," NH1"," CZ "," NH2", None, None, None), # arg*
|
|
120
|
+
'PRO': (" N "," CA "," C "," O "," CB "," CG ", None, None, None, None, None, None," CD ", None), # pro
|
|
121
|
+
|
|
122
|
+
# Other
|
|
123
|
+
'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
|
|
124
|
+
'ALA': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
|
|
125
|
+
'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
|
|
126
|
+
'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
|
127
|
+
},
|
|
128
|
+
|
|
129
|
+
"permute_ambiguous_only": {
|
|
130
|
+
# "CYS": [6, 5,], # SER | Permute *CB and SG (*CB and OG) # CB = next virtual atom since otherwise things get messy
|
|
131
|
+
# "ASP": [8, 7], # [6, 5], # ASN | Permute CG and OD2 (CG and OD1)
|
|
132
|
+
# "GLU": [9, 8], # [7, 6], # GLN | Permute CD and OE2 (CD and OE1)
|
|
133
|
+
|
|
134
|
+
# Ambiguous, modified
|
|
135
|
+
'CYS': (" N "," CA "," C "," O "," CB ", None, " SG ", None, None, None, None, None, None, None), # cys
|
|
136
|
+
'ASP': (" N "," CA "," C "," O "," CB "," CG "," OD1", None, " OD2", None, None, None, None, None), # asp
|
|
137
|
+
'GLU': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1", None, " OE2", None, None, None, None), # glu
|
|
138
|
+
|
|
139
|
+
# Ambiguous, unmodified
|
|
140
|
+
'SER': (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
|
|
141
|
+
'ASN': (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
|
|
142
|
+
'GLN': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
|
|
143
|
+
|
|
144
|
+
# Unambiguous
|
|
145
|
+
'TRP': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"), # trp
|
|
146
|
+
'HIS': (" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
|
|
147
|
+
'TYR': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
|
|
148
|
+
'PHE': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
|
|
149
|
+
'THR': (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
|
|
150
|
+
'LEU': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
|
|
151
|
+
'VAL': (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
|
|
152
|
+
'ILE': (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
|
|
153
|
+
'MET': (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
|
|
154
|
+
'LYS': (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
|
|
155
|
+
'ARG': (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
|
|
156
|
+
'PRO': (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
|
|
157
|
+
'ALA': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
|
|
158
|
+
'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
|
159
|
+
'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
|
|
160
|
+
'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
|
|
161
|
+
},
|
|
162
|
+
|
|
163
|
+
'ccd': ccd_ordering_atomchar,
|
|
164
|
+
}
|
|
165
|
+
association_schemes['atom14-new'] = association_schemes['atom14'].copy()
|
|
166
|
+
association_schemes['atom14-new'] |= {
|
|
167
|
+
# Optional: Break TYR oxygen from GLN / ASN groups - not implemented for rfd3 since it might be useful for people to use
|
|
168
|
+
# 'TYR': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None," OH "," CG "), # tyr*
|
|
169
|
+
# Fixed carboxylate / amide groups:
|
|
170
|
+
'GLN': (" N "," CA "," C "," O "," CB ", None, None, None, None," NE2"," OE1", None," CG "," CD "), # gln
|
|
171
|
+
'GLU': (" N "," CA "," C "," O "," CB ", None, None," OE2"," OE1", None, None, None," CG "," CD "), # glu*
|
|
172
|
+
# Break connection with carboxylates
|
|
173
|
+
'HIS': (" N "," CA "," C "," O "," CB "," ND1"," CD2"," CE1", None, None, None," NE2", None," CG "), # his
|
|
174
|
+
}
|
|
175
|
+
association_schemes['dense'] = association_schemes['permute_ambiguous_only'].copy()
|
|
176
|
+
|
|
177
|
+
# fmt: on
|
|
178
|
+
VIRTUAL_ATOM_ELEMENT_NAME = "VX"
|
|
179
|
+
"""The element name annotation that will be assigned to virtual atoms"""
|
|
180
|
+
|
|
181
|
+
ATOM14_ATOM_NAMES = np.array(
|
|
182
|
+
["N", "CA", "C", "O", "CB"] + [f"V{i}" for i in range(14 - 5)]
|
|
183
|
+
)
|
|
184
|
+
"""Atom14 atom names (e.g. CA, V1)"""
|
|
185
|
+
|
|
186
|
+
ATOM14_ATOM_ELEMENTS = np.array(
|
|
187
|
+
["N", "C", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(14 - 5)]
|
|
188
|
+
)
|
|
189
|
+
"""Atom14 element names (e.g. C, VX)"""
|
|
190
|
+
|
|
191
|
+
ATOM14_ATOM_NAME_TO_ELEMENT = {
|
|
192
|
+
name: elem for name, elem in zip(ATOM14_ATOM_NAMES, ATOM14_ATOM_ELEMENTS)
|
|
193
|
+
}
|
|
194
|
+
"""Mapping from atom14 atom names (e.g. CA, V1) to their corresponding element names (e.g. C, VX)"""
|
|
195
|
+
|
|
196
|
+
strip_list = lambda x: [(x.strip() if x is not None else None) for x in x] # noqa
|
|
197
|
+
|
|
198
|
+
association_schemes_stripped = {
|
|
199
|
+
name: {k: strip_list(v) for k, v in scheme.items()}
|
|
200
|
+
for name, scheme in association_schemes.items()
|
|
201
|
+
}
|
|
202
|
+
SELECTION_PROTEIN = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
|
|
203
|
+
SELECTION_NONPROTEIN = [
|
|
204
|
+
"POLYDEOXYRIBONUCLEOTIDE",
|
|
205
|
+
"POLYRIBONUCLEOTIDE",
|
|
206
|
+
"PEPTIDE NUCLEIC ACID",
|
|
207
|
+
"OTHER",
|
|
208
|
+
"NON-POLYMER",
|
|
209
|
+
"CYCLIC-PSEUDO-PEPTIDE",
|
|
210
|
+
"MACROLIDE",
|
|
211
|
+
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
|
|
212
|
+
]
|