nshtrainer 1.0.0b14__py3-none-any.whl → 1.0.0b16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. nshtrainer/configs/__init__.py +134 -405
  2. nshtrainer/configs/_checkpoint/__init__.py +2 -25
  3. nshtrainer/configs/_checkpoint/metadata/__init__.py +2 -25
  4. nshtrainer/configs/_directory/__init__.py +5 -28
  5. nshtrainer/configs/_hf_hub/__init__.py +5 -28
  6. nshtrainer/configs/callbacks/__init__.py +52 -161
  7. nshtrainer/configs/callbacks/actsave/__init__.py +2 -23
  8. nshtrainer/configs/callbacks/base/__init__.py +1 -20
  9. nshtrainer/configs/callbacks/checkpoint/__init__.py +19 -64
  10. nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +9 -36
  11. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +10 -43
  12. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +9 -36
  13. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -29
  14. nshtrainer/configs/callbacks/debug_flag/__init__.py +4 -27
  15. nshtrainer/configs/callbacks/directory_setup/__init__.py +6 -29
  16. nshtrainer/configs/callbacks/early_stopping/__init__.py +5 -34
  17. nshtrainer/configs/callbacks/ema/__init__.py +2 -23
  18. nshtrainer/configs/callbacks/finite_checks/__init__.py +4 -29
  19. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +6 -29
  20. nshtrainer/configs/callbacks/log_epoch/__init__.py +4 -27
  21. nshtrainer/configs/callbacks/lr_monitor/__init__.py +4 -27
  22. nshtrainer/configs/callbacks/norm_logging/__init__.py +4 -29
  23. nshtrainer/configs/callbacks/print_table/__init__.py +4 -29
  24. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +6 -29
  25. nshtrainer/configs/callbacks/shared_parameters/__init__.py +6 -29
  26. nshtrainer/configs/callbacks/timer/__init__.py +4 -27
  27. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +6 -29
  28. nshtrainer/configs/callbacks/wandb_watch/__init__.py +4 -29
  29. nshtrainer/configs/loggers/__init__.py +13 -52
  30. nshtrainer/configs/loggers/_base/__init__.py +1 -18
  31. nshtrainer/configs/loggers/actsave/__init__.py +2 -25
  32. nshtrainer/configs/loggers/csv/__init__.py +2 -21
  33. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -27
  34. nshtrainer/configs/loggers/wandb/__init__.py +9 -40
  35. nshtrainer/configs/lr_scheduler/__init__.py +10 -51
  36. nshtrainer/configs/lr_scheduler/_base/__init__.py +1 -22
  37. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +9 -36
  38. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +7 -36
  39. nshtrainer/configs/metrics/__init__.py +1 -18
  40. nshtrainer/configs/metrics/_config/__init__.py +1 -18
  41. nshtrainer/configs/nn/__init__.py +19 -70
  42. nshtrainer/configs/nn/mlp/__init__.py +3 -24
  43. nshtrainer/configs/nn/nonlinearity/__init__.py +30 -121
  44. nshtrainer/configs/optimizer/__init__.py +3 -24
  45. nshtrainer/configs/profiler/__init__.py +5 -30
  46. nshtrainer/configs/profiler/_base/__init__.py +1 -20
  47. nshtrainer/configs/profiler/advanced/__init__.py +4 -27
  48. nshtrainer/configs/profiler/pytorch/__init__.py +2 -27
  49. nshtrainer/configs/profiler/simple/__init__.py +2 -25
  50. nshtrainer/configs/trainer/__init__.py +50 -169
  51. nshtrainer/configs/trainer/_config/__init__.py +50 -169
  52. nshtrainer/configs/trainer/trainer/__init__.py +2 -23
  53. nshtrainer/configs/util/__init__.py +33 -102
  54. nshtrainer/configs/util/_environment_info/__init__.py +29 -90
  55. nshtrainer/configs/util/config/__init__.py +4 -27
  56. nshtrainer/configs/util/config/dtype/__init__.py +1 -18
  57. nshtrainer/configs/util/config/duration/__init__.py +3 -30
  58. nshtrainer/trainer/_config.py +42 -10
  59. {nshtrainer-1.0.0b14.dist-info → nshtrainer-1.0.0b16.dist-info}/METADATA +1 -1
  60. {nshtrainer-1.0.0b14.dist-info → nshtrainer-1.0.0b16.dist-info}/RECORD +61 -61
  61. {nshtrainer-1.0.0b14.dist-info → nshtrainer-1.0.0b16.dist-info}/WHEEL +0 -0
@@ -2,37 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
6
+ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
7
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
8
+ from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
9
+ from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
6
10
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
11
- from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
12
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
13
- from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
14
- from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
15
- else:
16
-
17
- def __getattr__(name):
18
- import importlib
19
-
20
- if name in globals():
21
- return globals()[name]
22
- if name == "AdvancedProfilerConfig":
23
- return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
24
- if name == "BaseProfilerConfig":
25
- return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
26
- if name == "PyTorchProfilerConfig":
27
- return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
28
- if name == "SimpleProfilerConfig":
29
- return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
30
- if name == "ProfilerConfig":
31
- return importlib.import_module("nshtrainer.profiler").ProfilerConfig
32
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
33
-
34
-
35
- # Submodule exports
36
11
  from . import _base as _base
37
12
  from . import advanced as advanced
38
13
  from . import pytorch as pytorch
@@ -2,23 +2,4 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
11
- else:
12
-
13
- def __getattr__(name):
14
- import importlib
15
-
16
- if name in globals():
17
- return globals()[name]
18
- if name == "BaseProfilerConfig":
19
- return importlib.import_module(
20
- "nshtrainer.profiler._base"
21
- ).BaseProfilerConfig
22
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
23
-
24
- # Submodule exports
5
+ from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
@@ -2,30 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.profiler.advanced import (
11
- AdvancedProfilerConfig as AdvancedProfilerConfig,
12
- )
13
- from nshtrainer.profiler.advanced import BaseProfilerConfig as BaseProfilerConfig
14
- else:
15
-
16
- def __getattr__(name):
17
- import importlib
18
-
19
- if name in globals():
20
- return globals()[name]
21
- if name == "AdvancedProfilerConfig":
22
- return importlib.import_module(
23
- "nshtrainer.profiler.advanced"
24
- ).AdvancedProfilerConfig
25
- if name == "BaseProfilerConfig":
26
- return importlib.import_module(
27
- "nshtrainer.profiler.advanced"
28
- ).BaseProfilerConfig
29
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
-
31
- # Submodule exports
5
+ from nshtrainer.profiler.advanced import (
6
+ AdvancedProfilerConfig as AdvancedProfilerConfig,
7
+ )
8
+ from nshtrainer.profiler.advanced import BaseProfilerConfig as BaseProfilerConfig
@@ -2,30 +2,5 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.profiler.pytorch import BaseProfilerConfig as BaseProfilerConfig
11
- from nshtrainer.profiler.pytorch import (
12
- PyTorchProfilerConfig as PyTorchProfilerConfig,
13
- )
14
- else:
15
-
16
- def __getattr__(name):
17
- import importlib
18
-
19
- if name in globals():
20
- return globals()[name]
21
- if name == "BaseProfilerConfig":
22
- return importlib.import_module(
23
- "nshtrainer.profiler.pytorch"
24
- ).BaseProfilerConfig
25
- if name == "PyTorchProfilerConfig":
26
- return importlib.import_module(
27
- "nshtrainer.profiler.pytorch"
28
- ).PyTorchProfilerConfig
29
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
-
31
- # Submodule exports
5
+ from nshtrainer.profiler.pytorch import BaseProfilerConfig as BaseProfilerConfig
6
+ from nshtrainer.profiler.pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
@@ -2,28 +2,5 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.profiler.simple import BaseProfilerConfig as BaseProfilerConfig
11
- from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
12
- else:
13
-
14
- def __getattr__(name):
15
- import importlib
16
-
17
- if name in globals():
18
- return globals()[name]
19
- if name == "BaseProfilerConfig":
20
- return importlib.import_module(
21
- "nshtrainer.profiler.simple"
22
- ).BaseProfilerConfig
23
- if name == "SimpleProfilerConfig":
24
- return importlib.import_module(
25
- "nshtrainer.profiler.simple"
26
- ).SimpleProfilerConfig
27
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
-
29
- # Submodule exports
5
+ from nshtrainer.profiler.simple import BaseProfilerConfig as BaseProfilerConfig
6
+ from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
@@ -2,175 +2,56 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer.trainer import TrainerConfig as TrainerConfig
6
+ from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
7
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
8
+ from nshtrainer.trainer._config import (
9
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
10
+ )
11
+ from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
12
+ from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
13
+ from nshtrainer.trainer._config import (
14
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
15
+ )
16
+ from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
17
+ from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
18
+ from nshtrainer.trainer._config import (
19
+ DebugFlagCallbackConfig as DebugFlagCallbackConfig,
20
+ )
21
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
22
+ from nshtrainer.trainer._config import (
23
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
24
+ )
25
+ from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
26
+ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
27
+ from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
28
+ from nshtrainer.trainer._config import (
29
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
30
+ )
31
+ from nshtrainer.trainer._config import (
32
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
33
+ )
34
+ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
35
+ from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
36
+ from nshtrainer.trainer._config import MetricConfig as MetricConfig
37
+ from nshtrainer.trainer._config import (
38
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
39
+ )
40
+ from nshtrainer.trainer._config import (
41
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
42
+ )
43
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
44
+ from nshtrainer.trainer._config import (
45
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
46
+ )
47
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
48
+ from nshtrainer.trainer._config import (
49
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
50
+ )
51
+ from nshtrainer.trainer._config import (
52
+ TensorboardLoggerConfig as TensorboardLoggerConfig,
53
+ )
54
+ from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
6
55
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
- from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
- from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
13
- from nshtrainer.trainer._config import (
14
- BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
15
- )
16
- from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
17
- from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
18
- from nshtrainer.trainer._config import (
19
- CheckpointCallbackConfig as CheckpointCallbackConfig,
20
- )
21
- from nshtrainer.trainer._config import (
22
- CheckpointSavingConfig as CheckpointSavingConfig,
23
- )
24
- from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
25
- from nshtrainer.trainer._config import (
26
- DebugFlagCallbackConfig as DebugFlagCallbackConfig,
27
- )
28
- from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
29
- from nshtrainer.trainer._config import (
30
- EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
31
- )
32
- from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
33
- from nshtrainer.trainer._config import (
34
- GradientClippingConfig as GradientClippingConfig,
35
- )
36
- from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
37
- from nshtrainer.trainer._config import (
38
- LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
39
- )
40
- from nshtrainer.trainer._config import (
41
- LearningRateMonitorConfig as LearningRateMonitorConfig,
42
- )
43
- from nshtrainer.trainer._config import (
44
- LogEpochCallbackConfig as LogEpochCallbackConfig,
45
- )
46
- from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
47
- from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
- from nshtrainer.trainer._config import (
49
- NormLoggingCallbackConfig as NormLoggingCallbackConfig,
50
- )
51
- from nshtrainer.trainer._config import (
52
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
53
- )
54
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
55
- from nshtrainer.trainer._config import (
56
- RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
57
- )
58
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
59
- from nshtrainer.trainer._config import (
60
- SharedParametersCallbackConfig as SharedParametersCallbackConfig,
61
- )
62
- from nshtrainer.trainer._config import (
63
- TensorboardLoggerConfig as TensorboardLoggerConfig,
64
- )
65
- from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
66
- else:
67
-
68
- def __getattr__(name):
69
- import importlib
70
-
71
- if name in globals():
72
- return globals()[name]
73
- if name == "ActSaveLoggerConfig":
74
- return importlib.import_module(
75
- "nshtrainer.trainer._config"
76
- ).ActSaveLoggerConfig
77
- if name == "BaseLoggerConfig":
78
- return importlib.import_module(
79
- "nshtrainer.trainer._config"
80
- ).BaseLoggerConfig
81
- if name == "BestCheckpointCallbackConfig":
82
- return importlib.import_module(
83
- "nshtrainer.trainer._config"
84
- ).BestCheckpointCallbackConfig
85
- if name == "CSVLoggerConfig":
86
- return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
87
- if name == "CallbackConfigBase":
88
- return importlib.import_module(
89
- "nshtrainer.trainer._config"
90
- ).CallbackConfigBase
91
- if name == "CheckpointSavingConfig":
92
- return importlib.import_module(
93
- "nshtrainer.trainer._config"
94
- ).CheckpointSavingConfig
95
- if name == "DebugFlagCallbackConfig":
96
- return importlib.import_module(
97
- "nshtrainer.trainer._config"
98
- ).DebugFlagCallbackConfig
99
- if name == "DirectoryConfig":
100
- return importlib.import_module("nshtrainer.trainer._config").DirectoryConfig
101
- if name == "EarlyStoppingCallbackConfig":
102
- return importlib.import_module(
103
- "nshtrainer.trainer._config"
104
- ).EarlyStoppingCallbackConfig
105
- if name == "EnvironmentConfig":
106
- return importlib.import_module(
107
- "nshtrainer.trainer._config"
108
- ).EnvironmentConfig
109
- if name == "GradientClippingConfig":
110
- return importlib.import_module(
111
- "nshtrainer.trainer._config"
112
- ).GradientClippingConfig
113
- if name == "HuggingFaceHubConfig":
114
- return importlib.import_module(
115
- "nshtrainer.trainer._config"
116
- ).HuggingFaceHubConfig
117
- if name == "LastCheckpointCallbackConfig":
118
- return importlib.import_module(
119
- "nshtrainer.trainer._config"
120
- ).LastCheckpointCallbackConfig
121
- if name == "LearningRateMonitorConfig":
122
- return importlib.import_module(
123
- "nshtrainer.trainer._config"
124
- ).LearningRateMonitorConfig
125
- if name == "LogEpochCallbackConfig":
126
- return importlib.import_module(
127
- "nshtrainer.trainer._config"
128
- ).LogEpochCallbackConfig
129
- if name == "MetricConfig":
130
- return importlib.import_module("nshtrainer.trainer._config").MetricConfig
131
- if name == "NormLoggingCallbackConfig":
132
- return importlib.import_module(
133
- "nshtrainer.trainer._config"
134
- ).NormLoggingCallbackConfig
135
- if name == "OnExceptionCheckpointCallbackConfig":
136
- return importlib.import_module(
137
- "nshtrainer.trainer._config"
138
- ).OnExceptionCheckpointCallbackConfig
139
- if name == "RLPSanityChecksCallbackConfig":
140
- return importlib.import_module(
141
- "nshtrainer.trainer._config"
142
- ).RLPSanityChecksCallbackConfig
143
- if name == "SanityCheckingConfig":
144
- return importlib.import_module(
145
- "nshtrainer.trainer._config"
146
- ).SanityCheckingConfig
147
- if name == "SharedParametersCallbackConfig":
148
- return importlib.import_module(
149
- "nshtrainer.trainer._config"
150
- ).SharedParametersCallbackConfig
151
- if name == "TensorboardLoggerConfig":
152
- return importlib.import_module(
153
- "nshtrainer.trainer._config"
154
- ).TensorboardLoggerConfig
155
- if name == "TrainerConfig":
156
- return importlib.import_module("nshtrainer.trainer").TrainerConfig
157
- if name == "WandbLoggerConfig":
158
- return importlib.import_module(
159
- "nshtrainer.trainer._config"
160
- ).WandbLoggerConfig
161
- if name == "CallbackConfig":
162
- return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
163
- if name == "CheckpointCallbackConfig":
164
- return importlib.import_module(
165
- "nshtrainer.trainer._config"
166
- ).CheckpointCallbackConfig
167
- if name == "LoggerConfig":
168
- return importlib.import_module("nshtrainer.trainer._config").LoggerConfig
169
- if name == "ProfilerConfig":
170
- return importlib.import_module("nshtrainer.trainer._config").ProfilerConfig
171
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
172
-
173
-
174
- # Submodule exports
175
56
  from . import _config as _config
176
57
  from . import trainer as trainer
@@ -2,172 +2,53 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
- from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
- from nshtrainer.trainer._config import (
13
- BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
- )
15
- from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
16
- from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
17
- from nshtrainer.trainer._config import (
18
- CheckpointCallbackConfig as CheckpointCallbackConfig,
19
- )
20
- from nshtrainer.trainer._config import (
21
- CheckpointSavingConfig as CheckpointSavingConfig,
22
- )
23
- from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
24
- from nshtrainer.trainer._config import (
25
- DebugFlagCallbackConfig as DebugFlagCallbackConfig,
26
- )
27
- from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
28
- from nshtrainer.trainer._config import (
29
- EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
30
- )
31
- from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
32
- from nshtrainer.trainer._config import (
33
- GradientClippingConfig as GradientClippingConfig,
34
- )
35
- from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
36
- from nshtrainer.trainer._config import (
37
- LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
38
- )
39
- from nshtrainer.trainer._config import (
40
- LearningRateMonitorConfig as LearningRateMonitorConfig,
41
- )
42
- from nshtrainer.trainer._config import (
43
- LogEpochCallbackConfig as LogEpochCallbackConfig,
44
- )
45
- from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
- from nshtrainer.trainer._config import MetricConfig as MetricConfig
47
- from nshtrainer.trainer._config import (
48
- NormLoggingCallbackConfig as NormLoggingCallbackConfig,
49
- )
50
- from nshtrainer.trainer._config import (
51
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
52
- )
53
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
54
- from nshtrainer.trainer._config import (
55
- RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
56
- )
57
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
58
- from nshtrainer.trainer._config import (
59
- SharedParametersCallbackConfig as SharedParametersCallbackConfig,
60
- )
61
- from nshtrainer.trainer._config import (
62
- TensorboardLoggerConfig as TensorboardLoggerConfig,
63
- )
64
- from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
65
- from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
66
- else:
67
-
68
- def __getattr__(name):
69
- import importlib
70
-
71
- if name in globals():
72
- return globals()[name]
73
- if name == "ActSaveLoggerConfig":
74
- return importlib.import_module(
75
- "nshtrainer.trainer._config"
76
- ).ActSaveLoggerConfig
77
- if name == "BaseLoggerConfig":
78
- return importlib.import_module(
79
- "nshtrainer.trainer._config"
80
- ).BaseLoggerConfig
81
- if name == "BestCheckpointCallbackConfig":
82
- return importlib.import_module(
83
- "nshtrainer.trainer._config"
84
- ).BestCheckpointCallbackConfig
85
- if name == "CSVLoggerConfig":
86
- return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
87
- if name == "CallbackConfigBase":
88
- return importlib.import_module(
89
- "nshtrainer.trainer._config"
90
- ).CallbackConfigBase
91
- if name == "CheckpointSavingConfig":
92
- return importlib.import_module(
93
- "nshtrainer.trainer._config"
94
- ).CheckpointSavingConfig
95
- if name == "DebugFlagCallbackConfig":
96
- return importlib.import_module(
97
- "nshtrainer.trainer._config"
98
- ).DebugFlagCallbackConfig
99
- if name == "DirectoryConfig":
100
- return importlib.import_module("nshtrainer.trainer._config").DirectoryConfig
101
- if name == "EarlyStoppingCallbackConfig":
102
- return importlib.import_module(
103
- "nshtrainer.trainer._config"
104
- ).EarlyStoppingCallbackConfig
105
- if name == "EnvironmentConfig":
106
- return importlib.import_module(
107
- "nshtrainer.trainer._config"
108
- ).EnvironmentConfig
109
- if name == "GradientClippingConfig":
110
- return importlib.import_module(
111
- "nshtrainer.trainer._config"
112
- ).GradientClippingConfig
113
- if name == "HuggingFaceHubConfig":
114
- return importlib.import_module(
115
- "nshtrainer.trainer._config"
116
- ).HuggingFaceHubConfig
117
- if name == "LastCheckpointCallbackConfig":
118
- return importlib.import_module(
119
- "nshtrainer.trainer._config"
120
- ).LastCheckpointCallbackConfig
121
- if name == "LearningRateMonitorConfig":
122
- return importlib.import_module(
123
- "nshtrainer.trainer._config"
124
- ).LearningRateMonitorConfig
125
- if name == "LogEpochCallbackConfig":
126
- return importlib.import_module(
127
- "nshtrainer.trainer._config"
128
- ).LogEpochCallbackConfig
129
- if name == "MetricConfig":
130
- return importlib.import_module("nshtrainer.trainer._config").MetricConfig
131
- if name == "NormLoggingCallbackConfig":
132
- return importlib.import_module(
133
- "nshtrainer.trainer._config"
134
- ).NormLoggingCallbackConfig
135
- if name == "OnExceptionCheckpointCallbackConfig":
136
- return importlib.import_module(
137
- "nshtrainer.trainer._config"
138
- ).OnExceptionCheckpointCallbackConfig
139
- if name == "RLPSanityChecksCallbackConfig":
140
- return importlib.import_module(
141
- "nshtrainer.trainer._config"
142
- ).RLPSanityChecksCallbackConfig
143
- if name == "SanityCheckingConfig":
144
- return importlib.import_module(
145
- "nshtrainer.trainer._config"
146
- ).SanityCheckingConfig
147
- if name == "SharedParametersCallbackConfig":
148
- return importlib.import_module(
149
- "nshtrainer.trainer._config"
150
- ).SharedParametersCallbackConfig
151
- if name == "TensorboardLoggerConfig":
152
- return importlib.import_module(
153
- "nshtrainer.trainer._config"
154
- ).TensorboardLoggerConfig
155
- if name == "TrainerConfig":
156
- return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
157
- if name == "WandbLoggerConfig":
158
- return importlib.import_module(
159
- "nshtrainer.trainer._config"
160
- ).WandbLoggerConfig
161
- if name == "CallbackConfig":
162
- return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
163
- if name == "CheckpointCallbackConfig":
164
- return importlib.import_module(
165
- "nshtrainer.trainer._config"
166
- ).CheckpointCallbackConfig
167
- if name == "LoggerConfig":
168
- return importlib.import_module("nshtrainer.trainer._config").LoggerConfig
169
- if name == "ProfilerConfig":
170
- return importlib.import_module("nshtrainer.trainer._config").ProfilerConfig
171
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
172
-
173
- # Submodule exports
5
+ from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
6
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
7
+ from nshtrainer.trainer._config import (
8
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
9
+ )
10
+ from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
11
+ from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
12
+ from nshtrainer.trainer._config import (
13
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
14
+ )
15
+ from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
16
+ from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
17
+ from nshtrainer.trainer._config import (
18
+ DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
+ )
20
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
+ from nshtrainer.trainer._config import (
22
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
23
+ )
24
+ from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
25
+ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
26
+ from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
27
+ from nshtrainer.trainer._config import (
28
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
29
+ )
30
+ from nshtrainer.trainer._config import (
31
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
32
+ )
33
+ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
34
+ from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
+ from nshtrainer.trainer._config import MetricConfig as MetricConfig
36
+ from nshtrainer.trainer._config import (
37
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
38
+ )
39
+ from nshtrainer.trainer._config import (
40
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
41
+ )
42
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
43
+ from nshtrainer.trainer._config import (
44
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
45
+ )
46
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
47
+ from nshtrainer.trainer._config import (
48
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
49
+ )
50
+ from nshtrainer.trainer._config import (
51
+ TensorboardLoggerConfig as TensorboardLoggerConfig,
52
+ )
53
+ from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
54
+ from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
@@ -2,26 +2,5 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
11
- from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
12
- else:
13
-
14
- def __getattr__(name):
15
- import importlib
16
-
17
- if name in globals():
18
- return globals()[name]
19
- if name == "EnvironmentConfig":
20
- return importlib.import_module(
21
- "nshtrainer.trainer.trainer"
22
- ).EnvironmentConfig
23
- if name == "TrainerConfig":
24
- return importlib.import_module("nshtrainer.trainer.trainer").TrainerConfig
25
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
-
27
- # Submodule exports
5
+ from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
6
+ from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig