nshtrainer 0.11.5__py3-none-any.whl → 0.11.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -72,6 +72,8 @@ class BestCheckpoint(Checkpoint):
|
|
|
72
72
|
self.metric = metric
|
|
73
73
|
self.dirpath = dirpath
|
|
74
74
|
|
|
75
|
+
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
|
76
|
+
|
|
75
77
|
@override
|
|
76
78
|
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
77
79
|
self._save_best_checkpoint(trainer)
|
|
@@ -127,6 +129,11 @@ class BestCheckpoint(Checkpoint):
|
|
|
127
129
|
log.debug(f"Created best symlink: {symlink_path}")
|
|
128
130
|
|
|
129
131
|
def _save_best_checkpoint(self, trainer: Trainer):
|
|
132
|
+
# Skip saving the checkpoint if we're not in the fitting state
|
|
133
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# Get the current metric value
|
|
130
137
|
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
131
138
|
log.warning(
|
|
132
139
|
f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
|
|
@@ -168,5 +175,22 @@ class BestCheckpoint(Checkpoint):
|
|
|
168
175
|
_, best_ckpt_path = sorted_ckpts[0]
|
|
169
176
|
self._create_symlink(trainer, best_ckpt_path)
|
|
170
177
|
|
|
178
|
+
# Update the last global step saved
|
|
179
|
+
self._last_global_step_saved = trainer.global_step
|
|
180
|
+
|
|
171
181
|
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
172
182
|
trainer.strategy.barrier()
|
|
183
|
+
|
|
184
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
185
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
186
|
+
|
|
187
|
+
return (
|
|
188
|
+
bool(
|
|
189
|
+
getattr(trainer, "fast_dev_run", False)
|
|
190
|
+
) # disable checkpointing with fast_dev_run
|
|
191
|
+
or trainer.state.fn
|
|
192
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
193
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
194
|
+
or self._last_global_step_saved
|
|
195
|
+
== trainer.global_step # already saved at the last step
|
|
196
|
+
)
|
|
@@ -11,7 +11,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
13
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
|
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=Ygblsf9NdLHxQPJUM47W0nGxlabj-ZnEBIMpvk-QMS8,7124
|
|
15
15
|
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=NES-acaslPBiZQIMAdk_YwtnBrkm_y_BJQ8Ian0UKP0,4294
|
|
16
16
|
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.6.dist-info/METADATA,sha256=tHGQ69o-paHEvlLUgo46bWeMlvuTlb8Q-upA00NxoKE,860
|
|
86
|
+
nshtrainer-0.11.6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.6.dist-info/RECORD,,
|
|
File without changes
|