nshtrainer 1.0.0b26__py3-none-any.whl → 1.0.0b28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/actsave.py +2 -2
- nshtrainer/callbacks/base.py +5 -3
- nshtrainer/callbacks/shared_parameters.py +5 -3
- nshtrainer/configs/__init__.py +22 -0
- nshtrainer/configs/_directory/__init__.py +2 -0
- nshtrainer/configs/callbacks/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
- nshtrainer/configs/loggers/__init__.py +2 -0
- nshtrainer/configs/lr_scheduler/__init__.py +6 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +2 -0
- nshtrainer/configs/nn/mlp/__init__.py +2 -0
- nshtrainer/configs/nn/nonlinearity/__init__.py +2 -0
- nshtrainer/configs/optimizer/__init__.py +2 -0
- nshtrainer/configs/profiler/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +14 -0
- nshtrainer/configs/trainer/_config/__init__.py +14 -0
- nshtrainer/configs/util/__init__.py +2 -0
- nshtrainer/configs/util/config/__init__.py +2 -0
- nshtrainer/configs/util/config/duration/__init__.py +2 -0
- nshtrainer/loggers/__init__.py +12 -5
- nshtrainer/lr_scheduler/__init__.py +9 -5
- nshtrainer/model/mixins/callback.py +6 -4
- nshtrainer/optimizer.py +5 -3
- nshtrainer/profiler/__init__.py +9 -5
- nshtrainer/trainer/_config.py +47 -42
- nshtrainer/trainer/_runtime_callback.py +3 -3
- nshtrainer/trainer/signal_connector.py +6 -4
- nshtrainer/util/_useful_types.py +11 -2
- nshtrainer/util/config/dtype.py +46 -43
- nshtrainer/util/path.py +3 -2
- {nshtrainer-1.0.0b26.dist-info → nshtrainer-1.0.0b28.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b26.dist-info → nshtrainer-1.0.0b28.dist-info}/RECORD +35 -34
- {nshtrainer-1.0.0b26.dist-info → nshtrainer-1.0.0b28.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/actsave.py
CHANGED
@@ -4,12 +4,12 @@ import contextlib
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Literal
|
6
6
|
|
7
|
-
from typing_extensions import
|
7
|
+
from typing_extensions import TypeAliasType, override
|
8
8
|
|
9
9
|
from .._callback import NTCallbackBase
|
10
10
|
from .base import CallbackConfigBase
|
11
11
|
|
12
|
-
Stage
|
12
|
+
Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
|
13
13
|
|
14
14
|
|
15
15
|
class ActSaveConfig(CallbackConfigBase):
|
nshtrainer/callbacks/base.py
CHANGED
@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
|
|
4
4
|
from collections import Counter
|
5
5
|
from collections.abc import Iterable
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import TYPE_CHECKING, ClassVar
|
7
|
+
from typing import TYPE_CHECKING, ClassVar
|
8
8
|
|
9
9
|
import nshconfig as C
|
10
10
|
from lightning.pytorch import Callback
|
11
|
-
from typing_extensions import TypedDict, Unpack
|
11
|
+
from typing_extensions import TypeAliasType, TypedDict, Unpack
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from ..trainer._config import TrainerConfig
|
@@ -30,7 +30,9 @@ class CallbackWithMetadata:
|
|
30
30
|
metadata: CallbackMetadataConfig
|
31
31
|
|
32
32
|
|
33
|
-
ConstructedCallback
|
33
|
+
ConstructedCallback = TypeAliasType(
|
34
|
+
"ConstructedCallback", Callback | CallbackWithMetadata
|
35
|
+
)
|
34
36
|
|
35
37
|
|
36
38
|
class CallbackConfigBase(C.Config, ABC):
|
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Literal, Protocol,
|
5
|
+
from typing import Literal, Protocol, runtime_checkable
|
6
6
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import LightningModule, Trainer
|
9
9
|
from lightning.pytorch.callbacks import Callback
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
from .base import CallbackConfigBase
|
13
13
|
|
@@ -34,7 +34,9 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
34
34
|
yield SharedParametersCallback(self)
|
35
35
|
|
36
36
|
|
37
|
-
SharedParametersList
|
37
|
+
SharedParametersList = TypeAliasType(
|
38
|
+
"SharedParametersList", list[tuple[nn.Parameter, int | float]]
|
39
|
+
)
|
38
40
|
|
39
41
|
|
40
42
|
@runtime_checkable
|
nshtrainer/configs/__init__.py
CHANGED
@@ -14,6 +14,7 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
14
14
|
from nshtrainer.callbacks import (
|
15
15
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
+
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
17
18
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
18
19
|
from nshtrainer.callbacks import (
|
19
20
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
@@ -46,6 +47,9 @@ from nshtrainer.callbacks import (
|
|
46
47
|
from nshtrainer.callbacks import (
|
47
48
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
48
49
|
)
|
50
|
+
from nshtrainer.callbacks import (
|
51
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
52
|
+
)
|
49
53
|
from nshtrainer.callbacks import (
|
50
54
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
51
55
|
)
|
@@ -57,11 +61,13 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
57
61
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
58
62
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
59
63
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
64
|
+
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
60
65
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
61
66
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
62
67
|
from nshtrainer.lr_scheduler import (
|
63
68
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
64
69
|
)
|
70
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
65
71
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
66
72
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
67
73
|
from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
|
@@ -70,6 +76,7 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
70
76
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
71
77
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
72
78
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
79
|
+
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
73
80
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
74
81
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
75
82
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -83,12 +90,17 @@ from nshtrainer.nn.nonlinearity import (
|
|
83
90
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
84
91
|
)
|
85
92
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
93
|
+
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
86
94
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
87
95
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
88
96
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
97
|
+
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
89
98
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
90
99
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
91
100
|
from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
|
101
|
+
from nshtrainer.trainer._config import (
|
102
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
103
|
+
)
|
92
104
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
93
105
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
94
106
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
@@ -127,6 +139,7 @@ from nshtrainer.util._environment_info import (
|
|
127
139
|
)
|
128
140
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
129
141
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
142
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
130
143
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
131
144
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
132
145
|
|
@@ -155,13 +168,16 @@ __all__ = [
|
|
155
168
|
"BaseProfilerConfig",
|
156
169
|
"BestCheckpointCallbackConfig",
|
157
170
|
"CSVLoggerConfig",
|
171
|
+
"CallbackConfig",
|
158
172
|
"CallbackConfigBase",
|
173
|
+
"CheckpointCallbackConfig",
|
159
174
|
"CheckpointMetadata",
|
160
175
|
"CheckpointSavingConfig",
|
161
176
|
"DTypeConfig",
|
162
177
|
"DebugFlagCallbackConfig",
|
163
178
|
"DirectoryConfig",
|
164
179
|
"DirectorySetupCallbackConfig",
|
180
|
+
"DurationConfig",
|
165
181
|
"ELUNonlinearityConfig",
|
166
182
|
"EMACallbackConfig",
|
167
183
|
"EarlyStoppingCallbackConfig",
|
@@ -184,21 +200,26 @@ __all__ = [
|
|
184
200
|
"GradientSkippingCallbackConfig",
|
185
201
|
"HuggingFaceHubAutoCreateConfig",
|
186
202
|
"HuggingFaceHubConfig",
|
203
|
+
"LRSchedulerConfig",
|
187
204
|
"LRSchedulerConfigBase",
|
188
205
|
"LastCheckpointCallbackConfig",
|
189
206
|
"LeakyReLUNonlinearityConfig",
|
190
207
|
"LearningRateMonitorConfig",
|
191
208
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
192
209
|
"LogEpochCallbackConfig",
|
210
|
+
"LoggerConfig",
|
193
211
|
"MLPConfig",
|
194
212
|
"MetricConfig",
|
195
213
|
"MishNonlinearityConfig",
|
214
|
+
"NonlinearityConfig",
|
196
215
|
"NormLoggingCallbackConfig",
|
197
216
|
"OnExceptionCheckpointCallbackConfig",
|
217
|
+
"OptimizerConfig",
|
198
218
|
"OptimizerConfigBase",
|
199
219
|
"PReLUConfig",
|
200
220
|
"PluginConfigBase",
|
201
221
|
"PrintTableMetricsCallbackConfig",
|
222
|
+
"ProfilerConfig",
|
202
223
|
"PyTorchProfilerConfig",
|
203
224
|
"RLPSanityChecksCallbackConfig",
|
204
225
|
"ReLUNonlinearityConfig",
|
@@ -217,6 +238,7 @@ __all__ = [
|
|
217
238
|
"SwishNonlinearityConfig",
|
218
239
|
"TanhNonlinearityConfig",
|
219
240
|
"TensorboardLoggerConfig",
|
241
|
+
"TimeCheckpointCallbackConfig",
|
220
242
|
"TrainerConfig",
|
221
243
|
"WandbLoggerConfig",
|
222
244
|
"WandbUploadCodeCallbackConfig",
|
@@ -6,8 +6,10 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
6
6
|
from nshtrainer._directory import (
|
7
7
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer._directory import LoggerConfig as LoggerConfig
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"DirectoryConfig",
|
12
13
|
"DirectorySetupCallbackConfig",
|
14
|
+
"LoggerConfig",
|
13
15
|
]
|
@@ -5,6 +5,7 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.callbacks import (
|
6
6
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
7
7
|
)
|
8
|
+
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
8
9
|
from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
|
9
10
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
10
11
|
from nshtrainer.callbacks import (
|
@@ -38,6 +39,9 @@ from nshtrainer.callbacks import (
|
|
38
39
|
from nshtrainer.callbacks import (
|
39
40
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
40
41
|
)
|
42
|
+
from nshtrainer.callbacks import (
|
43
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
44
|
+
)
|
41
45
|
from nshtrainer.callbacks import (
|
42
46
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
43
47
|
)
|
@@ -77,6 +81,7 @@ __all__ = [
|
|
77
81
|
"ActSaveConfig",
|
78
82
|
"BaseCheckpointCallbackConfig",
|
79
83
|
"BestCheckpointCallbackConfig",
|
84
|
+
"CallbackConfig",
|
80
85
|
"CallbackConfigBase",
|
81
86
|
"CheckpointMetadata",
|
82
87
|
"DebugFlagCallbackConfig",
|
@@ -95,6 +100,7 @@ __all__ = [
|
|
95
100
|
"PrintTableMetricsCallbackConfig",
|
96
101
|
"RLPSanityChecksCallbackConfig",
|
97
102
|
"SharedParametersCallbackConfig",
|
103
|
+
"TimeCheckpointCallbackConfig",
|
98
104
|
"WandbUploadCodeCallbackConfig",
|
99
105
|
"WandbWatchCallbackConfig",
|
100
106
|
"actsave",
|
@@ -11,6 +11,9 @@ from nshtrainer.callbacks.checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint import (
|
12
12
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
from nshtrainer.callbacks.checkpoint import (
|
15
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
16
|
+
)
|
14
17
|
from nshtrainer.callbacks.checkpoint._base import (
|
15
18
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
16
19
|
)
|
@@ -26,6 +29,7 @@ from . import _base as _base
|
|
26
29
|
from . import best_checkpoint as best_checkpoint
|
27
30
|
from . import last_checkpoint as last_checkpoint
|
28
31
|
from . import on_exception_checkpoint as on_exception_checkpoint
|
32
|
+
from . import time_checkpoint as time_checkpoint
|
29
33
|
|
30
34
|
__all__ = [
|
31
35
|
"BaseCheckpointCallbackConfig",
|
@@ -35,8 +39,10 @@ __all__ = [
|
|
35
39
|
"LastCheckpointCallbackConfig",
|
36
40
|
"MetricConfig",
|
37
41
|
"OnExceptionCheckpointCallbackConfig",
|
42
|
+
"TimeCheckpointCallbackConfig",
|
38
43
|
"_base",
|
39
44
|
"best_checkpoint",
|
40
45
|
"last_checkpoint",
|
41
46
|
"on_exception_checkpoint",
|
47
|
+
"time_checkpoint",
|
42
48
|
]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
6
|
+
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
7
|
+
)
|
8
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
9
|
+
CheckpointMetadata as CheckpointMetadata,
|
10
|
+
)
|
11
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
12
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CheckpointMetadata",
|
18
|
+
"TimeCheckpointCallbackConfig",
|
19
|
+
]
|
@@ -5,6 +5,7 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
6
6
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
7
7
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
8
|
+
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
8
9
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
9
10
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
10
11
|
from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
|
@@ -26,6 +27,7 @@ __all__ = [
|
|
26
27
|
"BaseLoggerConfig",
|
27
28
|
"CSVLoggerConfig",
|
28
29
|
"CallbackConfigBase",
|
30
|
+
"LoggerConfig",
|
29
31
|
"TensorboardLoggerConfig",
|
30
32
|
"WandbLoggerConfig",
|
31
33
|
"WandbUploadCodeCallbackConfig",
|
@@ -5,8 +5,12 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.lr_scheduler import (
|
6
6
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
7
7
|
)
|
8
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
8
9
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
9
10
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
11
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
12
|
+
DurationConfig as DurationConfig,
|
13
|
+
)
|
10
14
|
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
|
11
15
|
|
12
16
|
from . import _base as _base
|
@@ -14,6 +18,8 @@ from . import linear_warmup_cosine as linear_warmup_cosine
|
|
14
18
|
from . import reduce_lr_on_plateau as reduce_lr_on_plateau
|
15
19
|
|
16
20
|
__all__ = [
|
21
|
+
"DurationConfig",
|
22
|
+
"LRSchedulerConfig",
|
17
23
|
"LRSchedulerConfigBase",
|
18
24
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
19
25
|
"MetricConfig",
|
@@ -2,6 +2,9 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
6
|
+
DurationConfig as DurationConfig,
|
7
|
+
)
|
5
8
|
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
6
9
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
7
10
|
)
|
@@ -10,6 +13,7 @@ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
10
13
|
)
|
11
14
|
|
12
15
|
__all__ = [
|
16
|
+
"DurationConfig",
|
13
17
|
"LRSchedulerConfigBase",
|
14
18
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
15
19
|
]
|
@@ -8,6 +8,7 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
8
8
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
9
9
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
10
10
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
11
|
+
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
11
12
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
12
13
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
13
14
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -31,6 +32,7 @@ __all__ = [
|
|
31
32
|
"LeakyReLUNonlinearityConfig",
|
32
33
|
"MLPConfig",
|
33
34
|
"MishNonlinearityConfig",
|
35
|
+
"NonlinearityConfig",
|
34
36
|
"PReLUConfig",
|
35
37
|
"ReLUNonlinearityConfig",
|
36
38
|
"SiLUNonlinearityConfig",
|
@@ -4,8 +4,10 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
|
6
6
|
from nshtrainer.nn.mlp import MLPConfig as MLPConfig
|
7
|
+
from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"BaseNonlinearityConfig",
|
10
11
|
"MLPConfig",
|
12
|
+
"NonlinearityConfig",
|
11
13
|
]
|
@@ -9,6 +9,7 @@ from nshtrainer.nn.nonlinearity import (
|
|
9
9
|
LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
|
10
10
|
)
|
11
11
|
from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
|
12
|
+
from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
|
12
13
|
from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
|
13
14
|
from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
14
15
|
from nshtrainer.nn.nonlinearity import (
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
38
39
|
"GELUNonlinearityConfig",
|
39
40
|
"LeakyReLUNonlinearityConfig",
|
40
41
|
"MishNonlinearityConfig",
|
42
|
+
"NonlinearityConfig",
|
41
43
|
"PReLUConfig",
|
42
44
|
"ReLUNonlinearityConfig",
|
43
45
|
"SiLUNonlinearityConfig",
|
@@ -3,9 +3,11 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
6
|
+
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
6
7
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"AdamWConfig",
|
11
|
+
"OptimizerConfig",
|
10
12
|
"OptimizerConfigBase",
|
11
13
|
]
|
@@ -4,6 +4,7 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
6
6
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
7
|
+
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
7
8
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
8
9
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
9
10
|
|
@@ -15,6 +16,7 @@ from . import simple as simple
|
|
15
16
|
__all__ = [
|
16
17
|
"AdvancedProfilerConfig",
|
17
18
|
"BaseProfilerConfig",
|
19
|
+
"ProfilerConfig",
|
18
20
|
"PyTorchProfilerConfig",
|
19
21
|
"SimpleProfilerConfig",
|
20
22
|
"_base",
|
@@ -9,7 +9,11 @@ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
|
9
9
|
from nshtrainer.trainer._config import (
|
10
10
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
11
11
|
)
|
12
|
+
from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
|
12
13
|
from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
|
14
|
+
from nshtrainer.trainer._config import (
|
15
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
16
|
+
)
|
13
17
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
14
18
|
from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
|
15
19
|
from nshtrainer.trainer._config import (
|
@@ -29,6 +33,7 @@ from nshtrainer.trainer._config import (
|
|
29
33
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
30
34
|
)
|
31
35
|
from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
|
36
|
+
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
32
37
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
33
38
|
from nshtrainer.trainer._config import (
|
34
39
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
@@ -37,6 +42,7 @@ from nshtrainer.trainer._config import (
|
|
37
42
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
38
43
|
)
|
39
44
|
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
45
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
40
46
|
from nshtrainer.trainer._config import (
|
41
47
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
42
48
|
)
|
@@ -48,6 +54,9 @@ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
|
|
48
54
|
from nshtrainer.trainer._config import (
|
49
55
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
50
56
|
)
|
57
|
+
from nshtrainer.trainer._config import (
|
58
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
59
|
+
)
|
51
60
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
52
61
|
|
53
62
|
from . import _config as _config
|
@@ -59,7 +68,9 @@ __all__ = [
|
|
59
68
|
"BaseLoggerConfig",
|
60
69
|
"BestCheckpointCallbackConfig",
|
61
70
|
"CSVLoggerConfig",
|
71
|
+
"CallbackConfig",
|
62
72
|
"CallbackConfigBase",
|
73
|
+
"CheckpointCallbackConfig",
|
63
74
|
"CheckpointSavingConfig",
|
64
75
|
"DebugFlagCallbackConfig",
|
65
76
|
"DirectoryConfig",
|
@@ -70,15 +81,18 @@ __all__ = [
|
|
70
81
|
"LastCheckpointCallbackConfig",
|
71
82
|
"LearningRateMonitorConfig",
|
72
83
|
"LogEpochCallbackConfig",
|
84
|
+
"LoggerConfig",
|
73
85
|
"MetricConfig",
|
74
86
|
"NormLoggingCallbackConfig",
|
75
87
|
"OnExceptionCheckpointCallbackConfig",
|
76
88
|
"PluginConfigBase",
|
89
|
+
"ProfilerConfig",
|
77
90
|
"RLPSanityChecksCallbackConfig",
|
78
91
|
"SanityCheckingConfig",
|
79
92
|
"SharedParametersCallbackConfig",
|
80
93
|
"StrategyConfigBase",
|
81
94
|
"TensorboardLoggerConfig",
|
95
|
+
"TimeCheckpointCallbackConfig",
|
82
96
|
"TrainerConfig",
|
83
97
|
"WandbLoggerConfig",
|
84
98
|
"_config",
|
@@ -8,7 +8,11 @@ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
|
8
8
|
from nshtrainer.trainer._config import (
|
9
9
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
|
11
12
|
from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
|
13
|
+
from nshtrainer.trainer._config import (
|
14
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
15
|
+
)
|
12
16
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
13
17
|
from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
|
14
18
|
from nshtrainer.trainer._config import (
|
@@ -28,6 +32,7 @@ from nshtrainer.trainer._config import (
|
|
28
32
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
29
33
|
)
|
30
34
|
from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
|
35
|
+
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
31
36
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
32
37
|
from nshtrainer.trainer._config import (
|
33
38
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
@@ -36,6 +41,7 @@ from nshtrainer.trainer._config import (
|
|
36
41
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
37
42
|
)
|
38
43
|
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
44
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
39
45
|
from nshtrainer.trainer._config import (
|
40
46
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
41
47
|
)
|
@@ -47,6 +53,9 @@ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
|
|
47
53
|
from nshtrainer.trainer._config import (
|
48
54
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
49
55
|
)
|
56
|
+
from nshtrainer.trainer._config import (
|
57
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
58
|
+
)
|
50
59
|
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
51
60
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
52
61
|
|
@@ -56,7 +65,9 @@ __all__ = [
|
|
56
65
|
"BaseLoggerConfig",
|
57
66
|
"BestCheckpointCallbackConfig",
|
58
67
|
"CSVLoggerConfig",
|
68
|
+
"CallbackConfig",
|
59
69
|
"CallbackConfigBase",
|
70
|
+
"CheckpointCallbackConfig",
|
60
71
|
"CheckpointSavingConfig",
|
61
72
|
"DebugFlagCallbackConfig",
|
62
73
|
"DirectoryConfig",
|
@@ -67,15 +78,18 @@ __all__ = [
|
|
67
78
|
"LastCheckpointCallbackConfig",
|
68
79
|
"LearningRateMonitorConfig",
|
69
80
|
"LogEpochCallbackConfig",
|
81
|
+
"LoggerConfig",
|
70
82
|
"MetricConfig",
|
71
83
|
"NormLoggingCallbackConfig",
|
72
84
|
"OnExceptionCheckpointCallbackConfig",
|
73
85
|
"PluginConfigBase",
|
86
|
+
"ProfilerConfig",
|
74
87
|
"RLPSanityChecksCallbackConfig",
|
75
88
|
"SanityCheckingConfig",
|
76
89
|
"SharedParametersCallbackConfig",
|
77
90
|
"StrategyConfigBase",
|
78
91
|
"TensorboardLoggerConfig",
|
92
|
+
"TimeCheckpointCallbackConfig",
|
79
93
|
"TrainerConfig",
|
80
94
|
"WandbLoggerConfig",
|
81
95
|
]
|
@@ -32,6 +32,7 @@ from nshtrainer.util._environment_info import (
|
|
32
32
|
)
|
33
33
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
34
34
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
35
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
35
36
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
36
37
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
37
38
|
|
@@ -40,6 +41,7 @@ from . import config as config
|
|
40
41
|
|
41
42
|
__all__ = [
|
42
43
|
"DTypeConfig",
|
44
|
+
"DurationConfig",
|
43
45
|
"EnvironmentCUDAConfig",
|
44
46
|
"EnvironmentClassInformationConfig",
|
45
47
|
"EnvironmentConfig",
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
6
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
6
7
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
7
8
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
8
9
|
|
@@ -11,6 +12,7 @@ from . import duration as duration
|
|
11
12
|
|
12
13
|
__all__ = [
|
13
14
|
"DTypeConfig",
|
15
|
+
"DurationConfig",
|
14
16
|
"EpochsConfig",
|
15
17
|
"StepsConfig",
|
16
18
|
"dtype",
|
@@ -2,10 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.util.config.duration import DurationConfig as DurationConfig
|
5
6
|
from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
|
6
7
|
from nshtrainer.util.config.duration import StepsConfig as StepsConfig
|
7
8
|
|
8
9
|
__all__ = [
|
10
|
+
"DurationConfig",
|
9
11
|
"EpochsConfig",
|
10
12
|
"StepsConfig",
|
11
13
|
]
|
nshtrainer/loggers/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import BaseLoggerConfig as BaseLoggerConfig
|
8
9
|
from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
|
@@ -10,7 +11,13 @@ from .csv import CSVLoggerConfig as CSVLoggerConfig
|
|
10
11
|
from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
11
12
|
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
12
13
|
|
13
|
-
LoggerConfig
|
14
|
-
|
15
|
-
|
16
|
-
|
14
|
+
LoggerConfig = TypeAliasType(
|
15
|
+
"LoggerConfig",
|
16
|
+
Annotated[
|
17
|
+
CSVLoggerConfig
|
18
|
+
| TensorboardLoggerConfig
|
19
|
+
| WandbLoggerConfig
|
20
|
+
| ActSaveLoggerConfig,
|
21
|
+
C.Field(discriminator="name"),
|
22
|
+
],
|
23
|
+
)
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
8
9
|
from ._base import LRSchedulerMetadata as LRSchedulerMetadata
|
@@ -15,7 +16,10 @@ from .linear_warmup_cosine import (
|
|
15
16
|
from .reduce_lr_on_plateau import ReduceLROnPlateau as ReduceLROnPlateau
|
16
17
|
from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
17
18
|
|
18
|
-
LRSchedulerConfig
|
19
|
-
|
20
|
-
|
21
|
-
|
19
|
+
LRSchedulerConfig = TypeAliasType(
|
20
|
+
"LRSchedulerConfig",
|
21
|
+
Annotated[
|
22
|
+
LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
|
23
|
+
C.Field(discriminator="name"),
|
24
|
+
],
|
25
|
+
)
|
@@ -2,18 +2,20 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Iterable, Sequence
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, cast
|
6
6
|
|
7
7
|
from lightning.pytorch import Callback, LightningModule
|
8
|
-
from typing_extensions import override
|
8
|
+
from typing_extensions import TypeAliasType, override
|
9
9
|
|
10
10
|
from ..._callback import NTCallbackBase
|
11
11
|
from ...util.typing_utils import mixin_base_type
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
|
-
_Callback = Callback | NTCallbackBase
|
16
|
-
CallbackFn
|
15
|
+
_Callback = TypeAliasType("_Callback", Callback | NTCallbackBase)
|
16
|
+
CallbackFn = TypeAliasType(
|
17
|
+
"CallbackFn", Callable[[], _Callback | Iterable[_Callback] | None]
|
18
|
+
)
|
17
19
|
|
18
20
|
|
19
21
|
class CallbackRegistrarModuleMixin:
|
nshtrainer/optimizer.py
CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Annotated, Any, Literal
|
5
|
+
from typing import Annotated, Any, Literal
|
6
6
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
|
13
13
|
class OptimizerConfigBase(C.Config, ABC):
|
@@ -57,4 +57,6 @@ class AdamWConfig(OptimizerConfigBase):
|
|
57
57
|
)
|
58
58
|
|
59
59
|
|
60
|
-
OptimizerConfig
|
60
|
+
OptimizerConfig = TypeAliasType(
|
61
|
+
"OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
|
62
|
+
)
|
nshtrainer/profiler/__init__.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import Annotated
|
3
|
+
from typing import Annotated
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
|
+
from typing_extensions import TypeAliasType
|
6
7
|
|
7
8
|
from ._base import BaseProfilerConfig as BaseProfilerConfig
|
8
9
|
from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
|
9
10
|
from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
|
10
11
|
from .simple import SimpleProfilerConfig as SimpleProfilerConfig
|
11
12
|
|
12
|
-
ProfilerConfig
|
13
|
-
|
14
|
-
|
15
|
-
|
13
|
+
ProfilerConfig = TypeAliasType(
|
14
|
+
"ProfilerConfig",
|
15
|
+
Annotated[
|
16
|
+
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
17
|
+
C.Field(discriminator="name"),
|
18
|
+
],
|
19
|
+
)
|
nshtrainer/trainer/_config.py
CHANGED
@@ -14,7 +14,6 @@ from typing import (
|
|
14
14
|
Any,
|
15
15
|
ClassVar,
|
16
16
|
Literal,
|
17
|
-
TypeAlias,
|
18
17
|
)
|
19
18
|
|
20
19
|
import nshconfig as C
|
@@ -101,47 +100,53 @@ class StrategyConfigBase(C.Config, ABC):
|
|
101
100
|
strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
|
102
101
|
|
103
102
|
|
104
|
-
AcceleratorLiteral
|
105
|
-
"cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
|
106
|
-
|
107
|
-
|
108
|
-
StrategyLiteral
|
109
|
-
"
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
103
|
+
AcceleratorLiteral = TypeAliasType(
|
104
|
+
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
105
|
+
)
|
106
|
+
|
107
|
+
StrategyLiteral = TypeAliasType(
|
108
|
+
"StrategyLiteral",
|
109
|
+
Literal[
|
110
|
+
"auto",
|
111
|
+
"ddp",
|
112
|
+
"ddp_find_unused_parameters_false",
|
113
|
+
"ddp_find_unused_parameters_true",
|
114
|
+
"ddp_spawn",
|
115
|
+
"ddp_spawn_find_unused_parameters_false",
|
116
|
+
"ddp_spawn_find_unused_parameters_true",
|
117
|
+
"ddp_fork",
|
118
|
+
"ddp_fork_find_unused_parameters_false",
|
119
|
+
"ddp_fork_find_unused_parameters_true",
|
120
|
+
"ddp_notebook",
|
121
|
+
"dp",
|
122
|
+
"deepspeed",
|
123
|
+
"deepspeed_stage_1",
|
124
|
+
"deepspeed_stage_1_offload",
|
125
|
+
"deepspeed_stage_2",
|
126
|
+
"deepspeed_stage_2_offload",
|
127
|
+
"deepspeed_stage_3",
|
128
|
+
"deepspeed_stage_3_offload",
|
129
|
+
"deepspeed_stage_3_offload_nvme",
|
130
|
+
"fsdp",
|
131
|
+
"fsdp_cpu_offload",
|
132
|
+
"single_xla",
|
133
|
+
"xla_fsdp",
|
134
|
+
"xla",
|
135
|
+
"single_tpu",
|
136
|
+
],
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
CheckpointCallbackConfig = TypeAliasType(
|
141
|
+
"CheckpointCallbackConfig",
|
142
|
+
Annotated[
|
143
|
+
BestCheckpointCallbackConfig
|
144
|
+
| LastCheckpointCallbackConfig
|
145
|
+
| OnExceptionCheckpointCallbackConfig
|
146
|
+
| TimeCheckpointCallbackConfig,
|
147
|
+
C.Field(discriminator="name"),
|
148
|
+
],
|
149
|
+
)
|
145
150
|
|
146
151
|
|
147
152
|
class CheckpointSavingConfig(CallbackConfigBase):
|
@@ -4,14 +4,14 @@ import datetime
|
|
4
4
|
import logging
|
5
5
|
import time
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import Any, Literal
|
8
8
|
|
9
9
|
from lightning.pytorch.callbacks.callback import Callback
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
log = logging.getLogger(__name__)
|
13
13
|
|
14
|
-
Stage
|
14
|
+
Stage = TypeAliasType("Stage", Literal["train", "validate", "test", "predict"])
|
15
15
|
ALL_STAGES = ("train", "validate", "test", "predict")
|
16
16
|
|
17
17
|
|
@@ -12,7 +12,7 @@ from collections import defaultdict
|
|
12
12
|
from collections.abc import Callable
|
13
13
|
from pathlib import Path
|
14
14
|
from types import FrameType
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any
|
16
16
|
|
17
17
|
import nshrunner as nr
|
18
18
|
import torch.utils.data
|
@@ -22,12 +22,14 @@ from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompo
|
|
22
22
|
from lightning.pytorch.trainer.connectors.signal_connector import (
|
23
23
|
_SignalConnector as _LightningSignalConnector,
|
24
24
|
)
|
25
|
-
from typing_extensions import override
|
25
|
+
from typing_extensions import TypeAliasType, override
|
26
26
|
|
27
27
|
log = logging.getLogger(__name__)
|
28
28
|
|
29
|
-
_SIGNUM = int | signal.Signals
|
30
|
-
_HANDLER
|
29
|
+
_SIGNUM = TypeAliasType("_SIGNUM", int | signal.Signals)
|
30
|
+
_HANDLER = TypeAliasType(
|
31
|
+
"_HANDLER", Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
|
32
|
+
)
|
31
33
|
_IS_WINDOWS = platform.system() == "Windows"
|
32
34
|
|
33
35
|
|
nshtrainer/util/_useful_types.py
CHANGED
@@ -7,7 +7,14 @@ from collections.abc import Set as AbstractSet
|
|
7
7
|
from os import PathLike
|
8
8
|
from typing import Any, TypeVar, overload
|
9
9
|
|
10
|
-
from typing_extensions import
|
10
|
+
from typing_extensions import (
|
11
|
+
Buffer,
|
12
|
+
Literal,
|
13
|
+
Protocol,
|
14
|
+
SupportsIndex,
|
15
|
+
TypeAlias,
|
16
|
+
TypeAliasType,
|
17
|
+
)
|
11
18
|
|
12
19
|
_KT = TypeVar("_KT")
|
13
20
|
_KT_co = TypeVar("_KT_co", covariant=True)
|
@@ -60,7 +67,9 @@ class SupportsAllComparisons(
|
|
60
67
|
): ...
|
61
68
|
|
62
69
|
|
63
|
-
SupportsRichComparison
|
70
|
+
SupportsRichComparison = TypeAliasType(
|
71
|
+
"SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
|
72
|
+
)
|
64
73
|
SupportsRichComparisonT = TypeVar(
|
65
74
|
"SupportsRichComparisonT", bound=SupportsRichComparison
|
66
75
|
)
|
nshtrainer/util/config/dtype.py
CHANGED
@@ -1,57 +1,60 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING, Literal
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
4
4
|
|
5
5
|
import nshconfig as C
|
6
6
|
import torch
|
7
|
-
from typing_extensions import assert_never
|
7
|
+
from typing_extensions import TypeAliasType, assert_never
|
8
8
|
|
9
9
|
from ..bf16 import is_bf16_supported_no_emulation
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from ...trainer._config import TrainerConfig
|
13
13
|
|
14
|
-
DTypeName
|
15
|
-
"
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
14
|
+
DTypeName = TypeAliasType(
|
15
|
+
"DTypeName",
|
16
|
+
Literal[
|
17
|
+
"float32",
|
18
|
+
"float",
|
19
|
+
"float64",
|
20
|
+
"double",
|
21
|
+
"float16",
|
22
|
+
"bfloat16",
|
23
|
+
"float8_e4m3fn",
|
24
|
+
"float8_e4m3fnuz",
|
25
|
+
"float8_e5m2",
|
26
|
+
"float8_e5m2fnuz",
|
27
|
+
"half",
|
28
|
+
"uint8",
|
29
|
+
"uint16",
|
30
|
+
"uint32",
|
31
|
+
"uint64",
|
32
|
+
"int8",
|
33
|
+
"int16",
|
34
|
+
"short",
|
35
|
+
"int32",
|
36
|
+
"int",
|
37
|
+
"int64",
|
38
|
+
"long",
|
39
|
+
"complex32",
|
40
|
+
"complex64",
|
41
|
+
"chalf",
|
42
|
+
"cfloat",
|
43
|
+
"complex128",
|
44
|
+
"cdouble",
|
45
|
+
"quint8",
|
46
|
+
"qint8",
|
47
|
+
"qint32",
|
48
|
+
"bool",
|
49
|
+
"quint4x2",
|
50
|
+
"quint2x4",
|
51
|
+
"bits1x8",
|
52
|
+
"bits2x4",
|
53
|
+
"bits4x2",
|
54
|
+
"bits8",
|
55
|
+
"bits16",
|
56
|
+
],
|
57
|
+
)
|
55
58
|
|
56
59
|
|
57
60
|
class DTypeConfig(C.Config):
|
nshtrainer/util/path.py
CHANGED
@@ -6,11 +6,12 @@ import os
|
|
6
6
|
import platform
|
7
7
|
import shutil
|
8
8
|
from pathlib import Path
|
9
|
-
|
9
|
+
|
10
|
+
from typing_extensions import TypeAliasType
|
10
11
|
|
11
12
|
log = logging.getLogger(__name__)
|
12
13
|
|
13
|
-
_Path
|
14
|
+
_Path = TypeAliasType("_Path", str | Path)
|
14
15
|
|
15
16
|
|
16
17
|
def get_relative_path(source: _Path, destination: _Path):
|
@@ -7,8 +7,8 @@ nshtrainer/_directory.py,sha256=p2uk1FnISFEpMqlDevKhoWhQsCEtvHUPg459K-86QA8,3053
|
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=xj91CEKcuvpaiipZbVZovZiU_fdsTkPXOkQ-3Xb-FhU,14183
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=z2uDaUVzNXG79UOAYGpwC0hxk__LTbEBVhXJ3W-Srws,3911
|
10
|
-
nshtrainer/callbacks/actsave.py,sha256=
|
11
|
-
nshtrainer/callbacks/base.py,sha256=
|
10
|
+
nshtrainer/callbacks/actsave.py,sha256=bnqS3y9sit1hfUvqPx-WWRug2OeTHPM0PDEJaex7f3Q,3776
|
11
|
+
nshtrainer/callbacks/base.py,sha256=AFqvKNzH710xviYQ7X0hw0M7ETWRhBAWv3PN9WnrZw0,3608
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=09cwgAJawTFLdzZwYhg_jVgbW-1d09hwjHdI-PQRck0,797
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=hcvtpNvuyWxxP36XiGyFYwVsSwmHphKD-1IA95DinME,2602
|
@@ -27,23 +27,24 @@ nshtrainer/callbacks/lr_monitor.py,sha256=IyFZoXaxJoTBSkdLu1iEZ1qI8_UFNJwafR_xTV
|
|
27
27
|
nshtrainer/callbacks/norm_logging.py,sha256=C44Mvt73gqQEpCFd0j3qYg6NY7sL2jm3X1qJVY_XLfI,6329
|
28
28
|
nshtrainer/callbacks/print_table.py,sha256=lS49Hz0OLcv3VPxEfLBguwe57y2nmKg0pMF6HJxuJio,2974
|
29
29
|
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=kWl2dYOXn2L8k6ub_012jNkqOxtyea1yr1qWRNG6UW4,9990
|
30
|
-
nshtrainer/callbacks/shared_parameters.py,sha256=
|
30
|
+
nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoVQm7W90LZb9so,3015
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
34
|
+
nshtrainer/configs/__init__.py,sha256=zyo4lV9ObB3T3_hhBhzWGNb6MRma4h7QHD3OrypxqEw,10582
|
35
35
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
|
-
nshtrainer/configs/_directory/__init__.py,sha256=
|
37
|
+
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
38
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=VUgQnyEI2ekBxBIV15L09tKdrfGt7eWxnf30DiCLaso,416
|
39
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
39
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=APpF2jmafqbS4CoMmDvFADi0wdmXJ_BvFw4QnnQpok0,4353
|
40
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JJg9d8iNGpO-9M1LsK4h1cu3NYWniyIyLQ4SauFCzOs,272
|
41
41
|
nshtrainer/configs/callbacks/base/__init__.py,sha256=V694hzF_ubnA-hwTps30PeFbgDSm3I_UIMTnljM3_OI,176
|
42
|
-
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=
|
42
|
+
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=zPUItxoYWrMT9i1TxOvVhIeTa0NEFg-nDE5FjHfkP-A,1564
|
43
43
|
nshtrainer/configs/callbacks/checkpoint/_base/__init__.py,sha256=5jl6A5Gv6arZXmHV6lz5dQ8DL6PdJIfJqHLP4acClKQ,479
|
44
44
|
nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py,sha256=_KTkF3_Yx0WiwOvRf2s1KRvod2dryeGJITVkx10YqBE,648
|
45
45
|
nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=sH4aJyovCeT6h4xz9r5WfVA0eviJur4zZt7R0hQsyAk,539
|
46
46
|
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=iUoCrTJpvDGEYCYfNCpkficH6D7yq119hyxy1qeFSGU,410
|
47
|
+
nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py,sha256=ZQz5gXKSHz7xvnZbgU_bauNhAJDh9rblBA0ldkSsdwg,539
|
47
48
|
nshtrainer/configs/callbacks/debug_flag/__init__.py,sha256=gPC3EAqzuyP2hAcCf3s09sPDe7q_02S1eUCWvRTNKrI,317
|
48
49
|
nshtrainer/configs/callbacks/directory_setup/__init__.py,sha256=25551zMMctAkzcLEGBN7HSeQUIrBtbBq7whgPZjtepY,351
|
49
50
|
nshtrainer/configs/callbacks/early_stopping/__init__.py,sha256=Q-hAAIcucLNriw5PXIshgBv7Yr5sMDk4GwjTIDdWBxo,434
|
@@ -59,46 +60,46 @@ nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=Ivef5jk3RMYQDe
|
|
59
60
|
nshtrainer/configs/callbacks/timer/__init__.py,sha256=RHOQoREp4NxS_AvKNdc0UuUlS0PnqCxxsuOz5D8h7iM,310
|
60
61
|
nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=WM9hCGFl2LXDUOgkIGaV3tkdnXnVBasrhIILjbIeFUo,358
|
61
62
|
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=MW-ANrF529DxBhopovPjYEQ7nANX9ttd1K4_bJnKXks,322
|
62
|
-
nshtrainer/configs/loggers/__init__.py,sha256=
|
63
|
+
nshtrainer/configs/loggers/__init__.py,sha256=5wTekL79mQxit8f1K3AMllvb0mKertTzOKfC3gpE2Zk,1251
|
63
64
|
nshtrainer/configs/loggers/_base/__init__.py,sha256=HxPPPePsEjlNuhnjsMgYIl0rwj_iqNKKOBTEk_zIOsM,169
|
64
65
|
nshtrainer/configs/loggers/actsave/__init__.py,sha256=2lZQ4bpbjwd4MuUE_Z_PGbmQjjGtWCZUCtXqKO4dTSc,280
|
65
66
|
nshtrainer/configs/loggers/csv/__init__.py,sha256=M3QGF5GKiRGENy3re6LJKpa4A4RThy1FlmaFuR4cPyo,260
|
66
67
|
nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=FbkYXnSohIX6JN5XyI-9y91IJv_T3VB3IwmpagXAnM4,309
|
67
68
|
nshtrainer/configs/loggers/wandb/__init__.py,sha256=76qb0HhWojf0Ub1x9OkMjtzeXxE67KysBGa-MBbJyC4,651
|
68
|
-
nshtrainer/configs/lr_scheduler/__init__.py,sha256=
|
69
|
+
nshtrainer/configs/lr_scheduler/__init__.py,sha256=8ORO-QC12SjZ2F_reMoDgr8-O8nxZxX0IKU4fl-cC3A,1023
|
69
70
|
nshtrainer/configs/lr_scheduler/_base/__init__.py,sha256=fvGjkUJ1K2RVXjXror22QOtEa-xWFJz2Cz3HrBC5XfA,189
|
70
|
-
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=
|
71
|
+
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=i8LeZh0c4wqtZ1ehZb2LCq7kwOL0OyswMMOnwyI6R04,533
|
71
72
|
nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=lpXEFZY4cM3znZqYG9IZ1xNNtzttt8VVspSuOz0fb-k,467
|
72
73
|
nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
|
73
74
|
nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
|
74
|
-
nshtrainer/configs/nn/__init__.py,sha256=
|
75
|
-
nshtrainer/configs/nn/mlp/__init__.py,sha256=
|
76
|
-
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=
|
77
|
-
nshtrainer/configs/optimizer/__init__.py,sha256=
|
78
|
-
nshtrainer/configs/profiler/__init__.py,sha256=
|
75
|
+
nshtrainer/configs/nn/__init__.py,sha256=3hVc81Gs9AJYVkrwJkQ_ye7tLU2HOLdBj-mMkXx2c_I,1957
|
76
|
+
nshtrainer/configs/nn/mlp/__init__.py,sha256=eMECrgz-My9mFS7lpWVI3dj1ApB-E7xwfmNc37hUsPI,347
|
77
|
+
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=Gjr2HCx8jJTcfu7sLgn54o2ucGKaBea4encm4AWpKNY,2040
|
78
|
+
nshtrainer/configs/optimizer/__init__.py,sha256=IMEsEbiVFXSkj6WmDjNjmKQuRspphs5xZnYZ2gYE39Y,344
|
79
|
+
nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
|
79
80
|
nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
|
80
81
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
81
82
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
82
83
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
83
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
84
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
84
|
+
nshtrainer/configs/trainer/__init__.py,sha256=KIDYjJsc-WYXKiH2RNzAZJD5MKOTdO9wdtu_vWDNPxU,3936
|
85
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=1_Ad5uTvXdVuHMJB3s8s-0EraDwNZssg3sXBmVouF9w,3847
|
85
86
|
nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
|
86
|
-
nshtrainer/configs/util/__init__.py,sha256=
|
87
|
+
nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
|
87
88
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
88
|
-
nshtrainer/configs/util/config/__init__.py,sha256=
|
89
|
+
nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
|
89
90
|
nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEqhdVdcU0dgJW1Ljzpp6k,158
|
90
|
-
nshtrainer/configs/util/config/duration/__init__.py,sha256=
|
91
|
+
nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
|
91
92
|
nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
|
92
93
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
93
94
|
nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
|
94
95
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
95
|
-
nshtrainer/loggers/__init__.py,sha256
|
96
|
+
nshtrainer/loggers/__init__.py,sha256=-y8B-9TF6vJdZUQewJNDcZ2aOv04FEUFtKwaiDobIO0,670
|
96
97
|
nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
|
97
98
|
nshtrainer/loggers/actsave.py,sha256=23Kre-mq-Y9Iw1SRyGmHnHK1bc_0gTWFXJViv9bkVz0,1324
|
98
99
|
nshtrainer/loggers/csv.py,sha256=Deh5gm3oROJbQzigV4SHni5JRwSrBdm-4YD3yrcGnHo,1104
|
99
100
|
nshtrainer/loggers/tensorboard.py,sha256=jP9V4nlq_MXUaoD6xv1Cws2ioft83Lm8yUJhhGhuUrQ,2268
|
100
101
|
nshtrainer/loggers/wandb.py,sha256=EjKQQznLSUSCWO7uIviz9g0dVW4ZLxb_8UVhY4vR7r0,6800
|
101
|
-
nshtrainer/lr_scheduler/__init__.py,sha256=
|
102
|
+
nshtrainer/lr_scheduler/__init__.py,sha256=BGnO-okUTZOtF15-UmQ05U4oEatSF5VNs3YeidNEWn4,853
|
102
103
|
nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcVpsE,3723
|
103
104
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
|
104
105
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
|
@@ -106,7 +107,7 @@ nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMk
|
|
106
107
|
nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
|
107
108
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
108
109
|
nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
|
109
|
-
nshtrainer/model/mixins/callback.py,sha256=
|
110
|
+
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
110
111
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
111
112
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
112
113
|
nshtrainer/nn/__init__.py,sha256=sANhrZpeN5syLKOsmXMwhaFl2SBFPWcLaEe1EH22TWQ,1463
|
@@ -114,29 +115,29 @@ nshtrainer/nn/mlp.py,sha256=2W8bzE96DzCMzGm6WPiPhNFQfhqaoG3GXPn_oKBnlUM,5988
|
|
114
115
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
115
116
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
116
117
|
nshtrainer/nn/nonlinearity.py,sha256=mp5XvXRHURB6jwuZ0YyTj5ZoHJYNJNgO2aLtUY1D-2Y,6114
|
117
|
-
nshtrainer/optimizer.py,sha256=
|
118
|
-
nshtrainer/profiler/__init__.py,sha256=
|
118
|
+
nshtrainer/optimizer.py,sha256=wmSRpSoU59rstj2RBoifQ15ZwRInYpm0tDBQZ1gqOfE,1596
|
119
|
+
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
119
120
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
120
121
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
121
122
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
122
123
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
123
124
|
nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
|
124
|
-
nshtrainer/trainer/_config.py,sha256=
|
125
|
-
nshtrainer/trainer/_runtime_callback.py,sha256=
|
126
|
-
nshtrainer/trainer/signal_connector.py,sha256=
|
125
|
+
nshtrainer/trainer/_config.py,sha256=Mz9J2ZFqxTlttnRA1eScGRgSAuf3-o3i9-xjN7eTm-k,35256
|
126
|
+
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
127
|
+
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
127
128
|
nshtrainer/trainer/trainer.py,sha256=HHqT83zWtYY9g5yD6X9aWrVh5VSpILW8PhoE6fp4snE,20734
|
128
129
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
129
|
-
nshtrainer/util/_useful_types.py,sha256=
|
130
|
+
nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
|
130
131
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
131
132
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
132
|
-
nshtrainer/util/config/dtype.py,sha256=
|
133
|
+
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
133
134
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
134
135
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
135
|
-
nshtrainer/util/path.py,sha256=
|
136
|
+
nshtrainer/util/path.py,sha256=L-Nh9tlXSUfoP19TFbQq8I0AfS5ugCfGYTYFeddDHcs,3516
|
136
137
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
137
138
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
138
139
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
139
140
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
140
|
-
nshtrainer-1.0.
|
141
|
-
nshtrainer-1.0.
|
142
|
-
nshtrainer-1.0.
|
141
|
+
nshtrainer-1.0.0b28.dist-info/METADATA,sha256=1MJi65pa7HEVmtDR64Y32SwDe_bv1AZHSgyo6gIBmzo,988
|
142
|
+
nshtrainer-1.0.0b28.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
143
|
+
nshtrainer-1.0.0b28.dist-info/RECORD,,
|
File without changes
|