nshtrainer 0.42.0__py3-none-any.whl → 0.44.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +130 -90
  35. nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
  37. nshtrainer/config/_directory/__init__.py +9 -3
  38. nshtrainer/config/_hf_hub/__init__.py +6 -4
  39. nshtrainer/config/callbacks/__init__.py +82 -42
  40. nshtrainer/config/callbacks/actsave/__init__.py +4 -2
  41. nshtrainer/config/callbacks/base/__init__.py +2 -0
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
  50. nshtrainer/config/callbacks/ema/__init__.py +5 -3
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
  54. nshtrainer/config/callbacks/print_table/__init__.py +7 -5
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
  58. nshtrainer/config/callbacks/timer/__init__.py +9 -5
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
  61. nshtrainer/config/loggers/__init__.py +18 -10
  62. nshtrainer/config/loggers/_base/__init__.py +2 -0
  63. nshtrainer/config/loggers/csv/__init__.py +2 -0
  64. nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
  65. nshtrainer/config/loggers/wandb/__init__.py +18 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +2 -0
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
  70. nshtrainer/config/metrics/__init__.py +2 -0
  71. nshtrainer/config/metrics/_config/__init__.py +2 -0
  72. nshtrainer/config/model/__init__.py +8 -6
  73. nshtrainer/config/model/base/__init__.py +4 -2
  74. nshtrainer/config/model/config/__init__.py +8 -6
  75. nshtrainer/config/model/mixins/logger/__init__.py +2 -0
  76. nshtrainer/config/nn/__init__.py +16 -14
  77. nshtrainer/config/nn/mlp/__init__.py +2 -0
  78. nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
  79. nshtrainer/config/optimizer/__init__.py +2 -0
  80. nshtrainer/config/profiler/__init__.py +2 -0
  81. nshtrainer/config/profiler/_base/__init__.py +2 -0
  82. nshtrainer/config/profiler/advanced/__init__.py +6 -4
  83. nshtrainer/config/profiler/pytorch/__init__.py +6 -4
  84. nshtrainer/config/profiler/simple/__init__.py +6 -4
  85. nshtrainer/config/runner/__init__.py +2 -0
  86. nshtrainer/config/trainer/_config/__init__.py +43 -39
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
  88. nshtrainer/config/util/_environment_info/__init__.py +20 -18
  89. nshtrainer/config/util/config/__init__.py +2 -0
  90. nshtrainer/config/util/config/dtype/__init__.py +2 -0
  91. nshtrainer/config/util/config/duration/__init__.py +2 -0
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +8 -11
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +18 -17
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +8 -6
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.44.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.42.0.dist-info/RECORD +0 -162
  162. {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -30,25 +32,47 @@ if TYPE_CHECKING:
30
32
  )
31
33
  from nshtrainer.callbacks import CallbackConfig as CallbackConfig
32
34
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
33
- from nshtrainer.callbacks import DirectorySetupConfig as DirectorySetupConfig
34
- from nshtrainer.callbacks import EarlyStoppingConfig as EarlyStoppingConfig
35
- from nshtrainer.callbacks import EMAConfig as EMAConfig
36
- from nshtrainer.callbacks import EpochTimerConfig as EpochTimerConfig
37
- from nshtrainer.callbacks import FiniteChecksConfig as FiniteChecksConfig
38
- from nshtrainer.callbacks import GradientSkippingConfig as GradientSkippingConfig
35
+ from nshtrainer.callbacks import (
36
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
37
+ )
38
+ from nshtrainer.callbacks import (
39
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
40
+ )
41
+ from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
42
+ from nshtrainer.callbacks import (
43
+ EpochTimerCallbackConfig as EpochTimerCallbackConfig,
44
+ )
45
+ from nshtrainer.callbacks import (
46
+ FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
47
+ )
48
+ from nshtrainer.callbacks import (
49
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
50
+ )
39
51
  from nshtrainer.callbacks import (
40
52
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
53
  )
