nshtrainer 0.41.0__py3-none-any.whl → 0.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/config/__init__.py +406 -95
  3. nshtrainer/config/_checkpoint/loader/__init__.py +55 -13
  4. nshtrainer/config/_checkpoint/metadata/__init__.py +22 -8
  5. nshtrainer/config/_directory/__init__.py +21 -9
  6. nshtrainer/config/_hf_hub/__init__.py +25 -9
  7. nshtrainer/config/callbacks/__init__.py +114 -29
  8. nshtrainer/config/callbacks/actsave/__init__.py +20 -8
  9. nshtrainer/config/callbacks/base/__init__.py +17 -7
  10. nshtrainer/config/callbacks/checkpoint/__init__.py +62 -13
  11. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +33 -9
  12. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +40 -10
  13. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +33 -9
  14. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +26 -8
  15. nshtrainer/config/callbacks/debug_flag/__init__.py +24 -8
  16. nshtrainer/config/callbacks/directory_setup/__init__.py +26 -8
  17. nshtrainer/config/callbacks/early_stopping/__init__.py +31 -9
  18. nshtrainer/config/callbacks/ema/__init__.py +20 -8
  19. nshtrainer/config/callbacks/finite_checks/__init__.py +26 -8
  20. nshtrainer/config/callbacks/gradient_skipping/__init__.py +26 -8
  21. nshtrainer/config/callbacks/norm_logging/__init__.py +24 -8
  22. nshtrainer/config/callbacks/print_table/__init__.py +26 -8
  23. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +26 -8
  24. nshtrainer/config/callbacks/shared_parameters/__init__.py +26 -8
  25. nshtrainer/config/callbacks/throughput_monitor/__init__.py +26 -8
  26. nshtrainer/config/callbacks/timer/__init__.py +22 -8
  27. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +26 -8
  28. nshtrainer/config/callbacks/wandb_watch/__init__.py +24 -8
  29. nshtrainer/config/loggers/__init__.py +41 -14
  30. nshtrainer/config/loggers/_base/__init__.py +15 -7
  31. nshtrainer/config/loggers/csv/__init__.py +18 -8
  32. nshtrainer/config/loggers/tensorboard/__init__.py +24 -8
  33. nshtrainer/config/loggers/wandb/__init__.py +31 -11
  34. nshtrainer/config/lr_scheduler/__init__.py +49 -12
  35. nshtrainer/config/lr_scheduler/_base/__init__.py +19 -7
  36. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +33 -9
  37. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +33 -9
  38. nshtrainer/config/metrics/__init__.py +16 -7
  39. nshtrainer/config/metrics/_config/__init__.py +15 -7
  40. nshtrainer/config/model/__init__.py +31 -12
  41. nshtrainer/config/model/base/__init__.py +18 -8
  42. nshtrainer/config/model/config/__init__.py +30 -12
  43. nshtrainer/config/model/mixins/logger/__init__.py +15 -7
  44. nshtrainer/config/nn/__init__.py +68 -23
  45. nshtrainer/config/nn/mlp/__init__.py +21 -9
  46. nshtrainer/config/nn/nonlinearity/__init__.py +118 -22
  47. nshtrainer/config/optimizer/__init__.py +21 -9
  48. nshtrainer/config/profiler/__init__.py +28 -11
  49. nshtrainer/config/profiler/_base/__init__.py +17 -7
  50. nshtrainer/config/profiler/advanced/__init__.py +24 -8
  51. nshtrainer/config/profiler/pytorch/__init__.py +24 -8
  52. nshtrainer/config/profiler/simple/__init__.py +22 -8
  53. nshtrainer/config/runner/__init__.py +15 -7
  54. nshtrainer/config/trainer/_config/__init__.py +144 -30
  55. nshtrainer/config/trainer/checkpoint_connector/__init__.py +19 -7
  56. nshtrainer/config/util/_environment_info/__init__.py +87 -17
  57. nshtrainer/config/util/config/__init__.py +25 -10
  58. nshtrainer/config/util/config/dtype/__init__.py +15 -7
  59. nshtrainer/config/util/config/duration/__init__.py +27 -9
  60. {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/METADATA +1 -1
  61. {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/RECORD +62 -62
  62. {nshtrainer-0.41.0.dist-info → nshtrainer-0.42.0.dist-info}/WHEEL +0 -0
@@ -1,16 +1,36 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
9
- from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
10
- from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
11
- from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
12
- from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
9
+ from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
10
+ from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
11
+ from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
12
+ from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
13
+ else:
14
+
15
+ def __getattr__(name):
16
+ import importlib
13
17
 
14
- # Type aliases
18
+ if name in globals():
19
+ return globals()[name]
20
+ if name == "CallbackConfigBase":
21
+ return importlib.import_module(
22
+ "nshtrainer.loggers.wandb"
23
+ ).CallbackConfigBase
24
+ if name == "WandbLoggerConfig":
25
+ return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
26
+ if name == "WandbUploadCodeConfig":
27
+ return importlib.import_module(
28
+ "nshtrainer.loggers.wandb"
29
+ ).WandbUploadCodeConfig
30
+ if name == "WandbWatchConfig":
31
+ return importlib.import_module("nshtrainer.loggers.wandb").WandbWatchConfig
32
+ if name == "BaseLoggerConfig":
33
+ return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
34
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
15
35
 
16
36
  # Submodule exports
@@ -1,18 +1,55 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
9
- from nshtrainer.lr_scheduler import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
10
- from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
11
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.lr_scheduler import (
9
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
10
+ )
11
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
12
+ from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
13
+ from nshtrainer.lr_scheduler import (
14
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
15
+ )
16
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
17
+ DurationConfig as DurationConfig,
18
+ )
19
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
20
+ MetricConfig as MetricConfig,
21
+ )
22
+ else:
23
+
24
+ def __getattr__(name):
25
+ import importlib
26
+
27
+ if name in globals():
28
+ return globals()[name]
29
+ if name == "LRSchedulerConfigBase":
30
+ return importlib.import_module(
31
+ "nshtrainer.lr_scheduler"
32
+ ).LRSchedulerConfigBase
33
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
34
+ return importlib.import_module(
35
+ "nshtrainer.lr_scheduler"
36
+ ).LinearWarmupCosineDecayLRSchedulerConfig
37
+ if name == "MetricConfig":
38
+ return importlib.import_module(
39
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
40
+ ).MetricConfig
41
+ if name == "ReduceLROnPlateauConfig":
42
+ return importlib.import_module(
43
+ "nshtrainer.lr_scheduler"
44
+ ).ReduceLROnPlateauConfig
45
+ if name == "DurationConfig":
46
+ return importlib.import_module(
47
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
48
+ ).DurationConfig
49
+ if name == "LRSchedulerConfig":
50
+ return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
51
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
52
 
13
- # Type aliases
14
- from nshtrainer.lr_scheduler.linear_warmup_cosine import DurationConfig as DurationConfig
15
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
16
53
 
17
54
  # Submodule exports
18
55
  from . import _base as _base
@@ -1,12 +1,24 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.lr_scheduler._base import (
9
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
10
+ )
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
9
15
 
10
- # Type aliases
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "LRSchedulerConfigBase":
19
+ return importlib.import_module(
20
+ "nshtrainer.lr_scheduler._base"
21
+ ).LRSchedulerConfigBase
22
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
23
 
12
24
  # Submodule exports
@@ -1,14 +1,38 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler.linear_warmup_cosine import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
9
- from nshtrainer.lr_scheduler.linear_warmup_cosine import LRSchedulerConfigBase as LRSchedulerConfigBase
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
9
+ DurationConfig as DurationConfig,
10
+ )
11
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
13
+ )
14
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
15
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
16
+ )
17
+ else:
18
+
19
+ def __getattr__(name):
20
+ import importlib
10
21
 
11
- # Type aliases
12
- from nshtrainer.lr_scheduler.linear_warmup_cosine import DurationConfig as DurationConfig
22
+ if name in globals():
23
+ return globals()[name]
24
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
25
+ return importlib.import_module(
26
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
27
+ ).LinearWarmupCosineDecayLRSchedulerConfig
28
+ if name == "LRSchedulerConfigBase":
29
+ return importlib.import_module(
30
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
31
+ ).LRSchedulerConfigBase
32
+ if name == "DurationConfig":
33
+ return importlib.import_module(
34
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
35
+ ).DurationConfig
36
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
37
 
14
38
  # Submodule exports
@@ -1,14 +1,38 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
9
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import LRSchedulerConfigBase as LRSchedulerConfigBase
10
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
9
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
10
+ )
11
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
12
+ MetricConfig as MetricConfig,
13
+ )
14
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
15
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
16
+ )
17
+ else:
18
+
19
+ def __getattr__(name):
20
+ import importlib
11
21
 
12
- # Type aliases
22
+ if name in globals():
23
+ return globals()[name]
24
+ if name == "MetricConfig":
25
+ return importlib.import_module(
26
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
27
+ ).MetricConfig
28
+ if name == "ReduceLROnPlateauConfig":
29
+ return importlib.import_module(
30
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
31
+ ).ReduceLROnPlateauConfig
32
+ if name == "LRSchedulerConfigBase":
33
+ return importlib.import_module(
34
+ "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
35
+ ).LRSchedulerConfigBase
36
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
37
 
14
38
  # Submodule exports
@@ -1,13 +1,22 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.metrics import MetricConfig as MetricConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.metrics import MetricConfig as MetricConfig
9
+ else:
10
+
11
+ def __getattr__(name):
12
+ import importlib
13
+
14
+ if name in globals():
15
+ return globals()[name]
16
+ if name == "MetricConfig":
17
+ return importlib.import_module("nshtrainer.metrics").MetricConfig
18
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
9
19
 
10
- # Type aliases
11
20
 
12
21
  # Submodule exports
13
22
  from . import _config as _config
@@ -1,12 +1,20 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.metrics._config import MetricConfig as MetricConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.metrics._config import MetricConfig as MetricConfig
9
+ else:
10
+
11
+ def __getattr__(name):
12
+ import importlib
9
13
 
10
- # Type aliases
14
+ if name in globals():
15
+ return globals()[name]
16
+ if name == "MetricConfig":
17
+ return importlib.import_module("nshtrainer.metrics._config").MetricConfig
18
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
19
 
12
20
  # Submodule exports
@@ -1,18 +1,37 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.model import MetricConfig as MetricConfig
9
- from nshtrainer.model import BaseConfig as BaseConfig
10
- from nshtrainer.model import DirectoryConfig as DirectoryConfig
11
- from nshtrainer.model import TrainerConfig as TrainerConfig
12
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
13
- from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.model import BaseConfig as BaseConfig
9
+ from nshtrainer.model import DirectoryConfig as DirectoryConfig
10
+ from nshtrainer.model import MetricConfig as MetricConfig
11
+ from nshtrainer.model import TrainerConfig as TrainerConfig
12
+ from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
13
+ from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
23
+ if name == "TrainerConfig":
24
+ return importlib.import_module("nshtrainer.model").TrainerConfig
25
+ if name == "EnvironmentConfig":
26
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
27
+ if name == "BaseConfig":
28
+ return importlib.import_module("nshtrainer.model").BaseConfig
29
+ if name == "MetricConfig":
30
+ return importlib.import_module("nshtrainer.model").MetricConfig
31
+ if name == "DirectoryConfig":
32
+ return importlib.import_module("nshtrainer.model").DirectoryConfig
33
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
14
34
 
15
- # Type aliases
16
35
 
17
36
  # Submodule exports
18
37
  from . import base as base
@@ -1,13 +1,23 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
9
- from nshtrainer.model.base import BaseConfig as BaseConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.model.base import BaseConfig as BaseConfig
9
+ from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
10
+ else:
11
+
12
+ def __getattr__(name):
13
+ import importlib
10
14
 
11
- # Type aliases
15
+ if name in globals():
16
+ return globals()[name]
17
+ if name == "EnvironmentConfig":
18
+ return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
19
+ if name == "BaseConfig":
20
+ return importlib.import_module("nshtrainer.model.base").BaseConfig
21
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
12
22
 
13
23
  # Submodule exports
@@ -1,17 +1,35 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.model.config import MetricConfig as MetricConfig
9
- from nshtrainer.model.config import BaseConfig as BaseConfig
10
- from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
11
- from nshtrainer.model.config import TrainerConfig as TrainerConfig
12
- from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
13
- from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.model.config import BaseConfig as BaseConfig
9
+ from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
10
+ from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
11
+ from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
12
+ from nshtrainer.model.config import MetricConfig as MetricConfig
13
+ from nshtrainer.model.config import TrainerConfig as TrainerConfig
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
14
18
 
15
- # Type aliases
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
23
+ if name == "TrainerConfig":
24
+ return importlib.import_module("nshtrainer.model.config").TrainerConfig
25
+ if name == "EnvironmentConfig":
26
+ return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
27
+ if name == "BaseConfig":
28
+ return importlib.import_module("nshtrainer.model.config").BaseConfig
29
+ if name == "MetricConfig":
30
+ return importlib.import_module("nshtrainer.model.config").MetricConfig
31
+ if name == "DirectoryConfig":
32
+ return importlib.import_module("nshtrainer.model.config").DirectoryConfig
33
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
16
34
 
17
35
  # Submodule exports
@@ -1,12 +1,20 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
9
+ else:
10
+
11
+ def __getattr__(name):
12
+ import importlib
9
13
 
10
- # Type aliases
14
+ if name in globals():
15
+ return globals()[name]
16
+ if name == "BaseConfig":
17
+ return importlib.import_module("nshtrainer.model.mixins.logger").BaseConfig
18
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
11
19
 
12
20
  # Submodule exports
@@ -1,29 +1,74 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.nn import MLPConfig as MLPConfig
9
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
10
- from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
11
- from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
12
- from nshtrainer.nn import PReLUConfig as PReLUConfig
13
- from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
14
- from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
15
- from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
16
- from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
17
- from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
18
- from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
19
- from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
20
- from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
21
- from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
22
- from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
23
- from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
9
+ from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
10
+ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
11
+ from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
12
+ from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
13
+ from nshtrainer.nn import MLPConfig as MLPConfig
14
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
15
+ from nshtrainer.nn import PReLUConfig as PReLUConfig
16
+ from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
17
+ from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
18
+ from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
19
+ from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
20
+ from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
21
+ from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
22
+ from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
23
+ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
24
+ from nshtrainer.nn.nonlinearity import (
25
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
26
+ )
27
+ else:
28
+
29
+ def __getattr__(name):
30
+ import importlib
31
+
32
+ if name in globals():
33
+ return globals()[name]
34
+ if name == "BaseNonlinearityConfig":
35
+ return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
36
+ if name == "MLPConfig":
37
+ return importlib.import_module("nshtrainer.nn").MLPConfig
38
+ if name == "SwiGLUNonlinearityConfig":
39
+ return importlib.import_module(
40
+ "nshtrainer.nn.nonlinearity"
41
+ ).SwiGLUNonlinearityConfig
42
+ if name == "ReLUNonlinearityConfig":
43
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
44
+ if name == "SiLUNonlinearityConfig":
45
+ return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
46
+ if name == "ELUNonlinearityConfig":
47
+ return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
48
+ if name == "GELUNonlinearityConfig":
49
+ return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
50
+ if name == "SoftplusNonlinearityConfig":
51
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
52
+ if name == "SoftsignNonlinearityConfig":
53
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
54
+ if name == "SwishNonlinearityConfig":
55
+ return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
56
+ if name == "SoftmaxNonlinearityConfig":
57
+ return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
58
+ if name == "MishNonlinearityConfig":
59
+ return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
60
+ if name == "SigmoidNonlinearityConfig":
61
+ return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
62
+ if name == "TanhNonlinearityConfig":
63
+ return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
64
+ if name == "PReLUConfig":
65
+ return importlib.import_module("nshtrainer.nn").PReLUConfig
66
+ if name == "LeakyReLUNonlinearityConfig":
67
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
68
+ if name == "NonlinearityConfig":
69
+ return importlib.import_module("nshtrainer.nn").NonlinearityConfig
70
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
24
71
 
25
- # Type aliases
26
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
27
72
 
28
73
  # Submodule exports
29
74
  from . import mlp as mlp
@@ -1,14 +1,26 @@
1
- # fmt: off
2
- # ruff: noqa
3
- # type: ignore
4
-
5
1
  __codegen__ = True
6
2
 
7
- # Config classes
8
- from nshtrainer.nn.mlp import MLPConfig as MLPConfig
9
- from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Config/alias imports
6
+
7
+ if TYPE_CHECKING:
8
+ from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
9
+ from nshtrainer.nn.mlp import MLPConfig as MLPConfig
10
+ from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
11
+ else:
12
+
13
+ def __getattr__(name):
14
+ import importlib
10
15
 
11
- # Type aliases
12
- from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
16
+ if name in globals():
17
+ return globals()[name]
18
+ if name == "BaseNonlinearityConfig":
19
+ return importlib.import_module("nshtrainer.nn.mlp").BaseNonlinearityConfig
20
+ if name == "MLPConfig":
21
+ return importlib.import_module("nshtrainer.nn.mlp").MLPConfig
22
+ if name == "NonlinearityConfig":
23
+ return importlib.import_module("nshtrainer.nn.mlp").NonlinearityConfig
24
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
13
25
 
14
26
  # Submodule exports