nshtrainer 1.3.5__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/PKG-INFO +1 -1
  2. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/pyproject.toml +1 -1
  3. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/__init__.py +14 -0
  4. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/_checkpoint/metadata.py +4 -1
  5. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/_hf_hub.py +3 -0
  6. nshtrainer-1.4.0/src/nshtrainer/callbacks/checkpoint/_base.py +320 -0
  7. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/lr_monitor.py +9 -1
  8. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/__init__.py +1 -5
  9. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/__init__.py +4 -2
  10. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/_config/__init__.py +4 -2
  11. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_config.py +525 -73
  12. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/trainer.py +11 -2
  13. nshtrainer-1.3.5/src/nshtrainer/_directory.py +0 -72
  14. nshtrainer-1.3.5/src/nshtrainer/callbacks/checkpoint/_base.py +0 -187
  15. nshtrainer-1.3.5/src/nshtrainer/configs/_directory/__init__.py +0 -15
  16. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/README.md +0 -0
  17. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
  18. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/_callback.py +0 -0
  19. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  20. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  21. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  22. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  23. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/base.py +0 -0
  24. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  25. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  26. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  27. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  28. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  29. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  30. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/distributed_prediction_writer.py +0 -0
  31. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  32. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/ema.py +0 -0
  33. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  34. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  35. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/interval.py +0 -0
  36. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  37. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
  38. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  39. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  40. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  41. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  42. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/timer.py +0 -0
  43. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  44. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  45. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/.gitattributes +0 -0
  46. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  47. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  48. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  49. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  50. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  51. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  52. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  53. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  54. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  56. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  58. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  59. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
  60. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  61. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  62. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  63. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  64. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  65. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  66. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  67. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  68. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  69. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  70. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  71. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  72. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  73. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  74. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  75. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  76. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  77. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  78. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  79. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  80. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  81. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  82. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  83. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  84. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  85. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  86. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
  87. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  88. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  89. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
  90. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  91. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  92. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  93. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  94. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  95. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  96. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  97. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  98. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  99. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  100. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  101. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  102. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  103. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  104. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  105. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/__init__.py +0 -0
  106. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  107. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  108. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  109. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  110. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/data/__init__.py +0 -0
  111. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  112. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/data/datamodule.py +0 -0
  113. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/data/transform.py +0 -0
  114. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/loggers/__init__.py +0 -0
  115. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/loggers/actsave.py +0 -0
  116. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/loggers/base.py +0 -0
  117. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/loggers/csv.py +0 -0
  118. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  119. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/loggers/wandb.py +0 -0
  120. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  121. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/base.py +0 -0
  122. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  123. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  124. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/metrics/__init__.py +0 -0
  125. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/metrics/_config.py +0 -0
  126. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/model/__init__.py +0 -0
  127. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/model/base.py +0 -0
  128. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/model/mixins/callback.py +0 -0
  129. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/model/mixins/debug.py +0 -0
  130. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/model/mixins/logger.py +0 -0
  131. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/nn/__init__.py +0 -0
  132. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/nn/mlp.py +0 -0
  133. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/nn/module_dict.py +0 -0
  134. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/nn/module_list.py +0 -0
  135. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  136. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/nn/rng.py +0 -0
  137. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/optimizer.py +0 -0
  138. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/profiler/__init__.py +0 -0
  139. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/profiler/_base.py +0 -0
  140. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/profiler/advanced.py +0 -0
  141. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  142. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/profiler/simple.py +0 -0
  143. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/__init__.py +0 -0
  144. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_distributed_prediction_result.py +0 -0
  145. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
  146. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  147. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/accelerator.py +0 -0
  148. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  149. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
  150. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  151. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
  152. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  153. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  154. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  155. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/trainer/strategy.py +0 -0
  156. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/_environment_info.py +0 -0
  157. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/bf16.py +0 -0
  158. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/code_upload.py +0 -0
  159. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/config/__init__.py +0 -0
  160. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/config/dtype.py +0 -0
  161. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/config/duration.py +0 -0
  162. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/environment.py +0 -0
  163. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/path.py +0 -0
  164. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/seed.py +0 -0
  165. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/slurm.py +0 -0
  166. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/typed.py +0 -0
  167. {nshtrainer-1.3.5 → nshtrainer-1.4.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.5
3
+ Version: 1.4.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nshtrainer"
3
- version = "1.3.5"
3
+ version = "1.4.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
@@ -19,3 +19,17 @@ try:
19
19
  from . import configs as configs
20
20
  except BaseException:
21
21
  pass
22
+
23
+ try:
24
+ from importlib.metadata import PackageNotFoundError, version
25
+ except ImportError:
26
+ # For Python <3.8
27
+ from importlib_metadata import ( # pyright: ignore[reportMissingImports]
28
+ PackageNotFoundError,
29
+ version,
30
+ )
31
+
32
+ try:
33
+ __version__ = version(__name__)
34
+ except PackageNotFoundError:
35
+ __version__ = "unknown"
@@ -85,6 +85,7 @@ def _generate_checkpoint_metadata(
85
85
  trainer: Trainer,
86
86
  checkpoint_path: Path,
87
87
  metadata_path: Path,
88
+ compute_checksum: bool = True,
88
89
  ):
89
90
  checkpoint_timestamp = datetime.datetime.now()
90
91
  start_timestamp = trainer.start_time()
@@ -105,7 +106,9 @@ def _generate_checkpoint_metadata(
105
106
  # moving the checkpoint directory
106
107
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
107
108
  checkpoint_filename=checkpoint_path.name,
108
- checkpoint_checksum=compute_file_checksum(checkpoint_path),
109
+ checkpoint_checksum=compute_file_checksum(checkpoint_path)
110
+ if compute_checksum
111
+ else "",
109
112
  run_id=trainer.hparams.id,
110
113
  name=trainer.hparams.full_name,
111
114
  project=trainer.hparams.project,
@@ -91,6 +91,9 @@ class HuggingFaceHubConfig(CallbackConfigBase):
91
91
 
92
92
  @override
93
93
  def create_callbacks(self, trainer_config):
94
+ if not self:
95
+ return
96
+
94
97
  # Attempt to login. If it fails, we'll log a warning or error based on the configuration.
95
98
  try:
96
99
  api = _api(self.token)
@@ -0,0 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import string
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Callable
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
9
+
10
+ import numpy as np
11
+ import torch
12
+ from lightning.pytorch import Trainer
13
+ from lightning.pytorch.callbacks import Checkpoint
14
+ from typing_extensions import override
15
+
16
+ from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
17
+ from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
18
+ from ..base import CallbackConfigBase
19
+
20
+ if TYPE_CHECKING:
21
+ from ...trainer._config import TrainerConfig
22
+
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ class _FormatDict(dict):
28
+ """A dictionary that returns an empty string for missing keys when formatting."""
29
+
30
+ def __missing__(self, key):
31
+ log.debug(
32
+ f"Missing format key '{key}' in checkpoint filename, using empty string"
33
+ )
34
+ return ""
35
+
36
+
37
+ def _get_checkpoint_metadata(dirpath: Path) -> list[CheckpointMetadata]:
38
+ """Get all checkpoint metadata from a directory."""
39
+ return [
40
+ CheckpointMetadata.from_file(p)
41
+ for p in dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
42
+ if p.is_file() and not p.is_symlink()
43
+ ]
44
+
45
+
46
+ def _sort_checkpoint_metadata(
47
+ metas: list[CheckpointMetadata],
48
+ key_fn: Callable[[CheckpointMetadata], Any],
49
+ reverse: bool = False,
50
+ ) -> list[CheckpointMetadata]:
51
+ """Sort checkpoint metadata by the given key function."""
52
+ return sorted(metas, key=key_fn, reverse=reverse)
53
+
54
+
55
+ def _remove_checkpoints(
56
+ trainer: Trainer,
57
+ dirpath: Path,
58
+ metas_to_remove: list[CheckpointMetadata],
59
+ ) -> None:
60
+ """Remove checkpoint files and their metadata."""
61
+ for meta in metas_to_remove:
62
+ ckpt_path = dirpath / meta.checkpoint_filename
63
+ if not ckpt_path.exists():
64
+ log.warning(
65
+ f"Checkpoint file not found: {ckpt_path}\n"
66
+ "Skipping removal of the checkpoint metadata."
67
+ )
68
+ continue
69
+
70
+ remove_checkpoint(trainer, ckpt_path, metadata=True)
71
+ log.debug(f"Removed checkpoint: {ckpt_path}")
72
+
73
+
74
+ def _update_symlink(
75
+ dirpath: Path,
76
+ symlink_path: Path | None,
77
+ sort_key_fn: Callable[[CheckpointMetadata], Any],
78
+ sort_reverse: bool,
79
+ ) -> None:
80
+ """Update symlink to point to the best checkpoint."""
81
+ if symlink_path is None:
82
+ return
83
+
84
+ # Get all checkpoint metadata after any removals
85
+ remaining_metas = _get_checkpoint_metadata(dirpath)
86
+
87
+ if remaining_metas:
88
+ # Sort by the key function
89
+ remaining_metas = _sort_checkpoint_metadata(
90
+ remaining_metas, sort_key_fn, sort_reverse
91
+ )
92
+
93
+ # Link to the best checkpoint
94
+ best_meta = remaining_metas[0]
95
+ best_filepath = dirpath / best_meta.checkpoint_filename
96
+ link_checkpoint(best_filepath, symlink_path, metadata=True)
97
+ log.debug(f"Updated symlink {symlink_path.name} -> {best_filepath.name}")
98
+ else:
99
+ log.warning(f"No checkpoints found in {dirpath} to create symlink.")
100
+
101
+
102
+ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
103
+ dirpath: str | Path | None = None
104
+ """Directory path to save the checkpoint file."""
105
+
106
+ filename: str | None = None
107
+ """Checkpoint filename. This must not include the extension.
108
+ If None, the default filename will be used."""
109
+
110
+ save_weights_only: bool = False
111
+ """Whether to save only the model's weights or the entire model object."""
112
+
113
+ save_symlink: bool = True
114
+ """Whether to create a symlink to the saved checkpoint."""
115
+
116
+ topk: int | Literal["all"] = 1
117
+ """The number of checkpoints to keep."""
118
+
119
+ @abstractmethod
120
+ def create_checkpoint(
121
+ self,
122
+ trainer_config: TrainerConfig,
123
+ dirpath: Path,
124
+ ) -> "CheckpointBase | None": ...
125
+
126
+ @override
127
+ def create_callbacks(self, trainer_config):
128
+ dirpath = Path(
129
+ self.dirpath
130
+ or trainer_config.directory.resolve_subdirectory(
131
+ trainer_config.id, "checkpoint"
132
+ )
133
+ )
134
+
135
+ if (callback := self.create_checkpoint(trainer_config, dirpath)) is not None:
136
+ yield callback
137
+
138
+
139
+ TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
140
+
141
+
142
+ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
143
+ def __init__(self, config: TConfig, dirpath: Path):
144
+ super().__init__()
145
+
146
+ self.config = config
147
+ self.dirpath = dirpath / self.name()
148
+ self.dirpath.mkdir(parents=True, exist_ok=True)
149
+ self.symlink_dirpath = dirpath
150
+
151
+ @abstractmethod
152
+ def default_filename(self) -> str: ...
153
+
154
+ @abstractmethod
155
+ def name(self) -> str: ...
156
+
157
+ def extension(self) -> str:
158
+ return ".ckpt"
159
+
160
+ @abstractmethod
161
+ def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
162
+
163
+ @abstractmethod
164
+ def topk_sort_reverse(self) -> bool: ...
165
+
166
+ def symlink_path(self):
167
+ if not self.config.save_symlink:
168
+ return None
169
+
170
+ return self.symlink_dirpath / f"{self.name()}{self.extension()}"
171
+
172
+ def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
173
+ if (filename := self.config.filename) is None:
174
+ filename = self.default_filename()
175
+
176
+ # Extract all field names from the format string
177
+ field_names = [
178
+ fname for _, fname, _, _ in string.Formatter().parse(filename) if fname
179
+ ]
180
+
181
+ # Filter current_metrics to only include keys that are in the format string
182
+ format_dict = {k: v for k, v in current_metrics.items() if k in field_names}
183
+
184
+ try:
185
+ formatted_filename = filename.format(**format_dict)
186
+ except KeyError as e:
187
+ log.warning(
188
+ f"Missing key {e} in {filename=} with {format_dict=}. Using default values."
189
+ )
190
+ # Provide a simple fallback for missing keys
191
+ formatted_filename = string.Formatter().vformat(
192
+ filename, (), _FormatDict(format_dict)
193
+ )
194
+
195
+ return self.dirpath / f"{formatted_filename}{self.extension()}"
196
+
197
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
198
+ current_metrics: dict[str, Any] = {
199
+ "epoch": trainer.current_epoch,
200
+ "step": trainer.global_step,
201
+ }
202
+
203
+ for name, value in trainer.callback_metrics.items():
204
+ match value:
205
+ case torch.Tensor() if value.numel() == 1:
206
+ value = value.detach().cpu().item()
207
+ case np.ndarray() if value.size == 1:
208
+ value = value.item()
209
+ case _:
210
+ pass
211
+
212
+ current_metrics[name] = value
213
+
214
+ log.debug(
215
+ f"Current metrics: {current_metrics}, {trainer.callback_metrics=}, {trainer.logged_metrics=}"
216
+ )
217
+ return current_metrics
218
+
219
+ def save_checkpoints(self, trainer: Trainer):
220
+ log.debug(
221
+ f"{type(self).__name__}.save_checkpoints() called at {trainer.current_epoch=}, {trainer.global_step=}"
222
+ )
223
+ # Also print out the current stack trace for debugging
224
+ if log.isEnabledFor(logging.DEBUG):
225
+ import traceback
226
+
227
+ stack = traceback.extract_stack()
228
+ log.debug(f"Stack trace: {''.join(traceback.format_list(stack))}")
229
+
230
+ if self._should_skip_saving_checkpoint(trainer):
231
+ return
232
+
233
+ from ...trainer import Trainer as NTTrainer
234
+
235
+ if not isinstance(trainer, NTTrainer):
236
+ raise TypeError(
237
+ f"Trainer must be an instance of {NTTrainer.__name__}, "
238
+ f"but got {type(trainer).__name__}"
239
+ )
240
+
241
+ current_metrics = self.current_metrics(trainer)
242
+ filepath = self.resolve_checkpoint_path(current_metrics)
243
+
244
+ # Get all existing checkpoint metadata
245
+ existing_metas = _get_checkpoint_metadata(self.dirpath)
246
+
247
+ # Determine which checkpoints to remove
248
+ to_remove: list[CheckpointMetadata] = []
249
+ should_save = True
250
+
251
+ # Check if we should save this checkpoint
252
+ if (topk := self.config.topk) != "all" and len(existing_metas) >= topk:
253
+ # Generate hypothetical metadata for the current checkpoint
254
+ hypothetical_meta = _generate_checkpoint_metadata(
255
+ trainer=trainer,
256
+ checkpoint_path=filepath,
257
+ metadata_path=filepath.with_suffix(CheckpointMetadata.PATH_SUFFIX),
258
+ compute_checksum=False,
259
+ )
260
+
261
+ # Add the hypothetical metadata to the list and sort
262
+ metas = _sort_checkpoint_metadata(
263
+ [*existing_metas, hypothetical_meta],
264
+ self.topk_sort_key,
265
+ self.topk_sort_reverse(),
266
+ )
267
+
268
+ # If the hypothetical metadata is not in the top-k, skip saving
269
+ if hypothetical_meta not in metas[:topk]:
270
+ log.debug(
271
+ f"Skipping checkpoint save: would not make top {topk} "
272
+ f"based on {self.topk_sort_key.__name__}"
273
+ )
274
+ should_save = False
275
+ else:
276
+ # Determine which existing checkpoints to remove
277
+ to_remove = metas[topk:]
278
+ assert hypothetical_meta not in to_remove, (
279
+ "Hypothetical metadata should not be in the to_remove list."
280
+ )
281
+ log.debug(
282
+ f"Removing checkpoints: {[meta.checkpoint_filename for meta in to_remove]} "
283
+ f"and saving the new checkpoint: {hypothetical_meta.checkpoint_filename}"
284
+ )
285
+
286
+ # Only save if it would make it into the top-k
287
+ if should_save:
288
+ # Save the new checkpoint
289
+ trainer.save_checkpoint(
290
+ filepath,
291
+ weights_only=self.config.save_weights_only,
292
+ )
293
+
294
+ if trainer.is_global_zero:
295
+ # Remove old checkpoints that should be deleted
296
+ if to_remove:
297
+ _remove_checkpoints(trainer, self.dirpath, to_remove)
298
+
299
+ # Update the symlink to point to the best checkpoint
300
+ _update_symlink(
301
+ self.dirpath,
302
+ self.symlink_path(),
303
+ self.topk_sort_key,
304
+ self.topk_sort_reverse(),
305
+ )
306
+
307
+ # Barrier to ensure all processes have completed checkpoint operations
308
+ trainer.strategy.barrier()
309
+
310
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
311
+ from lightning.pytorch.trainer.states import TrainerFn
312
+
313
+ return (
314
+ bool(
315
+ getattr(trainer, "fast_dev_run", False)
316
+ ) # disable checkpointing with fast_dev_run
317
+ or trainer.state.fn
318
+ != TrainerFn.FITTING # don't save anything during non-fit
319
+ or trainer.sanity_checking # don't save anything during sanity check
320
+ )
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  from typing import Literal
4
5
 
5
6
  from lightning.pytorch.callbacks import LearningRateMonitor
6
- from typing_extensions import final
7
+ from typing_extensions import final, override
7
8
 
8
9
  from .base import CallbackConfigBase, callback_registry
9
10
 
11
+ log = logging.getLogger(__name__)
12
+
10
13
 
11
14
  @final
12
15
  @callback_registry.register
@@ -28,7 +31,12 @@ class LearningRateMonitorConfig(CallbackConfigBase):
28
31
  Option to also log the weight decay values of the optimizer. Defaults to False.
29
32
  """
30
33
 
34
+ @override
31
35
  def create_callbacks(self, trainer_config):
36
+ if not list(trainer_config.enabled_loggers()):
37
+ log.warning("No loggers enabled. LearningRateMonitor will not be used.")
38
+ return
39
+
32
40
  yield LearningRateMonitor(
33
41
  logging_interval=self.logging_interval,
34
42
  log_momentum=self.log_momentum,
@@ -5,7 +5,6 @@ __codegen__ = True
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
7
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
- from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
8
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
10
9
  from nshtrainer._hf_hub import (
11
10
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
@@ -126,9 +125,9 @@ from nshtrainer.trainer._config import (
126
125
  CheckpointCallbackConfig as CheckpointCallbackConfig,
127
126
  )
128
127
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
128
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
129
129
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
130
130
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
131
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
132
131
  from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
133
132
  from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
134
133
  from nshtrainer.trainer.accelerator import (
@@ -227,7 +226,6 @@ from nshtrainer.util.config import EpochsConfig as EpochsConfig
227
226
  from nshtrainer.util.config import StepsConfig as StepsConfig
228
227
 
229
228
  from . import _checkpoint as _checkpoint
230
- from . import _directory as _directory
231
229
  from . import _hf_hub as _hf_hub
232
230
  from . import callbacks as callbacks
233
231
  from . import loggers as loggers
@@ -338,7 +336,6 @@ __all__ = [
338
336
  "RpropConfig",
339
337
  "SGDConfig",
340
338
  "SLURMEnvironmentPlugin",
341
- "SanityCheckingConfig",
342
339
  "SharedParametersCallbackConfig",
343
340
  "SiLUNonlinearityConfig",
344
341
  "SigmoidNonlinearityConfig",
@@ -367,7 +364,6 @@ __all__ = [
367
364
  "XLAEnvironmentPlugin",
368
365
  "XLAPluginConfig",
369
366
  "_checkpoint",
370
- "_directory",
371
367
  "_hf_hub",
372
368
  "accelerator_registry",
373
369
  "callback_registry",
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
+ from nshtrainer.trainer._config import (
26
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
27
+ )
25
28
  from nshtrainer.trainer._config import (
26
29
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
27
30
  )
@@ -51,7 +54,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
51
54
  from nshtrainer.trainer._config import (
52
55
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
53
56
  )
54
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
57
59
  )
@@ -152,6 +154,7 @@ __all__ = [
152
154
  "DebugFlagCallbackConfig",
153
155
  "DeepSpeedPluginConfig",
154
156
  "DirectoryConfig",
157
+ "DirectorySetupCallbackConfig",
155
158
  "DistributedPredictionWriterConfig",
156
159
  "DoublePrecisionPluginConfig",
157
160
  "EarlyStoppingCallbackConfig",
@@ -180,7 +183,6 @@ __all__ = [
180
183
  "ProfilerConfig",
181
184
  "RLPSanityChecksCallbackConfig",
182
185
  "SLURMEnvironmentPlugin",
183
- "SanityCheckingConfig",
184
186
  "SharedParametersCallbackConfig",
185
187
  "StrategyConfig",
186
188
  "StrategyConfigBase",
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
+ from nshtrainer.trainer._config import (
22
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
23
+ )
21
24
  from nshtrainer.trainer._config import (
22
25
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
23
26
  )
@@ -48,7 +51,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
48
51
  from nshtrainer.trainer._config import (
49
52
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
52
54
  from nshtrainer.trainer._config import (
53
55
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
54
56
  )
@@ -70,6 +72,7 @@ __all__ = [
70
72
  "CheckpointSavingConfig",
71
73
  "DebugFlagCallbackConfig",
72
74
  "DirectoryConfig",
75
+ "DirectorySetupCallbackConfig",
73
76
  "EarlyStoppingCallbackConfig",
74
77
  "EnvironmentConfig",
75
78
  "GradientClippingConfig",
@@ -86,7 +89,6 @@ __all__ = [
86
89
  "PluginConfig",
87
90
  "ProfilerConfig",
88
91
  "RLPSanityChecksCallbackConfig",
89
- "SanityCheckingConfig",
90
92
  "SharedParametersCallbackConfig",
91
93
  "StrategyConfig",
92
94
  "TensorboardLoggerConfig",