nshtrainer 0.11.3__py3-none-any.whl → 0.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -106,6 +106,26 @@ class BestCheckpoint(Checkpoint):
|
|
|
106
106
|
reverse=(self.metric.mode == "min"),
|
|
107
107
|
)
|
|
108
108
|
|
|
109
|
+
def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
|
|
110
|
+
# Resolve the symlink filename
|
|
111
|
+
if (symlink_filename := self._best_symlink_filename()) is None:
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
# If the symlink already exists and points to the best checkpoint,
|
|
115
|
+
# then we don't need to create a new symlink.
|
|
116
|
+
symlink_path = self.dirpath / symlink_filename
|
|
117
|
+
if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
_link_checkpoint(
|
|
121
|
+
trainer,
|
|
122
|
+
best_ckpt_path,
|
|
123
|
+
symlink_path,
|
|
124
|
+
metadata=True,
|
|
125
|
+
barrier=False,
|
|
126
|
+
)
|
|
127
|
+
log.debug(f"Created best symlink: {symlink_path}")
|
|
128
|
+
|
|
109
129
|
def _save_best_checkpoint(self, trainer: Trainer):
|
|
110
130
|
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
111
131
|
log.warning(
|
|
@@ -143,13 +163,9 @@ class BestCheckpoint(Checkpoint):
|
|
|
143
163
|
)
|
|
144
164
|
|
|
145
165
|
# Create symlink to best model
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
barrier=True,
|
|
153
|
-
metadata=True,
|
|
154
|
-
)
|
|
155
|
-
log.debug(f"Created best symlink: {symlink_path}")
|
|
166
|
+
_, best_ckpt_path = sorted_ckpts[0]
|
|
167
|
+
self._create_symlink(trainer, best_ckpt_path)
|
|
168
|
+
log.debug(f"Saved best checkpoint: {filepath}")
|
|
169
|
+
|
|
170
|
+
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
171
|
+
trainer.strategy.barrier()
|
|
@@ -11,7 +11,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
13
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
|
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=SfM0VGCSDxYutJ21tY8m283eElN8jyKh2hhj6b7-1-s,6121
|
|
15
15
|
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=NES-acaslPBiZQIMAdk_YwtnBrkm_y_BJQ8Ian0UKP0,4294
|
|
16
16
|
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.4.dist-info/METADATA,sha256=6s-tc2dutPpIQaxv7j8cdZWiLU3l0Y_aS5f-k5JvAfM,860
|
|
86
|
+
nshtrainer-0.11.4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.4.dist-info/RECORD,,
|
|
File without changes
|