boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,677 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from math import sqrt
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
import torch.nn.functional as F # noqa: N812
|
10
|
+
from einops import rearrange
|
11
|
+
from torch import nn
|
12
|
+
from torch.nn import Module
|
13
|
+
|
14
|
+
import boltz.model.layers.initialize as init
|
15
|
+
from boltz.data import const
|
16
|
+
from boltz.model.loss.diffusionv2 import (
|
17
|
+
smooth_lddt_loss,
|
18
|
+
weighted_rigid_align,
|
19
|
+
)
|
20
|
+
from boltz.model.modules.encodersv2 import (
|
21
|
+
AtomAttentionDecoder,
|
22
|
+
AtomAttentionEncoder,
|
23
|
+
SingleConditioning,
|
24
|
+
)
|
25
|
+
from boltz.model.modules.transformersv2 import (
|
26
|
+
DiffusionTransformer,
|
27
|
+
)
|
28
|
+
from boltz.model.modules.utils import (
|
29
|
+
LinearNoBias,
|
30
|
+
center_random_augmentation,
|
31
|
+
compute_random_augmentation,
|
32
|
+
default,
|
33
|
+
log,
|
34
|
+
)
|
35
|
+
from boltz.model.potentials.potentials import get_potentials
|
36
|
+
|
37
|
+
|
38
|
+
class DiffusionModule(Module):
|
39
|
+
"""Diffusion module"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
token_s: int,
|
44
|
+
atom_s: int,
|
45
|
+
atoms_per_window_queries: int = 32,
|
46
|
+
atoms_per_window_keys: int = 128,
|
47
|
+
sigma_data: int = 16,
|
48
|
+
dim_fourier: int = 256,
|
49
|
+
atom_encoder_depth: int = 3,
|
50
|
+
atom_encoder_heads: int = 4,
|
51
|
+
token_transformer_depth: int = 24,
|
52
|
+
token_transformer_heads: int = 8,
|
53
|
+
atom_decoder_depth: int = 3,
|
54
|
+
atom_decoder_heads: int = 4,
|
55
|
+
conditioning_transition_layers: int = 2,
|
56
|
+
activation_checkpointing: bool = False,
|
57
|
+
transformer_post_ln: bool = False,
|
58
|
+
) -> None:
|
59
|
+
super().__init__()
|
60
|
+
|
61
|
+
self.atoms_per_window_queries = atoms_per_window_queries
|
62
|
+
self.atoms_per_window_keys = atoms_per_window_keys
|
63
|
+
self.sigma_data = sigma_data
|
64
|
+
self.activation_checkpointing = activation_checkpointing
|
65
|
+
|
66
|
+
# conditioning
|
67
|
+
self.single_conditioner = SingleConditioning(
|
68
|
+
sigma_data=sigma_data,
|
69
|
+
token_s=token_s,
|
70
|
+
dim_fourier=dim_fourier,
|
71
|
+
num_transitions=conditioning_transition_layers,
|
72
|
+
)
|
73
|
+
|
74
|
+
self.atom_attention_encoder = AtomAttentionEncoder(
|
75
|
+
atom_s=atom_s,
|
76
|
+
token_s=token_s,
|
77
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
78
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
79
|
+
atom_encoder_depth=atom_encoder_depth,
|
80
|
+
atom_encoder_heads=atom_encoder_heads,
|
81
|
+
structure_prediction=True,
|
82
|
+
activation_checkpointing=activation_checkpointing,
|
83
|
+
transformer_post_layer_norm=transformer_post_ln,
|
84
|
+
)
|
85
|
+
|
86
|
+
self.s_to_a_linear = nn.Sequential(
|
87
|
+
nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
|
88
|
+
)
|
89
|
+
init.final_init_(self.s_to_a_linear[1].weight)
|
90
|
+
|
91
|
+
self.token_transformer = DiffusionTransformer(
|
92
|
+
dim=2 * token_s,
|
93
|
+
dim_single_cond=2 * token_s,
|
94
|
+
depth=token_transformer_depth,
|
95
|
+
heads=token_transformer_heads,
|
96
|
+
activation_checkpointing=activation_checkpointing,
|
97
|
+
# post_layer_norm=transformer_post_ln,
|
98
|
+
)
|
99
|
+
|
100
|
+
self.a_norm = nn.LayerNorm(
|
101
|
+
2 * token_s
|
102
|
+
) # if not transformer_post_ln else nn.Identity()
|
103
|
+
|
104
|
+
self.atom_attention_decoder = AtomAttentionDecoder(
|
105
|
+
atom_s=atom_s,
|
106
|
+
token_s=token_s,
|
107
|
+
attn_window_queries=atoms_per_window_queries,
|
108
|
+
attn_window_keys=atoms_per_window_keys,
|
109
|
+
atom_decoder_depth=atom_decoder_depth,
|
110
|
+
atom_decoder_heads=atom_decoder_heads,
|
111
|
+
activation_checkpointing=activation_checkpointing,
|
112
|
+
# transformer_post_layer_norm=transformer_post_ln,
|
113
|
+
)
|
114
|
+
|
115
|
+
def forward(
|
116
|
+
self,
|
117
|
+
s_inputs, # Float['b n ts']
|
118
|
+
s_trunk, # Float['b n ts']
|
119
|
+
r_noisy, # Float['bm m 3']
|
120
|
+
times, # Float['bm 1 1']
|
121
|
+
feats,
|
122
|
+
diffusion_conditioning,
|
123
|
+
multiplicity=1,
|
124
|
+
):
|
125
|
+
if self.activation_checkpointing and self.training:
|
126
|
+
s, normed_fourier = torch.utils.checkpoint.checkpoint(
|
127
|
+
self.single_conditioner,
|
128
|
+
times,
|
129
|
+
s_trunk.repeat_interleave(multiplicity, 0),
|
130
|
+
s_inputs.repeat_interleave(multiplicity, 0),
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
s, normed_fourier = self.single_conditioner(
|
134
|
+
times,
|
135
|
+
s_trunk.repeat_interleave(multiplicity, 0),
|
136
|
+
s_inputs.repeat_interleave(multiplicity, 0),
|
137
|
+
)
|
138
|
+
|
139
|
+
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
|
140
|
+
a, q_skip, c_skip, to_keys = self.atom_attention_encoder(
|
141
|
+
feats=feats,
|
142
|
+
q=diffusion_conditioning["q"].float(),
|
143
|
+
c=diffusion_conditioning["c"].float(),
|
144
|
+
atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(),
|
145
|
+
to_keys=diffusion_conditioning["to_keys"],
|
146
|
+
r=r_noisy, # Float['b m 3'],
|
147
|
+
multiplicity=multiplicity,
|
148
|
+
)
|
149
|
+
|
150
|
+
# Full self-attention on token level
|
151
|
+
a = a + self.s_to_a_linear(s)
|
152
|
+
|
153
|
+
mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
154
|
+
a = self.token_transformer(
|
155
|
+
a,
|
156
|
+
mask=mask.float(),
|
157
|
+
s=s,
|
158
|
+
bias=diffusion_conditioning[
|
159
|
+
"token_trans_bias"
|
160
|
+
].float(), # note z is not expanded with multiplicity until after bias is computed
|
161
|
+
multiplicity=multiplicity,
|
162
|
+
)
|
163
|
+
a = self.a_norm(a)
|
164
|
+
|
165
|
+
# Broadcast token activations to atoms and run Sequence-local Atom Attention
|
166
|
+
r_update = self.atom_attention_decoder(
|
167
|
+
a=a,
|
168
|
+
q=q_skip,
|
169
|
+
c=c_skip,
|
170
|
+
atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(),
|
171
|
+
feats=feats,
|
172
|
+
multiplicity=multiplicity,
|
173
|
+
to_keys=to_keys,
|
174
|
+
)
|
175
|
+
|
176
|
+
return r_update
|
177
|
+
|
178
|
+
|
179
|
+
class AtomDiffusion(Module):
|
180
|
+
def __init__(
|
181
|
+
self,
|
182
|
+
score_model_args,
|
183
|
+
num_sampling_steps: int = 5, # number of sampling steps
|
184
|
+
sigma_min: float = 0.0004, # min noise level
|
185
|
+
sigma_max: float = 160.0, # max noise level
|
186
|
+
sigma_data: float = 16.0, # standard deviation of data distribution
|
187
|
+
rho: float = 7, # controls the sampling schedule
|
188
|
+
P_mean: float = -1.2, # mean of log-normal distribution from which noise is drawn for training
|
189
|
+
P_std: float = 1.5, # standard deviation of log-normal distribution from which noise is drawn for training
|
190
|
+
gamma_0: float = 0.8,
|
191
|
+
gamma_min: float = 1.0,
|
192
|
+
noise_scale: float = 1.003,
|
193
|
+
step_scale: float = 1.5,
|
194
|
+
step_scale_random: list = None,
|
195
|
+
coordinate_augmentation: bool = True,
|
196
|
+
coordinate_augmentation_inference=None,
|
197
|
+
compile_score: bool = False,
|
198
|
+
alignment_reverse_diff: bool = False,
|
199
|
+
synchronize_sigmas: bool = False,
|
200
|
+
):
|
201
|
+
super().__init__()
|
202
|
+
self.score_model = DiffusionModule(
|
203
|
+
**score_model_args,
|
204
|
+
)
|
205
|
+
if compile_score:
|
206
|
+
self.score_model = torch.compile(
|
207
|
+
self.score_model, dynamic=False, fullgraph=False
|
208
|
+
)
|
209
|
+
|
210
|
+
# parameters
|
211
|
+
self.sigma_min = sigma_min
|
212
|
+
self.sigma_max = sigma_max
|
213
|
+
self.sigma_data = sigma_data
|
214
|
+
self.rho = rho
|
215
|
+
self.P_mean = P_mean
|
216
|
+
self.P_std = P_std
|
217
|
+
self.num_sampling_steps = num_sampling_steps
|
218
|
+
self.gamma_0 = gamma_0
|
219
|
+
self.gamma_min = gamma_min
|
220
|
+
self.noise_scale = noise_scale
|
221
|
+
self.step_scale = step_scale
|
222
|
+
self.step_scale_random = step_scale_random
|
223
|
+
self.coordinate_augmentation = coordinate_augmentation
|
224
|
+
self.coordinate_augmentation_inference = (
|
225
|
+
coordinate_augmentation_inference
|
226
|
+
if coordinate_augmentation_inference is not None
|
227
|
+
else coordinate_augmentation
|
228
|
+
)
|
229
|
+
self.alignment_reverse_diff = alignment_reverse_diff
|
230
|
+
self.synchronize_sigmas = synchronize_sigmas
|
231
|
+
|
232
|
+
self.token_s = score_model_args["token_s"]
|
233
|
+
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
234
|
+
|
235
|
+
@property
|
236
|
+
def device(self):
|
237
|
+
return next(self.score_model.parameters()).device
|
238
|
+
|
239
|
+
def c_skip(self, sigma):
|
240
|
+
return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
|
241
|
+
|
242
|
+
def c_out(self, sigma):
|
243
|
+
return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
|
244
|
+
|
245
|
+
def c_in(self, sigma):
|
246
|
+
return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
|
247
|
+
|
248
|
+
def c_noise(self, sigma):
|
249
|
+
return log(sigma / self.sigma_data) * 0.25
|
250
|
+
|
251
|
+
def preconditioned_network_forward(
|
252
|
+
self,
|
253
|
+
noised_atom_coords, #: Float['b m 3'],
|
254
|
+
sigma, #: Float['b'] | Float[' '] | float,
|
255
|
+
network_condition_kwargs: dict,
|
256
|
+
):
|
257
|
+
batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
|
258
|
+
|
259
|
+
if isinstance(sigma, float):
|
260
|
+
sigma = torch.full((batch,), sigma, device=device)
|
261
|
+
|
262
|
+
padded_sigma = rearrange(sigma, "b -> b 1 1")
|
263
|
+
|
264
|
+
r_update = self.score_model(
|
265
|
+
r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
|
266
|
+
times=self.c_noise(sigma),
|
267
|
+
**network_condition_kwargs,
|
268
|
+
)
|
269
|
+
|
270
|
+
denoised_coords = (
|
271
|
+
self.c_skip(padded_sigma) * noised_atom_coords
|
272
|
+
+ self.c_out(padded_sigma) * r_update
|
273
|
+
)
|
274
|
+
return denoised_coords
|
275
|
+
|
276
|
+
def sample_schedule(self, num_sampling_steps=None):
|
277
|
+
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
|
278
|
+
inv_rho = 1 / self.rho
|
279
|
+
|
280
|
+
steps = torch.arange(
|
281
|
+
num_sampling_steps, device=self.device, dtype=torch.float32
|
282
|
+
)
|
283
|
+
sigmas = (
|
284
|
+
self.sigma_max**inv_rho
|
285
|
+
+ steps
|
286
|
+
/ (num_sampling_steps - 1)
|
287
|
+
* (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
|
288
|
+
) ** self.rho
|
289
|
+
|
290
|
+
sigmas = sigmas * self.sigma_data
|
291
|
+
|
292
|
+
sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0.
|
293
|
+
return sigmas
|
294
|
+
|
295
|
+
def sample(
|
296
|
+
self,
|
297
|
+
atom_mask,
|
298
|
+
num_sampling_steps=None,
|
299
|
+
multiplicity=1,
|
300
|
+
max_parallel_samples=None,
|
301
|
+
steering_args=None,
|
302
|
+
**network_condition_kwargs,
|
303
|
+
):
|
304
|
+
potentials = get_potentials()
|
305
|
+
if steering_args["fk_steering"]:
|
306
|
+
multiplicity = multiplicity * steering_args["num_particles"]
|
307
|
+
energy_traj = torch.empty((multiplicity, 0), device=self.device)
|
308
|
+
resample_weights = torch.ones(multiplicity, device=self.device).reshape(
|
309
|
+
-1, steering_args["num_particles"]
|
310
|
+
)
|
311
|
+
if steering_args["guidance_update"]:
|
312
|
+
scaled_guidance_update = torch.zeros(
|
313
|
+
(multiplicity, *atom_mask.shape[1:], 3),
|
314
|
+
dtype=torch.float32,
|
315
|
+
device=self.device,
|
316
|
+
)
|
317
|
+
if max_parallel_samples is None:
|
318
|
+
max_parallel_samples = multiplicity
|
319
|
+
|
320
|
+
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
|
321
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
322
|
+
|
323
|
+
shape = (*atom_mask.shape, 3)
|
324
|
+
|
325
|
+
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
|
326
|
+
sigmas = self.sample_schedule(num_sampling_steps)
|
327
|
+
gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
|
328
|
+
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
|
329
|
+
if self.training and self.step_scale_random is not None:
|
330
|
+
step_scale = np.random.choice(self.step_scale_random)
|
331
|
+
else:
|
332
|
+
step_scale = self.step_scale
|
333
|
+
|
334
|
+
# atom position is noise at the beginning
|
335
|
+
init_sigma = sigmas[0]
|
336
|
+
atom_coords = init_sigma * torch.randn(shape, device=self.device)
|
337
|
+
token_repr = None
|
338
|
+
atom_coords_denoised = None
|
339
|
+
|
340
|
+
# gradually denoise
|
341
|
+
for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas):
|
342
|
+
random_R, random_tr = compute_random_augmentation(
|
343
|
+
multiplicity, device=atom_coords.device, dtype=atom_coords.dtype
|
344
|
+
)
|
345
|
+
atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
|
346
|
+
atom_coords = (
|
347
|
+
torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr
|
348
|
+
)
|
349
|
+
if atom_coords_denoised is not None:
|
350
|
+
atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True)
|
351
|
+
atom_coords_denoised = (
|
352
|
+
torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R)
|
353
|
+
+ random_tr
|
354
|
+
)
|
355
|
+
if steering_args["guidance_update"] and scaled_guidance_update is not None:
|
356
|
+
scaled_guidance_update = torch.einsum(
|
357
|
+
"bmd,bds->bms", scaled_guidance_update, random_R
|
358
|
+
)
|
359
|
+
|
360
|
+
sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
|
361
|
+
|
362
|
+
t_hat = sigma_tm * (1 + gamma)
|
363
|
+
steering_t = 1.0 - (step_idx / num_sampling_steps)
|
364
|
+
noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2)
|
365
|
+
eps = sqrt(noise_var) * torch.randn(shape, device=self.device)
|
366
|
+
atom_coords_noisy = atom_coords + eps
|
367
|
+
|
368
|
+
with torch.no_grad():
|
369
|
+
atom_coords_denoised = torch.zeros_like(atom_coords_noisy)
|
370
|
+
sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device)
|
371
|
+
sample_ids_chunks = sample_ids.chunk(
|
372
|
+
multiplicity % max_parallel_samples + 1
|
373
|
+
)
|
374
|
+
|
375
|
+
for sample_ids_chunk in sample_ids_chunks:
|
376
|
+
atom_coords_denoised_chunk = self.preconditioned_network_forward(
|
377
|
+
atom_coords_noisy[sample_ids_chunk],
|
378
|
+
t_hat,
|
379
|
+
network_condition_kwargs=dict(
|
380
|
+
multiplicity=sample_ids_chunk.numel(),
|
381
|
+
**network_condition_kwargs,
|
382
|
+
),
|
383
|
+
)
|
384
|
+
atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk
|
385
|
+
|
386
|
+
if steering_args["fk_steering"] and (
|
387
|
+
(
|
388
|
+
step_idx % steering_args["fk_resampling_interval"] == 0
|
389
|
+
and noise_var > 0
|
390
|
+
)
|
391
|
+
or step_idx == num_sampling_steps - 1
|
392
|
+
):
|
393
|
+
# Compute energy of x_0 prediction
|
394
|
+
energy = torch.zeros(multiplicity, device=self.device)
|
395
|
+
for potential in potentials:
|
396
|
+
parameters = potential.compute_parameters(steering_t)
|
397
|
+
if parameters["resampling_weight"] > 0:
|
398
|
+
component_energy = potential.compute(
|
399
|
+
atom_coords_denoised,
|
400
|
+
network_condition_kwargs["feats"],
|
401
|
+
parameters,
|
402
|
+
)
|
403
|
+
energy += parameters["resampling_weight"] * component_energy
|
404
|
+
energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1)
|
405
|
+
|
406
|
+
# Compute log G values
|
407
|
+
if step_idx == 0:
|
408
|
+
log_G = -1 * energy
|
409
|
+
else:
|
410
|
+
log_G = energy_traj[:, -2] - energy_traj[:, -1]
|
411
|
+
|
412
|
+
# Compute ll difference between guided and unguided transition distribution
|
413
|
+
if steering_args["guidance_update"] and noise_var > 0:
|
414
|
+
ll_difference = (
|
415
|
+
eps**2 - (eps + scaled_guidance_update) ** 2
|
416
|
+
).sum(dim=(-1, -2)) / (2 * noise_var)
|
417
|
+
else:
|
418
|
+
ll_difference = torch.zeros_like(energy)
|
419
|
+
|
420
|
+
# Compute resampling weights
|
421
|
+
resample_weights = F.softmax(
|
422
|
+
(ll_difference + steering_args["fk_lambda"] * log_G).reshape(
|
423
|
+
-1, steering_args["num_particles"]
|
424
|
+
),
|
425
|
+
dim=1,
|
426
|
+
)
|
427
|
+
|
428
|
+
# Compute guidance update to x_0 prediction
|
429
|
+
if (
|
430
|
+
steering_args["guidance_update"]
|
431
|
+
and step_idx < num_sampling_steps - 1
|
432
|
+
):
|
433
|
+
guidance_update = torch.zeros_like(atom_coords_denoised)
|
434
|
+
for guidance_step in range(steering_args["num_gd_steps"]):
|
435
|
+
energy_gradient = torch.zeros_like(atom_coords_denoised)
|
436
|
+
for potential in potentials:
|
437
|
+
parameters = potential.compute_parameters(steering_t)
|
438
|
+
if (
|
439
|
+
parameters["guidance_weight"] > 0
|
440
|
+
and (guidance_step) % parameters["guidance_interval"]
|
441
|
+
== 0
|
442
|
+
):
|
443
|
+
energy_gradient += parameters[
|
444
|
+
"guidance_weight"
|
445
|
+
] * potential.compute_gradient(
|
446
|
+
atom_coords_denoised + guidance_update,
|
447
|
+
network_condition_kwargs["feats"],
|
448
|
+
parameters,
|
449
|
+
)
|
450
|
+
guidance_update -= energy_gradient
|
451
|
+
atom_coords_denoised += guidance_update
|
452
|
+
scaled_guidance_update = (
|
453
|
+
guidance_update
|
454
|
+
* -1
|
455
|
+
* self.step_scale
|
456
|
+
* (sigma_t - t_hat)
|
457
|
+
/ t_hat
|
458
|
+
)
|
459
|
+
|
460
|
+
if steering_args["fk_steering"] and (
|
461
|
+
(
|
462
|
+
step_idx % steering_args["fk_resampling_interval"] == 0
|
463
|
+
and noise_var > 0
|
464
|
+
)
|
465
|
+
or step_idx == num_sampling_steps - 1
|
466
|
+
):
|
467
|
+
resample_indices = (
|
468
|
+
torch.multinomial(
|
469
|
+
resample_weights,
|
470
|
+
resample_weights.shape[1]
|
471
|
+
if step_idx < num_sampling_steps - 1
|
472
|
+
else 1,
|
473
|
+
replacement=True,
|
474
|
+
)
|
475
|
+
+ resample_weights.shape[1]
|
476
|
+
* torch.arange(
|
477
|
+
resample_weights.shape[0], device=resample_weights.device
|
478
|
+
).unsqueeze(-1)
|
479
|
+
).flatten()
|
480
|
+
|
481
|
+
atom_coords = atom_coords[resample_indices]
|
482
|
+
atom_coords_noisy = atom_coords_noisy[resample_indices]
|
483
|
+
atom_mask = atom_mask[resample_indices]
|
484
|
+
if atom_coords_denoised is not None:
|
485
|
+
atom_coords_denoised = atom_coords_denoised[resample_indices]
|
486
|
+
energy_traj = energy_traj[resample_indices]
|
487
|
+
if steering_args["guidance_update"]:
|
488
|
+
scaled_guidance_update = scaled_guidance_update[
|
489
|
+
resample_indices
|
490
|
+
]
|
491
|
+
if token_repr is not None:
|
492
|
+
token_repr = token_repr[resample_indices]
|
493
|
+
|
494
|
+
if self.alignment_reverse_diff:
|
495
|
+
with torch.autocast("cuda", enabled=False):
|
496
|
+
atom_coords_noisy = weighted_rigid_align(
|
497
|
+
atom_coords_noisy.float(),
|
498
|
+
atom_coords_denoised.float(),
|
499
|
+
atom_mask.float(),
|
500
|
+
atom_mask.float(),
|
501
|
+
)
|
502
|
+
|
503
|
+
atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
|
504
|
+
|
505
|
+
denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
|
506
|
+
atom_coords_next = (
|
507
|
+
atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma
|
508
|
+
)
|
509
|
+
|
510
|
+
atom_coords = atom_coords_next
|
511
|
+
|
512
|
+
return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
|
513
|
+
|
514
|
+
def loss_weight(self, sigma):
|
515
|
+
return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
|
516
|
+
|
517
|
+
def noise_distribution(self, batch_size):
|
518
|
+
return (
|
519
|
+
self.sigma_data
|
520
|
+
* (
|
521
|
+
self.P_mean
|
522
|
+
+ self.P_std * torch.randn((batch_size,), device=self.device)
|
523
|
+
).exp()
|
524
|
+
)
|
525
|
+
|
526
|
+
def forward(
|
527
|
+
self,
|
528
|
+
s_inputs,
|
529
|
+
s_trunk,
|
530
|
+
feats,
|
531
|
+
diffusion_conditioning,
|
532
|
+
multiplicity=1,
|
533
|
+
):
|
534
|
+
# training diffusion step
|
535
|
+
batch_size = feats["coords"].shape[0] // multiplicity
|
536
|
+
|
537
|
+
if self.synchronize_sigmas:
|
538
|
+
sigmas = self.noise_distribution(batch_size).repeat_interleave(
|
539
|
+
multiplicity, 0
|
540
|
+
)
|
541
|
+
else:
|
542
|
+
sigmas = self.noise_distribution(batch_size * multiplicity)
|
543
|
+
padded_sigmas = rearrange(sigmas, "b -> b 1 1")
|
544
|
+
|
545
|
+
atom_coords = feats["coords"]
|
546
|
+
|
547
|
+
atom_mask = feats["atom_pad_mask"]
|
548
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
549
|
+
|
550
|
+
atom_coords = center_random_augmentation(
|
551
|
+
atom_coords, atom_mask, augmentation=self.coordinate_augmentation
|
552
|
+
)
|
553
|
+
|
554
|
+
noise = torch.randn_like(atom_coords)
|
555
|
+
noised_atom_coords = atom_coords + padded_sigmas * noise
|
556
|
+
|
557
|
+
denoised_atom_coords = self.preconditioned_network_forward(
|
558
|
+
noised_atom_coords,
|
559
|
+
sigmas,
|
560
|
+
network_condition_kwargs={
|
561
|
+
"s_inputs": s_inputs,
|
562
|
+
"s_trunk": s_trunk,
|
563
|
+
"feats": feats,
|
564
|
+
"multiplicity": multiplicity,
|
565
|
+
"diffusion_conditioning": diffusion_conditioning,
|
566
|
+
},
|
567
|
+
)
|
568
|
+
|
569
|
+
return {
|
570
|
+
"noised_atom_coords": noised_atom_coords,
|
571
|
+
"denoised_atom_coords": denoised_atom_coords,
|
572
|
+
"sigmas": sigmas,
|
573
|
+
"aligned_true_atom_coords": atom_coords,
|
574
|
+
}
|
575
|
+
|
576
|
+
def compute_loss(
|
577
|
+
self,
|
578
|
+
feats,
|
579
|
+
out_dict,
|
580
|
+
add_smooth_lddt_loss=True,
|
581
|
+
nucleotide_loss_weight=5.0,
|
582
|
+
ligand_loss_weight=10.0,
|
583
|
+
multiplicity=1,
|
584
|
+
filter_by_plddt=0.0,
|
585
|
+
):
|
586
|
+
with torch.autocast("cuda", enabled=False):
|
587
|
+
denoised_atom_coords = out_dict["denoised_atom_coords"].float()
|
588
|
+
noised_atom_coords = out_dict["noised_atom_coords"].float()
|
589
|
+
sigmas = out_dict["sigmas"].float()
|
590
|
+
|
591
|
+
resolved_atom_mask_uni = feats["atom_resolved_mask"].float()
|
592
|
+
|
593
|
+
if filter_by_plddt > 0:
|
594
|
+
plddt_mask = feats["plddt"] > filter_by_plddt
|
595
|
+
resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float()
|
596
|
+
|
597
|
+
resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave(
|
598
|
+
multiplicity, 0
|
599
|
+
)
|
600
|
+
|
601
|
+
align_weights = noised_atom_coords.new_ones(noised_atom_coords.shape[:2])
|
602
|
+
atom_type = (
|
603
|
+
torch.bmm(
|
604
|
+
feats["atom_to_token"].float(),
|
605
|
+
feats["mol_type"].unsqueeze(-1).float(),
|
606
|
+
)
|
607
|
+
.squeeze(-1)
|
608
|
+
.long()
|
609
|
+
)
|
610
|
+
atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
|
611
|
+
|
612
|
+
align_weights = (
|
613
|
+
align_weights
|
614
|
+
* (
|
615
|
+
1
|
616
|
+
+ nucleotide_loss_weight
|
617
|
+
* (
|
618
|
+
torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
|
619
|
+
+ torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
|
620
|
+
)
|
621
|
+
+ ligand_loss_weight
|
622
|
+
* torch.eq(
|
623
|
+
atom_type_mult, const.chain_type_ids["NONPOLYMER"]
|
624
|
+
).float()
|
625
|
+
).float()
|
626
|
+
)
|
627
|
+
|
628
|
+
atom_coords = out_dict["aligned_true_atom_coords"].float()
|
629
|
+
atom_coords_aligned_ground_truth = weighted_rigid_align(
|
630
|
+
atom_coords.detach(),
|
631
|
+
denoised_atom_coords.detach(),
|
632
|
+
align_weights.detach(),
|
633
|
+
mask=feats["atom_resolved_mask"]
|
634
|
+
.float()
|
635
|
+
.repeat_interleave(multiplicity, 0)
|
636
|
+
.detach(),
|
637
|
+
)
|
638
|
+
|
639
|
+
# Cast back
|
640
|
+
atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
|
641
|
+
denoised_atom_coords
|
642
|
+
)
|
643
|
+
|
644
|
+
# weighted MSE loss of denoised atom positions
|
645
|
+
mse_loss = (
|
646
|
+
(denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2
|
647
|
+
).sum(dim=-1)
|
648
|
+
mse_loss = torch.sum(
|
649
|
+
mse_loss * align_weights * resolved_atom_mask, dim=-1
|
650
|
+
) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5)
|
651
|
+
|
652
|
+
# weight by sigma factor
|
653
|
+
loss_weights = self.loss_weight(sigmas)
|
654
|
+
mse_loss = (mse_loss * loss_weights).mean()
|
655
|
+
|
656
|
+
total_loss = mse_loss
|
657
|
+
|
658
|
+
# proposed auxiliary smooth lddt loss
|
659
|
+
lddt_loss = self.zero
|
660
|
+
if add_smooth_lddt_loss:
|
661
|
+
lddt_loss = smooth_lddt_loss(
|
662
|
+
denoised_atom_coords,
|
663
|
+
feats["coords"],
|
664
|
+
torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
|
665
|
+
+ torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
|
666
|
+
coords_mask=resolved_atom_mask_uni,
|
667
|
+
multiplicity=multiplicity,
|
668
|
+
)
|
669
|
+
|
670
|
+
total_loss = total_loss + lddt_loss
|
671
|
+
|
672
|
+
loss_breakdown = {
|
673
|
+
"mse_loss": mse_loss,
|
674
|
+
"smooth_lddt_loss": lddt_loss,
|
675
|
+
}
|
676
|
+
|
677
|
+
return {"loss": total_loss, "loss_breakdown": loss_breakdown}
|