nshtrainer 0.44.0__tar.gz → 1.0.0b9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/PKG-INFO +2 -2
  2. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/pyproject.toml +10 -3
  3. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/__init__.py +6 -3
  4. nshtrainer-1.0.0b9/src/nshtrainer/_callback.py +337 -0
  5. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_checkpoint/loader.py +23 -30
  6. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_checkpoint/metadata.py +22 -18
  7. nshtrainer-1.0.0b9/src/nshtrainer/_experimental/__init__.py +1 -0
  8. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_hf_hub.py +25 -26
  9. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/__init__.py +1 -3
  10. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/actsave.py +22 -20
  11. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/base.py +7 -7
  12. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  13. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/_base.py +8 -5
  14. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  15. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  16. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  17. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/debug_flag.py +14 -19
  18. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/directory_setup.py +6 -11
  19. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/early_stopping.py +3 -3
  20. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/ema.py +1 -1
  21. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/finite_checks.py +1 -1
  22. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/gradient_skipping.py +1 -1
  23. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/log_epoch.py +1 -1
  24. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/norm_logging.py +1 -1
  25. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/print_table.py +1 -1
  26. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  27. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/shared_parameters.py +1 -1
  28. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/timer.py +1 -1
  29. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/wandb_upload_code.py +1 -1
  30. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/wandb_watch.py +1 -1
  31. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/__init__.py +189 -189
  32. nshtrainer-1.0.0b9/src/nshtrainer/config/_checkpoint/__init__.py +70 -0
  33. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/__init__.py +44 -44
  34. nshtrainer-1.0.0b9/src/nshtrainer/config/callbacks/log_epoch/__init__.py +31 -0
  35. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/__init__.py +10 -6
  36. nshtrainer-1.0.0b9/src/nshtrainer/config/loggers/actsave/__init__.py +29 -0
  37. nshtrainer-1.0.0b9/src/nshtrainer/config/trainer/__init__.py +180 -0
  38. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/trainer/_config/__init__.py +59 -36
  39. nshtrainer-1.0.0b9/src/nshtrainer/config/trainer/trainer/__init__.py +27 -0
  40. nshtrainer-1.0.0b9/src/nshtrainer/config/util/__init__.py +109 -0
  41. nshtrainer-1.0.0b9/src/nshtrainer/data/datamodule.py +56 -0
  42. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/__init__.py +2 -1
  43. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/_base.py +5 -2
  44. nshtrainer-1.0.0b9/src/nshtrainer/loggers/actsave.py +59 -0
  45. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/csv.py +5 -5
  46. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/tensorboard.py +5 -5
  47. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/wandb.py +17 -16
  48. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/_base.py +2 -1
  49. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  50. nshtrainer-1.0.0b9/src/nshtrainer/model/__init__.py +3 -0
  51. nshtrainer-1.0.0b9/src/nshtrainer/model/base.py +243 -0
  52. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/model/mixins/callback.py +24 -5
  53. nshtrainer-1.0.0b9/src/nshtrainer/model/mixins/debug.py +86 -0
  54. nshtrainer-1.0.0b9/src/nshtrainer/model/mixins/logger.py +163 -0
  55. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/_base.py +2 -2
  56. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/advanced.py +4 -4
  57. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/pytorch.py +4 -4
  58. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/simple.py +4 -4
  59. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/__init__.py +1 -0
  60. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/_config.py +164 -17
  61. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/checkpoint_connector.py +23 -8
  62. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/trainer.py +194 -76
  63. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/_environment_info.py +21 -13
  64. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/config/dtype.py +4 -4
  65. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/typing_utils.py +1 -1
  66. nshtrainer-0.44.0/src/nshtrainer/_callback.py +0 -42
  67. nshtrainer-0.44.0/src/nshtrainer/_experimental/__init__.py +0 -3
  68. nshtrainer-0.44.0/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  69. nshtrainer-0.44.0/src/nshtrainer/callbacks/throughput_monitor.py +0 -58
  70. nshtrainer-0.44.0/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +0 -33
  71. nshtrainer-0.44.0/src/nshtrainer/config/model/__init__.py +0 -41
  72. nshtrainer-0.44.0/src/nshtrainer/config/model/base/__init__.py +0 -25
  73. nshtrainer-0.44.0/src/nshtrainer/config/model/config/__init__.py +0 -37
  74. nshtrainer-0.44.0/src/nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  75. nshtrainer-0.44.0/src/nshtrainer/config/runner/__init__.py +0 -22
  76. nshtrainer-0.44.0/src/nshtrainer/data/datamodule.py +0 -7
  77. nshtrainer-0.44.0/src/nshtrainer/ll/__init__.py +0 -59
  78. nshtrainer-0.44.0/src/nshtrainer/ll/_experimental.py +0 -3
  79. nshtrainer-0.44.0/src/nshtrainer/ll/actsave.py +0 -6
  80. nshtrainer-0.44.0/src/nshtrainer/ll/callbacks.py +0 -3
  81. nshtrainer-0.44.0/src/nshtrainer/ll/config.py +0 -6
  82. nshtrainer-0.44.0/src/nshtrainer/ll/data.py +0 -3
  83. nshtrainer-0.44.0/src/nshtrainer/ll/log.py +0 -5
  84. nshtrainer-0.44.0/src/nshtrainer/ll/lr_scheduler.py +0 -3
  85. nshtrainer-0.44.0/src/nshtrainer/ll/model.py +0 -21
  86. nshtrainer-0.44.0/src/nshtrainer/ll/nn.py +0 -3
  87. nshtrainer-0.44.0/src/nshtrainer/ll/optimizer.py +0 -3
  88. nshtrainer-0.44.0/src/nshtrainer/ll/runner.py +0 -5
  89. nshtrainer-0.44.0/src/nshtrainer/ll/snapshot.py +0 -3
  90. nshtrainer-0.44.0/src/nshtrainer/ll/snoop.py +0 -3
  91. nshtrainer-0.44.0/src/nshtrainer/ll/trainer.py +0 -3
  92. nshtrainer-0.44.0/src/nshtrainer/ll/typecheck.py +0 -3
  93. nshtrainer-0.44.0/src/nshtrainer/ll/util.py +0 -3
  94. nshtrainer-0.44.0/src/nshtrainer/model/__init__.py +0 -7
  95. nshtrainer-0.44.0/src/nshtrainer/model/base.py +0 -526
  96. nshtrainer-0.44.0/src/nshtrainer/model/config.py +0 -218
  97. nshtrainer-0.44.0/src/nshtrainer/model/mixins/logger.py +0 -166
  98. nshtrainer-0.44.0/src/nshtrainer/runner.py +0 -101
  99. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/README.md +0 -0
  100. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_checkpoint/saver.py +0 -0
  101. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_directory.py +0 -0
  102. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/interval.py +0 -0
  103. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  104. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_checkpoint/metadata/__init__.py +0 -0
  105. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_directory/__init__.py +2 -2
  106. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_hf_hub/__init__.py +2 -2
  107. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/actsave/__init__.py +0 -0
  108. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/base/__init__.py +0 -0
  109. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  110. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  111. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  112. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  113. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  114. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  115. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  116. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  117. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/ema/__init__.py +2 -2
  118. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  119. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  120. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  121. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  122. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  123. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  124. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/timer/__init__.py +4 -4
  125. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  126. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  127. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/_base/__init__.py +0 -0
  128. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/csv/__init__.py +2 -2
  129. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/tensorboard/__init__.py +0 -0
  130. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/wandb/__init__.py +6 -6
  131. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/__init__.py +0 -0
  132. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/_base/__init__.py +0 -0
  133. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  134. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  135. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/metrics/__init__.py +0 -0
  136. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/metrics/_config/__init__.py +0 -0
  137. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/nn/__init__.py +18 -18
  138. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/nn/mlp/__init__.py +0 -0
  139. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  140. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/optimizer/__init__.py +2 -2
  141. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/__init__.py +2 -2
  142. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/_base/__init__.py +0 -0
  143. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/advanced/__init__.py +0 -0
  144. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  145. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/simple/__init__.py +4 -4
  146. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +0 -0
  147. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/_environment_info/__init__.py +20 -20
  148. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/config/__init__.py +2 -2
  149. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/config/dtype/__init__.py +0 -0
  150. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/config/duration/__init__.py +0 -0
  151. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/data/__init__.py +0 -0
  152. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  153. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/data/transform.py +0 -0
  154. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  155. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  156. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/metrics/__init__.py +0 -0
  157. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/metrics/_config.py +0 -0
  158. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/__init__.py +0 -0
  159. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/mlp.py +0 -0
  160. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/module_dict.py +0 -0
  161. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/module_list.py +0 -0
  162. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/nonlinearity.py +0 -0
  163. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/optimizer.py +0 -0
  164. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/__init__.py +0 -0
  165. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/scripts/find_packages.py +0 -0
  166. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  167. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/signal_connector.py +0 -0
  168. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/_useful_types.py +0 -0
  169. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/bf16.py +0 -0
  170. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/config/__init__.py +0 -0
  171. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/config/duration.py +0 -0
  172. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/environment.py +0 -0
  173. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/path.py +0 -0
  174. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/seed.py +0 -0
  175. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/slurm.py +0 -0
  176. {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/typed.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.44.0
3
+ Version: 1.0.0b9
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -15,7 +15,7 @@ Requires-Dist: huggingface-hub ; extra == "extra"
15
15
  Requires-Dist: lightning
16
16
  Requires-Dist: nshconfig
17
17
  Requires-Dist: nshrunner
18
- Requires-Dist: nshutils
18
+ Requires-Dist: nshutils ; extra == "extra"
19
19
  Requires-Dist: numpy
20
20
  Requires-Dist: packaging
21
21
  Requires-Dist: psutil
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.44.0"
3
+ version = "1.0.0-beta9"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -9,7 +9,7 @@ readme = "README.md"
9
9
  python = "^3.10"
10
10
  nshrunner = "*"
11
11
  nshconfig = "*"
12
- nshutils = "*"
12
+ nshutils = { version = "*", optional = true }
13
13
  psutil = "*"
14
14
  numpy = "*"
15
15
  torch = "*"
@@ -50,4 +50,11 @@ ignore = ["F722", "F821", "E731", "E741"]
50
50
  required-imports = ["from __future__ import annotations"]
51
51
 
52
52
  [tool.poetry.extras]
53
- extra = ["wrapt", "GitPython", "wandb", "tensorboard", "huggingface-hub"]
53
+ extra = [
54
+ "wrapt",
55
+ "GitPython",
56
+ "wandb",
57
+ "tensorboard",
58
+ "huggingface-hub",
59
+ "nshutils",
60
+ ]
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from . import _experimental as _experimental
4
4
  from . import callbacks as callbacks
5
- from . import config as config
6
5
  from . import data as data
7
6
  from . import lr_scheduler as lr_scheduler
8
7
  from . import metrics as metrics
@@ -12,7 +11,11 @@ from . import optimizer as optimizer
12
11
  from . import profiler as profiler
13
12
  from .data import LightningDataModuleBase as LightningDataModuleBase
14
13
  from .metrics import MetricConfig as MetricConfig
15
- from .model import BaseConfig as BaseConfig
16
14
  from .model import LightningModuleBase as LightningModuleBase
17
- from .runner import Runner as Runner
18
15
  from .trainer import Trainer as Trainer
16
+ from .trainer import TrainerConfig as TrainerConfig
17
+
18
+ try:
19
+ from . import config as config
20
+ except BaseException:
21
+ pass
@@ -0,0 +1,337 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import torch
7
+ from lightning.pytorch.callbacks import Callback as _LightningCallback
8
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+ from torch.optim import Optimizer
10
+
11
+ if TYPE_CHECKING:
12
+ from .model import LightningModuleBase
13
+ from .trainer import Trainer
14
+
15
+
16
+ class NTCallbackBase(_LightningCallback):
17
+ def setup( # pyright: ignore[reportIncompatibleMethodOverride]
18
+ self, trainer: Trainer, pl_module: LightningModuleBase, stage: str
19
+ ) -> None:
20
+ """Called when fit, validate, test, predict, or tune begins."""
21
+
22
+ def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
23
+ self, trainer: Trainer, pl_module: LightningModuleBase, stage: str
24
+ ) -> None:
25
+ """Called when fit, validate, test, predict, or tune ends."""
26
+
27
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
28
+ """Called when fit begins."""
29
+
30
+ def on_fit_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
31
+ """Called when fit ends."""
32
+
33
+ def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
34
+ self, trainer: Trainer, pl_module: LightningModuleBase
35
+ ) -> None:
36
+ """Called when the validation sanity check starts."""
37
+
38
+ def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
39
+ self, trainer: Trainer, pl_module: LightningModuleBase
40
+ ) -> None:
41
+ """Called when the validation sanity check ends."""
42
+
43
+ def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
44
+ self,
45
+ trainer: Trainer,
46
+ pl_module: LightningModuleBase,
47
+ batch: Any,
48
+ batch_idx: int,
49
+ ) -> None:
50
+ """Called when the train batch begins."""
51
+
52
+ def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
53
+ self,
54
+ trainer: Trainer,
55
+ pl_module: LightningModuleBase,
56
+ outputs: STEP_OUTPUT,
57
+ batch: Any,
58
+ batch_idx: int,
59
+ ) -> None:
60
+ """Called when the train batch ends.
61
+
62
+ Note:
63
+ The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
64
+ loss returned from ``training_step``.
65
+
66
+ """
67
+
68
+ def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
69
+ self, trainer: Trainer, pl_module: LightningModuleBase
70
+ ) -> None:
71
+ """Called when the train epoch begins."""
72
+
73
+ def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
74
+ self, trainer: Trainer, pl_module: LightningModuleBase
75
+ ) -> None:
76
+ """Called when the train epoch ends.
77
+
78
+ To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
79
+ :class:`lightning.pytorch.core.LightningModule` and access them in this hook:
80
+
81
+ .. code-block:: python
82
+
83
+ class MyLightningModule(L.LightningModule):
84
+ def __init__(self):
85
+ super().__init__() # pyright: ignore[reportIncompatibleMethodOverride]
86
+ self.training_step_outputs = []
87
+
88
+ def training_step(self):
89
+ loss = ... # pyright: ignore[reportIncompatibleMethodOverride]
90
+ self.training_step_outputs.append(loss)
91
+ return loss
92
+
93
+
94
+ class MyCallback(L.Callback):
95
+ def on_train_epoch_end(self, trainer, pl_module):
96
+ # do something with all training_step outputs, for example: # pyright: ignore[reportIncompatibleMethodOverride]
97
+ epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
98
+ pl_module.log("training_epoch_mean", epoch_mean)
99
+ # free up the memory
100
+ pl_module.training_step_outputs.clear()
101
+
102
+ """
103
+
104
+ def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
105
+ self, trainer: Trainer, pl_module: LightningModuleBase
106
+ ) -> None:
107
+ """Called when the val epoch begins."""
108
+
109
+ def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
110
+ self, trainer: Trainer, pl_module: LightningModuleBase
111
+ ) -> None:
112
+ """Called when the val epoch ends."""
113
+
114
+ def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
115
+ self, trainer: Trainer, pl_module: LightningModuleBase
116
+ ) -> None:
117
+ """Called when the test epoch begins."""
118
+
119
+ def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
120
+ self, trainer: Trainer, pl_module: LightningModuleBase
121
+ ) -> None:
122
+ """Called when the test epoch ends."""
123
+
124
+ def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
125
+ self, trainer: Trainer, pl_module: LightningModuleBase
126
+ ) -> None:
127
+ """Called when the predict epoch begins."""
128
+
129
+ def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
130
+ self, trainer: Trainer, pl_module: LightningModuleBase
131
+ ) -> None:
132
+ """Called when the predict epoch ends."""
133
+
134
+ def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
135
+ self,
136
+ trainer: Trainer,
137
+ pl_module: LightningModuleBase,
138
+ batch: Any,
139
+ batch_idx: int,
140
+ dataloader_idx: int = 0,
141
+ ) -> None:
142
+ """Called when the validation batch begins."""
143
+
144
+ def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
145
+ self,
146
+ trainer: Trainer,
147
+ pl_module: LightningModuleBase,
148
+ outputs: STEP_OUTPUT,
149
+ batch: Any,
150
+ batch_idx: int,
151
+ dataloader_idx: int = 0,
152
+ ) -> None:
153
+ """Called when the validation batch ends."""
154
+
155
+ def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
156
+ self,
157
+ trainer: Trainer,
158
+ pl_module: LightningModuleBase,
159
+ batch: Any,
160
+ batch_idx: int,
161
+ dataloader_idx: int = 0,
162
+ ) -> None:
163
+ """Called when the test batch begins."""
164
+
165
+ def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
166
+ self,
167
+ trainer: Trainer,
168
+ pl_module: LightningModuleBase,
169
+ outputs: STEP_OUTPUT,
170
+ batch: Any,
171
+ batch_idx: int,
172
+ dataloader_idx: int = 0,
173
+ ) -> None:
174
+ """Called when the test batch ends."""
175
+
176
+ def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
177
+ self,
178
+ trainer: Trainer,
179
+ pl_module: LightningModuleBase,
180
+ batch: Any,
181
+ batch_idx: int,
182
+ dataloader_idx: int = 0,
183
+ ) -> None:
184
+ """Called when the predict batch begins."""
185
+
186
+ def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
187
+ self,
188
+ trainer: Trainer,
189
+ pl_module: LightningModuleBase,
190
+ outputs: Any,
191
+ batch: Any,
192
+ batch_idx: int,
193
+ dataloader_idx: int = 0,
194
+ ) -> None:
195
+ """Called when the predict batch ends."""
196
+
197
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
198
+ """Called when the train begins."""
199
+
200
+ def on_train_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
201
+ """Called when the train ends."""
202
+
203
+ def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
204
+ self, trainer: Trainer, pl_module: LightningModuleBase
205
+ ) -> None:
206
+ """Called when the validation loop begins."""
207
+
208
+ def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
209
+ self, trainer: Trainer, pl_module: LightningModuleBase
210
+ ) -> None:
211
+ """Called when the validation loop ends."""
212
+
213
+ def on_test_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
214
+ """Called when the test begins."""
215
+
216
+ def on_test_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
217
+ """Called when the test ends."""
218
+
219
+ def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
220
+ self, trainer: Trainer, pl_module: LightningModuleBase
221
+ ) -> None:
222
+ """Called when the predict begins."""
223
+
224
+ def on_predict_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
225
+ """Called when predict ends."""
226
+
227
+ def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
228
+ self,
229
+ trainer: Trainer,
230
+ pl_module: LightningModuleBase,
231
+ exception: BaseException,
232
+ ) -> None:
233
+ """Called when any trainer execution is interrupted by an exception."""
234
+
235
+ def state_dict(self) -> dict[str, Any]: # pyright: ignore[reportIncompatibleMethodOverride]
236
+ """Called when saving a checkpoint, implement to generate callback's ``state_dict``.
237
+
238
+ Returns:
239
+ A dictionary containing callback state.
240
+
241
+ """
242
+ return {}
243
+
244
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
245
+ """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
246
+
247
+ Args:
248
+ state_dict: the callback state returned by ``state_dict``.
249
+
250
+ """
251
+ pass
252
+
253
+ def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
254
+ self,
255
+ trainer: Trainer,
256
+ pl_module: LightningModuleBase,
257
+ checkpoint: dict[str, Any],
258
+ ) -> None:
259
+ r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save.
260
+
261
+ Args:
262
+ trainer: the current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
263
+ pl_module: the current :class:`~lightning.pytorch.core.LightningModule` instance.
264
+ checkpoint: the checkpoint dictionary that will be saved.
265
+
266
+ """
267
+
268
+ def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
269
+ self,
270
+ trainer: Trainer,
271
+ pl_module: LightningModuleBase,
272
+ checkpoint: dict[str, Any],
273
+ ) -> None:
274
+ r"""Called when loading a model checkpoint, use to reload state.
275
+
276
+ Args:
277
+ trainer: the current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
278
+ pl_module: the current :class:`~lightning.pytorch.core.LightningModule` instance.
279
+ checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
280
+
281
+ """
282
+
283
+ def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
284
+ self, trainer: Trainer, pl_module: LightningModuleBase, loss: torch.Tensor
285
+ ) -> None:
286
+ """Called before ``loss.backward()``."""
287
+
288
+ def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
289
+ self, trainer: Trainer, pl_module: LightningModuleBase
290
+ ) -> None:
291
+ """Called after ``loss.backward()`` and before optimizers are stepped."""
292
+
293
+ def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
294
+ self,
295
+ trainer: Trainer,
296
+ pl_module: LightningModuleBase,
297
+ optimizer: Optimizer,
298
+ ) -> None:
299
+ """Called before ``optimizer.step()``."""
300
+
301
+ def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
302
+ self,
303
+ trainer: Trainer,
304
+ pl_module: LightningModuleBase,
305
+ optimizer: Optimizer,
306
+ ) -> None:
307
+ """Called before ``optimizer.zero_grad()``."""
308
+
309
+ def on_checkpoint_saved( # pyright: ignore[reportIncompatibleMethodOverride]
310
+ self,
311
+ ckpt_path: Path,
312
+ metadata_path: Path | None,
313
+ trainer: "Trainer",
314
+ pl_module: "LightningModuleBase",
315
+ ) -> None:
316
+ """Called after a checkpoint is saved."""
317
+ pass
318
+
319
+
320
+ def _call_on_checkpoint_saved(
321
+ trainer: "Trainer",
322
+ ckpt_path: str | Path,
323
+ metadata_path: str | Path | None,
324
+ ):
325
+ ckpt_path = Path(ckpt_path)
326
+ metadata_path = Path(metadata_path) if metadata_path else None
327
+
328
+ for callback in trainer.callbacks:
329
+ if not isinstance(callback, NTCallbackBase):
330
+ continue
331
+
332
+ callback.on_checkpoint_saved(
333
+ ckpt_path,
334
+ metadata_path,
335
+ trainer,
336
+ trainer._base_module,
337
+ )
@@ -7,7 +7,6 @@ from pathlib import Path
7
7
  from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
8
8
 
9
9
  import nshconfig as C
10
- from lightning.pytorch import Trainer as LightningTrainer
11
10
  from lightning.pytorch.trainer.states import TrainerFn
12
11
  from typing_extensions import assert_never
13
12
 
@@ -15,7 +14,8 @@ from ..metrics._config import MetricConfig
15
14
  from .metadata import CheckpointMetadata
16
15
 
17
16
  if TYPE_CHECKING:
18
- from ..model.config import BaseConfig
17
+ from ..trainer import Trainer
18
+ from ..trainer._config import TrainerConfig
19
19
 
20
20
  log = logging.getLogger(__name__)
21
21
 
@@ -228,22 +228,22 @@ class _CkptCandidate:
228
228
  @overload
229
229
  def _load_ckpt_meta(
230
230
  path: Path,
231
- root_config: "BaseConfig",
231
+ trainer_config: TrainerConfig,
232
232
  on_error: Literal["warn"] = "warn",
233
233
  ) -> _CkptCandidate | None: ...
234
234
  @overload
235
235
  def _load_ckpt_meta(
236
236
  path: Path,
237
- root_config: "BaseConfig",
237
+ trainer_config: TrainerConfig,
238
238
  on_error: Literal["raise"],
239
239
  ) -> _CkptCandidate: ...
240
240
  def _load_ckpt_meta(
241
241
  path: Path,
242
- root_config: "BaseConfig",
242
+ trainer_config: TrainerConfig,
243
243
  on_error: Literal["warn", "raise"] = "warn",
244
244
  ):
245
245
  meta = CheckpointMetadata.from_file(path)
246
- if root_config.id != meta.run_id:
246
+ if trainer_config.id != meta.run_id:
247
247
  error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
248
248
  match on_error:
249
249
  case "warn":
@@ -256,16 +256,13 @@ def _load_ckpt_meta(
256
256
  return _CkptCandidate(meta, path)
257
257
 
258
258
 
259
- def _checkpoint_candidates(
260
- root_config: "BaseConfig",
261
- trainer: LightningTrainer,
262
- *,
263
- include_hpc: bool = True,
264
- ):
259
+ def _checkpoint_candidates(trainer: Trainer, *, include_hpc: bool = True):
265
260
  # Load the checkpoint directory, and throw if it doesn't exist.
266
261
  # This indicates a non-standard setup, and we don't want to guess
267
262
  # where the checkpoints are.
268
- ckpt_dir = root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
263
+ ckpt_dir = trainer.hparams.directory.resolve_subdirectory(
264
+ trainer.hparams.id, "checkpoint"
265
+ )
269
266
  if not ckpt_dir.is_dir():
270
267
  raise FileNotFoundError(
271
268
  f"Checkpoint directory {ckpt_dir} not found. "
@@ -275,46 +272,40 @@ def _checkpoint_candidates(
275
272
  # Load all checkpoints in the directory.
276
273
  # We can do this by looking for metadata files.
277
274
  for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
278
- if (meta := _load_ckpt_meta(path, root_config)) is not None:
275
+ if (meta := _load_ckpt_meta(path, trainer.hparams)) is not None:
279
276
  yield meta
280
277
 
281
278
  # If we have a pre-empted checkpoint, load it
282
279
  if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
283
280
  hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
284
- if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
281
+ if (meta := _load_ckpt_meta(hpc_meta_path, trainer.hparams)) is not None:
285
282
  yield meta
286
283
 
287
284
 
288
285
  def _additional_candidates(
289
- additional_candidates: Iterable[Path], root_config: "BaseConfig"
286
+ additional_candidates: Iterable[Path], trainer_config: TrainerConfig
290
287
  ):
291
288
  for path in additional_candidates:
292
289
  if (
293
290
  meta := _load_ckpt_meta(
294
- path.with_suffix(CheckpointMetadata.PATH_SUFFIX), root_config
291
+ path.with_suffix(CheckpointMetadata.PATH_SUFFIX), trainer_config
295
292
  )
296
293
  ) is None:
297
294
  continue
298
295
  yield meta
299
296
 
300
297
 
301
- def _resolve_checkpoint(
302
- config: CheckpointLoadingConfig,
303
- root_config: "BaseConfig",
304
- trainer: LightningTrainer,
305
- ):
298
+ def _resolve_checkpoint(config: CheckpointLoadingConfig, trainer: Trainer):
306
299
  # We lazily load the checkpoint candidates to avoid loading them
307
300
  # if they are not needed.
308
301
  _ckpt_candidates: list[_CkptCandidate] | None = None
309
302
 
310
303
  def ckpt_candidates():
311
- nonlocal _ckpt_candidates, root_config, trainer
304
+ nonlocal _ckpt_candidates, trainer
312
305
 
313
306
  if _ckpt_candidates is None:
314
307
  _ckpt_candidates = list(
315
- _checkpoint_candidates(
316
- root_config, trainer, include_hpc=config.include_hpc
317
- )
308
+ _checkpoint_candidates(trainer, include_hpc=config.include_hpc)
318
309
  )
319
310
  return _ckpt_candidates
320
311
 
@@ -324,7 +315,7 @@ def _resolve_checkpoint(
324
315
  case UserProvidedPathCheckpointStrategyConfig():
325
316
  meta = _load_ckpt_meta(
326
317
  strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
327
- root_config,
318
+ trainer.hparams,
328
319
  on_error=strategy.on_error,
329
320
  )
330
321
  if meta is None:
@@ -334,7 +325,7 @@ def _resolve_checkpoint(
334
325
  candidates = [
335
326
  *ckpt_candidates(),
336
327
  *_additional_candidates(
337
- strategy.additional_candidates, root_config
328
+ strategy.additional_candidates, trainer.hparams
338
329
  ),
339
330
  ]
340
331
  if not candidates:
@@ -343,7 +334,9 @@ def _resolve_checkpoint(
343
334
  )
344
335
  continue
345
336
 
346
- if (metric := strategy.metric or root_config.primary_metric) is None:
337
+ if (
338
+ metric := strategy.metric or trainer.hparams.primary_metric
339
+ ) is None:
347
340
  log.warning(
348
341
  "No metric specified for `best` checkpoint strategy, "
349
342
  "and no primary metric is set in the configuration. "
@@ -369,7 +362,7 @@ def _resolve_checkpoint(
369
362
  candidates = [
370
363
  *ckpt_candidates(),
371
364
  *_additional_candidates(
372
- strategy.additional_candidates, root_config
365
+ strategy.additional_candidates, trainer.hparams
373
366
  ),
374
367
  ]
375
368
  if not candidates:
@@ -5,7 +5,7 @@ import datetime
5
5
  import logging
6
6
  from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, ClassVar, cast
8
+ from typing import TYPE_CHECKING, Any, ClassVar
9
9
 
10
10
  import nshconfig as C
11
11
  import numpy as np
@@ -15,7 +15,6 @@ from ..util._environment_info import EnvironmentConfig
15
15
  from ..util.path import compute_file_checksum, try_symlink_or_copy
16
16
 
17
17
  if TYPE_CHECKING:
18
- from ..model import BaseConfig, LightningModuleBase
19
18
  from ..trainer.trainer import Trainer
20
19
 
21
20
  log = logging.getLogger(__name__)
@@ -24,6 +23,19 @@ log = logging.getLogger(__name__)
24
23
  METADATA_PATH_SUFFIX = ".metadata.json"
25
24
 
26
25
 
26
+ def _full_hparams_dict(trainer: Trainer):
27
+ hparams = {}
28
+ hparams["trainer"] = trainer.hparams.model_dump(mode="json")
29
+
30
+ if trainer.lightning_module is not None:
31
+ from ..model import LightningModuleBase
32
+
33
+ if isinstance(trainer.lightning_module, LightningModuleBase):
34
+ hparams["model"] = trainer.lightning_module.hparams.model_dump(mode="json")
35
+
36
+ return hparams
37
+
38
+
27
39
  class CheckpointMetadata(C.Config):
28
40
  PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
29
41
 
@@ -59,8 +71,7 @@ class CheckpointMetadata(C.Config):
59
71
 
60
72
 
61
73
  def _generate_checkpoint_metadata(
62
- config: "BaseConfig",
63
- trainer: "Trainer",
74
+ trainer: Trainer,
64
75
  checkpoint_path: Path,
65
76
  metadata_path: Path,
66
77
  ):
@@ -84,9 +95,9 @@ def _generate_checkpoint_metadata(
84
95
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
85
96
  checkpoint_filename=checkpoint_path.name,
86
97
  checkpoint_checksum=compute_file_checksum(checkpoint_path),
87
- run_id=config.id,
88
- name=config.run_name,
89
- project=config.project,
98
+ run_id=trainer.hparams.id,
99
+ name=trainer.hparams.full_name,
100
+ project=trainer.hparams.project,
90
101
  checkpoint_timestamp=checkpoint_timestamp,
91
102
  start_timestamp=start_timestamp.datetime
92
103
  if start_timestamp is not None
@@ -95,8 +106,8 @@ def _generate_checkpoint_metadata(
95
106
  global_step=trainer.global_step,
96
107
  training_time=training_time,
97
108
  metrics=metrics,
98
- environment=config.environment,
99
- hparams=config.model_dump(),
109
+ environment=trainer.hparams.environment,
110
+ hparams=_full_hparams_dict(trainer),
100
111
  )
101
112
 
102
113
 
@@ -104,16 +115,9 @@ def _metadata_path(checkpoint_path: Path):
104
115
  return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
105
116
 
106
117
 
107
- def _write_checkpoint_metadata(
108
- trainer: "Trainer",
109
- model: "LightningModuleBase",
110
- checkpoint_path: Path,
111
- ):
112
- config = cast("BaseConfig", model.config)
118
+ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
113
119
  metadata_path = _metadata_path(checkpoint_path)
114
- metadata = _generate_checkpoint_metadata(
115
- config, trainer, checkpoint_path, metadata_path
116
- )
120
+ metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
117
121
 
118
122
  # Write the metadata to the checkpoint directory
119
123
  try:
@@ -0,0 +1 @@
1
+ from __future__ import annotations