nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +92 -507
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py CHANGED
@@ -2,13 +2,14 @@ from . import _experimental as _experimental
2
2
  from . import callbacks as callbacks
3
3
  from . import data as data
4
4
  from . import lr_scheduler as lr_scheduler
5
+ from . import metrics as metrics
5
6
  from . import model as model
6
7
  from . import nn as nn
7
8
  from . import optimizer as optimizer
9
+ from .metrics import MetricConfig as MetricConfig
8
10
  from .model import Base as Base
9
11
  from .model import BaseConfig as BaseConfig
10
12
  from .model import ConfigList as ConfigList
11
13
  from .model import LightningModuleBase as LightningModuleBase
12
- from .model import MetricConfig as MetricConfig
13
14
  from .runner import Runner as Runner
14
15
  from .trainer import Trainer as Trainer
@@ -14,15 +14,27 @@ from .interval import EpochIntervalCallback as EpochIntervalCallback
14
14
  from .interval import IntervalCallback as IntervalCallback
15
15
  from .interval import StepIntervalCallback as StepIntervalCallback
16
16
  from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
17
+ from .latest_epoch_checkpoint import (
18
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
19
+ )
17
20
  from .log_epoch import LogEpochCallback as LogEpochCallback
21
+ from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
22
+ from .model_checkpoint import (
23
+ ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
24
+ )
18
25
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
19
26
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
20
27
  from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
28
+ from .on_exception_checkpoint import (
29
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
30
+ )
21
31
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
22
32
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
23
33
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
24
34
  from .timer import EpochTimer as EpochTimer
25
35
  from .timer import EpochTimerConfig as EpochTimerConfig
36
+ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
37
+ from .wandb_watch import WandbWatchConfig as WandbWatchConfig
26
38
 
