nshtrainer 0.43.0__tar.gz → 0.44.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/lr_scheduler/_base.py +7 -11
  4. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +16 -17
  5. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +6 -6
  6. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/README.md +0 -0
  7. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/__init__.py +0 -0
  8. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_callback.py +0 -0
  9. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  10. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  11. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  12. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_directory.py +0 -0
  13. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  14. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/_hf_hub.py +0 -0
  15. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  16. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  17. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  18. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/base.py +0 -0
  19. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  20. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  21. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  22. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  23. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  24. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  25. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  26. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  27. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/ema.py +0 -0
  28. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  29. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  30. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/interval.py +0 -0
  31. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  32. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  33. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  34. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  35. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  36. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  37. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/timer.py +0 -0
  38. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  39. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  40. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/__init__.py +0 -0
  41. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/_checkpoint/loader/__init__.py +0 -0
  42. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/_checkpoint/metadata/__init__.py +0 -0
  43. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/_directory/__init__.py +0 -0
  44. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/_hf_hub/__init__.py +0 -0
  45. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/__init__.py +0 -0
  46. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/actsave/__init__.py +0 -0
  47. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/base/__init__.py +0 -0
  48. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/checkpoint/__init__.py +0 -0
  49. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +0 -0
  50. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  51. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  52. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  53. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/debug_flag/__init__.py +0 -0
  54. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/directory_setup/__init__.py +0 -0
  55. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/early_stopping/__init__.py +0 -0
  56. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/ema/__init__.py +0 -0
  57. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/finite_checks/__init__.py +0 -0
  58. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +0 -0
  59. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/norm_logging/__init__.py +0 -0
  60. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/print_table/__init__.py +0 -0
  61. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +0 -0
  62. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +0 -0
  63. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +0 -0
  64. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/timer/__init__.py +0 -0
  65. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +0 -0
  66. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +0 -0
  67. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/loggers/__init__.py +0 -0
  68. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/loggers/_base/__init__.py +0 -0
  69. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/loggers/csv/__init__.py +0 -0
  70. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/loggers/tensorboard/__init__.py +0 -0
  71. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/loggers/wandb/__init__.py +0 -0
  72. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/lr_scheduler/__init__.py +0 -0
  73. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/lr_scheduler/_base/__init__.py +0 -0
  74. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  75. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  76. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/metrics/__init__.py +0 -0
  77. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/metrics/_config/__init__.py +0 -0
  78. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/model/__init__.py +0 -0
  79. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/model/base/__init__.py +0 -0
  80. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/model/config/__init__.py +0 -0
  81. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/model/mixins/logger/__init__.py +0 -0
  82. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/nn/__init__.py +0 -0
  83. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/nn/mlp/__init__.py +0 -0
  84. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/nn/nonlinearity/__init__.py +0 -0
  85. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/optimizer/__init__.py +0 -0
  86. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/profiler/__init__.py +0 -0
  87. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/profiler/_base/__init__.py +0 -0
  88. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/profiler/advanced/__init__.py +0 -0
  89. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/profiler/pytorch/__init__.py +0 -0
  90. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/profiler/simple/__init__.py +0 -0
  91. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/runner/__init__.py +0 -0
  92. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/trainer/_config/__init__.py +0 -0
  93. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +0 -0
  94. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/util/_environment_info/__init__.py +0 -0
  95. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/util/config/__init__.py +0 -0
  96. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/util/config/dtype/__init__.py +0 -0
  97. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/config/util/config/duration/__init__.py +0 -0
  98. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/data/__init__.py +0 -0
  99. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  100. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/data/datamodule.py +0 -0
  101. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/data/transform.py +0 -0
  102. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/__init__.py +0 -0
  103. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/_experimental.py +0 -0
  104. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/actsave.py +0 -0
  105. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/callbacks.py +0 -0
  106. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/config.py +0 -0
  107. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/data.py +0 -0
  108. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/log.py +0 -0
  109. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  110. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/model.py +0 -0
  111. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/nn.py +0 -0
  112. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/optimizer.py +0 -0
  113. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/runner.py +0 -0
  114. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/snapshot.py +0 -0
  115. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/snoop.py +0 -0
  116. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/trainer.py +0 -0
  117. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/typecheck.py +0 -0
  118. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/ll/util.py +0 -0
  119. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/loggers/__init__.py +0 -0
  120. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/loggers/_base.py +0 -0
  121. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/loggers/csv.py +0 -0
  122. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
  123. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/loggers/wandb.py +0 -0
  124. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  125. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/metrics/__init__.py +0 -0
  126. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/metrics/_config.py +0 -0
  127. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/model/__init__.py +0 -0
  128. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/model/base.py +0 -0
  129. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/model/config.py +0 -0
  130. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/model/mixins/callback.py +0 -0
  131. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/model/mixins/logger.py +0 -0
  132. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/nn/__init__.py +0 -0
  133. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/nn/mlp.py +0 -0
  134. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/nn/module_dict.py +0 -0
  135. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/nn/module_list.py +0 -0
  136. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  137. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/optimizer.py +0 -0
  138. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/profiler/__init__.py +0 -0
  139. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/profiler/_base.py +0 -0
  140. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/profiler/advanced.py +0 -0
  141. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/profiler/pytorch.py +0 -0
  142. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/profiler/simple.py +0 -0
  143. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/runner.py +0 -0
  144. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  145. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/trainer/__init__.py +0 -0
  146. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/trainer/_config.py +0 -0
  147. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  148. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  149. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  150. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/trainer/trainer.py +0 -0
  151. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/_environment_info.py +0 -0
  152. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/_useful_types.py +0 -0
  153. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/bf16.py +0 -0
  154. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/config/__init__.py +0 -0
  155. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/config/dtype.py +0 -0
  156. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/config/duration.py +0 -0
  157. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/environment.py +0 -0
  158. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/path.py +0 -0
  159. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/seed.py +0 -0
  160. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/slurm.py +0 -0
  161. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/typed.py +0 -0
  162. {nshtrainer-0.43.0 → nshtrainer-0.44.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.43.0
3
+ Version: 0.44.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.43.0"
3
+ version = "0.44.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -11,7 +11,7 @@ from lightning.pytorch.utilities.types import (
11
11
  LRSchedulerTypeUnion,
12
12
  )
