nshtrainer 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import os
2
3
  import shutil
3
4
  from pathlib import Path
@@ -7,6 +8,8 @@ from lightning.pytorch import Trainer
7
8
  from ..util.path import get_relative_path
8
9
  from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
9
10
 
11
+ log = logging.getLogger(__name__)
12
+
10
13
 
11
14
  def _link_checkpoint(
12
15
  filepath: str | Path | os.PathLike,
@@ -19,11 +22,14 @@ def _link_checkpoint(
19
22
  linkpath = Path(linkpath)
20
23
 
21
24
  if remove_existing:
22
- if linkpath.exists():
23
- if linkpath.is_symlink() or linkpath.is_file():
24
- linkpath.unlink()
25
- elif linkpath.is_dir():
26
- shutil.rmtree(linkpath)
25
+ try:
26
+ if linkpath.exists():
27
+ if linkpath.is_dir():
28
+ shutil.rmtree(linkpath, ignore_errors=True)
29
+ else:
30
+ linkpath.unlink(missing_ok=True)
31
+ except Exception:
32
+ log.exception(f"Failed to remove {linkpath}")
27
33
 
28
34
  if metadata:
29
35
  _remove_checkpoint_metadata(linkpath)
@@ -11,6 +11,7 @@ from typing_extensions import TypeVar, override
11
11
 
12
12
  from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
13
13
  from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
+ from ...util.path import find_symlinks
14
15
  from ..base import CallbackConfigBase
15
16
 
16
17
  if TYPE_CHECKING:
@@ -116,9 +117,47 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
116
117
  )
117
118
  continue
118
119
 
119
- _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
120
+ self._remove_checkpoint_with_link_support(
121
+ trainer, old_ckpt_path, metadata=True
122
+ )
120
123
  log.debug(f"Removed old checkpoint: {old_ckpt_path}")
121
124
 
125
+ def _remove_checkpoint_with_link_support(
126
+ self,
127
+ trainer: Trainer,
128
+ path: Path,
129
+ metadata: bool,
130
+ ):
131
+ # Find all the symlinks to the checkpoint
132
+ ckpt_callbacks: list[CheckpointBase] = [
133
+ callback
134
+ for callback in trainer.checkpoint_callbacks
135
+ if isinstance(callback, CheckpointBase) and callback is not self
136
+ ]
137
+ symlink_paths = find_symlinks(
138
+ path,
139
+ *[callback.dirpath for callback in ckpt_callbacks],
140
+ glob_pattern=f"*.{self.extension()}",
141
+ )
142
+
143
+ # If there are no symlinks, just remove the checkpoint
144
+ if not symlink_paths:
145
+ _remove_checkpoint(trainer, path, metadata=metadata)
146
+ return
147
+
148
+ log.debug(
149
+ f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
150
+ )
151
+
152
+ # For the first symlink, we can just move the checkpoint file
153
+ # to the symlink path. For the rest, we need to make new symlinks.
154
+ new_target = symlink_paths.pop(0)
155
+ path.rename(new_target)
156
+ log.debug(f"New symlink target: {new_target}")
157
+
158
+ for symlink_path in symlink_paths:
159
+ _link_checkpoint(new_target, symlink_path, metadata=False)
160
+
122
161
  def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
123
162
  current_metrics: dict[str, Any] = {
124
163
  "epoch": trainer.current_epoch,
@@ -441,6 +441,7 @@ class Trainer(LightningTrainer):
441
441
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
442
  if self.is_global_zero:
443
443
  _link_checkpoint(cached_path, filepath, metadata=False)
444
+ self.strategy.barrier("Trainer.save_checkpoint")
444
445
  else:
445
446
  super().save_checkpoint(filepath, weights_only, storage_options)
446
447
 
nshtrainer/util/path.py CHANGED
@@ -27,3 +27,26 @@ def get_relative_path(source: _Path, destination: _Path):
27
27
  down = os.sep.join(destination_parts[i:])
28
28
 
29
29
  return Path(os.path.normpath(os.path.join(up, down)))
30
+
31
+
32
+ def find_symlinks(
33
+ target_file: _Path,
34
+ *search_directories: _Path,
35
+ glob_pattern: str = "*",
36
+ ):
37
+ target_file = Path(target_file).resolve()
38
+ symlinks: list[Path] = []
39
+
40
+ for search_directory in search_directories:
41
+ search_directory = Path(search_directory)
42
+ for path in search_directory.rglob(glob_pattern):
43
+ if path.is_symlink():
44
+ try:
45
+ link_target = path.resolve()
46
+ if link_target.samefile(target_file):
47
+ symlinks.append(path)
48
+ except FileNotFoundError:
49
+ # Handle broken symlinks
50
+ pass
51
+
52
+ return symlinks
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.22.1
3
+ Version: 0.23.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -2,7 +2,7 @@ nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=E4tfiGzhnn65X95P0Y6K2d_YfPWqvHZoF0FF1-smEJc,5221
5
- nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
5
+ nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
7
  nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
@@ -10,7 +10,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=zKN-6n61aGze-Hf8MBY1Surh6B-xDwNSApqQJtPcTUs,8048
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=KXsvAhgVgYjmYfoqzH_qoQXqd6nVx7-vs9ObQJpwbIk,19140
81
+ nshtrainer/trainer/trainer.py,sha256=Leh3ADxoYsRWlJFIW20netohLcKx0XxUrRhD9LM4jws,19201
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=A_Ocag3_hbwns_zAxFDlH-5eVHWFlcy2DKxHQ7jddvk,837
85
+ nshtrainer/util/path.py,sha256=WbPWXpu5LIDocQihQC3-72qxN1sa6-d1kPOmKDR-NC8,1520
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.22.1.dist-info/METADATA,sha256=XF3QXKbeAN7I5vYHNbjExlV_6CF8QgPqPYFsCxs52rA,935
91
- nshtrainer-0.22.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.22.1.dist-info/RECORD,,
90
+ nshtrainer-0.23.0.dist-info/METADATA,sha256=wkbqsz6A4d0h1u-8CCZwfYYmqLm7YjirdnS-fTA-mkI,935
91
+ nshtrainer-0.23.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.23.0.dist-info/RECORD,,