nshtrainer 0.42.0__py3-none-any.whl → 0.44.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +130 -90
  35. nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
  37. nshtrainer/config/_directory/__init__.py +9 -3
  38. nshtrainer/config/_hf_hub/__init__.py +6 -4
  39. nshtrainer/config/callbacks/__init__.py +82 -42
  40. nshtrainer/config/callbacks/actsave/__init__.py +4 -2
  41. nshtrainer/config/callbacks/base/__init__.py +2 -0
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
  50. nshtrainer/config/callbacks/ema/__init__.py +5 -3
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
  54. nshtrainer/config/callbacks/print_table/__init__.py +7 -5
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
  58. nshtrainer/config/callbacks/timer/__init__.py +9 -5
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
  61. nshtrainer/config/loggers/__init__.py +18 -10
  62. nshtrainer/config/loggers/_base/__init__.py +2 -0
  63. nshtrainer/config/loggers/csv/__init__.py +2 -0
  64. nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
  65. nshtrainer/config/loggers/wandb/__init__.py +18 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +2 -0
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
  70. nshtrainer/config/metrics/__init__.py +2 -0
  71. nshtrainer/config/metrics/_config/__init__.py +2 -0
  72. nshtrainer/config/model/__init__.py +8 -6
  73. nshtrainer/config/model/base/__init__.py +4 -2
  74. nshtrainer/config/model/config/__init__.py +8 -6
  75. nshtrainer/config/model/mixins/logger/__init__.py +2 -0
  76. nshtrainer/config/nn/__init__.py +16 -14
  77. nshtrainer/config/nn/mlp/__init__.py +2 -0
  78. nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
  79. nshtrainer/config/optimizer/__init__.py +2 -0
  80. nshtrainer/config/profiler/__init__.py +2 -0
  81. nshtrainer/config/profiler/_base/__init__.py +2 -0
  82. nshtrainer/config/profiler/advanced/__init__.py +6 -4
  83. nshtrainer/config/profiler/pytorch/__init__.py +6 -4
  84. nshtrainer/config/profiler/simple/__init__.py +6 -4
  85. nshtrainer/config/runner/__init__.py +2 -0
  86. nshtrainer/config/trainer/_config/__init__.py +43 -39
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
  88. nshtrainer/config/util/_environment_info/__init__.py +20 -18
  89. nshtrainer/config/util/config/__init__.py +2 -0
  90. nshtrainer/config/util/config/dtype/__init__.py +2 -0
  91. nshtrainer/config/util/config/duration/__init__.py +2 -0
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +8 -11
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +18 -17
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +8 -6
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.44.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.42.0.dist-info/RECORD +0 -162
  162. {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -43,30 +45,30 @@ else:
43
45
 
44
46
  if name in globals():
45
47
  return globals()[name]
46
- if name == "EnvironmentPackageConfig":
47
- return importlib.import_module(
48
- "nshtrainer.util._environment_info"
49
- ).EnvironmentPackageConfig
50
- if name == "EnvironmentSnapshotConfig":
48
+ if name == "EnvironmentLinuxEnvironmentConfig":
51
49
  return importlib.import_module(
52
50
  "nshtrainer.util._environment_info"
53
- ).EnvironmentSnapshotConfig
51
+ ).EnvironmentLinuxEnvironmentConfig
54
52
  if name == "EnvironmentLSFInformationConfig":
55
53
  return importlib.import_module(
56
54
  "nshtrainer.util._environment_info"
57
55
  ).EnvironmentLSFInformationConfig
58
- if name == "EnvironmentLinuxEnvironmentConfig":
56
+ if name == "EnvironmentGPUConfig":
59
57
  return importlib.import_module(
60
58
  "nshtrainer.util._environment_info"
61
- ).EnvironmentLinuxEnvironmentConfig
62
- if name == "EnvironmentSLURMInformationConfig":
59
+ ).EnvironmentGPUConfig
60
+ if name == "EnvironmentPackageConfig":
63
61
  return importlib.import_module(
64
62
  "nshtrainer.util._environment_info"
65
- ).EnvironmentSLURMInformationConfig
66
- if name == "EnvironmentConfig":
63
+ ).EnvironmentPackageConfig
64
+ if name == "EnvironmentHardwareConfig":
67
65
  return importlib.import_module(
68
66
  "nshtrainer.util._environment_info"
69
- ).EnvironmentConfig
67
+ ).EnvironmentHardwareConfig
68
+ if name == "EnvironmentSnapshotConfig":
69
+ return importlib.import_module(
70
+ "nshtrainer.util._environment_info"
71
+ ).EnvironmentSnapshotConfig
70
72
  if name == "EnvironmentClassInformationConfig":
