nshtrainer 0.42.0__py3-none-any.whl → 0.43.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/_callback.py +2 -0
- nshtrainer/_checkpoint/loader.py +2 -0
- nshtrainer/_checkpoint/metadata.py +2 -0
- nshtrainer/_checkpoint/saver.py +2 -0
- nshtrainer/_directory.py +4 -2
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_hf_hub.py +2 -0
- nshtrainer/callbacks/__init__.py +45 -29
- nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
- nshtrainer/callbacks/actsave.py +2 -0
- nshtrainer/callbacks/base.py +2 -0
- nshtrainer/callbacks/checkpoint/__init__.py +6 -2
- nshtrainer/callbacks/checkpoint/_base.py +2 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
- nshtrainer/callbacks/debug_flag.py +2 -0
- nshtrainer/callbacks/directory_setup.py +4 -2
- nshtrainer/callbacks/early_stopping.py +6 -4
- nshtrainer/callbacks/ema.py +5 -3
- nshtrainer/callbacks/finite_checks.py +3 -1
- nshtrainer/callbacks/gradient_skipping.py +6 -4
- nshtrainer/callbacks/interval.py +2 -0
- nshtrainer/callbacks/log_epoch.py +13 -1
- nshtrainer/callbacks/norm_logging.py +4 -2
- nshtrainer/callbacks/print_table.py +3 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +4 -2
- nshtrainer/callbacks/throughput_monitor.py +2 -0
- nshtrainer/callbacks/timer.py +5 -3
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/config/__init__.py +130 -90
- nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
- nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
- nshtrainer/config/_directory/__init__.py +9 -3
- nshtrainer/config/_hf_hub/__init__.py +6 -4
- nshtrainer/config/callbacks/__init__.py +82 -42
- nshtrainer/config/callbacks/actsave/__init__.py +4 -2
- nshtrainer/config/callbacks/base/__init__.py +2 -0
- nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
- nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
- nshtrainer/config/callbacks/ema/__init__.py +5 -3
- nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
- nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
- nshtrainer/config/callbacks/print_table/__init__.py +7 -5
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
- nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
- nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
- nshtrainer/config/callbacks/timer/__init__.py +9 -5
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
- nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
- nshtrainer/config/loggers/__init__.py +18 -10
- nshtrainer/config/loggers/_base/__init__.py +2 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -0
- nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
- nshtrainer/config/loggers/wandb/__init__.py +18 -10
- nshtrainer/config/lr_scheduler/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
- nshtrainer/config/metrics/__init__.py +2 -0
- nshtrainer/config/metrics/_config/__init__.py +2 -0
- nshtrainer/config/model/__init__.py +8 -6
- nshtrainer/config/model/base/__init__.py +4 -2
- nshtrainer/config/model/config/__init__.py +8 -6
- nshtrainer/config/model/mixins/logger/__init__.py +2 -0
- nshtrainer/config/nn/__init__.py +16 -14
- nshtrainer/config/nn/mlp/__init__.py +2 -0
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
- nshtrainer/config/optimizer/__init__.py +2 -0
- nshtrainer/config/profiler/__init__.py +2 -0
- nshtrainer/config/profiler/_base/__init__.py +2 -0
- nshtrainer/config/profiler/advanced/__init__.py +6 -4
- nshtrainer/config/profiler/pytorch/__init__.py +6 -4
- nshtrainer/config/profiler/simple/__init__.py +6 -4
- nshtrainer/config/runner/__init__.py +2 -0
- nshtrainer/config/trainer/_config/__init__.py +43 -39
- nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -18
- nshtrainer/config/util/config/__init__.py +2 -0
- nshtrainer/config/util/config/dtype/__init__.py +2 -0
- nshtrainer/config/util/config/duration/__init__.py +2 -0
- nshtrainer/data/__init__.py +2 -0
- nshtrainer/data/balanced_batch_sampler.py +2 -0
- nshtrainer/data/datamodule.py +2 -0
- nshtrainer/data/transform.py +2 -0
- nshtrainer/ll/__init__.py +2 -0
- nshtrainer/ll/_experimental.py +2 -0
- nshtrainer/ll/actsave.py +2 -0
- nshtrainer/ll/callbacks.py +2 -0
- nshtrainer/ll/config.py +2 -0
- nshtrainer/ll/data.py +2 -0
- nshtrainer/ll/log.py +2 -0
- nshtrainer/ll/lr_scheduler.py +2 -0
- nshtrainer/ll/model.py +2 -0
- nshtrainer/ll/nn.py +2 -0
- nshtrainer/ll/optimizer.py +2 -0
- nshtrainer/ll/runner.py +2 -0
- nshtrainer/ll/snapshot.py +2 -0
- nshtrainer/ll/snoop.py +2 -0
- nshtrainer/ll/trainer.py +2 -0
- nshtrainer/ll/typecheck.py +2 -0
- nshtrainer/ll/util.py +2 -0
- nshtrainer/loggers/__init__.py +2 -0
- nshtrainer/loggers/_base.py +2 -0
- nshtrainer/loggers/csv.py +2 -0
- nshtrainer/loggers/tensorboard.py +2 -0
- nshtrainer/loggers/wandb.py +6 -4
- nshtrainer/lr_scheduler/__init__.py +2 -0
- nshtrainer/lr_scheduler/_base.py +2 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
- nshtrainer/metrics/__init__.py +2 -0
- nshtrainer/metrics/_config.py +2 -0
- nshtrainer/model/__init__.py +2 -0
- nshtrainer/model/base.py +2 -0
- nshtrainer/model/config.py +2 -0
- nshtrainer/model/mixins/callback.py +2 -0
- nshtrainer/model/mixins/logger.py +2 -0
- nshtrainer/nn/__init__.py +2 -0
- nshtrainer/nn/mlp.py +2 -0
- nshtrainer/nn/module_dict.py +2 -0
- nshtrainer/nn/module_list.py +2 -0
- nshtrainer/nn/nonlinearity.py +2 -0
- nshtrainer/optimizer.py +2 -0
- nshtrainer/profiler/__init__.py +2 -0
- nshtrainer/profiler/_base.py +2 -0
- nshtrainer/profiler/advanced.py +2 -0
- nshtrainer/profiler/pytorch.py +2 -0
- nshtrainer/profiler/simple.py +2 -0
- nshtrainer/runner.py +2 -0
- nshtrainer/scripts/find_packages.py +2 -0
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +16 -13
- nshtrainer/trainer/_runtime_callback.py +2 -0
- nshtrainer/trainer/checkpoint_connector.py +2 -0
- nshtrainer/trainer/signal_connector.py +2 -0
- nshtrainer/trainer/trainer.py +2 -0
- nshtrainer/util/_environment_info.py +2 -0
- nshtrainer/util/bf16.py +2 -0
- nshtrainer/util/config/__init__.py +2 -0
- nshtrainer/util/config/dtype.py +2 -0
- nshtrainer/util/config/duration.py +2 -0
- nshtrainer/util/environment.py +2 -0
- nshtrainer/util/path.py +2 -0
- nshtrainer/util/seed.py +2 -0
- nshtrainer/util/slurm.py +3 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +2 -0
- {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
- nshtrainer-0.43.0.dist-info/RECORD +162 -0
- nshtrainer-0.42.0.dist-info/RECORD +0 -162
- {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
nshtrainer/config/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -30,25 +32,47 @@ if TYPE_CHECKING:
|
|
|
30
32
|
)
|
|
31
33
|
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
|
32
34
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
|
33
|
-
from nshtrainer.callbacks import
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
from nshtrainer.callbacks import
|
|
37
|
-
|
|
38
|
-
|
|
35
|
+
from nshtrainer.callbacks import (
|
|
36
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
|
37
|
+
)
|
|
38
|
+
from nshtrainer.callbacks import (
|
|
39
|
+
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
|
40
|
+
)
|
|
41
|
+
from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
|
|
42
|
+
from nshtrainer.callbacks import (
|
|
43
|
+
EpochTimerCallbackConfig as EpochTimerCallbackConfig,
|
|
44
|
+
)
|
|
45
|
+
from nshtrainer.callbacks import (
|
|
46
|
+
FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
|
|
47
|
+
)
|
|
48
|
+
from nshtrainer.callbacks import (
|
|
49
|
+
GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
|
|
50
|
+
)
|
|
39
51
|
from nshtrainer.callbacks import (
|
|
40
52
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
41
53
|
)
|
|
42
|
-
from nshtrainer.callbacks import
|
|
54
|
+
from nshtrainer.callbacks import (
|
|
55
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
|
56
|
+
)
|
|
43
57
|
from nshtrainer.callbacks import (
|
|
44
58
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
45
59
|
)
|
|
46
|
-
from nshtrainer.callbacks import
|
|
47
|
-
|
|
48
|
-
|
|
60
|
+
from nshtrainer.callbacks import (
|
|
61
|
+
PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
|
|
62
|
+
)
|
|
63
|
+
from nshtrainer.callbacks import (
|
|
64
|
+
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
|
65
|
+
)
|
|
66
|
+
from nshtrainer.callbacks import (
|
|
67
|
+
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
|
68
|
+
)
|
|
49
69
|
from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
50
|
-
from nshtrainer.callbacks import
|
|
51
|
-
|
|
70
|
+
from nshtrainer.callbacks import (
|
|
71
|
+
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
|
72
|
+
)
|
|
73
|
+
from nshtrainer.callbacks import (
|
|
74
|
+
WandbWatchCallbackConfig as WandbWatchCallbackConfig,
|
|
75
|
+
)
|
|
52
76
|
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
|
53
77
|
from nshtrainer.callbacks.checkpoint._base import (
|
|
54
78
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
|
@@ -155,26 +179,28 @@ else:
|
|
|
155
179
|
|
|
156
180
|
if name in globals():
|
|
157
181
|
return globals()[name]
|
|
158
|
-
if name == "BaseConfig":
|
|
159
|
-
return importlib.import_module("nshtrainer").BaseConfig
|
|
160
182
|
if name == "MetricConfig":
|
|
161
183
|
return importlib.import_module("nshtrainer").MetricConfig
|
|
162
|
-
if name == "
|
|
163
|
-
return importlib.import_module("nshtrainer
|
|
164
|
-
if name == "HuggingFaceHubConfig":
|
|
165
|
-
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
|
184
|
+
if name == "BaseConfig":
|
|
185
|
+
return importlib.import_module("nshtrainer").BaseConfig
|
|
166
186
|
if name == "HuggingFaceHubAutoCreateConfig":
|
|
167
187
|
return importlib.import_module(
|
|
168
188
|
"nshtrainer._hf_hub"
|
|
169
189
|
).HuggingFaceHubAutoCreateConfig
|
|
190
|
+
if name == "HuggingFaceHubConfig":
|
|
191
|
+
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
|
192
|
+
if name == "CallbackConfigBase":
|
|
193
|
+
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
|
170
194
|
if name == "OptimizerConfigBase":
|
|
171
195
|
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
|
172
196
|
if name == "AdamWConfig":
|
|
173
197
|
return importlib.import_module("nshtrainer.optimizer").AdamWConfig
|
|
198
|
+
if name == "DirectorySetupCallbackConfig":
|
|
199
|
+
return importlib.import_module(
|
|
200
|
+
"nshtrainer.callbacks"
|
|
201
|
+
).DirectorySetupCallbackConfig
|
|
174
202
|
if name == "DirectoryConfig":
|
|
175
203
|
return importlib.import_module("nshtrainer.model").DirectoryConfig
|
|
176
|
-
if name == "DirectorySetupConfig":
|
|
177
|
-
return importlib.import_module("nshtrainer.callbacks").DirectorySetupConfig
|
|
178
204
|
if name == "TrainerConfig":
|
|
179
205
|
return importlib.import_module("nshtrainer.model").TrainerConfig
|
|
180
206
|
if name == "EnvironmentConfig":
|
|
@@ -183,36 +209,36 @@ else:
|
|
|
183
209
|
return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
|
|
184
210
|
if name == "MLPConfig":
|
|
185
211
|
return importlib.import_module("nshtrainer.nn").MLPConfig
|
|
212
|
+
if name == "PReLUConfig":
|
|
213
|
+
return importlib.import_module("nshtrainer.nn").PReLUConfig
|
|
214
|
+
if name == "LeakyReLUNonlinearityConfig":
|
|
215
|
+
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
|
186
216
|
if name == "SwiGLUNonlinearityConfig":
|
|
187
217
|
return importlib.import_module(
|
|
188
218
|
"nshtrainer.nn.nonlinearity"
|
|
189
219
|
).SwiGLUNonlinearityConfig
|
|
190
|
-
if name == "
|
|
191
|
-
return importlib.import_module("nshtrainer.nn").
|
|
220
|
+
if name == "SoftsignNonlinearityConfig":
|
|
221
|
+
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
|
192
222
|
if name == "SiLUNonlinearityConfig":
|
|
193
223
|
return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
|
|
224
|
+
if name == "SigmoidNonlinearityConfig":
|
|
225
|
+
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
|
226
|
+
if name == "SoftplusNonlinearityConfig":
|
|
227
|
+
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
|
194
228
|
if name == "ELUNonlinearityConfig":
|
|
195
229
|
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
|
230
|
+
if name == "SoftmaxNonlinearityConfig":
|
|
231
|
+
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
|
196
232
|
if name == "GELUNonlinearityConfig":
|
|
197
233
|
return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
|
|
198
|
-
if name == "SoftplusNonlinearityConfig":
|
|
199
|
-
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
|
200
|
-
if name == "SoftsignNonlinearityConfig":
|
|
201
|
-
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
|
202
234
|
if name == "SwishNonlinearityConfig":
|
|
203
235
|
return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
|
|
204
|
-
if name == "SoftmaxNonlinearityConfig":
|
|
205
|
-
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
|
206
236
|
if name == "MishNonlinearityConfig":
|
|
207
237
|
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
|
208
|
-
if name == "SigmoidNonlinearityConfig":
|
|
209
|
-
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
|
210
238
|
if name == "TanhNonlinearityConfig":
|
|
211
239
|
return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
|
|
212
|
-
if name == "
|
|
213
|
-
return importlib.import_module("nshtrainer.nn").
|
|
214
|
-
if name == "LeakyReLUNonlinearityConfig":
|
|
215
|
-
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
|
240
|
+
if name == "ReLUNonlinearityConfig":
|
|
241
|
+
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
|
216
242
|
if name == "LRSchedulerConfigBase":
|
|
217
243
|
return importlib.import_module(
|
|
218
244
|
"nshtrainer.lr_scheduler"
|
|
@@ -231,32 +257,40 @@ else:
|
|
|
231
257
|
return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
|
|
232
258
|
if name == "WandbLoggerConfig":
|
|
233
259
|
return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
|
|
234
|
-
if name == "
|
|
235
|
-
return importlib.import_module(
|
|
236
|
-
|
|
237
|
-
|
|
260
|
+
if name == "WandbUploadCodeCallbackConfig":
|
|
261
|
+
return importlib.import_module(
|
|
262
|
+
"nshtrainer.callbacks"
|
|
263
|
+
).WandbUploadCodeCallbackConfig
|
|
264
|
+
if name == "WandbWatchCallbackConfig":
|
|
265
|
+
return importlib.import_module(
|
|
266
|
+
"nshtrainer.callbacks"
|
|
267
|
+
).WandbWatchCallbackConfig
|
|
238
268
|
if name == "CSVLoggerConfig":
|
|
239
269
|
return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
|
|
240
|
-
if name == "
|
|
241
|
-
return importlib.import_module(
|
|
242
|
-
"nshtrainer.util._environment_info"
|
|
243
|
-
).EnvironmentPackageConfig
|
|
244
|
-
if name == "EnvironmentSnapshotConfig":
|
|
270
|
+
if name == "EnvironmentLinuxEnvironmentConfig":
|
|
245
271
|
return importlib.import_module(
|
|
246
272
|
"nshtrainer.util._environment_info"
|
|
247
|
-
).
|
|
273
|
+
).EnvironmentLinuxEnvironmentConfig
|
|
248
274
|
if name == "EnvironmentLSFInformationConfig":
|
|
249
275
|
return importlib.import_module(
|
|
250
276
|
"nshtrainer.util._environment_info"
|
|
251
277
|
).EnvironmentLSFInformationConfig
|
|
252
|
-
if name == "
|
|
278
|
+
if name == "EnvironmentGPUConfig":
|
|
253
279
|
return importlib.import_module(
|
|
254
280
|
"nshtrainer.util._environment_info"
|
|
255
|
-
).
|
|
256
|
-
if name == "
|
|
281
|
+
).EnvironmentGPUConfig
|
|
282
|
+
if name == "EnvironmentPackageConfig":
|
|
257
283
|
return importlib.import_module(
|
|
258
284
|
"nshtrainer.util._environment_info"
|
|
259
|
-
).
|
|
285
|
+
).EnvironmentPackageConfig
|
|
286
|
+
if name == "EnvironmentHardwareConfig":
|
|
287
|
+
return importlib.import_module(
|
|
288
|
+
"nshtrainer.util._environment_info"
|
|
289
|
+
).EnvironmentHardwareConfig
|
|
290
|
+
if name == "EnvironmentSnapshotConfig":
|
|
291
|
+
return importlib.import_module(
|
|
292
|
+
"nshtrainer.util._environment_info"
|
|
293
|
+
).EnvironmentSnapshotConfig
|
|
260
294
|
if name == "EnvironmentClassInformationConfig":
|
|
261
295
|
return importlib.import_module(
|
|
262
296
|
"nshtrainer.util._environment_info"
|
|
@@ -269,14 +303,10 @@ else:
|
|
|
269
303
|
return importlib.import_module(
|
|
270
304
|
"nshtrainer.util._environment_info"
|
|
271
305
|
).EnvironmentCUDAConfig
|
|
272
|
-
if name == "
|
|
273
|
-
return importlib.import_module(
|
|
274
|
-
"nshtrainer.util._environment_info"
|
|
275
|
-
).EnvironmentGPUConfig
|
|
276
|
-
if name == "EnvironmentHardwareConfig":
|
|
306
|
+
if name == "EnvironmentSLURMInformationConfig":
|
|
277
307
|
return importlib.import_module(
|
|
278
308
|
"nshtrainer.util._environment_info"
|
|
279
|
-
).
|
|
309
|
+
).EnvironmentSLURMInformationConfig
|
|
280
310
|
if name == "EpochsConfig":
|
|
281
311
|
return importlib.import_module("nshtrainer.util.config").EpochsConfig
|
|
282
312
|
if name == "StepsConfig":
|
|
@@ -287,52 +317,56 @@ else:
|
|
|
287
317
|
return importlib.import_module(
|
|
288
318
|
"nshtrainer.trainer._config"
|
|
289
319
|
).CheckpointLoadingConfig
|
|
290
|
-
if name == "
|
|
320
|
+
if name == "SanityCheckingConfig":
|
|
291
321
|
return importlib.import_module(
|
|
292
322
|
"nshtrainer.trainer._config"
|
|
293
|
-
).
|
|
323
|
+
).SanityCheckingConfig
|
|
324
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
|
325
|
+
return importlib.import_module(
|
|
326
|
+
"nshtrainer.callbacks"
|
|
327
|
+
).OnExceptionCheckpointCallbackConfig
|
|
294
328
|
if name == "GradientClippingConfig":
|
|
295
329
|
return importlib.import_module(
|
|
296
330
|
"nshtrainer.trainer._config"
|
|
297
331
|
).GradientClippingConfig
|
|
298
|
-
if name == "
|
|
332
|
+
if name == "LoggingConfig":
|
|
333
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
|
334
|
+
if name == "RLPSanityChecksCallbackConfig":
|
|
299
335
|
return importlib.import_module(
|
|
300
336
|
"nshtrainer.callbacks"
|
|
301
|
-
).
|
|
302
|
-
if name == "
|
|
337
|
+
).RLPSanityChecksCallbackConfig
|
|
338
|
+
if name == "CheckpointSavingConfig":
|
|
303
339
|
return importlib.import_module(
|
|
304
|
-
"nshtrainer.
|
|
305
|
-
).
|
|
306
|
-
if name == "RLPSanityChecksConfig":
|
|
307
|
-
return importlib.import_module("nshtrainer.callbacks").RLPSanityChecksConfig
|
|
308
|
-
if name == "EarlyStoppingConfig":
|
|
309
|
-
return importlib.import_module("nshtrainer.callbacks").EarlyStoppingConfig
|
|
340
|
+
"nshtrainer.trainer._config"
|
|
341
|
+
).CheckpointSavingConfig
|
|
310
342
|
if name == "DebugFlagCallbackConfig":
|
|
311
343
|
return importlib.import_module(
|
|
312
344
|
"nshtrainer.callbacks"
|
|
313
345
|
).DebugFlagCallbackConfig
|
|
314
|
-
if name == "
|
|
346
|
+
if name == "LastCheckpointCallbackConfig":
|
|
315
347
|
return importlib.import_module(
|
|
316
|
-
"nshtrainer.
|
|
317
|
-
).
|
|
318
|
-
if name == "
|
|
348
|
+
"nshtrainer.callbacks"
|
|
349
|
+
).LastCheckpointCallbackConfig
|
|
350
|
+
if name == "SharedParametersCallbackConfig":
|
|
319
351
|
return importlib.import_module(
|
|
320
352
|
"nshtrainer.callbacks"
|
|
321
|
-
).
|
|
322
|
-
if name == "
|
|
323
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
|
324
|
-
if name == "SanityCheckingConfig":
|
|
353
|
+
).SharedParametersCallbackConfig
|
|
354
|
+
if name == "ReproducibilityConfig":
|
|
325
355
|
return importlib.import_module(
|
|
326
356
|
"nshtrainer.trainer._config"
|
|
327
|
-
).
|
|
328
|
-
if name == "
|
|
357
|
+
).ReproducibilityConfig
|
|
358
|
+
if name == "EarlyStoppingCallbackConfig":
|
|
329
359
|
return importlib.import_module(
|
|
330
360
|
"nshtrainer.callbacks"
|
|
331
|
-
).
|
|
332
|
-
if name == "
|
|
361
|
+
).EarlyStoppingCallbackConfig
|
|
362
|
+
if name == "OptimizationConfig":
|
|
333
363
|
return importlib.import_module(
|
|
334
364
|
"nshtrainer.trainer._config"
|
|
335
|
-
).
|
|
365
|
+
).OptimizationConfig
|
|
366
|
+
if name == "BestCheckpointCallbackConfig":
|
|
367
|
+
return importlib.import_module(
|
|
368
|
+
"nshtrainer.callbacks"
|
|
369
|
+
).BestCheckpointCallbackConfig
|
|
336
370
|
if name == "CheckpointMetadata":
|
|
337
371
|
return importlib.import_module(
|
|
338
372
|
"nshtrainer._checkpoint.loader"
|
|
@@ -349,28 +383,34 @@ else:
|
|
|
349
383
|
return importlib.import_module(
|
|
350
384
|
"nshtrainer._checkpoint.loader"
|
|
351
385
|
).UserProvidedPathCheckpointStrategyConfig
|
|
352
|
-
if name == "
|
|
386
|
+
if name == "PrintTableMetricsCallbackConfig":
|
|
353
387
|
return importlib.import_module(
|
|
354
388
|
"nshtrainer.callbacks"
|
|
355
|
-
).
|
|
389
|
+
).PrintTableMetricsCallbackConfig
|
|
356
390
|
if name == "ThroughputMonitorConfig":
|
|
357
391
|
return importlib.import_module(
|
|
358
392
|
"nshtrainer.callbacks"
|
|
359
393
|
).ThroughputMonitorConfig
|
|
360
|
-
if name == "
|
|
394
|
+
if name == "GradientSkippingCallbackConfig":
|
|
361
395
|
return importlib.import_module(
|
|
362
396
|
"nshtrainer.callbacks"
|
|
363
|
-
).
|
|
364
|
-
if name == "
|
|
365
|
-
return importlib.import_module("nshtrainer.callbacks").
|
|
397
|
+
).GradientSkippingCallbackConfig
|
|
398
|
+
if name == "EMACallbackConfig":
|
|
399
|
+
return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
|
|
366
400
|
if name == "ActSaveConfig":
|
|
367
401
|
return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
|
|
368
|
-
if name == "
|
|
369
|
-
return importlib.import_module(
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
if name == "
|
|
373
|
-
return importlib.import_module(
|
|
402
|
+
if name == "FiniteChecksCallbackConfig":
|
|
403
|
+
return importlib.import_module(
|
|
404
|
+
"nshtrainer.callbacks"
|
|
405
|
+
).FiniteChecksCallbackConfig
|
|
406
|
+
if name == "NormLoggingCallbackConfig":
|
|
407
|
+
return importlib.import_module(
|
|
408
|
+
"nshtrainer.callbacks"
|
|
409
|
+
).NormLoggingCallbackConfig
|
|
410
|
+
if name == "EpochTimerCallbackConfig":
|
|
411
|
+
return importlib.import_module(
|
|
412
|
+
"nshtrainer.callbacks"
|
|
413
|
+
).EpochTimerCallbackConfig
|
|
374
414
|
if name == "BaseCheckpointCallbackConfig":
|
|
375
415
|
return importlib.import_module(
|
|
376
416
|
"nshtrainer.callbacks.checkpoint._base"
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -29,24 +31,24 @@ else:
|
|
|
29
31
|
|
|
30
32
|
if name in globals():
|
|
31
33
|
return globals()[name]
|
|
32
|
-
if name == "CheckpointLoadingConfig":
|
|
33
|
-
return importlib.import_module(
|
|
34
|
-
"nshtrainer._checkpoint.loader"
|
|
35
|
-
).CheckpointLoadingConfig
|
|
36
|
-
if name == "CheckpointMetadata":
|
|
37
|
-
return importlib.import_module(
|
|
38
|
-
"nshtrainer._checkpoint.loader"
|
|
39
|
-
).CheckpointMetadata
|
|
40
34
|
if name == "MetricConfig":
|
|
41
35
|
return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
|
|
42
36
|
if name == "BestCheckpointStrategyConfig":
|
|
43
37
|
return importlib.import_module(
|
|
44
38
|
"nshtrainer._checkpoint.loader"
|
|
45
39
|
).BestCheckpointStrategyConfig
|
|
40
|
+
if name == "CheckpointMetadata":
|
|
41
|
+
return importlib.import_module(
|
|
42
|
+
"nshtrainer._checkpoint.loader"
|
|
43
|
+
).CheckpointMetadata
|
|
46
44
|
if name == "LastCheckpointStrategyConfig":
|
|
47
45
|
return importlib.import_module(
|
|
48
46
|
"nshtrainer._checkpoint.loader"
|
|
49
47
|
).LastCheckpointStrategyConfig
|
|
48
|
+
if name == "CheckpointLoadingConfig":
|
|
49
|
+
return importlib.import_module(
|
|
50
|
+
"nshtrainer._checkpoint.loader"
|
|
51
|
+
).CheckpointLoadingConfig
|
|
50
52
|
if name == "UserProvidedPathCheckpointStrategyConfig":
|
|
51
53
|
return importlib.import_module(
|
|
52
54
|
"nshtrainer._checkpoint.loader"
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -14,14 +16,14 @@ else:
|
|
|
14
16
|
|
|
15
17
|
if name in globals():
|
|
16
18
|
return globals()[name]
|
|
17
|
-
if name == "EnvironmentConfig":
|
|
18
|
-
return importlib.import_module(
|
|
19
|
-
"nshtrainer._checkpoint.metadata"
|
|
20
|
-
).EnvironmentConfig
|
|
21
19
|
if name == "CheckpointMetadata":
|
|
22
20
|
return importlib.import_module(
|
|
23
21
|
"nshtrainer._checkpoint.metadata"
|
|
24
22
|
).CheckpointMetadata
|
|
23
|
+
if name == "EnvironmentConfig":
|
|
24
|
+
return importlib.import_module(
|
|
25
|
+
"nshtrainer._checkpoint.metadata"
|
|
26
|
+
).EnvironmentConfig
|
|
25
27
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
26
28
|
|
|
27
29
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -6,7 +8,9 @@ from typing import TYPE_CHECKING
|
|
|
6
8
|
|
|
7
9
|
if TYPE_CHECKING:
|
|
8
10
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
9
|
-
from nshtrainer._directory import
|
|
11
|
+
from nshtrainer._directory import (
|
|
12
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
|
13
|
+
)
|
|
10
14
|
from nshtrainer._directory import LoggerConfig as LoggerConfig
|
|
11
15
|
else:
|
|
12
16
|
|
|
@@ -15,10 +19,12 @@ else:
|
|
|
15
19
|
|
|
16
20
|
if name in globals():
|
|
17
21
|
return globals()[name]
|
|
22
|
+
if name == "DirectorySetupCallbackConfig":
|
|
23
|
+
return importlib.import_module(
|
|
24
|
+
"nshtrainer._directory"
|
|
25
|
+
).DirectorySetupCallbackConfig
|
|
18
26
|
if name == "DirectoryConfig":
|
|
19
27
|
return importlib.import_module("nshtrainer._directory").DirectoryConfig
|
|
20
|
-
if name == "DirectorySetupConfig":
|
|
21
|
-
return importlib.import_module("nshtrainer._directory").DirectorySetupConfig
|
|
22
28
|
if name == "LoggerConfig":
|
|
23
29
|
return importlib.import_module("nshtrainer._directory").LoggerConfig
|
|
24
30
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -17,14 +19,14 @@ else:
|
|
|
17
19
|
|
|
18
20
|
if name in globals():
|
|
19
21
|
return globals()[name]
|
|
20
|
-
if name == "CallbackConfigBase":
|
|
21
|
-
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
|
22
|
-
if name == "HuggingFaceHubConfig":
|
|
23
|
-
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
|
24
22
|
if name == "HuggingFaceHubAutoCreateConfig":
|
|
25
23
|
return importlib.import_module(
|
|
26
24
|
"nshtrainer._hf_hub"
|
|
27
25
|
).HuggingFaceHubAutoCreateConfig
|
|
26
|
+
if name == "HuggingFaceHubConfig":
|
|
27
|
+
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
|
28
|
+
if name == "CallbackConfigBase":
|
|
29
|
+
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
|
28
30
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
29
31
|
|
|
30
32
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -11,25 +13,47 @@ if TYPE_CHECKING:
|
|
|
11
13
|
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
|
12
14
|
from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
|
|
13
15
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
|
14
|
-
from nshtrainer.callbacks import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from nshtrainer.callbacks import
|
|
18
|
-
|
|
19
|
-
|
|
16
|
+
from nshtrainer.callbacks import (
|
|
17
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
|
18
|
+
)
|
|
19
|
+
from nshtrainer.callbacks import (
|
|
20
|
+
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
|
21
|
+
)
|
|
22
|
+
from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
|
|
23
|
+
from nshtrainer.callbacks import (
|
|
24
|
+
EpochTimerCallbackConfig as EpochTimerCallbackConfig,
|
|
25
|
+
)
|
|
26
|
+
from nshtrainer.callbacks import (
|
|
27
|
+
FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
|
|
28
|
+
)
|
|
29
|
+
from nshtrainer.callbacks import (
|
|
30
|
+
GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
|
|
31
|
+
)
|
|
20
32
|
from nshtrainer.callbacks import (
|
|
21
33
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
22
34
|
)
|
|
23
|
-
from nshtrainer.callbacks import
|
|
35
|
+
from nshtrainer.callbacks import (
|
|
36
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
|
37
|
+
)
|
|
24
38
|
from nshtrainer.callbacks import (
|
|
25
39
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
26
40
|
)
|
|
27
|
-
from nshtrainer.callbacks import
|
|
28
|
-
|
|
29
|
-
|
|
41
|
+
from nshtrainer.callbacks import (
|
|
42
|
+
PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
|
|
43
|
+
)
|
|
44
|
+
from nshtrainer.callbacks import (
|
|
45
|
+
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
|
46
|
+
)
|
|
47
|
+
from nshtrainer.callbacks import (
|
|
48
|
+
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
|
49
|
+
)
|
|
30
50
|
from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
31
|
-
from nshtrainer.callbacks import
|
|
32
|
-
|
|
51
|
+
from nshtrainer.callbacks import (
|
|
52
|
+
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
|
53
|
+
)
|
|
54
|
+
from nshtrainer.callbacks import (
|
|
55
|
+
WandbWatchCallbackConfig as WandbWatchCallbackConfig,
|
|
56
|
+
)
|
|
33
57
|
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
|
34
58
|
from nshtrainer.callbacks.checkpoint._base import (
|
|
35
59
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
|
@@ -45,12 +69,12 @@ else:
|
|
|
45
69
|
|
|
46
70
|
if name in globals():
|
|
47
71
|
return globals()[name]
|
|
48
|
-
if name == "
|
|
49
|
-
return importlib.import_module("nshtrainer.callbacks").CallbackConfigBase
|
|
50
|
-
if name == "PrintTableMetricsConfig":
|
|
72
|
+
if name == "PrintTableMetricsCallbackConfig":
|
|
51
73
|
return importlib.import_module(
|
|
52
74
|
"nshtrainer.callbacks"
|
|
53
|
-
).
|
|
75
|
+
).PrintTableMetricsCallbackConfig
|
|
76
|
+
if name == "CallbackConfigBase":
|
|
77
|
+
return importlib.import_module("nshtrainer.callbacks").CallbackConfigBase
|
|
54
78
|
if name == "DebugFlagCallbackConfig":
|
|
55
79
|
return importlib.import_module(
|
|
56
80
|
"nshtrainer.callbacks"
|
|
@@ -59,50 +83,66 @@ else:
|
|
|
59
83
|
return importlib.import_module(
|
|
60
84
|
"nshtrainer.callbacks"
|
|
61
85
|
).ThroughputMonitorConfig
|
|
62
|
-
if name == "
|
|
86
|
+
if name == "GradientSkippingCallbackConfig":
|
|
87
|
+
return importlib.import_module(
|
|
88
|
+
"nshtrainer.callbacks"
|
|
89
|
+
).GradientSkippingCallbackConfig
|
|
90
|
+
if name == "RLPSanityChecksCallbackConfig":
|
|
63
91
|
return importlib.import_module(
|
|
64
92
|
"nshtrainer.callbacks"
|
|
65
|
-
).
|
|
66
|
-
if name == "
|
|
67
|
-
return importlib.import_module(
|
|
68
|
-
|
|
69
|
-
|
|
93
|
+
).RLPSanityChecksCallbackConfig
|
|
94
|
+
if name == "WandbUploadCodeCallbackConfig":
|
|
95
|
+
return importlib.import_module(
|
|
96
|
+
"nshtrainer.callbacks"
|
|
97
|
+
).WandbUploadCodeCallbackConfig
|
|
70
98
|
if name == "MetricConfig":
|
|
71
99
|
return importlib.import_module(
|
|
72
100
|
"nshtrainer.callbacks.early_stopping"
|
|
73
101
|
).MetricConfig
|
|
74
|
-
if name == "
|
|
75
|
-
return importlib.import_module(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if name == "
|
|
79
|
-
return importlib.import_module(
|
|
80
|
-
|
|
81
|
-
|
|
102
|
+
if name == "EarlyStoppingCallbackConfig":
|
|
103
|
+
return importlib.import_module(
|
|
104
|
+
"nshtrainer.callbacks"
|
|
105
|
+
).EarlyStoppingCallbackConfig
|
|
106
|
+
if name == "WandbWatchCallbackConfig":
|
|
107
|
+
return importlib.import_module(
|
|
108
|
+
"nshtrainer.callbacks"
|
|
109
|
+
).WandbWatchCallbackConfig
|
|
110
|
+
if name == "EMACallbackConfig":
|
|
111
|
+
return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
|
|
112
|
+
if name == "DirectorySetupCallbackConfig":
|
|
113
|
+
return importlib.import_module(
|
|
114
|
+
"nshtrainer.callbacks"
|
|
115
|
+
).DirectorySetupCallbackConfig
|
|
82
116
|
if name == "ActSaveConfig":
|
|
83
117
|
return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
|
|
84
|
-
if name == "
|
|
85
|
-
return importlib.import_module("nshtrainer.callbacks").FiniteChecksConfig
|
|
86
|
-
if name == "NormLoggingConfig":
|
|
87
|
-
return importlib.import_module("nshtrainer.callbacks").NormLoggingConfig
|
|
88
|
-
if name == "LastCheckpointCallbackConfig":
|
|
118
|
+
if name == "FiniteChecksCallbackConfig":
|
|
89
119
|
return importlib.import_module(
|
|
90
120
|
"nshtrainer.callbacks"
|
|
91
|
-
).
|
|
92
|
-
if name == "
|
|
121
|
+
).FiniteChecksCallbackConfig
|
|
122
|
+
if name == "NormLoggingCallbackConfig":
|
|
93
123
|
return importlib.import_module(
|
|
94
124
|
"nshtrainer.callbacks"
|
|
95
|
-
).
|
|
125
|
+
).NormLoggingCallbackConfig
|
|
126
|
+
if name == "EpochTimerCallbackConfig":
|
|
127
|
+
return importlib.import_module(
|
|
128
|
+
"nshtrainer.callbacks"
|
|
129
|
+
).EpochTimerCallbackConfig
|
|
96
130
|
if name == "OnExceptionCheckpointCallbackConfig":
|
|
97
131
|
return importlib.import_module(
|
|
98
132
|
"nshtrainer.callbacks"
|
|
99
133
|
).OnExceptionCheckpointCallbackConfig
|
|
100
|
-
if name == "
|
|
101
|
-
return importlib.import_module(
|
|
102
|
-
|
|
134
|
+
if name == "LastCheckpointCallbackConfig":
|
|
135
|
+
return importlib.import_module(
|
|
136
|
+
"nshtrainer.callbacks"
|
|
137
|
+
).LastCheckpointCallbackConfig
|
|
138
|
+
if name == "SharedParametersCallbackConfig":
|
|
139
|
+
return importlib.import_module(
|
|
140
|
+
"nshtrainer.callbacks"
|
|
141
|
+
).SharedParametersCallbackConfig
|
|
142
|
+
if name == "BestCheckpointCallbackConfig":
|
|
103
143
|
return importlib.import_module(
|
|
104
144
|
"nshtrainer.callbacks"
|
|
105
|
-
).
|
|
145
|
+
).BestCheckpointCallbackConfig
|
|
106
146
|
if name == "CheckpointMetadata":
|
|
107
147
|
return importlib.import_module(
|
|
108
148
|
"nshtrainer.callbacks.checkpoint._base"
|