nshtrainer 0.11.8__tar.gz → 0.11.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/__init__.py +0 -8
  4. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -8
  5. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/__init__.py +2 -4
  6. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/config.py +1 -36
  7. nshtrainer-0.11.8/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +0 -131
  8. nshtrainer-0.11.8/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -207
  9. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/README.md +0 -0
  10. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/__init__.py +0 -0
  11. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_checkpoint/loader.py +0 -0
  12. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  13. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_checkpoint/saver.py +0 -0
  14. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_experimental/__init__.py +0 -0
  15. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  16. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  17. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  18. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  19. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/actsave.py +0 -0
  20. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/base.py +0 -0
  21. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  22. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
  23. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  24. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  25. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  26. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/ema.py +0 -0
  27. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  28. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  29. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/interval.py +0 -0
  30. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  31. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  32. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/print_table.py +0 -0
  33. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  34. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/timer.py +0 -0
  35. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  36. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/data/__init__.py +0 -0
  37. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  38. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/data/transform.py +0 -0
  39. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/__init__.py +0 -0
  40. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/_experimental.py +0 -0
  41. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/actsave.py +0 -0
  42. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/callbacks.py +0 -0
  43. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/config.py +0 -0
  44. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/data.py +0 -0
  45. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/log.py +0 -0
  46. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  47. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/model.py +0 -0
  48. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/nn.py +0 -0
  49. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/optimizer.py +0 -0
  50. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/runner.py +0 -0
  51. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/snapshot.py +0 -0
  52. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/snoop.py +0 -0
  53. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/trainer.py +0 -0
  54. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/typecheck.py +0 -0
  55. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/ll/util.py +0 -0
  56. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  57. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  58. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  59. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  60. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/metrics/__init__.py +0 -0
  61. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/metrics/_config.py +0 -0
  62. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/base.py +0 -0
  63. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/callback.py +0 -0
  64. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/debug.py +0 -0
  65. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/distributed.py +0 -0
  66. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/logger.py +0 -0
  67. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/profiler.py +0 -0
  68. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  69. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  70. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/nn/__init__.py +0 -0
  71. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/nn/mlp.py +0 -0
  72. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/nn/module_dict.py +0 -0
  73. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/nn/module_list.py +0 -0
  74. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/nn/nonlinearity.py +0 -0
  75. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/optimizer.py +0 -0
  76. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/runner.py +0 -0
  77. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/scripts/find_packages.py +0 -0
  78. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/trainer/__init__.py +0 -0
  79. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  80. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  81. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/trainer/signal_connector.py +0 -0
  82. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/trainer/trainer.py +0 -0
  83. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/_environment_info.py +0 -0
  84. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/_useful_types.py +0 -0
  85. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/environment.py +0 -0
  86. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/seed.py +0 -0
  87. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/slurm.py +0 -0
  88. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/typed.py +0 -0
  89. {nshtrainer-0.11.8 → nshtrainer-0.11.10}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.8
3
+ Version: 0.11.10
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.8"
3
+ version = "0.11.10"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -8,12 +8,6 @@ from .checkpoint import BestCheckpoint as BestCheckpoint
8
8
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
9
  from .checkpoint import LastCheckpoint as LastCheckpoint
10
10
  from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
11
- from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
12
- from .checkpoint import (
13
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
14
- )
15
- from .checkpoint import ModelCheckpoint as ModelCheckpoint
16
- from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
17
11
  from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
