nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,112 @@
1
+ import math
2
+ from logging import getLogger
3
+
4
+ from lightning.fabric.utilities.rank_zero import _get_rank
5
+ from lightning.pytorch import Trainer
6
+ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
+ from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
+ from typing_extensions import override
9
+
10
+ log = getLogger(__name__)
11
+
12
+
13
+ class EarlyStopping(_EarlyStopping):
14
+ def __init__(
15
+ self,
16
+ monitor: str,
17
+ min_delta: float = 0,
18
+ min_lr: float | None = None,
19
+ patience: int = 3,
20
+ verbose: bool = True,
21
+ mode: str = "min",
22
+ strict: bool = True,
23
+ check_finite: bool = True,
24
+ stopping_threshold: float | None = None,
25
+ divergence_threshold: float | None = None,
26
+ check_on_train_epoch_end: bool | None = None,
27
+ log_rank_zero_only: bool = False,
28
+ ):
29
+ super().__init__(
30
+ monitor,
31
+ min_delta,
32
+ patience,
33
+ verbose,
34
+ mode,
35
+ strict,
36
+ check_finite,
37
+ stopping_threshold,
38
+ divergence_threshold,
39
+ check_on_train_epoch_end,
40
+ log_rank_zero_only,
41
+ )
42
+
43
+ self.min_lr = min_lr
44
+
45
+ @override
46
+ @staticmethod
47
+ def _log_info(
48
+ trainer: Trainer | None, message: str, log_rank_zero_only: bool
49
+ ) -> None:
50
+ rank = _get_rank()
51
+ if trainer is not None and trainer.world_size <= 1:
52
+ rank = None
53
+ message = rank_prefixed_message(message, rank)
54
+ if rank is None or not log_rank_zero_only or rank == 0:
55
+ log.critical(message)
56
+
57
+ @override
58
+ def _run_early_stopping_check(self, trainer: Trainer):
59
+ """Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
60
+ logs = trainer.callback_metrics
61
+
62
+ # Disable early_stopping with fast_dev_run
63
+ if getattr(trainer, "fast_dev_run", False):
64
+ return
65
+
66
+ should_stop, reason = False, None
67
+
68
+ if not should_stop:
69
+ should_stop, reason = self._evaluate_stopping_criteria_min_lr(trainer)
70
+
71
+ # If metric present
72
+ if not should_stop and self._validate_condition_metric(logs):
73
+ current = logs[self.monitor].squeeze()
74
+ should_stop, reason = self._evaluate_stopping_criteria(current)
75
+
76
+ # stop every ddp process if any world process decides to stop
77
+ should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False)
78
+ trainer.should_stop = trainer.should_stop or should_stop
79
+ if should_stop:
80
+ self.stopped_epoch = trainer.current_epoch
81
+ if reason and self.verbose:
82
+ self._log_info(trainer, reason, self.log_rank_zero_only)
83
+
84
+ def _evaluate_stopping_criteria_min_lr(
85
+ self, trainer: Trainer
86
+ ) -> tuple[bool, str | None]:
87
+ if self.min_lr is None:
88
+ return False, None
89
+
90
+ # Get the maximum LR across all param groups in all optimizers
91
+ model_max_lr = max(
92
+ [
93
+ param_group["lr"]
94
+ for optimizer in trainer.optimizers
95
+ for param_group in optimizer.param_groups
96
+ ]
97
+ )
98
+ if not isinstance(model_max_lr, float) or not math.isfinite(model_max_lr):
99
+ return False, None
100
+
101
+ # If the maximum LR is less than the minimum LR, stop training
102
+ if model_max_lr >= self.min_lr:
103
+ return False, None
104
+
105
+ return True, (
106
+ "Stopping threshold reached: "
107
+ f"The maximum LR of the model across all param groups is {model_max_lr:.2e} "
108
+ f"which is less than the minimum LR {self.min_lr:.2e}"
109
+ )
110
+
111
+ def on_early_stopping(self, trainer: Trainer):
112
+ pass
@@ -0,0 +1,383 @@
1
+ import contextlib
2
+ import copy
3
+ import threading
4
+ from collections.abc import Callable, Iterable
5
+ from typing import Any, Literal, overload
6
+
7
+ import lightning.pytorch as pl
8
+ import torch
9
+ from lightning.pytorch import Callback
10
+ from lightning.pytorch.utilities.exceptions import MisconfigurationException
11
+ from typing_extensions import override
12
+
13
+ from .base import CallbackConfigBase
14
+
15
+
16
+ class EMA(Callback):
17
+ """
18
+ Implements Exponential Moving Averaging (EMA).
19
+
20
+ When training a model, this callback will maintain moving averages of the trained parameters.
21
+ When evaluating, we use the moving averages copy of the trained parameters.
22
+ When saving, we save an additional set of parameters with the prefix `ema`.
23
+
24
+ Args:
25
+ decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
26
+ validate_original_weights: Validate the original weights, as apposed to the EMA weights.
27
+ every_n_steps: Apply EMA every N steps.
28
+ cpu_offload: Offload weights to CPU.
29
+ """
30
+
31
+ @override
32
+ def __init__(
33
+ self,
34
+ decay: float,
35
+ validate_original_weights: bool = False,
36
+ every_n_steps: int = 1,
37
+ cpu_offload: bool = False,
38
+ ):
39
+ if not (0 <= decay <= 1):
40
+ raise MisconfigurationException("EMA decay value must be between 0 and 1")
41
+ self.decay = decay
42
+ self.validate_original_weights = validate_original_weights
43
+ self.every_n_steps = every_n_steps
44
+ self.cpu_offload = cpu_offload
45
+
46
+ @override
47
+ def on_fit_start(
48
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
49
+ ) -> None:
50
+ device = pl_module.device if not self.cpu_offload else torch.device("cpu")
51
+ trainer.optimizers = [
52
+ EMAOptimizer(
53
+ optim,
54
+ device=device,
55
+ decay=self.decay,
56
+ every_n_steps=self.every_n_steps,
57
+ current_step=trainer.global_step,
58
+ )
59
+ for optim in trainer.optimizers
60
+ if not isinstance(optim, EMAOptimizer)
61
+ ]
62
+
63
+ @override
64
+ def on_validation_start(
65
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
66
+ ) -> None:
67
+ if self._should_validate_ema_weights(trainer):
68
+ self.swap_model_weights(trainer)
69
+
70
+ @override
71
+ def on_validation_end(
72
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
73
+ ) -> None:
74
+ if self._should_validate_ema_weights(trainer):
75
+ self.swap_model_weights(trainer)
76
+
77
+ @override
78
+ def on_test_start(
79
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
80
+ ) -> None:
81
+ if self._should_validate_ema_weights(trainer):
82
+ self.swap_model_weights(trainer)
83
+
84
+ @override
85
+ def on_test_end(
86
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
87
+ ) -> None:
88
+ if self._should_validate_ema_weights(trainer):
89
+ self.swap_model_weights(trainer)
90
+
91
+ def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
92
+ return not self.validate_original_weights and self._ema_initialized(trainer)
93
+
94
+ def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
95
+ return any(
96
+ isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers
97
+ )
98
+
99
+ def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
100
+ for optimizer in trainer.optimizers:
101
+ assert isinstance(optimizer, EMAOptimizer)
102
+ optimizer.switch_main_parameter_weights(saving_ema_model)
103
+
104
+ @contextlib.contextmanager
105
+ def save_ema_model(self, trainer: "pl.Trainer"):
106
+ """
107
+ Saves an EMA copy of the model + EMA optimizer states for resume.
108
+ """
109
+ self.swap_model_weights(trainer, saving_ema_model=True)
110
+ try:
111
+ yield
112
+ finally:
113
+ self.swap_model_weights(trainer, saving_ema_model=False)
114
+
115
+ @contextlib.contextmanager
116
+ def save_original_optimizer_state(self, trainer: "pl.Trainer"):
117
+ for optimizer in trainer.optimizers:
118
+ assert isinstance(optimizer, EMAOptimizer)
119
+ optimizer.save_original_optimizer_state = True
120
+ try:
121
+ yield
122
+ finally:
123
+ for optimizer in trainer.optimizers:
124
+ assert isinstance(optimizer, EMAOptimizer)
125
+ optimizer.save_original_optimizer_state = False
126
+
127
+
128
+ @torch.no_grad()
129
+ def ema_update(ema_model_tuple, current_model_tuple, decay):
130
+ torch._foreach_mul_(ema_model_tuple, decay)
131
+ torch._foreach_add_(
132
+ ema_model_tuple,
133
+ current_model_tuple,
134
+ alpha=(1.0 - decay),
135
+ )
136
+
137
+
138
+ def run_ema_update_cpu(
139
+ ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None
140
+ ):
141
+ if pre_sync_stream is not None:
142
+ pre_sync_stream.synchronize()
143
+
144
+ ema_update(ema_model_tuple, current_model_tuple, decay)
145
+
146
+
147
+ class EMAOptimizer(torch.optim.Optimizer):
148
+ r"""
149
+ EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
150
+ Exponential Moving Average of parameters registered in the optimizer.
151
+
152
+ EMA parameters are automatically updated after every step of the optimizer
153
+ with the following formula:
154
+
155
+ ema_weight = decay * ema_weight + (1 - decay) * training_weight
156
+
157
+ To access EMA parameters, use ``swap_ema_weights()`` context manager to
158
+ perform a temporary in-place swap of regular parameters with EMA
159
+ parameters.
160
+
161
+ Notes:
162
+ - EMAOptimizer is not compatible with APEX AMP O2.
163
+
164
+ Args:
165
+ optimizer (torch.optim.Optimizer): optimizer to wrap
166
+ device (torch.device): device for EMA parameters
167
+ decay (float): decay factor
168
+
169
+ Returns:
170
+ returns an instance of torch.optim.Optimizer that computes EMA of
171
+ parameters
172
+
173
+ Example:
174
+ model = Model().to(device)
175
+ opt = torch.optim.Adam(model.parameters())
176
+
177
+ opt = EMAOptimizer(opt, device, 0.9999)
178
+
179
+ for epoch in range(epochs):
180
+ training_loop(model, opt)
181
+
182
+ regular_eval_accuracy = evaluate(model)
183
+
184
+ with opt.swap_ema_weights():
185
+ ema_eval_accuracy = evaluate(model)
186
+ """
187
+
188
+ stream: Any | None
189
+
190
+ @override
191
+ def __init__(
192
+ self,
193
+ optimizer: torch.optim.Optimizer,
194
+ device: torch.device,
195
+ decay: float = 0.9999,
196
+ every_n_steps: int = 1,
197
+ current_step: int = 0,
198
+ ):
199
+ self.optimizer = optimizer
200
+ self.decay = decay
201
+ self.device = device
202
+ self.current_step = current_step
203
+ self.every_n_steps = every_n_steps
204
+ self.save_original_optimizer_state = False
205
+
206
+ self.first_iteration = True
207
+ self.rebuild_ema_params = True
208
+ self.stream = None
209
+ self.thread = None
210
+
211
+ self.ema_params = ()
212
+ self.in_saving_ema_model_context = False
213
+
214
+ def all_parameters(self) -> Iterable[torch.Tensor]:
215
+ return (param for group in self.param_groups for param in group["params"])
216
+
217
+ @overload
218
+ def step(self, closure: None = ...) -> None: ...
219
+
220
+ @overload
221
+ def step(self, closure: Callable[[], float]) -> float: ...
222
+
223
+ @override
224
+ def step(self, closure: Callable[[], float] | None = None) -> float | None:
225
+ self.join()
226
+
227
+ if self.first_iteration:
228
+ if any(p.is_cuda for p in self.all_parameters()):
229
+ self.stream = torch.cuda.Stream()
230
+
231
+ self.first_iteration = False
232
+
233
+ if self.rebuild_ema_params:
234
+ opt_params = list(self.all_parameters())
235
+
236
+ self.ema_params += tuple(
237
+ copy.deepcopy(param.data.detach()).to(self.device)
238
+ for param in opt_params[len(self.ema_params) :]
239
+ )
240
+ self.rebuild_ema_params = False
241
+
242
+ loss = self.optimizer.step(closure)
243
+
244
+ if self._should_update_at_step():
245
+ self.update()
246
+ self.current_step += 1
247
+ return loss
248
+
249
+ def _should_update_at_step(self) -> bool:
250
+ return self.current_step % self.every_n_steps == 0
251
+
252
+ @torch.no_grad()
253
+ def update(self):
254
+ if self.stream is not None:
255
+ self.stream.wait_stream(torch.cuda.current_stream())
256
+
257
+ with torch.cuda.stream(self.stream):
258
+ current_model_state = tuple(
259
+ param.data.to(self.device, non_blocking=True)
260
+ for param in self.all_parameters()
261
+ )
262
+
263
+ if self.device.type == "cuda":
264
+ ema_update(self.ema_params, current_model_state, self.decay)
265
+
266
+ if self.device.type == "cpu":
267
+ self.thread = threading.Thread(
268
+ target=run_ema_update_cpu,
269
+ args=(
270
+ self.ema_params,
271
+ current_model_state,
272
+ self.decay,
273
+ self.stream,
274
+ ),
275
+ )
276
+ self.thread.start()
277
+
278
+ def swap_tensors(self, tensor1, tensor2):
279
+ tmp = torch.empty_like(tensor1)
280
+ tmp.copy_(tensor1)
281
+ tensor1.copy_(tensor2)
282
+ tensor2.copy_(tmp)
283
+
284
+ def switch_main_parameter_weights(self, saving_ema_model: bool = False):
285
+ self.join()
286
+ self.in_saving_ema_model_context = saving_ema_model
287
+ for param, ema_param in zip(self.all_parameters(), self.ema_params):
288
+ self.swap_tensors(param.data, ema_param)
289
+
290
+ @contextlib.contextmanager
291
+ def swap_ema_weights(self, enabled: bool = True):
292
+ r"""
293
+ A context manager to in-place swap regular parameters with EMA
294
+ parameters.
295
+ It swaps back to the original regular parameters on context manager
296
+ exit.
297
+
298
+ Args:
299
+ enabled (bool): whether the swap should be performed
300
+ """
301
+
302
+ if enabled:
303
+ self.switch_main_parameter_weights()
304
+ try:
305
+ yield
306
+ finally:
307
+ if enabled:
308
+ self.switch_main_parameter_weights()
309
+
310
+ def __getattr__(self, name):
311
+ return getattr(self.optimizer, name)
312
+
313
+ def join(self):
314
+ if self.stream is not None:
315
+ self.stream.synchronize()
316
+
317
+ if self.thread is not None:
318
+ self.thread.join()
319
+
320
+ @override
321
+ def state_dict(self):
322
+ self.join()
323
+
324
+ if self.save_original_optimizer_state:
325
+ return self.optimizer.state_dict()
326
+
327
+ # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
328
+ ema_params = (
329
+ self.ema_params
330
+ if not self.in_saving_ema_model_context
331
+ else list(self.all_parameters())
332
+ )
333
+ state_dict = {
334
+ "opt": self.optimizer.state_dict(),
335
+ "ema": ema_params,
336
+ "current_step": self.current_step,
337
+ "decay": self.decay,
338
+ "every_n_steps": self.every_n_steps,
339
+ }
340
+ return state_dict
341
+
342
+ @override
343
+ def load_state_dict(self, state_dict):
344
+ self.join()
345
+
346
+ self.optimizer.load_state_dict(state_dict["opt"])
347
+ self.ema_params = tuple(
348
+ param.to(self.device) for param in copy.deepcopy(state_dict["ema"])
349
+ )
350
+ self.current_step = state_dict["current_step"]
351
+ self.decay = state_dict["decay"]
352
+ self.every_n_steps = state_dict["every_n_steps"]
353
+ self.rebuild_ema_params = False
354
+
355
+ @override
356
+ def add_param_group(self, param_group):
357
+ self.optimizer.add_param_group(param_group)
358
+ self.rebuild_ema_params = True
359
+
360
+
361
+ class EMAConfig(CallbackConfigBase):
362
+ name: Literal["ema"] = "ema"
363
+
364
+ decay: float
365
+ """The exponential decay used when calculating the moving average. Has to be between 0-1."""
366
+
367
+ validate_original_weights: bool = False
368
+ """Validate the original weights, as apposed to the EMA weights."""
369
+
370
+ every_n_steps: int = 1
371
+ """Apply EMA every N steps."""
372
+
373
+ cpu_offload: bool = False
374
+ """Offload weights to CPU."""
375
+
376
+ @override
377
+ def construct_callbacks(self, root_config):
378
+ yield EMA(
379
+ decay=self.decay,
380
+ validate_original_weights=self.validate_original_weights,
381
+ every_n_steps=self.every_n_steps,
382
+ cpu_offload=self.cpu_offload,
383
+ )
@@ -0,0 +1,75 @@
1
+ from logging import getLogger
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from lightning.pytorch import Callback, LightningModule, Trainer
6
+ from typing_extensions import override
7
+
8
+ from .base import CallbackConfigBase
9
+
10
+ log = getLogger(__name__)
11
+
12
+
13
+ def finite_checks(
14
+ module: LightningModule,
15
+ nonfinite_grads: bool = True,
16
+ none_grads: bool = False,
17
+ ):
18
+ for name, param in module.named_parameters():
19
+ if not param.requires_grad:
20
+ continue
21
+
22
+ if param.grad is None:
23
+ if none_grads:
24
+ log.critical(f"Parameter {name} ({param.shape}) has None gradients")
25
+ continue
26
+
27
+ if not nonfinite_grads or torch.isfinite(param.grad.float()).all():
28
+ continue
29
+
30
+ has_nan = torch.isnan(param.grad.float()).any()
31
+ has_inf = torch.isinf(param.grad.float()).any()
32
+ kinds = [
33
+ "NaN" if has_nan else None,
34
+ "Inf" if has_inf else None,
35
+ ]
36
+ kinds = " and ".join(prop for prop in kinds if prop is not None)
37
+ log.critical(f"{name} ({param.shape}) has {kinds} gradients")
38
+
39
+
40
+ class FiniteChecksCallback(Callback):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ nonfinite_grads: bool = True,
45
+ none_grads: bool = True,
46
+ ):
47
+ super().__init__()
48
+
49
+ self._nonfinite_grads = nonfinite_grads
50
+ self._none_grads = none_grads
51
+
52
+ @override
53
+ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
54
+ finite_checks(
55
+ pl_module,
56
+ nonfinite_grads=self._nonfinite_grads,
57
+ none_grads=self._none_grads,
58
+ )
59
+
60
+
61
+ class FiniteChecksConfig(CallbackConfigBase):
62
+ name: Literal["finite_checks"] = "finite_checks"
63
+
64
+ nonfinite_grads: bool = True
65
+ """Whether to check for non-finite (i.e. NaN or Inf) gradients"""
66
+
67
+ none_grads: bool = True
68
+ """Whether to check for None gradients"""
69
+
70
+ @override
71
+ def construct_callbacks(self, root_config):
72
+ yield FiniteChecksCallback(
73
+ nonfinite_grads=self.nonfinite_grads,
74
+ none_grads=self.none_grads,
75
+ )
@@ -0,0 +1,103 @@
1
+ from logging import getLogger
2
+ from typing import Literal, Protocol, runtime_checkable
3
+
4
+ import torch
5
+ import torchmetrics
6
+ from lightning.pytorch import Callback, LightningModule, Trainer
7
+ from torch.optim import Optimizer
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+ from .norm_logging import compute_norm
12
+
13
+ log = getLogger(__name__)
14
+
15
+
16
+ @runtime_checkable
17
+ class HasGradSkippedSteps(Protocol):
18
+ grad_skipped_steps: torchmetrics.SumMetric
19
+
20
+
21
+ class GradientSkipping(Callback):
22
+ def __init__(self, config: "GradientSkippingConfig"):
23
+ super().__init__()
24
+
25
+ self.config = config
26
+
27
+ @override
28
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
29
+ if not isinstance(pl_module, HasGradSkippedSteps):
30
+ pl_module.grad_skipped_steps = torchmetrics.SumMetric()
31
+
32
+ @override
33
+ def on_before_optimizer_step(
34
+ self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer
35
+ ):
36
+ # This should never happen, but just in case
37
+ if not isinstance(pl_module, HasGradSkippedSteps):
38
+ raise TypeError(
39
+ f"Expected HasGradSkippedSteps, got {type(pl_module)} instead"
40
+ )
41
+
42
+ # Skip the step if the global step is less than the start_after_n_steps
43
+ # This is because we want to let AMP adjust the loss scale before we start
44
+ if (
45
+ self.config.start_after_n_steps is not None
46
+ and pl_module.global_step < self.config.start_after_n_steps
47
+ ):
48
+ return
49
+
50
+ norm = compute_norm(
51
+ pl_module,
52
+ optimizer,
53
+ self.config.norm_type,
54
+ grad=True,
55
+ )
56
+
57
+ # If the norm is NaN/Inf, we don't want to skip the step
58
+ # beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
59
+ if self.config.skip_non_finite and not torch.isfinite(norm).all():
60
+ optimizer.zero_grad()
61
+ pl_module.grad_skipped_steps(1)
62
+ log.warning(
63
+ f"Skipping step at global step {pl_module.global_step} with non-finite norm {norm:.2f}"
64
+ )
65
+ elif (norm > self.config.threshold).any():
66
+ optimizer.zero_grad()
67
+ pl_module.grad_skipped_steps(1)
68
+ log.warning(
69
+ f"Skipping step at global step {pl_module.global_step} with norm {norm:.2f} > {self.config.threshold:.2f}"
70
+ )
71
+ else:
72
+ pl_module.grad_skipped_steps(0)
73
+
74
+ pl_module.log(
75
+ "train/grad_skipped_steps",
76
+ pl_module.grad_skipped_steps,
77
+ on_step=True,
78
+ on_epoch=False,
79
+ )
80
+
81
+
82
+ class GradientSkippingConfig(CallbackConfigBase):
83
+ name: Literal["gradient_skipping"] = "gradient_skipping"
84
+
85
+ threshold: float
86
+ """Threshold to use for gradient skipping."""
87
+
88
+ norm_type: str | float = 2.0
89
+ """Norm type to use for gradient skipping."""
90
+
91
+ start_after_n_steps: int | None = 100
92
+ """Number of steps to wait before starting gradient skipping."""
93
+
94
+ skip_non_finite: bool = False
95
+ """
96
+ If False, it doesn't skip steps with non-finite norms. This is useful when using AMP, as AMP checks for NaN/Inf grads to adjust the loss scale. Otherwise, skips steps with non-finite norms.
97
+
98
+ Should almost always be False, especially when using AMP (unless you know what you're doing!).
99
+ """
100
+
101
+ @override
102
+ def construct_callbacks(self, root_config):
103
+ yield GradientSkipping(self)