nshtrainer 1.0.0b14__py3-none-any.whl → 1.0.0b16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. nshtrainer/configs/__init__.py +134 -405
  2. nshtrainer/configs/_checkpoint/__init__.py +2 -25
  3. nshtrainer/configs/_checkpoint/metadata/__init__.py +2 -25
  4. nshtrainer/configs/_directory/__init__.py +5 -28
  5. nshtrainer/configs/_hf_hub/__init__.py +5 -28
  6. nshtrainer/configs/callbacks/__init__.py +52 -161
  7. nshtrainer/configs/callbacks/actsave/__init__.py +2 -23
  8. nshtrainer/configs/callbacks/base/__init__.py +1 -20
  9. nshtrainer/configs/callbacks/checkpoint/__init__.py +19 -64
  10. nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +9 -36
  11. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +10 -43
  12. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +9 -36
  13. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -29
  14. nshtrainer/configs/callbacks/debug_flag/__init__.py +4 -27
  15. nshtrainer/configs/callbacks/directory_setup/__init__.py +6 -29
  16. nshtrainer/configs/callbacks/early_stopping/__init__.py +5 -34
  17. nshtrainer/configs/callbacks/ema/__init__.py +2 -23
  18. nshtrainer/configs/callbacks/finite_checks/__init__.py +4 -29
  19. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +6 -29
  20. nshtrainer/configs/callbacks/log_epoch/__init__.py +4 -27
  21. nshtrainer/configs/callbacks/lr_monitor/__init__.py +4 -27
  22. nshtrainer/configs/callbacks/norm_logging/__init__.py +4 -29
  23. nshtrainer/configs/callbacks/print_table/__init__.py +4 -29
  24. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +6 -29
  25. nshtrainer/configs/callbacks/shared_parameters/__init__.py +6 -29
  26. nshtrainer/configs/callbacks/timer/__init__.py +4 -27
  27. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +6 -29
  28. nshtrainer/configs/callbacks/wandb_watch/__init__.py +4 -29
  29. nshtrainer/configs/loggers/__init__.py +13 -52
  30. nshtrainer/configs/loggers/_base/__init__.py +1 -18
  31. nshtrainer/configs/loggers/actsave/__init__.py +2 -25
  32. nshtrainer/configs/loggers/csv/__init__.py +2 -21
  33. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -27
  34. nshtrainer/configs/loggers/wandb/__init__.py +9 -40
  35. nshtrainer/configs/lr_scheduler/__init__.py +10 -51
  36. nshtrainer/configs/lr_scheduler/_base/__init__.py +1 -22
  37. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +9 -36
  38. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +7 -36
  39. nshtrainer/configs/metrics/__init__.py +1 -18
  40. nshtrainer/configs/metrics/_config/__init__.py +1 -18
  41. nshtrainer/configs/nn/__init__.py +19 -70
  42. nshtrainer/configs/nn/mlp/__init__.py +3 -24
  43. nshtrainer/configs/nn/nonlinearity/__init__.py +30 -121
  44. nshtrainer/configs/optimizer/__init__.py +3 -24
  45. nshtrainer/configs/profiler/__init__.py +5 -30
  46. nshtrainer/configs/profiler/_base/__init__.py +1 -20
  47. nshtrainer/configs/profiler/advanced/__init__.py +4 -27
  48. nshtrainer/configs/profiler/pytorch/__init__.py +2 -27
  49. nshtrainer/configs/profiler/simple/__init__.py +2 -25
  50. nshtrainer/configs/trainer/__init__.py +50 -169
  51. nshtrainer/configs/trainer/_config/__init__.py +50 -169
  52. nshtrainer/configs/trainer/trainer/__init__.py +2 -23
  53. nshtrainer/configs/util/__init__.py +33 -102
  54. nshtrainer/configs/util/_environment_info/__init__.py +29 -90
  55. nshtrainer/configs/util/config/__init__.py +4 -27
  56. nshtrainer/configs/util/config/dtype/__init__.py +1 -18
  57. nshtrainer/configs/util/config/duration/__init__.py +3 -30
  58. nshtrainer/trainer/_config.py +42 -10
  59. {nshtrainer-1.0.0b14.dist-info → nshtrainer-1.0.0b16.dist-info}/METADATA +1 -1
  60. {nshtrainer-1.0.0b14.dist-info → nshtrainer-1.0.0b16.dist-info}/RECORD +61 -61
  61. {nshtrainer-1.0.0b14.dist-info → nshtrainer-1.0.0b16.dist-info}/WHEEL +0 -0
