nshtrainer 0.10.12__py3-none-any.whl → 0.10.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,8 @@ def _link_checkpoint(
36
36
  # fall back to copying the file
37
37
  shutil.copy(filepath, linkpath)
38
38
 
39
- _link_checkpoint_metadata(filepath, linkpath)
39
+ if metadata:
40
+ _link_checkpoint_metadata(filepath, linkpath)
40
41
  if barrier:
41
42
  trainer.strategy.barrier()
42
43
 
@@ -44,9 +45,17 @@ def _link_checkpoint(
44
45
  def _remove_checkpoint(
45
46
  trainer: Trainer,
46
47
  filepath: str | Path | os.PathLike,
47
- remove_metadata: bool = True,
48
+ *,
49
+ metadata: bool,
50
+ barrier: bool,
48
51
  ):
49
52
  if not isinstance(filepath, Path):
50
53
  filepath = Path(filepath)
51
- trainer.strategy.remove_checkpoint(filepath)
52
- _remove_checkpoint_metadata(filepath)
54
+
55
+ if trainer.is_global_zero:
56
+ trainer.strategy.remove_checkpoint(filepath)
57
+ if metadata:
58
+ _remove_checkpoint_metadata(filepath)
59
+
60
+ if barrier:
61
+ trainer.strategy.barrier()
@@ -69,7 +69,7 @@ class LatestEpochCheckpoint(Checkpoint):
69
69
 
70
70
  def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
71
  for ckpt_path in ckpt_paths:
72
- _remove_checkpoint(trainer, ckpt_path, remove_metadata=True)
72
+ _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
73
73
 
74
74
  def _remove_old_checkpoints(self, trainer: Trainer):
75
75
  if (latest_k := self.config.latest_k) == "all":
@@ -202,4 +202,4 @@ class ModelCheckpoint(_ModelCheckpoint):
202
202
 
203
203
  @override
204
204
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
205
- return _remove_checkpoint(trainer, filepath, remove_metadata=True)
205
+ return _remove_checkpoint(trainer, filepath, metadata=True, barrier=False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.12
3
+ Version: 0.10.13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,7 +1,7 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=soK9tXVs6EOpzhlnIxTEF51KmdkaCDUj0Rdyid3uREk,5640
4
- nshtrainer/_checkpoint/saver.py,sha256=KZp9ITUVHwj2Ttu81zXKdlS_h-fKkHearspwuAijDpM,1501
4
+ nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
@@ -15,9 +15,9 @@ nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,1
15
15
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
16
16
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
17
17
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
18
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=t4vWa4PvJDO3rKXKZbuegm7iLl7xCEd17wNif0Bp-BA,4138
18
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zUeYAGfeQby0R6IwQBJH3lng-MD0vkckdX4aIOm-VIc,4146
19
19
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
20
- nshtrainer/callbacks/model_checkpoint.py,sha256=MaDkD8Ismcj8u6l2flCFlqJR3-k1Tc4xzhxNWNux4n0,6556
20
+ nshtrainer/callbacks/model_checkpoint.py,sha256=8D0wWLhr_KiksAA1fjfIuby42Mq6XokCvAnVUhjADd8,6564
21
21
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
22
22
  nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKhA9qBEosiV8ZNhhZ6lI,3355
23
23
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
@@ -79,6 +79,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
79
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.12.dist-info/METADATA,sha256=RDwRSo6cvM5H1OgeCc3C5ZglfwjhAQ0I1nlw4xfLpUw,696
83
- nshtrainer-0.10.12.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.12.dist-info/RECORD,,
82
+ nshtrainer-0.10.13.dist-info/METADATA,sha256=HpGl8_E6q2l2nQrIzU5ibNEyUXj8adF8cxMzouUSpAg,696
83
+ nshtrainer-0.10.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.13.dist-info/RECORD,,