nshtrainer 0.41.1__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +445 -94
  35. nshtrainer/config/_checkpoint/loader/__init__.py +56 -12
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +23 -7
  37. nshtrainer/config/_directory/__init__.py +26 -8
  38. nshtrainer/config/_hf_hub/__init__.py +26 -8
  39. nshtrainer/config/callbacks/__init__.py +154 -29
  40. nshtrainer/config/callbacks/actsave/__init__.py +21 -7
  41. nshtrainer/config/callbacks/base/__init__.py +18 -6
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +63 -12
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +34 -8
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +41 -9
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +34 -8
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +27 -7
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +25 -7
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +27 -7
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +32 -8
  50. nshtrainer/config/callbacks/ema/__init__.py +21 -7
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +27 -7
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +27 -7
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +27 -7
  54. nshtrainer/config/callbacks/print_table/__init__.py +27 -7
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +27 -7
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +27 -7
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +27 -7
  58. nshtrainer/config/callbacks/timer/__init__.py +25 -7
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +27 -7
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +27 -7
  61. nshtrainer/config/loggers/__init__.py +49 -14
  62. nshtrainer/config/loggers/_base/__init__.py +16 -6
  63. nshtrainer/config/loggers/csv/__init__.py +19 -7
  64. nshtrainer/config/loggers/tensorboard/__init__.py +25 -7
  65. nshtrainer/config/loggers/wandb/__init__.py +38 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +50 -11
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +20 -6
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +34 -8
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +34 -8
  70. nshtrainer/config/metrics/__init__.py +17 -6
  71. nshtrainer/config/metrics/_config/__init__.py +16 -6
  72. nshtrainer/config/model/__init__.py +32 -11
  73. nshtrainer/config/model/base/__init__.py +19 -7
  74. nshtrainer/config/model/config/__init__.py +31 -11
  75. nshtrainer/config/model/mixins/logger/__init__.py +16 -6
  76. nshtrainer/config/nn/__init__.py +70 -23
  77. nshtrainer/config/nn/mlp/__init__.py +22 -8
  78. nshtrainer/config/nn/nonlinearity/__init__.py +119 -21
  79. nshtrainer/config/optimizer/__init__.py +22 -8
  80. nshtrainer/config/profiler/__init__.py +29 -10
  81. nshtrainer/config/profiler/_base/__init__.py +18 -6
  82. nshtrainer/config/profiler/advanced/__init__.py +25 -7
  83. nshtrainer/config/profiler/pytorch/__init__.py +25 -7
  84. nshtrainer/config/profiler/simple/__init__.py +23 -7
  85. nshtrainer/config/runner/__init__.py +16 -6
  86. nshtrainer/config/trainer/_config/__init__.py +147 -29
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +20 -6
  88. nshtrainer/config/util/_environment_info/__init__.py +88 -16
  89. nshtrainer/config/util/config/__init__.py +26 -9
  90. nshtrainer/config/util/config/dtype/__init__.py +16 -6
  91. nshtrainer/config/util/config/duration/__init__.py +28 -8
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +2 -0
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.41.1.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.43.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.41.1.dist-info/RECORD +0 -162
  162. {nshtrainer-0.41.1.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.directory_setup import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.directory_setup import DirectorySetupConfig as DirectorySetupConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.directory_setup import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.directory_setup import (
14
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "DirectorySetupCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.directory_setup"
26
+ ).DirectorySetupCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.directory_setup"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,14 +1,38 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
9
- from nshtrainer.callbacks.early_stopping import CallbackConfigBase as CallbackConfigBase
10
- from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
5
+ from typing import TYPE_CHECKING
11
6
 
12
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.early_stopping import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.early_stopping import (
14
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
15
+ )
16
+ from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
17
+ else:
18
+
19
+ def __getattr__(name):
20
+ import importlib
21
+
22
+ if name in globals():
23
+ return globals()[name]
24
+ if name == "MetricConfig":
25
+ return importlib.import_module(
26
+ "nshtrainer.callbacks.early_stopping"
27
+ ).MetricConfig
28
+ if name == "EarlyStoppingCallbackConfig":
29
+ return importlib.import_module(
30
+ "nshtrainer.callbacks.early_stopping"
31
+ ).EarlyStoppingCallbackConfig
32
+ if name == "CallbackConfigBase":
33
+ return importlib.import_module(
34
+ "nshtrainer.callbacks.early_stopping"
35
+ ).CallbackConfigBase
36
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
37
 
14
38
  # Submodule exports
@@ -1,13 +1,27 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.ema import EMAConfig as EMAConfig
9
- from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "EMACallbackConfig":
20
+ return importlib.import_module("nshtrainer.callbacks.ema").EMACallbackConfig
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.ema"
24
+ ).CallbackConfigBase
25
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
26
 
