nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +92 -507
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
@@ -11,10 +11,14 @@ from lightning.pytorch.utilities.types import _METRIC
11
11
  from lightning_utilities.core.rank_zero import rank_zero_warn
12
12
  from typing_extensions import override
13
13
 
14
- from ...actsave import ActSave
15
14
  from ...util.typing_utils import mixin_base_type
16
15
  from ..config import BaseConfig
17
16
 
17
+ try:
18
+ from nshutils import ActSave # type: ignore
19
+ except ImportError:
20
+ ActSave = None
21
+
18
22
 
19
23
  @dataclass(frozen=True, kw_only=True)
20
24
  class _LogContext:
@@ -155,14 +159,15 @@ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningMod
155
159
 
156
160
  def _logger_actsave(self, name: str, value: _METRIC) -> None:
157
161
  hparams = cast(BaseConfig, self.hparams)
158
- if (
159
- not hparams.trainer.actsave
160
- or not hparams.trainer.actsave.auto_save_logged_metrics
161
- ):
162
+ if not hparams.trainer.logging.actsave_logged_metrics:
163
+ return
164
+
165
+ if ActSave is None:
166
+ rank_zero_warn("ActSave is not available, skipping logging of metrics")
162
167
  return
163
168
 
164
169
  ActSave.save(
165
- {
170
+ lambda: {
166
171
  f"logger.{name}": lambda: value.compute()
167
172
  if isinstance(value, torchmetrics.Metric)
168
173
  else value
nshtrainer/runner.py CHANGED
@@ -5,6 +5,7 @@ from typing import Generic
5
5
 
6
6
  from nshrunner import RunInfo
7
7
  from nshrunner import Runner as _Runner
8
+ from nshrunner._submit import screen
8
9
  from nshrunner.snapshot import SnapshotArgType
9
10
  from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
10
11
 
@@ -89,6 +90,7 @@ class Runner(
89
90
  def fast_dev_run_session(
90
91
  self,
91
92
  runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
93
+ options: screen.ScreenJobKwargs = {},
92
94
  n_batches: int = 1,
93
95
  *,
94
96
  snapshot: SnapshotArgType,
@@ -99,10 +101,7 @@ class Runner(
99
101
  ]
100
102
  | None = None,
101
103
  activate_venv: bool = True,
102
- session_name: str = "nshrunner",
103
- attach: bool = True,
104
104
  print_command: bool = True,
105
- pause_before_exit: bool = False,
106
105
  ):
107
106
  transforms = transforms or []
108
107
  transforms.append(
@@ -110,13 +109,11 @@ class Runner(
110
109
  )
111
110
  return self.session(
112
111
  runs,
112
+ options,
113
113
  snapshot=snapshot,
114
114
  setup_commands=setup_commands,
115
115
  env=env,
116
116
  transforms=transforms,
117
117
  activate_venv=activate_venv,
118
- session_name=session_name,
119
- attach=attach,
120
118
  print_command=print_command,
121
- pause_before_exit=pause_before_exit,
122
119
  )
@@ -0,0 +1,102 @@
1
+ import copy
2
+ import datetime
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any, cast
6
+
7
+ import nshconfig as C
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ..model._environment import EnvironmentConfig
12
+
13
+ if TYPE_CHECKING:
14
+ from ..model import BaseConfig, LightningModuleBase
15
+ from .trainer import Trainer
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ METADATA_PATH_SUFFIX = ".metadata.json"
21
+ HPARAMS_PATH_SUFFIX = ".hparams.json"
22
+
23
+
24
+ class CheckpointMetadata(C.Config):
25
+ checkpoint_path: Path
26
+ checkpoint_filename: str
27
+
28
+ run_id: str
29
+ name: str
30
+ project: str | None
31
+ checkpoint_timestamp: datetime.datetime
32
+ start_timestamp: datetime.datetime | None
33
+
34
+ epoch: int
35
+ global_step: int
36
+ training_time: datetime.timedelta
37
+ metrics: dict[str, Any]
38
+ environment: EnvironmentConfig
39
+
40
+ @classmethod
41
+ def from_file(cls, path: Path):
42
+ return cls.model_validate_json(path.read_text())
43
+
44
+
45
+ def _generate_checkpoint_metadata(
46
+ config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
47
+ ):
48
+ checkpoint_timestamp = datetime.datetime.now()
49
+ start_timestamp = trainer.start_time()
50
+ training_time = trainer.time_elapsed()
51
+
52
+ metrics: dict[str, Any] = {}
53
+ for name, metric in copy.deepcopy(trainer.callback_metrics).items():
54
+ match metric:
55
+ case torch.Tensor() | np.ndarray():
56
+ metrics[name] = metric.detach().cpu().item()
57
+ case _:
58
+ metrics[name] = metric
59
+
60
+ return CheckpointMetadata(
61
+ checkpoint_path=checkpoint_path,
62
+ checkpoint_filename=checkpoint_path.name,
63
+ run_id=config.id,
64
+ name=config.run_name,
65
+ project=config.project,
66
+ checkpoint_timestamp=checkpoint_timestamp,
67
+ start_timestamp=start_timestamp.datetime
68
+ if start_timestamp is not None
69
+ else None,
70
+ epoch=trainer.current_epoch,
71
+ global_step=trainer.global_step,
72
+ training_time=training_time,
73
+ metrics=metrics,
74
+ environment=config.environment,
75
+ )
76
+
77
+
78
+ def _write_checkpoint_metadata(
79
+ trainer: "Trainer",
80
+ model: "LightningModuleBase",
81
+ checkpoint_path: Path,
82
+ ):
83
+ config = cast("BaseConfig", model.config)
84
+ metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
85
+
86
+ # Write the metadata to the checkpoint directory
87
+ try:
88
+ metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
89
+ metadata_path.write_text(metadata.model_dump_json(indent=4))
90
+ except Exception as e:
91
+ log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
92
+ else:
93
+ log.info(f"Checkpoint metadata written to {checkpoint_path}")
94
+
95
+ # Write the hparams to the checkpoint directory
96
+ try:
97
+ hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
98
+ hparams_path.write_text(config.model_dump_json(indent=4))
99
+ except Exception as e:
100
+ log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
101
+ else:
102
+ log.info(f"Checkpoint metadata written to {checkpoint_path}")
@@ -0,0 +1,319 @@
1
+ import logging
2
+ from collections.abc import Iterable, Sequence
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
6
+
7
+ import nshconfig as C
8
+ from lightning.pytorch import Trainer as LightningTrainer
9
+ from lightning.pytorch.trainer.states import TrainerFn
10
+ from typing_extensions import assert_never
11
+
12
+ from ..metrics._config import MetricConfig
13
+ from ._checkpoint_metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
14
+
15
+ if TYPE_CHECKING:
16
+ from ..model.config import BaseConfig
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ class BestCheckpointStrategyConfig(C.Config):
22
+ name: Literal["best"] = "best"
23
+
24
+ metric: MetricConfig | None = None
25
+ """The metric to use for selecting the best checkpoint. If `None`, the primary metric will be used."""
26
+
27
+ additional_candidates: Iterable[Path] = []
28
+ """Additional checkpoint candidates to consider when selecting the last checkpoint."""
29
+
30
+
31
+ class UserProvidedPathCheckpointStrategyConfig(C.Config):
32
+ name: Literal["user_provided_path"] = "user_provided_path"
33
+
34
+ path: Path
35
+ """The path to the checkpoint to load."""
36
+
37
+ on_error: Literal["warn", "raise"] = "warn"
38
+ """The behavior when the checkpoint does not belong to the current run.
39
+
40
+ - `warn`: Log a warning and skip the checkpoint.
41
+ - `raise`: Raise an error.
42
+ """
43
+
44
+
45
+ class LastCheckpointStrategyConfig(C.Config):
46
+ name: Literal["last"] = "last"
47
+
48
+ criterion: Literal["global_step", "runtime"] = "global_step"
49
+ """The criterion to use for selecting the last checkpoint.
50
+
51
+ - `global_step`: The checkpoint with the highest global step will be selected.
52
+ - `runtime`: The checkpoint with the highest runtime will be selected.
53
+ """
54
+
55
+ additional_candidates: Iterable[Path] = []
56
+ """Additional checkpoint candidates to consider when selecting the last checkpoint."""
57
+
58
+
59
+ CheckpointLoadingStrategyConfig: TypeAlias = Annotated[
60
+ BestCheckpointStrategyConfig
61
+ | LastCheckpointStrategyConfig
62
+ | UserProvidedPathCheckpointStrategyConfig,
63
+ C.Field(discriminator="name"),
64
+ ]
65
+
66
+
67
+ class CheckpointLoadingConfig(C.Config):
68
+ strategies: Sequence[CheckpointLoadingStrategyConfig]
69
+ """The strategies to use for loading checkpoints.
70
+
71
+ The order of the strategies determines the priority of the strategies.
72
+ The first strategy that resolves a checkpoint will be used.
73
+ """
74
+
75
+ include_hpc: bool
76
+ """Whether to include checkpoints from HPC pre-emption."""
77
+
78
+ @classmethod
79
+ def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
80
+ if ckpt is None:
81
+ ckpt = "last"
82
+ match ckpt:
83
+ case "best":
84
+ return cls(
85
+ strategies=[BestCheckpointStrategyConfig()],
86
+ include_hpc=True,
87
+ )
88
+ case "last":
89
+ return cls(
90
+ strategies=[LastCheckpointStrategyConfig()],
91
+ include_hpc=True,
92
+ )
93
+ case Path() | str():
94
+ ckpt = Path(ckpt)
95
+ return cls(
96
+ strategies=[
97
+ LastCheckpointStrategyConfig(additional_candidates=[ckpt]),
98
+ UserProvidedPathCheckpointStrategyConfig(path=ckpt),
99
+ ],
100
+ include_hpc=True,
101
+ )
102
+ case _:
103
+ assert_never(ckpt)
104
+
105
+ @classmethod
106
+ def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
107
+ if ckpt is None:
108
+ raise ValueError("Checkpoint path must be provided for evaluation.")
109
+
110
+ match ckpt:
111
+ case "best":
112
+ return cls(
113
+ strategies=[BestCheckpointStrategyConfig()],
114
+ include_hpc=False,
115
+ )
116
+ case "last":
117
+ return cls(
118
+ strategies=[LastCheckpointStrategyConfig()],
119
+ include_hpc=False,
120
+ )
121
+ case Path() | str():
122
+ ckpt = Path(ckpt)
123
+ return cls(
124
+ strategies=[UserProvidedPathCheckpointStrategyConfig(path=ckpt)],
125
+ include_hpc=False,
126
+ )
127
+ case _:
128
+ assert_never(ckpt)
129
+
130
+ @classmethod
131
+ def auto(
132
+ cls,
133
+ ckpt: Literal["best", "last"] | str | Path | None,
134
+ trainer_mode: TrainerFn,
135
+ ):
136
+ match trainer_mode:
137
+ case TrainerFn.FITTING:
138
+ return cls._auto_train(ckpt)
139
+ case TrainerFn.VALIDATING | TrainerFn.TESTING | TrainerFn.PREDICTING:
140
+ return cls._auto_eval(ckpt)
141
+ case _:
142
+ assert_never(trainer_mode)
143
+
144
+
145
+ @dataclass
146
+ class _CkptCandidate:
147
+ meta: CheckpointMetadata
148
+ meta_path: Path
149
+
150
+ @property
151
+ def ckpt_path(self):
152
+ return self.meta_path.with_name(self.meta.checkpoint_filename)
153
+
154
+
155
+ @overload
156
+ def _load_ckpt_meta(
157
+ path: Path,
158
+ root_config: "BaseConfig",
159
+ on_error: Literal["warn"] = "warn",
160
+ ) -> _CkptCandidate | None: ...
161
+ @overload
162
+ def _load_ckpt_meta(
163
+ path: Path,
164
+ root_config: "BaseConfig",
165
+ on_error: Literal["raise"],
166
+ ) -> _CkptCandidate: ...
167
+ def _load_ckpt_meta(
168
+ path: Path,
169
+ root_config: "BaseConfig",
170
+ on_error: Literal["warn", "raise"] = "warn",
171
+ ):
172
+ meta = CheckpointMetadata.from_file(path)
173
+ if root_config.id != meta.run_id:
174
+ error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
175
+ match on_error:
176
+ case "warn":
177
+ log.warn(error_msg)
178
+ case "raise":
179
+ raise ValueError(error_msg)
180
+ case _:
181
+ assert_never(on_error)
182
+ return None
183
+ return _CkptCandidate(meta, path)
184
+
185
+
186
+ def _checkpoint_candidates(
187
+ root_config: "BaseConfig",
188
+ trainer: LightningTrainer,
189
+ *,
190
+ include_hpc: bool = True,
191
+ ):
192
+ # Load the checkpoint directory, and throw if it doesn't exist.
193
+ # This indicates a non-standard setup, and we don't want to guess
194
+ # where the checkpoints are.
195
+ ckpt_dir = root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
196
+ if not ckpt_dir.is_dir():
197
+ raise FileNotFoundError(
198
+ f"Checkpoint directory {ckpt_dir} not found. "
199
+ "Please ensure that the checkpoint directory exists."
200
+ )
201
+
202
+ # Load all checkpoints in the directory.
203
+ # We can do this by looking for metadata files.
204
+ for path in ckpt_dir.glob(f"*{METADATA_PATH_SUFFIX}"):
205
+ if (meta := _load_ckpt_meta(path, root_config)) is not None:
206
+ yield meta
207
+
208
+ # If we have a pre-empted checkpoint, load it
209
+ if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
210
+ hpc_meta_path = Path(hpc_path).with_suffix(METADATA_PATH_SUFFIX)
211
+ if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
212
+ yield meta
213
+
214
+
215
+ def _additional_candidates(
216
+ additional_candidates: Iterable[Path], root_config: "BaseConfig"
217
+ ):
218
+ for path in additional_candidates:
219
+ if (
220
+ meta := _load_ckpt_meta(path.with_suffix(METADATA_PATH_SUFFIX), root_config)
221
+ ) is None:
222
+ continue
223
+ yield meta
224
+
225
+
226
+ def _resolve_checkpoint(
227
+ config: CheckpointLoadingConfig,
228
+ root_config: "BaseConfig",
229
+ trainer: LightningTrainer,
230
+ ):
231
+ # We lazily load the checkpoint candidates to avoid loading them
232
+ # if they are not needed.
233
+ _ckpt_candidates: list[_CkptCandidate] | None = None
234
+
235
+ def ckpt_candidates():
236
+ nonlocal _ckpt_candidates, root_config, trainer
237
+
238
+ if _ckpt_candidates is None:
239
+ _ckpt_candidates = list(
240
+ _checkpoint_candidates(
241
+ root_config, trainer, include_hpc=config.include_hpc
242
+ )
243
+ )
244
+ return _ckpt_candidates
245
+
246
+ # Iterate over the strategies and try to resolve the checkpoint.
247
+ for strategy in config.strategies:
248
+ match strategy:
249
+ case UserProvidedPathCheckpointStrategyConfig():
250
+ meta = _load_ckpt_meta(
251
+ strategy.path.with_suffix(METADATA_PATH_SUFFIX),
252
+ root_config,
253
+ on_error=strategy.on_error,
254
+ )
255
+ if meta is None:
256
+ continue
257
+ return meta.ckpt_path
258
+ case BestCheckpointStrategyConfig():
259
+ candidates = [
260
+ *ckpt_candidates(),
261
+ *_additional_candidates(
262
+ strategy.additional_candidates, root_config
263
+ ),
264
+ ]
265
+ if not candidates:
266
+ log.warn(
267
+ "No checkpoint candidates found for `best` checkpoint strategy."
268
+ )
269
+ continue
270
+
271
+ if (metric := strategy.metric or root_config.primary_metric) is None:
272
+ log.warn(
273
+ "No metric specified for `best` checkpoint strategy, "
274
+ "and no primary metric is set in the configuration. "
275
+ "Skipping strategy."
276
+ )
277
+ continue
278
+
279
+ # Find the best checkpoint based on the metric.
280
+ def metric_value(ckpt: _CkptCandidate):
281
+ assert metric is not None
282
+ if (
283
+ value := ckpt.meta.metrics.get(metric.validation_monitor)
284
+ ) is None:
285
+ raise ValueError(
286
+ f"Metric {metric.validation_monitor} not found in checkpoint metadata. "
287
+ f"Available metrics: {ckpt.meta.metrics.keys()}"
288
+ )
289
+ return value
290
+
291
+ best_candidate = metric.best(candidates, key=metric_value)
292
+ return best_candidate.ckpt_path
293
+ case LastCheckpointStrategyConfig():
294
+ candidates = [
295
+ *ckpt_candidates(),
296
+ *_additional_candidates(
297
+ strategy.additional_candidates, root_config
298
+ ),
299
+ ]
300
+ if not candidates:
301
+ log.warn(
302
+ "No checkpoint candidates found for `last` checkpoint strategy."
303
+ )
304
+ continue
305
+
306
+ # Find the last checkpoint based on the criterion.
307
+ def criterion_value(ckpt: _CkptCandidate):
308
+ match strategy.criterion:
309
+ case "global_step":
310
+ return ckpt.meta.global_step
311
+ case "runtime":
312
+ return ckpt.meta.training_time.total_seconds()
313
+ case _:
314
+ assert_never(strategy.criterion)
315
+
316
+ last_candidate = max(candidates, key=criterion_value)
317
+ return last_candidate.ckpt_path
318
+ case _:
319
+ assert_never(strategy)
@@ -0,0 +1,120 @@
1
+ import datetime
2
+ import logging
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Any, Literal, TypeAlias
6
+
7
+ from lightning.pytorch.callbacks.callback import Callback
8
+ from typing_extensions import override
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ Stage: TypeAlias = Literal["train", "validate", "test", "predict"]
13
+ ALL_STAGES = ("train", "validate", "test", "predict")
14
+
15
+
16
+ @dataclass
17
+ class TimeInfo:
18
+ datetime: datetime.datetime
19
+ monotonic: float
20
+
21
+
22
+ class RuntimeTrackerCallback(Callback):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self._start_time: dict[Stage, TimeInfo] = {}
26
+ self._end_time: dict[Stage, TimeInfo] = {}
27
+ self._offsets = {stage: datetime.timedelta() for stage in ALL_STAGES}
28
+
29
+ def start_time(self, stage: Stage) -> TimeInfo | None:
30
+ """Return the start time of a particular stage"""
31
+ return self._start_time.get(stage)
32
+
33
+ def end_time(self, stage: Stage) -> TimeInfo | None:
34
+ """Return the end time of a particular stage"""
35
+ return self._end_time.get(stage)
36
+
37
+ def time_elapsed(self, stage: Stage) -> datetime.timedelta:
38
+ """Return the time elapsed for a particular stage"""
39
+ start = self.start_time(stage)
40
+ end = self.end_time(stage)
41
+ offset = self._offsets[stage]
42
+ if start is None:
43
+ return offset
44
+ if end is None:
45
+ current = TimeInfo(datetime.datetime.now(), time.monotonic())
46
+ return (
47
+ datetime.timedelta(seconds=current.monotonic - start.monotonic) + offset
48
+ )
49
+ return datetime.timedelta(seconds=end.monotonic - start.monotonic) + offset
50
+
51
+ def _record_time(self, stage: Stage, time_dict: dict[Stage, TimeInfo]):
52
+ time_dict[stage] = TimeInfo(datetime.datetime.now(), time.monotonic())
53
+
54
+ @override
55
+ def on_train_start(self, trainer, pl_module):
56
+ self._record_time("train", self._start_time)
57
+
58
+ @override
59
+ def on_train_end(self, trainer, pl_module):
60
+ self._record_time("train", self._end_time)
61
+
62
+ @override
63
+ def on_validation_start(self, trainer, pl_module):
64
+ self._record_time("validate", self._start_time)
65
+
66
+ @override
67
+ def on_validation_end(self, trainer, pl_module):
68
+ self._record_time("validate", self._end_time)
69
+
70
+ @override
71
+ def on_test_start(self, trainer, pl_module):
72
+ self._record_time("test", self._start_time)
73
+
74
+ @override
75
+ def on_test_end(self, trainer, pl_module):
76
+ self._record_time("test", self._end_time)
77
+
78
+ @override
79
+ def on_predict_start(self, trainer, pl_module):
80
+ self._record_time("predict", self._start_time)
81
+
82
+ @override
83
+ def on_predict_end(self, trainer, pl_module):
84
+ self._record_time("predict", self._end_time)
85
+
86
+ @override
87
+ def state_dict(self) -> dict[str, Any]:
88
+ return {
89
+ "time_elapsed": {
90
+ stage: self.time_elapsed(stage).total_seconds() for stage in ALL_STAGES
91
+ },
92
+ "start_times": {
93
+ stage: (info.datetime.isoformat(), info.monotonic)
94
+ for stage, info in self._start_time.items()
95
+ },
96
+ "end_times": {
97
+ stage: (info.datetime.isoformat(), info.monotonic)
98
+ for stage, info in self._end_time.items()
99
+ },
100
+ }
101
+
102
+ @override
103
+ def load_state_dict(self, state_dict: dict[str, Any]):
104
+ time_elapsed: dict[Stage, float] = state_dict.get("time_elapsed", {})
105
+ for stage in ALL_STAGES:
106
+ self._offsets[stage] = datetime.timedelta(
107
+ seconds=time_elapsed.get(stage, 0)
108
+ )
109
+
110
+ start_times: dict[Stage, tuple[str, float]] = state_dict.get("start_times", {})
111
+ for stage, (dt_str, monotonic) in start_times.items():
112
+ self._start_time[stage] = TimeInfo(
113
+ datetime.datetime.fromisoformat(dt_str), monotonic
114
+ )
115
+
116
+ end_times: dict[Stage, tuple[str, float]] = state_dict.get("end_times", {})
117
+ for stage, (dt_str, monotonic) in end_times.items():
118
+ self._end_time[stage] = TimeInfo(
119
+ datetime.datetime.fromisoformat(dt_str), monotonic
120
+ )
@@ -0,0 +1,63 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, cast
4
+
5
+ from lightning.pytorch.trainer.connectors.checkpoint_connector import (
6
+ _CheckpointConnector,
7
+ )
8
+ from lightning.pytorch.trainer.states import TrainerFn
9
+ from typing_extensions import override
10
+
11
+ from ._checkpoint_resolver import CheckpointLoadingConfig, _resolve_checkpoint
12
+
13
+ if TYPE_CHECKING:
14
+ from ..model.config import BaseConfig
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class CheckpointConnector(_CheckpointConnector):
19
+ def __resolve_auto_ckpt_path(
20
+ self,
21
+ ckpt_path: str | Path | None,
22
+ state_fn: TrainerFn,
23
+ ):
24
+ from .trainer import Trainer
25
+
26
+ # If this isn't an `nshtrainer` trainer (which I don't know why it wouldn't be),
27
+ # then we just default to the parent class's implementation of `_parse_ckpt_path`.
28
+ trainer = self.trainer
29
+ if not isinstance(trainer, Trainer):
30
+ return None
31
+
32
+ # Now, resolve the checkpoint loader config.
33
+ root_config = cast("BaseConfig", trainer._base_module.config)
34
+ if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
35
+ ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
36
+ log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
37
+
38
+ # Use the config to resolve the checkpoint.
39
+ if (
40
+ ckpt_path := _resolve_checkpoint(ckpt_loader_config, root_config, trainer)
41
+ ) is None:
42
+ log.info(
43
+ "No checkpoint found for the current trainer state. "
44
+ "Training will start from scratch."
45
+ )
46
+
47
+ log.info(f"Loading checkpoint from: {ckpt_path}")
48
+ return ckpt_path
49
+
50
+ @override
51
+ def _parse_ckpt_path(
52
+ self,
53
+ state_fn: TrainerFn,
54
+ ckpt_path: str | Path | None,
55
+ model_provided: bool,
56
+ model_connected: bool,
57
+ ):
58
+ if (p := self.__resolve_auto_ckpt_path(ckpt_path, state_fn)) is not None:
59
+ return p
60
+
61
+ return super()._parse_ckpt_path(
62
+ state_fn, ckpt_path, model_provided, model_connected
63
+ )