nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +12 -1
- nshtrainer/callbacks/debug_flag.py +72 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +124 -67
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +787 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/callback.py +0 -206
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
nshtrainer/ll/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import TypeAlias
|
|
2
|
+
|
|
1
3
|
from . import _experimental as _experimental
|
|
2
4
|
from . import actsave as actsave
|
|
3
5
|
from . import callbacks as callbacks
|
|
@@ -21,12 +23,9 @@ from .log import init_python_logging as init_python_logging
|
|
|
21
23
|
from .log import lovely as lovely
|
|
22
24
|
from .log import pretty as pretty
|
|
23
25
|
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
|
-
from .model import Base as Base
|
|
25
26
|
from .model import BaseConfig as BaseConfig
|
|
26
|
-
from .model import BaseProfilerConfig as BaseProfilerConfig
|
|
27
27
|
from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
28
28
|
from .model import CheckpointSavingConfig as CheckpointSavingConfig
|
|
29
|
-
from .model import ConfigList as ConfigList
|
|
30
29
|
from .model import DirectoryConfig as DirectoryConfig
|
|
31
30
|
from .model import (
|
|
32
31
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
@@ -43,7 +42,6 @@ from .model import LightningModuleBase as LightningModuleBase
|
|
|
43
42
|
from .model import LoggingConfig as LoggingConfig
|
|
44
43
|
from .model import MetricConfig as MetricConfig
|
|
45
44
|
from .model import OptimizationConfig as OptimizationConfig
|
|
46
|
-
from .model import PrimaryMetricConfig as PrimaryMetricConfig
|
|
47
45
|
from .model import ReproducibilityConfig as ReproducibilityConfig
|
|
48
46
|
from .model import SanityCheckingConfig as SanityCheckingConfig
|
|
49
47
|
from .model import TrainerConfig as TrainerConfig
|
|
@@ -54,3 +52,6 @@ from .runner import Runner as Runner
|
|
|
54
52
|
from .runner import SnapshotConfig as SnapshotConfig
|
|
55
53
|
from .snoop import snoop as snoop
|
|
56
54
|
from .trainer import Trainer as Trainer
|
|
55
|
+
|
|
56
|
+
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
57
|
+
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|
nshtrainer/ll/model.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
|
1
1
|
from nshtrainer.model import * # noqa: F403
|
|
2
2
|
|
|
3
|
+
from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
4
|
+
from ..trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
5
|
+
from ..trainer._config import GradientClippingConfig as GradientClippingConfig
|
|
6
|
+
from ..trainer._config import LoggingConfig as LoggingConfig
|
|
7
|
+
from ..trainer._config import OptimizationConfig as OptimizationConfig
|
|
8
|
+
from ..trainer._config import ReproducibilityConfig as ReproducibilityConfig
|
|
9
|
+
from ..trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
|
3
10
|
from ..util._environment_info import (
|
|
4
11
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
5
12
|
)
|
nshtrainer/loggers/wandb.py
CHANGED
|
@@ -129,7 +129,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
129
129
|
"Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
|
|
130
130
|
)
|
|
131
131
|
else:
|
|
132
|
-
wandb.require("core")
|
|
132
|
+
wandb.require("core") # type: ignore
|
|
133
133
|
log.critical("Using the `wandb-core` backend for WandB.")
|
|
134
134
|
except ImportError:
|
|
135
135
|
pass
|
|
@@ -6,7 +6,7 @@ from torch.optim import Optimizer
|
|
|
6
6
|
from torch.optim.lr_scheduler import LRScheduler
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
-
from ..config import Duration
|
|
9
|
+
from ..util.config import Duration
|
|
10
10
|
from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
|
|
11
11
|
|
|
12
12
|
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -1,26 +1,5 @@
|
|
|
1
|
-
from typing_extensions import TypeAlias
|
|
2
|
-
|
|
3
|
-
from .base import Base as Base
|
|
4
1
|
from .base import LightningModuleBase as LightningModuleBase
|
|
5
2
|
from .config import BaseConfig as BaseConfig
|
|
6
|
-
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
7
|
-
from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
8
|
-
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
9
|
-
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
10
3
|
from .config import DirectoryConfig as DirectoryConfig
|
|
11
|
-
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
12
|
-
from .config import GradientClippingConfig as GradientClippingConfig
|
|
13
|
-
from .config import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
14
|
-
from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
15
|
-
from .config import LoggingConfig as LoggingConfig
|
|
16
4
|
from .config import MetricConfig as MetricConfig
|
|
17
|
-
from .config import (
|
|
18
|
-
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
19
|
-
)
|
|
20
|
-
from .config import OptimizationConfig as OptimizationConfig
|
|
21
|
-
from .config import PrimaryMetricConfig as PrimaryMetricConfig
|
|
22
|
-
from .config import ReproducibilityConfig as ReproducibilityConfig
|
|
23
|
-
from .config import SanityCheckingConfig as SanityCheckingConfig
|
|
24
5
|
from .config import TrainerConfig as TrainerConfig
|
|
25
|
-
|
|
26
|
-
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|
nshtrainer/model/base.py
CHANGED
|
@@ -2,39 +2,65 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import MutableMapping
|
|
5
|
-
from typing import IO, TYPE_CHECKING, Any, Generic, cast
|
|
5
|
+
from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
import torch.distributed
|
|
8
9
|
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
|
9
|
-
from lightning.pytorch import LightningModule
|
|
10
|
-
from lightning.pytorch.
|
|
10
|
+
from lightning.pytorch import LightningModule
|
|
11
|
+
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
|
11
12
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
12
13
|
from typing_extensions import Self, TypeVar, override
|
|
13
14
|
|
|
15
|
+
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
|
14
16
|
from ..util._environment_info import EnvironmentConfig
|
|
15
17
|
from .config import BaseConfig
|
|
16
|
-
from .
|
|
17
|
-
from .modules.debug import DebugModuleMixin
|
|
18
|
-
from .modules.distributed import DistributedMixin
|
|
19
|
-
from .modules.logger import LoggerLightningModuleMixin
|
|
20
|
-
from .modules.profiler import ProfilerMixin
|
|
21
|
-
from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
|
|
22
|
-
from .modules.shared_parameters import SharedParametersModuleMixin
|
|
18
|
+
from .mixins.logger import LoggerLightningModuleMixin
|
|
23
19
|
|
|
24
20
|
log = logging.getLogger(__name__)
|
|
25
21
|
|
|
26
22
|
THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
|
|
27
23
|
|
|
28
24
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
25
|
+
T = TypeVar("T", infer_variance=True)
|
|
26
|
+
|
|
27
|
+
ReduceOpStr = Literal[
|
|
28
|
+
"avg",
|
|
29
|
+
"mean",
|
|
30
|
+
"band",
|
|
31
|
+
"bor",
|
|
32
|
+
"bxor",
|
|
33
|
+
"max",
|
|
34
|
+
"min",
|
|
35
|
+
"premul_sum",
|
|
36
|
+
"product",
|
|
37
|
+
"sum",
|
|
38
|
+
]
|
|
39
|
+
VALID_REDUCE_OPS = (
|
|
40
|
+
"avg",
|
|
41
|
+
"mean",
|
|
42
|
+
"band",
|
|
43
|
+
"bor",
|
|
44
|
+
"bxor",
|
|
45
|
+
"max",
|
|
46
|
+
"min",
|
|
47
|
+
"premul_sum",
|
|
48
|
+
"product",
|
|
49
|
+
"sum",
|
|
50
|
+
)
|
|
51
|
+
|
|
34
52
|
|
|
53
|
+
class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
54
|
+
_RLPSanityCheckModuleMixin,
|
|
55
|
+
LoggerLightningModuleMixin,
|
|
56
|
+
LightningModule,
|
|
57
|
+
ABC,
|
|
58
|
+
Generic[THparams],
|
|
59
|
+
):
|
|
60
|
+
# region Config
|
|
35
61
|
@torch.jit.unused
|
|
36
62
|
@property
|
|
37
|
-
def
|
|
63
|
+
def config(self) -> THparams:
|
|
38
64
|
return self.hparams
|
|
39
65
|
|
|
40
66
|
@property
|
|
@@ -43,65 +69,98 @@ class Base(DebugModuleMixin, Generic[THparams]):
|
|
|
43
69
|
return False
|
|
44
70
|
return self.config.debug
|
|
45
71
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
72
|
+
# endregion
|
|
73
|
+
|
|
74
|
+
# region Debug
|
|
75
|
+
|
|
76
|
+
@torch.jit.unused
|
|
77
|
+
def breakpoint(self, rank_zero_only: bool = True):
|
|
78
|
+
if (
|
|
79
|
+
not rank_zero_only
|
|
80
|
+
or not torch.distributed.is_initialized()
|
|
81
|
+
or torch.distributed.get_rank() == 0
|
|
82
|
+
):
|
|
83
|
+
breakpoint()
|
|
84
|
+
|
|
85
|
+
if rank_zero_only and torch.distributed.is_initialized():
|
|
86
|
+
_ = torch.distributed.barrier()
|
|
87
|
+
|
|
88
|
+
@torch.jit.unused
|
|
89
|
+
def ensure_finite(
|
|
90
|
+
self,
|
|
91
|
+
tensor: torch.Tensor,
|
|
92
|
+
name: str | None = None,
|
|
93
|
+
throw: bool = False,
|
|
94
|
+
):
|
|
95
|
+
name_parts: list[str] = ["Tensor"]
|
|
96
|
+
if name is not None:
|
|
97
|
+
name_parts.append(name)
|
|
98
|
+
name = " ".join(name_parts)
|
|
99
|
+
|
|
100
|
+
not_finite = ~torch.isfinite(tensor)
|
|
101
|
+
if not_finite.any():
|
|
102
|
+
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
|
103
|
+
if throw:
|
|
104
|
+
raise RuntimeError(msg)
|
|
105
|
+
else:
|
|
106
|
+
log.warning(msg)
|
|
49
107
|
return False
|
|
50
|
-
return
|
|
108
|
+
return True
|
|
51
109
|
|
|
52
|
-
|
|
53
|
-
def __init__(self, hparams: THparams):
|
|
54
|
-
super().__init__()
|
|
110
|
+
# endregion
|
|
55
111
|
|
|
56
|
-
|
|
57
|
-
|
|
112
|
+
# region Profiler
|
|
113
|
+
@property
|
|
114
|
+
def profiler(self) -> Profiler:
|
|
115
|
+
if (trainer := self._trainer) is None:
|
|
116
|
+
raise RuntimeError("trainer is not defined")
|
|
58
117
|
|
|
118
|
+
if not hasattr(trainer, "profiler"):
|
|
119
|
+
raise RuntimeError("trainer does not have profiler")
|
|
59
120
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
Sets the debug flag to true in the following circumstances:
|
|
63
|
-
- fast_dev_run is enabled
|
|
64
|
-
- sanity check is running
|
|
65
|
-
"""
|
|
121
|
+
if (profiler := getattr(trainer, "profiler")) is None:
|
|
122
|
+
profiler = PassThroughProfiler()
|
|
66
123
|
|
|
67
|
-
|
|
68
|
-
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
|
|
69
|
-
if not getattr(trainer, "fast_dev_run", False):
|
|
70
|
-
return
|
|
124
|
+
return profiler
|
|
71
125
|
|
|
72
|
-
|
|
73
|
-
if not hparams.debug:
|
|
74
|
-
log.critical("Fast dev run detected, setting debug flag to True.")
|
|
75
|
-
hparams.debug = True
|
|
126
|
+
# endregion
|
|
76
127
|
|
|
77
|
-
|
|
78
|
-
def
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
128
|
+
# region Distributed
|
|
129
|
+
def all_gather_object(
|
|
130
|
+
self,
|
|
131
|
+
object: T,
|
|
132
|
+
group: torch.distributed.ProcessGroup | None = None,
|
|
133
|
+
) -> list[T]:
|
|
134
|
+
if (
|
|
135
|
+
not torch.distributed.is_available()
|
|
136
|
+
or not torch.distributed.is_initialized()
|
|
137
|
+
):
|
|
138
|
+
return [object]
|
|
139
|
+
|
|
140
|
+
object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
|
|
141
|
+
torch.distributed.all_gather_object(object_list, object, group=group)
|
|
142
|
+
return object_list
|
|
143
|
+
|
|
144
|
+
def barrier(self, name: str | None = None):
|
|
145
|
+
self.trainer.strategy.barrier(name=name)
|
|
146
|
+
|
|
147
|
+
def reduce(
|
|
148
|
+
self,
|
|
149
|
+
tensor: torch.Tensor,
|
|
150
|
+
reduce_op: torch.distributed.ReduceOp.RedOpType | ReduceOpStr,
|
|
151
|
+
group: Any | None = None,
|
|
152
|
+
) -> torch.Tensor:
|
|
153
|
+
if isinstance(reduce_op, str):
|
|
154
|
+
# validate reduce_op
|
|
155
|
+
if reduce_op not in VALID_REDUCE_OPS:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
|
|
158
|
+
)
|
|
84
159
|
|
|
85
|
-
|
|
86
|
-
def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
87
|
-
hparams = cast(BaseConfig, pl_module.hparams)
|
|
88
|
-
if not self._debug:
|
|
89
|
-
log.critical("Sanity check routine complete, disabling debug flag.")
|
|
90
|
-
hparams.debug = self._debug
|
|
160
|
+
return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
|
|
91
161
|
|
|
162
|
+
# endregion
|
|
92
163
|
|
|
93
|
-
class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
94
|
-
ProfilerMixin,
|
|
95
|
-
RLPSanityCheckModuleMixin,
|
|
96
|
-
LoggerLightningModuleMixin,
|
|
97
|
-
SharedParametersModuleMixin,
|
|
98
|
-
DistributedMixin,
|
|
99
|
-
CallbackModuleMixin,
|
|
100
|
-
Base[THparams],
|
|
101
|
-
LightningModule,
|
|
102
|
-
ABC,
|
|
103
|
-
Generic[THparams],
|
|
104
|
-
):
|
|
105
164
|
# Our own custom __repr__ method.
|
|
106
165
|
# Torch's __repr__ method is too verbose and doesn't provide any useful information.
|
|
107
166
|
@override
|
|
@@ -193,10 +252,8 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
193
252
|
hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
|
|
194
253
|
hparams = self.pre_init_update_hparams(hparams)
|
|
195
254
|
|
|
196
|
-
super().__init__(
|
|
197
|
-
|
|
255
|
+
super().__init__()
|
|
198
256
|
self.save_hyperparameters(hparams)
|
|
199
|
-
self.register_callback(lambda: DebugFlagCallback())
|
|
200
257
|
|
|
201
258
|
def zero_loss(self):
|
|
202
259
|
"""
|