nshtrainer 0.42.0__py3-none-any.whl → 0.44.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/_callback.py +2 -0
- nshtrainer/_checkpoint/loader.py +2 -0
- nshtrainer/_checkpoint/metadata.py +2 -0
- nshtrainer/_checkpoint/saver.py +2 -0
- nshtrainer/_directory.py +4 -2
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_hf_hub.py +2 -0
- nshtrainer/callbacks/__init__.py +45 -29
- nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
- nshtrainer/callbacks/actsave.py +2 -0
- nshtrainer/callbacks/base.py +2 -0
- nshtrainer/callbacks/checkpoint/__init__.py +6 -2
- nshtrainer/callbacks/checkpoint/_base.py +2 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
- nshtrainer/callbacks/debug_flag.py +2 -0
- nshtrainer/callbacks/directory_setup.py +4 -2
- nshtrainer/callbacks/early_stopping.py +6 -4
- nshtrainer/callbacks/ema.py +5 -3
- nshtrainer/callbacks/finite_checks.py +3 -1
- nshtrainer/callbacks/gradient_skipping.py +6 -4
- nshtrainer/callbacks/interval.py +2 -0
- nshtrainer/callbacks/log_epoch.py +13 -1
- nshtrainer/callbacks/norm_logging.py +4 -2
- nshtrainer/callbacks/print_table.py +3 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +4 -2
- nshtrainer/callbacks/throughput_monitor.py +2 -0
- nshtrainer/callbacks/timer.py +5 -3
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/config/__init__.py +130 -90
- nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
- nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
- nshtrainer/config/_directory/__init__.py +9 -3
- nshtrainer/config/_hf_hub/__init__.py +6 -4
- nshtrainer/config/callbacks/__init__.py +82 -42
- nshtrainer/config/callbacks/actsave/__init__.py +4 -2
- nshtrainer/config/callbacks/base/__init__.py +2 -0
- nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
- nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
- nshtrainer/config/callbacks/ema/__init__.py +5 -3
- nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
- nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
- nshtrainer/config/callbacks/print_table/__init__.py +7 -5
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
- nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
- nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
- nshtrainer/config/callbacks/timer/__init__.py +9 -5
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
- nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
- nshtrainer/config/loggers/__init__.py +18 -10
- nshtrainer/config/loggers/_base/__init__.py +2 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -0
- nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
- nshtrainer/config/loggers/wandb/__init__.py +18 -10
- nshtrainer/config/lr_scheduler/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
- nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
- nshtrainer/config/metrics/__init__.py +2 -0
- nshtrainer/config/metrics/_config/__init__.py +2 -0
- nshtrainer/config/model/__init__.py +8 -6
- nshtrainer/config/model/base/__init__.py +4 -2
- nshtrainer/config/model/config/__init__.py +8 -6
- nshtrainer/config/model/mixins/logger/__init__.py +2 -0
- nshtrainer/config/nn/__init__.py +16 -14
- nshtrainer/config/nn/mlp/__init__.py +2 -0
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
- nshtrainer/config/optimizer/__init__.py +2 -0
- nshtrainer/config/profiler/__init__.py +2 -0
- nshtrainer/config/profiler/_base/__init__.py +2 -0
- nshtrainer/config/profiler/advanced/__init__.py +6 -4
- nshtrainer/config/profiler/pytorch/__init__.py +6 -4
- nshtrainer/config/profiler/simple/__init__.py +6 -4
- nshtrainer/config/runner/__init__.py +2 -0
- nshtrainer/config/trainer/_config/__init__.py +43 -39
- nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -18
- nshtrainer/config/util/config/__init__.py +2 -0
- nshtrainer/config/util/config/dtype/__init__.py +2 -0
- nshtrainer/config/util/config/duration/__init__.py +2 -0
- nshtrainer/data/__init__.py +2 -0
- nshtrainer/data/balanced_batch_sampler.py +2 -0
- nshtrainer/data/datamodule.py +2 -0
- nshtrainer/data/transform.py +2 -0
- nshtrainer/ll/__init__.py +2 -0
- nshtrainer/ll/_experimental.py +2 -0
- nshtrainer/ll/actsave.py +2 -0
- nshtrainer/ll/callbacks.py +2 -0
- nshtrainer/ll/config.py +2 -0
- nshtrainer/ll/data.py +2 -0
- nshtrainer/ll/log.py +2 -0
- nshtrainer/ll/lr_scheduler.py +2 -0
- nshtrainer/ll/model.py +2 -0
- nshtrainer/ll/nn.py +2 -0
- nshtrainer/ll/optimizer.py +2 -0
- nshtrainer/ll/runner.py +2 -0
- nshtrainer/ll/snapshot.py +2 -0
- nshtrainer/ll/snoop.py +2 -0
- nshtrainer/ll/trainer.py +2 -0
- nshtrainer/ll/typecheck.py +2 -0
- nshtrainer/ll/util.py +2 -0
- nshtrainer/loggers/__init__.py +2 -0
- nshtrainer/loggers/_base.py +2 -0
- nshtrainer/loggers/csv.py +2 -0
- nshtrainer/loggers/tensorboard.py +2 -0
- nshtrainer/loggers/wandb.py +6 -4
- nshtrainer/lr_scheduler/__init__.py +2 -0
- nshtrainer/lr_scheduler/_base.py +8 -11
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +18 -17
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +8 -6
- nshtrainer/metrics/__init__.py +2 -0
- nshtrainer/metrics/_config.py +2 -0
- nshtrainer/model/__init__.py +2 -0
- nshtrainer/model/base.py +2 -0
- nshtrainer/model/config.py +2 -0
- nshtrainer/model/mixins/callback.py +2 -0
- nshtrainer/model/mixins/logger.py +2 -0
- nshtrainer/nn/__init__.py +2 -0
- nshtrainer/nn/mlp.py +2 -0
- nshtrainer/nn/module_dict.py +2 -0
- nshtrainer/nn/module_list.py +2 -0
- nshtrainer/nn/nonlinearity.py +2 -0
- nshtrainer/optimizer.py +2 -0
- nshtrainer/profiler/__init__.py +2 -0
- nshtrainer/profiler/_base.py +2 -0
- nshtrainer/profiler/advanced.py +2 -0
- nshtrainer/profiler/pytorch.py +2 -0
- nshtrainer/profiler/simple.py +2 -0
- nshtrainer/runner.py +2 -0
- nshtrainer/scripts/find_packages.py +2 -0
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +16 -13
- nshtrainer/trainer/_runtime_callback.py +2 -0
- nshtrainer/trainer/checkpoint_connector.py +2 -0
- nshtrainer/trainer/signal_connector.py +2 -0
- nshtrainer/trainer/trainer.py +2 -0
- nshtrainer/util/_environment_info.py +2 -0
- nshtrainer/util/bf16.py +2 -0
- nshtrainer/util/config/__init__.py +2 -0
- nshtrainer/util/config/dtype.py +2 -0
- nshtrainer/util/config/duration.py +2 -0
- nshtrainer/util/environment.py +2 -0
- nshtrainer/util/path.py +2 -0
- nshtrainer/util/seed.py +2 -0
- nshtrainer/util/slurm.py +3 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +2 -0
- {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/METADATA +1 -1
- nshtrainer-0.44.0.dist-info/RECORD +162 -0
- nshtrainer-0.42.0.dist-info/RECORD +0 -162
- {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/WHEEL +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -8,8 +10,12 @@ if TYPE_CHECKING:
|
|
|
8
10
|
from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
|
|
9
11
|
from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
|
|
10
12
|
from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
|
|
11
|
-
from nshtrainer.loggers.wandb import
|
|
12
|
-
|
|
13
|
+
from nshtrainer.loggers.wandb import (
|
|
14
|
+
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
|
15
|
+
)
|
|
16
|
+
from nshtrainer.loggers.wandb import (
|
|
17
|
+
WandbWatchCallbackConfig as WandbWatchCallbackConfig,
|
|
18
|
+
)
|
|
13
19
|
else:
|
|
14
20
|
|
|
15
21
|
def __getattr__(name):
|
|
@@ -17,20 +23,22 @@ else:
|
|
|
17
23
|
|
|
18
24
|
if name in globals():
|
|
19
25
|
return globals()[name]
|
|
20
|
-
if name == "CallbackConfigBase":
|
|
21
|
-
return importlib.import_module(
|
|
22
|
-
"nshtrainer.loggers.wandb"
|
|
23
|
-
).CallbackConfigBase
|
|
24
26
|
if name == "WandbLoggerConfig":
|
|
25
27
|
return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
|
|
26
|
-
if name == "
|
|
28
|
+
if name == "WandbUploadCodeCallbackConfig":
|
|
27
29
|
return importlib.import_module(
|
|
28
30
|
"nshtrainer.loggers.wandb"
|
|
29
|
-
).
|
|
30
|
-
if name == "
|
|
31
|
-
return importlib.import_module(
|
|
31
|
+
).WandbUploadCodeCallbackConfig
|
|
32
|
+
if name == "WandbWatchCallbackConfig":
|
|
33
|
+
return importlib.import_module(
|
|
34
|
+
"nshtrainer.loggers.wandb"
|
|
35
|
+
).WandbWatchCallbackConfig
|
|
32
36
|
if name == "BaseLoggerConfig":
|
|
33
37
|
return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
|
|
38
|
+
if name == "CallbackConfigBase":
|
|
39
|
+
return importlib.import_module(
|
|
40
|
+
"nshtrainer.loggers.wandb"
|
|
41
|
+
).CallbackConfigBase
|
|
34
42
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
35
43
|
|
|
36
44
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -21,6 +23,10 @@ else:
|
|
|
21
23
|
|
|
22
24
|
if name in globals():
|
|
23
25
|
return globals()[name]
|
|
26
|
+
if name == "LRSchedulerConfigBase":
|
|
27
|
+
return importlib.import_module(
|
|
28
|
+
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
29
|
+
).LRSchedulerConfigBase
|
|
24
30
|
if name == "MetricConfig":
|
|
25
31
|
return importlib.import_module(
|
|
26
32
|
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
@@ -29,10 +35,6 @@ else:
|
|
|
29
35
|
return importlib.import_module(
|
|
30
36
|
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
31
37
|
).ReduceLROnPlateauConfig
|
|
32
|
-
if name == "LRSchedulerConfigBase":
|
|
33
|
-
return importlib.import_module(
|
|
34
|
-
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
35
|
-
).LRSchedulerConfigBase
|
|
36
38
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
37
39
|
|
|
38
40
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -18,18 +20,18 @@ else:
|
|
|
18
20
|
|
|
19
21
|
if name in globals():
|
|
20
22
|
return globals()[name]
|
|
21
|
-
if name == "
|
|
22
|
-
return importlib.import_module("nshtrainer.model
|
|
23
|
+
if name == "MetricConfig":
|
|
24
|
+
return importlib.import_module("nshtrainer.model").MetricConfig
|
|
23
25
|
if name == "TrainerConfig":
|
|
24
26
|
return importlib.import_module("nshtrainer.model").TrainerConfig
|
|
25
|
-
if name == "EnvironmentConfig":
|
|
26
|
-
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
|
27
27
|
if name == "BaseConfig":
|
|
28
28
|
return importlib.import_module("nshtrainer.model").BaseConfig
|
|
29
|
-
if name == "
|
|
30
|
-
return importlib.import_module("nshtrainer.model").
|
|
29
|
+
if name == "EnvironmentConfig":
|
|
30
|
+
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
|
31
31
|
if name == "DirectoryConfig":
|
|
32
32
|
return importlib.import_module("nshtrainer.model").DirectoryConfig
|
|
33
|
+
if name == "CallbackConfigBase":
|
|
34
|
+
return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
|
|
33
35
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
34
36
|
|
|
35
37
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -14,10 +16,10 @@ else:
|
|
|
14
16
|
|
|
15
17
|
if name in globals():
|
|
16
18
|
return globals()[name]
|
|
17
|
-
if name == "EnvironmentConfig":
|
|
18
|
-
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
|
19
19
|
if name == "BaseConfig":
|
|
20
20
|
return importlib.import_module("nshtrainer.model.base").BaseConfig
|
|
21
|
+
if name == "EnvironmentConfig":
|
|
22
|
+
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
|
21
23
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
22
24
|
|
|
23
25
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -18,18 +20,18 @@ else:
|
|
|
18
20
|
|
|
19
21
|
if name in globals():
|
|
20
22
|
return globals()[name]
|
|
21
|
-
if name == "
|
|
22
|
-
return importlib.import_module("nshtrainer.model.config").
|
|
23
|
+
if name == "MetricConfig":
|
|
24
|
+
return importlib.import_module("nshtrainer.model.config").MetricConfig
|
|
23
25
|
if name == "TrainerConfig":
|
|
24
26
|
return importlib.import_module("nshtrainer.model.config").TrainerConfig
|
|
25
|
-
if name == "EnvironmentConfig":
|
|
26
|
-
return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
|
|
27
27
|
if name == "BaseConfig":
|
|
28
28
|
return importlib.import_module("nshtrainer.model.config").BaseConfig
|
|
29
|
-
if name == "
|
|
30
|
-
return importlib.import_module("nshtrainer.model.config").
|
|
29
|
+
if name == "EnvironmentConfig":
|
|
30
|
+
return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
|
|
31
31
|
if name == "DirectoryConfig":
|
|
32
32
|
return importlib.import_module("nshtrainer.model.config").DirectoryConfig
|
|
33
|
+
if name == "CallbackConfigBase":
|
|
34
|
+
return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
|
|
33
35
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
34
36
|
|
|
35
37
|
# Submodule exports
|
nshtrainer/config/nn/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -35,36 +37,36 @@ else:
|
|
|
35
37
|
return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
|
|
36
38
|
if name == "MLPConfig":
|
|
37
39
|
return importlib.import_module("nshtrainer.nn").MLPConfig
|
|
40
|
+
if name == "PReLUConfig":
|
|
41
|
+
return importlib.import_module("nshtrainer.nn").PReLUConfig
|
|
42
|
+
if name == "LeakyReLUNonlinearityConfig":
|
|
43
|
+
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
|
38
44
|
if name == "SwiGLUNonlinearityConfig":
|
|
39
45
|
return importlib.import_module(
|
|
40
46
|
"nshtrainer.nn.nonlinearity"
|
|
41
47
|
).SwiGLUNonlinearityConfig
|
|
42
|
-
if name == "
|
|
43
|
-
return importlib.import_module("nshtrainer.nn").
|
|
48
|
+
if name == "SoftsignNonlinearityConfig":
|
|
49
|
+
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
|
44
50
|
if name == "SiLUNonlinearityConfig":
|
|
45
51
|
return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
|
|
52
|
+
if name == "SigmoidNonlinearityConfig":
|
|
53
|
+
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
|
54
|
+
if name == "SoftplusNonlinearityConfig":
|
|
55
|
+
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
|
46
56
|
if name == "ELUNonlinearityConfig":
|
|
47
57
|
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
|
58
|
+
if name == "SoftmaxNonlinearityConfig":
|
|
59
|
+
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
|
48
60
|
if name == "GELUNonlinearityConfig":
|
|
49
61
|
return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
|
|
50
|
-
if name == "SoftplusNonlinearityConfig":
|
|
51
|
-
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
|
52
|
-
if name == "SoftsignNonlinearityConfig":
|
|
53
|
-
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
|
54
62
|
if name == "SwishNonlinearityConfig":
|
|
55
63
|
return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
|
|
56
|
-
if name == "SoftmaxNonlinearityConfig":
|
|
57
|
-
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
|
58
64
|
if name == "MishNonlinearityConfig":
|
|
59
65
|
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
|
60
|
-
if name == "SigmoidNonlinearityConfig":
|
|
61
|
-
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
|
62
66
|
if name == "TanhNonlinearityConfig":
|
|
63
67
|
return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
|
|
64
|
-
if name == "
|
|
65
|
-
return importlib.import_module("nshtrainer.nn").
|
|
66
|
-
if name == "LeakyReLUNonlinearityConfig":
|
|
67
|
-
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
|
68
|
+
if name == "ReLUNonlinearityConfig":
|
|
69
|
+
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
|
68
70
|
if name == "NonlinearityConfig":
|
|
69
71
|
return importlib.import_module("nshtrainer.nn").NonlinearityConfig
|
|
70
72
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -56,64 +58,64 @@ else:
|
|
|
56
58
|
|
|
57
59
|
if name in globals():
|
|
58
60
|
return globals()[name]
|
|
61
|
+
if name == "PReLUConfig":
|
|
62
|
+
return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
|
|
63
|
+
if name == "LeakyReLUNonlinearityConfig":
|
|
64
|
+
return importlib.import_module(
|
|
65
|
+
"nshtrainer.nn.nonlinearity"
|
|
66
|
+
).LeakyReLUNonlinearityConfig
|
|
59
67
|
if name == "SwiGLUNonlinearityConfig":
|
|
60
68
|
return importlib.import_module(
|
|
61
69
|
"nshtrainer.nn.nonlinearity"
|
|
62
70
|
).SwiGLUNonlinearityConfig
|
|
63
|
-
if name == "
|
|
71
|
+
if name == "SoftsignNonlinearityConfig":
|
|
64
72
|
return importlib.import_module(
|
|
65
73
|
"nshtrainer.nn.nonlinearity"
|
|
66
|
-
).
|
|
74
|
+
).SoftsignNonlinearityConfig
|
|
67
75
|
if name == "SiLUNonlinearityConfig":
|
|
68
76
|
return importlib.import_module(
|
|
69
77
|
"nshtrainer.nn.nonlinearity"
|
|
70
78
|
).SiLUNonlinearityConfig
|
|
71
|
-
if name == "
|
|
72
|
-
return importlib.import_module(
|
|
73
|
-
"nshtrainer.nn.nonlinearity"
|
|
74
|
-
).ELUNonlinearityConfig
|
|
75
|
-
if name == "GELUNonlinearityConfig":
|
|
79
|
+
if name == "SigmoidNonlinearityConfig":
|
|
76
80
|
return importlib.import_module(
|
|
77
81
|
"nshtrainer.nn.nonlinearity"
|
|
78
|
-
).
|
|
82
|
+
).SigmoidNonlinearityConfig
|
|
79
83
|
if name == "SoftplusNonlinearityConfig":
|
|
80
84
|
return importlib.import_module(
|
|
81
85
|
"nshtrainer.nn.nonlinearity"
|
|
82
86
|
).SoftplusNonlinearityConfig
|
|
83
|
-
if name == "
|
|
84
|
-
return importlib.import_module(
|
|
85
|
-
"nshtrainer.nn.nonlinearity"
|
|
86
|
-
).SoftsignNonlinearityConfig
|
|
87
|
-
if name == "SwishNonlinearityConfig":
|
|
87
|
+
if name == "ELUNonlinearityConfig":
|
|
88
88
|
return importlib.import_module(
|
|
89
89
|
"nshtrainer.nn.nonlinearity"
|
|
90
|
-
).
|
|
90
|
+
).ELUNonlinearityConfig
|
|
91
91
|
if name == "SoftmaxNonlinearityConfig":
|
|
92
92
|
return importlib.import_module(
|
|
93
93
|
"nshtrainer.nn.nonlinearity"
|
|
94
94
|
).SoftmaxNonlinearityConfig
|
|
95
|
-
if name == "
|
|
95
|
+
if name == "GELUNonlinearityConfig":
|
|
96
96
|
return importlib.import_module(
|
|
97
97
|
"nshtrainer.nn.nonlinearity"
|
|
98
|
-
).
|
|
99
|
-
if name == "
|
|
98
|
+
).GELUNonlinearityConfig
|
|
99
|
+
if name == "SwishNonlinearityConfig":
|
|
100
100
|
return importlib.import_module(
|
|
101
101
|
"nshtrainer.nn.nonlinearity"
|
|
102
|
-
).
|
|
102
|
+
).SwishNonlinearityConfig
|
|
103
|
+
if name == "MishNonlinearityConfig":
|
|
104
|
+
return importlib.import_module(
|
|
105
|
+
"nshtrainer.nn.nonlinearity"
|
|
106
|
+
).MishNonlinearityConfig
|
|
103
107
|
if name == "TanhNonlinearityConfig":
|
|
104
108
|
return importlib.import_module(
|
|
105
109
|
"nshtrainer.nn.nonlinearity"
|
|
106
110
|
).TanhNonlinearityConfig
|
|
107
|
-
if name == "
|
|
111
|
+
if name == "ReLUNonlinearityConfig":
|
|
108
112
|
return importlib.import_module(
|
|
109
113
|
"nshtrainer.nn.nonlinearity"
|
|
110
|
-
).
|
|
111
|
-
if name == "
|
|
112
|
-
return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
|
|
113
|
-
if name == "LeakyReLUNonlinearityConfig":
|
|
114
|
+
).ReLUNonlinearityConfig
|
|
115
|
+
if name == "BaseNonlinearityConfig":
|
|
114
116
|
return importlib.import_module(
|
|
115
117
|
"nshtrainer.nn.nonlinearity"
|
|
116
|
-
).
|
|
118
|
+
).BaseNonlinearityConfig
|
|
117
119
|
if name == "NonlinearityConfig":
|
|
118
120
|
return importlib.import_module(
|
|
119
121
|
"nshtrainer.nn.nonlinearity"
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -16,14 +18,14 @@ else:
|
|
|
16
18
|
|
|
17
19
|
if name in globals():
|
|
18
20
|
return globals()[name]
|
|
19
|
-
if name == "BaseProfilerConfig":
|
|
20
|
-
return importlib.import_module(
|
|
21
|
-
"nshtrainer.profiler.advanced"
|
|
22
|
-
).BaseProfilerConfig
|
|
23
21
|
if name == "AdvancedProfilerConfig":
|
|
24
22
|
return importlib.import_module(
|
|
25
23
|
"nshtrainer.profiler.advanced"
|
|
26
24
|
).AdvancedProfilerConfig
|
|
25
|
+
if name == "BaseProfilerConfig":
|
|
26
|
+
return importlib.import_module(
|
|
27
|
+
"nshtrainer.profiler.advanced"
|
|
28
|
+
).BaseProfilerConfig
|
|
27
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
28
30
|
|
|
29
31
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -16,14 +18,14 @@ else:
|
|
|
16
18
|
|
|
17
19
|
if name in globals():
|
|
18
20
|
return globals()[name]
|
|
19
|
-
if name == "BaseProfilerConfig":
|
|
20
|
-
return importlib.import_module(
|
|
21
|
-
"nshtrainer.profiler.pytorch"
|
|
22
|
-
).BaseProfilerConfig
|
|
23
21
|
if name == "PyTorchProfilerConfig":
|
|
24
22
|
return importlib.import_module(
|
|
25
23
|
"nshtrainer.profiler.pytorch"
|
|
26
24
|
).PyTorchProfilerConfig
|
|
25
|
+
if name == "BaseProfilerConfig":
|
|
26
|
+
return importlib.import_module(
|
|
27
|
+
"nshtrainer.profiler.pytorch"
|
|
28
|
+
).BaseProfilerConfig
|
|
27
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
28
30
|
|
|
29
31
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -14,14 +16,14 @@ else:
|
|
|
14
16
|
|
|
15
17
|
if name in globals():
|
|
16
18
|
return globals()[name]
|
|
17
|
-
if name == "BaseProfilerConfig":
|
|
18
|
-
return importlib.import_module(
|
|
19
|
-
"nshtrainer.profiler.simple"
|
|
20
|
-
).BaseProfilerConfig
|
|
21
19
|
if name == "SimpleProfilerConfig":
|
|
22
20
|
return importlib.import_module(
|
|
23
21
|
"nshtrainer.profiler.simple"
|
|
24
22
|
).SimpleProfilerConfig
|
|
23
|
+
if name == "BaseProfilerConfig":
|
|
24
|
+
return importlib.import_module(
|
|
25
|
+
"nshtrainer.profiler.simple"
|
|
26
|
+
).BaseProfilerConfig
|
|
25
27
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
26
28
|
|
|
27
29
|
# Submodule exports
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
__codegen__ = True
|
|
2
4
|
|
|
3
5
|
from typing import TYPE_CHECKING
|
|
@@ -23,7 +25,9 @@ if TYPE_CHECKING:
|
|
|
23
25
|
from nshtrainer.trainer._config import (
|
|
24
26
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
|
25
27
|
)
|
|
26
|
-
from nshtrainer.trainer._config import
|
|
28
|
+
from nshtrainer.trainer._config import (
|
|
29
|
+
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
|
30
|
+
)
|
|
27
31
|
from nshtrainer.trainer._config import (
|
|
28
32
|
GradientClippingConfig as GradientClippingConfig,
|
|
29
33
|
)
|
|
@@ -42,11 +46,11 @@ if TYPE_CHECKING:
|
|
|
42
46
|
ReproducibilityConfig as ReproducibilityConfig,
|
|
43
47
|
)
|
|
44
48
|
from nshtrainer.trainer._config import (
|
|
45
|
-
|
|
49
|
+
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
|
46
50
|
)
|
|
47
51
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
|
48
52
|
from nshtrainer.trainer._config import (
|
|
49
|
-
|
|
53
|
+
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
|
50
54
|
)
|
|
51
55
|
from nshtrainer.trainer._config import (
|
|
52
56
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
|
@@ -60,80 +64,80 @@ else:
|
|
|
60
64
|
|
|
61
65
|
if name in globals():
|
|
62
66
|
return globals()[name]
|
|
63
|
-
if name == "
|
|
64
|
-
return importlib.import_module(
|
|
65
|
-
"nshtrainer.trainer._config"
|
|
66
|
-
).HuggingFaceHubConfig
|
|
67
|
-
if name == "OptimizationConfig":
|
|
67
|
+
if name == "SanityCheckingConfig":
|
|
68
68
|
return importlib.import_module(
|
|
69
69
|
"nshtrainer.trainer._config"
|
|
70
|
-
).
|
|
70
|
+
).SanityCheckingConfig
|
|
71
71
|
if name == "TrainerConfig":
|
|
72
72
|
return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
|
|
73
|
-
if name == "
|
|
73
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
|
74
74
|
return importlib.import_module(
|
|
75
75
|
"nshtrainer.trainer._config"
|
|
76
|
-
).
|
|
76
|
+
).OnExceptionCheckpointCallbackConfig
|
|
77
77
|
if name == "GradientClippingConfig":
|
|
78
78
|
return importlib.import_module(
|
|
79
79
|
"nshtrainer.trainer._config"
|
|
80
80
|
).GradientClippingConfig
|
|
81
|
-
if name == "
|
|
81
|
+
if name == "WandbLoggerConfig":
|
|
82
82
|
return importlib.import_module(
|
|
83
83
|
"nshtrainer.trainer._config"
|
|
84
|
-
).
|
|
85
|
-
if name == "
|
|
86
|
-
return importlib.import_module("nshtrainer.trainer._config").
|
|
87
|
-
if name == "
|
|
84
|
+
).WandbLoggerConfig
|
|
85
|
+
if name == "LoggingConfig":
|
|
86
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
|
87
|
+
if name == "TensorboardLoggerConfig":
|
|
88
88
|
return importlib.import_module(
|
|
89
89
|
"nshtrainer.trainer._config"
|
|
90
|
-
).
|
|
91
|
-
if name == "
|
|
90
|
+
).TensorboardLoggerConfig
|
|
91
|
+
if name == "RLPSanityChecksCallbackConfig":
|
|
92
92
|
return importlib.import_module(
|
|
93
93
|
"nshtrainer.trainer._config"
|
|
94
|
-
).
|
|
95
|
-
if name == "
|
|
94
|
+
).RLPSanityChecksCallbackConfig
|
|
95
|
+
if name == "CheckpointSavingConfig":
|
|
96
96
|
return importlib.import_module(
|
|
97
97
|
"nshtrainer.trainer._config"
|
|
98
|
-
).
|
|
99
|
-
if name == "
|
|
98
|
+
).CheckpointSavingConfig
|
|
99
|
+
if name == "CSVLoggerConfig":
|
|
100
|
+
return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
|
|
101
|
+
if name == "HuggingFaceHubConfig":
|
|
102
|
+
return importlib.import_module(
|
|
103
|
+
"nshtrainer.trainer._config"
|
|
104
|
+
).HuggingFaceHubConfig
|
|
105
|
+
if name == "CheckpointLoadingConfig":
|
|
100
106
|
return importlib.import_module(
|
|
101
107
|
"nshtrainer.trainer._config"
|
|
102
|
-
).
|
|
108
|
+
).CheckpointLoadingConfig
|
|
103
109
|
if name == "DebugFlagCallbackConfig":
|
|
104
110
|
return importlib.import_module(
|
|
105
111
|
"nshtrainer.trainer._config"
|
|
106
112
|
).DebugFlagCallbackConfig
|
|
107
|
-
if name == "
|
|
113
|
+
if name == "CallbackConfigBase":
|
|
108
114
|
return importlib.import_module(
|
|
109
115
|
"nshtrainer.trainer._config"
|
|
110
|
-
).
|
|
111
|
-
if name == "
|
|
116
|
+
).CallbackConfigBase
|
|
117
|
+
if name == "LastCheckpointCallbackConfig":
|
|
112
118
|
return importlib.import_module(
|
|
113
119
|
"nshtrainer.trainer._config"
|
|
114
|
-
).
|
|
115
|
-
if name == "
|
|
120
|
+
).LastCheckpointCallbackConfig
|
|
121
|
+
if name == "SharedParametersCallbackConfig":
|
|
116
122
|
return importlib.import_module(
|
|
117
123
|
"nshtrainer.trainer._config"
|
|
118
|
-
).
|
|
119
|
-
if name == "
|
|
124
|
+
).SharedParametersCallbackConfig
|
|
125
|
+
if name == "ReproducibilityConfig":
|
|
120
126
|
return importlib.import_module(
|
|
121
127
|
"nshtrainer.trainer._config"
|
|
122
|
-
).
|
|
123
|
-
if name == "
|
|
124
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
|
125
|
-
if name == "SanityCheckingConfig":
|
|
128
|
+
).ReproducibilityConfig
|
|
129
|
+
if name == "EarlyStoppingCallbackConfig":
|
|
126
130
|
return importlib.import_module(
|
|
127
131
|
"nshtrainer.trainer._config"
|
|
128
|
-
).
|
|
129
|
-
if name == "
|
|
132
|
+
).EarlyStoppingCallbackConfig
|
|
133
|
+
if name == "OptimizationConfig":
|
|
130
134
|
return importlib.import_module(
|
|
131
135
|
"nshtrainer.trainer._config"
|
|
132
|
-
).
|
|
133
|
-
if name == "
|
|
136
|
+
).OptimizationConfig
|
|
137
|
+
if name == "BestCheckpointCallbackConfig":
|
|
134
138
|
return importlib.import_module(
|
|
135
139
|
"nshtrainer.trainer._config"
|
|
136
|
-
).
|
|
140
|
+
).BestCheckpointCallbackConfig
|
|
137
141
|
if name == "CallbackConfig":
|
|
138
142
|
return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
|
|
139
143
|
if name == "CheckpointCallbackConfig":
|