nshtrainer 0.19.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_callback.py +40 -0
- nshtrainer/_checkpoint/loader.py +12 -4
- nshtrainer/_checkpoint/metadata.py +8 -4
- nshtrainer/_checkpoint/saver.py +6 -7
- nshtrainer/_hf_hub.py +158 -22
- nshtrainer/callbacks/checkpoint/_base.py +23 -8
- nshtrainer/model/config.py +9 -4
- nshtrainer/trainer/checkpoint_connector.py +8 -2
- nshtrainer/trainer/trainer.py +56 -12
- nshtrainer/util/path.py +29 -0
- {nshtrainer-0.19.3.dist-info → nshtrainer-0.21.0.dist-info}/METADATA +2 -1
- {nshtrainer-0.19.3.dist-info → nshtrainer-0.21.0.dist-info}/RECORD +13 -11
- {nshtrainer-0.19.3.dist-info → nshtrainer-0.21.0.dist-info}/WHEEL +0 -0
nshtrainer/_callback.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch.callbacks import Callback as _LightningCallback
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .model import LightningModuleBase
|
|
8
|
+
from .trainer import Trainer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NTCallbackBase(_LightningCallback):
|
|
12
|
+
def on_checkpoint_saved(
|
|
13
|
+
self,
|
|
14
|
+
ckpt_path: Path,
|
|
15
|
+
metadata_path: Path | None,
|
|
16
|
+
trainer: "Trainer",
|
|
17
|
+
pl_module: "LightningModuleBase",
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Called after a checkpoint is saved."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _call_on_checkpoint_saved(
|
|
24
|
+
trainer: "Trainer",
|
|
25
|
+
ckpt_path: str | Path,
|
|
26
|
+
metadata_path: str | Path | None,
|
|
27
|
+
):
|
|
28
|
+
ckpt_path = Path(ckpt_path)
|
|
29
|
+
metadata_path = Path(metadata_path) if metadata_path else None
|
|
30
|
+
|
|
31
|
+
for callback in trainer.callbacks:
|
|
32
|
+
if not isinstance(callback, NTCallbackBase):
|
|
33
|
+
continue
|
|
34
|
+
|
|
35
|
+
callback.on_checkpoint_saved(
|
|
36
|
+
ckpt_path,
|
|
37
|
+
metadata_path,
|
|
38
|
+
trainer,
|
|
39
|
+
trainer._base_module,
|
|
40
|
+
)
|
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -76,7 +76,11 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
76
76
|
"""Whether to include checkpoints from HPC pre-emption."""
|
|
77
77
|
|
|
78
78
|
@classmethod
|
|
79
|
-
def
|
|
79
|
+
def none(cls, include_hpc: bool = False):
|
|
80
|
+
return cls(strategies=[], include_hpc=include_hpc)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
|
80
84
|
if ckpt is None:
|
|
81
85
|
ckpt = "last"
|
|
82
86
|
match ckpt:
|
|
@@ -90,6 +94,8 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
90
94
|
strategies=[LastCheckpointStrategyConfig()],
|
|
91
95
|
include_hpc=True,
|
|
92
96
|
)
|
|
97
|
+
case "none":
|
|
98
|
+
return cls.none()
|
|
93
99
|
case Path() | str():
|
|
94
100
|
ckpt = Path(ckpt)
|
|
95
101
|
return cls(
|
|
@@ -103,7 +109,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
103
109
|
assert_never(ckpt)
|
|
104
110
|
|
|
105
111
|
@classmethod
|
|
106
|
-
def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
112
|
+
def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
|
107
113
|
if ckpt is None:
|
|
108
114
|
log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
|
|
109
115
|
ckpt = "last"
|
|
@@ -119,6 +125,8 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
119
125
|
strategies=[LastCheckpointStrategyConfig()],
|
|
120
126
|
include_hpc=False,
|
|
121
127
|
)
|
|
128
|
+
case "none":
|
|
129
|
+
return cls.none(include_hpc=False)
|
|
122
130
|
case Path() | str():
|
|
123
131
|
ckpt = Path(ckpt)
|
|
124
132
|
return cls(
|
|
@@ -131,7 +139,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
131
139
|
@classmethod
|
|
132
140
|
def auto(
|
|
133
141
|
cls,
|
|
134
|
-
ckpt: Literal["best", "last"] | str | Path | None,
|
|
142
|
+
ckpt: Literal["best", "last", "none"] | str | Path | None,
|
|
135
143
|
trainer_mode: TrainerFn,
|
|
136
144
|
):
|
|
137
145
|
"""
|
|
@@ -142,7 +150,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
142
150
|
|
|
143
151
|
Parameters:
|
|
144
152
|
-----------
|
|
145
|
-
ckpt : Literal["best", "last"] | str | Path | None
|
|
153
|
+
ckpt : Literal["best", "last", "none"] | str | Path | None
|
|
146
154
|
Specifies the checkpoint loading preference:
|
|
147
155
|
- "best": Use the best checkpoint based on the primary metric.
|
|
148
156
|
- "last": Use the most recent checkpoint.
|
|
@@ -96,13 +96,17 @@ def _generate_checkpoint_metadata(
|
|
|
96
96
|
)
|
|
97
97
|
|
|
98
98
|
|
|
99
|
+
def _metadata_path(checkpoint_path: Path):
|
|
100
|
+
return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
101
|
+
|
|
102
|
+
|
|
99
103
|
def _write_checkpoint_metadata(
|
|
100
104
|
trainer: "Trainer",
|
|
101
105
|
model: "LightningModuleBase",
|
|
102
106
|
checkpoint_path: Path,
|
|
103
107
|
):
|
|
104
108
|
config = cast("BaseConfig", model.config)
|
|
105
|
-
metadata_path = checkpoint_path
|
|
109
|
+
metadata_path = _metadata_path(checkpoint_path)
|
|
106
110
|
metadata = _generate_checkpoint_metadata(
|
|
107
111
|
config, trainer, checkpoint_path, metadata_path
|
|
108
112
|
)
|
|
@@ -119,7 +123,7 @@ def _write_checkpoint_metadata(
|
|
|
119
123
|
|
|
120
124
|
|
|
121
125
|
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
-
path = checkpoint_path
|
|
126
|
+
path = _metadata_path(checkpoint_path)
|
|
123
127
|
try:
|
|
124
128
|
path.unlink(missing_ok=True)
|
|
125
129
|
except Exception:
|
|
@@ -133,8 +137,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
133
137
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
134
138
|
|
|
135
139
|
# Link the metadata files to the new checkpoint
|
|
136
|
-
path = checkpoint_path
|
|
137
|
-
linked_path = linked_checkpoint_path
|
|
140
|
+
path = _metadata_path(checkpoint_path)
|
|
141
|
+
linked_path = _metadata_path(linked_checkpoint_path)
|
|
138
142
|
try:
|
|
139
143
|
try:
|
|
140
144
|
# linked_path.symlink_to(path)
|
nshtrainer/_checkpoint/saver.py
CHANGED
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch import Trainer
|
|
6
6
|
|
|
7
|
+
from ..util.path import get_relative_path
|
|
7
8
|
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
8
9
|
|
|
9
10
|
|
|
@@ -14,10 +15,8 @@ def _link_checkpoint(
|
|
|
14
15
|
metadata: bool,
|
|
15
16
|
remove_existing: bool = True,
|
|
16
17
|
):
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
if not isinstance(linkpath, Path):
|
|
20
|
-
linkpath = Path(linkpath)
|
|
18
|
+
filepath = Path(filepath)
|
|
19
|
+
linkpath = Path(linkpath)
|
|
21
20
|
|
|
22
21
|
if remove_existing:
|
|
23
22
|
if linkpath.exists():
|
|
@@ -30,7 +29,7 @@ def _link_checkpoint(
|
|
|
30
29
|
_remove_checkpoint_metadata(linkpath)
|
|
31
30
|
|
|
32
31
|
try:
|
|
33
|
-
linkpath.symlink_to(
|
|
32
|
+
linkpath.symlink_to(get_relative_path(linkpath, filepath))
|
|
34
33
|
except OSError:
|
|
35
34
|
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
36
35
|
# fall back to copying the file
|
|
@@ -46,9 +45,9 @@ def _remove_checkpoint(
|
|
|
46
45
|
*,
|
|
47
46
|
metadata: bool,
|
|
48
47
|
):
|
|
49
|
-
|
|
50
|
-
filepath = Path(filepath)
|
|
48
|
+
filepath = Path(filepath)
|
|
51
49
|
|
|
52
50
|
trainer.strategy.remove_checkpoint(filepath)
|
|
51
|
+
|
|
53
52
|
if metadata:
|
|
54
53
|
_remove_checkpoint_metadata(filepath)
|
nshtrainer/_hf_hub.py
CHANGED
|
@@ -6,11 +6,10 @@ from pathlib import Path
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
import nshconfig as C
|
|
9
|
-
from lightning.pytorch import Callback
|
|
10
|
-
from lightning.pytorch.trainer import Trainer
|
|
11
9
|
from nshrunner._env import SNAPSHOT_DIR
|
|
12
10
|
from typing_extensions import override
|
|
13
11
|
|
|
12
|
+
from ._callback import NTCallbackBase
|
|
14
13
|
from .callbacks.base import (
|
|
15
14
|
CallbackConfigBase,
|
|
16
15
|
CallbackMetadataConfig,
|
|
@@ -22,6 +21,8 @@ if TYPE_CHECKING:
|
|
|
22
21
|
|
|
23
22
|
from .model.base import BaseConfig
|
|
24
23
|
from .trainer.trainer import Trainer
|
|
24
|
+
|
|
25
|
+
|
|
25
26
|
log = logging.getLogger(__name__)
|
|
26
27
|
|
|
27
28
|
|
|
@@ -102,9 +103,9 @@ def _api(token: str | None = None):
|
|
|
102
103
|
|
|
103
104
|
# Verify authentication
|
|
104
105
|
api.whoami()
|
|
105
|
-
except Exception
|
|
106
|
+
except Exception:
|
|
106
107
|
log.exception(
|
|
107
|
-
|
|
108
|
+
"Authentication failed for Hugging Face Hub. "
|
|
108
109
|
"Please make sure you are logged in using `huggingface-cli login`, "
|
|
109
110
|
"by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
|
|
110
111
|
"or by providing a valid token in the configuration."
|
|
@@ -210,10 +211,10 @@ def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
|
|
|
210
211
|
exist_ok=True,
|
|
211
212
|
)
|
|
212
213
|
log.info(f"Created new repository '{repo_name}'.")
|
|
213
|
-
except Exception
|
|
214
|
-
log.exception(f"Failed to create repository '{repo_name}'
|
|
215
|
-
except Exception
|
|
216
|
-
log.exception(f"Error checking repository '{repo_name}'
|
|
214
|
+
except Exception:
|
|
215
|
+
log.exception(f"Failed to create repository '{repo_name}'")
|
|
216
|
+
except Exception:
|
|
217
|
+
log.exception(f"Error checking repository '{repo_name}'")
|
|
217
218
|
|
|
218
219
|
# Upload the config
|
|
219
220
|
_save_config(root_config, trainer=trainer)
|
|
@@ -262,9 +263,9 @@ def _save_code(
|
|
|
262
263
|
log.info(
|
|
263
264
|
f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
264
265
|
)
|
|
265
|
-
except Exception
|
|
266
|
+
except Exception:
|
|
266
267
|
log.exception(
|
|
267
|
-
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder
|
|
268
|
+
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
268
269
|
)
|
|
269
270
|
|
|
270
271
|
|
|
@@ -300,10 +301,8 @@ def _save_config(
|
|
|
300
301
|
run_as_future=cast(Any, config.save_in_background),
|
|
301
302
|
)
|
|
302
303
|
log.info(f"Uploaded config.json to repository '{repo_name}'.")
|
|
303
|
-
except Exception
|
|
304
|
-
log.exception(
|
|
305
|
-
f"Failed to upload config.json to repository '{repo_name}': {str(e)}"
|
|
306
|
-
)
|
|
304
|
+
except Exception:
|
|
305
|
+
log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
|
|
307
306
|
|
|
308
307
|
|
|
309
308
|
def _save_checkpoint_files(
|
|
@@ -331,17 +330,24 @@ def _save_checkpoint_files(
|
|
|
331
330
|
# Read all the files to memory
|
|
332
331
|
file_contents: list[bytes | None] = []
|
|
333
332
|
for p in paths:
|
|
333
|
+
assert not p.is_symlink(), f"Path {p} is a symlink."
|
|
334
|
+
assert p.is_file(), f"Path {p} is not a file."
|
|
334
335
|
try:
|
|
335
336
|
with open(p, "rb") as f:
|
|
336
337
|
file_contents.append(f.read())
|
|
337
|
-
except IOError
|
|
338
|
-
log.
|
|
338
|
+
except IOError:
|
|
339
|
+
log.exception(f"Failed to read checkpoint file {p}.")
|
|
339
340
|
file_contents.append(None)
|
|
340
341
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
342
|
+
# Remove the paths that failed to read
|
|
343
|
+
file_contents_and_paths = [
|
|
344
|
+
(contents, p)
|
|
345
|
+
for contents, p in zip(file_contents, paths)
|
|
346
|
+
if contents is not None
|
|
347
|
+
]
|
|
344
348
|
|
|
349
|
+
# Upload the checkpoint files to the repository
|
|
350
|
+
for contents, p in file_contents_and_paths:
|
|
345
351
|
try:
|
|
346
352
|
relative_path = p.relative_to(checkpoint_dir)
|
|
347
353
|
except ValueError:
|
|
@@ -365,21 +371,136 @@ def _save_checkpoint_files(
|
|
|
365
371
|
log.info(
|
|
366
372
|
f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
|
|
367
373
|
)
|
|
368
|
-
except Exception
|
|
374
|
+
except Exception:
|
|
369
375
|
log.exception(
|
|
370
|
-
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'
|
|
376
|
+
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
|
|
371
377
|
)
|
|
372
378
|
|
|
373
379
|
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
374
380
|
|
|
375
381
|
|
|
376
|
-
|
|
382
|
+
def _save_checkpoint_symlinks(
|
|
383
|
+
trainer: "Trainer",
|
|
384
|
+
paths: list[Path],
|
|
385
|
+
*,
|
|
386
|
+
root_config: "BaseConfig",
|
|
387
|
+
):
|
|
388
|
+
config = root_config.trainer.hf_hub
|
|
389
|
+
if (
|
|
390
|
+
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
391
|
+
) is None or not config.save_checkpoints:
|
|
392
|
+
return
|
|
393
|
+
|
|
394
|
+
# Resolve the checkpoint directory
|
|
395
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
396
|
+
root_config.id, "checkpoint"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Resolve the repository name
|
|
400
|
+
repo_name = _repo_name(api, root_config)
|
|
401
|
+
|
|
402
|
+
# Create a commit for copying the files
|
|
403
|
+
from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
|
|
404
|
+
|
|
405
|
+
commits: list[CommitOperation] = []
|
|
406
|
+
for p in paths:
|
|
407
|
+
assert p.is_symlink(), f"Path {p} is not a symlink."
|
|
408
|
+
|
|
409
|
+
try:
|
|
410
|
+
dest_relative_path = p.relative_to(checkpoint_dir)
|
|
411
|
+
except ValueError:
|
|
412
|
+
log.warning(
|
|
413
|
+
f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
|
|
414
|
+
)
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
try:
|
|
418
|
+
source_relative_path = p.resolve().relative_to(checkpoint_dir)
|
|
419
|
+
except ValueError:
|
|
420
|
+
log.warning(
|
|
421
|
+
f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
|
|
422
|
+
)
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
# Prefix the path in repo with "checkpoints"
|
|
426
|
+
dest_path_in_repo = Path("checkpoints") / dest_relative_path
|
|
427
|
+
source_path_in_repo = Path("checkpoints") / source_relative_path
|
|
428
|
+
|
|
429
|
+
# Create and append a CommitOperationCopy for copying the symlink
|
|
430
|
+
copy_op = CommitOperationCopy(
|
|
431
|
+
src_path_in_repo=str(source_path_in_repo),
|
|
432
|
+
path_in_repo=str(dest_path_in_repo),
|
|
433
|
+
)
|
|
434
|
+
commits.append(copy_op)
|
|
435
|
+
|
|
436
|
+
log.info(f"Creating a commit with the following operations: {commits}")
|
|
437
|
+
|
|
438
|
+
try:
|
|
439
|
+
api.create_commit(
|
|
440
|
+
repo_id=repo_name,
|
|
441
|
+
repo_type="model",
|
|
442
|
+
commit_message="Copy checkpoint symlinks",
|
|
443
|
+
operations=commits,
|
|
444
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
445
|
+
)
|
|
446
|
+
log.info(
|
|
447
|
+
f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
|
|
448
|
+
)
|
|
449
|
+
except Exception:
|
|
450
|
+
log.exception(
|
|
451
|
+
f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _save_checkpoint_directory(trainer: "Trainer", *, root_config: "BaseConfig"):
|
|
458
|
+
config = root_config.trainer.hf_hub
|
|
459
|
+
if (
|
|
460
|
+
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
461
|
+
) is None or not config.save_checkpoints:
|
|
462
|
+
return
|
|
463
|
+
|
|
464
|
+
# Resolve the checkpoint directory
|
|
465
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
466
|
+
root_config.id, "checkpoint"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Resolve the repository name
|
|
470
|
+
repo_name = _repo_name(api, root_config)
|
|
471
|
+
|
|
472
|
+
# Upload the checkpoint directory to the repository
|
|
473
|
+
try:
|
|
474
|
+
api.upload_folder(
|
|
475
|
+
folder_path=str(checkpoint_dir),
|
|
476
|
+
repo_id=repo_name,
|
|
477
|
+
repo_type="model",
|
|
478
|
+
path_in_repo="checkpoints",
|
|
479
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
480
|
+
)
|
|
481
|
+
log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
|
|
482
|
+
except Exception:
|
|
483
|
+
log.exception(
|
|
484
|
+
f"Failed to upload checkpoint directory to repository '{repo_name}'."
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
class HFHubCallback(NTCallbackBase):
|
|
377
491
|
def __init__(self, config: HuggingFaceHubConfig):
|
|
378
492
|
super().__init__()
|
|
379
493
|
self.config = config
|
|
380
494
|
|
|
381
495
|
@override
|
|
382
496
|
def setup(self, trainer, pl_module, stage):
|
|
497
|
+
from .trainer.trainer import Trainer
|
|
498
|
+
|
|
499
|
+
if not isinstance(trainer, Trainer):
|
|
500
|
+
raise ValueError(
|
|
501
|
+
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
502
|
+
)
|
|
503
|
+
|
|
383
504
|
root_config = cast("BaseConfig", pl_module.hparams)
|
|
384
505
|
_init(trainer=trainer, root_config=root_config)
|
|
385
506
|
|
|
@@ -387,3 +508,18 @@ class HFHubCallback(Callback):
|
|
|
387
508
|
def teardown(self, trainer, pl_module, stage):
|
|
388
509
|
if hasattr(trainer, "_hf_hub_api"):
|
|
389
510
|
delattr(trainer, "_hf_hub_api")
|
|
511
|
+
|
|
512
|
+
@override
|
|
513
|
+
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
514
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
515
|
+
|
|
516
|
+
# If HF Hub is enabled, then we upload
|
|
517
|
+
if root_config.trainer.hf_hub and trainer.is_global_zero:
|
|
518
|
+
# Upload the regular files first, then the symlinks
|
|
519
|
+
all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
|
|
520
|
+
if regular_paths := [p for p in all_paths if not p.is_symlink()]:
|
|
521
|
+
_save_checkpoint_files(trainer, regular_paths, root_config=root_config)
|
|
522
|
+
if symlink_paths := [p for p in all_paths if p.is_symlink()]:
|
|
523
|
+
_save_checkpoint_symlinks(
|
|
524
|
+
trainer, symlink_paths, root_config=root_config
|
|
525
|
+
)
|
|
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer
|
|
|
9
9
|
from lightning.pytorch.callbacks import Checkpoint
|
|
10
10
|
from typing_extensions import TypeVar, override
|
|
11
11
|
|
|
12
|
-
from ..._checkpoint.metadata import CheckpointMetadata
|
|
12
|
+
from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
|
|
13
13
|
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
14
|
from ..base import CallbackConfigBase
|
|
15
15
|
|
|
@@ -65,8 +65,6 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
65
65
|
self.dirpath.mkdir(parents=True, exist_ok=True)
|
|
66
66
|
self.symlink_dirpath = dirpath
|
|
67
67
|
|
|
68
|
-
self._last_global_step_saved = 0
|
|
69
|
-
|
|
70
68
|
@abstractmethod
|
|
71
69
|
def default_filename(self) -> str: ...
|
|
72
70
|
|
|
@@ -144,9 +142,21 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
144
142
|
if self._should_skip_saving_checkpoint(trainer):
|
|
145
143
|
return
|
|
146
144
|
|
|
145
|
+
from ...trainer import Trainer as NTTrainer
|
|
146
|
+
|
|
147
|
+
if not isinstance(trainer, NTTrainer):
|
|
148
|
+
raise TypeError(
|
|
149
|
+
f"Trainer must be an instance of {NTTrainer.__name__}, "
|
|
150
|
+
f"but got {type(trainer).__name__}"
|
|
151
|
+
)
|
|
152
|
+
|
|
147
153
|
# Save the new checkpoint
|
|
148
154
|
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
|
149
|
-
trainer.
|
|
155
|
+
trainer._nshtrainer_save_checkpoint(
|
|
156
|
+
filepath,
|
|
157
|
+
self.config.save_weights_only,
|
|
158
|
+
use_checkpoint_cache=None,
|
|
159
|
+
)
|
|
150
160
|
|
|
151
161
|
if trainer.is_global_zero:
|
|
152
162
|
# Create the latest symlink
|
|
@@ -162,8 +172,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
162
172
|
# deleted the old checkpoints, and created the symlink before continuing
|
|
163
173
|
trainer.strategy.barrier()
|
|
164
174
|
|
|
165
|
-
#
|
|
166
|
-
self.
|
|
175
|
+
# Call the on save checkpoint callback for the symlink (if it exists)
|
|
176
|
+
if (symlink_filename := self.symlink_path()) is not None:
|
|
177
|
+
from ... import _callback
|
|
178
|
+
|
|
179
|
+
symlink_path = self.dirpath / symlink_filename
|
|
180
|
+
symlink_metadata_path = _metadata_path(symlink_path)
|
|
181
|
+
_callback._call_on_checkpoint_saved(
|
|
182
|
+
trainer, symlink_path, symlink_metadata_path
|
|
183
|
+
)
|
|
167
184
|
|
|
168
185
|
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
169
186
|
from lightning.pytorch.trainer.states import TrainerFn
|
|
@@ -175,6 +192,4 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
175
192
|
or trainer.state.fn
|
|
176
193
|
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
177
194
|
or trainer.sanity_checking # don't save anything during sanity check
|
|
178
|
-
or self._last_global_step_saved
|
|
179
|
-
== trainer.global_step # already saved at the last step
|
|
180
195
|
)
|
nshtrainer/model/config.py
CHANGED
|
@@ -811,11 +811,14 @@ class SanityCheckingConfig(C.Config):
|
|
|
811
811
|
|
|
812
812
|
|
|
813
813
|
class TrainerConfig(C.Config):
|
|
814
|
-
ckpt_path: str | Path | None = None
|
|
815
|
-
"""Path to a checkpoint to load and resume training from."""
|
|
814
|
+
ckpt_path: Literal["none"] | str | Path | None = None
|
|
815
|
+
"""Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
|
|
816
816
|
|
|
817
|
-
checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
|
|
818
|
-
"""Checkpoint loading configuration options.
|
|
817
|
+
checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
|
|
818
|
+
"""Checkpoint loading configuration options.
|
|
819
|
+
`"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
|
|
820
|
+
`"none"` will disable checkpoint loading.
|
|
821
|
+
"""
|
|
819
822
|
|
|
820
823
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
821
824
|
"""Checkpoint saving configuration options."""
|
|
@@ -1009,6 +1012,8 @@ class TrainerConfig(C.Config):
|
|
|
1009
1012
|
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1010
1013
|
save_checkpoint_metadata: bool = True
|
|
1011
1014
|
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
1015
|
+
use_checkpoint_cache: bool = True
|
|
1016
|
+
"""If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
|
|
1012
1017
|
|
|
1013
1018
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1014
1019
|
"""
|
|
@@ -31,8 +31,14 @@ class _CheckpointConnector(_LightningCheckpointConnector):
|
|
|
31
31
|
|
|
32
32
|
# Now, resolve the checkpoint loader config.
|
|
33
33
|
root_config = cast("BaseConfig", trainer._base_module.config)
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
ckpt_loader_config = root_config.trainer.checkpoint_loading
|
|
35
|
+
match ckpt_loader_config:
|
|
36
|
+
case "auto":
|
|
37
|
+
ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
|
|
38
|
+
case "none":
|
|
39
|
+
ckpt_loader_config = CheckpointLoadingConfig.none()
|
|
40
|
+
case _:
|
|
41
|
+
pass
|
|
36
42
|
log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
|
|
37
43
|
|
|
38
44
|
# Use the config to resolve the checkpoint.
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -17,6 +17,7 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
|
|
17
17
|
from typing_extensions import Unpack, assert_never, override
|
|
18
18
|
|
|
19
19
|
from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
20
|
+
from .._checkpoint.saver import _link_checkpoint
|
|
20
21
|
from ..callbacks.base import resolve_all_callbacks
|
|
21
22
|
from ..model.config import (
|
|
22
23
|
AcceleratorConfigProtocol,
|
|
@@ -285,6 +286,8 @@ class Trainer(LightningTrainer):
|
|
|
285
286
|
/,
|
|
286
287
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
287
288
|
):
|
|
289
|
+
self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
|
|
290
|
+
|
|
288
291
|
self._pre_init(config)
|
|
289
292
|
|
|
290
293
|
kwargs = self._update_kwargs(config, kwargs)
|
|
@@ -407,29 +410,70 @@ class Trainer(LightningTrainer):
|
|
|
407
410
|
|
|
408
411
|
return super()._run(model, ckpt_path)
|
|
409
412
|
|
|
410
|
-
|
|
411
|
-
def save_checkpoint(
|
|
413
|
+
def _nshtrainer_save_checkpoint(
|
|
412
414
|
self,
|
|
413
415
|
filepath: str | Path,
|
|
414
416
|
weights_only: bool = False,
|
|
415
417
|
storage_options: Any | None = None,
|
|
418
|
+
use_checkpoint_cache: bool | None = None,
|
|
416
419
|
):
|
|
420
|
+
lm = self._base_module
|
|
421
|
+
hparams = cast(BaseConfig, lm.hparams)
|
|
422
|
+
if use_checkpoint_cache is None:
|
|
423
|
+
use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
|
|
424
|
+
|
|
417
425
|
filepath = Path(filepath)
|
|
418
|
-
|
|
426
|
+
|
|
427
|
+
# List of files that we should upload to HF
|
|
428
|
+
written_files: list[Path] = [filepath]
|
|
429
|
+
|
|
430
|
+
cached_path = None
|
|
431
|
+
if (
|
|
432
|
+
use_checkpoint_cache
|
|
433
|
+
and (
|
|
434
|
+
cached_path := self._nshtrainer_checkpoint_cache.get(
|
|
435
|
+
(self.current_epoch, self.global_step)
|
|
436
|
+
)
|
|
437
|
+
)
|
|
438
|
+
is not None
|
|
439
|
+
):
|
|
440
|
+
# If we have a cached path, then we symlink it to the new path.
|
|
441
|
+
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
|
+
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
443
|
+
else:
|
|
444
|
+
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
445
|
+
|
|
446
|
+
# If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
|
|
447
|
+
if use_checkpoint_cache and cached_path is None:
|
|
448
|
+
self._nshtrainer_checkpoint_cache[
|
|
449
|
+
(self.current_epoch, self.global_step)
|
|
450
|
+
] = filepath
|
|
451
|
+
log.debug(f"Checkpoint saved to cache: {filepath}")
|
|
419
452
|
|
|
420
453
|
# Save the checkpoint metadata
|
|
421
|
-
lm = self._base_module
|
|
422
|
-
hparams = cast(BaseConfig, lm.hparams)
|
|
423
454
|
metadata_path = None
|
|
424
455
|
if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
425
456
|
# Generate the metadata and write to disk
|
|
426
|
-
|
|
457
|
+
if (
|
|
458
|
+
metadata_path := _write_checkpoint_metadata(self, lm, filepath)
|
|
459
|
+
) is not None:
|
|
460
|
+
written_files.append(metadata_path)
|
|
427
461
|
|
|
428
|
-
#
|
|
429
|
-
|
|
430
|
-
from .._hf_hub import _save_checkpoint_files
|
|
462
|
+
# Call the `on_checkpoint_saved` method on all callbacks
|
|
463
|
+
from .. import _callback
|
|
431
464
|
|
|
432
|
-
|
|
433
|
-
_save_checkpoint_files(self, files, root_config=hparams)
|
|
465
|
+
_callback._call_on_checkpoint_saved(self, filepath, metadata_path)
|
|
434
466
|
|
|
435
|
-
|
|
467
|
+
@override
|
|
468
|
+
def save_checkpoint(
|
|
469
|
+
self,
|
|
470
|
+
filepath: str | Path,
|
|
471
|
+
weights_only: bool = False,
|
|
472
|
+
storage_options: Any | None = None,
|
|
473
|
+
):
|
|
474
|
+
return self._nshtrainer_save_checkpoint(
|
|
475
|
+
filepath=filepath,
|
|
476
|
+
weights_only=weights_only,
|
|
477
|
+
storage_options=storage_options,
|
|
478
|
+
use_checkpoint_cache=False,
|
|
479
|
+
)
|
nshtrainer/util/path.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
|
|
5
|
+
_Path: TypeAlias = str | Path | os.PathLike
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_relative_path(source: _Path, destination: _Path):
|
|
9
|
+
# Get the absolute paths
|
|
10
|
+
source = os.path.abspath(source)
|
|
11
|
+
destination = os.path.abspath(destination)
|
|
12
|
+
|
|
13
|
+
# Split the paths into components
|
|
14
|
+
source_parts = source.split(os.sep)
|
|
15
|
+
destination_parts = destination.split(os.sep)
|
|
16
|
+
|
|
17
|
+
# Find the point where the paths diverge
|
|
18
|
+
i = 0
|
|
19
|
+
for i in range(min(len(source_parts), len(destination_parts))):
|
|
20
|
+
if source_parts[i] != destination_parts[i]:
|
|
21
|
+
break
|
|
22
|
+
else:
|
|
23
|
+
i += 1
|
|
24
|
+
|
|
25
|
+
# Build the relative path
|
|
26
|
+
up = os.sep.join([".." for _ in range(len(source_parts) - i - 1)])
|
|
27
|
+
down = os.sep.join(destination_parts[i:])
|
|
28
|
+
|
|
29
|
+
return Path(os.path.normpath(os.path.join(up, down)))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.21.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -26,6 +26,7 @@ Requires-Dist: torchmetrics ; extra == "extra"
|
|
|
26
26
|
Requires-Dist: typing-extensions
|
|
27
27
|
Requires-Dist: wandb ; extra == "extra"
|
|
28
28
|
Requires-Dist: wrapt ; extra == "extra"
|
|
29
|
+
Requires-Dist: zstandard ; extra == "extra"
|
|
29
30
|
Description-Content-Type: text/markdown
|
|
30
31
|
|
|
31
32
|
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/
|
|
3
|
-
nshtrainer/_checkpoint/
|
|
4
|
-
nshtrainer/_checkpoint/
|
|
2
|
+
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
|
+
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=TLAt7yR3KhSRbXCtomLMxcMvOiAju873A1ZRo8VWNwA,5179
|
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
|
|
5
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
6
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=0bOhJNyIjQGJsMRaW7qQJc1oTnUMHj08auuztzTQvZ0,16906
|
|
7
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
8
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
9
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
10
11
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
11
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
12
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
|
|
13
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
14
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
15
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -57,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
57
58
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
58
59
|
nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
|
|
59
60
|
nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
|
|
60
|
-
nshtrainer/model/config.py,sha256=
|
|
61
|
+
nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
|
|
61
62
|
nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
|
|
62
63
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
63
64
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -75,16 +76,17 @@ nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
|
75
76
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
76
77
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
77
78
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
78
|
-
nshtrainer/trainer/checkpoint_connector.py,sha256=
|
|
79
|
+
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
79
80
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
80
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
81
|
+
nshtrainer/trainer/trainer.py,sha256=DNKA4mcW083i7qLk0fi3j5-Qj4KNBtlLuyIsNxykebw,19100
|
|
81
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
82
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
83
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
85
|
+
nshtrainer/util/path.py,sha256=A_Ocag3_hbwns_zAxFDlH-5eVHWFlcy2DKxHQ7jddvk,837
|
|
84
86
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
85
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
86
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
87
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
88
|
-
nshtrainer-0.
|
|
89
|
-
nshtrainer-0.
|
|
90
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.21.0.dist-info/METADATA,sha256=7QfSX_yXi-Up6uxOVFfDPn4ieGK5b3UgQfO_KFsNzXk,979
|
|
91
|
+
nshtrainer-0.21.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.21.0.dist-info/RECORD,,
|
|
File without changes
|