nshtrainer 0.42.0__py3-none-any.whl → 0.44.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +130 -90
  35. nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
  37. nshtrainer/config/_directory/__init__.py +9 -3
  38. nshtrainer/config/_hf_hub/__init__.py +6 -4
  39. nshtrainer/config/callbacks/__init__.py +82 -42
  40. nshtrainer/config/callbacks/actsave/__init__.py +4 -2
  41. nshtrainer/config/callbacks/base/__init__.py +2 -0
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
  50. nshtrainer/config/callbacks/ema/__init__.py +5 -3
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
  54. nshtrainer/config/callbacks/print_table/__init__.py +7 -5
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
  58. nshtrainer/config/callbacks/timer/__init__.py +9 -5
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
  61. nshtrainer/config/loggers/__init__.py +18 -10
  62. nshtrainer/config/loggers/_base/__init__.py +2 -0
  63. nshtrainer/config/loggers/csv/__init__.py +2 -0
  64. nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
  65. nshtrainer/config/loggers/wandb/__init__.py +18 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +2 -0
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
  70. nshtrainer/config/metrics/__init__.py +2 -0
  71. nshtrainer/config/metrics/_config/__init__.py +2 -0
  72. nshtrainer/config/model/__init__.py +8 -6
  73. nshtrainer/config/model/base/__init__.py +4 -2
  74. nshtrainer/config/model/config/__init__.py +8 -6
  75. nshtrainer/config/model/mixins/logger/__init__.py +2 -0
  76. nshtrainer/config/nn/__init__.py +16 -14
  77. nshtrainer/config/nn/mlp/__init__.py +2 -0
  78. nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
  79. nshtrainer/config/optimizer/__init__.py +2 -0
  80. nshtrainer/config/profiler/__init__.py +2 -0
  81. nshtrainer/config/profiler/_base/__init__.py +2 -0
  82. nshtrainer/config/profiler/advanced/__init__.py +6 -4
  83. nshtrainer/config/profiler/pytorch/__init__.py +6 -4
  84. nshtrainer/config/profiler/simple/__init__.py +6 -4
  85. nshtrainer/config/runner/__init__.py +2 -0
  86. nshtrainer/config/trainer/_config/__init__.py +43 -39
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
  88. nshtrainer/config/util/_environment_info/__init__.py +20 -18
  89. nshtrainer/config/util/config/__init__.py +2 -0
  90. nshtrainer/config/util/config/dtype/__init__.py +2 -0
  91. nshtrainer/config/util/config/duration/__init__.py +2 -0
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +8 -11
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +18 -17
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +8 -6
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.44.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.42.0.dist-info/RECORD +0 -162
  162. {nshtrainer-0.42.0.dist-info → nshtrainer-0.44.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -14,12 +16,12 @@ else:
14
16
 
15
17
  if name in globals():
16
18
  return globals()[name]
19
+ if name == "ActSaveConfig":
20
+ return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
17
21
  if name == "CallbackConfigBase":
18
22
  return importlib.import_module(
19
23
  "nshtrainer.callbacks.actsave"
20
24
  ).CallbackConfigBase
21
- if name == "ActSaveConfig":
22
- return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
23
25
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
24
26
 
25
27
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -33,10 +35,6 @@ else:
33
35
 
34
36
  if name in globals():
35
37
  return globals()[name]
36
- if name == "LastCheckpointCallbackConfig":
37
- return importlib.import_module(
38
- "nshtrainer.callbacks.checkpoint"
39
- ).LastCheckpointCallbackConfig
40
38
  if name == "CheckpointMetadata":
41
39
  return importlib.import_module(
42
40
  "nshtrainer.callbacks.checkpoint._base"
@@ -45,6 +43,10 @@ else:
45
43
  return importlib.import_module(
46
44
  "nshtrainer.callbacks.checkpoint._base"
47
45
  ).BaseCheckpointCallbackConfig
46
+ if name == "LastCheckpointCallbackConfig":
47
+ return importlib.import_module(
48
+ "nshtrainer.callbacks.checkpoint"
49
+ ).LastCheckpointCallbackConfig
48
50
  if name == "CallbackConfigBase":
49
51
  return importlib.import_module(
50
52
  "nshtrainer.callbacks.checkpoint._base"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -21,10 +23,6 @@ else:
21
23
 
22
24
  if name in globals():
23
25
  return globals()[name]
24
- if name == "CallbackConfigBase":
25
- return importlib.import_module(
26
- "nshtrainer.callbacks.checkpoint._base"
27
- ).CallbackConfigBase
28
26
  if name == "CheckpointMetadata":
29
27
  return importlib.import_module(
30
28
  "nshtrainer.callbacks.checkpoint._base"
@@ -33,6 +31,10 @@ else:
33
31
  return importlib.import_module(
34
32
  "nshtrainer.callbacks.checkpoint._base"
35
33
  ).BaseCheckpointCallbackConfig
34
+ if name == "CallbackConfigBase":
35
+ return importlib.import_module(
36
+ "nshtrainer.callbacks.checkpoint._base"
37
+ ).CallbackConfigBase
36
38
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
37
39
 
38
40
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -21,10 +23,6 @@ else:
21
23
 
22
24
  if name in globals():
23
25
  return globals()[name]
24
- if name == "LastCheckpointCallbackConfig":
25
- return importlib.import_module(
26
- "nshtrainer.callbacks.checkpoint.last_checkpoint"
27
- ).LastCheckpointCallbackConfig
28
26
  if name == "CheckpointMetadata":
29
27
  return importlib.import_module(
30
28
  "nshtrainer.callbacks.checkpoint.last_checkpoint"
@@ -33,6 +31,10 @@ else:
33
31
  return importlib.import_module(
34
32
  "nshtrainer.callbacks.checkpoint.last_checkpoint"
35
33
  ).BaseCheckpointCallbackConfig
34
+ if name == "LastCheckpointCallbackConfig":
35
+ return importlib.import_module(
36
+ "nshtrainer.callbacks.checkpoint.last_checkpoint"
37
+ ).LastCheckpointCallbackConfig
36
38
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
37
39
 
38
40
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
22
- return importlib.import_module(
23
- "nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
24
- ).CallbackConfigBase
25
23
  if name == "OnExceptionCheckpointCallbackConfig":
26
24
  return importlib.import_module(
27
25
  "nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
28
26
  ).OnExceptionCheckpointCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -16,14 +18,14 @@ else:
16
18
 
17
19
  if name in globals():
18
20
  return globals()[name]
19
- if name == "CallbackConfigBase":
20
- return importlib.import_module(
21
- "nshtrainer.callbacks.debug_flag"
22
- ).CallbackConfigBase
23
21
  if name == "DebugFlagCallbackConfig":
24
22
  return importlib.import_module(
25
23
  "nshtrainer.callbacks.debug_flag"
26
24
  ).DebugFlagCallbackConfig
25
+ if name == "CallbackConfigBase":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.debug_flag"
28
+ ).CallbackConfigBase
27
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
30
 
29
31
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.directory_setup import (
12
- DirectorySetupConfig as DirectorySetupConfig,
14
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "DirectorySetupCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.directory_setup"
24
- ).CallbackConfigBase
25
- if name == "DirectorySetupConfig":
26
+ ).DirectorySetupCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.directory_setup"
28
- ).DirectorySetupConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.early_stopping import (
12
- EarlyStoppingConfig as EarlyStoppingConfig,
14
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
13
15
  )
14
16
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
15
17
  else:
@@ -19,18 +21,18 @@ else:
19
21
 
20
22
  if name in globals():
21
23
  return globals()[name]
22
- if name == "CallbackConfigBase":
23
- return importlib.import_module(
24
- "nshtrainer.callbacks.early_stopping"
25
- ).CallbackConfigBase
26
24
  if name == "MetricConfig":
27
25
  return importlib.import_module(
28
26
  "nshtrainer.callbacks.early_stopping"
29
27
  ).MetricConfig
30
- if name == "EarlyStoppingConfig":
28
+ if name == "EarlyStoppingCallbackConfig":
31
29
  return importlib.import_module(
32
30
  "nshtrainer.callbacks.early_stopping"
33
- ).EarlyStoppingConfig
31
+ ).EarlyStoppingCallbackConfig
32
+ if name == "CallbackConfigBase":
33
+ return importlib.import_module(
34
+ "nshtrainer.callbacks.early_stopping"
35
+ ).CallbackConfigBase
34
36
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
35
37
 
36
38
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -6,7 +8,7 @@ from typing import TYPE_CHECKING
6
8
 
7
9
  if TYPE_CHECKING:
8
10
  from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.ema import EMAConfig as EMAConfig
11
+ from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
10
12
  else:
11
13
 
12
14
  def __getattr__(name):
@@ -14,12 +16,12 @@ else:
14
16
 
15
17
  if name in globals():
16
18
  return globals()[name]
19
+ if name == "EMACallbackConfig":
20
+ return importlib.import_module("nshtrainer.callbacks.ema").EMACallbackConfig
17
21
  if name == "CallbackConfigBase":
18
22
  return importlib.import_module(
19
23
  "nshtrainer.callbacks.ema"
20
24
  ).CallbackConfigBase
21
- if name == "EMAConfig":
22
- return importlib.import_module("nshtrainer.callbacks.ema").EMAConfig
23
25
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
24
26
 
25
27
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.finite_checks import (
12
- FiniteChecksConfig as FiniteChecksConfig,
14
+ FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "FiniteChecksCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.finite_checks"
24
- ).CallbackConfigBase
25
- if name == "FiniteChecksConfig":
26
+ ).FiniteChecksCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.finite_checks"
28
- ).FiniteChecksConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.gradient_skipping import (
12
- GradientSkippingConfig as GradientSkippingConfig,
14
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "GradientSkippingCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.gradient_skipping"
24
- ).CallbackConfigBase
25
- if name == "GradientSkippingConfig":
26
+ ).GradientSkippingCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.gradient_skipping"
28
- ).GradientSkippingConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -8,7 +10,9 @@ if TYPE_CHECKING:
8
10
  from nshtrainer.callbacks.norm_logging import (
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
- from nshtrainer.callbacks.norm_logging import NormLoggingConfig as NormLoggingConfig
13
+ from nshtrainer.callbacks.norm_logging import (
14
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
15
+ )
12
16
  else:
13
17
 
14
18
  def __getattr__(name):
@@ -16,14 +20,14 @@ else:
16
20
 
17
21
  if name in globals():
18
22
  return globals()[name]
19
- if name == "CallbackConfigBase":
23
+ if name == "NormLoggingCallbackConfig":
20
24
  return importlib.import_module(
21
25
  "nshtrainer.callbacks.norm_logging"
22
- ).CallbackConfigBase
23
- if name == "NormLoggingConfig":
26
+ ).NormLoggingCallbackConfig
27
+ if name == "CallbackConfigBase":
24
28
  return importlib.import_module(
25
29
  "nshtrainer.callbacks.norm_logging"
26
- ).NormLoggingConfig
30
+ ).CallbackConfigBase
27
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
32
 
