rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
from opt_einsum import contract as einsum
|
|
9
|
+
from rf3.util_module import init_lecun_normal
|
|
10
|
+
|
|
11
|
+
from foundry import SHOULD_USE_CUEQUIVARIANCE
|
|
12
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
13
|
+
|
|
14
|
+
if SHOULD_USE_CUEQUIVARIANCE:
|
|
15
|
+
import cuequivariance_torch as cuet
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TriangleAttention(nn.Module):
|
|
19
|
+
"""Implementation of Triangle Attention from AlphaFold3.
|
|
20
|
+
|
|
21
|
+
Routes to either cuEquivariance or vanilla implementation based on configuration.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
pair: Pair representation tensor of shape (B, L, L, d_pair)
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Updated pair representation tensor of shape (B, L, L, d_pair)
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
d_pair,
|
|
34
|
+
n_head=4,
|
|
35
|
+
d_hidden=32,
|
|
36
|
+
p_drop=0.1, # noqa: E402
|
|
37
|
+
start_node=True,
|
|
38
|
+
use_cuequivariance=True,
|
|
39
|
+
):
|
|
40
|
+
super(TriangleAttention, self).__init__()
|
|
41
|
+
|
|
42
|
+
self.norm = nn.LayerNorm(d_pair)
|
|
43
|
+
|
|
44
|
+
self.to_q = nn.Linear(d_pair, n_head * d_hidden, bias=False)
|
|
45
|
+
self.to_k = nn.Linear(d_pair, n_head * d_hidden, bias=False)
|
|
46
|
+
self.to_v = nn.Linear(d_pair, n_head * d_hidden, bias=False)
|
|
47
|
+
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
|
48
|
+
self.to_g = nn.Linear(d_pair, n_head * d_hidden)
|
|
49
|
+
self.to_out = nn.Linear(n_head * d_hidden, d_pair)
|
|
50
|
+
|
|
51
|
+
self.scaling = 1 / math.sqrt(d_hidden)
|
|
52
|
+
|
|
53
|
+
self.h = n_head
|
|
54
|
+
self.dim = d_hidden
|
|
55
|
+
self.start_node = start_node
|
|
56
|
+
|
|
57
|
+
self.use_cuequivariance = use_cuequivariance
|
|
58
|
+
|
|
59
|
+
self.reset_parameter()
|
|
60
|
+
|
|
61
|
+
def reset_parameter(self):
|
|
62
|
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
|
63
|
+
nn.init.xavier_uniform_(self.to_q.weight)
|
|
64
|
+
nn.init.xavier_uniform_(self.to_k.weight)
|
|
65
|
+
nn.init.xavier_uniform_(self.to_v.weight)
|
|
66
|
+
|
|
67
|
+
# bias: normal distribution
|
|
68
|
+
self.to_b = init_lecun_normal(self.to_b)
|
|
69
|
+
|
|
70
|
+
# gating: zero weights, one biases (mostly open gate at the begining)
|
|
71
|
+
nn.init.zeros_(self.to_g.weight)
|
|
72
|
+
nn.init.ones_(self.to_g.bias)
|
|
73
|
+
|
|
74
|
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
|
75
|
+
nn.init.zeros_(self.to_out.weight)
|
|
76
|
+
nn.init.zeros_(self.to_out.bias)
|
|
77
|
+
|
|
78
|
+
@activation_checkpointing
|
|
79
|
+
def forward(self, pair):
|
|
80
|
+
"""Forward pass of triangle attention."""
|
|
81
|
+
pair = self.norm(pair)
|
|
82
|
+
bias = self.to_b(pair) # (B, L, L, h)
|
|
83
|
+
|
|
84
|
+
if not self.start_node:
|
|
85
|
+
pair = rearrange(pair, "b i j d -> b j i d")
|
|
86
|
+
|
|
87
|
+
# Route to appropriate implementation
|
|
88
|
+
if self.use_cuequivariance and SHOULD_USE_CUEQUIVARIANCE:
|
|
89
|
+
out = self._forward_cuequivariance(pair, bias)
|
|
90
|
+
else:
|
|
91
|
+
out = self._forward_vanilla(pair, bias)
|
|
92
|
+
|
|
93
|
+
if not self.start_node:
|
|
94
|
+
out = rearrange(out, "b i j d -> b j i d")
|
|
95
|
+
|
|
96
|
+
# output projection
|
|
97
|
+
out = self.to_out(out)
|
|
98
|
+
return out
|
|
99
|
+
|
|
100
|
+
def _forward_cuequivariance(self, pair, bias):
|
|
101
|
+
"""cuEquivariance triangle attention implementation."""
|
|
102
|
+
# Handle autocast conversion
|
|
103
|
+
if torch.is_autocast_enabled():
|
|
104
|
+
dtype = torch.get_autocast_dtype("cuda")
|
|
105
|
+
pair = pair.to(dtype=dtype)
|
|
106
|
+
bias = bias.to(dtype=dtype)
|
|
107
|
+
|
|
108
|
+
assert (
|
|
109
|
+
pair.dtype == torch.bfloat16 and bias.dtype == torch.bfloat16
|
|
110
|
+
), f"cuEquivariance requires bfloat16 inputs (got pair={pair.dtype}, bias={bias.dtype})"
|
|
111
|
+
|
|
112
|
+
# Gate computation
|
|
113
|
+
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
|
114
|
+
|
|
115
|
+
# Project and reshape to cuEquivariance format: (B, L, H, L, D)
|
|
116
|
+
query = rearrange(self.to_q(pair), "b i j (h d) -> b i h j d", h=self.h)
|
|
117
|
+
key = rearrange(self.to_k(pair), "b i k (h d) -> b i h k d", h=self.h)
|
|
118
|
+
value = rearrange(self.to_v(pair), "b i k (h d) -> b i h k d", h=self.h)
|
|
119
|
+
|
|
120
|
+
# Bias: (B, L, L, H) -> (B, 1, H, L, L)
|
|
121
|
+
bias_cueq = rearrange(bias, "b i j h -> b 1 h i j")
|
|
122
|
+
|
|
123
|
+
# Call cuEquivariance triangle attention
|
|
124
|
+
out_cueq = cuet.triangle_attention(
|
|
125
|
+
query, key, value, bias=bias_cueq, scale=self.scaling
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Reshape back: (B, L, H, L, D) -> (B, L, L, H*D)
|
|
129
|
+
out = rearrange(out_cueq, "b i h j d -> b i j (h d)")
|
|
130
|
+
out = gate * out # gated attention
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
def _forward_vanilla(self, pair, bias):
|
|
134
|
+
"""Vanilla PyTorch triangle attention implementation."""
|
|
135
|
+
B, L = pair.shape[:2]
|
|
136
|
+
|
|
137
|
+
# Gate computation
|
|
138
|
+
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
|
139
|
+
|
|
140
|
+
# Project and reshape to vanilla format: (B, L, L, H, D)
|
|
141
|
+
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
|
|
142
|
+
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
|
|
143
|
+
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
|
|
144
|
+
|
|
145
|
+
query = query * self.scaling
|
|
146
|
+
|
|
147
|
+
attn = einsum("bijhd,bikhd->bijkh", query, key)
|
|
148
|
+
attn = attn + bias.unsqueeze(1).expand(-1, L, -1, -1, -1) # (bijkh)
|
|
149
|
+
attn = F.softmax(attn, dim=-2)
|
|
150
|
+
|
|
151
|
+
out = einsum("bijkh,bikhd->bijhd", attn, value).reshape(B, L, L, -1)
|
|
152
|
+
out = gate * out # gated attention
|
|
153
|
+
|
|
154
|
+
return out
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class TriangleMultiplication(nn.Module):
|
|
158
|
+
"""Implementation of Triangle Multiplicative Update from AlphaFold3.
|
|
159
|
+
|
|
160
|
+
Routes to either cuEquivariance or naive implementation based on configuration.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
d_pair: Pair representation dimension (must equal d_hidden for cuEquivariance)
|
|
164
|
+
d_hidden: Hidden dimension (must equal d_pair for cuEquivariance)
|
|
165
|
+
direction: "outgoing" or "incoming" triangle multiplication direction
|
|
166
|
+
bias: Whether to use bias in normalization layers
|
|
167
|
+
use_cuequivariance: Whether to use cuEquivariance fused kernel when available
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
d_pair,
|
|
173
|
+
d_hidden=None,
|
|
174
|
+
direction="outgoing",
|
|
175
|
+
bias=True,
|
|
176
|
+
use_cuequivariance=True,
|
|
177
|
+
):
|
|
178
|
+
super(TriangleMultiplication, self).__init__()
|
|
179
|
+
|
|
180
|
+
# Set d_hidden to d_pair if not specified
|
|
181
|
+
if d_hidden is None:
|
|
182
|
+
d_hidden = d_pair
|
|
183
|
+
|
|
184
|
+
self.d_pair = d_pair
|
|
185
|
+
self.d_hidden = d_hidden
|
|
186
|
+
|
|
187
|
+
# Validate direction parameter
|
|
188
|
+
if direction not in ["outgoing", "incoming"]:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"direction must be 'outgoing' or 'incoming', got '{direction}'"
|
|
191
|
+
)
|
|
192
|
+
self.direction = direction
|
|
193
|
+
|
|
194
|
+
self.use_cuequivariance = use_cuequivariance
|
|
195
|
+
|
|
196
|
+
if self.use_cuequivariance:
|
|
197
|
+
# cuEquivariance kernel requires d_pair == d_hidden...
|
|
198
|
+
assert (
|
|
199
|
+
d_pair == d_hidden
|
|
200
|
+
), "cuEquivariance triangle multiplication requires d_pair == d_hidden"
|
|
201
|
+
# ... and d_pair must be a multiple of 32
|
|
202
|
+
assert (
|
|
203
|
+
d_pair % 32 == 0
|
|
204
|
+
), "cuEquivariance triangle multiplication requires d_pair to be a multiple of 32"
|
|
205
|
+
|
|
206
|
+
# Input normalization (optional bias)
|
|
207
|
+
self.norm_in = nn.LayerNorm(d_pair, bias=bias)
|
|
208
|
+
|
|
209
|
+
# Input projections: combine left and right projections (2*d_hidden, d_pair) (no bias)
|
|
210
|
+
self.p_in = nn.Linear(d_pair, 2 * d_hidden, bias=False)
|
|
211
|
+
|
|
212
|
+
# Input gating: combine left and right gates (2*d_hidden, d_pair) (no bias)
|
|
213
|
+
self.g_in = nn.Linear(d_pair, 2 * d_hidden, bias=False)
|
|
214
|
+
|
|
215
|
+
# Output normalization (optional bias)
|
|
216
|
+
self.norm_out = nn.LayerNorm(d_hidden, bias=bias)
|
|
217
|
+
|
|
218
|
+
# Output projection (no bias)
|
|
219
|
+
self.p_out = nn.Linear(d_hidden, d_pair, bias=False)
|
|
220
|
+
|
|
221
|
+
# Output gating (no bias)
|
|
222
|
+
self.g_out = nn.Linear(d_pair, d_pair, bias=False)
|
|
223
|
+
|
|
224
|
+
self.reset_parameters()
|
|
225
|
+
|
|
226
|
+
def reset_parameters(self):
|
|
227
|
+
# Input projections: lecun normal distribution for regular linear weights
|
|
228
|
+
self.p_in = init_lecun_normal(self.p_in)
|
|
229
|
+
|
|
230
|
+
# We use default PyTorch initialization for the other parameters, as in AF-3 they do not specify their
|
|
231
|
+
# weight initialization schemes. Without bias, e.g., the gate initialization from AF-2 is not correct.
|
|
232
|
+
|
|
233
|
+
def forward(
|
|
234
|
+
self, pair: Float[torch.Tensor, "B N N D"]
|
|
235
|
+
) -> Float[torch.Tensor, "B N N D"]:
|
|
236
|
+
"""Forward pass of triangle multiplication."""
|
|
237
|
+
# Route to appropriate implementation
|
|
238
|
+
if self.use_cuequivariance and SHOULD_USE_CUEQUIVARIANCE:
|
|
239
|
+
return self._forward_cuequivariance(pair)
|
|
240
|
+
else:
|
|
241
|
+
return self._forward_vanilla(pair)
|
|
242
|
+
|
|
243
|
+
def _forward_vanilla(
|
|
244
|
+
self, pair: Float[torch.Tensor, "B N N D"]
|
|
245
|
+
) -> Float[torch.Tensor, "B N N D"]:
|
|
246
|
+
"""Vanilla PyTorch triangle multiplication implementation."""
|
|
247
|
+
B, L = pair.shape[:2]
|
|
248
|
+
|
|
249
|
+
# Input normalization
|
|
250
|
+
pair_norm = self.norm_in(pair)
|
|
251
|
+
|
|
252
|
+
# Input projections: get combined output and split
|
|
253
|
+
p_combined = self.p_in(pair_norm) # (B, L, L, 2*d_hidden)
|
|
254
|
+
left = p_combined[..., : self.d_hidden] # (B, L, L, d_hidden)
|
|
255
|
+
right = p_combined[..., self.d_hidden :] # (B, L, L, d_hidden)
|
|
256
|
+
|
|
257
|
+
# Input gating: get combined output and split
|
|
258
|
+
g_combined = self.g_in(pair_norm) # (B, L, L, 2*d_hidden)
|
|
259
|
+
left_gate = torch.sigmoid(g_combined[..., : self.d_hidden])
|
|
260
|
+
right_gate = torch.sigmoid(g_combined[..., self.d_hidden :])
|
|
261
|
+
|
|
262
|
+
# Apply gating
|
|
263
|
+
left = left_gate * left
|
|
264
|
+
right = right_gate * right
|
|
265
|
+
|
|
266
|
+
# Triangle multiplication based on direction
|
|
267
|
+
if self.direction == "outgoing":
|
|
268
|
+
out = torch.einsum("bikd,bjkd->bijd", left, right / float(L))
|
|
269
|
+
else: # incoming
|
|
270
|
+
out = torch.einsum("bkid,bkjd->bijd", left, right / float(L))
|
|
271
|
+
|
|
272
|
+
# Output normalization
|
|
273
|
+
out = self.norm_out(out)
|
|
274
|
+
|
|
275
|
+
# Output projection
|
|
276
|
+
out = self.p_out(out)
|
|
277
|
+
|
|
278
|
+
# Output gating
|
|
279
|
+
gate = torch.sigmoid(self.g_out(pair_norm))
|
|
280
|
+
out = gate * out
|
|
281
|
+
|
|
282
|
+
return out
|
|
283
|
+
|
|
284
|
+
def _forward_cuequivariance(
|
|
285
|
+
self, pair: Float[torch.Tensor, "B N N D"]
|
|
286
|
+
) -> Float[torch.Tensor, "B N N D"]:
|
|
287
|
+
"""cuEquivariance triangle multiplication implementation."""
|
|
288
|
+
# Handle autocast conversion
|
|
289
|
+
# (Use bfloat16 for optimal performance)
|
|
290
|
+
if torch.is_autocast_enabled():
|
|
291
|
+
dtype = torch.get_autocast_dtype("cuda")
|
|
292
|
+
pair = pair.to(dtype=dtype)
|
|
293
|
+
|
|
294
|
+
assert (
|
|
295
|
+
pair.dtype == torch.bfloat16
|
|
296
|
+
), "cuEquivariance requires bfloat16 inputs for optimal performance"
|
|
297
|
+
|
|
298
|
+
output = cuet.triangle_multiplicative_update(
|
|
299
|
+
x=pair,
|
|
300
|
+
direction=self.direction,
|
|
301
|
+
mask=None,
|
|
302
|
+
norm_in_weight=self.norm_in.weight,
|
|
303
|
+
norm_in_bias=self.norm_in.bias,
|
|
304
|
+
p_in_weight=self.p_in.weight, # (2*d_hidden, d_pair)
|
|
305
|
+
g_in_weight=self.g_in.weight, # (2*d_hidden, d_pair)
|
|
306
|
+
norm_out_weight=self.norm_out.weight,
|
|
307
|
+
norm_out_bias=self.norm_out.bias,
|
|
308
|
+
p_out_weight=self.p_out.weight, # (d_pair, d_pair) since d_hidden == d_pair
|
|
309
|
+
g_out_weight=self.g_out.weight, # (d_pair, d_pair)
|
|
310
|
+
eps=1e-5,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
return output
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from torch.nn.functional import silu
|
|
7
|
+
|
|
8
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
9
|
+
|
|
10
|
+
linearNoBias = partial(torch.nn.Linear, bias=False)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def collapse(x, L):
|
|
14
|
+
return x.reshape((L, x.numel() // L))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MultiDimLinear(nn.Linear):
|
|
18
|
+
def __init__(self, in_features, out_shape, **kwargs):
|
|
19
|
+
self.out_shape = out_shape
|
|
20
|
+
out_features = np.prod(out_shape)
|
|
21
|
+
super().__init__(in_features, out_features, **kwargs)
|
|
22
|
+
self.reset_parameters()
|
|
23
|
+
|
|
24
|
+
def reset_parameters(self, **kwargs) -> None:
|
|
25
|
+
super().reset_parameters()
|
|
26
|
+
nn.init.xavier_uniform_(self.weight)
|
|
27
|
+
|
|
28
|
+
def forward(self, x):
|
|
29
|
+
out = super().forward(x)
|
|
30
|
+
return out.reshape(x.shape[:-1] + self.out_shape)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LinearBiasInit(nn.Linear):
|
|
34
|
+
def __init__(self, *args, biasinit, **kwargs):
|
|
35
|
+
assert biasinit == -2.0 # Sanity check
|
|
36
|
+
self.biasinit = biasinit
|
|
37
|
+
super().__init__(*args, **kwargs)
|
|
38
|
+
|
|
39
|
+
def reset_parameters(self) -> None:
|
|
40
|
+
super().reset_parameters()
|
|
41
|
+
self.bias.data.fill_(self.biasinit)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Transition(nn.Module):
|
|
45
|
+
def __init__(self, n, c):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.layer_norm_1 = nn.LayerNorm(c)
|
|
48
|
+
self.linear_1 = linearNoBias(c, n * c)
|
|
49
|
+
self.linear_2 = linearNoBias(c, n * c)
|
|
50
|
+
self.linear_3 = linearNoBias(n * c, c)
|
|
51
|
+
|
|
52
|
+
@activation_checkpointing
|
|
53
|
+
def forward(
|
|
54
|
+
self,
|
|
55
|
+
X,
|
|
56
|
+
):
|
|
57
|
+
X = self.layer_norm_1(X)
|
|
58
|
+
A = self.linear_1(X)
|
|
59
|
+
B = self.linear_2(X)
|
|
60
|
+
X = self.linear_3(silu(A) * B)
|
|
61
|
+
return X
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AdaLN(nn.Module):
|
|
65
|
+
def __init__(self, c_a, c_s, n=2):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.ln_a = nn.LayerNorm(normalized_shape=(c_a,), elementwise_affine=False)
|
|
68
|
+
self.ln_s = nn.LayerNorm(normalized_shape=(c_s,), bias=False)
|
|
69
|
+
self.to_gain = nn.Sequential(
|
|
70
|
+
nn.Linear(c_s, c_a),
|
|
71
|
+
nn.Sigmoid(),
|
|
72
|
+
)
|
|
73
|
+
self.to_bias = linearNoBias(c_s, c_a)
|
|
74
|
+
|
|
75
|
+
def forward(
|
|
76
|
+
self,
|
|
77
|
+
Ai, # [B, I, C_a]
|
|
78
|
+
Si, # [B, I, C_s]
|
|
79
|
+
):
|
|
80
|
+
"""
|
|
81
|
+
Output:
|
|
82
|
+
[B, I, C_a]
|
|
83
|
+
"""
|
|
84
|
+
Ai = self.ln_a(Ai)
|
|
85
|
+
Si = self.ln_s(Si)
|
|
86
|
+
return self.to_gain(Si) * Ai + self.to_bias(Si)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def create_batch_dimension_if_not_present(batched_n_dim):
|
|
90
|
+
"""
|
|
91
|
+
Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
|
|
92
|
+
accept unbatched arguments.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def wrap(f):
|
|
96
|
+
def _wrap(arg):
|
|
97
|
+
inserted_batch_dim = False
|
|
98
|
+
if arg.ndim == batched_n_dim - 1:
|
|
99
|
+
arg = arg[None]
|
|
100
|
+
inserted_batch_dim = True
|
|
101
|
+
elif arg.ndim == batched_n_dim:
|
|
102
|
+
pass
|
|
103
|
+
else:
|
|
104
|
+
raise Exception(
|
|
105
|
+
f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
|
|
106
|
+
)
|
|
107
|
+
o = f(arg)
|
|
108
|
+
|
|
109
|
+
if inserted_batch_dim:
|
|
110
|
+
assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
|
|
111
|
+
return o[0]
|
|
112
|
+
return o
|
|
113
|
+
|
|
114
|
+
return _wrap
|
|
115
|
+
|
|
116
|
+
return wrap
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def unpack_args_for_checkpointing(arg_names):
|
|
120
|
+
def wrap(f):
|
|
121
|
+
def _wrap(*args):
|
|
122
|
+
f = args[0]
|
|
123
|
+
return f(**dict(zip(arg_names, args)))
|
|
124
|
+
|
|
125
|
+
return _wrap
|
|
126
|
+
|
|
127
|
+
return wrap
|
rf3/model/layers/mlff.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from einops import rearrange
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConformerEmbeddingWeightedAverage(nn.Module):
|
|
8
|
+
"""Learned weighted average of reference conformer embeddings.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
atom_level_embedding_dim: Dimension of the input atom-level embeddings (default: 384, for EGRET)
|
|
12
|
+
c_atompair: Dimension of the atom-pair embeddings (default: 16)
|
|
13
|
+
c_atom: Dimension of the output atom embeddings (default: 128)
|
|
14
|
+
n_conformers: Number of conformers to expect (default: 8)
|
|
15
|
+
dropout_rate: Dropout rate for regularization (default: 0.1)
|
|
16
|
+
use_layer_norm: Whether to apply layer normalization to the output (default: True)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
atom_level_embedding_dim: int,
|
|
22
|
+
c_atompair: int,
|
|
23
|
+
c_atom: int,
|
|
24
|
+
n_conformers: int = 8,
|
|
25
|
+
dropout_rate: float = 0.1,
|
|
26
|
+
use_layer_norm: bool = True,
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self.n_conformers = n_conformers
|
|
31
|
+
self.atom_level_embedding_dim = atom_level_embedding_dim
|
|
32
|
+
|
|
33
|
+
# Downcast MLP from atom_level_embedding_dim to c_atompair
|
|
34
|
+
self.process_atom_level_embedding = nn.Sequential(
|
|
35
|
+
nn.Linear(atom_level_embedding_dim, atom_level_embedding_dim // 2),
|
|
36
|
+
nn.ReLU(),
|
|
37
|
+
nn.Dropout(dropout_rate),
|
|
38
|
+
nn.Linear(atom_level_embedding_dim // 2, atom_level_embedding_dim // 4),
|
|
39
|
+
nn.ReLU(),
|
|
40
|
+
nn.Dropout(dropout_rate),
|
|
41
|
+
nn.Linear(atom_level_embedding_dim // 4, atom_level_embedding_dim // 8),
|
|
42
|
+
nn.ReLU(),
|
|
43
|
+
nn.Dropout(dropout_rate),
|
|
44
|
+
nn.Linear(atom_level_embedding_dim // 8, c_atompair),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Final MLP to convert from (n_conformers * c_atompair) to c_atom
|
|
48
|
+
self.conformers_to_atom_single_embedding = nn.Sequential(
|
|
49
|
+
nn.Linear(n_conformers * c_atompair, c_atom, bias=False),
|
|
50
|
+
nn.LayerNorm(c_atom) if use_layer_norm else nn.Identity(),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Zero-init the final linear layer to ensure the model starts with identity function (output ≈ 0)
|
|
54
|
+
nn.init.zeros_(self.conformers_to_atom_single_embedding[0].weight)
|
|
55
|
+
|
|
56
|
+
def forward(
|
|
57
|
+
self,
|
|
58
|
+
atom_level_embeddings: Float[
|
|
59
|
+
torch.Tensor, "n_conformers n_atom atom_level_embedding_dim"
|
|
60
|
+
],
|
|
61
|
+
) -> Float[torch.Tensor, "n_atom c_atom"]:
|
|
62
|
+
"""Forward pass: process atom-level embeddings and return the processed result.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
atom_level_embeddings: Input tensor of shape [n_conformers, n_atom, atom_level_embedding_dim]
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Processed tensor of shape [n_atom, c_atom] ready for residual addition
|
|
69
|
+
"""
|
|
70
|
+
assert (
|
|
71
|
+
atom_level_embeddings.shape[0] == self.n_conformers
|
|
72
|
+
), "Number of conformers must be consistent"
|
|
73
|
+
|
|
74
|
+
# Subset to [:atom_level_embedding_dim]
|
|
75
|
+
if atom_level_embeddings.shape[-1] > self.atom_level_embedding_dim:
|
|
76
|
+
atom_level_embeddings = atom_level_embeddings[
|
|
77
|
+
..., : self.atom_level_embedding_dim
|
|
78
|
+
]
|
|
79
|
+
elif atom_level_embeddings.shape[-1] < self.atom_level_embedding_dim:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Atom-level embedding dimension {atom_level_embeddings.shape[-1]} is less than the expected dimension {self.atom_level_embedding_dim}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Process atom-level embeddings to get shape [n_conformers, n_atom, c_atompair]
|
|
85
|
+
processed_embeddings: Float[torch.Tensor, "n_conformers n_atom c_atompair"] = (
|
|
86
|
+
self.process_atom_level_embedding(atom_level_embeddings)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Pad with zeros if we don't have enough conformers
|
|
90
|
+
current_n_conformers = processed_embeddings.shape[0]
|
|
91
|
+
if current_n_conformers < self.n_conformers:
|
|
92
|
+
# Pad with zeros at the beginning
|
|
93
|
+
padding_size = self.n_conformers - current_n_conformers
|
|
94
|
+
padding: Float[torch.Tensor, "padding_size n_atom c_atompair"] = (
|
|
95
|
+
torch.zeros(
|
|
96
|
+
padding_size,
|
|
97
|
+
processed_embeddings.shape[1],
|
|
98
|
+
processed_embeddings.shape[2],
|
|
99
|
+
device=processed_embeddings.device,
|
|
100
|
+
dtype=processed_embeddings.dtype,
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
processed_embeddings = torch.cat([padding, processed_embeddings], dim=0)
|
|
104
|
+
elif current_n_conformers > self.n_conformers:
|
|
105
|
+
# Truncate to n_conformers
|
|
106
|
+
processed_embeddings = processed_embeddings[: self.n_conformers]
|
|
107
|
+
|
|
108
|
+
# Reshape from [n_conformers, n_atom, c_atompair] to [n_atom, n_conformers * c_atompair]
|
|
109
|
+
reshaped_embeddings: Float[torch.Tensor, "n_atom n_conformers*c_atompair"] = (
|
|
110
|
+
rearrange(processed_embeddings, "c n d -> n (c d)")
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Final MLP to get [n_atom, c_atom]
|
|
114
|
+
result: Float[torch.Tensor, "n_atom c_atom"] = (
|
|
115
|
+
self.conformers_to_atom_single_embedding(reshaped_embeddings)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return result
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from rf3.util_module import init_lecun_normal
|
|
4
|
+
|
|
5
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OuterProductMean(nn.Module):
|
|
9
|
+
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
|
|
10
|
+
super(OuterProductMean, self).__init__()
|
|
11
|
+
self.norm = nn.LayerNorm(d_msa)
|
|
12
|
+
self.proj_left = nn.Linear(d_msa, d_hidden)
|
|
13
|
+
self.proj_right = nn.Linear(d_msa, d_hidden)
|
|
14
|
+
self.proj_out = nn.Linear(d_hidden * d_hidden, d_pair)
|
|
15
|
+
|
|
16
|
+
self.reset_parameter()
|
|
17
|
+
|
|
18
|
+
def reset_parameter(self):
|
|
19
|
+
# normal initialization
|
|
20
|
+
self.proj_left = init_lecun_normal(self.proj_left)
|
|
21
|
+
self.proj_right = init_lecun_normal(self.proj_right)
|
|
22
|
+
nn.init.zeros_(self.proj_left.bias)
|
|
23
|
+
nn.init.zeros_(self.proj_right.bias)
|
|
24
|
+
|
|
25
|
+
# zero initialize output
|
|
26
|
+
nn.init.zeros_(self.proj_out.weight)
|
|
27
|
+
nn.init.zeros_(self.proj_out.bias)
|
|
28
|
+
|
|
29
|
+
def forward(self, msa):
|
|
30
|
+
B, N, L = msa.shape[:3]
|
|
31
|
+
msa = self.norm(msa)
|
|
32
|
+
left = self.proj_left(msa)
|
|
33
|
+
right = self.proj_right(msa)
|
|
34
|
+
right = right / float(N)
|
|
35
|
+
out = torch.einsum("bsli,bsmj->blmij", left, right).reshape(B, L, L, -1)
|
|
36
|
+
out = self.proj_out(out)
|
|
37
|
+
|
|
38
|
+
return out
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class OuterProductMean_AF3(nn.Module):
|
|
42
|
+
def __init__(self, c_msa_embed, c_outer_product, c_out):
|
|
43
|
+
super(OuterProductMean_AF3, self).__init__()
|
|
44
|
+
self.norm = nn.LayerNorm(c_msa_embed)
|
|
45
|
+
self.proj_left = nn.Linear(c_msa_embed, c_outer_product)
|
|
46
|
+
self.proj_right = nn.Linear(c_msa_embed, c_outer_product)
|
|
47
|
+
self.proj_out = nn.Linear(c_outer_product * c_outer_product, c_out)
|
|
48
|
+
|
|
49
|
+
@activation_checkpointing
|
|
50
|
+
def forward(self, msa):
|
|
51
|
+
B, N, L = msa.shape[:3]
|
|
52
|
+
msa = self.norm(msa)
|
|
53
|
+
left = self.proj_left(msa)
|
|
54
|
+
right = self.proj_right(msa)
|
|
55
|
+
right = right / float(N)
|
|
56
|
+
out = torch.einsum("bsli,bsmj->blmij", left, right).reshape(B, L, L, -1)
|
|
57
|
+
out = self.proj_out(out)
|
|
58
|
+
|
|
59
|
+
return out
|