nshtrainer 1.0.0b43__tar.gz → 1.0.0b44__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/__init__.py +4 -0
  4. nshtrainer-1.0.0b44/src/nshtrainer/callbacks/metric_validation.py +75 -0
  5. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/__init__.py +4 -0
  6. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/__init__.py +6 -0
  7. nshtrainer-1.0.0b44/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +21 -0
  8. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/__init__.py +4 -0
  9. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/_config/__init__.py +4 -0
  10. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/_config.py +6 -0
  11. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/README.md +0 -0
  12. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/.nshconfig.generated.json +0 -0
  13. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/__init__.py +0 -0
  14. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/_callback.py +0 -0
  15. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  16. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/_checkpoint/saver.py +0 -0
  17. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/_directory.py +0 -0
  18. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/_experimental/__init__.py +0 -0
  19. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/_hf_hub.py +0 -0
  20. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/actsave.py +0 -0
  21. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/base.py +0 -0
  22. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  23. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  24. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  25. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  26. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  28. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  29. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  30. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/ema.py +0 -0
  31. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  32. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  33. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/interval.py +0 -0
  34. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  35. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  36. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  37. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/print_table.py +0 -0
  38. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  39. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  40. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/timer.py +0 -0
  41. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  42. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  43. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/.gitattributes +0 -0
  44. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  45. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  46. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  47. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  48. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  49. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  50. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  51. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  52. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  53. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  54. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  56. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  57. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  58. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  59. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  60. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  61. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  62. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  63. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  64. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  65. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  66. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  67. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  68. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  69. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  70. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  71. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  72. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  73. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  74. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  75. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  76. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  77. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  78. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  79. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  80. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  81. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  82. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/nn/__init__.py +0 -0
  83. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  84. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  85. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  86. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  87. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  88. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  89. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  90. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  91. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  92. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  93. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  94. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  95. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  96. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  97. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  98. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  99. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  100. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/util/__init__.py +0 -0
  101. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  102. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  103. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  104. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  105. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/data/__init__.py +0 -0
  106. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  107. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/data/datamodule.py +0 -0
  108. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/data/transform.py +0 -0
  109. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/loggers/__init__.py +0 -0
  110. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/loggers/actsave.py +0 -0
  111. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/loggers/base.py +0 -0
  112. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/loggers/csv.py +0 -0
  113. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/loggers/tensorboard.py +0 -0
  114. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/loggers/wandb.py +0 -0
  115. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  116. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/lr_scheduler/base.py +0 -0
  117. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  118. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  119. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/metrics/__init__.py +0 -0
  120. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/metrics/_config.py +0 -0
  121. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/model/__init__.py +0 -0
  122. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/model/base.py +0 -0
  123. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/model/mixins/callback.py +0 -0
  124. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/model/mixins/debug.py +0 -0
  125. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/model/mixins/logger.py +0 -0
  126. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/nn/__init__.py +0 -0
  127. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/nn/mlp.py +0 -0
  128. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/nn/module_dict.py +0 -0
  129. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/nn/module_list.py +0 -0
  130. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/nn/nonlinearity.py +0 -0
  131. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/optimizer.py +0 -0
  132. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/profiler/__init__.py +0 -0
  133. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/profiler/_base.py +0 -0
  134. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/profiler/advanced.py +0 -0
  135. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/profiler/pytorch.py +0 -0
  136. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/profiler/simple.py +0 -0
  137. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/__init__.py +0 -0
  138. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  139. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/accelerator.py +0 -0
  140. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  141. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/plugin/base.py +0 -0
  142. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  143. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/plugin/io.py +0 -0
  144. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  145. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  146. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/signal_connector.py +0 -0
  147. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/strategy.py +0 -0
  148. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/trainer/trainer.py +0 -0
  149. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/_environment_info.py +0 -0
  150. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/bf16.py +0 -0
  151. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/config/__init__.py +0 -0
  152. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/config/dtype.py +0 -0
  153. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/config/duration.py +0 -0
  154. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/environment.py +0 -0
  155. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/path.py +0 -0
  156. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/seed.py +0 -0
  157. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/slurm.py +0 -0
  158. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/typed.py +0 -0
  159. {nshtrainer-1.0.0b43 → nshtrainer-1.0.0b44}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b43
3
+ Version: 1.0.0b44
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta43"
3
+ version = "1.0.0-beta44"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -40,6 +40,10 @@ from .log_epoch import LogEpochCallback as LogEpochCallback
40
40
  from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
41
41
  from .lr_monitor import LearningRateMonitor as LearningRateMonitor
42
42
  from .lr_monitor import LearningRateMonitorConfig as LearningRateMonitorConfig
43
+ from .metric_validation import MetricValidationCallback as MetricValidationCallback
44
+ from .metric_validation import (
45
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
46
+ )
43
47
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
44
48
  from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
45
49
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Literal
5
+
6
+ from lightning.pytorch.utilities.exceptions import MisconfigurationException
7
+ from typing_extensions import final, override, assert_never
8
+
9
+ from .._callback import NTCallbackBase
10
+ from ..metrics import MetricConfig
11
+ from .base import CallbackConfigBase, callback_registry
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ @callback_registry.register
18
+ class MetricValidationCallbackConfig(CallbackConfigBase):
19
+ name: Literal["metric_validation"] = "metric_validation"
20
+
21
+ error_behavior: Literal["raise", "warn"] = "raise"
22
+ """
23
+ Behavior when an error occurs during validation:
24
+ - "raise": Raise an error and stop the training.
25
+ - "warn": Log a warning and continue the training.
26
+ """
27
+
28
+ validate_default_metric: bool = True
29
+ """Whether to validate the default metric from the root config."""
30
+
31
+ metrics: list[MetricConfig] = []
32
+ """List of metrics to validate."""
33
+
34
+ @override
35
+ def create_callbacks(self, trainer_config):
36
+ metrics = self.metrics.copy()
37
+ if (
38
+ self.validate_default_metric
39
+ and (default_metric := trainer_config.primary_metric) is not None
40
+ ):
41
+ metrics.append(default_metric)
42
+
43
+ yield MetricValidationCallback(self, metrics)
44
+
45
+
46
+ class MetricValidationCallback(NTCallbackBase):
47
+ def __init__(
48
+ self, config: MetricValidationCallbackConfig, metrics: list[MetricConfig]
49
+ ):
50
+ super().__init__()
51
+
52
+ self.config = config
53
+ self.metrics = metrics
54
+
55
+ @override
56
+ def on_sanity_check_end(self, trainer, pl_module):
57
+ super().on_sanity_check_end(trainer, pl_module)
58
+
59
+ log.debug("Validating metrics...")
60
+ logged_metrics = set(trainer.logged_metrics.keys())
61
+ for metric in self.metrics:
62
+ if metric.validation_monitor in logged_metrics:
63
+ continue
64
+
65
+ match self.config.error_behavior:
66
+ case "raise":
67
+ raise MisconfigurationException(
68
+ f"Metric '{metric.validation_monitor}' not found in logged metrics."
69
+ )
70
+ case "warn":
71
+ log.warning(
72
+ f"Metric '{metric.validation_monitor}' not found in logged metrics."
73
+ )
74
+ case _:
75
+ assert_never(self.config.error_behavior)
@@ -39,6 +39,9 @@ from nshtrainer.callbacks import (
39
39
  )
40
40
  from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
