boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,844 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from math import sqrt
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from einops import rearrange
|
10
|
+
from torch import nn
|
11
|
+
from torch.nn import Module
|
12
|
+
|
13
|
+
import boltz.model.layers.initialize as init
|
14
|
+
from boltz.data import const
|
15
|
+
from boltz.model.loss.diffusion import (
|
16
|
+
smooth_lddt_loss,
|
17
|
+
weighted_rigid_align,
|
18
|
+
)
|
19
|
+
from boltz.model.modules.utils import center_random_augmentation
|
20
|
+
from boltz.model.modules.encoders import (
|
21
|
+
AtomAttentionDecoder,
|
22
|
+
AtomAttentionEncoder,
|
23
|
+
FourierEmbedding,
|
24
|
+
PairwiseConditioning,
|
25
|
+
SingleConditioning,
|
26
|
+
)
|
27
|
+
from boltz.model.modules.transformers import (
|
28
|
+
ConditionedTransitionBlock,
|
29
|
+
DiffusionTransformer,
|
30
|
+
)
|
31
|
+
from boltz.model.modules.utils import (
|
32
|
+
LinearNoBias,
|
33
|
+
compute_random_augmentation,
|
34
|
+
center_random_augmentation,
|
35
|
+
default,
|
36
|
+
log,
|
37
|
+
)
|
38
|
+
from boltz.model.potentials.potentials import get_potentials
|
39
|
+
|
40
|
+
|
41
|
+
class DiffusionModule(Module):
|
42
|
+
"""Diffusion module"""
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
token_s: int,
|
47
|
+
token_z: int,
|
48
|
+
atom_s: int,
|
49
|
+
atom_z: int,
|
50
|
+
atoms_per_window_queries: int = 32,
|
51
|
+
atoms_per_window_keys: int = 128,
|
52
|
+
sigma_data: int = 16,
|
53
|
+
dim_fourier: int = 256,
|
54
|
+
atom_encoder_depth: int = 3,
|
55
|
+
atom_encoder_heads: int = 4,
|
56
|
+
token_transformer_depth: int = 24,
|
57
|
+
token_transformer_heads: int = 8,
|
58
|
+
atom_decoder_depth: int = 3,
|
59
|
+
atom_decoder_heads: int = 4,
|
60
|
+
atom_feature_dim: int = 128,
|
61
|
+
conditioning_transition_layers: int = 2,
|
62
|
+
activation_checkpointing: bool = False,
|
63
|
+
offload_to_cpu: bool = False,
|
64
|
+
**kwargs,
|
65
|
+
) -> None:
|
66
|
+
"""Initialize the diffusion module.
|
67
|
+
|
68
|
+
Parameters
|
69
|
+
----------
|
70
|
+
token_s : int
|
71
|
+
The single representation dimension.
|
72
|
+
token_z : int
|
73
|
+
The pair representation dimension.
|
74
|
+
atom_s : int
|
75
|
+
The atom single representation dimension.
|
76
|
+
atom_z : int
|
77
|
+
The atom pair representation dimension.
|
78
|
+
atoms_per_window_queries : int, optional
|
79
|
+
The number of atoms per window for queries, by default 32.
|
80
|
+
atoms_per_window_keys : int, optional
|
81
|
+
The number of atoms per window for keys, by default 128.
|
82
|
+
sigma_data : int, optional
|
83
|
+
The standard deviation of the data distribution, by default 16.
|
84
|
+
dim_fourier : int, optional
|
85
|
+
The dimension of the fourier embedding, by default 256.
|
86
|
+
atom_encoder_depth : int, optional
|
87
|
+
The depth of the atom encoder, by default 3.
|
88
|
+
atom_encoder_heads : int, optional
|
89
|
+
The number of heads in the atom encoder, by default 4.
|
90
|
+
token_transformer_depth : int, optional
|
91
|
+
The depth of the token transformer, by default 24.
|
92
|
+
token_transformer_heads : int, optional
|
93
|
+
The number of heads in the token transformer, by default 8.
|
94
|
+
atom_decoder_depth : int, optional
|
95
|
+
The depth of the atom decoder, by default 3.
|
96
|
+
atom_decoder_heads : int, optional
|
97
|
+
The number of heads in the atom decoder, by default 4.
|
98
|
+
atom_feature_dim : int, optional
|
99
|
+
The atom feature dimension, by default 128.
|
100
|
+
conditioning_transition_layers : int, optional
|
101
|
+
The number of transition layers for conditioning, by default 2.
|
102
|
+
activation_checkpointing : bool, optional
|
103
|
+
Whether to use activation checkpointing, by default False.
|
104
|
+
offload_to_cpu : bool, optional
|
105
|
+
Whether to offload the activations to CPU, by default False.
|
106
|
+
|
107
|
+
"""
|
108
|
+
super().__init__()
|
109
|
+
|
110
|
+
self.atoms_per_window_queries = atoms_per_window_queries
|
111
|
+
self.atoms_per_window_keys = atoms_per_window_keys
|
112
|
+
self.sigma_data = sigma_data
|
113
|
+
|
114
|
+
self.single_conditioner = SingleConditioning(
|
115
|
+
sigma_data=sigma_data,
|
116
|
+
token_s=token_s,
|
117
|
+
dim_fourier=dim_fourier,
|
118
|
+
num_transitions=conditioning_transition_layers,
|
119
|
+
)
|
120
|
+
self.pairwise_conditioner = PairwiseConditioning(
|
121
|
+
token_z=token_z,
|
122
|
+
dim_token_rel_pos_feats=token_z,
|
123
|
+
num_transitions=conditioning_transition_layers,
|
124
|
+
)
|
125
|
+
|
126
|
+
self.atom_attention_encoder = AtomAttentionEncoder(
|
127
|
+
atom_s=atom_s,
|
128
|
+
atom_z=atom_z,
|
129
|
+
token_s=token_s,
|
130
|
+
token_z=token_z,
|
131
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
132
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
133
|
+
atom_feature_dim=atom_feature_dim,
|
134
|
+
atom_encoder_depth=atom_encoder_depth,
|
135
|
+
atom_encoder_heads=atom_encoder_heads,
|
136
|
+
structure_prediction=True,
|
137
|
+
activation_checkpointing=activation_checkpointing,
|
138
|
+
)
|
139
|
+
|
140
|
+
self.s_to_a_linear = nn.Sequential(
|
141
|
+
nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
|
142
|
+
)
|
143
|
+
init.final_init_(self.s_to_a_linear[1].weight)
|
144
|
+
|
145
|
+
self.token_transformer = DiffusionTransformer(
|
146
|
+
dim=2 * token_s,
|
147
|
+
dim_single_cond=2 * token_s,
|
148
|
+
dim_pairwise=token_z,
|
149
|
+
depth=token_transformer_depth,
|
150
|
+
heads=token_transformer_heads,
|
151
|
+
activation_checkpointing=activation_checkpointing,
|
152
|
+
offload_to_cpu=offload_to_cpu,
|
153
|
+
)
|
154
|
+
|
155
|
+
self.a_norm = nn.LayerNorm(2 * token_s)
|
156
|
+
|
157
|
+
self.atom_attention_decoder = AtomAttentionDecoder(
|
158
|
+
atom_s=atom_s,
|
159
|
+
atom_z=atom_z,
|
160
|
+
token_s=token_s,
|
161
|
+
attn_window_queries=atoms_per_window_queries,
|
162
|
+
attn_window_keys=atoms_per_window_keys,
|
163
|
+
atom_decoder_depth=atom_decoder_depth,
|
164
|
+
atom_decoder_heads=atom_decoder_heads,
|
165
|
+
activation_checkpointing=activation_checkpointing,
|
166
|
+
)
|
167
|
+
|
168
|
+
def forward(
|
169
|
+
self,
|
170
|
+
s_inputs,
|
171
|
+
s_trunk,
|
172
|
+
z_trunk,
|
173
|
+
r_noisy,
|
174
|
+
times,
|
175
|
+
relative_position_encoding,
|
176
|
+
feats,
|
177
|
+
multiplicity=1,
|
178
|
+
model_cache=None,
|
179
|
+
):
|
180
|
+
s, normed_fourier = self.single_conditioner(
|
181
|
+
times=times,
|
182
|
+
s_trunk=s_trunk.repeat_interleave(multiplicity, 0),
|
183
|
+
s_inputs=s_inputs.repeat_interleave(multiplicity, 0),
|
184
|
+
)
|
185
|
+
|
186
|
+
if model_cache is None or len(model_cache) == 0:
|
187
|
+
z = self.pairwise_conditioner(
|
188
|
+
z_trunk=z_trunk, token_rel_pos_feats=relative_position_encoding
|
189
|
+
)
|
190
|
+
else:
|
191
|
+
z = None
|
192
|
+
|
193
|
+
# Compute Atom Attention Encoder and aggregation to coarse-grained tokens
|
194
|
+
a, q_skip, c_skip, p_skip, to_keys = self.atom_attention_encoder(
|
195
|
+
feats=feats,
|
196
|
+
s_trunk=s_trunk,
|
197
|
+
z=z,
|
198
|
+
r=r_noisy,
|
199
|
+
multiplicity=multiplicity,
|
200
|
+
model_cache=model_cache,
|
201
|
+
)
|
202
|
+
|
203
|
+
# Full self-attention on token level
|
204
|
+
a = a + self.s_to_a_linear(s)
|
205
|
+
|
206
|
+
mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
207
|
+
a = self.token_transformer(
|
208
|
+
a,
|
209
|
+
mask=mask.float(),
|
210
|
+
s=s,
|
211
|
+
z=z, # note z is not expanded with multiplicity until after bias is computed
|
212
|
+
multiplicity=multiplicity,
|
213
|
+
model_cache=model_cache,
|
214
|
+
)
|
215
|
+
a = self.a_norm(a)
|
216
|
+
|
217
|
+
# Broadcast token activations to atoms and run Sequence-local Atom Attention
|
218
|
+
r_update = self.atom_attention_decoder(
|
219
|
+
a=a,
|
220
|
+
q=q_skip,
|
221
|
+
c=c_skip,
|
222
|
+
p=p_skip,
|
223
|
+
feats=feats,
|
224
|
+
multiplicity=multiplicity,
|
225
|
+
to_keys=to_keys,
|
226
|
+
model_cache=model_cache,
|
227
|
+
)
|
228
|
+
|
229
|
+
return {"r_update": r_update, "token_a": a.detach()}
|
230
|
+
|
231
|
+
|
232
|
+
class OutTokenFeatUpdate(Module):
|
233
|
+
"""Output token feature update"""
|
234
|
+
|
235
|
+
def __init__(
|
236
|
+
self,
|
237
|
+
sigma_data: float,
|
238
|
+
token_s=384,
|
239
|
+
dim_fourier=256,
|
240
|
+
):
|
241
|
+
"""Initialize the Output token feature update for confidence model.
|
242
|
+
|
243
|
+
Parameters
|
244
|
+
----------
|
245
|
+
sigma_data : float
|
246
|
+
The standard deviation of the data distribution.
|
247
|
+
token_s : int, optional
|
248
|
+
The token dimension, by default 384.
|
249
|
+
dim_fourier : int, optional
|
250
|
+
The dimension of the fourier embedding, by default 256.
|
251
|
+
|
252
|
+
"""
|
253
|
+
|
254
|
+
super().__init__()
|
255
|
+
self.sigma_data = sigma_data
|
256
|
+
|
257
|
+
self.norm_next = nn.LayerNorm(2 * token_s)
|
258
|
+
self.fourier_embed = FourierEmbedding(dim_fourier)
|
259
|
+
self.norm_fourier = nn.LayerNorm(dim_fourier)
|
260
|
+
self.transition_block = ConditionedTransitionBlock(
|
261
|
+
2 * token_s, 2 * token_s + dim_fourier
|
262
|
+
)
|
263
|
+
|
264
|
+
def forward(
|
265
|
+
self,
|
266
|
+
times,
|
267
|
+
acc_a,
|
268
|
+
next_a,
|
269
|
+
):
|
270
|
+
next_a = self.norm_next(next_a)
|
271
|
+
fourier_embed = self.fourier_embed(times)
|
272
|
+
normed_fourier = (
|
273
|
+
self.norm_fourier(fourier_embed)
|
274
|
+
.unsqueeze(1)
|
275
|
+
.expand(-1, next_a.shape[1], -1)
|
276
|
+
)
|
277
|
+
cond_a = torch.cat((acc_a, normed_fourier), dim=-1)
|
278
|
+
|
279
|
+
acc_a = acc_a + self.transition_block(next_a, cond_a)
|
280
|
+
|
281
|
+
return acc_a
|
282
|
+
|
283
|
+
|
284
|
+
class AtomDiffusion(Module):
|
285
|
+
"""Atom diffusion module"""
|
286
|
+
|
287
|
+
def __init__(
|
288
|
+
self,
|
289
|
+
score_model_args,
|
290
|
+
num_sampling_steps=5,
|
291
|
+
sigma_min=0.0004,
|
292
|
+
sigma_max=160.0,
|
293
|
+
sigma_data=16.0,
|
294
|
+
rho=7,
|
295
|
+
P_mean=-1.2,
|
296
|
+
P_std=1.5,
|
297
|
+
gamma_0=0.8,
|
298
|
+
gamma_min=1.0,
|
299
|
+
noise_scale=1.003,
|
300
|
+
step_scale=1.5,
|
301
|
+
coordinate_augmentation=True,
|
302
|
+
compile_score=False,
|
303
|
+
alignment_reverse_diff=False,
|
304
|
+
synchronize_sigmas=False,
|
305
|
+
use_inference_model_cache=False,
|
306
|
+
accumulate_token_repr=False,
|
307
|
+
**kwargs,
|
308
|
+
):
|
309
|
+
"""Initialize the atom diffusion module.
|
310
|
+
|
311
|
+
Parameters
|
312
|
+
----------
|
313
|
+
score_model_args : dict
|
314
|
+
The arguments for the score model.
|
315
|
+
num_sampling_steps : int, optional
|
316
|
+
The number of sampling steps, by default 5.
|
317
|
+
sigma_min : float, optional
|
318
|
+
The minimum sigma value, by default 0.0004.
|
319
|
+
sigma_max : float, optional
|
320
|
+
The maximum sigma value, by default 160.0.
|
321
|
+
sigma_data : float, optional
|
322
|
+
The standard deviation of the data distribution, by default 16.0.
|
323
|
+
rho : int, optional
|
324
|
+
The rho value, by default 7.
|
325
|
+
P_mean : float, optional
|
326
|
+
The mean value of P, by default -1.2.
|
327
|
+
P_std : float, optional
|
328
|
+
The standard deviation of P, by default 1.5.
|
329
|
+
gamma_0 : float, optional
|
330
|
+
The gamma value, by default 0.8.
|
331
|
+
gamma_min : float, optional
|
332
|
+
The minimum gamma value, by default 1.0.
|
333
|
+
noise_scale : float, optional
|
334
|
+
The noise scale, by default 1.003.
|
335
|
+
step_scale : float, optional
|
336
|
+
The step scale, by default 1.5.
|
337
|
+
coordinate_augmentation : bool, optional
|
338
|
+
Whether to use coordinate augmentation, by default True.
|
339
|
+
compile_score : bool, optional
|
340
|
+
Whether to compile the score model, by default False.
|
341
|
+
alignment_reverse_diff : bool, optional
|
342
|
+
Whether to use alignment reverse diff, by default False.
|
343
|
+
synchronize_sigmas : bool, optional
|
344
|
+
Whether to synchronize the sigmas, by default False.
|
345
|
+
use_inference_model_cache : bool, optional
|
346
|
+
Whether to use the inference model cache, by default False.
|
347
|
+
accumulate_token_repr : bool, optional
|
348
|
+
Whether to accumulate the token representation, by default False.
|
349
|
+
|
350
|
+
"""
|
351
|
+
super().__init__()
|
352
|
+
self.score_model = DiffusionModule(
|
353
|
+
**score_model_args,
|
354
|
+
)
|
355
|
+
if compile_score:
|
356
|
+
self.score_model = torch.compile(
|
357
|
+
self.score_model, dynamic=False, fullgraph=False
|
358
|
+
)
|
359
|
+
|
360
|
+
# parameters
|
361
|
+
self.sigma_min = sigma_min
|
362
|
+
self.sigma_max = sigma_max
|
363
|
+
self.sigma_data = sigma_data
|
364
|
+
self.rho = rho
|
365
|
+
self.P_mean = P_mean
|
366
|
+
self.P_std = P_std
|
367
|
+
self.num_sampling_steps = num_sampling_steps
|
368
|
+
self.gamma_0 = gamma_0
|
369
|
+
self.gamma_min = gamma_min
|
370
|
+
self.noise_scale = noise_scale
|
371
|
+
self.step_scale = step_scale
|
372
|
+
self.coordinate_augmentation = coordinate_augmentation
|
373
|
+
self.alignment_reverse_diff = alignment_reverse_diff
|
374
|
+
self.synchronize_sigmas = synchronize_sigmas
|
375
|
+
self.use_inference_model_cache = use_inference_model_cache
|
376
|
+
|
377
|
+
self.accumulate_token_repr = accumulate_token_repr
|
378
|
+
self.token_s = score_model_args["token_s"]
|
379
|
+
if self.accumulate_token_repr:
|
380
|
+
self.out_token_feat_update = OutTokenFeatUpdate(
|
381
|
+
sigma_data=sigma_data,
|
382
|
+
token_s=score_model_args["token_s"],
|
383
|
+
dim_fourier=score_model_args["dim_fourier"],
|
384
|
+
)
|
385
|
+
|
386
|
+
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
387
|
+
|
388
|
+
@property
|
389
|
+
def device(self):
|
390
|
+
return next(self.score_model.parameters()).device
|
391
|
+
|
392
|
+
def c_skip(self, sigma):
|
393
|
+
return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
|
394
|
+
|
395
|
+
def c_out(self, sigma):
|
396
|
+
return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
|
397
|
+
|
398
|
+
def c_in(self, sigma):
|
399
|
+
return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
|
400
|
+
|
401
|
+
def c_noise(self, sigma):
|
402
|
+
return log(sigma / self.sigma_data) * 0.25
|
403
|
+
|
404
|
+
def preconditioned_network_forward(
|
405
|
+
self,
|
406
|
+
noised_atom_coords,
|
407
|
+
sigma,
|
408
|
+
network_condition_kwargs: dict,
|
409
|
+
training: bool = True,
|
410
|
+
):
|
411
|
+
batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
|
412
|
+
|
413
|
+
if isinstance(sigma, float):
|
414
|
+
sigma = torch.full((batch,), sigma, device=device)
|
415
|
+
|
416
|
+
padded_sigma = rearrange(sigma, "b -> b 1 1")
|
417
|
+
|
418
|
+
net_out = self.score_model(
|
419
|
+
r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
|
420
|
+
times=self.c_noise(sigma),
|
421
|
+
**network_condition_kwargs,
|
422
|
+
)
|
423
|
+
|
424
|
+
denoised_coords = (
|
425
|
+
self.c_skip(padded_sigma) * noised_atom_coords
|
426
|
+
+ self.c_out(padded_sigma) * net_out["r_update"]
|
427
|
+
)
|
428
|
+
return denoised_coords, net_out["token_a"]
|
429
|
+
|
430
|
+
def sample_schedule(self, num_sampling_steps=None):
|
431
|
+
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
|
432
|
+
inv_rho = 1 / self.rho
|
433
|
+
|
434
|
+
steps = torch.arange(
|
435
|
+
num_sampling_steps, device=self.device, dtype=torch.float32
|
436
|
+
)
|
437
|
+
sigmas = (
|
438
|
+
self.sigma_max**inv_rho
|
439
|
+
+ steps
|
440
|
+
/ (num_sampling_steps - 1)
|
441
|
+
* (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
|
442
|
+
) ** self.rho
|
443
|
+
|
444
|
+
sigmas = sigmas * self.sigma_data
|
445
|
+
|
446
|
+
sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0.
|
447
|
+
return sigmas
|
448
|
+
|
449
|
+
def sample(
|
450
|
+
self,
|
451
|
+
atom_mask,
|
452
|
+
num_sampling_steps=None,
|
453
|
+
multiplicity=1,
|
454
|
+
max_parallel_samples=None,
|
455
|
+
train_accumulate_token_repr=False,
|
456
|
+
steering_args=None,
|
457
|
+
**network_condition_kwargs,
|
458
|
+
):
|
459
|
+
if steering_args is not None and (steering_args["fk_steering"] or steering_args["guidance_update"]):
|
460
|
+
potentials = get_potentials()
|
461
|
+
if steering_args is not None and steering_args["fk_steering"]:
|
462
|
+
multiplicity = multiplicity * steering_args["num_particles"]
|
463
|
+
energy_traj = torch.empty((multiplicity, 0), device=self.device)
|
464
|
+
resample_weights = torch.ones(multiplicity, device=self.device).reshape(
|
465
|
+
-1, steering_args["num_particles"]
|
466
|
+
)
|
467
|
+
if steering_args is not None and steering_args["guidance_update"]:
|
468
|
+
scaled_guidance_update = torch.zeros(
|
469
|
+
(multiplicity, *atom_mask.shape[1:], 3),
|
470
|
+
dtype=torch.float32,
|
471
|
+
device=self.device,
|
472
|
+
)
|
473
|
+
|
474
|
+
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
|
475
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
476
|
+
|
477
|
+
shape = (*atom_mask.shape, 3)
|
478
|
+
token_repr_shape = (multiplicity, network_condition_kwargs['feats']['token_index'].shape[1], 2 * self.token_s)
|
479
|
+
|
480
|
+
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
|
481
|
+
sigmas = self.sample_schedule(num_sampling_steps)
|
482
|
+
gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
|
483
|
+
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
|
484
|
+
|
485
|
+
# atom position is noise at the beginning
|
486
|
+
init_sigma = sigmas[0]
|
487
|
+
atom_coords = init_sigma * torch.randn(shape, device=self.device)
|
488
|
+
atom_coords_denoised = None
|
489
|
+
model_cache = {} if self.use_inference_model_cache else None
|
490
|
+
|
491
|
+
token_repr = None
|
492
|
+
token_a = None
|
493
|
+
|
494
|
+
# gradually denoise
|
495
|
+
for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas):
|
496
|
+
random_R, random_tr = compute_random_augmentation(
|
497
|
+
multiplicity, device=atom_coords.device, dtype=atom_coords.dtype
|
498
|
+
)
|
499
|
+
atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
|
500
|
+
atom_coords = (
|
501
|
+
torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr
|
502
|
+
)
|
503
|
+
if atom_coords_denoised is not None:
|
504
|
+
atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True)
|
505
|
+
atom_coords_denoised = (
|
506
|
+
torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R)
|
507
|
+
+ random_tr
|
508
|
+
)
|
509
|
+
if steering_args is not None and steering_args["guidance_update"] and scaled_guidance_update is not None:
|
510
|
+
scaled_guidance_update = torch.einsum(
|
511
|
+
"bmd,bds->bms", scaled_guidance_update, random_R
|
512
|
+
)
|
513
|
+
|
514
|
+
sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
|
515
|
+
|
516
|
+
t_hat = sigma_tm * (1 + gamma)
|
517
|
+
steering_t = 1.0 - (step_idx / num_sampling_steps)
|
518
|
+
noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2)
|
519
|
+
eps = sqrt(noise_var) * torch.randn(shape, device=self.device)
|
520
|
+
atom_coords_noisy = atom_coords + eps
|
521
|
+
|
522
|
+
with torch.no_grad():
|
523
|
+
atom_coords_denoised = torch.zeros_like(atom_coords_noisy)
|
524
|
+
token_a = torch.zeros(token_repr_shape).to(atom_coords_noisy)
|
525
|
+
|
526
|
+
sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device)
|
527
|
+
sample_ids_chunks = sample_ids.chunk(
|
528
|
+
multiplicity % max_parallel_samples + 1
|
529
|
+
)
|
530
|
+
for sample_ids_chunk in sample_ids_chunks:
|
531
|
+
atom_coords_denoised_chunk, token_a_chunk = \
|
532
|
+
self.preconditioned_network_forward(
|
533
|
+
atom_coords_noisy[sample_ids_chunk],
|
534
|
+
t_hat,
|
535
|
+
training=False,
|
536
|
+
network_condition_kwargs=dict(
|
537
|
+
multiplicity=sample_ids_chunk.numel(),
|
538
|
+
model_cache=model_cache,
|
539
|
+
**network_condition_kwargs,
|
540
|
+
),
|
541
|
+
)
|
542
|
+
atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk
|
543
|
+
token_a[sample_ids_chunk] = token_a_chunk
|
544
|
+
|
545
|
+
if steering_args is not None and steering_args["fk_steering"] and (
|
546
|
+
(
|
547
|
+
step_idx % steering_args["fk_resampling_interval"] == 0
|
548
|
+
and noise_var > 0
|
549
|
+
)
|
550
|
+
or step_idx == num_sampling_steps - 1
|
551
|
+
):
|
552
|
+
# Compute energy of x_0 prediction
|
553
|
+
energy = torch.zeros(multiplicity, device=self.device)
|
554
|
+
for potential in potentials:
|
555
|
+
parameters = potential.compute_parameters(steering_t)
|
556
|
+
if parameters["resampling_weight"] > 0:
|
557
|
+
component_energy = potential.compute(
|
558
|
+
atom_coords_denoised,
|
559
|
+
network_condition_kwargs["feats"],
|
560
|
+
parameters,
|
561
|
+
)
|
562
|
+
energy += parameters["resampling_weight"] * component_energy
|
563
|
+
energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1)
|
564
|
+
|
565
|
+
# Compute log G values
|
566
|
+
if step_idx == 0:
|
567
|
+
log_G = -1 * energy
|
568
|
+
else:
|
569
|
+
log_G = energy_traj[:, -2] - energy_traj[:, -1]
|
570
|
+
|
571
|
+
# Compute ll difference between guided and unguided transition distribution
|
572
|
+
if steering_args["guidance_update"] and noise_var > 0:
|
573
|
+
ll_difference = (
|
574
|
+
eps**2 - (eps + scaled_guidance_update) ** 2
|
575
|
+
).sum(dim=(-1, -2)) / (2 * noise_var)
|
576
|
+
else:
|
577
|
+
ll_difference = torch.zeros_like(energy)
|
578
|
+
|
579
|
+
# Compute resampling weights
|
580
|
+
resample_weights = F.softmax(
|
581
|
+
(ll_difference + steering_args["fk_lambda"] * log_G).reshape(
|
582
|
+
-1, steering_args["num_particles"]
|
583
|
+
),
|
584
|
+
dim=1,
|
585
|
+
)
|
586
|
+
|
587
|
+
# Compute guidance update to x_0 prediction
|
588
|
+
if (
|
589
|
+
steering_args is not None and
|
590
|
+
steering_args["guidance_update"]
|
591
|
+
and step_idx < num_sampling_steps - 1
|
592
|
+
):
|
593
|
+
guidance_update = torch.zeros_like(atom_coords_denoised)
|
594
|
+
for guidance_step in range(steering_args["num_gd_steps"]):
|
595
|
+
energy_gradient = torch.zeros_like(atom_coords_denoised)
|
596
|
+
for potential in potentials:
|
597
|
+
parameters = potential.compute_parameters(steering_t)
|
598
|
+
if (
|
599
|
+
parameters["guidance_weight"] > 0
|
600
|
+
and (guidance_step) % parameters["guidance_interval"]
|
601
|
+
== 0
|
602
|
+
):
|
603
|
+
energy_gradient += parameters[
|
604
|
+
"guidance_weight"
|
605
|
+
] * potential.compute_gradient(
|
606
|
+
atom_coords_denoised + guidance_update,
|
607
|
+
network_condition_kwargs["feats"],
|
608
|
+
parameters,
|
609
|
+
)
|
610
|
+
guidance_update -= energy_gradient
|
611
|
+
atom_coords_denoised += guidance_update
|
612
|
+
scaled_guidance_update = (
|
613
|
+
guidance_update
|
614
|
+
* -1
|
615
|
+
* self.step_scale
|
616
|
+
* (sigma_t - t_hat)
|
617
|
+
/ t_hat
|
618
|
+
)
|
619
|
+
|
620
|
+
if steering_args is not None and steering_args["fk_steering"] and (
|
621
|
+
(
|
622
|
+
step_idx % steering_args["fk_resampling_interval"] == 0
|
623
|
+
and noise_var > 0
|
624
|
+
)
|
625
|
+
or step_idx == num_sampling_steps - 1
|
626
|
+
):
|
627
|
+
resample_indices = (
|
628
|
+
torch.multinomial(
|
629
|
+
resample_weights,
|
630
|
+
resample_weights.shape[1]
|
631
|
+
if step_idx < num_sampling_steps - 1
|
632
|
+
else 1,
|
633
|
+
replacement=True,
|
634
|
+
)
|
635
|
+
+ resample_weights.shape[1]
|
636
|
+
* torch.arange(
|
637
|
+
resample_weights.shape[0], device=resample_weights.device
|
638
|
+
).unsqueeze(-1)
|
639
|
+
).flatten()
|
640
|
+
|
641
|
+
atom_coords = atom_coords[resample_indices]
|
642
|
+
atom_coords_noisy = atom_coords_noisy[resample_indices]
|
643
|
+
atom_mask = atom_mask[resample_indices]
|
644
|
+
if atom_coords_denoised is not None:
|
645
|
+
atom_coords_denoised = atom_coords_denoised[resample_indices]
|
646
|
+
energy_traj = energy_traj[resample_indices]
|
647
|
+
if steering_args["guidance_update"]:
|
648
|
+
scaled_guidance_update = scaled_guidance_update[
|
649
|
+
resample_indices
|
650
|
+
]
|
651
|
+
if token_repr is not None:
|
652
|
+
token_repr = token_repr[resample_indices]
|
653
|
+
if token_a is not None:
|
654
|
+
token_a = token_a[resample_indices]
|
655
|
+
|
656
|
+
if self.accumulate_token_repr:
|
657
|
+
if token_repr is None:
|
658
|
+
token_repr = torch.zeros_like(token_a)
|
659
|
+
|
660
|
+
with torch.set_grad_enabled(train_accumulate_token_repr):
|
661
|
+
sigma = torch.full(
|
662
|
+
(atom_coords_denoised.shape[0],),
|
663
|
+
t_hat,
|
664
|
+
device=atom_coords_denoised.device,
|
665
|
+
)
|
666
|
+
token_repr = self.out_token_feat_update(
|
667
|
+
times=self.c_noise(sigma), acc_a=token_repr, next_a=token_a
|
668
|
+
)
|
669
|
+
|
670
|
+
if self.alignment_reverse_diff:
|
671
|
+
with torch.autocast("cuda", enabled=False):
|
672
|
+
atom_coords_noisy = weighted_rigid_align(
|
673
|
+
atom_coords_noisy.float(),
|
674
|
+
atom_coords_denoised.float(),
|
675
|
+
atom_mask.float(),
|
676
|
+
atom_mask.float(),
|
677
|
+
)
|
678
|
+
|
679
|
+
atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
|
680
|
+
|
681
|
+
denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
|
682
|
+
atom_coords_next = (
|
683
|
+
atom_coords_noisy
|
684
|
+
+ self.step_scale * (sigma_t - t_hat) * denoised_over_sigma
|
685
|
+
)
|
686
|
+
|
687
|
+
atom_coords = atom_coords_next
|
688
|
+
|
689
|
+
return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
|
690
|
+
|
691
|
+
def loss_weight(self, sigma):
|
692
|
+
return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
|
693
|
+
|
694
|
+
def noise_distribution(self, batch_size):
|
695
|
+
return (
|
696
|
+
self.sigma_data
|
697
|
+
* (
|
698
|
+
self.P_mean
|
699
|
+
+ self.P_std * torch.randn((batch_size,), device=self.device)
|
700
|
+
).exp()
|
701
|
+
)
|
702
|
+
|
703
|
+
def forward(
|
704
|
+
self,
|
705
|
+
s_inputs,
|
706
|
+
s_trunk,
|
707
|
+
z_trunk,
|
708
|
+
relative_position_encoding,
|
709
|
+
feats,
|
710
|
+
multiplicity=1,
|
711
|
+
):
|
712
|
+
# training diffusion step
|
713
|
+
batch_size = feats["coords"].shape[0]
|
714
|
+
|
715
|
+
if self.synchronize_sigmas:
|
716
|
+
sigmas = self.noise_distribution(batch_size).repeat_interleave(
|
717
|
+
multiplicity, 0
|
718
|
+
)
|
719
|
+
else:
|
720
|
+
sigmas = self.noise_distribution(batch_size * multiplicity)
|
721
|
+
padded_sigmas = rearrange(sigmas, "b -> b 1 1")
|
722
|
+
|
723
|
+
atom_coords = feats["coords"]
|
724
|
+
B, N, L = atom_coords.shape[0:3]
|
725
|
+
atom_coords = atom_coords.reshape(B * N, L, 3)
|
726
|
+
atom_coords = atom_coords.repeat_interleave(multiplicity // N, 0)
|
727
|
+
feats["coords"] = atom_coords
|
728
|
+
|
729
|
+
atom_mask = feats["atom_pad_mask"]
|
730
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
731
|
+
|
732
|
+
atom_coords = center_random_augmentation(
|
733
|
+
atom_coords, atom_mask, augmentation=self.coordinate_augmentation
|
734
|
+
)
|
735
|
+
|
736
|
+
noise = torch.randn_like(atom_coords)
|
737
|
+
noised_atom_coords = atom_coords + padded_sigmas * noise
|
738
|
+
|
739
|
+
denoised_atom_coords, _ = self.preconditioned_network_forward(
|
740
|
+
noised_atom_coords,
|
741
|
+
sigmas,
|
742
|
+
training=True,
|
743
|
+
network_condition_kwargs=dict(
|
744
|
+
s_inputs=s_inputs,
|
745
|
+
s_trunk=s_trunk,
|
746
|
+
z_trunk=z_trunk,
|
747
|
+
relative_position_encoding=relative_position_encoding,
|
748
|
+
feats=feats,
|
749
|
+
multiplicity=multiplicity,
|
750
|
+
),
|
751
|
+
)
|
752
|
+
|
753
|
+
return dict(
|
754
|
+
noised_atom_coords=noised_atom_coords,
|
755
|
+
denoised_atom_coords=denoised_atom_coords,
|
756
|
+
sigmas=sigmas,
|
757
|
+
aligned_true_atom_coords=atom_coords,
|
758
|
+
)
|
759
|
+
|
760
|
+
def compute_loss(
|
761
|
+
self,
|
762
|
+
feats,
|
763
|
+
out_dict,
|
764
|
+
add_smooth_lddt_loss=True,
|
765
|
+
nucleotide_loss_weight=5.0,
|
766
|
+
ligand_loss_weight=10.0,
|
767
|
+
multiplicity=1,
|
768
|
+
):
|
769
|
+
denoised_atom_coords = out_dict["denoised_atom_coords"]
|
770
|
+
noised_atom_coords = out_dict["noised_atom_coords"]
|
771
|
+
sigmas = out_dict["sigmas"]
|
772
|
+
|
773
|
+
resolved_atom_mask = feats["atom_resolved_mask"]
|
774
|
+
resolved_atom_mask = resolved_atom_mask.repeat_interleave(multiplicity, 0)
|
775
|
+
|
776
|
+
align_weights = noised_atom_coords.new_ones(noised_atom_coords.shape[:2])
|
777
|
+
atom_type = (
|
778
|
+
torch.bmm(
|
779
|
+
feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
|
780
|
+
)
|
781
|
+
.squeeze(-1)
|
782
|
+
.long()
|
783
|
+
)
|
784
|
+
atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
|
785
|
+
|
786
|
+
align_weights = align_weights * (
|
787
|
+
1
|
788
|
+
+ nucleotide_loss_weight
|
789
|
+
* (
|
790
|
+
torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
|
791
|
+
+ torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
|
792
|
+
)
|
793
|
+
+ ligand_loss_weight
|
794
|
+
* torch.eq(atom_type_mult, const.chain_type_ids["NONPOLYMER"]).float()
|
795
|
+
)
|
796
|
+
|
797
|
+
with torch.no_grad(), torch.autocast("cuda", enabled=False):
|
798
|
+
atom_coords = out_dict["aligned_true_atom_coords"]
|
799
|
+
atom_coords_aligned_ground_truth = weighted_rigid_align(
|
800
|
+
atom_coords.detach().float(),
|
801
|
+
denoised_atom_coords.detach().float(),
|
802
|
+
align_weights.detach().float(),
|
803
|
+
mask=resolved_atom_mask.detach().float(),
|
804
|
+
)
|
805
|
+
|
806
|
+
# Cast back
|
807
|
+
atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
|
808
|
+
denoised_atom_coords
|
809
|
+
)
|
810
|
+
|
811
|
+
# weighted MSE loss of denoised atom positions
|
812
|
+
mse_loss = ((denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(
|
813
|
+
dim=-1
|
814
|
+
)
|
815
|
+
mse_loss = torch.sum(
|
816
|
+
mse_loss * align_weights * resolved_atom_mask, dim=-1
|
817
|
+
) / torch.sum(3 * align_weights * resolved_atom_mask, dim=-1)
|
818
|
+
|
819
|
+
# weight by sigma factor
|
820
|
+
loss_weights = self.loss_weight(sigmas)
|
821
|
+
mse_loss = (mse_loss * loss_weights).mean()
|
822
|
+
|
823
|
+
total_loss = mse_loss
|
824
|
+
|
825
|
+
# proposed auxiliary smooth lddt loss
|
826
|
+
lddt_loss = self.zero
|
827
|
+
if add_smooth_lddt_loss:
|
828
|
+
lddt_loss = smooth_lddt_loss(
|
829
|
+
denoised_atom_coords,
|
830
|
+
feats["coords"],
|
831
|
+
torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
|
832
|
+
+ torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
|
833
|
+
coords_mask=feats["atom_resolved_mask"],
|
834
|
+
multiplicity=multiplicity,
|
835
|
+
)
|
836
|
+
|
837
|
+
total_loss = total_loss + lddt_loss
|
838
|
+
|
839
|
+
loss_breakdown = dict(
|
840
|
+
mse_loss=mse_loss,
|
841
|
+
smooth_lddt_loss=lddt_loss,
|
842
|
+
)
|
843
|
+
|
844
|
+
return dict(loss=total_loss, loss_breakdown=loss_breakdown)
|