rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,549 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from atomworks.constants import (
|
|
8
|
+
DICT_THREE_TO_ONE,
|
|
9
|
+
PROTEIN_BACKBONE_ATOM_NAMES,
|
|
10
|
+
UNKNOWN_AA,
|
|
11
|
+
)
|
|
12
|
+
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
|
|
13
|
+
from biotite.structure import AtomArray
|
|
14
|
+
from mpnn.collate.feature_collator import FeatureCollator
|
|
15
|
+
from mpnn.metrics.sequence_recovery import (
|
|
16
|
+
InterfaceSequenceRecovery,
|
|
17
|
+
SequenceRecovery,
|
|
18
|
+
)
|
|
19
|
+
from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
|
|
20
|
+
from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline
|
|
21
|
+
from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
|
|
22
|
+
from mpnn.utils.inference import (
|
|
23
|
+
MPNN_GLOBAL_INFERENCE_DEFAULTS,
|
|
24
|
+
MPNNInferenceInput,
|
|
25
|
+
MPNNInferenceOutput,
|
|
26
|
+
_absolute_path_or_none,
|
|
27
|
+
)
|
|
28
|
+
from mpnn.utils.weights import load_legacy_weights
|
|
29
|
+
|
|
30
|
+
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
|
|
31
|
+
from foundry.metrics.metric import MetricManager
|
|
32
|
+
from foundry.utils.ddp import RankedLogger
|
|
33
|
+
|
|
34
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MPNNInferenceEngine:
|
|
38
|
+
"""Inference engine for ProteinMPNN/LigandMPNN."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
model_type: str = MPNN_GLOBAL_INFERENCE_DEFAULTS["model_type"],
|
|
44
|
+
checkpoint_path: str = MPNN_GLOBAL_INFERENCE_DEFAULTS["checkpoint_path"],
|
|
45
|
+
is_legacy_weights: bool = MPNN_GLOBAL_INFERENCE_DEFAULTS["is_legacy_weights"],
|
|
46
|
+
out_directory: str | None = MPNN_GLOBAL_INFERENCE_DEFAULTS["out_directory"],
|
|
47
|
+
write_fasta: bool = MPNN_GLOBAL_INFERENCE_DEFAULTS["write_fasta"],
|
|
48
|
+
write_structures: bool = MPNN_GLOBAL_INFERENCE_DEFAULTS["write_structures"],
|
|
49
|
+
device: str | torch.device | None = None,
|
|
50
|
+
):
|
|
51
|
+
# Store raw configuration
|
|
52
|
+
self.model_type = model_type
|
|
53
|
+
self.is_legacy_weights = is_legacy_weights
|
|
54
|
+
self.out_directory = out_directory
|
|
55
|
+
self.write_fasta = write_fasta
|
|
56
|
+
self.write_structures = write_structures
|
|
57
|
+
|
|
58
|
+
# allow null for checkpoint path when foundry-installed
|
|
59
|
+
# TODO: Currently this assumes the model type is the key in the registered path. Rework needed
|
|
60
|
+
self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \
|
|
61
|
+
if not checkpoint_path else checkpoint_path
|
|
62
|
+
|
|
63
|
+
# Determine the device.
|
|
64
|
+
if device is not None:
|
|
65
|
+
self.device = torch.device(device)
|
|
66
|
+
else:
|
|
67
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
68
|
+
|
|
69
|
+
# Set up allowed model types.
|
|
70
|
+
self.allowed_model_types = {"protein_mpnn", "ligand_mpnn"}
|
|
71
|
+
|
|
72
|
+
# Validate the user configuration.
|
|
73
|
+
self._validate_all()
|
|
74
|
+
|
|
75
|
+
# Post-process the configuration (making absolute paths, etc).
|
|
76
|
+
self._post_process_engine_config()
|
|
77
|
+
|
|
78
|
+
# Build and load the model.
|
|
79
|
+
self.model = self._build_and_load_model().to(self.device)
|
|
80
|
+
|
|
81
|
+
# Construct metrics manager.
|
|
82
|
+
self.metrics = self._build_metrics_manager()
|
|
83
|
+
|
|
84
|
+
def _validate_model_config(self) -> None:
|
|
85
|
+
"""Validate model-type and checkpoint-related configuration."""
|
|
86
|
+
# Model type.
|
|
87
|
+
if self.model_type not in self.allowed_model_types:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"model_type must be one of {self.allowed_model_types}; "
|
|
90
|
+
f"got {self.model_type!r}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Checkpoint path.
|
|
94
|
+
if not isinstance(self.checkpoint_path, str):
|
|
95
|
+
raise TypeError("checkpoint_path must be a string path.")
|
|
96
|
+
|
|
97
|
+
# Check that the checkpoint path exists.
|
|
98
|
+
ckpt_path = Path(_absolute_path_or_none(self.checkpoint_path))
|
|
99
|
+
if not ckpt_path.is_file():
|
|
100
|
+
raise FileNotFoundError(
|
|
101
|
+
f"checkpoint_path does not exist: {self.checkpoint_path}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Legacy-weight flag.
|
|
105
|
+
if not isinstance(self.is_legacy_weights, bool):
|
|
106
|
+
raise TypeError("is_legacy_weights must be a bool.")
|
|
107
|
+
|
|
108
|
+
def _validate_output_config(self) -> None:
|
|
109
|
+
"""Validate output-directory and writing-related configuration."""
|
|
110
|
+
# Output directory.
|
|
111
|
+
if self.out_directory is not None:
|
|
112
|
+
# Must be a string.
|
|
113
|
+
if not isinstance(self.out_directory, str):
|
|
114
|
+
raise TypeError("out_directory must be a string when provided.")
|
|
115
|
+
|
|
116
|
+
# Boolean writing flags.
|
|
117
|
+
for name in ("write_fasta", "write_structures"):
|
|
118
|
+
value = getattr(self, name)
|
|
119
|
+
if not isinstance(value, bool):
|
|
120
|
+
raise TypeError(f"{name} must be a bool.")
|
|
121
|
+
|
|
122
|
+
# If asked to write outputs, out_directory must be set.
|
|
123
|
+
if value and self.out_directory is None:
|
|
124
|
+
raise ValueError(f"{name} is True, but out_directory is not set.")
|
|
125
|
+
|
|
126
|
+
def _validate_all(self) -> None:
|
|
127
|
+
"""Run validation on the user-specified engine config variables."""
|
|
128
|
+
# Validate the model configuration.
|
|
129
|
+
self._validate_model_config()
|
|
130
|
+
|
|
131
|
+
# Validate the output configuration.
|
|
132
|
+
self._validate_output_config()
|
|
133
|
+
|
|
134
|
+
def _post_process_engine_config(self) -> None:
|
|
135
|
+
"""Normalize paths into absolute paths."""
|
|
136
|
+
# Make checkpoint path absolute.
|
|
137
|
+
self.checkpoint_path = _absolute_path_or_none(self.checkpoint_path)
|
|
138
|
+
|
|
139
|
+
# Make output directory absolute.
|
|
140
|
+
if self.out_directory is not None:
|
|
141
|
+
self.out_directory = _absolute_path_or_none(self.out_directory)
|
|
142
|
+
|
|
143
|
+
def _build_and_load_model(self) -> torch.nn.Module:
|
|
144
|
+
# Load model architecture.
|
|
145
|
+
if self.model_type == "protein_mpnn":
|
|
146
|
+
model = ProteinMPNN()
|
|
147
|
+
elif self.model_type == "ligand_mpnn":
|
|
148
|
+
model = LigandMPNN()
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Unsupported model_type: {self.model_type}")
|
|
151
|
+
|
|
152
|
+
# Load weights.
|
|
153
|
+
if self.is_legacy_weights:
|
|
154
|
+
ranked_logger.info("Loading legacy MPNN weights.")
|
|
155
|
+
load_legacy_weights(model, self.checkpoint_path)
|
|
156
|
+
else:
|
|
157
|
+
ranked_logger.info("Loading MPNN weights.")
|
|
158
|
+
|
|
159
|
+
# Load the checkpoint.
|
|
160
|
+
checkpoint = torch.load(
|
|
161
|
+
self.checkpoint_path, map_location="cpu", weights_only=False
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Check that checkpoint is a dict.
|
|
165
|
+
if not isinstance(checkpoint, dict) or "model" not in checkpoint:
|
|
166
|
+
raise TypeError("Expected checkpoint to be a dict with a 'model' key.")
|
|
167
|
+
|
|
168
|
+
state_dict = checkpoint["model"]
|
|
169
|
+
|
|
170
|
+
model.load_state_dict(state_dict, strict=True)
|
|
171
|
+
|
|
172
|
+
# Set model to eval mode.
|
|
173
|
+
model.eval()
|
|
174
|
+
|
|
175
|
+
return model
|
|
176
|
+
|
|
177
|
+
def _build_metrics_manager(self) -> MetricManager:
|
|
178
|
+
"""Build the metrics manager for inference."""
|
|
179
|
+
|
|
180
|
+
# Construct metrics dict.
|
|
181
|
+
metrics: dict[str, Any] = {
|
|
182
|
+
"sequence_recovery": SequenceRecovery(return_per_example_metrics=True),
|
|
183
|
+
}
|
|
184
|
+
if self.model_type == "ligand_mpnn":
|
|
185
|
+
metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery(
|
|
186
|
+
return_per_example_metrics=True
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Construct the MetricManager.
|
|
190
|
+
metric_manager = MetricManager.from_metrics(metrics, raise_errors=True)
|
|
191
|
+
|
|
192
|
+
return metric_manager
|
|
193
|
+
|
|
194
|
+
# ------------------------------------------------------------------ #
|
|
195
|
+
# Public API
|
|
196
|
+
# ------------------------------------------------------------------ #
|
|
197
|
+
def run(
|
|
198
|
+
self,
|
|
199
|
+
*,
|
|
200
|
+
input_dicts: list[dict[str, Any]] | None = None,
|
|
201
|
+
atom_arrays: list[AtomArray] | None = None,
|
|
202
|
+
) -> list[MPNNInferenceOutput]:
|
|
203
|
+
"""Run inference and return a list of MPNNInferenceOutput objects.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
input_dicts:
|
|
208
|
+
Optional list of per-input JSON-like dictionaries (one per
|
|
209
|
+
input). If None, 'atom_arrays' must be provided.
|
|
210
|
+
atom_arrays:
|
|
211
|
+
Optional list of externally provided AtomArray objects. If given,
|
|
212
|
+
must align one-to-one with 'input_dicts'. If None, 'input_dicts'
|
|
213
|
+
must be sufficient to resolve structures internally.
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
list[MPNNInferenceOutput]
|
|
218
|
+
A flat list of per-design MPNNInferenceOutput objects. Writing
|
|
219
|
+
of CIF/FASTA outputs is handled internally based on engine-level
|
|
220
|
+
configuration.
|
|
221
|
+
"""
|
|
222
|
+
if input_dicts is None and atom_arrays is None:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"At least one of 'input_dicts' or 'atom_arrays' must be provided."
|
|
225
|
+
)
|
|
226
|
+
if atom_arrays is not None and input_dicts is not None:
|
|
227
|
+
if len(atom_arrays) != len(input_dicts):
|
|
228
|
+
raise ValueError(
|
|
229
|
+
"'atom_arrays' and 'input_dicts' must have the same length."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Determine the number of inputs.
|
|
233
|
+
num_inputs = len(input_dicts) if input_dicts is not None else len(atom_arrays)
|
|
234
|
+
results: list[MPNNInferenceOutput] = []
|
|
235
|
+
for input_idx in range(num_inputs):
|
|
236
|
+
# Construct the per-input MPNNInferenceInput.
|
|
237
|
+
inference_input = MPNNInferenceInput.from_atom_array_and_dict(
|
|
238
|
+
atom_array=atom_arrays[input_idx] if atom_arrays is not None else None,
|
|
239
|
+
input_dict=input_dicts[input_idx] if input_dicts is not None else None,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Optional per-input RNG seeding for deterministic sampling across
|
|
243
|
+
# batches. Initialize the seed at the beginning of the batches.
|
|
244
|
+
seed = inference_input.input_dict["seed"]
|
|
245
|
+
if seed is not None:
|
|
246
|
+
torch.manual_seed(seed)
|
|
247
|
+
np.random.seed(seed)
|
|
248
|
+
if torch.cuda.is_available():
|
|
249
|
+
torch.cuda.manual_seed_all(seed)
|
|
250
|
+
|
|
251
|
+
# Run the batches for this input.
|
|
252
|
+
for batch_idx in range(inference_input.input_dict["number_of_batches"]):
|
|
253
|
+
ranked_logger.info(
|
|
254
|
+
f"Running MPNN inference for input {input_idx}, "
|
|
255
|
+
f"batch {batch_idx}..."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Run a single batch.
|
|
259
|
+
result = self._run_batch(
|
|
260
|
+
atom_array=inference_input.atom_array,
|
|
261
|
+
input_dict=inference_input.input_dict,
|
|
262
|
+
batch_idx=batch_idx,
|
|
263
|
+
)
|
|
264
|
+
results.extend(result)
|
|
265
|
+
|
|
266
|
+
# Write outputs if requested.
|
|
267
|
+
self._write_outputs(results)
|
|
268
|
+
|
|
269
|
+
return results
|
|
270
|
+
|
|
271
|
+
def _run_batch(
|
|
272
|
+
self,
|
|
273
|
+
atom_array: AtomArray,
|
|
274
|
+
input_dict: dict[str, Any],
|
|
275
|
+
batch_idx: int | None = None,
|
|
276
|
+
) -> list[MPNNInferenceOutput]:
|
|
277
|
+
"""
|
|
278
|
+
Run a single batch (possibly multiple designs) through the pipeline.
|
|
279
|
+
|
|
280
|
+
This function:
|
|
281
|
+
- builds the transform pipeline based on 'input_dict',
|
|
282
|
+
- runs the pipeline and collator,
|
|
283
|
+
- executes the model forward pass,
|
|
284
|
+
- decodes sequences and applies them to the pipeline output
|
|
285
|
+
AtomArray,
|
|
286
|
+
- constructs 'MPNNInferenceOutput' objects
|
|
287
|
+
"""
|
|
288
|
+
# Overriding of default pipeline args from input_dict.
|
|
289
|
+
pipeline_args = dict()
|
|
290
|
+
if input_dict["occupancy_threshold_sidechain"] is not None:
|
|
291
|
+
pipeline_args["occupancy_threshold_sidechain"] = input_dict[
|
|
292
|
+
"occupancy_threshold_sidechain"
|
|
293
|
+
]
|
|
294
|
+
if input_dict["occupancy_threshold_backbone"] is not None:
|
|
295
|
+
pipeline_args["occupancy_threshold_backbone"] = input_dict[
|
|
296
|
+
"occupancy_threshold_backbone"
|
|
297
|
+
]
|
|
298
|
+
if input_dict["undesired_res_names"] is not None:
|
|
299
|
+
pipeline_args["undesired_res_names"] = input_dict["undesired_res_names"]
|
|
300
|
+
|
|
301
|
+
# Construct the pipeline.
|
|
302
|
+
pipeline = build_mpnn_transform_pipeline(
|
|
303
|
+
model_type=self.model_type,
|
|
304
|
+
is_inference=True,
|
|
305
|
+
minimal_return=True,
|
|
306
|
+
device=self.device,
|
|
307
|
+
**pipeline_args,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Construct the collator.
|
|
311
|
+
collator = FeatureCollator()
|
|
312
|
+
|
|
313
|
+
# Data dict for pipeline: atom_array plus scalar user-settings.
|
|
314
|
+
data: dict[str, Any] = {
|
|
315
|
+
"atom_array": atom_array.copy(),
|
|
316
|
+
# Scalar user settings.
|
|
317
|
+
"structure_noise": input_dict["structure_noise"],
|
|
318
|
+
"decode_type": input_dict["decode_type"],
|
|
319
|
+
"causality_pattern": input_dict["causality_pattern"],
|
|
320
|
+
"initialize_sequence_embedding_with_ground_truth": input_dict[
|
|
321
|
+
"initialize_sequence_embedding_with_ground_truth"
|
|
322
|
+
],
|
|
323
|
+
"atomize_side_chains": input_dict["atomize_side_chains"],
|
|
324
|
+
"repeat_sample_num": input_dict["repeat_sample_num"],
|
|
325
|
+
"features_to_return": input_dict["features_to_return"],
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
# Run the pipeline.
|
|
329
|
+
pipeline_output = pipeline(data)
|
|
330
|
+
|
|
331
|
+
# Construct the collated network input.
|
|
332
|
+
network_input = collator([pipeline_output])
|
|
333
|
+
|
|
334
|
+
# Run the model forward pass.
|
|
335
|
+
with torch.no_grad():
|
|
336
|
+
network_output = self.model(network_input)
|
|
337
|
+
|
|
338
|
+
# Compute metrics once per batch.
|
|
339
|
+
metrics_output = self.metrics(
|
|
340
|
+
network_input=network_input,
|
|
341
|
+
network_output=network_output,
|
|
342
|
+
extra_info={},
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Extract the sampled sequences.
|
|
346
|
+
# S_sampled: [B = batch_size, L = sequence length]
|
|
347
|
+
S_sampled = (
|
|
348
|
+
network_output["decoder_features"]["S_sampled"].detach().cpu().numpy()
|
|
349
|
+
)
|
|
350
|
+
B, L = S_sampled.shape
|
|
351
|
+
if B != input_dict["batch_size"]:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Mismatch between network output batch size and input_dict batch_size."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Extract the metrics.
|
|
357
|
+
sequence_recovery_per_design = (
|
|
358
|
+
metrics_output["sequence_recovery.sequence_recovery_per_example_sampled"]
|
|
359
|
+
.detach()
|
|
360
|
+
.cpu()
|
|
361
|
+
.numpy()
|
|
362
|
+
)
|
|
363
|
+
if self.model_type == "ligand_mpnn":
|
|
364
|
+
interface_sequence_recovery_per_design = (
|
|
365
|
+
metrics_output[
|
|
366
|
+
"interface_sequence_recovery.interface_sequence_recovery_per_example_sampled"
|
|
367
|
+
]
|
|
368
|
+
.detach()
|
|
369
|
+
.cpu()
|
|
370
|
+
.numpy()
|
|
371
|
+
)
|
|
372
|
+
else:
|
|
373
|
+
interface_sequence_recovery_per_design = None
|
|
374
|
+
|
|
375
|
+
# Grab the index to token mapping from the model.
|
|
376
|
+
idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token
|
|
377
|
+
|
|
378
|
+
# Construct the output objects.
|
|
379
|
+
outputs: list[MPNNInferenceOutput] = []
|
|
380
|
+
for design_idx in range(input_dict["batch_size"]):
|
|
381
|
+
# Per design, copy the atom array.
|
|
382
|
+
design_atom_array = pipeline_output["atom_array"].copy()
|
|
383
|
+
|
|
384
|
+
# Grab the non-atomized atom and token level arrays. This mimics
|
|
385
|
+
# the logic in the pipeline for token level extraction, so it
|
|
386
|
+
# should lead to a one-to-one mapping between decoded tokens and
|
|
387
|
+
# non-atomized residues.
|
|
388
|
+
design_non_atomized_array = design_atom_array[~design_atom_array.atomize]
|
|
389
|
+
design_non_atomized_token_starts = get_token_starts(
|
|
390
|
+
design_non_atomized_array
|
|
391
|
+
)
|
|
392
|
+
design_non_atomized_token_level = design_non_atomized_array[
|
|
393
|
+
design_non_atomized_token_starts
|
|
394
|
+
]
|
|
395
|
+
|
|
396
|
+
# Create the res_name array for the design.
|
|
397
|
+
designed_resnames = np.array(
|
|
398
|
+
[idx_to_token[int(token_idx)] for token_idx in S_sampled[design_idx]],
|
|
399
|
+
dtype=design_atom_array.res_name.dtype,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Sanity check: decoded sequence length must match number of
|
|
403
|
+
# non-atomized tokens.
|
|
404
|
+
if len(design_non_atomized_token_level) != len(designed_resnames):
|
|
405
|
+
raise ValueError(
|
|
406
|
+
"Mismatch between number of non-atomized tokens and "
|
|
407
|
+
"decoded sequence length."
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Spread token-level residue names back to atom level, but only
|
|
411
|
+
# over the non-atomized subset.
|
|
412
|
+
designed_resnames_atom = spread_token_wise(
|
|
413
|
+
design_non_atomized_array,
|
|
414
|
+
designed_resnames,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Create a full res_name array.
|
|
418
|
+
full_resnames = design_atom_array.res_name.copy()
|
|
419
|
+
full_resnames[~design_atom_array.atomize] = designed_resnames_atom
|
|
420
|
+
|
|
421
|
+
# Overwrite with designed residue names.
|
|
422
|
+
design_atom_array.set_annotation("res_name", full_resnames)
|
|
423
|
+
|
|
424
|
+
# We need to remove any non-atomized residue atoms that no
|
|
425
|
+
# longer belong (i.e. old side chain atoms). We want to keep any
|
|
426
|
+
# atom that is atomized, any atom that is a backbone atom, and
|
|
427
|
+
# any atom that was fixed.
|
|
428
|
+
design_is_backbone_atom = np.isin(
|
|
429
|
+
design_atom_array.atom_name,
|
|
430
|
+
PROTEIN_BACKBONE_ATOM_NAMES,
|
|
431
|
+
)
|
|
432
|
+
if (
|
|
433
|
+
"mpnn_designed_residue_mask"
|
|
434
|
+
in design_atom_array.get_annotation_categories()
|
|
435
|
+
):
|
|
436
|
+
design_is_fixed_atom = ~design_atom_array.mpnn_designed_residue_mask
|
|
437
|
+
else:
|
|
438
|
+
design_is_fixed_atom = np.zeros(len(design_atom_array), dtype=bool)
|
|
439
|
+
design_atom_array = design_atom_array[
|
|
440
|
+
design_atom_array.atomize
|
|
441
|
+
| design_is_backbone_atom
|
|
442
|
+
| design_is_fixed_atom
|
|
443
|
+
]
|
|
444
|
+
|
|
445
|
+
# Construct one letter sequence and recovery metrics for
|
|
446
|
+
# output dict.
|
|
447
|
+
one_letter_seq = "".join(
|
|
448
|
+
[
|
|
449
|
+
DICT_THREE_TO_ONE.get(res_name, DICT_THREE_TO_ONE[UNKNOWN_AA])
|
|
450
|
+
for res_name in designed_resnames
|
|
451
|
+
]
|
|
452
|
+
)
|
|
453
|
+
sequence_recovery = float(sequence_recovery_per_design[design_idx])
|
|
454
|
+
if interface_sequence_recovery_per_design is not None:
|
|
455
|
+
ligand_interface_sequence_recovery = float(
|
|
456
|
+
interface_sequence_recovery_per_design[design_idx]
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
ligand_interface_sequence_recovery = None
|
|
460
|
+
|
|
461
|
+
# Build the output dict.
|
|
462
|
+
output_dict = {
|
|
463
|
+
"batch_idx": batch_idx,
|
|
464
|
+
"design_idx": design_idx,
|
|
465
|
+
"designed_sequence": one_letter_seq,
|
|
466
|
+
"sequence_recovery": sequence_recovery,
|
|
467
|
+
"ligand_interface_sequence_recovery": (
|
|
468
|
+
ligand_interface_sequence_recovery
|
|
469
|
+
),
|
|
470
|
+
"model_type": self.model_type,
|
|
471
|
+
"checkpoint_path": self.checkpoint_path,
|
|
472
|
+
"is_legacy_weights": self.is_legacy_weights,
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
outputs.append(
|
|
476
|
+
MPNNInferenceOutput(
|
|
477
|
+
atom_array=design_atom_array,
|
|
478
|
+
output_dict=output_dict,
|
|
479
|
+
input_dict=copy.deepcopy(input_dict),
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
return outputs
|
|
484
|
+
|
|
485
|
+
def _write_outputs(self, results: list[MPNNInferenceOutput]) -> None:
|
|
486
|
+
"""Write CIF and/or FASTA outputs based on engine-level settings."""
|
|
487
|
+
out_directory = self.out_directory
|
|
488
|
+
|
|
489
|
+
# If no output directory and writing requested, raise error.
|
|
490
|
+
if not out_directory and (self.write_fasta or self.write_structures):
|
|
491
|
+
raise ValueError(
|
|
492
|
+
"Output directory is not set, but writing of outputs was requested."
|
|
493
|
+
)
|
|
494
|
+
elif not out_directory:
|
|
495
|
+
# Nothing to do.
|
|
496
|
+
return
|
|
497
|
+
|
|
498
|
+
# Make the output directory if it does not exist.
|
|
499
|
+
out_dir_path = Path(out_directory)
|
|
500
|
+
out_dir_path.mkdir(parents=True, exist_ok=True)
|
|
501
|
+
|
|
502
|
+
if self.write_structures:
|
|
503
|
+
# One CIF per design.
|
|
504
|
+
for idx, result in enumerate(results):
|
|
505
|
+
name = result.input_dict["name"]
|
|
506
|
+
batch_idx = result.output_dict["batch_idx"]
|
|
507
|
+
design_idx = result.output_dict["design_idx"]
|
|
508
|
+
|
|
509
|
+
# Can't write without a name.
|
|
510
|
+
if name is None:
|
|
511
|
+
raise ValueError(
|
|
512
|
+
f"Cannot write structure for result {idx}: 'name' is "
|
|
513
|
+
"not set in input_dict."
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Construct the output file path.
|
|
517
|
+
file_stem = f"{name}_b{batch_idx}_d{design_idx}"
|
|
518
|
+
base_path = out_dir_path / file_stem
|
|
519
|
+
|
|
520
|
+
# Use the MPNNInferenceOutput helper for writing.
|
|
521
|
+
result.write_structure(
|
|
522
|
+
base_path=base_path,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# Write FASTA outputs if requested, one per input name.
|
|
526
|
+
if self.write_fasta:
|
|
527
|
+
# Group results by input name.
|
|
528
|
+
grouped: dict[str, list[MPNNInferenceOutput]] = {}
|
|
529
|
+
for result in results:
|
|
530
|
+
name = result.input_dict["name"]
|
|
531
|
+
|
|
532
|
+
# Can't write without a name.
|
|
533
|
+
if name is None:
|
|
534
|
+
raise ValueError(
|
|
535
|
+
"Cannot write FASTA output: 'name' is not set in input_dict."
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
if name not in grouped:
|
|
539
|
+
grouped[name] = []
|
|
540
|
+
|
|
541
|
+
grouped[name].append(result)
|
|
542
|
+
|
|
543
|
+
# Write one FASTA file per input name.
|
|
544
|
+
for name, group in grouped.items():
|
|
545
|
+
fasta_path = out_dir_path / f"{name}.fa"
|
|
546
|
+
# Append mode so that multiple runs can accumulate designs.
|
|
547
|
+
with fasta_path.open("a") as handle:
|
|
548
|
+
for result in group:
|
|
549
|
+
result.write_fasta(handle=handle)
|
mpnn/loss/nll_loss.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LabelSmoothedNLLLoss(nn.Module):
|
|
6
|
+
def __init__(self, label_smoothing_eps=0.1, normalization_constant=6000.0):
|
|
7
|
+
"""
|
|
8
|
+
Label smoothed negative log likelihood loss for Protein/Ligand MPNN.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
label_smoothing_eps (float): The label smoothing factor. Default is
|
|
12
|
+
0.1.
|
|
13
|
+
normalization_constant (float): The normalization constant for the
|
|
14
|
+
loss. As opposed to averaging per sample in the batch, or
|
|
15
|
+
averaging across all tokens, this constant is used to normalize
|
|
16
|
+
the loss. Default is 6000.0.
|
|
17
|
+
"""
|
|
18
|
+
super(LabelSmoothedNLLLoss, self).__init__()
|
|
19
|
+
|
|
20
|
+
self.label_smoothing_eps = label_smoothing_eps
|
|
21
|
+
self.normalization_constant = normalization_constant
|
|
22
|
+
|
|
23
|
+
def forward(self, network_input, network_output, loss_input):
|
|
24
|
+
"""
|
|
25
|
+
Given the network_input (same as input_features to the model), network
|
|
26
|
+
output, and loss input, compute the loss.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
network_input (dict): The input to the network.
|
|
30
|
+
- input_features (dict): Contains the input features.
|
|
31
|
+
- S (torch.Tensor): [B, L] - the sequence of residues.
|
|
32
|
+
network_output (dict): The output of the network, a dictionary
|
|
33
|
+
containing several sub-dictionaries; the necessary sub-
|
|
34
|
+
dictionaries and their needed keys are listed below:
|
|
35
|
+
- input_features (dict): Contains the modified input features.
|
|
36
|
+
- mask_for_loss (torch.Tensor): [B, L] - the mask for the
|
|
37
|
+
loss computation.
|
|
38
|
+
- decoder_features (dict): Contains the decoder features.
|
|
39
|
+
- log_probs (torch.Tensor): [B, L, vocab_size] - the log
|
|
40
|
+
probabilities for the sequence.
|
|
41
|
+
loss_input (dict): Dictionary containing additional inputs needed
|
|
42
|
+
for the loss computation. Unused here.
|
|
43
|
+
Returns:
|
|
44
|
+
The loss and a dictionary containing the loss values.
|
|
45
|
+
- label_smoothed_nll_loss_agg (torch.Tensor): [1] - the
|
|
46
|
+
aggregated label smoothed negative log likelihood loss,
|
|
47
|
+
masked by the mask for the loss, summed across the batch and
|
|
48
|
+
length dimensions, and normalized by the normalization
|
|
49
|
+
constant. This is the final loss value returned by the loss
|
|
50
|
+
function.
|
|
51
|
+
- loss_dict (dict): A dictionary containing the loss outputs.
|
|
52
|
+
- label_smoothed_nll_loss_per_residue (torch.Tensor): [B, L]
|
|
53
|
+
- the per-residue label smoothed negative log likelihood
|
|
54
|
+
loss, masked by the mask for loss.
|
|
55
|
+
- label_smoothed_nll_loss_agg (torch.Tensor): [1] - the
|
|
56
|
+
aggregated label smoothed negative log likelihood loss,
|
|
57
|
+
masked by the mask for loss, summed across the batch and
|
|
58
|
+
length dimensions, and normalized by the normalization
|
|
59
|
+
constant. This is the final loss value returned by the
|
|
60
|
+
loss function.
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
input_features = network_input["input_features"]
|
|
64
|
+
|
|
65
|
+
# Check that the input features contains the necessary keys.
|
|
66
|
+
if "S" not in input_features:
|
|
67
|
+
raise ValueError("Input features must contain 'S' key.")
|
|
68
|
+
|
|
69
|
+
# Check that the network output contains the necessary keys.
|
|
70
|
+
if "input_features" not in network_output:
|
|
71
|
+
raise ValueError("Network output must contain 'input_features' key.")
|
|
72
|
+
if "mask_for_loss" not in network_output["input_features"]:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
"Network output must contain'"
|
|
75
|
+
+ "mask_for_loss' key in 'input_features'."
|
|
76
|
+
)
|
|
77
|
+
if "decoder_features" not in network_output:
|
|
78
|
+
raise ValueError("Network output must contain 'decoder_features' key.")
|
|
79
|
+
if "log_probs" not in network_output["decoder_features"]:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"Network output must contain" + "'log_probs' key in 'decoder_features'."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
B, L, vocab_size = network_output["decoder_features"]["log_probs"].shape
|
|
85
|
+
|
|
86
|
+
# S_onehot [B, L, vocab_size] - the one-hot encoded sequence.
|
|
87
|
+
S_onehot = torch.nn.functional.one_hot(
|
|
88
|
+
input_features["S"], num_classes=vocab_size
|
|
89
|
+
).float()
|
|
90
|
+
|
|
91
|
+
# label_smoothed_S_onehot [B, L, vocab_size] - the label smoothed
|
|
92
|
+
# encoded sequence.
|
|
93
|
+
label_smoothed_S_onehot = (
|
|
94
|
+
1 - self.label_smoothing_eps
|
|
95
|
+
) * S_onehot + self.label_smoothing_eps / vocab_size
|
|
96
|
+
|
|
97
|
+
# label_smoothed_nll_loss_per_residue [B, L] - the per-residue label
|
|
98
|
+
# smoothed negative log likelihood loss, masked by the mask for loss.
|
|
99
|
+
label_smoothed_nll_loss_per_residue = (
|
|
100
|
+
-torch.sum(
|
|
101
|
+
label_smoothed_S_onehot
|
|
102
|
+
* network_output["decoder_features"]["log_probs"],
|
|
103
|
+
dim=-1,
|
|
104
|
+
)
|
|
105
|
+
* network_output["input_features"]["mask_for_loss"]
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# label_smoothed_nll_loss_agg - the aggregated label smoothed
|
|
109
|
+
# negative log likelihood loss, aggregated across the batch and
|
|
110
|
+
# length dimensions, and normalized by the normalization constant.
|
|
111
|
+
# This is the final loss value returned by the loss function.
|
|
112
|
+
label_smoothed_nll_loss_agg = (
|
|
113
|
+
torch.sum(label_smoothed_nll_loss_per_residue) / self.normalization_constant
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Construct the output loss dictionary.
|
|
117
|
+
loss_dict = {
|
|
118
|
+
"label_smoothed_nll_loss_per_residue": label_smoothed_nll_loss_per_residue.detach(),
|
|
119
|
+
"label_smoothed_nll_loss_agg": label_smoothed_nll_loss_agg.detach(),
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
return label_smoothed_nll_loss_agg, loss_dict
|