nshtrainer 0.31.0__tar.gz → 0.32.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/__init__.py +4 -1
  4. nshtrainer-0.32.0/src/nshtrainer/callbacks/debug_flag.py +72 -0
  5. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/model/base.py +2 -40
  6. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/trainer/_config.py +9 -0
  7. nshtrainer-0.31.0/src/nshtrainer/model/mixins/callback.py +0 -206
  8. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/README.md +0 -0
  9. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/__init__.py +0 -0
  10. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_callback.py +0 -0
  11. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  12. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  13. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  14. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_directory.py +0 -0
  15. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  16. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/_hf_hub.py +0 -0
  17. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  18. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  19. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/base.py +0 -0
  20. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  21. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  22. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  23. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  24. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  25. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  26. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  27. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/ema.py +0 -0
  28. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  29. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  30. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/interval.py +0 -0
  31. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  32. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  33. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  34. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  35. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  36. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  37. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/timer.py +0 -0
  38. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  39. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/config.py +0 -0
  40. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/data/__init__.py +0 -0
  41. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  42. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/data/transform.py +0 -0
  43. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/__init__.py +0 -0
  44. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/_experimental.py +0 -0
  45. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/actsave.py +0 -0
  46. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/callbacks.py +0 -0
  47. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/config.py +0 -0
  48. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/data.py +0 -0
  49. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/log.py +0 -0
  50. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  51. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/model.py +0 -0
  52. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/nn.py +0 -0
  53. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/optimizer.py +0 -0
  54. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/runner.py +0 -0
  55. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/snapshot.py +0 -0
  56. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/snoop.py +0 -0
  57. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/trainer.py +0 -0
  58. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/typecheck.py +0 -0
  59. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/ll/util.py +0 -0
  60. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/loggers/__init__.py +0 -0
  61. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/loggers/_base.py +0 -0
  62. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/loggers/csv.py +0 -0
  63. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  64. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/loggers/wandb.py +0 -0
  65. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  66. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  67. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  68. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  69. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/metrics/__init__.py +0 -0
  70. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/metrics/_config.py +0 -0
  71. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/model/__init__.py +0 -0
  72. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/model/config.py +0 -0
  73. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/model/mixins/logger.py +0 -0
  74. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/nn/__init__.py +0 -0
  75. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/nn/mlp.py +0 -0
  76. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/nn/module_dict.py +0 -0
  77. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/nn/module_list.py +0 -0
  78. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  79. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/optimizer.py +0 -0
  80. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/profiler/__init__.py +0 -0
  81. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/profiler/_base.py +0 -0
  82. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/profiler/advanced.py +0 -0
  83. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  84. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/profiler/simple.py +0 -0
  85. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/runner.py +0 -0
  86. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  87. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/trainer/__init__.py +0 -0
  88. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  89. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  90. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  91. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/trainer/trainer.py +0 -0
  92. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/_environment_info.py +0 -0
  93. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/_useful_types.py +0 -0
  94. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/config/__init__.py +0 -0
  95. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/config/duration.py +0 -0
  96. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/environment.py +0 -0
  97. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/path.py +0 -0
  98. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/seed.py +0 -0
  99. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/slurm.py +0 -0
  100. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/typed.py +0 -0
  101. {nshtrainer-0.31.0 → nshtrainer-0.32.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.31.0
3
+ Version: 0.32.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.31.0"
3
+ version = "0.32.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -12,6 +12,8 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
12
12
  from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
+ from .debug_flag import DebugFlagCallback as DebugFlagCallback
16
+ from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
15
17
  from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
16
18
  from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
17
19
  from .early_stopping import EarlyStopping as EarlyStopping
@@ -41,7 +43,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
41
43
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
42
44
 
43
45
  CallbackConfig = Annotated[
44
- EarlyStoppingConfig
46
+ DebugFlagCallbackConfig
47
+ | EarlyStoppingConfig
45
48
  | ThroughputMonitorConfig
46
49
  | EpochTimerConfig
47
50
  | PrintTableMetricsConfig
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Literal, cast
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from lightning.pytorch.callbacks import Callback
6
+ from typing_extensions import override
7
+
8
+ from nshtrainer.model.config import BaseConfig
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ if TYPE_CHECKING:
13
+ from ..model.config import BaseConfig
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class DebugFlagCallbackConfig(CallbackConfigBase):
19
+ name: Literal["debug_flag"] = "debug_flag"
20
+
21
+ enabled: bool = True
22
+ """Whether to enable the callback."""
23
+
24
+ def __bool__(self):
25
+ return self.enabled
26
+
27
+ @override
28
+ def create_callbacks(self, root_config):
29
+ if not self:
30
+ return
31
+
32
+ yield DebugFlagCallback(self)
33
+
34
+
35
+ class DebugFlagCallback(Callback):
36
+ """
37
+ Sets the debug flag to true in the following circumstances:
38
+ - fast_dev_run is enabled
39
+ - sanity check is running
40
+ """
41
+
42
+ @override
43
+ def __init__(self, config: DebugFlagCallbackConfig):
44
+ super().__init__()
45
+
46
+ self.config = config
47
+ del config
48
+
49
+ @override
50
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
51
+ if not getattr(trainer, "fast_dev_run", False):
52
+ return
53
+
54
+ hparams = cast("BaseConfig", pl_module.hparams)
55
+ if not hparams.debug:
56
+ log.critical("Fast dev run detected, setting debug flag to True.")
57
+ hparams.debug = True
58
+
59
+ @override
60
+ def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
61
+ hparams = cast("BaseConfig", pl_module.hparams)
62
+ self._debug = hparams.debug
63
+ if not self._debug:
64
+ log.critical("Enabling debug flag during sanity check routine.")
65
+ hparams.debug = True
66
+
67
+ @override
68
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
69
+ hparams = cast("BaseConfig", pl_module.hparams)
70
+ if not self._debug:
71
+ log.critical("Sanity check routine complete, disabling debug flag.")
72
+ hparams.debug = self._debug
@@ -7,8 +7,7 @@ from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
7
7
  import torch
8
8
  import torch.distributed
9
9
  from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
10
- from lightning.pytorch import LightningModule, Trainer
11
- from lightning.pytorch.callbacks import Callback
10
+ from lightning.pytorch import LightningModule
12
11
  from lightning.pytorch.profilers import PassThroughProfiler, Profiler
13
12
  from lightning.pytorch.utilities.types import STEP_OUTPUT
14
13
  from typing_extensions import Self, TypeVar, override
@@ -16,7 +15,6 @@ from typing_extensions import Self, TypeVar, override
16
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
17
16
  from ..util._environment_info import EnvironmentConfig
18
17
  from .config import BaseConfig
19
- from .mixins.callback import CallbackModuleMixin
20
18
  from .mixins.logger import LoggerLightningModuleMixin
21
19
 
22
20
  log = logging.getLogger(__name__)
@@ -24,39 +22,6 @@ log = logging.getLogger(__name__)
24
22
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
25
23
 
26
24
 
27
- class DebugFlagCallback(Callback):
28
- """
29
- Sets the debug flag to true in the following circumstances:
30
- - fast_dev_run is enabled
31
- - sanity check is running
32
- """
33
-
34
- @override
35
- def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
36
- if not getattr(trainer, "fast_dev_run", False):
37
- return
38
-
39
- hparams = cast(BaseConfig, pl_module.hparams)
40
- if not hparams.debug:
41
- log.critical("Fast dev run detected, setting debug flag to True.")
42
- hparams.debug = True
43
-
44
- @override
45
- def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
46
- hparams = cast(BaseConfig, pl_module.hparams)
47
- self._debug = hparams.debug
48
- if not self._debug:
49
- log.critical("Enabling debug flag during sanity check routine.")
50
- hparams.debug = True
51
-
52
- @override
53
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
54
- hparams = cast(BaseConfig, pl_module.hparams)
55
- if not self._debug:
56
- log.critical("Sanity check routine complete, disabling debug flag.")
57
- hparams.debug = self._debug
58
-
59
-
60
25
  T = TypeVar("T", infer_variance=True)
61
26
 
62
27
  ReduceOpStr = Literal[
@@ -88,7 +53,6 @@ VALID_REDUCE_OPS = (
88
53
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
89
54
  _RLPSanityCheckModuleMixin,
90
55
  LoggerLightningModuleMixin,
91
- CallbackModuleMixin,
92
56
  LightningModule,
93
57
  ABC,
94
58
  Generic[THparams],
@@ -288,10 +252,8 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
288
252
  hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
289
253
  hparams = self.pre_init_update_hparams(hparams)
290
254
 
291
- super().__init__(hparams)
292
-
255
+ super().__init__()
293
256
  self.save_hyperparameters(hparams)
294
- self.register_callback(lambda: DebugFlagCallback())
295
257
 
296
258
  def zero_loss(self):
297
259
  """
@@ -35,6 +35,7 @@ from ..callbacks import (
35
35
  OnExceptionCheckpointCallbackConfig,
36
36
  )
37
37
  from ..callbacks.base import CallbackConfigBase
38
+ from ..callbacks.debug_flag import DebugFlagCallbackConfig
38
39
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksConfig
39
40
  from ..callbacks.shared_parameters import SharedParametersConfig
40
41
  from ..loggers import (
@@ -751,6 +752,11 @@ class TrainerConfig(C.Config):
751
752
  """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
752
753
  save_checkpoint_metadata: bool = True
753
754
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
755
+ auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
756
+ """If enabled, will automatically set the debug flag to True if:
757
+ - The trainer is running in fast_dev_run mode.
758
+ - The trainer is running a sanity check (which happens before starting the training routine).
759
+ """
754
760
 
755
761
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
756
762
  """
@@ -775,4 +781,7 @@ class TrainerConfig(C.Config):
775
781
  yield self.logging
776
782
  yield self.optimizer
777
783
  yield self.hf_hub
784
+ yield self.shared_parameters
785
+ yield self.reduce_lr_on_plateau_sanity_checking
786
+ yield self.auto_set_debug_flag
778
787
  yield from self.callbacks
@@ -1,206 +0,0 @@
1
- import logging
2
- from collections.abc import Callable, Iterable, Sequence
3
- from typing import Any, TypeAlias, cast, final, overload
4
-
5
- from lightning.pytorch import Callback, LightningModule
6
- from lightning.pytorch.callbacks import LambdaCallback
7
- from typing_extensions import override
8
-
9
- from ...util.typing_utils import mixin_base_type
10
-
11
- log = logging.getLogger(__name__)
12
-
13
- CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
14
-
15
-
16
- class CallbackRegistrarModuleMixin:
17
- @override
18
- def __init__(self, *args, **kwargs):
19
- super().__init__(*args, **kwargs)
20
-
21
- self._nshtrainer_callbacks: list[CallbackFn] = []
22
-
23
- @overload
24
- def register_callback(
25
- self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
26
- ): ...
27
-
28
- @overload
29
- def register_callback(
30
- self,
31
- /,
32
- *,
33
- setup: Callable | None = None,
34
- teardown: Callable | None = None,
35
- on_fit_start: Callable | None = None,
36
- on_fit_end: Callable | None = None,
37
- on_sanity_check_start: Callable | None = None,
38
- on_sanity_check_end: Callable | None = None,
39
- on_train_batch_start: Callable | None = None,
40
- on_train_batch_end: Callable | None = None,
41
- on_train_epoch_start: Callable | None = None,
42
- on_train_epoch_end: Callable | None = None,
43
- on_validation_epoch_start: Callable | None = None,
44
- on_validation_epoch_end: Callable | None = None,
45
- on_test_epoch_start: Callable | None = None,
46
- on_test_epoch_end: Callable | None = None,
47
- on_validation_batch_start: Callable | None = None,
48
- on_validation_batch_end: Callable | None = None,
49
- on_test_batch_start: Callable | None = None,
50
- on_test_batch_end: Callable | None = None,
51
- on_train_start: Callable | None = None,
52
- on_train_end: Callable | None = None,
53
- on_validation_start: Callable | None = None,
54
- on_validation_end: Callable | None = None,
55
- on_test_start: Callable | None = None,
56
- on_test_end: Callable | None = None,
57
- on_exception: Callable | None = None,
58
- on_save_checkpoint: Callable | None = None,
59
- on_load_checkpoint: Callable | None = None,
60
- on_before_backward: Callable | None = None,
61
- on_after_backward: Callable | None = None,
62
- on_before_optimizer_step: Callable | None = None,
63
- on_before_zero_grad: Callable | None = None,
64
- on_predict_start: Callable | None = None,
65
- on_predict_end: Callable | None = None,
66
- on_predict_batch_start: Callable | None = None,
67
- on_predict_batch_end: Callable | None = None,
68
- on_predict_epoch_start: Callable | None = None,
69
- on_predict_epoch_end: Callable | None = None,
70
- ): ...
71
-
72
- def register_callback(
73
- self,
74
- callback: Callback | Iterable[Callback] | CallbackFn | None = None,
75
- /,
76
- *,
77
- setup: Callable | None = None,
78
- teardown: Callable | None = None,
79
- on_fit_start: Callable | None = None,
80
- on_fit_end: Callable | None = None,
81
- on_sanity_check_start: Callable | None = None,
82
- on_sanity_check_end: Callable | None = None,
83
- on_train_batch_start: Callable | None = None,
84
- on_train_batch_end: Callable | None = None,
85
- on_train_epoch_start: Callable | None = None,
86
- on_train_epoch_end: Callable | None = None,
87
- on_validation_epoch_start: Callable | None = None,
88
- on_validation_epoch_end: Callable | None = None,
89
- on_test_epoch_start: Callable | None = None,
90
- on_test_epoch_end: Callable | None = None,
91
- on_validation_batch_start: Callable | None = None,
92
- on_validation_batch_end: Callable | None = None,
93
- on_test_batch_start: Callable | None = None,
94
- on_test_batch_end: Callable | None = None,
95
- on_train_start: Callable | None = None,
96
- on_train_end: Callable | None = None,
97
- on_validation_start: Callable | None = None,
98
- on_validation_end: Callable | None = None,
99
- on_test_start: Callable | None = None,
100
- on_test_end: Callable | None = None,
101
- on_exception: Callable | None = None,
102
- on_save_checkpoint: Callable | None = None,
103
- on_load_checkpoint: Callable | None = None,
104
- on_before_backward: Callable | None = None,
105
- on_after_backward: Callable | None = None,
106
- on_before_optimizer_step: Callable | None = None,
107
- on_before_zero_grad: Callable | None = None,
108
- on_predict_start: Callable | None = None,
109
- on_predict_end: Callable | None = None,
110
- on_predict_batch_start: Callable | None = None,
111
- on_predict_batch_end: Callable | None = None,
112
- on_predict_epoch_start: Callable | None = None,
113
- on_predict_epoch_end: Callable | None = None,
114
- ):
115
- if callback is None:
116
- callback = LambdaCallback(
117
- setup=setup,
118
- teardown=teardown,
119
- on_fit_start=on_fit_start,
120
- on_fit_end=on_fit_end,
121
- on_sanity_check_start=on_sanity_check_start,
122
- on_sanity_check_end=on_sanity_check_end,
123
- on_train_batch_start=on_train_batch_start,
124
- on_train_batch_end=on_train_batch_end,
125
- on_train_epoch_start=on_train_epoch_start,
126
- on_train_epoch_end=on_train_epoch_end,
127
- on_validation_epoch_start=on_validation_epoch_start,
128
- on_validation_epoch_end=on_validation_epoch_end,
129
- on_test_epoch_start=on_test_epoch_start,
130
- on_test_epoch_end=on_test_epoch_end,
131
- on_validation_batch_start=on_validation_batch_start,
132
- on_validation_batch_end=on_validation_batch_end,
133
- on_test_batch_start=on_test_batch_start,
134
- on_test_batch_end=on_test_batch_end,
135
- on_train_start=on_train_start,
136
- on_train_end=on_train_end,
137
- on_validation_start=on_validation_start,
138
- on_validation_end=on_validation_end,
139
- on_test_start=on_test_start,
140
- on_test_end=on_test_end,
141
- on_exception=on_exception,
142
- on_save_checkpoint=on_save_checkpoint,
143
- on_load_checkpoint=on_load_checkpoint,
144
- on_before_backward=on_before_backward,
145
- on_after_backward=on_after_backward,
146
- on_before_optimizer_step=on_before_optimizer_step,
147
- on_before_zero_grad=on_before_zero_grad,
148
- on_predict_start=on_predict_start,
149
- on_predict_end=on_predict_end,
150
- on_predict_batch_start=on_predict_batch_start,
151
- on_predict_batch_end=on_predict_batch_end,
152
- on_predict_epoch_start=on_predict_epoch_start,
153
- on_predict_epoch_end=on_predict_epoch_end,
154
- )
155
-
156
- if not callable(callback):
157
- callback_ = cast(CallbackFn, lambda: callback)
158
- else:
159
- callback_ = callback
160
-
161
- self._nshtrainer_callbacks.append(callback_)
162
-
163
-
164
- class CallbackModuleMixin(
165
- CallbackRegistrarModuleMixin,
166
- mixin_base_type(LightningModule),
167
- ):
168
- def _nshtrainer_gather_all_callbacks(self):
169
- modules: list[Any] = []
170
- if isinstance(self, CallbackRegistrarModuleMixin):
171
- modules.append(self)
172
- if (
173
- datamodule := getattr(self.trainer, "datamodule", None)
174
- ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
175
- modules.append(datamodule)
176
- modules.extend(
177
- module
178
- for module in self.children()
179
- if isinstance(module, CallbackRegistrarModuleMixin)
180
- )
181
- for module in modules:
182
- yield from module._nshtrainer_callbacks
183
-
184
- @final
185
- @override
186
- def configure_callbacks(self):
187
- callbacks = super().configure_callbacks()
188
- if not isinstance(callbacks, Sequence):
189
- callbacks = [callbacks]
190
-
191
- callbacks = list(callbacks)
192
- for callback_fn in self._nshtrainer_gather_all_callbacks():
193
- callback_result = callback_fn()
194
- if callback_result is None:
195
- continue
196
-
197
- if not isinstance(callback_result, Iterable):
198
- callback_result = [callback_result]
199
-
200
- for callback in callback_result:
201
- log.info(
202
- f"Registering {callback.__class__.__qualname__} callback {callback}"
203
- )
204
- callbacks.append(callback)
205
-
206
- return callbacks
File without changes