nshtrainer 1.4.0__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/PKG-INFO +2 -2
  2. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/pyproject.toml +8 -3
  3. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_callback.py +50 -3
  4. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/__init__.py +1 -1
  5. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/_base.py +2 -2
  6. nshtrainer-1.5.0/src/nshtrainer/callbacks/log_epoch.py +136 -0
  7. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/print_table.py +2 -2
  8. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -0
  9. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/__init__.py +0 -2
  10. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/optimizer/__init__.py +0 -2
  11. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/__init__.py +1 -2
  12. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/actsave.py +7 -1
  13. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/wandb.py +5 -5
  14. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/base.py +1 -1
  15. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/mixins/callback.py +0 -17
  16. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/mixins/logger.py +1 -0
  17. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/module_dict.py +4 -4
  18. nshtrainer-1.5.0/src/nshtrainer/nn/module_list.py +52 -0
  19. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/nonlinearity.py +15 -2
  20. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/optimizer.py +2 -4
  21. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_config.py +1 -1
  22. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/accelerator.py +1 -2
  23. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/__init__.py +1 -2
  24. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/code_upload.py +1 -1
  25. nshtrainer-1.4.0/src/nshtrainer/callbacks/log_epoch.py +0 -49
  26. nshtrainer-1.4.0/src/nshtrainer/nn/module_list.py +0 -52
  27. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/README.md +0 -0
  28. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
  29. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/__init__.py +0 -0
  30. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  31. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  32. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  33. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_hf_hub.py +0 -0
  34. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  35. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/base.py +0 -0
  36. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  37. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  38. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  39. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  40. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  41. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  42. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/distributed_prediction_writer.py +0 -0
  43. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  44. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/ema.py +0 -0
  45. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  46. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  47. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/interval.py +0 -0
  48. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  49. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
  50. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  51. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  52. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/timer.py +0 -0
  53. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  54. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  55. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/.gitattributes +0 -0
  56. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  58. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  59. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  60. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  61. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  62. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  63. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  64. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  65. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  66. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  67. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  68. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  69. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
  70. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  71. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  72. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  73. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  74. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  75. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  76. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  77. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  78. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  79. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  80. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  81. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  82. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  83. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  84. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  85. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  86. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  87. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  88. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  89. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  90. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  91. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  92. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  93. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  94. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  95. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  96. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
  97. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  98. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  99. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
  100. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  101. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  102. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  103. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  104. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  105. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/__init__.py +0 -0
  106. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
  107. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  108. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  109. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  110. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  111. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  112. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  113. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  114. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  115. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  116. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/__init__.py +0 -0
  117. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  118. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  119. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  120. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  121. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/__init__.py +0 -0
  122. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  123. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/datamodule.py +0 -0
  124. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/transform.py +0 -0
  125. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/base.py +0 -0
  126. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/csv.py +0 -0
  127. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  128. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  129. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  130. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  131. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/metrics/__init__.py +0 -0
  132. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/metrics/_config.py +0 -0
  133. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/__init__.py +0 -0
  134. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/base.py +0 -0
  135. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/mixins/debug.py +0 -0
  136. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/__init__.py +0 -0
  137. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/mlp.py +0 -0
  138. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/rng.py +0 -0
  139. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/__init__.py +0 -0
  140. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/_base.py +0 -0
  141. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/advanced.py +0 -0
  142. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  143. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/simple.py +0 -0
  144. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/__init__.py +0 -0
  145. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_distributed_prediction_result.py +0 -0
  146. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
  147. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  148. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
  149. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  150. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
  151. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  152. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  153. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  154. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/strategy.py +0 -0
  155. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/trainer.py +0 -0
  156. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/_environment_info.py +0 -0
  157. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/bf16.py +0 -0
  158. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/config/__init__.py +0 -0
  159. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/config/dtype.py +0 -0
  160. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/config/duration.py +0 -0
  161. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/environment.py +0 -0
  162. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/path.py +0 -0
  163. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/seed.py +0 -0
  164. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/slurm.py +0 -0
  165. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/typed.py +0 -0
  166. {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.4.0
3
+ Version: 1.5.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -14,7 +14,7 @@ Provides-Extra: extra
14
14
  Requires-Dist: GitPython ; extra == "extra"
15
15
  Requires-Dist: huggingface-hub ; extra == "extra"
16
16
  Requires-Dist: lightning
17
- Requires-Dist: nshconfig (>0.39)
17
+ Requires-Dist: nshconfig (>=0.43)
18
18
  Requires-Dist: nshrunner ; extra == "extra"
19
19
  Requires-Dist: nshutils ; extra == "extra"
20
20
  Requires-Dist: numpy
@@ -1,13 +1,13 @@
1
1
  [project]
2
2
  name = "nshtrainer"
3
- version = "1.4.0"
3
+ version = "1.5.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
7
7
  readme = "README.md"
8
8
 
9
9
  dependencies = [
10
- "nshconfig>0.39",
10
+ "nshconfig>=0.43",
11
11
  "psutil",
12
12
  "numpy",
13
13
  "torch",
@@ -47,7 +47,12 @@ deprecateTypingAliases = true
47
47
  strictListInference = true
48
48
  strictDictionaryInference = true
49
49
  strictSetInference = true
50
- reportPrivateImportUsage = false
50
+ reportPrivateImportUsage = "none"
51
+ reportMatchNotExhaustive = "warning"
52
+ reportOverlappingOverload = "warning"
53
+ reportUnnecessaryTypeIgnoreComment = "warning"
54
+ reportImplicitOverride = "warning"
55
+ reportIncompatibleMethodOverride = "information"
51
56
 
52
57
  [tool.ruff.lint]
53
58
  select = ["FA102", "FA100"]
@@ -8,38 +8,46 @@ from lightning.pytorch import LightningModule
8
8
  from lightning.pytorch.callbacks import Callback as _LightningCallback
9
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
10
10
  from torch.optim import Optimizer
11
+ from typing_extensions import override
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from .trainer import Trainer
14
15
 
15
16
 
16
17
  class NTCallbackBase(_LightningCallback):
18
+ @override
17
19
  def setup( # pyright: ignore[reportIncompatibleMethodOverride]
18
20
  self, trainer: Trainer, pl_module: LightningModule, stage: str
19
21
  ) -> None:
20
22
  """Called when fit, validate, test, predict, or tune begins."""
21
23
 
24
+ @override
22
25
  def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
23
26
  self, trainer: Trainer, pl_module: LightningModule, stage: str
24
27
  ) -> None:
25
28
  """Called when fit, validate, test, predict, or tune ends."""
26
29
 
30
+ @override
27
31
  def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
28
32
  """Called when fit begins."""
29
33
 
34
+ @override
30
35
  def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
31
36
  """Called when fit ends."""
32
37
 
38
+ @override
33
39
  def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
34
40
  self, trainer: Trainer, pl_module: LightningModule
35
41
  ) -> None:
