nshtrainer 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import os
2
3
  import shutil
3
4
  from pathlib import Path
@@ -7,6 +8,8 @@ from lightning.pytorch import Trainer
7
8
  from ..util.path import get_relative_path
8
9
  from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
9
10
 
11
+ log = logging.getLogger(__name__)
12
+
10
13
 
11
14
  def _link_checkpoint(
12
15
  filepath: str | Path | os.PathLike,
@@ -19,11 +22,14 @@ def _link_checkpoint(
19
22
  linkpath = Path(linkpath)
20
23
 
21
24
  if remove_existing:
22
- if linkpath.exists():
23
- if linkpath.is_symlink() or linkpath.is_file():
24
- linkpath.unlink()
25
- elif linkpath.is_dir():
26
- shutil.rmtree(linkpath)
25
+ try:
26
+ if linkpath.exists():
27
+ if linkpath.is_dir():
28
+ shutil.rmtree(linkpath, ignore_errors=True)
29
+ else:
30
+ linkpath.unlink(missing_ok=True)
31
+ except Exception:
32
+ log.exception(f"Failed to remove {linkpath}")
27
33
 
28
34
  if metadata:
29
35
  _remove_checkpoint_metadata(linkpath)
nshtrainer/_hf_hub.py CHANGED
@@ -2,6 +2,7 @@ import io
2
2
  import logging
3
3
  import os
4
4
  import re
5
+ from functools import cached_property
5
6
  from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any, cast
7
8
 
@@ -110,6 +111,7 @@ def _api(token: str | None = None):
110
111
 
111
112
  def _enabled_and_valid(
112
113
  trainer: "Trainer",
114
+ callback: "HFHubCallback",
113
115
  config: HuggingFaceHubConfig,
114
116
  *,
115
117
  rank_zero_only: bool,
@@ -120,22 +122,9 @@ def _enabled_and_valid(
120
122
 
121
123
  # If `rank_zero_only` and this is not rank 0, stop here.
122
124
  if rank_zero_only and not trainer.is_global_zero:
123
- return
124
-
125
- # Make sure that `huggingface_hub` is installed
126
- try:
127
- import huggingface_hub # noqa: F401
128
- except ImportError:
129
- log.exception(
130
- "Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
131
- )
132
125
  return None
133
126
 
134
- # Create and authenticate the API instance
135
- if (api := getattr(trainer, "_hf_hub_api", None)) is None:
136
- api = _api(config.token)
137
- setattr(trainer, "_hf_hub_api", api)
138
- return cast(huggingface_hub.HfApi, api)
127
+ return callback._hf_hub_api
139
128
 
140
129
 
141
130
  def _repo_name(api: "HfApi", root_config: "BaseConfig"):
@@ -173,11 +162,12 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
173
162
  return f"{username}/{repo_name}"
174
163
 
175
164
 
176
- def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
165
+ def _init(*, trainer: "Trainer", callback: "HFHubCallback", root_config: "BaseConfig"):
177
166
  config = root_config.trainer.hf_hub
178
167
  if (
179
168
  api := _enabled_and_valid(
180
169
  trainer,
170
+ callback,
181
171
  config,
182
172
  rank_zero_only=True,
183
173
  )
@@ -210,10 +200,10 @@ def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
210
200
  log.exception(f"Error checking repository '{repo_name}'")
211
201
 
212
202
  # Upload the config
213
- _save_config(root_config, trainer=trainer)
203
+ _save_config(root_config, trainer=trainer, callback=callback)
214
204
 
215
205
  # Upload the code
216
- _save_code(repo_name, config=config, trainer=trainer)
206
+ _save_code(repo_name, config=config, trainer=trainer, callback=callback)
217
207
 
218
208
 
219
209
  def _save_code(
@@ -221,10 +211,12 @@ def _save_code(
221
211
  *,
222
212
  config: HuggingFaceHubConfig,
223
213
  trainer: "Trainer",
214
+ callback: "HFHubCallback",
224
215
  ):
225
216
  if (
226
217
  api := _enabled_and_valid(
227
218
  trainer,
219
+ callback,
228
220
  config,
229
221
  rank_zero_only=True,
230
222
  )
@@ -266,11 +258,13 @@ def _save_config(
266
258
  root_config: "BaseConfig",
267
259
  *,
268
260
  trainer: "Trainer",
261
+ callback: "HFHubCallback",
269
262
  ):
270
263
  config = root_config.trainer.hf_hub
271
264
  if (
272
265
  api := _enabled_and_valid(
273
266
  trainer,
267
+ callback,
274
268
  config,
275
269
  rank_zero_only=True,
276
270
  )
@@ -298,15 +292,45 @@ def _save_config(
298
292
  log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
299
293
 
300
294
 
295
+ def _is_link(p: Path, trainer: "Trainer"):
296
+ if p.is_symlink():
297
+ return True
298
+
299
+ try:
300
+ link_path = trainer._nshtrainer_ckpt_link(p)
301
+ return link_path in trainer._nshtrainer_checkpoint_link_dict
302
+ except Exception:
303
+ log.info(f"Failed to check if path {p} is a symlink.", exc_info=True)
304
+
305
+ return False
306
+
307
+
308
+ def _resolve_link(p: Path, trainer: "Trainer"):
309
+ if p.is_symlink():
310
+ return p.resolve()
311
+
312
+ try:
313
+ link_path = trainer._nshtrainer_ckpt_link(p)
314
+ if (
315
+ resolved := trainer._nshtrainer_checkpoint_link_dict.get(link_path)
316
+ ) is not None:
317
+ return resolved
318
+ except Exception:
319
+ log.info(f"Failed to resolve symlink for path {p}.", exc_info=True)
320
+
321
+ return None
322
+
323
+
301
324
  def _save_checkpoint_files(
302
325
  trainer: "Trainer",
326
+ callback: "HFHubCallback",
303
327
  paths: list[Path],
304
328
  *,
305
329
  root_config: "BaseConfig",
306
330
  ):
307
331
  config = root_config.trainer.hf_hub
308
332
  if (
309
- api := _enabled_and_valid(trainer, config, rank_zero_only=True)
333
+ api := _enabled_and_valid(trainer, callback, config, rank_zero_only=True)
310
334
  ) is None or not config.save_checkpoints:
311
335
  return
312
336
 
@@ -323,7 +347,7 @@ def _save_checkpoint_files(
323
347
  # Read all the files to memory
324
348
  file_contents: list[bytes | None] = []
325
349
  for p in paths:
326
- assert not p.is_symlink(), f"Path {p} is a symlink."
350
+ assert not _is_link(p, trainer=trainer), f"Path {p} is a symlink."
327
351
  assert p.is_file(), f"Path {p} is not a file."
328
352
  try:
329
353
  with open(p, "rb") as f:
@@ -374,13 +398,14 @@ def _save_checkpoint_files(
374
398
 
375
399
  def _save_checkpoint_symlinks(
376
400
  trainer: "Trainer",
401
+ callback: "HFHubCallback",
377
402
  paths: list[Path],
378
403
  *,
379
404
  root_config: "BaseConfig",
380
405
  ):
381
406
  config = root_config.trainer.hf_hub
382
407
  if (
383
- api := _enabled_and_valid(trainer, config, rank_zero_only=True)
408
+ api := _enabled_and_valid(trainer, callback, config, rank_zero_only=True)
384
409
  ) is None or not config.save_checkpoints:
385
410
  return
386
411
 
@@ -397,7 +422,7 @@ def _save_checkpoint_symlinks(
397
422
 
398
423
  commits: list[CommitOperation] = []
399
424
  for p in paths:
400
- assert p.is_symlink(), f"Path {p} is not a symlink."
425
+ assert _is_link(p, trainer=trainer), f"Path {p} is not a symlink."
401
426
 
402
427
  try:
403
428
  dest_relative_path = p.relative_to(checkpoint_dir)
@@ -407,11 +432,15 @@ def _save_checkpoint_symlinks(
407
432
  )
408
433
  continue
409
434
 
435
+ if (p_resolved := _resolve_link(p, trainer=trainer)) is None:
436
+ log.warning(f"Failed to resolve symlink for path {p}.")
437
+ continue
438
+
410
439
  try:
411
- source_relative_path = p.resolve().relative_to(checkpoint_dir)
440
+ source_relative_path = p_resolved.relative_to(checkpoint_dir)
412
441
  except ValueError:
413
442
  log.warning(
414
- f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
443
+ f"Checkpoint symlink target {p_resolved} is not within the checkpoint directory {checkpoint_dir}."
415
444
  )
416
445
  continue
417
446
 
@@ -447,39 +476,6 @@ def _save_checkpoint_symlinks(
447
476
  log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
448
477
 
449
478
 
450
- def _save_checkpoint_directory(trainer: "Trainer", *, root_config: "BaseConfig"):
451
- config = root_config.trainer.hf_hub
452
- if (
453
- api := _enabled_and_valid(trainer, config, rank_zero_only=True)
454
- ) is None or not config.save_checkpoints:
455
- return
456
-
457
- # Resolve the checkpoint directory
458
- checkpoint_dir = root_config.directory.resolve_subdirectory(
459
- root_config.id, "checkpoint"
460
- )
461
-
462
- # Resolve the repository name
463
- repo_name = _repo_name(api, root_config)
464
-
465
- # Upload the checkpoint directory to the repository
466
- try:
467
- api.upload_folder(
468
- folder_path=str(checkpoint_dir),
469
- repo_id=repo_name,
470
- repo_type="model",
471
- path_in_repo="checkpoints",
472
- run_as_future=cast(Any, config.save_in_background),
473
- )
474
- log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
475
- except Exception:
476
- log.exception(
477
- f"Failed to upload checkpoint directory to repository '{repo_name}'."
478
- )
479
-
480
- log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
481
-
482
-
483
479
  class HFHubCallback(NTCallbackBase):
484
480
  def __init__(self, config: HuggingFaceHubConfig):
485
481
  super().__init__()
@@ -495,12 +491,7 @@ class HFHubCallback(NTCallbackBase):
495
491
  )
496
492
 
497
493
  root_config = cast("BaseConfig", pl_module.hparams)
498
- _init(trainer=trainer, root_config=root_config)
499
-
500
- @override
501
- def teardown(self, trainer, pl_module, stage):
502
- if hasattr(trainer, "_hf_hub_api"):
503
- delattr(trainer, "_hf_hub_api")
494
+ _init(trainer=trainer, callback=self, root_config=root_config)
504
495
 
505
496
  @override
506
497
  def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
@@ -510,9 +501,18 @@ class HFHubCallback(NTCallbackBase):
510
501
  if root_config.trainer.hf_hub and trainer.is_global_zero:
511
502
  # Upload the regular files first, then the symlinks
512
503
  all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
513
- if regular_paths := [p for p in all_paths if not p.is_symlink()]:
514
- _save_checkpoint_files(trainer, regular_paths, root_config=root_config)
515
- if symlink_paths := [p for p in all_paths if p.is_symlink()]:
504
+ if regular_paths := [
505
+ p for p in all_paths if not _is_link(p, trainer=trainer)
506
+ ]:
507
+ _save_checkpoint_files(
508
+ trainer, self, regular_paths, root_config=root_config
509
+ )
510
+ if symlink_paths := [p for p in all_paths if _is_link(p, trainer=trainer)]:
516
511
  _save_checkpoint_symlinks(
517
- trainer, symlink_paths, root_config=root_config
512
+ trainer, self, symlink_paths, root_config=root_config
518
513
  )
514
+
515
+ @cached_property
516
+ def _hf_hub_api(self):
517
+ # Create and authenticate the API instance
518
+ return _api(self.config.token)
@@ -156,6 +156,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
156
156
  filepath,
157
157
  self.config.save_weights_only,
158
158
  use_checkpoint_cache=None,
159
+ ckpt_cache_use_symlink=False,
159
160
  )
160
161
 
161
162
  if trainer.is_global_zero:
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import os
3
+ import shutil
3
4
  from collections.abc import Sequence
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, cast
@@ -279,6 +280,13 @@ class Trainer(LightningTrainer):
279
280
  if TYPE_CHECKING:
280
281
  callbacks: list[Callback]
281
282
 
283
+ def _nshtrainer_ckpt_link(self, ckpt_path: Path):
284
+ root_config = cast(BaseConfig, self._base_module.hparams)
285
+ ckpt_dir = root_config.directory.resolve_subdirectory(
286
+ root_config.id, "checkpoint"
287
+ )
288
+ return str(ckpt_path.absolute().relative_to(ckpt_dir))
289
+
282
290
  @override
283
291
  def __init__(
284
292
  self,
@@ -287,6 +295,7 @@ class Trainer(LightningTrainer):
287
295
  **kwargs: Unpack[LightningTrainerKwargs],
288
296
  ):
289
297
  self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
298
+ self._nshtrainer_checkpoint_link_dict = dict[str, Path]()
290
299
 
291
300
  self._pre_init(config)
292
301
 
@@ -416,11 +425,12 @@ class Trainer(LightningTrainer):
416
425
  weights_only: bool = False,
417
426
  storage_options: Any | None = None,
418
427
  use_checkpoint_cache: bool | None = None,
428
+ ckpt_cache_use_symlink: bool = False,
419
429
  ):
420
430
  lm = self._base_module
421
- hparams = cast(BaseConfig, lm.hparams)
431
+ root_config = cast(BaseConfig, lm.hparams)
422
432
  if use_checkpoint_cache is None:
423
- use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
433
+ use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
424
434
 
425
435
  filepath = Path(filepath)
426
436
 
@@ -440,7 +450,14 @@ class Trainer(LightningTrainer):
440
450
  # If we have a cached path, then we symlink it to the new path.
441
451
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
452
  if self.is_global_zero:
443
- _link_checkpoint(cached_path, filepath, metadata=False)
453
+ if ckpt_cache_use_symlink:
454
+ _link_checkpoint(cached_path, filepath, metadata=False)
455
+ else:
456
+ shutil.copy(cached_path, filepath)
457
+ self._nshtrainer_checkpoint_link_dict[
458
+ self._nshtrainer_ckpt_link(filepath)
459
+ ] = cached_path
460
+ self.strategy.barrier("Trainer.save_checkpoint")
444
461
  else:
445
462
  super().save_checkpoint(filepath, weights_only, storage_options)
446
463
 
@@ -453,7 +470,7 @@ class Trainer(LightningTrainer):
453
470
 
454
471
  # Save the checkpoint metadata
455
472
  metadata_path = None
456
- if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
473
+ if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
457
474
  # Generate the metadata and write to disk
458
475
  if (
459
476
  metadata_path := _write_checkpoint_metadata(self, lm, filepath)
nshtrainer/util/path.py CHANGED
@@ -27,3 +27,26 @@ def get_relative_path(source: _Path, destination: _Path):
27
27
  down = os.sep.join(destination_parts[i:])
28
28
 
29
29
  return Path(os.path.normpath(os.path.join(up, down)))
30
+
31
+
32
+ def find_symlinks(
33
+ target_file: _Path,
34
+ *search_directories: _Path,
35
+ glob_pattern: str = "*",
36
+ ):
37
+ target_file = Path(target_file).resolve()
38
+ symlinks: list[Path] = []
39
+
40
+ for search_directory in search_directories:
41
+ search_directory = Path(search_directory)
42
+ for path in search_directory.rglob(glob_pattern):
43
+ if path.is_symlink():
44
+ try:
45
+ link_target = path.resolve()
46
+ if link_target.samefile(target_file):
47
+ symlinks.append(path)
48
+ except FileNotFoundError:
49
+ # Handle broken symlinks
50
+ pass
51
+
52
+ return symlinks
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.22.1
3
+ Version: 0.24.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -2,15 +2,15 @@ nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=E4tfiGzhnn65X95P0Y6K2d_YfPWqvHZoF0FF1-smEJc,5221
5
- nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
5
+ nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
7
+ nshtrainer/_hf_hub.py,sha256=Ac4y7jmuAMEQOJPJgoYmiaIGlZvgyUcqpipb6fPuHSE,16587
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=KXsvAhgVgYjmYfoqzH_qoQXqd6nVx7-vs9ObQJpwbIk,19140
81
+ nshtrainer/trainer/trainer.py,sha256=vswxAhyLqTL99kRJvU4Q3uEyQT80eM3mN74yMhsyn_I,19905
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=A_Ocag3_hbwns_zAxFDlH-5eVHWFlcy2DKxHQ7jddvk,837
85
+ nshtrainer/util/path.py,sha256=WbPWXpu5LIDocQihQC3-72qxN1sa6-d1kPOmKDR-NC8,1520
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.22.1.dist-info/METADATA,sha256=XF3QXKbeAN7I5vYHNbjExlV_6CF8QgPqPYFsCxs52rA,935
91
- nshtrainer-0.22.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.22.1.dist-info/RECORD,,
90
+ nshtrainer-0.24.0.dist-info/METADATA,sha256=sYEAgF1cJ2a3Mjp137ienL4cvHswQHH76lm2GTj-Nb8,935
91
+ nshtrainer-0.24.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.24.0.dist-info/RECORD,,