nshtrainer 1.0.0b37__tar.gz → 1.0.0b40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/PKG-INFO +2 -2
  2. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/pyproject.toml +19 -1
  3. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/_directory.py +1 -1
  4. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +13 -12
  5. nshtrainer-1.0.0b40/src/nshtrainer/configs/.gitattributes +1 -0
  6. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/loggers/base.py +9 -0
  7. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/nn/mlp.py +64 -45
  8. nshtrainer-1.0.0b40/src/nshtrainer/nn/tests/test_mlp.py +55 -0
  9. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/README.md +0 -0
  10. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/.nshconfig.generated.json +0 -0
  11. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/__init__.py +0 -0
  12. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/_callback.py +0 -0
  13. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  14. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/_checkpoint/saver.py +0 -0
  15. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/_experimental/__init__.py +0 -0
  16. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/_hf_hub.py +0 -0
  17. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/__init__.py +0 -0
  18. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/actsave.py +0 -0
  19. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/base.py +0 -0
  20. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  21. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  22. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  23. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  24. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  25. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  26. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  27. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/ema.py +0 -0
  28. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  29. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  30. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/interval.py +0 -0
  31. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  32. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  33. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  34. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/print_table.py +0 -0
  35. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  36. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  37. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/timer.py +0 -0
  38. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  39. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  40. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/__init__.py +0 -0
  41. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  42. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  43. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  44. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  45. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  46. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  47. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  48. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  49. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  50. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  51. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  52. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  53. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  54. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  55. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  56. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  57. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  58. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  59. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  60. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  61. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  62. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  63. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  64. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  65. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  66. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  67. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  68. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  69. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  70. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  71. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  72. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  73. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  74. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  75. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  76. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  77. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  78. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  79. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  80. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/nn/__init__.py +0 -0
  81. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  82. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  83. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  84. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  85. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  86. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  87. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  88. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  89. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/__init__.py +0 -0
  90. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
  91. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  92. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  93. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  94. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  95. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  96. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  97. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  98. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  99. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  100. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/util/__init__.py +0 -0
  101. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  102. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  103. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  104. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  105. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/data/__init__.py +0 -0
  106. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  107. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/data/datamodule.py +0 -0
  108. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/data/transform.py +0 -0
  109. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/loggers/__init__.py +0 -0
  110. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/loggers/actsave.py +0 -0
  111. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/loggers/csv.py +0 -0
  112. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/loggers/tensorboard.py +0 -0
  113. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/loggers/wandb.py +0 -0
  114. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  115. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/lr_scheduler/base.py +0 -0
  116. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  117. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  118. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/metrics/__init__.py +0 -0
  119. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/metrics/_config.py +0 -0
  120. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/model/__init__.py +0 -0
  121. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/model/base.py +0 -0
  122. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/model/mixins/callback.py +0 -0
  123. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/model/mixins/debug.py +0 -0
  124. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/model/mixins/logger.py +0 -0
  125. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/nn/__init__.py +0 -0
  126. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/nn/module_dict.py +0 -0
  127. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/nn/module_list.py +0 -0
  128. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/nn/nonlinearity.py +0 -0
  129. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/optimizer.py +0 -0
  130. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/profiler/__init__.py +0 -0
  131. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/profiler/_base.py +0 -0
  132. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/profiler/advanced.py +0 -0
  133. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/profiler/pytorch.py +0 -0
  134. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/profiler/simple.py +0 -0
  135. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/__init__.py +0 -0
  136. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/_config.py +0 -0
  137. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  138. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/accelerator.py +0 -0
  139. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  140. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/plugin/base.py +0 -0
  141. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  142. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/plugin/io.py +0 -0
  143. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  144. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  145. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/signal_connector.py +0 -0
  146. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/strategy.py +0 -0
  147. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/trainer/trainer.py +0 -0
  148. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/_environment_info.py +0 -0
  149. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/bf16.py +0 -0
  150. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/config/__init__.py +0 -0
  151. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/config/dtype.py +0 -0
  152. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/config/duration.py +0 -0
  153. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/environment.py +0 -0
  154. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/path.py +0 -0
  155. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/seed.py +0 -0
  156. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/slurm.py +0 -0
  157. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/typed.py +0 -0
  158. {nshtrainer-1.0.0b37 → nshtrainer-1.0.0b40}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b37
3
+ Version: 1.0.0b40
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta37"
3
+ version = "1.0.0-beta40"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -29,6 +29,8 @@ pyright = "*"
29
29
  ruff = "*"
30
30
  ipykernel = "*"
31
31
  ipywidgets = "*"
32
+ pytest = "^8.3.5"
33
+ pytest-cov = "^6.0.0"
32
34
 
33
35
  [build-system]
34
36
  requires = ["poetry-core"]
@@ -58,3 +60,19 @@ extra = [
58
60
  "huggingface-hub",
59
61
  "nshutils",
60
62
  ]
63
+
64
+ [tool.pytest]
65
+ testpaths = ["tests"]
66
+ python_files = "test_*.py"
67
+ python_functions = "test_*"
68
+ python_classes = "Test*"
69
+
70
+ [tool.pytest.ini_options]
71
+ minversion = "6.0"
72
+ # addopts = "--cov=src/nshtrainer --cov-report=term-missing --cov-report=xml --"
73
+ addopts = [
74
+ "--import-mode=importlib",
75
+ "--cov=src/nshtrainer",
76
+ "--cov-report=term-missing",
77
+ "--cov-report=xml",
78
+ ]
@@ -81,7 +81,7 @@ class DirectoryConfig(C.Config):
81
81
 
82
82
  # Save to nshtrainer/{id}/log/{logger name}
83
83
  log_dir = self.resolve_subdirectory(run_id, "log")
84
- log_dir = log_dir / getattr(logger, "name")
84
+ log_dir = log_dir / logger.resolve_logger_dirname()
85
85
  # ^ NOTE: Logger must have a `name` attribute, as this is
86
86
  # the discriminator for the logger registry
87
87
  log_dir.mkdir(exist_ok=True)
@@ -21,11 +21,8 @@ log = logging.getLogger(__name__)
21
21
  class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
22
22
  name: Literal["last_checkpoint"] = "last_checkpoint"
23
23
 
24
- save_on_time_interval: bool = True
25
- """Whether to save checkpoints based on time interval."""
26
-
27
- interval: timedelta = timedelta(hours=12)
28
- """Time interval between checkpoints when save_on_time_interval is True."""
24
+ save_on_time_interval: timedelta | None = None
25
+ """Save a checkpoint every `save_on_time_interval` seconds. If `None`, this feature is disabled."""
29
26
 
30
27
  @override
31
28
  def create_checkpoint(self, trainer_config, dirpath):
@@ -38,8 +35,6 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
38
35
  super().__init__(config, dirpath)
39
36
  self.start_time = time.time()
40
37
  self.last_checkpoint_time = self.start_time
41
- self.interval_seconds = config.interval.total_seconds()
42
- self.save_on_time_interval = config.save_on_time_interval
43
38
 
44
39
  @override
45
40
  def name(self):
@@ -57,12 +52,18 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
57
52
  def topk_sort_reverse(self):
58
53
  return True
59
54
 
60
- def _should_checkpoint(self) -> bool:
61
- if not self.save_on_time_interval:
55
+ def _local_should_checkpoint(self) -> bool:
56
+ if (interval := self.config.save_on_time_interval) is None:
62
57
  return False
58
+
63
59
  current_time = time.time()
64
60
  elapsed_time = current_time - self.last_checkpoint_time
65
- return elapsed_time >= self.interval_seconds
61
+ return elapsed_time >= interval.total_seconds()
62
+
63
+ def _should_checkpoint(self, trainer: Trainer):
64
+ if self.config.save_on_time_interval is None:
65
+ return False
66
+ return trainer.strategy.broadcast(self._local_should_checkpoint(), src=0)
66
67
 
67
68
  def _format_duration(self, seconds: float) -> str:
68
69
  """Format duration in seconds to a human-readable string."""
@@ -98,7 +99,7 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
98
99
  *args,
99
100
  **kwargs,
100
101
  ):
101
- if not self._should_checkpoint():
102
+ if not self._should_checkpoint(trainer):
102
103
  return
