nshtrainer 0.24.0__tar.gz → 0.26.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/PKG-INFO +2 -2
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/pyproject.toml +3 -10
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_checkpoint/metadata.py +3 -1
- nshtrainer-0.26.0/src/nshtrainer/_hf_hub.py +353 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/gradient_skipping.py +1 -8
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/trainer.py +0 -11
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/path.py +18 -0
- nshtrainer-0.24.0/src/nshtrainer/_hf_hub.py +0 -518
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/README.md +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.26.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -22,7 +22,7 @@ Requires-Dist: psutil
|
|
|
22
22
|
Requires-Dist: pytorch-lightning
|
|
23
23
|
Requires-Dist: tensorboard ; extra == "extra"
|
|
24
24
|
Requires-Dist: torch
|
|
25
|
-
Requires-Dist: torchmetrics
|
|
25
|
+
Requires-Dist: torchmetrics
|
|
26
26
|
Requires-Dist: typing-extensions
|
|
27
27
|
Requires-Dist: wandb ; extra == "extra"
|
|
28
28
|
Requires-Dist: wrapt ; extra == "extra"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.26.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -17,7 +17,7 @@ typing-extensions = "*"
|
|
|
17
17
|
packaging = "*"
|
|
18
18
|
lightning = "*"
|
|
19
19
|
pytorch-lightning = "*"
|
|
20
|
-
torchmetrics =
|
|
20
|
+
torchmetrics = "*"
|
|
21
21
|
wrapt = { version = "*", optional = true }
|
|
22
22
|
GitPython = { version = "*", optional = true }
|
|
23
23
|
wandb = { version = "*", optional = true }
|
|
@@ -46,11 +46,4 @@ reportPrivateImportUsage = false
|
|
|
46
46
|
ignore = ["F722", "F821", "E731", "E741"]
|
|
47
47
|
|
|
48
48
|
[tool.poetry.extras]
|
|
49
|
-
extra = [
|
|
50
|
-
"torchmetrics",
|
|
51
|
-
"wrapt",
|
|
52
|
-
"GitPython",
|
|
53
|
-
"wandb",
|
|
54
|
-
"tensorboard",
|
|
55
|
-
"huggingface-hub",
|
|
56
|
-
]
|
|
49
|
+
extra = ["wrapt", "GitPython", "wandb", "tensorboard", "huggingface-hub"]
|
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
-
from ..util.path import get_relative_path
|
|
14
|
+
from ..util.path import compute_file_checksum, get_relative_path
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
|
|
|
28
28
|
|
|
29
29
|
checkpoint_path: Path
|
|
30
30
|
checkpoint_filename: str
|
|
31
|
+
checkpoint_checksum: str
|
|
31
32
|
|
|
32
33
|
run_id: str
|
|
33
34
|
name: str
|
|
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
|
|
|
81
82
|
# moving the checkpoint directory
|
|
82
83
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
83
84
|
checkpoint_filename=checkpoint_path.name,
|
|
85
|
+
checkpoint_checksum=compute_file_checksum(checkpoint_path),
|
|
84
86
|
run_id=config.id,
|
|
85
87
|
name=config.run_name,
|
|
86
88
|
project=config.project,
|
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
9
|
+
|
|
10
|
+
import nshconfig as C
|
|
11
|
+
from nshrunner._env import SNAPSHOT_DIR
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from ._callback import NTCallbackBase
|
|
15
|
+
from .callbacks.base import CallbackConfigBase
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from huggingface_hub import HfApi # noqa: F401
|
|
19
|
+
|
|
20
|
+
from .model.base import BaseConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HuggingFaceHubAutoCreateConfig(C.Config):
|
|
27
|
+
enabled: bool = True
|
|
28
|
+
"""Enable automatic repository creation on the Hugging Face Hub."""
|
|
29
|
+
|
|
30
|
+
private: bool = True
|
|
31
|
+
"""Whether to create the repository as private."""
|
|
32
|
+
|
|
33
|
+
namespace: str | None = None
|
|
34
|
+
"""The namespace to create the repository in. If `None`, the repository will be created in the user's namespace."""
|
|
35
|
+
|
|
36
|
+
def __bool__(self):
|
|
37
|
+
return self.enabled
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HuggingFaceHubConfig(CallbackConfigBase):
|
|
41
|
+
"""Configuration options for Hugging Face Hub integration."""
|
|
42
|
+
|
|
43
|
+
enabled: bool = False
|
|
44
|
+
"""Enable Hugging Face Hub integration."""
|
|
45
|
+
|
|
46
|
+
token: str | None = None
|
|
47
|
+
"""Hugging Face Hub API token. If `None`, the token will be read from the current environment.
|
|
48
|
+
This needs to either be set using `huggingface-cli login` or by setting the `HUGGINGFACE_TOKEN`
|
|
49
|
+
environment variable."""
|
|
50
|
+
|
|
51
|
+
auto_create: HuggingFaceHubAutoCreateConfig = HuggingFaceHubAutoCreateConfig()
|
|
52
|
+
"""Automatic repository creation configuration options."""
|
|
53
|
+
|
|
54
|
+
save_config: bool = True
|
|
55
|
+
"""Whether to save the model configuration to the Hugging Face Hub."""
|
|
56
|
+
|
|
57
|
+
save_checkpoints: bool = True
|
|
58
|
+
"""Whether to save checkpoints to the Hugging Face Hub."""
|
|
59
|
+
|
|
60
|
+
save_code: bool = True
|
|
61
|
+
"""Whether to save code to the Hugging Face Hub.
|
|
62
|
+
This is only supported if `nshsnap` is installed and snapshotting is enabled."""
|
|
63
|
+
|
|
64
|
+
save_in_background: bool = True
|
|
65
|
+
"""Whether to save to the Hugging Face Hub in the background.
|
|
66
|
+
This corresponds to setting `run_as_future=True` in the HFApi upload methods."""
|
|
67
|
+
|
|
68
|
+
def enable_(self):
|
|
69
|
+
self.enabled = True
|
|
70
|
+
return self
|
|
71
|
+
|
|
72
|
+
def disable_(self):
|
|
73
|
+
self.enabled = False
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
def __bool__(self):
|
|
77
|
+
return self.enabled
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def create_callbacks(self, root_config):
|
|
81
|
+
yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _api(token: str | None = None):
|
|
85
|
+
# Make sure that `huggingface_hub` is installed
|
|
86
|
+
try:
|
|
87
|
+
import huggingface_hub # noqa: F401
|
|
88
|
+
except ImportError:
|
|
89
|
+
log.exception(
|
|
90
|
+
"Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
|
|
91
|
+
)
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
# Create and authenticate the API instance
|
|
95
|
+
try:
|
|
96
|
+
api = huggingface_hub.HfApi(token=token)
|
|
97
|
+
|
|
98
|
+
# Verify authentication
|
|
99
|
+
api.whoami()
|
|
100
|
+
except Exception:
|
|
101
|
+
log.exception(
|
|
102
|
+
"Authentication failed for Hugging Face Hub. "
|
|
103
|
+
"Please make sure you are logged in using `huggingface-cli login`, "
|
|
104
|
+
"by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
|
|
105
|
+
"or by providing a valid token in the configuration."
|
|
106
|
+
)
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return api
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
113
|
+
username = None
|
|
114
|
+
if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
|
|
115
|
+
username = ac.namespace
|
|
116
|
+
elif (username := api.whoami().get("name", None)) is None:
|
|
117
|
+
raise ValueError("Could not get username from Hugging Face Hub.")
|
|
118
|
+
|
|
119
|
+
# Sanitize the project (if it exists), run_name, and id
|
|
120
|
+
parts = []
|
|
121
|
+
if root_config.project:
|
|
122
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
|
|
123
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
|
|
124
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
|
|
125
|
+
|
|
126
|
+
# Combine parts and ensure it starts and ends with alphanumeric characters
|
|
127
|
+
repo_name = "-".join(parts)
|
|
128
|
+
repo_name = repo_name.strip("-")
|
|
129
|
+
repo_name = re.sub(
|
|
130
|
+
r"-+", "-", repo_name
|
|
131
|
+
) # Replace multiple dashes with a single dash
|
|
132
|
+
|
|
133
|
+
# Ensure the name is not longer than 96 characters (excluding username)
|
|
134
|
+
if len(repo_name) > 96:
|
|
135
|
+
repo_name = repo_name[:96].rstrip("-")
|
|
136
|
+
|
|
137
|
+
# Ensure the repo name starts with an alphanumeric character
|
|
138
|
+
repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
|
|
139
|
+
|
|
140
|
+
# If the repo_name is empty after all sanitization, use a default name
|
|
141
|
+
if not repo_name:
|
|
142
|
+
repo_name = "default-repo-name"
|
|
143
|
+
|
|
144
|
+
return f"{username}/{repo_name}"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class _Upload:
|
|
149
|
+
local_path: Path
|
|
150
|
+
path_in_repo: Path
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_local_path(
|
|
154
|
+
cls,
|
|
155
|
+
local_path: Path,
|
|
156
|
+
root_config: "BaseConfig",
|
|
157
|
+
):
|
|
158
|
+
# Resolve the checkpoint directory
|
|
159
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
160
|
+
root_config.id, "checkpoint"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
relative_path = local_path.relative_to(checkpoint_dir)
|
|
165
|
+
except ValueError:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Prefix the path in repo with "checkpoints"
|
|
171
|
+
path_in_repo = Path("checkpoints") / relative_path
|
|
172
|
+
|
|
173
|
+
return cls(local_path=local_path, path_in_repo=path_in_repo)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class HFHubCallback(NTCallbackBase):
|
|
177
|
+
@contextlib.contextmanager
|
|
178
|
+
def _with_error_handling(self, opeartion: str):
|
|
179
|
+
try:
|
|
180
|
+
yield
|
|
181
|
+
except Exception:
|
|
182
|
+
log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
|
|
183
|
+
else:
|
|
184
|
+
log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
|
|
185
|
+
|
|
186
|
+
def __init__(self, config: HuggingFaceHubConfig):
|
|
187
|
+
super().__init__()
|
|
188
|
+
|
|
189
|
+
self.config = config
|
|
190
|
+
|
|
191
|
+
self._repo_id = None
|
|
192
|
+
self._checksum_to_path_in_repo: dict[str, Path] = {}
|
|
193
|
+
|
|
194
|
+
@override
|
|
195
|
+
def setup(self, trainer, pl_module, stage):
|
|
196
|
+
from .trainer.trainer import Trainer
|
|
197
|
+
|
|
198
|
+
if not isinstance(trainer, Trainer):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
204
|
+
|
|
205
|
+
# Create the repository, if it doesn't exist
|
|
206
|
+
self._repo_id = self.api.create_repo(
|
|
207
|
+
repo_id=_repo_name(self.api, root_config),
|
|
208
|
+
repo_type="model",
|
|
209
|
+
private=self.config.auto_create.private,
|
|
210
|
+
exist_ok=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Upload the config and code
|
|
214
|
+
self._save_config(root_config)
|
|
215
|
+
self._save_code()
|
|
216
|
+
|
|
217
|
+
@override
|
|
218
|
+
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
219
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
220
|
+
|
|
221
|
+
# If HF Hub is enabled, then we upload
|
|
222
|
+
if self.config and trainer.is_global_zero:
|
|
223
|
+
with self._with_error_handling("save checkpoints"):
|
|
224
|
+
self._save_checkpoint(
|
|
225
|
+
_Upload.from_local_path(ckpt_path, root_config),
|
|
226
|
+
_Upload.from_local_path(metadata_path, root_config)
|
|
227
|
+
if metadata_path is not None
|
|
228
|
+
else None,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@cached_property
|
|
232
|
+
def api(self):
|
|
233
|
+
# Create and authenticate the API instance
|
|
234
|
+
if (api := _api(self.config.token)) is None:
|
|
235
|
+
raise ValueError("Failed to create Hugging Face Hub API instance.")
|
|
236
|
+
return api
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def repo_id(self):
|
|
240
|
+
if self._repo_id is None:
|
|
241
|
+
raise ValueError("Repository id has not been initialized.")
|
|
242
|
+
return self._repo_id
|
|
243
|
+
|
|
244
|
+
def _save_config(self, root_config: "BaseConfig"):
|
|
245
|
+
with self._with_error_handling("upload config"):
|
|
246
|
+
self.api.upload_file(
|
|
247
|
+
path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
|
|
248
|
+
path_in_repo="config.json",
|
|
249
|
+
repo_id=self.repo_id,
|
|
250
|
+
repo_type="model",
|
|
251
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def _save_code(self):
|
|
255
|
+
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
256
|
+
# then upload all contents within the snapshot directory to the repository.
|
|
257
|
+
if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
|
|
258
|
+
log.debug("No snapshot directory found. Skipping upload.")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
with self._with_error_handling("save code"):
|
|
262
|
+
snapshot_dir = Path(snapshot_dir)
|
|
263
|
+
if not snapshot_dir.exists() or not snapshot_dir.is_dir():
|
|
264
|
+
log.warning(
|
|
265
|
+
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
266
|
+
)
|
|
267
|
+
return
|
|
268
|
+
|
|
269
|
+
self.api.upload_folder(
|
|
270
|
+
folder_path=str(snapshot_dir),
|
|
271
|
+
repo_id=self.repo_id,
|
|
272
|
+
repo_type="model",
|
|
273
|
+
path_in_repo="code", # Prefix with "code" folder
|
|
274
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def _save_file(self, p: _Upload):
|
|
278
|
+
with self._with_error_handling("save file"):
|
|
279
|
+
# Upload the checkpoint files to the repository
|
|
280
|
+
self.api.upload_file(
|
|
281
|
+
path_or_fileobj=p.local_path,
|
|
282
|
+
path_in_repo=str(p.path_in_repo),
|
|
283
|
+
repo_id=self.repo_id,
|
|
284
|
+
repo_type="model",
|
|
285
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
|
|
289
|
+
# Create a commit for copying the files
|
|
290
|
+
from huggingface_hub.hf_api import CommitOperationCopy
|
|
291
|
+
|
|
292
|
+
with self._with_error_handling("copy file"):
|
|
293
|
+
copy_op = CommitOperationCopy(
|
|
294
|
+
src_path_in_repo=str(source_path_in_repo),
|
|
295
|
+
path_in_repo=str(dest_path_in_repo),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self.api.create_commit(
|
|
299
|
+
repo_id=self.repo_id,
|
|
300
|
+
repo_type="model",
|
|
301
|
+
commit_message="Copy checkpoint file",
|
|
302
|
+
operations=[copy_op],
|
|
303
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
|
|
307
|
+
if not self.config.save_checkpoints:
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
# If no metadata, just save regularly.
|
|
311
|
+
if metadata_path is None:
|
|
312
|
+
self._save_file(path)
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
# Otherwise, let's check to see if we've already uploaded the metadata.
|
|
316
|
+
# If so, we can just copy the checkpoint file.
|
|
317
|
+
from ._checkpoint.metadata import CheckpointMetadata
|
|
318
|
+
|
|
319
|
+
metadata = CheckpointMetadata.from_file(metadata_path.local_path)
|
|
320
|
+
if (
|
|
321
|
+
existing_ckpt_path := self._checksum_to_path_in_repo.get(
|
|
322
|
+
metadata.checkpoint_checksum
|
|
323
|
+
)
|
|
324
|
+
) is not None:
|
|
325
|
+
self._copy_file(existing_ckpt_path, path.path_in_repo)
|
|
326
|
+
else:
|
|
327
|
+
# Otherwise, we save the checkpoint & keep the checksum so we don't
|
|
328
|
+
# re-upload the same file again.
|
|
329
|
+
self._save_file(path)
|
|
330
|
+
self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
|
|
331
|
+
path.path_in_repo
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Save the metadata file
|
|
335
|
+
# NOTE: This file is fairly small, so we can just upload it directly.
|
|
336
|
+
# No need to copy.
|
|
337
|
+
self._save_file(metadata_path)
|
|
338
|
+
|
|
339
|
+
@override
|
|
340
|
+
def state_dict(self):
|
|
341
|
+
return {
|
|
342
|
+
"repo_id": self._repo_id,
|
|
343
|
+
"checksum_to_path_in_repo": {
|
|
344
|
+
k: str(v) for k, v in self._checksum_to_path_in_repo.items()
|
|
345
|
+
},
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
@override
|
|
349
|
+
def load_state_dict(self, state_dict):
|
|
350
|
+
self._repo_id = state_dict["repo_id"]
|
|
351
|
+
self._checksum_to_path_in_repo = {
|
|
352
|
+
k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
|
|
353
|
+
}
|
{nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
|
70
70
|
|
|
71
71
|
# Events
|
|
72
72
|
@override
|
|
73
|
-
def
|
|
73
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
74
74
|
self.save_checkpoints(trainer)
|
{nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
|
39
39
|
return True
|
|
40
40
|
|
|
41
41
|
@override
|
|
42
|
-
def
|
|
42
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
43
43
|
self.save_checkpoints(trainer)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
import importlib.util
|
|
2
1
|
import logging
|
|
3
2
|
from typing import Any, Literal, Protocol, runtime_checkable
|
|
4
3
|
|
|
5
4
|
import torch
|
|
5
|
+
import torchmetrics
|
|
6
6
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
7
7
|
from torch.optim import Optimizer
|
|
8
8
|
from typing_extensions import override
|
|
@@ -20,19 +20,12 @@ class HasGradSkippedSteps(Protocol):
|
|
|
20
20
|
|
|
21
21
|
class GradientSkipping(Callback):
|
|
22
22
|
def __init__(self, config: "GradientSkippingConfig"):
|
|
23
|
-
if importlib.util.find_spec("torchmetrics") is not None:
|
|
24
|
-
raise ImportError(
|
|
25
|
-
"To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
|
|
26
|
-
)
|
|
27
|
-
|
|
28
23
|
super().__init__()
|
|
29
24
|
self.config = config
|
|
30
25
|
|
|
31
26
|
@override
|
|
32
27
|
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
33
28
|
if not isinstance(pl_module, HasGradSkippedSteps):
|
|
34
|
-
import torchmetrics # type: ignore
|
|
35
|
-
|
|
36
29
|
pl_module.grad_skipped_steps = torchmetrics.SumMetric()
|
|
37
30
|
|
|
38
31
|
@override
|
|
@@ -280,13 +280,6 @@ class Trainer(LightningTrainer):
|
|
|
280
280
|
if TYPE_CHECKING:
|
|
281
281
|
callbacks: list[Callback]
|
|
282
282
|
|
|
283
|
-
def _nshtrainer_ckpt_link(self, ckpt_path: Path):
|
|
284
|
-
root_config = cast(BaseConfig, self._base_module.hparams)
|
|
285
|
-
ckpt_dir = root_config.directory.resolve_subdirectory(
|
|
286
|
-
root_config.id, "checkpoint"
|
|
287
|
-
)
|
|
288
|
-
return str(ckpt_path.absolute().relative_to(ckpt_dir))
|
|
289
|
-
|
|
290
283
|
@override
|
|
291
284
|
def __init__(
|
|
292
285
|
self,
|
|
@@ -295,7 +288,6 @@ class Trainer(LightningTrainer):
|
|
|
295
288
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
296
289
|
):
|
|
297
290
|
self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
|
|
298
|
-
self._nshtrainer_checkpoint_link_dict = dict[str, Path]()
|
|
299
291
|
|
|
300
292
|
self._pre_init(config)
|
|
301
293
|
|
|
@@ -454,9 +446,6 @@ class Trainer(LightningTrainer):
|
|
|
454
446
|
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
455
447
|
else:
|
|
456
448
|
shutil.copy(cached_path, filepath)
|
|
457
|
-
self._nshtrainer_checkpoint_link_dict[
|
|
458
|
-
self._nshtrainer_ckpt_link(filepath)
|
|
459
|
-
] = cached_path
|
|
460
449
|
self.strategy.barrier("Trainer.save_checkpoint")
|
|
461
450
|
else:
|
|
462
451
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import os
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import TypeAlias
|
|
@@ -50,3 +51,20 @@ def find_symlinks(
|
|
|
50
51
|
pass
|
|
51
52
|
|
|
52
53
|
return symlinks
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_file_checksum(file_path: Path) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Calculate the SHA256 checksum of a file.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
file_path (Path): The path to the file.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
str: The hexadecimal representation of the file's SHA256 checksum.
|
|
65
|
+
"""
|
|
66
|
+
sha256_hash = hashlib.sha256()
|
|
67
|
+
with file_path.open("rb") as f:
|
|
68
|
+
for byte_block in iter(lambda: f.read(4096), b""):
|
|
69
|
+
sha256_hash.update(byte_block)
|
|
70
|
+
return sha256_hash.hexdigest()
|