nshtrainer 0.41.1__py3-none-any.whl → 0.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. nshtrainer/config/__init__.py +406 -95
  2. nshtrainer/config/_checkpoint/loader/__init__.py +55 -13
  3. nshtrainer/config/_checkpoint/metadata/__init__.py +22 -8
  4. nshtrainer/config/_directory/__init__.py +21 -9
  5. nshtrainer/config/_hf_hub/__init__.py +25 -9
  6. nshtrainer/config/callbacks/__init__.py +114 -29
  7. nshtrainer/config/callbacks/actsave/__init__.py +20 -8
  8. nshtrainer/config/callbacks/base/__init__.py +17 -7
  9. nshtrainer/config/callbacks/checkpoint/__init__.py +62 -13
  10. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +33 -9
  11. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +40 -10
  12. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +33 -9
  13. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +26 -8
  14. nshtrainer/config/callbacks/debug_flag/__init__.py +24 -8
  15. nshtrainer/config/callbacks/directory_setup/__init__.py +26 -8
  16. nshtrainer/config/callbacks/early_stopping/__init__.py +31 -9
  17. nshtrainer/config/callbacks/ema/__init__.py +20 -8
  18. nshtrainer/config/callbacks/finite_checks/__init__.py +26 -8
  19. nshtrainer/config/callbacks/gradient_skipping/__init__.py +26 -8
  20. nshtrainer/config/callbacks/norm_logging/__init__.py +24 -8
  21. nshtrainer/config/callbacks/print_table/__init__.py +26 -8
  22. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +26 -8
  23. nshtrainer/config/callbacks/shared_parameters/__init__.py +26 -8
  24. nshtrainer/config/callbacks/throughput_monitor/__init__.py +26 -8
  25. nshtrainer/config/callbacks/timer/__init__.py +22 -8
  26. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +26 -8
  27. nshtrainer/config/callbacks/wandb_watch/__init__.py +24 -8
  28. nshtrainer/config/loggers/__init__.py +41 -14
  29. nshtrainer/config/loggers/_base/__init__.py +15 -7
  30. nshtrainer/config/loggers/csv/__init__.py +18 -8
  31. nshtrainer/config/loggers/tensorboard/__init__.py +24 -8
  32. nshtrainer/config/loggers/wandb/__init__.py +31 -11
  33. nshtrainer/config/lr_scheduler/__init__.py +49 -12
  34. nshtrainer/config/lr_scheduler/_base/__init__.py +19 -7
  35. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +33 -9
  36. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +33 -9
  37. nshtrainer/config/metrics/__init__.py +16 -7
  38. nshtrainer/config/metrics/_config/__init__.py +15 -7
  39. nshtrainer/config/model/__init__.py +31 -12
  40. nshtrainer/config/model/base/__init__.py +18 -8
  41. nshtrainer/config/model/config/__init__.py +30 -12
  42. nshtrainer/config/model/mixins/logger/__init__.py +15 -7
  43. nshtrainer/config/nn/__init__.py +68 -23
  44. nshtrainer/config/nn/mlp/__init__.py +21 -9
  45. nshtrainer/config/nn/nonlinearity/__init__.py +118 -22
  46. nshtrainer/config/optimizer/__init__.py +21 -9
  47. nshtrainer/config/profiler/__init__.py +28 -11
  48. nshtrainer/config/profiler/_base/__init__.py +17 -7
  49. nshtrainer/config/profiler/advanced/__init__.py +24 -8
  50. nshtrainer/config/profiler/pytorch/__init__.py +24 -8
  51. nshtrainer/config/profiler/simple/__init__.py +22 -8
  52. nshtrainer/config/runner/__init__.py +15 -7
  53. nshtrainer/config/trainer/_config/__init__.py +144 -30
  54. nshtrainer/config/trainer/checkpoint_connector/__init__.py +19 -7
  55. nshtrainer/config/util/_environment_info/__init__.py +87 -17
  56. nshtrainer/config/util/config/__init__.py +25 -10
  57. nshtrainer/config/util/config/dtype/__init__.py +15 -7
  58. nshtrainer/config/util/config/duration/__init__.py +27 -9
  59. {nshtrainer-0.41.1.dist-info → nshtrainer-0.42.0.dist-info}/METADATA +1 -1
  60. {nshtrainer-0.41.1.dist-info → nshtrainer-0.42.0.dist-info}/RECORD +61 -61
  61. {nshtrainer-0.41.1.dist-info → nshtrainer-0.42.0.dist-info}/WHEEL +0 -0
