nshtrainer 0.42.0__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +130 -90
  35. nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
  37. nshtrainer/config/_directory/__init__.py +9 -3
  38. nshtrainer/config/_hf_hub/__init__.py +6 -4
  39. nshtrainer/config/callbacks/__init__.py +82 -42
  40. nshtrainer/config/callbacks/actsave/__init__.py +4 -2
  41. nshtrainer/config/callbacks/base/__init__.py +2 -0
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
  50. nshtrainer/config/callbacks/ema/__init__.py +5 -3
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
  54. nshtrainer/config/callbacks/print_table/__init__.py +7 -5
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
  58. nshtrainer/config/callbacks/timer/__init__.py +9 -5
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
  61. nshtrainer/config/loggers/__init__.py +18 -10
  62. nshtrainer/config/loggers/_base/__init__.py +2 -0
  63. nshtrainer/config/loggers/csv/__init__.py +2 -0
  64. nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
  65. nshtrainer/config/loggers/wandb/__init__.py +18 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +2 -0
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
  70. nshtrainer/config/metrics/__init__.py +2 -0
  71. nshtrainer/config/metrics/_config/__init__.py +2 -0
  72. nshtrainer/config/model/__init__.py +8 -6
  73. nshtrainer/config/model/base/__init__.py +4 -2
  74. nshtrainer/config/model/config/__init__.py +8 -6
  75. nshtrainer/config/model/mixins/logger/__init__.py +2 -0
  76. nshtrainer/config/nn/__init__.py +16 -14
  77. nshtrainer/config/nn/mlp/__init__.py +2 -0
  78. nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
  79. nshtrainer/config/optimizer/__init__.py +2 -0
  80. nshtrainer/config/profiler/__init__.py +2 -0
  81. nshtrainer/config/profiler/_base/__init__.py +2 -0
  82. nshtrainer/config/profiler/advanced/__init__.py +6 -4
  83. nshtrainer/config/profiler/pytorch/__init__.py +6 -4
  84. nshtrainer/config/profiler/simple/__init__.py +6 -4
  85. nshtrainer/config/runner/__init__.py +2 -0
  86. nshtrainer/config/trainer/_config/__init__.py +43 -39
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
  88. nshtrainer/config/util/_environment_info/__init__.py +20 -18
  89. nshtrainer/config/util/config/__init__.py +2 -0
  90. nshtrainer/config/util/config/dtype/__init__.py +2 -0
  91. nshtrainer/config/util/config/duration/__init__.py +2 -0
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +2 -0
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.43.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.42.0.dist-info/RECORD +0 -162
  162. {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -8,8 +10,12 @@ if TYPE_CHECKING:
8
10
  from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
9
11
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
10
12
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
11
- from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
12
- from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
13
+ from nshtrainer.loggers.wandb import (
14
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
15
+ )
16
+ from nshtrainer.loggers.wandb import (
17
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
18
+ )
13
19
  else:
14
20
 
15
21
  def __getattr__(name):
@@ -17,20 +23,22 @@ else:
17
23
 
18
24
  if name in globals():
19
25
  return globals()[name]
20
- if name == "CallbackConfigBase":
21
- return importlib.import_module(
22
- "nshtrainer.loggers.wandb"
23
- ).CallbackConfigBase
24
26
  if name == "WandbLoggerConfig":
25
27
  return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
26
- if name == "WandbUploadCodeConfig":
28
+ if name == "WandbUploadCodeCallbackConfig":
27
29
  return importlib.import_module(
28
30
  "nshtrainer.loggers.wandb"
29
- ).WandbUploadCodeConfig
30
- if name == "WandbWatchConfig":
31
- return importlib.import_module("nshtrainer.loggers.wandb").WandbWatchConfig
31
+ ).WandbUploadCodeCallbackConfig
32
+ if name == "WandbWatchCallbackConfig":
33
+ return importlib.import_module(
34
+ "nshtrainer.loggers.wandb"
35
+ ).WandbWatchCallbackConfig
32
36
  if name == "BaseLoggerConfig":
33
37
  return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
38
+ if name == "CallbackConfigBase":
39
+ return importlib.import_module(
40
+ "nshtrainer.loggers.wandb"
41
+ ).CallbackConfigBase
34
42
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
35
43
 
36
44
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -21,6 +23,10 @@ else:
21
23
 
22
24
  if name in globals():
23
25
  return globals()[name]
