nshtrainer 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/_checkpoint/loader.py +319 -0
  3. nshtrainer/_checkpoint/metadata.py +102 -0
  4. nshtrainer/callbacks/__init__.py +17 -1
  5. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  6. nshtrainer/callbacks/base.py +7 -5
  7. nshtrainer/callbacks/ema.py +1 -1
  8. nshtrainer/callbacks/finite_checks.py +1 -1
  9. nshtrainer/callbacks/gradient_skipping.py +1 -1
  10. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  11. nshtrainer/callbacks/model_checkpoint.py +187 -0
  12. nshtrainer/callbacks/norm_logging.py +1 -1
  13. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  14. nshtrainer/callbacks/print_table.py +1 -1
  15. nshtrainer/callbacks/throughput_monitor.py +1 -1
  16. nshtrainer/callbacks/timer.py +1 -1
  17. nshtrainer/callbacks/wandb_watch.py +1 -1
  18. nshtrainer/ll/__init__.py +0 -1
  19. nshtrainer/ll/actsave.py +2 -1
  20. nshtrainer/metrics/__init__.py +1 -0
  21. nshtrainer/metrics/_config.py +37 -0
  22. nshtrainer/model/__init__.py +11 -11
  23. nshtrainer/model/_environment.py +777 -0
  24. nshtrainer/model/base.py +5 -114
  25. nshtrainer/model/config.py +49 -501
  26. nshtrainer/model/modules/logger.py +11 -6
  27. nshtrainer/runner.py +3 -6
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/WHEEL +0 -0
@@ -1,35 +1,54 @@
1
1
  import logging
2
2
  from pathlib import Path
3
+ from typing import Literal
3
4
 
4
- from lightning.fabric.utilities.types import _PATH
5
5
  from lightning.pytorch import LightningModule, Trainer
6
6
  from lightning.pytorch.callbacks import Checkpoint
7
7
  from typing_extensions import override
8
8
 
9
+ from .base import CallbackConfigBase
10
+
9
11
  log = logging.getLogger(__name__)
10
12
 
11
13
 
14
+ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
15
+ kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
16
+
17
+ dirpath: str | Path | None = None
18
+ """Directory path to save the checkpoint file."""
19
+
20
+ filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
21
+ """Checkpoint filename. This must not include the extension."""
22
+
23
+ save_weights_only: bool = False
24
+ """Whether to save only the model's weights or the entire model object."""
25
+
26
+ latest_symlink_filename: str | None = "latest.ckpt"
27
+ """Filename for the latest symlink. If None, no symlink will be created."""
28
+
29
+ @override
30
+ def create_callbacks(self, root_config):
31
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
32
+ root_config.id, "checkpoint"
33
+ )
34
+ dirpath = Path(dirpath)
35
+
36
+ yield LatestEpochCheckpoint(self, dirpath)
37
+
38
+
12
39
  class LatestEpochCheckpoint(Checkpoint):
13
- DEFAULT_FILENAME = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
14
-
15
- def __init__(
16
- self,
17
- dirpath: _PATH,
18
- filename: str | None = None,
19
- save_weights_only: bool = False,
20
- ):
40
+ def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
21
41
  super().__init__()
22
42
 
23
- self._dirpath = Path(dirpath)
24
- self._filename = filename or self.DEFAULT_FILENAME
25
- self._save_weights_only = save_weights_only
43
+ self.config = config
44
+ self.dirpath = dirpath
26
45
 
27
46
  # Also, we hold a reference to the last checkpoint path
28
47
  # to be able to remove it when a new checkpoint is saved.
29
48
  self._last_ckpt_path: Path | None = None
30
49
 
31
50
  def _ckpt_path(self, trainer: Trainer):
