nshtrainer 1.0.0b13__py3-none-any.whl → 1.0.0b15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +3 -3
  2. nshtrainer/configs/__init__.py +134 -440
  3. nshtrainer/configs/_checkpoint/__init__.py +2 -64
  4. nshtrainer/configs/_checkpoint/metadata/__init__.py +2 -25
  5. nshtrainer/configs/_directory/__init__.py +5 -28
  6. nshtrainer/configs/_hf_hub/__init__.py +5 -28
  7. nshtrainer/configs/callbacks/__init__.py +52 -161
  8. nshtrainer/configs/callbacks/actsave/__init__.py +2 -23
  9. nshtrainer/configs/callbacks/base/__init__.py +1 -20
  10. nshtrainer/configs/callbacks/checkpoint/__init__.py +19 -64
  11. nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +9 -36
  12. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +10 -43
  13. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +9 -36
  14. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -29
  15. nshtrainer/configs/callbacks/debug_flag/__init__.py +4 -27
  16. nshtrainer/configs/callbacks/directory_setup/__init__.py +6 -29
  17. nshtrainer/configs/callbacks/early_stopping/__init__.py +5 -34
  18. nshtrainer/configs/callbacks/ema/__init__.py +2 -23
  19. nshtrainer/configs/callbacks/finite_checks/__init__.py +4 -29
  20. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +6 -29
  21. nshtrainer/configs/callbacks/log_epoch/__init__.py +4 -27
  22. nshtrainer/configs/callbacks/lr_monitor/__init__.py +4 -27
  23. nshtrainer/configs/callbacks/norm_logging/__init__.py +4 -29
  24. nshtrainer/configs/callbacks/print_table/__init__.py +4 -29
  25. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +6 -29
  26. nshtrainer/configs/callbacks/shared_parameters/__init__.py +6 -29
  27. nshtrainer/configs/callbacks/timer/__init__.py +4 -27
  28. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +6 -29
  29. nshtrainer/configs/callbacks/wandb_watch/__init__.py +4 -29
  30. nshtrainer/configs/loggers/__init__.py +13 -52
  31. nshtrainer/configs/loggers/_base/__init__.py +1 -18
  32. nshtrainer/configs/loggers/actsave/__init__.py +2 -25
  33. nshtrainer/configs/loggers/csv/__init__.py +2 -21
  34. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -27
  35. nshtrainer/configs/loggers/wandb/__init__.py +9 -40
  36. nshtrainer/configs/lr_scheduler/__init__.py +10 -51
  37. nshtrainer/configs/lr_scheduler/_base/__init__.py +1 -22
  38. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +9 -36
  39. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +7 -36
  40. nshtrainer/configs/metrics/__init__.py +1 -18
  41. nshtrainer/configs/metrics/_config/__init__.py +1 -18
  42. nshtrainer/configs/nn/__init__.py +19 -70
  43. nshtrainer/configs/nn/mlp/__init__.py +3 -24
  44. nshtrainer/configs/nn/nonlinearity/__init__.py +30 -121
  45. nshtrainer/configs/optimizer/__init__.py +3 -24
  46. nshtrainer/configs/profiler/__init__.py +5 -30
  47. nshtrainer/configs/profiler/_base/__init__.py +1 -20
  48. nshtrainer/configs/profiler/advanced/__init__.py +4 -27
  49. nshtrainer/configs/profiler/pytorch/__init__.py +2 -27
  50. nshtrainer/configs/profiler/simple/__init__.py +2 -25
  51. nshtrainer/configs/trainer/__init__.py +50 -177
  52. nshtrainer/configs/trainer/_config/__init__.py +50 -176
  53. nshtrainer/configs/trainer/trainer/__init__.py +2 -23
  54. nshtrainer/configs/util/__init__.py +33 -102
  55. nshtrainer/configs/util/_environment_info/__init__.py +29 -90
  56. nshtrainer/configs/util/config/__init__.py +4 -27
  57. nshtrainer/configs/util/config/dtype/__init__.py +1 -18
  58. nshtrainer/configs/util/config/duration/__init__.py +3 -30
  59. nshtrainer/trainer/_config.py +0 -7
  60. nshtrainer/trainer/trainer.py +1 -11
  61. {nshtrainer-1.0.0b13.dist-info → nshtrainer-1.0.0b15.dist-info}/METADATA +1 -1
  62. {nshtrainer-1.0.0b13.dist-info → nshtrainer-1.0.0b15.dist-info}/RECORD +63 -67
  63. nshtrainer/_checkpoint/loader.py +0 -387
  64. nshtrainer/configs/_checkpoint/loader/__init__.py +0 -62
  65. nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -26
  66. nshtrainer/trainer/checkpoint_connector.py +0 -86
  67. {nshtrainer-1.0.0b13.dist-info → nshtrainer-1.0.0b15.dist-info}/WHEEL +0 -0
