nshtrainer 0.34.2__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/model/base.py CHANGED
@@ -15,6 +15,7 @@ from typing_extensions import Self, TypeVar, override
15
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
16
  from ..util._environment_info import EnvironmentConfig
17
17
  from .config import BaseConfig
18
+ from .mixins.callback import CallbackModuleMixin
18
19
  from .mixins.logger import LoggerLightningModuleMixin
19
20
 
20
21
  log = logging.getLogger(__name__)
@@ -53,6 +54,7 @@ VALID_REDUCE_OPS = (
53
54
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
54
55
  _RLPSanityCheckModuleMixin,
55
56
  LoggerLightningModuleMixin,
57
+ CallbackModuleMixin,
56
58
  LightningModule,
57
59
  ABC,
58
60
  Generic[THparams],
@@ -0,0 +1,74 @@
1
+ import logging
2
+ from collections.abc import Callable, Iterable, Sequence
3
+ from typing import Any, TypeAlias, cast, final
4
+
5
+ from lightning.pytorch import Callback, LightningModule
6
+ from typing_extensions import override
7
+
8
+ from ...util.typing_utils import mixin_base_type
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
13
+
14
+
15
+ class CallbackRegistrarModuleMixin:
16
+ @property
17
+ def _nshtrainer_callbacks(self) -> list[CallbackFn]:
18
+ if not hasattr(self, "_private_nshtrainer_callbacks_list"):
19
+ self._private_nshtrainer_callbacks_list = []
20
+ return self._private_nshtrainer_callbacks_list
21
+
22
+ def register_callback(
23
+ self,
24
+ callback: Callback | Iterable[Callback] | CallbackFn | None = None,
25
+ ):
26
+ if not callable(callback):
27
+ callback_ = cast(CallbackFn, lambda: callback)
28
+ else:
29
+ callback_ = callback
30
+
31
+ self._nshtrainer_callbacks.append(callback_)
32
+
33
+
34
+ class CallbackModuleMixin(
35
+ CallbackRegistrarModuleMixin, mixin_base_type(LightningModule)
36
+ ):
37
+ def _gather_all_callbacks(self):
38
+ modules: list[Any] = []
39
+ if isinstance(self, CallbackRegistrarModuleMixin):
40
+ modules.append(self)
41
+ if (
42
+ datamodule := getattr(self.trainer, "datamodule", None)
43
+ ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
44
+ modules.append(datamodule)
45
+ modules.extend(
46
+ module
47
+ for module in self.children()
48
+ if isinstance(module, CallbackRegistrarModuleMixin)
49
+ )
50
+ for module in modules:
51
+ yield from module._nshtrainer_callbacks
52
+
53
+ @final
54
+ @override
55
+ def configure_callbacks(self):
56
+ callbacks = super().configure_callbacks()
57
+ if not isinstance(callbacks, Sequence):
58
+ callbacks = [callbacks]
59
+
60
+ callbacks = list(callbacks)
61
+ for callback_fn in self._gather_all_callbacks():
62
+ if (callback_result := callback_fn()) is None:
63
+ continue
64
+
65
+ if not isinstance(callback_result, Iterable):
66
+ callback_result = [callback_result]
67
+
68
+ for callback in callback_result:
69
+ log.info(
70
+ f"Registering {callback.__class__.__qualname__} callback {callback}"
71
+ )
72
+ callbacks.append(callback)
73
+
74
+ return callbacks
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.34.2
3
+ Version: 0.35.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -63,8 +63,9 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-
63
63
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
64
64
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
65
65
  nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
66
- nshtrainer/model/base.py,sha256=1zVY8ybZTzVKhpp7sUC0t360Ut3YmdGxAW5PZAIBSyw,18535
66
+ nshtrainer/model/base.py,sha256=NasbYZJBuEly6Hm9t9HVZk-CUHmy4T7p1v-Ye981XA4,18609
67
67
  nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
68
+ nshtrainer/model/mixins/callback.py,sha256=rbe8P22iEjPkH1df6rfEo3Txw7EwSz6Dkm0TWO_AysM,2419
68
69
  nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
69
70
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
70
71
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
@@ -97,6 +98,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
97
98
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
98
99
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
99
100
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
100
- nshtrainer-0.34.2.dist-info/METADATA,sha256=DQyYTUO0wpboH1gy3nSRJV6EsWCpY7Kb_ldD8v4BQFY,916
101
- nshtrainer-0.34.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
- nshtrainer-0.34.2.dist-info/RECORD,,
101
+ nshtrainer-0.35.1.dist-info/METADATA,sha256=LJUDrvicSUUgu2IfOlKoySGMoIrFMce0ryo4Fvrzwbs,916
102
+ nshtrainer-0.35.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
103
+ nshtrainer-0.35.1.dist-info/RECORD,,