29
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.print_table import (
12
- PrintTableMetricsConfig as PrintTableMetricsConfig,
14
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "PrintTableMetricsCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.print_table"
24
- ).CallbackConfigBase
25
- if name == "PrintTableMetricsConfig":
26
+ ).PrintTableMetricsCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.print_table"
28
- ).PrintTableMetricsConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.rlp_sanity_checks import (
12
- RLPSanityChecksConfig as RLPSanityChecksConfig,
14
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "RLPSanityChecksCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.rlp_sanity_checks"
24
- ).CallbackConfigBase
25
- if name == "RLPSanityChecksConfig":
26
+ ).RLPSanityChecksCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.rlp_sanity_checks"
28
- ).RLPSanityChecksConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.shared_parameters import (
12
- SharedParametersConfig as SharedParametersConfig,
14
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "SharedParametersCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.shared_parameters"
24
- ).CallbackConfigBase
25
- if name == "SharedParametersConfig":
26
+ ).SharedParametersCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.shared_parameters"
28
- ).SharedParametersConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
22
- return importlib.import_module(
23
- "nshtrainer.callbacks.throughput_monitor"
24
- ).CallbackConfigBase
25
23
  if name == "ThroughputMonitorConfig":
26
24
  return importlib.import_module(
27
25
  "nshtrainer.callbacks.throughput_monitor"
28
26
  ).ThroughputMonitorConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.throughput_monitor"
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -6,7 +8,9 @@ from typing import TYPE_CHECKING
6
8
 
