nshtrainer 0.41.1__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +445 -94
  35. nshtrainer/config/_checkpoint/loader/__init__.py +56 -12
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +23 -7
  37. nshtrainer/config/_directory/__init__.py +26 -8
  38. nshtrainer/config/_hf_hub/__init__.py +26 -8
  39. nshtrainer/config/callbacks/__init__.py +154 -29
  40. nshtrainer/config/callbacks/actsave/__init__.py +21 -7
  41. nshtrainer/config/callbacks/base/__init__.py +18 -6
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +63 -12
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +34 -8
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +41 -9
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +34 -8
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +27 -7
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +25 -7
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +27 -7
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +32 -8
  50. nshtrainer/config/callbacks/ema/__init__.py +21 -7
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +27 -7
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +27 -7
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +27 -7
  54. nshtrainer/config/callbacks/print_table/__init__.py +27 -7
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +27 -7
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +27 -7
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +27 -7
  58. nshtrainer/config/callbacks/timer/__init__.py +25 -7
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +27 -7
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +27 -7
  61. nshtrainer/config/loggers/__init__.py +49 -14
  62. nshtrainer/config/loggers/_base/__init__.py +16 -6
  63. nshtrainer/config/loggers/csv/__init__.py +19 -7
  64. nshtrainer/config/loggers/tensorboard/__init__.py +25 -7
  65. nshtrainer/config/loggers/wandb/__init__.py +38 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +50 -11
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +20 -6
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +34 -8
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +34 -8
  70. nshtrainer/config/metrics/__init__.py +17 -6
  71. nshtrainer/config/metrics/_config/__init__.py +16 -6
  72. nshtrainer/config/model/__init__.py +32 -11
  73. nshtrainer/config/model/base/__init__.py +19 -7
  74. nshtrainer/config/model/config/__init__.py +31 -11
  75. nshtrainer/config/model/mixins/logger/__init__.py +16 -6
  76. nshtrainer/config/nn/__init__.py +70 -23
  77. nshtrainer/config/nn/mlp/__init__.py +22 -8
  78. nshtrainer/config/nn/nonlinearity/__init__.py +119 -21
  79. nshtrainer/config/optimizer/__init__.py +22 -8
  80. nshtrainer/config/profiler/__init__.py +29 -10
  81. nshtrainer/config/profiler/_base/__init__.py +18 -6
  82. nshtrainer/config/profiler/advanced/__init__.py +25 -7
  83. nshtrainer/config/profiler/pytorch/__init__.py +25 -7
  84. nshtrainer/config/profiler/simple/__init__.py +23 -7
  85. nshtrainer/config/runner/__init__.py +16 -6
  86. nshtrainer/config/trainer/_config/__init__.py +147 -29
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +20 -6
  88. nshtrainer/config/util/_environment_info/__init__.py +88 -16
  89. nshtrainer/config/util/config/__init__.py +26 -9
  90. nshtrainer/config/util/config/dtype/__init__.py +16 -6
  91. nshtrainer/config/util/config/duration/__init__.py +28 -8
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +2 -0
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.41.1.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.43.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.41.1.dist-info/RECORD +0 -162
  162. {nshtrainer-0.41.1.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
@@ -1,13 +1,31 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.loggers.tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
9
- from nshtrainer.loggers.tensorboard import BaseLoggerConfig as BaseLoggerConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.loggers.tensorboard import BaseLoggerConfig as BaseLoggerConfig
11
+ from nshtrainer.loggers.tensorboard import (
12
+ TensorboardLoggerConfig as TensorboardLoggerConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "BaseLoggerConfig":
22
+ return importlib.import_module(
23
+ "nshtrainer.loggers.tensorboard"
24
+ ).BaseLoggerConfig
25
+ if name == "TensorboardLoggerConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.loggers.tensorboard"
28
+ ).TensorboardLoggerConfig
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
30
 
13
31
  # Submodule exports
@@ -1,16 +1,44 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
9
- from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
10
- from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
11
- from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
12
- from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
5
+ from typing import TYPE_CHECKING
13
6
 
14
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
11
+ from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
12
+ from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
13
+ from nshtrainer.loggers.wandb import (
14
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
15
+ )
16
+ from nshtrainer.loggers.wandb import (
17
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
18
+ )
19
+ else:
20
+
21
+ def __getattr__(name):
22
+ import importlib
23
+
24
+ if name in globals():
25
+ return globals()[name]
26
+ if name == "WandbLoggerConfig":
27
+ return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
28
+ if name == "WandbUploadCodeCallbackConfig":
29
+ return importlib.import_module(
30
+ "nshtrainer.loggers.wandb"
31
+ ).WandbUploadCodeCallbackConfig
32
+ if name == "WandbWatchCallbackConfig":
33
+ return importlib.import_module(
34
+ "nshtrainer.loggers.wandb"
35
+ ).WandbWatchCallbackConfig
36
+ if name == "BaseLoggerConfig":
37
+ return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
38
+ if name == "CallbackConfigBase":
39
+ return importlib.import_module(
40
+ "nshtrainer.loggers.wandb"
41
+ ).CallbackConfigBase
42
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
15
43
 
16
44
  # Submodule exports
@@ -1,18 +1,57 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
9
- from nshtrainer.lr_scheduler import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
10
- from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
11
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.lr_scheduler import (
11
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
12
+ )
13
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
14
+ from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
15
+ from nshtrainer.lr_scheduler import (
16
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
17
+ )
18
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
19
+ DurationConfig as DurationConfig,
20
+ )
21
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
22
+ MetricConfig as MetricConfig,
23
+ )
24
+ else:
25
+
26
+ def __getattr__(name):
27
+ import importlib
28
+
29
+ if name in globals():
30
+ return globals()[name]
31
+ if name == "LRSchedulerConfigBase":
32
+ return importlib.import_module(
33
+ "nshtrainer.lr_scheduler"
34
+ ).LRSchedulerConfigBase
35
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
36
+ return importlib.import_module(
37
+ "nshtrainer.lr_scheduler"
38
+ ).LinearWarmupCosineDecayLRSchedulerConfig
39
+ if name == "MetricConfig":
40
+ return importlib.import_module(
41
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
42
+ ).MetricConfig
43
+ if name == "ReduceLROnPlateauConfig":
44
+ return importlib.import_module(
45
+ "nshtrainer.lr_scheduler"
46
+ ).ReduceLROnPlateauConfig
47
+ if name == "DurationConfig":
48
+ return importlib.import_module(
49
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
50
+ ).DurationConfig
51
+ if name == "LRSchedulerConfig":
52
+ return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
53
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
54
 
