nshtrainer 1.0.0b12__tar.gz → 1.0.0b13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/pyproject.toml +1 -1
  3. nshtrainer-1.0.0b13/src/nshtrainer/callbacks/lr_monitor.py +31 -0
  4. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/__init__.py +5 -13
  5. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/__init__.py +8 -0
  6. nshtrainer-1.0.0b13/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +31 -0
  7. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/__init__.py +19 -15
  8. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/_config/__init__.py +19 -15
  9. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/datamodule.py +0 -2
  10. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_config.py +95 -146
  11. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/trainer.py +10 -13
  12. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/README.md +0 -0
  13. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/__init__.py +0 -0
  14. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_callback.py +0 -0
  15. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/loader.py +0 -0
  16. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  17. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/saver.py +0 -0
  18. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_directory.py +0 -0
  19. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_experimental/__init__.py +0 -0
  20. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_hf_hub.py +0 -0
  21. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/__init__.py +0 -0
  22. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/actsave.py +0 -0
  23. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/base.py +0 -0
  24. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  25. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  26. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  27. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  28. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  29. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  30. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  31. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  32. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/ema.py +0 -0
  33. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  34. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  35. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/interval.py +0 -0
  36. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  37. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  38. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/print_table.py +0 -0
  39. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  40. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  41. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/timer.py +0 -0
  42. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  43. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  44. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  45. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -0
  46. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  47. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  48. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  49. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  50. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  51. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  52. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  53. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  54. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  56. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  57. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  58. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  59. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  60. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  61. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  62. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  63. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  64. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  65. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  66. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  67. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  68. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  69. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  70. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  71. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
  72. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  73. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  74. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  75. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  76. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  77. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
  78. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  79. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  80. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  81. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  82. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/__init__.py +0 -0
  83. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  84. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  85. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  86. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  87. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  88. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  89. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  90. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  91. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -0
  92. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  93. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/__init__.py +0 -0
  94. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  95. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  96. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  97. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  98. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/__init__.py +0 -0
  99. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  100. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/transform.py +0 -0
  101. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/__init__.py +0 -0
  102. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/_base.py +0 -0
  103. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/actsave.py +0 -0
  104. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/csv.py +0 -0
  105. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/tensorboard.py +0 -0
  106. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/wandb.py +0 -0
  107. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  108. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  109. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  110. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  111. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/__init__.py +0 -0
  112. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/_config.py +0 -0
  113. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/__init__.py +0 -0
  114. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/base.py +0 -0
  115. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/callback.py +0 -0
  116. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/debug.py +0 -0
  117. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/logger.py +0 -0
  118. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/__init__.py +0 -0
  119. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/mlp.py +0 -0
  120. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_dict.py +0 -0
  121. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_list.py +0 -0
  122. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/nonlinearity.py +0 -0
  123. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/optimizer.py +0 -0
  124. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/__init__.py +0 -0
  125. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/_base.py +0 -0
  126. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/advanced.py +0 -0
  127. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/pytorch.py +0 -0
  128. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/simple.py +0 -0
  129. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/__init__.py +0 -0
  130. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  131. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  132. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/signal_connector.py +0 -0
  133. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_environment_info.py +0 -0
  134. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_useful_types.py +0 -0
  135. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/bf16.py +0 -0
  136. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/__init__.py +0 -0
  137. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/dtype.py +0 -0
  138. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/duration.py +0 -0
  139. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/environment.py +0 -0
  140. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/path.py +0 -0
  141. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/seed.py +0 -0
  142. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/slurm.py +0 -0
  143. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typed.py +0 -0
  144. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b12
3
+ Version: 1.0.0b13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta12"
3
+ version = "1.0.0-beta13"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.callbacks import LearningRateMonitor
6
+
7
+ from .base import CallbackConfigBase
8
+
9
+
10
+ class LearningRateMonitorConfig(CallbackConfigBase):
11
+ logging_interval: Literal["step", "epoch"] | None = None
12
+ """
13
+ Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
14
+ """
15
+
16
+ log_momentum: bool = False
17
+ """
18
+ Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
19
+ """
20
+
21
+ log_weight_decay: bool = False
22
+ """
23
+ Option to also log the weight decay values of the optimizer. Defaults to False.
24
+ """
25
+
26
+ def create_callbacks(self, trainer_config):
27
+ yield LearningRateMonitor(
28
+ logging_interval=self.logging_interval,
29
+ log_momentum=self.log_momentum,
30
+ log_weight_decay=self.log_weight_decay,
31
+ )
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
132
132
  from nshtrainer.trainer._config import (
133
133
  GradientClippingConfig as GradientClippingConfig,
134
134
  )
135
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
136
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
137
135
  from nshtrainer.trainer._config import (
138
- ReproducibilityConfig as ReproducibilityConfig,
136
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
139
137
  )
140
138
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
141
139
  from nshtrainer.util._environment_info import (
@@ -325,6 +323,10 @@ else:
325
323
  ).LastCheckpointStrategyConfig
326
324
  if name == "LeakyReLUNonlinearityConfig":
327
325
  return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
326
+ if name == "LearningRateMonitorConfig":
327
+ return importlib.import_module(
328
+ "nshtrainer.trainer._config"
329
+ ).LearningRateMonitorConfig
328
330
  if name == "LinearWarmupCosineDecayLRSchedulerConfig":
329
331
  return importlib.import_module(
330
332
  "nshtrainer.lr_scheduler"
@@ -333,8 +335,6 @@ else:
333
335
  return importlib.import_module(
334
336
  "nshtrainer.callbacks"
335
337
  ).LogEpochCallbackConfig
336
- if name == "LoggingConfig":
337
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
338
338
  if name == "MLPConfig":
339
339
  return importlib.import_module("nshtrainer.nn").MLPConfig
340
340
  if name == "MetricConfig":
@@ -349,10 +349,6 @@ else:
349
349
  return importlib.import_module(
350
350
  "nshtrainer.callbacks"
351
351
  ).OnExceptionCheckpointCallbackConfig
352
- if name == "OptimizationConfig":
353
- return importlib.import_module(
354
- "nshtrainer.trainer._config"
355
- ).OptimizationConfig
356
352
  if name == "OptimizerConfigBase":
357
353
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
358
354
  if name == "PReLUConfig":
@@ -373,10 +369,6 @@ else:
373
369
  return importlib.import_module(
374
370
  "nshtrainer.lr_scheduler"
375
371
  ).ReduceLROnPlateauConfig
376
- if name == "ReproducibilityConfig":
377
- return importlib.import_module(
378
- "nshtrainer.trainer._config"
379
- ).ReproducibilityConfig
380
372
  if name == "SanityCheckingConfig":
381
373
  return importlib.import_module(
382
374
  "nshtrainer.trainer._config"
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
62
62
  CheckpointMetadata as CheckpointMetadata,
63
63
  )
64
64
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
65
+ from nshtrainer.callbacks.lr_monitor import (
66
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
67
+ )
65
68
  else:
66
69
 
67
70
  def __getattr__(name):
@@ -115,6 +118,10 @@ else:
115
118
  return importlib.import_module(
116
119
  "nshtrainer.callbacks"
117
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.callbacks.lr_monitor"
124
+ ).LearningRateMonitorConfig
118
125
  if name == "LogEpochCallbackConfig":
119
126
  return importlib.import_module(
120
127
  "nshtrainer.callbacks"
@@ -167,6 +174,7 @@ from . import ema as ema
167
174
  from . import finite_checks as finite_checks
168
175
  from . import gradient_skipping as gradient_skipping
169
176
  from . import log_epoch as log_epoch
177
+ from . import lr_monitor as lr_monitor
170
178
  from . import norm_logging as norm_logging
171
179
  from . import print_table as print_table
172
180
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.lr_monitor import (
12
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.lr_monitor"
24
+ ).CallbackConfigBase
25
+ if name == "LearningRateMonitorConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.lr_monitor"
28
+ ).LearningRateMonitorConfig
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
+
31
+ # Submodule exports
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
11
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
13
  from nshtrainer.trainer._config import (
13
14
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
15
  )
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
39
40
  from nshtrainer.trainer._config import (
40
41
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
42
  )
43
+ from nshtrainer.trainer._config import (
44
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
45
+ )
42
46
  from nshtrainer.trainer._config import (
43
47
  LogEpochCallbackConfig as LogEpochCallbackConfig,
44
48
  )
45
49
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
47
50
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
51
  from nshtrainer.trainer._config import (
49
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
52
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
52
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
53
54
  from nshtrainer.trainer._config import (
54
- ReproducibilityConfig as ReproducibilityConfig,
55
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
55
56
  )
57
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
56
58
  from nshtrainer.trainer._config import (
57
59
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
58
60
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
11
12
  from nshtrainer.trainer._config import (
12
13
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
13
14
  )
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
38
39
  from nshtrainer.trainer._config import (
39
40
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
40
41
  )
42
+ from nshtrainer.trainer._config import (
43
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
44
+ )
41
45
  from nshtrainer.trainer._config import (
42
46
  LogEpochCallbackConfig as LogEpochCallbackConfig,
43
47
  )
44
48
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
45
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
46
49
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
47
50
  from nshtrainer.trainer._config import (
48
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
51
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
49
52
  )
50
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
51
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
52
53
  from nshtrainer.trainer._config import (
53
- ReproducibilityConfig as ReproducibilityConfig,
54
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
54
55
  )
56
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
57
59
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -8,8 +8,6 @@ from typing import Any, Generic, cast
8
8
  import nshconfig as C
9
9
  import torch
10
10
  from lightning.pytorch import LightningDataModule
11
- from lightning.pytorch.utilities.model_helpers import is_overridden
12
- from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
11
  from typing_extensions import Never, TypeVar, deprecated, override
14
12
 
15
13
  from ..model.mixins.callback import CallbackRegistrarModuleMixin
@@ -40,11 +40,13 @@ from ..callbacks import (
40
40
  CallbackConfig,
41
41
  EarlyStoppingCallbackConfig,
42
42
  LastCheckpointCallbackConfig,
43
+ NormLoggingCallbackConfig,
43
44
  OnExceptionCheckpointCallbackConfig,
44
45
  )
45
46
  from ..callbacks.base import CallbackConfigBase
46
47
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
47
48
  from ..callbacks.log_epoch import LogEpochCallbackConfig
49
+ from ..callbacks.lr_monitor import LearningRateMonitorConfig
48
50
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
49
51
  from ..callbacks.shared_parameters import SharedParametersCallbackConfig
50
52
  from ..loggers import (
@@ -53,6 +55,7 @@ from ..loggers import (
53
55
  TensorboardLoggerConfig,
54
56
  WandbLoggerConfig,
55
57
  )
58
+ from ..loggers._base import BaseLoggerConfig
56
59
  from ..loggers.actsave import ActSaveLoggerConfig
57
60
  from ..metrics._config import MetricConfig
58
61
  from ..profiler import ProfilerConfig
@@ -61,103 +64,6 @@ from ..util._environment_info import EnvironmentConfig
61
64
  log = logging.getLogger(__name__)
62
65
 
63
66
 
64
- class LoggingConfig(CallbackConfigBase):
65
- enabled: bool = True
66
- """Enable experiment tracking."""
67
-
68
- loggers: Sequence[LoggerConfig] = [
69
- WandbLoggerConfig(),
70
- CSVLoggerConfig(),
71
- TensorboardLoggerConfig(),
72
- ]
73
- """Loggers to use for experiment tracking."""
74
-
75
- log_lr: bool | Literal["step", "epoch"] = True
76
- """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
77
- log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
78
- """If enabled, will log the fractional epoch number to the logger."""
79
-
80
- actsave_logger: ActSaveLoggerConfig | None = None
81
- """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
82
-
83
- @property
84
- def wandb(self):
85
- return next(
86
- (
87
- logger
88
- for logger in self.loggers
89
- if isinstance(logger, WandbLoggerConfig)
90
- ),
91
- None,
92
- )
93
-
94
- @property
95
- def csv(self):
96
- return next(
97
- (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
98
- None,
99
- )
100
-
101
- @property
102
- def tensorboard(self):
103
- return next(
104
- (
105
- logger
106
- for logger in self.loggers
107
- if isinstance(logger, TensorboardLoggerConfig)
108
- ),
109
- None,
110
- )
111
-
112
- def create_loggers(self, trainer_config: TrainerConfig):
113
- """
114
- Constructs and returns a list of loggers based on the provided root configuration.
115
-
116
- Args:
117
- trainer_config (TrainerConfig): The root configuration object.
118
-
119
- Returns:
120
- list[Logger]: A list of constructed loggers.
121
- """
122
- if not self.enabled:
123
- return
124
-
125
- for logger_config in sorted(
126
- self.loggers,
127
- key=lambda x: x.priority,
128
- reverse=True,
129
- ):
130
- if not logger_config.enabled:
131
- continue
132
- if (logger := logger_config.create_logger(trainer_config)) is None:
133
- continue
134
- yield logger
135
-
136
- # If the actsave_metrics is enabled, add the ActSave logger
137
- if self.actsave_logger:
138
- yield self.actsave_logger.create_logger(trainer_config)
139
-
140
- @override
141
- def create_callbacks(self, trainer_config):
142
- if self.log_lr:
143
- from lightning.pytorch.callbacks import LearningRateMonitor
144
-
145
- logging_interval: str | None = None
146
- if isinstance(self.log_lr, str):
147
- logging_interval = self.log_lr
148
-
149
- yield LearningRateMonitor(logging_interval=logging_interval)
150
-
151
- if self.log_epoch:
152
- yield from self.log_epoch.create_callbacks(trainer_config)
153
-
154
- for logger in self.loggers:
155
- if not logger or not isinstance(logger, CallbackConfigBase):
156
- continue
157
-
158
- yield from logger.create_callbacks(trainer_config)
159
-
160
-
161
67
  class GradientClippingConfig(C.Config):
162
68
  enabled: bool = True
163
69
  """Enable gradient clipping."""
@@ -167,32 +73,6 @@ class GradientClippingConfig(C.Config):
167
73
  """Norm type to use for gradient clipping."""
168
74
 
169
75
 
170
- class OptimizationConfig(CallbackConfigBase):
171
- log_grad_norm: bool | str | float = False
172
- """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
173
- log_grad_norm_per_param: bool | str | float = False
174
- """If enabled, will log the gradient norm for each model parameter to the logger."""
175
-
176
- log_param_norm: bool | str | float = False
177
- """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
178
- log_param_norm_per_param: bool | str | float = False
179
- """If enabled, will log the parameter norm for each model parameter to the logger."""
180
-
181
- gradient_clipping: GradientClippingConfig | None = None
182
- """Gradient clipping configuration, or None to disable gradient clipping."""
183
-
184
- @override
185
- def create_callbacks(self, trainer_config):
186
- from ..callbacks.norm_logging import NormLoggingCallbackConfig
187
-
188
- yield from NormLoggingCallbackConfig(
189
- log_grad_norm=self.log_grad_norm,
190
- log_grad_norm_per_param=self.log_grad_norm_per_param,
191
- log_param_norm=self.log_param_norm,
192
- log_param_norm_per_param=self.log_param_norm_per_param,
193
- ).create_callbacks(trainer_config)
194
-
195
-
196
76
  TPlugin = TypeVar(
197
77
  "TPlugin",
198
78
  Precision,
@@ -252,15 +132,6 @@ StrategyLiteral: TypeAlias = Literal[
252
132
  ]
253
133
 
254
134
 
255
- class ReproducibilityConfig(C.Config):
256
- deterministic: bool | Literal["warn"] | None = None
257
- """
258
- If ``True``, sets whether PyTorch operations must use deterministic algorithms.
259
- Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
260
- that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
261
- """
262
-
263
-
264
135
  CheckpointCallbackConfig: TypeAlias = Annotated[
265
136
  BestCheckpointCallbackConfig
266
137
  | LastCheckpointCallbackConfig
@@ -634,14 +505,34 @@ class TrainerConfig(C.Config):
634
505
  hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
635
506
  """Hugging Face Hub configuration options."""
636
507
 
637
- logging: LoggingConfig = LoggingConfig()
638
- """Logging/experiment tracking (e.g., WandB) configuration options."""
508
+ loggers: Sequence[LoggerConfig] = [
509
+ WandbLoggerConfig(),
510
+ CSVLoggerConfig(),
511
+ TensorboardLoggerConfig(),
512
+ ]
513
+ """Loggers to use for experiment tracking."""
639
514
 
640
- optimizer: OptimizationConfig = OptimizationConfig()
641
- """Optimization configuration options."""
515
+ actsave_logger: ActSaveLoggerConfig | None = None
516
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
642
517
 
643
- reproducibility: ReproducibilityConfig = ReproducibilityConfig()
644
- """Reproducibility configuration options."""
518
+ lr_monitor: LearningRateMonitorConfig | None = LearningRateMonitorConfig()
519
+ """Learning rate monitoring configuration options."""
520
+
521
+ log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
522
+ """If enabled, will log the fractional epoch number to the logger."""
523
+
524
+ gradient_clipping: GradientClippingConfig | None = None
525
+ """Gradient clipping configuration, or None to disable gradient clipping."""
526
+
527
+ log_norms: NormLoggingCallbackConfig | None = None
528
+ """Norm logging configuration options."""
529
+
530
+ deterministic: bool | Literal["warn"] | None = None
531
+ """
532
+ If ``True``, sets whether PyTorch operations must use deterministic algorithms.
533
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
534
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
535
+ """
645
536
 
646
537
  reduce_lr_on_plateau_sanity_checking: RLPSanityChecksCallbackConfig | None = (
647
538
  RLPSanityChecksCallbackConfig()
@@ -856,27 +747,87 @@ class TrainerConfig(C.Config):
856
747
  set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
857
748
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
858
749
 
750
+ @property
751
+ def wandb_logger(self):
752
+ return next(
753
+ (
754
+ logger
755
+ for logger in self.loggers
756
+ if isinstance(logger, WandbLoggerConfig)
757
+ ),
758
+ None,
759
+ )
760
+
761
+ @property
762
+ def csv_logger(self):
763
+ return next(
764
+ (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
765
+ None,
766
+ )
767
+
768
+ @property
769
+ def tensorboard_logger(self):
770
+ return next(
771
+ (
772
+ logger
773
+ for logger in self.loggers
774
+ if isinstance(logger, TensorboardLoggerConfig)
775
+ ),
776
+ None,
777
+ )
778
+
859
779
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
860
780
  yield self.early_stopping
861
781
  yield self.checkpoint_saving
862
- yield self.logging
863
- yield self.optimizer
782
+ yield self.lr_monitor
783
+ yield from (
784
+ logger_config
785
+ for logger_config in self.loggers
786
+ if logger_config is not None
787
+ and isinstance(logger_config, CallbackConfigBase)
788
+ )
789
+ yield self.log_epoch
790
+ yield self.log_norms
864
791
  yield self.hf_hub
865
792
  yield self.shared_parameters
866
793
  yield self.reduce_lr_on_plateau_sanity_checking
867
794
  yield self.auto_set_debug_flag
868
795
  yield from self.callbacks
869
796
 
797
+ def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
798
+ yield from self.loggers
799
+ yield self.actsave_logger
800
+
870
801
  # region Helper Methods
802
+ def fast_dev_run_(self, value: int | bool = True, /):
803
+ """
804
+ Enables fast_dev_run mode for the trainer.
805
+ This will run the training loop for a specified number of batches,
806
+ if an integer is provided, or for a single batch if True is provided.
807
+ """
808
+ self.fast_dev_run = value
809
+ return self
810
+
871
811
  def with_fast_dev_run(self, value: int | bool = True, /):
872
812
  """
873
813
  Enables fast_dev_run mode for the trainer.
874
814
  This will run the training loop for a specified number of batches,
875
815
  if an integer is provided, or for a single batch if True is provided.
876
816
  """
877
- config = copy.deepcopy(self)
878
- config.fast_dev_run = value
879
- return config
817
+ return copy.deepcopy(self).fast_dev_run_(value)
818
+
819
+ def project_root_(self, project_root: str | Path | os.PathLike):
820
+ """
821
+ Set the project root directory for the trainer.
822
+
823
+ Args:
824
+ project_root (Path): The base directory to use.
825
+
826
+ Returns:
827
+ self: The current instance of the class.
828
+ """
829
+ self.directory.project_root = Path(project_root)
830
+ return self
880
831
 
881
832
  def with_project_root(self, project_root: str | Path | os.PathLike):
882
833
  """
@@ -888,9 +839,7 @@ class TrainerConfig(C.Config):
888
839
  Returns:
889
840
  self: The current instance of the class.
890
841
  """
891
- config = copy.deepcopy(self)
892
- config.directory.project_root = Path(project_root)
893
- return config
842
+ return copy.deepcopy(self).project_root_(project_root)
894
843
 
895
844
  def reset_run(
896
845
  self,