nshtrainer 0.41.1__tar.gz → 0.43.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/pyproject.toml +9 -5
  3. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/__init__.py +2 -0
  4. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_callback.py +2 -0
  5. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_checkpoint/loader.py +2 -0
  6. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_checkpoint/metadata.py +2 -0
  7. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_checkpoint/saver.py +2 -0
  8. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_directory.py +4 -2
  9. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_experimental/__init__.py +2 -0
  10. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_hf_hub.py +2 -0
  11. nshtrainer-0.43.0/src/nshtrainer/callbacks/__init__.py +81 -0
  12. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  13. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/actsave.py +2 -0
  14. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/base.py +2 -0
  15. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  16. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/_base.py +2 -0
  17. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  18. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  19. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  20. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/debug_flag.py +2 -0
  21. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/directory_setup.py +4 -2
  22. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/early_stopping.py +6 -4
  23. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/ema.py +5 -3
  24. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/finite_checks.py +3 -1
  25. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/gradient_skipping.py +6 -4
  26. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/interval.py +2 -0
  27. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/log_epoch.py +13 -1
  28. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/norm_logging.py +4 -2
  29. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/print_table.py +3 -1
  30. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  31. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/shared_parameters.py +4 -2
  32. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/throughput_monitor.py +2 -0
  33. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/timer.py +5 -3
  34. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/wandb_upload_code.py +4 -2
  35. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/wandb_watch.py +4 -2
  36. nshtrainer-0.43.0/src/nshtrainer/config/__init__.py +465 -0
  37. nshtrainer-0.43.0/src/nshtrainer/config/_checkpoint/loader/__init__.py +62 -0
  38. nshtrainer-0.43.0/src/nshtrainer/config/_checkpoint/metadata/__init__.py +29 -0
  39. nshtrainer-0.43.0/src/nshtrainer/config/_directory/__init__.py +32 -0
  40. nshtrainer-0.43.0/src/nshtrainer/config/_hf_hub/__init__.py +32 -0
  41. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/__init__.py +176 -0
  42. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/actsave/__init__.py +27 -0
  43. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/base/__init__.py +24 -0
  44. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/__init__.py +73 -0
  45. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +40 -0
  46. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +47 -0
  47. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +40 -0
  48. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +33 -0
  49. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/debug_flag/__init__.py +31 -0
  50. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/directory_setup/__init__.py +33 -0
  51. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/early_stopping/__init__.py +38 -0
  52. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/ema/__init__.py +27 -0
  53. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/finite_checks/__init__.py +33 -0
  54. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +33 -0
  55. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/norm_logging/__init__.py +33 -0
  56. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/print_table/__init__.py +33 -0
  57. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +33 -0
  58. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +33 -0
  59. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +33 -0
  60. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/timer/__init__.py +31 -0
  61. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +33 -0
  62. nshtrainer-0.43.0/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +33 -0
  63. nshtrainer-0.43.0/src/nshtrainer/config/loggers/__init__.py +58 -0
  64. nshtrainer-0.43.0/src/nshtrainer/config/loggers/_base/__init__.py +22 -0
  65. nshtrainer-0.43.0/src/nshtrainer/config/loggers/csv/__init__.py +25 -0
  66. nshtrainer-0.43.0/src/nshtrainer/config/loggers/tensorboard/__init__.py +31 -0
  67. nshtrainer-0.43.0/src/nshtrainer/config/loggers/wandb/__init__.py +44 -0
  68. nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/__init__.py +59 -0
  69. nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/_base/__init__.py +26 -0
  70. nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +40 -0
  71. nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +40 -0
  72. nshtrainer-0.43.0/src/nshtrainer/config/metrics/__init__.py +24 -0
  73. nshtrainer-0.43.0/src/nshtrainer/config/metrics/_config/__init__.py +22 -0
  74. nshtrainer-0.43.0/src/nshtrainer/config/model/__init__.py +41 -0
  75. nshtrainer-0.43.0/src/nshtrainer/config/model/base/__init__.py +25 -0
  76. nshtrainer-0.43.0/src/nshtrainer/config/model/config/__init__.py +37 -0
  77. nshtrainer-0.43.0/src/nshtrainer/config/model/mixins/logger/__init__.py +22 -0
  78. nshtrainer-0.43.0/src/nshtrainer/config/nn/__init__.py +77 -0
  79. nshtrainer-0.43.0/src/nshtrainer/config/nn/mlp/__init__.py +28 -0
  80. nshtrainer-0.43.0/src/nshtrainer/config/nn/nonlinearity/__init__.py +125 -0
  81. nshtrainer-0.43.0/src/nshtrainer/config/optimizer/__init__.py +28 -0
  82. nshtrainer-0.43.0/src/nshtrainer/config/profiler/__init__.py +39 -0
  83. nshtrainer-0.43.0/src/nshtrainer/config/profiler/_base/__init__.py +24 -0
  84. nshtrainer-0.43.0/src/nshtrainer/config/profiler/advanced/__init__.py +31 -0
  85. nshtrainer-0.43.0/src/nshtrainer/config/profiler/pytorch/__init__.py +31 -0
  86. nshtrainer-0.43.0/src/nshtrainer/config/profiler/simple/__init__.py +29 -0
  87. nshtrainer-0.43.0/src/nshtrainer/config/runner/__init__.py +22 -0
  88. nshtrainer-0.43.0/src/nshtrainer/config/trainer/_config/__init__.py +153 -0
  89. nshtrainer-0.43.0/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +26 -0
  90. nshtrainer-0.43.0/src/nshtrainer/config/util/_environment_info/__init__.py +94 -0
  91. nshtrainer-0.43.0/src/nshtrainer/config/util/config/__init__.py +34 -0
  92. nshtrainer-0.43.0/src/nshtrainer/config/util/config/dtype/__init__.py +22 -0
  93. nshtrainer-0.43.0/src/nshtrainer/config/util/config/duration/__init__.py +34 -0
  94. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/__init__.py +2 -0
  95. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/balanced_batch_sampler.py +2 -0
  96. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/datamodule.py +2 -0
  97. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/transform.py +2 -0
  98. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/__init__.py +2 -0
  99. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/_experimental.py +2 -0
  100. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/actsave.py +2 -0
  101. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/callbacks.py +2 -0
  102. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/config.py +2 -0
  103. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/data.py +2 -0
  104. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/log.py +2 -0
  105. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/lr_scheduler.py +2 -0
  106. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/model.py +2 -0
  107. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/nn.py +2 -0
  108. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/optimizer.py +2 -0
  109. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/runner.py +2 -0
  110. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/snapshot.py +2 -0
  111. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/snoop.py +2 -0
  112. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/trainer.py +2 -0
  113. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/typecheck.py +2 -0
  114. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/util.py +2 -0
  115. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/__init__.py +2 -0
  116. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/_base.py +2 -0
  117. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/csv.py +2 -0
  118. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/tensorboard.py +2 -0
  119. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/wandb.py +6 -4
  120. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/__init__.py +2 -0
  121. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/_base.py +2 -0
  122. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  123. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  124. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/metrics/__init__.py +2 -0
  125. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/metrics/_config.py +2 -0
  126. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/__init__.py +2 -0
  127. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/base.py +2 -0
  128. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/config.py +2 -0
  129. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/mixins/callback.py +2 -0
  130. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/mixins/logger.py +2 -0
  131. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/__init__.py +2 -0
  132. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/mlp.py +2 -0
  133. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/module_dict.py +2 -0
  134. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/module_list.py +2 -0
  135. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/nonlinearity.py +2 -0
  136. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/optimizer.py +2 -0
  137. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/__init__.py +2 -0
  138. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/_base.py +2 -0
  139. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/advanced.py +2 -0
  140. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/pytorch.py +2 -0
  141. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/simple.py +2 -0
  142. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/runner.py +2 -0
  143. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/scripts/find_packages.py +2 -0
  144. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/__init__.py +2 -0
  145. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/_config.py +16 -13
  146. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/_runtime_callback.py +2 -0
  147. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/checkpoint_connector.py +2 -0
  148. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/signal_connector.py +2 -0
  149. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/trainer.py +2 -0
  150. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/_environment_info.py +2 -0
  151. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/bf16.py +2 -0
  152. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/config/__init__.py +2 -0
  153. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/config/dtype.py +2 -0
  154. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/config/duration.py +2 -0
  155. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/environment.py +2 -0
  156. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/path.py +2 -0
  157. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/seed.py +2 -0
  158. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/slurm.py +3 -0
  159. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/typed.py +2 -0
  160. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/typing_utils.py +2 -0
  161. nshtrainer-0.41.1/src/nshtrainer/callbacks/__init__.py +0 -65
  162. nshtrainer-0.41.1/src/nshtrainer/config/__init__.py +0 -114
  163. nshtrainer-0.41.1/src/nshtrainer/config/_checkpoint/loader/__init__.py +0 -18
  164. nshtrainer-0.41.1/src/nshtrainer/config/_checkpoint/metadata/__init__.py +0 -13
  165. nshtrainer-0.41.1/src/nshtrainer/config/_directory/__init__.py +0 -14
  166. nshtrainer-0.41.1/src/nshtrainer/config/_hf_hub/__init__.py +0 -14
  167. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/__init__.py +0 -51
  168. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/actsave/__init__.py +0 -13
  169. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/base/__init__.py +0 -12
  170. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/__init__.py +0 -22
  171. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +0 -14
  172. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +0 -15
  173. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +0 -14
  174. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -13
  175. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/debug_flag/__init__.py +0 -13
  176. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/directory_setup/__init__.py +0 -13
  177. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/early_stopping/__init__.py +0 -14
  178. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/ema/__init__.py +0 -13
  179. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/finite_checks/__init__.py +0 -13
  180. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +0 -13
  181. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/norm_logging/__init__.py +0 -13
  182. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/print_table/__init__.py +0 -13
  183. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +0 -13
  184. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +0 -13
  185. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +0 -13
  186. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/timer/__init__.py +0 -13
  187. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +0 -13
  188. nshtrainer-0.41.1/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +0 -13
  189. nshtrainer-0.41.1/src/nshtrainer/config/loggers/__init__.py +0 -23
  190. nshtrainer-0.41.1/src/nshtrainer/config/loggers/_base/__init__.py +0 -12
  191. nshtrainer-0.41.1/src/nshtrainer/config/loggers/csv/__init__.py +0 -13
  192. nshtrainer-0.41.1/src/nshtrainer/config/loggers/tensorboard/__init__.py +0 -13
  193. nshtrainer-0.41.1/src/nshtrainer/config/loggers/wandb/__init__.py +0 -16
  194. nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/__init__.py +0 -20
  195. nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/_base/__init__.py +0 -12
  196. nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +0 -14
  197. nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -14
  198. nshtrainer-0.41.1/src/nshtrainer/config/metrics/__init__.py +0 -13
  199. nshtrainer-0.41.1/src/nshtrainer/config/metrics/_config/__init__.py +0 -12
  200. nshtrainer-0.41.1/src/nshtrainer/config/model/__init__.py +0 -20
  201. nshtrainer-0.41.1/src/nshtrainer/config/model/base/__init__.py +0 -13
  202. nshtrainer-0.41.1/src/nshtrainer/config/model/config/__init__.py +0 -17
  203. nshtrainer-0.41.1/src/nshtrainer/config/model/mixins/logger/__init__.py +0 -12
  204. nshtrainer-0.41.1/src/nshtrainer/config/nn/__init__.py +0 -30
  205. nshtrainer-0.41.1/src/nshtrainer/config/nn/mlp/__init__.py +0 -14
  206. nshtrainer-0.41.1/src/nshtrainer/config/nn/nonlinearity/__init__.py +0 -27
  207. nshtrainer-0.41.1/src/nshtrainer/config/optimizer/__init__.py +0 -14
  208. nshtrainer-0.41.1/src/nshtrainer/config/profiler/__init__.py +0 -20
  209. nshtrainer-0.41.1/src/nshtrainer/config/profiler/_base/__init__.py +0 -12
  210. nshtrainer-0.41.1/src/nshtrainer/config/profiler/advanced/__init__.py +0 -13
  211. nshtrainer-0.41.1/src/nshtrainer/config/profiler/pytorch/__init__.py +0 -13
  212. nshtrainer-0.41.1/src/nshtrainer/config/profiler/simple/__init__.py +0 -13
  213. nshtrainer-0.41.1/src/nshtrainer/config/runner/__init__.py +0 -12
  214. nshtrainer-0.41.1/src/nshtrainer/config/trainer/_config/__init__.py +0 -35
  215. nshtrainer-0.41.1/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +0 -12
  216. nshtrainer-0.41.1/src/nshtrainer/config/util/_environment_info/__init__.py +0 -22
  217. nshtrainer-0.41.1/src/nshtrainer/config/util/config/__init__.py +0 -17
  218. nshtrainer-0.41.1/src/nshtrainer/config/util/config/dtype/__init__.py +0 -12
  219. nshtrainer-0.41.1/src/nshtrainer/config/util/config/duration/__init__.py +0 -14
  220. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/README.md +0 -0
  221. {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/_useful_types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.41.1
3
+ Version: 0.43.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.41.1"
3
+ version = "0.43.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -25,10 +25,10 @@ tensorboard = { version = "*", optional = true }
25
25
  huggingface-hub = { version = "*", optional = true }
26
26
 
27
27
  [tool.poetry.group.dev.dependencies]
28
- pyright = "^1.1.372"
29
- ruff = "^0.5.4"
30
- ipykernel = "^6.29.5"
31
- ipywidgets = "^8.1.3"
28
+ pyright = "*"
29
+ ruff = "*"
30
+ ipykernel = "*"
31
+ ipywidgets = "*"
32
32
 
33
33
  [build-system]
34
34
  requires = ["poetry-core"]
@@ -43,7 +43,11 @@ strictSetInference = true
43
43
  reportPrivateImportUsage = false
44
44
 
45
45
  [tool.ruff.lint]
46
+ select = ["FA102", "FA100"]
46
47
  ignore = ["F722", "F821", "E731", "E741"]
47
48
 
49
+ [tool.ruff.lint.isort]
50
+ required-imports = ["from __future__ import annotations"]
51
+
48
52
  [tool.poetry.extras]
49
53
  extra = ["wrapt", "GitPython", "wandb", "tensorboard", "huggingface-hub"]
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from . import _experimental as _experimental
2
4
  from . import callbacks as callbacks
3
5
  from . import config as config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from pathlib import Path
2
4
  from typing import TYPE_CHECKING
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Iterable, Sequence
3
5
  from dataclasses import dataclass
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import datetime
3
5
  import logging
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  import shutil
@@ -1,9 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from pathlib import Path
3
5
 
4
6
  import nshconfig as C
5
7
 
6
- from .callbacks.directory_setup import DirectorySetupConfig
8
+ from .callbacks.directory_setup import DirectorySetupCallbackConfig
7
9
  from .loggers import LoggerConfig
8
10
 
9
11
  log = logging.getLogger(__name__)
@@ -32,7 +34,7 @@ class DirectoryConfig(C.Config):
32
34
  profile: Path | None = None
33
35
  """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
34
36
 
35
- setup_callback: DirectorySetupConfig = DirectorySetupConfig()
37
+ setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
36
38
  """Configuration for the directory setup PyTorch Lightning callback."""
37
39
 
38
40
  def resolve_run_root_directory(self, run_id: str) -> Path:
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from lightning.fabric.utilities.throughput import measure_flops as measure_flops
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import logging
3
5
  import os
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Annotated
4
+
5
+ import nshconfig as C
6
+
7
+ from . import checkpoint as checkpoint
8
+ from .base import CallbackConfigBase as CallbackConfigBase
9
+ from .checkpoint import BestCheckpoint as BestCheckpoint
10
+ from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
11
+ from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
12
+ from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
13
+ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCallback
14
+ from .checkpoint import (
15
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
+ )
17
+ from .debug_flag import DebugFlagCallback as DebugFlagCallback
18
+ from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
19
+ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
20
+ from .directory_setup import (
21
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
22
+ )
23
+ from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
24
+ from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
25
+ from .ema import EMACallback as EMACallback
26
+ from .ema import EMACallbackConfig as EMACallbackConfig
27
+ from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
28
+ from .finite_checks import FiniteChecksCallbackConfig as FiniteChecksCallbackConfig
29
+ from .gradient_skipping import GradientSkippingCallback as GradientSkippingCallback
30
+ from .gradient_skipping import (
31
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
32
+ )
33
+ from .interval import EpochIntervalCallback as EpochIntervalCallback
34
+ from .interval import IntervalCallback as IntervalCallback
35
+ from .interval import StepIntervalCallback as StepIntervalCallback
36
+ from .log_epoch import LogEpochCallback as LogEpochCallback
37
+ from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
38
+ from .norm_logging import NormLoggingCallback as NormLoggingCallback
39
+ from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
40
+ from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
41
+ from .print_table import (
42
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
43
+ )
44
+ from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
45
+ from .rlp_sanity_checks import (
46
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
47
+ )
48
+ from .shared_parameters import SharedParametersCallback as SharedParametersCallback
49
+ from .shared_parameters import (
50
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
51
+ )
52
+ from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
53
+ from .timer import EpochTimerCallback as EpochTimerCallback
54
+ from .timer import EpochTimerCallbackConfig as EpochTimerCallbackConfig
55
+ from .wandb_upload_code import WandbUploadCodeCallback as WandbUploadCodeCallback
56
+ from .wandb_upload_code import (
57
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
58
+ )
59
+ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
60
+ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
61
+
62
+ CallbackConfig = Annotated[
63
+ DebugFlagCallbackConfig
64
+ | EarlyStoppingCallbackConfig
65
+ | ThroughputMonitorConfig
66
+ | EpochTimerCallbackConfig
67
+ | PrintTableMetricsCallbackConfig
68
+ | FiniteChecksCallbackConfig
69
+ | NormLoggingCallbackConfig
70
+ | GradientSkippingCallbackConfig
71
+ | LogEpochCallbackConfig
72
+ | EMACallbackConfig
73
+ | BestCheckpointCallbackConfig
74
+ | LastCheckpointCallbackConfig
75
+ | OnExceptionCheckpointCallbackConfig
76
+ | SharedParametersCallbackConfig
77
+ | RLPSanityChecksCallbackConfig
78
+ | WandbWatchCallbackConfig
79
+ | WandbUploadCodeCallbackConfig,
80
+ C.Field(discriminator="name"),
81
+ ]
@@ -12,6 +12,8 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+ from __future__ import annotations
16
+
15
17
  import time
16
18
  from collections import deque
17
19
  from typing import (
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  from pathlib import Path
3
5
  from typing import Literal
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from collections import Counter
3
5
  from collections.abc import Iterable
@@ -1,12 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
4
  from .best_checkpoint import (
3
5
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
6
  )
5
- from .last_checkpoint import LastCheckpoint as LastCheckpoint
7
+ from .last_checkpoint import LastCheckpointCallback as LastCheckpointCallback
6
8
  from .last_checkpoint import (
7
9
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
8
10
  )
9
- from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
11
+ from .on_exception_checkpoint import (
12
+ OnExceptionCheckpointCallback as OnExceptionCheckpointCallback,
13
+ )
10
14
  from .on_exception_checkpoint import (
11
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
12
16
  )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from abc import ABC, abstractmethod
3
5
  from pathlib import Path
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from pathlib import Path
3
5
  from typing import Literal
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -17,11 +19,11 @@ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
17
19
 
18
20
  @override
19
21
  def create_checkpoint(self, root_config, dirpath):
20
- return LastCheckpoint(self, dirpath)
22
+ return LastCheckpointCallback(self, dirpath)
21
23
 
22
24
 
23
25
  @final
24
- class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
26
+ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
25
27
  @override
26
28
  def name(self):
27
29
  return "last"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import datetime
3
5
  import logging
@@ -59,10 +61,12 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
59
61
 
60
62
  if not (filename := self.filename):
61
63
  filename = f"on_exception_{root_config.id}"
62
- yield OnExceptionCheckpoint(self, dirpath=Path(dirpath), filename=filename)
64
+ yield OnExceptionCheckpointCallback(
65
+ self, dirpath=Path(dirpath), filename=filename
66
+ )
63
67
 
64
68
 
65
- class OnExceptionCheckpoint(_OnExceptionCheckpoint):
69
+ class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
66
70
  @override
67
71
  def __init__(
68
72
  self,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import TYPE_CHECKING, Literal, cast
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from pathlib import Path
@@ -41,7 +43,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
41
43
  symlink_path.symlink_to(session_dir)
42
44
 
43
45
 
44
- class DirectorySetupConfig(CallbackConfigBase):
46
+ class DirectorySetupCallbackConfig(CallbackConfigBase):
45
47
  name: Literal["directory_setup"] = "directory_setup"
46
48
 
47
49
  enabled: bool = True
@@ -62,7 +64,7 @@ class DirectorySetupConfig(CallbackConfigBase):
62
64
 
63
65
  class DirectorySetupCallback(Callback):
64
66
  @override
65
- def __init__(self, config: DirectorySetupConfig):
67
+ def __init__(self, config: DirectorySetupCallbackConfig):
66
68
  super().__init__()
67
69
 
68
70
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import math
3
5
  from typing import Literal
@@ -14,7 +16,7 @@ from .base import CallbackConfigBase
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
17
- class EarlyStoppingConfig(CallbackConfigBase):
19
+ class EarlyStoppingCallbackConfig(CallbackConfigBase):
18
20
  name: Literal["early_stopping"] = "early_stopping"
19
21
 
20
22
  metric: MetricConfig | None = None
@@ -54,11 +56,11 @@ class EarlyStoppingConfig(CallbackConfigBase):
54
56
  "Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
55
57
  )
56
58
 
57
- yield EarlyStopping(self, metric)
59
+ yield EarlyStoppingCallback(self, metric)
58
60
 
59
61
 
60
- class EarlyStopping(_EarlyStopping):
61
- def __init__(self, config: EarlyStoppingConfig, metric: MetricConfig):
62
+ class EarlyStoppingCallback(_EarlyStopping):
63
+ def __init__(self, config: EarlyStoppingCallbackConfig, metric: MetricConfig):
62
64
  self.config = config
63
65
  self.metric = metric
64
66
  del config, metric
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import copy
3
5
  import threading
@@ -13,7 +15,7 @@ from typing_extensions import override
13
15
  from .base import CallbackConfigBase
14
16
 
15
17
 
16
- class EMA(Callback):
18
+ class EMACallback(Callback):
17
19
  """
18
20
  Implements Exponential Moving Averaging (EMA).
19
21
 
@@ -358,7 +360,7 @@ class EMAOptimizer(torch.optim.Optimizer):
358
360
  self.rebuild_ema_params = True
359
361
 
360
362
 
361
- class EMAConfig(CallbackConfigBase):
363
+ class EMACallbackConfig(CallbackConfigBase):
362
364
  name: Literal["ema"] = "ema"
363
365
 
364
366
  decay: float
@@ -375,7 +377,7 @@ class EMAConfig(CallbackConfigBase):
375
377
 
376
378
  @override
377
379
  def create_callbacks(self, root_config):
378
- yield EMA(
380
+ yield EMACallback(
379
381
  decay=self.decay,
380
382
  validate_original_weights=self.validate_original_weights,
381
383
  every_n_steps=self.every_n_steps,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -58,7 +60,7 @@ class FiniteChecksCallback(Callback):
58
60
  )
59
61
 
60
62
 
61
- class FiniteChecksConfig(CallbackConfigBase):
63
+ class FiniteChecksCallbackConfig(CallbackConfigBase):
62
64
  name: Literal["finite_checks"] = "finite_checks"
63
65
 
64
66
  nonfinite_grads: bool = True
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Literal, Protocol, runtime_checkable
3
5
 
@@ -18,8 +20,8 @@ class HasGradSkippedSteps(Protocol):
18
20
  grad_skipped_steps: Any
19
21
 
20
22
 
21
- class GradientSkipping(Callback):
22
- def __init__(self, config: "GradientSkippingConfig"):
23
+ class GradientSkippingCallback(Callback):
24
+ def __init__(self, config: "GradientSkippingCallbackConfig"):
23
25
  super().__init__()
24
26
  self.config = config
25
27
 
@@ -73,7 +75,7 @@ class GradientSkipping(Callback):
73
75
  )
74
76
 
75
77
 
76
- class GradientSkippingConfig(CallbackConfigBase):
78
+ class GradientSkippingCallbackConfig(CallbackConfigBase):
77
79
  name: Literal["gradient_skipping"] = "gradient_skipping"
78
80
 
79
81
  threshold: float
@@ -94,4 +96,4 @@ class GradientSkippingConfig(CallbackConfigBase):
94
96
 
95
97
  @override
96
98
  def create_callbacks(self, root_config):
97
- yield GradientSkipping(self)
99
+ yield GradientSkippingCallback(self)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Callable
2
4
  from typing import Literal
3
5
 
@@ -1,14 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import math
3
- from typing import Any
5
+ from typing import Any, Literal
4
6
 
5
7
  from lightning.pytorch import LightningModule, Trainer
6
8
  from lightning.pytorch.callbacks import Callback
7
9
  from typing_extensions import override
8
10
 
11
+ from .base import CallbackConfigBase
12
+
9
13
  log = logging.getLogger(__name__)
10
14
 
11
15
 
16
+ class LogEpochCallbackConfig(CallbackConfigBase):
17
+ name: Literal["log_epoch"] = "log_epoch"
18
+
19
+ @override
20
+ def create_callbacks(self, root_config):
21
+ yield LogEpochCallback()
22
+
23
+
12
24
  class LogEpochCallback(Callback):
13
25
  def __init__(self, metric_name: str = "computed_epoch"):
14
26
  super().__init__()
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal, cast
3
5
 
@@ -96,7 +98,7 @@ def compute_norm(
96
98
 
97
99
 
98
100
  class NormLoggingCallback(Callback):
99
- def __init__(self, config: "NormLoggingConfig"):
101
+ def __init__(self, config: "NormLoggingCallbackConfig"):
100
102
  super().__init__()
101
103
 
102
104
  self.config = config
@@ -155,7 +157,7 @@ class NormLoggingCallback(Callback):
155
157
  )
156
158
 
157
159
 
158
- class NormLoggingConfig(CallbackConfigBase):
160
+ class NormLoggingCallbackConfig(CallbackConfigBase):
159
161
  name: Literal["norm_logging"] = "norm_logging"
160
162
 
161
163
  log_grad_norm: bool | str | float = False
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import fnmatch
3
5
  import importlib.util
@@ -74,7 +76,7 @@ class PrintTableMetricsCallback(Callback):
74
76
  return table
75
77
 
76
78
 
77
- class PrintTableMetricsConfig(CallbackConfigBase):
79
+ class PrintTableMetricsCallbackConfig(CallbackConfigBase):
78
80
  """Configuration class for PrintTableMetricsCallback."""
79
81
 
80
82
  name: Literal["print_table_metrics"] = "print_table_metrics"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Mapping
3
5
  from typing import Literal, cast
@@ -16,7 +18,7 @@ from .base import CallbackConfigBase
16
18
  log = logging.getLogger(__name__)
17
19
 
18
20
 
19
- class RLPSanityChecksConfig(CallbackConfigBase):
21
+ class RLPSanityChecksCallbackConfig(CallbackConfigBase):
20
22
  """
21
23
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
22
24
  - If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
@@ -43,7 +45,7 @@ class RLPSanityChecksConfig(CallbackConfigBase):
43
45
 
44
46
  class RLPSanityChecksCallback(Callback):
45
47
  @override
46
- def __init__(self, config: RLPSanityChecksConfig):
48
+ def __init__(self, config: RLPSanityChecksCallbackConfig):
47
49
  super().__init__()
48
50
 
49
51
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Iterable
3
5
  from typing import Literal, Protocol, TypeAlias, runtime_checkable
@@ -17,7 +19,7 @@ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
17
19
  return [mapping[id(p)] for p in parameters]
18
20
 
19
21
 
20
- class SharedParametersConfig(CallbackConfigBase):
22
+ class SharedParametersCallbackConfig(CallbackConfigBase):
21
23
  """A callback that allows scaling the gradients of shared parameters that
22
24
  are registered in the ``self.shared_parameters`` list of the root module.
23
25
 
@@ -43,7 +45,7 @@ class ModuleWithSharedParameters(Protocol):
43
45
 
44
46
  class SharedParametersCallback(Callback):
45
47
  @override
46
- def __init__(self, config: SharedParametersConfig):
48
+ def __init__(self, config: SharedParametersCallbackConfig):
47
49
  super().__init__()
48
50
 
49
51
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import time
3
5
  from typing import Any, Literal
@@ -12,7 +14,7 @@ from .base import CallbackConfigBase
12
14
  log = logging.getLogger(__name__)
13
15
 
14
16
 
15
- class EpochTimer(Callback):
17
+ class EpochTimerCallback(Callback):
16
18
  def __init__(self):
17
19
  super().__init__()
18
20
 
@@ -149,9 +151,9 @@ class EpochTimer(Callback):
149
151
  self._total_batches = state_dict["total_batches"]
150
152
 
151
153
 
152
- class EpochTimerConfig(CallbackConfigBase):
154
+ class EpochTimerCallbackConfig(CallbackConfigBase):
153
155
  name: Literal["epoch_timer"] = "epoch_timer"
154
156
 
155
157
  @override
156
158
  def create_callbacks(self, root_config):
157
- yield EpochTimer()
159
+ yield EpochTimerCallback()
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from pathlib import Path
@@ -14,7 +16,7 @@ from .base import CallbackConfigBase
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
17
- class WandbUploadCodeConfig(CallbackConfigBase):
19
+ class WandbUploadCodeCallbackConfig(CallbackConfigBase):
18
20
  name: Literal["wandb_upload_code"] = "wandb_upload_code"
19
21
 
20
22
  enabled: bool = True
@@ -32,7 +34,7 @@ class WandbUploadCodeConfig(CallbackConfigBase):
32
34
 
33
35
 
34
36
  class WandbUploadCodeCallback(Callback):
35
- def __init__(self, config: WandbUploadCodeConfig):
37
+ def __init__(self, config: WandbUploadCodeCallbackConfig):
36
38
  super().__init__()
37
39
 
38
40
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal, Protocol, cast, runtime_checkable
3
5
 
@@ -12,7 +14,7 @@ from .base import CallbackConfigBase
12
14
  log = logging.getLogger(__name__)
13
15
 
14
16
 
15
- class WandbWatchConfig(CallbackConfigBase):
17
+ class WandbWatchCallbackConfig(CallbackConfigBase):
16
18
  name: Literal["wandb_watch"] = "wandb_watch"
17
19
 
18
20
  enabled: bool = True
@@ -41,7 +43,7 @@ class _HasWandbLogModuleProtocol(Protocol):
41
43
 
42
44
 
43
45
  class WandbWatchCallback(Callback):
44
- def __init__(self, config: WandbWatchConfig):
46
+ def __init__(self, config: WandbWatchCallbackConfig):
45
47
  super().__init__()
46
48
 
47
49
  self.config = config