nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
nshtrainer/_hf_hub.py CHANGED
@@ -7,19 +7,19 @@ import re
7
7
  from dataclasses import dataclass
8
8
  from functools import cached_property
9
9
  from pathlib import Path
10
- from typing import TYPE_CHECKING, Any, Literal, cast
10
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
11
11
 
12
12
  import nshconfig as C
13
13
  from nshrunner._env import SNAPSHOT_DIR
14
14
  from typing_extensions import assert_never, override
15
15
 
16
16
  from ._callback import NTCallbackBase
17
- from .callbacks.base import CallbackConfigBase
17
+ from .callbacks.base import CallbackConfigBase, CallbackMetadataConfig
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from huggingface_hub import HfApi # noqa: F401
21
21
 
22
- from .model.base import BaseConfig
22
+ from .trainer._config import TrainerConfig
23
23
 
24
24
 
25
25
  log = logging.getLogger(__name__)
@@ -42,6 +42,8 @@ class HuggingFaceHubAutoCreateConfig(C.Config):
42
42
  class HuggingFaceHubConfig(CallbackConfigBase):
43
43
  """Configuration options for Hugging Face Hub integration."""
44
44
 
45
+ metadata: ClassVar[CallbackMetadataConfig] = {"ignore_if_exists": True}
46
+
45
47
  enabled: bool = False
46
48
  """Enable Hugging Face Hub integration."""
47
49
 
@@ -82,7 +84,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
82
84
  return self.enabled
83
85
 
84
86
  @override
85
- def create_callbacks(self, root_config):
87
+ def create_callbacks(self, trainer_config):
86
88
  # Attempt to login. If it fails, we'll log a warning or error based on the configuration.
87
89
  try:
88
90
  api = _api(self.token)
@@ -107,7 +109,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
107
109
  case _:
108
110
  assert_never(self.on_login_error)
109
111
 
110
- yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
112
+ yield HFHubCallback(self)
111
113
 
112
114
 
113
115
  def _api(token: str | None = None):
@@ -138,19 +140,20 @@ def _api(token: str | None = None):
138
140
  return api
139
141
 
140
142
 
141
- def _repo_name(api: "HfApi", root_config: "BaseConfig"):
143
+ def _repo_name(api: HfApi, trainer_config: TrainerConfig):
142
144
  username = None
143
- if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
145
+ if (ac := trainer_config.hf_hub.auto_create) and ac.namespace:
144
146
  username = ac.namespace
145
147
  elif (username := api.whoami().get("name", None)) is None:
146
148
  raise ValueError("Could not get username from Hugging Face Hub.")
147
149
 
148
150
  # Sanitize the project (if it exists), run_name, and id
149
151
  parts = []
150
- if root_config.project:
151
- parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
152
- parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
153
- parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
152
+ if trainer_config.project:
153
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", trainer_config.project))
154
+ if trainer_config.full_name:
155
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", trainer_config.full_name))
156
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", trainer_config.id))
154
157
 
155
158
  # Combine parts and ensure it starts and ends with alphanumeric characters
156
159
  repo_name = "-".join(parts)
@@ -179,14 +182,10 @@ class _Upload:
179
182
  path_in_repo: Path
180
183
 
181
184
  @classmethod
182
- def from_local_path(
183
- cls,
184
- local_path: Path,
185
- root_config: "BaseConfig",
186
- ):
185
+ def from_local_path(cls, local_path: Path, trainer_config: TrainerConfig):
187
186
  # Resolve the checkpoint directory
188
- checkpoint_dir = root_config.directory.resolve_subdirectory(
189
- root_config.id, "checkpoint"
187
+ checkpoint_dir = trainer_config.directory.resolve_subdirectory(
188
+ trainer_config.id, "checkpoint"
190
189
  )
191
190
 
192
191
  try:
@@ -224,8 +223,7 @@ class HFHubCallback(NTCallbackBase):
224
223
 
225
224
  @override
226
225
  def setup(self, trainer, pl_module, stage):
227
- root_config = cast("BaseConfig", pl_module.hparams)
228
- self._repo_id = _repo_name(self.api, root_config)
226
+ self._repo_id = _repo_name(self.api, trainer.hparams)
229
227
 
230
228
  if not self.config or not trainer.is_global_zero:
231
229
  return
@@ -234,7 +232,7 @@ class HFHubCallback(NTCallbackBase):
234
232
  self._create_repo_if_not_exists()
235
233
 
236
234
  # Upload the config and code
237
- self._save_config(root_config)
235
+ self._save_config(trainer.hparams)
238
236
  self._save_code()
239
237
 
240
238
  @override
@@ -248,10 +246,9 @@ class HFHubCallback(NTCallbackBase):
248
246
  return
249
247
 
250
248
  with self._with_error_handling("save checkpoints"):
251
- root_config = cast("BaseConfig", pl_module.hparams)
252
249
  self._save_checkpoint(
253
- _Upload.from_local_path(ckpt_path, root_config),
254
- _Upload.from_local_path(metadata_path, root_config)
250
+ _Upload.from_local_path(ckpt_path, trainer.hparams),
251
+ _Upload.from_local_path(metadata_path, trainer.hparams)
255
252
  if metadata_path is not None
256
253
  else None,
257
254
  )
@@ -300,10 +297,12 @@ class HFHubCallback(NTCallbackBase):
300
297
  f"Error checking repository '{self.repo_id}'", exc_info=True
301
298
  )
302
299
 
303
- def _save_config(self, root_config: "BaseConfig"):
300
+ def _save_config(self, trainer_config: TrainerConfig):
304
301
  with self._with_error_handling("upload config"):
305
302
  self.api.upload_file(
306
- path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
303
+ path_or_fileobj=trainer_config.model_dump_json(indent=4).encode(
304
+ "utf-8"
305
+ ),
307
306
  path_in_repo="config.json",
308
307
  repo_id=self.repo_id,
309
308
  repo_type="model",
@@ -6,7 +6,7 @@ import nshconfig as C
6
6
 
7
7
  from . import checkpoint as checkpoint
8
8
  from .base import CallbackConfigBase as CallbackConfigBase
9
- from .checkpoint import BestCheckpoint as BestCheckpoint
9
+ from .checkpoint import BestCheckpointCallback as BestCheckpointCallback
10
10
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
11
11
  from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
12
12
  from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
@@ -49,7 +49,6 @@ from .shared_parameters import SharedParametersCallback as SharedParametersCallb
49
49
  from .shared_parameters import (
50
50
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
51
51
  )
52
- from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
53
52
  from .timer import EpochTimerCallback as EpochTimerCallback
54
53
  from .timer import EpochTimerCallbackConfig as EpochTimerCallbackConfig
55
54
  from .wandb_upload_code import WandbUploadCodeCallback as WandbUploadCodeCallback
@@ -62,7 +61,6 @@ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
62
61
  CallbackConfig = Annotated[
63
62
  DebugFlagCallbackConfig
64
63
  | EarlyStoppingCallbackConfig
65
- | ThroughputMonitorConfig
66
64
  | EpochTimerCallbackConfig
67
65
  | PrintTableMetricsCallbackConfig
68
66
  | FiniteChecksCallbackConfig
@@ -4,11 +4,9 @@ import contextlib
4
4
  from pathlib import Path
5
5
  from typing import Literal
6
6
 
7
- from lightning.pytorch import LightningModule, Trainer
8
- from lightning.pytorch.callbacks.callback import Callback
9
- from nshutils import ActSave
10
7
  from typing_extensions import TypeAlias, override
11
8
 
9
+ from .._callback import NTCallbackBase
12
10
  from .base import CallbackConfigBase
13
11
 
14
12
  Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
@@ -25,15 +23,17 @@ class ActSaveConfig(CallbackConfigBase):
25
23
  return self.enabled
26
24
 
27
25
  @override
28
- def create_callbacks(self, root_config):
26
+ def create_callbacks(self, trainer_config):
29
27
  yield ActSaveCallback(
30
28
  self,
31
29
  self.save_dir
32
- or root_config.directory.resolve_subdirectory(root_config.id, "activation"),
30
+ or trainer_config.directory.resolve_subdirectory(
31
+ trainer_config.id, "activation"
32
+ ),
33
33
  )
34
34
 
35
35
 
36
- class ActSaveCallback(Callback):
36
+ class ActSaveCallback(NTCallbackBase):
37
37
  def __init__(self, config: ActSaveConfig, save_dir: Path):
38
38
  super().__init__()
39
39
 
@@ -43,20 +43,20 @@ class ActSaveCallback(Callback):
43
43
  self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
44
44
 
45
45
  @override
46
- def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
46
+ def setup(self, trainer, pl_module, stage) -> None:
47
47
  super().setup(trainer, pl_module, stage)
48
48
 
49
49
  if not self.config:
50
50
  return
51
51
 
52
+ from nshutils import ActSave
53
+
52
54
  context = ActSave.enabled(self.save_dir)
53
55
  context.__enter__()
54
56
  self._enabled_context = context
55
57
 
56
58
  @override
57
- def teardown(
58
- self, trainer: Trainer, pl_module: LightningModule, stage: str
59
- ) -> None:
59
+ def teardown(self, trainer, pl_module, stage) -> None:
60
60
  super().teardown(trainer, pl_module, stage)
61
61
 
62
62
  if not self.config:
@@ -66,10 +66,12 @@ class ActSaveCallback(Callback):
66
66
  self._enabled_context.__exit__(None, None, None)
67
67
  self._enabled_context = None
68
68
 
69
- def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
69
+ def _on_start(self, stage: Stage, trainer, pl_module):
70
70
  if not self.config:
71
71
  return
72
72
 
73
+ from nshutils import ActSave
74
+
73
75
  # If we have an active context manager for this stage, exit it
74
76
  if active_contexts := self._active_contexts.get(stage):
75
77
  active_contexts.__exit__(None, None, None)
@@ -79,7 +81,7 @@ class ActSaveCallback(Callback):
79
81
  context.__enter__()
80
82
  self._active_contexts[stage] = context
81
83
 
82
- def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
84
+ def _on_end(self, stage: Stage, trainer, pl_module):
83
85
  if not self.config:
84
86
  return
85
87
 
@@ -88,33 +90,33 @@ class ActSaveCallback(Callback):
88
90
  active_contexts.__exit__(None, None, None)
89
91
 
90
92
  @override
91
- def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
93
+ def on_train_epoch_start(self, trainer, pl_module):
92
94
  return self._on_start("train", trainer, pl_module)
93
95
 
94
96
  @override
95
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
97
+ def on_train_epoch_end(self, trainer, pl_module):
96
98
  return self._on_end("train", trainer, pl_module)
97
99
 
98
100
  @override
99
- def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
101
+ def on_validation_epoch_start(self, trainer, pl_module):
100
102
  return self._on_start("validation", trainer, pl_module)
101
103
 
102
104
  @override
103
- def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
105
+ def on_validation_epoch_end(self, trainer, pl_module):
104
106
  return self._on_end("validation", trainer, pl_module)
105
107
 
106
108
  @override
107
- def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
109
+ def on_test_epoch_start(self, trainer, pl_module):
108
110
  return self._on_start("test", trainer, pl_module)
109
111
 
110
112
  @override
111
- def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
113
+ def on_test_epoch_end(self, trainer, pl_module):
112
114
  return self._on_end("test", trainer, pl_module)
113
115
 
114
116
  @override
115
- def on_predict_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
117
+ def on_predict_epoch_start(self, trainer, pl_module):
116
118
  return self._on_start("predict", trainer, pl_module)
117
119
 
118
120
  @override
119
- def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
121
+ def on_predict_epoch_end(self, trainer, pl_module):
120
122
  return self._on_end("predict", trainer, pl_module)
@@ -11,7 +11,7 @@ from lightning.pytorch import Callback
11
11
  from typing_extensions import TypedDict, Unpack
12
12
 
13
13
  if TYPE_CHECKING:
14
- from ..model.config import BaseConfig
14
+ from ..trainer._config import TrainerConfig
15
15
 
16
16
 
17
17
  class CallbackMetadataConfig(TypedDict, total=False):
@@ -49,15 +49,15 @@ class CallbackConfigBase(C.Config, ABC):
49
49
 
50
50
  @abstractmethod
51
51
  def create_callbacks(
52
- self, root_config: "BaseConfig"
52
+ self, trainer_config: TrainerConfig
53
53
  ) -> Iterable[Callback | CallbackWithMetadata]: ...
54
54
 
55
55
 
56
56
  # region Config resolution helpers
57
57
  def _create_callbacks_with_metadata(
58
- config: CallbackConfigBase, root_config: "BaseConfig"
58
+ config: CallbackConfigBase, trainer_config: TrainerConfig
59
59
  ) -> Iterable[CallbackWithMetadata]:
60
- for callback in config.create_callbacks(root_config):
60
+ for callback in config.create_callbacks(trainer_config):
61
61
  if isinstance(callback, CallbackWithMetadata):
62
62
  yield callback
63
63
  continue
@@ -102,16 +102,16 @@ def _process_and_filter_callbacks(
102
102
  return [callback.callback for callback in callbacks]
103
103
 
104
104
 
105
- def resolve_all_callbacks(root_config: "BaseConfig"):
105
+ def resolve_all_callbacks(trainer_config: TrainerConfig):
106
106
  callback_configs = [
107
107
  config
108
- for config in root_config._nshtrainer_all_callback_configs()
108
+ for config in trainer_config._nshtrainer_all_callback_configs()
109
109
  if config is not None
110
110
  ]
111
111
  callbacks = _process_and_filter_callbacks(
112
112
  callback
113
113
  for callback_config in callback_configs
114
- for callback in _create_callbacks_with_metadata(callback_config, root_config)
114
+ for callback in _create_callbacks_with_metadata(callback_config, trainer_config)
115
115
  )
116
116
  return callbacks
117
117
 
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from .best_checkpoint import BestCheckpoint as BestCheckpoint
3
+ from .best_checkpoint import BestCheckpointCallback as BestCheckpointCallback
4
4
  from .best_checkpoint import (
5
5
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
6
6
  )
@@ -16,7 +16,8 @@ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
16
16
  from ..base import CallbackConfigBase
17
17
 
18
18
  if TYPE_CHECKING:
19
- from ...model.config import BaseConfig
19
+ from ...trainer._config import TrainerConfig
20
+
20
21
 
21
22
  log = logging.getLogger(__name__)
22
23
 
@@ -41,18 +42,20 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
41
42
  @abstractmethod
42
43
  def create_checkpoint(
43
44
  self,
44
- root_config: "BaseConfig",
45
+ trainer_config: TrainerConfig,
45
46
  dirpath: Path,
46
47
  ) -> "CheckpointBase | None": ...
47
48
 
48
49
  @override
49
- def create_callbacks(self, root_config):
50
+ def create_callbacks(self, trainer_config):
50
51
  dirpath = Path(
51
52
  self.dirpath
52
- or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
53
+ or trainer_config.directory.resolve_subdirectory(
54
+ trainer_config.id, "checkpoint"
55
+ )
53
56
  )
54
57
 
55
- if (callback := self.create_checkpoint(root_config, dirpath)) is not None:
58
+ if (callback := self.create_checkpoint(trainer_config, dirpath)) is not None:
56
59
  yield callback
57
60
 
58
61
 
@@ -28,10 +28,10 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
28
28
  """
29
29
 
30
30
  @override
31
- def create_checkpoint(self, root_config, dirpath):
31
+ def create_checkpoint(self, trainer_config, dirpath):
32
32
  # Resolve metric
33
33
  if (metric := self.metric) is None and (
34
- metric := root_config.primary_metric
34
+ metric := trainer_config.primary_metric
35
35
  ) is None:
36
36
  error_msg = (
37
37
  "No metric provided and no primary metric found in the root config. "
@@ -43,11 +43,11 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
43
43
  log.warning(error_msg)
44
44
  return None
45
45
 
46
- return BestCheckpoint(self, dirpath, metric)
46
+ return BestCheckpointCallback(self, dirpath, metric)
47
47
 
48
48
 
49
49
  @final
50
- class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
50
+ class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
51
51
  @property
52
52
  def _metric_name_normalized(self):
53
53
  return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
@@ -18,7 +18,7 @@ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
18
18
  name: Literal["last_checkpoint"] = "last_checkpoint"
19
19
 
20
20
  @override
21
- def create_checkpoint(self, root_config, dirpath):
21
+ def create_checkpoint(self, trainer_config, dirpath):
22
22
  return LastCheckpointCallback(self, dirpath)
23
23
 
24
24
 
@@ -54,13 +54,13 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
54
54
  """Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
55
55
 
56
56
  @override
57
- def create_callbacks(self, root_config):
58
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
59
- root_config.id, "checkpoint"
57
+ def create_callbacks(self, trainer_config):
58
+ dirpath = self.dirpath or trainer_config.directory.resolve_subdirectory(
59
+ trainer_config.id, "checkpoint"
60
60
  )
61
61
 
62
62
  if not (filename := self.filename):
63
- filename = f"on_exception_{root_config.id}"
63
+ filename = f"on_exception_{trainer_config.id}"
64
64
  yield OnExceptionCheckpointCallback(
65
65
  self, dirpath=Path(dirpath), filename=filename
66
66
  )
@@ -1,17 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Literal, cast
4
+ from typing import Literal
5
5
 
6
- from lightning.pytorch import LightningModule, Trainer
7
- from lightning.pytorch.callbacks import Callback
8
6
  from typing_extensions import override
9
7
 
8
+ from .._callback import NTCallbackBase
10
9
  from .base import CallbackConfigBase
11
10
 
12
- if TYPE_CHECKING:
13
- from ..model.config import BaseConfig
14
-
15
11
  log = logging.getLogger(__name__)
16
12
 
17
13
 
@@ -25,14 +21,14 @@ class DebugFlagCallbackConfig(CallbackConfigBase):
25
21
  return self.enabled
26
22
 
27
23
  @override
28
- def create_callbacks(self, root_config):
24
+ def create_callbacks(self, trainer_config):
29
25
  if not self:
30
26
  return
31
27
 
32
28
  yield DebugFlagCallback(self)
33
29
 
34
30
 
35
- class DebugFlagCallback(Callback):
31
+ class DebugFlagCallback(NTCallbackBase):
36
32
  """
37
33
  Sets the debug flag to true in the following circumstances:
38
34
  - fast_dev_run is enabled
@@ -46,27 +42,26 @@ class DebugFlagCallback(Callback):
46
42
  self.config = config
47
43
  del config
48
44
 
45
+ self._debug = False
46
+
49
47
  @override
50
- def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
48
+ def setup(self, trainer, pl_module, stage):
51
49
  if not getattr(trainer, "fast_dev_run", False):
52
50
  return
53
51
 
54
- hparams = cast("BaseConfig", pl_module.hparams)
55
- if not hparams.debug:
52
+ if not trainer.debug:
56
53
  log.critical("Fast dev run detected, setting debug flag to True.")
57
- hparams.debug = True
54
+ trainer.debug = True
58
55
 
59
56
  @override
60
- def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
61
- hparams = cast("BaseConfig", pl_module.hparams)
62
- self._debug = hparams.debug
57
+ def on_sanity_check_start(self, trainer, pl_module):
58
+ self._debug = trainer.debug
63
59
  if not self._debug:
64
60
  log.critical("Enabling debug flag during sanity check routine.")
65
- hparams.debug = True
61
+ trainer.debug = True
66
62
 
67
63
  @override
68
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
69
- hparams = cast("BaseConfig", pl_module.hparams)
64
+ def on_sanity_check_end(self, trainer, pl_module):
70
65
  if not self._debug:
71
66
  log.critical("Sanity check routine complete, disabling debug flag.")
72
- hparams.debug = self._debug
67
+ trainer.debug = self._debug
@@ -5,9 +5,9 @@ import os
5
5
  from pathlib import Path
6
6
  from typing import Literal
7
7
 
8
- from lightning.pytorch import Callback
9
8
  from typing_extensions import override
10
9
 
10
+ from .._callback import NTCallbackBase
11
11
  from .base import CallbackConfigBase
12
12
 
13
13
  log = logging.getLogger(__name__)
@@ -55,14 +55,14 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
55
55
  def __bool__(self):
56
56
  return self.enabled
57
57
 
58
- def create_callbacks(self, root_config):
58
+ def create_callbacks(self, trainer_config):
59
59
  if not self:
60
60
  return
61
61
 
62
62
  yield DirectorySetupCallback(self)
63
63
 
64
64
 
65
- class DirectorySetupCallback(Callback):
65
+ class DirectorySetupCallback(NTCallbackBase):
66
66
  @override
67
67
  def __init__(self, config: DirectorySetupCallbackConfig):
68
68
  super().__init__()
@@ -76,12 +76,7 @@ class DirectorySetupCallback(Callback):
76
76
 
77
77
  # Create a symlink to the root folder for the Runner
78
78
  if self.config.create_symlink_to_nshrunner_root:
79
- # Resolve the base dir
80
- from ..model.config import BaseConfig
81
-
82
- assert isinstance(
83
- config := pl_module.hparams, BaseConfig
84
- ), f"Expected a BaseConfig, got {type(config)}"
85
-
86
- base_dir = config.directory.resolve_run_root_directory(config.id)
79
+ base_dir = trainer.hparams.directory.resolve_run_root_directory(
80
+ trainer.hparams.id
81
+ )
87
82
  _create_symlink_to_nshrunner(base_dir)
@@ -48,12 +48,12 @@ class EarlyStoppingCallbackConfig(CallbackConfigBase):
48
48
  """
49
49
 
50
50
  @override
51
- def create_callbacks(self, root_config):
51
+ def create_callbacks(self, trainer_config):
52
52
  if (metric := self.metric) is None and (
53
- metric := root_config.primary_metric
53
+ metric := trainer_config.primary_metric
54
54
  ) is None:
55
55
  raise ValueError(
56
- "Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
56
+ "Either `metric` or `trainer_config.primary_metric` must be set to use EarlyStopping."
57
57
  )
58
58
 
59
59
  yield EarlyStoppingCallback(self, metric)
@@ -376,7 +376,7 @@ class EMACallbackConfig(CallbackConfigBase):
376
376
  """Offload weights to CPU."""
377
377
 
378
378
  @override
379
- def create_callbacks(self, root_config):
379
+ def create_callbacks(self, trainer_config):
380
380
  yield EMACallback(
381
381
  decay=self.decay,
382
382
  validate_original_weights=self.validate_original_weights,
@@ -70,7 +70,7 @@ class FiniteChecksCallbackConfig(CallbackConfigBase):
70
70
  """Whether to check for None gradients"""
71
71
 
72
72
  @override
73
- def create_callbacks(self, root_config):
73
+ def create_callbacks(self, trainer_config):
74
74
  yield FiniteChecksCallback(
75
75
  nonfinite_grads=self.nonfinite_grads,
76
76
  none_grads=self.none_grads,
@@ -95,5 +95,5 @@ class GradientSkippingCallbackConfig(CallbackConfigBase):
95
95
  """
96
96
 
97
97
  @override
98
- def create_callbacks(self, root_config):
98
+ def create_callbacks(self, trainer_config):
99
99
  yield GradientSkippingCallback(self)
@@ -17,7 +17,7 @@ class LogEpochCallbackConfig(CallbackConfigBase):
17
17
  name: Literal["log_epoch"] = "log_epoch"
18
18
 
19
19
  @override
20
- def create_callbacks(self, root_config):
20
+ def create_callbacks(self, trainer_config):
21
21
  yield LogEpochCallback()
22
22
 
23
23
 
@@ -182,7 +182,7 @@ class NormLoggingCallbackConfig(CallbackConfigBase):
182
182
  )
183
183
 
184
184
  @override
185
- def create_callbacks(self, root_config):
185
+ def create_callbacks(self, trainer_config):
186
186
  if not self:
187
187
  return
188
188
 
@@ -88,5 +88,5 @@ class PrintTableMetricsCallbackConfig(CallbackConfigBase):
88
88
  """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
89
89
 
90
90
  @override
91
- def create_callbacks(self, root_config):
91
+ def create_callbacks(self, trainer_config):
92
92
  yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
@@ -36,7 +36,7 @@ class RLPSanityChecksCallbackConfig(CallbackConfigBase):
36
36
  def __bool__(self):
37
37
  return self.enabled
38
38
 
39
- def create_callbacks(self, root_config):
39
+ def create_callbacks(self, trainer_config):
40
40
  if not self:
41
41
  return
42
42
 
@@ -30,7 +30,7 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
30
30
  name: Literal["shared_parameters"] = "shared_parameters"
31
31
 
32
32
  @override
33
- def create_callbacks(self, root_config):
33
+ def create_callbacks(self, trainer_config):
34
34
  yield SharedParametersCallback(self)
35
35
 
36
36
 
@@ -155,5 +155,5 @@ class EpochTimerCallbackConfig(CallbackConfigBase):
155
155
  name: Literal["epoch_timer"] = "epoch_timer"
156
156
 
157
157
  @override
158
- def create_callbacks(self, root_config):
158
+ def create_callbacks(self, trainer_config):
159
159
  yield EpochTimerCallback()