nshtrainer 1.0.0b10__tar.gz → 1.0.0b12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/__init__.py +1 -1
  4. nshtrainer-1.0.0b12/src/nshtrainer/data/datamodule.py +126 -0
  5. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/model/base.py +100 -2
  6. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/_config.py +0 -1
  7. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/trainer.py +38 -63
  8. nshtrainer-1.0.0b10/src/nshtrainer/data/datamodule.py +0 -57
  9. nshtrainer-1.0.0b10/src/nshtrainer/scripts/find_packages.py +0 -52
  10. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/README.md +0 -0
  11. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_callback.py +0 -0
  12. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_checkpoint/loader.py +0 -0
  13. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  14. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_checkpoint/saver.py +0 -0
  15. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_directory.py +0 -0
  16. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_experimental/__init__.py +0 -0
  17. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/_hf_hub.py +0 -0
  18. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/__init__.py +0 -0
  19. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/actsave.py +0 -0
  20. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/base.py +0 -0
  21. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  22. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  23. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  24. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  25. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  26. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  27. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  28. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  29. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/ema.py +0 -0
  30. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  31. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  32. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/interval.py +0 -0
  33. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  34. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  35. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/print_table.py +0 -0
  36. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  37. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  38. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/timer.py +0 -0
  39. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  40. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  41. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/__init__.py +0 -0
  42. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/_checkpoint/__init__.py +0 -0
  43. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/_checkpoint/loader/__init__.py +0 -0
  44. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/_checkpoint/metadata/__init__.py +0 -0
  45. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/_directory/__init__.py +0 -0
  46. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/_hf_hub/__init__.py +0 -0
  47. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/__init__.py +0 -0
  48. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/actsave/__init__.py +0 -0
  49. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/base/__init__.py +0 -0
  50. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/checkpoint/__init__.py +0 -0
  51. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/checkpoint/_base/__init__.py +0 -0
  52. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  53. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  54. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/debug_flag/__init__.py +0 -0
  56. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/directory_setup/__init__.py +0 -0
  57. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/early_stopping/__init__.py +0 -0
  58. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/ema/__init__.py +0 -0
  59. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/finite_checks/__init__.py +0 -0
  60. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/gradient_skipping/__init__.py +0 -0
  61. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/log_epoch/__init__.py +0 -0
  62. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/norm_logging/__init__.py +0 -0
  63. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/print_table/__init__.py +0 -0
  64. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/rlp_sanity_checks/__init__.py +0 -0
  65. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/shared_parameters/__init__.py +0 -0
  66. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/timer/__init__.py +0 -0
  67. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/wandb_upload_code/__init__.py +0 -0
  68. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/callbacks/wandb_watch/__init__.py +0 -0
  69. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/loggers/__init__.py +0 -0
  70. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/loggers/_base/__init__.py +0 -0
  71. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/loggers/actsave/__init__.py +0 -0
  72. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/loggers/csv/__init__.py +0 -0
  73. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/loggers/tensorboard/__init__.py +0 -0
  74. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/loggers/wandb/__init__.py +0 -0
  75. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/lr_scheduler/__init__.py +0 -0
  76. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/lr_scheduler/_base/__init__.py +0 -0
  77. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  78. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  79. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/metrics/__init__.py +0 -0
  80. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/metrics/_config/__init__.py +0 -0
  81. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/nn/__init__.py +0 -0
  82. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/nn/mlp/__init__.py +0 -0
  83. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/nn/nonlinearity/__init__.py +0 -0
  84. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/optimizer/__init__.py +0 -0
  85. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/profiler/__init__.py +0 -0
  86. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/profiler/_base/__init__.py +0 -0
  87. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/profiler/advanced/__init__.py +0 -0
  88. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/profiler/pytorch/__init__.py +0 -0
  89. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/profiler/simple/__init__.py +0 -0
  90. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/trainer/__init__.py +0 -0
  91. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/trainer/_config/__init__.py +0 -0
  92. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/trainer/checkpoint_connector/__init__.py +0 -0
  93. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/trainer/trainer/__init__.py +0 -0
  94. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/util/__init__.py +0 -0
  95. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/util/_environment_info/__init__.py +0 -0
  96. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/util/config/__init__.py +0 -0
  97. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/util/config/dtype/__init__.py +0 -0
  98. {nshtrainer-1.0.0b10/src/nshtrainer/config → nshtrainer-1.0.0b12/src/nshtrainer/configs}/util/config/duration/__init__.py +0 -0
  99. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/data/__init__.py +0 -0
  100. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  101. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/data/transform.py +0 -0
  102. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/__init__.py +0 -0
  103. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/_base.py +0 -0
  104. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/actsave.py +0 -0
  105. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/csv.py +0 -0
  106. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/tensorboard.py +0 -0
  107. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/wandb.py +0 -0
  108. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  109. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  110. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  111. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  112. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/metrics/__init__.py +0 -0
  113. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/metrics/_config.py +0 -0
  114. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/model/__init__.py +0 -0
  115. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/model/mixins/callback.py +0 -0
  116. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/model/mixins/debug.py +0 -0
  117. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/model/mixins/logger.py +0 -0
  118. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/__init__.py +0 -0
  119. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/mlp.py +0 -0
  120. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/module_dict.py +0 -0
  121. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/module_list.py +0 -0
  122. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/nonlinearity.py +0 -0
  123. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/optimizer.py +0 -0
  124. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/__init__.py +0 -0
  125. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/_base.py +0 -0
  126. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/advanced.py +0 -0
  127. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/pytorch.py +0 -0
  128. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/simple.py +0 -0
  129. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/__init__.py +0 -0
  130. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  131. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  132. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/signal_connector.py +0 -0
  133. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/_environment_info.py +0 -0
  134. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/_useful_types.py +0 -0
  135. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/bf16.py +0 -0
  136. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/config/__init__.py +0 -0
  137. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/config/dtype.py +0 -0
  138. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/config/duration.py +0 -0
  139. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/environment.py +0 -0
  140. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/path.py +0 -0
  141. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/seed.py +0 -0
  142. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/slurm.py +0 -0
  143. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/typed.py +0 -0
  144. {nshtrainer-1.0.0b10 → nshtrainer-1.0.0b12}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b10
3
+ Version: 1.0.0b12
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta10"
3
+ version = "1.0.0-beta12"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -16,6 +16,6 @@ from .trainer import Trainer as Trainer
16
16
  from .trainer import TrainerConfig as TrainerConfig
17
17
 
18
18
  try:
19
- from . import config as config
19
+ from . import configs as configs
20
20
  except BaseException:
21
21
  pass
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Callable, Mapping
5
+ from pathlib import Path
6
+ from typing import Any, Generic, cast
7
+
8
+ import nshconfig as C
9
+ import torch
10
+ from lightning.pytorch import LightningDataModule
11
+ from lightning.pytorch.utilities.model_helpers import is_overridden
12
+ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
+ from typing_extensions import Never, TypeVar, deprecated, override
14
+
15
+ from ..model.mixins.callback import CallbackRegistrarModuleMixin
16
+ from ..model.mixins.debug import _DebugModuleMixin
17
+
18
+ THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
19
+
20
+
21
+ class LightningDataModuleBase(
22
+ _DebugModuleMixin,
23
+ CallbackRegistrarModuleMixin,
24
+ LightningDataModule,
25
+ ABC,
26
+ Generic[THparams],
27
+ ):
28
+ @property
29
+ @override
30
+ def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
31
+ return cast(THparams, super().hparams)
32
+
33
+ @property
34
+ @override
35
+ def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
36
+ hparams = cast(THparams, super().hparams_initial)
37
+ return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
38
+
39
+ @property
40
+ @deprecated("Use `hparams` instead")
41
+ def config(self):
42
+ return cast(Never, self.hparams)
43
+
44
+ @classmethod
45
+ @abstractmethod
46
+ def hparams_cls(cls) -> type[THparams]: ...
47
+
48
+ @override
49
+ def __init__(self, hparams: THparams | Mapping[str, Any]):
50
+ super().__init__()
51
+
52
+ # Validate and save hyperparameters
53
+ hparams_cls = self.hparams_cls()
54
+ if isinstance(hparams, Mapping):
55
+ hparams = hparams_cls.model_validate(hparams)
56
+ elif not isinstance(hparams, hparams_cls):
57
+ raise TypeError(
58
+ f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
59
+ )
60
+ hparams = hparams.model_deep_validate()
61
+ self.save_hyperparameters(hparams)
62
+
63
+ @override
64
+ @classmethod
65
+ def load_from_checkpoint(cls, *args, **kwargs) -> Never:
66
+ raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
67
+
68
+ @classmethod
69
+ def hparams_from_checkpoint(
70
+ cls,
71
+ ckpt_or_path: dict[str, Any] | str | Path,
72
+ /,
73
+ strict: bool | None = None,
74
+ *,
75
+ update_hparams: Callable[[THparams], THparams] | None = None,
76
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
77
+ ):
78
+ if isinstance(ckpt_or_path, dict):
79
+ ckpt = ckpt_or_path
80
+ else:
81
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
82
+
83
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
84
+ raise ValueError(
85
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
86
+ )
87
+ if update_hparams_dict is not None:
88
+ hparams = update_hparams_dict(hparams)
89
+
90
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
91
+ if update_hparams is not None:
92
+ hparams = update_hparams(hparams)
93
+
94
+ return hparams
95
+
96
+ @classmethod
97
+ def from_checkpoint(
98
+ cls,
99
+ ckpt_or_path: dict[str, Any] | str | Path,
100
+ /,
101
+ strict: bool | None = None,
102
+ map_location: torch.serialization.MAP_LOCATION = None,
103
+ *,
104
+ update_hparams: Callable[[THparams], THparams] | None = None,
105
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
106
+ ):
107
+ # Load checkpoint
108
+ if isinstance(ckpt_or_path, Mapping):
109
+ ckpt = ckpt_or_path
110
+ else:
111
+ ckpt = torch.load(ckpt_or_path, map_location=map_location)
112
+
113
+ # Load hyperparameters from checkpoint
114
+ hparams = cls.hparams_from_checkpoint(
115
+ ckpt,
116
+ strict=strict,
117
+ update_hparams=update_hparams,
118
+ update_hparams_dict=update_hparams_dict,
119
+ )
120
+
121
+ # Load datamodule from checkpoint
122
+ datamodule = cls(hparams)
123
+ if datamodule.__class__.__qualname__ in ckpt:
124
+ datamodule.load_state_dict(ckpt[datamodule.__class__.__qualname__])
125
+
126
+ return datamodule
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import Mapping
5
+ from collections.abc import Callable, Mapping
6
+ from pathlib import Path
6
7
  from typing import Any, Generic, Literal, cast
7
8
 
8
9
  import nshconfig as C
@@ -10,11 +11,13 @@ import torch
10
11
  import torch.distributed
11
12
  from lightning.pytorch import LightningModule
12
13
  from lightning.pytorch.profilers import PassThroughProfiler, Profiler
14
+ from lightning.pytorch.utilities.model_helpers import is_overridden
15
+ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
16
  from typing_extensions import Never, TypeVar, deprecated, override
14
17
 
15
18
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
19
  from .mixins.callback import CallbackModuleMixin
17
- from .mixins.debug import _DebugModuleMixin, _trainer
20
+ from .mixins.debug import _DebugModuleMixin
18
21
  from .mixins.logger import LoggerLightningModuleMixin
19
22
 
20
23
  log = logging.getLogger(__name__)
@@ -241,3 +244,98 @@ class LightningModuleBase(
241
244
  loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
242
245
  loss = cast(torch.Tensor, loss)
243
246
  return loss
247
+
248
+ @override
249
+ @classmethod
250
+ def load_from_checkpoint(cls, *args, **kwargs) -> Never:
251
+ raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
252
+
253
+ @classmethod
254
+ def hparams_from_checkpoint(
255
+ cls,
256
+ ckpt_or_path: dict[str, Any] | str | Path,
257
+ /,
258
+ strict: bool | None = None,
259
+ *,
260
+ update_hparams: Callable[[THparams], THparams] | None = None,
261
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
262
+ ):
263
+ if isinstance(ckpt_or_path, dict):
264
+ ckpt = ckpt_or_path
265
+ else:
266
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
267
+
268
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
269
+ raise ValueError(
270
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
271
+ )
272
+ if update_hparams_dict is not None:
273
+ hparams = update_hparams_dict(hparams)
274
+
275
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
276
+ if update_hparams is not None:
277
+ hparams = update_hparams(hparams)
278
+
279
+ return hparams
280
+
281
+ @classmethod
282
+ def from_checkpoint(
283
+ cls,
284
+ ckpt_or_path: dict[str, Any] | str | Path,
285
+ /,
286
+ strict: bool | None = None,
287
+ map_location: torch.serialization.MAP_LOCATION = None,
288
+ *,
289
+ update_hparams: Callable[[THparams], THparams] | None = None,
290
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
291
+ ):
292
+ # Load checkpoint
293
+ if isinstance(ckpt_or_path, Mapping):
294
+ ckpt = ckpt_or_path
295
+ else:
296
+ ckpt = torch.load(ckpt_or_path, map_location=map_location)
297
+
298
+ # Load hyperparameters from checkpoint
299
+ hparams = cls.hparams_from_checkpoint(
300
+ ckpt,
301
+ strict=strict,
302
+ update_hparams=update_hparams,
303
+ update_hparams_dict=update_hparams_dict,
304
+ )
305
+
306
+ # Load model from checkpoint
307
+ model = cls(hparams)
308
+
309
+ # Load model state from checkpoint
310
+ if (
311
+ model._strict_loading is not None
312
+ and strict is not None
313
+ and strict != model.strict_loading
314
+ ):
315
+ raise ValueError(
316
+ f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
317
+ f" `{cls.__name__}.strict_loading={model.strict_loading!r}. Please set the same value for both of them."
318
+ )
319
+ strict = model.strict_loading if strict is None else strict
320
+
321
+ if is_overridden("configure_model", model):
322
+ model.configure_model()
323
+
324
+ # give model a chance to load something
325
+ model.on_load_checkpoint(ckpt)
326
+
327
+ # load the state_dict on the model automatically
328
+
329
+ keys = model.load_state_dict(ckpt["state_dict"], strict=strict)
330
+
331
+ if not strict:
332
+ if keys.missing_keys:
333
+ rank_zero_warn(
334
+ f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
335
+ )
336
+ if keys.unexpected_keys:
337
+ rank_zero_warn(
338
+ f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
339
+ )
340
+
341
+ return model
@@ -9,7 +9,6 @@ from collections.abc import Iterable, Sequence
9
9
  from datetime import timedelta
10
10
  from pathlib import Path
11
11
  from typing import (
12
- TYPE_CHECKING,
13
12
  Annotated,
14
13
  Any,
15
14
  ClassVar,
@@ -2,28 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import os
5
- from collections.abc import Mapping, Sequence
5
+ from collections.abc import Callable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import IO, TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, cast
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
11
11
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
12
12
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
13
- from lightning.fabric.utilities.cloud_io import _load as pl_load
14
- from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
15
13
  from lightning.pytorch import LightningModule
16
14
  from lightning.pytorch import Trainer as LightningTrainer
17
15
  from lightning.pytorch.callbacks import Callback
18
- from lightning.pytorch.core.saving import (
19
- _default_map_location,
20
- load_hparams_from_tags_csv,
21
- load_hparams_from_yaml,
22
- )
23
16
  from lightning.pytorch.profilers import Profiler
24
17
  from lightning.pytorch.trainer.states import TrainerFn
25
- from lightning.pytorch.utilities.migration import pl_legacy_patch
26
- from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
27
18
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
28
19
  from typing_extensions import Never, Unpack, assert_never, deprecated, override
29
20
 
@@ -473,62 +464,46 @@ class Trainer(LightningTrainer):
473
464
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
474
465
 
475
466
  @classmethod
476
- def load_from_checkpoint(
467
+ def hparams_from_checkpoint(
477
468
  cls,
478
- checkpoint_path: _PATH | IO,
479
- map_location: _MAP_LOCATION_TYPE = None,
480
- hparams_file: _PATH | None = None,
481
- **kwargs: Any,
469
+ ckpt_or_path: dict[str, Any] | str | Path,
470
+ /,
471
+ strict: bool | None = None,
472
+ *,
473
+ update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
474
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
482
475
  ):
483
- loaded = _load_from_checkpoint(
484
- checkpoint_path,
485
- map_location=map_location,
486
- hparams_file=hparams_file,
487
- **kwargs,
488
- )
489
- return loaded
490
-
491
-
492
- def _load_from_checkpoint(
493
- checkpoint_path: _PATH | IO,
494
- map_location: _MAP_LOCATION_TYPE = None,
495
- hparams_file: _PATH | None = None,
496
- **kwargs: Any,
497
- ):
498
- map_location = map_location or _default_map_location
499
- with pl_legacy_patch():
500
- checkpoint = pl_load(checkpoint_path, map_location=map_location)
501
-
502
- # convert legacy checkpoints to the new format
503
- checkpoint = _pl_migrate_checkpoint(
504
- checkpoint,
505
- checkpoint_path=(
506
- checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None
507
- ),
508
- )
509
-
510
- if hparams_file is not None:
511
- extension = str(hparams_file).split(".")[-1]
512
- if extension.lower() == "csv":
513
- hparams = load_hparams_from_tags_csv(hparams_file)
514
- elif extension.lower() in ("yml", "yaml"):
515
- hparams = load_hparams_from_yaml(hparams_file)
476
+ if isinstance(ckpt_or_path, dict):
477
+ ckpt = ckpt_or_path
516
478
  else:
517
- raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
479
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
518
480
 
519
- # overwrite hparams by the given file
520
- checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
481
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
482
+ raise ValueError(
483
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
484
+ )
485
+ if update_hparams_dict is not None:
486
+ hparams = update_hparams_dict(hparams)
521
487
 
522
- # for past checkpoint need to add the new key
523
- checkpoint.setdefault(Trainer.CHECKPOINT_HYPER_PARAMS_KEY, {})
524
- # override the hparams with values that were passed in
525
- checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
488
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
489
+ if update_hparams is not None:
490
+ hparams = update_hparams(hparams)
526
491
 
527
- # load the hparams
528
- hparams = Trainer.hparams_cls().model_validate(
529
- checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY]
530
- )
492
+ return hparams
531
493
 
532
- # create the trainer
533
- trainer = Trainer(hparams)
534
- return trainer
494
+ @classmethod
495
+ def from_checkpoint(
496
+ cls,
497
+ path: str | Path,
498
+ strict: bool | None = None,
499
+ *,
500
+ update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
501
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
502
+ ):
503
+ hparams = cls.hparams_from_checkpoint(
504
+ path,
505
+ strict=strict,
506
+ update_hparams=update_hparams,
507
+ update_hparams_dict=update_hparams_dict,
508
+ )
509
+ return cls(hparams)
@@ -1,57 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from abc import ABC, abstractmethod
4
- from collections.abc import Mapping
5
- from typing import Any, Generic, cast
6
-
7
- import nshconfig as C
8
- from lightning.pytorch import LightningDataModule
9
- from typing_extensions import Never, TypeVar, deprecated, override
10
-
11
- from ..model.mixins.callback import CallbackRegistrarModuleMixin
12
- from ..model.mixins.debug import _DebugModuleMixin
13
-
14
- THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
15
-
16
-
17
- class LightningDataModuleBase(
18
- _DebugModuleMixin,
19
- CallbackRegistrarModuleMixin,
20
- LightningDataModule,
21
- ABC,
22
- Generic[THparams],
23
- ):
24
- @property
25
- @override
26
- def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
27
- return cast(THparams, super().hparams)
28
-
29
- @property
30
- @override
31
- def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
32
- hparams = cast(THparams, super().hparams_initial)
33
- return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
34
-
35
- @property
36
- @deprecated("Use `hparams` instead")
37
- def config(self):
38
- return cast(Never, self.hparams)
39
-
40
- @classmethod
41
- @abstractmethod
42
- def hparams_cls(cls) -> type[THparams]: ...
43
-
44
- @override
45
- def __init__(self, hparams: THparams | Mapping[str, Any]):
46
- super().__init__()
47
-
48
- # Validate and save hyperparameters
49
- hparams_cls = self.hparams_cls()
50
- if isinstance(hparams, Mapping):
51
- hparams = hparams_cls.model_validate(hparams)
52
- elif not isinstance(hparams, hparams_cls):
53
- raise TypeError(
54
- f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
55
- )
56
- hparams = hparams.model_deep_validate()
57
- self.save_hyperparameters(hparams)
@@ -1,52 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import ast
5
- import glob
6
- import sys
7
- from pathlib import Path
8
-
9
-
10
- def get_imports(file_path: Path):
11
- with open(file_path, "r") as file:
12
- try:
13
- tree = ast.parse(file.read())
14
- except SyntaxError:
15
- print(f"Syntax error in file: {file_path}", file=sys.stderr)
16
- return set()
17
-
18
- imports = set()
19
- for node in ast.walk(tree):
20
- if isinstance(node, ast.Import):
21
- for alias in node.names:
22
- imports.add(alias.name.split(".")[0])
23
- elif isinstance(node, ast.ImportFrom):
24
- if node.level == 0 and node.module: # Absolute import
25
- imports.add(node.module.split(".")[0])
26
- return imports
27
-
28
-
29
- def main():
30
- parser = argparse.ArgumentParser(
31
- description="Find unique Python packages used in files."
32
- )
33
- parser.add_argument("glob_pattern", help="Glob pattern to match files")
34
- parser.add_argument(
35
- "--exclude-std", action="store_true", help="Exclude Python standard libraries"
36
- )
37
- args = parser.parse_args()
38
-
39
- all_imports = set()
40
- for file_path in glob.glob(args.glob_pattern, recursive=True):
41
- all_imports.update(get_imports(Path(file_path)))
42
-
43
- if args.exclude_std:
44
- std_libs = set(sys.stdlib_module_names)
45
- all_imports = all_imports - std_libs
46
-
47
- for package in sorted(all_imports):
48
- print(package)
49
-
50
-
51
- if __name__ == "__main__":
52
- main()
File without changes