nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,157 @@
1
+ from collections import abc
2
+ from collections.abc import Callable, Iterable
3
+ from logging import getLogger
4
+ from typing import Any, TypeAlias, cast, final
5
+
6
+ from lightning.pytorch import Callback, LightningModule
7
+ from lightning.pytorch.callbacks import LambdaCallback
8
+ from typing_extensions import override
9
+
10
+ from ...util.typing_utils import mixin_base_type
11
+
12
+ log = getLogger(__name__)
13
+
14
+ CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
15
+
16
+
17
+ class CallbackRegistrarModuleMixin:
18
+ @override
19
+ def __init__(self, *args, **kwargs):
20
+ super().__init__(*args, **kwargs)
21
+
22
+ self._ll_callbacks: list[CallbackFn] = []
23
+
24
+ def register_callback(
25
+ self,
26
+ callback: Callback | Iterable[Callback] | CallbackFn | None = None,
27
+ *,
28
+ setup: Callable | None = None,
29
+ teardown: Callable | None = None,
30
+ on_fit_start: Callable | None = None,
31
+ on_fit_end: Callable | None = None,
32
+ on_sanity_check_start: Callable | None = None,
33
+ on_sanity_check_end: Callable | None = None,
34
+ on_train_batch_start: Callable | None = None,
35
+ on_train_batch_end: Callable | None = None,
36
+ on_train_epoch_start: Callable | None = None,
37
+ on_train_epoch_end: Callable | None = None,
38
+ on_validation_epoch_start: Callable | None = None,
39
+ on_validation_epoch_end: Callable | None = None,
40
+ on_test_epoch_start: Callable | None = None,
41
+ on_test_epoch_end: Callable | None = None,
42
+ on_validation_batch_start: Callable | None = None,
43
+ on_validation_batch_end: Callable | None = None,
44
+ on_test_batch_start: Callable | None = None,
45
+ on_test_batch_end: Callable | None = None,
46
+ on_train_start: Callable | None = None,
47
+ on_train_end: Callable | None = None,
48
+ on_validation_start: Callable | None = None,
49
+ on_validation_end: Callable | None = None,
50
+ on_test_start: Callable | None = None,
51
+ on_test_end: Callable | None = None,
52
+ on_exception: Callable | None = None,
53
+ on_save_checkpoint: Callable | None = None,
54
+ on_load_checkpoint: Callable | None = None,
55
+ on_before_backward: Callable | None = None,
56
+ on_after_backward: Callable | None = None,
57
+ on_before_optimizer_step: Callable | None = None,
58
+ on_before_zero_grad: Callable | None = None,
59
+ on_predict_start: Callable | None = None,
60
+ on_predict_end: Callable | None = None,
61
+ on_predict_batch_start: Callable | None = None,
62
+ on_predict_batch_end: Callable | None = None,
63
+ on_predict_epoch_start: Callable | None = None,
64
+ on_predict_epoch_end: Callable | None = None,
65
+ ):
66
+ if callback is None:
67
+ callback = LambdaCallback(
68
+ setup=setup,
69
+ teardown=teardown,
70
+ on_fit_start=on_fit_start,
71
+ on_fit_end=on_fit_end,
72
+ on_sanity_check_start=on_sanity_check_start,
73
+ on_sanity_check_end=on_sanity_check_end,
74
+ on_train_batch_start=on_train_batch_start,
75
+ on_train_batch_end=on_train_batch_end,
76
+ on_train_epoch_start=on_train_epoch_start,
77
+ on_train_epoch_end=on_train_epoch_end,
78
+ on_validation_epoch_start=on_validation_epoch_start,
79
+ on_validation_epoch_end=on_validation_epoch_end,
80
+ on_test_epoch_start=on_test_epoch_start,
81
+ on_test_epoch_end=on_test_epoch_end,
82
+ on_validation_batch_start=on_validation_batch_start,
83
+ on_validation_batch_end=on_validation_batch_end,
84
+ on_test_batch_start=on_test_batch_start,
85
+ on_test_batch_end=on_test_batch_end,
86
+ on_train_start=on_train_start,
87
+ on_train_end=on_train_end,
88
+ on_validation_start=on_validation_start,
89
+ on_validation_end=on_validation_end,
90
+ on_test_start=on_test_start,
91
+ on_test_end=on_test_end,
92
+ on_exception=on_exception,
93
+ on_save_checkpoint=on_save_checkpoint,
94
+ on_load_checkpoint=on_load_checkpoint,
95
+ on_before_backward=on_before_backward,
96
+ on_after_backward=on_after_backward,
97
+ on_before_optimizer_step=on_before_optimizer_step,
98
+ on_before_zero_grad=on_before_zero_grad,
99
+ on_predict_start=on_predict_start,
100
+ on_predict_end=on_predict_end,
101
+ on_predict_batch_start=on_predict_batch_start,
102
+ on_predict_batch_end=on_predict_batch_end,
103
+ on_predict_epoch_start=on_predict_epoch_start,
104
+ on_predict_epoch_end=on_predict_epoch_end,
105
+ )
106
+
107
+ if not callable(callback):
108
+ callback_ = cast(CallbackFn, lambda: callback)
109
+ else:
110
+ callback_ = callback
111
+
112
+ self._ll_callbacks.append(callback_)
113
+
114
+
115
+ class CallbackModuleMixin(
116
+ CallbackRegistrarModuleMixin,
117
+ mixin_base_type(LightningModule),
118
+ ):
119
+ def _gather_all_callbacks(self):
120
+ modules: list[Any] = []
121
+ if isinstance(self, CallbackRegistrarModuleMixin):
122
+ modules.append(self)
123
+ if (
124
+ datamodule := getattr(self.trainer, "datamodule", None)
125
+ ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
126
+ modules.append(datamodule)
127
+ modules.extend(
128
+ module
129
+ for module in self.children()
130
+ if isinstance(module, CallbackRegistrarModuleMixin)
131
+ )
132
+ for module in modules:
133
+ yield from module._ll_callbacks
134
+
135
+ @final
136
+ @override
137
+ def configure_callbacks(self):
138
+ callbacks = super().configure_callbacks()
139
+ if not isinstance(callbacks, abc.Sequence):
140
+ callbacks = [callbacks]
141
+
142
+ callbacks = list(callbacks)
143
+ for callback_fn in self._gather_all_callbacks():
144
+ callback_result = callback_fn()
145
+ if callback_result is None:
146
+ continue
147
+
148
+ if not isinstance(callback_result, abc.Iterable):
149
+ callback_result = [callback_result]
150
+
151
+ for callback in callback_result:
152
+ log.info(
153
+ f"Registering {callback.__class__.__qualname__} callback {callback}"
154
+ )
155
+ callbacks.append(callback)
156
+
157
+ return callbacks
@@ -0,0 +1,42 @@
1
+ from logging import getLogger
2
+
3
+ import torch
4
+ import torch.distributed
5
+
6
+ log = getLogger(__name__)
7
+
8
+
9
+ class DebugModuleMixin:
10
+ @torch.jit.unused
11
+ def breakpoint(self, rank_zero_only: bool = True):
12
+ if (
13
+ not rank_zero_only
14
+ or not torch.distributed.is_initialized()
15
+ or torch.distributed.get_rank() == 0
16
+ ):
17
+ breakpoint()
18
+
19
+ if rank_zero_only and torch.distributed.is_initialized():
20
+ _ = torch.distributed.barrier()
21
+
22
+ @torch.jit.unused
23
+ def ensure_finite(
24
+ self,
25
+ tensor: torch.Tensor,
26
+ name: str | None = None,
27
+ throw: bool = False,
28
+ ):
29
+ name_parts: list[str] = ["Tensor"]
30
+ if name is not None:
31
+ name_parts.append(name)
32
+ name = " ".join(name_parts)
33
+
34
+ not_finite = ~torch.isfinite(tensor)
35
+ if not_finite.any():
36
+ msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
37
+ if throw:
38
+ raise RuntimeError(msg)
39
+ else:
40
+ log.warning(msg)
41
+ return False
42
+ return True
@@ -0,0 +1,70 @@
1
+ from typing import Any, Literal, cast
2
+
3
+ import torch.distributed
4
+ from lightning.pytorch import LightningModule
5
+ from torch.distributed import ReduceOp
6
+ from typing_extensions import TypeVar
7
+
8
+ from ...util.typing_utils import mixin_base_type
9
+
10
+ T = TypeVar("T", infer_variance=True)
11
+
12
+ ReduceOpStr = Literal[
13
+ "avg",
14
+ "mean",
15
+ "band",
16
+ "bor",
17
+ "bxor",
18
+ "max",
19
+ "min",
20
+ "premul_sum",
21
+ "product",
22
+ "sum",
23
+ ]
24
+ VALID_REDUCE_OPS = (
25
+ "avg",
26
+ "mean",
27
+ "band",
28
+ "bor",
29
+ "bxor",
30
+ "max",
31
+ "min",
32
+ "premul_sum",
33
+ "product",
34
+ "sum",
35
+ )
36
+
37
+
38
+ class DistributedMixin(mixin_base_type(LightningModule)):
39
+ def all_gather_object(
40
+ self,
41
+ object: T,
42
+ group: torch.distributed.ProcessGroup | None = None,
43
+ ) -> list[T]:
44
+ if (
45
+ not torch.distributed.is_available()
46
+ or not torch.distributed.is_initialized()
47
+ ):
48
+ return [object]
49
+
50
+ object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
51
+ torch.distributed.all_gather_object(object_list, object, group=group)
52
+ return object_list
53
+
54
+ def barrier(self, name: str | None = None):
55
+ self.trainer.strategy.barrier(name=name)
56
+
57
+ def reduce(
58
+ self,
59
+ tensor: torch.Tensor,
60
+ reduce_op: ReduceOp.RedOpType | ReduceOpStr,
61
+ group: Any | None = None,
62
+ ) -> torch.Tensor:
63
+ if isinstance(reduce_op, str):
64
+ # validate reduce_op
65
+ if reduce_op not in VALID_REDUCE_OPS:
66
+ raise ValueError(
67
+ f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
68
+ )
69
+
70
+ return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
@@ -0,0 +1,170 @@
1
+ from collections import deque
2
+ from collections.abc import Callable, Generator
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, cast
7
+
8
+ import torchmetrics
9
+ from lightning.pytorch import LightningDataModule, LightningModule
10
+ from lightning.pytorch.utilities.types import _METRIC
11
+ from lightning_utilities.core.rank_zero import rank_zero_warn
12
+ from typing_extensions import override
13
+
14
+ from ...actsave import ActSave
15
+ from ...util.typing_utils import mixin_base_type
16
+ from ..config import BaseConfig
17
+
18
+
19
+ @dataclass(frozen=True, kw_only=True)
20
+ class _LogContext:
21
+ prefix: str | None = None
22
+ disabled: bool | None = None
23
+ kwargs: dict[str, Any] = field(default_factory=dict)
24
+
25
+
26
+ class LoggerModuleMixin:
27
+ @property
28
+ def log_dir(self):
29
+ if not isinstance(self, (LightningModule, LightningDataModule)):
30
+ raise TypeError(
31
+ "log_dir can only be used on LightningModule or LightningDataModule"
32
+ )
33
+
34
+ if (trainer := self.trainer) is None:
35
+ raise RuntimeError("trainer is not defined")
36
+
37
+ if (logger := trainer.logger) is None:
38
+ raise RuntimeError("trainer.logger is not defined")
39
+
40
+ if (log_dir := logger.log_dir) is None:
41
+ raise RuntimeError("trainer.logger.log_dir is not defined")
42
+
43
+ return Path(log_dir)
44
+
45
+ @property
46
+ def should_update_logs(self):
47
+ if not isinstance(self, (LightningModule, LightningDataModule)):
48
+ raise TypeError(
49
+ "should_update_logs can only be used on LightningModule or LightningDataModule"
50
+ )
51
+
52
+ trainer = self._trainer if isinstance(self, LightningModule) else self.trainer
53
+ if trainer is None:
54
+ return True
55
+
56
+ return trainer._logger_connector.should_update_logs
57
+
58
+
59
+ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningModule)):
60
+ @override
61
+ def __init__(self, *args, **kwargs):
62
+ super().__init__(*args, **kwargs)
63
+
64
+ self._logger_prefix_stack = deque[_LogContext]()
65
+
66
+ if TYPE_CHECKING:
67
+
68
+ @contextmanager
69
+ def log_context(
70
+ self,
71
+ prefix: str | None = None,
72
+ *,
73
+ disabled: bool | None = None,
74
+ prog_bar: bool | None = None,
75
+ logger: bool | None = None,
76
+ on_step: bool | None = None,
77
+ on_epoch: bool | None = None,
78
+ reduce_fx: str | Callable | None = None,
79
+ enable_graph: bool | None = None,
80
+ sync_dist: bool | None = None,
81
+ sync_dist_group: Any | None = None,
82
+ add_dataloader_idx: bool | None = None,
83
+ batch_size: int | None = None,
84
+ rank_zero_only: bool | None = None,
85
+ ) -> Generator[None, None, None]: ...
86
+
87
+ else:
88
+
89
+ @contextmanager
90
+ def log_context(
91
+ self, prefix: str | None = None, *, disabled: bool | None = None, **kwargs
92
+ ) -> Generator[None, None, None]:
93
+ self._logger_prefix_stack.append(
94
+ _LogContext(
95
+ prefix=prefix,
96
+ disabled=disabled,
97
+ kwargs=kwargs,
98
+ )
99
+ )
100
+ try:
101
+ yield
102
+ finally:
103
+ _ = self._logger_prefix_stack.pop()
104
+
105
+ if TYPE_CHECKING:
106
+
107
+ @override
108
+ def log( # type: ignore[override]
109
+ self,
110
+ name: str,
111
+ value: _METRIC,
112
+ *,
113
+ prog_bar: bool = False,
114
+ logger: bool | None = None,
115
+ on_step: bool | None = None,
116
+ on_epoch: bool | None = None,
117
+ reduce_fx: str | Callable = "mean",
118
+ enable_graph: bool = False,
119
+ sync_dist: bool = False,
120
+ sync_dist_group: Any | None = None,
121
+ add_dataloader_idx: bool = True,
122
+ batch_size: int | None = None,
123
+ metric_attribute: str | None = None,
124
+ rank_zero_only: bool = False,
125
+ ) -> None: ...
126
+
127
+ else:
128
+
129
+ @override
130
+ def log(self, name: str, value: _METRIC, **kwargs) -> None:
131
+ # join all prefixes
132
+ prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
133
+ name = f"{prefix}{name}"
134
+
135
+ # check for disabled context:
136
+ # if the topmost non-null context is disabled, then we don't log
137
+ for c in reversed(self._logger_prefix_stack):
138
+ if c.disabled is not None:
139
+ if c.disabled:
140
+ rank_zero_warn(
141
+ f"Skipping logging of {name} due to disabled context"
142
+ )
143
+ return
144
+ else:
145
+ break
146
+
147
+ fn_kwargs = {}
148
+ for c in self._logger_prefix_stack:
149
+ fn_kwargs.update(c.kwargs)
150
+ fn_kwargs.update(kwargs)
151
+
152
+ self._logger_actsave(name, value)
153
+
154
+ return super().log(name, value, **fn_kwargs)
155
+
156
+ def _logger_actsave(self, name: str, value: _METRIC) -> None:
157
+ hparams = cast(BaseConfig, self.hparams)
158
+ if (
159
+ not hparams.trainer.actsave
160
+ or not hparams.trainer.actsave.auto_save_logged_metrics
161
+ ):
162
+ return
163
+
164
+ ActSave.save(
165
+ {
166
+ f"logger.{name}": lambda: value.compute()
167
+ if isinstance(value, torchmetrics.Metric)
168
+ else value
169
+ }
170
+ )
@@ -0,0 +1,24 @@
1
+ from lightning.pytorch import LightningDataModule, LightningModule
2
+ from lightning.pytorch.profilers import PassThroughProfiler
3
+
4
+ from ...util.typing_utils import mixin_base_type
5
+
6
+
7
+ class ProfilerMixin(mixin_base_type(LightningModule)):
8
+ @property
9
+ def profiler(self):
10
+ if not isinstance(self, (LightningModule, LightningDataModule)):
11
+ raise TypeError(
12
+ "`profiler` can only be used on LightningModule or LightningDataModule"
13
+ )
14
+
15
+ if (trainer := self.trainer) is None:
16
+ raise RuntimeError("trainer is not defined")
17
+
18
+ if not hasattr(trainer, "profiler"):
19
+ raise RuntimeError("trainer does not have profiler")
20
+
21
+ if (profiler := getattr(trainer, "profiler")) is None:
22
+ profiler = PassThroughProfiler()
23
+
24
+ return profiler
@@ -0,0 +1,202 @@
1
+ from collections.abc import Mapping
2
+ from logging import getLogger
3
+ from typing import cast
4
+
5
+ import torch
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from lightning.pytorch.utilities.types import (
8
+ LRSchedulerConfigType,
9
+ LRSchedulerTypeUnion,
10
+ )
11
+ from typing_extensions import Protocol, override, runtime_checkable
12
+
13
+ from ...util.typing_utils import mixin_base_type
14
+ from ..config import BaseConfig
15
+ from .callback import CallbackModuleMixin
16
+
17
+ log = getLogger(__name__)
18
+
19
+
20
+ def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
21
+ # If we're in PL's "sanity check" mode, we don't need to run this check
22
+ if trainer.sanity_checking:
23
+ return
24
+
25
+ config = cast(BaseConfig, pl_module.hparams)
26
+ if config.trainer.sanity_checking.reduce_lr_on_plateau == "disable":
27
+ return
28
+
29
+ # if no lr schedulers, return
30
+ if not trainer.lr_scheduler_configs:
31
+ return
32
+
33
+ errors: list[str] = []
34
+ disable_message = (
35
+ "Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau='disable'` "
36
+ "to disable this sanity check."
37
+ )
38
+
39
+ for lr_scheduler_config in trainer.lr_scheduler_configs:
40
+ if not lr_scheduler_config.reduce_on_plateau:
41
+ continue
42
+
43
+ match lr_scheduler_config.interval:
44
+ case "epoch":
45
+ # we need to make sure that the trainer runs val every `frequency` epochs
46
+
47
+ # If `trainer.check_val_every_n_epoch` is None, then Lightning
48
+ # will run val every `int(trainer.val_check_interval)` steps.
49
+ # So, first we need to make sure that `trainer.val_check_interval` is not None first.
50
+ if trainer.check_val_every_n_epoch is None:
51
+ errors.append(
52
+ "Trainer is not running validation at epoch intervals "
53
+ "(i.e., `trainer.check_val_every_n_epoch` is None) but "
54
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
55
+ f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
56
+ + disable_message
57
+ )
58
+
59
+ # Second, we make sure that the trainer runs val at least every `frequency` epochs
60
+ if (
61
+ trainer.check_val_every_n_epoch is not None
62
+ and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch
63
+ != 0
64
+ ):
65
+ errors.append(
66
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
67
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
68
+ f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
69
+ + disable_message
70
+ )
71
+
72
+ case "step":
73
+ # In this case, we need to make sure that the trainer runs val at step intervals
74
+ # that are multiples of `frequency`.
75
+
76
+ # First, we make sure that validation is run at step intervals
77
+ if trainer.check_val_every_n_epoch is not None:
78
+ errors.append(
79
+ "Trainer is running validation at epoch intervals "
80
+ "(i.e., `trainer.check_val_every_n_epoch` is not None) but "
81
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
82
+ "Please set `config.trainer.check_val_every_n_epoch=None` "
83
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
84
+ + disable_message
85
+ )
86
+
87
+ # Second, we make sure `trainer.val_check_interval` is an integer
88
+ if not isinstance(trainer.val_check_interval, int):
89
+ errors.append(
90
+ f"Trainer is not running validation at step intervals "
91
+ f"(i.e., `trainer.val_check_interval` is not an integer) but "
92
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
93
+ "Please set `config.trainer.val_check_interval=None` "
94
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
95
+ + disable_message
96
+ )
97
+
98
+ # Third, we make sure that the trainer runs val at least every `frequency` steps
99
+ if (
100
+ isinstance(trainer.val_check_interval, int)
101
+ and trainer.val_check_interval % lr_scheduler_config.frequency != 0
102
+ ):
103
+ errors.append(
104
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
105
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
106
+ "Please set `config.trainer.val_check_interval` "
107
+ f"to a multiple of {lr_scheduler_config.frequency}. "
108
+ + disable_message
109
+ )
110
+
111
+ case _:
112
+ pass
113
+
114
+ if not errors:
115
+ return
116
+
117
+ message = (
118
+ "ReduceLRPlateau sanity checks failed with the following errors:\n"
119
+ + "\n".join(errors)
120
+ )
121
+ match config.trainer.sanity_checking.reduce_lr_on_plateau:
122
+ case "warn":
123
+ log.warning(message)
124
+ case "error":
125
+ raise ValueError(message)
126
+ case _:
127
+ pass
128
+
129
+
130
+ @runtime_checkable
131
+ class CustomRLPImplementation(Protocol):
132
+ __reduce_lr_on_plateau__: bool
133
+
134
+
135
+ class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)):
136
+ @override
137
+ def __init__(self, *args, **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+
140
+ global _on_train_start_callback
141
+ self.register_callback(on_train_start=_on_train_start_callback)
142
+
143
+ def reduce_lr_on_plateau_config(
144
+ self,
145
+ lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
146
+ ) -> LRSchedulerConfigType:
147
+ if (trainer := self._trainer) is None:
148
+ raise RuntimeError(
149
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
150
+ "because `self.trainer` is None."
151
+ )
152
+
153
+ # First, resolve the LR scheduler from the provided config.
154
+ lr_scheduler_config: LRSchedulerConfigType
155
+ match lr_scheduler:
156
+ case Mapping():
157
+ lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
158
+ case _:
159
+ lr_scheduler_config = {"scheduler": lr_scheduler}
160
+
161
+ # Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
162
+ if (
163
+ not isinstance(
164
+ lr_scheduler_config["scheduler"],
165
+ torch.optim.lr_scheduler.ReduceLROnPlateau,
166
+ )
167
+ ) and (
168
+ not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
169
+ or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
170
+ ):
171
+ log.warning(
172
+ "`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
173
+ f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
174
+ "`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
175
+ "Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
176
+ "If you are using a custom ReduceLRPlateau scheduler implementation, "
177
+ "please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
178
+ "or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
179
+ )
180
+
181
+ # If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
182
+ if trainer.check_val_every_n_epoch is not None:
183
+ return {
184
+ "reduce_on_plateau": True,
185
+ "interval": "epoch",
186
+ "frequency": trainer.check_val_every_n_epoch,
187
+ **lr_scheduler_config,
188
+ }
189
+
190
+ # Otherwise, we run val at step intervals.
191
+ if not isinstance(trainer.val_check_batch, int):
192
+ raise ValueError(
193
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
194
+ f"because {trainer.val_check_batch=} is not an integer."
195
+ )
196
+
197
+ return {
198
+ "reduce_on_plateau": True,
199
+ "interval": "step",
200
+ "frequency": trainer.val_check_batch,
201
+ **lr_scheduler_config,
202
+ }