nshtrainer 0.10.5__tar.gz → 0.10.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/PKG-INFO +1 -1
  2. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/pyproject.toml +1 -1
  3. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +5 -12
  4. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/config.py +3 -0
  5. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/trainer/trainer.py +4 -0
  6. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/README.md +0 -0
  7. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/__init__.py +0 -0
  8. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/_checkpoint/loader.py +0 -0
  9. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  10. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/_experimental/__init__.py +0 -0
  11. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  12. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  13. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  14. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/__init__.py +0 -0
  15. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/actsave.py +0 -0
  17. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/base.py +0 -0
  18. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  19. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/ema.py +0 -0
  20. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  21. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  22. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/interval.py +0 -0
  23. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  24. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/model_checkpoint.py +0 -0
  25. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  26. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/print_table.py +0 -0
  28. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  29. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/timer.py +0 -0
  30. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  31. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/data/__init__.py +0 -0
  32. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  33. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/data/transform.py +0 -0
  34. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/__init__.py +0 -0
  35. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/_experimental.py +0 -0
  36. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/actsave.py +0 -0
  37. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/callbacks.py +0 -0
  38. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/config.py +0 -0
  39. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/data.py +0 -0
  40. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/log.py +0 -0
  41. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  42. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/model.py +0 -0
  43. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/nn.py +0 -0
  44. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/optimizer.py +0 -0
  45. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/runner.py +0 -0
  46. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/snapshot.py +0 -0
  47. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/snoop.py +0 -0
  48. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/trainer.py +0 -0
  49. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/typecheck.py +0 -0
  50. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/ll/util.py +0 -0
  51. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  52. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  53. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  54. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  55. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/metrics/__init__.py +0 -0
  56. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/metrics/_config.py +0 -0
  57. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/__init__.py +0 -0
  58. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/_environment.py +0 -0
  59. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/base.py +0 -0
  60. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/callback.py +0 -0
  61. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/debug.py +0 -0
  62. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/distributed.py +0 -0
  63. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/logger.py +0 -0
  64. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/profiler.py +0 -0
  65. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  66. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  67. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/nn/__init__.py +0 -0
  68. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/nn/mlp.py +0 -0
  69. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/nn/module_dict.py +0 -0
  70. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/nn/module_list.py +0 -0
  71. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/nn/nonlinearity.py +0 -0
  72. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/optimizer.py +0 -0
  73. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/runner.py +0 -0
  74. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/scripts/check_env.py +0 -0
  75. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/util/environment.py +0 -0
  81. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/util/seed.py +0 -0
  82. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/util/slurm.py +0 -0
  83. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/util/typed.py +0 -0
  84. {nshtrainer-0.10.5 → nshtrainer-0.10.7}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.5
3
+ Version: 0.10.7
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.10.5"
3
+ version = "0.10.7"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -43,10 +43,6 @@ class LatestEpochCheckpoint(Checkpoint):
43
43
  self.config = config
44
44
  self.dirpath = dirpath
45
45
 
46
- # Also, we hold a reference to the last checkpoint path
47
- # to be able to remove it when a new checkpoint is saved.
48
- self._last_ckpt_path: Path | None = None
49
-
50
46
  def _ckpt_path(self, trainer: Trainer):
51
47
  return self.dirpath / self.config.filename.format(
52
48
  epoch=trainer.current_epoch, step=trainer.global_step
@@ -54,20 +50,17 @@ class LatestEpochCheckpoint(Checkpoint):
54
50
 
55
51
  @override
56
52
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
57
- # Remove the last checkpoint if it exists
58
- if self._last_ckpt_path is not None:
59
- trainer.strategy.remove_checkpoint(self._last_ckpt_path)
60
-
61
53
  # Save the new checkpoint
62
54
  filepath = self._ckpt_path(trainer)
63
55
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
64
- self._last_ckpt_path = filepath
65
56
 
66
57
  # Create the latest symlink
67
- if (symlink_filename := self.config.latest_symlink_filename) is not None:
58
+ if (
59
+ trainer.is_global_zero
60
+ and (symlink_filename := self.config.latest_symlink_filename) is not None
61
+ ):
68
62
  symlink_path = self.dirpath / symlink_filename
69
- if symlink_path.exists():
70
- symlink_path.unlink()
63
+ symlink_path.unlink(missing_ok=True)
71
64
  symlink_path.symlink_to(filepath.name)
72
65
  log.info(f"Created latest symlink: {symlink_path}")
73
66
 
@@ -1121,6 +1121,9 @@ class SanityCheckingConfig(C.Config):
1121
1121
 
1122
1122
 
1123
1123
  class TrainerConfig(C.Config):
1124
+ ckpt_path: str | Path | None = None
1125
+ """Path to a checkpoint to load and resume training from."""
1126
+
1124
1127
  checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
1125
1128
  """Checkpoint loading configuration options."""
1126
1129
 
@@ -304,6 +304,10 @@ class Trainer(LightningTrainer):
304
304
  log_dir = str(Path(log_dir).resolve())
305
305
  log.critical(f"LightningTrainer log directory: {self.log_dir}.")
306
306
 
307
+ # Set the checkpoint
308
+ if (ckpt_path := config.trainer.ckpt_path) is not None:
309
+ self.ckpt_path = str(Path(ckpt_path).resolve().absolute())
310
+
307
311
  def __runtime_tracker(self):
308
312
  return next(
309
313
  (
File without changes