nshtrainer 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,8 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
12
12
  from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
+ from .debug_flag import DebugFlagCallback as DebugFlagCallback
16
+ from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
15
17
  from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
16
18
  from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
17
19
  from .early_stopping import EarlyStopping as EarlyStopping
@@ -41,7 +43,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
41
43
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
42
44
 
43
45
  CallbackConfig = Annotated[
44
- EarlyStoppingConfig
46
+ DebugFlagCallbackConfig
47
+ | EarlyStoppingConfig
45
48
  | ThroughputMonitorConfig
46
49
  | EpochTimerConfig
47
50
  | PrintTableMetricsConfig
@@ -0,0 +1,70 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Literal, cast
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from lightning.pytorch.callbacks import Callback
6
+ from typing_extensions import override
7
+
8
+ from .base import CallbackConfigBase
9
+
10
+ if TYPE_CHECKING:
11
+ from ..model.config import BaseConfig
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class DebugFlagCallbackConfig(CallbackConfigBase):
17
+ name: Literal["debug_flag"] = "debug_flag"
18
+
19
+ enabled: bool = True
20
+ """Whether to enable the callback."""
21
+
22
+ def __bool__(self):
23
+ return self.enabled
24
+
25
+ @override
26
+ def create_callbacks(self, root_config):
27
+ if not self:
28
+ return
29
+
30
+ yield DebugFlagCallback(self)
31
+
32
+
33
+ class DebugFlagCallback(Callback):
34
+ """
35
+ Sets the debug flag to true in the following circumstances:
36
+ - fast_dev_run is enabled
37
+ - sanity check is running
38
+ """
39
+
40
+ @override
41
+ def __init__(self, config: DebugFlagCallbackConfig):
42
+ super().__init__()
43
+
44
+ self.config = config
45
+ del config
46
+
47
+ @override
48
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
49
+ if not getattr(trainer, "fast_dev_run", False):
50
+ return
51
+
52
+ hparams = cast("BaseConfig", pl_module.hparams)
53
+ if not hparams.debug:
54
+ log.critical("Fast dev run detected, setting debug flag to True.")
55
+ hparams.debug = True
56
+
57
+ @override
58
+ def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
59
+ hparams = cast("BaseConfig", pl_module.hparams)
60
+ self._debug = hparams.debug
61
+ if not self._debug:
62
+ log.critical("Enabling debug flag during sanity check routine.")
63
+ hparams.debug = True
64
+
65
+ @override
66
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
67
+ hparams = cast("BaseConfig", pl_module.hparams)
68
+ if not self._debug:
69
+ log.critical("Sanity check routine complete, disabling debug flag.")
70
+ hparams.debug = self._debug
nshtrainer/config.py CHANGED
@@ -11,6 +11,7 @@ from nshtrainer.callbacks.checkpoint._base import BaseCheckpointCallbackConfig a
11
11
  from nshtrainer.callbacks.checkpoint.best_checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
12
12
  from nshtrainer.callbacks.checkpoint.last_checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
13
13
  from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
14
+ from nshtrainer.callbacks.debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
14
15
  from nshtrainer.callbacks.directory_setup import DirectorySetupConfig as DirectorySetupConfig
15
16
  from nshtrainer.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
16
17
  from nshtrainer.callbacks.ema import EMAConfig as EMAConfig
@@ -7,8 +7,8 @@ from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from packaging import version
8
8
  from typing_extensions import override
9
9
 
10
- from ..callbacks import WandbWatchConfig
11
10
  from ..callbacks.base import CallbackConfigBase
11
+ from ..callbacks.wandb_watch import WandbWatchConfig
12
12
  from ._base import BaseLoggerConfig
13
13
 
14
14
  if TYPE_CHECKING:
nshtrainer/model/base.py CHANGED
@@ -7,8 +7,7 @@ from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
7
7
  import torch
8
8
  import torch.distributed
9
9
  from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
10
- from lightning.pytorch import LightningModule, Trainer
11
- from lightning.pytorch.callbacks import Callback
10
+ from lightning.pytorch import LightningModule
12
11
  from lightning.pytorch.profilers import PassThroughProfiler, Profiler
13
12
  from lightning.pytorch.utilities.types import STEP_OUTPUT
14
13
  from typing_extensions import Self, TypeVar, override
@@ -16,7 +15,6 @@ from typing_extensions import Self, TypeVar, override
16
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
17
16
  from ..util._environment_info import EnvironmentConfig
18
17
  from .config import BaseConfig
19
- from .mixins.callback import CallbackModuleMixin
20
18
  from .mixins.logger import LoggerLightningModuleMixin
21
19
 
22
20
  log = logging.getLogger(__name__)
@@ -24,39 +22,6 @@ log = logging.getLogger(__name__)
24
22
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
25
23
 
26
24
 
27
- class DebugFlagCallback(Callback):
28
- """
29
- Sets the debug flag to true in the following circumstances:
30
- - fast_dev_run is enabled
31
- - sanity check is running
32
- """
33
-
34
- @override
35
- def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
36
- if not getattr(trainer, "fast_dev_run", False):
37
- return
38
-
39
- hparams = cast(BaseConfig, pl_module.hparams)
40
- if not hparams.debug:
41
- log.critical("Fast dev run detected, setting debug flag to True.")
42
- hparams.debug = True
43
-
44
- @override
45
- def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
46
- hparams = cast(BaseConfig, pl_module.hparams)
47
- self._debug = hparams.debug
48
- if not self._debug:
49
- log.critical("Enabling debug flag during sanity check routine.")
50
- hparams.debug = True
51
-
52
- @override
53
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
54
- hparams = cast(BaseConfig, pl_module.hparams)
55
- if not self._debug:
56
- log.critical("Sanity check routine complete, disabling debug flag.")
57
- hparams.debug = self._debug
58
-
59
-
60
25
  T = TypeVar("T", infer_variance=True)
61
26
 
62
27
  ReduceOpStr = Literal[
@@ -88,7 +53,6 @@ VALID_REDUCE_OPS = (
88
53
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
89
54
  _RLPSanityCheckModuleMixin,
90
55
  LoggerLightningModuleMixin,
91
- CallbackModuleMixin,
92
56
  LightningModule,
93
57
  ABC,
94
58
  Generic[THparams],
@@ -288,10 +252,8 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
288
252
  hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
289
253
  hparams = self.pre_init_update_hparams(hparams)
290
254
 
291
- super().__init__(hparams)
292
-
255
+ super().__init__()
293
256
  self.save_hyperparameters(hparams)
294
- self.register_callback(lambda: DebugFlagCallback())
295
257
 
296
258
  def zero_loss(self):
297
259
  """
@@ -35,6 +35,7 @@ from ..callbacks import (
35
35
  OnExceptionCheckpointCallbackConfig,
36
36
  )
37
37
  from ..callbacks.base import CallbackConfigBase
38
+ from ..callbacks.debug_flag import DebugFlagCallbackConfig
38
39
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksConfig
39
40
  from ..callbacks.shared_parameters import SharedParametersConfig
40
41
  from ..loggers import (
@@ -751,6 +752,11 @@ class TrainerConfig(C.Config):
751
752
  """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
752
753
  save_checkpoint_metadata: bool = True
753
754
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
755
+ auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
756
+ """If enabled, will automatically set the debug flag to True if:
757
+ - The trainer is running in fast_dev_run mode.
758
+ - The trainer is running a sanity check (which happens before starting the training routine).
759
+ """
754
760
 
755
761
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
756
762
  """
@@ -775,4 +781,7 @@ class TrainerConfig(C.Config):
775
781
  yield self.logging
776
782
  yield self.optimizer
777
783
  yield self.hf_hub
784
+ yield self.shared_parameters
785
+ yield self.reduce_lr_on_plateau_sanity_checking
786
+ yield self.auto_set_debug_flag
778
787
  yield from self.callbacks
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.31.0
3
+ Version: 0.32.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -6,7 +6,7 @@ nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28Ad
6
6
  nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
7
7
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
8
8
  nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
9
- nshtrainer/callbacks/__init__.py,sha256=NpEV8bMU12ClFN2sLKLBDXnuwIHYyZOCNxDZgjrV104,2892
9
+ nshtrainer/callbacks/__init__.py,sha256=1SBLpMsx7BzgimO35MwQViYBcbgxlkyvTMz1JKUKK-0,3060
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
@@ -15,6 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN
15
15
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
16
16
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
18
+ nshtrainer/callbacks/debug_flag.py,sha256=T7rkY9hYQ_-PsPo2XiQ4eVZ9bBsTd2knpZWctCbjxXc,2011
18
19
  nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
19
20
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
20
21
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -29,7 +30,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5Aeh
29
30
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
30
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
31
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
32
- nshtrainer/config.py,sha256=W6nAmn5Y1GVZto9vkx4v8i5XdikMSdVYDiq7kbDEWAg,5900
33
+ nshtrainer/config.py,sha256=HJWKMFGNFHmuk92KlpYpEIhY01Ysnqr4HOWx4npGVH0,5995
33
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
34
35
  nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
35
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
@@ -54,7 +55,7 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
54
55
  nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
55
56
  nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
56
57
  nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
57
- nshtrainer/loggers/wandb.py,sha256=8B2BMMzILRSUEiCkmp_fBpcXs69euRKViTiaV__DJZk,5128
58
+ nshtrainer/loggers/wandb.py,sha256=C-yGX9e2FUSfbUxur7-meNUjpB3D8hIdVCOgPzGm3QM,5140
58
59
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
59
60
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
60
61
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=YQm84Sb4SWrofpBwa39DCslJvu2uorjbpWaGWyys1l4,5352
@@ -62,9 +63,8 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-
62
63
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
63
64
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
64
65
  nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
65
- nshtrainer/model/base.py,sha256=hT27FtzwKQiEL0C8RcaTKYXlanfvzTxHOJpHUcWiItk,19891
66
+ nshtrainer/model/base.py,sha256=1zVY8ybZTzVKhpp7sUC0t360Ut3YmdGxAW5PZAIBSyw,18535
66
67
  nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
67
- nshtrainer/model/mixins/callback.py,sha256=lh3imlw1H3ESIG4WFA5frooSlWi6-RPUUDRFGRzEg4A,8571
68
68
  nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
69
69
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
70
70
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
@@ -80,7 +80,7 @@ nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y
80
80
  nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
81
81
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
82
82
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
83
- nshtrainer/trainer/_config.py,sha256=2SgO1L5VBfWQ5g7Dg2dTx_vq2_Wo7dTqt2A4GlQaGo0,28673
83
+ nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM,29174
84
84
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
85
85
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
86
86
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
@@ -95,6 +95,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
95
95
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
96
96
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
97
97
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
98
- nshtrainer-0.31.0.dist-info/METADATA,sha256=99b-8IvPlMmTrjyb5EK1kKsgKj8lWhGw4gZvM5sKyzc,916
99
- nshtrainer-0.31.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
100
- nshtrainer-0.31.0.dist-info/RECORD,,
98
+ nshtrainer-0.32.1.dist-info/METADATA,sha256=zGSKc6CY965hgKixgUgeAHv3VvIOsDJ4NdCeDIKTTAs,916
99
+ nshtrainer-0.32.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
100
+ nshtrainer-0.32.1.dist-info/RECORD,,
@@ -1,206 +0,0 @@
1
- import logging
2
- from collections.abc import Callable, Iterable, Sequence
3
- from typing import Any, TypeAlias, cast, final, overload
4
-
5
- from lightning.pytorch import Callback, LightningModule
6
- from lightning.pytorch.callbacks import LambdaCallback
7
- from typing_extensions import override
8
-
9
- from ...util.typing_utils import mixin_base_type
10
-
11
- log = logging.getLogger(__name__)
12
-
13
- CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
14
-
15
-
16
- class CallbackRegistrarModuleMixin:
17
- @override
18
- def __init__(self, *args, **kwargs):
19
- super().__init__(*args, **kwargs)
20
-
21
- self._nshtrainer_callbacks: list[CallbackFn] = []
22
-
23
- @overload
24
- def register_callback(
25
- self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
26
- ): ...
27
-
28
- @overload
29
- def register_callback(
30
- self,
31
- /,
32
- *,
33
- setup: Callable | None = None,
34
- teardown: Callable | None = None,
35
- on_fit_start: Callable | None = None,
36
- on_fit_end: Callable | None = None,
37
- on_sanity_check_start: Callable | None = None,
38
- on_sanity_check_end: Callable | None = None,
39
- on_train_batch_start: Callable | None = None,
40
- on_train_batch_end: Callable | None = None,
41
- on_train_epoch_start: Callable | None = None,
42
- on_train_epoch_end: Callable | None = None,
43
- on_validation_epoch_start: Callable | None = None,
44
- on_validation_epoch_end: Callable | None = None,
45
- on_test_epoch_start: Callable | None = None,
46
- on_test_epoch_end: Callable | None = None,
47
- on_validation_batch_start: Callable | None = None,
48
- on_validation_batch_end: Callable | None = None,
49
- on_test_batch_start: Callable | None = None,
50
- on_test_batch_end: Callable | None = None,
51
- on_train_start: Callable | None = None,
52
- on_train_end: Callable | None = None,
53
- on_validation_start: Callable | None = None,
54
- on_validation_end: Callable | None = None,
55
- on_test_start: Callable | None = None,
56
- on_test_end: Callable | None = None,
57
- on_exception: Callable | None = None,
58
- on_save_checkpoint: Callable | None = None,
59
- on_load_checkpoint: Callable | None = None,
60
- on_before_backward: Callable | None = None,
61
- on_after_backward: Callable | None = None,
62
- on_before_optimizer_step: Callable | None = None,
63
- on_before_zero_grad: Callable | None = None,
64
- on_predict_start: Callable | None = None,
65
- on_predict_end: Callable | None = None,
66
- on_predict_batch_start: Callable | None = None,
67
- on_predict_batch_end: Callable | None = None,
68
- on_predict_epoch_start: Callable | None = None,
69
- on_predict_epoch_end: Callable | None = None,
70
- ): ...
71
-
72
- def register_callback(
73
- self,
74
- callback: Callback | Iterable[Callback] | CallbackFn | None = None,
75
- /,
76
- *,
77
- setup: Callable | None = None,
78
- teardown: Callable | None = None,
79
- on_fit_start: Callable | None = None,
80
- on_fit_end: Callable | None = None,
81
- on_sanity_check_start: Callable | None = None,
82
- on_sanity_check_end: Callable | None = None,
83
- on_train_batch_start: Callable | None = None,
84
- on_train_batch_end: Callable | None = None,
85
- on_train_epoch_start: Callable | None = None,
86
- on_train_epoch_end: Callable | None = None,
87
- on_validation_epoch_start: Callable | None = None,
88
- on_validation_epoch_end: Callable | None = None,
89
- on_test_epoch_start: Callable | None = None,
90
- on_test_epoch_end: Callable | None = None,
91
- on_validation_batch_start: Callable | None = None,
92
- on_validation_batch_end: Callable | None = None,
93
- on_test_batch_start: Callable | None = None,
94
- on_test_batch_end: Callable | None = None,
95
- on_train_start: Callable | None = None,
96
- on_train_end: Callable | None = None,
97
- on_validation_start: Callable | None = None,
98
- on_validation_end: Callable | None = None,
99
- on_test_start: Callable | None = None,
100
- on_test_end: Callable | None = None,
101
- on_exception: Callable | None = None,
102
- on_save_checkpoint: Callable | None = None,
103
- on_load_checkpoint: Callable | None = None,
104
- on_before_backward: Callable | None = None,
105
- on_after_backward: Callable | None = None,
106
- on_before_optimizer_step: Callable | None = None,
107
- on_before_zero_grad: Callable | None = None,
108
- on_predict_start: Callable | None = None,
109
- on_predict_end: Callable | None = None,
110
- on_predict_batch_start: Callable | None = None,
111
- on_predict_batch_end: Callable | None = None,
112
- on_predict_epoch_start: Callable | None = None,
113
- on_predict_epoch_end: Callable | None = None,
114
- ):
115
- if callback is None:
116
- callback = LambdaCallback(
117
- setup=setup,
118
- teardown=teardown,
119
- on_fit_start=on_fit_start,
120
- on_fit_end=on_fit_end,
121
- on_sanity_check_start=on_sanity_check_start,
122
- on_sanity_check_end=on_sanity_check_end,
123
- on_train_batch_start=on_train_batch_start,
124
- on_train_batch_end=on_train_batch_end,
125
- on_train_epoch_start=on_train_epoch_start,
126
- on_train_epoch_end=on_train_epoch_end,
127
- on_validation_epoch_start=on_validation_epoch_start,
128
- on_validation_epoch_end=on_validation_epoch_end,
129
- on_test_epoch_start=on_test_epoch_start,
130
- on_test_epoch_end=on_test_epoch_end,
131
- on_validation_batch_start=on_validation_batch_start,
132
- on_validation_batch_end=on_validation_batch_end,
133
- on_test_batch_start=on_test_batch_start,
134
- on_test_batch_end=on_test_batch_end,
135
- on_train_start=on_train_start,
136
- on_train_end=on_train_end,
137
- on_validation_start=on_validation_start,
138
- on_validation_end=on_validation_end,
139
- on_test_start=on_test_start,
140
- on_test_end=on_test_end,
141
- on_exception=on_exception,
142
- on_save_checkpoint=on_save_checkpoint,
143
- on_load_checkpoint=on_load_checkpoint,
144
- on_before_backward=on_before_backward,
145
- on_after_backward=on_after_backward,
146
- on_before_optimizer_step=on_before_optimizer_step,
147
- on_before_zero_grad=on_before_zero_grad,
148
- on_predict_start=on_predict_start,
149
- on_predict_end=on_predict_end,
150
- on_predict_batch_start=on_predict_batch_start,
151
- on_predict_batch_end=on_predict_batch_end,
152
- on_predict_epoch_start=on_predict_epoch_start,
153
- on_predict_epoch_end=on_predict_epoch_end,
154
- )
155
-
156
- if not callable(callback):
157
- callback_ = cast(CallbackFn, lambda: callback)
158
- else:
159
- callback_ = callback
160
-
161
- self._nshtrainer_callbacks.append(callback_)
162
-
163
-
164
- class CallbackModuleMixin(
165
- CallbackRegistrarModuleMixin,
166
- mixin_base_type(LightningModule),
167
- ):
168
- def _nshtrainer_gather_all_callbacks(self):
169
- modules: list[Any] = []
170
- if isinstance(self, CallbackRegistrarModuleMixin):
171
- modules.append(self)
172
- if (
173
- datamodule := getattr(self.trainer, "datamodule", None)
174
- ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
175
- modules.append(datamodule)
176
- modules.extend(
177
- module
178
- for module in self.children()
179
- if isinstance(module, CallbackRegistrarModuleMixin)
180
- )
181
- for module in modules:
182
- yield from module._nshtrainer_callbacks
183
-
184
- @final
185
- @override
186
- def configure_callbacks(self):
187
- callbacks = super().configure_callbacks()
188
- if not isinstance(callbacks, Sequence):
189
- callbacks = [callbacks]
190
-
191
- callbacks = list(callbacks)
192
- for callback_fn in self._nshtrainer_gather_all_callbacks():
193
- callback_result = callback_fn()
194
- if callback_result is None:
195
- continue
196
-
197
- if not isinstance(callback_result, Iterable):
198
- callback_result = [callback_result]
199
-
200
- for callback in callback_result:
201
- log.info(
202
- f"Registering {callback.__class__.__qualname__} callback {callback}"
203
- )
204
- callbacks.append(callback)
205
-
206
- return callbacks