nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
|
|
4
|
+
from lightning.fabric.utilities.rank_zero import _get_rank
|
|
5
|
+
from lightning.pytorch import Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
|
|
7
|
+
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
log = getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EarlyStopping(_EarlyStopping):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
monitor: str,
|
|
17
|
+
min_delta: float = 0,
|
|
18
|
+
min_lr: float | None = None,
|
|
19
|
+
patience: int = 3,
|
|
20
|
+
verbose: bool = True,
|
|
21
|
+
mode: str = "min",
|
|
22
|
+
strict: bool = True,
|
|
23
|
+
check_finite: bool = True,
|
|
24
|
+
stopping_threshold: float | None = None,
|
|
25
|
+
divergence_threshold: float | None = None,
|
|
26
|
+
check_on_train_epoch_end: bool | None = None,
|
|
27
|
+
log_rank_zero_only: bool = False,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(
|
|
30
|
+
monitor,
|
|
31
|
+
min_delta,
|
|
32
|
+
patience,
|
|
33
|
+
verbose,
|
|
34
|
+
mode,
|
|
35
|
+
strict,
|
|
36
|
+
check_finite,
|
|
37
|
+
stopping_threshold,
|
|
38
|
+
divergence_threshold,
|
|
39
|
+
check_on_train_epoch_end,
|
|
40
|
+
log_rank_zero_only,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.min_lr = min_lr
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
@staticmethod
|
|
47
|
+
def _log_info(
|
|
48
|
+
trainer: Trainer | None, message: str, log_rank_zero_only: bool
|
|
49
|
+
) -> None:
|
|
50
|
+
rank = _get_rank()
|
|
51
|
+
if trainer is not None and trainer.world_size <= 1:
|
|
52
|
+
rank = None
|
|
53
|
+
message = rank_prefixed_message(message, rank)
|
|
54
|
+
if rank is None or not log_rank_zero_only or rank == 0:
|
|
55
|
+
log.critical(message)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def _run_early_stopping_check(self, trainer: Trainer):
|
|
59
|
+
"""Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
|
|
60
|
+
logs = trainer.callback_metrics
|
|
61
|
+
|
|
62
|
+
# Disable early_stopping with fast_dev_run
|
|
63
|
+
if getattr(trainer, "fast_dev_run", False):
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
should_stop, reason = False, None
|
|
67
|
+
|
|
68
|
+
if not should_stop:
|
|
69
|
+
should_stop, reason = self._evaluate_stopping_criteria_min_lr(trainer)
|
|
70
|
+
|
|
71
|
+
# If metric present
|
|
72
|
+
if not should_stop and self._validate_condition_metric(logs):
|
|
73
|
+
current = logs[self.monitor].squeeze()
|
|
74
|
+
should_stop, reason = self._evaluate_stopping_criteria(current)
|
|
75
|
+
|
|
76
|
+
# stop every ddp process if any world process decides to stop
|
|
77
|
+
should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False)
|
|
78
|
+
trainer.should_stop = trainer.should_stop or should_stop
|
|
79
|
+
if should_stop:
|
|
80
|
+
self.stopped_epoch = trainer.current_epoch
|
|
81
|
+
if reason and self.verbose:
|
|
82
|
+
self._log_info(trainer, reason, self.log_rank_zero_only)
|
|
83
|
+
|
|
84
|
+
def _evaluate_stopping_criteria_min_lr(
|
|
85
|
+
self, trainer: Trainer
|
|
86
|
+
) -> tuple[bool, str | None]:
|
|
87
|
+
if self.min_lr is None:
|
|
88
|
+
return False, None
|
|
89
|
+
|
|
90
|
+
# Get the maximum LR across all param groups in all optimizers
|
|
91
|
+
model_max_lr = max(
|
|
92
|
+
[
|
|
93
|
+
param_group["lr"]
|
|
94
|
+
for optimizer in trainer.optimizers
|
|
95
|
+
for param_group in optimizer.param_groups
|
|
96
|
+
]
|
|
97
|
+
)
|
|
98
|
+
if not isinstance(model_max_lr, float) or not math.isfinite(model_max_lr):
|
|
99
|
+
return False, None
|
|
100
|
+
|
|
101
|
+
# If the maximum LR is less than the minimum LR, stop training
|
|
102
|
+
if model_max_lr >= self.min_lr:
|
|
103
|
+
return False, None
|
|
104
|
+
|
|
105
|
+
return True, (
|
|
106
|
+
"Stopping threshold reached: "
|
|
107
|
+
f"The maximum LR of the model across all param groups is {model_max_lr:.2e} "
|
|
108
|
+
f"which is less than the minimum LR {self.min_lr:.2e}"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def on_early_stopping(self, trainer: Trainer):
|
|
112
|
+
pass
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import copy
|
|
3
|
+
import threading
|
|
4
|
+
from collections.abc import Callable, Iterable
|
|
5
|
+
from typing import Any, Literal, overload
|
|
6
|
+
|
|
7
|
+
import lightning.pytorch as pl
|
|
8
|
+
import torch
|
|
9
|
+
from lightning.pytorch import Callback
|
|
10
|
+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from .base import CallbackConfigBase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EMA(Callback):
|
|
17
|
+
"""
|
|
18
|
+
Implements Exponential Moving Averaging (EMA).
|
|
19
|
+
|
|
20
|
+
When training a model, this callback will maintain moving averages of the trained parameters.
|
|
21
|
+
When evaluating, we use the moving averages copy of the trained parameters.
|
|
22
|
+
When saving, we save an additional set of parameters with the prefix `ema`.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
|
|
26
|
+
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
|
|
27
|
+
every_n_steps: Apply EMA every N steps.
|
|
28
|
+
cpu_offload: Offload weights to CPU.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
@override
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
decay: float,
|
|
35
|
+
validate_original_weights: bool = False,
|
|
36
|
+
every_n_steps: int = 1,
|
|
37
|
+
cpu_offload: bool = False,
|
|
38
|
+
):
|
|
39
|
+
if not (0 <= decay <= 1):
|
|
40
|
+
raise MisconfigurationException("EMA decay value must be between 0 and 1")
|
|
41
|
+
self.decay = decay
|
|
42
|
+
self.validate_original_weights = validate_original_weights
|
|
43
|
+
self.every_n_steps = every_n_steps
|
|
44
|
+
self.cpu_offload = cpu_offload
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def on_fit_start(
|
|
48
|
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
|
49
|
+
) -> None:
|
|
50
|
+
device = pl_module.device if not self.cpu_offload else torch.device("cpu")
|
|
51
|
+
trainer.optimizers = [
|
|
52
|
+
EMAOptimizer(
|
|
53
|
+
optim,
|
|
54
|
+
device=device,
|
|
55
|
+
decay=self.decay,
|
|
56
|
+
every_n_steps=self.every_n_steps,
|
|
57
|
+
current_step=trainer.global_step,
|
|
58
|
+
)
|
|
59
|
+
for optim in trainer.optimizers
|
|
60
|
+
if not isinstance(optim, EMAOptimizer)
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def on_validation_start(
|
|
65
|
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
|
66
|
+
) -> None:
|
|
67
|
+
if self._should_validate_ema_weights(trainer):
|
|
68
|
+
self.swap_model_weights(trainer)
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def on_validation_end(
|
|
72
|
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
|
73
|
+
) -> None:
|
|
74
|
+
if self._should_validate_ema_weights(trainer):
|
|
75
|
+
self.swap_model_weights(trainer)
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def on_test_start(
|
|
79
|
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
|
80
|
+
) -> None:
|
|
81
|
+
if self._should_validate_ema_weights(trainer):
|
|
82
|
+
self.swap_model_weights(trainer)
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def on_test_end(
|
|
86
|
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
|
87
|
+
) -> None:
|
|
88
|
+
if self._should_validate_ema_weights(trainer):
|
|
89
|
+
self.swap_model_weights(trainer)
|
|
90
|
+
|
|
91
|
+
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
|
|
92
|
+
return not self.validate_original_weights and self._ema_initialized(trainer)
|
|
93
|
+
|
|
94
|
+
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
|
|
95
|
+
return any(
|
|
96
|
+
isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
|
|
100
|
+
for optimizer in trainer.optimizers:
|
|
101
|
+
assert isinstance(optimizer, EMAOptimizer)
|
|
102
|
+
optimizer.switch_main_parameter_weights(saving_ema_model)
|
|
103
|
+
|
|
104
|
+
@contextlib.contextmanager
|
|
105
|
+
def save_ema_model(self, trainer: "pl.Trainer"):
|
|
106
|
+
"""
|
|
107
|
+
Saves an EMA copy of the model + EMA optimizer states for resume.
|
|
108
|
+
"""
|
|
109
|
+
self.swap_model_weights(trainer, saving_ema_model=True)
|
|
110
|
+
try:
|
|
111
|
+
yield
|
|
112
|
+
finally:
|
|
113
|
+
self.swap_model_weights(trainer, saving_ema_model=False)
|
|
114
|
+
|
|
115
|
+
@contextlib.contextmanager
|
|
116
|
+
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
|
|
117
|
+
for optimizer in trainer.optimizers:
|
|
118
|
+
assert isinstance(optimizer, EMAOptimizer)
|
|
119
|
+
optimizer.save_original_optimizer_state = True
|
|
120
|
+
try:
|
|
121
|
+
yield
|
|
122
|
+
finally:
|
|
123
|
+
for optimizer in trainer.optimizers:
|
|
124
|
+
assert isinstance(optimizer, EMAOptimizer)
|
|
125
|
+
optimizer.save_original_optimizer_state = False
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@torch.no_grad()
|
|
129
|
+
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
|
130
|
+
torch._foreach_mul_(ema_model_tuple, decay)
|
|
131
|
+
torch._foreach_add_(
|
|
132
|
+
ema_model_tuple,
|
|
133
|
+
current_model_tuple,
|
|
134
|
+
alpha=(1.0 - decay),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def run_ema_update_cpu(
|
|
139
|
+
ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None
|
|
140
|
+
):
|
|
141
|
+
if pre_sync_stream is not None:
|
|
142
|
+
pre_sync_stream.synchronize()
|
|
143
|
+
|
|
144
|
+
ema_update(ema_model_tuple, current_model_tuple, decay)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class EMAOptimizer(torch.optim.Optimizer):
|
|
148
|
+
r"""
|
|
149
|
+
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
|
|
150
|
+
Exponential Moving Average of parameters registered in the optimizer.
|
|
151
|
+
|
|
152
|
+
EMA parameters are automatically updated after every step of the optimizer
|
|
153
|
+
with the following formula:
|
|
154
|
+
|
|
155
|
+
ema_weight = decay * ema_weight + (1 - decay) * training_weight
|
|
156
|
+
|
|
157
|
+
To access EMA parameters, use ``swap_ema_weights()`` context manager to
|
|
158
|
+
perform a temporary in-place swap of regular parameters with EMA
|
|
159
|
+
parameters.
|
|
160
|
+
|
|
161
|
+
Notes:
|
|
162
|
+
- EMAOptimizer is not compatible with APEX AMP O2.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
optimizer (torch.optim.Optimizer): optimizer to wrap
|
|
166
|
+
device (torch.device): device for EMA parameters
|
|
167
|
+
decay (float): decay factor
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
returns an instance of torch.optim.Optimizer that computes EMA of
|
|
171
|
+
parameters
|
|
172
|
+
|
|
173
|
+
Example:
|
|
174
|
+
model = Model().to(device)
|
|
175
|
+
opt = torch.optim.Adam(model.parameters())
|
|
176
|
+
|
|
177
|
+
opt = EMAOptimizer(opt, device, 0.9999)
|
|
178
|
+
|
|
179
|
+
for epoch in range(epochs):
|
|
180
|
+
training_loop(model, opt)
|
|
181
|
+
|
|
182
|
+
regular_eval_accuracy = evaluate(model)
|
|
183
|
+
|
|
184
|
+
with opt.swap_ema_weights():
|
|
185
|
+
ema_eval_accuracy = evaluate(model)
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
stream: Any | None
|
|
189
|
+
|
|
190
|
+
@override
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
optimizer: torch.optim.Optimizer,
|
|
194
|
+
device: torch.device,
|
|
195
|
+
decay: float = 0.9999,
|
|
196
|
+
every_n_steps: int = 1,
|
|
197
|
+
current_step: int = 0,
|
|
198
|
+
):
|
|
199
|
+
self.optimizer = optimizer
|
|
200
|
+
self.decay = decay
|
|
201
|
+
self.device = device
|
|
202
|
+
self.current_step = current_step
|
|
203
|
+
self.every_n_steps = every_n_steps
|
|
204
|
+
self.save_original_optimizer_state = False
|
|
205
|
+
|
|
206
|
+
self.first_iteration = True
|
|
207
|
+
self.rebuild_ema_params = True
|
|
208
|
+
self.stream = None
|
|
209
|
+
self.thread = None
|
|
210
|
+
|
|
211
|
+
self.ema_params = ()
|
|
212
|
+
self.in_saving_ema_model_context = False
|
|
213
|
+
|
|
214
|
+
def all_parameters(self) -> Iterable[torch.Tensor]:
|
|
215
|
+
return (param for group in self.param_groups for param in group["params"])
|
|
216
|
+
|
|
217
|
+
@overload
|
|
218
|
+
def step(self, closure: None = ...) -> None: ...
|
|
219
|
+
|
|
220
|
+
@overload
|
|
221
|
+
def step(self, closure: Callable[[], float]) -> float: ...
|
|
222
|
+
|
|
223
|
+
@override
|
|
224
|
+
def step(self, closure: Callable[[], float] | None = None) -> float | None:
|
|
225
|
+
self.join()
|
|
226
|
+
|
|
227
|
+
if self.first_iteration:
|
|
228
|
+
if any(p.is_cuda for p in self.all_parameters()):
|
|
229
|
+
self.stream = torch.cuda.Stream()
|
|
230
|
+
|
|
231
|
+
self.first_iteration = False
|
|
232
|
+
|
|
233
|
+
if self.rebuild_ema_params:
|
|
234
|
+
opt_params = list(self.all_parameters())
|
|
235
|
+
|
|
236
|
+
self.ema_params += tuple(
|
|
237
|
+
copy.deepcopy(param.data.detach()).to(self.device)
|
|
238
|
+
for param in opt_params[len(self.ema_params) :]
|
|
239
|
+
)
|
|
240
|
+
self.rebuild_ema_params = False
|
|
241
|
+
|
|
242
|
+
loss = self.optimizer.step(closure)
|
|
243
|
+
|
|
244
|
+
if self._should_update_at_step():
|
|
245
|
+
self.update()
|
|
246
|
+
self.current_step += 1
|
|
247
|
+
return loss
|
|
248
|
+
|
|
249
|
+
def _should_update_at_step(self) -> bool:
|
|
250
|
+
return self.current_step % self.every_n_steps == 0
|
|
251
|
+
|
|
252
|
+
@torch.no_grad()
|
|
253
|
+
def update(self):
|
|
254
|
+
if self.stream is not None:
|
|
255
|
+
self.stream.wait_stream(torch.cuda.current_stream())
|
|
256
|
+
|
|
257
|
+
with torch.cuda.stream(self.stream):
|
|
258
|
+
current_model_state = tuple(
|
|
259
|
+
param.data.to(self.device, non_blocking=True)
|
|
260
|
+
for param in self.all_parameters()
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if self.device.type == "cuda":
|
|
264
|
+
ema_update(self.ema_params, current_model_state, self.decay)
|
|
265
|
+
|
|
266
|
+
if self.device.type == "cpu":
|
|
267
|
+
self.thread = threading.Thread(
|
|
268
|
+
target=run_ema_update_cpu,
|
|
269
|
+
args=(
|
|
270
|
+
self.ema_params,
|
|
271
|
+
current_model_state,
|
|
272
|
+
self.decay,
|
|
273
|
+
self.stream,
|
|
274
|
+
),
|
|
275
|
+
)
|
|
276
|
+
self.thread.start()
|
|
277
|
+
|
|
278
|
+
def swap_tensors(self, tensor1, tensor2):
|
|
279
|
+
tmp = torch.empty_like(tensor1)
|
|
280
|
+
tmp.copy_(tensor1)
|
|
281
|
+
tensor1.copy_(tensor2)
|
|
282
|
+
tensor2.copy_(tmp)
|
|
283
|
+
|
|
284
|
+
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
|
|
285
|
+
self.join()
|
|
286
|
+
self.in_saving_ema_model_context = saving_ema_model
|
|
287
|
+
for param, ema_param in zip(self.all_parameters(), self.ema_params):
|
|
288
|
+
self.swap_tensors(param.data, ema_param)
|
|
289
|
+
|
|
290
|
+
@contextlib.contextmanager
|
|
291
|
+
def swap_ema_weights(self, enabled: bool = True):
|
|
292
|
+
r"""
|
|
293
|
+
A context manager to in-place swap regular parameters with EMA
|
|
294
|
+
parameters.
|
|
295
|
+
It swaps back to the original regular parameters on context manager
|
|
296
|
+
exit.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
enabled (bool): whether the swap should be performed
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
if enabled:
|
|
303
|
+
self.switch_main_parameter_weights()
|
|
304
|
+
try:
|
|
305
|
+
yield
|
|
306
|
+
finally:
|
|
307
|
+
if enabled:
|
|
308
|
+
self.switch_main_parameter_weights()
|
|
309
|
+
|
|
310
|
+
def __getattr__(self, name):
|
|
311
|
+
return getattr(self.optimizer, name)
|
|
312
|
+
|
|
313
|
+
def join(self):
|
|
314
|
+
if self.stream is not None:
|
|
315
|
+
self.stream.synchronize()
|
|
316
|
+
|
|
317
|
+
if self.thread is not None:
|
|
318
|
+
self.thread.join()
|
|
319
|
+
|
|
320
|
+
@override
|
|
321
|
+
def state_dict(self):
|
|
322
|
+
self.join()
|
|
323
|
+
|
|
324
|
+
if self.save_original_optimizer_state:
|
|
325
|
+
return self.optimizer.state_dict()
|
|
326
|
+
|
|
327
|
+
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
|
|
328
|
+
ema_params = (
|
|
329
|
+
self.ema_params
|
|
330
|
+
if not self.in_saving_ema_model_context
|
|
331
|
+
else list(self.all_parameters())
|
|
332
|
+
)
|
|
333
|
+
state_dict = {
|
|
334
|
+
"opt": self.optimizer.state_dict(),
|
|
335
|
+
"ema": ema_params,
|
|
336
|
+
"current_step": self.current_step,
|
|
337
|
+
"decay": self.decay,
|
|
338
|
+
"every_n_steps": self.every_n_steps,
|
|
339
|
+
}
|
|
340
|
+
return state_dict
|
|
341
|
+
|
|
342
|
+
@override
|
|
343
|
+
def load_state_dict(self, state_dict):
|
|
344
|
+
self.join()
|
|
345
|
+
|
|
346
|
+
self.optimizer.load_state_dict(state_dict["opt"])
|
|
347
|
+
self.ema_params = tuple(
|
|
348
|
+
param.to(self.device) for param in copy.deepcopy(state_dict["ema"])
|
|
349
|
+
)
|
|
350
|
+
self.current_step = state_dict["current_step"]
|
|
351
|
+
self.decay = state_dict["decay"]
|
|
352
|
+
self.every_n_steps = state_dict["every_n_steps"]
|
|
353
|
+
self.rebuild_ema_params = False
|
|
354
|
+
|
|
355
|
+
@override
|
|
356
|
+
def add_param_group(self, param_group):
|
|
357
|
+
self.optimizer.add_param_group(param_group)
|
|
358
|
+
self.rebuild_ema_params = True
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class EMAConfig(CallbackConfigBase):
|
|
362
|
+
name: Literal["ema"] = "ema"
|
|
363
|
+
|
|
364
|
+
decay: float
|
|
365
|
+
"""The exponential decay used when calculating the moving average. Has to be between 0-1."""
|
|
366
|
+
|
|
367
|
+
validate_original_weights: bool = False
|
|
368
|
+
"""Validate the original weights, as apposed to the EMA weights."""
|
|
369
|
+
|
|
370
|
+
every_n_steps: int = 1
|
|
371
|
+
"""Apply EMA every N steps."""
|
|
372
|
+
|
|
373
|
+
cpu_offload: bool = False
|
|
374
|
+
"""Offload weights to CPU."""
|
|
375
|
+
|
|
376
|
+
@override
|
|
377
|
+
def construct_callbacks(self, root_config):
|
|
378
|
+
yield EMA(
|
|
379
|
+
decay=self.decay,
|
|
380
|
+
validate_original_weights=self.validate_original_weights,
|
|
381
|
+
every_n_steps=self.every_n_steps,
|
|
382
|
+
cpu_offload=self.cpu_offload,
|
|
383
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from .base import CallbackConfigBase
|
|
9
|
+
|
|
10
|
+
log = getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def finite_checks(
|
|
14
|
+
module: LightningModule,
|
|
15
|
+
nonfinite_grads: bool = True,
|
|
16
|
+
none_grads: bool = False,
|
|
17
|
+
):
|
|
18
|
+
for name, param in module.named_parameters():
|
|
19
|
+
if not param.requires_grad:
|
|
20
|
+
continue
|
|
21
|
+
|
|
22
|
+
if param.grad is None:
|
|
23
|
+
if none_grads:
|
|
24
|
+
log.critical(f"Parameter {name} ({param.shape}) has None gradients")
|
|
25
|
+
continue
|
|
26
|
+
|
|
27
|
+
if not nonfinite_grads or torch.isfinite(param.grad.float()).all():
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
has_nan = torch.isnan(param.grad.float()).any()
|
|
31
|
+
has_inf = torch.isinf(param.grad.float()).any()
|
|
32
|
+
kinds = [
|
|
33
|
+
"NaN" if has_nan else None,
|
|
34
|
+
"Inf" if has_inf else None,
|
|
35
|
+
]
|
|
36
|
+
kinds = " and ".join(prop for prop in kinds if prop is not None)
|
|
37
|
+
log.critical(f"{name} ({param.shape}) has {kinds} gradients")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FiniteChecksCallback(Callback):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
nonfinite_grads: bool = True,
|
|
45
|
+
none_grads: bool = True,
|
|
46
|
+
):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
self._nonfinite_grads = nonfinite_grads
|
|
50
|
+
self._none_grads = none_grads
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
|
|
54
|
+
finite_checks(
|
|
55
|
+
pl_module,
|
|
56
|
+
nonfinite_grads=self._nonfinite_grads,
|
|
57
|
+
none_grads=self._none_grads,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class FiniteChecksConfig(CallbackConfigBase):
|
|
62
|
+
name: Literal["finite_checks"] = "finite_checks"
|
|
63
|
+
|
|
64
|
+
nonfinite_grads: bool = True
|
|
65
|
+
"""Whether to check for non-finite (i.e. NaN or Inf) gradients"""
|
|
66
|
+
|
|
67
|
+
none_grads: bool = True
|
|
68
|
+
"""Whether to check for None gradients"""
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def construct_callbacks(self, root_config):
|
|
72
|
+
yield FiniteChecksCallback(
|
|
73
|
+
nonfinite_grads=self.nonfinite_grads,
|
|
74
|
+
none_grads=self.none_grads,
|
|
75
|
+
)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Literal, Protocol, runtime_checkable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torchmetrics
|
|
6
|
+
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from .base import CallbackConfigBase
|
|
11
|
+
from .norm_logging import compute_norm
|
|
12
|
+
|
|
13
|
+
log = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@runtime_checkable
|
|
17
|
+
class HasGradSkippedSteps(Protocol):
|
|
18
|
+
grad_skipped_steps: torchmetrics.SumMetric
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GradientSkipping(Callback):
|
|
22
|
+
def __init__(self, config: "GradientSkippingConfig"):
|
|
23
|
+
super().__init__()
|
|
24
|
+
|
|
25
|
+
self.config = config
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
29
|
+
if not isinstance(pl_module, HasGradSkippedSteps):
|
|
30
|
+
pl_module.grad_skipped_steps = torchmetrics.SumMetric()
|
|
31
|
+
|
|
32
|
+
@override
|
|
33
|
+
def on_before_optimizer_step(
|
|
34
|
+
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer
|
|
35
|
+
):
|
|
36
|
+
# This should never happen, but just in case
|
|
37
|
+
if not isinstance(pl_module, HasGradSkippedSteps):
|
|
38
|
+
raise TypeError(
|
|
39
|
+
f"Expected HasGradSkippedSteps, got {type(pl_module)} instead"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Skip the step if the global step is less than the start_after_n_steps
|
|
43
|
+
# This is because we want to let AMP adjust the loss scale before we start
|
|
44
|
+
if (
|
|
45
|
+
self.config.start_after_n_steps is not None
|
|
46
|
+
and pl_module.global_step < self.config.start_after_n_steps
|
|
47
|
+
):
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
norm = compute_norm(
|
|
51
|
+
pl_module,
|
|
52
|
+
optimizer,
|
|
53
|
+
self.config.norm_type,
|
|
54
|
+
grad=True,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# If the norm is NaN/Inf, we don't want to skip the step
|
|
58
|
+
# beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
|
|
59
|
+
if self.config.skip_non_finite and not torch.isfinite(norm).all():
|
|
60
|
+
optimizer.zero_grad()
|
|
61
|
+
pl_module.grad_skipped_steps(1)
|
|
62
|
+
log.warning(
|
|
63
|
+
f"Skipping step at global step {pl_module.global_step} with non-finite norm {norm:.2f}"
|
|
64
|
+
)
|
|
65
|
+
elif (norm > self.config.threshold).any():
|
|
66
|
+
optimizer.zero_grad()
|
|
67
|
+
pl_module.grad_skipped_steps(1)
|
|
68
|
+
log.warning(
|
|
69
|
+
f"Skipping step at global step {pl_module.global_step} with norm {norm:.2f} > {self.config.threshold:.2f}"
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
pl_module.grad_skipped_steps(0)
|
|
73
|
+
|
|
74
|
+
pl_module.log(
|
|
75
|
+
"train/grad_skipped_steps",
|
|
76
|
+
pl_module.grad_skipped_steps,
|
|
77
|
+
on_step=True,
|
|
78
|
+
on_epoch=False,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class GradientSkippingConfig(CallbackConfigBase):
|
|
83
|
+
name: Literal["gradient_skipping"] = "gradient_skipping"
|
|
84
|
+
|
|
85
|
+
threshold: float
|
|
86
|
+
"""Threshold to use for gradient skipping."""
|
|
87
|
+
|
|
88
|
+
norm_type: str | float = 2.0
|
|
89
|
+
"""Norm type to use for gradient skipping."""
|
|
90
|
+
|
|
91
|
+
start_after_n_steps: int | None = 100
|
|
92
|
+
"""Number of steps to wait before starting gradient skipping."""
|
|
93
|
+
|
|
94
|
+
skip_non_finite: bool = False
|
|
95
|
+
"""
|
|
96
|
+
If False, it doesn't skip steps with non-finite norms. This is useful when using AMP, as AMP checks for NaN/Inf grads to adjust the loss scale. Otherwise, skips steps with non-finite norms.
|
|
97
|
+
|
|
98
|
+
Should almost always be False, especially when using AMP (unless you know what you're doing!).
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def construct_callbacks(self, root_config):
|
|
103
|
+
yield GradientSkipping(self)
|