103
104
  self.save_checkpoints(trainer)
104
105
 
@@ -110,5 +111,5 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
110
111
  def save_checkpoints(self, trainer):
111
112
  super().save_checkpoints(trainer)
112
113
 
113
- if self.save_on_time_interval:
114
+ if self.config.save_on_time_interval is not None:
114
115
  self.last_checkpoint_time = time.time()
@@ -0,0 +1 @@
1
+ * linguist-generated=true
@@ -30,5 +30,14 @@ class LoggerConfigBase(C.Config, ABC):
30
30
  def __bool__(self):
31
31
  return self.enabled
32
32
 
33
+ def resolve_logger_dirname(self) -> str:
34
+ if not (name := getattr(self, "name", None)):
35
+ raise ValueError(
36
+ "Logger must have a name attribute to resolve the directory name.\n"
37
+ "Otherwise, you must override `resolve_logger_dirname`."
38
+ )
39
+
40
+ return name
41
+
33
42
 
34
43
  logger_registry = C.Registry(LoggerConfigBase, discriminator="name")
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import copy
4
5
  from collections.abc import Callable, Sequence
5
6
  from typing import Literal, Protocol, runtime_checkable
@@ -44,6 +45,9 @@ class MLPConfigDict(TypedDict):
44
45
  residual: bool
45
46
  """Whether to use residual connections between layers."""
46
47
 
48
+ seed: int | None
49
+ """Random seed to use for initialization. If None, the default Torch behavior is used."""
50
+
47
51
 
48
52
  class MLPConfig(C.Config):
49
53
  bias: bool = True
@@ -64,15 +68,20 @@ class MLPConfig(C.Config):
64
68
  residual: bool = False
65
69
  """Whether to use residual connections between layers."""
66
70
 
71
+ seed: int | None = None
72
+ """Random seed to use for initialization. If None, the default Torch behavior is used."""
73
+
67
74
  def to_kwargs(self) -> MLPConfigDict:
