nshtrainer 0.11.7__tar.gz → 0.11.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_checkpoint/loader.py +4 -4
  4. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_checkpoint/metadata.py +37 -35
  5. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/__init__.py +3 -8
  6. nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/__init__.py +12 -0
  7. nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/_base.py +175 -0
  8. nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +70 -0
  9. nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +39 -0
  10. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/__init__.py +2 -4
  11. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/config.py +4 -37
  12. nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -16
  13. nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -192
  14. nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +0 -131
  15. nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -207
  16. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/README.md +0 -0
  17. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/__init__.py +0 -0
  18. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_checkpoint/saver.py +0 -0
  19. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/__init__.py +0 -0
  20. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  21. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  22. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  23. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  24. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/actsave.py +0 -0
  25. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/base.py +0 -0
  26. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  28. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/ema.py +0 -0
  29. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  30. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  31. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/interval.py +0 -0
  32. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  33. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  34. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/print_table.py +0 -0
  35. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  36. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/timer.py +0 -0
  37. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  38. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/data/__init__.py +0 -0
  39. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  40. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/data/transform.py +0 -0
  41. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/__init__.py +0 -0
  42. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/_experimental.py +0 -0
  43. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/actsave.py +0 -0
  44. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/callbacks.py +0 -0
  45. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/config.py +0 -0
  46. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/data.py +0 -0
  47. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/log.py +0 -0
  48. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  49. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/model.py +0 -0
  50. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/nn.py +0 -0
  51. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/optimizer.py +0 -0
  52. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/runner.py +0 -0
  53. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/snapshot.py +0 -0
  54. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/snoop.py +0 -0
  55. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/trainer.py +0 -0
  56. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/typecheck.py +0 -0
  57. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/util.py +0 -0
  58. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  59. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  60. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  61. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  62. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/metrics/__init__.py +0 -0
  63. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/metrics/_config.py +0 -0
  64. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/base.py +0 -0
  65. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/callback.py +0 -0
  66. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/debug.py +0 -0
  67. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/distributed.py +0 -0
  68. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/logger.py +0 -0
  69. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/profiler.py +0 -0
  70. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  71. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  72. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/__init__.py +0 -0
  73. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/mlp.py +0 -0
  74. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/module_dict.py +0 -0
  75. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/module_list.py +0 -0
  76. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/nonlinearity.py +0 -0
  77. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/optimizer.py +0 -0
  78. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/runner.py +0 -0
  79. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/scripts/find_packages.py +0 -0
  80. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/__init__.py +0 -0
  81. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  82. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  83. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/signal_connector.py +0 -0
  84. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/trainer.py +0 -0
  85. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/_environment_info.py +0 -0
  86. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/_useful_types.py +0 -0
  87. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/environment.py +0 -0
  88. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/seed.py +0 -0
  89. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/slurm.py +0 -0
  90. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/typed.py +0 -0
  91. {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.7
3
+ Version: 0.11.9
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.7"
3
+ version = "0.11.9"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
236
236
  error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
237
237
  match on_error:
238
238
  case "warn":
239
- log.warn(error_msg)
239
+ log.warning(error_msg)
240
240
  case "raise":
241
241
  raise ValueError(error_msg)
242
242
  case _:
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
325
325
  ),
326
326
  ]
327
327
  if not candidates:
