nshtrainer 0.11.7__tar.gz → 0.11.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_checkpoint/loader.py +4 -4
  4. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_checkpoint/metadata.py +37 -35
  5. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/__init__.py +3 -0
  6. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/checkpoint/__init__.py +4 -0
  7. nshtrainer-0.11.8/src/nshtrainer/callbacks/checkpoint/_base.py +175 -0
  8. nshtrainer-0.11.8/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +70 -0
  9. nshtrainer-0.11.8/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +39 -0
  10. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +1 -1
  11. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/config.py +3 -1
  12. nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -192
  13. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/README.md +0 -0
  14. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/__init__.py +0 -0
  15. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_checkpoint/saver.py +0 -0
  16. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_experimental/__init__.py +0 -0
  17. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  18. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  19. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  20. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  21. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/actsave.py +0 -0
  22. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/base.py +0 -0
  23. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -0
  24. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  25. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  26. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/ema.py +0 -0
  27. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  28. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  29. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/interval.py +0 -0
  30. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  31. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  32. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/print_table.py +0 -0
  33. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  34. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/timer.py +0 -0
  35. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  36. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/data/__init__.py +0 -0
  37. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  38. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/data/transform.py +0 -0
  39. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/__init__.py +0 -0
  40. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/_experimental.py +0 -0
  41. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/actsave.py +0 -0
  42. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/callbacks.py +0 -0
  43. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/config.py +0 -0
  44. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/data.py +0 -0
  45. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/log.py +0 -0
  46. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  47. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/model.py +0 -0
  48. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/nn.py +0 -0
  49. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/optimizer.py +0 -0
  50. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/runner.py +0 -0
  51. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/snapshot.py +0 -0
  52. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/snoop.py +0 -0
  53. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/trainer.py +0 -0
  54. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/typecheck.py +0 -0
  55. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/ll/util.py +0 -0
  56. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  57. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  58. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  59. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  60. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/metrics/__init__.py +0 -0
  61. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/metrics/_config.py +0 -0
  62. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/__init__.py +0 -0
  63. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/base.py +0 -0
  64. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/callback.py +0 -0
  65. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/debug.py +0 -0
  66. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/distributed.py +0 -0
  67. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/logger.py +0 -0
  68. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/profiler.py +0 -0
  69. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  70. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  71. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/nn/__init__.py +0 -0
  72. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/nn/mlp.py +0 -0
  73. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/nn/module_dict.py +0 -0
  74. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/nn/module_list.py +0 -0
  75. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/nn/nonlinearity.py +0 -0
  76. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/optimizer.py +0 -0
  77. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/runner.py +0 -0
  78. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/scripts/find_packages.py +0 -0
  79. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/trainer/__init__.py +0 -0
  80. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  81. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  82. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/trainer/signal_connector.py +0 -0
  83. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/trainer/trainer.py +0 -0
  84. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/_environment_info.py +0 -0
  85. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/_useful_types.py +0 -0
  86. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/environment.py +0 -0
  87. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/seed.py +0 -0
  88. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/slurm.py +0 -0
  89. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/typed.py +0 -0
  90. {nshtrainer-0.11.7 → nshtrainer-0.11.8}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.7
3
+ Version: 0.11.8
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.7"
3
+ version = "0.11.8"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
236
236
  error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
237
237
  match on_error:
238
238
  case "warn":
239
- log.warn(error_msg)
239
+ log.warning(error_msg)
240
240
  case "raise":
241
241
  raise ValueError(error_msg)
242
242
  case _:
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
325
325
  ),
326
326
  ]
327
327
  if not candidates:
