nshtrainer 0.26.0__py3-none-any.whl → 0.26.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +2 -17
- nshtrainer/_checkpoint/saver.py +2 -8
- nshtrainer/_hf_hub.py +48 -24
- nshtrainer/util/path.py +34 -0
- {nshtrainer-0.26.0.dist-info → nshtrainer-0.26.2.dist-info}/METADATA +1 -1
- {nshtrainer-0.26.0.dist-info → nshtrainer-0.26.2.dist-info}/RECORD +7 -7
- {nshtrainer-0.26.0.dist-info → nshtrainer-0.26.2.dist-info}/WHEEL +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import datetime
|
|
3
3
|
import logging
|
|
4
|
-
import shutil
|
|
5
4
|
from collections.abc import Callable
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
|
@@ -11,7 +10,7 @@ import numpy as np
|
|
|
11
10
|
import torch
|
|
12
11
|
|
|
13
12
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
-
from ..util.path import compute_file_checksum,
|
|
13
|
+
from ..util.path import compute_file_checksum, try_symlink_or_copy
|
|
15
14
|
|
|
16
15
|
if TYPE_CHECKING:
|
|
17
16
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -142,21 +141,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
142
141
|
# Link the metadata files to the new checkpoint
|
|
143
142
|
path = _metadata_path(checkpoint_path)
|
|
144
143
|
linked_path = _metadata_path(linked_checkpoint_path)
|
|
145
|
-
|
|
146
|
-
try:
|
|
147
|
-
# linked_path.symlink_to(path)
|
|
148
|
-
# We should store the path as a relative path
|
|
149
|
-
# to the metadata file to avoid issues with
|
|
150
|
-
# moving the checkpoint directory
|
|
151
|
-
linked_path.symlink_to(get_relative_path(linked_path, path))
|
|
152
|
-
except OSError:
|
|
153
|
-
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
154
|
-
# fall back to copying the file
|
|
155
|
-
shutil.copy(path, linked_path)
|
|
156
|
-
except Exception:
|
|
157
|
-
log.exception(f"Failed to link {path} to {linked_path}")
|
|
158
|
-
else:
|
|
159
|
-
log.debug(f"Linked {path} to {linked_path}")
|
|
144
|
+
try_symlink_or_copy(path, linked_path)
|
|
160
145
|
|
|
161
146
|
|
|
162
147
|
def _sort_ckpts_by_metadata(
|
nshtrainer/_checkpoint/saver.py
CHANGED
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
|
|
6
6
|
from lightning.pytorch import Trainer
|
|
7
7
|
|
|
8
|
-
from ..util.path import
|
|
8
|
+
from ..util.path import try_symlink_or_copy
|
|
9
9
|
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
10
10
|
|
|
11
11
|
log = logging.getLogger(__name__)
|
|
@@ -34,13 +34,7 @@ def _link_checkpoint(
|
|
|
34
34
|
if metadata:
|
|
35
35
|
_remove_checkpoint_metadata(linkpath)
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
linkpath.symlink_to(get_relative_path(linkpath, filepath))
|
|
39
|
-
except OSError:
|
|
40
|
-
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
41
|
-
# fall back to copying the file
|
|
42
|
-
shutil.copy(filepath, linkpath)
|
|
43
|
-
|
|
37
|
+
try_symlink_or_copy(filepath, linkpath)
|
|
44
38
|
if metadata:
|
|
45
39
|
_link_checkpoint_metadata(filepath, linkpath)
|
|
46
40
|
|
nshtrainer/_hf_hub.py
CHANGED
|
@@ -188,27 +188,19 @@ class HFHubCallback(NTCallbackBase):
|
|
|
188
188
|
|
|
189
189
|
self.config = config
|
|
190
190
|
|
|
191
|
-
self._repo_id = None
|
|
191
|
+
self._repo_id: str | None = None
|
|
192
192
|
self._checksum_to_path_in_repo: dict[str, Path] = {}
|
|
193
193
|
|
|
194
194
|
@override
|
|
195
195
|
def setup(self, trainer, pl_module, stage):
|
|
196
|
-
from .trainer.trainer import Trainer
|
|
197
|
-
|
|
198
|
-
if not isinstance(trainer, Trainer):
|
|
199
|
-
raise ValueError(
|
|
200
|
-
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
201
|
-
)
|
|
202
|
-
|
|
203
196
|
root_config = cast("BaseConfig", pl_module.hparams)
|
|
197
|
+
self._repo_id = _repo_name(self.api, root_config)
|
|
198
|
+
|
|
199
|
+
if not self.config or not trainer.is_global_zero:
|
|
200
|
+
return
|
|
204
201
|
|
|
205
202
|
# Create the repository, if it doesn't exist
|
|
206
|
-
self.
|
|
207
|
-
repo_id=_repo_name(self.api, root_config),
|
|
208
|
-
repo_type="model",
|
|
209
|
-
private=self.config.auto_create.private,
|
|
210
|
-
exist_ok=True,
|
|
211
|
-
)
|
|
203
|
+
self._create_repo_if_not_exists()
|
|
212
204
|
|
|
213
205
|
# Upload the config and code
|
|
214
206
|
self._save_config(root_config)
|
|
@@ -216,17 +208,22 @@ class HFHubCallback(NTCallbackBase):
|
|
|
216
208
|
|
|
217
209
|
@override
|
|
218
210
|
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
219
|
-
root_config = cast("BaseConfig", pl_module.hparams)
|
|
220
|
-
|
|
221
211
|
# If HF Hub is enabled, then we upload
|
|
222
|
-
if
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
212
|
+
if (
|
|
213
|
+
not self.config
|
|
214
|
+
or not self.config.save_checkpoints
|
|
215
|
+
or not trainer.is_global_zero
|
|
216
|
+
):
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
with self._with_error_handling("save checkpoints"):
|
|
220
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
221
|
+
self._save_checkpoint(
|
|
222
|
+
_Upload.from_local_path(ckpt_path, root_config),
|
|
223
|
+
_Upload.from_local_path(metadata_path, root_config)
|
|
224
|
+
if metadata_path is not None
|
|
225
|
+
else None,
|
|
226
|
+
)
|
|
230
227
|
|
|
231
228
|
@cached_property
|
|
232
229
|
def api(self):
|
|
@@ -241,6 +238,33 @@ class HFHubCallback(NTCallbackBase):
|
|
|
241
238
|
raise ValueError("Repository id has not been initialized.")
|
|
242
239
|
return self._repo_id
|
|
243
240
|
|
|
241
|
+
def _create_repo_if_not_exists(self):
|
|
242
|
+
if not self.config or not self.config.auto_create:
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
# Create the repository, if it doesn't exist
|
|
246
|
+
with self._with_error_handling("create repository"):
|
|
247
|
+
from huggingface_hub.utils import RepositoryNotFoundError
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
# Check if the repository exists
|
|
251
|
+
self.api.repo_info(repo_id=self.repo_id, repo_type="model")
|
|
252
|
+
log.info(f"Repository '{self.repo_id}' already exists.")
|
|
253
|
+
except RepositoryNotFoundError:
|
|
254
|
+
# Repository doesn't exist, so create it
|
|
255
|
+
try:
|
|
256
|
+
self.api.create_repo(
|
|
257
|
+
repo_id=self.repo_id,
|
|
258
|
+
repo_type="model",
|
|
259
|
+
private=self.config.auto_create.private,
|
|
260
|
+
exist_ok=True,
|
|
261
|
+
)
|
|
262
|
+
log.info(f"Created new repository '{self.repo_id}'.")
|
|
263
|
+
except Exception:
|
|
264
|
+
log.exception(f"Failed to create repository '{self.repo_id}'")
|
|
265
|
+
except Exception:
|
|
266
|
+
log.exception(f"Error checking repository '{self.repo_id}'")
|
|
267
|
+
|
|
244
268
|
def _save_config(self, root_config: "BaseConfig"):
|
|
245
269
|
with self._with_error_handling("upload config"):
|
|
246
270
|
self.api.upload_file(
|
nshtrainer/util/path.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
import hashlib
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
4
|
+
import platform
|
|
5
|
+
import shutil
|
|
3
6
|
from pathlib import Path
|
|
4
7
|
from typing import TypeAlias
|
|
5
8
|
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
6
11
|
_Path: TypeAlias = str | Path | os.PathLike
|
|
7
12
|
|
|
8
13
|
|
|
@@ -68,3 +73,32 @@ def compute_file_checksum(file_path: Path) -> str:
|
|
|
68
73
|
for byte_block in iter(lambda: f.read(4096), b""):
|
|
69
74
|
sha256_hash.update(byte_block)
|
|
70
75
|
return sha256_hash.hexdigest()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def try_symlink_or_copy(
|
|
79
|
+
file_path: Path,
|
|
80
|
+
link_path: Path,
|
|
81
|
+
target_is_directory: bool = False,
|
|
82
|
+
relative: bool = True,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Symlinks on Unix, copies on Windows.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
symlink_target = get_relative_path(link_path, file_path) if relative else file_path
|
|
89
|
+
try:
|
|
90
|
+
if platform.system() == "Windows":
|
|
91
|
+
if target_is_directory:
|
|
92
|
+
shutil.copytree(file_path, link_path)
|
|
93
|
+
else:
|
|
94
|
+
shutil.copy(file_path, link_path)
|
|
95
|
+
else:
|
|
96
|
+
link_path.symlink_to(
|
|
97
|
+
symlink_target, target_is_directory=target_is_directory
|
|
98
|
+
)
|
|
99
|
+
except Exception:
|
|
100
|
+
log.exception(f"Failed to create symlink or copy {file_path} to {link_path}")
|
|
101
|
+
return False
|
|
102
|
+
else:
|
|
103
|
+
log.debug(f"Created symlink or copied {file_path} to {link_path}")
|
|
104
|
+
return True
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
5
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
|
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
|
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
|
|
8
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
9
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
@@ -82,11 +82,11 @@ nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk
|
|
|
82
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
83
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
84
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
85
|
-
nshtrainer/util/path.py,sha256=
|
|
85
|
+
nshtrainer/util/path.py,sha256=jAEjF1qp8Aii32L5lWG4UFgVyQAFkHOMYEc_TC2hDx8,2947
|
|
86
86
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
87
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.26.
|
|
91
|
-
nshtrainer-0.26.
|
|
92
|
-
nshtrainer-0.26.
|
|
90
|
+
nshtrainer-0.26.2.dist-info/METADATA,sha256=d3vWjdB9FT6fbWJPkyBI-4M18ekg2WcJmJiKasExchM,916
|
|
91
|
+
nshtrainer-0.26.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.26.2.dist-info/RECORD,,
|
|
File without changes
|