nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -26,7 +26,7 @@ class WandbUploadCodeCallbackConfig(CallbackConfigBase):
26
26
  return self.enabled
27
27
 
28
28
  @override
29
- def create_callbacks(self, root_config):
29
+ def create_callbacks(self, trainer_config):
30
30
  if not self:
31
31
  return
32
32
 
@@ -33,7 +33,7 @@ class WandbWatchCallbackConfig(CallbackConfigBase):
33
33
  return self.enabled
34
34
 
35
35
  @override
36
- def create_callbacks(self, root_config):
36
+ def create_callbacks(self, trainer_config):
37
37
  yield WandbWatchCallback(self)
38
38
 
39
39
 
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING
7
7
  # Config/alias imports
8
8
 
9
9
  if TYPE_CHECKING:
10
- from nshtrainer import BaseConfig as BaseConfig
11
10
  from nshtrainer import MetricConfig as MetricConfig
11
+ from nshtrainer import TrainerConfig as TrainerConfig
12
12
  from nshtrainer._checkpoint.loader import (
13
13
  BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
14
14
  )
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
22
22
  from nshtrainer._checkpoint.loader import (
23
23
  UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
24
24
  )
25
+ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
25
26
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
26
27
  from nshtrainer._hf_hub import (
27
28
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
51
52
  from nshtrainer.callbacks import (
52
53
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
53
54
  )
55
+ from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
54
56
  from nshtrainer.callbacks import (
55
57
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
56
58
  )
@@ -66,7 +68,6 @@ if TYPE_CHECKING:
66
68
  from nshtrainer.callbacks import (
67
69
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
68
70
  )
69
- from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
70
71
  from nshtrainer.callbacks import (
71
72
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
72
73
  )
@@ -77,6 +78,7 @@ if TYPE_CHECKING:
77
78
  from nshtrainer.callbacks.checkpoint._base import (
78
79
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
79
80
  )
81
+ from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
80
82
  from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
81
83
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
82
84
  from nshtrainer.loggers import LoggerConfig as LoggerConfig
@@ -90,9 +92,6 @@ if TYPE_CHECKING:
90
92
  from nshtrainer.lr_scheduler import (
91
93
  ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
92
94
  )
93
- from nshtrainer.model import DirectoryConfig as DirectoryConfig
94
- from nshtrainer.model import TrainerConfig as TrainerConfig
95
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
96
95
  from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
97
96
  from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
98
97
  from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
@@ -129,6 +128,7 @@ if TYPE_CHECKING:
129
128
  from nshtrainer.trainer._config import (
130
129
  CheckpointSavingConfig as CheckpointSavingConfig,
131
130
  )
131
+ from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
132
132
  from nshtrainer.trainer._config import (
133
133
  GradientClippingConfig as GradientClippingConfig,
134
134
  )
@@ -179,272 +179,274 @@ else:
179
179
 
180
180
  if name in globals():
181
181
  return globals()[name]
182
- if name == "MetricConfig":
183
- return importlib.import_module("nshtrainer").MetricConfig
184
- if name == "BaseConfig":
185
- return importlib.import_module("nshtrainer").BaseConfig
186
- if name == "HuggingFaceHubAutoCreateConfig":
187
- return importlib.import_module(
188
- "nshtrainer._hf_hub"
189
- ).HuggingFaceHubAutoCreateConfig
190
- if name == "HuggingFaceHubConfig":
191
- return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
192
- if name == "CallbackConfigBase":
193
- return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
194
- if name == "OptimizerConfigBase":
195
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
182
+ if name == "ActSaveConfig":
183
+ return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
184
+ if name == "ActSaveLoggerConfig":
185
+ return importlib.import_module("nshtrainer.loggers").ActSaveLoggerConfig
196
186
  if name == "AdamWConfig":
197
187
  return importlib.import_module("nshtrainer.optimizer").AdamWConfig
198
- if name == "DirectorySetupCallbackConfig":
188
+ if name == "AdvancedProfilerConfig":
189
+ return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
190
+ if name == "BaseCheckpointCallbackConfig":
199
191
  return importlib.import_module(
200
- "nshtrainer.callbacks"
201
- ).DirectorySetupCallbackConfig
202
- if name == "DirectoryConfig":
203
- return importlib.import_module("nshtrainer.model").DirectoryConfig
204
- if name == "TrainerConfig":
205
- return importlib.import_module("nshtrainer.model").TrainerConfig
206
- if name == "EnvironmentConfig":
207
- return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
192
+ "nshtrainer.callbacks.checkpoint._base"
193
+ ).BaseCheckpointCallbackConfig
194
+ if name == "BaseLoggerConfig":
195
+ return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
208
196
  if name == "BaseNonlinearityConfig":
209
197
  return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
210
- if name == "MLPConfig":
211
- return importlib.import_module("nshtrainer.nn").MLPConfig
212
- if name == "PReLUConfig":
213
- return importlib.import_module("nshtrainer.nn").PReLUConfig
214
- if name == "LeakyReLUNonlinearityConfig":
215
- return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
216
- if name == "SwiGLUNonlinearityConfig":
198
+ if name == "BaseProfilerConfig":
199
+ return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
200
+ if name == "BestCheckpointCallbackConfig":
217
201
  return importlib.import_module(
218
- "nshtrainer.nn.nonlinearity"
219
- ).SwiGLUNonlinearityConfig
220
- if name == "SoftsignNonlinearityConfig":
221
- return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
222
- if name == "SiLUNonlinearityConfig":
223
- return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
224
- if name == "SigmoidNonlinearityConfig":
225
- return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
226
- if name == "SoftplusNonlinearityConfig":
227
- return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
228
- if name == "ELUNonlinearityConfig":
229
- return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
230
- if name == "SoftmaxNonlinearityConfig":
231
- return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
232
- if name == "GELUNonlinearityConfig":
233
- return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
234
- if name == "SwishNonlinearityConfig":
235
- return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
236
- if name == "MishNonlinearityConfig":
237
- return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
238
- if name == "TanhNonlinearityConfig":
239
- return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
240
- if name == "ReLUNonlinearityConfig":
241
- return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
242
- if name == "LRSchedulerConfigBase":
202
+ "nshtrainer.callbacks"
203
+ ).BestCheckpointCallbackConfig
204
+ if name == "BestCheckpointStrategyConfig":
243
205
  return importlib.import_module(
244
- "nshtrainer.lr_scheduler"
245
- ).LRSchedulerConfigBase
246
- if name == "LinearWarmupCosineDecayLRSchedulerConfig":
206
+ "nshtrainer._checkpoint.loader"
207
+ ).BestCheckpointStrategyConfig
208
+ if name == "CSVLoggerConfig":
209
+ return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
210
+ if name == "CallbackConfigBase":
211
+ return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
212
+ if name == "CheckpointLoadingConfig":
247
213
  return importlib.import_module(
248
- "nshtrainer.lr_scheduler"
249
- ).LinearWarmupCosineDecayLRSchedulerConfig
250
- if name == "ReduceLROnPlateauConfig":
214
+ "nshtrainer.trainer._config"
215
+ ).CheckpointLoadingConfig
216
+ if name == "CheckpointMetadata":
251
217
  return importlib.import_module(
252
- "nshtrainer.lr_scheduler"
253
- ).ReduceLROnPlateauConfig
254
- if name == "BaseLoggerConfig":
255
- return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
256
- if name == "TensorboardLoggerConfig":
257
- return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
258
- if name == "WandbLoggerConfig":
259
- return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
260
- if name == "WandbUploadCodeCallbackConfig":
218
+ "nshtrainer._checkpoint.loader"
219
+ ).CheckpointMetadata
220
+ if name == "CheckpointSavingConfig":
221
+ return importlib.import_module(
222
+ "nshtrainer.trainer._config"
223
+ ).CheckpointSavingConfig
224
+ if name == "DTypeConfig":
225
+ return importlib.import_module("nshtrainer.util.config").DTypeConfig
226
+ if name == "DebugFlagCallbackConfig":
261
227
  return importlib.import_module(
262
228
  "nshtrainer.callbacks"
263
- ).WandbUploadCodeCallbackConfig
264
- if name == "WandbWatchCallbackConfig":
229
+ ).DebugFlagCallbackConfig
230
+ if name == "DirectoryConfig":
231
+ return importlib.import_module("nshtrainer._directory").DirectoryConfig
232
+ if name == "DirectorySetupCallbackConfig":
265
233
  return importlib.import_module(
266
234
  "nshtrainer.callbacks"
267
- ).WandbWatchCallbackConfig
268
- if name == "CSVLoggerConfig":
269
- return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
270
- if name == "EnvironmentLinuxEnvironmentConfig":
235
+ ).DirectorySetupCallbackConfig
236
+ if name == "ELUNonlinearityConfig":
237
+ return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
238
+ if name == "EMACallbackConfig":
239
+ return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
240
+ if name == "EarlyStoppingCallbackConfig":
241
+ return importlib.import_module(
242
+ "nshtrainer.callbacks"
243
+ ).EarlyStoppingCallbackConfig
244
+ if name == "EnvironmentCUDAConfig":
271
245
  return importlib.import_module(
272
246
  "nshtrainer.util._environment_info"
273
- ).EnvironmentLinuxEnvironmentConfig
274
- if name == "EnvironmentLSFInformationConfig":
247
+ ).EnvironmentCUDAConfig
248
+ if name == "EnvironmentClassInformationConfig":
275
249
  return importlib.import_module(
276
250
  "nshtrainer.util._environment_info"
277
- ).EnvironmentLSFInformationConfig
251
+ ).EnvironmentClassInformationConfig
252
+ if name == "EnvironmentConfig":
253
+ return importlib.import_module(
254
+ "nshtrainer.trainer._config"
255
+ ).EnvironmentConfig
278
256
  if name == "EnvironmentGPUConfig":
