nshtrainer 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/_checkpoint/loader.py +319 -0
- nshtrainer/_checkpoint/metadata.py +102 -0
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +49 -501
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/METADATA +3 -1
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/WHEEL +0 -0
|
@@ -1,35 +1,54 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
3
4
|
|
|
4
|
-
from lightning.fabric.utilities.types import _PATH
|
|
5
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
6
6
|
from lightning.pytorch.callbacks import Checkpoint
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
+
from .base import CallbackConfigBase
|
|
10
|
+
|
|
9
11
|
log = logging.getLogger(__name__)
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
15
|
+
kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
16
|
+
|
|
17
|
+
dirpath: str | Path | None = None
|
|
18
|
+
"""Directory path to save the checkpoint file."""
|
|
19
|
+
|
|
20
|
+
filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
|
|
21
|
+
"""Checkpoint filename. This must not include the extension."""
|
|
22
|
+
|
|
23
|
+
save_weights_only: bool = False
|
|
24
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
25
|
+
|
|
26
|
+
latest_symlink_filename: str | None = "latest.ckpt"
|
|
27
|
+
"""Filename for the latest symlink. If None, no symlink will be created."""
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def create_callbacks(self, root_config):
|
|
31
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
32
|
+
root_config.id, "checkpoint"
|
|
33
|
+
)
|
|
34
|
+
dirpath = Path(dirpath)
|
|
35
|
+
|
|
36
|
+
yield LatestEpochCheckpoint(self, dirpath)
|
|
37
|
+
|
|
38
|
+
|
|
12
39
|
class LatestEpochCheckpoint(Checkpoint):
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
dirpath: _PATH,
|
|
18
|
-
filename: str | None = None,
|
|
19
|
-
save_weights_only: bool = False,
|
|
20
|
-
):
|
|
40
|
+
def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
|
|
21
41
|
super().__init__()
|
|
22
42
|
|
|
23
|
-
self.
|
|
24
|
-
self.
|
|
25
|
-
self._save_weights_only = save_weights_only
|
|
43
|
+
self.config = config
|
|
44
|
+
self.dirpath = dirpath
|
|
26
45
|
|
|
27
46
|
# Also, we hold a reference to the last checkpoint path
|
|
28
47
|
# to be able to remove it when a new checkpoint is saved.
|
|
29
48
|
self._last_ckpt_path: Path | None = None
|
|
30
49
|
|
|
31
50
|
def _ckpt_path(self, trainer: Trainer):
|
|
32
|
-
return self.
|
|
51
|
+
return self.dirpath / self.config.filename.format(
|
|
33
52
|
epoch=trainer.current_epoch, step=trainer.global_step
|
|
34
53
|
)
|
|
35
54
|
|
|
@@ -41,5 +60,22 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
41
60
|
|
|
42
61
|
# Save the new checkpoint
|
|
43
62
|
filepath = self._ckpt_path(trainer)
|
|
44
|
-
trainer.save_checkpoint(filepath, self.
|
|
63
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
45
64
|
self._last_ckpt_path = filepath
|
|
65
|
+
|
|
66
|
+
# Create the latest symlink
|
|
67
|
+
if (symlink_filename := self.config.latest_symlink_filename) is not None:
|
|
68
|
+
symlink_path = self.dirpath / symlink_filename
|
|
69
|
+
if symlink_path.exists():
|
|
70
|
+
symlink_path.unlink()
|
|
71
|
+
symlink_path.symlink_to(filepath.name)
|
|
72
|
+
log.info(f"Created latest symlink: {symlink_path}")
|
|
73
|
+
|
|
74
|
+
def latest_checkpoint(self):
|
|
75
|
+
if (symlink_filename := self.config.latest_symlink_filename) is None:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
if not (symlink_path := self.dirpath / symlink_filename).exists():
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
return symlink_path
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
|
|
7
|
+
from lightning.pytorch.callbacks.model_checkpoint import (
|
|
8
|
+
ModelCheckpoint as _ModelCheckpoint,
|
|
9
|
+
)
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from ..metrics import MetricConfig
|
|
13
|
+
from .base import CallbackConfigBase
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..model.config import BaseConfig
|
|
17
|
+
|
|
18
|
+
log = getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _convert_string(input_string: str):
|
|
22
|
+
# Find all variables enclosed in curly braces
|
|
23
|
+
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
24
|
+
|
|
25
|
+
# Replace each variable with its corresponding key-value pair
|
|
26
|
+
output_string = input_string
|
|
27
|
+
for variable in variables:
|
|
28
|
+
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
29
|
+
key_name = variable
|
|
30
|
+
if ":" in variable:
|
|
31
|
+
key_name, _ = variable.split(":", 1)
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
# Replace '/' with '_' in the key name
|
|
35
|
+
key_name = key_name.replace("/", "_")
|
|
36
|
+
output_string = output_string.replace(
|
|
37
|
+
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return output_string
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
44
|
+
"""Arguments for the ModelCheckpoint callback."""
|
|
45
|
+
|
|
46
|
+
kind: Literal["model_checkpoint"] = "model_checkpoint"
|
|
47
|
+
|
|
48
|
+
dirpath: str | Path | None = None
|
|
49
|
+
"""
|
|
50
|
+
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
filename: str | None = None
|
|
54
|
+
"""
|
|
55
|
+
Checkpoint filename.
|
|
56
|
+
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
metric: MetricConfig | None = None
|
|
60
|
+
"""
|
|
61
|
+
Metric to monitor for saving checkpoints.
|
|
62
|
+
If None, the primary metric of the runner will be used, if available.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
verbose: bool = False
|
|
66
|
+
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
67
|
+
|
|
68
|
+
save_last: Literal[True, False, "link"] | None = "link"
|
|
69
|
+
"""
|
|
70
|
+
Whether to save the last checkpoint.
|
|
71
|
+
If True, saves a copy of the last checkpoint separately.
|
|
72
|
+
If "link", creates a symbolic link to the last checkpoint.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
save_top_k: int = 1
|
|
76
|
+
"""
|
|
77
|
+
Number of best models to save.
|
|
78
|
+
If -1, all models are saved.
|
|
79
|
+
If 0, no models are saved.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
save_weights_only: bool = False
|
|
83
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
84
|
+
|
|
85
|
+
auto_insert_metric_name: bool = True
|
|
86
|
+
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
87
|
+
|
|
88
|
+
every_n_train_steps: int | None = None
|
|
89
|
+
"""
|
|
90
|
+
Number of training steps between checkpoints.
|
|
91
|
+
If None or 0, no checkpoints are saved during training.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
train_time_interval: timedelta | None = None
|
|
95
|
+
"""
|
|
96
|
+
Time interval between checkpoints during training.
|
|
97
|
+
If None, no checkpoints are saved during training based on time.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
every_n_epochs: int | None = None
|
|
101
|
+
"""
|
|
102
|
+
Number of epochs between checkpoints.
|
|
103
|
+
If None or 0, no checkpoints are saved at the end of epochs.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
save_on_train_epoch_end: bool | None = None
|
|
107
|
+
"""
|
|
108
|
+
Whether to run checkpointing at the end of the training epoch.
|
|
109
|
+
If False, checkpointing runs at the end of the validation.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
enable_version_counter: bool = True
|
|
113
|
+
"""Whether to append a version to the existing file name."""
|
|
114
|
+
|
|
115
|
+
auto_append_metric: bool = True
|
|
116
|
+
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
117
|
+
|
|
118
|
+
def metric_or_default(self, root_config: "BaseConfig"):
|
|
119
|
+
if self.metric is not None:
|
|
120
|
+
return self.metric
|
|
121
|
+
if root_config.primary_metric is not None:
|
|
122
|
+
return root_config.primary_metric
|
|
123
|
+
raise ValueError("Primary metric must be provided if metric is not specified.")
|
|
124
|
+
|
|
125
|
+
def resolve_filename(self, root_config: "BaseConfig"):
|
|
126
|
+
metric = self.metric_or_default(root_config)
|
|
127
|
+
|
|
128
|
+
filename = self.filename
|
|
129
|
+
if not filename:
|
|
130
|
+
filename = "{epoch}-{step}"
|
|
131
|
+
if self.auto_append_metric:
|
|
132
|
+
filename = f"{filename}-{{{metric.validation_monitor}}}"
|
|
133
|
+
|
|
134
|
+
if self.auto_insert_metric_name and filename:
|
|
135
|
+
new_filename = _convert_string(filename)
|
|
136
|
+
log.critical(
|
|
137
|
+
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
138
|
+
)
|
|
139
|
+
filename = new_filename
|
|
140
|
+
|
|
141
|
+
return filename
|
|
142
|
+
|
|
143
|
+
@override
|
|
144
|
+
def create_callbacks(self, root_config):
|
|
145
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
146
|
+
root_config.id, "checkpoint"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
metric = self.metric_or_default(root_config)
|
|
150
|
+
filename = self.resolve_filename(root_config)
|
|
151
|
+
|
|
152
|
+
yield ModelCheckpoint(
|
|
153
|
+
self,
|
|
154
|
+
dirpath=Path(dirpath),
|
|
155
|
+
filename=filename,
|
|
156
|
+
metric=metric,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class ModelCheckpoint(_ModelCheckpoint):
|
|
161
|
+
@override
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
config: ModelCheckpointCallbackConfig,
|
|
165
|
+
dirpath: Path,
|
|
166
|
+
filename: str,
|
|
167
|
+
metric: MetricConfig,
|
|
168
|
+
):
|
|
169
|
+
self.config = config
|
|
170
|
+
del config
|
|
171
|
+
|
|
172
|
+
super().__init__(
|
|
173
|
+
dirpath=dirpath,
|
|
174
|
+
filename=filename,
|
|
175
|
+
monitor=metric.validation_monitor,
|
|
176
|
+
mode=metric.mode,
|
|
177
|
+
verbose=self.config.verbose,
|
|
178
|
+
save_last=self.config.save_last,
|
|
179
|
+
save_top_k=self.config.save_top_k,
|
|
180
|
+
save_weights_only=self.config.save_weights_only,
|
|
181
|
+
auto_insert_metric_name=False,
|
|
182
|
+
every_n_train_steps=self.config.every_n_train_steps,
|
|
183
|
+
train_time_interval=self.config.train_time_interval,
|
|
184
|
+
every_n_epochs=self.config.every_n_epochs,
|
|
185
|
+
save_on_train_epoch_end=self.config.save_on_train_epoch_end,
|
|
186
|
+
enable_version_counter=self.config.enable_version_counter,
|
|
187
|
+
)
|
|
@@ -1,16 +1,82 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import datetime
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal
|
|
5
7
|
|
|
6
|
-
from lightning.pytorch import Trainer
|
|
8
|
+
from lightning.pytorch import Trainer as LightningTrainer
|
|
7
9
|
from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
|
|
8
10
|
from typing_extensions import override
|
|
9
11
|
|
|
12
|
+
from .base import CallbackConfigBase
|
|
13
|
+
|
|
10
14
|
log = logging.getLogger(__name__)
|
|
11
15
|
|
|
12
16
|
|
|
17
|
+
@contextlib.contextmanager
|
|
18
|
+
def _monkey_patch_disable_barrier(trainer: LightningTrainer):
|
|
19
|
+
"""
|
|
20
|
+
Monkey-patch the strategy instance to make the barrier operation a no-op.
|
|
21
|
+
|
|
22
|
+
We do this because `save_checkpoint` calls `barrier`. This is okay in most
|
|
23
|
+
cases, but when we want to save a checkpoint in the case of an exception,
|
|
24
|
+
`barrier` causes a deadlock. So we monkey-patch the strategy instance to
|
|
25
|
+
make the barrier operation a no-op.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# We monkey-patch the barrier method to do nothing.
|
|
29
|
+
original_barrier = trainer.strategy.barrier
|
|
30
|
+
|
|
31
|
+
def new_barrier(*args, **kwargs):
|
|
32
|
+
log.warning("Monkey-patched no-op barrier.")
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
trainer.strategy.barrier = new_barrier
|
|
36
|
+
log.warning("Monkey-patched barrier to no-op.")
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
yield
|
|
40
|
+
finally:
|
|
41
|
+
trainer.strategy.barrier = original_barrier
|
|
42
|
+
log.warning("Reverted monkey-patched barrier.")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
46
|
+
kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
|
|
47
|
+
|
|
48
|
+
dirpath: str | Path | None = None
|
|
49
|
+
"""Directory path to save the checkpoint file."""
|
|
50
|
+
|
|
51
|
+
filename: str | None = None
|
|
52
|
+
"""Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def create_callbacks(self, root_config):
|
|
56
|
+
from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
|
57
|
+
|
|
58
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
59
|
+
root_config.id, "checkpoint"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not (filename := self.filename):
|
|
63
|
+
filename = f"on_exception_{root_config.id}"
|
|
64
|
+
yield OnExceptionCheckpoint(self, dirpath=Path(dirpath), filename=filename)
|
|
65
|
+
|
|
66
|
+
|
|
13
67
|
class OnExceptionCheckpoint(_OnExceptionCheckpoint):
|
|
68
|
+
@override
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
config: OnExceptionCheckpointCallbackConfig,
|
|
72
|
+
dirpath: Path,
|
|
73
|
+
filename: str,
|
|
74
|
+
):
|
|
75
|
+
self.config = config
|
|
76
|
+
del config
|
|
77
|
+
|
|
78
|
+
super().__init__(dirpath, filename)
|
|
79
|
+
|
|
14
80
|
@property
|
|
15
81
|
@override
|
|
16
82
|
def ckpt_path(self) -> str:
|
|
@@ -22,23 +88,11 @@ class OnExceptionCheckpoint(_OnExceptionCheckpoint):
|
|
|
22
88
|
return f"{ckpt_path}_{timestamp}{ext}"
|
|
23
89
|
|
|
24
90
|
@override
|
|
25
|
-
def on_exception(self, trainer:
|
|
26
|
-
#
|
|
27
|
-
#
|
|
28
|
-
|
|
29
|
-
#
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
|
|
34
|
-
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
|
|
35
|
-
)
|
|
36
|
-
checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
|
|
37
|
-
trainer.strategy.save_checkpoint(
|
|
38
|
-
checkpoint, self.ckpt_path, storage_options=None
|
|
39
|
-
)
|
|
40
|
-
# self.strategy.barrier("Trainer.save_checkpoint") # <-- This is disabled
|
|
41
|
-
|
|
42
|
-
@override
|
|
43
|
-
def teardown(self, trainer: Trainer, *_: Any, **__: Any) -> None:
|
|
44
|
-
trainer.strategy.remove_checkpoint(self.ckpt_path)
|
|
91
|
+
def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
|
|
92
|
+
# Monkey-patch the strategy instance to make the barrier operation a no-op.
|
|
93
|
+
# We do this because `save_checkpoint` calls `barrier`. This is okay in most
|
|
94
|
+
# cases, but when we want to save a checkpoint in the case of an exception,
|
|
95
|
+
# `barrier` causes a deadlock. So we monkey-patch the strategy instance to
|
|
96
|
+
# make the barrier operation a no-op.
|
|
97
|
+
with _monkey_patch_disable_barrier(trainer):
|
|
98
|
+
return super().on_exception(trainer, *args, **kwargs)
|
|
@@ -86,5 +86,5 @@ class PrintTableMetricsConfig(CallbackConfigBase):
|
|
|
86
86
|
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
|
87
87
|
|
|
88
88
|
@override
|
|
89
|
-
def
|
|
89
|
+
def create_callbacks(self, root_config):
|
|
90
90
|
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
|
@@ -52,5 +52,5 @@ class ThroughputMonitorConfig(CallbackConfigBase):
|
|
|
52
52
|
"""Number of batches to use for a rolling average."""
|
|
53
53
|
|
|
54
54
|
@override
|
|
55
|
-
def
|
|
55
|
+
def create_callbacks(self, root_config):
|
|
56
56
|
yield ThroughputMonitor(window_size=self.window_size)
|
nshtrainer/callbacks/timer.py
CHANGED
nshtrainer/ll/__init__.py
CHANGED
|
@@ -21,7 +21,6 @@ from .log import init_python_logging as init_python_logging
|
|
|
21
21
|
from .log import lovely as lovely
|
|
22
22
|
from .log import pretty as pretty
|
|
23
23
|
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
|
-
from .model import ActSaveConfig as ActSaveConfig
|
|
25
24
|
from .model import Base as Base
|
|
26
25
|
from .model import BaseConfig as BaseConfig
|
|
27
26
|
from .model import BaseLoggerConfig as BaseLoggerConfig
|
nshtrainer/ll/actsave.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
1
|
from nshutils.actsave import * # type: ignore # noqa: F403
|
|
2
2
|
|
|
3
|
-
from nshtrainer.actsave import
|
|
3
|
+
from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
|
|
4
|
+
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._config import MetricConfig as MetricConfig
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MetricConfig(C.Config):
|
|
8
|
+
name: str
|
|
9
|
+
"""The name of the primary metric."""
|
|
10
|
+
|
|
11
|
+
mode: Literal["min", "max"]
|
|
12
|
+
"""
|
|
13
|
+
The mode of the primary metric:
|
|
14
|
+
- "min" for metrics that should be minimized (e.g., loss)
|
|
15
|
+
- "max" for metrics that should be maximized (e.g., accuracy)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def validation_monitor(self) -> str:
|
|
20
|
+
return f"val/{self.name}"
|
|
21
|
+
|
|
22
|
+
def __post_init__(self):
|
|
23
|
+
for split in ("train", "val", "test", "predict"):
|
|
24
|
+
if self.name.startswith(f"{split}/"):
|
|
25
|
+
raise ValueError(
|
|
26
|
+
f"Primary metric name should not start with '{split}/'. "
|
|
27
|
+
f"Just use '{self.name[len(split) + 1:]}' instead. "
|
|
28
|
+
"The split name is automatically added depending on the context."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def loss(cls, mode: Literal["min", "max"] = "min"):
|
|
33
|
+
return cls(name="loss", mode=mode)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def best(self):
|
|
37
|
+
return builtins.min if self.mode == "min" else builtins.max
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -1,8 +1,18 @@
|
|
|
1
1
|
from typing_extensions import TypeAlias
|
|
2
2
|
|
|
3
|
+
from ._environment import (
|
|
4
|
+
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
5
|
+
)
|
|
6
|
+
from ._environment import EnvironmentConfig as EnvironmentConfig
|
|
7
|
+
from ._environment import (
|
|
8
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
9
|
+
)
|
|
10
|
+
from ._environment import (
|
|
11
|
+
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
12
|
+
)
|
|
13
|
+
from ._environment import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
|
|
3
14
|
from .base import Base as Base
|
|
4
15
|
from .base import LightningModuleBase as LightningModuleBase
|
|
5
|
-
from .config import ActSaveConfig as ActSaveConfig
|
|
6
16
|
from .config import BaseConfig as BaseConfig
|
|
7
17
|
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
8
18
|
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
@@ -10,16 +20,6 @@ from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
|
10
20
|
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
11
21
|
from .config import DirectoryConfig as DirectoryConfig
|
|
12
22
|
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
13
|
-
from .config import (
|
|
14
|
-
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
15
|
-
)
|
|
16
|
-
from .config import EnvironmentConfig as EnvironmentConfig
|
|
17
|
-
from .config import (
|
|
18
|
-
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
19
|
-
)
|
|
20
|
-
from .config import (
|
|
21
|
-
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
22
|
-
)
|
|
23
23
|
from .config import GradientClippingConfig as GradientClippingConfig
|
|
24
24
|
from .config import (
|
|
25
25
|
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|