nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
nshtrainer/config/__init__.py
CHANGED
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING
|
|
7
7
|
# Config/alias imports
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
|
-
from nshtrainer import BaseConfig as BaseConfig
|
11
10
|
from nshtrainer import MetricConfig as MetricConfig
|
11
|
+
from nshtrainer import TrainerConfig as TrainerConfig
|
12
12
|
from nshtrainer._checkpoint.loader import (
|
13
13
|
BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
|
14
14
|
)
|
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
|
22
22
|
from nshtrainer._checkpoint.loader import (
|
23
23
|
UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
|
24
24
|
)
|
25
|
+
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
25
26
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
26
27
|
from nshtrainer._hf_hub import (
|
27
28
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|
51
52
|
from nshtrainer.callbacks import (
|
52
53
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
53
54
|
)
|
55
|
+
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
54
56
|
from nshtrainer.callbacks import (
|
55
57
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
56
58
|
)
|
@@ -66,7 +68,6 @@ if TYPE_CHECKING:
|
|
66
68
|
from nshtrainer.callbacks import (
|
67
69
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
68
70
|
)
|
69
|
-
from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
|
70
71
|
from nshtrainer.callbacks import (
|
71
72
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
72
73
|
)
|
@@ -77,6 +78,7 @@ if TYPE_CHECKING:
|
|
77
78
|
from nshtrainer.callbacks.checkpoint._base import (
|
78
79
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
79
80
|
)
|
81
|
+
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
80
82
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
81
83
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
82
84
|
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
@@ -90,9 +92,6 @@ if TYPE_CHECKING:
|
|
90
92
|
from nshtrainer.lr_scheduler import (
|
91
93
|
ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
|
92
94
|
)
|
93
|
-
from nshtrainer.model import DirectoryConfig as DirectoryConfig
|
94
|
-
from nshtrainer.model import TrainerConfig as TrainerConfig
|
95
|
-
from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
|
96
95
|
from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
|
97
96
|
from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
|
98
97
|
from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
@@ -129,6 +128,7 @@ if TYPE_CHECKING:
|
|
129
128
|
from nshtrainer.trainer._config import (
|
130
129
|
CheckpointSavingConfig as CheckpointSavingConfig,
|
131
130
|
)
|
131
|
+
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
132
132
|
from nshtrainer.trainer._config import (
|
133
133
|
GradientClippingConfig as GradientClippingConfig,
|
134
134
|
)
|
@@ -179,272 +179,274 @@ else:
|
|
179
179
|
|
180
180
|
if name in globals():
|
181
181
|
return globals()[name]
|
182
|
-
if name == "
|
183
|
-
return importlib.import_module("nshtrainer").
|
184
|
-
if name == "
|
185
|
-
return importlib.import_module("nshtrainer").
|
186
|
-
if name == "HuggingFaceHubAutoCreateConfig":
|
187
|
-
return importlib.import_module(
|
188
|
-
"nshtrainer._hf_hub"
|
189
|
-
).HuggingFaceHubAutoCreateConfig
|
190
|
-
if name == "HuggingFaceHubConfig":
|
191
|
-
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
192
|
-
if name == "CallbackConfigBase":
|
193
|
-
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
194
|
-
if name == "OptimizerConfigBase":
|
195
|
-
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
182
|
+
if name == "ActSaveConfig":
|
183
|
+
return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
|
184
|
+
if name == "ActSaveLoggerConfig":
|
185
|
+
return importlib.import_module("nshtrainer.loggers").ActSaveLoggerConfig
|
196
186
|
if name == "AdamWConfig":
|
197
187
|
return importlib.import_module("nshtrainer.optimizer").AdamWConfig
|
198
|
-
if name == "
|
188
|
+
if name == "AdvancedProfilerConfig":
|
189
|
+
return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
|
190
|
+
if name == "BaseCheckpointCallbackConfig":
|
199
191
|
return importlib.import_module(
|
200
|
-
"nshtrainer.callbacks"
|
201
|
-
).
|
202
|
-
if name == "
|
203
|
-
return importlib.import_module("nshtrainer.
|
204
|
-
if name == "TrainerConfig":
|
205
|
-
return importlib.import_module("nshtrainer.model").TrainerConfig
|
206
|
-
if name == "EnvironmentConfig":
|
207
|
-
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
192
|
+
"nshtrainer.callbacks.checkpoint._base"
|
193
|
+
).BaseCheckpointCallbackConfig
|
194
|
+
if name == "BaseLoggerConfig":
|
195
|
+
return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
|
208
196
|
if name == "BaseNonlinearityConfig":
|
209
197
|
return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
|
210
|
-
if name == "
|
211
|
-
return importlib.import_module("nshtrainer.
|
212
|
-
if name == "
|
213
|
-
return importlib.import_module("nshtrainer.nn").PReLUConfig
|
214
|
-
if name == "LeakyReLUNonlinearityConfig":
|
215
|
-
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
216
|
-
if name == "SwiGLUNonlinearityConfig":
|
198
|
+
if name == "BaseProfilerConfig":
|
199
|
+
return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
|
200
|
+
if name == "BestCheckpointCallbackConfig":
|
217
201
|
return importlib.import_module(
|
218
|
-
"nshtrainer.
|
219
|
-
).
|
220
|
-
if name == "
|
221
|
-
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
222
|
-
if name == "SiLUNonlinearityConfig":
|
223
|
-
return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
|
224
|
-
if name == "SigmoidNonlinearityConfig":
|
225
|
-
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
226
|
-
if name == "SoftplusNonlinearityConfig":
|
227
|
-
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
228
|
-
if name == "ELUNonlinearityConfig":
|
229
|
-
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
230
|
-
if name == "SoftmaxNonlinearityConfig":
|
231
|
-
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
232
|
-
if name == "GELUNonlinearityConfig":
|
233
|
-
return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
|
234
|
-
if name == "SwishNonlinearityConfig":
|
235
|
-
return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
|
236
|
-
if name == "MishNonlinearityConfig":
|
237
|
-
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
238
|
-
if name == "TanhNonlinearityConfig":
|
239
|
-
return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
|
240
|
-
if name == "ReLUNonlinearityConfig":
|
241
|
-
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
242
|
-
if name == "LRSchedulerConfigBase":
|
202
|
+
"nshtrainer.callbacks"
|
203
|
+
).BestCheckpointCallbackConfig
|
204
|
+
if name == "BestCheckpointStrategyConfig":
|
243
205
|
return importlib.import_module(
|
244
|
-
"nshtrainer.
|
245
|
-
).
|
246
|
-
if name == "
|
206
|
+
"nshtrainer._checkpoint.loader"
|
207
|
+
).BestCheckpointStrategyConfig
|
208
|
+
if name == "CSVLoggerConfig":
|
209
|
+
return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
|
210
|
+
if name == "CallbackConfigBase":
|
211
|
+
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
212
|
+
if name == "CheckpointLoadingConfig":
|
247
213
|
return importlib.import_module(
|
248
|
-
"nshtrainer.
|
249
|
-
).
|
250
|
-
if name == "
|
214
|
+
"nshtrainer.trainer._config"
|
215
|
+
).CheckpointLoadingConfig
|
216
|
+
if name == "CheckpointMetadata":
|
251
217
|
return importlib.import_module(
|
252
|
-
"nshtrainer.
|
253
|
-
).
|
254
|
-
if name == "
|
255
|
-
return importlib.import_module(
|
256
|
-
|
257
|
-
|
258
|
-
if name == "
|
259
|
-
return importlib.import_module("nshtrainer.
|
260
|
-
if name == "
|
218
|
+
"nshtrainer._checkpoint.loader"
|
219
|
+
).CheckpointMetadata
|
220
|
+
if name == "CheckpointSavingConfig":
|
221
|
+
return importlib.import_module(
|
222
|
+
"nshtrainer.trainer._config"
|
223
|
+
).CheckpointSavingConfig
|
224
|
+
if name == "DTypeConfig":
|
225
|
+
return importlib.import_module("nshtrainer.util.config").DTypeConfig
|
226
|
+
if name == "DebugFlagCallbackConfig":
|
261
227
|
return importlib.import_module(
|
262
228
|
"nshtrainer.callbacks"
|
263
|
-
).
|
264
|
-
if name == "
|
229
|
+
).DebugFlagCallbackConfig
|
230
|
+
if name == "DirectoryConfig":
|
231
|
+
return importlib.import_module("nshtrainer._directory").DirectoryConfig
|
232
|
+
if name == "DirectorySetupCallbackConfig":
|
265
233
|
return importlib.import_module(
|
266
234
|
"nshtrainer.callbacks"
|
267
|
-
).
|
268
|
-
if name == "
|
269
|
-
return importlib.import_module("nshtrainer.
|
270
|
-
if name == "
|
235
|
+
).DirectorySetupCallbackConfig
|
236
|
+
if name == "ELUNonlinearityConfig":
|
237
|
+
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
238
|
+
if name == "EMACallbackConfig":
|
239
|
+
return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
|
240
|
+
if name == "EarlyStoppingCallbackConfig":
|
241
|
+
return importlib.import_module(
|
242
|
+
"nshtrainer.callbacks"
|
243
|
+
).EarlyStoppingCallbackConfig
|
244
|
+
if name == "EnvironmentCUDAConfig":
|
271
245
|
return importlib.import_module(
|
272
246
|
"nshtrainer.util._environment_info"
|
273
|
-
).
|
274
|
-
if name == "
|
247
|
+
).EnvironmentCUDAConfig
|
248
|
+
if name == "EnvironmentClassInformationConfig":
|
275
249
|
return importlib.import_module(
|
276
250
|
"nshtrainer.util._environment_info"
|
277
|
-
).
|
251
|
+
).EnvironmentClassInformationConfig
|
252
|
+
if name == "EnvironmentConfig":
|
253
|
+
return importlib.import_module(
|
254
|
+
"nshtrainer.trainer._config"
|
255
|
+
).EnvironmentConfig
|
278
256
|
if name == "EnvironmentGPUConfig":
|
279
257
|
return importlib.import_module(
|
280
258
|
"nshtrainer.util._environment_info"
|
281
259
|
).EnvironmentGPUConfig
|
282
|
-
if name == "EnvironmentPackageConfig":
|
283
|
-
return importlib.import_module(
|
284
|
-
"nshtrainer.util._environment_info"
|
285
|
-
).EnvironmentPackageConfig
|
286
260
|
if name == "EnvironmentHardwareConfig":
|
287
261
|
return importlib.import_module(
|
288
262
|
"nshtrainer.util._environment_info"
|
289
263
|
).EnvironmentHardwareConfig
|
290
|
-
if name == "
|
291
|
-
return importlib.import_module(
|
292
|
-
"nshtrainer.util._environment_info"
|
293
|
-
).EnvironmentSnapshotConfig
|
294
|
-
if name == "EnvironmentClassInformationConfig":
|
264
|
+
if name == "EnvironmentLSFInformationConfig":
|
295
265
|
return importlib.import_module(
|
296
266
|
"nshtrainer.util._environment_info"
|
297
|
-
).
|
298
|
-
if name == "
|
267
|
+
).EnvironmentLSFInformationConfig
|
268
|
+
if name == "EnvironmentLinuxEnvironmentConfig":
|
299
269
|
return importlib.import_module(
|
300
270
|
"nshtrainer.util._environment_info"
|
301
|
-
).
|
302
|
-
if name == "
|
271
|
+
).EnvironmentLinuxEnvironmentConfig
|
272
|
+
if name == "EnvironmentPackageConfig":
|
303
273
|
return importlib.import_module(
|
304
274
|
"nshtrainer.util._environment_info"
|
305
|
-
).
|
275
|
+
).EnvironmentPackageConfig
|
306
276
|
if name == "EnvironmentSLURMInformationConfig":
|
307
277
|
return importlib.import_module(
|
308
278
|
"nshtrainer.util._environment_info"
|
309
279
|
).EnvironmentSLURMInformationConfig
|
310
|
-
if name == "
|
311
|
-
return importlib.import_module("nshtrainer.util.config").EpochsConfig
|
312
|
-
if name == "StepsConfig":
|
313
|
-
return importlib.import_module("nshtrainer.util.config").StepsConfig
|
314
|
-
if name == "DTypeConfig":
|
315
|
-
return importlib.import_module("nshtrainer.util.config").DTypeConfig
|
316
|
-
if name == "CheckpointLoadingConfig":
|
280
|
+
if name == "EnvironmentSnapshotConfig":
|
317
281
|
return importlib.import_module(
|
318
|
-
"nshtrainer.
|
319
|
-
).
|
320
|
-
if name == "
|
282
|
+
"nshtrainer.util._environment_info"
|
283
|
+
).EnvironmentSnapshotConfig
|
284
|
+
if name == "EpochTimerCallbackConfig":
|
321
285
|
return importlib.import_module(
|
322
|
-
"nshtrainer.
|
323
|
-
).
|
324
|
-
if name == "
|
286
|
+
"nshtrainer.callbacks"
|
287
|
+
).EpochTimerCallbackConfig
|
288
|
+
if name == "EpochsConfig":
|
289
|
+
return importlib.import_module("nshtrainer.util.config").EpochsConfig
|
290
|
+
if name == "FiniteChecksCallbackConfig":
|
325
291
|
return importlib.import_module(
|
326
292
|
"nshtrainer.callbacks"
|
327
|
-
).
|
293
|
+
).FiniteChecksCallbackConfig
|
294
|
+
if name == "GELUNonlinearityConfig":
|
295
|
+
return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
|
296
|
+
if name == "GitRepositoryConfig":
|
297
|
+
return importlib.import_module(
|
298
|
+
"nshtrainer.util._environment_info"
|
299
|
+
).GitRepositoryConfig
|
328
300
|
if name == "GradientClippingConfig":
|
329
301
|
return importlib.import_module(
|
330
302
|
"nshtrainer.trainer._config"
|
331
303
|
).GradientClippingConfig
|
332
|
-
if name == "
|
333
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
334
|
-
if name == "RLPSanityChecksCallbackConfig":
|
304
|
+
if name == "GradientSkippingCallbackConfig":
|
335
305
|
return importlib.import_module(
|
336
306
|
"nshtrainer.callbacks"
|
337
|
-
).
|
338
|
-
if name == "
|
307
|
+
).GradientSkippingCallbackConfig
|
308
|
+
if name == "HuggingFaceHubAutoCreateConfig":
|
339
309
|
return importlib.import_module(
|
340
|
-
"nshtrainer.
|
341
|
-
).
|
342
|
-
if name == "
|
310
|
+
"nshtrainer._hf_hub"
|
311
|
+
).HuggingFaceHubAutoCreateConfig
|
312
|
+
if name == "HuggingFaceHubConfig":
|
313
|
+
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
314
|
+
if name == "LRSchedulerConfigBase":
|
343
315
|
return importlib.import_module(
|
344
|
-
"nshtrainer.
|
345
|
-
).
|
316
|
+
"nshtrainer.lr_scheduler"
|
317
|
+
).LRSchedulerConfigBase
|
346
318
|
if name == "LastCheckpointCallbackConfig":
|
347
319
|
return importlib.import_module(
|
348
320
|
"nshtrainer.callbacks"
|
349
321
|
).LastCheckpointCallbackConfig
|
350
|
-
if name == "
|
322
|
+
if name == "LastCheckpointStrategyConfig":
|
323
|
+
return importlib.import_module(
|
324
|
+
"nshtrainer._checkpoint.loader"
|
325
|
+
).LastCheckpointStrategyConfig
|
326
|
+
if name == "LeakyReLUNonlinearityConfig":
|
327
|
+
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
328
|
+
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
329
|
+
return importlib.import_module(
|
330
|
+
"nshtrainer.lr_scheduler"
|
331
|
+
).LinearWarmupCosineDecayLRSchedulerConfig
|
332
|
+
if name == "LogEpochCallbackConfig":
|
351
333
|
return importlib.import_module(
|
352
334
|
"nshtrainer.callbacks"
|
353
|
-
).
|
354
|
-
if name == "
|
335
|
+
).LogEpochCallbackConfig
|
336
|
+
if name == "LoggingConfig":
|
337
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
338
|
+
if name == "MLPConfig":
|
339
|
+
return importlib.import_module("nshtrainer.nn").MLPConfig
|
340
|
+
if name == "MetricConfig":
|
341
|
+
return importlib.import_module("nshtrainer").MetricConfig
|
342
|
+
if name == "MishNonlinearityConfig":
|
343
|
+
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
344
|
+
if name == "NormLoggingCallbackConfig":
|
355
345
|
return importlib.import_module(
|
356
|
-
"nshtrainer.
|
357
|
-
).
|
358
|
-
if name == "
|
346
|
+
"nshtrainer.callbacks"
|
347
|
+
).NormLoggingCallbackConfig
|
348
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
359
349
|
return importlib.import_module(
|
360
350
|
"nshtrainer.callbacks"
|
361
|
-
).
|
351
|
+
).OnExceptionCheckpointCallbackConfig
|
362
352
|
if name == "OptimizationConfig":
|
363
353
|
return importlib.import_module(
|
364
354
|
"nshtrainer.trainer._config"
|
365
355
|
).OptimizationConfig
|
366
|
-
if name == "
|
356
|
+
if name == "OptimizerConfigBase":
|
357
|
+
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
358
|
+
if name == "PReLUConfig":
|
359
|
+
return importlib.import_module("nshtrainer.nn").PReLUConfig
|
360
|
+
if name == "PrintTableMetricsCallbackConfig":
|
367
361
|
return importlib.import_module(
|
368
362
|
"nshtrainer.callbacks"
|
369
|
-
).
|
370
|
-
if name == "
|
371
|
-
return importlib.import_module(
|
372
|
-
|
373
|
-
).CheckpointMetadata
|
374
|
-
if name == "BestCheckpointStrategyConfig":
|
363
|
+
).PrintTableMetricsCallbackConfig
|
364
|
+
if name == "PyTorchProfilerConfig":
|
365
|
+
return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
|
366
|
+
if name == "RLPSanityChecksCallbackConfig":
|
375
367
|
return importlib.import_module(
|
376
|
-
"nshtrainer.
|
377
|
-
).
|
378
|
-
if name == "
|
368
|
+
"nshtrainer.callbacks"
|
369
|
+
).RLPSanityChecksCallbackConfig
|
370
|
+
if name == "ReLUNonlinearityConfig":
|
371
|
+
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
372
|
+
if name == "ReduceLROnPlateauConfig":
|
379
373
|
return importlib.import_module(
|
380
|
-
"nshtrainer.
|
381
|
-
).
|
382
|
-
if name == "
|
374
|
+
"nshtrainer.lr_scheduler"
|
375
|
+
).ReduceLROnPlateauConfig
|
376
|
+
if name == "ReproducibilityConfig":
|
383
377
|
return importlib.import_module(
|
384
|
-
"nshtrainer.
|
385
|
-
).
|
386
|
-
if name == "
|
378
|
+
"nshtrainer.trainer._config"
|
379
|
+
).ReproducibilityConfig
|
380
|
+
if name == "SanityCheckingConfig":
|
387
381
|
return importlib.import_module(
|
388
|
-
"nshtrainer.
|
389
|
-
).
|
390
|
-
if name == "
|
382
|
+
"nshtrainer.trainer._config"
|
383
|
+
).SanityCheckingConfig
|
384
|
+
if name == "SharedParametersCallbackConfig":
|
391
385
|
return importlib.import_module(
|
392
386
|
"nshtrainer.callbacks"
|
393
|
-
).
|
394
|
-
if name == "
|
387
|
+
).SharedParametersCallbackConfig
|
388
|
+
if name == "SiLUNonlinearityConfig":
|
389
|
+
return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
|
390
|
+
if name == "SigmoidNonlinearityConfig":
|
391
|
+
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
392
|
+
if name == "SimpleProfilerConfig":
|
393
|
+
return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
|
394
|
+
if name == "SoftmaxNonlinearityConfig":
|
395
|
+
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
396
|
+
if name == "SoftplusNonlinearityConfig":
|
397
|
+
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
398
|
+
if name == "SoftsignNonlinearityConfig":
|
399
|
+
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
400
|
+
if name == "StepsConfig":
|
401
|
+
return importlib.import_module("nshtrainer.util.config").StepsConfig
|
402
|
+
if name == "SwiGLUNonlinearityConfig":
|
395
403
|
return importlib.import_module(
|
396
|
-
"nshtrainer.
|
397
|
-
).
|
398
|
-
if name == "
|
399
|
-
return importlib.import_module("nshtrainer.
|
400
|
-
if name == "
|
401
|
-
return importlib.import_module("nshtrainer.
|
402
|
-
if name == "
|
404
|
+
"nshtrainer.nn.nonlinearity"
|
405
|
+
).SwiGLUNonlinearityConfig
|
406
|
+
if name == "SwishNonlinearityConfig":
|
407
|
+
return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
|
408
|
+
if name == "TanhNonlinearityConfig":
|
409
|
+
return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
|
410
|
+
if name == "TensorboardLoggerConfig":
|
411
|
+
return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
|
412
|
+
if name == "TrainerConfig":
|
413
|
+
return importlib.import_module("nshtrainer").TrainerConfig
|
414
|
+
if name == "UserProvidedPathCheckpointStrategyConfig":
|
403
415
|
return importlib.import_module(
|
404
|
-
"nshtrainer.
|
405
|
-
).
|
406
|
-
if name == "
|
416
|
+
"nshtrainer._checkpoint.loader"
|
417
|
+
).UserProvidedPathCheckpointStrategyConfig
|
418
|
+
if name == "WandbLoggerConfig":
|
419
|
+
return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
|
420
|
+
if name == "WandbUploadCodeCallbackConfig":
|
407
421
|
return importlib.import_module(
|
408
422
|
"nshtrainer.callbacks"
|
409
|
-
).
|
410
|
-
if name == "
|
423
|
+
).WandbUploadCodeCallbackConfig
|
424
|
+
if name == "WandbWatchCallbackConfig":
|
411
425
|
return importlib.import_module(
|
412
426
|
"nshtrainer.callbacks"
|
413
|
-
).
|
414
|
-
if name == "BaseCheckpointCallbackConfig":
|
415
|
-
return importlib.import_module(
|
416
|
-
"nshtrainer.callbacks.checkpoint._base"
|
417
|
-
).BaseCheckpointCallbackConfig
|
418
|
-
if name == "BaseProfilerConfig":
|
419
|
-
return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
|
420
|
-
if name == "PyTorchProfilerConfig":
|
421
|
-
return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
|
422
|
-
if name == "AdvancedProfilerConfig":
|
423
|
-
return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
|
424
|
-
if name == "SimpleProfilerConfig":
|
425
|
-
return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
|
426
|
-
if name == "OptimizerConfig":
|
427
|
-
return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
|
428
|
-
if name == "LoggerConfig":
|
429
|
-
return importlib.import_module("nshtrainer.loggers").LoggerConfig
|
430
|
-
if name == "NonlinearityConfig":
|
431
|
-
return importlib.import_module("nshtrainer.nn").NonlinearityConfig
|
432
|
-
if name == "DurationConfig":
|
433
|
-
return importlib.import_module("nshtrainer.util.config").DurationConfig
|
434
|
-
if name == "LRSchedulerConfig":
|
435
|
-
return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
|
427
|
+
).WandbWatchCallbackConfig
|
436
428
|
if name == "CallbackConfig":
|
437
429
|
return importlib.import_module("nshtrainer.callbacks").CallbackConfig
|
438
430
|
if name == "CheckpointCallbackConfig":
|
439
431
|
return importlib.import_module(
|
440
432
|
"nshtrainer.trainer._config"
|
441
433
|
).CheckpointCallbackConfig
|
442
|
-
if name == "ProfilerConfig":
|
443
|
-
return importlib.import_module("nshtrainer.profiler").ProfilerConfig
|
444
434
|
if name == "CheckpointLoadingStrategyConfig":
|
445
435
|
return importlib.import_module(
|
446
436
|
"nshtrainer._checkpoint.loader"
|
447
437
|
).CheckpointLoadingStrategyConfig
|
438
|
+
if name == "DurationConfig":
|
439
|
+
return importlib.import_module("nshtrainer.util.config").DurationConfig
|
440
|
+
if name == "LRSchedulerConfig":
|
441
|
+
return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
|
442
|
+
if name == "LoggerConfig":
|
443
|
+
return importlib.import_module("nshtrainer.loggers").LoggerConfig
|
444
|
+
if name == "NonlinearityConfig":
|
445
|
+
return importlib.import_module("nshtrainer.nn").NonlinearityConfig
|
446
|
+
if name == "OptimizerConfig":
|
447
|
+
return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
|
448
|
+
if name == "ProfilerConfig":
|
449
|
+
return importlib.import_module("nshtrainer.profiler").ProfilerConfig
|
448
450
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
449
451
|
|
450
452
|
|
@@ -456,10 +458,8 @@ from . import callbacks as callbacks
|
|
456
458
|
from . import loggers as loggers
|
457
459
|
from . import lr_scheduler as lr_scheduler
|
458
460
|
from . import metrics as metrics
|
459
|
-
from . import model as model
|
460
461
|
from . import nn as nn
|
461
462
|
from . import optimizer as optimizer
|
462
463
|
from . import profiler as profiler
|
463
|
-
from . import runner as runner
|
464
464
|
from . import trainer as trainer
|
465
465
|
from . import util as util
|
@@ -0,0 +1,70 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer._checkpoint.loader import (
|
11
|
+
BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
|
12
|
+
)
|
13
|
+
from nshtrainer._checkpoint.loader import (
|
14
|
+
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
15
|
+
)
|
16
|
+
from nshtrainer._checkpoint.loader import (
|
17
|
+
CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
|
18
|
+
)
|
19
|
+
from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
|
20
|
+
from nshtrainer._checkpoint.loader import (
|
21
|
+
LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
|
22
|
+
)
|
23
|
+
from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
|
24
|
+
from nshtrainer._checkpoint.loader import (
|
25
|
+
UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
|
26
|
+
)
|
27
|
+
from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
|
28
|
+
else:
|
29
|
+
|
30
|
+
def __getattr__(name):
|
31
|
+
import importlib
|
32
|
+
|
33
|
+
if name in globals():
|
34
|
+
return globals()[name]
|
35
|
+
if name == "BestCheckpointStrategyConfig":
|
36
|
+
return importlib.import_module(
|
37
|
+
"nshtrainer._checkpoint.loader"
|
38
|
+
).BestCheckpointStrategyConfig
|
39
|
+
if name == "CheckpointLoadingConfig":
|
40
|
+
return importlib.import_module(
|
41
|
+
"nshtrainer._checkpoint.loader"
|
42
|
+
).CheckpointLoadingConfig
|
43
|
+
if name == "CheckpointMetadata":
|
44
|
+
return importlib.import_module(
|
45
|
+
"nshtrainer._checkpoint.loader"
|
46
|
+
).CheckpointMetadata
|
47
|
+
if name == "EnvironmentConfig":
|
48
|
+
return importlib.import_module(
|
49
|
+
"nshtrainer._checkpoint.metadata"
|
50
|
+
).EnvironmentConfig
|
51
|
+
if name == "LastCheckpointStrategyConfig":
|
52
|
+
return importlib.import_module(
|
53
|
+
"nshtrainer._checkpoint.loader"
|
54
|
+
).LastCheckpointStrategyConfig
|
55
|
+
if name == "MetricConfig":
|
56
|
+
return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
|
57
|
+
if name == "UserProvidedPathCheckpointStrategyConfig":
|
58
|
+
return importlib.import_module(
|
59
|
+
"nshtrainer._checkpoint.loader"
|
60
|
+
).UserProvidedPathCheckpointStrategyConfig
|
61
|
+
if name == "CheckpointLoadingStrategyConfig":
|
62
|
+
return importlib.import_module(
|
63
|
+
"nshtrainer._checkpoint.loader"
|
64
|
+
).CheckpointLoadingStrategyConfig
|
65
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
66
|
+
|
67
|
+
|
68
|
+
# Submodule exports
|
69
|
+
from . import loader as loader
|
70
|
+
from . import metadata as metadata
|
@@ -31,12 +31,14 @@ else:
|
|
31
31
|
|
32
32
|
if name in globals():
|
33
33
|
return globals()[name]
|
34
|
-
if name == "MetricConfig":
|
35
|
-
return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
|
36
34
|
if name == "BestCheckpointStrategyConfig":
|
37
35
|
return importlib.import_module(
|
38
36
|
"nshtrainer._checkpoint.loader"
|
39
37
|
).BestCheckpointStrategyConfig
|
38
|
+
if name == "CheckpointLoadingConfig":
|
39
|
+
return importlib.import_module(
|
40
|
+
"nshtrainer._checkpoint.loader"
|
41
|
+
).CheckpointLoadingConfig
|
40
42
|
if name == "CheckpointMetadata":
|
41
43
|
return importlib.import_module(
|
42
44
|
"nshtrainer._checkpoint.loader"
|
@@ -45,10 +47,8 @@ else:
|
|
45
47
|
return importlib.import_module(
|
46
48
|
"nshtrainer._checkpoint.loader"
|
47
49
|
).LastCheckpointStrategyConfig
|
48
|
-
if name == "
|
49
|
-
return importlib.import_module(
|
50
|
-
"nshtrainer._checkpoint.loader"
|
51
|
-
).CheckpointLoadingConfig
|
50
|
+
if name == "MetricConfig":
|
51
|
+
return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
|
52
52
|
if name == "UserProvidedPathCheckpointStrategyConfig":
|
53
53
|
return importlib.import_module(
|
54
54
|
"nshtrainer._checkpoint.loader"
|