nshtrainer 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,40 @@
1
+ from pathlib import Path
2
+ from typing import TYPE_CHECKING
3
+
4
+ from lightning.pytorch.callbacks import Callback as _LightningCallback
5
+
6
+ if TYPE_CHECKING:
7
+ from .model import LightningModuleBase
8
+ from .trainer import Trainer
9
+
10
+
11
+ class NTCallbackBase(_LightningCallback):
12
+ def on_checkpoint_saved(
13
+ self,
14
+ ckpt_path: Path,
15
+ metadata_path: Path | None,
16
+ trainer: "Trainer",
17
+ pl_module: "LightningModuleBase",
18
+ ) -> None:
19
+ """Called after a checkpoint is saved."""
20
+ pass
21
+
22
+
23
+ def _call_on_checkpoint_saved(
24
+ trainer: "Trainer",
25
+ ckpt_path: str | Path,
26
+ metadata_path: str | Path | None,
27
+ ):
28
+ ckpt_path = Path(ckpt_path)
29
+ metadata_path = Path(metadata_path) if metadata_path else None
30
+
31
+ for callback in trainer.callbacks:
32
+ if not isinstance(callback, NTCallbackBase):
33
+ continue
34
+
35
+ callback.on_checkpoint_saved(
36
+ ckpt_path,
37
+ metadata_path,
38
+ trainer,
39
+ trainer._base_module,
40
+ )
@@ -96,13 +96,17 @@ def _generate_checkpoint_metadata(
96
96
  )
97
97
 
98
98
 
99
+ def _metadata_path(checkpoint_path: Path):
100
+ return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
101
+
102
+
99
103
  def _write_checkpoint_metadata(
100
104
  trainer: "Trainer",
101
105
  model: "LightningModuleBase",
102
106
  checkpoint_path: Path,
103
107
  ):
104
108
  config = cast("BaseConfig", model.config)
105
- metadata_path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
109
+ metadata_path = _metadata_path(checkpoint_path)
106
110
  metadata = _generate_checkpoint_metadata(
107
111
  config, trainer, checkpoint_path, metadata_path
108
112
  )
@@ -119,7 +123,7 @@ def _write_checkpoint_metadata(
119
123
 
120
124
 
121
125
  def _remove_checkpoint_metadata(checkpoint_path: Path):
122
- path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
126
+ path = _metadata_path(checkpoint_path)
123
127
  try:
124
128
  path.unlink(missing_ok=True)
125
129
  except Exception:
@@ -133,8 +137,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
133
137
  _remove_checkpoint_metadata(linked_checkpoint_path)
134
138
 
135
139
  # Link the metadata files to the new checkpoint
136
- path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
137
- linked_path = linked_checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
140
+ path = _metadata_path(checkpoint_path)
141
+ linked_path = _metadata_path(linked_checkpoint_path)
138
142
  try:
139
143
  try:
140
144
  # linked_path.symlink_to(path)
@@ -4,6 +4,7 @@ from pathlib import Path
4
4
 
5
5
  from lightning.pytorch import Trainer
6
6
 
7
+ from ..util.path import get_relative_path
7
8
  from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
9
 
9
10
 
@@ -14,10 +15,8 @@ def _link_checkpoint(
14
15
  metadata: bool,
15
16
  remove_existing: bool = True,
16
17
  ):
17
- if not isinstance(filepath, Path):
18
- filepath = Path(filepath)
19
- if not isinstance(linkpath, Path):
20
- linkpath = Path(linkpath)
18
+ filepath = Path(filepath)
19
+ linkpath = Path(linkpath)
21
20
 
22
21
  if remove_existing:
23
22
  if linkpath.exists():
@@ -30,7 +29,7 @@ def _link_checkpoint(
30
29
  _remove_checkpoint_metadata(linkpath)
31
30
 
32
31
  try:
33
- linkpath.symlink_to(filepath.relative_to(linkpath.parent))
32
+ linkpath.symlink_to(get_relative_path(linkpath, filepath))
34
33
  except OSError:
35
34
  # on Windows, special permissions are required to create symbolic links as a regular user
36
35
  # fall back to copying the file
@@ -46,9 +45,9 @@ def _remove_checkpoint(
46
45
  *,
47
46
  metadata: bool,
48
47
  ):
49
- if not isinstance(filepath, Path):
50
- filepath = Path(filepath)
48
+ filepath = Path(filepath)
51
49
 
52
50
  trainer.strategy.remove_checkpoint(filepath)
51
+
53
52
  if metadata:
54
53
  _remove_checkpoint_metadata(filepath)
nshtrainer/_hf_hub.py CHANGED
@@ -6,22 +6,19 @@ from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  import nshconfig as C
9
- from lightning.pytorch import Callback
10
- from lightning.pytorch.trainer import Trainer
11
9
  from nshrunner._env import SNAPSHOT_DIR
12
10
  from typing_extensions import override
13
11
 
14
- from .callbacks.base import (
15
- CallbackConfigBase,
16
- CallbackMetadataConfig,
17
- CallbackWithMetadata,
18
- )
12
+ from ._callback import NTCallbackBase
13
+ from .callbacks.base import CallbackConfigBase
19
14
 
20
15
  if TYPE_CHECKING:
21
16
  from huggingface_hub import HfApi # noqa: F401
22
17
 
23
18
  from .model.base import BaseConfig
24
19
  from .trainer.trainer import Trainer
20
+
21
+
25
22
  log = logging.getLogger(__name__)
26
23
 
27
24
 
@@ -80,10 +77,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
80
77
 
81
78
  @override
82
79
  def create_callbacks(self, root_config):
83
- yield CallbackWithMetadata(
84
- HFHubCallback(self),
85
- CallbackMetadataConfig(ignore_if_exists=True),
86
- )
80
+ yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
87
81
 
88
82
 
89
83
  def _api(token: str | None = None):
@@ -102,9 +96,9 @@ def _api(token: str | None = None):
102
96
 
103
97
  # Verify authentication
104
98
  api.whoami()
105
- except Exception as e:
99
+ except Exception:
106
100
  log.exception(
107
- f"Authentication failed for Hugging Face Hub: {str(e)}. "
101
+ "Authentication failed for Hugging Face Hub. "
108
102
  "Please make sure you are logged in using `huggingface-cli login`, "
109
103
  "by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
110
104
  "or by providing a valid token in the configuration."
@@ -210,10 +204,10 @@ def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
210
204
  exist_ok=True,
211
205
  )
212
206
  log.info(f"Created new repository '{repo_name}'.")
213
- except Exception as e:
214
- log.exception(f"Failed to create repository '{repo_name}': {str(e)}")
215
- except Exception as e:
216
- log.exception(f"Error checking repository '{repo_name}': {str(e)}")
207
+ except Exception:
208
+ log.exception(f"Failed to create repository '{repo_name}'")
209
+ except Exception:
210
+ log.exception(f"Error checking repository '{repo_name}'")
217
211
 
218
212
  # Upload the config
219
213
  _save_config(root_config, trainer=trainer)
@@ -262,9 +256,9 @@ def _save_code(
262
256
  log.info(
263
257
  f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
264
258
  )
265
- except Exception as e:
259
+ except Exception:
266
260
  log.exception(
267
- f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder: {str(e)}"
261
+ f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
268
262
  )
269
263
 
270
264
 
@@ -300,10 +294,8 @@ def _save_config(
300
294
  run_as_future=cast(Any, config.save_in_background),
301
295
  )
302
296
  log.info(f"Uploaded config.json to repository '{repo_name}'.")
303
- except Exception as e:
304
- log.exception(
305
- f"Failed to upload config.json to repository '{repo_name}': {str(e)}"
306
- )
297
+ except Exception:
298
+ log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
307
299
 
308
300
 
309
301
  def _save_checkpoint_files(
@@ -331,17 +323,24 @@ def _save_checkpoint_files(
331
323
  # Read all the files to memory
332
324
  file_contents: list[bytes | None] = []
333
325
  for p in paths:
326
+ assert not p.is_symlink(), f"Path {p} is a symlink."
327
+ assert p.is_file(), f"Path {p} is not a file."
334
328
  try:
335
329
  with open(p, "rb") as f:
336
330
  file_contents.append(f.read())
337
- except IOError as e:
338
- log.warning(f"Failed to read checkpoint file {p}: {str(e)}")
331
+ except IOError:
332
+ log.exception(f"Failed to read checkpoint file {p}.")
339
333
  file_contents.append(None)
340
334
 
341
- for p, contents in zip(paths, file_contents):
342
- if contents is None:
343
- continue
335
+ # Remove the paths that failed to read
336
+ file_contents_and_paths = [
337
+ (contents, p)
338
+ for contents, p in zip(file_contents, paths)
339
+ if contents is not None
340
+ ]
344
341
 
342
+ # Upload the checkpoint files to the repository
343
+ for contents, p in file_contents_and_paths:
345
344
  try:
346
345
  relative_path = p.relative_to(checkpoint_dir)
347
346
  except ValueError:
@@ -365,21 +364,136 @@ def _save_checkpoint_files(
365
364
  log.info(
366
365
  f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
367
366
  )
368
- except Exception as e:
367
+ except Exception:
369
368
  log.exception(
370
- f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}': {str(e)}"
369
+ f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
371
370
  )
372
371
 
373
372
  log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
374
373
 
375
374
 
376
- class HFHubCallback(Callback):
375
+ def _save_checkpoint_symlinks(
376
+ trainer: "Trainer",
377
+ paths: list[Path],
378
+ *,
379
+ root_config: "BaseConfig",
380
+ ):
381
+ config = root_config.trainer.hf_hub
382
+ if (
383
+ api := _enabled_and_valid(trainer, config, rank_zero_only=True)
384
+ ) is None or not config.save_checkpoints:
385
+ return
386
+
387
+ # Resolve the checkpoint directory
388
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
389
+ root_config.id, "checkpoint"
390
+ )
391
+
392
+ # Resolve the repository name
393
+ repo_name = _repo_name(api, root_config)
394
+
395
+ # Create a commit for copying the files
396
+ from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
397
+
398
+ commits: list[CommitOperation] = []
399
+ for p in paths:
400
+ assert p.is_symlink(), f"Path {p} is not a symlink."
401
+
402
+ try:
403
+ dest_relative_path = p.relative_to(checkpoint_dir)
404
+ except ValueError:
405
+ log.warning(
406
+ f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
407
+ )
408
+ continue
409
+
410
+ try:
411
+ source_relative_path = p.resolve().relative_to(checkpoint_dir)
412
+ except ValueError:
413
+ log.warning(
414
+ f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
415
+ )
416
+ continue
417
+
418
+ # Prefix the path in repo with "checkpoints"
419
+ dest_path_in_repo = Path("checkpoints") / dest_relative_path
420
+ source_path_in_repo = Path("checkpoints") / source_relative_path
421
+
422
+ # Create and append a CommitOperationCopy for copying the symlink
423
+ copy_op = CommitOperationCopy(
424
+ src_path_in_repo=str(source_path_in_repo),
425
+ path_in_repo=str(dest_path_in_repo),
426
+ )
427
+ commits.append(copy_op)
428
+
429
+ log.info(f"Creating a commit with the following operations: {commits}")
430
+
431
+ try:
432
+ api.create_commit(
433
+ repo_id=repo_name,
434
+ repo_type="model",
435
+ commit_message="Copy checkpoint symlinks",
436
+ operations=commits,
437
+ run_as_future=cast(Any, config.save_in_background),
438
+ )
439
+ log.info(
440
+ f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
441
+ )
442
+ except Exception:
443
+ log.exception(
444
+ f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
445
+ )
446
+
447
+ log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
448
+
449
+
450
+ def _save_checkpoint_directory(trainer: "Trainer", *, root_config: "BaseConfig"):
451
+ config = root_config.trainer.hf_hub
452
+ if (
453
+ api := _enabled_and_valid(trainer, config, rank_zero_only=True)
454
+ ) is None or not config.save_checkpoints:
455
+ return
456
+
457
+ # Resolve the checkpoint directory
458
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
459
+ root_config.id, "checkpoint"
460
+ )
461
+
462
+ # Resolve the repository name
463
+ repo_name = _repo_name(api, root_config)
464
+
465
+ # Upload the checkpoint directory to the repository
466
+ try:
467
+ api.upload_folder(
468
+ folder_path=str(checkpoint_dir),
469
+ repo_id=repo_name,
470
+ repo_type="model",
471
+ path_in_repo="checkpoints",
472
+ run_as_future=cast(Any, config.save_in_background),
473
+ )
474
+ log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
475
+ except Exception:
476
+ log.exception(
477
+ f"Failed to upload checkpoint directory to repository '{repo_name}'."
478
+ )
479
+
480
+ log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
481
+
482
+
483
+ class HFHubCallback(NTCallbackBase):
377
484
  def __init__(self, config: HuggingFaceHubConfig):
378
485
  super().__init__()
379
486
  self.config = config
380
487
 
381
488
  @override
382
489
  def setup(self, trainer, pl_module, stage):
490
+ from .trainer.trainer import Trainer
491
+
492
+ if not isinstance(trainer, Trainer):
493
+ raise ValueError(
494
+ f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
495
+ )
496
+
383
497
  root_config = cast("BaseConfig", pl_module.hparams)
384
498
  _init(trainer=trainer, root_config=root_config)
385
499
 
@@ -387,3 +501,18 @@ class HFHubCallback(Callback):
387
501
  def teardown(self, trainer, pl_module, stage):
388
502
  if hasattr(trainer, "_hf_hub_api"):
389
503
  delattr(trainer, "_hf_hub_api")
504
+
505
+ @override
506
+ def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
507
+ root_config = cast("BaseConfig", pl_module.hparams)
508
+
509
+ # If HF Hub is enabled, then we upload
510
+ if root_config.trainer.hf_hub and trainer.is_global_zero:
511
+ # Upload the regular files first, then the symlinks
512
+ all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
513
+ if regular_paths := [p for p in all_paths if not p.is_symlink()]:
514
+ _save_checkpoint_files(trainer, regular_paths, root_config=root_config)
515
+ if symlink_paths := [p for p in all_paths if p.is_symlink()]:
516
+ _save_checkpoint_symlinks(
517
+ trainer, symlink_paths, root_config=root_config
518
+ )
@@ -2,29 +2,24 @@ from abc import ABC, abstractmethod
2
2
  from collections import Counter
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, TypeAlias, TypedDict
5
+ from typing import TYPE_CHECKING, ClassVar, TypeAlias
6
6
 
7
7
  import nshconfig as C
8
8
  from lightning.pytorch import Callback
9
+ from typing_extensions import TypedDict, Unpack
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from ..model.config import BaseConfig
12
13
 
13
14
 
14
- class CallbackMetadataDict(TypedDict, total=False):
15
+ class CallbackMetadataConfig(TypedDict, total=False):
15
16
  ignore_if_exists: bool
16
- """If `True`, the callback will not be added if another callback with the same class already exists."""
17
+ """If `True`, the callback will not be added if another callback with the same class already exists.
18
+ Default is `False`."""
17
19
 
18
20
  priority: int
19
- """Priority of the callback. Callbacks with higher priority will be loaded first."""
20
-
21
-
22
- class CallbackMetadataConfig(C.Config):
23
- ignore_if_exists: bool = False
24
- """If `True`, the callback will not be added if another callback with the same class already exists."""
25
-
26
- priority: int = 0
27
- """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
+ """Priority of the callback. Callbacks with higher priority will be loaded first.
22
+ Default is `0`."""
28
23
 
29
24
 
30
25
  @dataclass(frozen=True)
@@ -37,13 +32,18 @@ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
37
32
 
38
33
 
39
34
  class CallbackConfigBase(C.Config, ABC):
40
- metadata: CallbackMetadataConfig = CallbackMetadataConfig()
35
+ metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig()
41
36
  """Metadata for the callback."""
42
37
 
43
- def with_metadata(self, callback: Callback, **metadata: CallbackMetadataDict):
44
- return CallbackWithMetadata(
45
- callback=callback, metadata=self.metadata.model_copy(update=metadata)
46
- )
38
+ @classmethod
39
+ def with_metadata(
40
+ cls, callback: Callback, **kwargs: Unpack[CallbackMetadataConfig]
41
+ ):
42
+ metadata: CallbackMetadataConfig = {}
43
+ metadata.update(cls.metadata)
44
+ metadata.update(kwargs)
45
+
46
+ return CallbackWithMetadata(callback=callback, metadata=metadata)
47
47
 
48
48
  @abstractmethod
49
49
  def create_callbacks(
@@ -73,7 +73,7 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
73
73
  for callback in callbacks:
74
74
  # If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
75
75
  if (
76
- callback.metadata.ignore_if_exists
76
+ callback.metadata.get("ignore_if_exists", False)
77
77
  and callback_classes[callback.callback.__class__] > 1
78
78
  ):
79
79
  continue
@@ -89,7 +89,10 @@ def _process_and_filter_callbacks(
89
89
  callbacks = list(callbacks)
90
90
 
91
91
  # Sort by priority (higher priority first)
92
- callbacks.sort(key=lambda callback: callback.metadata.priority, reverse=True)
92
+ callbacks.sort(
93
+ key=lambda callback: callback.metadata.get("priority", 0),
94
+ reverse=True,
95
+ )
93
96
 
94
97
  # Process `ignore_if_exists`
95
98
  callbacks = _filter_ignore_if_exists(callbacks)
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer
9
9
  from lightning.pytorch.callbacks import Checkpoint
10
10
  from typing_extensions import TypeVar, override
11
11
 
12
- from ..._checkpoint.metadata import CheckpointMetadata
12
+ from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
13
13
  from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
14
  from ..base import CallbackConfigBase
15
15
 
@@ -65,8 +65,6 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
65
65
  self.dirpath.mkdir(parents=True, exist_ok=True)
66
66
  self.symlink_dirpath = dirpath
67
67
 
68
- self._last_global_step_saved = 0
69
-
70
68
  @abstractmethod
71
69
  def default_filename(self) -> str: ...
72
70
 
@@ -144,9 +142,21 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
144
142
  if self._should_skip_saving_checkpoint(trainer):
145
143
  return
146
144
 
145
+ from ...trainer import Trainer as NTTrainer
146
+
147
+ if not isinstance(trainer, NTTrainer):
148
+ raise TypeError(
149
+ f"Trainer must be an instance of {NTTrainer.__name__}, "
150
+ f"but got {type(trainer).__name__}"
151
+ )
152
+
147
153
  # Save the new checkpoint
148
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
149
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
155
+ trainer._nshtrainer_save_checkpoint(
156
+ filepath,
157
+ self.config.save_weights_only,
158
+ use_checkpoint_cache=None,
159
+ )
150
160
 
151
161
  if trainer.is_global_zero:
152
162
  # Create the latest symlink
@@ -162,8 +172,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
162
172
  # deleted the old checkpoints, and created the symlink before continuing
163
173
  trainer.strategy.barrier()
164
174
 
165
- # Set the last global step saved
166
- self._last_global_step_saved = trainer.global_step
175
+ # Call the on save checkpoint callback for the symlink (if it exists)
176
+ if (symlink_filename := self.symlink_path()) is not None:
177
+ from ... import _callback
178
+
179
+ symlink_path = self.dirpath / symlink_filename
180
+ symlink_metadata_path = _metadata_path(symlink_path)
181
+ _callback._call_on_checkpoint_saved(
182
+ trainer, symlink_path, symlink_metadata_path
183
+ )
167
184
 
168
185
  def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
169
186
  from lightning.pytorch.trainer.states import TrainerFn
@@ -175,6 +192,4 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
175
192
  or trainer.state.fn
176
193
  != TrainerFn.FITTING # don't save anything during non-fit
177
194
  or trainer.sanity_checking # don't save anything during sanity check
178
- or self._last_global_step_saved
179
- == trainer.global_step # already saved at the last step
180
195
  )
@@ -1012,6 +1012,8 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
+ use_checkpoint_cache: bool = True
1016
+ """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1015
1017
 
1016
1018
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1017
1019
  """
@@ -17,6 +17,7 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
17
17
  from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
+ from .._checkpoint.saver import _link_checkpoint
20
21
  from ..callbacks.base import resolve_all_callbacks
21
22
  from ..model.config import (
22
23
  AcceleratorConfigProtocol,
@@ -285,6 +286,8 @@ class Trainer(LightningTrainer):
285
286
  /,
286
287
  **kwargs: Unpack[LightningTrainerKwargs],
287
288
  ):
289
+ self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
290
+
288
291
  self._pre_init(config)
289
292
 
290
293
  kwargs = self._update_kwargs(config, kwargs)
@@ -407,29 +410,70 @@ class Trainer(LightningTrainer):
407
410
 
408
411
  return super()._run(model, ckpt_path)
409
412
 
410
- @override
411
- def save_checkpoint(
413
+ def _nshtrainer_save_checkpoint(
412
414
  self,
413
415
  filepath: str | Path,
414
416
  weights_only: bool = False,
415
417
  storage_options: Any | None = None,
418
+ use_checkpoint_cache: bool | None = None,
416
419
  ):
420
+ lm = self._base_module
421
+ hparams = cast(BaseConfig, lm.hparams)
422
+ if use_checkpoint_cache is None:
423
+ use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
424
+
417
425
  filepath = Path(filepath)
418
- ret_val = super().save_checkpoint(filepath, weights_only, storage_options)
426
+
427
+ # List of files that we should upload to HF
428
+ written_files: list[Path] = [filepath]
429
+
430
+ cached_path = None
431
+ if (
432
+ use_checkpoint_cache
433
+ and (
434
+ cached_path := self._nshtrainer_checkpoint_cache.get(
435
+ (self.current_epoch, self.global_step)
436
+ )
437
+ )
438
+ is not None
439
+ ):
440
+ # If we have a cached path, then we symlink it to the new path.
441
+ log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
+ _link_checkpoint(cached_path, filepath, metadata=False)
443
+ else:
444
+ super().save_checkpoint(filepath, weights_only, storage_options)
445
+
446
+ # If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
447
+ if use_checkpoint_cache and cached_path is None:
448
+ self._nshtrainer_checkpoint_cache[
449
+ (self.current_epoch, self.global_step)
450
+ ] = filepath
451
+ log.debug(f"Checkpoint saved to cache: {filepath}")
419
452
 
420
453
  # Save the checkpoint metadata
421
- lm = self._base_module
422
- hparams = cast(BaseConfig, lm.hparams)
423
454
  metadata_path = None
424
455
  if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
425
456
  # Generate the metadata and write to disk
426
- metadata_path = _write_checkpoint_metadata(self, lm, filepath)
457
+ if (
458
+ metadata_path := _write_checkpoint_metadata(self, lm, filepath)
459
+ ) is not None:
460
+ written_files.append(metadata_path)
427
461
 
428
- # If HF Hub is enabled, then we upload
429
- if hparams.trainer.hf_hub and self.is_global_zero:
430
- from .._hf_hub import _save_checkpoint_files
462
+ # Call the `on_checkpoint_saved` method on all callbacks
463
+ from .. import _callback
431
464
 
432
- files = [f for f in (filepath, metadata_path) if f is not None]
433
- _save_checkpoint_files(self, files, root_config=hparams)
465
+ _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
434
466
 
435
- return ret_val
467
+ @override
468
+ def save_checkpoint(
469
+ self,
470
+ filepath: str | Path,
471
+ weights_only: bool = False,
472
+ storage_options: Any | None = None,
473
+ ):
474
+ return self._nshtrainer_save_checkpoint(
475
+ filepath=filepath,
476
+ weights_only=weights_only,
477
+ storage_options=storage_options,
478
+ use_checkpoint_cache=False,
479
+ )
@@ -0,0 +1,29 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import TypeAlias
4
+
5
+ _Path: TypeAlias = str | Path | os.PathLike
6
+
7
+
8
+ def get_relative_path(source: _Path, destination: _Path):
9
+ # Get the absolute paths
10
+ source = os.path.abspath(source)
11
+ destination = os.path.abspath(destination)
12
+
13
+ # Split the paths into components
14
+ source_parts = source.split(os.sep)
15
+ destination_parts = destination.split(os.sep)
16
+
17
+ # Find the point where the paths diverge
18
+ i = 0
19
+ for i in range(min(len(source_parts), len(destination_parts))):
20
+ if source_parts[i] != destination_parts[i]:
21
+ break
22
+ else:
23
+ i += 1
24
+
25
+ # Build the relative path
26
+ up = os.sep.join([".." for _ in range(len(source_parts) - i - 1)])
27
+ down = os.sep.join(destination_parts[i:])
28
+
29
+ return Path(os.path.normpath(os.path.join(up, down)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.20.0
3
+ Version: 0.22.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,15 +1,16 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
+ nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
2
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
3
- nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
4
- nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
4
+ nshtrainer/_checkpoint/metadata.py,sha256=TLAt7yR3KhSRbXCtomLMxcMvOiAju873A1ZRo8VWNwA,5179
5
+ nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
5
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
6
- nshtrainer/_hf_hub.py,sha256=To3BnnGWbMNNMBdzVtgrNOcNU2fi1dQpwwuclusFAbI,12169
7
+ nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
7
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
8
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
9
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
10
- nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
11
+ nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
11
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
12
- nshtrainer/callbacks/checkpoint/_base.py,sha256=YT_V-oihO9iB4ETl46CGYTCQjIYl-CpV7TMViTn07Lk,6144
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
13
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
14
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
15
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -57,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
57
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
58
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
- nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
61
+ nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
61
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
62
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
63
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -77,14 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
77
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
78
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
79
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
80
- nshtrainer/trainer/trainer.py,sha256=TTtVkgSB_ekgDlHg24d58Vzddtkpp6ZHOTVprXdXMH0,17503
81
+ nshtrainer/trainer/trainer.py,sha256=DNKA4mcW083i7qLk0fi3j5-Qj4KNBtlLuyIsNxykebw,19100
81
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
82
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
83
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
+ nshtrainer/util/path.py,sha256=A_Ocag3_hbwns_zAxFDlH-5eVHWFlcy2DKxHQ7jddvk,837
84
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.20.0.dist-info/METADATA,sha256=BCzgQYVMH8_7VHpAcAEuJqlQ0oJOERSbBop4bOebYZ4,935
89
- nshtrainer-0.20.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.20.0.dist-info/RECORD,,
90
+ nshtrainer-0.22.0.dist-info/METADATA,sha256=sdjt9S4X3xiIGgD6FNF06yIyC1tJA89B9Qm9mxy29tc,935
91
+ nshtrainer-0.22.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.22.0.dist-info/RECORD,,