nshtrainer 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/_checkpoint/loader.py +319 -0
  3. nshtrainer/_checkpoint/metadata.py +102 -0
  4. nshtrainer/callbacks/__init__.py +17 -1
  5. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  6. nshtrainer/callbacks/base.py +7 -5
  7. nshtrainer/callbacks/ema.py +1 -1
  8. nshtrainer/callbacks/finite_checks.py +1 -1
  9. nshtrainer/callbacks/gradient_skipping.py +1 -1
  10. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  11. nshtrainer/callbacks/model_checkpoint.py +187 -0
  12. nshtrainer/callbacks/norm_logging.py +1 -1
  13. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  14. nshtrainer/callbacks/print_table.py +1 -1
  15. nshtrainer/callbacks/throughput_monitor.py +1 -1
  16. nshtrainer/callbacks/timer.py +1 -1
  17. nshtrainer/callbacks/wandb_watch.py +1 -1
  18. nshtrainer/ll/__init__.py +0 -1
  19. nshtrainer/ll/actsave.py +2 -1
  20. nshtrainer/metrics/__init__.py +1 -0
  21. nshtrainer/metrics/_config.py +37 -0
  22. nshtrainer/model/__init__.py +11 -11
  23. nshtrainer/model/_environment.py +777 -0
  24. nshtrainer/model/base.py +5 -114
  25. nshtrainer/model/config.py +49 -501
  26. nshtrainer/model/modules/logger.py +11 -6
  27. nshtrainer/runner.py +3 -6
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py CHANGED
@@ -2,13 +2,14 @@ from . import _experimental as _experimental
2
2
  from . import callbacks as callbacks
3
3
  from . import data as data
4
4
  from . import lr_scheduler as lr_scheduler
5
+ from . import metrics as metrics
5
6
  from . import model as model
6
7
  from . import nn as nn
7
8
  from . import optimizer as optimizer
9
+ from .metrics import MetricConfig as MetricConfig
8
10
  from .model import Base as Base
9
11
  from .model import BaseConfig as BaseConfig
10
12
  from .model import ConfigList as ConfigList
11
13
  from .model import LightningModuleBase as LightningModuleBase
12
- from .model import MetricConfig as MetricConfig
13
14
  from .runner import Runner as Runner
14
15
  from .trainer import Trainer as Trainer
