nshtrainer 0.44.0__py3-none-any.whl → 1.0.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +51 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/_base.py +2 -1
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b9.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.0.dist-info/RECORD +0 -162
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/WHEEL +0 -0
nshtrainer/_hf_hub.py
CHANGED
@@ -7,19 +7,19 @@ import re
|
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from functools import cached_property
|
9
9
|
from pathlib import Path
|
10
|
-
from typing import TYPE_CHECKING, Any, Literal, cast
|
10
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
11
11
|
|
12
12
|
import nshconfig as C
|
13
13
|
from nshrunner._env import SNAPSHOT_DIR
|
14
14
|
from typing_extensions import assert_never, override
|
15
15
|
|
16
16
|
from ._callback import NTCallbackBase
|
17
|
-
from .callbacks.base import CallbackConfigBase
|
17
|
+
from .callbacks.base import CallbackConfigBase, CallbackMetadataConfig
|
18
18
|
|
19
19
|
if TYPE_CHECKING:
|
20
20
|
from huggingface_hub import HfApi # noqa: F401
|
21
21
|
|
22
|
-
from .
|
22
|
+
from .trainer._config import TrainerConfig
|
23
23
|
|
24
24
|
|
25
25
|
log = logging.getLogger(__name__)
|
@@ -42,6 +42,8 @@ class HuggingFaceHubAutoCreateConfig(C.Config):
|
|
42
42
|
class HuggingFaceHubConfig(CallbackConfigBase):
|
43
43
|
"""Configuration options for Hugging Face Hub integration."""
|
44
44
|
|
45
|
+
metadata: ClassVar[CallbackMetadataConfig] = {"ignore_if_exists": True}
|
46
|
+
|
45
47
|
enabled: bool = False
|
46
48
|
"""Enable Hugging Face Hub integration."""
|
47
49
|
|
@@ -82,7 +84,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
82
84
|
return self.enabled
|
83
85
|
|
84
86
|
@override
|
85
|
-
def create_callbacks(self,
|
87
|
+
def create_callbacks(self, trainer_config):
|
86
88
|
# Attempt to login. If it fails, we'll log a warning or error based on the configuration.
|
87
89
|
try:
|
88
90
|
api = _api(self.token)
|
@@ -107,7 +109,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
107
109
|
case _:
|
108
110
|
assert_never(self.on_login_error)
|
109
111
|
|
110
|
-
yield
|
112
|
+
yield HFHubCallback(self)
|
111
113
|
|
112
114
|
|
113
115
|
def _api(token: str | None = None):
|
@@ -138,19 +140,20 @@ def _api(token: str | None = None):
|
|
138
140
|
return api
|
139
141
|
|
140
142
|
|
141
|
-
def _repo_name(api:
|
143
|
+
def _repo_name(api: HfApi, trainer_config: TrainerConfig):
|
142
144
|
username = None
|
143
|
-
if (ac :=
|
145
|
+
if (ac := trainer_config.hf_hub.auto_create) and ac.namespace:
|
144
146
|
username = ac.namespace
|
145
147
|
elif (username := api.whoami().get("name", None)) is None:
|
146
148
|
raise ValueError("Could not get username from Hugging Face Hub.")
|
147
149
|
|
148
150
|
# Sanitize the project (if it exists), run_name, and id
|
149
151
|
parts = []
|
150
|
-
if
|
151
|
-
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-",
|
152
|
-
|
153
|
-
|
152
|
+
if trainer_config.project:
|
153
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", trainer_config.project))
|
154
|
+
if trainer_config.full_name:
|
155
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", trainer_config.full_name))
|
156
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", trainer_config.id))
|
154
157
|
|
155
158
|
# Combine parts and ensure it starts and ends with alphanumeric characters
|
156
159
|
repo_name = "-".join(parts)
|
@@ -179,14 +182,10 @@ class _Upload:
|
|
179
182
|
path_in_repo: Path
|
180
183
|
|
181
184
|
@classmethod
|
182
|
-
def from_local_path(
|
183
|
-
cls,
|
184
|
-
local_path: Path,
|
185
|
-
root_config: "BaseConfig",
|
186
|
-
):
|
185
|
+
def from_local_path(cls, local_path: Path, trainer_config: TrainerConfig):
|
187
186
|
# Resolve the checkpoint directory
|
188
|
-
checkpoint_dir =
|
189
|
-
|
187
|
+
checkpoint_dir = trainer_config.directory.resolve_subdirectory(
|
188
|
+
trainer_config.id, "checkpoint"
|
190
189
|
)
|
191
190
|
|
192
191
|
try:
|
@@ -224,8 +223,7 @@ class HFHubCallback(NTCallbackBase):
|
|
224
223
|
|
225
224
|
@override
|
226
225
|
def setup(self, trainer, pl_module, stage):
|
227
|
-
|
228
|
-
self._repo_id = _repo_name(self.api, root_config)
|
226
|
+
self._repo_id = _repo_name(self.api, trainer.hparams)
|
229
227
|
|
230
228
|
if not self.config or not trainer.is_global_zero:
|
231
229
|
return
|
@@ -234,7 +232,7 @@ class HFHubCallback(NTCallbackBase):
|
|
234
232
|
self._create_repo_if_not_exists()
|
235
233
|
|
236
234
|
# Upload the config and code
|
237
|
-
self._save_config(
|
235
|
+
self._save_config(trainer.hparams)
|
238
236
|
self._save_code()
|
239
237
|
|
240
238
|
@override
|
@@ -248,10 +246,9 @@ class HFHubCallback(NTCallbackBase):
|
|
248
246
|
return
|
249
247
|
|
250
248
|
with self._with_error_handling("save checkpoints"):
|
251
|
-
root_config = cast("BaseConfig", pl_module.hparams)
|
252
249
|
self._save_checkpoint(
|
253
|
-
_Upload.from_local_path(ckpt_path,
|
254
|
-
_Upload.from_local_path(metadata_path,
|
250
|
+
_Upload.from_local_path(ckpt_path, trainer.hparams),
|
251
|
+
_Upload.from_local_path(metadata_path, trainer.hparams)
|
255
252
|
if metadata_path is not None
|
256
253
|
else None,
|
257
254
|
)
|
@@ -300,10 +297,12 @@ class HFHubCallback(NTCallbackBase):
|
|
300
297
|
f"Error checking repository '{self.repo_id}'", exc_info=True
|
301
298
|
)
|
302
299
|
|
303
|
-
def _save_config(self,
|
300
|
+
def _save_config(self, trainer_config: TrainerConfig):
|
304
301
|
with self._with_error_handling("upload config"):
|
305
302
|
self.api.upload_file(
|
306
|
-
path_or_fileobj=
|
303
|
+
path_or_fileobj=trainer_config.model_dump_json(indent=4).encode(
|
304
|
+
"utf-8"
|
305
|
+
),
|
307
306
|
path_in_repo="config.json",
|
308
307
|
repo_id=self.repo_id,
|
309
308
|
repo_type="model",
|
nshtrainer/callbacks/__init__.py
CHANGED
@@ -6,7 +6,7 @@ import nshconfig as C
|
|
6
6
|
|
7
7
|
from . import checkpoint as checkpoint
|
8
8
|
from .base import CallbackConfigBase as CallbackConfigBase
|
9
|
-
from .checkpoint import
|
9
|
+
from .checkpoint import BestCheckpointCallback as BestCheckpointCallback
|
10
10
|
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
11
11
|
from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
|
12
12
|
from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
@@ -49,7 +49,6 @@ from .shared_parameters import SharedParametersCallback as SharedParametersCallb
|
|
49
49
|
from .shared_parameters import (
|
50
50
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
51
51
|
)
|
52
|
-
from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
53
52
|
from .timer import EpochTimerCallback as EpochTimerCallback
|
54
53
|
from .timer import EpochTimerCallbackConfig as EpochTimerCallbackConfig
|
55
54
|
from .wandb_upload_code import WandbUploadCodeCallback as WandbUploadCodeCallback
|
@@ -62,7 +61,6 @@ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
|
62
61
|
CallbackConfig = Annotated[
|
63
62
|
DebugFlagCallbackConfig
|
64
63
|
| EarlyStoppingCallbackConfig
|
65
|
-
| ThroughputMonitorConfig
|
66
64
|
| EpochTimerCallbackConfig
|
67
65
|
| PrintTableMetricsCallbackConfig
|
68
66
|
| FiniteChecksCallbackConfig
|
nshtrainer/callbacks/actsave.py
CHANGED
@@ -4,11 +4,9 @@ import contextlib
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Literal
|
6
6
|
|
7
|
-
from lightning.pytorch import LightningModule, Trainer
|
8
|
-
from lightning.pytorch.callbacks.callback import Callback
|
9
|
-
from nshutils import ActSave
|
10
7
|
from typing_extensions import TypeAlias, override
|
11
8
|
|
9
|
+
from .._callback import NTCallbackBase
|
12
10
|
from .base import CallbackConfigBase
|
13
11
|
|
14
12
|
Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
|
@@ -25,15 +23,17 @@ class ActSaveConfig(CallbackConfigBase):
|
|
25
23
|
return self.enabled
|
26
24
|
|
27
25
|
@override
|
28
|
-
def create_callbacks(self,
|
26
|
+
def create_callbacks(self, trainer_config):
|
29
27
|
yield ActSaveCallback(
|
30
28
|
self,
|
31
29
|
self.save_dir
|
32
|
-
or
|
30
|
+
or trainer_config.directory.resolve_subdirectory(
|
31
|
+
trainer_config.id, "activation"
|
32
|
+
),
|
33
33
|
)
|
34
34
|
|
35
35
|
|
36
|
-
class ActSaveCallback(
|
36
|
+
class ActSaveCallback(NTCallbackBase):
|
37
37
|
def __init__(self, config: ActSaveConfig, save_dir: Path):
|
38
38
|
super().__init__()
|
39
39
|
|
@@ -43,20 +43,20 @@ class ActSaveCallback(Callback):
|
|
43
43
|
self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
|
44
44
|
|
45
45
|
@override
|
46
|
-
def setup(self, trainer
|
46
|
+
def setup(self, trainer, pl_module, stage) -> None:
|
47
47
|
super().setup(trainer, pl_module, stage)
|
48
48
|
|
49
49
|
if not self.config:
|
50
50
|
return
|
51
51
|
|
52
|
+
from nshutils import ActSave
|
53
|
+
|
52
54
|
context = ActSave.enabled(self.save_dir)
|
53
55
|
context.__enter__()
|
54
56
|
self._enabled_context = context
|
55
57
|
|
56
58
|
@override
|
57
|
-
def teardown(
|
58
|
-
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
59
|
-
) -> None:
|
59
|
+
def teardown(self, trainer, pl_module, stage) -> None:
|
60
60
|
super().teardown(trainer, pl_module, stage)
|
61
61
|
|
62
62
|
if not self.config:
|
@@ -66,10 +66,12 @@ class ActSaveCallback(Callback):
|
|
66
66
|
self._enabled_context.__exit__(None, None, None)
|
67
67
|
self._enabled_context = None
|
68
68
|
|
69
|
-
def _on_start(self, stage: Stage, trainer
|
69
|
+
def _on_start(self, stage: Stage, trainer, pl_module):
|
70
70
|
if not self.config:
|
71
71
|
return
|
72
72
|
|
73
|
+
from nshutils import ActSave
|
74
|
+
|
73
75
|
# If we have an active context manager for this stage, exit it
|
74
76
|
if active_contexts := self._active_contexts.get(stage):
|
75
77
|
active_contexts.__exit__(None, None, None)
|
@@ -79,7 +81,7 @@ class ActSaveCallback(Callback):
|
|
79
81
|
context.__enter__()
|
80
82
|
self._active_contexts[stage] = context
|
81
83
|
|
82
|
-
def _on_end(self, stage: Stage, trainer
|
84
|
+
def _on_end(self, stage: Stage, trainer, pl_module):
|
83
85
|
if not self.config:
|
84
86
|
return
|
85
87
|
|
@@ -88,33 +90,33 @@ class ActSaveCallback(Callback):
|
|
88
90
|
active_contexts.__exit__(None, None, None)
|
89
91
|
|
90
92
|
@override
|
91
|
-
def on_train_epoch_start(self, trainer
|
93
|
+
def on_train_epoch_start(self, trainer, pl_module):
|
92
94
|
return self._on_start("train", trainer, pl_module)
|
93
95
|
|
94
96
|
@override
|
95
|
-
def on_train_epoch_end(self, trainer
|
97
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
96
98
|
return self._on_end("train", trainer, pl_module)
|
97
99
|
|
98
100
|
@override
|
99
|
-
def on_validation_epoch_start(self, trainer
|
101
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
100
102
|
return self._on_start("validation", trainer, pl_module)
|
101
103
|
|
102
104
|
@override
|
103
|
-
def on_validation_epoch_end(self, trainer
|
105
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
104
106
|
return self._on_end("validation", trainer, pl_module)
|
105
107
|
|
106
108
|
@override
|
107
|
-
def on_test_epoch_start(self, trainer
|
109
|
+
def on_test_epoch_start(self, trainer, pl_module):
|
108
110
|
return self._on_start("test", trainer, pl_module)
|
109
111
|
|
110
112
|
@override
|
111
|
-
def on_test_epoch_end(self, trainer
|
113
|
+
def on_test_epoch_end(self, trainer, pl_module):
|
112
114
|
return self._on_end("test", trainer, pl_module)
|
113
115
|
|
114
116
|
@override
|
115
|
-
def on_predict_epoch_start(self, trainer
|
117
|
+
def on_predict_epoch_start(self, trainer, pl_module):
|
116
118
|
return self._on_start("predict", trainer, pl_module)
|
117
119
|
|
118
120
|
@override
|
119
|
-
def on_predict_epoch_end(self, trainer
|
121
|
+
def on_predict_epoch_end(self, trainer, pl_module):
|
120
122
|
return self._on_end("predict", trainer, pl_module)
|
nshtrainer/callbacks/base.py
CHANGED
@@ -11,7 +11,7 @@ from lightning.pytorch import Callback
|
|
11
11
|
from typing_extensions import TypedDict, Unpack
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
|
-
from ..
|
14
|
+
from ..trainer._config import TrainerConfig
|
15
15
|
|
16
16
|
|
17
17
|
class CallbackMetadataConfig(TypedDict, total=False):
|
@@ -49,15 +49,15 @@ class CallbackConfigBase(C.Config, ABC):
|
|
49
49
|
|
50
50
|
@abstractmethod
|
51
51
|
def create_callbacks(
|
52
|
-
self,
|
52
|
+
self, trainer_config: TrainerConfig
|
53
53
|
) -> Iterable[Callback | CallbackWithMetadata]: ...
|
54
54
|
|
55
55
|
|
56
56
|
# region Config resolution helpers
|
57
57
|
def _create_callbacks_with_metadata(
|
58
|
-
config: CallbackConfigBase,
|
58
|
+
config: CallbackConfigBase, trainer_config: TrainerConfig
|
59
59
|
) -> Iterable[CallbackWithMetadata]:
|
60
|
-
for callback in config.create_callbacks(
|
60
|
+
for callback in config.create_callbacks(trainer_config):
|
61
61
|
if isinstance(callback, CallbackWithMetadata):
|
62
62
|
yield callback
|
63
63
|
continue
|
@@ -102,16 +102,16 @@ def _process_and_filter_callbacks(
|
|
102
102
|
return [callback.callback for callback in callbacks]
|
103
103
|
|
104
104
|
|
105
|
-
def resolve_all_callbacks(
|
105
|
+
def resolve_all_callbacks(trainer_config: TrainerConfig):
|
106
106
|
callback_configs = [
|
107
107
|
config
|
108
|
-
for config in
|
108
|
+
for config in trainer_config._nshtrainer_all_callback_configs()
|
109
109
|
if config is not None
|
110
110
|
]
|
111
111
|
callbacks = _process_and_filter_callbacks(
|
112
112
|
callback
|
113
113
|
for callback_config in callback_configs
|
114
|
-
for callback in _create_callbacks_with_metadata(callback_config,
|
114
|
+
for callback in _create_callbacks_with_metadata(callback_config, trainer_config)
|
115
115
|
)
|
116
116
|
return callbacks
|
117
117
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from .best_checkpoint import
|
3
|
+
from .best_checkpoint import BestCheckpointCallback as BestCheckpointCallback
|
4
4
|
from .best_checkpoint import (
|
5
5
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
6
6
|
)
|
@@ -16,7 +16,8 @@ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
16
16
|
from ..base import CallbackConfigBase
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
19
|
-
from ...
|
19
|
+
from ...trainer._config import TrainerConfig
|
20
|
+
|
20
21
|
|
21
22
|
log = logging.getLogger(__name__)
|
22
23
|
|
@@ -41,18 +42,20 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
41
42
|
@abstractmethod
|
42
43
|
def create_checkpoint(
|
43
44
|
self,
|
44
|
-
|
45
|
+
trainer_config: TrainerConfig,
|
45
46
|
dirpath: Path,
|
46
47
|
) -> "CheckpointBase | None": ...
|
47
48
|
|
48
49
|
@override
|
49
|
-
def create_callbacks(self,
|
50
|
+
def create_callbacks(self, trainer_config):
|
50
51
|
dirpath = Path(
|
51
52
|
self.dirpath
|
52
|
-
or
|
53
|
+
or trainer_config.directory.resolve_subdirectory(
|
54
|
+
trainer_config.id, "checkpoint"
|
55
|
+
)
|
53
56
|
)
|
54
57
|
|
55
|
-
if (callback := self.create_checkpoint(
|
58
|
+
if (callback := self.create_checkpoint(trainer_config, dirpath)) is not None:
|
56
59
|
yield callback
|
57
60
|
|
58
61
|
|
@@ -28,10 +28,10 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
28
28
|
"""
|
29
29
|
|
30
30
|
@override
|
31
|
-
def create_checkpoint(self,
|
31
|
+
def create_checkpoint(self, trainer_config, dirpath):
|
32
32
|
# Resolve metric
|
33
33
|
if (metric := self.metric) is None and (
|
34
|
-
metric :=
|
34
|
+
metric := trainer_config.primary_metric
|
35
35
|
) is None:
|
36
36
|
error_msg = (
|
37
37
|
"No metric provided and no primary metric found in the root config. "
|
@@ -43,11 +43,11 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
43
43
|
log.warning(error_msg)
|
44
44
|
return None
|
45
45
|
|
46
|
-
return
|
46
|
+
return BestCheckpointCallback(self, dirpath, metric)
|
47
47
|
|
48
48
|
|
49
49
|
@final
|
50
|
-
class
|
50
|
+
class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
|
51
51
|
@property
|
52
52
|
def _metric_name_normalized(self):
|
53
53
|
return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
|
@@ -18,7 +18,7 @@ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
18
18
|
name: Literal["last_checkpoint"] = "last_checkpoint"
|
19
19
|
|
20
20
|
@override
|
21
|
-
def create_checkpoint(self,
|
21
|
+
def create_checkpoint(self, trainer_config, dirpath):
|
22
22
|
return LastCheckpointCallback(self, dirpath)
|
23
23
|
|
24
24
|
|
@@ -54,13 +54,13 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
54
54
|
"""Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
|
55
55
|
|
56
56
|
@override
|
57
|
-
def create_callbacks(self,
|
58
|
-
dirpath = self.dirpath or
|
59
|
-
|
57
|
+
def create_callbacks(self, trainer_config):
|
58
|
+
dirpath = self.dirpath or trainer_config.directory.resolve_subdirectory(
|
59
|
+
trainer_config.id, "checkpoint"
|
60
60
|
)
|
61
61
|
|
62
62
|
if not (filename := self.filename):
|
63
|
-
filename = f"on_exception_{
|
63
|
+
filename = f"on_exception_{trainer_config.id}"
|
64
64
|
yield OnExceptionCheckpointCallback(
|
65
65
|
self, dirpath=Path(dirpath), filename=filename
|
66
66
|
)
|
@@ -1,17 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import
|
4
|
+
from typing import Literal
|
5
5
|
|
6
|
-
from lightning.pytorch import LightningModule, Trainer
|
7
|
-
from lightning.pytorch.callbacks import Callback
|
8
6
|
from typing_extensions import override
|
9
7
|
|
8
|
+
from .._callback import NTCallbackBase
|
10
9
|
from .base import CallbackConfigBase
|
11
10
|
|
12
|
-
if TYPE_CHECKING:
|
13
|
-
from ..model.config import BaseConfig
|
14
|
-
|
15
11
|
log = logging.getLogger(__name__)
|
16
12
|
|
17
13
|
|
@@ -25,14 +21,14 @@ class DebugFlagCallbackConfig(CallbackConfigBase):
|
|
25
21
|
return self.enabled
|
26
22
|
|
27
23
|
@override
|
28
|
-
def create_callbacks(self,
|
24
|
+
def create_callbacks(self, trainer_config):
|
29
25
|
if not self:
|
30
26
|
return
|
31
27
|
|
32
28
|
yield DebugFlagCallback(self)
|
33
29
|
|
34
30
|
|
35
|
-
class DebugFlagCallback(
|
31
|
+
class DebugFlagCallback(NTCallbackBase):
|
36
32
|
"""
|
37
33
|
Sets the debug flag to true in the following circumstances:
|
38
34
|
- fast_dev_run is enabled
|
@@ -46,27 +42,26 @@ class DebugFlagCallback(Callback):
|
|
46
42
|
self.config = config
|
47
43
|
del config
|
48
44
|
|
45
|
+
self._debug = False
|
46
|
+
|
49
47
|
@override
|
50
|
-
def setup(self, trainer
|
48
|
+
def setup(self, trainer, pl_module, stage):
|
51
49
|
if not getattr(trainer, "fast_dev_run", False):
|
52
50
|
return
|
53
51
|
|
54
|
-
|
55
|
-
if not hparams.debug:
|
52
|
+
if not trainer.debug:
|
56
53
|
log.critical("Fast dev run detected, setting debug flag to True.")
|
57
|
-
|
54
|
+
trainer.debug = True
|
58
55
|
|
59
56
|
@override
|
60
|
-
def on_sanity_check_start(self, trainer
|
61
|
-
|
62
|
-
self._debug = hparams.debug
|
57
|
+
def on_sanity_check_start(self, trainer, pl_module):
|
58
|
+
self._debug = trainer.debug
|
63
59
|
if not self._debug:
|
64
60
|
log.critical("Enabling debug flag during sanity check routine.")
|
65
|
-
|
61
|
+
trainer.debug = True
|
66
62
|
|
67
63
|
@override
|
68
|
-
def on_sanity_check_end(self, trainer
|
69
|
-
hparams = cast("BaseConfig", pl_module.hparams)
|
64
|
+
def on_sanity_check_end(self, trainer, pl_module):
|
70
65
|
if not self._debug:
|
71
66
|
log.critical("Sanity check routine complete, disabling debug flag.")
|
72
|
-
|
67
|
+
trainer.debug = self._debug
|
@@ -5,9 +5,9 @@ import os
|
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Literal
|
7
7
|
|
8
|
-
from lightning.pytorch import Callback
|
9
8
|
from typing_extensions import override
|
10
9
|
|
10
|
+
from .._callback import NTCallbackBase
|
11
11
|
from .base import CallbackConfigBase
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
@@ -55,14 +55,14 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
|
|
55
55
|
def __bool__(self):
|
56
56
|
return self.enabled
|
57
57
|
|
58
|
-
def create_callbacks(self,
|
58
|
+
def create_callbacks(self, trainer_config):
|
59
59
|
if not self:
|
60
60
|
return
|
61
61
|
|
62
62
|
yield DirectorySetupCallback(self)
|
63
63
|
|
64
64
|
|
65
|
-
class DirectorySetupCallback(
|
65
|
+
class DirectorySetupCallback(NTCallbackBase):
|
66
66
|
@override
|
67
67
|
def __init__(self, config: DirectorySetupCallbackConfig):
|
68
68
|
super().__init__()
|
@@ -76,12 +76,7 @@ class DirectorySetupCallback(Callback):
|
|
76
76
|
|
77
77
|
# Create a symlink to the root folder for the Runner
|
78
78
|
if self.config.create_symlink_to_nshrunner_root:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
assert isinstance(
|
83
|
-
config := pl_module.hparams, BaseConfig
|
84
|
-
), f"Expected a BaseConfig, got {type(config)}"
|
85
|
-
|
86
|
-
base_dir = config.directory.resolve_run_root_directory(config.id)
|
79
|
+
base_dir = trainer.hparams.directory.resolve_run_root_directory(
|
80
|
+
trainer.hparams.id
|
81
|
+
)
|
87
82
|
_create_symlink_to_nshrunner(base_dir)
|
@@ -48,12 +48,12 @@ class EarlyStoppingCallbackConfig(CallbackConfigBase):
|
|
48
48
|
"""
|
49
49
|
|
50
50
|
@override
|
51
|
-
def create_callbacks(self,
|
51
|
+
def create_callbacks(self, trainer_config):
|
52
52
|
if (metric := self.metric) is None and (
|
53
|
-
metric :=
|
53
|
+
metric := trainer_config.primary_metric
|
54
54
|
) is None:
|
55
55
|
raise ValueError(
|
56
|
-
"Either `metric` or `
|
56
|
+
"Either `metric` or `trainer_config.primary_metric` must be set to use EarlyStopping."
|
57
57
|
)
|
58
58
|
|
59
59
|
yield EarlyStoppingCallback(self, metric)
|
nshtrainer/callbacks/ema.py
CHANGED
@@ -376,7 +376,7 @@ class EMACallbackConfig(CallbackConfigBase):
|
|
376
376
|
"""Offload weights to CPU."""
|
377
377
|
|
378
378
|
@override
|
379
|
-
def create_callbacks(self,
|
379
|
+
def create_callbacks(self, trainer_config):
|
380
380
|
yield EMACallback(
|
381
381
|
decay=self.decay,
|
382
382
|
validate_original_weights=self.validate_original_weights,
|
@@ -70,7 +70,7 @@ class FiniteChecksCallbackConfig(CallbackConfigBase):
|
|
70
70
|
"""Whether to check for None gradients"""
|
71
71
|
|
72
72
|
@override
|
73
|
-
def create_callbacks(self,
|
73
|
+
def create_callbacks(self, trainer_config):
|
74
74
|
yield FiniteChecksCallback(
|
75
75
|
nonfinite_grads=self.nonfinite_grads,
|
76
76
|
none_grads=self.none_grads,
|
@@ -88,5 +88,5 @@ class PrintTableMetricsCallbackConfig(CallbackConfigBase):
|
|
88
88
|
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
89
89
|
|
90
90
|
@override
|
91
|
-
def create_callbacks(self,
|
91
|
+
def create_callbacks(self, trainer_config):
|
92
92
|
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
@@ -30,7 +30,7 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
30
30
|
name: Literal["shared_parameters"] = "shared_parameters"
|
31
31
|
|
32
32
|
@override
|
33
|
-
def create_callbacks(self,
|
33
|
+
def create_callbacks(self, trainer_config):
|
34
34
|
yield SharedParametersCallback(self)
|
35
35
|
|
36
36
|
|
nshtrainer/callbacks/timer.py
CHANGED