rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/utils/recycling.py ADDED
@@ -0,0 +1,42 @@
1
+ import math
2
+
3
+ import torch
4
+ from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state
5
+
6
+
7
+ def get_recycle_schedule(
8
+ max_cycle: int,
9
+ n_epochs: int,
10
+ n_train: int,
11
+ world_size: int,
12
+ seed: int = 42,
13
+ ) -> torch.Tensor:
14
+ """Generate a schedule for recycling iterations over multiple epochs.
15
+
16
+ Used to ensure that each GPU has the same number of recycles within a given batch.
17
+
18
+ Args:
19
+ max_cycle (int): Maximum number of recycling iterations (n_recycle).
20
+ n_epochs (int): Number of training epochs.
21
+ n_train (int): The total number of training examples per epoch (across all GPUs).
22
+ world_size (int): The number of distributed training processes.
23
+ seed (int, optional): The seed for random number generation. Defaults to 42.
24
+
25
+ Returns:
26
+ torch.Tensor: A tensor containing the recycling schedule for each epoch,
27
+ with dimensions `(n_epochs, n_train // world_size)`.
28
+
29
+ References:
30
+ AF-2 Supplement, Algorithm 31
31
+ """
32
+ # We use a context manager to avoid modifying the global RNG state
33
+ with rng_state(create_rng_state_from_seeds(torch_seed=seed)):
34
+ # ...generate a recycling schedule for each epoch
35
+ recycle_schedule = []
36
+ for i in range(n_epochs):
37
+ schedule = torch.randint(
38
+ 1, max_cycle + 1, (math.ceil(n_train / world_size),)
39
+ )
40
+ recycle_schedule.append(schedule)
41
+
42
+ return torch.stack(recycle_schedule, dim=0)
rf3/validate.py ADDED
@@ -0,0 +1,140 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
2
+
3
+ import logging
4
+ import os
5
+
6
+ import hydra
7
+ import rootutils
8
+ from dotenv import load_dotenv
9
+ from omegaconf import DictConfig
10
+
11
+ from foundry.utils.logging import suppress_warnings
12
+
13
+ load_dotenv(override=True)
14
+
15
+ # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
16
+ # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
17
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
18
+
19
+ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
20
+
21
+ _spawning_process_logger = logging.getLogger(__name__)
22
+
23
+
24
+ @hydra.main(config_path=_config_path, config_name="validate", version_base="1.3")
25
+ def validate(cfg: DictConfig) -> None:
26
+ # ==============================================================================
27
+ # Import dependencies and resolve Hydra configuration
28
+ # ==============================================================================
29
+
30
+ _spawning_process_logger.info("Importing dependencies...")
31
+
32
+ # Lazy imports to make config generation fast
33
+ import torch
34
+ from lightning.fabric import seed_everything
35
+ from lightning.fabric.loggers import Logger
36
+
37
+ # If training on DIGS L40, set precision of matrix multiplication to balance speed and accuracy
38
+ # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
39
+ torch.set_float32_matmul_precision("medium")
40
+
41
+ from foundry.callbacks.callback import BaseCallback # noqa
42
+ from foundry.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
43
+ from foundry.utils.logging import print_config_tree # noqa
44
+ from foundry.utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa
45
+ from foundry.utils.ddp import is_rank_zero # noqa
46
+ from foundry.utils.datasets import assemble_val_loader_dict # noqa
47
+
48
+ set_accelerator_based_on_availability(cfg)
49
+
50
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
51
+ _spawning_process_logger.info("Completed dependency imports ...")
52
+
53
+ # ... print the configuration tree (NOTE: Only prints for rank 0)
54
+ print_config_tree(cfg, resolve=True)
55
+
56
+ # ==============================================================================
57
+ # Logging and Callback instantiation
58
+ # ==============================================================================
59
+
60
+ # Reduce the logging level for all dataset and sampler loggers (unless rank 0)
61
+ # We will still see messages from Rank 0; they are identical, since all ranks load and sample from the same datasets
62
+ if not is_rank_zero():
63
+ dataset_logger = logging.getLogger("datasets")
64
+ sampler_logger = logging.getLogger("atomworks.ml.samplers")
65
+ dataset_logger.setLevel(logging.WARNING)
66
+ sampler_logger.setLevel(logging.ERROR)
67
+
68
+ # ... seed everything (NOTE: By setting `workers=True`, we ensure that the dataloaders are seeded as well)
69
+ # (`PL_GLOBAL_SEED` environment varaible will be passed to the spawned subprocessed; e.g., through `ddp_spawn` backend)
70
+ if cfg.get("seed"):
71
+ ranked_logger.info(f"Seeding everything with seed={cfg.seed}...")
72
+ seed_everything(cfg.seed, workers=True, verbose=True)
73
+ else:
74
+ ranked_logger.warning("No seed provided - Not seeding anything!")
75
+
76
+ ranked_logger.info("Instantiating loggers...")
77
+ loggers: list[Logger] = instantiate_loggers(cfg.get("logger"))
78
+
79
+ ranked_logger.info("Instantiating callbacks...")
80
+ callbacks: list[BaseCallback] = instantiate_callbacks(cfg.get("callbacks"))
81
+
82
+ # ==============================================================================
83
+ # Trainer and model instantiation
84
+ # ==============================================================================
85
+
86
+ # ... instantiate the trainer
87
+ trainer = hydra.utils.instantiate(
88
+ cfg.trainer,
89
+ loggers=loggers or None,
90
+ callbacks=callbacks or None,
91
+ _convert_="partial",
92
+ _recursive_=False,
93
+ )
94
+ # (Store the Hydra configuration in the trainer state)
95
+ trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
96
+
97
+ # ... spawn processes for distributed training
98
+ # (We spawn here, rather than within `fit`, so we can use Fabric's `init_module` to efficiently initialize the model on the appropriate device)
99
+ ranked_logger.info(
100
+ f"Spawning {trainer.fabric.world_size} processes from {trainer.fabric.global_rank}..."
101
+ )
102
+ trainer.fabric.launch()
103
+
104
+ # ... construct the model
105
+ trainer.construct_model()
106
+
107
+ # ==============================================================================
108
+ # Dataset instantiation
109
+ # ==============================================================================
110
+
111
+ # Compose the validation loader(s)
112
+ val_loaders = assemble_val_loader_dict(
113
+ cfg=cfg.datasets.val,
114
+ rank=trainer.fabric.global_rank,
115
+ world_size=trainer.fabric.world_size,
116
+ loader_cfg=cfg.dataloader["val"],
117
+ )
118
+
119
+ # ... load the checkpoint configuration, regardless of whether it's a path or a config
120
+ if "ckpt_path" in cfg and cfg.ckpt_path:
121
+ ckpt_path = cfg.ckpt_path
122
+ elif "ckpt_config" in cfg and cfg.ckpt_config:
123
+ assert (
124
+ "path" in cfg.ckpt_config
125
+ ), "No checkpoint path provided in `ckpt_config`!"
126
+ ckpt_path = cfg.ckpt_config.path
127
+
128
+ # ... validate the model
129
+ ranked_logger.info("Validating model...")
130
+ with suppress_warnings():
131
+ trainer.validate(
132
+ val_loaders=val_loaders,
133
+ ckpt_path=ckpt_path,
134
+ )
135
+
136
+ ranked_logger.info("Validation complete!")
137
+
138
+
139
+ if __name__ == "__main__":
140
+ validate()
rfd3/.gitignore ADDED
@@ -0,0 +1,7 @@
1
+ tests/outs
2
+ tests/test_data/mcsa_41/
3
+ configs/datasets/val/data
4
+ configs/model/old
5
+ transforms/old
6
+ benchmarks
7
+ old
rfd3/Makefile ADDED
@@ -0,0 +1,76 @@
1
+ .PHONY: clean format
2
+
3
+ #################################################################################
4
+ # COMMANDS #
5
+ #################################################################################
6
+
7
+ ## Delete all compiled Python files
8
+ clean:
9
+ find . -type f -name "*.py[co]" -delete
10
+ find . -type d -name "__pycache__" -delete
11
+
12
+ ## Format src directory using ruff
13
+ format:
14
+ ruff format .
15
+ ruff check --fix .
16
+
17
+ #################################################################################
18
+ # Self Documenting Commands #
19
+ #################################################################################
20
+
21
+ .DEFAULT_GOAL := help
22
+
23
+ # Inspired by <http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html>
24
+ # sed script explained:
25
+ # /^##/:
26
+ # * save line in hold space
27
+ # * purge line
28
+ # * Loop:
29
+ # * append newline + line to hold space
30
+ # * go to next line
31
+ # * if line starts with doc comment, strip comment character off and loop
32
+ # * remove target prerequisites
33
+ # * append hold space (+ newline) to line
34
+ # * replace newline plus comments by `---`
35
+ # * print line
36
+ # Separate expressions are necessary because labels cannot be delimited by
37
+ # semicolon; see <http://stackoverflow.com/a/11799865/1968>
38
+ .PHONY: help
39
+ help:
40
+ @echo "$$(tput bold)Available rules:$$(tput sgr0)"
41
+ @echo
42
+ @sed -n -e "/^## / { \
43
+ h; \
44
+ s/.*//; \
45
+ :doc" \
46
+ -e "H; \
47
+ n; \
48
+ s/^## //; \
49
+ t doc" \
50
+ -e "s/:.*//; \
51
+ G; \
52
+ s/\\n## /---/; \
53
+ s/\\n/ /g; \
54
+ p; \
55
+ }" ${MAKEFILE_LIST} \
56
+ | LC_ALL='C' sort --ignore-case \
57
+ | awk -F '---' \
58
+ -v ncol=$$(tput cols) \
59
+ -v indent=19 \
60
+ -v col_on="$$(tput setaf 6)" \
61
+ -v col_off="$$(tput sgr0)" \
62
+ '{ \
63
+ printf "%s%*s%s ", col_on, -indent, $$1, col_off; \
64
+ n = split($$2, words, " "); \
65
+ line_length = ncol - indent; \
66
+ for (i = 1; i <= n; i++) { \
67
+ line_length -= length(words[i]) + 1; \
68
+ if (line_length <= 0) { \
69
+ line_length = ncol - indent - length(words[i]) - 1; \
70
+ printf "\n%*s ", -indent, " "; \
71
+ } \
72
+ printf "%s ", words[i]; \
73
+ } \
74
+ printf "\n"; \
75
+ }' \
76
+ | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars')
rfd3/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ """RFD3 - RosettaFold-diffusion model implementation."""
2
+
3
+ import pydantic
4
+ from packaging.version import Version
5
+
6
+ if Version(pydantic.__version__) < Version("2.0"):
7
+ raise RuntimeError(
8
+ f"Pydantic >=2.0 is required; found {pydantic.__version__}. "
9
+ "Pin pydantic>=2,<3 and upgrade dependent packages."
10
+ )
11
+
12
+ __version__ = "0.1.0"
rfd3/callbacks.py ADDED
@@ -0,0 +1,66 @@
1
+ import pandas as pd
2
+ from beartype.typing import Any
3
+
4
+ from foundry.callbacks.callback import BaseCallback
5
+ from foundry.utils.ddp import RankedLogger
6
+ from foundry.utils.logging import print_df_as_table
7
+
8
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ class LogDesignValidationMetricsCallback(BaseCallback):
12
+ def on_validation_epoch_end(self, trainer: Any):
13
+ # Only log metrics to disk if this is the global zero rank
14
+ if not trainer.fabric.is_global_zero:
15
+ return
16
+
17
+ assert hasattr(
18
+ trainer, "validation_results_path"
19
+ ), "Results path not found! Ensure that StoreValidationMetricsInDFCallback is called first."
20
+ df = pd.read_csv(trainer.validation_results_path)
21
+
22
+ # ... filter to most recent epoch, drop epoch column
23
+ df = df[df["epoch"] == df["epoch"].max()]
24
+ df.drop(columns=["epoch"], inplace=True)
25
+
26
+ for dataset in df["dataset"].unique():
27
+ dataset_df = df[df["dataset"] == dataset].copy()
28
+ dataset_df.drop(columns=["dataset"], inplace=True)
29
+
30
+ print(f"\n+{' ' + dataset + ' ':-^150}+\n")
31
+
32
+ remaining_cols = [
33
+ col for col in dataset_df.columns if col not in ["example_id"]
34
+ ]
35
+ remaining_df = dataset_df[remaining_cols].copy()
36
+ remaining_df = remaining_df.dropna(how="all")
37
+ numeric_cols = remaining_df.select_dtypes(include="number").columns
38
+
39
+ # Compute means and non-NaN counts for numeric columns
40
+ final_means = remaining_df[numeric_cols].mean()
41
+ non_nan_counts = remaining_df[numeric_cols].count()
42
+
43
+ # Convert the Series to a DataFrame and add the count as a new column
44
+ final_means_df = final_means.to_frame(name="mean")
45
+ final_means_df["Count"] = non_nan_counts
46
+
47
+ print_df_as_table(
48
+ final_means_df.reset_index(),
49
+ f"{dataset} — {trainer.state['current_epoch']} — Design Validation Metrics",
50
+ )
51
+ if trainer.fabric:
52
+ trainer.fabric.log_dict(
53
+ {f"val/{dataset}/{col}": final_means[col] for col in numeric_cols},
54
+ step=trainer.state["current_epoch"],
55
+ )
56
+
57
+ if len(dataset_df["example_id"].unique()) <= 25:
58
+ for eid, df_ in dataset_df.groupby("example_id"):
59
+ df_ = df_[numeric_cols].mean()
60
+ trainer.fabric.log_dict(
61
+ {
62
+ f"val/{dataset}/{col}/{eid}": df_[col]
63
+ for col in numeric_cols
64
+ },
65
+ step=trainer.state["current_epoch"],
66
+ )
rfd3/cli.py ADDED
@@ -0,0 +1,41 @@
1
+ from pathlib import Path
2
+
3
+ import typer
4
+ from hydra import compose, initialize_config_dir
5
+
6
+ app = typer.Typer()
7
+
8
+
9
+ @app.command(
10
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
11
+ )
12
+ def design(ctx: typer.Context):
13
+ """Run design using hydra config overrides and input files."""
14
+ # Find the RFD3 configs directory relative to this file
15
+ # This file is at: models/rfd3/src/rfd3/cli.py
16
+ # Configs are at: models/rfd3/configs/
17
+ rfd3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rfd3/
18
+ config_path = str(rfd3_package_dir / "configs")
19
+
20
+ # Get all arguments
21
+ args = ctx.params.get("args", []) + ctx.args
22
+ args = [a for a in args if a not in ["design", "fold"]]
23
+
24
+ # Ensure we have at least a default inference_engine if not specified
25
+ has_inference_engine = any(arg.startswith("inference_engine=") for arg in args)
26
+ if not has_inference_engine:
27
+ args.append("inference_engine=rfdiffusion3")
28
+
29
+ with initialize_config_dir(config_dir=config_path, version_base="1.3"):
30
+ cfg = compose(config_name="inference", overrides=args)
31
+
32
+ # Lazy import to avoid loading heavy dependencies at CLI startup
33
+ from foundry.utils.logging import suppress_warnings
34
+ from rfd3.run_inference import run_inference
35
+
36
+ with suppress_warnings(is_inference=True):
37
+ run_inference(cfg)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ app()
rfd3/constants.py ADDED
@@ -0,0 +1,212 @@
1
+ import numpy as np
2
+
3
+ from foundry.constants import TIP_BY_RESTYPE
4
+
5
+ TIP_BY_RESTYPE
6
+
7
+ # Annot: default (diffused default)
8
+ REQUIRED_CONDITIONING_ANNOTATION_VALUES = {
9
+ "is_motif_atom_with_fixed_seq": True,
10
+ "is_motif_atom_with_fixed_coord": True,
11
+ "is_motif_atom_unindexed": False,
12
+ "is_motif_atom_unindexed_motif_breakpoint": False,
13
+ }
14
+ REQUIRED_CONDITIONING_ANNOTATIONS = list(REQUIRED_CONDITIONING_ANNOTATION_VALUES.keys())
15
+ REQUIRED_INFERENCE_ANNOTATIONS = REQUIRED_CONDITIONING_ANNOTATIONS + ["src_component"]
16
+ """Annotations assigned to every valid atom array"""
17
+
18
+ OPTIONAL_CONDITIONING_VALUES = {
19
+ "is_atom_level_hotspot": 0,
20
+ "is_helix_conditioning": 0,
21
+ "is_sheet_conditioning": 0,
22
+ "is_loop_conditioning": 0,
23
+ "active_donor": 0,
24
+ "active_acceptor": 0,
25
+ "rasa_bin": 3,
26
+ "ref_plddt": 0,
27
+ "is_non_loopy": 0,
28
+ "partial_t": np.nan,
29
+ # kept for legacy reasons
30
+ "is_motif_token": 1,
31
+ "is_motif_atom": 1,
32
+ }
33
+ """Optional conditioning annotations and their default values if not provided."""
34
+
35
+ CONDITIONING_VALUES = (
36
+ REQUIRED_CONDITIONING_ANNOTATION_VALUES | OPTIONAL_CONDITIONING_VALUES
37
+ )
38
+ """Annotations that must be present in the AtomArray at inference time."""
39
+
40
+ INFERENCE_ANNOTATIONS = REQUIRED_INFERENCE_ANNOTATIONS + list(
41
+ OPTIONAL_CONDITIONING_VALUES.keys()
42
+ )
43
+ """All annotations that might be desired at inference time. Determines what AtomArray annotations will be preserved."""
44
+
45
+ SAVED_CONDITIONING_ANNOTATIONS = [
46
+ # "is_motif_atom_with_fixed_coord",
47
+ "is_motif_atom_with_fixed_seq",
48
+ ]
49
+ """Annotations for conditioning to save in output files"""
50
+
51
+ # fmt: off
52
+ ccd_ordering_atomchar = {
53
+ 'TRP': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"), # trp
54
+ 'HIS': (" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
55
+ 'TYR': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
56
+ 'PHE': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
57
+ 'ASN': (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
58
+ 'ASP': (" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None), # asp
59
+ 'GLN': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
60
+ 'GLU': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None), # glu
61
+ 'CYS': (" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
62
+ 'SER': (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
63
+ 'THR': (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
64
+ 'LEU': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
65
+ 'VAL': (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
66
+ 'ILE': (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
67
+ 'MET': (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
68
+ 'LYS': (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
69
+ 'ARG': (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
70
+ 'PRO': (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
71
+ 'ALA': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
72
+ 'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
73
+ 'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
74
+ 'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
75
+ }
76
+ """Canonical ordering of amino acid atom names in the CCD."""
77
+
78
+ symmetric_atomchar = {
79
+ "TYR": [[" CE1", " CE2"], [" CD1", " CD2"]],
80
+ "PHE": [[" CE1", " CE2"], [" CD1", " CD2"]],
81
+ "ASP": [[" OD1", " OD2"]],
82
+ "GLU": [[" OE1", " OE2"]],
83
+ "LEU": [[" CD1", " CD2"]],
84
+ "VAL": [[" CG1", " CG2"]],
85
+ }
86
+ """Maps residues to their pairs of aton names corresponding to symmetric atoms."""
87
+
88
+ association_schemes = {
89
+ 'atom14': {
90
+ # | Backbone atoms |sp2-L1|sp2-R1|sp2-L2|sp2-R2|sp2-CZ|O-/S-|beta-OH|sp3-CG|sp2-CG|
91
+ # 0 1 2 3 4 V0 V1 V2 V3 V4 V5 V6 V7 V8
92
+ # Aromatics
93
+ 'TRP': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"," CG "), # trp
94
+ 'HIS': (" N "," CA "," C "," O "," CB "," ND1"," CD2"," CE1"," NE2", None, None, None, None," CG "), # his
95
+ 'TYR': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None," CG "), # tyr*
96
+ 'PHE': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None," CG "), # phe*
97
+
98
+ # Carboxylates & amines
99
+ 'ASN': (" N "," CA "," C "," O "," CB ", None, None, None, None," ND2"," OD1", None, None," CG "), # asn
100
+ 'ASP': (" N "," CA "," C "," O "," CB ", None, None," OD1"," OD2", None, None, None, None," CG "), # asp*
101
+ 'GLN': (" N "," CA "," C "," O "," CB ", None, None, None, None," NE2"," OE1", None," CD "," CG "), # gln
102
+ 'GLU': (" N "," CA "," C "," O "," CB ", None, None," OE2"," OE1", None, None, None," CD "," CG "), # glu*
103
+
104
+ # CB-OH and CB-SG
105
+ 'CYS': (" N "," CA "," C "," O "," CB ", None, None, None, None, None," SG ", None, None, None), # cys
106
+ 'SER': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None," OG ", None, None), # ser
107
+ 'THR': (" N "," CA "," C "," O "," CB "," CG2", None, None, None, None, None," OG1", None, None), # thr
108
+
109
+ # Ile/Leu/Val have a common C backbone but different placements of branching C
110
+ 'LEU': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu*
111
+ 'VAL': (" N "," CA "," C "," O "," CB "," CG1", None, None," CG2", None, None, None, None, None), # val*
112
+ 'ILE': (" N "," CA "," C "," O "," CB "," CG1"," CD1", None," CG2", None, None, None, None, None), # ile
113
+
114
+ # MET / LYS have a common C backbone but heteroatoms inbetween
115
+ 'MET': (" N "," CA "," C "," O "," CB "," CG ", None," CE ", None, None," SD ", None, None, None), # met
116
+ 'LYS': (" N "," CA "," C "," O "," CB "," CG "," CD "," CE ", None," NZ ", None, None, None, None), # lys
117
+
118
+ # Weird ones
119
+ 'ARG': (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," NH1"," CZ "," NH2", None, None, None), # arg*
120
+ 'PRO': (" N "," CA "," C "," O "," CB "," CG ", None, None, None, None, None, None," CD ", None), # pro
121
+
122
+ # Other
123
+ 'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
124
+ 'ALA': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
125
+ 'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
126
+ 'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
127
+ },
128
+
129
+ "permute_ambiguous_only": {
130
+ # "CYS": [6, 5,], # SER | Permute *CB and SG (*CB and OG) # CB = next virtual atom since otherwise things get messy
131
+ # "ASP": [8, 7], # [6, 5], # ASN | Permute CG and OD2 (CG and OD1)
132
+ # "GLU": [9, 8], # [7, 6], # GLN | Permute CD and OE2 (CD and OE1)
133
+
134
+ # Ambiguous, modified
135
+ 'CYS': (" N "," CA "," C "," O "," CB ", None, " SG ", None, None, None, None, None, None, None), # cys
136
+ 'ASP': (" N "," CA "," C "," O "," CB "," CG "," OD1", None, " OD2", None, None, None, None, None), # asp
137
+ 'GLU': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1", None, " OE2", None, None, None, None), # glu
138
+
139
+ # Ambiguous, unmodified
140
+ 'SER': (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
141
+ 'ASN': (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
142
+ 'GLN': (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
143
+
144
+ # Unambiguous
145
+ 'TRP': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"), # trp
146
+ 'HIS': (" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
147
+ 'TYR': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
148
+ 'PHE': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
149
+ 'THR': (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
150
+ 'LEU': (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
151
+ 'VAL': (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
152
+ 'ILE': (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
153
+ 'MET': (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
154
+ 'LYS': (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
155
+ 'ARG': (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
156
+ 'PRO': (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
157
+ 'ALA': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
158
+ 'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
159
+ 'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
160
+ 'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
161
+ },
162
+
163
+ 'ccd': ccd_ordering_atomchar,
164
+ }
165
+ association_schemes['atom14-new'] = association_schemes['atom14'].copy()
166
+ association_schemes['atom14-new'] |= {
167
+ # Optional: Break TYR oxygen from GLN / ASN groups - not implemented for rfd3 since it might be useful for people to use
168
+ # 'TYR': (" N "," CA "," C "," O "," CB "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None," OH "," CG "), # tyr*
169
+ # Fixed carboxylate / amide groups:
170
+ 'GLN': (" N "," CA "," C "," O "," CB ", None, None, None, None," NE2"," OE1", None," CG "," CD "), # gln
171
+ 'GLU': (" N "," CA "," C "," O "," CB ", None, None," OE2"," OE1", None, None, None," CG "," CD "), # glu*
172
+ # Break connection with carboxylates
173
+ 'HIS': (" N "," CA "," C "," O "," CB "," ND1"," CD2"," CE1", None, None, None," NE2", None," CG "), # his
174
+ }
175
+ association_schemes['dense'] = association_schemes['permute_ambiguous_only'].copy()
176
+
177
+ # fmt: on
178
+ VIRTUAL_ATOM_ELEMENT_NAME = "VX"
179
+ """The element name annotation that will be assigned to virtual atoms"""
180
+
181
+ ATOM14_ATOM_NAMES = np.array(
182
+ ["N", "CA", "C", "O", "CB"] + [f"V{i}" for i in range(14 - 5)]
183
+ )
184
+ """Atom14 atom names (e.g. CA, V1)"""
185
+
186
+ ATOM14_ATOM_ELEMENTS = np.array(
187
+ ["N", "C", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(14 - 5)]
188
+ )
189
+ """Atom14 element names (e.g. C, VX)"""
190
+
191
+ ATOM14_ATOM_NAME_TO_ELEMENT = {
192
+ name: elem for name, elem in zip(ATOM14_ATOM_NAMES, ATOM14_ATOM_ELEMENTS)
193
+ }
194
+ """Mapping from atom14 atom names (e.g. CA, V1) to their corresponding element names (e.g. C, VX)"""
195
+
196
+ strip_list = lambda x: [(x.strip() if x is not None else None) for x in x] # noqa
197
+
198
+ association_schemes_stripped = {
199
+ name: {k: strip_list(v) for k, v in scheme.items()}
200
+ for name, scheme in association_schemes.items()
201
+ }
202
+ SELECTION_PROTEIN = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
203
+ SELECTION_NONPROTEIN = [
204
+ "POLYDEOXYRIBONUCLEOTIDE",
205
+ "POLYRIBONUCLEOTIDE",
206
+ "PEPTIDE NUCLEIC ACID",
207
+ "OTHER",
208
+ "NON-POLYMER",
209
+ "CYCLIC-PSEUDO-PEPTIDE",
210
+ "MACROLIDE",
211
+ "POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
212
+ ]