26
+ if name == "LRSchedulerConfigBase":
27
+ return importlib.import_module(
28
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
29
+ ).LRSchedulerConfigBase
24
30
  if name == "MetricConfig":
25
31
  return importlib.import_module(
26
32
  "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
@@ -29,10 +35,6 @@ else:
29
35
  return importlib.import_module(
30
36
  "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
31
37
  ).ReduceLROnPlateauConfig
32
- if name == "LRSchedulerConfigBase":
33
- return importlib.import_module(
34
- "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
35
- ).LRSchedulerConfigBase
36
38
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
37
39
 
38
40
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -18,18 +20,18 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
22
- return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
23
+ if name == "MetricConfig":
24
+ return importlib.import_module("nshtrainer.model").MetricConfig
23
25
  if name == "TrainerConfig":
24
26
  return importlib.import_module("nshtrainer.model").TrainerConfig
25
- if name == "EnvironmentConfig":
26
- return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
27
27
  if name == "BaseConfig":
28
28
  return importlib.import_module("nshtrainer.model").BaseConfig
29
- if name == "MetricConfig":
30
- return importlib.import_module("nshtrainer.model").MetricConfig
29
+ if name == "EnvironmentConfig":
30
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
31
31
  if name == "DirectoryConfig":
32
32
  return importlib.import_module("nshtrainer.model").DirectoryConfig
33
+ if name == "CallbackConfigBase":
34
+ return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
33
35
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
34
36
 
35
37
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -14,10 +16,10 @@ else:
14
16
 
15
17
  if name in globals():
16
18
  return globals()[name]
17
- if name == "EnvironmentConfig":
18
- return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
19
19
  if name == "BaseConfig":
20
20
  return importlib.import_module("nshtrainer.model.base").BaseConfig
21
+ if name == "EnvironmentConfig":
22
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
21
23
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
22
24
 
23
25
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -18,18 +20,18 @@ else:
18
20
 
19
21
  if name in globals():
20
22
  return globals()[name]
21
- if name == "CallbackConfigBase":
22
- return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
23
+ if name == "MetricConfig":
24
+ return importlib.import_module("nshtrainer.model.config").MetricConfig
23
25
  if name == "TrainerConfig":
24
26
  return importlib.import_module("nshtrainer.model.config").TrainerConfig
25
- if name == "EnvironmentConfig":
26
- return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
27
27
  if name == "BaseConfig":
28
28
  return importlib.import_module("nshtrainer.model.config").BaseConfig
29
- if name == "MetricConfig":
30
- return importlib.import_module("nshtrainer.model.config").MetricConfig
29
+ if name == "EnvironmentConfig":
30
+ return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
31
31
  if name == "DirectoryConfig":
32
32
  return importlib.import_module("nshtrainer.model.config").DirectoryConfig
33
+ if name == "CallbackConfigBase":
34
+ return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
33
35
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
34
36
 
35
37
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -35,36 +37,36 @@ else:
35
37
  return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
36
38
  if name == "MLPConfig":
37
39
  return importlib.import_module("nshtrainer.nn").MLPConfig
40
+ if name == "PReLUConfig":
41
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
42
+ if name == "LeakyReLUNonlinearityConfig":
43
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
38
44
  if name == "SwiGLUNonlinearityConfig":
39
45
  return importlib.import_module(
40
46
  "nshtrainer.nn.nonlinearity"
41
47
  ).SwiGLUNonlinearityConfig
42
- if name == "ReLUNonlinearityConfig":
43
- return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
48
+ if name == "SoftsignNonlinearityConfig":
49
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
44
50
  if name == "SiLUNonlinearityConfig":
45
51
  return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
52
+ if name == "SigmoidNonlinearityConfig":
53
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
54
+ if name == "SoftplusNonlinearityConfig":
55
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
46
56
  if name == "ELUNonlinearityConfig":
47
57
  return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
58
+ if name == "SoftmaxNonlinearityConfig":
59
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
48
60
  if name == "GELUNonlinearityConfig":
49
61
  return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
50
- if name == "SoftplusNonlinearityConfig":
51
- return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
52
- if name == "SoftsignNonlinearityConfig":
53
- return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
54
62
  if name == "SwishNonlinearityConfig":
55
63
  return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
56
- if name == "SoftmaxNonlinearityConfig":
57
- return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
58
64
  if name == "MishNonlinearityConfig":
59
65
  return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
60
- if name == "SigmoidNonlinearityConfig":
61
- return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
62
66
  if name == "TanhNonlinearityConfig":
63
67
  return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
64
- if name == "PReLUConfig":
65
- return importlib.import_module("nshtrainer.nn").PReLUConfig
66
- if name == "LeakyReLUNonlinearityConfig":
67
- return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
68
+ if name == "ReLUNonlinearityConfig":
69
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
68
70
  if name == "NonlinearityConfig":
69
71
  return importlib.import_module("nshtrainer.nn").NonlinearityConfig
70
72
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -56,64 +58,64 @@ else:
56
58
 
57
59
  if name in globals():
58
60
  return globals()[name]
61
+ if name == "PReLUConfig":
62
+ return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
63
+ if name == "LeakyReLUNonlinearityConfig":
64
+ return importlib.import_module(
65
+ "nshtrainer.nn.nonlinearity"
66
+ ).LeakyReLUNonlinearityConfig
59
67
  if name == "SwiGLUNonlinearityConfig":
60
68
  return importlib.import_module(
61
69
  "nshtrainer.nn.nonlinearity"
62
70
  ).SwiGLUNonlinearityConfig
63
- if name == "ReLUNonlinearityConfig":
71
+ if name == "SoftsignNonlinearityConfig":
64
72
  return importlib.import_module(
65
73
  "nshtrainer.nn.nonlinearity"
66
- ).ReLUNonlinearityConfig
74
+ ).SoftsignNonlinearityConfig
67
75
  if name == "SiLUNonlinearityConfig":
