nshtrainer 1.0.0b47__tar.gz → 1.0.0b50__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +3 -3
  4. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/early_stopping.py +1 -1
  5. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/metric_validation.py +3 -3
  6. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  7. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/datamodule.py +2 -2
  8. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/__init__.py +0 -1
  9. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +18 -11
  10. nshtrainer-1.0.0b50/src/nshtrainer/metrics/_config.py +25 -0
  11. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/base.py +4 -4
  12. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/mixins/debug.py +1 -1
  13. nshtrainer-1.0.0b50/src/nshtrainer/model/mixins/logger.py +275 -0
  14. nshtrainer-1.0.0b47/src/nshtrainer/metrics/_config.py +0 -42
  15. nshtrainer-1.0.0b47/src/nshtrainer/model/mixins/logger.py +0 -181
  16. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/README.md +0 -0
  17. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/.nshconfig.generated.json +0 -0
  18. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/__init__.py +0 -0
  19. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_callback.py +0 -0
  20. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  21. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_checkpoint/saver.py +0 -0
  22. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_directory.py +0 -0
  23. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_experimental/__init__.py +0 -0
  24. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_hf_hub.py +0 -0
  25. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/__init__.py +0 -0
  26. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/actsave.py +0 -0
  27. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/base.py +0 -0
  28. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  29. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  30. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  31. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  32. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  33. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  34. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/ema.py +0 -0
  35. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  36. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  37. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/interval.py +0 -0
  38. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  39. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  40. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  41. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/print_table.py +0 -0
  42. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  43. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/timer.py +0 -0
  44. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  45. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  46. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/.gitattributes +0 -0
  47. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/__init__.py +0 -0
  48. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  49. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  50. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  51. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  52. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  53. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  54. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  55. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  56. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  57. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  58. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  59. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  60. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  61. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  62. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  63. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  64. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  65. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  66. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  67. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  68. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  69. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  70. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  71. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  72. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  73. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  74. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  75. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  76. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  77. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  78. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  79. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  80. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  81. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  82. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  83. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  84. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  85. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  86. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  87. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  88. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/nn/__init__.py +0 -0
  89. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  90. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  91. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  92. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  93. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  94. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  95. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  96. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  97. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/__init__.py +0 -0
  98. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
  99. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  100. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  101. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  102. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  103. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  104. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  105. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  106. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  107. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  108. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/__init__.py +0 -0
  109. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  110. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  111. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  112. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  113. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/__init__.py +0 -0
  114. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  115. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/transform.py +0 -0
  116. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/actsave.py +0 -0
  117. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/base.py +0 -0
  118. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/csv.py +0 -0
  119. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/tensorboard.py +0 -0
  120. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/wandb.py +0 -0
  121. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  122. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/base.py +0 -0
  123. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  124. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/metrics/__init__.py +0 -0
  125. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/__init__.py +0 -0
  126. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/mixins/callback.py +0 -0
  127. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/__init__.py +0 -0
  128. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/mlp.py +0 -0
  129. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/module_dict.py +0 -0
  130. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/module_list.py +0 -0
  131. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/nonlinearity.py +0 -0
  132. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/optimizer.py +0 -0
  133. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/__init__.py +0 -0
  134. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/_base.py +0 -0
  135. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/advanced.py +0 -0
  136. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/pytorch.py +0 -0
  137. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/simple.py +0 -0
  138. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/__init__.py +0 -0
  139. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/_config.py +0 -0
  140. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  141. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/accelerator.py +0 -0
  142. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  143. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/base.py +0 -0
  144. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  145. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/io.py +0 -0
  146. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  147. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  148. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/signal_connector.py +0 -0
  149. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/strategy.py +0 -0
  150. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/trainer.py +0 -0
  151. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/_environment_info.py +0 -0
  152. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/bf16.py +0 -0
  153. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/config/__init__.py +0 -0
  154. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/config/dtype.py +0 -0
  155. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/config/duration.py +0 -0
  156. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/environment.py +0 -0
  157. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/path.py +0 -0
  158. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/seed.py +0 -0
  159. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/slurm.py +0 -0
  160. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/typed.py +0 -0
  161. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b47