@@ -2,30 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.loggers.tensorboard import BaseLoggerConfig as BaseLoggerConfig
11
- from nshtrainer.loggers.tensorboard import (
12
- TensorboardLoggerConfig as TensorboardLoggerConfig,
13
- )
14
- else:
15
-
16
- def __getattr__(name):
17
- import importlib
18
-
19
- if name in globals():
20
- return globals()[name]
21
- if name == "BaseLoggerConfig":
22
- return importlib.import_module(
23
- "nshtrainer.loggers.tensorboard"
24
- ).BaseLoggerConfig
25
- if name == "TensorboardLoggerConfig":
26
- return importlib.import_module(
27
- "nshtrainer.loggers.tensorboard"
28
- ).TensorboardLoggerConfig
29
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
-
31
- # Submodule exports
5
+ from nshtrainer.loggers.tensorboard import BaseLoggerConfig as BaseLoggerConfig
6
+ from nshtrainer.loggers.tensorboard import (
7
+ TensorboardLoggerConfig as TensorboardLoggerConfig,
8
+ )
@@ -2,43 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
11
- from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
12
- from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
13
- from nshtrainer.loggers.wandb import (
14
- WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
15
- )
16
- from nshtrainer.loggers.wandb import (
17
- WandbWatchCallbackConfig as WandbWatchCallbackConfig,
18
- )
19
- else:
20
-
21
- def __getattr__(name):
22
- import importlib
23
-
24
- if name in globals():
25
- return globals()[name]
26
- if name == "BaseLoggerConfig":
27
- return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
28
- if name == "CallbackConfigBase":
29
- return importlib.import_module(
30
- "nshtrainer.loggers.wandb"
31
- ).CallbackConfigBase
32
- if name == "WandbLoggerConfig":
33
- return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
34
- if name == "WandbUploadCodeCallbackConfig":
35
- return importlib.import_module(
36
- "nshtrainer.loggers.wandb"
37
- ).WandbUploadCodeCallbackConfig
38
- if name == "WandbWatchCallbackConfig":
39
- return importlib.import_module(
40
- "nshtrainer.loggers.wandb"
41
- ).WandbWatchCallbackConfig
42
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
43
-
44
- # Submodule exports
5
+ from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
6
+ from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
7
+ from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
8
+ from nshtrainer.loggers.wandb import (
9
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
10
+ )
11
+ from nshtrainer.loggers.wandb import (
12
+ WandbWatchCallbackConfig as WandbWatchCallbackConfig,
13
+ )
@@ -2,58 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer.lr_scheduler import (
6
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
7
+ )
8
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
9
+ from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
10
+ from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
11
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
+ DurationConfig as DurationConfig,
13
+ )
14
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
6
15
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.lr_scheduler import (
11
- LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
12
- )
13
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
14
- from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
15
- from nshtrainer.lr_scheduler import (
16
- ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
17
- )
18
- from nshtrainer.lr_scheduler.linear_warmup_cosine import (
19
- DurationConfig as DurationConfig,
20
- )
21
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
22
- MetricConfig as MetricConfig,
23
- )
24
- else:
25
-
26
- def __getattr__(name):
27
- import importlib
28
-
29
- if name in globals():
30
- return globals()[name]
31
- if name == "LRSchedulerConfigBase":
32
- return importlib.import_module(
33
- "nshtrainer.lr_scheduler"
34
- ).LRSchedulerConfigBase
35
- if name == "LinearWarmupCosineDecayLRSchedulerConfig":
36
- return importlib.import_module(
37
- "nshtrainer.lr_scheduler"
38
- ).LinearWarmupCosineDecayLRSchedulerConfig
39
- if name == "MetricConfig":
40
- return importlib.import_module(
41
- "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
42
- ).MetricConfig
43
- if name == "ReduceLROnPlateauConfig":
44
- return importlib.import_module(
45
- "nshtrainer.lr_scheduler"
46
- ).ReduceLROnPlateauConfig
47
- if name == "DurationConfig":
48
- return importlib.import_module(
49
- "nshtrainer.lr_scheduler.linear_warmup_cosine"
50
- ).DurationConfig
51
- if name == "LRSchedulerConfig":
52
- return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
53
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
54
-
55
-
56
- # Submodule exports
57
16
  from . import _base as _base
58
17
  from . import linear_warmup_cosine as linear_warmup_cosine
59
18
  from . import reduce_lr_on_plateau as reduce_lr_on_plateau
