nshtrainer 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_callback.py +40 -0
- nshtrainer/_checkpoint/metadata.py +8 -4
- nshtrainer/_checkpoint/saver.py +6 -7
- nshtrainer/_hf_hub.py +160 -31
- nshtrainer/callbacks/base.py +22 -19
- nshtrainer/callbacks/checkpoint/_base.py +23 -8
- nshtrainer/model/config.py +2 -0
- nshtrainer/trainer/trainer.py +56 -12
- nshtrainer/util/path.py +29 -0
- {nshtrainer-0.20.0.dist-info → nshtrainer-0.22.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.20.0.dist-info → nshtrainer-0.22.0.dist-info}/RECORD +12 -10
- {nshtrainer-0.20.0.dist-info → nshtrainer-0.22.0.dist-info}/WHEEL +0 -0
nshtrainer/_callback.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch.callbacks import Callback as _LightningCallback
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .model import LightningModuleBase
|
|
8
|
+
from .trainer import Trainer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NTCallbackBase(_LightningCallback):
|
|
12
|
+
def on_checkpoint_saved(
|
|
13
|
+
self,
|
|
14
|
+
ckpt_path: Path,
|
|
15
|
+
metadata_path: Path | None,
|
|
16
|
+
trainer: "Trainer",
|
|
17
|
+
pl_module: "LightningModuleBase",
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Called after a checkpoint is saved."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _call_on_checkpoint_saved(
|
|
24
|
+
trainer: "Trainer",
|
|
25
|
+
ckpt_path: str | Path,
|
|
26
|
+
metadata_path: str | Path | None,
|
|
27
|
+
):
|
|
28
|
+
ckpt_path = Path(ckpt_path)
|
|
29
|
+
metadata_path = Path(metadata_path) if metadata_path else None
|
|
30
|
+
|
|
31
|
+
for callback in trainer.callbacks:
|
|
32
|
+
if not isinstance(callback, NTCallbackBase):
|
|
33
|
+
continue
|
|
34
|
+
|
|
35
|
+
callback.on_checkpoint_saved(
|
|
36
|
+
ckpt_path,
|
|
37
|
+
metadata_path,
|
|
38
|
+
trainer,
|
|
39
|
+
trainer._base_module,
|
|
40
|
+
)
|
|
@@ -96,13 +96,17 @@ def _generate_checkpoint_metadata(
|
|
|
96
96
|
)
|
|
97
97
|
|
|
98
98
|
|
|
99
|
+
def _metadata_path(checkpoint_path: Path):
|
|
100
|
+
return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
101
|
+
|
|
102
|
+
|
|
99
103
|
def _write_checkpoint_metadata(
|
|
100
104
|
trainer: "Trainer",
|
|
101
105
|
model: "LightningModuleBase",
|
|
102
106
|
checkpoint_path: Path,
|
|
103
107
|
):
|
|
104
108
|
config = cast("BaseConfig", model.config)
|
|
105
|
-
metadata_path = checkpoint_path
|
|
109
|
+
metadata_path = _metadata_path(checkpoint_path)
|
|
106
110
|
metadata = _generate_checkpoint_metadata(
|
|
107
111
|
config, trainer, checkpoint_path, metadata_path
|
|
108
112
|
)
|
|
@@ -119,7 +123,7 @@ def _write_checkpoint_metadata(
|
|
|
119
123
|
|
|
120
124
|
|
|
121
125
|
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
-
path = checkpoint_path
|
|
126
|
+
path = _metadata_path(checkpoint_path)
|
|
123
127
|
try:
|
|
124
128
|
path.unlink(missing_ok=True)
|
|
125
129
|
except Exception:
|
|
@@ -133,8 +137,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
133
137
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
134
138
|
|
|
135
139
|
# Link the metadata files to the new checkpoint
|
|
136
|
-
path = checkpoint_path
|
|
137
|
-
linked_path = linked_checkpoint_path
|
|
140
|
+
path = _metadata_path(checkpoint_path)
|
|
141
|
+
linked_path = _metadata_path(linked_checkpoint_path)
|
|
138
142
|
try:
|
|
139
143
|
try:
|
|
140
144
|
# linked_path.symlink_to(path)
|
nshtrainer/_checkpoint/saver.py
CHANGED
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch import Trainer
|
|
6
6
|
|
|
7
|
+
from ..util.path import get_relative_path
|
|
7
8
|
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
8
9
|
|
|
9
10
|
|
|
@@ -14,10 +15,8 @@ def _link_checkpoint(
|
|
|
14
15
|
metadata: bool,
|
|
15
16
|
remove_existing: bool = True,
|
|
16
17
|
):
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
if not isinstance(linkpath, Path):
|
|
20
|
-
linkpath = Path(linkpath)
|
|
18
|
+
filepath = Path(filepath)
|
|
19
|
+
linkpath = Path(linkpath)
|
|
21
20
|
|
|
22
21
|
if remove_existing:
|
|
23
22
|
if linkpath.exists():
|
|
@@ -30,7 +29,7 @@ def _link_checkpoint(
|
|
|
30
29
|
_remove_checkpoint_metadata(linkpath)
|
|
31
30
|
|
|
32
31
|
try:
|
|
33
|
-
linkpath.symlink_to(
|
|
32
|
+
linkpath.symlink_to(get_relative_path(linkpath, filepath))
|
|
34
33
|
except OSError:
|
|
35
34
|
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
36
35
|
# fall back to copying the file
|
|
@@ -46,9 +45,9 @@ def _remove_checkpoint(
|
|
|
46
45
|
*,
|
|
47
46
|
metadata: bool,
|
|
48
47
|
):
|
|
49
|
-
|
|
50
|
-
filepath = Path(filepath)
|
|
48
|
+
filepath = Path(filepath)
|
|
51
49
|
|
|
52
50
|
trainer.strategy.remove_checkpoint(filepath)
|
|
51
|
+
|
|
53
52
|
if metadata:
|
|
54
53
|
_remove_checkpoint_metadata(filepath)
|
nshtrainer/_hf_hub.py
CHANGED
|
@@ -6,22 +6,19 @@ from pathlib import Path
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
import nshconfig as C
|
|
9
|
-
from lightning.pytorch import Callback
|
|
10
|
-
from lightning.pytorch.trainer import Trainer
|
|
11
9
|
from nshrunner._env import SNAPSHOT_DIR
|
|
12
10
|
from typing_extensions import override
|
|
13
11
|
|
|
14
|
-
from .
|
|
15
|
-
|
|
16
|
-
CallbackMetadataConfig,
|
|
17
|
-
CallbackWithMetadata,
|
|
18
|
-
)
|
|
12
|
+
from ._callback import NTCallbackBase
|
|
13
|
+
from .callbacks.base import CallbackConfigBase
|
|
19
14
|
|
|
20
15
|
if TYPE_CHECKING:
|
|
21
16
|
from huggingface_hub import HfApi # noqa: F401
|
|
22
17
|
|
|
23
18
|
from .model.base import BaseConfig
|
|
24
19
|
from .trainer.trainer import Trainer
|
|
20
|
+
|
|
21
|
+
|
|
25
22
|
log = logging.getLogger(__name__)
|
|
26
23
|
|
|
27
24
|
|
|
@@ -80,10 +77,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
|
80
77
|
|
|
81
78
|
@override
|
|
82
79
|
def create_callbacks(self, root_config):
|
|
83
|
-
yield
|
|
84
|
-
HFHubCallback(self),
|
|
85
|
-
CallbackMetadataConfig(ignore_if_exists=True),
|
|
86
|
-
)
|
|
80
|
+
yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
|
|
87
81
|
|
|
88
82
|
|
|
89
83
|
def _api(token: str | None = None):
|
|
@@ -102,9 +96,9 @@ def _api(token: str | None = None):
|
|
|
102
96
|
|
|
103
97
|
# Verify authentication
|
|
104
98
|
api.whoami()
|
|
105
|
-
except Exception
|
|
99
|
+
except Exception:
|
|
106
100
|
log.exception(
|
|
107
|
-
|
|
101
|
+
"Authentication failed for Hugging Face Hub. "
|
|
108
102
|
"Please make sure you are logged in using `huggingface-cli login`, "
|
|
109
103
|
"by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
|
|
110
104
|
"or by providing a valid token in the configuration."
|
|
@@ -210,10 +204,10 @@ def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
|
|
|
210
204
|
exist_ok=True,
|
|
211
205
|
)
|
|
212
206
|
log.info(f"Created new repository '{repo_name}'.")
|
|
213
|
-
except Exception
|
|
214
|
-
log.exception(f"Failed to create repository '{repo_name}'
|
|
215
|
-
except Exception
|
|
216
|
-
log.exception(f"Error checking repository '{repo_name}'
|
|
207
|
+
except Exception:
|
|
208
|
+
log.exception(f"Failed to create repository '{repo_name}'")
|
|
209
|
+
except Exception:
|
|
210
|
+
log.exception(f"Error checking repository '{repo_name}'")
|
|
217
211
|
|
|
218
212
|
# Upload the config
|
|
219
213
|
_save_config(root_config, trainer=trainer)
|
|
@@ -262,9 +256,9 @@ def _save_code(
|
|
|
262
256
|
log.info(
|
|
263
257
|
f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
264
258
|
)
|
|
265
|
-
except Exception
|
|
259
|
+
except Exception:
|
|
266
260
|
log.exception(
|
|
267
|
-
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder
|
|
261
|
+
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
268
262
|
)
|
|
269
263
|
|
|
270
264
|
|
|
@@ -300,10 +294,8 @@ def _save_config(
|
|
|
300
294
|
run_as_future=cast(Any, config.save_in_background),
|
|
301
295
|
)
|
|
302
296
|
log.info(f"Uploaded config.json to repository '{repo_name}'.")
|
|
303
|
-
except Exception
|
|
304
|
-
log.exception(
|
|
305
|
-
f"Failed to upload config.json to repository '{repo_name}': {str(e)}"
|
|
306
|
-
)
|
|
297
|
+
except Exception:
|
|
298
|
+
log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
|
|
307
299
|
|
|
308
300
|
|
|
309
301
|
def _save_checkpoint_files(
|
|
@@ -331,17 +323,24 @@ def _save_checkpoint_files(
|
|
|
331
323
|
# Read all the files to memory
|
|
332
324
|
file_contents: list[bytes | None] = []
|
|
333
325
|
for p in paths:
|
|
326
|
+
assert not p.is_symlink(), f"Path {p} is a symlink."
|
|
327
|
+
assert p.is_file(), f"Path {p} is not a file."
|
|
334
328
|
try:
|
|
335
329
|
with open(p, "rb") as f:
|
|
336
330
|
file_contents.append(f.read())
|
|
337
|
-
except IOError
|
|
338
|
-
log.
|
|
331
|
+
except IOError:
|
|
332
|
+
log.exception(f"Failed to read checkpoint file {p}.")
|
|
339
333
|
file_contents.append(None)
|
|
340
334
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
335
|
+
# Remove the paths that failed to read
|
|
336
|
+
file_contents_and_paths = [
|
|
337
|
+
(contents, p)
|
|
338
|
+
for contents, p in zip(file_contents, paths)
|
|
339
|
+
if contents is not None
|
|
340
|
+
]
|
|
344
341
|
|
|
342
|
+
# Upload the checkpoint files to the repository
|
|
343
|
+
for contents, p in file_contents_and_paths:
|
|
345
344
|
try:
|
|
346
345
|
relative_path = p.relative_to(checkpoint_dir)
|
|
347
346
|
except ValueError:
|
|
@@ -365,21 +364,136 @@ def _save_checkpoint_files(
|
|
|
365
364
|
log.info(
|
|
366
365
|
f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
|
|
367
366
|
)
|
|
368
|
-
except Exception
|
|
367
|
+
except Exception:
|
|
369
368
|
log.exception(
|
|
370
|
-
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'
|
|
369
|
+
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
|
|
371
370
|
)
|
|
372
371
|
|
|
373
372
|
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
374
373
|
|
|
375
374
|
|
|
376
|
-
|
|
375
|
+
def _save_checkpoint_symlinks(
|
|
376
|
+
trainer: "Trainer",
|
|
377
|
+
paths: list[Path],
|
|
378
|
+
*,
|
|
379
|
+
root_config: "BaseConfig",
|
|
380
|
+
):
|
|
381
|
+
config = root_config.trainer.hf_hub
|
|
382
|
+
if (
|
|
383
|
+
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
384
|
+
) is None or not config.save_checkpoints:
|
|
385
|
+
return
|
|
386
|
+
|
|
387
|
+
# Resolve the checkpoint directory
|
|
388
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
389
|
+
root_config.id, "checkpoint"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Resolve the repository name
|
|
393
|
+
repo_name = _repo_name(api, root_config)
|
|
394
|
+
|
|
395
|
+
# Create a commit for copying the files
|
|
396
|
+
from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
|
|
397
|
+
|
|
398
|
+
commits: list[CommitOperation] = []
|
|
399
|
+
for p in paths:
|
|
400
|
+
assert p.is_symlink(), f"Path {p} is not a symlink."
|
|
401
|
+
|
|
402
|
+
try:
|
|
403
|
+
dest_relative_path = p.relative_to(checkpoint_dir)
|
|
404
|
+
except ValueError:
|
|
405
|
+
log.warning(
|
|
406
|
+
f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
|
|
407
|
+
)
|
|
408
|
+
continue
|
|
409
|
+
|
|
410
|
+
try:
|
|
411
|
+
source_relative_path = p.resolve().relative_to(checkpoint_dir)
|
|
412
|
+
except ValueError:
|
|
413
|
+
log.warning(
|
|
414
|
+
f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
|
|
415
|
+
)
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
# Prefix the path in repo with "checkpoints"
|
|
419
|
+
dest_path_in_repo = Path("checkpoints") / dest_relative_path
|
|
420
|
+
source_path_in_repo = Path("checkpoints") / source_relative_path
|
|
421
|
+
|
|
422
|
+
# Create and append a CommitOperationCopy for copying the symlink
|
|
423
|
+
copy_op = CommitOperationCopy(
|
|
424
|
+
src_path_in_repo=str(source_path_in_repo),
|
|
425
|
+
path_in_repo=str(dest_path_in_repo),
|
|
426
|
+
)
|
|
427
|
+
commits.append(copy_op)
|
|
428
|
+
|
|
429
|
+
log.info(f"Creating a commit with the following operations: {commits}")
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
api.create_commit(
|
|
433
|
+
repo_id=repo_name,
|
|
434
|
+
repo_type="model",
|
|
435
|
+
commit_message="Copy checkpoint symlinks",
|
|
436
|
+
operations=commits,
|
|
437
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
438
|
+
)
|
|
439
|
+
log.info(
|
|
440
|
+
f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
|
|
441
|
+
)
|
|
442
|
+
except Exception:
|
|
443
|
+
log.exception(
|
|
444
|
+
f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def _save_checkpoint_directory(trainer: "Trainer", *, root_config: "BaseConfig"):
|
|
451
|
+
config = root_config.trainer.hf_hub
|
|
452
|
+
if (
|
|
453
|
+
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
454
|
+
) is None or not config.save_checkpoints:
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
# Resolve the checkpoint directory
|
|
458
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
459
|
+
root_config.id, "checkpoint"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Resolve the repository name
|
|
463
|
+
repo_name = _repo_name(api, root_config)
|
|
464
|
+
|
|
465
|
+
# Upload the checkpoint directory to the repository
|
|
466
|
+
try:
|
|
467
|
+
api.upload_folder(
|
|
468
|
+
folder_path=str(checkpoint_dir),
|
|
469
|
+
repo_id=repo_name,
|
|
470
|
+
repo_type="model",
|
|
471
|
+
path_in_repo="checkpoints",
|
|
472
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
473
|
+
)
|
|
474
|
+
log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
|
|
475
|
+
except Exception:
|
|
476
|
+
log.exception(
|
|
477
|
+
f"Failed to upload checkpoint directory to repository '{repo_name}'."
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class HFHubCallback(NTCallbackBase):
|
|
377
484
|
def __init__(self, config: HuggingFaceHubConfig):
|
|
378
485
|
super().__init__()
|
|
379
486
|
self.config = config
|
|
380
487
|
|
|
381
488
|
@override
|
|
382
489
|
def setup(self, trainer, pl_module, stage):
|
|
490
|
+
from .trainer.trainer import Trainer
|
|
491
|
+
|
|
492
|
+
if not isinstance(trainer, Trainer):
|
|
493
|
+
raise ValueError(
|
|
494
|
+
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
495
|
+
)
|
|
496
|
+
|
|
383
497
|
root_config = cast("BaseConfig", pl_module.hparams)
|
|
384
498
|
_init(trainer=trainer, root_config=root_config)
|
|
385
499
|
|
|
@@ -387,3 +501,18 @@ class HFHubCallback(Callback):
|
|
|
387
501
|
def teardown(self, trainer, pl_module, stage):
|
|
388
502
|
if hasattr(trainer, "_hf_hub_api"):
|
|
389
503
|
delattr(trainer, "_hf_hub_api")
|
|
504
|
+
|
|
505
|
+
@override
|
|
506
|
+
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
507
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
508
|
+
|
|
509
|
+
# If HF Hub is enabled, then we upload
|
|
510
|
+
if root_config.trainer.hf_hub and trainer.is_global_zero:
|
|
511
|
+
# Upload the regular files first, then the symlinks
|
|
512
|
+
all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
|
|
513
|
+
if regular_paths := [p for p in all_paths if not p.is_symlink()]:
|
|
514
|
+
_save_checkpoint_files(trainer, regular_paths, root_config=root_config)
|
|
515
|
+
if symlink_paths := [p for p in all_paths if p.is_symlink()]:
|
|
516
|
+
_save_checkpoint_symlinks(
|
|
517
|
+
trainer, symlink_paths, root_config=root_config
|
|
518
|
+
)
|
nshtrainer/callbacks/base.py
CHANGED
|
@@ -2,29 +2,24 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from collections import Counter
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import TYPE_CHECKING,
|
|
5
|
+
from typing import TYPE_CHECKING, ClassVar, TypeAlias
|
|
6
6
|
|
|
7
7
|
import nshconfig as C
|
|
8
8
|
from lightning.pytorch import Callback
|
|
9
|
+
from typing_extensions import TypedDict, Unpack
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
12
|
from ..model.config import BaseConfig
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
class
|
|
15
|
+
class CallbackMetadataConfig(TypedDict, total=False):
|
|
15
16
|
ignore_if_exists: bool
|
|
16
|
-
"""If `True`, the callback will not be added if another callback with the same class already exists.
|
|
17
|
+
"""If `True`, the callback will not be added if another callback with the same class already exists.
|
|
18
|
+
Default is `False`."""
|
|
17
19
|
|
|
18
20
|
priority: int
|
|
19
|
-
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class CallbackMetadataConfig(C.Config):
|
|
23
|
-
ignore_if_exists: bool = False
|
|
24
|
-
"""If `True`, the callback will not be added if another callback with the same class already exists."""
|
|
25
|
-
|
|
26
|
-
priority: int = 0
|
|
27
|
-
"""Priority of the callback. Callbacks with higher priority will be loaded first."""
|
|
21
|
+
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
|
22
|
+
Default is `0`."""
|
|
28
23
|
|
|
29
24
|
|
|
30
25
|
@dataclass(frozen=True)
|
|
@@ -37,13 +32,18 @@ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
|
|
|
37
32
|
|
|
38
33
|
|
|
39
34
|
class CallbackConfigBase(C.Config, ABC):
|
|
40
|
-
metadata: CallbackMetadataConfig = CallbackMetadataConfig()
|
|
35
|
+
metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig()
|
|
41
36
|
"""Metadata for the callback."""
|
|
42
37
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
38
|
+
@classmethod
|
|
39
|
+
def with_metadata(
|
|
40
|
+
cls, callback: Callback, **kwargs: Unpack[CallbackMetadataConfig]
|
|
41
|
+
):
|
|
42
|
+
metadata: CallbackMetadataConfig = {}
|
|
43
|
+
metadata.update(cls.metadata)
|
|
44
|
+
metadata.update(kwargs)
|
|
45
|
+
|
|
46
|
+
return CallbackWithMetadata(callback=callback, metadata=metadata)
|
|
47
47
|
|
|
48
48
|
@abstractmethod
|
|
49
49
|
def create_callbacks(
|
|
@@ -73,7 +73,7 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
|
|
|
73
73
|
for callback in callbacks:
|
|
74
74
|
# If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
|
|
75
75
|
if (
|
|
76
|
-
callback.metadata.ignore_if_exists
|
|
76
|
+
callback.metadata.get("ignore_if_exists", False)
|
|
77
77
|
and callback_classes[callback.callback.__class__] > 1
|
|
78
78
|
):
|
|
79
79
|
continue
|
|
@@ -89,7 +89,10 @@ def _process_and_filter_callbacks(
|
|
|
89
89
|
callbacks = list(callbacks)
|
|
90
90
|
|
|
91
91
|
# Sort by priority (higher priority first)
|
|
92
|
-
callbacks.sort(
|
|
92
|
+
callbacks.sort(
|
|
93
|
+
key=lambda callback: callback.metadata.get("priority", 0),
|
|
94
|
+
reverse=True,
|
|
95
|
+
)
|
|
93
96
|
|
|
94
97
|
# Process `ignore_if_exists`
|
|
95
98
|
callbacks = _filter_ignore_if_exists(callbacks)
|
|
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer
|
|
|
9
9
|
from lightning.pytorch.callbacks import Checkpoint
|
|
10
10
|
from typing_extensions import TypeVar, override
|
|
11
11
|
|
|
12
|
-
from ..._checkpoint.metadata import CheckpointMetadata
|
|
12
|
+
from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
|
|
13
13
|
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
14
|
from ..base import CallbackConfigBase
|
|
15
15
|
|
|
@@ -65,8 +65,6 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
65
65
|
self.dirpath.mkdir(parents=True, exist_ok=True)
|
|
66
66
|
self.symlink_dirpath = dirpath
|
|
67
67
|
|
|
68
|
-
self._last_global_step_saved = 0
|
|
69
|
-
|
|
70
68
|
@abstractmethod
|
|
71
69
|
def default_filename(self) -> str: ...
|
|
72
70
|
|
|
@@ -144,9 +142,21 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
144
142
|
if self._should_skip_saving_checkpoint(trainer):
|
|
145
143
|
return
|
|
146
144
|
|
|
145
|
+
from ...trainer import Trainer as NTTrainer
|
|
146
|
+
|
|
147
|
+
if not isinstance(trainer, NTTrainer):
|
|
148
|
+
raise TypeError(
|
|
149
|
+
f"Trainer must be an instance of {NTTrainer.__name__}, "
|
|
150
|
+
f"but got {type(trainer).__name__}"
|
|
151
|
+
)
|
|
152
|
+
|
|
147
153
|
# Save the new checkpoint
|
|
148
154
|
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
|
149
|
-
trainer.
|
|
155
|
+
trainer._nshtrainer_save_checkpoint(
|
|
156
|
+
filepath,
|
|
157
|
+
self.config.save_weights_only,
|
|
158
|
+
use_checkpoint_cache=None,
|
|
159
|
+
)
|
|
150
160
|
|
|
151
161
|
if trainer.is_global_zero:
|
|
152
162
|
# Create the latest symlink
|
|
@@ -162,8 +172,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
162
172
|
# deleted the old checkpoints, and created the symlink before continuing
|
|
163
173
|
trainer.strategy.barrier()
|
|
164
174
|
|
|
165
|
-
#
|
|
166
|
-
self.
|
|
175
|
+
# Call the on save checkpoint callback for the symlink (if it exists)
|
|
176
|
+
if (symlink_filename := self.symlink_path()) is not None:
|
|
177
|
+
from ... import _callback
|
|
178
|
+
|
|
179
|
+
symlink_path = self.dirpath / symlink_filename
|
|
180
|
+
symlink_metadata_path = _metadata_path(symlink_path)
|
|
181
|
+
_callback._call_on_checkpoint_saved(
|
|
182
|
+
trainer, symlink_path, symlink_metadata_path
|
|
183
|
+
)
|
|
167
184
|
|
|
168
185
|
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
169
186
|
from lightning.pytorch.trainer.states import TrainerFn
|
|
@@ -175,6 +192,4 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
175
192
|
or trainer.state.fn
|
|
176
193
|
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
177
194
|
or trainer.sanity_checking # don't save anything during sanity check
|
|
178
|
-
or self._last_global_step_saved
|
|
179
|
-
== trainer.global_step # already saved at the last step
|
|
180
195
|
)
|
nshtrainer/model/config.py
CHANGED
|
@@ -1012,6 +1012,8 @@ class TrainerConfig(C.Config):
|
|
|
1012
1012
|
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1013
1013
|
save_checkpoint_metadata: bool = True
|
|
1014
1014
|
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
1015
|
+
use_checkpoint_cache: bool = True
|
|
1016
|
+
"""If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
|
|
1015
1017
|
|
|
1016
1018
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1017
1019
|
"""
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -17,6 +17,7 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
|
|
17
17
|
from typing_extensions import Unpack, assert_never, override
|
|
18
18
|
|
|
19
19
|
from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
20
|
+
from .._checkpoint.saver import _link_checkpoint
|
|
20
21
|
from ..callbacks.base import resolve_all_callbacks
|
|
21
22
|
from ..model.config import (
|
|
22
23
|
AcceleratorConfigProtocol,
|
|
@@ -285,6 +286,8 @@ class Trainer(LightningTrainer):
|
|
|
285
286
|
/,
|
|
286
287
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
287
288
|
):
|
|
289
|
+
self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
|
|
290
|
+
|
|
288
291
|
self._pre_init(config)
|
|
289
292
|
|
|
290
293
|
kwargs = self._update_kwargs(config, kwargs)
|
|
@@ -407,29 +410,70 @@ class Trainer(LightningTrainer):
|
|
|
407
410
|
|
|
408
411
|
return super()._run(model, ckpt_path)
|
|
409
412
|
|
|
410
|
-
|
|
411
|
-
def save_checkpoint(
|
|
413
|
+
def _nshtrainer_save_checkpoint(
|
|
412
414
|
self,
|
|
413
415
|
filepath: str | Path,
|
|
414
416
|
weights_only: bool = False,
|
|
415
417
|
storage_options: Any | None = None,
|
|
418
|
+
use_checkpoint_cache: bool | None = None,
|
|
416
419
|
):
|
|
420
|
+
lm = self._base_module
|
|
421
|
+
hparams = cast(BaseConfig, lm.hparams)
|
|
422
|
+
if use_checkpoint_cache is None:
|
|
423
|
+
use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
|
|
424
|
+
|
|
417
425
|
filepath = Path(filepath)
|
|
418
|
-
|
|
426
|
+
|
|
427
|
+
# List of files that we should upload to HF
|
|
428
|
+
written_files: list[Path] = [filepath]
|
|
429
|
+
|
|
430
|
+
cached_path = None
|
|
431
|
+
if (
|
|
432
|
+
use_checkpoint_cache
|
|
433
|
+
and (
|
|
434
|
+
cached_path := self._nshtrainer_checkpoint_cache.get(
|
|
435
|
+
(self.current_epoch, self.global_step)
|
|
436
|
+
)
|
|
437
|
+
)
|
|
438
|
+
is not None
|
|
439
|
+
):
|
|
440
|
+
# If we have a cached path, then we symlink it to the new path.
|
|
441
|
+
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
|
+
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
443
|
+
else:
|
|
444
|
+
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
445
|
+
|
|
446
|
+
# If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
|
|
447
|
+
if use_checkpoint_cache and cached_path is None:
|
|
448
|
+
self._nshtrainer_checkpoint_cache[
|
|
449
|
+
(self.current_epoch, self.global_step)
|
|
450
|
+
] = filepath
|
|
451
|
+
log.debug(f"Checkpoint saved to cache: {filepath}")
|
|
419
452
|
|
|
420
453
|
# Save the checkpoint metadata
|
|
421
|
-
lm = self._base_module
|
|
422
|
-
hparams = cast(BaseConfig, lm.hparams)
|
|
423
454
|
metadata_path = None
|
|
424
455
|
if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
425
456
|
# Generate the metadata and write to disk
|
|
426
|
-
|
|
457
|
+
if (
|
|
458
|
+
metadata_path := _write_checkpoint_metadata(self, lm, filepath)
|
|
459
|
+
) is not None:
|
|
460
|
+
written_files.append(metadata_path)
|
|
427
461
|
|
|
428
|
-
#
|
|
429
|
-
|
|
430
|
-
from .._hf_hub import _save_checkpoint_files
|
|
462
|
+
# Call the `on_checkpoint_saved` method on all callbacks
|
|
463
|
+
from .. import _callback
|
|
431
464
|
|
|
432
|
-
|
|
433
|
-
_save_checkpoint_files(self, files, root_config=hparams)
|
|
465
|
+
_callback._call_on_checkpoint_saved(self, filepath, metadata_path)
|
|
434
466
|
|
|
435
|
-
|
|
467
|
+
@override
|
|
468
|
+
def save_checkpoint(
|
|
469
|
+
self,
|
|
470
|
+
filepath: str | Path,
|
|
471
|
+
weights_only: bool = False,
|
|
472
|
+
storage_options: Any | None = None,
|
|
473
|
+
):
|
|
474
|
+
return self._nshtrainer_save_checkpoint(
|
|
475
|
+
filepath=filepath,
|
|
476
|
+
weights_only=weights_only,
|
|
477
|
+
storage_options=storage_options,
|
|
478
|
+
use_checkpoint_cache=False,
|
|
479
|
+
)
|
nshtrainer/util/path.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
|
|
5
|
+
_Path: TypeAlias = str | Path | os.PathLike
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_relative_path(source: _Path, destination: _Path):
|
|
9
|
+
# Get the absolute paths
|
|
10
|
+
source = os.path.abspath(source)
|
|
11
|
+
destination = os.path.abspath(destination)
|
|
12
|
+
|
|
13
|
+
# Split the paths into components
|
|
14
|
+
source_parts = source.split(os.sep)
|
|
15
|
+
destination_parts = destination.split(os.sep)
|
|
16
|
+
|
|
17
|
+
# Find the point where the paths diverge
|
|
18
|
+
i = 0
|
|
19
|
+
for i in range(min(len(source_parts), len(destination_parts))):
|
|
20
|
+
if source_parts[i] != destination_parts[i]:
|
|
21
|
+
break
|
|
22
|
+
else:
|
|
23
|
+
i += 1
|
|
24
|
+
|
|
25
|
+
# Build the relative path
|
|
26
|
+
up = os.sep.join([".." for _ in range(len(source_parts) - i - 1)])
|
|
27
|
+
down = os.sep.join(destination_parts[i:])
|
|
28
|
+
|
|
29
|
+
return Path(os.path.normpath(os.path.join(up, down)))
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
+
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
2
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
3
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
4
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=TLAt7yR3KhSRbXCtomLMxcMvOiAju873A1ZRo8VWNwA,5179
|
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
|
|
5
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
6
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
|
|
7
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
8
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
9
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
10
|
-
nshtrainer/callbacks/base.py,sha256=
|
|
11
|
+
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
11
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
12
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
|
|
13
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
14
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
15
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -57,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
57
58
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
58
59
|
nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
|
|
59
60
|
nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
|
|
60
|
-
nshtrainer/model/config.py,sha256=
|
|
61
|
+
nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
|
|
61
62
|
nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
|
|
62
63
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
63
64
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -77,14 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
77
78
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
78
79
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
79
80
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
80
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
81
|
+
nshtrainer/trainer/trainer.py,sha256=DNKA4mcW083i7qLk0fi3j5-Qj4KNBtlLuyIsNxykebw,19100
|
|
81
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
82
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
83
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
85
|
+
nshtrainer/util/path.py,sha256=A_Ocag3_hbwns_zAxFDlH-5eVHWFlcy2DKxHQ7jddvk,837
|
|
84
86
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
85
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
86
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
87
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
88
|
-
nshtrainer-0.
|
|
89
|
-
nshtrainer-0.
|
|
90
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.22.0.dist-info/METADATA,sha256=sdjt9S4X3xiIGgD6FNF06yIyC1tJA89B9Qm9mxy29tc,935
|
|
91
|
+
nshtrainer-0.22.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.22.0.dist-info/RECORD,,
|
|
File without changes
|