nshtrainer 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/__init__.py +4 -1
- nshtrainer/callbacks/debug_flag.py +72 -0
- nshtrainer/model/base.py +2 -40
- nshtrainer/trainer/_config.py +9 -0
- {nshtrainer-0.31.0.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.31.0.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +7 -7
- nshtrainer/model/mixins/callback.py +0 -206
- {nshtrainer-0.31.0.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -12,6 +12,8 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
|
12
12
|
from .checkpoint import (
|
|
13
13
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
14
14
|
)
|
|
15
|
+
from .debug_flag import DebugFlagCallback as DebugFlagCallback
|
|
16
|
+
from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
|
15
17
|
from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
|
16
18
|
from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
|
|
17
19
|
from .early_stopping import EarlyStopping as EarlyStopping
|
|
@@ -41,7 +43,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
|
|
|
41
43
|
from .wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
42
44
|
|
|
43
45
|
CallbackConfig = Annotated[
|
|
44
|
-
|
|
46
|
+
DebugFlagCallbackConfig
|
|
47
|
+
| EarlyStoppingConfig
|
|
45
48
|
| ThroughputMonitorConfig
|
|
46
49
|
| EpochTimerConfig
|
|
47
50
|
| PrintTableMetricsConfig
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
5
|
+
from lightning.pytorch.callbacks import Callback
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from nshtrainer.model.config import BaseConfig
|
|
9
|
+
|
|
10
|
+
from .base import CallbackConfigBase
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..model.config import BaseConfig
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DebugFlagCallbackConfig(CallbackConfigBase):
|
|
19
|
+
name: Literal["debug_flag"] = "debug_flag"
|
|
20
|
+
|
|
21
|
+
enabled: bool = True
|
|
22
|
+
"""Whether to enable the callback."""
|
|
23
|
+
|
|
24
|
+
def __bool__(self):
|
|
25
|
+
return self.enabled
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
def create_callbacks(self, root_config):
|
|
29
|
+
if not self:
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
yield DebugFlagCallback(self)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DebugFlagCallback(Callback):
|
|
36
|
+
"""
|
|
37
|
+
Sets the debug flag to true in the following circumstances:
|
|
38
|
+
- fast_dev_run is enabled
|
|
39
|
+
- sanity check is running
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def __init__(self, config: DebugFlagCallbackConfig):
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
self.config = config
|
|
47
|
+
del config
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
|
|
51
|
+
if not getattr(trainer, "fast_dev_run", False):
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
hparams = cast("BaseConfig", pl_module.hparams)
|
|
55
|
+
if not hparams.debug:
|
|
56
|
+
log.critical("Fast dev run detected, setting debug flag to True.")
|
|
57
|
+
hparams.debug = True
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
61
|
+
hparams = cast("BaseConfig", pl_module.hparams)
|
|
62
|
+
self._debug = hparams.debug
|
|
63
|
+
if not self._debug:
|
|
64
|
+
log.critical("Enabling debug flag during sanity check routine.")
|
|
65
|
+
hparams.debug = True
|
|
66
|
+
|
|
67
|
+
@override
|
|
68
|
+
def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
69
|
+
hparams = cast("BaseConfig", pl_module.hparams)
|
|
70
|
+
if not self._debug:
|
|
71
|
+
log.critical("Sanity check routine complete, disabling debug flag.")
|
|
72
|
+
hparams.debug = self._debug
|
nshtrainer/model/base.py
CHANGED
|
@@ -7,8 +7,7 @@ from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.distributed
|
|
9
9
|
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
|
10
|
-
from lightning.pytorch import LightningModule
|
|
11
|
-
from lightning.pytorch.callbacks import Callback
|
|
10
|
+
from lightning.pytorch import LightningModule
|
|
12
11
|
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
|
13
12
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
14
13
|
from typing_extensions import Self, TypeVar, override
|
|
@@ -16,7 +15,6 @@ from typing_extensions import Self, TypeVar, override
|
|
|
16
15
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
|
17
16
|
from ..util._environment_info import EnvironmentConfig
|
|
18
17
|
from .config import BaseConfig
|
|
19
|
-
from .mixins.callback import CallbackModuleMixin
|
|
20
18
|
from .mixins.logger import LoggerLightningModuleMixin
|
|
21
19
|
|
|
22
20
|
log = logging.getLogger(__name__)
|
|
@@ -24,39 +22,6 @@ log = logging.getLogger(__name__)
|
|
|
24
22
|
THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
class DebugFlagCallback(Callback):
|
|
28
|
-
"""
|
|
29
|
-
Sets the debug flag to true in the following circumstances:
|
|
30
|
-
- fast_dev_run is enabled
|
|
31
|
-
- sanity check is running
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
@override
|
|
35
|
-
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
|
|
36
|
-
if not getattr(trainer, "fast_dev_run", False):
|
|
37
|
-
return
|
|
38
|
-
|
|
39
|
-
hparams = cast(BaseConfig, pl_module.hparams)
|
|
40
|
-
if not hparams.debug:
|
|
41
|
-
log.critical("Fast dev run detected, setting debug flag to True.")
|
|
42
|
-
hparams.debug = True
|
|
43
|
-
|
|
44
|
-
@override
|
|
45
|
-
def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
46
|
-
hparams = cast(BaseConfig, pl_module.hparams)
|
|
47
|
-
self._debug = hparams.debug
|
|
48
|
-
if not self._debug:
|
|
49
|
-
log.critical("Enabling debug flag during sanity check routine.")
|
|
50
|
-
hparams.debug = True
|
|
51
|
-
|
|
52
|
-
@override
|
|
53
|
-
def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
54
|
-
hparams = cast(BaseConfig, pl_module.hparams)
|
|
55
|
-
if not self._debug:
|
|
56
|
-
log.critical("Sanity check routine complete, disabling debug flag.")
|
|
57
|
-
hparams.debug = self._debug
|
|
58
|
-
|
|
59
|
-
|
|
60
25
|
T = TypeVar("T", infer_variance=True)
|
|
61
26
|
|
|
62
27
|
ReduceOpStr = Literal[
|
|
@@ -88,7 +53,6 @@ VALID_REDUCE_OPS = (
|
|
|
88
53
|
class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
89
54
|
_RLPSanityCheckModuleMixin,
|
|
90
55
|
LoggerLightningModuleMixin,
|
|
91
|
-
CallbackModuleMixin,
|
|
92
56
|
LightningModule,
|
|
93
57
|
ABC,
|
|
94
58
|
Generic[THparams],
|
|
@@ -288,10 +252,8 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
288
252
|
hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
|
|
289
253
|
hparams = self.pre_init_update_hparams(hparams)
|
|
290
254
|
|
|
291
|
-
super().__init__(
|
|
292
|
-
|
|
255
|
+
super().__init__()
|
|
293
256
|
self.save_hyperparameters(hparams)
|
|
294
|
-
self.register_callback(lambda: DebugFlagCallback())
|
|
295
257
|
|
|
296
258
|
def zero_loss(self):
|
|
297
259
|
"""
|
nshtrainer/trainer/_config.py
CHANGED
|
@@ -35,6 +35,7 @@ from ..callbacks import (
|
|
|
35
35
|
OnExceptionCheckpointCallbackConfig,
|
|
36
36
|
)
|
|
37
37
|
from ..callbacks.base import CallbackConfigBase
|
|
38
|
+
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
|
38
39
|
from ..callbacks.rlp_sanity_checks import RLPSanityChecksConfig
|
|
39
40
|
from ..callbacks.shared_parameters import SharedParametersConfig
|
|
40
41
|
from ..loggers import (
|
|
@@ -751,6 +752,11 @@ class TrainerConfig(C.Config):
|
|
|
751
752
|
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
|
752
753
|
save_checkpoint_metadata: bool = True
|
|
753
754
|
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
755
|
+
auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
|
|
756
|
+
"""If enabled, will automatically set the debug flag to True if:
|
|
757
|
+
- The trainer is running in fast_dev_run mode.
|
|
758
|
+
- The trainer is running a sanity check (which happens before starting the training routine).
|
|
759
|
+
"""
|
|
754
760
|
|
|
755
761
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
756
762
|
"""
|
|
@@ -775,4 +781,7 @@ class TrainerConfig(C.Config):
|
|
|
775
781
|
yield self.logging
|
|
776
782
|
yield self.optimizer
|
|
777
783
|
yield self.hf_hub
|
|
784
|
+
yield self.shared_parameters
|
|
785
|
+
yield self.reduce_lr_on_plateau_sanity_checking
|
|
786
|
+
yield self.auto_set_debug_flag
|
|
778
787
|
yield from self.callbacks
|
|
@@ -6,7 +6,7 @@ nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28Ad
|
|
|
6
6
|
nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
|
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
8
8
|
nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
|
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=1SBLpMsx7BzgimO35MwQViYBcbgxlkyvTMz1JKUKK-0,3060
|
|
10
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
@@ -15,6 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN
|
|
|
15
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
16
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
|
+
nshtrainer/callbacks/debug_flag.py,sha256=Mo69CtJqPWMlFBvgBEuYls8Vfp5v1QFiyMRTiMStdec,2059
|
|
18
19
|
nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
|
|
19
20
|
nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
|
|
20
21
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -62,9 +63,8 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-
|
|
|
62
63
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
63
64
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
64
65
|
nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
|
|
65
|
-
nshtrainer/model/base.py,sha256=
|
|
66
|
+
nshtrainer/model/base.py,sha256=1zVY8ybZTzVKhpp7sUC0t360Ut3YmdGxAW5PZAIBSyw,18535
|
|
66
67
|
nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
|
|
67
|
-
nshtrainer/model/mixins/callback.py,sha256=lh3imlw1H3ESIG4WFA5frooSlWi6-RPUUDRFGRzEg4A,8571
|
|
68
68
|
nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
|
|
69
69
|
nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
|
|
70
70
|
nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
|
|
@@ -80,7 +80,7 @@ nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y
|
|
|
80
80
|
nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
81
81
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
82
82
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
83
|
-
nshtrainer/trainer/_config.py,sha256=
|
|
83
|
+
nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM,29174
|
|
84
84
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
85
85
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
86
86
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
@@ -95,6 +95,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
95
95
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
96
96
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
97
97
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
98
|
-
nshtrainer-0.
|
|
99
|
-
nshtrainer-0.
|
|
100
|
-
nshtrainer-0.
|
|
98
|
+
nshtrainer-0.32.0.dist-info/METADATA,sha256=pe-TVRS0ZmZ9kx5NBQ8-0C6m4ZzaH_MalJZmh31mUNQ,916
|
|
99
|
+
nshtrainer-0.32.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
100
|
+
nshtrainer-0.32.0.dist-info/RECORD,,
|
|
@@ -1,206 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from collections.abc import Callable, Iterable, Sequence
|
|
3
|
-
from typing import Any, TypeAlias, cast, final, overload
|
|
4
|
-
|
|
5
|
-
from lightning.pytorch import Callback, LightningModule
|
|
6
|
-
from lightning.pytorch.callbacks import LambdaCallback
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from ...util.typing_utils import mixin_base_type
|
|
10
|
-
|
|
11
|
-
log = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class CallbackRegistrarModuleMixin:
|
|
17
|
-
@override
|
|
18
|
-
def __init__(self, *args, **kwargs):
|
|
19
|
-
super().__init__(*args, **kwargs)
|
|
20
|
-
|
|
21
|
-
self._nshtrainer_callbacks: list[CallbackFn] = []
|
|
22
|
-
|
|
23
|
-
@overload
|
|
24
|
-
def register_callback(
|
|
25
|
-
self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
|
|
26
|
-
): ...
|
|
27
|
-
|
|
28
|
-
@overload
|
|
29
|
-
def register_callback(
|
|
30
|
-
self,
|
|
31
|
-
/,
|
|
32
|
-
*,
|
|
33
|
-
setup: Callable | None = None,
|
|
34
|
-
teardown: Callable | None = None,
|
|
35
|
-
on_fit_start: Callable | None = None,
|
|
36
|
-
on_fit_end: Callable | None = None,
|
|
37
|
-
on_sanity_check_start: Callable | None = None,
|
|
38
|
-
on_sanity_check_end: Callable | None = None,
|
|
39
|
-
on_train_batch_start: Callable | None = None,
|
|
40
|
-
on_train_batch_end: Callable | None = None,
|
|
41
|
-
on_train_epoch_start: Callable | None = None,
|
|
42
|
-
on_train_epoch_end: Callable | None = None,
|
|
43
|
-
on_validation_epoch_start: Callable | None = None,
|
|
44
|
-
on_validation_epoch_end: Callable | None = None,
|
|
45
|
-
on_test_epoch_start: Callable | None = None,
|
|
46
|
-
on_test_epoch_end: Callable | None = None,
|
|
47
|
-
on_validation_batch_start: Callable | None = None,
|
|
48
|
-
on_validation_batch_end: Callable | None = None,
|
|
49
|
-
on_test_batch_start: Callable | None = None,
|
|
50
|
-
on_test_batch_end: Callable | None = None,
|
|
51
|
-
on_train_start: Callable | None = None,
|
|
52
|
-
on_train_end: Callable | None = None,
|
|
53
|
-
on_validation_start: Callable | None = None,
|
|
54
|
-
on_validation_end: Callable | None = None,
|
|
55
|
-
on_test_start: Callable | None = None,
|
|
56
|
-
on_test_end: Callable | None = None,
|
|
57
|
-
on_exception: Callable | None = None,
|
|
58
|
-
on_save_checkpoint: Callable | None = None,
|
|
59
|
-
on_load_checkpoint: Callable | None = None,
|
|
60
|
-
on_before_backward: Callable | None = None,
|
|
61
|
-
on_after_backward: Callable | None = None,
|
|
62
|
-
on_before_optimizer_step: Callable | None = None,
|
|
63
|
-
on_before_zero_grad: Callable | None = None,
|
|
64
|
-
on_predict_start: Callable | None = None,
|
|
65
|
-
on_predict_end: Callable | None = None,
|
|
66
|
-
on_predict_batch_start: Callable | None = None,
|
|
67
|
-
on_predict_batch_end: Callable | None = None,
|
|
68
|
-
on_predict_epoch_start: Callable | None = None,
|
|
69
|
-
on_predict_epoch_end: Callable | None = None,
|
|
70
|
-
): ...
|
|
71
|
-
|
|
72
|
-
def register_callback(
|
|
73
|
-
self,
|
|
74
|
-
callback: Callback | Iterable[Callback] | CallbackFn | None = None,
|
|
75
|
-
/,
|
|
76
|
-
*,
|
|
77
|
-
setup: Callable | None = None,
|
|
78
|
-
teardown: Callable | None = None,
|
|
79
|
-
on_fit_start: Callable | None = None,
|
|
80
|
-
on_fit_end: Callable | None = None,
|
|
81
|
-
on_sanity_check_start: Callable | None = None,
|
|
82
|
-
on_sanity_check_end: Callable | None = None,
|
|
83
|
-
on_train_batch_start: Callable | None = None,
|
|
84
|
-
on_train_batch_end: Callable | None = None,
|
|
85
|
-
on_train_epoch_start: Callable | None = None,
|
|
86
|
-
on_train_epoch_end: Callable | None = None,
|
|
87
|
-
on_validation_epoch_start: Callable | None = None,
|
|
88
|
-
on_validation_epoch_end: Callable | None = None,
|
|
89
|
-
on_test_epoch_start: Callable | None = None,
|
|
90
|
-
on_test_epoch_end: Callable | None = None,
|
|
91
|
-
on_validation_batch_start: Callable | None = None,
|
|
92
|
-
on_validation_batch_end: Callable | None = None,
|
|
93
|
-
on_test_batch_start: Callable | None = None,
|
|
94
|
-
on_test_batch_end: Callable | None = None,
|
|
95
|
-
on_train_start: Callable | None = None,
|
|
96
|
-
on_train_end: Callable | None = None,
|
|
97
|
-
on_validation_start: Callable | None = None,
|
|
98
|
-
on_validation_end: Callable | None = None,
|
|
99
|
-
on_test_start: Callable | None = None,
|
|
100
|
-
on_test_end: Callable | None = None,
|
|
101
|
-
on_exception: Callable | None = None,
|
|
102
|
-
on_save_checkpoint: Callable | None = None,
|
|
103
|
-
on_load_checkpoint: Callable | None = None,
|
|
104
|
-
on_before_backward: Callable | None = None,
|
|
105
|
-
on_after_backward: Callable | None = None,
|
|
106
|
-
on_before_optimizer_step: Callable | None = None,
|
|
107
|
-
on_before_zero_grad: Callable | None = None,
|
|
108
|
-
on_predict_start: Callable | None = None,
|
|
109
|
-
on_predict_end: Callable | None = None,
|
|
110
|
-
on_predict_batch_start: Callable | None = None,
|
|
111
|
-
on_predict_batch_end: Callable | None = None,
|
|
112
|
-
on_predict_epoch_start: Callable | None = None,
|
|
113
|
-
on_predict_epoch_end: Callable | None = None,
|
|
114
|
-
):
|
|
115
|
-
if callback is None:
|
|
116
|
-
callback = LambdaCallback(
|
|
117
|
-
setup=setup,
|
|
118
|
-
teardown=teardown,
|
|
119
|
-
on_fit_start=on_fit_start,
|
|
120
|
-
on_fit_end=on_fit_end,
|
|
121
|
-
on_sanity_check_start=on_sanity_check_start,
|
|
122
|
-
on_sanity_check_end=on_sanity_check_end,
|
|
123
|
-
on_train_batch_start=on_train_batch_start,
|
|
124
|
-
on_train_batch_end=on_train_batch_end,
|
|
125
|
-
on_train_epoch_start=on_train_epoch_start,
|
|
126
|
-
on_train_epoch_end=on_train_epoch_end,
|
|
127
|
-
on_validation_epoch_start=on_validation_epoch_start,
|
|
128
|
-
on_validation_epoch_end=on_validation_epoch_end,
|
|
129
|
-
on_test_epoch_start=on_test_epoch_start,
|
|
130
|
-
on_test_epoch_end=on_test_epoch_end,
|
|
131
|
-
on_validation_batch_start=on_validation_batch_start,
|
|
132
|
-
on_validation_batch_end=on_validation_batch_end,
|
|
133
|
-
on_test_batch_start=on_test_batch_start,
|
|
134
|
-
on_test_batch_end=on_test_batch_end,
|
|
135
|
-
on_train_start=on_train_start,
|
|
136
|
-
on_train_end=on_train_end,
|
|
137
|
-
on_validation_start=on_validation_start,
|
|
138
|
-
on_validation_end=on_validation_end,
|
|
139
|
-
on_test_start=on_test_start,
|
|
140
|
-
on_test_end=on_test_end,
|
|
141
|
-
on_exception=on_exception,
|
|
142
|
-
on_save_checkpoint=on_save_checkpoint,
|
|
143
|
-
on_load_checkpoint=on_load_checkpoint,
|
|
144
|
-
on_before_backward=on_before_backward,
|
|
145
|
-
on_after_backward=on_after_backward,
|
|
146
|
-
on_before_optimizer_step=on_before_optimizer_step,
|
|
147
|
-
on_before_zero_grad=on_before_zero_grad,
|
|
148
|
-
on_predict_start=on_predict_start,
|
|
149
|
-
on_predict_end=on_predict_end,
|
|
150
|
-
on_predict_batch_start=on_predict_batch_start,
|
|
151
|
-
on_predict_batch_end=on_predict_batch_end,
|
|
152
|
-
on_predict_epoch_start=on_predict_epoch_start,
|
|
153
|
-
on_predict_epoch_end=on_predict_epoch_end,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
if not callable(callback):
|
|
157
|
-
callback_ = cast(CallbackFn, lambda: callback)
|
|
158
|
-
else:
|
|
159
|
-
callback_ = callback
|
|
160
|
-
|
|
161
|
-
self._nshtrainer_callbacks.append(callback_)
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
class CallbackModuleMixin(
|
|
165
|
-
CallbackRegistrarModuleMixin,
|
|
166
|
-
mixin_base_type(LightningModule),
|
|
167
|
-
):
|
|
168
|
-
def _nshtrainer_gather_all_callbacks(self):
|
|
169
|
-
modules: list[Any] = []
|
|
170
|
-
if isinstance(self, CallbackRegistrarModuleMixin):
|
|
171
|
-
modules.append(self)
|
|
172
|
-
if (
|
|
173
|
-
datamodule := getattr(self.trainer, "datamodule", None)
|
|
174
|
-
) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
|
|
175
|
-
modules.append(datamodule)
|
|
176
|
-
modules.extend(
|
|
177
|
-
module
|
|
178
|
-
for module in self.children()
|
|
179
|
-
if isinstance(module, CallbackRegistrarModuleMixin)
|
|
180
|
-
)
|
|
181
|
-
for module in modules:
|
|
182
|
-
yield from module._nshtrainer_callbacks
|
|
183
|
-
|
|
184
|
-
@final
|
|
185
|
-
@override
|
|
186
|
-
def configure_callbacks(self):
|
|
187
|
-
callbacks = super().configure_callbacks()
|
|
188
|
-
if not isinstance(callbacks, Sequence):
|
|
189
|
-
callbacks = [callbacks]
|
|
190
|
-
|
|
191
|
-
callbacks = list(callbacks)
|
|
192
|
-
for callback_fn in self._nshtrainer_gather_all_callbacks():
|
|
193
|
-
callback_result = callback_fn()
|
|
194
|
-
if callback_result is None:
|
|
195
|
-
continue
|
|
196
|
-
|
|
197
|
-
if not isinstance(callback_result, Iterable):
|
|
198
|
-
callback_result = [callback_result]
|
|
199
|
-
|
|
200
|
-
for callback in callback_result:
|
|
201
|
-
log.info(
|
|
202
|
-
f"Registering {callback.__class__.__qualname__} callback {callback}"
|
|
203
|
-
)
|
|
204
|
-
callbacks.append(callback)
|
|
205
|
-
|
|
206
|
-
return callbacks
|
|
File without changes
|