nshtrainer 0.11.4__tar.gz → 0.11.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +28 -3
  4. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/README.md +0 -0
  5. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/__init__.py +0 -0
  6. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_checkpoint/loader.py +0 -0
  7. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  8. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_checkpoint/saver.py +0 -0
  9. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_experimental/__init__.py +0 -0
  10. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  11. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  12. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  13. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/__init__.py +0 -0
  14. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  15. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/actsave.py +0 -0
  16. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/base.py +0 -0
  17. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  18. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +0 -0
  19. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -0
  20. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  21. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  22. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/ema.py +0 -0
  23. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  24. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  25. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/interval.py +0 -0
  26. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  27. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  28. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/data/transform.py +0 -0
  35. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/__init__.py +0 -0
  36. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/_experimental.py +0 -0
  37. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/actsave.py +0 -0
  38. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/callbacks.py +0 -0
  39. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/config.py +0 -0
  40. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/data.py +0 -0
  41. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/log.py +0 -0
  42. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  43. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/model.py +0 -0
  44. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/nn.py +0 -0
  45. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/optimizer.py +0 -0
  46. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/runner.py +0 -0
  47. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/snapshot.py +0 -0
  48. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/snoop.py +0 -0
  49. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/trainer.py +0 -0
  50. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/typecheck.py +0 -0
  51. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/ll/util.py +0 -0
  52. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  53. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  54. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  55. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  56. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/metrics/__init__.py +0 -0
  57. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/metrics/_config.py +0 -0
  58. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/__init__.py +0 -0
  59. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/base.py +0 -0
  60. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/config.py +0 -0
  61. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/callback.py +0 -0
  62. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/debug.py +0 -0
  63. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/distributed.py +0 -0
  64. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/logger.py +0 -0
  65. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/profiler.py +0 -0
  66. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  67. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  68. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/nn/__init__.py +0 -0
  69. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/nn/mlp.py +0 -0
  70. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/nn/module_dict.py +0 -0
  71. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/nn/module_list.py +0 -0
  72. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/nn/nonlinearity.py +0 -0
  73. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/optimizer.py +0 -0
  74. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/runner.py +0 -0
  75. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/trainer/trainer.py +0 -0
  81. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/_environment_info.py +0 -0
  82. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/_useful_types.py +0 -0
  83. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/environment.py +0 -0
  84. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/seed.py +0 -0
  85. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/slurm.py +0 -0
  86. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/typed.py +0 -0
  87. {nshtrainer-0.11.4 → nshtrainer-0.11.6}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.4
3
+ Version: 0.11.6
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.4"
3
+ version = "0.11.6"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -72,6 +72,8 @@ class BestCheckpoint(Checkpoint):
72
72
  self.metric = metric
73
73
  self.dirpath = dirpath
74
74
 
75
+ self._last_global_step_saved = 0 # no need to save when no steps were taken
76
+
75
77
  @override
76
78
  def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
77
79
  self._save_best_checkpoint(trainer)
@@ -127,6 +129,11 @@ class BestCheckpoint(Checkpoint):
127
129
  log.debug(f"Created best symlink: {symlink_path}")
128
130
 
129
131
  def _save_best_checkpoint(self, trainer: Trainer):
132
+ # Skip saving the checkpoint if we're not in the fitting state
133
+ if self._should_skip_saving_checkpoint(trainer):
134
+ return
135
+
136
+ # Get the current metric value
130
137
  if (current := self._get_metric_value(trainer.callback_metrics)) is None:
131
138
  log.warning(
132
139
  f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
@@ -150,6 +157,7 @@ class BestCheckpoint(Checkpoint):
150
157
  # Save the current model
151
158
  filepath = self._ckpt_path(trainer)
152
159
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
160
+ log.debug(f"Saved best checkpoint: {filepath}")
153
161
 
154
162
  # Remove worst checkpoint if we've reached save_top_k
155
163
  # NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
@@ -163,9 +171,26 @@ class BestCheckpoint(Checkpoint):
163
171
  )
164
172
 
165
173
  # Create symlink to best model
166
- _, best_ckpt_path = sorted_ckpts[0]
167
- self._create_symlink(trainer, best_ckpt_path)
168
- log.debug(f"Saved best checkpoint: {filepath}")
174
+ if sorted_ckpts:
175
+ _, best_ckpt_path = sorted_ckpts[0]
176
+ self._create_symlink(trainer, best_ckpt_path)
177
+
178
+ # Update the last global step saved
179
+ self._last_global_step_saved = trainer.global_step
169
180
 
170
181
  # Barrier to ensure all processes have saved the checkpoint before continuing
171
182
  trainer.strategy.barrier()
183
+
184
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
185
+ from lightning.pytorch.trainer.states import TrainerFn
186
+
187
+ return (
188
+ bool(
189
+ getattr(trainer, "fast_dev_run", False)
190
+ ) # disable checkpointing with fast_dev_run
191
+ or trainer.state.fn
192
+ != TrainerFn.FITTING # don't save anything during non-fit
193
+ or trainer.sanity_checking # don't save anything during sanity check
194
+ or self._last_global_step_saved
195
+ == trainer.global_step # already saved at the last step
196
+ )
File without changes