@@ -1,101 +1,412 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer import MetricConfig as MetricConfig
9
- from nshtrainer import BaseConfig as BaseConfig
10
- from nshtrainer._hf_hub import HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig
11
- from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
12
- from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
13
- from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
14
- from nshtrainer.optimizer import AdamWConfig as AdamWConfig
15
- from nshtrainer.model import DirectoryConfig as DirectoryConfig
16
- from nshtrainer.callbacks import DirectorySetupConfig as DirectorySetupConfig
17
- from nshtrainer.model import TrainerConfig as TrainerConfig
18
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
19
- from nshtrainer.nn import MLPConfig as MLPConfig
20
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
21
- from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
22
- from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
23
- from nshtrainer.nn import PReLUConfig as PReLUConfig
24
- from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
25
- from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
26
- from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
27
- from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
28
- from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
29
- from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
30
- from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
31
- from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
32
- from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
33
- from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
34
- from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
35
- from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
36
- from nshtrainer.lr_scheduler import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
37
- from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
38
- from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
39
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
40
- from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
41
- from nshtrainer.callbacks import WandbUploadCodeConfig as WandbUploadCodeConfig
42
- from nshtrainer.callbacks import WandbWatchConfig as WandbWatchConfig
43
- from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
44
- from nshtrainer.util._environment_info import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
45
- from nshtrainer.util._environment_info import EnvironmentHardwareConfig as EnvironmentHardwareConfig
46
- from nshtrainer.util._environment_info import EnvironmentPackageConfig as EnvironmentPackageConfig
47
- from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
48
- from nshtrainer.util._environment_info import EnvironmentGPUConfig as EnvironmentGPUConfig
49
- from nshtrainer.util._environment_info import EnvironmentCUDAConfig as EnvironmentCUDAConfig
50
- from nshtrainer.util._environment_info import EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig
51
- from nshtrainer.util._environment_info import EnvironmentClassInformationConfig as EnvironmentClassInformationConfig
52
- from nshtrainer.util._environment_info import EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig
53
- from nshtrainer.util._environment_info import EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig
54
- from nshtrainer.util.config import EpochsConfig as EpochsConfig
55
- from nshtrainer.util.config import StepsConfig as StepsConfig
56
- from nshtrainer.util.config import DTypeConfig as DTypeConfig
57
- from nshtrainer.trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
58
- from nshtrainer.callbacks import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
59
- from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
60
- from nshtrainer.callbacks import SharedParametersConfig as SharedParametersConfig
61
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
62
- from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
63
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
64
- from nshtrainer.callbacks import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
65
- from nshtrainer.callbacks import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
66
- from nshtrainer.callbacks import RLPSanityChecksConfig as RLPSanityChecksConfig
67
- from nshtrainer.callbacks import EarlyStoppingConfig as EarlyStoppingConfig
68
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
69
- from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
70
- from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
71
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
72
- from nshtrainer._checkpoint.loader import UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig
73
- from nshtrainer._checkpoint.loader import BestCheckpointStrategyConfig as BestCheckpointStrategyConfig
74
- from nshtrainer._checkpoint.loader import LastCheckpointStrategyConfig as LastCheckpointStrategyConfig
75
- from nshtrainer.callbacks import PrintTableMetricsConfig as PrintTableMetricsConfig
76
- from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
77
- from nshtrainer.callbacks import GradientSkippingConfig as GradientSkippingConfig
78
- from nshtrainer.callbacks import EMAConfig as EMAConfig
79
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
80
- from nshtrainer.callbacks import FiniteChecksConfig as FiniteChecksConfig
81
- from nshtrainer.callbacks import NormLoggingConfig as NormLoggingConfig
82
- from nshtrainer.callbacks import EpochTimerConfig as EpochTimerConfig
83
- from nshtrainer.callbacks.checkpoint._base import BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig
84
- from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
85
- from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
86
- from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
87
- from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer import BaseConfig as BaseConfig
9
+ from nshtrainer import MetricConfig as MetricConfig
10
+ from nshtrainer._checkpoint.loader import (
11
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
12
+ )
13
+ from nshtrainer._checkpoint.loader import (
14
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
15
+ )
16
+ from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
17
+ from nshtrainer._checkpoint.loader import (
18
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
19
+ )
20
+ from nshtrainer._checkpoint.loader import (
21
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
22
+ )
23
+ from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
24
+ from nshtrainer._hf_hub import (
25
+ HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
26
+ )
27
+ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
28
+ from nshtrainer.callbacks import (
29
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
30
+ )
31
+ from nshtrainer.callbacks import CallbackConfig as CallbackConfig
32
+ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
33
+ from nshtrainer.callbacks import DirectorySetupConfig as DirectorySetupConfig
34
+ from nshtrainer.callbacks import EarlyStoppingConfig as EarlyStoppingConfig
35
+ from nshtrainer.callbacks import EMAConfig as EMAConfig
36
+ from nshtrainer.callbacks import EpochTimerConfig as EpochTimerConfig
37
+ from nshtrainer.callbacks import FiniteChecksConfig as FiniteChecksConfig
38
+ from nshtrainer.callbacks import GradientSkippingConfig as GradientSkippingConfig
39
+ from nshtrainer.callbacks import (
40
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
+ )
42
+ from nshtrainer.callbacks import NormLoggingConfig as NormLoggingConfig
43
+ from nshtrainer.callbacks import (
44
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
45
+ )
46
+ from nshtrainer.callbacks import PrintTableMetricsConfig as PrintTableMetricsConfig
47
+ from nshtrainer.callbacks import RLPSanityChecksConfig as RLPSanityChecksConfig
48
+ from nshtrainer.callbacks import SharedParametersConfig as SharedParametersConfig
49
+ from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
50
+ from nshtrainer.callbacks import WandbUploadCodeConfig as WandbUploadCodeConfig
51
+ from nshtrainer.callbacks import WandbWatchConfig as WandbWatchConfig
52
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
53
+ from nshtrainer.callbacks.checkpoint._base import (
54
+ BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
55
+ )
56
+ from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
57
+ from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
58
+ from nshtrainer.loggers import LoggerConfig as LoggerConfig
59
+ from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
60
+ from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
61
+ from nshtrainer.lr_scheduler import (
62
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
63
+ )
64
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
65
+ from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
66
+ from nshtrainer.lr_scheduler import (
67
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
68
+ )
69
+ from nshtrainer.model import DirectoryConfig as DirectoryConfig
70
+ from nshtrainer.model import TrainerConfig as TrainerConfig
71
+ from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
72
+ from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
73
+ from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
74
+ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
75
+ from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
76
+ from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
77
+ from nshtrainer.nn import MLPConfig as MLPConfig
78
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
79
+ from nshtrainer.nn import PReLUConfig as PReLUConfig
80
+ from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
81
+ from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
82
+ from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
83
+ from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
84
+ from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
85
+ from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
86
+ from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
87
+ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
88
+ from nshtrainer.nn.nonlinearity import (
89
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
90
+ )
91
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
92
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
93
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
94
+ from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
95
+ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
96
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
97
+ from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
98
+ from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
99
+ from nshtrainer.trainer._config import (
100
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
101
+ )
102
+ from nshtrainer.trainer._config import (
103
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
104
+ )
105
+ from nshtrainer.trainer._config import (
106
+ CheckpointSavingConfig as CheckpointSavingConfig,
107
+ )
108
+ from nshtrainer.trainer._config import (
109
+ GradientClippingConfig as GradientClippingConfig,
110
+ )
111
+ from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
112
+ from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
113
+ from nshtrainer.trainer._config import (
114
+ ReproducibilityConfig as ReproducibilityConfig,
115
+ )
116
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
117
+ from nshtrainer.util._environment_info import (
118
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
119
+ )
120
+ from nshtrainer.util._environment_info import (
121
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
122
+ )
123
+ from nshtrainer.util._environment_info import (
124
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
125
+ )
126
+ from nshtrainer.util._environment_info import (
127
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
128
+ )
129
+ from nshtrainer.util._environment_info import (
130
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
131
+ )
132
+ from nshtrainer.util._environment_info import (
133
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
134
+ )
135
+ from nshtrainer.util._environment_info import (
136
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
137
+ )
138
+ from nshtrainer.util._environment_info import (
139
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
140
+ )
141
+ from nshtrainer.util._environment_info import (
142
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
143
+ )
144
+ from nshtrainer.util._environment_info import (
145
+ GitRepositoryConfig as GitRepositoryConfig,
146
+ )
147
+ from nshtrainer.util.config import DTypeConfig as DTypeConfig
148
+ from nshtrainer.util.config import DurationConfig as DurationConfig
149
+ from nshtrainer.util.config import EpochsConfig as EpochsConfig
150
+ from nshtrainer.util.config import StepsConfig as StepsConfig
151
+ else:
152
+
153
+ def __getattr__(name):
154
+ import importlib
155
+
156
+ if name in globals():
157
+ return globals()[name]
158
+ if name == "BaseConfig":
159
+ return importlib.import_module("nshtrainer").BaseConfig
160
+ if name == "MetricConfig":
161
+ return importlib.import_module("nshtrainer").MetricConfig
162
+ if name == "CallbackConfigBase":
163
+ return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
164
+ if name == "HuggingFaceHubConfig":
165
+ return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
166
+ if name == "HuggingFaceHubAutoCreateConfig":
167
+ return importlib.import_module(
168
+ "nshtrainer._hf_hub"
169
+ ).HuggingFaceHubAutoCreateConfig
170
+ if name == "OptimizerConfigBase":
171
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
172
+ if name == "AdamWConfig":
173
+ return importlib.import_module("nshtrainer.optimizer").AdamWConfig
174
+ if name == "DirectoryConfig":
175
+ return importlib.import_module("nshtrainer.model").DirectoryConfig
176
+ if name == "DirectorySetupConfig":
177
+ return importlib.import_module("nshtrainer.callbacks").DirectorySetupConfig
178
+ if name == "TrainerConfig":
179
+ return importlib.import_module("nshtrainer.model").TrainerConfig
180
+ if name == "EnvironmentConfig":
181
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
182
+ if name == "BaseNonlinearityConfig":
183
+ return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
184
+ if name == "MLPConfig":
185
+ return importlib.import_module("nshtrainer.nn").MLPConfig
186
+ if name == "SwiGLUNonlinearityConfig":
187
+ return importlib.import_module(
188
+ "nshtrainer.nn.nonlinearity"
189
+ ).SwiGLUNonlinearityConfig
190
+ if name == "ReLUNonlinearityConfig":
191
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
192
+ if name == "SiLUNonlinearityConfig":
193
+ return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
194
+ if name == "ELUNonlinearityConfig":
195
+ return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
196
+ if name == "GELUNonlinearityConfig":
197
+ return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
198
+ if name == "SoftplusNonlinearityConfig":
199
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
200
+ if name == "SoftsignNonlinearityConfig":
201
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
202
+ if name == "SwishNonlinearityConfig":
203
+ return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
204
+ if name == "SoftmaxNonlinearityConfig":
205
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
206
+ if name == "MishNonlinearityConfig":
207
+ return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
208
+ if name == "SigmoidNonlinearityConfig":
209
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
210
+ if name == "TanhNonlinearityConfig":
211
+ return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
212
+ if name == "PReLUConfig":
213
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
214
+ if name == "LeakyReLUNonlinearityConfig":
215
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
216
+ if name == "LRSchedulerConfigBase":
217
+ return importlib.import_module(
218
+ "nshtrainer.lr_scheduler"
219
+ ).LRSchedulerConfigBase
220
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
221
+ return importlib.import_module(
222
+ "nshtrainer.lr_scheduler"
223
+ ).LinearWarmupCosineDecayLRSchedulerConfig
224
+ if name == "ReduceLROnPlateauConfig":
225
+ return importlib.import_module(
226
+ "nshtrainer.lr_scheduler"
227
+ ).ReduceLROnPlateauConfig
228
+ if name == "BaseLoggerConfig":
229
+ return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
230
+ if name == "TensorboardLoggerConfig":
231
+ return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
232
+ if name == "WandbLoggerConfig":
233
+ return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
234
+ if name == "WandbUploadCodeConfig":
235
+ return importlib.import_module("nshtrainer.callbacks").WandbUploadCodeConfig
236
+ if name == "WandbWatchConfig":
237
+ return importlib.import_module("nshtrainer.callbacks").WandbWatchConfig
238
+ if name == "CSVLoggerConfig":
239
+ return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
240
+ if name == "EnvironmentPackageConfig":
241
+ return importlib.import_module(
242
+ "nshtrainer.util._environment_info"
243
+ ).EnvironmentPackageConfig
244
+ if name == "EnvironmentSnapshotConfig":
245
+ return importlib.import_module(
246
+ "nshtrainer.util._environment_info"
247
+ ).EnvironmentSnapshotConfig
248
+ if name == "EnvironmentLSFInformationConfig":
249
+ return importlib.import_module(
250
+ "nshtrainer.util._environment_info"
251
+ ).EnvironmentLSFInformationConfig
252
+ if name == "EnvironmentLinuxEnvironmentConfig":
253
+ return importlib.import_module(
254
+ "nshtrainer.util._environment_info"
255
+ ).EnvironmentLinuxEnvironmentConfig
256
+ if name == "EnvironmentSLURMInformationConfig":
257
+ return importlib.import_module(
258
+ "nshtrainer.util._environment_info"
259
+ ).EnvironmentSLURMInformationConfig
260
+ if name == "EnvironmentClassInformationConfig":
261
+ return importlib.import_module(
262
+ "nshtrainer.util._environment_info"
263
+ ).EnvironmentClassInformationConfig
264
+ if name == "GitRepositoryConfig":
265
+ return importlib.import_module(
266
+ "nshtrainer.util._environment_info"
267
+ ).GitRepositoryConfig
268
+ if name == "EnvironmentCUDAConfig":
269
+ return importlib.import_module(
270
+ "nshtrainer.util._environment_info"
271
+ ).EnvironmentCUDAConfig
272
+ if name == "EnvironmentGPUConfig":
273
+ return importlib.import_module(
274
+ "nshtrainer.util._environment_info"
275
+ ).EnvironmentGPUConfig
276
+ if name == "EnvironmentHardwareConfig":
277
+ return importlib.import_module(
278
+ "nshtrainer.util._environment_info"
279
+ ).EnvironmentHardwareConfig
280
+ if name == "EpochsConfig":
281
+ return importlib.import_module("nshtrainer.util.config").EpochsConfig
282
+ if name == "StepsConfig":
283
+ return importlib.import_module("nshtrainer.util.config").StepsConfig
284
+ if name == "DTypeConfig":
285
+ return importlib.import_module("nshtrainer.util.config").DTypeConfig
286
+ if name == "CheckpointLoadingConfig":
287
+ return importlib.import_module(
288
+ "nshtrainer.trainer._config"
289
+ ).CheckpointLoadingConfig
290
+ if name == "OptimizationConfig":
291
+ return importlib.import_module(
292
+ "nshtrainer.trainer._config"
293
+ ).OptimizationConfig
294
+ if name == "GradientClippingConfig":
295
+ return importlib.import_module(
296
+ "nshtrainer.trainer._config"
297
+ ).GradientClippingConfig
298
+ if name == "LastCheckpointCallbackConfig":
299
+ return importlib.import_module(
300
+ "nshtrainer.callbacks"
301
+ ).LastCheckpointCallbackConfig
302
+ if name == "OnExceptionCheckpointCallbackConfig":
303
+ return importlib.import_module(
304
+ "nshtrainer.callbacks"
305
+ ).OnExceptionCheckpointCallbackConfig
306
+ if name == "RLPSanityChecksConfig":
307
+ return importlib.import_module("nshtrainer.callbacks").RLPSanityChecksConfig
308
+ if name == "EarlyStoppingConfig":
309
+ return importlib.import_module("nshtrainer.callbacks").EarlyStoppingConfig
310
+ if name == "DebugFlagCallbackConfig":
311
+ return importlib.import_module(
312
+ "nshtrainer.callbacks"
313
+ ).DebugFlagCallbackConfig
314
+ if name == "CheckpointSavingConfig":
315
+ return importlib.import_module(
316
+ "nshtrainer.trainer._config"
317
+ ).CheckpointSavingConfig
318
+ if name == "BestCheckpointCallbackConfig":
319
+ return importlib.import_module(
320
+ "nshtrainer.callbacks"
321
+ ).BestCheckpointCallbackConfig
322
+ if name == "LoggingConfig":
323
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
324
+ if name == "SanityCheckingConfig":
325
+ return importlib.import_module(
326
+ "nshtrainer.trainer._config"
327
+ ).SanityCheckingConfig
328
+ if name == "SharedParametersConfig":
329
+ return importlib.import_module(
330
+ "nshtrainer.callbacks"
331
+ ).SharedParametersConfig
332
+ if name == "ReproducibilityConfig":
333
+ return importlib.import_module(
334
+ "nshtrainer.trainer._config"
335
+ ).ReproducibilityConfig
336
+ if name == "CheckpointMetadata":
337
+ return importlib.import_module(
338
+ "nshtrainer._checkpoint.loader"
339
+ ).CheckpointMetadata
340
+ if name == "BestCheckpointStrategyConfig":
341
+ return importlib.import_module(
342
+ "nshtrainer._checkpoint.loader"
343
+ ).BestCheckpointStrategyConfig
344
+ if name == "LastCheckpointStrategyConfig":
345
+ return importlib.import_module(
346
+ "nshtrainer._checkpoint.loader"
347
+ ).LastCheckpointStrategyConfig
348
+ if name == "UserProvidedPathCheckpointStrategyConfig":
349
+ return importlib.import_module(
350
+ "nshtrainer._checkpoint.loader"
351
+ ).UserProvidedPathCheckpointStrategyConfig
352
+ if name == "PrintTableMetricsConfig":
353
+ return importlib.import_module(
354
+ "nshtrainer.callbacks"
355
+ ).PrintTableMetricsConfig
356
+ if name == "ThroughputMonitorConfig":
357
+ return importlib.import_module(
358
+ "nshtrainer.callbacks"
359
+ ).ThroughputMonitorConfig
360
+ if name == "GradientSkippingConfig":
361
+ return importlib.import_module(
362
+ "nshtrainer.callbacks"
363
+ ).GradientSkippingConfig
364
+ if name == "EMAConfig":
365
+ return importlib.import_module("nshtrainer.callbacks").EMAConfig
366
+ if name == "ActSaveConfig":
367
+ return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
368
+ if name == "FiniteChecksConfig":
369
+ return importlib.import_module("nshtrainer.callbacks").FiniteChecksConfig
370
+ if name == "NormLoggingConfig":
371
+ return importlib.import_module("nshtrainer.callbacks").NormLoggingConfig
372
+ if name == "EpochTimerConfig":
373
+ return importlib.import_module("nshtrainer.callbacks").EpochTimerConfig
374
+ if name == "BaseCheckpointCallbackConfig":
375
+ return importlib.import_module(
376
+ "nshtrainer.callbacks.checkpoint._base"
377
+ ).BaseCheckpointCallbackConfig
378
+ if name == "BaseProfilerConfig":
379
+ return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
380
+ if name == "PyTorchProfilerConfig":
381
+ return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
382
+ if name == "AdvancedProfilerConfig":
383
+ return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
384
+ if name == "SimpleProfilerConfig":
385
+ return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
386
+ if name == "OptimizerConfig":
387
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
388
+ if name == "LoggerConfig":
389
+ return importlib.import_module("nshtrainer.loggers").LoggerConfig
390
+ if name == "NonlinearityConfig":
391
+ return importlib.import_module("nshtrainer.nn").NonlinearityConfig
392
+ if name == "DurationConfig":
393
+ return importlib.import_module("nshtrainer.util.config").DurationConfig
394
+ if name == "LRSchedulerConfig":
395
+ return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
396
+ if name == "CallbackConfig":
397
+ return importlib.import_module("nshtrainer.callbacks").CallbackConfig
398
+ if name == "CheckpointCallbackConfig":
399
+ return importlib.import_module(
400
+ "nshtrainer.trainer._config"
401
+ ).CheckpointCallbackConfig
402
+ if name == "ProfilerConfig":
403
+ return importlib.import_module("nshtrainer.profiler").ProfilerConfig
404
+ if name == "CheckpointLoadingStrategyConfig":
405
+ return importlib.import_module(
406
+ "nshtrainer._checkpoint.loader"
407
+ ).CheckpointLoadingStrategyConfig
408
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
88
409
 
