nshtrainer 0.11.5__py3-none-any.whl → 0.11.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/saver.py +18 -28
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +42 -22
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +41 -29
- nshtrainer/callbacks/checkpoint/model_checkpoint.py +4 -13
- {nshtrainer-0.11.5.dist-info → nshtrainer-0.11.7.dist-info}/METADATA +1 -1
- {nshtrainer-0.11.5.dist-info → nshtrainer-0.11.7.dist-info}/RECORD +7 -7
- {nshtrainer-0.11.5.dist-info → nshtrainer-0.11.7.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/saver.py
CHANGED
|
@@ -8,11 +8,9 @@ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def _link_checkpoint(
|
|
11
|
-
trainer: Trainer,
|
|
12
11
|
filepath: str | Path | os.PathLike,
|
|
13
12
|
linkpath: str | Path | os.PathLike,
|
|
14
13
|
*,
|
|
15
|
-
barrier: bool,
|
|
16
14
|
metadata: bool,
|
|
17
15
|
):
|
|
18
16
|
if not isinstance(filepath, Path):
|
|
@@ -20,26 +18,23 @@ def _link_checkpoint(
|
|
|
20
18
|
if not isinstance(linkpath, Path):
|
|
21
19
|
linkpath = Path(linkpath)
|
|
22
20
|
|
|
23
|
-
if
|
|
24
|
-
if linkpath.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
_remove_checkpoint_metadata(linkpath)
|
|
21
|
+
if linkpath.exists():
|
|
22
|
+
if linkpath.is_symlink() or linkpath.is_file():
|
|
23
|
+
linkpath.unlink()
|
|
24
|
+
elif linkpath.is_dir():
|
|
25
|
+
shutil.rmtree(linkpath)
|
|
26
|
+
_remove_checkpoint_metadata(linkpath)
|
|
30
27
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
28
|
+
try:
|
|
29
|
+
target_path = filepath.relative_to(linkpath.parent)
|
|
30
|
+
linkpath.symlink_to(target_path)
|
|
31
|
+
except OSError:
|
|
32
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
33
|
+
# fall back to copying the file
|
|
34
|
+
shutil.copy(filepath, linkpath)
|
|
38
35
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
if barrier:
|
|
42
|
-
trainer.strategy.barrier()
|
|
36
|
+
if metadata:
|
|
37
|
+
_link_checkpoint_metadata(filepath, linkpath)
|
|
43
38
|
|
|
44
39
|
|
|
45
40
|
def _remove_checkpoint(
|
|
@@ -47,15 +42,10 @@ def _remove_checkpoint(
|
|
|
47
42
|
filepath: str | Path | os.PathLike,
|
|
48
43
|
*,
|
|
49
44
|
metadata: bool,
|
|
50
|
-
barrier: bool,
|
|
51
45
|
):
|
|
52
46
|
if not isinstance(filepath, Path):
|
|
53
47
|
filepath = Path(filepath)
|
|
54
48
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
_remove_checkpoint_metadata(filepath)
|
|
59
|
-
|
|
60
|
-
if barrier:
|
|
61
|
-
trainer.strategy.barrier()
|
|
49
|
+
trainer.strategy.remove_checkpoint(filepath)
|
|
50
|
+
if metadata:
|
|
51
|
+
_remove_checkpoint_metadata(filepath)
|
|
@@ -72,6 +72,8 @@ class BestCheckpoint(Checkpoint):
|
|
|
72
72
|
self.metric = metric
|
|
73
73
|
self.dirpath = dirpath
|
|
74
74
|
|
|
75
|
+
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
|
76
|
+
|
|
75
77
|
@override
|
|
76
78
|
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
77
79
|
self._save_best_checkpoint(trainer)
|
|
@@ -88,10 +90,6 @@ class BestCheckpoint(Checkpoint):
|
|
|
88
90
|
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
89
91
|
return self.dirpath / filename
|
|
90
92
|
|
|
91
|
-
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
92
|
-
for ckpt_path in ckpt_paths:
|
|
93
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
94
|
-
|
|
95
93
|
def _get_metric_value(self, metrics: dict[str, Any]):
|
|
96
94
|
return metrics.get(
|
|
97
95
|
self.metric.validation_monitor,
|
|
@@ -99,11 +97,16 @@ class BestCheckpoint(Checkpoint):
|
|
|
99
97
|
)
|
|
100
98
|
|
|
101
99
|
def _sorted_ckpts(self):
|
|
100
|
+
"""
|
|
101
|
+
Get sorted checkpoints by the metric value.
|
|
102
|
+
|
|
103
|
+
Sort order: best -> worst
|
|
104
|
+
"""
|
|
102
105
|
ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
103
106
|
return _sort_ckpts_by_metadata(
|
|
104
107
|
ckpt_paths,
|
|
105
108
|
key=lambda meta, _: self._get_metric_value(meta.metrics),
|
|
106
|
-
reverse=(self.metric.mode == "
|
|
109
|
+
reverse=(self.metric.mode == "max"),
|
|
107
110
|
)
|
|
108
111
|
|
|
109
112
|
def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
|
|
@@ -117,16 +120,15 @@ class BestCheckpoint(Checkpoint):
|
|
|
117
120
|
if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
|
|
118
121
|
return
|
|
119
122
|
|
|
120
|
-
_link_checkpoint(
|
|
121
|
-
trainer,
|
|
122
|
-
best_ckpt_path,
|
|
123
|
-
symlink_path,
|
|
124
|
-
metadata=True,
|
|
125
|
-
barrier=False,
|
|
126
|
-
)
|
|
123
|
+
_link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
|
|
127
124
|
log.debug(f"Created best symlink: {symlink_path}")
|
|
128
125
|
|
|
129
126
|
def _save_best_checkpoint(self, trainer: Trainer):
|
|
127
|
+
# Skip saving the checkpoint if we're not in the fitting state
|
|
128
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
# Get the current metric value
|
|
130
132
|
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
131
133
|
log.warning(
|
|
132
134
|
f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
|
|
@@ -152,21 +154,39 @@ class BestCheckpoint(Checkpoint):
|
|
|
152
154
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
153
155
|
log.debug(f"Saved best checkpoint: {filepath}")
|
|
154
156
|
|
|
155
|
-
|
|
156
|
-
# NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
|
|
157
|
-
if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
|
|
157
|
+
if trainer.is_global_zero:
|
|
158
158
|
# Get the sorted checkpoints again because now we have added a new checkpoint.
|
|
159
159
|
# We could optimize this by adding the new checkpoint to the sorted list,
|
|
160
160
|
# and then sorting it in place, but this is simpler.
|
|
161
161
|
sorted_ckpts = self._sorted_ckpts()
|
|
162
|
-
self._remove_checkpoints(
|
|
163
|
-
trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
|
|
164
|
-
)
|
|
165
162
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
163
|
+
# Remove worst checkpoint if we've reached save_top_k
|
|
164
|
+
if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
|
|
165
|
+
# NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
|
|
166
|
+
for _, ckpt_path in sorted_ckpts[topk:]:
|
|
167
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
168
|
+
|
|
169
|
+
# Create symlink to best model
|
|
170
|
+
if sorted_ckpts:
|
|
171
|
+
_, best_ckpt_path = sorted_ckpts[0]
|
|
172
|
+
self._create_symlink(trainer, best_ckpt_path)
|
|
173
|
+
|
|
174
|
+
# Update the last global step saved
|
|
175
|
+
self._last_global_step_saved = trainer.global_step
|
|
170
176
|
|
|
171
177
|
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
172
178
|
trainer.strategy.barrier()
|
|
179
|
+
|
|
180
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
181
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
182
|
+
|
|
183
|
+
return (
|
|
184
|
+
bool(
|
|
185
|
+
getattr(trainer, "fast_dev_run", False)
|
|
186
|
+
) # disable checkpointing with fast_dev_run
|
|
187
|
+
or trainer.state.fn
|
|
188
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
189
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
190
|
+
or self._last_global_step_saved
|
|
191
|
+
== trainer.global_step # already saved at the last step
|
|
192
|
+
)
|
|
@@ -51,6 +51,8 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
51
51
|
self.config = config
|
|
52
52
|
self.dirpath = dirpath
|
|
53
53
|
|
|
54
|
+
self._last_global_step_saved = 0
|
|
55
|
+
|
|
54
56
|
@override
|
|
55
57
|
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
56
58
|
self._save_new_checkpoint(trainer)
|
|
@@ -67,53 +69,63 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
67
69
|
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
68
70
|
return self.dirpath / filename
|
|
69
71
|
|
|
70
|
-
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
71
|
-
for ckpt_path in ckpt_paths:
|
|
72
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
73
|
-
|
|
74
72
|
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
75
73
|
if (latest_k := self.config.latest_k) == "all":
|
|
76
74
|
return
|
|
77
75
|
|
|
78
|
-
# NOTE: We add 1 to the latest_k here because
|
|
79
|
-
# we're about to save a new checkpoint.
|
|
80
|
-
latest_k += 1
|
|
81
|
-
|
|
82
76
|
# Get all configs, ignoring the latest symlink
|
|
83
|
-
|
|
77
|
+
ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
84
78
|
# Ignore the latest symlink
|
|
85
79
|
if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
|
|
86
|
-
|
|
80
|
+
ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
|
|
87
81
|
|
|
88
82
|
# Sort by epoch, then step, then last modified
|
|
89
|
-
|
|
90
|
-
|
|
83
|
+
ckpts = _sort_ckpts_by_metadata(
|
|
84
|
+
ckpts,
|
|
91
85
|
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
92
86
|
reverse=True,
|
|
93
87
|
)
|
|
94
88
|
|
|
95
89
|
# Remove all but the latest k checkpoints
|
|
96
|
-
|
|
97
|
-
|
|
90
|
+
# NOTE: We add 1 to the latest_k here because
|
|
91
|
+
# we're about to save a new checkpoint.
|
|
92
|
+
for _, ckpt_path in ckpts[latest_k:]:
|
|
93
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
98
94
|
|
|
99
95
|
def _save_new_checkpoint(self, trainer: Trainer):
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self._remove_old_checkpoints(trainer)
|
|
103
|
-
trainer.strategy.barrier()
|
|
96
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
97
|
+
return
|
|
104
98
|
|
|
105
99
|
# Save the new checkpoint
|
|
106
100
|
filepath = self._ckpt_path(trainer)
|
|
107
101
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
108
102
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
symlink_path
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
103
|
+
if trainer.is_global_zero:
|
|
104
|
+
# Remove old checkpoints
|
|
105
|
+
self._remove_old_checkpoints(trainer)
|
|
106
|
+
|
|
107
|
+
# Create the latest symlink
|
|
108
|
+
if (symlink_filename := self._latest_symlink_filename()) is not None:
|
|
109
|
+
symlink_path = self.dirpath / symlink_filename
|
|
110
|
+
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
111
|
+
log.debug(f"Created latest symlink: {symlink_path}")
|
|
112
|
+
|
|
113
|
+
# Set the last global step saved
|
|
114
|
+
self._last_global_step_saved = trainer.global_step
|
|
115
|
+
|
|
116
|
+
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
117
|
+
trainer.strategy.barrier()
|
|
118
|
+
|
|
119
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
120
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
121
|
+
|
|
122
|
+
return (
|
|
123
|
+
bool(
|
|
124
|
+
getattr(trainer, "fast_dev_run", False)
|
|
125
|
+
) # disable checkpointing with fast_dev_run
|
|
126
|
+
or trainer.state.fn
|
|
127
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
128
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
129
|
+
or self._last_global_step_saved
|
|
130
|
+
== trainer.global_step # already saved at the last step
|
|
131
|
+
)
|
|
@@ -198,19 +198,10 @@ class ModelCheckpoint(_ModelCheckpoint):
|
|
|
198
198
|
|
|
199
199
|
@override
|
|
200
200
|
def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
linkpath,
|
|
205
|
-
barrier=True,
|
|
206
|
-
metadata=True,
|
|
207
|
-
)
|
|
201
|
+
if trainer.is_global_zero:
|
|
202
|
+
_link_checkpoint(filepath, linkpath, metadata=True)
|
|
203
|
+
trainer.strategy.barrier()
|
|
208
204
|
|
|
209
205
|
@override
|
|
210
206
|
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
211
|
-
|
|
212
|
-
trainer,
|
|
213
|
-
filepath,
|
|
214
|
-
metadata=True,
|
|
215
|
-
barrier=False,
|
|
216
|
-
)
|
|
207
|
+
_ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
2
|
nshtrainer/_checkpoint/loader.py,sha256=_3jBf-k-fJCFfmU8wjDwbnE9rb4WoKYEyQiKGsBOCi4,13777
|
|
3
3
|
nshtrainer/_checkpoint/metadata.py,sha256=M9eAZ2xMs36Z1G1xULu9MHZhsHxN8_9mNt3Iv7wuq-I,5069
|
|
4
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
|
4
|
+
nshtrainer/_checkpoint/saver.py,sha256=TuSAP39DOOVvSnSukQ9RitMV60JnDg6L27fMRc2uVJc,1358
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
|
|
6
6
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
7
7
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
@@ -11,9 +11,9 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
13
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
|
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
15
|
-
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=
|
|
16
|
-
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=BUO0sWqlwfyxD1UeII5DZ-01SGLiawJAEsL8HjGX4XA,7018
|
|
15
|
+
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=CQ0IqhuPI7zAxpQLy48kK8qqfVfwXEJoHGRqI4h8xNk,4819
|
|
16
|
+
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=JS1z2YuEiQxk61HgZU1jySzF_pzdfXYO54_qHo-q3CQ,6776
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
18
18
|
nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
|
|
19
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
85
|
+
nshtrainer-0.11.7.dist-info/METADATA,sha256=htPbfKNDbqr1taf0bEvSbl-hQOaPfjKkJhvrWoMz2r0,860
|
|
86
|
+
nshtrainer-0.11.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.11.7.dist-info/RECORD,,
|
|
File without changes
|