nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
|
7
7
|
# Config/alias imports
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
10
11
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
11
12
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
12
13
|
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
@@ -26,8 +27,16 @@ else:
|
|
26
27
|
|
27
28
|
if name in globals():
|
28
29
|
return globals()[name]
|
30
|
+
if name == "ActSaveLoggerConfig":
|
31
|
+
return importlib.import_module("nshtrainer.loggers").ActSaveLoggerConfig
|
29
32
|
if name == "BaseLoggerConfig":
|
30
33
|
return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
|
34
|
+
if name == "CSVLoggerConfig":
|
35
|
+
return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
|
36
|
+
if name == "CallbackConfigBase":
|
37
|
+
return importlib.import_module(
|
38
|
+
"nshtrainer.loggers.wandb"
|
39
|
+
).CallbackConfigBase
|
31
40
|
if name == "TensorboardLoggerConfig":
|
32
41
|
return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
|
33
42
|
if name == "WandbLoggerConfig":
|
@@ -40,12 +49,6 @@ else:
|
|
40
49
|
return importlib.import_module(
|
41
50
|
"nshtrainer.loggers.wandb"
|
42
51
|
).WandbWatchCallbackConfig
|
43
|
-
if name == "CallbackConfigBase":
|
44
|
-
return importlib.import_module(
|
45
|
-
"nshtrainer.loggers.wandb"
|
46
|
-
).CallbackConfigBase
|
47
|
-
if name == "CSVLoggerConfig":
|
48
|
-
return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
|
49
52
|
if name == "LoggerConfig":
|
50
53
|
return importlib.import_module("nshtrainer.loggers").LoggerConfig
|
51
54
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
@@ -53,6 +56,7 @@ else:
|
|
53
56
|
|
54
57
|
# Submodule exports
|
55
58
|
from . import _base as _base
|
59
|
+
from . import actsave as actsave
|
56
60
|
from . import csv as csv
|
57
61
|
from . import tensorboard as tensorboard
|
58
62
|
from . import wandb as wandb
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.loggers.actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
|
11
|
+
from nshtrainer.loggers.actsave import BaseLoggerConfig as BaseLoggerConfig
|
12
|
+
else:
|
13
|
+
|
14
|
+
def __getattr__(name):
|
15
|
+
import importlib
|
16
|
+
|
17
|
+
if name in globals():
|
18
|
+
return globals()[name]
|
19
|
+
if name == "ActSaveLoggerConfig":
|
20
|
+
return importlib.import_module(
|
21
|
+
"nshtrainer.loggers.actsave"
|
22
|
+
).ActSaveLoggerConfig
|
23
|
+
if name == "BaseLoggerConfig":
|
24
|
+
return importlib.import_module(
|
25
|
+
"nshtrainer.loggers.actsave"
|
26
|
+
).BaseLoggerConfig
|
27
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
28
|
+
|
29
|
+
# Submodule exports
|
@@ -16,10 +16,10 @@ else:
|
|
16
16
|
|
17
17
|
if name in globals():
|
18
18
|
return globals()[name]
|
19
|
-
if name == "CSVLoggerConfig":
|
20
|
-
return importlib.import_module("nshtrainer.loggers.csv").CSVLoggerConfig
|
21
19
|
if name == "BaseLoggerConfig":
|
22
20
|
return importlib.import_module("nshtrainer.loggers.csv").BaseLoggerConfig
|
21
|
+
if name == "CSVLoggerConfig":
|
22
|
+
return importlib.import_module("nshtrainer.loggers.csv").CSVLoggerConfig
|
23
23
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
24
24
|
|
25
25
|
# Submodule exports
|
@@ -23,6 +23,12 @@ else:
|
|
23
23
|
|
24
24
|
if name in globals():
|
25
25
|
return globals()[name]
|
26
|
+
if name == "BaseLoggerConfig":
|
27
|
+
return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
|
28
|
+
if name == "CallbackConfigBase":
|
29
|
+
return importlib.import_module(
|
30
|
+
"nshtrainer.loggers.wandb"
|
31
|
+
).CallbackConfigBase
|
26
32
|
if name == "WandbLoggerConfig":
|
27
33
|
return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
|
28
34
|
if name == "WandbUploadCodeCallbackConfig":
|
@@ -33,12 +39,6 @@ else:
|
|
33
39
|
return importlib.import_module(
|
34
40
|
"nshtrainer.loggers.wandb"
|
35
41
|
).WandbWatchCallbackConfig
|
36
|
-
if name == "BaseLoggerConfig":
|
37
|
-
return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
|
38
|
-
if name == "CallbackConfigBase":
|
39
|
-
return importlib.import_module(
|
40
|
-
"nshtrainer.loggers.wandb"
|
41
|
-
).CallbackConfigBase
|
42
42
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
43
43
|
|
44
44
|
# Submodule exports
|
@@ -23,14 +23,14 @@ else:
|
|
23
23
|
|
24
24
|
if name in globals():
|
25
25
|
return globals()[name]
|
26
|
-
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
27
|
-
return importlib.import_module(
|
28
|
-
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
29
|
-
).LinearWarmupCosineDecayLRSchedulerConfig
|
30
26
|
if name == "LRSchedulerConfigBase":
|
31
27
|
return importlib.import_module(
|
32
28
|
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
33
29
|
).LRSchedulerConfigBase
|
30
|
+
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
31
|
+
return importlib.import_module(
|
32
|
+
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
33
|
+
).LinearWarmupCosineDecayLRSchedulerConfig
|
34
34
|
if name == "DurationConfig":
|
35
35
|
return importlib.import_module(
|
36
36
|
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
nshtrainer/config/nn/__init__.py
CHANGED
@@ -35,38 +35,38 @@ else:
|
|
35
35
|
return globals()[name]
|
36
36
|
if name == "BaseNonlinearityConfig":
|
37
37
|
return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
|
38
|
+
if name == "ELUNonlinearityConfig":
|
39
|
+
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
40
|
+
if name == "GELUNonlinearityConfig":
|
41
|
+
return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
|
42
|
+
if name == "LeakyReLUNonlinearityConfig":
|
43
|
+
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
38
44
|
if name == "MLPConfig":
|
39
45
|
return importlib.import_module("nshtrainer.nn").MLPConfig
|
46
|
+
if name == "MishNonlinearityConfig":
|
47
|
+
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
40
48
|
if name == "PReLUConfig":
|
41
49
|
return importlib.import_module("nshtrainer.nn").PReLUConfig
|
42
|
-
if name == "
|
43
|
-
return importlib.import_module("nshtrainer.nn").
|
44
|
-
if name == "SwiGLUNonlinearityConfig":
|
45
|
-
return importlib.import_module(
|
46
|
-
"nshtrainer.nn.nonlinearity"
|
47
|
-
).SwiGLUNonlinearityConfig
|
48
|
-
if name == "SoftsignNonlinearityConfig":
|
49
|
-
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
50
|
+
if name == "ReLUNonlinearityConfig":
|
51
|
+
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
50
52
|
if name == "SiLUNonlinearityConfig":
|
51
53
|
return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
|
52
54
|
if name == "SigmoidNonlinearityConfig":
|
53
55
|
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
54
|
-
if name == "SoftplusNonlinearityConfig":
|
55
|
-
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
56
|
-
if name == "ELUNonlinearityConfig":
|
57
|
-
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
58
56
|
if name == "SoftmaxNonlinearityConfig":
|
59
57
|
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
60
|
-
if name == "
|
61
|
-
return importlib.import_module("nshtrainer.nn").
|
58
|
+
if name == "SoftplusNonlinearityConfig":
|
59
|
+
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
60
|
+
if name == "SoftsignNonlinearityConfig":
|
61
|
+
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
62
|
+
if name == "SwiGLUNonlinearityConfig":
|
63
|
+
return importlib.import_module(
|
64
|
+
"nshtrainer.nn.nonlinearity"
|
65
|
+
).SwiGLUNonlinearityConfig
|
62
66
|
if name == "SwishNonlinearityConfig":
|
63
67
|
return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
|
64
|
-
if name == "MishNonlinearityConfig":
|
65
|
-
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
66
68
|
if name == "TanhNonlinearityConfig":
|
67
69
|
return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
|
68
|
-
if name == "ReLUNonlinearityConfig":
|
69
|
-
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
70
70
|
if name == "NonlinearityConfig":
|
71
71
|
return importlib.import_module("nshtrainer.nn").NonlinearityConfig
|
72
72
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
@@ -58,20 +58,32 @@ else:
|
|
58
58
|
|
59
59
|
if name in globals():
|
60
60
|
return globals()[name]
|
61
|
-
if name == "
|
62
|
-
return importlib.import_module(
|
61
|
+
if name == "BaseNonlinearityConfig":
|
62
|
+
return importlib.import_module(
|
63
|
+
"nshtrainer.nn.nonlinearity"
|
64
|
+
).BaseNonlinearityConfig
|
65
|
+
if name == "ELUNonlinearityConfig":
|
66
|
+
return importlib.import_module(
|
67
|
+
"nshtrainer.nn.nonlinearity"
|
68
|
+
).ELUNonlinearityConfig
|
69
|
+
if name == "GELUNonlinearityConfig":
|
70
|
+
return importlib.import_module(
|
71
|
+
"nshtrainer.nn.nonlinearity"
|
72
|
+
).GELUNonlinearityConfig
|
63
73
|
if name == "LeakyReLUNonlinearityConfig":
|
64
74
|
return importlib.import_module(
|
65
75
|
"nshtrainer.nn.nonlinearity"
|
66
76
|
).LeakyReLUNonlinearityConfig
|
67
|
-
if name == "
|
77
|
+
if name == "MishNonlinearityConfig":
|
68
78
|
return importlib.import_module(
|
69
79
|
"nshtrainer.nn.nonlinearity"
|
70
|
-
).
|
71
|
-
if name == "
|
80
|
+
).MishNonlinearityConfig
|
81
|
+
if name == "PReLUConfig":
|
82
|
+
return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
|
83
|
+
if name == "ReLUNonlinearityConfig":
|
72
84
|
return importlib.import_module(
|
73
85
|
"nshtrainer.nn.nonlinearity"
|
74
|
-
).
|
86
|
+
).ReLUNonlinearityConfig
|
75
87
|
if name == "SiLUNonlinearityConfig":
|
76
88
|
return importlib.import_module(
|
77
89
|
"nshtrainer.nn.nonlinearity"
|
@@ -80,42 +92,30 @@ else:
|
|
80
92
|
return importlib.import_module(
|
81
93
|
"nshtrainer.nn.nonlinearity"
|
82
94
|
).SigmoidNonlinearityConfig
|
83
|
-
if name == "
|
95
|
+
if name == "SoftmaxNonlinearityConfig":
|
84
96
|
return importlib.import_module(
|
85
97
|
"nshtrainer.nn.nonlinearity"
|
86
|
-
).
|
87
|
-
if name == "
|
98
|
+
).SoftmaxNonlinearityConfig
|
99
|
+
if name == "SoftplusNonlinearityConfig":
|
88
100
|
return importlib.import_module(
|
89
101
|
"nshtrainer.nn.nonlinearity"
|
90
|
-
).
|
91
|
-
if name == "
|
102
|
+
).SoftplusNonlinearityConfig
|
103
|
+
if name == "SoftsignNonlinearityConfig":
|
92
104
|
return importlib.import_module(
|
93
105
|
"nshtrainer.nn.nonlinearity"
|
94
|
-
).
|
95
|
-
if name == "
|
106
|
+
).SoftsignNonlinearityConfig
|
107
|
+
if name == "SwiGLUNonlinearityConfig":
|
96
108
|
return importlib.import_module(
|
97
109
|
"nshtrainer.nn.nonlinearity"
|
98
|
-
).
|
110
|
+
).SwiGLUNonlinearityConfig
|
99
111
|
if name == "SwishNonlinearityConfig":
|
100
112
|
return importlib.import_module(
|
101
113
|
"nshtrainer.nn.nonlinearity"
|
102
114
|
).SwishNonlinearityConfig
|
103
|
-
if name == "MishNonlinearityConfig":
|
104
|
-
return importlib.import_module(
|
105
|
-
"nshtrainer.nn.nonlinearity"
|
106
|
-
).MishNonlinearityConfig
|
107
115
|
if name == "TanhNonlinearityConfig":
|
108
116
|
return importlib.import_module(
|
109
117
|
"nshtrainer.nn.nonlinearity"
|
110
118
|
).TanhNonlinearityConfig
|
111
|
-
if name == "ReLUNonlinearityConfig":
|
112
|
-
return importlib.import_module(
|
113
|
-
"nshtrainer.nn.nonlinearity"
|
114
|
-
).ReLUNonlinearityConfig
|
115
|
-
if name == "BaseNonlinearityConfig":
|
116
|
-
return importlib.import_module(
|
117
|
-
"nshtrainer.nn.nonlinearity"
|
118
|
-
).BaseNonlinearityConfig
|
119
119
|
if name == "NonlinearityConfig":
|
120
120
|
return importlib.import_module(
|
121
121
|
"nshtrainer.nn.nonlinearity"
|
@@ -17,10 +17,10 @@ else:
|
|
17
17
|
|
18
18
|
if name in globals():
|
19
19
|
return globals()[name]
|
20
|
-
if name == "OptimizerConfigBase":
|
21
|
-
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
22
20
|
if name == "AdamWConfig":
|
23
21
|
return importlib.import_module("nshtrainer.optimizer").AdamWConfig
|
22
|
+
if name == "OptimizerConfigBase":
|
23
|
+
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
24
24
|
if name == "OptimizerConfig":
|
25
25
|
return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
|
26
26
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
@@ -19,12 +19,12 @@ else:
|
|
19
19
|
|
20
20
|
if name in globals():
|
21
21
|
return globals()[name]
|
22
|
+
if name == "AdvancedProfilerConfig":
|
23
|
+
return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
|
22
24
|
if name == "BaseProfilerConfig":
|
23
25
|
return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
|
24
26
|
if name == "PyTorchProfilerConfig":
|
25
27
|
return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
|
26
|
-
if name == "AdvancedProfilerConfig":
|
27
|
-
return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
|
28
28
|
if name == "SimpleProfilerConfig":
|
29
29
|
return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
|
30
30
|
if name == "ProfilerConfig":
|
@@ -18,14 +18,14 @@ else:
|
|
18
18
|
|
19
19
|
if name in globals():
|
20
20
|
return globals()[name]
|
21
|
-
if name == "PyTorchProfilerConfig":
|
22
|
-
return importlib.import_module(
|
23
|
-
"nshtrainer.profiler.pytorch"
|
24
|
-
).PyTorchProfilerConfig
|
25
21
|
if name == "BaseProfilerConfig":
|
26
22
|
return importlib.import_module(
|
27
23
|
"nshtrainer.profiler.pytorch"
|
28
24
|
).BaseProfilerConfig
|
25
|
+
if name == "PyTorchProfilerConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.profiler.pytorch"
|
28
|
+
).PyTorchProfilerConfig
|
29
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
30
30
|
|
31
31
|
# Submodule exports
|
@@ -16,14 +16,14 @@ else:
|
|
16
16
|
|
17
17
|
if name in globals():
|
18
18
|
return globals()[name]
|
19
|
-
if name == "SimpleProfilerConfig":
|
20
|
-
return importlib.import_module(
|
21
|
-
"nshtrainer.profiler.simple"
|
22
|
-
).SimpleProfilerConfig
|
23
19
|
if name == "BaseProfilerConfig":
|
24
20
|
return importlib.import_module(
|
25
21
|
"nshtrainer.profiler.simple"
|
26
22
|
).BaseProfilerConfig
|
23
|
+
if name == "SimpleProfilerConfig":
|
24
|
+
return importlib.import_module(
|
25
|
+
"nshtrainer.profiler.simple"
|
26
|
+
).SimpleProfilerConfig
|
27
27
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
28
28
|
|
29
29
|
# Submodule exports
|
@@ -0,0 +1,180 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
11
|
+
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
12
|
+
from nshtrainer.trainer._config import (
|
13
|
+
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
14
|
+
)
|
15
|
+
from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
|
16
|
+
from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
|
17
|
+
from nshtrainer.trainer._config import (
|
18
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
19
|
+
)
|
20
|
+
from nshtrainer.trainer._config import (
|
21
|
+
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
22
|
+
)
|
23
|
+
from nshtrainer.trainer._config import (
|
24
|
+
CheckpointSavingConfig as CheckpointSavingConfig,
|
25
|
+
)
|
26
|
+
from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
|
27
|
+
from nshtrainer.trainer._config import (
|
28
|
+
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
29
|
+
)
|
30
|
+
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
31
|
+
from nshtrainer.trainer._config import (
|
32
|
+
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
33
|
+
)
|
34
|
+
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
35
|
+
from nshtrainer.trainer._config import (
|
36
|
+
GradientClippingConfig as GradientClippingConfig,
|
37
|
+
)
|
38
|
+
from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
|
39
|
+
from nshtrainer.trainer._config import (
|
40
|
+
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
41
|
+
)
|
42
|
+
from nshtrainer.trainer._config import (
|
43
|
+
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
44
|
+
)
|
45
|
+
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
46
|
+
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
47
|
+
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
48
|
+
from nshtrainer.trainer._config import (
|
49
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
50
|
+
)
|
51
|
+
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
52
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
53
|
+
from nshtrainer.trainer._config import (
|
54
|
+
ReproducibilityConfig as ReproducibilityConfig,
|
55
|
+
)
|
56
|
+
from nshtrainer.trainer._config import (
|
57
|
+
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
58
|
+
)
|
59
|
+
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
60
|
+
from nshtrainer.trainer._config import (
|
61
|
+
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
62
|
+
)
|
63
|
+
from nshtrainer.trainer._config import (
|
64
|
+
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
65
|
+
)
|
66
|
+
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
67
|
+
else:
|
68
|
+
|
69
|
+
def __getattr__(name):
|
70
|
+
import importlib
|
71
|
+
|
72
|
+
if name in globals():
|
73
|
+
return globals()[name]
|
74
|
+
if name == "ActSaveLoggerConfig":
|
75
|
+
return importlib.import_module(
|
76
|
+
"nshtrainer.trainer._config"
|
77
|
+
).ActSaveLoggerConfig
|
78
|
+
if name == "BestCheckpointCallbackConfig":
|
79
|
+
return importlib.import_module(
|
80
|
+
"nshtrainer.trainer._config"
|
81
|
+
).BestCheckpointCallbackConfig
|
82
|
+
if name == "CSVLoggerConfig":
|
83
|
+
return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
|
84
|
+
if name == "CallbackConfigBase":
|
85
|
+
return importlib.import_module(
|
86
|
+
"nshtrainer.trainer._config"
|
87
|
+
).CallbackConfigBase
|
88
|
+
if name == "CheckpointLoadingConfig":
|
89
|
+
return importlib.import_module(
|
90
|
+
"nshtrainer.trainer._config"
|
91
|
+
).CheckpointLoadingConfig
|
92
|
+
if name == "CheckpointSavingConfig":
|
93
|
+
return importlib.import_module(
|
94
|
+
"nshtrainer.trainer._config"
|
95
|
+
).CheckpointSavingConfig
|
96
|
+
if name == "DebugFlagCallbackConfig":
|
97
|
+
return importlib.import_module(
|
98
|
+
"nshtrainer.trainer._config"
|
99
|
+
).DebugFlagCallbackConfig
|
100
|
+
if name == "DirectoryConfig":
|
101
|
+
return importlib.import_module("nshtrainer.trainer._config").DirectoryConfig
|
102
|
+
if name == "EarlyStoppingCallbackConfig":
|
103
|
+
return importlib.import_module(
|
104
|
+
"nshtrainer.trainer._config"
|
105
|
+
).EarlyStoppingCallbackConfig
|
106
|
+
if name == "EnvironmentConfig":
|
107
|
+
return importlib.import_module(
|
108
|
+
"nshtrainer.trainer._config"
|
109
|
+
).EnvironmentConfig
|
110
|
+
if name == "GradientClippingConfig":
|
111
|
+
return importlib.import_module(
|
112
|
+
"nshtrainer.trainer._config"
|
113
|
+
).GradientClippingConfig
|
114
|
+
if name == "HuggingFaceHubConfig":
|
115
|
+
return importlib.import_module(
|
116
|
+
"nshtrainer.trainer._config"
|
117
|
+
).HuggingFaceHubConfig
|
118
|
+
if name == "LastCheckpointCallbackConfig":
|
119
|
+
return importlib.import_module(
|
120
|
+
"nshtrainer.trainer._config"
|
121
|
+
).LastCheckpointCallbackConfig
|
122
|
+
if name == "LogEpochCallbackConfig":
|
123
|
+
return importlib.import_module(
|
124
|
+
"nshtrainer.trainer._config"
|
125
|
+
).LogEpochCallbackConfig
|
126
|
+
if name == "LoggingConfig":
|
127
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
|
+
if name == "MetricConfig":
|
129
|
+
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
131
|
+
return importlib.import_module(
|
132
|
+
"nshtrainer.trainer._config"
|
133
|
+
).OnExceptionCheckpointCallbackConfig
|
134
|
+
if name == "OptimizationConfig":
|
135
|
+
return importlib.import_module(
|
136
|
+
"nshtrainer.trainer._config"
|
137
|
+
).OptimizationConfig
|
138
|
+
if name == "RLPSanityChecksCallbackConfig":
|
139
|
+
return importlib.import_module(
|
140
|
+
"nshtrainer.trainer._config"
|
141
|
+
).RLPSanityChecksCallbackConfig
|
142
|
+
if name == "ReproducibilityConfig":
|
143
|
+
return importlib.import_module(
|
144
|
+
"nshtrainer.trainer._config"
|
145
|
+
).ReproducibilityConfig
|
146
|
+
if name == "SanityCheckingConfig":
|
147
|
+
return importlib.import_module(
|
148
|
+
"nshtrainer.trainer._config"
|
149
|
+
).SanityCheckingConfig
|
150
|
+
if name == "SharedParametersCallbackConfig":
|
151
|
+
return importlib.import_module(
|
152
|
+
"nshtrainer.trainer._config"
|
153
|
+
).SharedParametersCallbackConfig
|
154
|
+
if name == "TensorboardLoggerConfig":
|
155
|
+
return importlib.import_module(
|
156
|
+
"nshtrainer.trainer._config"
|
157
|
+
).TensorboardLoggerConfig
|
158
|
+
if name == "TrainerConfig":
|
159
|
+
return importlib.import_module("nshtrainer.trainer").TrainerConfig
|
160
|
+
if name == "WandbLoggerConfig":
|
161
|
+
return importlib.import_module(
|
162
|
+
"nshtrainer.trainer._config"
|
163
|
+
).WandbLoggerConfig
|
164
|
+
if name == "CallbackConfig":
|
165
|
+
return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
|
166
|
+
if name == "CheckpointCallbackConfig":
|
167
|
+
return importlib.import_module(
|
168
|
+
"nshtrainer.trainer._config"
|
169
|
+
).CheckpointCallbackConfig
|
170
|
+
if name == "LoggerConfig":
|
171
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggerConfig
|
172
|
+
if name == "ProfilerConfig":
|
173
|
+
return importlib.import_module("nshtrainer.trainer._config").ProfilerConfig
|
174
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
175
|
+
|
176
|
+
|
177
|
+
# Submodule exports
|
178
|
+
from . import _config as _config
|
179
|
+
from . import checkpoint_connector as checkpoint_connector
|
180
|
+
from . import trainer as trainer
|