boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1249 @@
|
|
1
|
+
import gc
|
2
|
+
from typing import Any, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
import torch._dynamo
|
7
|
+
from pytorch_lightning import Callback, LightningModule
|
8
|
+
from torch import Tensor, nn
|
9
|
+
from torchmetrics import MeanMetric
|
10
|
+
|
11
|
+
import boltz.model.layers.initialize as init
|
12
|
+
from boltz.data import const
|
13
|
+
from boltz.data.mol import (
|
14
|
+
minimum_lddt_symmetry_coords,
|
15
|
+
)
|
16
|
+
from boltz.model.layers.pairformer import PairformerModule
|
17
|
+
from boltz.model.loss.bfactor import bfactor_loss_fn
|
18
|
+
from boltz.model.loss.confidencev2 import (
|
19
|
+
confidence_loss,
|
20
|
+
)
|
21
|
+
from boltz.model.loss.distogramv2 import distogram_loss
|
22
|
+
from boltz.model.modules.affinity import AffinityModule
|
23
|
+
from boltz.model.modules.confidencev2 import ConfidenceModule
|
24
|
+
from boltz.model.modules.diffusion_conditioning import DiffusionConditioning
|
25
|
+
from boltz.model.modules.diffusionv2 import AtomDiffusion
|
26
|
+
from boltz.model.modules.encodersv2 import RelativePositionEncoder
|
27
|
+
from boltz.model.modules.trunkv2 import (
|
28
|
+
BFactorModule,
|
29
|
+
ContactConditioning,
|
30
|
+
DistogramModule,
|
31
|
+
InputEmbedder,
|
32
|
+
MSAModule,
|
33
|
+
TemplateModule,
|
34
|
+
TemplateV2Module,
|
35
|
+
)
|
36
|
+
from boltz.model.optim.ema import EMA
|
37
|
+
from boltz.model.optim.scheduler import AlphaFoldLRScheduler
|
38
|
+
|
39
|
+
|
40
|
+
class Boltz2(LightningModule):
|
41
|
+
"""Boltz2 model."""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
atom_s: int,
|
46
|
+
atom_z: int,
|
47
|
+
token_s: int,
|
48
|
+
token_z: int,
|
49
|
+
num_bins: int,
|
50
|
+
training_args: dict[str, Any],
|
51
|
+
validation_args: dict[str, Any],
|
52
|
+
embedder_args: dict[str, Any],
|
53
|
+
msa_args: dict[str, Any],
|
54
|
+
pairformer_args: dict[str, Any],
|
55
|
+
score_model_args: dict[str, Any],
|
56
|
+
diffusion_process_args: dict[str, Any],
|
57
|
+
diffusion_loss_args: dict[str, Any],
|
58
|
+
confidence_model_args: Optional[dict[str, Any]] = None,
|
59
|
+
affinity_model_args: Optional[dict[str, Any]] = None,
|
60
|
+
affinity_model_args1: Optional[dict[str, Any]] = None,
|
61
|
+
affinity_model_args2: Optional[dict[str, Any]] = None,
|
62
|
+
validators: Any = None,
|
63
|
+
num_val_datasets: int = 1,
|
64
|
+
atom_feature_dim: int = 128,
|
65
|
+
template_args: Optional[dict] = None,
|
66
|
+
confidence_prediction: bool = True,
|
67
|
+
affinity_prediction: bool = False,
|
68
|
+
affinity_ensemble: bool = False,
|
69
|
+
affinity_mw_correction: bool = True,
|
70
|
+
run_trunk_and_structure: bool = True,
|
71
|
+
skip_run_structure: bool = False,
|
72
|
+
token_level_confidence: bool = True,
|
73
|
+
alpha_pae: float = 0.0,
|
74
|
+
structure_prediction_training: bool = True,
|
75
|
+
validate_structure: bool = True,
|
76
|
+
atoms_per_window_queries: int = 32,
|
77
|
+
atoms_per_window_keys: int = 128,
|
78
|
+
compile_pairformer: bool = False,
|
79
|
+
compile_structure: bool = False,
|
80
|
+
compile_confidence: bool = False,
|
81
|
+
compile_affinity: bool = False,
|
82
|
+
compile_msa: bool = False,
|
83
|
+
exclude_ions_from_lddt: bool = False,
|
84
|
+
ema: bool = False,
|
85
|
+
ema_decay: float = 0.999,
|
86
|
+
min_dist: float = 2.0,
|
87
|
+
max_dist: float = 22.0,
|
88
|
+
predict_args: Optional[dict[str, Any]] = None,
|
89
|
+
fix_sym_check: bool = False,
|
90
|
+
cyclic_pos_enc: bool = False,
|
91
|
+
aggregate_distogram: bool = True,
|
92
|
+
bond_type_feature: bool = False,
|
93
|
+
use_no_atom_char: bool = False,
|
94
|
+
no_random_recycling_training: bool = False,
|
95
|
+
use_atom_backbone_feat: bool = False,
|
96
|
+
use_residue_feats_atoms: bool = False,
|
97
|
+
conditioning_cutoff_min: float = 4.0,
|
98
|
+
conditioning_cutoff_max: float = 20.0,
|
99
|
+
steering_args: Optional[dict] = None,
|
100
|
+
use_templates: bool = False,
|
101
|
+
compile_templates: bool = False,
|
102
|
+
predict_bfactor: bool = False,
|
103
|
+
log_loss_every_steps: int = 50,
|
104
|
+
checkpoint_diffusion_conditioning: bool = False,
|
105
|
+
use_templates_v2: bool = False,
|
106
|
+
use_kernels: bool = False,
|
107
|
+
) -> None:
|
108
|
+
super().__init__()
|
109
|
+
self.save_hyperparameters(ignore=["validators"])
|
110
|
+
|
111
|
+
# No random recycling
|
112
|
+
self.no_random_recycling_training = no_random_recycling_training
|
113
|
+
|
114
|
+
if validate_structure:
|
115
|
+
# Late init at setup time
|
116
|
+
self.val_group_mapper = {} # maps a dataset index to a validation group name
|
117
|
+
self.validator_mapper = {} # maps a dataset index to a validator
|
118
|
+
|
119
|
+
# Validators for each dataset keep track of all metrics,
|
120
|
+
# compute validation, aggregate results and log
|
121
|
+
self.validators = nn.ModuleList(validators)
|
122
|
+
|
123
|
+
self.num_val_datasets = num_val_datasets
|
124
|
+
self.log_loss_every_steps = log_loss_every_steps
|
125
|
+
|
126
|
+
# EMA
|
127
|
+
self.use_ema = ema
|
128
|
+
self.ema_decay = ema_decay
|
129
|
+
|
130
|
+
# Arguments
|
131
|
+
self.training_args = training_args
|
132
|
+
self.validation_args = validation_args
|
133
|
+
self.diffusion_loss_args = diffusion_loss_args
|
134
|
+
self.predict_args = predict_args
|
135
|
+
self.steering_args = steering_args
|
136
|
+
|
137
|
+
# Training metrics
|
138
|
+
if validate_structure:
|
139
|
+
self.train_confidence_loss_logger = MeanMetric()
|
140
|
+
self.train_confidence_loss_dict_logger = nn.ModuleDict()
|
141
|
+
for m in [
|
142
|
+
"plddt_loss",
|
143
|
+
"resolved_loss",
|
144
|
+
"pde_loss",
|
145
|
+
"pae_loss",
|
146
|
+
]:
|
147
|
+
self.train_confidence_loss_dict_logger[m] = MeanMetric()
|
148
|
+
|
149
|
+
self.exclude_ions_from_lddt = exclude_ions_from_lddt
|
150
|
+
|
151
|
+
# Distogram
|
152
|
+
self.num_bins = num_bins
|
153
|
+
self.min_dist = min_dist
|
154
|
+
self.max_dist = max_dist
|
155
|
+
self.aggregate_distogram = aggregate_distogram
|
156
|
+
|
157
|
+
# Trunk
|
158
|
+
self.is_pairformer_compiled = False
|
159
|
+
self.is_msa_compiled = False
|
160
|
+
self.is_template_compiled = False
|
161
|
+
|
162
|
+
# Kernels
|
163
|
+
self.use_kernels = use_kernels
|
164
|
+
|
165
|
+
# Input embeddings
|
166
|
+
full_embedder_args = {
|
167
|
+
"atom_s": atom_s,
|
168
|
+
"atom_z": atom_z,
|
169
|
+
"token_s": token_s,
|
170
|
+
"token_z": token_z,
|
171
|
+
"atoms_per_window_queries": atoms_per_window_queries,
|
172
|
+
"atoms_per_window_keys": atoms_per_window_keys,
|
173
|
+
"atom_feature_dim": atom_feature_dim,
|
174
|
+
"use_no_atom_char": use_no_atom_char,
|
175
|
+
"use_atom_backbone_feat": use_atom_backbone_feat,
|
176
|
+
"use_residue_feats_atoms": use_residue_feats_atoms,
|
177
|
+
**embedder_args,
|
178
|
+
}
|
179
|
+
self.input_embedder = InputEmbedder(**full_embedder_args)
|
180
|
+
|
181
|
+
self.s_init = nn.Linear(token_s, token_s, bias=False)
|
182
|
+
self.z_init_1 = nn.Linear(token_s, token_z, bias=False)
|
183
|
+
self.z_init_2 = nn.Linear(token_s, token_z, bias=False)
|
184
|
+
|
185
|
+
self.rel_pos = RelativePositionEncoder(
|
186
|
+
token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc
|
187
|
+
)
|
188
|
+
|
189
|
+
self.token_bonds = nn.Linear(1, token_z, bias=False)
|
190
|
+
self.bond_type_feature = bond_type_feature
|
191
|
+
if bond_type_feature:
|
192
|
+
self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z)
|
193
|
+
|
194
|
+
self.contact_conditioning = ContactConditioning(
|
195
|
+
token_z=token_z,
|
196
|
+
cutoff_min=conditioning_cutoff_min,
|
197
|
+
cutoff_max=conditioning_cutoff_max,
|
198
|
+
)
|
199
|
+
|
200
|
+
# Normalization layers
|
201
|
+
self.s_norm = nn.LayerNorm(token_s)
|
202
|
+
self.z_norm = nn.LayerNorm(token_z)
|
203
|
+
|
204
|
+
# Recycling projections
|
205
|
+
self.s_recycle = nn.Linear(token_s, token_s, bias=False)
|
206
|
+
self.z_recycle = nn.Linear(token_z, token_z, bias=False)
|
207
|
+
init.gating_init_(self.s_recycle.weight)
|
208
|
+
init.gating_init_(self.z_recycle.weight)
|
209
|
+
|
210
|
+
# Set compile rules
|
211
|
+
# Big models hit the default cache limit (8)
|
212
|
+
torch._dynamo.config.cache_size_limit = 512 # noqa: SLF001
|
213
|
+
torch._dynamo.config.accumulated_cache_size_limit = 512 # noqa: SLF001
|
214
|
+
|
215
|
+
# Pairwise stack
|
216
|
+
self.use_templates = use_templates
|
217
|
+
if use_templates:
|
218
|
+
if use_templates_v2:
|
219
|
+
self.template_module = TemplateV2Module(token_z, **template_args)
|
220
|
+
else:
|
221
|
+
self.template_module = TemplateModule(token_z, **template_args)
|
222
|
+
if compile_templates:
|
223
|
+
self.is_template_compiled = True
|
224
|
+
self.template_module = torch.compile(
|
225
|
+
self.template_module,
|
226
|
+
dynamic=False,
|
227
|
+
fullgraph=False,
|
228
|
+
)
|
229
|
+
|
230
|
+
self.msa_module = MSAModule(
|
231
|
+
token_z=token_z,
|
232
|
+
token_s=token_s,
|
233
|
+
**msa_args,
|
234
|
+
)
|
235
|
+
if compile_msa:
|
236
|
+
self.is_msa_compiled = True
|
237
|
+
self.msa_module = torch.compile(
|
238
|
+
self.msa_module,
|
239
|
+
dynamic=False,
|
240
|
+
fullgraph=False,
|
241
|
+
)
|
242
|
+
self.pairformer_module = PairformerModule(token_s, token_z, **pairformer_args)
|
243
|
+
if compile_pairformer:
|
244
|
+
self.is_pairformer_compiled = True
|
245
|
+
self.pairformer_module = torch.compile(
|
246
|
+
self.pairformer_module,
|
247
|
+
dynamic=False,
|
248
|
+
fullgraph=False,
|
249
|
+
)
|
250
|
+
|
251
|
+
self.checkpoint_diffusion_conditioning = checkpoint_diffusion_conditioning
|
252
|
+
self.diffusion_conditioning = DiffusionConditioning(
|
253
|
+
token_s=token_s,
|
254
|
+
token_z=token_z,
|
255
|
+
atom_s=atom_s,
|
256
|
+
atom_z=atom_z,
|
257
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
258
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
259
|
+
atom_encoder_depth=score_model_args["atom_encoder_depth"],
|
260
|
+
atom_encoder_heads=score_model_args["atom_encoder_heads"],
|
261
|
+
token_transformer_depth=score_model_args["token_transformer_depth"],
|
262
|
+
token_transformer_heads=score_model_args["token_transformer_heads"],
|
263
|
+
atom_decoder_depth=score_model_args["atom_decoder_depth"],
|
264
|
+
atom_decoder_heads=score_model_args["atom_decoder_heads"],
|
265
|
+
atom_feature_dim=atom_feature_dim,
|
266
|
+
conditioning_transition_layers=score_model_args[
|
267
|
+
"conditioning_transition_layers"
|
268
|
+
],
|
269
|
+
use_no_atom_char=use_no_atom_char,
|
270
|
+
use_atom_backbone_feat=use_atom_backbone_feat,
|
271
|
+
use_residue_feats_atoms=use_residue_feats_atoms,
|
272
|
+
)
|
273
|
+
|
274
|
+
# Output modules
|
275
|
+
self.structure_module = AtomDiffusion(
|
276
|
+
score_model_args={
|
277
|
+
"token_s": token_s,
|
278
|
+
"atom_s": atom_s,
|
279
|
+
"atoms_per_window_queries": atoms_per_window_queries,
|
280
|
+
"atoms_per_window_keys": atoms_per_window_keys,
|
281
|
+
**score_model_args,
|
282
|
+
},
|
283
|
+
compile_score=compile_structure,
|
284
|
+
**diffusion_process_args,
|
285
|
+
)
|
286
|
+
self.distogram_module = DistogramModule(
|
287
|
+
token_z,
|
288
|
+
num_bins,
|
289
|
+
)
|
290
|
+
self.predict_bfactor = predict_bfactor
|
291
|
+
if predict_bfactor:
|
292
|
+
self.bfactor_module = BFactorModule(token_s, num_bins)
|
293
|
+
|
294
|
+
self.confidence_prediction = confidence_prediction
|
295
|
+
self.affinity_prediction = affinity_prediction
|
296
|
+
self.affinity_ensemble = affinity_ensemble
|
297
|
+
self.affinity_mw_correction = affinity_mw_correction
|
298
|
+
self.run_trunk_and_structure = run_trunk_and_structure
|
299
|
+
self.skip_run_structure = skip_run_structure
|
300
|
+
self.token_level_confidence = token_level_confidence
|
301
|
+
self.alpha_pae = alpha_pae
|
302
|
+
self.structure_prediction_training = structure_prediction_training
|
303
|
+
|
304
|
+
if self.confidence_prediction:
|
305
|
+
self.confidence_module = ConfidenceModule(
|
306
|
+
token_s,
|
307
|
+
token_z,
|
308
|
+
token_level_confidence=token_level_confidence,
|
309
|
+
bond_type_feature=bond_type_feature,
|
310
|
+
fix_sym_check=fix_sym_check,
|
311
|
+
cyclic_pos_enc=cyclic_pos_enc,
|
312
|
+
conditioning_cutoff_min=conditioning_cutoff_min,
|
313
|
+
conditioning_cutoff_max=conditioning_cutoff_max,
|
314
|
+
**confidence_model_args,
|
315
|
+
)
|
316
|
+
if compile_confidence:
|
317
|
+
self.confidence_module = torch.compile(
|
318
|
+
self.confidence_module, dynamic=False, fullgraph=False
|
319
|
+
)
|
320
|
+
|
321
|
+
if self.affinity_prediction:
|
322
|
+
if self.affinity_ensemble:
|
323
|
+
self.affinity_module1 = AffinityModule(
|
324
|
+
token_s,
|
325
|
+
token_z,
|
326
|
+
**affinity_model_args1,
|
327
|
+
)
|
328
|
+
self.affinity_module2 = AffinityModule(
|
329
|
+
token_s,
|
330
|
+
token_z,
|
331
|
+
**affinity_model_args2,
|
332
|
+
)
|
333
|
+
if compile_affinity:
|
334
|
+
self.affinity_module1 = torch.compile(
|
335
|
+
self.affinity_module1, dynamic=False, fullgraph=False
|
336
|
+
)
|
337
|
+
self.affinity_module2 = torch.compile(
|
338
|
+
self.affinity_module2, dynamic=False, fullgraph=False
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
self.affinity_module = AffinityModule(
|
342
|
+
token_s,
|
343
|
+
token_z,
|
344
|
+
**affinity_model_args,
|
345
|
+
)
|
346
|
+
if compile_affinity:
|
347
|
+
self.affinity_module = torch.compile(
|
348
|
+
self.affinity_module, dynamic=False, fullgraph=False
|
349
|
+
)
|
350
|
+
|
351
|
+
# Remove grad from weights they are not trained for ddp
|
352
|
+
if not structure_prediction_training:
|
353
|
+
for name, param in self.named_parameters():
|
354
|
+
if (
|
355
|
+
name.split(".")[0] not in ["confidence_module", "affinity_module"]
|
356
|
+
and "out_token_feat_update" not in name
|
357
|
+
):
|
358
|
+
param.requires_grad = False
|
359
|
+
|
360
|
+
def setup(self, stage: str) -> None:
|
361
|
+
"""Set the model for training, validation and inference."""
|
362
|
+
if stage == "predict" and not (
|
363
|
+
torch.cuda.is_available()
|
364
|
+
and torch.cuda.get_device_properties(torch.device("cuda")).major >= 8.0 # noqa: PLR2004
|
365
|
+
):
|
366
|
+
self.use_kernels = False
|
367
|
+
|
368
|
+
if (
|
369
|
+
stage != "predict"
|
370
|
+
and hasattr(self.trainer, "datamodule")
|
371
|
+
and self.trainer.datamodule
|
372
|
+
and self.validate_structure
|
373
|
+
):
|
374
|
+
self.val_group_mapper.update(self.trainer.datamodule.val_group_mapper)
|
375
|
+
|
376
|
+
l1 = len(self.val_group_mapper)
|
377
|
+
l2 = self.num_val_datasets
|
378
|
+
msg = (
|
379
|
+
f"Number of validation datasets num_val_datasets={l2} "
|
380
|
+
f"does not match the number of val_group_mapper entries={l1}."
|
381
|
+
)
|
382
|
+
assert l1 == l2, msg
|
383
|
+
|
384
|
+
# Map an index to a validator, and double check val names
|
385
|
+
# match from datamodule
|
386
|
+
all_validator_names = []
|
387
|
+
for validator in self.validators:
|
388
|
+
for val_name in validator.val_names:
|
389
|
+
msg = f"Validator {val_name} duplicated in validators."
|
390
|
+
assert val_name not in all_validator_names, msg
|
391
|
+
all_validator_names.append(val_name)
|
392
|
+
for val_idx, val_group in self.val_group_mapper.items():
|
393
|
+
if val_name == val_group["label"]:
|
394
|
+
self.validator_mapper[val_idx] = validator
|
395
|
+
|
396
|
+
msg = "Mismatch between validator names and val_group_mapper values."
|
397
|
+
assert set(all_validator_names) == {
|
398
|
+
x["label"] for x in self.val_group_mapper.values()
|
399
|
+
}, msg
|
400
|
+
|
401
|
+
def forward(
|
402
|
+
self,
|
403
|
+
feats: dict[str, Tensor],
|
404
|
+
recycling_steps: int = 0,
|
405
|
+
num_sampling_steps: Optional[int] = None,
|
406
|
+
multiplicity_diffusion_train: int = 1,
|
407
|
+
diffusion_samples: int = 1,
|
408
|
+
max_parallel_samples: Optional[int] = None,
|
409
|
+
run_confidence_sequentially: bool = False,
|
410
|
+
) -> dict[str, Tensor]:
|
411
|
+
with torch.set_grad_enabled(
|
412
|
+
self.training and self.structure_prediction_training
|
413
|
+
):
|
414
|
+
s_inputs = self.input_embedder(feats)
|
415
|
+
|
416
|
+
# Initialize the sequence embeddings
|
417
|
+
s_init = self.s_init(s_inputs)
|
418
|
+
|
419
|
+
# Initialize pairwise embeddings
|
420
|
+
z_init = (
|
421
|
+
self.z_init_1(s_inputs)[:, :, None]
|
422
|
+
+ self.z_init_2(s_inputs)[:, None, :]
|
423
|
+
)
|
424
|
+
relative_position_encoding = self.rel_pos(feats)
|
425
|
+
z_init = z_init + relative_position_encoding
|
426
|
+
z_init = z_init + self.token_bonds(feats["token_bonds"].float())
|
427
|
+
if self.bond_type_feature:
|
428
|
+
z_init = z_init + self.token_bonds_type(feats["type_bonds"].long())
|
429
|
+
z_init = z_init + self.contact_conditioning(feats)
|
430
|
+
|
431
|
+
# Perform rounds of the pairwise stack
|
432
|
+
s = torch.zeros_like(s_init)
|
433
|
+
z = torch.zeros_like(z_init)
|
434
|
+
|
435
|
+
# Compute pairwise mask
|
436
|
+
mask = feats["token_pad_mask"].float()
|
437
|
+
pair_mask = mask[:, :, None] * mask[:, None, :]
|
438
|
+
if self.run_trunk_and_structure:
|
439
|
+
for i in range(recycling_steps + 1):
|
440
|
+
with torch.set_grad_enabled(
|
441
|
+
self.training
|
442
|
+
and self.structure_prediction_training
|
443
|
+
and (i == recycling_steps)
|
444
|
+
):
|
445
|
+
# Issue with unused parameters in autocast
|
446
|
+
if (
|
447
|
+
self.training
|
448
|
+
and (i == recycling_steps)
|
449
|
+
and torch.is_autocast_enabled()
|
450
|
+
):
|
451
|
+
torch.clear_autocast_cache()
|
452
|
+
|
453
|
+
# Apply recycling
|
454
|
+
s = s_init + self.s_recycle(self.s_norm(s))
|
455
|
+
z = z_init + self.z_recycle(self.z_norm(z))
|
456
|
+
|
457
|
+
# Compute pairwise stack
|
458
|
+
if self.use_templates:
|
459
|
+
if self.is_template_compiled and not self.training:
|
460
|
+
template_module = self.template_module._orig_mod # noqa: SLF001
|
461
|
+
else:
|
462
|
+
template_module = self.template_module
|
463
|
+
|
464
|
+
z = z + template_module(
|
465
|
+
z, feats, pair_mask, use_kernels=self.use_kernels
|
466
|
+
)
|
467
|
+
|
468
|
+
if self.is_msa_compiled and not self.training:
|
469
|
+
msa_module = self.msa_module._orig_mod # noqa: SLF001
|
470
|
+
else:
|
471
|
+
msa_module = self.msa_module
|
472
|
+
|
473
|
+
z = z + msa_module(
|
474
|
+
z, s_inputs, feats, use_kernels=self.use_kernels
|
475
|
+
)
|
476
|
+
|
477
|
+
# Revert to uncompiled version for validation
|
478
|
+
if self.is_pairformer_compiled and not self.training:
|
479
|
+
pairformer_module = self.pairformer_module._orig_mod # noqa: SLF001
|
480
|
+
else:
|
481
|
+
pairformer_module = self.pairformer_module
|
482
|
+
|
483
|
+
s, z = pairformer_module(
|
484
|
+
s,
|
485
|
+
z,
|
486
|
+
mask=mask,
|
487
|
+
pair_mask=pair_mask,
|
488
|
+
use_kernels=self.use_kernels,
|
489
|
+
)
|
490
|
+
|
491
|
+
pdistogram = self.distogram_module(z)
|
492
|
+
dict_out = {"pdistogram": pdistogram}
|
493
|
+
|
494
|
+
if (
|
495
|
+
self.run_trunk_and_structure
|
496
|
+
and ((not self.training) or self.confidence_prediction)
|
497
|
+
and (not self.skip_run_structure)
|
498
|
+
):
|
499
|
+
if self.checkpoint_diffusion_conditioning and self.training:
|
500
|
+
# TODO decide whether this should be with bf16 or not
|
501
|
+
q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = (
|
502
|
+
torch.utils.checkpoint.checkpoint(
|
503
|
+
self.diffusion_conditioning,
|
504
|
+
s,
|
505
|
+
z,
|
506
|
+
relative_position_encoding,
|
507
|
+
feats,
|
508
|
+
)
|
509
|
+
)
|
510
|
+
else:
|
511
|
+
q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = (
|
512
|
+
self.diffusion_conditioning(
|
513
|
+
s_trunk=s,
|
514
|
+
z_trunk=z,
|
515
|
+
relative_position_encoding=relative_position_encoding,
|
516
|
+
feats=feats,
|
517
|
+
)
|
518
|
+
)
|
519
|
+
diffusion_conditioning = {
|
520
|
+
"q": q,
|
521
|
+
"c": c,
|
522
|
+
"to_keys": to_keys,
|
523
|
+
"atom_enc_bias": atom_enc_bias,
|
524
|
+
"atom_dec_bias": atom_dec_bias,
|
525
|
+
"token_trans_bias": token_trans_bias,
|
526
|
+
}
|
527
|
+
|
528
|
+
with torch.autocast("cuda", enabled=False):
|
529
|
+
struct_out = self.structure_module.sample(
|
530
|
+
s_trunk=s.float(),
|
531
|
+
s_inputs=s_inputs.float(),
|
532
|
+
feats=feats,
|
533
|
+
num_sampling_steps=num_sampling_steps,
|
534
|
+
atom_mask=feats["atom_pad_mask"].float(),
|
535
|
+
multiplicity=diffusion_samples,
|
536
|
+
max_parallel_samples=max_parallel_samples,
|
537
|
+
steering_args=self.steering_args,
|
538
|
+
diffusion_conditioning=diffusion_conditioning,
|
539
|
+
)
|
540
|
+
dict_out.update(struct_out)
|
541
|
+
|
542
|
+
if self.predict_bfactor:
|
543
|
+
pbfactor = self.bfactor_module(s)
|
544
|
+
dict_out["pbfactor"] = pbfactor
|
545
|
+
|
546
|
+
if self.training and self.confidence_prediction:
|
547
|
+
assert len(feats["coords"].shape) == 4
|
548
|
+
assert feats["coords"].shape[1] == 1, (
|
549
|
+
"Only one conformation is supported for confidence"
|
550
|
+
)
|
551
|
+
|
552
|
+
# Compute structure module
|
553
|
+
if self.training and self.structure_prediction_training:
|
554
|
+
atom_coords = feats["coords"]
|
555
|
+
B, K, L = atom_coords.shape[0:3]
|
556
|
+
assert K in (
|
557
|
+
multiplicity_diffusion_train,
|
558
|
+
1,
|
559
|
+
) # TODO make check somewhere else, expand to m % N == 0, m > N
|
560
|
+
atom_coords = atom_coords.reshape(B * K, L, 3)
|
561
|
+
atom_coords = atom_coords.repeat_interleave(
|
562
|
+
multiplicity_diffusion_train // K, 0
|
563
|
+
)
|
564
|
+
feats["coords"] = atom_coords # (multiplicity, L, 3)
|
565
|
+
assert len(feats["coords"].shape) == 3
|
566
|
+
|
567
|
+
with torch.autocast("cuda", enabled=False):
|
568
|
+
struct_out = self.structure_module(
|
569
|
+
s_trunk=s.float(),
|
570
|
+
s_inputs=s_inputs.float(),
|
571
|
+
feats=feats,
|
572
|
+
multiplicity=multiplicity_diffusion_train,
|
573
|
+
diffusion_conditioning=diffusion_conditioning,
|
574
|
+
)
|
575
|
+
dict_out.update(struct_out)
|
576
|
+
|
577
|
+
elif self.training:
|
578
|
+
feats["coords"] = feats["coords"].squeeze(1)
|
579
|
+
assert len(feats["coords"].shape) == 3
|
580
|
+
|
581
|
+
if self.confidence_prediction:
|
582
|
+
dict_out.update(
|
583
|
+
self.confidence_module(
|
584
|
+
s_inputs=s_inputs.detach(),
|
585
|
+
s=s.detach(),
|
586
|
+
z=z.detach(),
|
587
|
+
x_pred=(
|
588
|
+
dict_out["sample_atom_coords"].detach()
|
589
|
+
if not self.skip_run_structure
|
590
|
+
else feats["coords"].repeat_interleave(diffusion_samples, 0)
|
591
|
+
),
|
592
|
+
feats=feats,
|
593
|
+
pred_distogram_logits=(
|
594
|
+
dict_out["pdistogram"][
|
595
|
+
:, :, :, 0
|
596
|
+
].detach() # TODO only implemeted for 1 distogram
|
597
|
+
),
|
598
|
+
multiplicity=diffusion_samples,
|
599
|
+
run_sequentially=run_confidence_sequentially,
|
600
|
+
use_kernels=self.use_kernels,
|
601
|
+
)
|
602
|
+
)
|
603
|
+
|
604
|
+
if self.affinity_prediction:
|
605
|
+
pad_token_mask = feats["token_pad_mask"][0]
|
606
|
+
rec_mask = feats["mol_type"][0] == 0
|
607
|
+
rec_mask = rec_mask * pad_token_mask
|
608
|
+
lig_mask = feats["affinity_token_mask"][0].to(torch.bool)
|
609
|
+
lig_mask = lig_mask * pad_token_mask
|
610
|
+
cross_pair_mask = (
|
611
|
+
lig_mask[:, None] * rec_mask[None, :]
|
612
|
+
+ rec_mask[:, None] * lig_mask[None, :]
|
613
|
+
+ lig_mask[:, None] * lig_mask[None, :]
|
614
|
+
)
|
615
|
+
z_affinity = z * cross_pair_mask[None, :, :, None]
|
616
|
+
|
617
|
+
argsort = torch.argsort(dict_out["iptm"], descending=True)
|
618
|
+
best_idx = argsort[0].item()
|
619
|
+
coords_affinity = dict_out["sample_atom_coords"].detach()[best_idx][
|
620
|
+
None, None
|
621
|
+
]
|
622
|
+
s_inputs = self.input_embedder(feats, affinity=True)
|
623
|
+
|
624
|
+
with torch.autocast("cuda", enabled=False):
|
625
|
+
if self.affinity_ensemble:
|
626
|
+
dict_out_affinity1 = self.affinity_module1(
|
627
|
+
s_inputs=s_inputs.detach(),
|
628
|
+
z=z_affinity.detach(),
|
629
|
+
x_pred=coords_affinity,
|
630
|
+
feats=feats,
|
631
|
+
multiplicity=1,
|
632
|
+
use_kernels=self.use_kernels,
|
633
|
+
)
|
634
|
+
|
635
|
+
dict_out_affinity1["affinity_probability_binary"] = (
|
636
|
+
torch.nn.functional.sigmoid(
|
637
|
+
dict_out_affinity1["affinity_logits_binary"]
|
638
|
+
)
|
639
|
+
)
|
640
|
+
dict_out_affinity2 = self.affinity_module2(
|
641
|
+
s_inputs=s_inputs.detach(),
|
642
|
+
z=z_affinity.detach(),
|
643
|
+
x_pred=coords_affinity,
|
644
|
+
feats=feats,
|
645
|
+
multiplicity=1,
|
646
|
+
use_kernels=self.use_kernels,
|
647
|
+
)
|
648
|
+
dict_out_affinity2["affinity_probability_binary"] = (
|
649
|
+
torch.nn.functional.sigmoid(
|
650
|
+
dict_out_affinity2["affinity_logits_binary"]
|
651
|
+
)
|
652
|
+
)
|
653
|
+
|
654
|
+
dict_out_affinity_ensemble = {
|
655
|
+
"affinity_pred_value": (
|
656
|
+
dict_out_affinity1["affinity_pred_value"]
|
657
|
+
+ dict_out_affinity2["affinity_pred_value"]
|
658
|
+
)
|
659
|
+
/ 2,
|
660
|
+
"affinity_probability_binary": (
|
661
|
+
dict_out_affinity1["affinity_probability_binary"]
|
662
|
+
+ dict_out_affinity2["affinity_probability_binary"]
|
663
|
+
)
|
664
|
+
/ 2,
|
665
|
+
}
|
666
|
+
|
667
|
+
dict_out_affinity1 = {
|
668
|
+
"affinity_pred_value1": dict_out_affinity1[
|
669
|
+
"affinity_pred_value"
|
670
|
+
],
|
671
|
+
"affinity_probability_binary1": dict_out_affinity1[
|
672
|
+
"affinity_probability_binary"
|
673
|
+
],
|
674
|
+
}
|
675
|
+
dict_out_affinity2 = {
|
676
|
+
"affinity_pred_value2": dict_out_affinity2[
|
677
|
+
"affinity_pred_value"
|
678
|
+
],
|
679
|
+
"affinity_probability_binary2": dict_out_affinity2[
|
680
|
+
"affinity_probability_binary"
|
681
|
+
],
|
682
|
+
}
|
683
|
+
if self.affinity_mw_correction:
|
684
|
+
model_coef = 1.03525938
|
685
|
+
mw_coef = -0.59992683
|
686
|
+
bias = 2.83288489
|
687
|
+
mw = feats["affinity_mw"][0] ** 0.3
|
688
|
+
dict_out_affinity_ensemble["affinity_pred_value"] = (
|
689
|
+
model_coef
|
690
|
+
* dict_out_affinity_ensemble["affinity_pred_value"]
|
691
|
+
+ mw_coef * mw
|
692
|
+
+ bias
|
693
|
+
)
|
694
|
+
|
695
|
+
dict_out.update(dict_out_affinity_ensemble)
|
696
|
+
dict_out.update(dict_out_affinity1)
|
697
|
+
dict_out.update(dict_out_affinity2)
|
698
|
+
else:
|
699
|
+
dict_out_affinity = self.affinity_module(
|
700
|
+
s_inputs=s_inputs.detach(),
|
701
|
+
z=z_affinity.detach(),
|
702
|
+
x_pred=coords_affinity,
|
703
|
+
feats=feats,
|
704
|
+
multiplicity=1,
|
705
|
+
use_kernels=self.use_kernels,
|
706
|
+
)
|
707
|
+
dict_out.update(
|
708
|
+
{
|
709
|
+
"affinity_pred_value": dict_out_affinity[
|
710
|
+
"affinity_pred_value"
|
711
|
+
],
|
712
|
+
"affinity_probability_binary": torch.nn.functional.sigmoid(
|
713
|
+
dict_out_affinity["affinity_logits_binary"]
|
714
|
+
),
|
715
|
+
}
|
716
|
+
)
|
717
|
+
|
718
|
+
return dict_out
|
719
|
+
|
720
|
+
def get_true_coordinates(
|
721
|
+
self,
|
722
|
+
batch: dict[str, Tensor],
|
723
|
+
out: dict[str, Tensor],
|
724
|
+
diffusion_samples: int,
|
725
|
+
symmetry_correction: bool,
|
726
|
+
expand_to_diffusion_samples: bool = True,
|
727
|
+
):
|
728
|
+
if symmetry_correction:
|
729
|
+
msg = "expand_to_diffusion_samples must be true for symmetry correction."
|
730
|
+
assert expand_to_diffusion_samples, msg
|
731
|
+
|
732
|
+
return_dict = {}
|
733
|
+
|
734
|
+
assert batch["coords"].shape[0] == 1, (
|
735
|
+
f"Validation is not supported for batch sizes={batch['coords'].shape[0]}"
|
736
|
+
)
|
737
|
+
|
738
|
+
if symmetry_correction:
|
739
|
+
true_coords = []
|
740
|
+
true_coords_resolved_mask = []
|
741
|
+
for idx in range(batch["token_index"].shape[0]):
|
742
|
+
for rep in range(diffusion_samples):
|
743
|
+
i = idx * diffusion_samples + rep
|
744
|
+
best_true_coords, best_true_coords_resolved_mask = (
|
745
|
+
minimum_lddt_symmetry_coords(
|
746
|
+
coords=out["sample_atom_coords"][i : i + 1],
|
747
|
+
feats=batch,
|
748
|
+
index_batch=idx,
|
749
|
+
)
|
750
|
+
)
|
751
|
+
true_coords.append(best_true_coords)
|
752
|
+
true_coords_resolved_mask.append(best_true_coords_resolved_mask)
|
753
|
+
|
754
|
+
true_coords = torch.cat(true_coords, dim=0)
|
755
|
+
true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0)
|
756
|
+
true_coords = true_coords.unsqueeze(1)
|
757
|
+
|
758
|
+
true_coords_resolved_mask = true_coords_resolved_mask
|
759
|
+
|
760
|
+
return_dict["true_coords"] = true_coords
|
761
|
+
return_dict["true_coords_resolved_mask"] = true_coords_resolved_mask
|
762
|
+
return_dict["rmsds"] = 0
|
763
|
+
return_dict["best_rmsd_recall"] = 0
|
764
|
+
|
765
|
+
else:
|
766
|
+
K, L = batch["coords"].shape[1:3]
|
767
|
+
|
768
|
+
true_coords_resolved_mask = batch["atom_resolved_mask"]
|
769
|
+
true_coords = batch["coords"].squeeze(0)
|
770
|
+
if expand_to_diffusion_samples:
|
771
|
+
true_coords = true_coords.repeat((diffusion_samples, 1, 1)).reshape(
|
772
|
+
diffusion_samples, K, L, 3
|
773
|
+
)
|
774
|
+
|
775
|
+
true_coords_resolved_mask = true_coords_resolved_mask.repeat_interleave(
|
776
|
+
diffusion_samples, dim=0
|
777
|
+
) # since all masks are the same across conformers and diffusion samples, can just repeat S times
|
778
|
+
else:
|
779
|
+
true_coords_resolved_mask = true_coords_resolved_mask.squeeze(0)
|
780
|
+
|
781
|
+
return_dict["true_coords"] = true_coords
|
782
|
+
return_dict["true_coords_resolved_mask"] = true_coords_resolved_mask
|
783
|
+
return_dict["rmsds"] = 0
|
784
|
+
return_dict["best_rmsd_recall"] = 0
|
785
|
+
return_dict["best_rmsd_precision"] = 0
|
786
|
+
|
787
|
+
return return_dict
|
788
|
+
|
789
|
+
def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
|
790
|
+
# Sample recycling steps
|
791
|
+
if self.no_random_recycling_training:
|
792
|
+
recycling_steps = self.training_args.recycling_steps
|
793
|
+
else:
|
794
|
+
rgn = np.random.default_rng(self.global_step)
|
795
|
+
recycling_steps = rgn.integers(
|
796
|
+
0, self.training_args.recycling_steps + 1
|
797
|
+
).item()
|
798
|
+
|
799
|
+
if self.training_args.get("sampling_steps_random", None) is not None:
|
800
|
+
rgn_samplng_steps = np.random.default_rng(self.global_step)
|
801
|
+
sampling_steps = rgn_samplng_steps.choice(
|
802
|
+
self.training_args.sampling_steps_random
|
803
|
+
)
|
804
|
+
else:
|
805
|
+
sampling_steps = self.training_args.sampling_steps
|
806
|
+
|
807
|
+
# Compute the forward pass
|
808
|
+
out = self(
|
809
|
+
feats=batch,
|
810
|
+
recycling_steps=recycling_steps,
|
811
|
+
num_sampling_steps=sampling_steps,
|
812
|
+
multiplicity_diffusion_train=self.training_args.diffusion_multiplicity,
|
813
|
+
diffusion_samples=self.training_args.diffusion_samples,
|
814
|
+
)
|
815
|
+
|
816
|
+
# Compute losses
|
817
|
+
if self.structure_prediction_training:
|
818
|
+
disto_loss, _ = distogram_loss(
|
819
|
+
out,
|
820
|
+
batch,
|
821
|
+
aggregate_distogram=self.aggregate_distogram,
|
822
|
+
)
|
823
|
+
try:
|
824
|
+
diffusion_loss_dict = self.structure_module.compute_loss(
|
825
|
+
batch,
|
826
|
+
out,
|
827
|
+
multiplicity=self.training_args.diffusion_multiplicity,
|
828
|
+
**self.diffusion_loss_args,
|
829
|
+
)
|
830
|
+
except Exception as e:
|
831
|
+
print(f"Skipping batch {batch_idx} due to error: {e}")
|
832
|
+
return None
|
833
|
+
|
834
|
+
if self.predict_bfactor:
|
835
|
+
bfactor_loss = bfactor_loss_fn(out, batch)
|
836
|
+
else:
|
837
|
+
bfactor_loss = 0.0
|
838
|
+
|
839
|
+
else:
|
840
|
+
disto_loss = 0.0
|
841
|
+
bfactor_loss = 0.0
|
842
|
+
diffusion_loss_dict = {"loss": 0.0, "loss_breakdown": {}}
|
843
|
+
|
844
|
+
if self.confidence_prediction:
|
845
|
+
try:
|
846
|
+
# confidence model symmetry correction
|
847
|
+
return_dict = self.get_true_coordinates(
|
848
|
+
batch,
|
849
|
+
out,
|
850
|
+
diffusion_samples=self.training_args.diffusion_samples,
|
851
|
+
symmetry_correction=self.training_args.symmetry_correction,
|
852
|
+
)
|
853
|
+
except Exception as e:
|
854
|
+
print(f"Skipping batch with id {batch['pdb_id']} due to error: {e}")
|
855
|
+
return None
|
856
|
+
|
857
|
+
true_coords = return_dict["true_coords"]
|
858
|
+
true_coords_resolved_mask = return_dict["true_coords_resolved_mask"]
|
859
|
+
|
860
|
+
# TODO remove once multiple conformers are supported
|
861
|
+
K = true_coords.shape[1]
|
862
|
+
assert K == 1, (
|
863
|
+
f"Confidence_prediction is not supported for num_ensembles_val={K}."
|
864
|
+
)
|
865
|
+
|
866
|
+
# For now, just take the only conformer.
|
867
|
+
true_coords = true_coords.squeeze(1) # (S, L, 3)
|
868
|
+
batch["frames_idx"] = batch["frames_idx"].squeeze(
|
869
|
+
1
|
870
|
+
) # remove conformer dimension
|
871
|
+
batch["frame_resolved_mask"] = batch["frame_resolved_mask"].squeeze(
|
872
|
+
1
|
873
|
+
) # remove conformer dimension
|
874
|
+
|
875
|
+
confidence_loss_dict = confidence_loss(
|
876
|
+
out,
|
877
|
+
batch,
|
878
|
+
true_coords,
|
879
|
+
true_coords_resolved_mask,
|
880
|
+
token_level_confidence=self.token_level_confidence,
|
881
|
+
alpha_pae=self.alpha_pae,
|
882
|
+
multiplicity=self.training_args.diffusion_samples,
|
883
|
+
)
|
884
|
+
|
885
|
+
else:
|
886
|
+
confidence_loss_dict = {
|
887
|
+
"loss": torch.tensor(0.0, device=batch["token_index"].device),
|
888
|
+
"loss_breakdown": {},
|
889
|
+
}
|
890
|
+
|
891
|
+
# Aggregate losses
|
892
|
+
# NOTE: we already have an implicit weight in the losses induced by dataset sampling
|
893
|
+
# NOTE: this logic works only for datasets with confidence labels
|
894
|
+
loss = (
|
895
|
+
self.training_args.confidence_loss_weight * confidence_loss_dict["loss"]
|
896
|
+
+ self.training_args.diffusion_loss_weight * diffusion_loss_dict["loss"]
|
897
|
+
+ self.training_args.distogram_loss_weight * disto_loss
|
898
|
+
+ self.training_args.get("bfactor_loss_weight", 0.0) * bfactor_loss
|
899
|
+
)
|
900
|
+
|
901
|
+
if not (self.global_step % self.log_loss_every_steps):
|
902
|
+
# Log losses
|
903
|
+
if self.validate_structure:
|
904
|
+
self.log("train/distogram_loss", disto_loss)
|
905
|
+
self.log("train/diffusion_loss", diffusion_loss_dict["loss"])
|
906
|
+
for k, v in diffusion_loss_dict["loss_breakdown"].items():
|
907
|
+
self.log(f"train/{k}", v)
|
908
|
+
|
909
|
+
if self.confidence_prediction:
|
910
|
+
self.train_confidence_loss_logger.update(
|
911
|
+
confidence_loss_dict["loss"].detach()
|
912
|
+
)
|
913
|
+
for k in self.train_confidence_loss_dict_logger:
|
914
|
+
self.train_confidence_loss_dict_logger[k].update(
|
915
|
+
(
|
916
|
+
confidence_loss_dict["loss_breakdown"][k].detach()
|
917
|
+
if torch.is_tensor(
|
918
|
+
confidence_loss_dict["loss_breakdown"][k]
|
919
|
+
)
|
920
|
+
else confidence_loss_dict["loss_breakdown"][k]
|
921
|
+
)
|
922
|
+
)
|
923
|
+
self.log("train/loss", loss)
|
924
|
+
self.training_log()
|
925
|
+
return loss
|
926
|
+
|
927
|
+
def training_log(self):
|
928
|
+
self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False)
|
929
|
+
self.log("train/param_norm", self.parameter_norm(self), prog_bar=False)
|
930
|
+
|
931
|
+
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
932
|
+
self.log("lr", lr, prog_bar=False)
|
933
|
+
|
934
|
+
self.log(
|
935
|
+
"train/param_norm_msa_module",
|
936
|
+
self.parameter_norm(self.msa_module),
|
937
|
+
prog_bar=False,
|
938
|
+
)
|
939
|
+
|
940
|
+
self.log(
|
941
|
+
"train/param_norm_pairformer_module",
|
942
|
+
self.parameter_norm(self.pairformer_module),
|
943
|
+
prog_bar=False,
|
944
|
+
)
|
945
|
+
|
946
|
+
self.log(
|
947
|
+
"train/param_norm_structure_module",
|
948
|
+
self.parameter_norm(self.structure_module),
|
949
|
+
prog_bar=False,
|
950
|
+
)
|
951
|
+
|
952
|
+
if self.confidence_prediction:
|
953
|
+
self.log(
|
954
|
+
"train/grad_norm_confidence_module",
|
955
|
+
self.gradient_norm(self.confidence_module),
|
956
|
+
prog_bar=False,
|
957
|
+
)
|
958
|
+
self.log(
|
959
|
+
"train/param_norm_confidence_module",
|
960
|
+
self.parameter_norm(self.confidence_module),
|
961
|
+
prog_bar=False,
|
962
|
+
)
|
963
|
+
|
964
|
+
def on_train_epoch_end(self):
|
965
|
+
if self.confidence_prediction:
|
966
|
+
self.log(
|
967
|
+
"train/confidence_loss",
|
968
|
+
self.train_confidence_loss_logger,
|
969
|
+
prog_bar=False,
|
970
|
+
on_step=False,
|
971
|
+
on_epoch=True,
|
972
|
+
)
|
973
|
+
for k, v in self.train_confidence_loss_dict_logger.items():
|
974
|
+
self.log(f"train/{k}", v, prog_bar=False, on_step=False, on_epoch=True)
|
975
|
+
|
976
|
+
def gradient_norm(self, module):
|
977
|
+
parameters = [
|
978
|
+
p.grad.norm(p=2) ** 2
|
979
|
+
for p in module.parameters()
|
980
|
+
if p.requires_grad and p.grad is not None
|
981
|
+
]
|
982
|
+
if len(parameters) == 0:
|
983
|
+
return torch.tensor(
|
984
|
+
0.0, device="cuda" if torch.cuda.is_available() else "cpu"
|
985
|
+
)
|
986
|
+
norm = torch.stack(parameters).sum().sqrt()
|
987
|
+
return norm
|
988
|
+
|
989
|
+
def parameter_norm(self, module):
|
990
|
+
parameters = [p.norm(p=2) ** 2 for p in module.parameters() if p.requires_grad]
|
991
|
+
if len(parameters) == 0:
|
992
|
+
return torch.tensor(
|
993
|
+
0.0, device="cuda" if torch.cuda.is_available() else "cpu"
|
994
|
+
)
|
995
|
+
norm = torch.stack(parameters).sum().sqrt()
|
996
|
+
return norm
|
997
|
+
|
998
|
+
def validation_step(self, batch: dict[str, Tensor], batch_idx: int):
|
999
|
+
if self.validate_structure:
|
1000
|
+
try:
|
1001
|
+
msg = "Only batch=1 is supported for validation"
|
1002
|
+
assert batch["idx_dataset"].shape[0] == 1, msg
|
1003
|
+
|
1004
|
+
# Select validator based on dataset
|
1005
|
+
idx_dataset = batch["idx_dataset"][0].item()
|
1006
|
+
validator = self.validator_mapper[idx_dataset]
|
1007
|
+
|
1008
|
+
# Run forward pass
|
1009
|
+
out = validator.run_model(
|
1010
|
+
model=self, batch=batch, idx_dataset=idx_dataset
|
1011
|
+
)
|
1012
|
+
# Compute validation step
|
1013
|
+
validator.process(
|
1014
|
+
model=self, batch=batch, out=out, idx_dataset=idx_dataset
|
1015
|
+
)
|
1016
|
+
except RuntimeError as e: # catch out of memory exceptions
|
1017
|
+
idx_dataset = batch["idx_dataset"][0].item()
|
1018
|
+
if "out of memory" in str(e):
|
1019
|
+
msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}"
|
1020
|
+
print(msg)
|
1021
|
+
torch.cuda.empty_cache()
|
1022
|
+
gc.collect()
|
1023
|
+
return
|
1024
|
+
raise e
|
1025
|
+
else:
|
1026
|
+
try:
|
1027
|
+
out = self(
|
1028
|
+
batch,
|
1029
|
+
recycling_steps=self.validation_args.recycling_steps,
|
1030
|
+
num_sampling_steps=self.validation_args.sampling_steps,
|
1031
|
+
diffusion_samples=self.validation_args.diffusion_samples,
|
1032
|
+
run_confidence_sequentially=self.validation_args.get(
|
1033
|
+
"run_confidence_sequentially", False
|
1034
|
+
),
|
1035
|
+
)
|
1036
|
+
except RuntimeError as e: # catch out of memory exceptions
|
1037
|
+
idx_dataset = batch["idx_dataset"][0].item()
|
1038
|
+
if "out of memory" in str(e):
|
1039
|
+
msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}"
|
1040
|
+
print(msg)
|
1041
|
+
torch.cuda.empty_cache()
|
1042
|
+
gc.collect()
|
1043
|
+
return
|
1044
|
+
raise e
|
1045
|
+
|
1046
|
+
def on_validation_epoch_end(self):
|
1047
|
+
"""Aggregate all metrics for each validator."""
|
1048
|
+
if self.validate_structure:
|
1049
|
+
for validator in self.validator_mapper.values():
|
1050
|
+
# This will aggregate, compute and log all metrics
|
1051
|
+
validator.on_epoch_end(model=self)
|
1052
|
+
|
1053
|
+
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> dict:
|
1054
|
+
try:
|
1055
|
+
out = self(
|
1056
|
+
batch,
|
1057
|
+
recycling_steps=self.predict_args["recycling_steps"],
|
1058
|
+
num_sampling_steps=self.predict_args["sampling_steps"],
|
1059
|
+
diffusion_samples=self.predict_args["diffusion_samples"],
|
1060
|
+
max_parallel_samples=self.predict_args["max_parallel_samples"],
|
1061
|
+
run_confidence_sequentially=True,
|
1062
|
+
)
|
1063
|
+
pred_dict = {"exception": False}
|
1064
|
+
if "keys_dict_batch" in self.predict_args:
|
1065
|
+
for key in self.predict_args["keys_dict_batch"]:
|
1066
|
+
pred_dict[key] = batch[key]
|
1067
|
+
|
1068
|
+
pred_dict["masks"] = batch["atom_pad_mask"]
|
1069
|
+
pred_dict["token_masks"] = batch["token_pad_mask"]
|
1070
|
+
|
1071
|
+
if "keys_dict_out" in self.predict_args:
|
1072
|
+
for key in self.predict_args["keys_dict_out"]:
|
1073
|
+
pred_dict[key] = out[key]
|
1074
|
+
pred_dict["coords"] = out["sample_atom_coords"]
|
1075
|
+
if self.confidence_prediction:
|
1076
|
+
# pred_dict["confidence"] = out.get("ablation_confidence", None)
|
1077
|
+
pred_dict["pde"] = out["pde"]
|
1078
|
+
pred_dict["plddt"] = out["plddt"]
|
1079
|
+
pred_dict["confidence_score"] = (
|
1080
|
+
4 * out["complex_plddt"]
|
1081
|
+
+ (
|
1082
|
+
out["iptm"]
|
1083
|
+
if not torch.allclose(
|
1084
|
+
out["iptm"], torch.zeros_like(out["iptm"])
|
1085
|
+
)
|
1086
|
+
else out["ptm"]
|
1087
|
+
)
|
1088
|
+
) / 5
|
1089
|
+
|
1090
|
+
pred_dict["complex_plddt"] = out["complex_plddt"]
|
1091
|
+
pred_dict["complex_iplddt"] = out["complex_iplddt"]
|
1092
|
+
pred_dict["complex_pde"] = out["complex_pde"]
|
1093
|
+
pred_dict["complex_ipde"] = out["complex_ipde"]
|
1094
|
+
if self.alpha_pae > 0:
|
1095
|
+
pred_dict["pae"] = out["pae"]
|
1096
|
+
pred_dict["ptm"] = out["ptm"]
|
1097
|
+
pred_dict["iptm"] = out["iptm"]
|
1098
|
+
pred_dict["ligand_iptm"] = out["ligand_iptm"]
|
1099
|
+
pred_dict["protein_iptm"] = out["protein_iptm"]
|
1100
|
+
pred_dict["pair_chains_iptm"] = out["pair_chains_iptm"]
|
1101
|
+
if self.affinity_prediction:
|
1102
|
+
pred_dict["affinity_pred_value"] = out["affinity_pred_value"]
|
1103
|
+
pred_dict["affinity_probability_binary"] = out[
|
1104
|
+
"affinity_probability_binary"
|
1105
|
+
]
|
1106
|
+
if self.affinity_ensemble:
|
1107
|
+
pred_dict["affinity_pred_value1"] = out["affinity_pred_value1"]
|
1108
|
+
pred_dict["affinity_probability_binary1"] = out[
|
1109
|
+
"affinity_probability_binary1"
|
1110
|
+
]
|
1111
|
+
pred_dict["affinity_pred_value2"] = out["affinity_pred_value2"]
|
1112
|
+
pred_dict["affinity_probability_binary2"] = out[
|
1113
|
+
"affinity_probability_binary2"
|
1114
|
+
]
|
1115
|
+
return pred_dict
|
1116
|
+
|
1117
|
+
except RuntimeError as e: # catch out of memory exceptions
|
1118
|
+
if "out of memory" in str(e):
|
1119
|
+
print("| WARNING: ran out of memory, skipping batch")
|
1120
|
+
torch.cuda.empty_cache()
|
1121
|
+
gc.collect()
|
1122
|
+
return {"exception": True}
|
1123
|
+
else:
|
1124
|
+
raise e
|
1125
|
+
|
1126
|
+
def configure_optimizers(self) -> torch.optim.Optimizer:
|
1127
|
+
"""Configure the optimizer."""
|
1128
|
+
param_dict = dict(self.named_parameters())
|
1129
|
+
|
1130
|
+
if self.structure_prediction_training:
|
1131
|
+
all_parameter_names = [
|
1132
|
+
pn for pn, p in self.named_parameters() if p.requires_grad
|
1133
|
+
]
|
1134
|
+
else:
|
1135
|
+
all_parameter_names = [
|
1136
|
+
pn
|
1137
|
+
for pn, p in self.named_parameters()
|
1138
|
+
if p.requires_grad
|
1139
|
+
and ("out_token_feat_update" in pn or "confidence_module" in pn)
|
1140
|
+
]
|
1141
|
+
|
1142
|
+
if self.training_args.get("weight_decay", 0.0) > 0:
|
1143
|
+
w_decay = self.training_args.get("weight_decay", 0.0)
|
1144
|
+
if self.training_args.get("weight_decay_exclude", False):
|
1145
|
+
nodecay_params_names = [
|
1146
|
+
pn
|
1147
|
+
for pn in all_parameter_names
|
1148
|
+
if (
|
1149
|
+
"norm" in pn
|
1150
|
+
or "rel_pos" in pn
|
1151
|
+
or ".s_init" in pn
|
1152
|
+
or ".z_init_" in pn
|
1153
|
+
or "token_bonds" in pn
|
1154
|
+
or "embed_atom_features" in pn
|
1155
|
+
or "dist_bin_pairwise_embed" in pn
|
1156
|
+
)
|
1157
|
+
]
|
1158
|
+
nodecay_params = [param_dict[pn] for pn in nodecay_params_names]
|
1159
|
+
decay_params = [
|
1160
|
+
param_dict[pn]
|
1161
|
+
for pn in all_parameter_names
|
1162
|
+
if pn not in nodecay_params_names
|
1163
|
+
]
|
1164
|
+
optim_groups = [
|
1165
|
+
{"params": decay_params, "weight_decay": w_decay},
|
1166
|
+
{"params": nodecay_params, "weight_decay": 0.0},
|
1167
|
+
]
|
1168
|
+
optimizer = torch.optim.AdamW(
|
1169
|
+
optim_groups,
|
1170
|
+
betas=(
|
1171
|
+
self.training_args.adam_beta_1,
|
1172
|
+
self.training_args.adam_beta_2,
|
1173
|
+
),
|
1174
|
+
eps=self.training_args.adam_eps,
|
1175
|
+
lr=self.training_args.base_lr,
|
1176
|
+
)
|
1177
|
+
|
1178
|
+
else:
|
1179
|
+
optimizer = torch.optim.AdamW(
|
1180
|
+
[param_dict[pn] for pn in all_parameter_names],
|
1181
|
+
betas=(
|
1182
|
+
self.training_args.adam_beta_1,
|
1183
|
+
self.training_args.adam_beta_2,
|
1184
|
+
),
|
1185
|
+
eps=self.training_args.adam_eps,
|
1186
|
+
lr=self.training_args.base_lr,
|
1187
|
+
weight_decay=self.training_args.get("weight_decay", 0.0),
|
1188
|
+
)
|
1189
|
+
else:
|
1190
|
+
optimizer = torch.optim.AdamW(
|
1191
|
+
[param_dict[pn] for pn in all_parameter_names],
|
1192
|
+
betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2),
|
1193
|
+
eps=self.training_args.adam_eps,
|
1194
|
+
lr=self.training_args.base_lr,
|
1195
|
+
weight_decay=self.training_args.get("weight_decay", 0.0),
|
1196
|
+
)
|
1197
|
+
|
1198
|
+
if self.training_args.lr_scheduler == "af3":
|
1199
|
+
scheduler = AlphaFoldLRScheduler(
|
1200
|
+
optimizer,
|
1201
|
+
base_lr=self.training_args.base_lr,
|
1202
|
+
max_lr=self.training_args.max_lr,
|
1203
|
+
warmup_no_steps=self.training_args.lr_warmup_no_steps,
|
1204
|
+
start_decay_after_n_steps=self.training_args.lr_start_decay_after_n_steps,
|
1205
|
+
decay_every_n_steps=self.training_args.lr_decay_every_n_steps,
|
1206
|
+
decay_factor=self.training_args.lr_decay_factor,
|
1207
|
+
)
|
1208
|
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
1209
|
+
|
1210
|
+
return optimizer
|
1211
|
+
|
1212
|
+
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
1213
|
+
# Ignore the lr from the checkpoint
|
1214
|
+
lr = self.training_args.max_lr
|
1215
|
+
weight_decay = self.training_args.weight_decay
|
1216
|
+
if "optimizer_states" in checkpoint:
|
1217
|
+
for state in checkpoint["optimizer_states"]:
|
1218
|
+
for group in state["param_groups"]:
|
1219
|
+
group["lr"] = lr
|
1220
|
+
group["weight_decay"] = weight_decay
|
1221
|
+
if "lr_schedulers" in checkpoint:
|
1222
|
+
for scheduler in checkpoint["lr_schedulers"]:
|
1223
|
+
scheduler["max_lr"] = lr
|
1224
|
+
scheduler["base_lrs"] = [lr] * len(scheduler["base_lrs"])
|
1225
|
+
scheduler["_last_lr"] = [lr] * len(scheduler["_last_lr"])
|
1226
|
+
|
1227
|
+
# Ignore the training diffusion_multiplicity and recycling steps from the checkpoint
|
1228
|
+
if "hyper_parameters" in checkpoint:
|
1229
|
+
checkpoint["hyper_parameters"]["training_args"]["max_lr"] = lr
|
1230
|
+
checkpoint["hyper_parameters"]["training_args"][
|
1231
|
+
"diffusion_multiplicity"
|
1232
|
+
] = self.training_args.diffusion_multiplicity
|
1233
|
+
checkpoint["hyper_parameters"]["training_args"]["recycling_steps"] = (
|
1234
|
+
self.training_args.recycling_steps
|
1235
|
+
)
|
1236
|
+
checkpoint["hyper_parameters"]["training_args"]["weight_decay"] = (
|
1237
|
+
self.training_args.weight_decay
|
1238
|
+
)
|
1239
|
+
|
1240
|
+
def configure_callbacks(self) -> list[Callback]:
|
1241
|
+
"""Configure model callbacks.
|
1242
|
+
|
1243
|
+
Returns
|
1244
|
+
-------
|
1245
|
+
List[Callback]
|
1246
|
+
List of callbacks to be used in the model.
|
1247
|
+
|
1248
|
+
"""
|
1249
|
+
return [EMA(self.ema_decay)] if self.use_ema else []
|