nshtrainer 0.41.0__py3-none-any.whl → 0.42.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/config/__init__.py +406 -95
- nshtrainer/config/_checkpoint/loader/__init__.py +55 -13
- nshtrainer/config/_checkpoint/metadata/__init__.py +22 -8
- nshtrainer/config/_directory/__init__.py +21 -9
- nshtrainer/config/_hf_hub/__init__.py +25 -9
- nshtrainer/config/callbacks/__init__.py +114 -29
- nshtrainer/config/callbacks/actsave/__init__.py +20 -8
- nshtrainer/config/callbacks/base/__init__.py +17 -7
- nshtrainer/config/callbacks/checkpoint/__init__.py +62 -13
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +33 -9
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +40 -10
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +33 -9
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +26 -8
- nshtrainer/config/callbacks/debug_flag/__init__.py +24 -8
- nshtrainer/config/callbacks/directory_setup/__init__.py +26 -8
- nshtrainer/config/callbacks/early_stopping/__init__.py +31 -9
- nshtrainer/config/callbacks/ema/__init__.py +20 -8
- nshtrainer/config/callbacks/finite_checks/__init__.py +26 -8
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +26 -8
- nshtrainer/config/callbacks/norm_logging/__init__.py +24 -8
- nshtrainer/config/callbacks/print_table/__init__.py +26 -8
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +26 -8
- nshtrainer/config/callbacks/shared_parameters/__init__.py +26 -8
- nshtrainer/config/callbacks/throughput_monitor/__init__.py +26 -8
- nshtrainer/config/callbacks/timer/__init__.py +22 -8
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +26 -8
- nshtrainer/config/callbacks/wandb_watch/__init__.py +24 -8
- nshtrainer/config/loggers/__init__.py +41 -14
- nshtrainer/config/loggers/_base/__init__.py +15 -7
- nshtrainer/config/loggers/csv/__init__.py +18 -8
- nshtrainer/config/loggers/tensorboard/__init__.py +24 -8
- nshtrainer/config/loggers/wandb/__init__.py +31 -11
- nshtrainer/config/lr_scheduler/__init__.py +49 -12
- nshtrainer/config/lr_scheduler/_base/__init__.py +19 -7
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +33 -9
- nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +33 -9
- nshtrainer/config/metrics/__init__.py +16 -7
- nshtrainer/config/metrics/_config/__init__.py +15 -7
- nshtrainer/config/model/__init__.py +31 -12
- nshtrainer/config/model/base/__init__.py +18 -8
- nshtrainer/config/model/config/__init__.py +30 -12
- nshtrainer/config/model/mixins/logger/__init__.py +15 -7
- nshtrainer/config/nn/__init__.py +68 -23
- nshtrainer/config/nn/mlp/__init__.py +21 -9
- nshtrainer/config/nn/nonlinearity/__init__.py +118 -22
- nshtrainer/config/optimizer/__init__.py +21 -9
- nshtrainer/config/profiler/__init__.py +28 -11
- nshtrainer/config/profiler/_base/__init__.py +17 -7
- nshtrainer/config/profiler/advanced/__init__.py +24 -8
- nshtrainer/config/profiler/pytorch/__init__.py +24 -8
- nshtrainer/config/profiler/simple/__init__.py +22 -8
- nshtrainer/config/runner/__init__.py +15 -7
- nshtrainer/config/trainer/_config/__init__.py +144 -30
- nshtrainer/config/trainer/checkpoint_connector/__init__.py +19 -7
- nshtrainer/config/util/_environment_info/__init__.py +87 -17
- nshtrainer/config/util/config/__init__.py +25 -10
- nshtrainer/config/util/config/dtype/__init__.py +15 -7
- nshtrainer/config/util/config/duration/__init__.py +27 -9
- {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/RECORD +62 -62
- {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/WHEEL +0 -0
|
@@ -1,27 +1,123 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from nshtrainer.nn.nonlinearity import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from nshtrainer.nn.nonlinearity import
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from nshtrainer.nn.nonlinearity import
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from nshtrainer.nn.nonlinearity import
|
|
22
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.nn.nonlinearity import (
|
|
9
|
+
BaseNonlinearityConfig as BaseNonlinearityConfig,
|
|
10
|
+
)
|
|
11
|
+
from nshtrainer.nn.nonlinearity import (
|
|
12
|
+
ELUNonlinearityConfig as ELUNonlinearityConfig,
|
|
13
|
+
)
|
|
14
|
+
from nshtrainer.nn.nonlinearity import (
|
|
15
|
+
GELUNonlinearityConfig as GELUNonlinearityConfig,
|
|
16
|
+
)
|
|
17
|
+
from nshtrainer.nn.nonlinearity import (
|
|
18
|
+
LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
|
|
19
|
+
)
|
|
20
|
+
from nshtrainer.nn.nonlinearity import (
|
|
21
|
+
MishNonlinearityConfig as MishNonlinearityConfig,
|
|
22
|
+
)
|
|
23
|
+
from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
|
|
24
|
+
from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
|
|
25
|
+
from nshtrainer.nn.nonlinearity import (
|
|
26
|
+
ReLUNonlinearityConfig as ReLUNonlinearityConfig,
|
|
27
|
+
)
|
|
28
|
+
from nshtrainer.nn.nonlinearity import (
|
|
29
|
+
SigmoidNonlinearityConfig as SigmoidNonlinearityConfig,
|
|
30
|
+
)
|
|
31
|
+
from nshtrainer.nn.nonlinearity import (
|
|
32
|
+
SiLUNonlinearityConfig as SiLUNonlinearityConfig,
|
|
33
|
+
)
|
|
34
|
+
from nshtrainer.nn.nonlinearity import (
|
|
35
|
+
SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig,
|
|
36
|
+
)
|
|
37
|
+
from nshtrainer.nn.nonlinearity import (
|
|
38
|
+
SoftplusNonlinearityConfig as SoftplusNonlinearityConfig,
|
|
39
|
+
)
|
|
40
|
+
from nshtrainer.nn.nonlinearity import (
|
|
41
|
+
SoftsignNonlinearityConfig as SoftsignNonlinearityConfig,
|
|
42
|
+
)
|
|
43
|
+
from nshtrainer.nn.nonlinearity import (
|
|
44
|
+
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
|
45
|
+
)
|
|
46
|
+
from nshtrainer.nn.nonlinearity import (
|
|
47
|
+
SwishNonlinearityConfig as SwishNonlinearityConfig,
|
|
48
|
+
)
|
|
49
|
+
from nshtrainer.nn.nonlinearity import (
|
|
50
|
+
TanhNonlinearityConfig as TanhNonlinearityConfig,
|
|
51
|
+
)
|
|
52
|
+
else:
|
|
53
|
+
|
|
54
|
+
def __getattr__(name):
|
|
55
|
+
import importlib
|
|
23
56
|
|
|
24
|
-
|
|
25
|
-
|
|
57
|
+
if name in globals():
|
|
58
|
+
return globals()[name]
|
|
59
|
+
if name == "SwiGLUNonlinearityConfig":
|
|
60
|
+
return importlib.import_module(
|
|
61
|
+
"nshtrainer.nn.nonlinearity"
|
|
62
|
+
).SwiGLUNonlinearityConfig
|
|
63
|
+
if name == "ReLUNonlinearityConfig":
|
|
64
|
+
return importlib.import_module(
|
|
65
|
+
"nshtrainer.nn.nonlinearity"
|
|
66
|
+
).ReLUNonlinearityConfig
|
|
67
|
+
if name == "SiLUNonlinearityConfig":
|
|
68
|
+
return importlib.import_module(
|
|
69
|
+
"nshtrainer.nn.nonlinearity"
|
|
70
|
+
).SiLUNonlinearityConfig
|
|
71
|
+
if name == "ELUNonlinearityConfig":
|
|
72
|
+
return importlib.import_module(
|
|
73
|
+
"nshtrainer.nn.nonlinearity"
|
|
74
|
+
).ELUNonlinearityConfig
|
|
75
|
+
if name == "GELUNonlinearityConfig":
|
|
76
|
+
return importlib.import_module(
|
|
77
|
+
"nshtrainer.nn.nonlinearity"
|
|
78
|
+
).GELUNonlinearityConfig
|
|
79
|
+
if name == "SoftplusNonlinearityConfig":
|
|
80
|
+
return importlib.import_module(
|
|
81
|
+
"nshtrainer.nn.nonlinearity"
|
|
82
|
+
).SoftplusNonlinearityConfig
|
|
83
|
+
if name == "SoftsignNonlinearityConfig":
|
|
84
|
+
return importlib.import_module(
|
|
85
|
+
"nshtrainer.nn.nonlinearity"
|
|
86
|
+
).SoftsignNonlinearityConfig
|
|
87
|
+
if name == "SwishNonlinearityConfig":
|
|
88
|
+
return importlib.import_module(
|
|
89
|
+
"nshtrainer.nn.nonlinearity"
|
|
90
|
+
).SwishNonlinearityConfig
|
|
91
|
+
if name == "SoftmaxNonlinearityConfig":
|
|
92
|
+
return importlib.import_module(
|
|
93
|
+
"nshtrainer.nn.nonlinearity"
|
|
94
|
+
).SoftmaxNonlinearityConfig
|
|
95
|
+
if name == "MishNonlinearityConfig":
|
|
96
|
+
return importlib.import_module(
|
|
97
|
+
"nshtrainer.nn.nonlinearity"
|
|
98
|
+
).MishNonlinearityConfig
|
|
99
|
+
if name == "SigmoidNonlinearityConfig":
|
|
100
|
+
return importlib.import_module(
|
|
101
|
+
"nshtrainer.nn.nonlinearity"
|
|
102
|
+
).SigmoidNonlinearityConfig
|
|
103
|
+
if name == "TanhNonlinearityConfig":
|
|
104
|
+
return importlib.import_module(
|
|
105
|
+
"nshtrainer.nn.nonlinearity"
|
|
106
|
+
).TanhNonlinearityConfig
|
|
107
|
+
if name == "BaseNonlinearityConfig":
|
|
108
|
+
return importlib.import_module(
|
|
109
|
+
"nshtrainer.nn.nonlinearity"
|
|
110
|
+
).BaseNonlinearityConfig
|
|
111
|
+
if name == "PReLUConfig":
|
|
112
|
+
return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
|
|
113
|
+
if name == "LeakyReLUNonlinearityConfig":
|
|
114
|
+
return importlib.import_module(
|
|
115
|
+
"nshtrainer.nn.nonlinearity"
|
|
116
|
+
).LeakyReLUNonlinearityConfig
|
|
117
|
+
if name == "NonlinearityConfig":
|
|
118
|
+
return importlib.import_module(
|
|
119
|
+
"nshtrainer.nn.nonlinearity"
|
|
120
|
+
).NonlinearityConfig
|
|
121
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
26
122
|
|
|
27
123
|
# Submodule exports
|
|
@@ -1,14 +1,26 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
|
9
|
+
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
|
10
|
+
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
|
11
|
+
else:
|
|
12
|
+
|
|
13
|
+
def __getattr__(name):
|
|
14
|
+
import importlib
|
|
10
15
|
|
|
11
|
-
|
|
12
|
-
|
|
16
|
+
if name in globals():
|
|
17
|
+
return globals()[name]
|
|
18
|
+
if name == "OptimizerConfigBase":
|
|
19
|
+
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
|
20
|
+
if name == "AdamWConfig":
|
|
21
|
+
return importlib.import_module("nshtrainer.optimizer").AdamWConfig
|
|
22
|
+
if name == "OptimizerConfig":
|
|
23
|
+
return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
|
|
24
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
13
25
|
|
|
14
26
|
# Submodule exports
|
|
@@ -1,17 +1,34 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
|
9
|
+
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
|
10
|
+
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
|
11
|
+
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
|
12
|
+
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
|
13
|
+
else:
|
|
14
|
+
|
|
15
|
+
def __getattr__(name):
|
|
16
|
+
import importlib
|
|
17
|
+
|
|
18
|
+
if name in globals():
|
|
19
|
+
return globals()[name]
|
|
20
|
+
if name == "BaseProfilerConfig":
|
|
21
|
+
return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
|
|
22
|
+
if name == "PyTorchProfilerConfig":
|
|
23
|
+
return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
|
|
24
|
+
if name == "AdvancedProfilerConfig":
|
|
25
|
+
return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
|
|
26
|
+
if name == "SimpleProfilerConfig":
|
|
27
|
+
return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
|
|
28
|
+
if name == "ProfilerConfig":
|
|
29
|
+
return importlib.import_module("nshtrainer.profiler").ProfilerConfig
|
|
30
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
12
31
|
|
|
13
|
-
# Type aliases
|
|
14
|
-
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
|
15
32
|
|
|
16
33
|
# Submodule exports
|
|
17
34
|
from . import _base as _base
|
|
@@ -1,12 +1,22 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
|
|
9
|
+
else:
|
|
10
|
+
|
|
11
|
+
def __getattr__(name):
|
|
12
|
+
import importlib
|
|
9
13
|
|
|
10
|
-
|
|
14
|
+
if name in globals():
|
|
15
|
+
return globals()[name]
|
|
16
|
+
if name == "BaseProfilerConfig":
|
|
17
|
+
return importlib.import_module(
|
|
18
|
+
"nshtrainer.profiler._base"
|
|
19
|
+
).BaseProfilerConfig
|
|
20
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
11
21
|
|
|
12
22
|
# Submodule exports
|
|
@@ -1,13 +1,29 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.profiler.advanced import (
|
|
9
|
+
AdvancedProfilerConfig as AdvancedProfilerConfig,
|
|
10
|
+
)
|
|
11
|
+
from nshtrainer.profiler.advanced import BaseProfilerConfig as BaseProfilerConfig
|
|
12
|
+
else:
|
|
13
|
+
|
|
14
|
+
def __getattr__(name):
|
|
15
|
+
import importlib
|
|
10
16
|
|
|
11
|
-
|
|
17
|
+
if name in globals():
|
|
18
|
+
return globals()[name]
|
|
19
|
+
if name == "BaseProfilerConfig":
|
|
20
|
+
return importlib.import_module(
|
|
21
|
+
"nshtrainer.profiler.advanced"
|
|
22
|
+
).BaseProfilerConfig
|
|
23
|
+
if name == "AdvancedProfilerConfig":
|
|
24
|
+
return importlib.import_module(
|
|
25
|
+
"nshtrainer.profiler.advanced"
|
|
26
|
+
).AdvancedProfilerConfig
|
|
27
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
12
28
|
|
|
13
29
|
# Submodule exports
|
|
@@ -1,13 +1,29 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.profiler.pytorch import BaseProfilerConfig as BaseProfilerConfig
|
|
9
|
+
from nshtrainer.profiler.pytorch import (
|
|
10
|
+
PyTorchProfilerConfig as PyTorchProfilerConfig,
|
|
11
|
+
)
|
|
12
|
+
else:
|
|
13
|
+
|
|
14
|
+
def __getattr__(name):
|
|
15
|
+
import importlib
|
|
10
16
|
|
|
11
|
-
|
|
17
|
+
if name in globals():
|
|
18
|
+
return globals()[name]
|
|
19
|
+
if name == "BaseProfilerConfig":
|
|
20
|
+
return importlib.import_module(
|
|
21
|
+
"nshtrainer.profiler.pytorch"
|
|
22
|
+
).BaseProfilerConfig
|
|
23
|
+
if name == "PyTorchProfilerConfig":
|
|
24
|
+
return importlib.import_module(
|
|
25
|
+
"nshtrainer.profiler.pytorch"
|
|
26
|
+
).PyTorchProfilerConfig
|
|
27
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
12
28
|
|
|
13
29
|
# Submodule exports
|
|
@@ -1,13 +1,27 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.profiler.simple import BaseProfilerConfig as BaseProfilerConfig
|
|
9
|
+
from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
|
|
10
|
+
else:
|
|
11
|
+
|
|
12
|
+
def __getattr__(name):
|
|
13
|
+
import importlib
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
if name in globals():
|
|
16
|
+
return globals()[name]
|
|
17
|
+
if name == "BaseProfilerConfig":
|
|
18
|
+
return importlib.import_module(
|
|
19
|
+
"nshtrainer.profiler.simple"
|
|
20
|
+
).BaseProfilerConfig
|
|
21
|
+
if name == "SimpleProfilerConfig":
|
|
22
|
+
return importlib.import_module(
|
|
23
|
+
"nshtrainer.profiler.simple"
|
|
24
|
+
).SimpleProfilerConfig
|
|
25
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
12
26
|
|
|
13
27
|
# Submodule exports
|
|
@@ -1,12 +1,20 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.runner import BaseConfig as BaseConfig
|
|
9
|
+
else:
|
|
10
|
+
|
|
11
|
+
def __getattr__(name):
|
|
12
|
+
import importlib
|
|
9
13
|
|
|
10
|
-
|
|
14
|
+
if name in globals():
|
|
15
|
+
return globals()[name]
|
|
16
|
+
if name == "BaseConfig":
|
|
17
|
+
return importlib.import_module("nshtrainer.runner").BaseConfig
|
|
18
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
11
19
|
|
|
12
20
|
# Submodule exports
|
|
@@ -1,35 +1,149 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from nshtrainer.trainer._config import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from nshtrainer.trainer._config import
|
|
16
|
-
from nshtrainer.trainer._config import
|
|
17
|
-
from nshtrainer.trainer._config import
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from nshtrainer.trainer._config import
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from nshtrainer.trainer._config import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from nshtrainer.trainer._config import
|
|
27
|
-
from nshtrainer.trainer._config import
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.trainer._config import (
|
|
9
|
+
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
10
|
+
)
|
|
11
|
+
from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
|
|
12
|
+
from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
|
|
13
|
+
from nshtrainer.trainer._config import (
|
|
14
|
+
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
|
15
|
+
)
|
|
16
|
+
from nshtrainer.trainer._config import (
|
|
17
|
+
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
|
18
|
+
)
|
|
19
|
+
from nshtrainer.trainer._config import (
|
|
20
|
+
CheckpointSavingConfig as CheckpointSavingConfig,
|
|
21
|
+
)
|
|
22
|
+
from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
|
|
23
|
+
from nshtrainer.trainer._config import (
|
|
24
|
+
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
|
25
|
+
)
|
|
26
|
+
from nshtrainer.trainer._config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
27
|
+
from nshtrainer.trainer._config import (
|
|
28
|
+
GradientClippingConfig as GradientClippingConfig,
|
|
29
|
+
)
|
|
30
|
+
from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
31
|
+
from nshtrainer.trainer._config import (
|
|
32
|
+
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
33
|
+
)
|
|
34
|
+
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
|
35
|
+
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
|
36
|
+
from nshtrainer.trainer._config import (
|
|
37
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
38
|
+
)
|
|
39
|
+
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
|
40
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
41
|
+
from nshtrainer.trainer._config import (
|
|
42
|
+
ReproducibilityConfig as ReproducibilityConfig,
|
|
43
|
+
)
|
|
44
|
+
from nshtrainer.trainer._config import (
|
|
45
|
+
RLPSanityChecksConfig as RLPSanityChecksConfig,
|
|
46
|
+
)
|
|
47
|
+
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
|
48
|
+
from nshtrainer.trainer._config import (
|
|
49
|
+
SharedParametersConfig as SharedParametersConfig,
|
|
50
|
+
)
|
|
51
|
+
from nshtrainer.trainer._config import (
|
|
52
|
+
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
|
53
|
+
)
|
|
54
|
+
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
|
55
|
+
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
|
56
|
+
else:
|
|
57
|
+
|
|
58
|
+
def __getattr__(name):
|
|
59
|
+
import importlib
|
|
28
60
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
61
|
+
if name in globals():
|
|
62
|
+
return globals()[name]
|
|
63
|
+
if name == "HuggingFaceHubConfig":
|
|
64
|
+
return importlib.import_module(
|
|
65
|
+
"nshtrainer.trainer._config"
|
|
66
|
+
).HuggingFaceHubConfig
|
|
67
|
+
if name == "OptimizationConfig":
|
|
68
|
+
return importlib.import_module(
|
|
69
|
+
"nshtrainer.trainer._config"
|
|
70
|
+
).OptimizationConfig
|
|
71
|
+
if name == "TrainerConfig":
|
|
72
|
+
return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
|
|
73
|
+
if name == "TensorboardLoggerConfig":
|
|
74
|
+
return importlib.import_module(
|
|
75
|
+
"nshtrainer.trainer._config"
|
|
76
|
+
).TensorboardLoggerConfig
|
|
77
|
+
if name == "GradientClippingConfig":
|
|
78
|
+
return importlib.import_module(
|
|
79
|
+
"nshtrainer.trainer._config"
|
|
80
|
+
).GradientClippingConfig
|
|
81
|
+
if name == "CallbackConfigBase":
|
|
82
|
+
return importlib.import_module(
|
|
83
|
+
"nshtrainer.trainer._config"
|
|
84
|
+
).CallbackConfigBase
|
|
85
|
+
if name == "CSVLoggerConfig":
|
|
86
|
+
return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
|
|
87
|
+
if name == "LastCheckpointCallbackConfig":
|
|
88
|
+
return importlib.import_module(
|
|
89
|
+
"nshtrainer.trainer._config"
|
|
90
|
+
).LastCheckpointCallbackConfig
|
|
91
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
|
92
|
+
return importlib.import_module(
|
|
93
|
+
"nshtrainer.trainer._config"
|
|
94
|
+
).OnExceptionCheckpointCallbackConfig
|
|
95
|
+
if name == "RLPSanityChecksConfig":
|
|
96
|
+
return importlib.import_module(
|
|
97
|
+
"nshtrainer.trainer._config"
|
|
98
|
+
).RLPSanityChecksConfig
|
|
99
|
+
if name == "EarlyStoppingConfig":
|
|
100
|
+
return importlib.import_module(
|
|
101
|
+
"nshtrainer.trainer._config"
|
|
102
|
+
).EarlyStoppingConfig
|
|
103
|
+
if name == "DebugFlagCallbackConfig":
|
|
104
|
+
return importlib.import_module(
|
|
105
|
+
"nshtrainer.trainer._config"
|
|
106
|
+
).DebugFlagCallbackConfig
|
|
107
|
+
if name == "WandbLoggerConfig":
|
|
108
|
+
return importlib.import_module(
|
|
109
|
+
"nshtrainer.trainer._config"
|
|
110
|
+
).WandbLoggerConfig
|
|
111
|
+
if name == "CheckpointSavingConfig":
|
|
112
|
+
return importlib.import_module(
|
|
113
|
+
"nshtrainer.trainer._config"
|
|
114
|
+
).CheckpointSavingConfig
|
|
115
|
+
if name == "CheckpointLoadingConfig":
|
|
116
|
+
return importlib.import_module(
|
|
117
|
+
"nshtrainer.trainer._config"
|
|
118
|
+
).CheckpointLoadingConfig
|
|
119
|
+
if name == "BestCheckpointCallbackConfig":
|
|
120
|
+
return importlib.import_module(
|
|
121
|
+
"nshtrainer.trainer._config"
|
|
122
|
+
).BestCheckpointCallbackConfig
|
|
123
|
+
if name == "LoggingConfig":
|
|
124
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
|
125
|
+
if name == "SanityCheckingConfig":
|
|
126
|
+
return importlib.import_module(
|
|
127
|
+
"nshtrainer.trainer._config"
|
|
128
|
+
).SanityCheckingConfig
|
|
129
|
+
if name == "SharedParametersConfig":
|
|
130
|
+
return importlib.import_module(
|
|
131
|
+
"nshtrainer.trainer._config"
|
|
132
|
+
).SharedParametersConfig
|
|
133
|
+
if name == "ReproducibilityConfig":
|
|
134
|
+
return importlib.import_module(
|
|
135
|
+
"nshtrainer.trainer._config"
|
|
136
|
+
).ReproducibilityConfig
|
|
137
|
+
if name == "CallbackConfig":
|
|
138
|
+
return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
|
|
139
|
+
if name == "CheckpointCallbackConfig":
|
|
140
|
+
return importlib.import_module(
|
|
141
|
+
"nshtrainer.trainer._config"
|
|
142
|
+
).CheckpointCallbackConfig
|
|
143
|
+
if name == "LoggerConfig":
|
|
144
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggerConfig
|
|
145
|
+
if name == "ProfilerConfig":
|
|
146
|
+
return importlib.import_module("nshtrainer.trainer._config").ProfilerConfig
|
|
147
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
34
148
|
|
|
35
149
|
# Submodule exports
|
|
@@ -1,12 +1,24 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.trainer.checkpoint_connector import (
|
|
9
|
+
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
|
10
|
+
)
|
|
11
|
+
else:
|
|
12
|
+
|
|
13
|
+
def __getattr__(name):
|
|
14
|
+
import importlib
|
|
9
15
|
|
|
10
|
-
|
|
16
|
+
if name in globals():
|
|
17
|
+
return globals()[name]
|
|
18
|
+
if name == "CheckpointLoadingConfig":
|
|
19
|
+
return importlib.import_module(
|
|
20
|
+
"nshtrainer.trainer.checkpoint_connector"
|
|
21
|
+
).CheckpointLoadingConfig
|
|
22
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
11
23
|
|
|
12
24
|
# Submodule exports
|