nshtrainer 0.11.5__py3-none-any.whl → 0.11.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,11 +8,9 @@ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
8
 
9
9
 
10
10
  def _link_checkpoint(
11
- trainer: Trainer,
12
11
  filepath: str | Path | os.PathLike,
13
12
  linkpath: str | Path | os.PathLike,
14
13
  *,
15
- barrier: bool,
16
14
  metadata: bool,
17
15
  ):
18
16
  if not isinstance(filepath, Path):
@@ -20,26 +18,23 @@ def _link_checkpoint(
20
18
  if not isinstance(linkpath, Path):
21
19
  linkpath = Path(linkpath)
22
20
 
23
- if trainer.is_global_zero:
24
- if linkpath.exists():
25
- if linkpath.is_symlink() or linkpath.is_file():
26
- linkpath.unlink()
27
- elif linkpath.is_dir():
28
- shutil.rmtree(linkpath)
29
- _remove_checkpoint_metadata(linkpath)
21
+ if linkpath.exists():
22
+ if linkpath.is_symlink() or linkpath.is_file():
23
+ linkpath.unlink()
24
+ elif linkpath.is_dir():
25
+ shutil.rmtree(linkpath)
26
+ _remove_checkpoint_metadata(linkpath)
30
27
 
31
- try:
32
- target_path = filepath.relative_to(linkpath.parent)
33
- linkpath.symlink_to(target_path)
34
- except OSError:
35
- # on Windows, special permissions are required to create symbolic links as a regular user
36
- # fall back to copying the file
37
- shutil.copy(filepath, linkpath)
28
+ try:
29
+ target_path = filepath.relative_to(linkpath.parent)
30
+ linkpath.symlink_to(target_path)
31
+ except OSError:
32
+ # on Windows, special permissions are required to create symbolic links as a regular user
33
+ # fall back to copying the file
34
+ shutil.copy(filepath, linkpath)
38
35
 
39
- if metadata:
40
- _link_checkpoint_metadata(filepath, linkpath)
41
- if barrier:
42
- trainer.strategy.barrier()
36
+ if metadata:
37
+ _link_checkpoint_metadata(filepath, linkpath)
43
38
 
44
39
 
45
40
  def _remove_checkpoint(
@@ -47,15 +42,10 @@ def _remove_checkpoint(
47
42
  filepath: str | Path | os.PathLike,
48
43
  *,
49
44
  metadata: bool,
50
- barrier: bool,
51
45
  ):
52
46
  if not isinstance(filepath, Path):
53
47
  filepath = Path(filepath)
54
48
 
55
- if trainer.is_global_zero:
56
- trainer.strategy.remove_checkpoint(filepath)
57
- if metadata:
58
- _remove_checkpoint_metadata(filepath)
59
-
60
- if barrier:
61
- trainer.strategy.barrier()
49
+ trainer.strategy.remove_checkpoint(filepath)
50
+ if metadata:
51
+ _remove_checkpoint_metadata(filepath)
@@ -72,6 +72,8 @@ class BestCheckpoint(Checkpoint):
72
72
  self.metric = metric
73
73
  self.dirpath = dirpath
74
74
 
75
+ self._last_global_step_saved = 0 # no need to save when no steps were taken
76
+
75
77
  @override
76
78
  def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
77
79
  self._save_best_checkpoint(trainer)
@@ -88,10 +90,6 @@ class BestCheckpoint(Checkpoint):
88
90
  filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
89
91
  return self.dirpath / filename
90
92
 
91
- def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
92
- for ckpt_path in ckpt_paths:
93
- _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
94
-
95
93
  def _get_metric_value(self, metrics: dict[str, Any]):
96
94
  return metrics.get(
97
95
  self.metric.validation_monitor,
@@ -99,11 +97,16 @@ class BestCheckpoint(Checkpoint):
99
97
  )
100
98
 
101
99
  def _sorted_ckpts(self):
100
+ """
101
+ Get sorted checkpoints by the metric value.
102
+
103
+ Sort order: best -> worst
104
+ """
102
105
  ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
103
106
  return _sort_ckpts_by_metadata(
104
107
  ckpt_paths,
105
108
  key=lambda meta, _: self._get_metric_value(meta.metrics),
106
- reverse=(self.metric.mode == "min"),
109
+ reverse=(self.metric.mode == "max"),
107
110
  )
108
111
 
109
112
  def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
@@ -117,16 +120,15 @@ class BestCheckpoint(Checkpoint):
117
120
  if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
118
121
  return
119
122
 
120
- _link_checkpoint(
121
- trainer,
122
- best_ckpt_path,
123
- symlink_path,
124
- metadata=True,
125
- barrier=False,
126
- )
123
+ _link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
127
124
  log.debug(f"Created best symlink: {symlink_path}")
128
125
 
129
126
  def _save_best_checkpoint(self, trainer: Trainer):
127
+ # Skip saving the checkpoint if we're not in the fitting state
128
+ if self._should_skip_saving_checkpoint(trainer):
129
+ return
130
+
131
+ # Get the current metric value
130
132
  if (current := self._get_metric_value(trainer.callback_metrics)) is None:
131
133
  log.warning(
132
134
  f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
@@ -152,21 +154,39 @@ class BestCheckpoint(Checkpoint):
152
154
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
153
155
  log.debug(f"Saved best checkpoint: {filepath}")
154
156
 
155
- # Remove worst checkpoint if we've reached save_top_k
156
- # NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
157
- if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
157
+ if trainer.is_global_zero:
158
158
  # Get the sorted checkpoints again because now we have added a new checkpoint.
159
159
  # We could optimize this by adding the new checkpoint to the sorted list,
160
160
  # and then sorting it in place, but this is simpler.
161
161
  sorted_ckpts = self._sorted_ckpts()
162
- self._remove_checkpoints(
163
- trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
164
- )
165
162
 
166
- # Create symlink to best model
167
- if sorted_ckpts:
168
- _, best_ckpt_path = sorted_ckpts[0]
169
- self._create_symlink(trainer, best_ckpt_path)
163
+ # Remove worst checkpoint if we've reached save_top_k
164
+ if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
165
+ # NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
166
+ for _, ckpt_path in sorted_ckpts[topk:]:
167
+ _remove_checkpoint(trainer, ckpt_path, metadata=True)
168
+
169
+ # Create symlink to best model
170
+ if sorted_ckpts:
171
+ _, best_ckpt_path = sorted_ckpts[0]
172
+ self._create_symlink(trainer, best_ckpt_path)
173
+
174
+ # Update the last global step saved
175
+ self._last_global_step_saved = trainer.global_step
170
176
 
171
177
  # Barrier to ensure all processes have saved the checkpoint before continuing
172
178
  trainer.strategy.barrier()
179
+
180
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
181
+ from lightning.pytorch.trainer.states import TrainerFn
182
+
183
+ return (
184
+ bool(
185
+ getattr(trainer, "fast_dev_run", False)
186
+ ) # disable checkpointing with fast_dev_run
187
+ or trainer.state.fn
188
+ != TrainerFn.FITTING # don't save anything during non-fit
189
+ or trainer.sanity_checking # don't save anything during sanity check
190
+ or self._last_global_step_saved
191
+ == trainer.global_step # already saved at the last step
192
+ )
@@ -51,6 +51,8 @@ class LatestEpochCheckpoint(Checkpoint):
51
51
  self.config = config
52
52
  self.dirpath = dirpath
53
53
 
54
+ self._last_global_step_saved = 0
55
+
54
56
  @override
55
57
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
56
58
  self._save_new_checkpoint(trainer)
@@ -67,53 +69,63 @@ class LatestEpochCheckpoint(Checkpoint):
67
69
  filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
68
70
  return self.dirpath / filename
69
71
 
70
- def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
- for ckpt_path in ckpt_paths:
72
- _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
73
-
74
72
  def _remove_old_checkpoints(self, trainer: Trainer):
75
73
  if (latest_k := self.config.latest_k) == "all":
76
74
  return
77
75
 
78
- # NOTE: We add 1 to the latest_k here because
79
- # we're about to save a new checkpoint.
80
- latest_k += 1
81
-
82
76
  # Get all configs, ignoring the latest symlink
83
- ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
77
+ ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
84
78
  # Ignore the latest symlink
85
79
  if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
86
- ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
80
+ ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
87
81
 
88
82
  # Sort by epoch, then step, then last modified
89
- metadata_and_ckpt_paths = _sort_ckpts_by_metadata(
90
- ckpt_paths,
83
+ ckpts = _sort_ckpts_by_metadata(
84
+ ckpts,
91
85
  key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
92
86
  reverse=True,
93
87
  )
94
88
 
95
89
  # Remove all but the latest k checkpoints
96
- ckpts_to_remove = metadata_and_ckpt_paths[latest_k:]
97
- self._remove_checkpoints(trainer, [p for _, p in ckpts_to_remove])
90
+ # NOTE: We add 1 to the latest_k here because
91
+ # we're about to save a new checkpoint.
92
+ for _, ckpt_path in ckpts[latest_k:]:
93
+ _remove_checkpoint(trainer, ckpt_path, metadata=True)
98
94
 
99
95
  def _save_new_checkpoint(self, trainer: Trainer):
100
- # Remove old checkpoints
101
- if trainer.is_global_zero:
102
- self._remove_old_checkpoints(trainer)
103
- trainer.strategy.barrier()
96
+ if self._should_skip_saving_checkpoint(trainer):
97
+ return
104
98
 
105
99
  # Save the new checkpoint
106
100
  filepath = self._ckpt_path(trainer)
107
101
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
108
102
 
109
- # Create the latest symlink
110
- if (symlink_filename := self._latest_symlink_filename()) is not None:
111
- symlink_path = self.dirpath / symlink_filename
112
- _link_checkpoint(
113
- trainer,
114
- filepath,
115
- symlink_path,
116
- barrier=True,
117
- metadata=True,
118
- )
119
- log.debug(f"Created latest symlink: {symlink_path}")
103
+ if trainer.is_global_zero:
104
+ # Remove old checkpoints
105
+ self._remove_old_checkpoints(trainer)
106
+
107
+ # Create the latest symlink
108
+ if (symlink_filename := self._latest_symlink_filename()) is not None:
109
+ symlink_path = self.dirpath / symlink_filename
110
+ _link_checkpoint(filepath, symlink_path, metadata=True)
111
+ log.debug(f"Created latest symlink: {symlink_path}")
112
+
113
+ # Set the last global step saved
114
+ self._last_global_step_saved = trainer.global_step
115
+
116
+ # Barrier to ensure all processes have saved the checkpoint before continuing
117
+ trainer.strategy.barrier()
118
+
119
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
120
+ from lightning.pytorch.trainer.states import TrainerFn
121
+
122
+ return (
123
+ bool(
124
+ getattr(trainer, "fast_dev_run", False)
125
+ ) # disable checkpointing with fast_dev_run
126
+ or trainer.state.fn
127
+ != TrainerFn.FITTING # don't save anything during non-fit
128
+ or trainer.sanity_checking # don't save anything during sanity check
129
+ or self._last_global_step_saved
130
+ == trainer.global_step # already saved at the last step
131
+ )
@@ -198,19 +198,10 @@ class ModelCheckpoint(_ModelCheckpoint):
198
198
 
199
199
  @override
200
200
  def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
201
- return _link_checkpoint(
202
- trainer,
203
- filepath,
204
- linkpath,
205
- barrier=True,
206
- metadata=True,
207
- )
201
+ if trainer.is_global_zero:
202
+ _link_checkpoint(filepath, linkpath, metadata=True)
203
+ trainer.strategy.barrier()
208
204
 
209
205
  @override
210
206
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
211
- return _ckpt_saver_remove_checkpoint(
212
- trainer,
213
- filepath,
214
- metadata=True,
215
- barrier=False,
216
- )
207
+ _ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.5
3
+ Version: 0.11.7
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,7 +1,7 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=_3jBf-k-fJCFfmU8wjDwbnE9rb4WoKYEyQiKGsBOCi4,13777
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=M9eAZ2xMs36Z1G1xULu9MHZhsHxN8_9mNt3Iv7wuq-I,5069
4
- nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
4
+ nshtrainer/_checkpoint/saver.py,sha256=TuSAP39DOOVvSnSukQ9RitMV60JnDg6L27fMRc2uVJc,1358
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
@@ -11,9 +11,9 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
13
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=zrEVCGFikfkt0iOMceOFzXsZG2-6QrqY79RKBCS7bu4,738
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=w99O5GWRcV89XBe4j__v2TvNEHys0x_r3tSTr-6Lhec,6154
15
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=NES-acaslPBiZQIMAdk_YwtnBrkm_y_BJQ8Ian0UKP0,4294
16
- nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=mLFMbNzeMiBer3BCb7o3ucswKpOCQlYyN3wdB92N-LY,6884
14
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=BUO0sWqlwfyxD1UeII5DZ-01SGLiawJAEsL8HjGX4XA,7018
15
+ nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=CQ0IqhuPI7zAxpQLy48kK8qqfVfwXEJoHGRqI4h8xNk,4819
16
+ nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=JS1z2YuEiQxk61HgZU1jySzF_pzdfXYO54_qHo-q3CQ,6776
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
18
18
  nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
19
19
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
82
82
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
83
83
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
84
84
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
85
- nshtrainer-0.11.5.dist-info/METADATA,sha256=KHgvYOhQXbc37awWeLbpbdVQbSEU4J7KoC7Lr5286KE,860
86
- nshtrainer-0.11.5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
- nshtrainer-0.11.5.dist-info/RECORD,,
85
+ nshtrainer-0.11.7.dist-info/METADATA,sha256=htPbfKNDbqr1taf0bEvSbl-hQOaPfjKkJhvrWoMz2r0,860
86
+ nshtrainer-0.11.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
+ nshtrainer-0.11.7.dist-info/RECORD,,