nshtrainer 1.2.1__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/PKG-INFO +1 -1
  2. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/pyproject.toml +1 -1
  3. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/distributed_prediction_writer.py +22 -11
  4. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/__init__.py +3 -3
  5. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -4
  6. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +4 -0
  7. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/_config.py +0 -10
  8. nshtrainer-1.3.0/src/nshtrainer/trainer/_distributed_prediction_result.py +80 -0
  9. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/trainer.py +66 -2
  10. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/README.md +0 -0
  11. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
  12. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/__init__.py +0 -0
  13. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/_callback.py +0 -0
  14. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  15. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  16. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/_directory.py +0 -0
  17. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  18. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/_hf_hub.py +0 -0
  19. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  20. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  21. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/base.py +0 -0
  22. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  23. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  24. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  25. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  26. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  28. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  29. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  30. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/ema.py +0 -0
  31. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  32. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  33. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/interval.py +0 -0
  34. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  35. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  36. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
  37. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  38. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  39. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  40. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  41. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/timer.py +0 -0
  42. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  43. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  44. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/.gitattributes +0 -0
  45. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/__init__.py +0 -0
  46. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  47. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  48. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  49. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  50. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  51. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  52. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  53. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  54. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  55. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  56. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  58. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  59. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  60. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
  61. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  62. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  63. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  64. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  65. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  66. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  67. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
  68. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  69. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  70. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  71. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  72. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  73. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  74. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  75. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  76. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  77. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
  78. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  79. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  80. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  81. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  82. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
  83. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  84. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  85. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  86. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  87. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
  88. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  89. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  90. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
  91. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  92. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  93. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  94. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  95. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  96. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  97. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
  98. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
  99. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
  100. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
  101. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
  102. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
  103. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
  104. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
  105. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/util/__init__.py +0 -0
  106. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  107. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  108. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  109. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  110. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/data/__init__.py +0 -0
  111. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  112. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/data/datamodule.py +0 -0
  113. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/data/transform.py +0 -0
  114. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/loggers/__init__.py +0 -0
  115. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/loggers/actsave.py +0 -0
  116. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/loggers/base.py +0 -0
  117. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/loggers/csv.py +0 -0
  118. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  119. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/loggers/wandb.py +0 -0
  120. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  121. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/lr_scheduler/base.py +0 -0
  122. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  123. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  124. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/metrics/__init__.py +0 -0
  125. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/metrics/_config.py +0 -0
  126. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/model/__init__.py +0 -0
  127. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/model/base.py +0 -0
  128. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/model/mixins/callback.py +0 -0
  129. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/model/mixins/debug.py +0 -0
  130. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/model/mixins/logger.py +0 -0
  131. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/nn/__init__.py +0 -0
  132. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/nn/mlp.py +0 -0
  133. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/nn/module_dict.py +0 -0
  134. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/nn/module_list.py +0 -0
  135. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  136. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/nn/rng.py +0 -0
  137. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/optimizer.py +0 -0
  138. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/profiler/__init__.py +0 -0
  139. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/profiler/_base.py +0 -0
  140. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/profiler/advanced.py +0 -0
  141. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  142. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/profiler/simple.py +0 -0
  143. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/__init__.py +0 -0
  144. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
  145. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  146. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/accelerator.py +0 -0
  147. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
  148. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
  149. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
  150. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
  151. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
  152. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
  153. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  154. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/trainer/strategy.py +0 -0
  155. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/_environment_info.py +0 -0
  156. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/bf16.py +0 -0
  157. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/config/__init__.py +0 -0
  158. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/config/dtype.py +0 -0
  159. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/config/duration.py +0 -0
  160. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/environment.py +0 -0
  161. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/path.py +0 -0
  162. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/seed.py +0 -0
  163. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/slurm.py +0 -0
  164. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/typed.py +0 -0
  165. {nshtrainer-1.2.1 → nshtrainer-1.3.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.2.1
3
+ Version: 1.3.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nshtrainer"
3
- version = "1.2.1"
3
+ version = "1.3.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
@@ -4,15 +4,19 @@ import functools
4
4
  import logging
5
5
  from collections.abc import Iterator, Sequence
6
6
  from pathlib import Path
7
- from typing import Any, ClassVar, Literal, overload
7
+ from typing import TYPE_CHECKING, ClassVar, Generic, Literal, cast, overload
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.utilities.apply_func import move_data_to_device
11
11
  from lightning.pytorch.callbacks import BasePredictionWriter
12
- from typing_extensions import final, override
12
+ from typing_extensions import TypeVar, final, override
13
13
 
14
14
  from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
15
15
 
16
+ if TYPE_CHECKING:
17
+ from ..model.base import IndividualSample
18
+
19
+
16
20
  log = logging.getLogger(__name__)
17
21
 
18
22
 
@@ -130,7 +134,15 @@ class DistributedPredictionWriter(BasePredictionWriter):
130
134
  save(sample, output_dir / f"{sample['index']}.pt")
131
135
 
132
136
 
133
- class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
137
+ SampleT = TypeVar(
138
+ "SampleT",
139
+ bound="IndividualSample",
140
+ default="IndividualSample",
141
+ infer_variance=True,
142
+ )
143
+
144
+
145
+ class DistributedPredictionReader(Sequence[SampleT], Generic[SampleT]):
134
146
  def __init__(self, output_dir: Path):
135
147
  self.output_dir = output_dir
136
148
 
@@ -139,15 +151,13 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
139
151
  return len(list(self.output_dir.glob("*.pt")))
140
152
 
141
153
  @overload
142
- def __getitem__(self, index: int) -> tuple[Any, Any]: ...
154
+ def __getitem__(self, index: int) -> SampleT: ...
143
155
 
144
156
  @overload
145
- def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
157
+ def __getitem__(self, index: slice) -> list[SampleT]: ...
146
158
 
147
159
  @override
148
- def __getitem__(
149
- self, index: int | slice
150
- ) -> tuple[Any, Any] | list[tuple[Any, Any]]:
160
+ def __getitem__(self, index: int | slice) -> SampleT | list[SampleT]:
151
161
  if isinstance(index, slice):
152
162
  # Handle slice indexing
153
163
  indices = range(*index.indices(len(self)))
@@ -157,10 +167,11 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
157
167
  path = self.output_dir / f"{index}.pt"
158
168
  if not path.exists():
159
169
  raise FileNotFoundError(f"File {path} does not exist.")
160
- sample = torch.load(path)
161
- return sample["batch"], sample["prediction"]
170
+
171
+ sample = cast(SampleT, torch.load(path))
172
+ return sample
162
173
 
163
174
  @override
164
- def __iter__(self) -> Iterator[tuple[Any, Any]]:
175
+ def __iter__(self) -> Iterator[SampleT]:
165
176
  for i in range(len(self)):
166
177
  yield self[i]
@@ -22,9 +22,6 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
- from nshtrainer.trainer._config import (
26
- DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
27
- )
28
25
  from nshtrainer.trainer._config import (
29
26
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
30
27
  )
@@ -126,6 +123,9 @@ from nshtrainer.trainer.plugin.precision import (
126
123
  )
127
124
  from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
128
125
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
126
+ from nshtrainer.trainer.trainer import (
127
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
128
+ )
129
129
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
130
130
 
131
131
  from . import _config as _config
@@ -18,9 +18,6 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
- from nshtrainer.trainer._config import (
22
- DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
23
- )
24
21
  from nshtrainer.trainer._config import (
25
22
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
26
23
  )
@@ -73,7 +70,6 @@ __all__ = [
73
70
  "CheckpointSavingConfig",
74
71
  "DebugFlagCallbackConfig",
75
72
  "DirectoryConfig",
76
- "DistributedPredictionWriterConfig",
77
73
  "EarlyStoppingCallbackConfig",
78
74
  "EnvironmentConfig",
79
75
  "GradientClippingConfig",
@@ -3,12 +3,16 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
+ from nshtrainer.trainer.trainer import (
7
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
8
+ )
6
9
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
10
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
8
11
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
9
12
 
10
13
  __all__ = [
11
14
  "AcceleratorConfigBase",
15
+ "DistributedPredictionWriterConfig",
12
16
  "EnvironmentConfig",
13
17
  "StrategyConfigBase",
14
18
  "TrainerConfig",
@@ -31,7 +31,6 @@ from .._hf_hub import HuggingFaceHubConfig
31
31
  from ..callbacks import (
32
32
  BestCheckpointCallbackConfig,
33
33
  CallbackConfig,
34
- DistributedPredictionWriterConfig,
35
34
  EarlyStoppingCallbackConfig,
36
35
  LastCheckpointCallbackConfig,
37
36
  NormLoggingCallbackConfig,
@@ -702,14 +701,6 @@ class TrainerConfig(C.Config):
702
701
  auto_validate_metrics: MetricValidationCallbackConfig | None = None
703
702
  """If enabled, will automatically validate the metrics before starting the training routine."""
704
703
 
705
- distributed_predict: DistributedPredictionWriterConfig | None = (
706
- DistributedPredictionWriterConfig()
707
- )
708
- """If enabled, will use a custom BasePredictionWriter callback to automatically
709
- handle distributed prediction. This is useful for running prediction on multiple GPUs
710
- seamlessly.
711
- """
712
-
713
704
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
714
705
  """
715
706
  Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
@@ -778,7 +769,6 @@ class TrainerConfig(C.Config):
778
769
  yield self.reduce_lr_on_plateau_sanity_checking
779
770
  yield self.auto_set_debug_flag
780
771
  yield self.auto_validate_metrics
781
- yield self.distributed_predict
782
772
  yield from self.callbacks
783
773
 
784
774
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+
10
+ @dataclass
11
+ class DistributedPredictionResult:
12
+ """Represents the results of a distributed prediction run.
13
+
14
+ This dataclass provides easy access to both raw and processed prediction data.
15
+ """
16
+
17
+ root_dir: Path
18
+ """Root directory where predictions are stored."""
19
+
20
+ @property
21
+ def raw_dir(self) -> Path:
22
+ """Directory containing raw prediction data."""
23
+ return self.root_dir / "raw"
24
+
25
+ @property
26
+ def processed_dir(self) -> Path:
27
+ """Directory containing processed prediction data."""
28
+ return self.root_dir / "processed"
29
+
30
+ def get_raw_predictions(self, dataloader_idx: int = 0) -> Path:
31
+ """Get the directory containing raw predictions for a specific dataloader.
32
+
33
+ Args:
34
+ dataloader_idx: Index of the dataloader
35
+
36
+ Returns:
37
+ Path to the raw predictions directory for the specified dataloader
38
+ """
39
+ raw_loader_dir = self.raw_dir / f"dataloader_{dataloader_idx}"
40
+ if not raw_loader_dir.exists():
41
+ log.warning(f"Raw predictions directory {raw_loader_dir} does not exist.")
42
+ return raw_loader_dir
43
+
44
+ def get_processed_reader(self, dataloader_idx: int = 0):
45
+ """Get a reader for processed predictions from a specific dataloader.
46
+
47
+ Args:
48
+ dataloader_idx: Index of the dataloader
49
+
50
+ Returns:
51
+ A DistributedPredictionReader for the processed predictions, or None if no data exists
52
+ """
53
+ from ..callbacks.distributed_prediction_writer import (
54
+ DistributedPredictionReader,
55
+ )
56
+
57
+ processed_loader_dir = self.processed_dir / f"dataloader_{dataloader_idx}"
58
+ if not processed_loader_dir.exists():
59
+ log.warning(
60
+ f"Processed predictions directory {processed_loader_dir} does not exist."
61
+ )
62
+ return None
63
+
64
+ return DistributedPredictionReader(processed_loader_dir)
65
+
66
+ @classmethod
67
+ def load(cls, path: Path | str):
68
+ """Load prediction results from a directory.
69
+
70
+ Args:
71
+ path: Path to the predictions directory
72
+
73
+ Returns:
74
+ A DistributedPredictionResult instance
75
+ """
76
+ path = Path(path)
77
+ if not path.exists():
78
+ raise FileNotFoundError(f"Predictions directory {path} does not exist.")
79
+
80
+ return cls(root_dir=path)
@@ -4,7 +4,7 @@ import logging
4
4
  import os
5
5
  from collections.abc import Callable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, cast, overload
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
@@ -24,9 +24,14 @@ from typing_extensions import Never, Unpack, assert_never, deprecated, override
24
24
 
25
25
  from .._checkpoint.metadata import write_checkpoint_metadata
26
26
  from ..callbacks.base import resolve_all_callbacks
27
+ from ..callbacks.distributed_prediction_writer import (
28
+ DistributedPredictionWriter,
29
+ DistributedPredictionWriterConfig,
30
+ )
27
31
  from ..util._environment_info import EnvironmentConfig
28
32
  from ..util.bf16 import is_bf16_supported_no_emulation
29
33
  from ._config import LightningTrainerKwargs, TrainerConfig
34
+ from ._distributed_prediction_result import DistributedPredictionResult
30
35
  from ._log_hparams import patch_log_hparams_function
31
36
  from ._runtime_callback import RuntimeTrackerCallback, Stage
32
37
  from .accelerator import AcceleratorConfigBase
@@ -537,13 +542,66 @@ class Trainer(LightningTrainer):
537
542
  )
538
543
  return cls(hparams)
539
544
 
545
+ @overload
540
546
  def distributed_predict(
541
547
  self,
542
548
  model: LightningModule | None = None,
543
549
  dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
544
550
  datamodule: LightningDataModule | None = None,
545
551
  ckpt_path: str | Path | None = None,
546
- ):
552
+ *,
553
+ config: DistributedPredictionWriterConfig,
554
+ ) -> DistributedPredictionResult: ...
555
+
556
+ @overload
557
+ def distributed_predict(
558
+ self,
559
+ model: LightningModule | None = None,
560
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
561
+ datamodule: LightningDataModule | None = None,
562
+ ckpt_path: str | Path | None = None,
563
+ *,
564
+ dirpath: Path | None = None,
565
+ move_to_cpu_on_save: bool = True,
566
+ save_raw: bool = True,
567
+ save_processed: bool = True,
568
+ ) -> DistributedPredictionResult: ...
569
+
570
+ def distributed_predict(
571
+ self,
572
+ model: LightningModule | None = None,
573
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
574
+ datamodule: LightningDataModule | None = None,
575
+ ckpt_path: str | Path | None = None,
576
+ *,
577
+ config: DistributedPredictionWriterConfig | None = None,
578
+ dirpath: Path | None = None,
579
+ move_to_cpu_on_save: bool = True,
580
+ save_raw: bool = True,
581
+ save_processed: bool = True,
582
+ ) -> DistributedPredictionResult:
583
+ if config is None:
584
+ config = DistributedPredictionWriterConfig(
585
+ dirpath=dirpath,
586
+ move_to_cpu_on_save=move_to_cpu_on_save,
587
+ save_raw=save_raw,
588
+ save_processed=save_processed,
589
+ )
590
+
591
+ # Remove any DistributedPredictionWriter callbacks that are already set
592
+ # and add the new one.
593
+ callbacks = self.callbacks.copy()
594
+ callbacks = [
595
+ callback
596
+ for callback in callbacks
597
+ if not isinstance(callback, DistributedPredictionWriter)
598
+ ]
599
+ writer_callbacks = list(config.create_callbacks(self.hparams))
600
+ assert len(writer_callbacks) == 1
601
+ callback = writer_callbacks[0]
602
+ callbacks.append(callback)
603
+ self.callbacks = self._callback_connector._reorder_callbacks(callbacks)
604
+
547
605
  self.predict(
548
606
  model,
549
607
  dataloaders,
@@ -551,3 +609,9 @@ class Trainer(LightningTrainer):
551
609
  return_predictions=False,
552
610
  ckpt_path=ckpt_path,
553
611
  )
612
+
613
+ # Wait for all processes to finish
614
+ self.strategy.barrier("Trainer.distributed_predict")
615
+
616
+ # Return an object that contains information about the predictions
617
+ return DistributedPredictionResult(root_dir=callback.output_dir)
File without changes