nshtrainer 1.0.0b47__tar.gz → 1.0.0b48__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (160) hide show
  1. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +3 -3
  4. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/early_stopping.py +1 -1
  5. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/metric_validation.py +3 -3
  6. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  7. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/data/datamodule.py +2 -2
  8. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/loggers/__init__.py +0 -1
  9. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +7 -7
  10. nshtrainer-1.0.0b48/src/nshtrainer/metrics/_config.py +25 -0
  11. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/model/base.py +4 -4
  12. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/model/mixins/debug.py +1 -1
  13. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/model/mixins/logger.py +12 -6
  14. nshtrainer-1.0.0b47/src/nshtrainer/metrics/_config.py +0 -42
  15. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/README.md +0 -0
  16. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/.nshconfig.generated.json +0 -0
  17. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/__init__.py +0 -0
  18. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/_callback.py +0 -0
  19. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  20. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/_checkpoint/saver.py +0 -0
  21. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/_directory.py +0 -0
  22. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/_experimental/__init__.py +0 -0
  23. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/_hf_hub.py +0 -0
  24. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/__init__.py +0 -0
  25. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/actsave.py +0 -0
  26. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/base.py +0 -0
  27. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  28. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  29. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  30. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  31. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  32. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  33. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/ema.py +0 -0
  34. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  35. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  36. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/interval.py +0 -0
  37. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  38. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  39. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  40. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/print_table.py +0 -0
  41. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  42. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/timer.py +0 -0
  43. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  44. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  45. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/.gitattributes +0 -0
  46. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/__init__.py +0 -0
  47. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  48. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  49. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  50. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  51. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  52. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  53. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  54. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  56. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  58. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  59. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  60. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  61. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  62. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  63. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  64. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  65. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  66. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  67. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  68. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  69. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  70. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  71. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  72. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  73. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  74. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  75. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  76. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  77. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  78. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  79. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  80. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  81. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  82. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  83. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  84. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  85. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  86. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  87. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/nn/__init__.py +0 -0
  88. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  89. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  90. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  91. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  92. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  93. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  94. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  95. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  96. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/__init__.py +0 -0
  97. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
  98. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  99. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  100. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  101. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  102. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  103. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  104. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  105. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  106. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  107. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/util/__init__.py +0 -0
  108. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  109. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  110. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  111. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  112. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/data/__init__.py +0 -0
  113. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  114. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/data/transform.py +0 -0
  115. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/loggers/actsave.py +0 -0
  116. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/loggers/base.py +0 -0
  117. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/loggers/csv.py +0 -0
  118. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/loggers/tensorboard.py +0 -0
  119. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/loggers/wandb.py +0 -0
  120. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  121. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/lr_scheduler/base.py +0 -0
  122. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  123. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/metrics/__init__.py +0 -0
  124. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/model/__init__.py +0 -0
  125. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/model/mixins/callback.py +0 -0
  126. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/nn/__init__.py +0 -0
  127. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/nn/mlp.py +0 -0
  128. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/nn/module_dict.py +0 -0
  129. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/nn/module_list.py +0 -0
  130. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/nn/nonlinearity.py +0 -0
  131. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/optimizer.py +0 -0
  132. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/profiler/__init__.py +0 -0
  133. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/profiler/_base.py +0 -0
  134. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/profiler/advanced.py +0 -0
  135. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/profiler/pytorch.py +0 -0
  136. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/profiler/simple.py +0 -0
  137. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/__init__.py +0 -0
  138. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/_config.py +0 -0
  139. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  140. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/accelerator.py +0 -0
  141. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  142. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/plugin/base.py +0 -0
  143. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  144. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/plugin/io.py +0 -0
  145. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  146. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  147. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/signal_connector.py +0 -0
  148. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/strategy.py +0 -0
  149. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/trainer/trainer.py +0 -0
  150. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/_environment_info.py +0 -0
  151. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/bf16.py +0 -0
  152. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/config/__init__.py +0 -0
  153. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/config/dtype.py +0 -0
  154. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/config/duration.py +0 -0
  155. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/environment.py +0 -0
  156. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/path.py +0 -0
  157. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/seed.py +0 -0
  158. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/slurm.py +0 -0
  159. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/typed.py +0 -0
  160. {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b48}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b47
3
+ Version: 1.0.0b48
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta47"
3
+ version = "1.0.0-beta48"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -51,7 +51,7 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
51
51
  class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
52
52
  @property
53
53
  def _metric_name_normalized(self):
54
- return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
54
+ return self.metric.monitor.replace("/", "_").replace(" ", "_").replace(".", "_")
55
55
 
