nshtrainer 1.0.0b25__py3-none-any.whl → 1.0.0b27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/.nshconfig.generated.json +6 -0
- nshtrainer/_checkpoint/metadata.py +1 -1
- nshtrainer/callbacks/__init__.py +3 -0
- nshtrainer/callbacks/actsave.py +2 -2
- nshtrainer/callbacks/base.py +5 -3
- nshtrainer/callbacks/checkpoint/__init__.py +4 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -2
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -2
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +114 -0
- nshtrainer/callbacks/print_table.py +2 -2
- nshtrainer/callbacks/shared_parameters.py +5 -3
- nshtrainer/configs/__init__.py +99 -10
- nshtrainer/configs/_checkpoint/__init__.py +6 -0
- nshtrainer/configs/_checkpoint/metadata/__init__.py +5 -0
- nshtrainer/configs/_directory/__init__.py +5 -1
- nshtrainer/configs/_hf_hub/__init__.py +6 -0
- nshtrainer/configs/callbacks/__init__.py +48 -1
- nshtrainer/configs/callbacks/actsave/__init__.py +5 -0
- nshtrainer/configs/callbacks/base/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +20 -0
- nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +7 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +5 -0
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +5 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +5 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +6 -0
- nshtrainer/configs/callbacks/ema/__init__.py +5 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +5 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +5 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +5 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +5 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +5 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +5 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +5 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +5 -0
- nshtrainer/configs/callbacks/timer/__init__.py +5 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +5 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +5 -0
- nshtrainer/configs/loggers/__init__.py +16 -1
- nshtrainer/configs/loggers/_base/__init__.py +4 -0
- nshtrainer/configs/loggers/actsave/__init__.py +5 -0
- nshtrainer/configs/loggers/csv/__init__.py +5 -0
- nshtrainer/configs/loggers/tensorboard/__init__.py +5 -0
- nshtrainer/configs/loggers/wandb/__init__.py +8 -0
- nshtrainer/configs/lr_scheduler/__init__.py +10 -4
- nshtrainer/configs/lr_scheduler/_base/__init__.py +4 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +5 -3
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -0
- nshtrainer/configs/metrics/__init__.py +5 -0
- nshtrainer/configs/metrics/_config/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +21 -1
- nshtrainer/configs/nn/mlp/__init__.py +5 -1
- nshtrainer/configs/nn/nonlinearity/__init__.py +18 -1
- nshtrainer/configs/optimizer/__init__.py +5 -1
- nshtrainer/configs/profiler/__init__.py +11 -1
- nshtrainer/configs/profiler/_base/__init__.py +4 -0
- nshtrainer/configs/profiler/advanced/__init__.py +5 -0
- nshtrainer/configs/profiler/pytorch/__init__.py +5 -0
- nshtrainer/configs/profiler/simple/__init__.py +5 -0
- nshtrainer/configs/trainer/__init__.py +39 -6
- nshtrainer/configs/trainer/_config/__init__.py +37 -6
- nshtrainer/configs/trainer/trainer/__init__.py +9 -0
- nshtrainer/configs/util/__init__.py +19 -1
- nshtrainer/configs/util/_environment_info/__init__.py +14 -0
- nshtrainer/configs/util/config/__init__.py +8 -1
- nshtrainer/configs/util/config/dtype/__init__.py +4 -0
- nshtrainer/configs/util/config/duration/__init__.py +5 -1
- nshtrainer/loggers/__init__.py +12 -5
- nshtrainer/lr_scheduler/__init__.py +9 -5
- nshtrainer/model/mixins/callback.py +6 -4
- nshtrainer/optimizer.py +5 -3
- nshtrainer/profiler/__init__.py +9 -5
- nshtrainer/trainer/_config.py +85 -61
- nshtrainer/trainer/_runtime_callback.py +3 -3
- nshtrainer/trainer/signal_connector.py +6 -4
- nshtrainer/trainer/trainer.py +4 -4
- nshtrainer/util/_useful_types.py +11 -2
- nshtrainer/util/config/dtype.py +46 -43
- nshtrainer/util/path.py +3 -2
- {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/METADATA +2 -1
- nshtrainer-1.0.0b27.dist-info/RECORD +143 -0
- {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/WHEEL +1 -1
- nshtrainer-1.0.0b25.dist-info/RECORD +0 -140
@@ -31,3 +31,17 @@ from nshtrainer.util._environment_info import (
|
|
31
31
|
EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
|
32
32
|
)
|
33
33
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
34
|
+
|
35
|
+
__all__ = [
|
36
|
+
"EnvironmentCUDAConfig",
|
37
|
+
"EnvironmentClassInformationConfig",
|
38
|
+
"EnvironmentConfig",
|
39
|
+
"EnvironmentGPUConfig",
|
40
|
+
"EnvironmentHardwareConfig",
|
41
|
+
"EnvironmentLSFInformationConfig",
|
42
|
+
"EnvironmentLinuxEnvironmentConfig",
|
43
|
+
"EnvironmentPackageConfig",
|
44
|
+
"EnvironmentSLURMInformationConfig",
|
45
|
+
"EnvironmentSnapshotConfig",
|
46
|
+
"GitRepositoryConfig",
|
47
|
+
]
|
@@ -3,9 +3,16 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
6
|
-
from nshtrainer.util.config import DurationConfig as DurationConfig
|
7
6
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
8
7
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
9
8
|
|
10
9
|
from . import dtype as dtype
|
11
10
|
from . import duration as duration
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"DTypeConfig",
|
14
|
+
"EpochsConfig",
|
15
|
+
"StepsConfig",
|
16
|
+
"dtype",
|
17
|
+
"duration",
|
18
|
+
]
|
@@ -2,6 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
-
from nshtrainer.util.config.duration import DurationConfig as DurationConfig
|
6
5
|
from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
|
7
6
|
from nshtrainer.util.config.duration import StepsConfig as StepsConfig
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"EpochsConfig",
|
10
|
+
"StepsConfig",
|
11
|
+
]
|
nshtrainer/loggers/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import BaseLoggerConfig as BaseLoggerConfig
|
8
9
|
from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
|
@@ -10,7 +11,13 @@ from .csv import CSVLoggerConfig as CSVLoggerConfig
|
|
10
11
|
from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
11
12
|
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
12
13
|
|
13
|
-
LoggerConfig
|
14
|
-
|
15
|
-
|
16
|
-
|
14
|
+
LoggerConfig = TypeAliasType(
|
15
|
+
"LoggerConfig",
|
16
|
+
Annotated[
|
17
|
+
CSVLoggerConfig
|
18
|
+
| TensorboardLoggerConfig
|
19
|
+
| WandbLoggerConfig
|
20
|
+
| ActSaveLoggerConfig,
|
21
|
+
C.Field(discriminator="name"),
|
22
|
+
],
|
23
|
+
)
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
8
9
|
from ._base import LRSchedulerMetadata as LRSchedulerMetadata
|
@@ -15,7 +16,10 @@ from .linear_warmup_cosine import (
|
|
15
16
|
from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
|
16
17
|
from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
17
18
|
|
18
|
-
LRSchedulerConfig
|
19
|
-
|
20
|
-
|
21
|
-
|
19
|
+
LRSchedulerConfig = TypeAliasType(
|
20
|
+
"LRSchedulerConfig",
|
21
|
+
Annotated[
|
22
|
+
LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
|
23
|
+
C.Field(discriminator="name"),
|
24
|
+
],
|
25
|
+
)
|
@@ -2,18 +2,20 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Iterable, Sequence
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, cast
|
6
6
|
|
7
7
|
from lightning.pytorch import Callback, LightningModule
|
8
|
-
from typing_extensions import override
|
8
|
+
from typing_extensions import TypeAliasType, override
|
9
9
|
|
10
10
|
from ..._callback import NTCallbackBase
|
11
11
|
from ...util.typing_utils import mixin_base_type
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
|
-
_Callback = Callback | NTCallbackBase
|
16
|
-
CallbackFn
|
15
|
+
_Callback = TypeAliasType("_Callback", Callback | NTCallbackBase)
|
16
|
+
CallbackFn = TypeAliasType(
|
17
|
+
"CallbackFn", Callable[[], _Callback | Iterable[_Callback] | None]
|
18
|
+
)
|
17
19
|
|
18
20
|
|
19
21
|
class CallbackRegistrarModuleMixin:
|
nshtrainer/optimizer.py
CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Annotated, Any, Literal
|
5
|
+
from typing import Annotated, Any, Literal
|
6
6
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
|
13
13
|
class OptimizerConfigBase(C.Config, ABC):
|
@@ -57,4 +57,6 @@ class AdamWConfig(OptimizerConfigBase):
|
|
57
57
|
)
|
58
58
|
|
59
59
|
|
60
|
-
OptimizerConfig
|
60
|
+
OptimizerConfig = TypeAliasType(
|
61
|
+
"OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
|
62
|
+
)
|
nshtrainer/profiler/__init__.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import BaseProfilerConfig as BaseProfilerConfig
|
8
9
|
from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
|
9
10
|
from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
|
10
11
|
from .simple import SimpleProfilerConfig as SimpleProfilerConfig
|
11
12
|
|
12
|
-
ProfilerConfig
|
13
|
-
|
14
|
-
|
15
|
-
|
13
|
+
ProfilerConfig = TypeAliasType(
|
14
|
+
"ProfilerConfig",
|
15
|
+
Annotated[
|
16
|
+
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
17
|
+
C.Field(discriminator="name"),
|
18
|
+
],
|
19
|
+
)
|
nshtrainer/trainer/_config.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import string
|
7
7
|
import time
|
8
|
+
from abc import ABC, abstractmethod
|
8
9
|
from collections.abc import Iterable, Sequence
|
9
10
|
from datetime import timedelta
|
10
11
|
from pathlib import Path
|
@@ -13,9 +14,6 @@ from typing import (
|
|
13
14
|
Any,
|
14
15
|
ClassVar,
|
15
16
|
Literal,
|
16
|
-
Protocol,
|
17
|
-
TypeAlias,
|
18
|
-
runtime_checkable,
|
19
17
|
)
|
20
18
|
|
21
19
|
import nshconfig as C
|
@@ -30,7 +28,7 @@ from lightning.pytorch.plugins.layer_sync import LayerSync
|
|
30
28
|
from lightning.pytorch.plugins.precision.precision import Precision
|
31
29
|
from lightning.pytorch.profilers import Profiler
|
32
30
|
from lightning.pytorch.strategies.strategy import Strategy
|
33
|
-
from typing_extensions import
|
31
|
+
from typing_extensions import TypeAliasType, TypedDict, override
|
34
32
|
|
35
33
|
from .._directory import DirectoryConfig
|
36
34
|
from .._hf_hub import HuggingFaceHubConfig
|
@@ -43,6 +41,7 @@ from ..callbacks import (
|
|
43
41
|
OnExceptionCheckpointCallbackConfig,
|
44
42
|
)
|
45
43
|
from ..callbacks.base import CallbackConfigBase
|
44
|
+
from ..callbacks.checkpoint.time_checkpoint import TimeCheckpointCallbackConfig
|
46
45
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
47
46
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
48
47
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
@@ -72,71 +71,82 @@ class GradientClippingConfig(C.Config):
|
|
72
71
|
"""Norm type to use for gradient clipping."""
|
73
72
|
|
74
73
|
|
75
|
-
|
76
|
-
"
|
77
|
-
Precision,
|
78
|
-
ClusterEnvironment,
|
79
|
-
CheckpointIO,
|
80
|
-
LayerSync,
|
81
|
-
infer_variance=True,
|
74
|
+
Plugin = TypeAliasType(
|
75
|
+
"Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
|
82
76
|
)
|
83
77
|
|
84
78
|
|
85
|
-
|
86
|
-
|
87
|
-
def create_plugin(self) ->
|
79
|
+
class PluginConfigBase(C.Config, ABC):
|
80
|
+
@abstractmethod
|
81
|
+
def create_plugin(self) -> Plugin: ...
|
88
82
|
|
89
83
|
|
90
|
-
|
91
|
-
|
84
|
+
plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
|
85
|
+
|
86
|
+
|
87
|
+
class AcceleratorConfigBase(C.Config, ABC):
|
88
|
+
@abstractmethod
|
92
89
|
def create_accelerator(self) -> Accelerator: ...
|
93
90
|
|
94
91
|
|
95
|
-
|
96
|
-
|
92
|
+
accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
|
93
|
+
|
94
|
+
|
95
|
+
class StrategyConfigBase(C.Config, ABC):
|
96
|
+
@abstractmethod
|
97
97
|
def create_strategy(self) -> Strategy: ...
|
98
98
|
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
"
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
100
|
+
strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
|
101
|
+
|
102
|
+
|
103
|
+
AcceleratorLiteral = TypeAliasType(
|
104
|
+
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
105
|
+
)
|
106
|
+
|
107
|
+
StrategyLiteral = TypeAliasType(
|
108
|
+
"StrategyLiteral",
|
109
|
+
Literal[
|
110
|
+
"auto",
|
111
|
+
"ddp",
|
112
|
+
"ddp_find_unused_parameters_false",
|
113
|
+
"ddp_find_unused_parameters_true",
|
114
|
+
"ddp_spawn",
|
115
|
+
"ddp_spawn_find_unused_parameters_false",
|
116
|
+
"ddp_spawn_find_unused_parameters_true",
|
117
|
+
"ddp_fork",
|
118
|
+
"ddp_fork_find_unused_parameters_false",
|
119
|
+
"ddp_fork_find_unused_parameters_true",
|
120
|
+
"ddp_notebook",
|
121
|
+
"dp",
|
122
|
+
"deepspeed",
|
123
|
+
"deepspeed_stage_1",
|
124
|
+
"deepspeed_stage_1_offload",
|
125
|
+
"deepspeed_stage_2",
|
126
|
+
"deepspeed_stage_2_offload",
|
127
|
+
"deepspeed_stage_3",
|
128
|
+
"deepspeed_stage_3_offload",
|
129
|
+
"deepspeed_stage_3_offload_nvme",
|
130
|
+
"fsdp",
|
131
|
+
"fsdp_cpu_offload",
|
132
|
+
"single_xla",
|
133
|
+
"xla_fsdp",
|
134
|
+
"xla",
|
135
|
+
"single_tpu",
|
136
|
+
],
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
CheckpointCallbackConfig = TypeAliasType(
|
141
|
+
"CheckpointCallbackConfig",
|
142
|
+
Annotated[
|
143
|
+
BestCheckpointCallbackConfig
|
144
|
+
| LastCheckpointCallbackConfig
|
145
|
+
| OnExceptionCheckpointCallbackConfig
|
146
|
+
| TimeCheckpointCallbackConfig,
|
147
|
+
C.Field(discriminator="name"),
|
148
|
+
],
|
149
|
+
)
|
140
150
|
|
141
151
|
|
142
152
|
class CheckpointSavingConfig(CallbackConfigBase):
|
@@ -147,6 +157,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
147
157
|
BestCheckpointCallbackConfig(throw_on_no_metric=False),
|
148
158
|
LastCheckpointCallbackConfig(),
|
149
159
|
OnExceptionCheckpointCallbackConfig(),
|
160
|
+
TimeCheckpointCallbackConfig(interval=timedelta(hours=12)),
|
150
161
|
]
|
151
162
|
"""Checkpoint callback configurations."""
|
152
163
|
|
@@ -420,6 +431,9 @@ class SanityCheckingConfig(C.Config):
|
|
420
431
|
"""
|
421
432
|
|
422
433
|
|
434
|
+
@plugin_registry.rebuild_on_registers
|
435
|
+
@strategy_registry.rebuild_on_registers
|
436
|
+
@accelerator_registry.rebuild_on_registers
|
423
437
|
class TrainerConfig(C.Config):
|
424
438
|
# region Active Run Configuration
|
425
439
|
id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
|
@@ -564,7 +578,9 @@ class TrainerConfig(C.Config):
|
|
564
578
|
Default: ``False``.
|
565
579
|
"""
|
566
580
|
|
567
|
-
plugins:
|
581
|
+
plugins: (
|
582
|
+
list[Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]] | None
|
583
|
+
) = None
|
568
584
|
"""
|
569
585
|
Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
|
570
586
|
Default: ``None``.
|
@@ -724,13 +740,21 @@ class TrainerConfig(C.Config):
|
|
724
740
|
Default: ``True``.
|
725
741
|
"""
|
726
742
|
|
727
|
-
accelerator:
|
743
|
+
accelerator: (
|
744
|
+
Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()]
|
745
|
+
| AcceleratorLiteral
|
746
|
+
| None
|
747
|
+
) = None
|
728
748
|
"""Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
|
729
749
|
as well as custom accelerator instances.
|
730
750
|
Default: ``"auto"``.
|
731
751
|
"""
|
732
752
|
|
733
|
-
strategy:
|
753
|
+
strategy: (
|
754
|
+
Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()]
|
755
|
+
| StrategyLiteral
|
756
|
+
| None
|
757
|
+
) = None
|
734
758
|
"""Supports different training strategies with aliases as well custom strategies.
|
735
759
|
Default: ``"auto"``.
|
736
760
|
"""
|
@@ -4,14 +4,14 @@ import datetime
|
|
4
4
|
import logging
|
5
5
|
import time
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import Any, Literal
|
8
8
|
|
9
9
|
from lightning.pytorch.callbacks.callback import Callback
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
log = logging.getLogger(__name__)
|
13
13
|
|
14
|
-
Stage
|
14
|
+
Stage = TypeAliasType("Stage", Literal["train", "validate", "test", "predict"])
|
15
15
|
ALL_STAGES = ("train", "validate", "test", "predict")
|
16
16
|
|
17
17
|
|
@@ -12,7 +12,7 @@ from collections import defaultdict
|
|
12
12
|
from collections.abc import Callable
|
13
13
|
from pathlib import Path
|
14
14
|
from types import FrameType
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any
|
16
16
|
|
17
17
|
import nshrunner as nr
|
18
18
|
import torch.utils.data
|
@@ -22,12 +22,14 @@ from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompo
|
|
22
22
|
from lightning.pytorch.trainer.connectors.signal_connector import (
|
23
23
|
_SignalConnector as _LightningSignalConnector,
|
24
24
|
)
|
25
|
-
from typing_extensions import override
|
25
|
+
from typing_extensions import TypeAliasType, override
|
26
26
|
|
27
27
|
log = logging.getLogger(__name__)
|
28
28
|
|
29
|
-
_SIGNUM = int | signal.Signals
|
30
|
-
_HANDLER
|
29
|
+
_SIGNUM = TypeAliasType("_SIGNUM", int | signal.Signals)
|
30
|
+
_HANDLER = TypeAliasType(
|
31
|
+
"_HANDLER", Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
|
32
|
+
)
|
31
33
|
_IS_WINDOWS = platform.system() == "Windows"
|
32
34
|
|
33
35
|
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -23,9 +23,9 @@ from ..callbacks.base import resolve_all_callbacks
|
|
23
23
|
from ..util._environment_info import EnvironmentConfig
|
24
24
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
25
25
|
from ._config import (
|
26
|
-
|
26
|
+
AcceleratorConfigBase,
|
27
27
|
LightningTrainerKwargs,
|
28
|
-
|
28
|
+
StrategyConfigBase,
|
29
29
|
TrainerConfig,
|
30
30
|
)
|
31
31
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
@@ -171,12 +171,12 @@ class Trainer(LightningTrainer):
|
|
171
171
|
_update_kwargs(use_distributed_sampler=use_distributed_sampler)
|
172
172
|
|
173
173
|
if (accelerator := hparams.accelerator) is not None:
|
174
|
-
if isinstance(accelerator,
|
174
|
+
if isinstance(accelerator, AcceleratorConfigBase):
|
175
175
|
accelerator = accelerator.create_accelerator()
|
176
176
|
_update_kwargs(accelerator=accelerator)
|
177
177
|
|
178
178
|
if (strategy := hparams.strategy) is not None:
|
179
|
-
if isinstance(strategy,
|
179
|
+
if isinstance(strategy, StrategyConfigBase):
|
180
180
|
strategy = strategy.create_strategy()
|
181
181
|
_update_kwargs(strategy=strategy)
|
182
182
|
|
nshtrainer/util/_useful_types.py
CHANGED
@@ -7,7 +7,14 @@ from collections.abc import Set as AbstractSet
|
|
7
7
|
from os import PathLike
|
8
8
|
from typing import Any, TypeVar, overload
|
9
9
|
|
10
|
-
from typing_extensions import
|
10
|
+
from typing_extensions import (
|
11
|
+
Buffer,
|
12
|
+
Literal,
|
13
|
+
Protocol,
|
14
|
+
SupportsIndex,
|
15
|
+
TypeAlias,
|
16
|
+
TypeAliasType,
|
17
|
+
)
|
11
18
|
|
12
19
|
_KT = TypeVar("_KT")
|
13
20
|
_KT_co = TypeVar("_KT_co", covariant=True)
|
@@ -60,7 +67,9 @@ class SupportsAllComparisons(
|
|
60
67
|
): ...
|
61
68
|
|
62
69
|
|
63
|
-
SupportsRichComparison
|
70
|
+
SupportsRichComparison = TypeAliasType(
|
71
|
+
"SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
|
72
|
+
)
|
64
73
|
SupportsRichComparisonT = TypeVar(
|
65
74
|
"SupportsRichComparisonT", bound=SupportsRichComparison
|
66
75
|
)
|
nshtrainer/util/config/dtype.py
CHANGED
@@ -1,57 +1,60 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING, Literal
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
6
|
import torch
|
7
|
-
from typing_extensions import assert_never
|
7
|
+
from typing_extensions import TypeAliasType, assert_never
|
8
8
|
|
9
9
|
from ..bf16 import is_bf16_supported_no_emulation
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from ...trainer._config import TrainerConfig
|
13
13
|
|
14
|
-
DTypeName
|
15
|
-
"
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
14
|
+
DTypeName = TypeAliasType(
|
15
|
+
"DTypeName",
|
16
|
+
Literal[
|
17
|
+
"float32",
|
18
|
+
"float",
|
19
|
+
"float64",
|
20
|
+
"double",
|
21
|
+
"float16",
|
22
|
+
"bfloat16",
|
23
|
+
"float8_e4m3fn",
|
24
|
+
"float8_e4m3fnuz",
|
25
|
+
"float8_e5m2",
|
26
|
+
"float8_e5m2fnuz",
|
27
|
+
"half",
|
28
|
+
"uint8",
|
29
|
+
"uint16",
|
30
|
+
"uint32",
|
31
|
+
"uint64",
|
32
|
+
"int8",
|
33
|
+
"int16",
|
34
|
+
"short",
|
35
|
+
"int32",
|
36
|
+
"int",
|
37
|
+
"int64",
|
38
|
+
"long",
|
39
|
+
"complex32",
|
40
|
+
"complex64",
|
41
|
+
"chalf",
|
42
|
+
"cfloat",
|
43
|
+
"complex128",
|
44
|
+
"cdouble",
|
45
|
+
"quint8",
|
46
|
+
"qint8",
|
47
|
+
"qint32",
|
48
|
+
"bool",
|
49
|
+
"quint4x2",
|
50
|
+
"quint2x4",
|
51
|
+
"bits1x8",
|
52
|
+
"bits2x4",
|
53
|
+
"bits4x2",
|
54
|
+
"bits8",
|
55
|
+
"bits16",
|
56
|
+
],
|
57
|
+
)
|
55
58
|
|
56
59
|
|
57
60
|
class DTypeConfig(C.Config):
|
nshtrainer/util/path.py
CHANGED
@@ -6,11 +6,12 @@ import os
|
|
6
6
|
import platform
|
7
7
|
import shutil
|
8
8
|
from pathlib import Path
|
9
|
-
|
9
|
+
|
10
|
+
from typing_extensions import TypeAliasType
|
10
11
|
|
11
12
|
log = logging.getLogger(__name__)
|
12
13
|
|
13
|
-
_Path
|
14
|
+
_Path = TypeAliasType("_Path", str | Path)
|
14
15
|
|
15
16
|
|
16
17
|
def get_relative_path(source: _Path, destination: _Path):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nshtrainer
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.0b27
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -9,6 +9,7 @@ Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
12
13
|
Provides-Extra: extra
|
13
14
|
Requires-Dist: GitPython ; extra == "extra"
|
14
15
|
Requires-Dist: huggingface-hub ; extra == "extra"
|