boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1286 @@
1
+ import gc
2
+ import random
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch._dynamo
7
+ from pytorch_lightning import LightningModule
8
+ from torch import Tensor, nn
9
+ from torchmetrics import MeanMetric
10
+
11
+ import boltz.model.layers.initialize as init
12
+ from boltz.data import const
13
+ from boltz.data.feature.symmetry import (
14
+ minimum_lddt_symmetry_coords,
15
+ minimum_symmetry_coords,
16
+ )
17
+ from boltz.model.loss.confidence import confidence_loss
18
+ from boltz.model.loss.distogram import distogram_loss
19
+ from boltz.model.loss.validation import (
20
+ compute_pae_mae,
21
+ compute_pde_mae,
22
+ compute_plddt_mae,
23
+ factored_lddt_loss,
24
+ factored_token_lddt_dist_loss,
25
+ weighted_minimum_rmsd,
26
+ )
27
+ from boltz.model.modules.confidence import ConfidenceModule
28
+ from boltz.model.modules.diffusion import AtomDiffusion
29
+ from boltz.model.modules.encoders import RelativePositionEncoder
30
+ from boltz.model.modules.trunk import (
31
+ DistogramModule,
32
+ InputEmbedder,
33
+ MSAModule,
34
+ PairformerModule,
35
+ )
36
+ from boltz.model.modules.utils import ExponentialMovingAverage
37
+ from boltz.model.optim.scheduler import AlphaFoldLRScheduler
38
+
39
+
40
+ class Boltz1(LightningModule):
41
+ """Boltz1 model."""
42
+
43
+ def __init__( # noqa: PLR0915, C901, PLR0912
44
+ self,
45
+ atom_s: int,
46
+ atom_z: int,
47
+ token_s: int,
48
+ token_z: int,
49
+ num_bins: int,
50
+ training_args: dict[str, Any],
51
+ validation_args: dict[str, Any],
52
+ embedder_args: dict[str, Any],
53
+ msa_args: dict[str, Any],
54
+ pairformer_args: dict[str, Any],
55
+ score_model_args: dict[str, Any],
56
+ diffusion_process_args: dict[str, Any],
57
+ diffusion_loss_args: dict[str, Any],
58
+ confidence_model_args: dict[str, Any],
59
+ atom_feature_dim: int = 128,
60
+ confidence_prediction: bool = False,
61
+ confidence_imitate_trunk: bool = False,
62
+ alpha_pae: float = 0.0,
63
+ structure_prediction_training: bool = True,
64
+ atoms_per_window_queries: int = 32,
65
+ atoms_per_window_keys: int = 128,
66
+ compile_pairformer: bool = False,
67
+ compile_structure: bool = False,
68
+ compile_confidence: bool = False,
69
+ nucleotide_rmsd_weight: float = 5.0,
70
+ ligand_rmsd_weight: float = 10.0,
71
+ no_msa: bool = False,
72
+ no_atom_encoder: bool = False,
73
+ ema: bool = False,
74
+ ema_decay: float = 0.999,
75
+ min_dist: float = 2.0,
76
+ max_dist: float = 22.0,
77
+ predict_args: Optional[dict[str, Any]] = None,
78
+ steering_args: Optional[dict[str, Any]] = None,
79
+ use_kernels: bool = False,
80
+ ) -> None:
81
+ super().__init__()
82
+
83
+ self.save_hyperparameters()
84
+
85
+ self.lddt = nn.ModuleDict()
86
+ self.disto_lddt = nn.ModuleDict()
87
+ self.complex_lddt = nn.ModuleDict()
88
+ if confidence_prediction:
89
+ self.top1_lddt = nn.ModuleDict()
90
+ self.iplddt_top1_lddt = nn.ModuleDict()
91
+ self.ipde_top1_lddt = nn.ModuleDict()
92
+ self.pde_top1_lddt = nn.ModuleDict()
93
+ self.ptm_top1_lddt = nn.ModuleDict()
94
+ self.iptm_top1_lddt = nn.ModuleDict()
95
+ self.ligand_iptm_top1_lddt = nn.ModuleDict()
96
+ self.protein_iptm_top1_lddt = nn.ModuleDict()
97
+ self.avg_lddt = nn.ModuleDict()
98
+ self.plddt_mae = nn.ModuleDict()
99
+ self.pde_mae = nn.ModuleDict()
100
+ self.pae_mae = nn.ModuleDict()
101
+ for m in const.out_types + ["pocket_ligand_protein"]:
102
+ self.lddt[m] = MeanMetric()
103
+ self.disto_lddt[m] = MeanMetric()
104
+ self.complex_lddt[m] = MeanMetric()
105
+ if confidence_prediction:
106
+ self.top1_lddt[m] = MeanMetric()
107
+ self.iplddt_top1_lddt[m] = MeanMetric()
108
+ self.ipde_top1_lddt[m] = MeanMetric()
109
+ self.pde_top1_lddt[m] = MeanMetric()
110
+ self.ptm_top1_lddt[m] = MeanMetric()
111
+ self.iptm_top1_lddt[m] = MeanMetric()
112
+ self.ligand_iptm_top1_lddt[m] = MeanMetric()
113
+ self.protein_iptm_top1_lddt[m] = MeanMetric()
114
+ self.avg_lddt[m] = MeanMetric()
115
+ self.pde_mae[m] = MeanMetric()
116
+ self.pae_mae[m] = MeanMetric()
117
+ for m in const.out_single_types:
118
+ if confidence_prediction:
119
+ self.plddt_mae[m] = MeanMetric()
120
+ self.rmsd = MeanMetric()
121
+ self.best_rmsd = MeanMetric()
122
+
123
+ self.train_confidence_loss_logger = MeanMetric()
124
+ self.train_confidence_loss_dict_logger = nn.ModuleDict()
125
+ for m in [
126
+ "plddt_loss",
127
+ "resolved_loss",
128
+ "pde_loss",
129
+ "pae_loss",
130
+ ]:
131
+ self.train_confidence_loss_dict_logger[m] = MeanMetric()
132
+
133
+ self.ema = None
134
+ self.use_ema = ema
135
+ self.ema_decay = ema_decay
136
+
137
+ self.training_args = training_args
138
+ self.validation_args = validation_args
139
+ self.diffusion_loss_args = diffusion_loss_args
140
+ self.predict_args = predict_args
141
+ self.steering_args = steering_args
142
+
143
+ self.use_kernels = use_kernels
144
+
145
+ self.nucleotide_rmsd_weight = nucleotide_rmsd_weight
146
+ self.ligand_rmsd_weight = ligand_rmsd_weight
147
+
148
+ self.num_bins = num_bins
149
+ self.min_dist = min_dist
150
+ self.max_dist = max_dist
151
+ self.is_pairformer_compiled = False
152
+
153
+ # Input projections
154
+ s_input_dim = (
155
+ token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
156
+ )
157
+ self.s_init = nn.Linear(s_input_dim, token_s, bias=False)
158
+ self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
159
+ self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)
160
+
161
+ # Input embeddings
162
+ full_embedder_args = {
163
+ "atom_s": atom_s,
164
+ "atom_z": atom_z,
165
+ "token_s": token_s,
166
+ "token_z": token_z,
167
+ "atoms_per_window_queries": atoms_per_window_queries,
168
+ "atoms_per_window_keys": atoms_per_window_keys,
169
+ "atom_feature_dim": atom_feature_dim,
170
+ "no_atom_encoder": no_atom_encoder,
171
+ **embedder_args,
172
+ }
173
+ self.input_embedder = InputEmbedder(**full_embedder_args)
174
+ self.rel_pos = RelativePositionEncoder(token_z)
175
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
176
+
177
+ # Normalization layers
178
+ self.s_norm = nn.LayerNorm(token_s)
179
+ self.z_norm = nn.LayerNorm(token_z)
180
+
181
+ # Recycling projections
182
+ self.s_recycle = nn.Linear(token_s, token_s, bias=False)
183
+ self.z_recycle = nn.Linear(token_z, token_z, bias=False)
184
+ init.gating_init_(self.s_recycle.weight)
185
+ init.gating_init_(self.z_recycle.weight)
186
+
187
+ # Pairwise stack
188
+ self.no_msa = no_msa
189
+ if not no_msa:
190
+ self.msa_module = MSAModule(
191
+ token_z=token_z,
192
+ s_input_dim=s_input_dim,
193
+ **msa_args,
194
+ )
195
+ self.pairformer_module = PairformerModule(token_s, token_z, **pairformer_args)
196
+ if compile_pairformer:
197
+ # Big models hit the default cache limit (8)
198
+ self.is_pairformer_compiled = True
199
+ torch._dynamo.config.cache_size_limit = 512
200
+ torch._dynamo.config.accumulated_cache_size_limit = 512
201
+ self.pairformer_module = torch.compile(
202
+ self.pairformer_module,
203
+ dynamic=False,
204
+ fullgraph=False,
205
+ )
206
+
207
+ # Output modules
208
+ use_accumulate_token_repr = (
209
+ confidence_prediction
210
+ and "use_s_diffusion" in confidence_model_args
211
+ and confidence_model_args["use_s_diffusion"]
212
+ )
213
+ self.structure_module = AtomDiffusion(
214
+ score_model_args={
215
+ "token_z": token_z,
216
+ "token_s": token_s,
217
+ "atom_z": atom_z,
218
+ "atom_s": atom_s,
219
+ "atoms_per_window_queries": atoms_per_window_queries,
220
+ "atoms_per_window_keys": atoms_per_window_keys,
221
+ "atom_feature_dim": atom_feature_dim,
222
+ **score_model_args,
223
+ },
224
+ compile_score=compile_structure,
225
+ accumulate_token_repr=use_accumulate_token_repr,
226
+ **diffusion_process_args,
227
+ )
228
+ self.distogram_module = DistogramModule(token_z, num_bins)
229
+ self.confidence_prediction = confidence_prediction
230
+ self.alpha_pae = alpha_pae
231
+
232
+ self.structure_prediction_training = structure_prediction_training
233
+ self.confidence_imitate_trunk = confidence_imitate_trunk
234
+ if self.confidence_prediction:
235
+ if self.confidence_imitate_trunk:
236
+ self.confidence_module = ConfidenceModule(
237
+ token_s,
238
+ token_z,
239
+ compute_pae=alpha_pae > 0,
240
+ imitate_trunk=True,
241
+ pairformer_args=pairformer_args,
242
+ full_embedder_args=full_embedder_args,
243
+ msa_args=msa_args,
244
+ **confidence_model_args,
245
+ )
246
+ else:
247
+ self.confidence_module = ConfidenceModule(
248
+ token_s,
249
+ token_z,
250
+ compute_pae=alpha_pae > 0,
251
+ **confidence_model_args,
252
+ )
253
+ if compile_confidence:
254
+ self.confidence_module = torch.compile(
255
+ self.confidence_module, dynamic=False, fullgraph=False
256
+ )
257
+
258
+ # Remove grad from weights they are not trained for ddp
259
+ if not structure_prediction_training:
260
+ for name, param in self.named_parameters():
261
+ if name.split(".")[0] != "confidence_module":
262
+ param.requires_grad = False
263
+
264
+ def setup(self, stage: str) -> None:
265
+ """Set the model for training, validation and inference."""
266
+ if stage == "predict" and not (
267
+ torch.cuda.is_available()
268
+ and torch.cuda.get_device_properties(torch.device("cuda")).major >= 8.0 # noqa: PLR2004
269
+ ):
270
+ self.use_kernels = False
271
+
272
+ def forward(
273
+ self,
274
+ feats: dict[str, Tensor],
275
+ recycling_steps: int = 0,
276
+ num_sampling_steps: Optional[int] = None,
277
+ multiplicity_diffusion_train: int = 1,
278
+ diffusion_samples: int = 1,
279
+ max_parallel_samples: Optional[int] = None,
280
+ run_confidence_sequentially: bool = False,
281
+ ) -> dict[str, Tensor]:
282
+ dict_out = {}
283
+
284
+ # Compute input embeddings
285
+ with torch.set_grad_enabled(
286
+ self.training and self.structure_prediction_training
287
+ ):
288
+ s_inputs = self.input_embedder(feats)
289
+
290
+ # Initialize the sequence and pairwise embeddings
291
+ s_init = self.s_init(s_inputs)
292
+ z_init = (
293
+ self.z_init_1(s_inputs)[:, :, None]
294
+ + self.z_init_2(s_inputs)[:, None, :]
295
+ )
296
+ relative_position_encoding = self.rel_pos(feats)
297
+ z_init = z_init + relative_position_encoding
298
+ z_init = z_init + self.token_bonds(feats["token_bonds"].float())
299
+
300
+ # Perform rounds of the pairwise stack
301
+ s = torch.zeros_like(s_init)
302
+ z = torch.zeros_like(z_init)
303
+
304
+ # Compute pairwise mask
305
+ mask = feats["token_pad_mask"].float()
306
+ pair_mask = mask[:, :, None] * mask[:, None, :]
307
+
308
+ for i in range(recycling_steps + 1):
309
+ with torch.set_grad_enabled(self.training and (i == recycling_steps)):
310
+ # Fixes an issue with unused parameters in autocast
311
+ if (
312
+ self.training
313
+ and (i == recycling_steps)
314
+ and torch.is_autocast_enabled()
315
+ ):
316
+ torch.clear_autocast_cache()
317
+
318
+ # Apply recycling
319
+ s = s_init + self.s_recycle(self.s_norm(s))
320
+ z = z_init + self.z_recycle(self.z_norm(z))
321
+
322
+ # Compute pairwise stack
323
+ if not self.no_msa:
324
+ z = z + self.msa_module(
325
+ z, s_inputs, feats, use_kernels=self.use_kernels
326
+ )
327
+
328
+ # Revert to uncompiled version for validation
329
+ if self.is_pairformer_compiled and not self.training:
330
+ pairformer_module = self.pairformer_module._orig_mod # noqa: SLF001
331
+ else:
332
+ pairformer_module = self.pairformer_module
333
+
334
+ s, z = pairformer_module(
335
+ s,
336
+ z,
337
+ mask=mask,
338
+ pair_mask=pair_mask,
339
+ use_kernels=self.use_kernels,
340
+ )
341
+
342
+ pdistogram = self.distogram_module(z)
343
+ dict_out = {"pdistogram": pdistogram}
344
+
345
+ # Compute structure module
346
+ if self.training and self.structure_prediction_training:
347
+ dict_out.update(
348
+ self.structure_module(
349
+ s_trunk=s,
350
+ z_trunk=z,
351
+ s_inputs=s_inputs,
352
+ feats=feats,
353
+ relative_position_encoding=relative_position_encoding,
354
+ multiplicity=multiplicity_diffusion_train,
355
+ )
356
+ )
357
+
358
+ if (not self.training) or self.confidence_prediction:
359
+ dict_out.update(
360
+ self.structure_module.sample(
361
+ s_trunk=s,
362
+ z_trunk=z,
363
+ s_inputs=s_inputs,
364
+ feats=feats,
365
+ relative_position_encoding=relative_position_encoding,
366
+ num_sampling_steps=num_sampling_steps,
367
+ atom_mask=feats["atom_pad_mask"],
368
+ multiplicity=diffusion_samples,
369
+ max_parallel_samples=max_parallel_samples,
370
+ train_accumulate_token_repr=self.training,
371
+ steering_args=self.steering_args,
372
+ )
373
+ )
374
+
375
+ if self.confidence_prediction:
376
+ dict_out.update(
377
+ self.confidence_module(
378
+ s_inputs=s_inputs.detach(),
379
+ s=s.detach(),
380
+ z=z.detach(),
381
+ s_diffusion=(
382
+ dict_out["diff_token_repr"]
383
+ if self.confidence_module.use_s_diffusion
384
+ else None
385
+ ),
386
+ x_pred=dict_out["sample_atom_coords"].detach(),
387
+ feats=feats,
388
+ pred_distogram_logits=dict_out["pdistogram"].detach(),
389
+ multiplicity=diffusion_samples,
390
+ run_sequentially=run_confidence_sequentially,
391
+ use_kernels=self.use_kernels,
392
+ )
393
+ )
394
+ if self.confidence_prediction and self.confidence_module.use_s_diffusion:
395
+ dict_out.pop("diff_token_repr", None)
396
+ return dict_out
397
+
398
+ def get_true_coordinates(
399
+ self,
400
+ batch,
401
+ out,
402
+ diffusion_samples,
403
+ symmetry_correction,
404
+ lddt_minimization=True,
405
+ ):
406
+ if symmetry_correction:
407
+ min_coords_routine = (
408
+ minimum_lddt_symmetry_coords
409
+ if lddt_minimization
410
+ else minimum_symmetry_coords
411
+ )
412
+ true_coords = []
413
+ true_coords_resolved_mask = []
414
+ rmsds, best_rmsds = [], []
415
+ for idx in range(batch["token_index"].shape[0]):
416
+ best_rmsd = float("inf")
417
+ for rep in range(diffusion_samples):
418
+ i = idx * diffusion_samples + rep
419
+ best_true_coords, rmsd, best_true_coords_resolved_mask = (
420
+ min_coords_routine(
421
+ coords=out["sample_atom_coords"][i : i + 1],
422
+ feats=batch,
423
+ index_batch=idx,
424
+ nucleotide_weight=self.nucleotide_rmsd_weight,
425
+ ligand_weight=self.ligand_rmsd_weight,
426
+ )
427
+ )
428
+ rmsds.append(rmsd)
429
+ true_coords.append(best_true_coords)
430
+ true_coords_resolved_mask.append(best_true_coords_resolved_mask)
431
+ if rmsd < best_rmsd:
432
+ best_rmsd = rmsd
433
+ best_rmsds.append(best_rmsd)
434
+ true_coords = torch.cat(true_coords, dim=0)
435
+ true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0)
436
+ else:
437
+ true_coords = (
438
+ batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0)
439
+ )
440
+
441
+ true_coords_resolved_mask = batch["atom_resolved_mask"].repeat_interleave(
442
+ diffusion_samples, 0
443
+ )
444
+ rmsds, best_rmsds = weighted_minimum_rmsd(
445
+ out["sample_atom_coords"],
446
+ batch,
447
+ multiplicity=diffusion_samples,
448
+ nucleotide_weight=self.nucleotide_rmsd_weight,
449
+ ligand_weight=self.ligand_rmsd_weight,
450
+ )
451
+
452
+ return true_coords, rmsds, best_rmsds, true_coords_resolved_mask
453
+
454
+ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
455
+ # Sample recycling steps
456
+ recycling_steps = random.randint(0, self.training_args.recycling_steps)
457
+
458
+ # Compute the forward pass
459
+ out = self(
460
+ feats=batch,
461
+ recycling_steps=recycling_steps,
462
+ num_sampling_steps=self.training_args.sampling_steps,
463
+ multiplicity_diffusion_train=self.training_args.diffusion_multiplicity,
464
+ diffusion_samples=self.training_args.diffusion_samples,
465
+ )
466
+
467
+ # Compute losses
468
+ if self.structure_prediction_training:
469
+ disto_loss, _ = distogram_loss(
470
+ out,
471
+ batch,
472
+ )
473
+ try:
474
+ diffusion_loss_dict = self.structure_module.compute_loss(
475
+ batch,
476
+ out,
477
+ multiplicity=self.training_args.diffusion_multiplicity,
478
+ **self.diffusion_loss_args,
479
+ )
480
+ except Exception as e:
481
+ print(f"Skipping batch {batch_idx} due to error: {e}")
482
+ return None
483
+
484
+ else:
485
+ disto_loss = 0.0
486
+ diffusion_loss_dict = {"loss": 0.0, "loss_breakdown": {}}
487
+
488
+ if self.confidence_prediction:
489
+ # confidence model symmetry correction
490
+ true_coords, _, _, true_coords_resolved_mask = self.get_true_coordinates(
491
+ batch,
492
+ out,
493
+ diffusion_samples=self.training_args.diffusion_samples,
494
+ symmetry_correction=self.training_args.symmetry_correction,
495
+ )
496
+
497
+ confidence_loss_dict = confidence_loss(
498
+ out,
499
+ batch,
500
+ true_coords,
501
+ true_coords_resolved_mask,
502
+ alpha_pae=self.alpha_pae,
503
+ multiplicity=self.training_args.diffusion_samples,
504
+ )
505
+ else:
506
+ confidence_loss_dict = {
507
+ "loss": torch.tensor(0.0).to(batch["token_index"].device),
508
+ "loss_breakdown": {},
509
+ }
510
+
511
+ # Aggregate losses
512
+ loss = (
513
+ self.training_args.confidence_loss_weight * confidence_loss_dict["loss"]
514
+ + self.training_args.diffusion_loss_weight * diffusion_loss_dict["loss"]
515
+ + self.training_args.distogram_loss_weight * disto_loss
516
+ )
517
+ # Log losses
518
+ self.log("train/distogram_loss", disto_loss)
519
+ self.log("train/diffusion_loss", diffusion_loss_dict["loss"])
520
+ for k, v in diffusion_loss_dict["loss_breakdown"].items():
521
+ self.log(f"train/{k}", v)
522
+
523
+ if self.confidence_prediction:
524
+ self.train_confidence_loss_logger.update(
525
+ confidence_loss_dict["loss"].detach()
526
+ )
527
+
528
+ for k in self.train_confidence_loss_dict_logger.keys():
529
+ self.train_confidence_loss_dict_logger[k].update(
530
+ confidence_loss_dict["loss_breakdown"][k].detach()
531
+ if torch.is_tensor(confidence_loss_dict["loss_breakdown"][k])
532
+ else confidence_loss_dict["loss_breakdown"][k]
533
+ )
534
+ self.log("train/loss", loss)
535
+ self.training_log()
536
+ return loss
537
+
538
+ def training_log(self):
539
+ self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False)
540
+ self.log("train/param_norm", self.parameter_norm(self), prog_bar=False)
541
+
542
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
543
+ self.log("lr", lr, prog_bar=False)
544
+
545
+ self.log(
546
+ "train/grad_norm_msa_module",
547
+ self.gradient_norm(self.msa_module),
548
+ prog_bar=False,
549
+ )
550
+ self.log(
551
+ "train/param_norm_msa_module",
552
+ self.parameter_norm(self.msa_module),
553
+ prog_bar=False,
554
+ )
555
+
556
+ self.log(
557
+ "train/grad_norm_pairformer_module",
558
+ self.gradient_norm(self.pairformer_module),
559
+ prog_bar=False,
560
+ )
561
+ self.log(
562
+ "train/param_norm_pairformer_module",
563
+ self.parameter_norm(self.pairformer_module),
564
+ prog_bar=False,
565
+ )
566
+
567
+ self.log(
568
+ "train/grad_norm_structure_module",
569
+ self.gradient_norm(self.structure_module),
570
+ prog_bar=False,
571
+ )
572
+ self.log(
573
+ "train/param_norm_structure_module",
574
+ self.parameter_norm(self.structure_module),
575
+ prog_bar=False,
576
+ )
577
+
578
+ if self.confidence_prediction:
579
+ self.log(
580
+ "train/grad_norm_confidence_module",
581
+ self.gradient_norm(self.confidence_module),
582
+ prog_bar=False,
583
+ )
584
+ self.log(
585
+ "train/param_norm_confidence_module",
586
+ self.parameter_norm(self.confidence_module),
587
+ prog_bar=False,
588
+ )
589
+
590
+ def on_train_epoch_end(self):
591
+ self.log(
592
+ "train/confidence_loss",
593
+ self.train_confidence_loss_logger,
594
+ prog_bar=False,
595
+ on_step=False,
596
+ on_epoch=True,
597
+ )
598
+ for k, v in self.train_confidence_loss_dict_logger.items():
599
+ self.log(f"train/{k}", v, prog_bar=False, on_step=False, on_epoch=True)
600
+
601
+ def gradient_norm(self, module) -> float:
602
+ # Only compute over parameters that are being trained
603
+ parameters = filter(lambda p: p.requires_grad, module.parameters())
604
+ parameters = filter(lambda p: p.grad is not None, parameters)
605
+ norm = torch.tensor([p.grad.norm(p=2) ** 2 for p in parameters]).sum().sqrt()
606
+ return norm
607
+
608
+ def parameter_norm(self, module) -> float:
609
+ # Only compute over parameters that are being trained
610
+ parameters = filter(lambda p: p.requires_grad, module.parameters())
611
+ norm = torch.tensor([p.norm(p=2) ** 2 for p in parameters]).sum().sqrt()
612
+ return norm
613
+
614
+ def validation_step(self, batch: dict[str, Tensor], batch_idx: int):
615
+ # Compute the forward pass
616
+ n_samples = self.validation_args.diffusion_samples
617
+ try:
618
+ out = self(
619
+ batch,
620
+ recycling_steps=self.validation_args.recycling_steps,
621
+ num_sampling_steps=self.validation_args.sampling_steps,
622
+ diffusion_samples=n_samples,
623
+ run_confidence_sequentially=self.validation_args.run_confidence_sequentially,
624
+ )
625
+
626
+ except RuntimeError as e: # catch out of memory exceptions
627
+ if "out of memory" in str(e):
628
+ print("| WARNING: ran out of memory, skipping batch")
629
+ torch.cuda.empty_cache()
630
+ gc.collect()
631
+ return
632
+ else:
633
+ raise e
634
+
635
+ try:
636
+ # Compute distogram LDDT
637
+ boundaries = torch.linspace(2, 22.0, 63)
638
+ lower = torch.tensor([1.0])
639
+ upper = torch.tensor([22.0 + 5.0])
640
+ exp_boundaries = torch.cat((lower, boundaries, upper))
641
+ mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to(
642
+ out["pdistogram"]
643
+ )
644
+
645
+ # Compute predicted dists
646
+ preds = out["pdistogram"]
647
+ pred_softmax = torch.softmax(preds, dim=-1)
648
+ pred_softmax = pred_softmax.argmax(dim=-1)
649
+ pred_softmax = torch.nn.functional.one_hot(
650
+ pred_softmax, num_classes=preds.shape[-1]
651
+ )
652
+ pred_dist = (pred_softmax * mid_points).sum(dim=-1)
653
+ true_center = batch["disto_center"]
654
+ true_dists = torch.cdist(true_center, true_center)
655
+
656
+ # Compute lddt's
657
+ batch["token_disto_mask"] = batch["token_disto_mask"]
658
+ disto_lddt_dict, disto_total_dict = factored_token_lddt_dist_loss(
659
+ feats=batch,
660
+ true_d=true_dists,
661
+ pred_d=pred_dist,
662
+ )
663
+
664
+ true_coords, rmsds, best_rmsds, true_coords_resolved_mask = (
665
+ self.get_true_coordinates(
666
+ batch=batch,
667
+ out=out,
668
+ diffusion_samples=n_samples,
669
+ symmetry_correction=self.validation_args.symmetry_correction,
670
+ )
671
+ )
672
+
673
+ all_lddt_dict, all_total_dict = factored_lddt_loss(
674
+ feats=batch,
675
+ atom_mask=true_coords_resolved_mask,
676
+ true_atom_coords=true_coords,
677
+ pred_atom_coords=out["sample_atom_coords"],
678
+ multiplicity=n_samples,
679
+ )
680
+ except RuntimeError as e: # catch out of memory exceptions
681
+ if "out of memory" in str(e):
682
+ print("| WARNING: ran out of memory, skipping batch")
683
+ torch.cuda.empty_cache()
684
+ gc.collect()
685
+ return
686
+ else:
687
+ raise e
688
+ # if the multiplicity used is > 1 then we take the best lddt of the different samples
689
+ # AF3 combines this with the confidence based filtering
690
+ best_lddt_dict, best_total_dict = {}, {}
691
+ best_complex_lddt_dict, best_complex_total_dict = {}, {}
692
+ B = true_coords.shape[0] // n_samples
693
+ if n_samples > 1:
694
+ # NOTE: we can change the way we aggregate the lddt
695
+ complex_total = 0
696
+ complex_lddt = 0
697
+ for key in all_lddt_dict.keys():
698
+ complex_lddt += all_lddt_dict[key] * all_total_dict[key]
699
+ complex_total += all_total_dict[key]
700
+ complex_lddt /= complex_total + 1e-7
701
+ best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1)
702
+ for key in all_lddt_dict:
703
+ best_idx = all_lddt_dict[key].reshape(-1, n_samples).argmax(dim=1)
704
+ best_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[
705
+ torch.arange(B), best_idx
706
+ ]
707
+ best_total_dict[key] = all_total_dict[key].reshape(-1, n_samples)[
708
+ torch.arange(B), best_idx
709
+ ]
710
+ best_complex_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[
711
+ torch.arange(B), best_complex_idx
712
+ ]
713
+ best_complex_total_dict[key] = all_total_dict[key].reshape(
714
+ -1, n_samples
715
+ )[torch.arange(B), best_complex_idx]
716
+ else:
717
+ best_lddt_dict = all_lddt_dict
718
+ best_total_dict = all_total_dict
719
+ best_complex_lddt_dict = all_lddt_dict
720
+ best_complex_total_dict = all_total_dict
721
+
722
+ # Filtering based on confidence
723
+ if self.confidence_prediction and n_samples > 1:
724
+ # note: for now we don't have pae predictions so have to use pLDDT instead of pTM
725
+ # also, while AF3 differentiates the best prediction per confidence type we are currently not doing it
726
+ # consider this in the future as well as weighing the different pLLDT types before aggregation
727
+ mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae(
728
+ pred_atom_coords=out["sample_atom_coords"],
729
+ feats=batch,
730
+ true_atom_coords=true_coords,
731
+ pred_lddt=out["plddt"],
732
+ true_coords_resolved_mask=true_coords_resolved_mask,
733
+ multiplicity=n_samples,
734
+ )
735
+ mae_pde_dict, total_mae_pde_dict = compute_pde_mae(
736
+ pred_atom_coords=out["sample_atom_coords"],
737
+ feats=batch,
738
+ true_atom_coords=true_coords,
739
+ pred_pde=out["pde"],
740
+ true_coords_resolved_mask=true_coords_resolved_mask,
741
+ multiplicity=n_samples,
742
+ )
743
+ mae_pae_dict, total_mae_pae_dict = compute_pae_mae(
744
+ pred_atom_coords=out["sample_atom_coords"],
745
+ feats=batch,
746
+ true_atom_coords=true_coords,
747
+ pred_pae=out["pae"],
748
+ true_coords_resolved_mask=true_coords_resolved_mask,
749
+ multiplicity=n_samples,
750
+ )
751
+
752
+ plddt = out["complex_plddt"].reshape(-1, n_samples)
753
+ top1_idx = plddt.argmax(dim=1)
754
+ iplddt = out["complex_iplddt"].reshape(-1, n_samples)
755
+ iplddt_top1_idx = iplddt.argmax(dim=1)
756
+ pde = out["complex_pde"].reshape(-1, n_samples)
757
+ pde_top1_idx = pde.argmin(dim=1)
758
+ ipde = out["complex_ipde"].reshape(-1, n_samples)
759
+ ipde_top1_idx = ipde.argmin(dim=1)
760
+ ptm = out["ptm"].reshape(-1, n_samples)
761
+ ptm_top1_idx = ptm.argmax(dim=1)
762
+ iptm = out["iptm"].reshape(-1, n_samples)
763
+ iptm_top1_idx = iptm.argmax(dim=1)
764
+ ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples)
765
+ ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1)
766
+ protein_iptm = out["protein_iptm"].reshape(-1, n_samples)
767
+ protein_iptm_top1_idx = protein_iptm.argmax(dim=1)
768
+
769
+ for key in all_lddt_dict:
770
+ top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
771
+ torch.arange(B), top1_idx
772
+ ]
773
+ top1_total = all_total_dict[key].reshape(-1, n_samples)[
774
+ torch.arange(B), top1_idx
775
+ ]
776
+ iplddt_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
777
+ torch.arange(B), iplddt_top1_idx
778
+ ]
779
+ iplddt_top1_total = all_total_dict[key].reshape(-1, n_samples)[
780
+ torch.arange(B), iplddt_top1_idx
781
+ ]
782
+ pde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
783
+ torch.arange(B), pde_top1_idx
784
+ ]
785
+ pde_top1_total = all_total_dict[key].reshape(-1, n_samples)[
786
+ torch.arange(B), pde_top1_idx
787
+ ]
788
+ ipde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
789
+ torch.arange(B), ipde_top1_idx
790
+ ]
791
+ ipde_top1_total = all_total_dict[key].reshape(-1, n_samples)[
792
+ torch.arange(B), ipde_top1_idx
793
+ ]
794
+ ptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
795
+ torch.arange(B), ptm_top1_idx
796
+ ]
797
+ ptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
798
+ torch.arange(B), ptm_top1_idx
799
+ ]
800
+ iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
801
+ torch.arange(B), iptm_top1_idx
802
+ ]
803
+ iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
804
+ torch.arange(B), iptm_top1_idx
805
+ ]
806
+ ligand_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
807
+ torch.arange(B), ligand_iptm_top1_idx
808
+ ]
809
+ ligand_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
810
+ torch.arange(B), ligand_iptm_top1_idx
811
+ ]
812
+ protein_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
813
+ torch.arange(B), protein_iptm_top1_idx
814
+ ]
815
+ protein_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
816
+ torch.arange(B), protein_iptm_top1_idx
817
+ ]
818
+
819
+ self.top1_lddt[key].update(top1_lddt, top1_total)
820
+ self.iplddt_top1_lddt[key].update(iplddt_top1_lddt, iplddt_top1_total)
821
+ self.pde_top1_lddt[key].update(pde_top1_lddt, pde_top1_total)
822
+ self.ipde_top1_lddt[key].update(ipde_top1_lddt, ipde_top1_total)
823
+ self.ptm_top1_lddt[key].update(ptm_top1_lddt, ptm_top1_total)
824
+ self.iptm_top1_lddt[key].update(iptm_top1_lddt, iptm_top1_total)
825
+ self.ligand_iptm_top1_lddt[key].update(
826
+ ligand_iptm_top1_lddt, ligand_iptm_top1_total
827
+ )
828
+ self.protein_iptm_top1_lddt[key].update(
829
+ protein_iptm_top1_lddt, protein_iptm_top1_total
830
+ )
831
+
832
+ self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key])
833
+ self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key])
834
+ self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key])
835
+
836
+ for key in mae_plddt_dict:
837
+ self.plddt_mae[key].update(
838
+ mae_plddt_dict[key], total_mae_plddt_dict[key]
839
+ )
840
+
841
+ for m in const.out_types:
842
+ if m == "ligand_protein":
843
+ if torch.any(
844
+ batch["pocket_feature"][
845
+ :, :, const.pocket_contact_info["POCKET"]
846
+ ].bool()
847
+ ):
848
+ self.lddt["pocket_ligand_protein"].update(
849
+ best_lddt_dict[m], best_total_dict[m]
850
+ )
851
+ self.disto_lddt["pocket_ligand_protein"].update(
852
+ disto_lddt_dict[m], disto_total_dict[m]
853
+ )
854
+ self.complex_lddt["pocket_ligand_protein"].update(
855
+ best_complex_lddt_dict[m], best_complex_total_dict[m]
856
+ )
857
+ else:
858
+ self.lddt["ligand_protein"].update(
859
+ best_lddt_dict[m], best_total_dict[m]
860
+ )
861
+ self.disto_lddt["ligand_protein"].update(
862
+ disto_lddt_dict[m], disto_total_dict[m]
863
+ )
864
+ self.complex_lddt["ligand_protein"].update(
865
+ best_complex_lddt_dict[m], best_complex_total_dict[m]
866
+ )
867
+ else:
868
+ self.lddt[m].update(best_lddt_dict[m], best_total_dict[m])
869
+ self.disto_lddt[m].update(disto_lddt_dict[m], disto_total_dict[m])
870
+ self.complex_lddt[m].update(
871
+ best_complex_lddt_dict[m], best_complex_total_dict[m]
872
+ )
873
+ self.rmsd.update(rmsds)
874
+ self.best_rmsd.update(best_rmsds)
875
+
876
+ def on_validation_epoch_end(self):
877
+ avg_lddt = {}
878
+ avg_disto_lddt = {}
879
+ avg_complex_lddt = {}
880
+ if self.confidence_prediction:
881
+ avg_top1_lddt = {}
882
+ avg_iplddt_top1_lddt = {}
883
+ avg_pde_top1_lddt = {}
884
+ avg_ipde_top1_lddt = {}
885
+ avg_ptm_top1_lddt = {}
886
+ avg_iptm_top1_lddt = {}
887
+ avg_ligand_iptm_top1_lddt = {}
888
+ avg_protein_iptm_top1_lddt = {}
889
+
890
+ avg_avg_lddt = {}
891
+ avg_mae_plddt = {}
892
+ avg_mae_pde = {}
893
+ avg_mae_pae = {}
894
+
895
+ for m in const.out_types + ["pocket_ligand_protein"]:
896
+ avg_lddt[m] = self.lddt[m].compute()
897
+ avg_lddt[m] = 0.0 if torch.isnan(avg_lddt[m]) else avg_lddt[m].item()
898
+ self.lddt[m].reset()
899
+ self.log(f"val/lddt_{m}", avg_lddt[m], prog_bar=False, sync_dist=True)
900
+
901
+ avg_disto_lddt[m] = self.disto_lddt[m].compute()
902
+ avg_disto_lddt[m] = (
903
+ 0.0 if torch.isnan(avg_disto_lddt[m]) else avg_disto_lddt[m].item()
904
+ )
905
+ self.disto_lddt[m].reset()
906
+ self.log(
907
+ f"val/disto_lddt_{m}", avg_disto_lddt[m], prog_bar=False, sync_dist=True
908
+ )
909
+ avg_complex_lddt[m] = self.complex_lddt[m].compute()
910
+ avg_complex_lddt[m] = (
911
+ 0.0 if torch.isnan(avg_complex_lddt[m]) else avg_complex_lddt[m].item()
912
+ )
913
+ self.complex_lddt[m].reset()
914
+ self.log(
915
+ f"val/complex_lddt_{m}",
916
+ avg_complex_lddt[m],
917
+ prog_bar=False,
918
+ sync_dist=True,
919
+ )
920
+ if self.confidence_prediction:
921
+ avg_top1_lddt[m] = self.top1_lddt[m].compute()
922
+ avg_top1_lddt[m] = (
923
+ 0.0 if torch.isnan(avg_top1_lddt[m]) else avg_top1_lddt[m].item()
924
+ )
925
+ self.top1_lddt[m].reset()
926
+ self.log(
927
+ f"val/top1_lddt_{m}",
928
+ avg_top1_lddt[m],
929
+ prog_bar=False,
930
+ sync_dist=True,
931
+ )
932
+ avg_iplddt_top1_lddt[m] = self.iplddt_top1_lddt[m].compute()
933
+ avg_iplddt_top1_lddt[m] = (
934
+ 0.0
935
+ if torch.isnan(avg_iplddt_top1_lddt[m])
936
+ else avg_iplddt_top1_lddt[m].item()
937
+ )
938
+ self.iplddt_top1_lddt[m].reset()
939
+ self.log(
940
+ f"val/iplddt_top1_lddt_{m}",
941
+ avg_iplddt_top1_lddt[m],
942
+ prog_bar=False,
943
+ sync_dist=True,
944
+ )
945
+ avg_pde_top1_lddt[m] = self.pde_top1_lddt[m].compute()
946
+ avg_pde_top1_lddt[m] = (
947
+ 0.0
948
+ if torch.isnan(avg_pde_top1_lddt[m])
949
+ else avg_pde_top1_lddt[m].item()
950
+ )
951
+ self.pde_top1_lddt[m].reset()
952
+ self.log(
953
+ f"val/pde_top1_lddt_{m}",
954
+ avg_pde_top1_lddt[m],
955
+ prog_bar=False,
956
+ sync_dist=True,
957
+ )
958
+ avg_ipde_top1_lddt[m] = self.ipde_top1_lddt[m].compute()
959
+ avg_ipde_top1_lddt[m] = (
960
+ 0.0
961
+ if torch.isnan(avg_ipde_top1_lddt[m])
962
+ else avg_ipde_top1_lddt[m].item()
963
+ )
964
+ self.ipde_top1_lddt[m].reset()
965
+ self.log(
966
+ f"val/ipde_top1_lddt_{m}",
967
+ avg_ipde_top1_lddt[m],
968
+ prog_bar=False,
969
+ sync_dist=True,
970
+ )
971
+ avg_ptm_top1_lddt[m] = self.ptm_top1_lddt[m].compute()
972
+ avg_ptm_top1_lddt[m] = (
973
+ 0.0
974
+ if torch.isnan(avg_ptm_top1_lddt[m])
975
+ else avg_ptm_top1_lddt[m].item()
976
+ )
977
+ self.ptm_top1_lddt[m].reset()
978
+ self.log(
979
+ f"val/ptm_top1_lddt_{m}",
980
+ avg_ptm_top1_lddt[m],
981
+ prog_bar=False,
982
+ sync_dist=True,
983
+ )
984
+ avg_iptm_top1_lddt[m] = self.iptm_top1_lddt[m].compute()
985
+ avg_iptm_top1_lddt[m] = (
986
+ 0.0
987
+ if torch.isnan(avg_iptm_top1_lddt[m])
988
+ else avg_iptm_top1_lddt[m].item()
989
+ )
990
+ self.iptm_top1_lddt[m].reset()
991
+ self.log(
992
+ f"val/iptm_top1_lddt_{m}",
993
+ avg_iptm_top1_lddt[m],
994
+ prog_bar=False,
995
+ sync_dist=True,
996
+ )
997
+
998
+ avg_ligand_iptm_top1_lddt[m] = self.ligand_iptm_top1_lddt[m].compute()
999
+ avg_ligand_iptm_top1_lddt[m] = (
1000
+ 0.0
1001
+ if torch.isnan(avg_ligand_iptm_top1_lddt[m])
1002
+ else avg_ligand_iptm_top1_lddt[m].item()
1003
+ )
1004
+ self.ligand_iptm_top1_lddt[m].reset()
1005
+ self.log(
1006
+ f"val/ligand_iptm_top1_lddt_{m}",
1007
+ avg_ligand_iptm_top1_lddt[m],
1008
+ prog_bar=False,
1009
+ sync_dist=True,
1010
+ )
1011
+
1012
+ avg_protein_iptm_top1_lddt[m] = self.protein_iptm_top1_lddt[m].compute()
1013
+ avg_protein_iptm_top1_lddt[m] = (
1014
+ 0.0
1015
+ if torch.isnan(avg_protein_iptm_top1_lddt[m])
1016
+ else avg_protein_iptm_top1_lddt[m].item()
1017
+ )
1018
+ self.protein_iptm_top1_lddt[m].reset()
1019
+ self.log(
1020
+ f"val/protein_iptm_top1_lddt_{m}",
1021
+ avg_protein_iptm_top1_lddt[m],
1022
+ prog_bar=False,
1023
+ sync_dist=True,
1024
+ )
1025
+
1026
+ avg_avg_lddt[m] = self.avg_lddt[m].compute()
1027
+ avg_avg_lddt[m] = (
1028
+ 0.0 if torch.isnan(avg_avg_lddt[m]) else avg_avg_lddt[m].item()
1029
+ )
1030
+ self.avg_lddt[m].reset()
1031
+ self.log(
1032
+ f"val/avg_lddt_{m}", avg_avg_lddt[m], prog_bar=False, sync_dist=True
1033
+ )
1034
+ avg_mae_pde[m] = self.pde_mae[m].compute().item()
1035
+ self.pde_mae[m].reset()
1036
+ self.log(
1037
+ f"val/MAE_pde_{m}",
1038
+ avg_mae_pde[m],
1039
+ prog_bar=False,
1040
+ sync_dist=True,
1041
+ )
1042
+ avg_mae_pae[m] = self.pae_mae[m].compute().item()
1043
+ self.pae_mae[m].reset()
1044
+ self.log(
1045
+ f"val/MAE_pae_{m}",
1046
+ avg_mae_pae[m],
1047
+ prog_bar=False,
1048
+ sync_dist=True,
1049
+ )
1050
+
1051
+ for m in const.out_single_types:
1052
+ if self.confidence_prediction:
1053
+ avg_mae_plddt[m] = self.plddt_mae[m].compute().item()
1054
+ self.plddt_mae[m].reset()
1055
+ self.log(
1056
+ f"val/MAE_plddt_{m}",
1057
+ avg_mae_plddt[m],
1058
+ prog_bar=False,
1059
+ sync_dist=True,
1060
+ )
1061
+
1062
+ overall_disto_lddt = sum(
1063
+ avg_disto_lddt[m] * w for (m, w) in const.out_types_weights.items()
1064
+ ) / sum(const.out_types_weights.values())
1065
+ self.log("val/disto_lddt", overall_disto_lddt, prog_bar=True, sync_dist=True)
1066
+
1067
+ overall_lddt = sum(
1068
+ avg_lddt[m] * w for (m, w) in const.out_types_weights.items()
1069
+ ) / sum(const.out_types_weights.values())
1070
+ self.log("val/lddt", overall_lddt, prog_bar=True, sync_dist=True)
1071
+
1072
+ overall_complex_lddt = sum(
1073
+ avg_complex_lddt[m] * w for (m, w) in const.out_types_weights.items()
1074
+ ) / sum(const.out_types_weights.values())
1075
+ self.log(
1076
+ "val/complex_lddt", overall_complex_lddt, prog_bar=True, sync_dist=True
1077
+ )
1078
+
1079
+ if self.confidence_prediction:
1080
+ overall_top1_lddt = sum(
1081
+ avg_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
1082
+ ) / sum(const.out_types_weights.values())
1083
+ self.log("val/top1_lddt", overall_top1_lddt, prog_bar=True, sync_dist=True)
1084
+
1085
+ overall_iplddt_top1_lddt = sum(
1086
+ avg_iplddt_top1_lddt[m] * w
1087
+ for (m, w) in const.out_types_weights.items()
1088
+ ) / sum(const.out_types_weights.values())
1089
+ self.log(
1090
+ "val/iplddt_top1_lddt",
1091
+ overall_iplddt_top1_lddt,
1092
+ prog_bar=True,
1093
+ sync_dist=True,
1094
+ )
1095
+
1096
+ overall_pde_top1_lddt = sum(
1097
+ avg_pde_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
1098
+ ) / sum(const.out_types_weights.values())
1099
+ self.log(
1100
+ "val/pde_top1_lddt",
1101
+ overall_pde_top1_lddt,
1102
+ prog_bar=True,
1103
+ sync_dist=True,
1104
+ )
1105
+
1106
+ overall_ipde_top1_lddt = sum(
1107
+ avg_ipde_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
1108
+ ) / sum(const.out_types_weights.values())
1109
+ self.log(
1110
+ "val/ipde_top1_lddt",
1111
+ overall_ipde_top1_lddt,
1112
+ prog_bar=True,
1113
+ sync_dist=True,
1114
+ )
1115
+
1116
+ overall_ptm_top1_lddt = sum(
1117
+ avg_ptm_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
1118
+ ) / sum(const.out_types_weights.values())
1119
+ self.log(
1120
+ "val/ptm_top1_lddt",
1121
+ overall_ptm_top1_lddt,
1122
+ prog_bar=True,
1123
+ sync_dist=True,
1124
+ )
1125
+
1126
+ overall_iptm_top1_lddt = sum(
1127
+ avg_iptm_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
1128
+ ) / sum(const.out_types_weights.values())
1129
+ self.log(
1130
+ "val/iptm_top1_lddt",
1131
+ overall_iptm_top1_lddt,
1132
+ prog_bar=True,
1133
+ sync_dist=True,
1134
+ )
1135
+
1136
+ overall_avg_lddt = sum(
1137
+ avg_avg_lddt[m] * w for (m, w) in const.out_types_weights.items()
1138
+ ) / sum(const.out_types_weights.values())
1139
+ self.log("val/avg_lddt", overall_avg_lddt, prog_bar=True, sync_dist=True)
1140
+
1141
+ self.log("val/rmsd", self.rmsd.compute(), prog_bar=True, sync_dist=True)
1142
+ self.rmsd.reset()
1143
+
1144
+ self.log(
1145
+ "val/best_rmsd", self.best_rmsd.compute(), prog_bar=True, sync_dist=True
1146
+ )
1147
+ self.best_rmsd.reset()
1148
+
1149
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
1150
+ try:
1151
+ out = self(
1152
+ batch,
1153
+ recycling_steps=self.predict_args["recycling_steps"],
1154
+ num_sampling_steps=self.predict_args["sampling_steps"],
1155
+ diffusion_samples=self.predict_args["diffusion_samples"],
1156
+ max_parallel_samples=self.predict_args["diffusion_samples"],
1157
+ run_confidence_sequentially=True,
1158
+ )
1159
+ pred_dict = {"exception": False}
1160
+ pred_dict["masks"] = batch["atom_pad_mask"]
1161
+ pred_dict["coords"] = out["sample_atom_coords"]
1162
+ if self.predict_args.get("write_confidence_summary", True):
1163
+ pred_dict["confidence_score"] = (
1164
+ 4 * out["complex_plddt"]
1165
+ + (
1166
+ out["iptm"]
1167
+ if not torch.allclose(
1168
+ out["iptm"], torch.zeros_like(out["iptm"])
1169
+ )
1170
+ else out["ptm"]
1171
+ )
1172
+ ) / 5
1173
+ for key in [
1174
+ "ptm",
1175
+ "iptm",
1176
+ "ligand_iptm",
1177
+ "protein_iptm",
1178
+ "pair_chains_iptm",
1179
+ "complex_plddt",
1180
+ "complex_iplddt",
1181
+ "complex_pde",
1182
+ "complex_ipde",
1183
+ "plddt",
1184
+ ]:
1185
+ pred_dict[key] = out[key]
1186
+ if self.predict_args.get("write_full_pae", True):
1187
+ pred_dict["pae"] = out["pae"]
1188
+ if self.predict_args.get("write_full_pde", False):
1189
+ pred_dict["pde"] = out["pde"]
1190
+ return pred_dict
1191
+
1192
+ except RuntimeError as e: # catch out of memory exceptions
1193
+ if "out of memory" in str(e):
1194
+ print("| WARNING: ran out of memory, skipping batch")
1195
+ torch.cuda.empty_cache()
1196
+ gc.collect()
1197
+ return {"exception": True}
1198
+ else:
1199
+ raise
1200
+
1201
+ def configure_optimizers(self):
1202
+ """Configure the optimizer."""
1203
+
1204
+ if self.structure_prediction_training:
1205
+ parameters = [p for p in self.parameters() if p.requires_grad]
1206
+ else:
1207
+ parameters = [
1208
+ p for p in self.confidence_module.parameters() if p.requires_grad
1209
+ ] + [
1210
+ p
1211
+ for p in self.structure_module.out_token_feat_update.parameters()
1212
+ if p.requires_grad
1213
+ ]
1214
+
1215
+ optimizer = torch.optim.Adam(
1216
+ parameters,
1217
+ betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2),
1218
+ eps=self.training_args.adam_eps,
1219
+ lr=self.training_args.base_lr,
1220
+ )
1221
+ if self.training_args.lr_scheduler == "af3":
1222
+ scheduler = AlphaFoldLRScheduler(
1223
+ optimizer,
1224
+ base_lr=self.training_args.base_lr,
1225
+ max_lr=self.training_args.max_lr,
1226
+ warmup_no_steps=self.training_args.lr_warmup_no_steps,
1227
+ start_decay_after_n_steps=self.training_args.lr_start_decay_after_n_steps,
1228
+ decay_every_n_steps=self.training_args.lr_decay_every_n_steps,
1229
+ decay_factor=self.training_args.lr_decay_factor,
1230
+ )
1231
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
1232
+
1233
+ return optimizer
1234
+
1235
+ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
1236
+ if self.use_ema:
1237
+ checkpoint["ema"] = self.ema.state_dict()
1238
+
1239
+ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
1240
+ if self.use_ema and "ema" in checkpoint:
1241
+ self.ema = ExponentialMovingAverage(
1242
+ parameters=self.parameters(), decay=self.ema_decay
1243
+ )
1244
+ if self.ema.compatible(checkpoint["ema"]["shadow_params"]):
1245
+ self.ema.load_state_dict(checkpoint["ema"], device=torch.device("cpu"))
1246
+ else:
1247
+ self.ema = None
1248
+ print(
1249
+ "Warning: EMA state not loaded due to incompatible model parameters."
1250
+ )
1251
+
1252
+ def on_train_start(self):
1253
+ if self.use_ema and self.ema is None:
1254
+ self.ema = ExponentialMovingAverage(
1255
+ parameters=self.parameters(), decay=self.ema_decay
1256
+ )
1257
+ elif self.use_ema:
1258
+ self.ema.to(self.device)
1259
+
1260
+ def on_train_epoch_start(self) -> None:
1261
+ if self.use_ema:
1262
+ self.ema.restore(self.parameters())
1263
+
1264
+ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
1265
+ # Updates EMA parameters after optimizer.step()
1266
+ if self.use_ema:
1267
+ self.ema.update(self.parameters())
1268
+
1269
+ def prepare_eval(self) -> None:
1270
+ if self.use_ema and self.ema is None:
1271
+ self.ema = ExponentialMovingAverage(
1272
+ parameters=self.parameters(), decay=self.ema_decay
1273
+ )
1274
+
1275
+ if self.use_ema:
1276
+ self.ema.store(self.parameters())
1277
+ self.ema.copy_to(self.parameters())
1278
+
1279
+ def on_validation_start(self):
1280
+ self.prepare_eval()
1281
+
1282
+ def on_predict_start(self) -> None:
1283
+ self.prepare_eval()
1284
+
1285
+ def on_test_start(self) -> None:
1286
+ self.prepare_eval()