7
9
  if TYPE_CHECKING:
8
10
  from nshtrainer.callbacks.timer import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
11
+ from nshtrainer.callbacks.timer import (
12
+ EpochTimerCallbackConfig as EpochTimerCallbackConfig,
13
+ )
10
14
  else:
11
15
 
12
16
  def __getattr__(name):
@@ -14,14 +18,14 @@ else:
14
18
 
15
19
  if name in globals():
16
20
  return globals()[name]
17
- if name == "CallbackConfigBase":
21
+ if name == "EpochTimerCallbackConfig":
18
22
  return importlib.import_module(
19
23
  "nshtrainer.callbacks.timer"
20
- ).CallbackConfigBase
21
- if name == "EpochTimerConfig":
24
+ ).EpochTimerCallbackConfig
25
+ if name == "CallbackConfigBase":
22
26
  return importlib.import_module(
23
27
  "nshtrainer.callbacks.timer"
24
- ).EpochTimerConfig
28
+ ).CallbackConfigBase
25
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
30
 
27
31
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,7 +11,7 @@ if TYPE_CHECKING:
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
13
  from nshtrainer.callbacks.wandb_upload_code import (
12
- WandbUploadCodeConfig as WandbUploadCodeConfig,
14
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
13
15
  )
14
16
  else:
15
17
 
@@ -18,14 +20,14 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
23
+ if name == "WandbUploadCodeCallbackConfig":
22
24
  return importlib.import_module(
23
25
  "nshtrainer.callbacks.wandb_upload_code"
24
- ).CallbackConfigBase
25
- if name == "WandbUploadCodeConfig":
26
+ ).WandbUploadCodeCallbackConfig
27
+ if name == "CallbackConfigBase":
26
28
  return importlib.import_module(
27
29
  "nshtrainer.callbacks.wandb_upload_code"
28
- ).WandbUploadCodeConfig
30
+ ).CallbackConfigBase
29
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
32
 
31
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -8,7 +10,9 @@ if TYPE_CHECKING:
8
10
  from nshtrainer.callbacks.wandb_watch import (
9
11
  CallbackConfigBase as CallbackConfigBase,
10
12
  )
11
- from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
13
+ from nshtrainer.callbacks.wandb_watch import (
14
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
15
+ )
12
16
  else:
13
17
 
14
18
  def __getattr__(name):
@@ -16,14 +20,14 @@ else:
16
20
 
17
21
  if name in globals():
18
22
  return globals()[name]
19
- if name == "CallbackConfigBase":
23
+ if name == "WandbWatchCallbackConfig":
20
24
  return importlib.import_module(
21
25
  "nshtrainer.callbacks.wandb_watch"
22
- ).CallbackConfigBase
23
- if name == "WandbWatchConfig":
26
+ ).WandbWatchCallbackConfig
27
+ if name == "CallbackConfigBase":
24
28
  return importlib.import_module(
25
29
  "nshtrainer.callbacks.wandb_watch"
26
- ).WandbWatchConfig
30
+ ).CallbackConfigBase
27
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
32
 
29
33
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -11,8 +13,12 @@ if TYPE_CHECKING:
11
13
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
12
14
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
13
15
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
14
- from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
15
- from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
16
+ from nshtrainer.loggers.wandb import (
17
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
18
+ )
19
+ from nshtrainer.loggers.wandb import (
20
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
21
+ )
16
22
  else:
17
23
 
18
24
  def __getattr__(name):
@@ -24,18 +30,20 @@ else:
24
30
  return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
25
31
  if name == "TensorboardLoggerConfig":
26
32
  return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
27
- if name == "CallbackConfigBase":
28
- return importlib.import_module(
29
- "nshtrainer.loggers.wandb"
30
- ).CallbackConfigBase
31
33
  if name == "WandbLoggerConfig":
32
34
  return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
33
- if name == "WandbUploadCodeConfig":
35
+ if name == "WandbUploadCodeCallbackConfig":
34
36
  return importlib.import_module(
35
37
  "nshtrainer.loggers.wandb"
36
- ).WandbUploadCodeConfig
37
- if name == "WandbWatchConfig":
38
- return importlib.import_module("nshtrainer.loggers.wandb").WandbWatchConfig
38
+ ).WandbUploadCodeCallbackConfig
39
+ if name == "WandbWatchCallbackConfig":
40
+ return importlib.import_module(
41
+ "nshtrainer.loggers.wandb"
42
+ ).WandbWatchCallbackConfig
43
+ if name == "CallbackConfigBase":
44
+ return importlib.import_module(
45
+ "nshtrainer.loggers.wandb"
46
+ ).CallbackConfigBase
39
47
  if name == "CSVLoggerConfig":
40
48
  return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
41
49
  if name == "LoggerConfig":
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING