nshtrainer 0.11.8__py3-none-any.whl → 0.11.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/__init__.py +0 -8
- nshtrainer/callbacks/checkpoint/__init__.py +0 -8
- nshtrainer/model/__init__.py +2 -4
- nshtrainer/model/config.py +1 -36
- {nshtrainer-0.11.8.dist-info → nshtrainer-0.11.9.dist-info}/METADATA +1 -1
- {nshtrainer-0.11.8.dist-info → nshtrainer-0.11.9.dist-info}/RECORD +7 -9
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +0 -131
- nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -207
- {nshtrainer-0.11.8.dist-info → nshtrainer-0.11.9.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -8,12 +8,6 @@ from .checkpoint import BestCheckpoint as BestCheckpoint
|
|
|
8
8
|
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
9
9
|
from .checkpoint import LastCheckpoint as LastCheckpoint
|
|
10
10
|
from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
11
|
-
from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
12
|
-
from .checkpoint import (
|
|
13
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
14
|
-
)
|
|
15
|
-
from .checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
16
|
-
from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
17
11
|
from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
18
12
|
from .checkpoint import (
|
|
19
13
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
@@ -49,8 +43,6 @@ CallbackConfig = Annotated[
|
|
|
49
43
|
| EMAConfig
|
|
50
44
|
| BestCheckpointCallbackConfig
|
|
51
45
|
| LastCheckpointCallbackConfig
|
|
52
|
-
| ModelCheckpointCallbackConfig
|
|
53
|
-
| LatestEpochCheckpointCallbackConfig
|
|
54
46
|
| OnExceptionCheckpointCallbackConfig
|
|
55
47
|
| WandbWatchConfig,
|
|
56
48
|
C.Field(discriminator="name"),
|
|
@@ -6,14 +6,6 @@ from .last_checkpoint import LastCheckpoint as LastCheckpoint
|
|
|
6
6
|
from .last_checkpoint import (
|
|
7
7
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
8
8
|
)
|
|
9
|
-
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
10
|
-
from .latest_epoch_checkpoint import (
|
|
11
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
12
|
-
)
|
|
13
|
-
from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
14
|
-
from .model_checkpoint import (
|
|
15
|
-
ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
|
|
16
|
-
)
|
|
17
9
|
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
18
10
|
from .on_exception_checkpoint import (
|
|
19
11
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -5,17 +5,15 @@ from .base import LightningModuleBase as LightningModuleBase
|
|
|
5
5
|
from .config import BaseConfig as BaseConfig
|
|
6
6
|
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
7
7
|
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
8
|
+
from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
8
9
|
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
9
10
|
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
10
11
|
from .config import DirectoryConfig as DirectoryConfig
|
|
11
12
|
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
12
13
|
from .config import GradientClippingConfig as GradientClippingConfig
|
|
13
|
-
from .config import
|
|
14
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
15
|
-
)
|
|
14
|
+
from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
16
15
|
from .config import LoggingConfig as LoggingConfig
|
|
17
16
|
from .config import MetricConfig as MetricConfig
|
|
18
|
-
from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
19
17
|
from .config import (
|
|
20
18
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
21
19
|
)
|
nshtrainer/model/config.py
CHANGED
|
@@ -40,8 +40,6 @@ from ..callbacks import (
|
|
|
40
40
|
BestCheckpointCallbackConfig,
|
|
41
41
|
CallbackConfig,
|
|
42
42
|
LastCheckpointCallbackConfig,
|
|
43
|
-
LatestEpochCheckpointCallbackConfig,
|
|
44
|
-
ModelCheckpointCallbackConfig,
|
|
45
43
|
OnExceptionCheckpointCallbackConfig,
|
|
46
44
|
WandbWatchConfig,
|
|
47
45
|
)
|
|
@@ -772,10 +770,8 @@ class ReproducibilityConfig(C.Config):
|
|
|
772
770
|
|
|
773
771
|
|
|
774
772
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
775
|
-
|
|
776
|
-
| BestCheckpointCallbackConfig
|
|
773
|
+
BestCheckpointCallbackConfig
|
|
777
774
|
| LastCheckpointCallbackConfig
|
|
778
|
-
| LatestEpochCheckpointCallbackConfig
|
|
779
775
|
| OnExceptionCheckpointCallbackConfig,
|
|
780
776
|
C.Field(discriminator="name"),
|
|
781
777
|
]
|
|
@@ -786,7 +782,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
786
782
|
"""Enable checkpoint saving."""
|
|
787
783
|
|
|
788
784
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
789
|
-
# ModelCheckpointCallbackConfig(),
|
|
790
785
|
BestCheckpointCallbackConfig(),
|
|
791
786
|
LastCheckpointCallbackConfig(),
|
|
792
787
|
OnExceptionCheckpointCallbackConfig(),
|
|
@@ -806,36 +801,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
806
801
|
|
|
807
802
|
return True
|
|
808
803
|
|
|
809
|
-
@property
|
|
810
|
-
def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
|
|
811
|
-
return next(
|
|
812
|
-
(
|
|
813
|
-
callback
|
|
814
|
-
for callback in self.checkpoint_callbacks
|
|
815
|
-
if isinstance(callback, ModelCheckpointCallbackConfig)
|
|
816
|
-
),
|
|
817
|
-
)
|
|
818
|
-
|
|
819
|
-
@property
|
|
820
|
-
def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
|
|
821
|
-
return next(
|
|
822
|
-
(
|
|
823
|
-
callback
|
|
824
|
-
for callback in self.checkpoint_callbacks
|
|
825
|
-
if isinstance(callback, LatestEpochCheckpointCallbackConfig)
|
|
826
|
-
),
|
|
827
|
-
)
|
|
828
|
-
|
|
829
|
-
@property
|
|
830
|
-
def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
|
|
831
|
-
return next(
|
|
832
|
-
(
|
|
833
|
-
callback
|
|
834
|
-
for callback in self.checkpoint_callbacks
|
|
835
|
-
if isinstance(callback, OnExceptionCheckpointCallbackConfig)
|
|
836
|
-
),
|
|
837
|
-
)
|
|
838
|
-
|
|
839
804
|
@override
|
|
840
805
|
def create_callbacks(self, root_config: "BaseConfig"):
|
|
841
806
|
if not self.should_save_checkpoints(root_config):
|
|
@@ -6,16 +6,14 @@ nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ
|
|
|
6
6
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
7
7
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
8
8
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=k-DbpIlH2t5-oR3gHGHr8KiyCd_Twers4PcIUM1noqQ,2262
|
|
10
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
|
-
nshtrainer/callbacks/checkpoint/__init__.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
14
14
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=9HQSa-toOyjtuDldQ71gaDVBRdryAaB_nRv5Y554tIk,5938
|
|
15
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=O6d4SqIYsxpnRj_IYX8A9VLgOBwxTdz-j2FV_nn3BT8,2067
|
|
16
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CM8f37dwaYHkjQFfJNTZTzSoF45zEjFRm-Fg1CzYmP4,1037
|
|
17
|
-
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=II-WZHAk7leK1Vgjza0PVifrF3QetR9Nn3n1qhqtuVo,4819
|
|
18
|
-
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=JS1z2YuEiQxk61HgZU1jySzF_pzdfXYO54_qHo-q3CQ,6776
|
|
19
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
20
18
|
nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
|
|
21
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -54,9 +52,9 @@ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9b
|
|
|
54
52
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
55
53
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
56
54
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
57
|
-
nshtrainer/model/__init__.py,sha256=
|
|
55
|
+
nshtrainer/model/__init__.py,sha256=RlGW5a46DZcqK6cYICYxDaKpZIEj-8zLxoMrl432tno,1429
|
|
58
56
|
nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
|
|
59
|
-
nshtrainer/model/config.py,sha256=
|
|
57
|
+
nshtrainer/model/config.py,sha256=F-doUiqPVw4yepT6FRqhflEiiMEWr_PuFC6lKzFWktA,53809
|
|
60
58
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
61
59
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
62
60
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -84,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
84
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
85
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
86
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
87
|
-
nshtrainer-0.11.
|
|
88
|
-
nshtrainer-0.11.
|
|
89
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.9.dist-info/METADATA,sha256=rFmi4wYXJz8srZhqFl_5ROxfmgFru7jhYUpUp7ZZjMg,860
|
|
86
|
+
nshtrainer-0.11.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.9.dist-info/RECORD,,
|
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
-
from lightning.pytorch.callbacks import Checkpoint
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from ..._checkpoint.metadata import _sort_ckpts_by_metadata
|
|
10
|
-
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
11
|
-
from ..base import CallbackConfigBase
|
|
12
|
-
|
|
13
|
-
log = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
17
|
-
name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
18
|
-
|
|
19
|
-
dirpath: str | Path | None = None
|
|
20
|
-
"""Directory path to save the checkpoint file."""
|
|
21
|
-
|
|
22
|
-
filename: str = "epoch{epoch:03d}_step{step:07d}"
|
|
23
|
-
"""Checkpoint filename. This must not include the extension."""
|
|
24
|
-
|
|
25
|
-
save_weights_only: bool = False
|
|
26
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
27
|
-
|
|
28
|
-
latest_symlink_filename: str | None = "latest"
|
|
29
|
-
"""Filename for the latest symlink. If None, no symlink will be created."""
|
|
30
|
-
|
|
31
|
-
latest_k: int | Literal["all"] = 1
|
|
32
|
-
"""Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
|
|
33
|
-
|
|
34
|
-
@override
|
|
35
|
-
def create_callbacks(self, root_config):
|
|
36
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
37
|
-
root_config.id, "checkpoint"
|
|
38
|
-
)
|
|
39
|
-
dirpath = Path(dirpath)
|
|
40
|
-
|
|
41
|
-
yield LatestEpochCheckpoint(self, dirpath)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class LatestEpochCheckpoint(Checkpoint):
|
|
45
|
-
PREFIX = "latest_"
|
|
46
|
-
EXTENSION = ".ckpt"
|
|
47
|
-
|
|
48
|
-
def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
|
|
49
|
-
super().__init__()
|
|
50
|
-
|
|
51
|
-
self.config = config
|
|
52
|
-
self.dirpath = dirpath
|
|
53
|
-
|
|
54
|
-
self._last_global_step_saved = 0
|
|
55
|
-
|
|
56
|
-
@override
|
|
57
|
-
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
58
|
-
self._save_new_checkpoint(trainer)
|
|
59
|
-
|
|
60
|
-
def _latest_symlink_filename(self):
|
|
61
|
-
if (filename := self.config.latest_symlink_filename) is None:
|
|
62
|
-
return None
|
|
63
|
-
return f"{filename}{self.EXTENSION}"
|
|
64
|
-
|
|
65
|
-
def _ckpt_path(self, trainer: Trainer):
|
|
66
|
-
filename = self.config.filename.format(
|
|
67
|
-
epoch=trainer.current_epoch, step=trainer.global_step
|
|
68
|
-
)
|
|
69
|
-
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
70
|
-
return self.dirpath / filename
|
|
71
|
-
|
|
72
|
-
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
73
|
-
if (latest_k := self.config.latest_k) == "all":
|
|
74
|
-
return
|
|
75
|
-
|
|
76
|
-
# Get all configs, ignoring the latest symlink
|
|
77
|
-
ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
78
|
-
# Ignore the latest symlink
|
|
79
|
-
if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
|
|
80
|
-
ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
|
|
81
|
-
|
|
82
|
-
# Sort by epoch, then step, then last modified
|
|
83
|
-
ckpts = _sort_ckpts_by_metadata(
|
|
84
|
-
ckpts,
|
|
85
|
-
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
86
|
-
reverse=True,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Remove all but the latest k checkpoints
|
|
90
|
-
# NOTE: We add 1 to the latest_k here because
|
|
91
|
-
# we're about to save a new checkpoint.
|
|
92
|
-
for _, ckpt_path in ckpts[latest_k:]:
|
|
93
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
94
|
-
|
|
95
|
-
def _save_new_checkpoint(self, trainer: Trainer):
|
|
96
|
-
if self._should_skip_saving_checkpoint(trainer):
|
|
97
|
-
return
|
|
98
|
-
|
|
99
|
-
# Save the new checkpoint
|
|
100
|
-
filepath = self._ckpt_path(trainer)
|
|
101
|
-
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
102
|
-
|
|
103
|
-
if trainer.is_global_zero:
|
|
104
|
-
# Remove old checkpoints
|
|
105
|
-
self._remove_old_checkpoints(trainer)
|
|
106
|
-
|
|
107
|
-
# Create the latest symlink
|
|
108
|
-
if (symlink_filename := self._latest_symlink_filename()) is not None:
|
|
109
|
-
symlink_path = self.dirpath / symlink_filename
|
|
110
|
-
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
111
|
-
log.debug(f"Created latest symlink: {symlink_path}")
|
|
112
|
-
|
|
113
|
-
# Set the last global step saved
|
|
114
|
-
self._last_global_step_saved = trainer.global_step
|
|
115
|
-
|
|
116
|
-
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
117
|
-
trainer.strategy.barrier()
|
|
118
|
-
|
|
119
|
-
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
120
|
-
from lightning.pytorch.trainer.states import TrainerFn
|
|
121
|
-
|
|
122
|
-
return (
|
|
123
|
-
bool(
|
|
124
|
-
getattr(trainer, "fast_dev_run", False)
|
|
125
|
-
) # disable checkpointing with fast_dev_run
|
|
126
|
-
or trainer.state.fn
|
|
127
|
-
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
128
|
-
or trainer.sanity_checking # don't save anything during sanity check
|
|
129
|
-
or self._last_global_step_saved
|
|
130
|
-
== trainer.global_step # already saved at the last step
|
|
131
|
-
)
|
|
@@ -1,207 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import re
|
|
3
|
-
from datetime import timedelta
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import TYPE_CHECKING, Literal
|
|
6
|
-
|
|
7
|
-
from lightning.pytorch import Trainer
|
|
8
|
-
from lightning.pytorch.callbacks.model_checkpoint import (
|
|
9
|
-
ModelCheckpoint as _ModelCheckpoint,
|
|
10
|
-
)
|
|
11
|
-
from typing_extensions import override
|
|
12
|
-
|
|
13
|
-
from ..._checkpoint.saver import _link_checkpoint
|
|
14
|
-
from ..._checkpoint.saver import _remove_checkpoint as _ckpt_saver_remove_checkpoint
|
|
15
|
-
from ...metrics import MetricConfig
|
|
16
|
-
from ..base import CallbackConfigBase
|
|
17
|
-
|
|
18
|
-
if TYPE_CHECKING:
|
|
19
|
-
from ...model.config import BaseConfig
|
|
20
|
-
|
|
21
|
-
log = logging.getLogger(__name__)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def _convert_string(input_string: str):
|
|
25
|
-
# Find all variables enclosed in curly braces
|
|
26
|
-
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
27
|
-
|
|
28
|
-
# Replace each variable with its corresponding key-value pair
|
|
29
|
-
output_string = input_string
|
|
30
|
-
for variable in variables:
|
|
31
|
-
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
32
|
-
key_name = variable
|
|
33
|
-
if ":" in variable:
|
|
34
|
-
key_name, _ = variable.split(":", 1)
|
|
35
|
-
continue
|
|
36
|
-
|
|
37
|
-
# Replace '/' with '_' in the key name
|
|
38
|
-
key_name = key_name.replace("/", "_")
|
|
39
|
-
output_string = output_string.replace(
|
|
40
|
-
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
return output_string
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
47
|
-
"""Arguments for the ModelCheckpoint callback."""
|
|
48
|
-
|
|
49
|
-
name: Literal["model_checkpoint"] = "model_checkpoint"
|
|
50
|
-
|
|
51
|
-
dirpath: str | Path | None = None
|
|
52
|
-
"""
|
|
53
|
-
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
filename: str | None = None
|
|
57
|
-
"""
|
|
58
|
-
Checkpoint filename.
|
|
59
|
-
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
metric: MetricConfig | None = None
|
|
63
|
-
"""
|
|
64
|
-
Metric to monitor for saving checkpoints.
|
|
65
|
-
If None, the primary metric of the runner will be used, if available.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
verbose: bool = False
|
|
69
|
-
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
70
|
-
|
|
71
|
-
save_last: Literal[True, False, "link"] | None = "link"
|
|
72
|
-
"""
|
|
73
|
-
Whether to save the last checkpoint.
|
|
74
|
-
If True, saves a copy of the last checkpoint separately.
|
|
75
|
-
If "link", creates a symbolic link to the last checkpoint.
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
save_top_k: int | Literal["all"] = 1
|
|
79
|
-
"""
|
|
80
|
-
Number of best models to save.
|
|
81
|
-
If "all" or -1, all models are saved.
|
|
82
|
-
If 0, no models are saved.
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
save_weights_only: bool = False
|
|
86
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
87
|
-
|
|
88
|
-
auto_insert_metric_name: bool = True
|
|
89
|
-
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
90
|
-
|
|
91
|
-
every_n_train_steps: int | None = None
|
|
92
|
-
"""
|
|
93
|
-
Number of training steps between checkpoints.
|
|
94
|
-
If None or 0, no checkpoints are saved during training.
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
train_time_interval: timedelta | None = None
|
|
98
|
-
"""
|
|
99
|
-
Time interval between checkpoints during training.
|
|
100
|
-
If None, no checkpoints are saved during training based on time.
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
every_n_epochs: int | None = None
|
|
104
|
-
"""
|
|
105
|
-
Number of epochs between checkpoints.
|
|
106
|
-
If None or 0, no checkpoints are saved at the end of epochs.
|
|
107
|
-
"""
|
|
108
|
-
|
|
109
|
-
save_on_train_epoch_end: bool | None = None
|
|
110
|
-
"""
|
|
111
|
-
Whether to run checkpointing at the end of the training epoch.
|
|
112
|
-
If False, checkpointing runs at the end of the validation.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
enable_version_counter: bool = True
|
|
116
|
-
"""Whether to append a version to the existing file name."""
|
|
117
|
-
|
|
118
|
-
auto_append_metric: bool = True
|
|
119
|
-
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
120
|
-
|
|
121
|
-
def metric_or_default(self, root_config: "BaseConfig"):
|
|
122
|
-
if self.metric is not None:
|
|
123
|
-
return self.metric
|
|
124
|
-
if root_config.primary_metric is not None:
|
|
125
|
-
return root_config.primary_metric
|
|
126
|
-
raise ValueError("Primary metric must be provided if metric is not specified.")
|
|
127
|
-
|
|
128
|
-
def resolve_filename(self, root_config: "BaseConfig"):
|
|
129
|
-
metric = self.metric_or_default(root_config)
|
|
130
|
-
|
|
131
|
-
filename = self.filename
|
|
132
|
-
if not filename:
|
|
133
|
-
filename = "{epoch}-{step}"
|
|
134
|
-
if self.auto_append_metric:
|
|
135
|
-
filename = f"{filename}-{{{metric.validation_monitor}}}"
|
|
136
|
-
|
|
137
|
-
if self.auto_insert_metric_name and filename:
|
|
138
|
-
new_filename = _convert_string(filename)
|
|
139
|
-
log.critical(
|
|
140
|
-
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
141
|
-
)
|
|
142
|
-
filename = new_filename
|
|
143
|
-
|
|
144
|
-
return filename
|
|
145
|
-
|
|
146
|
-
@override
|
|
147
|
-
def create_callbacks(self, root_config):
|
|
148
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
149
|
-
root_config.id, "checkpoint"
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
metric = self.metric_or_default(root_config)
|
|
153
|
-
filename = self.resolve_filename(root_config)
|
|
154
|
-
|
|
155
|
-
yield ModelCheckpoint(
|
|
156
|
-
self,
|
|
157
|
-
dirpath=Path(dirpath),
|
|
158
|
-
filename=filename,
|
|
159
|
-
metric=metric,
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
def _save_top_k_model_ckpt_input(self):
|
|
163
|
-
if self.save_top_k == "all":
|
|
164
|
-
return -1
|
|
165
|
-
return self.save_top_k
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
class ModelCheckpoint(_ModelCheckpoint):
|
|
169
|
-
CHECKPOINT_NAME_LAST = "best"
|
|
170
|
-
|
|
171
|
-
@override
|
|
172
|
-
def __init__(
|
|
173
|
-
self,
|
|
174
|
-
config: ModelCheckpointCallbackConfig,
|
|
175
|
-
dirpath: Path,
|
|
176
|
-
filename: str,
|
|
177
|
-
metric: MetricConfig,
|
|
178
|
-
):
|
|
179
|
-
self.config = config
|
|
180
|
-
del config
|
|
181
|
-
|
|
182
|
-
super().__init__(
|
|
183
|
-
dirpath=dirpath,
|
|
184
|
-
filename=filename,
|
|
185
|
-
monitor=metric.validation_monitor,
|
|
186
|
-
mode=metric.mode,
|
|
187
|
-
verbose=self.config.verbose,
|
|
188
|
-
save_last=self.config.save_last,
|
|
189
|
-
save_top_k=self.config._save_top_k_model_ckpt_input(),
|
|
190
|
-
save_weights_only=self.config.save_weights_only,
|
|
191
|
-
auto_insert_metric_name=False,
|
|
192
|
-
every_n_train_steps=self.config.every_n_train_steps,
|
|
193
|
-
train_time_interval=self.config.train_time_interval,
|
|
194
|
-
every_n_epochs=self.config.every_n_epochs,
|
|
195
|
-
save_on_train_epoch_end=self.config.save_on_train_epoch_end,
|
|
196
|
-
enable_version_counter=self.config.enable_version_counter,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
@override
|
|
200
|
-
def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
201
|
-
if trainer.is_global_zero:
|
|
202
|
-
_link_checkpoint(filepath, linkpath, metadata=True)
|
|
203
|
-
trainer.strategy.barrier()
|
|
204
|
-
|
|
205
|
-
@override
|
|
206
|
-
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
207
|
-
_ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
|
|
File without changes
|