32
- return self._dirpath / self._filename.format(
51
+ return self.dirpath / self.config.filename.format(
33
52
  epoch=trainer.current_epoch, step=trainer.global_step
34
53
  )
35
54
 
@@ -41,5 +60,22 @@ class LatestEpochCheckpoint(Checkpoint):
41
60
 
42
61
  # Save the new checkpoint
43
62
  filepath = self._ckpt_path(trainer)
44
- trainer.save_checkpoint(filepath, self._save_weights_only)
63
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
45
64
  self._last_ckpt_path = filepath
65
+
66
+ # Create the latest symlink
67
+ if (symlink_filename := self.config.latest_symlink_filename) is not None:
68
+ symlink_path = self.dirpath / symlink_filename
69
+ if symlink_path.exists():
70
+ symlink_path.unlink()
71
+ symlink_path.symlink_to(filepath.name)
72
+ log.info(f"Created latest symlink: {symlink_path}")
73
+
74
+ def latest_checkpoint(self):
75
+ if (symlink_filename := self.config.latest_symlink_filename) is None:
76
+ return None
77
+
78
+ if not (symlink_path := self.dirpath / symlink_filename).exists():
79
+ return None
80
+
81
+ return symlink_path
@@ -0,0 +1,187 @@
1
+ import re
2
+ from datetime import timedelta
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal
6
+
7
+ from lightning.pytorch.callbacks.model_checkpoint import (
8
+ ModelCheckpoint as _ModelCheckpoint,
9
+ )
10
+ from typing_extensions import override
11
+
12
+ from ..metrics import MetricConfig
13
+ from .base import CallbackConfigBase
14
+
15
+ if TYPE_CHECKING:
16
+ from ..model.config import BaseConfig
17
+
18
+ log = getLogger(__name__)
19
+
20
+
21
+ def _convert_string(input_string: str):
22
+ # Find all variables enclosed in curly braces
23
+ variables = re.findall(r"\{(.*?)\}", input_string)
24
+
25
+ # Replace each variable with its corresponding key-value pair
26
+ output_string = input_string
27
+ for variable in variables:
28
+ # If the name is something like {variable:format}, we shouldn't process the format.
29
+ key_name = variable
30
+ if ":" in variable:
31
+ key_name, _ = variable.split(":", 1)
32
+ continue
33
+
34
+ # Replace '/' with '_' in the key name
35
+ key_name = key_name.replace("/", "_")
36
+ output_string = output_string.replace(
37
+ f"{{{variable}}}", f"{key_name}={{{variable}}}"
38
+ )
39
+
40
+ return output_string
41
+
42
+
43
+ class ModelCheckpointCallbackConfig(CallbackConfigBase):
44
+ """Arguments for the ModelCheckpoint callback."""
45
+
46
+ kind: Literal["model_checkpoint"] = "model_checkpoint"
47
+
48
+ dirpath: str | Path | None = None
49
+ """
50
+ Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
51
+ """
52
+
53
+ filename: str | None = None
54
+ """
55
+ Checkpoint filename.
56
+ If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
57
+ """
58
+
59
+ metric: MetricConfig | None = None
60
+ """
61
+ Metric to monitor for saving checkpoints.
62
+ If None, the primary metric of the runner will be used, if available.
63
+ """
64
+
65
+ verbose: bool = False
66
+ """Verbosity mode. If True, print additional information about checkpoints."""
67
+
68
+ save_last: Literal[True, False, "link"] | None = "link"
69
+ """
70
+ Whether to save the last checkpoint.
71
+ If True, saves a copy of the last checkpoint separately.
72
+ If "link", creates a symbolic link to the last checkpoint.
73
+ """
74
+
75
+ save_top_k: int = 1
76
+ """
77
+ Number of best models to save.
78
+ If -1, all models are saved.
79
+ If 0, no models are saved.
80
+ """
81
+
82
+ save_weights_only: bool = False
83
+ """Whether to save only the model's weights or the entire model object."""
84
+
85
+ auto_insert_metric_name: bool = True
86
+ """Whether to automatically insert the metric name in the checkpoint filename."""
87
+
88
+ every_n_train_steps: int | None = None
89
+ """
90
+ Number of training steps between checkpoints.
91
+ If None or 0, no checkpoints are saved during training.
92
+ """
93
+
94
+ train_time_interval: timedelta | None = None
95
+ """
96
+ Time interval between checkpoints during training.
97
+ If None, no checkpoints are saved during training based on time.
98
+ """
99
+
100
+ every_n_epochs: int | None = None
101
+ """
102
+ Number of epochs between checkpoints.
103
+ If None or 0, no checkpoints are saved at the end of epochs.
104
+ """
105
+
106
+ save_on_train_epoch_end: bool | None = None
107
+ """
108
+ Whether to run checkpointing at the end of the training epoch.
109
+ If False, checkpointing runs at the end of the validation.
110
+ """
111
+
112
+ enable_version_counter: bool = True
113
+ """Whether to append a version to the existing file name."""
114
+
115
+ auto_append_metric: bool = True
116
+ """If enabled, this will automatically add "-{monitor}" to the filename."""
117
+
118
+ def metric_or_default(self, root_config: "BaseConfig"):
119
+ if self.metric is not None:
120
+ return self.metric
121
+ if root_config.primary_metric is not None:
122
+ return root_config.primary_metric
123
+ raise ValueError("Primary metric must be provided if metric is not specified.")
124
+
125
+ def resolve_filename(self, root_config: "BaseConfig"):
126
+ metric = self.metric_or_default(root_config)
127
+
128
+ filename = self.filename
129
+ if not filename:
130
+ filename = "{epoch}-{step}"
131
+ if self.auto_append_metric:
132
+ filename = f"{filename}-{{{metric.validation_monitor}}}"
133
+
134
+ if self.auto_insert_metric_name and filename:
135
+ new_filename = _convert_string(filename)
136
+ log.critical(
137
+ f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
138
+ )
139
+ filename = new_filename
140
+
141
+ return filename
142
+
143
+ @override
144
+ def create_callbacks(self, root_config):
145
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
146
+ root_config.id, "checkpoint"
147
+ )
148
+
149
+ metric = self.metric_or_default(root_config)
150
+ filename = self.resolve_filename(root_config)
151
+
152
+ yield ModelCheckpoint(
153
+ self,
154
+ dirpath=Path(dirpath),
155
+ filename=filename,
156
+ metric=metric,
157
+ )
158
+
159
+
160
+ class ModelCheckpoint(_ModelCheckpoint):
161
+ @override
162
+ def __init__(
163
+ self,
164
+ config: ModelCheckpointCallbackConfig,
165
+ dirpath: Path,
166
+ filename: str,
167
+ metric: MetricConfig,
168
+ ):
169
+ self.config = config
170
+ del config
171
+
172
+ super().__init__(
173
+ dirpath=dirpath,
174
+ filename=filename,
175
+ monitor=metric.validation_monitor,
176
+ mode=metric.mode,
177
+ verbose=self.config.verbose,
178
+ save_last=self.config.save_last,
179
+ save_top_k=self.config.save_top_k,
180
+ save_weights_only=self.config.save_weights_only,
181
+ auto_insert_metric_name=False,
182
+ every_n_train_steps=self.config.every_n_train_steps,
183
+ train_time_interval=self.config.train_time_interval,
184
+ every_n_epochs=self.config.every_n_epochs,
185
+ save_on_train_epoch_end=self.config.save_on_train_epoch_end,
186
+ enable_version_counter=self.config.enable_version_counter,
187
+ )
@@ -180,7 +180,7 @@ class NormLoggingConfig(CallbackConfigBase):
180
180
  )
181
181
 
182
182
  @override
183
- def construct_callbacks(self, root_config):
183
+ def create_callbacks(self, root_config):
184
184
  if not self:
185
185
  return
186
186
 
@@ -1,16 +1,82 @@
1
+ import contextlib
1
2
  import datetime
2
3
  import logging
3
4
  import os
4
- from typing import Any
5
+ from pathlib import Path
6
+ from typing import Any, Literal
5
7
 
6
- from lightning.pytorch import Trainer
8
+ from lightning.pytorch import Trainer as LightningTrainer
7
9
  from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
8
10
  from typing_extensions import override
9
11
 
12
+ from .base import CallbackConfigBase
13
+
10
14
  log = logging.getLogger(__name__)
11
15
 
12
16
 
17
+ @contextlib.contextmanager
18
+ def _monkey_patch_disable_barrier(trainer: LightningTrainer):
19
+ """
20
+ Monkey-patch the strategy instance to make the barrier operation a no-op.
21
+
22
+ We do this because `save_checkpoint` calls `barrier`. This is okay in most
23
+ cases, but when we want to save a checkpoint in the case of an exception,
24
+ `barrier` causes a deadlock. So we monkey-patch the strategy instance to
25
+ make the barrier operation a no-op.
26
+ """
27
+
28
+ # We monkey-patch the barrier method to do nothing.
29
+ original_barrier = trainer.strategy.barrier
30
+
31
+ def new_barrier(*args, **kwargs):
32
+ log.warning("Monkey-patched no-op barrier.")
33
+ pass
34
+
35
+ trainer.strategy.barrier = new_barrier
36
+ log.warning("Monkey-patched barrier to no-op.")
37
+
38
+ try:
39
+ yield
40
+ finally:
41
+ trainer.strategy.barrier = original_barrier
42
+ log.warning("Reverted monkey-patched barrier.")
43
+
44
+
45
+ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
46
+ kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
47
+
48
+ dirpath: str | Path | None = None
49
+ """Directory path to save the checkpoint file."""
50
+
51
+ filename: str | None = None
52
+ """Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
53
+
54
+ @override
55
+ def create_callbacks(self, root_config):
56
+ from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
57
+
58
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
59
+ root_config.id, "checkpoint"
60
+ )
61
+
62
+ if not (filename := self.filename):
63
+ filename = f"on_exception_{root_config.id}"
64
+ yield OnExceptionCheckpoint(self, dirpath=Path(dirpath), filename=filename)
65
+
66
+
13
67
  class OnExceptionCheckpoint(_OnExceptionCheckpoint):
68
+ @override
69
+ def __init__(
70
+ self,
71
+ config: OnExceptionCheckpointCallbackConfig,
72
+ dirpath: Path,
73
+ filename: str,
74
+ ):
75
+ self.config = config
76
+ del config
77
+
78
+ super().__init__(dirpath, filename)
79
+
14
80
  @property
15
81
  @override
16
82
  def ckpt_path(self) -> str:
@@ -22,23 +88,11 @@ class OnExceptionCheckpoint(_OnExceptionCheckpoint):
22
88
  return f"{ckpt_path}_{timestamp}{ext}"
23
89
 
24
90
  @override
25
- def on_exception(self, trainer: Trainer, *_: Any, **__: Any) -> None:
26
- # We override this to checkpoint the model manually,
27
- # without calling the dist barrier.
28
-
29
- # trainer.save_checkpoint(self.ckpt_path)
30
-
31
- if trainer.model is None:
32
- raise AttributeError(
33
- "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
34
- " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
35
- )
36
- checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
37
- trainer.strategy.save_checkpoint(
38
- checkpoint, self.ckpt_path, storage_options=None
39
- )
40
- # self.strategy.barrier("Trainer.save_checkpoint") # <-- This is disabled
41
-
42
- @override
43
- def teardown(self, trainer: Trainer, *_: Any, **__: Any) -> None:
44
- trainer.strategy.remove_checkpoint(self.ckpt_path)
91
+ def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
92
+ # Monkey-patch the strategy instance to make the barrier operation a no-op.
93
+ # We do this because `save_checkpoint` calls `barrier`. This is okay in most
94
+ # cases, but when we want to save a checkpoint in the case of an exception,
95
+ # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
96
+ # make the barrier operation a no-op.
97
+ with _monkey_patch_disable_barrier(trainer):
98
+ return super().on_exception(trainer, *args, **kwargs)
@@ -86,5 +86,5 @@ class PrintTableMetricsConfig(CallbackConfigBase):
86
86
  """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
