nshtrainer 0.11.1__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,6 +133,68 @@ class CheckpointLoadingConfig(C.Config):
133
133
  ckpt: Literal["best", "last"] | str | Path | None,
134
134
  trainer_mode: TrainerFn,
135
135
  ):
136
+ """
137
+ Automatically create a CheckpointLoadingConfig based on the provided checkpoint option and trainer mode.
138
+
139
+ This method provides a convenient way to generate a checkpoint loading configuration
140
+ tailored to different training and evaluation scenarios.
141
+
142
+ Parameters:
143
+ -----------
144
+ ckpt : Literal["best", "last"] | str | Path | None
145
+ Specifies the checkpoint loading preference:
146
+ - "best": Use the best checkpoint based on the primary metric.
147
+ - "last": Use the most recent checkpoint.
148
+ - str or Path: Path to a specific checkpoint file.
149
+ - None: Defaults to "last" for training, raises an error for evaluation.
150
+
151
+ trainer_mode : TrainerFn
152
+ The mode in which the trainer is operating. This affects how the configuration is created.
153
+ - TrainerFn.FITTING: Used for training scenarios.
154
+ - TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING: Used for evaluation scenarios.
155
+
156
+ Returns:
157
+ --------
158
+ CheckpointLoadingConfig
159
+ A configuration object for checkpoint loading based on the given parameters.
160
+
161
+ Behavior:
162
+ ---------
163
+ 1. For training (TrainerFn.FITTING):
164
+ - Includes HPC pre-emption checkpoints.
165
+ - If ckpt is None, defaults to "last".
166
+ - For "best" or "last", creates a single-strategy configuration that loads the best or last checkpoint.
167
+ - For a specific path, creates a two-strategy configuration:
168
+ a) Tries to load the checkpoint as the last checkpoint.
169
+ b) Falls back to loading it as a user-provided path.
170
+
171
+ 2. For evaluation (VALIDATING, TESTING, PREDICTING):
172
+ - Does not include HPC pre-emption checkpoints.
173
+ - Requires ckpt to be specified (raises ValueError if None).
174
+ - Creates a single-strategy configuration based on the ckpt value.
175
+
176
+ Raises:
177
+ -------
178
+ ValueError
179
+ If ckpt is None during evaluation modes.
180
+
181
+ Examples:
182
+ ---------
183
+ # Training mode, use last checkpoint
184
+ config = CheckpointLoadingConfig.auto("last", TrainerFn.FITTING)
185
+
186
+ # Evaluation mode, use best checkpoint
187
+ config = CheckpointLoadingConfig.auto("best", TrainerFn.TESTING)
188
+
189
+ # Training mode, use specific checkpoint
190
+ config = CheckpointLoadingConfig.auto("/path/to/checkpoint.ckpt", TrainerFn.FITTING)
191
+
192
+ Notes:
193
+ ------
194
+ - The method internally calls _auto_train or _auto_eval based on the trainer_mode.
195
+ - The resulting configuration always includes strategies as a sequence, even if there's only one strategy.
196
+ """
197
+ # Implementation remains the same...
136
198
  match trainer_mode:
137
199
  case TrainerFn.FITTING:
138
200
  return cls._auto_train(ckpt)
@@ -2,7 +2,18 @@ from typing import Annotated
2
2
 
3
3
  import nshconfig as C
4
4
 
5
+ from . import checkpoint as checkpoint
5
6
  from .base import CallbackConfigBase as CallbackConfigBase
7
+ from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
8
+ from .checkpoint import (
9
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
10
+ )
11
+ from .checkpoint import ModelCheckpoint as ModelCheckpoint
12
+ from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
13
+ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
14
+ from .checkpoint import (
15
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
+ )
6
17
  from .early_stopping import EarlyStopping as EarlyStopping
7
18
  from .ema import EMA as EMA
8
19
  from .ema import EMAConfig as EMAConfig
@@ -13,21 +24,9 @@ from .gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
13
24
  from .interval import EpochIntervalCallback as EpochIntervalCallback
14
25
  from .interval import IntervalCallback as IntervalCallback
15
26
  from .interval import StepIntervalCallback as StepIntervalCallback
16
- from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
17
- from .latest_epoch_checkpoint import (
18
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
19
- )
20
27
  from .log_epoch import LogEpochCallback as LogEpochCallback