13
13
  from torch.optim import Optimizer
14
- from typing_extensions import NotRequired, TypedDict
14
+ from typing_extensions import Never, NotRequired, TypedDict
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ..model.base import LightningModuleBase
@@ -44,20 +44,18 @@ class LRSchedulerConfigBase(C.Config, ABC):
44
44
 
45
45
  @abstractmethod
46
46
  def create_scheduler_impl(
47
- self,
48
- optimizer: Optimizer,
49
- lightning_module: "LightningModuleBase",
50
- lr: float,
47
+ self, optimizer: Optimizer, lightning_module: LightningModuleBase
51
48
  ) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
52
49
 
53
50
  def create_scheduler(
54
51
  self,
55
52
  optimizer: Optimizer,
56
- lightning_module: "LightningModuleBase",
57
- lr: float,
53
+ lightning_module: LightningModuleBase,
54
+ lr: Never
55
+ | None = None, # Backward compatibility, should be removed in the future
58
56
  ) -> LRSchedulerConfigType:
59
57
  # Create the scheduler.
60
- scheduler = self.create_scheduler_impl(optimizer, lightning_module, lr)
58
+ scheduler = self.create_scheduler_impl(optimizer, lightning_module)
61
59
 
62
60
  # If the scheduler is not a `LRSchedulerConfigType`, then make it one.
63
61
  if not isinstance(scheduler, Mapping):
@@ -89,9 +87,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
89
87
 
90
88
  return scheduler
91
89
 
92
- def compute_num_steps_per_epoch(
93
- self, lightning_module: "LightningModuleBase"
94
- ) -> int:
90
+ def compute_num_steps_per_epoch(self, lightning_module: LightningModuleBase) -> int:
95
91
  trainer = lightning_module.trainer
96
92
  # Use the Lightning trainer to convert the epoch-based values to step-based values
97
93
  _ = trainer.estimated_stepping_batches
@@ -20,21 +20,21 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
20
20
  optimizer: Optimizer,
21
21
  warmup_epochs: int,
22
22
  max_epochs: int,
23
- warmup_start_lr: float = 0.0,
24
- eta_min: float = 0.0,
23
+ warmup_start_lr_factor: float = 0.0,
24
+ eta_min_factor: float = 0.0,
25
25
  last_epoch: int = -1,
26
26
  should_restart: bool = True,
27
27
  ) -> None:
28
28
  self.warmup_epochs = warmup_epochs
29
29
  self.max_epochs = max_epochs
30
- self.warmup_start_lr = warmup_start_lr
31
- self.eta_min = eta_min
30
+ self.warmup_start_lr_factor = warmup_start_lr_factor
31
+ self.eta_min_factor = eta_min_factor
32
32
  self.should_restart = should_restart
