nshtrainer 0.34.2__tar.gz → 0.35.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/model/base.py +2 -0
  4. nshtrainer-0.35.1/src/nshtrainer/model/mixins/callback.py +74 -0
  5. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/README.md +0 -0
  6. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/__init__.py +0 -0
  7. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_callback.py +0 -0
  8. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  9. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  10. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  11. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_directory.py +0 -0
  12. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  13. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/_hf_hub.py +0 -0
  14. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  15. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  17. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/base.py +0 -0
  18. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  19. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  20. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  21. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  22. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  23. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  24. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  25. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  26. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/ema.py +0 -0
  27. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  28. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  29. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/interval.py +0 -0
  30. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  31. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  32. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  33. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  34. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  35. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  36. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/timer.py +0 -0
  37. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  38. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/config.py +0 -0
  39. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/data/__init__.py +0 -0
  40. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  41. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/data/transform.py +0 -0
  42. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/__init__.py +0 -0
  43. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/_experimental.py +0 -0
  44. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/actsave.py +0 -0
  45. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/callbacks.py +0 -0
  46. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/config.py +0 -0
  47. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/data.py +0 -0
  48. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/log.py +0 -0
  49. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  50. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/model.py +0 -0
  51. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/nn.py +0 -0
  52. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/optimizer.py +0 -0
  53. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/runner.py +0 -0
  54. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/snapshot.py +0 -0
  55. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/snoop.py +0 -0
  56. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/trainer.py +0 -0
  57. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/typecheck.py +0 -0
  58. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/ll/util.py +0 -0
  59. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/loggers/__init__.py +0 -0
  60. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/loggers/_base.py +0 -0
  61. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/loggers/csv.py +0 -0
  62. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
  63. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/loggers/wandb.py +0 -0
  64. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  65. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  66. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  67. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  68. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/metrics/__init__.py +0 -0
  69. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/metrics/_config.py +0 -0
  70. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/model/__init__.py +0 -0
  71. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/model/config.py +0 -0
  72. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/model/mixins/logger.py +0 -0
  73. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/nn/__init__.py +0 -0
  74. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/nn/mlp.py +0 -0
  75. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/nn/module_dict.py +0 -0
  76. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/nn/module_list.py +0 -0
  77. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  78. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/optimizer.py +0 -0
  79. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/profiler/__init__.py +0 -0
  80. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/profiler/_base.py +0 -0
  81. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/profiler/advanced.py +0 -0
  82. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/profiler/pytorch.py +0 -0
  83. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/profiler/simple.py +0 -0
  84. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/runner.py +0 -0
  85. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  86. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/trainer/__init__.py +0 -0
  87. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/trainer/_config.py +0 -0
  88. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  89. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  90. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  91. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/trainer/trainer.py +0 -0
  92. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/_environment_info.py +0 -0
  93. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/_useful_types.py +0 -0
  94. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/bf16.py +0 -0
  95. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/config/__init__.py +0 -0
  96. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/config/dtype.py +0 -0
  97. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/config/duration.py +0 -0
  98. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/environment.py +0 -0
  99. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/path.py +0 -0
  100. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/seed.py +0 -0
  101. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/slurm.py +0 -0
  102. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/typed.py +0 -0
  103. {nshtrainer-0.34.2 → nshtrainer-0.35.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.34.2
3
+ Version: 0.35.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.34.2"
3
+ version = "0.35.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -15,6 +15,7 @@ from typing_extensions import Self, TypeVar, override
15
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
16
  from ..util._environment_info import EnvironmentConfig
17
17
  from .config import BaseConfig
18
+ from .mixins.callback import CallbackModuleMixin
18
19
  from .mixins.logger import LoggerLightningModuleMixin
19
20
 
20
21
  log = logging.getLogger(__name__)
@@ -53,6 +54,7 @@ VALID_REDUCE_OPS = (
53
54
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
54
55
  _RLPSanityCheckModuleMixin,
55
56
  LoggerLightningModuleMixin,
57
+ CallbackModuleMixin,
56
58
  LightningModule,
57
59
  ABC,
58
60
  Generic[THparams],
@@ -0,0 +1,74 @@
1
+ import logging
2
+ from collections.abc import Callable, Iterable, Sequence
3
+ from typing import Any, TypeAlias, cast, final
4
+
5
+ from lightning.pytorch import Callback, LightningModule
6
+ from typing_extensions import override
7
+
8
+ from ...util.typing_utils import mixin_base_type
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
13
+
14
+
15
+ class CallbackRegistrarModuleMixin:
16
+ @property
17
+ def _nshtrainer_callbacks(self) -> list[CallbackFn]:
18
+ if not hasattr(self, "_private_nshtrainer_callbacks_list"):
19
+ self._private_nshtrainer_callbacks_list = []
20
+ return self._private_nshtrainer_callbacks_list
21
+
22
+ def register_callback(
23
+ self,
24
+ callback: Callback | Iterable[Callback] | CallbackFn | None = None,
25
+ ):
26
+ if not callable(callback):
27
+ callback_ = cast(CallbackFn, lambda: callback)
28
+ else:
29
+ callback_ = callback
30
+
31
+ self._nshtrainer_callbacks.append(callback_)
32
+
33
+
34
+ class CallbackModuleMixin(
35
+ CallbackRegistrarModuleMixin, mixin_base_type(LightningModule)
36
+ ):
37
+ def _gather_all_callbacks(self):
38
+ modules: list[Any] = []
39
+ if isinstance(self, CallbackRegistrarModuleMixin):
40
+ modules.append(self)
41
+ if (
42
+ datamodule := getattr(self.trainer, "datamodule", None)
43
+ ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
44
+ modules.append(datamodule)
45
+ modules.extend(
46
+ module
47
+ for module in self.children()
48
+ if isinstance(module, CallbackRegistrarModuleMixin)
49
+ )
50
+ for module in modules:
51
+ yield from module._nshtrainer_callbacks
52
+
53
+ @final
54
+ @override
55
+ def configure_callbacks(self):
56
+ callbacks = super().configure_callbacks()
57
+ if not isinstance(callbacks, Sequence):
58
+ callbacks = [callbacks]
59
+
60
+ callbacks = list(callbacks)
61
+ for callback_fn in self._gather_all_callbacks():
62
+ if (callback_result := callback_fn()) is None:
63
+ continue
64
+
65
+ if not isinstance(callback_result, Iterable):
66
+ callback_result = [callback_result]
67
+
68
+ for callback in callback_result:
69
+ log.info(
70
+ f"Registering {callback.__class__.__qualname__} callback {callback}"
71
+ )
72
+ callbacks.append(callback)
73
+
74
+ return callbacks
File without changes