68
76
  return importlib.import_module(
69
77
  "nshtrainer.nn.nonlinearity"
70
78
  ).SiLUNonlinearityConfig
71
- if name == "ELUNonlinearityConfig":
72
- return importlib.import_module(
73
- "nshtrainer.nn.nonlinearity"
74
- ).ELUNonlinearityConfig
75
- if name == "GELUNonlinearityConfig":
79
+ if name == "SigmoidNonlinearityConfig":
76
80
  return importlib.import_module(
77
81
  "nshtrainer.nn.nonlinearity"
78
- ).GELUNonlinearityConfig
82
+ ).SigmoidNonlinearityConfig
79
83
  if name == "SoftplusNonlinearityConfig":
80
84
  return importlib.import_module(
81
85
  "nshtrainer.nn.nonlinearity"
82
86
  ).SoftplusNonlinearityConfig
83
- if name == "SoftsignNonlinearityConfig":
84
- return importlib.import_module(
85
- "nshtrainer.nn.nonlinearity"
86
- ).SoftsignNonlinearityConfig
87
- if name == "SwishNonlinearityConfig":
87
+ if name == "ELUNonlinearityConfig":
88
88
  return importlib.import_module(
89
89
  "nshtrainer.nn.nonlinearity"
90
- ).SwishNonlinearityConfig
90
+ ).ELUNonlinearityConfig
91
91
  if name == "SoftmaxNonlinearityConfig":
92
92
  return importlib.import_module(
93
93
  "nshtrainer.nn.nonlinearity"
94
94
  ).SoftmaxNonlinearityConfig
95
- if name == "MishNonlinearityConfig":
95
+ if name == "GELUNonlinearityConfig":
96
96
  return importlib.import_module(
97
97
  "nshtrainer.nn.nonlinearity"
98
- ).MishNonlinearityConfig
99
- if name == "SigmoidNonlinearityConfig":
98
+ ).GELUNonlinearityConfig
99
+ if name == "SwishNonlinearityConfig":
100
100
  return importlib.import_module(
101
101
  "nshtrainer.nn.nonlinearity"
102
- ).SigmoidNonlinearityConfig
102
+ ).SwishNonlinearityConfig
103
+ if name == "MishNonlinearityConfig":
104
+ return importlib.import_module(
105
+ "nshtrainer.nn.nonlinearity"
106
+ ).MishNonlinearityConfig
103
107
  if name == "TanhNonlinearityConfig":
104
108
  return importlib.import_module(
105
109
  "nshtrainer.nn.nonlinearity"
106
110
  ).TanhNonlinearityConfig
107
- if name == "BaseNonlinearityConfig":
111
+ if name == "ReLUNonlinearityConfig":
108
112
  return importlib.import_module(
109
113
  "nshtrainer.nn.nonlinearity"
110
- ).BaseNonlinearityConfig
111
- if name == "PReLUConfig":
112
- return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
113
- if name == "LeakyReLUNonlinearityConfig":
114
+ ).ReLUNonlinearityConfig
115
+ if name == "BaseNonlinearityConfig":
114
116
  return importlib.import_module(
115
117
  "nshtrainer.nn.nonlinearity"
116
- ).LeakyReLUNonlinearityConfig
118
+ ).BaseNonlinearityConfig
117
119
  if name == "NonlinearityConfig":
118
120
  return importlib.import_module(
119
121
  "nshtrainer.nn.nonlinearity"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -16,14 +18,14 @@ else:
16
18
 
17
19
  if name in globals():
18
20
  return globals()[name]
19
- if name == "BaseProfilerConfig":
20
- return importlib.import_module(
21
- "nshtrainer.profiler.advanced"
22
- ).BaseProfilerConfig
23
21
  if name == "AdvancedProfilerConfig":
24
22
  return importlib.import_module(
25
23
  "nshtrainer.profiler.advanced"
26
24
  ).AdvancedProfilerConfig
25
+ if name == "BaseProfilerConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.profiler.advanced"
28
+ ).BaseProfilerConfig
27
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
30
 
29
31
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -16,14 +18,14 @@ else:
16
18
 
17
19
  if name in globals():
18
20
  return globals()[name]
19
- if name == "BaseProfilerConfig":
20
- return importlib.import_module(
21
- "nshtrainer.profiler.pytorch"
22
- ).BaseProfilerConfig
23
21
  if name == "PyTorchProfilerConfig":
24
22
  return importlib.import_module(
25
23
  "nshtrainer.profiler.pytorch"
26
24
  ).PyTorchProfilerConfig
25
+ if name == "BaseProfilerConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.profiler.pytorch"
28
+ ).BaseProfilerConfig
27
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
30
 
29
31
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -14,14 +16,14 @@ else:
14
16
 
15
17
  if name in globals():
16
18
  return globals()[name]
17
- if name == "BaseProfilerConfig":
18
- return importlib.import_module(
19
- "nshtrainer.profiler.simple"
20
- ).BaseProfilerConfig
21
19
  if name == "SimpleProfilerConfig":
22
20
  return importlib.import_module(
23
21
  "nshtrainer.profiler.simple"
24
22
  ).SimpleProfilerConfig
23
+ if name == "BaseProfilerConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.profiler.simple"
26
+ ).BaseProfilerConfig
25
27
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
28
 
27
29
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -23,7 +25,9 @@ if TYPE_CHECKING:
23
25
  from nshtrainer.trainer._config import (
24
26
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
25
27
  )
26
- from nshtrainer.trainer._config import EarlyStoppingConfig as EarlyStoppingConfig
28
+ from nshtrainer.trainer._config import (
29
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
30
+ )
27
31
  from nshtrainer.trainer._config import (
28
32
  GradientClippingConfig as GradientClippingConfig,
29
33
  )
@@ -42,11 +46,11 @@ if TYPE_CHECKING:
42
46
  ReproducibilityConfig as ReproducibilityConfig,
43
47
  )
44
48
  from nshtrainer.trainer._config import (
45
- RLPSanityChecksConfig as RLPSanityChecksConfig,
49
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
46
50
  )
47
51
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
48
52
  from nshtrainer.trainer._config import (
49
- SharedParametersConfig as SharedParametersConfig,
53
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
50
54
  )
51
55
  from nshtrainer.trainer._config import (
52
56
  TensorboardLoggerConfig as TensorboardLoggerConfig,
@@ -60,80 +64,80 @@ else:
60
64
 
61
65
  if name in globals():
62
66
  return globals()[name]
63
- if name == "HuggingFaceHubConfig":
64
- return importlib.import_module(
65
- "nshtrainer.trainer._config"
66
- ).HuggingFaceHubConfig
67
- if name == "OptimizationConfig":
67
+ if name == "SanityCheckingConfig":
68
68
  return importlib.import_module(
69
69
  "nshtrainer.trainer._config"
70
- ).OptimizationConfig
70
+ ).SanityCheckingConfig
71
71
  if name == "TrainerConfig":
72
72
  return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
73
- if name == "TensorboardLoggerConfig":
73
+ if name == "OnExceptionCheckpointCallbackConfig":
74
74
  return importlib.import_module(
75
75
  "nshtrainer.trainer._config"
76
- ).TensorboardLoggerConfig
76
+ ).OnExceptionCheckpointCallbackConfig
77
77
  if name == "GradientClippingConfig":
