nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +92 -507
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
|
@@ -2,13 +2,14 @@ from . import _experimental as _experimental
|
|
|
2
2
|
from . import callbacks as callbacks
|
|
3
3
|
from . import data as data
|
|
4
4
|
from . import lr_scheduler as lr_scheduler
|
|
5
|
+
from . import metrics as metrics
|
|
5
6
|
from . import model as model
|
|
6
7
|
from . import nn as nn
|
|
7
8
|
from . import optimizer as optimizer
|
|
9
|
+
from .metrics import MetricConfig as MetricConfig
|
|
8
10
|
from .model import Base as Base
|
|
9
11
|
from .model import BaseConfig as BaseConfig
|
|
10
12
|
from .model import ConfigList as ConfigList
|
|
11
13
|
from .model import LightningModuleBase as LightningModuleBase
|
|
12
|
-
from .model import MetricConfig as MetricConfig
|
|
13
14
|
from .runner import Runner as Runner
|
|
14
15
|
from .trainer import Trainer as Trainer
|
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -14,15 +14,27 @@ from .interval import EpochIntervalCallback as EpochIntervalCallback
|
|
|
14
14
|
from .interval import IntervalCallback as IntervalCallback
|
|
15
15
|
from .interval import StepIntervalCallback as StepIntervalCallback
|
|
16
16
|
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
17
|
+
from .latest_epoch_checkpoint import (
|
|
18
|
+
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
19
|
+
)
|
|
17
20
|
from .log_epoch import LogEpochCallback as LogEpochCallback
|
|
21
|
+
from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
22
|
+
from .model_checkpoint import (
|
|
23
|
+
ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
|
|
24
|
+
)
|
|
18
25
|
from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
|
19
26
|
from .norm_logging import NormLoggingConfig as NormLoggingConfig
|
|
20
27
|
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
28
|
+
from .on_exception_checkpoint import (
|
|
29
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
30
|
+
)
|
|
21
31
|
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
|
22
32
|
from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
|
|
23
33
|
from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
24
34
|
from .timer import EpochTimer as EpochTimer
|
|
25
35
|
from .timer import EpochTimerConfig as EpochTimerConfig
|
|
36
|
+
from .wandb_watch import WandbWatchCallback as WandbWatchCallback
|
|
37
|
+
from .wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
26
38
|
|
|
27
39
|
CallbackConfig = Annotated[
|
|
28
40
|
ThroughputMonitorConfig
|
|
@@ -31,6 +43,10 @@ CallbackConfig = Annotated[
|
|
|
31
43
|
| FiniteChecksConfig
|
|
32
44
|
| NormLoggingConfig
|
|
33
45
|
| GradientSkippingConfig
|
|
34
|
-
| EMAConfig
|
|
46
|
+
| EMAConfig
|
|
47
|
+
| ModelCheckpointCallbackConfig
|
|
48
|
+
| LatestEpochCheckpointCallbackConfig
|
|
49
|
+
| OnExceptionCheckpointCallbackConfig
|
|
50
|
+
| WandbWatchConfig,
|
|
35
51
|
C.Field(discriminator="name"),
|
|
36
52
|
]
|
|
@@ -1,28 +1,87 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
3
4
|
|
|
4
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
5
6
|
from lightning.pytorch.callbacks.callback import Callback
|
|
6
|
-
from nshutils.actsave import ActSave
|
|
7
7
|
from typing_extensions import TypeAlias, override
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
from .base import CallbackConfigBase
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from nshutils import ActSave # type: ignore
|
|
13
|
+
except ImportError:
|
|
14
|
+
ActSave = None
|
|
11
15
|
|
|
12
16
|
Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
|
|
13
17
|
|
|
14
18
|
|
|
19
|
+
class ActSaveConfig(CallbackConfigBase):
|
|
20
|
+
enabled: bool = True
|
|
21
|
+
"""Enable activation saving."""
|
|
22
|
+
|
|
23
|
+
save_dir: Path | None = None
|
|
24
|
+
"""Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
|
|
25
|
+
|
|
26
|
+
def __bool__(self):
|
|
27
|
+
return self.enabled
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def create_callbacks(self, root_config):
|
|
31
|
+
yield ActSaveCallback(
|
|
32
|
+
self,
|
|
33
|
+
self.save_dir
|
|
34
|
+
or root_config.directory.resolve_subdirectory(root_config.id, "activation"),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
15
38
|
class ActSaveCallback(Callback):
|
|
16
|
-
def __init__(self):
|
|
39
|
+
def __init__(self, config: ActSaveConfig, save_dir: Path):
|
|
17
40
|
super().__init__()
|
|
18
41
|
|
|
42
|
+
self.config = config
|
|
43
|
+
self.save_dir = save_dir
|
|
44
|
+
self._enabled_context: contextlib._GeneratorContextManager | None = None
|
|
19
45
|
self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
|
|
20
46
|
|
|
47
|
+
@override
|
|
48
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
49
|
+
super().setup(trainer, pl_module, stage)
|
|
50
|
+
|
|
51
|
+
if not self.config:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
if ActSave is None:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"ActSave is not installed. Please install nshutils to use the ActSaveCallback."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
context = ActSave.enabled(self.save_dir)
|
|
60
|
+
context.__enter__()
|
|
61
|
+
self._enabled_context = context
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def teardown(
|
|
65
|
+
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
|
66
|
+
) -> None:
|
|
67
|
+
super().teardown(trainer, pl_module, stage)
|
|
68
|
+
|
|
69
|
+
if not self.config:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
if self._enabled_context is not None:
|
|
73
|
+
self._enabled_context.__exit__(None, None, None)
|
|
74
|
+
self._enabled_context = None
|
|
75
|
+
|
|
21
76
|
def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
|
|
22
|
-
|
|
23
|
-
if not hparams.trainer.actsave:
|
|
77
|
+
if not self.config:
|
|
24
78
|
return
|
|
25
79
|
|
|
80
|
+
if ActSave is None:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
"ActSave is not installed. Please install nshutils to use the ActSaveCallback."
|
|
83
|
+
)
|
|
84
|
+
|
|
26
85
|
# If we have an active context manager for this stage, exit it
|
|
27
86
|
if active_contexts := self._active_contexts.get(stage):
|
|
28
87
|
active_contexts.__exit__(None, None, None)
|
|
@@ -33,12 +92,11 @@ class ActSaveCallback(Callback):
|
|
|
33
92
|
self._active_contexts[stage] = context
|
|
34
93
|
|
|
35
94
|
def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
|
|
36
|
-
|
|
37
|
-
if not hparams.trainer.actsave:
|
|
95
|
+
if not self.config:
|
|
38
96
|
return
|
|
39
97
|
|
|
40
98
|
# If we have an active context manager for this stage, exit it
|
|
41
|
-
if active_contexts := self._active_contexts.
|
|
99
|
+
if (active_contexts := self._active_contexts.pop(stage, None)) is not None:
|
|
42
100
|
active_contexts.__exit__(None, None, None)
|
|
43
101
|
|
|
44
102
|
@override
|
nshtrainer/callbacks/base.py
CHANGED
|
@@ -46,16 +46,16 @@ class CallbackConfigBase(C.Config, ABC):
|
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
@abstractmethod
|
|
49
|
-
def
|
|
49
|
+
def create_callbacks(
|
|
50
50
|
self, root_config: "BaseConfig"
|
|
51
51
|
) -> Iterable[Callback | CallbackWithMetadata]: ...
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
# region Config resolution helpers
|
|
55
|
-
def
|
|
55
|
+
def _create_callbacks_with_metadata(
|
|
56
56
|
config: CallbackConfigBase, root_config: "BaseConfig"
|
|
57
57
|
) -> Iterable[CallbackWithMetadata]:
|
|
58
|
-
for callback in config.
|
|
58
|
+
for callback in config.create_callbacks(root_config):
|
|
59
59
|
if isinstance(callback, CallbackWithMetadata):
|
|
60
60
|
yield callback
|
|
61
61
|
continue
|
|
@@ -99,12 +99,14 @@ def _process_and_filter_callbacks(
|
|
|
99
99
|
|
|
100
100
|
def resolve_all_callbacks(root_config: "BaseConfig"):
|
|
101
101
|
callback_configs = [
|
|
102
|
-
config
|
|
102
|
+
config
|
|
103
|
+
for config in root_config._nshtrainer_all_callback_configs()
|
|
104
|
+
if config is not None
|
|
103
105
|
]
|
|
104
106
|
callbacks = _process_and_filter_callbacks(
|
|
105
107
|
callback
|
|
106
108
|
for callback_config in callback_configs
|
|
107
|
-
for callback in
|
|
109
|
+
for callback in _create_callbacks_with_metadata(callback_config, root_config)
|
|
108
110
|
)
|
|
109
111
|
return callbacks
|
|
110
112
|
|
nshtrainer/callbacks/ema.py
CHANGED
|
@@ -374,7 +374,7 @@ class EMAConfig(CallbackConfigBase):
|
|
|
374
374
|
"""Offload weights to CPU."""
|
|
375
375
|
|
|
376
376
|
@override
|
|
377
|
-
def
|
|
377
|
+
def create_callbacks(self, root_config):
|
|
378
378
|
yield EMA(
|
|
379
379
|
decay=self.decay,
|
|
380
380
|
validate_original_weights=self.validate_original_weights,
|
|
@@ -68,7 +68,7 @@ class FiniteChecksConfig(CallbackConfigBase):
|
|
|
68
68
|
"""Whether to check for None gradients"""
|
|
69
69
|
|
|
70
70
|
@override
|
|
71
|
-
def
|
|
71
|
+
def create_callbacks(self, root_config):
|
|
72
72
|
yield FiniteChecksCallback(
|
|
73
73
|
nonfinite_grads=self.nonfinite_grads,
|
|
74
74
|
none_grads=self.none_grads,
|
|
@@ -1,35 +1,54 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
3
4
|
|
|
4
|
-
from lightning.fabric.utilities.types import _PATH
|
|
5
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
6
6
|
from lightning.pytorch.callbacks import Checkpoint
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
+
from .base import CallbackConfigBase
|
|
10
|
+
|
|
9
11
|
log = logging.getLogger(__name__)
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
15
|
+
kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
16
|
+
|
|
17
|
+
dirpath: str | Path | None = None
|
|
18
|
+
"""Directory path to save the checkpoint file."""
|
|
19
|
+
|
|
20
|
+
filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
|
|
21
|
+
"""Checkpoint filename. This must not include the extension."""
|
|
22
|
+
|
|
23
|
+
save_weights_only: bool = False
|
|
24
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
25
|
+
|
|
26
|
+
latest_symlink_filename: str | None = "latest.ckpt"
|
|
27
|
+
"""Filename for the latest symlink. If None, no symlink will be created."""
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def create_callbacks(self, root_config):
|
|
31
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
32
|
+
root_config.id, "checkpoint"
|
|
33
|
+
)
|
|
34
|
+
dirpath = Path(dirpath)
|
|
35
|
+
|
|
36
|
+
yield LatestEpochCheckpoint(self, dirpath)
|
|
37
|
+
|
|
38
|
+
|
|
12
39
|
class LatestEpochCheckpoint(Checkpoint):
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
dirpath: _PATH,
|
|
18
|
-
filename: str | None = None,
|
|
19
|
-
save_weights_only: bool = False,
|
|
20
|
-
):
|
|
40
|
+
def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
|
|
21
41
|
super().__init__()
|
|
22
42
|
|
|
23
|
-
self.
|
|
24
|
-
self.
|
|
25
|
-
self._save_weights_only = save_weights_only
|
|
43
|
+
self.config = config
|
|
44
|
+
self.dirpath = dirpath
|
|
26
45
|
|
|
27
46
|
# Also, we hold a reference to the last checkpoint path
|
|
28
47
|
# to be able to remove it when a new checkpoint is saved.
|
|
29
48
|
self._last_ckpt_path: Path | None = None
|
|
30
49
|
|
|
31
50
|
def _ckpt_path(self, trainer: Trainer):
|
|
32
|
-
return self.
|
|
51
|
+
return self.dirpath / self.config.filename.format(
|
|
33
52
|
epoch=trainer.current_epoch, step=trainer.global_step
|
|
34
53
|
)
|
|
35
54
|
|
|
@@ -41,5 +60,22 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
41
60
|
|
|
42
61
|
# Save the new checkpoint
|
|
43
62
|
filepath = self._ckpt_path(trainer)
|
|
44
|
-
trainer.save_checkpoint(filepath, self.
|
|
63
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
45
64
|
self._last_ckpt_path = filepath
|
|
65
|
+
|
|
66
|
+
# Create the latest symlink
|
|
67
|
+
if (symlink_filename := self.config.latest_symlink_filename) is not None:
|
|
68
|
+
symlink_path = self.dirpath / symlink_filename
|
|
69
|
+
if symlink_path.exists():
|
|
70
|
+
symlink_path.unlink()
|
|
71
|
+
symlink_path.symlink_to(filepath.name)
|
|
72
|
+
log.info(f"Created latest symlink: {symlink_path}")
|
|
73
|
+
|
|
74
|
+
def latest_checkpoint(self):
|
|
75
|
+
if (symlink_filename := self.config.latest_symlink_filename) is None:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
if not (symlink_path := self.dirpath / symlink_filename).exists():
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
return symlink_path
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
|
|
7
|
+
from lightning.pytorch.callbacks.model_checkpoint import (
|
|
8
|
+
ModelCheckpoint as _ModelCheckpoint,
|
|
9
|
+
)
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from ..metrics import MetricConfig
|
|
13
|
+
from .base import CallbackConfigBase
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..model.config import BaseConfig
|
|
17
|
+
|
|
18
|
+
log = getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _convert_string(input_string: str):
|
|
22
|
+
# Find all variables enclosed in curly braces
|
|
23
|
+
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
24
|
+
|
|
25
|
+
# Replace each variable with its corresponding key-value pair
|
|
26
|
+
output_string = input_string
|
|
27
|
+
for variable in variables:
|
|
28
|
+
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
29
|
+
key_name = variable
|
|
30
|
+
if ":" in variable:
|
|
31
|
+
key_name, _ = variable.split(":", 1)
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
# Replace '/' with '_' in the key name
|
|
35
|
+
key_name = key_name.replace("/", "_")
|
|
36
|
+
output_string = output_string.replace(
|
|
37
|
+
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return output_string
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
44
|
+
"""Arguments for the ModelCheckpoint callback."""
|
|
45
|
+
|
|
46
|
+
kind: Literal["model_checkpoint"] = "model_checkpoint"
|
|
47
|
+
|
|
48
|
+
dirpath: str | Path | None = None
|
|
49
|
+
"""
|
|
50
|
+
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
filename: str | None = None
|
|
54
|
+
"""
|
|
55
|
+
Checkpoint filename.
|
|
56
|
+
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
metric: MetricConfig | None = None
|
|
60
|
+
"""
|
|
61
|
+
Metric to monitor for saving checkpoints.
|
|
62
|
+
If None, the primary metric of the runner will be used, if available.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
verbose: bool = False
|
|
66
|
+
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
67
|
+
|
|
68
|
+
save_last: Literal[True, False, "link"] | None = "link"
|
|
69
|
+
"""
|
|
70
|
+
Whether to save the last checkpoint.
|
|
71
|
+
If True, saves a copy of the last checkpoint separately.
|
|
72
|
+
If "link", creates a symbolic link to the last checkpoint.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
save_top_k: int = 1
|
|
76
|
+
"""
|
|
77
|
+
Number of best models to save.
|
|
78
|
+
If -1, all models are saved.
|
|
79
|
+
If 0, no models are saved.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
save_weights_only: bool = False
|
|
83
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
84
|
+
|
|
85
|
+
auto_insert_metric_name: bool = True
|
|
86
|
+
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
87
|
+
|
|
88
|
+
every_n_train_steps: int | None = None
|
|
89
|
+
"""
|
|
90
|
+
Number of training steps between checkpoints.
|
|
91
|
+
If None or 0, no checkpoints are saved during training.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
train_time_interval: timedelta | None = None
|
|
95
|
+
"""
|
|
96
|
+
Time interval between checkpoints during training.
|
|
97
|
+
If None, no checkpoints are saved during training based on time.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
every_n_epochs: int | None = None
|
|
101
|
+
"""
|
|
102
|
+
Number of epochs between checkpoints.
|
|
103
|
+
If None or 0, no checkpoints are saved at the end of epochs.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
save_on_train_epoch_end: bool | None = None
|
|
107
|
+
"""
|
|
108
|
+
Whether to run checkpointing at the end of the training epoch.
|
|
109
|
+
If False, checkpointing runs at the end of the validation.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
enable_version_counter: bool = True
|
|
113
|
+
"""Whether to append a version to the existing file name."""
|
|
114
|
+
|
|
115
|
+
auto_append_metric: bool = True
|
|
116
|
+
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
117
|
+
|
|
118
|
+
def metric_or_default(self, root_config: "BaseConfig"):
|
|
119
|
+
if self.metric is not None:
|
|
120
|
+
return self.metric
|
|
121
|
+
if root_config.primary_metric is not None:
|
|
122
|
+
return root_config.primary_metric
|
|
123
|
+
raise ValueError("Primary metric must be provided if metric is not specified.")
|
|
124
|
+
|
|
125
|
+
def resolve_filename(self, root_config: "BaseConfig"):
|
|
126
|
+
metric = self.metric_or_default(root_config)
|
|
127
|
+
|
|
128
|
+
filename = self.filename
|
|
129
|
+
if not filename:
|
|
130
|
+
filename = "{epoch}-{step}"
|
|
131
|
+
if self.auto_append_metric:
|
|
132
|
+
filename = f"{filename}-{{{metric.validation_monitor}}}"
|
|
133
|
+
|
|
134
|
+
if self.auto_insert_metric_name and filename:
|
|
135
|
+
new_filename = _convert_string(filename)
|
|
136
|
+
log.critical(
|
|
137
|
+
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
138
|
+
)
|
|
139
|
+
filename = new_filename
|
|
140
|
+
|
|
141
|
+
return filename
|
|
142
|
+
|
|
143
|
+
@override
|
|
144
|
+
def create_callbacks(self, root_config):
|
|
145
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
146
|
+
root_config.id, "checkpoint"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
metric = self.metric_or_default(root_config)
|
|
150
|
+
filename = self.resolve_filename(root_config)
|
|
151
|
+
|
|
152
|
+
yield ModelCheckpoint(
|
|
153
|
+
self,
|
|
154
|
+
dirpath=Path(dirpath),
|
|
155
|
+
filename=filename,
|
|
156
|
+
metric=metric,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class ModelCheckpoint(_ModelCheckpoint):
|
|
161
|
+
@override
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
config: ModelCheckpointCallbackConfig,
|
|
165
|
+
dirpath: Path,
|
|
166
|
+
filename: str,
|
|
167
|
+
metric: MetricConfig,
|
|
168
|
+
):
|
|
169
|
+
self.config = config
|
|
170
|
+
del config
|
|
171
|
+
|
|
172
|
+
super().__init__(
|
|
173
|
+
dirpath=dirpath,
|
|
174
|
+
filename=filename,
|
|
175
|
+
monitor=metric.validation_monitor,
|
|
176
|
+
mode=metric.mode,
|
|
177
|
+
verbose=self.config.verbose,
|
|
178
|
+
save_last=self.config.save_last,
|
|
179
|
+
save_top_k=self.config.save_top_k,
|
|
180
|
+
save_weights_only=self.config.save_weights_only,
|
|
181
|
+
auto_insert_metric_name=False,
|
|
182
|
+
every_n_train_steps=self.config.every_n_train_steps,
|
|
183
|
+
train_time_interval=self.config.train_time_interval,
|
|
184
|
+
every_n_epochs=self.config.every_n_epochs,
|
|
185
|
+
save_on_train_epoch_end=self.config.save_on_train_epoch_end,
|
|
186
|
+
enable_version_counter=self.config.enable_version_counter,
|
|
187
|
+
)
|
|
@@ -1,16 +1,82 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import datetime
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal
|
|
5
7
|
|
|
6
|
-
from lightning.pytorch import Trainer
|
|
8
|
+
from lightning.pytorch import Trainer as LightningTrainer
|
|
7
9
|
from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
|
|
8
10
|
from typing_extensions import override
|
|
9
11
|
|
|
12
|
+
from .base import CallbackConfigBase
|
|
13
|
+
|
|
10
14
|
log = logging.getLogger(__name__)
|
|
11
15
|
|
|
12
16
|
|
|
17
|
+
@contextlib.contextmanager
|
|
18
|
+
def _monkey_patch_disable_barrier(trainer: LightningTrainer):
|
|
19
|
+
"""
|
|
20
|
+
Monkey-patch the strategy instance to make the barrier operation a no-op.
|
|
21
|
+
|
|
22
|
+
We do this because `save_checkpoint` calls `barrier`. This is okay in most
|
|
23
|
+
cases, but when we want to save a checkpoint in the case of an exception,
|
|
24
|
+
`barrier` causes a deadlock. So we monkey-patch the strategy instance to
|
|
25
|
+
make the barrier operation a no-op.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# We monkey-patch the barrier method to do nothing.
|
|
29
|
+
original_barrier = trainer.strategy.barrier
|
|
30
|
+
|
|
31
|
+
def new_barrier(*args, **kwargs):
|
|
32
|
+
log.warning("Monkey-patched no-op barrier.")
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
trainer.strategy.barrier = new_barrier
|
|
36
|
+
log.warning("Monkey-patched barrier to no-op.")
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
yield
|
|
40
|
+
finally:
|
|
41
|
+
trainer.strategy.barrier = original_barrier
|
|
42
|
+
log.warning("Reverted monkey-patched barrier.")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
46
|
+
kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
|
|
47
|
+
|
|
48
|
+
dirpath: str | Path | None = None
|
|
49
|
+
"""Directory path to save the checkpoint file."""
|
|
50
|
+
|
|
51
|
+
filename: str | None = None
|
|
52
|
+
"""Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def create_callbacks(self, root_config):
|
|
56
|
+
from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
|
57
|
+
|
|
58
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
59
|
+
root_config.id, "checkpoint"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not (filename := self.filename):
|
|
63
|
+
filename = f"on_exception_{root_config.id}"
|
|
64
|
+
yield OnExceptionCheckpoint(self, dirpath=Path(dirpath), filename=filename)
|
|
65
|
+
|
|
66
|
+
|
|
13
67
|
class OnExceptionCheckpoint(_OnExceptionCheckpoint):
|
|
68
|
+
@override
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
config: OnExceptionCheckpointCallbackConfig,
|
|
72
|
+
dirpath: Path,
|
|
73
|
+
filename: str,
|
|
74
|
+
):
|
|
75
|
+
self.config = config
|
|
76
|
+
del config
|
|
77
|
+
|
|
78
|
+
super().__init__(dirpath, filename)
|
|
79
|
+
|
|
14
80
|
@property
|
|
15
81
|
@override
|
|
16
82
|
def ckpt_path(self) -> str:
|
|
@@ -22,23 +88,11 @@ class OnExceptionCheckpoint(_OnExceptionCheckpoint):
|
|
|
22
88
|
return f"{ckpt_path}_{timestamp}{ext}"
|
|
23
89
|
|
|
24
90
|
@override
|
|
25
|
-
def on_exception(self, trainer:
|
|
26
|
-
#
|
|
27
|
-
#
|
|
28
|
-
|
|
29
|
-
#
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
|
|
34
|
-
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
|
|
35
|
-
)
|
|
36
|
-
checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
|
|
37
|
-
trainer.strategy.save_checkpoint(
|
|
38
|
-
checkpoint, self.ckpt_path, storage_options=None
|
|
39
|
-
)
|
|
40
|
-
# self.strategy.barrier("Trainer.save_checkpoint") # <-- This is disabled
|
|
41
|
-
|
|
42
|
-
@override
|
|
43
|
-
def teardown(self, trainer: Trainer, *_: Any, **__: Any) -> None:
|
|
44
|
-
trainer.strategy.remove_checkpoint(self.ckpt_path)
|
|
91
|
+
def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
|
|
92
|
+
# Monkey-patch the strategy instance to make the barrier operation a no-op.
|
|
93
|
+
# We do this because `save_checkpoint` calls `barrier`. This is okay in most
|
|
94
|
+
# cases, but when we want to save a checkpoint in the case of an exception,
|
|
95
|
+
# `barrier` causes a deadlock. So we monkey-patch the strategy instance to
|
|
96
|
+
# make the barrier operation a no-op.
|
|
97
|
+
with _monkey_patch_disable_barrier(trainer):
|
|
98
|
+
return super().on_exception(trainer, *args, **kwargs)
|
|
@@ -86,5 +86,5 @@ class PrintTableMetricsConfig(CallbackConfigBase):
|
|
|
86
86
|
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
|
87
87
|
|
|
88
88
|
@override
|
|
89
|
-
def
|
|
89
|
+
def create_callbacks(self, root_config):
|
|
90
90
|
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
|
@@ -52,5 +52,5 @@ class ThroughputMonitorConfig(CallbackConfigBase):
|
|
|
52
52
|
"""Number of batches to use for a rolling average."""
|
|
53
53
|
|
|
54
54
|
@override
|
|
55
|
-
def
|
|
55
|
+
def create_callbacks(self, root_config):
|
|
56
56
|
yield ThroughputMonitor(window_size=self.window_size)
|
nshtrainer/callbacks/timer.py
CHANGED
nshtrainer/ll/__init__.py
CHANGED
|
@@ -21,7 +21,6 @@ from .log import init_python_logging as init_python_logging
|
|
|
21
21
|
from .log import lovely as lovely
|
|
22
22
|
from .log import pretty as pretty
|
|
23
23
|
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
|
-
from .model import ActSaveConfig as ActSaveConfig
|
|
25
24
|
from .model import Base as Base
|
|
26
25
|
from .model import BaseConfig as BaseConfig
|
|
27
26
|
from .model import BaseLoggerConfig as BaseLoggerConfig
|