27
39
  CallbackConfig = Annotated[
28
40
  ThroughputMonitorConfig
@@ -31,6 +43,10 @@ CallbackConfig = Annotated[
31
43
  | FiniteChecksConfig
32
44
  | NormLoggingConfig
33
45
  | GradientSkippingConfig
34
- | EMAConfig,
46
+ | EMAConfig
47
+ | ModelCheckpointCallbackConfig
48
+ | LatestEpochCheckpointCallbackConfig
49
+ | OnExceptionCheckpointCallbackConfig
50
+ | WandbWatchConfig,
35
51
  C.Field(discriminator="name"),
36
52
  ]
@@ -1,28 +1,87 @@
1
1
  import contextlib
2
- from typing import TYPE_CHECKING, Literal, cast
2
+ from pathlib import Path
3
+ from typing import Literal
3
4
 
4
5
  from lightning.pytorch import LightningModule, Trainer
5
6
  from lightning.pytorch.callbacks.callback import Callback
6
- from nshutils.actsave import ActSave
7
7
  from typing_extensions import TypeAlias, override
8
8
 
9
- if TYPE_CHECKING:
10
- from ..model.config import BaseConfig
9
+ from .base import CallbackConfigBase
10
+
11
+ try:
12
+ from nshutils import ActSave # type: ignore
13
+ except ImportError:
14
+ ActSave = None
11
15
 
12
16
  Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
13
17
 
14
18
 
19
+ class ActSaveConfig(CallbackConfigBase):
20
+ enabled: bool = True
21
+ """Enable activation saving."""
22
+
23
+ save_dir: Path | None = None
24
+ """Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
25
+
26
+ def __bool__(self):
27
+ return self.enabled
28
+
29
+ @override
30
+ def create_callbacks(self, root_config):
31
+ yield ActSaveCallback(
32
+ self,
33
+ self.save_dir
34
+ or root_config.directory.resolve_subdirectory(root_config.id, "activation"),
35
+ )
36
+
37
+
15
38
  class ActSaveCallback(Callback):
16
- def __init__(self):
39
+ def __init__(self, config: ActSaveConfig, save_dir: Path):
17
40
  super().__init__()
18
41
 
42
+ self.config = config
43
+ self.save_dir = save_dir
44
+ self._enabled_context: contextlib._GeneratorContextManager | None = None
19
45
  self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
20
46
 
47
+ @override
48
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
49
+ super().setup(trainer, pl_module, stage)
50
+
51
+ if not self.config:
52
+ return
53
+
54
+ if ActSave is None:
55
+ raise ImportError(
56
+ "ActSave is not installed. Please install nshutils to use the ActSaveCallback."
57
+ )
58
+
59
+ context = ActSave.enabled(self.save_dir)
60
+ context.__enter__()
61
+ self._enabled_context = context
62
+
63
+ @override
64
+ def teardown(
65
+ self, trainer: Trainer, pl_module: LightningModule, stage: str
66
+ ) -> None:
67
+ super().teardown(trainer, pl_module, stage)
68
+
69
+ if not self.config:
70
+ return
71
+
72
+ if self._enabled_context is not None:
73
+ self._enabled_context.__exit__(None, None, None)
74
+ self._enabled_context = None
75
+
21
76
  def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
22
- hparams = cast("BaseConfig", pl_module.hparams)
23
- if not hparams.trainer.actsave:
77
+ if not self.config:
24
78
  return
25
79
 
80
+ if ActSave is None:
81
+ raise ImportError(
82
+ "ActSave is not installed. Please install nshutils to use the ActSaveCallback."
83
+ )
84
+
26
85
  # If we have an active context manager for this stage, exit it
27
86
  if active_contexts := self._active_contexts.get(stage):
28
87
  active_contexts.__exit__(None, None, None)
@@ -33,12 +92,11 @@ class ActSaveCallback(Callback):
33
92
  self._active_contexts[stage] = context
34
93
 
35
94
  def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
36
- hparams = cast("BaseConfig", pl_module.hparams)
37
- if not hparams.trainer.actsave:
95
+ if not self.config:
38
96
  return
39
97
 
40
98
  # If we have an active context manager for this stage, exit it
41
- if active_contexts := self._active_contexts.get(stage):
99
+ if (active_contexts := self._active_contexts.pop(stage, None)) is not None:
42
100
  active_contexts.__exit__(None, None, None)
43
101
 
44
102
  @override
@@ -46,16 +46,16 @@ class CallbackConfigBase(C.Config, ABC):
46
46
  )
47
47
 
48
48
  @abstractmethod
49
- def construct_callbacks(
49
+ def create_callbacks(
50
50
  self, root_config: "BaseConfig"
51
51
  ) -> Iterable[Callback | CallbackWithMetadata]: ...
52
52
 
53
53
 
54
54
  # region Config resolution helpers
55
- def _construct_callbacks_with_metadata(
55
+ def _create_callbacks_with_metadata(
56
56
  config: CallbackConfigBase, root_config: "BaseConfig"
57
57
  ) -> Iterable[CallbackWithMetadata]:
58
- for callback in config.construct_callbacks(root_config):
58
+ for callback in config.create_callbacks(root_config):
59
59
  if isinstance(callback, CallbackWithMetadata):
60
60
  yield callback
61
61
  continue
@@ -99,12 +99,14 @@ def _process_and_filter_callbacks(
99
99
 
100
100
  def resolve_all_callbacks(root_config: "BaseConfig"):
101
101
  callback_configs = [
102
- config for config in root_config.ll_all_callback_configs() if config is not None
102
+ config
103
+ for config in root_config._nshtrainer_all_callback_configs()
104
+ if config is not None
103
105
  ]
104
106
  callbacks = _process_and_filter_callbacks(
105
107
  callback
106
108
  for callback_config in callback_configs
107
- for callback in _construct_callbacks_with_metadata(callback_config, root_config)
109
+ for callback in _create_callbacks_with_metadata(callback_config, root_config)
108
110
  )
109
111
  return callbacks
110
112
 
@@ -374,7 +374,7 @@ class EMAConfig(CallbackConfigBase):
374
374
  """Offload weights to CPU."""
375
375
 
376
376
  @override
377
- def construct_callbacks(self, root_config):
377
+ def create_callbacks(self, root_config):
378
378
  yield EMA(
379
379
  decay=self.decay,
380
380
  validate_original_weights=self.validate_original_weights,
@@ -68,7 +68,7 @@ class FiniteChecksConfig(CallbackConfigBase):
68
68
  """Whether to check for None gradients"""
69
69
 
70
70
  @override
71
- def construct_callbacks(self, root_config):
71
+ def create_callbacks(self, root_config):
72
72
  yield FiniteChecksCallback(
73
73
  nonfinite_grads=self.nonfinite_grads,
74
74
  none_grads=self.none_grads,
@@ -99,5 +99,5 @@ class GradientSkippingConfig(CallbackConfigBase):
99
99
  """
100
100
 
101
101
  @override
102
- def construct_callbacks(self, root_config):
102
+ def create_callbacks(self, root_config):
103
103
  yield GradientSkipping(self)
@@ -1,35 +1,54 @@
1
1
  import logging
2
2
  from pathlib import Path
3
+ from typing import Literal
3
4
 
4
- from lightning.fabric.utilities.types import _PATH
5
5
  from lightning.pytorch import LightningModule, Trainer
6
6
  from lightning.pytorch.callbacks import Checkpoint
7
7
  from typing_extensions import override
8
8
 
9
+ from .base import CallbackConfigBase
10
+
9
11
  log = logging.getLogger(__name__)
10
12
 
11
13
 
14
+ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
15
+ kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
16
+
17
+ dirpath: str | Path | None = None
18
+ """Directory path to save the checkpoint file."""
19
+
20
+ filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
21
+ """Checkpoint filename. This must not include the extension."""
22
+
23
+ save_weights_only: bool = False
24
+ """Whether to save only the model's weights or the entire model object."""
25
+
26
+ latest_symlink_filename: str | None = "latest.ckpt"
27
+ """Filename for the latest symlink. If None, no symlink will be created."""
28
+
29
+ @override
30
+ def create_callbacks(self, root_config):
31
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
32
+ root_config.id, "checkpoint"
33
+ )
34
+ dirpath = Path(dirpath)
35
+
36
+ yield LatestEpochCheckpoint(self, dirpath)
37
+
38
+
12
39
  class LatestEpochCheckpoint(Checkpoint):
