rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from contextlib import ExitStack
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from rfd3.model.layers.block_utils import (
|
|
9
|
+
bucketize_scaled_distogram,
|
|
10
|
+
create_attention_indices,
|
|
11
|
+
)
|
|
12
|
+
from rfd3.model.layers.blocks import (
|
|
13
|
+
CompactStreamingDecoder,
|
|
14
|
+
Downcast,
|
|
15
|
+
LinearEmbedWithPool,
|
|
16
|
+
LinearSequenceHead,
|
|
17
|
+
LocalAtomTransformer,
|
|
18
|
+
LocalTokenTransformer,
|
|
19
|
+
)
|
|
20
|
+
from rfd3.model.layers.encoders import (
|
|
21
|
+
DiffusionTokenEncoder,
|
|
22
|
+
)
|
|
23
|
+
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
|
24
|
+
|
|
25
|
+
from foundry.model.layers.blocks import (
|
|
26
|
+
FourierEmbedding,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RFD3DiffusionModule(nn.Module):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
*,
|
|
36
|
+
c_atom,
|
|
37
|
+
c_atompair,
|
|
38
|
+
c_token,
|
|
39
|
+
c_s,
|
|
40
|
+
c_z,
|
|
41
|
+
c_t_embed,
|
|
42
|
+
sigma_data,
|
|
43
|
+
f_pred,
|
|
44
|
+
n_attn_seq_neighbours,
|
|
45
|
+
n_attn_keys,
|
|
46
|
+
n_recycle,
|
|
47
|
+
atom_attention_encoder,
|
|
48
|
+
diffusion_token_encoder,
|
|
49
|
+
diffusion_transformer,
|
|
50
|
+
atom_attention_decoder,
|
|
51
|
+
# upcast,
|
|
52
|
+
downcast,
|
|
53
|
+
use_local_token_attention=True,
|
|
54
|
+
**_,
|
|
55
|
+
):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.sigma_data = sigma_data
|
|
58
|
+
self.c_atom = c_atom
|
|
59
|
+
self.c_atompair = c_atompair
|
|
60
|
+
self.c_token = c_token
|
|
61
|
+
self.c_s = c_s
|
|
62
|
+
self.c_z = c_z
|
|
63
|
+
self.f_pred = f_pred
|
|
64
|
+
self.n_attn_seq_neighbours = n_attn_seq_neighbours
|
|
65
|
+
self.n_attn_keys = n_attn_keys
|
|
66
|
+
self.use_local_token_attention = use_local_token_attention
|
|
67
|
+
|
|
68
|
+
# Auxiliary
|
|
69
|
+
self.process_r = linearNoBias(3, c_atom)
|
|
70
|
+
self.to_r_update = nn.Sequential(RMSNorm((c_atom,)), linearNoBias(c_atom, 3))
|
|
71
|
+
self.sequence_head = LinearSequenceHead(c_token=c_token)
|
|
72
|
+
|
|
73
|
+
self.n_recycle = n_recycle
|
|
74
|
+
self.n_bins = 65
|
|
75
|
+
self.bucketize_fn = functools.partial(
|
|
76
|
+
bucketize_scaled_distogram,
|
|
77
|
+
min_dist=1,
|
|
78
|
+
max_dist=30,
|
|
79
|
+
sigma_data=1,
|
|
80
|
+
n_bins=self.n_bins,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Time processing
|
|
84
|
+
self.fourier_embedding = nn.ModuleList(
|
|
85
|
+
[FourierEmbedding(c_t_embed), FourierEmbedding(c_t_embed)]
|
|
86
|
+
)
|
|
87
|
+
self.process_n = nn.ModuleList(
|
|
88
|
+
[
|
|
89
|
+
nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)),
|
|
90
|
+
nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
|
|
91
|
+
]
|
|
92
|
+
)
|
|
93
|
+
self.downcast_c = Downcast(c_atom=c_atom, c_token=c_s, c_s=None, **downcast)
|
|
94
|
+
self.downcast_q = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast)
|
|
95
|
+
self.process_a = LinearEmbedWithPool(c_token)
|
|
96
|
+
self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
|
|
97
|
+
|
|
98
|
+
# UNet-like architecture for processing across tokens and atoms
|
|
99
|
+
self.encoder = LocalAtomTransformer(
|
|
100
|
+
c_atom=c_atom, c_s=c_atom, c_atompair=c_atompair, **atom_attention_encoder
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self.diffusion_token_encoder = DiffusionTokenEncoder(
|
|
104
|
+
c_s=c_s,
|
|
105
|
+
c_token=c_token,
|
|
106
|
+
c_z=c_z,
|
|
107
|
+
c_atompair=c_atompair,
|
|
108
|
+
**diffusion_token_encoder,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self.diffusion_transformer = LocalTokenTransformer(
|
|
112
|
+
c_token=c_token,
|
|
113
|
+
c_tokenpair=c_z,
|
|
114
|
+
c_s=c_s,
|
|
115
|
+
**diffusion_transformer,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.decoder = CompactStreamingDecoder(
|
|
119
|
+
c_atom=c_atom,
|
|
120
|
+
c_atompair=c_atompair,
|
|
121
|
+
c_token=c_token,
|
|
122
|
+
c_s=c_s,
|
|
123
|
+
c_tokenpair=c_z,
|
|
124
|
+
**atom_attention_decoder,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def scale_positions_in(self, X_noisy_L, t):
|
|
128
|
+
if t.ndim == 1:
|
|
129
|
+
t = t[..., None, None] # [B, (n_atoms), (3)]
|
|
130
|
+
elif t.ndim == 2:
|
|
131
|
+
t = t[..., None] # [B, n_atoms, (3)]
|
|
132
|
+
|
|
133
|
+
if self.f_pred == "edm":
|
|
134
|
+
R_noisy_L = X_noisy_L / torch.sqrt(t**2 + self.sigma_data**2)
|
|
135
|
+
elif self.f_pred == "unconditioned":
|
|
136
|
+
R_noisy_L = torch.zeros_like(X_noisy_L)
|
|
137
|
+
elif self.f_pred == "noise_pred":
|
|
138
|
+
R_noisy_L = X_noisy_L
|
|
139
|
+
else:
|
|
140
|
+
raise Exception(f"{self.f_pred=} unrecognized")
|
|
141
|
+
return R_noisy_L
|
|
142
|
+
|
|
143
|
+
def scale_positions_out(self, R_update_L, X_noisy_L, t):
|
|
144
|
+
if t.ndim == 1:
|
|
145
|
+
t = t[..., None, None]
|
|
146
|
+
elif t.ndim == 2:
|
|
147
|
+
t = t[..., None] # [B, n_atoms, (3)]
|
|
148
|
+
|
|
149
|
+
if self.f_pred == "edm":
|
|
150
|
+
X_out_L = (self.sigma_data**2 / (self.sigma_data**2 + t**2)) * X_noisy_L + (
|
|
151
|
+
self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5
|
|
152
|
+
) * R_update_L
|
|
153
|
+
elif self.f_pred == "unconditioned":
|
|
154
|
+
X_out_L = R_update_L
|
|
155
|
+
elif self.f_pred == "noise_pred":
|
|
156
|
+
X_out_L = X_noisy_L + R_update_L
|
|
157
|
+
else:
|
|
158
|
+
raise Exception(f"{self.f_pred=} unrecognized")
|
|
159
|
+
return X_out_L
|
|
160
|
+
|
|
161
|
+
def process_time_(self, t_L, i):
|
|
162
|
+
C_L = self.process_n[i](
|
|
163
|
+
self.fourier_embedding[i](
|
|
164
|
+
1 / 4 * torch.log(torch.clamp(t_L, min=1e-20) / self.sigma_data)
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
# Mask out zero-time features;
|
|
168
|
+
C_L = C_L * (t_L > 0).float()[..., None] # [B, L, C_atom]
|
|
169
|
+
return C_L
|
|
170
|
+
|
|
171
|
+
def forward(
|
|
172
|
+
self,
|
|
173
|
+
X_noisy_L,
|
|
174
|
+
t,
|
|
175
|
+
f,
|
|
176
|
+
# Features from initialization
|
|
177
|
+
Q_L_init,
|
|
178
|
+
C_L,
|
|
179
|
+
P_LL,
|
|
180
|
+
S_I,
|
|
181
|
+
Z_II,
|
|
182
|
+
n_recycle=None,
|
|
183
|
+
# Chunked memory optimization parameters
|
|
184
|
+
chunked_pairwise_embedder=None,
|
|
185
|
+
initializer_outputs=None,
|
|
186
|
+
**kwargs,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Diffusion forward pass with recycling.
|
|
190
|
+
Computes denoised positions given encoded features and noisy coordinates.
|
|
191
|
+
"""
|
|
192
|
+
# ... Collect inputs
|
|
193
|
+
tok_idx = f["atom_to_token_map"]
|
|
194
|
+
L = len(tok_idx)
|
|
195
|
+
I = tok_idx.max() + 1 # Number of tokens
|
|
196
|
+
f["attn_indices"] = create_attention_indices(
|
|
197
|
+
X_L=X_noisy_L,
|
|
198
|
+
f=f,
|
|
199
|
+
n_attn_keys=self.n_attn_keys,
|
|
200
|
+
n_attn_seq_neighbours=self.n_attn_seq_neighbours,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# ... Expand t tensors
|
|
204
|
+
t_L = t.unsqueeze(-1).expand(-1, L) * (
|
|
205
|
+
~f["is_motif_atom_with_fixed_coord"]
|
|
206
|
+
).float().unsqueeze(0)
|
|
207
|
+
t_I = t.unsqueeze(-1).expand(-1, I) * (
|
|
208
|
+
~f["is_motif_token_with_fully_fixed_coord"]
|
|
209
|
+
).float().unsqueeze(0)
|
|
210
|
+
|
|
211
|
+
# ... Create scaled positions
|
|
212
|
+
R_L_uniform = self.scale_positions_in(X_noisy_L, t)
|
|
213
|
+
R_noisy_L = self.scale_positions_in(X_noisy_L, t_L)
|
|
214
|
+
|
|
215
|
+
# ... Pool initial representation to sequence level
|
|
216
|
+
A_I = self.process_a(R_noisy_L, tok_idx=tok_idx)
|
|
217
|
+
S_I = self.downcast_c(C_L, S_I, tok_idx=tok_idx)
|
|
218
|
+
|
|
219
|
+
# ... Add batch-wise features to inputs
|
|
220
|
+
Q_L = Q_L_init.unsqueeze(0) + self.process_r(R_noisy_L)
|
|
221
|
+
C_L = C_L.unsqueeze(0) + self.process_time_(t_L, i=0)
|
|
222
|
+
S_I = S_I.unsqueeze(0) + self.process_time_(t_I, i=1)
|
|
223
|
+
C_L = C_L + self.process_c(C_L)
|
|
224
|
+
|
|
225
|
+
# ... Run Local-Atom Self Attention and Pool
|
|
226
|
+
if chunked_pairwise_embedder is not None:
|
|
227
|
+
# Chunked mode: pass chunked embedder and feature dict
|
|
228
|
+
Q_L = self.encoder(
|
|
229
|
+
Q_L,
|
|
230
|
+
C_L,
|
|
231
|
+
P_LL=None,
|
|
232
|
+
indices=f["attn_indices"],
|
|
233
|
+
f=f, # Pass feature dict for chunked computation
|
|
234
|
+
chunked_pairwise_embedder=chunked_pairwise_embedder,
|
|
235
|
+
initializer_outputs=initializer_outputs,
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
# Standard mode: use full P_LL
|
|
239
|
+
Q_L = self.encoder(Q_L, C_L, P_LL, indices=f["attn_indices"])
|
|
240
|
+
A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx)
|
|
241
|
+
|
|
242
|
+
# Debug chunked parameters
|
|
243
|
+
|
|
244
|
+
# ... Run forward with recycling
|
|
245
|
+
recycled_features = self.forward_with_recycle(
|
|
246
|
+
n_recycle,
|
|
247
|
+
X_noisy_L=X_noisy_L,
|
|
248
|
+
R_L_uniform=R_L_uniform,
|
|
249
|
+
t_L=t_L,
|
|
250
|
+
f=f,
|
|
251
|
+
Q_L=Q_L,
|
|
252
|
+
C_L=C_L,
|
|
253
|
+
P_LL=P_LL,
|
|
254
|
+
A_I=A_I,
|
|
255
|
+
S_I=S_I,
|
|
256
|
+
Z_II=Z_II,
|
|
257
|
+
chunked_pairwise_embedder=chunked_pairwise_embedder,
|
|
258
|
+
initializer_outputs=initializer_outputs,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# ... Collect outputs
|
|
262
|
+
outputs = {
|
|
263
|
+
"X_L": recycled_features["X_L"], # [B, L, 3] denoised positions
|
|
264
|
+
"sequence_indices_I": recycled_features["sequence_indices_I"],
|
|
265
|
+
"sequence_logits_I": recycled_features["sequence_logits_I"],
|
|
266
|
+
}
|
|
267
|
+
return outputs
|
|
268
|
+
|
|
269
|
+
def forward_with_recycle(
|
|
270
|
+
self,
|
|
271
|
+
n_recycle,
|
|
272
|
+
**kwargs,
|
|
273
|
+
):
|
|
274
|
+
if not self.training:
|
|
275
|
+
n_recycle = self.n_recycle
|
|
276
|
+
else:
|
|
277
|
+
assert n_recycle is not None
|
|
278
|
+
|
|
279
|
+
recycled_features = {}
|
|
280
|
+
for i in range(n_recycle):
|
|
281
|
+
with ExitStack() as stack:
|
|
282
|
+
last = not (i < n_recycle - 1)
|
|
283
|
+
if not last:
|
|
284
|
+
stack.enter_context(torch.no_grad())
|
|
285
|
+
|
|
286
|
+
# Clear the autocast cache if gradients are enabled (workaround for autocast bug)
|
|
287
|
+
# See: https://github.com/pytorch/pytorch/issues/65766
|
|
288
|
+
if torch.is_grad_enabled():
|
|
289
|
+
torch.clear_autocast_cache()
|
|
290
|
+
|
|
291
|
+
# Run forward
|
|
292
|
+
recycled_features = self.process_(
|
|
293
|
+
D_II_self=recycled_features.get("D_II_self"),
|
|
294
|
+
X_L_self=recycled_features.get("X_L"),
|
|
295
|
+
**kwargs,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
return recycled_features
|
|
299
|
+
|
|
300
|
+
def process_(
|
|
301
|
+
self,
|
|
302
|
+
D_II_self,
|
|
303
|
+
X_L_self,
|
|
304
|
+
*,
|
|
305
|
+
R_L_uniform,
|
|
306
|
+
X_noisy_L,
|
|
307
|
+
t_L,
|
|
308
|
+
f,
|
|
309
|
+
Q_L,
|
|
310
|
+
C_L,
|
|
311
|
+
P_LL,
|
|
312
|
+
A_I,
|
|
313
|
+
S_I,
|
|
314
|
+
Z_II,
|
|
315
|
+
chunked_pairwise_embedder=None,
|
|
316
|
+
initializer_outputs=None,
|
|
317
|
+
**_,
|
|
318
|
+
):
|
|
319
|
+
# ... Embed token level features with atom level encodings
|
|
320
|
+
S_I, Z_II = self.diffusion_token_encoder(
|
|
321
|
+
f=f,
|
|
322
|
+
R_L=R_L_uniform,
|
|
323
|
+
D_II_self=D_II_self,
|
|
324
|
+
S_init_I=S_I,
|
|
325
|
+
Z_init_II=Z_II,
|
|
326
|
+
C_L=C_L,
|
|
327
|
+
P_LL=P_LL,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# ... Diffusion transformer
|
|
331
|
+
A_I = self.diffusion_transformer(
|
|
332
|
+
A_I,
|
|
333
|
+
S_I,
|
|
334
|
+
Z_II,
|
|
335
|
+
f=f,
|
|
336
|
+
X_L=(
|
|
337
|
+
X_noisy_L[..., f["is_ca"], :]
|
|
338
|
+
if X_L_self is None
|
|
339
|
+
else X_L_self[..., f["is_ca"], :]
|
|
340
|
+
),
|
|
341
|
+
full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# ... Decoder readout
|
|
345
|
+
# Check if using chunked P_LL mode
|
|
346
|
+
|
|
347
|
+
if chunked_pairwise_embedder is not None:
|
|
348
|
+
# Chunked mode: pass embedder and no P_LL
|
|
349
|
+
A_I, Q_L, o = self.decoder(
|
|
350
|
+
A_I,
|
|
351
|
+
S_I,
|
|
352
|
+
Z_II,
|
|
353
|
+
Q_L,
|
|
354
|
+
C_L,
|
|
355
|
+
P_LL=None, # Not used in chunked mode
|
|
356
|
+
tok_idx=f["atom_to_token_map"],
|
|
357
|
+
indices=f["attn_indices"],
|
|
358
|
+
f=f, # Pass f for chunked computation
|
|
359
|
+
chunked_pairwise_embedder=chunked_pairwise_embedder,
|
|
360
|
+
initializer_outputs=initializer_outputs,
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
# Original mode: use full P_LL
|
|
364
|
+
A_I, Q_L, o = self.decoder(
|
|
365
|
+
A_I,
|
|
366
|
+
S_I,
|
|
367
|
+
Z_II,
|
|
368
|
+
Q_L,
|
|
369
|
+
C_L,
|
|
370
|
+
P_LL=P_LL,
|
|
371
|
+
tok_idx=f["atom_to_token_map"],
|
|
372
|
+
indices=f["attn_indices"],
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# ... Process outputs to positions update
|
|
376
|
+
R_update_L = self.to_r_update(Q_L)
|
|
377
|
+
X_out_L = self.scale_positions_out(R_update_L, X_noisy_L, t_L)
|
|
378
|
+
|
|
379
|
+
sequence_logits_I, sequence_indices_I = self.sequence_head(A_I=A_I)
|
|
380
|
+
D_II_self = self.bucketize_fn(X_out_L[..., f["is_ca"], :].detach())
|
|
381
|
+
|
|
382
|
+
return {
|
|
383
|
+
"X_L": X_out_L,
|
|
384
|
+
"D_II_self": D_II_self,
|
|
385
|
+
"sequence_logits_I": sequence_logits_I,
|
|
386
|
+
"sequence_indices_I": sequence_indices_I,
|
|
387
|
+
} | o
|
rfd3/model/cfg_utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def strip_f(
|
|
5
|
+
f,
|
|
6
|
+
cfg_features,
|
|
7
|
+
):
|
|
8
|
+
"""
|
|
9
|
+
Strips conditioning features from 'f' for classifier-free guidance.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
f (dict): Conditioning features
|
|
13
|
+
cfg_features (list): List of features to be set to 0
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
dict: Stripped conditioning features
|
|
17
|
+
"""
|
|
18
|
+
# variable used to identify token and atom features independent of their variable names (this way we only need to hardcode these two)
|
|
19
|
+
token_dim = f["is_motif_token_unindexed"].shape[0]
|
|
20
|
+
atom_dim = f["is_motif_atom_unindexed"].shape[0]
|
|
21
|
+
|
|
22
|
+
# identify the first atom and token to be cropped
|
|
23
|
+
crop = torch.any(f["is_motif_atom_unindexed"]).item()
|
|
24
|
+
atom_crop_index = (
|
|
25
|
+
torch.where(f["is_motif_atom_unindexed"])[0][0]
|
|
26
|
+
if crop
|
|
27
|
+
else f["is_motif_atom_unindexed"].shape[0]
|
|
28
|
+
)
|
|
29
|
+
token_crop_index = (
|
|
30
|
+
torch.where(f["is_motif_token_unindexed"])[0][0]
|
|
31
|
+
if crop
|
|
32
|
+
else f["is_motif_token_unindexed"].shape[0]
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# ... Mask out conditioning features
|
|
36
|
+
f_stripped = f.copy()
|
|
37
|
+
|
|
38
|
+
# Crop features based on them being atom or token features and based on them being 1d or 2d features
|
|
39
|
+
for k, v in f.items():
|
|
40
|
+
# handle cases not captured below
|
|
41
|
+
v_cropped = v
|
|
42
|
+
|
|
43
|
+
# handle token features
|
|
44
|
+
if token_dim in v.shape:
|
|
45
|
+
# Check if it's a 2D feature (square matrix)
|
|
46
|
+
if len(v.shape) == 2 and v.shape[0] == v.shape[1]:
|
|
47
|
+
v_cropped = v[:token_crop_index, :token_crop_index]
|
|
48
|
+
else:
|
|
49
|
+
v_cropped = v[:token_crop_index]
|
|
50
|
+
# handle atom features
|
|
51
|
+
if atom_dim in v.shape:
|
|
52
|
+
# Check if it's a 2D feature (square matrix)
|
|
53
|
+
if len(v.shape) == 2 and v.shape[0] == v.shape[1]:
|
|
54
|
+
v_cropped = v[:atom_crop_index, :atom_crop_index]
|
|
55
|
+
else:
|
|
56
|
+
v_cropped = v[:atom_crop_index]
|
|
57
|
+
|
|
58
|
+
# set the feature to default value if it is in the cfg_features
|
|
59
|
+
if k in cfg_features:
|
|
60
|
+
v_cropped = torch.zeros_like(v_cropped).to(
|
|
61
|
+
v_cropped.device, dtype=v_cropped.dtype
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# update the feature in the dictionary
|
|
65
|
+
f_stripped[k] = v_cropped
|
|
66
|
+
|
|
67
|
+
return f_stripped
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def strip_X(X_L, f_stripped):
|
|
71
|
+
"""
|
|
72
|
+
Strips X_L unindexed atoms from X_L
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
X_L (torch.Tensor): Atom coordinates
|
|
76
|
+
f_stripped (dict): Stripped conditioning features
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
torch.Tensor: Atom coordinates with unindexed atoms removed
|
|
80
|
+
"""
|
|
81
|
+
return X_L[..., : f_stripped["is_motif_atom_unindexed"].shape[0], :]
|