nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,132 @@
1
+ import heapq
2
+ from functools import cached_property
3
+ from logging import getLogger
4
+ from typing import Any, Protocol, runtime_checkable
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed
9
+ from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
+ from torch.utils.data import BatchSampler, Dataset, DistributedSampler
11
+ from typing_extensions import override
12
+
13
+ log = getLogger(__name__)
14
+
15
+
16
+ def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
17
+ gathered = [
18
+ torch.zeros_like(tensor, device=device)
19
+ for _ in range(torch.distributed.get_world_size())
20
+ ]
21
+ _ = torch.distributed.all_gather(gathered, tensor)
22
+ return gathered
23
+
24
+
25
+ # @numba.njit
26
+ def _balanced_partition(sizes: np.ndarray, num_parts: int):
27
+ """
28
+ Greedily partition the given set by always inserting
29
+ the largest element into the smallest partition.
30
+ """
31
+ sort_idx = np.argsort(-sizes) # Sort in descending order
32
+ heap = []
33
+ for idx in sort_idx[:num_parts]:
34
+ heap.append((sizes[idx], [idx]))
35
+ heapq.heapify(heap)
36
+ for idx in sort_idx[num_parts:]:
37
+ smallest_part = heapq.heappop(heap)
38
+ new_size = smallest_part[0] + sizes[idx]
39
+ new_idx = smallest_part[1] + [idx]
40
+ heapq.heappush(heap, (new_size, new_idx))
41
+ idx_balanced = [part[1] for part in heap]
42
+ return idx_balanced
43
+
44
+
45
+ @runtime_checkable
46
+ class DatasetWithSizes(Protocol):
47
+ def data_sizes(self, indices: list[int]) -> np.ndarray: ...
48
+
49
+
50
+ class BalancedBatchSampler(BatchSampler):
51
+ @staticmethod
52
+ def _ensure_supported(dataset: Any):
53
+ if not isinstance(dataset, Dataset):
54
+ raise ValueError(
55
+ "BalancedBatchSampler requires a dataset that implements `__getitem__`"
56
+ )
57
+
58
+ if not isinstance(dataset, DatasetWithSizes):
59
+ raise ValueError(
60
+ "BalancedBatchSampler requires a dataset that implements `data_sizes`"
61
+ )
62
+
63
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
64
+ return dataset
65
+
66
+ @staticmethod
67
+ def _unwrap_dataset(dataset: Dataset) -> Dataset:
68
+ if isinstance(dataset, _DatasetSamplerWrapper):
69
+ if (data_source := getattr(dataset._sampler, "data_source", None)) is None:
70
+ raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper")
71
+ return data_source
72
+ return dataset
73
+
74
+ @property
75
+ def distributed_sampler(self):
76
+ if not isinstance(self.sampler, DistributedSampler):
77
+ raise ValueError(
78
+ f"Sampler must be a DistributedSampler, got {type(self.sampler)}"
79
+ )
80
+ return self.sampler
81
+
82
+ @cached_property
83
+ def dataset(self):
84
+ return self._ensure_supported(
85
+ self._unwrap_dataset(self.distributed_sampler.dataset)
86
+ )
87
+
88
+ def __init__(
89
+ self,
90
+ sampler: DistributedSampler,
91
+ *,
92
+ batch_size: int,
93
+ device: torch.device,
94
+ drop_last: bool = False,
95
+ ):
96
+ super().__init__(sampler, batch_size, drop_last=drop_last)
97
+
98
+ self._device = device
99
+
100
+ log.info(
101
+ f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
102
+ )
103
+
104
+ @staticmethod
105
+ def _dist_enabled():
106
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
107
+
108
+ @override
109
+ def __iter__(self):
110
+ if not self._dist_enabled():
111
+ yield from super().__iter__()
112
+ return
113
+
114
+ for batch_idx in super().__iter__():
115
+ sizes = self.dataset.data_sizes(batch_idx)
116
+ idx_sizes = torch.stack(
117
+ [
118
+ torch.tensor(batch_idx, device=self._device),
119
+ torch.tensor(sizes, device=self._device),
120
+ ]
121
+ )
122
+ idx_sizes_all = _all_gather(idx_sizes, device=self._device)
123
+ idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu()
124
+ idx_all = idx_sizes_all[0]
125
+ sizes_all = idx_sizes_all[1]
126
+
127
+ local_idx_balanced = _balanced_partition(
128
+ sizes_all.numpy(), num_parts=self.distributed_sampler.num_replicas
129
+ )
130
+ # Since DistributedSampler pads the last batch
131
+ # this should always have an entry for each replica.
132
+ yield idx_all[local_idx_balanced[self.distributed_sampler.rank]].tolist()
@@ -0,0 +1,67 @@
1
+ import copy
2
+ from collections.abc import Callable
3
+ from typing import Any, cast
4
+
5
+ from typing_extensions import TypeVar
6
+
7
+ TDataset = TypeVar("TDataset", infer_variance=True)
8
+
9
+
10
+ def transform(
11
+ dataset: TDataset,
12
+ transform: Callable[[Any], Any],
13
+ *,
14
+ deepcopy: bool = False,
15
+ ) -> TDataset:
16
+ """
17
+ Wraps a dataset with a transform function.
18
+
19
+ Args:
20
+ dataset: The dataset to wrap.
21
+ transform: The transform function to apply to each item.
22
+ deepcopy: Whether to deep copy each item before applying the transform.
23
+ """
24
+
25
+ import wrapt
26
+
27
+ class _TransformedDataset(wrapt.ObjectProxy):
28
+ def __getitem__(self, idx):
29
+ nonlocal deepcopy, transform
30
+
31
+ data = self.__wrapped__.__getitem__(idx)
32
+ if deepcopy:
33
+ data = copy.deepcopy(data)
34
+ data = transform(data)
35
+ return data
36
+
37
+ return cast(TDataset, _TransformedDataset(dataset))
38
+
39
+
40
+ def transform_with_index(
41
+ dataset: TDataset,
42
+ transform: Callable[[Any, int], Any],
43
+ *,
44
+ deepcopy: bool = False,
45
+ ) -> TDataset:
46
+ """
47
+ Wraps a dataset with a transform function that takes an index, in addition to the item.
48
+
49
+ Args:
50
+ dataset: The dataset to wrap.
51
+ transform: The transform function to apply to each item.
52
+ deepcopy: Whether to deep copy each item before applying the transform.
53
+ """
54
+
55
+ import wrapt
56
+
57
+ class _TransformedWithIndexDataset(wrapt.ObjectProxy):
58
+ def __getitem__(self, idx: int):
59
+ nonlocal deepcopy, transform
60
+
61
+ data = self.__wrapped__.__getitem__(idx)
62
+ if deepcopy:
63
+ data = copy.deepcopy(data)
64
+ data = transform(data, idx)
65
+ return data
66
+
67
+ return cast(TDataset, _TransformedWithIndexDataset(dataset))
@@ -0,0 +1,18 @@
1
+ from typing import Annotated, TypeAlias
2
+
3
+ from ..config import Field
4
+ from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
5
+ from ._base import LRSchedulerMetadata as LRSchedulerMetadata
6
+ from .linear_warmup_cosine import (
7
+ LinearWarmupCosineAnnealingLR as LinearWarmupCosineAnnealingLR,
8
+ )
9
+ from .linear_warmup_cosine import (
10
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
11
+ )
12
+ from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
13
+ from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
14
+
15
+ LRSchedulerConfig: TypeAlias = Annotated[
16
+ LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
17
+ Field(discriminator="name"),
18
+ ]
@@ -0,0 +1,101 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Mapping
4
+ from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias
5
+
6
+ from lightning.pytorch.utilities.types import (
7
+ LRSchedulerConfigType,
8
+ LRSchedulerTypeUnion,
9
+ )
10
+ from torch.optim import Optimizer
11
+ from typing_extensions import NotRequired, TypedDict
12
+
13
+ from ..config import TypedConfig
14
+
15
+ if TYPE_CHECKING:
16
+ from ..model.base import LightningModuleBase
17
+
18
+
19
+ class LRSchedulerMetadata(TypedDict):
20
+ interval: Literal["epoch", "step"]
21
+ """Interval to update the learning rate."""
22
+
23
+ name: NotRequired[str | None]
24
+ """Name of the learning rate scheduler. Default is `None`."""
25
+
26
+ frequency: NotRequired[int]
27
+ """Frequency to update the learning rate. Default is `1`."""
28
+
29
+ reduce_on_plateau: NotRequired[bool]
30
+ """Whether to reduce the learning rate on plateau. Default is `False`."""
31
+
32
+ monitor: NotRequired[str | None]
33
+ """Value to monitor for reducing the learning rate on plateau. Required if `reduce_on_plateau` is `True`.
34
+ Default is `None`."""
35
+
36
+ strict: NotRequired[bool]
37
+ """Whether to enforce that the monitor exists for reducing the learning rate on plateau. Default is `True`."""
38
+
39
+
40
+ class LRSchedulerConfigBase(TypedConfig, ABC):
41
+ Metadata: ClassVar[TypeAlias] = LRSchedulerMetadata
42
+
43
+ @abstractmethod
44
+ def metadata(self) -> LRSchedulerMetadata: ...
45
+
46
+ @abstractmethod
47
+ def create_scheduler_impl(
48
+ self,
49
+ optimizer: Optimizer,
50
+ lightning_module: "LightningModuleBase",
51
+ lr: float,
52
+ ) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
53
+
54
+ def create_scheduler(
55
+ self,
56
+ optimizer: Optimizer,
57
+ lightning_module: "LightningModuleBase",
58
+ lr: float,
59
+ ) -> LRSchedulerConfigType:
60
+ # Create the scheduler.
61
+ scheduler = self.create_scheduler_impl(optimizer, lightning_module, lr)
62
+
63
+ # If the scheduler is not a `LRSchedulerConfigType`, then make it one.
64
+ if not isinstance(scheduler, Mapping):
65
+ scheduler = LRSchedulerConfigType(scheduler=scheduler)
66
+
67
+ # Update the scheduler config with the metadata (if not already present).
68
+ metadata = self.metadata()
69
+ # - `interval` has to be present.
70
+ if scheduler.get("interval") is None:
71
+ scheduler["interval"] = metadata["interval"]
72
+ # - `name`
73
+ if scheduler.get("name") is None and "name" in metadata:
74
+ scheduler["name"] = metadata["name"]
75
+ # - `frequency`
76
+ if scheduler.get("frequency") is None and "frequency" in metadata:
77
+ scheduler["frequency"] = metadata["frequency"]
78
+ # - `reduce_on_plateau`
79
+ if (
80
+ scheduler.get("reduce_on_plateau") is None
81
+ and "reduce_on_plateau" in metadata
82
+ ):
83
+ scheduler["reduce_on_plateau"] = metadata["reduce_on_plateau"]
84
+ # - `monitor`
85
+ if scheduler.get("monitor") is None and "monitor" in metadata:
86
+ scheduler["monitor"] = metadata["monitor"]
87
+ # - `strict`
88
+ if scheduler.get("strict") is None and "strict" in metadata:
89
+ scheduler["strict"] = metadata["strict"] # type: ignore
90
+
91
+ return scheduler
92
+
93
+ def compute_num_steps_per_epoch(
94
+ self, lightning_module: "LightningModuleBase"
95
+ ) -> int:
96
+ trainer = lightning_module.trainer
97
+ # Use the Lightning trainer to convert the epoch-based values to step-based values
98
+ _ = trainer.estimated_stepping_batches
99
+ # ^ This is a hack to trigger the computation of the estimated stepping batches
100
+ # and make sure that the `trainer.num_training_batches` attribute is set.
101
+ return math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
@@ -0,0 +1,138 @@
1
+ import math
2
+ import warnings
3
+ from typing import Literal
4
+
5
+ from torch.optim import Optimizer
6
+ from torch.optim.lr_scheduler import LRScheduler
7
+ from typing_extensions import override
8
+
9
+ from ..config import Field
10
+ from ._base import LRSchedulerConfigBase
11
+
12
+
13
+ class LinearWarmupCosineAnnealingLR(LRScheduler):
14
+ _get_lr_called_within_step: bool
15
+
16
+ def __init__(
17
+ self,
18
+ optimizer: Optimizer,
19
+ warmup_epochs: int,
20
+ max_epochs: int,
21
+ warmup_start_lr: float = 0.0,
22
+ eta_min: float = 0.0,
23
+ last_epoch: int = -1,
24
+ should_restart: bool = True,
25
+ ) -> None:
26
+ self.warmup_epochs = warmup_epochs
27
+ self.max_epochs = max_epochs
28
+ self.warmup_start_lr = warmup_start_lr
29
+ self.eta_min = eta_min
30
+ self.should_restart = should_restart
31
+
32
+ super().__init__(optimizer, last_epoch)
33
+
34
+ @override
35
+ def get_lr(self) -> list[float]: # pyright: ignore[reportIncompatibleMethodOverride]
36
+ if not self._get_lr_called_within_step:
37
+ warnings.warn(
38
+ "To get the last learning rate computed by the scheduler, "
39
+ "please use `get_last_lr()`.",
40
+ UserWarning,
41
+ )
42
+
43
+ if self.last_epoch == 0:
44
+ return [self.warmup_start_lr] * len(self.base_lrs)
45
+ if self.last_epoch < self.warmup_epochs:
46
+ return [
47
+ group["lr"]
48
+ + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
49
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
50
+ ]
51
+ if self.last_epoch == self.warmup_epochs:
52
+ return self.base_lrs
53
+
54
+ if not self.should_restart and self.last_epoch >= self.max_epochs:
55
+ return [self.eta_min] * len(self.base_lrs)
56
+
57
+ if (self.last_epoch - 1 - self.max_epochs) % (
58
+ 2 * (self.max_epochs - self.warmup_epochs)
59
+ ) == 0:
60
+ return [
61
+ group["lr"]
62
+ + (base_lr - self.eta_min)
63
+ * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)))
64
+ / 2
65
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
66
+ ]
67
+
68
+ return [
69
+ (
70
+ 1
71
+ + math.cos(
72
+ math.pi
73
+ * (self.last_epoch - self.warmup_epochs)
74
+ / (self.max_epochs - self.warmup_epochs)
75
+ )
76
+ )
77
+ / (
78
+ 1
79
+ + math.cos(
80
+ math.pi
81
+ * (self.last_epoch - self.warmup_epochs - 1)
82
+ / (self.max_epochs - self.warmup_epochs)
83
+ )
84
+ )
85
+ * (group["lr"] - self.eta_min)
86
+ + self.eta_min
87
+ for group in self.optimizer.param_groups
88
+ ]
89
+
90
+
91
+ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
92
+ name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
93
+
94
+ warmup_epochs: int = Field(ge=0)
95
+ r"""The number of epochs for the linear warmup phase.
96
+ The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
97
+
98
+ max_epochs: int = Field(gt=0)
99
+ r"""The total number of epochs.
100
+ The learning rate is decayed to `min_lr` over this number of epochs."""
101
+
102
+ warmup_start_lr_factor: float = 0.0
103
+ r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
104
+ The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
105
+
106
+ min_lr_factor: float = 0.0
107
+ r"""The minimum learning rate, as a factor of the initial learning rate.
108
+ The learning rate is decayed to this value over `max_epochs` epochs."""
109
+
110
+ annealing: bool = False
111
+ r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
112
+ If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
113
+ If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
114
+
115
+ @override
116
+ def metadata(self) -> LRSchedulerConfigBase.Metadata:
117
+ return {
118
+ "interval": "step",
119
+ }
120
+
121
+ @override
122
+ def create_scheduler_impl(self, optimizer, lightning_module, lr):
123
+ num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
124
+ warmup_steps = self.warmup_epochs * num_steps_per_epoch
125
+ max_steps = self.max_epochs * num_steps_per_epoch
126
+ warmup_start_lr = self.warmup_start_lr_factor * lr
127
+ min_lr = self.min_lr_factor * lr
128
+
129
+ # Create the scheduler
130
+ scheduler = LinearWarmupCosineAnnealingLR(
131
+ optimizer=optimizer,
132
+ warmup_epochs=warmup_steps,
133
+ max_epochs=max_steps,
134
+ warmup_start_lr=warmup_start_lr,
135
+ eta_min=min_lr,
136
+ should_restart=self.annealing,
137
+ )
138
+ return scheduler
@@ -0,0 +1,73 @@
1
+ from typing import TYPE_CHECKING, Literal, cast
2
+
3
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
4
+ from typing_extensions import override
5
+
6
+ from ll.lr_scheduler._base import LRSchedulerMetadata
7
+
8
+ from ..model.config import MetricConfig
9
+ from ._base import LRSchedulerConfigBase
10
+
11
+ if TYPE_CHECKING:
12
+ from ..model.base import BaseConfig
13
+
14
+
15
+ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
16
+ """Reduce learning rate when a metric has stopped improving."""
17
+
18
+ name: Literal["reduce_lr_on_plateau"] = "reduce_lr_on_plateau"
19
+
20
+ metric: MetricConfig | None = None
21
+ """Metric to monitor.
22
+ If not provided, the primary metric of the runner will be used."""
23
+
24
+ patience: int = 10
25
+ r"""Number of epochs with no improvement after which learning rate will be reduced."""
26
+
27
+ factor: float = 0.1
28
+ r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
29
+
30
+ min_lr: float | list[float] = 0.0
31
+ r"""A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively."""
32
+
33
+ eps: float = 1.0e-8
34
+ r"""Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored."""
35
+
36
+ cooldown: int = 0
37
+ r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
38
+
39
+ threshold: float = 1.0e-4
40
+ r"""Threshold for measuring the new optimum, to only focus on significant changes."""
41
+
42
+ threshold_mode: str = "rel"
43
+ r"""One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'."""
44
+
45
+ @override
46
+ def create_scheduler_impl(self, optimizer, lightning_module, lr):
47
+ if (metric := self.metric) is None:
48
+ lm_config = cast("BaseConfig", lightning_module.config)
49
+ assert (
50
+ metric := lm_config.primary_metric
51
+ ) is not None, "Primary metric must be provided if metric is not specified."
52
+
53
+ lr_scheduler = ReduceLROnPlateau(
54
+ optimizer,
55
+ mode=metric.mode,
56
+ factor=self.factor,
57
+ patience=self.patience,
58
+ threshold=self.threshold,
59
+ threshold_mode=self.threshold_mode,
60
+ cooldown=self.cooldown,
61
+ min_lr=self.min_lr,
62
+ eps=self.eps,
63
+ )
64
+ return {
65
+ "scheduler": lr_scheduler,
66
+ "monitor": metric.validation_monitor,
67
+ }
68
+
69
+ @override
70
+ def metadata(self) -> LRSchedulerMetadata:
71
+ return {
72
+ "interval": "epoch",
73
+ }
@@ -0,0 +1,44 @@
1
+ from typing_extensions import TypeAlias
2
+
3
+ from .base import Base as Base
4
+ from .base import LightningDataModuleBase as LightningDataModuleBase
5
+ from .base import LightningModuleBase as LightningModuleBase
6
+ from .config import ActSaveConfig as ActSaveConfig
7
+ from .config import BaseConfig as BaseConfig
8
+ from .config import BaseLoggerConfig as BaseLoggerConfig
9
+ from .config import BaseProfilerConfig as BaseProfilerConfig
10
+ from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
11
+ from .config import CheckpointSavingConfig as CheckpointSavingConfig
12
+ from .config import DirectoryConfig as DirectoryConfig
13
+ from .config import EarlyStoppingConfig as EarlyStoppingConfig
14
+ from .config import (
15
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
16
+ )
17
+ from .config import EnvironmentConfig as EnvironmentConfig
18
+ from .config import (
19
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
20
+ )
21
+ from .config import (
22
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
23
+ )
24
+ from .config import GradientClippingConfig as GradientClippingConfig
25
+ from .config import (
26
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
27
+ )
28
+ from .config import LoggingConfig as LoggingConfig
29
+ from .config import MetricConfig as MetricConfig
30
+ from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
31
+ from .config import (
32
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
33
+ )
34
+ from .config import OptimizationConfig as OptimizationConfig
35
+ from .config import PrimaryMetricConfig as PrimaryMetricConfig
36
+ from .config import PythonLogging as PythonLogging
37
+ from .config import ReproducibilityConfig as ReproducibilityConfig
38
+ from .config import RunnerConfig as RunnerConfig
39
+ from .config import SanityCheckingConfig as SanityCheckingConfig
40
+ from .config import SeedConfig as SeedConfig
41
+ from .config import TrainerConfig as TrainerConfig
42
+ from .config import WandbWatchConfig as WandbWatchConfig
43
+
44
+ ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]