nshtrainer 1.0.0b26__tar.gz → 1.0.0b27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/PKG-INFO +1 -1
  2. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/pyproject.toml +1 -1
  3. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/actsave.py +2 -2
  4. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/base.py +5 -3
  5. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/shared_parameters.py +5 -3
  6. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/__init__.py +4 -0
  7. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/__init__.py +4 -0
  8. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +6 -0
  9. nshtrainer-1.0.0b27/src/nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
  10. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/trainer/__init__.py +4 -0
  11. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/trainer/_config/__init__.py +4 -0
  12. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/loggers/__init__.py +12 -5
  13. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/lr_scheduler/__init__.py +9 -5
  14. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/model/mixins/callback.py +6 -4
  15. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/optimizer.py +5 -3
  16. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/profiler/__init__.py +9 -5
  17. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/trainer/_config.py +47 -42
  18. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/trainer/_runtime_callback.py +3 -3
  19. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/trainer/signal_connector.py +6 -4
  20. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/_useful_types.py +11 -2
  21. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/config/dtype.py +46 -43
  22. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/path.py +3 -2
  23. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/README.md +0 -0
  24. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/.nshconfig.generated.json +0 -0
  25. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/__init__.py +0 -0
  26. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/_callback.py +0 -0
  27. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  28. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/_checkpoint/saver.py +0 -0
  29. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/_directory.py +0 -0
  30. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/_experimental/__init__.py +0 -0
  31. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/_hf_hub.py +0 -0
  32. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/__init__.py +0 -0
  33. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  34. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  35. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  36. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  37. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  38. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -0
  39. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  40. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  41. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  42. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/ema.py +0 -0
  43. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  44. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  45. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/interval.py +0 -0
  46. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  47. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
  48. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  49. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/print_table.py +0 -0
  50. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  51. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/timer.py +0 -0
  52. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  53. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  54. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
  55. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
  56. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/_directory/__init__.py +0 -0
  57. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
  58. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
  59. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
  60. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
  61. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
  62. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
  63. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
  64. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
  65. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
  66. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
  67. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
  68. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
  69. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
  70. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
  71. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
  72. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
  73. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
  74. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
  75. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
  76. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
  77. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
  78. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
  79. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/loggers/__init__.py +0 -0
  80. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
  81. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
  82. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
  83. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
  84. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
  85. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
  86. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
  87. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
  88. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
  89. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/metrics/__init__.py +0 -0
  90. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
  91. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/nn/__init__.py +0 -0
  92. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
  93. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
  94. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
  95. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/profiler/__init__.py +0 -0
  96. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
  97. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
  98. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
  99. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
  100. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
  101. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/util/__init__.py +0 -0
  102. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
  103. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/util/config/__init__.py +0 -0
  104. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
  105. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
  106. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/data/__init__.py +0 -0
  107. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  108. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/data/datamodule.py +0 -0
  109. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/data/transform.py +0 -0
  110. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/loggers/_base.py +0 -0
  111. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/loggers/actsave.py +0 -0
  112. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/loggers/csv.py +0 -0
  113. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/loggers/tensorboard.py +0 -0
  114. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/loggers/wandb.py +0 -0
  115. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  116. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  117. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  118. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/metrics/__init__.py +0 -0
  119. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/metrics/_config.py +0 -0
  120. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/model/__init__.py +0 -0
  121. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/model/base.py +0 -0
  122. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/model/mixins/debug.py +0 -0
  123. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/model/mixins/logger.py +0 -0
  124. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/nn/__init__.py +0 -0
  125. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/nn/mlp.py +0 -0
  126. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/nn/module_dict.py +0 -0
  127. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/nn/module_list.py +0 -0
  128. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/nn/nonlinearity.py +0 -0
  129. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/profiler/_base.py +0 -0
  130. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/profiler/advanced.py +0 -0
  131. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/profiler/pytorch.py +0 -0
  132. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/profiler/simple.py +0 -0
  133. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/trainer/__init__.py +0 -0
  134. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/trainer/trainer.py +0 -0
  135. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/_environment_info.py +0 -0
  136. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/bf16.py +0 -0
  137. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/config/__init__.py +0 -0
  138. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/config/duration.py +0 -0
  139. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/environment.py +0 -0
  140. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/seed.py +0 -0
  141. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/slurm.py +0 -0
  142. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/typed.py +0 -0
  143. {nshtrainer-1.0.0b26 → nshtrainer-1.0.0b27}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b26
