nshtrainer 1.0.0b33__tar.gz → 1.0.0b36__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/__init__.py +1 -0
  4. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/_hf_hub.py +8 -1
  5. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/__init__.py +10 -23
  6. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/actsave.py +6 -2
  7. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/base.py +3 -0
  8. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -4
  9. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  10. nshtrainer-1.0.0b33/src/nshtrainer/callbacks/checkpoint/time_checkpoint.py → nshtrainer-1.0.0b36/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +31 -31
  11. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
  12. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/debug_flag.py +4 -2
  13. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/directory_setup.py +23 -21
  14. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/early_stopping.py +4 -2
  15. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/ema.py +29 -27
  16. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/finite_checks.py +21 -19
  17. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/gradient_skipping.py +29 -27
  18. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/log_epoch.py +4 -2
  19. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/lr_monitor.py +6 -1
  20. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/norm_logging.py +36 -34
  21. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/print_table.py +20 -18
  22. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  23. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/shared_parameters.py +9 -7
  24. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/timer.py +12 -10
  25. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/wandb_upload_code.py +4 -2
  26. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/wandb_watch.py +4 -2
  27. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/__init__.py +4 -8
  28. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/_hf_hub/__init__.py +2 -0
  29. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/__init__.py +4 -8
  30. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
  31. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/base/__init__.py +2 -0
  32. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
  33. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
  34. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
  35. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
  36. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
  37. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
  38. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
  39. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/ema/__init__.py +2 -0
  40. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
  41. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
  42. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
  43. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
  44. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
  45. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
  46. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
  47. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
  48. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/timer/__init__.py +2 -0
  49. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
  50. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
  51. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/__init__.py +2 -4
  52. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -8
  53. nshtrainer-1.0.0b36/src/nshtrainer/trainer/__init__.py +7 -0
  54. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/_config.py +4 -42
  55. nshtrainer-1.0.0b33/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -44
  56. nshtrainer-1.0.0b33/src/nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
  57. nshtrainer-1.0.0b33/src/nshtrainer/trainer/__init__.py +0 -6
  58. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/README.md +0 -0
  59. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/.nshconfig.generated.json +0 -0
  60. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/_callback.py +0 -0
  61. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  62. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/_checkpoint/saver.py +0 -0
  63. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/_directory.py +0 -0
  64. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/_experimental/__init__.py +0 -0
  65. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  66. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/callbacks/interval.py +0 -0
  67. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  68. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  69. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  70. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  71. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  72. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
  73. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  74. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  75. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  76. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  77. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  78. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
  79. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  80. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  81. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  82. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  83. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/nn/__init__.py +0 -0
  84. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  85. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  86. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  87. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  88. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  89. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  90. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  91. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  92. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  93. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  94. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  95. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  96. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  97. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  98. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  99. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  100. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  101. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/util/__init__.py +0 -0
  102. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  103. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  104. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  105. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  106. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/data/__init__.py +0 -0
  107. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  108. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/data/datamodule.py +0 -0
  109. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/data/transform.py +0 -0
  110. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/loggers/__init__.py +0 -0
  111. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/loggers/_base.py +0 -0
  112. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/loggers/actsave.py +0 -0
  113. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/loggers/csv.py +0 -0
  114. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/loggers/tensorboard.py +0 -0
  115. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/loggers/wandb.py +0 -0
  116. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  117. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  118. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  119. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  120. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/metrics/__init__.py +0 -0
  121. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/metrics/_config.py +0 -0
  122. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/model/__init__.py +0 -0
  123. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/model/base.py +0 -0
  124. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/model/mixins/callback.py +0 -0
  125. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/model/mixins/debug.py +0 -0
  126. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/model/mixins/logger.py +0 -0
  127. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/nn/__init__.py +0 -0
  128. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/nn/mlp.py +0 -0
  129. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/nn/module_dict.py +0 -0
  130. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/nn/module_list.py +0 -0
  131. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/nn/nonlinearity.py +0 -0
  132. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/optimizer.py +0 -0
  133. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/profiler/__init__.py +0 -0
  134. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/profiler/_base.py +0 -0
  135. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/profiler/advanced.py +0 -0
  136. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/profiler/pytorch.py +0 -0
  137. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/profiler/simple.py +0 -0
  138. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  139. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/accelerator.py +0 -0
  140. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  141. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/plugin/base.py +0 -0
  142. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  143. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/plugin/io.py +0 -0
  144. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  145. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  146. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/signal_connector.py +0 -0
  147. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/strategy.py +0 -0
  148. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/trainer/trainer.py +0 -0
  149. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/_environment_info.py +0 -0
  150. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/bf16.py +0 -0
  151. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/config/__init__.py +0 -0
  152. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/config/dtype.py +0 -0
  153. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/config/duration.py +0 -0
  154. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/environment.py +0 -0
  155. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/path.py +0 -0
  156. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/seed.py +0 -0
  157. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/slurm.py +0 -0
  158. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/typed.py +0 -0
  159. {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b36}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b33
