nshtrainer 0.11.7__py3-none-any.whl → 0.11.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/loader.py +4 -4
- nshtrainer/_checkpoint/metadata.py +37 -35
- nshtrainer/callbacks/__init__.py +3 -8
- nshtrainer/callbacks/checkpoint/__init__.py +3 -7
- nshtrainer/callbacks/checkpoint/_base.py +175 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +29 -151
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +39 -0
- nshtrainer/model/__init__.py +2 -4
- nshtrainer/model/config.py +4 -37
- {nshtrainer-0.11.7.dist-info → nshtrainer-0.11.9.dist-info}/METADATA +1 -1
- {nshtrainer-0.11.7.dist-info → nshtrainer-0.11.9.dist-info}/RECORD +12 -12
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +0 -131
- nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -207
- {nshtrainer-0.11.7.dist-info → nshtrainer-0.11.9.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
|
|
|
236
236
|
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
|
237
237
|
match on_error:
|
|
238
238
|
case "warn":
|
|
239
|
-
log.
|
|
239
|
+
log.warning(error_msg)
|
|
240
240
|
case "raise":
|
|
241
241
|
raise ValueError(error_msg)
|
|
242
242
|
case _:
|
|
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
|
|
|
325
325
|
),
|
|
326
326
|
]
|
|
327
327
|
if not candidates:
|
|
328
|
-
log.
|
|
328
|
+
log.warning(
|
|
329
329
|
"No checkpoint candidates found for `best` checkpoint strategy."
|
|
330
330
|
)
|
|
331
331
|
continue
|
|
332
332
|
|
|
333
333
|
if (metric := strategy.metric or root_config.primary_metric) is None:
|
|
334
|
-
log.
|
|
334
|
+
log.warning(
|
|
335
335
|
"No metric specified for `best` checkpoint strategy, "
|
|
336
336
|
"and no primary metric is set in the configuration. "
|
|
337
337
|
"Skipping strategy."
|
|
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
|
|
|
360
360
|
),
|
|
361
361
|
]
|
|
362
362
|
if not candidates:
|
|
363
|
-
log.
|
|
363
|
+
log.warning(
|
|
364
364
|
"No checkpoint candidates found for `last` checkpoint strategy."
|
|
365
365
|
)
|
|
366
366
|
continue
|
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import shutil
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
|
8
8
|
|
|
9
9
|
import nshconfig as C
|
|
10
10
|
import numpy as np
|
|
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
METADATA_PATH_SUFFIX = ".metadata.json"
|
|
23
|
-
HPARAMS_PATH_SUFFIX = ".hparams.json"
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class CheckpointMetadata(C.Config):
|
|
26
|
+
PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
|
|
27
|
+
|
|
27
28
|
checkpoint_path: Path
|
|
28
29
|
checkpoint_filename: str
|
|
29
30
|
|
|
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
|
|
|
39
40
|
metrics: dict[str, Any]
|
|
40
41
|
environment: EnvironmentConfig
|
|
41
42
|
|
|
43
|
+
hparams: dict[str, Any] | None
|
|
44
|
+
|
|
42
45
|
@classmethod
|
|
43
46
|
def from_file(cls, path: Path):
|
|
44
47
|
return cls.model_validate_json(path.read_text())
|
|
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
|
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
def _generate_checkpoint_metadata(
|
|
58
|
-
config: "BaseConfig",
|
|
61
|
+
config: "BaseConfig",
|
|
62
|
+
trainer: "Trainer",
|
|
63
|
+
checkpoint_path: Path,
|
|
64
|
+
metadata_path: Path,
|
|
59
65
|
):
|
|
60
66
|
checkpoint_timestamp = datetime.datetime.now()
|
|
61
67
|
start_timestamp = trainer.start_time()
|
|
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
|
|
|
70
76
|
metrics[name] = metric
|
|
71
77
|
|
|
72
78
|
return CheckpointMetadata(
|
|
73
|
-
checkpoint_path=checkpoint_path,
|
|
79
|
+
# checkpoint_path=checkpoint_path,
|
|
80
|
+
# We should store the path as a relative path
|
|
81
|
+
# to the metadata file to avoid issues with
|
|
82
|
+
# moving the checkpoint directory
|
|
83
|
+
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
74
84
|
checkpoint_filename=checkpoint_path.name,
|
|
75
85
|
run_id=config.id,
|
|
76
86
|
name=config.run_name,
|
|
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
|
|
|
84
94
|
training_time=training_time,
|
|
85
95
|
metrics=metrics,
|
|
86
96
|
environment=config.environment,
|
|
97
|
+
hparams=config.model_dump(mode="json"),
|
|
87
98
|
)
|
|
88
99
|
|
|
89
100
|
|
|
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
|
|
|
93
104
|
checkpoint_path: Path,
|
|
94
105
|
):
|
|
95
106
|
config = cast("BaseConfig", model.config)
|
|
96
|
-
|
|
107
|
+
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
108
|
+
metadata = _generate_checkpoint_metadata(
|
|
109
|
+
config, trainer, checkpoint_path, metadata_path
|
|
110
|
+
)
|
|
97
111
|
|
|
98
112
|
# Write the metadata to the checkpoint directory
|
|
99
113
|
try:
|
|
100
|
-
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
101
114
|
metadata_path.write_text(metadata.model_dump_json(indent=4))
|
|
102
115
|
except Exception as e:
|
|
103
116
|
log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
|
|
104
117
|
else:
|
|
105
118
|
log.debug(f"Checkpoint metadata written to {checkpoint_path}")
|
|
106
119
|
|
|
107
|
-
|
|
120
|
+
|
|
121
|
+
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
+
path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
108
123
|
try:
|
|
109
|
-
|
|
110
|
-
hparams_path.write_text(config.model_dump_json(indent=4))
|
|
124
|
+
path.unlink(missing_ok=True)
|
|
111
125
|
except Exception as e:
|
|
112
|
-
log.warning(f"Failed to
|
|
126
|
+
log.warning(f"Failed to remove {path}: {e}")
|
|
113
127
|
else:
|
|
114
|
-
log.debug(f"
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
118
|
-
for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
|
|
119
|
-
path = checkpoint_path.with_suffix(suffix)
|
|
120
|
-
try:
|
|
121
|
-
path.unlink(missing_ok=True)
|
|
122
|
-
except Exception as e:
|
|
123
|
-
log.warning(f"Failed to remove {path}: {e}")
|
|
124
|
-
else:
|
|
125
|
-
log.debug(f"Removed {path}")
|
|
128
|
+
log.debug(f"Removed {path}")
|
|
126
129
|
|
|
127
130
|
|
|
128
131
|
def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
|
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
130
133
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
131
134
|
|
|
132
135
|
# Link the metadata files to the new checkpoint
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
+
path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
137
|
+
linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
138
|
+
try:
|
|
136
139
|
try:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
log.debug(f"Linked {path} to {linked_path}")
|
|
140
|
+
linked_path.symlink_to(path)
|
|
141
|
+
except OSError:
|
|
142
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
143
|
+
# fall back to copying the file
|
|
144
|
+
shutil.copy(path, linked_path)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
log.warning(f"Failed to link {path} to {linked_path}: {e}")
|
|
147
|
+
else:
|
|
148
|
+
log.debug(f"Linked {path} to {linked_path}")
|
|
147
149
|
|
|
148
150
|
|
|
149
151
|
def _sort_ckpts_by_metadata(
|
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -6,12 +6,8 @@ from . import checkpoint as checkpoint
|
|
|
6
6
|
from .base import CallbackConfigBase as CallbackConfigBase
|
|
7
7
|
from .checkpoint import BestCheckpoint as BestCheckpoint
|
|
8
8
|
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
9
|
-
from .checkpoint import
|
|
10
|
-
from .checkpoint import
|
|
11
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
12
|
-
)
|
|
13
|
-
from .checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
14
|
-
from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
9
|
+
from .checkpoint import LastCheckpoint as LastCheckpoint
|
|
10
|
+
from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
15
11
|
from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
16
12
|
from .checkpoint import (
|
|
17
13
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
@@ -46,8 +42,7 @@ CallbackConfig = Annotated[
|
|
|
46
42
|
| GradientSkippingConfig
|
|
47
43
|
| EMAConfig
|
|
48
44
|
| BestCheckpointCallbackConfig
|
|
49
|
-
|
|
|
50
|
-
| LatestEpochCheckpointCallbackConfig
|
|
45
|
+
| LastCheckpointCallbackConfig
|
|
51
46
|
| OnExceptionCheckpointCallbackConfig
|
|
52
47
|
| WandbWatchConfig,
|
|
53
48
|
C.Field(discriminator="name"),
|
|
@@ -2,13 +2,9 @@ from .best_checkpoint import BestCheckpoint as BestCheckpoint
|
|
|
2
2
|
from .best_checkpoint import (
|
|
3
3
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
4
4
|
)
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
7
|
-
|
|
8
|
-
)
|
|
9
|
-
from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
10
|
-
from .model_checkpoint import (
|
|
11
|
-
ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
|
|
5
|
+
from .last_checkpoint import LastCheckpoint as LastCheckpoint
|
|
6
|
+
from .last_checkpoint import (
|
|
7
|
+
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
12
8
|
)
|
|
13
9
|
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
14
10
|
from .on_exception_checkpoint import (
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch import Trainer
|
|
9
|
+
from lightning.pytorch.callbacks import Checkpoint
|
|
10
|
+
from typing_extensions import TypeVar, override
|
|
11
|
+
|
|
12
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
|
13
|
+
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
+
from ..base import CallbackConfigBase
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ...model.config import BaseConfig
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
23
|
+
dirpath: str | Path | None = None
|
|
24
|
+
"""Directory path to save the checkpoint file."""
|
|
25
|
+
|
|
26
|
+
filename: str | None = None
|
|
27
|
+
"""Checkpoint filename. This must not include the extension.
|
|
28
|
+
If None, the default filename will be used."""
|
|
29
|
+
|
|
30
|
+
save_weights_only: bool = False
|
|
31
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
32
|
+
|
|
33
|
+
save_symlink: bool = True
|
|
34
|
+
"""Whether to create a symlink to the saved checkpoint."""
|
|
35
|
+
|
|
36
|
+
topk: int | Literal["all"] = 1
|
|
37
|
+
"""The number of checkpoints to keep."""
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def create_checkpoint(
|
|
41
|
+
self,
|
|
42
|
+
root_config: "BaseConfig",
|
|
43
|
+
dirpath: Path,
|
|
44
|
+
) -> "CheckpointBase": ...
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def create_callbacks(self, root_config):
|
|
48
|
+
dirpath = Path(
|
|
49
|
+
self.dirpath
|
|
50
|
+
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
yield self.create_checkpoint(root_config, dirpath)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
60
|
+
def __init__(self, config: TConfig, dirpath: Path):
|
|
61
|
+
super().__init__()
|
|
62
|
+
|
|
63
|
+
self.config = config
|
|
64
|
+
self.dirpath = dirpath / self.name()
|
|
65
|
+
self.symlink_dirpath = dirpath
|
|
66
|
+
|
|
67
|
+
self._last_global_step_saved = 0
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def default_filename(self) -> str: ...
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def name(self) -> str: ...
|
|
74
|
+
|
|
75
|
+
def extension(self) -> str:
|
|
76
|
+
return ".ckpt"
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
|
|
80
|
+
|
|
81
|
+
def symlink_path(self):
|
|
82
|
+
if not self.config.save_symlink:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
return self.symlink_dirpath / f"{self.name()}{self.extension()}"
|
|
86
|
+
|
|
87
|
+
def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
|
|
88
|
+
if (filename := self.config.filename) is None:
|
|
89
|
+
filename = self.default_filename()
|
|
90
|
+
filename = filename.format(**current_metrics)
|
|
91
|
+
return self.dirpath / f"{filename}{self.extension()}"
|
|
92
|
+
|
|
93
|
+
def remove_old_checkpoints(self, trainer: Trainer):
|
|
94
|
+
if (topk := self.config.topk) == "all":
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# Get all the checkpoint metadata
|
|
98
|
+
metas = [
|
|
99
|
+
CheckpointMetadata.from_file(p)
|
|
100
|
+
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# Sort by the topk sort key
|
|
104
|
+
metas = sorted(metas, key=self.topk_sort_key)
|
|
105
|
+
|
|
106
|
+
# Now, the metas are sorted from the best to the worst,
|
|
107
|
+
# so we can remove the worst checkpoints
|
|
108
|
+
for meta in metas[topk:]:
|
|
109
|
+
if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
|
|
110
|
+
log.warning(
|
|
111
|
+
f"Checkpoint file not found: {old_ckpt_path}\n"
|
|
112
|
+
"Skipping removal of the checkpoint metadata."
|
|
113
|
+
)
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
_remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
|
117
|
+
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
118
|
+
|
|
119
|
+
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
120
|
+
current_metrics: dict[str, Any] = {
|
|
121
|
+
"epoch": trainer.current_epoch,
|
|
122
|
+
"step": trainer.global_step,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
for name, value in trainer.callback_metrics.items():
|
|
126
|
+
match value:
|
|
127
|
+
case torch.Tensor() if value.numel() == 1:
|
|
128
|
+
value = value.detach().cpu().item()
|
|
129
|
+
case np.ndarray() if value.size == 1:
|
|
130
|
+
value = value.item()
|
|
131
|
+
case _:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
current_metrics[name] = value
|
|
135
|
+
|
|
136
|
+
return current_metrics
|
|
137
|
+
|
|
138
|
+
def save_checkpoints(self, trainer: Trainer):
|
|
139
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
# Save the new checkpoint
|
|
143
|
+
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
|
144
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
145
|
+
|
|
146
|
+
if trainer.is_global_zero:
|
|
147
|
+
# Remove old checkpoints
|
|
148
|
+
self.remove_old_checkpoints(trainer)
|
|
149
|
+
|
|
150
|
+
# Create the latest symlink
|
|
151
|
+
if (symlink_filename := self.symlink_path()) is not None:
|
|
152
|
+
symlink_path = self.dirpath / symlink_filename
|
|
153
|
+
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
154
|
+
log.debug(f"Created latest symlink: {symlink_path}")
|
|
155
|
+
|
|
156
|
+
# Barrier to ensure all processes have saved the checkpoint,
|
|
157
|
+
# deleted the old checkpoints, and created the symlink before continuing
|
|
158
|
+
trainer.strategy.barrier()
|
|
159
|
+
|
|
160
|
+
# Set the last global step saved
|
|
161
|
+
self._last_global_step_saved = trainer.global_step
|
|
162
|
+
|
|
163
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
164
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
165
|
+
|
|
166
|
+
return (
|
|
167
|
+
bool(
|
|
168
|
+
getattr(trainer, "fast_dev_run", False)
|
|
169
|
+
) # disable checkpointing with fast_dev_run
|
|
170
|
+
or trainer.state.fn
|
|
171
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
172
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
173
|
+
or self._last_global_step_saved
|
|
174
|
+
== trainer.global_step # already saved at the last step
|
|
175
|
+
)
|
|
@@ -1,47 +1,27 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
-
from
|
|
7
|
-
|
|
6
|
+
from typing_extensions import final, override
|
|
7
|
+
|
|
8
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
|
8
9
|
|
|
9
|
-
from ..._checkpoint.metadata import _sort_ckpts_by_metadata
|
|
10
|
-
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
11
10
|
from ...metrics._config import MetricConfig
|
|
12
|
-
from
|
|
11
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
|
13
12
|
|
|
14
13
|
log = logging.getLogger(__name__)
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
|
|
16
|
+
@final
|
|
17
|
+
class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
18
18
|
name: Literal["best_checkpoint"] = "best_checkpoint"
|
|
19
19
|
|
|
20
|
-
dirpath: str | Path | None = None
|
|
21
|
-
"""Directory path to save the checkpoint file."""
|
|
22
|
-
|
|
23
|
-
filename: str = "epoch{epoch:02d}_step{step:04d}"
|
|
24
|
-
"""Checkpoint filename. This must not include the extension."""
|
|
25
|
-
|
|
26
|
-
save_weights_only: bool = False
|
|
27
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
28
|
-
|
|
29
20
|
metric: MetricConfig | None = None
|
|
30
21
|
"""Metric to monitor, or `None` to use the default metric."""
|
|
31
22
|
|
|
32
|
-
best_symlink_filename: str | None = "best"
|
|
33
|
-
"""Filename for the best symlink. If None, no symlink will be created."""
|
|
34
|
-
|
|
35
|
-
save_top_k: int | Literal["all"] = 1
|
|
36
|
-
"""The number of best checkpoints to keep."""
|
|
37
|
-
|
|
38
23
|
@override
|
|
39
|
-
def
|
|
40
|
-
dirpath = Path(
|
|
41
|
-
self.dirpath
|
|
42
|
-
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
43
|
-
)
|
|
44
|
-
|
|
24
|
+
def create_checkpoint(self, root_config, dirpath):
|
|
45
25
|
# Resolve metric
|
|
46
26
|
if (metric := self.metric) is None and (
|
|
47
27
|
metric := root_config.primary_metric
|
|
@@ -50,143 +30,41 @@ class BestCheckpointCallbackConfig(CallbackConfigBase):
|
|
|
50
30
|
"No metric provided and no primary metric found in the root config"
|
|
51
31
|
)
|
|
52
32
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def _save_top_k_value(self):
|
|
57
|
-
return float("inf" if self.save_top_k == "all" else self.save_top_k)
|
|
33
|
+
return BestCheckpoint(self, dirpath, metric)
|
|
58
34
|
|
|
59
35
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
36
|
+
@final
|
|
37
|
+
class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
38
|
+
@property
|
|
39
|
+
def _metric_name_normalized(self):
|
|
40
|
+
return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
|
|
63
41
|
|
|
42
|
+
@override
|
|
64
43
|
def __init__(
|
|
65
44
|
self,
|
|
66
45
|
config: BestCheckpointCallbackConfig,
|
|
67
|
-
metric: MetricConfig,
|
|
68
46
|
dirpath: Path,
|
|
47
|
+
metric: MetricConfig,
|
|
69
48
|
):
|
|
70
|
-
super().__init__()
|
|
71
|
-
self.config = config
|
|
49
|
+
super().__init__(config, dirpath)
|
|
72
50
|
self.metric = metric
|
|
73
|
-
self.dirpath = dirpath
|
|
74
|
-
|
|
75
|
-
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
|
76
51
|
|
|
77
52
|
@override
|
|
78
|
-
def
|
|
79
|
-
self.
|
|
80
|
-
|
|
81
|
-
def _best_symlink_filename(self):
|
|
82
|
-
if (filename := self.config.best_symlink_filename) is None:
|
|
83
|
-
return None
|
|
84
|
-
return f"{filename}{self.EXTENSION}"
|
|
53
|
+
def name(self):
|
|
54
|
+
return f"best_{self._metric_name_normalized}"
|
|
85
55
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
)
|
|
90
|
-
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
91
|
-
return self.dirpath / filename
|
|
56
|
+
@override
|
|
57
|
+
def default_filename(self):
|
|
58
|
+
return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
|
|
92
59
|
|
|
93
|
-
|
|
94
|
-
|
|
60
|
+
@override
|
|
61
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
62
|
+
return metadata.metrics.get(
|
|
95
63
|
self.metric.validation_monitor,
|
|
96
64
|
float("-inf" if self.metric.mode == "max" else "inf"),
|
|
97
65
|
)
|
|
98
66
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
Sort order: best -> worst
|
|
104
|
-
"""
|
|
105
|
-
ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
106
|
-
return _sort_ckpts_by_metadata(
|
|
107
|
-
ckpt_paths,
|
|
108
|
-
key=lambda meta, _: self._get_metric_value(meta.metrics),
|
|
109
|
-
reverse=(self.metric.mode == "max"),
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
|
|
113
|
-
# Resolve the symlink filename
|
|
114
|
-
if (symlink_filename := self._best_symlink_filename()) is None:
|
|
115
|
-
return
|
|
116
|
-
|
|
117
|
-
# If the symlink already exists and points to the best checkpoint,
|
|
118
|
-
# then we don't need to create a new symlink.
|
|
119
|
-
symlink_path = self.dirpath / symlink_filename
|
|
120
|
-
if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
|
|
121
|
-
return
|
|
122
|
-
|
|
123
|
-
_link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
|
|
124
|
-
log.debug(f"Created best symlink: {symlink_path}")
|
|
125
|
-
|
|
126
|
-
def _save_best_checkpoint(self, trainer: Trainer):
|
|
127
|
-
# Skip saving the checkpoint if we're not in the fitting state
|
|
128
|
-
if self._should_skip_saving_checkpoint(trainer):
|
|
129
|
-
return
|
|
130
|
-
|
|
131
|
-
# Get the current metric value
|
|
132
|
-
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
133
|
-
log.warning(
|
|
134
|
-
f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
|
|
135
|
-
)
|
|
136
|
-
return
|
|
137
|
-
|
|
138
|
-
# Get sorted checkpoints
|
|
139
|
-
sorted_ckpts = self._sorted_ckpts()
|
|
140
|
-
|
|
141
|
-
# If the current model is worse than the worst checkpoint,
|
|
142
|
-
# and we have already saved the maximum number of checkpoints,
|
|
143
|
-
# then don't save the current model.
|
|
144
|
-
if len(
|
|
145
|
-
sorted_ckpts
|
|
146
|
-
) >= self.config._save_top_k_value and not self.metric.is_better(
|
|
147
|
-
current,
|
|
148
|
-
self._get_metric_value(sorted_ckpts[-1][0].metrics),
|
|
149
|
-
):
|
|
150
|
-
return
|
|
151
|
-
|
|
152
|
-
# Save the current model
|
|
153
|
-
filepath = self._ckpt_path(trainer)
|
|
154
|
-
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
155
|
-
log.debug(f"Saved best checkpoint: {filepath}")
|
|
156
|
-
|
|
157
|
-
if trainer.is_global_zero:
|
|
158
|
-
# Get the sorted checkpoints again because now we have added a new checkpoint.
|
|
159
|
-
# We could optimize this by adding the new checkpoint to the sorted list,
|
|
160
|
-
# and then sorting it in place, but this is simpler.
|
|
161
|
-
sorted_ckpts = self._sorted_ckpts()
|
|
162
|
-
|
|
163
|
-
# Remove worst checkpoint if we've reached save_top_k
|
|
164
|
-
if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
|
|
165
|
-
# NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
|
|
166
|
-
for _, ckpt_path in sorted_ckpts[topk:]:
|
|
167
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
168
|
-
|
|
169
|
-
# Create symlink to best model
|
|
170
|
-
if sorted_ckpts:
|
|
171
|
-
_, best_ckpt_path = sorted_ckpts[0]
|
|
172
|
-
self._create_symlink(trainer, best_ckpt_path)
|
|
173
|
-
|
|
174
|
-
# Update the last global step saved
|
|
175
|
-
self._last_global_step_saved = trainer.global_step
|
|
176
|
-
|
|
177
|
-
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
178
|
-
trainer.strategy.barrier()
|
|
179
|
-
|
|
180
|
-
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
181
|
-
from lightning.pytorch.trainer.states import TrainerFn
|
|
182
|
-
|
|
183
|
-
return (
|
|
184
|
-
bool(
|
|
185
|
-
getattr(trainer, "fast_dev_run", False)
|
|
186
|
-
) # disable checkpointing with fast_dev_run
|
|
187
|
-
or trainer.state.fn
|
|
188
|
-
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
189
|
-
or trainer.sanity_checking # don't save anything during sanity check
|
|
190
|
-
or self._last_global_step_saved
|
|
191
|
-
== trainer.global_step # already saved at the last step
|
|
192
|
-
)
|
|
67
|
+
# Events
|
|
68
|
+
@override
|
|
69
|
+
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
70
|
+
self.save_checkpoints(trainer)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
5
|
+
from typing_extensions import final, override
|
|
6
|
+
|
|
7
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
|
8
|
+
|
|
9
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@final
|
|
15
|
+
class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
16
|
+
name: Literal["last_checkpoint"] = "last_checkpoint"
|
|
17
|
+
|
|
18
|
+
@override
|
|
19
|
+
def create_checkpoint(self, root_config, dirpath):
|
|
20
|
+
return LastCheckpoint(self, dirpath)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@final
|
|
24
|
+
class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
25
|
+
@override
|
|
26
|
+
def name(self):
|
|
27
|
+
return "last"
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def default_filename(self):
|
|
31
|
+
return "epoch{epoch:03d}-step{step:07d}"
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
35
|
+
return metadata.checkpoint_timestamp
|
|
36
|
+
|
|
37
|
+
@override
|
|
38
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
39
|
+
self.save_checkpoints(trainer)
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -5,17 +5,15 @@ from .base import LightningModuleBase as LightningModuleBase
|
|
|
5
5
|
from .config import BaseConfig as BaseConfig
|
|
6
6
|
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
7
7
|
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
8
|
+
from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
8
9
|
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
9
10
|
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
10
11
|
from .config import DirectoryConfig as DirectoryConfig
|
|
11
12
|
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
12
13
|
from .config import GradientClippingConfig as GradientClippingConfig
|
|
13
|
-
from .config import
|
|
14
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
15
|
-
)
|
|
14
|
+
from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
16
15
|
from .config import LoggingConfig as LoggingConfig
|
|
17
16
|
from .config import MetricConfig as MetricConfig
|
|
18
|
-
from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
19
17
|
from .config import (
|
|
20
18
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
21
19
|
)
|
nshtrainer/model/config.py
CHANGED
|
@@ -39,8 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
|
|
|
39
39
|
from ..callbacks import (
|
|
40
40
|
BestCheckpointCallbackConfig,
|
|
41
41
|
CallbackConfig,
|
|
42
|
-
|
|
43
|
-
ModelCheckpointCallbackConfig,
|
|
42
|
+
LastCheckpointCallbackConfig,
|
|
44
43
|
OnExceptionCheckpointCallbackConfig,
|
|
45
44
|
WandbWatchConfig,
|
|
46
45
|
)
|
|
@@ -771,9 +770,8 @@ class ReproducibilityConfig(C.Config):
|
|
|
771
770
|
|
|
772
771
|
|
|
773
772
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
774
|
-
|
|
775
|
-
|
|
|
776
|
-
| LatestEpochCheckpointCallbackConfig
|
|
773
|
+
BestCheckpointCallbackConfig
|
|
774
|
+
| LastCheckpointCallbackConfig
|
|
777
775
|
| OnExceptionCheckpointCallbackConfig,
|
|
778
776
|
C.Field(discriminator="name"),
|
|
779
777
|
]
|
|
@@ -784,9 +782,8 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
784
782
|
"""Enable checkpoint saving."""
|
|
785
783
|
|
|
786
784
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
787
|
-
# ModelCheckpointCallbackConfig(),
|
|
788
785
|
BestCheckpointCallbackConfig(),
|
|
789
|
-
|
|
786
|
+
LastCheckpointCallbackConfig(),
|
|
790
787
|
OnExceptionCheckpointCallbackConfig(),
|
|
791
788
|
]
|
|
792
789
|
"""Checkpoint callback configurations."""
|
|
@@ -804,36 +801,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
804
801
|
|
|
805
802
|
return True
|
|
806
803
|
|
|
807
|
-
@property
|
|
808
|
-
def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
|
|
809
|
-
return next(
|
|
810
|
-
(
|
|
811
|
-
callback
|
|
812
|
-
for callback in self.checkpoint_callbacks
|
|
813
|
-
if isinstance(callback, ModelCheckpointCallbackConfig)
|
|
814
|
-
),
|
|
815
|
-
)
|
|
816
|
-
|
|
817
|
-
@property
|
|
818
|
-
def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
|
|
819
|
-
return next(
|
|
820
|
-
(
|
|
821
|
-
callback
|
|
822
|
-
for callback in self.checkpoint_callbacks
|
|
823
|
-
if isinstance(callback, LatestEpochCheckpointCallbackConfig)
|
|
824
|
-
),
|
|
825
|
-
)
|
|
826
|
-
|
|
827
|
-
@property
|
|
828
|
-
def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
|
|
829
|
-
return next(
|
|
830
|
-
(
|
|
831
|
-
callback
|
|
832
|
-
for callback in self.checkpoint_callbacks
|
|
833
|
-
if isinstance(callback, OnExceptionCheckpointCallbackConfig)
|
|
834
|
-
),
|
|
835
|
-
)
|
|
836
|
-
|
|
837
804
|
@override
|
|
838
805
|
def create_callbacks(self, root_config: "BaseConfig"):
|
|
839
806
|
if not self.should_save_checkpoints(root_config):
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/_checkpoint/loader.py,sha256=
|
|
3
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=DSaNR8194kWon4O1svslNsCcN_8vlyLbF0LNCPfUpzI,13789
|
|
3
|
+
nshtrainer/_checkpoint/metadata.py,sha256=n2PMGdA3Jn3BuoyDXVCF9dUnNjuuru5CL9jqMD-X4Vk,4918
|
|
4
4
|
nshtrainer/_checkpoint/saver.py,sha256=TuSAP39DOOVvSnSukQ9RitMV60JnDg6L27fMRc2uVJc,1358
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
|
|
6
6
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
7
7
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
8
8
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=k-DbpIlH2t5-oR3gHGHr8KiyCd_Twers4PcIUM1noqQ,2262
|
|
10
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
|
-
nshtrainer/callbacks/checkpoint/__init__.py,sha256=
|
|
14
|
-
nshtrainer/callbacks/checkpoint/
|
|
15
|
-
nshtrainer/callbacks/checkpoint/
|
|
16
|
-
nshtrainer/callbacks/checkpoint/
|
|
13
|
+
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
14
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=9HQSa-toOyjtuDldQ71gaDVBRdryAaB_nRv5Y554tIk,5938
|
|
15
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=O6d4SqIYsxpnRj_IYX8A9VLgOBwxTdz-j2FV_nn3BT8,2067
|
|
16
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CM8f37dwaYHkjQFfJNTZTzSoF45zEjFRm-Fg1CzYmP4,1037
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
18
18
|
nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
|
|
19
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -52,9 +52,9 @@ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9b
|
|
|
52
52
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
53
53
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
54
54
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
55
|
-
nshtrainer/model/__init__.py,sha256=
|
|
55
|
+
nshtrainer/model/__init__.py,sha256=RlGW5a46DZcqK6cYICYxDaKpZIEj-8zLxoMrl432tno,1429
|
|
56
56
|
nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
|
|
57
|
-
nshtrainer/model/config.py,sha256=
|
|
57
|
+
nshtrainer/model/config.py,sha256=F-doUiqPVw4yepT6FRqhflEiiMEWr_PuFC6lKzFWktA,53809
|
|
58
58
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
59
59
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
60
60
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.9.dist-info/METADATA,sha256=rFmi4wYXJz8srZhqFl_5ROxfmgFru7jhYUpUp7ZZjMg,860
|
|
86
|
+
nshtrainer-0.11.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.9.dist-info/RECORD,,
|
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
-
from lightning.pytorch.callbacks import Checkpoint
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from ..._checkpoint.metadata import _sort_ckpts_by_metadata
|
|
10
|
-
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
11
|
-
from ..base import CallbackConfigBase
|
|
12
|
-
|
|
13
|
-
log = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
17
|
-
name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
18
|
-
|
|
19
|
-
dirpath: str | Path | None = None
|
|
20
|
-
"""Directory path to save the checkpoint file."""
|
|
21
|
-
|
|
22
|
-
filename: str = "epoch{epoch:02d}_step{step:04d}"
|
|
23
|
-
"""Checkpoint filename. This must not include the extension."""
|
|
24
|
-
|
|
25
|
-
save_weights_only: bool = False
|
|
26
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
27
|
-
|
|
28
|
-
latest_symlink_filename: str | None = "latest"
|
|
29
|
-
"""Filename for the latest symlink. If None, no symlink will be created."""
|
|
30
|
-
|
|
31
|
-
latest_k: int | Literal["all"] = 1
|
|
32
|
-
"""Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
|
|
33
|
-
|
|
34
|
-
@override
|
|
35
|
-
def create_callbacks(self, root_config):
|
|
36
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
37
|
-
root_config.id, "checkpoint"
|
|
38
|
-
)
|
|
39
|
-
dirpath = Path(dirpath)
|
|
40
|
-
|
|
41
|
-
yield LatestEpochCheckpoint(self, dirpath)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class LatestEpochCheckpoint(Checkpoint):
|
|
45
|
-
PREFIX = "latest_"
|
|
46
|
-
EXTENSION = ".ckpt"
|
|
47
|
-
|
|
48
|
-
def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
|
|
49
|
-
super().__init__()
|
|
50
|
-
|
|
51
|
-
self.config = config
|
|
52
|
-
self.dirpath = dirpath
|
|
53
|
-
|
|
54
|
-
self._last_global_step_saved = 0
|
|
55
|
-
|
|
56
|
-
@override
|
|
57
|
-
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
58
|
-
self._save_new_checkpoint(trainer)
|
|
59
|
-
|
|
60
|
-
def _latest_symlink_filename(self):
|
|
61
|
-
if (filename := self.config.latest_symlink_filename) is None:
|
|
62
|
-
return None
|
|
63
|
-
return f"{filename}{self.EXTENSION}"
|
|
64
|
-
|
|
65
|
-
def _ckpt_path(self, trainer: Trainer):
|
|
66
|
-
filename = self.config.filename.format(
|
|
67
|
-
epoch=trainer.current_epoch, step=trainer.global_step
|
|
68
|
-
)
|
|
69
|
-
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
70
|
-
return self.dirpath / filename
|
|
71
|
-
|
|
72
|
-
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
73
|
-
if (latest_k := self.config.latest_k) == "all":
|
|
74
|
-
return
|
|
75
|
-
|
|
76
|
-
# Get all configs, ignoring the latest symlink
|
|
77
|
-
ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
78
|
-
# Ignore the latest symlink
|
|
79
|
-
if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
|
|
80
|
-
ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
|
|
81
|
-
|
|
82
|
-
# Sort by epoch, then step, then last modified
|
|
83
|
-
ckpts = _sort_ckpts_by_metadata(
|
|
84
|
-
ckpts,
|
|
85
|
-
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
86
|
-
reverse=True,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Remove all but the latest k checkpoints
|
|
90
|
-
# NOTE: We add 1 to the latest_k here because
|
|
91
|
-
# we're about to save a new checkpoint.
|
|
92
|
-
for _, ckpt_path in ckpts[latest_k:]:
|
|
93
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
94
|
-
|
|
95
|
-
def _save_new_checkpoint(self, trainer: Trainer):
|
|
96
|
-
if self._should_skip_saving_checkpoint(trainer):
|
|
97
|
-
return
|
|
98
|
-
|
|
99
|
-
# Save the new checkpoint
|
|
100
|
-
filepath = self._ckpt_path(trainer)
|
|
101
|
-
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
102
|
-
|
|
103
|
-
if trainer.is_global_zero:
|
|
104
|
-
# Remove old checkpoints
|
|
105
|
-
self._remove_old_checkpoints(trainer)
|
|
106
|
-
|
|
107
|
-
# Create the latest symlink
|
|
108
|
-
if (symlink_filename := self._latest_symlink_filename()) is not None:
|
|
109
|
-
symlink_path = self.dirpath / symlink_filename
|
|
110
|
-
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
111
|
-
log.debug(f"Created latest symlink: {symlink_path}")
|
|
112
|
-
|
|
113
|
-
# Set the last global step saved
|
|
114
|
-
self._last_global_step_saved = trainer.global_step
|
|
115
|
-
|
|
116
|
-
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
117
|
-
trainer.strategy.barrier()
|
|
118
|
-
|
|
119
|
-
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
120
|
-
from lightning.pytorch.trainer.states import TrainerFn
|
|
121
|
-
|
|
122
|
-
return (
|
|
123
|
-
bool(
|
|
124
|
-
getattr(trainer, "fast_dev_run", False)
|
|
125
|
-
) # disable checkpointing with fast_dev_run
|
|
126
|
-
or trainer.state.fn
|
|
127
|
-
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
128
|
-
or trainer.sanity_checking # don't save anything during sanity check
|
|
129
|
-
or self._last_global_step_saved
|
|
130
|
-
== trainer.global_step # already saved at the last step
|
|
131
|
-
)
|
|
@@ -1,207 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import re
|
|
3
|
-
from datetime import timedelta
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import TYPE_CHECKING, Literal
|
|
6
|
-
|
|
7
|
-
from lightning.pytorch import Trainer
|
|
8
|
-
from lightning.pytorch.callbacks.model_checkpoint import (
|
|
9
|
-
ModelCheckpoint as _ModelCheckpoint,
|
|
10
|
-
)
|
|
11
|
-
from typing_extensions import override
|
|
12
|
-
|
|
13
|
-
from ..._checkpoint.saver import _link_checkpoint
|
|
14
|
-
from ..._checkpoint.saver import _remove_checkpoint as _ckpt_saver_remove_checkpoint
|
|
15
|
-
from ...metrics import MetricConfig
|
|
16
|
-
from ..base import CallbackConfigBase
|
|
17
|
-
|
|
18
|
-
if TYPE_CHECKING:
|
|
19
|
-
from ...model.config import BaseConfig
|
|
20
|
-
|
|
21
|
-
log = logging.getLogger(__name__)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def _convert_string(input_string: str):
|
|
25
|
-
# Find all variables enclosed in curly braces
|
|
26
|
-
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
27
|
-
|
|
28
|
-
# Replace each variable with its corresponding key-value pair
|
|
29
|
-
output_string = input_string
|
|
30
|
-
for variable in variables:
|
|
31
|
-
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
32
|
-
key_name = variable
|
|
33
|
-
if ":" in variable:
|
|
34
|
-
key_name, _ = variable.split(":", 1)
|
|
35
|
-
continue
|
|
36
|
-
|
|
37
|
-
# Replace '/' with '_' in the key name
|
|
38
|
-
key_name = key_name.replace("/", "_")
|
|
39
|
-
output_string = output_string.replace(
|
|
40
|
-
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
return output_string
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
47
|
-
"""Arguments for the ModelCheckpoint callback."""
|
|
48
|
-
|
|
49
|
-
name: Literal["model_checkpoint"] = "model_checkpoint"
|
|
50
|
-
|
|
51
|
-
dirpath: str | Path | None = None
|
|
52
|
-
"""
|
|
53
|
-
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
filename: str | None = None
|
|
57
|
-
"""
|
|
58
|
-
Checkpoint filename.
|
|
59
|
-
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
metric: MetricConfig | None = None
|
|
63
|
-
"""
|
|
64
|
-
Metric to monitor for saving checkpoints.
|
|
65
|
-
If None, the primary metric of the runner will be used, if available.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
verbose: bool = False
|
|
69
|
-
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
70
|
-
|
|
71
|
-
save_last: Literal[True, False, "link"] | None = "link"
|
|
72
|
-
"""
|
|
73
|
-
Whether to save the last checkpoint.
|
|
74
|
-
If True, saves a copy of the last checkpoint separately.
|
|
75
|
-
If "link", creates a symbolic link to the last checkpoint.
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
save_top_k: int | Literal["all"] = 1
|
|
79
|
-
"""
|
|
80
|
-
Number of best models to save.
|
|
81
|
-
If "all" or -1, all models are saved.
|
|
82
|
-
If 0, no models are saved.
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
save_weights_only: bool = False
|
|
86
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
87
|
-
|
|
88
|
-
auto_insert_metric_name: bool = True
|
|
89
|
-
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
90
|
-
|
|
91
|
-
every_n_train_steps: int | None = None
|
|
92
|
-
"""
|
|
93
|
-
Number of training steps between checkpoints.
|
|
94
|
-
If None or 0, no checkpoints are saved during training.
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
train_time_interval: timedelta | None = None
|
|
98
|
-
"""
|
|
99
|
-
Time interval between checkpoints during training.
|
|
100
|
-
If None, no checkpoints are saved during training based on time.
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
every_n_epochs: int | None = None
|
|
104
|
-
"""
|
|
105
|
-
Number of epochs between checkpoints.
|
|
106
|
-
If None or 0, no checkpoints are saved at the end of epochs.
|
|
107
|
-
"""
|
|
108
|
-
|
|
109
|
-
save_on_train_epoch_end: bool | None = None
|
|
110
|
-
"""
|
|
111
|
-
Whether to run checkpointing at the end of the training epoch.
|
|
112
|
-
If False, checkpointing runs at the end of the validation.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
enable_version_counter: bool = True
|
|
116
|
-
"""Whether to append a version to the existing file name."""
|
|
117
|
-
|
|
118
|
-
auto_append_metric: bool = True
|
|
119
|
-
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
120
|
-
|
|
121
|
-
def metric_or_default(self, root_config: "BaseConfig"):
|
|
122
|
-
if self.metric is not None:
|
|
123
|
-
return self.metric
|
|
124
|
-
if root_config.primary_metric is not None:
|
|
125
|
-
return root_config.primary_metric
|
|
126
|
-
raise ValueError("Primary metric must be provided if metric is not specified.")
|
|
127
|
-
|
|
128
|
-
def resolve_filename(self, root_config: "BaseConfig"):
|
|
129
|
-
metric = self.metric_or_default(root_config)
|
|
130
|
-
|
|
131
|
-
filename = self.filename
|
|
132
|
-
if not filename:
|
|
133
|
-
filename = "{epoch}-{step}"
|
|
134
|
-
if self.auto_append_metric:
|
|
135
|
-
filename = f"{filename}-{{{metric.validation_monitor}}}"
|
|
136
|
-
|
|
137
|
-
if self.auto_insert_metric_name and filename:
|
|
138
|
-
new_filename = _convert_string(filename)
|
|
139
|
-
log.critical(
|
|
140
|
-
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
141
|
-
)
|
|
142
|
-
filename = new_filename
|
|
143
|
-
|
|
144
|
-
return filename
|
|
145
|
-
|
|
146
|
-
@override
|
|
147
|
-
def create_callbacks(self, root_config):
|
|
148
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
149
|
-
root_config.id, "checkpoint"
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
metric = self.metric_or_default(root_config)
|
|
153
|
-
filename = self.resolve_filename(root_config)
|
|
154
|
-
|
|
155
|
-
yield ModelCheckpoint(
|
|
156
|
-
self,
|
|
157
|
-
dirpath=Path(dirpath),
|
|
158
|
-
filename=filename,
|
|
159
|
-
metric=metric,
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
def _save_top_k_model_ckpt_input(self):
|
|
163
|
-
if self.save_top_k == "all":
|
|
164
|
-
return -1
|
|
165
|
-
return self.save_top_k
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
class ModelCheckpoint(_ModelCheckpoint):
|
|
169
|
-
CHECKPOINT_NAME_LAST = "best"
|
|
170
|
-
|
|
171
|
-
@override
|
|
172
|
-
def __init__(
|
|
173
|
-
self,
|
|
174
|
-
config: ModelCheckpointCallbackConfig,
|
|
175
|
-
dirpath: Path,
|
|
176
|
-
filename: str,
|
|
177
|
-
metric: MetricConfig,
|
|
178
|
-
):
|
|
179
|
-
self.config = config
|
|
180
|
-
del config
|
|
181
|
-
|
|
182
|
-
super().__init__(
|
|
183
|
-
dirpath=dirpath,
|
|
184
|
-
filename=filename,
|
|
185
|
-
monitor=metric.validation_monitor,
|
|
186
|
-
mode=metric.mode,
|
|
187
|
-
verbose=self.config.verbose,
|
|
188
|
-
save_last=self.config.save_last,
|
|
189
|
-
save_top_k=self.config._save_top_k_model_ckpt_input(),
|
|
190
|
-
save_weights_only=self.config.save_weights_only,
|
|
191
|
-
auto_insert_metric_name=False,
|
|
192
|
-
every_n_train_steps=self.config.every_n_train_steps,
|
|
193
|
-
train_time_interval=self.config.train_time_interval,
|
|
194
|
-
every_n_epochs=self.config.every_n_epochs,
|
|
195
|
-
save_on_train_epoch_end=self.config.save_on_train_epoch_end,
|
|
196
|
-
enable_version_counter=self.config.enable_version_counter,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
@override
|
|
200
|
-
def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
201
|
-
if trainer.is_global_zero:
|
|
202
|
-
_link_checkpoint(filepath, linkpath, metadata=True)
|
|
203
|
-
trainer.strategy.barrier()
|
|
204
|
-
|
|
205
|
-
@override
|
|
206
|
-
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
207
|
-
_ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
|
|
File without changes
|