rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from rf3.model.layers.af3_diffusion_transformer import (
|
|
6
|
+
AtomAttentionEncoderDiffusion,
|
|
7
|
+
AtomTransformer,
|
|
8
|
+
DiffusionTransformer,
|
|
9
|
+
)
|
|
10
|
+
from rf3.model.layers.layer_utils import Transition, linearNoBias
|
|
11
|
+
from rf3.model.layers.pairformer_layers import (
|
|
12
|
+
MSAModule,
|
|
13
|
+
PairformerBlock,
|
|
14
|
+
RelativePositionEncoding,
|
|
15
|
+
RF3TemplateEmbedder,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from foundry.model.layers.blocks import FourierEmbedding
|
|
19
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
Glossary:
|
|
25
|
+
I: # tokens (coarse representation)
|
|
26
|
+
L: # atoms (fine representation)
|
|
27
|
+
M: # msa
|
|
28
|
+
T: # templates
|
|
29
|
+
D: # diffusion structure batch dim
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AtomAttentionDecoder(nn.Module):
|
|
34
|
+
def __init__(self, c_token, c_atom, c_atompair, atom_transformer):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.atom_transformer = AtomTransformer(
|
|
37
|
+
c_atom=c_atom, c_atompair=c_atompair, **atom_transformer
|
|
38
|
+
)
|
|
39
|
+
self.linear_1 = linearNoBias(c_token, c_atom)
|
|
40
|
+
self.to_r_update = nn.Sequential(
|
|
41
|
+
nn.LayerNorm((c_atom,)), linearNoBias(c_atom, 3)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def forward(
|
|
45
|
+
self,
|
|
46
|
+
f,
|
|
47
|
+
Ai, # [L, C_token]
|
|
48
|
+
Ql_skip, # [L, C_atom]
|
|
49
|
+
Cl_skip, # [L, C_atom]
|
|
50
|
+
Plm_skip, # [L, L, C_atompair]
|
|
51
|
+
):
|
|
52
|
+
tok_idx = f["atom_to_token_map"]
|
|
53
|
+
|
|
54
|
+
@activation_checkpointing
|
|
55
|
+
def atom_decoder(Ai, Ql_skip, Cl_skip, Plm_skip, tok_idx):
|
|
56
|
+
# Broadcast per-token activiations to per-atom activations and add the skip connection
|
|
57
|
+
Ql = self.linear_1(Ai[..., tok_idx, :]) + Ql_skip
|
|
58
|
+
|
|
59
|
+
# Cross attention transformer.
|
|
60
|
+
Ql = self.atom_transformer(Ql, Cl_skip, Plm_skip)
|
|
61
|
+
|
|
62
|
+
# Map to positions update
|
|
63
|
+
Rl_update = self.to_r_update(Ql)
|
|
64
|
+
|
|
65
|
+
return Rl_update
|
|
66
|
+
|
|
67
|
+
return atom_decoder(Ai, Ql_skip, Cl_skip, Plm_skip, tok_idx)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DiffusionModule(nn.Module):
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
sigma_data,
|
|
74
|
+
c_atom,
|
|
75
|
+
c_atompair,
|
|
76
|
+
c_token,
|
|
77
|
+
c_s,
|
|
78
|
+
c_z,
|
|
79
|
+
f_pred,
|
|
80
|
+
diffusion_conditioning,
|
|
81
|
+
atom_attention_encoder,
|
|
82
|
+
diffusion_transformer,
|
|
83
|
+
atom_attention_decoder,
|
|
84
|
+
):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.sigma_data = sigma_data
|
|
87
|
+
self.c_atom = c_atom
|
|
88
|
+
self.c_atompair = c_atompair
|
|
89
|
+
self.c_token = c_token
|
|
90
|
+
self.c_s = c_s
|
|
91
|
+
self.f_pred = f_pred
|
|
92
|
+
|
|
93
|
+
self.diffusion_conditioning = DiffusionConditioning(
|
|
94
|
+
sigma_data=sigma_data, c_s=c_s, c_z=c_z, **diffusion_conditioning
|
|
95
|
+
)
|
|
96
|
+
self.atom_attention_encoder = AtomAttentionEncoderDiffusion(
|
|
97
|
+
c_token=c_token,
|
|
98
|
+
c_s=c_s,
|
|
99
|
+
c_atom=c_atom,
|
|
100
|
+
c_atompair=c_atompair,
|
|
101
|
+
**atom_attention_encoder,
|
|
102
|
+
)
|
|
103
|
+
self.process_s = nn.Sequential(
|
|
104
|
+
nn.LayerNorm((c_s,)),
|
|
105
|
+
linearNoBias(c_s, c_token),
|
|
106
|
+
)
|
|
107
|
+
self.diffusion_transformer = DiffusionTransformer(
|
|
108
|
+
c_token=c_token, c_s=c_s, c_tokenpair=c_z, **diffusion_transformer
|
|
109
|
+
)
|
|
110
|
+
self.layer_norm_1 = nn.LayerNorm(c_token)
|
|
111
|
+
self.atom_attention_decoder = AtomAttentionDecoder(
|
|
112
|
+
c_token=c_token,
|
|
113
|
+
c_atom=c_atom,
|
|
114
|
+
c_atompair=c_atompair,
|
|
115
|
+
**atom_attention_decoder,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def forward(
|
|
119
|
+
self,
|
|
120
|
+
X_noisy_L, # [B, L, 3]
|
|
121
|
+
t, # [B] (0 is ground truth)
|
|
122
|
+
f, # Dict (Input feature dictionary)
|
|
123
|
+
S_inputs_I, # [B, I, C_S_input]
|
|
124
|
+
S_trunk_I, # [B, I, C_S_trunk]
|
|
125
|
+
Z_trunk_II, # [B, I, I, C_Z]
|
|
126
|
+
):
|
|
127
|
+
# Conditioning
|
|
128
|
+
S_I, Z_II = self.diffusion_conditioning(
|
|
129
|
+
t, f, S_inputs_I.float(), S_trunk_I.float(), Z_trunk_II.float()
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Scale positions to dimensionless vectors with approximately unit variance
|
|
133
|
+
if self.f_pred == "edm":
|
|
134
|
+
R_noisy_L = X_noisy_L / torch.sqrt(
|
|
135
|
+
t[..., None, None] ** 2 + self.sigma_data**2
|
|
136
|
+
)
|
|
137
|
+
elif self.f_pred == "unconditioned":
|
|
138
|
+
R_noisy_L = torch.zeros_like(X_noisy_L)
|
|
139
|
+
elif self.f_pred == "noise_pred":
|
|
140
|
+
R_noisy_L = X_noisy_L
|
|
141
|
+
else:
|
|
142
|
+
raise Exception(f"{self.f_pred=} unrecognized")
|
|
143
|
+
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
|
|
144
|
+
A_I, Q_skip_L, C_skip_L, P_skip_LL = self.atom_attention_encoder(
|
|
145
|
+
f, R_noisy_L, S_trunk_I.float(), Z_II
|
|
146
|
+
)
|
|
147
|
+
# Full self-attention on token level
|
|
148
|
+
|
|
149
|
+
A_I = A_I + self.process_s(S_I)
|
|
150
|
+
A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=None)
|
|
151
|
+
A_I = self.layer_norm_1(A_I)
|
|
152
|
+
|
|
153
|
+
# Broadcast token activations to atoms and run Sequence-local Atom Attention
|
|
154
|
+
R_update_L = self.atom_attention_decoder(
|
|
155
|
+
f, A_I.float(), Q_skip_L, C_skip_L, P_skip_LL
|
|
156
|
+
)
|
|
157
|
+
# Rescale updates to positions and combine with input positions
|
|
158
|
+
if self.f_pred == "edm":
|
|
159
|
+
X_out_L = (self.sigma_data**2 / (self.sigma_data**2 + t**2))[
|
|
160
|
+
..., None, None
|
|
161
|
+
] * X_noisy_L + (self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5)[
|
|
162
|
+
..., None, None
|
|
163
|
+
] * R_update_L
|
|
164
|
+
elif self.f_pred == "unconditioned":
|
|
165
|
+
X_out_L = R_update_L
|
|
166
|
+
elif self.f_pred == "noise_pred":
|
|
167
|
+
X_out_L = X_noisy_L + R_update_L
|
|
168
|
+
else:
|
|
169
|
+
raise Exception(f"{self.f_pred=} unrecognized")
|
|
170
|
+
|
|
171
|
+
return X_out_L
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class DiffusionConditioning(nn.Module):
|
|
175
|
+
def __init__(
|
|
176
|
+
self, sigma_data, c_z, c_s, c_s_inputs, c_t_embed, relative_position_encoding
|
|
177
|
+
):
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.sigma_data = sigma_data
|
|
180
|
+
self.relative_position_encoding = RelativePositionEncoding(
|
|
181
|
+
c_z=c_z, **relative_position_encoding
|
|
182
|
+
)
|
|
183
|
+
self.to_zii = nn.Sequential(
|
|
184
|
+
nn.LayerNorm(
|
|
185
|
+
c_z * 2
|
|
186
|
+
), # Operates on concatenated (z_ij_trunk: [..., c_z]), RelativePositionalEncoding: [..., c_z])
|
|
187
|
+
linearNoBias(c_z * 2, c_z),
|
|
188
|
+
)
|
|
189
|
+
self.transition_1 = nn.ModuleList(
|
|
190
|
+
[
|
|
191
|
+
Transition(c=c_z, n=2),
|
|
192
|
+
Transition(c=c_z, n=2),
|
|
193
|
+
]
|
|
194
|
+
)
|
|
195
|
+
self.to_si = nn.Sequential(
|
|
196
|
+
nn.LayerNorm(c_s + c_s_inputs), linearNoBias(c_s + c_s_inputs, c_s)
|
|
197
|
+
)
|
|
198
|
+
c_t_embed = 256
|
|
199
|
+
self.fourier_embedding = FourierEmbedding(c_t_embed)
|
|
200
|
+
self.process_n = nn.Sequential(
|
|
201
|
+
nn.LayerNorm(c_t_embed), linearNoBias(c_t_embed, c_s)
|
|
202
|
+
)
|
|
203
|
+
self.transition_2 = nn.ModuleList(
|
|
204
|
+
[
|
|
205
|
+
Transition(c=c_s, n=2),
|
|
206
|
+
Transition(c=c_s, n=2),
|
|
207
|
+
]
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def forward(self, t, f, S_inputs_I, S_trunk_I, Z_trunk_II):
|
|
211
|
+
# Pair conditioning
|
|
212
|
+
Z_II = torch.cat([Z_trunk_II, self.relative_position_encoding(f)], dim=-1)
|
|
213
|
+
|
|
214
|
+
@activation_checkpointing
|
|
215
|
+
def _run_conditioning(Z_II, S_trunk_I, S_inputs_I):
|
|
216
|
+
Z_II = self.to_zii(Z_II)
|
|
217
|
+
for b in range(2):
|
|
218
|
+
Z_II = Z_II + self.transition_1[b](Z_II)
|
|
219
|
+
|
|
220
|
+
# Single conditioning
|
|
221
|
+
S_I = torch.cat([S_trunk_I, S_inputs_I], dim=-1)
|
|
222
|
+
S_I = self.to_si(S_I)
|
|
223
|
+
N_D = self.fourier_embedding(1 / 4 * torch.log(t / self.sigma_data))
|
|
224
|
+
S_I = self.process_n(N_D).unsqueeze(-2) + S_I
|
|
225
|
+
for b in range(2):
|
|
226
|
+
S_I = S_I + self.transition_2[b](S_I)
|
|
227
|
+
|
|
228
|
+
return S_I, Z_II
|
|
229
|
+
|
|
230
|
+
return _run_conditioning(Z_II, S_trunk_I, S_inputs_I)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class DistogramHead(nn.Module):
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
c_z,
|
|
237
|
+
bins,
|
|
238
|
+
):
|
|
239
|
+
super().__init__()
|
|
240
|
+
self.predictor = nn.Linear(c_z, bins)
|
|
241
|
+
self.reset_parameters()
|
|
242
|
+
|
|
243
|
+
def reset_parameters(self):
|
|
244
|
+
# initialize linear layer for final logit prediction
|
|
245
|
+
nn.init.zeros_(self.predictor.weight)
|
|
246
|
+
nn.init.zeros_(self.predictor.bias)
|
|
247
|
+
|
|
248
|
+
def forward(
|
|
249
|
+
self,
|
|
250
|
+
Z_II,
|
|
251
|
+
):
|
|
252
|
+
return self.predictor(
|
|
253
|
+
Z_II + Z_II.transpose(-2, -3) # symmetrize pair features
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class Recycler(nn.Module):
|
|
258
|
+
def __init__(
|
|
259
|
+
self,
|
|
260
|
+
c_s,
|
|
261
|
+
c_z,
|
|
262
|
+
template_embedder,
|
|
263
|
+
msa_module,
|
|
264
|
+
n_pairformer_blocks,
|
|
265
|
+
pairformer_block,
|
|
266
|
+
):
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.c_z = c_z
|
|
269
|
+
self.process_zh = nn.Sequential(
|
|
270
|
+
nn.LayerNorm(c_z),
|
|
271
|
+
linearNoBias(c_z, c_z),
|
|
272
|
+
)
|
|
273
|
+
self.template_embedder = RF3TemplateEmbedder(c_z=c_z, **template_embedder)
|
|
274
|
+
self.msa_module = MSAModule(**msa_module)
|
|
275
|
+
self.process_sh = nn.Sequential(
|
|
276
|
+
nn.LayerNorm(c_s),
|
|
277
|
+
linearNoBias(c_s, c_s),
|
|
278
|
+
)
|
|
279
|
+
self.pairformer_stack = nn.ModuleList(
|
|
280
|
+
[
|
|
281
|
+
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
|
|
282
|
+
for _ in range(n_pairformer_blocks)
|
|
283
|
+
]
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def forward(
|
|
287
|
+
self,
|
|
288
|
+
f,
|
|
289
|
+
S_inputs_I,
|
|
290
|
+
S_init_I,
|
|
291
|
+
Z_init_II,
|
|
292
|
+
S_I,
|
|
293
|
+
Z_II,
|
|
294
|
+
):
|
|
295
|
+
Z_II = Z_init_II + self.process_zh(Z_II)
|
|
296
|
+
Z_II = Z_II + self.template_embedder(f, Z_II)
|
|
297
|
+
# NOTE: Implementing bugfix from the Protenix Technical report, where residual-connecting the MSA module is redundant
|
|
298
|
+
# Reference: https://github.com/bytedance/Protenix/blob/main/Protenix_Technical_Report.pdf
|
|
299
|
+
Z_II = self.msa_module(f, Z_II, S_inputs_I)
|
|
300
|
+
S_I = S_init_I + self.process_sh(S_I)
|
|
301
|
+
for block in self.pairformer_stack:
|
|
302
|
+
S_I, Z_II = block(S_I, Z_II)
|
|
303
|
+
return S_I, Z_II
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from rf3.model.RF3_structure import PairformerBlock, linearNoBias
|
|
5
|
+
|
|
6
|
+
# TODO: Get from RF2AA encoding instead
|
|
7
|
+
CHEM_DATA_LEGACY = {"NHEAVY": 23, "aa2num": {"UNK": 20, "GLY": 7, "MAS": 21}}
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def discretize_distance_matrix(
|
|
11
|
+
distance_matrix, num_bins=38, min_distance=3.25, max_distance=50.75
|
|
12
|
+
):
|
|
13
|
+
# Calculate the bin width
|
|
14
|
+
bin_width = (max_distance - min_distance) / num_bins
|
|
15
|
+
bins = (
|
|
16
|
+
torch.arange(num_bins, device=distance_matrix.device) * bin_width + min_distance
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Discretize distances into bins (bucketize automatically places out-of-range values in the last bin)
|
|
20
|
+
binned_distances = torch.bucketize(distance_matrix, bins)
|
|
21
|
+
|
|
22
|
+
return binned_distances
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConfidenceHead(nn.Module):
|
|
26
|
+
"""Algorithm 31"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
c_s,
|
|
31
|
+
c_z,
|
|
32
|
+
n_pairformer_layers,
|
|
33
|
+
pairformer,
|
|
34
|
+
n_bins_pae,
|
|
35
|
+
n_bins_pde,
|
|
36
|
+
n_bins_plddt,
|
|
37
|
+
n_bins_exp_resolved,
|
|
38
|
+
use_Cb_distances=False,
|
|
39
|
+
use_af3_style_binning_and_final_layer_norms=False,
|
|
40
|
+
symmetrize_Cb_logits=True,
|
|
41
|
+
layer_norm_along_feature_dimension=False,
|
|
42
|
+
):
|
|
43
|
+
super(ConfidenceHead, self).__init__()
|
|
44
|
+
self.process_s_inputs_right = linearNoBias(449, c_z)
|
|
45
|
+
self.process_s_inputs_left = linearNoBias(449, c_z)
|
|
46
|
+
self.use_af3_style_binning_and_final_layer_norms = (
|
|
47
|
+
use_af3_style_binning_and_final_layer_norms
|
|
48
|
+
)
|
|
49
|
+
self.layer_norm_along_feature_dimension = layer_norm_along_feature_dimension
|
|
50
|
+
if self.use_af3_style_binning_and_final_layer_norms:
|
|
51
|
+
self.layernorm_pde = nn.LayerNorm(c_z)
|
|
52
|
+
self.layernorm_pae = nn.LayerNorm(c_z)
|
|
53
|
+
self.layernorm_plddt = nn.LayerNorm(c_s)
|
|
54
|
+
self.layernorm_exp_resolved = nn.LayerNorm(c_s)
|
|
55
|
+
self.process_pred_distances = linearNoBias(40, c_z)
|
|
56
|
+
else:
|
|
57
|
+
self.process_pred_distances = linearNoBias(11, c_z)
|
|
58
|
+
|
|
59
|
+
self.pairformer = nn.ModuleList(
|
|
60
|
+
[
|
|
61
|
+
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer)
|
|
62
|
+
for _ in range(n_pairformer_layers)
|
|
63
|
+
]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.predict_pae = linearNoBias(c_z, n_bins_pae)
|
|
67
|
+
self.predict_pde = linearNoBias(c_z, n_bins_pde)
|
|
68
|
+
self.predict_plddt = linearNoBias(
|
|
69
|
+
c_s, CHEM_DATA_LEGACY["NHEAVY"] * n_bins_plddt
|
|
70
|
+
)
|
|
71
|
+
self.predict_exp_resolved = linearNoBias(
|
|
72
|
+
c_s, CHEM_DATA_LEGACY["NHEAVY"] * n_bins_exp_resolved
|
|
73
|
+
)
|
|
74
|
+
self.use_Cb_distances = use_Cb_distances
|
|
75
|
+
if self.use_Cb_distances:
|
|
76
|
+
self.process_Cb_distances = linearNoBias(25, c_z)
|
|
77
|
+
self.symmetrize_Cb_logits = symmetrize_Cb_logits
|
|
78
|
+
|
|
79
|
+
def reset_parameters(self):
|
|
80
|
+
for m in self.modules():
|
|
81
|
+
if isinstance(m, nn.Linear):
|
|
82
|
+
nn.init.xavier_uniform_(m.weight)
|
|
83
|
+
if m.bias is not None:
|
|
84
|
+
nn.init.constant_(m.bias, 0)
|
|
85
|
+
|
|
86
|
+
def forward(
|
|
87
|
+
self,
|
|
88
|
+
S_inputs_I,
|
|
89
|
+
S_trunk_I,
|
|
90
|
+
Z_trunk_II,
|
|
91
|
+
X_pred_L,
|
|
92
|
+
seq,
|
|
93
|
+
rep_atoms,
|
|
94
|
+
frame_atom_idxs=None,
|
|
95
|
+
):
|
|
96
|
+
# stopgrad on S_trunk_I, Z_trunk_II, X_pred_L but not S_inputs_I (4.3.5)
|
|
97
|
+
S_trunk_I = S_trunk_I.detach().float() # B, L, 384
|
|
98
|
+
Z_trunk_II = Z_trunk_II.detach().float() # B, L, L, 128
|
|
99
|
+
if X_pred_L is not None:
|
|
100
|
+
X_pred_L = X_pred_L.detach().float() # B, n_atoms, 3
|
|
101
|
+
S_inputs_I = S_inputs_I.detach().float() # B, L, 384
|
|
102
|
+
seq = seq.detach()
|
|
103
|
+
|
|
104
|
+
if self.layer_norm_along_feature_dimension:
|
|
105
|
+
# do a layer norm on S_trunk_I
|
|
106
|
+
S_trunk_I = F.layer_norm(S_trunk_I, normalized_shape=(S_trunk_I.shape[-1]))
|
|
107
|
+
# do a layer norm on Z_trunk_II
|
|
108
|
+
Z_trunk_II = F.layer_norm(
|
|
109
|
+
Z_trunk_II, normalized_shape=(Z_trunk_II.shape[-1])
|
|
110
|
+
)
|
|
111
|
+
# do a layer norm on S_inputs_I
|
|
112
|
+
S_inputs_I = F.layer_norm(
|
|
113
|
+
S_inputs_I, normalized_shape=(S_inputs_I.shape[-1])
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
S_trunk_I = F.layer_norm(S_trunk_I, normalized_shape=(S_trunk_I.shape))
|
|
117
|
+
Z_trunk_II = F.layer_norm(Z_trunk_II, normalized_shape=(Z_trunk_II.shape))
|
|
118
|
+
S_inputs_I = F.layer_norm(S_inputs_I, normalized_shape=(S_inputs_I.shape))
|
|
119
|
+
|
|
120
|
+
# embed S_inputs_I twice
|
|
121
|
+
S_inputs_I_right = self.process_s_inputs_right(S_inputs_I)
|
|
122
|
+
S_inputs_I_left = self.process_s_inputs_left(S_inputs_I)
|
|
123
|
+
# add outer product of two linear embeddings of S_inputs_I to Z_II
|
|
124
|
+
# TODO: check the unsqueezed dimension is the correct one
|
|
125
|
+
Z_trunk_II = Z_trunk_II + (
|
|
126
|
+
S_inputs_I_right.unsqueeze(-2) + S_inputs_I_left.unsqueeze(-3)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# embed distances of representative atom from every token
|
|
130
|
+
# in the pair representation
|
|
131
|
+
# if no coords are input, skip this connection
|
|
132
|
+
if X_pred_L is not None:
|
|
133
|
+
X_pred_rep_I = X_pred_L.index_select(1, rep_atoms)
|
|
134
|
+
dist = torch.cdist(X_pred_rep_I, X_pred_rep_I)
|
|
135
|
+
if not self.use_af3_style_binning_and_final_layer_norms:
|
|
136
|
+
# bins are 3.375 to 20.375 in 1.75 increments according to pseudocode
|
|
137
|
+
dist_one_hot = F.one_hot(
|
|
138
|
+
discretize_distance_matrix(
|
|
139
|
+
dist, min_distance=3.375, max_distance=20.875, num_bins=10
|
|
140
|
+
),
|
|
141
|
+
num_classes=11,
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
# published code is 3.25 to 50.75, with 39 bins
|
|
145
|
+
dist_one_hot = F.one_hot(
|
|
146
|
+
discretize_distance_matrix(
|
|
147
|
+
dist, min_distance=3.25, max_distance=50.75, num_bins=39
|
|
148
|
+
),
|
|
149
|
+
num_classes=40,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
Z_trunk_II = Z_trunk_II + self.process_pred_distances(dist_one_hot.float())
|
|
153
|
+
|
|
154
|
+
if self.use_Cb_distances:
|
|
155
|
+
# embed difference between observed cb and ideal cb positions
|
|
156
|
+
Cb_distances = calc_Cb_distances(
|
|
157
|
+
X_pred_L, seq, rep_atoms, frame_atom_idxs
|
|
158
|
+
)
|
|
159
|
+
Cb_distances_one_hot = F.one_hot(
|
|
160
|
+
discretize_distance_matrix(
|
|
161
|
+
Cb_distances,
|
|
162
|
+
min_distance=0.0001,
|
|
163
|
+
max_distance=0.25,
|
|
164
|
+
num_bins=24,
|
|
165
|
+
),
|
|
166
|
+
num_classes=25,
|
|
167
|
+
)
|
|
168
|
+
Cb_logits = self.process_Cb_distances(Cb_distances_one_hot.float())
|
|
169
|
+
# symmetrize the logits
|
|
170
|
+
if self.symmetrize_Cb_logits:
|
|
171
|
+
Cb_logits = Cb_logits[:, None, :, :] + Cb_logits[:, :, None, :]
|
|
172
|
+
else:
|
|
173
|
+
Cb_logits = Cb_logits[:, None, :, :]
|
|
174
|
+
|
|
175
|
+
Z_trunk_II = Z_trunk_II + Cb_logits
|
|
176
|
+
|
|
177
|
+
if not self.use_af3_style_binning_and_final_layer_norms:
|
|
178
|
+
S_trunk_residual_I = S_trunk_I.clone()
|
|
179
|
+
Z_trunk_residual_II = Z_trunk_II.clone()
|
|
180
|
+
|
|
181
|
+
# process with pairformer stack
|
|
182
|
+
for n in range(len(self.pairformer)):
|
|
183
|
+
S_trunk_I, Z_trunk_II = self.pairformer[n](S_trunk_I, Z_trunk_II)
|
|
184
|
+
|
|
185
|
+
# despite doing so in their pseudocode, af3's published code does not add the residual back
|
|
186
|
+
if not self.use_af3_style_binning_and_final_layer_norms:
|
|
187
|
+
S_trunk_I = S_trunk_residual_I + S_trunk_I
|
|
188
|
+
Z_trunk_II = Z_trunk_residual_II + Z_trunk_II
|
|
189
|
+
|
|
190
|
+
# linearly project for each prediction task
|
|
191
|
+
pde_logits = self.predict_pde(
|
|
192
|
+
Z_trunk_II + Z_trunk_II.transpose(-2, -3)
|
|
193
|
+
) # BUG: needs to be symmetrized correctly
|
|
194
|
+
|
|
195
|
+
pae_logits = self.predict_pae(Z_trunk_II)
|
|
196
|
+
|
|
197
|
+
plddt_logits = self.predict_plddt(S_trunk_I)
|
|
198
|
+
exp_resolved_logits = self.predict_exp_resolved(S_trunk_I)
|
|
199
|
+
|
|
200
|
+
# af3's published code does not add the residual back and has some additional layernorms before the linear projections
|
|
201
|
+
# they also do the pde slightly differently, adding the transpose after the linear projection
|
|
202
|
+
else:
|
|
203
|
+
left_distance_logits = self.predict_pde(self.layernorm_pde(Z_trunk_II))
|
|
204
|
+
right_distance_logits = left_distance_logits.transpose(-2, -3)
|
|
205
|
+
pde_logits = left_distance_logits + right_distance_logits
|
|
206
|
+
|
|
207
|
+
pae_logits = self.predict_pae(self.layernorm_pae(Z_trunk_II))
|
|
208
|
+
plddt_logits = self.predict_plddt(self.layernorm_plddt(S_trunk_I))
|
|
209
|
+
exp_resolved_logits = self.predict_exp_resolved(
|
|
210
|
+
self.layernorm_exp_resolved(S_trunk_I)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return dict(
|
|
214
|
+
pde_logits=pde_logits,
|
|
215
|
+
pae_logits=pae_logits,
|
|
216
|
+
plddt_logits=plddt_logits,
|
|
217
|
+
exp_resolved_logits=exp_resolved_logits,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def calc_Cb_distances(X_pred_L, seq, rep_atoms, frame_atom_idxs):
|
|
222
|
+
frame_atom_idxs = frame_atom_idxs.unsqueeze(0).expand(X_pred_L.shape[0], -1, -1)
|
|
223
|
+
|
|
224
|
+
N = torch.gather(
|
|
225
|
+
X_pred_L, 1, frame_atom_idxs[..., 0].unsqueeze(-1).expand(-1, -1, 3)
|
|
226
|
+
)
|
|
227
|
+
Ca = torch.gather(
|
|
228
|
+
X_pred_L, 1, frame_atom_idxs[..., 1].unsqueeze(-1).expand(-1, -1, 3)
|
|
229
|
+
)
|
|
230
|
+
C = torch.gather(
|
|
231
|
+
X_pred_L, 1, frame_atom_idxs[..., 2].unsqueeze(-1).expand(-1, -1, 3)
|
|
232
|
+
)
|
|
233
|
+
Cb = X_pred_L.index_select(1, rep_atoms)
|
|
234
|
+
|
|
235
|
+
is_valid_Cb = (
|
|
236
|
+
(seq != CHEM_DATA_LEGACY.aa2num["UNK"])
|
|
237
|
+
& (seq != CHEM_DATA_LEGACY.aa2num["GLY"])
|
|
238
|
+
& (seq != CHEM_DATA_LEGACY.aa2num["MAS"])
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def _legacy_is_protein(seq):
|
|
242
|
+
return (seq >= 0).all() & (seq < 20).all()
|
|
243
|
+
|
|
244
|
+
is_valid_Cb = is_valid_Cb & _legacy_is_protein(seq)
|
|
245
|
+
|
|
246
|
+
b = Ca - N
|
|
247
|
+
c = C - Ca
|
|
248
|
+
a = torch.cross(b, c, dim=-1)
|
|
249
|
+
|
|
250
|
+
ideal_Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
|
|
251
|
+
|
|
252
|
+
Cb_distances = torch.norm(Cb - ideal_Cb, dim=-1)
|
|
253
|
+
Cb_distances[:, ~is_valid_Cb] = 0.0
|
|
254
|
+
|
|
255
|
+
return Cb_distances
|