nshtrainer 0.11.4__py3-none-any.whl → 0.11.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -72,6 +72,8 @@ class BestCheckpoint(Checkpoint):
|
|
|
72
72
|
self.metric = metric
|
|
73
73
|
self.dirpath = dirpath
|
|
74
74
|
|
|
75
|
+
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
|
76
|
+
|
|
75
77
|
@override
|
|
76
78
|
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
77
79
|
self._save_best_checkpoint(trainer)
|
|
@@ -127,6 +129,11 @@ class BestCheckpoint(Checkpoint):
|
|
|
127
129
|
log.debug(f"Created best symlink: {symlink_path}")
|
|
128
130
|
|
|
129
131
|
def _save_best_checkpoint(self, trainer: Trainer):
|
|
132
|
+
# Skip saving the checkpoint if we're not in the fitting state
|
|
133
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# Get the current metric value
|
|
130
137
|
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
131
138
|
log.warning(
|
|
132
139
|
f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
|
|
@@ -150,6 +157,7 @@ class BestCheckpoint(Checkpoint):
|
|
|
150
157
|
# Save the current model
|
|
151
158
|
filepath = self._ckpt_path(trainer)
|
|
152
159
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
160
|
+
log.debug(f"Saved best checkpoint: {filepath}")
|
|
153
161
|
|
|
154
162
|
# Remove worst checkpoint if we've reached save_top_k
|
|
155
163
|
# NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
|
|
@@ -163,9 +171,26 @@ class BestCheckpoint(Checkpoint):
|
|
|
163
171
|
)
|
|
164
172
|
|
|
165
173
|
# Create symlink to best model
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
174
|
+
if sorted_ckpts:
|
|
175
|
+
_, best_ckpt_path = sorted_ckpts[0]
|
|
176
|
+
self._create_symlink(trainer, best_ckpt_path)
|
|
177
|
+
|
|
178
|
+
# Update the last global step saved
|
|
179
|
+
self._last_global_step_saved = trainer.global_step
|
|
169
180
|
|
|
170
181
|
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
171
182
|
trainer.strategy.barrier()
|
|
183
|
+
|
|
184
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
185
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
186
|
+
|
|
187
|
+
return (
|
|
188
|
+
bool(
|
|
189
|
+
getattr(trainer, "fast_dev_run", False)
|
|
190
|
+
) # disable checkpointing with fast_dev_run
|
|
191
|
+
or trainer.state.fn
|
|
192
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
193
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
194
|
+
or self._last_global_step_saved
|
|
195
|
+
== trainer.global_step # already saved at the last step
|
|
196
|
+
)
|
|
@@ -11,7 +11,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
13
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
|
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=Ygblsf9NdLHxQPJUM47W0nGxlabj-ZnEBIMpvk-QMS8,7124
|
|
15
15
|
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=NES-acaslPBiZQIMAdk_YwtnBrkm_y_BJQ8Ian0UKP0,4294
|
|
16
16
|
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.6.dist-info/METADATA,sha256=tHGQ69o-paHEvlLUgo46bWeMlvuTlb8Q-upA00NxoKE,860
|
|
86
|
+
nshtrainer-0.11.6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.6.dist-info/RECORD,,
|
|
File without changes
|