21
- from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
22
- from .model_checkpoint import (
23
- ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
24
- )
25
28
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
26
29
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
27
- from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
28
- from .on_exception_checkpoint import (
29
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
30
- )
31
30
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
32
31
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
33
32
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
@@ -0,0 +1,12 @@
1
+ from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
2
+ from .latest_epoch_checkpoint import (
3
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
4
+ )
5
+ from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
6
+ from .model_checkpoint import (
7
+ ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
8
+ )
9
+ from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
10
+ from .on_exception_checkpoint import (
11
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
12
+ )
@@ -6,9 +6,9 @@ from lightning.pytorch import LightningModule, Trainer
6
6
  from lightning.pytorch.callbacks import Checkpoint
7
7
  from typing_extensions import override
8
8
 
9
- from .._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
- from .base import CallbackConfigBase
9
+ from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
+ from ..base import CallbackConfigBase
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
@@ -75,6 +75,10 @@ class LatestEpochCheckpoint(Checkpoint):
75
75
  if (latest_k := self.config.latest_k) == "all":
76
76
  return
77
77
 
78
+ # NOTE: We add 1 to the latest_k here because
79
+ # we're about to save a new checkpoint.
80
+ latest_k += 1
81
+
78
82
  # Get all configs, ignoring the latest symlink
79
83
  ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
80
84
  # Ignore the latest symlink
@@ -90,8 +94,7 @@ class LatestEpochCheckpoint(Checkpoint):
90
94
  )
91
95
 
92
96
  # Remove all but the latest k checkpoints
93
- ckpts_to_remove = ckpt_paths[:-latest_k]
94
- self._remove_checkpoints(trainer, ckpts_to_remove)
97
+ self._remove_checkpoints(trainer, ckpt_paths[:-latest_k])
95
98
 
96
99
  def _save_new_checkpoint(self, trainer: Trainer):
97
100
  # Remove old checkpoints
@@ -113,4 +116,4 @@ class LatestEpochCheckpoint(Checkpoint):
113
116
  barrier=True,
114
117
  metadata=True,
115
118
  )
116
- log.info(f"Created latest symlink: {symlink_path}")
119
+ log.debug(f"Created latest symlink: {symlink_path}")
@@ -10,12 +10,13 @@ from lightning.pytorch.callbacks.model_checkpoint import (
10
10
  )
11
11
  from typing_extensions import override
12
12
 
13
- from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
- from ..metrics import MetricConfig
15
- from .base import CallbackConfigBase
13
+ from ..._checkpoint.saver import _link_checkpoint
14
+ from ..._checkpoint.saver import _remove_checkpoint as _ckpt_saver_remove_checkpoint
15
+ from ...metrics import MetricConfig
16
+ from ..base import CallbackConfigBase
16
17
 
17
18
  if TYPE_CHECKING:
18
- from ..model.config import BaseConfig
19
+ from ...model.config import BaseConfig
19
20
 
20
21
  log = logging.getLogger(__name__)
21
22
 
@@ -74,10 +75,10 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
74
75
  If "link", creates a symbolic link to the last checkpoint.
75
76
  """
76
77
 
77
- save_top_k: int = 1
78
+ save_top_k: int | Literal["all"] = 1
78
79
  """
79
80
  Number of best models to save.
80
- If -1, all models are saved.
81
+ If "all" or -1, all models are saved.
81
82
  If 0, no models are saved.