13
27
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.finite_checks import FiniteChecksConfig as FiniteChecksConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.finite_checks import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.finite_checks import (
14
+ FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "FiniteChecksCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.finite_checks"
26
+ ).FiniteChecksCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.finite_checks"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.gradient_skipping import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.gradient_skipping import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.gradient_skipping import (
14
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "GradientSkippingCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.gradient_skipping"
26
+ ).GradientSkippingCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.gradient_skipping"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.norm_logging import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.norm_logging import NormLoggingConfig as NormLoggingConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.norm_logging import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.norm_logging import (
14
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "NormLoggingCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.norm_logging"
26
+ ).NormLoggingCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.norm_logging"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
9
- from nshtrainer.callbacks.print_table import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.print_table import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.print_table import (
14
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "PrintTableMetricsCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.print_table"
26
+ ).PrintTableMetricsCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.print_table"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
9
- from nshtrainer.callbacks.rlp_sanity_checks import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.rlp_sanity_checks import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.rlp_sanity_checks import (
14
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "RLPSanityChecksCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.rlp_sanity_checks"
26
+ ).RLPSanityChecksCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.rlp_sanity_checks"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.shared_parameters import SharedParametersConfig as SharedParametersConfig
9
- from nshtrainer.callbacks.shared_parameters import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.shared_parameters import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.shared_parameters import (
14
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "SharedParametersCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.shared_parameters"
26
+ ).SharedParametersCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.shared_parameters"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.throughput_monitor import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.throughput_monitor import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.throughput_monitor import (
14
+ ThroughputMonitorConfig as ThroughputMonitorConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "ThroughputMonitorConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.throughput_monitor"
26
+ ).ThroughputMonitorConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.throughput_monitor"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,31 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.timer import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.timer import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.timer import (
12
+ EpochTimerCallbackConfig as EpochTimerCallbackConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "EpochTimerCallbackConfig":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.timer"
24
+ ).EpochTimerCallbackConfig
25
+ if name == "CallbackConfigBase":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.timer"
28
+ ).CallbackConfigBase
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
30
 
13
31
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.wandb_upload_code import WandbUploadCodeConfig as WandbUploadCodeConfig
9
- from nshtrainer.callbacks.wandb_upload_code import CallbackConfigBase as CallbackConfigBase
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.wandb_upload_code import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.wandb_upload_code import (
14
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "WandbUploadCodeCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.wandb_upload_code"
26
+ ).WandbUploadCodeCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.wandb_upload_code"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,13 +1,33 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.callbacks.wandb_watch import CallbackConfigBase as CallbackConfigBase
9
- from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.wandb_watch import (
11
+ CallbackConfigBase as CallbackConfigBase,
12
+ )
13
+ from nshtrainer.callbacks.wandb_watch import (
14
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
15
+ )
16
+ else:
17
+
18
+ def __getattr__(name):
19
+ import importlib
20
+
21
+ if name in globals():
22
+ return globals()[name]
23
+ if name == "WandbWatchCallbackConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.callbacks.wandb_watch"
26
+ ).WandbWatchCallbackConfig
27
+ if name == "CallbackConfigBase":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.wandb_watch"
30
+ ).CallbackConfigBase
31
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
32
 
13
33
  # Submodule exports
@@ -1,20 +1,55 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
9
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
10
- from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
11
- from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
12
- from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
13
- from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
14
- from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
15
-
16
- # Type aliases
17
- from nshtrainer.loggers import LoggerConfig as LoggerConfig
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
11
+ from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
12
+ from nshtrainer.loggers import LoggerConfig as LoggerConfig
13
+ from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
14
+ from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
15
+ from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
16
+ from nshtrainer.loggers.wandb import (
17
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
18
+ )
19
+ from nshtrainer.loggers.wandb import (
20
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
21
+ )
22
+ else:
23
+
24
+ def __getattr__(name):
25
+ import importlib
26
+
27
+ if name in globals():
28
+ return globals()[name]
29
+ if name == "BaseLoggerConfig":
30
+ return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
31
+ if name == "TensorboardLoggerConfig":
32
+ return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
33
+ if name == "WandbLoggerConfig":
34
+ return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
35
+ if name == "WandbUploadCodeCallbackConfig":
36
+ return importlib.import_module(
37
+ "nshtrainer.loggers.wandb"
38
+ ).WandbUploadCodeCallbackConfig
39
+ if name == "WandbWatchCallbackConfig":
40
+ return importlib.import_module(
41
+ "nshtrainer.loggers.wandb"
42
+ ).WandbWatchCallbackConfig
43
+ if name == "CallbackConfigBase":
44
+ return importlib.import_module(
45
+ "nshtrainer.loggers.wandb"
46
+ ).CallbackConfigBase
47
+ if name == "CSVLoggerConfig":
48
+ return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
49
+ if name == "LoggerConfig":
50
+ return importlib.import_module("nshtrainer.loggers").LoggerConfig
51
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
52
+
18
53
 
19
54
  # Submodule exports
20
55
  from . import _base as _base
@@ -1,12 +1,22 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
5
+ from typing import TYPE_CHECKING
9
6
 
10
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
15
+
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "BaseLoggerConfig":
19
+ return importlib.import_module("nshtrainer.loggers._base").BaseLoggerConfig
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
21
 
12
22
  # Submodule exports
@@ -1,13 +1,25 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
1
+ from __future__ import annotations
4
2
 
5
3
  __codegen__ = True
6
4
 
7
- # Config classes
8
- from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
9
- from nshtrainer.loggers.csv import BaseLoggerConfig as BaseLoggerConfig
5
+ from typing import TYPE_CHECKING
10
6
 
11
- # Type aliases
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.loggers.csv import BaseLoggerConfig as BaseLoggerConfig
11
+ from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "CSVLoggerConfig":
20
+ return importlib.import_module("nshtrainer.loggers.csv").CSVLoggerConfig
21
+ if name == "BaseLoggerConfig":
22
+ return importlib.import_module("nshtrainer.loggers.csv").BaseLoggerConfig
23
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
24
 
13
25
  # Submodule exports