nshtrainer 0.11.12__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,6 +79,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
79
79
  @abstractmethod
80
80
  def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
81
81
 
82
+ @abstractmethod
83
+ def topk_sort_reverse(self) -> bool: ...
84
+
82
85
  def symlink_path(self):
83
86
  if not self.config.save_symlink:
84
87
  return None
@@ -102,7 +105,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
102
105
  ]
103
106
 
104
107
  # Sort by the topk sort key
105
- metas = sorted(metas, key=self.topk_sort_key)
108
+ metas = sorted(metas, key=self.topk_sort_key, reverse=self.topk_sort_reverse())
106
109
 
107
110
  # Now, the metas are sorted from the best to the worst,
108
111
  # so we can remove the worst checkpoints
@@ -145,15 +148,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
145
148
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
146
149
 
147
150
  if trainer.is_global_zero:
148
- # Remove old checkpoints
149
- self.remove_old_checkpoints(trainer)
150
-
151
151
  # Create the latest symlink
152
152
  if (symlink_filename := self.symlink_path()) is not None:
153
153
  symlink_path = self.dirpath / symlink_filename
154
154
  _link_checkpoint(filepath, symlink_path, metadata=True)
155
155
  log.debug(f"Created latest symlink: {symlink_path}")
156
156
 
157
+ # Remove old checkpoints
158
+ self.remove_old_checkpoints(trainer)
159
+
157
160
  # Barrier to ensure all processes have saved the checkpoint,
158
161
  # deleted the old checkpoints, and created the symlink before continuing
159
162
  trainer.strategy.barrier()
@@ -55,7 +55,7 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
55
55
 
56
56
  @override
57
57
  def default_filename(self):
58
- return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
58
+ return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
59
59
 
60
60
  @override
61
61
  def topk_sort_key(self, metadata: CheckpointMetadata):
@@ -64,6 +64,10 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
64
64
  float("-inf" if self.metric.mode == "max" else "inf"),
65
65
  )
66
66
 
67
+ @override
68
+ def topk_sort_reverse(self):
69
+ return self.metric.mode == "max"
70
+
67
71
  # Events
68
72
  @override
69
73
  def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
@@ -28,12 +28,16 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
28
28
 
29
29
  @override
30
30
  def default_filename(self):
31
- return "epoch{epoch:03d}-step{step:07d}"
31
+ return "epoch{epoch}-step{step}"
32
32
 
33
33
  @override
34
34
  def topk_sort_key(self, metadata: CheckpointMetadata):
35
35
  return metadata.checkpoint_timestamp
36
36
 
37
+ @override
38
+ def topk_sort_reverse(self):
39
+ return True
40
+
37
41
  @override
38
42
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
39
43
  self.save_checkpoints(trainer)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.12
3
+ Version: 0.12.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,9 +11,9 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
13
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
14
- nshtrainer/callbacks/checkpoint/_base.py,sha256=wb5ARqqSslZXB5FSwNd9a_7qtt4-F9dC7E1vNHKCK1o,5994
15
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=ySeyALxc-YJaS5IhmW0MkAhr41Mxsm_BllHuufXGy4Y,2067
16
- nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CM8f37dwaYHkjQFfJNTZTzSoF45zEjFRm-Fg1CzYmP4,1037
14
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt3w4KdzqrzLs,6094
15
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
16
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
18
18
  nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
19
19
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
82
82
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
83
83
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
84
84
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
85
- nshtrainer-0.11.12.dist-info/METADATA,sha256=Lthmwj2EpbkCNWlDg2n7z4zS197Pn5Kt1aUMayB70oo,861
86
- nshtrainer-0.11.12.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
- nshtrainer-0.11.12.dist-info/RECORD,,
85
+ nshtrainer-0.12.0.dist-info/METADATA,sha256=YKxr6_dhF7jHF1njUUJi0CZRR0BuOPDNrQiwiIO7iTU,860
86
+ nshtrainer-0.12.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
+ nshtrainer-0.12.0.dist-info/RECORD,,