3
+ Version: 1.0.0b36
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta33"
3
+ version = "1.0.0-beta36"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -15,6 +15,7 @@ from .model import LightningModuleBase as LightningModuleBase
15
15
  from .trainer import Trainer as Trainer
16
16
  from .trainer import TrainerConfig as TrainerConfig
17
17
  from .trainer import accelerator_registry as accelerator_registry
18
+ from .trainer import callback_registry as callback_registry
18
19
  from .trainer import plugin_registry as plugin_registry
19
20
 
20
21
  try:
@@ -14,7 +14,11 @@ from nshrunner._env import SNAPSHOT_DIR
14
14
  from typing_extensions import assert_never, override
15
15
 
16
16
  from ._callback import NTCallbackBase
17
- from .callbacks.base import CallbackConfigBase, CallbackMetadataConfig
17
+ from .callbacks.base import (
18
+ CallbackConfigBase,
19
+ CallbackMetadataConfig,
20
+ callback_registry,
21
+ )
18
22
 
19
23
  if TYPE_CHECKING:
20
24
  from huggingface_hub import HfApi # noqa: F401
@@ -39,9 +43,12 @@ class HuggingFaceHubAutoCreateConfig(C.Config):
39
43
  return self.enabled
40
44
 
41
45
 
46
+ @callback_registry.register
42
47
  class HuggingFaceHubConfig(CallbackConfigBase):
43
48
  """Configuration options for Hugging Face Hub integration."""
44
49
 
50
+ name: Literal["hf_hub"] = "hf_hub"
51
+
45
52
  metadata: ClassVar[CallbackMetadataConfig] = {"ignore_if_exists": True}
46
53
 
47
54
  enabled: bool = False
@@ -2,10 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Annotated
4
4
 
5
- import nshconfig as C
5
+ from typing_extensions import TypeAliasType
6
6
 
7
7
  from . import checkpoint as checkpoint
8
+ from .actsave import ActSaveCallback as ActSaveCallback
9
+ from .actsave import ActSaveConfig as ActSaveConfig
8
10
  from .base import CallbackConfigBase as CallbackConfigBase
11
+ from .base import callback_registry as callback_registry
9
12
  from .checkpoint import BestCheckpointCallback as BestCheckpointCallback
10
13
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
11
14
  from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
@@ -14,8 +17,6 @@ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCa
14
17
  from .checkpoint import (
15
18
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
19
  )
17
- from .checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
18
- from .checkpoint import TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig
19
20
  from .debug_flag import DebugFlagCallback as DebugFlagCallback
20
21
  from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
21
22
  from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
@@ -37,6 +38,8 @@ from .interval import IntervalCallback as IntervalCallback
37
38
  from .interval import StepIntervalCallback as StepIntervalCallback
38
39
  from .log_epoch import LogEpochCallback as LogEpochCallback
39
40
  from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