279
257
  return importlib.import_module(
280
258
  "nshtrainer.util._environment_info"
281
259
  ).EnvironmentGPUConfig
282
- if name == "EnvironmentPackageConfig":
283
- return importlib.import_module(
284
- "nshtrainer.util._environment_info"
285
- ).EnvironmentPackageConfig
286
260
  if name == "EnvironmentHardwareConfig":
287
261
  return importlib.import_module(
288
262
  "nshtrainer.util._environment_info"
289
263
  ).EnvironmentHardwareConfig
290
- if name == "EnvironmentSnapshotConfig":
291
- return importlib.import_module(
292
- "nshtrainer.util._environment_info"
293
- ).EnvironmentSnapshotConfig
294
- if name == "EnvironmentClassInformationConfig":
264
+ if name == "EnvironmentLSFInformationConfig":
295
265
  return importlib.import_module(
296
266
  "nshtrainer.util._environment_info"
297
- ).EnvironmentClassInformationConfig
298
- if name == "GitRepositoryConfig":
267
+ ).EnvironmentLSFInformationConfig
268
+ if name == "EnvironmentLinuxEnvironmentConfig":
299
269
  return importlib.import_module(
300
270
  "nshtrainer.util._environment_info"
301
- ).GitRepositoryConfig
302
- if name == "EnvironmentCUDAConfig":
271
+ ).EnvironmentLinuxEnvironmentConfig
272
+ if name == "EnvironmentPackageConfig":
303
273
  return importlib.import_module(
304
274
  "nshtrainer.util._environment_info"
305
- ).EnvironmentCUDAConfig
275
+ ).EnvironmentPackageConfig
306
276
  if name == "EnvironmentSLURMInformationConfig":
307
277
  return importlib.import_module(
308
278
  "nshtrainer.util._environment_info"
309
279
  ).EnvironmentSLURMInformationConfig
