nshtrainer 0.22.1__tar.gz → 0.23.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/PKG-INFO +1 -1
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/pyproject.toml +1 -1
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/_checkpoint/saver.py +11 -5
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/_base.py +40 -1
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/trainer/trainer.py +1 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/path.py +23 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/README.md +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
import shutil
|
|
3
4
|
from pathlib import Path
|
|
@@ -7,6 +8,8 @@ from lightning.pytorch import Trainer
|
|
|
7
8
|
from ..util.path import get_relative_path
|
|
8
9
|
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
9
10
|
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
def _link_checkpoint(
|
|
12
15
|
filepath: str | Path | os.PathLike,
|
|
@@ -19,11 +22,14 @@ def _link_checkpoint(
|
|
|
19
22
|
linkpath = Path(linkpath)
|
|
20
23
|
|
|
21
24
|
if remove_existing:
|
|
22
|
-
|
|
23
|
-
if linkpath.
|
|
24
|
-
linkpath.
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
try:
|
|
26
|
+
if linkpath.exists():
|
|
27
|
+
if linkpath.is_dir():
|
|
28
|
+
shutil.rmtree(linkpath, ignore_errors=True)
|
|
29
|
+
else:
|
|
30
|
+
linkpath.unlink(missing_ok=True)
|
|
31
|
+
except Exception:
|
|
32
|
+
log.exception(f"Failed to remove {linkpath}")
|
|
27
33
|
|
|
28
34
|
if metadata:
|
|
29
35
|
_remove_checkpoint_metadata(linkpath)
|
|
@@ -11,6 +11,7 @@ from typing_extensions import TypeVar, override
|
|
|
11
11
|
|
|
12
12
|
from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
|
|
13
13
|
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
+
from ...util.path import find_symlinks
|
|
14
15
|
from ..base import CallbackConfigBase
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -116,9 +117,47 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
116
117
|
)
|
|
117
118
|
continue
|
|
118
119
|
|
|
119
|
-
|
|
120
|
+
self._remove_checkpoint_with_link_support(
|
|
121
|
+
trainer, old_ckpt_path, metadata=True
|
|
122
|
+
)
|
|
120
123
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
121
124
|
|
|
125
|
+
def _remove_checkpoint_with_link_support(
|
|
126
|
+
self,
|
|
127
|
+
trainer: Trainer,
|
|
128
|
+
path: Path,
|
|
129
|
+
metadata: bool,
|
|
130
|
+
):
|
|
131
|
+
# Find all the symlinks to the checkpoint
|
|
132
|
+
ckpt_callbacks: list[CheckpointBase] = [
|
|
133
|
+
callback
|
|
134
|
+
for callback in trainer.checkpoint_callbacks
|
|
135
|
+
if isinstance(callback, CheckpointBase) and callback is not self
|
|
136
|
+
]
|
|
137
|
+
symlink_paths = find_symlinks(
|
|
138
|
+
path,
|
|
139
|
+
*[callback.dirpath for callback in ckpt_callbacks],
|
|
140
|
+
glob_pattern=f"*.{self.extension()}",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# If there are no symlinks, just remove the checkpoint
|
|
144
|
+
if not symlink_paths:
|
|
145
|
+
_remove_checkpoint(trainer, path, metadata=metadata)
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
log.debug(
|
|
149
|
+
f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# For the first symlink, we can just move the checkpoint file
|
|
153
|
+
# to the symlink path. For the rest, we need to make new symlinks.
|
|
154
|
+
new_target = symlink_paths.pop(0)
|
|
155
|
+
path.rename(new_target)
|
|
156
|
+
log.debug(f"New symlink target: {new_target}")
|
|
157
|
+
|
|
158
|
+
for symlink_path in symlink_paths:
|
|
159
|
+
_link_checkpoint(new_target, symlink_path, metadata=False)
|
|
160
|
+
|
|
122
161
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
123
162
|
current_metrics: dict[str, Any] = {
|
|
124
163
|
"epoch": trainer.current_epoch,
|
|
@@ -441,6 +441,7 @@ class Trainer(LightningTrainer):
|
|
|
441
441
|
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
442
|
if self.is_global_zero:
|
|
443
443
|
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
444
|
+
self.strategy.barrier("Trainer.save_checkpoint")
|
|
444
445
|
else:
|
|
445
446
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
446
447
|
|
|
@@ -27,3 +27,26 @@ def get_relative_path(source: _Path, destination: _Path):
|
|
|
27
27
|
down = os.sep.join(destination_parts[i:])
|
|
28
28
|
|
|
29
29
|
return Path(os.path.normpath(os.path.join(up, down)))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def find_symlinks(
|
|
33
|
+
target_file: _Path,
|
|
34
|
+
*search_directories: _Path,
|
|
35
|
+
glob_pattern: str = "*",
|
|
36
|
+
):
|
|
37
|
+
target_file = Path(target_file).resolve()
|
|
38
|
+
symlinks: list[Path] = []
|
|
39
|
+
|
|
40
|
+
for search_directory in search_directories:
|
|
41
|
+
search_directory = Path(search_directory)
|
|
42
|
+
for path in search_directory.rglob(glob_pattern):
|
|
43
|
+
if path.is_symlink():
|
|
44
|
+
try:
|
|
45
|
+
link_target = path.resolve()
|
|
46
|
+
if link_target.samefile(target_file):
|
|
47
|
+
symlinks.append(path)
|
|
48
|
+
except FileNotFoundError:
|
|
49
|
+
# Handle broken symlinks
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
return symlinks
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
File without changes
|
{nshtrainer-0.22.1 → nshtrainer-0.23.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|