89
- # Type aliases
90
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
91
- from nshtrainer.loggers import LoggerConfig as LoggerConfig
92
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
93
- from nshtrainer.util.config import DurationConfig as DurationConfig
94
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
95
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
96
- from nshtrainer.trainer._config import CheckpointCallbackConfig as CheckpointCallbackConfig
97
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
98
- from nshtrainer._checkpoint.loader import CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig
99
410
 
100
411
  # Submodule exports
101
412
  from . import _checkpoint as _checkpoint
@@ -1,18 +1,60 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer._checkpoint.loader import UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig
9
- from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
10
- from nshtrainer._checkpoint.loader import CheckpointLoadingConfig as CheckpointLoadingConfig
11
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
12
- from nshtrainer._checkpoint.loader import BestCheckpointStrategyConfig as BestCheckpointStrategyConfig
13
- from nshtrainer._checkpoint.loader import LastCheckpointStrategyConfig as LastCheckpointStrategyConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer._checkpoint.loader import (
9
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
10
+ )
11
+ from nshtrainer._checkpoint.loader import (
12
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
13
+ )
14
+ from nshtrainer._checkpoint.loader import (
15
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
16
+ )
17
+ from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
18
+ from nshtrainer._checkpoint.loader import (
19
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
20
+ )
21
+ from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
22
+ from nshtrainer._checkpoint.loader import (
23
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
24
+ )
25
+ else:
26
+
27
+ def __getattr__(name):
28
+ import importlib
14
29
 
