nshtrainer 1.1.1b1__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/PKG-INFO +1 -1
  2. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/pyproject.toml +4 -3
  3. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_directory.py +3 -3
  4. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/__init__.py +6 -0
  5. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/base.py +22 -3
  6. nshtrainer-1.2.0/src/nshtrainer/callbacks/distributed_prediction_writer.py +166 -0
  7. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/__init__.py +28 -0
  8. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/__init__.py +6 -0
  9. nshtrainer-1.2.0/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +19 -0
  10. nshtrainer-1.2.0/src/nshtrainer/configs/optimizer/__init__.py +39 -0
  11. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/__init__.py +4 -0
  12. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/_config/__init__.py +4 -0
  13. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/base.py +60 -2
  14. nshtrainer-1.2.0/src/nshtrainer/optimizer.py +626 -0
  15. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/_config.py +10 -4
  16. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/trainer.py +21 -2
  17. nshtrainer-1.1.1b1/src/nshtrainer/configs/optimizer/__init__.py +0 -15
  18. nshtrainer-1.1.1b1/src/nshtrainer/optimizer.py +0 -68
  19. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/README.md +0 -0
  20. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
  21. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/__init__.py +0 -0
  22. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_callback.py +0 -0
  23. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  24. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  25. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  26. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_hf_hub.py +0 -0
  27. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  28. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  29. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  30. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  31. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  32. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  33. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  34. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  35. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  36. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/ema.py +0 -0
  37. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  38. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  39. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/interval.py +0 -0
  40. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  41. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  42. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
  43. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  44. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  45. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  46. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  47. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/timer.py +0 -0
  48. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  49. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  50. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/.gitattributes +0 -0
  51. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  52. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  53. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  54. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  55. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  56. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  57. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  58. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  59. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  60. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  61. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  62. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  63. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  64. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  65. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  66. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  67. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  68. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  69. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  70. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  71. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  72. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  73. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  74. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  75. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  76. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  77. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  78. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  79. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  80. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  81. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  82. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  83. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  84. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  85. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  86. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  87. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  88. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  89. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  90. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
  91. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  92. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  93. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
  94. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  95. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  96. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  97. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  98. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  99. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  100. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  101. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  102. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  103. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  104. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  105. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  106. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  107. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  108. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/__init__.py +0 -0
  109. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  110. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  111. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  112. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  113. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/__init__.py +0 -0
  114. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  115. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/datamodule.py +0 -0
  116. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/transform.py +0 -0
  117. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/__init__.py +0 -0
  118. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/actsave.py +0 -0
  119. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/base.py +0 -0
  120. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/csv.py +0 -0
  121. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  122. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/wandb.py +0 -0
  123. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  124. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/base.py +0 -0
  125. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  126. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  127. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/metrics/__init__.py +0 -0
  128. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/metrics/_config.py +0 -0
  129. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/__init__.py +0 -0
  130. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/mixins/callback.py +0 -0
  131. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/mixins/debug.py +0 -0
  132. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/mixins/logger.py +0 -0
  133. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/__init__.py +0 -0
  134. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/mlp.py +0 -0
  135. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/module_dict.py +0 -0
  136. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/module_list.py +0 -0
  137. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  138. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/rng.py +0 -0
  139. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/__init__.py +0 -0
  140. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/_base.py +0 -0
  141. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/advanced.py +0 -0
  142. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  143. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/simple.py +0 -0
  144. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/__init__.py +0 -0
  145. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
  146. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  147. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/accelerator.py +0 -0
  148. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  149. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
  150. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  151. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
  152. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  153. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  154. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  155. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/strategy.py +0 -0
  156. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/_environment_info.py +0 -0
  157. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/bf16.py +0 -0
  158. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/config/__init__.py +0 -0
  159. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/config/dtype.py +0 -0
  160. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/config/duration.py +0 -0
  161. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/environment.py +0 -0
  162. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/path.py +0 -0
  163. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/seed.py +0 -0
  164. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/slurm.py +0 -0
  165. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/typed.py +0 -0
  166. {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.1.1b1
3
+ Version: 1.2.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nshtrainer"
3
- version = "1.1.1-beta.1"
3
+ version = "1.2.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
@@ -33,8 +33,9 @@ basedpyright = "*"
33
33
  ruff = "*"
34
34
  ipykernel = "*"
35
35
  ipywidgets = "*"
36
- pytest = "^8.3.5"
37
- pytest-cov = "^6.0.0"
36
+ pytest = "*"
37
+ pytest-cov = "*"
38
+ pytest-forked = "*"
38
39
 
39
40
  [build-system]
40
41
  requires = ["poetry-core"]
@@ -65,9 +65,9 @@ class DirectoryConfig(C.Config):
65
65
  ) -> Path:
66
66
  # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
67
67
  if (subdir := getattr(self, subdirectory, None)) is not None:
68
- assert isinstance(
69
- subdir, Path
70
- ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
68
+ assert isinstance(subdir, Path), (
69
+ f"Expected a Path for {subdirectory}, got {type(subdir)}"
70
+ )
71
71
  return subdir
72
72
 
73
73
  dir = self.resolve_run_root_directory(run_id)
@@ -23,6 +23,12 @@ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
23
23
  from .directory_setup import (
24
24
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
25
25
  )
26
+ from .distributed_prediction_writer import (
27
+ DistributedPredictionWriter as DistributedPredictionWriter,
28
+ )
29
+ from .distributed_prediction_writer import (
30
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
31
+ )
26
32
  from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
27
33
  from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
28
34
  from .ema import EMACallback as EMACallback
@@ -23,6 +23,10 @@ class CallbackMetadataConfig(TypedDict, total=False):
23
23
  """Priority of the callback. Callbacks with higher priority will be loaded first.
24
24
  Default is `0`."""
25
25
 
26
+ enabled_for_barebones: bool
27
+ """Whether this callback is enabled for barebones mode.
28
+ Default is `False`."""
29
+
26
30
 
27
31
  @dataclass(frozen=True)
28
32
  class CallbackWithMetadata:
@@ -91,10 +95,20 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
91
95
 
92
96
 
93
97
  def _process_and_filter_callbacks(
98
+ trainer_config: TrainerConfig,
94
99
  callbacks: Iterable[CallbackWithMetadata],
95
100
  ) -> list[Callback]:
96
101
  callbacks = list(callbacks)
97
102
 
103
+ # If we're in barebones mode, used the callback metadata
104
+ # to decide to keep/remove the callback.
105
+ if trainer_config.barebones:
106
+ callbacks = [
107
+ callback
108
+ for callback in callbacks
109
+ if callback.metadata.get("enabled_for_barebones", False)
110
+ ]
111
+
98
112
  # Sort by priority (higher priority first)
99
113
  callbacks.sort(
100
114
  key=lambda callback: callback.metadata.get("priority", 0),
@@ -114,9 +128,14 @@ def resolve_all_callbacks(trainer_config: TrainerConfig):
114
128
  if config is not None
115
129
  ]
116
130
  callbacks = _process_and_filter_callbacks(
117
- callback
118
- for callback_config in callback_configs
119
- for callback in _create_callbacks_with_metadata(callback_config, trainer_config)
131
+ trainer_config,
132
+ (
133
+ callback
134
+ for callback_config in callback_configs
135
+ for callback in _create_callbacks_with_metadata(
136
+ callback_config, trainer_config
137
+ )
138
+ ),
120
139
  )
121
140
  return callbacks
122
141
 
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import logging
5
+ from collections.abc import Iterator, Sequence
6
+ from pathlib import Path
7
+ from typing import Any, ClassVar, Literal, overload
8
+
9
+ import torch
10
+ from lightning.fabric.utilities.apply_func import move_data_to_device
11
+ from lightning.pytorch.callbacks import BasePredictionWriter
12
+ from typing_extensions import final, override
13
+
14
+ from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ @final
20
+ @callback_registry.register
21
+ class DistributedPredictionWriterConfig(CallbackConfigBase):
22
+ metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig(
23
+ enabled_for_barebones=True
24
+ )
25
+ """Metadata for the callback."""
26
+
27
+ name: Literal["distributed_prediction_writer"] = "distributed_prediction_writer"
28
+
29
+ dirpath: Path | None = None
30
+ """Directory to save the predictions to. If None, will use the default directory."""
31
+
32
+ move_to_cpu_on_save: bool = True
33
+ """Whether to move the predictions to CPU before saving. Default is True."""
34
+
35
+ save_raw: bool = True
36
+ """Whether to save the raw predictions."""
37
+
38
+ save_processed: bool = True
39
+ """Whether to process and save the predictions.
40
+
41
+ "Processing" means that the model's batched predictions are split into individual predictions
42
+ and saved as a list of tensors.
43
+ """
44
+
45
+ @override
46
+ def create_callbacks(self, trainer_config):
47
+ if (dirpath := self.dirpath) is None:
48
+ dirpath = trainer_config.directory.resolve_subdirectory(
49
+ trainer_config.id, "predictions"
50
+ )
51
+
52
+ yield DistributedPredictionWriter(self, dirpath)
53
+
54
+
55
+ def _move_and_save(data, path: Path, move_to_cpu: bool):
56
+ if move_to_cpu:
57
+ data = move_data_to_device(data, "cpu")
58
+
59
+ # Save the data to the specified path
60
+ torch.save(data, path)
61
+
62
+
63
+ class DistributedPredictionWriter(BasePredictionWriter):
64
+ def __init__(
65
+ self,
66
+ config: DistributedPredictionWriterConfig,
67
+ output_dir: Path,
68
+ ):
69
+ self.config = config
70
+
71
+ super().__init__(write_interval="batch")
72
+
73
+ self.output_dir = output_dir
74
+
75
+ @override
76
+ def write_on_batch_end(
77
+ self,
78
+ trainer,
79
+ pl_module,
80
+ prediction,
81
+ batch_indices,
82
+ batch,
83
+ batch_idx,
84
+ dataloader_idx,
85
+ ):
86
+ save = functools.partial(
87
+ _move_and_save,
88
+ move_to_cpu=self.config.move_to_cpu_on_save,
89
+ )
90
+
91
+ # Regular, unstructured writing.
92
+ if self.config.save_raw:
93
+ output_dir = (
94
+ self.output_dir
95
+ / "raw"
96
+ / f"dataloader_{dataloader_idx}"
97
+ / f"rank_{trainer.global_rank}"
98
+ / f"batch_{batch_idx}"
99
+ )
100
+ output_dir.mkdir(parents=True, exist_ok=True)
101
+ save(prediction, output_dir / "predictions.pt")
102
+ save(batch, output_dir / "batch.pt")
103
+ save(batch_indices, output_dir / "batch_indices.pt")
104
+
105
+ if self.config.save_processed:
106
+ # Processed writing.
107
+ from ..model.base import LightningModuleBase
108
+
109
+ if not isinstance(pl_module, LightningModuleBase):
110
+ raise ValueError(
111
+ "The model must be a subclass of LightningModuleBase to use the distributed prediction writer."
112
+ )
113
+
114
+ output_dir = self.output_dir / "processed" / f"dataloader_{dataloader_idx}"
115
+ output_dir.mkdir(parents=True, exist_ok=True)
116
+
117
+ # Split into individual predictions
118
+ assert batch_indices is not None, (
119
+ "Batch indices must be provided for processed writing."
120
+ )
121
+ for sample in pl_module.split_batched_predictions(
122
+ batch, prediction, batch_indices
123
+ ):
124
+ sample = {
125
+ **sample,
126
+ "global_rank": trainer.global_rank,
127
+ "world_size": trainer.world_size,
128
+ "is_global_zero": trainer.is_global_zero,
129
+ }
130
+ save(sample, output_dir / f"{sample['index']}.pt")
131
+
132
+
133
+ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
134
+ def __init__(self, output_dir: Path):
135
+ self.output_dir = output_dir
136
+
137
+ @override
138
+ def __len__(self) -> int:
139
+ return len(list(self.output_dir.glob("*.pt")))
140
+
141
+ @overload
142
+ def __getitem__(self, index: int) -> tuple[Any, Any]: ...
143
+
144
+ @overload
145
+ def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
146
+
147
+ @override
148
+ def __getitem__(
149
+ self, index: int | slice
150
+ ) -> tuple[Any, Any] | list[tuple[Any, Any]]:
151
+ if isinstance(index, slice):
152
+ # Handle slice indexing
153
+ indices = range(*index.indices(len(self)))
154
+ return [self.__getitem__(i) for i in indices]
155
+
156
+ # Handle integer indexing
157
+ path = self.output_dir / f"{index}.pt"
158
+ if not path.exists():
159
+ raise FileNotFoundError(f"File {path} does not exist.")
160
+ sample = torch.load(path)
161
+ return sample["batch"], sample["prediction"]
162
+
163
+ @override
164
+ def __iter__(self) -> Iterator[tuple[Any, Any]]:
165
+ for i in range(len(self)):
166
+ yield self[i]
@@ -21,6 +21,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
21
21
  from nshtrainer.callbacks import (
22
22
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
23
23
  )
24
+ from nshtrainer.callbacks import (
25
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
26
+ )
24
27
  from nshtrainer.callbacks import (
25
28
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
26
29
  )
@@ -95,9 +98,21 @@ from nshtrainer.nn.nonlinearity import (
95
98
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
96
99
  )
97
100
  from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
101
+ from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
102
+ from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
103
+ from nshtrainer.optimizer import AdagradConfig as AdagradConfig
104
+ from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
105
+ from nshtrainer.optimizer import AdamConfig as AdamConfig
98
106
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
107
+ from nshtrainer.optimizer import ASGDConfig as ASGDConfig
108
+ from nshtrainer.optimizer import NAdamConfig as NAdamConfig
99
109
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
100
110
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
111
+ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
112
+ from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
113
+ from nshtrainer.optimizer import RpropConfig as RpropConfig
114
+ from nshtrainer.optimizer import SGDConfig as SGDConfig
115
+ from nshtrainer.optimizer import Union as Union
101
116
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
102
117
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
103
118
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
@@ -225,11 +240,17 @@ from . import trainer as trainer
225
240
  from . import util as util
226
241
 
227
242
  __all__ = [
243
+ "ASGDConfig",
228
244
  "AcceleratorConfig",
229
245
  "AcceleratorConfigBase",
230
246
  "ActSaveConfig",
231
247
  "ActSaveLoggerConfig",
248
+ "AdadeltaConfig",
249
+ "AdafactorConfig",
250
+ "AdagradConfig",
251
+ "AdamConfig",
232
252
  "AdamWConfig",
253
+ "AdamaxConfig",
233
254
  "AdvancedProfilerConfig",
234
255
  "AsyncCheckpointIOPlugin",
235
256
  "BaseCheckpointCallbackConfig",
@@ -249,6 +270,7 @@ __all__ = [
249
270
  "DeepSpeedPluginConfig",
250
271
  "DirectoryConfig",
251
272
  "DirectorySetupCallbackConfig",
273
+ "DistributedPredictionWriterConfig",
252
274
  "DoublePrecisionPluginConfig",
253
275
  "DurationConfig",
254
276
  "ELUNonlinearityConfig",
@@ -294,6 +316,7 @@ __all__ = [
294
316
  "MetricValidationCallbackConfig",
295
317
  "MishNonlinearityConfig",
296
318
  "MixedPrecisionPluginConfig",
319
+ "NAdamConfig",
297
320
  "NonlinearityConfig",
298
321
  "NonlinearityConfigBase",
299
322
  "NormLoggingCallbackConfig",
@@ -306,10 +329,14 @@ __all__ = [
306
329
  "PrintTableMetricsCallbackConfig",
307
330
  "ProfilerConfig",
308
331
  "PyTorchProfilerConfig",
332
+ "RAdamConfig",
309
333
  "RLPSanityChecksCallbackConfig",
334
+ "RMSpropConfig",
310
335
  "RNGConfig",
311
336
  "ReLUNonlinearityConfig",
312
337
  "ReduceLROnPlateauConfig",
338
+ "RpropConfig",
339
+ "SGDConfig",
313
340
  "SLURMEnvironmentPlugin",
314
341
  "SanityCheckingConfig",
315
342
  "SharedParametersCallbackConfig",
@@ -331,6 +358,7 @@ __all__ = [
331
358
  "TorchSyncBatchNormPlugin",
332
359
  "TrainerConfig",
333
360
  "TransformerEnginePluginConfig",
361
+ "Union",
334
362
  "WandbLoggerConfig",
335
363
  "WandbUploadCodeCallbackConfig",
336
364
  "WandbWatchCallbackConfig",
@@ -12,6 +12,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
12
12
  from nshtrainer.callbacks import (
13
13
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
14
14
  )
15
+ from nshtrainer.callbacks import (
16
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
17
+ )
15
18
  from nshtrainer.callbacks import (
16
19
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
17
20
  )
@@ -62,6 +65,7 @@ from . import base as base
62
65
  from . import checkpoint as checkpoint
63
66
  from . import debug_flag as debug_flag
64
67
  from . import directory_setup as directory_setup
68
+ from . import distributed_prediction_writer as distributed_prediction_writer
65
69
  from . import early_stopping as early_stopping
66
70
  from . import ema as ema
67
71
  from . import finite_checks as finite_checks
@@ -86,6 +90,7 @@ __all__ = [
86
90
  "CheckpointMetadata",
87
91
  "DebugFlagCallbackConfig",
88
92
  "DirectorySetupCallbackConfig",
93
+ "DistributedPredictionWriterConfig",
89
94
  "EMACallbackConfig",
90
95
  "EarlyStoppingCallbackConfig",
91
96
  "EpochTimerCallbackConfig",
@@ -109,6 +114,7 @@ __all__ = [
109
114
  "checkpoint",
110
115
  "debug_flag",
111
116
  "directory_setup",
117
+ "distributed_prediction_writer",
112
118
  "early_stopping",
113
119
  "ema",
114
120
  "finite_checks",
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.distributed_prediction_writer import (
6
+ CallbackConfigBase as CallbackConfigBase,
7
+ )
8
+ from nshtrainer.callbacks.distributed_prediction_writer import (
9
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
10
+ )
11
+ from nshtrainer.callbacks.distributed_prediction_writer import (
12
+ callback_registry as callback_registry,
13
+ )
14
+
15
+ __all__ = [
16
+ "CallbackConfigBase",
17
+ "DistributedPredictionWriterConfig",
18
+ "callback_registry",
19
+ ]
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
6
+ from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
7
+ from nshtrainer.optimizer import AdagradConfig as AdagradConfig
8
+ from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
9
+ from nshtrainer.optimizer import AdamConfig as AdamConfig
10
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
11
+ from nshtrainer.optimizer import ASGDConfig as ASGDConfig
12
+ from nshtrainer.optimizer import NAdamConfig as NAdamConfig
13
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
14
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
15
+ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
16
+ from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
17
+ from nshtrainer.optimizer import RpropConfig as RpropConfig
18
+ from nshtrainer.optimizer import SGDConfig as SGDConfig
19
+ from nshtrainer.optimizer import Union as Union
20
+ from nshtrainer.optimizer import optimizer_registry as optimizer_registry
21
+
22
+ __all__ = [
23
+ "ASGDConfig",
24
+ "AdadeltaConfig",
25
+ "AdafactorConfig",
26
+ "AdagradConfig",
27
+ "AdamConfig",
28
+ "AdamWConfig",
29
+ "AdamaxConfig",
30
+ "NAdamConfig",
31
+ "OptimizerConfig",
32
+ "OptimizerConfigBase",
33
+ "RAdamConfig",
34
+ "RMSpropConfig",
35
+ "RpropConfig",
36
+ "SGDConfig",
37
+ "Union",
38
+ "optimizer_registry",
39
+ ]
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
+ from nshtrainer.trainer._config import (
26
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
27
+ )
25
28
  from nshtrainer.trainer._config import (
26
29
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
27
30
  )
@@ -149,6 +152,7 @@ __all__ = [
149
152
  "DebugFlagCallbackConfig",
150
153
  "DeepSpeedPluginConfig",
151
154
  "DirectoryConfig",
155
+ "DistributedPredictionWriterConfig",
152
156
  "DoublePrecisionPluginConfig",
153
157
  "EarlyStoppingCallbackConfig",
154
158
  "EnvironmentConfig",
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
+ from nshtrainer.trainer._config import (
22
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
23
+ )
21
24
  from nshtrainer.trainer._config import (
22
25
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
23
26
  )
@@ -70,6 +73,7 @@ __all__ = [
70
73
  "CheckpointSavingConfig",
71
74
  "DebugFlagCallbackConfig",
72
75
  "DirectoryConfig",
76
+ "DistributedPredictionWriterConfig",
73
77
  "EarlyStoppingCallbackConfig",
74
78
  "EnvironmentConfig",
75
79
  "GradientClippingConfig",
@@ -2,9 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import Callable, Mapping
5
+ from collections.abc import Callable, Iterable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import Any, Generic, Literal, cast
7
+ from typing import Any, Generic, Literal, TypedDict, cast
8
8
 
9
9
  import nshconfig as C
10
10
  import torch
@@ -53,6 +53,47 @@ VALID_REDUCE_OPS = (
53
53
  )
54
54
 
55
55
 
56
+ class IndividualSample(TypedDict):
57
+ """
58
+ A dictionary that contains the individual sample.
59
+ This is used to split the batched predictions into individual predictions.
60
+ """
61
+
62
+ index: int
63
+ """The index of the sample in the batch."""
64
+
65
+ batch: Any
66
+ """The batch to split."""
67
+
68
+ prediction: Any
69
+ """The batched prediction to split."""
70
+
71
+
72
+ def default_split_batched_predictions(
73
+ batch: Any,
74
+ prediction: Any,
75
+ batch_indices: Sequence[Any],
76
+ ) -> Iterable[IndividualSample]:
77
+ """
78
+ Splits the batched predictions into a list of individual predictions.
79
+ Args:
80
+ batch: The batch to split.
81
+ prediction: The batched prediction to split.
82
+ batch_indices: The indices of the batches.
83
+ Returns:
84
+ A tuple of two sequences: the corresponding batches and the individual predictions.
85
+ """
86
+ import torch.utils._pytree as tree
87
+
88
+ for sample_idx, batch_idx in enumerate(batch_indices):
89
+ # Create a dictionary for each sample
90
+ yield IndividualSample(
91
+ index=batch_idx,
92
+ batch=tree.tree_map(lambda x: x[sample_idx], batch),
93
+ prediction=tree.tree_map(lambda x: x[sample_idx], prediction),
94
+ )
95
+
96
+
56
97
  class LightningModuleBase(
57
98
  DebugModuleMixin,
58
99
  RLPSanityCheckModuleMixin,
@@ -171,6 +212,23 @@ class LightningModuleBase(
171
212
  loss = cast(torch.Tensor, loss)
172
213
  return loss
173
214
 
215
+ def split_batched_predictions(
216
+ self,
217
+ batch: Any,
218
+ prediction: Any,
219
+ batch_indices: Sequence[Any],
220
+ ) -> Iterable[IndividualSample]:
221
+ """
222
+ Splits the batched predictions into a list of individual predictions.
223
+ Args:
224
+ batch: The batch to split.
225
+ prediction: The batched prediction to split.
226
+ batch_indices: The indices of the batches.
227
+ Returns:
228
+ A tuple of two sequences: the corresponding batches and the individual predictions.
229
+ """
230
+ return default_split_batched_predictions(batch, prediction, batch_indices)
231
+
174
232
  @override
175
233
  @classmethod
176
234
  def load_from_checkpoint(cls, *args, **kwargs) -> Never: