nshtrainer 0.41.1__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +445 -94
  35. nshtrainer/config/_checkpoint/loader/__init__.py +56 -12
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +23 -7
  37. nshtrainer/config/_directory/__init__.py +26 -8
  38. nshtrainer/config/_hf_hub/__init__.py +26 -8
  39. nshtrainer/config/callbacks/__init__.py +154 -29
  40. nshtrainer/config/callbacks/actsave/__init__.py +21 -7
  41. nshtrainer/config/callbacks/base/__init__.py +18 -6
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +63 -12
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +34 -8
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +41 -9
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +34 -8
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +27 -7
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +25 -7
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +27 -7
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +32 -8
  50. nshtrainer/config/callbacks/ema/__init__.py +21 -7
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +27 -7
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +27 -7
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +27 -7
  54. nshtrainer/config/callbacks/print_table/__init__.py +27 -7
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +27 -7
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +27 -7
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +27 -7
  58. nshtrainer/config/callbacks/timer/__init__.py +25 -7
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +27 -7
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +27 -7
  61. nshtrainer/config/loggers/__init__.py +49 -14
  62. nshtrainer/config/loggers/_base/__init__.py +16 -6
  63. nshtrainer/config/loggers/csv/__init__.py +19 -7
  64. nshtrainer/config/loggers/tensorboard/__init__.py +25 -7
  65. nshtrainer/config/loggers/wandb/__init__.py +38 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +50 -11
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +20 -6
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +34 -8
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +34 -8
  70. nshtrainer/config/metrics/__init__.py +17 -6
  71. nshtrainer/config/metrics/_config/__init__.py +16 -6
  72. nshtrainer/config/model/__init__.py +32 -11
  73. nshtrainer/config/model/base/__init__.py +19 -7
  74. nshtrainer/config/model/config/__init__.py +31 -11
  75. nshtrainer/config/model/mixins/logger/__init__.py +16 -6
  76. nshtrainer/config/nn/__init__.py +70 -23
  77. nshtrainer/config/nn/mlp/__init__.py +22 -8
  78. nshtrainer/config/nn/nonlinearity/__init__.py +119 -21
  79. nshtrainer/config/optimizer/__init__.py +22 -8
  80. nshtrainer/config/profiler/__init__.py +29 -10
  81. nshtrainer/config/profiler/_base/__init__.py +18 -6
  82. nshtrainer/config/profiler/advanced/__init__.py +25 -7
  83. nshtrainer/config/profiler/pytorch/__init__.py +25 -7
  84. nshtrainer/config/profiler/simple/__init__.py +23 -7
  85. nshtrainer/config/runner/__init__.py +16 -6
  86. nshtrainer/config/trainer/_config/__init__.py +147 -29
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +20 -6
  88. nshtrainer/config/util/_environment_info/__init__.py +88 -16
  89. nshtrainer/config/util/config/__init__.py +26 -9
  90. nshtrainer/config/util/config/dtype/__init__.py +16 -6
  91. nshtrainer/config/util/config/duration/__init__.py +28 -8
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +2 -0
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.41.1.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.43.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.41.1.dist-info/RECORD +0 -162
  162. {nshtrainer-0.41.1.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
@@ -1,101 +1,452 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer import MetricConfig as MetricConfig
9
- from nshtrainer import BaseConfig as BaseConfig
10
- from nshtrainer._hf_hub import HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig
11
- from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
12
- from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
13
- from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
14
- from nshtrainer.optimizer import AdamWConfig as AdamWConfig
15
- from nshtrainer.model import DirectoryConfig as DirectoryConfig
16
- from nshtrainer.callbacks import DirectorySetupConfig as DirectorySetupConfig
17
- from nshtrainer.model import TrainerConfig as TrainerConfig
18
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
19
- from nshtrainer.nn import MLPConfig as MLPConfig
20
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
21
- from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
22
- from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
23
- from nshtrainer.nn import PReLUConfig as PReLUConfig
24
- from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
25
- from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
26
- from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
27
- from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
28
- from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
29
- from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
30
- from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
31
- from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
32
- from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
33
- from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
34
- from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
35
- from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
36
- from nshtrainer.lr_scheduler import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
37
- from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
38
- from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
39
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
40
- from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
41
- from nshtrainer.callbacks import WandbUploadCodeConfig as WandbUploadCodeConfig
42
- from nshtrainer.callbacks import WandbWatchConfig as WandbWatchConfig
43
- from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
44
- from nshtrainer.util._environment_info import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
45
- from nshtrainer.util._environment_info import EnvironmentHardwareConfig as EnvironmentHardwareConfig
46
- from nshtrainer.util._environment_info import EnvironmentPackageConfig as EnvironmentPackageConfig
47
- from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
48
- from nshtrainer.util._environment_info import EnvironmentGPUConfig as EnvironmentGPUConfig
49
- from nshtrainer.util._environment_info import EnvironmentCUDAConfig as EnvironmentCUDAConfig
50
- from nshtrainer.util._environment_info import EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig
51
- from nshtrainer.util._environment_info import EnvironmentClassInformationConfig as EnvironmentClassInformationConfig
52
- from nshtrainer.util._environment_info import EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig
53
- from nshtrainer.util._environment_info import EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig
54
- from nshtrainer.util.config import EpochsConfig as EpochsConfig
55
- from nshtrainer.util.config import StepsConfig as StepsConfig
56
- from nshtrainer.util.config import DTypeConfig as DTypeConfig
57
- from nshtrainer.trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
58
- from nshtrainer.callbacks import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
59
- from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
60
- from nshtrainer.callbacks import SharedParametersConfig as SharedParametersConfig
61
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
62
- from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
63
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
64
- from nshtrainer.callbacks import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
65
- from nshtrainer.callbacks import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
66
- from nshtrainer.callbacks import RLPSanityChecksConfig as RLPSanityChecksConfig
67
- from nshtrainer.callbacks import EarlyStoppingConfig as EarlyStoppingConfig
68
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
69
- from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
70
- from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
71
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
72
- from nshtrainer._checkpoint.loader import UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig
73
- from nshtrainer._checkpoint.loader import BestCheckpointStrategyConfig as BestCheckpointStrategyConfig
74
- from nshtrainer._checkpoint.loader import LastCheckpointStrategyConfig as LastCheckpointStrategyConfig
75
- from nshtrainer.callbacks import PrintTableMetricsConfig as PrintTableMetricsConfig
76
- from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
77
- from nshtrainer.callbacks import GradientSkippingConfig as GradientSkippingConfig
78
- from nshtrainer.callbacks import EMAConfig as EMAConfig
79
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
80
- from nshtrainer.callbacks import FiniteChecksConfig as FiniteChecksConfig
81
- from nshtrainer.callbacks import NormLoggingConfig as NormLoggingConfig
82
- from nshtrainer.callbacks import EpochTimerConfig as EpochTimerConfig
83
- from nshtrainer.callbacks.checkpoint._base import BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig
84
- from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
85
- from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
86
- from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
87
- from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer import BaseConfig as BaseConfig
11
+ from nshtrainer import MetricConfig as MetricConfig
12
+ from nshtrainer._checkpoint.loader import (
13
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
14
+ )
15
+ from nshtrainer._checkpoint.loader import (
16
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
17
+ )
18
+ from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
19
+ from nshtrainer._checkpoint.loader import (
20
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
21
+ )
22
+ from nshtrainer._checkpoint.loader import (
23
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
24
+ )
25
+ from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
26
+ from nshtrainer._hf_hub import (
27
+ HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
28
+ )
29
+ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
30
+ from nshtrainer.callbacks import (
31
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
32
+ )
33
+ from nshtrainer.callbacks import CallbackConfig as CallbackConfig
34
+ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
35
+ from nshtrainer.callbacks import (
36
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
37
+ )
38
+ from nshtrainer.callbacks import (
39
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
40
+ )
41
+ from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
42
+ from nshtrainer.callbacks import (
43
+ EpochTimerCallbackConfig as EpochTimerCallbackConfig,
44
+ )
45
+ from nshtrainer.callbacks import (
46
+ FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
47
+ )
48
+ from nshtrainer.callbacks import (
49
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
50
+ )
51
+ from nshtrainer.callbacks import (
52
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
53
+ )
54
+ from nshtrainer.callbacks import (
55
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
56
+ )
57
+ from nshtrainer.callbacks import (
58
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
59
+ )
60
+ from nshtrainer.callbacks import (
61
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
62
+ )
63
+ from nshtrainer.callbacks import (
64
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
65
+ )
66
+ from nshtrainer.callbacks import (
67
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
68
+ )
69
+ from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
70
+ from nshtrainer.callbacks import (
71
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
72
+ )
73
+ from nshtrainer.callbacks import (
74
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
75
+ )
76
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
77
+ from nshtrainer.callbacks.checkpoint._base import (
78
+ BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
79
+ )
80
+ from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
81
+ from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
82
+ from nshtrainer.loggers import LoggerConfig as LoggerConfig
83
+ from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
84
+ from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
85
+ from nshtrainer.lr_scheduler import (
86
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
87
+ )
88
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
89
+ from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
90
+ from nshtrainer.lr_scheduler import (
91
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
92
+ )
93
+ from nshtrainer.model import DirectoryConfig as DirectoryConfig
94
+ from nshtrainer.model import TrainerConfig as TrainerConfig
95
+ from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
96
+ from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
97
+ from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
98
+ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
99
+ from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
100
+ from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
101
+ from nshtrainer.nn import MLPConfig as MLPConfig
102
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
103
+ from nshtrainer.nn import PReLUConfig as PReLUConfig
104
+ from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
105
+ from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
106
+ from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
107
+ from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
108
+ from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
109
+ from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
110
+ from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
111
+ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
112
+ from nshtrainer.nn.nonlinearity import (
113
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
114
+ )
115
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
116
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
117
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
118
+ from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
119
+ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
120
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
121
+ from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
122
+ from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
123
+ from nshtrainer.trainer._config import (
124
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
125
+ )
126
+ from nshtrainer.trainer._config import (
127
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
128
+ )
129
+ from nshtrainer.trainer._config import (
130
+ CheckpointSavingConfig as CheckpointSavingConfig,
131
+ )
132
+ from nshtrainer.trainer._config import (
133
+ GradientClippingConfig as GradientClippingConfig,
134
+ )
135
+ from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
136
+ from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
137
+ from nshtrainer.trainer._config import (
138
+ ReproducibilityConfig as ReproducibilityConfig,
139
+ )
140
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
141
+ from nshtrainer.util._environment_info import (
142
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
143
+ )
144
+ from nshtrainer.util._environment_info import (
145
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
146
+ )
147
+ from nshtrainer.util._environment_info import (
148
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
149
+ )
150
+ from nshtrainer.util._environment_info import (
151
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
152
+ )
153
+ from nshtrainer.util._environment_info import (
154
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
155
+ )
156
+ from nshtrainer.util._environment_info import (
157
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
158
+ )
159
+ from nshtrainer.util._environment_info import (
160
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
161
+ )
162
+ from nshtrainer.util._environment_info import (
163
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
164
+ )
165
+ from nshtrainer.util._environment_info import (
166
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
167
+ )
168
+ from nshtrainer.util._environment_info import (
169
+ GitRepositoryConfig as GitRepositoryConfig,
170
+ )
171
+ from nshtrainer.util.config import DTypeConfig as DTypeConfig
172
+ from nshtrainer.util.config import DurationConfig as DurationConfig
173
+ from nshtrainer.util.config import EpochsConfig as EpochsConfig
174
+ from nshtrainer.util.config import StepsConfig as StepsConfig
175
+ else:
176
+
177
+ def __getattr__(name):
178
+ import importlib
179
+
180
+ if name in globals():
181
+ return globals()[name]
182
+ if name == "MetricConfig":
183
+ return importlib.import_module("nshtrainer").MetricConfig
184
+ if name == "BaseConfig":
185
+ return importlib.import_module("nshtrainer").BaseConfig
186
+ if name == "HuggingFaceHubAutoCreateConfig":
187
+ return importlib.import_module(
188
+ "nshtrainer._hf_hub"
189
+ ).HuggingFaceHubAutoCreateConfig
190
+ if name == "HuggingFaceHubConfig":
191
+ return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
192
+ if name == "CallbackConfigBase":
193
+ return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
194
+ if name == "OptimizerConfigBase":
195
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
196
+ if name == "AdamWConfig":
197
+ return importlib.import_module("nshtrainer.optimizer").AdamWConfig
198
+ if name == "DirectorySetupCallbackConfig":
199
+ return importlib.import_module(
200
+ "nshtrainer.callbacks"
201
+ ).DirectorySetupCallbackConfig
202
+ if name == "DirectoryConfig":
203
+ return importlib.import_module("nshtrainer.model").DirectoryConfig
204
+ if name == "TrainerConfig":
205
+ return importlib.import_module("nshtrainer.model").TrainerConfig
206
+ if name == "EnvironmentConfig":
207
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
208
+ if name == "BaseNonlinearityConfig":
209
+ return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
210
+ if name == "MLPConfig":
211
+ return importlib.import_module("nshtrainer.nn").MLPConfig
212
+ if name == "PReLUConfig":
213
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
214
+ if name == "LeakyReLUNonlinearityConfig":
215
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
216
+ if name == "SwiGLUNonlinearityConfig":
217
+ return importlib.import_module(
218
+ "nshtrainer.nn.nonlinearity"
219
+ ).SwiGLUNonlinearityConfig
220
+ if name == "SoftsignNonlinearityConfig":
221
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
222
+ if name == "SiLUNonlinearityConfig":
223
+ return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
224
+ if name == "SigmoidNonlinearityConfig":
225
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
226
+ if name == "SoftplusNonlinearityConfig":
227
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
228
+ if name == "ELUNonlinearityConfig":
229
+ return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
230
+ if name == "SoftmaxNonlinearityConfig":
231
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
232
+ if name == "GELUNonlinearityConfig":
233
+ return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
234
+ if name == "SwishNonlinearityConfig":
235
+ return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
236
+ if name == "MishNonlinearityConfig":
237
+ return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
238
+ if name == "TanhNonlinearityConfig":
239
+ return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
240
+ if name == "ReLUNonlinearityConfig":
241
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
242
+ if name == "LRSchedulerConfigBase":
243
+ return importlib.import_module(
244
+ "nshtrainer.lr_scheduler"
245
+ ).LRSchedulerConfigBase
246
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
247
+ return importlib.import_module(
248
+ "nshtrainer.lr_scheduler"
249
+ ).LinearWarmupCosineDecayLRSchedulerConfig
250
+ if name == "ReduceLROnPlateauConfig":
251
+ return importlib.import_module(
252
+ "nshtrainer.lr_scheduler"
253
+ ).ReduceLROnPlateauConfig
254
+ if name == "BaseLoggerConfig":
255
+ return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
256
+ if name == "TensorboardLoggerConfig":
257
+ return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
258
+ if name == "WandbLoggerConfig":
259
+ return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
260
+ if name == "WandbUploadCodeCallbackConfig":
261
+ return importlib.import_module(
262
+ "nshtrainer.callbacks"
263
+ ).WandbUploadCodeCallbackConfig
264
+ if name == "WandbWatchCallbackConfig":
265
+ return importlib.import_module(
266
+ "nshtrainer.callbacks"
267
+ ).WandbWatchCallbackConfig
268
+ if name == "CSVLoggerConfig":
269
+ return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
270
+ if name == "EnvironmentLinuxEnvironmentConfig":
271
+ return importlib.import_module(
272
+ "nshtrainer.util._environment_info"
273
+ ).EnvironmentLinuxEnvironmentConfig
274
+ if name == "EnvironmentLSFInformationConfig":
275
+ return importlib.import_module(
276
+ "nshtrainer.util._environment_info"
277
+ ).EnvironmentLSFInformationConfig
278
+ if name == "EnvironmentGPUConfig":
279
+ return importlib.import_module(
280
+ "nshtrainer.util._environment_info"
281
+ ).EnvironmentGPUConfig
282
+ if name == "EnvironmentPackageConfig":
283
+ return importlib.import_module(
284
+ "nshtrainer.util._environment_info"
285
+ ).EnvironmentPackageConfig
286
+ if name == "EnvironmentHardwareConfig":
287
+ return importlib.import_module(
288
+ "nshtrainer.util._environment_info"
289
+ ).EnvironmentHardwareConfig
290
+ if name == "EnvironmentSnapshotConfig":
291
+ return importlib.import_module(
292
+ "nshtrainer.util._environment_info"
293
+ ).EnvironmentSnapshotConfig
294
+ if name == "EnvironmentClassInformationConfig":
295
+ return importlib.import_module(
296
+ "nshtrainer.util._environment_info"
297
+ ).EnvironmentClassInformationConfig
298
+ if name == "GitRepositoryConfig":
299
+ return importlib.import_module(
300
+ "nshtrainer.util._environment_info"
301
+ ).GitRepositoryConfig
302
+ if name == "EnvironmentCUDAConfig":
303
+ return importlib.import_module(
304
+ "nshtrainer.util._environment_info"
305
+ ).EnvironmentCUDAConfig
306
+ if name == "EnvironmentSLURMInformationConfig":
307
+ return importlib.import_module(
308
+ "nshtrainer.util._environment_info"
309
+ ).EnvironmentSLURMInformationConfig
310
+ if name == "EpochsConfig":
311
+ return importlib.import_module("nshtrainer.util.config").EpochsConfig
312
+ if name == "StepsConfig":
313
+ return importlib.import_module("nshtrainer.util.config").StepsConfig
314
+ if name == "DTypeConfig":
315
+ return importlib.import_module("nshtrainer.util.config").DTypeConfig
316
+ if name == "CheckpointLoadingConfig":
317
+ return importlib.import_module(
318
+ "nshtrainer.trainer._config"
319
+ ).CheckpointLoadingConfig
320
+ if name == "SanityCheckingConfig":
321
+ return importlib.import_module(
322
+ "nshtrainer.trainer._config"
323
+ ).SanityCheckingConfig
324
+ if name == "OnExceptionCheckpointCallbackConfig":
325
+ return importlib.import_module(
326
+ "nshtrainer.callbacks"
327
+ ).OnExceptionCheckpointCallbackConfig
328
+ if name == "GradientClippingConfig":
329
+ return importlib.import_module(
330
+ "nshtrainer.trainer._config"
331
+ ).GradientClippingConfig
332
+ if name == "LoggingConfig":
333
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
334
+ if name == "RLPSanityChecksCallbackConfig":
335
+ return importlib.import_module(
336
+ "nshtrainer.callbacks"
337
+ ).RLPSanityChecksCallbackConfig
338
+ if name == "CheckpointSavingConfig":
339
+ return importlib.import_module(
340
+ "nshtrainer.trainer._config"
341
+ ).CheckpointSavingConfig
342
+ if name == "DebugFlagCallbackConfig":
343
+ return importlib.import_module(
344
+ "nshtrainer.callbacks"
345
+ ).DebugFlagCallbackConfig
346
+ if name == "LastCheckpointCallbackConfig":
347
+ return importlib.import_module(
348
+ "nshtrainer.callbacks"
349
+ ).LastCheckpointCallbackConfig
350
+ if name == "SharedParametersCallbackConfig":
351
+ return importlib.import_module(
352
+ "nshtrainer.callbacks"
353
+ ).SharedParametersCallbackConfig
354
+ if name == "ReproducibilityConfig":
355
+ return importlib.import_module(
356
+ "nshtrainer.trainer._config"
357
+ ).ReproducibilityConfig
358
+ if name == "EarlyStoppingCallbackConfig":
359
+ return importlib.import_module(
360
+ "nshtrainer.callbacks"
361
+ ).EarlyStoppingCallbackConfig
362
+ if name == "OptimizationConfig":
363
+ return importlib.import_module(
364
+ "nshtrainer.trainer._config"
365
+ ).OptimizationConfig
366
+ if name == "BestCheckpointCallbackConfig":
367
+ return importlib.import_module(
368
+ "nshtrainer.callbacks"
369
+ ).BestCheckpointCallbackConfig
370
+ if name == "CheckpointMetadata":
371
+ return importlib.import_module(
372
+ "nshtrainer._checkpoint.loader"
373
+ ).CheckpointMetadata
374
+ if name == "BestCheckpointStrategyConfig":
375
+ return importlib.import_module(
376
+ "nshtrainer._checkpoint.loader"
377
+ ).BestCheckpointStrategyConfig
378
+ if name == "LastCheckpointStrategyConfig":
379
+ return importlib.import_module(
380
+ "nshtrainer._checkpoint.loader"
381
+ ).LastCheckpointStrategyConfig
382
+ if name == "UserProvidedPathCheckpointStrategyConfig":
383
+ return importlib.import_module(
384
+ "nshtrainer._checkpoint.loader"
385
+ ).UserProvidedPathCheckpointStrategyConfig
386
+ if name == "PrintTableMetricsCallbackConfig":
387
+ return importlib.import_module(
388
+ "nshtrainer.callbacks"
389
+ ).PrintTableMetricsCallbackConfig
390
+ if name == "ThroughputMonitorConfig":
391
+ return importlib.import_module(
392
+ "nshtrainer.callbacks"
393
+ ).ThroughputMonitorConfig
394
+ if name == "GradientSkippingCallbackConfig":
395
+ return importlib.import_module(
396
+ "nshtrainer.callbacks"
397
+ ).GradientSkippingCallbackConfig
398
+ if name == "EMACallbackConfig":
399
+ return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
400
+ if name == "ActSaveConfig":
401
+ return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
402
+ if name == "FiniteChecksCallbackConfig":
403
+ return importlib.import_module(
404
+ "nshtrainer.callbacks"
405
+ ).FiniteChecksCallbackConfig
406
+ if name == "NormLoggingCallbackConfig":
407
+ return importlib.import_module(
408
+ "nshtrainer.callbacks"
409
+ ).NormLoggingCallbackConfig
410
+ if name == "EpochTimerCallbackConfig":
411
+ return importlib.import_module(
412
+ "nshtrainer.callbacks"
413
+ ).EpochTimerCallbackConfig
414
+ if name == "BaseCheckpointCallbackConfig":
415
+ return importlib.import_module(
416
+ "nshtrainer.callbacks.checkpoint._base"
417
+ ).BaseCheckpointCallbackConfig
418
+ if name == "BaseProfilerConfig":
419
+ return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
420
+ if name == "PyTorchProfilerConfig":
421
+ return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
422
+ if name == "AdvancedProfilerConfig":
423
+ return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
424
+ if name == "SimpleProfilerConfig":
425
+ return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
426
+ if name == "OptimizerConfig":
427
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
428
+ if name == "LoggerConfig":
429
+ return importlib.import_module("nshtrainer.loggers").LoggerConfig
430
+ if name == "NonlinearityConfig":
431
+ return importlib.import_module("nshtrainer.nn").NonlinearityConfig
432
+ if name == "DurationConfig":
433
+ return importlib.import_module("nshtrainer.util.config").DurationConfig
434
+ if name == "LRSchedulerConfig":
435
+ return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
436
+ if name == "CallbackConfig":
437
+ return importlib.import_module("nshtrainer.callbacks").CallbackConfig
438
+ if name == "CheckpointCallbackConfig":
439
+ return importlib.import_module(
440
+ "nshtrainer.trainer._config"
441
+ ).CheckpointCallbackConfig
442
+ if name == "ProfilerConfig":
443
+ return importlib.import_module("nshtrainer.profiler").ProfilerConfig
444
+ if name == "CheckpointLoadingStrategyConfig":
445
+ return importlib.import_module(
446
+ "nshtrainer._checkpoint.loader"
447
+ ).CheckpointLoadingStrategyConfig
448
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
88
449
 
89
- # Type aliases
90
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
91
- from nshtrainer.loggers import LoggerConfig as LoggerConfig
92
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
93
- from nshtrainer.util.config import DurationConfig as DurationConfig
94
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
95
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
96
- from nshtrainer.trainer._config import CheckpointCallbackConfig as CheckpointCallbackConfig
97
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
98
- from nshtrainer._checkpoint.loader import CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig
99
450
 
100
451
  # Submodule exports
101
452
  from . import _checkpoint as _checkpoint
@@ -1,18 +1,62 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer._checkpoint.loader import UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig
9
- from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
10
- from nshtrainer._checkpoint.loader import CheckpointLoadingConfig as CheckpointLoadingConfig
11
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
12
- from nshtrainer._checkpoint.loader import BestCheckpointStrategyConfig as BestCheckpointStrategyConfig
13
- from nshtrainer._checkpoint.loader import LastCheckpointStrategyConfig as LastCheckpointStrategyConfig
5
+ from typing import TYPE_CHECKING
14
6
 
15
- # Type aliases
16
- from nshtrainer._checkpoint.loader import CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer._checkpoint.loader import (
11
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
12
+ )
13
+ from nshtrainer._checkpoint.loader import (
14
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
15
+ )
16
+ from nshtrainer._checkpoint.loader import (
17
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
18
+ )
19
+ from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
20
+ from nshtrainer._checkpoint.loader import (
21
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
22
+ )
23
+ from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
24
+ from nshtrainer._checkpoint.loader import (
25
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
26
+ )
27
+ else:
28
+
29
+ def __getattr__(name):
30
+ import importlib
31
+
32
+ if name in globals():
33
+ return globals()[name]
34
+ if name == "MetricConfig":
35
+ return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
36
+ if name == "BestCheckpointStrategyConfig":
37
+ return importlib.import_module(
38
+ "nshtrainer._checkpoint.loader"
39
+ ).BestCheckpointStrategyConfig
40
+ if name == "CheckpointMetadata":
41
+ return importlib.import_module(
42
+ "nshtrainer._checkpoint.loader"
43
+ ).CheckpointMetadata
44
+ if name == "LastCheckpointStrategyConfig":
45
+ return importlib.import_module(
46
+ "nshtrainer._checkpoint.loader"
47
+ ).LastCheckpointStrategyConfig
48
+ if name == "CheckpointLoadingConfig":
49
+ return importlib.import_module(
50
+ "nshtrainer._checkpoint.loader"
51
+ ).CheckpointLoadingConfig
52
+ if name == "UserProvidedPathCheckpointStrategyConfig":
53
+ return importlib.import_module(
54
+ "nshtrainer._checkpoint.loader"
55
+ ).UserProvidedPathCheckpointStrategyConfig
56
+ if name == "CheckpointLoadingStrategyConfig":
57
+ return importlib.import_module(
58
+ "nshtrainer._checkpoint.loader"
59
+ ).CheckpointLoadingStrategyConfig
60
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
17
61
 
18
62
  # Submodule exports
@@ -1,13 +1,29 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
9
- from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
11
+ from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "CheckpointMetadata":
20
+ return importlib.import_module(
21
+ "nshtrainer._checkpoint.metadata"
22
+ ).CheckpointMetadata
23
+ if name == "EnvironmentConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer._checkpoint.metadata"
26
+ ).EnvironmentConfig
27
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
28
 
13
29
  # Submodule exports