@@ -2,25 +2,4 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.lr_scheduler._base import (
11
- LRSchedulerConfigBase as LRSchedulerConfigBase,
12
- )
13
- else:
14
-
15
- def __getattr__(name):
16
- import importlib
17
-
18
- if name in globals():
19
- return globals()[name]
20
- if name == "LRSchedulerConfigBase":
21
- return importlib.import_module(
22
- "nshtrainer.lr_scheduler._base"
23
- ).LRSchedulerConfigBase
24
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
25
-
26
- # Submodule exports
5
+ from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
@@ -2,39 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.lr_scheduler.linear_warmup_cosine import (
11
- DurationConfig as DurationConfig,
12
- )
13
- from nshtrainer.lr_scheduler.linear_warmup_cosine import (
14
- LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
15
- )
16
- from nshtrainer.lr_scheduler.linear_warmup_cosine import (
17
- LRSchedulerConfigBase as LRSchedulerConfigBase,
18
- )
19
- else:
20
-
21
- def __getattr__(name):
22
- import importlib
23
-
24
- if name in globals():
25
- return globals()[name]
26
- if name == "LRSchedulerConfigBase":
27
- return importlib.import_module(
28
- "nshtrainer.lr_scheduler.linear_warmup_cosine"
29
- ).LRSchedulerConfigBase
30
- if name == "LinearWarmupCosineDecayLRSchedulerConfig":
31
- return importlib.import_module(
32
- "nshtrainer.lr_scheduler.linear_warmup_cosine"
33
- ).LinearWarmupCosineDecayLRSchedulerConfig
34
- if name == "DurationConfig":
35
- return importlib.import_module(
36
- "nshtrainer.lr_scheduler.linear_warmup_cosine"
37
- ).DurationConfig
38
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
39
-
40
- # Submodule exports
5
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
6
+ DurationConfig as DurationConfig,
7
+ )
8
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
9
+ LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
10
+ )
11
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
13
+ )
@@ -2,39 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
11
- LRSchedulerConfigBase as LRSchedulerConfigBase,
12
- )
13
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
14
- MetricConfig as MetricConfig,
15
- )
16
- from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
17
- ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
18
- )
19
- else:
20
-
21
- def __getattr__(name):
22
- import importlib
23
-
24
- if name in globals():
25
- return globals()[name]
26
- if name == "LRSchedulerConfigBase":
27
- return importlib.import_module(
28
- "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
29
- ).LRSchedulerConfigBase
30
- if name == "MetricConfig":
31
- return importlib.import_module(
32
- "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
33
- ).MetricConfig
34
- if name == "ReduceLROnPlateauConfig":
35
- return importlib.import_module(
36
- "nshtrainer.lr_scheduler.reduce_lr_on_plateau"
37
- ).ReduceLROnPlateauConfig
38
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
39
-
40
- # Submodule exports
5
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
6
+ LRSchedulerConfigBase as LRSchedulerConfigBase,
7
+ )
8
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
9
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
10
+ ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
11
+ )
@@ -2,23 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer.metrics import MetricConfig as MetricConfig
6
6
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.metrics import MetricConfig as MetricConfig
11
- else:
12
-
13
- def __getattr__(name):
14
- import importlib
15
-
16
- if name in globals():
17
- return globals()[name]
18
- if name == "MetricConfig":
19
- return importlib.import_module("nshtrainer.metrics").MetricConfig
20
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
-
22
-
23
- # Submodule exports
24
7
  from . import _config as _config
