boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,303 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
from functools import partial
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import (
|
9
|
+
Linear,
|
10
|
+
Module,
|
11
|
+
)
|
12
|
+
from torch.types import Device
|
13
|
+
|
14
|
+
LinearNoBias = partial(Linear, bias=False)
|
15
|
+
|
16
|
+
|
17
|
+
def exists(v):
|
18
|
+
return v is not None
|
19
|
+
|
20
|
+
|
21
|
+
def default(v, d):
|
22
|
+
return v if exists(v) else d
|
23
|
+
|
24
|
+
|
25
|
+
def log(t, eps=1e-20):
|
26
|
+
return torch.log(t.clamp(min=eps))
|
27
|
+
|
28
|
+
|
29
|
+
class SwiGLU(Module):
|
30
|
+
def forward(
|
31
|
+
self,
|
32
|
+
x, #: Float['... d']
|
33
|
+
): # -> Float[' ... (d//2)']:
|
34
|
+
x, gates = x.chunk(2, dim=-1)
|
35
|
+
return F.silu(gates) * x
|
36
|
+
|
37
|
+
|
38
|
+
def center(atom_coords, atom_mask):
|
39
|
+
atom_mean = torch.sum(
|
40
|
+
atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
|
41
|
+
) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
|
42
|
+
atom_coords = atom_coords - atom_mean
|
43
|
+
return atom_coords
|
44
|
+
|
45
|
+
|
46
|
+
def compute_random_augmentation(
|
47
|
+
multiplicity, s_trans=1.0, device=None, dtype=torch.float32
|
48
|
+
):
|
49
|
+
R = random_rotations(multiplicity, dtype=dtype, device=device)
|
50
|
+
random_trans = (
|
51
|
+
torch.randn((multiplicity, 1, 3), dtype=dtype, device=device) * s_trans
|
52
|
+
)
|
53
|
+
return R, random_trans
|
54
|
+
|
55
|
+
|
56
|
+
def randomly_rotate(coords, return_second_coords=False, second_coords=None):
|
57
|
+
R = random_rotations(len(coords), coords.dtype, coords.device)
|
58
|
+
|
59
|
+
if return_second_coords:
|
60
|
+
return torch.einsum("bmd,bds->bms", coords, R), torch.einsum(
|
61
|
+
"bmd,bds->bms", second_coords, R
|
62
|
+
) if second_coords is not None else None
|
63
|
+
|
64
|
+
return torch.einsum("bmd,bds->bms", coords, R)
|
65
|
+
|
66
|
+
|
67
|
+
def center_random_augmentation(
|
68
|
+
atom_coords,
|
69
|
+
atom_mask,
|
70
|
+
s_trans=1.0,
|
71
|
+
augmentation=True,
|
72
|
+
centering=True,
|
73
|
+
return_second_coords=False,
|
74
|
+
second_coords=None,
|
75
|
+
):
|
76
|
+
"""Algorithm 19"""
|
77
|
+
if centering:
|
78
|
+
atom_mean = torch.sum(
|
79
|
+
atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
|
80
|
+
) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
|
81
|
+
atom_coords = atom_coords - atom_mean
|
82
|
+
|
83
|
+
if second_coords is not None:
|
84
|
+
# apply same transformation also to this input
|
85
|
+
second_coords = second_coords - atom_mean
|
86
|
+
|
87
|
+
if augmentation:
|
88
|
+
atom_coords, second_coords = randomly_rotate(
|
89
|
+
atom_coords, return_second_coords=True, second_coords=second_coords
|
90
|
+
)
|
91
|
+
random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans
|
92
|
+
atom_coords = atom_coords + random_trans
|
93
|
+
|
94
|
+
if second_coords is not None:
|
95
|
+
second_coords = second_coords + random_trans
|
96
|
+
|
97
|
+
if return_second_coords:
|
98
|
+
return atom_coords, second_coords
|
99
|
+
|
100
|
+
return atom_coords
|
101
|
+
|
102
|
+
|
103
|
+
class ExponentialMovingAverage:
|
104
|
+
"""from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license
|
105
|
+
Maintains (exponential) moving average of a set of parameters."""
|
106
|
+
|
107
|
+
def __init__(self, parameters, decay, use_num_updates=True):
|
108
|
+
"""
|
109
|
+
Args:
|
110
|
+
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
111
|
+
`model.parameters()`.
|
112
|
+
decay: The exponential decay.
|
113
|
+
use_num_updates: Whether to use number of updates when computing
|
114
|
+
averages.
|
115
|
+
"""
|
116
|
+
if decay < 0.0 or decay > 1.0:
|
117
|
+
raise ValueError("Decay must be between 0 and 1")
|
118
|
+
self.decay = decay
|
119
|
+
self.num_updates = 0 if use_num_updates else None
|
120
|
+
self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
|
121
|
+
self.collected_params = []
|
122
|
+
|
123
|
+
def update(self, parameters):
|
124
|
+
"""
|
125
|
+
Update currently maintained parameters.
|
126
|
+
Call this every time the parameters are updated, such as the result of
|
127
|
+
the `optimizer.step()` call.
|
128
|
+
Args:
|
129
|
+
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
130
|
+
parameters used to initialize this object.
|
131
|
+
"""
|
132
|
+
decay = self.decay
|
133
|
+
if self.num_updates is not None:
|
134
|
+
self.num_updates += 1
|
135
|
+
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
136
|
+
one_minus_decay = 1.0 - decay
|
137
|
+
with torch.no_grad():
|
138
|
+
parameters = [p for p in parameters if p.requires_grad]
|
139
|
+
for s_param, param in zip(self.shadow_params, parameters):
|
140
|
+
s_param.sub_(one_minus_decay * (s_param - param))
|
141
|
+
|
142
|
+
def compatible(self, parameters):
|
143
|
+
if len(self.shadow_params) != len(parameters):
|
144
|
+
print(
|
145
|
+
f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}"
|
146
|
+
)
|
147
|
+
return False
|
148
|
+
|
149
|
+
for s_param, param in zip(self.shadow_params, parameters):
|
150
|
+
if param.data.shape != s_param.data.shape:
|
151
|
+
print(
|
152
|
+
f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}"
|
153
|
+
)
|
154
|
+
return False
|
155
|
+
return True
|
156
|
+
|
157
|
+
def copy_to(self, parameters):
|
158
|
+
"""
|
159
|
+
Copy current parameters into given collection of parameters.
|
160
|
+
Args:
|
161
|
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
162
|
+
updated with the stored moving averages.
|
163
|
+
"""
|
164
|
+
parameters = [p for p in parameters if p.requires_grad]
|
165
|
+
for s_param, param in zip(self.shadow_params, parameters):
|
166
|
+
if param.requires_grad:
|
167
|
+
param.data.copy_(s_param.data)
|
168
|
+
|
169
|
+
def store(self, parameters):
|
170
|
+
"""
|
171
|
+
Save the current parameters for restoring later.
|
172
|
+
Args:
|
173
|
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
174
|
+
temporarily stored.
|
175
|
+
"""
|
176
|
+
self.collected_params = [param.clone() for param in parameters]
|
177
|
+
|
178
|
+
def restore(self, parameters):
|
179
|
+
"""
|
180
|
+
Restore the parameters stored with the `store` method.
|
181
|
+
Useful to validate the model with EMA parameters without affecting the
|
182
|
+
original optimization process. Store the parameters before the
|
183
|
+
`copy_to` method. After validation (or model saving), use this to
|
184
|
+
restore the former parameters.
|
185
|
+
Args:
|
186
|
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
187
|
+
updated with the stored parameters.
|
188
|
+
"""
|
189
|
+
for c_param, param in zip(self.collected_params, parameters):
|
190
|
+
param.data.copy_(c_param.data)
|
191
|
+
|
192
|
+
def state_dict(self):
|
193
|
+
return dict(
|
194
|
+
decay=self.decay,
|
195
|
+
num_updates=self.num_updates,
|
196
|
+
shadow_params=self.shadow_params,
|
197
|
+
)
|
198
|
+
|
199
|
+
def load_state_dict(self, state_dict, device):
|
200
|
+
self.decay = state_dict["decay"]
|
201
|
+
self.num_updates = state_dict["num_updates"]
|
202
|
+
self.shadow_params = [
|
203
|
+
tensor.to(device) for tensor in state_dict["shadow_params"]
|
204
|
+
]
|
205
|
+
|
206
|
+
def to(self, device):
|
207
|
+
self.shadow_params = [tensor.to(device) for tensor in self.shadow_params]
|
208
|
+
|
209
|
+
|
210
|
+
# the following is copied from Torch3D, BSD License, Copyright (c) Meta Platforms, Inc. and affiliates.
|
211
|
+
|
212
|
+
|
213
|
+
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
214
|
+
"""
|
215
|
+
Return a tensor where each element has the absolute value taken from the,
|
216
|
+
corresponding element of a, with sign taken from the corresponding
|
217
|
+
element of b. This is like the standard copysign floating-point operation,
|
218
|
+
but is not careful about negative 0 and NaN.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
a: source tensor.
|
222
|
+
b: tensor whose signs will be used, of the same shape as a.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
Tensor of the same shape as a with the signs of b.
|
226
|
+
"""
|
227
|
+
signs_differ = (a < 0) != (b < 0)
|
228
|
+
return torch.where(signs_differ, -a, a)
|
229
|
+
|
230
|
+
|
231
|
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
232
|
+
"""
|
233
|
+
Convert rotations given as quaternions to rotation matrices.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
quaternions: quaternions with real part first,
|
237
|
+
as tensor of shape (..., 4).
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
Rotation matrices as tensor of shape (..., 3, 3).
|
241
|
+
"""
|
242
|
+
r, i, j, k = torch.unbind(quaternions, -1)
|
243
|
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
244
|
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
245
|
+
|
246
|
+
o = torch.stack(
|
247
|
+
(
|
248
|
+
1 - two_s * (j * j + k * k),
|
249
|
+
two_s * (i * j - k * r),
|
250
|
+
two_s * (i * k + j * r),
|
251
|
+
two_s * (i * j + k * r),
|
252
|
+
1 - two_s * (i * i + k * k),
|
253
|
+
two_s * (j * k - i * r),
|
254
|
+
two_s * (i * k - j * r),
|
255
|
+
two_s * (j * k + i * r),
|
256
|
+
1 - two_s * (i * i + j * j),
|
257
|
+
),
|
258
|
+
-1,
|
259
|
+
)
|
260
|
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
261
|
+
|
262
|
+
|
263
|
+
def random_quaternions(
|
264
|
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
265
|
+
) -> torch.Tensor:
|
266
|
+
"""
|
267
|
+
Generate random quaternions representing rotations,
|
268
|
+
i.e. versors with nonnegative real part.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
n: Number of quaternions in a batch to return.
|
272
|
+
dtype: Type to return.
|
273
|
+
device: Desired device of returned tensor. Default:
|
274
|
+
uses the current device for the default tensor type.
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
Quaternions as tensor of shape (N, 4).
|
278
|
+
"""
|
279
|
+
if isinstance(device, str):
|
280
|
+
device = torch.device(device)
|
281
|
+
o = torch.randn((n, 4), dtype=dtype, device=device)
|
282
|
+
s = (o * o).sum(1)
|
283
|
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
284
|
+
return o
|
285
|
+
|
286
|
+
|
287
|
+
def random_rotations(
|
288
|
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
289
|
+
) -> torch.Tensor:
|
290
|
+
"""
|
291
|
+
Generate random rotations as 3x3 rotation matrices.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
n: Number of rotation matrices in a batch to return.
|
295
|
+
dtype: Type to return.
|
296
|
+
device: Device of returned tensor. Default: if None,
|
297
|
+
uses the current device for the default tensor type.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
Rotation matrices as tensor of shape (n, 3, 3).
|
301
|
+
"""
|
302
|
+
quaternions = random_quaternions(n, dtype=dtype, device=device)
|
303
|
+
return quaternion_to_matrix(quaternions)
|
File without changes
|
boltz/model/optim/ema.py
ADDED
@@ -0,0 +1,389 @@
|
|
1
|
+
# --------------------------------------------------------------------------------------
|
2
|
+
# Modified from Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion):
|
3
|
+
# Modified from : https://github.com/BioinfoMachineLearning/bio-diffusion/blob/main/src/utils/__init__.py
|
4
|
+
# --------------------------------------------------------------------------------------
|
5
|
+
|
6
|
+
from typing import Any, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from pytorch_lightning import Callback, LightningModule, Trainer
|
10
|
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
11
|
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
12
|
+
|
13
|
+
|
14
|
+
class EMA(Callback):
|
15
|
+
"""Implements Exponential Moving Averaging (EMA).
|
16
|
+
|
17
|
+
When training a model, this callback maintains moving averages
|
18
|
+
of the trained parameters. When evaluating, we use the moving
|
19
|
+
averages copy of the trained parameters. When saving, we save
|
20
|
+
an additional set of parameters with the prefix `ema`.
|
21
|
+
|
22
|
+
Adapted from:
|
23
|
+
https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
|
24
|
+
https://github.com/BioinfoMachineLearning/bio-diffusion/blob/main/src/utils/__init__.py
|
25
|
+
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
decay: float = 0.999,
|
31
|
+
apply_ema_every_n_steps: int = 1,
|
32
|
+
start_step: int = 0,
|
33
|
+
eval_with_ema: bool = True,
|
34
|
+
warm_start: bool = True,
|
35
|
+
) -> None:
|
36
|
+
"""Initialize the EMA callback.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
decay: float
|
41
|
+
The exponential decay, has to be between 0-1.
|
42
|
+
apply_ema_every_n_steps: int, optional (default=1)
|
43
|
+
Apply EMA every n global steps.
|
44
|
+
start_step: int, optional (default=0)
|
45
|
+
Start applying EMA from ``start_step`` global step onwards.
|
46
|
+
eval_with_ema: bool, optional (default=True)
|
47
|
+
Validate the EMA weights instead of the original weights.
|
48
|
+
Note this means that when saving the model, the
|
49
|
+
validation metrics are calculated with the EMA weights.
|
50
|
+
|
51
|
+
"""
|
52
|
+
if not (0 <= decay <= 1):
|
53
|
+
msg = "EMA decay value must be between 0 and 1"
|
54
|
+
raise MisconfigurationException(msg)
|
55
|
+
|
56
|
+
self._ema_weights: Optional[dict[str, torch.Tensor]] = None
|
57
|
+
self._cur_step: Optional[int] = None
|
58
|
+
self._weights_buffer: Optional[dict[str, torch.Tensor]] = None
|
59
|
+
self.apply_ema_every_n_steps = apply_ema_every_n_steps
|
60
|
+
self.start_step = start_step
|
61
|
+
self.eval_with_ema = eval_with_ema
|
62
|
+
self.decay = decay
|
63
|
+
self.warm_start = warm_start
|
64
|
+
|
65
|
+
@property
|
66
|
+
def ema_initialized(self) -> bool:
|
67
|
+
"""Check if EMA weights have been initialized.
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
bool
|
72
|
+
Whether the EMA weights have been initialized.
|
73
|
+
|
74
|
+
"""
|
75
|
+
return self._ema_weights is not None
|
76
|
+
|
77
|
+
def state_dict(self) -> dict[str, Any]:
|
78
|
+
"""Return the current state of the callback.
|
79
|
+
|
80
|
+
Returns
|
81
|
+
-------
|
82
|
+
dict[str, Any]
|
83
|
+
The current state of the callback.
|
84
|
+
|
85
|
+
"""
|
86
|
+
return {
|
87
|
+
"cur_step": self._cur_step,
|
88
|
+
"ema_weights": self._ema_weights,
|
89
|
+
}
|
90
|
+
|
91
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
92
|
+
"""Load the state of the callback.
|
93
|
+
|
94
|
+
Parameters
|
95
|
+
----------
|
96
|
+
state_dict: dict[str, Any]
|
97
|
+
The state of the callback to load.
|
98
|
+
|
99
|
+
"""
|
100
|
+
self._cur_step = state_dict["cur_step"]
|
101
|
+
self._ema_weights = state_dict["ema_weights"]
|
102
|
+
|
103
|
+
def should_apply_ema(self, step: int) -> bool:
|
104
|
+
"""Check if EMA should be applied at the current step.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
step: int
|
109
|
+
The current global step.
|
110
|
+
|
111
|
+
Returns
|
112
|
+
-------
|
113
|
+
bool
|
114
|
+
True if EMA should be applied, False otherwise.
|
115
|
+
|
116
|
+
"""
|
117
|
+
return (
|
118
|
+
step != self._cur_step
|
119
|
+
and step >= self.start_step
|
120
|
+
and step % self.apply_ema_every_n_steps == 0
|
121
|
+
)
|
122
|
+
|
123
|
+
def apply_ema(self, pl_module: LightningModule) -> None:
|
124
|
+
"""Apply EMA to the model weights.
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
pl_module: LightningModule
|
129
|
+
The LightningModule instance.
|
130
|
+
|
131
|
+
"""
|
132
|
+
decay = self.decay
|
133
|
+
if self.warm_start:
|
134
|
+
decay = min(decay, (1 + self._cur_step) / (10 + self._cur_step))
|
135
|
+
|
136
|
+
for k, orig_weight in pl_module.state_dict().items():
|
137
|
+
ema_weight = self._ema_weights[k]
|
138
|
+
if (
|
139
|
+
ema_weight.data.dtype != torch.long # noqa: PLR1714
|
140
|
+
and orig_weight.data.dtype != torch.long # skip non-trainable weights
|
141
|
+
):
|
142
|
+
diff = ema_weight.data - orig_weight.data
|
143
|
+
diff.mul_(1.0 - decay)
|
144
|
+
ema_weight.sub_(diff)
|
145
|
+
|
146
|
+
def on_load_checkpoint(
|
147
|
+
self,
|
148
|
+
trainer: Trainer,
|
149
|
+
pl_module: LightningModule, # noqa: ARG002
|
150
|
+
checkpoint: dict[str, Any],
|
151
|
+
) -> None:
|
152
|
+
"""Load the EMA weights from the checkpoint.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
trainer: Trainer
|
157
|
+
The Trainer instance.
|
158
|
+
pl_module: LightningModule
|
159
|
+
The LightningModule instance.
|
160
|
+
checkpoint: dict[str, Any]
|
161
|
+
The checkpoint to load.
|
162
|
+
|
163
|
+
"""
|
164
|
+
if "ema" in checkpoint:
|
165
|
+
print("LOADING CHECKPOINT RUNNING")
|
166
|
+
self.load_state_dict(checkpoint["ema"])
|
167
|
+
|
168
|
+
def on_save_checkpoint(
|
169
|
+
self,
|
170
|
+
trainer: Trainer, # noqa: ARG002
|
171
|
+
pl_module: LightningModule, # noqa: ARG002
|
172
|
+
checkpoint: dict[str, Any],
|
173
|
+
) -> None:
|
174
|
+
"""Save the EMA weights to the checkpoint.
|
175
|
+
|
176
|
+
Parameters
|
177
|
+
----------
|
178
|
+
trainer: Trainer
|
179
|
+
The Trainer instance.
|
180
|
+
pl_module: LightningModule
|
181
|
+
The LightningModule instance.
|
182
|
+
checkpoint: dict[str, Any]
|
183
|
+
The checkpoint to save.
|
184
|
+
|
185
|
+
"""
|
186
|
+
if self.ema_initialized:
|
187
|
+
checkpoint["ema"] = self.state_dict()
|
188
|
+
|
189
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # noqa: ARG002
|
190
|
+
"""Initialize EMA weights and move to device.
|
191
|
+
|
192
|
+
Parameters
|
193
|
+
----------
|
194
|
+
trainer: pl.Trainer
|
195
|
+
The Trainer instance.
|
196
|
+
pl_module: pl.LightningModule
|
197
|
+
The LightningModule instance.
|
198
|
+
|
199
|
+
"""
|
200
|
+
# Create EMA weights if not already initialized
|
201
|
+
if not self.ema_initialized:
|
202
|
+
self._ema_weights = {
|
203
|
+
k: p.detach().clone() for k, p in pl_module.state_dict().items()
|
204
|
+
}
|
205
|
+
|
206
|
+
# Move EMA weights to the correct device
|
207
|
+
self._ema_weights = {
|
208
|
+
k: p.to(pl_module.device) for k, p in self._ema_weights.items()
|
209
|
+
}
|
210
|
+
|
211
|
+
def on_train_batch_end(
|
212
|
+
self,
|
213
|
+
trainer: Trainer,
|
214
|
+
pl_module: LightningModule,
|
215
|
+
outputs: STEP_OUTPUT, # noqa: ARG002
|
216
|
+
batch: Any, # noqa: ARG002
|
217
|
+
batch_idx: int, # noqa: ARG002
|
218
|
+
) -> None:
|
219
|
+
"""Apply EMA to the model weights at the end of each training batch.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
trainer: Trainer
|
224
|
+
The Trainer instance.
|
225
|
+
pl_module: LightningModule
|
226
|
+
The LightningModule instance.
|
227
|
+
outputs: STEP_OUTPUT
|
228
|
+
The outputs of the model.
|
229
|
+
batch: Any
|
230
|
+
The current batch.
|
231
|
+
batch_idx: int
|
232
|
+
The index of the current batch.
|
233
|
+
|
234
|
+
"""
|
235
|
+
if self.should_apply_ema(trainer.global_step):
|
236
|
+
self._cur_step = trainer.global_step
|
237
|
+
self.apply_ema(pl_module)
|
238
|
+
|
239
|
+
def replace_model_weights(self, pl_module: LightningModule) -> None:
|
240
|
+
"""Replace model weights with EMA weights.
|
241
|
+
|
242
|
+
Parameters
|
243
|
+
----------
|
244
|
+
pl_module: LightningModule
|
245
|
+
The LightningModule instance.
|
246
|
+
|
247
|
+
"""
|
248
|
+
self._weights_buffer = {
|
249
|
+
k: p.detach().clone().to("cpu") for k, p in pl_module.state_dict().items()
|
250
|
+
}
|
251
|
+
pl_module.load_state_dict(self._ema_weights, strict=False)
|
252
|
+
|
253
|
+
def restore_original_weights(self, pl_module: LightningModule) -> None:
|
254
|
+
"""Restore model weights to original weights.
|
255
|
+
|
256
|
+
Parameters
|
257
|
+
----------
|
258
|
+
pl_module: LightningModule
|
259
|
+
The LightningModule instance.
|
260
|
+
|
261
|
+
"""
|
262
|
+
pl_module.load_state_dict(self._weights_buffer, strict=False)
|
263
|
+
del self._weights_buffer
|
264
|
+
|
265
|
+
def _on_eval_start(self, pl_module: LightningModule) -> None:
|
266
|
+
"""Use EMA weights for evaluation.
|
267
|
+
|
268
|
+
Parameters
|
269
|
+
----------
|
270
|
+
pl_module: LightningModule
|
271
|
+
The LightningModule instance.
|
272
|
+
|
273
|
+
"""
|
274
|
+
if self.ema_initialized and self.eval_with_ema:
|
275
|
+
self.replace_model_weights(pl_module)
|
276
|
+
|
277
|
+
def _on_eval_end(self, pl_module: LightningModule) -> None:
|
278
|
+
"""Restore original weights after evaluation.
|
279
|
+
|
280
|
+
Parameters
|
281
|
+
----------
|
282
|
+
pl_module: LightningModule
|
283
|
+
The LightningModule instance.
|
284
|
+
|
285
|
+
"""
|
286
|
+
if self.ema_initialized and self.eval_with_ema:
|
287
|
+
self.restore_original_weights(pl_module)
|
288
|
+
|
289
|
+
def on_validation_start(
|
290
|
+
self,
|
291
|
+
trainer: Trainer, # noqa: ARG002
|
292
|
+
pl_module: LightningModule,
|
293
|
+
) -> None:
|
294
|
+
"""Use EMA weights for validation.
|
295
|
+
|
296
|
+
Parameters
|
297
|
+
----------
|
298
|
+
trainer: Trainer
|
299
|
+
The Trainer instance.
|
300
|
+
pl_module: LightningModule
|
301
|
+
The LightningModule instance.
|
302
|
+
|
303
|
+
"""
|
304
|
+
self._on_eval_start(pl_module)
|
305
|
+
|
306
|
+
def on_validation_end(
|
307
|
+
self,
|
308
|
+
trainer: Trainer, # noqa: ARG002
|
309
|
+
pl_module: LightningModule,
|
310
|
+
) -> None:
|
311
|
+
"""Restore original weights after validation.
|
312
|
+
|
313
|
+
Parameters
|
314
|
+
----------
|
315
|
+
trainer: Trainer
|
316
|
+
The Trainer instance.
|
317
|
+
pl_module: LightningModule
|
318
|
+
The LightningModule instance.
|
319
|
+
|
320
|
+
"""
|
321
|
+
self._on_eval_end(pl_module)
|
322
|
+
|
323
|
+
def on_test_start(
|
324
|
+
self,
|
325
|
+
trainer: Trainer, # noqa: ARG002
|
326
|
+
pl_module: LightningModule,
|
327
|
+
) -> None:
|
328
|
+
"""Use EMA weights for testing.
|
329
|
+
|
330
|
+
Parameters
|
331
|
+
----------
|
332
|
+
trainer: Trainer
|
333
|
+
The Trainer instance.
|
334
|
+
pl_module: LightningModule
|
335
|
+
The LightningModule instance.
|
336
|
+
|
337
|
+
"""
|
338
|
+
self._on_eval_start(pl_module)
|
339
|
+
|
340
|
+
def on_test_end(
|
341
|
+
self,
|
342
|
+
trainer: Trainer, # noqa: ARG002
|
343
|
+
pl_module: LightningModule,
|
344
|
+
) -> None:
|
345
|
+
"""Restore original weights after testing.
|
346
|
+
|
347
|
+
Parameters
|
348
|
+
----------
|
349
|
+
trainer: Trainer
|
350
|
+
The Trainer instance.
|
351
|
+
pl_module: LightningModule
|
352
|
+
The LightningModule instance.
|
353
|
+
|
354
|
+
"""
|
355
|
+
self._on_eval_end(pl_module)
|
356
|
+
|
357
|
+
def on_predict_start(
|
358
|
+
self,
|
359
|
+
trainer: Trainer, # noqa: ARG002
|
360
|
+
pl_module: LightningModule,
|
361
|
+
) -> None:
|
362
|
+
"""Use EMA weights for prediction.
|
363
|
+
|
364
|
+
Parameters
|
365
|
+
----------
|
366
|
+
trainer: Trainer
|
367
|
+
The Trainer instance.
|
368
|
+
pl_module: LightningModule
|
369
|
+
The LightningModule instance.
|
370
|
+
|
371
|
+
"""
|
372
|
+
self._on_eval_start(pl_module)
|
373
|
+
|
374
|
+
def on_predict_end(
|
375
|
+
self,
|
376
|
+
trainer: Trainer, # noqa: ARG002
|
377
|
+
pl_module: LightningModule,
|
378
|
+
) -> None:
|
379
|
+
"""Restore original weights after prediction.
|
380
|
+
|
381
|
+
Parameters
|
382
|
+
----------
|
383
|
+
trainer: Trainer
|
384
|
+
The Trainer instance.
|
385
|
+
pl_module: LightningModule
|
386
|
+
The LightningModule instance.
|
387
|
+
|
388
|
+
"""
|
389
|
+
self._on_eval_end(pl_module)
|