@@ -0,0 +1,319 @@
1
+ import logging
2
+ from collections.abc import Iterable, Sequence
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
6
+
7
+ import nshconfig as C
8
+ from lightning.pytorch import Trainer as LightningTrainer
9
+ from lightning.pytorch.trainer.states import TrainerFn
10
+ from typing_extensions import assert_never
11
+
12
+ from ..metrics._config import MetricConfig
13
+ from .metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
14
+
15
+ if TYPE_CHECKING:
16
+ from ..model.config import BaseConfig
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ class BestCheckpointStrategyConfig(C.Config):
22
+ name: Literal["best"] = "best"
23
+
24
+ metric: MetricConfig | None = None
25
+ """The metric to use for selecting the best checkpoint. If `None`, the primary metric will be used."""
26
+
27
+ additional_candidates: Iterable[Path] = []
28
+ """Additional checkpoint candidates to consider when selecting the last checkpoint."""
29
+
30
+
31
+ class UserProvidedPathCheckpointStrategyConfig(C.Config):
32
+ name: Literal["user_provided_path"] = "user_provided_path"
33
+
34
+ path: Path
35
+ """The path to the checkpoint to load."""
36
+
37
+ on_error: Literal["warn", "raise"] = "warn"
38
+ """The behavior when the checkpoint does not belong to the current run.
39
+
40
+ - `warn`: Log a warning and skip the checkpoint.
41
+ - `raise`: Raise an error.
42
+ """
43
+
44
+
45
+ class LastCheckpointStrategyConfig(C.Config):
46
+ name: Literal["last"] = "last"
47
+
48
+ criterion: Literal["global_step", "runtime"] = "global_step"
49
+ """The criterion to use for selecting the last checkpoint.
50
+
51
+ - `global_step`: The checkpoint with the highest global step will be selected.
52
+ - `runtime`: The checkpoint with the highest runtime will be selected.
53
+ """
54
+
55
+ additional_candidates: Iterable[Path] = []
56
+ """Additional checkpoint candidates to consider when selecting the last checkpoint."""
57
+
58
+
59
+ CheckpointLoadingStrategyConfig: TypeAlias = Annotated[
60
+ BestCheckpointStrategyConfig
61
+ | LastCheckpointStrategyConfig
62
+ | UserProvidedPathCheckpointStrategyConfig,
63
+ C.Field(discriminator="name"),
64
+ ]
65
+
66
+
67
+ class CheckpointLoadingConfig(C.Config):
68
+ strategies: Sequence[CheckpointLoadingStrategyConfig]
69
+ """The strategies to use for loading checkpoints.
70
+
71
+ The order of the strategies determines the priority of the strategies.
72
+ The first strategy that resolves a checkpoint will be used.
73
+ """
74
+
75
+ include_hpc: bool
76
+ """Whether to include checkpoints from HPC pre-emption."""
77
+
78
+ @classmethod
79
+ def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
80
+ if ckpt is None:
81
+ ckpt = "last"
82
+ match ckpt:
83
+ case "best":
84
+ return cls(
85
+ strategies=[BestCheckpointStrategyConfig()],
86
+ include_hpc=True,
87
+ )
88
+ case "last":
89
+ return cls(
90
+ strategies=[LastCheckpointStrategyConfig()],
91
+ include_hpc=True,
92
+ )
93
+ case Path() | str():
94
+ ckpt = Path(ckpt)
95
+ return cls(
96
+ strategies=[
97
+ LastCheckpointStrategyConfig(additional_candidates=[ckpt]),
98
+ UserProvidedPathCheckpointStrategyConfig(path=ckpt),
99
+ ],
100
+ include_hpc=True,
101
+ )
102
+ case _:
103
+ assert_never(ckpt)
104
+
105
+ @classmethod
106
+ def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
107
+ if ckpt is None:
108
+ raise ValueError("Checkpoint path must be provided for evaluation.")
109
+
110
+ match ckpt:
111
+ case "best":
112
+ return cls(
113
+ strategies=[BestCheckpointStrategyConfig()],
114
+ include_hpc=False,
115
+ )
116
+ case "last":
117
+ return cls(
118
+ strategies=[LastCheckpointStrategyConfig()],
119
+ include_hpc=False,
120
+ )
121
+ case Path() | str():
122
+ ckpt = Path(ckpt)
123
+ return cls(
124
+ strategies=[UserProvidedPathCheckpointStrategyConfig(path=ckpt)],
125
+ include_hpc=False,
126
+ )
127
+ case _:
128
+ assert_never(ckpt)
129
+
130
+ @classmethod
131
+ def auto(
132
+ cls,
133
+ ckpt: Literal["best", "last"] | str | Path | None,
134
+ trainer_mode: TrainerFn,
135
+ ):
136
+ match trainer_mode:
137
+ case TrainerFn.FITTING:
138
+ return cls._auto_train(ckpt)
139
+ case TrainerFn.VALIDATING | TrainerFn.TESTING | TrainerFn.PREDICTING:
140
+ return cls._auto_eval(ckpt)
141
+ case _:
142
+ assert_never(trainer_mode)
143
+
144
+
145
+ @dataclass
146
+ class _CkptCandidate:
147
+ meta: CheckpointMetadata
148
+ meta_path: Path
149
+
150
+ @property
151
+ def ckpt_path(self):
152
+ return self.meta_path.with_name(self.meta.checkpoint_filename)
153
+
154
+
155
+ @overload
156
+ def _load_ckpt_meta(
157
+ path: Path,
158
+ root_config: "BaseConfig",
159
+ on_error: Literal["warn"] = "warn",
160
+ ) -> _CkptCandidate | None: ...
161
+ @overload
162
+ def _load_ckpt_meta(
163
+ path: Path,
164
+ root_config: "BaseConfig",
165
+ on_error: Literal["raise"],
166
+ ) -> _CkptCandidate: ...
167
+ def _load_ckpt_meta(
168
+ path: Path,
169
+ root_config: "BaseConfig",
170
+ on_error: Literal["warn", "raise"] = "warn",
171
+ ):
172
+ meta = CheckpointMetadata.from_file(path)
173
+ if root_config.id != meta.run_id:
174
+ error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
175
+ match on_error:
176
+ case "warn":
177
+ log.warn(error_msg)
178
+ case "raise":
179
+ raise ValueError(error_msg)
180
+ case _:
181
+ assert_never(on_error)
182
+ return None
183
+ return _CkptCandidate(meta, path)
184
+
185
+
186
+ def _checkpoint_candidates(
187
+ root_config: "BaseConfig",
188
+ trainer: LightningTrainer,
189
+ *,
190
+ include_hpc: bool = True,
191
+ ):
192
+ # Load the checkpoint directory, and throw if it doesn't exist.
193
+ # This indicates a non-standard setup, and we don't want to guess
194
+ # where the checkpoints are.
195
+ ckpt_dir = root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
196
+ if not ckpt_dir.is_dir():
197
+ raise FileNotFoundError(
198
+ f"Checkpoint directory {ckpt_dir} not found. "
199
+ "Please ensure that the checkpoint directory exists."
200
+ )
201
+
202
+ # Load all checkpoints in the directory.
203
+ # We can do this by looking for metadata files.
204
+ for path in ckpt_dir.glob(f"*{METADATA_PATH_SUFFIX}"):
205
+ if (meta := _load_ckpt_meta(path, root_config)) is not None:
206
+ yield meta
207
+
208
+ # If we have a pre-empted checkpoint, load it
209
+ if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
210
+ hpc_meta_path = Path(hpc_path).with_suffix(METADATA_PATH_SUFFIX)
211
+ if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
212
+ yield meta
213
+
214
+
215
+ def _additional_candidates(
216
+ additional_candidates: Iterable[Path], root_config: "BaseConfig"
217
+ ):
218
+ for path in additional_candidates:
219
+ if (
220
+ meta := _load_ckpt_meta(path.with_suffix(METADATA_PATH_SUFFIX), root_config)
221
+ ) is None:
222
+ continue
223
+ yield meta
224
+
225
+
226
+ def _resolve_checkpoint(
227
+ config: CheckpointLoadingConfig,
228
+ root_config: "BaseConfig",
229
+ trainer: LightningTrainer,
230
+ ):
231
+ # We lazily load the checkpoint candidates to avoid loading them
232
+ # if they are not needed.
233
+ _ckpt_candidates: list[_CkptCandidate] | None = None
234
+
235
+ def ckpt_candidates():
236
+ nonlocal _ckpt_candidates, root_config, trainer
237
+
238
+ if _ckpt_candidates is None:
239
+ _ckpt_candidates = list(
240
+ _checkpoint_candidates(
241
+ root_config, trainer, include_hpc=config.include_hpc
242
+ )
243
+ )
244
+ return _ckpt_candidates
245
+
246
+ # Iterate over the strategies and try to resolve the checkpoint.
247
+ for strategy in config.strategies:
248
+ match strategy:
249
+ case UserProvidedPathCheckpointStrategyConfig():
250
+ meta = _load_ckpt_meta(
251
+ strategy.path.with_suffix(METADATA_PATH_SUFFIX),
252
+ root_config,
253
+ on_error=strategy.on_error,
254
+ )
255
+ if meta is None:
256
+ continue
257
+ return meta.ckpt_path
258
+ case BestCheckpointStrategyConfig():
259
+ candidates = [
260
+ *ckpt_candidates(),
261
+ *_additional_candidates(
262
+ strategy.additional_candidates, root_config
263
+ ),
264
+ ]
265
+ if not candidates:
266
+ log.warn(
267
+ "No checkpoint candidates found for `best` checkpoint strategy."
268
+ )
269
+ continue
270
+
271
+ if (metric := strategy.metric or root_config.primary_metric) is None:
272
+ log.warn(
273
+ "No metric specified for `best` checkpoint strategy, "
274
+ "and no primary metric is set in the configuration. "
275
+ "Skipping strategy."
276
+ )
277
+ continue
278
+
279
+ # Find the best checkpoint based on the metric.
280
+ def metric_value(ckpt: _CkptCandidate):
281
+ assert metric is not None
282
+ if (
283
+ value := ckpt.meta.metrics.get(metric.validation_monitor)
284
+ ) is None:
285
+ raise ValueError(
286
+ f"Metric {metric.validation_monitor} not found in checkpoint metadata. "
287
+ f"Available metrics: {ckpt.meta.metrics.keys()}"
288
+ )
289
+ return value
290
+
291
+ best_candidate = metric.best(candidates, key=metric_value)
292
+ return best_candidate.ckpt_path
293
+ case LastCheckpointStrategyConfig():
294
+ candidates = [
295
+ *ckpt_candidates(),
296
+ *_additional_candidates(
297
+ strategy.additional_candidates, root_config
298
+ ),
299
+ ]
300
+ if not candidates:
301
+ log.warn(
302
+ "No checkpoint candidates found for `last` checkpoint strategy."
303
+ )
304
+ continue
305
+
306
+ # Find the last checkpoint based on the criterion.
307
+ def criterion_value(ckpt: _CkptCandidate):
308
+ match strategy.criterion:
309
+ case "global_step":
310
+ return ckpt.meta.global_step
311
+ case "runtime":
312
+ return ckpt.meta.training_time.total_seconds()
313
+ case _:
314
+ assert_never(strategy.criterion)
315
+
316
+ last_candidate = max(candidates, key=criterion_value)
317
+ return last_candidate.ckpt_path
318
+ case _:
319
+ assert_never(strategy)
@@ -0,0 +1,102 @@
1
+ import copy
2
+ import datetime
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any, cast
6
+
7
+ import nshconfig as C
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ..model._environment import EnvironmentConfig
12
+
13
+ if TYPE_CHECKING:
14
+ from ..model import BaseConfig, LightningModuleBase
15
+ from ..trainer.trainer import Trainer
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ METADATA_PATH_SUFFIX = ".metadata.json"
21
+ HPARAMS_PATH_SUFFIX = ".hparams.json"
22
+
23
+
24
+ class CheckpointMetadata(C.Config):
25
+ checkpoint_path: Path
26
+ checkpoint_filename: str
27
+
28
+ run_id: str
29
+ name: str
30
+ project: str | None
31
+ checkpoint_timestamp: datetime.datetime
32
+ start_timestamp: datetime.datetime | None
33
+
34
+ epoch: int
35
+ global_step: int
36
+ training_time: datetime.timedelta
37
+ metrics: dict[str, Any]
38
+ environment: EnvironmentConfig
39
+
40
+ @classmethod
41
+ def from_file(cls, path: Path):
42
+ return cls.model_validate_json(path.read_text())
43
+
44
+
45
+ def _generate_checkpoint_metadata(
46
+ config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
47
+ ):
48
+ checkpoint_timestamp = datetime.datetime.now()
49
+ start_timestamp = trainer.start_time()
50
+ training_time = trainer.time_elapsed()
51
+
52
+ metrics: dict[str, Any] = {}
53
+ for name, metric in copy.deepcopy(trainer.callback_metrics).items():
54
+ match metric:
55
+ case torch.Tensor() | np.ndarray():
56
+ metrics[name] = metric.detach().cpu().item()
57
+ case _:
58
+ metrics[name] = metric
59
+
60
+ return CheckpointMetadata(
61
+ checkpoint_path=checkpoint_path,
62
+ checkpoint_filename=checkpoint_path.name,
63
+ run_id=config.id,
64
+ name=config.run_name,
65
+ project=config.project,
66
+ checkpoint_timestamp=checkpoint_timestamp,
67
+ start_timestamp=start_timestamp.datetime
68
+ if start_timestamp is not None
69
+ else None,
70
+ epoch=trainer.current_epoch,
71
+ global_step=trainer.global_step,
72
+ training_time=training_time,
73
+ metrics=metrics,
74
+ environment=config.environment,
75
+ )
76
+
77
+
78
+ def _write_checkpoint_metadata(
79
+ trainer: "Trainer",
80
+ model: "LightningModuleBase",
81
+ checkpoint_path: Path,
82
+ ):
83
+ config = cast("BaseConfig", model.config)
84
+ metadata = _generate_checkpoint_metadata(config, trainer, checkpoint_path)
85
+
86
+ # Write the metadata to the checkpoint directory
87
+ try:
88
+ metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
89
+ metadata_path.write_text(metadata.model_dump_json(indent=4))
90
+ except Exception as e:
91
+ log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
92
+ else:
93
+ log.info(f"Checkpoint metadata written to {checkpoint_path}")
94
+
95
+ # Write the hparams to the checkpoint directory
96
+ try:
97
+ hparams_path = checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX)
98
+ hparams_path.write_text(config.model_dump_json(indent=4))
99
+ except Exception as e:
100
+ log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
101
+ else:
102
+ log.info(f"Checkpoint metadata written to {checkpoint_path}")
@@ -14,15 +14,27 @@ from .interval import EpochIntervalCallback as EpochIntervalCallback
14
14
  from .interval import IntervalCallback as IntervalCallback