41
+ from .lr_monitor import LearningRateMonitor as LearningRateMonitor
42
+ from .lr_monitor import LearningRateMonitorConfig as LearningRateMonitorConfig
40
43
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
41
44
  from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
42
45
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
@@ -60,23 +63,7 @@ from .wandb_upload_code import (
60
63
  from .wandb_watch import WandbWatchCallback as WandbWatchCallback
61
64
  from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
62
65
 
63
- CallbackConfig = Annotated[
64
- DebugFlagCallbackConfig
65
- | EarlyStoppingCallbackConfig
66
- | EpochTimerCallbackConfig
67
- | PrintTableMetricsCallbackConfig
68
- | FiniteChecksCallbackConfig
69
- | NormLoggingCallbackConfig
70
- | GradientSkippingCallbackConfig
71
- | LogEpochCallbackConfig
72
- | EMACallbackConfig
73
- | BestCheckpointCallbackConfig
74
- | LastCheckpointCallbackConfig
75
- | OnExceptionCheckpointCallbackConfig
76
- | TimeCheckpointCallbackConfig
77
- | SharedParametersCallbackConfig
78
- | RLPSanityChecksCallbackConfig
79
- | WandbWatchCallbackConfig
80
- | WandbUploadCodeCallbackConfig,
81
- C.Field(discriminator="name"),
82
- ]
66
+ CallbackConfig = TypeAliasType(
67
+ "CallbackConfig",
68
+ Annotated[CallbackConfigBase, callback_registry.DynamicResolution()],
69
+ )
@@ -4,15 +4,19 @@ import contextlib
4
4
  from pathlib import Path
5
5
  from typing import Literal
6
6
 
7
- from typing_extensions import TypeAliasType, override
7
+ from typing_extensions import TypeAliasType, final, override
8
8
 
9
9
  from .._callback import NTCallbackBase
10
- from .base import CallbackConfigBase
10
+ from .base import CallbackConfigBase, callback_registry
11
11
 
12
12
  Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
13
13
 
14
14
 
15
+ @final
16
+ @callback_registry.register
15
17
  class ActSaveConfig(CallbackConfigBase):
18
+ name: Literal["act_save"] = "act_save"
19
+
16
20
  enabled: bool = True
17
21
  """Enable activation saving."""
18
22
 
@@ -55,6 +55,9 @@ class CallbackConfigBase(C.Config, ABC):
55
55
  ) -> Iterable[Callback | CallbackWithMetadata]: ...
56
56
 
57
57
 
58
+ callback_registry = C.Registry(CallbackConfigBase, discriminator="name")
59
+
60
+
58
61
  # region Config resolution helpers