82
83
  """
83
84
 
@@ -158,6 +159,11 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
158
159
  metric=metric,
159
160
  )
160
161
 
162
+ def _save_top_k_model_ckpt_input(self):
163
+ if self.save_top_k == "all":
164
+ return -1
165
+ return self.save_top_k
166
+
161
167
 
162
168
  class ModelCheckpoint(_ModelCheckpoint):
163
169
  CHECKPOINT_NAME_LAST = "best"
@@ -180,7 +186,7 @@ class ModelCheckpoint(_ModelCheckpoint):
180
186
  mode=metric.mode,
181
187
  verbose=self.config.verbose,
182
188
  save_last=self.config.save_last,
183
- save_top_k=self.config.save_top_k,
189
+ save_top_k=self.config._save_top_k_model_ckpt_input(),
184
190
  save_weights_only=self.config.save_weights_only,
185
191
  auto_insert_metric_name=False,
186
192
  every_n_train_steps=self.config.every_n_train_steps,
@@ -202,4 +208,9 @@ class ModelCheckpoint(_ModelCheckpoint):
202
208
 
203
209
  @override
204
210
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
205
- return _remove_checkpoint(trainer, filepath, metadata=True, barrier=False)
211
+ return _ckpt_saver_remove_checkpoint(
212
+ trainer,
213
+ filepath,
214
+ metadata=True,
215
+ barrier=False,
216
+ )
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer as LightningTrainer
9
9
  from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
10
10
  from typing_extensions import override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from ..base import CallbackConfigBase
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -53,8 +53,6 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
53
53
 
54
54
  @override
55
55
  def create_callbacks(self, root_config):
56
- from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
57
-
58
56
  dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
59
57
  root_config.id, "checkpoint"
60
58
  )
@@ -3,7 +3,7 @@ from pathlib import Path
3
3
  from typing import TYPE_CHECKING, cast
4
4
 
5
5
  from lightning.pytorch.trainer.connectors.checkpoint_connector import (
6
- _CheckpointConnector,
6
+ _CheckpointConnector as _LightningCheckpointConnector,
7
7
  )
8
8
  from lightning.pytorch.trainer.states import TrainerFn
9
9
  from typing_extensions import override
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class CheckpointConnector(_CheckpointConnector):
18
+ class _CheckpointConnector(_LightningCheckpointConnector):
19
19
  def __resolve_auto_ckpt_path(
20
20
  self,
21
21
  ckpt_path: str | Path | None,
@@ -26,6 +26,7 @@ from ..model.config import (
26
26
  StrategyConfigProtocol,
27
27
  )
28
28
  from ._runtime_callback import RuntimeTrackerCallback, Stage
29
+ from .checkpoint_connector import _CheckpointConnector
29
30
  from .signal_connector import _SignalConnector
30
31
 
31
32
  log = logging.getLogger(__name__)
@@ -297,6 +298,9 @@ class Trainer(LightningTrainer):
297
298
  # Replace the signal connector with our own.
298
299
  self._signal_connector = _SignalConnector(self)
299
300
 
301
+ # Replace the checkpoint connector with our own.
302
+ self._checkpoint_connector = _CheckpointConnector(self)
303
+
300
304
  # Print out the log dir, so that we can easily find it in the logs.
301
305
  if log_dir := self.log_dir:
302
306
  log_dir = str(Path(log_dir).resolve())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.1
3
+ Version: 0.11.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,25 +1,26 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
- nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
2
+ nshtrainer/_checkpoint/loader.py,sha256=_3jBf-k-fJCFfmU8wjDwbnE9rb4WoKYEyQiKGsBOCi4,13777
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=3yxGxHLIVwKh5K4L8LYOEK3GQ6HQXy89CGcy9zarApo,5583
4
4
  nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
8
8
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
9
- nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gOqjE,2345
9
+ nshtrainer/callbacks/__init__.py,sha256=4WxCc0KwWJRxgwiDo95S8awd8E2NuLAB0EMP2CYkFoQ,2311
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
+ nshtrainer/callbacks/checkpoint/__init__.py,sha256=7-vcG0RgLyjZmvVcglFkzc026OR-49VGl9eAouKBSyo,577
14
+ nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=7iCLw2Bi8js-05xIOQXFRy4TAjig5Y46UB7V-8eQsOs,4306
15
+ nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
16
+ nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
13
17
  nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
14
18
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
15
19
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
16
20
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
17
21
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
18
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=5JC-JCdgWNnunl0jv4Q9LhkEspLAn0x8VpCMJZi7-ow,4219
19
22
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
20
- nshtrainer/callbacks/model_checkpoint.py,sha256=8D0wWLhr_KiksAA1fjfIuby42Mq6XokCvAnVUhjADd8,6564
21
23
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
22
- nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKhA9qBEosiV8ZNhhZ6lI,3355
23
24
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
24
25
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
25
26
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
@@ -70,15 +71,15 @@ nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
70
71
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
71
72
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
72
73
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
73
- nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
74
+ nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
74
75
  nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
75
- nshtrainer/trainer/trainer.py,sha256=MrSG83TC1woQ-NqzxcWUerJ3JoFi_gOTh2IMnjNO65Y,16920
76
+ nshtrainer/trainer/trainer.py,sha256=IHEtuDVVBradVQOKSP9zYAalkn2sguXUZixzvS8P4UY,17097
76
77
  nshtrainer/util/_environment_info.py,sha256=yPtAbgjCY4tkvh5wp9sjNsF0Z45TYwzEAM_N2_b5BbY,23123
77
78
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
78
79
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
79
80
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
81
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
82
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.11.1.dist-info/METADATA,sha256=lnInZUp-YIr3dp53nyGDQSRFFB2ecLYbYcb_vydhvUs,860
83
- nshtrainer-0.11.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.11.1.dist-info/RECORD,,
83
+ nshtrainer-0.11.2.dist-info/METADATA,sha256=s34LitkStDa3ixSvsXsw7jXjKaIZ3CuGFnC4Z47tcuk,860
84
+ nshtrainer-0.11.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
85
+ nshtrainer-0.11.2.dist-info/RECORD,,