56
56
  @override
57
57
  def __init__(
@@ -69,12 +69,12 @@ class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
69
69
 
70
70
  @override
71
71
  def default_filename(self):
72
- return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
72
+ return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.monitor}}}"
73
73
 
74
74
  @override
75
75
  def topk_sort_key(self, metadata: CheckpointMetadata):
76
76
  return metadata.metrics.get(
77
- self.metric.validation_monitor,
77
+ self.metric.monitor,
78
78
  float("-inf" if self.metric.mode == "max" else "inf"),
79
79
  )
80
80
 
@@ -68,7 +68,7 @@ class EarlyStoppingCallback(_EarlyStopping):
68
68
  del config, metric
69
69
 
70
70
  super().__init__(
71
- monitor=self.metric.validation_monitor,
71
+ monitor=self.metric.monitor,
72
72
  mode=self.metric.mode,
73
73
  patience=self.config.patience,
74
74
  min_delta=self.config.min_delta,
@@ -55,14 +55,14 @@ class MetricValidationCallback(Callback):
55
55
  self.metrics = metrics
56
56
 
57
57
  def _check_metrics(self, trainer: Trainer):
58
- metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
58
+ metric_names = ", ".join(metric.monitor for metric in self.metrics)
59
59
  log.info(f"Validating metrics: {metric_names}...")
60
60
  logged_metrics = set(trainer.logged_metrics.keys())
61
61
 
62
62
  invalid_metrics: list[str] = []
63
63
  for metric in self.metrics:
64
- if metric.validation_monitor not in logged_metrics:
65
- invalid_metrics.append(metric.validation_monitor)
64
+ if metric.monitor not in logged_metrics:
65
+ invalid_metrics.append(metric.monitor)
66
66
 
67
67
  if invalid_metrics:
68
68
  msg = (
@@ -171,7 +171,7 @@ class CustomRLPImplementation(Protocol):
171
171
  __reduce_lr_on_plateau__: bool
172
172
 
173
173
 
174
- class _RLPSanityCheckModuleMixin(LightningModule):
174
+ class RLPSanityCheckModuleMixin(LightningModule):
175
175
  def reduce_lr_on_plateau_config(
176
176
  self,
177
177
  lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
@@ -11,13 +11,13 @@ from lightning.pytorch import LightningDataModule
11
11
  from typing_extensions import Never, TypeVar, deprecated, override
12
12
 
13
13
  from ..model.mixins.callback import CallbackRegistrarModuleMixin
14
- from ..model.mixins.debug import _DebugModuleMixin
14
+ from ..model.mixins.debug import DebugModuleMixin
15
15
 
16
16
  THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
17
17
 
18
18
 
19
19
  class LightningDataModuleBase(
20
- _DebugModuleMixin,
20
+ DebugModuleMixin,
21
21
  CallbackRegistrarModuleMixin,
22
22
  LightningDataModule,
23
23
  ABC,
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Annotated
4
4
 
5
- import nshconfig as C
6
5
  from typing_extensions import TypeAliasType
7
6
 
8
7
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
@@ -49,13 +49,13 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
49
49
  if (metric := self.metric) is None:
50
50
  from ..trainer import Trainer
51
51
 
52
- assert isinstance(
53
- trainer := lightning_module.trainer, Trainer
54
- ), "The trainer must be a `nshtrainer.Trainer` instance."
52
+ assert isinstance(trainer := lightning_module.trainer, Trainer), (
53
+ "The trainer must be a `nshtrainer.Trainer` instance."
54
+ )
55
55
 
56
- assert (
57
- metric := trainer.hparams.primary_metric
58
- ) is not None, "Primary metric must be provided if metric is not specified."
56
+ assert (metric := trainer.hparams.primary_metric) is not None, (
57
+ "Primary metric must be provided if metric is not specified."
58
+ )
59
59
 
60
60
  lr_scheduler = ReduceLROnPlateau(
61
61
  optimizer,
@@ -70,7 +70,7 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
70
70
  )
71
71
  return {
72
72
  "scheduler": lr_scheduler,
73
- "monitor": metric.validation_monitor,
73
+ "monitor": metric.monitor,
74
74
  }
75
75
 
76
76
  @override
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ import builtins
4
+ from typing import Any, Literal
5
+
6
+ import nshconfig as C
7
+
8
+
9
+ class MetricConfig(C.Config):
10
+ monitor: str
11
+ """The name of the metric to monitor."""
12
+
13
+ mode: Literal["min", "max"]
14
+ """
15
+ The mode of the primary metric:
16
+ - "min" for metrics that should be minimized (e.g., loss)
17
+ - "max" for metrics that should be maximized (e.g., accuracy)
18
+ """
19
+
20
+ @property
21
+ def best(self):
22
+ return builtins.min if self.mode == "min" else builtins.max
23
+
24
+ def is_better(self, a: Any, b: Any):
25
+ return self.best(a, b) == a
@@ -15,9 +15,9 @@ from lightning.pytorch.utilities.model_helpers import is_overridden
15
15
  from lightning.pytorch.utilities.rank_zero import rank_zero_warn
16
16
  from typing_extensions import Never, TypeVar, deprecated, override
17
17
 
18
- from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
18
+ from ..callbacks.rlp_sanity_checks import RLPSanityCheckModuleMixin
19
19
  from .mixins.callback import CallbackModuleMixin
20
- from .mixins.debug import _DebugModuleMixin
20
+ from .mixins.debug import DebugModuleMixin
21
21
  from .mixins.logger import LoggerLightningModuleMixin
22
22
 
23
23
  log = logging.getLogger(__name__)
@@ -54,8 +54,8 @@ VALID_REDUCE_OPS = (
54
54
 
55
55
 
56
56
  class LightningModuleBase(
57
- _DebugModuleMixin,
58
- _RLPSanityCheckModuleMixin,
57
+ DebugModuleMixin,
58
+ RLPSanityCheckModuleMixin,
59
59
  LoggerLightningModuleMixin,
60
60
  CallbackModuleMixin,
61
61
  LightningModule,
@@ -28,7 +28,7 @@ def _trainer(module: Any):
28
28
  return trainer
29
29
 
30
30
 
31
- class _DebugModuleMixin:
31
+ class DebugModuleMixin:
32
32
  @property
33
33
  def nshtrainer_or_none(self):
34
34
  return _trainer(self)
@@ -54,6 +54,12 @@ class _LogContextKwargs:
54
54
  d = dataclasses.asdict(self)
55
55
  for field in self.__ignore_fields__:
56
56
  d.pop(field, None)
57
+
58
+ # Pop all None values
59
+ for k in list(d.keys()):
60
+ if d[k] is None:
61
+ d.pop(k)
62
+
57
63
  return d
58
64
 
59
65
 
@@ -134,18 +140,18 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
134
140
  self,
135
141
  name: str,
136
142
  value: _METRIC,
137
- prog_bar: bool = False,
143
+ prog_bar: bool | None = None,
138
144
  logger: bool | None = None,
139
145
  on_step: bool | None = None,
140
146
  on_epoch: bool | None = None,
141
- reduce_fx: str | Callable = "mean",
142
- enable_graph: bool = False,
143
- sync_dist: bool = False,
147
+ reduce_fx: str | Callable | None = None,
148
+ enable_graph: bool | None = None,
149
+ sync_dist: bool | None = None,
144
150
  sync_dist_group: Any | None = None,
145
- add_dataloader_idx: bool = True,
151
+ add_dataloader_idx: bool | None = None,
146
152
  batch_size: int | None = None,
147
153
  metric_attribute: str | None = None,
148
- rank_zero_only: bool = False,
154
+ rank_zero_only: bool | None = None,
149
155
  ) -> None:
150
156
  # If logging is disabled, then do nothing.
151
157
  if not self.logging_enabled:
@@ -1,42 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import builtins
4
- from typing import Any, Literal
5
-
6
- import nshconfig as C
7
-
8
-
9
- class MetricConfig(C.Config):
10
- name: str
11
- """The name of the primary metric."""
12
-
13
- mode: Literal["min", "max"]
14
- """
15
- The mode of the primary metric:
16
- - "min" for metrics that should be minimized (e.g., loss)
17
- - "max" for metrics that should be maximized (e.g., accuracy)
18
- """
19
-
20
- @property
21
- def validation_monitor(self) -> str:
22
- return f"val/{self.name}"
23
-
24
- def __post_init__(self):
25
- for split in ("train", "val", "test", "predict"):
26
- if self.name.startswith(f"{split}/"):
27
- raise ValueError(
28
- f"Primary metric name should not start with '{split}/'. "
29
- f"Just use '{self.name[len(split) + 1:]}' instead. "
30
- "The split name is automatically added depending on the context."
31
- )
32
-
33
- @classmethod
34
- def loss(cls, mode: Literal["min", "max"] = "min"):
35
- return cls(name="loss", mode=mode)
36
-
37
- @property
38
- def best(self):
39
- return builtins.min if self.mode == "min" else builtins.max
40
-
41
- def is_better(self, a: Any, b: Any):
42
- return self.best(a, b) == a
File without changes