nshtrainer 1.0.0b29__py3-none-any.whl → 1.0.0b31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/configs/__init__.py +95 -3
  3. nshtrainer/configs/trainer/__init__.py +103 -3
  4. nshtrainer/configs/trainer/_config/__init__.py +10 -6
  5. nshtrainer/configs/trainer/accelerator/__init__.py +25 -0
  6. nshtrainer/configs/trainer/plugin/__init__.py +98 -0
  7. nshtrainer/configs/trainer/plugin/base/__init__.py +13 -0
  8. nshtrainer/configs/trainer/plugin/environment/__init__.py +41 -0
  9. nshtrainer/configs/trainer/plugin/io/__init__.py +23 -0
  10. nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +15 -0
  11. nshtrainer/configs/trainer/plugin/precision/__init__.py +43 -0
  12. nshtrainer/configs/trainer/strategy/__init__.py +11 -0
  13. nshtrainer/configs/trainer/trainer/__init__.py +2 -0
  14. nshtrainer/data/datamodule.py +2 -0
  15. nshtrainer/model/base.py +2 -0
  16. nshtrainer/trainer/__init__.py +2 -0
  17. nshtrainer/trainer/_config.py +3 -47
  18. nshtrainer/trainer/accelerator.py +86 -0
  19. nshtrainer/trainer/plugin/__init__.py +10 -0
  20. nshtrainer/trainer/plugin/base.py +33 -0
  21. nshtrainer/trainer/plugin/environment.py +128 -0
  22. nshtrainer/trainer/plugin/io.py +62 -0
  23. nshtrainer/trainer/plugin/layer_sync.py +25 -0
  24. nshtrainer/trainer/plugin/precision.py +163 -0
  25. nshtrainer/trainer/strategy.py +51 -0
  26. nshtrainer/trainer/trainer.py +8 -9
  27. nshtrainer/util/hparams.py +17 -0
  28. {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/METADATA +1 -1
  29. {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/RECORD +30 -13
  30. {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py CHANGED
@@ -14,6 +14,8 @@ from .metrics import MetricConfig as MetricConfig
14
14
  from .model import LightningModuleBase as LightningModuleBase
15
15
  from .trainer import Trainer as Trainer
16
16
  from .trainer import TrainerConfig as TrainerConfig
17
+ from .trainer import accelerator_registry as accelerator_registry
18
+ from .trainer import plugin_registry as plugin_registry
17
19
 
18
20
  try:
19
21
  from . import configs as configs
@@ -4,6 +4,8 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
+ from nshtrainer import accelerator_registry as accelerator_registry
8
+ from nshtrainer import plugin_registry as plugin_registry
7
9
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
10
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
11
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
@@ -97,7 +99,7 @@ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
97
99
  from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
98
100
  from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
99
101
  from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
100
- from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
102
+ from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
101
103
  from nshtrainer.trainer._config import (
102
104
  CheckpointCallbackConfig as CheckpointCallbackConfig,
103
105
  )
@@ -107,9 +109,71 @@ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippin
107
109
  from nshtrainer.trainer._config import (
108
110
  LearningRateMonitorConfig as LearningRateMonitorConfig,
109
111
  )
110
- from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
111
112
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
112
- from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
113
+ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
114
+ from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
115
+ from nshtrainer.trainer.accelerator import (
116
+ CUDAAcceleratorConfig as CUDAAcceleratorConfig,
117
+ )
118
+ from nshtrainer.trainer.accelerator import MPSAcceleratorConfig as MPSAcceleratorConfig
119
+ from nshtrainer.trainer.accelerator import XLAAcceleratorConfig as XLAAcceleratorConfig
120
+ from nshtrainer.trainer.plugin import PluginConfig as PluginConfig
121
+ from nshtrainer.trainer.plugin import PluginConfigBase as PluginConfigBase
122
+ from nshtrainer.trainer.plugin.environment import (
123
+ KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
124
+ )
125
+ from nshtrainer.trainer.plugin.environment import (
126
+ LightningEnvironmentPlugin as LightningEnvironmentPlugin,
127
+ )
128
+ from nshtrainer.trainer.plugin.environment import (
129
+ LSFEnvironmentPlugin as LSFEnvironmentPlugin,
130
+ )
131
+ from nshtrainer.trainer.plugin.environment import (
132
+ MPIEnvironmentPlugin as MPIEnvironmentPlugin,
133
+ )
134
+ from nshtrainer.trainer.plugin.environment import (
135
+ SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
136
+ )
137
+ from nshtrainer.trainer.plugin.environment import (
138
+ TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
139
+ )
140
+ from nshtrainer.trainer.plugin.environment import (
141
+ XLAEnvironmentPlugin as XLAEnvironmentPlugin,
142
+ )
143
+ from nshtrainer.trainer.plugin.io import (
144
+ AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
145
+ )
146
+ from nshtrainer.trainer.plugin.io import (
147
+ TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
148
+ )
149
+ from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
150
+ from nshtrainer.trainer.plugin.layer_sync import (
151
+ TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
152
+ )
153
+ from nshtrainer.trainer.plugin.precision import (
154
+ BitsandbytesPluginConfig as BitsandbytesPluginConfig,
155
+ )
156
+ from nshtrainer.trainer.plugin.precision import (
157
+ DeepSpeedPluginConfig as DeepSpeedPluginConfig,
158
+ )
159
+ from nshtrainer.trainer.plugin.precision import (
160
+ DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
161
+ )
162
+ from nshtrainer.trainer.plugin.precision import (
163
+ FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
164
+ )
165
+ from nshtrainer.trainer.plugin.precision import (
166
+ HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
167
+ )
168
+ from nshtrainer.trainer.plugin.precision import (
169
+ MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
170
+ )
171
+ from nshtrainer.trainer.plugin.precision import (
172
+ TransformerEnginePluginConfig as TransformerEnginePluginConfig,
173
+ )
174
+ from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
175
+ from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
176
+ from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
113
177
  from nshtrainer.util._environment_info import (
114
178
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
115
179
  )
@@ -157,17 +221,22 @@ from . import trainer as trainer
157
221
  from . import util as util
158
222
 
159
223
  __all__ = [
224
+ "AcceleratorConfig",
160
225
  "AcceleratorConfigBase",
161
226
  "ActSaveConfig",
162
227
  "ActSaveLoggerConfig",
163
228
  "AdamWConfig",
164
229
  "AdvancedProfilerConfig",
230
+ "AsyncCheckpointIOPlugin",
165
231
  "BaseCheckpointCallbackConfig",
166
232
  "BaseLoggerConfig",
167
233
  "BaseNonlinearityConfig",
168
234
  "BaseProfilerConfig",
169
235
  "BestCheckpointCallbackConfig",
236
+ "BitsandbytesPluginConfig",
237
+ "CPUAcceleratorConfig",
170
238
  "CSVLoggerConfig",
239
+ "CUDAAcceleratorConfig",
171
240
  "CallbackConfig",
172
241
  "CallbackConfigBase",
173
242
  "CheckpointCallbackConfig",
@@ -175,8 +244,10 @@ __all__ = [
175
244
  "CheckpointSavingConfig",
176
245
  "DTypeConfig",
177
246
  "DebugFlagCallbackConfig",
247
+ "DeepSpeedPluginConfig",
178
248
  "DirectoryConfig",
179
249
  "DirectorySetupCallbackConfig",
250
+ "DoublePrecisionPluginConfig",
180
251
  "DurationConfig",
181
252
  "ELUNonlinearityConfig",
182
253
  "EMACallbackConfig",
@@ -193,30 +264,39 @@ __all__ = [
193
264
  "EnvironmentSnapshotConfig",
194
265
  "EpochTimerCallbackConfig",
195
266
  "EpochsConfig",
267
+ "FSDPPrecisionPluginConfig",
196
268
  "FiniteChecksCallbackConfig",
197
269
  "GELUNonlinearityConfig",
198
270
  "GitRepositoryConfig",
199
271
  "GradientClippingConfig",
200
272
  "GradientSkippingCallbackConfig",
273
+ "HalfPrecisionPluginConfig",
201
274
  "HuggingFaceHubAutoCreateConfig",
202
275
  "HuggingFaceHubConfig",
276
+ "KubeflowEnvironmentPlugin",
203
277
  "LRSchedulerConfig",
204
278
  "LRSchedulerConfigBase",
279
+ "LSFEnvironmentPlugin",
205
280
  "LastCheckpointCallbackConfig",
206
281
  "LeakyReLUNonlinearityConfig",
207
282
  "LearningRateMonitorConfig",
283
+ "LightningEnvironmentPlugin",
208
284
  "LinearWarmupCosineDecayLRSchedulerConfig",
209
285
  "LogEpochCallbackConfig",
210
286
  "LoggerConfig",
211
287
  "MLPConfig",
288
+ "MPIEnvironmentPlugin",
289
+ "MPSAcceleratorConfig",
212
290
  "MetricConfig",
213
291
  "MishNonlinearityConfig",
292
+ "MixedPrecisionPluginConfig",
214
293
  "NonlinearityConfig",
215
294
  "NormLoggingCallbackConfig",
216
295
  "OnExceptionCheckpointCallbackConfig",
217
296
  "OptimizerConfig",
218
297
  "OptimizerConfigBase",
219
298
  "PReLUConfig",
299
+ "PluginConfig",
220
300
  "PluginConfigBase",
221
301
  "PrintTableMetricsCallbackConfig",
222
302
  "ProfilerConfig",
@@ -224,6 +304,7 @@ __all__ = [
224
304
  "RLPSanityChecksCallbackConfig",
225
305
  "ReLUNonlinearityConfig",
226
306
  "ReduceLROnPlateauConfig",
307
+ "SLURMEnvironmentPlugin",
227
308
  "SanityCheckingConfig",
228
309
  "SharedParametersCallbackConfig",
229
310
  "SiLUNonlinearityConfig",
@@ -233,25 +314,36 @@ __all__ = [
233
314
  "SoftplusNonlinearityConfig",
234
315
  "SoftsignNonlinearityConfig",
235
316
  "StepsConfig",
317
+ "StrategyConfig",
236
318
  "StrategyConfigBase",
237
319
  "SwiGLUNonlinearityConfig",
238
320
  "SwishNonlinearityConfig",
239
321
  "TanhNonlinearityConfig",
240
322
  "TensorboardLoggerConfig",
241
323
  "TimeCheckpointCallbackConfig",
324
+ "TorchCheckpointIOPlugin",
325
+ "TorchElasticEnvironmentPlugin",
326
+ "TorchSyncBatchNormPlugin",
242
327
  "TrainerConfig",
328
+ "TransformerEnginePluginConfig",
243
329
  "WandbLoggerConfig",
244
330
  "WandbUploadCodeCallbackConfig",
245
331
  "WandbWatchCallbackConfig",
332
+ "XLAAcceleratorConfig",
333
+ "XLACheckpointIOPlugin",
334
+ "XLAEnvironmentPlugin",
335
+ "XLAPluginConfig",
246
336
  "_checkpoint",
247
337
  "_directory",
248
338
  "_hf_hub",
339
+ "accelerator_registry",
249
340
  "callbacks",
250
341
  "loggers",
251
342
  "lr_scheduler",
252
343
  "metrics",
253
344
  "nn",
254
345
  "optimizer",
346
+ "plugin_registry",
255
347
  "profiler",
256
348
  "trainer",
257
349
  "util",
@@ -3,7 +3,9 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
6
- from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
6
+ from nshtrainer.trainer import accelerator_registry as accelerator_registry
7
+ from nshtrainer.trainer import plugin_registry as plugin_registry
8
+ from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
7
9
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
8
10
  from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
9
11
  from nshtrainer.trainer._config import (
@@ -41,7 +43,6 @@ from nshtrainer.trainer._config import (
41
43
  from nshtrainer.trainer._config import (
42
44
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
43
45
  )
44
- from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
45
46
  from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
46
47
  from nshtrainer.trainer._config import (
47
48
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
@@ -50,7 +51,7 @@ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingCon
50
51
  from nshtrainer.trainer._config import (
51
52
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
52
53
  )
53
- from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
54
+ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
54
55
  from nshtrainer.trainer._config import (
55
56
  TensorboardLoggerConfig as TensorboardLoggerConfig,
56
57
  )
@@ -58,43 +59,142 @@ from nshtrainer.trainer._config import (
58
59
  TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
59
60
  )
60
61
  from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
62
+ from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
63
+ from nshtrainer.trainer.accelerator import (
64
+ CUDAAcceleratorConfig as CUDAAcceleratorConfig,
65
+ )
66
+ from nshtrainer.trainer.accelerator import MPSAcceleratorConfig as MPSAcceleratorConfig
67
+ from nshtrainer.trainer.accelerator import XLAAcceleratorConfig as XLAAcceleratorConfig
68
+ from nshtrainer.trainer.plugin import PluginConfig as PluginConfig
69
+ from nshtrainer.trainer.plugin import PluginConfigBase as PluginConfigBase
70
+ from nshtrainer.trainer.plugin.environment import (
71
+ KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
72
+ )
73
+ from nshtrainer.trainer.plugin.environment import (
74
+ LightningEnvironmentPlugin as LightningEnvironmentPlugin,
75
+ )
76
+ from nshtrainer.trainer.plugin.environment import (
77
+ LSFEnvironmentPlugin as LSFEnvironmentPlugin,
78
+ )
79
+ from nshtrainer.trainer.plugin.environment import (
80
+ MPIEnvironmentPlugin as MPIEnvironmentPlugin,
81
+ )
82
+ from nshtrainer.trainer.plugin.environment import (
83
+ SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
84
+ )
85
+ from nshtrainer.trainer.plugin.environment import (
86
+ TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
87
+ )
88
+ from nshtrainer.trainer.plugin.environment import (
89
+ XLAEnvironmentPlugin as XLAEnvironmentPlugin,
90
+ )
91
+ from nshtrainer.trainer.plugin.io import (
92
+ AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
93
+ )
94
+ from nshtrainer.trainer.plugin.io import (
95
+ TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
96
+ )
97
+ from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
98
+ from nshtrainer.trainer.plugin.layer_sync import (
99
+ TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
100
+ )
101
+ from nshtrainer.trainer.plugin.precision import (
102
+ BitsandbytesPluginConfig as BitsandbytesPluginConfig,
103
+ )
104
+ from nshtrainer.trainer.plugin.precision import (
105
+ DeepSpeedPluginConfig as DeepSpeedPluginConfig,
106
+ )
107
+ from nshtrainer.trainer.plugin.precision import (
108
+ DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
109
+ )
110
+ from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
111
+ from nshtrainer.trainer.plugin.precision import (
112
+ FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
113
+ )
114
+ from nshtrainer.trainer.plugin.precision import (
115
+ HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
116
+ )
117
+ from nshtrainer.trainer.plugin.precision import (
118
+ MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
119
+ )
120
+ from nshtrainer.trainer.plugin.precision import (
121
+ TransformerEnginePluginConfig as TransformerEnginePluginConfig,
122
+ )
123
+ from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
124
+ from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
125
+ from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
61
126
 
62
127
  from . import _config as _config
128
+ from . import accelerator as accelerator
129
+ from . import plugin as plugin
130
+ from . import strategy as strategy
63
131
  from . import trainer as trainer
64
132
 
65
133
  __all__ = [
134
+ "AcceleratorConfig",
66
135
  "AcceleratorConfigBase",
67
136
  "ActSaveLoggerConfig",
137
+ "AsyncCheckpointIOPlugin",
68
138
  "BaseLoggerConfig",
69
139
  "BestCheckpointCallbackConfig",
140
+ "BitsandbytesPluginConfig",
141
+ "CPUAcceleratorConfig",
70
142
  "CSVLoggerConfig",
143
+ "CUDAAcceleratorConfig",
71
144
  "CallbackConfig",
72
145
  "CallbackConfigBase",
73
146
  "CheckpointCallbackConfig",
74
147
  "CheckpointSavingConfig",
148
+ "DTypeConfig",
75
149
  "DebugFlagCallbackConfig",
150
+ "DeepSpeedPluginConfig",
76
151
  "DirectoryConfig",
152
+ "DoublePrecisionPluginConfig",
77
153
  "EarlyStoppingCallbackConfig",
78
154
  "EnvironmentConfig",
155
+ "FSDPPrecisionPluginConfig",
79
156
  "GradientClippingConfig",
157
+ "HalfPrecisionPluginConfig",
80
158
  "HuggingFaceHubConfig",
159
+ "KubeflowEnvironmentPlugin",
160
+ "LSFEnvironmentPlugin",
81
161
  "LastCheckpointCallbackConfig",
82
162
  "LearningRateMonitorConfig",
163
+ "LightningEnvironmentPlugin",
83
164
  "LogEpochCallbackConfig",
84
165
  "LoggerConfig",
166
+ "MPIEnvironmentPlugin",
167
+ "MPSAcceleratorConfig",
85
168
  "MetricConfig",
169
+ "MixedPrecisionPluginConfig",
86
170
  "NormLoggingCallbackConfig",
87
171
  "OnExceptionCheckpointCallbackConfig",
172
+ "PluginConfig",
88
173
  "PluginConfigBase",
89
174
  "ProfilerConfig",
90
175
  "RLPSanityChecksCallbackConfig",
176
+ "SLURMEnvironmentPlugin",
91
177
  "SanityCheckingConfig",
92
178
  "SharedParametersCallbackConfig",
179
+ "StrategyConfig",
93
180
  "StrategyConfigBase",
94
181
  "TensorboardLoggerConfig",
95
182
  "TimeCheckpointCallbackConfig",
183
+ "TorchCheckpointIOPlugin",
184
+ "TorchElasticEnvironmentPlugin",
185
+ "TorchSyncBatchNormPlugin",
96
186
  "TrainerConfig",
187
+ "TransformerEnginePluginConfig",
97
188
  "WandbLoggerConfig",
189
+ "XLAAcceleratorConfig",
190
+ "XLACheckpointIOPlugin",
191
+ "XLAEnvironmentPlugin",
192
+ "XLAPluginConfig",
98
193
  "_config",
194
+ "accelerator",
195
+ "accelerator_registry",
196
+ "plugin",
197
+ "plugin_registry",
198
+ "strategy",
99
199
  "trainer",
100
200
  ]
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
5
+ from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
6
6
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
7
7
  from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
8
8
  from nshtrainer.trainer._config import (
@@ -40,7 +40,7 @@ from nshtrainer.trainer._config import (
40
40
  from nshtrainer.trainer._config import (
41
41
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
42
42
  )
43
- from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
43
+ from nshtrainer.trainer._config import PluginConfig as PluginConfig
44
44
  from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
45
45
  from nshtrainer.trainer._config import (
46
46
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
@@ -49,7 +49,7 @@ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingCon
49
49
  from nshtrainer.trainer._config import (
50
50
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
51
51
  )
52
- from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
52
+ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
53
53
  from nshtrainer.trainer._config import (
54
54
  TensorboardLoggerConfig as TensorboardLoggerConfig,
55
55
  )
@@ -58,9 +58,11 @@ from nshtrainer.trainer._config import (
58
58
  )
59
59
  from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
60
60
  from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
61
+ from nshtrainer.trainer._config import accelerator_registry as accelerator_registry
62
+ from nshtrainer.trainer._config import plugin_registry as plugin_registry
61
63
 
62
64
  __all__ = [
63
- "AcceleratorConfigBase",
65
+ "AcceleratorConfig",
64
66
  "ActSaveLoggerConfig",
65
67
  "BaseLoggerConfig",
66
68
  "BestCheckpointCallbackConfig",
@@ -82,14 +84,16 @@ __all__ = [
82
84
  "MetricConfig",
83
85
  "NormLoggingCallbackConfig",
84
86
  "OnExceptionCheckpointCallbackConfig",
85
- "PluginConfigBase",
87
+ "PluginConfig",
86
88
  "ProfilerConfig",
87
89
  "RLPSanityChecksCallbackConfig",
88
90
  "SanityCheckingConfig",
89
91
  "SharedParametersCallbackConfig",
90
- "StrategyConfigBase",
92
+ "StrategyConfig",
91
93
  "TensorboardLoggerConfig",
92
94
  "TimeCheckpointCallbackConfig",
93
95
  "TrainerConfig",
94
96
  "WandbLoggerConfig",
97
+ "accelerator_registry",
98
+ "plugin_registry",
95
99
  ]
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.accelerator import AcceleratorConfig as AcceleratorConfig
6
+ from nshtrainer.trainer.accelerator import (
7
+ AcceleratorConfigBase as AcceleratorConfigBase,
8
+ )
9
+ from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
10
+ from nshtrainer.trainer.accelerator import (
11
+ CUDAAcceleratorConfig as CUDAAcceleratorConfig,
12
+ )
13
+ from nshtrainer.trainer.accelerator import MPSAcceleratorConfig as MPSAcceleratorConfig
14
+ from nshtrainer.trainer.accelerator import XLAAcceleratorConfig as XLAAcceleratorConfig
15
+ from nshtrainer.trainer.accelerator import accelerator_registry as accelerator_registry
16
+
17
+ __all__ = [
18
+ "AcceleratorConfig",
19
+ "AcceleratorConfigBase",
20
+ "CPUAcceleratorConfig",
21
+ "CUDAAcceleratorConfig",
22
+ "MPSAcceleratorConfig",
23
+ "XLAAcceleratorConfig",
24
+ "accelerator_registry",
25
+ ]
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin import PluginConfig as PluginConfig
6
+ from nshtrainer.trainer.plugin import PluginConfigBase as PluginConfigBase
7
+ from nshtrainer.trainer.plugin import plugin_registry as plugin_registry
8
+ from nshtrainer.trainer.plugin.environment import (
9
+ KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
10
+ )
11
+ from nshtrainer.trainer.plugin.environment import (
12
+ LightningEnvironmentPlugin as LightningEnvironmentPlugin,
13
+ )
14
+ from nshtrainer.trainer.plugin.environment import (
15
+ LSFEnvironmentPlugin as LSFEnvironmentPlugin,
16
+ )
17
+ from nshtrainer.trainer.plugin.environment import (
18
+ MPIEnvironmentPlugin as MPIEnvironmentPlugin,
19
+ )
20
+ from nshtrainer.trainer.plugin.environment import (
21
+ SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
22
+ )
23
+ from nshtrainer.trainer.plugin.environment import (
24
+ TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
25
+ )
26
+ from nshtrainer.trainer.plugin.environment import (
27
+ XLAEnvironmentPlugin as XLAEnvironmentPlugin,
28
+ )
29
+ from nshtrainer.trainer.plugin.io import (
30
+ AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
31
+ )
32
+ from nshtrainer.trainer.plugin.io import (
33
+ TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
34
+ )
35
+ from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
36
+ from nshtrainer.trainer.plugin.layer_sync import (
37
+ TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
38
+ )
39
+ from nshtrainer.trainer.plugin.precision import (
40
+ BitsandbytesPluginConfig as BitsandbytesPluginConfig,
41
+ )
42
+ from nshtrainer.trainer.plugin.precision import (
43
+ DeepSpeedPluginConfig as DeepSpeedPluginConfig,
44
+ )
45
+ from nshtrainer.trainer.plugin.precision import (
46
+ DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
47
+ )
48
+ from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
49
+ from nshtrainer.trainer.plugin.precision import (
50
+ FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
51
+ )
52
+ from nshtrainer.trainer.plugin.precision import (
53
+ HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
54
+ )
55
+ from nshtrainer.trainer.plugin.precision import (
56
+ MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
57
+ )
58
+ from nshtrainer.trainer.plugin.precision import (
59
+ TransformerEnginePluginConfig as TransformerEnginePluginConfig,
60
+ )
61
+ from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
62
+
63
+ from . import base as base
64
+ from . import environment as environment
65
+ from . import io as io
66
+ from . import layer_sync as layer_sync
67
+ from . import precision as precision
68
+
69
+ __all__ = [
70
+ "AsyncCheckpointIOPlugin",
71
+ "BitsandbytesPluginConfig",
72
+ "DTypeConfig",
73
+ "DeepSpeedPluginConfig",
74
+ "DoublePrecisionPluginConfig",
75
+ "FSDPPrecisionPluginConfig",
76
+ "HalfPrecisionPluginConfig",
77
+ "KubeflowEnvironmentPlugin",
78
+ "LSFEnvironmentPlugin",
79
+ "LightningEnvironmentPlugin",
80
+ "MPIEnvironmentPlugin",
81
+ "MixedPrecisionPluginConfig",
82
+ "PluginConfig",
83
+ "PluginConfigBase",
84
+ "SLURMEnvironmentPlugin",
85
+ "TorchCheckpointIOPlugin",
86
+ "TorchElasticEnvironmentPlugin",
87
+ "TorchSyncBatchNormPlugin",
88
+ "TransformerEnginePluginConfig",
89
+ "XLACheckpointIOPlugin",
90
+ "XLAEnvironmentPlugin",
91
+ "XLAPluginConfig",
92
+ "base",
93
+ "environment",
94
+ "io",
95
+ "layer_sync",
96
+ "plugin_registry",
97
+ "precision",
98
+ ]
@@ -0,0 +1,13 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin.base import PluginConfig as PluginConfig
6
+ from nshtrainer.trainer.plugin.base import PluginConfigBase as PluginConfigBase
7
+ from nshtrainer.trainer.plugin.base import plugin_registry as plugin_registry
8
+
9
+ __all__ = [
10
+ "PluginConfig",
11
+ "PluginConfigBase",
12
+ "plugin_registry",
13
+ ]
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin.environment import DTypeConfig as DTypeConfig
6
+ from nshtrainer.trainer.plugin.environment import (
7
+ KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
8
+ )
9
+ from nshtrainer.trainer.plugin.environment import (
10
+ LightningEnvironmentPlugin as LightningEnvironmentPlugin,
11
+ )
12
+ from nshtrainer.trainer.plugin.environment import (
13
+ LSFEnvironmentPlugin as LSFEnvironmentPlugin,
14
+ )
15
+ from nshtrainer.trainer.plugin.environment import (
16
+ MPIEnvironmentPlugin as MPIEnvironmentPlugin,
17
+ )
18
+ from nshtrainer.trainer.plugin.environment import PluginConfigBase as PluginConfigBase
19
+ from nshtrainer.trainer.plugin.environment import (
20
+ SLURMEnvironmentPlugin as SLURMEnvironmentPlugin,
21
+ )
22
+ from nshtrainer.trainer.plugin.environment import (
23
+ TorchElasticEnvironmentPlugin as TorchElasticEnvironmentPlugin,
24
+ )
25
+ from nshtrainer.trainer.plugin.environment import (
26
+ XLAEnvironmentPlugin as XLAEnvironmentPlugin,
27
+ )
28
+ from nshtrainer.trainer.plugin.environment import plugin_registry as plugin_registry
29
+
30
+ __all__ = [
31
+ "DTypeConfig",
32
+ "KubeflowEnvironmentPlugin",
33
+ "LSFEnvironmentPlugin",
34
+ "LightningEnvironmentPlugin",
35
+ "MPIEnvironmentPlugin",
36
+ "PluginConfigBase",
37
+ "SLURMEnvironmentPlugin",
38
+ "TorchElasticEnvironmentPlugin",
39
+ "XLAEnvironmentPlugin",
40
+ "plugin_registry",
41
+ ]
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin.io import (
6
+ AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
7
+ )
8
+ from nshtrainer.trainer.plugin.io import PluginConfig as PluginConfig
9
+ from nshtrainer.trainer.plugin.io import PluginConfigBase as PluginConfigBase
10
+ from nshtrainer.trainer.plugin.io import (
11
+ TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
12
+ )
13
+ from nshtrainer.trainer.plugin.io import XLACheckpointIOPlugin as XLACheckpointIOPlugin
14
+ from nshtrainer.trainer.plugin.io import plugin_registry as plugin_registry
15
+
16
+ __all__ = [
17
+ "AsyncCheckpointIOPlugin",
18
+ "PluginConfig",
19
+ "PluginConfigBase",
20
+ "TorchCheckpointIOPlugin",
21
+ "XLACheckpointIOPlugin",
22
+ "plugin_registry",
23
+ ]
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin.layer_sync import PluginConfigBase as PluginConfigBase
6
+ from nshtrainer.trainer.plugin.layer_sync import (
7
+ TorchSyncBatchNormPlugin as TorchSyncBatchNormPlugin,
8
+ )
9
+ from nshtrainer.trainer.plugin.layer_sync import plugin_registry as plugin_registry
10
+
11
+ __all__ = [
12
+ "PluginConfigBase",
13
+ "TorchSyncBatchNormPlugin",
14
+ "plugin_registry",
15
+ ]