nshtrainer 0.11.7__py3-none-any.whl → 0.11.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
236
236
  error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
237
237
  match on_error:
238
238
  case "warn":
239
- log.warn(error_msg)
239
+ log.warning(error_msg)
240
240
  case "raise":
241
241
  raise ValueError(error_msg)
242
242
  case _:
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
325
325
  ),
326
326
  ]
327
327
  if not candidates:
328
- log.warn(
328
+ log.warning(
329
329
  "No checkpoint candidates found for `best` checkpoint strategy."
330
330
  )
331
331
  continue
332
332
 
333
333
  if (metric := strategy.metric or root_config.primary_metric) is None:
334
- log.warn(
334
+ log.warning(
335
335
  "No metric specified for `best` checkpoint strategy, "
336
336
  "and no primary metric is set in the configuration. "
337
337
  "Skipping strategy."
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
360
360
  ),
361
361
  ]
362
362
  if not candidates:
363
- log.warn(
363
+ log.warning(
364
364
  "No checkpoint candidates found for `last` checkpoint strategy."
365
365
  )
366
366
  continue
@@ -4,7 +4,7 @@ import logging
4
4
  import shutil
5
5
  from collections.abc import Callable
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
8
8
 
9
9
  import nshconfig as C
10
10
  import numpy as np
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  METADATA_PATH_SUFFIX = ".metadata.json"
23
- HPARAMS_PATH_SUFFIX = ".hparams.json"
24
23
 
25
24
 
26
25
  class CheckpointMetadata(C.Config):
26
+ PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
27
+
27
28
  checkpoint_path: Path
28
29
  checkpoint_filename: str
29
30
 
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
39
40
  metrics: dict[str, Any]
40
41
  environment: EnvironmentConfig
41
42
 
43
+ hparams: dict[str, Any] | None
44
+
42
45
  @classmethod
43
46
  def from_file(cls, path: Path):
44
47
  return cls.model_validate_json(path.read_text())
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
55
58
 
56
59
 
57
60
  def _generate_checkpoint_metadata(
58
- config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
61
+ config: "BaseConfig",
62
+ trainer: "Trainer",
63
+ checkpoint_path: Path,
64
+ metadata_path: Path,
59
65
  ):
60
66
  checkpoint_timestamp = datetime.datetime.now()
61
67
  start_timestamp = trainer.start_time()
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
70
76
  metrics[name] = metric
71
77
 
72
78
  return CheckpointMetadata(
73
- checkpoint_path=checkpoint_path,
79
+ # checkpoint_path=checkpoint_path,
80
+ # We should store the path as a relative path
81
+ # to the metadata file to avoid issues with
82
+ # moving the checkpoint directory
83
+ checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
74
84
  checkpoint_filename=checkpoint_path.name,
75
85
  run_id=config.id,
76
86
  name=config.run_name,
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
84
94
  training_time=training_time,
85
95
  metrics=metrics,
86
96
  environment=config.environment,
97
+ hparams=config.model_dump(mode="json"),
87
98
  )
88
99
 
89
100
 
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
93
104
  checkpoint_path: Path,
94
105
  ):
95
106
  config = cast("BaseConfig", model.config)
96
- metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
107
+ metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
+ metadata = _generate_checkpoint_metadata(
109
+ config, trainer, checkpoint_path, metadata_path
110
+ )
97
111
 
98
112
  # Write the metadata to the checkpoint directory
99
113
  try:
100
- metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
101
114
  metadata_path.write_text(metadata.model_dump_json(indent=4))
102
115
  except Exception as e:
103
116
  log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
104
117
  else:
105
118
  log.debug(f"Checkpoint metadata written to {checkpoint_path}")
106
119
 
107
- # Write the hparams to the checkpoint directory
120
+
121
+ def _remove_checkpoint_metadata(checkpoint_path: Path):
122
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
123
  try:
109
- hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
110
- hparams_path.write_text(config.model_dump_json(indent=4))
124
+ path.unlink(missing_ok=True)
111
125
  except Exception as e:
112
- log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
126
+ log.warning(f"Failed to remove {path}: {e}")
113
127
  else:
114
- log.debug(f"Checkpoint metadata written to {checkpoint_path}")
115
-
116
-
117
- def _remove_checkpoint_metadata(checkpoint_path: Path):
118
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
119
- path = checkpoint_path.with_suffix(suffix)
120
- try:
121
- path.unlink(missing_ok=True)
122
- except Exception as e:
123
- log.warning(f"Failed to remove {path}: {e}")
124
- else:
125
- log.debug(f"Removed {path}")
128
+ log.debug(f"Removed {path}")
126
129
 
127
130
 
128
131
  def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
130
133
  _remove_checkpoint_metadata(linked_checkpoint_path)
131
134
 
132
135
  # Link the metadata files to the new checkpoint
133
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
134
- path = checkpoint_path.with_suffix(suffix)
135
- linked_path = linked_checkpoint_path.with_suffix(suffix)
136
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
137
+ linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
138
+ try:
136
139
  try:
137
- try:
138
- linked_path.symlink_to(path)
139
- except OSError:
140
- # on Windows, special permissions are required to create symbolic links as a regular user
141
- # fall back to copying the file
142
- shutil.copy(path, linked_path)
143
- except Exception as e:
144
- log.warning(f"Failed to link {path} to {linked_path}: {e}")
145
- else:
146
- log.debug(f"Linked {path} to {linked_path}")
140
+ linked_path.symlink_to(path)
141
+ except OSError:
142
+ # on Windows, special permissions are required to create symbolic links as a regular user
143
+ # fall back to copying the file
144
+ shutil.copy(path, linked_path)
145
+ except Exception as e:
146
+ log.warning(f"Failed to link {path} to {linked_path}: {e}")
147
+ else:
148
+ log.debug(f"Linked {path} to {linked_path}")
147
149
 
148
150
 
149
151
  def _sort_ckpts_by_metadata(
@@ -6,12 +6,8 @@ from . import checkpoint as checkpoint
6
6
  from .base import CallbackConfigBase as CallbackConfigBase
7
7
  from .checkpoint import BestCheckpoint as BestCheckpoint
8
8
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
- from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
10
- from .checkpoint import (
11
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
12
- )
13
- from .checkpoint import ModelCheckpoint as ModelCheckpoint
14
- from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
9
+ from .checkpoint import LastCheckpoint as LastCheckpoint
10
+ from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
15
11
  from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
16
12
  from .checkpoint import (
17
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -46,8 +42,7 @@ CallbackConfig = Annotated[
46
42
  | GradientSkippingConfig
47
43
  | EMAConfig
48
44
  | BestCheckpointCallbackConfig
49
- | ModelCheckpointCallbackConfig
50
- | LatestEpochCheckpointCallbackConfig
45
+ | LastCheckpointCallbackConfig
51
46
  | OnExceptionCheckpointCallbackConfig
52
47
  | WandbWatchConfig,
53
48
  C.Field(discriminator="name"),
@@ -2,13 +2,9 @@ from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
2
  from .best_checkpoint import (
3
3
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
4
  )
5
- from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
6
- from .latest_epoch_checkpoint import (
7
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
8
- )
9
- from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
10
- from .model_checkpoint import (
11
- ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
5
+ from .last_checkpoint import LastCheckpoint as LastCheckpoint
6
+ from .last_checkpoint import (
7
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
12
8
  )
13
9
  from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
14
10
  from .on_exception_checkpoint import (
@@ -0,0 +1,175 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, Generic, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Checkpoint
10
+ from typing_extensions import TypeVar, override
11
+
12
+ from ..._checkpoint.metadata import CheckpointMetadata
13
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
+ from ..base import CallbackConfigBase
15
+
16
+ if TYPE_CHECKING:
17
+ from ...model.config import BaseConfig
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
23
+ dirpath: str | Path | None = None
24
+ """Directory path to save the checkpoint file."""
25
+
26
+ filename: str | None = None
27
+ """Checkpoint filename. This must not include the extension.
28
+ If None, the default filename will be used."""
29
+
30
+ save_weights_only: bool = False
31
+ """Whether to save only the model's weights or the entire model object."""
32
+
33
+ save_symlink: bool = True
34
+ """Whether to create a symlink to the saved checkpoint."""
35
+
36
+ topk: int | Literal["all"] = 1
37
+ """The number of checkpoints to keep."""
38
+
39
+ @abstractmethod
40
+ def create_checkpoint(
41
+ self,
42
+ root_config: "BaseConfig",
43
+ dirpath: Path,
44
+ ) -> "CheckpointBase": ...
45
+
46
+ @override
47
+ def create_callbacks(self, root_config):
48
+ dirpath = Path(
49
+ self.dirpath
50
+ or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
51
+ )
52
+
53
+ yield self.create_checkpoint(root_config, dirpath)
54
+
55
+
56
+ TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
57
+
58
+
59
+ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
60
+ def __init__(self, config: TConfig, dirpath: Path):
61
+ super().__init__()
62
+
63
+ self.config = config
64
+ self.dirpath = dirpath / self.name()
65
+ self.symlink_dirpath = dirpath
66
+
67
+ self._last_global_step_saved = 0
68
+
69
+ @abstractmethod
70
+ def default_filename(self) -> str: ...
71
+
72
+ @abstractmethod
73
+ def name(self) -> str: ...
74
+
75
+ def extension(self) -> str:
76
+ return ".ckpt"
77
+
78
+ @abstractmethod
79
+ def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
80
+
81
+ def symlink_path(self):
82
+ if not self.config.save_symlink:
83
+ return None
84
+
85
+ return self.symlink_dirpath / f"{self.name()}{self.extension()}"
86
+
87
+ def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
88
+ if (filename := self.config.filename) is None:
89
+ filename = self.default_filename()
90
+ filename = filename.format(**current_metrics)
91
+ return self.dirpath / f"{filename}{self.extension()}"
92
+
93
+ def remove_old_checkpoints(self, trainer: Trainer):
94
+ if (topk := self.config.topk) == "all":
95
+ return
96
+
97
+ # Get all the checkpoint metadata
98
+ metas = [
99
+ CheckpointMetadata.from_file(p)
100
+ for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
101
+ ]
102
+
103
+ # Sort by the topk sort key
104
+ metas = sorted(metas, key=self.topk_sort_key)
105
+
106
+ # Now, the metas are sorted from the best to the worst,
107
+ # so we can remove the worst checkpoints
108
+ for meta in metas[topk:]:
109
+ if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
110
+ log.warning(
111
+ f"Checkpoint file not found: {old_ckpt_path}\n"
112
+ "Skipping removal of the checkpoint metadata."
113
+ )
114
+ continue
115
+
116
+ _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
117
+ log.debug(f"Removed old checkpoint: {old_ckpt_path}")
118
+
119
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
120
+ current_metrics: dict[str, Any] = {
121
+ "epoch": trainer.current_epoch,
122
+ "step": trainer.global_step,
123
+ }
124
+
125
+ for name, value in trainer.callback_metrics.items():
126
+ match value:
127
+ case torch.Tensor() if value.numel() == 1:
128
+ value = value.detach().cpu().item()
129
+ case np.ndarray() if value.size == 1:
130
+ value = value.item()
131
+ case _:
132
+ pass
133
+
134
+ current_metrics[name] = value
135
+
136
+ return current_metrics
137
+
138
+ def save_checkpoints(self, trainer: Trainer):
139
+ if self._should_skip_saving_checkpoint(trainer):
140
+ return
141
+
142
+ # Save the new checkpoint
143
+ filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
144
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
145
+
146
+ if trainer.is_global_zero:
147
+ # Remove old checkpoints
148
+ self.remove_old_checkpoints(trainer)
149
+
150
+ # Create the latest symlink
151
+ if (symlink_filename := self.symlink_path()) is not None:
152
+ symlink_path = self.dirpath / symlink_filename
153
+ _link_checkpoint(filepath, symlink_path, metadata=True)
154
+ log.debug(f"Created latest symlink: {symlink_path}")
155
+
156
+ # Barrier to ensure all processes have saved the checkpoint,
157
+ # deleted the old checkpoints, and created the symlink before continuing
158
+ trainer.strategy.barrier()
159
+
160
+ # Set the last global step saved
161
+ self._last_global_step_saved = trainer.global_step
162
+
163
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
164
+ from lightning.pytorch.trainer.states import TrainerFn
165
+
166
+ return (
167
+ bool(
168
+ getattr(trainer, "fast_dev_run", False)
169
+ ) # disable checkpointing with fast_dev_run
170
+ or trainer.state.fn
171
+ != TrainerFn.FITTING # don't save anything during non-fit
172
+ or trainer.sanity_checking # don't save anything during sanity check
173
+ or self._last_global_step_saved
174
+ == trainer.global_step # already saved at the last step
175
+ )
@@ -1,47 +1,27 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Any, Literal
3
+ from typing import Literal
4
4
 
5
5
  from lightning.pytorch import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Checkpoint
7
- from typing_extensions import override
6
+ from typing_extensions import final, override
7
+
8
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
8
9
 
9
- from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
10
  from ...metrics._config import MetricConfig
12
- from ..base import CallbackConfigBase
11
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
13
12
 
14
13
  log = logging.getLogger(__name__)
15
14
 
16
15
 
17
- class BestCheckpointCallbackConfig(CallbackConfigBase):
16
+ @final
17
+ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
18
18
  name: Literal["best_checkpoint"] = "best_checkpoint"
19
19
 
20
- dirpath: str | Path | None = None
21
- """Directory path to save the checkpoint file."""
22
-
23
- filename: str = "epoch{epoch:02d}_step{step:04d}"
24
- """Checkpoint filename. This must not include the extension."""
25
-
26
- save_weights_only: bool = False
27
- """Whether to save only the model's weights or the entire model object."""
28
-
29
20
  metric: MetricConfig | None = None
30
21
  """Metric to monitor, or `None` to use the default metric."""
31
22
 
32
- best_symlink_filename: str | None = "best"
33
- """Filename for the best symlink. If None, no symlink will be created."""
34
-
35
- save_top_k: int | Literal["all"] = 1
36
- """The number of best checkpoints to keep."""
37
-
38
23
  @override
39
- def create_callbacks(self, root_config):
40
- dirpath = Path(
41
- self.dirpath
42
- or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
43
- )
44
-
24
+ def create_checkpoint(self, root_config, dirpath):
45
25
  # Resolve metric
46
26
  if (metric := self.metric) is None and (
47
27
  metric := root_config.primary_metric
@@ -50,143 +30,41 @@ class BestCheckpointCallbackConfig(CallbackConfigBase):
50
30
  "No metric provided and no primary metric found in the root config"
51
31
  )
52
32
 
53
- yield BestCheckpoint(self, metric, dirpath)
54
-
55
- @property
56
- def _save_top_k_value(self):
57
- return float("inf" if self.save_top_k == "all" else self.save_top_k)
33
+ return BestCheckpoint(self, dirpath, metric)
58
34
 
59
35
 
60
- class BestCheckpoint(Checkpoint):
61
- PREFIX = "best_"
62
- EXTENSION = ".ckpt"
36
+ @final
37
+ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
38
+ @property
39
+ def _metric_name_normalized(self):
40
+ return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
63
41
 
42
+ @override
64
43
  def __init__(
65
44
  self,
66
45
  config: BestCheckpointCallbackConfig,
67
- metric: MetricConfig,
68
46
  dirpath: Path,
47
+ metric: MetricConfig,
69
48
  ):
70
- super().__init__()
71
- self.config = config
49
+ super().__init__(config, dirpath)
72
50
  self.metric = metric
73
- self.dirpath = dirpath
74
-
75
- self._last_global_step_saved = 0 # no need to save when no steps were taken
76
51
 
77
52
  @override
78
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
79
- self._save_best_checkpoint(trainer)
80
-
81
- def _best_symlink_filename(self):
82
- if (filename := self.config.best_symlink_filename) is None:
83
- return None
84
- return f"{filename}{self.EXTENSION}"
53
+ def name(self):
54
+ return f"best_{self._metric_name_normalized}"
85
55
 
86
- def _ckpt_path(self, trainer: Trainer):
87
- filename = self.config.filename.format(
88
- epoch=trainer.current_epoch, step=trainer.global_step
89
- )
90
- filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
91
- return self.dirpath / filename
56
+ @override
57
+ def default_filename(self):
58
+ return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
92
59
 
93
- def _get_metric_value(self, metrics: dict[str, Any]):
94
- return metrics.get(
60
+ @override
61
+ def topk_sort_key(self, metadata: CheckpointMetadata):
62
+ return metadata.metrics.get(
95
63
  self.metric.validation_monitor,
96
64
  float("-inf" if self.metric.mode == "max" else "inf"),
97
65
  )
98
66
 
99
- def _sorted_ckpts(self):
100
- """
101
- Get sorted checkpoints by the metric value.
102
-
103
- Sort order: best -> worst
104
- """
105
- ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
106
- return _sort_ckpts_by_metadata(
107
- ckpt_paths,
108
- key=lambda meta, _: self._get_metric_value(meta.metrics),
109
- reverse=(self.metric.mode == "max"),
110
- )
111
-
112
- def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
113
- # Resolve the symlink filename
114
- if (symlink_filename := self._best_symlink_filename()) is None:
115
- return
116
-
117
- # If the symlink already exists and points to the best checkpoint,
118
- # then we don't need to create a new symlink.
119
- symlink_path = self.dirpath / symlink_filename
120
- if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
121
- return
122
-
123
- _link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
124
- log.debug(f"Created best symlink: {symlink_path}")
125
-
126
- def _save_best_checkpoint(self, trainer: Trainer):
127
- # Skip saving the checkpoint if we're not in the fitting state
128
- if self._should_skip_saving_checkpoint(trainer):
129
- return
130
-
131
- # Get the current metric value
132
- if (current := self._get_metric_value(trainer.callback_metrics)) is None:
133
- log.warning(
134
- f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
135
- )
136
- return
137
-
138
- # Get sorted checkpoints
139
- sorted_ckpts = self._sorted_ckpts()
140
-
141
- # If the current model is worse than the worst checkpoint,
142
- # and we have already saved the maximum number of checkpoints,
143
- # then don't save the current model.
144
- if len(
145
- sorted_ckpts
146
- ) >= self.config._save_top_k_value and not self.metric.is_better(
147
- current,
148
- self._get_metric_value(sorted_ckpts[-1][0].metrics),
149
- ):
150
- return
151
-
152
- # Save the current model
153
- filepath = self._ckpt_path(trainer)
154
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
155
- log.debug(f"Saved best checkpoint: {filepath}")
156
-
157
- if trainer.is_global_zero:
158
- # Get the sorted checkpoints again because now we have added a new checkpoint.
159
- # We could optimize this by adding the new checkpoint to the sorted list,
160
- # and then sorting it in place, but this is simpler.
161
- sorted_ckpts = self._sorted_ckpts()
162
-
163
- # Remove worst checkpoint if we've reached save_top_k
164
- if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
165
- # NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
166
- for _, ckpt_path in sorted_ckpts[topk:]:
167
- _remove_checkpoint(trainer, ckpt_path, metadata=True)
168
-
169
- # Create symlink to best model
170
- if sorted_ckpts:
171
- _, best_ckpt_path = sorted_ckpts[0]
172
- self._create_symlink(trainer, best_ckpt_path)
173
-
174
- # Update the last global step saved
175
- self._last_global_step_saved = trainer.global_step
176
-
177
- # Barrier to ensure all processes have saved the checkpoint before continuing
178
- trainer.strategy.barrier()
179
-
180
- def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
181
- from lightning.pytorch.trainer.states import TrainerFn
182
-
183
- return (
184
- bool(
185
- getattr(trainer, "fast_dev_run", False)
186
- ) # disable checkpointing with fast_dev_run
187
- or trainer.state.fn
188
- != TrainerFn.FITTING # don't save anything during non-fit
189
- or trainer.sanity_checking # don't save anything during sanity check
190
- or self._last_global_step_saved
191
- == trainer.global_step # already saved at the last step
192
- )
67
+ # Events
68
+ @override
69
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
70
+ self.save_checkpoints(trainer)
@@ -0,0 +1,39 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from typing_extensions import final, override
6
+
7
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
8
+
9
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ @final
15
+ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
16
+ name: Literal["last_checkpoint"] = "last_checkpoint"
17
+
18
+ @override
19
+ def create_checkpoint(self, root_config, dirpath):
20
+ return LastCheckpoint(self, dirpath)
21
+
22
+
23
+ @final
24
+ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
25
+ @override
26
+ def name(self):
27
+ return "last"
28
+
29
+ @override
30
+ def default_filename(self):
31
+ return "epoch{epoch:03d}-step{step:07d}"
32
+
33
+ @override
34
+ def topk_sort_key(self, metadata: CheckpointMetadata):
35
+ return metadata.checkpoint_timestamp
36
+
37
+ @override
38
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
39
+ self.save_checkpoints(trainer)
@@ -5,17 +5,15 @@ from .base import LightningModuleBase as LightningModuleBase
5
5
  from .config import BaseConfig as BaseConfig
6
6
  from .config import BaseLoggerConfig as BaseLoggerConfig
7
7
  from .config import BaseProfilerConfig as BaseProfilerConfig
8
+ from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
8
9
  from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
9
10
  from .config import CheckpointSavingConfig as CheckpointSavingConfig
10
11
  from .config import DirectoryConfig as DirectoryConfig
11
12
  from .config import EarlyStoppingConfig as EarlyStoppingConfig
12
13
  from .config import GradientClippingConfig as GradientClippingConfig
13
- from .config import (
14
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
15
- )
14
+ from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
16
15
  from .config import LoggingConfig as LoggingConfig
17
16
  from .config import MetricConfig as MetricConfig
18
- from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
19
17
  from .config import (
20
18
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
21
19
  )
@@ -39,8 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
39
39
  from ..callbacks import (
40
40
  BestCheckpointCallbackConfig,
41
41
  CallbackConfig,
42
- LatestEpochCheckpointCallbackConfig,
43
- ModelCheckpointCallbackConfig,
42
+ LastCheckpointCallbackConfig,
44
43
  OnExceptionCheckpointCallbackConfig,
45
44
  WandbWatchConfig,
46
45
  )
@@ -771,9 +770,8 @@ class ReproducibilityConfig(C.Config):
771
770
 
772
771
 
773
772
  CheckpointCallbackConfig: TypeAlias = Annotated[
774
- ModelCheckpointCallbackConfig
775
- | BestCheckpointCallbackConfig
776
- | LatestEpochCheckpointCallbackConfig
773
+ BestCheckpointCallbackConfig
774
+ | LastCheckpointCallbackConfig
777
775
  | OnExceptionCheckpointCallbackConfig,
778
776
  C.Field(discriminator="name"),
779
777
  ]
@@ -784,9 +782,8 @@ class CheckpointSavingConfig(CallbackConfigBase):
784
782
  """Enable checkpoint saving."""
785
783
 
786
784
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
787
- # ModelCheckpointCallbackConfig(),
788
785
  BestCheckpointCallbackConfig(),
789
- LatestEpochCheckpointCallbackConfig(),
786
+ LastCheckpointCallbackConfig(),
790
787
  OnExceptionCheckpointCallbackConfig(),
791
788
  ]
792
789
  """Checkpoint callback configurations."""
@@ -804,36 +801,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
804
801
 
805
802
  return True
806
803
 
807
- @property
808
- def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
809
- return next(
810
- (
811
- callback
812
- for callback in self.checkpoint_callbacks
813
- if isinstance(callback, ModelCheckpointCallbackConfig)
814
- ),
815
- )
816
-
817
- @property
818
- def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
819
- return next(
820
- (
821
- callback
822
- for callback in self.checkpoint_callbacks
823
- if isinstance(callback, LatestEpochCheckpointCallbackConfig)
824
- ),
825
- )
826
-
827
- @property
828
- def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
829
- return next(
830
- (
831
- callback
832
- for callback in self.checkpoint_callbacks
833
- if isinstance(callback, OnExceptionCheckpointCallbackConfig)
834
- ),
835
- )
836
-
837
804
  @override
838
805
  def create_callbacks(self, root_config: "BaseConfig"):
839
806
  if not self.should_save_checkpoints(root_config):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.7
3
+ Version: 0.11.9
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,19 +1,19 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
- nshtrainer/_checkpoint/loader.py,sha256=_3jBf-k-fJCFfmU8wjDwbnE9rb4WoKYEyQiKGsBOCi4,13777
3
- nshtrainer/_checkpoint/metadata.py,sha256=M9eAZ2xMs36Z1G1xULu9MHZhsHxN8_9mNt3Iv7wuq-I,5069
2
+ nshtrainer/_checkpoint/loader.py,sha256=DSaNR8194kWon4O1svslNsCcN_8vlyLbF0LNCPfUpzI,13789
3
+ nshtrainer/_checkpoint/metadata.py,sha256=n2PMGdA3Jn3BuoyDXVCF9dUnNjuuru5CL9jqMD-X4Vk,4918
4
4
  nshtrainer/_checkpoint/saver.py,sha256=TuSAP39DOOVvSnSukQ9RitMV60JnDg6L27fMRc2uVJc,1358
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
8
8
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
9
- nshtrainer/callbacks/__init__.py,sha256=5lFHe7bNdKxqvw8VRHG18W3R5l34n02Z1wVuQ38_PTg,2488
9
+ nshtrainer/callbacks/__init__.py,sha256=k-DbpIlH2t5-oR3gHGHr8KiyCd_Twers4PcIUM1noqQ,2262
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
- nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=BUO0sWqlwfyxD1UeII5DZ-01SGLiawJAEsL8HjGX4XA,7018
15
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=CQ0IqhuPI7zAxpQLy48kK8qqfVfwXEJoHGRqI4h8xNk,4819
16
- nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=JS1z2YuEiQxk61HgZU1jySzF_pzdfXYO54_qHo-q3CQ,6776
13
+ nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
14
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=9HQSa-toOyjtuDldQ71gaDVBRdryAaB_nRv5Y554tIk,5938
15
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=O6d4SqIYsxpnRj_IYX8A9VLgOBwxTdz-j2FV_nn3BT8,2067
16
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CM8f37dwaYHkjQFfJNTZTzSoF45zEjFRm-Fg1CzYmP4,1037
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
18
18
  nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
19
19
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -52,9 +52,9 @@ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9b
52
52
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
53
53
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
54
54
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
55
- nshtrainer/model/__init__.py,sha256=NpvyQHmGaHB8xdraHmm8l7kDHLmvJSgBNQKkfYqtgyI,1454
55
+ nshtrainer/model/__init__.py,sha256=RlGW5a46DZcqK6cYICYxDaKpZIEj-8zLxoMrl432tno,1429
56
56
  nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
57
- nshtrainer/model/config.py,sha256=48Vx4RiPUGjqEhoHEX4ukOGy6KlI7RmhnShJQaRQ3io,54885
57
+ nshtrainer/model/config.py,sha256=F-doUiqPVw4yepT6FRqhflEiiMEWr_PuFC6lKzFWktA,53809
58
58
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
59
59
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
60
60
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
82
82
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
83
83
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
84
84
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
85
- nshtrainer-0.11.7.dist-info/METADATA,sha256=htPbfKNDbqr1taf0bEvSbl-hQOaPfjKkJhvrWoMz2r0,860
86
- nshtrainer-0.11.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
- nshtrainer-0.11.7.dist-info/RECORD,,
85
+ nshtrainer-0.11.9.dist-info/METADATA,sha256=rFmi4wYXJz8srZhqFl_5ROxfmgFru7jhYUpUp7ZZjMg,860
86
+ nshtrainer-0.11.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
+ nshtrainer-0.11.9.dist-info/RECORD,,
@@ -1,131 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Literal
4
-
5
- from lightning.pytorch import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Checkpoint
7
- from typing_extensions import override
8
-
9
- from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
- from ..base import CallbackConfigBase
12
-
13
- log = logging.getLogger(__name__)
14
-
15
-
16
- class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
17
- name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
18
-
19
- dirpath: str | Path | None = None
20
- """Directory path to save the checkpoint file."""
21
-
22
- filename: str = "epoch{epoch:02d}_step{step:04d}"
23
- """Checkpoint filename. This must not include the extension."""
24
-
25
- save_weights_only: bool = False
26
- """Whether to save only the model's weights or the entire model object."""
27
-
28
- latest_symlink_filename: str | None = "latest"
29
- """Filename for the latest symlink. If None, no symlink will be created."""
30
-
31
- latest_k: int | Literal["all"] = 1
32
- """Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
33
-
34
- @override
35
- def create_callbacks(self, root_config):
36
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
37
- root_config.id, "checkpoint"
38
- )
39
- dirpath = Path(dirpath)
40
-
41
- yield LatestEpochCheckpoint(self, dirpath)
42
-
43
-
44
- class LatestEpochCheckpoint(Checkpoint):
45
- PREFIX = "latest_"
46
- EXTENSION = ".ckpt"
47
-
48
- def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
49
- super().__init__()
50
-
51
- self.config = config
52
- self.dirpath = dirpath
53
-
54
- self._last_global_step_saved = 0
55
-
56
- @override
57
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
58
- self._save_new_checkpoint(trainer)
59
-
60
- def _latest_symlink_filename(self):
61
- if (filename := self.config.latest_symlink_filename) is None:
62
- return None
63
- return f"{filename}{self.EXTENSION}"
64
-
65
- def _ckpt_path(self, trainer: Trainer):
66
- filename = self.config.filename.format(
67
- epoch=trainer.current_epoch, step=trainer.global_step
68
- )
69
- filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
70
- return self.dirpath / filename
71
-
72
- def _remove_old_checkpoints(self, trainer: Trainer):
73
- if (latest_k := self.config.latest_k) == "all":
74
- return
75
-
76
- # Get all configs, ignoring the latest symlink
77
- ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
78
- # Ignore the latest symlink
79
- if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
80
- ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
81
-
82
- # Sort by epoch, then step, then last modified
83
- ckpts = _sort_ckpts_by_metadata(
84
- ckpts,
85
- key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
86
- reverse=True,
87
- )
88
-
89
- # Remove all but the latest k checkpoints
90
- # NOTE: We add 1 to the latest_k here because
91
- # we're about to save a new checkpoint.
92
- for _, ckpt_path in ckpts[latest_k:]:
93
- _remove_checkpoint(trainer, ckpt_path, metadata=True)
94
-
95
- def _save_new_checkpoint(self, trainer: Trainer):
96
- if self._should_skip_saving_checkpoint(trainer):
97
- return
98
-
99
- # Save the new checkpoint
100
- filepath = self._ckpt_path(trainer)
101
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
102
-
103
- if trainer.is_global_zero:
104
- # Remove old checkpoints
105
- self._remove_old_checkpoints(trainer)
106
-
107
- # Create the latest symlink
108
- if (symlink_filename := self._latest_symlink_filename()) is not None:
109
- symlink_path = self.dirpath / symlink_filename
110
- _link_checkpoint(filepath, symlink_path, metadata=True)
111
- log.debug(f"Created latest symlink: {symlink_path}")
112
-
113
- # Set the last global step saved
114
- self._last_global_step_saved = trainer.global_step
115
-
116
- # Barrier to ensure all processes have saved the checkpoint before continuing
117
- trainer.strategy.barrier()
118
-
119
- def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
120
- from lightning.pytorch.trainer.states import TrainerFn
121
-
122
- return (
123
- bool(
124
- getattr(trainer, "fast_dev_run", False)
125
- ) # disable checkpointing with fast_dev_run
126
- or trainer.state.fn
127
- != TrainerFn.FITTING # don't save anything during non-fit
128
- or trainer.sanity_checking # don't save anything during sanity check
129
- or self._last_global_step_saved
130
- == trainer.global_step # already saved at the last step
131
- )
@@ -1,207 +0,0 @@
1
- import logging
2
- import re
3
- from datetime import timedelta
4
- from pathlib import Path
5
- from typing import TYPE_CHECKING, Literal
6
-
7
- from lightning.pytorch import Trainer
8
- from lightning.pytorch.callbacks.model_checkpoint import (
9
- ModelCheckpoint as _ModelCheckpoint,
10
- )
11
- from typing_extensions import override
12
-
13
- from ..._checkpoint.saver import _link_checkpoint
14
- from ..._checkpoint.saver import _remove_checkpoint as _ckpt_saver_remove_checkpoint
15
- from ...metrics import MetricConfig
16
- from ..base import CallbackConfigBase
17
-
18
- if TYPE_CHECKING:
19
- from ...model.config import BaseConfig
20
-
21
- log = logging.getLogger(__name__)
22
-
23
-
24
- def _convert_string(input_string: str):
25
- # Find all variables enclosed in curly braces
26
- variables = re.findall(r"\{(.*?)\}", input_string)
27
-
28
- # Replace each variable with its corresponding key-value pair
29
- output_string = input_string
30
- for variable in variables:
31
- # If the name is something like {variable:format}, we shouldn't process the format.
32
- key_name = variable
33
- if ":" in variable:
34
- key_name, _ = variable.split(":", 1)
35
- continue
36
-
37
- # Replace '/' with '_' in the key name
38
- key_name = key_name.replace("/", "_")
39
- output_string = output_string.replace(
40
- f"{{{variable}}}", f"{key_name}={{{variable}}}"
41
- )
42
-
43
- return output_string
44
-
45
-
46
- class ModelCheckpointCallbackConfig(CallbackConfigBase):
47
- """Arguments for the ModelCheckpoint callback."""
48
-
49
- name: Literal["model_checkpoint"] = "model_checkpoint"
50
-
51
- dirpath: str | Path | None = None
52
- """
53
- Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
54
- """
55
-
56
- filename: str | None = None
57
- """
58
- Checkpoint filename.
59
- If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
60
- """
61
-
62
- metric: MetricConfig | None = None
63
- """
64
- Metric to monitor for saving checkpoints.
65
- If None, the primary metric of the runner will be used, if available.
66
- """
67
-
68
- verbose: bool = False
69
- """Verbosity mode. If True, print additional information about checkpoints."""
70
-
71
- save_last: Literal[True, False, "link"] | None = "link"
72
- """
73
- Whether to save the last checkpoint.
74
- If True, saves a copy of the last checkpoint separately.
75
- If "link", creates a symbolic link to the last checkpoint.
76
- """
77
-
78
- save_top_k: int | Literal["all"] = 1
79
- """
80
- Number of best models to save.
81
- If "all" or -1, all models are saved.
82
- If 0, no models are saved.
83
- """
84
-
85
- save_weights_only: bool = False
86
- """Whether to save only the model's weights or the entire model object."""
87
-
88
- auto_insert_metric_name: bool = True
89
- """Whether to automatically insert the metric name in the checkpoint filename."""
90
-
91
- every_n_train_steps: int | None = None
92
- """
93
- Number of training steps between checkpoints.
94
- If None or 0, no checkpoints are saved during training.
95
- """
96
-
97
- train_time_interval: timedelta | None = None
98
- """
99
- Time interval between checkpoints during training.
100
- If None, no checkpoints are saved during training based on time.
101
- """
102
-
103
- every_n_epochs: int | None = None
104
- """
105
- Number of epochs between checkpoints.
106
- If None or 0, no checkpoints are saved at the end of epochs.
107
- """
108
-
109
- save_on_train_epoch_end: bool | None = None
110
- """
111
- Whether to run checkpointing at the end of the training epoch.
112
- If False, checkpointing runs at the end of the validation.
113
- """
114
-
115
- enable_version_counter: bool = True
116
- """Whether to append a version to the existing file name."""
117
-
118
- auto_append_metric: bool = True
119
- """If enabled, this will automatically add "-{monitor}" to the filename."""
120
-
121
- def metric_or_default(self, root_config: "BaseConfig"):
122
- if self.metric is not None:
123
- return self.metric
124
- if root_config.primary_metric is not None:
125
- return root_config.primary_metric
126
- raise ValueError("Primary metric must be provided if metric is not specified.")
127
-
128
- def resolve_filename(self, root_config: "BaseConfig"):
129
- metric = self.metric_or_default(root_config)
130
-
131
- filename = self.filename
132
- if not filename:
133
- filename = "{epoch}-{step}"
134
- if self.auto_append_metric:
135
- filename = f"{filename}-{{{metric.validation_monitor}}}"
136
-
137
- if self.auto_insert_metric_name and filename:
138
- new_filename = _convert_string(filename)
139
- log.critical(
140
- f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
141
- )
142
- filename = new_filename
143
-
144
- return filename
145
-
146
- @override
147
- def create_callbacks(self, root_config):
148
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
149
- root_config.id, "checkpoint"
150
- )
151
-
152
- metric = self.metric_or_default(root_config)
153
- filename = self.resolve_filename(root_config)
154
-
155
- yield ModelCheckpoint(
156
- self,
157
- dirpath=Path(dirpath),
158
- filename=filename,
159
- metric=metric,
160
- )
161
-
162
- def _save_top_k_model_ckpt_input(self):
163
- if self.save_top_k == "all":
164
- return -1
165
- return self.save_top_k
166
-
167
-
168
- class ModelCheckpoint(_ModelCheckpoint):
169
- CHECKPOINT_NAME_LAST = "best"
170
-
171
- @override
172
- def __init__(
173
- self,
174
- config: ModelCheckpointCallbackConfig,
175
- dirpath: Path,
176
- filename: str,
177
- metric: MetricConfig,
178
- ):
179
- self.config = config
180
- del config
181
-
182
- super().__init__(
183
- dirpath=dirpath,
184
- filename=filename,
185
- monitor=metric.validation_monitor,
186
- mode=metric.mode,
187
- verbose=self.config.verbose,
188
- save_last=self.config.save_last,
189
- save_top_k=self.config._save_top_k_model_ckpt_input(),
190
- save_weights_only=self.config.save_weights_only,
191
- auto_insert_metric_name=False,
192
- every_n_train_steps=self.config.every_n_train_steps,
193
- train_time_interval=self.config.train_time_interval,
194
- every_n_epochs=self.config.every_n_epochs,
195
- save_on_train_epoch_end=self.config.save_on_train_epoch_end,
196
- enable_version_counter=self.config.enable_version_counter,
197
- )
198
-
199
- @override
200
- def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
201
- if trainer.is_global_zero:
202
- _link_checkpoint(filepath, linkpath, metadata=True)
203
- trainer.strategy.barrier()
204
-
205
- @override
206
- def _remove_checkpoint(self, trainer: Trainer, filepath: str):
207
- _ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)