68
- return {
75
+ kwargs: MLPConfigDict = {
69
76
  "bias": self.bias,
70
77
  "no_bias_scalar": self.no_bias_scalar,
71
78
  "nonlinearity": self.nonlinearity,
72
79
  "ln": self.ln,
73
80
  "dropout": self.dropout,
74
81
  "residual": self.residual,
82
+ "seed": self.seed,
75
83
  }
84
+ return kwargs
76
85
 
77
86
  def create_module(
78
87
  self,
@@ -108,6 +117,7 @@ def MLP(
108
117
  pre_layers: Sequence[nn.Module] = [],
109
118
  post_layers: Sequence[nn.Module] = [],
110
119
  linear_cls: LinearModuleConstructor = nn.Linear,
120
+ seed: int | None = None,
111
121
  ):
112
122
  """
113
123
  Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
@@ -123,52 +133,61 @@ def MLP(
123
133
  residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
124
134
  pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
125
135
  post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
136
+ linear_cls (LinearModuleConstructor, optional): Linear module constructor to use. Defaults to nn.Linear.
137
+ seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used. Defaults to None.
126
138
 
127
139
  Returns:
128
140
  nn.Sequential: The constructed MLP.
129
141
  """
130
142
 
131
- if activation is None:
132
- activation = nonlinearity
133
-
134
- if len(dims) < 2:
135
- raise ValueError("mlp requires at least 2 dimensions")
136
- if ln is True:
137
- ln = "pre"
138
- elif isinstance(ln, str) and ln not in ("pre", "post"):
139
- raise ValueError("ln must be a boolean or 'pre' or 'post'")
140
-
141
- layers: list[nn.Module] = []
142
- if ln == "pre":
143
- layers.append(nn.LayerNorm(dims[0]))
144
-
145
- layers.extend(pre_layers)
146
-
147
- for i in range(len(dims) - 1):
148
- in_features = dims[i]
149
- out_features = dims[i + 1]
150
- bias_ = bias and not (no_bias_scalar and out_features == 1)
151
- layers.append(linear_cls(in_features, out_features, bias=bias_))
152
- if dropout is not None:
153
- layers.append(nn.Dropout(dropout))
154
- if i < len(dims) - 2:
155
- match activation:
156
- case NonlinearityConfigBase():
157
- layers.append(activation.create_module())
158
- case nn.Module():
159
- # In this case, we create a deep copy of the module to avoid sharing parameters (if any).
160
- layers.append(copy.deepcopy(activation))
161
- case Callable():
162
- layers.append(activation())
163
- case _:
164
- raise ValueError(
165
- "Either `nonlinearity` or `activation` must be provided"
166
- )
167
-
168
- layers.extend(post_layers)
169
-
170
- if ln == "post":
171
- layers.append(nn.LayerNorm(dims[-1]))
172
-
173
- cls = ResidualSequential if residual else nn.Sequential
174
- return cls(*layers)
143
+ with contextlib.ExitStack() as stack:
144
+ if seed is not None:
145
+ stack.enter_context(
146
+ torch.random.fork_rng(devices=range(torch.cuda.device_count()))
147
+ )
148
+ torch.manual_seed(seed)
149
+
150
+ if activation is None:
151
+ activation = nonlinearity
152
+
153
+ if len(dims) < 2:
154
+ raise ValueError("mlp requires at least 2 dimensions")
155
+ if ln is True:
156
+ ln = "pre"
157
+ elif isinstance(ln, str) and ln not in ("pre", "post"):
158
+ raise ValueError("ln must be a boolean or 'pre' or 'post'")
159
+
160
+ layers: list[nn.Module] = []
161
+ if ln == "pre":
162
+ layers.append(nn.LayerNorm(dims[0]))
163
+
164
+ layers.extend(pre_layers)
165
+
166
+ for i in range(len(dims) - 1):
167
+ in_features = dims[i]
168
+ out_features = dims[i + 1]
169
+ bias_ = bias and not (no_bias_scalar and out_features == 1)
170
+ layers.append(linear_cls(in_features, out_features, bias=bias_))
171
+ if dropout is not None:
172
+ layers.append(nn.Dropout(dropout))
173
+ if i < len(dims) - 2:
174
+ match activation:
175
+ case NonlinearityConfigBase():
176
+ layers.append(activation.create_module())
177
+ case nn.Module():
178
+ # In this case, we create a deep copy of the module to avoid sharing parameters (if any).
179
+ layers.append(copy.deepcopy(activation))
180
+ case Callable():
181
+ layers.append(activation())
182
+ case _:
183
+ raise ValueError(
184
+ "Either `nonlinearity` or `activation` must be provided"
185
+ )
186
+
187
+ layers.extend(post_layers)
188
+
189
+ if ln == "post":
190
+ layers.append(nn.LayerNorm(dims[-1]))
191
+
192
+ cls = ResidualSequential if residual else nn.Sequential
193
+ return cls(*layers)
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from nshtrainer.nn.mlp import MLP
9
+
10
+
11
+ def test_mlp_seed_reproducibility():
12
+ """Test that the seed parameter in MLP ensures reproducible weights."""
13
+
14
+ # Test dimensions
15
+ dims = [10, 20, 5]
16
+
17
+ # Create two MLPs with the same seed
18
+ seed1 = 42
19
+ mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
20
+ mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
21
+
22
+ # Create an MLP with a different seed
23
+ seed2 = 123
24
+ mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
25
+
26
+ # Check first layer weights
27
+ layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
28
+ layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
29
+ layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
30
+
31
+ # Same seed should produce identical weights
32
+ assert torch.allclose(layer1_weights1, layer1_weights2)
33
+
34
+ # Different seeds should produce different weights
35
+ assert not torch.allclose(layer1_weights1, layer1_weights3)
36
+
37
+ # Check second layer weights
38
+ layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
39
+ layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
40
+ layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
41
+
42
+ # Same seed should produce identical weights for all layers
43
+ assert torch.allclose(layer2_weights1, layer2_weights2)
44
+
45
+ # Different seeds should produce different weights for all layers
46
+ assert not torch.allclose(layer2_weights1, layer2_weights3)
47
+
48
+ # Test that not providing a seed gives different results each time
49
+ mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
50
+ mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
51
+
52
+ # Without seeds, weights should be different
53
+ assert not torch.allclose(
54
+ cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
55
+ )
File without changes