boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,844 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ from __future__ import annotations
4
+
5
+ from math import sqrt
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from torch import nn
11
+ from torch.nn import Module
12
+
13
+ import boltz.model.layers.initialize as init
14
+ from boltz.data import const
15
+ from boltz.model.loss.diffusion import (
16
+ smooth_lddt_loss,
17
+ weighted_rigid_align,
18
+ )
19
+ from boltz.model.modules.utils import center_random_augmentation
20
+ from boltz.model.modules.encoders import (
21
+ AtomAttentionDecoder,
22
+ AtomAttentionEncoder,
23
+ FourierEmbedding,
24
+ PairwiseConditioning,
25
+ SingleConditioning,
26
+ )
27
+ from boltz.model.modules.transformers import (
28
+ ConditionedTransitionBlock,
29
+ DiffusionTransformer,
30
+ )
31
+ from boltz.model.modules.utils import (
32
+ LinearNoBias,
33
+ compute_random_augmentation,
34
+ center_random_augmentation,
35
+ default,
36
+ log,
37
+ )
38
+ from boltz.model.potentials.potentials import get_potentials
39
+
40
+
41
+ class DiffusionModule(Module):
42
+ """Diffusion module"""
43
+
44
+ def __init__(
45
+ self,
46
+ token_s: int,
47
+ token_z: int,
48
+ atom_s: int,
49
+ atom_z: int,
50
+ atoms_per_window_queries: int = 32,
51
+ atoms_per_window_keys: int = 128,
52
+ sigma_data: int = 16,
53
+ dim_fourier: int = 256,
54
+ atom_encoder_depth: int = 3,
55
+ atom_encoder_heads: int = 4,
56
+ token_transformer_depth: int = 24,
57
+ token_transformer_heads: int = 8,
58
+ atom_decoder_depth: int = 3,
59
+ atom_decoder_heads: int = 4,
60
+ atom_feature_dim: int = 128,
61
+ conditioning_transition_layers: int = 2,
62
+ activation_checkpointing: bool = False,
63
+ offload_to_cpu: bool = False,
64
+ **kwargs,
65
+ ) -> None:
66
+ """Initialize the diffusion module.
67
+
68
+ Parameters
69
+ ----------
70
+ token_s : int
71
+ The single representation dimension.
72
+ token_z : int
73
+ The pair representation dimension.
74
+ atom_s : int
75
+ The atom single representation dimension.
76
+ atom_z : int
77
+ The atom pair representation dimension.
78
+ atoms_per_window_queries : int, optional
79
+ The number of atoms per window for queries, by default 32.
80
+ atoms_per_window_keys : int, optional
81
+ The number of atoms per window for keys, by default 128.
82
+ sigma_data : int, optional
83
+ The standard deviation of the data distribution, by default 16.
84
+ dim_fourier : int, optional
85
+ The dimension of the fourier embedding, by default 256.
86
+ atom_encoder_depth : int, optional
87
+ The depth of the atom encoder, by default 3.
88
+ atom_encoder_heads : int, optional
89
+ The number of heads in the atom encoder, by default 4.
90
+ token_transformer_depth : int, optional
91
+ The depth of the token transformer, by default 24.
92
+ token_transformer_heads : int, optional
93
+ The number of heads in the token transformer, by default 8.
94
+ atom_decoder_depth : int, optional
95
+ The depth of the atom decoder, by default 3.
96
+ atom_decoder_heads : int, optional
97
+ The number of heads in the atom decoder, by default 4.
98
+ atom_feature_dim : int, optional
99
+ The atom feature dimension, by default 128.
100
+ conditioning_transition_layers : int, optional
101
+ The number of transition layers for conditioning, by default 2.
102
+ activation_checkpointing : bool, optional
103
+ Whether to use activation checkpointing, by default False.
104
+ offload_to_cpu : bool, optional
105
+ Whether to offload the activations to CPU, by default False.
106
+
107
+ """
108
+ super().__init__()
109
+
110
+ self.atoms_per_window_queries = atoms_per_window_queries
111
+ self.atoms_per_window_keys = atoms_per_window_keys
112
+ self.sigma_data = sigma_data
113
+
114
+ self.single_conditioner = SingleConditioning(
115
+ sigma_data=sigma_data,
116
+ token_s=token_s,
117
+ dim_fourier=dim_fourier,
118
+ num_transitions=conditioning_transition_layers,
119
+ )
120
+ self.pairwise_conditioner = PairwiseConditioning(
121
+ token_z=token_z,
122
+ dim_token_rel_pos_feats=token_z,
123
+ num_transitions=conditioning_transition_layers,
124
+ )
125
+
126
+ self.atom_attention_encoder = AtomAttentionEncoder(
127
+ atom_s=atom_s,
128
+ atom_z=atom_z,
129
+ token_s=token_s,
130
+ token_z=token_z,
131
+ atoms_per_window_queries=atoms_per_window_queries,
132
+ atoms_per_window_keys=atoms_per_window_keys,
133
+ atom_feature_dim=atom_feature_dim,
134
+ atom_encoder_depth=atom_encoder_depth,
135
+ atom_encoder_heads=atom_encoder_heads,
136
+ structure_prediction=True,
137
+ activation_checkpointing=activation_checkpointing,
138
+ )
139
+
140
+ self.s_to_a_linear = nn.Sequential(
141
+ nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
142
+ )
143
+ init.final_init_(self.s_to_a_linear[1].weight)
144
+
145
+ self.token_transformer = DiffusionTransformer(
146
+ dim=2 * token_s,
147
+ dim_single_cond=2 * token_s,
148
+ dim_pairwise=token_z,
149
+ depth=token_transformer_depth,
150
+ heads=token_transformer_heads,
151
+ activation_checkpointing=activation_checkpointing,
152
+ offload_to_cpu=offload_to_cpu,
153
+ )
154
+
155
+ self.a_norm = nn.LayerNorm(2 * token_s)
156
+
157
+ self.atom_attention_decoder = AtomAttentionDecoder(
158
+ atom_s=atom_s,
159
+ atom_z=atom_z,
160
+ token_s=token_s,
161
+ attn_window_queries=atoms_per_window_queries,
162
+ attn_window_keys=atoms_per_window_keys,
163
+ atom_decoder_depth=atom_decoder_depth,
164
+ atom_decoder_heads=atom_decoder_heads,
165
+ activation_checkpointing=activation_checkpointing,
166
+ )
167
+
168
+ def forward(
169
+ self,
170
+ s_inputs,
171
+ s_trunk,
172
+ z_trunk,
173
+ r_noisy,
174
+ times,
175
+ relative_position_encoding,
176
+ feats,
177
+ multiplicity=1,
178
+ model_cache=None,
179
+ ):
180
+ s, normed_fourier = self.single_conditioner(
181
+ times=times,
182
+ s_trunk=s_trunk.repeat_interleave(multiplicity, 0),
183
+ s_inputs=s_inputs.repeat_interleave(multiplicity, 0),
184
+ )
185
+
186
+ if model_cache is None or len(model_cache) == 0:
187
+ z = self.pairwise_conditioner(
188
+ z_trunk=z_trunk, token_rel_pos_feats=relative_position_encoding
189
+ )
190
+ else:
191
+ z = None
192
+
193
+ # Compute Atom Attention Encoder and aggregation to coarse-grained tokens
194
+ a, q_skip, c_skip, p_skip, to_keys = self.atom_attention_encoder(
195
+ feats=feats,
196
+ s_trunk=s_trunk,
197
+ z=z,
198
+ r=r_noisy,
199
+ multiplicity=multiplicity,
200
+ model_cache=model_cache,
201
+ )
202
+
203
+ # Full self-attention on token level
204
+ a = a + self.s_to_a_linear(s)
205
+
206
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
207
+ a = self.token_transformer(
208
+ a,
209
+ mask=mask.float(),
210
+ s=s,
211
+ z=z, # note z is not expanded with multiplicity until after bias is computed
212
+ multiplicity=multiplicity,
213
+ model_cache=model_cache,
214
+ )
215
+ a = self.a_norm(a)
216
+
217
+ # Broadcast token activations to atoms and run Sequence-local Atom Attention
218
+ r_update = self.atom_attention_decoder(
219
+ a=a,
220
+ q=q_skip,
221
+ c=c_skip,
222
+ p=p_skip,
223
+ feats=feats,
224
+ multiplicity=multiplicity,
225
+ to_keys=to_keys,
226
+ model_cache=model_cache,
227
+ )
228
+
229
+ return {"r_update": r_update, "token_a": a.detach()}
230
+
231
+
232
+ class OutTokenFeatUpdate(Module):
233
+ """Output token feature update"""
234
+
235
+ def __init__(
236
+ self,
237
+ sigma_data: float,
238
+ token_s=384,
239
+ dim_fourier=256,
240
+ ):
241
+ """Initialize the Output token feature update for confidence model.
242
+
243
+ Parameters
244
+ ----------
245
+ sigma_data : float
246
+ The standard deviation of the data distribution.
247
+ token_s : int, optional
248
+ The token dimension, by default 384.
249
+ dim_fourier : int, optional
250
+ The dimension of the fourier embedding, by default 256.
251
+
252
+ """
253
+
254
+ super().__init__()
255
+ self.sigma_data = sigma_data
256
+
257
+ self.norm_next = nn.LayerNorm(2 * token_s)
258
+ self.fourier_embed = FourierEmbedding(dim_fourier)
259
+ self.norm_fourier = nn.LayerNorm(dim_fourier)
260
+ self.transition_block = ConditionedTransitionBlock(
261
+ 2 * token_s, 2 * token_s + dim_fourier
262
+ )
263
+
264
+ def forward(
265
+ self,
266
+ times,
267
+ acc_a,
268
+ next_a,
269
+ ):
270
+ next_a = self.norm_next(next_a)
271
+ fourier_embed = self.fourier_embed(times)
272
+ normed_fourier = (
273
+ self.norm_fourier(fourier_embed)
274
+ .unsqueeze(1)
275
+ .expand(-1, next_a.shape[1], -1)
276
+ )
277
+ cond_a = torch.cat((acc_a, normed_fourier), dim=-1)
278
+
279
+ acc_a = acc_a + self.transition_block(next_a, cond_a)
280
+
281
+ return acc_a
282
+
283
+
284
+ class AtomDiffusion(Module):
285
+ """Atom diffusion module"""
286
+
287
+ def __init__(
288
+ self,
289
+ score_model_args,
290
+ num_sampling_steps=5,
291
+ sigma_min=0.0004,
292
+ sigma_max=160.0,
293
+ sigma_data=16.0,
294
+ rho=7,
295
+ P_mean=-1.2,
296
+ P_std=1.5,
297
+ gamma_0=0.8,
298
+ gamma_min=1.0,
299
+ noise_scale=1.003,
300
+ step_scale=1.5,
301
+ coordinate_augmentation=True,
302
+ compile_score=False,
303
+ alignment_reverse_diff=False,
304
+ synchronize_sigmas=False,
305
+ use_inference_model_cache=False,
306
+ accumulate_token_repr=False,
307
+ **kwargs,
308
+ ):
309
+ """Initialize the atom diffusion module.
310
+
311
+ Parameters
312
+ ----------
313
+ score_model_args : dict
314
+ The arguments for the score model.
315
+ num_sampling_steps : int, optional
316
+ The number of sampling steps, by default 5.
317
+ sigma_min : float, optional
318
+ The minimum sigma value, by default 0.0004.
319
+ sigma_max : float, optional
320
+ The maximum sigma value, by default 160.0.
321
+ sigma_data : float, optional
322
+ The standard deviation of the data distribution, by default 16.0.
323
+ rho : int, optional
324
+ The rho value, by default 7.
325
+ P_mean : float, optional
326
+ The mean value of P, by default -1.2.
327
+ P_std : float, optional
328
+ The standard deviation of P, by default 1.5.
329
+ gamma_0 : float, optional
330
+ The gamma value, by default 0.8.
331
+ gamma_min : float, optional
332
+ The minimum gamma value, by default 1.0.
333
+ noise_scale : float, optional
334
+ The noise scale, by default 1.003.
335
+ step_scale : float, optional
336
+ The step scale, by default 1.5.
337
+ coordinate_augmentation : bool, optional
338
+ Whether to use coordinate augmentation, by default True.
339
+ compile_score : bool, optional
340
+ Whether to compile the score model, by default False.
341
+ alignment_reverse_diff : bool, optional
342
+ Whether to use alignment reverse diff, by default False.
343
+ synchronize_sigmas : bool, optional
344
+ Whether to synchronize the sigmas, by default False.
345
+ use_inference_model_cache : bool, optional
346
+ Whether to use the inference model cache, by default False.
347
+ accumulate_token_repr : bool, optional
348
+ Whether to accumulate the token representation, by default False.
349
+
350
+ """
351
+ super().__init__()
352
+ self.score_model = DiffusionModule(
353
+ **score_model_args,
354
+ )
355
+ if compile_score:
356
+ self.score_model = torch.compile(
357
+ self.score_model, dynamic=False, fullgraph=False
358
+ )
359
+
360
+ # parameters
361
+ self.sigma_min = sigma_min
362
+ self.sigma_max = sigma_max
363
+ self.sigma_data = sigma_data
364
+ self.rho = rho
365
+ self.P_mean = P_mean
366
+ self.P_std = P_std
367
+ self.num_sampling_steps = num_sampling_steps
368
+ self.gamma_0 = gamma_0
369
+ self.gamma_min = gamma_min
370
+ self.noise_scale = noise_scale
371
+ self.step_scale = step_scale
372
+ self.coordinate_augmentation = coordinate_augmentation
373
+ self.alignment_reverse_diff = alignment_reverse_diff
374
+ self.synchronize_sigmas = synchronize_sigmas
375
+ self.use_inference_model_cache = use_inference_model_cache
376
+
377
+ self.accumulate_token_repr = accumulate_token_repr
378
+ self.token_s = score_model_args["token_s"]
379
+ if self.accumulate_token_repr:
380
+ self.out_token_feat_update = OutTokenFeatUpdate(
381
+ sigma_data=sigma_data,
382
+ token_s=score_model_args["token_s"],
383
+ dim_fourier=score_model_args["dim_fourier"],
384
+ )
385
+
386
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
387
+
388
+ @property
389
+ def device(self):
390
+ return next(self.score_model.parameters()).device
391
+
392
+ def c_skip(self, sigma):
393
+ return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
394
+
395
+ def c_out(self, sigma):
396
+ return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
397
+
398
+ def c_in(self, sigma):
399
+ return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
400
+
401
+ def c_noise(self, sigma):
402
+ return log(sigma / self.sigma_data) * 0.25
403
+
404
+ def preconditioned_network_forward(
405
+ self,
406
+ noised_atom_coords,
407
+ sigma,
408
+ network_condition_kwargs: dict,
409
+ training: bool = True,
410
+ ):
411
+ batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
412
+
413
+ if isinstance(sigma, float):
414
+ sigma = torch.full((batch,), sigma, device=device)
415
+
416
+ padded_sigma = rearrange(sigma, "b -> b 1 1")
417
+
418
+ net_out = self.score_model(
419
+ r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
420
+ times=self.c_noise(sigma),
421
+ **network_condition_kwargs,
422
+ )
423
+
424
+ denoised_coords = (
425
+ self.c_skip(padded_sigma) * noised_atom_coords
426
+ + self.c_out(padded_sigma) * net_out["r_update"]
427
+ )
428
+ return denoised_coords, net_out["token_a"]
429
+
430
+ def sample_schedule(self, num_sampling_steps=None):
431
+ num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
432
+ inv_rho = 1 / self.rho
433
+
434
+ steps = torch.arange(
435
+ num_sampling_steps, device=self.device, dtype=torch.float32
436
+ )
437
+ sigmas = (
438
+ self.sigma_max**inv_rho
439
+ + steps
440
+ / (num_sampling_steps - 1)
441
+ * (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
442
+ ) ** self.rho
443
+
444
+ sigmas = sigmas * self.sigma_data
445
+
446
+ sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0.
447
+ return sigmas
448
+
449
+ def sample(
450
+ self,
451
+ atom_mask,
452
+ num_sampling_steps=None,
453
+ multiplicity=1,
454
+ max_parallel_samples=None,
455
+ train_accumulate_token_repr=False,
456
+ steering_args=None,
457
+ **network_condition_kwargs,
458
+ ):
459
+ if steering_args is not None and (steering_args["fk_steering"] or steering_args["guidance_update"]):
460
+ potentials = get_potentials()
461
+ if steering_args is not None and steering_args["fk_steering"]:
462
+ multiplicity = multiplicity * steering_args["num_particles"]
463
+ energy_traj = torch.empty((multiplicity, 0), device=self.device)
464
+ resample_weights = torch.ones(multiplicity, device=self.device).reshape(
465
+ -1, steering_args["num_particles"]
466
+ )
467
+ if steering_args is not None and steering_args["guidance_update"]:
468
+ scaled_guidance_update = torch.zeros(
469
+ (multiplicity, *atom_mask.shape[1:], 3),
470
+ dtype=torch.float32,
471
+ device=self.device,
472
+ )
473
+
474
+ num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
475
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
476
+
477
+ shape = (*atom_mask.shape, 3)
478
+ token_repr_shape = (multiplicity, network_condition_kwargs['feats']['token_index'].shape[1], 2 * self.token_s)
479
+
480
+ # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
481
+ sigmas = self.sample_schedule(num_sampling_steps)
482
+ gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
483
+ sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
484
+
485
+ # atom position is noise at the beginning
486
+ init_sigma = sigmas[0]
487
+ atom_coords = init_sigma * torch.randn(shape, device=self.device)
488
+ atom_coords_denoised = None
489
+ model_cache = {} if self.use_inference_model_cache else None
490
+
491
+ token_repr = None
492
+ token_a = None
493
+
494
+ # gradually denoise
495
+ for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas):
496
+ random_R, random_tr = compute_random_augmentation(
497
+ multiplicity, device=atom_coords.device, dtype=atom_coords.dtype
498
+ )
499
+ atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
500
+ atom_coords = (
501
+ torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr
502
+ )
503
+ if atom_coords_denoised is not None:
504
+ atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True)
505
+ atom_coords_denoised = (
506
+ torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R)
507
+ + random_tr
508
+ )
509
+ if steering_args is not None and steering_args["guidance_update"] and scaled_guidance_update is not None:
510
+ scaled_guidance_update = torch.einsum(
511
+ "bmd,bds->bms", scaled_guidance_update, random_R
512
+ )
513
+
514
+ sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
515
+
516
+ t_hat = sigma_tm * (1 + gamma)
517
+ steering_t = 1.0 - (step_idx / num_sampling_steps)
518
+ noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2)
519
+ eps = sqrt(noise_var) * torch.randn(shape, device=self.device)
520
+ atom_coords_noisy = atom_coords + eps
521
+
522
+ with torch.no_grad():
523
+ atom_coords_denoised = torch.zeros_like(atom_coords_noisy)
524
+ token_a = torch.zeros(token_repr_shape).to(atom_coords_noisy)
525
+
526
+ sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device)
527
+ sample_ids_chunks = sample_ids.chunk(
528
+ multiplicity % max_parallel_samples + 1
529
+ )
530
+ for sample_ids_chunk in sample_ids_chunks:
531
+ atom_coords_denoised_chunk, token_a_chunk = \
532
+ self.preconditioned_network_forward(
533
+ atom_coords_noisy[sample_ids_chunk],
534
+ t_hat,
535
+ training=False,
536
+ network_condition_kwargs=dict(
537
+ multiplicity=sample_ids_chunk.numel(),
538
+ model_cache=model_cache,
539
+ **network_condition_kwargs,
540
+ ),
541
+ )
542
+ atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk
543
+ token_a[sample_ids_chunk] = token_a_chunk
544
+
545
+ if steering_args is not None and steering_args["fk_steering"] and (
546
+ (
547
+ step_idx % steering_args["fk_resampling_interval"] == 0
548
+ and noise_var > 0
549
+ )
550
+ or step_idx == num_sampling_steps - 1
551
+ ):
552
+ # Compute energy of x_0 prediction
553
+ energy = torch.zeros(multiplicity, device=self.device)
554
+ for potential in potentials:
555
+ parameters = potential.compute_parameters(steering_t)
556
+ if parameters["resampling_weight"] > 0:
557
+ component_energy = potential.compute(
558
+ atom_coords_denoised,
559
+ network_condition_kwargs["feats"],
560
+ parameters,
561
+ )
562
+ energy += parameters["resampling_weight"] * component_energy
563
+ energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1)
564
+
565
+ # Compute log G values
566
+ if step_idx == 0:
567
+ log_G = -1 * energy
568
+ else:
569
+ log_G = energy_traj[:, -2] - energy_traj[:, -1]
570
+
571
+ # Compute ll difference between guided and unguided transition distribution
572
+ if steering_args["guidance_update"] and noise_var > 0:
573
+ ll_difference = (
574
+ eps**2 - (eps + scaled_guidance_update) ** 2
575
+ ).sum(dim=(-1, -2)) / (2 * noise_var)
576
+ else:
577
+ ll_difference = torch.zeros_like(energy)
578
+
579
+ # Compute resampling weights
580
+ resample_weights = F.softmax(
581
+ (ll_difference + steering_args["fk_lambda"] * log_G).reshape(
582
+ -1, steering_args["num_particles"]
583
+ ),
584
+ dim=1,
585
+ )
586
+
587
+ # Compute guidance update to x_0 prediction
588
+ if (
589
+ steering_args is not None and
590
+ steering_args["guidance_update"]
591
+ and step_idx < num_sampling_steps - 1
592
+ ):
593
+ guidance_update = torch.zeros_like(atom_coords_denoised)
594
+ for guidance_step in range(steering_args["num_gd_steps"]):
595
+ energy_gradient = torch.zeros_like(atom_coords_denoised)
596
+ for potential in potentials:
597
+ parameters = potential.compute_parameters(steering_t)
598
+ if (
599
+ parameters["guidance_weight"] > 0
600
+ and (guidance_step) % parameters["guidance_interval"]
601
+ == 0
602
+ ):
603
+ energy_gradient += parameters[
604
+ "guidance_weight"
605
+ ] * potential.compute_gradient(
606
+ atom_coords_denoised + guidance_update,
607
+ network_condition_kwargs["feats"],
608
+ parameters,
609
+ )
610
+ guidance_update -= energy_gradient
611
+ atom_coords_denoised += guidance_update
612
+ scaled_guidance_update = (
613
+ guidance_update
614
+ * -1
615
+ * self.step_scale
616
+ * (sigma_t - t_hat)
617
+ / t_hat
618
+ )
619
+
620
+ if steering_args is not None and steering_args["fk_steering"] and (
621
+ (
622
+ step_idx % steering_args["fk_resampling_interval"] == 0
623
+ and noise_var > 0
624
+ )
625
+ or step_idx == num_sampling_steps - 1
626
+ ):
627
+ resample_indices = (
628
+ torch.multinomial(
629
+ resample_weights,
630
+ resample_weights.shape[1]
631
+ if step_idx < num_sampling_steps - 1
632
+ else 1,
633
+ replacement=True,
634
+ )
635
+ + resample_weights.shape[1]
636
+ * torch.arange(
637
+ resample_weights.shape[0], device=resample_weights.device
638
+ ).unsqueeze(-1)
639
+ ).flatten()
640
+
641
+ atom_coords = atom_coords[resample_indices]
642
+ atom_coords_noisy = atom_coords_noisy[resample_indices]
643
+ atom_mask = atom_mask[resample_indices]
644
+ if atom_coords_denoised is not None:
645
+ atom_coords_denoised = atom_coords_denoised[resample_indices]
646
+ energy_traj = energy_traj[resample_indices]
647
+ if steering_args["guidance_update"]:
648
+ scaled_guidance_update = scaled_guidance_update[
649
+ resample_indices
650
+ ]
651
+ if token_repr is not None:
652
+ token_repr = token_repr[resample_indices]
653
+ if token_a is not None:
654
+ token_a = token_a[resample_indices]
655
+
656
+ if self.accumulate_token_repr:
657
+ if token_repr is None:
658
+ token_repr = torch.zeros_like(token_a)
659
+
660
+ with torch.set_grad_enabled(train_accumulate_token_repr):
661
+ sigma = torch.full(
662
+ (atom_coords_denoised.shape[0],),
663
+ t_hat,
664
+ device=atom_coords_denoised.device,
665
+ )
666
+ token_repr = self.out_token_feat_update(
667
+ times=self.c_noise(sigma), acc_a=token_repr, next_a=token_a
668
+ )
669
+
670
+ if self.alignment_reverse_diff:
671
+ with torch.autocast("cuda", enabled=False):
672
+ atom_coords_noisy = weighted_rigid_align(
673
+ atom_coords_noisy.float(),
674
+ atom_coords_denoised.float(),
675
+ atom_mask.float(),
676
+ atom_mask.float(),
677
+ )
678
+
679
+ atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
680
+
681
+ denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
682
+ atom_coords_next = (
683
+ atom_coords_noisy
684
+ + self.step_scale * (sigma_t - t_hat) * denoised_over_sigma
685
+ )
686
+
687
+ atom_coords = atom_coords_next
688
+
689
+ return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
690
+
691
+ def loss_weight(self, sigma):
692
+ return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
693
+
694
+ def noise_distribution(self, batch_size):
695
+ return (
696
+ self.sigma_data
697
+ * (
698
+ self.P_mean
699
+ + self.P_std * torch.randn((batch_size,), device=self.device)
700
+ ).exp()
701
+ )
702
+
703
+ def forward(
704
+ self,
705
+ s_inputs,
706
+ s_trunk,
707
+ z_trunk,
708
+ relative_position_encoding,
709
+ feats,
710
+ multiplicity=1,
711
+ ):
712
+ # training diffusion step
713
+ batch_size = feats["coords"].shape[0]
714
+
715
+ if self.synchronize_sigmas:
716
+ sigmas = self.noise_distribution(batch_size).repeat_interleave(
717
+ multiplicity, 0
718
+ )
719
+ else:
720
+ sigmas = self.noise_distribution(batch_size * multiplicity)
721
+ padded_sigmas = rearrange(sigmas, "b -> b 1 1")
722
+
723
+ atom_coords = feats["coords"]
724
+ B, N, L = atom_coords.shape[0:3]
725
+ atom_coords = atom_coords.reshape(B * N, L, 3)
726
+ atom_coords = atom_coords.repeat_interleave(multiplicity // N, 0)
727
+ feats["coords"] = atom_coords
728
+
729
+ atom_mask = feats["atom_pad_mask"]
730
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
731
+
732
+ atom_coords = center_random_augmentation(
733
+ atom_coords, atom_mask, augmentation=self.coordinate_augmentation
734
+ )
735
+
736
+ noise = torch.randn_like(atom_coords)
737
+ noised_atom_coords = atom_coords + padded_sigmas * noise
738
+
739
+ denoised_atom_coords, _ = self.preconditioned_network_forward(
740
+ noised_atom_coords,
741
+ sigmas,
742
+ training=True,
743
+ network_condition_kwargs=dict(
744
+ s_inputs=s_inputs,
745
+ s_trunk=s_trunk,
746
+ z_trunk=z_trunk,
747
+ relative_position_encoding=relative_position_encoding,
748
+ feats=feats,
749
+ multiplicity=multiplicity,
750
+ ),
751
+ )
752
+
753
+ return dict(
754
+ noised_atom_coords=noised_atom_coords,
755
+ denoised_atom_coords=denoised_atom_coords,
756
+ sigmas=sigmas,
757
+ aligned_true_atom_coords=atom_coords,
758
+ )
759
+
760
+ def compute_loss(
761
+ self,
762
+ feats,
763
+ out_dict,
764
+ add_smooth_lddt_loss=True,
765
+ nucleotide_loss_weight=5.0,
766
+ ligand_loss_weight=10.0,
767
+ multiplicity=1,
768
+ ):
769
+ denoised_atom_coords = out_dict["denoised_atom_coords"]
770
+ noised_atom_coords = out_dict["noised_atom_coords"]
771
+ sigmas = out_dict["sigmas"]
772
+
773
+ resolved_atom_mask = feats["atom_resolved_mask"]
774
+ resolved_atom_mask = resolved_atom_mask.repeat_interleave(multiplicity, 0)
775
+
776
+ align_weights = noised_atom_coords.new_ones(noised_atom_coords.shape[:2])
777
+ atom_type = (
778
+ torch.bmm(
779
+ feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
780
+ )
781
+ .squeeze(-1)
782
+ .long()
783
+ )
784
+ atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
785
+
786
+ align_weights = align_weights * (
787
+ 1
788
+ + nucleotide_loss_weight
789
+ * (
790
+ torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
791
+ + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
792
+ )
793
+ + ligand_loss_weight
794
+ * torch.eq(atom_type_mult, const.chain_type_ids["NONPOLYMER"]).float()
795
+ )
796
+
797
+ with torch.no_grad(), torch.autocast("cuda", enabled=False):
798
+ atom_coords = out_dict["aligned_true_atom_coords"]
799
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
800
+ atom_coords.detach().float(),
801
+ denoised_atom_coords.detach().float(),
802
+ align_weights.detach().float(),
803
+ mask=resolved_atom_mask.detach().float(),
804
+ )
805
+
806
+ # Cast back
807
+ atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
808
+ denoised_atom_coords
809
+ )
810
+
811
+ # weighted MSE loss of denoised atom positions
812
+ mse_loss = ((denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(
813
+ dim=-1
814
+ )
815
+ mse_loss = torch.sum(
816
+ mse_loss * align_weights * resolved_atom_mask, dim=-1
817
+ ) / torch.sum(3 * align_weights * resolved_atom_mask, dim=-1)
818
+
819
+ # weight by sigma factor
820
+ loss_weights = self.loss_weight(sigmas)
821
+ mse_loss = (mse_loss * loss_weights).mean()
822
+
823
+ total_loss = mse_loss
824
+
825
+ # proposed auxiliary smooth lddt loss
826
+ lddt_loss = self.zero
827
+ if add_smooth_lddt_loss:
828
+ lddt_loss = smooth_lddt_loss(
829
+ denoised_atom_coords,
830
+ feats["coords"],
831
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
832
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
833
+ coords_mask=feats["atom_resolved_mask"],
834
+ multiplicity=multiplicity,
835
+ )
836
+
837
+ total_loss = total_loss + lddt_loss
838
+
839
+ loss_breakdown = dict(
840
+ mse_loss=mse_loss,
841
+ smooth_lddt_loss=lddt_loss,
842
+ )
843
+
844
+ return dict(loss=total_loss, loss_breakdown=loss_breakdown)