3
+ Version: 1.0.0b27
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "1.0.0-beta26"
3
+ version = "1.0.0-beta27"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -4,12 +4,12 @@ import contextlib
4
4
  from pathlib import Path
5
5
  from typing import Literal
6
6
 
7
- from typing_extensions import TypeAlias, override
7
+ from typing_extensions import TypeAliasType, override
8
8
 
9
9
  from .._callback import NTCallbackBase
10
10
  from .base import CallbackConfigBase
11
11
 
12
- Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
12
+ Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
13
13
 
14
14
 
15
15
  class ActSaveConfig(CallbackConfigBase):
@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
4
4
  from collections import Counter
5
5
  from collections.abc import Iterable
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, ClassVar, TypeAlias
7
+ from typing import TYPE_CHECKING, ClassVar
8
8
 
9
9
  import nshconfig as C
10
10
  from lightning.pytorch import Callback
11
- from typing_extensions import TypedDict, Unpack
11
+ from typing_extensions import TypeAliasType, TypedDict, Unpack
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from ..trainer._config import TrainerConfig
@@ -30,7 +30,9 @@ class CallbackWithMetadata:
30
30
  metadata: CallbackMetadataConfig
31
31
 
32
32
 
33
- ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
33
+ ConstructedCallback = TypeAliasType(
34
+ "ConstructedCallback", Callback | CallbackWithMetadata
35
+ )
34
36
 
35
37
 
36
38
  class CallbackConfigBase(C.Config, ABC):
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from collections.abc import Iterable
5
- from typing import Literal, Protocol, TypeAlias, runtime_checkable
5
+ from typing import Literal, Protocol, runtime_checkable
6
6
 
7
7
  import torch.nn as nn
8
8
  from lightning.pytorch import LightningModule, Trainer
9
9
  from lightning.pytorch.callbacks import Callback
10
- from typing_extensions import override
10
+ from typing_extensions import TypeAliasType, override
11
11
 
12
12
  from .base import CallbackConfigBase
13
13
 
@@ -34,7 +34,9 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
34
34
  yield SharedParametersCallback(self)
35
35
 
36
36
 
37
- SharedParametersList: TypeAlias = list[tuple[nn.Parameter, int | float]]
37
+ SharedParametersList = TypeAliasType(
38
+ "SharedParametersList", list[tuple[nn.Parameter, int | float]]
39
+ )
38
40
 
39
41
 
40
42
  @runtime_checkable
@@ -46,6 +46,9 @@ from nshtrainer.callbacks import (
46
46
  from nshtrainer.callbacks import (
47
47
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
48
48
  )
49
+ from nshtrainer.callbacks import (
50
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
51
+ )
49
52
  from nshtrainer.callbacks import (
50
53
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
51
54
  )
@@ -217,6 +220,7 @@ __all__ = [
217
220
  "SwishNonlinearityConfig",
218
221
  "TanhNonlinearityConfig",
219
222
  "TensorboardLoggerConfig",
223
+ "TimeCheckpointCallbackConfig",
220
224
  "TrainerConfig",
221
225
  "WandbLoggerConfig",
222
226
  "WandbUploadCodeCallbackConfig",
@@ -38,6 +38,9 @@ from nshtrainer.callbacks import (
38
38
  from nshtrainer.callbacks import (
39
39
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
40
40
  )
41
+ from nshtrainer.callbacks import (
42
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
43
+ )
41
44
  from nshtrainer.callbacks import (
42
45
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
43
46
  )
@@ -95,6 +98,7 @@ __all__ = [
95
98
  "PrintTableMetricsCallbackConfig",
96
99
  "RLPSanityChecksCallbackConfig",
97
100
  "SharedParametersCallbackConfig",
101
+ "TimeCheckpointCallbackConfig",
98
102
  "WandbUploadCodeCallbackConfig",
99
103
  "WandbWatchCallbackConfig",
100
104
  "actsave",
@@ -11,6 +11,9 @@ from nshtrainer.callbacks.checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint import (
12
12
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
13
13
  )
14
+ from nshtrainer.callbacks.checkpoint import (
15
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
16
+ )
14
17
  from nshtrainer.callbacks.checkpoint._base import (
15
18
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
16
19
  )
@@ -26,6 +29,7 @@ from . import _base as _base
26
29
  from . import best_checkpoint as best_checkpoint
27
30
  from . import last_checkpoint as last_checkpoint
28
31
  from . import on_exception_checkpoint as on_exception_checkpoint
32
+ from . import time_checkpoint as time_checkpoint
29
33
 
30
34
  __all__ = [
31
35
  "BaseCheckpointCallbackConfig",
@@ -35,8 +39,10 @@ __all__ = [
35
39
  "LastCheckpointCallbackConfig",
36
40
  "MetricConfig",
37
41
  "OnExceptionCheckpointCallbackConfig",
42
+ "TimeCheckpointCallbackConfig",
38
43
  "_base",
39
44
  "best_checkpoint",
40
45
  "last_checkpoint",
41
46
  "on_exception_checkpoint",
47
+ "time_checkpoint",
42
48
  ]
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.checkpoint.time_checkpoint import (
6
+ BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
7
+ )
8
+ from nshtrainer.callbacks.checkpoint.time_checkpoint import (
9
+ CheckpointMetadata as CheckpointMetadata,
10
+ )
11
+ from nshtrainer.callbacks.checkpoint.time_checkpoint import (
12
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
13
+ )
14
+
15
+ __all__ = [
16
+ "BaseCheckpointCallbackConfig",
17
+ "CheckpointMetadata",
18
+ "TimeCheckpointCallbackConfig",
19
+ ]
@@ -48,6 +48,9 @@ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
48
48
  from nshtrainer.trainer._config import (
49
49
  TensorboardLoggerConfig as TensorboardLoggerConfig,
50
50
  )
51
+ from nshtrainer.trainer._config import (
52
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
53
+ )
51
54
  from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
52
55
 
53
56
  from . import _config as _config
@@ -79,6 +82,7 @@ __all__ = [
79
82
  "SharedParametersCallbackConfig",
80
83
  "StrategyConfigBase",
81
84
  "TensorboardLoggerConfig",
85
+ "TimeCheckpointCallbackConfig",
82
86
  "TrainerConfig",
83
87
  "WandbLoggerConfig",
84
88
  "_config",
@@ -47,6 +47,9 @@ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
47
47
  from nshtrainer.trainer._config import (
48
48
  TensorboardLoggerConfig as TensorboardLoggerConfig,
49
49
  )
50
+ from nshtrainer.trainer._config import (
51
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
52
+ )
50
53
  from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
51
54
  from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
52
55
 
@@ -76,6 +79,7 @@ __all__ = [
76
79
  "SharedParametersCallbackConfig",
77
80
  "StrategyConfigBase",
78
81
  "TensorboardLoggerConfig",
82
+ "TimeCheckpointCallbackConfig",
79
83
  "TrainerConfig",
80
84
  "WandbLoggerConfig",
81
85
  ]
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, TypeAlias
3
+ from typing import Annotated
4
4
 
5
5
  import nshconfig as C
6
+ from typing_extensions import TypeAliasType
6
7
 
7
8
  from ._base import BaseLoggerConfig as BaseLoggerConfig
8
9
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
@@ -10,7 +11,13 @@ from .csv import CSVLoggerConfig as CSVLoggerConfig
10
11
  from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
11
12
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
12
13
 
13
- LoggerConfig: TypeAlias = Annotated[
14
- CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig | ActSaveLoggerConfig,
15
- C.Field(discriminator="name"),
16
- ]
14
+ LoggerConfig = TypeAliasType(
15
+ "LoggerConfig",
16
+ Annotated[
17
+ CSVLoggerConfig
18
+ | TensorboardLoggerConfig
19
+ | WandbLoggerConfig
20
+ | ActSaveLoggerConfig,
21
+ C.Field(discriminator="name"),
22
+ ],
23
+ )
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, TypeAlias
3
+ from typing import Annotated
4
4
 
5
5
  import nshconfig as C
6
+ from typing_extensions import TypeAliasType
6
7
 
7
8
  from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
8
9
  from ._base import LRSchedulerMetadata as LRSchedulerMetadata
@@ -15,7 +16,10 @@ from .linear_warmup_cosine import (
15
16
  from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
16
17
  from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
17
18
 
18
- LRSchedulerConfig: TypeAlias = Annotated[
19
- LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
20
- C.Field(discriminator="name"),
21
- ]
19
+ LRSchedulerConfig = TypeAliasType(
20
+ "LRSchedulerConfig",
21
+ Annotated[
22
+ LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
23
+ C.Field(discriminator="name"),
24
+ ],
25
+ )
@@ -2,18 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from collections.abc import Callable, Iterable, Sequence
5
- from typing import Any, TypeAlias, cast
5
+ from typing import Any, cast
6
6
 
7
7
  from lightning.pytorch import Callback, LightningModule
8
- from typing_extensions import override
8
+ from typing_extensions import TypeAliasType, override
9
9
 
10
10
  from ..._callback import NTCallbackBase
11
11
  from ...util.typing_utils import mixin_base_type
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
15
- _Callback = Callback | NTCallbackBase
16
- CallbackFn: TypeAlias = Callable[[], _Callback | Iterable[_Callback] | None]
15
+ _Callback = TypeAliasType("_Callback", Callback | NTCallbackBase)
16
+ CallbackFn = TypeAliasType(
17
+ "CallbackFn", Callable[[], _Callback | Iterable[_Callback] | None]
18
+ )
17
19
 
18
20
 
19
21
  class CallbackRegistrarModuleMixin:
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable
5
- from typing import Annotated, Any, Literal, TypeAlias
5
+ from typing import Annotated, Any, Literal
6
6
 
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import override
10
+ from typing_extensions import TypeAliasType, override
11
11
 
12
12
 
13
13
  class OptimizerConfigBase(C.Config, ABC):
@@ -57,4 +57,6 @@ class AdamWConfig(OptimizerConfigBase):
57
57
  )
58
58
 
59
59
 
60
- OptimizerConfig: TypeAlias = Annotated[AdamWConfig, C.Field(discriminator="name")]
60
+ OptimizerConfig = TypeAliasType(
61
+ "OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
62
+ )
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, TypeAlias
3
+ from typing import Annotated
4
4
 
5
5
  import nshconfig as C
6
+ from typing_extensions import TypeAliasType
6
7
 
7
8
  from ._base import BaseProfilerConfig as BaseProfilerConfig
8
9
  from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
9
10
  from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
10
11
  from .simple import SimpleProfilerConfig as SimpleProfilerConfig
11
12
 
12
- ProfilerConfig: TypeAlias = Annotated[
13
- SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
14
- C.Field(discriminator="name"),
15
- ]
13
+ ProfilerConfig = TypeAliasType(
14
+ "ProfilerConfig",
15
+ Annotated[
16
+ SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
17
+ C.Field(discriminator="name"),
18
+ ],
19
+ )
@@ -14,7 +14,6 @@ from typing import (
14
14
  Any,
15
15
  ClassVar,
16
16
  Literal,
17
- TypeAlias,
18
17
  )
19
18
 
20
19
  import nshconfig as C
@@ -101,47 +100,53 @@ class StrategyConfigBase(C.Config, ABC):
101
100
  strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
102
101
 
103
102
 
104
- AcceleratorLiteral: TypeAlias = Literal[
105
- "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
106
- ]
107
-
108
- StrategyLiteral: TypeAlias = Literal[
109
- "auto",
110
- "ddp",
111
- "ddp_find_unused_parameters_false",
112
- "ddp_find_unused_parameters_true",
113
- "ddp_spawn",
114
- "ddp_spawn_find_unused_parameters_false",
115
- "ddp_spawn_find_unused_parameters_true",
116
- "ddp_fork",
117
- "ddp_fork_find_unused_parameters_false",
118
- "ddp_fork_find_unused_parameters_true",
119
- "ddp_notebook",
120
- "dp",
121
- "deepspeed",
122
- "deepspeed_stage_1",
123
- "deepspeed_stage_1_offload",
124
- "deepspeed_stage_2",
125
- "deepspeed_stage_2_offload",
126
- "deepspeed_stage_3",
127
- "deepspeed_stage_3_offload",
128
- "deepspeed_stage_3_offload_nvme",
129
- "fsdp",
130
- "fsdp_cpu_offload",
131
- "single_xla",
132
- "xla_fsdp",
133
- "xla",
134
- "single_tpu",
135
- ]
136
-
137
-
138
- CheckpointCallbackConfig: TypeAlias = Annotated[
139
- BestCheckpointCallbackConfig
140
- | LastCheckpointCallbackConfig
141
- | OnExceptionCheckpointCallbackConfig
142
- | TimeCheckpointCallbackConfig,
143
- C.Field(discriminator="name"),
144
- ]
103
+ AcceleratorLiteral = TypeAliasType(
104
+ "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
105
+ )
106
+
107
+ StrategyLiteral = TypeAliasType(
108
+ "StrategyLiteral",
109
+ Literal[
110
+ "auto",
111
+ "ddp",
112
+ "ddp_find_unused_parameters_false",
113
+ "ddp_find_unused_parameters_true",
114
+ "ddp_spawn",
115
+ "ddp_spawn_find_unused_parameters_false",
116
+ "ddp_spawn_find_unused_parameters_true",
117
+ "ddp_fork",
118
+ "ddp_fork_find_unused_parameters_false",
119
+ "ddp_fork_find_unused_parameters_true",
120
+ "ddp_notebook",
121
+ "dp",
122
+ "deepspeed",
123
+ "deepspeed_stage_1",
124
+ "deepspeed_stage_1_offload",
125
+ "deepspeed_stage_2",
126
+ "deepspeed_stage_2_offload",
127
+ "deepspeed_stage_3",
128
+ "deepspeed_stage_3_offload",
129
+ "deepspeed_stage_3_offload_nvme",
130
+ "fsdp",
131
+ "fsdp_cpu_offload",
132
+ "single_xla",
133
+ "xla_fsdp",
134
+ "xla",
135
+ "single_tpu",
136
+ ],
137
+ )
138
+
139
+
140
+ CheckpointCallbackConfig = TypeAliasType(
141
+ "CheckpointCallbackConfig",
142
+ Annotated[
143
+ BestCheckpointCallbackConfig
144
+ | LastCheckpointCallbackConfig
145
+ | OnExceptionCheckpointCallbackConfig
146
+ | TimeCheckpointCallbackConfig,
147
+ C.Field(discriminator="name"),
148
+ ],
149
+ )
145
150
 
146
151
 
147
152
  class CheckpointSavingConfig(CallbackConfigBase):
@@ -4,14 +4,14 @@ import datetime
4
4
  import logging
5
5
  import time
6
6
  from dataclasses import dataclass
7
- from typing import Any, Literal, TypeAlias
7
+ from typing import Any, Literal
8
8
 
9
9
  from lightning.pytorch.callbacks.callback import Callback
10
- from typing_extensions import override
10
+ from typing_extensions import TypeAliasType, override
11
11
 
12
12
  log = logging.getLogger(__name__)
13
13
 
14
- Stage: TypeAlias = Literal["train", "validate", "test", "predict"]
14
+ Stage = TypeAliasType("Stage", Literal["train", "validate", "test", "predict"])
15
15
  ALL_STAGES = ("train", "validate", "test", "predict")
16
16
 
17
17
 
@@ -12,7 +12,7 @@ from collections import defaultdict
12
12
  from collections.abc import Callable
13
13
  from pathlib import Path
14
14
  from types import FrameType
15
- from typing import Any, TypeAlias
15
+ from typing import Any
16
16
 
17
17
  import nshrunner as nr
18
18
  import torch.utils.data
@@ -22,12 +22,14 @@ from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompo
22
22
  from lightning.pytorch.trainer.connectors.signal_connector import (
23
23
  _SignalConnector as _LightningSignalConnector,
24
24
  )
25
- from typing_extensions import override
25
+ from typing_extensions import TypeAliasType, override
26
26
 
27
27
  log = logging.getLogger(__name__)
28
28
 
29
- _SIGNUM = int | signal.Signals
30
- _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
29
+ _SIGNUM = TypeAliasType("_SIGNUM", int | signal.Signals)
30
+ _HANDLER = TypeAliasType(
31
+ "_HANDLER", Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
32
+ )
31
33
  _IS_WINDOWS = platform.system() == "Windows"
32
34
 
33
35
 
@@ -7,7 +7,14 @@ from collections.abc import Set as AbstractSet
7
7
  from os import PathLike
8
8
  from typing import Any, TypeVar, overload
9
9
 
10
- from typing_extensions import Buffer, Literal, Protocol, SupportsIndex, TypeAlias
10
+ from typing_extensions import (
11
+ Buffer,
12
+ Literal,
13
+ Protocol,
14
+ SupportsIndex,
15
+ TypeAlias,
16
+ TypeAliasType,
17
+ )
11
18
 
12
19
  _KT = TypeVar("_KT")
13
20
  _KT_co = TypeVar("_KT_co", covariant=True)
@@ -60,7 +67,9 @@ class SupportsAllComparisons(
60
67
  ): ...
61
68
 
62
69
 
63
- SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any]
70
+ SupportsRichComparison = TypeAliasType(
71
+ "SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
72
+ )
64
73
  SupportsRichComparisonT = TypeVar(
65
74
  "SupportsRichComparisonT", bound=SupportsRichComparison
66
75
  )
@@ -1,57 +1,60 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Literal, TypeAlias
3
+ from typing import TYPE_CHECKING, Literal
4
4
 
5
5
  import nshconfig as C
6
6
  import torch
7
- from typing_extensions import assert_never
7
+ from typing_extensions import TypeAliasType, assert_never
8
8
 
9
9
  from ..bf16 import is_bf16_supported_no_emulation
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from ...trainer._config import TrainerConfig
13
13
 
14
- DTypeName: TypeAlias = Literal[
15
- "float32",
16
- "float",
17
- "float64",
18
- "double",
19
- "float16",
20
- "bfloat16",
21
- "float8_e4m3fn",
22
- "float8_e4m3fnuz",
23
- "float8_e5m2",
24
- "float8_e5m2fnuz",
25
- "half",
26
- "uint8",
27
- "uint16",
28
- "uint32",
29
- "uint64",
30
- "int8",
31
- "int16",
32
- "short",
33
- "int32",
34
- "int",
35
- "int64",
36
- "long",
37
- "complex32",
38
- "complex64",
39
- "chalf",
40
- "cfloat",
41
- "complex128",
42
- "cdouble",
43
- "quint8",
44
- "qint8",
45
- "qint32",
46
- "bool",
47
- "quint4x2",
48
- "quint2x4",
49
- "bits1x8",
50
- "bits2x4",
51
- "bits4x2",
52
- "bits8",
53
- "bits16",
54
- ]
14
+ DTypeName = TypeAliasType(
15
+ "DTypeName",
16
+ Literal[
17
+ "float32",
18
+ "float",
19
+ "float64",
20
+ "double",
21
+ "float16",
22
+ "bfloat16",
23
+ "float8_e4m3fn",
24
+ "float8_e4m3fnuz",
25
+ "float8_e5m2",
26
+ "float8_e5m2fnuz",
27
+ "half",
28
+ "uint8",
29
+ "uint16",
30
+ "uint32",
31
+ "uint64",
32
+ "int8",
33
+ "int16",
34
+ "short",
35
+ "int32",
36
+ "int",
37
+ "int64",
38
+ "long",
39
+ "complex32",
40
+ "complex64",
41
+ "chalf",
42
+ "cfloat",
43
+ "complex128",
44
+ "cdouble",
45
+ "quint8",
46
+ "qint8",
47
+ "qint32",
48
+ "bool",
49
+ "quint4x2",
50
+ "quint2x4",
51
+ "bits1x8",
52
+ "bits2x4",
53
+ "bits4x2",
54
+ "bits8",
55
+ "bits16",
56
+ ],
57
+ )
55
58
 
56
59
 
57
60
  class DTypeConfig(C.Config):
@@ -6,11 +6,12 @@ import os
6
6
  import platform
7
7
  import shutil
8
8
  from pathlib import Path
9
- from typing import TypeAlias
9
+
10
+ from typing_extensions import TypeAliasType
10
11
 
11
12
  log = logging.getLogger(__name__)
12
13
 
13
- _Path: TypeAlias = str | Path | os.PathLike
14
+ _Path = TypeAliasType("_Path", str | Path)
14
15
 
15
16
 
16
17
  def get_relative_path(source: _Path, destination: _Path):
File without changes