nshtrainer 0.11.5__py3-none-any.whl → 0.11.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -72,6 +72,8 @@ class BestCheckpoint(Checkpoint):
72
72
  self.metric = metric
73
73
  self.dirpath = dirpath
74
74
 
75
+ self._last_global_step_saved = 0 # no need to save when no steps were taken
76
+
75
77
  @override
76
78
  def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
77
79
  self._save_best_checkpoint(trainer)
@@ -127,6 +129,11 @@ class BestCheckpoint(Checkpoint):
127
129
  log.debug(f"Created best symlink: {symlink_path}")
128
130
 
129
131
  def _save_best_checkpoint(self, trainer: Trainer):
132
+ # Skip saving the checkpoint if we're not in the fitting state
133
+ if self._should_skip_saving_checkpoint(trainer):
134
+ return
135
+
136
+ # Get the current metric value
130
137
  if (current := self._get_metric_value(trainer.callback_metrics)) is None:
131
138
  log.warning(
132
139
  f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
@@ -168,5 +175,22 @@ class BestCheckpoint(Checkpoint):
168
175
  _, best_ckpt_path = sorted_ckpts[0]
169
176
  self._create_symlink(trainer, best_ckpt_path)
170
177
 
178
+ # Update the last global step saved
179
+ self._last_global_step_saved = trainer.global_step
180
+
171
181
  # Barrier to ensure all processes have saved the checkpoint before continuing
172
182
  trainer.strategy.barrier()
183
+
184
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
185
+ from lightning.pytorch.trainer.states import TrainerFn
186
+
187
+ return (
188
+ bool(
189
+ getattr(trainer, "fast_dev_run", False)
190
+ ) # disable checkpointing with fast_dev_run
191
+ or trainer.state.fn
192
+ != TrainerFn.FITTING # don't save anything during non-fit
193
+ or trainer.sanity_checking # don't save anything during sanity check
194
+ or self._last_global_step_saved
195
+ == trainer.global_step # already saved at the last step
196
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.5
3
+ Version: 0.11.6
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,7 +11,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
13
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=w99O5GWRcV89XBe4j__v2TvNEHys0x_r3tSTr-6Lhec,6154
14
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=Ygblsf9NdLHxQPJUM47W0nGxlabj-ZnEBIMpvk-QMS8,7124
15
15
  nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=NES-acaslPBiZQIMAdk_YwtnBrkm_y_BJQ8Ian0UKP0,4294
16
16
  nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
82
82
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
83
83
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
84
84
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
85
- nshtrainer-0.11.5.dist-info/METADATA,sha256=KHgvYOhQXbc37awWeLbpbdVQbSEU4J7KoC7Lr5286KE,860
86
- nshtrainer-0.11.5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
- nshtrainer-0.11.5.dist-info/RECORD,,
85
+ nshtrainer-0.11.6.dist-info/METADATA,sha256=tHGQ69o-paHEvlLUgo46bWeMlvuTlb8Q-upA00NxoKE,860
86
+ nshtrainer-0.11.6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
+ nshtrainer-0.11.6.dist-info/RECORD,,