rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,77 @@
1
+ """Resolvers for Hydra configuration files.
2
+
3
+ Documentation on custom resolvers:
4
+ - https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html
5
+ """
6
+
7
+ import importlib
8
+
9
+ from atomworks.enums import ChainType, ChainTypeInfo
10
+ from beartype.typing import Any
11
+ from omegaconf import OmegaConf
12
+
13
+ from ..common import run_once
14
+
15
+
16
+ # (Custom resolvers)
17
+ @run_once
18
+ def register_resolvers():
19
+ resolvers = {
20
+ "resolve_import": resolve_import,
21
+ "chain_type_info_to_regex": chain_type_info_to_regex,
22
+ }
23
+
24
+ for name, resolver in resolvers.items():
25
+ OmegaConf.register_new_resolver(name, resolver)
26
+
27
+
28
+ def resolve_import(module_path: str, attribute_path: str = None) -> Any:
29
+ """
30
+ Import a module and access a specific attribute from it.
31
+
32
+ Args:
33
+ module_path (str): The path to the module.
34
+ attribute_path (str): The path to the attribute within the module.
35
+
36
+ Returns:
37
+ The imported attribute.
38
+ """
39
+ module = importlib.import_module(module_path)
40
+ if attribute_path is not None:
41
+ # Split the attribute path to navigate through nested attributes
42
+ attributes = attribute_path.split(".")
43
+ attr = module
44
+ for attr_name in attributes:
45
+ attr = getattr(attr, attr_name)
46
+ return attr
47
+ else:
48
+ return module
49
+
50
+
51
+ def chain_type_info_to_regex(*args) -> Any:
52
+ """Convert a combination of ChainType or ChainTypeInfo attributes to a regex string.
53
+
54
+ Primarily used for filtering a dataset by chain type prior to training/validation.
55
+
56
+ Example filter:
57
+ - "pn_unit_1_type.astype('str').str.match('${chain_type_info_to_regex:PROTEINS}')"
58
+
59
+ """
60
+ regex_str = ""
61
+
62
+ for arg in args:
63
+ if hasattr(ChainType, arg):
64
+ regex_str += f"{getattr(ChainType, arg).value}|"
65
+ elif hasattr(ChainTypeInfo, arg):
66
+ chain_types_list = getattr(ChainTypeInfo, arg)
67
+ for ct in chain_types_list:
68
+ regex_str += f"{ct.value}|"
69
+ else:
70
+ raise ValueError(
71
+ f"Attribute not found for ChainType or ChainTypeInfo: {arg}."
72
+ )
73
+
74
+ # Remove the trailing '|'
75
+ regex_str = regex_str[:-1]
76
+
77
+ return regex_str
@@ -0,0 +1,235 @@
1
+ import logging
2
+ import os
3
+ from os import PathLike
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+
7
+ import hydra
8
+ import torch
9
+ from biotite.structure import AtomArray
10
+ from lightning.fabric import seed_everything
11
+ from omegaconf import OmegaConf
12
+
13
+ from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
14
+ from foundry.utils.ddp import RankedLogger, set_accelerator_based_on_availability
15
+ from foundry.utils.logging import (
16
+ configure_minimal_inference_logging,
17
+ print_config_tree,
18
+ )
19
+
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
23
+ datefmt="%H:%M:%S",
24
+ )
25
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
26
+
27
+
28
+ def merge(cfg, overrides: dict):
29
+ return OmegaConf.merge(cfg, OmegaConf.create(overrides))
30
+
31
+
32
+ class BaseInferenceEngine:
33
+ """
34
+ Base inference engine.
35
+ Separates model setup (expensive, once) from inference (can run multiple times).
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ ckpt_path: PathLike,
41
+ num_nodes: int = 1,
42
+ devices_per_node: int = 1,
43
+ # Config overrides
44
+ transform_overrides={},
45
+ inference_sampler_overrides={},
46
+ trainer_overrides={},
47
+ # Debug
48
+ verbose: bool = False,
49
+ seed: int | None = None,
50
+ ):
51
+ """Initialize inference engine and load model.
52
+
53
+ Model config is loaded from checkpoint and overridden with parameters provided here.
54
+
55
+ Args:
56
+ ckpt_path: Path to model checkpoint.
57
+ seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
58
+ num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
59
+ devices_per_node: Number of devices per node. Defaults to ``1``.
60
+ verbose: If True, show detailed logging and config trees. Defaults to ``False``.
61
+ """
62
+ if not verbose:
63
+ configure_minimal_inference_logging()
64
+
65
+ # Set attrs
66
+ self.initialized_ = False
67
+ self.trainer = None
68
+ self.pipeline = None
69
+ self.verbose = verbose
70
+
71
+ # Resolve checkpoint path
72
+ if '.' not in str(ckpt_path):
73
+ # Assume registered model
74
+ name = str(ckpt_path)
75
+ assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints'
76
+ ckpt = REGISTERED_CHECKPOINTS[name]
77
+
78
+ ckpt_path = ckpt.get_default_path()
79
+ ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path)))
80
+ assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path)
81
+ self.ckpt_path = Path(ckpt_path).resolve()
82
+
83
+ # Set random seed (only if seed is not None)
84
+ if seed is not None:
85
+ ranked_logger.info(f"Seeding everything with seed={seed}...")
86
+ seed_everything(seed, workers=True, verbose=True)
87
+ else:
88
+ ranked_logger.info("Seed is None - using external RNG state")
89
+ self.seed = seed
90
+
91
+ # Stored for later;
92
+ self.transform_overrides = transform_overrides
93
+ self.overrides: dict[str, Any] = {}
94
+
95
+ base_overrides = {
96
+ "trainer.seed": seed,
97
+ "trainer.metrics": {},
98
+ "trainer.loss": None,
99
+ "trainer.num_nodes": num_nodes,
100
+ "trainer.devices_per_node": devices_per_node,
101
+ }
102
+ for key, value in base_overrides.items():
103
+ self._assign_override(key, value)
104
+
105
+ for key, value in trainer_overrides.items():
106
+ self._assign_override(f"trainer.{key}", value)
107
+
108
+ for key, value in inference_sampler_overrides.items():
109
+ self._assign_override(f"model.net.inference_sampler.{key}", value)
110
+
111
+ ###################################################################################
112
+ # Required subclasss methods
113
+ ###################################################################################
114
+
115
+ def initialize(self):
116
+ if self.initialized_:
117
+ return getattr(self, "cfg", None)
118
+
119
+ # Load checkpoint and config
120
+ ranked_logger.info(
121
+ f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
122
+ )
123
+ checkpoint = torch.load(self.ckpt_path, "cpu", weights_only=False)
124
+ cfg = self._override_checkpoint_config(checkpoint["train_cfg"])
125
+
126
+ # Load pipeline first before trainer/model
127
+ self._construct_pipeline(cfg)
128
+ self._construct_trainer(cfg, checkpoint=checkpoint)
129
+
130
+ ranked_logger.info("Model loaded and ready for inference.")
131
+ self.initialized_ = True
132
+ return cfg
133
+
134
+ def run(
135
+ self,
136
+ inputs: (
137
+ Dict[str, dict] | AtomArray | list[AtomArray] | PathLike | list[PathLike]
138
+ ),
139
+ *_,
140
+ ) -> dict[str, dict] | None:
141
+ self.initialize()
142
+ raise NotImplementedError(
143
+ "Subclasses must implement inference logic in `run` method."
144
+ )
145
+
146
+ ###################################################################################
147
+ # Util methods
148
+ ###################################################################################
149
+
150
+ def _override_checkpoint_config(self, cfg):
151
+ cfg = merge(cfg, self.overrides)
152
+ cfg = set_accelerator_based_on_availability(cfg)
153
+ return cfg
154
+
155
+ def _construct_trainer(self, cfg, checkpoint=None):
156
+ """
157
+ Sets attr self.trainer
158
+ """
159
+ # Instantiate trainer
160
+ ranked_logger.info("Instantiating trainer...")
161
+ if self.verbose:
162
+ print_config_tree(
163
+ cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
164
+ )
165
+ trainer = hydra.utils.instantiate(
166
+ cfg.trainer,
167
+ _convert_="partial",
168
+ _recursive_=False,
169
+ )
170
+
171
+ # Setup model
172
+ ranked_logger.info("Setting up model...")
173
+ trainer.fabric.launch()
174
+ trainer.initialize_or_update_trainer_state(
175
+ {"train_cfg": cfg}
176
+ ) # config from training stores net params
177
+ trainer.construct_model()
178
+
179
+ ranked_logger.info("Loading model weights from checkpoint...")
180
+ trainer.load_checkpoint(checkpoint=checkpoint or self.ckpt_path)
181
+
182
+ # Ensure optimizer isn't loaded
183
+ trainer.state["optimizer"] = None
184
+ trainer.state["train_cfg"].model.optimizer = None
185
+ trainer.setup_model_optimizers_and_schedulers()
186
+ trainer.state["model"].eval()
187
+ self.trainer = trainer
188
+
189
+ def _assign_override(self, dotted_key: str, value: Any) -> None:
190
+ """Assign ``value`` into ``self.overrides`` using a dotted path."""
191
+ target = self.overrides
192
+ keys = dotted_key.split(".")
193
+ for key in keys[:-1]:
194
+ if key not in target or not isinstance(target[key], dict):
195
+ target[key] = {}
196
+ target = target[key]
197
+ target[keys[-1]] = value
198
+
199
+ def _construct_pipeline(self, cfg):
200
+ """
201
+ Sets attr self.pipeline
202
+ """
203
+ # Construct pipeline
204
+ ranked_logger.info("Building Transform pipeline...")
205
+ first_val_dataset_key, first_val_dataset = next(iter(cfg.datasets.val.items()))
206
+ ranked_logger.info(
207
+ f"Using settings from validation dataset: {first_val_dataset_key}."
208
+ )
209
+ transform = first_val_dataset.dataset.transform
210
+ transform = merge(transform, self.transform_overrides)
211
+
212
+ if self.verbose:
213
+ print_config_tree(
214
+ transform,
215
+ resolve=True,
216
+ title="INFERENCE TRANSFORM PIPELINE",
217
+ )
218
+
219
+ self.pipeline = hydra.utils.instantiate(transform)
220
+
221
+ # aliases for run
222
+ def forward(self, *args, **kwargs):
223
+ return self.run(*args, **kwargs)
224
+
225
+ def __call__(self, *args, **kwargs):
226
+ return self.run(*args, **kwargs)
227
+
228
+ # for use as a context manager: e.g. `with BaseInferenceEngine(...) as engine:` to automatically cleanup
229
+ def __enter__(self):
230
+ return self
231
+
232
+ def __exit__(self, exc_type, exc, tb):
233
+ self.trainer = None
234
+ self.pipeline = None
235
+ self.initialized_ = False
@@ -0,0 +1,66 @@
1
+ '''Management of checkpoints'''
2
+ import os
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+
7
+ def get_default_checkpoint_dir() -> Path:
8
+ """Get the default checkpoint directory.
9
+
10
+ Priority:
11
+ 1. FOUNDRY_CHECKPOINTS_DIR environment variable
12
+ 2. ~/.foundry/checkpoints
13
+ """
14
+ if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"):
15
+ return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
16
+ return Path.home() / ".foundry" / "checkpoints"
17
+
18
+ @dataclass
19
+ class RegisteredCheckpoint:
20
+ url: str
21
+ filename: str
22
+ description: str
23
+ sha256: None = None # Optional: add checksum for verification
24
+
25
+ def get_default_path(self):
26
+ return get_default_checkpoint_dir() / self.filename
27
+
28
+
29
+ REGISTERED_CHECKPOINTS = {
30
+ "rfd3": RegisteredCheckpoint(
31
+ url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
32
+ filename = "rfd3_latest.ckpt",
33
+ description = "RFdiffusion3 checkpoint",
34
+ ),
35
+ "rf3": RegisteredCheckpoint(
36
+ url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
37
+ filename= "rf3_foundry_01_24_latest_remapped.ckpt",
38
+ description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
39
+ ),
40
+ "proteinmpnn": RegisteredCheckpoint(
41
+ url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
42
+ filename= "proteinmpnn_v_48_020.pt",
43
+ description= "ProteinMPNN checkpoint",
44
+ ),
45
+ "ligandmpnn": RegisteredCheckpoint(
46
+ url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
47
+ filename= "ligandmpnn_v_32_010_25.pt",
48
+ description= "LigandMPNN checkpoint",
49
+ ),
50
+ # Other models
51
+ "rf3_preprint_921": RegisteredCheckpoint(
52
+ url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
53
+ filename = "rf3_foundry_09_21_preprint_remapped.ckpt",
54
+ description = "RF3 preprint checkpoint trained with data until 9/2021",
55
+ ),
56
+ "rf3_preprint_124": RegisteredCheckpoint(
57
+ url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
58
+ filename = "rf3_foundry_01_24_preprint_remapped.ckpt",
59
+ description= "RF3 preprint checkpoint trained with data until 1/2024",
60
+ ),
61
+ "solublempnn": RegisteredCheckpoint(
62
+ url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
63
+ filename= "solublempnn_v_48_020.pt",
64
+ description= "SolubleMPNN checkpoint"
65
+ )
66
+ }
@@ -0,0 +1,12 @@
1
+ """Metrics for model evaluation.
2
+
3
+ This module provides the base metric framework.
4
+ """
5
+
6
+ from foundry.metrics.metric import Metric, MetricInputError, MetricManager
7
+
8
+ __all__ = [
9
+ "Metric",
10
+ "MetricManager",
11
+ "MetricInputError",
12
+ ]
@@ -0,0 +1,30 @@
1
+ import hydra
2
+ import torch.nn as nn
3
+ from omegaconf import DictConfig
4
+
5
+
6
+ class Loss(nn.Module):
7
+ def __init__(self, **losses):
8
+ super().__init__()
9
+ self.to_compute = []
10
+ for loss_name, loss in losses.items():
11
+ loss_fn = hydra.utils.instantiate(loss)
12
+ self.to_compute.append(loss_fn)
13
+ assert not isinstance(
14
+ loss_fn, DictConfig
15
+ ), f"Loss {loss_name} was instantiated as a DictConfig. Is _target_ present?."
16
+
17
+ def forward(
18
+ self,
19
+ network_input,
20
+ network_output,
21
+ loss_input,
22
+ ):
23
+ loss_dict = {}
24
+ loss = 0
25
+ for loss_fn in self.to_compute:
26
+ loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
27
+ loss += loss_
28
+ loss_dict.update(loss_dict_)
29
+ loss_dict["total_loss"] = loss.detach()
30
+ return loss, loss_dict