nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import heapq
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from typing import Any, Protocol, runtime_checkable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed
|
|
9
|
+
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
|
|
10
|
+
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
log = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
|
|
17
|
+
gathered = [
|
|
18
|
+
torch.zeros_like(tensor, device=device)
|
|
19
|
+
for _ in range(torch.distributed.get_world_size())
|
|
20
|
+
]
|
|
21
|
+
_ = torch.distributed.all_gather(gathered, tensor)
|
|
22
|
+
return gathered
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# @numba.njit
|
|
26
|
+
def _balanced_partition(sizes: np.ndarray, num_parts: int):
|
|
27
|
+
"""
|
|
28
|
+
Greedily partition the given set by always inserting
|
|
29
|
+
the largest element into the smallest partition.
|
|
30
|
+
"""
|
|
31
|
+
sort_idx = np.argsort(-sizes) # Sort in descending order
|
|
32
|
+
heap = []
|
|
33
|
+
for idx in sort_idx[:num_parts]:
|
|
34
|
+
heap.append((sizes[idx], [idx]))
|
|
35
|
+
heapq.heapify(heap)
|
|
36
|
+
for idx in sort_idx[num_parts:]:
|
|
37
|
+
smallest_part = heapq.heappop(heap)
|
|
38
|
+
new_size = smallest_part[0] + sizes[idx]
|
|
39
|
+
new_idx = smallest_part[1] + [idx]
|
|
40
|
+
heapq.heappush(heap, (new_size, new_idx))
|
|
41
|
+
idx_balanced = [part[1] for part in heap]
|
|
42
|
+
return idx_balanced
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@runtime_checkable
|
|
46
|
+
class DatasetWithSizes(Protocol):
|
|
47
|
+
def data_sizes(self, indices: list[int]) -> np.ndarray: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BalancedBatchSampler(BatchSampler):
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _ensure_supported(dataset: Any):
|
|
53
|
+
if not isinstance(dataset, Dataset):
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"BalancedBatchSampler requires a dataset that implements `__getitem__`"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if not isinstance(dataset, DatasetWithSizes):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"BalancedBatchSampler requires a dataset that implements `data_sizes`"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
|
|
64
|
+
return dataset
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _unwrap_dataset(dataset: Dataset) -> Dataset:
|
|
68
|
+
if isinstance(dataset, _DatasetSamplerWrapper):
|
|
69
|
+
if (data_source := getattr(dataset._sampler, "data_source", None)) is None:
|
|
70
|
+
raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper")
|
|
71
|
+
return data_source
|
|
72
|
+
return dataset
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def distributed_sampler(self):
|
|
76
|
+
if not isinstance(self.sampler, DistributedSampler):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Sampler must be a DistributedSampler, got {type(self.sampler)}"
|
|
79
|
+
)
|
|
80
|
+
return self.sampler
|
|
81
|
+
|
|
82
|
+
@cached_property
|
|
83
|
+
def dataset(self):
|
|
84
|
+
return self._ensure_supported(
|
|
85
|
+
self._unwrap_dataset(self.distributed_sampler.dataset)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
sampler: DistributedSampler,
|
|
91
|
+
*,
|
|
92
|
+
batch_size: int,
|
|
93
|
+
device: torch.device,
|
|
94
|
+
drop_last: bool = False,
|
|
95
|
+
):
|
|
96
|
+
super().__init__(sampler, batch_size, drop_last=drop_last)
|
|
97
|
+
|
|
98
|
+
self._device = device
|
|
99
|
+
|
|
100
|
+
log.info(
|
|
101
|
+
f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _dist_enabled():
|
|
106
|
+
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def __iter__(self):
|
|
110
|
+
if not self._dist_enabled():
|
|
111
|
+
yield from super().__iter__()
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
for batch_idx in super().__iter__():
|
|
115
|
+
sizes = self.dataset.data_sizes(batch_idx)
|
|
116
|
+
idx_sizes = torch.stack(
|
|
117
|
+
[
|
|
118
|
+
torch.tensor(batch_idx, device=self._device),
|
|
119
|
+
torch.tensor(sizes, device=self._device),
|
|
120
|
+
]
|
|
121
|
+
)
|
|
122
|
+
idx_sizes_all = _all_gather(idx_sizes, device=self._device)
|
|
123
|
+
idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu()
|
|
124
|
+
idx_all = idx_sizes_all[0]
|
|
125
|
+
sizes_all = idx_sizes_all[1]
|
|
126
|
+
|
|
127
|
+
local_idx_balanced = _balanced_partition(
|
|
128
|
+
sizes_all.numpy(), num_parts=self.distributed_sampler.num_replicas
|
|
129
|
+
)
|
|
130
|
+
# Since DistributedSampler pads the last batch
|
|
131
|
+
# this should always have an entry for each replica.
|
|
132
|
+
yield idx_all[local_idx_balanced[self.distributed_sampler.rank]].tolist()
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
from typing_extensions import TypeVar
|
|
6
|
+
|
|
7
|
+
TDataset = TypeVar("TDataset", infer_variance=True)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def transform(
|
|
11
|
+
dataset: TDataset,
|
|
12
|
+
transform: Callable[[Any], Any],
|
|
13
|
+
*,
|
|
14
|
+
deepcopy: bool = False,
|
|
15
|
+
) -> TDataset:
|
|
16
|
+
"""
|
|
17
|
+
Wraps a dataset with a transform function.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
dataset: The dataset to wrap.
|
|
21
|
+
transform: The transform function to apply to each item.
|
|
22
|
+
deepcopy: Whether to deep copy each item before applying the transform.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import wrapt
|
|
26
|
+
|
|
27
|
+
class _TransformedDataset(wrapt.ObjectProxy):
|
|
28
|
+
def __getitem__(self, idx):
|
|
29
|
+
nonlocal deepcopy, transform
|
|
30
|
+
|
|
31
|
+
data = self.__wrapped__.__getitem__(idx)
|
|
32
|
+
if deepcopy:
|
|
33
|
+
data = copy.deepcopy(data)
|
|
34
|
+
data = transform(data)
|
|
35
|
+
return data
|
|
36
|
+
|
|
37
|
+
return cast(TDataset, _TransformedDataset(dataset))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def transform_with_index(
|
|
41
|
+
dataset: TDataset,
|
|
42
|
+
transform: Callable[[Any, int], Any],
|
|
43
|
+
*,
|
|
44
|
+
deepcopy: bool = False,
|
|
45
|
+
) -> TDataset:
|
|
46
|
+
"""
|
|
47
|
+
Wraps a dataset with a transform function that takes an index, in addition to the item.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
dataset: The dataset to wrap.
|
|
51
|
+
transform: The transform function to apply to each item.
|
|
52
|
+
deepcopy: Whether to deep copy each item before applying the transform.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
import wrapt
|
|
56
|
+
|
|
57
|
+
class _TransformedWithIndexDataset(wrapt.ObjectProxy):
|
|
58
|
+
def __getitem__(self, idx: int):
|
|
59
|
+
nonlocal deepcopy, transform
|
|
60
|
+
|
|
61
|
+
data = self.__wrapped__.__getitem__(idx)
|
|
62
|
+
if deepcopy:
|
|
63
|
+
data = copy.deepcopy(data)
|
|
64
|
+
data = transform(data, idx)
|
|
65
|
+
return data
|
|
66
|
+
|
|
67
|
+
return cast(TDataset, _TransformedWithIndexDataset(dataset))
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Annotated, TypeAlias
|
|
2
|
+
|
|
3
|
+
from ..config import Field
|
|
4
|
+
from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
|
5
|
+
from ._base import LRSchedulerMetadata as LRSchedulerMetadata
|
|
6
|
+
from .linear_warmup_cosine import (
|
|
7
|
+
LinearWarmupCosineAnnealingLR as LinearWarmupCosineAnnealingLR,
|
|
8
|
+
)
|
|
9
|
+
from .linear_warmup_cosine import (
|
|
10
|
+
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
|
11
|
+
)
|
|
12
|
+
from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
|
|
13
|
+
from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
|
14
|
+
|
|
15
|
+
LRSchedulerConfig: TypeAlias = Annotated[
|
|
16
|
+
LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
|
|
17
|
+
Field(discriminator="name"),
|
|
18
|
+
]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch.utilities.types import (
|
|
7
|
+
LRSchedulerConfigType,
|
|
8
|
+
LRSchedulerTypeUnion,
|
|
9
|
+
)
|
|
10
|
+
from torch.optim import Optimizer
|
|
11
|
+
from typing_extensions import NotRequired, TypedDict
|
|
12
|
+
|
|
13
|
+
from ..config import TypedConfig
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..model.base import LightningModuleBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LRSchedulerMetadata(TypedDict):
|
|
20
|
+
interval: Literal["epoch", "step"]
|
|
21
|
+
"""Interval to update the learning rate."""
|
|
22
|
+
|
|
23
|
+
name: NotRequired[str | None]
|
|
24
|
+
"""Name of the learning rate scheduler. Default is `None`."""
|
|
25
|
+
|
|
26
|
+
frequency: NotRequired[int]
|
|
27
|
+
"""Frequency to update the learning rate. Default is `1`."""
|
|
28
|
+
|
|
29
|
+
reduce_on_plateau: NotRequired[bool]
|
|
30
|
+
"""Whether to reduce the learning rate on plateau. Default is `False`."""
|
|
31
|
+
|
|
32
|
+
monitor: NotRequired[str | None]
|
|
33
|
+
"""Value to monitor for reducing the learning rate on plateau. Required if `reduce_on_plateau` is `True`.
|
|
34
|
+
Default is `None`."""
|
|
35
|
+
|
|
36
|
+
strict: NotRequired[bool]
|
|
37
|
+
"""Whether to enforce that the monitor exists for reducing the learning rate on plateau. Default is `True`."""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class LRSchedulerConfigBase(TypedConfig, ABC):
|
|
41
|
+
Metadata: ClassVar[TypeAlias] = LRSchedulerMetadata
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def metadata(self) -> LRSchedulerMetadata: ...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def create_scheduler_impl(
|
|
48
|
+
self,
|
|
49
|
+
optimizer: Optimizer,
|
|
50
|
+
lightning_module: "LightningModuleBase",
|
|
51
|
+
lr: float,
|
|
52
|
+
) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
|
|
53
|
+
|
|
54
|
+
def create_scheduler(
|
|
55
|
+
self,
|
|
56
|
+
optimizer: Optimizer,
|
|
57
|
+
lightning_module: "LightningModuleBase",
|
|
58
|
+
lr: float,
|
|
59
|
+
) -> LRSchedulerConfigType:
|
|
60
|
+
# Create the scheduler.
|
|
61
|
+
scheduler = self.create_scheduler_impl(optimizer, lightning_module, lr)
|
|
62
|
+
|
|
63
|
+
# If the scheduler is not a `LRSchedulerConfigType`, then make it one.
|
|
64
|
+
if not isinstance(scheduler, Mapping):
|
|
65
|
+
scheduler = LRSchedulerConfigType(scheduler=scheduler)
|
|
66
|
+
|
|
67
|
+
# Update the scheduler config with the metadata (if not already present).
|
|
68
|
+
metadata = self.metadata()
|
|
69
|
+
# - `interval` has to be present.
|
|
70
|
+
if scheduler.get("interval") is None:
|
|
71
|
+
scheduler["interval"] = metadata["interval"]
|
|
72
|
+
# - `name`
|
|
73
|
+
if scheduler.get("name") is None and "name" in metadata:
|
|
74
|
+
scheduler["name"] = metadata["name"]
|
|
75
|
+
# - `frequency`
|
|
76
|
+
if scheduler.get("frequency") is None and "frequency" in metadata:
|
|
77
|
+
scheduler["frequency"] = metadata["frequency"]
|
|
78
|
+
# - `reduce_on_plateau`
|
|
79
|
+
if (
|
|
80
|
+
scheduler.get("reduce_on_plateau") is None
|
|
81
|
+
and "reduce_on_plateau" in metadata
|
|
82
|
+
):
|
|
83
|
+
scheduler["reduce_on_plateau"] = metadata["reduce_on_plateau"]
|
|
84
|
+
# - `monitor`
|
|
85
|
+
if scheduler.get("monitor") is None and "monitor" in metadata:
|
|
86
|
+
scheduler["monitor"] = metadata["monitor"]
|
|
87
|
+
# - `strict`
|
|
88
|
+
if scheduler.get("strict") is None and "strict" in metadata:
|
|
89
|
+
scheduler["strict"] = metadata["strict"] # type: ignore
|
|
90
|
+
|
|
91
|
+
return scheduler
|
|
92
|
+
|
|
93
|
+
def compute_num_steps_per_epoch(
|
|
94
|
+
self, lightning_module: "LightningModuleBase"
|
|
95
|
+
) -> int:
|
|
96
|
+
trainer = lightning_module.trainer
|
|
97
|
+
# Use the Lightning trainer to convert the epoch-based values to step-based values
|
|
98
|
+
_ = trainer.estimated_stepping_batches
|
|
99
|
+
# ^ This is a hack to trigger the computation of the estimated stepping batches
|
|
100
|
+
# and make sure that the `trainer.num_training_batches` attribute is set.
|
|
101
|
+
return math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from torch.optim import Optimizer
|
|
6
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from ..config import Field
|
|
10
|
+
from ._base import LRSchedulerConfigBase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
14
|
+
_get_lr_called_within_step: bool
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
optimizer: Optimizer,
|
|
19
|
+
warmup_epochs: int,
|
|
20
|
+
max_epochs: int,
|
|
21
|
+
warmup_start_lr: float = 0.0,
|
|
22
|
+
eta_min: float = 0.0,
|
|
23
|
+
last_epoch: int = -1,
|
|
24
|
+
should_restart: bool = True,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.warmup_epochs = warmup_epochs
|
|
27
|
+
self.max_epochs = max_epochs
|
|
28
|
+
self.warmup_start_lr = warmup_start_lr
|
|
29
|
+
self.eta_min = eta_min
|
|
30
|
+
self.should_restart = should_restart
|
|
31
|
+
|
|
32
|
+
super().__init__(optimizer, last_epoch)
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def get_lr(self) -> list[float]: # pyright: ignore[reportIncompatibleMethodOverride]
|
|
36
|
+
if not self._get_lr_called_within_step:
|
|
37
|
+
warnings.warn(
|
|
38
|
+
"To get the last learning rate computed by the scheduler, "
|
|
39
|
+
"please use `get_last_lr()`.",
|
|
40
|
+
UserWarning,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
if self.last_epoch == 0:
|
|
44
|
+
return [self.warmup_start_lr] * len(self.base_lrs)
|
|
45
|
+
if self.last_epoch < self.warmup_epochs:
|
|
46
|
+
return [
|
|
47
|
+
group["lr"]
|
|
48
|
+
+ (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
|
|
49
|
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
|
50
|
+
]
|
|
51
|
+
if self.last_epoch == self.warmup_epochs:
|
|
52
|
+
return self.base_lrs
|
|
53
|
+
|
|
54
|
+
if not self.should_restart and self.last_epoch >= self.max_epochs:
|
|
55
|
+
return [self.eta_min] * len(self.base_lrs)
|
|
56
|
+
|
|
57
|
+
if (self.last_epoch - 1 - self.max_epochs) % (
|
|
58
|
+
2 * (self.max_epochs - self.warmup_epochs)
|
|
59
|
+
) == 0:
|
|
60
|
+
return [
|
|
61
|
+
group["lr"]
|
|
62
|
+
+ (base_lr - self.eta_min)
|
|
63
|
+
* (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)))
|
|
64
|
+
/ 2
|
|
65
|
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
return [
|
|
69
|
+
(
|
|
70
|
+
1
|
|
71
|
+
+ math.cos(
|
|
72
|
+
math.pi
|
|
73
|
+
* (self.last_epoch - self.warmup_epochs)
|
|
74
|
+
/ (self.max_epochs - self.warmup_epochs)
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
/ (
|
|
78
|
+
1
|
|
79
|
+
+ math.cos(
|
|
80
|
+
math.pi
|
|
81
|
+
* (self.last_epoch - self.warmup_epochs - 1)
|
|
82
|
+
/ (self.max_epochs - self.warmup_epochs)
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
* (group["lr"] - self.eta_min)
|
|
86
|
+
+ self.eta_min
|
|
87
|
+
for group in self.optimizer.param_groups
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
92
|
+
name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
|
|
93
|
+
|
|
94
|
+
warmup_epochs: int = Field(ge=0)
|
|
95
|
+
r"""The number of epochs for the linear warmup phase.
|
|
96
|
+
The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
|
|
97
|
+
|
|
98
|
+
max_epochs: int = Field(gt=0)
|
|
99
|
+
r"""The total number of epochs.
|
|
100
|
+
The learning rate is decayed to `min_lr` over this number of epochs."""
|
|
101
|
+
|
|
102
|
+
warmup_start_lr_factor: float = 0.0
|
|
103
|
+
r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
|
|
104
|
+
The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
|
|
105
|
+
|
|
106
|
+
min_lr_factor: float = 0.0
|
|
107
|
+
r"""The minimum learning rate, as a factor of the initial learning rate.
|
|
108
|
+
The learning rate is decayed to this value over `max_epochs` epochs."""
|
|
109
|
+
|
|
110
|
+
annealing: bool = False
|
|
111
|
+
r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
|
|
112
|
+
If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
|
|
113
|
+
If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
|
|
114
|
+
|
|
115
|
+
@override
|
|
116
|
+
def metadata(self) -> LRSchedulerConfigBase.Metadata:
|
|
117
|
+
return {
|
|
118
|
+
"interval": "step",
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
@override
|
|
122
|
+
def create_scheduler_impl(self, optimizer, lightning_module, lr):
|
|
123
|
+
num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
|
|
124
|
+
warmup_steps = self.warmup_epochs * num_steps_per_epoch
|
|
125
|
+
max_steps = self.max_epochs * num_steps_per_epoch
|
|
126
|
+
warmup_start_lr = self.warmup_start_lr_factor * lr
|
|
127
|
+
min_lr = self.min_lr_factor * lr
|
|
128
|
+
|
|
129
|
+
# Create the scheduler
|
|
130
|
+
scheduler = LinearWarmupCosineAnnealingLR(
|
|
131
|
+
optimizer=optimizer,
|
|
132
|
+
warmup_epochs=warmup_steps,
|
|
133
|
+
max_epochs=max_steps,
|
|
134
|
+
warmup_start_lr=warmup_start_lr,
|
|
135
|
+
eta_min=min_lr,
|
|
136
|
+
should_restart=self.annealing,
|
|
137
|
+
)
|
|
138
|
+
return scheduler
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
2
|
+
|
|
3
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from ll.lr_scheduler._base import LRSchedulerMetadata
|
|
7
|
+
|
|
8
|
+
from ..model.config import MetricConfig
|
|
9
|
+
from ._base import LRSchedulerConfigBase
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..model.base import BaseConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
16
|
+
"""Reduce learning rate when a metric has stopped improving."""
|
|
17
|
+
|
|
18
|
+
name: Literal["reduce_lr_on_plateau"] = "reduce_lr_on_plateau"
|
|
19
|
+
|
|
20
|
+
metric: MetricConfig | None = None
|
|
21
|
+
"""Metric to monitor.
|
|
22
|
+
If not provided, the primary metric of the runner will be used."""
|
|
23
|
+
|
|
24
|
+
patience: int = 10
|
|
25
|
+
r"""Number of epochs with no improvement after which learning rate will be reduced."""
|
|
26
|
+
|
|
27
|
+
factor: float = 0.1
|
|
28
|
+
r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
|
|
29
|
+
|
|
30
|
+
min_lr: float | list[float] = 0.0
|
|
31
|
+
r"""A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively."""
|
|
32
|
+
|
|
33
|
+
eps: float = 1.0e-8
|
|
34
|
+
r"""Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored."""
|
|
35
|
+
|
|
36
|
+
cooldown: int = 0
|
|
37
|
+
r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
|
|
38
|
+
|
|
39
|
+
threshold: float = 1.0e-4
|
|
40
|
+
r"""Threshold for measuring the new optimum, to only focus on significant changes."""
|
|
41
|
+
|
|
42
|
+
threshold_mode: str = "rel"
|
|
43
|
+
r"""One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'."""
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def create_scheduler_impl(self, optimizer, lightning_module, lr):
|
|
47
|
+
if (metric := self.metric) is None:
|
|
48
|
+
lm_config = cast("BaseConfig", lightning_module.config)
|
|
49
|
+
assert (
|
|
50
|
+
metric := lm_config.primary_metric
|
|
51
|
+
) is not None, "Primary metric must be provided if metric is not specified."
|
|
52
|
+
|
|
53
|
+
lr_scheduler = ReduceLROnPlateau(
|
|
54
|
+
optimizer,
|
|
55
|
+
mode=metric.mode,
|
|
56
|
+
factor=self.factor,
|
|
57
|
+
patience=self.patience,
|
|
58
|
+
threshold=self.threshold,
|
|
59
|
+
threshold_mode=self.threshold_mode,
|
|
60
|
+
cooldown=self.cooldown,
|
|
61
|
+
min_lr=self.min_lr,
|
|
62
|
+
eps=self.eps,
|
|
63
|
+
)
|
|
64
|
+
return {
|
|
65
|
+
"scheduler": lr_scheduler,
|
|
66
|
+
"monitor": metric.validation_monitor,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def metadata(self) -> LRSchedulerMetadata:
|
|
71
|
+
return {
|
|
72
|
+
"interval": "epoch",
|
|
73
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing_extensions import TypeAlias
|
|
2
|
+
|
|
3
|
+
from .base import Base as Base
|
|
4
|
+
from .base import LightningDataModuleBase as LightningDataModuleBase
|
|
5
|
+
from .base import LightningModuleBase as LightningModuleBase
|
|
6
|
+
from .config import ActSaveConfig as ActSaveConfig
|
|
7
|
+
from .config import BaseConfig as BaseConfig
|
|
8
|
+
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
9
|
+
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
10
|
+
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
11
|
+
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
12
|
+
from .config import DirectoryConfig as DirectoryConfig
|
|
13
|
+
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
14
|
+
from .config import (
|
|
15
|
+
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
16
|
+
)
|
|
17
|
+
from .config import EnvironmentConfig as EnvironmentConfig
|
|
18
|
+
from .config import (
|
|
19
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
20
|
+
)
|
|
21
|
+
from .config import (
|
|
22
|
+
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
23
|
+
)
|
|
24
|
+
from .config import GradientClippingConfig as GradientClippingConfig
|
|
25
|
+
from .config import (
|
|
26
|
+
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
27
|
+
)
|
|
28
|
+
from .config import LoggingConfig as LoggingConfig
|
|
29
|
+
from .config import MetricConfig as MetricConfig
|
|
30
|
+
from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
31
|
+
from .config import (
|
|
32
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
33
|
+
)
|
|
34
|
+
from .config import OptimizationConfig as OptimizationConfig
|
|
35
|
+
from .config import PrimaryMetricConfig as PrimaryMetricConfig
|
|
36
|
+
from .config import PythonLogging as PythonLogging
|
|
37
|
+
from .config import ReproducibilityConfig as ReproducibilityConfig
|
|
38
|
+
from .config import RunnerConfig as RunnerConfig
|
|
39
|
+
from .config import SanityCheckingConfig as SanityCheckingConfig
|
|
40
|
+
from .config import SeedConfig as SeedConfig
|
|
41
|
+
from .config import TrainerConfig as TrainerConfig
|
|
42
|
+
from .config import WandbWatchConfig as WandbWatchConfig
|
|
43
|
+
|
|
44
|
+
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|