36
42
  """Called when the validation sanity check starts."""
37
43
 
44
+ @override
38
45
  def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
39
46
  self, trainer: Trainer, pl_module: LightningModule
40
47
  ) -> None:
41
48
  """Called when the validation sanity check ends."""
42
49
 
50
+ @override
43
51
  def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
44
52
  self,
45
53
  trainer: Trainer,
@@ -49,6 +57,7 @@ class NTCallbackBase(_LightningCallback):
49
57
  ) -> None:
50
58
  """Called when the train batch begins."""
51
59
 
60
+ @override
52
61
  def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
53
62
  self,
54
63
  trainer: Trainer,
@@ -65,11 +74,13 @@ class NTCallbackBase(_LightningCallback):
65
74
 
66
75
  """
67
76
 
77
+ @override
68
78
  def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
69
79
  self, trainer: Trainer, pl_module: LightningModule
70
80
  ) -> None:
71
81
  """Called when the train epoch begins."""
72
82
 
83
+ @override
73
84
  def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
74
85
  self, trainer: Trainer, pl_module: LightningModule
75
86
  ) -> None:
@@ -81,10 +92,12 @@ class NTCallbackBase(_LightningCallback):
81
92
  .. code-block:: python
82
93
 
83
94
  class MyLightningModule(L.LightningModule):
95
+ @override
84
96
  def __init__(self):
85
97
  super().__init__() # pyright: ignore[reportIncompatibleMethodOverride]
86
98
  self.training_step_outputs = []
87
99
 
100
+ @override
88
101
  def training_step(self):
89
102
  loss = ... # pyright: ignore[reportIncompatibleMethodOverride]
90
103
  self.training_step_outputs.append(loss)
@@ -92,6 +105,7 @@ class NTCallbackBase(_LightningCallback):
92
105
 
93
106
 
94
107
  class MyCallback(L.Callback):
108
+ @override
95
109
  def on_train_epoch_end(self, trainer, pl_module):
96
110
  # do something with all training_step outputs, for example: # pyright: ignore[reportIncompatibleMethodOverride]
97
111
  epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
@@ -101,36 +115,43 @@ class NTCallbackBase(_LightningCallback):
101
115
 
102
116
  """
