nshtrainer 0.11.6__py3-none-any.whl → 0.11.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
236
236
  error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
237
237
  match on_error:
238
238
  case "warn":
239
- log.warn(error_msg)
239
+ log.warning(error_msg)
240
240
  case "raise":
241
241
  raise ValueError(error_msg)
242
242
  case _:
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
325
325
  ),
326
326
  ]
327
327
  if not candidates:
328
- log.warn(
328
+ log.warning(
329
329
  "No checkpoint candidates found for `best` checkpoint strategy."
330
330
  )
331
331
  continue
332
332
 
333
333
  if (metric := strategy.metric or root_config.primary_metric) is None:
334
- log.warn(
334
+ log.warning(
335
335
  "No metric specified for `best` checkpoint strategy, "
336
336
  "and no primary metric is set in the configuration. "
337
337
  "Skipping strategy."
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
360
360
  ),
361
361
  ]
362
362
  if not candidates:
363
- log.warn(
363
+ log.warning(
364
364
  "No checkpoint candidates found for `last` checkpoint strategy."
365
365
  )
366
366
  continue
@@ -4,7 +4,7 @@ import logging
4
4
  import shutil
5
5
  from collections.abc import Callable
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
8
8
 
9
9
  import nshconfig as C
10
10
  import numpy as np
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  METADATA_PATH_SUFFIX = ".metadata.json"
23
- HPARAMS_PATH_SUFFIX = ".hparams.json"
24
23
 
25
24
 
26
25
  class CheckpointMetadata(C.Config):
26
+ PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
27
+
27
28
  checkpoint_path: Path
28
29
  checkpoint_filename: str
29
30
 
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
39
40
  metrics: dict[str, Any]
40
41
  environment: EnvironmentConfig
41
42
 
43
+ hparams: dict[str, Any] | None
44
+
42
45
  @classmethod
43
46
  def from_file(cls, path: Path):
44
47
  return cls.model_validate_json(path.read_text())
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
55
58
 
56
59
 
57
60
  def _generate_checkpoint_metadata(
58
- config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
61
+ config: "BaseConfig",
62
+ trainer: "Trainer",
63
+ checkpoint_path: Path,
64
+ metadata_path: Path,
59
65
  ):
60
66
  checkpoint_timestamp = datetime.datetime.now()
61
67
  start_timestamp = trainer.start_time()
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
70
76
  metrics[name] = metric
71
77
 
72
78
  return CheckpointMetadata(
73
- checkpoint_path=checkpoint_path,
79
+ # checkpoint_path=checkpoint_path,
80
+ # We should store the path as a relative path
81
+ # to the metadata file to avoid issues with
82
+ # moving the checkpoint directory
83
+ checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
74
84
  checkpoint_filename=checkpoint_path.name,
75
85
  run_id=config.id,
76
86
  name=config.run_name,
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
84
94
  training_time=training_time,
85
95
  metrics=metrics,
86
96
  environment=config.environment,
97
+ hparams=config.model_dump(mode="json"),
87
98
  )
88
99
 
89
100
 
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
93
104
  checkpoint_path: Path,
94
105
  ):
95
106
  config = cast("BaseConfig", model.config)
96
- metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
107
+ metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
+ metadata = _generate_checkpoint_metadata(
109
+ config, trainer, checkpoint_path, metadata_path
110
+ )
97
111
 
98
112
  # Write the metadata to the checkpoint directory
99
113
  try:
100
- metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
101
114
  metadata_path.write_text(metadata.model_dump_json(indent=4))
102
115
  except Exception as e:
103
116
  log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
104
117
  else:
105
118
  log.debug(f"Checkpoint metadata written to {checkpoint_path}")
106
119
 
107
- # Write the hparams to the checkpoint directory
120
+
121
+ def _remove_checkpoint_metadata(checkpoint_path: Path):
122
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
123
  try:
109
- hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
110
- hparams_path.write_text(config.model_dump_json(indent=4))
124
+ path.unlink(missing_ok=True)
111
125
  except Exception as e:
112
- log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
126
+ log.warning(f"Failed to remove {path}: {e}")
113
127
  else:
114
- log.debug(f"Checkpoint metadata written to {checkpoint_path}")
115
-
116
-
117
- def _remove_checkpoint_metadata(checkpoint_path: Path):
118
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
119
- path = checkpoint_path.with_suffix(suffix)
120
- try:
121
- path.unlink(missing_ok=True)
122
- except Exception as e:
123
- log.warning(f"Failed to remove {path}: {e}")
124
- else:
125
- log.debug(f"Removed {path}")
128
+ log.debug(f"Removed {path}")
126
129
 
127
130
 
128
131
  def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
130
133
  _remove_checkpoint_metadata(linked_checkpoint_path)
131
134
 
132
135
  # Link the metadata files to the new checkpoint
133
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
134
- path = checkpoint_path.with_suffix(suffix)
135
- linked_path = linked_checkpoint_path.with_suffix(suffix)
136
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
137
+ linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
138
+ try:
136
139
  try:
137
- try:
138
- linked_path.symlink_to(path)
139
- except OSError:
140
- # on Windows, special permissions are required to create symbolic links as a regular user
141
- # fall back to copying the file
142
- shutil.copy(path, linked_path)
143
- except Exception as e:
144
- log.warning(f"Failed to link {path} to {linked_path}: {e}")
145
- else:
146
- log.debug(f"Linked {path} to {linked_path}")
140
+ linked_path.symlink_to(path)
141
+ except OSError:
142
+ # on Windows, special permissions are required to create symbolic links as a regular user
143
+ # fall back to copying the file
144
+ shutil.copy(path, linked_path)
145
+ except Exception as e:
146
+ log.warning(f"Failed to link {path} to {linked_path}: {e}")
147
+ else:
148
+ log.debug(f"Linked {path} to {linked_path}")
147
149
 
148
150
 
149
151
  def _sort_ckpts_by_metadata(
@@ -8,11 +8,9 @@ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
8
 
9
9
 
10
10
  def _link_checkpoint(
11
- trainer: Trainer,
12
11
  filepath: str | Path | os.PathLike,
13
12
  linkpath: str | Path | os.PathLike,
14
13
  *,
15
- barrier: bool,
16
14
  metadata: bool,
17
15
  ):
18
16
  if not isinstance(filepath, Path):
@@ -20,26 +18,23 @@ def _link_checkpoint(
20
18
  if not isinstance(linkpath, Path):
21
19
  linkpath = Path(linkpath)
22
20
 
23
- if trainer.is_global_zero:
24
- if linkpath.exists():
25
- if linkpath.is_symlink() or linkpath.is_file():
26
- linkpath.unlink()
27
- elif linkpath.is_dir():
28
- shutil.rmtree(linkpath)
29
- _remove_checkpoint_metadata(linkpath)
21
+ if linkpath.exists():
22
+ if linkpath.is_symlink() or linkpath.is_file():
23
+ linkpath.unlink()
24
+ elif linkpath.is_dir():
25
+ shutil.rmtree(linkpath)
26
+ _remove_checkpoint_metadata(linkpath)
30
27
 
31
- try:
32
- target_path = filepath.relative_to(linkpath.parent)
33
- linkpath.symlink_to(target_path)
34
- except OSError:
35
- # on Windows, special permissions are required to create symbolic links as a regular user
36
- # fall back to copying the file
37
- shutil.copy(filepath, linkpath)
28
+ try:
29
+ target_path = filepath.relative_to(linkpath.parent)
30
+ linkpath.symlink_to(target_path)
31
+ except OSError:
32
+ # on Windows, special permissions are required to create symbolic links as a regular user
33
+ # fall back to copying the file
34
+ shutil.copy(filepath, linkpath)
38
35
 
39
- if metadata:
40
- _link_checkpoint_metadata(filepath, linkpath)
41
- if barrier:
42
- trainer.strategy.barrier()
36
+ if metadata:
37
+ _link_checkpoint_metadata(filepath, linkpath)
43
38
 
44
39
 
45
40
  def _remove_checkpoint(
@@ -47,15 +42,10 @@ def _remove_checkpoint(
47
42
  filepath: str | Path | os.PathLike,
48
43
  *,
49
44
  metadata: bool,
50
- barrier: bool,
51
45
  ):
52
46
  if not isinstance(filepath, Path):
53
47
  filepath = Path(filepath)
54
48
 
55
- if trainer.is_global_zero:
56
- trainer.strategy.remove_checkpoint(filepath)
57
- if metadata:
58
- _remove_checkpoint_metadata(filepath)
59
-
60
- if barrier:
61
- trainer.strategy.barrier()
49
+ trainer.strategy.remove_checkpoint(filepath)
50
+ if metadata:
51
+ _remove_checkpoint_metadata(filepath)
@@ -6,6 +6,8 @@ from . import checkpoint as checkpoint
6
6
  from .base import CallbackConfigBase as CallbackConfigBase
7
7
  from .checkpoint import BestCheckpoint as BestCheckpoint
8
8
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
+ from .checkpoint import LastCheckpoint as LastCheckpoint
10
+ from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
9
11
  from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
10
12
  from .checkpoint import (
11
13
  LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
@@ -46,6 +48,7 @@ CallbackConfig = Annotated[
46
48
  | GradientSkippingConfig
47
49
  | EMAConfig
48
50
  | BestCheckpointCallbackConfig
51
+ | LastCheckpointCallbackConfig
49
52
  | ModelCheckpointCallbackConfig
50
53
  | LatestEpochCheckpointCallbackConfig
51
54
  | OnExceptionCheckpointCallbackConfig
@@ -2,6 +2,10 @@ from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
2
  from .best_checkpoint import (
3
3
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
4
  )
5
+ from .last_checkpoint import LastCheckpoint as LastCheckpoint
6
+ from .last_checkpoint import (
7
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
8
+ )
5
9
  from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
6
10
  from .latest_epoch_checkpoint import (
7
11
  LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
@@ -0,0 +1,175 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, Generic, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Checkpoint
10
+ from typing_extensions import TypeVar, override
11
+
12
+ from ..._checkpoint.metadata import CheckpointMetadata
13
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
+ from ..base import CallbackConfigBase
15
+
16
+ if TYPE_CHECKING:
17
+ from ...model.config import BaseConfig
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
23
+ dirpath: str | Path | None = None
24
+ """Directory path to save the checkpoint file."""
25
+
26
+ filename: str | None = None
27
+ """Checkpoint filename. This must not include the extension.
28
+ If None, the default filename will be used."""
29
+
30
+ save_weights_only: bool = False
31
+ """Whether to save only the model's weights or the entire model object."""
32
+
33
+ save_symlink: bool = True
34
+ """Whether to create a symlink to the saved checkpoint."""
35
+
36
+ topk: int | Literal["all"] = 1
37
+ """The number of checkpoints to keep."""
38
+
39
+ @abstractmethod
40
+ def create_checkpoint(
41
+ self,
42
+ root_config: "BaseConfig",
43
+ dirpath: Path,
44
+ ) -> "CheckpointBase": ...
45
+
46
+ @override
47
+ def create_callbacks(self, root_config):
48
+ dirpath = Path(
49
+ self.dirpath
50
+ or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
51
+ )
52
+
53
+ yield self.create_checkpoint(root_config, dirpath)
54
+
55
+
56
+ TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
57
+
58
+
59
+ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
60
+ def __init__(self, config: TConfig, dirpath: Path):
61
+ super().__init__()
62
+
63
+ self.config = config
64
+ self.dirpath = dirpath / self.name()
65
+ self.symlink_dirpath = dirpath
66
+
67
+ self._last_global_step_saved = 0
68
+
69
+ @abstractmethod
70
+ def default_filename(self) -> str: ...
71
+
72
+ @abstractmethod
73
+ def name(self) -> str: ...
74
+
75
+ def extension(self) -> str:
76
+ return ".ckpt"
77
+
78
+ @abstractmethod
79
+ def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
80
+
81
+ def symlink_path(self):
82
+ if not self.config.save_symlink:
83
+ return None
84
+
85
+ return self.symlink_dirpath / f"{self.name()}{self.extension()}"
86
+
87
+ def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
88
+ if (filename := self.config.filename) is None:
89
+ filename = self.default_filename()
90
+ filename = filename.format(**current_metrics)
91
+ return self.dirpath / f"{filename}{self.extension()}"
92
+
93
+ def remove_old_checkpoints(self, trainer: Trainer):
94
+ if (topk := self.config.topk) == "all":
95
+ return
96
+
97
+ # Get all the checkpoint metadata
98
+ metas = [
99
+ CheckpointMetadata.from_file(p)
100
+ for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
101
+ ]
102
+
103
+ # Sort by the topk sort key
104
+ metas = sorted(metas, key=self.topk_sort_key)
105
+
106
+ # Now, the metas are sorted from the best to the worst,
107
+ # so we can remove the worst checkpoints
108
+ for meta in metas[topk:]:
109
+ if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
110
+ log.warning(
111
+ f"Checkpoint file not found: {old_ckpt_path}\n"
112
+ "Skipping removal of the checkpoint metadata."
113
+ )
114
+ continue
115
+
116
+ _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
117
+ log.debug(f"Removed old checkpoint: {old_ckpt_path}")
118
+
119
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
120
+ current_metrics: dict[str, Any] = {
121
+ "epoch": trainer.current_epoch,
122
+ "step": trainer.global_step,
123
+ }
124
+
125
+ for name, value in trainer.callback_metrics.items():
126
+ match value:
127
+ case torch.Tensor() if value.numel() == 1:
128
+ value = value.detach().cpu().item()
129
+ case np.ndarray() if value.size == 1:
130
+ value = value.item()
131
+ case _:
132
+ pass
133
+
134
+ current_metrics[name] = value
135
+
136
+ return current_metrics
137
+
138
+ def save_checkpoints(self, trainer: Trainer):
139
+ if self._should_skip_saving_checkpoint(trainer):
140
+ return
141
+
142
+ # Save the new checkpoint
143
+ filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
144
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
145
+
146
+ if trainer.is_global_zero:
147
+ # Remove old checkpoints
148
+ self.remove_old_checkpoints(trainer)
149
+
150
+ # Create the latest symlink
151
+ if (symlink_filename := self.symlink_path()) is not None:
152
+ symlink_path = self.dirpath / symlink_filename
153
+ _link_checkpoint(filepath, symlink_path, metadata=True)
154
+ log.debug(f"Created latest symlink: {symlink_path}")
155
+
156
+ # Barrier to ensure all processes have saved the checkpoint,
157
+ # deleted the old checkpoints, and created the symlink before continuing
158
+ trainer.strategy.barrier()
159
+
160
+ # Set the last global step saved
161
+ self._last_global_step_saved = trainer.global_step
162
+
163
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
164
+ from lightning.pytorch.trainer.states import TrainerFn
165
+
166
+ return (
167
+ bool(
168
+ getattr(trainer, "fast_dev_run", False)
169
+ ) # disable checkpointing with fast_dev_run
170
+ or trainer.state.fn
171
+ != TrainerFn.FITTING # don't save anything during non-fit
172
+ or trainer.sanity_checking # don't save anything during sanity check
173
+ or self._last_global_step_saved
174
+ == trainer.global_step # already saved at the last step
175
+ )
@@ -1,47 +1,27 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Any, Literal
3
+ from typing import Literal
4
4
 
5
5
  from lightning.pytorch import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Checkpoint
7
- from typing_extensions import override
6
+ from typing_extensions import final, override
7
+
8
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
8
9
 
9
- from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
10
  from ...metrics._config import MetricConfig
12
- from ..base import CallbackConfigBase
11
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
13
12
 
14
13
  log = logging.getLogger(__name__)
15
14
 
16
15
 
17
- class BestCheckpointCallbackConfig(CallbackConfigBase):
16
+ @final
17
+ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
18
18
  name: Literal["best_checkpoint"] = "best_checkpoint"
19
19
 
20
- dirpath: str | Path | None = None
21
- """Directory path to save the checkpoint file."""
22
-
23
- filename: str = "epoch{epoch:02d}_step{step:04d}"
24
- """Checkpoint filename. This must not include the extension."""
25
-
26
- save_weights_only: bool = False
27
- """Whether to save only the model's weights or the entire model object."""
28
-
29
20
  metric: MetricConfig | None = None
30
21
  """Metric to monitor, or `None` to use the default metric."""
31
22
 
32
- best_symlink_filename: str | None = "best"
33
- """Filename for the best symlink. If None, no symlink will be created."""
34
-
35
- save_top_k: int | Literal["all"] = 1
36
- """The number of best checkpoints to keep."""
37
-
38
23
  @override
39
- def create_callbacks(self, root_config):
40
- dirpath = Path(
41
- self.dirpath
42
- or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
43
- )
44
-
24
+ def create_checkpoint(self, root_config, dirpath):
45
25
  # Resolve metric
46
26
  if (metric := self.metric) is None and (
47
27
  metric := root_config.primary_metric
@@ -50,147 +30,41 @@ class BestCheckpointCallbackConfig(CallbackConfigBase):
50
30
  "No metric provided and no primary metric found in the root config"
51
31
  )
52
32
 
53
- yield BestCheckpoint(self, metric, dirpath)
54
-
55
- @property
56
- def _save_top_k_value(self):
57
- return float("inf" if self.save_top_k == "all" else self.save_top_k)
33
+ return BestCheckpoint(self, dirpath, metric)
58
34
 
59
35
 
60
- class BestCheckpoint(Checkpoint):
61
- PREFIX = "best_"
62
- EXTENSION = ".ckpt"
36
+ @final
37
+ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
38
+ @property
39
+ def _metric_name_normalized(self):
40
+ return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
63
41
 
42
+ @override
64
43
  def __init__(
65
44
  self,
66
45
  config: BestCheckpointCallbackConfig,
67
- metric: MetricConfig,
68
46
  dirpath: Path,
47
+ metric: MetricConfig,
69
48
  ):
70
- super().__init__()
71
- self.config = config
49
+ super().__init__(config, dirpath)
72
50
  self.metric = metric
73
- self.dirpath = dirpath
74
-
75
- self._last_global_step_saved = 0 # no need to save when no steps were taken
76
51
 
77
52
  @override
78
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
79
- self._save_best_checkpoint(trainer)
53
+ def name(self):
54
+ return f"best_{self._metric_name_normalized}"
80
55
 
81
- def _best_symlink_filename(self):
82
- if (filename := self.config.best_symlink_filename) is None:
83
- return None
84
- return f"{filename}{self.EXTENSION}"
85
-
86
- def _ckpt_path(self, trainer: Trainer):
87
- filename = self.config.filename.format(
88
- epoch=trainer.current_epoch, step=trainer.global_step
89
- )
90
- filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
91
- return self.dirpath / filename
92
-
93
- def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
94
- for ckpt_path in ckpt_paths:
95
- _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
56
+ @override
57
+ def default_filename(self):
58
+ return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
96
59
 
97
- def _get_metric_value(self, metrics: dict[str, Any]):
98
- return metrics.get(
60
+ @override
61
+ def topk_sort_key(self, metadata: CheckpointMetadata):
62
+ return metadata.metrics.get(
99
63
  self.metric.validation_monitor,
100
64
  float("-inf" if self.metric.mode == "max" else "inf"),
101
65
  )
102
66
 
103
- def _sorted_ckpts(self):
104
- ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
105
- return _sort_ckpts_by_metadata(
106
- ckpt_paths,
107
- key=lambda meta, _: self._get_metric_value(meta.metrics),
108
- reverse=(self.metric.mode == "min"),
109
- )
110
-
111
- def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
112
- # Resolve the symlink filename
113
- if (symlink_filename := self._best_symlink_filename()) is None:
114
- return
115
-
116
- # If the symlink already exists and points to the best checkpoint,
117
- # then we don't need to create a new symlink.
118
- symlink_path = self.dirpath / symlink_filename
119
- if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
120
- return
121
-
122
- _link_checkpoint(
123
- trainer,
124
- best_ckpt_path,
125
- symlink_path,
126
- metadata=True,
127
- barrier=False,
128
- )
129
- log.debug(f"Created best symlink: {symlink_path}")
130
-
131
- def _save_best_checkpoint(self, trainer: Trainer):
132
- # Skip saving the checkpoint if we're not in the fitting state
133
- if self._should_skip_saving_checkpoint(trainer):
134
- return
135
-
136
- # Get the current metric value
137
- if (current := self._get_metric_value(trainer.callback_metrics)) is None:
138
- log.warning(
139
- f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
140
- )
141
- return
142
-
143
- # Get sorted checkpoints
144
- sorted_ckpts = self._sorted_ckpts()
145
-
146
- # If the current model is worse than the worst checkpoint,
147
- # and we have already saved the maximum number of checkpoints,
148
- # then don't save the current model.
149
- if len(
150
- sorted_ckpts
151
- ) >= self.config._save_top_k_value and not self.metric.is_better(
152
- current,
153
- self._get_metric_value(sorted_ckpts[-1][0].metrics),
154
- ):
155
- return
156
-
157
- # Save the current model
158
- filepath = self._ckpt_path(trainer)
159
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
160
- log.debug(f"Saved best checkpoint: {filepath}")
161
-
162
- # Remove worst checkpoint if we've reached save_top_k
163
- # NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
164
- if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
165
- # Get the sorted checkpoints again because now we have added a new checkpoint.
166
- # We could optimize this by adding the new checkpoint to the sorted list,
167
- # and then sorting it in place, but this is simpler.
168
- sorted_ckpts = self._sorted_ckpts()
169
- self._remove_checkpoints(
170
- trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
171
- )
172
-
173
- # Create symlink to best model
174
- if sorted_ckpts:
175
- _, best_ckpt_path = sorted_ckpts[0]
176
- self._create_symlink(trainer, best_ckpt_path)
177
-
178
- # Update the last global step saved
179
- self._last_global_step_saved = trainer.global_step
180
-
181
- # Barrier to ensure all processes have saved the checkpoint before continuing
182
- trainer.strategy.barrier()
183
-
184
- def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
185
- from lightning.pytorch.trainer.states import TrainerFn
186
-
187
- return (
188
- bool(
189
- getattr(trainer, "fast_dev_run", False)
190
- ) # disable checkpointing with fast_dev_run
191
- or trainer.state.fn
192
- != TrainerFn.FITTING # don't save anything during non-fit
193
- or trainer.sanity_checking # don't save anything during sanity check
194
- or self._last_global_step_saved
195
- == trainer.global_step # already saved at the last step
196
- )
67
+ # Events
68
+ @override
69
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
70
+ self.save_checkpoints(trainer)
@@ -0,0 +1,39 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from typing_extensions import final, override
6
+
7
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
8
+
9
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ @final
15
+ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
16
+ name: Literal["last_checkpoint"] = "last_checkpoint"
17
+
18
+ @override
19
+ def create_checkpoint(self, root_config, dirpath):
20
+ return LastCheckpoint(self, dirpath)
21
+
22
+
23
+ @final
24
+ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
25
+ @override
26
+ def name(self):
27
+ return "last"
28
+
29
+ @override
30
+ def default_filename(self):
31
+ return "epoch{epoch:03d}-step{step:07d}"
32
+
33
+ @override
34
+ def topk_sort_key(self, metadata: CheckpointMetadata):
35
+ return metadata.checkpoint_timestamp
36
+
37
+ @override
38
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
39
+ self.save_checkpoints(trainer)
@@ -19,7 +19,7 @@ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
19
19
  dirpath: str | Path | None = None
20
20
  """Directory path to save the checkpoint file."""
21
21
 
22
- filename: str = "epoch{epoch:02d}_step{step:04d}"
22
+ filename: str = "epoch{epoch:03d}_step{step:07d}"
23
23
  """Checkpoint filename. This must not include the extension."""
24
24
 
25
25
  save_weights_only: bool = False
@@ -51,6 +51,8 @@ class LatestEpochCheckpoint(Checkpoint):
51
51
  self.config = config
52
52
  self.dirpath = dirpath
53
53
 
54
+ self._last_global_step_saved = 0
55
+
54
56
  @override
55
57
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
56
58
  self._save_new_checkpoint(trainer)
@@ -67,53 +69,63 @@ class LatestEpochCheckpoint(Checkpoint):
67
69
  filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
68
70
  return self.dirpath / filename
69
71
 
70
- def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
- for ckpt_path in ckpt_paths:
72
- _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
73
-
74
72
  def _remove_old_checkpoints(self, trainer: Trainer):
75
73
  if (latest_k := self.config.latest_k) == "all":
76
74
  return
77
75
 
78
- # NOTE: We add 1 to the latest_k here because
79
- # we're about to save a new checkpoint.
80
- latest_k += 1
81
-
82
76
  # Get all configs, ignoring the latest symlink
83
- ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
77
+ ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
84
78
  # Ignore the latest symlink
85
79
  if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
86
- ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
80
+ ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
87
81
 
88
82
  # Sort by epoch, then step, then last modified
89
- metadata_and_ckpt_paths = _sort_ckpts_by_metadata(
90
- ckpt_paths,
83
+ ckpts = _sort_ckpts_by_metadata(
84
+ ckpts,
91
85
  key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
92
86
  reverse=True,
93
87
  )
94
88
 
95
89
  # Remove all but the latest k checkpoints
96
- ckpts_to_remove = metadata_and_ckpt_paths[latest_k:]
97
- self._remove_checkpoints(trainer, [p for _, p in ckpts_to_remove])
90
+ # NOTE: We add 1 to the latest_k here because
91
+ # we're about to save a new checkpoint.
92
+ for _, ckpt_path in ckpts[latest_k:]:
93
+ _remove_checkpoint(trainer, ckpt_path, metadata=True)
98
94
 
99
95
  def _save_new_checkpoint(self, trainer: Trainer):
100
- # Remove old checkpoints
101
- if trainer.is_global_zero:
102
- self._remove_old_checkpoints(trainer)
103
- trainer.strategy.barrier()
96
+ if self._should_skip_saving_checkpoint(trainer):
97
+ return
104
98
 
105
99
  # Save the new checkpoint
106
100
  filepath = self._ckpt_path(trainer)
107
101
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
108
102
 
109
- # Create the latest symlink
110
- if (symlink_filename := self._latest_symlink_filename()) is not None:
111
- symlink_path = self.dirpath / symlink_filename
112
- _link_checkpoint(
113
- trainer,
114
- filepath,
115
- symlink_path,
116
- barrier=True,
117
- metadata=True,
118
- )
119
- log.debug(f"Created latest symlink: {symlink_path}")
103
+ if trainer.is_global_zero:
104
+ # Remove old checkpoints
105
+ self._remove_old_checkpoints(trainer)
106
+
107
+ # Create the latest symlink
108
+ if (symlink_filename := self._latest_symlink_filename()) is not None:
109
+ symlink_path = self.dirpath / symlink_filename
110
+ _link_checkpoint(filepath, symlink_path, metadata=True)
111
+ log.debug(f"Created latest symlink: {symlink_path}")
112
+
113
+ # Set the last global step saved
114
+ self._last_global_step_saved = trainer.global_step
115
+
116
+ # Barrier to ensure all processes have saved the checkpoint before continuing
117
+ trainer.strategy.barrier()
118
+
119
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
120
+ from lightning.pytorch.trainer.states import TrainerFn
121
+
122
+ return (
123
+ bool(
124
+ getattr(trainer, "fast_dev_run", False)
125
+ ) # disable checkpointing with fast_dev_run
126
+ or trainer.state.fn
127
+ != TrainerFn.FITTING # don't save anything during non-fit
128
+ or trainer.sanity_checking # don't save anything during sanity check
129
+ or self._last_global_step_saved
130
+ == trainer.global_step # already saved at the last step
131
+ )
@@ -198,19 +198,10 @@ class ModelCheckpoint(_ModelCheckpoint):
198
198
 
199
199
  @override
200
200
  def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
201
- return _link_checkpoint(
202
- trainer,
203
- filepath,
204
- linkpath,
205
- barrier=True,
206
- metadata=True,
207
- )
201
+ if trainer.is_global_zero:
202
+ _link_checkpoint(filepath, linkpath, metadata=True)
203
+ trainer.strategy.barrier()
208
204
 
209
205
  @override
210
206
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
211
- return _ckpt_saver_remove_checkpoint(
212
- trainer,
213
- filepath,
214
- metadata=True,
215
- barrier=False,
216
- )
207
+ _ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
@@ -39,6 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
39
39
  from ..callbacks import (
40
40
  BestCheckpointCallbackConfig,
41
41
  CallbackConfig,
42
+ LastCheckpointCallbackConfig,
42
43
  LatestEpochCheckpointCallbackConfig,
43
44
  ModelCheckpointCallbackConfig,
44
45
  OnExceptionCheckpointCallbackConfig,
@@ -773,6 +774,7 @@ class ReproducibilityConfig(C.Config):
773
774
  CheckpointCallbackConfig: TypeAlias = Annotated[
774
775
  ModelCheckpointCallbackConfig
775
776
  | BestCheckpointCallbackConfig
777
+ | LastCheckpointCallbackConfig
776
778
  | LatestEpochCheckpointCallbackConfig
777
779
  | OnExceptionCheckpointCallbackConfig,
778
780
  C.Field(discriminator="name"),
@@ -786,7 +788,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
786
788
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
787
789
  # ModelCheckpointCallbackConfig(),
788
790
  BestCheckpointCallbackConfig(),
789
- LatestEpochCheckpointCallbackConfig(),
791
+ LastCheckpointCallbackConfig(),
790
792
  OnExceptionCheckpointCallbackConfig(),
791
793
  ]
792
794
  """Checkpoint callback configurations."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.6
3
+ Version: 0.11.8
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,19 +1,21 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
- nshtrainer/_checkpoint/loader.py,sha256=_3jBf-k-fJCFfmU8wjDwbnE9rb4WoKYEyQiKGsBOCi4,13777
3
- nshtrainer/_checkpoint/metadata.py,sha256=M9eAZ2xMs36Z1G1xULu9MHZhsHxN8_9mNt3Iv7wuq-I,5069
4
- nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
2
+ nshtrainer/_checkpoint/loader.py,sha256=DSaNR8194kWon4O1svslNsCcN_8vlyLbF0LNCPfUpzI,13789
3
+ nshtrainer/_checkpoint/metadata.py,sha256=n2PMGdA3Jn3BuoyDXVCF9dUnNjuuru5CL9jqMD-X4Vk,4918
4
+ nshtrainer/_checkpoint/saver.py,sha256=TuSAP39DOOVvSnSukQ9RitMV60JnDg6L27fMRc2uVJc,1358
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
8
8
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
9
- nshtrainer/callbacks/__init__.py,sha256=5lFHe7bNdKxqvw8VRHG18W3R5l34n02Z1wVuQ38_PTg,2488
9
+ nshtrainer/callbacks/__init__.py,sha256=0ZYo2pX_MfTlPeoixOIHU2RgX4NcpxeQL4YrU2FE63Q,2665
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
- nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=Ygblsf9NdLHxQPJUM47W0nGxlabj-ZnEBIMpvk-QMS8,7124
15
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=NES-acaslPBiZQIMAdk_YwtnBrkm_y_BJQ8Ian0UKP0,4294
16
- nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
13
+ nshtrainer/callbacks/checkpoint/__init__.py,sha256=2COllQ7Rfe9IWwdLjgG0Bxc-lxVJJ_aU7A9zFObTwwY,899
14
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=9HQSa-toOyjtuDldQ71gaDVBRdryAaB_nRv5Y554tIk,5938
15
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=O6d4SqIYsxpnRj_IYX8A9VLgOBwxTdz-j2FV_nn3BT8,2067
16
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CM8f37dwaYHkjQFfJNTZTzSoF45zEjFRm-Fg1CzYmP4,1037
17
+ nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=II-WZHAk7leK1Vgjza0PVifrF3QetR9Nn3n1qhqtuVo,4819
18
+ nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=JS1z2YuEiQxk61HgZU1jySzF_pzdfXYO54_qHo-q3CQ,6776
17
19
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
18
20
  nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
19
21
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -54,7 +56,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
54
56
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
55
57
  nshtrainer/model/__init__.py,sha256=NpvyQHmGaHB8xdraHmm8l7kDHLmvJSgBNQKkfYqtgyI,1454
56
58
  nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
57
- nshtrainer/model/config.py,sha256=48Vx4RiPUGjqEhoHEX4ukOGy6KlI7RmhnShJQaRQ3io,54885
59
+ nshtrainer/model/config.py,sha256=u6pUbgR_qqo_xP1lclhavKG1KbsQ9run_P0Am2XcsQw,54947
58
60
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
59
61
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
60
62
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -82,6 +84,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
82
84
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
83
85
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
84
86
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
85
- nshtrainer-0.11.6.dist-info/METADATA,sha256=tHGQ69o-paHEvlLUgo46bWeMlvuTlb8Q-upA00NxoKE,860
86
- nshtrainer-0.11.6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
- nshtrainer-0.11.6.dist-info/RECORD,,
87
+ nshtrainer-0.11.8.dist-info/METADATA,sha256=l5Y-rA6Kxo50NbQQJ3kKp5le6Y579NY6Jmjruomfbfs,860
88
+ nshtrainer-0.11.8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
89
+ nshtrainer-0.11.8.dist-info/RECORD,,