71
73
  return importlib.import_module(
72
74
  "nshtrainer.util._environment_info"
@@ -75,18 +77,18 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.util._environment_info"
77
79
  ).GitRepositoryConfig
78
- if name == "EnvironmentCUDAConfig":
80
+ if name == "EnvironmentConfig":
79
81
  return importlib.import_module(
80
82
  "nshtrainer.util._environment_info"
81
- ).EnvironmentCUDAConfig
82
- if name == "EnvironmentGPUConfig":
83
+ ).EnvironmentConfig
84
+ if name == "EnvironmentCUDAConfig":
83
85
  return importlib.import_module(
84
86
  "nshtrainer.util._environment_info"
85
- ).EnvironmentGPUConfig
86
- if name == "EnvironmentHardwareConfig":
87
+ ).EnvironmentCUDAConfig
88
+ if name == "EnvironmentSLURMInformationConfig":
87
89
  return importlib.import_module(
88
90
  "nshtrainer.util._environment_info"
89
- ).EnvironmentHardwareConfig
91
+ ).EnvironmentSLURMInformationConfig
90
92
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
91
93
 
92
94
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from . import transform as dataset_transform
2
4
  from .balanced_batch_sampler import BalancedBatchSampler as BalancedBatchSampler
3
5
  from .datamodule import LightningDataModuleBase as LightningDataModuleBase
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import heapq
2
4
  import logging
3
5
  from typing import Any, Protocol, runtime_checkable
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from lightning.pytorch import LightningDataModule
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from collections.abc import Callable
3
5
  from typing import Any, cast
nshtrainer/ll/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import TypeAlias
2
4
 
3
5
  from . import _experimental as _experimental
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer._experimental import * # noqa: F403
nshtrainer/ll/actsave.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils.actsave import * # type: ignore # noqa: F403
2
4
 
3
5
  from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.callbacks import * # noqa: F403
nshtrainer/ll/config.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshconfig import * # type: ignore # noqa: F403
2
4
  from nshconfig import Config as TypedConfig # type: ignore # noqa: F401
3
5
 
nshtrainer/ll/data.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.data import * # noqa: F403
nshtrainer/ll/log.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils import init_python_logging as init_python_logging
2
4
  from nshutils import lovely as lovely
3
5
  from nshutils import pretty as pretty
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.lr_scheduler import * # noqa: F403
nshtrainer/ll/model.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.model import * # noqa: F403
2
4
 
3
5
  from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
nshtrainer/ll/nn.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.nn import * # noqa: F403
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.optimizer import * # noqa: F403
nshtrainer/ll/runner.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshrunner import SnapshotConfig as SnapshotConfig
2
4
 
3
5
  from nshtrainer.runner import * # type: ignore # noqa: F403
nshtrainer/ll/snapshot.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshsnap import * # pyright: ignore[reportWildcardImportFromLibrary] # noqa: F403
nshtrainer/ll/snoop.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils import snoop as snoop
nshtrainer/ll/trainer.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.trainer import * # noqa: F403
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils.typecheck import * # type: ignore # noqa: F403
nshtrainer/ll/util.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.util import * # noqa: F403
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from typing import TYPE_CHECKING
3
5
 
nshtrainer/loggers/csv.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Literal
2
4
 
3
5
  from typing_extensions import override
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib.metadata
2
4
  import logging
3
5
  from typing import TYPE_CHECKING, Literal
@@ -8,8 +10,8 @@ from packaging import version
8
10
  from typing_extensions import assert_never, override
9
11
 
10
12
  from ..callbacks.base import CallbackConfigBase
11
- from ..callbacks.wandb_upload_code import WandbUploadCodeConfig
12
- from ..callbacks.wandb_watch import WandbWatchConfig
13
+ from ..callbacks.wandb_upload_code import WandbUploadCodeCallbackConfig
14
+ from ..callbacks.wandb_watch import WandbWatchCallbackConfig
13
15
  from ._base import BaseLoggerConfig
14
16
 
15
17
  if TYPE_CHECKING:
@@ -92,10 +94,10 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
92
94
  - "none" or False: Do not log any checkpoints
93
95
  """
94
96
 
95
- log_code: WandbUploadCodeConfig | None = WandbUploadCodeConfig()
97
+ log_code: WandbUploadCodeCallbackConfig | None = WandbUploadCodeCallbackConfig()
96
98
  """WandB code upload configuration. Used to upload code to WandB."""
97
99
 
98
- watch: WandbWatchConfig | None = WandbWatchConfig()
100
+ watch: WandbWatchCallbackConfig | None = WandbWatchCallbackConfig()
99
101
  """WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