15
- # Type aliases
16
- from nshtrainer._checkpoint.loader import CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig
30
+ if name in globals():
31
+ return globals()[name]
32
+ if name == "CheckpointLoadingConfig":
33
+ return importlib.import_module(
34
+ "nshtrainer._checkpoint.loader"
35
+ ).CheckpointLoadingConfig
36
+ if name == "CheckpointMetadata":
37
+ return importlib.import_module(
38
+ "nshtrainer._checkpoint.loader"
39
+ ).CheckpointMetadata
40
+ if name == "MetricConfig":
41
+ return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
42
+ if name == "BestCheckpointStrategyConfig":
43
+ return importlib.import_module(
44
+ "nshtrainer._checkpoint.loader"
45
+ ).BestCheckpointStrategyConfig
46
+ if name == "LastCheckpointStrategyConfig":
47
+ return importlib.import_module(
48
+ "nshtrainer._checkpoint.loader"
49
+ ).LastCheckpointStrategyConfig
50
+ if name == "UserProvidedPathCheckpointStrategyConfig":
51
+ return importlib.import_module(
52
+ "nshtrainer._checkpoint.loader"
53
+ ).UserProvidedPathCheckpointStrategyConfig
54
+ if name == "CheckpointLoadingStrategyConfig":
55
+ return importlib.import_module(
56
+ "nshtrainer._checkpoint.loader"
57
+ ).CheckpointLoadingStrategyConfig
58
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
17
59
 
18
60
  # Submodule exports
@@ -1,13 +1,27 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
9
- from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
9
+ from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
10
+ else:
11
+
12
+ def __getattr__(name):
13
+ import importlib
10
14
 
11
- # Type aliases
15
+ if name in globals():
16
+ return globals()[name]
17
+ if name == "EnvironmentConfig":
18
+ return importlib.import_module(
19
+ "nshtrainer._checkpoint.metadata"
20
+ ).EnvironmentConfig
21
+ if name == "CheckpointMetadata":
22
+ return importlib.import_module(
23
+ "nshtrainer._checkpoint.metadata"
24
+ ).CheckpointMetadata
25
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
26
 
13
27
  # Submodule exports