13
- DEFAULT_FILENAME = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
14
-
15
- def __init__(
16
- self,
17
- dirpath: _PATH,
18
- filename: str | None = None,
19
- save_weights_only: bool = False,
20
- ):
40
+ def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
21
41
  super().__init__()
22
42
 
23
- self._dirpath = Path(dirpath)
24
- self._filename = filename or self.DEFAULT_FILENAME
25
- self._save_weights_only = save_weights_only
43
+ self.config = config
44
+ self.dirpath = dirpath
26
45
 
27
46
  # Also, we hold a reference to the last checkpoint path
28
47
  # to be able to remove it when a new checkpoint is saved.
29
48
  self._last_ckpt_path: Path | None = None
30
49
 
31
50
  def _ckpt_path(self, trainer: Trainer):
32
- return self._dirpath / self._filename.format(
51
+ return self.dirpath / self.config.filename.format(
33
52
  epoch=trainer.current_epoch, step=trainer.global_step
34
53
  )
35
54
 
@@ -41,5 +60,22 @@ class LatestEpochCheckpoint(Checkpoint):
41
60
 
42
61
  # Save the new checkpoint
43
62
  filepath = self._ckpt_path(trainer)
44
- trainer.save_checkpoint(filepath, self._save_weights_only)
63
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
45
64
  self._last_ckpt_path = filepath
65
+
66
+ # Create the latest symlink
67
+ if (symlink_filename := self.config.latest_symlink_filename) is not None:
68
+ symlink_path = self.dirpath / symlink_filename
69
+ if symlink_path.exists():
70
+ symlink_path.unlink()
71
+ symlink_path.symlink_to(filepath.name)
72
+ log.info(f"Created latest symlink: {symlink_path}")
73
+
74
+ def latest_checkpoint(self):
75
+ if (symlink_filename := self.config.latest_symlink_filename) is None:
76
+ return None
77
+
78
+ if not (symlink_path := self.dirpath / symlink_filename).exists():
79
+ return None
80
+
81
+ return symlink_path
@@ -0,0 +1,187 @@
1
+ import re
2
+ from datetime import timedelta
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal
6
+
7
+ from lightning.pytorch.callbacks.model_checkpoint import (
8
+ ModelCheckpoint as _ModelCheckpoint,
9
+ )
10
+ from typing_extensions import override
11
+
12
+ from ..metrics import MetricConfig
13
+ from .base import CallbackConfigBase
14
+
15
+ if TYPE_CHECKING:
16
+ from ..model.config import BaseConfig
17
+
18
+ log = getLogger(__name__)
19
+
20
+
21
+ def _convert_string(input_string: str):
22
+ # Find all variables enclosed in curly braces
23
+ variables = re.findall(r"\{(.*?)\}", input_string)
24
+
25
+ # Replace each variable with its corresponding key-value pair
26
+ output_string = input_string
27
+ for variable in variables:
28
+ # If the name is something like {variable:format}, we shouldn't process the format.
29
+ key_name = variable
30
+ if ":" in variable:
31
+ key_name, _ = variable.split(":", 1)
32
+ continue
33
+
34
+ # Replace '/' with '_' in the key name
35
+ key_name = key_name.replace("/", "_")
36
+ output_string = output_string.replace(
37
+ f"{{{variable}}}", f"{key_name}={{{variable}}}"
38
+ )
39
+
40
+ return output_string
41
+
42
+
43
+ class ModelCheckpointCallbackConfig(CallbackConfigBase):
44
+ """Arguments for the ModelCheckpoint callback."""
45
+
46
+ kind: Literal["model_checkpoint"] = "model_checkpoint"
47
+
48
+ dirpath: str | Path | None = None
49
+ """
50
+ Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
51
+ """
52
+
53
+ filename: str | None = None
54
+ """
55
+ Checkpoint filename.
56
+ If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
57
+ """
58
+
59
+ metric: MetricConfig | None = None
60
+ """
61
+ Metric to monitor for saving checkpoints.
62
+ If None, the primary metric of the runner will be used, if available.
63
+ """
64
+
65
+ verbose: bool = False
66
+ """Verbosity mode. If True, print additional information about checkpoints."""
67
+
68
+ save_last: Literal[True, False, "link"] | None = "link"
69
+ """
70
+ Whether to save the last checkpoint.
71
+ If True, saves a copy of the last checkpoint separately.
72
+ If "link", creates a symbolic link to the last checkpoint.
73
+ """
74
+
75
+ save_top_k: int = 1
76
+ """
77
+ Number of best models to save.
78
+ If -1, all models are saved.
79
+ If 0, no models are saved.
80
+ """
81
+
82
+ save_weights_only: bool = False
83
+ """Whether to save only the model's weights or the entire model object."""
84
+
85
+ auto_insert_metric_name: bool = True
86
+ """Whether to automatically insert the metric name in the checkpoint filename."""
87
+
88
+ every_n_train_steps: int | None = None
89
+ """
90
+ Number of training steps between checkpoints.
91
+ If None or 0, no checkpoints are saved during training.
92
+ """
93
+
94
+ train_time_interval: timedelta | None = None
95
+ """
96
+ Time interval between checkpoints during training.
97
+ If None, no checkpoints are saved during training based on time.
98
+ """
99
+
100
+ every_n_epochs: int | None = None
101
+ """
102
+ Number of epochs between checkpoints.
103
+ If None or 0, no checkpoints are saved at the end of epochs.
104
+ """
105
+
106
+ save_on_train_epoch_end: bool | None = None
107
+ """
108
+ Whether to run checkpointing at the end of the training epoch.
109
+ If False, checkpointing runs at the end of the validation.
110
+ """
111
+
112
+ enable_version_counter: bool = True
113
+ """Whether to append a version to the existing file name."""
114
+
115
+ auto_append_metric: bool = True
116
+ """If enabled, this will automatically add "-{monitor}" to the filename."""
117
+
118
+ def metric_or_default(self, root_config: "BaseConfig"):
119
+ if self.metric is not None:
120
+ return self.metric
121
+ if root_config.primary_metric is not None:
122
+ return root_config.primary_metric
123
+ raise ValueError("Primary metric must be provided if metric is not specified.")
124
+
125
+ def resolve_filename(self, root_config: "BaseConfig"):
126
+ metric = self.metric_or_default(root_config)
127
+
128
+ filename = self.filename
129
+ if not filename:
130
+ filename = "{epoch}-{step}"
131
+ if self.auto_append_metric:
132
+ filename = f"{filename}-{{{metric.validation_monitor}}}"
133
+
134
+ if self.auto_insert_metric_name and filename:
135
+ new_filename = _convert_string(filename)
136
+ log.critical(
137
+ f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
138
+ )
139
+ filename = new_filename
140
+
141
+ return filename
142
+
143
+ @override
144
+ def create_callbacks(self, root_config):
145
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
146
+ root_config.id, "checkpoint"
147
+ )
148
+
149
+ metric = self.metric_or_default(root_config)
150
+ filename = self.resolve_filename(root_config)
151
+
152
+ yield ModelCheckpoint(
153
+ self,
154
+ dirpath=Path(dirpath),
155
+ filename=filename,
156
+ metric=metric,
157
+ )
158
+
159
+
160
+ class ModelCheckpoint(_ModelCheckpoint):
161
+ @override
162
+ def __init__(
163
+ self,
164
+ config: ModelCheckpointCallbackConfig,
165
+ dirpath: Path,
166
+ filename: str,
167
+ metric: MetricConfig,
168
+ ):
169
+ self.config = config
170
+ del config
171
+
172
+ super().__init__(
173
+ dirpath=dirpath,
174
+ filename=filename,
175
+ monitor=metric.validation_monitor,
176
+ mode=metric.mode,
177
+ verbose=self.config.verbose,
178
+ save_last=self.config.save_last,
179
+ save_top_k=self.config.save_top_k,
180
+ save_weights_only=self.config.save_weights_only,
181
+ auto_insert_metric_name=False,
182
+ every_n_train_steps=self.config.every_n_train_steps,
183
+ train_time_interval=self.config.train_time_interval,
184
+ every_n_epochs=self.config.every_n_epochs,
185
+ save_on_train_epoch_end=self.config.save_on_train_epoch_end,
186
+ enable_version_counter=self.config.enable_version_counter,
187
+ )
@@ -180,7 +180,7 @@ class NormLoggingConfig(CallbackConfigBase):
180
180
  )
