rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_recycle_schedule(
|
|
8
|
+
max_cycle: int,
|
|
9
|
+
n_epochs: int,
|
|
10
|
+
n_train: int,
|
|
11
|
+
world_size: int,
|
|
12
|
+
seed: int = 42,
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""Generate a schedule for recycling iterations over multiple epochs.
|
|
15
|
+
|
|
16
|
+
Used to ensure that each GPU has the same number of recycles within a given batch.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
max_cycle (int): Maximum number of recycling iterations (n_recycle).
|
|
20
|
+
n_epochs (int): Number of training epochs.
|
|
21
|
+
n_train (int): The total number of training examples per epoch (across all GPUs).
|
|
22
|
+
world_size (int): The number of distributed training processes.
|
|
23
|
+
seed (int, optional): The seed for random number generation. Defaults to 42.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
torch.Tensor: A tensor containing the recycling schedule for each epoch,
|
|
27
|
+
with dimensions `(n_epochs, n_train // world_size)`.
|
|
28
|
+
|
|
29
|
+
References:
|
|
30
|
+
AF-2 Supplement, Algorithm 31
|
|
31
|
+
"""
|
|
32
|
+
# We use a context manager to avoid modifying the global RNG state
|
|
33
|
+
with rng_state(create_rng_state_from_seeds(torch_seed=seed)):
|
|
34
|
+
# ...generate a recycling schedule for each epoch
|
|
35
|
+
recycle_schedule = []
|
|
36
|
+
for i in range(n_epochs):
|
|
37
|
+
schedule = torch.randint(
|
|
38
|
+
1, max_cycle + 1, (math.ceil(n_train / world_size),)
|
|
39
|
+
)
|
|
40
|
+
recycle_schedule.append(schedule)
|
|
41
|
+
|
|
42
|
+
return torch.stack(recycle_schedule, dim=0)
|
rfd3/trainer/rfd3.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from beartype.typing import Any, List, Union
|
|
4
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
5
|
+
from biotite.structure.residues import get_residue_starts
|
|
6
|
+
from einops import repeat
|
|
7
|
+
from lightning_utilities import apply_to_collection
|
|
8
|
+
from omegaconf import DictConfig
|
|
9
|
+
from rfd3.metrics.design_metrics import get_all_backbone_metrics
|
|
10
|
+
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
|
|
11
|
+
from rfd3.trainer.recycling import get_recycle_schedule
|
|
12
|
+
from rfd3.trainer.trainer_utils import (
|
|
13
|
+
_build_atom_array_stack,
|
|
14
|
+
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
|
15
|
+
_reassign_unindexed_token_chains,
|
|
16
|
+
_reorder_dict,
|
|
17
|
+
process_unindexed_outputs,
|
|
18
|
+
)
|
|
19
|
+
from rfd3.utils.io import (
|
|
20
|
+
build_stack_from_atom_array_and_batched_coords,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from foundry.metrics.losses import Loss
|
|
24
|
+
from foundry.metrics.metric import MetricManager
|
|
25
|
+
from foundry.trainers.fabric import FabricTrainer
|
|
26
|
+
from foundry.utils.ddp import RankedLogger
|
|
27
|
+
from foundry.utils.torch import assert_no_nans, assert_same_shape
|
|
28
|
+
|
|
29
|
+
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AADesignTrainer(FabricTrainer):
|
|
33
|
+
"""Mostly for unique things like saving outputs and parsing inputs
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
allow_sequence_outputs (bool): Whether to allow sequence outputs in the model.
|
|
37
|
+
convert_non_protein_designed_res_to_ala (bool): Convert non-protein designed residues to ALA. Useful if the
|
|
38
|
+
sequence head spuriously predicts NA residues (when it's performing very poorly).
|
|
39
|
+
cleanup_inference_outputs (bool): Not implemented yet.
|
|
40
|
+
load_sequence_head_weights_if_present (bool): Whether to load the sequence head weights from the checkpoint.
|
|
41
|
+
association_scheme (str): Association scheme to use for the sequence head. Defaults to "atom14".
|
|
42
|
+
seed (int | None): The random seed used for this design, which will be dumped in the output JSON.
|
|
43
|
+
If None, no value will be dumped.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
allow_sequence_outputs,
|
|
49
|
+
cleanup_guideposts,
|
|
50
|
+
cleanup_virtual_atoms,
|
|
51
|
+
read_sequence_from_sequence_head,
|
|
52
|
+
output_full_json,
|
|
53
|
+
association_scheme,
|
|
54
|
+
compute_non_clash_metrics_for_diffused_region_only=False,
|
|
55
|
+
seed=None, # Deprecated
|
|
56
|
+
n_recycles_train: int | None = None,
|
|
57
|
+
loss: DictConfig | dict | None = None,
|
|
58
|
+
metrics: DictConfig | dict | None = None,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
|
|
63
|
+
self.allow_sequence_outputs = allow_sequence_outputs
|
|
64
|
+
self.cleanup_guideposts = cleanup_guideposts
|
|
65
|
+
self.cleanup_virtual_atoms = cleanup_virtual_atoms
|
|
66
|
+
self.read_sequence_from_sequence_head = read_sequence_from_sequence_head
|
|
67
|
+
self.output_full_json = output_full_json
|
|
68
|
+
self.compute_non_clash_metrics_for_diffused_region_only = (
|
|
69
|
+
compute_non_clash_metrics_for_diffused_region_only
|
|
70
|
+
)
|
|
71
|
+
self.association_scheme = association_scheme
|
|
72
|
+
self.seed = None
|
|
73
|
+
|
|
74
|
+
# (Initialize recycle schedule upfront so all GPU's can sample the same number of recycles within a batch)
|
|
75
|
+
self.n_recycles_train = n_recycles_train
|
|
76
|
+
self.recycle_schedule = get_recycle_schedule(
|
|
77
|
+
max_cycle=n_recycles_train,
|
|
78
|
+
n_epochs=self.max_epochs, # Set by FabricTrainer
|
|
79
|
+
n_train=self.n_examples_per_epoch, # Set by FabricTrainer
|
|
80
|
+
world_size=self.fabric.world_size,
|
|
81
|
+
) # [n_epochs, n_examples_per_epoch // world_size]
|
|
82
|
+
|
|
83
|
+
# Metrics
|
|
84
|
+
# (We could have instantiated loss and metrics recursively, but we prioritize being explicit)
|
|
85
|
+
self.metrics = (
|
|
86
|
+
MetricManager.instantiate_from_hydra(metrics_cfg=metrics)
|
|
87
|
+
if metrics
|
|
88
|
+
else None
|
|
89
|
+
)
|
|
90
|
+
# Loss (full precision)
|
|
91
|
+
with torch.autocast(device_type=self.fabric.device.type, enabled=False):
|
|
92
|
+
self.loss = Loss(**loss) if loss else None
|
|
93
|
+
|
|
94
|
+
def _assemble_network_inputs(self, example: dict) -> dict:
|
|
95
|
+
"""Assemble and validate the network inputs."""
|
|
96
|
+
assert_same_shape(example["coord_atom_lvl_to_be_noised"], example["noise"])
|
|
97
|
+
network_input = {
|
|
98
|
+
"X_noisy_L": example["coord_atom_lvl_to_be_noised"] + example["noise"],
|
|
99
|
+
"t": example["t"],
|
|
100
|
+
"f": example["feats"],
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
assert_no_nans(
|
|
105
|
+
network_input["X_noisy_L"],
|
|
106
|
+
msg=f"network_input (X_noisy_L) for example_id: {example['example_id']}",
|
|
107
|
+
)
|
|
108
|
+
except AssertionError as e:
|
|
109
|
+
if self.state["model"].training:
|
|
110
|
+
# In some cases, we may indeed have NaNs in the the noisy coordinates; we can safely replace them with zeros,
|
|
111
|
+
# and begin noising of those coordinates (which will not have their loss computed) from the origin.
|
|
112
|
+
# Such a situation could occur if there was a chain in the crop with no resolved residues (but that contained resolved
|
|
113
|
+
# residues outside the crop); we then would not be able to resolve the missing coordinates to their "closest resolved neighbor"
|
|
114
|
+
# within the same chain.
|
|
115
|
+
network_input["X_noisy_L"] = torch.nan_to_num(
|
|
116
|
+
network_input["X_noisy_L"]
|
|
117
|
+
)
|
|
118
|
+
global_logger.warning(str(e))
|
|
119
|
+
else:
|
|
120
|
+
# During validation, since we do not crop, there should be no NaN's in the coordinates to noise
|
|
121
|
+
# (They were either removed, as is done with fully unresolved chains, or resolved accoring to our pipeline's rules)
|
|
122
|
+
raise e
|
|
123
|
+
|
|
124
|
+
assert_no_nans(
|
|
125
|
+
network_input["f"],
|
|
126
|
+
msg=f"NaN detected in `feats` for example_id: {example['example_id']}",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return network_input
|
|
130
|
+
|
|
131
|
+
def training_step(
|
|
132
|
+
self,
|
|
133
|
+
batch: Any,
|
|
134
|
+
batch_idx: int,
|
|
135
|
+
is_accumulating: bool,
|
|
136
|
+
) -> None:
|
|
137
|
+
"""Training step, running forward and backward passes.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
batch: The current batch; can be of any form.
|
|
141
|
+
batch_idx: The index of the current batch.
|
|
142
|
+
is_accumulating: Whether we are accumulating gradients (i.e., not yet calling optimizer.step()).
|
|
143
|
+
If this is the case, we should skip the synchronization during the backward pass.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
None; we call `loss.backward()` directly, and store the outputs in `self._current_train_return`.
|
|
147
|
+
"""
|
|
148
|
+
model = self.state["model"]
|
|
149
|
+
assert model.training, "Model must be training!"
|
|
150
|
+
|
|
151
|
+
# Recycling
|
|
152
|
+
# (Number of recycles for the current batch; shared across all GPUs within a distributed batch)
|
|
153
|
+
n_cycle = self.recycle_schedule[self.state["current_epoch"], batch_idx].item()
|
|
154
|
+
|
|
155
|
+
with self.fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
156
|
+
# (We assume batch size of 1 for structure predictions)
|
|
157
|
+
example = batch[0] if not isinstance(batch, dict) else batch
|
|
158
|
+
|
|
159
|
+
network_input = self._assemble_network_inputs(example)
|
|
160
|
+
|
|
161
|
+
# Forward pass (without rollout)
|
|
162
|
+
network_output = model.forward(input=network_input, n_cycle=n_cycle)
|
|
163
|
+
assert_no_nans(
|
|
164
|
+
network_output,
|
|
165
|
+
msg=f"network_output for example_id: {example['example_id']}",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
loss_extra_info = self._assemble_loss_extra_info(example)
|
|
169
|
+
|
|
170
|
+
total_loss, loss_dict_batched = self.loss(
|
|
171
|
+
network_input=network_input,
|
|
172
|
+
network_output=network_output,
|
|
173
|
+
# TODO: Rename `loss_input` to `extra_info` to pattern-match metrics
|
|
174
|
+
loss_input=loss_extra_info,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Backward pass
|
|
178
|
+
self.fabric.backward(total_loss)
|
|
179
|
+
|
|
180
|
+
# ... store the outputs without gradients for use in logging, callbacks, learning rate schedulers, etc.
|
|
181
|
+
self._current_train_return = apply_to_collection(
|
|
182
|
+
{"total_loss": total_loss, "loss_dict": loss_dict_batched},
|
|
183
|
+
dtype=torch.Tensor,
|
|
184
|
+
function=lambda x: x.detach(),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def validation_step(
|
|
188
|
+
self,
|
|
189
|
+
batch: Any,
|
|
190
|
+
batch_idx: int,
|
|
191
|
+
compute_metrics: bool = True,
|
|
192
|
+
) -> dict:
|
|
193
|
+
"""Validation step, running forward pass and computing validation metrics.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
batch: The current batch; can be of any form.
|
|
197
|
+
batch_idx: The index of the current batch.
|
|
198
|
+
compute_metrics: Whether to compute metrics. If False, we will not compute metrics, and the output will be None.
|
|
199
|
+
Set to False during the inference pipeline, where we need the network output but cannot compute metrics (since we
|
|
200
|
+
do not have the ground truth).
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
dict: Output dictionary containing the validation metrics and network output.
|
|
204
|
+
"""
|
|
205
|
+
model = self.state["model"]
|
|
206
|
+
assert not model.training, "Model must be in evaluation mode during validation!"
|
|
207
|
+
|
|
208
|
+
example = batch[0] if not isinstance(batch, dict) else batch
|
|
209
|
+
|
|
210
|
+
network_input = self._assemble_network_inputs(example)
|
|
211
|
+
|
|
212
|
+
assert_no_nans(
|
|
213
|
+
network_input,
|
|
214
|
+
msg=f"network_input for example_id: {example['example_id']}",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# ... forward pass (with rollout)
|
|
218
|
+
# (Note that forward() passes to the EMA/shadow model if the model is not training)
|
|
219
|
+
network_output = model.forward(
|
|
220
|
+
input=network_input,
|
|
221
|
+
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
assert_no_nans(
|
|
225
|
+
network_output,
|
|
226
|
+
msg=f"network_output for example_id: {example['example_id']}",
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# ... Convert output to a stack of atom arrays
|
|
230
|
+
predicted_atom_array_stack, prediction_metadata = (
|
|
231
|
+
self._build_predicted_atom_array_stack(network_output, example)
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
metrics_output = {}
|
|
235
|
+
if compute_metrics:
|
|
236
|
+
assert self.metrics is not None, "Metrics are not defined!"
|
|
237
|
+
|
|
238
|
+
metrics_extra_info = self._assemble_metrics_extra_info(
|
|
239
|
+
example, network_output
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
metrics_output = self.metrics(
|
|
243
|
+
network_input=network_input,
|
|
244
|
+
network_output=network_output,
|
|
245
|
+
extra_info=metrics_extra_info,
|
|
246
|
+
# (Uses the permuted ground truth after symmetry resolution)
|
|
247
|
+
ground_truth_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
|
248
|
+
metrics_extra_info["X_gt_L"], example.get("atom_array", None)
|
|
249
|
+
),
|
|
250
|
+
predicted_atom_array_stack=predicted_atom_array_stack,
|
|
251
|
+
prediction_metadata=prediction_metadata,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if "X_gt_index_to_X" in metrics_extra_info:
|
|
255
|
+
# Remap outputs to minimize error with ground truth
|
|
256
|
+
# TODO: Remap before computing metrics, so that we can avoid pass `extra_info` to metrics (we instead just pass the remapped prediction)
|
|
257
|
+
mapping = metrics_extra_info["X_gt_index_to_X"] # [D, L]
|
|
258
|
+
network_output["X_L"] = _remap_outputs(network_output["X_L"], mapping)
|
|
259
|
+
|
|
260
|
+
# Avoid gradients in stored values to prevent memory leaks
|
|
261
|
+
if metrics_output is not None:
|
|
262
|
+
metrics_output = apply_to_collection(
|
|
263
|
+
metrics_output, torch.Tensor, lambda x: x.detach()
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if network_output is not None:
|
|
267
|
+
network_output = apply_to_collection(
|
|
268
|
+
network_output, torch.Tensor, lambda x: x.detach()
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return {
|
|
272
|
+
"metrics_output": metrics_output,
|
|
273
|
+
"network_output": network_output,
|
|
274
|
+
"predicted_atom_array_stack": predicted_atom_array_stack,
|
|
275
|
+
"prediction_metadata": prediction_metadata,
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
def _assemble_loss_extra_info(self, example: dict) -> dict:
|
|
279
|
+
"""Assembles metadata arguments to the loss function (incremental to the network inputs and outputs)."""
|
|
280
|
+
|
|
281
|
+
# ... reshape
|
|
282
|
+
diffusion_batch_size = example["coord_atom_lvl_to_be_noised"].shape[0]
|
|
283
|
+
X_gt_L = repeat(
|
|
284
|
+
example["ground_truth"]["coord_atom_lvl"],
|
|
285
|
+
"l c -> d l c",
|
|
286
|
+
d=diffusion_batch_size,
|
|
287
|
+
) # [L, 3] -> [D, L, 3] with broadcasting
|
|
288
|
+
|
|
289
|
+
return {
|
|
290
|
+
"X_gt_L": X_gt_L, # [D, L, 3]
|
|
291
|
+
"X_gt_L_in_input_frame": example[
|
|
292
|
+
"coord_atom_lvl_to_be_noised"
|
|
293
|
+
], # [D, L, 3] for no-align loss
|
|
294
|
+
"crd_mask_L": example["ground_truth"]["mask_atom_lvl"], # [D, L]
|
|
295
|
+
"is_original_unindexed_token": example["ground_truth"][
|
|
296
|
+
"is_original_unindexed_token"
|
|
297
|
+
], # [I,]
|
|
298
|
+
# Sequence information:
|
|
299
|
+
"seq_token_lvl": example["ground_truth"]["sequence_gt_I"], # [I, 32]
|
|
300
|
+
"sequence_valid_mask": example["ground_truth"][
|
|
301
|
+
"sequence_valid_mask"
|
|
302
|
+
], # [I,]
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
def _assemble_metrics_extra_info(self, example: dict, network_output: dict) -> dict:
|
|
306
|
+
"""Prepares the extra info for the metrics"""
|
|
307
|
+
# We need the same information as for the loss...
|
|
308
|
+
metrics_extra_info = self._assemble_loss_extra_info(example)
|
|
309
|
+
|
|
310
|
+
# ... and possibly some additional metadata from the example dictionary
|
|
311
|
+
# TODO: Generalize, so we always use the `extra_info` key, rather than unpacking the ground truth as well
|
|
312
|
+
metrics_extra_info.update(
|
|
313
|
+
{
|
|
314
|
+
# TODO: Remove, instead using `extra_info` for all keys
|
|
315
|
+
**{
|
|
316
|
+
k: example["ground_truth"][k]
|
|
317
|
+
for k in [
|
|
318
|
+
"interfaces_to_score",
|
|
319
|
+
"pn_units_to_score",
|
|
320
|
+
"chain_iid_token_lvl",
|
|
321
|
+
]
|
|
322
|
+
if k in example["ground_truth"]
|
|
323
|
+
},
|
|
324
|
+
"example_id": example[
|
|
325
|
+
"example_id"
|
|
326
|
+
], # We require the example ID for logging
|
|
327
|
+
# (From the parser)
|
|
328
|
+
**example.get("extra_info", {}),
|
|
329
|
+
}
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# (Create a shallow copy to avoid modifying the original dictionary)
|
|
333
|
+
return {**metrics_extra_info}
|
|
334
|
+
|
|
335
|
+
def _build_predicted_atom_array_stack(
|
|
336
|
+
self, network_output: dict, example: dict
|
|
337
|
+
) -> Union[AtomArrayStack, List[AtomArray]]:
|
|
338
|
+
atom_array = example["atom_array"]
|
|
339
|
+
f = example["feats"]
|
|
340
|
+
|
|
341
|
+
# ... Cleanup atom array:
|
|
342
|
+
atom_array.bonds = None
|
|
343
|
+
atom_array.res_name[~atom_array.is_motif_atom_with_fixed_seq] = (
|
|
344
|
+
"UNK" # Ensure non-motif residues set to UNK
|
|
345
|
+
)
|
|
346
|
+
atom_array = _reassign_unindexed_token_chains(atom_array)
|
|
347
|
+
|
|
348
|
+
# ... Build output atom array stack
|
|
349
|
+
atom_array_stack = _build_atom_array_stack(
|
|
350
|
+
network_output["X_L"],
|
|
351
|
+
atom_array,
|
|
352
|
+
sequence_logits=network_output.get("sequence_logits_I"),
|
|
353
|
+
sequence_indices=network_output.get("sequence_indices_I"),
|
|
354
|
+
allow_sequence_outputs=self.allow_sequence_outputs,
|
|
355
|
+
read_sequence_from_sequence_head=self.read_sequence_from_sequence_head,
|
|
356
|
+
association_scheme=self.association_scheme,
|
|
357
|
+
) # NB: Will be either list (when sequences are saved) or stack
|
|
358
|
+
|
|
359
|
+
arrays = atom_array_stack
|
|
360
|
+
metadata_dict = {i: {"metrics": {}} for i in range(len(arrays))}
|
|
361
|
+
|
|
362
|
+
# Add the seed to the metadata dictionary if provided
|
|
363
|
+
if self.seed is not None:
|
|
364
|
+
for i in range(len(arrays)):
|
|
365
|
+
metadata_dict[i]["seed"] = self.seed
|
|
366
|
+
|
|
367
|
+
atom_array_stack = []
|
|
368
|
+
for i, atom_array in enumerate(arrays):
|
|
369
|
+
# ... Create essential outputs for metadata dictionary
|
|
370
|
+
if "example" in example["specification"]:
|
|
371
|
+
metadata_dict[i] |= {"task": example["specification"]["example"]}
|
|
372
|
+
|
|
373
|
+
# ... Add original specification to metadata
|
|
374
|
+
if self.output_full_json:
|
|
375
|
+
metadata_dict[i] |= {
|
|
376
|
+
"specification": example["specification"],
|
|
377
|
+
}
|
|
378
|
+
if (
|
|
379
|
+
hasattr(self, "inference_sampler_overrides")
|
|
380
|
+
and self.inference_sampler_overrides
|
|
381
|
+
):
|
|
382
|
+
metadata_dict[i] |= {
|
|
383
|
+
"inference_sampler": self.inference_sampler_overrides
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
if np.any(atom_array.is_motif_atom_unindexed):
|
|
387
|
+
# ... insert unindexed motif to output
|
|
388
|
+
atom_array_processed, metadata = process_unindexed_outputs(
|
|
389
|
+
atom_array,
|
|
390
|
+
insert_guideposts=self.cleanup_guideposts,
|
|
391
|
+
)
|
|
392
|
+
global_logger.info(
|
|
393
|
+
f"Inserted unindexed motif atoms for example {i} with RMSD {metadata['insertion_rmsd']:.3f} A"
|
|
394
|
+
)
|
|
395
|
+
if self.cleanup_guideposts:
|
|
396
|
+
atom_array = atom_array_processed
|
|
397
|
+
|
|
398
|
+
diffused_index_map = metadata.pop("diffused_index_map", None)
|
|
399
|
+
metadata_dict[i]["metrics"] |= metadata
|
|
400
|
+
if diffused_index_map is not None:
|
|
401
|
+
metadata_dict[i]["diffused_index_map"] = diffused_index_map
|
|
402
|
+
else:
|
|
403
|
+
metadata_dict[i]["diffused_index_map"] = {}
|
|
404
|
+
|
|
405
|
+
# Also record where indexed motifs ended up
|
|
406
|
+
residue_start_atoms = atom_array[get_residue_starts(atom_array)]
|
|
407
|
+
indexed_residue_starts_non_ligand = residue_start_atoms[
|
|
408
|
+
~residue_start_atoms.is_motif_atom_unindexed
|
|
409
|
+
& ~residue_start_atoms.is_ligand
|
|
410
|
+
]
|
|
411
|
+
|
|
412
|
+
# If the src_component starts with an alphabetic character, it's from an external source
|
|
413
|
+
external_src_mask = np.array(
|
|
414
|
+
[
|
|
415
|
+
(s[0].isalpha() if len(s) > 0 else False)
|
|
416
|
+
for s in indexed_residue_starts_non_ligand.src_component
|
|
417
|
+
]
|
|
418
|
+
)
|
|
419
|
+
indexed_residue_starts_from_external_src = (
|
|
420
|
+
indexed_residue_starts_non_ligand[external_src_mask]
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
for token in indexed_residue_starts_from_external_src:
|
|
424
|
+
metadata_dict[i]["diffused_index_map"][token.src_component] = (
|
|
425
|
+
f"{token.chain_id}{token.res_id}"
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# ... Delete virtual atoms and assign atom names and elements
|
|
429
|
+
if self.cleanup_virtual_atoms:
|
|
430
|
+
atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
|
431
|
+
atom_array, association_scheme=self.association_scheme
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# ... When cleaning up virtual atoms, we can also calculate native_array_metricsl
|
|
435
|
+
metadata_dict[i]["metrics"] |= get_all_backbone_metrics(
|
|
436
|
+
atom_array,
|
|
437
|
+
compute_non_clash_metrics_for_diffused_region_only=self.compute_non_clash_metrics_for_diffused_region_only,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if (
|
|
441
|
+
"active_donor" in atom_array.get_annotation_categories()
|
|
442
|
+
or "active_acceptor" in atom_array.get_annotation_categories()
|
|
443
|
+
):
|
|
444
|
+
metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array)
|
|
445
|
+
|
|
446
|
+
if "partial_t" in f:
|
|
447
|
+
# Try calcualte a CA RMSD to input:
|
|
448
|
+
aa_in = example["atom_array"]
|
|
449
|
+
xyz_ca_input = aa_in.coord[np.isin(aa_in.atom_name, "CA")]
|
|
450
|
+
xyz_ca_output = atom_array.coord[np.isin(atom_array.atom_name, "CA")]
|
|
451
|
+
|
|
452
|
+
# Align ca and calculate RMSD:
|
|
453
|
+
if xyz_ca_input.shape == xyz_ca_output.shape:
|
|
454
|
+
try:
|
|
455
|
+
from rfd3.utils.alignment import weighted_rigid_align
|
|
456
|
+
|
|
457
|
+
xyz_ca_output_aligned = (
|
|
458
|
+
weighted_rigid_align(
|
|
459
|
+
torch.from_numpy(xyz_ca_input)[None],
|
|
460
|
+
torch.from_numpy(xyz_ca_output)[None],
|
|
461
|
+
)
|
|
462
|
+
.squeeze(0)
|
|
463
|
+
.numpy()
|
|
464
|
+
)
|
|
465
|
+
metadata_dict[i]["metrics"] |= {
|
|
466
|
+
"ca_rmsd_to_input": float(
|
|
467
|
+
np.sqrt(
|
|
468
|
+
np.mean(
|
|
469
|
+
np.square(
|
|
470
|
+
xyz_ca_input - xyz_ca_output_aligned
|
|
471
|
+
).sum(-1)
|
|
472
|
+
)
|
|
473
|
+
)
|
|
474
|
+
)
|
|
475
|
+
}
|
|
476
|
+
except Exception as e:
|
|
477
|
+
global_logger.warning(
|
|
478
|
+
f"Failed to calculate CA RMSD for partial diffusion output: {e}"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
atom_array_stack.append(atom_array)
|
|
482
|
+
|
|
483
|
+
# Reorder metadata dictionaries to ensure 'metrics' and 'specification' are last
|
|
484
|
+
metadata_dict = {k: _reorder_dict(d) for k, d in metadata_dict.items()}
|
|
485
|
+
return atom_array_stack, metadata_dict
|