nshtrainer 0.10.9__tar.gz → 0.10.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/PKG-INFO +1 -1
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/pyproject.toml +1 -1
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_checkpoint/metadata.py +73 -0
- nshtrainer-0.10.11/src/nshtrainer/_checkpoint/saver.py +52 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/early_stopping.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/finite_checks.py +2 -2
- nshtrainer-0.10.11/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +114 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/model_checkpoint.py +20 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/norm_logging.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/throughput_monitor.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/wandb_watch.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/data/balanced_batch_sampler.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/base.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/config.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/callback.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/debug.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/rlp_sanity_checks.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/shared_parameters.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/util/environment.py +2 -2
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/util/seed.py +2 -2
- nshtrainer-0.10.9/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -74
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/README.md +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/_environment.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import datetime
|
|
3
3
|
import logging
|
|
4
|
+
import shutil
|
|
5
|
+
from collections.abc import Callable
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from typing import TYPE_CHECKING, Any, cast
|
|
6
8
|
|
|
@@ -100,3 +102,74 @@ def _write_checkpoint_metadata(
|
|
|
100
102
|
log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
|
|
101
103
|
else:
|
|
102
104
|
log.info(f"Checkpoint metadata written to {checkpoint_path}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
108
|
+
for path in (
|
|
109
|
+
checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
|
|
110
|
+
checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
|
|
111
|
+
):
|
|
112
|
+
try:
|
|
113
|
+
path.unlink(missing_ok=True)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
log.warning(f"Failed to remove {path}: {e}")
|
|
116
|
+
else:
|
|
117
|
+
log.info(f"Removed {path}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
|
121
|
+
# First, remove any existing metadata files
|
|
122
|
+
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
123
|
+
|
|
124
|
+
# Link the metadata files to the new checkpoint
|
|
125
|
+
for path in (
|
|
126
|
+
checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
|
|
127
|
+
checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
|
|
128
|
+
):
|
|
129
|
+
linked_path = linked_checkpoint_path.with_suffix(path.suffix)
|
|
130
|
+
try:
|
|
131
|
+
try:
|
|
132
|
+
linked_path.symlink_to(path)
|
|
133
|
+
except OSError:
|
|
134
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
135
|
+
# fall back to copying the file
|
|
136
|
+
shutil.copy(path, linked_path)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
log.warning(f"Failed to link {path} to {linked_path}: {e}")
|
|
139
|
+
else:
|
|
140
|
+
log.info(f"Linked {path} to {linked_path}")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _checkpoint_sort_key_fn(key: Callable[[CheckpointMetadata, Path], Any]):
|
|
144
|
+
def sort_key_fn(checkpoint_path: Path):
|
|
145
|
+
if not (p := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)).exists():
|
|
146
|
+
raise FileNotFoundError(f"Metadata file not found: {p}")
|
|
147
|
+
|
|
148
|
+
nonlocal key
|
|
149
|
+
return key(CheckpointMetadata.from_file(p), p)
|
|
150
|
+
|
|
151
|
+
return sort_key_fn
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _sort_ckpts_by_metadata(
|
|
155
|
+
checkpoint_paths: list[Path],
|
|
156
|
+
key: Callable[[CheckpointMetadata, Path], Any],
|
|
157
|
+
fallback_key: Callable[[Path], Any],
|
|
158
|
+
):
|
|
159
|
+
# First, let's make sure all the metadata files exist.
|
|
160
|
+
# If not, use the fallback function to sort the checkpoints.
|
|
161
|
+
no_metadata_paths: list[Path] = []
|
|
162
|
+
for path in checkpoint_paths:
|
|
163
|
+
if (path.with_suffix(METADATA_PATH_SUFFIX)).exists():
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
no_metadata_paths.append(path)
|
|
167
|
+
|
|
168
|
+
if no_metadata_paths:
|
|
169
|
+
log.warning(
|
|
170
|
+
f"Metadata file not found on {len(no_metadata_paths)} checkpoints: {no_metadata_paths}\n"
|
|
171
|
+
"Falling back to sorting by last modified time."
|
|
172
|
+
)
|
|
173
|
+
return sorted(checkpoint_paths, key=fallback_key)
|
|
174
|
+
|
|
175
|
+
return sorted(checkpoint_paths, key=_checkpoint_sort_key_fn(key))
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import Trainer
|
|
6
|
+
|
|
7
|
+
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _link_checkpoint(
|
|
11
|
+
trainer: Trainer,
|
|
12
|
+
filepath: str | Path | os.PathLike,
|
|
13
|
+
linkpath: str | Path | os.PathLike,
|
|
14
|
+
*,
|
|
15
|
+
barrier: bool,
|
|
16
|
+
metadata: bool,
|
|
17
|
+
):
|
|
18
|
+
if not isinstance(filepath, Path):
|
|
19
|
+
filepath = Path(filepath)
|
|
20
|
+
if not isinstance(linkpath, Path):
|
|
21
|
+
linkpath = Path(linkpath)
|
|
22
|
+
|
|
23
|
+
if trainer.is_global_zero:
|
|
24
|
+
if linkpath.exists():
|
|
25
|
+
if linkpath.is_symlink() or linkpath.is_file():
|
|
26
|
+
linkpath.unlink()
|
|
27
|
+
elif linkpath.is_dir():
|
|
28
|
+
shutil.rmtree(linkpath)
|
|
29
|
+
_remove_checkpoint_metadata(linkpath)
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
target_path = filepath.relative_to(linkpath.parent)
|
|
33
|
+
linkpath.symlink_to(target_path)
|
|
34
|
+
except OSError:
|
|
35
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
36
|
+
# fall back to copying the file
|
|
37
|
+
shutil.copy(filepath, linkpath)
|
|
38
|
+
|
|
39
|
+
_link_checkpoint_metadata(filepath, linkpath)
|
|
40
|
+
if barrier:
|
|
41
|
+
trainer.strategy.barrier()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _remove_checkpoint(
|
|
45
|
+
trainer: Trainer,
|
|
46
|
+
filepath: str | Path | os.PathLike,
|
|
47
|
+
remove_metadata: bool = True,
|
|
48
|
+
):
|
|
49
|
+
if not isinstance(filepath, Path):
|
|
50
|
+
filepath = Path(filepath)
|
|
51
|
+
trainer.strategy.remove_checkpoint(filepath)
|
|
52
|
+
_remove_checkpoint_metadata(filepath)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import math
|
|
2
|
-
from logging import getLogger
|
|
3
3
|
|
|
4
4
|
from lightning.fabric.utilities.rank_zero import _get_rank
|
|
5
5
|
from lightning.pytorch import Trainer
|
|
@@ -7,7 +7,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
|
|
|
7
7
|
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
|
|
8
8
|
from typing_extensions import override
|
|
9
9
|
|
|
10
|
-
log = getLogger(__name__)
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class EarlyStopping(_EarlyStopping):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Literal
|
|
3
3
|
|
|
4
4
|
import torch
|
|
@@ -7,7 +7,7 @@ from typing_extensions import override
|
|
|
7
7
|
|
|
8
8
|
from .base import CallbackConfigBase
|
|
9
9
|
|
|
10
|
-
log = getLogger(__name__)
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def finite_checks(
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Checkpoint
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from .._checkpoint.metadata import _sort_ckpts_by_metadata
|
|
10
|
+
from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
11
|
+
from .base import CallbackConfigBase
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
17
|
+
name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
18
|
+
|
|
19
|
+
dirpath: str | Path | None = None
|
|
20
|
+
"""Directory path to save the checkpoint file."""
|
|
21
|
+
|
|
22
|
+
filename: str = "epoch{epoch:02d}_step{step:04d}"
|
|
23
|
+
"""Checkpoint filename. This must not include the extension."""
|
|
24
|
+
|
|
25
|
+
save_weights_only: bool = False
|
|
26
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
27
|
+
|
|
28
|
+
latest_symlink_filename: str | None = "latest"
|
|
29
|
+
"""Filename for the latest symlink. If None, no symlink will be created."""
|
|
30
|
+
|
|
31
|
+
latest_k: int | Literal["all"] = 1
|
|
32
|
+
"""Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def create_callbacks(self, root_config):
|
|
36
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
37
|
+
root_config.id, "checkpoint"
|
|
38
|
+
)
|
|
39
|
+
dirpath = Path(dirpath)
|
|
40
|
+
|
|
41
|
+
yield LatestEpochCheckpoint(self, dirpath)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LatestEpochCheckpoint(Checkpoint):
|
|
45
|
+
PREFIX = "latest_"
|
|
46
|
+
EXTENSION = ".ckpt"
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self.config = config
|
|
52
|
+
self.dirpath = dirpath
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
56
|
+
self._save_new_checkpoint(trainer)
|
|
57
|
+
|
|
58
|
+
def _latest_symlink_filename(self):
|
|
59
|
+
if (filename := self.config.latest_symlink_filename) is None:
|
|
60
|
+
return None
|
|
61
|
+
return f"{filename}{self.EXTENSION}"
|
|
62
|
+
|
|
63
|
+
def _ckpt_path(self, trainer: Trainer):
|
|
64
|
+
filename = self.config.filename.format(
|
|
65
|
+
epoch=trainer.current_epoch, step=trainer.global_step
|
|
66
|
+
)
|
|
67
|
+
filename = f"{self.PREFIX}{filename}.{self.EXTENSION}"
|
|
68
|
+
return self.dirpath / filename
|
|
69
|
+
|
|
70
|
+
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
71
|
+
for ckpt_path in ckpt_paths:
|
|
72
|
+
_remove_checkpoint(trainer, ckpt_path, remove_metadata=True)
|
|
73
|
+
|
|
74
|
+
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
75
|
+
if (latest_k := self.config.latest_k) == "all":
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
# Get all configs, ignoring the latest symlink
|
|
79
|
+
ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
80
|
+
# Ignore the latest symlink
|
|
81
|
+
if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
|
|
82
|
+
ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
|
|
83
|
+
|
|
84
|
+
# Sort by epoch, then step, then last modified
|
|
85
|
+
ckpt_paths = _sort_ckpts_by_metadata(
|
|
86
|
+
ckpt_paths,
|
|
87
|
+
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
88
|
+
fallback_key=lambda p: p.stat().st_mtime,
|
|
89
|
+
# ^ Called if metadata is not found on all checkpoints
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Remove all but the latest k checkpoints
|
|
93
|
+
ckpts_to_remove = ckpt_paths[:-latest_k]
|
|
94
|
+
self._remove_checkpoints(trainer, ckpts_to_remove)
|
|
95
|
+
|
|
96
|
+
def _save_new_checkpoint(self, trainer: Trainer):
|
|
97
|
+
# Remove old checkpoints
|
|
98
|
+
self._remove_old_checkpoints(trainer)
|
|
99
|
+
|
|
100
|
+
# Save the new checkpoint
|
|
101
|
+
filepath = self._ckpt_path(trainer)
|
|
102
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
103
|
+
|
|
104
|
+
# Create the latest symlink
|
|
105
|
+
if (symlink_filename := self._latest_symlink_filename()) is not None:
|
|
106
|
+
symlink_path = self.dirpath / symlink_filename
|
|
107
|
+
_link_checkpoint(
|
|
108
|
+
trainer,
|
|
109
|
+
filepath,
|
|
110
|
+
symlink_path,
|
|
111
|
+
barrier=True,
|
|
112
|
+
metadata=True,
|
|
113
|
+
)
|
|
114
|
+
log.info(f"Created latest symlink: {symlink_path}")
|
|
@@ -1,21 +1,23 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import re
|
|
2
3
|
from datetime import timedelta
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import TYPE_CHECKING, Literal
|
|
6
6
|
|
|
7
|
+
from lightning.pytorch import Trainer
|
|
7
8
|
from lightning.pytorch.callbacks.model_checkpoint import (
|
|
8
9
|
ModelCheckpoint as _ModelCheckpoint,
|
|
9
10
|
)
|
|
10
11
|
from typing_extensions import override
|
|
11
12
|
|
|
13
|
+
from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
12
14
|
from ..metrics import MetricConfig
|
|
13
15
|
from .base import CallbackConfigBase
|
|
14
16
|
|
|
15
17
|
if TYPE_CHECKING:
|
|
16
18
|
from ..model.config import BaseConfig
|
|
17
19
|
|
|
18
|
-
log = getLogger(__name__)
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def _convert_string(input_string: str):
|
|
@@ -158,6 +160,8 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
|
158
160
|
|
|
159
161
|
|
|
160
162
|
class ModelCheckpoint(_ModelCheckpoint):
|
|
163
|
+
CHECKPOINT_NAME_LAST = "best"
|
|
164
|
+
|
|
161
165
|
@override
|
|
162
166
|
def __init__(
|
|
163
167
|
self,
|
|
@@ -185,3 +189,17 @@ class ModelCheckpoint(_ModelCheckpoint):
|
|
|
185
189
|
save_on_train_epoch_end=self.config.save_on_train_epoch_end,
|
|
186
190
|
enable_version_counter=self.config.enable_version_counter,
|
|
187
191
|
)
|
|
192
|
+
|
|
193
|
+
@override
|
|
194
|
+
def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
195
|
+
return _link_checkpoint(
|
|
196
|
+
trainer,
|
|
197
|
+
filepath,
|
|
198
|
+
linkpath,
|
|
199
|
+
barrier=True,
|
|
200
|
+
metadata=True,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@override
|
|
204
|
+
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
205
|
+
return _remove_checkpoint(trainer, filepath, remove_metadata=True)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Literal, cast
|
|
3
3
|
|
|
4
4
|
import torch
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
|
|
10
10
|
from .base import CallbackConfigBase
|
|
11
11
|
|
|
12
|
-
log = getLogger(__name__)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def grad_norm(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
|
|
3
3
|
|
|
4
4
|
from typing_extensions import NotRequired, override
|
|
@@ -6,7 +6,7 @@ from typing_extensions import NotRequired, override
|
|
|
6
6
|
from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
|
|
7
7
|
from .base import CallbackConfigBase
|
|
8
8
|
|
|
9
|
-
log = getLogger(__name__)
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class ThroughputMonitorBatchStats(TypedDict):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Literal, Protocol, cast, runtime_checkable
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
|
|
10
10
|
from .base import CallbackConfigBase
|
|
11
11
|
|
|
12
|
-
log = getLogger(__name__)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
@runtime_checkable
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import heapq
|
|
2
|
+
import logging
|
|
2
3
|
from functools import cached_property
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from typing import Any, Protocol, runtime_checkable
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
@@ -10,7 +10,7 @@ from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
|
|
|
10
10
|
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
-
log = getLogger(__name__)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import logging
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import MutableMapping
|
|
4
|
-
from logging import getLogger
|
|
5
5
|
from typing import IO, TYPE_CHECKING, Any, Generic, cast
|
|
6
6
|
|
|
7
7
|
import torch
|
|
@@ -21,7 +21,7 @@ from .modules.profiler import ProfilerMixin
|
|
|
21
21
|
from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
|
|
22
22
|
from .modules.shared_parameters import SharedParametersModuleMixin
|
|
23
23
|
|
|
24
|
-
log = getLogger(__name__)
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
|
|
27
27
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
import string
|
|
4
5
|
import time
|
|
@@ -6,7 +7,6 @@ import warnings
|
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
7
8
|
from collections.abc import Iterable, Sequence
|
|
8
9
|
from datetime import timedelta
|
|
9
|
-
from logging import getLogger
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from typing import (
|
|
12
12
|
Annotated,
|
|
@@ -46,7 +46,7 @@ from ..callbacks.base import CallbackConfigBase
|
|
|
46
46
|
from ..metrics import MetricConfig
|
|
47
47
|
from ._environment import EnvironmentConfig
|
|
48
48
|
|
|
49
|
-
log = getLogger(__name__)
|
|
49
|
+
log = logging.getLogger(__name__)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class IdSeedWarning(Warning):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from collections import abc
|
|
2
3
|
from collections.abc import Callable, Iterable
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from typing import Any, TypeAlias, cast, final
|
|
5
5
|
|
|
6
6
|
from lightning.pytorch import Callback, LightningModule
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
|
|
10
10
|
from ...util.typing_utils import mixin_base_type
|
|
11
11
|
|
|
12
|
-
log = getLogger(__name__)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
|
|
15
15
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from collections.abc import Mapping
|
|
2
|
-
from logging import getLogger
|
|
3
3
|
from typing import cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
@@ -14,7 +14,7 @@ from ...util.typing_utils import mixin_base_type
|
|
|
14
14
|
from ..config import BaseConfig
|
|
15
15
|
from .callback import CallbackModuleMixin
|
|
16
16
|
|
|
17
|
-
log = getLogger(__name__)
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from collections.abc import Sequence
|
|
2
|
-
from logging import getLogger
|
|
3
3
|
from typing import cast
|
|
4
4
|
|
|
5
5
|
import torch.nn as nn
|
|
@@ -10,7 +10,7 @@ from ...util.typing_utils import mixin_base_type
|
|
|
10
10
|
from ..config import BaseConfig
|
|
11
11
|
from .callback import CallbackRegistrarModuleMixin
|
|
12
12
|
|
|
13
|
-
log = getLogger(__name__)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
|
|
@@ -1,74 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
-
from lightning.pytorch.callbacks import Checkpoint
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from .base import CallbackConfigBase
|
|
10
|
-
|
|
11
|
-
log = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
15
|
-
name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
16
|
-
|
|
17
|
-
dirpath: str | Path | None = None
|
|
18
|
-
"""Directory path to save the checkpoint file."""
|
|
19
|
-
|
|
20
|
-
filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
|
|
21
|
-
"""Checkpoint filename. This must not include the extension."""
|
|
22
|
-
|
|
23
|
-
save_weights_only: bool = False
|
|
24
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
25
|
-
|
|
26
|
-
latest_symlink_filename: str | None = "latest.ckpt"
|
|
27
|
-
"""Filename for the latest symlink. If None, no symlink will be created."""
|
|
28
|
-
|
|
29
|
-
@override
|
|
30
|
-
def create_callbacks(self, root_config):
|
|
31
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
32
|
-
root_config.id, "checkpoint"
|
|
33
|
-
)
|
|
34
|
-
dirpath = Path(dirpath)
|
|
35
|
-
|
|
36
|
-
yield LatestEpochCheckpoint(self, dirpath)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class LatestEpochCheckpoint(Checkpoint):
|
|
40
|
-
def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
|
|
41
|
-
super().__init__()
|
|
42
|
-
|
|
43
|
-
self.config = config
|
|
44
|
-
self.dirpath = dirpath
|
|
45
|
-
|
|
46
|
-
def _ckpt_path(self, trainer: Trainer):
|
|
47
|
-
return self.dirpath / self.config.filename.format(
|
|
48
|
-
epoch=trainer.current_epoch, step=trainer.global_step
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
@override
|
|
52
|
-
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
53
|
-
# Save the new checkpoint
|
|
54
|
-
filepath = self._ckpt_path(trainer)
|
|
55
|
-
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
56
|
-
|
|
57
|
-
# Create the latest symlink
|
|
58
|
-
if (
|
|
59
|
-
trainer.is_global_zero
|
|
60
|
-
and (symlink_filename := self.config.latest_symlink_filename) is not None
|
|
61
|
-
):
|
|
62
|
-
symlink_path = self.dirpath / symlink_filename
|
|
63
|
-
symlink_path.unlink(missing_ok=True)
|
|
64
|
-
symlink_path.symlink_to(filepath.name)
|
|
65
|
-
log.info(f"Created latest symlink: {symlink_path}")
|
|
66
|
-
|
|
67
|
-
def latest_checkpoint(self):
|
|
68
|
-
if (symlink_filename := self.config.latest_symlink_filename) is None:
|
|
69
|
-
return None
|
|
70
|
-
|
|
71
|
-
if not (symlink_path := self.dirpath / symlink_filename).exists():
|
|
72
|
-
return None
|
|
73
|
-
|
|
74
|
-
return symlink_path
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/_experimental/flops/module_tracker.py
RENAMED
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/callbacks/on_exception_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py
RENAMED
|
File without changes
|
{nshtrainer-0.10.9 → nshtrainer-0.10.11}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|