nshtrainer 0.26.0__py3-none-any.whl → 0.26.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  import copy
2
2
  import datetime
3
3
  import logging
4
- import shutil
5
4
  from collections.abc import Callable
6
5
  from pathlib import Path
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, cast
@@ -11,7 +10,7 @@ import numpy as np
11
10
  import torch
12
11
 
13
12
  from ..util._environment_info import EnvironmentConfig
14
- from ..util.path import compute_file_checksum, get_relative_path
13
+ from ..util.path import compute_file_checksum, try_symlink_or_copy
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from ..model import BaseConfig, LightningModuleBase
@@ -142,21 +141,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
142
141
  # Link the metadata files to the new checkpoint
143
142
  path = _metadata_path(checkpoint_path)
144
143
  linked_path = _metadata_path(linked_checkpoint_path)
145
- try:
146
- try:
147
- # linked_path.symlink_to(path)
148
- # We should store the path as a relative path
149
- # to the metadata file to avoid issues with
150
- # moving the checkpoint directory
151
- linked_path.symlink_to(get_relative_path(linked_path, path))
152
- except OSError:
153
- # on Windows, special permissions are required to create symbolic links as a regular user
154
- # fall back to copying the file
155
- shutil.copy(path, linked_path)
156
- except Exception:
157
- log.exception(f"Failed to link {path} to {linked_path}")
158
- else:
159
- log.debug(f"Linked {path} to {linked_path}")
144
+ try_symlink_or_copy(path, linked_path)
160
145
 
161
146
 
162
147
  def _sort_ckpts_by_metadata(
@@ -5,7 +5,7 @@ from pathlib import Path
5
5
 
6
6
  from lightning.pytorch import Trainer
7
7
 
8
- from ..util.path import get_relative_path
8
+ from ..util.path import try_symlink_or_copy
9
9
  from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
10
10
 
11
11
  log = logging.getLogger(__name__)
@@ -34,13 +34,7 @@ def _link_checkpoint(
34
34
  if metadata:
35
35
  _remove_checkpoint_metadata(linkpath)
36
36
 
37
- try:
38
- linkpath.symlink_to(get_relative_path(linkpath, filepath))
39
- except OSError:
40
- # on Windows, special permissions are required to create symbolic links as a regular user
41
- # fall back to copying the file
42
- shutil.copy(filepath, linkpath)
43
-
37
+ try_symlink_or_copy(filepath, linkpath)
44
38
  if metadata:
45
39
  _link_checkpoint_metadata(filepath, linkpath)
46
40
 
nshtrainer/_hf_hub.py CHANGED
@@ -188,27 +188,19 @@ class HFHubCallback(NTCallbackBase):
188
188
 
189
189
  self.config = config
190
190
 
191
- self._repo_id = None
191
+ self._repo_id: str | None = None
192
192
  self._checksum_to_path_in_repo: dict[str, Path] = {}
193
193
 
194
194
  @override
195
195
  def setup(self, trainer, pl_module, stage):
196
- from .trainer.trainer import Trainer
197
-
198
- if not isinstance(trainer, Trainer):
199
- raise ValueError(
200
- f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
201
- )
202
-
203
196
  root_config = cast("BaseConfig", pl_module.hparams)
197
+ self._repo_id = _repo_name(self.api, root_config)
198
+
199
+ if not self.config or not trainer.is_global_zero:
200
+ return
204
201
 
205
202
  # Create the repository, if it doesn't exist
206
- self._repo_id = self.api.create_repo(
207
- repo_id=_repo_name(self.api, root_config),
208
- repo_type="model",
209
- private=self.config.auto_create.private,
210
- exist_ok=True,
211
- )
203
+ self._create_repo_if_not_exists()
212
204
 
213
205
  # Upload the config and code
214
206
  self._save_config(root_config)
@@ -216,17 +208,22 @@ class HFHubCallback(NTCallbackBase):
216
208
 
217
209
  @override
218
210
  def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
- root_config = cast("BaseConfig", pl_module.hparams)
220
-
221
211
  # If HF Hub is enabled, then we upload
222
- if self.config and trainer.is_global_zero:
223
- with self._with_error_handling("save checkpoints"):
224
- self._save_checkpoint(
225
- _Upload.from_local_path(ckpt_path, root_config),
226
- _Upload.from_local_path(metadata_path, root_config)
227
- if metadata_path is not None
228
- else None,
229
- )
212
+ if (
213
+ not self.config
214
+ or not self.config.save_checkpoints
215
+ or not trainer.is_global_zero
216
+ ):
217
+ return
218
+
219
+ with self._with_error_handling("save checkpoints"):
220
+ root_config = cast("BaseConfig", pl_module.hparams)
221
+ self._save_checkpoint(
222
+ _Upload.from_local_path(ckpt_path, root_config),
223
+ _Upload.from_local_path(metadata_path, root_config)
224
+ if metadata_path is not None
225
+ else None,
226
+ )
230
227
 
231
228
  @cached_property
232
229
  def api(self):
@@ -241,6 +238,33 @@ class HFHubCallback(NTCallbackBase):
241
238
  raise ValueError("Repository id has not been initialized.")
242
239
  return self._repo_id
243
240
 
241
+ def _create_repo_if_not_exists(self):
242
+ if not self.config or not self.config.auto_create:
243
+ return
244
+
245
+ # Create the repository, if it doesn't exist
246
+ with self._with_error_handling("create repository"):
247
+ from huggingface_hub.utils import RepositoryNotFoundError
248
+
249
+ try:
250
+ # Check if the repository exists
251
+ self.api.repo_info(repo_id=self.repo_id, repo_type="model")
252
+ log.info(f"Repository '{self.repo_id}' already exists.")
253
+ except RepositoryNotFoundError:
254
+ # Repository doesn't exist, so create it
255
+ try:
256
+ self.api.create_repo(
257
+ repo_id=self.repo_id,
258
+ repo_type="model",
259
+ private=self.config.auto_create.private,
260
+ exist_ok=True,
261
+ )
262
+ log.info(f"Created new repository '{self.repo_id}'.")
263
+ except Exception:
264
+ log.exception(f"Failed to create repository '{self.repo_id}'")
265
+ except Exception:
266
+ log.exception(f"Error checking repository '{self.repo_id}'")
267
+
244
268
  def _save_config(self, root_config: "BaseConfig"):
245
269
  with self._with_error_handling("upload config"):
246
270
  self.api.upload_file(
nshtrainer/util/path.py CHANGED
@@ -1,8 +1,13 @@
1
1
  import hashlib
2
+ import logging
2
3
  import os
4
+ import platform
5
+ import shutil
3
6
  from pathlib import Path
4
7
  from typing import TypeAlias
5
8
 
9
+ log = logging.getLogger(__name__)
10
+
6
11
  _Path: TypeAlias = str | Path | os.PathLike
7
12
 
8
13
 
@@ -68,3 +73,32 @@ def compute_file_checksum(file_path: Path) -> str:
68
73
  for byte_block in iter(lambda: f.read(4096), b""):
69
74
  sha256_hash.update(byte_block)
70
75
  return sha256_hash.hexdigest()
76
+
77
+
78
+ def try_symlink_or_copy(
79
+ file_path: Path,
80
+ link_path: Path,
81
+ target_is_directory: bool = False,
82
+ relative: bool = True,
83
+ ):
84
+ """
85
+ Symlinks on Unix, copies on Windows.
86
+ """
87
+
88
+ symlink_target = get_relative_path(link_path, file_path) if relative else file_path
89
+ try:
90
+ if platform.system() == "Windows":
91
+ if target_is_directory:
92
+ shutil.copytree(file_path, link_path)
93
+ else:
94
+ shutil.copy(file_path, link_path)
95
+ else:
96
+ link_path.symlink_to(
97
+ symlink_target, target_is_directory=target_is_directory
98
+ )
99
+ except Exception:
100
+ log.exception(f"Failed to create symlink or copy {file_path} to {link_path}")
101
+ return False
102
+ else:
103
+ log.debug(f"Created symlink or copied {file_path} to {link_path}")
104
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.26.0
3
+ Version: 0.26.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,10 +1,10 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
- nshtrainer/_checkpoint/metadata.py,sha256=BpxC3VGrgSDHvZpc40A2icjBEDvRNGEsTis9YkFY8Kc,5341
5
- nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
4
+ nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
+ nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
7
+ nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
@@ -82,11 +82,11 @@ nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=RUkIOrlj9b8zPPXE3JLhdihBNitJSUWw1whZ33u-2Yk,2005
85
+ nshtrainer/util/path.py,sha256=jAEjF1qp8Aii32L5lWG4UFgVyQAFkHOMYEc_TC2hDx8,2947
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.26.0.dist-info/METADATA,sha256=YBlbpalQ3BX8UBF_5SHk_F7v9Nq3JMsqVf6MoqH8KzU,916
91
- nshtrainer-0.26.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.26.0.dist-info/RECORD,,
90
+ nshtrainer-0.26.2.dist-info/METADATA,sha256=d3vWjdB9FT6fbWJPkyBI-4M18ekg2WcJmJiKasExchM,916
91
+ nshtrainer-0.26.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.26.2.dist-info/RECORD,,