nshtrainer 0.9.1__tar.gz → 0.10.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/PKG-INFO +3 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/pyproject.toml +3 -2
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/__init__.py +2 -1
- nshtrainer-0.10.1/src/nshtrainer/_checkpoint/loader.py +319 -0
- nshtrainer-0.10.1/src/nshtrainer/_checkpoint/metadata.py +102 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer-0.9.1/src/nshtrainer/actsave/_callback.py → nshtrainer-0.10.1/src/nshtrainer/callbacks/actsave.py +68 -10
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/base.py +7 -5
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/ema.py +1 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/finite_checks.py +1 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer-0.10.1/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +81 -0
- nshtrainer-0.10.1/src/nshtrainer/callbacks/model_checkpoint.py +187 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer-0.10.1/src/nshtrainer/callbacks/on_exception_checkpoint.py +98 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/print_table.py +1 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/throughput_monitor.py +1 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/timer.py +1 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/wandb_watch.py +1 -1
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/__init__.py +0 -1
- nshtrainer-0.10.1/src/nshtrainer/ll/actsave.py +4 -0
- nshtrainer-0.10.1/src/nshtrainer/metrics/__init__.py +1 -0
- nshtrainer-0.10.1/src/nshtrainer/metrics/_config.py +37 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/__init__.py +11 -11
- nshtrainer-0.10.1/src/nshtrainer/model/_environment.py +777 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/base.py +5 -114
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/config.py +49 -501
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/logger.py +11 -6
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/runner.py +3 -6
- nshtrainer-0.10.1/src/nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer-0.10.1/src/nshtrainer/trainer/checkpoint_connector.py +63 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/trainer/signal_connector.py +12 -9
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/trainer/trainer.py +111 -31
- nshtrainer-0.9.1/src/nshtrainer/actsave/__init__.py +0 -3
- nshtrainer-0.9.1/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -45
- nshtrainer-0.9.1/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -44
- nshtrainer-0.9.1/src/nshtrainer/ll/actsave.py +0 -3
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/README.md +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/scripts/check_env.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.9.1 → nshtrainer-0.10.1}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.10.1
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -9,11 +9,13 @@ Classifier: Programming Language :: Python :: 3
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: GitPython
|
|
12
13
|
Requires-Dist: lightning
|
|
13
14
|
Requires-Dist: nshconfig
|
|
14
15
|
Requires-Dist: nshrunner
|
|
15
16
|
Requires-Dist: nshutils
|
|
16
17
|
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: psutil
|
|
17
19
|
Requires-Dist: pytorch-lightning
|
|
18
20
|
Requires-Dist: torch
|
|
19
21
|
Requires-Dist: torchmetrics
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.10.1"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -10,6 +10,7 @@ python = "^3.10"
|
|
|
10
10
|
nshrunner = "*"
|
|
11
11
|
nshconfig = "*"
|
|
12
12
|
nshutils = "*"
|
|
13
|
+
psutil = "*"
|
|
13
14
|
torch = "*"
|
|
14
15
|
typing-extensions = "*"
|
|
15
16
|
lightning = "*"
|
|
@@ -17,7 +18,7 @@ pytorch-lightning = "*"
|
|
|
17
18
|
torchmetrics = "*"
|
|
18
19
|
numpy = "*"
|
|
19
20
|
wrapt = "*"
|
|
20
|
-
|
|
21
|
+
GitPython = "*"
|
|
21
22
|
|
|
22
23
|
[tool.poetry.group.dev.dependencies]
|
|
23
24
|
pyright = "^1.1.372"
|
|
@@ -2,13 +2,14 @@ from . import _experimental as _experimental
|
|
|
2
2
|
from . import callbacks as callbacks
|
|
3
3
|
from . import data as data
|
|
4
4
|
from . import lr_scheduler as lr_scheduler
|
|
5
|
+
from . import metrics as metrics
|
|
5
6
|
from . import model as model
|
|
6
7
|
from . import nn as nn
|
|
7
8
|
from . import optimizer as optimizer
|
|
9
|
+
from .metrics import MetricConfig as MetricConfig
|
|
8
10
|
from .model import Base as Base
|
|
9
11
|
from .model import BaseConfig as BaseConfig
|
|
10
12
|
from .model import ConfigList as ConfigList
|
|
11
13
|
from .model import LightningModuleBase as LightningModuleBase
|
|
12
|
-
from .model import MetricConfig as MetricConfig
|
|
13
14
|
from .runner import Runner as Runner
|
|
14
15
|
from .trainer import Trainer as Trainer
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
|
|
6
|
+
|
|
7
|
+
import nshconfig as C
|
|
8
|
+
from lightning.pytorch import Trainer as LightningTrainer
|
|
9
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from ..metrics._config import MetricConfig
|
|
13
|
+
from .metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..model.config import BaseConfig
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BestCheckpointStrategyConfig(C.Config):
|
|
22
|
+
name: Literal["best"] = "best"
|
|
23
|
+
|
|
24
|
+
metric: MetricConfig | None = None
|
|
25
|
+
"""The metric to use for selecting the best checkpoint. If `None`, the primary metric will be used."""
|
|
26
|
+
|
|
27
|
+
additional_candidates: Iterable[Path] = []
|
|
28
|
+
"""Additional checkpoint candidates to consider when selecting the last checkpoint."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class UserProvidedPathCheckpointStrategyConfig(C.Config):
|
|
32
|
+
name: Literal["user_provided_path"] = "user_provided_path"
|
|
33
|
+
|
|
34
|
+
path: Path
|
|
35
|
+
"""The path to the checkpoint to load."""
|
|
36
|
+
|
|
37
|
+
on_error: Literal["warn", "raise"] = "warn"
|
|
38
|
+
"""The behavior when the checkpoint does not belong to the current run.
|
|
39
|
+
|
|
40
|
+
- `warn`: Log a warning and skip the checkpoint.
|
|
41
|
+
- `raise`: Raise an error.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LastCheckpointStrategyConfig(C.Config):
|
|
46
|
+
name: Literal["last"] = "last"
|
|
47
|
+
|
|
48
|
+
criterion: Literal["global_step", "runtime"] = "global_step"
|
|
49
|
+
"""The criterion to use for selecting the last checkpoint.
|
|
50
|
+
|
|
51
|
+
- `global_step`: The checkpoint with the highest global step will be selected.
|
|
52
|
+
- `runtime`: The checkpoint with the highest runtime will be selected.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
additional_candidates: Iterable[Path] = []
|
|
56
|
+
"""Additional checkpoint candidates to consider when selecting the last checkpoint."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
CheckpointLoadingStrategyConfig: TypeAlias = Annotated[
|
|
60
|
+
BestCheckpointStrategyConfig
|
|
61
|
+
| LastCheckpointStrategyConfig
|
|
62
|
+
| UserProvidedPathCheckpointStrategyConfig,
|
|
63
|
+
C.Field(discriminator="name"),
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class CheckpointLoadingConfig(C.Config):
|
|
68
|
+
strategies: Sequence[CheckpointLoadingStrategyConfig]
|
|
69
|
+
"""The strategies to use for loading checkpoints.
|
|
70
|
+
|
|
71
|
+
The order of the strategies determines the priority of the strategies.
|
|
72
|
+
The first strategy that resolves a checkpoint will be used.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
include_hpc: bool
|
|
76
|
+
"""Whether to include checkpoints from HPC pre-emption."""
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
80
|
+
if ckpt is None:
|
|
81
|
+
ckpt = "last"
|
|
82
|
+
match ckpt:
|
|
83
|
+
case "best":
|
|
84
|
+
return cls(
|
|
85
|
+
strategies=[BestCheckpointStrategyConfig()],
|
|
86
|
+
include_hpc=True,
|
|
87
|
+
)
|
|
88
|
+
case "last":
|
|
89
|
+
return cls(
|
|
90
|
+
strategies=[LastCheckpointStrategyConfig()],
|
|
91
|
+
include_hpc=True,
|
|
92
|
+
)
|
|
93
|
+
case Path() | str():
|
|
94
|
+
ckpt = Path(ckpt)
|
|
95
|
+
return cls(
|
|
96
|
+
strategies=[
|
|
97
|
+
LastCheckpointStrategyConfig(additional_candidates=[ckpt]),
|
|
98
|
+
UserProvidedPathCheckpointStrategyConfig(path=ckpt),
|
|
99
|
+
],
|
|
100
|
+
include_hpc=True,
|
|
101
|
+
)
|
|
102
|
+
case _:
|
|
103
|
+
assert_never(ckpt)
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
107
|
+
if ckpt is None:
|
|
108
|
+
raise ValueError("Checkpoint path must be provided for evaluation.")
|
|
109
|
+
|
|
110
|
+
match ckpt:
|
|
111
|
+
case "best":
|
|
112
|
+
return cls(
|
|
113
|
+
strategies=[BestCheckpointStrategyConfig()],
|
|
114
|
+
include_hpc=False,
|
|
115
|
+
)
|
|
116
|
+
case "last":
|
|
117
|
+
return cls(
|
|
118
|
+
strategies=[LastCheckpointStrategyConfig()],
|
|
119
|
+
include_hpc=False,
|
|
120
|
+
)
|
|
121
|
+
case Path() | str():
|
|
122
|
+
ckpt = Path(ckpt)
|
|
123
|
+
return cls(
|
|
124
|
+
strategies=[UserProvidedPathCheckpointStrategyConfig(path=ckpt)],
|
|
125
|
+
include_hpc=False,
|
|
126
|
+
)
|
|
127
|
+
case _:
|
|
128
|
+
assert_never(ckpt)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def auto(
|
|
132
|
+
cls,
|
|
133
|
+
ckpt: Literal["best", "last"] | str | Path | None,
|
|
134
|
+
trainer_mode: TrainerFn,
|
|
135
|
+
):
|
|
136
|
+
match trainer_mode:
|
|
137
|
+
case TrainerFn.FITTING:
|
|
138
|
+
return cls._auto_train(ckpt)
|
|
139
|
+
case TrainerFn.VALIDATING | TrainerFn.TESTING | TrainerFn.PREDICTING:
|
|
140
|
+
return cls._auto_eval(ckpt)
|
|
141
|
+
case _:
|
|
142
|
+
assert_never(trainer_mode)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class _CkptCandidate:
|
|
147
|
+
meta: CheckpointMetadata
|
|
148
|
+
meta_path: Path
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def ckpt_path(self):
|
|
152
|
+
return self.meta_path.with_name(self.meta.checkpoint_filename)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@overload
|
|
156
|
+
def _load_ckpt_meta(
|
|
157
|
+
path: Path,
|
|
158
|
+
root_config: "BaseConfig",
|
|
159
|
+
on_error: Literal["warn"] = "warn",
|
|
160
|
+
) -> _CkptCandidate | None: ...
|
|
161
|
+
@overload
|
|
162
|
+
def _load_ckpt_meta(
|
|
163
|
+
path: Path,
|
|
164
|
+
root_config: "BaseConfig",
|
|
165
|
+
on_error: Literal["raise"],
|
|
166
|
+
) -> _CkptCandidate: ...
|
|
167
|
+
def _load_ckpt_meta(
|
|
168
|
+
path: Path,
|
|
169
|
+
root_config: "BaseConfig",
|
|
170
|
+
on_error: Literal["warn", "raise"] = "warn",
|
|
171
|
+
):
|
|
172
|
+
meta = CheckpointMetadata.from_file(path)
|
|
173
|
+
if root_config.id != meta.run_id:
|
|
174
|
+
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
|
175
|
+
match on_error:
|
|
176
|
+
case "warn":
|
|
177
|
+
log.warn(error_msg)
|
|
178
|
+
case "raise":
|
|
179
|
+
raise ValueError(error_msg)
|
|
180
|
+
case _:
|
|
181
|
+
assert_never(on_error)
|
|
182
|
+
return None
|
|
183
|
+
return _CkptCandidate(meta, path)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _checkpoint_candidates(
|
|
187
|
+
root_config: "BaseConfig",
|
|
188
|
+
trainer: LightningTrainer,
|
|
189
|
+
*,
|
|
190
|
+
include_hpc: bool = True,
|
|
191
|
+
):
|
|
192
|
+
# Load the checkpoint directory, and throw if it doesn't exist.
|
|
193
|
+
# This indicates a non-standard setup, and we don't want to guess
|
|
194
|
+
# where the checkpoints are.
|
|
195
|
+
ckpt_dir = root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
196
|
+
if not ckpt_dir.is_dir():
|
|
197
|
+
raise FileNotFoundError(
|
|
198
|
+
f"Checkpoint directory {ckpt_dir} not found. "
|
|
199
|
+
"Please ensure that the checkpoint directory exists."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Load all checkpoints in the directory.
|
|
203
|
+
# We can do this by looking for metadata files.
|
|
204
|
+
for path in ckpt_dir.glob(f"*{METADATA_PATH_SUFFIX}"):
|
|
205
|
+
if (meta := _load_ckpt_meta(path, root_config)) is not None:
|
|
206
|
+
yield meta
|
|
207
|
+
|
|
208
|
+
# If we have a pre-empted checkpoint, load it
|
|
209
|
+
if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
|
|
210
|
+
hpc_meta_path = Path(hpc_path).with_suffix(METADATA_PATH_SUFFIX)
|
|
211
|
+
if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
|
|
212
|
+
yield meta
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _additional_candidates(
|
|
216
|
+
additional_candidates: Iterable[Path], root_config: "BaseConfig"
|
|
217
|
+
):
|
|
218
|
+
for path in additional_candidates:
|
|
219
|
+
if (
|
|
220
|
+
meta := _load_ckpt_meta(path.with_suffix(METADATA_PATH_SUFFIX), root_config)
|
|
221
|
+
) is None:
|
|
222
|
+
continue
|
|
223
|
+
yield meta
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _resolve_checkpoint(
|
|
227
|
+
config: CheckpointLoadingConfig,
|
|
228
|
+
root_config: "BaseConfig",
|
|
229
|
+
trainer: LightningTrainer,
|
|
230
|
+
):
|
|
231
|
+
# We lazily load the checkpoint candidates to avoid loading them
|
|
232
|
+
# if they are not needed.
|
|
233
|
+
_ckpt_candidates: list[_CkptCandidate] | None = None
|
|
234
|
+
|
|
235
|
+
def ckpt_candidates():
|
|
236
|
+
nonlocal _ckpt_candidates, root_config, trainer
|
|
237
|
+
|
|
238
|
+
if _ckpt_candidates is None:
|
|
239
|
+
_ckpt_candidates = list(
|
|
240
|
+
_checkpoint_candidates(
|
|
241
|
+
root_config, trainer, include_hpc=config.include_hpc
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
return _ckpt_candidates
|
|
245
|
+
|
|
246
|
+
# Iterate over the strategies and try to resolve the checkpoint.
|
|
247
|
+
for strategy in config.strategies:
|
|
248
|
+
match strategy:
|
|
249
|
+
case UserProvidedPathCheckpointStrategyConfig():
|
|
250
|
+
meta = _load_ckpt_meta(
|
|
251
|
+
strategy.path.with_suffix(METADATA_PATH_SUFFIX),
|
|
252
|
+
root_config,
|
|
253
|
+
on_error=strategy.on_error,
|
|
254
|
+
)
|
|
255
|
+
if meta is None:
|
|
256
|
+
continue
|
|
257
|
+
return meta.ckpt_path
|
|
258
|
+
case BestCheckpointStrategyConfig():
|
|
259
|
+
candidates = [
|
|
260
|
+
*ckpt_candidates(),
|
|
261
|
+
*_additional_candidates(
|
|
262
|
+
strategy.additional_candidates, root_config
|
|
263
|
+
),
|
|
264
|
+
]
|
|
265
|
+
if not candidates:
|
|
266
|
+
log.warn(
|
|
267
|
+
"No checkpoint candidates found for `best` checkpoint strategy."
|
|
268
|
+
)
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
if (metric := strategy.metric or root_config.primary_metric) is None:
|
|
272
|
+
log.warn(
|
|
273
|
+
"No metric specified for `best` checkpoint strategy, "
|
|
274
|
+
"and no primary metric is set in the configuration. "
|
|
275
|
+
"Skipping strategy."
|
|
276
|
+
)
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
# Find the best checkpoint based on the metric.
|
|
280
|
+
def metric_value(ckpt: _CkptCandidate):
|
|
281
|
+
assert metric is not None
|
|
282
|
+
if (
|
|
283
|
+
value := ckpt.meta.metrics.get(metric.validation_monitor)
|
|
284
|
+
) is None:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
f"Metric {metric.validation_monitor} not found in checkpoint metadata. "
|
|
287
|
+
f"Available metrics: {ckpt.meta.metrics.keys()}"
|
|
288
|
+
)
|
|
289
|
+
return value
|
|
290
|
+
|
|
291
|
+
best_candidate = metric.best(candidates, key=metric_value)
|
|
292
|
+
return best_candidate.ckpt_path
|
|
293
|
+
case LastCheckpointStrategyConfig():
|
|
294
|
+
candidates = [
|
|
295
|
+
*ckpt_candidates(),
|
|
296
|
+
*_additional_candidates(
|
|
297
|
+
strategy.additional_candidates, root_config
|
|
298
|
+
),
|
|
299
|
+
]
|
|
300
|
+
if not candidates:
|
|
301
|
+
log.warn(
|
|
302
|
+
"No checkpoint candidates found for `last` checkpoint strategy."
|
|
303
|
+
)
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
# Find the last checkpoint based on the criterion.
|
|
307
|
+
def criterion_value(ckpt: _CkptCandidate):
|
|
308
|
+
match strategy.criterion:
|
|
309
|
+
case "global_step":
|
|
310
|
+
return ckpt.meta.global_step
|
|
311
|
+
case "runtime":
|
|
312
|
+
return ckpt.meta.training_time.total_seconds()
|
|
313
|
+
case _:
|
|
314
|
+
assert_never(strategy.criterion)
|
|
315
|
+
|
|
316
|
+
last_candidate = max(candidates, key=criterion_value)
|
|
317
|
+
return last_candidate.ckpt_path
|
|
318
|
+
case _:
|
|
319
|
+
assert_never(strategy)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import datetime
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
6
|
+
|
|
7
|
+
import nshconfig as C
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ..model._environment import EnvironmentConfig
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..model import BaseConfig, LightningModuleBase
|
|
15
|
+
from ..trainer.trainer import Trainer
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
METADATA_PATH_SUFFIX = ".metadata.json"
|
|
21
|
+
HPARAMS_PATH_SUFFIX = ".hparams.json"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CheckpointMetadata(C.Config):
|
|
25
|
+
checkpoint_path: Path
|
|
26
|
+
checkpoint_filename: str
|
|
27
|
+
|
|
28
|
+
run_id: str
|
|
29
|
+
name: str
|
|
30
|
+
project: str | None
|
|
31
|
+
checkpoint_timestamp: datetime.datetime
|
|
32
|
+
start_timestamp: datetime.datetime | None
|
|
33
|
+
|
|
34
|
+
epoch: int
|
|
35
|
+
global_step: int
|
|
36
|
+
training_time: datetime.timedelta
|
|
37
|
+
metrics: dict[str, Any]
|
|
38
|
+
environment: EnvironmentConfig
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_file(cls, path: Path):
|
|
42
|
+
return cls.model_validate_json(path.read_text())
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _generate_checkpoint_metadata(
|
|
46
|
+
config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
|
|
47
|
+
):
|
|
48
|
+
checkpoint_timestamp = datetime.datetime.now()
|
|
49
|
+
start_timestamp = trainer.start_time()
|
|
50
|
+
training_time = trainer.time_elapsed()
|
|
51
|
+
|
|
52
|
+
metrics: dict[str, Any] = {}
|
|
53
|
+
for name, metric in copy.deepcopy(trainer.callback_metrics).items():
|
|
54
|
+
match metric:
|
|
55
|
+
case torch.Tensor() | np.ndarray():
|
|
56
|
+
metrics[name] = metric.detach().cpu().item()
|
|
57
|
+
case _:
|
|
58
|
+
metrics[name] = metric
|
|
59
|
+
|
|
60
|
+
return CheckpointMetadata(
|
|
61
|
+
checkpoint_path=checkpoint_path,
|
|
62
|
+
checkpoint_filename=checkpoint_path.name,
|
|
63
|
+
run_id=config.id,
|
|
64
|
+
name=config.run_name,
|
|
65
|
+
project=config.project,
|
|
66
|
+
checkpoint_timestamp=checkpoint_timestamp,
|
|
67
|
+
start_timestamp=start_timestamp.datetime
|
|
68
|
+
if start_timestamp is not None
|
|
69
|
+
else None,
|
|
70
|
+
epoch=trainer.current_epoch,
|
|
71
|
+
global_step=trainer.global_step,
|
|
72
|
+
training_time=training_time,
|
|
73
|
+
metrics=metrics,
|
|
74
|
+
environment=config.environment,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _write_checkpoint_metadata(
|
|
79
|
+
trainer: "Trainer",
|
|
80
|
+
model: "LightningModuleBase",
|
|
81
|
+
checkpoint_path: Path,
|
|
82
|
+
):
|
|
83
|
+
config = cast("BaseConfig", model.config)
|
|
84
|
+
metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
|
|
85
|
+
|
|
86
|
+
# Write the metadata to the checkpoint directory
|
|
87
|
+
try:
|
|
88
|
+
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
89
|
+
metadata_path.write_text(metadata.model_dump_json(indent=4))
|
|
90
|
+
except Exception as e:
|
|
91
|
+
log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
|
|
92
|
+
else:
|
|
93
|
+
log.info(f"Checkpoint metadata written to {checkpoint_path}")
|
|
94
|
+
|
|
95
|
+
# Write the hparams to the checkpoint directory
|
|
96
|
+
try:
|
|
97
|
+
hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
|
|
98
|
+
hparams_path.write_text(config.model_dump_json(indent=4))
|
|
99
|
+
except Exception as e:
|
|
100
|
+
log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
|
|
101
|
+
else:
|
|
102
|
+
log.info(f"Checkpoint metadata written to {checkpoint_path}")
|
|
@@ -14,15 +14,27 @@ from .interval import EpochIntervalCallback as EpochIntervalCallback
|
|
|
14
14
|
from .interval import IntervalCallback as IntervalCallback
|
|
15
15
|
from .interval import StepIntervalCallback as StepIntervalCallback
|
|
16
16
|
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
17
|
+
from .latest_epoch_checkpoint import (
|
|
18
|
+
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
19
|
+
)
|
|
17
20
|
from .log_epoch import LogEpochCallback as LogEpochCallback
|
|
21
|
+
from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
22
|
+
from .model_checkpoint import (
|
|
23
|
+
ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
|
|
24
|
+
)
|
|
18
25
|
from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
|
19
26
|
from .norm_logging import NormLoggingConfig as NormLoggingConfig
|
|
20
27
|
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
28
|
+
from .on_exception_checkpoint import (
|
|
29
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
30
|
+
)
|
|
21
31
|
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
|
22
32
|
from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
|
|
23
33
|
from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
24
34
|
from .timer import EpochTimer as EpochTimer
|
|
25
35
|
from .timer import EpochTimerConfig as EpochTimerConfig
|
|
36
|
+
from .wandb_watch import WandbWatchCallback as WandbWatchCallback
|
|
37
|
+
from .wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
26
38
|
|
|
27
39
|
CallbackConfig = Annotated[
|
|
28
40
|
ThroughputMonitorConfig
|
|
@@ -31,6 +43,10 @@ CallbackConfig = Annotated[
|
|
|
31
43
|
| FiniteChecksConfig
|
|
32
44
|
| NormLoggingConfig
|
|
33
45
|
| GradientSkippingConfig
|
|
34
|
-
| EMAConfig
|
|
46
|
+
| EMAConfig
|
|
47
|
+
| ModelCheckpointCallbackConfig
|
|
48
|
+
| LatestEpochCheckpointCallbackConfig
|
|
49
|
+
| OnExceptionCheckpointCallbackConfig
|
|
50
|
+
| WandbWatchConfig,
|
|
35
51
|
C.Field(discriminator="name"),
|
|
36
52
|
]
|
|
@@ -1,28 +1,87 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
3
4
|
|
|
4
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
5
6
|
from lightning.pytorch.callbacks.callback import Callback
|
|
6
|
-
from nshutils.actsave import ActSave
|
|
7
7
|
from typing_extensions import TypeAlias, override
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
from .base import CallbackConfigBase
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from nshutils import ActSave # type: ignore
|
|
13
|
+
except ImportError:
|
|
14
|
+
ActSave = None
|
|
11
15
|
|
|
12
16
|
Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
|
|
13
17
|
|
|
14
18
|
|
|
19
|
+
class ActSaveConfig(CallbackConfigBase):
|
|
20
|
+
enabled: bool = True
|
|
21
|
+
"""Enable activation saving."""
|
|
22
|
+
|
|
23
|
+
save_dir: Path | None = None
|
|
24
|
+
"""Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
|
|
25
|
+
|
|
26
|
+
def __bool__(self):
|
|
27
|
+
return self.enabled
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def create_callbacks(self, root_config):
|
|
31
|
+
yield ActSaveCallback(
|
|
32
|
+
self,
|
|
33
|
+
self.save_dir
|
|
34
|
+
or root_config.directory.resolve_subdirectory(root_config.id, "activation"),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
15
38
|
class ActSaveCallback(Callback):
|
|
16
|
-
def __init__(self):
|
|
39
|
+
def __init__(self, config: ActSaveConfig, save_dir: Path):
|
|
17
40
|
super().__init__()
|
|
18
41
|
|
|
42
|
+
self.config = config
|
|
43
|
+
self.save_dir = save_dir
|
|
44
|
+
self._enabled_context: contextlib._GeneratorContextManager | None = None
|
|
19
45
|
self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
|
|
20
46
|
|
|
47
|
+
@override
|
|
48
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
49
|
+
super().setup(trainer, pl_module, stage)
|
|
50
|
+
|
|
51
|
+
if not self.config:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
if ActSave is None:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"ActSave is not installed. Please install nshutils to use the ActSaveCallback."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
context = ActSave.enabled(self.save_dir)
|
|
60
|
+
context.__enter__()
|
|
61
|
+
self._enabled_context = context
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def teardown(
|
|
65
|
+
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
|
66
|
+
) -> None:
|
|
67
|
+
super().teardown(trainer, pl_module, stage)
|
|
68
|
+
|
|
69
|
+
if not self.config:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
if self._enabled_context is not None:
|
|
73
|
+
self._enabled_context.__exit__(None, None, None)
|
|
74
|
+
self._enabled_context = None
|
|
75
|
+
|
|
21
76
|
def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
|
|
22
|
-
|
|
23
|
-
if not hparams.trainer.actsave:
|
|
77
|
+
if not self.config:
|
|
24
78
|
return
|
|
25
79
|
|
|
80
|
+
if ActSave is None:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
"ActSave is not installed. Please install nshutils to use the ActSaveCallback."
|
|
83
|
+
)
|
|
84
|
+
|
|
26
85
|
# If we have an active context manager for this stage, exit it
|
|
27
86
|
if active_contexts := self._active_contexts.get(stage):
|
|
28
87
|
active_contexts.__exit__(None, None, None)
|
|
@@ -33,12 +92,11 @@ class ActSaveCallback(Callback):
|
|
|
33
92
|
self._active_contexts[stage] = context
|
|
34
93
|
|
|
35
94
|
def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
|
|
36
|
-
|
|
37
|
-
if not hparams.trainer.actsave:
|
|
95
|
+
if not self.config:
|
|
38
96
|
return
|
|
39
97
|
|
|
40
98
|
# If we have an active context manager for this stage, exit it
|
|
41
|
-
if active_contexts := self._active_contexts.
|
|
99
|
+
if (active_contexts := self._active_contexts.pop(stage, None)) is not None:
|
|
42
100
|
active_contexts.__exit__(None, None, None)
|
|
43
101
|
|
|
44
102
|
@override
|
|
@@ -46,16 +46,16 @@ class CallbackConfigBase(C.Config, ABC):
|
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
@abstractmethod
|
|
49
|
-
def
|
|
49
|
+
def create_callbacks(
|
|
50
50
|
self, root_config: "BaseConfig"
|
|
51
51
|
) -> Iterable[Callback | CallbackWithMetadata]: ...
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
# region Config resolution helpers
|
|
55
|
-
def
|
|
55
|
+
def _create_callbacks_with_metadata(
|
|
56
56
|
config: CallbackConfigBase, root_config: "BaseConfig"
|
|
57
57
|
) -> Iterable[CallbackWithMetadata]:
|
|
58
|
-
for callback in config.
|
|
58
|
+
for callback in config.create_callbacks(root_config):
|
|
59
59
|
if isinstance(callback, CallbackWithMetadata):
|
|
60
60
|
yield callback
|
|
61
61
|
continue
|
|
@@ -99,12 +99,14 @@ def _process_and_filter_callbacks(
|
|
|
99
99
|
|
|
100
100
|
def resolve_all_callbacks(root_config: "BaseConfig"):
|
|
101
101
|
callback_configs = [
|
|
102
|
-
config
|
|
102
|
+
config
|
|
103
|
+
for config in root_config._nshtrainer_all_callback_configs()
|
|
104
|
+
if config is not None
|
|
103
105
|
]
|
|
104
106
|
callbacks = _process_and_filter_callbacks(
|
|
105
107
|
callback
|
|
106
108
|
for callback_config in callback_configs
|
|
107
|
-
for callback in
|
|
109
|
+
for callback in _create_callbacks_with_metadata(callback_config, root_config)
|
|
108
110
|
)
|
|
109
111
|
return callbacks
|
|
110
112
|
|
|
@@ -374,7 +374,7 @@ class EMAConfig(CallbackConfigBase):
|
|
|
374
374
|
"""Offload weights to CPU."""
|
|
375
375
|
|
|
376
376
|
@override
|
|
377
|
-
def
|
|
377
|
+
def create_callbacks(self, root_config):
|
|
378
378
|
yield EMA(
|
|
379
379
|
decay=self.decay,
|
|
380
380
|
validate_original_weights=self.validate_original_weights,
|
|
@@ -68,7 +68,7 @@ class FiniteChecksConfig(CallbackConfigBase):
|
|
|
68
68
|
"""Whether to check for None gradients"""
|
|
69
69
|
|
|
70
70
|
@override
|
|
71
|
-
def
|
|
71
|
+
def create_callbacks(self, root_config):
|
|
72
72
|
yield FiniteChecksCallback(
|
|
73
73
|
nonfinite_grads=self.nonfinite_grads,
|
|
74
74
|
none_grads=self.none_grads,
|