nshtrainer 1.0.0b29__py3-none-any.whl → 1.0.0b31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/configs/__init__.py +95 -3
- nshtrainer/configs/trainer/__init__.py +103 -3
- nshtrainer/configs/trainer/_config/__init__.py +10 -6
- nshtrainer/configs/trainer/accelerator/__init__.py +25 -0
- nshtrainer/configs/trainer/plugin/__init__.py +98 -0
- nshtrainer/configs/trainer/plugin/base/__init__.py +13 -0
- nshtrainer/configs/trainer/plugin/environment/__init__.py +41 -0
- nshtrainer/configs/trainer/plugin/io/__init__.py +23 -0
- nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +15 -0
- nshtrainer/configs/trainer/plugin/precision/__init__.py +43 -0
- nshtrainer/configs/trainer/strategy/__init__.py +11 -0
- nshtrainer/configs/trainer/trainer/__init__.py +2 -0
- nshtrainer/data/datamodule.py +2 -0
- nshtrainer/model/base.py +2 -0
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +3 -47
- nshtrainer/trainer/accelerator.py +86 -0
- nshtrainer/trainer/plugin/__init__.py +10 -0
- nshtrainer/trainer/plugin/base.py +33 -0
- nshtrainer/trainer/plugin/environment.py +128 -0
- nshtrainer/trainer/plugin/io.py +62 -0
- nshtrainer/trainer/plugin/layer_sync.py +25 -0
- nshtrainer/trainer/plugin/precision.py +163 -0
- nshtrainer/trainer/strategy.py +51 -0
- nshtrainer/trainer/trainer.py +8 -9
- nshtrainer/util/hparams.py +17 -0
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/RECORD +30 -13
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
@@ -14,6 +14,8 @@ from .metrics import MetricConfig as MetricConfig
|
|
14
14
|
from .model import LightningModuleBase as LightningModuleBase
|
15
15
|
from .trainer import Trainer as Trainer
|
16
16
|
from .trainer import TrainerConfig as TrainerConfig
|
17
|
+
from .trainer import accelerator_registry as accelerator_registry
|
18
|
+
from .trainer import plugin_registry as plugin_registry
|
17
19
|
|
18
20
|
try:
|
19
21
|
from . import configs as configs
|
nshtrainer/configs/__init__.py
CHANGED
@@ -4,6 +4,8 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
|
+
from nshtrainer import accelerator_registry as accelerator_registry
|
8
|
+
from nshtrainer import plugin_registry as plugin_registry
|
7
9
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
8
10
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
9
11
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
@@ -97,7 +99,7 @@ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
|
97
99
|
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
98
100
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
99
101
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
100
|
-
from nshtrainer.trainer._config import
|
102
|
+
from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
|
101
103
|
from nshtrainer.trainer._config import (
|
102
104
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
103
105
|
)
|
@@ -107,9 +109,71 @@ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippin
|
|
107
109
|
from nshtrainer.trainer._config import (
|
108
110
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
109
111
|
)
|
110
|
-
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
111
112
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
112
|
-
from nshtrainer.trainer._config import
|
113
|
+
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
114
|
+
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
115
|
+
from nshtrainer.trainer.accelerator import (
|
116
|
+
CUDAAcceleratorConfig as CUDAAcceleratorConfig,
|
117
|
+
)
|
118
|
+
from nshtrainer.trainer.accelerator import MPSAcceleratorConfig as MPSAcceleratorConfig
|
119
|
+
from nshtrainer.trainer.accelerator import XLAAcceleratorConfig as XLAAcceleratorConfig
|
120
|
+
from nshtrainer.trainer.plugin import PluginConfig as PluginConfig
|
121
|
+
from nshtrainer.trainer.plugin import PluginConfigBase as PluginConfigBase
|
122
|
+
from nshtrainer.trainer.plugin.environment import (
|
123
|
+
KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
|
124
|
+
)
|
125
|
+
from nshtrainer.trainer.plugin.environment import (
|
126
|
+
LightningEnvironmentPlugin as LightningEnvironmentPlugin,
|
127
|
+
)
|
128
|
+
from nshtrainer.trainer.plugin.environment import (
|
129
|
+
LSFEnvironmentPlugin as LSFEnvironmentPlugin,
|
130
|
+
)
|
131
|
+
from nshtrainer.trainer.plugin.environment import (
|
132
|
+
MPIEnvironmentPlugin as MPIEnvironmentPlugin,
|
133
|
+
)
|
134
|
+
from nshtrainer.trainer.plugin.environment import (
|
135
|
+
SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
|
136
|
+
)
|
137
|
+
from nshtrainer.trainer.plugin.environment import (
|
138
|
+
TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
|
139
|
+
)
|
140
|
+
from nshtrainer.trainer.plugin.environment import (
|
141
|
+
XLAEnvironmentPlugin as XLAEnvironmentPlugin,
|
142
|
+
)
|
143
|
+
from nshtrainer.trainer.plugin.io import (
|
144
|
+
AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
|
145
|
+
)
|
146
|
+
from nshtrainer.trainer.plugin.io import (
|
147
|
+
TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
|
148
|
+
)
|
149
|
+
from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
|
150
|
+
from nshtrainer.trainer.plugin.layer_sync import (
|
151
|
+
TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
|
152
|
+
)
|
153
|
+
from nshtrainer.trainer.plugin.precision import (
|
154
|
+
BitsandbytesPluginConfig as BitsandbytesPluginConfig,
|
155
|
+
)
|
156
|
+
from nshtrainer.trainer.plugin.precision import (
|
157
|
+
DeepSpeedPluginConfig as DeepSpeedPluginConfig,
|
158
|
+
)
|
159
|
+
from nshtrainer.trainer.plugin.precision import (
|
160
|
+
DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
|
161
|
+
)
|
162
|
+
from nshtrainer.trainer.plugin.precision import (
|
163
|
+
FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
|
164
|
+
)
|
165
|
+
from nshtrainer.trainer.plugin.precision import (
|
166
|
+
HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
|
167
|
+
)
|
168
|
+
from nshtrainer.trainer.plugin.precision import (
|
169
|
+
MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
|
170
|
+
)
|
171
|
+
from nshtrainer.trainer.plugin.precision import (
|
172
|
+
TransformerEnginePluginConfig as TransformerEnginePluginConfig,
|
173
|
+
)
|
174
|
+
from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
|
175
|
+
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
176
|
+
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
113
177
|
from nshtrainer.util._environment_info import (
|
114
178
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
115
179
|
)
|
@@ -157,17 +221,22 @@ from . import trainer as trainer
|
|
157
221
|
from . import util as util
|
158
222
|
|
159
223
|
__all__ = [
|
224
|
+
"AcceleratorConfig",
|
160
225
|
"AcceleratorConfigBase",
|
161
226
|
"ActSaveConfig",
|
162
227
|
"ActSaveLoggerConfig",
|
163
228
|
"AdamWConfig",
|
164
229
|
"AdvancedProfilerConfig",
|
230
|
+
"AsyncCheckpointIOPlugin",
|
165
231
|
"BaseCheckpointCallbackConfig",
|
166
232
|
"BaseLoggerConfig",
|
167
233
|
"BaseNonlinearityConfig",
|
168
234
|
"BaseProfilerConfig",
|
169
235
|
"BestCheckpointCallbackConfig",
|
236
|
+
"BitsandbytesPluginConfig",
|
237
|
+
"CPUAcceleratorConfig",
|
170
238
|
"CSVLoggerConfig",
|
239
|
+
"CUDAAcceleratorConfig",
|
171
240
|
"CallbackConfig",
|
172
241
|
"CallbackConfigBase",
|
173
242
|
"CheckpointCallbackConfig",
|
@@ -175,8 +244,10 @@ __all__ = [
|
|
175
244
|
"CheckpointSavingConfig",
|
176
245
|
"DTypeConfig",
|
177
246
|
"DebugFlagCallbackConfig",
|
247
|
+
"DeepSpeedPluginConfig",
|
178
248
|
"DirectoryConfig",
|
179
249
|
"DirectorySetupCallbackConfig",
|
250
|
+
"DoublePrecisionPluginConfig",
|
180
251
|
"DurationConfig",
|
181
252
|
"ELUNonlinearityConfig",
|
182
253
|
"EMACallbackConfig",
|
@@ -193,30 +264,39 @@ __all__ = [
|
|
193
264
|
"EnvironmentSnapshotConfig",
|
194
265
|
"EpochTimerCallbackConfig",
|
195
266
|
"EpochsConfig",
|
267
|
+
"FSDPPrecisionPluginConfig",
|
196
268
|
"FiniteChecksCallbackConfig",
|
197
269
|
"GELUNonlinearityConfig",
|
198
270
|
"GitRepositoryConfig",
|
199
271
|
"GradientClippingConfig",
|
200
272
|
"GradientSkippingCallbackConfig",
|
273
|
+
"HalfPrecisionPluginConfig",
|
201
274
|
"HuggingFaceHubAutoCreateConfig",
|
202
275
|
"HuggingFaceHubConfig",
|
276
|
+
"KubeflowEnvironmentPlugin",
|
203
277
|
"LRSchedulerConfig",
|
204
278
|
"LRSchedulerConfigBase",
|
279
|
+
"LSFEnvironmentPlugin",
|
205
280
|
"LastCheckpointCallbackConfig",
|
206
281
|
"LeakyReLUNonlinearityConfig",
|
207
282
|
"LearningRateMonitorConfig",
|
283
|
+
"LightningEnvironmentPlugin",
|
208
284
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
209
285
|
"LogEpochCallbackConfig",
|
210
286
|
"LoggerConfig",
|
211
287
|
"MLPConfig",
|
288
|
+
"MPIEnvironmentPlugin",
|
289
|
+
"MPSAcceleratorConfig",
|
212
290
|
"MetricConfig",
|
213
291
|
"MishNonlinearityConfig",
|
292
|
+
"MixedPrecisionPluginConfig",
|
214
293
|
"NonlinearityConfig",
|
215
294
|
"NormLoggingCallbackConfig",
|
216
295
|
"OnExceptionCheckpointCallbackConfig",
|
217
296
|
"OptimizerConfig",
|
218
297
|
"OptimizerConfigBase",
|
219
298
|
"PReLUConfig",
|
299
|
+
"PluginConfig",
|
220
300
|
"PluginConfigBase",
|
221
301
|
"PrintTableMetricsCallbackConfig",
|
222
302
|
"ProfilerConfig",
|
@@ -224,6 +304,7 @@ __all__ = [
|
|
224
304
|
"RLPSanityChecksCallbackConfig",
|
225
305
|
"ReLUNonlinearityConfig",
|
226
306
|
"ReduceLROnPlateauConfig",
|
307
|
+
"SLURMEnvironmentPlugin",
|
227
308
|
"SanityCheckingConfig",
|
228
309
|
"SharedParametersCallbackConfig",
|
229
310
|
"SiLUNonlinearityConfig",
|
@@ -233,25 +314,36 @@ __all__ = [
|
|
233
314
|
"SoftplusNonlinearityConfig",
|
234
315
|
"SoftsignNonlinearityConfig",
|
235
316
|
"StepsConfig",
|
317
|
+
"StrategyConfig",
|
236
318
|
"StrategyConfigBase",
|
237
319
|
"SwiGLUNonlinearityConfig",
|
238
320
|
"SwishNonlinearityConfig",
|
239
321
|
"TanhNonlinearityConfig",
|
240
322
|
"TensorboardLoggerConfig",
|
241
323
|
"TimeCheckpointCallbackConfig",
|
324
|
+
"TorchCheckpointIOPlugin",
|
325
|
+
"TorchElasticEnvironmentPlugin",
|
326
|
+
"TorchSyncBatchNormPlugin",
|
242
327
|
"TrainerConfig",
|
328
|
+
"TransformerEnginePluginConfig",
|
243
329
|
"WandbLoggerConfig",
|
244
330
|
"WandbUploadCodeCallbackConfig",
|
245
331
|
"WandbWatchCallbackConfig",
|
332
|
+
"XLAAcceleratorConfig",
|
333
|
+
"XLACheckpointIOPlugin",
|
334
|
+
"XLAEnvironmentPlugin",
|
335
|
+
"XLAPluginConfig",
|
246
336
|
"_checkpoint",
|
247
337
|
"_directory",
|
248
338
|
"_hf_hub",
|
339
|
+
"accelerator_registry",
|
249
340
|
"callbacks",
|
250
341
|
"loggers",
|
251
342
|
"lr_scheduler",
|
252
343
|
"metrics",
|
253
344
|
"nn",
|
254
345
|
"optimizer",
|
346
|
+
"plugin_registry",
|
255
347
|
"profiler",
|
256
348
|
"trainer",
|
257
349
|
"util",
|
@@ -3,7 +3,9 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
6
|
-
from nshtrainer.trainer
|
6
|
+
from nshtrainer.trainer import accelerator_registry as accelerator_registry
|
7
|
+
from nshtrainer.trainer import plugin_registry as plugin_registry
|
8
|
+
from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
|
7
9
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
8
10
|
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
9
11
|
from nshtrainer.trainer._config import (
|
@@ -41,7 +43,6 @@ from nshtrainer.trainer._config import (
|
|
41
43
|
from nshtrainer.trainer._config import (
|
42
44
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
43
45
|
)
|
44
|
-
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
45
46
|
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
46
47
|
from nshtrainer.trainer._config import (
|
47
48
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
@@ -50,7 +51,7 @@ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingCon
|
|
50
51
|
from nshtrainer.trainer._config import (
|
51
52
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
52
53
|
)
|
53
|
-
from nshtrainer.trainer._config import
|
54
|
+
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
54
55
|
from nshtrainer.trainer._config import (
|
55
56
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
56
57
|
)
|
@@ -58,43 +59,142 @@ from nshtrainer.trainer._config import (
|
|
58
59
|
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
59
60
|
)
|
60
61
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
62
|
+
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
63
|
+
from nshtrainer.trainer.accelerator import (
|
64
|
+
CUDAAcceleratorConfig as CUDAAcceleratorConfig,
|
65
|
+
)
|
66
|
+
from nshtrainer.trainer.accelerator import MPSAcceleratorConfig as MPSAcceleratorConfig
|
67
|
+
from nshtrainer.trainer.accelerator import XLAAcceleratorConfig as XLAAcceleratorConfig
|
68
|
+
from nshtrainer.trainer.plugin import PluginConfig as PluginConfig
|
69
|
+
from nshtrainer.trainer.plugin import PluginConfigBase as PluginConfigBase
|
70
|
+
from nshtrainer.trainer.plugin.environment import (
|
71
|
+
KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
|
72
|
+
)
|
73
|
+
from nshtrainer.trainer.plugin.environment import (
|
74
|
+
LightningEnvironmentPlugin as LightningEnvironmentPlugin,
|
75
|
+
)
|
76
|
+
from nshtrainer.trainer.plugin.environment import (
|
77
|
+
LSFEnvironmentPlugin as LSFEnvironmentPlugin,
|
78
|
+
)
|
79
|
+
from nshtrainer.trainer.plugin.environment import (
|
80
|
+
MPIEnvironmentPlugin as MPIEnvironmentPlugin,
|
81
|
+
)
|
82
|
+
from nshtrainer.trainer.plugin.environment import (
|
83
|
+
SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
|
84
|
+
)
|
85
|
+
from nshtrainer.trainer.plugin.environment import (
|
86
|
+
TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
|
87
|
+
)
|
88
|
+
from nshtrainer.trainer.plugin.environment import (
|
89
|
+
XLAEnvironmentPlugin as XLAEnvironmentPlugin,
|
90
|
+
)
|
91
|
+
from nshtrainer.trainer.plugin.io import (
|
92
|
+
AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
|
93
|
+
)
|
94
|
+
from nshtrainer.trainer.plugin.io import (
|
95
|
+
TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
|
96
|
+
)
|
97
|
+
from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
|
98
|
+
from nshtrainer.trainer.plugin.layer_sync import (
|
99
|
+
TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
|
100
|
+
)
|
101
|
+
from nshtrainer.trainer.plugin.precision import (
|
102
|
+
BitsandbytesPluginConfig as BitsandbytesPluginConfig,
|
103
|
+
)
|
104
|
+
from nshtrainer.trainer.plugin.precision import (
|
105
|
+
DeepSpeedPluginConfig as DeepSpeedPluginConfig,
|
106
|
+
)
|
107
|
+
from nshtrainer.trainer.plugin.precision import (
|
108
|
+
DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
|
109
|
+
)
|
110
|
+
from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
|
111
|
+
from nshtrainer.trainer.plugin.precision import (
|
112
|
+
FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
|
113
|
+
)
|
114
|
+
from nshtrainer.trainer.plugin.precision import (
|
115
|
+
HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
|
116
|
+
)
|
117
|
+
from nshtrainer.trainer.plugin.precision import (
|
118
|
+
MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
|
119
|
+
)
|
120
|
+
from nshtrainer.trainer.plugin.precision import (
|
121
|
+
TransformerEnginePluginConfig as TransformerEnginePluginConfig,
|
122
|
+
)
|
123
|
+
from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
|
124
|
+
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
125
|
+
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
61
126
|
|
62
127
|
from . import _config as _config
|
128
|
+
from . import accelerator as accelerator
|
129
|
+
from . import plugin as plugin
|
130
|
+
from . import strategy as strategy
|
63
131
|
from . import trainer as trainer
|
64
132
|
|
65
133
|
__all__ = [
|
134
|
+
"AcceleratorConfig",
|
66
135
|
"AcceleratorConfigBase",
|
67
136
|
"ActSaveLoggerConfig",
|
137
|
+
"AsyncCheckpointIOPlugin",
|
68
138
|
"BaseLoggerConfig",
|
69
139
|
"BestCheckpointCallbackConfig",
|
140
|
+
"BitsandbytesPluginConfig",
|
141
|
+
"CPUAcceleratorConfig",
|
70
142
|
"CSVLoggerConfig",
|
143
|
+
"CUDAAcceleratorConfig",
|
71
144
|
"CallbackConfig",
|
72
145
|
"CallbackConfigBase",
|
73
146
|
"CheckpointCallbackConfig",
|
74
147
|
"CheckpointSavingConfig",
|
148
|
+
"DTypeConfig",
|
75
149
|
"DebugFlagCallbackConfig",
|
150
|
+
"DeepSpeedPluginConfig",
|
76
151
|
"DirectoryConfig",
|
152
|
+
"DoublePrecisionPluginConfig",
|
77
153
|
"EarlyStoppingCallbackConfig",
|
78
154
|
"EnvironmentConfig",
|
155
|
+
"FSDPPrecisionPluginConfig",
|
79
156
|
"GradientClippingConfig",
|
157
|
+
"HalfPrecisionPluginConfig",
|
80
158
|
"HuggingFaceHubConfig",
|
159
|
+
"KubeflowEnvironmentPlugin",
|
160
|
+
"LSFEnvironmentPlugin",
|
81
161
|
"LastCheckpointCallbackConfig",
|
82
162
|
"LearningRateMonitorConfig",
|
163
|
+
"LightningEnvironmentPlugin",
|
83
164
|
"LogEpochCallbackConfig",
|
84
165
|
"LoggerConfig",
|
166
|
+
"MPIEnvironmentPlugin",
|
167
|
+
"MPSAcceleratorConfig",
|
85
168
|
"MetricConfig",
|
169
|
+
"MixedPrecisionPluginConfig",
|
86
170
|
"NormLoggingCallbackConfig",
|
87
171
|
"OnExceptionCheckpointCallbackConfig",
|
172
|
+
"PluginConfig",
|
88
173
|
"PluginConfigBase",
|
89
174
|
"ProfilerConfig",
|
90
175
|
"RLPSanityChecksCallbackConfig",
|
176
|
+
"SLURMEnvironmentPlugin",
|
91
177
|
"SanityCheckingConfig",
|
92
178
|
"SharedParametersCallbackConfig",
|
179
|
+
"StrategyConfig",
|
93
180
|
"StrategyConfigBase",
|
94
181
|
"TensorboardLoggerConfig",
|
95
182
|
"TimeCheckpointCallbackConfig",
|
183
|
+
"TorchCheckpointIOPlugin",
|
184
|
+
"TorchElasticEnvironmentPlugin",
|
185
|
+
"TorchSyncBatchNormPlugin",
|
96
186
|
"TrainerConfig",
|
187
|
+
"TransformerEnginePluginConfig",
|
97
188
|
"WandbLoggerConfig",
|
189
|
+
"XLAAcceleratorConfig",
|
190
|
+
"XLACheckpointIOPlugin",
|
191
|
+
"XLAEnvironmentPlugin",
|
192
|
+
"XLAPluginConfig",
|
98
193
|
"_config",
|
194
|
+
"accelerator",
|
195
|
+
"accelerator_registry",
|
196
|
+
"plugin",
|
197
|
+
"plugin_registry",
|
198
|
+
"strategy",
|
99
199
|
"trainer",
|
100
200
|
]
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
-
from nshtrainer.trainer._config import
|
5
|
+
from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
|
6
6
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
7
7
|
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
8
8
|
from nshtrainer.trainer._config import (
|
@@ -40,7 +40,7 @@ from nshtrainer.trainer._config import (
|
|
40
40
|
from nshtrainer.trainer._config import (
|
41
41
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
42
42
|
)
|
43
|
-
from nshtrainer.trainer._config import
|
43
|
+
from nshtrainer.trainer._config import PluginConfig as PluginConfig
|
44
44
|
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
45
45
|
from nshtrainer.trainer._config import (
|
46
46
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
@@ -49,7 +49,7 @@ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingCon
|
|
49
49
|
from nshtrainer.trainer._config import (
|
50
50
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
51
51
|
)
|
52
|
-
from nshtrainer.trainer._config import
|
52
|
+
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
53
53
|
from nshtrainer.trainer._config import (
|
54
54
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
55
55
|
)
|
@@ -58,9 +58,11 @@ from nshtrainer.trainer._config import (
|
|
58
58
|
)
|
59
59
|
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
60
60
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
61
|
+
from nshtrainer.trainer._config import accelerator_registry as accelerator_registry
|
62
|
+
from nshtrainer.trainer._config import plugin_registry as plugin_registry
|
61
63
|
|
62
64
|
__all__ = [
|
63
|
-
"
|
65
|
+
"AcceleratorConfig",
|
64
66
|
"ActSaveLoggerConfig",
|
65
67
|
"BaseLoggerConfig",
|
66
68
|
"BestCheckpointCallbackConfig",
|
@@ -82,14 +84,16 @@ __all__ = [
|
|
82
84
|
"MetricConfig",
|
83
85
|
"NormLoggingCallbackConfig",
|
84
86
|
"OnExceptionCheckpointCallbackConfig",
|
85
|
-
"
|
87
|
+
"PluginConfig",
|
86
88
|
"ProfilerConfig",
|
87
89
|
"RLPSanityChecksCallbackConfig",
|
88
90
|
"SanityCheckingConfig",
|
89
91
|
"SharedParametersCallbackConfig",
|
90
|
-
"
|
92
|
+
"StrategyConfig",
|
91
93
|
"TensorboardLoggerConfig",
|
92
94
|
"TimeCheckpointCallbackConfig",
|
93
95
|
"TrainerConfig",
|
94
96
|
"WandbLoggerConfig",
|
97
|
+
"accelerator_registry",
|
98
|
+
"plugin_registry",
|
95
99
|
]
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.accelerator import AcceleratorConfig as AcceleratorConfig
|
6
|
+
from nshtrainer.trainer.accelerator import (
|
7
|
+
AcceleratorConfigBase as AcceleratorConfigBase,
|
8
|
+
)
|
9
|
+
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
10
|
+
from nshtrainer.trainer.accelerator import (
|
11
|
+
CUDAAcceleratorConfig as CUDAAcceleratorConfig,
|
12
|
+
)
|
13
|
+
from nshtrainer.trainer.accelerator import MPSAcceleratorConfig as MPSAcceleratorConfig
|
14
|
+
from nshtrainer.trainer.accelerator import XLAAcceleratorConfig as XLAAcceleratorConfig
|
15
|
+
from nshtrainer.trainer.accelerator import accelerator_registry as accelerator_registry
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"AcceleratorConfig",
|
19
|
+
"AcceleratorConfigBase",
|
20
|
+
"CPUAcceleratorConfig",
|
21
|
+
"CUDAAcceleratorConfig",
|
22
|
+
"MPSAcceleratorConfig",
|
23
|
+
"XLAAcceleratorConfig",
|
24
|
+
"accelerator_registry",
|
25
|
+
]
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.plugin import PluginConfig as PluginConfig
|
6
|
+
from nshtrainer.trainer.plugin import PluginConfigBase as PluginConfigBase
|
7
|
+
from nshtrainer.trainer.plugin import plugin_registry as plugin_registry
|
8
|
+
from nshtrainer.trainer.plugin.environment import (
|
9
|
+
KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
|
10
|
+
)
|
11
|
+
from nshtrainer.trainer.plugin.environment import (
|
12
|
+
LightningEnvironmentPlugin as LightningEnvironmentPlugin,
|
13
|
+
)
|
14
|
+
from nshtrainer.trainer.plugin.environment import (
|
15
|
+
LSFEnvironmentPlugin as LSFEnvironmentPlugin,
|
16
|
+
)
|
17
|
+
from nshtrainer.trainer.plugin.environment import (
|
18
|
+
MPIEnvironmentPlugin as MPIEnvironmentPlugin,
|
19
|
+
)
|
20
|
+
from nshtrainer.trainer.plugin.environment import (
|
21
|
+
SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
|
22
|
+
)
|
23
|
+
from nshtrainer.trainer.plugin.environment import (
|
24
|
+
TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
|
25
|
+
)
|
26
|
+
from nshtrainer.trainer.plugin.environment import (
|
27
|
+
XLAEnvironmentPlugin as XLAEnvironmentPlugin,
|
28
|
+
)
|
29
|
+
from nshtrainer.trainer.plugin.io import (
|
30
|
+
AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
|
31
|
+
)
|
32
|
+
from nshtrainer.trainer.plugin.io import (
|
33
|
+
TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
|
34
|
+
)
|
35
|
+
from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
|
36
|
+
from nshtrainer.trainer.plugin.layer_sync import (
|
37
|
+
TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
|
38
|
+
)
|
39
|
+
from nshtrainer.trainer.plugin.precision import (
|
40
|
+
BitsandbytesPluginConfig as BitsandbytesPluginConfig,
|
41
|
+
)
|
42
|
+
from nshtrainer.trainer.plugin.precision import (
|
43
|
+
DeepSpeedPluginConfig as DeepSpeedPluginConfig,
|
44
|
+
)
|
45
|
+
from nshtrainer.trainer.plugin.precision import (
|
46
|
+
DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
|
47
|
+
)
|
48
|
+
from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
|
49
|
+
from nshtrainer.trainer.plugin.precision import (
|
50
|
+
FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
|
51
|
+
)
|
52
|
+
from nshtrainer.trainer.plugin.precision import (
|
53
|
+
HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
|
54
|
+
)
|
55
|
+
from nshtrainer.trainer.plugin.precision import (
|
56
|
+
MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
|
57
|
+
)
|
58
|
+
from nshtrainer.trainer.plugin.precision import (
|
59
|
+
TransformerEnginePluginConfig as TransformerEnginePluginConfig,
|
60
|
+
)
|
61
|
+
from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
|
62
|
+
|
63
|
+
from . import base as base
|
64
|
+
from . import environment as environment
|
65
|
+
from . import io as io
|
66
|
+
from . import layer_sync as layer_sync
|
67
|
+
from . import precision as precision
|
68
|
+
|
69
|
+
__all__ = [
|
70
|
+
"AsyncCheckpointIOPlugin",
|
71
|
+
"BitsandbytesPluginConfig",
|
72
|
+
"DTypeConfig",
|
73
|
+
"DeepSpeedPluginConfig",
|
74
|
+
"DoublePrecisionPluginConfig",
|
75
|
+
"FSDPPrecisionPluginConfig",
|
76
|
+
"HalfPrecisionPluginConfig",
|
77
|
+
"KubeflowEnvironmentPlugin",
|
78
|
+
"LSFEnvironmentPlugin",
|
79
|
+
"LightningEnvironmentPlugin",
|
80
|
+
"MPIEnvironmentPlugin",
|
81
|
+
"MixedPrecisionPluginConfig",
|
82
|
+
"PluginConfig",
|
83
|
+
"PluginConfigBase",
|
84
|
+
"SLURMEnvironmentPlugin",
|
85
|
+
"TorchCheckpointIOPlugin",
|
86
|
+
"TorchElasticEnvironmentPlugin",
|
87
|
+
"TorchSyncBatchNormPlugin",
|
88
|
+
"TransformerEnginePluginConfig",
|
89
|
+
"XLACheckpointIOPlugin",
|
90
|
+
"XLAEnvironmentPlugin",
|
91
|
+
"XLAPluginConfig",
|
92
|
+
"base",
|
93
|
+
"environment",
|
94
|
+
"io",
|
95
|
+
"layer_sync",
|
96
|
+
"plugin_registry",
|
97
|
+
"precision",
|
98
|
+
]
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.plugin.base import PluginConfig as PluginConfig
|
6
|
+
from nshtrainer.trainer.plugin.base import PluginConfigBase as PluginConfigBase
|
7
|
+
from nshtrainer.trainer.plugin.base import plugin_registry as plugin_registry
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"PluginConfig",
|
11
|
+
"PluginConfigBase",
|
12
|
+
"plugin_registry",
|
13
|
+
]
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.plugin.environment import DTypeConfig as DTypeConfig
|
6
|
+
from nshtrainer.trainer.plugin.environment import (
|
7
|
+
KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
|
8
|
+
)
|
9
|
+
from nshtrainer.trainer.plugin.environment import (
|
10
|
+
LightningEnvironmentPlugin as LightningEnvironmentPlugin,
|
11
|
+
)
|
12
|
+
from nshtrainer.trainer.plugin.environment import (
|
13
|
+
LSFEnvironmentPlugin as LSFEnvironmentPlugin,
|
14
|
+
)
|
15
|
+
from nshtrainer.trainer.plugin.environment import (
|
16
|
+
MPIEnvironmentPlugin as MPIEnvironmentPlugin,
|
17
|
+
)
|
18
|
+
from nshtrainer.trainer.plugin.environment import PluginConfigBase as PluginConfigBase
|
19
|
+
from nshtrainer.trainer.plugin.environment import (
|
20
|
+
SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
|
21
|
+
)
|
22
|
+
from nshtrainer.trainer.plugin.environment import (
|
23
|
+
TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
|
24
|
+
)
|
25
|
+
from nshtrainer.trainer.plugin.environment import (
|
26
|
+
XLAEnvironmentPlugin as XLAEnvironmentPlugin,
|
27
|
+
)
|
28
|
+
from nshtrainer.trainer.plugin.environment import plugin_registry as plugin_registry
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"DTypeConfig",
|
32
|
+
"KubeflowEnvironmentPlugin",
|
33
|
+
"LSFEnvironmentPlugin",
|
34
|
+
"LightningEnvironmentPlugin",
|
35
|
+
"MPIEnvironmentPlugin",
|
36
|
+
"PluginConfigBase",
|
37
|
+
"SLURMEnvironmentPlugin",
|
38
|
+
"TorchElasticEnvironmentPlugin",
|
39
|
+
"XLAEnvironmentPlugin",
|
40
|
+
"plugin_registry",
|
41
|
+
]
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.plugin.io import (
|
6
|
+
AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
|
7
|
+
)
|
8
|
+
from nshtrainer.trainer.plugin.io import PluginConfig as PluginConfig
|
9
|
+
from nshtrainer.trainer.plugin.io import PluginConfigBase as PluginConfigBase
|
10
|
+
from nshtrainer.trainer.plugin.io import (
|
11
|
+
TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
|
12
|
+
)
|
13
|
+
from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
|
14
|
+
from nshtrainer.trainer.plugin.io import plugin_registry as plugin_registry
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"AsyncCheckpointIOPlugin",
|
18
|
+
"PluginConfig",
|
19
|
+
"PluginConfigBase",
|
20
|
+
"TorchCheckpointIOPlugin",
|
21
|
+
"XLACheckpointIOPlugin",
|
22
|
+
"plugin_registry",
|
23
|
+
]
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.plugin.layer_sync import PluginConfigBase as PluginConfigBase
|
6
|
+
from nshtrainer.trainer.plugin.layer_sync import (
|
7
|
+
TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
|
8
|
+
)
|
9
|
+
from nshtrainer.trainer.plugin.layer_sync import plugin_registry as plugin_registry
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"PluginConfigBase",
|
13
|
+
"TorchSyncBatchNormPlugin",
|
14
|
+
"plugin_registry",
|
15
|
+
]
|