nshtrainer 1.0.0b13__tar.gz → 1.0.0b14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +3 -3
  4. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/__init__.py +2 -37
  5. nshtrainer-1.0.0b14/src/nshtrainer/configs/_checkpoint/__init__.py +31 -0
  6. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/__init__.py +0 -8
  7. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -7
  8. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/_config.py +0 -7
  9. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/trainer.py +1 -11
  10. nshtrainer-1.0.0b13/src/nshtrainer/_checkpoint/loader.py +0 -387
  11. nshtrainer-1.0.0b13/src/nshtrainer/configs/_checkpoint/__init__.py +0 -70
  12. nshtrainer-1.0.0b13/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -62
  13. nshtrainer-1.0.0b13/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -26
  14. nshtrainer-1.0.0b13/src/nshtrainer/trainer/checkpoint_connector.py +0 -86
  15. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/README.md +0 -0
  16. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/__init__.py +0 -0
  17. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_callback.py +0 -0
  18. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  19. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_checkpoint/saver.py +0 -0
  20. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_directory.py +0 -0
  21. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_experimental/__init__.py +0 -0
  22. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_hf_hub.py +0 -0
  23. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/__init__.py +0 -0
  24. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/actsave.py +0 -0
  25. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/base.py +0 -0
  26. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  27. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  28. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  29. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  30. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  31. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  32. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  33. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/ema.py +0 -0
  34. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  35. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  36. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/interval.py +0 -0
  37. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  38. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  39. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  40. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/print_table.py +0 -0
  41. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  42. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  43. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/timer.py +0 -0
  44. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  45. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  46. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  47. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  48. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  49. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
  50. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  51. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  52. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
  53. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  54. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  56. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  57. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  58. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  59. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  60. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  61. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  62. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  63. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  64. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  65. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  66. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  67. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  68. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  69. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  70. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  71. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  72. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  73. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
  74. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  75. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  76. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  77. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  78. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  79. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
  80. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  81. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  82. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  83. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  84. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/__init__.py +0 -0
  85. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  86. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  87. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  88. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  89. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  90. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  91. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  92. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  93. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  94. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/__init__.py +0 -0
  95. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  96. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  97. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  98. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  99. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/__init__.py +0 -0
  100. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  101. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/datamodule.py +0 -0
  102. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/transform.py +0 -0
  103. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/__init__.py +0 -0
  104. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/_base.py +0 -0
  105. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/actsave.py +0 -0
  106. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/csv.py +0 -0
  107. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/tensorboard.py +0 -0
  108. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/wandb.py +0 -0
  109. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  110. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  111. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  112. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  113. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/metrics/__init__.py +0 -0
  114. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/metrics/_config.py +0 -0
  115. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/__init__.py +0 -0
  116. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/base.py +0 -0
  117. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/callback.py +0 -0
  118. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/debug.py +0 -0
  119. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/logger.py +0 -0
  120. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/__init__.py +0 -0
  121. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/mlp.py +0 -0
  122. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/module_dict.py +0 -0
  123. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/module_list.py +0 -0
  124. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/nonlinearity.py +0 -0
  125. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/optimizer.py +0 -0
  126. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/__init__.py +0 -0
  127. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/_base.py +0 -0
  128. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/advanced.py +0 -0
  129. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/pytorch.py +0 -0
  130. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/simple.py +0 -0
  131. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/__init__.py +0 -0
  132. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  133. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/signal_connector.py +0 -0
  134. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/_environment_info.py +0 -0
  135. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/_useful_types.py +0 -0
  136. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/bf16.py +0 -0
  137. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/__init__.py +0 -0
  138. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/dtype.py +0 -0
  139. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/duration.py +0 -0
  140. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/environment.py +0 -0
  141. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/path.py +0 -0
  142. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/seed.py +0 -0
  143. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/slurm.py +0 -0
  144. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/typed.py +0 -0
  145. {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b13
3
+ Version: 1.0.0b14
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta13"
3
+ version = "1.0.0-beta14"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -96,8 +96,8 @@ class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
96
96
  def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
97
97
  # Monkey-patch the strategy instance to make the barrier operation a no-op.
98
98
  # We do this because `save_checkpoint` calls `barrier`. This is okay in most
99
- # cases, but when we want to save a checkpoint in the case of an exception,
100
- # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
101
- # make the barrier operation a no-op.
99
+ # cases, but when we want to save a checkpoint in the case of an exception,
100
+ # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
101
+ # make the barrier operation a no-op.
102
102
  with _monkey_patch_disable_barrier(trainer):
103
103
  return super().on_exception(trainer, *args, **kwargs)
@@ -9,19 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer import MetricConfig as MetricConfig
11
11
  from nshtrainer import TrainerConfig as TrainerConfig
12
- from nshtrainer._checkpoint.loader import (
13
- BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
14
- )
15
- from nshtrainer._checkpoint.loader import (
16
- CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
17
- )
18
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
19
- from nshtrainer._checkpoint.loader import (
20
- LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
21
- )
22
- from nshtrainer._checkpoint.loader import (
23
- UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
24
- )
12
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
25
13
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
26
14
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
27
15
  from nshtrainer._hf_hub import (
@@ -122,9 +110,6 @@ if TYPE_CHECKING:
122
110
  from nshtrainer.trainer._config import (
123
111
  CheckpointCallbackConfig as CheckpointCallbackConfig,
124
112
  )
125
- from nshtrainer.trainer._config import (
126
- CheckpointLoadingConfig as CheckpointLoadingConfig,
127
- )
128
113
  from nshtrainer.trainer._config import (
129
114
  CheckpointSavingConfig as CheckpointSavingConfig,
130
115
  )
@@ -199,21 +184,13 @@ else:
199
184
  return importlib.import_module(
200
185
  "nshtrainer.callbacks"
201
186
  ).BestCheckpointCallbackConfig
202
- if name == "BestCheckpointStrategyConfig":
203
- return importlib.import_module(
204
- "nshtrainer._checkpoint.loader"
205
- ).BestCheckpointStrategyConfig
206
187
  if name == "CSVLoggerConfig":
207
188
  return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
208
189
  if name == "CallbackConfigBase":
209
190
  return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
210
- if name == "CheckpointLoadingConfig":
211
- return importlib.import_module(
212
- "nshtrainer.trainer._config"
213
- ).CheckpointLoadingConfig
214
191
  if name == "CheckpointMetadata":
215
192
  return importlib.import_module(
216
- "nshtrainer._checkpoint.loader"
193
+ "nshtrainer._checkpoint.metadata"
217
194
  ).CheckpointMetadata
218
195
  if name == "CheckpointSavingConfig":
219
196
  return importlib.import_module(
@@ -317,10 +294,6 @@ else:
317
294
  return importlib.import_module(
318
295
  "nshtrainer.callbacks"
319
296
  ).LastCheckpointCallbackConfig
320
- if name == "LastCheckpointStrategyConfig":
321
- return importlib.import_module(
322
- "nshtrainer._checkpoint.loader"
323
- ).LastCheckpointStrategyConfig
324
297
  if name == "LeakyReLUNonlinearityConfig":
325
298
  return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
326
299
  if name == "LearningRateMonitorConfig":
@@ -403,10 +376,6 @@ else:
403
376
  return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
404
377
  if name == "TrainerConfig":
405
378
  return importlib.import_module("nshtrainer").TrainerConfig
406
- if name == "UserProvidedPathCheckpointStrategyConfig":
407
- return importlib.import_module(
408
- "nshtrainer._checkpoint.loader"
409
- ).UserProvidedPathCheckpointStrategyConfig
410
379
  if name == "WandbLoggerConfig":
411
380
  return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
412
381
  if name == "WandbUploadCodeCallbackConfig":
@@ -423,10 +392,6 @@ else:
423
392
  return importlib.import_module(
424
393
  "nshtrainer.trainer._config"
425
394
  ).CheckpointCallbackConfig
426
- if name == "CheckpointLoadingStrategyConfig":
427
- return importlib.import_module(
428
- "nshtrainer._checkpoint.loader"
429
- ).CheckpointLoadingStrategyConfig
430
395
  if name == "DurationConfig":
431
396
  return importlib.import_module("nshtrainer.util.config").DurationConfig
432
397
  if name == "LRSchedulerConfig":
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
11
+ from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "CheckpointMetadata":
20
+ return importlib.import_module(
21
+ "nshtrainer._checkpoint.metadata"
22
+ ).CheckpointMetadata
23
+ if name == "EnvironmentConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer._checkpoint.metadata"
26
+ ).EnvironmentConfig
27
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
+
29
+
30
+ # Submodule exports
31
+ from . import metadata as metadata
@@ -18,9 +18,6 @@ if TYPE_CHECKING:
18
18
  from nshtrainer.trainer._config import (
19
19
  CheckpointCallbackConfig as CheckpointCallbackConfig,
20
20
  )
21
- from nshtrainer.trainer._config import (
22
- CheckpointLoadingConfig as CheckpointLoadingConfig,
23
- )
24
21
  from nshtrainer.trainer._config import (
25
22
  CheckpointSavingConfig as CheckpointSavingConfig,
26
23
  )
@@ -91,10 +88,6 @@ else:
91
88
  return importlib.import_module(
92
89
  "nshtrainer.trainer._config"
93
90
  ).CallbackConfigBase
94
- if name == "CheckpointLoadingConfig":
95
- return importlib.import_module(
96
- "nshtrainer.trainer._config"
97
- ).CheckpointLoadingConfig
98
91
  if name == "CheckpointSavingConfig":
99
92
  return importlib.import_module(
100
93
  "nshtrainer.trainer._config"
@@ -180,5 +173,4 @@ else:
180
173
 
181
174
  # Submodule exports
182
175
  from . import _config as _config
183
- from . import checkpoint_connector as checkpoint_connector
184
176
  from . import trainer as trainer
@@ -17,9 +17,6 @@ if TYPE_CHECKING:
17
17
  from nshtrainer.trainer._config import (
18
18
  CheckpointCallbackConfig as CheckpointCallbackConfig,
19
19
  )
20
- from nshtrainer.trainer._config import (
21
- CheckpointLoadingConfig as CheckpointLoadingConfig,
22
- )
23
20
  from nshtrainer.trainer._config import (
24
21
  CheckpointSavingConfig as CheckpointSavingConfig,
25
22
  )
@@ -91,10 +88,6 @@ else:
91
88
  return importlib.import_module(
92
89
  "nshtrainer.trainer._config"
93
90
  ).CallbackConfigBase
94
- if name == "CheckpointLoadingConfig":
95
- return importlib.import_module(
96
- "nshtrainer.trainer._config"
97
- ).CheckpointLoadingConfig
98
91
  if name == "CheckpointSavingConfig":
99
92
  return importlib.import_module(
100
93
  "nshtrainer.trainer._config"
@@ -32,7 +32,6 @@ from lightning.pytorch.profilers import Profiler
32
32
  from lightning.pytorch.strategies.strategy import Strategy
33
33
  from typing_extensions import TypedDict, TypeVar, override
34
34
 
35
- from .._checkpoint.loader import CheckpointLoadingConfig
36
35
  from .._directory import DirectoryConfig
37
36
  from .._hf_hub import HuggingFaceHubConfig
38
37
  from ..callbacks import (
@@ -493,12 +492,6 @@ class TrainerConfig(C.Config):
493
492
  ckpt_path: Literal["none"] | str | Path | None = None
494
493
  """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
495
494
 
496
- checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
497
- """Checkpoint loading configuration options.
498
- `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
499
- `"none"` will disable checkpoint loading.
500
- """
501
-
502
495
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
503
496
  """Checkpoint saving configuration options."""
504
497
 
@@ -29,7 +29,6 @@ from ._config import (
29
29
  TrainerConfig,
30
30
  )
31
31
  from ._runtime_callback import RuntimeTrackerCallback, Stage
32
- from .checkpoint_connector import _CheckpointConnector
33
32
  from .signal_connector import _SignalConnector
34
33
 
35
34
  log = logging.getLogger(__name__)
@@ -314,9 +313,6 @@ class Trainer(LightningTrainer):
314
313
  # Replace the signal connector with our own.
315
314
  self._signal_connector = _SignalConnector(self)
316
315
 
317
- # Replace the checkpoint connector with our own.
318
- self._checkpoint_connector = _CheckpointConnector(self)
319
-
320
316
  # Print out the log dir, so that we can easily find it in the logs.
321
317
  if log_dir := self.log_dir:
322
318
  log_dir = str(Path(log_dir).resolve())
@@ -441,19 +437,13 @@ class Trainer(LightningTrainer):
441
437
  ):
442
438
  filepath = Path(filepath)
443
439
 
444
- # List of files that we should upload to HF
445
- written_files: list[Path] = [filepath]
446
-
447
440
  super().save_checkpoint(filepath, weights_only, storage_options)
448
441
 
449
442
  # Save the checkpoint metadata
450
443
  metadata_path = None
451
444
  if self.hparams.save_checkpoint_metadata and self.is_global_zero:
452
445
  # Generate the metadata and write to disk
453
- if (
454
- metadata_path := _write_checkpoint_metadata(self, filepath)
455
- ) is not None:
456
- written_files.append(metadata_path)
446
+ metadata_path = _write_checkpoint_metadata(self, filepath)
457
447
 
458
448
  # Call the `on_checkpoint_saved` method on all callbacks
459
449
  from .. import _callback
@@ -1,387 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- from collections.abc import Iterable, Sequence
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
8
-
9
- import nshconfig as C
10
- from lightning.pytorch.trainer.states import TrainerFn
11
- from typing_extensions import assert_never
12
-
13
- from ..metrics._config import MetricConfig
14
- from .metadata import CheckpointMetadata
15
-
16
- if TYPE_CHECKING:
17
- from ..trainer import Trainer
18
- from ..trainer._config import TrainerConfig
19
-
20
- log = logging.getLogger(__name__)
21
-
22
-
23
- class BestCheckpointStrategyConfig(C.Config):
24
- name: Literal["best"] = "best"
25
-
26
- metric: MetricConfig | None = None
27
- """The metric to use for selecting the best checkpoint. If `None`, the primary metric will be used."""
28
-
29
- additional_candidates: Iterable[Path] = []
30
- """Additional checkpoint candidates to consider when selecting the last checkpoint."""
31
-
32
-
33
- class UserProvidedPathCheckpointStrategyConfig(C.Config):
34
- name: Literal["user_provided_path"] = "user_provided_path"
35
-
36
- path: Path
37
- """The path to the checkpoint to load."""
38
-
39
- on_error: Literal["warn", "raise"] = "warn"
40
- """The behavior when the checkpoint does not belong to the current run.
41
-
42
- - `warn`: Log a warning and skip the checkpoint.
43
- - `raise`: Raise an error.
44
- """
45
-
46
-
47
- class LastCheckpointStrategyConfig(C.Config):
48
- name: Literal["last"] = "last"
49
-
50
- criterion: Literal["global_step", "runtime"] = "global_step"
51
- """The criterion to use for selecting the last checkpoint.
52
-
53
- - `global_step`: The checkpoint with the highest global step will be selected.
54
- - `runtime`: The checkpoint with the highest runtime will be selected.
55
- """
56
-
57
- additional_candidates: Iterable[Path] = []
58
- """Additional checkpoint candidates to consider when selecting the last checkpoint."""
59
-
60
-
61
- CheckpointLoadingStrategyConfig: TypeAlias = Annotated[
62
- BestCheckpointStrategyConfig
63
- | LastCheckpointStrategyConfig
64
- | UserProvidedPathCheckpointStrategyConfig,
65
- C.Field(discriminator="name"),
66
- ]
67
-
68
-
69
- class CheckpointLoadingConfig(C.Config):
70
- strategies: Sequence[CheckpointLoadingStrategyConfig]
71
- """The strategies to use for loading checkpoints.
72
-
73
- The order of the strategies determines the priority of the strategies.
74
- The first strategy that resolves a checkpoint will be used.
75
- """
76
-
77
- include_hpc: bool
78
- """Whether to include checkpoints from HPC pre-emption."""
79
-
80
- @classmethod
81
- def none(cls, include_hpc: bool = False):
82
- return cls(strategies=[], include_hpc=include_hpc)
83
-
84
- @classmethod
85
- def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
86
- if ckpt is None:
87
- ckpt = "last"
88
- match ckpt:
89
- case "best":
90
- return cls(
91
- strategies=[BestCheckpointStrategyConfig()],
92
- include_hpc=True,
93
- )
94
- case "last":
95
- return cls(
96
- strategies=[LastCheckpointStrategyConfig()],
97
- include_hpc=True,
98
- )
99
- case "none":
100
- return cls.none()
101
- case Path() | str():
102
- ckpt = Path(ckpt)
103
- return cls(
104
- strategies=[
105
- LastCheckpointStrategyConfig(additional_candidates=[ckpt]),
106
- UserProvidedPathCheckpointStrategyConfig(path=ckpt),
107
- ],
108
- include_hpc=True,
109
- )
110
- case _:
111
- assert_never(ckpt)
112
-
113
- @classmethod
114
- def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
115
- if ckpt is None:
116
- log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
117
- ckpt = "last"
118
-
119
- match ckpt:
120
- case "best":
121
- return cls(
122
- strategies=[BestCheckpointStrategyConfig()],
123
- include_hpc=False,
124
- )
125
- case "last":
126
- return cls(
127
- strategies=[LastCheckpointStrategyConfig()],
128
- include_hpc=False,
129
- )
130
- case "none":
131
- return cls.none(include_hpc=False)
132
- case Path() | str():
133
- ckpt = Path(ckpt)
134
- return cls(
135
- strategies=[UserProvidedPathCheckpointStrategyConfig(path=ckpt)],
136
- include_hpc=False,
137
- )
138
- case _:
139
- assert_never(ckpt)
140
-
141
- @classmethod
142
- def auto(
143
- cls,
144
- ckpt: Literal["best", "last", "none"] | str | Path | None,
145
- trainer_mode: TrainerFn,
146
- ):
147
- """
148
- Automatically create a CheckpointLoadingConfig based on the provided checkpoint option and trainer mode.
149
-
150
- This method provides a convenient way to generate a checkpoint loading configuration
151
- tailored to different training and evaluation scenarios.
152
-
153
- Parameters:
154
- -----------
155
- ckpt : Literal["best", "last", "none"] | str | Path | None
156
- Specifies the checkpoint loading preference:
157
- - "best": Use the best checkpoint based on the primary metric.
158
- - "last": Use the most recent checkpoint.
159
- - str or Path: Path to a specific checkpoint file.
160
- - None: Defaults to "last" for training, raises an error for evaluation.
161
-
162
- trainer_mode : TrainerFn
163
- The mode in which the trainer is operating. This affects how the configuration is created.
164
- - TrainerFn.FITTING: Used for training scenarios.
165
- - TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING: Used for evaluation scenarios.
166
-
167
- Returns:
168
- --------
169
- CheckpointLoadingConfig
170
- A configuration object for checkpoint loading based on the given parameters.
171
-
172
- Behavior:
173
- ---------
174
- 1. For training (TrainerFn.FITTING):
175
- - Includes HPC pre-emption checkpoints.
176
- - If ckpt is None, defaults to "last".
177
- - For "best" or "last", creates a single-strategy configuration that loads the best or last checkpoint.
178
- - For a specific path, creates a two-strategy configuration:
179
- a) Tries to load the checkpoint as the last checkpoint.
180
- b) Falls back to loading it as a user-provided path.
181
-
182
- 2. For evaluation (VALIDATING, TESTING, PREDICTING):
183
- - Does not include HPC pre-emption checkpoints.
184
- - Requires ckpt to be specified (raises ValueError if None).
185
- - Creates a single-strategy configuration based on the ckpt value.
186
-
187
- Raises:
188
- -------
189
- ValueError
190
- If ckpt is None during evaluation modes.
191
-
192
- Examples:
193
- ---------
194
- # Training mode, use last checkpoint
195
- config = CheckpointLoadingConfig.auto("last", TrainerFn.FITTING)
196
-
197
- # Evaluation mode, use best checkpoint
198
- config = CheckpointLoadingConfig.auto("best", TrainerFn.TESTING)
199
-
200
- # Training mode, use specific checkpoint
201
- config = CheckpointLoadingConfig.auto("/path/to/checkpoint.ckpt", TrainerFn.FITTING)
202
-
203
- Notes:
204
- ------
205
- - The method internally calls _auto_train or _auto_eval based on the trainer_mode.
206
- - The resulting configuration always includes strategies as a sequence, even if there's only one strategy.
207
- """
208
- # Implementation remains the same...
209
- match trainer_mode:
210
- case TrainerFn.FITTING:
211
- return cls._auto_train(ckpt)
212
- case TrainerFn.VALIDATING | TrainerFn.TESTING | TrainerFn.PREDICTING:
213
- return cls._auto_eval(ckpt)
214
- case _:
215
- assert_never(trainer_mode)
216
-
217
-
218
- @dataclass
219
- class _CkptCandidate:
220
- meta: CheckpointMetadata
221
- meta_path: Path
222
-
223
- @property
224
- def ckpt_path(self):
225
- return self.meta_path.with_name(self.meta.checkpoint_filename)
226
-
227
-
228
- @overload
229
- def _load_ckpt_meta(
230
- path: Path,
231
- trainer_config: TrainerConfig,
232
- on_error: Literal["warn"] = "warn",
233
- ) -> _CkptCandidate | None: ...
234
- @overload
235
- def _load_ckpt_meta(
236
- path: Path,
237
- trainer_config: TrainerConfig,
238
- on_error: Literal["raise"],
239
- ) -> _CkptCandidate: ...
240
- def _load_ckpt_meta(
241
- path: Path,
242
- trainer_config: TrainerConfig,
243
- on_error: Literal["warn", "raise"] = "warn",
244
- ):
245
- meta = CheckpointMetadata.from_file(path)
246
- if trainer_config.id != meta.run_id:
247
- error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
248
- match on_error:
249
- case "warn":
250
- log.warning(error_msg)
251
- case "raise":
252
- raise ValueError(error_msg)
253
- case _:
254
- assert_never(on_error)
255
- return None
256
- return _CkptCandidate(meta, path)
257
-
258
-
259
- def _checkpoint_candidates(trainer: Trainer, *, include_hpc: bool = True):
260
- # Load the checkpoint directory, and throw if it doesn't exist.
261
- # This indicates a non-standard setup, and we don't want to guess
262
- # where the checkpoints are.
263
- ckpt_dir = trainer.hparams.directory.resolve_subdirectory(
264
- trainer.hparams.id, "checkpoint"
265
- )
266
- if not ckpt_dir.is_dir():
267
- raise FileNotFoundError(
268
- f"Checkpoint directory {ckpt_dir} not found. "
269
- "Please ensure that the checkpoint directory exists."
270
- )
271
-
272
- # Load all checkpoints in the directory.
273
- # We can do this by looking for metadata files.
274
- for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
275
- if (meta := _load_ckpt_meta(path, trainer.hparams)) is not None:
276
- yield meta
277
-
278
- # If we have a pre-empted checkpoint, load it
279
- if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
280
- hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
281
- if (meta := _load_ckpt_meta(hpc_meta_path, trainer.hparams)) is not None:
282
- yield meta
283
-
284
-
285
- def _additional_candidates(
286
- additional_candidates: Iterable[Path], trainer_config: TrainerConfig
287
- ):
288
- for path in additional_candidates:
289
- if (
290
- meta := _load_ckpt_meta(
291
- path.with_suffix(CheckpointMetadata.PATH_SUFFIX), trainer_config
292
- )
293
- ) is None:
294
- continue
295
- yield meta
296
-
297
-
298
- def _resolve_checkpoint(config: CheckpointLoadingConfig, trainer: Trainer):
299
- # We lazily load the checkpoint candidates to avoid loading them
300
- # if they are not needed.
301
- _ckpt_candidates: list[_CkptCandidate] | None = None
302
-
303
- def ckpt_candidates():
304
- nonlocal _ckpt_candidates, trainer
305
-
306
- if _ckpt_candidates is None:
307
- _ckpt_candidates = list(
308
- _checkpoint_candidates(trainer, include_hpc=config.include_hpc)
309
- )
310
- return _ckpt_candidates
311
-
312
- # Iterate over the strategies and try to resolve the checkpoint.
313
- for strategy in config.strategies:
314
- match strategy:
315
- case UserProvidedPathCheckpointStrategyConfig():
316
- meta = _load_ckpt_meta(
317
- strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
318
- trainer.hparams,
319
- on_error=strategy.on_error,
320
- )
321
- if meta is None:
322
- continue
323
- return meta.ckpt_path
324
- case BestCheckpointStrategyConfig():
325
- candidates = [
326
- *ckpt_candidates(),
327
- *_additional_candidates(
328
- strategy.additional_candidates, trainer.hparams
329
- ),
330
- ]
331
- if not candidates:
332
- log.warning(
333
- "No checkpoint candidates found for `best` checkpoint strategy."
334
- )
335
- continue
336
-
337
- if (
338
- metric := strategy.metric or trainer.hparams.primary_metric
339
- ) is None:
340
- log.warning(
341
- "No metric specified for `best` checkpoint strategy, "
342
- "and no primary metric is set in the configuration. "
343
- "Skipping strategy."
344
- )
345
- continue
346
-
347
- # Find the best checkpoint based on the metric.
348
- def metric_value(ckpt: _CkptCandidate):
349
- assert metric is not None
350
- if (
351
- value := ckpt.meta.metrics.get(metric.validation_monitor)
352
- ) is None:
353
- raise ValueError(
354
- f"Metric {metric.validation_monitor} not found in checkpoint metadata. "
355
- f"Available metrics: {ckpt.meta.metrics.keys()}"
356
- )
357
- return value
358
-
359
- best_candidate = metric.best(candidates, key=metric_value)
360
- return best_candidate.ckpt_path
361
- case LastCheckpointStrategyConfig():
362
- candidates = [
363
- *ckpt_candidates(),
364
- *_additional_candidates(
365
- strategy.additional_candidates, trainer.hparams
366
- ),
367
- ]
368
- if not candidates:
369
- log.warning(
370
- "No checkpoint candidates found for `last` checkpoint strategy."
371
- )
372
- continue
373
-
374
- # Find the last checkpoint based on the criterion.
375
- def criterion_value(ckpt: _CkptCandidate):
376
- match strategy.criterion:
377
- case "global_step":
378
- return ckpt.meta.global_step
379
- case "runtime":
380
- return ckpt.meta.training_time.total_seconds()
381
- case _:
382
- assert_never(strategy.criterion)
383
-
384
- last_candidate = max(candidates, key=criterion_value)
385
- return last_candidate.ckpt_path
386
- case _:
387
- assert_never(strategy)