nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +92 -507
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/ll/actsave.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from nshutils.actsave import * # type: ignore # noqa: F403
2
2
 
3
- from nshtrainer.actsave import * # type: ignore # noqa: F403
3
+ from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
4
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
@@ -0,0 +1 @@
1
+ from ._config import MetricConfig as MetricConfig
@@ -0,0 +1,37 @@
1
+ import builtins
2
+ from typing import Literal
3
+
4
+ import nshconfig as C
5
+
6
+
7
+ class MetricConfig(C.Config):
8
+ name: str
9
+ """The name of the primary metric."""
10
+
11
+ mode: Literal["min", "max"]
12
+ """
13
+ The mode of the primary metric:
14
+ - "min" for metrics that should be minimized (e.g., loss)
15
+ - "max" for metrics that should be maximized (e.g., accuracy)
16
+ """
17
+
18
+ @property
19
+ def validation_monitor(self) -> str:
20
+ return f"val/{self.name}"
21
+
22
+ def __post_init__(self):
23
+ for split in ("train", "val", "test", "predict"):
24
+ if self.name.startswith(f"{split}/"):
25
+ raise ValueError(
26
+ f"Primary metric name should not start with '{split}/'. "
27
+ f"Just use '{self.name[len(split) + 1:]}' instead. "
28
+ "The split name is automatically added depending on the context."
29
+ )
30
+
31
+ @classmethod
32
+ def loss(cls, mode: Literal["min", "max"] = "min"):
33
+ return cls(name="loss", mode=mode)
34
+
35
+ @property
36
+ def best(self):
37
+ return builtins.min if self.mode == "min" else builtins.max
@@ -1,8 +1,18 @@
1
1
  from typing_extensions import TypeAlias
2
2
 
3
+ from ._environment import (
4
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
5
+ )
6
+ from ._environment import EnvironmentConfig as EnvironmentConfig
7
+ from ._environment import (
8
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
9
+ )
10
+ from ._environment import (
11
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
12
+ )
13
+ from ._environment import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
3
14
  from .base import Base as Base
4
15
  from .base import LightningModuleBase as LightningModuleBase
5
- from .config import ActSaveConfig as ActSaveConfig
6
16
  from .config import BaseConfig as BaseConfig
7
17
  from .config import BaseLoggerConfig as BaseLoggerConfig
8
18
  from .config import BaseProfilerConfig as BaseProfilerConfig
@@ -10,16 +20,6 @@ from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
10
20
  from .config import CheckpointSavingConfig as CheckpointSavingConfig
11
21
  from .config import DirectoryConfig as DirectoryConfig
12
22
  from .config import EarlyStoppingConfig as EarlyStoppingConfig
13
- from .config import (
14
- EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
15
- )
16
- from .config import EnvironmentConfig as EnvironmentConfig
17
- from .config import (
18
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
19
- )
20
- from .config import (
21
- EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
22
- )
23
23
  from .config import GradientClippingConfig as GradientClippingConfig
24
24
  from .config import (
25
25
  LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,