18
12
  from .checkpoint import (
19
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -49,8 +43,6 @@ CallbackConfig = Annotated[
49
43
  | EMAConfig
50
44
  | BestCheckpointCallbackConfig
51
45
  | LastCheckpointCallbackConfig
52
- | ModelCheckpointCallbackConfig
53
- | LatestEpochCheckpointCallbackConfig
54
46
  | OnExceptionCheckpointCallbackConfig
55
47
  | WandbWatchConfig,
56
48
  C.Field(discriminator="name"),
@@ -6,14 +6,6 @@ from .last_checkpoint import LastCheckpoint as LastCheckpoint
6
6
  from .last_checkpoint import (
7
7
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
8
8
  )
9
- from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
10
- from .latest_epoch_checkpoint import (
11
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
12
- )
13
- from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
14
- from .model_checkpoint import (
15
- ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
16
- )
17
9
  from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
18
10
  from .on_exception_checkpoint import (
19
11
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -5,17 +5,15 @@ from .base import LightningModuleBase as LightningModuleBase
5
5
  from .config import BaseConfig as BaseConfig
6
6
  from .config import BaseLoggerConfig as BaseLoggerConfig
7
7
  from .config import BaseProfilerConfig as BaseProfilerConfig
8
+ from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
8
9
  from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
9
10
  from .config import CheckpointSavingConfig as CheckpointSavingConfig
10
11
  from .config import DirectoryConfig as DirectoryConfig
11
12
  from .config import EarlyStoppingConfig as EarlyStoppingConfig
12
13
  from .config import GradientClippingConfig as GradientClippingConfig
13
- from .config import (
14
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
15
- )
14
+ from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
16
15
  from .config import LoggingConfig as LoggingConfig
17
16
  from .config import MetricConfig as MetricConfig
18
- from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
19
17
  from .config import (
20
18
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
21
19
  )
@@ -40,8 +40,6 @@ from ..callbacks import (
40
40
  BestCheckpointCallbackConfig,
41
41
  CallbackConfig,
42
42
  LastCheckpointCallbackConfig,
43
- LatestEpochCheckpointCallbackConfig,
44
- ModelCheckpointCallbackConfig,
45
43
  OnExceptionCheckpointCallbackConfig,
46
44
  WandbWatchConfig,
47
45
  )
@@ -772,10 +770,8 @@ class ReproducibilityConfig(C.Config):
772
770
 
773
771
 
774
772
  CheckpointCallbackConfig: TypeAlias = Annotated[
775
- ModelCheckpointCallbackConfig
776
- | BestCheckpointCallbackConfig
773
+ BestCheckpointCallbackConfig
777
774
  | LastCheckpointCallbackConfig
778
- | LatestEpochCheckpointCallbackConfig
779
775
  | OnExceptionCheckpointCallbackConfig,
780
776
  C.Field(discriminator="name"),
781
777
  ]
@@ -786,7 +782,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
786
782
  """Enable checkpoint saving."""
787
783
 
788
784
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
789
- # ModelCheckpointCallbackConfig(),
790
785
  BestCheckpointCallbackConfig(),
791
786
  LastCheckpointCallbackConfig(),
792
787
  OnExceptionCheckpointCallbackConfig(),
@@ -806,36 +801,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
806
801
 
807
802
  return True
808
803
 
809
- @property
810
- def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
811
- return next(
812
- (
813
- callback
814
- for callback in self.checkpoint_callbacks
815
- if isinstance(callback, ModelCheckpointCallbackConfig)
816
- ),
817
- )
818
-
819
- @property
820
- def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
821
- return next(
822
- (
823
- callback
824
- for callback in self.checkpoint_callbacks
825
- if isinstance(callback, LatestEpochCheckpointCallbackConfig)
826
- ),
827
- )
828
-
829
- @property
830
- def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
831
- return next(
832
- (
833
- callback
834
- for callback in self.checkpoint_callbacks
835
- if isinstance(callback, OnExceptionCheckpointCallbackConfig)
836
- ),
837
- )
838
-
839
804
  @override
840
805
  def create_callbacks(self, root_config: "BaseConfig"):
841
806
  if not self.should_save_checkpoints(root_config):
@@ -1,131 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Literal
4
-
5
- from lightning.pytorch import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Checkpoint
7
- from typing_extensions import override
8
-
9
- from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
- from ..base import CallbackConfigBase
12
-
13
- log = logging.getLogger(__name__)
14
-
15
-
16
- class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
17
- name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
18
-
19
- dirpath: str | Path | None = None
20
- """Directory path to save the checkpoint file."""
21
-
22
- filename: str = "epoch{epoch:03d}_step{step:07d}"
23
- """Checkpoint filename. This must not include the extension."""
24
-
25
- save_weights_only: bool = False
26
- """Whether to save only the model's weights or the entire model object."""
27
-
28
- latest_symlink_filename: str | None = "latest"
29
- """Filename for the latest symlink. If None, no symlink will be created."""
30
-
31
- latest_k: int | Literal["all"] = 1
32
- """Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
33
-
34
- @override
35
- def create_callbacks(self, root_config):
36
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
37
- root_config.id, "checkpoint"
38
- )
39
- dirpath = Path(dirpath)
40
-
41
- yield LatestEpochCheckpoint(self, dirpath)
42
-
43
-
44
- class LatestEpochCheckpoint(Checkpoint):
45
- PREFIX = "latest_"
46
- EXTENSION = ".ckpt"
47
-
48
- def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
49
- super().__init__()
50
-
51
- self.config = config
52
- self.dirpath = dirpath
53
-
54
- self._last_global_step_saved = 0
55
-
56
- @override
57
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
58
- self._save_new_checkpoint(trainer)
59
-
60
- def _latest_symlink_filename(self):
61
- if (filename := self.config.latest_symlink_filename) is None:
62
- return None
63
- return f"{filename}{self.EXTENSION}"
64
-
65
- def _ckpt_path(self, trainer: Trainer):
66
- filename = self.config.filename.format(
67
- epoch=trainer.current_epoch, step=trainer.global_step
68
- )
69
- filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
70
- return self.dirpath / filename
71
-
72
- def _remove_old_checkpoints(self, trainer: Trainer):
73
- if (latest_k := self.config.latest_k) == "all":
74
- return
75
-
76
- # Get all configs, ignoring the latest symlink
77
- ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
78
- # Ignore the latest symlink
79
- if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
80
- ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
81
-
82
- # Sort by epoch, then step, then last modified
83
- ckpts = _sort_ckpts_by_metadata(
84
- ckpts,
85
- key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
86
- reverse=True,
87
- )
88
-
89
- # Remove all but the latest k checkpoints
90
- # NOTE: We add 1 to the latest_k here because
91
- # we're about to save a new checkpoint.
92
- for _, ckpt_path in ckpts[latest_k:]:
93
- _remove_checkpoint(trainer, ckpt_path, metadata=True)
94
-
95
- def _save_new_checkpoint(self, trainer: Trainer):
96
- if self._should_skip_saving_checkpoint(trainer):
97
- return
98
-
99
- # Save the new checkpoint
100
- filepath = self._ckpt_path(trainer)
101
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
102
-
103
- if trainer.is_global_zero:
104
- # Remove old checkpoints
105
- self._remove_old_checkpoints(trainer)
106
-
107
- # Create the latest symlink
108
- if (symlink_filename := self._latest_symlink_filename()) is not None:
109
- symlink_path = self.dirpath / symlink_filename
110
- _link_checkpoint(filepath, symlink_path, metadata=True)
111
- log.debug(f"Created latest symlink: {symlink_path}")
112
-
113
- # Set the last global step saved
114
- self._last_global_step_saved = trainer.global_step
115
-
116
- # Barrier to ensure all processes have saved the checkpoint before continuing
117
- trainer.strategy.barrier()
118
-
119
- def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
120
- from lightning.pytorch.trainer.states import TrainerFn
121
-
122
- return (
123
- bool(
124
- getattr(trainer, "fast_dev_run", False)
125
- ) # disable checkpointing with fast_dev_run
126
- or trainer.state.fn
127
- != TrainerFn.FITTING # don't save anything during non-fit
128
- or trainer.sanity_checking # don't save anything during sanity check
129
- or self._last_global_step_saved
130
- == trainer.global_step # already saved at the last step
131
- )
@@ -1,207 +0,0 @@
1
- import logging
2
- import re
3
- from datetime import timedelta
4
- from pathlib import Path
5
- from typing import TYPE_CHECKING, Literal
6
-
7
- from lightning.pytorch import Trainer
8
- from lightning.pytorch.callbacks.model_checkpoint import (
9
- ModelCheckpoint as _ModelCheckpoint,
10
- )
11
- from typing_extensions import override
12
-
13
- from ..._checkpoint.saver import _link_checkpoint
14
- from ..._checkpoint.saver import _remove_checkpoint as _ckpt_saver_remove_checkpoint
15
- from ...metrics import MetricConfig
16
- from ..base import CallbackConfigBase
17
-
18
- if TYPE_CHECKING:
19
- from ...model.config import BaseConfig
20
-
21
- log = logging.getLogger(__name__)
22
-
23
-
24
- def _convert_string(input_string: str):
25
- # Find all variables enclosed in curly braces
26
- variables = re.findall(r"\{(.*?)\}", input_string)
27
-
28
- # Replace each variable with its corresponding key-value pair
29
- output_string = input_string
30
- for variable in variables:
31
- # If the name is something like {variable:format}, we shouldn't process the format.
32
- key_name = variable
33
- if ":" in variable:
34
- key_name, _ = variable.split(":", 1)
35
- continue
36
-
37
- # Replace '/' with '_' in the key name
38
- key_name = key_name.replace("/", "_")
39
- output_string = output_string.replace(
40
- f"{{{variable}}}", f"{key_name}={{{variable}}}"
41
- )
42
-
43
- return output_string
44
-
45
-
46
- class ModelCheckpointCallbackConfig(CallbackConfigBase):
47
- """Arguments for the ModelCheckpoint callback."""
48
-
49
- name: Literal["model_checkpoint"] = "model_checkpoint"
50
-
51
- dirpath: str | Path | None = None
52
- """
53
- Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
54
- """
55
-
56
- filename: str | None = None
57
- """
58
- Checkpoint filename.
59
- If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
60
- """
61
-
62
- metric: MetricConfig | None = None
63
- """
64
- Metric to monitor for saving checkpoints.
65
- If None, the primary metric of the runner will be used, if available.
66
- """
67
-
68
- verbose: bool = False
69
- """Verbosity mode. If True, print additional information about checkpoints."""
70
-
71
- save_last: Literal[True, False, "link"] | None = "link"
72
- """
73
- Whether to save the last checkpoint.
74
- If True, saves a copy of the last checkpoint separately.
75
- If "link", creates a symbolic link to the last checkpoint.
76
- """
77
-
78
- save_top_k: int | Literal["all"] = 1
79
- """
80
- Number of best models to save.
81
- If "all" or -1, all models are saved.
82
- If 0, no models are saved.
83
- """
84
-
85
- save_weights_only: bool = False
86
- """Whether to save only the model's weights or the entire model object."""
87
-
88
- auto_insert_metric_name: bool = True
89
- """Whether to automatically insert the metric name in the checkpoint filename."""
90
-
91
- every_n_train_steps: int | None = None
92
- """
93
- Number of training steps between checkpoints.
94
- If None or 0, no checkpoints are saved during training.
95
- """
96
-
97
- train_time_interval: timedelta | None = None
98
- """
99
- Time interval between checkpoints during training.
100
- If None, no checkpoints are saved during training based on time.
101
- """
102
-
103
- every_n_epochs: int | None = None
104
- """
105
- Number of epochs between checkpoints.
106
- If None or 0, no checkpoints are saved at the end of epochs.
107
- """
108
-
109
- save_on_train_epoch_end: bool | None = None
110
- """
111
- Whether to run checkpointing at the end of the training epoch.
112
- If False, checkpointing runs at the end of the validation.
113
- """
114
-
115
- enable_version_counter: bool = True
116
- """Whether to append a version to the existing file name."""
117
-
118
- auto_append_metric: bool = True
119
- """If enabled, this will automatically add "-{monitor}" to the filename."""
120
-
121
- def metric_or_default(self, root_config: "BaseConfig"):
122
- if self.metric is not None:
123
- return self.metric
124
- if root_config.primary_metric is not None:
125
- return root_config.primary_metric
126
- raise ValueError("Primary metric must be provided if metric is not specified.")
127
-
128
- def resolve_filename(self, root_config: "BaseConfig"):
129
- metric = self.metric_or_default(root_config)
130
-
131
- filename = self.filename
132
- if not filename:
133
- filename = "{epoch}-{step}"
134
- if self.auto_append_metric:
135
- filename = f"{filename}-{{{metric.validation_monitor}}}"
136
-
137
- if self.auto_insert_metric_name and filename:
138
- new_filename = _convert_string(filename)
139
- log.critical(
140
- f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
141
- )
142
- filename = new_filename
143
-
144
- return filename
145
-
146
- @override
147
- def create_callbacks(self, root_config):
148
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
149
- root_config.id, "checkpoint"
150
- )
151
-
152
- metric = self.metric_or_default(root_config)
153
- filename = self.resolve_filename(root_config)
154
-
155
- yield ModelCheckpoint(
156
- self,
157
- dirpath=Path(dirpath),
158
- filename=filename,
159
- metric=metric,
160
- )
161
-
162
- def _save_top_k_model_ckpt_input(self):
163
- if self.save_top_k == "all":
164
- return -1
165
- return self.save_top_k
166
-
167
-
168
- class ModelCheckpoint(_ModelCheckpoint):
169
- CHECKPOINT_NAME_LAST = "best"
170
-
171
- @override
172
- def __init__(
173
- self,
174
- config: ModelCheckpointCallbackConfig,
175
- dirpath: Path,
176
- filename: str,
177
- metric: MetricConfig,
178
- ):
179
- self.config = config
180
- del config
181
-
182
- super().__init__(
183
- dirpath=dirpath,
184
- filename=filename,
185
- monitor=metric.validation_monitor,
186
- mode=metric.mode,
187
- verbose=self.config.verbose,
188
- save_last=self.config.save_last,
189
- save_top_k=self.config._save_top_k_model_ckpt_input(),
190
- save_weights_only=self.config.save_weights_only,
191
- auto_insert_metric_name=False,
192
- every_n_train_steps=self.config.every_n_train_steps,
193
- train_time_interval=self.config.train_time_interval,
194
- every_n_epochs=self.config.every_n_epochs,
195
- save_on_train_epoch_end=self.config.save_on_train_epoch_end,
196
- enable_version_counter=self.config.enable_version_counter,
197
- )
198
-
199
- @override
200
- def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
201
- if trainer.is_global_zero:
202
- _link_checkpoint(filepath, linkpath, metadata=True)
203
- trainer.strategy.barrier()
204
-
205
- @override
206
- def _remove_checkpoint(self, trainer: Trainer, filepath: str):
207
- _ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
File without changes
@@ -46,8 +46,8 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
46
46
  dirpath: Path,
47
47
  metric: MetricConfig,
48
48
  ):
49
- super().__init__(config, dirpath)
50
49
  self.metric = metric
50
+ super().__init__(config, dirpath)
51
51
 
52
52
  @override
53
53
  def name(self):