310
- if name == "EpochsConfig":
311
- return importlib.import_module("nshtrainer.util.config").EpochsConfig
312
- if name == "StepsConfig":
313
- return importlib.import_module("nshtrainer.util.config").StepsConfig
314
- if name == "DTypeConfig":
315
- return importlib.import_module("nshtrainer.util.config").DTypeConfig
316
- if name == "CheckpointLoadingConfig":
280
+ if name == "EnvironmentSnapshotConfig":
317
281
  return importlib.import_module(
318
- "nshtrainer.trainer._config"
319
- ).CheckpointLoadingConfig
320
- if name == "SanityCheckingConfig":
282
+ "nshtrainer.util._environment_info"
283
+ ).EnvironmentSnapshotConfig
284
+ if name == "EpochTimerCallbackConfig":
321
285
  return importlib.import_module(
322
- "nshtrainer.trainer._config"
323
- ).SanityCheckingConfig
324
- if name == "OnExceptionCheckpointCallbackConfig":
286
+ "nshtrainer.callbacks"
287
+ ).EpochTimerCallbackConfig
288
+ if name == "EpochsConfig":
289
+ return importlib.import_module("nshtrainer.util.config").EpochsConfig
290
+ if name == "FiniteChecksCallbackConfig":
325
291
  return importlib.import_module(
326
292
  "nshtrainer.callbacks"
327
- ).OnExceptionCheckpointCallbackConfig
293
+ ).FiniteChecksCallbackConfig
294
+ if name == "GELUNonlinearityConfig":
295
+ return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
296
+ if name == "GitRepositoryConfig":
297
+ return importlib.import_module(
298
+ "nshtrainer.util._environment_info"
299
+ ).GitRepositoryConfig
328
300
  if name == "GradientClippingConfig":
329
301
  return importlib.import_module(
330
302
  "nshtrainer.trainer._config"
331
303
  ).GradientClippingConfig
332
- if name == "LoggingConfig":
333
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
334
- if name == "RLPSanityChecksCallbackConfig":
304
+ if name == "GradientSkippingCallbackConfig":
335
305
  return importlib.import_module(
336
306
  "nshtrainer.callbacks"
337
- ).RLPSanityChecksCallbackConfig
338
- if name == "CheckpointSavingConfig":
307
+ ).GradientSkippingCallbackConfig
308
+ if name == "HuggingFaceHubAutoCreateConfig":
339
309
  return importlib.import_module(
340
- "nshtrainer.trainer._config"
341
- ).CheckpointSavingConfig
342
- if name == "DebugFlagCallbackConfig":
310
+ "nshtrainer._hf_hub"
311
+ ).HuggingFaceHubAutoCreateConfig
312
+ if name == "HuggingFaceHubConfig":
313
+ return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
314
+ if name == "LRSchedulerConfigBase":
343
315
  return importlib.import_module(
344
- "nshtrainer.callbacks"
345
- ).DebugFlagCallbackConfig
316
+ "nshtrainer.lr_scheduler"
317
+ ).LRSchedulerConfigBase
346
318
  if name == "LastCheckpointCallbackConfig":
347
319
  return importlib.import_module(
348
320
  "nshtrainer.callbacks"
349
321
  ).LastCheckpointCallbackConfig
350
- if name == "SharedParametersCallbackConfig":
322
+ if name == "LastCheckpointStrategyConfig":
323
+ return importlib.import_module(
324
+ "nshtrainer._checkpoint.loader"
325
+ ).LastCheckpointStrategyConfig
326
+ if name == "LeakyReLUNonlinearityConfig":
327
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
328
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
329
+ return importlib.import_module(
330
+ "nshtrainer.lr_scheduler"
331
+ ).LinearWarmupCosineDecayLRSchedulerConfig
332
+ if name == "LogEpochCallbackConfig":
351
333
  return importlib.import_module(
352
334
  "nshtrainer.callbacks"
353
- ).SharedParametersCallbackConfig
354
- if name == "ReproducibilityConfig":
335
+ ).LogEpochCallbackConfig
336
+ if name == "LoggingConfig":
337
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
338
+ if name == "MLPConfig":
339
+ return importlib.import_module("nshtrainer.nn").MLPConfig
340
+ if name == "MetricConfig":
341
+ return importlib.import_module("nshtrainer").MetricConfig
342
+ if name == "MishNonlinearityConfig":
343
+ return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
344
+ if name == "NormLoggingCallbackConfig":
355
345
  return importlib.import_module(
356
- "nshtrainer.trainer._config"
357
- ).ReproducibilityConfig
358
- if name == "EarlyStoppingCallbackConfig":
346
+ "nshtrainer.callbacks"
347
+ ).NormLoggingCallbackConfig
348
+ if name == "OnExceptionCheckpointCallbackConfig":
359
349
  return importlib.import_module(
360
350
  "nshtrainer.callbacks"
361
- ).EarlyStoppingCallbackConfig
351
+ ).OnExceptionCheckpointCallbackConfig
362
352
  if name == "OptimizationConfig":
363
353
  return importlib.import_module(
364
354
  "nshtrainer.trainer._config"
365
355
  ).OptimizationConfig
366
- if name == "BestCheckpointCallbackConfig":
356
+ if name == "OptimizerConfigBase":
357
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
358
+ if name == "PReLUConfig":
359
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
360
+ if name == "PrintTableMetricsCallbackConfig":
367
361
  return importlib.import_module(
368
362
  "nshtrainer.callbacks"
369
- ).BestCheckpointCallbackConfig
370
- if name == "CheckpointMetadata":
371
- return importlib.import_module(
372
- "nshtrainer._checkpoint.loader"
373
- ).CheckpointMetadata
374
- if name == "BestCheckpointStrategyConfig":
363
+ ).PrintTableMetricsCallbackConfig
364
+ if name == "PyTorchProfilerConfig":
365
+ return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
366
+ if name == "RLPSanityChecksCallbackConfig":
375
367
  return importlib.import_module(
376
- "nshtrainer._checkpoint.loader"
377
- ).BestCheckpointStrategyConfig
378
- if name == "LastCheckpointStrategyConfig":
368
+ "nshtrainer.callbacks"
369
+ ).RLPSanityChecksCallbackConfig
370
+ if name == "ReLUNonlinearityConfig":
371
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
372
+ if name == "ReduceLROnPlateauConfig":
379
373
  return importlib.import_module(
380
- "nshtrainer._checkpoint.loader"
381
- ).LastCheckpointStrategyConfig
382
- if name == "UserProvidedPathCheckpointStrategyConfig":
374
+ "nshtrainer.lr_scheduler"
375
+ ).ReduceLROnPlateauConfig
376
+ if name == "ReproducibilityConfig":
383
377
  return importlib.import_module(
384
- "nshtrainer._checkpoint.loader"
385
- ).UserProvidedPathCheckpointStrategyConfig
386
- if name == "PrintTableMetricsCallbackConfig":
378
+ "nshtrainer.trainer._config"
379
+ ).ReproducibilityConfig
380
+ if name == "SanityCheckingConfig":
387
381
  return importlib.import_module(
388
- "nshtrainer.callbacks"
389
- ).PrintTableMetricsCallbackConfig
390
- if name == "ThroughputMonitorConfig":
382
+ "nshtrainer.trainer._config"
383
+ ).SanityCheckingConfig
384
+ if name == "SharedParametersCallbackConfig":
391
385
  return importlib.import_module(
392
386
  "nshtrainer.callbacks"
393
- ).ThroughputMonitorConfig
394
- if name == "GradientSkippingCallbackConfig":
387
+ ).SharedParametersCallbackConfig
388
+ if name == "SiLUNonlinearityConfig":
389
+ return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
390
+ if name == "SigmoidNonlinearityConfig":
391
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
392
+ if name == "SimpleProfilerConfig":
393
+ return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
394
+ if name == "SoftmaxNonlinearityConfig":
395
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
396
+ if name == "SoftplusNonlinearityConfig":
397
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
398
+ if name == "SoftsignNonlinearityConfig":
399
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
400
+ if name == "StepsConfig":
401
+ return importlib.import_module("nshtrainer.util.config").StepsConfig
402
+ if name == "SwiGLUNonlinearityConfig":
395
403
  return importlib.import_module(
396
- "nshtrainer.callbacks"
397
- ).GradientSkippingCallbackConfig
398
- if name == "EMACallbackConfig":
399
- return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
400
- if name == "ActSaveConfig":
401
- return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
402
- if name == "FiniteChecksCallbackConfig":
404
+ "nshtrainer.nn.nonlinearity"
405
+ ).SwiGLUNonlinearityConfig
406
+ if name == "SwishNonlinearityConfig":
407
+ return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
408
+ if name == "TanhNonlinearityConfig":
409
+ return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
410
+ if name == "TensorboardLoggerConfig":
411
+ return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
412
+ if name == "TrainerConfig":
413
+ return importlib.import_module("nshtrainer").TrainerConfig
414
+ if name == "UserProvidedPathCheckpointStrategyConfig":
403
415
  return importlib.import_module(
404
- "nshtrainer.callbacks"
405
- ).FiniteChecksCallbackConfig
406
- if name == "NormLoggingCallbackConfig":
416
+ "nshtrainer._checkpoint.loader"
417
+ ).UserProvidedPathCheckpointStrategyConfig
418
+ if name == "WandbLoggerConfig":
419
+ return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
420
+ if name == "WandbUploadCodeCallbackConfig":
407
421
  return importlib.import_module(
408
422
  "nshtrainer.callbacks"
409
- ).NormLoggingCallbackConfig
410
- if name == "EpochTimerCallbackConfig":
423
+ ).WandbUploadCodeCallbackConfig
424
+ if name == "WandbWatchCallbackConfig":
411
425
  return importlib.import_module(
412
426
  "nshtrainer.callbacks"
413
- ).EpochTimerCallbackConfig
414
- if name == "BaseCheckpointCallbackConfig":
415
- return importlib.import_module(
416
- "nshtrainer.callbacks.checkpoint._base"
417
- ).BaseCheckpointCallbackConfig
418
- if name == "BaseProfilerConfig":
419
- return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
420
- if name == "PyTorchProfilerConfig":
421
- return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
422
- if name == "AdvancedProfilerConfig":
423
- return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
424
- if name == "SimpleProfilerConfig":
425
- return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
426
- if name == "OptimizerConfig":
427
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
428
- if name == "LoggerConfig":
429
- return importlib.import_module("nshtrainer.loggers").LoggerConfig
430
- if name == "NonlinearityConfig":
431
- return importlib.import_module("nshtrainer.nn").NonlinearityConfig
432
- if name == "DurationConfig":
433
- return importlib.import_module("nshtrainer.util.config").DurationConfig
434
- if name == "LRSchedulerConfig":
435
- return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
427
+ ).WandbWatchCallbackConfig
436
428
  if name == "CallbackConfig":
