nshtrainer 0.41.1__py3-none-any.whl → 0.42.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/config/__init__.py +406 -95
- nshtrainer/config/_checkpoint/loader/__init__.py +55 -13
- nshtrainer/config/_checkpoint/metadata/__init__.py +22 -8
- nshtrainer/config/_directory/__init__.py +21 -9
- nshtrainer/config/_hf_hub/__init__.py +25 -9
- nshtrainer/config/callbacks/__init__.py +114 -29
- nshtrainer/config/callbacks/actsave/__init__.py +20 -8
- nshtrainer/config/callbacks/base/__init__.py +17 -7
- nshtrainer/config/callbacks/checkpoint/__init__.py +62 -13
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +33 -9
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +40 -10
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +33 -9
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +26 -8
- nshtrainer/config/callbacks/debug_flag/__init__.py +24 -8
- nshtrainer/config/callbacks/directory_setup/__init__.py +26 -8
- nshtrainer/config/callbacks/early_stopping/__init__.py +31 -9
- nshtrainer/config/callbacks/ema/__init__.py +20 -8
- nshtrainer/config/callbacks/finite_checks/__init__.py +26 -8
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +26 -8
- nshtrainer/config/callbacks/norm_logging/__init__.py +24 -8
- nshtrainer/config/callbacks/print_table/__init__.py +26 -8
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +26 -8
- nshtrainer/config/callbacks/shared_parameters/__init__.py +26 -8
- nshtrainer/config/callbacks/throughput_monitor/__init__.py +26 -8
- nshtrainer/config/callbacks/timer/__init__.py +22 -8
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +26 -8
- nshtrainer/config/callbacks/wandb_watch/__init__.py +24 -8
- nshtrainer/config/loggers/__init__.py +41 -14
- nshtrainer/config/loggers/_base/__init__.py +15 -7
- nshtrainer/config/loggers/csv/__init__.py +18 -8
- nshtrainer/config/loggers/tensorboard/__init__.py +24 -8
- nshtrainer/config/loggers/wandb/__init__.py +31 -11
- nshtrainer/config/lr_scheduler/__init__.py +49 -12
- nshtrainer/config/lr_scheduler/_base/__init__.py +19 -7
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +33 -9
- nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +33 -9
- nshtrainer/config/metrics/__init__.py +16 -7
- nshtrainer/config/metrics/_config/__init__.py +15 -7
- nshtrainer/config/model/__init__.py +31 -12
- nshtrainer/config/model/base/__init__.py +18 -8
- nshtrainer/config/model/config/__init__.py +30 -12
- nshtrainer/config/model/mixins/logger/__init__.py +15 -7
- nshtrainer/config/nn/__init__.py +68 -23
- nshtrainer/config/nn/mlp/__init__.py +21 -9
- nshtrainer/config/nn/nonlinearity/__init__.py +118 -22
- nshtrainer/config/optimizer/__init__.py +21 -9
- nshtrainer/config/profiler/__init__.py +28 -11
- nshtrainer/config/profiler/_base/__init__.py +17 -7
- nshtrainer/config/profiler/advanced/__init__.py +24 -8
- nshtrainer/config/profiler/pytorch/__init__.py +24 -8
- nshtrainer/config/profiler/simple/__init__.py +22 -8
- nshtrainer/config/runner/__init__.py +15 -7
- nshtrainer/config/trainer/_config/__init__.py +144 -30
- nshtrainer/config/trainer/checkpoint_connector/__init__.py +19 -7
- nshtrainer/config/util/_environment_info/__init__.py +87 -17
- nshtrainer/config/util/config/__init__.py +25 -10
- nshtrainer/config/util/config/dtype/__init__.py +15 -7
- nshtrainer/config/util/config/duration/__init__.py +27 -9
- {nshtrainer-0.41.1.dist-info → nshtrainer-0.42.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.41.1.dist-info → nshtrainer-0.42.0.dist-info}/RECORD +61 -61
- {nshtrainer-0.41.1.dist-info → nshtrainer-0.42.0.dist-info}/WHEEL +0 -0
|
@@ -1,16 +1,36 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from nshtrainer.loggers.wandb import
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
|
|
9
|
+
from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
|
|
10
|
+
from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
|
|
11
|
+
from nshtrainer.loggers.wandb import WandbUploadCodeConfig as WandbUploadCodeConfig
|
|
12
|
+
from nshtrainer.loggers.wandb import WandbWatchConfig as WandbWatchConfig
|
|
13
|
+
else:
|
|
14
|
+
|
|
15
|
+
def __getattr__(name):
|
|
16
|
+
import importlib
|
|
13
17
|
|
|
14
|
-
|
|
18
|
+
if name in globals():
|
|
19
|
+
return globals()[name]
|
|
20
|
+
if name == "CallbackConfigBase":
|
|
21
|
+
return importlib.import_module(
|
|
22
|
+
"nshtrainer.loggers.wandb"
|
|
23
|
+
).CallbackConfigBase
|
|
24
|
+
if name == "WandbLoggerConfig":
|
|
25
|
+
return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
|
|
26
|
+
if name == "WandbUploadCodeConfig":
|
|
27
|
+
return importlib.import_module(
|
|
28
|
+
"nshtrainer.loggers.wandb"
|
|
29
|
+
).WandbUploadCodeConfig
|
|
30
|
+
if name == "WandbWatchConfig":
|
|
31
|
+
return importlib.import_module("nshtrainer.loggers.wandb").WandbWatchConfig
|
|
32
|
+
if name == "BaseLoggerConfig":
|
|
33
|
+
return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
|
|
34
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
15
35
|
|
|
16
36
|
# Submodule exports
|
|
@@ -1,18 +1,55 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.lr_scheduler import (
|
|
9
|
+
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
|
10
|
+
)
|
|
11
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
12
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
|
13
|
+
from nshtrainer.lr_scheduler import (
|
|
14
|
+
ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
|
|
15
|
+
)
|
|
16
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
17
|
+
DurationConfig as DurationConfig,
|
|
18
|
+
)
|
|
19
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
|
20
|
+
MetricConfig as MetricConfig,
|
|
21
|
+
)
|
|
22
|
+
else:
|
|
23
|
+
|
|
24
|
+
def __getattr__(name):
|
|
25
|
+
import importlib
|
|
26
|
+
|
|
27
|
+
if name in globals():
|
|
28
|
+
return globals()[name]
|
|
29
|
+
if name == "LRSchedulerConfigBase":
|
|
30
|
+
return importlib.import_module(
|
|
31
|
+
"nshtrainer.lr_scheduler"
|
|
32
|
+
).LRSchedulerConfigBase
|
|
33
|
+
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
|
34
|
+
return importlib.import_module(
|
|
35
|
+
"nshtrainer.lr_scheduler"
|
|
36
|
+
).LinearWarmupCosineDecayLRSchedulerConfig
|
|
37
|
+
if name == "MetricConfig":
|
|
38
|
+
return importlib.import_module(
|
|
39
|
+
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
40
|
+
).MetricConfig
|
|
41
|
+
if name == "ReduceLROnPlateauConfig":
|
|
42
|
+
return importlib.import_module(
|
|
43
|
+
"nshtrainer.lr_scheduler"
|
|
44
|
+
).ReduceLROnPlateauConfig
|
|
45
|
+
if name == "DurationConfig":
|
|
46
|
+
return importlib.import_module(
|
|
47
|
+
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
|
48
|
+
).DurationConfig
|
|
49
|
+
if name == "LRSchedulerConfig":
|
|
50
|
+
return importlib.import_module("nshtrainer.lr_scheduler").LRSchedulerConfig
|
|
51
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
12
52
|
|
|
13
|
-
# Type aliases
|
|
14
|
-
from nshtrainer.lr_scheduler.linear_warmup_cosine import DurationConfig as DurationConfig
|
|
15
|
-
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
16
53
|
|
|
17
54
|
# Submodule exports
|
|
18
55
|
from . import _base as _base
|
|
@@ -1,12 +1,24 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.lr_scheduler._base import (
|
|
9
|
+
LRSchedulerConfigBase as LRSchedulerConfigBase,
|
|
10
|
+
)
|
|
11
|
+
else:
|
|
12
|
+
|
|
13
|
+
def __getattr__(name):
|
|
14
|
+
import importlib
|
|
9
15
|
|
|
10
|
-
|
|
16
|
+
if name in globals():
|
|
17
|
+
return globals()[name]
|
|
18
|
+
if name == "LRSchedulerConfigBase":
|
|
19
|
+
return importlib.import_module(
|
|
20
|
+
"nshtrainer.lr_scheduler._base"
|
|
21
|
+
).LRSchedulerConfigBase
|
|
22
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
11
23
|
|
|
12
24
|
# Submodule exports
|
|
@@ -1,14 +1,38 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
9
|
+
DurationConfig as DurationConfig,
|
|
10
|
+
)
|
|
11
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
12
|
+
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
|
13
|
+
)
|
|
14
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
15
|
+
LRSchedulerConfigBase as LRSchedulerConfigBase,
|
|
16
|
+
)
|
|
17
|
+
else:
|
|
18
|
+
|
|
19
|
+
def __getattr__(name):
|
|
20
|
+
import importlib
|
|
10
21
|
|
|
11
|
-
|
|
12
|
-
|
|
22
|
+
if name in globals():
|
|
23
|
+
return globals()[name]
|
|
24
|
+
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
|
25
|
+
return importlib.import_module(
|
|
26
|
+
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
|
27
|
+
).LinearWarmupCosineDecayLRSchedulerConfig
|
|
28
|
+
if name == "LRSchedulerConfigBase":
|
|
29
|
+
return importlib.import_module(
|
|
30
|
+
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
|
31
|
+
).LRSchedulerConfigBase
|
|
32
|
+
if name == "DurationConfig":
|
|
33
|
+
return importlib.import_module(
|
|
34
|
+
"nshtrainer.lr_scheduler.linear_warmup_cosine"
|
|
35
|
+
).DurationConfig
|
|
36
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
13
37
|
|
|
14
38
|
# Submodule exports
|
|
@@ -1,14 +1,38 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
|
9
|
+
LRSchedulerConfigBase as LRSchedulerConfigBase,
|
|
10
|
+
)
|
|
11
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
|
12
|
+
MetricConfig as MetricConfig,
|
|
13
|
+
)
|
|
14
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
|
15
|
+
ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
|
|
16
|
+
)
|
|
17
|
+
else:
|
|
18
|
+
|
|
19
|
+
def __getattr__(name):
|
|
20
|
+
import importlib
|
|
11
21
|
|
|
12
|
-
|
|
22
|
+
if name in globals():
|
|
23
|
+
return globals()[name]
|
|
24
|
+
if name == "MetricConfig":
|
|
25
|
+
return importlib.import_module(
|
|
26
|
+
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
27
|
+
).MetricConfig
|
|
28
|
+
if name == "ReduceLROnPlateauConfig":
|
|
29
|
+
return importlib.import_module(
|
|
30
|
+
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
31
|
+
).ReduceLROnPlateauConfig
|
|
32
|
+
if name == "LRSchedulerConfigBase":
|
|
33
|
+
return importlib.import_module(
|
|
34
|
+
"nshtrainer.lr_scheduler.reduce_lr_on_plateau"
|
|
35
|
+
).LRSchedulerConfigBase
|
|
36
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
13
37
|
|
|
14
38
|
# Submodule exports
|
|
@@ -1,13 +1,22 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.metrics import MetricConfig as MetricConfig
|
|
9
|
+
else:
|
|
10
|
+
|
|
11
|
+
def __getattr__(name):
|
|
12
|
+
import importlib
|
|
13
|
+
|
|
14
|
+
if name in globals():
|
|
15
|
+
return globals()[name]
|
|
16
|
+
if name == "MetricConfig":
|
|
17
|
+
return importlib.import_module("nshtrainer.metrics").MetricConfig
|
|
18
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
9
19
|
|
|
10
|
-
# Type aliases
|
|
11
20
|
|
|
12
21
|
# Submodule exports
|
|
13
22
|
from . import _config as _config
|
|
@@ -1,12 +1,20 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.metrics._config import MetricConfig as MetricConfig
|
|
9
|
+
else:
|
|
10
|
+
|
|
11
|
+
def __getattr__(name):
|
|
12
|
+
import importlib
|
|
9
13
|
|
|
10
|
-
|
|
14
|
+
if name in globals():
|
|
15
|
+
return globals()[name]
|
|
16
|
+
if name == "MetricConfig":
|
|
17
|
+
return importlib.import_module("nshtrainer.metrics._config").MetricConfig
|
|
18
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
11
19
|
|
|
12
20
|
# Submodule exports
|
|
@@ -1,18 +1,37 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from nshtrainer.model
|
|
13
|
-
from nshtrainer.model
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.model import BaseConfig as BaseConfig
|
|
9
|
+
from nshtrainer.model import DirectoryConfig as DirectoryConfig
|
|
10
|
+
from nshtrainer.model import MetricConfig as MetricConfig
|
|
11
|
+
from nshtrainer.model import TrainerConfig as TrainerConfig
|
|
12
|
+
from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
|
|
13
|
+
from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
|
|
14
|
+
else:
|
|
15
|
+
|
|
16
|
+
def __getattr__(name):
|
|
17
|
+
import importlib
|
|
18
|
+
|
|
19
|
+
if name in globals():
|
|
20
|
+
return globals()[name]
|
|
21
|
+
if name == "CallbackConfigBase":
|
|
22
|
+
return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
|
|
23
|
+
if name == "TrainerConfig":
|
|
24
|
+
return importlib.import_module("nshtrainer.model").TrainerConfig
|
|
25
|
+
if name == "EnvironmentConfig":
|
|
26
|
+
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
|
27
|
+
if name == "BaseConfig":
|
|
28
|
+
return importlib.import_module("nshtrainer.model").BaseConfig
|
|
29
|
+
if name == "MetricConfig":
|
|
30
|
+
return importlib.import_module("nshtrainer.model").MetricConfig
|
|
31
|
+
if name == "DirectoryConfig":
|
|
32
|
+
return importlib.import_module("nshtrainer.model").DirectoryConfig
|
|
33
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
14
34
|
|
|
15
|
-
# Type aliases
|
|
16
35
|
|
|
17
36
|
# Submodule exports
|
|
18
37
|
from . import base as base
|
|
@@ -1,13 +1,23 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.model.base import BaseConfig as BaseConfig
|
|
9
|
+
from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
|
|
10
|
+
else:
|
|
11
|
+
|
|
12
|
+
def __getattr__(name):
|
|
13
|
+
import importlib
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
if name in globals():
|
|
16
|
+
return globals()[name]
|
|
17
|
+
if name == "EnvironmentConfig":
|
|
18
|
+
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
|
19
|
+
if name == "BaseConfig":
|
|
20
|
+
return importlib.import_module("nshtrainer.model.base").BaseConfig
|
|
21
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
12
22
|
|
|
13
23
|
# Submodule exports
|
|
@@ -1,17 +1,35 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from nshtrainer.model.config import
|
|
13
|
-
from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.model.config import BaseConfig as BaseConfig
|
|
9
|
+
from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
|
|
10
|
+
from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
|
|
11
|
+
from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
|
|
12
|
+
from nshtrainer.model.config import MetricConfig as MetricConfig
|
|
13
|
+
from nshtrainer.model.config import TrainerConfig as TrainerConfig
|
|
14
|
+
else:
|
|
15
|
+
|
|
16
|
+
def __getattr__(name):
|
|
17
|
+
import importlib
|
|
14
18
|
|
|
15
|
-
|
|
19
|
+
if name in globals():
|
|
20
|
+
return globals()[name]
|
|
21
|
+
if name == "CallbackConfigBase":
|
|
22
|
+
return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
|
|
23
|
+
if name == "TrainerConfig":
|
|
24
|
+
return importlib.import_module("nshtrainer.model.config").TrainerConfig
|
|
25
|
+
if name == "EnvironmentConfig":
|
|
26
|
+
return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
|
|
27
|
+
if name == "BaseConfig":
|
|
28
|
+
return importlib.import_module("nshtrainer.model.config").BaseConfig
|
|
29
|
+
if name == "MetricConfig":
|
|
30
|
+
return importlib.import_module("nshtrainer.model.config").MetricConfig
|
|
31
|
+
if name == "DirectoryConfig":
|
|
32
|
+
return importlib.import_module("nshtrainer.model.config").DirectoryConfig
|
|
33
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
16
34
|
|
|
17
35
|
# Submodule exports
|
|
@@ -1,12 +1,20 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
|
|
9
|
+
else:
|
|
10
|
+
|
|
11
|
+
def __getattr__(name):
|
|
12
|
+
import importlib
|
|
9
13
|
|
|
10
|
-
|
|
14
|
+
if name in globals():
|
|
15
|
+
return globals()[name]
|
|
16
|
+
if name == "BaseConfig":
|
|
17
|
+
return importlib.import_module("nshtrainer.model.mixins.logger").BaseConfig
|
|
18
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
11
19
|
|
|
12
20
|
# Submodule exports
|
nshtrainer/config/nn/__init__.py
CHANGED
|
@@ -1,29 +1,74 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from nshtrainer.nn import
|
|
13
|
-
from nshtrainer.nn import
|
|
14
|
-
from nshtrainer.nn import
|
|
15
|
-
from nshtrainer.nn import
|
|
16
|
-
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
|
17
|
-
from nshtrainer.nn import
|
|
18
|
-
from nshtrainer.nn import
|
|
19
|
-
from nshtrainer.nn
|
|
20
|
-
from nshtrainer.nn import
|
|
21
|
-
from nshtrainer.nn import
|
|
22
|
-
from nshtrainer.nn import
|
|
23
|
-
from nshtrainer.nn import
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
|
|
9
|
+
from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
|
|
10
|
+
from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
11
|
+
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
|
12
|
+
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
|
13
|
+
from nshtrainer.nn import MLPConfig as MLPConfig
|
|
14
|
+
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
|
15
|
+
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
|
16
|
+
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
|
17
|
+
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
|
18
|
+
from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
|
19
|
+
from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
|
20
|
+
from nshtrainer.nn import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
|
|
21
|
+
from nshtrainer.nn import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
|
|
22
|
+
from nshtrainer.nn import SwishNonlinearityConfig as SwishNonlinearityConfig
|
|
23
|
+
from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
|
|
24
|
+
from nshtrainer.nn.nonlinearity import (
|
|
25
|
+
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
|
26
|
+
)
|
|
27
|
+
else:
|
|
28
|
+
|
|
29
|
+
def __getattr__(name):
|
|
30
|
+
import importlib
|
|
31
|
+
|
|
32
|
+
if name in globals():
|
|
33
|
+
return globals()[name]
|
|
34
|
+
if name == "BaseNonlinearityConfig":
|
|
35
|
+
return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
|
|
36
|
+
if name == "MLPConfig":
|
|
37
|
+
return importlib.import_module("nshtrainer.nn").MLPConfig
|
|
38
|
+
if name == "SwiGLUNonlinearityConfig":
|
|
39
|
+
return importlib.import_module(
|
|
40
|
+
"nshtrainer.nn.nonlinearity"
|
|
41
|
+
).SwiGLUNonlinearityConfig
|
|
42
|
+
if name == "ReLUNonlinearityConfig":
|
|
43
|
+
return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
|
|
44
|
+
if name == "SiLUNonlinearityConfig":
|
|
45
|
+
return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
|
|
46
|
+
if name == "ELUNonlinearityConfig":
|
|
47
|
+
return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
|
|
48
|
+
if name == "GELUNonlinearityConfig":
|
|
49
|
+
return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
|
|
50
|
+
if name == "SoftplusNonlinearityConfig":
|
|
51
|
+
return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
|
|
52
|
+
if name == "SoftsignNonlinearityConfig":
|
|
53
|
+
return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
|
|
54
|
+
if name == "SwishNonlinearityConfig":
|
|
55
|
+
return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
|
|
56
|
+
if name == "SoftmaxNonlinearityConfig":
|
|
57
|
+
return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
|
|
58
|
+
if name == "MishNonlinearityConfig":
|
|
59
|
+
return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
|
|
60
|
+
if name == "SigmoidNonlinearityConfig":
|
|
61
|
+
return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
|
|
62
|
+
if name == "TanhNonlinearityConfig":
|
|
63
|
+
return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
|
|
64
|
+
if name == "PReLUConfig":
|
|
65
|
+
return importlib.import_module("nshtrainer.nn").PReLUConfig
|
|
66
|
+
if name == "LeakyReLUNonlinearityConfig":
|
|
67
|
+
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
|
68
|
+
if name == "NonlinearityConfig":
|
|
69
|
+
return importlib.import_module("nshtrainer.nn").NonlinearityConfig
|
|
70
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
24
71
|
|
|
25
|
-
# Type aliases
|
|
26
|
-
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
|
27
72
|
|
|
28
73
|
# Submodule exports
|
|
29
74
|
from . import mlp as mlp
|
|
@@ -1,14 +1,26 @@
|
|
|
1
|
-
# fmt: off
|
|
2
|
-
# ruff: noqa
|
|
3
|
-
# type: ignore
|
|
4
|
-
|
|
5
1
|
__codegen__ = True
|
|
6
2
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Config/alias imports
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
|
|
9
|
+
from nshtrainer.nn.mlp import MLPConfig as MLPConfig
|
|
10
|
+
from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
|
|
11
|
+
else:
|
|
12
|
+
|
|
13
|
+
def __getattr__(name):
|
|
14
|
+
import importlib
|
|
10
15
|
|
|
11
|
-
|
|
12
|
-
|
|
16
|
+
if name in globals():
|
|
17
|
+
return globals()[name]
|
|
18
|
+
if name == "BaseNonlinearityConfig":
|
|
19
|
+
return importlib.import_module("nshtrainer.nn.mlp").BaseNonlinearityConfig
|
|
20
|
+
if name == "MLPConfig":
|
|
21
|
+
return importlib.import_module("nshtrainer.nn.mlp").MLPConfig
|
|
22
|
+
if name == "NonlinearityConfig":
|
|
23
|
+
return importlib.import_module("nshtrainer.nn.mlp").NonlinearityConfig
|
|
24
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
13
25
|
|
|
14
26
|
# Submodule exports
|