nshtrainer 1.0.0b26__py3-none-any.whl → 1.0.0b27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/actsave.py +2 -2
- nshtrainer/callbacks/base.py +5 -3
- nshtrainer/callbacks/shared_parameters.py +5 -3
- nshtrainer/configs/__init__.py +4 -0
- nshtrainer/configs/callbacks/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
- nshtrainer/configs/trainer/__init__.py +4 -0
- nshtrainer/configs/trainer/_config/__init__.py +4 -0
- nshtrainer/loggers/__init__.py +12 -5
- nshtrainer/lr_scheduler/__init__.py +9 -5
- nshtrainer/model/mixins/callback.py +6 -4
- nshtrainer/optimizer.py +5 -3
- nshtrainer/profiler/__init__.py +9 -5
- nshtrainer/trainer/_config.py +47 -42
- nshtrainer/trainer/_runtime_callback.py +3 -3
- nshtrainer/trainer/signal_connector.py +6 -4
- nshtrainer/util/_useful_types.py +11 -2
- nshtrainer/util/config/dtype.py +46 -43
- nshtrainer/util/path.py +3 -2
- {nshtrainer-1.0.0b26.dist-info → nshtrainer-1.0.0b27.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b26.dist-info → nshtrainer-1.0.0b27.dist-info}/RECORD +23 -22
- {nshtrainer-1.0.0b26.dist-info → nshtrainer-1.0.0b27.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/actsave.py
CHANGED
@@ -4,12 +4,12 @@ import contextlib
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Literal
|
6
6
|
|
7
|
-
from typing_extensions import
|
7
|
+
from typing_extensions import TypeAliasType, override
|
8
8
|
|
9
9
|
from .._callback import NTCallbackBase
|
10
10
|
from .base import CallbackConfigBase
|
11
11
|
|
12
|
-
Stage
|
12
|
+
Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
|
13
13
|
|
14
14
|
|
15
15
|
class ActSaveConfig(CallbackConfigBase):
|
nshtrainer/callbacks/base.py
CHANGED
@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
|
|
4
4
|
from collections import Counter
|
5
5
|
from collections.abc import Iterable
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import TYPE_CHECKING, ClassVar
|
7
|
+
from typing import TYPE_CHECKING, ClassVar
|
8
8
|
|
9
9
|
import nshconfig as C
|
10
10
|
from lightning.pytorch import Callback
|
11
|
-
from typing_extensions import TypedDict, Unpack
|
11
|
+
from typing_extensions import TypeAliasType, TypedDict, Unpack
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from ..trainer._config import TrainerConfig
|
@@ -30,7 +30,9 @@ class CallbackWithMetadata:
|
|
30
30
|
metadata: CallbackMetadataConfig
|
31
31
|
|
32
32
|
|
33
|
-
ConstructedCallback
|
33
|
+
ConstructedCallback = TypeAliasType(
|
34
|
+
"ConstructedCallback", Callback | CallbackWithMetadata
|
35
|
+
)
|
34
36
|
|
35
37
|
|
36
38
|
class CallbackConfigBase(C.Config, ABC):
|
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Literal, Protocol,
|
5
|
+
from typing import Literal, Protocol, runtime_checkable
|
6
6
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import LightningModule, Trainer
|
9
9
|
from lightning.pytorch.callbacks import Callback
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
from .base import CallbackConfigBase
|
13
13
|
|
@@ -34,7 +34,9 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
34
34
|
yield SharedParametersCallback(self)
|
35
35
|
|
36
36
|
|
37
|
-
SharedParametersList
|
37
|
+
SharedParametersList = TypeAliasType(
|
38
|
+
"SharedParametersList", list[tuple[nn.Parameter, int | float]]
|
39
|
+
)
|
38
40
|
|
39
41
|
|
40
42
|
@runtime_checkable
|
nshtrainer/configs/__init__.py
CHANGED
@@ -46,6 +46,9 @@ from nshtrainer.callbacks import (
|
|
46
46
|
from nshtrainer.callbacks import (
|
47
47
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
48
48
|
)
|
49
|
+
from nshtrainer.callbacks import (
|
50
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
51
|
+
)
|
49
52
|
from nshtrainer.callbacks import (
|
50
53
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
51
54
|
)
|
@@ -217,6 +220,7 @@ __all__ = [
|
|
217
220
|
"SwishNonlinearityConfig",
|
218
221
|
"TanhNonlinearityConfig",
|
219
222
|
"TensorboardLoggerConfig",
|
223
|
+
"TimeCheckpointCallbackConfig",
|
220
224
|
"TrainerConfig",
|
221
225
|
"WandbLoggerConfig",
|
222
226
|
"WandbUploadCodeCallbackConfig",
|
@@ -38,6 +38,9 @@ from nshtrainer.callbacks import (
|
|
38
38
|
from nshtrainer.callbacks import (
|
39
39
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
40
40
|
)
|
41
|
+
from nshtrainer.callbacks import (
|
42
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
43
|
+
)
|
41
44
|
from nshtrainer.callbacks import (
|
42
45
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
43
46
|
)
|
@@ -95,6 +98,7 @@ __all__ = [
|
|
95
98
|
"PrintTableMetricsCallbackConfig",
|
96
99
|
"RLPSanityChecksCallbackConfig",
|
97
100
|
"SharedParametersCallbackConfig",
|
101
|
+
"TimeCheckpointCallbackConfig",
|
98
102
|
"WandbUploadCodeCallbackConfig",
|
99
103
|
"WandbWatchCallbackConfig",
|
100
104
|
"actsave",
|
@@ -11,6 +11,9 @@ from nshtrainer.callbacks.checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint import (
|
12
12
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
from nshtrainer.callbacks.checkpoint import (
|
15
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
16
|
+
)
|
14
17
|
from nshtrainer.callbacks.checkpoint._base import (
|
15
18
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
16
19
|
)
|
@@ -26,6 +29,7 @@ from . import _base as _base
|
|
26
29
|
from . import best_checkpoint as best_checkpoint
|
27
30
|
from . import last_checkpoint as last_checkpoint
|
28
31
|
from . import on_exception_checkpoint as on_exception_checkpoint
|
32
|
+
from . import time_checkpoint as time_checkpoint
|
29
33
|
|
30
34
|
__all__ = [
|
31
35
|
"BaseCheckpointCallbackConfig",
|
@@ -35,8 +39,10 @@ __all__ = [
|
|
35
39
|
"LastCheckpointCallbackConfig",
|
36
40
|
"MetricConfig",
|
37
41
|
"OnExceptionCheckpointCallbackConfig",
|
42
|
+
"TimeCheckpointCallbackConfig",
|
38
43
|
"_base",
|
39
44
|
"best_checkpoint",
|
40
45
|
"last_checkpoint",
|
41
46
|
"on_exception_checkpoint",
|
47
|
+
"time_checkpoint",
|
42
48
|
]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
6
|
+
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
7
|
+
)
|
8
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
9
|
+
CheckpointMetadata as CheckpointMetadata,
|
10
|
+
)
|
11
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
12
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CheckpointMetadata",
|
18
|
+
"TimeCheckpointCallbackConfig",
|
19
|
+
]
|
@@ -48,6 +48,9 @@ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
|
|
48
48
|
from nshtrainer.trainer._config import (
|
49
49
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
50
50
|
)
|
51
|
+
from nshtrainer.trainer._config import (
|
52
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
53
|
+
)
|
51
54
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
52
55
|
|
53
56
|
from . import _config as _config
|
@@ -79,6 +82,7 @@ __all__ = [
|
|
79
82
|
"SharedParametersCallbackConfig",
|
80
83
|
"StrategyConfigBase",
|
81
84
|
"TensorboardLoggerConfig",
|
85
|
+
"TimeCheckpointCallbackConfig",
|
82
86
|
"TrainerConfig",
|
83
87
|
"WandbLoggerConfig",
|
84
88
|
"_config",
|
@@ -47,6 +47,9 @@ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
|
|
47
47
|
from nshtrainer.trainer._config import (
|
48
48
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
49
49
|
)
|
50
|
+
from nshtrainer.trainer._config import (
|
51
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
52
|
+
)
|
50
53
|
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
51
54
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
52
55
|
|
@@ -76,6 +79,7 @@ __all__ = [
|
|
76
79
|
"SharedParametersCallbackConfig",
|
77
80
|
"StrategyConfigBase",
|
78
81
|
"TensorboardLoggerConfig",
|
82
|
+
"TimeCheckpointCallbackConfig",
|
79
83
|
"TrainerConfig",
|
80
84
|
"WandbLoggerConfig",
|
81
85
|
]
|
nshtrainer/loggers/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import BaseLoggerConfig as BaseLoggerConfig
|
8
9
|
from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
|
@@ -10,7 +11,13 @@ from .csv import CSVLoggerConfig as CSVLoggerConfig
|
|
10
11
|
from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
11
12
|
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
12
13
|
|
13
|
-
LoggerConfig
|
14
|
-
|
15
|
-
|
16
|
-
|
14
|
+
LoggerConfig = TypeAliasType(
|
15
|
+
"LoggerConfig",
|
16
|
+
Annotated[
|
17
|
+
CSVLoggerConfig
|
18
|
+
| TensorboardLoggerConfig
|
19
|
+
| WandbLoggerConfig
|
20
|
+
| ActSaveLoggerConfig,
|
21
|
+
C.Field(discriminator="name"),
|
22
|
+
],
|
23
|
+
)
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
8
9
|
from ._base import LRSchedulerMetadata as LRSchedulerMetadata
|
@@ -15,7 +16,10 @@ from .linear_warmup_cosine import (
|
|
15
16
|
from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
|
16
17
|
from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
17
18
|
|
18
|
-
LRSchedulerConfig
|
19
|
-
|
20
|
-
|
21
|
-
|
19
|
+
LRSchedulerConfig = TypeAliasType(
|
20
|
+
"LRSchedulerConfig",
|
21
|
+
Annotated[
|
22
|
+
LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
|
23
|
+
C.Field(discriminator="name"),
|
24
|
+
],
|
25
|
+
)
|
@@ -2,18 +2,20 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Iterable, Sequence
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, cast
|
6
6
|
|
7
7
|
from lightning.pytorch import Callback, LightningModule
|
8
|
-
from typing_extensions import override
|
8
|
+
from typing_extensions import TypeAliasType, override
|
9
9
|
|
10
10
|
from ..._callback import NTCallbackBase
|
11
11
|
from ...util.typing_utils import mixin_base_type
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
|
-
_Callback = Callback | NTCallbackBase
|
16
|
-
CallbackFn
|
15
|
+
_Callback = TypeAliasType("_Callback", Callback | NTCallbackBase)
|
16
|
+
CallbackFn = TypeAliasType(
|
17
|
+
"CallbackFn", Callable[[], _Callback | Iterable[_Callback] | None]
|
18
|
+
)
|
17
19
|
|
18
20
|
|
19
21
|
class CallbackRegistrarModuleMixin:
|
nshtrainer/optimizer.py
CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Annotated, Any, Literal
|
5
|
+
from typing import Annotated, Any, Literal
|
6
6
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
|
13
13
|
class OptimizerConfigBase(C.Config, ABC):
|
@@ -57,4 +57,6 @@ class AdamWConfig(OptimizerConfigBase):
|
|
57
57
|
)
|
58
58
|
|
59
59
|
|
60
|
-
OptimizerConfig
|
60
|
+
OptimizerConfig = TypeAliasType(
|
61
|
+
"OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
|
62
|
+
)
|
nshtrainer/profiler/__init__.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import BaseProfilerConfig as BaseProfilerConfig
|
8
9
|
from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
|
9
10
|
from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
|
10
11
|
from .simple import SimpleProfilerConfig as SimpleProfilerConfig
|
11
12
|
|
12
|
-
ProfilerConfig
|
13
|
-
|
14
|
-
|
15
|
-
|
13
|
+
ProfilerConfig = TypeAliasType(
|
14
|
+
"ProfilerConfig",
|
15
|
+
Annotated[
|
16
|
+
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
17
|
+
C.Field(discriminator="name"),
|
18
|
+
],
|
19
|
+
)
|
nshtrainer/trainer/_config.py
CHANGED
@@ -14,7 +14,6 @@ from typing import (
|
|
14
14
|
Any,
|
15
15
|
ClassVar,
|
16
16
|
Literal,
|
17
|
-
TypeAlias,
|
18
17
|
)
|
19
18
|
|
20
19
|
import nshconfig as C
|
@@ -101,47 +100,53 @@ class StrategyConfigBase(C.Config, ABC):
|
|
101
100
|
strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
|
102
101
|
|
103
102
|
|
104
|
-
AcceleratorLiteral
|
105
|
-
"cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
|
106
|
-
|
107
|
-
|
108
|
-
StrategyLiteral
|
109
|
-
"
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
103
|
+
AcceleratorLiteral = TypeAliasType(
|
104
|
+
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
105
|
+
)
|
106
|
+
|
107
|
+
StrategyLiteral = TypeAliasType(
|
108
|
+
"StrategyLiteral",
|
109
|
+
Literal[
|
110
|
+
"auto",
|
111
|
+
"ddp",
|
112
|
+
"ddp_find_unused_parameters_false",
|
113
|
+
"ddp_find_unused_parameters_true",
|
114
|
+
"ddp_spawn",
|
115
|
+
"ddp_spawn_find_unused_parameters_false",
|
116
|
+
"ddp_spawn_find_unused_parameters_true",
|
117
|
+
"ddp_fork",
|
118
|
+
"ddp_fork_find_unused_parameters_false",
|
119
|
+
"ddp_fork_find_unused_parameters_true",
|
120
|
+
"ddp_notebook",
|
121
|
+
"dp",
|
122
|
+
"deepspeed",
|
123
|
+
"deepspeed_stage_1",
|
124
|
+
"deepspeed_stage_1_offload",
|
125
|
+
"deepspeed_stage_2",
|
126
|
+
"deepspeed_stage_2_offload",
|
127
|
+
"deepspeed_stage_3",
|
128
|
+
"deepspeed_stage_3_offload",
|
129
|
+
"deepspeed_stage_3_offload_nvme",
|
130
|
+
"fsdp",
|
131
|
+
"fsdp_cpu_offload",
|
132
|
+
"single_xla",
|
133
|
+
"xla_fsdp",
|
134
|
+
"xla",
|
135
|
+
"single_tpu",
|
136
|
+
],
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
CheckpointCallbackConfig = TypeAliasType(
|
141
|
+
"CheckpointCallbackConfig",
|
142
|
+
Annotated[
|
143
|
+
BestCheckpointCallbackConfig
|
144
|
+
| LastCheckpointCallbackConfig
|
145
|
+
| OnExceptionCheckpointCallbackConfig
|
146
|
+
| TimeCheckpointCallbackConfig,
|
147
|
+
C.Field(discriminator="name"),
|
148
|
+
],
|
149
|
+
)
|
145
150
|
|
146
151
|
|
147
152
|
class CheckpointSavingConfig(CallbackConfigBase):
|
@@ -4,14 +4,14 @@ import datetime
|
|
4
4
|
import logging
|
5
5
|
import time
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import Any, Literal
|
8
8
|
|
9
9
|
from lightning.pytorch.callbacks.callback import Callback
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
log = logging.getLogger(__name__)
|
13
13
|
|
14
|
-
Stage
|
14
|
+
Stage = TypeAliasType("Stage", Literal["train", "validate", "test", "predict"])
|
15
15
|
ALL_STAGES = ("train", "validate", "test", "predict")
|
16
16
|
|
17
17
|
|
@@ -12,7 +12,7 @@ from collections import defaultdict
|
|
12
12
|
from collections.abc import Callable
|
13
13
|
from pathlib import Path
|
14
14
|
from types import FrameType
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any
|
16
16
|
|
17
17
|
import nshrunner as nr
|
18
18
|
import torch.utils.data
|
@@ -22,12 +22,14 @@ from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompo
|
|
22
22
|
from lightning.pytorch.trainer.connectors.signal_connector import (
|
23
23
|
_SignalConnector as _LightningSignalConnector,
|
24
24
|
)
|
25
|
-
from typing_extensions import override
|
25
|
+
from typing_extensions import TypeAliasType, override
|
26
26
|
|
27
27
|
log = logging.getLogger(__name__)
|
28
28
|
|
29
|
-
_SIGNUM = int | signal.Signals
|
30
|
-
_HANDLER
|
29
|
+
_SIGNUM = TypeAliasType("_SIGNUM", int | signal.Signals)
|
30
|
+
_HANDLER = TypeAliasType(
|
31
|
+
"_HANDLER", Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
|
32
|
+
)
|
31
33
|
_IS_WINDOWS = platform.system() == "Windows"
|
32
34
|
|
33
35
|
|
nshtrainer/util/_useful_types.py
CHANGED
@@ -7,7 +7,14 @@ from collections.abc import Set as AbstractSet
|
|
7
7
|
from os import PathLike
|
8
8
|
from typing import Any, TypeVar, overload
|
9
9
|
|
10
|
-
from typing_extensions import
|
10
|
+
from typing_extensions import (
|
11
|
+
Buffer,
|
12
|
+
Literal,
|
13
|
+
Protocol,
|
14
|
+
SupportsIndex,
|
15
|
+
TypeAlias,
|
16
|
+
TypeAliasType,
|
17
|
+
)
|
11
18
|
|
12
19
|
_KT = TypeVar("_KT")
|
13
20
|
_KT_co = TypeVar("_KT_co", covariant=True)
|
@@ -60,7 +67,9 @@ class SupportsAllComparisons(
|
|
60
67
|
): ...
|
61
68
|
|
62
69
|
|
63
|
-
SupportsRichComparison
|
70
|
+
SupportsRichComparison = TypeAliasType(
|
71
|
+
"SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
|
72
|
+
)
|
64
73
|
SupportsRichComparisonT = TypeVar(
|
65
74
|
"SupportsRichComparisonT", bound=SupportsRichComparison
|
66
75
|
)
|
nshtrainer/util/config/dtype.py
CHANGED
@@ -1,57 +1,60 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING, Literal
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
6
|
import torch
|
7
|
-
from typing_extensions import assert_never
|
7
|
+
from typing_extensions import TypeAliasType, assert_never
|
8
8
|
|
9
9
|
from ..bf16 import is_bf16_supported_no_emulation
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from ...trainer._config import TrainerConfig
|
13
13
|
|
14
|
-
DTypeName
|
15
|
-
"
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
14
|
+
DTypeName = TypeAliasType(
|
15
|
+
"DTypeName",
|
16
|
+
Literal[
|
17
|
+
"float32",
|
18
|
+
"float",
|
19
|
+
"float64",
|
20
|
+
"double",
|
21
|
+
"float16",
|
22
|
+
"bfloat16",
|
23
|
+
"float8_e4m3fn",
|
24
|
+
"float8_e4m3fnuz",
|
25
|
+
"float8_e5m2",
|
26
|
+
"float8_e5m2fnuz",
|
27
|
+
"half",
|
28
|
+
"uint8",
|
29
|
+
"uint16",
|
30
|
+
"uint32",
|
31
|
+
"uint64",
|
32
|
+
"int8",
|
33
|
+
"int16",
|
34
|
+
"short",
|
35
|
+
"int32",
|
36
|
+
"int",
|
37
|
+
"int64",
|
38
|
+
"long",
|
39
|
+
"complex32",
|
40
|
+
"complex64",
|
41
|
+
"chalf",
|
42
|
+
"cfloat",
|
43
|
+
"complex128",
|
44
|
+
"cdouble",
|
45
|
+
"quint8",
|
46
|
+
"qint8",
|
47
|
+
"qint32",
|
48
|
+
"bool",
|
49
|
+
"quint4x2",
|
50
|
+
"quint2x4",
|
51
|
+
"bits1x8",
|
52
|
+
"bits2x4",
|
53
|
+
"bits4x2",
|
54
|
+
"bits8",
|
55
|
+
"bits16",
|
56
|
+
],
|
57
|
+
)
|
55
58
|
|
56
59
|
|
57
60
|
class DTypeConfig(C.Config):
|
nshtrainer/util/path.py
CHANGED
@@ -6,11 +6,12 @@ import os
|
|
6
6
|
import platform
|
7
7
|
import shutil
|
8
8
|
from pathlib import Path
|
9
|
-
|
9
|
+
|
10
|
+
from typing_extensions import TypeAliasType
|
10
11
|
|
11
12
|
log = logging.getLogger(__name__)
|
12
13
|
|
13
|
-
_Path
|
14
|
+
_Path = TypeAliasType("_Path", str | Path)
|
14
15
|
|
15
16
|
|
16
17
|
def get_relative_path(source: _Path, destination: _Path):
|
@@ -7,8 +7,8 @@ nshtrainer/_directory.py,sha256=p2uk1FnISFEpMqlDevKhoWhQsCEtvHUPg459K-86QA8,3053
|
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=xj91CEKcuvpaiipZbVZovZiU_fdsTkPXOkQ-3Xb-FhU,14183
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=z2uDaUVzNXG79UOAYGpwC0hxk__LTbEBVhXJ3W-Srws,3911
|
10
|
-
nshtrainer/callbacks/actsave.py,sha256=
|
11
|
-
nshtrainer/callbacks/base.py,sha256=
|
10
|
+
nshtrainer/callbacks/actsave.py,sha256=bnqS3y9sit1hfUvqPx-WWRug2OeTHPM0PDEJaex7f3Q,3776
|
11
|
+
nshtrainer/callbacks/base.py,sha256=AFqvKNzH710xviYQ7X0hw0M7ETWRhBAWv3PN9WnrZw0,3608
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=09cwgAJawTFLdzZwYhg_jVgbW-1d09hwjHdI-PQRck0,797
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=hcvtpNvuyWxxP36XiGyFYwVsSwmHphKD-1IA95DinME,2602
|
@@ -27,23 +27,24 @@ nshtrainer/callbacks/lr_monitor.py,sha256=IyFZoXaxJoTBSkdLu1iEZ1qI8_UFNJwafR_xTV
|
|
27
27
|
nshtrainer/callbacks/norm_logging.py,sha256=C44Mvt73gqQEpCFd0j3qYg6NY7sL2jm3X1qJVY_XLfI,6329
|
28
28
|
nshtrainer/callbacks/print_table.py,sha256=lS49Hz0OLcv3VPxEfLBguwe57y2nmKg0pMF6HJxuJio,2974
|
29
29
|
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=kWl2dYOXn2L8k6ub_012jNkqOxtyea1yr1qWRNG6UW4,9990
|
30
|
-
nshtrainer/callbacks/shared_parameters.py,sha256=
|
30
|
+
nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoVQm7W90LZb9so,3015
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
34
|
+
nshtrainer/configs/__init__.py,sha256=crQ0IcNkmXx-dUEyKMg9mhzxWiBYppba6UOKzEgbWzo,9820
|
35
35
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
37
|
nshtrainer/configs/_directory/__init__.py,sha256=7H3fIh9c31ce0r8JpuzEY8bZptI7tiVLNwVtj729HAY,303
|
38
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=VUgQnyEI2ekBxBIV15L09tKdrfGt7eWxnf30DiCLaso,416
|
39
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
39
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=TwXL-GkDa1j3m1GEfIJ-YaBqazm9wm1uQpzUd6135cA,4265
|
40
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JJg9d8iNGpO-9M1LsK4h1cu3NYWniyIyLQ4SauFCzOs,272
|
41
41
|
nshtrainer/configs/callbacks/base/__init__.py,sha256=V694hzF_ubnA-hwTps30PeFbgDSm3I_UIMTnljM3_OI,176
|
42
|
-
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=
|
42
|
+
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=zPUItxoYWrMT9i1TxOvVhIeTa0NEFg-nDE5FjHfkP-A,1564
|
43
43
|
nshtrainer/configs/callbacks/checkpoint/_base/__init__.py,sha256=5jl6A5Gv6arZXmHV6lz5dQ8DL6PdJIfJqHLP4acClKQ,479
|
44
44
|
nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py,sha256=_KTkF3_Yx0WiwOvRf2s1KRvod2dryeGJITVkx10YqBE,648
|
45
45
|
nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=sH4aJyovCeT6h4xz9r5WfVA0eviJur4zZt7R0hQsyAk,539
|
46
46
|
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=iUoCrTJpvDGEYCYfNCpkficH6D7yq119hyxy1qeFSGU,410
|
47
|
+
nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py,sha256=ZQz5gXKSHz7xvnZbgU_bauNhAJDh9rblBA0ldkSsdwg,539
|
47
48
|
nshtrainer/configs/callbacks/debug_flag/__init__.py,sha256=gPC3EAqzuyP2hAcCf3s09sPDe7q_02S1eUCWvRTNKrI,317
|
48
49
|
nshtrainer/configs/callbacks/directory_setup/__init__.py,sha256=25551zMMctAkzcLEGBN7HSeQUIrBtbBq7whgPZjtepY,351
|
49
50
|
nshtrainer/configs/callbacks/early_stopping/__init__.py,sha256=Q-hAAIcucLNriw5PXIshgBv7Yr5sMDk4GwjTIDdWBxo,434
|
@@ -80,8 +81,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
80
81
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
81
82
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
82
83
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
83
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
84
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
84
|
+
nshtrainer/configs/trainer/__init__.py,sha256=jRbJylnfPa483iHzN-ZYDObInUfMxuql47gFMsBlJKU,3527
|
85
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=tAWlUTtn2EeQ8xnKft4CA4gVwXn2nf9Yrs57em8jz70,3438
|
85
86
|
nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
|
86
87
|
nshtrainer/configs/util/__init__.py,sha256=gtYtZ4VGwEvF9_hByZl8CWOSeDpEOIkkcLtUwvNbSEQ,2014
|
87
88
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
@@ -92,13 +93,13 @@ nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,2
|
|
92
93
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
93
94
|
nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
|
94
95
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
95
|
-
nshtrainer/loggers/__init__.py,sha256
|
96
|
+
nshtrainer/loggers/__init__.py,sha256=-y8B-9TF6vJdZUQewJNDcZ2aOv04FEUFtKwaiDobIO0,670
|
96
97
|
nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
|
97
98
|
nshtrainer/loggers/actsave.py,sha256=23Kre-mq-Y9Iw1SRyGmHnHK1bc_0gTWFXJViv9bkVz0,1324
|
98
99
|
nshtrainer/loggers/csv.py,sha256=Deh5gm3oROJbQzigV4SHni5JRwSrBdm-4YD3yrcGnHo,1104
|
99
100
|
nshtrainer/loggers/tensorboard.py,sha256=jP9V4nlq_MXUaoD6xv1Cws2ioft83Lm8yUJhhGhuUrQ,2268
|
100
101
|
nshtrainer/loggers/wandb.py,sha256=EjKQQznLSUSCWO7uIviz9g0dVW4ZLxb_8UVhY4vR7r0,6800
|
101
|
-
nshtrainer/lr_scheduler/__init__.py,sha256=
|
102
|
+
nshtrainer/lr_scheduler/__init__.py,sha256=BGnO-okUTZOtF15-UmQ05U4oEatSF5VNs3YeidNEWn4,853
|
102
103
|
nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcVpsE,3723
|
103
104
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
|
104
105
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
|
@@ -106,7 +107,7 @@ nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMk
|
|
106
107
|
nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
|
107
108
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
108
109
|
nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
|
109
|
-
nshtrainer/model/mixins/callback.py,sha256=
|
110
|
+
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
110
111
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
111
112
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
112
113
|
nshtrainer/nn/__init__.py,sha256=sANhrZpeN5syLKOsmXMwhaFl2SBFPWcLaEe1EH22TWQ,1463
|
@@ -114,29 +115,29 @@ nshtrainer/nn/mlp.py,sha256=2W8bzE96DzCMzGm6WPiPhNFQfhqaoG3GXPn_oKBnlUM,5988
|
|
114
115
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
115
116
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
116
117
|
nshtrainer/nn/nonlinearity.py,sha256=mp5XvXRHURB6jwuZ0YyTj5ZoHJYNJNgO2aLtUY1D-2Y,6114
|
117
|
-
nshtrainer/optimizer.py,sha256=
|
118
|
-
nshtrainer/profiler/__init__.py,sha256=
|
118
|
+
nshtrainer/optimizer.py,sha256=wmSRpSoU59rstj2RBoifQ15ZwRInYpm0tDBQZ1gqOfE,1596
|
119
|
+
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
119
120
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
120
121
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
121
122
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
122
123
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
123
124
|
nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
|
124
|
-
nshtrainer/trainer/_config.py,sha256=
|
125
|
-
nshtrainer/trainer/_runtime_callback.py,sha256=
|
126
|
-
nshtrainer/trainer/signal_connector.py,sha256=
|
125
|
+
nshtrainer/trainer/_config.py,sha256=Mz9J2ZFqxTlttnRA1eScGRgSAuf3-o3i9-xjN7eTm-k,35256
|
126
|
+
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
127
|
+
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
127
128
|
nshtrainer/trainer/trainer.py,sha256=HHqT83zWtYY9g5yD6X9aWrVh5VSpILW8PhoE6fp4snE,20734
|
128
129
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
129
|
-
nshtrainer/util/_useful_types.py,sha256=
|
130
|
+
nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
|
130
131
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
131
132
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
132
|
-
nshtrainer/util/config/dtype.py,sha256=
|
133
|
+
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
133
134
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
134
135
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
135
|
-
nshtrainer/util/path.py,sha256=
|
136
|
+
nshtrainer/util/path.py,sha256=L-Nh9tlXSUfoP19TFbQq8I0AfS5ugCfGYTYFeddDHcs,3516
|
136
137
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
137
138
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
138
139
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
139
140
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
140
|
-
nshtrainer-1.0.
|
141
|
-
nshtrainer-1.0.
|
142
|
-
nshtrainer-1.0.
|
141
|
+
nshtrainer-1.0.0b27.dist-info/METADATA,sha256=-eLqorpTOufpf0XwVyHmk2_nsgI3NdETpNPYa3uhHy0,988
|
142
|
+
nshtrainer-1.0.0b27.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
143
|
+
nshtrainer-1.0.0b27.dist-info/RECORD,,
|
File without changes
|