@@ -96,8 +96,8 @@ class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
96
96
  def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
97
97
  # Monkey-patch the strategy instance to make the barrier operation a no-op.
98
98
  # We do this because `save_checkpoint` calls `barrier`. This is okay in most
99
- # cases, but when we want to save a checkpoint in the case of an exception,
100
- # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
101
- # make the barrier operation a no-op.
99
+ # cases, but when we want to save a checkpoint in the case of an exception,
100
+ # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
101
+ # make the barrier operation a no-op.
102
102
  with _monkey_patch_disable_barrier(trainer):
103
103
  return super().on_exception(trainer, *args, **kwargs)
@@ -2,447 +2,141 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer import MetricConfig as MetricConfig
6
+ from nshtrainer import TrainerConfig as TrainerConfig
7
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
+ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
+ from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
10
+ from nshtrainer._hf_hub import (
11
+ HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
12
+ )
13
+ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
14
+ from nshtrainer.callbacks import (
15
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
16
+ )
17
+ from nshtrainer.callbacks import CallbackConfig as CallbackConfig
18
+ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
19
+ from nshtrainer.callbacks import (
20
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
21
+ )
22
+ from nshtrainer.callbacks import (
23
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
24
+ )
25
+ from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
26
+ from nshtrainer.callbacks import EpochTimerCallbackConfig as EpochTimerCallbackConfig
27
+ from nshtrainer.callbacks import (
28
+ FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
29
+ )
30
+ from nshtrainer.callbacks import (
31
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
32
+ )
33
+ from nshtrainer.callbacks import (
34
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
35
+ )
36
+ from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
37
+ from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
38
+ from nshtrainer.callbacks import (
39
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
40
+ )
41
+ from nshtrainer.callbacks import (
42
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
43
+ )
44
+ from nshtrainer.callbacks import (
45
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
46
+ )
47
+ from nshtrainer.callbacks import (
48
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
49
+ )
50
+ from nshtrainer.callbacks import (
51
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
52
+ )
53
+ from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
54
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
55
+ from nshtrainer.callbacks.checkpoint._base import (
56
+ BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
57
+ )
58
+ from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
59
+ from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
60
+ from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
61
+ from nshtrainer.loggers import LoggerConfig as LoggerConfig
62
+ from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
63
+ from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
64
+ from nshtrainer.lr_scheduler import (
65
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
66
+ )
67
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
68
+ from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
69
+ from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
70
+ from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
71
+ from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
72
+ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
73
+ from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
74
+ from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
75
+ from nshtrainer.nn import MLPConfig as MLPConfig
76
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
77
+ from nshtrainer.nn import PReLUConfig as PReLUConfig
78
+ from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
79
+ from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
80
+ from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
81
+ from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
82
+ from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
83
+ from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
84
+ from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
85
+ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
86
+ from nshtrainer.nn.nonlinearity import (
87
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
88
+ )
89
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
90
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
91
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
92
+ from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
93
+ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
94
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
95
+ from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
96
+ from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
97
+ from nshtrainer.trainer._config import (
98
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
99
+ )
100
+ from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
101
+ from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
102
+ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
103
+ from nshtrainer.trainer._config import (
104
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
105
+ )
106
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
107
+ from nshtrainer.util._environment_info import (
108
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
109
+ )
110
+ from nshtrainer.util._environment_info import (
111
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
112
+ )
113
+ from nshtrainer.util._environment_info import (
114
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
115
+ )
116
+ from nshtrainer.util._environment_info import (
117
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
118
+ )
119
+ from nshtrainer.util._environment_info import (
120
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
121
+ )
122
+ from nshtrainer.util._environment_info import (
123
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
124
+ )
125
+ from nshtrainer.util._environment_info import (
126
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
127
+ )
128
+ from nshtrainer.util._environment_info import (
129
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
130
+ )
131
+ from nshtrainer.util._environment_info import (
132
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
133
+ )
134
+ from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
135
+ from nshtrainer.util.config import DTypeConfig as DTypeConfig
136
+ from nshtrainer.util.config import DurationConfig as DurationConfig
137
+ from nshtrainer.util.config import EpochsConfig as EpochsConfig
138
+ from nshtrainer.util.config import StepsConfig as StepsConfig
6
139
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer import MetricConfig as MetricConfig
11
- from nshtrainer import TrainerConfig as TrainerConfig
12
- from nshtrainer._checkpoint.loader import (
13
- BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
14
- )
15
- from nshtrainer._checkpoint.loader import (
16
- CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
17
- )
18
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
19
- from nshtrainer._checkpoint.loader import (
20
- LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
21
- )
22
- from nshtrainer._checkpoint.loader import (
23
- UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
24
- )
25
- from nshtrainer._directory import DirectoryConfig as DirectoryConfig
26
- from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
27
- from nshtrainer._hf_hub import (
28
- HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
29
- )
30
- from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
31
- from nshtrainer.callbacks import (
32
- BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
33
- )
34
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
35
- from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
36
- from nshtrainer.callbacks import (
37
- DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
38
- )
39
- from nshtrainer.callbacks import (
40
- EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
41
- )
42
- from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
43
- from nshtrainer.callbacks import (
44
- EpochTimerCallbackConfig as EpochTimerCallbackConfig,
45
- )
46
- from nshtrainer.callbacks import (
47
- FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
48
- )
49
- from nshtrainer.callbacks import (
50
- GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
51
- )
52
- from nshtrainer.callbacks import (
53
- LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
54
- )
55
- from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
56
- from nshtrainer.callbacks import (
57
- NormLoggingCallbackConfig as NormLoggingCallbackConfig,
58
- )
59
- from nshtrainer.callbacks import (
60
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
61
- )
62
- from nshtrainer.callbacks import (
63
- PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
64
- )
65
- from nshtrainer.callbacks import (
66
- RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
67
- )
68
- from nshtrainer.callbacks import (
69
- SharedParametersCallbackConfig as SharedParametersCallbackConfig,
70
- )
71
- from nshtrainer.callbacks import (
72
- WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
73
- )
74
- from nshtrainer.callbacks import (
75
- WandbWatchCallbackConfig as WandbWatchCallbackConfig,
76
- )
77
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
78
- from nshtrainer.callbacks.checkpoint._base import (
79
- BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
80
- )
81
- from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
82
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
83
- from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
84
- from nshtrainer.loggers import LoggerConfig as LoggerConfig
85
- from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
86
- from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
87
- from nshtrainer.lr_scheduler import (
88
- LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
89
- )
90
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
91
- from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
92
- from nshtrainer.lr_scheduler import (
93
- ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
94
- )
95
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
96
- from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
97
- from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
98
- from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
99
- from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
100
- from nshtrainer.nn import MLPConfig as MLPConfig
101
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
102
- from nshtrainer.nn import PReLUConfig as PReLUConfig
103
- from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
104
- from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
105
- from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
106
- from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
107
- from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
108
- from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
109
- from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
110
- from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
111
- from nshtrainer.nn.nonlinearity import (
112
- SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
113
- )
114
- from nshtrainer.optimizer import AdamWConfig as AdamWConfig
115
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
116
- from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
117
- from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
118
- from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
119
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
120
- from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
121
- from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
122
- from nshtrainer.trainer._config import (
123
- CheckpointCallbackConfig as CheckpointCallbackConfig,
124
- )
125
- from nshtrainer.trainer._config import (
126
- CheckpointLoadingConfig as CheckpointLoadingConfig,
127
- )
128
- from nshtrainer.trainer._config import (
129
- CheckpointSavingConfig as CheckpointSavingConfig,
130
- )
131
- from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
132
- from nshtrainer.trainer._config import (
133
- GradientClippingConfig as GradientClippingConfig,
134
- )
135
- from nshtrainer.trainer._config import (
136
- LearningRateMonitorConfig as LearningRateMonitorConfig,
137
- )
138
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
139
- from nshtrainer.util._environment_info import (
140
- EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
141
- )
142
- from nshtrainer.util._environment_info import (
143
- EnvironmentCUDAConfig as EnvironmentCUDAConfig,
144
- )
145
- from nshtrainer.util._environment_info import (
146
- EnvironmentGPUConfig as EnvironmentGPUConfig,
147
- )
148
- from nshtrainer.util._environment_info import (
149
- EnvironmentHardwareConfig as EnvironmentHardwareConfig,
150
- )
151
- from nshtrainer.util._environment_info import (
152
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
153
- )
154
- from nshtrainer.util._environment_info import (
155
- EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
156
- )
157
- from nshtrainer.util._environment_info import (
158
- EnvironmentPackageConfig as EnvironmentPackageConfig,
159
- )
160
- from nshtrainer.util._environment_info import (
161
- EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
162
- )
163
- from nshtrainer.util._environment_info import (
164
- EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
165
- )
166
- from nshtrainer.util._environment_info import (
167
- GitRepositoryConfig as GitRepositoryConfig,
168
- )
169
- from nshtrainer.util.config import DTypeConfig as DTypeConfig
170
- from nshtrainer.util.config import DurationConfig as DurationConfig
171
- from nshtrainer.util.config import EpochsConfig as EpochsConfig
172
- from nshtrainer.util.config import StepsConfig as StepsConfig
173
- else:
174
-
175
- def __getattr__(name):
176
- import importlib
177
-
178
- if name in globals():
179
- return globals()[name]
180
- if name == "ActSaveConfig":
181
- return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
182
- if name == "ActSaveLoggerConfig":
183
- return importlib.import_module("nshtrainer.loggers").ActSaveLoggerConfig
184
- if name == "AdamWConfig":
185
- return importlib.import_module("nshtrainer.optimizer").AdamWConfig
186
- if name == "AdvancedProfilerConfig":
187
- return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
188
- if name == "BaseCheckpointCallbackConfig":
189
- return importlib.import_module(
190
- "nshtrainer.callbacks.checkpoint._base"
191
- ).BaseCheckpointCallbackConfig
192
- if name == "BaseLoggerConfig":
193
- return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
194
- if name == "BaseNonlinearityConfig":
195
- return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
196
- if name == "BaseProfilerConfig":
197
- return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
198
- if name == "BestCheckpointCallbackConfig":
199
- return importlib.import_module(
200
- "nshtrainer.callbacks"
201
- ).BestCheckpointCallbackConfig
202
- if name == "BestCheckpointStrategyConfig":
203
- return importlib.import_module(
204
- "nshtrainer._checkpoint.loader"
205
- ).BestCheckpointStrategyConfig
206
- if name == "CSVLoggerConfig":
207
- return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
208
- if name == "CallbackConfigBase":
209
- return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
210
- if name == "CheckpointLoadingConfig":
211
- return importlib.import_module(
212
- "nshtrainer.trainer._config"
213
- ).CheckpointLoadingConfig
214
- if name == "CheckpointMetadata":
215
- return importlib.import_module(
216
- "nshtrainer._checkpoint.loader"
217
- ).CheckpointMetadata
218
- if name == "CheckpointSavingConfig":
219
- return importlib.import_module(
220
- "nshtrainer.trainer._config"
221
- ).CheckpointSavingConfig
222
- if name == "DTypeConfig":
223
- return importlib.import_module("nshtrainer.util.config").DTypeConfig
224
- if name == "DebugFlagCallbackConfig":
225
- return importlib.import_module(
226
- "nshtrainer.callbacks"
227
- ).DebugFlagCallbackConfig
228
- if name == "DirectoryConfig":
229
- return importlib.import_module("nshtrainer._directory").DirectoryConfig
230
- if name == "DirectorySetupCallbackConfig":
231
- return importlib.import_module(
232
- "nshtrainer.callbacks"
233
- ).DirectorySetupCallbackConfig
234
- if name == "ELUNonlinearityConfig":
235
- return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
236
- if name == "EMACallbackConfig":
237
- return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
238
- if name == "EarlyStoppingCallbackConfig":
239
- return importlib.import_module(
240
- "nshtrainer.callbacks"
241
- ).EarlyStoppingCallbackConfig
242
- if name == "EnvironmentCUDAConfig":
243
- return importlib.import_module(
244
- "nshtrainer.util._environment_info"
245
- ).EnvironmentCUDAConfig
246
- if name == "EnvironmentClassInformationConfig":
247
- return importlib.import_module(
248
- "nshtrainer.util._environment_info"
249
- ).EnvironmentClassInformationConfig
250
- if name == "EnvironmentConfig":
251
- return importlib.import_module(
252
- "nshtrainer.trainer._config"
253
- ).EnvironmentConfig
254
- if name == "EnvironmentGPUConfig":
255
- return importlib.import_module(
256
- "nshtrainer.util._environment_info"
257
- ).EnvironmentGPUConfig
258
- if name == "EnvironmentHardwareConfig":
259
- return importlib.import_module(
260
- "nshtrainer.util._environment_info"
261
- ).EnvironmentHardwareConfig
262
- if name == "EnvironmentLSFInformationConfig":
263
- return importlib.import_module(
264
- "nshtrainer.util._environment_info"
265
- ).EnvironmentLSFInformationConfig
266
- if name == "EnvironmentLinuxEnvironmentConfig":
267
- return importlib.import_module(
268
- "nshtrainer.util._environment_info"
269
- ).EnvironmentLinuxEnvironmentConfig
270
- if name == "EnvironmentPackageConfig":
271
- return importlib.import_module(
272
- "nshtrainer.util._environment_info"
273
- ).EnvironmentPackageConfig
274
- if name == "EnvironmentSLURMInformationConfig":
275
- return importlib.import_module(
276
- "nshtrainer.util._environment_info"
277
- ).EnvironmentSLURMInformationConfig
278
- if name == "EnvironmentSnapshotConfig":
279
- return importlib.import_module(
280
- "nshtrainer.util._environment_info"
281
- ).EnvironmentSnapshotConfig
282
- if name == "EpochTimerCallbackConfig":
283
- return importlib.import_module(
284
- "nshtrainer.callbacks"
285
- ).EpochTimerCallbackConfig
286
- if name == "EpochsConfig":
287
- return importlib.import_module("nshtrainer.util.config").EpochsConfig
288
- if name == "FiniteChecksCallbackConfig":
289
- return importlib.import_module(
290
- "nshtrainer.callbacks"
291
- ).FiniteChecksCallbackConfig
292
- if name == "GELUNonlinearityConfig":
293
- return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
294
- if name == "GitRepositoryConfig":
295
- return importlib.import_module(
296
- "nshtrainer.util._environment_info"
297
- ).GitRepositoryConfig
298
- if name == "GradientClippingConfig":
299
- return importlib.import_module(
300
- "nshtrainer.trainer._config"
301
- ).GradientClippingConfig
302
- if name == "GradientSkippingCallbackConfig":
303
- return importlib.import_module(
304
- "nshtrainer.callbacks"
305
- ).GradientSkippingCallbackConfig
306
- if name == "HuggingFaceHubAutoCreateConfig":
307
- return importlib.import_module(
308
- "nshtrainer._hf_hub"
309
- ).HuggingFaceHubAutoCreateConfig
310
- if name == "HuggingFaceHubConfig":
311
- return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
312
- if name == "LRSchedulerConfigBase":
313
- return importlib.import_module(
314
- "nshtrainer.lr_scheduler"
315
- ).LRSchedulerConfigBase
316
- if name == "LastCheckpointCallbackConfig":
317
- return importlib.import_module(
318
- "nshtrainer.callbacks"
319
- ).LastCheckpointCallbackConfig
320
- if name == "LastCheckpointStrategyConfig":
321
- return importlib.import_module(
322
- "nshtrainer._checkpoint.loader"
323
- ).LastCheckpointStrategyConfig
324
- if name == "LeakyReLUNonlinearityConfig":
325
- return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
326
- if name == "LearningRateMonitorConfig":
327
- return importlib.import_module(
328
- "nshtrainer.trainer._config"
329
- ).LearningRateMonitorConfig
330
- if name == "LinearWarmupCosineDecayLRSchedulerConfig":
331
- return importlib.import_module(
332
- "nshtrainer.lr_scheduler"
333
- ).LinearWarmupCosineDecayLRSchedulerConfig
334
- if name == "LogEpochCallbackConfig":
335
- return importlib.import_module(
336
- "nshtrainer.callbacks"
337
- ).LogEpochCallbackConfig
338
- if name == "MLPConfig":
339
- return importlib.import_module("nshtrainer.nn").MLPConfig
340
- if name == "MetricConfig":
341
- return importlib.import_module("nshtrainer").MetricConfig
342
- if name == "MishNonlinearityConfig":
343
- return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
344
- if name == "NormLoggingCallbackConfig":
345
- return importlib.import_module(
346
- "nshtrainer.callbacks"
347
- ).NormLoggingCallbackConfig
348
- if name == "OnExceptionCheckpointCallbackConfig":
349
- return importlib.import_module(
350
- "nshtrainer.callbacks"
351
- ).OnExceptionCheckpointCallbackConfig
352
- if name == "OptimizerConfigBase":
353
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
354
- if name == "PReLUConfig":
355
- return importlib.import_module("nshtrainer.nn").PReLUConfig
356
- if name == "PrintTableMetricsCallbackConfig":
357
- return importlib.import_module(
358
- "nshtrainer.callbacks"
359
- ).PrintTableMetricsCallbackConfig
360
- if name == "PyTorchProfilerConfig":
361
- return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
362
- if name == "RLPSanityChecksCallbackConfig":
363
- return importlib.import_module(
364
- "nshtrainer.callbacks"
365
- ).RLPSanityChecksCallbackConfig
366
- if name == "ReLUNonlinearityConfig":
367
- return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
368
- if name == "ReduceLROnPlateauConfig":
369
- return importlib.import_module(
370
- "nshtrainer.lr_scheduler"
371
- ).ReduceLROnPlateauConfig
372
- if name == "SanityCheckingConfig":
373
- return importlib.import_module(
374
- "nshtrainer.trainer._config"
375
- ).SanityCheckingConfig
376
- if name == "SharedParametersCallbackConfig":
377
- return importlib.import_module(
378
- "nshtrainer.callbacks"
379
- ).SharedParametersCallbackConfig
380
- if name == "SiLUNonlinearityConfig":
381
- return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
382
- if name == "SigmoidNonlinearityConfig":
383
- return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
384
- if name == "SimpleProfilerConfig":
385
- return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
386
- if name == "SoftmaxNonlinearityConfig":
387
- return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
388
- if name == "SoftplusNonlinearityConfig":
389
- return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
390
- if name == "SoftsignNonlinearityConfig":
391
- return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
392
- if name == "StepsConfig":
393
- return importlib.import_module("nshtrainer.util.config").StepsConfig
394
- if name == "SwiGLUNonlinearityConfig":
395
- return importlib.import_module(
396
- "nshtrainer.nn.nonlinearity"
397
- ).SwiGLUNonlinearityConfig
398
- if name == "SwishNonlinearityConfig":
399
- return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
400
- if name == "TanhNonlinearityConfig":
401
- return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
402
- if name == "TensorboardLoggerConfig":
403
- return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
404
- if name == "TrainerConfig":
405
- return importlib.import_module("nshtrainer").TrainerConfig
406
- if name == "UserProvidedPathCheckpointStrategyConfig":
407
- return importlib.import_module(
408
- "nshtrainer._checkpoint.loader"
409
- ).UserProvidedPathCheckpointStrategyConfig
410
- if name == "WandbLoggerConfig":
411
- return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
412
- if name == "WandbUploadCodeCallbackConfig":
413
- return importlib.import_module(
414
- "nshtrainer.callbacks"
415
- ).WandbUploadCodeCallbackConfig
416
- if name == "WandbWatchCallbackConfig":
417
- return importlib.import_module(
418
- "nshtrainer.callbacks"
419
- ).WandbWatchCallbackConfig
420
- if name == "CallbackConfig":
421
- return importlib.import_module("nshtrainer.callbacks").CallbackConfig
422
- if name == "CheckpointCallbackConfig":
423
- return importlib.import_module(
424
- "nshtrainer.trainer._config"
425
- ).CheckpointCallbackConfig
426
- if name == "CheckpointLoadingStrategyConfig":
427
- return importlib.import_module(
428
- "nshtrainer._checkpoint.loader"
429
- ).CheckpointLoadingStrategyConfig
430
- if name == "DurationConfig":
431
- return importlib.import_module("nshtrainer.util.config").DurationConfig
432
- if name == "LRSchedulerConfig":
433
- return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
434
- if name == "LoggerConfig":
435
- return importlib.import_module("nshtrainer.loggers").LoggerConfig
436
- if name == "NonlinearityConfig":
437
- return importlib.import_module("nshtrainer.nn").NonlinearityConfig
438
- if name == "OptimizerConfig":
439
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
440
- if name == "ProfilerConfig":
441
- return importlib.import_module("nshtrainer.profiler").ProfilerConfig
442
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
443
-
444
-
445
- # Submodule exports
446
140
  from . import _checkpoint as _checkpoint
