nshtrainer 0.10.15__py3-none-any.whl → 0.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -105,10 +105,8 @@ def _write_checkpoint_metadata(
105
105
 
106
106
 
107
107
  def _remove_checkpoint_metadata(checkpoint_path: Path):
108
- for path in (
109
- checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
110
- checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
111
- ):
108
+ for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
109
+ path = checkpoint_path.with_suffix(suffix)
112
110
  try:
113
111
  path.unlink(missing_ok=True)
114
112
  except Exception as e:
@@ -122,11 +120,9 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
122
120
  _remove_checkpoint_metadata(linked_checkpoint_path)
123
121
 
124
122
  # Link the metadata files to the new checkpoint
125
- for path in (
126
- checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
127
- checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
128
- ):
129
- linked_path = linked_checkpoint_path.with_suffix(path.suffix)
123
+ for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
124
+ path = checkpoint_path.with_suffix(suffix)
125
+ linked_path = linked_checkpoint_path.with_suffix(suffix)
130
126
  try:
131
127
  try:
132
128
  linked_path.symlink_to(path)
@@ -64,7 +64,7 @@ class LatestEpochCheckpoint(Checkpoint):
64
64
  filename = self.config.filename.format(
65
65
  epoch=trainer.current_epoch, step=trainer.global_step
66
66
  )
67
- filename = f"{self.PREFIX}{filename}.{self.EXTENSION}"
67
+ filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
68
68
  return self.dirpath / filename
69
69
 
70
70
  def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
@@ -95,7 +95,9 @@ class LatestEpochCheckpoint(Checkpoint):
95
95
 
96
96
  def _save_new_checkpoint(self, trainer: Trainer):
97
97
  # Remove old checkpoints
98
- self._remove_old_checkpoints(trainer)
98
+ if trainer.is_global_zero:
99
+ self._remove_old_checkpoints(trainer)
100
+ trainer.strategy.barrier()
99
101
 
100
102
  # Save the new checkpoint
101
103
  filepath = self._ckpt_path(trainer)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.15
3
+ Version: 0.10.16
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
- nshtrainer/_checkpoint/metadata.py,sha256=soK9tXVs6EOpzhlnIxTEF51KmdkaCDUj0Rdyid3uREk,5640
3
+ nshtrainer/_checkpoint/metadata.py,sha256=GlhlAyJh5gcp3R8l2Y3eAUQtQzBnitFlB0xdx-khEUQ,5579
4
4
  nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
@@ -15,7 +15,7 @@ nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,1
15
15
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
16
16
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
17
17
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
18
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zUeYAGfeQby0R6IwQBJH3lng-MD0vkckdX4aIOm-VIc,4146
18
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=5JC-JCdgWNnunl0jv4Q9LhkEspLAn0x8VpCMJZi7-ow,4219
19
19
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
20
20
  nshtrainer/callbacks/model_checkpoint.py,sha256=8D0wWLhr_KiksAA1fjfIuby42Mq6XokCvAnVUhjADd8,6564
21
21
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
@@ -79,6 +79,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
79
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.15.dist-info/METADATA,sha256=lBdMigvT3LEgOyWtMBwaRvru8XRTU8K5GQ-ll3kqwE8,696
83
- nshtrainer-0.10.15.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.15.dist-info/RECORD,,
82
+ nshtrainer-0.10.16.dist-info/METADATA,sha256=8jgjZDL82cNf_ys1xKUuqfKXAol8m2dWYB909W239fk,696
83
+ nshtrainer-0.10.16.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.16.dist-info/RECORD,,