nshtrainer 1.0.0b27__py3-none-any.whl → 1.0.0b29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/configs/__init__.py +18 -0
- nshtrainer/configs/_directory/__init__.py +2 -0
- nshtrainer/configs/callbacks/__init__.py +2 -0
- nshtrainer/configs/loggers/__init__.py +2 -0
- nshtrainer/configs/lr_scheduler/__init__.py +6 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +2 -0
- nshtrainer/configs/nn/mlp/__init__.py +2 -0
- nshtrainer/configs/nn/nonlinearity/__init__.py +2 -0
- nshtrainer/configs/optimizer/__init__.py +2 -0
- nshtrainer/configs/profiler/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +10 -0
- nshtrainer/configs/trainer/_config/__init__.py +10 -0
- nshtrainer/configs/util/__init__.py +2 -0
- nshtrainer/configs/util/config/__init__.py +2 -0
- nshtrainer/configs/util/config/duration/__init__.py +2 -0
- nshtrainer/metrics/_config.py +2 -4
- nshtrainer/trainer/_config.py +24 -25
- {nshtrainer-1.0.0b27.dist-info → nshtrainer-1.0.0b29.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b27.dist-info → nshtrainer-1.0.0b29.dist-info}/RECORD +21 -22
- nshtrainer/util/_useful_types.py +0 -316
- {nshtrainer-1.0.0b27.dist-info → nshtrainer-1.0.0b29.dist-info}/WHEEL +0 -0
nshtrainer/configs/__init__.py
CHANGED
@@ -14,6 +14,7 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
14
14
|
from nshtrainer.callbacks import (
|
15
15
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
+
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
17
18
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
18
19
|
from nshtrainer.callbacks import (
|
19
20
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
@@ -60,11 +61,13 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
60
61
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
61
62
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
62
63
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
64
|
+
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
63
65
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
64
66
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
65
67
|
from nshtrainer.lr_scheduler import (
|
66
68
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
67
69
|
)
|
70
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
68
71
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
69
72
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
70
73
|
from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
|
@@ -73,6 +76,7 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
73
76
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
74
77
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
75
78
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
79
|
+
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
76
80
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
77
81
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
78
82
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -86,12 +90,17 @@ from nshtrainer.nn.nonlinearity import (
|
|
86
90
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
87
91
|
)
|
88
92
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
93
|
+
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
89
94
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
90
95
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
91
96
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
97
|
+
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
92
98
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
93
99
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
94
100
|
from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
|
101
|
+
from nshtrainer.trainer._config import (
|
102
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
103
|
+
)
|
95
104
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
96
105
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
97
106
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
@@ -130,6 +139,7 @@ from nshtrainer.util._environment_info import (
|
|
130
139
|
)
|
131
140
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
132
141
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
142
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
133
143
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
134
144
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
135
145
|
|
@@ -158,13 +168,16 @@ __all__ = [
|
|
158
168
|
"BaseProfilerConfig",
|
159
169
|
"BestCheckpointCallbackConfig",
|
160
170
|
"CSVLoggerConfig",
|
171
|
+
"CallbackConfig",
|
161
172
|
"CallbackConfigBase",
|
173
|
+
"CheckpointCallbackConfig",
|
162
174
|
"CheckpointMetadata",
|
163
175
|
"CheckpointSavingConfig",
|
164
176
|
"DTypeConfig",
|
165
177
|
"DebugFlagCallbackConfig",
|
166
178
|
"DirectoryConfig",
|
167
179
|
"DirectorySetupCallbackConfig",
|
180
|
+
"DurationConfig",
|
168
181
|
"ELUNonlinearityConfig",
|
169
182
|
"EMACallbackConfig",
|
170
183
|
"EarlyStoppingCallbackConfig",
|
@@ -187,21 +200,26 @@ __all__ = [
|
|
187
200
|
"GradientSkippingCallbackConfig",
|
188
201
|
"HuggingFaceHubAutoCreateConfig",
|
189
202
|
"HuggingFaceHubConfig",
|
203
|
+
"LRSchedulerConfig",
|
190
204
|
"LRSchedulerConfigBase",
|
191
205
|
"LastCheckpointCallbackConfig",
|
192
206
|
"LeakyReLUNonlinearityConfig",
|
193
207
|
"LearningRateMonitorConfig",
|
194
208
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
195
209
|
"LogEpochCallbackConfig",
|
210
|
+
"LoggerConfig",
|
196
211
|
"MLPConfig",
|
197
212
|
"MetricConfig",
|
198
213
|
"MishNonlinearityConfig",
|
214
|
+
"NonlinearityConfig",
|
199
215
|
"NormLoggingCallbackConfig",
|
200
216
|
"OnExceptionCheckpointCallbackConfig",
|
217
|
+
"OptimizerConfig",
|
201
218
|
"OptimizerConfigBase",
|
202
219
|
"PReLUConfig",
|
203
220
|
"PluginConfigBase",
|
204
221
|
"PrintTableMetricsCallbackConfig",
|
222
|
+
"ProfilerConfig",
|
205
223
|
"PyTorchProfilerConfig",
|
206
224
|
"RLPSanityChecksCallbackConfig",
|
207
225
|
"ReLUNonlinearityConfig",
|
@@ -6,8 +6,10 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
6
6
|
from nshtrainer._directory import (
|
7
7
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer._directory import LoggerConfig as LoggerConfig
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"DirectoryConfig",
|
12
13
|
"DirectorySetupCallbackConfig",
|
14
|
+
"LoggerConfig",
|
13
15
|
]
|
@@ -5,6 +5,7 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.callbacks import (
|
6
6
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
7
7
|
)
|
8
|
+
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
8
9
|
from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
|
9
10
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
10
11
|
from nshtrainer.callbacks import (
|
@@ -80,6 +81,7 @@ __all__ = [
|
|
80
81
|
"ActSaveConfig",
|
81
82
|
"BaseCheckpointCallbackConfig",
|
82
83
|
"BestCheckpointCallbackConfig",
|
84
|
+
"CallbackConfig",
|
83
85
|
"CallbackConfigBase",
|
84
86
|
"CheckpointMetadata",
|
85
87
|
"DebugFlagCallbackConfig",
|
@@ -5,6 +5,7 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
6
6
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
7
7
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
8
|
+
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
8
9
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
9
10
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
10
11
|
from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
|
@@ -26,6 +27,7 @@ __all__ = [
|
|
26
27
|
"BaseLoggerConfig",
|
27
28
|
"CSVLoggerConfig",
|
28
29
|
"CallbackConfigBase",
|
30
|
+
"LoggerConfig",
|
29
31
|
"TensorboardLoggerConfig",
|
30
32
|
"WandbLoggerConfig",
|
31
33
|
"WandbUploadCodeCallbackConfig",
|
@@ -5,8 +5,12 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.lr_scheduler import (
|
6
6
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
7
7
|
)
|
8
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
8
9
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
9
10
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
11
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
12
|
+
DurationConfig as DurationConfig,
|
13
|
+
)
|
10
14
|
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
|
11
15
|
|
12
16
|
from . import _base as _base
|
@@ -14,6 +18,8 @@ from . import linear_warmup_cosine as linear_warmup_cosine
|
|
14
18
|
from . import reduce_lr_on_plateau as reduce_lr_on_plateau
|
15
19
|
|
16
20
|
__all__ = [
|
21
|
+
"DurationConfig",
|
22
|
+
"LRSchedulerConfig",
|
17
23
|
"LRSchedulerConfigBase",
|
18
24
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
19
25
|
"MetricConfig",
|
@@ -2,6 +2,9 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
6
|
+
DurationConfig as DurationConfig,
|
7
|
+
)
|
5
8
|
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
6
9
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
7
10
|
)
|
@@ -10,6 +13,7 @@ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
10
13
|
)
|
11
14
|
|
12
15
|
__all__ = [
|
16
|
+
"DurationConfig",
|
13
17
|
"LRSchedulerConfigBase",
|
14
18
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
15
19
|
]
|
@@ -8,6 +8,7 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
8
8
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
9
9
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
10
10
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
11
|
+
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
11
12
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
12
13
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
13
14
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -31,6 +32,7 @@ __all__ = [
|
|
31
32
|
"LeakyReLUNonlinearityConfig",
|
32
33
|
"MLPConfig",
|
33
34
|
"MishNonlinearityConfig",
|
35
|
+
"NonlinearityConfig",
|
34
36
|
"PReLUConfig",
|
35
37
|
"ReLUNonlinearityConfig",
|
36
38
|
"SiLUNonlinearityConfig",
|
@@ -4,8 +4,10 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
|
6
6
|
from nshtrainer.nn.mlp import MLPConfig as MLPConfig
|
7
|
+
from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"BaseNonlinearityConfig",
|
10
11
|
"MLPConfig",
|
12
|
+
"NonlinearityConfig",
|
11
13
|
]
|
@@ -9,6 +9,7 @@ from nshtrainer.nn.nonlinearity import (
|
|
9
9
|
LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
|
10
10
|
)
|
11
11
|
from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
|
12
|
+
from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
|
12
13
|
from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
|
13
14
|
from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
14
15
|
from nshtrainer.nn.nonlinearity import (
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
38
39
|
"GELUNonlinearityConfig",
|
39
40
|
"LeakyReLUNonlinearityConfig",
|
40
41
|
"MishNonlinearityConfig",
|
42
|
+
"NonlinearityConfig",
|
41
43
|
"PReLUConfig",
|
42
44
|
"ReLUNonlinearityConfig",
|
43
45
|
"SiLUNonlinearityConfig",
|
@@ -3,9 +3,11 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
6
|
+
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
6
7
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"AdamWConfig",
|
11
|
+
"OptimizerConfig",
|
10
12
|
"OptimizerConfigBase",
|
11
13
|
]
|
@@ -4,6 +4,7 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
6
6
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
7
|
+
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
7
8
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
8
9
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
9
10
|
|
@@ -15,6 +16,7 @@ from . import simple as simple
|
|
15
16
|
__all__ = [
|
16
17
|
"AdvancedProfilerConfig",
|
17
18
|
"BaseProfilerConfig",
|
19
|
+
"ProfilerConfig",
|
18
20
|
"PyTorchProfilerConfig",
|
19
21
|
"SimpleProfilerConfig",
|
20
22
|
"_base",
|
@@ -9,7 +9,11 @@ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
|
9
9
|
from nshtrainer.trainer._config import (
|
10
10
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
11
11
|
)
|
12
|
+
from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
|
12
13
|
from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
|
14
|
+
from nshtrainer.trainer._config import (
|
15
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
16
|
+
)
|
13
17
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
14
18
|
from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
|
15
19
|
from nshtrainer.trainer._config import (
|
@@ -29,6 +33,7 @@ from nshtrainer.trainer._config import (
|
|
29
33
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
30
34
|
)
|
31
35
|
from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
|
36
|
+
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
32
37
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
33
38
|
from nshtrainer.trainer._config import (
|
34
39
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
@@ -37,6 +42,7 @@ from nshtrainer.trainer._config import (
|
|
37
42
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
38
43
|
)
|
39
44
|
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
45
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
40
46
|
from nshtrainer.trainer._config import (
|
41
47
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
42
48
|
)
|
@@ -62,7 +68,9 @@ __all__ = [
|
|
62
68
|
"BaseLoggerConfig",
|
63
69
|
"BestCheckpointCallbackConfig",
|
64
70
|
"CSVLoggerConfig",
|
71
|
+
"CallbackConfig",
|
65
72
|
"CallbackConfigBase",
|
73
|
+
"CheckpointCallbackConfig",
|
66
74
|
"CheckpointSavingConfig",
|
67
75
|
"DebugFlagCallbackConfig",
|
68
76
|
"DirectoryConfig",
|
@@ -73,10 +81,12 @@ __all__ = [
|
|
73
81
|
"LastCheckpointCallbackConfig",
|
74
82
|
"LearningRateMonitorConfig",
|
75
83
|
"LogEpochCallbackConfig",
|
84
|
+
"LoggerConfig",
|
76
85
|
"MetricConfig",
|
77
86
|
"NormLoggingCallbackConfig",
|
78
87
|
"OnExceptionCheckpointCallbackConfig",
|
79
88
|
"PluginConfigBase",
|
89
|
+
"ProfilerConfig",
|
80
90
|
"RLPSanityChecksCallbackConfig",
|
81
91
|
"SanityCheckingConfig",
|
82
92
|
"SharedParametersCallbackConfig",
|
@@ -8,7 +8,11 @@ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
|
8
8
|
from nshtrainer.trainer._config import (
|
9
9
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
|
11
12
|
from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
|
13
|
+
from nshtrainer.trainer._config import (
|
14
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
15
|
+
)
|
12
16
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
13
17
|
from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
|
14
18
|
from nshtrainer.trainer._config import (
|
@@ -28,6 +32,7 @@ from nshtrainer.trainer._config import (
|
|
28
32
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
29
33
|
)
|
30
34
|
from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
|
35
|
+
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
31
36
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
32
37
|
from nshtrainer.trainer._config import (
|
33
38
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
@@ -36,6 +41,7 @@ from nshtrainer.trainer._config import (
|
|
36
41
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
37
42
|
)
|
38
43
|
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
44
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
39
45
|
from nshtrainer.trainer._config import (
|
40
46
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
41
47
|
)
|
@@ -59,7 +65,9 @@ __all__ = [
|
|
59
65
|
"BaseLoggerConfig",
|
60
66
|
"BestCheckpointCallbackConfig",
|
61
67
|
"CSVLoggerConfig",
|
68
|
+
"CallbackConfig",
|
62
69
|
"CallbackConfigBase",
|
70
|
+
"CheckpointCallbackConfig",
|
63
71
|
"CheckpointSavingConfig",
|
64
72
|
"DebugFlagCallbackConfig",
|
65
73
|
"DirectoryConfig",
|
@@ -70,10 +78,12 @@ __all__ = [
|
|
70
78
|
"LastCheckpointCallbackConfig",
|
71
79
|
"LearningRateMonitorConfig",
|
72
80
|
"LogEpochCallbackConfig",
|
81
|
+
"LoggerConfig",
|
73
82
|
"MetricConfig",
|
74
83
|
"NormLoggingCallbackConfig",
|
75
84
|
"OnExceptionCheckpointCallbackConfig",
|
76
85
|
"PluginConfigBase",
|
86
|
+
"ProfilerConfig",
|
77
87
|
"RLPSanityChecksCallbackConfig",
|
78
88
|
"SanityCheckingConfig",
|
79
89
|
"SharedParametersCallbackConfig",
|
@@ -32,6 +32,7 @@ from nshtrainer.util._environment_info import (
|
|
32
32
|
)
|
33
33
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
34
34
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
35
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
35
36
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
36
37
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
37
38
|
|
@@ -40,6 +41,7 @@ from . import config as config
|
|
40
41
|
|
41
42
|
__all__ = [
|
42
43
|
"DTypeConfig",
|
44
|
+
"DurationConfig",
|
43
45
|
"EnvironmentCUDAConfig",
|
44
46
|
"EnvironmentClassInformationConfig",
|
45
47
|
"EnvironmentConfig",
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
6
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
6
7
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
7
8
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
8
9
|
|
@@ -11,6 +12,7 @@ from . import duration as duration
|
|
11
12
|
|
12
13
|
__all__ = [
|
13
14
|
"DTypeConfig",
|
15
|
+
"DurationConfig",
|
14
16
|
"EpochsConfig",
|
15
17
|
"StepsConfig",
|
16
18
|
"dtype",
|
@@ -2,10 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.util.config.duration import DurationConfig as DurationConfig
|
5
6
|
from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
|
6
7
|
from nshtrainer.util.config.duration import StepsConfig as StepsConfig
|
7
8
|
|
8
9
|
__all__ = [
|
10
|
+
"DurationConfig",
|
9
11
|
"EpochsConfig",
|
10
12
|
"StepsConfig",
|
11
13
|
]
|
nshtrainer/metrics/_config.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import builtins
|
4
|
-
from typing import Literal
|
4
|
+
from typing import Any, Literal
|
5
5
|
|
6
6
|
import nshconfig as C
|
7
7
|
|
8
|
-
from ..util._useful_types import SupportsRichComparisonT
|
9
|
-
|
10
8
|
|
11
9
|
class MetricConfig(C.Config):
|
12
10
|
name: str
|
@@ -40,5 +38,5 @@ class MetricConfig(C.Config):
|
|
40
38
|
def best(self):
|
41
39
|
return builtins.min if self.mode == "min" else builtins.max
|
42
40
|
|
43
|
-
def is_better(self, a:
|
41
|
+
def is_better(self, a: Any, b: Any):
|
44
42
|
return self.best(a, b) == a
|
nshtrainer/trainer/_config.py
CHANGED
@@ -82,6 +82,13 @@ class PluginConfigBase(C.Config, ABC):
|
|
82
82
|
|
83
83
|
|
84
84
|
plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
|
85
|
+
PluginConfig = TypeAliasType(
|
86
|
+
"PluginConfig", Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]
|
87
|
+
)
|
88
|
+
|
89
|
+
AcceleratorLiteral = TypeAliasType(
|
90
|
+
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
91
|
+
)
|
85
92
|
|
86
93
|
|
87
94
|
class AcceleratorConfigBase(C.Config, ABC):
|
@@ -90,18 +97,9 @@ class AcceleratorConfigBase(C.Config, ABC):
|
|
90
97
|
|
91
98
|
|
92
99
|
accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
@abstractmethod
|
97
|
-
def create_strategy(self) -> Strategy: ...
|
98
|
-
|
99
|
-
|
100
|
-
strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
|
101
|
-
|
102
|
-
|
103
|
-
AcceleratorLiteral = TypeAliasType(
|
104
|
-
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
100
|
+
AcceleratorConfig = TypeAliasType(
|
101
|
+
"AcceleratorConfig",
|
102
|
+
Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
|
105
103
|
)
|
106
104
|
|
107
105
|
StrategyLiteral = TypeAliasType(
|
@@ -137,6 +135,17 @@ StrategyLiteral = TypeAliasType(
|
|
137
135
|
)
|
138
136
|
|
139
137
|
|
138
|
+
class StrategyConfigBase(C.Config, ABC):
|
139
|
+
@abstractmethod
|
140
|
+
def create_strategy(self) -> Strategy: ...
|
141
|
+
|
142
|
+
|
143
|
+
strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
|
144
|
+
StrategyConfig = TypeAliasType(
|
145
|
+
"StrategyConfig",
|
146
|
+
Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()],
|
147
|
+
)
|
148
|
+
|
140
149
|
CheckpointCallbackConfig = TypeAliasType(
|
141
150
|
"CheckpointCallbackConfig",
|
142
151
|
Annotated[
|
@@ -578,9 +587,7 @@ class TrainerConfig(C.Config):
|
|
578
587
|
Default: ``False``.
|
579
588
|
"""
|
580
589
|
|
581
|
-
plugins:
|
582
|
-
list[Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]] | None
|
583
|
-
) = None
|
590
|
+
plugins: list[PluginConfig] | None = None
|
584
591
|
"""
|
585
592
|
Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
|
586
593
|
Default: ``None``.
|
@@ -740,21 +747,13 @@ class TrainerConfig(C.Config):
|
|
740
747
|
Default: ``True``.
|
741
748
|
"""
|
742
749
|
|
743
|
-
accelerator:
|
744
|
-
Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()]
|
745
|
-
| AcceleratorLiteral
|
746
|
-
| None
|
747
|
-
) = None
|
750
|
+
accelerator: AcceleratorConfig | AcceleratorLiteral | None = None
|
748
751
|
"""Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
|
749
752
|
as well as custom accelerator instances.
|
750
753
|
Default: ``"auto"``.
|
751
754
|
"""
|
752
755
|
|
753
|
-
strategy:
|
754
|
-
Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()]
|
755
|
-
| StrategyLiteral
|
756
|
-
| None
|
757
|
-
) = None
|
756
|
+
strategy: StrategyConfig | StrategyLiteral | None = None
|
758
757
|
"""Supports different training strategies with aliases as well custom strategies.
|
759
758
|
Default: ``"auto"``.
|
760
759
|
"""
|
@@ -31,12 +31,12 @@ nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoV
|
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
34
|
+
nshtrainer/configs/__init__.py,sha256=zyo4lV9ObB3T3_hhBhzWGNb6MRma4h7QHD3OrypxqEw,10582
|
35
35
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
|
-
nshtrainer/configs/_directory/__init__.py,sha256=
|
37
|
+
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
38
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=VUgQnyEI2ekBxBIV15L09tKdrfGt7eWxnf30DiCLaso,416
|
39
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
39
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=APpF2jmafqbS4CoMmDvFADi0wdmXJ_BvFw4QnnQpok0,4353
|
40
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JJg9d8iNGpO-9M1LsK4h1cu3NYWniyIyLQ4SauFCzOs,272
|
41
41
|
nshtrainer/configs/callbacks/base/__init__.py,sha256=V694hzF_ubnA-hwTps30PeFbgDSm3I_UIMTnljM3_OI,176
|
42
42
|
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=zPUItxoYWrMT9i1TxOvVhIeTa0NEFg-nDE5FjHfkP-A,1564
|
@@ -60,35 +60,35 @@ nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=Ivef5jk3RMYQDe
|
|
60
60
|
nshtrainer/configs/callbacks/timer/__init__.py,sha256=RHOQoREp4NxS_AvKNdc0UuUlS0PnqCxxsuOz5D8h7iM,310
|
61
61
|
nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=WM9hCGFl2LXDUOgkIGaV3tkdnXnVBasrhIILjbIeFUo,358
|
62
62
|
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=MW-ANrF529DxBhopovPjYEQ7nANX9ttd1K4_bJnKXks,322
|
63
|
-
nshtrainer/configs/loggers/__init__.py,sha256=
|
63
|
+
nshtrainer/configs/loggers/__init__.py,sha256=5wTekL79mQxit8f1K3AMllvb0mKertTzOKfC3gpE2Zk,1251
|
64
64
|
nshtrainer/configs/loggers/_base/__init__.py,sha256=HxPPPePsEjlNuhnjsMgYIl0rwj_iqNKKOBTEk_zIOsM,169
|
65
65
|
nshtrainer/configs/loggers/actsave/__init__.py,sha256=2lZQ4bpbjwd4MuUE_Z_PGbmQjjGtWCZUCtXqKO4dTSc,280
|
66
66
|
nshtrainer/configs/loggers/csv/__init__.py,sha256=M3QGF5GKiRGENy3re6LJKpa4A4RThy1FlmaFuR4cPyo,260
|
67
67
|
nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=FbkYXnSohIX6JN5XyI-9y91IJv_T3VB3IwmpagXAnM4,309
|
68
68
|
nshtrainer/configs/loggers/wandb/__init__.py,sha256=76qb0HhWojf0Ub1x9OkMjtzeXxE67KysBGa-MBbJyC4,651
|
69
|
-
nshtrainer/configs/lr_scheduler/__init__.py,sha256=
|
69
|
+
nshtrainer/configs/lr_scheduler/__init__.py,sha256=8ORO-QC12SjZ2F_reMoDgr8-O8nxZxX0IKU4fl-cC3A,1023
|
70
70
|
nshtrainer/configs/lr_scheduler/_base/__init__.py,sha256=fvGjkUJ1K2RVXjXror22QOtEa-xWFJz2Cz3HrBC5XfA,189
|
71
|
-
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=
|
71
|
+
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=i8LeZh0c4wqtZ1ehZb2LCq7kwOL0OyswMMOnwyI6R04,533
|
72
72
|
nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=lpXEFZY4cM3znZqYG9IZ1xNNtzttt8VVspSuOz0fb-k,467
|
73
73
|
nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
|
74
74
|
nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
|
75
|
-
nshtrainer/configs/nn/__init__.py,sha256=
|
76
|
-
nshtrainer/configs/nn/mlp/__init__.py,sha256=
|
77
|
-
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=
|
78
|
-
nshtrainer/configs/optimizer/__init__.py,sha256=
|
79
|
-
nshtrainer/configs/profiler/__init__.py,sha256=
|
75
|
+
nshtrainer/configs/nn/__init__.py,sha256=3hVc81Gs9AJYVkrwJkQ_ye7tLU2HOLdBj-mMkXx2c_I,1957
|
76
|
+
nshtrainer/configs/nn/mlp/__init__.py,sha256=eMECrgz-My9mFS7lpWVI3dj1ApB-E7xwfmNc37hUsPI,347
|
77
|
+
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=Gjr2HCx8jJTcfu7sLgn54o2ucGKaBea4encm4AWpKNY,2040
|
78
|
+
nshtrainer/configs/optimizer/__init__.py,sha256=IMEsEbiVFXSkj6WmDjNjmKQuRspphs5xZnYZ2gYE39Y,344
|
79
|
+
nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
|
80
80
|
nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
|
81
81
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
82
82
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
83
83
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
84
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
85
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
84
|
+
nshtrainer/configs/trainer/__init__.py,sha256=KIDYjJsc-WYXKiH2RNzAZJD5MKOTdO9wdtu_vWDNPxU,3936
|
85
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=1_Ad5uTvXdVuHMJB3s8s-0EraDwNZssg3sXBmVouF9w,3847
|
86
86
|
nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
|
87
|
-
nshtrainer/configs/util/__init__.py,sha256=
|
87
|
+
nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
|
88
88
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
89
|
-
nshtrainer/configs/util/config/__init__.py,sha256=
|
89
|
+
nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
|
90
90
|
nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEqhdVdcU0dgJW1Ljzpp6k,158
|
91
|
-
nshtrainer/configs/util/config/duration/__init__.py,sha256=
|
91
|
+
nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
|
92
92
|
nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
|
93
93
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
94
94
|
nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
|
@@ -104,7 +104,7 @@ nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcV
|
|
104
104
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
|
105
105
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
|
106
106
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
107
|
-
nshtrainer/metrics/_config.py,sha256=
|
107
|
+
nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
|
108
108
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
109
109
|
nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
|
110
110
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
@@ -122,12 +122,11 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
122
122
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
123
123
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
124
124
|
nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
|
125
|
-
nshtrainer/trainer/_config.py,sha256=
|
125
|
+
nshtrainer/trainer/_config.py,sha256=VD0DfdS-pyQ2nFG83c4u5AUkSAHODmXLX5s2qtvS_to,35400
|
126
126
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
127
127
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
128
128
|
nshtrainer/trainer/trainer.py,sha256=HHqT83zWtYY9g5yD6X9aWrVh5VSpILW8PhoE6fp4snE,20734
|
129
129
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
130
|
-
nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
|
131
130
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
132
131
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
133
132
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
@@ -138,6 +137,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
138
137
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
139
138
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
140
139
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
141
|
-
nshtrainer-1.0.
|
142
|
-
nshtrainer-1.0.
|
143
|
-
nshtrainer-1.0.
|
140
|
+
nshtrainer-1.0.0b29.dist-info/METADATA,sha256=YRehZvU9svmmfAmwFrdmu-Tzxgi_EHbFwrn-ewD8W9c,988
|
141
|
+
nshtrainer-1.0.0b29.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
142
|
+
nshtrainer-1.0.0b29.dist-info/RECORD,,
|
nshtrainer/util/_useful_types.py
DELETED
@@ -1,316 +0,0 @@
|
|
1
|
-
"""Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
|
2
|
-
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
|
-
from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
|
6
|
-
from collections.abc import Set as AbstractSet
|
7
|
-
from os import PathLike
|
8
|
-
from typing import Any, TypeVar, overload
|
9
|
-
|
10
|
-
from typing_extensions import (
|
11
|
-
Buffer,
|
12
|
-
Literal,
|
13
|
-
Protocol,
|
14
|
-
SupportsIndex,
|
15
|
-
TypeAlias,
|
16
|
-
TypeAliasType,
|
17
|
-
)
|
18
|
-
|
19
|
-
_KT = TypeVar("_KT")
|
20
|
-
_KT_co = TypeVar("_KT_co", covariant=True)
|
21
|
-
_KT_contra = TypeVar("_KT_contra", contravariant=True)
|
22
|
-
_VT = TypeVar("_VT")
|
23
|
-
_VT_co = TypeVar("_VT_co", covariant=True)
|
24
|
-
_T = TypeVar("_T")
|
25
|
-
_T_co = TypeVar("_T_co", covariant=True)
|
26
|
-
_T_contra = TypeVar("_T_contra", contravariant=True)
|
27
|
-
|
28
|
-
# For partially known annotations. Usually, fields where type annotations
|
29
|
-
# haven't been added are left unannotated, but in some situations this
|
30
|
-
# isn't possible or a type is already partially known. In cases like these,
|
31
|
-
# use Incomplete instead of Any as a marker. For example, use
|
32
|
-
# "Incomplete | None" instead of "Any | None".
|
33
|
-
Incomplete: TypeAlias = Any
|
34
|
-
|
35
|
-
|
36
|
-
class IdentityFunction(Protocol):
|
37
|
-
def __call__(self, __x: _T) -> _T: ...
|
38
|
-
|
39
|
-
|
40
|
-
# ====================
|
41
|
-
# Comparison protocols
|
42
|
-
# ====================
|
43
|
-
|
44
|
-
|
45
|
-
class SupportsDunderLT(Protocol[_T_contra]):
|
46
|
-
def __lt__(self, __other: _T_contra) -> bool: ...
|
47
|
-
|
48
|
-
|
49
|
-
class SupportsDunderGT(Protocol[_T_contra]):
|
50
|
-
def __gt__(self, __other: _T_contra) -> bool: ...
|
51
|
-
|
52
|
-
|
53
|
-
class SupportsDunderLE(Protocol[_T_contra]):
|
54
|
-
def __le__(self, __other: _T_contra) -> bool: ...
|
55
|
-
|
56
|
-
|
57
|
-
class SupportsDunderGE(Protocol[_T_contra]):
|
58
|
-
def __ge__(self, __other: _T_contra) -> bool: ...
|
59
|
-
|
60
|
-
|
61
|
-
class SupportsAllComparisons(
|
62
|
-
SupportsDunderLT[Any],
|
63
|
-
SupportsDunderGT[Any],
|
64
|
-
SupportsDunderLE[Any],
|
65
|
-
SupportsDunderGE[Any],
|
66
|
-
Protocol,
|
67
|
-
): ...
|
68
|
-
|
69
|
-
|
70
|
-
SupportsRichComparison = TypeAliasType(
|
71
|
-
"SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
|
72
|
-
)
|
73
|
-
SupportsRichComparisonT = TypeVar(
|
74
|
-
"SupportsRichComparisonT", bound=SupportsRichComparison
|
75
|
-
)
|
76
|
-
|
77
|
-
# ====================
|
78
|
-
# Dunder protocols
|
79
|
-
# ====================
|
80
|
-
|
81
|
-
|
82
|
-
class SupportsNext(Protocol[_T_co]):
|
83
|
-
def __next__(self) -> _T_co: ...
|
84
|
-
|
85
|
-
|
86
|
-
class SupportsAnext(Protocol[_T_co]):
|
87
|
-
def __anext__(self) -> Awaitable[_T_co]: ...
|
88
|
-
|
89
|
-
|
90
|
-
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
91
|
-
def __add__(self, __x: _T_contra) -> _T_co: ...
|
92
|
-
|
93
|
-
|
94
|
-
class SupportsRAdd(Protocol[_T_contra, _T_co]):
|
95
|
-
def __radd__(self, __x: _T_contra) -> _T_co: ...
|
96
|
-
|
97
|
-
|
98
|
-
class SupportsSub(Protocol[_T_contra, _T_co]):
|
99
|
-
def __sub__(self, __x: _T_contra) -> _T_co: ...
|
100
|
-
|
101
|
-
|
102
|
-
class SupportsRSub(Protocol[_T_contra, _T_co]):
|
103
|
-
def __rsub__(self, __x: _T_contra) -> _T_co: ...
|
104
|
-
|
105
|
-
|
106
|
-
class SupportsDivMod(Protocol[_T_contra, _T_co]):
|
107
|
-
def __divmod__(self, __other: _T_contra) -> _T_co: ...
|
108
|
-
|
109
|
-
|
110
|
-
class SupportsRDivMod(Protocol[_T_contra, _T_co]):
|
111
|
-
def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
|
112
|
-
|
113
|
-
|
114
|
-
# This protocol is generic over the iterator type, while Iterable is
|
115
|
-
# generic over the type that is iterated over.
|
116
|
-
class SupportsIter(Protocol[_T_co]):
|
117
|
-
def __iter__(self) -> _T_co: ...
|
118
|
-
|
119
|
-
|
120
|
-
# This protocol is generic over the iterator type, while AsyncIterable is
|
121
|
-
# generic over the type that is iterated over.
|
122
|
-
class SupportsAiter(Protocol[_T_co]):
|
123
|
-
def __aiter__(self) -> _T_co: ...
|
124
|
-
|
125
|
-
|
126
|
-
class SupportsLenAndGetItem(Protocol[_T_co]):
|
127
|
-
def __len__(self) -> int: ...
|
128
|
-
def __getitem__(self, __k: int) -> _T_co: ...
|
129
|
-
|
130
|
-
|
131
|
-
class SupportsTrunc(Protocol):
|
132
|
-
def __trunc__(self) -> int: ...
|
133
|
-
|
134
|
-
|
135
|
-
# ====================
|
136
|
-
# Mapping-like protocols
|
137
|
-
# ====================
|
138
|
-
|
139
|
-
|
140
|
-
class SupportsItems(Protocol[_KT_co, _VT_co]):
|
141
|
-
def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
|
142
|
-
|
143
|
-
|
144
|
-
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
|
145
|
-
def keys(self) -> Iterable[_KT]: ...
|
146
|
-
def __getitem__(self, __key: _KT) -> _VT_co: ...
|
147
|
-
|
148
|
-
|
149
|
-
class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
|
150
|
-
def __contains__(self, __x: Any) -> bool: ...
|
151
|
-
def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
|
152
|
-
|
153
|
-
|
154
|
-
class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
|
155
|
-
def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
|
156
|
-
def __delitem__(self, __key: _KT_contra) -> None: ...
|
157
|
-
|
158
|
-
|
159
|
-
# ====================
|
160
|
-
# File handling
|
161
|
-
# ====================
|
162
|
-
|
163
|
-
StrPath: TypeAlias = str | PathLike[str]
|
164
|
-
BytesPath: TypeAlias = bytes | PathLike[bytes]
|
165
|
-
StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
|
166
|
-
|
167
|
-
OpenTextModeUpdating: TypeAlias = Literal[
|
168
|
-
"r+",
|
169
|
-
"+r",
|
170
|
-
"rt+",
|
171
|
-
"r+t",
|
172
|
-
"+rt",
|
173
|
-
"tr+",
|
174
|
-
"t+r",
|
175
|
-
"+tr",
|
176
|
-
"w+",
|
177
|
-
"+w",
|
178
|
-
"wt+",
|
179
|
-
"w+t",
|
180
|
-
"+wt",
|
181
|
-
"tw+",
|
182
|
-
"t+w",
|
183
|
-
"+tw",
|
184
|
-
"a+",
|
185
|
-
"+a",
|
186
|
-
"at+",
|
187
|
-
"a+t",
|
188
|
-
"+at",
|
189
|
-
"ta+",
|
190
|
-
"t+a",
|
191
|
-
"+ta",
|
192
|
-
"x+",
|
193
|
-
"+x",
|
194
|
-
"xt+",
|
195
|
-
"x+t",
|
196
|
-
"+xt",
|
197
|
-
"tx+",
|
198
|
-
"t+x",
|
199
|
-
"+tx",
|
200
|
-
]
|
201
|
-
OpenTextModeWriting: TypeAlias = Literal[
|
202
|
-
"w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
|
203
|
-
]
|
204
|
-
OpenTextModeReading: TypeAlias = Literal[
|
205
|
-
"r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
|
206
|
-
]
|
207
|
-
OpenTextMode: TypeAlias = (
|
208
|
-
OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
|
209
|
-
)
|
210
|
-
OpenBinaryModeUpdating: TypeAlias = Literal[
|
211
|
-
"rb+",
|
212
|
-
"r+b",
|
213
|
-
"+rb",
|
214
|
-
"br+",
|
215
|
-
"b+r",
|
216
|
-
"+br",
|
217
|
-
"wb+",
|
218
|
-
"w+b",
|
219
|
-
"+wb",
|
220
|
-
"bw+",
|
221
|
-
"b+w",
|
222
|
-
"+bw",
|
223
|
-
"ab+",
|
224
|
-
"a+b",
|
225
|
-
"+ab",
|
226
|
-
"ba+",
|
227
|
-
"b+a",
|
228
|
-
"+ba",
|
229
|
-
"xb+",
|
230
|
-
"x+b",
|
231
|
-
"+xb",
|
232
|
-
"bx+",
|
233
|
-
"b+x",
|
234
|
-
"+bx",
|
235
|
-
]
|
236
|
-
OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
|
237
|
-
OpenBinaryModeReading: TypeAlias = Literal[
|
238
|
-
"rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
|
239
|
-
]
|
240
|
-
OpenBinaryMode: TypeAlias = (
|
241
|
-
OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
|
242
|
-
)
|
243
|
-
|
244
|
-
|
245
|
-
class HasFileno(Protocol):
|
246
|
-
def fileno(self) -> int: ...
|
247
|
-
|
248
|
-
|
249
|
-
FileDescriptor: TypeAlias = int
|
250
|
-
FileDescriptorLike: TypeAlias = int | HasFileno
|
251
|
-
FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
|
252
|
-
|
253
|
-
|
254
|
-
class SupportsRead(Protocol[_T_co]):
|
255
|
-
def read(self, __length: int = ...) -> _T_co: ...
|
256
|
-
|
257
|
-
|
258
|
-
class SupportsReadline(Protocol[_T_co]):
|
259
|
-
def readline(self, __length: int = ...) -> _T_co: ...
|
260
|
-
|
261
|
-
|
262
|
-
class SupportsNoArgReadline(Protocol[_T_co]):
|
263
|
-
def readline(self) -> _T_co: ...
|
264
|
-
|
265
|
-
|
266
|
-
class SupportsWrite(Protocol[_T_contra]):
|
267
|
-
def write(self, __s: _T_contra) -> object: ...
|
268
|
-
|
269
|
-
|
270
|
-
# ====================
|
271
|
-
# Buffer protocols
|
272
|
-
# ====================
|
273
|
-
|
274
|
-
# Unfortunately PEP 688 does not allow us to distinguish read-only
|
275
|
-
# from writable buffers. We use these aliases for readability for now.
|
276
|
-
# Perhaps a future extension of the buffer protocol will allow us to
|
277
|
-
# distinguish these cases in the type system.
|
278
|
-
ReadOnlyBuffer: TypeAlias = Buffer
|
279
|
-
# Anything that implements the read-write buffer interface.
|
280
|
-
WriteableBuffer: TypeAlias = Buffer
|
281
|
-
# Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
|
282
|
-
ReadableBuffer: TypeAlias = Buffer
|
283
|
-
|
284
|
-
|
285
|
-
class SliceableBuffer(Buffer, Protocol):
|
286
|
-
def __getitem__(self, __slice: slice) -> Sequence[int]: ...
|
287
|
-
|
288
|
-
|
289
|
-
class IndexableBuffer(Buffer, Protocol):
|
290
|
-
def __getitem__(self, __i: int) -> int: ...
|
291
|
-
|
292
|
-
|
293
|
-
class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
|
294
|
-
def __contains__(self, __x: Any) -> bool: ...
|
295
|
-
@overload
|
296
|
-
def __getitem__(self, __slice: slice) -> Sequence[int]: ...
|
297
|
-
@overload
|
298
|
-
def __getitem__(self, __i: int) -> int: ...
|
299
|
-
|
300
|
-
|
301
|
-
class SizedBuffer(Sized, Buffer, Protocol): ...
|
302
|
-
|
303
|
-
|
304
|
-
# Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
|
305
|
-
# This works because str.__contains__ does not accept object (either in typeshed or at runtime)
|
306
|
-
class SequenceNotStr(Protocol[_T_co]):
|
307
|
-
@overload
|
308
|
-
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
|
309
|
-
@overload
|
310
|
-
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
|
311
|
-
def __contains__(self, value: object, /) -> bool: ...
|
312
|
-
def __len__(self) -> int: ...
|
313
|
-
def __iter__(self) -> Iterator[_T_co]: ...
|
314
|
-
def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
|
315
|
-
def count(self, value: Any, /) -> int: ...
|
316
|
-
def __reversed__(self) -> Iterator[_T_co]: ...
|
File without changes
|