447
141
  from . import _directory as _directory
448
142
  from . import _hf_hub as _hf_hub
@@ -2,69 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
6
+ from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
6
7
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer._checkpoint.loader import (
11
- BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
12
- )
13
- from nshtrainer._checkpoint.loader import (
14
- CheckpointLoadingConfig as CheckpointLoadingConfig,
15
- )
16
- from nshtrainer._checkpoint.loader import (
17
- CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
18
- )
19
- from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
20
- from nshtrainer._checkpoint.loader import (
21
- LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
22
- )
23
- from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
24
- from nshtrainer._checkpoint.loader import (
25
- UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
26
- )
27
- from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
28
- else:
29
-
30
- def __getattr__(name):
31
- import importlib
32
-
33
- if name in globals():
34
- return globals()[name]
35
- if name == "BestCheckpointStrategyConfig":
36
- return importlib.import_module(
37
- "nshtrainer._checkpoint.loader"
38
- ).BestCheckpointStrategyConfig
39
- if name == "CheckpointLoadingConfig":
40
- return importlib.import_module(
41
- "nshtrainer._checkpoint.loader"
42
- ).CheckpointLoadingConfig
43
- if name == "CheckpointMetadata":
44
- return importlib.import_module(
45
- "nshtrainer._checkpoint.loader"
46
- ).CheckpointMetadata
47
- if name == "EnvironmentConfig":
48
- return importlib.import_module(
49
- "nshtrainer._checkpoint.metadata"
50
- ).EnvironmentConfig
51
- if name == "LastCheckpointStrategyConfig":
52
- return importlib.import_module(
53
- "nshtrainer._checkpoint.loader"
54
- ).LastCheckpointStrategyConfig
55
- if name == "MetricConfig":
56
- return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
57
- if name == "UserProvidedPathCheckpointStrategyConfig":
58
- return importlib.import_module(
59
- "nshtrainer._checkpoint.loader"
60
- ).UserProvidedPathCheckpointStrategyConfig
61
- if name == "CheckpointLoadingStrategyConfig":
62
- return importlib.import_module(
63
- "nshtrainer._checkpoint.loader"
64
- ).CheckpointLoadingStrategyConfig
65
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
66
-
67
-
68
- # Submodule exports
69
- from . import loader as loader
70
8
  from . import metadata as metadata