nshtrainer 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/_checkpoint/loader.py +319 -0
- nshtrainer/_checkpoint/metadata.py +102 -0
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +49 -501
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/METADATA +3 -1
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/WHEEL +0 -0
|
@@ -11,10 +11,14 @@ from lightning.pytorch.utilities.types import _METRIC
|
|
|
11
11
|
from lightning_utilities.core.rank_zero import rank_zero_warn
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
-
from ...actsave import ActSave
|
|
15
14
|
from ...util.typing_utils import mixin_base_type
|
|
16
15
|
from ..config import BaseConfig
|
|
17
16
|
|
|
17
|
+
try:
|
|
18
|
+
from nshutils import ActSave # type: ignore
|
|
19
|
+
except ImportError:
|
|
20
|
+
ActSave = None
|
|
21
|
+
|
|
18
22
|
|
|
19
23
|
@dataclass(frozen=True, kw_only=True)
|
|
20
24
|
class _LogContext:
|
|
@@ -155,14 +159,15 @@ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningMod
|
|
|
155
159
|
|
|
156
160
|
def _logger_actsave(self, name: str, value: _METRIC) -> None:
|
|
157
161
|
hparams = cast(BaseConfig, self.hparams)
|
|
158
|
-
if
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
+
if not hparams.trainer.logging.actsave_logged_metrics:
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
if ActSave is None:
|
|
166
|
+
rank_zero_warn("ActSave is not available, skipping logging of metrics")
|
|
162
167
|
return
|
|
163
168
|
|
|
164
169
|
ActSave.save(
|
|
165
|
-
{
|
|
170
|
+
lambda: {
|
|
166
171
|
f"logger.{name}": lambda: value.compute()
|
|
167
172
|
if isinstance(value, torchmetrics.Metric)
|
|
168
173
|
else value
|
nshtrainer/runner.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Generic
|
|
|
5
5
|
|
|
6
6
|
from nshrunner import RunInfo
|
|
7
7
|
from nshrunner import Runner as _Runner
|
|
8
|
+
from nshrunner._submit import screen
|
|
8
9
|
from nshrunner.snapshot import SnapshotArgType
|
|
9
10
|
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
|
|
10
11
|
|
|
@@ -89,6 +90,7 @@ class Runner(
|
|
|
89
90
|
def fast_dev_run_session(
|
|
90
91
|
self,
|
|
91
92
|
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
93
|
+
options: screen.ScreenJobKwargs = {},
|
|
92
94
|
n_batches: int = 1,
|
|
93
95
|
*,
|
|
94
96
|
snapshot: SnapshotArgType,
|
|
@@ -99,10 +101,7 @@ class Runner(
|
|
|
99
101
|
]
|
|
100
102
|
| None = None,
|
|
101
103
|
activate_venv: bool = True,
|
|
102
|
-
session_name: str = "nshrunner",
|
|
103
|
-
attach: bool = True,
|
|
104
104
|
print_command: bool = True,
|
|
105
|
-
pause_before_exit: bool = False,
|
|
106
105
|
):
|
|
107
106
|
transforms = transforms or []
|
|
108
107
|
transforms.append(
|
|
@@ -110,13 +109,11 @@ class Runner(
|
|
|
110
109
|
)
|
|
111
110
|
return self.session(
|
|
112
111
|
runs,
|
|
112
|
+
options,
|
|
113
113
|
snapshot=snapshot,
|
|
114
114
|
setup_commands=setup_commands,
|
|
115
115
|
env=env,
|
|
116
116
|
transforms=transforms,
|
|
117
117
|
activate_venv=activate_venv,
|
|
118
|
-
session_name=session_name,
|
|
119
|
-
attach=attach,
|
|
120
118
|
print_command=print_command,
|
|
121
|
-
pause_before_exit=pause_before_exit,
|
|
122
119
|
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Literal, TypeAlias
|
|
6
|
+
|
|
7
|
+
from lightning.pytorch.callbacks.callback import Callback
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
Stage: TypeAlias = Literal["train", "validate", "test", "predict"]
|
|
13
|
+
ALL_STAGES = ("train", "validate", "test", "predict")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class TimeInfo:
|
|
18
|
+
datetime: datetime.datetime
|
|
19
|
+
monotonic: float
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RuntimeTrackerCallback(Callback):
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self._start_time: dict[Stage, TimeInfo] = {}
|
|
26
|
+
self._end_time: dict[Stage, TimeInfo] = {}
|
|
27
|
+
self._offsets = {stage: datetime.timedelta() for stage in ALL_STAGES}
|
|
28
|
+
|
|
29
|
+
def start_time(self, stage: Stage) -> TimeInfo | None:
|
|
30
|
+
"""Return the start time of a particular stage"""
|
|
31
|
+
return self._start_time.get(stage)
|
|
32
|
+
|
|
33
|
+
def end_time(self, stage: Stage) -> TimeInfo | None:
|
|
34
|
+
"""Return the end time of a particular stage"""
|
|
35
|
+
return self._end_time.get(stage)
|
|
36
|
+
|
|
37
|
+
def time_elapsed(self, stage: Stage) -> datetime.timedelta:
|
|
38
|
+
"""Return the time elapsed for a particular stage"""
|
|
39
|
+
start = self.start_time(stage)
|
|
40
|
+
end = self.end_time(stage)
|
|
41
|
+
offset = self._offsets[stage]
|
|
42
|
+
if start is None:
|
|
43
|
+
return offset
|
|
44
|
+
if end is None:
|
|
45
|
+
current = TimeInfo(datetime.datetime.now(), time.monotonic())
|
|
46
|
+
return (
|
|
47
|
+
datetime.timedelta(seconds=current.monotonic - start.monotonic) + offset
|
|
48
|
+
)
|
|
49
|
+
return datetime.timedelta(seconds=end.monotonic - start.monotonic) + offset
|
|
50
|
+
|
|
51
|
+
def _record_time(self, stage: Stage, time_dict: dict[Stage, TimeInfo]):
|
|
52
|
+
time_dict[stage] = TimeInfo(datetime.datetime.now(), time.monotonic())
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def on_train_start(self, trainer, pl_module):
|
|
56
|
+
self._record_time("train", self._start_time)
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
def on_train_end(self, trainer, pl_module):
|
|
60
|
+
self._record_time("train", self._end_time)
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def on_validation_start(self, trainer, pl_module):
|
|
64
|
+
self._record_time("validate", self._start_time)
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def on_validation_end(self, trainer, pl_module):
|
|
68
|
+
self._record_time("validate", self._end_time)
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def on_test_start(self, trainer, pl_module):
|
|
72
|
+
self._record_time("test", self._start_time)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def on_test_end(self, trainer, pl_module):
|
|
76
|
+
self._record_time("test", self._end_time)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def on_predict_start(self, trainer, pl_module):
|
|
80
|
+
self._record_time("predict", self._start_time)
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def on_predict_end(self, trainer, pl_module):
|
|
84
|
+
self._record_time("predict", self._end_time)
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def state_dict(self) -> dict[str, Any]:
|
|
88
|
+
return {
|
|
89
|
+
"time_elapsed": {
|
|
90
|
+
stage: self.time_elapsed(stage).total_seconds() for stage in ALL_STAGES
|
|
91
|
+
},
|
|
92
|
+
"start_times": {
|
|
93
|
+
stage: (info.datetime.isoformat(), info.monotonic)
|
|
94
|
+
for stage, info in self._start_time.items()
|
|
95
|
+
},
|
|
96
|
+
"end_times": {
|
|
97
|
+
stage: (info.datetime.isoformat(), info.monotonic)
|
|
98
|
+
for stage, info in self._end_time.items()
|
|
99
|
+
},
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
@override
|
|
103
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
104
|
+
time_elapsed: dict[Stage, float] = state_dict.get("time_elapsed", {})
|
|
105
|
+
for stage in ALL_STAGES:
|
|
106
|
+
self._offsets[stage] = datetime.timedelta(
|
|
107
|
+
seconds=time_elapsed.get(stage, 0)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
start_times: dict[Stage, tuple[str, float]] = state_dict.get("start_times", {})
|
|
111
|
+
for stage, (dt_str, monotonic) in start_times.items():
|
|
112
|
+
self._start_time[stage] = TimeInfo(
|
|
113
|
+
datetime.datetime.fromisoformat(dt_str), monotonic
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
end_times: dict[Stage, tuple[str, float]] = state_dict.get("end_times", {})
|
|
117
|
+
for stage, (dt_str, monotonic) in end_times.items():
|
|
118
|
+
self._end_time[stage] = TimeInfo(
|
|
119
|
+
datetime.datetime.fromisoformat(dt_str), monotonic
|
|
120
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, cast
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.trainer.connectors.checkpoint_connector import (
|
|
6
|
+
_CheckpointConnector,
|
|
7
|
+
)
|
|
8
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from .._checkpoint.loader import CheckpointLoadingConfig, _resolve_checkpoint
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..model.config import BaseConfig
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CheckpointConnector(_CheckpointConnector):
|
|
19
|
+
def __resolve_auto_ckpt_path(
|
|
20
|
+
self,
|
|
21
|
+
ckpt_path: str | Path | None,
|
|
22
|
+
state_fn: TrainerFn,
|
|
23
|
+
):
|
|
24
|
+
from .trainer import Trainer
|
|
25
|
+
|
|
26
|
+
# If this isn't an `nshtrainer` trainer (which I don't know why it wouldn't be),
|
|
27
|
+
# then we just default to the parent class's implementation of `_parse_ckpt_path`.
|
|
28
|
+
trainer = self.trainer
|
|
29
|
+
if not isinstance(trainer, Trainer):
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
# Now, resolve the checkpoint loader config.
|
|
33
|
+
root_config = cast("BaseConfig", trainer._base_module.config)
|
|
34
|
+
if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
|
|
35
|
+
ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
|
|
36
|
+
log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
|
|
37
|
+
|
|
38
|
+
# Use the config to resolve the checkpoint.
|
|
39
|
+
if (
|
|
40
|
+
ckpt_path := _resolve_checkpoint(ckpt_loader_config, root_config, trainer)
|
|
41
|
+
) is None:
|
|
42
|
+
log.info(
|
|
43
|
+
"No checkpoint found for the current trainer state. "
|
|
44
|
+
"Training will start from scratch."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
log.info(f"Loading checkpoint from: {ckpt_path}")
|
|
48
|
+
return ckpt_path
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def _parse_ckpt_path(
|
|
52
|
+
self,
|
|
53
|
+
state_fn: TrainerFn,
|
|
54
|
+
ckpt_path: str | Path | None,
|
|
55
|
+
model_provided: bool,
|
|
56
|
+
model_connected: bool,
|
|
57
|
+
):
|
|
58
|
+
if (p := self.__resolve_auto_ckpt_path(ckpt_path, state_fn)) is not None:
|
|
59
|
+
return p
|
|
60
|
+
|
|
61
|
+
return super()._parse_ckpt_path(
|
|
62
|
+
state_fn, ckpt_path, model_provided, model_connected
|
|
63
|
+
)
|
|
@@ -11,6 +11,7 @@ from pathlib import Path
|
|
|
11
11
|
from types import FrameType
|
|
12
12
|
from typing import Any, TypeAlias
|
|
13
13
|
|
|
14
|
+
import nshrunner as nr
|
|
14
15
|
import torch.utils.data
|
|
15
16
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
|
16
17
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
@@ -27,20 +28,22 @@ _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handler
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def _resolve_requeue_signals():
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
if timeout_signal_name := os.environ.get("NSHRUNNER_TIMEOUT_SIGNAL"):
|
|
33
|
-
signals.append(signal.Signals[timeout_signal_name])
|
|
34
|
-
|
|
35
|
-
if preempt_signal_name := os.environ.get("NSHRUNNER_PREEMPT_SIGNAL"):
|
|
36
|
-
signals.append(signal.Signals[preempt_signal_name])
|
|
31
|
+
if (session := nr.Session.from_current_session()) is None:
|
|
32
|
+
return None
|
|
37
33
|
|
|
34
|
+
signals: list[signal.Signals] = []
|
|
35
|
+
if session.submit_timeout_signal:
|
|
36
|
+
signals.append(session.submit_timeout_signal)
|
|
37
|
+
if session.submit_preempt_signal:
|
|
38
|
+
signals.append(session.submit_preempt_signal)
|
|
38
39
|
return signals
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
class _SignalConnector(_LightningSignalConnector):
|
|
42
|
-
def _auto_requeue_signals(self) -> list[signal.Signals]:
|
|
43
|
-
signals
|
|
43
|
+
def _auto_requeue_signals(self) -> list[signal.Signals] | None:
|
|
44
|
+
if not (signals := _resolve_requeue_signals()):
|
|
45
|
+
return None
|
|
46
|
+
|
|
44
47
|
signals_set = set(signals)
|
|
45
48
|
valid_signals: set[signal.Signals] = signal.valid_signals()
|
|
46
49
|
assert signals_set.issubset(
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -3,7 +3,7 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
|
@@ -11,11 +11,13 @@ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
|
11
11
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
12
12
|
from lightning.pytorch import LightningModule
|
|
13
13
|
from lightning.pytorch import Trainer as LightningTrainer
|
|
14
|
+
from lightning.pytorch.callbacks import Callback
|
|
14
15
|
from lightning.pytorch.profilers import Profiler
|
|
16
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
15
17
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
|
16
18
|
from typing_extensions import Unpack, assert_never, override
|
|
17
19
|
|
|
18
|
-
from ..
|
|
20
|
+
from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
19
21
|
from ..callbacks.base import resolve_all_callbacks
|
|
20
22
|
from ..model.config import (
|
|
21
23
|
AcceleratorConfigProtocol,
|
|
@@ -24,6 +26,7 @@ from ..model.config import (
|
|
|
24
26
|
LightningTrainerKwargs,
|
|
25
27
|
StrategyConfigProtocol,
|
|
26
28
|
)
|
|
29
|
+
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
|
27
30
|
from .signal_connector import _SignalConnector
|
|
28
31
|
|
|
29
32
|
log = logging.getLogger(__name__)
|
|
@@ -168,12 +171,12 @@ class Trainer(LightningTrainer):
|
|
|
168
171
|
|
|
169
172
|
if (accelerator := config.trainer.accelerator) is not None:
|
|
170
173
|
if isinstance(accelerator, AcceleratorConfigProtocol):
|
|
171
|
-
accelerator = accelerator.
|
|
174
|
+
accelerator = accelerator.create_accelerator()
|
|
172
175
|
_update_kwargs(accelerator=accelerator)
|
|
173
176
|
|
|
174
177
|
if (strategy := config.trainer.strategy) is not None:
|
|
175
178
|
if isinstance(strategy, StrategyConfigProtocol):
|
|
176
|
-
strategy = strategy.
|
|
179
|
+
strategy = strategy.create_strategy()
|
|
177
180
|
_update_kwargs(strategy=strategy)
|
|
178
181
|
|
|
179
182
|
if (precision := config.trainer.precision) is not None:
|
|
@@ -220,7 +223,7 @@ class Trainer(LightningTrainer):
|
|
|
220
223
|
if profiler := config.trainer.profiler:
|
|
221
224
|
# If the profiler is an ProfilerConfig instance, then we instantiate it.
|
|
222
225
|
if isinstance(profiler, BaseProfilerConfig):
|
|
223
|
-
profiler = profiler.
|
|
226
|
+
profiler = profiler.create_profiler(config)
|
|
224
227
|
# Make sure that the profiler is an instance of `Profiler`.
|
|
225
228
|
if not isinstance(profiler, Profiler):
|
|
226
229
|
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
|
@@ -236,7 +239,7 @@ class Trainer(LightningTrainer):
|
|
|
236
239
|
if plugin_configs := config.trainer.plugins:
|
|
237
240
|
_update_kwargs(
|
|
238
241
|
plugins=[
|
|
239
|
-
plugin_config.
|
|
242
|
+
plugin_config.create_plugin() for plugin_config in plugin_configs
|
|
240
243
|
]
|
|
241
244
|
)
|
|
242
245
|
|
|
@@ -244,7 +247,7 @@ class Trainer(LightningTrainer):
|
|
|
244
247
|
log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
|
|
245
248
|
kwargs["logger"] = False
|
|
246
249
|
else:
|
|
247
|
-
_update_kwargs(logger=config.trainer.logging.
|
|
250
|
+
_update_kwargs(logger=config.trainer.logging.create_loggers(config))
|
|
248
251
|
|
|
249
252
|
if config.trainer.auto_determine_num_nodes:
|
|
250
253
|
# When num_nodes is auto, we need to detect the number of nodes.
|
|
@@ -275,6 +278,9 @@ class Trainer(LightningTrainer):
|
|
|
275
278
|
|
|
276
279
|
return kwargs
|
|
277
280
|
|
|
281
|
+
if TYPE_CHECKING:
|
|
282
|
+
callbacks: list[Callback]
|
|
283
|
+
|
|
278
284
|
@override
|
|
279
285
|
def __init__(
|
|
280
286
|
self,
|
|
@@ -282,12 +288,14 @@ class Trainer(LightningTrainer):
|
|
|
282
288
|
/,
|
|
283
289
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
284
290
|
):
|
|
285
|
-
self._ll_config = config
|
|
286
291
|
kwargs = self._update_kwargs(config, kwargs)
|
|
287
292
|
log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
|
|
288
293
|
|
|
289
294
|
super().__init__(**kwargs)
|
|
290
295
|
|
|
296
|
+
# Add our own start time callback to measure the start time.
|
|
297
|
+
self.callbacks.append(RuntimeTrackerCallback())
|
|
298
|
+
|
|
291
299
|
# Replace the signal connector with our own.
|
|
292
300
|
self._signal_connector = _SignalConnector(self)
|
|
293
301
|
|
|
@@ -296,34 +304,89 @@ class Trainer(LightningTrainer):
|
|
|
296
304
|
log_dir = str(Path(log_dir).resolve())
|
|
297
305
|
log.critical(f"LightningTrainer log directory: {self.log_dir}.")
|
|
298
306
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
307
|
+
def __runtime_tracker(self):
|
|
308
|
+
return next(
|
|
309
|
+
(
|
|
310
|
+
callback
|
|
311
|
+
for callback in self.callbacks
|
|
312
|
+
if isinstance(callback, RuntimeTrackerCallback)
|
|
313
|
+
),
|
|
314
|
+
None,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def __current_stage(self) -> Stage:
|
|
318
|
+
match self.state.fn:
|
|
319
|
+
case None:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
"Trainer state function is not set. "
|
|
322
|
+
"You must call `fit`, `validate`, `test`, or `predict`, "
|
|
323
|
+
"or explicitly provide a stage."
|
|
324
|
+
)
|
|
325
|
+
case TrainerFn.FITTING:
|
|
326
|
+
return "train"
|
|
327
|
+
case TrainerFn.VALIDATING:
|
|
328
|
+
return "validate"
|
|
329
|
+
case TrainerFn.TESTING:
|
|
330
|
+
return "test"
|
|
331
|
+
case TrainerFn.PREDICTING:
|
|
332
|
+
return "predict"
|
|
333
|
+
case _:
|
|
334
|
+
assert_never(self.state.fn)
|
|
335
|
+
|
|
336
|
+
def start_time(self, stage: Stage | None = None):
|
|
337
|
+
"""Return the start time of the run"""
|
|
338
|
+
if (tracker := self.__runtime_tracker()) is None:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"RuntimeTrackerCallback is not set. Cannot get start time."
|
|
341
|
+
)
|
|
342
|
+
if stage is None:
|
|
343
|
+
stage = self.__current_stage()
|
|
304
344
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
345
|
+
return tracker.start_time(stage)
|
|
346
|
+
|
|
347
|
+
def end_time(self, stage: Stage | None = None):
|
|
348
|
+
"""Return the end time of the run"""
|
|
349
|
+
if (tracker := self.__runtime_tracker()) is None:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"RuntimeTrackerCallback is not set. Cannot get start time."
|
|
352
|
+
)
|
|
353
|
+
if stage is None:
|
|
354
|
+
stage = self.__current_stage()
|
|
355
|
+
|
|
356
|
+
return tracker.end_time(stage)
|
|
357
|
+
|
|
358
|
+
def time_elapsed(self, stage: Stage | None = None):
|
|
359
|
+
"""Return the time elapsed for the run"""
|
|
360
|
+
if (tracker := self.__runtime_tracker()) is None:
|
|
361
|
+
raise ValueError(
|
|
362
|
+
"RuntimeTrackerCallback is not set. Cannot get start time."
|
|
363
|
+
)
|
|
364
|
+
if stage is None:
|
|
365
|
+
stage = self.__current_stage()
|
|
311
366
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
367
|
+
return tracker.time_elapsed(stage)
|
|
368
|
+
|
|
369
|
+
@property
|
|
370
|
+
def _base_module(self):
|
|
371
|
+
if self.lightning_module is None:
|
|
372
|
+
raise ValueError("LightningModule is not set.")
|
|
373
|
+
|
|
374
|
+
from ..model.base import LightningModuleBase
|
|
375
|
+
|
|
376
|
+
if not isinstance(self.lightning_module, LightningModuleBase):
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"LightningModule is not an instance of {LightningModuleBase}."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
return self.lightning_module
|
|
315
382
|
|
|
316
383
|
@override
|
|
317
384
|
def _run(
|
|
318
385
|
self, model: LightningModule, ckpt_path: str | Path | None = None
|
|
319
386
|
) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None:
|
|
320
|
-
"""
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
We patch the `Trainer._run` method to throw if gradient clipping is enabled
|
|
324
|
-
and `model.automatic_optimization` is False.
|
|
325
|
-
|
|
326
|
-
2. We actually set up actsave here.
|
|
387
|
+
"""Lightning doesn't support gradient clipping with manual optimization.
|
|
388
|
+
We patch the `Trainer._run` method to throw if gradient clipping is enabled
|
|
389
|
+
and `model.automatic_optimization` is False.
|
|
327
390
|
"""
|
|
328
391
|
|
|
329
392
|
if not model.automatic_optimization and (
|
|
@@ -336,5 +399,22 @@ class Trainer(LightningTrainer):
|
|
|
336
399
|
"or disable automatic gradient clipping. "
|
|
337
400
|
)
|
|
338
401
|
|
|
339
|
-
|
|
340
|
-
|
|
402
|
+
return super()._run(model, ckpt_path)
|
|
403
|
+
|
|
404
|
+
@override
|
|
405
|
+
def save_checkpoint(
|
|
406
|
+
self,
|
|
407
|
+
filepath: str | Path,
|
|
408
|
+
weights_only: bool = False,
|
|
409
|
+
storage_options: Any | None = None,
|
|
410
|
+
):
|
|
411
|
+
filepath = Path(filepath)
|
|
412
|
+
ret_val = super().save_checkpoint(filepath, weights_only, storage_options)
|
|
413
|
+
|
|
414
|
+
# Save the checkpoint metadata
|
|
415
|
+
lm = self._base_module
|
|
416
|
+
if lm.config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
417
|
+
# Generate the metadata and write to disk
|
|
418
|
+
_write_checkpoint_metadata(self, lm, filepath)
|
|
419
|
+
|
|
420
|
+
return ret_val
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.10.1
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -9,11 +9,13 @@ Classifier: Programming Language :: Python :: 3
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: GitPython
|
|
12
13
|
Requires-Dist: lightning
|
|
13
14
|
Requires-Dist: nshconfig
|
|
14
15
|
Requires-Dist: nshrunner
|
|
15
16
|
Requires-Dist: nshutils
|
|
16
17
|
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: psutil
|
|
17
19
|
Requires-Dist: pytorch-lightning
|
|
18
20
|
Requires-Dist: torch
|
|
19
21
|
Requires-Dist: torchmetrics
|
|
@@ -1,32 +1,34 @@
|
|
|
1
|
-
nshtrainer/__init__.py,sha256=
|
|
1
|
+
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
|
|
3
|
+
nshtrainer/_checkpoint/metadata.py,sha256=C7je_soYyEbZjiq7p2_pSVFkgcXnz2J2H5sMy8oskx0,3051
|
|
2
4
|
nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
|
|
3
5
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
4
6
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
5
7
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
6
|
-
nshtrainer/
|
|
7
|
-
nshtrainer/actsave/_callback.py,sha256=mnHOtuG9vtHEzz9q4vCvDNC6VvjZsgb4MSSuOoUDh3M,2778
|
|
8
|
-
nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
|
|
8
|
+
nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gOqjE,2345
|
|
9
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
|
-
nshtrainer/callbacks/
|
|
10
|
+
nshtrainer/callbacks/actsave.py,sha256=aY6T_NAzaFAVU8WMHOXnWL5wd2bi8eVxeU2S0iAs70c,4446
|
|
11
|
+
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
11
12
|
nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
|
|
12
|
-
nshtrainer/callbacks/ema.py,sha256=
|
|
13
|
-
nshtrainer/callbacks/finite_checks.py,sha256=
|
|
14
|
-
nshtrainer/callbacks/gradient_skipping.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
14
|
+
nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
|
|
15
|
+
nshtrainer/callbacks/gradient_skipping.py,sha256=fSJpjgHbztFKz7w3qFuCHZpmbEt9BCLAy-sU0B4xJQI,3474
|
|
15
16
|
nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
|
|
16
|
-
nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=
|
|
17
|
+
nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=p0zeDK3PLWWl485e9o08ywEEARCfuZ5it47tNCtR4ec,2838
|
|
17
18
|
nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
|
|
18
|
-
nshtrainer/callbacks/
|
|
19
|
-
nshtrainer/callbacks/
|
|
20
|
-
nshtrainer/callbacks/
|
|
21
|
-
nshtrainer/callbacks/
|
|
22
|
-
nshtrainer/callbacks/
|
|
23
|
-
nshtrainer/callbacks/
|
|
19
|
+
nshtrainer/callbacks/model_checkpoint.py,sha256=4zYycpXHGRyL4svWLP6GmG3WJs5m3B5PRCOzXC3m_qg,5955
|
|
20
|
+
nshtrainer/callbacks/norm_logging.py,sha256=EWyrfkp8iHjQi9iAAXHxb0xStw2RwkdpKG2_gLarQRA,6281
|
|
21
|
+
nshtrainer/callbacks/on_exception_checkpoint.py,sha256=zna_QF_x4HwD7Es5XxrHLDED43NU1GpcDNoL139HEOs,3355
|
|
22
|
+
nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
|
|
23
|
+
nshtrainer/callbacks/throughput_monitor.py,sha256=4EF3b79HdHiRgBGIFDyD4O1oywb5h1tV8nml7NuuDjU,1845
|
|
24
|
+
nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
|
|
25
|
+
nshtrainer/callbacks/wandb_watch.py,sha256=bicXS3nZfPGoN7Owu1XIBS-1bw7yeIJdYJTnRN0dp2E,2934
|
|
24
26
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
25
27
|
nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
|
|
26
28
|
nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
|
|
27
|
-
nshtrainer/ll/__init__.py,sha256=
|
|
29
|
+
nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
|
|
28
30
|
nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
|
|
29
|
-
nshtrainer/ll/actsave.py,sha256=
|
|
31
|
+
nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
|
|
30
32
|
nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
|
|
31
33
|
nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
|
|
32
34
|
nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
|
|
@@ -45,13 +47,16 @@ nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzq
|
|
|
45
47
|
nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
|
|
46
48
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
|
|
47
49
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
48
|
-
nshtrainer/
|
|
49
|
-
nshtrainer/
|
|
50
|
-
nshtrainer/model/
|
|
50
|
+
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
51
|
+
nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
|
|
52
|
+
nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
|
|
53
|
+
nshtrainer/model/_environment.py,sha256=s3JFnigbssFRJTwH33K7DcAYVhLOFCC1OZgFNXJgjuw,22317
|
|
54
|
+
nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
|
|
55
|
+
nshtrainer/model/config.py,sha256=B1XkKYbhpAm6RmF4n4eR66hMh-kCXwIQB2pQuhR9TZE,53177
|
|
51
56
|
nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
|
|
52
57
|
nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
|
|
53
58
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
54
|
-
nshtrainer/model/modules/logger.py,sha256=
|
|
59
|
+
nshtrainer/model/modules/logger.py,sha256=YYhehQysqTjuVFcd_EREYDh57CIlezidFBS2Ohp_xKo,5661
|
|
55
60
|
nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
|
|
56
61
|
nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
|
|
57
62
|
nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
|
|
@@ -61,17 +66,19 @@ nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,
|
|
|
61
66
|
nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
|
|
62
67
|
nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
|
|
63
68
|
nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
|
|
64
|
-
nshtrainer/runner.py,sha256=
|
|
69
|
+
nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
|
|
65
70
|
nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
|
|
66
71
|
nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
|
|
67
72
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
68
|
-
nshtrainer/trainer/
|
|
69
|
-
nshtrainer/trainer/
|
|
73
|
+
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
74
|
+
nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
|
|
75
|
+
nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
|
|
76
|
+
nshtrainer/trainer/trainer.py,sha256=n3T9Iz3eaDostxEdjapWImAsVMxyU9WBdhlPl0THX-g,16785
|
|
70
77
|
nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
|
|
71
78
|
nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
|
|
72
79
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
73
80
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
74
81
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
75
|
-
nshtrainer-0.
|
|
76
|
-
nshtrainer-0.
|
|
77
|
-
nshtrainer-0.
|
|
82
|
+
nshtrainer-0.10.1.dist-info/METADATA,sha256=O8wMPb0ksoZajyes8dsq4IIjsfP_jQaxGYpW3rYE9Ro,695
|
|
83
|
+
nshtrainer-0.10.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
84
|
+
nshtrainer-0.10.1.dist-info/RECORD,,
|
nshtrainer/actsave/__init__.py
DELETED
|
File without changes
|