boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,303 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ from functools import partial
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.nn import (
9
+ Linear,
10
+ Module,
11
+ )
12
+ from torch.types import Device
13
+
14
+ LinearNoBias = partial(Linear, bias=False)
15
+
16
+
17
+ def exists(v):
18
+ return v is not None
19
+
20
+
21
+ def default(v, d):
22
+ return v if exists(v) else d
23
+
24
+
25
+ def log(t, eps=1e-20):
26
+ return torch.log(t.clamp(min=eps))
27
+
28
+
29
+ class SwiGLU(Module):
30
+ def forward(
31
+ self,
32
+ x, #: Float['... d']
33
+ ): # -> Float[' ... (d//2)']:
34
+ x, gates = x.chunk(2, dim=-1)
35
+ return F.silu(gates) * x
36
+
37
+
38
+ def center(atom_coords, atom_mask):
39
+ atom_mean = torch.sum(
40
+ atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
41
+ ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
42
+ atom_coords = atom_coords - atom_mean
43
+ return atom_coords
44
+
45
+
46
+ def compute_random_augmentation(
47
+ multiplicity, s_trans=1.0, device=None, dtype=torch.float32
48
+ ):
49
+ R = random_rotations(multiplicity, dtype=dtype, device=device)
50
+ random_trans = (
51
+ torch.randn((multiplicity, 1, 3), dtype=dtype, device=device) * s_trans
52
+ )
53
+ return R, random_trans
54
+
55
+
56
+ def randomly_rotate(coords, return_second_coords=False, second_coords=None):
57
+ R = random_rotations(len(coords), coords.dtype, coords.device)
58
+
59
+ if return_second_coords:
60
+ return torch.einsum("bmd,bds->bms", coords, R), torch.einsum(
61
+ "bmd,bds->bms", second_coords, R
62
+ ) if second_coords is not None else None
63
+
64
+ return torch.einsum("bmd,bds->bms", coords, R)
65
+
66
+
67
+ def center_random_augmentation(
68
+ atom_coords,
69
+ atom_mask,
70
+ s_trans=1.0,
71
+ augmentation=True,
72
+ centering=True,
73
+ return_second_coords=False,
74
+ second_coords=None,
75
+ ):
76
+ """Algorithm 19"""
77
+ if centering:
78
+ atom_mean = torch.sum(
79
+ atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
80
+ ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
81
+ atom_coords = atom_coords - atom_mean
82
+
83
+ if second_coords is not None:
84
+ # apply same transformation also to this input
85
+ second_coords = second_coords - atom_mean
86
+
87
+ if augmentation:
88
+ atom_coords, second_coords = randomly_rotate(
89
+ atom_coords, return_second_coords=True, second_coords=second_coords
90
+ )
91
+ random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans
92
+ atom_coords = atom_coords + random_trans
93
+
94
+ if second_coords is not None:
95
+ second_coords = second_coords + random_trans
96
+
97
+ if return_second_coords:
98
+ return atom_coords, second_coords
99
+
100
+ return atom_coords
101
+
102
+
103
+ class ExponentialMovingAverage:
104
+ """from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license
105
+ Maintains (exponential) moving average of a set of parameters."""
106
+
107
+ def __init__(self, parameters, decay, use_num_updates=True):
108
+ """
109
+ Args:
110
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
111
+ `model.parameters()`.
112
+ decay: The exponential decay.
113
+ use_num_updates: Whether to use number of updates when computing
114
+ averages.
115
+ """
116
+ if decay < 0.0 or decay > 1.0:
117
+ raise ValueError("Decay must be between 0 and 1")
118
+ self.decay = decay
119
+ self.num_updates = 0 if use_num_updates else None
120
+ self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
121
+ self.collected_params = []
122
+
123
+ def update(self, parameters):
124
+ """
125
+ Update currently maintained parameters.
126
+ Call this every time the parameters are updated, such as the result of
127
+ the `optimizer.step()` call.
128
+ Args:
129
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
130
+ parameters used to initialize this object.
131
+ """
132
+ decay = self.decay
133
+ if self.num_updates is not None:
134
+ self.num_updates += 1
135
+ decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
136
+ one_minus_decay = 1.0 - decay
137
+ with torch.no_grad():
138
+ parameters = [p for p in parameters if p.requires_grad]
139
+ for s_param, param in zip(self.shadow_params, parameters):
140
+ s_param.sub_(one_minus_decay * (s_param - param))
141
+
142
+ def compatible(self, parameters):
143
+ if len(self.shadow_params) != len(parameters):
144
+ print(
145
+ f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}"
146
+ )
147
+ return False
148
+
149
+ for s_param, param in zip(self.shadow_params, parameters):
150
+ if param.data.shape != s_param.data.shape:
151
+ print(
152
+ f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}"
153
+ )
154
+ return False
155
+ return True
156
+
157
+ def copy_to(self, parameters):
158
+ """
159
+ Copy current parameters into given collection of parameters.
160
+ Args:
161
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
162
+ updated with the stored moving averages.
163
+ """
164
+ parameters = [p for p in parameters if p.requires_grad]
165
+ for s_param, param in zip(self.shadow_params, parameters):
166
+ if param.requires_grad:
167
+ param.data.copy_(s_param.data)
168
+
169
+ def store(self, parameters):
170
+ """
171
+ Save the current parameters for restoring later.
172
+ Args:
173
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
174
+ temporarily stored.
175
+ """
176
+ self.collected_params = [param.clone() for param in parameters]
177
+
178
+ def restore(self, parameters):
179
+ """
180
+ Restore the parameters stored with the `store` method.
181
+ Useful to validate the model with EMA parameters without affecting the
182
+ original optimization process. Store the parameters before the
183
+ `copy_to` method. After validation (or model saving), use this to
184
+ restore the former parameters.
185
+ Args:
186
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
187
+ updated with the stored parameters.
188
+ """
189
+ for c_param, param in zip(self.collected_params, parameters):
190
+ param.data.copy_(c_param.data)
191
+
192
+ def state_dict(self):
193
+ return dict(
194
+ decay=self.decay,
195
+ num_updates=self.num_updates,
196
+ shadow_params=self.shadow_params,
197
+ )
198
+
199
+ def load_state_dict(self, state_dict, device):
200
+ self.decay = state_dict["decay"]
201
+ self.num_updates = state_dict["num_updates"]
202
+ self.shadow_params = [
203
+ tensor.to(device) for tensor in state_dict["shadow_params"]
204
+ ]
205
+
206
+ def to(self, device):
207
+ self.shadow_params = [tensor.to(device) for tensor in self.shadow_params]
208
+
209
+
210
+ # the following is copied from Torch3D, BSD License, Copyright (c) Meta Platforms, Inc. and affiliates.
211
+
212
+
213
+ def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ Return a tensor where each element has the absolute value taken from the,
216
+ corresponding element of a, with sign taken from the corresponding
217
+ element of b. This is like the standard copysign floating-point operation,
218
+ but is not careful about negative 0 and NaN.
219
+
220
+ Args:
221
+ a: source tensor.
222
+ b: tensor whose signs will be used, of the same shape as a.
223
+
224
+ Returns:
225
+ Tensor of the same shape as a with the signs of b.
226
+ """
227
+ signs_differ = (a < 0) != (b < 0)
228
+ return torch.where(signs_differ, -a, a)
229
+
230
+
231
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
232
+ """
233
+ Convert rotations given as quaternions to rotation matrices.
234
+
235
+ Args:
236
+ quaternions: quaternions with real part first,
237
+ as tensor of shape (..., 4).
238
+
239
+ Returns:
240
+ Rotation matrices as tensor of shape (..., 3, 3).
241
+ """
242
+ r, i, j, k = torch.unbind(quaternions, -1)
243
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
244
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
245
+
246
+ o = torch.stack(
247
+ (
248
+ 1 - two_s * (j * j + k * k),
249
+ two_s * (i * j - k * r),
250
+ two_s * (i * k + j * r),
251
+ two_s * (i * j + k * r),
252
+ 1 - two_s * (i * i + k * k),
253
+ two_s * (j * k - i * r),
254
+ two_s * (i * k - j * r),
255
+ two_s * (j * k + i * r),
256
+ 1 - two_s * (i * i + j * j),
257
+ ),
258
+ -1,
259
+ )
260
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
261
+
262
+
263
+ def random_quaternions(
264
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
265
+ ) -> torch.Tensor:
266
+ """
267
+ Generate random quaternions representing rotations,
268
+ i.e. versors with nonnegative real part.
269
+
270
+ Args:
271
+ n: Number of quaternions in a batch to return.
272
+ dtype: Type to return.
273
+ device: Desired device of returned tensor. Default:
274
+ uses the current device for the default tensor type.
275
+
276
+ Returns:
277
+ Quaternions as tensor of shape (N, 4).
278
+ """
279
+ if isinstance(device, str):
280
+ device = torch.device(device)
281
+ o = torch.randn((n, 4), dtype=dtype, device=device)
282
+ s = (o * o).sum(1)
283
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
284
+ return o
285
+
286
+
287
+ def random_rotations(
288
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
289
+ ) -> torch.Tensor:
290
+ """
291
+ Generate random rotations as 3x3 rotation matrices.
292
+
293
+ Args:
294
+ n: Number of rotation matrices in a batch to return.
295
+ dtype: Type to return.
296
+ device: Device of returned tensor. Default: if None,
297
+ uses the current device for the default tensor type.
298
+
299
+ Returns:
300
+ Rotation matrices as tensor of shape (n, 3, 3).
301
+ """
302
+ quaternions = random_quaternions(n, dtype=dtype, device=device)
303
+ return quaternion_to_matrix(quaternions)
File without changes
@@ -0,0 +1,389 @@
1
+ # --------------------------------------------------------------------------------------
2
+ # Modified from Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion):
3
+ # Modified from : https://github.com/BioinfoMachineLearning/bio-diffusion/blob/main/src/utils/__init__.py
4
+ # --------------------------------------------------------------------------------------
5
+
6
+ from typing import Any, Optional
7
+
8
+ import torch
9
+ from pytorch_lightning import Callback, LightningModule, Trainer
10
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
11
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
12
+
13
+
14
+ class EMA(Callback):
15
+ """Implements Exponential Moving Averaging (EMA).
16
+
17
+ When training a model, this callback maintains moving averages
18
+ of the trained parameters. When evaluating, we use the moving
19
+ averages copy of the trained parameters. When saving, we save
20
+ an additional set of parameters with the prefix `ema`.
21
+
22
+ Adapted from:
23
+ https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
24
+ https://github.com/BioinfoMachineLearning/bio-diffusion/blob/main/src/utils/__init__.py
25
+
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ decay: float = 0.999,
31
+ apply_ema_every_n_steps: int = 1,
32
+ start_step: int = 0,
33
+ eval_with_ema: bool = True,
34
+ warm_start: bool = True,
35
+ ) -> None:
36
+ """Initialize the EMA callback.
37
+
38
+ Parameters
39
+ ----------
40
+ decay: float
41
+ The exponential decay, has to be between 0-1.
42
+ apply_ema_every_n_steps: int, optional (default=1)
43
+ Apply EMA every n global steps.
44
+ start_step: int, optional (default=0)
45
+ Start applying EMA from ``start_step`` global step onwards.
46
+ eval_with_ema: bool, optional (default=True)
47
+ Validate the EMA weights instead of the original weights.
48
+ Note this means that when saving the model, the
49
+ validation metrics are calculated with the EMA weights.
50
+
51
+ """
52
+ if not (0 <= decay <= 1):
53
+ msg = "EMA decay value must be between 0 and 1"
54
+ raise MisconfigurationException(msg)
55
+
56
+ self._ema_weights: Optional[dict[str, torch.Tensor]] = None
57
+ self._cur_step: Optional[int] = None
58
+ self._weights_buffer: Optional[dict[str, torch.Tensor]] = None
59
+ self.apply_ema_every_n_steps = apply_ema_every_n_steps
60
+ self.start_step = start_step
61
+ self.eval_with_ema = eval_with_ema
62
+ self.decay = decay
63
+ self.warm_start = warm_start
64
+
65
+ @property
66
+ def ema_initialized(self) -> bool:
67
+ """Check if EMA weights have been initialized.
68
+
69
+ Returns
70
+ -------
71
+ bool
72
+ Whether the EMA weights have been initialized.
73
+
74
+ """
75
+ return self._ema_weights is not None
76
+
77
+ def state_dict(self) -> dict[str, Any]:
78
+ """Return the current state of the callback.
79
+
80
+ Returns
81
+ -------
82
+ dict[str, Any]
83
+ The current state of the callback.
84
+
85
+ """
86
+ return {
87
+ "cur_step": self._cur_step,
88
+ "ema_weights": self._ema_weights,
89
+ }
90
+
91
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
92
+ """Load the state of the callback.
93
+
94
+ Parameters
95
+ ----------
96
+ state_dict: dict[str, Any]
97
+ The state of the callback to load.
98
+
99
+ """
100
+ self._cur_step = state_dict["cur_step"]
101
+ self._ema_weights = state_dict["ema_weights"]
102
+
103
+ def should_apply_ema(self, step: int) -> bool:
104
+ """Check if EMA should be applied at the current step.
105
+
106
+ Parameters
107
+ ----------
108
+ step: int
109
+ The current global step.
110
+
111
+ Returns
112
+ -------
113
+ bool
114
+ True if EMA should be applied, False otherwise.
115
+
116
+ """
117
+ return (
118
+ step != self._cur_step
119
+ and step >= self.start_step
120
+ and step % self.apply_ema_every_n_steps == 0
121
+ )
122
+
123
+ def apply_ema(self, pl_module: LightningModule) -> None:
124
+ """Apply EMA to the model weights.
125
+
126
+ Parameters
127
+ ----------
128
+ pl_module: LightningModule
129
+ The LightningModule instance.
130
+
131
+ """
132
+ decay = self.decay
133
+ if self.warm_start:
134
+ decay = min(decay, (1 + self._cur_step) / (10 + self._cur_step))
135
+
136
+ for k, orig_weight in pl_module.state_dict().items():
137
+ ema_weight = self._ema_weights[k]
138
+ if (
139
+ ema_weight.data.dtype != torch.long # noqa: PLR1714
140
+ and orig_weight.data.dtype != torch.long # skip non-trainable weights
141
+ ):
142
+ diff = ema_weight.data - orig_weight.data
143
+ diff.mul_(1.0 - decay)
144
+ ema_weight.sub_(diff)
145
+
146
+ def on_load_checkpoint(
147
+ self,
148
+ trainer: Trainer,
149
+ pl_module: LightningModule, # noqa: ARG002
150
+ checkpoint: dict[str, Any],
151
+ ) -> None:
152
+ """Load the EMA weights from the checkpoint.
153
+
154
+ Parameters
155
+ ----------
156
+ trainer: Trainer
157
+ The Trainer instance.
158
+ pl_module: LightningModule
159
+ The LightningModule instance.
160
+ checkpoint: dict[str, Any]
161
+ The checkpoint to load.
162
+
163
+ """
164
+ if "ema" in checkpoint:
165
+ print("LOADING CHECKPOINT RUNNING")
166
+ self.load_state_dict(checkpoint["ema"])
167
+
168
+ def on_save_checkpoint(
169
+ self,
170
+ trainer: Trainer, # noqa: ARG002
171
+ pl_module: LightningModule, # noqa: ARG002
172
+ checkpoint: dict[str, Any],
173
+ ) -> None:
174
+ """Save the EMA weights to the checkpoint.
175
+
176
+ Parameters
177
+ ----------
178
+ trainer: Trainer
179
+ The Trainer instance.
180
+ pl_module: LightningModule
181
+ The LightningModule instance.
182
+ checkpoint: dict[str, Any]
183
+ The checkpoint to save.
184
+
185
+ """
186
+ if self.ema_initialized:
187
+ checkpoint["ema"] = self.state_dict()
188
+
189
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # noqa: ARG002
190
+ """Initialize EMA weights and move to device.
191
+
192
+ Parameters
193
+ ----------
194
+ trainer: pl.Trainer
195
+ The Trainer instance.
196
+ pl_module: pl.LightningModule
197
+ The LightningModule instance.
198
+
199
+ """
200
+ # Create EMA weights if not already initialized
201
+ if not self.ema_initialized:
202
+ self._ema_weights = {
203
+ k: p.detach().clone() for k, p in pl_module.state_dict().items()
204
+ }
205
+
206
+ # Move EMA weights to the correct device
207
+ self._ema_weights = {
208
+ k: p.to(pl_module.device) for k, p in self._ema_weights.items()
209
+ }
210
+
211
+ def on_train_batch_end(
212
+ self,
213
+ trainer: Trainer,
214
+ pl_module: LightningModule,
215
+ outputs: STEP_OUTPUT, # noqa: ARG002
216
+ batch: Any, # noqa: ARG002
217
+ batch_idx: int, # noqa: ARG002
218
+ ) -> None:
219
+ """Apply EMA to the model weights at the end of each training batch.
220
+
221
+ Parameters
222
+ ----------
223
+ trainer: Trainer
224
+ The Trainer instance.
225
+ pl_module: LightningModule
226
+ The LightningModule instance.
227
+ outputs: STEP_OUTPUT
228
+ The outputs of the model.
229
+ batch: Any
230
+ The current batch.
231
+ batch_idx: int
232
+ The index of the current batch.
233
+
234
+ """
235
+ if self.should_apply_ema(trainer.global_step):
236
+ self._cur_step = trainer.global_step
237
+ self.apply_ema(pl_module)
238
+
239
+ def replace_model_weights(self, pl_module: LightningModule) -> None:
240
+ """Replace model weights with EMA weights.
241
+
242
+ Parameters
243
+ ----------
244
+ pl_module: LightningModule
245
+ The LightningModule instance.
246
+
247
+ """
248
+ self._weights_buffer = {
249
+ k: p.detach().clone().to("cpu") for k, p in pl_module.state_dict().items()
250
+ }
251
+ pl_module.load_state_dict(self._ema_weights, strict=False)
252
+
253
+ def restore_original_weights(self, pl_module: LightningModule) -> None:
254
+ """Restore model weights to original weights.
255
+
256
+ Parameters
257
+ ----------
258
+ pl_module: LightningModule
259
+ The LightningModule instance.
260
+
261
+ """
262
+ pl_module.load_state_dict(self._weights_buffer, strict=False)
263
+ del self._weights_buffer
264
+
265
+ def _on_eval_start(self, pl_module: LightningModule) -> None:
266
+ """Use EMA weights for evaluation.
267
+
268
+ Parameters
269
+ ----------
270
+ pl_module: LightningModule
271
+ The LightningModule instance.
272
+
273
+ """
274
+ if self.ema_initialized and self.eval_with_ema:
275
+ self.replace_model_weights(pl_module)
276
+
277
+ def _on_eval_end(self, pl_module: LightningModule) -> None:
278
+ """Restore original weights after evaluation.
279
+
280
+ Parameters
281
+ ----------
282
+ pl_module: LightningModule
283
+ The LightningModule instance.
284
+
285
+ """
286
+ if self.ema_initialized and self.eval_with_ema:
287
+ self.restore_original_weights(pl_module)
288
+
289
+ def on_validation_start(
290
+ self,
291
+ trainer: Trainer, # noqa: ARG002
292
+ pl_module: LightningModule,
293
+ ) -> None:
294
+ """Use EMA weights for validation.
295
+
296
+ Parameters
297
+ ----------
298
+ trainer: Trainer
299
+ The Trainer instance.
300
+ pl_module: LightningModule
301
+ The LightningModule instance.
302
+
303
+ """
304
+ self._on_eval_start(pl_module)
305
+
306
+ def on_validation_end(
307
+ self,
308
+ trainer: Trainer, # noqa: ARG002
309
+ pl_module: LightningModule,
310
+ ) -> None:
311
+ """Restore original weights after validation.
312
+
313
+ Parameters
314
+ ----------
315
+ trainer: Trainer
316
+ The Trainer instance.
317
+ pl_module: LightningModule
318
+ The LightningModule instance.
319
+
320
+ """
321
+ self._on_eval_end(pl_module)
322
+
323
+ def on_test_start(
324
+ self,
325
+ trainer: Trainer, # noqa: ARG002
326
+ pl_module: LightningModule,
327
+ ) -> None:
328
+ """Use EMA weights for testing.
329
+
330
+ Parameters
331
+ ----------
332
+ trainer: Trainer
333
+ The Trainer instance.
334
+ pl_module: LightningModule
335
+ The LightningModule instance.
336
+
337
+ """
338
+ self._on_eval_start(pl_module)
339
+
340
+ def on_test_end(
341
+ self,
342
+ trainer: Trainer, # noqa: ARG002
343
+ pl_module: LightningModule,
344
+ ) -> None:
345
+ """Restore original weights after testing.
346
+
347
+ Parameters
348
+ ----------
349
+ trainer: Trainer
350
+ The Trainer instance.
351
+ pl_module: LightningModule
352
+ The LightningModule instance.
353
+
354
+ """
355
+ self._on_eval_end(pl_module)
356
+
357
+ def on_predict_start(
358
+ self,
359
+ trainer: Trainer, # noqa: ARG002
360
+ pl_module: LightningModule,
361
+ ) -> None:
362
+ """Use EMA weights for prediction.
363
+
364
+ Parameters
365
+ ----------
366
+ trainer: Trainer
367
+ The Trainer instance.
368
+ pl_module: LightningModule
369
+ The LightningModule instance.
370
+
371
+ """
372
+ self._on_eval_start(pl_module)
373
+
374
+ def on_predict_end(
375
+ self,
376
+ trainer: Trainer, # noqa: ARG002
377
+ pl_module: LightningModule,
378
+ ) -> None:
379
+ """Restore original weights after prediction.
380
+
381
+ Parameters
382
+ ----------
383
+ trainer: Trainer
384
+ The Trainer instance.
385
+ pl_module: LightningModule
386
+ The LightningModule instance.
387
+
388
+ """
389
+ self._on_eval_end(pl_module)