nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from typing import Any, TypeAlias, cast, final
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch import Callback, LightningModule
|
|
7
|
+
from lightning.pytorch.callbacks import LambdaCallback
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from ...util.typing_utils import mixin_base_type
|
|
11
|
+
|
|
12
|
+
log = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CallbackRegistrarModuleMixin:
|
|
18
|
+
@override
|
|
19
|
+
def __init__(self, *args, **kwargs):
|
|
20
|
+
super().__init__(*args, **kwargs)
|
|
21
|
+
|
|
22
|
+
self._ll_callbacks: list[CallbackFn] = []
|
|
23
|
+
|
|
24
|
+
def register_callback(
|
|
25
|
+
self,
|
|
26
|
+
callback: Callback | Iterable[Callback] | CallbackFn | None = None,
|
|
27
|
+
*,
|
|
28
|
+
setup: Callable | None = None,
|
|
29
|
+
teardown: Callable | None = None,
|
|
30
|
+
on_fit_start: Callable | None = None,
|
|
31
|
+
on_fit_end: Callable | None = None,
|
|
32
|
+
on_sanity_check_start: Callable | None = None,
|
|
33
|
+
on_sanity_check_end: Callable | None = None,
|
|
34
|
+
on_train_batch_start: Callable | None = None,
|
|
35
|
+
on_train_batch_end: Callable | None = None,
|
|
36
|
+
on_train_epoch_start: Callable | None = None,
|
|
37
|
+
on_train_epoch_end: Callable | None = None,
|
|
38
|
+
on_validation_epoch_start: Callable | None = None,
|
|
39
|
+
on_validation_epoch_end: Callable | None = None,
|
|
40
|
+
on_test_epoch_start: Callable | None = None,
|
|
41
|
+
on_test_epoch_end: Callable | None = None,
|
|
42
|
+
on_validation_batch_start: Callable | None = None,
|
|
43
|
+
on_validation_batch_end: Callable | None = None,
|
|
44
|
+
on_test_batch_start: Callable | None = None,
|
|
45
|
+
on_test_batch_end: Callable | None = None,
|
|
46
|
+
on_train_start: Callable | None = None,
|
|
47
|
+
on_train_end: Callable | None = None,
|
|
48
|
+
on_validation_start: Callable | None = None,
|
|
49
|
+
on_validation_end: Callable | None = None,
|
|
50
|
+
on_test_start: Callable | None = None,
|
|
51
|
+
on_test_end: Callable | None = None,
|
|
52
|
+
on_exception: Callable | None = None,
|
|
53
|
+
on_save_checkpoint: Callable | None = None,
|
|
54
|
+
on_load_checkpoint: Callable | None = None,
|
|
55
|
+
on_before_backward: Callable | None = None,
|
|
56
|
+
on_after_backward: Callable | None = None,
|
|
57
|
+
on_before_optimizer_step: Callable | None = None,
|
|
58
|
+
on_before_zero_grad: Callable | None = None,
|
|
59
|
+
on_predict_start: Callable | None = None,
|
|
60
|
+
on_predict_end: Callable | None = None,
|
|
61
|
+
on_predict_batch_start: Callable | None = None,
|
|
62
|
+
on_predict_batch_end: Callable | None = None,
|
|
63
|
+
on_predict_epoch_start: Callable | None = None,
|
|
64
|
+
on_predict_epoch_end: Callable | None = None,
|
|
65
|
+
):
|
|
66
|
+
if callback is None:
|
|
67
|
+
callback = LambdaCallback(
|
|
68
|
+
setup=setup,
|
|
69
|
+
teardown=teardown,
|
|
70
|
+
on_fit_start=on_fit_start,
|
|
71
|
+
on_fit_end=on_fit_end,
|
|
72
|
+
on_sanity_check_start=on_sanity_check_start,
|
|
73
|
+
on_sanity_check_end=on_sanity_check_end,
|
|
74
|
+
on_train_batch_start=on_train_batch_start,
|
|
75
|
+
on_train_batch_end=on_train_batch_end,
|
|
76
|
+
on_train_epoch_start=on_train_epoch_start,
|
|
77
|
+
on_train_epoch_end=on_train_epoch_end,
|
|
78
|
+
on_validation_epoch_start=on_validation_epoch_start,
|
|
79
|
+
on_validation_epoch_end=on_validation_epoch_end,
|
|
80
|
+
on_test_epoch_start=on_test_epoch_start,
|
|
81
|
+
on_test_epoch_end=on_test_epoch_end,
|
|
82
|
+
on_validation_batch_start=on_validation_batch_start,
|
|
83
|
+
on_validation_batch_end=on_validation_batch_end,
|
|
84
|
+
on_test_batch_start=on_test_batch_start,
|
|
85
|
+
on_test_batch_end=on_test_batch_end,
|
|
86
|
+
on_train_start=on_train_start,
|
|
87
|
+
on_train_end=on_train_end,
|
|
88
|
+
on_validation_start=on_validation_start,
|
|
89
|
+
on_validation_end=on_validation_end,
|
|
90
|
+
on_test_start=on_test_start,
|
|
91
|
+
on_test_end=on_test_end,
|
|
92
|
+
on_exception=on_exception,
|
|
93
|
+
on_save_checkpoint=on_save_checkpoint,
|
|
94
|
+
on_load_checkpoint=on_load_checkpoint,
|
|
95
|
+
on_before_backward=on_before_backward,
|
|
96
|
+
on_after_backward=on_after_backward,
|
|
97
|
+
on_before_optimizer_step=on_before_optimizer_step,
|
|
98
|
+
on_before_zero_grad=on_before_zero_grad,
|
|
99
|
+
on_predict_start=on_predict_start,
|
|
100
|
+
on_predict_end=on_predict_end,
|
|
101
|
+
on_predict_batch_start=on_predict_batch_start,
|
|
102
|
+
on_predict_batch_end=on_predict_batch_end,
|
|
103
|
+
on_predict_epoch_start=on_predict_epoch_start,
|
|
104
|
+
on_predict_epoch_end=on_predict_epoch_end,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if not callable(callback):
|
|
108
|
+
callback_ = cast(CallbackFn, lambda: callback)
|
|
109
|
+
else:
|
|
110
|
+
callback_ = callback
|
|
111
|
+
|
|
112
|
+
self._ll_callbacks.append(callback_)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class CallbackModuleMixin(
|
|
116
|
+
CallbackRegistrarModuleMixin,
|
|
117
|
+
mixin_base_type(LightningModule),
|
|
118
|
+
):
|
|
119
|
+
def _gather_all_callbacks(self):
|
|
120
|
+
modules: list[Any] = []
|
|
121
|
+
if isinstance(self, CallbackRegistrarModuleMixin):
|
|
122
|
+
modules.append(self)
|
|
123
|
+
if (
|
|
124
|
+
datamodule := getattr(self.trainer, "datamodule", None)
|
|
125
|
+
) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
|
|
126
|
+
modules.append(datamodule)
|
|
127
|
+
modules.extend(
|
|
128
|
+
module
|
|
129
|
+
for module in self.children()
|
|
130
|
+
if isinstance(module, CallbackRegistrarModuleMixin)
|
|
131
|
+
)
|
|
132
|
+
for module in modules:
|
|
133
|
+
yield from module._ll_callbacks
|
|
134
|
+
|
|
135
|
+
@final
|
|
136
|
+
@override
|
|
137
|
+
def configure_callbacks(self):
|
|
138
|
+
callbacks = super().configure_callbacks()
|
|
139
|
+
if not isinstance(callbacks, abc.Sequence):
|
|
140
|
+
callbacks = [callbacks]
|
|
141
|
+
|
|
142
|
+
callbacks = list(callbacks)
|
|
143
|
+
for callback_fn in self._gather_all_callbacks():
|
|
144
|
+
callback_result = callback_fn()
|
|
145
|
+
if callback_result is None:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
if not isinstance(callback_result, abc.Iterable):
|
|
149
|
+
callback_result = [callback_result]
|
|
150
|
+
|
|
151
|
+
for callback in callback_result:
|
|
152
|
+
log.info(
|
|
153
|
+
f"Registering {callback.__class__.__qualname__} callback {callback}"
|
|
154
|
+
)
|
|
155
|
+
callbacks.append(callback)
|
|
156
|
+
|
|
157
|
+
return callbacks
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.distributed
|
|
5
|
+
|
|
6
|
+
log = getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DebugModuleMixin:
|
|
10
|
+
@torch.jit.unused
|
|
11
|
+
def breakpoint(self, rank_zero_only: bool = True):
|
|
12
|
+
if (
|
|
13
|
+
not rank_zero_only
|
|
14
|
+
or not torch.distributed.is_initialized()
|
|
15
|
+
or torch.distributed.get_rank() == 0
|
|
16
|
+
):
|
|
17
|
+
breakpoint()
|
|
18
|
+
|
|
19
|
+
if rank_zero_only and torch.distributed.is_initialized():
|
|
20
|
+
_ = torch.distributed.barrier()
|
|
21
|
+
|
|
22
|
+
@torch.jit.unused
|
|
23
|
+
def ensure_finite(
|
|
24
|
+
self,
|
|
25
|
+
tensor: torch.Tensor,
|
|
26
|
+
name: str | None = None,
|
|
27
|
+
throw: bool = False,
|
|
28
|
+
):
|
|
29
|
+
name_parts: list[str] = ["Tensor"]
|
|
30
|
+
if name is not None:
|
|
31
|
+
name_parts.append(name)
|
|
32
|
+
name = " ".join(name_parts)
|
|
33
|
+
|
|
34
|
+
not_finite = ~torch.isfinite(tensor)
|
|
35
|
+
if not_finite.any():
|
|
36
|
+
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
|
37
|
+
if throw:
|
|
38
|
+
raise RuntimeError(msg)
|
|
39
|
+
else:
|
|
40
|
+
log.warning(msg)
|
|
41
|
+
return False
|
|
42
|
+
return True
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from typing import Any, Literal, cast
|
|
2
|
+
|
|
3
|
+
import torch.distributed
|
|
4
|
+
from lightning.pytorch import LightningModule
|
|
5
|
+
from torch.distributed import ReduceOp
|
|
6
|
+
from typing_extensions import TypeVar
|
|
7
|
+
|
|
8
|
+
from ...util.typing_utils import mixin_base_type
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T", infer_variance=True)
|
|
11
|
+
|
|
12
|
+
ReduceOpStr = Literal[
|
|
13
|
+
"avg",
|
|
14
|
+
"mean",
|
|
15
|
+
"band",
|
|
16
|
+
"bor",
|
|
17
|
+
"bxor",
|
|
18
|
+
"max",
|
|
19
|
+
"min",
|
|
20
|
+
"premul_sum",
|
|
21
|
+
"product",
|
|
22
|
+
"sum",
|
|
23
|
+
]
|
|
24
|
+
VALID_REDUCE_OPS = (
|
|
25
|
+
"avg",
|
|
26
|
+
"mean",
|
|
27
|
+
"band",
|
|
28
|
+
"bor",
|
|
29
|
+
"bxor",
|
|
30
|
+
"max",
|
|
31
|
+
"min",
|
|
32
|
+
"premul_sum",
|
|
33
|
+
"product",
|
|
34
|
+
"sum",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DistributedMixin(mixin_base_type(LightningModule)):
|
|
39
|
+
def all_gather_object(
|
|
40
|
+
self,
|
|
41
|
+
object: T,
|
|
42
|
+
group: torch.distributed.ProcessGroup | None = None,
|
|
43
|
+
) -> list[T]:
|
|
44
|
+
if (
|
|
45
|
+
not torch.distributed.is_available()
|
|
46
|
+
or not torch.distributed.is_initialized()
|
|
47
|
+
):
|
|
48
|
+
return [object]
|
|
49
|
+
|
|
50
|
+
object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
|
|
51
|
+
torch.distributed.all_gather_object(object_list, object, group=group)
|
|
52
|
+
return object_list
|
|
53
|
+
|
|
54
|
+
def barrier(self, name: str | None = None):
|
|
55
|
+
self.trainer.strategy.barrier(name=name)
|
|
56
|
+
|
|
57
|
+
def reduce(
|
|
58
|
+
self,
|
|
59
|
+
tensor: torch.Tensor,
|
|
60
|
+
reduce_op: ReduceOp.RedOpType | ReduceOpStr,
|
|
61
|
+
group: Any | None = None,
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
if isinstance(reduce_op, str):
|
|
64
|
+
# validate reduce_op
|
|
65
|
+
if reduce_op not in VALID_REDUCE_OPS:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Callable, Generator
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
7
|
+
|
|
8
|
+
import torchmetrics
|
|
9
|
+
from lightning.pytorch import LightningDataModule, LightningModule
|
|
10
|
+
from lightning.pytorch.utilities.types import _METRIC
|
|
11
|
+
from lightning_utilities.core.rank_zero import rank_zero_warn
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from ...actsave import ActSave
|
|
15
|
+
from ...util.typing_utils import mixin_base_type
|
|
16
|
+
from ..config import BaseConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, kw_only=True)
|
|
20
|
+
class _LogContext:
|
|
21
|
+
prefix: str | None = None
|
|
22
|
+
disabled: bool | None = None
|
|
23
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LoggerModuleMixin:
|
|
27
|
+
@property
|
|
28
|
+
def log_dir(self):
|
|
29
|
+
if not isinstance(self, (LightningModule, LightningDataModule)):
|
|
30
|
+
raise TypeError(
|
|
31
|
+
"log_dir can only be used on LightningModule or LightningDataModule"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if (trainer := self.trainer) is None:
|
|
35
|
+
raise RuntimeError("trainer is not defined")
|
|
36
|
+
|
|
37
|
+
if (logger := trainer.logger) is None:
|
|
38
|
+
raise RuntimeError("trainer.logger is not defined")
|
|
39
|
+
|
|
40
|
+
if (log_dir := logger.log_dir) is None:
|
|
41
|
+
raise RuntimeError("trainer.logger.log_dir is not defined")
|
|
42
|
+
|
|
43
|
+
return Path(log_dir)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def should_update_logs(self):
|
|
47
|
+
if not isinstance(self, (LightningModule, LightningDataModule)):
|
|
48
|
+
raise TypeError(
|
|
49
|
+
"should_update_logs can only be used on LightningModule or LightningDataModule"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
trainer = self._trainer if isinstance(self, LightningModule) else self.trainer
|
|
53
|
+
if trainer is None:
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
return trainer._logger_connector.should_update_logs
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningModule)):
|
|
60
|
+
@override
|
|
61
|
+
def __init__(self, *args, **kwargs):
|
|
62
|
+
super().__init__(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
self._logger_prefix_stack = deque[_LogContext]()
|
|
65
|
+
|
|
66
|
+
if TYPE_CHECKING:
|
|
67
|
+
|
|
68
|
+
@contextmanager
|
|
69
|
+
def log_context(
|
|
70
|
+
self,
|
|
71
|
+
prefix: str | None = None,
|
|
72
|
+
*,
|
|
73
|
+
disabled: bool | None = None,
|
|
74
|
+
prog_bar: bool | None = None,
|
|
75
|
+
logger: bool | None = None,
|
|
76
|
+
on_step: bool | None = None,
|
|
77
|
+
on_epoch: bool | None = None,
|
|
78
|
+
reduce_fx: str | Callable | None = None,
|
|
79
|
+
enable_graph: bool | None = None,
|
|
80
|
+
sync_dist: bool | None = None,
|
|
81
|
+
sync_dist_group: Any | None = None,
|
|
82
|
+
add_dataloader_idx: bool | None = None,
|
|
83
|
+
batch_size: int | None = None,
|
|
84
|
+
rank_zero_only: bool | None = None,
|
|
85
|
+
) -> Generator[None, None, None]: ...
|
|
86
|
+
|
|
87
|
+
else:
|
|
88
|
+
|
|
89
|
+
@contextmanager
|
|
90
|
+
def log_context(
|
|
91
|
+
self, prefix: str | None = None, *, disabled: bool | None = None, **kwargs
|
|
92
|
+
) -> Generator[None, None, None]:
|
|
93
|
+
self._logger_prefix_stack.append(
|
|
94
|
+
_LogContext(
|
|
95
|
+
prefix=prefix,
|
|
96
|
+
disabled=disabled,
|
|
97
|
+
kwargs=kwargs,
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
try:
|
|
101
|
+
yield
|
|
102
|
+
finally:
|
|
103
|
+
_ = self._logger_prefix_stack.pop()
|
|
104
|
+
|
|
105
|
+
if TYPE_CHECKING:
|
|
106
|
+
|
|
107
|
+
@override
|
|
108
|
+
def log( # type: ignore[override]
|
|
109
|
+
self,
|
|
110
|
+
name: str,
|
|
111
|
+
value: _METRIC,
|
|
112
|
+
*,
|
|
113
|
+
prog_bar: bool = False,
|
|
114
|
+
logger: bool | None = None,
|
|
115
|
+
on_step: bool | None = None,
|
|
116
|
+
on_epoch: bool | None = None,
|
|
117
|
+
reduce_fx: str | Callable = "mean",
|
|
118
|
+
enable_graph: bool = False,
|
|
119
|
+
sync_dist: bool = False,
|
|
120
|
+
sync_dist_group: Any | None = None,
|
|
121
|
+
add_dataloader_idx: bool = True,
|
|
122
|
+
batch_size: int | None = None,
|
|
123
|
+
metric_attribute: str | None = None,
|
|
124
|
+
rank_zero_only: bool = False,
|
|
125
|
+
) -> None: ...
|
|
126
|
+
|
|
127
|
+
else:
|
|
128
|
+
|
|
129
|
+
@override
|
|
130
|
+
def log(self, name: str, value: _METRIC, **kwargs) -> None:
|
|
131
|
+
# join all prefixes
|
|
132
|
+
prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
|
|
133
|
+
name = f"{prefix}{name}"
|
|
134
|
+
|
|
135
|
+
# check for disabled context:
|
|
136
|
+
# if the topmost non-null context is disabled, then we don't log
|
|
137
|
+
for c in reversed(self._logger_prefix_stack):
|
|
138
|
+
if c.disabled is not None:
|
|
139
|
+
if c.disabled:
|
|
140
|
+
rank_zero_warn(
|
|
141
|
+
f"Skipping logging of {name} due to disabled context"
|
|
142
|
+
)
|
|
143
|
+
return
|
|
144
|
+
else:
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
fn_kwargs = {}
|
|
148
|
+
for c in self._logger_prefix_stack:
|
|
149
|
+
fn_kwargs.update(c.kwargs)
|
|
150
|
+
fn_kwargs.update(kwargs)
|
|
151
|
+
|
|
152
|
+
self._logger_actsave(name, value)
|
|
153
|
+
|
|
154
|
+
return super().log(name, value, **fn_kwargs)
|
|
155
|
+
|
|
156
|
+
def _logger_actsave(self, name: str, value: _METRIC) -> None:
|
|
157
|
+
hparams = cast(BaseConfig, self.hparams)
|
|
158
|
+
if (
|
|
159
|
+
not hparams.trainer.actsave
|
|
160
|
+
or not hparams.trainer.actsave.auto_save_logged_metrics
|
|
161
|
+
):
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
ActSave.save(
|
|
165
|
+
{
|
|
166
|
+
f"logger.{name}": lambda: value.compute()
|
|
167
|
+
if isinstance(value, torchmetrics.Metric)
|
|
168
|
+
else value
|
|
169
|
+
}
|
|
170
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from lightning.pytorch import LightningDataModule, LightningModule
|
|
2
|
+
from lightning.pytorch.profilers import PassThroughProfiler
|
|
3
|
+
|
|
4
|
+
from ...util.typing_utils import mixin_base_type
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ProfilerMixin(mixin_base_type(LightningModule)):
|
|
8
|
+
@property
|
|
9
|
+
def profiler(self):
|
|
10
|
+
if not isinstance(self, (LightningModule, LightningDataModule)):
|
|
11
|
+
raise TypeError(
|
|
12
|
+
"`profiler` can only be used on LightningModule or LightningDataModule"
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
if (trainer := self.trainer) is None:
|
|
16
|
+
raise RuntimeError("trainer is not defined")
|
|
17
|
+
|
|
18
|
+
if not hasattr(trainer, "profiler"):
|
|
19
|
+
raise RuntimeError("trainer does not have profiler")
|
|
20
|
+
|
|
21
|
+
if (profiler := getattr(trainer, "profiler")) is None:
|
|
22
|
+
profiler = PassThroughProfiler()
|
|
23
|
+
|
|
24
|
+
return profiler
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
+
from lightning.pytorch.utilities.types import (
|
|
8
|
+
LRSchedulerConfigType,
|
|
9
|
+
LRSchedulerTypeUnion,
|
|
10
|
+
)
|
|
11
|
+
from typing_extensions import Protocol, override, runtime_checkable
|
|
12
|
+
|
|
13
|
+
from ...util.typing_utils import mixin_base_type
|
|
14
|
+
from ..config import BaseConfig
|
|
15
|
+
from .callback import CallbackModuleMixin
|
|
16
|
+
|
|
17
|
+
log = getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
|
|
21
|
+
# If we're in PL's "sanity check" mode, we don't need to run this check
|
|
22
|
+
if trainer.sanity_checking:
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
config = cast(BaseConfig, pl_module.hparams)
|
|
26
|
+
if config.trainer.sanity_checking.reduce_lr_on_plateau == "disable":
|
|
27
|
+
return
|
|
28
|
+
|
|
29
|
+
# if no lr schedulers, return
|
|
30
|
+
if not trainer.lr_scheduler_configs:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
errors: list[str] = []
|
|
34
|
+
disable_message = (
|
|
35
|
+
"Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau='disable'` "
|
|
36
|
+
"to disable this sanity check."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
for lr_scheduler_config in trainer.lr_scheduler_configs:
|
|
40
|
+
if not lr_scheduler_config.reduce_on_plateau:
|
|
41
|
+
continue
|
|
42
|
+
|
|
43
|
+
match lr_scheduler_config.interval:
|
|
44
|
+
case "epoch":
|
|
45
|
+
# we need to make sure that the trainer runs val every `frequency` epochs
|
|
46
|
+
|
|
47
|
+
# If `trainer.check_val_every_n_epoch` is None, then Lightning
|
|
48
|
+
# will run val every `int(trainer.val_check_interval)` steps.
|
|
49
|
+
# So, first we need to make sure that `trainer.val_check_interval` is not None first.
|
|
50
|
+
if trainer.check_val_every_n_epoch is None:
|
|
51
|
+
errors.append(
|
|
52
|
+
"Trainer is not running validation at epoch intervals "
|
|
53
|
+
"(i.e., `trainer.check_val_every_n_epoch` is None) but "
|
|
54
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
55
|
+
f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
|
|
56
|
+
+ disable_message
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Second, we make sure that the trainer runs val at least every `frequency` epochs
|
|
60
|
+
if (
|
|
61
|
+
trainer.check_val_every_n_epoch is not None
|
|
62
|
+
and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch
|
|
63
|
+
!= 0
|
|
64
|
+
):
|
|
65
|
+
errors.append(
|
|
66
|
+
f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
|
|
67
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
68
|
+
f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
|
|
69
|
+
+ disable_message
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
case "step":
|
|
73
|
+
# In this case, we need to make sure that the trainer runs val at step intervals
|
|
74
|
+
# that are multiples of `frequency`.
|
|
75
|
+
|
|
76
|
+
# First, we make sure that validation is run at step intervals
|
|
77
|
+
if trainer.check_val_every_n_epoch is not None:
|
|
78
|
+
errors.append(
|
|
79
|
+
"Trainer is running validation at epoch intervals "
|
|
80
|
+
"(i.e., `trainer.check_val_every_n_epoch` is not None) but "
|
|
81
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
82
|
+
"Please set `config.trainer.check_val_every_n_epoch=None` "
|
|
83
|
+
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
84
|
+
+ disable_message
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Second, we make sure `trainer.val_check_interval` is an integer
|
|
88
|
+
if not isinstance(trainer.val_check_interval, int):
|
|
89
|
+
errors.append(
|
|
90
|
+
f"Trainer is not running validation at step intervals "
|
|
91
|
+
f"(i.e., `trainer.val_check_interval` is not an integer) but "
|
|
92
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
93
|
+
"Please set `config.trainer.val_check_interval=None` "
|
|
94
|
+
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
95
|
+
+ disable_message
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Third, we make sure that the trainer runs val at least every `frequency` steps
|
|
99
|
+
if (
|
|
100
|
+
isinstance(trainer.val_check_interval, int)
|
|
101
|
+
and trainer.val_check_interval % lr_scheduler_config.frequency != 0
|
|
102
|
+
):
|
|
103
|
+
errors.append(
|
|
104
|
+
f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
|
|
105
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
106
|
+
"Please set `config.trainer.val_check_interval` "
|
|
107
|
+
f"to a multiple of {lr_scheduler_config.frequency}. "
|
|
108
|
+
+ disable_message
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
case _:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
if not errors:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
message = (
|
|
118
|
+
"ReduceLRPlateau sanity checks failed with the following errors:\n"
|
|
119
|
+
+ "\n".join(errors)
|
|
120
|
+
)
|
|
121
|
+
match config.trainer.sanity_checking.reduce_lr_on_plateau:
|
|
122
|
+
case "warn":
|
|
123
|
+
log.warning(message)
|
|
124
|
+
case "error":
|
|
125
|
+
raise ValueError(message)
|
|
126
|
+
case _:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@runtime_checkable
|
|
131
|
+
class CustomRLPImplementation(Protocol):
|
|
132
|
+
__reduce_lr_on_plateau__: bool
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)):
|
|
136
|
+
@override
|
|
137
|
+
def __init__(self, *args, **kwargs):
|
|
138
|
+
super().__init__(*args, **kwargs)
|
|
139
|
+
|
|
140
|
+
global _on_train_start_callback
|
|
141
|
+
self.register_callback(on_train_start=_on_train_start_callback)
|
|
142
|
+
|
|
143
|
+
def reduce_lr_on_plateau_config(
|
|
144
|
+
self,
|
|
145
|
+
lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
|
|
146
|
+
) -> LRSchedulerConfigType:
|
|
147
|
+
if (trainer := self._trainer) is None:
|
|
148
|
+
raise RuntimeError(
|
|
149
|
+
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
150
|
+
"because `self.trainer` is None."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# First, resolve the LR scheduler from the provided config.
|
|
154
|
+
lr_scheduler_config: LRSchedulerConfigType
|
|
155
|
+
match lr_scheduler:
|
|
156
|
+
case Mapping():
|
|
157
|
+
lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
|
|
158
|
+
case _:
|
|
159
|
+
lr_scheduler_config = {"scheduler": lr_scheduler}
|
|
160
|
+
|
|
161
|
+
# Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
|
|
162
|
+
if (
|
|
163
|
+
not isinstance(
|
|
164
|
+
lr_scheduler_config["scheduler"],
|
|
165
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
166
|
+
)
|
|
167
|
+
) and (
|
|
168
|
+
not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
|
|
169
|
+
or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
|
|
170
|
+
):
|
|
171
|
+
log.warning(
|
|
172
|
+
"`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
|
|
173
|
+
f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
|
|
174
|
+
"`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
|
|
175
|
+
"Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
|
|
176
|
+
"If you are using a custom ReduceLRPlateau scheduler implementation, "
|
|
177
|
+
"please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
|
|
178
|
+
"or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
|
|
182
|
+
if trainer.check_val_every_n_epoch is not None:
|
|
183
|
+
return {
|
|
184
|
+
"reduce_on_plateau": True,
|
|
185
|
+
"interval": "epoch",
|
|
186
|
+
"frequency": trainer.check_val_every_n_epoch,
|
|
187
|
+
**lr_scheduler_config,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
# Otherwise, we run val at step intervals.
|
|
191
|
+
if not isinstance(trainer.val_check_batch, int):
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
194
|
+
f"because {trainer.val_check_batch=} is not an integer."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return {
|
|
198
|
+
"reduce_on_plateau": True,
|
|
199
|
+
"interval": "step",
|
|
200
|
+
"frequency": trainer.val_check_batch,
|
|
201
|
+
**lr_scheduler_config,
|
|
202
|
+
}
|