103
117
 
118
+ @override
104
119
  def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
105
120
  self, trainer: Trainer, pl_module: LightningModule
106
121
  ) -> None:
107
122
  """Called when the val epoch begins."""
108
123
 
124
+ @override
109
125
  def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
110
126
  self, trainer: Trainer, pl_module: LightningModule
111
127
  ) -> None:
112
128
  """Called when the val epoch ends."""
113
129
 
130
+ @override
114
131
  def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
115
132
  self, trainer: Trainer, pl_module: LightningModule
116
133
  ) -> None:
117
134
  """Called when the test epoch begins."""
118
135
 
136
+ @override
119
137
  def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
120
138
  self, trainer: Trainer, pl_module: LightningModule
121
139
  ) -> None:
122
140
  """Called when the test epoch ends."""
123
141
 
142
+ @override
124
143
  def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
125
144
  self, trainer: Trainer, pl_module: LightningModule
126
145
  ) -> None:
127
146
  """Called when the predict epoch begins."""
128
147
 
148
+ @override
129
149
  def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
130
150
  self, trainer: Trainer, pl_module: LightningModule
131
151
  ) -> None:
132
152
  """Called when the predict epoch ends."""
133
153
 
154
+ @override
134
155
  def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
135
156
  self,
136
157
  trainer: Trainer,
@@ -141,6 +162,7 @@ class NTCallbackBase(_LightningCallback):
141
162
  ) -> None:
142
163
  """Called when the validation batch begins."""
143
164
 
165
+ @override
144
166
  def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
145
167
  self,
146
168
  trainer: Trainer,
@@ -152,6 +174,7 @@ class NTCallbackBase(_LightningCallback):
152
174
  ) -> None:
153
175
  """Called when the validation batch ends."""
154
176
 
177
+ @override
155
178
  def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
156
179
  self,
157
180
  trainer: Trainer,
@@ -162,6 +185,7 @@ class NTCallbackBase(_LightningCallback):
162
185
  ) -> None:
163
186
  """Called when the test batch begins."""
164
187
 
188
+ @override
165
189
  def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
166
190
  self,
167
191
  trainer: Trainer,
@@ -173,6 +197,7 @@ class NTCallbackBase(_LightningCallback):
173
197
  ) -> None:
174
198
  """Called when the test batch ends."""
175
199
 
200
+ @override
176
201
  def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
177
202
  self,
178
203
  trainer: Trainer,
@@ -183,6 +208,7 @@ class NTCallbackBase(_LightningCallback):
183
208
  ) -> None:
184
209
  """Called when the predict batch begins."""
185
210
 
211
+ @override
186
212
  def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
187
213
  self,
188
214
  trainer: Trainer,
@@ -194,36 +220,45 @@ class NTCallbackBase(_LightningCallback):
194
220
  ) -> None:
