nshtrainer 0.42.0__py3-none-any.whl → 0.44.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/_callback.py +2 -0
- nshtrainer/_checkpoint/loader.py +2 -0
- nshtrainer/_checkpoint/metadata.py +2 -0
- nshtrainer/_checkpoint/saver.py +2 -0
- nshtrainer/_directory.py +4 -2
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_hf_hub.py +2 -0
- nshtrainer/callbacks/__init__.py +45 -29
- nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
- nshtrainer/callbacks/actsave.py +2 -0
- nshtrainer/callbacks/base.py +2 -0
- nshtrainer/callbacks/checkpoint/__init__.py +6 -2
- nshtrainer/callbacks/checkpoint/_base.py +2 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
- nshtrainer/callbacks/debug_flag.py +2 -0
- nshtrainer/callbacks/directory_setup.py +4 -2
- nshtrainer/callbacks/early_stopping.py +6 -4
- nshtrainer/callbacks/ema.py +5 -3
- nshtrainer/callbacks/finite_checks.py +3 -1
- nshtrainer/callbacks/gradient_skipping.py +6 -4
- nshtrainer/callbacks/interval.py +2 -0
- nshtrainer/callbacks/log_epoch.py +13 -1
- nshtrainer/callbacks/norm_logging.py +4 -2
- nshtrainer/callbacks/print_table.py +3 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +4 -2
- nshtrainer/callbacks/throughput_monitor.py +2 -0
- nshtrainer/callbacks/timer.py +5 -3
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/config/__init__.py +130 -90
- nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
- nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
- nshtrainer/config/_directory/__init__.py +9 -3
- nshtrainer/config/_hf_hub/__init__.py +6 -4
- nshtrainer/config/callbacks/__init__.py +82 -42
- nshtrainer/config/callbacks/actsave/__init__.py +4 -2
- nshtrainer/config/callbacks/base/__init__.py +2 -0
- nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
- nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
- nshtrainer/config/callbacks/ema/__init__.py +5 -3
- nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
- nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
- nshtrainer/config/callbacks/print_table/__init__.py +7 -5
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
- nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
- nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
- nshtrainer/config/callbacks/timer/__init__.py +9 -5
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
- nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
- nshtrainer/config/loggers/__init__.py +18 -10
- nshtrainer/config/loggers/_base/__init__.py +2 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -0
- nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
- nshtrainer/config/loggers/wandb/__init__.py +18 -10
- nshtrainer/config/lr_scheduler/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
- nshtrainer/config/metrics/__init__.py +2 -0
- nshtrainer/config/metrics/_config/__init__.py +2 -0
- nshtrainer/config/model/__init__.py +8 -6
- nshtrainer/config/model/base/__init__.py +4 -2
- nshtrainer/config/model/config/__init__.py +8 -6
- nshtrainer/config/model/mixins/logger/__init__.py +2 -0
- nshtrainer/config/nn/__init__.py +16 -14
- nshtrainer/config/nn/mlp/__init__.py +2 -0
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
- nshtrainer/config/optimizer/__init__.py +2 -0
- nshtrainer/config/profiler/__init__.py +2 -0
- nshtrainer/config/profiler/_base/__init__.py +2 -0
- nshtrainer/config/profiler/advanced/__init__.py +6 -4
- nshtrainer/config/profiler/pytorch/__init__.py +6 -4
- nshtrainer/config/profiler/simple/__init__.py +6 -4
- nshtrainer/config/runner/__init__.py +2 -0
- nshtrainer/config/trainer/_config/__init__.py +43 -39
- nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -18
- nshtrainer/config/util/config/__init__.py +2 -0
- nshtrainer/config/util/config/dtype/__init__.py +2 -0
- nshtrainer/config/util/config/duration/__init__.py +2 -0
- nshtrainer/data/__init__.py +2 -0
- nshtrainer/data/balanced_batch_sampler.py +2 -0
- nshtrainer/data/datamodule.py +2 -0
- nshtrainer/data/transform.py +2 -0
- nshtrainer/ll/__init__.py +2 -0
- nshtrainer/ll/_experimental.py +2 -0
- nshtrainer/ll/actsave.py +2 -0
- nshtrainer/ll/callbacks.py +2 -0
- nshtrainer/ll/config.py +2 -0
- nshtrainer/ll/data.py +2 -0
- nshtrainer/ll/log.py +2 -0
- nshtrainer/ll/lr_scheduler.py +2 -0
- nshtrainer/ll/model.py +2 -0
- nshtrainer/ll/nn.py +2 -0
- nshtrainer/ll/optimizer.py +2 -0
- nshtrainer/ll/runner.py +2 -0
- nshtrainer/ll/snapshot.py +2 -0
- nshtrainer/ll/snoop.py +2 -0
- nshtrainer/ll/trainer.py +2 -0
- nshtrainer/ll/typecheck.py +2 -0
- nshtrainer/ll/util.py +2 -0
- nshtrainer/loggers/__init__.py +2 -0
- nshtrainer/loggers/_base.py +2 -0
- nshtrainer/loggers/csv.py +2 -0
- nshtrainer/loggers/tensorboard.py +2 -0
- nshtrainer/loggers/wandb.py +6 -4
- nshtrainer/lr_scheduler/__init__.py +2 -0
- nshtrainer/lr_scheduler/_base.py +8 -11
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +18 -17
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +8 -6
- nshtrainer/metrics/__init__.py +2 -0
- nshtrainer/metrics/_config.py +2 -0
- nshtrainer/model/__init__.py +2 -0
- nshtrainer/model/base.py +2 -0
- nshtrainer/model/config.py +2 -0
- nshtrainer/model/mixins/callback.py +2 -0
- nshtrainer/model/mixins/logger.py +2 -0
- nshtrainer/nn/__init__.py +2 -0
- nshtrainer/nn/mlp.py +2 -0
- nshtrainer/nn/module_dict.py +2 -0
- nshtrainer/nn/module_list.py +2 -0
- nshtrainer/nn/nonlinearity.py +2 -0
- nshtrainer/optimizer.py +2 -0
- nshtrainer/profiler/__init__.py +2 -0
- nshtrainer/profiler/_base.py +2 -0
- nshtrainer/profiler/advanced.py +2 -0
- nshtrainer/profiler/pytorch.py +2 -0
- nshtrainer/profiler/simple.py +2 -0
- nshtrainer/runner.py +2 -0
- nshtrainer/scripts/find_packages.py +2 -0
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +16 -13
- nshtrainer/trainer/_runtime_callback.py +2 -0
- nshtrainer/trainer/checkpoint_connector.py +2 -0
- nshtrainer/trainer/signal_connector.py +2 -0
- nshtrainer/trainer/trainer.py +2 -0
- nshtrainer/util/_environment_info.py +2 -0
- nshtrainer/util/bf16.py +2 -0
- nshtrainer/util/config/__init__.py +2 -0
- nshtrainer/util/config/dtype.py +2 -0
- nshtrainer/util/config/duration.py +2 -0
- nshtrainer/util/environment.py +2 -0
- nshtrainer/util/path.py +2 -0
- nshtrainer/util/seed.py +2 -0
- nshtrainer/util/slurm.py +3 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +2 -0
- {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/METADATA +1 -1
- nshtrainer-0.44.0.dist-info/RECORD +162 -0
- nshtrainer-0.42.0.dist-info/RECORD +0 -162
- {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/WHEEL +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -43,30 +45,30 @@ else:
|
|
|
43
45
|
|
|
44
46
|
if name in globals():
|
|
45
47
|
return globals()[name]
|
|
46
|
-
if name == "
|
|
47
|
-
return importlib.import_module(
|
|
48
|
-
"nshtrainer.util._environment_info"
|
|
49
|
-
).EnvironmentPackageConfig
|
|
50
|
-
if name == "EnvironmentSnapshotConfig":
|
|
48
|
+
if name == "EnvironmentLinuxEnvironmentConfig":
|
|
51
49
|
return importlib.import_module(
|
|
52
50
|
"nshtrainer.util._environment_info"
|
|
53
|
-
).
|
|
51
|
+
).EnvironmentLinuxEnvironmentConfig
|
|
54
52
|
if name == "EnvironmentLSFInformationConfig":
|
|
55
53
|
return importlib.import_module(
|
|
56
54
|
"nshtrainer.util._environment_info"
|
|
57
55
|
).EnvironmentLSFInformationConfig
|
|
58
|
-
if name == "
|
|
56
|
+
if name == "EnvironmentGPUConfig":
|
|
59
57
|
return importlib.import_module(
|
|
60
58
|
"nshtrainer.util._environment_info"
|
|
61
|
-
).
|
|
62
|
-
if name == "
|
|
59
|
+
).EnvironmentGPUConfig
|
|
60
|
+
if name == "EnvironmentPackageConfig":
|
|
63
61
|
return importlib.import_module(
|
|
64
62
|
"nshtrainer.util._environment_info"
|
|
65
|
-
).
|
|
66
|
-
if name == "
|
|
63
|
+
).EnvironmentPackageConfig
|
|
64
|
+
if name == "EnvironmentHardwareConfig":
|
|
67
65
|
return importlib.import_module(
|
|
68
66
|
"nshtrainer.util._environment_info"
|
|
69
|
-
).
|
|
67
|
+
).EnvironmentHardwareConfig
|
|
68
|
+
if name == "EnvironmentSnapshotConfig":
|
|
69
|
+
return importlib.import_module(
|
|
70
|
+
"nshtrainer.util._environment_info"
|
|
71
|
+
).EnvironmentSnapshotConfig
|
|
70
72
|
if name == "EnvironmentClassInformationConfig":
|
|
71
73
|
return importlib.import_module(
|
|
72
74
|
"nshtrainer.util._environment_info"
|
|
@@ -75,18 +77,18 @@ else:
|
|
|
75
77
|
return importlib.import_module(
|
|
76
78
|
"nshtrainer.util._environment_info"
|
|
77
79
|
).GitRepositoryConfig
|
|
78
|
-
if name == "
|
|
80
|
+
if name == "EnvironmentConfig":
|
|
79
81
|
return importlib.import_module(
|
|
80
82
|
"nshtrainer.util._environment_info"
|
|
81
|
-
).
|
|
82
|
-
if name == "
|
|
83
|
+
).EnvironmentConfig
|
|
84
|
+
if name == "EnvironmentCUDAConfig":
|
|
83
85
|
return importlib.import_module(
|
|
84
86
|
"nshtrainer.util._environment_info"
|
|
85
|
-
).
|
|
86
|
-
if name == "
|
|
87
|
+
).EnvironmentCUDAConfig
|
|
88
|
+
if name == "EnvironmentSLURMInformationConfig":
|
|
87
89
|
return importlib.import_module(
|
|
88
90
|
"nshtrainer.util._environment_info"
|
|
89
|
-
).
|
|
91
|
+
).EnvironmentSLURMInformationConfig
|
|
90
92
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
91
93
|
|
|
92
94
|
# Submodule exports
|
nshtrainer/data/__init__.py
CHANGED
nshtrainer/data/datamodule.py
CHANGED
nshtrainer/data/transform.py
CHANGED
nshtrainer/ll/__init__.py
CHANGED
nshtrainer/ll/_experimental.py
CHANGED
nshtrainer/ll/actsave.py
CHANGED
nshtrainer/ll/callbacks.py
CHANGED
nshtrainer/ll/config.py
CHANGED
nshtrainer/ll/data.py
CHANGED
nshtrainer/ll/log.py
CHANGED
nshtrainer/ll/lr_scheduler.py
CHANGED
nshtrainer/ll/model.py
CHANGED
nshtrainer/ll/nn.py
CHANGED
nshtrainer/ll/optimizer.py
CHANGED
nshtrainer/ll/runner.py
CHANGED
nshtrainer/ll/snapshot.py
CHANGED
nshtrainer/ll/snoop.py
CHANGED
nshtrainer/ll/trainer.py
CHANGED
nshtrainer/ll/typecheck.py
CHANGED
nshtrainer/ll/util.py
CHANGED
nshtrainer/loggers/__init__.py
CHANGED
nshtrainer/loggers/_base.py
CHANGED
nshtrainer/loggers/csv.py
CHANGED
nshtrainer/loggers/wandb.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import importlib.metadata
|
|
2
4
|
import logging
|
|
3
5
|
from typing import TYPE_CHECKING, Literal
|
|
@@ -8,8 +10,8 @@ from packaging import version
|
|
|
8
10
|
from typing_extensions import assert_never, override
|
|
9
11
|
|
|
10
12
|
from ..callbacks.base import CallbackConfigBase
|
|
11
|
-
from ..callbacks.wandb_upload_code import
|
|
12
|
-
from ..callbacks.wandb_watch import
|
|
13
|
+
from ..callbacks.wandb_upload_code import WandbUploadCodeCallbackConfig
|
|
14
|
+
from ..callbacks.wandb_watch import WandbWatchCallbackConfig
|
|
13
15
|
from ._base import BaseLoggerConfig
|
|
14
16
|
|
|
15
17
|
if TYPE_CHECKING:
|
|
@@ -92,10 +94,10 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
92
94
|
- "none" or False: Do not log any checkpoints
|
|
93
95
|
"""
|
|
94
96
|
|
|
95
|
-
log_code:
|
|
97
|
+
log_code: WandbUploadCodeCallbackConfig | None = WandbUploadCodeCallbackConfig()
|
|
96
98
|
"""WandB code upload configuration. Used to upload code to WandB."""
|
|
97
99
|
|
|
98
|
-
watch:
|
|
100
|
+
watch: WandbWatchCallbackConfig | None = WandbWatchCallbackConfig()
|
|
99
101
|
"""WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
|
|
100
102
|
|
|
101
103
|
offline: bool = False
|
nshtrainer/lr_scheduler/_base.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
4
|
from abc import ABC, abstractmethod
|
|
3
5
|
from collections.abc import Mapping
|
|
@@ -9,7 +11,7 @@ from lightning.pytorch.utilities.types import (
|
|
|
9
11
|
LRSchedulerTypeUnion,
|
|
10
12
|
)
|
|
11
13
|
from torch.optim import Optimizer
|
|
12
|
-
from typing_extensions import NotRequired, TypedDict
|
|
14
|
+
from typing_extensions import Never, NotRequired, TypedDict
|
|
13
15
|
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
17
|
from ..model.base import LightningModuleBase
|
|
@@ -42,20 +44,17 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
|
42
44
|
|
|
43
45
|
@abstractmethod
|
|
44
46
|
def create_scheduler_impl(
|
|
45
|
-
self,
|
|
46
|
-
optimizer: Optimizer,
|
|
47
|
-
lightning_module: "LightningModuleBase",
|
|
48
|
-
lr: float,
|
|
47
|
+
self, optimizer: Optimizer, lightning_module: LightningModuleBase
|
|
49
48
|
) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
|
|
50
49
|
|
|
51
50
|
def create_scheduler(
|
|
52
51
|
self,
|
|
53
52
|
optimizer: Optimizer,
|
|
54
|
-
lightning_module:
|
|
55
|
-
lr:
|
|
53
|
+
lightning_module: LightningModuleBase,
|
|
54
|
+
lr: Never, # Backward compatibility, should be removed in the future
|
|
56
55
|
) -> LRSchedulerConfigType:
|
|
57
56
|
# Create the scheduler.
|
|
58
|
-
scheduler = self.create_scheduler_impl(optimizer, lightning_module
|
|
57
|
+
scheduler = self.create_scheduler_impl(optimizer, lightning_module)
|
|
59
58
|
|
|
60
59
|
# If the scheduler is not a `LRSchedulerConfigType`, then make it one.
|
|
61
60
|
if not isinstance(scheduler, Mapping):
|
|
@@ -87,9 +86,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
|
87
86
|
|
|
88
87
|
return scheduler
|
|
89
88
|
|
|
90
|
-
def compute_num_steps_per_epoch(
|
|
91
|
-
self, lightning_module: "LightningModuleBase"
|
|
92
|
-
) -> int:
|
|
89
|
+
def compute_num_steps_per_epoch(self, lightning_module: LightningModuleBase) -> int:
|
|
93
90
|
trainer = lightning_module.trainer
|
|
94
91
|
# Use the Lightning trainer to convert the epoch-based values to step-based values
|
|
95
92
|
_ = trainer.estimated_stepping_batches
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
4
|
import warnings
|
|
3
5
|
from typing import Literal
|
|
@@ -18,21 +20,21 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
|
18
20
|
optimizer: Optimizer,
|
|
19
21
|
warmup_epochs: int,
|
|
20
22
|
max_epochs: int,
|
|
21
|
-
|
|
22
|
-
|
|
23
|
+
warmup_start_lr_factor: float = 0.0,
|
|
24
|
+
eta_min_factor: float = 0.0,
|
|
23
25
|
last_epoch: int = -1,
|
|
24
26
|
should_restart: bool = True,
|
|
25
27
|
) -> None:
|
|
26
28
|
self.warmup_epochs = warmup_epochs
|
|
27
29
|
self.max_epochs = max_epochs
|
|
28
|
-
self.
|
|
29
|
-
self.
|
|
30
|
+
self.warmup_start_lr_factor = warmup_start_lr_factor
|
|
31
|
+
self.eta_min_factor = eta_min_factor
|
|
30
32
|
self.should_restart = should_restart
|
|
31
33
|
|
|
32
34
|
super().__init__(optimizer, last_epoch)
|
|
33
35
|
|
|
34
36
|
@override
|
|
35
|
-
def get_lr(self) -> list[float]:
|
|
37
|
+
def get_lr(self) -> list[float]:
|
|
36
38
|
if not self._get_lr_called_within_step:
|
|
37
39
|
warnings.warn(
|
|
38
40
|
"To get the last learning rate computed by the scheduler, "
|
|
@@ -41,25 +43,26 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
|
41
43
|
)
|
|
42
44
|
|
|
43
45
|
if self.last_epoch == 0:
|
|
44
|
-
return [self.
|
|
46
|
+
return [self.warmup_start_lr_factor * base_lr for base_lr in self.base_lrs]
|
|
45
47
|
if self.last_epoch < self.warmup_epochs:
|
|
46
48
|
return [
|
|
47
49
|
group["lr"]
|
|
48
|
-
+ (base_lr - self.
|
|
50
|
+
+ (base_lr - self.warmup_start_lr_factor * base_lr)
|
|
51
|
+
/ (self.warmup_epochs - 1)
|
|
49
52
|
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
|
50
53
|
]
|
|
51
54
|
if self.last_epoch == self.warmup_epochs:
|
|
52
55
|
return self.base_lrs
|
|
53
56
|
|
|
54
57
|
if not self.should_restart and self.last_epoch >= self.max_epochs:
|
|
55
|
-
return [self.
|
|
58
|
+
return [self.eta_min_factor * base_lr for base_lr in self.base_lrs]
|
|
56
59
|
|
|
57
60
|
if (self.last_epoch - 1 - self.max_epochs) % (
|
|
58
61
|
2 * (self.max_epochs - self.warmup_epochs)
|
|
59
62
|
) == 0:
|
|
60
63
|
return [
|
|
61
64
|
group["lr"]
|
|
62
|
-
+ (base_lr - self.
|
|
65
|
+
+ (base_lr - self.eta_min_factor * base_lr)
|
|
63
66
|
* (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)))
|
|
64
67
|
/ 2
|
|
65
68
|
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
|
@@ -82,9 +85,9 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
|
82
85
|
/ (self.max_epochs - self.warmup_epochs)
|
|
83
86
|
)
|
|
84
87
|
)
|
|
85
|
-
* (group["lr"] - self.
|
|
86
|
-
+ self.
|
|
87
|
-
for group in self.optimizer.param_groups
|
|
88
|
+
* (group["lr"] - self.eta_min_factor * base_lr)
|
|
89
|
+
+ self.eta_min_factor * base_lr
|
|
90
|
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
|
88
91
|
]
|
|
89
92
|
|
|
90
93
|
|
|
@@ -119,12 +122,10 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
|
119
122
|
}
|
|
120
123
|
|
|
121
124
|
@override
|
|
122
|
-
def create_scheduler_impl(self, optimizer, lightning_module
|
|
125
|
+
def create_scheduler_impl(self, optimizer, lightning_module):
|
|
123
126
|
num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
|
|
124
127
|
warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
|
|
125
128
|
max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
|
|
126
|
-
warmup_start_lr = self.warmup_start_lr_factor * lr
|
|
127
|
-
min_lr = self.min_lr_factor * lr
|
|
128
129
|
|
|
129
130
|
# Warmup and max steps should be at least 1.
|
|
130
131
|
warmup_steps = max(warmup_steps, 1)
|
|
@@ -135,8 +136,8 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
|
135
136
|
optimizer=optimizer,
|
|
136
137
|
warmup_epochs=warmup_steps,
|
|
137
138
|
max_epochs=max_steps,
|
|
138
|
-
|
|
139
|
-
|
|
139
|
+
warmup_start_lr_factor=self.warmup_start_lr_factor,
|
|
140
|
+
eta_min_factor=self.min_lr_factor,
|
|
140
141
|
should_restart=self.annealing,
|
|
141
142
|
)
|
|
142
143
|
return scheduler
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import TYPE_CHECKING, Literal, cast
|
|
2
4
|
|
|
3
5
|
from lightning.pytorch.utilities.types import LRSchedulerConfigType
|
|
@@ -20,21 +22,21 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
|
20
22
|
"""Metric to monitor.
|
|
21
23
|
If not provided, the primary metric of the runner will be used."""
|
|
22
24
|
|
|
23
|
-
patience: int
|
|
25
|
+
patience: int
|
|
24
26
|
r"""Number of epochs with no improvement after which learning rate will be reduced."""
|
|
25
27
|
|
|
26
|
-
factor: float
|
|
28
|
+
factor: float
|
|
27
29
|
r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
|
|
28
30
|
|
|
31
|
+
cooldown: int = 0
|
|
32
|
+
r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
|
|
33
|
+
|
|
29
34
|
min_lr: float | list[float] = 0.0
|
|
30
35
|
r"""A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively."""
|
|
31
36
|
|
|
32
37
|
eps: float = 1.0e-8
|
|
33
38
|
r"""Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored."""
|
|
34
39
|
|
|
35
|
-
cooldown: int = 0
|
|
36
|
-
r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
|
|
37
|
-
|
|
38
40
|
threshold: float = 1.0e-4
|
|
39
41
|
r"""Threshold for measuring the new optimum, to only focus on significant changes."""
|
|
40
42
|
|
|
@@ -43,7 +45,7 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
|
43
45
|
|
|
44
46
|
@override
|
|
45
47
|
def create_scheduler_impl(
|
|
46
|
-
self, optimizer, lightning_module
|
|
48
|
+
self, optimizer, lightning_module
|
|
47
49
|
) -> LRSchedulerConfigType:
|
|
48
50
|
if (metric := self.metric) is None:
|
|
49
51
|
lm_config = cast("BaseConfig", lightning_module.config)
|
nshtrainer/metrics/__init__.py
CHANGED
nshtrainer/metrics/_config.py
CHANGED
nshtrainer/model/__init__.py
CHANGED
nshtrainer/model/base.py
CHANGED
nshtrainer/model/config.py
CHANGED
nshtrainer/nn/__init__.py
CHANGED
nshtrainer/nn/mlp.py
CHANGED
nshtrainer/nn/module_dict.py
CHANGED
nshtrainer/nn/module_list.py
CHANGED
nshtrainer/nn/nonlinearity.py
CHANGED
nshtrainer/optimizer.py
CHANGED
nshtrainer/profiler/__init__.py
CHANGED
nshtrainer/profiler/_base.py
CHANGED
nshtrainer/profiler/advanced.py
CHANGED
nshtrainer/profiler/pytorch.py
CHANGED
nshtrainer/profiler/simple.py
CHANGED
nshtrainer/runner.py
CHANGED
nshtrainer/trainer/__init__.py
CHANGED