nshtrainer 0.41.0__py3-none-any.whl → 0.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/config/__init__.py +406 -95
  3. nshtrainer/config/_checkpoint/loader/__init__.py +55 -13
  4. nshtrainer/config/_checkpoint/metadata/__init__.py +22 -8
  5. nshtrainer/config/_directory/__init__.py +21 -9
  6. nshtrainer/config/_hf_hub/__init__.py +25 -9
  7. nshtrainer/config/callbacks/__init__.py +114 -29
  8. nshtrainer/config/callbacks/actsave/__init__.py +20 -8
  9. nshtrainer/config/callbacks/base/__init__.py +17 -7
  10. nshtrainer/config/callbacks/checkpoint/__init__.py +62 -13
  11. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +33 -9
  12. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +40 -10
  13. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +33 -9
  14. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +26 -8
  15. nshtrainer/config/callbacks/debug_flag/__init__.py +24 -8
  16. nshtrainer/config/callbacks/directory_setup/__init__.py +26 -8
  17. nshtrainer/config/callbacks/early_stopping/__init__.py +31 -9
  18. nshtrainer/config/callbacks/ema/__init__.py +20 -8
  19. nshtrainer/config/callbacks/finite_checks/__init__.py +26 -8
  20. nshtrainer/config/callbacks/gradient_skipping/__init__.py +26 -8
  21. nshtrainer/config/callbacks/norm_logging/__init__.py +24 -8
  22. nshtrainer/config/callbacks/print_table/__init__.py +26 -8
  23. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +26 -8
  24. nshtrainer/config/callbacks/shared_parameters/__init__.py +26 -8
  25. nshtrainer/config/callbacks/throughput_monitor/__init__.py +26 -8
  26. nshtrainer/config/callbacks/timer/__init__.py +22 -8
  27. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +26 -8
  28. nshtrainer/config/callbacks/wandb_watch/__init__.py +24 -8
  29. nshtrainer/config/loggers/__init__.py +41 -14
  30. nshtrainer/config/loggers/_base/__init__.py +15 -7
  31. nshtrainer/config/loggers/csv/__init__.py +18 -8
  32. nshtrainer/config/loggers/tensorboard/__init__.py +24 -8
  33. nshtrainer/config/loggers/wandb/__init__.py +31 -11
  34. nshtrainer/config/lr_scheduler/__init__.py +49 -12
  35. nshtrainer/config/lr_scheduler/_base/__init__.py +19 -7
  36. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +33 -9
  37. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +33 -9
  38. nshtrainer/config/metrics/__init__.py +16 -7
  39. nshtrainer/config/metrics/_config/__init__.py +15 -7
  40. nshtrainer/config/model/__init__.py +31 -12
  41. nshtrainer/config/model/base/__init__.py +18 -8
  42. nshtrainer/config/model/config/__init__.py +30 -12
  43. nshtrainer/config/model/mixins/logger/__init__.py +15 -7
  44. nshtrainer/config/nn/__init__.py +68 -23
  45. nshtrainer/config/nn/mlp/__init__.py +21 -9
  46. nshtrainer/config/nn/nonlinearity/__init__.py +118 -22
  47. nshtrainer/config/optimizer/__init__.py +21 -9
  48. nshtrainer/config/profiler/__init__.py +28 -11
  49. nshtrainer/config/profiler/_base/__init__.py +17 -7
  50. nshtrainer/config/profiler/advanced/__init__.py +24 -8
  51. nshtrainer/config/profiler/pytorch/__init__.py +24 -8
  52. nshtrainer/config/profiler/simple/__init__.py +22 -8
  53. nshtrainer/config/runner/__init__.py +15 -7
  54. nshtrainer/config/trainer/_config/__init__.py +144 -30
  55. nshtrainer/config/trainer/checkpoint_connector/__init__.py +19 -7
  56. nshtrainer/config/util/_environment_info/__init__.py +87 -17
  57. nshtrainer/config/util/config/__init__.py +25 -10
  58. nshtrainer/config/util/config/dtype/__init__.py +15 -7
  59. nshtrainer/config/util/config/duration/__init__.py +27 -9
  60. {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/METADATA +1 -1
  61. {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/RECORD +62 -62
  62. {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/WHEEL +0 -0
@@ -1,27 +1,123 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
9
- from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
10
- from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
11
- from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
12
- from nshtrainer.nn.nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
13
- from nshtrainer.nn.nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
14
- from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
15
- from nshtrainer.nn.nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
16
- from nshtrainer.nn.nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
17
- from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
18
- from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
19
- from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
20
- from nshtrainer.nn.nonlinearity import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
21
- from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
22
- from nshtrainer.nn.nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.nn.nonlinearity import (
9
+ BaseNonlinearityConfig as BaseNonlinearityConfig,
10
+ )
11
+ from nshtrainer.nn.nonlinearity import (
12
+ ELUNonlinearityConfig as ELUNonlinearityConfig,
13
+ )
14
+ from nshtrainer.nn.nonlinearity import (
15
+ GELUNonlinearityConfig as GELUNonlinearityConfig,
16
+ )
17
+ from nshtrainer.nn.nonlinearity import (
18
+ LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
19
+ )
20
+ from nshtrainer.nn.nonlinearity import (
21
+ MishNonlinearityConfig as MishNonlinearityConfig,
22
+ )
23
+ from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
24
+ from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
25
+ from nshtrainer.nn.nonlinearity import (
26
+ ReLUNonlinearityConfig as ReLUNonlinearityConfig,
27
+ )
28
+ from nshtrainer.nn.nonlinearity import (
29
+ SigmoidNonlinearityConfig as SigmoidNonlinearityConfig,
30
+ )
31
+ from nshtrainer.nn.nonlinearity import (
32
+ SiLUNonlinearityConfig as SiLUNonlinearityConfig,
33
+ )
34
+ from nshtrainer.nn.nonlinearity import (
35
+ SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig,
36
+ )
37
+ from nshtrainer.nn.nonlinearity import (
38
+ SoftplusNonlinearityConfig as SoftplusNonlinearityConfig,
39
+ )
40
+ from nshtrainer.nn.nonlinearity import (
41
+ SoftsignNonlinearityConfig as SoftsignNonlinearityConfig,
42
+ )
43
+ from nshtrainer.nn.nonlinearity import (
44
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
45
+ )
46
+ from nshtrainer.nn.nonlinearity import (
47
+ SwishNonlinearityConfig as SwishNonlinearityConfig,
48
+ )
49
+ from nshtrainer.nn.nonlinearity import (
50
+ TanhNonlinearityConfig as TanhNonlinearityConfig,
51
+ )
52
+ else:
53
+
54
+ def __getattr__(name):
55
+ import importlib
23
56
 
24
- # Type aliases
25
- from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
57
+ if name in globals():
58
+ return globals()[name]
59
+ if name == "SwiGLUNonlinearityConfig":
60
+ return importlib.import_module(
61
+ "nshtrainer.nn.nonlinearity"
62
+ ).SwiGLUNonlinearityConfig
63
+ if name == "ReLUNonlinearityConfig":
64
+ return importlib.import_module(
65
+ "nshtrainer.nn.nonlinearity"
66
+ ).ReLUNonlinearityConfig
67
+ if name == "SiLUNonlinearityConfig":
68
+ return importlib.import_module(
69
+ "nshtrainer.nn.nonlinearity"
70
+ ).SiLUNonlinearityConfig
71
+ if name == "ELUNonlinearityConfig":
72
+ return importlib.import_module(
73
+ "nshtrainer.nn.nonlinearity"
74
+ ).ELUNonlinearityConfig
75
+ if name == "GELUNonlinearityConfig":
76
+ return importlib.import_module(
77
+ "nshtrainer.nn.nonlinearity"
78
+ ).GELUNonlinearityConfig
79
+ if name == "SoftplusNonlinearityConfig":
80
+ return importlib.import_module(
81
+ "nshtrainer.nn.nonlinearity"
82
+ ).SoftplusNonlinearityConfig
83
+ if name == "SoftsignNonlinearityConfig":
84
+ return importlib.import_module(
85
+ "nshtrainer.nn.nonlinearity"
86
+ ).SoftsignNonlinearityConfig
87
+ if name == "SwishNonlinearityConfig":
88
+ return importlib.import_module(
89
+ "nshtrainer.nn.nonlinearity"
90
+ ).SwishNonlinearityConfig
91
+ if name == "SoftmaxNonlinearityConfig":
92
+ return importlib.import_module(
93
+ "nshtrainer.nn.nonlinearity"
94
+ ).SoftmaxNonlinearityConfig
95
+ if name == "MishNonlinearityConfig":
96
+ return importlib.import_module(
97
+ "nshtrainer.nn.nonlinearity"
98
+ ).MishNonlinearityConfig
99
+ if name == "SigmoidNonlinearityConfig":
100
+ return importlib.import_module(
101
+ "nshtrainer.nn.nonlinearity"
102
+ ).SigmoidNonlinearityConfig
103
+ if name == "TanhNonlinearityConfig":
104
+ return importlib.import_module(
105
+ "nshtrainer.nn.nonlinearity"
106
+ ).TanhNonlinearityConfig
107
+ if name == "BaseNonlinearityConfig":
108
+ return importlib.import_module(
109
+ "nshtrainer.nn.nonlinearity"
110
+ ).BaseNonlinearityConfig
111
+ if name == "PReLUConfig":
112
+ return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
113
+ if name == "LeakyReLUNonlinearityConfig":
114
+ return importlib.import_module(
115
+ "nshtrainer.nn.nonlinearity"
116
+ ).LeakyReLUNonlinearityConfig
117
+ if name == "NonlinearityConfig":
118
+ return importlib.import_module(
119
+ "nshtrainer.nn.nonlinearity"
120
+ ).NonlinearityConfig
121
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
122
 
27
123
  # Submodule exports
@@ -1,14 +1,26 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
9
- from nshtrainer.optimizer import AdamWConfig as AdamWConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
9
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
10
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
10
15
 
11
- # Type aliases
12
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "OptimizerConfigBase":
19
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
20
+ if name == "AdamWConfig":
21
+ return importlib.import_module("nshtrainer.optimizer").AdamWConfig
22
+ if name == "OptimizerConfig":
23
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
24
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
25
 
14
26
  # Submodule exports
@@ -1,17 +1,34 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
9
- from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
10
- from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
11
- from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
9
+ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
10
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
11
+ from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
12
+ from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
13
+ else:
14
+
15
+ def __getattr__(name):
16
+ import importlib
17
+
18
+ if name in globals():
19
+ return globals()[name]
20
+ if name == "BaseProfilerConfig":
21
+ return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
22
+ if name == "PyTorchProfilerConfig":
23
+ return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
24
+ if name == "AdvancedProfilerConfig":
25
+ return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
26
+ if name == "SimpleProfilerConfig":
27
+ return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
28
+ if name == "ProfilerConfig":
29
+ return importlib.import_module("nshtrainer.profiler").ProfilerConfig
30
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
31
 
13
- # Type aliases
14
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
15
32
 
16
33
  # Submodule exports
17
34
  from . import _base as _base
@@ -1,12 +1,22 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
9
+ else:
10
+
11
+ def __getattr__(name):
12
+ import importlib
9
13
 
10
- # Type aliases
14
+ if name in globals():
15
+ return globals()[name]
16
+ if name == "BaseProfilerConfig":
17
+ return importlib.import_module(
18
+ "nshtrainer.profiler._base"
19
+ ).BaseProfilerConfig
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
21
 
12
22
  # Submodule exports
@@ -1,13 +1,29 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.profiler.advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
9
- from nshtrainer.profiler.advanced import BaseProfilerConfig as BaseProfilerConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.profiler.advanced import (
9
+ AdvancedProfilerConfig as AdvancedProfilerConfig,
10
+ )
11
+ from nshtrainer.profiler.advanced import BaseProfilerConfig as BaseProfilerConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
10
16
 
11
- # Type aliases
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "BaseProfilerConfig":
20
+ return importlib.import_module(
21
+ "nshtrainer.profiler.advanced"
22
+ ).BaseProfilerConfig
23
+ if name == "AdvancedProfilerConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.profiler.advanced"
26
+ ).AdvancedProfilerConfig
27
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
28
 
13
29
  # Submodule exports
@@ -1,13 +1,29 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.profiler.pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
9
- from nshtrainer.profiler.pytorch import BaseProfilerConfig as BaseProfilerConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.profiler.pytorch import BaseProfilerConfig as BaseProfilerConfig
9
+ from nshtrainer.profiler.pytorch import (
10
+ PyTorchProfilerConfig as PyTorchProfilerConfig,
11
+ )
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
10
16
 
11
- # Type aliases
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "BaseProfilerConfig":
20
+ return importlib.import_module(
21
+ "nshtrainer.profiler.pytorch"
22
+ ).BaseProfilerConfig
23
+ if name == "PyTorchProfilerConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.profiler.pytorch"
26
+ ).PyTorchProfilerConfig
27
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
28
 
13
29
  # Submodule exports
@@ -1,13 +1,27 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.profiler.simple import BaseProfilerConfig as BaseProfilerConfig
9
- from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.profiler.simple import BaseProfilerConfig as BaseProfilerConfig
9
+ from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
10
+ else:
11
+
12
+ def __getattr__(name):
13
+ import importlib
10
14
 
11
- # Type aliases
15
+ if name in globals():
16
+ return globals()[name]
17
+ if name == "BaseProfilerConfig":
18
+ return importlib.import_module(
19
+ "nshtrainer.profiler.simple"
20
+ ).BaseProfilerConfig
21
+ if name == "SimpleProfilerConfig":
22
+ return importlib.import_module(
23
+ "nshtrainer.profiler.simple"
24
+ ).SimpleProfilerConfig
25
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
26
 
13
27
  # Submodule exports
@@ -1,12 +1,20 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.runner import BaseConfig as BaseConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.runner import BaseConfig as BaseConfig
9
+ else:
10
+
11
+ def __getattr__(name):
12
+ import importlib
9
13
 
10
- # Type aliases
14
+ if name in globals():
15
+ return globals()[name]
16
+ if name == "BaseConfig":
17
+ return importlib.import_module("nshtrainer.runner").BaseConfig
18
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
19
 
12
20
  # Submodule exports
@@ -1,35 +1,149 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.trainer._config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
- from nshtrainer.trainer._config import TensorboardLoggerConfig as TensorboardLoggerConfig
10
- from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
11
- from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
12
- from nshtrainer.trainer._config import SharedParametersConfig as SharedParametersConfig
13
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
14
- from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
15
- from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
16
- from nshtrainer.trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
17
- from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
18
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
19
- from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
20
- from nshtrainer.trainer._config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
21
- from nshtrainer.trainer._config import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
22
- from nshtrainer.trainer._config import RLPSanityChecksConfig as RLPSanityChecksConfig
23
- from nshtrainer.trainer._config import EarlyStoppingConfig as EarlyStoppingConfig
24
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
25
- from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
26
- from nshtrainer.trainer._config import DebugFlagCallbackConfig as DebugFlagCallbackConfig
27
- from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.trainer._config import (
9
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
10
+ )
11
+ from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
12
+ from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
13
+ from nshtrainer.trainer._config import (
14
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
15
+ )
16
+ from nshtrainer.trainer._config import (
17
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
18
+ )
19
+ from nshtrainer.trainer._config import (
20
+ CheckpointSavingConfig as CheckpointSavingConfig,
21
+ )
22
+ from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
23
+ from nshtrainer.trainer._config import (
24
+ DebugFlagCallbackConfig as DebugFlagCallbackConfig,
25
+ )
26
+ from nshtrainer.trainer._config import EarlyStoppingConfig as EarlyStoppingConfig
27
+ from nshtrainer.trainer._config import (
28
+ GradientClippingConfig as GradientClippingConfig,
29
+ )
30
+ from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
31
+ from nshtrainer.trainer._config import (
32
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
33
+ )
34
+ from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
+ from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
36
+ from nshtrainer.trainer._config import (
37
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
38
+ )
39
+ from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
40
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
41
+ from nshtrainer.trainer._config import (
42
+ ReproducibilityConfig as ReproducibilityConfig,
43
+ )
44
+ from nshtrainer.trainer._config import (
45
+ RLPSanityChecksConfig as RLPSanityChecksConfig,
46
+ )
47
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
48
+ from nshtrainer.trainer._config import (
49
+ SharedParametersConfig as SharedParametersConfig,
50
+ )
51
+ from nshtrainer.trainer._config import (
52
+ TensorboardLoggerConfig as TensorboardLoggerConfig,
53
+ )
54
+ from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
55
+ from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
56
+ else:
57
+
58
+ def __getattr__(name):
59
+ import importlib
28
60
 
29
- # Type aliases
30
- from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
31
- from nshtrainer.trainer._config import CheckpointCallbackConfig as CheckpointCallbackConfig
32
- from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
33
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
61
+ if name in globals():
62
+ return globals()[name]
63
+ if name == "HuggingFaceHubConfig":
64
+ return importlib.import_module(
65
+ "nshtrainer.trainer._config"
66
+ ).HuggingFaceHubConfig
67
+ if name == "OptimizationConfig":
68
+ return importlib.import_module(
69
+ "nshtrainer.trainer._config"
70
+ ).OptimizationConfig
71
+ if name == "TrainerConfig":
72
+ return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
73
+ if name == "TensorboardLoggerConfig":
74
+ return importlib.import_module(
75
+ "nshtrainer.trainer._config"
76
+ ).TensorboardLoggerConfig
77
+ if name == "GradientClippingConfig":
78
+ return importlib.import_module(
79
+ "nshtrainer.trainer._config"
80
+ ).GradientClippingConfig
81
+ if name == "CallbackConfigBase":
82
+ return importlib.import_module(
83
+ "nshtrainer.trainer._config"
84
+ ).CallbackConfigBase
85
+ if name == "CSVLoggerConfig":
86
+ return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
87
+ if name == "LastCheckpointCallbackConfig":
88
+ return importlib.import_module(
89
+ "nshtrainer.trainer._config"
90
+ ).LastCheckpointCallbackConfig
91
+ if name == "OnExceptionCheckpointCallbackConfig":
92
+ return importlib.import_module(
93
+ "nshtrainer.trainer._config"
94
+ ).OnExceptionCheckpointCallbackConfig
95
+ if name == "RLPSanityChecksConfig":
96
+ return importlib.import_module(
97
+ "nshtrainer.trainer._config"
98
+ ).RLPSanityChecksConfig
99
+ if name == "EarlyStoppingConfig":
100
+ return importlib.import_module(
101
+ "nshtrainer.trainer._config"
102
+ ).EarlyStoppingConfig
103
+ if name == "DebugFlagCallbackConfig":
104
+ return importlib.import_module(
105
+ "nshtrainer.trainer._config"
106
+ ).DebugFlagCallbackConfig
107
+ if name == "WandbLoggerConfig":
108
+ return importlib.import_module(
109
+ "nshtrainer.trainer._config"
110
+ ).WandbLoggerConfig
111
+ if name == "CheckpointSavingConfig":
112
+ return importlib.import_module(
113
+ "nshtrainer.trainer._config"
114
+ ).CheckpointSavingConfig
115
+ if name == "CheckpointLoadingConfig":
116
+ return importlib.import_module(
117
+ "nshtrainer.trainer._config"
118
+ ).CheckpointLoadingConfig
119
+ if name == "BestCheckpointCallbackConfig":
120
+ return importlib.import_module(
121
+ "nshtrainer.trainer._config"
122
+ ).BestCheckpointCallbackConfig
123
+ if name == "LoggingConfig":
124
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
125
+ if name == "SanityCheckingConfig":
126
+ return importlib.import_module(
127
+ "nshtrainer.trainer._config"
128
+ ).SanityCheckingConfig
129
+ if name == "SharedParametersConfig":
130
+ return importlib.import_module(
131
+ "nshtrainer.trainer._config"
132
+ ).SharedParametersConfig
133
+ if name == "ReproducibilityConfig":
134
+ return importlib.import_module(
135
+ "nshtrainer.trainer._config"
136
+ ).ReproducibilityConfig
137
+ if name == "CallbackConfig":
138
+ return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
139
+ if name == "CheckpointCallbackConfig":
140
+ return importlib.import_module(
141
+ "nshtrainer.trainer._config"
142
+ ).CheckpointCallbackConfig
143
+ if name == "LoggerConfig":
144
+ return importlib.import_module("nshtrainer.trainer._config").LoggerConfig
145
+ if name == "ProfilerConfig":
146
+ return importlib.import_module("nshtrainer.trainer._config").ProfilerConfig
147
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
34
148
 
35
149
  # Submodule exports
@@ -1,12 +1,24 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.trainer.checkpoint_connector import CheckpointLoadingConfig as CheckpointLoadingConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.trainer.checkpoint_connector import (
9
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
10
+ )
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
9
15
 
10
- # Type aliases
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "CheckpointLoadingConfig":
19
+ return importlib.import_module(
20
+ "nshtrainer.trainer.checkpoint_connector"
21
+ ).CheckpointLoadingConfig
22
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
23
 
12
24
  # Submodule exports