181
181
 
182
182
  @override
183
- def construct_callbacks(self, root_config):
183
+ def create_callbacks(self, root_config):
184
184
  if not self:
185
185
  return
186
186
 
@@ -1,16 +1,82 @@
1
+ import contextlib
1
2
  import datetime
2
3
  import logging
3
4
  import os
4
- from typing import Any
5
+ from pathlib import Path
6
+ from typing import Any, Literal
5
7
 
6
- from lightning.pytorch import Trainer
8
+ from lightning.pytorch import Trainer as LightningTrainer
7
9
  from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
8
10
  from typing_extensions import override
9
11
 
12
+ from .base import CallbackConfigBase
13
+
10
14
  log = logging.getLogger(__name__)
11
15
 
12
16
 
17
+ @contextlib.contextmanager
18
+ def _monkey_patch_disable_barrier(trainer: LightningTrainer):
19
+ """
20
+ Monkey-patch the strategy instance to make the barrier operation a no-op.
21
+
22
+ We do this because `save_checkpoint` calls `barrier`. This is okay in most
23
+ cases, but when we want to save a checkpoint in the case of an exception,
24
+ `barrier` causes a deadlock. So we monkey-patch the strategy instance to
25
+ make the barrier operation a no-op.
26
+ """
27
+
28
+ # We monkey-patch the barrier method to do nothing.
29
+ original_barrier = trainer.strategy.barrier
30
+
31
+ def new_barrier(*args, **kwargs):
32
+ log.warning("Monkey-patched no-op barrier.")
33
+ pass
34
+
35
+ trainer.strategy.barrier = new_barrier
36
+ log.warning("Monkey-patched barrier to no-op.")
37
+
38
+ try:
39
+ yield
40
+ finally:
41
+ trainer.strategy.barrier = original_barrier
42
+ log.warning("Reverted monkey-patched barrier.")
43
+
44
+
45
+ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
46
+ kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
47
+
48
+ dirpath: str | Path | None = None
49
+ """Directory path to save the checkpoint file."""
50
+
51
+ filename: str | None = None
52
+ """Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
53
+
54
+ @override
55
+ def create_callbacks(self, root_config):
56
+ from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
57
+
58
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
59
+ root_config.id, "checkpoint"
60
+ )
61
+
62
+ if not (filename := self.filename):
63
+ filename = f"on_exception_{root_config.id}"
64
+ yield OnExceptionCheckpoint(self, dirpath=Path(dirpath), filename=filename)
65
+
66
+
13
67
  class OnExceptionCheckpoint(_OnExceptionCheckpoint):
68
+ @override
69
+ def __init__(
70
+ self,
71
+ config: OnExceptionCheckpointCallbackConfig,
72
+ dirpath: Path,
73
+ filename: str,
74
+ ):
75
+ self.config = config
76
+ del config
77
+
78
+ super().__init__(dirpath, filename)
79
+
14
80
  @property
15
81
  @override
16
82
  def ckpt_path(self) -> str:
@@ -22,23 +88,11 @@ class OnExceptionCheckpoint(_OnExceptionCheckpoint):
22
88
  return f"{ckpt_path}_{timestamp}{ext}"
23
89
 
24
90
  @override
25
- def on_exception(self, trainer: Trainer, *_: Any, **__: Any) -> None:
26
- # We override this to checkpoint the model manually,
27
- # without calling the dist barrier.
28
-
29
- # trainer.save_checkpoint(self.ckpt_path)
30
-
31
- if trainer.model is None:
32
- raise AttributeError(
33
- "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
34
- " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
35
- )
36
- checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
37
- trainer.strategy.save_checkpoint(
38
- checkpoint, self.ckpt_path, storage_options=None
39
- )
40
- # self.strategy.barrier("Trainer.save_checkpoint") # <-- This is disabled
41
-
42
- @override
43
- def teardown(self, trainer: Trainer, *_: Any, **__: Any) -> None:
44
- trainer.strategy.remove_checkpoint(self.ckpt_path)
91
+ def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
92
+ # Monkey-patch the strategy instance to make the barrier operation a no-op.
93
+ # We do this because `save_checkpoint` calls `barrier`. This is okay in most
94
+ # cases, but when we want to save a checkpoint in the case of an exception,
95
+ # `barrier` causes a deadlock. So we monkey-patch the strategy instance to
96
+ # make the barrier operation a no-op.
97
+ with _monkey_patch_disable_barrier(trainer):
98
+ return super().on_exception(trainer, *args, **kwargs)
@@ -86,5 +86,5 @@ class PrintTableMetricsConfig(CallbackConfigBase):
86
86
  """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
