rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/trainers/rf3.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
import torch
|
|
3
|
+
from beartype.typing import Any
|
|
4
|
+
from einops import repeat
|
|
5
|
+
from jaxtyping import Float, Int
|
|
6
|
+
from lightning_utilities import apply_to_collection
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from rf3.loss.af3_losses import (
|
|
9
|
+
ResidueSymmetryResolution,
|
|
10
|
+
SubunitSymmetryResolution,
|
|
11
|
+
)
|
|
12
|
+
from rf3.model.RF3 import ShouldEarlyStopFn
|
|
13
|
+
from rf3.utils.io import build_stack_from_atom_array_and_batched_coords
|
|
14
|
+
from rf3.utils.recycling import get_recycle_schedule
|
|
15
|
+
|
|
16
|
+
from foundry.common import exists
|
|
17
|
+
from foundry.metrics.losses import Loss
|
|
18
|
+
from foundry.metrics.metric import MetricManager
|
|
19
|
+
from foundry.trainers.fabric import FabricTrainer
|
|
20
|
+
from foundry.training.EMA import EMA
|
|
21
|
+
from foundry.utils.ddp import RankedLogger
|
|
22
|
+
from foundry.utils.torch import assert_no_nans, assert_same_shape
|
|
23
|
+
|
|
24
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _remap_outputs(
|
|
28
|
+
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
|
|
29
|
+
) -> Float[torch.Tensor, "D L 3"]:
|
|
30
|
+
"""Helper function to remap outputs using a mapping tensor."""
|
|
31
|
+
for i in range(xyz.shape[0]):
|
|
32
|
+
xyz[i, mapping[i]] = xyz[i].clone()
|
|
33
|
+
return xyz
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RF3Trainer(FabricTrainer):
|
|
37
|
+
"""Standard Trainer for AF3-style models"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
*,
|
|
42
|
+
n_recycles_train: int | None = None,
|
|
43
|
+
loss: DictConfig | dict | None = None,
|
|
44
|
+
metrics: DictConfig | dict | MetricManager | None = None,
|
|
45
|
+
seed=None, # dumped
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
"""See `FabricTrainer` for the additional initialization arguments.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
n_recycles_train: Maximum number of recycles (per-batch), for models that support recycling. During training, the model will be recycled a
|
|
52
|
+
random number of times between 1 and `n_recycles_train`. During inference, we determine the number of recycles from the MSA stack shape. However,
|
|
53
|
+
for training, we must sample the number of recycles upfront, so all GPUs within a distributed batch can sample the same number of recycles.
|
|
54
|
+
loss: Configuration for the loss function. If None, the loss function will not be instantiated.
|
|
55
|
+
metrics: Metrics configuration. Can be:
|
|
56
|
+
- DictConfig/dict with Hydra configs (instantiated internally)
|
|
57
|
+
- Pre-instantiated MetricManager
|
|
58
|
+
- None (no metrics)
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(**kwargs)
|
|
61
|
+
|
|
62
|
+
# (Initialize recycle schedule upfront so all GPU's can sample the same number of recycles within a batch)
|
|
63
|
+
self.n_recycles_train = n_recycles_train
|
|
64
|
+
self.recycle_schedule = get_recycle_schedule(
|
|
65
|
+
max_cycle=n_recycles_train,
|
|
66
|
+
n_epochs=self.max_epochs, # Set by FabricTrainer
|
|
67
|
+
n_train=self.n_examples_per_epoch, # Set by FabricTrainer
|
|
68
|
+
world_size=self.fabric.world_size,
|
|
69
|
+
) # [n_epochs, n_examples_per_epoch // world_size]
|
|
70
|
+
|
|
71
|
+
# Metrics
|
|
72
|
+
if isinstance(metrics, MetricManager):
|
|
73
|
+
# Already instantiated
|
|
74
|
+
self.metrics = metrics
|
|
75
|
+
elif metrics is not None:
|
|
76
|
+
# Hydra config - instantiate
|
|
77
|
+
self.metrics = MetricManager.instantiate_from_hydra(metrics_cfg=metrics)
|
|
78
|
+
else:
|
|
79
|
+
# No metrics
|
|
80
|
+
self.metrics = None
|
|
81
|
+
|
|
82
|
+
# Loss
|
|
83
|
+
self.loss = Loss(**loss) if loss else None
|
|
84
|
+
|
|
85
|
+
# (Symmetry resolution)
|
|
86
|
+
self.subunit_symm_resolve = SubunitSymmetryResolution()
|
|
87
|
+
self.residue_symm_resolve = ResidueSymmetryResolution()
|
|
88
|
+
|
|
89
|
+
def construct_model(self):
|
|
90
|
+
"""Construct the model and optionally wrap with EMA."""
|
|
91
|
+
# ... instantiate model with Hydra and Fabric
|
|
92
|
+
with self.fabric.init_module():
|
|
93
|
+
ranked_logger.info("Instantiating model...")
|
|
94
|
+
|
|
95
|
+
model = hydra.utils.instantiate(
|
|
96
|
+
self.state["train_cfg"].model.net,
|
|
97
|
+
_recursive_=False,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Optionally, wrap the model with EMA
|
|
101
|
+
if self.state["train_cfg"].model.ema is not None:
|
|
102
|
+
ranked_logger.info("Wrapping model with EMA...")
|
|
103
|
+
model = EMA(model, **self.state["train_cfg"].model.ema)
|
|
104
|
+
|
|
105
|
+
self.initialize_or_update_trainer_state({"model": model})
|
|
106
|
+
|
|
107
|
+
def _assemble_network_inputs(self, example: dict) -> dict:
|
|
108
|
+
"""Assemble and validate the network inputs."""
|
|
109
|
+
assert_same_shape(example["coord_atom_lvl_to_be_noised"], example["noise"])
|
|
110
|
+
network_input = {
|
|
111
|
+
"X_noisy_L": example["coord_atom_lvl_to_be_noised"] + example["noise"],
|
|
112
|
+
"t": example["t"],
|
|
113
|
+
"f": example["feats"],
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
assert_no_nans(
|
|
118
|
+
network_input["X_noisy_L"],
|
|
119
|
+
msg=f"network_input (X_noisy_L) for example_id: {example['example_id']}",
|
|
120
|
+
)
|
|
121
|
+
except AssertionError as e:
|
|
122
|
+
if self.state["model"].training:
|
|
123
|
+
# In some cases, we may indeed have NaNs in the the noisy coordinates; we can safely replace them with zeros,
|
|
124
|
+
# and begin noising of those coordinates (which will not have their loss computed) from the origin.
|
|
125
|
+
# Such a situation could occur if there was a chain in the crop with no resolved residues (but that contained resolved
|
|
126
|
+
# residues outside the crop); we then would not be able to resolve the missing coordinates to their "closest resolved neighbor"
|
|
127
|
+
# within the same chain.
|
|
128
|
+
network_input["X_noisy_L"] = torch.nan_to_num(
|
|
129
|
+
network_input["X_noisy_L"]
|
|
130
|
+
)
|
|
131
|
+
ranked_logger.warning(str(e))
|
|
132
|
+
else:
|
|
133
|
+
# During validation, since we do not crop, there should be no NaN's in the coordinates to noise
|
|
134
|
+
# (They were either removed, as is done with fully unresolved chains, or resolved accoring to our pipeline's rules)
|
|
135
|
+
raise e
|
|
136
|
+
|
|
137
|
+
assert_no_nans(
|
|
138
|
+
network_input["f"],
|
|
139
|
+
msg=f"NaN detected in `feats` for example_id: {example['example_id']}",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return network_input
|
|
143
|
+
|
|
144
|
+
def _assemble_loss_extra_info(self, example: dict) -> dict:
|
|
145
|
+
"""Assembles metadata arguments to the loss function (incremental to the network inputs and outputs)."""
|
|
146
|
+
# ... reshape
|
|
147
|
+
diffusion_batch_size = example["coord_atom_lvl_to_be_noised"].shape[0]
|
|
148
|
+
X_gt_L = repeat(
|
|
149
|
+
example["ground_truth"]["coord_atom_lvl"],
|
|
150
|
+
"l c -> d l c",
|
|
151
|
+
d=diffusion_batch_size,
|
|
152
|
+
) # [L, 3] -> [D, L, 3] with broadcasting
|
|
153
|
+
crd_mask_L = repeat(
|
|
154
|
+
example["ground_truth"]["mask_atom_lvl"],
|
|
155
|
+
"l -> d l",
|
|
156
|
+
d=diffusion_batch_size,
|
|
157
|
+
) # [L] -> [D, L] with broadcasting
|
|
158
|
+
|
|
159
|
+
loss_extra_info = {
|
|
160
|
+
"X_gt_L": X_gt_L, # [D, L, 3]
|
|
161
|
+
"crd_mask_L": crd_mask_L, # [D, L]
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
# ... merge with ground_truth key
|
|
165
|
+
loss_extra_info.update(example["ground_truth"])
|
|
166
|
+
|
|
167
|
+
return loss_extra_info
|
|
168
|
+
|
|
169
|
+
def _assemble_metrics_extra_info(self, example: dict, network_output: dict) -> dict:
|
|
170
|
+
"""Prepares the extra info for the metrics"""
|
|
171
|
+
# We need the same information as for the loss...
|
|
172
|
+
metrics_extra_info = self._assemble_loss_extra_info(example)
|
|
173
|
+
|
|
174
|
+
# ... and possibly some additional metadata from the example dictionary
|
|
175
|
+
# TODO: Generalize, so we always use the `extra_info` key, rather than unpacking the ground truth as well
|
|
176
|
+
metrics_extra_info.update(
|
|
177
|
+
{
|
|
178
|
+
# TODO: Remove, instead using `extra_info` for all keys
|
|
179
|
+
**{
|
|
180
|
+
k: example["ground_truth"][k]
|
|
181
|
+
for k in [
|
|
182
|
+
"interfaces_to_score",
|
|
183
|
+
"pn_units_to_score",
|
|
184
|
+
"chain_iid_token_lvl",
|
|
185
|
+
]
|
|
186
|
+
if k in example["ground_truth"]
|
|
187
|
+
},
|
|
188
|
+
"example_id": example[
|
|
189
|
+
"example_id"
|
|
190
|
+
], # We require the example ID for logging
|
|
191
|
+
# (From the parser)
|
|
192
|
+
**example.get("extra_info", {}),
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Record metrics_tags for this example
|
|
197
|
+
metrics_extra_info["metrics_tags"] = example.get("metrics_tags", set())
|
|
198
|
+
|
|
199
|
+
# (Create a shallow copy to avoid modifying the original dictionary)
|
|
200
|
+
return {**metrics_extra_info}
|
|
201
|
+
|
|
202
|
+
def training_step(
|
|
203
|
+
self,
|
|
204
|
+
batch: Any,
|
|
205
|
+
batch_idx: int,
|
|
206
|
+
is_accumulating: bool,
|
|
207
|
+
) -> None:
|
|
208
|
+
"""Training step, running forward and backward passes.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
batch: The current batch; can be of any form.
|
|
212
|
+
batch_idx: The index of the current batch.
|
|
213
|
+
is_accumulating: Whether we are accumulating gradients (i.e., not yet calling optimizer.step()).
|
|
214
|
+
If this is the case, we should skip the synchronization during the backward pass.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
None; we call `loss.backward()` directly, and store the outputs in `self._current_train_return`.
|
|
218
|
+
"""
|
|
219
|
+
model = self.state["model"]
|
|
220
|
+
assert model.training, "Model must be training!"
|
|
221
|
+
|
|
222
|
+
# Recycling
|
|
223
|
+
# (Number of recycles for the current batch; shared across all GPUs within a distributed batch)
|
|
224
|
+
n_cycle = self.recycle_schedule[self.state["current_epoch"], batch_idx].item()
|
|
225
|
+
|
|
226
|
+
with self.fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
227
|
+
# (We assume batch size of 1 for structure predictions)
|
|
228
|
+
example = batch[0] if not isinstance(batch, dict) else batch
|
|
229
|
+
|
|
230
|
+
network_input = self._assemble_network_inputs(example)
|
|
231
|
+
|
|
232
|
+
# Forward pass (without rollout)
|
|
233
|
+
network_output = model.forward(input=network_input, n_cycle=n_cycle)
|
|
234
|
+
assert_no_nans(
|
|
235
|
+
network_output,
|
|
236
|
+
msg=f"network_output for example_id: {example['example_id']}",
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
loss_extra_info = self._assemble_loss_extra_info(example)
|
|
240
|
+
|
|
241
|
+
total_loss, loss_dict_batched = self.loss(
|
|
242
|
+
network_input=network_input,
|
|
243
|
+
network_output=network_output,
|
|
244
|
+
# TODO: Rename `loss_input` to `extra_info` to pattern-match metrics
|
|
245
|
+
loss_input=loss_extra_info,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Backward pass
|
|
249
|
+
self.fabric.backward(total_loss)
|
|
250
|
+
|
|
251
|
+
# ... store the outputs without gradients for use in logging, callbacks, learning rate schedulers, etc.
|
|
252
|
+
self._current_train_return = apply_to_collection(
|
|
253
|
+
{"total_loss": total_loss, "loss_dict": loss_dict_batched},
|
|
254
|
+
dtype=torch.Tensor,
|
|
255
|
+
function=lambda x: x.detach(),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def validation_step(
|
|
259
|
+
self,
|
|
260
|
+
batch: Any,
|
|
261
|
+
batch_idx: int,
|
|
262
|
+
compute_metrics: bool = True,
|
|
263
|
+
) -> dict:
|
|
264
|
+
"""Validation step, running forward pass and computing validation metrics.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
batch: The current batch; can be of any form.
|
|
268
|
+
batch_idx: The index of the current batch.
|
|
269
|
+
compute_metrics: Whether to compute metrics. If False, we will not compute metrics, and the output will be None.
|
|
270
|
+
Set to False during the inference pipeline, where we need the network output but cannot compute metrics (since we
|
|
271
|
+
do not have the ground truth).
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
dict: Output dictionary containing the validation metrics and network output.
|
|
275
|
+
"""
|
|
276
|
+
model = self.state["model"]
|
|
277
|
+
assert not model.training, "Model must be in evaluation mode during validation!"
|
|
278
|
+
|
|
279
|
+
example = batch[0] if not isinstance(batch, dict) else batch
|
|
280
|
+
|
|
281
|
+
network_input = self._assemble_network_inputs(example)
|
|
282
|
+
|
|
283
|
+
assert_no_nans(
|
|
284
|
+
network_input,
|
|
285
|
+
msg=f"network_input for example_id: {example['example_id']}",
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# ... forward pass (with rollout)
|
|
289
|
+
# (Note that forward() passes to the EMA/shadow model if the model is not training)
|
|
290
|
+
network_output = model.forward(
|
|
291
|
+
input=network_input,
|
|
292
|
+
n_cycle=example["feats"]["msa_stack"].shape[
|
|
293
|
+
0
|
|
294
|
+
], # Determine the number of recycles from the MSA stack shape
|
|
295
|
+
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
assert_no_nans(
|
|
299
|
+
network_output,
|
|
300
|
+
msg=f"network_output for example_id: {example['example_id']}",
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
metrics_output = {}
|
|
304
|
+
if compute_metrics and exists(self.metrics):
|
|
305
|
+
metrics_extra_info = self._assemble_metrics_extra_info(
|
|
306
|
+
example, network_output
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Symmetry resolution
|
|
310
|
+
# TODO: Refactor such that symmetry returns the ideal coordinate permutation, we apply permutation, and pass adjusted prediction to metrics
|
|
311
|
+
# (without needing to use `extra_info` as we are now)
|
|
312
|
+
# TODO: Update symmetry resolution to be functional (vs. using class variable), take explicit inputs (vs. all from netowork_ouput), and use extra_info for the keys it needs
|
|
313
|
+
metrics_extra_info = self.subunit_symm_resolve(
|
|
314
|
+
network_output,
|
|
315
|
+
metrics_extra_info,
|
|
316
|
+
example["symmetry_resolution"],
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
metrics_extra_info = self.residue_symm_resolve(
|
|
320
|
+
network_output,
|
|
321
|
+
metrics_extra_info,
|
|
322
|
+
example["automorphisms"],
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
metrics_output = self.metrics(
|
|
326
|
+
network_input=network_input,
|
|
327
|
+
network_output=network_output,
|
|
328
|
+
extra_info=metrics_extra_info,
|
|
329
|
+
# (Uses the permuted ground truth after symmetry resolution)
|
|
330
|
+
ground_truth_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
|
331
|
+
metrics_extra_info["X_gt_L"], example.get("atom_array", None)
|
|
332
|
+
),
|
|
333
|
+
predicted_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
|
334
|
+
network_output["X_L"], example.get("atom_array", None)
|
|
335
|
+
),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Avoid gradients in stored values to prevent memory leaks
|
|
339
|
+
if metrics_output is not None:
|
|
340
|
+
metrics_output = apply_to_collection(
|
|
341
|
+
metrics_output, torch.Tensor, lambda x: x.detach()
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
network_output = apply_to_collection(
|
|
345
|
+
network_output, torch.Tensor, lambda x: x.detach()
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return {"metrics_output": metrics_output, "network_output": network_output}
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class RF3TrainerWithConfidence(RF3Trainer):
|
|
352
|
+
"""AF-3 trainer with rollout and confidence prediction"""
|
|
353
|
+
|
|
354
|
+
def construct_model(self):
|
|
355
|
+
super().construct_model()
|
|
356
|
+
|
|
357
|
+
# Freeze gradients for all modules except the confidence head
|
|
358
|
+
for name, param in self.state["model"].named_parameters():
|
|
359
|
+
if "model.confidence_head" not in name:
|
|
360
|
+
param.requires_grad = False
|
|
361
|
+
|
|
362
|
+
def _assemble_network_inputs(self, example):
|
|
363
|
+
# assemble the base network inputs...
|
|
364
|
+
network_input = super()._assemble_network_inputs(example)
|
|
365
|
+
# ... and then add the confidence-specific inputs
|
|
366
|
+
network_input.update(
|
|
367
|
+
{
|
|
368
|
+
"seq": example["confidence_feats"]["rf2aa_seq"],
|
|
369
|
+
"rep_atom_idxs": example["ground_truth"]["rep_atom_idxs"],
|
|
370
|
+
"frame_atom_idxs": example["confidence_feats"][
|
|
371
|
+
"pae_frame_idx_token_lvl_from_atom_lvl"
|
|
372
|
+
],
|
|
373
|
+
}
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
return network_input
|
|
377
|
+
|
|
378
|
+
def _assemble_loss_extra_info(self, example):
|
|
379
|
+
# assemble the base loss extra info...
|
|
380
|
+
loss_extra_info = super()._assemble_loss_extra_info(example)
|
|
381
|
+
# ... and then add the confidence-specific inputs
|
|
382
|
+
loss_extra_info.update(
|
|
383
|
+
{
|
|
384
|
+
# TODO: We are duplicating network_input here; we should be able to significantly trim this dictionary
|
|
385
|
+
"seq": example["confidence_feats"]["rf2aa_seq"],
|
|
386
|
+
"atom_frames": example["confidence_feats"]["atom_frames"],
|
|
387
|
+
"tok_idx": example["feats"]["atom_to_token_map"],
|
|
388
|
+
"is_real_atom": example["confidence_feats"]["is_real_atom"],
|
|
389
|
+
"rep_atom_idxs": example["ground_truth"]["rep_atom_idxs"],
|
|
390
|
+
"frame_atom_idxs": example["confidence_feats"][
|
|
391
|
+
"pae_frame_idx_token_lvl_from_atom_lvl"
|
|
392
|
+
],
|
|
393
|
+
}
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
return loss_extra_info
|
|
397
|
+
|
|
398
|
+
def _assemble_metrics_extra_info(self, example, network_output):
|
|
399
|
+
# assemble the base metrics extra info...
|
|
400
|
+
metrics_extra_info = super()._assemble_metrics_extra_info(
|
|
401
|
+
example, network_output
|
|
402
|
+
)
|
|
403
|
+
# ... and then add the confidence-specific inputs
|
|
404
|
+
# TODO: Refactor; we should not need pass confidence log config through metrics extra info, it should be a property of the Metric (e.g., passed at `_init_` using Hydra interpolation from the relevant loss config)
|
|
405
|
+
metrics_extra_info.update(
|
|
406
|
+
{
|
|
407
|
+
"is_real_atom": example["confidence_feats"]["is_real_atom"],
|
|
408
|
+
"is_ligand": example["feats"]["is_ligand"],
|
|
409
|
+
# TODO: Refactor so that we pass the relevant values from the config direclty to the Metric upon instantiation (reference in Hydra through interpolation)
|
|
410
|
+
"confidence_loss": self.state["train_cfg"].trainer.loss.confidence_loss,
|
|
411
|
+
}
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
return metrics_extra_info
|
|
415
|
+
|
|
416
|
+
def training_step(
|
|
417
|
+
self,
|
|
418
|
+
batch: Any,
|
|
419
|
+
batch_idx: int,
|
|
420
|
+
is_accumulating: bool,
|
|
421
|
+
) -> None:
|
|
422
|
+
"""Perform mini-rollout and assess gradient of the confidence head parameters with respect to the confidence loss."""
|
|
423
|
+
model = self.state["model"]
|
|
424
|
+
assert model.training, "Model must be training!"
|
|
425
|
+
|
|
426
|
+
# Recycling
|
|
427
|
+
# (Number of recycles for the current batch; shared across all GPUs within a distributed batch)
|
|
428
|
+
n_cycle = self.recycle_schedule[self.state["current_epoch"], batch_idx].item()
|
|
429
|
+
|
|
430
|
+
with self.fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
431
|
+
# (We assume batch size of 1 for structure predictions)
|
|
432
|
+
example = batch[0] if not isinstance(batch, dict) else batch
|
|
433
|
+
|
|
434
|
+
network_input = self._assemble_network_inputs(example)
|
|
435
|
+
|
|
436
|
+
# Forward pass (with mini-rollout)
|
|
437
|
+
# NOTE: We use the non-EMA weights for structure prediction; this approach is theoretically sub-optimal, since
|
|
438
|
+
# we should be using the EMA weights for structure prediction (given those parameters are frozen) and the non-EMA weights
|
|
439
|
+
# for the confidence head, to better match the inference-time task
|
|
440
|
+
network_output = model.forward(
|
|
441
|
+
input=network_input,
|
|
442
|
+
n_cycle=n_cycle,
|
|
443
|
+
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
|
|
444
|
+
)
|
|
445
|
+
assert_no_nans(
|
|
446
|
+
network_output,
|
|
447
|
+
msg=f"network_output for example_id: {example['example_id']}",
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
loss_extra_info = self._assemble_loss_extra_info(example)
|
|
451
|
+
|
|
452
|
+
# Remap X_L to the rollout X_L so ground truth matches rollout batch dimension during the symmetry resolution
|
|
453
|
+
# NOTE: Since `X_L` derives from the mini-rollout, we cannot compute standard training loss and perform gradient updates
|
|
454
|
+
network_output["X_L"] = network_output["X_pred_rollout_L"]
|
|
455
|
+
|
|
456
|
+
# (Symmetry resolution)
|
|
457
|
+
loss_extra_info = self.subunit_symm_resolve(
|
|
458
|
+
network_output, loss_extra_info, example["symmetry_resolution"]
|
|
459
|
+
)
|
|
460
|
+
loss_extra_info = self.residue_symm_resolve(
|
|
461
|
+
network_output, loss_extra_info, example["automorphisms"]
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# We only assess the confidence loss
|
|
465
|
+
total_loss, loss_dict_batched = self.loss(
|
|
466
|
+
network_input=network_input,
|
|
467
|
+
network_output=network_output,
|
|
468
|
+
# TODO: Rename `loss_input` to `extra_info` to pattern-match metrics
|
|
469
|
+
loss_input=loss_extra_info,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Backward pass
|
|
473
|
+
self.fabric.backward(total_loss)
|
|
474
|
+
|
|
475
|
+
# ... store the outputs without gradients for use in logging, callbacks, learning rate schedulers, etc.
|
|
476
|
+
self._current_train_return = apply_to_collection(
|
|
477
|
+
{"total_loss": total_loss, "loss_dict": loss_dict_batched},
|
|
478
|
+
dtype=torch.Tensor,
|
|
479
|
+
function=lambda x: x.detach(),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def validation_step(
|
|
483
|
+
self,
|
|
484
|
+
batch: Any,
|
|
485
|
+
batch_idx: int,
|
|
486
|
+
compute_metrics: bool = True,
|
|
487
|
+
should_early_stop_fn: ShouldEarlyStopFn | None = None,
|
|
488
|
+
) -> dict:
|
|
489
|
+
"""Validation step, running forward pass (with full rollout) and computing validation metrics, including confidence."""
|
|
490
|
+
model = self.state["model"]
|
|
491
|
+
assert not model.training, "Model must be in evaluation mode during validation!"
|
|
492
|
+
|
|
493
|
+
example = batch[0] if not isinstance(batch, dict) else batch
|
|
494
|
+
|
|
495
|
+
network_input = self._assemble_network_inputs(example)
|
|
496
|
+
|
|
497
|
+
assert_no_nans(
|
|
498
|
+
network_input,
|
|
499
|
+
msg=f"network_input for example_id: {example['example_id']}",
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# ... forward pass (with FULL rollout)
|
|
503
|
+
# (Note that forward() passes to the EMA/shadow model if the model is not training)
|
|
504
|
+
network_output = model.forward(
|
|
505
|
+
input=network_input,
|
|
506
|
+
n_cycle=example["feats"]["msa_stack"].shape[
|
|
507
|
+
0
|
|
508
|
+
], # Determine the number of recycles from the MSA stack shape
|
|
509
|
+
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
|
|
510
|
+
should_early_stop_fn=should_early_stop_fn,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
assert_no_nans(
|
|
514
|
+
network_output,
|
|
515
|
+
msg=f"network_output for example_id: {example['example_id']}",
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Remap X_L to the rollout X_L
|
|
519
|
+
network_output["X_L"] = network_output.get("X_pred_rollout_L")
|
|
520
|
+
|
|
521
|
+
metrics_output = {}
|
|
522
|
+
if (
|
|
523
|
+
compute_metrics
|
|
524
|
+
and exists(self.metrics)
|
|
525
|
+
and not network_output.get("early_stopped", False)
|
|
526
|
+
):
|
|
527
|
+
# Assemble the base metrics extra info and add confidence-specific inputs
|
|
528
|
+
metrics_extra_info = self._assemble_metrics_extra_info(
|
|
529
|
+
example, network_output
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Symmetry resolution
|
|
533
|
+
metrics_extra_info = self.subunit_symm_resolve(
|
|
534
|
+
network_output,
|
|
535
|
+
metrics_extra_info,
|
|
536
|
+
example["symmetry_resolution"],
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
metrics_extra_info = self.residue_symm_resolve(
|
|
540
|
+
network_output,
|
|
541
|
+
metrics_extra_info,
|
|
542
|
+
example["automorphisms"],
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
metrics_output = self.metrics(
|
|
546
|
+
network_input=network_input,
|
|
547
|
+
network_output=network_output,
|
|
548
|
+
extra_info=metrics_extra_info,
|
|
549
|
+
# (Uses the permuted ground truth after symmetry resolution)
|
|
550
|
+
ground_truth_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
|
551
|
+
metrics_extra_info["X_gt_L"], example.get("atom_array", None)
|
|
552
|
+
),
|
|
553
|
+
predicted_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
|
554
|
+
network_output["X_L"], example.get("atom_array", None)
|
|
555
|
+
),
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Avoid gradients in stored values to prevent memory leaks
|
|
559
|
+
if metrics_output is not None:
|
|
560
|
+
metrics_output = apply_to_collection(
|
|
561
|
+
metrics_output, torch.Tensor, lambda x: x.detach()
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
network_output = (
|
|
565
|
+
apply_to_collection(network_output, torch.Tensor, lambda x: x.detach())
|
|
566
|
+
if network_output is not None
|
|
567
|
+
else None
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return {"metrics_output": metrics_output, "network_output": network_output}
|
rf3/util_module.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def init_lecun_normal(module, scale=1.0):
|
|
6
|
+
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
|
7
|
+
normal = torch.distributions.normal.Normal(0, 1)
|
|
8
|
+
|
|
9
|
+
alpha = (a - mu) / sigma
|
|
10
|
+
beta = (b - mu) / sigma
|
|
11
|
+
|
|
12
|
+
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
|
|
13
|
+
p = (
|
|
14
|
+
alpha_normal_cdf
|
|
15
|
+
+ (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
|
|
19
|
+
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
|
|
20
|
+
x = torch.clamp(x, a, b)
|
|
21
|
+
|
|
22
|
+
return x
|
|
23
|
+
|
|
24
|
+
def sample_truncated_normal(shape, scale=1.0):
|
|
25
|
+
stddev = np.sqrt(scale / shape[-1]) / 0.87962566103423978 # shape[-1] = fan_in
|
|
26
|
+
return stddev * truncated_normal(torch.rand(shape))
|
|
27
|
+
|
|
28
|
+
module.weight = torch.nn.Parameter((sample_truncated_normal(module.weight.shape)))
|
|
29
|
+
return module
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# for gradient checkpointing
|
|
33
|
+
def create_custom_forward(module, **kwargs):
|
|
34
|
+
def custom_forward(*inputs):
|
|
35
|
+
return module(*inputs, **kwargs)
|
|
36
|
+
|
|
37
|
+
return custom_forward
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
|
|
41
|
+
# Distance radial basis function
|
|
42
|
+
D_max = D_min + (D_count - 1) * D_sigma
|
|
43
|
+
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
|
|
44
|
+
D_mu = D_mu[None, :]
|
|
45
|
+
D_expand = torch.unsqueeze(D, -1)
|
|
46
|
+
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
|
|
47
|
+
return RBF
|