rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/model/RF3.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from contextlib import ExitStack
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils.checkpoint as checkpoint
|
|
6
|
+
from beartype.typing import Any, Generator, Protocol
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from rf3.diffusion_samplers.inference_sampler import (
|
|
9
|
+
SampleDiffusion,
|
|
10
|
+
SamplePartialDiffusion,
|
|
11
|
+
)
|
|
12
|
+
from rf3.model.layers.pairformer_layers import (
|
|
13
|
+
FeatureInitializer,
|
|
14
|
+
)
|
|
15
|
+
from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
from foundry.training.checkpoint import create_custom_forward
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
Shape Annotation Glossary:
|
|
22
|
+
I: # tokens (coarse representation)
|
|
23
|
+
L: # atoms (fine representation)
|
|
24
|
+
M: # msa
|
|
25
|
+
T: # templates
|
|
26
|
+
D: # diffusion structure batch dim
|
|
27
|
+
|
|
28
|
+
C_s: # Token-level single reprentation channel dimension
|
|
29
|
+
C_z: # Token-level pair reprentation channel dimension
|
|
30
|
+
C_atom: # Atom-level single reprentation channel dimension
|
|
31
|
+
C_atompair: # Atom-level pair reprentation channel dimension
|
|
32
|
+
|
|
33
|
+
Tensor Name Glossary:
|
|
34
|
+
S: Token-level single representation (I, C_s)
|
|
35
|
+
Z: Token-level pair representation (I, I, C_z)
|
|
36
|
+
Q: Atom-level single representation (L, C_atom)
|
|
37
|
+
P: Atom-level pair representation (L, L, C_atompair)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ShouldEarlyStopFn(Protocol):
|
|
42
|
+
def __call__(
|
|
43
|
+
self, confidence_outputs: dict[str, Any], first_recycle_outputs: dict[str, Any]
|
|
44
|
+
) -> tuple[bool, dict[str, Any]]:
|
|
45
|
+
"""Duck-typed function Protocol for early stopping based on confidence outputs.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
tuple: A pair containing:
|
|
49
|
+
- should_stop (bool): Whether to stop early.
|
|
50
|
+
- additional_data (dict): Metadata for the user, if any
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RF3(nn.Module):
|
|
56
|
+
"""RF3 Network module.
|
|
57
|
+
|
|
58
|
+
We adhere to the PyTorch Lightning Style Guide; see (1).
|
|
59
|
+
|
|
60
|
+
References:
|
|
61
|
+
(1) PyTorch Lightning Style Guide: https://lightning.ai/docs/pytorch/latest/starter/style_guide.html
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
*,
|
|
67
|
+
# Arguments for modules that will be instantiated
|
|
68
|
+
feature_initializer: DictConfig | dict,
|
|
69
|
+
recycler: DictConfig | dict,
|
|
70
|
+
diffusion_module: DictConfig | dict,
|
|
71
|
+
distogram_head: DictConfig | dict,
|
|
72
|
+
inference_sampler: DictConfig | dict,
|
|
73
|
+
# Channel dimensions
|
|
74
|
+
c_s: int, # AF-3: 384,
|
|
75
|
+
c_z: int, # AF-3: 128,
|
|
76
|
+
c_atom: int, # AF-3: 128,
|
|
77
|
+
c_atompair: int, # AF-3: 16,
|
|
78
|
+
c_s_inputs: int, # AF-3: 449,
|
|
79
|
+
):
|
|
80
|
+
"""Initializes the AF3 model.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
feature_initializer: Arguments for FeatureInitializer
|
|
84
|
+
recycler: Arguments for Recycler
|
|
85
|
+
diffusion_module: Arguments for DiffusionModule
|
|
86
|
+
distogram_head: Arguments for DistogramHead
|
|
87
|
+
inference_sampler: Arguments for the SampleDiffusion class, used for inference (contains no trainable parameters)
|
|
88
|
+
c_s: Token-level single reprentation channel dimension
|
|
89
|
+
c_z: Token-level pair reprentation channel dimension
|
|
90
|
+
c_atom: Atom-level single reprentation channel dimension
|
|
91
|
+
c_atompair: Atom-level pair reprentation channel dimension
|
|
92
|
+
c_s_inputs: Output dimension of the InputFeatureEmbedder
|
|
93
|
+
"""
|
|
94
|
+
super().__init__()
|
|
95
|
+
|
|
96
|
+
# ... initialize the FeatureInitializer, which creates the initial token-level representations and conditioning
|
|
97
|
+
self.feature_initializer = FeatureInitializer(
|
|
98
|
+
c_s=c_s,
|
|
99
|
+
c_z=c_z,
|
|
100
|
+
c_atom=c_atom,
|
|
101
|
+
c_atompair=c_atompair,
|
|
102
|
+
c_s_inputs=c_s_inputs,
|
|
103
|
+
**feature_initializer,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# ... initialize the Recycler, which runs the trunk repeatedly with shared weights
|
|
107
|
+
self.recycler = Recycler(c_s=c_s, c_z=c_z, **recycler)
|
|
108
|
+
self.diffusion_module = DiffusionModule(
|
|
109
|
+
c_atom=c_atom,
|
|
110
|
+
c_atompair=c_atompair,
|
|
111
|
+
c_s=c_s,
|
|
112
|
+
c_z=c_z,
|
|
113
|
+
**diffusion_module,
|
|
114
|
+
)
|
|
115
|
+
self.distogram_head = DistogramHead(c_z=c_z, **distogram_head)
|
|
116
|
+
|
|
117
|
+
# ... initialize the inference sampler, which performs a full diffusion rollout during inference
|
|
118
|
+
self.inference_sampler = (
|
|
119
|
+
SampleDiffusion(**inference_sampler)
|
|
120
|
+
if not inference_sampler.get("partial_t", False)
|
|
121
|
+
else SamplePartialDiffusion(**inference_sampler)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def forward(
|
|
125
|
+
self,
|
|
126
|
+
input: dict,
|
|
127
|
+
n_cycle: int,
|
|
128
|
+
coord_atom_lvl_to_be_noised: torch.Tensor = None,
|
|
129
|
+
) -> dict:
|
|
130
|
+
"""Complete forward pass of the model.
|
|
131
|
+
|
|
132
|
+
Runs recycling with gradients only on final recycle.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
input (dict): Dictionary of model inputs
|
|
136
|
+
n_cycle (int): Number of recycling cycles for the trunk
|
|
137
|
+
coord_atom_lvl_to_be_noised (torch.Tensor): Atom-level coordinates to be noised further. Optional;
|
|
138
|
+
only used during inference for partial denoising.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
dict: Dictionary of model outputs, including:
|
|
142
|
+
- X_L: Predicted atomic coordinates [D, L, 3]
|
|
143
|
+
- distogram: Predicted distogram [I, I, C], where C is the number of bins in the distogram
|
|
144
|
+
- If not training, additional lists are returned, each of length T:
|
|
145
|
+
* X_noisy_L_traj: List of noisy atomic coordinates at each timestep [D, L, 3]
|
|
146
|
+
* X_denoised_L_traj: List of denoised atomic coordinates at each timestep [D, L, 3]
|
|
147
|
+
* t_hats: List of tensor scalars representing the noise schedule at each timestep
|
|
148
|
+
"""
|
|
149
|
+
# Cast features to lower precision if autocast is enabled
|
|
150
|
+
if torch.is_autocast_enabled():
|
|
151
|
+
autocast_dtype = torch.get_autocast_dtype("cuda")
|
|
152
|
+
for x in [
|
|
153
|
+
"msa_stack",
|
|
154
|
+
"profile",
|
|
155
|
+
"template_distogram",
|
|
156
|
+
"template_restype",
|
|
157
|
+
"template_unit_vector",
|
|
158
|
+
]:
|
|
159
|
+
if x in input["f"]:
|
|
160
|
+
input["f"][x] = input["f"][x].to(autocast_dtype)
|
|
161
|
+
|
|
162
|
+
# ... recycling
|
|
163
|
+
# Gives dictionary of outputs S_inputs_I, S_init_I, Z_init_II, S_I, Z_II
|
|
164
|
+
# (We use `deque` with maxlen=1 to ensure that we only keep the last output in memory)
|
|
165
|
+
try:
|
|
166
|
+
recycling_outputs = deque(
|
|
167
|
+
self.trunk_forward_with_recycling(f=input["f"], n_recycles=n_cycle),
|
|
168
|
+
maxlen=1,
|
|
169
|
+
).pop()
|
|
170
|
+
except IndexError:
|
|
171
|
+
# Handle the case where the generator is empty
|
|
172
|
+
raise RuntimeError("Recycling generator produced no outputs")
|
|
173
|
+
|
|
174
|
+
# Predict the distogram from the pair representation
|
|
175
|
+
distogram_pred = self.distogram_head(recycling_outputs["Z_II"])
|
|
176
|
+
|
|
177
|
+
# ... post-recycling (diffusion module)
|
|
178
|
+
if self.training:
|
|
179
|
+
# Single denoising step
|
|
180
|
+
X_pred = self.diffusion_module(
|
|
181
|
+
X_noisy_L=input["X_noisy_L"],
|
|
182
|
+
t=input["t"],
|
|
183
|
+
f=input["f"],
|
|
184
|
+
S_inputs_I=recycling_outputs["S_inputs_I"],
|
|
185
|
+
S_trunk_I=recycling_outputs["S_I"],
|
|
186
|
+
Z_trunk_II=recycling_outputs["Z_II"],
|
|
187
|
+
) # [D, L, 3]
|
|
188
|
+
return dict(
|
|
189
|
+
X_L=X_pred,
|
|
190
|
+
distogram=distogram_pred,
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
# Full diffusion rollout (no gradients, or will OOM)
|
|
194
|
+
sample_diffusion_outs = self.inference_sampler.sample_diffusion_like_af3(
|
|
195
|
+
f=input["f"],
|
|
196
|
+
S_inputs_I=recycling_outputs["S_inputs_I"],
|
|
197
|
+
S_trunk_I=recycling_outputs["S_I"],
|
|
198
|
+
Z_trunk_II=recycling_outputs["Z_II"],
|
|
199
|
+
diffusion_module=self.diffusion_module,
|
|
200
|
+
diffusion_batch_size=input["t"].shape[0],
|
|
201
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
|
202
|
+
)
|
|
203
|
+
return dict(
|
|
204
|
+
X_L=sample_diffusion_outs["X_L"],
|
|
205
|
+
distogram=distogram_pred,
|
|
206
|
+
# For reporting, inference (validation or testing) only
|
|
207
|
+
X_noisy_L_traj=sample_diffusion_outs["X_noisy_L_traj"],
|
|
208
|
+
X_denoised_L_traj=sample_diffusion_outs["X_denoised_L_traj"],
|
|
209
|
+
t_hats=sample_diffusion_outs["t_hats"],
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def trunk_forward_with_recycling(
|
|
213
|
+
self, f: dict, n_recycles: int
|
|
214
|
+
) -> Generator[dict[str, torch.Tensor]]:
|
|
215
|
+
"""Forward pass of the AF-3 trunk.
|
|
216
|
+
|
|
217
|
+
(e.g., the recycling process, including the MSAModule, PairfomerStack, etc.).
|
|
218
|
+
|
|
219
|
+
Notes:
|
|
220
|
+
- We run with gradients ONLY on the final recycle
|
|
221
|
+
- All recycles use shared weights (ResNet-style)
|
|
222
|
+
- We yield results after reach recycle to support use cases such as e.g., early stopping during inference
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
f: Feature dictionary
|
|
226
|
+
n_recycles: Number of recycles to run
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
dict: Recycling outputs, with keys:
|
|
230
|
+
- S_inputs_I: Token-level single representation input, prior to AtomAttention [I, c_s_inputs]
|
|
231
|
+
- S_init_I: Token-level single representation initialization [I, c_s], after AtomAttention but before recycling stack
|
|
232
|
+
- Z_init_II: Token-level pair representation initialization [I, I, c_z], after AtomAttention but before recycling stack
|
|
233
|
+
- S_I: Token-level single representation [I, c_s], after recycling stack
|
|
234
|
+
- Z_II: Token-level pair representation [I, I, c_z], after recycling stack
|
|
235
|
+
"""
|
|
236
|
+
# ... initialize the recycling process (feature initialization)
|
|
237
|
+
# Gives S_inputs_I, S_init_I, Z_init_II, S_I, Z_II
|
|
238
|
+
initialized_features = self.pre_recycle(f)
|
|
239
|
+
|
|
240
|
+
# ... collect the recycling inputs, which will be updated in place
|
|
241
|
+
recycling_inputs = {**initialized_features, "f": f}
|
|
242
|
+
|
|
243
|
+
for i_cycle in range(n_recycles):
|
|
244
|
+
with ExitStack() as stack:
|
|
245
|
+
# For the first n_recycles - 1 cycles (all but the last recycle), we run without gradients
|
|
246
|
+
if i_cycle < n_recycles - 1:
|
|
247
|
+
stack.enter_context(torch.no_grad())
|
|
248
|
+
|
|
249
|
+
# Clear the autocast cache if gradients are enabled (workaround for autocast bug)
|
|
250
|
+
# See: https://github.com/pytorch/pytorch/issues/65766
|
|
251
|
+
if torch.is_grad_enabled():
|
|
252
|
+
torch.clear_autocast_cache()
|
|
253
|
+
|
|
254
|
+
# Select the MSA for the current recycle (we sample an i.i.d. MSA for each recycle)
|
|
255
|
+
recycling_inputs["f"]["msa"] = f["msa_stack"][i_cycle]
|
|
256
|
+
|
|
257
|
+
# Run the model trunk (MSAModule, PairformerStack, etc.)
|
|
258
|
+
# We alter the S_I and Z_II in place such that the next iteration uses the updated values
|
|
259
|
+
recycling_inputs = self.recycle(**recycling_inputs)
|
|
260
|
+
|
|
261
|
+
# Yield after each recycle
|
|
262
|
+
yield {
|
|
263
|
+
"S_inputs_I": recycling_inputs["S_inputs_I"],
|
|
264
|
+
"S_init_I": recycling_inputs["S_init_I"],
|
|
265
|
+
"Z_init_II": recycling_inputs["Z_init_II"],
|
|
266
|
+
"S_I": recycling_inputs["S_I"],
|
|
267
|
+
"Z_II": recycling_inputs["Z_II"],
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
def pre_recycle(self, f: dict) -> dict:
|
|
271
|
+
"""Prepare feature inputs for recycling.
|
|
272
|
+
|
|
273
|
+
Includes:
|
|
274
|
+
- Feature initialization (S_inputs_I, S_init_I, Z_init_II)
|
|
275
|
+
- Initializing S_I and Z_II to zeros
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
dict: Dictionary of recycling inputs, including:
|
|
279
|
+
- S_inputs_I: Token-level single representation input (prior to AtomAttention) [I, c_s_inputs]
|
|
280
|
+
- S_init_I: Token-level single representation initialization [I, c_s] (after round of AtomAttention)
|
|
281
|
+
- Z_init_II: Token-level pair representation initialization [I, I, c_z] (after round of AtomAttention)
|
|
282
|
+
- S_I: Token-level single representation [I, c_s], initialized to zeros
|
|
283
|
+
- Z_II: Token-level pair representation [I, I, c_z], initialized to zeros
|
|
284
|
+
"""
|
|
285
|
+
S_inputs_I, S_init_I, Z_init_II = self.feature_initializer(f)
|
|
286
|
+
S_I = torch.zeros_like(S_init_I)
|
|
287
|
+
Z_II = torch.zeros_like(Z_init_II)
|
|
288
|
+
|
|
289
|
+
return dict(
|
|
290
|
+
S_inputs_I=S_inputs_I,
|
|
291
|
+
S_init_I=S_init_I,
|
|
292
|
+
Z_init_II=Z_init_II,
|
|
293
|
+
S_I=S_I,
|
|
294
|
+
Z_II=Z_II,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def recycle(
|
|
298
|
+
self,
|
|
299
|
+
# TODO: Jax typing
|
|
300
|
+
S_inputs_I,
|
|
301
|
+
S_init_I,
|
|
302
|
+
Z_init_II,
|
|
303
|
+
S_I,
|
|
304
|
+
Z_II,
|
|
305
|
+
f,
|
|
306
|
+
):
|
|
307
|
+
S_I, Z_II = self.recycler(
|
|
308
|
+
f=f,
|
|
309
|
+
S_inputs_I=S_inputs_I,
|
|
310
|
+
S_init_I=S_init_I,
|
|
311
|
+
Z_init_II=Z_init_II,
|
|
312
|
+
S_I=S_I,
|
|
313
|
+
Z_II=Z_II,
|
|
314
|
+
)
|
|
315
|
+
return dict(
|
|
316
|
+
S_inputs_I=S_inputs_I,
|
|
317
|
+
S_init_I=S_init_I,
|
|
318
|
+
Z_init_II=Z_init_II,
|
|
319
|
+
S_I=S_I,
|
|
320
|
+
Z_II=Z_II,
|
|
321
|
+
f=f,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class RF3WithConfidence(RF3):
|
|
326
|
+
"""Model for training and inference with confidence metric computation"""
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
confidence_head: DictConfig | dict,
|
|
331
|
+
mini_rollout_sampler: DictConfig | dict,
|
|
332
|
+
**kwargs,
|
|
333
|
+
):
|
|
334
|
+
"""
|
|
335
|
+
Args:
|
|
336
|
+
(... all arguments from the AF3 class)
|
|
337
|
+
confidence_head: Hydra configuration for the confidence head architecture
|
|
338
|
+
mini_rollout_sampler: Hydra configuration for the mini-rollout sampler (e.g., SampleDiffusion with 20 rather than
|
|
339
|
+
200 timesteps. Note that the `inference_sampler` argument in the AF3 class will still be used for full
|
|
340
|
+
rollouts during inference)
|
|
341
|
+
"""
|
|
342
|
+
# (Lazy import)
|
|
343
|
+
from rf3.model.layers.af3_auxiliary_heads import ConfidenceHead # noqa
|
|
344
|
+
|
|
345
|
+
super().__init__(**kwargs)
|
|
346
|
+
|
|
347
|
+
self.confidence_head = ConfidenceHead(**confidence_head)
|
|
348
|
+
self.mini_rollout_sampler = SampleDiffusion(**mini_rollout_sampler)
|
|
349
|
+
|
|
350
|
+
def forward(
|
|
351
|
+
self,
|
|
352
|
+
input: dict,
|
|
353
|
+
n_cycle: int,
|
|
354
|
+
coord_atom_lvl_to_be_noised: torch.Tensor | None = None,
|
|
355
|
+
should_early_stop_fn: ShouldEarlyStopFn | None = None,
|
|
356
|
+
) -> dict:
|
|
357
|
+
"""Complete forward pass of the model with confidence head.
|
|
358
|
+
|
|
359
|
+
Notes:
|
|
360
|
+
- Performs a mini-rollout without gradients during training (e.g., 20 timesteps) and a full rollout (e.g., 200 timesteps) during inference
|
|
361
|
+
- Runs the trunk forward without gradients to conserve memory (which departs from the AF-3 implementation)
|
|
362
|
+
- Runs the forward pass (with gradients) for the confidence model
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
input (dict): Dictionary of model inputs. In addition to the standard AF-3 model inputs, we expect:
|
|
366
|
+
- rep_atom_idxs: TBD
|
|
367
|
+
- frame_atom_idxs: TBD
|
|
368
|
+
n_cycle (int): Number of recycling cycles for the trunk
|
|
369
|
+
coord_atom_lvl_to_be_noised (torch.Tensor): Atom-level coordinates to be noised further. Optional;
|
|
370
|
+
only used during inference for partial denoising.
|
|
371
|
+
should_early_stop_fn(Callable): Function that takes the confidence and trunk outputs after the first recycle and returns a boolean
|
|
372
|
+
indicating whether to stop early and a dictionary with additional information (e.g., value and threshold).
|
|
373
|
+
If None, no early stopping is performed. Optional; only used during inference.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
dict: Dictionary of model outputs, including:
|
|
377
|
+
- X_L: Predicted atomic coordinates [D, L, 3] (from the mini rollout during training or full rollout during inference)
|
|
378
|
+
- plddt: TBD
|
|
379
|
+
- pae: TBD
|
|
380
|
+
- pde: TBD
|
|
381
|
+
- exp_resolved: TBD
|
|
382
|
+
"""
|
|
383
|
+
# Cast features to lower precision if autocast is enabled
|
|
384
|
+
if torch.is_autocast_enabled():
|
|
385
|
+
autocast_dtype = torch.get_autocast_dtype("cuda")
|
|
386
|
+
for x in [
|
|
387
|
+
"msa_stack",
|
|
388
|
+
"profile",
|
|
389
|
+
"template_distogram",
|
|
390
|
+
"template_restype",
|
|
391
|
+
"template_unit_vector",
|
|
392
|
+
]:
|
|
393
|
+
if x in input["f"]:
|
|
394
|
+
input["f"][x] = input["f"][x].to(autocast_dtype)
|
|
395
|
+
|
|
396
|
+
diffusion_batch_size = input["t"].shape[0]
|
|
397
|
+
with torch.no_grad():
|
|
398
|
+
# ... recycling
|
|
399
|
+
recycling_output_generator = self.trunk_forward_with_recycling(
|
|
400
|
+
f=input["f"], n_recycles=n_cycle
|
|
401
|
+
)
|
|
402
|
+
if should_early_stop_fn:
|
|
403
|
+
assert (
|
|
404
|
+
not self.training
|
|
405
|
+
), "Early stopping is not supported during training!"
|
|
406
|
+
# ... get the recycling outputs after the first recycle
|
|
407
|
+
first_recycle_outputs = next(recycling_output_generator)
|
|
408
|
+
|
|
409
|
+
# ... compute confidence metrics (without structure)
|
|
410
|
+
confidence_outputs = checkpoint.checkpoint(
|
|
411
|
+
create_custom_forward(
|
|
412
|
+
self.confidence_head, frame_atom_idxs=input["frame_atom_idxs"]
|
|
413
|
+
),
|
|
414
|
+
first_recycle_outputs["S_inputs_I"],
|
|
415
|
+
first_recycle_outputs["S_I"],
|
|
416
|
+
first_recycle_outputs["Z_II"],
|
|
417
|
+
None, # Omit structure
|
|
418
|
+
input["seq"],
|
|
419
|
+
input["rep_atom_idxs"],
|
|
420
|
+
use_reentrant=False,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
should_early_stop, early_stop_data = should_early_stop_fn(
|
|
424
|
+
confidence_outputs=confidence_outputs,
|
|
425
|
+
first_recycle_outputs=first_recycle_outputs,
|
|
426
|
+
)
|
|
427
|
+
if should_early_stop:
|
|
428
|
+
result = {"early_stopped": True}
|
|
429
|
+
return result | early_stop_data
|
|
430
|
+
|
|
431
|
+
# (We use `deque` with maxlen=1 to ensure that we only keep the last output in memory)
|
|
432
|
+
try:
|
|
433
|
+
recycling_outputs = deque(recycling_output_generator, maxlen=1).pop()
|
|
434
|
+
except IndexError:
|
|
435
|
+
# Handle the case where the generator is empty
|
|
436
|
+
raise RuntimeError("Recycling generator produced no outputs")
|
|
437
|
+
|
|
438
|
+
# Predict the distogram from the pair representation
|
|
439
|
+
# (NOTE: Not necessary for confidence head training, but helpful for reporting)
|
|
440
|
+
distogram_pred = self.distogram_head(recycling_outputs["Z_II"])
|
|
441
|
+
|
|
442
|
+
# ... post-recycling (diffusion module)
|
|
443
|
+
if self.training:
|
|
444
|
+
# Mini-rollout (no gradients still)
|
|
445
|
+
sample_diffusion_outs = (
|
|
446
|
+
self.mini_rollout_sampler.sample_diffusion_like_af3(
|
|
447
|
+
f=input["f"],
|
|
448
|
+
S_inputs_I=recycling_outputs["S_inputs_I"],
|
|
449
|
+
S_trunk_I=recycling_outputs["S_I"],
|
|
450
|
+
Z_trunk_II=recycling_outputs["Z_II"],
|
|
451
|
+
diffusion_module=self.diffusion_module,
|
|
452
|
+
diffusion_batch_size=diffusion_batch_size,
|
|
453
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
# Full diffusion rollout (no gradients still)
|
|
458
|
+
sample_diffusion_outs = (
|
|
459
|
+
self.inference_sampler.sample_diffusion_like_af3(
|
|
460
|
+
f=input["f"],
|
|
461
|
+
S_inputs_I=recycling_outputs["S_inputs_I"],
|
|
462
|
+
S_trunk_I=recycling_outputs["S_I"],
|
|
463
|
+
Z_trunk_II=recycling_outputs["Z_II"],
|
|
464
|
+
diffusion_module=self.diffusion_module,
|
|
465
|
+
diffusion_batch_size=diffusion_batch_size,
|
|
466
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# ... run non-batched confidence head
|
|
471
|
+
# TODO: Write a version of the confidence head that splits into batches based on memory available
|
|
472
|
+
# (Currently, we OOM with the full batch size, so we loop, which is slow)
|
|
473
|
+
D = sample_diffusion_outs["X_L"].shape[0]
|
|
474
|
+
confidence_stack = {}
|
|
475
|
+
for i in range(D):
|
|
476
|
+
confidence = checkpoint.checkpoint(
|
|
477
|
+
create_custom_forward(
|
|
478
|
+
self.confidence_head, frame_atom_idxs=input["frame_atom_idxs"]
|
|
479
|
+
),
|
|
480
|
+
recycling_outputs["S_inputs_I"],
|
|
481
|
+
recycling_outputs["S_I"],
|
|
482
|
+
recycling_outputs["Z_II"],
|
|
483
|
+
sample_diffusion_outs["X_L"][i].unsqueeze(0),
|
|
484
|
+
input["seq"],
|
|
485
|
+
input["rep_atom_idxs"],
|
|
486
|
+
use_reentrant=False,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
for k, v in confidence.items():
|
|
490
|
+
if k in confidence_stack:
|
|
491
|
+
confidence_stack[k] = torch.cat((confidence_stack[k], v), dim=0)
|
|
492
|
+
else:
|
|
493
|
+
confidence_stack[k] = v
|
|
494
|
+
confidence = confidence_stack
|
|
495
|
+
|
|
496
|
+
# ... run batched confidence head
|
|
497
|
+
# fd too much memory use at training time...
|
|
498
|
+
# confidence = checkpoint.checkpoint(
|
|
499
|
+
# create_custom_forward(
|
|
500
|
+
# self.confidence_head, frame_atom_idxs=input["frame_atom_idxs"]
|
|
501
|
+
# ),
|
|
502
|
+
# recycling_outputs["S_inputs_I"],
|
|
503
|
+
# recycling_outputs["S_I"],
|
|
504
|
+
# recycling_outputs["Z_II"],
|
|
505
|
+
# sample_diffusion_outs["X_L"],
|
|
506
|
+
# input["seq"],
|
|
507
|
+
# input["rep_atom_idxs"],
|
|
508
|
+
# use_reentrant=False,
|
|
509
|
+
# )
|
|
510
|
+
|
|
511
|
+
# TODO: Return outputs in a more structured way (e.g., a dataclass)
|
|
512
|
+
return dict(
|
|
513
|
+
early_stopped=False,
|
|
514
|
+
# We return X_L from diffusion sampling as X_pred_rollout_L to support future joint training with the confidence head (where we would have both X_L and X_pred_rollout_L)
|
|
515
|
+
X_L=None,
|
|
516
|
+
distogram=distogram_pred,
|
|
517
|
+
# For reporting, inference (validation or testing) only
|
|
518
|
+
X_noisy_L_traj=sample_diffusion_outs["X_noisy_L_traj"],
|
|
519
|
+
X_denoised_L_traj=sample_diffusion_outs["X_denoised_L_traj"],
|
|
520
|
+
t_hats=sample_diffusion_outs["t_hats"],
|
|
521
|
+
# Confidence outputs
|
|
522
|
+
X_pred_rollout_L=sample_diffusion_outs["X_L"],
|
|
523
|
+
plddt=confidence["plddt_logits"],
|
|
524
|
+
pae=confidence["pae_logits"],
|
|
525
|
+
pde=confidence["pde_logits"],
|
|
526
|
+
exp_resolved=confidence["exp_resolved_logits"],
|
|
527
|
+
)
|
rf3/model/RF3_blocks.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MSASubsampleEmbedder(nn.Module):
|
|
9
|
+
def __init__(self, num_sequences, dim_raw_msa, c_msa_embed, c_s_inputs):
|
|
10
|
+
super(MSASubsampleEmbedder, self).__init__()
|
|
11
|
+
self.num_sequences = num_sequences
|
|
12
|
+
self.emb_msa = nn.Linear(dim_raw_msa, c_msa_embed, bias=False)
|
|
13
|
+
self.emb_S_inputs = nn.Linear(c_s_inputs, c_msa_embed, bias=False)
|
|
14
|
+
|
|
15
|
+
@activation_checkpointing
|
|
16
|
+
def forward(
|
|
17
|
+
self,
|
|
18
|
+
msa_SI, # (S, I, 34) (32 tokens + has_deletion + deletion value)
|
|
19
|
+
S_inputs, # (L, S_dim)
|
|
20
|
+
):
|
|
21
|
+
# Embed the subsampled MSA
|
|
22
|
+
# (NOTE: We subsample in the data loader to avoid memory issues)
|
|
23
|
+
msa_SI = self.emb_msa(msa_SI)
|
|
24
|
+
msa_SI = msa_SI + self.emb_S_inputs(S_inputs)
|
|
25
|
+
return msa_SI
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MSAPairWeightedAverage(nn.Module):
|
|
29
|
+
"""implements Algorithm 10 from AF3 paper"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
c_weighted_average,
|
|
34
|
+
n_heads,
|
|
35
|
+
c_msa_embed,
|
|
36
|
+
c_z,
|
|
37
|
+
separate_gate_for_every_channel,
|
|
38
|
+
):
|
|
39
|
+
super(MSAPairWeightedAverage, self).__init__()
|
|
40
|
+
self.weighted_average_channels = c_weighted_average
|
|
41
|
+
self.n_heads = n_heads
|
|
42
|
+
self.msa_channels = c_msa_embed
|
|
43
|
+
self.pair_channels = c_z
|
|
44
|
+
self.norm_msa = nn.LayerNorm(self.msa_channels)
|
|
45
|
+
self.to_v = nn.Linear(
|
|
46
|
+
self.msa_channels, self.n_heads * self.weighted_average_channels, bias=False
|
|
47
|
+
)
|
|
48
|
+
self.norm_pair = nn.LayerNorm(self.pair_channels)
|
|
49
|
+
self.to_bias = nn.Linear(self.pair_channels, self.n_heads, bias=False)
|
|
50
|
+
|
|
51
|
+
self.separate_gate_for_every_channel = separate_gate_for_every_channel
|
|
52
|
+
if self.separate_gate_for_every_channel:
|
|
53
|
+
self.to_gate = nn.Linear(
|
|
54
|
+
self.msa_channels,
|
|
55
|
+
self.weighted_average_channels * self.n_heads,
|
|
56
|
+
bias=False,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
self.to_gate = nn.Linear(self.msa_channels, self.n_heads, bias=False)
|
|
60
|
+
|
|
61
|
+
self.to_out = nn.Linear(
|
|
62
|
+
self.weighted_average_channels * self.n_heads, self.msa_channels, bias=False
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@activation_checkpointing
|
|
66
|
+
def forward(self, msa_SI, pair_II):
|
|
67
|
+
S, I = msa_SI.shape[:2]
|
|
68
|
+
|
|
69
|
+
# normalize inputs
|
|
70
|
+
msa_SI = self.norm_msa(msa_SI)
|
|
71
|
+
|
|
72
|
+
# construct values, bias and weights
|
|
73
|
+
v_SIH = self.to_v(msa_SI).reshape(
|
|
74
|
+
S, I, self.n_heads, self.weighted_average_channels
|
|
75
|
+
)
|
|
76
|
+
bias_IIH = self.to_bias(self.norm_pair(pair_II))
|
|
77
|
+
w_IIH = F.softmax(bias_IIH, dim=-2)
|
|
78
|
+
|
|
79
|
+
# construct gate
|
|
80
|
+
gate_SIH = torch.sigmoid(self.to_gate(msa_SI))
|
|
81
|
+
|
|
82
|
+
# compute weighted average & apply gate
|
|
83
|
+
if self.separate_gate_for_every_channel:
|
|
84
|
+
weights = torch.einsum("ijh,sjhc->sihc", w_IIH, v_SIH).reshape(S, I, -1)
|
|
85
|
+
o_SIH = gate_SIH * weights
|
|
86
|
+
else:
|
|
87
|
+
weights = torch.einsum("ijh,sjhc->sihc", w_IIH, v_SIH)
|
|
88
|
+
o_SIH = gate_SIH[..., None] * weights
|
|
89
|
+
|
|
90
|
+
# concatenate heads and project
|
|
91
|
+
msa_update_SI = self.to_out(o_SIH.reshape(S, I, -1))
|
|
92
|
+
return msa_update_SI
|