41
41
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
42
+ from nshtrainer.callbacks import (
43
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
44
+ )
42
45
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
43
46
  from nshtrainer.callbacks import (
44
47
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -287,6 +290,7 @@ __all__ = [
287
290
  "MPIEnvironmentPlugin",
288
291
  "MPSAcceleratorConfig",
289
292
  "MetricConfig",
293
+ "MetricValidationCallbackConfig",
290
294
  "MishNonlinearityConfig",
291
295
  "MixedPrecisionPluginConfig",
292
296
  "NonlinearityConfig",
@@ -28,6 +28,9 @@ from nshtrainer.callbacks import (
28
28
  )
29
29
  from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
30
30
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
31
+ from nshtrainer.callbacks import (
32
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
33
+ )
31
34
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
32
35
  from nshtrainer.callbacks import (
33
36
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -65,6 +68,7 @@ from . import finite_checks as finite_checks
65
68
  from . import gradient_skipping as gradient_skipping
66
69
  from . import log_epoch as log_epoch
67
70
  from . import lr_monitor as lr_monitor
71
+ from . import metric_validation as metric_validation
68
72
  from . import norm_logging as norm_logging
69
73
  from . import print_table as print_table
70
74
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -91,6 +95,7 @@ __all__ = [
91
95
  "LearningRateMonitorConfig",
92
96
  "LogEpochCallbackConfig",
93
97
  "MetricConfig",
98
+ "MetricValidationCallbackConfig",
94
99
  "NormLoggingCallbackConfig",
95
100
  "OnExceptionCheckpointCallbackConfig",
96
101
  "PrintTableMetricsCallbackConfig",
@@ -110,6 +115,7 @@ __all__ = [
110
115
  "gradient_skipping",
111
116
  "log_epoch",
112
117
  "lr_monitor",
118
+ "metric_validation",
113
119
  "norm_logging",
114
120
  "print_table",
115
121
  "rlp_sanity_checks",
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.metric_validation import (
6
+ CallbackConfigBase as CallbackConfigBase,
7
+ )
8
+ from nshtrainer.callbacks.metric_validation import MetricConfig as MetricConfig
9
+ from nshtrainer.callbacks.metric_validation import (
10
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
11
+ )
12
+ from nshtrainer.callbacks.metric_validation import (
13
+ callback_registry as callback_registry,
14
+ )
15
+
16
+ __all__ = [
17
+ "CallbackConfigBase",
18
+ "MetricConfig",
19
+ "MetricValidationCallbackConfig",
20
+ "callback_registry",
21
+ ]
@@ -38,6 +38,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
38
38
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
39
39
  from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
40
40
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
41
+ from nshtrainer.trainer._config import (
42
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
43
+ )
41
44
  from nshtrainer.trainer._config import (
42
45
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
43
46
  )
@@ -164,6 +167,7 @@ __all__ = [
164
167
  "MPIEnvironmentPlugin",
165
168
  "MPSAcceleratorConfig",
166
169
  "MetricConfig",
170
+ "MetricValidationCallbackConfig",
167
171
  "MixedPrecisionPluginConfig",
168
172
  "NormLoggingCallbackConfig",
169
173
  "OnExceptionCheckpointCallbackConfig",
@@ -34,6 +34,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
34
34
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
35
  from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
36
36
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
37
+ from nshtrainer.trainer._config import (
38
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
39
+ )
37
40
  from nshtrainer.trainer._config import (
38
41
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
39
42
  )
@@ -77,6 +80,7 @@ __all__ = [
77
80
  "LoggerConfig",
78
81
  "LoggerConfigBase",
79
82
  "MetricConfig",
83
+ "MetricValidationCallbackConfig",
80
84
  "NormLoggingCallbackConfig",
81
85
  "OnExceptionCheckpointCallbackConfig",
82
86
  "PluginConfig",
@@ -40,6 +40,7 @@ from ..callbacks.base import CallbackConfigBase
40
40
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
41
41
  from ..callbacks.log_epoch import LogEpochCallbackConfig
42
42
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
43
+ from ..callbacks.metric_validation import MetricValidationCallbackConfig
43
44
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
44
45
  from ..callbacks.shared_parameters import SharedParametersCallbackConfig
45
46
  from ..loggers import (
@@ -697,6 +698,10 @@ class TrainerConfig(C.Config):
697
698
  - The trainer is running in fast_dev_run mode.
698
699
  - The trainer is running a sanity check (which happens before starting the training routine).
699
700
  """
701
+ auto_validate_metrics: MetricValidationCallbackConfig | None = (
702
+ MetricValidationCallbackConfig()
703
+ )
704
+ """If enabled, will automatically validate the metrics before starting the training routine."""
700
705
 
701
706
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
702
707
  """
@@ -768,6 +773,7 @@ class TrainerConfig(C.Config):
768
773
  yield self.shared_parameters
769
774
  yield self.reduce_lr_on_plateau_sanity_checking
770
775
  yield self.auto_set_debug_flag
776
+ yield self.auto_validate_metrics
771
777
  yield from self.callbacks
772
778
 
773
779
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
File without changes