195
221
  """Called when the predict batch ends."""
196
222
 
223
+ @override
197
224
  def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
198
225
  """Called when the train begins."""
199
226
 
227
+ @override
200
228
  def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
201
229
  """Called when the train ends."""
202
230
 
231
+ @override
203
232
  def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
204
233
  self, trainer: Trainer, pl_module: LightningModule
205
234
  ) -> None:
206
235
  """Called when the validation loop begins."""
207
236
 
237
+ @override
208
238
  def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
209
239
  self, trainer: Trainer, pl_module: LightningModule
210
240
  ) -> None:
211
241
  """Called when the validation loop ends."""
212
242
 
243
+ @override
213
244
  def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
214
245
  """Called when the test begins."""
215
246
 
247
+ @override
216
248
  def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
217
249
  """Called when the test ends."""
218
250
 
251
+ @override
219
252
  def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
220
253
  self, trainer: Trainer, pl_module: LightningModule
221
254
  ) -> None:
222
255
  """Called when the predict begins."""
223
256
 
257
+ @override
224
258
  def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
225
259
  """Called when predict ends."""
226
260
 
261
+ @override
227
262
  def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
228
263
  self,
229
264
  trainer: Trainer,
@@ -232,7 +267,8 @@ class NTCallbackBase(_LightningCallback):
232
267
  ) -> None:
233
268
  """Called when any trainer execution is interrupted by an exception."""
234
269
 
235
- def state_dict(self) -> dict[str, Any]: # pyright: ignore[reportIncompatibleMethodOverride]
270
+ @override
271
+ def state_dict(self) -> dict[str, Any]:
236
272
  """Called when saving a checkpoint, implement to generate callback's ``state_dict``.
237
273
 
238
274
  Returns:
@@ -241,7 +277,8 @@ class NTCallbackBase(_LightningCallback):
241
277
  """
242
278
  return {}
243
279
 
244
- def load_state_dict(self, state_dict: dict[str, Any]) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
280
+ @override
281
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
245
282
  """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
246
283
 
247
284
  Args:
@@ -250,6 +287,7 @@ class NTCallbackBase(_LightningCallback):
250
287
  """
251
288
  pass
252
289
 
290
+ @override
253
291
  def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
254
292
  self,
255
293
  trainer: Trainer,
@@ -265,6 +303,7 @@ class NTCallbackBase(_LightningCallback):
265
303
 
266
304
  """
267
305
 
306
+ @override
268
307
  def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
269
308
  self,
270
309
  trainer: Trainer,
@@ -280,16 +319,19 @@ class NTCallbackBase(_LightningCallback):
280
319
 
281
320
  """
282
321
 
322
+ @override
283
323
  def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
284
324
  self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor
285
325
  ) -> None:
286
326
  """Called before ``loss.backward()``."""
287
327
 
328
+ @override
288
329
  def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
289
330
  self, trainer: Trainer, pl_module: LightningModule
290
331
  ) -> None:
291
332
  """Called after ``loss.backward()`` and before optimizers are stepped."""
292
333
 
334
+ @override
293
335
  def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
294
336
  self,
295
337
  trainer: Trainer,
@@ -298,6 +340,7 @@ class NTCallbackBase(_LightningCallback):
298
340
  ) -> None:
299
341
  """Called before ``optimizer.step()``."""
300
342
 
343
+ @override
301
344
  def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
302
345
  self,
303
346
  trainer: Trainer,
@@ -306,7 +349,10 @@ class NTCallbackBase(_LightningCallback):
306
349
  ) -> None:
307
350
  """Called before ``optimizer.zero_grad()``."""
308
351
 
309
- def on_checkpoint_saved( # pyright: ignore[reportIncompatibleMethodOverride]
352
+ # =================================================================
353
+ # Our own new callbacks
354
+ # =================================================================
355
+ def on_checkpoint_saved(
310
356
  self,
311
357
  ckpt_path: Path,
312
358
  metadata_path: Path | None,
@@ -317,6 +363,7 @@ class NTCallbackBase(_LightningCallback):
317
363
  pass
318
364
 
319
365
 
366
+ @override
320
367
  def _call_on_checkpoint_saved(
321
368
  trainer: Trainer,
322
369
  ckpt_path: str | Path,
@@ -75,5 +75,5 @@ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
75
75
 
76
76
  CallbackConfig = TypeAliasType(
77
77
  "CallbackConfig",
78
- Annotated[CallbackConfigBase, callback_registry.DynamicResolution()],
78
+ Annotated[CallbackConfigBase, callback_registry],
79
79
  )
@@ -5,13 +5,13 @@ import string
5
5
  from abc import ABC, abstractmethod
6
6
  from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal
9
9
 
10
10
  import numpy as np
11
11
  import torch
12
12
  from lightning.pytorch import Trainer
13
13
  from lightning.pytorch.callbacks import Checkpoint
14
- from typing_extensions import override
14
+ from typing_extensions import TypeVar, override
15
15
 
16
16
  from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
17
17
  from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
@@ -0,0 +1,136 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ from typing import Any, Literal
6
+
7
+ from lightning.pytorch import LightningModule, Trainer
8
+ from lightning.pytorch.callbacks import Callback
9
+ from typing_extensions import final, override
10
+
11
+ from .base import CallbackConfigBase, callback_registry
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ @callback_registry.register
18
+ class LogEpochCallbackConfig(CallbackConfigBase):
19
+ name: Literal["log_epoch"] = "log_epoch"
20
+
21
+ metric_name: str = "computed_epoch"
22
+ """The name of the metric to log the epoch as."""
23
+
24
+ train: bool = True
25
+ """Whether to log the epoch during training."""
26
+
27
+ val: bool = True
28
+ """Whether to log the epoch during validation."""
29
+
30
+ test: bool = True
31
+ """Whether to log the epoch during testing."""
32
+
33
+ @override
34
+ def create_callbacks(self, trainer_config):
35
+ yield LogEpochCallback(self)
36
+
37
+
38
+ def _worker_fn(
39
+ trainer: Trainer,
40
+ pl_module: LightningModule,
41
+ num_batches_prop: str,
42
+ dataloader_idx: int | None = None,
43
+ *,
44
+ metric_name: str,
45
+ ):
46
+ if trainer.logger is None:
47
+ return
48
+
49
+ # If trainer.num_{training/val/test}_batches is not set or is nan/inf, we cannot calculate the epoch
50
+ if not (num_batches := getattr(trainer, num_batches_prop, None)):
51
+ log.warning(f"Trainer has no valid `{num_batches_prop}`. Cannot log epoch.")
52
+ return
53
+
54
+ # If the trainer has a dataloader_idx, num_batches is a list of num_batches for each dataloader.
55
+ if dataloader_idx is not None:
56
+ assert isinstance(num_batches, list), (
57
+ f"Expected num_batches to be a list, got {type(num_batches)}"
58
+ )
59
+ assert 0 <= dataloader_idx < len(num_batches), (
60
+ f"Expected dataloader_idx to be between 0 and {len(num_batches)}, got {dataloader_idx}"
61
+ )
62
+ num_batches = num_batches[dataloader_idx]
63
+
64
+ if (
65
+ not isinstance(num_batches, (int, float))
66
+ or math.isnan(num_batches)
67
+ or math.isinf(num_batches)
68
+ ):
69
+ log.warning(
70
+ f"Trainer has no valid `{num_batches_prop}` (got {num_batches=}). Cannot log epoch."
71
+ )
72
+ return
73
+
74
+ epoch = pl_module.global_step / num_batches
75
+ pl_module.log(metric_name, epoch, on_step=True, on_epoch=False)
76
+
77
+
78
+ class LogEpochCallback(Callback):
79
+ def __init__(self, config: LogEpochCallbackConfig):
80
+ super().__init__()
81
+
82
+ self.config = config
83
+
84
+ @override
85
+ def on_train_batch_start(
86
+ self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
87
+ ):
88
+ if trainer.logger is None or not self.config.train:
89
+ return
90
+
91
+ _worker_fn(
92
+ trainer,
93
+ pl_module,
94
+ "num_training_batches",
95
+ metric_name=self.config.metric_name,
96
+ )
97
+
98
+ @override
99
+ def on_validation_batch_start(
100
+ self,
101
+ trainer: Trainer,
102
+ pl_module: LightningModule,
103
+ batch: Any,
104
+ batch_idx: int,
105
+ dataloader_idx: int = 0,
106
+ ) -> None:
107
+ if trainer.logger is None or not self.config.val:
108
+ return
109
+
110
+ _worker_fn(
111
+ trainer,
112
+ pl_module,
113
+ "num_val_batches",
114
+ dataloader_idx=dataloader_idx,
115
+ metric_name=self.config.metric_name,
116
+ )
117
+
118
+ @override
119
+ def on_test_batch_start(
120
+ self,
121
+ trainer: Trainer,
122
+ pl_module: LightningModule,
123
+ batch: Any,
124
+ batch_idx: int,
125
+ dataloader_idx: int = 0,
126
+ ) -> None:
127
+ if trainer.logger is None or not self.config.test:
128
+ return
129
+
130
+ _worker_fn(
131
+ trainer,
132
+ pl_module,
133
+ "num_test_batches",
134
+ dataloader_idx=dataloader_idx,
135
+ metric_name=self.config.metric_name,
136
+ )
@@ -67,14 +67,14 @@ class PrintTableMetricsCallback(Callback):
67
67
  }
68
68
  self.metrics.append(metrics_dict)
69
69
 
70
- from rich.console import Console # type: ignore[reportMissingImports] # noqa
70
+ from rich.console import Console # pyright: ignore[reportMissingImports] # noqa
71
71
 
72
72
  console = Console()
73
73
  table = self.create_metrics_table()
74
74
  console.print(table)
75
75
 
76
76
  def create_metrics_table(self):
77
- from rich.table import Table # type: ignore[reportMissingImports] # noqa
77
+ from rich.table import Table # pyright: ignore[reportMissingImports] # noqa
78
78
 
79
79
  table = Table(show_header=True, header_style="bold magenta")
80
80
 
@@ -38,6 +38,7 @@ class RLPSanityChecksCallbackConfig(CallbackConfigBase):
38
38
  def __bool__(self):
39
39
  return self.enabled
40
40
 
41
+ @override
41
42
  def create_callbacks(self, trainer_config):
42
43
  if not self:
43
44
  return
@@ -111,7 +111,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
111
111
  from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
112
112
  from nshtrainer.optimizer import RpropConfig as RpropConfig
113
113
  from nshtrainer.optimizer import SGDConfig as SGDConfig
114
- from nshtrainer.optimizer import Union as Union
115
114
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
116
115
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
117
116
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
@@ -355,7 +354,6 @@ __all__ = [
355
354
  "TorchSyncBatchNormPlugin",
356
355
  "TrainerConfig",
357
356
  "TransformerEnginePluginConfig",
358
- "Union",
359
357
  "WandbLoggerConfig",
360
358
  "WandbUploadCodeCallbackConfig",
361
359
  "WandbWatchCallbackConfig",
@@ -16,7 +16,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
16
16
  from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
17
17
  from nshtrainer.optimizer import RpropConfig as RpropConfig
18
18
  from nshtrainer.optimizer import SGDConfig as SGDConfig
19
- from nshtrainer.optimizer import Union as Union
20
19
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
21
20
 
22
21
  __all__ = [
@@ -34,6 +33,5 @@ __all__ = [
34
33
  "RMSpropConfig",
35
34
  "RpropConfig",
36
35
  "SGDConfig",
37
- "Union",
38
36
  "optimizer_registry",
39
37
  ]
@@ -12,6 +12,5 @@ from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
12
12
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
13
13
 
14
14
  LoggerConfig = TypeAliasType(
15
- "LoggerConfig",
16
- Annotated[LoggerConfigBase, logger_registry.DynamicResolution()],
15
+ "LoggerConfig", Annotated[LoggerConfigBase, logger_registry]
17
16
  )
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
7
  from lightning.pytorch.loggers import Logger
8
- from typing_extensions import final
8
+ from typing_extensions import final, override
9
9
 
10
10
  from .base import LoggerConfigBase, logger_registry
11
11
 
@@ -15,6 +15,7 @@ from .base import LoggerConfigBase, logger_registry
15
15
  class ActSaveLoggerConfig(LoggerConfigBase):
16
16
  name: Literal["actsave"] = "actsave"
17
17
 
18
+ @override
18
19
  def create_logger(self, trainer_config):
19
20
  if not self.enabled:
20
21
  return None
@@ -24,10 +25,12 @@ class ActSaveLoggerConfig(LoggerConfigBase):
24
25
 
25
26
  class ActSaveLogger(Logger):
26
27
  @property
28
+ @override
27
29
  def name(self):
28
30
  return None
29
31
 
30
32
  @property
33
+ @override
31
34
  def version(self):
32
35
  from nshutils import ActSave
33
36
 
@@ -37,6 +40,7 @@ class ActSaveLogger(Logger):
37
40
  return ActSave._saver._id
38
41
 
39
42
  @property
43
+ @override
40
44
  def save_dir(self):
41
45
  from nshutils import ActSave
42
46
 
@@ -45,6 +49,7 @@ class ActSaveLogger(Logger):
45
49
 
46
50
  return str(ActSave._saver._save_dir)
47
51
 
52
+ @override
48
53
  def log_hyperparams(
49
54
  self,
50
55
  params: dict[str, Any] | Namespace,
@@ -56,6 +61,7 @@ class ActSaveLogger(Logger):
56
61
  # Wrap the hparams as a object-dtype np array
57
62
  return ActSave.save({"hyperparameters": np.array(params, dtype=object)})
58
63
 
64
+ @override
59
65
  def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
60
66
  from nshutils import ActSave
61
67
 
@@ -63,7 +63,7 @@ class FinishWandbOnTeardownCallback(Callback):
63
63
  stage: str,
64
64
  ):
65
65
  try:
66
- import wandb # type: ignore
66
+ import wandb
67
67
  except ImportError:
68
68
  return
69
69
 
@@ -139,7 +139,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
139
139
  # If `wandb-core` is enabled, we should use the new backend.
140
140
  if self.use_wandb_core:
141
141
  try:
142
- import wandb # type: ignore
142
+ import wandb
143
143
 
144
144
  # The minimum version that supports the new backend is 0.17.5
145
145
  wandb_version = version.parse(importlib.metadata.version("wandb"))
@@ -151,7 +151,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
151
151
  )
152
152
  # W&B versions 0.18.0 use wandb-core by default
153
153
  elif wandb_version < version.parse("0.18.0"):
154
- wandb.require("core") # type: ignore
154
+ wandb.require("core")
155
155
  log.critical("Using the `wandb-core` backend for WandB.")
156
156
  except ImportError:
157
157
  pass
@@ -166,9 +166,9 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
166
166
  "If you want to use the new `wandb-core` backend, set `use_wandb_core=True`."
167
167
  )
168
168
  try:
169
- import wandb # type: ignore
169
+ import wandb
170
170
 
171
- wandb.require("legacy-service") # type: ignore
171
+ wandb.require("legacy-service")
172
172
  except ImportError:
173
173
  pass
174
174
 
@@ -81,7 +81,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
81
81
  scheduler["monitor"] = metadata["monitor"]
82
82
  # - `strict`
83
83
  if scheduler.get("strict") is None and "strict" in metadata:
84
- scheduler["strict"] = metadata["strict"] # type: ignore
84
+ scheduler["strict"] = metadata["strict"]
85
85
 
86
86
  return scheduler
87
87