nshtrainer 1.0.0b11__tar.gz → 1.0.0b13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (146) hide show
  1. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/pyproject.toml +1 -1
  3. nshtrainer-1.0.0b13/src/nshtrainer/callbacks/lr_monitor.py +31 -0
  4. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/__init__.py +5 -13
  5. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/__init__.py +8 -0
  6. nshtrainer-1.0.0b13/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +31 -0
  7. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/__init__.py +19 -15
  8. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/_config/__init__.py +19 -15
  9. nshtrainer-1.0.0b13/src/nshtrainer/data/datamodule.py +124 -0
  10. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/base.py +100 -2
  11. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_config.py +95 -147
  12. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/trainer.py +48 -76
  13. nshtrainer-1.0.0b11/src/nshtrainer/data/datamodule.py +0 -57
  14. nshtrainer-1.0.0b11/src/nshtrainer/scripts/find_packages.py +0 -52
  15. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/README.md +0 -0
  16. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/__init__.py +0 -0
  17. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_callback.py +0 -0
  18. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/loader.py +0 -0
  19. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  20. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/saver.py +0 -0
  21. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_directory.py +0 -0
  22. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_experimental/__init__.py +0 -0
  23. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_hf_hub.py +0 -0
  24. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/__init__.py +0 -0
  25. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/actsave.py +0 -0
  26. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/base.py +0 -0
  27. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  28. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  29. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  30. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  31. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  32. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  33. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  34. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  35. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/ema.py +0 -0
  36. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  37. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  38. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/interval.py +0 -0
  39. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  40. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  41. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/print_table.py +0 -0
  42. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  43. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  44. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/timer.py +0 -0
  45. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  46. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  47. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  48. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -0
  49. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  50. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  51. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  52. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  53. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  54. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  56. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  58. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  59. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  60. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  61. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  62. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  63. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  64. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  65. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  66. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  67. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  68. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  69. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  70. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  71. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  72. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  73. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  74. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
  75. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  76. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  77. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  78. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  79. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  80. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
  81. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  82. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  83. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  84. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  85. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/__init__.py +0 -0
  86. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  87. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  88. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  89. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  90. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  91. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  92. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  93. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  94. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -0
  95. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  96. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/__init__.py +0 -0
  97. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  98. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  99. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  100. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  101. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/data/__init__.py +0 -0
  102. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  103. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/data/transform.py +0 -0
  104. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/__init__.py +0 -0
  105. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/_base.py +0 -0
  106. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/actsave.py +0 -0
  107. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/csv.py +0 -0
  108. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/tensorboard.py +0 -0
  109. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/wandb.py +0 -0
  110. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  111. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  112. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  113. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  114. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/__init__.py +0 -0
  115. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/_config.py +0 -0
  116. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/__init__.py +0 -0
  117. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/callback.py +0 -0
  118. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/debug.py +0 -0
  119. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/logger.py +0 -0
  120. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/__init__.py +0 -0
  121. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/mlp.py +0 -0
  122. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_dict.py +0 -0
  123. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_list.py +0 -0
  124. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/nonlinearity.py +0 -0
  125. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/optimizer.py +0 -0
  126. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/__init__.py +0 -0
  127. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/_base.py +0 -0
  128. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/advanced.py +0 -0
  129. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/pytorch.py +0 -0
  130. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/simple.py +0 -0
  131. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/__init__.py +0 -0
  132. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  133. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  134. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/signal_connector.py +0 -0
  135. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_environment_info.py +0 -0
  136. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_useful_types.py +0 -0
  137. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/bf16.py +0 -0
  138. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/__init__.py +0 -0
  139. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/dtype.py +0 -0
  140. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/duration.py +0 -0
  141. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/environment.py +0 -0
  142. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/path.py +0 -0
  143. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/seed.py +0 -0
  144. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/slurm.py +0 -0
  145. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typed.py +0 -0
  146. {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b11
3
+ Version: 1.0.0b13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta11"
3
+ version = "1.0.0-beta13"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.callbacks import LearningRateMonitor
6
+
7
+ from .base import CallbackConfigBase
8
+
9
+
10
+ class LearningRateMonitorConfig(CallbackConfigBase):
11
+ logging_interval: Literal["step", "epoch"] | None = None
12
+ """
13
+ Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
14
+ """
15
+
16
+ log_momentum: bool = False
17
+ """
18
+ Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
19
+ """
20
+
21
+ log_weight_decay: bool = False
22
+ """
23
+ Option to also log the weight decay values of the optimizer. Defaults to False.
24
+ """
25
+
26
+ def create_callbacks(self, trainer_config):
27
+ yield LearningRateMonitor(
28
+ logging_interval=self.logging_interval,
29
+ log_momentum=self.log_momentum,
30
+ log_weight_decay=self.log_weight_decay,
31
+ )
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
132
132
  from nshtrainer.trainer._config import (
133
133
  GradientClippingConfig as GradientClippingConfig,
134
134
  )
135
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
136
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
137
135
  from nshtrainer.trainer._config import (
138
- ReproducibilityConfig as ReproducibilityConfig,
136
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
139
137
  )
140
138
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
141
139
  from nshtrainer.util._environment_info import (
@@ -325,6 +323,10 @@ else:
325
323
  ).LastCheckpointStrategyConfig
326
324
  if name == "LeakyReLUNonlinearityConfig":
327
325
  return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
326
+ if name == "LearningRateMonitorConfig":
327
+ return importlib.import_module(
328
+ "nshtrainer.trainer._config"
329
+ ).LearningRateMonitorConfig
328
330
  if name == "LinearWarmupCosineDecayLRSchedulerConfig":
329
331
  return importlib.import_module(
330
332
  "nshtrainer.lr_scheduler"
@@ -333,8 +335,6 @@ else:
333
335
  return importlib.import_module(
334
336
  "nshtrainer.callbacks"
335
337
  ).LogEpochCallbackConfig
336
- if name == "LoggingConfig":
337
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
338
338
  if name == "MLPConfig":
339
339
  return importlib.import_module("nshtrainer.nn").MLPConfig
340
340
  if name == "MetricConfig":
@@ -349,10 +349,6 @@ else:
349
349
  return importlib.import_module(
350
350
  "nshtrainer.callbacks"
351
351
  ).OnExceptionCheckpointCallbackConfig
352
- if name == "OptimizationConfig":
353
- return importlib.import_module(
354
- "nshtrainer.trainer._config"
355
- ).OptimizationConfig
356
352
  if name == "OptimizerConfigBase":
357
353
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
358
354
  if name == "PReLUConfig":
@@ -373,10 +369,6 @@ else:
373
369
  return importlib.import_module(
374
370
  "nshtrainer.lr_scheduler"
375
371
  ).ReduceLROnPlateauConfig
376
- if name == "ReproducibilityConfig":
377
- return importlib.import_module(
378
- "nshtrainer.trainer._config"
379
- ).ReproducibilityConfig
380
372
  if name == "SanityCheckingConfig":
381
373
  return importlib.import_module(
382
374
  "nshtrainer.trainer._config"
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
62
62
  CheckpointMetadata as CheckpointMetadata,
63
63
  )
64
64
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
65
+ from nshtrainer.callbacks.lr_monitor import (
66
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
67
+ )
65
68
  else:
66
69
 
67
70
  def __getattr__(name):
@@ -115,6 +118,10 @@ else:
115
118
  return importlib.import_module(
116
119
  "nshtrainer.callbacks"
117
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.callbacks.lr_monitor"
124
+ ).LearningRateMonitorConfig
118
125
  if name == "LogEpochCallbackConfig":
119
126
  return importlib.import_module(
120
127
  "nshtrainer.callbacks"
@@ -167,6 +174,7 @@ from . import ema as ema
167
174
  from . import finite_checks as finite_checks
168
175
  from . import gradient_skipping as gradient_skipping
169
176
  from . import log_epoch as log_epoch
177
+ from . import lr_monitor as lr_monitor
170
178
  from . import norm_logging as norm_logging
171
179
  from . import print_table as print_table
172
180
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.lr_monitor import (
12
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.lr_monitor"
24
+ ).CallbackConfigBase
25
+ if name == "LearningRateMonitorConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.lr_monitor"
28
+ ).LearningRateMonitorConfig
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
+
31
+ # Submodule exports
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
11
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
13
  from nshtrainer.trainer._config import (
13
14
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
15
  )
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
39
40
  from nshtrainer.trainer._config import (
40
41
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
42
  )
43
+ from nshtrainer.trainer._config import (
44
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
45
+ )
42
46
  from nshtrainer.trainer._config import (
43
47
  LogEpochCallbackConfig as LogEpochCallbackConfig,
44
48
  )
45
49
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
47
50
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
51
  from nshtrainer.trainer._config import (
49
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
52
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
52
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
53
54
  from nshtrainer.trainer._config import (
54
- ReproducibilityConfig as ReproducibilityConfig,
55
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
55
56
  )
57
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
56
58
  from nshtrainer.trainer._config import (
57
59
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
58
60
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
11
12
  from nshtrainer.trainer._config import (
12
13
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
13
14
  )
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
38
39
  from nshtrainer.trainer._config import (
39
40
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
40
41
  )
42
+ from nshtrainer.trainer._config import (
43
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
44
+ )
41
45
  from nshtrainer.trainer._config import (
42
46
  LogEpochCallbackConfig as LogEpochCallbackConfig,
43
47
  )
44
48
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
45
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
46
49
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
47
50
  from nshtrainer.trainer._config import (
48
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
51
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
49
52
  )
50
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
51
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
52
53
  from nshtrainer.trainer._config import (
53
- ReproducibilityConfig as ReproducibilityConfig,
54
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
54
55
  )
56
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
57
59
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -0,0 +1,124 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Callable, Mapping
5
+ from pathlib import Path
6
+ from typing import Any, Generic, cast
7
+
8
+ import nshconfig as C
9
+ import torch
10
+ from lightning.pytorch import LightningDataModule
11
+ from typing_extensions import Never, TypeVar, deprecated, override
12
+
13
+ from ..model.mixins.callback import CallbackRegistrarModuleMixin
14
+ from ..model.mixins.debug import _DebugModuleMixin
15
+
16
+ THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
17
+
18
+
19
+ class LightningDataModuleBase(
20
+ _DebugModuleMixin,
21
+ CallbackRegistrarModuleMixin,
22
+ LightningDataModule,
23
+ ABC,
24
+ Generic[THparams],
25
+ ):
26
+ @property
27
+ @override
28
+ def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
29
+ return cast(THparams, super().hparams)
30
+
31
+ @property
32
+ @override
33
+ def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
34
+ hparams = cast(THparams, super().hparams_initial)
35
+ return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
36
+
37
+ @property
38
+ @deprecated("Use `hparams` instead")
39
+ def config(self):
40
+ return cast(Never, self.hparams)
41
+
42
+ @classmethod
43
+ @abstractmethod
44
+ def hparams_cls(cls) -> type[THparams]: ...
45
+
46
+ @override
47
+ def __init__(self, hparams: THparams | Mapping[str, Any]):
48
+ super().__init__()
49
+
50
+ # Validate and save hyperparameters
51
+ hparams_cls = self.hparams_cls()
52
+ if isinstance(hparams, Mapping):
53
+ hparams = hparams_cls.model_validate(hparams)
54
+ elif not isinstance(hparams, hparams_cls):
55
+ raise TypeError(
56
+ f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
57
+ )
58
+ hparams = hparams.model_deep_validate()
59
+ self.save_hyperparameters(hparams)
60
+
61
+ @override
62
+ @classmethod
63
+ def load_from_checkpoint(cls, *args, **kwargs) -> Never:
64
+ raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
65
+
66
+ @classmethod
67
+ def hparams_from_checkpoint(
68
+ cls,
69
+ ckpt_or_path: dict[str, Any] | str | Path,
70
+ /,
71
+ strict: bool | None = None,
72
+ *,
73
+ update_hparams: Callable[[THparams], THparams] | None = None,
74
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
75
+ ):
76
+ if isinstance(ckpt_or_path, dict):
77
+ ckpt = ckpt_or_path
78
+ else:
79
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
80
+
81
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
82
+ raise ValueError(
83
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
84
+ )
85
+ if update_hparams_dict is not None:
86
+ hparams = update_hparams_dict(hparams)
87
+
88
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
89
+ if update_hparams is not None:
90
+ hparams = update_hparams(hparams)
91
+
92
+ return hparams
93
+
94
+ @classmethod
95
+ def from_checkpoint(
96
+ cls,
97
+ ckpt_or_path: dict[str, Any] | str | Path,
98
+ /,
99
+ strict: bool | None = None,
100
+ map_location: torch.serialization.MAP_LOCATION = None,
101
+ *,
102
+ update_hparams: Callable[[THparams], THparams] | None = None,
103
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
104
+ ):
105
+ # Load checkpoint
106
+ if isinstance(ckpt_or_path, Mapping):
107
+ ckpt = ckpt_or_path
108
+ else:
109
+ ckpt = torch.load(ckpt_or_path, map_location=map_location)
110
+
111
+ # Load hyperparameters from checkpoint
112
+ hparams = cls.hparams_from_checkpoint(
113
+ ckpt,
114
+ strict=strict,
115
+ update_hparams=update_hparams,
116
+ update_hparams_dict=update_hparams_dict,
117
+ )
118
+
119
+ # Load datamodule from checkpoint
120
+ datamodule = cls(hparams)
121
+ if datamodule.__class__.__qualname__ in ckpt:
122
+ datamodule.load_state_dict(ckpt[datamodule.__class__.__qualname__])
123
+
124
+ return datamodule
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import Mapping
5
+ from collections.abc import Callable, Mapping
6
+ from pathlib import Path
6
7
  from typing import Any, Generic, Literal, cast
7
8
 
8
9
  import nshconfig as C
@@ -10,11 +11,13 @@ import torch
10
11
  import torch.distributed
11
12
  from lightning.pytorch import LightningModule
12
13
  from lightning.pytorch.profilers import PassThroughProfiler, Profiler
14
+ from lightning.pytorch.utilities.model_helpers import is_overridden
15
+ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
16
  from typing_extensions import Never, TypeVar, deprecated, override
14
17
 
15
18
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
19
  from .mixins.callback import CallbackModuleMixin
17
- from .mixins.debug import _DebugModuleMixin, _trainer
20
+ from .mixins.debug import _DebugModuleMixin
18
21
  from .mixins.logger import LoggerLightningModuleMixin
19
22
 
20
23
  log = logging.getLogger(__name__)
@@ -241,3 +244,98 @@ class LightningModuleBase(
241
244
  loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
242
245
  loss = cast(torch.Tensor, loss)
243
246
  return loss
247
+
248
+ @override
249
+ @classmethod
250
+ def load_from_checkpoint(cls, *args, **kwargs) -> Never:
251
+ raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
252
+
253
+ @classmethod
254
+ def hparams_from_checkpoint(
255
+ cls,
256
+ ckpt_or_path: dict[str, Any] | str | Path,
257
+ /,
258
+ strict: bool | None = None,
259
+ *,
260
+ update_hparams: Callable[[THparams], THparams] | None = None,
261
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
262
+ ):
263
+ if isinstance(ckpt_or_path, dict):
264
+ ckpt = ckpt_or_path
265
+ else:
266
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
267
+
268
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
269
+ raise ValueError(
270
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
271
+ )
272
+ if update_hparams_dict is not None:
273
+ hparams = update_hparams_dict(hparams)
274
+
275
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
276
+ if update_hparams is not None:
277
+ hparams = update_hparams(hparams)
278
+
279
+ return hparams
280
+
281
+ @classmethod
282
+ def from_checkpoint(
283
+ cls,
284
+ ckpt_or_path: dict[str, Any] | str | Path,
285
+ /,
286
+ strict: bool | None = None,
287
+ map_location: torch.serialization.MAP_LOCATION = None,
288
+ *,
289
+ update_hparams: Callable[[THparams], THparams] | None = None,
290
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
291
+ ):
292
+ # Load checkpoint
293
+ if isinstance(ckpt_or_path, Mapping):
294
+ ckpt = ckpt_or_path
295
+ else:
296
+ ckpt = torch.load(ckpt_or_path, map_location=map_location)
297
+
298
+ # Load hyperparameters from checkpoint
299
+ hparams = cls.hparams_from_checkpoint(
300
+ ckpt,
301
+ strict=strict,
302
+ update_hparams=update_hparams,
303
+ update_hparams_dict=update_hparams_dict,
304
+ )
305
+
306
+ # Load model from checkpoint
307
+ model = cls(hparams)
308
+
309
+ # Load model state from checkpoint
310
+ if (
311
+ model._strict_loading is not None
312
+ and strict is not None
313
+ and strict != model.strict_loading
314
+ ):
315
+ raise ValueError(
316
+ f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
317
+ f" `{cls.__name__}.strict_loading={model.strict_loading!r}. Please set the same value for both of them."
318
+ )
319
+ strict = model.strict_loading if strict is None else strict
320
+
321
+ if is_overridden("configure_model", model):
322
+ model.configure_model()
323
+
324
+ # give model a chance to load something
325
+ model.on_load_checkpoint(ckpt)
326
+
327
+ # load the state_dict on the model automatically
328
+
329
+ keys = model.load_state_dict(ckpt["state_dict"], strict=strict)
330
+
331
+ if not strict:
332
+ if keys.missing_keys:
333
+ rank_zero_warn(
334
+ f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
335
+ )
336
+ if keys.unexpected_keys:
337
+ rank_zero_warn(
338
+ f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
339
+ )
340
+
341
+ return model