100
102
 
101
103
  offline: bool = False
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from abc import ABC, abstractmethod
3
5
  from collections.abc import Mapping
@@ -9,7 +11,7 @@ from lightning.pytorch.utilities.types import (
9
11
  LRSchedulerTypeUnion,
10
12
  )
11
13
  from torch.optim import Optimizer
12
- from typing_extensions import NotRequired, TypedDict
14
+ from typing_extensions import Never, NotRequired, TypedDict
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from ..model.base import LightningModuleBase
@@ -42,20 +44,17 @@ class LRSchedulerConfigBase(C.Config, ABC):
42
44
 
43
45
  @abstractmethod
44
46
  def create_scheduler_impl(
45
- self,
46
- optimizer: Optimizer,
47
- lightning_module: "LightningModuleBase",
48
- lr: float,
47
+ self, optimizer: Optimizer, lightning_module: LightningModuleBase
49
48
  ) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
50
49
 
51
50
  def create_scheduler(
52
51
  self,
53
52
  optimizer: Optimizer,
54
- lightning_module: "LightningModuleBase",
55
- lr: float,
53
+ lightning_module: LightningModuleBase,
54
+ lr: Never, # Backward compatibility, should be removed in the future
56
55
  ) -> LRSchedulerConfigType:
57
56
  # Create the scheduler.
58
- scheduler = self.create_scheduler_impl(optimizer, lightning_module, lr)
57
+ scheduler = self.create_scheduler_impl(optimizer, lightning_module)
59
58
 
60
59
  # If the scheduler is not a `LRSchedulerConfigType`, then make it one.
61
60
  if not isinstance(scheduler, Mapping):
@@ -87,9 +86,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
87
86
 
88
87
  return scheduler
89
88
 
90
- def compute_num_steps_per_epoch(
91
- self, lightning_module: "LightningModuleBase"
92
- ) -> int:
89
+ def compute_num_steps_per_epoch(self, lightning_module: LightningModuleBase) -> int:
93
90
  trainer = lightning_module.trainer
94
91
  # Use the Lightning trainer to convert the epoch-based values to step-based values
95
92
  _ = trainer.estimated_stepping_batches
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  import warnings
3
5
  from typing import Literal
@@ -18,21 +20,21 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
18
20
  optimizer: Optimizer,
19
21
  warmup_epochs: int,
20
22
  max_epochs: int,
21
- warmup_start_lr: float = 0.0,
22
- eta_min: float = 0.0,
23
+ warmup_start_lr_factor: float = 0.0,
24
+ eta_min_factor: float = 0.0,
23
25
  last_epoch: int = -1,
24
26
  should_restart: bool = True,
25
27
  ) -> None:
26
28
  self.warmup_epochs = warmup_epochs
27
29
  self.max_epochs = max_epochs
28
- self.warmup_start_lr = warmup_start_lr
29
- self.eta_min = eta_min
30
+ self.warmup_start_lr_factor = warmup_start_lr_factor
31
+ self.eta_min_factor = eta_min_factor
30
32
  self.should_restart = should_restart
31
33
 
32
34
  super().__init__(optimizer, last_epoch)
33
35
 
34
36
  @override
35
- def get_lr(self) -> list[float]: # pyright: ignore[reportIncompatibleMethodOverride]
37
+ def get_lr(self) -> list[float]:
36
38
  if not self._get_lr_called_within_step:
37
39
  warnings.warn(
38
40
  "To get the last learning rate computed by the scheduler, "
@@ -41,25 +43,26 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
41
43
  )
42
44
 
43
45
  if self.last_epoch == 0:
44
- return [self.warmup_start_lr] * len(self.base_lrs)
46
+ return [self.warmup_start_lr_factor * base_lr for base_lr in self.base_lrs]
45
47
  if self.last_epoch < self.warmup_epochs:
46
48
  return [
47
49
  group["lr"]
48
- + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
50
+ + (base_lr - self.warmup_start_lr_factor * base_lr)
51
+ / (self.warmup_epochs - 1)
49
52
  for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
50
53
  ]
51
54
  if self.last_epoch == self.warmup_epochs:
52
55
  return self.base_lrs
53
56
 
54
57
  if not self.should_restart and self.last_epoch >= self.max_epochs:
55
- return [self.eta_min] * len(self.base_lrs)
58
+ return [self.eta_min_factor * base_lr for base_lr in self.base_lrs]
56
59
 
57
60
  if (self.last_epoch - 1 - self.max_epochs) % (
58
61
  2 * (self.max_epochs - self.warmup_epochs)
59
62
  ) == 0:
