nshtrainer 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +49 -501
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/ll/actsave.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
1
|
from nshutils.actsave import * # type: ignore # noqa: F403
|
|
2
2
|
|
|
3
|
-
from nshtrainer.actsave import
|
|
3
|
+
from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
|
|
4
|
+
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._config import MetricConfig as MetricConfig
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MetricConfig(C.Config):
|
|
8
|
+
name: str
|
|
9
|
+
"""The name of the primary metric."""
|
|
10
|
+
|
|
11
|
+
mode: Literal["min", "max"]
|
|
12
|
+
"""
|
|
13
|
+
The mode of the primary metric:
|
|
14
|
+
- "min" for metrics that should be minimized (e.g., loss)
|
|
15
|
+
- "max" for metrics that should be maximized (e.g., accuracy)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def validation_monitor(self) -> str:
|
|
20
|
+
return f"val/{self.name}"
|
|
21
|
+
|
|
22
|
+
def __post_init__(self):
|
|
23
|
+
for split in ("train", "val", "test", "predict"):
|
|
24
|
+
if self.name.startswith(f"{split}/"):
|
|
25
|
+
raise ValueError(
|
|
26
|
+
f"Primary metric name should not start with '{split}/'. "
|
|
27
|
+
f"Just use '{self.name[len(split) + 1:]}' instead. "
|
|
28
|
+
"The split name is automatically added depending on the context."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def loss(cls, mode: Literal["min", "max"] = "min"):
|
|
33
|
+
return cls(name="loss", mode=mode)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def best(self):
|
|
37
|
+
return builtins.min if self.mode == "min" else builtins.max
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -1,8 +1,18 @@
|
|
|
1
1
|
from typing_extensions import TypeAlias
|
|
2
2
|
|
|
3
|
+
from ._environment import (
|
|
4
|
+
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
5
|
+
)
|
|
6
|
+
from ._environment import EnvironmentConfig as EnvironmentConfig
|
|
7
|
+
from ._environment import (
|
|
8
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
9
|
+
)
|
|
10
|
+
from ._environment import (
|
|
11
|
+
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
12
|
+
)
|
|
13
|
+
from ._environment import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
|
|
3
14
|
from .base import Base as Base
|
|
4
15
|
from .base import LightningModuleBase as LightningModuleBase
|
|
5
|
-
from .config import ActSaveConfig as ActSaveConfig
|
|
6
16
|
from .config import BaseConfig as BaseConfig
|
|
7
17
|
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
8
18
|
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
@@ -10,16 +20,6 @@ from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
|
10
20
|
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
11
21
|
from .config import DirectoryConfig as DirectoryConfig
|
|
12
22
|
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
13
|
-
from .config import (
|
|
14
|
-
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
15
|
-
)
|
|
16
|
-
from .config import EnvironmentConfig as EnvironmentConfig
|
|
17
|
-
from .config import (
|
|
18
|
-
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
19
|
-
)
|
|
20
|
-
from .config import (
|
|
21
|
-
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
22
|
-
)
|
|
23
23
|
from .config import GradientClippingConfig as GradientClippingConfig
|
|
24
24
|
from .config import (
|
|
25
25
|
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|