78
78
  return importlib.import_module(
79
79
  "nshtrainer.trainer._config"
80
80
  ).GradientClippingConfig
81
- if name == "CallbackConfigBase":
81
+ if name == "WandbLoggerConfig":
82
82
  return importlib.import_module(
83
83
  "nshtrainer.trainer._config"
84
- ).CallbackConfigBase
85
- if name == "CSVLoggerConfig":
86
- return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
87
- if name == "LastCheckpointCallbackConfig":
84
+ ).WandbLoggerConfig
85
+ if name == "LoggingConfig":
86
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
87
+ if name == "TensorboardLoggerConfig":
88
88
  return importlib.import_module(
89
89
  "nshtrainer.trainer._config"
90
- ).LastCheckpointCallbackConfig
91
- if name == "OnExceptionCheckpointCallbackConfig":
90
+ ).TensorboardLoggerConfig
91
+ if name == "RLPSanityChecksCallbackConfig":
92
92
  return importlib.import_module(
93
93
  "nshtrainer.trainer._config"
94
- ).OnExceptionCheckpointCallbackConfig
95
- if name == "RLPSanityChecksConfig":
94
+ ).RLPSanityChecksCallbackConfig
95
+ if name == "CheckpointSavingConfig":
96
96
  return importlib.import_module(
97
97
  "nshtrainer.trainer._config"
98
- ).RLPSanityChecksConfig
99
- if name == "EarlyStoppingConfig":
98
+ ).CheckpointSavingConfig
99
+ if name == "CSVLoggerConfig":
100
+ return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
101
+ if name == "HuggingFaceHubConfig":
102
+ return importlib.import_module(
103
+ "nshtrainer.trainer._config"
104
+ ).HuggingFaceHubConfig
105
+ if name == "CheckpointLoadingConfig":
100
106
  return importlib.import_module(
101
107
  "nshtrainer.trainer._config"
102
- ).EarlyStoppingConfig
108
+ ).CheckpointLoadingConfig
103
109
  if name == "DebugFlagCallbackConfig":
104
110
  return importlib.import_module(
105
111
  "nshtrainer.trainer._config"
106
112
  ).DebugFlagCallbackConfig
107
- if name == "WandbLoggerConfig":
113
+ if name == "CallbackConfigBase":
108
114
  return importlib.import_module(
109
115
  "nshtrainer.trainer._config"
110
- ).WandbLoggerConfig
111
- if name == "CheckpointSavingConfig":
116
+ ).CallbackConfigBase
117
+ if name == "LastCheckpointCallbackConfig":
112
118
  return importlib.import_module(
113
119
  "nshtrainer.trainer._config"
114
- ).CheckpointSavingConfig
115
- if name == "CheckpointLoadingConfig":
120
+ ).LastCheckpointCallbackConfig
121
+ if name == "SharedParametersCallbackConfig":
116
122
  return importlib.import_module(
117
123
  "nshtrainer.trainer._config"
118
- ).CheckpointLoadingConfig
119
- if name == "BestCheckpointCallbackConfig":
124
+ ).SharedParametersCallbackConfig
125
+ if name == "ReproducibilityConfig":
120
126
  return importlib.import_module(
121
127
  "nshtrainer.trainer._config"
122
- ).BestCheckpointCallbackConfig
123
- if name == "LoggingConfig":
124
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
125
- if name == "SanityCheckingConfig":
128
+ ).ReproducibilityConfig
129
+ if name == "EarlyStoppingCallbackConfig":
126
130
  return importlib.import_module(
127
131
  "nshtrainer.trainer._config"
128
- ).SanityCheckingConfig
129
- if name == "SharedParametersConfig":
132
+ ).EarlyStoppingCallbackConfig
133
+ if name == "OptimizationConfig":
130
134
  return importlib.import_module(
131
135
  "nshtrainer.trainer._config"
132
- ).SharedParametersConfig
133
- if name == "ReproducibilityConfig":
136
+ ).OptimizationConfig
137
+ if name == "BestCheckpointCallbackConfig":
134
138
  return importlib.import_module(
135
139
  "nshtrainer.trainer._config"
136
- ).ReproducibilityConfig
140
+ ).BestCheckpointCallbackConfig
137
141
  if name == "CallbackConfig":
138
142
  return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
139
143
  if name == "CheckpointCallbackConfig":
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING