nshtrainer 1.0.0b25__py3-none-any.whl → 1.0.0b27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. nshtrainer/.nshconfig.generated.json +6 -0
  2. nshtrainer/_checkpoint/metadata.py +1 -1
  3. nshtrainer/callbacks/__init__.py +3 -0
  4. nshtrainer/callbacks/actsave.py +2 -2
  5. nshtrainer/callbacks/base.py +5 -3
  6. nshtrainer/callbacks/checkpoint/__init__.py +4 -0
  7. nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -2
  8. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -2
  9. nshtrainer/callbacks/checkpoint/time_checkpoint.py +114 -0
  10. nshtrainer/callbacks/print_table.py +2 -2
  11. nshtrainer/callbacks/shared_parameters.py +5 -3
  12. nshtrainer/configs/__init__.py +99 -10
  13. nshtrainer/configs/_checkpoint/__init__.py +6 -0
  14. nshtrainer/configs/_checkpoint/metadata/__init__.py +5 -0
  15. nshtrainer/configs/_directory/__init__.py +5 -1
  16. nshtrainer/configs/_hf_hub/__init__.py +6 -0
  17. nshtrainer/configs/callbacks/__init__.py +48 -1
  18. nshtrainer/configs/callbacks/actsave/__init__.py +5 -0
  19. nshtrainer/configs/callbacks/base/__init__.py +4 -0
  20. nshtrainer/configs/callbacks/checkpoint/__init__.py +20 -0
  21. nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +6 -0
  22. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +7 -0
  23. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +6 -0
  24. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +5 -0
  25. nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
  26. nshtrainer/configs/callbacks/debug_flag/__init__.py +5 -0
  27. nshtrainer/configs/callbacks/directory_setup/__init__.py +5 -0
  28. nshtrainer/configs/callbacks/early_stopping/__init__.py +6 -0
  29. nshtrainer/configs/callbacks/ema/__init__.py +5 -0
  30. nshtrainer/configs/callbacks/finite_checks/__init__.py +5 -0
  31. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +5 -0
  32. nshtrainer/configs/callbacks/log_epoch/__init__.py +5 -0
  33. nshtrainer/configs/callbacks/lr_monitor/__init__.py +5 -0
  34. nshtrainer/configs/callbacks/norm_logging/__init__.py +5 -0
  35. nshtrainer/configs/callbacks/print_table/__init__.py +5 -0
  36. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +5 -0
  37. nshtrainer/configs/callbacks/shared_parameters/__init__.py +5 -0
  38. nshtrainer/configs/callbacks/timer/__init__.py +5 -0
  39. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +5 -0
  40. nshtrainer/configs/callbacks/wandb_watch/__init__.py +5 -0
  41. nshtrainer/configs/loggers/__init__.py +16 -1
  42. nshtrainer/configs/loggers/_base/__init__.py +4 -0
  43. nshtrainer/configs/loggers/actsave/__init__.py +5 -0
  44. nshtrainer/configs/loggers/csv/__init__.py +5 -0
  45. nshtrainer/configs/loggers/tensorboard/__init__.py +5 -0
  46. nshtrainer/configs/loggers/wandb/__init__.py +8 -0
  47. nshtrainer/configs/lr_scheduler/__init__.py +10 -4
  48. nshtrainer/configs/lr_scheduler/_base/__init__.py +4 -0
  49. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +5 -3
  50. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -0
  51. nshtrainer/configs/metrics/__init__.py +5 -0
  52. nshtrainer/configs/metrics/_config/__init__.py +4 -0
  53. nshtrainer/configs/nn/__init__.py +21 -1
  54. nshtrainer/configs/nn/mlp/__init__.py +5 -1
  55. nshtrainer/configs/nn/nonlinearity/__init__.py +18 -1
  56. nshtrainer/configs/optimizer/__init__.py +5 -1
  57. nshtrainer/configs/profiler/__init__.py +11 -1
  58. nshtrainer/configs/profiler/_base/__init__.py +4 -0
  59. nshtrainer/configs/profiler/advanced/__init__.py +5 -0
  60. nshtrainer/configs/profiler/pytorch/__init__.py +5 -0
  61. nshtrainer/configs/profiler/simple/__init__.py +5 -0
  62. nshtrainer/configs/trainer/__init__.py +39 -6
  63. nshtrainer/configs/trainer/_config/__init__.py +37 -6
  64. nshtrainer/configs/trainer/trainer/__init__.py +9 -0
  65. nshtrainer/configs/util/__init__.py +19 -1
  66. nshtrainer/configs/util/_environment_info/__init__.py +14 -0
  67. nshtrainer/configs/util/config/__init__.py +8 -1
  68. nshtrainer/configs/util/config/dtype/__init__.py +4 -0
  69. nshtrainer/configs/util/config/duration/__init__.py +5 -1
  70. nshtrainer/loggers/__init__.py +12 -5
  71. nshtrainer/lr_scheduler/__init__.py +9 -5
  72. nshtrainer/model/mixins/callback.py +6 -4
  73. nshtrainer/optimizer.py +5 -3
  74. nshtrainer/profiler/__init__.py +9 -5
  75. nshtrainer/trainer/_config.py +85 -61
  76. nshtrainer/trainer/_runtime_callback.py +3 -3
  77. nshtrainer/trainer/signal_connector.py +6 -4
  78. nshtrainer/trainer/trainer.py +4 -4
  79. nshtrainer/util/_useful_types.py +11 -2
  80. nshtrainer/util/config/dtype.py +46 -43
  81. nshtrainer/util/path.py +3 -2
  82. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/METADATA +2 -1
  83. nshtrainer-1.0.0b27.dist-info/RECORD +143 -0
  84. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/WHEEL +1 -1
  85. nshtrainer-1.0.0b25.dist-info/RECORD +0 -140
@@ -31,3 +31,17 @@ from nshtrainer.util._environment_info import (
31
31
  EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
32
32
  )
33
33
  from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
34
+
35
+ __all__ = [
36
+ "EnvironmentCUDAConfig",
37
+ "EnvironmentClassInformationConfig",
38
+ "EnvironmentConfig",
39
+ "EnvironmentGPUConfig",
40
+ "EnvironmentHardwareConfig",
41
+ "EnvironmentLSFInformationConfig",
42
+ "EnvironmentLinuxEnvironmentConfig",
43
+ "EnvironmentPackageConfig",
44
+ "EnvironmentSLURMInformationConfig",
45
+ "EnvironmentSnapshotConfig",
46
+ "GitRepositoryConfig",
47
+ ]
@@ -3,9 +3,16 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.util.config import DTypeConfig as DTypeConfig
6
- from nshtrainer.util.config import DurationConfig as DurationConfig
7
6
  from nshtrainer.util.config import EpochsConfig as EpochsConfig
8
7
  from nshtrainer.util.config import StepsConfig as StepsConfig
9
8
 
10
9
  from . import dtype as dtype
11
10
  from . import duration as duration
11
+
12
+ __all__ = [
13
+ "DTypeConfig",
14
+ "EpochsConfig",
15
+ "StepsConfig",
16
+ "dtype",
17
+ "duration",
18
+ ]
@@ -3,3 +3,7 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.util.config.dtype import DTypeConfig as DTypeConfig
6
+
7
+ __all__ = [
8
+ "DTypeConfig",
9
+ ]
@@ -2,6 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.util.config.duration import DurationConfig as DurationConfig
6
5
  from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
7
6
  from nshtrainer.util.config.duration import StepsConfig as StepsConfig
7
+
8
+ __all__ = [
9
+ "EpochsConfig",
10
+ "StepsConfig",
11
+ ]
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, TypeAlias
3
+ from typing import Annotated
4
4
 
5
5
  import nshconfig as C
6
+ from typing_extensions import TypeAliasType
6
7
 
7
8
  from ._base import BaseLoggerConfig as BaseLoggerConfig
8
9
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
@@ -10,7 +11,13 @@ from .csv import CSVLoggerConfig as CSVLoggerConfig
10
11
  from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
11
12
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
12
13
 
13
- LoggerConfig: TypeAlias = Annotated[
14
- CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig | ActSaveLoggerConfig,
15
- C.Field(discriminator="name"),
16
- ]
14
+ LoggerConfig = TypeAliasType(
15
+ "LoggerConfig",
16
+ Annotated[
17
+ CSVLoggerConfig
18
+ | TensorboardLoggerConfig
19
+ | WandbLoggerConfig
20
+ | ActSaveLoggerConfig,
21
+ C.Field(discriminator="name"),
22
+ ],
23
+ )
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, TypeAlias
3
+ from typing import Annotated
4
4
 
5
5
  import nshconfig as C
6
+ from typing_extensions import TypeAliasType
6
7
 
7
8
  from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
8
9
  from ._base import LRSchedulerMetadata as LRSchedulerMetadata
@@ -15,7 +16,10 @@ from .linear_warmup_cosine import (
15
16
  from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
16
17
  from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
17
18
 
18
- LRSchedulerConfig: TypeAlias = Annotated[
19
- LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
20
- C.Field(discriminator="name"),
21
- ]
19
+ LRSchedulerConfig = TypeAliasType(
20
+ "LRSchedulerConfig",
21
+ Annotated[
22
+ LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
23
+ C.Field(discriminator="name"),
24
+ ],
25
+ )
@@ -2,18 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from collections.abc import Callable, Iterable, Sequence
5
- from typing import Any, TypeAlias, cast
5
+ from typing import Any, cast
6
6
 
7
7
  from lightning.pytorch import Callback, LightningModule
8
- from typing_extensions import override
8
+ from typing_extensions import TypeAliasType, override
9
9
 
10
10
  from ..._callback import NTCallbackBase
11
11
  from ...util.typing_utils import mixin_base_type
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
15
- _Callback = Callback | NTCallbackBase
16
- CallbackFn: TypeAlias = Callable[[], _Callback | Iterable[_Callback] | None]
15
+ _Callback = TypeAliasType("_Callback", Callback | NTCallbackBase)
16
+ CallbackFn = TypeAliasType(
17
+ "CallbackFn", Callable[[], _Callback | Iterable[_Callback] | None]
18
+ )
17
19
 
18
20
 
19
21
  class CallbackRegistrarModuleMixin:
nshtrainer/optimizer.py CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable
5
- from typing import Annotated, Any, Literal, TypeAlias
5
+ from typing import Annotated, Any, Literal
6
6
 
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import override
10
+ from typing_extensions import TypeAliasType, override
11
11
 
12
12
 
13
13
  class OptimizerConfigBase(C.Config, ABC):
@@ -57,4 +57,6 @@ class AdamWConfig(OptimizerConfigBase):
57
57
  )
58
58
 
59
59
 
60
- OptimizerConfig: TypeAlias = Annotated[AdamWConfig, C.Field(discriminator="name")]
60
+ OptimizerConfig = TypeAliasType(
61
+ "OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
62
+ )
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, TypeAlias
3
+ from typing import Annotated
4
4
 
5
5
  import nshconfig as C
6
+ from typing_extensions import TypeAliasType
6
7
 
7
8
  from ._base import BaseProfilerConfig as BaseProfilerConfig
8
9
  from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
9
10
  from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
10
11
  from .simple import SimpleProfilerConfig as SimpleProfilerConfig
11
12
 
12
- ProfilerConfig: TypeAlias = Annotated[
13
- SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
14
- C.Field(discriminator="name"),
15
- ]
13
+ ProfilerConfig = TypeAliasType(
14
+ "ProfilerConfig",
15
+ Annotated[
16
+ SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
17
+ C.Field(discriminator="name"),
18
+ ],
19
+ )
@@ -5,6 +5,7 @@ import logging
5
5
  import os
6
6
  import string
7
7
  import time
8
+ from abc import ABC, abstractmethod
8
9
  from collections.abc import Iterable, Sequence
9
10
  from datetime import timedelta
10
11
  from pathlib import Path
@@ -13,9 +14,6 @@ from typing import (
13
14
  Any,
14
15
  ClassVar,
15
16
  Literal,
16
- Protocol,
17
- TypeAlias,
18
- runtime_checkable,
19
17
  )
20
18
 
21
19
  import nshconfig as C
@@ -30,7 +28,7 @@ from lightning.pytorch.plugins.layer_sync import LayerSync
30
28
  from lightning.pytorch.plugins.precision.precision import Precision
31
29
  from lightning.pytorch.profilers import Profiler
32
30
  from lightning.pytorch.strategies.strategy import Strategy
33
- from typing_extensions import TypedDict, TypeVar, override
31
+ from typing_extensions import TypeAliasType, TypedDict, override
34
32
 
35
33
  from .._directory import DirectoryConfig
36
34
  from .._hf_hub import HuggingFaceHubConfig
@@ -43,6 +41,7 @@ from ..callbacks import (
43
41
  OnExceptionCheckpointCallbackConfig,
44
42
  )
45
43
  from ..callbacks.base import CallbackConfigBase
44
+ from ..callbacks.checkpoint.time_checkpoint import TimeCheckpointCallbackConfig
46
45
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
47
46
  from ..callbacks.log_epoch import LogEpochCallbackConfig
48
47
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
@@ -72,71 +71,82 @@ class GradientClippingConfig(C.Config):
72
71
  """Norm type to use for gradient clipping."""
73
72
 
74
73
 
75
- TPlugin = TypeVar(
76
- "TPlugin",
77
- Precision,
78
- ClusterEnvironment,
79
- CheckpointIO,
80
- LayerSync,
81
- infer_variance=True,
74
+ Plugin = TypeAliasType(
75
+ "Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
82
76
  )
83
77
 
84
78
 
85
- @runtime_checkable
86
- class PluginConfigProtocol(Protocol[TPlugin]):
87
- def create_plugin(self) -> TPlugin: ...
79
+ class PluginConfigBase(C.Config, ABC):
80
+ @abstractmethod
81
+ def create_plugin(self) -> Plugin: ...
88
82
 
89
83
 
90
- @runtime_checkable
91
- class AcceleratorConfigProtocol(Protocol):
84
+ plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
85
+
86
+
87
+ class AcceleratorConfigBase(C.Config, ABC):
88
+ @abstractmethod
92
89
  def create_accelerator(self) -> Accelerator: ...
93
90
 
94
91
 
95
- @runtime_checkable
96
- class StrategyConfigProtocol(Protocol):
92
+ accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
93
+
94
+
95
+ class StrategyConfigBase(C.Config, ABC):
96
+ @abstractmethod
97
97
  def create_strategy(self) -> Strategy: ...
98
98
 
99
99
 
100
- AcceleratorLiteral: TypeAlias = Literal[
101
- "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
102
- ]
103
-
104
- StrategyLiteral: TypeAlias = Literal[
105
- "auto",
106
- "ddp",
107
- "ddp_find_unused_parameters_false",
108
- "ddp_find_unused_parameters_true",
109
- "ddp_spawn",
110
- "ddp_spawn_find_unused_parameters_false",
111
- "ddp_spawn_find_unused_parameters_true",
112
- "ddp_fork",
113
- "ddp_fork_find_unused_parameters_false",
114
- "ddp_fork_find_unused_parameters_true",
115
- "ddp_notebook",
116
- "dp",
117
- "deepspeed",
118
- "deepspeed_stage_1",
119
- "deepspeed_stage_1_offload",
120
- "deepspeed_stage_2",
121
- "deepspeed_stage_2_offload",
122
- "deepspeed_stage_3",
123
- "deepspeed_stage_3_offload",
124
- "deepspeed_stage_3_offload_nvme",
125
- "fsdp",
126
- "fsdp_cpu_offload",
127
- "single_xla",
128
- "xla_fsdp",
129
- "xla",
130
- "single_tpu",
131
- ]
132
-
133
-
134
- CheckpointCallbackConfig: TypeAlias = Annotated[
135
- BestCheckpointCallbackConfig
136
- | LastCheckpointCallbackConfig
137
- | OnExceptionCheckpointCallbackConfig,
138
- C.Field(discriminator="name"),
139
- ]
100
+ strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
101
+
102
+
103
+ AcceleratorLiteral = TypeAliasType(
104
+ "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
105
+ )
106
+
107
+ StrategyLiteral = TypeAliasType(
108
+ "StrategyLiteral",
109
+ Literal[
110
+ "auto",
111
+ "ddp",
112
+ "ddp_find_unused_parameters_false",
113
+ "ddp_find_unused_parameters_true",
114
+ "ddp_spawn",
115
+ "ddp_spawn_find_unused_parameters_false",
116
+ "ddp_spawn_find_unused_parameters_true",
117
+ "ddp_fork",
118
+ "ddp_fork_find_unused_parameters_false",
119
+ "ddp_fork_find_unused_parameters_true",
120
+ "ddp_notebook",
121
+ "dp",
122
+ "deepspeed",
123
+ "deepspeed_stage_1",
124
+ "deepspeed_stage_1_offload",
125
+ "deepspeed_stage_2",
126
+ "deepspeed_stage_2_offload",
127
+ "deepspeed_stage_3",
128
+ "deepspeed_stage_3_offload",
129
+ "deepspeed_stage_3_offload_nvme",
130
+ "fsdp",
131
+ "fsdp_cpu_offload",
132
+ "single_xla",
133
+ "xla_fsdp",
134
+ "xla",
135
+ "single_tpu",
136
+ ],
137
+ )
138
+
139
+
140
+ CheckpointCallbackConfig = TypeAliasType(
141
+ "CheckpointCallbackConfig",
142
+ Annotated[
143
+ BestCheckpointCallbackConfig
144
+ | LastCheckpointCallbackConfig
145
+ | OnExceptionCheckpointCallbackConfig
146
+ | TimeCheckpointCallbackConfig,
147
+ C.Field(discriminator="name"),
148
+ ],
149
+ )
140
150
 
141
151
 
142
152
  class CheckpointSavingConfig(CallbackConfigBase):
@@ -147,6 +157,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
147
157
  BestCheckpointCallbackConfig(throw_on_no_metric=False),
148
158
  LastCheckpointCallbackConfig(),
149
159
  OnExceptionCheckpointCallbackConfig(),
160
+ TimeCheckpointCallbackConfig(interval=timedelta(hours=12)),
150
161
  ]
151
162
  """Checkpoint callback configurations."""
152
163
 
@@ -420,6 +431,9 @@ class SanityCheckingConfig(C.Config):
420
431
  """
421
432
 
422
433
 
434
+ @plugin_registry.rebuild_on_registers
435
+ @strategy_registry.rebuild_on_registers
436
+ @accelerator_registry.rebuild_on_registers
423
437
  class TrainerConfig(C.Config):
424
438
  # region Active Run Configuration
425
439
  id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
@@ -564,7 +578,9 @@ class TrainerConfig(C.Config):
564
578
  Default: ``False``.
565
579
  """
566
580
 
567
- plugins: list[PluginConfigProtocol] | None = None
581
+ plugins: (
582
+ list[Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]] | None
583
+ ) = None
568
584
  """
569
585
  Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
570
586
  Default: ``None``.
@@ -724,13 +740,21 @@ class TrainerConfig(C.Config):
724
740
  Default: ``True``.
725
741
  """
726
742
 
727
- accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
743
+ accelerator: (
744
+ Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()]
745
+ | AcceleratorLiteral
746
+ | None
747
+ ) = None
728
748
  """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
729
749
  as well as custom accelerator instances.
730
750
  Default: ``"auto"``.
731
751
  """
732
752
 
733
- strategy: StrategyConfigProtocol | StrategyLiteral | None = None
753
+ strategy: (
754
+ Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()]
755
+ | StrategyLiteral
756
+ | None
757
+ ) = None
734
758
  """Supports different training strategies with aliases as well custom strategies.
735
759
  Default: ``"auto"``.
736
760
  """
@@ -4,14 +4,14 @@ import datetime
4
4
  import logging
5
5
  import time
6
6
  from dataclasses import dataclass
7
- from typing import Any, Literal, TypeAlias
7
+ from typing import Any, Literal
8
8
 
9
9
  from lightning.pytorch.callbacks.callback import Callback
10
- from typing_extensions import override
10
+ from typing_extensions import TypeAliasType, override
11
11
 
12
12
  log = logging.getLogger(__name__)
13
13
 
14
- Stage: TypeAlias = Literal["train", "validate", "test", "predict"]
14
+ Stage = TypeAliasType("Stage", Literal["train", "validate", "test", "predict"])
15
15
  ALL_STAGES = ("train", "validate", "test", "predict")
16
16
 
17
17
 
@@ -12,7 +12,7 @@ from collections import defaultdict
12
12
  from collections.abc import Callable
13
13
  from pathlib import Path
14
14
  from types import FrameType
15
- from typing import Any, TypeAlias
15
+ from typing import Any
16
16
 
17
17
  import nshrunner as nr
18
18
  import torch.utils.data
@@ -22,12 +22,14 @@ from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompo
22
22
  from lightning.pytorch.trainer.connectors.signal_connector import (
23
23
  _SignalConnector as _LightningSignalConnector,
24
24
  )
25
- from typing_extensions import override
25
+ from typing_extensions import TypeAliasType, override
26
26
 
27
27
  log = logging.getLogger(__name__)
28
28
 
29
- _SIGNUM = int | signal.Signals
30
- _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
29
+ _SIGNUM = TypeAliasType("_SIGNUM", int | signal.Signals)
30
+ _HANDLER = TypeAliasType(
31
+ "_HANDLER", Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
32
+ )
31
33
  _IS_WINDOWS = platform.system() == "Windows"
32
34
 
33
35
 
@@ -23,9 +23,9 @@ from ..callbacks.base import resolve_all_callbacks
23
23
  from ..util._environment_info import EnvironmentConfig
24
24
  from ..util.bf16 import is_bf16_supported_no_emulation
25
25
  from ._config import (
26
- AcceleratorConfigProtocol,
26
+ AcceleratorConfigBase,
27
27
  LightningTrainerKwargs,
28
- StrategyConfigProtocol,
28
+ StrategyConfigBase,
29
29
  TrainerConfig,
30
30
  )
31
31
  from ._runtime_callback import RuntimeTrackerCallback, Stage
@@ -171,12 +171,12 @@ class Trainer(LightningTrainer):
171
171
  _update_kwargs(use_distributed_sampler=use_distributed_sampler)
172
172
 
173
173
  if (accelerator := hparams.accelerator) is not None:
174
- if isinstance(accelerator, AcceleratorConfigProtocol):
174
+ if isinstance(accelerator, AcceleratorConfigBase):
175
175
  accelerator = accelerator.create_accelerator()
176
176
  _update_kwargs(accelerator=accelerator)
177
177
 
178
178
  if (strategy := hparams.strategy) is not None:
179
- if isinstance(strategy, StrategyConfigProtocol):
179
+ if isinstance(strategy, StrategyConfigBase):
180
180
  strategy = strategy.create_strategy()
181
181
  _update_kwargs(strategy=strategy)
182
182
 
@@ -7,7 +7,14 @@ from collections.abc import Set as AbstractSet
7
7
  from os import PathLike
8
8
  from typing import Any, TypeVar, overload
9
9
 
10
- from typing_extensions import Buffer, Literal, Protocol, SupportsIndex, TypeAlias
10
+ from typing_extensions import (
11
+ Buffer,
12
+ Literal,
13
+ Protocol,
14
+ SupportsIndex,
15
+ TypeAlias,
16
+ TypeAliasType,
17
+ )
11
18
 
12
19
  _KT = TypeVar("_KT")
13
20
  _KT_co = TypeVar("_KT_co", covariant=True)
@@ -60,7 +67,9 @@ class SupportsAllComparisons(
60
67
  ): ...
61
68
 
62
69
 
63
- SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any]
70
+ SupportsRichComparison = TypeAliasType(
71
+ "SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
72
+ )
64
73
  SupportsRichComparisonT = TypeVar(
65
74
  "SupportsRichComparisonT", bound=SupportsRichComparison
66
75
  )
@@ -1,57 +1,60 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Literal, TypeAlias
3
+ from typing import TYPE_CHECKING, Literal
4
4
 
5
5
  import nshconfig as C
6
6
  import torch
7
- from typing_extensions import assert_never
7
+ from typing_extensions import TypeAliasType, assert_never
8
8
 
9
9
  from ..bf16 import is_bf16_supported_no_emulation
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from ...trainer._config import TrainerConfig
13
13
 
14
- DTypeName: TypeAlias = Literal[
15
- "float32",
16
- "float",
17
- "float64",
18
- "double",
19
- "float16",
20
- "bfloat16",
21
- "float8_e4m3fn",
22
- "float8_e4m3fnuz",
23
- "float8_e5m2",
24
- "float8_e5m2fnuz",
25
- "half",
26
- "uint8",
27
- "uint16",
28
- "uint32",
29
- "uint64",
30
- "int8",
31
- "int16",
32
- "short",
33
- "int32",
34
- "int",
35
- "int64",
36
- "long",
37
- "complex32",
38
- "complex64",
39
- "chalf",
40
- "cfloat",
41
- "complex128",
42
- "cdouble",
43
- "quint8",
44
- "qint8",
45
- "qint32",
46
- "bool",
47
- "quint4x2",
48
- "quint2x4",
49
- "bits1x8",
50
- "bits2x4",
51
- "bits4x2",
52
- "bits8",
53
- "bits16",
54
- ]
14
+ DTypeName = TypeAliasType(
15
+ "DTypeName",
16
+ Literal[
17
+ "float32",
18
+ "float",
19
+ "float64",
20
+ "double",
21
+ "float16",
22
+ "bfloat16",
23
+ "float8_e4m3fn",
24
+ "float8_e4m3fnuz",
25
+ "float8_e5m2",
26
+ "float8_e5m2fnuz",
27
+ "half",
28
+ "uint8",
29
+ "uint16",
30
+ "uint32",
31
+ "uint64",
32
+ "int8",
33
+ "int16",
34
+ "short",
35
+ "int32",
36
+ "int",
37
+ "int64",
38
+ "long",
39
+ "complex32",
40
+ "complex64",
41
+ "chalf",
42
+ "cfloat",
43
+ "complex128",
44
+ "cdouble",
45
+ "quint8",
46
+ "qint8",
47
+ "qint32",
48
+ "bool",
49
+ "quint4x2",
50
+ "quint2x4",
51
+ "bits1x8",
52
+ "bits2x4",
53
+ "bits4x2",
54
+ "bits8",
55
+ "bits16",
56
+ ],
57
+ )
55
58
 
56
59
 
57
60
  class DTypeConfig(C.Config):
nshtrainer/util/path.py CHANGED
@@ -6,11 +6,12 @@ import os
6
6
  import platform
7
7
  import shutil
8
8
  from pathlib import Path
9
- from typing import TypeAlias
9
+
10
+ from typing_extensions import TypeAliasType
10
11
 
11
12
  log = logging.getLogger(__name__)
12
13
 
13
- _Path: TypeAlias = str | Path | os.PathLike
14
+ _Path = TypeAliasType("_Path", str | Path)
14
15
 
15
16
 
16
17
  def get_relative_path(source: _Path, destination: _Path):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b25
3
+ Version: 1.0.0b27
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,6 +9,7 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
12
13
  Provides-Extra: extra
13
14
  Requires-Dist: GitPython ; extra == "extra"
14
15
  Requires-Dist: huggingface-hub ; extra == "extra"