nshtrainer 0.10.10__tar.gz → 0.10.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/PKG-INFO +1 -1
  2. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/pyproject.toml +1 -1
  3. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/_checkpoint/metadata.py +73 -0
  4. nshtrainer-0.10.11/src/nshtrainer/_checkpoint/saver.py +52 -0
  5. nshtrainer-0.10.11/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +114 -0
  6. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/model_checkpoint.py +18 -0
  7. nshtrainer-0.10.10/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -74
  8. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/README.md +0 -0
  9. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/__init__.py +0 -0
  10. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/_checkpoint/loader.py +0 -0
  11. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/__init__.py +0 -0
  12. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  13. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  14. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  15. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/__init__.py +0 -0
  16. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  17. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/actsave.py +0 -0
  18. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/base.py +0 -0
  19. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  20. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/ema.py +0 -0
  21. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  22. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  23. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/interval.py +0 -0
  24. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  25. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  26. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/print_table.py +0 -0
  28. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  29. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/timer.py +0 -0
  30. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  31. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/data/__init__.py +0 -0
  32. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  33. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/data/transform.py +0 -0
  34. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/__init__.py +0 -0
  35. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/_experimental.py +0 -0
  36. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/actsave.py +0 -0
  37. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/callbacks.py +0 -0
  38. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/config.py +0 -0
  39. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/data.py +0 -0
  40. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/log.py +0 -0
  41. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  42. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/model.py +0 -0
  43. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/nn.py +0 -0
  44. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/optimizer.py +0 -0
  45. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/runner.py +0 -0
  46. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/snapshot.py +0 -0
  47. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/snoop.py +0 -0
  48. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/trainer.py +0 -0
  49. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/typecheck.py +0 -0
  50. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/ll/util.py +0 -0
  51. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  52. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  53. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  54. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  55. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/metrics/__init__.py +0 -0
  56. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/metrics/_config.py +0 -0
  57. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/__init__.py +0 -0
  58. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/_environment.py +0 -0
  59. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/base.py +0 -0
  60. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/config.py +0 -0
  61. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/callback.py +0 -0
  62. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/debug.py +0 -0
  63. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/distributed.py +0 -0
  64. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/logger.py +0 -0
  65. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/profiler.py +0 -0
  66. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  67. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  68. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/nn/__init__.py +0 -0
  69. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/nn/mlp.py +0 -0
  70. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/nn/module_dict.py +0 -0
  71. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/nn/module_list.py +0 -0
  72. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/nn/nonlinearity.py +0 -0
  73. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/optimizer.py +0 -0
  74. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/runner.py +0 -0
  75. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/trainer/trainer.py +0 -0
  81. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/util/environment.py +0 -0
  82. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/util/seed.py +0 -0
  83. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/util/slurm.py +0 -0
  84. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/util/typed.py +0 -0
  85. {nshtrainer-0.10.10 → nshtrainer-0.10.11}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.10
3
+ Version: 0.10.11
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.10.10"
3
+ version = "0.10.11"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -1,6 +1,8 @@
1
1
  import copy
2
2
  import datetime
3
3
  import logging
4
+ import shutil
5
+ from collections.abc import Callable
4
6
  from pathlib import Path
5
7
  from typing import TYPE_CHECKING, Any, cast
6
8
 
@@ -100,3 +102,74 @@ def _write_checkpoint_metadata(
100
102
  log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
101
103
  else:
102
104
  log.info(f"Checkpoint metadata written to {checkpoint_path}")
105
+
106
+
107
+ def _remove_checkpoint_metadata(checkpoint_path: Path):
108
+ for path in (
109
+ checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
110
+ checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
111
+ ):
112
+ try:
113
+ path.unlink(missing_ok=True)
114
+ except Exception as e:
115
+ log.warning(f"Failed to remove {path}: {e}")
116
+ else:
117
+ log.info(f"Removed {path}")
118
+
119
+
120
+ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
121
+ # First, remove any existing metadata files
122
+ _remove_checkpoint_metadata(linked_checkpoint_path)
123
+
124
+ # Link the metadata files to the new checkpoint
125
+ for path in (
126
+ checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
127
+ checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
128
+ ):
129
+ linked_path = linked_checkpoint_path.with_suffix(path.suffix)
130
+ try:
131
+ try:
132
+ linked_path.symlink_to(path)
133
+ except OSError:
134
+ # on Windows, special permissions are required to create symbolic links as a regular user
135
+ # fall back to copying the file
136
+ shutil.copy(path, linked_path)
137
+ except Exception as e:
138
+ log.warning(f"Failed to link {path} to {linked_path}: {e}")
139
+ else:
140
+ log.info(f"Linked {path} to {linked_path}")
141
+
142
+
143
+ def _checkpoint_sort_key_fn(key: Callable[[CheckpointMetadata, Path], Any]):
144
+ def sort_key_fn(checkpoint_path: Path):
145
+ if not (p := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)).exists():
146
+ raise FileNotFoundError(f"Metadata file not found: {p}")
147
+
148
+ nonlocal key
149
+ return key(CheckpointMetadata.from_file(p), p)
150
+
151
+ return sort_key_fn
152
+
153
+
154
+ def _sort_ckpts_by_metadata(
155
+ checkpoint_paths: list[Path],
156
+ key: Callable[[CheckpointMetadata, Path], Any],
157
+ fallback_key: Callable[[Path], Any],
158
+ ):
159
+ # First, let's make sure all the metadata files exist.
160
+ # If not, use the fallback function to sort the checkpoints.
161
+ no_metadata_paths: list[Path] = []
162
+ for path in checkpoint_paths:
163
+ if (path.with_suffix(METADATA_PATH_SUFFIX)).exists():
164
+ continue
165
+
166
+ no_metadata_paths.append(path)
167
+
168
+ if no_metadata_paths:
169
+ log.warning(
170
+ f"Metadata file not found on {len(no_metadata_paths)} checkpoints: {no_metadata_paths}\n"
171
+ "Falling back to sorting by last modified time."
172
+ )
173
+ return sorted(checkpoint_paths, key=fallback_key)
174
+
175
+ return sorted(checkpoint_paths, key=_checkpoint_sort_key_fn(key))
@@ -0,0 +1,52 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ from lightning.pytorch import Trainer
6
+
7
+ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
+
9
+
10
+ def _link_checkpoint(
11
+ trainer: Trainer,
12
+ filepath: str | Path | os.PathLike,
13
+ linkpath: str | Path | os.PathLike,
14
+ *,
15
+ barrier: bool,
16
+ metadata: bool,
17
+ ):
18
+ if not isinstance(filepath, Path):
19
+ filepath = Path(filepath)
20
+ if not isinstance(linkpath, Path):
21
+ linkpath = Path(linkpath)
22
+
23
+ if trainer.is_global_zero:
24
+ if linkpath.exists():
25
+ if linkpath.is_symlink() or linkpath.is_file():
26
+ linkpath.unlink()
27
+ elif linkpath.is_dir():
28
+ shutil.rmtree(linkpath)
29
+ _remove_checkpoint_metadata(linkpath)
30
+
31
+ try:
32
+ target_path = filepath.relative_to(linkpath.parent)
33
+ linkpath.symlink_to(target_path)
34
+ except OSError:
35
+ # on Windows, special permissions are required to create symbolic links as a regular user
36
+ # fall back to copying the file
37
+ shutil.copy(filepath, linkpath)
38
+
39
+ _link_checkpoint_metadata(filepath, linkpath)
40
+ if barrier:
41
+ trainer.strategy.barrier()
42
+
43
+
44
+ def _remove_checkpoint(
45
+ trainer: Trainer,
46
+ filepath: str | Path | os.PathLike,
47
+ remove_metadata: bool = True,
48
+ ):
49
+ if not isinstance(filepath, Path):
50
+ filepath = Path(filepath)
51
+ trainer.strategy.remove_checkpoint(filepath)
52
+ _remove_checkpoint_metadata(filepath)
@@ -0,0 +1,114 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Checkpoint
7
+ from typing_extensions import override
8
+
9
+ from .._checkpoint.metadata import _sort_ckpts_by_metadata
10
+ from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
+ from .base import CallbackConfigBase
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
17
+ name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
18
+
19
+ dirpath: str | Path | None = None
20
+ """Directory path to save the checkpoint file."""
21
+
22
+ filename: str = "epoch{epoch:02d}_step{step:04d}"
23
+ """Checkpoint filename. This must not include the extension."""
24
+
25
+ save_weights_only: bool = False
26
+ """Whether to save only the model's weights or the entire model object."""
27
+
28
+ latest_symlink_filename: str | None = "latest"
29
+ """Filename for the latest symlink. If None, no symlink will be created."""
30
+
31
+ latest_k: int | Literal["all"] = 1
32
+ """Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
33
+
34
+ @override
35
+ def create_callbacks(self, root_config):
36
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
37
+ root_config.id, "checkpoint"
38
+ )
39
+ dirpath = Path(dirpath)
40
+
41
+ yield LatestEpochCheckpoint(self, dirpath)
42
+
43
+
44
+ class LatestEpochCheckpoint(Checkpoint):
45
+ PREFIX = "latest_"
46
+ EXTENSION = ".ckpt"
47
+
48
+ def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
49
+ super().__init__()
50
+
51
+ self.config = config
52
+ self.dirpath = dirpath
53
+
54
+ @override
55
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
56
+ self._save_new_checkpoint(trainer)
57
+
58
+ def _latest_symlink_filename(self):
59
+ if (filename := self.config.latest_symlink_filename) is None:
60
+ return None
61
+ return f"{filename}{self.EXTENSION}"
62
+
63
+ def _ckpt_path(self, trainer: Trainer):
64
+ filename = self.config.filename.format(
65
+ epoch=trainer.current_epoch, step=trainer.global_step
66
+ )
67
+ filename = f"{self.PREFIX}{filename}.{self.EXTENSION}"
68
+ return self.dirpath / filename
69
+
70
+ def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
+ for ckpt_path in ckpt_paths:
72
+ _remove_checkpoint(trainer, ckpt_path, remove_metadata=True)
73
+
74
+ def _remove_old_checkpoints(self, trainer: Trainer):
75
+ if (latest_k := self.config.latest_k) == "all":
76
+ return
77
+
78
+ # Get all configs, ignoring the latest symlink
79
+ ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
80
+ # Ignore the latest symlink
81
+ if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
82
+ ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
83
+
84
+ # Sort by epoch, then step, then last modified
85
+ ckpt_paths = _sort_ckpts_by_metadata(
86
+ ckpt_paths,
87
+ key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
88
+ fallback_key=lambda p: p.stat().st_mtime,
89
+ # ^ Called if metadata is not found on all checkpoints
90
+ )
91
+
92
+ # Remove all but the latest k checkpoints
93
+ ckpts_to_remove = ckpt_paths[:-latest_k]
94
+ self._remove_checkpoints(trainer, ckpts_to_remove)
95
+
96
+ def _save_new_checkpoint(self, trainer: Trainer):
97
+ # Remove old checkpoints
98
+ self._remove_old_checkpoints(trainer)
99
+
100
+ # Save the new checkpoint
101
+ filepath = self._ckpt_path(trainer)
102
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
103
+
104
+ # Create the latest symlink
105
+ if (symlink_filename := self._latest_symlink_filename()) is not None:
106
+ symlink_path = self.dirpath / symlink_filename
107
+ _link_checkpoint(
108
+ trainer,
109
+ filepath,
110
+ symlink_path,
111
+ barrier=True,
112
+ metadata=True,
113
+ )
114
+ log.info(f"Created latest symlink: {symlink_path}")
@@ -4,11 +4,13 @@ from datetime import timedelta
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
7
+ from lightning.pytorch import Trainer
7
8
  from lightning.pytorch.callbacks.model_checkpoint import (
8
9
  ModelCheckpoint as _ModelCheckpoint,
9
10
  )
10
11
  from typing_extensions import override
11
12
 
13
+ from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
12
14
  from ..metrics import MetricConfig
13
15
  from .base import CallbackConfigBase
14
16
 
@@ -158,6 +160,8 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
158
160
 
159
161
 
160
162
  class ModelCheckpoint(_ModelCheckpoint):
163
+ CHECKPOINT_NAME_LAST = "best"
164
+
161
165
  @override
162
166
  def __init__(
163
167
  self,
@@ -185,3 +189,17 @@ class ModelCheckpoint(_ModelCheckpoint):
185
189
  save_on_train_epoch_end=self.config.save_on_train_epoch_end,
186
190
  enable_version_counter=self.config.enable_version_counter,
187
191
  )
192
+
193
+ @override
194
+ def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
195
+ return _link_checkpoint(
196
+ trainer,
197
+ filepath,
198
+ linkpath,
199
+ barrier=True,
200
+ metadata=True,
201
+ )
202
+
203
+ @override
204
+ def _remove_checkpoint(self, trainer: Trainer, filepath: str):
205
+ return _remove_checkpoint(trainer, filepath, remove_metadata=True)
@@ -1,74 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Literal
4
-
5
- from lightning.pytorch import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Checkpoint
7
- from typing_extensions import override
8
-
9
- from .base import CallbackConfigBase
10
-
11
- log = logging.getLogger(__name__)
12
-
13
-
14
- class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
15
- name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
16
-
17
- dirpath: str | Path | None = None
18
- """Directory path to save the checkpoint file."""
19
-
20
- filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
21
- """Checkpoint filename. This must not include the extension."""
22
-
23
- save_weights_only: bool = False
24
- """Whether to save only the model's weights or the entire model object."""
25
-
26
- latest_symlink_filename: str | None = "latest.ckpt"
27
- """Filename for the latest symlink. If None, no symlink will be created."""
28
-
29
- @override
30
- def create_callbacks(self, root_config):
31
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
32
- root_config.id, "checkpoint"
33
- )
34
- dirpath = Path(dirpath)
35
-
36
- yield LatestEpochCheckpoint(self, dirpath)
37
-
38
-
39
- class LatestEpochCheckpoint(Checkpoint):
40
- def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
41
- super().__init__()
42
-
43
- self.config = config
44
- self.dirpath = dirpath
45
-
46
- def _ckpt_path(self, trainer: Trainer):
47
- return self.dirpath / self.config.filename.format(
48
- epoch=trainer.current_epoch, step=trainer.global_step
49
- )
50
-
51
- @override
52
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
53
- # Save the new checkpoint
54
- filepath = self._ckpt_path(trainer)
55
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
56
-
57
- # Create the latest symlink
58
- if (
59
- trainer.is_global_zero
60
- and (symlink_filename := self.config.latest_symlink_filename) is not None
61
- ):
62
- symlink_path = self.dirpath / symlink_filename
63
- symlink_path.unlink(missing_ok=True)
64
- symlink_path.symlink_to(filepath.name)
65
- log.info(f"Created latest symlink: {symlink_path}")
66
-
67
- def latest_checkpoint(self):
68
- if (symlink_filename := self.config.latest_symlink_filename) is None:
69
- return None
70
-
71
- if not (symlink_path := self.dirpath / symlink_filename).exists():
72
- return None
73
-
74
- return symlink_path
File without changes