@@ -2,21 +2,4 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.metrics._config import MetricConfig as MetricConfig
11
- else:
12
-
13
- def __getattr__(name):
14
- import importlib
15
-
16
- if name in globals():
17
- return globals()[name]
18
- if name == "MetricConfig":
19
- return importlib.import_module("nshtrainer.metrics._config").MetricConfig
20
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
-
22
- # Submodule exports
5
+ from nshtrainer.metrics._config import MetricConfig as MetricConfig
@@ -2,76 +2,25 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
6
+ from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
7
+ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
8
+ from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
9
+ from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
10
+ from nshtrainer.nn import MLPConfig as MLPConfig
11
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
12
+ from nshtrainer.nn import PReLUConfig as PReLUConfig
13
+ from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
14
+ from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
15
+ from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
16
+ from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
17
+ from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
18
+ from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
19
+ from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
20
+ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
21
+ from nshtrainer.nn.nonlinearity import (
22
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
23
+ )
6
24
 
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
11
- from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
12
- from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
13
- from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
14
- from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
15
- from nshtrainer.nn import MLPConfig as MLPConfig
16
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
17
- from nshtrainer.nn import PReLUConfig as PReLUConfig
18
- from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
19
- from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
20
- from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
21
- from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
22
- from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
23
- from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
24
- from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
25
- from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
26
- from nshtrainer.nn.nonlinearity import (
27
- SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
28
- )
29
- else:
30
-
31
- def __getattr__(name):
32
- import importlib
33
-
34
- if name in globals():
35
- return globals()[name]
36
- if name == "BaseNonlinearityConfig":
37
- return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
38
- if name == "ELUNonlinearityConfig":
39
- return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
40
- if name == "GELUNonlinearityConfig":
41
- return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
42
- if name == "LeakyReLUNonlinearityConfig":
43
- return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
44
- if name == "MLPConfig":
45
- return importlib.import_module("nshtrainer.nn").MLPConfig
46
- if name == "MishNonlinearityConfig":
47
- return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
48
- if name == "PReLUConfig":
49
- return importlib.import_module("nshtrainer.nn").PReLUConfig
50
- if name == "ReLUNonlinearityConfig":
51
- return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
52
- if name == "SiLUNonlinearityConfig":
53
- return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
54
- if name == "SigmoidNonlinearityConfig":
55
- return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
56
- if name == "SoftmaxNonlinearityConfig":
57
- return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
58
- if name == "SoftplusNonlinearityConfig":
59
- return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
60
- if name == "SoftsignNonlinearityConfig":
61
- return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
62
- if name == "SwiGLUNonlinearityConfig":
63
- return importlib.import_module(
64
- "nshtrainer.nn.nonlinearity"
65
- ).SwiGLUNonlinearityConfig
66
- if name == "SwishNonlinearityConfig":
67
- return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
68
- if name == "TanhNonlinearityConfig":
69
- return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
70
- if name == "NonlinearityConfig":
71
- return importlib.import_module("nshtrainer.nn").NonlinearityConfig
72
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
73
-
74
-
75
- # Submodule exports
76
25
  from . import mlp as mlp
77
26
  from . import nonlinearity as nonlinearity
@@ -2,27 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
11
- from nshtrainer.nn.mlp import MLPConfig as MLPConfig
12
- from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
13
- else:
14
-
15
- def __getattr__(name):
16
- import importlib
17
-
18
- if name in globals():
19
- return globals()[name]
20
- if name == "BaseNonlinearityConfig":
21
- return importlib.import_module("nshtrainer.nn.mlp").BaseNonlinearityConfig
22
- if name == "MLPConfig":
23
- return importlib.import_module("nshtrainer.nn.mlp").MLPConfig
24
- if name == "NonlinearityConfig":
25
- return importlib.import_module("nshtrainer.nn.mlp").NonlinearityConfig
26
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
27
-
28
- # Submodule exports
5
+ from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
6
+ from nshtrainer.nn.mlp import MLPConfig as MLPConfig
7
+ from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
@@ -2,124 +2,33 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.nn.nonlinearity import (
11
- BaseNonlinearityConfig as BaseNonlinearityConfig,
12
- )
13
- from nshtrainer.nn.nonlinearity import (
14
- ELUNonlinearityConfig as ELUNonlinearityConfig,
15
- )
16
- from nshtrainer.nn.nonlinearity import (
17
- GELUNonlinearityConfig as GELUNonlinearityConfig,
18
- )
19
- from nshtrainer.nn.nonlinearity import (
20
- LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
21
- )
22
- from nshtrainer.nn.nonlinearity import (
23
- MishNonlinearityConfig as MishNonlinearityConfig,
24
- )
25
- from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
26
- from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
27
- from nshtrainer.nn.nonlinearity import (
28
- ReLUNonlinearityConfig as ReLUNonlinearityConfig,
29
- )
30
- from nshtrainer.nn.nonlinearity import (
31
- SigmoidNonlinearityConfig as SigmoidNonlinearityConfig,
32
- )
33
- from nshtrainer.nn.nonlinearity import (
34
- SiLUNonlinearityConfig as SiLUNonlinearityConfig,
35
- )
36
- from nshtrainer.nn.nonlinearity import (
37
- SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig,
38
- )
39
- from nshtrainer.nn.nonlinearity import (
40
- SoftplusNonlinearityConfig as SoftplusNonlinearityConfig,
41
- )
42
- from nshtrainer.nn.nonlinearity import (
43
- SoftsignNonlinearityConfig as SoftsignNonlinearityConfig,
44
- )
45
- from nshtrainer.nn.nonlinearity import (
46
- SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
47
- )
48
- from nshtrainer.nn.nonlinearity import (
49
- SwishNonlinearityConfig as SwishNonlinearityConfig,
50
- )
51
- from nshtrainer.nn.nonlinearity import (
52
- TanhNonlinearityConfig as TanhNonlinearityConfig,
53
- )
54
- else:
55
-
56
- def __getattr__(name):
57
- import importlib
58
-
59
- if name in globals():
60
- return globals()[name]
61
- if name == "BaseNonlinearityConfig":
62
- return importlib.import_module(
63
- "nshtrainer.nn.nonlinearity"
64
- ).BaseNonlinearityConfig
65
- if name == "ELUNonlinearityConfig":
66
- return importlib.import_module(
67
- "nshtrainer.nn.nonlinearity"
68
- ).ELUNonlinearityConfig
69
- if name == "GELUNonlinearityConfig":
70
- return importlib.import_module(
71
- "nshtrainer.nn.nonlinearity"
72
- ).GELUNonlinearityConfig
73
- if name == "LeakyReLUNonlinearityConfig":
74
- return importlib.import_module(
75
- "nshtrainer.nn.nonlinearity"
76
- ).LeakyReLUNonlinearityConfig
77
- if name == "MishNonlinearityConfig":
78
- return importlib.import_module(
79
- "nshtrainer.nn.nonlinearity"
80
- ).MishNonlinearityConfig
81
- if name == "PReLUConfig":
82
- return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
83
- if name == "ReLUNonlinearityConfig":
84
- return importlib.import_module(
85
- "nshtrainer.nn.nonlinearity"
86
- ).ReLUNonlinearityConfig
87
- if name == "SiLUNonlinearityConfig":
88
- return importlib.import_module(
89
- "nshtrainer.nn.nonlinearity"
90
- ).SiLUNonlinearityConfig
91
- if name == "SigmoidNonlinearityConfig":
92
- return importlib.import_module(
93
- "nshtrainer.nn.nonlinearity"
94
- ).SigmoidNonlinearityConfig
95
- if name == "SoftmaxNonlinearityConfig":
96
- return importlib.import_module(
97
- "nshtrainer.nn.nonlinearity"
98
- ).SoftmaxNonlinearityConfig
99
- if name == "SoftplusNonlinearityConfig":
100
- return importlib.import_module(
101
- "nshtrainer.nn.nonlinearity"
102
- ).SoftplusNonlinearityConfig
103
- if name == "SoftsignNonlinearityConfig":
104
- return importlib.import_module(
105
- "nshtrainer.nn.nonlinearity"
106
- ).SoftsignNonlinearityConfig
107
- if name == "SwiGLUNonlinearityConfig":
108
- return importlib.import_module(
109
- "nshtrainer.nn.nonlinearity"
110
- ).SwiGLUNonlinearityConfig
111
- if name == "SwishNonlinearityConfig":
112
- return importlib.import_module(
113
- "nshtrainer.nn.nonlinearity"
114
- ).SwishNonlinearityConfig
115
- if name == "TanhNonlinearityConfig":
116
- return importlib.import_module(
117
- "nshtrainer.nn.nonlinearity"
118
- ).TanhNonlinearityConfig
119
- if name == "NonlinearityConfig":
120
- return importlib.import_module(
121
- "nshtrainer.nn.nonlinearity"
122
- ).NonlinearityConfig
123
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
124
-
125
- # Submodule exports
5
+ from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
6
+ from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
7
+ from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
8
+ from nshtrainer.nn.nonlinearity import (
9
+ LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
10
+ )
11
+ from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
12
+ from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
13
+ from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
14
+ from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
15
+ from nshtrainer.nn.nonlinearity import (
16
+ SigmoidNonlinearityConfig as SigmoidNonlinearityConfig,
17
+ )
18
+ from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
19
+ from nshtrainer.nn.nonlinearity import (
20
+ SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig,
21
+ )
22
+ from nshtrainer.nn.nonlinearity import (
23
+ SoftplusNonlinearityConfig as SoftplusNonlinearityConfig,
24
+ )
25
+ from nshtrainer.nn.nonlinearity import (
26
+ SoftsignNonlinearityConfig as SoftsignNonlinearityConfig,
27
+ )
28
+ from nshtrainer.nn.nonlinearity import (
29
+ SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
30
+ )
31
+ from nshtrainer.nn.nonlinearity import (
32
+ SwishNonlinearityConfig as SwishNonlinearityConfig,
33
+ )
34
+ from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
@@ -2,27 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.optimizer import AdamWConfig as AdamWConfig
11
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
12
- from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
13
- else:
14
-
15
- def __getattr__(name):
16
- import importlib
17
-
18
- if name in globals():
19
- return globals()[name]
20
- if name == "AdamWConfig":
21
- return importlib.import_module("nshtrainer.optimizer").AdamWConfig
22
- if name == "OptimizerConfigBase":
23
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
24
- if name == "OptimizerConfig":
25
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
26
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
27
-
28
- # Submodule exports
5
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
6
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
7
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase