nshtrainer 0.11.12__py3-none-any.whl → 0.11.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/checkpoint/_base.py +7 -4
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -0
- {nshtrainer-0.11.12.dist-info → nshtrainer-0.11.13.dist-info}/METADATA +1 -1
- {nshtrainer-0.11.12.dist-info → nshtrainer-0.11.13.dist-info}/RECORD +6 -6
- {nshtrainer-0.11.12.dist-info → nshtrainer-0.11.13.dist-info}/WHEEL +0 -0
|
@@ -79,6 +79,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
79
79
|
@abstractmethod
|
|
80
80
|
def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
|
|
81
81
|
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def topk_sort_reverse(self) -> bool: ...
|
|
84
|
+
|
|
82
85
|
def symlink_path(self):
|
|
83
86
|
if not self.config.save_symlink:
|
|
84
87
|
return None
|
|
@@ -102,7 +105,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
102
105
|
]
|
|
103
106
|
|
|
104
107
|
# Sort by the topk sort key
|
|
105
|
-
metas = sorted(metas, key=self.topk_sort_key)
|
|
108
|
+
metas = sorted(metas, key=self.topk_sort_key, reverse=self.topk_sort_reverse())
|
|
106
109
|
|
|
107
110
|
# Now, the metas are sorted from the best to the worst,
|
|
108
111
|
# so we can remove the worst checkpoints
|
|
@@ -145,15 +148,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
145
148
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
146
149
|
|
|
147
150
|
if trainer.is_global_zero:
|
|
148
|
-
# Remove old checkpoints
|
|
149
|
-
self.remove_old_checkpoints(trainer)
|
|
150
|
-
|
|
151
151
|
# Create the latest symlink
|
|
152
152
|
if (symlink_filename := self.symlink_path()) is not None:
|
|
153
153
|
symlink_path = self.dirpath / symlink_filename
|
|
154
154
|
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
155
155
|
log.debug(f"Created latest symlink: {symlink_path}")
|
|
156
156
|
|
|
157
|
+
# Remove old checkpoints
|
|
158
|
+
self.remove_old_checkpoints(trainer)
|
|
159
|
+
|
|
157
160
|
# Barrier to ensure all processes have saved the checkpoint,
|
|
158
161
|
# deleted the old checkpoints, and created the symlink before continuing
|
|
159
162
|
trainer.strategy.barrier()
|
|
@@ -64,6 +64,10 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
|
64
64
|
float("-inf" if self.metric.mode == "max" else "inf"),
|
|
65
65
|
)
|
|
66
66
|
|
|
67
|
+
@override
|
|
68
|
+
def topk_sort_reverse(self):
|
|
69
|
+
return self.metric.mode == "max"
|
|
70
|
+
|
|
67
71
|
# Events
|
|
68
72
|
@override
|
|
69
73
|
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
@@ -34,6 +34,10 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
|
34
34
|
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
35
35
|
return metadata.checkpoint_timestamp
|
|
36
36
|
|
|
37
|
+
@override
|
|
38
|
+
def topk_sort_reverse(self):
|
|
39
|
+
return True
|
|
40
|
+
|
|
37
41
|
@override
|
|
38
42
|
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
39
43
|
self.save_checkpoints(trainer)
|
|
@@ -11,9 +11,9 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
13
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
14
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
15
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
16
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt3w4KdzqrzLs,6094
|
|
15
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=qajV_GxeUg0GXeOtiimmPabMJnkNu_I1prZb2ksPOG8,2156
|
|
16
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=ctwl2bmHC79enpg9wi-iHWQYIkP-iQIeyEvJUUJ5AW8,1105
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
18
|
nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
|
|
19
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.13.dist-info/METADATA,sha256=wKFqCeZ6hxeHznkFksP3-kqF6vhG7ErudiM-auKKEJE,861
|
|
86
|
+
nshtrainer-0.11.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.13.dist-info/RECORD,,
|
|
File without changes
|