42
- from nshtrainer.callbacks import NormLoggingConfig as NormLoggingConfig
54
+ from nshtrainer.callbacks import (
55
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
56
+ )
43
57
  from nshtrainer.callbacks import (
44
58
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
45
59
  )
46
- from nshtrainer.callbacks import PrintTableMetricsConfig as PrintTableMetricsConfig
47
- from nshtrainer.callbacks import RLPSanityChecksConfig as RLPSanityChecksConfig
48
- from nshtrainer.callbacks import SharedParametersConfig as SharedParametersConfig
60
+ from nshtrainer.callbacks import (
61
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
62
+ )
63
+ from nshtrainer.callbacks import (
64
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
65
+ )
66
+ from nshtrainer.callbacks import (
67
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
68
+ )
49
69
  from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
50
- from nshtrainer.callbacks import WandbUploadCodeConfig as WandbUploadCodeConfig
51
- from nshtrainer.callbacks import WandbWatchConfig as WandbWatchConfig
70
+ from nshtrainer.callbacks import (
71
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
72
+ )
73
+ from nshtrainer.callbacks import (
74
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
75
+ )
52
76
  from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
53
77
  from nshtrainer.callbacks.checkpoint._base import (
54
78
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
@@ -155,26 +179,28 @@ else:
155
179
 
156
180
  if name in globals():
157
181
  return globals()[name]
158
- if name == "BaseConfig":
159
- return importlib.import_module("nshtrainer").BaseConfig
160
182
  if name == "MetricConfig":
161
183
  return importlib.import_module("nshtrainer").MetricConfig
162
- if name == "CallbackConfigBase":
163
- return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
164
- if name == "HuggingFaceHubConfig":
165
- return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
184
+ if name == "BaseConfig":
185
+ return importlib.import_module("nshtrainer").BaseConfig
166
186
  if name == "HuggingFaceHubAutoCreateConfig":
167
187
  return importlib.import_module(
168
188
  "nshtrainer._hf_hub"
169
189
  ).HuggingFaceHubAutoCreateConfig
190
+ if name == "HuggingFaceHubConfig":
191
+ return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
192
+ if name == "CallbackConfigBase":
193
+ return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
170
194
  if name == "OptimizerConfigBase":
171
195
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
172
196
  if name == "AdamWConfig":
173
197
  return importlib.import_module("nshtrainer.optimizer").AdamWConfig
198
+ if name == "DirectorySetupCallbackConfig":
199
+ return importlib.import_module(
200
+ "nshtrainer.callbacks"
201
+ ).DirectorySetupCallbackConfig
174
202
  if name == "DirectoryConfig":
175
203
  return importlib.import_module("nshtrainer.model").DirectoryConfig
176
- if name == "DirectorySetupConfig":
177
- return importlib.import_module("nshtrainer.callbacks").DirectorySetupConfig
178
204
  if name == "TrainerConfig":
179
205
  return importlib.import_module("nshtrainer.model").TrainerConfig
180
206
  if name == "EnvironmentConfig":
@@ -183,36 +209,36 @@ else:
183
209
  return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
184
210
  if name == "MLPConfig":
185
211
  return importlib.import_module("nshtrainer.nn").MLPConfig
212
+ if name == "PReLUConfig":
213
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
214
+ if name == "LeakyReLUNonlinearityConfig":
215
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
186
216
  if name == "SwiGLUNonlinearityConfig":
187
217
  return importlib.import_module(
188
218
  "nshtrainer.nn.nonlinearity"
189
219
  ).SwiGLUNonlinearityConfig
190
- if name == "ReLUNonlinearityConfig":
191
- return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
220
+ if name == "SoftsignNonlinearityConfig":
221
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
192
222
  if name == "SiLUNonlinearityConfig":
193
223
  return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
224
+ if name == "SigmoidNonlinearityConfig":
225
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
226
+ if name == "SoftplusNonlinearityConfig":
227
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
194
228
  if name == "ELUNonlinearityConfig":
195
229
  return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
230
+ if name == "SoftmaxNonlinearityConfig":
231
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
196
232
  if name == "GELUNonlinearityConfig":
197
233
  return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
198
- if name == "SoftplusNonlinearityConfig":
199
- return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
200
- if name == "SoftsignNonlinearityConfig":
201
- return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
202
234
  if name == "SwishNonlinearityConfig":
203
235
  return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
204
- if name == "SoftmaxNonlinearityConfig":
205
- return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
206
236
  if name == "MishNonlinearityConfig":
207
237
  return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
208
- if name == "SigmoidNonlinearityConfig":
209
- return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
210
238
  if name == "TanhNonlinearityConfig":
211
239
  return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
212
- if name == "PReLUConfig":
213
- return importlib.import_module("nshtrainer.nn").PReLUConfig
214
- if name == "LeakyReLUNonlinearityConfig":
215
- return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
240
+ if name == "ReLUNonlinearityConfig":
241
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
216
242
  if name == "LRSchedulerConfigBase":
217
243
  return importlib.import_module(
218
244
  "nshtrainer.lr_scheduler"
@@ -231,32 +257,40 @@ else:
231
257
  return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
232
258
  if name == "WandbLoggerConfig":
233
259
  return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
234
- if name == "WandbUploadCodeConfig":
235
- return importlib.import_module("nshtrainer.callbacks").WandbUploadCodeConfig
236
- if name == "WandbWatchConfig":
237
- return importlib.import_module("nshtrainer.callbacks").WandbWatchConfig
260
+ if name == "WandbUploadCodeCallbackConfig":
261
+ return importlib.import_module(
262
+ "nshtrainer.callbacks"
263
+ ).WandbUploadCodeCallbackConfig
264
+ if name == "WandbWatchCallbackConfig":
265
+ return importlib.import_module(
266
+ "nshtrainer.callbacks"
267
+ ).WandbWatchCallbackConfig
238
268
  if name == "CSVLoggerConfig":
239
269
  return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
240
- if name == "EnvironmentPackageConfig":
241
- return importlib.import_module(
242
- "nshtrainer.util._environment_info"
243
- ).EnvironmentPackageConfig
244
- if name == "EnvironmentSnapshotConfig":
270
+ if name == "EnvironmentLinuxEnvironmentConfig":
245
271
  return importlib.import_module(
246
272
  "nshtrainer.util._environment_info"
247
- ).EnvironmentSnapshotConfig
273
+ ).EnvironmentLinuxEnvironmentConfig
248
274
  if name == "EnvironmentLSFInformationConfig":
249
275
  return importlib.import_module(
250
276
  "nshtrainer.util._environment_info"
251
277
  ).EnvironmentLSFInformationConfig
252
- if name == "EnvironmentLinuxEnvironmentConfig":
278
+ if name == "EnvironmentGPUConfig":
253
279
  return importlib.import_module(
254
280
  "nshtrainer.util._environment_info"
255
- ).EnvironmentLinuxEnvironmentConfig
256
- if name == "EnvironmentSLURMInformationConfig":
281
+ ).EnvironmentGPUConfig
282
+ if name == "EnvironmentPackageConfig":
257
283
  return importlib.import_module(
258
284
  "nshtrainer.util._environment_info"
259
- ).EnvironmentSLURMInformationConfig
285
+ ).EnvironmentPackageConfig
286
+ if name == "EnvironmentHardwareConfig":
287
+ return importlib.import_module(
288
+ "nshtrainer.util._environment_info"
289
+ ).EnvironmentHardwareConfig
290
+ if name == "EnvironmentSnapshotConfig":
291
+ return importlib.import_module(
292
+ "nshtrainer.util._environment_info"
293
+ ).EnvironmentSnapshotConfig
260
294
  if name == "EnvironmentClassInformationConfig":
261
295
  return importlib.import_module(
262
296
  "nshtrainer.util._environment_info"
@@ -269,14 +303,10 @@ else:
269
303
  return importlib.import_module(
270
304
  "nshtrainer.util._environment_info"
271
305
  ).EnvironmentCUDAConfig
272
- if name == "EnvironmentGPUConfig":
273
- return importlib.import_module(
274
- "nshtrainer.util._environment_info"
275
- ).EnvironmentGPUConfig
276
- if name == "EnvironmentHardwareConfig":
306
+ if name == "EnvironmentSLURMInformationConfig":
277
307
  return importlib.import_module(
278
308
  "nshtrainer.util._environment_info"
279
- ).EnvironmentHardwareConfig
309
+ ).EnvironmentSLURMInformationConfig
280
310
  if name == "EpochsConfig":
281
311
  return importlib.import_module("nshtrainer.util.config").EpochsConfig
282
312
  if name == "StepsConfig":
@@ -287,52 +317,56 @@ else:
287
317
  return importlib.import_module(
288
318
  "nshtrainer.trainer._config"
289
319
  ).CheckpointLoadingConfig
290
- if name == "OptimizationConfig":
320
+ if name == "SanityCheckingConfig":
291
321
  return importlib.import_module(
292
322
  "nshtrainer.trainer._config"
293
- ).OptimizationConfig
323
+ ).SanityCheckingConfig
324
+ if name == "OnExceptionCheckpointCallbackConfig":
325
+ return importlib.import_module(
326
+ "nshtrainer.callbacks"
327
+ ).OnExceptionCheckpointCallbackConfig
294
328
  if name == "GradientClippingConfig":
295
329
  return importlib.import_module(
296
330
  "nshtrainer.trainer._config"
297
331
  ).GradientClippingConfig
298
- if name == "LastCheckpointCallbackConfig":
332
+ if name == "LoggingConfig":
333
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
334
+ if name == "RLPSanityChecksCallbackConfig":
299
335
  return importlib.import_module(
300
336
  "nshtrainer.callbacks"
301
- ).LastCheckpointCallbackConfig
302
- if name == "OnExceptionCheckpointCallbackConfig":
337
+ ).RLPSanityChecksCallbackConfig
338
+ if name == "CheckpointSavingConfig":
303
339
  return importlib.import_module(
304
- "nshtrainer.callbacks"
305
- ).OnExceptionCheckpointCallbackConfig
306
- if name == "RLPSanityChecksConfig":
307
- return importlib.import_module("nshtrainer.callbacks").RLPSanityChecksConfig
308
- if name == "EarlyStoppingConfig":
309
- return importlib.import_module("nshtrainer.callbacks").EarlyStoppingConfig
340
+ "nshtrainer.trainer._config"
341
+ ).CheckpointSavingConfig
310
342
  if name == "DebugFlagCallbackConfig":
311
343
  return importlib.import_module(
312
344
  "nshtrainer.callbacks"
313
345
  ).DebugFlagCallbackConfig
314
- if name == "CheckpointSavingConfig":
346
+ if name == "LastCheckpointCallbackConfig":
315
347
  return importlib.import_module(
316
- "nshtrainer.trainer._config"
317
- ).CheckpointSavingConfig
318
- if name == "BestCheckpointCallbackConfig":
348
+ "nshtrainer.callbacks"
349
+ ).LastCheckpointCallbackConfig
350
+ if name == "SharedParametersCallbackConfig":
319
351
  return importlib.import_module(
320
352
  "nshtrainer.callbacks"
321
- ).BestCheckpointCallbackConfig
322
- if name == "LoggingConfig":
323
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
324
- if name == "SanityCheckingConfig":
353
+ ).SharedParametersCallbackConfig
354
+ if name == "ReproducibilityConfig":
325
355
  return importlib.import_module(
326
356
  "nshtrainer.trainer._config"
327
- ).SanityCheckingConfig
328
- if name == "SharedParametersConfig":
357
+ ).ReproducibilityConfig
358
+ if name == "EarlyStoppingCallbackConfig":
329
359
  return importlib.import_module(
330
360
  "nshtrainer.callbacks"
331
- ).SharedParametersConfig
332
- if name == "ReproducibilityConfig":
361
+ ).EarlyStoppingCallbackConfig
362
+ if name == "OptimizationConfig":
333
363
  return importlib.import_module(
334
364
  "nshtrainer.trainer._config"
335
- ).ReproducibilityConfig
365
+ ).OptimizationConfig
366
+ if name == "BestCheckpointCallbackConfig":
367
+ return importlib.import_module(
368
+ "nshtrainer.callbacks"
369
+ ).BestCheckpointCallbackConfig
336
370
  if name == "CheckpointMetadata":
337
371
  return importlib.import_module(
338
372
  "nshtrainer._checkpoint.loader"
@@ -349,28 +383,34 @@ else:
349
383
  return importlib.import_module(
350
384
  "nshtrainer._checkpoint.loader"
351
385
  ).UserProvidedPathCheckpointStrategyConfig
352
- if name == "PrintTableMetricsConfig":
386
+ if name == "PrintTableMetricsCallbackConfig":
353
387
  return importlib.import_module(
354
388
  "nshtrainer.callbacks"
355
- ).PrintTableMetricsConfig
389
+ ).PrintTableMetricsCallbackConfig
356
390
  if name == "ThroughputMonitorConfig":
357
391
  return importlib.import_module(
358
392
  "nshtrainer.callbacks"
359
393
  ).ThroughputMonitorConfig
360
- if name == "GradientSkippingConfig":
394
+ if name == "GradientSkippingCallbackConfig":
361
395
  return importlib.import_module(
362
396
  "nshtrainer.callbacks"
363
- ).GradientSkippingConfig
364
- if name == "EMAConfig":
365
- return importlib.import_module("nshtrainer.callbacks").EMAConfig
397
+ ).GradientSkippingCallbackConfig
398
+ if name == "EMACallbackConfig":
399
+ return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
366
400
  if name == "ActSaveConfig":
367
401
  return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
368
- if name == "FiniteChecksConfig":
369
- return importlib.import_module("nshtrainer.callbacks").FiniteChecksConfig
370
- if name == "NormLoggingConfig":
371
- return importlib.import_module("nshtrainer.callbacks").NormLoggingConfig
372
- if name == "EpochTimerConfig":
373
- return importlib.import_module("nshtrainer.callbacks").EpochTimerConfig
402
+ if name == "FiniteChecksCallbackConfig":
403
+ return importlib.import_module(
404
+ "nshtrainer.callbacks"
405
+ ).FiniteChecksCallbackConfig
406
+ if name == "NormLoggingCallbackConfig":
407
+ return importlib.import_module(
408
+ "nshtrainer.callbacks"
409
+ ).NormLoggingCallbackConfig
410
+ if name == "EpochTimerCallbackConfig":
411
+ return importlib.import_module(
412
+ "nshtrainer.callbacks"
413
+ ).EpochTimerCallbackConfig
374
414
  if name == "BaseCheckpointCallbackConfig":
375
415
  return importlib.import_module(
376
416
  "nshtrainer.callbacks.checkpoint._base"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -29,24 +31,24 @@ else:
29
31
 
30
32
  if name in globals():
31
33
  return globals()[name]
32
- if name == "CheckpointLoadingConfig":
33
- return importlib.import_module(
34
- "nshtrainer._checkpoint.loader"
35
- ).CheckpointLoadingConfig
36
- if name == "CheckpointMetadata":
37
- return importlib.import_module(
38
- "nshtrainer._checkpoint.loader"
39
- ).CheckpointMetadata
40
34
  if name == "MetricConfig":
41
35
  return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
42
36
  if name == "BestCheckpointStrategyConfig":
43
37
  return importlib.import_module(
44
38
  "nshtrainer._checkpoint.loader"
45
39
  ).BestCheckpointStrategyConfig
40
+ if name == "CheckpointMetadata":
41
+ return importlib.import_module(
42
+ "nshtrainer._checkpoint.loader"
43
+ ).CheckpointMetadata
46
44
  if name == "LastCheckpointStrategyConfig":
47
45
  return importlib.import_module(
48
46
  "nshtrainer._checkpoint.loader"
49
47
  ).LastCheckpointStrategyConfig
48
+ if name == "CheckpointLoadingConfig":
49
+ return importlib.import_module(
50
+ "nshtrainer._checkpoint.loader"
51
+ ).CheckpointLoadingConfig
50
52
  if name == "UserProvidedPathCheckpointStrategyConfig":
51
53
  return importlib.import_module(
52
54
  "nshtrainer._checkpoint.loader"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -14,14 +16,14 @@ else:
14
16
 
15
17
  if name in globals():
16
18
  return globals()[name]
17
- if name == "EnvironmentConfig":
18
- return importlib.import_module(
19
- "nshtrainer._checkpoint.metadata"
20
- ).EnvironmentConfig
21
19
  if name == "CheckpointMetadata":
22
20
  return importlib.import_module(
23
21
  "nshtrainer._checkpoint.metadata"
24
22
  ).CheckpointMetadata
23
+ if name == "EnvironmentConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer._checkpoint.metadata"
26
+ ).EnvironmentConfig
25
27
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
28
 
27
29
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -6,7 +8,9 @@ from typing import TYPE_CHECKING
6
8
 
7
9
  if TYPE_CHECKING:
8
10
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
- from nshtrainer._directory import DirectorySetupConfig as DirectorySetupConfig
11
+ from nshtrainer._directory import (
12
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
13
+ )
10
14
  from nshtrainer._directory import LoggerConfig as LoggerConfig
11
15
  else:
12
16
 
@@ -15,10 +19,12 @@ else:
15
19
 
16
20
  if name in globals():
17
21
  return globals()[name]
22
+ if name == "DirectorySetupCallbackConfig":
23
+ return importlib.import_module(
24
+ "nshtrainer._directory"
25
+ ).DirectorySetupCallbackConfig
18
26
  if name == "DirectoryConfig":
19
27
  return importlib.import_module("nshtrainer._directory").DirectoryConfig
20
- if name == "DirectorySetupConfig":
21
- return importlib.import_module("nshtrainer._directory").DirectorySetupConfig
22
28
  if name == "LoggerConfig":
23
29
  return importlib.import_module("nshtrainer._directory").LoggerConfig
24
30
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -17,14 +19,14 @@ else:
17
19
 
18
20
  if name in globals():
19
21
  return globals()[name]
20
- if name == "CallbackConfigBase":
21
- return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
22
- if name == "HuggingFaceHubConfig":
23
- return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
24
22
  if name == "HuggingFaceHubAutoCreateConfig":
25
23
  return importlib.import_module(
26
24
  "nshtrainer._hf_hub"
27
25
  ).HuggingFaceHubAutoCreateConfig
26
+ if name == "HuggingFaceHubConfig":
27
+ return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
28
+ if name == "CallbackConfigBase":
29
+ return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
28
30
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
29
31
 
30
32
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -11,25 +13,47 @@ if TYPE_CHECKING:
11
13
  from nshtrainer.callbacks import CallbackConfig as CallbackConfig
12
14
  from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
13
15
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
14
- from nshtrainer.callbacks import DirectorySetupConfig as DirectorySetupConfig
15
- from nshtrainer.callbacks import EarlyStoppingConfig as EarlyStoppingConfig
16
- from nshtrainer.callbacks import EMAConfig as EMAConfig
17
- from nshtrainer.callbacks import EpochTimerConfig as EpochTimerConfig
18
- from nshtrainer.callbacks import FiniteChecksConfig as FiniteChecksConfig
19
- from nshtrainer.callbacks import GradientSkippingConfig as GradientSkippingConfig
16
+ from nshtrainer.callbacks import (
17
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
18
+ )
19
+ from nshtrainer.callbacks import (
20
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
21
+ )
22
+ from nshtrainer.callbacks import EMACallbackConfig as EMACallbackConfig
23
+ from nshtrainer.callbacks import (
24
+ EpochTimerCallbackConfig as EpochTimerCallbackConfig,
25
+ )
26
+ from nshtrainer.callbacks import (
27
+ FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
28
+ )
29
+ from nshtrainer.callbacks import (
30
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
31
+ )
20
32
  from nshtrainer.callbacks import (
21
33
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
22
34
  )
23
- from nshtrainer.callbacks import NormLoggingConfig as NormLoggingConfig
35
+ from nshtrainer.callbacks import (
36
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
37
+ )
24
38
  from nshtrainer.callbacks import (
25
39
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
26
40
  )
27
- from nshtrainer.callbacks import PrintTableMetricsConfig as PrintTableMetricsConfig
28
- from nshtrainer.callbacks import RLPSanityChecksConfig as RLPSanityChecksConfig
29
- from nshtrainer.callbacks import SharedParametersConfig as SharedParametersConfig
41
+ from nshtrainer.callbacks import (
42
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
43
+ )
44
+ from nshtrainer.callbacks import (
45
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
46
+ )
47
+ from nshtrainer.callbacks import (
48
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
49
+ )
30
50
  from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
31
- from nshtrainer.callbacks import WandbUploadCodeConfig as WandbUploadCodeConfig
32
- from nshtrainer.callbacks import WandbWatchConfig as WandbWatchConfig
51
+ from nshtrainer.callbacks import (
52
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
53
+ )
54
+ from nshtrainer.callbacks import (
55
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
56
+ )
33
57
  from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
34
58
  from nshtrainer.callbacks.checkpoint._base import (
35
59
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
@@ -45,12 +69,12 @@ else:
45
69
 
46
70
  if name in globals():
47
71
  return globals()[name]
48
- if name == "CallbackConfigBase":
49
- return importlib.import_module("nshtrainer.callbacks").CallbackConfigBase
50
- if name == "PrintTableMetricsConfig":
72
+ if name == "PrintTableMetricsCallbackConfig":
51
73
  return importlib.import_module(
52
74
  "nshtrainer.callbacks"
53
- ).PrintTableMetricsConfig
75
+ ).PrintTableMetricsCallbackConfig
76
+ if name == "CallbackConfigBase":
77
+ return importlib.import_module("nshtrainer.callbacks").CallbackConfigBase
54
78
  if name == "DebugFlagCallbackConfig":
55
79
  return importlib.import_module(
56
80
  "nshtrainer.callbacks"
@@ -59,50 +83,66 @@ else:
59
83
  return importlib.import_module(
60
84
  "nshtrainer.callbacks"
61
85
  ).ThroughputMonitorConfig
62
- if name == "GradientSkippingConfig":
86
+ if name == "GradientSkippingCallbackConfig":
87
+ return importlib.import_module(
88
+ "nshtrainer.callbacks"
89
+ ).GradientSkippingCallbackConfig
90
+ if name == "RLPSanityChecksCallbackConfig":
63
91
  return importlib.import_module(
64
92
  "nshtrainer.callbacks"
65
- ).GradientSkippingConfig
66
- if name == "RLPSanityChecksConfig":
67
- return importlib.import_module("nshtrainer.callbacks").RLPSanityChecksConfig
68
- if name == "WandbUploadCodeConfig":
69
- return importlib.import_module("nshtrainer.callbacks").WandbUploadCodeConfig
93
+ ).RLPSanityChecksCallbackConfig
94
+ if name == "WandbUploadCodeCallbackConfig":
95
+ return importlib.import_module(
96
+ "nshtrainer.callbacks"
97
+ ).WandbUploadCodeCallbackConfig
70
98
  if name == "MetricConfig":
71
99
  return importlib.import_module(
72
100
  "nshtrainer.callbacks.early_stopping"
73
101
  ).MetricConfig
74
- if name == "EarlyStoppingConfig":
75
- return importlib.import_module("nshtrainer.callbacks").EarlyStoppingConfig
76
- if name == "WandbWatchConfig":
77
- return importlib.import_module("nshtrainer.callbacks").WandbWatchConfig
78
- if name == "EMAConfig":
79
- return importlib.import_module("nshtrainer.callbacks").EMAConfig
80
- if name == "DirectorySetupConfig":
81
- return importlib.import_module("nshtrainer.callbacks").DirectorySetupConfig
102
+ if name == "EarlyStoppingCallbackConfig":
103
+ return importlib.import_module(
104
+ "nshtrainer.callbacks"
105
+ ).EarlyStoppingCallbackConfig
106
+ if name == "WandbWatchCallbackConfig":
107
+ return importlib.import_module(
108
+ "nshtrainer.callbacks"
109
+ ).WandbWatchCallbackConfig
110
+ if name == "EMACallbackConfig":
111
+ return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
112
+ if name == "DirectorySetupCallbackConfig":
113
+ return importlib.import_module(
114
+ "nshtrainer.callbacks"
115
+ ).DirectorySetupCallbackConfig
82
116
  if name == "ActSaveConfig":
83
117
  return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
84
- if name == "FiniteChecksConfig":
85
- return importlib.import_module("nshtrainer.callbacks").FiniteChecksConfig
86
- if name == "NormLoggingConfig":
87
- return importlib.import_module("nshtrainer.callbacks").NormLoggingConfig
88
- if name == "LastCheckpointCallbackConfig":
118
+ if name == "FiniteChecksCallbackConfig":
89
119
  return importlib.import_module(
90
120
  "nshtrainer.callbacks"
91
- ).LastCheckpointCallbackConfig
92
- if name == "BestCheckpointCallbackConfig":
121
+ ).FiniteChecksCallbackConfig
122
+ if name == "NormLoggingCallbackConfig":
93
123
  return importlib.import_module(
94
124
  "nshtrainer.callbacks"
95
- ).BestCheckpointCallbackConfig
125
+ ).NormLoggingCallbackConfig
126
+ if name == "EpochTimerCallbackConfig":
127
+ return importlib.import_module(
128
+ "nshtrainer.callbacks"
129
+ ).EpochTimerCallbackConfig
96
130
  if name == "OnExceptionCheckpointCallbackConfig":
97
131
  return importlib.import_module(
98
132
  "nshtrainer.callbacks"
99
133
  ).OnExceptionCheckpointCallbackConfig
100
- if name == "EpochTimerConfig":
101
- return importlib.import_module("nshtrainer.callbacks").EpochTimerConfig
102
- if name == "SharedParametersConfig":
134
+ if name == "LastCheckpointCallbackConfig":
135
+ return importlib.import_module(
136
+ "nshtrainer.callbacks"
137
+ ).LastCheckpointCallbackConfig
138
+ if name == "SharedParametersCallbackConfig":
139
+ return importlib.import_module(
140
+ "nshtrainer.callbacks"
141
+ ).SharedParametersCallbackConfig
142
+ if name == "BestCheckpointCallbackConfig":
103
143
  return importlib.import_module(
104
144
  "nshtrainer.callbacks"
105
- ).SharedParametersConfig
145
+ ).BestCheckpointCallbackConfig
106
146
  if name == "CheckpointMetadata":
107
147
  return importlib.import_module(
108
148
  "nshtrainer.callbacks.checkpoint._base"