437
429
  return importlib.import_module("nshtrainer.callbacks").CallbackConfig
438
430
  if name == "CheckpointCallbackConfig":
439
431
  return importlib.import_module(
440
432
  "nshtrainer.trainer._config"
441
433
  ).CheckpointCallbackConfig
442
- if name == "ProfilerConfig":
443
- return importlib.import_module("nshtrainer.profiler").ProfilerConfig
444
434
  if name == "CheckpointLoadingStrategyConfig":
445
435
  return importlib.import_module(
446
436
  "nshtrainer._checkpoint.loader"
447
437
  ).CheckpointLoadingStrategyConfig
438
+ if name == "DurationConfig":
439
+ return importlib.import_module("nshtrainer.util.config").DurationConfig
440
+ if name == "LRSchedulerConfig":
441
+ return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
442
+ if name == "LoggerConfig":
443
+ return importlib.import_module("nshtrainer.loggers").LoggerConfig
444
+ if name == "NonlinearityConfig":
445
+ return importlib.import_module("nshtrainer.nn").NonlinearityConfig
446
+ if name == "OptimizerConfig":
447
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
448
+ if name == "ProfilerConfig":
449
+ return importlib.import_module("nshtrainer.profiler").ProfilerConfig
448
450
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
449
451
 
450
452
 
@@ -456,10 +458,8 @@ from . import callbacks as callbacks
456
458
  from . import loggers as loggers
