boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1286 @@
|
|
1
|
+
import gc
|
2
|
+
import random
|
3
|
+
from typing import Any, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch._dynamo
|
7
|
+
from pytorch_lightning import LightningModule
|
8
|
+
from torch import Tensor, nn
|
9
|
+
from torchmetrics import MeanMetric
|
10
|
+
|
11
|
+
import boltz.model.layers.initialize as init
|
12
|
+
from boltz.data import const
|
13
|
+
from boltz.data.feature.symmetry import (
|
14
|
+
minimum_lddt_symmetry_coords,
|
15
|
+
minimum_symmetry_coords,
|
16
|
+
)
|
17
|
+
from boltz.model.loss.confidence import confidence_loss
|
18
|
+
from boltz.model.loss.distogram import distogram_loss
|
19
|
+
from boltz.model.loss.validation import (
|
20
|
+
compute_pae_mae,
|
21
|
+
compute_pde_mae,
|
22
|
+
compute_plddt_mae,
|
23
|
+
factored_lddt_loss,
|
24
|
+
factored_token_lddt_dist_loss,
|
25
|
+
weighted_minimum_rmsd,
|
26
|
+
)
|
27
|
+
from boltz.model.modules.confidence import ConfidenceModule
|
28
|
+
from boltz.model.modules.diffusion import AtomDiffusion
|
29
|
+
from boltz.model.modules.encoders import RelativePositionEncoder
|
30
|
+
from boltz.model.modules.trunk import (
|
31
|
+
DistogramModule,
|
32
|
+
InputEmbedder,
|
33
|
+
MSAModule,
|
34
|
+
PairformerModule,
|
35
|
+
)
|
36
|
+
from boltz.model.modules.utils import ExponentialMovingAverage
|
37
|
+
from boltz.model.optim.scheduler import AlphaFoldLRScheduler
|
38
|
+
|
39
|
+
|
40
|
+
class Boltz1(LightningModule):
|
41
|
+
"""Boltz1 model."""
|
42
|
+
|
43
|
+
def __init__( # noqa: PLR0915, C901, PLR0912
|
44
|
+
self,
|
45
|
+
atom_s: int,
|
46
|
+
atom_z: int,
|
47
|
+
token_s: int,
|
48
|
+
token_z: int,
|
49
|
+
num_bins: int,
|
50
|
+
training_args: dict[str, Any],
|
51
|
+
validation_args: dict[str, Any],
|
52
|
+
embedder_args: dict[str, Any],
|
53
|
+
msa_args: dict[str, Any],
|
54
|
+
pairformer_args: dict[str, Any],
|
55
|
+
score_model_args: dict[str, Any],
|
56
|
+
diffusion_process_args: dict[str, Any],
|
57
|
+
diffusion_loss_args: dict[str, Any],
|
58
|
+
confidence_model_args: dict[str, Any],
|
59
|
+
atom_feature_dim: int = 128,
|
60
|
+
confidence_prediction: bool = False,
|
61
|
+
confidence_imitate_trunk: bool = False,
|
62
|
+
alpha_pae: float = 0.0,
|
63
|
+
structure_prediction_training: bool = True,
|
64
|
+
atoms_per_window_queries: int = 32,
|
65
|
+
atoms_per_window_keys: int = 128,
|
66
|
+
compile_pairformer: bool = False,
|
67
|
+
compile_structure: bool = False,
|
68
|
+
compile_confidence: bool = False,
|
69
|
+
nucleotide_rmsd_weight: float = 5.0,
|
70
|
+
ligand_rmsd_weight: float = 10.0,
|
71
|
+
no_msa: bool = False,
|
72
|
+
no_atom_encoder: bool = False,
|
73
|
+
ema: bool = False,
|
74
|
+
ema_decay: float = 0.999,
|
75
|
+
min_dist: float = 2.0,
|
76
|
+
max_dist: float = 22.0,
|
77
|
+
predict_args: Optional[dict[str, Any]] = None,
|
78
|
+
steering_args: Optional[dict[str, Any]] = None,
|
79
|
+
use_kernels: bool = False,
|
80
|
+
) -> None:
|
81
|
+
super().__init__()
|
82
|
+
|
83
|
+
self.save_hyperparameters()
|
84
|
+
|
85
|
+
self.lddt = nn.ModuleDict()
|
86
|
+
self.disto_lddt = nn.ModuleDict()
|
87
|
+
self.complex_lddt = nn.ModuleDict()
|
88
|
+
if confidence_prediction:
|
89
|
+
self.top1_lddt = nn.ModuleDict()
|
90
|
+
self.iplddt_top1_lddt = nn.ModuleDict()
|
91
|
+
self.ipde_top1_lddt = nn.ModuleDict()
|
92
|
+
self.pde_top1_lddt = nn.ModuleDict()
|
93
|
+
self.ptm_top1_lddt = nn.ModuleDict()
|
94
|
+
self.iptm_top1_lddt = nn.ModuleDict()
|
95
|
+
self.ligand_iptm_top1_lddt = nn.ModuleDict()
|
96
|
+
self.protein_iptm_top1_lddt = nn.ModuleDict()
|
97
|
+
self.avg_lddt = nn.ModuleDict()
|
98
|
+
self.plddt_mae = nn.ModuleDict()
|
99
|
+
self.pde_mae = nn.ModuleDict()
|
100
|
+
self.pae_mae = nn.ModuleDict()
|
101
|
+
for m in const.out_types + ["pocket_ligand_protein"]:
|
102
|
+
self.lddt[m] = MeanMetric()
|
103
|
+
self.disto_lddt[m] = MeanMetric()
|
104
|
+
self.complex_lddt[m] = MeanMetric()
|
105
|
+
if confidence_prediction:
|
106
|
+
self.top1_lddt[m] = MeanMetric()
|
107
|
+
self.iplddt_top1_lddt[m] = MeanMetric()
|
108
|
+
self.ipde_top1_lddt[m] = MeanMetric()
|
109
|
+
self.pde_top1_lddt[m] = MeanMetric()
|
110
|
+
self.ptm_top1_lddt[m] = MeanMetric()
|
111
|
+
self.iptm_top1_lddt[m] = MeanMetric()
|
112
|
+
self.ligand_iptm_top1_lddt[m] = MeanMetric()
|
113
|
+
self.protein_iptm_top1_lddt[m] = MeanMetric()
|
114
|
+
self.avg_lddt[m] = MeanMetric()
|
115
|
+
self.pde_mae[m] = MeanMetric()
|
116
|
+
self.pae_mae[m] = MeanMetric()
|
117
|
+
for m in const.out_single_types:
|
118
|
+
if confidence_prediction:
|
119
|
+
self.plddt_mae[m] = MeanMetric()
|
120
|
+
self.rmsd = MeanMetric()
|
121
|
+
self.best_rmsd = MeanMetric()
|
122
|
+
|
123
|
+
self.train_confidence_loss_logger = MeanMetric()
|
124
|
+
self.train_confidence_loss_dict_logger = nn.ModuleDict()
|
125
|
+
for m in [
|
126
|
+
"plddt_loss",
|
127
|
+
"resolved_loss",
|
128
|
+
"pde_loss",
|
129
|
+
"pae_loss",
|
130
|
+
]:
|
131
|
+
self.train_confidence_loss_dict_logger[m] = MeanMetric()
|
132
|
+
|
133
|
+
self.ema = None
|
134
|
+
self.use_ema = ema
|
135
|
+
self.ema_decay = ema_decay
|
136
|
+
|
137
|
+
self.training_args = training_args
|
138
|
+
self.validation_args = validation_args
|
139
|
+
self.diffusion_loss_args = diffusion_loss_args
|
140
|
+
self.predict_args = predict_args
|
141
|
+
self.steering_args = steering_args
|
142
|
+
|
143
|
+
self.use_kernels = use_kernels
|
144
|
+
|
145
|
+
self.nucleotide_rmsd_weight = nucleotide_rmsd_weight
|
146
|
+
self.ligand_rmsd_weight = ligand_rmsd_weight
|
147
|
+
|
148
|
+
self.num_bins = num_bins
|
149
|
+
self.min_dist = min_dist
|
150
|
+
self.max_dist = max_dist
|
151
|
+
self.is_pairformer_compiled = False
|
152
|
+
|
153
|
+
# Input projections
|
154
|
+
s_input_dim = (
|
155
|
+
token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
|
156
|
+
)
|
157
|
+
self.s_init = nn.Linear(s_input_dim, token_s, bias=False)
|
158
|
+
self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
|
159
|
+
self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)
|
160
|
+
|
161
|
+
# Input embeddings
|
162
|
+
full_embedder_args = {
|
163
|
+
"atom_s": atom_s,
|
164
|
+
"atom_z": atom_z,
|
165
|
+
"token_s": token_s,
|
166
|
+
"token_z": token_z,
|
167
|
+
"atoms_per_window_queries": atoms_per_window_queries,
|
168
|
+
"atoms_per_window_keys": atoms_per_window_keys,
|
169
|
+
"atom_feature_dim": atom_feature_dim,
|
170
|
+
"no_atom_encoder": no_atom_encoder,
|
171
|
+
**embedder_args,
|
172
|
+
}
|
173
|
+
self.input_embedder = InputEmbedder(**full_embedder_args)
|
174
|
+
self.rel_pos = RelativePositionEncoder(token_z)
|
175
|
+
self.token_bonds = nn.Linear(1, token_z, bias=False)
|
176
|
+
|
177
|
+
# Normalization layers
|
178
|
+
self.s_norm = nn.LayerNorm(token_s)
|
179
|
+
self.z_norm = nn.LayerNorm(token_z)
|
180
|
+
|
181
|
+
# Recycling projections
|
182
|
+
self.s_recycle = nn.Linear(token_s, token_s, bias=False)
|
183
|
+
self.z_recycle = nn.Linear(token_z, token_z, bias=False)
|
184
|
+
init.gating_init_(self.s_recycle.weight)
|
185
|
+
init.gating_init_(self.z_recycle.weight)
|
186
|
+
|
187
|
+
# Pairwise stack
|
188
|
+
self.no_msa = no_msa
|
189
|
+
if not no_msa:
|
190
|
+
self.msa_module = MSAModule(
|
191
|
+
token_z=token_z,
|
192
|
+
s_input_dim=s_input_dim,
|
193
|
+
**msa_args,
|
194
|
+
)
|
195
|
+
self.pairformer_module = PairformerModule(token_s, token_z, **pairformer_args)
|
196
|
+
if compile_pairformer:
|
197
|
+
# Big models hit the default cache limit (8)
|
198
|
+
self.is_pairformer_compiled = True
|
199
|
+
torch._dynamo.config.cache_size_limit = 512
|
200
|
+
torch._dynamo.config.accumulated_cache_size_limit = 512
|
201
|
+
self.pairformer_module = torch.compile(
|
202
|
+
self.pairformer_module,
|
203
|
+
dynamic=False,
|
204
|
+
fullgraph=False,
|
205
|
+
)
|
206
|
+
|
207
|
+
# Output modules
|
208
|
+
use_accumulate_token_repr = (
|
209
|
+
confidence_prediction
|
210
|
+
and "use_s_diffusion" in confidence_model_args
|
211
|
+
and confidence_model_args["use_s_diffusion"]
|
212
|
+
)
|
213
|
+
self.structure_module = AtomDiffusion(
|
214
|
+
score_model_args={
|
215
|
+
"token_z": token_z,
|
216
|
+
"token_s": token_s,
|
217
|
+
"atom_z": atom_z,
|
218
|
+
"atom_s": atom_s,
|
219
|
+
"atoms_per_window_queries": atoms_per_window_queries,
|
220
|
+
"atoms_per_window_keys": atoms_per_window_keys,
|
221
|
+
"atom_feature_dim": atom_feature_dim,
|
222
|
+
**score_model_args,
|
223
|
+
},
|
224
|
+
compile_score=compile_structure,
|
225
|
+
accumulate_token_repr=use_accumulate_token_repr,
|
226
|
+
**diffusion_process_args,
|
227
|
+
)
|
228
|
+
self.distogram_module = DistogramModule(token_z, num_bins)
|
229
|
+
self.confidence_prediction = confidence_prediction
|
230
|
+
self.alpha_pae = alpha_pae
|
231
|
+
|
232
|
+
self.structure_prediction_training = structure_prediction_training
|
233
|
+
self.confidence_imitate_trunk = confidence_imitate_trunk
|
234
|
+
if self.confidence_prediction:
|
235
|
+
if self.confidence_imitate_trunk:
|
236
|
+
self.confidence_module = ConfidenceModule(
|
237
|
+
token_s,
|
238
|
+
token_z,
|
239
|
+
compute_pae=alpha_pae > 0,
|
240
|
+
imitate_trunk=True,
|
241
|
+
pairformer_args=pairformer_args,
|
242
|
+
full_embedder_args=full_embedder_args,
|
243
|
+
msa_args=msa_args,
|
244
|
+
**confidence_model_args,
|
245
|
+
)
|
246
|
+
else:
|
247
|
+
self.confidence_module = ConfidenceModule(
|
248
|
+
token_s,
|
249
|
+
token_z,
|
250
|
+
compute_pae=alpha_pae > 0,
|
251
|
+
**confidence_model_args,
|
252
|
+
)
|
253
|
+
if compile_confidence:
|
254
|
+
self.confidence_module = torch.compile(
|
255
|
+
self.confidence_module, dynamic=False, fullgraph=False
|
256
|
+
)
|
257
|
+
|
258
|
+
# Remove grad from weights they are not trained for ddp
|
259
|
+
if not structure_prediction_training:
|
260
|
+
for name, param in self.named_parameters():
|
261
|
+
if name.split(".")[0] != "confidence_module":
|
262
|
+
param.requires_grad = False
|
263
|
+
|
264
|
+
def setup(self, stage: str) -> None:
|
265
|
+
"""Set the model for training, validation and inference."""
|
266
|
+
if stage == "predict" and not (
|
267
|
+
torch.cuda.is_available()
|
268
|
+
and torch.cuda.get_device_properties(torch.device("cuda")).major >= 8.0 # noqa: PLR2004
|
269
|
+
):
|
270
|
+
self.use_kernels = False
|
271
|
+
|
272
|
+
def forward(
|
273
|
+
self,
|
274
|
+
feats: dict[str, Tensor],
|
275
|
+
recycling_steps: int = 0,
|
276
|
+
num_sampling_steps: Optional[int] = None,
|
277
|
+
multiplicity_diffusion_train: int = 1,
|
278
|
+
diffusion_samples: int = 1,
|
279
|
+
max_parallel_samples: Optional[int] = None,
|
280
|
+
run_confidence_sequentially: bool = False,
|
281
|
+
) -> dict[str, Tensor]:
|
282
|
+
dict_out = {}
|
283
|
+
|
284
|
+
# Compute input embeddings
|
285
|
+
with torch.set_grad_enabled(
|
286
|
+
self.training and self.structure_prediction_training
|
287
|
+
):
|
288
|
+
s_inputs = self.input_embedder(feats)
|
289
|
+
|
290
|
+
# Initialize the sequence and pairwise embeddings
|
291
|
+
s_init = self.s_init(s_inputs)
|
292
|
+
z_init = (
|
293
|
+
self.z_init_1(s_inputs)[:, :, None]
|
294
|
+
+ self.z_init_2(s_inputs)[:, None, :]
|
295
|
+
)
|
296
|
+
relative_position_encoding = self.rel_pos(feats)
|
297
|
+
z_init = z_init + relative_position_encoding
|
298
|
+
z_init = z_init + self.token_bonds(feats["token_bonds"].float())
|
299
|
+
|
300
|
+
# Perform rounds of the pairwise stack
|
301
|
+
s = torch.zeros_like(s_init)
|
302
|
+
z = torch.zeros_like(z_init)
|
303
|
+
|
304
|
+
# Compute pairwise mask
|
305
|
+
mask = feats["token_pad_mask"].float()
|
306
|
+
pair_mask = mask[:, :, None] * mask[:, None, :]
|
307
|
+
|
308
|
+
for i in range(recycling_steps + 1):
|
309
|
+
with torch.set_grad_enabled(self.training and (i == recycling_steps)):
|
310
|
+
# Fixes an issue with unused parameters in autocast
|
311
|
+
if (
|
312
|
+
self.training
|
313
|
+
and (i == recycling_steps)
|
314
|
+
and torch.is_autocast_enabled()
|
315
|
+
):
|
316
|
+
torch.clear_autocast_cache()
|
317
|
+
|
318
|
+
# Apply recycling
|
319
|
+
s = s_init + self.s_recycle(self.s_norm(s))
|
320
|
+
z = z_init + self.z_recycle(self.z_norm(z))
|
321
|
+
|
322
|
+
# Compute pairwise stack
|
323
|
+
if not self.no_msa:
|
324
|
+
z = z + self.msa_module(
|
325
|
+
z, s_inputs, feats, use_kernels=self.use_kernels
|
326
|
+
)
|
327
|
+
|
328
|
+
# Revert to uncompiled version for validation
|
329
|
+
if self.is_pairformer_compiled and not self.training:
|
330
|
+
pairformer_module = self.pairformer_module._orig_mod # noqa: SLF001
|
331
|
+
else:
|
332
|
+
pairformer_module = self.pairformer_module
|
333
|
+
|
334
|
+
s, z = pairformer_module(
|
335
|
+
s,
|
336
|
+
z,
|
337
|
+
mask=mask,
|
338
|
+
pair_mask=pair_mask,
|
339
|
+
use_kernels=self.use_kernels,
|
340
|
+
)
|
341
|
+
|
342
|
+
pdistogram = self.distogram_module(z)
|
343
|
+
dict_out = {"pdistogram": pdistogram}
|
344
|
+
|
345
|
+
# Compute structure module
|
346
|
+
if self.training and self.structure_prediction_training:
|
347
|
+
dict_out.update(
|
348
|
+
self.structure_module(
|
349
|
+
s_trunk=s,
|
350
|
+
z_trunk=z,
|
351
|
+
s_inputs=s_inputs,
|
352
|
+
feats=feats,
|
353
|
+
relative_position_encoding=relative_position_encoding,
|
354
|
+
multiplicity=multiplicity_diffusion_train,
|
355
|
+
)
|
356
|
+
)
|
357
|
+
|
358
|
+
if (not self.training) or self.confidence_prediction:
|
359
|
+
dict_out.update(
|
360
|
+
self.structure_module.sample(
|
361
|
+
s_trunk=s,
|
362
|
+
z_trunk=z,
|
363
|
+
s_inputs=s_inputs,
|
364
|
+
feats=feats,
|
365
|
+
relative_position_encoding=relative_position_encoding,
|
366
|
+
num_sampling_steps=num_sampling_steps,
|
367
|
+
atom_mask=feats["atom_pad_mask"],
|
368
|
+
multiplicity=diffusion_samples,
|
369
|
+
max_parallel_samples=max_parallel_samples,
|
370
|
+
train_accumulate_token_repr=self.training,
|
371
|
+
steering_args=self.steering_args,
|
372
|
+
)
|
373
|
+
)
|
374
|
+
|
375
|
+
if self.confidence_prediction:
|
376
|
+
dict_out.update(
|
377
|
+
self.confidence_module(
|
378
|
+
s_inputs=s_inputs.detach(),
|
379
|
+
s=s.detach(),
|
380
|
+
z=z.detach(),
|
381
|
+
s_diffusion=(
|
382
|
+
dict_out["diff_token_repr"]
|
383
|
+
if self.confidence_module.use_s_diffusion
|
384
|
+
else None
|
385
|
+
),
|
386
|
+
x_pred=dict_out["sample_atom_coords"].detach(),
|
387
|
+
feats=feats,
|
388
|
+
pred_distogram_logits=dict_out["pdistogram"].detach(),
|
389
|
+
multiplicity=diffusion_samples,
|
390
|
+
run_sequentially=run_confidence_sequentially,
|
391
|
+
use_kernels=self.use_kernels,
|
392
|
+
)
|
393
|
+
)
|
394
|
+
if self.confidence_prediction and self.confidence_module.use_s_diffusion:
|
395
|
+
dict_out.pop("diff_token_repr", None)
|
396
|
+
return dict_out
|
397
|
+
|
398
|
+
def get_true_coordinates(
|
399
|
+
self,
|
400
|
+
batch,
|
401
|
+
out,
|
402
|
+
diffusion_samples,
|
403
|
+
symmetry_correction,
|
404
|
+
lddt_minimization=True,
|
405
|
+
):
|
406
|
+
if symmetry_correction:
|
407
|
+
min_coords_routine = (
|
408
|
+
minimum_lddt_symmetry_coords
|
409
|
+
if lddt_minimization
|
410
|
+
else minimum_symmetry_coords
|
411
|
+
)
|
412
|
+
true_coords = []
|
413
|
+
true_coords_resolved_mask = []
|
414
|
+
rmsds, best_rmsds = [], []
|
415
|
+
for idx in range(batch["token_index"].shape[0]):
|
416
|
+
best_rmsd = float("inf")
|
417
|
+
for rep in range(diffusion_samples):
|
418
|
+
i = idx * diffusion_samples + rep
|
419
|
+
best_true_coords, rmsd, best_true_coords_resolved_mask = (
|
420
|
+
min_coords_routine(
|
421
|
+
coords=out["sample_atom_coords"][i : i + 1],
|
422
|
+
feats=batch,
|
423
|
+
index_batch=idx,
|
424
|
+
nucleotide_weight=self.nucleotide_rmsd_weight,
|
425
|
+
ligand_weight=self.ligand_rmsd_weight,
|
426
|
+
)
|
427
|
+
)
|
428
|
+
rmsds.append(rmsd)
|
429
|
+
true_coords.append(best_true_coords)
|
430
|
+
true_coords_resolved_mask.append(best_true_coords_resolved_mask)
|
431
|
+
if rmsd < best_rmsd:
|
432
|
+
best_rmsd = rmsd
|
433
|
+
best_rmsds.append(best_rmsd)
|
434
|
+
true_coords = torch.cat(true_coords, dim=0)
|
435
|
+
true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0)
|
436
|
+
else:
|
437
|
+
true_coords = (
|
438
|
+
batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0)
|
439
|
+
)
|
440
|
+
|
441
|
+
true_coords_resolved_mask = batch["atom_resolved_mask"].repeat_interleave(
|
442
|
+
diffusion_samples, 0
|
443
|
+
)
|
444
|
+
rmsds, best_rmsds = weighted_minimum_rmsd(
|
445
|
+
out["sample_atom_coords"],
|
446
|
+
batch,
|
447
|
+
multiplicity=diffusion_samples,
|
448
|
+
nucleotide_weight=self.nucleotide_rmsd_weight,
|
449
|
+
ligand_weight=self.ligand_rmsd_weight,
|
450
|
+
)
|
451
|
+
|
452
|
+
return true_coords, rmsds, best_rmsds, true_coords_resolved_mask
|
453
|
+
|
454
|
+
def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
|
455
|
+
# Sample recycling steps
|
456
|
+
recycling_steps = random.randint(0, self.training_args.recycling_steps)
|
457
|
+
|
458
|
+
# Compute the forward pass
|
459
|
+
out = self(
|
460
|
+
feats=batch,
|
461
|
+
recycling_steps=recycling_steps,
|
462
|
+
num_sampling_steps=self.training_args.sampling_steps,
|
463
|
+
multiplicity_diffusion_train=self.training_args.diffusion_multiplicity,
|
464
|
+
diffusion_samples=self.training_args.diffusion_samples,
|
465
|
+
)
|
466
|
+
|
467
|
+
# Compute losses
|
468
|
+
if self.structure_prediction_training:
|
469
|
+
disto_loss, _ = distogram_loss(
|
470
|
+
out,
|
471
|
+
batch,
|
472
|
+
)
|
473
|
+
try:
|
474
|
+
diffusion_loss_dict = self.structure_module.compute_loss(
|
475
|
+
batch,
|
476
|
+
out,
|
477
|
+
multiplicity=self.training_args.diffusion_multiplicity,
|
478
|
+
**self.diffusion_loss_args,
|
479
|
+
)
|
480
|
+
except Exception as e:
|
481
|
+
print(f"Skipping batch {batch_idx} due to error: {e}")
|
482
|
+
return None
|
483
|
+
|
484
|
+
else:
|
485
|
+
disto_loss = 0.0
|
486
|
+
diffusion_loss_dict = {"loss": 0.0, "loss_breakdown": {}}
|
487
|
+
|
488
|
+
if self.confidence_prediction:
|
489
|
+
# confidence model symmetry correction
|
490
|
+
true_coords, _, _, true_coords_resolved_mask = self.get_true_coordinates(
|
491
|
+
batch,
|
492
|
+
out,
|
493
|
+
diffusion_samples=self.training_args.diffusion_samples,
|
494
|
+
symmetry_correction=self.training_args.symmetry_correction,
|
495
|
+
)
|
496
|
+
|
497
|
+
confidence_loss_dict = confidence_loss(
|
498
|
+
out,
|
499
|
+
batch,
|
500
|
+
true_coords,
|
501
|
+
true_coords_resolved_mask,
|
502
|
+
alpha_pae=self.alpha_pae,
|
503
|
+
multiplicity=self.training_args.diffusion_samples,
|
504
|
+
)
|
505
|
+
else:
|
506
|
+
confidence_loss_dict = {
|
507
|
+
"loss": torch.tensor(0.0).to(batch["token_index"].device),
|
508
|
+
"loss_breakdown": {},
|
509
|
+
}
|
510
|
+
|
511
|
+
# Aggregate losses
|
512
|
+
loss = (
|
513
|
+
self.training_args.confidence_loss_weight * confidence_loss_dict["loss"]
|
514
|
+
+ self.training_args.diffusion_loss_weight * diffusion_loss_dict["loss"]
|
515
|
+
+ self.training_args.distogram_loss_weight * disto_loss
|
516
|
+
)
|
517
|
+
# Log losses
|
518
|
+
self.log("train/distogram_loss", disto_loss)
|
519
|
+
self.log("train/diffusion_loss", diffusion_loss_dict["loss"])
|
520
|
+
for k, v in diffusion_loss_dict["loss_breakdown"].items():
|
521
|
+
self.log(f"train/{k}", v)
|
522
|
+
|
523
|
+
if self.confidence_prediction:
|
524
|
+
self.train_confidence_loss_logger.update(
|
525
|
+
confidence_loss_dict["loss"].detach()
|
526
|
+
)
|
527
|
+
|
528
|
+
for k in self.train_confidence_loss_dict_logger.keys():
|
529
|
+
self.train_confidence_loss_dict_logger[k].update(
|
530
|
+
confidence_loss_dict["loss_breakdown"][k].detach()
|
531
|
+
if torch.is_tensor(confidence_loss_dict["loss_breakdown"][k])
|
532
|
+
else confidence_loss_dict["loss_breakdown"][k]
|
533
|
+
)
|
534
|
+
self.log("train/loss", loss)
|
535
|
+
self.training_log()
|
536
|
+
return loss
|
537
|
+
|
538
|
+
def training_log(self):
|
539
|
+
self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False)
|
540
|
+
self.log("train/param_norm", self.parameter_norm(self), prog_bar=False)
|
541
|
+
|
542
|
+
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
543
|
+
self.log("lr", lr, prog_bar=False)
|
544
|
+
|
545
|
+
self.log(
|
546
|
+
"train/grad_norm_msa_module",
|
547
|
+
self.gradient_norm(self.msa_module),
|
548
|
+
prog_bar=False,
|
549
|
+
)
|
550
|
+
self.log(
|
551
|
+
"train/param_norm_msa_module",
|
552
|
+
self.parameter_norm(self.msa_module),
|
553
|
+
prog_bar=False,
|
554
|
+
)
|
555
|
+
|
556
|
+
self.log(
|
557
|
+
"train/grad_norm_pairformer_module",
|
558
|
+
self.gradient_norm(self.pairformer_module),
|
559
|
+
prog_bar=False,
|
560
|
+
)
|
561
|
+
self.log(
|
562
|
+
"train/param_norm_pairformer_module",
|
563
|
+
self.parameter_norm(self.pairformer_module),
|
564
|
+
prog_bar=False,
|
565
|
+
)
|
566
|
+
|
567
|
+
self.log(
|
568
|
+
"train/grad_norm_structure_module",
|
569
|
+
self.gradient_norm(self.structure_module),
|
570
|
+
prog_bar=False,
|
571
|
+
)
|
572
|
+
self.log(
|
573
|
+
"train/param_norm_structure_module",
|
574
|
+
self.parameter_norm(self.structure_module),
|
575
|
+
prog_bar=False,
|
576
|
+
)
|
577
|
+
|
578
|
+
if self.confidence_prediction:
|
579
|
+
self.log(
|
580
|
+
"train/grad_norm_confidence_module",
|
581
|
+
self.gradient_norm(self.confidence_module),
|
582
|
+
prog_bar=False,
|
583
|
+
)
|
584
|
+
self.log(
|
585
|
+
"train/param_norm_confidence_module",
|
586
|
+
self.parameter_norm(self.confidence_module),
|
587
|
+
prog_bar=False,
|
588
|
+
)
|
589
|
+
|
590
|
+
def on_train_epoch_end(self):
|
591
|
+
self.log(
|
592
|
+
"train/confidence_loss",
|
593
|
+
self.train_confidence_loss_logger,
|
594
|
+
prog_bar=False,
|
595
|
+
on_step=False,
|
596
|
+
on_epoch=True,
|
597
|
+
)
|
598
|
+
for k, v in self.train_confidence_loss_dict_logger.items():
|
599
|
+
self.log(f"train/{k}", v, prog_bar=False, on_step=False, on_epoch=True)
|
600
|
+
|
601
|
+
def gradient_norm(self, module) -> float:
|
602
|
+
# Only compute over parameters that are being trained
|
603
|
+
parameters = filter(lambda p: p.requires_grad, module.parameters())
|
604
|
+
parameters = filter(lambda p: p.grad is not None, parameters)
|
605
|
+
norm = torch.tensor([p.grad.norm(p=2) ** 2 for p in parameters]).sum().sqrt()
|
606
|
+
return norm
|
607
|
+
|
608
|
+
def parameter_norm(self, module) -> float:
|
609
|
+
# Only compute over parameters that are being trained
|
610
|
+
parameters = filter(lambda p: p.requires_grad, module.parameters())
|
611
|
+
norm = torch.tensor([p.norm(p=2) ** 2 for p in parameters]).sum().sqrt()
|
612
|
+
return norm
|
613
|
+
|
614
|
+
def validation_step(self, batch: dict[str, Tensor], batch_idx: int):
|
615
|
+
# Compute the forward pass
|
616
|
+
n_samples = self.validation_args.diffusion_samples
|
617
|
+
try:
|
618
|
+
out = self(
|
619
|
+
batch,
|
620
|
+
recycling_steps=self.validation_args.recycling_steps,
|
621
|
+
num_sampling_steps=self.validation_args.sampling_steps,
|
622
|
+
diffusion_samples=n_samples,
|
623
|
+
run_confidence_sequentially=self.validation_args.run_confidence_sequentially,
|
624
|
+
)
|
625
|
+
|
626
|
+
except RuntimeError as e: # catch out of memory exceptions
|
627
|
+
if "out of memory" in str(e):
|
628
|
+
print("| WARNING: ran out of memory, skipping batch")
|
629
|
+
torch.cuda.empty_cache()
|
630
|
+
gc.collect()
|
631
|
+
return
|
632
|
+
else:
|
633
|
+
raise e
|
634
|
+
|
635
|
+
try:
|
636
|
+
# Compute distogram LDDT
|
637
|
+
boundaries = torch.linspace(2, 22.0, 63)
|
638
|
+
lower = torch.tensor([1.0])
|
639
|
+
upper = torch.tensor([22.0 + 5.0])
|
640
|
+
exp_boundaries = torch.cat((lower, boundaries, upper))
|
641
|
+
mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to(
|
642
|
+
out["pdistogram"]
|
643
|
+
)
|
644
|
+
|
645
|
+
# Compute predicted dists
|
646
|
+
preds = out["pdistogram"]
|
647
|
+
pred_softmax = torch.softmax(preds, dim=-1)
|
648
|
+
pred_softmax = pred_softmax.argmax(dim=-1)
|
649
|
+
pred_softmax = torch.nn.functional.one_hot(
|
650
|
+
pred_softmax, num_classes=preds.shape[-1]
|
651
|
+
)
|
652
|
+
pred_dist = (pred_softmax * mid_points).sum(dim=-1)
|
653
|
+
true_center = batch["disto_center"]
|
654
|
+
true_dists = torch.cdist(true_center, true_center)
|
655
|
+
|
656
|
+
# Compute lddt's
|
657
|
+
batch["token_disto_mask"] = batch["token_disto_mask"]
|
658
|
+
disto_lddt_dict, disto_total_dict = factored_token_lddt_dist_loss(
|
659
|
+
feats=batch,
|
660
|
+
true_d=true_dists,
|
661
|
+
pred_d=pred_dist,
|
662
|
+
)
|
663
|
+
|
664
|
+
true_coords, rmsds, best_rmsds, true_coords_resolved_mask = (
|
665
|
+
self.get_true_coordinates(
|
666
|
+
batch=batch,
|
667
|
+
out=out,
|
668
|
+
diffusion_samples=n_samples,
|
669
|
+
symmetry_correction=self.validation_args.symmetry_correction,
|
670
|
+
)
|
671
|
+
)
|
672
|
+
|
673
|
+
all_lddt_dict, all_total_dict = factored_lddt_loss(
|
674
|
+
feats=batch,
|
675
|
+
atom_mask=true_coords_resolved_mask,
|
676
|
+
true_atom_coords=true_coords,
|
677
|
+
pred_atom_coords=out["sample_atom_coords"],
|
678
|
+
multiplicity=n_samples,
|
679
|
+
)
|
680
|
+
except RuntimeError as e: # catch out of memory exceptions
|
681
|
+
if "out of memory" in str(e):
|
682
|
+
print("| WARNING: ran out of memory, skipping batch")
|
683
|
+
torch.cuda.empty_cache()
|
684
|
+
gc.collect()
|
685
|
+
return
|
686
|
+
else:
|
687
|
+
raise e
|
688
|
+
# if the multiplicity used is > 1 then we take the best lddt of the different samples
|
689
|
+
# AF3 combines this with the confidence based filtering
|
690
|
+
best_lddt_dict, best_total_dict = {}, {}
|
691
|
+
best_complex_lddt_dict, best_complex_total_dict = {}, {}
|
692
|
+
B = true_coords.shape[0] // n_samples
|
693
|
+
if n_samples > 1:
|
694
|
+
# NOTE: we can change the way we aggregate the lddt
|
695
|
+
complex_total = 0
|
696
|
+
complex_lddt = 0
|
697
|
+
for key in all_lddt_dict.keys():
|
698
|
+
complex_lddt += all_lddt_dict[key] * all_total_dict[key]
|
699
|
+
complex_total += all_total_dict[key]
|
700
|
+
complex_lddt /= complex_total + 1e-7
|
701
|
+
best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1)
|
702
|
+
for key in all_lddt_dict:
|
703
|
+
best_idx = all_lddt_dict[key].reshape(-1, n_samples).argmax(dim=1)
|
704
|
+
best_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[
|
705
|
+
torch.arange(B), best_idx
|
706
|
+
]
|
707
|
+
best_total_dict[key] = all_total_dict[key].reshape(-1, n_samples)[
|
708
|
+
torch.arange(B), best_idx
|
709
|
+
]
|
710
|
+
best_complex_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[
|
711
|
+
torch.arange(B), best_complex_idx
|
712
|
+
]
|
713
|
+
best_complex_total_dict[key] = all_total_dict[key].reshape(
|
714
|
+
-1, n_samples
|
715
|
+
)[torch.arange(B), best_complex_idx]
|
716
|
+
else:
|
717
|
+
best_lddt_dict = all_lddt_dict
|
718
|
+
best_total_dict = all_total_dict
|
719
|
+
best_complex_lddt_dict = all_lddt_dict
|
720
|
+
best_complex_total_dict = all_total_dict
|
721
|
+
|
722
|
+
# Filtering based on confidence
|
723
|
+
if self.confidence_prediction and n_samples > 1:
|
724
|
+
# note: for now we don't have pae predictions so have to use pLDDT instead of pTM
|
725
|
+
# also, while AF3 differentiates the best prediction per confidence type we are currently not doing it
|
726
|
+
# consider this in the future as well as weighing the different pLLDT types before aggregation
|
727
|
+
mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae(
|
728
|
+
pred_atom_coords=out["sample_atom_coords"],
|
729
|
+
feats=batch,
|
730
|
+
true_atom_coords=true_coords,
|
731
|
+
pred_lddt=out["plddt"],
|
732
|
+
true_coords_resolved_mask=true_coords_resolved_mask,
|
733
|
+
multiplicity=n_samples,
|
734
|
+
)
|
735
|
+
mae_pde_dict, total_mae_pde_dict = compute_pde_mae(
|
736
|
+
pred_atom_coords=out["sample_atom_coords"],
|
737
|
+
feats=batch,
|
738
|
+
true_atom_coords=true_coords,
|
739
|
+
pred_pde=out["pde"],
|
740
|
+
true_coords_resolved_mask=true_coords_resolved_mask,
|
741
|
+
multiplicity=n_samples,
|
742
|
+
)
|
743
|
+
mae_pae_dict, total_mae_pae_dict = compute_pae_mae(
|
744
|
+
pred_atom_coords=out["sample_atom_coords"],
|
745
|
+
feats=batch,
|
746
|
+
true_atom_coords=true_coords,
|
747
|
+
pred_pae=out["pae"],
|
748
|
+
true_coords_resolved_mask=true_coords_resolved_mask,
|
749
|
+
multiplicity=n_samples,
|
750
|
+
)
|
751
|
+
|
752
|
+
plddt = out["complex_plddt"].reshape(-1, n_samples)
|
753
|
+
top1_idx = plddt.argmax(dim=1)
|
754
|
+
iplddt = out["complex_iplddt"].reshape(-1, n_samples)
|
755
|
+
iplddt_top1_idx = iplddt.argmax(dim=1)
|
756
|
+
pde = out["complex_pde"].reshape(-1, n_samples)
|
757
|
+
pde_top1_idx = pde.argmin(dim=1)
|
758
|
+
ipde = out["complex_ipde"].reshape(-1, n_samples)
|
759
|
+
ipde_top1_idx = ipde.argmin(dim=1)
|
760
|
+
ptm = out["ptm"].reshape(-1, n_samples)
|
761
|
+
ptm_top1_idx = ptm.argmax(dim=1)
|
762
|
+
iptm = out["iptm"].reshape(-1, n_samples)
|
763
|
+
iptm_top1_idx = iptm.argmax(dim=1)
|
764
|
+
ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples)
|
765
|
+
ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1)
|
766
|
+
protein_iptm = out["protein_iptm"].reshape(-1, n_samples)
|
767
|
+
protein_iptm_top1_idx = protein_iptm.argmax(dim=1)
|
768
|
+
|
769
|
+
for key in all_lddt_dict:
|
770
|
+
top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
771
|
+
torch.arange(B), top1_idx
|
772
|
+
]
|
773
|
+
top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
774
|
+
torch.arange(B), top1_idx
|
775
|
+
]
|
776
|
+
iplddt_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
777
|
+
torch.arange(B), iplddt_top1_idx
|
778
|
+
]
|
779
|
+
iplddt_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
780
|
+
torch.arange(B), iplddt_top1_idx
|
781
|
+
]
|
782
|
+
pde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
783
|
+
torch.arange(B), pde_top1_idx
|
784
|
+
]
|
785
|
+
pde_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
786
|
+
torch.arange(B), pde_top1_idx
|
787
|
+
]
|
788
|
+
ipde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
789
|
+
torch.arange(B), ipde_top1_idx
|
790
|
+
]
|
791
|
+
ipde_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
792
|
+
torch.arange(B), ipde_top1_idx
|
793
|
+
]
|
794
|
+
ptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
795
|
+
torch.arange(B), ptm_top1_idx
|
796
|
+
]
|
797
|
+
ptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
798
|
+
torch.arange(B), ptm_top1_idx
|
799
|
+
]
|
800
|
+
iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
801
|
+
torch.arange(B), iptm_top1_idx
|
802
|
+
]
|
803
|
+
iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
804
|
+
torch.arange(B), iptm_top1_idx
|
805
|
+
]
|
806
|
+
ligand_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
807
|
+
torch.arange(B), ligand_iptm_top1_idx
|
808
|
+
]
|
809
|
+
ligand_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
810
|
+
torch.arange(B), ligand_iptm_top1_idx
|
811
|
+
]
|
812
|
+
protein_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
|
813
|
+
torch.arange(B), protein_iptm_top1_idx
|
814
|
+
]
|
815
|
+
protein_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
|
816
|
+
torch.arange(B), protein_iptm_top1_idx
|
817
|
+
]
|
818
|
+
|
819
|
+
self.top1_lddt[key].update(top1_lddt, top1_total)
|
820
|
+
self.iplddt_top1_lddt[key].update(iplddt_top1_lddt, iplddt_top1_total)
|
821
|
+
self.pde_top1_lddt[key].update(pde_top1_lddt, pde_top1_total)
|
822
|
+
self.ipde_top1_lddt[key].update(ipde_top1_lddt, ipde_top1_total)
|
823
|
+
self.ptm_top1_lddt[key].update(ptm_top1_lddt, ptm_top1_total)
|
824
|
+
self.iptm_top1_lddt[key].update(iptm_top1_lddt, iptm_top1_total)
|
825
|
+
self.ligand_iptm_top1_lddt[key].update(
|
826
|
+
ligand_iptm_top1_lddt, ligand_iptm_top1_total
|
827
|
+
)
|
828
|
+
self.protein_iptm_top1_lddt[key].update(
|
829
|
+
protein_iptm_top1_lddt, protein_iptm_top1_total
|
830
|
+
)
|
831
|
+
|
832
|
+
self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key])
|
833
|
+
self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key])
|
834
|
+
self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key])
|
835
|
+
|
836
|
+
for key in mae_plddt_dict:
|
837
|
+
self.plddt_mae[key].update(
|
838
|
+
mae_plddt_dict[key], total_mae_plddt_dict[key]
|
839
|
+
)
|
840
|
+
|
841
|
+
for m in const.out_types:
|
842
|
+
if m == "ligand_protein":
|
843
|
+
if torch.any(
|
844
|
+
batch["pocket_feature"][
|
845
|
+
:, :, const.pocket_contact_info["POCKET"]
|
846
|
+
].bool()
|
847
|
+
):
|
848
|
+
self.lddt["pocket_ligand_protein"].update(
|
849
|
+
best_lddt_dict[m], best_total_dict[m]
|
850
|
+
)
|
851
|
+
self.disto_lddt["pocket_ligand_protein"].update(
|
852
|
+
disto_lddt_dict[m], disto_total_dict[m]
|
853
|
+
)
|
854
|
+
self.complex_lddt["pocket_ligand_protein"].update(
|
855
|
+
best_complex_lddt_dict[m], best_complex_total_dict[m]
|
856
|
+
)
|
857
|
+
else:
|
858
|
+
self.lddt["ligand_protein"].update(
|
859
|
+
best_lddt_dict[m], best_total_dict[m]
|
860
|
+
)
|
861
|
+
self.disto_lddt["ligand_protein"].update(
|
862
|
+
disto_lddt_dict[m], disto_total_dict[m]
|
863
|
+
)
|
864
|
+
self.complex_lddt["ligand_protein"].update(
|
865
|
+
best_complex_lddt_dict[m], best_complex_total_dict[m]
|
866
|
+
)
|
867
|
+
else:
|
868
|
+
self.lddt[m].update(best_lddt_dict[m], best_total_dict[m])
|
869
|
+
self.disto_lddt[m].update(disto_lddt_dict[m], disto_total_dict[m])
|
870
|
+
self.complex_lddt[m].update(
|
871
|
+
best_complex_lddt_dict[m], best_complex_total_dict[m]
|
872
|
+
)
|
873
|
+
self.rmsd.update(rmsds)
|
874
|
+
self.best_rmsd.update(best_rmsds)
|
875
|
+
|
876
|
+
def on_validation_epoch_end(self):
|
877
|
+
avg_lddt = {}
|
878
|
+
avg_disto_lddt = {}
|
879
|
+
avg_complex_lddt = {}
|
880
|
+
if self.confidence_prediction:
|
881
|
+
avg_top1_lddt = {}
|
882
|
+
avg_iplddt_top1_lddt = {}
|
883
|
+
avg_pde_top1_lddt = {}
|
884
|
+
avg_ipde_top1_lddt = {}
|
885
|
+
avg_ptm_top1_lddt = {}
|
886
|
+
avg_iptm_top1_lddt = {}
|
887
|
+
avg_ligand_iptm_top1_lddt = {}
|
888
|
+
avg_protein_iptm_top1_lddt = {}
|
889
|
+
|
890
|
+
avg_avg_lddt = {}
|
891
|
+
avg_mae_plddt = {}
|
892
|
+
avg_mae_pde = {}
|
893
|
+
avg_mae_pae = {}
|
894
|
+
|
895
|
+
for m in const.out_types + ["pocket_ligand_protein"]:
|
896
|
+
avg_lddt[m] = self.lddt[m].compute()
|
897
|
+
avg_lddt[m] = 0.0 if torch.isnan(avg_lddt[m]) else avg_lddt[m].item()
|
898
|
+
self.lddt[m].reset()
|
899
|
+
self.log(f"val/lddt_{m}", avg_lddt[m], prog_bar=False, sync_dist=True)
|
900
|
+
|
901
|
+
avg_disto_lddt[m] = self.disto_lddt[m].compute()
|
902
|
+
avg_disto_lddt[m] = (
|
903
|
+
0.0 if torch.isnan(avg_disto_lddt[m]) else avg_disto_lddt[m].item()
|
904
|
+
)
|
905
|
+
self.disto_lddt[m].reset()
|
906
|
+
self.log(
|
907
|
+
f"val/disto_lddt_{m}", avg_disto_lddt[m], prog_bar=False, sync_dist=True
|
908
|
+
)
|
909
|
+
avg_complex_lddt[m] = self.complex_lddt[m].compute()
|
910
|
+
avg_complex_lddt[m] = (
|
911
|
+
0.0 if torch.isnan(avg_complex_lddt[m]) else avg_complex_lddt[m].item()
|
912
|
+
)
|
913
|
+
self.complex_lddt[m].reset()
|
914
|
+
self.log(
|
915
|
+
f"val/complex_lddt_{m}",
|
916
|
+
avg_complex_lddt[m],
|
917
|
+
prog_bar=False,
|
918
|
+
sync_dist=True,
|
919
|
+
)
|
920
|
+
if self.confidence_prediction:
|
921
|
+
avg_top1_lddt[m] = self.top1_lddt[m].compute()
|
922
|
+
avg_top1_lddt[m] = (
|
923
|
+
0.0 if torch.isnan(avg_top1_lddt[m]) else avg_top1_lddt[m].item()
|
924
|
+
)
|
925
|
+
self.top1_lddt[m].reset()
|
926
|
+
self.log(
|
927
|
+
f"val/top1_lddt_{m}",
|
928
|
+
avg_top1_lddt[m],
|
929
|
+
prog_bar=False,
|
930
|
+
sync_dist=True,
|
931
|
+
)
|
932
|
+
avg_iplddt_top1_lddt[m] = self.iplddt_top1_lddt[m].compute()
|
933
|
+
avg_iplddt_top1_lddt[m] = (
|
934
|
+
0.0
|
935
|
+
if torch.isnan(avg_iplddt_top1_lddt[m])
|
936
|
+
else avg_iplddt_top1_lddt[m].item()
|
937
|
+
)
|
938
|
+
self.iplddt_top1_lddt[m].reset()
|
939
|
+
self.log(
|
940
|
+
f"val/iplddt_top1_lddt_{m}",
|
941
|
+
avg_iplddt_top1_lddt[m],
|
942
|
+
prog_bar=False,
|
943
|
+
sync_dist=True,
|
944
|
+
)
|
945
|
+
avg_pde_top1_lddt[m] = self.pde_top1_lddt[m].compute()
|
946
|
+
avg_pde_top1_lddt[m] = (
|
947
|
+
0.0
|
948
|
+
if torch.isnan(avg_pde_top1_lddt[m])
|
949
|
+
else avg_pde_top1_lddt[m].item()
|
950
|
+
)
|
951
|
+
self.pde_top1_lddt[m].reset()
|
952
|
+
self.log(
|
953
|
+
f"val/pde_top1_lddt_{m}",
|
954
|
+
avg_pde_top1_lddt[m],
|
955
|
+
prog_bar=False,
|
956
|
+
sync_dist=True,
|
957
|
+
)
|
958
|
+
avg_ipde_top1_lddt[m] = self.ipde_top1_lddt[m].compute()
|
959
|
+
avg_ipde_top1_lddt[m] = (
|
960
|
+
0.0
|
961
|
+
if torch.isnan(avg_ipde_top1_lddt[m])
|
962
|
+
else avg_ipde_top1_lddt[m].item()
|
963
|
+
)
|
964
|
+
self.ipde_top1_lddt[m].reset()
|
965
|
+
self.log(
|
966
|
+
f"val/ipde_top1_lddt_{m}",
|
967
|
+
avg_ipde_top1_lddt[m],
|
968
|
+
prog_bar=False,
|
969
|
+
sync_dist=True,
|
970
|
+
)
|
971
|
+
avg_ptm_top1_lddt[m] = self.ptm_top1_lddt[m].compute()
|
972
|
+
avg_ptm_top1_lddt[m] = (
|
973
|
+
0.0
|
974
|
+
if torch.isnan(avg_ptm_top1_lddt[m])
|
975
|
+
else avg_ptm_top1_lddt[m].item()
|
976
|
+
)
|
977
|
+
self.ptm_top1_lddt[m].reset()
|
978
|
+
self.log(
|
979
|
+
f"val/ptm_top1_lddt_{m}",
|
980
|
+
avg_ptm_top1_lddt[m],
|
981
|
+
prog_bar=False,
|
982
|
+
sync_dist=True,
|
983
|
+
)
|
984
|
+
avg_iptm_top1_lddt[m] = self.iptm_top1_lddt[m].compute()
|
985
|
+
avg_iptm_top1_lddt[m] = (
|
986
|
+
0.0
|
987
|
+
if torch.isnan(avg_iptm_top1_lddt[m])
|
988
|
+
else avg_iptm_top1_lddt[m].item()
|
989
|
+
)
|
990
|
+
self.iptm_top1_lddt[m].reset()
|
991
|
+
self.log(
|
992
|
+
f"val/iptm_top1_lddt_{m}",
|
993
|
+
avg_iptm_top1_lddt[m],
|
994
|
+
prog_bar=False,
|
995
|
+
sync_dist=True,
|
996
|
+
)
|
997
|
+
|
998
|
+
avg_ligand_iptm_top1_lddt[m] = self.ligand_iptm_top1_lddt[m].compute()
|
999
|
+
avg_ligand_iptm_top1_lddt[m] = (
|
1000
|
+
0.0
|
1001
|
+
if torch.isnan(avg_ligand_iptm_top1_lddt[m])
|
1002
|
+
else avg_ligand_iptm_top1_lddt[m].item()
|
1003
|
+
)
|
1004
|
+
self.ligand_iptm_top1_lddt[m].reset()
|
1005
|
+
self.log(
|
1006
|
+
f"val/ligand_iptm_top1_lddt_{m}",
|
1007
|
+
avg_ligand_iptm_top1_lddt[m],
|
1008
|
+
prog_bar=False,
|
1009
|
+
sync_dist=True,
|
1010
|
+
)
|
1011
|
+
|
1012
|
+
avg_protein_iptm_top1_lddt[m] = self.protein_iptm_top1_lddt[m].compute()
|
1013
|
+
avg_protein_iptm_top1_lddt[m] = (
|
1014
|
+
0.0
|
1015
|
+
if torch.isnan(avg_protein_iptm_top1_lddt[m])
|
1016
|
+
else avg_protein_iptm_top1_lddt[m].item()
|
1017
|
+
)
|
1018
|
+
self.protein_iptm_top1_lddt[m].reset()
|
1019
|
+
self.log(
|
1020
|
+
f"val/protein_iptm_top1_lddt_{m}",
|
1021
|
+
avg_protein_iptm_top1_lddt[m],
|
1022
|
+
prog_bar=False,
|
1023
|
+
sync_dist=True,
|
1024
|
+
)
|
1025
|
+
|
1026
|
+
avg_avg_lddt[m] = self.avg_lddt[m].compute()
|
1027
|
+
avg_avg_lddt[m] = (
|
1028
|
+
0.0 if torch.isnan(avg_avg_lddt[m]) else avg_avg_lddt[m].item()
|
1029
|
+
)
|
1030
|
+
self.avg_lddt[m].reset()
|
1031
|
+
self.log(
|
1032
|
+
f"val/avg_lddt_{m}", avg_avg_lddt[m], prog_bar=False, sync_dist=True
|
1033
|
+
)
|
1034
|
+
avg_mae_pde[m] = self.pde_mae[m].compute().item()
|
1035
|
+
self.pde_mae[m].reset()
|
1036
|
+
self.log(
|
1037
|
+
f"val/MAE_pde_{m}",
|
1038
|
+
avg_mae_pde[m],
|
1039
|
+
prog_bar=False,
|
1040
|
+
sync_dist=True,
|
1041
|
+
)
|
1042
|
+
avg_mae_pae[m] = self.pae_mae[m].compute().item()
|
1043
|
+
self.pae_mae[m].reset()
|
1044
|
+
self.log(
|
1045
|
+
f"val/MAE_pae_{m}",
|
1046
|
+
avg_mae_pae[m],
|
1047
|
+
prog_bar=False,
|
1048
|
+
sync_dist=True,
|
1049
|
+
)
|
1050
|
+
|
1051
|
+
for m in const.out_single_types:
|
1052
|
+
if self.confidence_prediction:
|
1053
|
+
avg_mae_plddt[m] = self.plddt_mae[m].compute().item()
|
1054
|
+
self.plddt_mae[m].reset()
|
1055
|
+
self.log(
|
1056
|
+
f"val/MAE_plddt_{m}",
|
1057
|
+
avg_mae_plddt[m],
|
1058
|
+
prog_bar=False,
|
1059
|
+
sync_dist=True,
|
1060
|
+
)
|
1061
|
+
|
1062
|
+
overall_disto_lddt = sum(
|
1063
|
+
avg_disto_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1064
|
+
) / sum(const.out_types_weights.values())
|
1065
|
+
self.log("val/disto_lddt", overall_disto_lddt, prog_bar=True, sync_dist=True)
|
1066
|
+
|
1067
|
+
overall_lddt = sum(
|
1068
|
+
avg_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1069
|
+
) / sum(const.out_types_weights.values())
|
1070
|
+
self.log("val/lddt", overall_lddt, prog_bar=True, sync_dist=True)
|
1071
|
+
|
1072
|
+
overall_complex_lddt = sum(
|
1073
|
+
avg_complex_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1074
|
+
) / sum(const.out_types_weights.values())
|
1075
|
+
self.log(
|
1076
|
+
"val/complex_lddt", overall_complex_lddt, prog_bar=True, sync_dist=True
|
1077
|
+
)
|
1078
|
+
|
1079
|
+
if self.confidence_prediction:
|
1080
|
+
overall_top1_lddt = sum(
|
1081
|
+
avg_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1082
|
+
) / sum(const.out_types_weights.values())
|
1083
|
+
self.log("val/top1_lddt", overall_top1_lddt, prog_bar=True, sync_dist=True)
|
1084
|
+
|
1085
|
+
overall_iplddt_top1_lddt = sum(
|
1086
|
+
avg_iplddt_top1_lddt[m] * w
|
1087
|
+
for (m, w) in const.out_types_weights.items()
|
1088
|
+
) / sum(const.out_types_weights.values())
|
1089
|
+
self.log(
|
1090
|
+
"val/iplddt_top1_lddt",
|
1091
|
+
overall_iplddt_top1_lddt,
|
1092
|
+
prog_bar=True,
|
1093
|
+
sync_dist=True,
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
overall_pde_top1_lddt = sum(
|
1097
|
+
avg_pde_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1098
|
+
) / sum(const.out_types_weights.values())
|
1099
|
+
self.log(
|
1100
|
+
"val/pde_top1_lddt",
|
1101
|
+
overall_pde_top1_lddt,
|
1102
|
+
prog_bar=True,
|
1103
|
+
sync_dist=True,
|
1104
|
+
)
|
1105
|
+
|
1106
|
+
overall_ipde_top1_lddt = sum(
|
1107
|
+
avg_ipde_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1108
|
+
) / sum(const.out_types_weights.values())
|
1109
|
+
self.log(
|
1110
|
+
"val/ipde_top1_lddt",
|
1111
|
+
overall_ipde_top1_lddt,
|
1112
|
+
prog_bar=True,
|
1113
|
+
sync_dist=True,
|
1114
|
+
)
|
1115
|
+
|
1116
|
+
overall_ptm_top1_lddt = sum(
|
1117
|
+
avg_ptm_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1118
|
+
) / sum(const.out_types_weights.values())
|
1119
|
+
self.log(
|
1120
|
+
"val/ptm_top1_lddt",
|
1121
|
+
overall_ptm_top1_lddt,
|
1122
|
+
prog_bar=True,
|
1123
|
+
sync_dist=True,
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
overall_iptm_top1_lddt = sum(
|
1127
|
+
avg_iptm_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1128
|
+
) / sum(const.out_types_weights.values())
|
1129
|
+
self.log(
|
1130
|
+
"val/iptm_top1_lddt",
|
1131
|
+
overall_iptm_top1_lddt,
|
1132
|
+
prog_bar=True,
|
1133
|
+
sync_dist=True,
|
1134
|
+
)
|
1135
|
+
|
1136
|
+
overall_avg_lddt = sum(
|
1137
|
+
avg_avg_lddt[m] * w for (m, w) in const.out_types_weights.items()
|
1138
|
+
) / sum(const.out_types_weights.values())
|
1139
|
+
self.log("val/avg_lddt", overall_avg_lddt, prog_bar=True, sync_dist=True)
|
1140
|
+
|
1141
|
+
self.log("val/rmsd", self.rmsd.compute(), prog_bar=True, sync_dist=True)
|
1142
|
+
self.rmsd.reset()
|
1143
|
+
|
1144
|
+
self.log(
|
1145
|
+
"val/best_rmsd", self.best_rmsd.compute(), prog_bar=True, sync_dist=True
|
1146
|
+
)
|
1147
|
+
self.best_rmsd.reset()
|
1148
|
+
|
1149
|
+
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
|
1150
|
+
try:
|
1151
|
+
out = self(
|
1152
|
+
batch,
|
1153
|
+
recycling_steps=self.predict_args["recycling_steps"],
|
1154
|
+
num_sampling_steps=self.predict_args["sampling_steps"],
|
1155
|
+
diffusion_samples=self.predict_args["diffusion_samples"],
|
1156
|
+
max_parallel_samples=self.predict_args["diffusion_samples"],
|
1157
|
+
run_confidence_sequentially=True,
|
1158
|
+
)
|
1159
|
+
pred_dict = {"exception": False}
|
1160
|
+
pred_dict["masks"] = batch["atom_pad_mask"]
|
1161
|
+
pred_dict["coords"] = out["sample_atom_coords"]
|
1162
|
+
if self.predict_args.get("write_confidence_summary", True):
|
1163
|
+
pred_dict["confidence_score"] = (
|
1164
|
+
4 * out["complex_plddt"]
|
1165
|
+
+ (
|
1166
|
+
out["iptm"]
|
1167
|
+
if not torch.allclose(
|
1168
|
+
out["iptm"], torch.zeros_like(out["iptm"])
|
1169
|
+
)
|
1170
|
+
else out["ptm"]
|
1171
|
+
)
|
1172
|
+
) / 5
|
1173
|
+
for key in [
|
1174
|
+
"ptm",
|
1175
|
+
"iptm",
|
1176
|
+
"ligand_iptm",
|
1177
|
+
"protein_iptm",
|
1178
|
+
"pair_chains_iptm",
|
1179
|
+
"complex_plddt",
|
1180
|
+
"complex_iplddt",
|
1181
|
+
"complex_pde",
|
1182
|
+
"complex_ipde",
|
1183
|
+
"plddt",
|
1184
|
+
]:
|
1185
|
+
pred_dict[key] = out[key]
|
1186
|
+
if self.predict_args.get("write_full_pae", True):
|
1187
|
+
pred_dict["pae"] = out["pae"]
|
1188
|
+
if self.predict_args.get("write_full_pde", False):
|
1189
|
+
pred_dict["pde"] = out["pde"]
|
1190
|
+
return pred_dict
|
1191
|
+
|
1192
|
+
except RuntimeError as e: # catch out of memory exceptions
|
1193
|
+
if "out of memory" in str(e):
|
1194
|
+
print("| WARNING: ran out of memory, skipping batch")
|
1195
|
+
torch.cuda.empty_cache()
|
1196
|
+
gc.collect()
|
1197
|
+
return {"exception": True}
|
1198
|
+
else:
|
1199
|
+
raise
|
1200
|
+
|
1201
|
+
def configure_optimizers(self):
|
1202
|
+
"""Configure the optimizer."""
|
1203
|
+
|
1204
|
+
if self.structure_prediction_training:
|
1205
|
+
parameters = [p for p in self.parameters() if p.requires_grad]
|
1206
|
+
else:
|
1207
|
+
parameters = [
|
1208
|
+
p for p in self.confidence_module.parameters() if p.requires_grad
|
1209
|
+
] + [
|
1210
|
+
p
|
1211
|
+
for p in self.structure_module.out_token_feat_update.parameters()
|
1212
|
+
if p.requires_grad
|
1213
|
+
]
|
1214
|
+
|
1215
|
+
optimizer = torch.optim.Adam(
|
1216
|
+
parameters,
|
1217
|
+
betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2),
|
1218
|
+
eps=self.training_args.adam_eps,
|
1219
|
+
lr=self.training_args.base_lr,
|
1220
|
+
)
|
1221
|
+
if self.training_args.lr_scheduler == "af3":
|
1222
|
+
scheduler = AlphaFoldLRScheduler(
|
1223
|
+
optimizer,
|
1224
|
+
base_lr=self.training_args.base_lr,
|
1225
|
+
max_lr=self.training_args.max_lr,
|
1226
|
+
warmup_no_steps=self.training_args.lr_warmup_no_steps,
|
1227
|
+
start_decay_after_n_steps=self.training_args.lr_start_decay_after_n_steps,
|
1228
|
+
decay_every_n_steps=self.training_args.lr_decay_every_n_steps,
|
1229
|
+
decay_factor=self.training_args.lr_decay_factor,
|
1230
|
+
)
|
1231
|
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
1232
|
+
|
1233
|
+
return optimizer
|
1234
|
+
|
1235
|
+
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
1236
|
+
if self.use_ema:
|
1237
|
+
checkpoint["ema"] = self.ema.state_dict()
|
1238
|
+
|
1239
|
+
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
1240
|
+
if self.use_ema and "ema" in checkpoint:
|
1241
|
+
self.ema = ExponentialMovingAverage(
|
1242
|
+
parameters=self.parameters(), decay=self.ema_decay
|
1243
|
+
)
|
1244
|
+
if self.ema.compatible(checkpoint["ema"]["shadow_params"]):
|
1245
|
+
self.ema.load_state_dict(checkpoint["ema"], device=torch.device("cpu"))
|
1246
|
+
else:
|
1247
|
+
self.ema = None
|
1248
|
+
print(
|
1249
|
+
"Warning: EMA state not loaded due to incompatible model parameters."
|
1250
|
+
)
|
1251
|
+
|
1252
|
+
def on_train_start(self):
|
1253
|
+
if self.use_ema and self.ema is None:
|
1254
|
+
self.ema = ExponentialMovingAverage(
|
1255
|
+
parameters=self.parameters(), decay=self.ema_decay
|
1256
|
+
)
|
1257
|
+
elif self.use_ema:
|
1258
|
+
self.ema.to(self.device)
|
1259
|
+
|
1260
|
+
def on_train_epoch_start(self) -> None:
|
1261
|
+
if self.use_ema:
|
1262
|
+
self.ema.restore(self.parameters())
|
1263
|
+
|
1264
|
+
def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
|
1265
|
+
# Updates EMA parameters after optimizer.step()
|
1266
|
+
if self.use_ema:
|
1267
|
+
self.ema.update(self.parameters())
|
1268
|
+
|
1269
|
+
def prepare_eval(self) -> None:
|
1270
|
+
if self.use_ema and self.ema is None:
|
1271
|
+
self.ema = ExponentialMovingAverage(
|
1272
|
+
parameters=self.parameters(), decay=self.ema_decay
|
1273
|
+
)
|
1274
|
+
|
1275
|
+
if self.use_ema:
|
1276
|
+
self.ema.store(self.parameters())
|
1277
|
+
self.ema.copy_to(self.parameters())
|
1278
|
+
|
1279
|
+
def on_validation_start(self):
|
1280
|
+
self.prepare_eval()
|
1281
|
+
|
1282
|
+
def on_predict_start(self) -> None:
|
1283
|
+
self.prepare_eval()
|
1284
|
+
|
1285
|
+
def on_test_start(self) -> None:
|
1286
|
+
self.prepare_eval()
|