13
- # Type aliases
14
- from nshtrainer.lr_scheduler.linear_warmup_cosine import DurationConfig as DurationConfig
15
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
16
55
 
17
56
  # Submodule exports
18
57
  from . import _base as _base
@@ -1,12 +1,26 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
5
+ from typing import TYPE_CHECKING
9
6
 
10
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.lr_scheduler._base import (
11
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
12
+ )
13
+ else:
14
+
15
+ def __getattr__(name):
16
+ import importlib
17
+
18
+ if name in globals():
19
+ return globals()[name]
20
+ if name == "LRSchedulerConfigBase":
21
+ return importlib.import_module(
22
+ "nshtrainer.lr_scheduler._base"
23
+ ).LRSchedulerConfigBase
24
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
25
 
12
26
  # Submodule exports
@@ -1,14 +1,40 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler.linear_warmup_cosine import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
9
- from nshtrainer.lr_scheduler.linear_warmup_cosine import LRSchedulerConfigBase as LRSchedulerConfigBase
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
12
- from nshtrainer.lr_scheduler.linear_warmup_cosine import DurationConfig as DurationConfig
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
11
+ DurationConfig as DurationConfig,
12
+ )
13
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
14
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
15
+ )
16
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
17
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
18
+ )
19
+ else:
20
+
21
+ def __getattr__(name):
22
+ import importlib
23
+
24
+ if name in globals():
25
+ return globals()[name]
26
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
27
+ return importlib.import_module(
28
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
29
+ ).LinearWarmupCosineDecayLRSchedulerConfig
30
+ if name == "LRSchedulerConfigBase":
31
+ return importlib.import_module(
32
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
33
+ ).LRSchedulerConfigBase
34
+ if name == "DurationConfig":
35
+ return importlib.import_module(
36
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
37
+ ).DurationConfig
38
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
39
 
14
40
  # Submodule exports
@@ -1,14 +1,40 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
9
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import LRSchedulerConfigBase as LRSchedulerConfigBase
10
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
5
+ from typing import TYPE_CHECKING
11
6
 
12
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
11
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
12
+ )
13
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
14
+ MetricConfig as MetricConfig,
15
+ )
16
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
17
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
18
+ )
19
+ else:
20
+
21
+ def __getattr__(name):
22
+ import importlib
23
+
24
+ if name in globals():
25
+ return globals()[name]
26
+ if name == "LRSchedulerConfigBase":
27
+ return importlib.import_module(
28
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
29
+ ).LRSchedulerConfigBase
30
+ if name == "MetricConfig":
31
+ return importlib.import_module(
32
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
33
+ ).MetricConfig
34
+ if name == "ReduceLROnPlateauConfig":
35
+ return importlib.import_module(
36
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
37
+ ).ReduceLROnPlateauConfig
38
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
39
 
14
40
  # Submodule exports
@@ -1,13 +1,24 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.metrics import MetricConfig as MetricConfig
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.metrics import MetricConfig as MetricConfig
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
15
+
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "MetricConfig":
19
+ return importlib.import_module("nshtrainer.metrics").MetricConfig
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
9
21
 
10
- # Type aliases
11
22
 
12
23
  # Submodule exports
13
24
  from . import _config as _config
@@ -1,12 +1,22 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.metrics._config import MetricConfig as MetricConfig
5
+ from typing import TYPE_CHECKING
9
6
 
10
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.metrics._config import MetricConfig as MetricConfig
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
15
+
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "MetricConfig":
19
+ return importlib.import_module("nshtrainer.metrics._config").MetricConfig
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
21
 
12
22
  # Submodule exports
@@ -1,18 +1,39 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.model import MetricConfig as MetricConfig
9
- from nshtrainer.model import BaseConfig as BaseConfig
10
- from nshtrainer.model import DirectoryConfig as DirectoryConfig
11
- from nshtrainer.model import TrainerConfig as TrainerConfig
12
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
13
- from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.model import BaseConfig as BaseConfig
11
+ from nshtrainer.model import DirectoryConfig as DirectoryConfig
12
+ from nshtrainer.model import MetricConfig as MetricConfig
13
+ from nshtrainer.model import TrainerConfig as TrainerConfig
14
+ from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
15
+ from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "MetricConfig":
24
+ return importlib.import_module("nshtrainer.model").MetricConfig
25
+ if name == "TrainerConfig":
26
+ return importlib.import_module("nshtrainer.model").TrainerConfig
27
+ if name == "BaseConfig":
28
+ return importlib.import_module("nshtrainer.model").BaseConfig
29
+ if name == "EnvironmentConfig":
30
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
31
+ if name == "DirectoryConfig":
32
+ return importlib.import_module("nshtrainer.model").DirectoryConfig
33
+ if name == "CallbackConfigBase":
34
+ return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
35
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
14
36
 
15
- # Type aliases
16
37
 
17
38
  # Submodule exports
18
39
  from . import base as base
@@ -1,13 +1,25 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
9
- from nshtrainer.model.base import BaseConfig as BaseConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.model.base import BaseConfig as BaseConfig
11
+ from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "BaseConfig":
20
+ return importlib.import_module("nshtrainer.model.base").BaseConfig
21
+ if name == "EnvironmentConfig":
22
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
23
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
24
 
13
25
  # Submodule exports
@@ -1,17 +1,37 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.model.config import MetricConfig as MetricConfig
9
- from nshtrainer.model.config import BaseConfig as BaseConfig
10
- from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
11
- from nshtrainer.model.config import TrainerConfig as TrainerConfig
12
- from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
13
- from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
14
6
 