328
- log.warn(
328
+ log.warning(
329
329
  "No checkpoint candidates found for `best` checkpoint strategy."
330
330
  )
331
331
  continue
332
332
 
333
333
  if (metric := strategy.metric or root_config.primary_metric) is None:
334
- log.warn(
334
+ log.warning(
335
335
  "No metric specified for `best` checkpoint strategy, "
336
336
  "and no primary metric is set in the configuration. "
337
337
  "Skipping strategy."
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
360
360
  ),
361
361
  ]
362
362
  if not candidates:
363
- log.warn(
363
+ log.warning(
364
364
  "No checkpoint candidates found for `last` checkpoint strategy."
365
365
  )
366
366
  continue
@@ -4,7 +4,7 @@ import logging
4
4
  import shutil
5
5
  from collections.abc import Callable
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
8
8
 
9
9
  import nshconfig as C
10
10
  import numpy as np
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  METADATA_PATH_SUFFIX = ".metadata.json"
23
- HPARAMS_PATH_SUFFIX = ".hparams.json"
24
23
 
25
24
 
26
25
  class CheckpointMetadata(C.Config):
26
+ PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
27
+
27
28
  checkpoint_path: Path
28
29
  checkpoint_filename: str
29
30
 
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
39
40
  metrics: dict[str, Any]
40
41
  environment: EnvironmentConfig
41
42
 
43
+ hparams: dict[str, Any] | None
44
+
42
45
  @classmethod
43
46
  def from_file(cls, path: Path):
44
47
  return cls.model_validate_json(path.read_text())
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
55
58
 
56
59
 
57
60
  def _generate_checkpoint_metadata(
58
- config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
61
+ config: "BaseConfig",
62
+ trainer: "Trainer",
63
+ checkpoint_path: Path,
64
+ metadata_path: Path,
59
65
  ):
60
66
  checkpoint_timestamp = datetime.datetime.now()
61
67
  start_timestamp = trainer.start_time()
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
70
76
  metrics[name] = metric
71
77
 
72
78
  return CheckpointMetadata(
73
- checkpoint_path=checkpoint_path,
79
+ # checkpoint_path=checkpoint_path,
80
+ # We should store the path as a relative path
81
+ # to the metadata file to avoid issues with
82
+ # moving the checkpoint directory
83
+ checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
74
84
  checkpoint_filename=checkpoint_path.name,
75
85
  run_id=config.id,
76
86
  name=config.run_name,
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
84
94
  training_time=training_time,
85
95
  metrics=metrics,
86
96
  environment=config.environment,
97
+ hparams=config.model_dump(mode="json"),
87
98
  )
88
99
 
89
100
 
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
93
104
  checkpoint_path: Path,
94
105
  ):
95
106
  config = cast("BaseConfig", model.config)
96
- metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
107
+ metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
+ metadata = _generate_checkpoint_metadata(
109
+ config, trainer, checkpoint_path, metadata_path
110
+ )
97
111
 
98
112
  # Write the metadata to the checkpoint directory
99
113
  try:
100
- metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
101
114
  metadata_path.write_text(metadata.model_dump_json(indent=4))
102
115
  except Exception as e:
103
116
  log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
104
117
  else:
105
118
  log.debug(f"Checkpoint metadata written to {checkpoint_path}")
106
119
 
107
- # Write the hparams to the checkpoint directory
120
+
121
+ def _remove_checkpoint_metadata(checkpoint_path: Path):
122
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
123
  try:
109
- hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
110
- hparams_path.write_text(config.model_dump_json(indent=4))
124
+ path.unlink(missing_ok=True)
111
125
  except Exception as e:
112
- log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
126
+ log.warning(f"Failed to remove {path}: {e}")
113
127
  else:
114
- log.debug(f"Checkpoint metadata written to {checkpoint_path}")
115
-
116
-
117
- def _remove_checkpoint_metadata(checkpoint_path: Path):
118
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
119
- path = checkpoint_path.with_suffix(suffix)
120
- try:
121
- path.unlink(missing_ok=True)
122
- except Exception as e:
123
- log.warning(f"Failed to remove {path}: {e}")
124
- else:
125
- log.debug(f"Removed {path}")
128
+ log.debug(f"Removed {path}")
126
129
 
127
130
 
128
131
  def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
130
133
  _remove_checkpoint_metadata(linked_checkpoint_path)
131
134
 
132
135
  # Link the metadata files to the new checkpoint
133
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
134
- path = checkpoint_path.with_suffix(suffix)
135
- linked_path = linked_checkpoint_path.with_suffix(suffix)
136
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
137
+ linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
138
+ try:
136
139
  try:
137
- try:
138
- linked_path.symlink_to(path)
139
- except OSError:
140
- # on Windows, special permissions are required to create symbolic links as a regular user
141
- # fall back to copying the file
142
- shutil.copy(path, linked_path)
143
- except Exception as e:
144
- log.warning(f"Failed to link {path} to {linked_path}: {e}")
145
- else:
146
- log.debug(f"Linked {path} to {linked_path}")
140
+ linked_path.symlink_to(path)
141
+ except OSError:
142
+ # on Windows, special permissions are required to create symbolic links as a regular user
143
+ # fall back to copying the file
144
+ shutil.copy(path, linked_path)
145
+ except Exception as e:
146
+ log.warning(f"Failed to link {path} to {linked_path}: {e}")
147
+ else:
148
+ log.debug(f"Linked {path} to {linked_path}")
147
149
 
148
150
 
149
151
  def _sort_ckpts_by_metadata(
@@ -6,12 +6,8 @@ from . import checkpoint as checkpoint
6
6
  from .base import CallbackConfigBase as CallbackConfigBase
7
7
  from .checkpoint import BestCheckpoint as BestCheckpoint
8
8
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
- from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
10
- from .checkpoint import (
11
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
12
- )
13
- from .checkpoint import ModelCheckpoint as ModelCheckpoint
14
- from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
9
+ from .checkpoint import LastCheckpoint as LastCheckpoint
10
+ from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
15
11
  from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
16
12
  from .checkpoint import (
17
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -46,8 +42,7 @@ CallbackConfig = Annotated[
46
42
  | GradientSkippingConfig
47
43
  | EMAConfig
48
44
  | BestCheckpointCallbackConfig
49
- | ModelCheckpointCallbackConfig
50
- | LatestEpochCheckpointCallbackConfig
45
+ | LastCheckpointCallbackConfig
51
46
  | OnExceptionCheckpointCallbackConfig
52
47
  | WandbWatchConfig,
53
48
  C.Field(discriminator="name"),
@@ -0,0 +1,12 @@
1
+ from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
+ from .best_checkpoint import (
3
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
+ )
5
+ from .last_checkpoint import LastCheckpoint as LastCheckpoint
6
+ from .last_checkpoint import (
7
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
8
+ )
9
+ from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
10
+ from .on_exception_checkpoint import (
11
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
12
+ )
@@ -0,0 +1,175 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, Generic, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Checkpoint
10
+ from typing_extensions import TypeVar, override
11
+
12
+ from ..._checkpoint.metadata import CheckpointMetadata
13
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
+ from ..base import CallbackConfigBase
15
+
16
+ if TYPE_CHECKING:
17
+ from ...model.config import BaseConfig
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
23
+ dirpath: str | Path | None = None
24
+ """Directory path to save the checkpoint file."""
25
+
26
+ filename: str | None = None
27
+ """Checkpoint filename. This must not include the extension.
28
+ If None, the default filename will be used."""
29
+
30
+ save_weights_only: bool = False
31
+ """Whether to save only the model's weights or the entire model object."""
32
+
33
+ save_symlink: bool = True
34
+ """Whether to create a symlink to the saved checkpoint."""
35
+
36
+ topk: int | Literal["all"] = 1
37
+ """The number of checkpoints to keep."""
38
+
39
+ @abstractmethod
40
+ def create_checkpoint(
41
+ self,
42
+ root_config: "BaseConfig",
43
+ dirpath: Path,
44
+ ) -> "CheckpointBase": ...
45
+
46
+ @override
47
+ def create_callbacks(self, root_config):
48
+ dirpath = Path(
49
+ self.dirpath
50
+ or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
51
+ )
52
+
53
+ yield self.create_checkpoint(root_config, dirpath)
54
+
55
+
56
+ TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
57
+
58
+
59
+ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
60
+ def __init__(self, config: TConfig, dirpath: Path):
61
+ super().__init__()
62
+
63
+ self.config = config
64
+ self.dirpath = dirpath / self.name()
65
+ self.symlink_dirpath = dirpath
66
+
67
+ self._last_global_step_saved = 0
68
+
69
+ @abstractmethod
70
+ def default_filename(self) -> str: ...
71
+
72
+ @abstractmethod
73
+ def name(self) -> str: ...
74
+
75
+ def extension(self) -> str:
76
+ return ".ckpt"
77
+
78
+ @abstractmethod
79
+ def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
80
+
81
+ def symlink_path(self):
82
+ if not self.config.save_symlink:
83
+ return None
84
+
85
+ return self.symlink_dirpath / f"{self.name()}{self.extension()}"
86
+
87
+ def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
88
+ if (filename := self.config.filename) is None:
89
+ filename = self.default_filename()
90
+ filename = filename.format(**current_metrics)
91
+ return self.dirpath / f"{filename}{self.extension()}"
92
+
93
+ def remove_old_checkpoints(self, trainer: Trainer):
94
+ if (topk := self.config.topk) == "all":
95
+ return
96
+
97
+ # Get all the checkpoint metadata
98
+ metas = [
99
+ CheckpointMetadata.from_file(p)
100
+ for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
101
+ ]
102
+
103
+ # Sort by the topk sort key
104
+ metas = sorted(metas, key=self.topk_sort_key)
105
+
106
+ # Now, the metas are sorted from the best to the worst,
107
+ # so we can remove the worst checkpoints
108
+ for meta in metas[topk:]:
109
+ if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
110
+ log.warning(
111
+ f"Checkpoint file not found: {old_ckpt_path}\n"
112
+ "Skipping removal of the checkpoint metadata."
113
+ )
114
+ continue
115
+
116
+ _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
117
+ log.debug(f"Removed old checkpoint: {old_ckpt_path}")
118
+
119
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
120
+ current_metrics: dict[str, Any] = {
121
+ "epoch": trainer.current_epoch,
122
+ "step": trainer.global_step,
123
+ }
124
+
125
+ for name, value in trainer.callback_metrics.items():
126
+ match value:
127
+ case torch.Tensor() if value.numel() == 1:
128
+ value = value.detach().cpu().item()
129
+ case np.ndarray() if value.size == 1:
130
+ value = value.item()
131
+ case _:
132
+ pass
133
+
134
+ current_metrics[name] = value
135
+
136
+ return current_metrics
137
+
138
+ def save_checkpoints(self, trainer: Trainer):
139
+ if self._should_skip_saving_checkpoint(trainer):
140
+ return
141
+
142
+ # Save the new checkpoint
143
+ filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
144
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
145
+
146
+ if trainer.is_global_zero:
147
+ # Remove old checkpoints
148
+ self.remove_old_checkpoints(trainer)
149
+
150
+ # Create the latest symlink
151
+ if (symlink_filename := self.symlink_path()) is not None:
152
+ symlink_path = self.dirpath / symlink_filename
153
+ _link_checkpoint(filepath, symlink_path, metadata=True)
154
+ log.debug(f"Created latest symlink: {symlink_path}")
155
+
156
+ # Barrier to ensure all processes have saved the checkpoint,
157
+ # deleted the old checkpoints, and created the symlink before continuing
158
+ trainer.strategy.barrier()
159
+
160
+ # Set the last global step saved
161
+ self._last_global_step_saved = trainer.global_step
162
+
163
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
164
+ from lightning.pytorch.trainer.states import TrainerFn
165
+
166
+ return (
167
+ bool(
168
+ getattr(trainer, "fast_dev_run", False)
169
+ ) # disable checkpointing with fast_dev_run
170
+ or trainer.state.fn
171
+ != TrainerFn.FITTING # don't save anything during non-fit
172
+ or trainer.sanity_checking # don't save anything during sanity check
173
+ or self._last_global_step_saved
174
+ == trainer.global_step # already saved at the last step
175
+ )
@@ -0,0 +1,70 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from typing_extensions import final, override
7
+
8
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
9
+
10
+ from ...metrics._config import MetricConfig
11
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
18
+ name: Literal["best_checkpoint"] = "best_checkpoint"
19
+
20
+ metric: MetricConfig | None = None
21
+ """Metric to monitor, or `None` to use the default metric."""
22
+
23
+ @override
24
+ def create_checkpoint(self, root_config, dirpath):
25
+ # Resolve metric
26
+ if (metric := self.metric) is None and (
27
+ metric := root_config.primary_metric
28
+ ) is None:
29
+ raise ValueError(
30
+ "No metric provided and no primary metric found in the root config"
31
+ )
32
+
33
+ return BestCheckpoint(self, dirpath, metric)
34
+
35
+
36
+ @final
37
+ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
38
+ @property
39
+ def _metric_name_normalized(self):
40
+ return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
41
+
42
+ @override
43
+ def __init__(
44
+ self,
45
+ config: BestCheckpointCallbackConfig,
46
+ dirpath: Path,
47
+ metric: MetricConfig,
48
+ ):
49
+ super().__init__(config, dirpath)
50
+ self.metric = metric
51
+
52
+ @override
53
+ def name(self):
54
+ return f"best_{self._metric_name_normalized}"
55
+
56
+ @override
57
+ def default_filename(self):
58
+ return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
59
+
60
+ @override
61
+ def topk_sort_key(self, metadata: CheckpointMetadata):
62
+ return metadata.metrics.get(
63
+ self.metric.validation_monitor,
64
+ float("-inf" if self.metric.mode == "max" else "inf"),
65
+ )
66
+
67
+ # Events
68
+ @override
69
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
70
+ self.save_checkpoints(trainer)
@@ -0,0 +1,39 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from typing_extensions import final, override
6
+
7
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
8
+
9
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ @final
15
+ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
16
+ name: Literal["last_checkpoint"] = "last_checkpoint"
17
+
18
+ @override
19
+ def create_checkpoint(self, root_config, dirpath):
20
+ return LastCheckpoint(self, dirpath)
21
+
22
+
23
+ @final
24
+ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
25
+ @override
26
+ def name(self):
27
+ return "last"
28
+
29
+ @override
30
+ def default_filename(self):
31
+ return "epoch{epoch:03d}-step{step:07d}"
32
+
33
+ @override
34
+ def topk_sort_key(self, metadata: CheckpointMetadata):
35
+ return metadata.checkpoint_timestamp
36
+
37
+ @override
38
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
39
+ self.save_checkpoints(trainer)
@@ -5,17 +5,15 @@ from .base import LightningModuleBase as LightningModuleBase
5
5
  from .config import BaseConfig as BaseConfig
6
6
  from .config import BaseLoggerConfig as BaseLoggerConfig
7
7
  from .config import BaseProfilerConfig as BaseProfilerConfig
8
+ from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
8
9
  from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
9
10
  from .config import CheckpointSavingConfig as CheckpointSavingConfig
10
11
  from .config import DirectoryConfig as DirectoryConfig
11
12
  from .config import EarlyStoppingConfig as EarlyStoppingConfig
12
13
  from .config import GradientClippingConfig as GradientClippingConfig
13
- from .config import (
14
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
15
- )
14
+ from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
16
15
  from .config import LoggingConfig as LoggingConfig
17
16
  from .config import MetricConfig as MetricConfig
18
- from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
19
17
  from .config import (
20
18
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
21
19
  )
@@ -39,8 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
39
39
  from ..callbacks import (
40
40
  BestCheckpointCallbackConfig,
41
41
  CallbackConfig,
42
- LatestEpochCheckpointCallbackConfig,
43
- ModelCheckpointCallbackConfig,
42
+ LastCheckpointCallbackConfig,
44
43
  OnExceptionCheckpointCallbackConfig,
45
44
  WandbWatchConfig,
46
45
  )
@@ -771,9 +770,8 @@ class ReproducibilityConfig(C.Config):
771
770
 
772
771
 
773
772
  CheckpointCallbackConfig: TypeAlias = Annotated[
774
- ModelCheckpointCallbackConfig
775
- | BestCheckpointCallbackConfig
776
- | LatestEpochCheckpointCallbackConfig
773
+ BestCheckpointCallbackConfig
774
+ | LastCheckpointCallbackConfig
777
775
  | OnExceptionCheckpointCallbackConfig,
778
776
  C.Field(discriminator="name"),
779
777
  ]
@@ -784,9 +782,8 @@ class CheckpointSavingConfig(CallbackConfigBase):
784
782
  """Enable checkpoint saving."""
785
783
 
786
784
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
787
- # ModelCheckpointCallbackConfig(),
788
785
  BestCheckpointCallbackConfig(),
789
- LatestEpochCheckpointCallbackConfig(),
786
+ LastCheckpointCallbackConfig(),
790
787
  OnExceptionCheckpointCallbackConfig(),
791
788
  ]
792
789
  """Checkpoint callback configurations."""
@@ -804,36 +801,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
804
801
 
805
802
  return True
806
803
 
807
- @property
808
- def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
809
- return next(
810
- (
811
- callback
812
- for callback in self.checkpoint_callbacks
813
- if isinstance(callback, ModelCheckpointCallbackConfig)
814
- ),
815
- )
816
-
817
- @property
818
- def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
819
- return next(
820
- (
821
- callback
822
- for callback in self.checkpoint_callbacks
823
- if isinstance(callback, LatestEpochCheckpointCallbackConfig)
824
- ),
825
- )
826
-
827
- @property
828
- def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
829
- return next(
830
- (
831
- callback
832
- for callback in self.checkpoint_callbacks
833
- if isinstance(callback, OnExceptionCheckpointCallbackConfig)
834
- ),
835
- )
836
-
837
804
  @override
838
805
  def create_callbacks(self, root_config: "BaseConfig"):
839
806
  if not self.should_save_checkpoints(root_config):
@@ -1,16 +0,0 @@
1
- from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
- from .best_checkpoint import (
3
- BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
- )
5
- from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
6
- from .latest_epoch_checkpoint import (
7
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
8
- )
9
- from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
10
- from .model_checkpoint import (
11
- ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
12
- )
13
- from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
14
- from .on_exception_checkpoint import (
15
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
- )