boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1249 @@
1
+ import gc
2
+ from typing import Any, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch._dynamo
7
+ from pytorch_lightning import Callback, LightningModule
8
+ from torch import Tensor, nn
9
+ from torchmetrics import MeanMetric
10
+
11
+ import boltz.model.layers.initialize as init
12
+ from boltz.data import const
13
+ from boltz.data.mol import (
14
+ minimum_lddt_symmetry_coords,
15
+ )
16
+ from boltz.model.layers.pairformer import PairformerModule
17
+ from boltz.model.loss.bfactor import bfactor_loss_fn
18
+ from boltz.model.loss.confidencev2 import (
19
+ confidence_loss,
20
+ )
21
+ from boltz.model.loss.distogramv2 import distogram_loss
22
+ from boltz.model.modules.affinity import AffinityModule
23
+ from boltz.model.modules.confidencev2 import ConfidenceModule
24
+ from boltz.model.modules.diffusion_conditioning import DiffusionConditioning
25
+ from boltz.model.modules.diffusionv2 import AtomDiffusion
26
+ from boltz.model.modules.encodersv2 import RelativePositionEncoder
27
+ from boltz.model.modules.trunkv2 import (
28
+ BFactorModule,
29
+ ContactConditioning,
30
+ DistogramModule,
31
+ InputEmbedder,
32
+ MSAModule,
33
+ TemplateModule,
34
+ TemplateV2Module,
35
+ )
36
+ from boltz.model.optim.ema import EMA
37
+ from boltz.model.optim.scheduler import AlphaFoldLRScheduler
38
+
39
+
40
+ class Boltz2(LightningModule):
41
+ """Boltz2 model."""
42
+
43
+ def __init__(
44
+ self,
45
+ atom_s: int,
46
+ atom_z: int,
47
+ token_s: int,
48
+ token_z: int,
49
+ num_bins: int,
50
+ training_args: dict[str, Any],
51
+ validation_args: dict[str, Any],
52
+ embedder_args: dict[str, Any],
53
+ msa_args: dict[str, Any],
54
+ pairformer_args: dict[str, Any],
55
+ score_model_args: dict[str, Any],
56
+ diffusion_process_args: dict[str, Any],
57
+ diffusion_loss_args: dict[str, Any],
58
+ confidence_model_args: Optional[dict[str, Any]] = None,
59
+ affinity_model_args: Optional[dict[str, Any]] = None,
60
+ affinity_model_args1: Optional[dict[str, Any]] = None,
61
+ affinity_model_args2: Optional[dict[str, Any]] = None,
62
+ validators: Any = None,
63
+ num_val_datasets: int = 1,
64
+ atom_feature_dim: int = 128,
65
+ template_args: Optional[dict] = None,
66
+ confidence_prediction: bool = True,
67
+ affinity_prediction: bool = False,
68
+ affinity_ensemble: bool = False,
69
+ affinity_mw_correction: bool = True,
70
+ run_trunk_and_structure: bool = True,
71
+ skip_run_structure: bool = False,
72
+ token_level_confidence: bool = True,
73
+ alpha_pae: float = 0.0,
74
+ structure_prediction_training: bool = True,
75
+ validate_structure: bool = True,
76
+ atoms_per_window_queries: int = 32,
77
+ atoms_per_window_keys: int = 128,
78
+ compile_pairformer: bool = False,
79
+ compile_structure: bool = False,
80
+ compile_confidence: bool = False,
81
+ compile_affinity: bool = False,
82
+ compile_msa: bool = False,
83
+ exclude_ions_from_lddt: bool = False,
84
+ ema: bool = False,
85
+ ema_decay: float = 0.999,
86
+ min_dist: float = 2.0,
87
+ max_dist: float = 22.0,
88
+ predict_args: Optional[dict[str, Any]] = None,
89
+ fix_sym_check: bool = False,
90
+ cyclic_pos_enc: bool = False,
91
+ aggregate_distogram: bool = True,
92
+ bond_type_feature: bool = False,
93
+ use_no_atom_char: bool = False,
94
+ no_random_recycling_training: bool = False,
95
+ use_atom_backbone_feat: bool = False,
96
+ use_residue_feats_atoms: bool = False,
97
+ conditioning_cutoff_min: float = 4.0,
98
+ conditioning_cutoff_max: float = 20.0,
99
+ steering_args: Optional[dict] = None,
100
+ use_templates: bool = False,
101
+ compile_templates: bool = False,
102
+ predict_bfactor: bool = False,
103
+ log_loss_every_steps: int = 50,
104
+ checkpoint_diffusion_conditioning: bool = False,
105
+ use_templates_v2: bool = False,
106
+ use_kernels: bool = False,
107
+ ) -> None:
108
+ super().__init__()
109
+ self.save_hyperparameters(ignore=["validators"])
110
+
111
+ # No random recycling
112
+ self.no_random_recycling_training = no_random_recycling_training
113
+
114
+ if validate_structure:
115
+ # Late init at setup time
116
+ self.val_group_mapper = {} # maps a dataset index to a validation group name
117
+ self.validator_mapper = {} # maps a dataset index to a validator
118
+
119
+ # Validators for each dataset keep track of all metrics,
120
+ # compute validation, aggregate results and log
121
+ self.validators = nn.ModuleList(validators)
122
+
123
+ self.num_val_datasets = num_val_datasets
124
+ self.log_loss_every_steps = log_loss_every_steps
125
+
126
+ # EMA
127
+ self.use_ema = ema
128
+ self.ema_decay = ema_decay
129
+
130
+ # Arguments
131
+ self.training_args = training_args
132
+ self.validation_args = validation_args
133
+ self.diffusion_loss_args = diffusion_loss_args
134
+ self.predict_args = predict_args
135
+ self.steering_args = steering_args
136
+
137
+ # Training metrics
138
+ if validate_structure:
139
+ self.train_confidence_loss_logger = MeanMetric()
140
+ self.train_confidence_loss_dict_logger = nn.ModuleDict()
141
+ for m in [
142
+ "plddt_loss",
143
+ "resolved_loss",
144
+ "pde_loss",
145
+ "pae_loss",
146
+ ]:
147
+ self.train_confidence_loss_dict_logger[m] = MeanMetric()
148
+
149
+ self.exclude_ions_from_lddt = exclude_ions_from_lddt
150
+
151
+ # Distogram
152
+ self.num_bins = num_bins
153
+ self.min_dist = min_dist
154
+ self.max_dist = max_dist
155
+ self.aggregate_distogram = aggregate_distogram
156
+
157
+ # Trunk
158
+ self.is_pairformer_compiled = False
159
+ self.is_msa_compiled = False
160
+ self.is_template_compiled = False
161
+
162
+ # Kernels
163
+ self.use_kernels = use_kernels
164
+
165
+ # Input embeddings
166
+ full_embedder_args = {
167
+ "atom_s": atom_s,
168
+ "atom_z": atom_z,
169
+ "token_s": token_s,
170
+ "token_z": token_z,
171
+ "atoms_per_window_queries": atoms_per_window_queries,
172
+ "atoms_per_window_keys": atoms_per_window_keys,
173
+ "atom_feature_dim": atom_feature_dim,
174
+ "use_no_atom_char": use_no_atom_char,
175
+ "use_atom_backbone_feat": use_atom_backbone_feat,
176
+ "use_residue_feats_atoms": use_residue_feats_atoms,
177
+ **embedder_args,
178
+ }
179
+ self.input_embedder = InputEmbedder(**full_embedder_args)
180
+
181
+ self.s_init = nn.Linear(token_s, token_s, bias=False)
182
+ self.z_init_1 = nn.Linear(token_s, token_z, bias=False)
183
+ self.z_init_2 = nn.Linear(token_s, token_z, bias=False)
184
+
185
+ self.rel_pos = RelativePositionEncoder(
186
+ token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc
187
+ )
188
+
189
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
190
+ self.bond_type_feature = bond_type_feature
191
+ if bond_type_feature:
192
+ self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z)
193
+
194
+ self.contact_conditioning = ContactConditioning(
195
+ token_z=token_z,
196
+ cutoff_min=conditioning_cutoff_min,
197
+ cutoff_max=conditioning_cutoff_max,
198
+ )
199
+
200
+ # Normalization layers
201
+ self.s_norm = nn.LayerNorm(token_s)
202
+ self.z_norm = nn.LayerNorm(token_z)
203
+
204
+ # Recycling projections
205
+ self.s_recycle = nn.Linear(token_s, token_s, bias=False)
206
+ self.z_recycle = nn.Linear(token_z, token_z, bias=False)
207
+ init.gating_init_(self.s_recycle.weight)
208
+ init.gating_init_(self.z_recycle.weight)
209
+
210
+ # Set compile rules
211
+ # Big models hit the default cache limit (8)
212
+ torch._dynamo.config.cache_size_limit = 512 # noqa: SLF001
213
+ torch._dynamo.config.accumulated_cache_size_limit = 512 # noqa: SLF001
214
+
215
+ # Pairwise stack
216
+ self.use_templates = use_templates
217
+ if use_templates:
218
+ if use_templates_v2:
219
+ self.template_module = TemplateV2Module(token_z, **template_args)
220
+ else:
221
+ self.template_module = TemplateModule(token_z, **template_args)
222
+ if compile_templates:
223
+ self.is_template_compiled = True
224
+ self.template_module = torch.compile(
225
+ self.template_module,
226
+ dynamic=False,
227
+ fullgraph=False,
228
+ )
229
+
230
+ self.msa_module = MSAModule(
231
+ token_z=token_z,
232
+ token_s=token_s,
233
+ **msa_args,
234
+ )
235
+ if compile_msa:
236
+ self.is_msa_compiled = True
237
+ self.msa_module = torch.compile(
238
+ self.msa_module,
239
+ dynamic=False,
240
+ fullgraph=False,
241
+ )
242
+ self.pairformer_module = PairformerModule(token_s, token_z, **pairformer_args)
243
+ if compile_pairformer:
244
+ self.is_pairformer_compiled = True
245
+ self.pairformer_module = torch.compile(
246
+ self.pairformer_module,
247
+ dynamic=False,
248
+ fullgraph=False,
249
+ )
250
+
251
+ self.checkpoint_diffusion_conditioning = checkpoint_diffusion_conditioning
252
+ self.diffusion_conditioning = DiffusionConditioning(
253
+ token_s=token_s,
254
+ token_z=token_z,
255
+ atom_s=atom_s,
256
+ atom_z=atom_z,
257
+ atoms_per_window_queries=atoms_per_window_queries,
258
+ atoms_per_window_keys=atoms_per_window_keys,
259
+ atom_encoder_depth=score_model_args["atom_encoder_depth"],
260
+ atom_encoder_heads=score_model_args["atom_encoder_heads"],
261
+ token_transformer_depth=score_model_args["token_transformer_depth"],
262
+ token_transformer_heads=score_model_args["token_transformer_heads"],
263
+ atom_decoder_depth=score_model_args["atom_decoder_depth"],
264
+ atom_decoder_heads=score_model_args["atom_decoder_heads"],
265
+ atom_feature_dim=atom_feature_dim,
266
+ conditioning_transition_layers=score_model_args[
267
+ "conditioning_transition_layers"
268
+ ],
269
+ use_no_atom_char=use_no_atom_char,
270
+ use_atom_backbone_feat=use_atom_backbone_feat,
271
+ use_residue_feats_atoms=use_residue_feats_atoms,
272
+ )
273
+
274
+ # Output modules
275
+ self.structure_module = AtomDiffusion(
276
+ score_model_args={
277
+ "token_s": token_s,
278
+ "atom_s": atom_s,
279
+ "atoms_per_window_queries": atoms_per_window_queries,
280
+ "atoms_per_window_keys": atoms_per_window_keys,
281
+ **score_model_args,
282
+ },
283
+ compile_score=compile_structure,
284
+ **diffusion_process_args,
285
+ )
286
+ self.distogram_module = DistogramModule(
287
+ token_z,
288
+ num_bins,
289
+ )
290
+ self.predict_bfactor = predict_bfactor
291
+ if predict_bfactor:
292
+ self.bfactor_module = BFactorModule(token_s, num_bins)
293
+
294
+ self.confidence_prediction = confidence_prediction
295
+ self.affinity_prediction = affinity_prediction
296
+ self.affinity_ensemble = affinity_ensemble
297
+ self.affinity_mw_correction = affinity_mw_correction
298
+ self.run_trunk_and_structure = run_trunk_and_structure
299
+ self.skip_run_structure = skip_run_structure
300
+ self.token_level_confidence = token_level_confidence
301
+ self.alpha_pae = alpha_pae
302
+ self.structure_prediction_training = structure_prediction_training
303
+
304
+ if self.confidence_prediction:
305
+ self.confidence_module = ConfidenceModule(
306
+ token_s,
307
+ token_z,
308
+ token_level_confidence=token_level_confidence,
309
+ bond_type_feature=bond_type_feature,
310
+ fix_sym_check=fix_sym_check,
311
+ cyclic_pos_enc=cyclic_pos_enc,
312
+ conditioning_cutoff_min=conditioning_cutoff_min,
313
+ conditioning_cutoff_max=conditioning_cutoff_max,
314
+ **confidence_model_args,
315
+ )
316
+ if compile_confidence:
317
+ self.confidence_module = torch.compile(
318
+ self.confidence_module, dynamic=False, fullgraph=False
319
+ )
320
+
321
+ if self.affinity_prediction:
322
+ if self.affinity_ensemble:
323
+ self.affinity_module1 = AffinityModule(
324
+ token_s,
325
+ token_z,
326
+ **affinity_model_args1,
327
+ )
328
+ self.affinity_module2 = AffinityModule(
329
+ token_s,
330
+ token_z,
331
+ **affinity_model_args2,
332
+ )
333
+ if compile_affinity:
334
+ self.affinity_module1 = torch.compile(
335
+ self.affinity_module1, dynamic=False, fullgraph=False
336
+ )
337
+ self.affinity_module2 = torch.compile(
338
+ self.affinity_module2, dynamic=False, fullgraph=False
339
+ )
340
+ else:
341
+ self.affinity_module = AffinityModule(
342
+ token_s,
343
+ token_z,
344
+ **affinity_model_args,
345
+ )
346
+ if compile_affinity:
347
+ self.affinity_module = torch.compile(
348
+ self.affinity_module, dynamic=False, fullgraph=False
349
+ )
350
+
351
+ # Remove grad from weights they are not trained for ddp
352
+ if not structure_prediction_training:
353
+ for name, param in self.named_parameters():
354
+ if (
355
+ name.split(".")[0] not in ["confidence_module", "affinity_module"]
356
+ and "out_token_feat_update" not in name
357
+ ):
358
+ param.requires_grad = False
359
+
360
+ def setup(self, stage: str) -> None:
361
+ """Set the model for training, validation and inference."""
362
+ if stage == "predict" and not (
363
+ torch.cuda.is_available()
364
+ and torch.cuda.get_device_properties(torch.device("cuda")).major >= 8.0 # noqa: PLR2004
365
+ ):
366
+ self.use_kernels = False
367
+
368
+ if (
369
+ stage != "predict"
370
+ and hasattr(self.trainer, "datamodule")
371
+ and self.trainer.datamodule
372
+ and self.validate_structure
373
+ ):
374
+ self.val_group_mapper.update(self.trainer.datamodule.val_group_mapper)
375
+
376
+ l1 = len(self.val_group_mapper)
377
+ l2 = self.num_val_datasets
378
+ msg = (
379
+ f"Number of validation datasets num_val_datasets={l2} "
380
+ f"does not match the number of val_group_mapper entries={l1}."
381
+ )
382
+ assert l1 == l2, msg
383
+
384
+ # Map an index to a validator, and double check val names
385
+ # match from datamodule
386
+ all_validator_names = []
387
+ for validator in self.validators:
388
+ for val_name in validator.val_names:
389
+ msg = f"Validator {val_name} duplicated in validators."
390
+ assert val_name not in all_validator_names, msg
391
+ all_validator_names.append(val_name)
392
+ for val_idx, val_group in self.val_group_mapper.items():
393
+ if val_name == val_group["label"]:
394
+ self.validator_mapper[val_idx] = validator
395
+
396
+ msg = "Mismatch between validator names and val_group_mapper values."
397
+ assert set(all_validator_names) == {
398
+ x["label"] for x in self.val_group_mapper.values()
399
+ }, msg
400
+
401
+ def forward(
402
+ self,
403
+ feats: dict[str, Tensor],
404
+ recycling_steps: int = 0,
405
+ num_sampling_steps: Optional[int] = None,
406
+ multiplicity_diffusion_train: int = 1,
407
+ diffusion_samples: int = 1,
408
+ max_parallel_samples: Optional[int] = None,
409
+ run_confidence_sequentially: bool = False,
410
+ ) -> dict[str, Tensor]:
411
+ with torch.set_grad_enabled(
412
+ self.training and self.structure_prediction_training
413
+ ):
414
+ s_inputs = self.input_embedder(feats)
415
+
416
+ # Initialize the sequence embeddings
417
+ s_init = self.s_init(s_inputs)
418
+
419
+ # Initialize pairwise embeddings
420
+ z_init = (
421
+ self.z_init_1(s_inputs)[:, :, None]
422
+ + self.z_init_2(s_inputs)[:, None, :]
423
+ )
424
+ relative_position_encoding = self.rel_pos(feats)
425
+ z_init = z_init + relative_position_encoding
426
+ z_init = z_init + self.token_bonds(feats["token_bonds"].float())
427
+ if self.bond_type_feature:
428
+ z_init = z_init + self.token_bonds_type(feats["type_bonds"].long())
429
+ z_init = z_init + self.contact_conditioning(feats)
430
+
431
+ # Perform rounds of the pairwise stack
432
+ s = torch.zeros_like(s_init)
433
+ z = torch.zeros_like(z_init)
434
+
435
+ # Compute pairwise mask
436
+ mask = feats["token_pad_mask"].float()
437
+ pair_mask = mask[:, :, None] * mask[:, None, :]
438
+ if self.run_trunk_and_structure:
439
+ for i in range(recycling_steps + 1):
440
+ with torch.set_grad_enabled(
441
+ self.training
442
+ and self.structure_prediction_training
443
+ and (i == recycling_steps)
444
+ ):
445
+ # Issue with unused parameters in autocast
446
+ if (
447
+ self.training
448
+ and (i == recycling_steps)
449
+ and torch.is_autocast_enabled()
450
+ ):
451
+ torch.clear_autocast_cache()
452
+
453
+ # Apply recycling
454
+ s = s_init + self.s_recycle(self.s_norm(s))
455
+ z = z_init + self.z_recycle(self.z_norm(z))
456
+
457
+ # Compute pairwise stack
458
+ if self.use_templates:
459
+ if self.is_template_compiled and not self.training:
460
+ template_module = self.template_module._orig_mod # noqa: SLF001
461
+ else:
462
+ template_module = self.template_module
463
+
464
+ z = z + template_module(
465
+ z, feats, pair_mask, use_kernels=self.use_kernels
466
+ )
467
+
468
+ if self.is_msa_compiled and not self.training:
469
+ msa_module = self.msa_module._orig_mod # noqa: SLF001
470
+ else:
471
+ msa_module = self.msa_module
472
+
473
+ z = z + msa_module(
474
+ z, s_inputs, feats, use_kernels=self.use_kernels
475
+ )
476
+
477
+ # Revert to uncompiled version for validation
478
+ if self.is_pairformer_compiled and not self.training:
479
+ pairformer_module = self.pairformer_module._orig_mod # noqa: SLF001
480
+ else:
481
+ pairformer_module = self.pairformer_module
482
+
483
+ s, z = pairformer_module(
484
+ s,
485
+ z,
486
+ mask=mask,
487
+ pair_mask=pair_mask,
488
+ use_kernels=self.use_kernels,
489
+ )
490
+
491
+ pdistogram = self.distogram_module(z)
492
+ dict_out = {"pdistogram": pdistogram}
493
+
494
+ if (
495
+ self.run_trunk_and_structure
496
+ and ((not self.training) or self.confidence_prediction)
497
+ and (not self.skip_run_structure)
498
+ ):
499
+ if self.checkpoint_diffusion_conditioning and self.training:
500
+ # TODO decide whether this should be with bf16 or not
501
+ q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = (
502
+ torch.utils.checkpoint.checkpoint(
503
+ self.diffusion_conditioning,
504
+ s,
505
+ z,
506
+ relative_position_encoding,
507
+ feats,
508
+ )
509
+ )
510
+ else:
511
+ q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = (
512
+ self.diffusion_conditioning(
513
+ s_trunk=s,
514
+ z_trunk=z,
515
+ relative_position_encoding=relative_position_encoding,
516
+ feats=feats,
517
+ )
518
+ )
519
+ diffusion_conditioning = {
520
+ "q": q,
521
+ "c": c,
522
+ "to_keys": to_keys,
523
+ "atom_enc_bias": atom_enc_bias,
524
+ "atom_dec_bias": atom_dec_bias,
525
+ "token_trans_bias": token_trans_bias,
526
+ }
527
+
528
+ with torch.autocast("cuda", enabled=False):
529
+ struct_out = self.structure_module.sample(
530
+ s_trunk=s.float(),
531
+ s_inputs=s_inputs.float(),
532
+ feats=feats,
533
+ num_sampling_steps=num_sampling_steps,
534
+ atom_mask=feats["atom_pad_mask"].float(),
535
+ multiplicity=diffusion_samples,
536
+ max_parallel_samples=max_parallel_samples,
537
+ steering_args=self.steering_args,
538
+ diffusion_conditioning=diffusion_conditioning,
539
+ )
540
+ dict_out.update(struct_out)
541
+
542
+ if self.predict_bfactor:
543
+ pbfactor = self.bfactor_module(s)
544
+ dict_out["pbfactor"] = pbfactor
545
+
546
+ if self.training and self.confidence_prediction:
547
+ assert len(feats["coords"].shape) == 4
548
+ assert feats["coords"].shape[1] == 1, (
549
+ "Only one conformation is supported for confidence"
550
+ )
551
+
552
+ # Compute structure module
553
+ if self.training and self.structure_prediction_training:
554
+ atom_coords = feats["coords"]
555
+ B, K, L = atom_coords.shape[0:3]
556
+ assert K in (
557
+ multiplicity_diffusion_train,
558
+ 1,
559
+ ) # TODO make check somewhere else, expand to m % N == 0, m > N
560
+ atom_coords = atom_coords.reshape(B * K, L, 3)
561
+ atom_coords = atom_coords.repeat_interleave(
562
+ multiplicity_diffusion_train // K, 0
563
+ )
564
+ feats["coords"] = atom_coords # (multiplicity, L, 3)
565
+ assert len(feats["coords"].shape) == 3
566
+
567
+ with torch.autocast("cuda", enabled=False):
568
+ struct_out = self.structure_module(
569
+ s_trunk=s.float(),
570
+ s_inputs=s_inputs.float(),
571
+ feats=feats,
572
+ multiplicity=multiplicity_diffusion_train,
573
+ diffusion_conditioning=diffusion_conditioning,
574
+ )
575
+ dict_out.update(struct_out)
576
+
577
+ elif self.training:
578
+ feats["coords"] = feats["coords"].squeeze(1)
579
+ assert len(feats["coords"].shape) == 3
580
+
581
+ if self.confidence_prediction:
582
+ dict_out.update(
583
+ self.confidence_module(
584
+ s_inputs=s_inputs.detach(),
585
+ s=s.detach(),
586
+ z=z.detach(),
587
+ x_pred=(
588
+ dict_out["sample_atom_coords"].detach()
589
+ if not self.skip_run_structure
590
+ else feats["coords"].repeat_interleave(diffusion_samples, 0)
591
+ ),
592
+ feats=feats,
593
+ pred_distogram_logits=(
594
+ dict_out["pdistogram"][
595
+ :, :, :, 0
596
+ ].detach() # TODO only implemeted for 1 distogram
597
+ ),
598
+ multiplicity=diffusion_samples,
599
+ run_sequentially=run_confidence_sequentially,
600
+ use_kernels=self.use_kernels,
601
+ )
602
+ )
603
+
604
+ if self.affinity_prediction:
605
+ pad_token_mask = feats["token_pad_mask"][0]
606
+ rec_mask = feats["mol_type"][0] == 0
607
+ rec_mask = rec_mask * pad_token_mask
608
+ lig_mask = feats["affinity_token_mask"][0].to(torch.bool)
609
+ lig_mask = lig_mask * pad_token_mask
610
+ cross_pair_mask = (
611
+ lig_mask[:, None] * rec_mask[None, :]
612
+ + rec_mask[:, None] * lig_mask[None, :]
613
+ + lig_mask[:, None] * lig_mask[None, :]
614
+ )
615
+ z_affinity = z * cross_pair_mask[None, :, :, None]
616
+
617
+ argsort = torch.argsort(dict_out["iptm"], descending=True)
618
+ best_idx = argsort[0].item()
619
+ coords_affinity = dict_out["sample_atom_coords"].detach()[best_idx][
620
+ None, None
621
+ ]
622
+ s_inputs = self.input_embedder(feats, affinity=True)
623
+
624
+ with torch.autocast("cuda", enabled=False):
625
+ if self.affinity_ensemble:
626
+ dict_out_affinity1 = self.affinity_module1(
627
+ s_inputs=s_inputs.detach(),
628
+ z=z_affinity.detach(),
629
+ x_pred=coords_affinity,
630
+ feats=feats,
631
+ multiplicity=1,
632
+ use_kernels=self.use_kernels,
633
+ )
634
+
635
+ dict_out_affinity1["affinity_probability_binary"] = (
636
+ torch.nn.functional.sigmoid(
637
+ dict_out_affinity1["affinity_logits_binary"]
638
+ )
639
+ )
640
+ dict_out_affinity2 = self.affinity_module2(
641
+ s_inputs=s_inputs.detach(),
642
+ z=z_affinity.detach(),
643
+ x_pred=coords_affinity,
644
+ feats=feats,
645
+ multiplicity=1,
646
+ use_kernels=self.use_kernels,
647
+ )
648
+ dict_out_affinity2["affinity_probability_binary"] = (
649
+ torch.nn.functional.sigmoid(
650
+ dict_out_affinity2["affinity_logits_binary"]
651
+ )
652
+ )
653
+
654
+ dict_out_affinity_ensemble = {
655
+ "affinity_pred_value": (
656
+ dict_out_affinity1["affinity_pred_value"]
657
+ + dict_out_affinity2["affinity_pred_value"]
658
+ )
659
+ / 2,
660
+ "affinity_probability_binary": (
661
+ dict_out_affinity1["affinity_probability_binary"]
662
+ + dict_out_affinity2["affinity_probability_binary"]
663
+ )
664
+ / 2,
665
+ }
666
+
667
+ dict_out_affinity1 = {
668
+ "affinity_pred_value1": dict_out_affinity1[
669
+ "affinity_pred_value"
670
+ ],
671
+ "affinity_probability_binary1": dict_out_affinity1[
672
+ "affinity_probability_binary"
673
+ ],
674
+ }
675
+ dict_out_affinity2 = {
676
+ "affinity_pred_value2": dict_out_affinity2[
677
+ "affinity_pred_value"
678
+ ],
679
+ "affinity_probability_binary2": dict_out_affinity2[
680
+ "affinity_probability_binary"
681
+ ],
682
+ }
683
+ if self.affinity_mw_correction:
684
+ model_coef = 1.03525938
685
+ mw_coef = -0.59992683
686
+ bias = 2.83288489
687
+ mw = feats["affinity_mw"][0] ** 0.3
688
+ dict_out_affinity_ensemble["affinity_pred_value"] = (
689
+ model_coef
690
+ * dict_out_affinity_ensemble["affinity_pred_value"]
691
+ + mw_coef * mw
692
+ + bias
693
+ )
694
+
695
+ dict_out.update(dict_out_affinity_ensemble)
696
+ dict_out.update(dict_out_affinity1)
697
+ dict_out.update(dict_out_affinity2)
698
+ else:
699
+ dict_out_affinity = self.affinity_module(
700
+ s_inputs=s_inputs.detach(),
701
+ z=z_affinity.detach(),
702
+ x_pred=coords_affinity,
703
+ feats=feats,
704
+ multiplicity=1,
705
+ use_kernels=self.use_kernels,
706
+ )
707
+ dict_out.update(
708
+ {
709
+ "affinity_pred_value": dict_out_affinity[
710
+ "affinity_pred_value"
711
+ ],
712
+ "affinity_probability_binary": torch.nn.functional.sigmoid(
713
+ dict_out_affinity["affinity_logits_binary"]
714
+ ),
715
+ }
716
+ )
717
+
718
+ return dict_out
719
+
720
+ def get_true_coordinates(
721
+ self,
722
+ batch: dict[str, Tensor],
723
+ out: dict[str, Tensor],
724
+ diffusion_samples: int,
725
+ symmetry_correction: bool,
726
+ expand_to_diffusion_samples: bool = True,
727
+ ):
728
+ if symmetry_correction:
729
+ msg = "expand_to_diffusion_samples must be true for symmetry correction."
730
+ assert expand_to_diffusion_samples, msg
731
+
732
+ return_dict = {}
733
+
734
+ assert batch["coords"].shape[0] == 1, (
735
+ f"Validation is not supported for batch sizes={batch['coords'].shape[0]}"
736
+ )
737
+
738
+ if symmetry_correction:
739
+ true_coords = []
740
+ true_coords_resolved_mask = []
741
+ for idx in range(batch["token_index"].shape[0]):
742
+ for rep in range(diffusion_samples):
743
+ i = idx * diffusion_samples + rep
744
+ best_true_coords, best_true_coords_resolved_mask = (
745
+ minimum_lddt_symmetry_coords(
746
+ coords=out["sample_atom_coords"][i : i + 1],
747
+ feats=batch,
748
+ index_batch=idx,
749
+ )
750
+ )
751
+ true_coords.append(best_true_coords)
752
+ true_coords_resolved_mask.append(best_true_coords_resolved_mask)
753
+
754
+ true_coords = torch.cat(true_coords, dim=0)
755
+ true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0)
756
+ true_coords = true_coords.unsqueeze(1)
757
+
758
+ true_coords_resolved_mask = true_coords_resolved_mask
759
+
760
+ return_dict["true_coords"] = true_coords
761
+ return_dict["true_coords_resolved_mask"] = true_coords_resolved_mask
762
+ return_dict["rmsds"] = 0
763
+ return_dict["best_rmsd_recall"] = 0
764
+
765
+ else:
766
+ K, L = batch["coords"].shape[1:3]
767
+
768
+ true_coords_resolved_mask = batch["atom_resolved_mask"]
769
+ true_coords = batch["coords"].squeeze(0)
770
+ if expand_to_diffusion_samples:
771
+ true_coords = true_coords.repeat((diffusion_samples, 1, 1)).reshape(
772
+ diffusion_samples, K, L, 3
773
+ )
774
+
775
+ true_coords_resolved_mask = true_coords_resolved_mask.repeat_interleave(
776
+ diffusion_samples, dim=0
777
+ ) # since all masks are the same across conformers and diffusion samples, can just repeat S times
778
+ else:
779
+ true_coords_resolved_mask = true_coords_resolved_mask.squeeze(0)
780
+
781
+ return_dict["true_coords"] = true_coords
782
+ return_dict["true_coords_resolved_mask"] = true_coords_resolved_mask
783
+ return_dict["rmsds"] = 0
784
+ return_dict["best_rmsd_recall"] = 0
785
+ return_dict["best_rmsd_precision"] = 0
786
+
787
+ return return_dict
788
+
789
+ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
790
+ # Sample recycling steps
791
+ if self.no_random_recycling_training:
792
+ recycling_steps = self.training_args.recycling_steps
793
+ else:
794
+ rgn = np.random.default_rng(self.global_step)
795
+ recycling_steps = rgn.integers(
796
+ 0, self.training_args.recycling_steps + 1
797
+ ).item()
798
+
799
+ if self.training_args.get("sampling_steps_random", None) is not None:
800
+ rgn_samplng_steps = np.random.default_rng(self.global_step)
801
+ sampling_steps = rgn_samplng_steps.choice(
802
+ self.training_args.sampling_steps_random
803
+ )
804
+ else:
805
+ sampling_steps = self.training_args.sampling_steps
806
+
807
+ # Compute the forward pass
808
+ out = self(
809
+ feats=batch,
810
+ recycling_steps=recycling_steps,
811
+ num_sampling_steps=sampling_steps,
812
+ multiplicity_diffusion_train=self.training_args.diffusion_multiplicity,
813
+ diffusion_samples=self.training_args.diffusion_samples,
814
+ )
815
+
816
+ # Compute losses
817
+ if self.structure_prediction_training:
818
+ disto_loss, _ = distogram_loss(
819
+ out,
820
+ batch,
821
+ aggregate_distogram=self.aggregate_distogram,
822
+ )
823
+ try:
824
+ diffusion_loss_dict = self.structure_module.compute_loss(
825
+ batch,
826
+ out,
827
+ multiplicity=self.training_args.diffusion_multiplicity,
828
+ **self.diffusion_loss_args,
829
+ )
830
+ except Exception as e:
831
+ print(f"Skipping batch {batch_idx} due to error: {e}")
832
+ return None
833
+
834
+ if self.predict_bfactor:
835
+ bfactor_loss = bfactor_loss_fn(out, batch)
836
+ else:
837
+ bfactor_loss = 0.0
838
+
839
+ else:
840
+ disto_loss = 0.0
841
+ bfactor_loss = 0.0
842
+ diffusion_loss_dict = {"loss": 0.0, "loss_breakdown": {}}
843
+
844
+ if self.confidence_prediction:
845
+ try:
846
+ # confidence model symmetry correction
847
+ return_dict = self.get_true_coordinates(
848
+ batch,
849
+ out,
850
+ diffusion_samples=self.training_args.diffusion_samples,
851
+ symmetry_correction=self.training_args.symmetry_correction,
852
+ )
853
+ except Exception as e:
854
+ print(f"Skipping batch with id {batch['pdb_id']} due to error: {e}")
855
+ return None
856
+
857
+ true_coords = return_dict["true_coords"]
858
+ true_coords_resolved_mask = return_dict["true_coords_resolved_mask"]
859
+
860
+ # TODO remove once multiple conformers are supported
861
+ K = true_coords.shape[1]
862
+ assert K == 1, (
863
+ f"Confidence_prediction is not supported for num_ensembles_val={K}."
864
+ )
865
+
866
+ # For now, just take the only conformer.
867
+ true_coords = true_coords.squeeze(1) # (S, L, 3)
868
+ batch["frames_idx"] = batch["frames_idx"].squeeze(
869
+ 1
870
+ ) # remove conformer dimension
871
+ batch["frame_resolved_mask"] = batch["frame_resolved_mask"].squeeze(
872
+ 1
873
+ ) # remove conformer dimension
874
+
875
+ confidence_loss_dict = confidence_loss(
876
+ out,
877
+ batch,
878
+ true_coords,
879
+ true_coords_resolved_mask,
880
+ token_level_confidence=self.token_level_confidence,
881
+ alpha_pae=self.alpha_pae,
882
+ multiplicity=self.training_args.diffusion_samples,
883
+ )
884
+
885
+ else:
886
+ confidence_loss_dict = {
887
+ "loss": torch.tensor(0.0, device=batch["token_index"].device),
888
+ "loss_breakdown": {},
889
+ }
890
+
891
+ # Aggregate losses
892
+ # NOTE: we already have an implicit weight in the losses induced by dataset sampling
893
+ # NOTE: this logic works only for datasets with confidence labels
894
+ loss = (
895
+ self.training_args.confidence_loss_weight * confidence_loss_dict["loss"]
896
+ + self.training_args.diffusion_loss_weight * diffusion_loss_dict["loss"]
897
+ + self.training_args.distogram_loss_weight * disto_loss
898
+ + self.training_args.get("bfactor_loss_weight", 0.0) * bfactor_loss
899
+ )
900
+
901
+ if not (self.global_step % self.log_loss_every_steps):
902
+ # Log losses
903
+ if self.validate_structure:
904
+ self.log("train/distogram_loss", disto_loss)
905
+ self.log("train/diffusion_loss", diffusion_loss_dict["loss"])
906
+ for k, v in diffusion_loss_dict["loss_breakdown"].items():
907
+ self.log(f"train/{k}", v)
908
+
909
+ if self.confidence_prediction:
910
+ self.train_confidence_loss_logger.update(
911
+ confidence_loss_dict["loss"].detach()
912
+ )
913
+ for k in self.train_confidence_loss_dict_logger:
914
+ self.train_confidence_loss_dict_logger[k].update(
915
+ (
916
+ confidence_loss_dict["loss_breakdown"][k].detach()
917
+ if torch.is_tensor(
918
+ confidence_loss_dict["loss_breakdown"][k]
919
+ )
920
+ else confidence_loss_dict["loss_breakdown"][k]
921
+ )
922
+ )
923
+ self.log("train/loss", loss)
924
+ self.training_log()
925
+ return loss
926
+
927
+ def training_log(self):
928
+ self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False)
929
+ self.log("train/param_norm", self.parameter_norm(self), prog_bar=False)
930
+
931
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
932
+ self.log("lr", lr, prog_bar=False)
933
+
934
+ self.log(
935
+ "train/param_norm_msa_module",
936
+ self.parameter_norm(self.msa_module),
937
+ prog_bar=False,
938
+ )
939
+
940
+ self.log(
941
+ "train/param_norm_pairformer_module",
942
+ self.parameter_norm(self.pairformer_module),
943
+ prog_bar=False,
944
+ )
945
+
946
+ self.log(
947
+ "train/param_norm_structure_module",
948
+ self.parameter_norm(self.structure_module),
949
+ prog_bar=False,
950
+ )
951
+
952
+ if self.confidence_prediction:
953
+ self.log(
954
+ "train/grad_norm_confidence_module",
955
+ self.gradient_norm(self.confidence_module),
956
+ prog_bar=False,
957
+ )
958
+ self.log(
959
+ "train/param_norm_confidence_module",
960
+ self.parameter_norm(self.confidence_module),
961
+ prog_bar=False,
962
+ )
963
+
964
+ def on_train_epoch_end(self):
965
+ if self.confidence_prediction:
966
+ self.log(
967
+ "train/confidence_loss",
968
+ self.train_confidence_loss_logger,
969
+ prog_bar=False,
970
+ on_step=False,
971
+ on_epoch=True,
972
+ )
973
+ for k, v in self.train_confidence_loss_dict_logger.items():
974
+ self.log(f"train/{k}", v, prog_bar=False, on_step=False, on_epoch=True)
975
+
976
+ def gradient_norm(self, module):
977
+ parameters = [
978
+ p.grad.norm(p=2) ** 2
979
+ for p in module.parameters()
980
+ if p.requires_grad and p.grad is not None
981
+ ]
982
+ if len(parameters) == 0:
983
+ return torch.tensor(
984
+ 0.0, device="cuda" if torch.cuda.is_available() else "cpu"
985
+ )
986
+ norm = torch.stack(parameters).sum().sqrt()
987
+ return norm
988
+
989
+ def parameter_norm(self, module):
990
+ parameters = [p.norm(p=2) ** 2 for p in module.parameters() if p.requires_grad]
991
+ if len(parameters) == 0:
992
+ return torch.tensor(
993
+ 0.0, device="cuda" if torch.cuda.is_available() else "cpu"
994
+ )
995
+ norm = torch.stack(parameters).sum().sqrt()
996
+ return norm
997
+
998
+ def validation_step(self, batch: dict[str, Tensor], batch_idx: int):
999
+ if self.validate_structure:
1000
+ try:
1001
+ msg = "Only batch=1 is supported for validation"
1002
+ assert batch["idx_dataset"].shape[0] == 1, msg
1003
+
1004
+ # Select validator based on dataset
1005
+ idx_dataset = batch["idx_dataset"][0].item()
1006
+ validator = self.validator_mapper[idx_dataset]
1007
+
1008
+ # Run forward pass
1009
+ out = validator.run_model(
1010
+ model=self, batch=batch, idx_dataset=idx_dataset
1011
+ )
1012
+ # Compute validation step
1013
+ validator.process(
1014
+ model=self, batch=batch, out=out, idx_dataset=idx_dataset
1015
+ )
1016
+ except RuntimeError as e: # catch out of memory exceptions
1017
+ idx_dataset = batch["idx_dataset"][0].item()
1018
+ if "out of memory" in str(e):
1019
+ msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}"
1020
+ print(msg)
1021
+ torch.cuda.empty_cache()
1022
+ gc.collect()
1023
+ return
1024
+ raise e
1025
+ else:
1026
+ try:
1027
+ out = self(
1028
+ batch,
1029
+ recycling_steps=self.validation_args.recycling_steps,
1030
+ num_sampling_steps=self.validation_args.sampling_steps,
1031
+ diffusion_samples=self.validation_args.diffusion_samples,
1032
+ run_confidence_sequentially=self.validation_args.get(
1033
+ "run_confidence_sequentially", False
1034
+ ),
1035
+ )
1036
+ except RuntimeError as e: # catch out of memory exceptions
1037
+ idx_dataset = batch["idx_dataset"][0].item()
1038
+ if "out of memory" in str(e):
1039
+ msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}"
1040
+ print(msg)
1041
+ torch.cuda.empty_cache()
1042
+ gc.collect()
1043
+ return
1044
+ raise e
1045
+
1046
+ def on_validation_epoch_end(self):
1047
+ """Aggregate all metrics for each validator."""
1048
+ if self.validate_structure:
1049
+ for validator in self.validator_mapper.values():
1050
+ # This will aggregate, compute and log all metrics
1051
+ validator.on_epoch_end(model=self)
1052
+
1053
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> dict:
1054
+ try:
1055
+ out = self(
1056
+ batch,
1057
+ recycling_steps=self.predict_args["recycling_steps"],
1058
+ num_sampling_steps=self.predict_args["sampling_steps"],
1059
+ diffusion_samples=self.predict_args["diffusion_samples"],
1060
+ max_parallel_samples=self.predict_args["max_parallel_samples"],
1061
+ run_confidence_sequentially=True,
1062
+ )
1063
+ pred_dict = {"exception": False}
1064
+ if "keys_dict_batch" in self.predict_args:
1065
+ for key in self.predict_args["keys_dict_batch"]:
1066
+ pred_dict[key] = batch[key]
1067
+
1068
+ pred_dict["masks"] = batch["atom_pad_mask"]
1069
+ pred_dict["token_masks"] = batch["token_pad_mask"]
1070
+
1071
+ if "keys_dict_out" in self.predict_args:
1072
+ for key in self.predict_args["keys_dict_out"]:
1073
+ pred_dict[key] = out[key]
1074
+ pred_dict["coords"] = out["sample_atom_coords"]
1075
+ if self.confidence_prediction:
1076
+ # pred_dict["confidence"] = out.get("ablation_confidence", None)
1077
+ pred_dict["pde"] = out["pde"]
1078
+ pred_dict["plddt"] = out["plddt"]
1079
+ pred_dict["confidence_score"] = (
1080
+ 4 * out["complex_plddt"]
1081
+ + (
1082
+ out["iptm"]
1083
+ if not torch.allclose(
1084
+ out["iptm"], torch.zeros_like(out["iptm"])
1085
+ )
1086
+ else out["ptm"]
1087
+ )
1088
+ ) / 5
1089
+
1090
+ pred_dict["complex_plddt"] = out["complex_plddt"]
1091
+ pred_dict["complex_iplddt"] = out["complex_iplddt"]
1092
+ pred_dict["complex_pde"] = out["complex_pde"]
1093
+ pred_dict["complex_ipde"] = out["complex_ipde"]
1094
+ if self.alpha_pae > 0:
1095
+ pred_dict["pae"] = out["pae"]
1096
+ pred_dict["ptm"] = out["ptm"]
1097
+ pred_dict["iptm"] = out["iptm"]
1098
+ pred_dict["ligand_iptm"] = out["ligand_iptm"]
1099
+ pred_dict["protein_iptm"] = out["protein_iptm"]
1100
+ pred_dict["pair_chains_iptm"] = out["pair_chains_iptm"]
1101
+ if self.affinity_prediction:
1102
+ pred_dict["affinity_pred_value"] = out["affinity_pred_value"]
1103
+ pred_dict["affinity_probability_binary"] = out[
1104
+ "affinity_probability_binary"
1105
+ ]
1106
+ if self.affinity_ensemble:
1107
+ pred_dict["affinity_pred_value1"] = out["affinity_pred_value1"]
1108
+ pred_dict["affinity_probability_binary1"] = out[
1109
+ "affinity_probability_binary1"
1110
+ ]
1111
+ pred_dict["affinity_pred_value2"] = out["affinity_pred_value2"]
1112
+ pred_dict["affinity_probability_binary2"] = out[
1113
+ "affinity_probability_binary2"
1114
+ ]
1115
+ return pred_dict
1116
+
1117
+ except RuntimeError as e: # catch out of memory exceptions
1118
+ if "out of memory" in str(e):
1119
+ print("| WARNING: ran out of memory, skipping batch")
1120
+ torch.cuda.empty_cache()
1121
+ gc.collect()
1122
+ return {"exception": True}
1123
+ else:
1124
+ raise e
1125
+
1126
+ def configure_optimizers(self) -> torch.optim.Optimizer:
1127
+ """Configure the optimizer."""
1128
+ param_dict = dict(self.named_parameters())
1129
+
1130
+ if self.structure_prediction_training:
1131
+ all_parameter_names = [
1132
+ pn for pn, p in self.named_parameters() if p.requires_grad
1133
+ ]
1134
+ else:
1135
+ all_parameter_names = [
1136
+ pn
1137
+ for pn, p in self.named_parameters()
1138
+ if p.requires_grad
1139
+ and ("out_token_feat_update" in pn or "confidence_module" in pn)
1140
+ ]
1141
+
1142
+ if self.training_args.get("weight_decay", 0.0) > 0:
1143
+ w_decay = self.training_args.get("weight_decay", 0.0)
1144
+ if self.training_args.get("weight_decay_exclude", False):
1145
+ nodecay_params_names = [
1146
+ pn
1147
+ for pn in all_parameter_names
1148
+ if (
1149
+ "norm" in pn
1150
+ or "rel_pos" in pn
1151
+ or ".s_init" in pn
1152
+ or ".z_init_" in pn
1153
+ or "token_bonds" in pn
1154
+ or "embed_atom_features" in pn
1155
+ or "dist_bin_pairwise_embed" in pn
1156
+ )
1157
+ ]
1158
+ nodecay_params = [param_dict[pn] for pn in nodecay_params_names]
1159
+ decay_params = [
1160
+ param_dict[pn]
1161
+ for pn in all_parameter_names
1162
+ if pn not in nodecay_params_names
1163
+ ]
1164
+ optim_groups = [
1165
+ {"params": decay_params, "weight_decay": w_decay},
1166
+ {"params": nodecay_params, "weight_decay": 0.0},
1167
+ ]
1168
+ optimizer = torch.optim.AdamW(
1169
+ optim_groups,
1170
+ betas=(
1171
+ self.training_args.adam_beta_1,
1172
+ self.training_args.adam_beta_2,
1173
+ ),
1174
+ eps=self.training_args.adam_eps,
1175
+ lr=self.training_args.base_lr,
1176
+ )
1177
+
1178
+ else:
1179
+ optimizer = torch.optim.AdamW(
1180
+ [param_dict[pn] for pn in all_parameter_names],
1181
+ betas=(
1182
+ self.training_args.adam_beta_1,
1183
+ self.training_args.adam_beta_2,
1184
+ ),
1185
+ eps=self.training_args.adam_eps,
1186
+ lr=self.training_args.base_lr,
1187
+ weight_decay=self.training_args.get("weight_decay", 0.0),
1188
+ )
1189
+ else:
1190
+ optimizer = torch.optim.AdamW(
1191
+ [param_dict[pn] for pn in all_parameter_names],
1192
+ betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2),
1193
+ eps=self.training_args.adam_eps,
1194
+ lr=self.training_args.base_lr,
1195
+ weight_decay=self.training_args.get("weight_decay", 0.0),
1196
+ )
1197
+
1198
+ if self.training_args.lr_scheduler == "af3":
1199
+ scheduler = AlphaFoldLRScheduler(
1200
+ optimizer,
1201
+ base_lr=self.training_args.base_lr,
1202
+ max_lr=self.training_args.max_lr,
1203
+ warmup_no_steps=self.training_args.lr_warmup_no_steps,
1204
+ start_decay_after_n_steps=self.training_args.lr_start_decay_after_n_steps,
1205
+ decay_every_n_steps=self.training_args.lr_decay_every_n_steps,
1206
+ decay_factor=self.training_args.lr_decay_factor,
1207
+ )
1208
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
1209
+
1210
+ return optimizer
1211
+
1212
+ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
1213
+ # Ignore the lr from the checkpoint
1214
+ lr = self.training_args.max_lr
1215
+ weight_decay = self.training_args.weight_decay
1216
+ if "optimizer_states" in checkpoint:
1217
+ for state in checkpoint["optimizer_states"]:
1218
+ for group in state["param_groups"]:
1219
+ group["lr"] = lr
1220
+ group["weight_decay"] = weight_decay
1221
+ if "lr_schedulers" in checkpoint:
1222
+ for scheduler in checkpoint["lr_schedulers"]:
1223
+ scheduler["max_lr"] = lr
1224
+ scheduler["base_lrs"] = [lr] * len(scheduler["base_lrs"])
1225
+ scheduler["_last_lr"] = [lr] * len(scheduler["_last_lr"])
1226
+
1227
+ # Ignore the training diffusion_multiplicity and recycling steps from the checkpoint
1228
+ if "hyper_parameters" in checkpoint:
1229
+ checkpoint["hyper_parameters"]["training_args"]["max_lr"] = lr
1230
+ checkpoint["hyper_parameters"]["training_args"][
1231
+ "diffusion_multiplicity"
1232
+ ] = self.training_args.diffusion_multiplicity
1233
+ checkpoint["hyper_parameters"]["training_args"]["recycling_steps"] = (
1234
+ self.training_args.recycling_steps
1235
+ )
1236
+ checkpoint["hyper_parameters"]["training_args"]["weight_decay"] = (
1237
+ self.training_args.weight_decay
1238
+ )
1239
+
1240
+ def configure_callbacks(self) -> list[Callback]:
1241
+ """Configure model callbacks.
1242
+
1243
+ Returns
1244
+ -------
1245
+ List[Callback]
1246
+ List of callbacks to be used in the model.
1247
+
1248
+ """
1249
+ return [EMA(self.ema_decay)] if self.use_ema else []