87
87
 
88
88
  @override
89
- def construct_callbacks(self, root_config):
89
+ def create_callbacks(self, root_config):
90
90
  yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
@@ -52,5 +52,5 @@ class ThroughputMonitorConfig(CallbackConfigBase):
52
52
  """Number of batches to use for a rolling average."""
53
53
 
54
54
  @override
55
- def construct_callbacks(self, root_config):
55
+ def create_callbacks(self, root_config):
56
56
  yield ThroughputMonitor(window_size=self.window_size)
@@ -153,5 +153,5 @@ class EpochTimerConfig(CallbackConfigBase):
153
153
  name: Literal["epoch_timer"] = "epoch_timer"
154
154
 
155
155
  @override
156
- def construct_callbacks(self, root_config):
156
+ def create_callbacks(self, root_config):
157
157
  yield EpochTimer()
@@ -99,5 +99,5 @@ class WandbWatchConfig(CallbackConfigBase):
99
99
  return self.enabled
100
100
 
101
101
  @override
102
- def construct_callbacks(self, root_config):
102
+ def create_callbacks(self, root_config):
103
103
  yield WandbWatchCallback(self)
nshtrainer/ll/__init__.py CHANGED
@@ -21,7 +21,6 @@ from .log import init_python_logging as init_python_logging
21
21
  from .log import lovely as lovely
22
22
  from .log import pretty as pretty
23
23
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
- from .model import ActSaveConfig as ActSaveConfig
25
24
  from .model import Base as Base
26
25
  from .model import BaseConfig as BaseConfig
27
26
  from .model import BaseLoggerConfig as BaseLoggerConfig
nshtrainer/ll/actsave.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from nshutils.actsave import * # type: ignore # noqa: F403
2
2
 
3
- from nshtrainer.actsave import * # type: ignore # noqa: F403
3
+ from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
4
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
@@ -0,0 +1 @@
1
+ from ._config import MetricConfig as MetricConfig
@@ -0,0 +1,37 @@
1
+ import builtins
2
+ from typing import Literal
3
+
4
+ import nshconfig as C
5
+
6
+
7
+ class MetricConfig(C.Config):
8
+ name: str
9
+ """The name of the primary metric."""
10
+
11
+ mode: Literal["min", "max"]
12
+ """
13
+ The mode of the primary metric:
14
+ - "min" for metrics that should be minimized (e.g., loss)
15
+ - "max" for metrics that should be maximized (e.g., accuracy)
16
+ """
17
+
18
+ @property
19
+ def validation_monitor(self) -> str:
20
+ return f"val/{self.name}"
21
+
22
+ def __post_init__(self):
23
+ for split in ("train", "val", "test", "predict"):
24
+ if self.name.startswith(f"{split}/"):
25
+ raise ValueError(
26
+ f"Primary metric name should not start with '{split}/'. "
27
+ f"Just use '{self.name[len(split) + 1:]}' instead. "
28
+ "The split name is automatically added depending on the context."
29
+ )
30
+
31
+ @classmethod
32
+ def loss(cls, mode: Literal["min", "max"] = "min"):
33
+ return cls(name="loss", mode=mode)
34
+
35
+ @property
36
+ def best(self):
37
+ return builtins.min if self.mode == "min" else builtins.max
@@ -1,8 +1,18 @@
1
1
  from typing_extensions import TypeAlias
2
2
 
3
+ from ._environment import (
4
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
5
+ )
6
+ from ._environment import EnvironmentConfig as EnvironmentConfig
7
+ from ._environment import (
8
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
9
+ )
10
+ from ._environment import (
11
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
12
+ )
13
+ from ._environment import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
3
14
  from .base import Base as Base
4
15
  from .base import LightningModuleBase as LightningModuleBase
5
- from .config import ActSaveConfig as ActSaveConfig
6
16
  from .config import BaseConfig as BaseConfig
7
17
  from .config import BaseLoggerConfig as BaseLoggerConfig
8
18
  from .config import BaseProfilerConfig as BaseProfilerConfig
@@ -10,16 +20,6 @@ from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
10
20
  from .config import CheckpointSavingConfig as CheckpointSavingConfig
11
21
  from .config import DirectoryConfig as DirectoryConfig
12
22
  from .config import EarlyStoppingConfig as EarlyStoppingConfig
13
- from .config import (
14
- EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
15
- )
16
- from .config import EnvironmentConfig as EnvironmentConfig
17
- from .config import (
18
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
19
- )
20
- from .config import (
21
- EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
22
- )
23
23
  from .config import GradientClippingConfig as GradientClippingConfig
24
24
  from .config import (
25
25
  LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,