nshtrainer 1.0.0b12__tar.gz → 1.0.0b14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +3 -3
  4. nshtrainer-1.0.0b14/src/nshtrainer/callbacks/lr_monitor.py +31 -0
  5. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/__init__.py +7 -50
  6. nshtrainer-1.0.0b14/src/nshtrainer/configs/_checkpoint/__init__.py +31 -0
  7. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/__init__.py +8 -0
  8. nshtrainer-1.0.0b14/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +31 -0
  9. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/__init__.py +19 -23
  10. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/_config/__init__.py +19 -22
  11. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/data/datamodule.py +0 -2
  12. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/_config.py +95 -153
  13. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/trainer.py +11 -24
  14. nshtrainer-1.0.0b12/src/nshtrainer/_checkpoint/loader.py +0 -387
  15. nshtrainer-1.0.0b12/src/nshtrainer/configs/_checkpoint/__init__.py +0 -70
  16. nshtrainer-1.0.0b12/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -62
  17. nshtrainer-1.0.0b12/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -26
  18. nshtrainer-1.0.0b12/src/nshtrainer/trainer/checkpoint_connector.py +0 -86
  19. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/README.md +0 -0
  20. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/__init__.py +0 -0
  21. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/_callback.py +0 -0
  22. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  23. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/_checkpoint/saver.py +0 -0
  24. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/_directory.py +0 -0
  25. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/_experimental/__init__.py +0 -0
  26. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/_hf_hub.py +0 -0
  27. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/__init__.py +0 -0
  28. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/actsave.py +0 -0
  29. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/base.py +0 -0
  30. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  31. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  32. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  33. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  34. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  35. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  36. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  37. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/ema.py +0 -0
  38. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  39. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  40. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/interval.py +0 -0
  41. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  42. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  43. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/print_table.py +0 -0
  44. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  45. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  46. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/timer.py +0 -0
  47. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  48. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  49. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  50. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  51. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  52. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  53. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  54. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  56. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  58. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  59. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  60. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  61. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  62. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  63. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  64. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  65. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  66. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  67. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  68. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  69. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  70. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  71. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  72. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  73. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  74. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
  75. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  76. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  77. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  78. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  79. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  80. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
  81. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  82. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  83. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  84. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  85. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/__init__.py +0 -0
  86. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  87. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  88. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  89. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  90. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  91. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  92. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  93. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  94. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  95. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/__init__.py +0 -0
  96. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  97. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  98. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  99. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  100. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/data/__init__.py +0 -0
  101. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  102. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/data/transform.py +0 -0
  103. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/__init__.py +0 -0
  104. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/_base.py +0 -0
  105. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/actsave.py +0 -0
  106. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/csv.py +0 -0
  107. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/tensorboard.py +0 -0
  108. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/wandb.py +0 -0
  109. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  110. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  111. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  112. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  113. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/metrics/__init__.py +0 -0
  114. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/metrics/_config.py +0 -0
  115. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/model/__init__.py +0 -0
  116. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/model/base.py +0 -0
  117. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/callback.py +0 -0
  118. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/debug.py +0 -0
  119. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/logger.py +0 -0
  120. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/__init__.py +0 -0
  121. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/mlp.py +0 -0
  122. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/module_dict.py +0 -0
  123. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/module_list.py +0 -0
  124. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/nonlinearity.py +0 -0
  125. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/optimizer.py +0 -0
  126. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/__init__.py +0 -0
  127. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/_base.py +0 -0
  128. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/advanced.py +0 -0
  129. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/pytorch.py +0 -0
  130. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/simple.py +0 -0
  131. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/__init__.py +0 -0
  132. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  133. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/signal_connector.py +0 -0
  134. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/_environment_info.py +0 -0
  135. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/_useful_types.py +0 -0
  136. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/bf16.py +0 -0
  137. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/__init__.py +0 -0
  138. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/dtype.py +0 -0
  139. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/duration.py +0 -0
  140. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/environment.py +0 -0
  141. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/path.py +0 -0
  142. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/seed.py +0 -0
  143. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/slurm.py +0 -0
  144. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/typed.py +0 -0
  145. {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b14}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b12
3
+ Version: 1.0.0b14
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta12"
3
+ version = "1.0.0-beta14"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -96,8 +96,8 @@ class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
96
96
  def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
97
97
  # Monkey-patch the strategy instance to make the barrier operation a no-op.
98
98
  # We do this because `save_checkpoint` calls `barrier`. This is okay in most
99
- # cases, but when we want to save a checkpoint in the case of an exception,
100
- # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
101
- # make the barrier operation a no-op.
99
+ # cases, but when we want to save a checkpoint in the case of an exception,
100
+ # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
101
+ # make the barrier operation a no-op.
102
102
  with _monkey_patch_disable_barrier(trainer):
103
103
  return super().on_exception(trainer, *args, **kwargs)
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.callbacks import LearningRateMonitor
6
+
7
+ from .base import CallbackConfigBase
8
+
9
+
10
+ class LearningRateMonitorConfig(CallbackConfigBase):
11
+ logging_interval: Literal["step", "epoch"] | None = None
12
+ """
13
+ Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
14
+ """
15
+
16
+ log_momentum: bool = False
17
+ """
18
+ Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
19
+ """
20
+
21
+ log_weight_decay: bool = False
22
+ """
23
+ Option to also log the weight decay values of the optimizer. Defaults to False.
24
+ """
25
+
26
+ def create_callbacks(self, trainer_config):
27
+ yield LearningRateMonitor(
28
+ logging_interval=self.logging_interval,
29
+ log_momentum=self.log_momentum,
30
+ log_weight_decay=self.log_weight_decay,
31
+ )
@@ -9,19 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer import MetricConfig as MetricConfig
11
11
  from nshtrainer import TrainerConfig as TrainerConfig
12
- from nshtrainer._checkpoint.loader import (
13
- BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
14
- )
15
- from nshtrainer._checkpoint.loader import (
16
- CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
17
- )
18
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
19
- from nshtrainer._checkpoint.loader import (
20
- LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
21
- )
22
- from nshtrainer._checkpoint.loader import (
23
- UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
24
- )
12
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
25
13
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
26
14
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
27
15
  from nshtrainer._hf_hub import (
@@ -122,9 +110,6 @@ if TYPE_CHECKING:
122
110
  from nshtrainer.trainer._config import (
123
111
  CheckpointCallbackConfig as CheckpointCallbackConfig,
124
112
  )
125
- from nshtrainer.trainer._config import (
126
- CheckpointLoadingConfig as CheckpointLoadingConfig,
127
- )
128
113
  from nshtrainer.trainer._config import (
129
114
  CheckpointSavingConfig as CheckpointSavingConfig,
130
115
  )
@@ -132,10 +117,8 @@ if TYPE_CHECKING:
132
117
  from nshtrainer.trainer._config import (
133
118
  GradientClippingConfig as GradientClippingConfig,
134
119
  )
135
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
136
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
137
120
  from nshtrainer.trainer._config import (
138
- ReproducibilityConfig as ReproducibilityConfig,
121
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
139
122
  )
140
123
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
141
124
  from nshtrainer.util._environment_info import (
@@ -201,21 +184,13 @@ else:
201
184
  return importlib.import_module(
202
185
  "nshtrainer.callbacks"
203
186
  ).BestCheckpointCallbackConfig
204
- if name == "BestCheckpointStrategyConfig":
205
- return importlib.import_module(
206
- "nshtrainer._checkpoint.loader"
207
- ).BestCheckpointStrategyConfig
208
187
  if name == "CSVLoggerConfig":
209
188
  return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
210
189
  if name == "CallbackConfigBase":
211
190
  return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
212
- if name == "CheckpointLoadingConfig":
213
- return importlib.import_module(
214
- "nshtrainer.trainer._config"
215
- ).CheckpointLoadingConfig
216
191
  if name == "CheckpointMetadata":
217
192
  return importlib.import_module(
218
- "nshtrainer._checkpoint.loader"
193
+ "nshtrainer._checkpoint.metadata"
219
194
  ).CheckpointMetadata
220
195
  if name == "CheckpointSavingConfig":
221
196
  return importlib.import_module(
@@ -319,12 +294,12 @@ else:
319
294
  return importlib.import_module(
320
295
  "nshtrainer.callbacks"
321
296
  ).LastCheckpointCallbackConfig
322
- if name == "LastCheckpointStrategyConfig":
323
- return importlib.import_module(
324
- "nshtrainer._checkpoint.loader"
325
- ).LastCheckpointStrategyConfig
326
297
  if name == "LeakyReLUNonlinearityConfig":
327
298
  return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
299
+ if name == "LearningRateMonitorConfig":
300
+ return importlib.import_module(
301
+ "nshtrainer.trainer._config"
302
+ ).LearningRateMonitorConfig
328
303
  if name == "LinearWarmupCosineDecayLRSchedulerConfig":
329
304
  return importlib.import_module(
330
305
  "nshtrainer.lr_scheduler"
@@ -333,8 +308,6 @@ else:
333
308
  return importlib.import_module(
334
309
  "nshtrainer.callbacks"
335
310
  ).LogEpochCallbackConfig
336
- if name == "LoggingConfig":
337
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
338
311
  if name == "MLPConfig":
339
312
  return importlib.import_module("nshtrainer.nn").MLPConfig
340
313
  if name == "MetricConfig":
@@ -349,10 +322,6 @@ else:
349
322
  return importlib.import_module(
350
323
  "nshtrainer.callbacks"
351
324
  ).OnExceptionCheckpointCallbackConfig
352
- if name == "OptimizationConfig":
353
- return importlib.import_module(
354
- "nshtrainer.trainer._config"
355
- ).OptimizationConfig
356
325
  if name == "OptimizerConfigBase":
357
326
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
358
327
  if name == "PReLUConfig":
@@ -373,10 +342,6 @@ else:
373
342
  return importlib.import_module(
374
343
  "nshtrainer.lr_scheduler"
375
344
  ).ReduceLROnPlateauConfig
376
- if name == "ReproducibilityConfig":
377
- return importlib.import_module(
378
- "nshtrainer.trainer._config"
379
- ).ReproducibilityConfig
380
345
  if name == "SanityCheckingConfig":
381
346
  return importlib.import_module(
382
347
  "nshtrainer.trainer._config"
@@ -411,10 +376,6 @@ else:
411
376
  return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
412
377
  if name == "TrainerConfig":
413
378
  return importlib.import_module("nshtrainer").TrainerConfig
414
- if name == "UserProvidedPathCheckpointStrategyConfig":
415
- return importlib.import_module(
416
- "nshtrainer._checkpoint.loader"
417
- ).UserProvidedPathCheckpointStrategyConfig
418
379
  if name == "WandbLoggerConfig":
419
380
  return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
420
381
  if name == "WandbUploadCodeCallbackConfig":
@@ -431,10 +392,6 @@ else:
431
392
  return importlib.import_module(
432
393
  "nshtrainer.trainer._config"
433
394
  ).CheckpointCallbackConfig
434
- if name == "CheckpointLoadingStrategyConfig":
435
- return importlib.import_module(
436
- "nshtrainer._checkpoint.loader"
437
- ).CheckpointLoadingStrategyConfig
438
395
  if name == "DurationConfig":
439
396
  return importlib.import_module("nshtrainer.util.config").DurationConfig
440
397
  if name == "LRSchedulerConfig":
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
11
+ from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "CheckpointMetadata":
20
+ return importlib.import_module(
21
+ "nshtrainer._checkpoint.metadata"
22
+ ).CheckpointMetadata
23
+ if name == "EnvironmentConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer._checkpoint.metadata"
26
+ ).EnvironmentConfig
27
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
+
29
+
30
+ # Submodule exports
31
+ from . import metadata as metadata
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
62
62
  CheckpointMetadata as CheckpointMetadata,
63
63
  )
64
64
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
65
+ from nshtrainer.callbacks.lr_monitor import (
66
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
67
+ )
65
68
  else:
66
69
 
67
70
  def __getattr__(name):
@@ -115,6 +118,10 @@ else:
115
118
  return importlib.import_module(
116
119
  "nshtrainer.callbacks"
117
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.callbacks.lr_monitor"
124
+ ).LearningRateMonitorConfig
118
125
  if name == "LogEpochCallbackConfig":
119
126
  return importlib.import_module(
120
127
  "nshtrainer.callbacks"
@@ -167,6 +174,7 @@ from . import ema as ema
167
174
  from . import finite_checks as finite_checks
168
175
  from . import gradient_skipping as gradient_skipping
169
176
  from . import log_epoch as log_epoch
177
+ from . import lr_monitor as lr_monitor
170
178
  from . import norm_logging as norm_logging
171
179
  from . import print_table as print_table
172
180
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.lr_monitor import (
12
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.lr_monitor"
24
+ ).CallbackConfigBase
25
+ if name == "LearningRateMonitorConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.lr_monitor"
28
+ ).LearningRateMonitorConfig
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
+
31
+ # Submodule exports
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
11
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
13
  from nshtrainer.trainer._config import (
13
14
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
15
  )
@@ -17,9 +18,6 @@ if TYPE_CHECKING:
17
18
  from nshtrainer.trainer._config import (
18
19
  CheckpointCallbackConfig as CheckpointCallbackConfig,
19
20
  )
20
- from nshtrainer.trainer._config import (
21
- CheckpointLoadingConfig as CheckpointLoadingConfig,
22
- )
23
21
  from nshtrainer.trainer._config import (
24
22
  CheckpointSavingConfig as CheckpointSavingConfig,
25
23
  )
@@ -39,20 +37,21 @@ if TYPE_CHECKING:
39
37
  from nshtrainer.trainer._config import (
40
38
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
39
  )
40
+ from nshtrainer.trainer._config import (
41
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
42
+ )
42
43
  from nshtrainer.trainer._config import (
43
44
  LogEpochCallbackConfig as LogEpochCallbackConfig,
44
45
  )
45
46
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
47
47
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
48
  from nshtrainer.trainer._config import (
49
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
49
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
50
50
  )
51
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
52
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
53
51
  from nshtrainer.trainer._config import (
54
- ReproducibilityConfig as ReproducibilityConfig,
52
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
55
53
  )
54
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
56
55
  from nshtrainer.trainer._config import (
57
56
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
58
57
  )
@@ -75,6 +74,10 @@ else:
75
74
  return importlib.import_module(
76
75
  "nshtrainer.trainer._config"
77
76
  ).ActSaveLoggerConfig
77
+ if name == "BaseLoggerConfig":
78
+ return importlib.import_module(
79
+ "nshtrainer.trainer._config"
80
+ ).BaseLoggerConfig
78
81
  if name == "BestCheckpointCallbackConfig":
79
82
  return importlib.import_module(
80
83
  "nshtrainer.trainer._config"
@@ -85,10 +88,6 @@ else:
85
88
  return importlib.import_module(
86
89
  "nshtrainer.trainer._config"
87
90
  ).CallbackConfigBase
88
- if name == "CheckpointLoadingConfig":
89
- return importlib.import_module(
90
- "nshtrainer.trainer._config"
91
- ).CheckpointLoadingConfig
92
91
  if name == "CheckpointSavingConfig":
93
92
  return importlib.import_module(
94
93
  "nshtrainer.trainer._config"
@@ -119,30 +118,28 @@ else:
119
118
  return importlib.import_module(
120
119
  "nshtrainer.trainer._config"
121
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.trainer._config"
124
+ ).LearningRateMonitorConfig
122
125
  if name == "LogEpochCallbackConfig":
123
126
  return importlib.import_module(
124
127
  "nshtrainer.trainer._config"
125
128
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
129
  if name == "MetricConfig":
129
130
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
131
+ if name == "NormLoggingCallbackConfig":
131
132
  return importlib.import_module(
132
133
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
134
+ ).NormLoggingCallbackConfig
135
+ if name == "OnExceptionCheckpointCallbackConfig":
135
136
  return importlib.import_module(
136
137
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
138
+ ).OnExceptionCheckpointCallbackConfig
138
139
  if name == "RLPSanityChecksCallbackConfig":
139
140
  return importlib.import_module(
140
141
  "nshtrainer.trainer._config"
141
142
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
143
  if name == "SanityCheckingConfig":
147
144
  return importlib.import_module(
148
145
  "nshtrainer.trainer._config"
@@ -176,5 +173,4 @@ else:
176
173
 
177
174
  # Submodule exports
178
175
  from . import _config as _config
179
- from . import checkpoint_connector as checkpoint_connector
180
176
  from . import trainer as trainer
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
11
12
  from nshtrainer.trainer._config import (
12
13
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
13
14
  )
@@ -16,9 +17,6 @@ if TYPE_CHECKING:
16
17
  from nshtrainer.trainer._config import (
17
18
  CheckpointCallbackConfig as CheckpointCallbackConfig,
18
19
  )
19
- from nshtrainer.trainer._config import (
20
- CheckpointLoadingConfig as CheckpointLoadingConfig,
21
- )
22
20
  from nshtrainer.trainer._config import (
23
21
  CheckpointSavingConfig as CheckpointSavingConfig,
24
22
  )
@@ -38,20 +36,21 @@ if TYPE_CHECKING:
38
36
  from nshtrainer.trainer._config import (
39
37
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
40
38
  )
39
+ from nshtrainer.trainer._config import (
40
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
41
+ )
41
42
  from nshtrainer.trainer._config import (
42
43
  LogEpochCallbackConfig as LogEpochCallbackConfig,
43
44
  )
44
45
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
45
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
46
46
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
47
47
  from nshtrainer.trainer._config import (
48
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
48
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
49
49
  )
50
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
51
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
52
50
  from nshtrainer.trainer._config import (
53
- ReproducibilityConfig as ReproducibilityConfig,
51
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
54
52
  )
53
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
55
54
  from nshtrainer.trainer._config import (
56
55
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
57
56
  )
@@ -75,6 +74,10 @@ else:
75
74
  return importlib.import_module(
76
75
  "nshtrainer.trainer._config"
77
76
  ).ActSaveLoggerConfig
77
+ if name == "BaseLoggerConfig":
78
+ return importlib.import_module(
79
+ "nshtrainer.trainer._config"
80
+ ).BaseLoggerConfig
78
81
  if name == "BestCheckpointCallbackConfig":
79
82
  return importlib.import_module(
80
83
  "nshtrainer.trainer._config"
@@ -85,10 +88,6 @@ else:
85
88
  return importlib.import_module(
86
89
  "nshtrainer.trainer._config"
87
90
  ).CallbackConfigBase
88
- if name == "CheckpointLoadingConfig":
89
- return importlib.import_module(
90
- "nshtrainer.trainer._config"
91
- ).CheckpointLoadingConfig
92
91
  if name == "CheckpointSavingConfig":
93
92
  return importlib.import_module(
94
93
  "nshtrainer.trainer._config"
@@ -119,30 +118,28 @@ else:
119
118
  return importlib.import_module(
120
119
  "nshtrainer.trainer._config"
121
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.trainer._config"
124
+ ).LearningRateMonitorConfig
122
125
  if name == "LogEpochCallbackConfig":
123
126
  return importlib.import_module(
124
127
  "nshtrainer.trainer._config"
125
128
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
129
  if name == "MetricConfig":
129
130
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
131
+ if name == "NormLoggingCallbackConfig":
131
132
  return importlib.import_module(
132
133
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
134
+ ).NormLoggingCallbackConfig
135
+ if name == "OnExceptionCheckpointCallbackConfig":
135
136
  return importlib.import_module(
136
137
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
138
+ ).OnExceptionCheckpointCallbackConfig
138
139
  if name == "RLPSanityChecksCallbackConfig":
139
140
  return importlib.import_module(
140
141
  "nshtrainer.trainer._config"
141
142
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
143
  if name == "SanityCheckingConfig":
147
144
  return importlib.import_module(
148
145
  "nshtrainer.trainer._config"
@@ -8,8 +8,6 @@ from typing import Any, Generic, cast
8
8
  import nshconfig as C
9
9
  import torch
10
10
  from lightning.pytorch import LightningDataModule
11
- from lightning.pytorch.utilities.model_helpers import is_overridden
12
- from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
11
  from typing_extensions import Never, TypeVar, deprecated, override
14
12
 
15
13
  from ..model.mixins.callback import CallbackRegistrarModuleMixin