nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +92 -507
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
|
@@ -11,10 +11,14 @@ from lightning.pytorch.utilities.types import _METRIC
|
|
|
11
11
|
from lightning_utilities.core.rank_zero import rank_zero_warn
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
-
from ...actsave import ActSave
|
|
15
14
|
from ...util.typing_utils import mixin_base_type
|
|
16
15
|
from ..config import BaseConfig
|
|
17
16
|
|
|
17
|
+
try:
|
|
18
|
+
from nshutils import ActSave # type: ignore
|
|
19
|
+
except ImportError:
|
|
20
|
+
ActSave = None
|
|
21
|
+
|
|
18
22
|
|
|
19
23
|
@dataclass(frozen=True, kw_only=True)
|
|
20
24
|
class _LogContext:
|
|
@@ -155,14 +159,15 @@ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningMod
|
|
|
155
159
|
|
|
156
160
|
def _logger_actsave(self, name: str, value: _METRIC) -> None:
|
|
157
161
|
hparams = cast(BaseConfig, self.hparams)
|
|
158
|
-
if
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
+
if not hparams.trainer.logging.actsave_logged_metrics:
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
if ActSave is None:
|
|
166
|
+
rank_zero_warn("ActSave is not available, skipping logging of metrics")
|
|
162
167
|
return
|
|
163
168
|
|
|
164
169
|
ActSave.save(
|
|
165
|
-
{
|
|
170
|
+
lambda: {
|
|
166
171
|
f"logger.{name}": lambda: value.compute()
|
|
167
172
|
if isinstance(value, torchmetrics.Metric)
|
|
168
173
|
else value
|
nshtrainer/runner.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Generic
|
|
|
5
5
|
|
|
6
6
|
from nshrunner import RunInfo
|
|
7
7
|
from nshrunner import Runner as _Runner
|
|
8
|
+
from nshrunner._submit import screen
|
|
8
9
|
from nshrunner.snapshot import SnapshotArgType
|
|
9
10
|
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
|
|
10
11
|
|
|
@@ -89,6 +90,7 @@ class Runner(
|
|
|
89
90
|
def fast_dev_run_session(
|
|
90
91
|
self,
|
|
91
92
|
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
93
|
+
options: screen.ScreenJobKwargs = {},
|
|
92
94
|
n_batches: int = 1,
|
|
93
95
|
*,
|
|
94
96
|
snapshot: SnapshotArgType,
|
|
@@ -99,10 +101,7 @@ class Runner(
|
|
|
99
101
|
]
|
|
100
102
|
| None = None,
|
|
101
103
|
activate_venv: bool = True,
|
|
102
|
-
session_name: str = "nshrunner",
|
|
103
|
-
attach: bool = True,
|
|
104
104
|
print_command: bool = True,
|
|
105
|
-
pause_before_exit: bool = False,
|
|
106
105
|
):
|
|
107
106
|
transforms = transforms or []
|
|
108
107
|
transforms.append(
|
|
@@ -110,13 +109,11 @@ class Runner(
|
|
|
110
109
|
)
|
|
111
110
|
return self.session(
|
|
112
111
|
runs,
|
|
112
|
+
options,
|
|
113
113
|
snapshot=snapshot,
|
|
114
114
|
setup_commands=setup_commands,
|
|
115
115
|
env=env,
|
|
116
116
|
transforms=transforms,
|
|
117
117
|
activate_venv=activate_venv,
|
|
118
|
-
session_name=session_name,
|
|
119
|
-
attach=attach,
|
|
120
118
|
print_command=print_command,
|
|
121
|
-
pause_before_exit=pause_before_exit,
|
|
122
119
|
)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import datetime
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
6
|
+
|
|
7
|
+
import nshconfig as C
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ..model._environment import EnvironmentConfig
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..model import BaseConfig, LightningModuleBase
|
|
15
|
+
from .trainer import Trainer
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
METADATA_PATH_SUFFIX = ".metadata.json"
|
|
21
|
+
HPARAMS_PATH_SUFFIX = ".hparams.json"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CheckpointMetadata(C.Config):
|
|
25
|
+
checkpoint_path: Path
|
|
26
|
+
checkpoint_filename: str
|
|
27
|
+
|
|
28
|
+
run_id: str
|
|
29
|
+
name: str
|
|
30
|
+
project: str | None
|
|
31
|
+
checkpoint_timestamp: datetime.datetime
|
|
32
|
+
start_timestamp: datetime.datetime | None
|
|
33
|
+
|
|
34
|
+
epoch: int
|
|
35
|
+
global_step: int
|
|
36
|
+
training_time: datetime.timedelta
|
|
37
|
+
metrics: dict[str, Any]
|
|
38
|
+
environment: EnvironmentConfig
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_file(cls, path: Path):
|
|
42
|
+
return cls.model_validate_json(path.read_text())
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _generate_checkpoint_metadata(
|
|
46
|
+
config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
|
|
47
|
+
):
|
|
48
|
+
checkpoint_timestamp = datetime.datetime.now()
|
|
49
|
+
start_timestamp = trainer.start_time()
|
|
50
|
+
training_time = trainer.time_elapsed()
|
|
51
|
+
|
|
52
|
+
metrics: dict[str, Any] = {}
|
|
53
|
+
for name, metric in copy.deepcopy(trainer.callback_metrics).items():
|
|
54
|
+
match metric:
|
|
55
|
+
case torch.Tensor() | np.ndarray():
|
|
56
|
+
metrics[name] = metric.detach().cpu().item()
|
|
57
|
+
case _:
|
|
58
|
+
metrics[name] = metric
|
|
59
|
+
|
|
60
|
+
return CheckpointMetadata(
|
|
61
|
+
checkpoint_path=checkpoint_path,
|
|
62
|
+
checkpoint_filename=checkpoint_path.name,
|
|
63
|
+
run_id=config.id,
|
|
64
|
+
name=config.run_name,
|
|
65
|
+
project=config.project,
|
|
66
|
+
checkpoint_timestamp=checkpoint_timestamp,
|
|
67
|
+
start_timestamp=start_timestamp.datetime
|
|
68
|
+
if start_timestamp is not None
|
|
69
|
+
else None,
|
|
70
|
+
epoch=trainer.current_epoch,
|
|
71
|
+
global_step=trainer.global_step,
|
|
72
|
+
training_time=training_time,
|
|
73
|
+
metrics=metrics,
|
|
74
|
+
environment=config.environment,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _write_checkpoint_metadata(
|
|
79
|
+
trainer: "Trainer",
|
|
80
|
+
model: "LightningModuleBase",
|
|
81
|
+
checkpoint_path: Path,
|
|
82
|
+
):
|
|
83
|
+
config = cast("BaseConfig", model.config)
|
|
84
|
+
metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
|
|
85
|
+
|
|
86
|
+
# Write the metadata to the checkpoint directory
|
|
87
|
+
try:
|
|
88
|
+
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
89
|
+
metadata_path.write_text(metadata.model_dump_json(indent=4))
|
|
90
|
+
except Exception as e:
|
|
91
|
+
log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
|
|
92
|
+
else:
|
|
93
|
+
log.info(f"Checkpoint metadata written to {checkpoint_path}")
|
|
94
|
+
|
|
95
|
+
# Write the hparams to the checkpoint directory
|
|
96
|
+
try:
|
|
97
|
+
hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
|
|
98
|
+
hparams_path.write_text(config.model_dump_json(indent=4))
|
|
99
|
+
except Exception as e:
|
|
100
|
+
log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
|
|
101
|
+
else:
|
|
102
|
+
log.info(f"Checkpoint metadata written to {checkpoint_path}")
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
|
|
6
|
+
|
|
7
|
+
import nshconfig as C
|
|
8
|
+
from lightning.pytorch import Trainer as LightningTrainer
|
|
9
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from ..metrics._config import MetricConfig
|
|
13
|
+
from ._checkpoint_metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..model.config import BaseConfig
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BestCheckpointStrategyConfig(C.Config):
|
|
22
|
+
name: Literal["best"] = "best"
|
|
23
|
+
|
|
24
|
+
metric: MetricConfig | None = None
|
|
25
|
+
"""The metric to use for selecting the best checkpoint. If `None`, the primary metric will be used."""
|
|
26
|
+
|
|
27
|
+
additional_candidates: Iterable[Path] = []
|
|
28
|
+
"""Additional checkpoint candidates to consider when selecting the last checkpoint."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class UserProvidedPathCheckpointStrategyConfig(C.Config):
|
|
32
|
+
name: Literal["user_provided_path"] = "user_provided_path"
|
|
33
|
+
|
|
34
|
+
path: Path
|
|
35
|
+
"""The path to the checkpoint to load."""
|
|
36
|
+
|
|
37
|
+
on_error: Literal["warn", "raise"] = "warn"
|
|
38
|
+
"""The behavior when the checkpoint does not belong to the current run.
|
|
39
|
+
|
|
40
|
+
- `warn`: Log a warning and skip the checkpoint.
|
|
41
|
+
- `raise`: Raise an error.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LastCheckpointStrategyConfig(C.Config):
|
|
46
|
+
name: Literal["last"] = "last"
|
|
47
|
+
|
|
48
|
+
criterion: Literal["global_step", "runtime"] = "global_step"
|
|
49
|
+
"""The criterion to use for selecting the last checkpoint.
|
|
50
|
+
|
|
51
|
+
- `global_step`: The checkpoint with the highest global step will be selected.
|
|
52
|
+
- `runtime`: The checkpoint with the highest runtime will be selected.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
additional_candidates: Iterable[Path] = []
|
|
56
|
+
"""Additional checkpoint candidates to consider when selecting the last checkpoint."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
CheckpointLoadingStrategyConfig: TypeAlias = Annotated[
|
|
60
|
+
BestCheckpointStrategyConfig
|
|
61
|
+
| LastCheckpointStrategyConfig
|
|
62
|
+
| UserProvidedPathCheckpointStrategyConfig,
|
|
63
|
+
C.Field(discriminator="name"),
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class CheckpointLoadingConfig(C.Config):
|
|
68
|
+
strategies: Sequence[CheckpointLoadingStrategyConfig]
|
|
69
|
+
"""The strategies to use for loading checkpoints.
|
|
70
|
+
|
|
71
|
+
The order of the strategies determines the priority of the strategies.
|
|
72
|
+
The first strategy that resolves a checkpoint will be used.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
include_hpc: bool
|
|
76
|
+
"""Whether to include checkpoints from HPC pre-emption."""
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
80
|
+
if ckpt is None:
|
|
81
|
+
ckpt = "last"
|
|
82
|
+
match ckpt:
|
|
83
|
+
case "best":
|
|
84
|
+
return cls(
|
|
85
|
+
strategies=[BestCheckpointStrategyConfig()],
|
|
86
|
+
include_hpc=True,
|
|
87
|
+
)
|
|
88
|
+
case "last":
|
|
89
|
+
return cls(
|
|
90
|
+
strategies=[LastCheckpointStrategyConfig()],
|
|
91
|
+
include_hpc=True,
|
|
92
|
+
)
|
|
93
|
+
case Path() | str():
|
|
94
|
+
ckpt = Path(ckpt)
|
|
95
|
+
return cls(
|
|
96
|
+
strategies=[
|
|
97
|
+
LastCheckpointStrategyConfig(additional_candidates=[ckpt]),
|
|
98
|
+
UserProvidedPathCheckpointStrategyConfig(path=ckpt),
|
|
99
|
+
],
|
|
100
|
+
include_hpc=True,
|
|
101
|
+
)
|
|
102
|
+
case _:
|
|
103
|
+
assert_never(ckpt)
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
107
|
+
if ckpt is None:
|
|
108
|
+
raise ValueError("Checkpoint path must be provided for evaluation.")
|
|
109
|
+
|
|
110
|
+
match ckpt:
|
|
111
|
+
case "best":
|
|
112
|
+
return cls(
|
|
113
|
+
strategies=[BestCheckpointStrategyConfig()],
|
|
114
|
+
include_hpc=False,
|
|
115
|
+
)
|
|
116
|
+
case "last":
|
|
117
|
+
return cls(
|
|
118
|
+
strategies=[LastCheckpointStrategyConfig()],
|
|
119
|
+
include_hpc=False,
|
|
120
|
+
)
|
|
121
|
+
case Path() | str():
|
|
122
|
+
ckpt = Path(ckpt)
|
|
123
|
+
return cls(
|
|
124
|
+
strategies=[UserProvidedPathCheckpointStrategyConfig(path=ckpt)],
|
|
125
|
+
include_hpc=False,
|
|
126
|
+
)
|
|
127
|
+
case _:
|
|
128
|
+
assert_never(ckpt)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def auto(
|
|
132
|
+
cls,
|
|
133
|
+
ckpt: Literal["best", "last"] | str | Path | None,
|
|
134
|
+
trainer_mode: TrainerFn,
|
|
135
|
+
):
|
|
136
|
+
match trainer_mode:
|
|
137
|
+
case TrainerFn.FITTING:
|
|
138
|
+
return cls._auto_train(ckpt)
|
|
139
|
+
case TrainerFn.VALIDATING | TrainerFn.TESTING | TrainerFn.PREDICTING:
|
|
140
|
+
return cls._auto_eval(ckpt)
|
|
141
|
+
case _:
|
|
142
|
+
assert_never(trainer_mode)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class _CkptCandidate:
|
|
147
|
+
meta: CheckpointMetadata
|
|
148
|
+
meta_path: Path
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def ckpt_path(self):
|
|
152
|
+
return self.meta_path.with_name(self.meta.checkpoint_filename)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@overload
|
|
156
|
+
def _load_ckpt_meta(
|
|
157
|
+
path: Path,
|
|
158
|
+
root_config: "BaseConfig",
|
|
159
|
+
on_error: Literal["warn"] = "warn",
|
|
160
|
+
) -> _CkptCandidate | None: ...
|
|
161
|
+
@overload
|
|
162
|
+
def _load_ckpt_meta(
|
|
163
|
+
path: Path,
|
|
164
|
+
root_config: "BaseConfig",
|
|
165
|
+
on_error: Literal["raise"],
|
|
166
|
+
) -> _CkptCandidate: ...
|
|
167
|
+
def _load_ckpt_meta(
|
|
168
|
+
path: Path,
|
|
169
|
+
root_config: "BaseConfig",
|
|
170
|
+
on_error: Literal["warn", "raise"] = "warn",
|
|
171
|
+
):
|
|
172
|
+
meta = CheckpointMetadata.from_file(path)
|
|
173
|
+
if root_config.id != meta.run_id:
|
|
174
|
+
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
|
175
|
+
match on_error:
|
|
176
|
+
case "warn":
|
|
177
|
+
log.warn(error_msg)
|
|
178
|
+
case "raise":
|
|
179
|
+
raise ValueError(error_msg)
|
|
180
|
+
case _:
|
|
181
|
+
assert_never(on_error)
|
|
182
|
+
return None
|
|
183
|
+
return _CkptCandidate(meta, path)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _checkpoint_candidates(
|
|
187
|
+
root_config: "BaseConfig",
|
|
188
|
+
trainer: LightningTrainer,
|
|
189
|
+
*,
|
|
190
|
+
include_hpc: bool = True,
|
|
191
|
+
):
|
|
192
|
+
# Load the checkpoint directory, and throw if it doesn't exist.
|
|
193
|
+
# This indicates a non-standard setup, and we don't want to guess
|
|
194
|
+
# where the checkpoints are.
|
|
195
|
+
ckpt_dir = root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
196
|
+
if not ckpt_dir.is_dir():
|
|
197
|
+
raise FileNotFoundError(
|
|
198
|
+
f"Checkpoint directory {ckpt_dir} not found. "
|
|
199
|
+
"Please ensure that the checkpoint directory exists."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Load all checkpoints in the directory.
|
|
203
|
+
# We can do this by looking for metadata files.
|
|
204
|
+
for path in ckpt_dir.glob(f"*{METADATA_PATH_SUFFIX}"):
|
|
205
|
+
if (meta := _load_ckpt_meta(path, root_config)) is not None:
|
|
206
|
+
yield meta
|
|
207
|
+
|
|
208
|
+
# If we have a pre-empted checkpoint, load it
|
|
209
|
+
if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
|
|
210
|
+
hpc_meta_path = Path(hpc_path).with_suffix(METADATA_PATH_SUFFIX)
|
|
211
|
+
if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
|
|
212
|
+
yield meta
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _additional_candidates(
|
|
216
|
+
additional_candidates: Iterable[Path], root_config: "BaseConfig"
|
|
217
|
+
):
|
|
218
|
+
for path in additional_candidates:
|
|
219
|
+
if (
|
|
220
|
+
meta := _load_ckpt_meta(path.with_suffix(METADATA_PATH_SUFFIX), root_config)
|
|
221
|
+
) is None:
|
|
222
|
+
continue
|
|
223
|
+
yield meta
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _resolve_checkpoint(
|
|
227
|
+
config: CheckpointLoadingConfig,
|
|
228
|
+
root_config: "BaseConfig",
|
|
229
|
+
trainer: LightningTrainer,
|
|
230
|
+
):
|
|
231
|
+
# We lazily load the checkpoint candidates to avoid loading them
|
|
232
|
+
# if they are not needed.
|
|
233
|
+
_ckpt_candidates: list[_CkptCandidate] | None = None
|
|
234
|
+
|
|
235
|
+
def ckpt_candidates():
|
|
236
|
+
nonlocal _ckpt_candidates, root_config, trainer
|
|
237
|
+
|
|
238
|
+
if _ckpt_candidates is None:
|
|
239
|
+
_ckpt_candidates = list(
|
|
240
|
+
_checkpoint_candidates(
|
|
241
|
+
root_config, trainer, include_hpc=config.include_hpc
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
return _ckpt_candidates
|
|
245
|
+
|
|
246
|
+
# Iterate over the strategies and try to resolve the checkpoint.
|
|
247
|
+
for strategy in config.strategies:
|
|
248
|
+
match strategy:
|
|
249
|
+
case UserProvidedPathCheckpointStrategyConfig():
|
|
250
|
+
meta = _load_ckpt_meta(
|
|
251
|
+
strategy.path.with_suffix(METADATA_PATH_SUFFIX),
|
|
252
|
+
root_config,
|
|
253
|
+
on_error=strategy.on_error,
|
|
254
|
+
)
|
|
255
|
+
if meta is None:
|
|
256
|
+
continue
|
|
257
|
+
return meta.ckpt_path
|
|
258
|
+
case BestCheckpointStrategyConfig():
|
|
259
|
+
candidates = [
|
|
260
|
+
*ckpt_candidates(),
|
|
261
|
+
*_additional_candidates(
|
|
262
|
+
strategy.additional_candidates, root_config
|
|
263
|
+
),
|
|
264
|
+
]
|
|
265
|
+
if not candidates:
|
|
266
|
+
log.warn(
|
|
267
|
+
"No checkpoint candidates found for `best` checkpoint strategy."
|
|
268
|
+
)
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
if (metric := strategy.metric or root_config.primary_metric) is None:
|
|
272
|
+
log.warn(
|
|
273
|
+
"No metric specified for `best` checkpoint strategy, "
|
|
274
|
+
"and no primary metric is set in the configuration. "
|
|
275
|
+
"Skipping strategy."
|
|
276
|
+
)
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
# Find the best checkpoint based on the metric.
|
|
280
|
+
def metric_value(ckpt: _CkptCandidate):
|
|
281
|
+
assert metric is not None
|
|
282
|
+
if (
|
|
283
|
+
value := ckpt.meta.metrics.get(metric.validation_monitor)
|
|
284
|
+
) is None:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
f"Metric {metric.validation_monitor} not found in checkpoint metadata. "
|
|
287
|
+
f"Available metrics: {ckpt.meta.metrics.keys()}"
|
|
288
|
+
)
|
|
289
|
+
return value
|
|
290
|
+
|
|
291
|
+
best_candidate = metric.best(candidates, key=metric_value)
|
|
292
|
+
return best_candidate.ckpt_path
|
|
293
|
+
case LastCheckpointStrategyConfig():
|
|
294
|
+
candidates = [
|
|
295
|
+
*ckpt_candidates(),
|
|
296
|
+
*_additional_candidates(
|
|
297
|
+
strategy.additional_candidates, root_config
|
|
298
|
+
),
|
|
299
|
+
]
|
|
300
|
+
if not candidates:
|
|
301
|
+
log.warn(
|
|
302
|
+
"No checkpoint candidates found for `last` checkpoint strategy."
|
|
303
|
+
)
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
# Find the last checkpoint based on the criterion.
|
|
307
|
+
def criterion_value(ckpt: _CkptCandidate):
|
|
308
|
+
match strategy.criterion:
|
|
309
|
+
case "global_step":
|
|
310
|
+
return ckpt.meta.global_step
|
|
311
|
+
case "runtime":
|
|
312
|
+
return ckpt.meta.training_time.total_seconds()
|
|
313
|
+
case _:
|
|
314
|
+
assert_never(strategy.criterion)
|
|
315
|
+
|
|
316
|
+
last_candidate = max(candidates, key=criterion_value)
|
|
317
|
+
return last_candidate.ckpt_path
|
|
318
|
+
case _:
|
|
319
|
+
assert_never(strategy)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Literal, TypeAlias
|
|
6
|
+
|
|
7
|
+
from lightning.pytorch.callbacks.callback import Callback
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
Stage: TypeAlias = Literal["train", "validate", "test", "predict"]
|
|
13
|
+
ALL_STAGES = ("train", "validate", "test", "predict")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class TimeInfo:
|
|
18
|
+
datetime: datetime.datetime
|
|
19
|
+
monotonic: float
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RuntimeTrackerCallback(Callback):
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self._start_time: dict[Stage, TimeInfo] = {}
|
|
26
|
+
self._end_time: dict[Stage, TimeInfo] = {}
|
|
27
|
+
self._offsets = {stage: datetime.timedelta() for stage in ALL_STAGES}
|
|
28
|
+
|
|
29
|
+
def start_time(self, stage: Stage) -> TimeInfo | None:
|
|
30
|
+
"""Return the start time of a particular stage"""
|
|
31
|
+
return self._start_time.get(stage)
|
|
32
|
+
|
|
33
|
+
def end_time(self, stage: Stage) -> TimeInfo | None:
|
|
34
|
+
"""Return the end time of a particular stage"""
|
|
35
|
+
return self._end_time.get(stage)
|
|
36
|
+
|
|
37
|
+
def time_elapsed(self, stage: Stage) -> datetime.timedelta:
|
|
38
|
+
"""Return the time elapsed for a particular stage"""
|
|
39
|
+
start = self.start_time(stage)
|
|
40
|
+
end = self.end_time(stage)
|
|
41
|
+
offset = self._offsets[stage]
|
|
42
|
+
if start is None:
|
|
43
|
+
return offset
|
|
44
|
+
if end is None:
|
|
45
|
+
current = TimeInfo(datetime.datetime.now(), time.monotonic())
|
|
46
|
+
return (
|
|
47
|
+
datetime.timedelta(seconds=current.monotonic - start.monotonic) + offset
|
|
48
|
+
)
|
|
49
|
+
return datetime.timedelta(seconds=end.monotonic - start.monotonic) + offset
|
|
50
|
+
|
|
51
|
+
def _record_time(self, stage: Stage, time_dict: dict[Stage, TimeInfo]):
|
|
52
|
+
time_dict[stage] = TimeInfo(datetime.datetime.now(), time.monotonic())
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def on_train_start(self, trainer, pl_module):
|
|
56
|
+
self._record_time("train", self._start_time)
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
def on_train_end(self, trainer, pl_module):
|
|
60
|
+
self._record_time("train", self._end_time)
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def on_validation_start(self, trainer, pl_module):
|
|
64
|
+
self._record_time("validate", self._start_time)
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def on_validation_end(self, trainer, pl_module):
|
|
68
|
+
self._record_time("validate", self._end_time)
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def on_test_start(self, trainer, pl_module):
|
|
72
|
+
self._record_time("test", self._start_time)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def on_test_end(self, trainer, pl_module):
|
|
76
|
+
self._record_time("test", self._end_time)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def on_predict_start(self, trainer, pl_module):
|
|
80
|
+
self._record_time("predict", self._start_time)
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def on_predict_end(self, trainer, pl_module):
|
|
84
|
+
self._record_time("predict", self._end_time)
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def state_dict(self) -> dict[str, Any]:
|
|
88
|
+
return {
|
|
89
|
+
"time_elapsed": {
|
|
90
|
+
stage: self.time_elapsed(stage).total_seconds() for stage in ALL_STAGES
|
|
91
|
+
},
|
|
92
|
+
"start_times": {
|
|
93
|
+
stage: (info.datetime.isoformat(), info.monotonic)
|
|
94
|
+
for stage, info in self._start_time.items()
|
|
95
|
+
},
|
|
96
|
+
"end_times": {
|
|
97
|
+
stage: (info.datetime.isoformat(), info.monotonic)
|
|
98
|
+
for stage, info in self._end_time.items()
|
|
99
|
+
},
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
@override
|
|
103
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
104
|
+
time_elapsed: dict[Stage, float] = state_dict.get("time_elapsed", {})
|
|
105
|
+
for stage in ALL_STAGES:
|
|
106
|
+
self._offsets[stage] = datetime.timedelta(
|
|
107
|
+
seconds=time_elapsed.get(stage, 0)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
start_times: dict[Stage, tuple[str, float]] = state_dict.get("start_times", {})
|
|
111
|
+
for stage, (dt_str, monotonic) in start_times.items():
|
|
112
|
+
self._start_time[stage] = TimeInfo(
|
|
113
|
+
datetime.datetime.fromisoformat(dt_str), monotonic
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
end_times: dict[Stage, tuple[str, float]] = state_dict.get("end_times", {})
|
|
117
|
+
for stage, (dt_str, monotonic) in end_times.items():
|
|
118
|
+
self._end_time[stage] = TimeInfo(
|
|
119
|
+
datetime.datetime.fromisoformat(dt_str), monotonic
|
|
120
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, cast
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.trainer.connectors.checkpoint_connector import (
|
|
6
|
+
_CheckpointConnector,
|
|
7
|
+
)
|
|
8
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from ._checkpoint_resolver import CheckpointLoadingConfig, _resolve_checkpoint
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..model.config import BaseConfig
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CheckpointConnector(_CheckpointConnector):
|
|
19
|
+
def __resolve_auto_ckpt_path(
|
|
20
|
+
self,
|
|
21
|
+
ckpt_path: str | Path | None,
|
|
22
|
+
state_fn: TrainerFn,
|
|
23
|
+
):
|
|
24
|
+
from .trainer import Trainer
|
|
25
|
+
|
|
26
|
+
# If this isn't an `nshtrainer` trainer (which I don't know why it wouldn't be),
|
|
27
|
+
# then we just default to the parent class's implementation of `_parse_ckpt_path`.
|
|
28
|
+
trainer = self.trainer
|
|
29
|
+
if not isinstance(trainer, Trainer):
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
# Now, resolve the checkpoint loader config.
|
|
33
|
+
root_config = cast("BaseConfig", trainer._base_module.config)
|
|
34
|
+
if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
|
|
35
|
+
ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
|
|
36
|
+
log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
|
|
37
|
+
|
|
38
|
+
# Use the config to resolve the checkpoint.
|
|
39
|
+
if (
|
|
40
|
+
ckpt_path := _resolve_checkpoint(ckpt_loader_config, root_config, trainer)
|
|
41
|
+
) is None:
|
|
42
|
+
log.info(
|
|
43
|
+
"No checkpoint found for the current trainer state. "
|
|
44
|
+
"Training will start from scratch."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
log.info(f"Loading checkpoint from: {ckpt_path}")
|
|
48
|
+
return ckpt_path
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def _parse_ckpt_path(
|
|
52
|
+
self,
|
|
53
|
+
state_fn: TrainerFn,
|
|
54
|
+
ckpt_path: str | Path | None,
|
|
55
|
+
model_provided: bool,
|
|
56
|
+
model_connected: bool,
|
|
57
|
+
):
|
|
58
|
+
if (p := self.__resolve_auto_ckpt_path(ckpt_path, state_fn)) is not None:
|
|
59
|
+
return p
|
|
60
|
+
|
|
61
|
+
return super()._parse_ckpt_path(
|
|
62
|
+
state_fn, ckpt_path, model_provided, model_connected
|
|
63
|
+
)
|