nshtrainer 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +2 -1
- nshtrainer/_checkpoint/saver.py +11 -5
- nshtrainer/callbacks/checkpoint/_base.py +40 -1
- nshtrainer/trainer/trainer.py +3 -1
- nshtrainer/util/path.py +23 -0
- {nshtrainer-0.22.0.dist-info → nshtrainer-0.23.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.22.0.dist-info → nshtrainer-0.23.0.dist-info}/RECORD +8 -8
- {nshtrainer-0.22.0.dist-info → nshtrainer-0.23.0.dist-info}/WHEEL +0 -0
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
+
from ..util.path import get_relative_path
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -145,7 +146,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
145
146
|
# We should store the path as a relative path
|
|
146
147
|
# to the metadata file to avoid issues with
|
|
147
148
|
# moving the checkpoint directory
|
|
148
|
-
linked_path.symlink_to(
|
|
149
|
+
linked_path.symlink_to(get_relative_path(linked_path, path))
|
|
149
150
|
except OSError:
|
|
150
151
|
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
151
152
|
# fall back to copying the file
|
nshtrainer/_checkpoint/saver.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
import shutil
|
|
3
4
|
from pathlib import Path
|
|
@@ -7,6 +8,8 @@ from lightning.pytorch import Trainer
|
|
|
7
8
|
from ..util.path import get_relative_path
|
|
8
9
|
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
9
10
|
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
def _link_checkpoint(
|
|
12
15
|
filepath: str | Path | os.PathLike,
|
|
@@ -19,11 +22,14 @@ def _link_checkpoint(
|
|
|
19
22
|
linkpath = Path(linkpath)
|
|
20
23
|
|
|
21
24
|
if remove_existing:
|
|
22
|
-
|
|
23
|
-
if linkpath.
|
|
24
|
-
linkpath.
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
try:
|
|
26
|
+
if linkpath.exists():
|
|
27
|
+
if linkpath.is_dir():
|
|
28
|
+
shutil.rmtree(linkpath, ignore_errors=True)
|
|
29
|
+
else:
|
|
30
|
+
linkpath.unlink(missing_ok=True)
|
|
31
|
+
except Exception:
|
|
32
|
+
log.exception(f"Failed to remove {linkpath}")
|
|
27
33
|
|
|
28
34
|
if metadata:
|
|
29
35
|
_remove_checkpoint_metadata(linkpath)
|
|
@@ -11,6 +11,7 @@ from typing_extensions import TypeVar, override
|
|
|
11
11
|
|
|
12
12
|
from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
|
|
13
13
|
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
+
from ...util.path import find_symlinks
|
|
14
15
|
from ..base import CallbackConfigBase
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -116,9 +117,47 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
116
117
|
)
|
|
117
118
|
continue
|
|
118
119
|
|
|
119
|
-
|
|
120
|
+
self._remove_checkpoint_with_link_support(
|
|
121
|
+
trainer, old_ckpt_path, metadata=True
|
|
122
|
+
)
|
|
120
123
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
121
124
|
|
|
125
|
+
def _remove_checkpoint_with_link_support(
|
|
126
|
+
self,
|
|
127
|
+
trainer: Trainer,
|
|
128
|
+
path: Path,
|
|
129
|
+
metadata: bool,
|
|
130
|
+
):
|
|
131
|
+
# Find all the symlinks to the checkpoint
|
|
132
|
+
ckpt_callbacks: list[CheckpointBase] = [
|
|
133
|
+
callback
|
|
134
|
+
for callback in trainer.checkpoint_callbacks
|
|
135
|
+
if isinstance(callback, CheckpointBase) and callback is not self
|
|
136
|
+
]
|
|
137
|
+
symlink_paths = find_symlinks(
|
|
138
|
+
path,
|
|
139
|
+
*[callback.dirpath for callback in ckpt_callbacks],
|
|
140
|
+
glob_pattern=f"*.{self.extension()}",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# If there are no symlinks, just remove the checkpoint
|
|
144
|
+
if not symlink_paths:
|
|
145
|
+
_remove_checkpoint(trainer, path, metadata=metadata)
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
log.debug(
|
|
149
|
+
f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# For the first symlink, we can just move the checkpoint file
|
|
153
|
+
# to the symlink path. For the rest, we need to make new symlinks.
|
|
154
|
+
new_target = symlink_paths.pop(0)
|
|
155
|
+
path.rename(new_target)
|
|
156
|
+
log.debug(f"New symlink target: {new_target}")
|
|
157
|
+
|
|
158
|
+
for symlink_path in symlink_paths:
|
|
159
|
+
_link_checkpoint(new_target, symlink_path, metadata=False)
|
|
160
|
+
|
|
122
161
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
123
162
|
current_metrics: dict[str, Any] = {
|
|
124
163
|
"epoch": trainer.current_epoch,
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -439,7 +439,9 @@ class Trainer(LightningTrainer):
|
|
|
439
439
|
):
|
|
440
440
|
# If we have a cached path, then we symlink it to the new path.
|
|
441
441
|
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
|
-
|
|
442
|
+
if self.is_global_zero:
|
|
443
|
+
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
444
|
+
self.strategy.barrier("Trainer.save_checkpoint")
|
|
443
445
|
else:
|
|
444
446
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
445
447
|
|
nshtrainer/util/path.py
CHANGED
|
@@ -27,3 +27,26 @@ def get_relative_path(source: _Path, destination: _Path):
|
|
|
27
27
|
down = os.sep.join(destination_parts[i:])
|
|
28
28
|
|
|
29
29
|
return Path(os.path.normpath(os.path.join(up, down)))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def find_symlinks(
|
|
33
|
+
target_file: _Path,
|
|
34
|
+
*search_directories: _Path,
|
|
35
|
+
glob_pattern: str = "*",
|
|
36
|
+
):
|
|
37
|
+
target_file = Path(target_file).resolve()
|
|
38
|
+
symlinks: list[Path] = []
|
|
39
|
+
|
|
40
|
+
for search_directory in search_directories:
|
|
41
|
+
search_directory = Path(search_directory)
|
|
42
|
+
for path in search_directory.rglob(glob_pattern):
|
|
43
|
+
if path.is_symlink():
|
|
44
|
+
try:
|
|
45
|
+
link_target = path.resolve()
|
|
46
|
+
if link_target.samefile(target_file):
|
|
47
|
+
symlinks.append(path)
|
|
48
|
+
except FileNotFoundError:
|
|
49
|
+
# Handle broken symlinks
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
return symlinks
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
5
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=E4tfiGzhnn65X95P0Y6K2d_YfPWqvHZoF0FF1-smEJc,5221
|
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
|
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
7
|
nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
|
|
8
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
@@ -10,7 +10,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
11
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=zKN-6n61aGze-Hf8MBY1Surh6B-xDwNSApqQJtPcTUs,8048
|
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
78
78
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
79
79
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
80
80
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
81
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
81
|
+
nshtrainer/trainer/trainer.py,sha256=Leh3ADxoYsRWlJFIW20netohLcKx0XxUrRhD9LM4jws,19201
|
|
82
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
83
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
84
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
85
|
-
nshtrainer/util/path.py,sha256=
|
|
85
|
+
nshtrainer/util/path.py,sha256=WbPWXpu5LIDocQihQC3-72qxN1sa6-d1kPOmKDR-NC8,1520
|
|
86
86
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
87
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.
|
|
91
|
-
nshtrainer-0.
|
|
92
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.23.0.dist-info/METADATA,sha256=wkbqsz6A4d0h1u-8CCZwfYYmqLm7YjirdnS-fTA-mkI,935
|
|
91
|
+
nshtrainer-0.23.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.23.0.dist-info/RECORD,,
|
|
File without changes
|