60
63
  return [
61
64
  group["lr"]
62
- + (base_lr - self.eta_min)
65
+ + (base_lr - self.eta_min_factor * base_lr)
63
66
  * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)))
64
67
  / 2
65
68
  for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
@@ -82,9 +85,9 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
82
85
  / (self.max_epochs - self.warmup_epochs)
83
86
  )
84
87
  )
85
- * (group["lr"] - self.eta_min)
86
- + self.eta_min
87
- for group in self.optimizer.param_groups
88
+ * (group["lr"] - self.eta_min_factor * base_lr)
89
+ + self.eta_min_factor * base_lr
90
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
88
91
  ]
89
92
 
90
93
 
@@ -119,12 +122,10 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
119
122
  }
120
123
 
121
124
  @override
122
- def create_scheduler_impl(self, optimizer, lightning_module, lr):
125
+ def create_scheduler_impl(self, optimizer, lightning_module):
123
126
  num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
124
127
  warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
125
128
  max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
126
- warmup_start_lr = self.warmup_start_lr_factor * lr
127
- min_lr = self.min_lr_factor * lr
128
129
 
129
130
  # Warmup and max steps should be at least 1.
130
131
  warmup_steps = max(warmup_steps, 1)
@@ -135,8 +136,8 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
135
136
  optimizer=optimizer,
136
137
  warmup_epochs=warmup_steps,
137
138
  max_epochs=max_steps,
138
- warmup_start_lr=warmup_start_lr,
139
- eta_min=min_lr,
139
+ warmup_start_lr_factor=self.warmup_start_lr_factor,
140
+ eta_min_factor=self.min_lr_factor,
140
141
  should_restart=self.annealing,
141
142
  )
142
143
  return scheduler
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import TYPE_CHECKING, Literal, cast
2
4
 
3
5
  from lightning.pytorch.utilities.types import LRSchedulerConfigType
@@ -20,21 +22,21 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
20
22
  """Metric to monitor.
21
23
  If not provided, the primary metric of the runner will be used."""
22
24
 
23
- patience: int = 10
25
+ patience: int
24
26
  r"""Number of epochs with no improvement after which learning rate will be reduced."""
25
27
 
26
- factor: float = 0.1
28
+ factor: float
27
29
  r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
28
30
 
31
+ cooldown: int = 0
32
+ r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
33
+
29
34
  min_lr: float | list[float] = 0.0
30
35
  r"""A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively."""
31
36
 
32
37
  eps: float = 1.0e-8
33
38
  r"""Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored."""
34
39
 
35
- cooldown: int = 0
36
- r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
37
-
38
40
  threshold: float = 1.0e-4
39
41
  r"""Threshold for measuring the new optimum, to only focus on significant changes."""
40
42
 
@@ -43,7 +45,7 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
43
45
 
44
46
  @override
45
47
  def create_scheduler_impl(
46
- self, optimizer, lightning_module, lr
48
+ self, optimizer, lightning_module
47
49
  ) -> LRSchedulerConfigType:
48
50
  if (metric := self.metric) is None:
49
51
  lm_config = cast("BaseConfig", lightning_module.config)
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from ._config import MetricConfig as MetricConfig
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import builtins
2
4
  from typing import Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .base import LightningModuleBase as LightningModuleBase
2
4
  from .config import BaseConfig as BaseConfig
3
5
  from .config import DirectoryConfig as DirectoryConfig
nshtrainer/model/base.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import inspect
2
4
  import logging
3
5
  from abc import ABC, abstractmethod
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import logging
3
5
  import os
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Callable, Iterable, Sequence
3
5
  from typing import Any, TypeAlias, cast, final
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import deque
2
4
  from collections.abc import Callable, Generator
3
5
  from contextlib import contextmanager
nshtrainer/nn/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .mlp import MLP as MLP
2
4
  from .mlp import MLPConfig as MLPConfig
3
5
  from .mlp import MLPConfigDict as MLPConfigDict
nshtrainer/nn/mlp.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from collections.abc import Callable, Sequence
3
5
  from typing import Literal, Protocol, runtime_checkable
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Iterable, Mapping
2
4
  from typing import Generic, cast
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Iterable, Iterator
2
4
  from typing import Generic, TypeVar, overload
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from typing import Annotated, Literal
3
5
 
nshtrainer/optimizer.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from collections.abc import Iterable
3
5
  from typing import Annotated, Any, Literal, TypeAlias
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from abc import ABC, abstractmethod
3
5
  from pathlib import Path
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
nshtrainer/runner.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import logging
3
5
  from collections.abc import Callable, Iterable, Mapping, Sequence
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import argparse
2
4
  import ast
3
5
  import glob
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from .trainer import Trainer as Trainer