15
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.model.config import BaseConfig as BaseConfig
11
+ from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
12
+ from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
13
+ from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
14
+ from nshtrainer.model.config import MetricConfig as MetricConfig
15
+ from nshtrainer.model.config import TrainerConfig as TrainerConfig
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "MetricConfig":
24
+ return importlib.import_module("nshtrainer.model.config").MetricConfig
25
+ if name == "TrainerConfig":
26
+ return importlib.import_module("nshtrainer.model.config").TrainerConfig
27
+ if name == "BaseConfig":
28
+ return importlib.import_module("nshtrainer.model.config").BaseConfig
29
+ if name == "EnvironmentConfig":
30
+ return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
31
+ if name == "DirectoryConfig":
32
+ return importlib.import_module("nshtrainer.model.config").DirectoryConfig
33
+ if name == "CallbackConfigBase":
34
+ return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
35
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
16
36
 
17
37
  # Submodule exports
@@ -1,12 +1,22 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
5
+ from typing import TYPE_CHECKING
9
6
 
10
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
15
+
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "BaseConfig":
19
+ return importlib.import_module("nshtrainer.model.mixins.logger").BaseConfig
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
21
 
12
22
  # Submodule exports
@@ -1,29 +1,76 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.nn import MLPConfig as MLPConfig
9
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
10
- from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
11
- from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
12
- from nshtrainer.nn import PReLUConfig as PReLUConfig
13
- from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
14
- from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
15
- from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
16
- from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
17
- from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
18
- from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
19
- from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
20
- from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
21
- from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
22
- from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
23
- from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
24
-
25
- # Type aliases
26
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
11
+ from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
12
+ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
13
+ from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
14
+ from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
15
+ from nshtrainer.nn import MLPConfig as MLPConfig
16
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
17
+ from nshtrainer.nn import PReLUConfig as PReLUConfig
18
+ from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
19
+ from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
20
+ from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
21
+ from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
22
+ from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
23
+ from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
24
+ from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
25
+ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
26
+ from nshtrainer.nn.nonlinearity import (
27
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
28
+ )
29
+ else:
30
+
31
+ def __getattr__(name):
32
+ import importlib
33
+
34
+ if name in globals():
35
+ return globals()[name]
36
+ if name == "BaseNonlinearityConfig":
37
+ return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
38
+ if name == "MLPConfig":
39
+ return importlib.import_module("nshtrainer.nn").MLPConfig
40
+ if name == "PReLUConfig":
41
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
42
+ if name == "LeakyReLUNonlinearityConfig":
43
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
44
+ if name == "SwiGLUNonlinearityConfig":
45
+ return importlib.import_module(
46
+ "nshtrainer.nn.nonlinearity"
47
+ ).SwiGLUNonlinearityConfig
48
+ if name == "SoftsignNonlinearityConfig":
49
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
50
+ if name == "SiLUNonlinearityConfig":
51
+ return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
52
+ if name == "SigmoidNonlinearityConfig":
53
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
54
+ if name == "SoftplusNonlinearityConfig":
55
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
56
+ if name == "ELUNonlinearityConfig":
57
+ return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
58
+ if name == "SoftmaxNonlinearityConfig":
59
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
60
+ if name == "GELUNonlinearityConfig":
61
+ return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
62
+ if name == "SwishNonlinearityConfig":
63
+ return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
64
+ if name == "MishNonlinearityConfig":
65
+ return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
66
+ if name == "TanhNonlinearityConfig":
67
+ return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
68
+ if name == "ReLUNonlinearityConfig":
69
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
70
+ if name == "NonlinearityConfig":
71
+ return importlib.import_module("nshtrainer.nn").NonlinearityConfig
72
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
73
+
27
74
 
28
75
  # Submodule exports
29
76
  from . import mlp as mlp
@@ -1,14 +1,28 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.nn.mlp import MLPConfig as MLPConfig
9
- from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
12
- from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
11
+ from nshtrainer.nn.mlp import MLPConfig as MLPConfig
12
+ from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
13
+ else:
14
+
15
+ def __getattr__(name):
16
+ import importlib
17
+
18
+ if name in globals():
19
+ return globals()[name]
20
+ if name == "BaseNonlinearityConfig":
21
+ return importlib.import_module("nshtrainer.nn.mlp").BaseNonlinearityConfig
22
+ if name == "MLPConfig":
23
+ return importlib.import_module("nshtrainer.nn.mlp").MLPConfig
24
+ if name == "NonlinearityConfig":
25
+ return importlib.import_module("nshtrainer.nn.mlp").NonlinearityConfig
26
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
27
 
14
28
  # Submodule exports