nshtrainer 1.0.0b44__tar.gz → 1.0.0b46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/_checkpoint/metadata.py +20 -5
  4. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/_checkpoint/saver.py +6 -2
  5. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/checkpoint/_base.py +1 -1
  6. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/metric_validation.py +33 -18
  7. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -2
  8. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/nn/__init__.py +0 -1
  9. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/nn/mlp.py +60 -60
  10. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/trainer.py +0 -1
  11. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/path.py +2 -1
  12. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/README.md +0 -0
  13. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/.nshconfig.generated.json +0 -0
  14. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/__init__.py +0 -0
  15. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/_callback.py +0 -0
  16. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/_directory.py +0 -0
  17. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/_experimental/__init__.py +0 -0
  18. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/_hf_hub.py +0 -0
  19. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/__init__.py +0 -0
  20. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/actsave.py +0 -0
  21. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/base.py +0 -0
  22. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  23. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  24. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  25. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  26. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  27. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  28. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  29. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/ema.py +0 -0
  30. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  31. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  32. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/interval.py +0 -0
  33. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  34. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  35. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  36. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/print_table.py +0 -0
  37. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  38. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  39. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/timer.py +0 -0
  40. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  41. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  42. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/.gitattributes +0 -0
  43. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/__init__.py +0 -0
  44. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  45. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  46. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  47. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  48. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  49. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  50. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  51. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  52. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  53. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  54. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  56. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  57. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  58. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  59. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  60. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  61. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  62. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  63. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  64. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  65. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  66. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  67. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  68. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  69. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  70. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  71. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  72. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  73. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  74. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  75. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  76. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  77. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  78. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  79. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  80. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  81. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  82. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  83. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  84. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/nn/__init__.py +0 -0
  85. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  86. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  87. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  88. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  89. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  90. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  91. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  92. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  93. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/__init__.py +0 -0
  94. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
  95. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  96. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  97. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  98. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  99. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  100. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  101. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  102. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  103. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/util/__init__.py +0 -0
  104. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  105. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  106. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  107. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  108. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/data/__init__.py +0 -0
  109. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  110. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/data/datamodule.py +0 -0
  111. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/data/transform.py +0 -0
  112. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/loggers/__init__.py +0 -0
  113. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/loggers/actsave.py +0 -0
  114. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/loggers/base.py +0 -0
  115. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/loggers/csv.py +0 -0
  116. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/loggers/tensorboard.py +0 -0
  117. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/loggers/wandb.py +0 -0
  118. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  119. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/lr_scheduler/base.py +0 -0
  120. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  121. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  122. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/metrics/__init__.py +0 -0
  123. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/metrics/_config.py +0 -0
  124. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/model/__init__.py +0 -0
  125. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/model/base.py +0 -0
  126. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/model/mixins/callback.py +0 -0
  127. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/model/mixins/debug.py +0 -0
  128. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/model/mixins/logger.py +0 -0
  129. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/nn/module_dict.py +0 -0
  130. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/nn/module_list.py +0 -0
  131. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/nn/nonlinearity.py +0 -0
  132. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/optimizer.py +0 -0
  133. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/profiler/__init__.py +0 -0
  134. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/profiler/_base.py +0 -0
  135. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/profiler/advanced.py +0 -0
  136. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/profiler/pytorch.py +0 -0
  137. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/profiler/simple.py +0 -0
  138. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/__init__.py +0 -0
  139. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/_config.py +0 -0
  140. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  141. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/accelerator.py +0 -0
  142. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  143. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/plugin/base.py +0 -0
  144. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  145. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/plugin/io.py +0 -0
  146. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  147. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  148. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/signal_connector.py +0 -0
  149. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/trainer/strategy.py +0 -0
  150. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/_environment_info.py +0 -0
  151. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/bf16.py +0 -0
  152. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/config/__init__.py +0 -0
  153. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/config/dtype.py +0 -0
  154. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/config/duration.py +0 -0
  155. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/environment.py +0 -0
  156. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/seed.py +0 -0
  157. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/slurm.py +0 -0
  158. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/typed.py +0 -0
  159. {nshtrainer-1.0.0b44 → nshtrainer-1.0.0b46}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b44
3
+ Version: 1.0.0b46
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta44"
3
+ version = "1.0.0-beta46"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -139,15 +139,30 @@ def remove_checkpoint_metadata(checkpoint_path: Path):
139
139
  log.debug(f"Removed {path}")
140
140
 
141
141
 
142
+ def remove_checkpoint_metadata_link(ckpt_link_path: Path):
143
+ path = _metadata_path(ckpt_link_path)
144
+ # If the metadata does not exist, we can safely ignore this
145
+ if not path.exists(follow_symlinks=False):
146
+ # This is EXTREMELY important here
147
+ # Otherwise, we've already deleted the file that the symlink
148
+ # used to point to, so this always returns False
149
+ log.debug(f"Metadata file does not exist: {path}")
150
+ return
151
+
152
+ # If the metadata exists, we can remove it
153
+ try:
154
+ path.unlink(missing_ok=True)
155
+ except Exception:
156
+ log.warning(f"Failed to remove {path}", exc_info=True)
157
+ else:
158
+ log.debug(f"Removed {path}")
159
+
160
+
142
161
  def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
143
162
  # First, remove any existing metadata files
144
- remove_checkpoint_metadata(linked_checkpoint_path)
163
+ remove_checkpoint_metadata_link(linked_checkpoint_path)
145
164
 
146
165
  # Link the metadata files to the new checkpoint
147
166
  path = _metadata_path(checkpoint_path)
148
167
  linked_path = _metadata_path(linked_checkpoint_path)
149
-
150
- if not path.exists():
151
- raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
152
-
153
168
  try_symlink_or_copy(path, linked_path)
@@ -8,7 +8,11 @@ from pathlib import Path
8
8
  from lightning.pytorch import Trainer
9
9
 
10
10
  from ..util.path import try_symlink_or_copy
11
- from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
11
+ from .metadata import (
12
+ link_checkpoint_metadata,
13
+ remove_checkpoint_metadata,
14
+ remove_checkpoint_metadata_link,
15
+ )
12
16
 
13
17
  log = logging.getLogger(__name__)
14
18
 
@@ -39,7 +43,7 @@ def link_checkpoint(
39
43
  log.debug(f"Removed {linkpath=}")
40
44
 
41
45
  if metadata:
42
- remove_checkpoint_metadata(linkpath)
46
+ remove_checkpoint_metadata_link(linkpath)
43
47
 
44
48
  try_symlink_or_copy(filepath, linkpath)
45
49
  if metadata:
@@ -160,7 +160,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
160
160
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
161
161
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
162
162
 
163
- if trainer.is_global_zero:
163
+ if trainer.hparams.save_checkpoint_metadata and trainer.is_global_zero:
164
164
  # Remove old checkpoints
165
165
  self.remove_old_checkpoints(trainer)
166
166
 
@@ -5,8 +5,8 @@ from typing import Literal
5
5
 
6
6
  from lightning.pytorch.utilities.exceptions import MisconfigurationException
7
7
  from typing_extensions import final, override, assert_never
8
-
9
- from .._callback import NTCallbackBase
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Callback
10
10
  from ..metrics import MetricConfig
11
11
  from .base import CallbackConfigBase, callback_registry
12
12
 
@@ -43,33 +43,48 @@ class MetricValidationCallbackConfig(CallbackConfigBase):
43
43
  yield MetricValidationCallback(self, metrics)
44
44
 
45
45
 
46
- class MetricValidationCallback(NTCallbackBase):
46
+ class MetricValidationCallback(Callback):
47
47
  def __init__(
48
- self, config: MetricValidationCallbackConfig, metrics: list[MetricConfig]
48
+ self,
49
+ config: MetricValidationCallbackConfig,
50
+ metrics: list[MetricConfig],
49
51
  ):
50
52
  super().__init__()
51
53
 
52
54
  self.config = config
53
55
  self.metrics = metrics
54
56
 
55
- @override
56
- def on_sanity_check_end(self, trainer, pl_module):
57
- super().on_sanity_check_end(trainer, pl_module)
58
-
59
- log.debug("Validating metrics...")
57
+ def _check_metrics(self, trainer: Trainer):
58
+ metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
59
+ log.info(f"Validating metrics: {metric_names}...")
60
60
  logged_metrics = set(trainer.logged_metrics.keys())
61
- for metric in self.metrics:
62
- if metric.validation_monitor in logged_metrics:
63
- continue
64
61
 
62
+ invalid_metrics: list[str] = []
63
+ for metric in self.metrics:
64
+ if metric.validation_monitor not in logged_metrics:
65
+ invalid_metrics.append(metric.validation_monitor)
66
+
67
+ if invalid_metrics:
68
+ msg = (
69
+ f"The following metrics were not found in logged metrics: {invalid_metrics}\n"
70
+ f"List of logged metrics: {list(trainer.logged_metrics.keys())}"
71
+ )
65
72
  match self.config.error_behavior:
66
73
  case "raise":
67
- raise MisconfigurationException(
68
- f"Metric '{metric.validation_monitor}' not found in logged metrics."
69
- )
74
+ raise MisconfigurationException(msg)
70
75
  case "warn":
71
- log.warning(
72
- f"Metric '{metric.validation_monitor}' not found in logged metrics."
73
- )
76
+ log.warning(msg)
74
77
  case _:
75
78
  assert_never(self.config.error_behavior)
79
+
80
+ @override
81
+ def on_sanity_check_end(self, trainer, pl_module):
82
+ super().on_sanity_check_end(trainer, pl_module)
83
+
84
+ self._check_metrics(trainer)
85
+
86
+ @override
87
+ def on_validation_end(self, trainer, pl_module):
88
+ super().on_validation_end(trainer, pl_module)
89
+
90
+ self._check_metrics(trainer)
@@ -4,14 +4,12 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
6
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
- from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
8
7
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
9
8
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
10
9
 
11
10
  __all__ = [
12
11
  "AcceleratorConfigBase",
13
12
  "EnvironmentConfig",
14
- "PluginConfigBase",
15
13
  "StrategyConfigBase",
16
14
  "TrainerConfig",
17
15
  ]
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from .mlp import MLP as MLP
4
4
  from .mlp import MLPConfig as MLPConfig
5
- from .mlp import MLPConfigDict as MLPConfigDict
6
5
  from .mlp import ResidualSequential as ResidualSequential
7
6
  from .mlp import custom_seed_context as custom_seed_context
8
7
  from .module_dict import TypedModuleDict as TypedModuleDict
@@ -3,12 +3,12 @@ from __future__ import annotations
3
3
  import contextlib
4
4
  import copy
5
5
  from collections.abc import Callable, Sequence
6
- from typing import Literal, Protocol, runtime_checkable
6
+ from typing import Any, Literal, Protocol, runtime_checkable
7
7
 
8
8
  import nshconfig as C
9
9
  import torch
10
10
  import torch.nn as nn
11
- from typing_extensions import TypedDict, override
11
+ from typing_extensions import deprecated, override
12
12
 
13
13
  from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
14
14
 
@@ -26,29 +26,6 @@ class ResidualSequential(nn.Sequential):
26
26
  return input + super().forward(input)
27
27
 
28
28
 
29
- class MLPConfigDict(TypedDict):
30
- bias: bool
31
- """Whether to include bias terms in the linear layers."""
32
-
33
- no_bias_scalar: bool
34
- """Whether to exclude bias terms when the output dimension is 1."""
35
-
36
- nonlinearity: NonlinearityConfig | None
37
- """Activation function to use between layers."""
38
-
39
- ln: bool | Literal["pre", "post"]
40
- """Whether to apply layer normalization before or after the linear layers."""
41
-
42
- dropout: float | None
43
- """Dropout probability to apply between layers."""
44
-
45
- residual: bool
46
- """Whether to use residual connections between layers."""
47
-
48
- seed: int | None
49
- """Random seed to use for initialization. If None, the default Torch behavior is used."""
50
-
51
-
52
29
  class MLPConfig(C.Config):
53
30
  bias: bool = True
54
31
  """Whether to include bias terms in the linear layers."""
@@ -71,8 +48,15 @@ class MLPConfig(C.Config):
71
48
  seed: int | None = None
72
49
  """Random seed to use for initialization. If None, the default Torch behavior is used."""
73
50
 
74
- def to_kwargs(self) -> MLPConfigDict:
75
- kwargs: MLPConfigDict = {
51
+ @deprecated("Use `nt.nn.MLP(config=...)` instead.")
52
+ def create_module(
53
+ self,
54
+ dims: Sequence[int],
55
+ pre_layers: Sequence[nn.Module] = [],
56
+ post_layers: Sequence[nn.Module] = [],
57
+ linear_cls: LinearModuleConstructor = nn.Linear,
58
+ ):
59
+ kwargs: dict[str, Any] = {
76
60
  "bias": self.bias,
77
61
  "no_bias_scalar": self.no_bias_scalar,
78
62
  "nonlinearity": self.nonlinearity,
@@ -81,18 +65,9 @@ class MLPConfig(C.Config):
81
65
  "residual": self.residual,
82
66
  "seed": self.seed,
83
67
  }
84
- return kwargs
85
-
86
- def create_module(
87
- self,
88
- dims: Sequence[int],
89
- pre_layers: Sequence[nn.Module] = [],
90
- post_layers: Sequence[nn.Module] = [],
91
- linear_cls: LinearModuleConstructor = nn.Linear,
92
- ):
93
68
  return MLP(
94
69
  dims,
95
- **self.to_kwargs(),
70
+ **kwargs,
96
71
  pre_layers=pre_layers,
97
72
  post_layers=post_layers,
98
73
  linear_cls=linear_cls,
@@ -121,50 +96,73 @@ def MLP(
121
96
  | nn.Module
122
97
  | Callable[[], nn.Module]
123
98
  | None = None,
124
- bias: bool = True,
125
- no_bias_scalar: bool = True,
126
- ln: bool | Literal["pre", "post"] = False,
99
+ bias: bool | None = None,
100
+ no_bias_scalar: bool | None = None,
101
+ ln: bool | Literal["pre", "post"] | None = None,
127
102
  dropout: float | None = None,
128
- residual: bool = False,
103
+ residual: bool | None = None,
129
104
  pre_layers: Sequence[nn.Module] = [],
130
105
  post_layers: Sequence[nn.Module] = [],
131
106
  linear_cls: LinearModuleConstructor = nn.Linear,
132
107
  seed: int | None = None,
108
+ config: MLPConfig | None = None,
133
109
  ):
134
110
  """
135
111
  Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
136
112
 
137
113
  Args:
138
114
  dims (Sequence[int]): List of integers representing the dimensions of the MLP.
139
- nonlinearity (Callable[[], nn.Module]): Activation function to use between layers.
140
- activation (Callable[[], nn.Module]): Activation function to use between layers.
141
- bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
142
- no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1. Defaults to True.
143
- ln (bool | Literal["pre", "post"], optional): Whether to apply layer normalization before or after the linear layers. Defaults to False.
144
- dropout (float | None, optional): Dropout probability to apply between layers. Defaults to None.
145
- residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
115
+ nonlinearity (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
116
+ activation (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
117
+ bias (bool | None, optional): Whether to include bias terms in the linear layers.
118
+ no_bias_scalar (bool | None, optional): Whether to exclude bias terms when the output dimension is 1.
119
+ ln (bool | Literal["pre", "post"] | None, optional): Whether to apply layer normalization before or after the linear layers.
120
+ dropout (float | None, optional): Dropout probability to apply between layers.
121
+ residual (bool | None, optional): Whether to use residual connections between layers.
146
122
  pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
147
123
  post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
148
124
  linear_cls (LinearModuleConstructor, optional): Linear module constructor to use. Defaults to nn.Linear.
149
- seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used. Defaults to None.
125
+ seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used.
126
+ config (MLPConfig | None, optional): Configuration object for the MLP. Parameters specified directly take precedence.
150
127
 
151
128
  Returns:
152
129
  nn.Sequential: The constructed MLP.
153
130
  """
154
131
 
155
- with custom_seed_context(seed):
132
+ # Resolve parameters: arg if not None, otherwise config value if config exists, otherwise default
133
+ resolved_bias = bias if bias is not None else (config.bias if config else True)
134
+ resolved_no_bias_scalar = (
135
+ no_bias_scalar
136
+ if no_bias_scalar is not None
137
+ else (config.no_bias_scalar if config else True)
138
+ )
139
+ resolved_nonlinearity = (
140
+ nonlinearity
141
+ if nonlinearity is not None
142
+ else (config.nonlinearity if config else None)
143
+ )
144
+ resolved_ln = ln if ln is not None else (config.ln if config else False)
145
+ resolved_dropout = (
146
+ dropout if dropout is not None else (config.dropout if config else None)
147
+ )
148
+ resolved_residual = (
149
+ residual if residual is not None else (config.residual if config else False)
150
+ )
151
+ resolved_seed = seed if seed is not None else (config.seed if config else None)
152
+
153
+ with custom_seed_context(resolved_seed):
156
154
  if activation is None:
157
- activation = nonlinearity
155
+ activation = resolved_nonlinearity
158
156
 
159
157
  if len(dims) < 2:
160
158
  raise ValueError("mlp requires at least 2 dimensions")
161
- if ln is True:
162
- ln = "pre"
163
- elif isinstance(ln, str) and ln not in ("pre", "post"):
159
+ if resolved_ln is True:
160
+ resolved_ln = "pre"
161
+ elif isinstance(resolved_ln, str) and resolved_ln not in ("pre", "post"):
164
162
  raise ValueError("ln must be a boolean or 'pre' or 'post'")
165
163
 
166
164
  layers: list[nn.Module] = []
167
- if ln == "pre":
165
+ if resolved_ln == "pre":
168
166
  layers.append(nn.LayerNorm(dims[0]))
169
167
 
170
168
  layers.extend(pre_layers)
@@ -172,10 +170,12 @@ def MLP(
172
170
  for i in range(len(dims) - 1):
173
171
  in_features = dims[i]
174
172
  out_features = dims[i + 1]
175
- bias_ = bias and not (no_bias_scalar and out_features == 1)
173
+ bias_ = resolved_bias and not (
174
+ resolved_no_bias_scalar and out_features == 1
175
+ )
176
176
  layers.append(linear_cls(in_features, out_features, bias=bias_))
177
- if dropout is not None:
178
- layers.append(nn.Dropout(dropout))
177
+ if resolved_dropout is not None:
178
+ layers.append(nn.Dropout(resolved_dropout))
179
179
  if i < len(dims) - 2:
180
180
  match activation:
181
181
  case NonlinearityConfigBase():
@@ -192,8 +192,8 @@ def MLP(
192
192
 
193
193
  layers.extend(post_layers)
194
194
 
195
- if ln == "post":
195
+ if resolved_ln == "post":
196
196
  layers.append(nn.LayerNorm(dims[-1]))
197
197
 
198
- cls = ResidualSequential if residual else nn.Sequential
198
+ cls = ResidualSequential if resolved_residual else nn.Sequential
199
199
  return cls(*layers)
@@ -25,7 +25,6 @@ from ..util.bf16 import is_bf16_supported_no_emulation
25
25
  from ._config import LightningTrainerKwargs, TrainerConfig
26
26
  from ._runtime_callback import RuntimeTrackerCallback, Stage
27
27
  from .accelerator import AcceleratorConfigBase
28
- from .plugin import PluginConfigBase
29
28
  from .signal_connector import _SignalConnector
30
29
  from .strategy import StrategyConfigBase
31
30
 
@@ -120,7 +120,8 @@ def try_symlink_or_copy(
120
120
  shutil.copy(file_path, link_path)
121
121
  else:
122
122
  link_path.symlink_to(
123
- symlink_target, target_is_directory=target_is_directory
123
+ symlink_target,
124
+ target_is_directory=target_is_directory,
124
125
  )
125
126
  except Exception:
126
127
  log.warning(
File without changes