nshtrainer 1.3.6__tar.gz → 1.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/PKG-INFO +2 -2
  2. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/pyproject.toml +2 -2
  3. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/_checkpoint/metadata.py +4 -1
  4. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/_hf_hub.py +3 -0
  5. nshtrainer-1.4.1/src/nshtrainer/callbacks/checkpoint/_base.py +320 -0
  6. nshtrainer-1.4.1/src/nshtrainer/callbacks/log_epoch.py +136 -0
  7. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/lr_monitor.py +9 -1
  8. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/_config.py +9 -3
  9. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/trainer.py +10 -2
  10. nshtrainer-1.3.6/src/nshtrainer/callbacks/checkpoint/_base.py +0 -187
  11. nshtrainer-1.3.6/src/nshtrainer/callbacks/log_epoch.py +0 -49
  12. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/README.md +0 -0
  13. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/.nshconfig.generated.json +0 -0
  14. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/__init__.py +0 -0
  15. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/_callback.py +0 -0
  16. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  17. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  18. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  19. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  20. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/base.py +0 -0
  21. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  22. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  23. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  24. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  25. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  26. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  27. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/distributed_prediction_writer.py +0 -0
  28. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  29. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/ema.py +0 -0
  30. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  31. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  32. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/interval.py +0 -0
  33. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/metric_validation.py +0 -0
  34. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  35. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  36. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  37. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  38. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/timer.py +0 -0
  39. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  40. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  41. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/.gitattributes +0 -0
  42. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/__init__.py +0 -0
  43. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  44. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  45. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  46. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  47. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  48. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  49. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  50. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  51. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  52. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  53. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  54. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  55. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  56. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
  57. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  58. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  59. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  60. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  61. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  62. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  63. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  64. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  65. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  66. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  67. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  68. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  69. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  70. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  71. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  72. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  73. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  74. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  75. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  76. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  77. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  78. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  79. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  80. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  81. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  82. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  83. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/nn/__init__.py +0 -0
  84. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  85. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  86. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
  87. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  88. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  89. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  90. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  91. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  92. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  93. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/__init__.py +0 -0
  94. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
  95. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  96. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  97. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  98. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  99. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  100. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  101. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  102. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  103. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  104. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/util/__init__.py +0 -0
  105. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  106. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  107. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  108. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  109. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/data/__init__.py +0 -0
  110. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  111. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/data/datamodule.py +0 -0
  112. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/data/transform.py +0 -0
  113. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/loggers/__init__.py +0 -0
  114. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/loggers/actsave.py +0 -0
  115. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/loggers/base.py +0 -0
  116. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/loggers/csv.py +0 -0
  117. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
  118. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/loggers/wandb.py +0 -0
  119. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  120. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/lr_scheduler/base.py +0 -0
  121. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  122. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  123. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/metrics/__init__.py +0 -0
  124. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/metrics/_config.py +0 -0
  125. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/model/__init__.py +0 -0
  126. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/model/base.py +0 -0
  127. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/model/mixins/callback.py +0 -0
  128. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/model/mixins/debug.py +0 -0
  129. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/model/mixins/logger.py +0 -0
  130. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/nn/__init__.py +0 -0
  131. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/nn/mlp.py +0 -0
  132. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/nn/module_dict.py +0 -0
  133. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/nn/module_list.py +0 -0
  134. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  135. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/nn/rng.py +0 -0
  136. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/optimizer.py +0 -0
  137. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/profiler/__init__.py +0 -0
  138. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/profiler/_base.py +0 -0
  139. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/profiler/advanced.py +0 -0
  140. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/profiler/pytorch.py +0 -0
  141. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/profiler/simple.py +0 -0
  142. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/__init__.py +0 -0
  143. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/_distributed_prediction_result.py +0 -0
  144. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/_log_hparams.py +0 -0
  145. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  146. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/accelerator.py +0 -0
  147. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  148. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/plugin/base.py +0 -0
  149. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  150. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/plugin/io.py +0 -0
  151. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  152. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  153. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  154. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/trainer/strategy.py +0 -0
  155. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/_environment_info.py +0 -0
  156. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/bf16.py +0 -0
  157. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/code_upload.py +0 -0
  158. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/config/__init__.py +0 -0
  159. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/config/dtype.py +0 -0
  160. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/config/duration.py +0 -0
  161. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/environment.py +0 -0
  162. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/path.py +0 -0
  163. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/seed.py +0 -0
  164. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/slurm.py +0 -0
  165. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/typed.py +0 -0
  166. {nshtrainer-1.3.6 → nshtrainer-1.4.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.6
3
+ Version: 1.4.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -14,7 +14,7 @@ Provides-Extra: extra
14
14
  Requires-Dist: GitPython ; extra == "extra"
15
15
  Requires-Dist: huggingface-hub ; extra == "extra"
16
16
  Requires-Dist: lightning
17
- Requires-Dist: nshconfig (>0.39)
17
+ Requires-Dist: nshconfig (>=0.43)
18
18
  Requires-Dist: nshrunner ; extra == "extra"
19
19
  Requires-Dist: nshutils ; extra == "extra"
20
20
  Requires-Dist: numpy
@@ -1,13 +1,13 @@
1
1
  [project]
2
2
  name = "nshtrainer"
3
- version = "1.3.6"
3
+ version = "1.4.1"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
7
7
  readme = "README.md"
8
8
 
9
9
  dependencies = [
10
- "nshconfig>0.39",
10
+ "nshconfig>=0.43",
11
11
  "psutil",
12
12
  "numpy",
13
13
  "torch",
@@ -85,6 +85,7 @@ def _generate_checkpoint_metadata(
85
85
  trainer: Trainer,
86
86
  checkpoint_path: Path,
87
87
  metadata_path: Path,
88
+ compute_checksum: bool = True,
88
89
  ):
89
90
  checkpoint_timestamp = datetime.datetime.now()
90
91
  start_timestamp = trainer.start_time()
@@ -105,7 +106,9 @@ def _generate_checkpoint_metadata(
105
106
  # moving the checkpoint directory
106
107
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
107
108
  checkpoint_filename=checkpoint_path.name,
108
- checkpoint_checksum=compute_file_checksum(checkpoint_path),
109
+ checkpoint_checksum=compute_file_checksum(checkpoint_path)
110
+ if compute_checksum
111
+ else "",
109
112
  run_id=trainer.hparams.id,
110
113
  name=trainer.hparams.full_name,
111
114
  project=trainer.hparams.project,
@@ -91,6 +91,9 @@ class HuggingFaceHubConfig(CallbackConfigBase):
91
91
 
92
92
  @override
93
93
  def create_callbacks(self, trainer_config):
94
+ if not self:
95
+ return
96
+
94
97
  # Attempt to login. If it fails, we'll log a warning or error based on the configuration.
95
98
  try:
96
99
  api = _api(self.token)
@@ -0,0 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import string
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Callable
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
9
+
10
+ import numpy as np
11
+ import torch
12
+ from lightning.pytorch import Trainer
13
+ from lightning.pytorch.callbacks import Checkpoint
14
+ from typing_extensions import override
15
+
16
+ from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
17
+ from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
18
+ from ..base import CallbackConfigBase
19
+
20
+ if TYPE_CHECKING:
21
+ from ...trainer._config import TrainerConfig
22
+
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ class _FormatDict(dict):
28
+ """A dictionary that returns an empty string for missing keys when formatting."""
29
+
30
+ def __missing__(self, key):
31
+ log.debug(
32
+ f"Missing format key '{key}' in checkpoint filename, using empty string"
33
+ )
34
+ return ""
35
+
36
+
37
+ def _get_checkpoint_metadata(dirpath: Path) -> list[CheckpointMetadata]:
38
+ """Get all checkpoint metadata from a directory."""
39
+ return [
40
+ CheckpointMetadata.from_file(p)
41
+ for p in dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
42
+ if p.is_file() and not p.is_symlink()
43
+ ]
44
+
45
+
46
+ def _sort_checkpoint_metadata(
47
+ metas: list[CheckpointMetadata],
48
+ key_fn: Callable[[CheckpointMetadata], Any],
49
+ reverse: bool = False,
50
+ ) -> list[CheckpointMetadata]:
51
+ """Sort checkpoint metadata by the given key function."""
52
+ return sorted(metas, key=key_fn, reverse=reverse)
53
+
54
+
55
+ def _remove_checkpoints(
56
+ trainer: Trainer,
57
+ dirpath: Path,
58
+ metas_to_remove: list[CheckpointMetadata],
59
+ ) -> None:
60
+ """Remove checkpoint files and their metadata."""
61
+ for meta in metas_to_remove:
62
+ ckpt_path = dirpath / meta.checkpoint_filename
63
+ if not ckpt_path.exists():
64
+ log.warning(
65
+ f"Checkpoint file not found: {ckpt_path}\n"
66
+ "Skipping removal of the checkpoint metadata."
67
+ )
68
+ continue
69
+
70
+ remove_checkpoint(trainer, ckpt_path, metadata=True)
71
+ log.debug(f"Removed checkpoint: {ckpt_path}")
72
+
73
+
74
+ def _update_symlink(
75
+ dirpath: Path,
76
+ symlink_path: Path | None,
77
+ sort_key_fn: Callable[[CheckpointMetadata], Any],
78
+ sort_reverse: bool,
79
+ ) -> None:
80
+ """Update symlink to point to the best checkpoint."""
81
+ if symlink_path is None:
82
+ return
83
+
84
+ # Get all checkpoint metadata after any removals
85
+ remaining_metas = _get_checkpoint_metadata(dirpath)
86
+
87
+ if remaining_metas:
88
+ # Sort by the key function
89
+ remaining_metas = _sort_checkpoint_metadata(
90
+ remaining_metas, sort_key_fn, sort_reverse
91
+ )
92
+
93
+ # Link to the best checkpoint
94
+ best_meta = remaining_metas[0]
95
+ best_filepath = dirpath / best_meta.checkpoint_filename
96
+ link_checkpoint(best_filepath, symlink_path, metadata=True)
97
+ log.debug(f"Updated symlink {symlink_path.name} -> {best_filepath.name}")
98
+ else:
99
+ log.warning(f"No checkpoints found in {dirpath} to create symlink.")
100
+
101
+
102
+ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
103
+ dirpath: str | Path | None = None
104
+ """Directory path to save the checkpoint file."""
105
+
106
+ filename: str | None = None
107
+ """Checkpoint filename. This must not include the extension.
108
+ If None, the default filename will be used."""
109
+
110
+ save_weights_only: bool = False
111
+ """Whether to save only the model's weights or the entire model object."""
112
+
113
+ save_symlink: bool = True
114
+ """Whether to create a symlink to the saved checkpoint."""
115
+
116
+ topk: int | Literal["all"] = 1
117
+ """The number of checkpoints to keep."""
118
+
119
+ @abstractmethod
120
+ def create_checkpoint(
121
+ self,
122
+ trainer_config: TrainerConfig,
123
+ dirpath: Path,
124
+ ) -> "CheckpointBase | None": ...
125
+
126
+ @override
127
+ def create_callbacks(self, trainer_config):
128
+ dirpath = Path(
129
+ self.dirpath
130
+ or trainer_config.directory.resolve_subdirectory(
131
+ trainer_config.id, "checkpoint"
132
+ )
133
+ )
134
+
135
+ if (callback := self.create_checkpoint(trainer_config, dirpath)) is not None:
136
+ yield callback
137
+
138
+
139
+ TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
140
+
141
+
142
+ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
143
+ def __init__(self, config: TConfig, dirpath: Path):
144
+ super().__init__()
145
+
146
+ self.config = config
147
+ self.dirpath = dirpath / self.name()
148
+ self.dirpath.mkdir(parents=True, exist_ok=True)
149
+ self.symlink_dirpath = dirpath
150
+
151
+ @abstractmethod
152
+ def default_filename(self) -> str: ...
153
+
154
+ @abstractmethod
155
+ def name(self) -> str: ...
156
+
157
+ def extension(self) -> str:
158
+ return ".ckpt"
159
+
160
+ @abstractmethod
161
+ def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
162
+
163
+ @abstractmethod
164
+ def topk_sort_reverse(self) -> bool: ...
165
+
166
+ def symlink_path(self):
167
+ if not self.config.save_symlink:
168
+ return None
169
+
170
+ return self.symlink_dirpath / f"{self.name()}{self.extension()}"
171
+
172
+ def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
173
+ if (filename := self.config.filename) is None:
174
+ filename = self.default_filename()
175
+
176
+ # Extract all field names from the format string
177
+ field_names = [
178
+ fname for _, fname, _, _ in string.Formatter().parse(filename) if fname
179
+ ]
180
+
181
+ # Filter current_metrics to only include keys that are in the format string
182
+ format_dict = {k: v for k, v in current_metrics.items() if k in field_names}
183
+
184
+ try:
185
+ formatted_filename = filename.format(**format_dict)
186
+ except KeyError as e:
187
+ log.warning(
188
+ f"Missing key {e} in {filename=} with {format_dict=}. Using default values."
189
+ )
190
+ # Provide a simple fallback for missing keys
191
+ formatted_filename = string.Formatter().vformat(
192
+ filename, (), _FormatDict(format_dict)
193
+ )
194
+
195
+ return self.dirpath / f"{formatted_filename}{self.extension()}"
196
+
197
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
198
+ current_metrics: dict[str, Any] = {
199
+ "epoch": trainer.current_epoch,
200
+ "step": trainer.global_step,
201
+ }
202
+
203
+ for name, value in trainer.callback_metrics.items():
204
+ match value:
205
+ case torch.Tensor() if value.numel() == 1:
206
+ value = value.detach().cpu().item()
207
+ case np.ndarray() if value.size == 1:
208
+ value = value.item()
209
+ case _:
210
+ pass
211
+
212
+ current_metrics[name] = value
213
+
214
+ log.debug(
215
+ f"Current metrics: {current_metrics}, {trainer.callback_metrics=}, {trainer.logged_metrics=}"
216
+ )
217
+ return current_metrics
218
+
219
+ def save_checkpoints(self, trainer: Trainer):
220
+ log.debug(
221
+ f"{type(self).__name__}.save_checkpoints() called at {trainer.current_epoch=}, {trainer.global_step=}"
222
+ )
223
+ # Also print out the current stack trace for debugging
224
+ if log.isEnabledFor(logging.DEBUG):
225
+ import traceback
226
+
227
+ stack = traceback.extract_stack()
228
+ log.debug(f"Stack trace: {''.join(traceback.format_list(stack))}")
229
+
230
+ if self._should_skip_saving_checkpoint(trainer):
231
+ return
232
+
233
+ from ...trainer import Trainer as NTTrainer
234
+
235
+ if not isinstance(trainer, NTTrainer):
236
+ raise TypeError(
237
+ f"Trainer must be an instance of {NTTrainer.__name__}, "
238
+ f"but got {type(trainer).__name__}"
239
+ )
240
+
241
+ current_metrics = self.current_metrics(trainer)
242
+ filepath = self.resolve_checkpoint_path(current_metrics)
243
+
244
+ # Get all existing checkpoint metadata
245
+ existing_metas = _get_checkpoint_metadata(self.dirpath)
246
+
247
+ # Determine which checkpoints to remove
248
+ to_remove: list[CheckpointMetadata] = []
249
+ should_save = True
250
+
251
+ # Check if we should save this checkpoint
252
+ if (topk := self.config.topk) != "all" and len(existing_metas) >= topk:
253
+ # Generate hypothetical metadata for the current checkpoint
254
+ hypothetical_meta = _generate_checkpoint_metadata(
255
+ trainer=trainer,
256
+ checkpoint_path=filepath,
257
+ metadata_path=filepath.with_suffix(CheckpointMetadata.PATH_SUFFIX),
258
+ compute_checksum=False,
259
+ )
260
+
261
+ # Add the hypothetical metadata to the list and sort
262
+ metas = _sort_checkpoint_metadata(
263
+ [*existing_metas, hypothetical_meta],
264
+ self.topk_sort_key,
265
+ self.topk_sort_reverse(),
266
+ )
267
+
268
+ # If the hypothetical metadata is not in the top-k, skip saving
269
+ if hypothetical_meta not in metas[:topk]:
270
+ log.debug(
271
+ f"Skipping checkpoint save: would not make top {topk} "
272
+ f"based on {self.topk_sort_key.__name__}"
273
+ )
274
+ should_save = False
275
+ else:
276
+ # Determine which existing checkpoints to remove
277
+ to_remove = metas[topk:]
278
+ assert hypothetical_meta not in to_remove, (
279
+ "Hypothetical metadata should not be in the to_remove list."
280
+ )
281
+ log.debug(
282
+ f"Removing checkpoints: {[meta.checkpoint_filename for meta in to_remove]} "
283
+ f"and saving the new checkpoint: {hypothetical_meta.checkpoint_filename}"
284
+ )
285
+
286
+ # Only save if it would make it into the top-k
287
+ if should_save:
288
+ # Save the new checkpoint
289
+ trainer.save_checkpoint(
290
+ filepath,
291
+ weights_only=self.config.save_weights_only,
292
+ )
293
+
294
+ if trainer.is_global_zero:
295
+ # Remove old checkpoints that should be deleted
296
+ if to_remove:
297
+ _remove_checkpoints(trainer, self.dirpath, to_remove)
298
+
299
+ # Update the symlink to point to the best checkpoint
300
+ _update_symlink(
301
+ self.dirpath,
302
+ self.symlink_path(),
303
+ self.topk_sort_key,
304
+ self.topk_sort_reverse(),
305
+ )
306
+
307
+ # Barrier to ensure all processes have completed checkpoint operations
308
+ trainer.strategy.barrier()
309
+
310
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
311
+ from lightning.pytorch.trainer.states import TrainerFn
312
+
313
+ return (
314
+ bool(
315
+ getattr(trainer, "fast_dev_run", False)
316
+ ) # disable checkpointing with fast_dev_run
317
+ or trainer.state.fn
318
+ != TrainerFn.FITTING # don't save anything during non-fit
319
+ or trainer.sanity_checking # don't save anything during sanity check
320
+ )
@@ -0,0 +1,136 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ from typing import Any, Literal
6
+
7
+ from lightning.pytorch import LightningModule, Trainer
8
+ from lightning.pytorch.callbacks import Callback
9
+ from typing_extensions import final, override
10
+
11
+ from .base import CallbackConfigBase, callback_registry
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ @callback_registry.register
18
+ class LogEpochCallbackConfig(CallbackConfigBase):
19
+ name: Literal["log_epoch"] = "log_epoch"
20
+
21
+ metric_name: str = "computed_epoch"
22
+ """The name of the metric to log the epoch as."""
23
+
24
+ train: bool = True
25
+ """Whether to log the epoch during training."""
26
+
27
+ val: bool = True
28
+ """Whether to log the epoch during validation."""
29
+
30
+ test: bool = True
31
+ """Whether to log the epoch during testing."""
32
+
33
+ @override
34
+ def create_callbacks(self, trainer_config):
35
+ yield LogEpochCallback(self)
36
+
37
+
38
+ def _worker_fn(
39
+ trainer: Trainer,
40
+ pl_module: LightningModule,
41
+ num_batches_prop: str,
42
+ dataloader_idx: int | None = None,
43
+ *,
44
+ metric_name: str,
45
+ ):
46
+ if trainer.logger is None:
47
+ return
48
+
49
+ # If trainer.num_{training/val/test}_batches is not set or is nan/inf, we cannot calculate the epoch
50
+ if not (num_batches := getattr(trainer, num_batches_prop, None)):
51
+ log.warning(f"Trainer has no valid `{num_batches_prop}`. Cannot log epoch.")
52
+ return
53
+
54
+ # If the trainer has a dataloader_idx, num_batches is a list of num_batches for each dataloader.
55
+ if dataloader_idx is not None:
56
+ assert isinstance(num_batches, list), (
57
+ f"Expected num_batches to be a list, got {type(num_batches)}"
58
+ )
59
+ assert 0 <= dataloader_idx < len(num_batches), (
60
+ f"Expected dataloader_idx to be between 0 and {len(num_batches)}, got {dataloader_idx}"
61
+ )
62
+ num_batches = num_batches[dataloader_idx]
63
+
64
+ if (
65
+ not isinstance(num_batches, (int, float))
66
+ or math.isnan(num_batches)
67
+ or math.isinf(num_batches)
68
+ ):
69
+ log.warning(
70
+ f"Trainer has no valid `{num_batches_prop}` (got {num_batches=}). Cannot log epoch."
71
+ )
72
+ return
73
+
74
+ epoch = pl_module.global_step / num_batches
75
+ pl_module.log(metric_name, epoch, on_step=True, on_epoch=False)
76
+
77
+
78
+ class LogEpochCallback(Callback):
79
+ def __init__(self, config: LogEpochCallbackConfig):
80
+ super().__init__()
81
+
82
+ self.config = config
83
+
84
+ @override
85
+ def on_train_batch_start(
86
+ self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
87
+ ):
88
+ if trainer.logger is None or not self.config.train:
89
+ return
90
+
91
+ _worker_fn(
92
+ trainer,
93
+ pl_module,
94
+ "num_training_batches",
95
+ metric_name=self.config.metric_name,
96
+ )
97
+
98
+ @override
99
+ def on_validation_batch_start(
100
+ self,
101
+ trainer: Trainer,
102
+ pl_module: LightningModule,
103
+ batch: Any,
104
+ batch_idx: int,
105
+ dataloader_idx: int = 0,
106
+ ) -> None:
107
+ if trainer.logger is None or not self.config.val:
108
+ return
109
+
110
+ _worker_fn(
111
+ trainer,
112
+ pl_module,
113
+ "num_val_batches",
114
+ dataloader_idx=dataloader_idx,
115
+ metric_name=self.config.metric_name,
116
+ )
117
+
118
+ @override
119
+ def on_test_batch_start(
120
+ self,
121
+ trainer: Trainer,
122
+ pl_module: LightningModule,
123
+ batch: Any,
124
+ batch_idx: int,
125
+ dataloader_idx: int = 0,
126
+ ) -> None:
127
+ if trainer.logger is None or not self.config.test:
128
+ return
129
+
130
+ _worker_fn(
131
+ trainer,
132
+ pl_module,
133
+ "num_test_batches",
134
+ dataloader_idx=dataloader_idx,
135
+ metric_name=self.config.metric_name,
136
+ )
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  from typing import Literal
4
5
 
5
6
  from lightning.pytorch.callbacks import LearningRateMonitor
6
- from typing_extensions import final
7
+ from typing_extensions import final, override
7
8
 
8
9
  from .base import CallbackConfigBase, callback_registry
9
10
 
11
+ log = logging.getLogger(__name__)
12
+
10
13
 
11
14
  @final
12
15
  @callback_registry.register
@@ -28,7 +31,12 @@ class LearningRateMonitorConfig(CallbackConfigBase):
28
31
  Option to also log the weight decay values of the optimizer. Defaults to False.
29
32
  """
30
33
 
34
+ @override
31
35
  def create_callbacks(self, trainer_config):
36
+ if not list(trainer_config.enabled_loggers()):
37
+ log.warning("No loggers enabled. LearningRateMonitor will not be used.")
38
+ return
39
+
32
40
  yield LearningRateMonitor(
33
41
  logging_interval=self.logging_interval,
34
42
  log_momentum=self.log_momentum,
@@ -419,7 +419,7 @@ class DirectoryConfig(C.Config):
419
419
 
420
420
  class TrainerConfig(C.Config):
421
421
  # region Active Run Configuration
422
- id: Annotated[str, C.AllowMissing()] = C.MISSING
422
+ id: C.AllowMissing[str] = C.MISSING
423
423
  """ID of the run."""
424
424
  name: list[str] = []
425
425
  """Run name in parts. Full name is constructed by joining the parts with spaces."""
@@ -717,8 +717,9 @@ class TrainerConfig(C.Config):
717
717
 
718
718
  auto_set_default_root_dir: bool = True
719
719
  """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
720
- save_checkpoint_metadata: bool = True
721
- """If enabled, will save additional metadata whenever a checkpoint is saved."""
720
+ save_checkpoint_metadata: Literal[True] = True
721
+ """Will save additional metadata whenever a checkpoint is saved.
722
+ This is a core feature of nshtrainer and cannot be disabled."""
722
723
  auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
723
724
  """If enabled, will automatically set the debug flag to True if:
724
725
  - The trainer is running in fast_dev_run mode.
@@ -1308,6 +1309,11 @@ class TrainerConfig(C.Config):
1308
1309
  if self.barebones and self.shared_parameters:
1309
1310
  raise ValueError("shared_parameters is not supported under barebones mode")
1310
1311
 
1312
+ if not self.save_checkpoint_metadata:
1313
+ raise ValueError(
1314
+ "save_checkpoint_metadata must be True. This is a core feature of nshtrainer and cannot be disabled."
1315
+ )
1316
+
1311
1317
  def _nshtrainer_set_id_if_missing(self):
1312
1318
  """
1313
1319
  Set the ID for the configuration object if it is missing.
@@ -45,6 +45,9 @@ patch_log_hparams_function()
45
45
 
46
46
 
47
47
  class Trainer(LightningTrainer):
48
+ profiler: Profiler
49
+ """Profiler used for profiling the training process."""
50
+
48
51
  CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
49
52
 
50
53
  @property
@@ -469,6 +472,11 @@ class Trainer(LightningTrainer):
469
472
  weights_only: bool = False,
470
473
  storage_options: Any | None = None,
471
474
  ):
475
+ assert self.hparams.save_checkpoint_metadata, (
476
+ "Checkpoint metadata is not enabled. "
477
+ "Please set `hparams.save_checkpoint_metadata=True`."
478
+ )
479
+
472
480
  filepath = Path(filepath)
473
481
 
474
482
  if self.model is None:
@@ -476,7 +484,7 @@ class Trainer(LightningTrainer):
476
484
  "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
477
485
  " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
478
486
  )
479
- with self.profiler.profile("save_checkpoint"): # type: ignore
487
+ with self.profiler.profile("save_checkpoint"):
480
488
  checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
481
489
  # Update the checkpoint for the trainer hyperparameters
482
490
  checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY] = self.hparams.model_dump(
@@ -489,7 +497,7 @@ class Trainer(LightningTrainer):
489
497
 
490
498
  # Save the checkpoint metadata
491
499
  metadata_path = None
492
- if self.hparams.save_checkpoint_metadata and self.is_global_zero:
500
+ if self.is_global_zero:
493
501
  # Generate the metadata and write to disk
494
502
  metadata_path = write_checkpoint_metadata(self, filepath)
495
503