33
33
 
34
34
  super().__init__(optimizer, last_epoch)
35
35
 
36
36
  @override
37
- def get_lr(self) -> list[float]: # pyright: ignore[reportIncompatibleMethodOverride]
37
+ def get_lr(self) -> list[float]:
38
38
  if not self._get_lr_called_within_step:
39
39
  warnings.warn(
40
40
  "To get the last learning rate computed by the scheduler, "
@@ -43,25 +43,26 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
43
43
  )
44
44
 
45
45
  if self.last_epoch == 0:
46
- return [self.warmup_start_lr] * len(self.base_lrs)
46
+ return [self.warmup_start_lr_factor * base_lr for base_lr in self.base_lrs]
47
47
  if self.last_epoch < self.warmup_epochs:
48
48
  return [
49
49
  group["lr"]
50
- + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
50
+ + (base_lr - self.warmup_start_lr_factor * base_lr)
51
+ / (self.warmup_epochs - 1)
51
52
  for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
52
53
  ]
53
54
  if self.last_epoch == self.warmup_epochs:
54
55
  return self.base_lrs
55
56
 
56
57
  if not self.should_restart and self.last_epoch >= self.max_epochs:
57
- return [self.eta_min] * len(self.base_lrs)
58
+ return [self.eta_min_factor * base_lr for base_lr in self.base_lrs]
58
59
 
59
60
  if (self.last_epoch - 1 - self.max_epochs) % (
60
61
  2 * (self.max_epochs - self.warmup_epochs)
61
62
  ) == 0:
62
63
  return [
63
64
  group["lr"]
64
- + (base_lr - self.eta_min)
65
+ + (base_lr - self.eta_min_factor * base_lr)
65
66
  * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)))
66
67
  / 2
67
68
  for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
@@ -84,9 +85,9 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
84
85
  / (self.max_epochs - self.warmup_epochs)
85
86
  )
86
87
  )
87
- * (group["lr"] - self.eta_min)
88
- + self.eta_min
89
- for group in self.optimizer.param_groups
88
+ * (group["lr"] - self.eta_min_factor * base_lr)
89
+ + self.eta_min_factor * base_lr
90
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
90
91
  ]
91
92
 
92
93
 
@@ -121,12 +122,10 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
121
122
  }
122
123
 
123
124
  @override
124
- def create_scheduler_impl(self, optimizer, lightning_module, lr):
125
+ def create_scheduler_impl(self, optimizer, lightning_module):
125
126
  num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
126
127
  warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
127
128
  max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
128
- warmup_start_lr = self.warmup_start_lr_factor * lr
129
- min_lr = self.min_lr_factor * lr
130
129
 
131
130
  # Warmup and max steps should be at least 1.
132
131
  warmup_steps = max(warmup_steps, 1)
@@ -137,8 +136,8 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
137
136
  optimizer=optimizer,
138
137
  warmup_epochs=warmup_steps,
139
138
  max_epochs=max_steps,
140
- warmup_start_lr=warmup_start_lr,
141
- eta_min=min_lr,
139
+ warmup_start_lr_factor=self.warmup_start_lr_factor,
140
+ eta_min_factor=self.min_lr_factor,
142
141
  should_restart=self.annealing,
143
142
  )
144
143
  return scheduler
@@ -22,21 +22,21 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
22
22
  """Metric to monitor.
23
23
  If not provided, the primary metric of the runner will be used."""
24
24
 
25
- patience: int = 10
25
+ patience: int
26
26
  r"""Number of epochs with no improvement after which learning rate will be reduced."""
27
27
 
28
- factor: float = 0.1
28
+ factor: float
29
29
  r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
30
30
 
31
+ cooldown: int = 0
32
+ r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
33
+
31
34
  min_lr: float | list[float] = 0.0
32
35
  r"""A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively."""
33
36
 
34
37
  eps: float = 1.0e-8
35
38
  r"""Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored."""
36
39
 
37
- cooldown: int = 0
38
- r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
39
-
40
40
  threshold: float = 1.0e-4
41
41
  r"""Threshold for measuring the new optimum, to only focus on significant changes."""
42
42
 
@@ -45,7 +45,7 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
45
45
 
46
46
  @override
47
47
  def create_scheduler_impl(
48
- self, optimizer, lightning_module, lr
48
+ self, optimizer, lightning_module
49
49
  ) -> LRSchedulerConfigType:
50
50
  if (metric := self.metric) is None:
51
51
  lm_config = cast("BaseConfig", lightning_module.config)
File without changes