328
- log.warn(
328
+ log.warning(
329
329
  "No checkpoint candidates found for `best` checkpoint strategy."
330
330
  )
331
331
  continue
332
332
 
333
333
  if (metric := strategy.metric or root_config.primary_metric) is None:
334
- log.warn(
334
+ log.warning(
335
335
  "No metric specified for `best` checkpoint strategy, "
336
336
  "and no primary metric is set in the configuration. "
337
337
  "Skipping strategy."
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
360
360
  ),
361
361
  ]
362
362
  if not candidates:
363
- log.warn(
363
+ log.warning(
364
364
  "No checkpoint candidates found for `last` checkpoint strategy."
365
365
  )
366
366
  continue
@@ -4,7 +4,7 @@ import logging
4
4
  import shutil
5
5
  from collections.abc import Callable
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
8
8
 
9
9
  import nshconfig as C
10
10
  import numpy as np
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  METADATA_PATH_SUFFIX = ".metadata.json"
23
- HPARAMS_PATH_SUFFIX = ".hparams.json"
24
23
 
25
24
 
26
25
  class CheckpointMetadata(C.Config):
26
+ PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
27
+
27
28
  checkpoint_path: Path
28
29
  checkpoint_filename: str
29
30
 
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
39
40
  metrics: dict[str, Any]
40
41
  environment: EnvironmentConfig
41
42
 
43
+ hparams: dict[str, Any] | None
44
+
42
45
  @classmethod
43
46
  def from_file(cls, path: Path):
44
47
  return cls.model_validate_json(path.read_text())
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
55
58
 
56
59
 
57
60
  def _generate_checkpoint_metadata(
58
- config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
61
+ config: "BaseConfig",
62
+ trainer: "Trainer",
63
+ checkpoint_path: Path,
64
+ metadata_path: Path,
59
65
  ):
60
66
  checkpoint_timestamp = datetime.datetime.now()
61
67
  start_timestamp = trainer.start_time()
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
70
76
  metrics[name] = metric
71
77
 
72
78
  return CheckpointMetadata(
73
- checkpoint_path=checkpoint_path,
79
+ # checkpoint_path=checkpoint_path,
80
+ # We should store the path as a relative path
81
+ # to the metadata file to avoid issues with
82
+ # moving the checkpoint directory
83
+ checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
74
84
  checkpoint_filename=checkpoint_path.name,
75
85
  run_id=config.id,
76
86
  name=config.run_name,
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
84
94
  training_time=training_time,
85
95
  metrics=metrics,
86
96
  environment=config.environment,
97
+ hparams=config.model_dump(mode="json"),
87
98
  )
88
99
 
89
100
 
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
93
104
  checkpoint_path: Path,
94
105
  ):
95
106
  config = cast("BaseConfig", model.config)
96
- metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
107
+ metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
+ metadata = _generate_checkpoint_metadata(
109
+ config, trainer, checkpoint_path, metadata_path
110
+ )
97
111
 
98
112
  # Write the metadata to the checkpoint directory
99
113
  try:
100
- metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
101
114
  metadata_path.write_text(metadata.model_dump_json(indent=4))
102
115
  except Exception as e:
103
116
  log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
104
117
  else:
105
118
  log.debug(f"Checkpoint metadata written to {checkpoint_path}")
106
119
 
107
- # Write the hparams to the checkpoint directory
120
+
121
+ def _remove_checkpoint_metadata(checkpoint_path: Path):
122
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
108
123
  try:
109
- hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
110
- hparams_path.write_text(config.model_dump_json(indent=4))
124
+ path.unlink(missing_ok=True)
111
125
  except Exception as e:
112
- log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
126
+ log.warning(f"Failed to remove {path}: {e}")
113
127
  else:
114
- log.debug(f"Checkpoint metadata written to {checkpoint_path}")
115
-
116
-
117
- def _remove_checkpoint_metadata(checkpoint_path: Path):
118
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
119
- path = checkpoint_path.with_suffix(suffix)
120
- try:
121
- path.unlink(missing_ok=True)
122
- except Exception as e:
123
- log.warning(f"Failed to remove {path}: {e}")
124
- else:
125
- log.debug(f"Removed {path}")
128
+ log.debug(f"Removed {path}")
126
129
 
127
130
 
128
131
  def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
130
133
  _remove_checkpoint_metadata(linked_checkpoint_path)
131
134
 
132
135
  # Link the metadata files to the new checkpoint
133
- for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
134
- path = checkpoint_path.with_suffix(suffix)
135
- linked_path = linked_checkpoint_path.with_suffix(suffix)
136
+ path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
137
+ linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
138
+ try:
136
139
  try:
137
- try:
138
- linked_path.symlink_to(path)
139
- except OSError:
140
- # on Windows, special permissions are required to create symbolic links as a regular user
141
- # fall back to copying the file
142
- shutil.copy(path, linked_path)
143
- except Exception as e:
144
- log.warning(f"Failed to link {path} to {linked_path}: {e}")
145
- else:
146
- log.debug(f"Linked {path} to {linked_path}")
140
+ linked_path.symlink_to(path)
141
+ except OSError:
142
+ # on Windows, special permissions are required to create symbolic links as a regular user
143
+ # fall back to copying the file
144
+ shutil.copy(path, linked_path)
145
+ except Exception as e:
146
+ log.warning(f"Failed to link {path} to {linked_path}: {e}")
147
+ else:
148
+ log.debug(f"Linked {path} to {linked_path}")
147
149
 
148
150
 
149
151
  def _sort_ckpts_by_metadata(
@@ -6,6 +6,8 @@ from . import checkpoint as checkpoint
6
6
  from .base import CallbackConfigBase as CallbackConfigBase
7
7
  from .checkpoint import BestCheckpoint as BestCheckpoint
8
8
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
+ from .checkpoint import LastCheckpoint as LastCheckpoint
10
+ from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
9
11
  from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
10
12
  from .checkpoint import (
11
13
  LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
@@ -46,6 +48,7 @@ CallbackConfig = Annotated[
46
48
  | GradientSkippingConfig
47
49
  | EMAConfig
48
50
  | BestCheckpointCallbackConfig
51
+ | LastCheckpointCallbackConfig
49
52
  | ModelCheckpointCallbackConfig
50
53
  | LatestEpochCheckpointCallbackConfig
51
54
  | OnExceptionCheckpointCallbackConfig
@@ -2,6 +2,10 @@ from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
2
  from .best_checkpoint import (
3
3
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
4
  )
5
+ from .last_checkpoint import LastCheckpoint as LastCheckpoint
6
+ from .last_checkpoint import (
7
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
8
+ )
5
9
  from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
6
10
  from .latest_epoch_checkpoint import (
7
11
  LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
@@ -0,0 +1,175 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, Generic, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Checkpoint
10
+ from typing_extensions import TypeVar, override
11
+
12
+ from ..._checkpoint.metadata import CheckpointMetadata
13
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
+ from ..base import CallbackConfigBase
15
+
16
+ if TYPE_CHECKING:
17
+ from ...model.config import BaseConfig
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
23
+ dirpath: str | Path | None = None
24
+ """Directory path to save the checkpoint file."""
25
+
26
+ filename: str | None = None
27
+ """Checkpoint filename. This must not include the extension.
28
+ If None, the default filename will be used."""
29
+
30
+ save_weights_only: bool = False
31
+ """Whether to save only the model's weights or the entire model object."""
32
+
33
+ save_symlink: bool = True
34
+ """Whether to create a symlink to the saved checkpoint."""
35
+
36
+ topk: int | Literal["all"] = 1
37
+ """The number of checkpoints to keep."""
38
+
39
+ @abstractmethod
40
+ def create_checkpoint(
41
+ self,
42
+ root_config: "BaseConfig",
43
+ dirpath: Path,
44
+ ) -> "CheckpointBase": ...
45
+
46
+ @override
47
+ def create_callbacks(self, root_config):
48
+ dirpath = Path(
49
+ self.dirpath
50
+ or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
51
+ )
52
+
53
+ yield self.create_checkpoint(root_config, dirpath)
54
+
55
+
56
+ TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
57
+
58
+
59
+ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
60
+ def __init__(self, config: TConfig, dirpath: Path):
61
+ super().__init__()
62
+
63
+ self.config = config
64
+ self.dirpath = dirpath / self.name()
65
+ self.symlink_dirpath = dirpath
66
+
67
+ self._last_global_step_saved = 0
68
+
69
+ @abstractmethod
70
+ def default_filename(self) -> str: ...
71
+
72
+ @abstractmethod
73
+ def name(self) -> str: ...
74
+
75
+ def extension(self) -> str:
76
+ return ".ckpt"
77
+
78
+ @abstractmethod
79
+ def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
80
+
81
+ def symlink_path(self):
82
+ if not self.config.save_symlink:
83
+ return None
84
+
85
+ return self.symlink_dirpath / f"{self.name()}{self.extension()}"
86
+
87
+ def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
88
+ if (filename := self.config.filename) is None:
89
+ filename = self.default_filename()
90
+ filename = filename.format(**current_metrics)
91
+ return self.dirpath / f"{filename}{self.extension()}"
92
+
93
+ def remove_old_checkpoints(self, trainer: Trainer):
94
+ if (topk := self.config.topk) == "all":
95
+ return
96
+
97
+ # Get all the checkpoint metadata
98
+ metas = [
99
+ CheckpointMetadata.from_file(p)
100
+ for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
101
+ ]
102
+
103
+ # Sort by the topk sort key
104
+ metas = sorted(metas, key=self.topk_sort_key)
105
+
106
+ # Now, the metas are sorted from the best to the worst,
107
+ # so we can remove the worst checkpoints
108
+ for meta in metas[topk:]:
109
+ if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
110
+ log.warning(
111
+ f"Checkpoint file not found: {old_ckpt_path}\n"
112
+ "Skipping removal of the checkpoint metadata."
113
+ )
114
+ continue
115
+
116
+ _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
117
+ log.debug(f"Removed old checkpoint: {old_ckpt_path}")
118
+
119
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
120
+ current_metrics: dict[str, Any] = {
121
+ "epoch": trainer.current_epoch,
122
+ "step": trainer.global_step,
123
+ }
124
+
125
+ for name, value in trainer.callback_metrics.items():
126
+ match value:
127
+ case torch.Tensor() if value.numel() == 1:
128
+ value = value.detach().cpu().item()
129
+ case np.ndarray() if value.size == 1:
130
+ value = value.item()
131
+ case _:
132
+ pass
133
+
134
+ current_metrics[name] = value
135
+
136
+ return current_metrics
137
+
138
+ def save_checkpoints(self, trainer: Trainer):
139
+ if self._should_skip_saving_checkpoint(trainer):
140
+ return
141
+
142
+ # Save the new checkpoint
143
+ filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
144
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
145
+
146
+ if trainer.is_global_zero:
147
+ # Remove old checkpoints
148
+ self.remove_old_checkpoints(trainer)
149
+
150
+ # Create the latest symlink
151
+ if (symlink_filename := self.symlink_path()) is not None:
152
+ symlink_path = self.dirpath / symlink_filename
153
+ _link_checkpoint(filepath, symlink_path, metadata=True)
154
+ log.debug(f"Created latest symlink: {symlink_path}")
155
+
156
+ # Barrier to ensure all processes have saved the checkpoint,
157
+ # deleted the old checkpoints, and created the symlink before continuing
158
+ trainer.strategy.barrier()
159
+
160
+ # Set the last global step saved
161
+ self._last_global_step_saved = trainer.global_step
162
+
163
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
164
+ from lightning.pytorch.trainer.states import TrainerFn
165
+
166
+ return (
167
+ bool(
168
+ getattr(trainer, "fast_dev_run", False)
169
+ ) # disable checkpointing with fast_dev_run
170
+ or trainer.state.fn
171
+ != TrainerFn.FITTING # don't save anything during non-fit
172
+ or trainer.sanity_checking # don't save anything during sanity check
173
+ or self._last_global_step_saved
174
+ == trainer.global_step # already saved at the last step
175
+ )
@@ -0,0 +1,70 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from typing_extensions import final, override
7
+
8
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
9
+
10
+ from ...metrics._config import MetricConfig
11
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
18
+ name: Literal["best_checkpoint"] = "best_checkpoint"
19
+
20
+ metric: MetricConfig | None = None
21
+ """Metric to monitor, or `None` to use the default metric."""
22
+
23
+ @override
24
+ def create_checkpoint(self, root_config, dirpath):
25
+ # Resolve metric
26
+ if (metric := self.metric) is None and (
27
+ metric := root_config.primary_metric
28
+ ) is None:
29
+ raise ValueError(
30
+ "No metric provided and no primary metric found in the root config"
31
+ )
32
+
33
+ return BestCheckpoint(self, dirpath, metric)
34
+
35
+
36
+ @final
37
+ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
38
+ @property
39
+ def _metric_name_normalized(self):
40
+ return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
41
+
42
+ @override
43
+ def __init__(
44
+ self,
45
+ config: BestCheckpointCallbackConfig,
46
+ dirpath: Path,
47
+ metric: MetricConfig,
48
+ ):
49
+ super().__init__(config, dirpath)
50
+ self.metric = metric
51
+
52
+ @override
53
+ def name(self):
54
+ return f"best_{self._metric_name_normalized}"
55
+
56
+ @override
57
+ def default_filename(self):
58
+ return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
59
+
60
+ @override
61
+ def topk_sort_key(self, metadata: CheckpointMetadata):
62
+ return metadata.metrics.get(
63
+ self.metric.validation_monitor,
64
+ float("-inf" if self.metric.mode == "max" else "inf"),
65
+ )
66
+
67
+ # Events
68
+ @override
69
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
70
+ self.save_checkpoints(trainer)
@@ -0,0 +1,39 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from typing_extensions import final, override
6
+
7
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
8
+
9
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ @final
15
+ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
16
+ name: Literal["last_checkpoint"] = "last_checkpoint"
17
+
18
+ @override
19
+ def create_checkpoint(self, root_config, dirpath):
20
+ return LastCheckpoint(self, dirpath)
21
+
22
+
23
+ @final
24
+ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
25
+ @override
26
+ def name(self):
27
+ return "last"
28
+
29
+ @override
30
+ def default_filename(self):
31
+ return "epoch{epoch:03d}-step{step:07d}"
32
+
33
+ @override
34
+ def topk_sort_key(self, metadata: CheckpointMetadata):
35
+ return metadata.checkpoint_timestamp
36
+
37
+ @override
38
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
39
+ self.save_checkpoints(trainer)
@@ -19,7 +19,7 @@ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
19
19
  dirpath: str | Path | None = None
20
20
  """Directory path to save the checkpoint file."""
21
21
 
22
- filename: str = "epoch{epoch:02d}_step{step:04d}"
22
+ filename: str = "epoch{epoch:03d}_step{step:07d}"
23
23
  """Checkpoint filename. This must not include the extension."""
24
24
 
25
25
  save_weights_only: bool = False
@@ -39,6 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
39
39
  from ..callbacks import (
40
40
  BestCheckpointCallbackConfig,
41
41
  CallbackConfig,
42
+ LastCheckpointCallbackConfig,
42
43
  LatestEpochCheckpointCallbackConfig,
43
44
  ModelCheckpointCallbackConfig,
44
45
  OnExceptionCheckpointCallbackConfig,
@@ -773,6 +774,7 @@ class ReproducibilityConfig(C.Config):
773
774
  CheckpointCallbackConfig: TypeAlias = Annotated[
774
775
  ModelCheckpointCallbackConfig
775
776
  | BestCheckpointCallbackConfig
777
+ | LastCheckpointCallbackConfig
776
778
  | LatestEpochCheckpointCallbackConfig
777
779
  | OnExceptionCheckpointCallbackConfig,
778
780
  C.Field(discriminator="name"),
@@ -786,7 +788,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
786
788
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
787
789
  # ModelCheckpointCallbackConfig(),
788
790
  BestCheckpointCallbackConfig(),
789
- LatestEpochCheckpointCallbackConfig(),
791
+ LastCheckpointCallbackConfig(),
790
792
  OnExceptionCheckpointCallbackConfig(),
791
793
  ]
792
794
  """Checkpoint callback configurations."""
@@ -1,192 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Any, Literal
4
-
5
- from lightning.pytorch import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Checkpoint
7
- from typing_extensions import override
8
-
9
- from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
- from ...metrics._config import MetricConfig
12
- from ..base import CallbackConfigBase
13
-
14
- log = logging.getLogger(__name__)
15
-
16
-
17
- class BestCheckpointCallbackConfig(CallbackConfigBase):
18
- name: Literal["best_checkpoint"] = "best_checkpoint"
19
-
20
- dirpath: str | Path | None = None
21
- """Directory path to save the checkpoint file."""
22
-
23
- filename: str = "epoch{epoch:02d}_step{step:04d}"
24
- """Checkpoint filename. This must not include the extension."""
25
-
26
- save_weights_only: bool = False
27
- """Whether to save only the model's weights or the entire model object."""
28
-
29
- metric: MetricConfig | None = None
30
- """Metric to monitor, or `None` to use the default metric."""
31
-
32
- best_symlink_filename: str | None = "best"
33
- """Filename for the best symlink. If None, no symlink will be created."""
34
-
35
- save_top_k: int | Literal["all"] = 1
36
- """The number of best checkpoints to keep."""
37
-
38
- @override
39
- def create_callbacks(self, root_config):
40
- dirpath = Path(
41
- self.dirpath
42
- or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
43
- )
44
-
45
- # Resolve metric
46
- if (metric := self.metric) is None and (
47
- metric := root_config.primary_metric
48
- ) is None:
49
- raise ValueError(
50
- "No metric provided and no primary metric found in the root config"
51
- )
52
-
53
- yield BestCheckpoint(self, metric, dirpath)
54
-
55
- @property
56
- def _save_top_k_value(self):
57
- return float("inf" if self.save_top_k == "all" else self.save_top_k)
58
-
59
-
60
- class BestCheckpoint(Checkpoint):
61
- PREFIX = "best_"
62
- EXTENSION = ".ckpt"
63
-
64
- def __init__(
65
- self,
66
- config: BestCheckpointCallbackConfig,
67
- metric: MetricConfig,
68
- dirpath: Path,
69
- ):
70
- super().__init__()
71
- self.config = config
72
- self.metric = metric
73
- self.dirpath = dirpath
74
-
75
- self._last_global_step_saved = 0 # no need to save when no steps were taken
76
-
77
- @override
78
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
79
- self._save_best_checkpoint(trainer)
80
-
81
- def _best_symlink_filename(self):
82
- if (filename := self.config.best_symlink_filename) is None:
83
- return None
84
- return f"{filename}{self.EXTENSION}"
85
-
86
- def _ckpt_path(self, trainer: Trainer):
87
- filename = self.config.filename.format(
88
- epoch=trainer.current_epoch, step=trainer.global_step
89
- )
90
- filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
91
- return self.dirpath / filename
92
-
93
- def _get_metric_value(self, metrics: dict[str, Any]):
94
- return metrics.get(
95
- self.metric.validation_monitor,
96
- float("-inf" if self.metric.mode == "max" else "inf"),
97
- )
98
-
99
- def _sorted_ckpts(self):
100
- """
101
- Get sorted checkpoints by the metric value.
102
-
103
- Sort order: best -> worst
104
- """
105
- ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
106
- return _sort_ckpts_by_metadata(
107
- ckpt_paths,
108
- key=lambda meta, _: self._get_metric_value(meta.metrics),
109
- reverse=(self.metric.mode == "max"),
110
- )
111
-
112
- def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
113
- # Resolve the symlink filename
114
- if (symlink_filename := self._best_symlink_filename()) is None:
115
- return
116
-
117
- # If the symlink already exists and points to the best checkpoint,
118
- # then we don't need to create a new symlink.
119
- symlink_path = self.dirpath / symlink_filename
120
- if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
121
- return
122
-
123
- _link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
124
- log.debug(f"Created best symlink: {symlink_path}")
125
-
126
- def _save_best_checkpoint(self, trainer: Trainer):
127
- # Skip saving the checkpoint if we're not in the fitting state
128
- if self._should_skip_saving_checkpoint(trainer):
129
- return
130
-
131
- # Get the current metric value
132
- if (current := self._get_metric_value(trainer.callback_metrics)) is None:
133
- log.warning(
134
- f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
135
- )
136
- return
137
-
138
- # Get sorted checkpoints
139
- sorted_ckpts = self._sorted_ckpts()
140
-
141
- # If the current model is worse than the worst checkpoint,
142
- # and we have already saved the maximum number of checkpoints,
143
- # then don't save the current model.
144
- if len(
145
- sorted_ckpts
146
- ) >= self.config._save_top_k_value and not self.metric.is_better(
147
- current,
148
- self._get_metric_value(sorted_ckpts[-1][0].metrics),
149
- ):
150
- return
151
-
152
- # Save the current model
153
- filepath = self._ckpt_path(trainer)
154
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
155
- log.debug(f"Saved best checkpoint: {filepath}")
156
-
157
- if trainer.is_global_zero:
158
- # Get the sorted checkpoints again because now we have added a new checkpoint.
159
- # We could optimize this by adding the new checkpoint to the sorted list,
160
- # and then sorting it in place, but this is simpler.
161
- sorted_ckpts = self._sorted_ckpts()
162
-
163
- # Remove worst checkpoint if we've reached save_top_k
164
- if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
165
- # NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
166
- for _, ckpt_path in sorted_ckpts[topk:]:
167
- _remove_checkpoint(trainer, ckpt_path, metadata=True)
168
-
169
- # Create symlink to best model
170
- if sorted_ckpts:
171
- _, best_ckpt_path = sorted_ckpts[0]
172
- self._create_symlink(trainer, best_ckpt_path)
173
-
174
- # Update the last global step saved
175
- self._last_global_step_saved = trainer.global_step
176
-
177
- # Barrier to ensure all processes have saved the checkpoint before continuing
178
- trainer.strategy.barrier()
179
-
180
- def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
181
- from lightning.pytorch.trainer.states import TrainerFn
182
-
183
- return (
184
- bool(
185
- getattr(trainer, "fast_dev_run", False)
186
- ) # disable checkpointing with fast_dev_run
187
- or trainer.state.fn
188
- != TrainerFn.FITTING # don't save anything during non-fit
189
- or trainer.sanity_checking # don't save anything during sanity check
190
- or self._last_global_step_saved
191
- == trainer.global_step # already saved at the last step
192
- )
File without changes