3
+ Version: 1.0.0b50
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta47"
3
+ version = "1.0.0-beta50"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -51,7 +51,7 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
51
51
  class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
52
52
  @property
53
53
  def _metric_name_normalized(self):
54
- return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
54
+ return self.metric.monitor.replace("/", "_").replace(" ", "_").replace(".", "_")
55
55
 
56
56
  @override
57
57
  def __init__(
@@ -69,12 +69,12 @@ class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
69
69
 
70
70
  @override
71
71
  def default_filename(self):
72
- return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
72
+ return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.monitor}}}"
73
73
 
74
74
  @override
75
75
  def topk_sort_key(self, metadata: CheckpointMetadata):
76
76
  return metadata.metrics.get(
77
- self.metric.validation_monitor,
77
+ self.metric.monitor,
78
78
  float("-inf" if self.metric.mode == "max" else "inf"),
79
79
  )
80
80
 
@@ -68,7 +68,7 @@ class EarlyStoppingCallback(_EarlyStopping):
68
68
  del config, metric
69
69
 
70
70
  super().__init__(
71
- monitor=self.metric.validation_monitor,
71
+ monitor=self.metric.monitor,
72
72
  mode=self.metric.mode,
73
73
  patience=self.config.patience,
74
74
  min_delta=self.config.min_delta,
@@ -55,14 +55,14 @@ class MetricValidationCallback(Callback):
55
55
  self.metrics = metrics
56
56
 
57
57
  def _check_metrics(self, trainer: Trainer):
58
- metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
58
+ metric_names = ", ".join(metric.monitor for metric in self.metrics)
59
59
  log.info(f"Validating metrics: {metric_names}...")
60
60
  logged_metrics = set(trainer.logged_metrics.keys())
61
61
 
62
62
  invalid_metrics: list[str] = []
63
63
  for metric in self.metrics:
64
- if metric.validation_monitor not in logged_metrics:
65
- invalid_metrics.append(metric.validation_monitor)
64
+ if metric.monitor not in logged_metrics:
65
+ invalid_metrics.append(metric.monitor)
66
66
 
67
67
  if invalid_metrics:
68
68
  msg = (
@@ -171,7 +171,7 @@ class CustomRLPImplementation(Protocol):
171
171
  __reduce_lr_on_plateau__: bool
172
172
 
173
173
 
174
- class _RLPSanityCheckModuleMixin(LightningModule):
174
+ class RLPSanityCheckModuleMixin(LightningModule):
175
175
  def reduce_lr_on_plateau_config(
176
176
  self,
177
177
  lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
@@ -11,13 +11,13 @@ from lightning.pytorch import LightningDataModule
11
11
  from typing_extensions import Never, TypeVar, deprecated, override
12
12
 
13
13
  from ..model.mixins.callback import CallbackRegistrarModuleMixin
14
- from ..model.mixins.debug import _DebugModuleMixin
14
+ from ..model.mixins.debug import DebugModuleMixin
15
15
 
16
16
  THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
17
17
 
18
18
 
19
19
  class LightningDataModuleBase(
20
- _DebugModuleMixin,
20
+ DebugModuleMixin,
21
21
  CallbackRegistrarModuleMixin,
22
22
  LightningDataModule,
23
23
  ABC,
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Annotated
4
4
 
5
- import nshconfig as C
6
5
  from typing_extensions import TypeAliasType
7
6
 
8
7
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
@@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
7
7
  from typing_extensions import final, override
8
8
 
9
9
  from ..metrics._config import MetricConfig
10
+ from ..util.config import EpochsConfig
10
11
  from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
11
12
 
12
13
 
@@ -21,13 +22,13 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
21
22
  """Metric to monitor.
22
23
  If not provided, the primary metric of the runner will be used."""
23
24
 
24
- patience: int
25
+ patience: int | EpochsConfig
25
26
  r"""Number of epochs with no improvement after which learning rate will be reduced."""
26
27
 
27
28
  factor: float
28
29
  r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
29
30
 
30
- cooldown: int = 0
31
+ cooldown: int | EpochsConfig = 0
31
32
  r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
32
33
 
33
34
  min_lr: float | list[float] = 0.0
@@ -49,28 +50,34 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
49
50
  if (metric := self.metric) is None:
50
51
  from ..trainer import Trainer
51
52
 
52
- assert isinstance(
53
- trainer := lightning_module.trainer, Trainer
54
- ), "The trainer must be a `nshtrainer.Trainer` instance."
53
+ assert isinstance(trainer := lightning_module.trainer, Trainer), (
54
+ "The trainer must be a `nshtrainer.Trainer` instance."
55
+ )
55
56
 
56
- assert (
57
- metric := trainer.hparams.primary_metric
58
- ) is not None, "Primary metric must be provided if metric is not specified."
57
+ assert (metric := trainer.hparams.primary_metric) is not None, (
58
+ "Primary metric must be provided if metric is not specified."
59
+ )
60
+
61
+ if isinstance(patience := self.patience, EpochsConfig):
62
+ patience = int(patience.value)
63
+
64
+ if isinstance(cooldown := self.cooldown, EpochsConfig):
65
+ cooldown = int(cooldown.value)
59
66
 
60
67
  lr_scheduler = ReduceLROnPlateau(
61
68
  optimizer,
62
69
  mode=metric.mode,
63
70
  factor=self.factor,
64
- patience=self.patience,
71
+ patience=patience,
65
72
  threshold=self.threshold,
66
73
  threshold_mode=self.threshold_mode,
67
- cooldown=self.cooldown,
74
+ cooldown=cooldown,
68
75
  min_lr=self.min_lr,
69
76
  eps=self.eps,
70
77
  )
71
78
  return {
72
79
  "scheduler": lr_scheduler,
73
- "monitor": metric.validation_monitor,
80
+ "monitor": metric.monitor,
74
81
  }
75
82
 
76
83
  @override
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ import builtins
4
+ from typing import Any, Literal
5
+
6
+ import nshconfig as C
7
+
8
+
9
+ class MetricConfig(C.Config):
10
+ monitor: str
11
+ """The name of the metric to monitor."""
12
+
13
+ mode: Literal["min", "max"]
14
+ """
15
+ The mode of the primary metric:
16
+ - "min" for metrics that should be minimized (e.g., loss)
17
+ - "max" for metrics that should be maximized (e.g., accuracy)
18
+ """
19
+
20
+ @property
21
+ def best(self):
22
+ return builtins.min if self.mode == "min" else builtins.max
23
+
24
+ def is_better(self, a: Any, b: Any):
25
+ return self.best(a, b) == a
@@ -15,9 +15,9 @@ from lightning.pytorch.utilities.model_helpers import is_overridden
15
15
  from lightning.pytorch.utilities.rank_zero import rank_zero_warn
16
16
  from typing_extensions import Never, TypeVar, deprecated, override
17
17
 
18
- from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
18
+ from ..callbacks.rlp_sanity_checks import RLPSanityCheckModuleMixin
19
19
  from .mixins.callback import CallbackModuleMixin
20
- from .mixins.debug import _DebugModuleMixin
20
+ from .mixins.debug import DebugModuleMixin
21
21
  from .mixins.logger import LoggerLightningModuleMixin
22
22
 
23
23
  log = logging.getLogger(__name__)
@@ -54,8 +54,8 @@ VALID_REDUCE_OPS = (
54
54
 
55
55
 
56
56
  class LightningModuleBase(
57
- _DebugModuleMixin,
58
- _RLPSanityCheckModuleMixin,
57
+ DebugModuleMixin,
58
+ RLPSanityCheckModuleMixin,
59
59
  LoggerLightningModuleMixin,
60
60
  CallbackModuleMixin,
61
61
  LightningModule,
@@ -28,7 +28,7 @@ def _trainer(module: Any):
28
28
  return trainer
29
29
 
30
30
 
31
- class _DebugModuleMixin:
31
+ class DebugModuleMixin:
32
32
  @property
33
33
  def nshtrainer_or_none(self):
34
34
  return _trainer(self)
@@ -0,0 +1,275 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from collections import deque
5
+ from collections.abc import Callable, Generator, Mapping
6
+ from contextlib import contextmanager
7
+ from typing import Any, ClassVar
8
+
9
+ import torchmetrics
10
+ from lightning.pytorch import LightningModule
11
+ from lightning.pytorch.utilities.types import _METRIC
12
+ from lightning_utilities.core.rank_zero import rank_zero_warn
13
+ from typing_extensions import override
14
+
15
+ from ...util.typing_utils import mixin_base_type
16
+
17
+
18
+ @dataclasses.dataclass(frozen=True, kw_only=True)
19
+ class _LogContextKwargs:
20
+ __ignore_fields__: ClassVar[set[str]] = {"prefix", "disabled"}
21
+
22
+ prefix: str | None = None
23
+ disabled: bool | None = None
24
+ prog_bar: bool | None = None
25
+ logger: bool | None = None
26
+ on_step: bool | None = None
27
+ on_epoch: bool | None = None
28
+ reduce_fx: str | Callable | None = None
29
+ enable_graph: bool | None = None
30
+ sync_dist: bool | None = None
31
+ sync_dist_group: Any | None = None
32
+ add_dataloader_idx: bool | None = None
33
+ batch_size: int | None = None
34
+ rank_zero_only: bool | None = None
35
+
36
+ def to_dict(self):
37
+ d = dataclasses.asdict(self)
38
+ for field in self.__ignore_fields__:
39
+ d.pop(field, None)
40
+
41
+ # Pop all None values
42
+ for k in list(d.keys()):
43
+ if d[k] is None:
44
+ d.pop(k)
45
+
46
+ return d
47
+
48
+
49
+ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
50
+ @override
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ self._logger_prefix_stack = deque[_LogContextKwargs]()
55
+
56
+ @property
57
+ def logging_enabled(self) -> bool:
58
+ # Logging is disabled in barebones mode.
59
+ if (trainer := self._trainer) is not None and trainer.barebones:
60
+ # Warn the user once that logging is disabled in barebones mode.
61
+ if not hasattr(self, "_barebones_logging_warned"):
62
+ rank_zero_warn(
63
+ "Logging is disabled in barebones mode. "
64
+ "This is to reduce the overhead of logging in barebones mode. "
65
+ "If you want to enable logging, set `barebones=False` in the Trainer.",
66
+ )
67
+ self._barebones_logging_warned = True
68
+ return False
69
+
70
+ # If no loggers are registered, then logging is disabled.
71
+ if not self.logger:
72
+ return False
73
+
74
+ # Check if the topmost non-null context is disabled
75
+ for context in reversed(self._logger_prefix_stack):
76
+ if context.disabled is not None:
77
+ return not context.disabled
78
+
79
+ # Otherwise, logging is enabled.
80
+ return True
81
+
82
+ @contextmanager
83
+ def log_context(
84
+ self,
85
+ prefix: str | None = None,
86
+ disabled: bool | None = None,
87
+ prog_bar: bool | None = None,
88
+ logger: bool | None = None,
89
+ on_step: bool | None = None,
90
+ on_epoch: bool | None = None,
91
+ reduce_fx: str | Callable | None = None,
92
+ enable_graph: bool | None = None,
93
+ sync_dist: bool | None = None,
94
+ sync_dist_group: Any | None = None,
95
+ add_dataloader_idx: bool | None = None,
96
+ batch_size: int | None = None,
97
+ rank_zero_only: bool | None = None,
98
+ ) -> Generator[None, None, None]:
99
+ self._logger_prefix_stack.append(
100
+ _LogContextKwargs(
101
+ prefix=prefix,
102
+ disabled=disabled,
103
+ prog_bar=prog_bar,
104
+ logger=logger,
105
+ on_step=on_step,
106
+ on_epoch=on_epoch,
107
+ reduce_fx=reduce_fx,
108
+ enable_graph=enable_graph,
109
+ sync_dist=sync_dist,
110
+ sync_dist_group=sync_dist_group,
111
+ add_dataloader_idx=add_dataloader_idx,
112
+ batch_size=batch_size,
113
+ rank_zero_only=rank_zero_only,
114
+ )
115
+ )
116
+ try:
117
+ yield
118
+ finally:
119
+ _ = self._logger_prefix_stack.pop()
120
+
121
+ def _make_prefix_and_kwargs_dict(self, kwargs: _LogContextKwargs):
122
+ prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
123
+
124
+ fn_kwargs: dict[str, Any] = {}
125
+ for c in self._logger_prefix_stack:
126
+ fn_kwargs.update(c.to_dict())
127
+
128
+ fn_kwargs.update(kwargs.to_dict())
129
+ return prefix, fn_kwargs
130
+
131
+ @override
132
+ def log(
133
+ self,
134
+ name: str,
135
+ value: _METRIC,
136
+ prog_bar: bool | None = None,
137
+ logger: bool | None = None,
138
+ on_step: bool | None = None,
139
+ on_epoch: bool | None = None,
140
+ reduce_fx: str | Callable | None = None,
141
+ enable_graph: bool | None = None,
142
+ sync_dist: bool | None = None,
143
+ sync_dist_group: Any | None = None,
144
+ add_dataloader_idx: bool | None = None,
145
+ batch_size: int | None = None,
146
+ metric_attribute: str | None = None,
147
+ rank_zero_only: bool | None = None,
148
+ ) -> None:
149
+ """Log a key, value pair.
150
+
151
+ Example::
152
+
153
+ self.log('train_loss', loss)
154
+
155
+ The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
156
+
157
+ Args:
158
+ name: key to log. Must be identical across all processes if using DDP or any other distributed strategy.
159
+ value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
160
+ prog_bar: if ``True`` logs to the progress bar.
161
+ logger: if ``True`` logs to the logger.
162
+ on_step: if ``True`` logs at this step. The default value is determined by the hook.
163
+ See :ref:`extensions/logging:Automatic Logging` for details.
164
+ on_epoch: if ``True`` logs epoch accumulated metrics. The default value is determined by the hook.
165
+ See :ref:`extensions/logging:Automatic Logging` for details.
166
+ reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
167
+ enable_graph: if ``True``, will not auto detach the graph.
168
+ sync_dist: if ``True``, reduces the metric across devices. Use with care as this may lead to a significant
169
+ communication overhead.
170
+ sync_dist_group: the DDP group to sync across.
171
+ add_dataloader_idx: if ``True``, appends the index of the current dataloader to
172
+ the name (when using multiple dataloaders). If False, user needs to give unique names for
173
+ each dataloader to not mix the values.
174
+ batch_size: Current batch_size. This will be directly inferred from the loaded batch,
175
+ but for some data structures you might need to explicitly provide it.
176
+ metric_attribute: To restore the metric state, Lightning requires the reference of the
177
+ :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
178
+ rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
179
+ rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
180
+ (e.g., early stopping). Warning: Improper use can lead to deadlocks! See
181
+ :ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
182
+
183
+ """
184
+ # If logging is disabled, then do nothing.
185
+ if not self.logging_enabled:
186
+ return
187
+
188
+ prefix, fn_kwargs = self._make_prefix_and_kwargs_dict(
189
+ _LogContextKwargs(
190
+ prog_bar=prog_bar,
191
+ logger=logger,
192
+ on_step=on_step,
193
+ on_epoch=on_epoch,
194
+ reduce_fx=reduce_fx,
195
+ enable_graph=enable_graph,
196
+ sync_dist=sync_dist,
197
+ sync_dist_group=sync_dist_group,
198
+ add_dataloader_idx=add_dataloader_idx,
199
+ batch_size=batch_size,
200
+ rank_zero_only=rank_zero_only,
201
+ )
202
+ )
203
+ name = f"{prefix}{name}"
204
+ return super().log(name, value, metric_attribute=metric_attribute, **fn_kwargs)
205
+
206
+ def log_dict(
207
+ self,
208
+ dictionary: Mapping[str, _METRIC] | torchmetrics.MetricCollection,
209
+ prog_bar: bool | None = None,
210
+ logger: bool | None = None,
211
+ on_step: bool | None = None,
212
+ on_epoch: bool | None = None,
213
+ reduce_fx: str | Callable | None = None,
214
+ enable_graph: bool | None = None,
215
+ sync_dist: bool | None = None,
216
+ sync_dist_group: Any | None = None,
217
+ add_dataloader_idx: bool | None = None,
218
+ batch_size: int | None = None,
219
+ rank_zero_only: bool | None = None,
220
+ ) -> None:
221
+ """Log a dictionary of values at once.
222
+
223
+ Example::
224
+
225
+ values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
226
+ self.log_dict(values)
227
+
228
+ Args:
229
+ dictionary: key value pairs.
230
+ Keys must be identical across all processes if using DDP or any other distributed strategy.
231
+ The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
232
+ prog_bar: if ``True`` logs to the progress base.
233
+ logger: if ``True`` logs to the logger.
234
+ on_step: if ``True`` logs at this step.
235
+ ``None`` auto-logs for training_step but not validation/test_step.
236
+ The default value is determined by the hook.
237
+ See :ref:`extensions/logging:Automatic Logging` for details.
238
+ on_epoch: if ``True`` logs epoch accumulated metrics.
239
+ ``None`` auto-logs for val/test step but not ``training_step``.
240
+ The default value is determined by the hook.
241
+ See :ref:`extensions/logging:Automatic Logging` for details.
242
+ reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
243
+ enable_graph: if ``True``, will not auto-detach the graph
244
+ sync_dist: if ``True``, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant
245
+ communication overhead.
246
+ sync_dist_group: the ddp group to sync across.
247
+ add_dataloader_idx: if ``True``, appends the index of the current dataloader to
248
+ the name (when using multiple). If ``False``, user needs to give unique names for
249
+ each dataloader to not mix values.
250
+ batch_size: Current batch size. This will be directly inferred from the loaded batch,
251
+ but some data structures might need to explicitly provide it.
252
+ rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
253
+ rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
254
+ (e.g., early stopping). Warning: Improper use can lead to deadlocks! See
255
+ :ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
256
+
257
+ """
258
+
259
+ _, fn_kwargs = self._make_prefix_and_kwargs_dict(
260
+ _LogContextKwargs(
261
+ prog_bar=prog_bar,
262
+ logger=logger,
263
+ on_step=on_step,
264
+ on_epoch=on_epoch,
265
+ reduce_fx=reduce_fx,
266
+ enable_graph=enable_graph,
267
+ sync_dist=sync_dist,
268
+ sync_dist_group=sync_dist_group,
269
+ add_dataloader_idx=add_dataloader_idx,
270
+ batch_size=batch_size,
271
+ rank_zero_only=rank_zero_only,
272
+ )
273
+ )
274
+ # NOTE: Prefix will be handled by the individual log calls.
275
+ return super().log_dict(dictionary, **fn_kwargs)
@@ -1,42 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import builtins
4
- from typing import Any, Literal
5
-
6
- import nshconfig as C
7
-
8
-
9
- class MetricConfig(C.Config):
10
- name: str
11
- """The name of the primary metric."""
12
-
13
- mode: Literal["min", "max"]
14
- """
15
- The mode of the primary metric:
16
- - "min" for metrics that should be minimized (e.g., loss)
17
- - "max" for metrics that should be maximized (e.g., accuracy)
18
- """
19
-
20
- @property
21
- def validation_monitor(self) -> str:
22
- return f"val/{self.name}"
23
-
24
- def __post_init__(self):
25
- for split in ("train", "val", "test", "predict"):
26
- if self.name.startswith(f"{split}/"):
27
- raise ValueError(
28
- f"Primary metric name should not start with '{split}/'. "
29
- f"Just use '{self.name[len(split) + 1:]}' instead. "
30
- "The split name is automatically added depending on the context."
31
- )
32
-
33
- @classmethod
34
- def loss(cls, mode: Literal["min", "max"] = "min"):
35
- return cls(name="loss", mode=mode)
36
-
37
- @property
38
- def best(self):
39
- return builtins.min if self.mode == "min" else builtins.max
40
-
41
- def is_better(self, a: Any, b: Any):
42
- return self.best(a, b) == a