nshtrainer 0.30.0__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +8 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +3 -8
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +139 -44
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/callback.py +2 -2
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +778 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
|
@@ -6,7 +6,7 @@ from torch.optim import Optimizer
|
|
|
6
6
|
from torch.optim.lr_scheduler import LRScheduler
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
-
from ..config import Duration
|
|
9
|
+
from ..util.config import Duration
|
|
10
10
|
from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
|
|
11
11
|
|
|
12
12
|
|
|
@@ -121,13 +121,8 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
|
121
121
|
@override
|
|
122
122
|
def create_scheduler_impl(self, optimizer, lightning_module, lr):
|
|
123
123
|
num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
|
|
124
|
-
warmup_steps = (
|
|
125
|
-
|
|
126
|
-
* num_steps_per_epoch
|
|
127
|
-
)
|
|
128
|
-
max_steps = (
|
|
129
|
-
self.max_duration.to_steps(num_steps_per_epoch).value * num_steps_per_epoch
|
|
130
|
-
)
|
|
124
|
+
warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
|
|
125
|
+
max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
|
|
131
126
|
warmup_start_lr = self.warmup_start_lr_factor * lr
|
|
132
127
|
min_lr = self.min_lr_factor * lr
|
|
133
128
|
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -1,26 +1,5 @@
|
|
|
1
|
-
from typing_extensions import TypeAlias
|
|
2
|
-
|
|
3
|
-
from .base import Base as Base
|
|
4
1
|
from .base import LightningModuleBase as LightningModuleBase
|
|
5
2
|
from .config import BaseConfig as BaseConfig
|
|
6
|
-
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
7
|
-
from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
8
|
-
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
9
|
-
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
10
3
|
from .config import DirectoryConfig as DirectoryConfig
|
|
11
|
-
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
12
|
-
from .config import GradientClippingConfig as GradientClippingConfig
|
|
13
|
-
from .config import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
14
|
-
from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
15
|
-
from .config import LoggingConfig as LoggingConfig
|
|
16
4
|
from .config import MetricConfig as MetricConfig
|
|
17
|
-
from .config import (
|
|
18
|
-
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
19
|
-
)
|
|
20
|
-
from .config import OptimizationConfig as OptimizationConfig
|
|
21
|
-
from .config import PrimaryMetricConfig as PrimaryMetricConfig
|
|
22
|
-
from .config import ReproducibilityConfig as ReproducibilityConfig
|
|
23
|
-
from .config import SanityCheckingConfig as SanityCheckingConfig
|
|
24
5
|
from .config import TrainerConfig as TrainerConfig
|
|
25
|
-
|
|
26
|
-
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|
nshtrainer/model/base.py
CHANGED
|
@@ -2,61 +2,28 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import MutableMapping
|
|
5
|
-
from typing import IO, TYPE_CHECKING, Any, Generic, cast
|
|
5
|
+
from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
import torch.distributed
|
|
8
9
|
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
|
9
10
|
from lightning.pytorch import LightningModule, Trainer
|
|
10
11
|
from lightning.pytorch.callbacks import Callback
|
|
12
|
+
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
|
11
13
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
12
14
|
from typing_extensions import Self, TypeVar, override
|
|
13
15
|
|
|
16
|
+
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
|
14
17
|
from ..util._environment_info import EnvironmentConfig
|
|
15
18
|
from .config import BaseConfig
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
18
|
-
from .modules.distributed import DistributedMixin
|
|
19
|
-
from .modules.logger import LoggerLightningModuleMixin
|
|
20
|
-
from .modules.profiler import ProfilerMixin
|
|
21
|
-
from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
|
|
22
|
-
from .modules.shared_parameters import SharedParametersModuleMixin
|
|
19
|
+
from .mixins.callback import CallbackModuleMixin
|
|
20
|
+
from .mixins.logger import LoggerLightningModuleMixin
|
|
23
21
|
|
|
24
22
|
log = logging.getLogger(__name__)
|
|
25
23
|
|
|
26
24
|
THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
|
|
27
25
|
|
|
28
26
|
|
|
29
|
-
class Base(DebugModuleMixin, Generic[THparams]):
|
|
30
|
-
@torch.jit.unused
|
|
31
|
-
@property
|
|
32
|
-
def config(self) -> THparams:
|
|
33
|
-
return self.hparams
|
|
34
|
-
|
|
35
|
-
@torch.jit.unused
|
|
36
|
-
@property
|
|
37
|
-
def C(self) -> THparams:
|
|
38
|
-
return self.hparams
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def debug(self) -> bool:
|
|
42
|
-
if torch.jit.is_scripting():
|
|
43
|
-
return False
|
|
44
|
-
return self.config.debug
|
|
45
|
-
|
|
46
|
-
@property
|
|
47
|
-
def dev(self) -> bool:
|
|
48
|
-
if torch.jit.is_scripting():
|
|
49
|
-
return False
|
|
50
|
-
return self.config.debug
|
|
51
|
-
|
|
52
|
-
@override
|
|
53
|
-
def __init__(self, hparams: THparams):
|
|
54
|
-
super().__init__()
|
|
55
|
-
|
|
56
|
-
if not hasattr(self, "hparams"):
|
|
57
|
-
self.hparams = hparams
|
|
58
|
-
|
|
59
|
-
|
|
60
27
|
class DebugFlagCallback(Callback):
|
|
61
28
|
"""
|
|
62
29
|
Sets the debug flag to true in the following circumstances:
|
|
@@ -90,18 +57,146 @@ class DebugFlagCallback(Callback):
|
|
|
90
57
|
hparams.debug = self._debug
|
|
91
58
|
|
|
92
59
|
|
|
60
|
+
T = TypeVar("T", infer_variance=True)
|
|
61
|
+
|
|
62
|
+
ReduceOpStr = Literal[
|
|
63
|
+
"avg",
|
|
64
|
+
"mean",
|
|
65
|
+
"band",
|
|
66
|
+
"bor",
|
|
67
|
+
"bxor",
|
|
68
|
+
"max",
|
|
69
|
+
"min",
|
|
70
|
+
"premul_sum",
|
|
71
|
+
"product",
|
|
72
|
+
"sum",
|
|
73
|
+
]
|
|
74
|
+
VALID_REDUCE_OPS = (
|
|
75
|
+
"avg",
|
|
76
|
+
"mean",
|
|
77
|
+
"band",
|
|
78
|
+
"bor",
|
|
79
|
+
"bxor",
|
|
80
|
+
"max",
|
|
81
|
+
"min",
|
|
82
|
+
"premul_sum",
|
|
83
|
+
"product",
|
|
84
|
+
"sum",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
93
88
|
class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
94
|
-
|
|
95
|
-
RLPSanityCheckModuleMixin,
|
|
89
|
+
_RLPSanityCheckModuleMixin,
|
|
96
90
|
LoggerLightningModuleMixin,
|
|
97
|
-
SharedParametersModuleMixin,
|
|
98
|
-
DistributedMixin,
|
|
99
91
|
CallbackModuleMixin,
|
|
100
|
-
Base[THparams],
|
|
101
92
|
LightningModule,
|
|
102
93
|
ABC,
|
|
103
94
|
Generic[THparams],
|
|
104
95
|
):
|
|
96
|
+
# region Config
|
|
97
|
+
@torch.jit.unused
|
|
98
|
+
@property
|
|
99
|
+
def config(self) -> THparams:
|
|
100
|
+
return self.hparams
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def debug(self) -> bool:
|
|
104
|
+
if torch.jit.is_scripting():
|
|
105
|
+
return False
|
|
106
|
+
return self.config.debug
|
|
107
|
+
|
|
108
|
+
# endregion
|
|
109
|
+
|
|
110
|
+
# region Debug
|
|
111
|
+
|
|
112
|
+
@torch.jit.unused
|
|
113
|
+
def breakpoint(self, rank_zero_only: bool = True):
|
|
114
|
+
if (
|
|
115
|
+
not rank_zero_only
|
|
116
|
+
or not torch.distributed.is_initialized()
|
|
117
|
+
or torch.distributed.get_rank() == 0
|
|
118
|
+
):
|
|
119
|
+
breakpoint()
|
|
120
|
+
|
|
121
|
+
if rank_zero_only and torch.distributed.is_initialized():
|
|
122
|
+
_ = torch.distributed.barrier()
|
|
123
|
+
|
|
124
|
+
@torch.jit.unused
|
|
125
|
+
def ensure_finite(
|
|
126
|
+
self,
|
|
127
|
+
tensor: torch.Tensor,
|
|
128
|
+
name: str | None = None,
|
|
129
|
+
throw: bool = False,
|
|
130
|
+
):
|
|
131
|
+
name_parts: list[str] = ["Tensor"]
|
|
132
|
+
if name is not None:
|
|
133
|
+
name_parts.append(name)
|
|
134
|
+
name = " ".join(name_parts)
|
|
135
|
+
|
|
136
|
+
not_finite = ~torch.isfinite(tensor)
|
|
137
|
+
if not_finite.any():
|
|
138
|
+
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
|
139
|
+
if throw:
|
|
140
|
+
raise RuntimeError(msg)
|
|
141
|
+
else:
|
|
142
|
+
log.warning(msg)
|
|
143
|
+
return False
|
|
144
|
+
return True
|
|
145
|
+
|
|
146
|
+
# endregion
|
|
147
|
+
|
|
148
|
+
# region Profiler
|
|
149
|
+
@property
|
|
150
|
+
def profiler(self) -> Profiler:
|
|
151
|
+
if (trainer := self._trainer) is None:
|
|
152
|
+
raise RuntimeError("trainer is not defined")
|
|
153
|
+
|
|
154
|
+
if not hasattr(trainer, "profiler"):
|
|
155
|
+
raise RuntimeError("trainer does not have profiler")
|
|
156
|
+
|
|
157
|
+
if (profiler := getattr(trainer, "profiler")) is None:
|
|
158
|
+
profiler = PassThroughProfiler()
|
|
159
|
+
|
|
160
|
+
return profiler
|
|
161
|
+
|
|
162
|
+
# endregion
|
|
163
|
+
|
|
164
|
+
# region Distributed
|
|
165
|
+
def all_gather_object(
|
|
166
|
+
self,
|
|
167
|
+
object: T,
|
|
168
|
+
group: torch.distributed.ProcessGroup | None = None,
|
|
169
|
+
) -> list[T]:
|
|
170
|
+
if (
|
|
171
|
+
not torch.distributed.is_available()
|
|
172
|
+
or not torch.distributed.is_initialized()
|
|
173
|
+
):
|
|
174
|
+
return [object]
|
|
175
|
+
|
|
176
|
+
object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
|
|
177
|
+
torch.distributed.all_gather_object(object_list, object, group=group)
|
|
178
|
+
return object_list
|
|
179
|
+
|
|
180
|
+
def barrier(self, name: str | None = None):
|
|
181
|
+
self.trainer.strategy.barrier(name=name)
|
|
182
|
+
|
|
183
|
+
def reduce(
|
|
184
|
+
self,
|
|
185
|
+
tensor: torch.Tensor,
|
|
186
|
+
reduce_op: torch.distributed.ReduceOp.RedOpType | ReduceOpStr,
|
|
187
|
+
group: Any | None = None,
|
|
188
|
+
) -> torch.Tensor:
|
|
189
|
+
if isinstance(reduce_op, str):
|
|
190
|
+
# validate reduce_op
|
|
191
|
+
if reduce_op not in VALID_REDUCE_OPS:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
|
|
197
|
+
|
|
198
|
+
# endregion
|
|
199
|
+
|
|
105
200
|
# Our own custom __repr__ method.
|
|
106
201
|
# Torch's __repr__ method is too verbose and doesn't provide any useful information.
|
|
107
202
|
@override
|