nshtrainer 0.11.2__tar.gz → 0.11.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/PKG-INFO +1 -1
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/pyproject.toml +1 -1
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_checkpoint/metadata.py +16 -29
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/__init__.py +3 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/checkpoint/__init__.py +4 -0
- nshtrainer-0.11.4/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +171 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +4 -4
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/metrics/_config.py +5 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/config.py +4 -1
- nshtrainer-0.11.4/src/nshtrainer/util/_useful_types.py +307 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/README.md +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -43,6 +43,16 @@ class CheckpointMetadata(C.Config):
|
|
|
43
43
|
def from_file(cls, path: Path):
|
|
44
44
|
return cls.model_validate_json(path.read_text())
|
|
45
45
|
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_ckpt_path(cls, checkpoint_path: Path):
|
|
48
|
+
if not (
|
|
49
|
+
metadata_path := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
50
|
+
).exists():
|
|
51
|
+
raise FileNotFoundError(
|
|
52
|
+
f"Metadata file not found for checkpoint: {checkpoint_path}"
|
|
53
|
+
)
|
|
54
|
+
return cls.from_file(metadata_path)
|
|
55
|
+
|
|
46
56
|
|
|
47
57
|
def _generate_checkpoint_metadata(
|
|
48
58
|
config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
|
|
@@ -136,36 +146,13 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
136
146
|
log.debug(f"Linked {path} to {linked_path}")
|
|
137
147
|
|
|
138
148
|
|
|
139
|
-
def _checkpoint_sort_key_fn(key: Callable[[CheckpointMetadata, Path], Any]):
|
|
140
|
-
def sort_key_fn(checkpoint_path: Path):
|
|
141
|
-
if not (p := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)).exists():
|
|
142
|
-
raise FileNotFoundError(f"Metadata file not found: {p}")
|
|
143
|
-
|
|
144
|
-
nonlocal key
|
|
145
|
-
return key(CheckpointMetadata.from_file(p), p)
|
|
146
|
-
|
|
147
|
-
return sort_key_fn
|
|
148
|
-
|
|
149
|
-
|
|
150
149
|
def _sort_ckpts_by_metadata(
|
|
151
150
|
checkpoint_paths: list[Path],
|
|
152
151
|
key: Callable[[CheckpointMetadata, Path], Any],
|
|
153
|
-
|
|
152
|
+
reverse: bool = False,
|
|
154
153
|
):
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
continue
|
|
161
|
-
|
|
162
|
-
no_metadata_paths.append(path)
|
|
163
|
-
|
|
164
|
-
if no_metadata_paths:
|
|
165
|
-
log.warning(
|
|
166
|
-
f"Metadata file not found on {len(no_metadata_paths)} checkpoints: {no_metadata_paths}\n"
|
|
167
|
-
"Falling back to sorting by last modified time."
|
|
168
|
-
)
|
|
169
|
-
return sorted(checkpoint_paths, key=fallback_key)
|
|
170
|
-
|
|
171
|
-
return sorted(checkpoint_paths, key=_checkpoint_sort_key_fn(key))
|
|
154
|
+
return sorted(
|
|
155
|
+
[(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
|
|
156
|
+
key=lambda args_tuple: key(*args_tuple),
|
|
157
|
+
reverse=reverse,
|
|
158
|
+
)
|
|
@@ -4,6 +4,8 @@ import nshconfig as C
|
|
|
4
4
|
|
|
5
5
|
from . import checkpoint as checkpoint
|
|
6
6
|
from .base import CallbackConfigBase as CallbackConfigBase
|
|
7
|
+
from .checkpoint import BestCheckpoint as BestCheckpoint
|
|
8
|
+
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
7
9
|
from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
8
10
|
from .checkpoint import (
|
|
9
11
|
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
@@ -43,6 +45,7 @@ CallbackConfig = Annotated[
|
|
|
43
45
|
| NormLoggingConfig
|
|
44
46
|
| GradientSkippingConfig
|
|
45
47
|
| EMAConfig
|
|
48
|
+
| BestCheckpointCallbackConfig
|
|
46
49
|
| ModelCheckpointCallbackConfig
|
|
47
50
|
| LatestEpochCheckpointCallbackConfig
|
|
48
51
|
| OnExceptionCheckpointCallbackConfig
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from .best_checkpoint import BestCheckpoint as BestCheckpoint
|
|
2
|
+
from .best_checkpoint import (
|
|
3
|
+
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
4
|
+
)
|
|
1
5
|
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
2
6
|
from .latest_epoch_checkpoint import (
|
|
3
7
|
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Checkpoint
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from ..._checkpoint.metadata import _sort_ckpts_by_metadata
|
|
10
|
+
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
11
|
+
from ...metrics._config import MetricConfig
|
|
12
|
+
from ..base import CallbackConfigBase
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BestCheckpointCallbackConfig(CallbackConfigBase):
|
|
18
|
+
name: Literal["best_checkpoint"] = "best_checkpoint"
|
|
19
|
+
|
|
20
|
+
dirpath: str | Path | None = None
|
|
21
|
+
"""Directory path to save the checkpoint file."""
|
|
22
|
+
|
|
23
|
+
filename: str = "epoch{epoch:02d}_step{step:04d}"
|
|
24
|
+
"""Checkpoint filename. This must not include the extension."""
|
|
25
|
+
|
|
26
|
+
save_weights_only: bool = False
|
|
27
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
28
|
+
|
|
29
|
+
metric: MetricConfig | None = None
|
|
30
|
+
"""Metric to monitor, or `None` to use the default metric."""
|
|
31
|
+
|
|
32
|
+
best_symlink_filename: str | None = "best"
|
|
33
|
+
"""Filename for the best symlink. If None, no symlink will be created."""
|
|
34
|
+
|
|
35
|
+
save_top_k: int | Literal["all"] = 1
|
|
36
|
+
"""The number of best checkpoints to keep."""
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def create_callbacks(self, root_config):
|
|
40
|
+
dirpath = Path(
|
|
41
|
+
self.dirpath
|
|
42
|
+
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Resolve metric
|
|
46
|
+
if (metric := self.metric) is None and (
|
|
47
|
+
metric := root_config.primary_metric
|
|
48
|
+
) is None:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"No metric provided and no primary metric found in the root config"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
yield BestCheckpoint(self, metric, dirpath)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def _save_top_k_value(self):
|
|
57
|
+
return float("inf" if self.save_top_k == "all" else self.save_top_k)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class BestCheckpoint(Checkpoint):
|
|
61
|
+
PREFIX = "best_"
|
|
62
|
+
EXTENSION = ".ckpt"
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
config: BestCheckpointCallbackConfig,
|
|
67
|
+
metric: MetricConfig,
|
|
68
|
+
dirpath: Path,
|
|
69
|
+
):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.config = config
|
|
72
|
+
self.metric = metric
|
|
73
|
+
self.dirpath = dirpath
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
77
|
+
self._save_best_checkpoint(trainer)
|
|
78
|
+
|
|
79
|
+
def _best_symlink_filename(self):
|
|
80
|
+
if (filename := self.config.best_symlink_filename) is None:
|
|
81
|
+
return None
|
|
82
|
+
return f"{filename}{self.EXTENSION}"
|
|
83
|
+
|
|
84
|
+
def _ckpt_path(self, trainer: Trainer):
|
|
85
|
+
filename = self.config.filename.format(
|
|
86
|
+
epoch=trainer.current_epoch, step=trainer.global_step
|
|
87
|
+
)
|
|
88
|
+
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
89
|
+
return self.dirpath / filename
|
|
90
|
+
|
|
91
|
+
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
92
|
+
for ckpt_path in ckpt_paths:
|
|
93
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
94
|
+
|
|
95
|
+
def _get_metric_value(self, metrics: dict[str, Any]):
|
|
96
|
+
return metrics.get(
|
|
97
|
+
self.metric.validation_monitor,
|
|
98
|
+
float("-inf" if self.metric.mode == "max" else "inf"),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def _sorted_ckpts(self):
|
|
102
|
+
ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
103
|
+
return _sort_ckpts_by_metadata(
|
|
104
|
+
ckpt_paths,
|
|
105
|
+
key=lambda meta, _: self._get_metric_value(meta.metrics),
|
|
106
|
+
reverse=(self.metric.mode == "min"),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
|
|
110
|
+
# Resolve the symlink filename
|
|
111
|
+
if (symlink_filename := self._best_symlink_filename()) is None:
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
# If the symlink already exists and points to the best checkpoint,
|
|
115
|
+
# then we don't need to create a new symlink.
|
|
116
|
+
symlink_path = self.dirpath / symlink_filename
|
|
117
|
+
if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
_link_checkpoint(
|
|
121
|
+
trainer,
|
|
122
|
+
best_ckpt_path,
|
|
123
|
+
symlink_path,
|
|
124
|
+
metadata=True,
|
|
125
|
+
barrier=False,
|
|
126
|
+
)
|
|
127
|
+
log.debug(f"Created best symlink: {symlink_path}")
|
|
128
|
+
|
|
129
|
+
def _save_best_checkpoint(self, trainer: Trainer):
|
|
130
|
+
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
131
|
+
log.warning(
|
|
132
|
+
f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
|
|
133
|
+
)
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# Get sorted checkpoints
|
|
137
|
+
sorted_ckpts = self._sorted_ckpts()
|
|
138
|
+
|
|
139
|
+
# If the current model is worse than the worst checkpoint,
|
|
140
|
+
# and we have already saved the maximum number of checkpoints,
|
|
141
|
+
# then don't save the current model.
|
|
142
|
+
if len(
|
|
143
|
+
sorted_ckpts
|
|
144
|
+
) >= self.config._save_top_k_value and not self.metric.is_better(
|
|
145
|
+
current,
|
|
146
|
+
self._get_metric_value(sorted_ckpts[-1][0].metrics),
|
|
147
|
+
):
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
# Save the current model
|
|
151
|
+
filepath = self._ckpt_path(trainer)
|
|
152
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
153
|
+
|
|
154
|
+
# Remove worst checkpoint if we've reached save_top_k
|
|
155
|
+
# NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
|
|
156
|
+
if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
|
|
157
|
+
# Get the sorted checkpoints again because now we have added a new checkpoint.
|
|
158
|
+
# We could optimize this by adding the new checkpoint to the sorted list,
|
|
159
|
+
# and then sorting it in place, but this is simpler.
|
|
160
|
+
sorted_ckpts = self._sorted_ckpts()
|
|
161
|
+
self._remove_checkpoints(
|
|
162
|
+
trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Create symlink to best model
|
|
166
|
+
_, best_ckpt_path = sorted_ckpts[0]
|
|
167
|
+
self._create_symlink(trainer, best_ckpt_path)
|
|
168
|
+
log.debug(f"Saved best checkpoint: {filepath}")
|
|
169
|
+
|
|
170
|
+
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
171
|
+
trainer.strategy.barrier()
|
|
@@ -86,15 +86,15 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
86
86
|
ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
|
|
87
87
|
|
|
88
88
|
# Sort by epoch, then step, then last modified
|
|
89
|
-
|
|
89
|
+
metadata_and_ckpt_paths = _sort_ckpts_by_metadata(
|
|
90
90
|
ckpt_paths,
|
|
91
91
|
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
92
|
-
|
|
93
|
-
# ^ Called if metadata is not found on all checkpoints
|
|
92
|
+
reverse=True,
|
|
94
93
|
)
|
|
95
94
|
|
|
96
95
|
# Remove all but the latest k checkpoints
|
|
97
|
-
|
|
96
|
+
ckpts_to_remove = metadata_and_ckpt_paths[latest_k:]
|
|
97
|
+
self._remove_checkpoints(trainer, [p for _, p in ckpts_to_remove])
|
|
98
98
|
|
|
99
99
|
def _save_new_checkpoint(self, trainer: Trainer):
|
|
100
100
|
# Remove old checkpoints
|
|
@@ -3,6 +3,8 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import nshconfig as C
|
|
5
5
|
|
|
6
|
+
from ..util._useful_types import SupportsRichComparisonT
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
class MetricConfig(C.Config):
|
|
8
10
|
name: str
|
|
@@ -35,3 +37,6 @@ class MetricConfig(C.Config):
|
|
|
35
37
|
@property
|
|
36
38
|
def best(self):
|
|
37
39
|
return builtins.min if self.mode == "min" else builtins.max
|
|
40
|
+
|
|
41
|
+
def is_better(self, a: SupportsRichComparisonT, b: SupportsRichComparisonT) -> bool:
|
|
42
|
+
return self.best(a, b) == a
|
|
@@ -37,6 +37,7 @@ from typing_extensions import Self, TypedDict, TypeVar, override
|
|
|
37
37
|
|
|
38
38
|
from .._checkpoint.loader import CheckpointLoadingConfig
|
|
39
39
|
from ..callbacks import (
|
|
40
|
+
BestCheckpointCallbackConfig,
|
|
40
41
|
CallbackConfig,
|
|
41
42
|
LatestEpochCheckpointCallbackConfig,
|
|
42
43
|
ModelCheckpointCallbackConfig,
|
|
@@ -771,6 +772,7 @@ class ReproducibilityConfig(C.Config):
|
|
|
771
772
|
|
|
772
773
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
773
774
|
ModelCheckpointCallbackConfig
|
|
775
|
+
| BestCheckpointCallbackConfig
|
|
774
776
|
| LatestEpochCheckpointCallbackConfig
|
|
775
777
|
| OnExceptionCheckpointCallbackConfig,
|
|
776
778
|
C.Field(discriminator="name"),
|
|
@@ -782,7 +784,8 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
782
784
|
"""Enable checkpoint saving."""
|
|
783
785
|
|
|
784
786
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
785
|
-
ModelCheckpointCallbackConfig(),
|
|
787
|
+
# ModelCheckpointCallbackConfig(),
|
|
788
|
+
BestCheckpointCallbackConfig(),
|
|
786
789
|
LatestEpochCheckpointCallbackConfig(),
|
|
787
790
|
OnExceptionCheckpointCallbackConfig(),
|
|
788
791
|
]
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
|
|
6
|
+
from collections.abc import Set as AbstractSet
|
|
7
|
+
from os import PathLike
|
|
8
|
+
from typing import Any, TypeVar, overload
|
|
9
|
+
|
|
10
|
+
from typing_extensions import Buffer, Literal, Protocol, SupportsIndex, TypeAlias
|
|
11
|
+
|
|
12
|
+
_KT = TypeVar("_KT")
|
|
13
|
+
_KT_co = TypeVar("_KT_co", covariant=True)
|
|
14
|
+
_KT_contra = TypeVar("_KT_contra", contravariant=True)
|
|
15
|
+
_VT = TypeVar("_VT")
|
|
16
|
+
_VT_co = TypeVar("_VT_co", covariant=True)
|
|
17
|
+
_T = TypeVar("_T")
|
|
18
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
|
19
|
+
_T_contra = TypeVar("_T_contra", contravariant=True)
|
|
20
|
+
|
|
21
|
+
# For partially known annotations. Usually, fields where type annotations
|
|
22
|
+
# haven't been added are left unannotated, but in some situations this
|
|
23
|
+
# isn't possible or a type is already partially known. In cases like these,
|
|
24
|
+
# use Incomplete instead of Any as a marker. For example, use
|
|
25
|
+
# "Incomplete | None" instead of "Any | None".
|
|
26
|
+
Incomplete: TypeAlias = Any
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class IdentityFunction(Protocol):
|
|
30
|
+
def __call__(self, __x: _T) -> _T: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ====================
|
|
34
|
+
# Comparison protocols
|
|
35
|
+
# ====================
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SupportsDunderLT(Protocol[_T_contra]):
|
|
39
|
+
def __lt__(self, __other: _T_contra) -> bool: ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SupportsDunderGT(Protocol[_T_contra]):
|
|
43
|
+
def __gt__(self, __other: _T_contra) -> bool: ...
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SupportsDunderLE(Protocol[_T_contra]):
|
|
47
|
+
def __le__(self, __other: _T_contra) -> bool: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SupportsDunderGE(Protocol[_T_contra]):
|
|
51
|
+
def __ge__(self, __other: _T_contra) -> bool: ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SupportsAllComparisons(
|
|
55
|
+
SupportsDunderLT[Any],
|
|
56
|
+
SupportsDunderGT[Any],
|
|
57
|
+
SupportsDunderLE[Any],
|
|
58
|
+
SupportsDunderGE[Any],
|
|
59
|
+
Protocol,
|
|
60
|
+
): ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any]
|
|
64
|
+
SupportsRichComparisonT = TypeVar(
|
|
65
|
+
"SupportsRichComparisonT", bound=SupportsRichComparison
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# ====================
|
|
69
|
+
# Dunder protocols
|
|
70
|
+
# ====================
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SupportsNext(Protocol[_T_co]):
|
|
74
|
+
def __next__(self) -> _T_co: ...
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SupportsAnext(Protocol[_T_co]):
|
|
78
|
+
def __anext__(self) -> Awaitable[_T_co]: ...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
|
82
|
+
def __add__(self, __x: _T_contra) -> _T_co: ...
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class SupportsRAdd(Protocol[_T_contra, _T_co]):
|
|
86
|
+
def __radd__(self, __x: _T_contra) -> _T_co: ...
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SupportsSub(Protocol[_T_contra, _T_co]):
|
|
90
|
+
def __sub__(self, __x: _T_contra) -> _T_co: ...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class SupportsRSub(Protocol[_T_contra, _T_co]):
|
|
94
|
+
def __rsub__(self, __x: _T_contra) -> _T_co: ...
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SupportsDivMod(Protocol[_T_contra, _T_co]):
|
|
98
|
+
def __divmod__(self, __other: _T_contra) -> _T_co: ...
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SupportsRDivMod(Protocol[_T_contra, _T_co]):
|
|
102
|
+
def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# This protocol is generic over the iterator type, while Iterable is
|
|
106
|
+
# generic over the type that is iterated over.
|
|
107
|
+
class SupportsIter(Protocol[_T_co]):
|
|
108
|
+
def __iter__(self) -> _T_co: ...
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# This protocol is generic over the iterator type, while AsyncIterable is
|
|
112
|
+
# generic over the type that is iterated over.
|
|
113
|
+
class SupportsAiter(Protocol[_T_co]):
|
|
114
|
+
def __aiter__(self) -> _T_co: ...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SupportsLenAndGetItem(Protocol[_T_co]):
|
|
118
|
+
def __len__(self) -> int: ...
|
|
119
|
+
def __getitem__(self, __k: int) -> _T_co: ...
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SupportsTrunc(Protocol):
|
|
123
|
+
def __trunc__(self) -> int: ...
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# ====================
|
|
127
|
+
# Mapping-like protocols
|
|
128
|
+
# ====================
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class SupportsItems(Protocol[_KT_co, _VT_co]):
|
|
132
|
+
def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
|
|
136
|
+
def keys(self) -> Iterable[_KT]: ...
|
|
137
|
+
def __getitem__(self, __key: _KT) -> _VT_co: ...
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
|
|
141
|
+
def __contains__(self, __x: Any) -> bool: ...
|
|
142
|
+
def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
|
|
146
|
+
def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
|
|
147
|
+
def __delitem__(self, __key: _KT_contra) -> None: ...
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ====================
|
|
151
|
+
# File handling
|
|
152
|
+
# ====================
|
|
153
|
+
|
|
154
|
+
StrPath: TypeAlias = str | PathLike[str]
|
|
155
|
+
BytesPath: TypeAlias = bytes | PathLike[bytes]
|
|
156
|
+
StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
|
|
157
|
+
|
|
158
|
+
OpenTextModeUpdating: TypeAlias = Literal[
|
|
159
|
+
"r+",
|
|
160
|
+
"+r",
|
|
161
|
+
"rt+",
|
|
162
|
+
"r+t",
|
|
163
|
+
"+rt",
|
|
164
|
+
"tr+",
|
|
165
|
+
"t+r",
|
|
166
|
+
"+tr",
|
|
167
|
+
"w+",
|
|
168
|
+
"+w",
|
|
169
|
+
"wt+",
|
|
170
|
+
"w+t",
|
|
171
|
+
"+wt",
|
|
172
|
+
"tw+",
|
|
173
|
+
"t+w",
|
|
174
|
+
"+tw",
|
|
175
|
+
"a+",
|
|
176
|
+
"+a",
|
|
177
|
+
"at+",
|
|
178
|
+
"a+t",
|
|
179
|
+
"+at",
|
|
180
|
+
"ta+",
|
|
181
|
+
"t+a",
|
|
182
|
+
"+ta",
|
|
183
|
+
"x+",
|
|
184
|
+
"+x",
|
|
185
|
+
"xt+",
|
|
186
|
+
"x+t",
|
|
187
|
+
"+xt",
|
|
188
|
+
"tx+",
|
|
189
|
+
"t+x",
|
|
190
|
+
"+tx",
|
|
191
|
+
]
|
|
192
|
+
OpenTextModeWriting: TypeAlias = Literal[
|
|
193
|
+
"w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
|
|
194
|
+
]
|
|
195
|
+
OpenTextModeReading: TypeAlias = Literal[
|
|
196
|
+
"r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
|
|
197
|
+
]
|
|
198
|
+
OpenTextMode: TypeAlias = (
|
|
199
|
+
OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
|
|
200
|
+
)
|
|
201
|
+
OpenBinaryModeUpdating: TypeAlias = Literal[
|
|
202
|
+
"rb+",
|
|
203
|
+
"r+b",
|
|
204
|
+
"+rb",
|
|
205
|
+
"br+",
|
|
206
|
+
"b+r",
|
|
207
|
+
"+br",
|
|
208
|
+
"wb+",
|
|
209
|
+
"w+b",
|
|
210
|
+
"+wb",
|
|
211
|
+
"bw+",
|
|
212
|
+
"b+w",
|
|
213
|
+
"+bw",
|
|
214
|
+
"ab+",
|
|
215
|
+
"a+b",
|
|
216
|
+
"+ab",
|
|
217
|
+
"ba+",
|
|
218
|
+
"b+a",
|
|
219
|
+
"+ba",
|
|
220
|
+
"xb+",
|
|
221
|
+
"x+b",
|
|
222
|
+
"+xb",
|
|
223
|
+
"bx+",
|
|
224
|
+
"b+x",
|
|
225
|
+
"+bx",
|
|
226
|
+
]
|
|
227
|
+
OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
|
|
228
|
+
OpenBinaryModeReading: TypeAlias = Literal[
|
|
229
|
+
"rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
|
|
230
|
+
]
|
|
231
|
+
OpenBinaryMode: TypeAlias = (
|
|
232
|
+
OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class HasFileno(Protocol):
|
|
237
|
+
def fileno(self) -> int: ...
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
FileDescriptor: TypeAlias = int
|
|
241
|
+
FileDescriptorLike: TypeAlias = int | HasFileno
|
|
242
|
+
FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class SupportsRead(Protocol[_T_co]):
|
|
246
|
+
def read(self, __length: int = ...) -> _T_co: ...
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class SupportsReadline(Protocol[_T_co]):
|
|
250
|
+
def readline(self, __length: int = ...) -> _T_co: ...
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class SupportsNoArgReadline(Protocol[_T_co]):
|
|
254
|
+
def readline(self) -> _T_co: ...
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class SupportsWrite(Protocol[_T_contra]):
|
|
258
|
+
def write(self, __s: _T_contra) -> object: ...
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# ====================
|
|
262
|
+
# Buffer protocols
|
|
263
|
+
# ====================
|
|
264
|
+
|
|
265
|
+
# Unfortunately PEP 688 does not allow us to distinguish read-only
|
|
266
|
+
# from writable buffers. We use these aliases for readability for now.
|
|
267
|
+
# Perhaps a future extension of the buffer protocol will allow us to
|
|
268
|
+
# distinguish these cases in the type system.
|
|
269
|
+
ReadOnlyBuffer: TypeAlias = Buffer
|
|
270
|
+
# Anything that implements the read-write buffer interface.
|
|
271
|
+
WriteableBuffer: TypeAlias = Buffer
|
|
272
|
+
# Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
|
|
273
|
+
ReadableBuffer: TypeAlias = Buffer
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class SliceableBuffer(Buffer, Protocol):
|
|
277
|
+
def __getitem__(self, __slice: slice) -> Sequence[int]: ...
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class IndexableBuffer(Buffer, Protocol):
|
|
281
|
+
def __getitem__(self, __i: int) -> int: ...
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
|
|
285
|
+
def __contains__(self, __x: Any) -> bool: ...
|
|
286
|
+
@overload
|
|
287
|
+
def __getitem__(self, __slice: slice) -> Sequence[int]: ...
|
|
288
|
+
@overload
|
|
289
|
+
def __getitem__(self, __i: int) -> int: ...
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class SizedBuffer(Sized, Buffer, Protocol): ...
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
|
|
296
|
+
# This works because str.__contains__ does not accept object (either in typeshed or at runtime)
|
|
297
|
+
class SequenceNotStr(Protocol[_T_co]):
|
|
298
|
+
@overload
|
|
299
|
+
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
|
|
300
|
+
@overload
|
|
301
|
+
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
|
|
302
|
+
def __contains__(self, value: object, /) -> bool: ...
|
|
303
|
+
def __len__(self) -> int: ...
|
|
304
|
+
def __iter__(self) -> Iterator[_T_co]: ...
|
|
305
|
+
def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
|
|
306
|
+
def count(self, value: Any, /) -> int: ...
|
|
307
|
+
def __reversed__(self) -> Iterator[_T_co]: ...
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/_experimental/flops/module_tracker.py
RENAMED
|
File without changes
|
{nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.11.2 → nshtrainer-0.11.4}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|