15
15
  from .interval import StepIntervalCallback as StepIntervalCallback
16
16
  from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
17
+ from .latest_epoch_checkpoint import (
18
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
19
+ )
17
20
  from .log_epoch import LogEpochCallback as LogEpochCallback
21
+ from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
22
+ from .model_checkpoint import (
23
+ ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
24
+ )
18
25
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
19
26
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
20
27
  from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
28
+ from .on_exception_checkpoint import (
29
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
30
+ )
21
31
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
22
32
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
23
33
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
24
34
  from .timer import EpochTimer as EpochTimer
25
35
  from .timer import EpochTimerConfig as EpochTimerConfig
36
+ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
37
+ from .wandb_watch import WandbWatchConfig as WandbWatchConfig
26
38
 
27
39
  CallbackConfig = Annotated[
28
40
  ThroughputMonitorConfig
@@ -31,6 +43,10 @@ CallbackConfig = Annotated[
31
43
  | FiniteChecksConfig
32
44
  | NormLoggingConfig
33
45
  | GradientSkippingConfig
34
- | EMAConfig,
46
+ | EMAConfig
47
+ | ModelCheckpointCallbackConfig
48
+ | LatestEpochCheckpointCallbackConfig
49
+ | OnExceptionCheckpointCallbackConfig
50
+ | WandbWatchConfig,
35
51
  C.Field(discriminator="name"),
36
52
  ]
@@ -1,28 +1,87 @@
1
1
  import contextlib
2
- from typing import TYPE_CHECKING, Literal, cast
2
+ from pathlib import Path
3
+ from typing import Literal
3
4
 
4
5
  from lightning.pytorch import LightningModule, Trainer
5
6
  from lightning.pytorch.callbacks.callback import Callback
6
- from nshutils.actsave import ActSave
7
7
  from typing_extensions import TypeAlias, override
8
8
 
9
- if TYPE_CHECKING:
10
- from ..model.config import BaseConfig
9
+ from .base import CallbackConfigBase
10
+
11
+ try:
12
+ from nshutils import ActSave # type: ignore
13
+ except ImportError:
14
+ ActSave = None
11
15
 
12
16
  Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
13
17
 
14
18
 
19
+ class ActSaveConfig(CallbackConfigBase):
20
+ enabled: bool = True
21
+ """Enable activation saving."""
22
+
23
+ save_dir: Path | None = None
24
+ """Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
25
+
26
+ def __bool__(self):
27
+ return self.enabled
28
+
29
+ @override
30
+ def create_callbacks(self, root_config):
31
+ yield ActSaveCallback(
32
+ self,
33
+ self.save_dir
34
+ or root_config.directory.resolve_subdirectory(root_config.id, "activation"),
35
+ )
36
+
37
+
15
38
  class ActSaveCallback(Callback):
16
- def __init__(self):
39
+ def __init__(self, config: ActSaveConfig, save_dir: Path):
17
40
  super().__init__()
18
41
 
42
+ self.config = config
43
+ self.save_dir = save_dir
44
+ self._enabled_context: contextlib._GeneratorContextManager | None = None
19
45
  self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
20
46
 
47
+ @override
48
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
49
+ super().setup(trainer, pl_module, stage)
50
+
51
+ if not self.config:
52
+ return
53
+
54
+ if ActSave is None:
55
+ raise ImportError(
56
+ "ActSave is not installed. Please install nshutils to use the ActSaveCallback."
57
+ )
58
+
59
+ context = ActSave.enabled(self.save_dir)
60
+ context.__enter__()
61
+ self._enabled_context = context
62
+
63
+ @override
64
+ def teardown(
65
+ self, trainer: Trainer, pl_module: LightningModule, stage: str
66
+ ) -> None:
67
+ super().teardown(trainer, pl_module, stage)
68
+
69
+ if not self.config:
70
+ return
71
+
72
+ if self._enabled_context is not None:
73
+ self._enabled_context.__exit__(None, None, None)
74
+ self._enabled_context = None
75
+
21
76
  def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
22
- hparams = cast("BaseConfig", pl_module.hparams)
23
- if not hparams.trainer.actsave:
77
+ if not self.config:
24
78
  return
25
79
 
80
+ if ActSave is None:
81
+ raise ImportError(
82
+ "ActSave is not installed. Please install nshutils to use the ActSaveCallback."
83
+ )
84
+
26
85
  # If we have an active context manager for this stage, exit it
27
86
  if active_contexts := self._active_contexts.get(stage):
28
87
  active_contexts.__exit__(None, None, None)
@@ -33,12 +92,11 @@ class ActSaveCallback(Callback):
33
92
  self._active_contexts[stage] = context
34
93
 
35
94
  def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
36
- hparams = cast("BaseConfig", pl_module.hparams)
37
- if not hparams.trainer.actsave:
95
+ if not self.config:
38
96
  return
39
97
 
40
98
  # If we have an active context manager for this stage, exit it
41
- if active_contexts := self._active_contexts.get(stage):
99
+ if (active_contexts := self._active_contexts.pop(stage, None)) is not None:
42
100
  active_contexts.__exit__(None, None, None)
43
101
 
44
102
  @override
@@ -46,16 +46,16 @@ class CallbackConfigBase(C.Config, ABC):
46
46
  )
47
47
 
48
48
  @abstractmethod
49
- def construct_callbacks(
49
+ def create_callbacks(
50
50
  self, root_config: "BaseConfig"
51
51
  ) -> Iterable[Callback | CallbackWithMetadata]: ...
52
52
 
53
53
 
54
54
  # region Config resolution helpers
55
- def _construct_callbacks_with_metadata(
55
+ def _create_callbacks_with_metadata(
56
56
  config: CallbackConfigBase, root_config: "BaseConfig"
57
57
  ) -> Iterable[CallbackWithMetadata]:
58
- for callback in config.construct_callbacks(root_config):
58
+ for callback in config.create_callbacks(root_config):
59
59
  if isinstance(callback, CallbackWithMetadata):
60
60
  yield callback
61
61
  continue
@@ -99,12 +99,14 @@ def _process_and_filter_callbacks(
99
99
 
100
100
  def resolve_all_callbacks(root_config: "BaseConfig"):
101
101
  callback_configs = [
102
- config for config in root_config.ll_all_callback_configs() if config is not None
102
+ config
103
+ for config in root_config._nshtrainer_all_callback_configs()
104
+ if config is not None
103
105
  ]
104
106
  callbacks = _process_and_filter_callbacks(
105
107
  callback
106
108
  for callback_config in callback_configs
107
- for callback in _construct_callbacks_with_metadata(callback_config, root_config)
109
+ for callback in _create_callbacks_with_metadata(callback_config, root_config)
108
110
  )
109
111
  return callbacks
110
112
 
@@ -374,7 +374,7 @@ class EMAConfig(CallbackConfigBase):
374
374
  """Offload weights to CPU."""
375
375
 
376
376
  @override
377
- def construct_callbacks(self, root_config):
377
+ def create_callbacks(self, root_config):
378
378
  yield EMA(
379
379
  decay=self.decay,
380
380
  validate_original_weights=self.validate_original_weights,
@@ -68,7 +68,7 @@ class FiniteChecksConfig(CallbackConfigBase):
68
68
  """Whether to check for None gradients"""
69
69
 
70
70
  @override
71
- def construct_callbacks(self, root_config):
71
+ def create_callbacks(self, root_config):
72
72
  yield FiniteChecksCallback(
73
73
  nonfinite_grads=self.nonfinite_grads,
74
74
  none_grads=self.none_grads,
@@ -99,5 +99,5 @@ class GradientSkippingConfig(CallbackConfigBase):
99
99
  """
100
100
 
101
101
  @override
102
- def construct_callbacks(self, root_config):
102
+ def create_callbacks(self, root_config):
103
103
  yield GradientSkipping(self)