59
62
  def _create_callbacks_with_metadata(
60
63
  config: CallbackConfigBase, trainer_config: TrainerConfig
@@ -14,7 +14,3 @@ from .on_exception_checkpoint import (
14
14
  from .on_exception_checkpoint import (
15
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
16
  )
17
- from .time_checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
18
- from .time_checkpoint import (
19
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
20
- )
@@ -9,12 +9,14 @@ from typing_extensions import final, override
9
9
 
10
10
  from ..._checkpoint.metadata import CheckpointMetadata
11
11
  from ...metrics._config import MetricConfig
12
+ from ..base import callback_registry
12
13
  from ._base import BaseCheckpointCallbackConfig, CheckpointBase
13
14
 
14
15
  log = logging.getLogger(__name__)
15
16
 
16
17
 
17
18
  @final
19
+ @callback_registry.register
18
20
  class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
19
21
  name: Literal["best_checkpoint"] = "best_checkpoint"
20
22
 
@@ -9,36 +9,41 @@ from typing import Any, Literal
9
9
  from lightning.pytorch import LightningModule, Trainer
10
10
  from typing_extensions import final, override
11
11
 
12
- from nshtrainer._checkpoint.metadata import CheckpointMetadata
13
-
12
+ from ..._checkpoint.metadata import CheckpointMetadata
13
+ from ..base import callback_registry
14
14
  from ._base import BaseCheckpointCallbackConfig, CheckpointBase
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
19
  @final
20
- class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
21
- name: Literal["time_checkpoint"] = "time_checkpoint"
20
+ @callback_registry.register
21
+ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
22
+ name: Literal["last_checkpoint"] = "last_checkpoint"
23
+
24
+ save_on_time_interval: bool = True
25
+ """Whether to save checkpoints based on time interval."""
22
26
 
23
27
  interval: timedelta = timedelta(hours=12)
24
- """Time interval between checkpoints."""
28
+ """Time interval between checkpoints when save_on_time_interval is True."""
25
29
 
26
30
  @override
27
31
  def create_checkpoint(self, trainer_config, dirpath):
28
- return TimeCheckpointCallback(self, dirpath)
32
+ return LastCheckpointCallback(self, dirpath)
29
33
 
30
34
 
31
35
  @final
32
- class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
33
- def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
36
+ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
37
+ def __init__(self, config: LastCheckpointCallbackConfig, dirpath: Path):
34
38
  super().__init__(config, dirpath)
35
39
  self.start_time = time.time()
36
40
  self.last_checkpoint_time = self.start_time
37
41
  self.interval_seconds = config.interval.total_seconds()
42
+ self.save_on_time_interval = config.save_on_time_interval
38
43
 
39
44
  @override
40
45
  def name(self):
41
- return "time"
46
+ return "last"
42
47
 
43
48
  @override
44
49
  def default_filename(self):
@@ -53,6 +58,8 @@ class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
53
58
  return True
54
59
 
55
60
  def _should_checkpoint(self) -> bool:
61
+ if not self.save_on_time_interval:
62
+ return False
56
63
  current_time = time.time()
57
64
  elapsed_time = current_time - self.last_checkpoint_time
58
65
  return elapsed_time >= self.interval_seconds
@@ -85,30 +92,23 @@ class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
85
92
 
86
93
  @override
87
94
  def on_train_batch_end(
88
- self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
95
+ self,
96
+ trainer: Trainer,
97
+ pl_module: LightningModule,
98
+ *args,
99
+ **kwargs,
89
100
  ):
90
- if self._should_checkpoint():
91
- self.save_checkpoints(trainer)
92
- self.last_checkpoint_time = time.time()
101
+ if not self._should_checkpoint():
102
+ return
103
+ self.save_checkpoints(trainer)
93
104
 
94
105
  @override
95
- def state_dict(self) -> dict[str, Any]:
96
- """Save the timer state for checkpoint resumption.
97
-
98
- Returns:
99
- Dictionary containing the start time and last checkpoint time.
100
- """
101
- return {
102
- "start_time": self.start_time,
103
- "last_checkpoint_time": self.last_checkpoint_time,
104
- }
106
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
107
+ self.save_checkpoints(trainer)
105
108
 
106
109
  @override
107
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
108
- """Restore the timer state when resuming from a checkpoint.
109
-
110
- Args:
111
- state_dict: Dictionary containing the previously saved timer state.
112
- """
113
- self.start_time = state_dict["start_time"]
114
- self.last_checkpoint_time = state_dict["last_checkpoint_time"]
110
+ def save_checkpoints(self, trainer):
111
+ super().save_checkpoints(trainer)
112
+
113
+ if self.save_on_time_interval:
114
+ self.last_checkpoint_time = time.time()
@@ -9,9 +9,9 @@ from typing import Any, Literal
9
9
 
10
10
  from lightning.pytorch import Trainer as LightningTrainer
11
11
  from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
12
- from typing_extensions import override
12
+ from typing_extensions import final, override
13
13
 
14
- from ..base import CallbackConfigBase
14
+ from ..base import CallbackConfigBase, callback_registry
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -44,6 +44,8 @@ def _monkey_patch_disable_barrier(trainer: LightningTrainer):
44
44
  log.warning("Reverted monkey-patched barrier.")
45
45
 
46
46
 
47
+ @final
48
+ @callback_registry.register
47
49
  class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
48
50
  name: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
49
51
 
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  import logging
4
4
  from typing import Literal
5
5
 
6
- from typing_extensions import override
6
+ from typing_extensions import final, override
7
7
 
8
8
  from .._callback import NTCallbackBase
9
- from .base import CallbackConfigBase
9
+ from .base import CallbackConfigBase, callback_registry
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
13
13
 
14
+ @final
15
+ @callback_registry.register
14
16
  class DebugFlagCallbackConfig(CallbackConfigBase):
15
17
  name: Literal["debug_flag"] = "debug_flag"
16
18
 
@@ -5,14 +5,35 @@ import os
5
5
  from pathlib import Path
6
6
  from typing import Literal
7
7
 
8
- from typing_extensions import override
8
+ from typing_extensions import final, override
9
9
 
10
10
  from .._callback import NTCallbackBase
11
- from .base import CallbackConfigBase
11
+ from .base import CallbackConfigBase, callback_registry
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
15
15
 
16
+ @final
17
+ @callback_registry.register
18
+ class DirectorySetupCallbackConfig(CallbackConfigBase):
19
+ name: Literal["directory_setup"] = "directory_setup"
20
+
21
+ enabled: bool = True
22
+ """Whether to enable the directory setup callback."""
23
+
24
+ create_symlink_to_nshrunner_root: bool = True
25
+ """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
26
+
27
+ def __bool__(self):
28
+ return self.enabled
29
+
30
+ def create_callbacks(self, trainer_config):
31
+ if not self:
32
+ return
33
+
34
+ yield DirectorySetupCallback(self)
35
+
36
+
16
37
  def _create_symlink_to_nshrunner(base_dir: Path):
17
38
  # Resolve the current nshrunner session directory
18
39
  if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
@@ -43,25 +64,6 @@ def _create_symlink_to_nshrunner(base_dir: Path):
43
64
  symlink_path.symlink_to(session_dir)
44
65
 
45
66
 
46
- class DirectorySetupCallbackConfig(CallbackConfigBase):
47
- name: Literal["directory_setup"] = "directory_setup"
48
-
49
- enabled: bool = True
50
- """Whether to enable the directory setup callback."""
51
-
52
- create_symlink_to_nshrunner_root: bool = True
53
- """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
54
-
55
- def __bool__(self):
56
- return self.enabled
57
-
58
- def create_callbacks(self, trainer_config):
59
- if not self:
60
- return
61
-
62
- yield DirectorySetupCallback(self)
63
-
64
-
65
67
  class DirectorySetupCallback(NTCallbackBase):
66
68
  @override
67
69
  def __init__(self, config: DirectorySetupCallbackConfig):
@@ -8,14 +8,16 @@ from lightning.fabric.utilities.rank_zero import _get_rank
8
8
  from lightning.pytorch import Trainer
9
9
  from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
10
10
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
11
- from typing_extensions import override
11
+ from typing_extensions import final, override
12
12
 
13
13
  from ..metrics._config import MetricConfig
14
- from .base import CallbackConfigBase
14
+ from .base import CallbackConfigBase, callback_registry
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ @final
20
+ @callback_registry.register
19
21
  class EarlyStoppingCallbackConfig(CallbackConfigBase):
20
22
  name: Literal["early_stopping"] = "early_stopping"
21
23
 
@@ -10,9 +10,36 @@ import lightning.pytorch as pl
10
10
  import torch
11
11
  from lightning.pytorch import Callback
12
12
  from lightning.pytorch.utilities.exceptions import MisconfigurationException
13
- from typing_extensions import override
13
+ from typing_extensions import final, override
14
14
 
15
- from .base import CallbackConfigBase
15
+ from .base import CallbackConfigBase, callback_registry
16
+
17
+
18
+ @final
19
+ @callback_registry.register
20
+ class EMACallbackConfig(CallbackConfigBase):
21
+ name: Literal["ema"] = "ema"
22
+
23
+ decay: float
24
+ """The exponential decay used when calculating the moving average. Has to be between 0-1."""
25
+
26
+ validate_original_weights: bool = False
27
+ """Validate the original weights, as apposed to the EMA weights."""
28
+
29
+ every_n_steps: int = 1
30
+ """Apply EMA every N steps."""
31
+
32
+ cpu_offload: bool = False
33
+ """Offload weights to CPU."""
34
+
35
+ @override
36
+ def create_callbacks(self, trainer_config):
37
+ yield EMACallback(
38
+ decay=self.decay,
39
+ validate_original_weights=self.validate_original_weights,
40
+ every_n_steps=self.every_n_steps,
41
+ cpu_offload=self.cpu_offload,
42
+ )
16
43
 
17
44
 
18
45
  class EMACallback(Callback):
@@ -358,28 +385,3 @@ class EMAOptimizer(torch.optim.Optimizer):
358
385
  def add_param_group(self, param_group):
359
386
  self.optimizer.add_param_group(param_group)
360
387
  self.rebuild_ema_params = True
361
-
362
-
363
- class EMACallbackConfig(CallbackConfigBase):
364
- name: Literal["ema"] = "ema"
365
-
366
- decay: float
367
- """The exponential decay used when calculating the moving average. Has to be between 0-1."""
368
-
369
- validate_original_weights: bool = False
370
- """Validate the original weights, as apposed to the EMA weights."""
371
-
372
- every_n_steps: int = 1
373
- """Apply EMA every N steps."""
374
-
375
- cpu_offload: bool = False
376
- """Offload weights to CPU."""
377
-
378
- @override
379
- def create_callbacks(self, trainer_config):
380
- yield EMACallback(
381
- decay=self.decay,
382
- validate_original_weights=self.validate_original_weights,
383
- every_n_steps=self.every_n_steps,
384
- cpu_offload=self.cpu_offload,
385
- )
@@ -5,13 +5,32 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
  from lightning.pytorch import Callback, LightningModule, Trainer
8
- from typing_extensions import override
8
+ from typing_extensions import final, override
9
9
 
10
- from .base import CallbackConfigBase
10
+ from .base import CallbackConfigBase, callback_registry
11
11
 
12
12
  log = logging.getLogger(__name__)
13
13
 
14
14
 
15
+ @final
16
+ @callback_registry.register
17
+ class FiniteChecksCallbackConfig(CallbackConfigBase):
18
+ name: Literal["finite_checks"] = "finite_checks"
19
+
20
+ nonfinite_grads: bool = True
21
+ """Whether to check for non-finite (i.e. NaN or Inf) gradients"""
22
+
23
+ none_grads: bool = True
24
+ """Whether to check for None gradients"""
25
+
26
+ @override
27
+ def create_callbacks(self, trainer_config):
28
+ yield FiniteChecksCallback(
29
+ nonfinite_grads=self.nonfinite_grads,
30
+ none_grads=self.none_grads,
31
+ )
32
+
33
+
15
34
  def finite_checks(
16
35
  module: LightningModule,
17
36
  nonfinite_grads: bool = True,
@@ -58,20 +77,3 @@ class FiniteChecksCallback(Callback):
58
77
  nonfinite_grads=self._nonfinite_grads,
59
78
  none_grads=self._none_grads,
60
79
  )
61
-
62
-
63
- class FiniteChecksCallbackConfig(CallbackConfigBase):
64
- name: Literal["finite_checks"] = "finite_checks"
65
-
66
- nonfinite_grads: bool = True
67
- """Whether to check for non-finite (i.e. NaN or Inf) gradients"""
68
-
69
- none_grads: bool = True
70
- """Whether to check for None gradients"""
71
-
72
- @override
73
- def create_callbacks(self, trainer_config):
74
- yield FiniteChecksCallback(
75
- nonfinite_grads=self.nonfinite_grads,
76
- none_grads=self.none_grads,
77
- )
@@ -7,21 +7,47 @@ import torch
7
7
  import torchmetrics
8
8
  from lightning.pytorch import Callback, LightningModule, Trainer
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
  from .norm_logging import compute_norm
14
14
 
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
+ @final
19
+ @callback_registry.register
20
+ class GradientSkippingCallbackConfig(CallbackConfigBase):
21
+ name: Literal["gradient_skipping"] = "gradient_skipping"
22
+
23
+ threshold: float
24
+ """Threshold to use for gradient skipping."""
25
+
26
+ norm_type: str | float = 2.0
27
+ """Norm type to use for gradient skipping."""
28
+
29
+ start_after_n_steps: int | None = 100
30
+ """Number of steps to wait before starting gradient skipping."""
31
+
32
+ skip_non_finite: bool = False
33
+ """
34
+ If False, it doesn't skip steps with non-finite norms. This is useful when using AMP, as AMP checks for NaN/Inf grads to adjust the loss scale. Otherwise, skips steps with non-finite norms.
35
+
36
+ Should almost always be False, especially when using AMP (unless you know what you're doing!).
37
+ """
38
+
39
+ @override
40
+ def create_callbacks(self, trainer_config):
41
+ yield GradientSkippingCallback(self)
42
+
43
+
18
44
  @runtime_checkable
19
45
  class HasGradSkippedSteps(Protocol):
20
46
  grad_skipped_steps: Any
21
47
 
22
48
 
23
49
  class GradientSkippingCallback(Callback):
24
- def __init__(self, config: "GradientSkippingCallbackConfig"):
50
+ def __init__(self, config: GradientSkippingCallbackConfig):
25
51
  super().__init__()
26
52
  self.config = config
27
53
 
@@ -73,27 +99,3 @@ class GradientSkippingCallback(Callback):
73
99
  on_step=True,
74
100
  on_epoch=False,
75
101
  )
76
-
77
-
78
- class GradientSkippingCallbackConfig(CallbackConfigBase):
79
- name: Literal["gradient_skipping"] = "gradient_skipping"
80
-
81
- threshold: float
82
- """Threshold to use for gradient skipping."""
83
-
84
- norm_type: str | float = 2.0
85
- """Norm type to use for gradient skipping."""
86
-
87
- start_after_n_steps: int | None = 100
88
- """Number of steps to wait before starting gradient skipping."""
89
-
90
- skip_non_finite: bool = False
91
- """
92
- If False, it doesn't skip steps with non-finite norms. This is useful when using AMP, as AMP checks for NaN/Inf grads to adjust the loss scale. Otherwise, skips steps with non-finite norms.
93
-
94
- Should almost always be False, especially when using AMP (unless you know what you're doing!).
95
- """
96
-
97
- @override
98
- def create_callbacks(self, trainer_config):
99
- yield GradientSkippingCallback(self)
@@ -6,13 +6,15 @@ from typing import Any, Literal
6
6
 
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from lightning.pytorch.callbacks import Callback
9
- from typing_extensions import override
9
+ from typing_extensions import final, override
10
10
 
11
- from .base import CallbackConfigBase
11
+ from .base import CallbackConfigBase, callback_registry
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
15
15
 
16
+ @final
17
+ @callback_registry.register
16
18
  class LogEpochCallbackConfig(CallbackConfigBase):
17
19
  name: Literal["log_epoch"] = "log_epoch"
18
20
 
@@ -3,11 +3,16 @@ from __future__ import annotations
3
3
  from typing import Literal
4
4
 
5
5
  from lightning.pytorch.callbacks import LearningRateMonitor
6
+ from typing_extensions import final
6
7
 
7
- from .base import CallbackConfigBase
8
+ from .base import CallbackConfigBase, callback_registry
8
9
 
9
10
 
11
+ @final
12
+ @callback_registry.register
10
13
  class LearningRateMonitorConfig(CallbackConfigBase):
14
+ name: Literal["learning_rate_monitor"] = "learning_rate_monitor"
15
+
11
16
  logging_interval: Literal["step", "epoch"] | None = None
12
17
  """
13
18
  Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.