87
87
 
88
88
  @override
89
- def construct_callbacks(self, root_config):
89
+ def create_callbacks(self, root_config):
90
90
  yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
@@ -52,5 +52,5 @@ class ThroughputMonitorConfig(CallbackConfigBase):
52
52
  """Number of batches to use for a rolling average."""
53
53
 
54
54
  @override
55
- def construct_callbacks(self, root_config):
55
+ def create_callbacks(self, root_config):
56
56
  yield ThroughputMonitor(window_size=self.window_size)
@@ -153,5 +153,5 @@ class EpochTimerConfig(CallbackConfigBase):
153
153
  name: Literal["epoch_timer"] = "epoch_timer"
154
154
 
155
155
  @override
156
- def construct_callbacks(self, root_config):
156
+ def create_callbacks(self, root_config):
157
157
  yield EpochTimer()
@@ -99,5 +99,5 @@ class WandbWatchConfig(CallbackConfigBase):
99
99
  return self.enabled
100
100
 
101
101
  @override
102
- def construct_callbacks(self, root_config):
102
+ def create_callbacks(self, root_config):
103
103
  yield WandbWatchCallback(self)
nshtrainer/ll/__init__.py CHANGED
@@ -21,7 +21,6 @@ from .log import init_python_logging as init_python_logging
21
21
  from .log import lovely as lovely
22
22
  from .log import pretty as pretty
23
23
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
- from .model import ActSaveConfig as ActSaveConfig
25
24
  from .model import Base as Base
26
25
  from .model import BaseConfig as BaseConfig
27
26
  from .model import BaseLoggerConfig as BaseLoggerConfig