457
459
  from . import lr_scheduler as lr_scheduler
458
460
  from . import metrics as metrics
459
- from . import model as model
460
461
  from . import nn as nn
461
462
  from . import optimizer as optimizer
462
463
  from . import profiler as profiler
463
- from . import runner as runner
464
464
  from . import trainer as trainer
465
465
  from . import util as util
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer._checkpoint.loader import (
11
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
12
+ )
13
+ from nshtrainer._checkpoint.loader import (
14
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
15
+ )
16
+ from nshtrainer._checkpoint.loader import (
17
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
18
+ )
19
+ from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
20
+ from nshtrainer._checkpoint.loader import (
21
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
22
+ )
23
+ from nshtrainer._checkpoint.loader import MetricConfig as MetricConfig
24
+ from nshtrainer._checkpoint.loader import (
25
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
26
+ )
27
+ from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
28
+ else:
29
+
30
+ def __getattr__(name):
31
+ import importlib
32
+
33
+ if name in globals():
34
+ return globals()[name]
35
+ if name == "BestCheckpointStrategyConfig":
36
+ return importlib.import_module(
37
+ "nshtrainer._checkpoint.loader"
38
+ ).BestCheckpointStrategyConfig
39
+ if name == "CheckpointLoadingConfig":
40
+ return importlib.import_module(
41
+ "nshtrainer._checkpoint.loader"
42
+ ).CheckpointLoadingConfig
43
+ if name == "CheckpointMetadata":
44
+ return importlib.import_module(
45
+ "nshtrainer._checkpoint.loader"
46
+ ).CheckpointMetadata
47
+ if name == "EnvironmentConfig":
48
+ return importlib.import_module(
49
+ "nshtrainer._checkpoint.metadata"
50
+ ).EnvironmentConfig
51
+ if name == "LastCheckpointStrategyConfig":
52
+ return importlib.import_module(
53
+ "nshtrainer._checkpoint.loader"
54
+ ).LastCheckpointStrategyConfig
55
+ if name == "MetricConfig":
56
+ return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
57
+ if name == "UserProvidedPathCheckpointStrategyConfig":
58
+ return importlib.import_module(
59
+ "nshtrainer._checkpoint.loader"
60
+ ).UserProvidedPathCheckpointStrategyConfig
61
+ if name == "CheckpointLoadingStrategyConfig":
62
+ return importlib.import_module(
63
+ "nshtrainer._checkpoint.loader"
64
+ ).CheckpointLoadingStrategyConfig
65
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
66
+
67
+
68
+ # Submodule exports
69
+ from . import loader as loader
70
+ from . import metadata as metadata
@@ -31,12 +31,14 @@ else:
31
31
 
32
32
  if name in globals():
33
33
  return globals()[name]
34
- if name == "MetricConfig":
35
- return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
36
34
  if name == "BestCheckpointStrategyConfig":
37
35
  return importlib.import_module(
38
36
  "nshtrainer._checkpoint.loader"
39
37
  ).BestCheckpointStrategyConfig
38
+ if name == "CheckpointLoadingConfig":
39
+ return importlib.import_module(
40
+ "nshtrainer._checkpoint.loader"
41
+ ).CheckpointLoadingConfig
40
42
  if name == "CheckpointMetadata":
41
43
  return importlib.import_module(
42
44
  "nshtrainer._checkpoint.loader"
@@ -45,10 +47,8 @@ else:
45
47
  return importlib.import_module(
46
48
  "nshtrainer._checkpoint.loader"
47
49
  ).LastCheckpointStrategyConfig
48
- if name == "CheckpointLoadingConfig":
49
- return importlib.import_module(
50
- "nshtrainer._checkpoint.loader"
51
- ).CheckpointLoadingConfig
50
+ if name == "MetricConfig":
51
+ return importlib.import_module("nshtrainer._checkpoint.loader").MetricConfig
52
52
  if name == "UserProvidedPathCheckpointStrategyConfig":
53
53
  return importlib.import_module(
54
54
  "nshtrainer._checkpoint.loader"