boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,677 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ from __future__ import annotations
4
+
5
+ from math import sqrt
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F # noqa: N812
10
+ from einops import rearrange
11
+ from torch import nn
12
+ from torch.nn import Module
13
+
14
+ import boltz.model.layers.initialize as init
15
+ from boltz.data import const
16
+ from boltz.model.loss.diffusionv2 import (
17
+ smooth_lddt_loss,
18
+ weighted_rigid_align,
19
+ )
20
+ from boltz.model.modules.encodersv2 import (
21
+ AtomAttentionDecoder,
22
+ AtomAttentionEncoder,
23
+ SingleConditioning,
24
+ )
25
+ from boltz.model.modules.transformersv2 import (
26
+ DiffusionTransformer,
27
+ )
28
+ from boltz.model.modules.utils import (
29
+ LinearNoBias,
30
+ center_random_augmentation,
31
+ compute_random_augmentation,
32
+ default,
33
+ log,
34
+ )
35
+ from boltz.model.potentials.potentials import get_potentials
36
+
37
+
38
+ class DiffusionModule(Module):
39
+ """Diffusion module"""
40
+
41
+ def __init__(
42
+ self,
43
+ token_s: int,
44
+ atom_s: int,
45
+ atoms_per_window_queries: int = 32,
46
+ atoms_per_window_keys: int = 128,
47
+ sigma_data: int = 16,
48
+ dim_fourier: int = 256,
49
+ atom_encoder_depth: int = 3,
50
+ atom_encoder_heads: int = 4,
51
+ token_transformer_depth: int = 24,
52
+ token_transformer_heads: int = 8,
53
+ atom_decoder_depth: int = 3,
54
+ atom_decoder_heads: int = 4,
55
+ conditioning_transition_layers: int = 2,
56
+ activation_checkpointing: bool = False,
57
+ transformer_post_ln: bool = False,
58
+ ) -> None:
59
+ super().__init__()
60
+
61
+ self.atoms_per_window_queries = atoms_per_window_queries
62
+ self.atoms_per_window_keys = atoms_per_window_keys
63
+ self.sigma_data = sigma_data
64
+ self.activation_checkpointing = activation_checkpointing
65
+
66
+ # conditioning
67
+ self.single_conditioner = SingleConditioning(
68
+ sigma_data=sigma_data,
69
+ token_s=token_s,
70
+ dim_fourier=dim_fourier,
71
+ num_transitions=conditioning_transition_layers,
72
+ )
73
+
74
+ self.atom_attention_encoder = AtomAttentionEncoder(
75
+ atom_s=atom_s,
76
+ token_s=token_s,
77
+ atoms_per_window_queries=atoms_per_window_queries,
78
+ atoms_per_window_keys=atoms_per_window_keys,
79
+ atom_encoder_depth=atom_encoder_depth,
80
+ atom_encoder_heads=atom_encoder_heads,
81
+ structure_prediction=True,
82
+ activation_checkpointing=activation_checkpointing,
83
+ transformer_post_layer_norm=transformer_post_ln,
84
+ )
85
+
86
+ self.s_to_a_linear = nn.Sequential(
87
+ nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
88
+ )
89
+ init.final_init_(self.s_to_a_linear[1].weight)
90
+
91
+ self.token_transformer = DiffusionTransformer(
92
+ dim=2 * token_s,
93
+ dim_single_cond=2 * token_s,
94
+ depth=token_transformer_depth,
95
+ heads=token_transformer_heads,
96
+ activation_checkpointing=activation_checkpointing,
97
+ # post_layer_norm=transformer_post_ln,
98
+ )
99
+
100
+ self.a_norm = nn.LayerNorm(
101
+ 2 * token_s
102
+ ) # if not transformer_post_ln else nn.Identity()
103
+
104
+ self.atom_attention_decoder = AtomAttentionDecoder(
105
+ atom_s=atom_s,
106
+ token_s=token_s,
107
+ attn_window_queries=atoms_per_window_queries,
108
+ attn_window_keys=atoms_per_window_keys,
109
+ atom_decoder_depth=atom_decoder_depth,
110
+ atom_decoder_heads=atom_decoder_heads,
111
+ activation_checkpointing=activation_checkpointing,
112
+ # transformer_post_layer_norm=transformer_post_ln,
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ s_inputs, # Float['b n ts']
118
+ s_trunk, # Float['b n ts']
119
+ r_noisy, # Float['bm m 3']
120
+ times, # Float['bm 1 1']
121
+ feats,
122
+ diffusion_conditioning,
123
+ multiplicity=1,
124
+ ):
125
+ if self.activation_checkpointing and self.training:
126
+ s, normed_fourier = torch.utils.checkpoint.checkpoint(
127
+ self.single_conditioner,
128
+ times,
129
+ s_trunk.repeat_interleave(multiplicity, 0),
130
+ s_inputs.repeat_interleave(multiplicity, 0),
131
+ )
132
+ else:
133
+ s, normed_fourier = self.single_conditioner(
134
+ times,
135
+ s_trunk.repeat_interleave(multiplicity, 0),
136
+ s_inputs.repeat_interleave(multiplicity, 0),
137
+ )
138
+
139
+ # Sequence-local Atom Attention and aggregation to coarse-grained tokens
140
+ a, q_skip, c_skip, to_keys = self.atom_attention_encoder(
141
+ feats=feats,
142
+ q=diffusion_conditioning["q"].float(),
143
+ c=diffusion_conditioning["c"].float(),
144
+ atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(),
145
+ to_keys=diffusion_conditioning["to_keys"],
146
+ r=r_noisy, # Float['b m 3'],
147
+ multiplicity=multiplicity,
148
+ )
149
+
150
+ # Full self-attention on token level
151
+ a = a + self.s_to_a_linear(s)
152
+
153
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
154
+ a = self.token_transformer(
155
+ a,
156
+ mask=mask.float(),
157
+ s=s,
158
+ bias=diffusion_conditioning[
159
+ "token_trans_bias"
160
+ ].float(), # note z is not expanded with multiplicity until after bias is computed
161
+ multiplicity=multiplicity,
162
+ )
163
+ a = self.a_norm(a)
164
+
165
+ # Broadcast token activations to atoms and run Sequence-local Atom Attention
166
+ r_update = self.atom_attention_decoder(
167
+ a=a,
168
+ q=q_skip,
169
+ c=c_skip,
170
+ atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(),
171
+ feats=feats,
172
+ multiplicity=multiplicity,
173
+ to_keys=to_keys,
174
+ )
175
+
176
+ return r_update
177
+
178
+
179
+ class AtomDiffusion(Module):
180
+ def __init__(
181
+ self,
182
+ score_model_args,
183
+ num_sampling_steps: int = 5, # number of sampling steps
184
+ sigma_min: float = 0.0004, # min noise level
185
+ sigma_max: float = 160.0, # max noise level
186
+ sigma_data: float = 16.0, # standard deviation of data distribution
187
+ rho: float = 7, # controls the sampling schedule
188
+ P_mean: float = -1.2, # mean of log-normal distribution from which noise is drawn for training
189
+ P_std: float = 1.5, # standard deviation of log-normal distribution from which noise is drawn for training
190
+ gamma_0: float = 0.8,
191
+ gamma_min: float = 1.0,
192
+ noise_scale: float = 1.003,
193
+ step_scale: float = 1.5,
194
+ step_scale_random: list = None,
195
+ coordinate_augmentation: bool = True,
196
+ coordinate_augmentation_inference=None,
197
+ compile_score: bool = False,
198
+ alignment_reverse_diff: bool = False,
199
+ synchronize_sigmas: bool = False,
200
+ ):
201
+ super().__init__()
202
+ self.score_model = DiffusionModule(
203
+ **score_model_args,
204
+ )
205
+ if compile_score:
206
+ self.score_model = torch.compile(
207
+ self.score_model, dynamic=False, fullgraph=False
208
+ )
209
+
210
+ # parameters
211
+ self.sigma_min = sigma_min
212
+ self.sigma_max = sigma_max
213
+ self.sigma_data = sigma_data
214
+ self.rho = rho
215
+ self.P_mean = P_mean
216
+ self.P_std = P_std
217
+ self.num_sampling_steps = num_sampling_steps
218
+ self.gamma_0 = gamma_0
219
+ self.gamma_min = gamma_min
220
+ self.noise_scale = noise_scale
221
+ self.step_scale = step_scale
222
+ self.step_scale_random = step_scale_random
223
+ self.coordinate_augmentation = coordinate_augmentation
224
+ self.coordinate_augmentation_inference = (
225
+ coordinate_augmentation_inference
226
+ if coordinate_augmentation_inference is not None
227
+ else coordinate_augmentation
228
+ )
229
+ self.alignment_reverse_diff = alignment_reverse_diff
230
+ self.synchronize_sigmas = synchronize_sigmas
231
+
232
+ self.token_s = score_model_args["token_s"]
233
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
234
+
235
+ @property
236
+ def device(self):
237
+ return next(self.score_model.parameters()).device
238
+
239
+ def c_skip(self, sigma):
240
+ return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
241
+
242
+ def c_out(self, sigma):
243
+ return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
244
+
245
+ def c_in(self, sigma):
246
+ return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
247
+
248
+ def c_noise(self, sigma):
249
+ return log(sigma / self.sigma_data) * 0.25
250
+
251
+ def preconditioned_network_forward(
252
+ self,
253
+ noised_atom_coords, #: Float['b m 3'],
254
+ sigma, #: Float['b'] | Float[' '] | float,
255
+ network_condition_kwargs: dict,
256
+ ):
257
+ batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
258
+
259
+ if isinstance(sigma, float):
260
+ sigma = torch.full((batch,), sigma, device=device)
261
+
262
+ padded_sigma = rearrange(sigma, "b -> b 1 1")
263
+
264
+ r_update = self.score_model(
265
+ r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
266
+ times=self.c_noise(sigma),
267
+ **network_condition_kwargs,
268
+ )
269
+
270
+ denoised_coords = (
271
+ self.c_skip(padded_sigma) * noised_atom_coords
272
+ + self.c_out(padded_sigma) * r_update
273
+ )
274
+ return denoised_coords
275
+
276
+ def sample_schedule(self, num_sampling_steps=None):
277
+ num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
278
+ inv_rho = 1 / self.rho
279
+
280
+ steps = torch.arange(
281
+ num_sampling_steps, device=self.device, dtype=torch.float32
282
+ )
283
+ sigmas = (
284
+ self.sigma_max**inv_rho
285
+ + steps
286
+ / (num_sampling_steps - 1)
287
+ * (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
288
+ ) ** self.rho
289
+
290
+ sigmas = sigmas * self.sigma_data
291
+
292
+ sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0.
293
+ return sigmas
294
+
295
+ def sample(
296
+ self,
297
+ atom_mask,
298
+ num_sampling_steps=None,
299
+ multiplicity=1,
300
+ max_parallel_samples=None,
301
+ steering_args=None,
302
+ **network_condition_kwargs,
303
+ ):
304
+ potentials = get_potentials()
305
+ if steering_args["fk_steering"]:
306
+ multiplicity = multiplicity * steering_args["num_particles"]
307
+ energy_traj = torch.empty((multiplicity, 0), device=self.device)
308
+ resample_weights = torch.ones(multiplicity, device=self.device).reshape(
309
+ -1, steering_args["num_particles"]
310
+ )
311
+ if steering_args["guidance_update"]:
312
+ scaled_guidance_update = torch.zeros(
313
+ (multiplicity, *atom_mask.shape[1:], 3),
314
+ dtype=torch.float32,
315
+ device=self.device,
316
+ )
317
+ if max_parallel_samples is None:
318
+ max_parallel_samples = multiplicity
319
+
320
+ num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
321
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
322
+
323
+ shape = (*atom_mask.shape, 3)
324
+
325
+ # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
326
+ sigmas = self.sample_schedule(num_sampling_steps)
327
+ gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
328
+ sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
329
+ if self.training and self.step_scale_random is not None:
330
+ step_scale = np.random.choice(self.step_scale_random)
331
+ else:
332
+ step_scale = self.step_scale
333
+
334
+ # atom position is noise at the beginning
335
+ init_sigma = sigmas[0]
336
+ atom_coords = init_sigma * torch.randn(shape, device=self.device)
337
+ token_repr = None
338
+ atom_coords_denoised = None
339
+
340
+ # gradually denoise
341
+ for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas):
342
+ random_R, random_tr = compute_random_augmentation(
343
+ multiplicity, device=atom_coords.device, dtype=atom_coords.dtype
344
+ )
345
+ atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
346
+ atom_coords = (
347
+ torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr
348
+ )
349
+ if atom_coords_denoised is not None:
350
+ atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True)
351
+ atom_coords_denoised = (
352
+ torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R)
353
+ + random_tr
354
+ )
355
+ if steering_args["guidance_update"] and scaled_guidance_update is not None:
356
+ scaled_guidance_update = torch.einsum(
357
+ "bmd,bds->bms", scaled_guidance_update, random_R
358
+ )
359
+
360
+ sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
361
+
362
+ t_hat = sigma_tm * (1 + gamma)
363
+ steering_t = 1.0 - (step_idx / num_sampling_steps)
364
+ noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2)
365
+ eps = sqrt(noise_var) * torch.randn(shape, device=self.device)
366
+ atom_coords_noisy = atom_coords + eps
367
+
368
+ with torch.no_grad():
369
+ atom_coords_denoised = torch.zeros_like(atom_coords_noisy)
370
+ sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device)
371
+ sample_ids_chunks = sample_ids.chunk(
372
+ multiplicity % max_parallel_samples + 1
373
+ )
374
+
375
+ for sample_ids_chunk in sample_ids_chunks:
376
+ atom_coords_denoised_chunk = self.preconditioned_network_forward(
377
+ atom_coords_noisy[sample_ids_chunk],
378
+ t_hat,
379
+ network_condition_kwargs=dict(
380
+ multiplicity=sample_ids_chunk.numel(),
381
+ **network_condition_kwargs,
382
+ ),
383
+ )
384
+ atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk
385
+
386
+ if steering_args["fk_steering"] and (
387
+ (
388
+ step_idx % steering_args["fk_resampling_interval"] == 0
389
+ and noise_var > 0
390
+ )
391
+ or step_idx == num_sampling_steps - 1
392
+ ):
393
+ # Compute energy of x_0 prediction
394
+ energy = torch.zeros(multiplicity, device=self.device)
395
+ for potential in potentials:
396
+ parameters = potential.compute_parameters(steering_t)
397
+ if parameters["resampling_weight"] > 0:
398
+ component_energy = potential.compute(
399
+ atom_coords_denoised,
400
+ network_condition_kwargs["feats"],
401
+ parameters,
402
+ )
403
+ energy += parameters["resampling_weight"] * component_energy
404
+ energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1)
405
+
406
+ # Compute log G values
407
+ if step_idx == 0:
408
+ log_G = -1 * energy
409
+ else:
410
+ log_G = energy_traj[:, -2] - energy_traj[:, -1]
411
+
412
+ # Compute ll difference between guided and unguided transition distribution
413
+ if steering_args["guidance_update"] and noise_var > 0:
414
+ ll_difference = (
415
+ eps**2 - (eps + scaled_guidance_update) ** 2
416
+ ).sum(dim=(-1, -2)) / (2 * noise_var)
417
+ else:
418
+ ll_difference = torch.zeros_like(energy)
419
+
420
+ # Compute resampling weights
421
+ resample_weights = F.softmax(
422
+ (ll_difference + steering_args["fk_lambda"] * log_G).reshape(
423
+ -1, steering_args["num_particles"]
424
+ ),
425
+ dim=1,
426
+ )
427
+
428
+ # Compute guidance update to x_0 prediction
429
+ if (
430
+ steering_args["guidance_update"]
431
+ and step_idx < num_sampling_steps - 1
432
+ ):
433
+ guidance_update = torch.zeros_like(atom_coords_denoised)
434
+ for guidance_step in range(steering_args["num_gd_steps"]):
435
+ energy_gradient = torch.zeros_like(atom_coords_denoised)
436
+ for potential in potentials:
437
+ parameters = potential.compute_parameters(steering_t)
438
+ if (
439
+ parameters["guidance_weight"] > 0
440
+ and (guidance_step) % parameters["guidance_interval"]
441
+ == 0
442
+ ):
443
+ energy_gradient += parameters[
444
+ "guidance_weight"
445
+ ] * potential.compute_gradient(
446
+ atom_coords_denoised + guidance_update,
447
+ network_condition_kwargs["feats"],
448
+ parameters,
449
+ )
450
+ guidance_update -= energy_gradient
451
+ atom_coords_denoised += guidance_update
452
+ scaled_guidance_update = (
453
+ guidance_update
454
+ * -1
455
+ * self.step_scale
456
+ * (sigma_t - t_hat)
457
+ / t_hat
458
+ )
459
+
460
+ if steering_args["fk_steering"] and (
461
+ (
462
+ step_idx % steering_args["fk_resampling_interval"] == 0
463
+ and noise_var > 0
464
+ )
465
+ or step_idx == num_sampling_steps - 1
466
+ ):
467
+ resample_indices = (
468
+ torch.multinomial(
469
+ resample_weights,
470
+ resample_weights.shape[1]
471
+ if step_idx < num_sampling_steps - 1
472
+ else 1,
473
+ replacement=True,
474
+ )
475
+ + resample_weights.shape[1]
476
+ * torch.arange(
477
+ resample_weights.shape[0], device=resample_weights.device
478
+ ).unsqueeze(-1)
479
+ ).flatten()
480
+
481
+ atom_coords = atom_coords[resample_indices]
482
+ atom_coords_noisy = atom_coords_noisy[resample_indices]
483
+ atom_mask = atom_mask[resample_indices]
484
+ if atom_coords_denoised is not None:
485
+ atom_coords_denoised = atom_coords_denoised[resample_indices]
486
+ energy_traj = energy_traj[resample_indices]
487
+ if steering_args["guidance_update"]:
488
+ scaled_guidance_update = scaled_guidance_update[
489
+ resample_indices
490
+ ]
491
+ if token_repr is not None:
492
+ token_repr = token_repr[resample_indices]
493
+
494
+ if self.alignment_reverse_diff:
495
+ with torch.autocast("cuda", enabled=False):
496
+ atom_coords_noisy = weighted_rigid_align(
497
+ atom_coords_noisy.float(),
498
+ atom_coords_denoised.float(),
499
+ atom_mask.float(),
500
+ atom_mask.float(),
501
+ )
502
+
503
+ atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
504
+
505
+ denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
506
+ atom_coords_next = (
507
+ atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma
508
+ )
509
+
510
+ atom_coords = atom_coords_next
511
+
512
+ return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
513
+
514
+ def loss_weight(self, sigma):
515
+ return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
516
+
517
+ def noise_distribution(self, batch_size):
518
+ return (
519
+ self.sigma_data
520
+ * (
521
+ self.P_mean
522
+ + self.P_std * torch.randn((batch_size,), device=self.device)
523
+ ).exp()
524
+ )
525
+
526
+ def forward(
527
+ self,
528
+ s_inputs,
529
+ s_trunk,
530
+ feats,
531
+ diffusion_conditioning,
532
+ multiplicity=1,
533
+ ):
534
+ # training diffusion step
535
+ batch_size = feats["coords"].shape[0] // multiplicity
536
+
537
+ if self.synchronize_sigmas:
538
+ sigmas = self.noise_distribution(batch_size).repeat_interleave(
539
+ multiplicity, 0
540
+ )
541
+ else:
542
+ sigmas = self.noise_distribution(batch_size * multiplicity)
543
+ padded_sigmas = rearrange(sigmas, "b -> b 1 1")
544
+
545
+ atom_coords = feats["coords"]
546
+
547
+ atom_mask = feats["atom_pad_mask"]
548
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
549
+
550
+ atom_coords = center_random_augmentation(
551
+ atom_coords, atom_mask, augmentation=self.coordinate_augmentation
552
+ )
553
+
554
+ noise = torch.randn_like(atom_coords)
555
+ noised_atom_coords = atom_coords + padded_sigmas * noise
556
+
557
+ denoised_atom_coords = self.preconditioned_network_forward(
558
+ noised_atom_coords,
559
+ sigmas,
560
+ network_condition_kwargs={
561
+ "s_inputs": s_inputs,
562
+ "s_trunk": s_trunk,
563
+ "feats": feats,
564
+ "multiplicity": multiplicity,
565
+ "diffusion_conditioning": diffusion_conditioning,
566
+ },
567
+ )
568
+
569
+ return {
570
+ "noised_atom_coords": noised_atom_coords,
571
+ "denoised_atom_coords": denoised_atom_coords,
572
+ "sigmas": sigmas,
573
+ "aligned_true_atom_coords": atom_coords,
574
+ }
575
+
576
+ def compute_loss(
577
+ self,
578
+ feats,
579
+ out_dict,
580
+ add_smooth_lddt_loss=True,
581
+ nucleotide_loss_weight=5.0,
582
+ ligand_loss_weight=10.0,
583
+ multiplicity=1,
584
+ filter_by_plddt=0.0,
585
+ ):
586
+ with torch.autocast("cuda", enabled=False):
587
+ denoised_atom_coords = out_dict["denoised_atom_coords"].float()
588
+ noised_atom_coords = out_dict["noised_atom_coords"].float()
589
+ sigmas = out_dict["sigmas"].float()
590
+
591
+ resolved_atom_mask_uni = feats["atom_resolved_mask"].float()
592
+
593
+ if filter_by_plddt > 0:
594
+ plddt_mask = feats["plddt"] > filter_by_plddt
595
+ resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float()
596
+
597
+ resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave(
598
+ multiplicity, 0
599
+ )
600
+
601
+ align_weights = noised_atom_coords.new_ones(noised_atom_coords.shape[:2])
602
+ atom_type = (
603
+ torch.bmm(
604
+ feats["atom_to_token"].float(),
605
+ feats["mol_type"].unsqueeze(-1).float(),
606
+ )
607
+ .squeeze(-1)
608
+ .long()
609
+ )
610
+ atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
611
+
612
+ align_weights = (
613
+ align_weights
614
+ * (
615
+ 1
616
+ + nucleotide_loss_weight
617
+ * (
618
+ torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
619
+ + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
620
+ )
621
+ + ligand_loss_weight
622
+ * torch.eq(
623
+ atom_type_mult, const.chain_type_ids["NONPOLYMER"]
624
+ ).float()
625
+ ).float()
626
+ )
627
+
628
+ atom_coords = out_dict["aligned_true_atom_coords"].float()
629
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
630
+ atom_coords.detach(),
631
+ denoised_atom_coords.detach(),
632
+ align_weights.detach(),
633
+ mask=feats["atom_resolved_mask"]
634
+ .float()
635
+ .repeat_interleave(multiplicity, 0)
636
+ .detach(),
637
+ )
638
+
639
+ # Cast back
640
+ atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
641
+ denoised_atom_coords
642
+ )
643
+
644
+ # weighted MSE loss of denoised atom positions
645
+ mse_loss = (
646
+ (denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2
647
+ ).sum(dim=-1)
648
+ mse_loss = torch.sum(
649
+ mse_loss * align_weights * resolved_atom_mask, dim=-1
650
+ ) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5)
651
+
652
+ # weight by sigma factor
653
+ loss_weights = self.loss_weight(sigmas)
654
+ mse_loss = (mse_loss * loss_weights).mean()
655
+
656
+ total_loss = mse_loss
657
+
658
+ # proposed auxiliary smooth lddt loss
659
+ lddt_loss = self.zero
660
+ if add_smooth_lddt_loss:
661
+ lddt_loss = smooth_lddt_loss(
662
+ denoised_atom_coords,
663
+ feats["coords"],
664
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
665
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
666
+ coords_mask=resolved_atom_mask_uni,
667
+ multiplicity=multiplicity,
668
+ )
669
+
670
+ total_loss = total_loss + lddt_loss
671
+
672
+ loss_breakdown = {
673
+ "mse_loss": mse_loss,
674
+ "smooth_lddt_loss": lddt_loss,
675
+ }
676
+
677
+ return {"loss": total_loss, "loss_breakdown": loss_breakdown}