nshtrainer 0.23.0__tar.gz → 0.25.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/PKG-INFO +1 -1
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/pyproject.toml +1 -1
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_checkpoint/metadata.py +3 -1
- nshtrainer-0.25.0/src/nshtrainer/_hf_hub.py +353 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/_base.py +2 -40
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/trainer.py +9 -4
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/path.py +18 -0
- nshtrainer-0.23.0/src/nshtrainer/_hf_hub.py +0 -518
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/README.md +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
-
from ..util.path import get_relative_path
|
|
14
|
+
from ..util.path import compute_file_checksum, get_relative_path
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
|
|
|
28
28
|
|
|
29
29
|
checkpoint_path: Path
|
|
30
30
|
checkpoint_filename: str
|
|
31
|
+
checkpoint_checksum: str
|
|
31
32
|
|
|
32
33
|
run_id: str
|
|
33
34
|
name: str
|
|
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
|
|
|
81
82
|
# moving the checkpoint directory
|
|
82
83
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
83
84
|
checkpoint_filename=checkpoint_path.name,
|
|
85
|
+
checkpoint_checksum=compute_file_checksum(checkpoint_path),
|
|
84
86
|
run_id=config.id,
|
|
85
87
|
name=config.run_name,
|
|
86
88
|
project=config.project,
|
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
9
|
+
|
|
10
|
+
import nshconfig as C
|
|
11
|
+
from nshrunner._env import SNAPSHOT_DIR
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from ._callback import NTCallbackBase
|
|
15
|
+
from .callbacks.base import CallbackConfigBase
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from huggingface_hub import HfApi # noqa: F401
|
|
19
|
+
|
|
20
|
+
from .model.base import BaseConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HuggingFaceHubAutoCreateConfig(C.Config):
|
|
27
|
+
enabled: bool = True
|
|
28
|
+
"""Enable automatic repository creation on the Hugging Face Hub."""
|
|
29
|
+
|
|
30
|
+
private: bool = True
|
|
31
|
+
"""Whether to create the repository as private."""
|
|
32
|
+
|
|
33
|
+
namespace: str | None = None
|
|
34
|
+
"""The namespace to create the repository in. If `None`, the repository will be created in the user's namespace."""
|
|
35
|
+
|
|
36
|
+
def __bool__(self):
|
|
37
|
+
return self.enabled
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HuggingFaceHubConfig(CallbackConfigBase):
|
|
41
|
+
"""Configuration options for Hugging Face Hub integration."""
|
|
42
|
+
|
|
43
|
+
enabled: bool = False
|
|
44
|
+
"""Enable Hugging Face Hub integration."""
|
|
45
|
+
|
|
46
|
+
token: str | None = None
|
|
47
|
+
"""Hugging Face Hub API token. If `None`, the token will be read from the current environment.
|
|
48
|
+
This needs to either be set using `huggingface-cli login` or by setting the `HUGGINGFACE_TOKEN`
|
|
49
|
+
environment variable."""
|
|
50
|
+
|
|
51
|
+
auto_create: HuggingFaceHubAutoCreateConfig = HuggingFaceHubAutoCreateConfig()
|
|
52
|
+
"""Automatic repository creation configuration options."""
|
|
53
|
+
|
|
54
|
+
save_config: bool = True
|
|
55
|
+
"""Whether to save the model configuration to the Hugging Face Hub."""
|
|
56
|
+
|
|
57
|
+
save_checkpoints: bool = True
|
|
58
|
+
"""Whether to save checkpoints to the Hugging Face Hub."""
|
|
59
|
+
|
|
60
|
+
save_code: bool = True
|
|
61
|
+
"""Whether to save code to the Hugging Face Hub.
|
|
62
|
+
This is only supported if `nshsnap` is installed and snapshotting is enabled."""
|
|
63
|
+
|
|
64
|
+
save_in_background: bool = True
|
|
65
|
+
"""Whether to save to the Hugging Face Hub in the background.
|
|
66
|
+
This corresponds to setting `run_as_future=True` in the HFApi upload methods."""
|
|
67
|
+
|
|
68
|
+
def enable_(self):
|
|
69
|
+
self.enabled = True
|
|
70
|
+
return self
|
|
71
|
+
|
|
72
|
+
def disable_(self):
|
|
73
|
+
self.enabled = False
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
def __bool__(self):
|
|
77
|
+
return self.enabled
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def create_callbacks(self, root_config):
|
|
81
|
+
yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _api(token: str | None = None):
|
|
85
|
+
# Make sure that `huggingface_hub` is installed
|
|
86
|
+
try:
|
|
87
|
+
import huggingface_hub # noqa: F401
|
|
88
|
+
except ImportError:
|
|
89
|
+
log.exception(
|
|
90
|
+
"Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
|
|
91
|
+
)
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
# Create and authenticate the API instance
|
|
95
|
+
try:
|
|
96
|
+
api = huggingface_hub.HfApi(token=token)
|
|
97
|
+
|
|
98
|
+
# Verify authentication
|
|
99
|
+
api.whoami()
|
|
100
|
+
except Exception:
|
|
101
|
+
log.exception(
|
|
102
|
+
"Authentication failed for Hugging Face Hub. "
|
|
103
|
+
"Please make sure you are logged in using `huggingface-cli login`, "
|
|
104
|
+
"by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
|
|
105
|
+
"or by providing a valid token in the configuration."
|
|
106
|
+
)
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return api
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
113
|
+
username = None
|
|
114
|
+
if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
|
|
115
|
+
username = ac.namespace
|
|
116
|
+
elif (username := api.whoami().get("name", None)) is None:
|
|
117
|
+
raise ValueError("Could not get username from Hugging Face Hub.")
|
|
118
|
+
|
|
119
|
+
# Sanitize the project (if it exists), run_name, and id
|
|
120
|
+
parts = []
|
|
121
|
+
if root_config.project:
|
|
122
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
|
|
123
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
|
|
124
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
|
|
125
|
+
|
|
126
|
+
# Combine parts and ensure it starts and ends with alphanumeric characters
|
|
127
|
+
repo_name = "-".join(parts)
|
|
128
|
+
repo_name = repo_name.strip("-")
|
|
129
|
+
repo_name = re.sub(
|
|
130
|
+
r"-+", "-", repo_name
|
|
131
|
+
) # Replace multiple dashes with a single dash
|
|
132
|
+
|
|
133
|
+
# Ensure the name is not longer than 96 characters (excluding username)
|
|
134
|
+
if len(repo_name) > 96:
|
|
135
|
+
repo_name = repo_name[:96].rstrip("-")
|
|
136
|
+
|
|
137
|
+
# Ensure the repo name starts with an alphanumeric character
|
|
138
|
+
repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
|
|
139
|
+
|
|
140
|
+
# If the repo_name is empty after all sanitization, use a default name
|
|
141
|
+
if not repo_name:
|
|
142
|
+
repo_name = "default-repo-name"
|
|
143
|
+
|
|
144
|
+
return f"{username}/{repo_name}"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class _Upload:
|
|
149
|
+
local_path: Path
|
|
150
|
+
path_in_repo: Path
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_local_path(
|
|
154
|
+
cls,
|
|
155
|
+
local_path: Path,
|
|
156
|
+
root_config: "BaseConfig",
|
|
157
|
+
):
|
|
158
|
+
# Resolve the checkpoint directory
|
|
159
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
160
|
+
root_config.id, "checkpoint"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
relative_path = local_path.relative_to(checkpoint_dir)
|
|
165
|
+
except ValueError:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Prefix the path in repo with "checkpoints"
|
|
171
|
+
path_in_repo = Path("checkpoints") / relative_path
|
|
172
|
+
|
|
173
|
+
return cls(local_path=local_path, path_in_repo=path_in_repo)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class HFHubCallback(NTCallbackBase):
|
|
177
|
+
@contextlib.contextmanager
|
|
178
|
+
def _with_error_handling(self, opeartion: str):
|
|
179
|
+
try:
|
|
180
|
+
yield
|
|
181
|
+
except Exception:
|
|
182
|
+
log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
|
|
183
|
+
else:
|
|
184
|
+
log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
|
|
185
|
+
|
|
186
|
+
def __init__(self, config: HuggingFaceHubConfig):
|
|
187
|
+
super().__init__()
|
|
188
|
+
|
|
189
|
+
self.config = config
|
|
190
|
+
|
|
191
|
+
self._repo_id = None
|
|
192
|
+
self._checksum_to_path_in_repo: dict[str, Path] = {}
|
|
193
|
+
|
|
194
|
+
@override
|
|
195
|
+
def setup(self, trainer, pl_module, stage):
|
|
196
|
+
from .trainer.trainer import Trainer
|
|
197
|
+
|
|
198
|
+
if not isinstance(trainer, Trainer):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
204
|
+
|
|
205
|
+
# Create the repository, if it doesn't exist
|
|
206
|
+
self._repo_id = self.api.create_repo(
|
|
207
|
+
repo_id=_repo_name(self.api, root_config),
|
|
208
|
+
repo_type="model",
|
|
209
|
+
private=self.config.auto_create.private,
|
|
210
|
+
exist_ok=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Upload the config and code
|
|
214
|
+
self._save_config(root_config)
|
|
215
|
+
self._save_code()
|
|
216
|
+
|
|
217
|
+
@override
|
|
218
|
+
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
219
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
220
|
+
|
|
221
|
+
# If HF Hub is enabled, then we upload
|
|
222
|
+
if self.config and trainer.is_global_zero:
|
|
223
|
+
with self._with_error_handling("save checkpoints"):
|
|
224
|
+
self._save_checkpoint(
|
|
225
|
+
_Upload.from_local_path(ckpt_path, root_config),
|
|
226
|
+
_Upload.from_local_path(metadata_path, root_config)
|
|
227
|
+
if metadata_path is not None
|
|
228
|
+
else None,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@cached_property
|
|
232
|
+
def api(self):
|
|
233
|
+
# Create and authenticate the API instance
|
|
234
|
+
if (api := _api(self.config.token)) is None:
|
|
235
|
+
raise ValueError("Failed to create Hugging Face Hub API instance.")
|
|
236
|
+
return api
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def repo_id(self):
|
|
240
|
+
if self._repo_id is None:
|
|
241
|
+
raise ValueError("Repository id has not been initialized.")
|
|
242
|
+
return self._repo_id
|
|
243
|
+
|
|
244
|
+
def _save_config(self, root_config: "BaseConfig"):
|
|
245
|
+
with self._with_error_handling("upload config"):
|
|
246
|
+
self.api.upload_file(
|
|
247
|
+
path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
|
|
248
|
+
path_in_repo="config.json",
|
|
249
|
+
repo_id=self.repo_id,
|
|
250
|
+
repo_type="model",
|
|
251
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def _save_code(self):
|
|
255
|
+
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
256
|
+
# then upload all contents within the snapshot directory to the repository.
|
|
257
|
+
if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
|
|
258
|
+
log.debug("No snapshot directory found. Skipping upload.")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
with self._with_error_handling("save code"):
|
|
262
|
+
snapshot_dir = Path(snapshot_dir)
|
|
263
|
+
if not snapshot_dir.exists() or not snapshot_dir.is_dir():
|
|
264
|
+
log.warning(
|
|
265
|
+
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
266
|
+
)
|
|
267
|
+
return
|
|
268
|
+
|
|
269
|
+
self.api.upload_folder(
|
|
270
|
+
folder_path=str(snapshot_dir),
|
|
271
|
+
repo_id=self.repo_id,
|
|
272
|
+
repo_type="model",
|
|
273
|
+
path_in_repo="code", # Prefix with "code" folder
|
|
274
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def _save_file(self, p: _Upload):
|
|
278
|
+
with self._with_error_handling("save file"):
|
|
279
|
+
# Upload the checkpoint files to the repository
|
|
280
|
+
self.api.upload_file(
|
|
281
|
+
path_or_fileobj=p.local_path,
|
|
282
|
+
path_in_repo=str(p.path_in_repo),
|
|
283
|
+
repo_id=self.repo_id,
|
|
284
|
+
repo_type="model",
|
|
285
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
|
|
289
|
+
# Create a commit for copying the files
|
|
290
|
+
from huggingface_hub.hf_api import CommitOperationCopy
|
|
291
|
+
|
|
292
|
+
with self._with_error_handling("copy file"):
|
|
293
|
+
copy_op = CommitOperationCopy(
|
|
294
|
+
src_path_in_repo=str(source_path_in_repo),
|
|
295
|
+
path_in_repo=str(dest_path_in_repo),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self.api.create_commit(
|
|
299
|
+
repo_id=self.repo_id,
|
|
300
|
+
repo_type="model",
|
|
301
|
+
commit_message="Copy checkpoint file",
|
|
302
|
+
operations=[copy_op],
|
|
303
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
|
|
307
|
+
if not self.config.save_checkpoints:
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
# If no metadata, just save regularly.
|
|
311
|
+
if metadata_path is None:
|
|
312
|
+
self._save_file(path)
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
# Otherwise, let's check to see if we've already uploaded the metadata.
|
|
316
|
+
# If so, we can just copy the checkpoint file.
|
|
317
|
+
from ._checkpoint.metadata import CheckpointMetadata
|
|
318
|
+
|
|
319
|
+
metadata = CheckpointMetadata.from_file(metadata_path.local_path)
|
|
320
|
+
if (
|
|
321
|
+
existing_ckpt_path := self._checksum_to_path_in_repo.get(
|
|
322
|
+
metadata.checkpoint_checksum
|
|
323
|
+
)
|
|
324
|
+
) is not None:
|
|
325
|
+
self._copy_file(existing_ckpt_path, path.path_in_repo)
|
|
326
|
+
else:
|
|
327
|
+
# Otherwise, we save the checkpoint & keep the checksum so we don't
|
|
328
|
+
# re-upload the same file again.
|
|
329
|
+
self._save_file(path)
|
|
330
|
+
self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
|
|
331
|
+
path.path_in_repo
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Save the metadata file
|
|
335
|
+
# NOTE: This file is fairly small, so we can just upload it directly.
|
|
336
|
+
# No need to copy.
|
|
337
|
+
self._save_file(metadata_path)
|
|
338
|
+
|
|
339
|
+
@override
|
|
340
|
+
def state_dict(self):
|
|
341
|
+
return {
|
|
342
|
+
"repo_id": self._repo_id,
|
|
343
|
+
"checksum_to_path_in_repo": {
|
|
344
|
+
k: str(v) for k, v in self._checksum_to_path_in_repo.items()
|
|
345
|
+
},
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
@override
|
|
349
|
+
def load_state_dict(self, state_dict):
|
|
350
|
+
self._repo_id = state_dict["repo_id"]
|
|
351
|
+
self._checksum_to_path_in_repo = {
|
|
352
|
+
k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
|
|
353
|
+
}
|
|
@@ -11,7 +11,6 @@ from typing_extensions import TypeVar, override
|
|
|
11
11
|
|
|
12
12
|
from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
|
|
13
13
|
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
-
from ...util.path import find_symlinks
|
|
15
14
|
from ..base import CallbackConfigBase
|
|
16
15
|
|
|
17
16
|
if TYPE_CHECKING:
|
|
@@ -117,47 +116,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
117
116
|
)
|
|
118
117
|
continue
|
|
119
118
|
|
|
120
|
-
|
|
121
|
-
trainer, old_ckpt_path, metadata=True
|
|
122
|
-
)
|
|
119
|
+
_remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
|
123
120
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
124
121
|
|
|
125
|
-
def _remove_checkpoint_with_link_support(
|
|
126
|
-
self,
|
|
127
|
-
trainer: Trainer,
|
|
128
|
-
path: Path,
|
|
129
|
-
metadata: bool,
|
|
130
|
-
):
|
|
131
|
-
# Find all the symlinks to the checkpoint
|
|
132
|
-
ckpt_callbacks: list[CheckpointBase] = [
|
|
133
|
-
callback
|
|
134
|
-
for callback in trainer.checkpoint_callbacks
|
|
135
|
-
if isinstance(callback, CheckpointBase) and callback is not self
|
|
136
|
-
]
|
|
137
|
-
symlink_paths = find_symlinks(
|
|
138
|
-
path,
|
|
139
|
-
*[callback.dirpath for callback in ckpt_callbacks],
|
|
140
|
-
glob_pattern=f"*.{self.extension()}",
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
# If there are no symlinks, just remove the checkpoint
|
|
144
|
-
if not symlink_paths:
|
|
145
|
-
_remove_checkpoint(trainer, path, metadata=metadata)
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
log.debug(
|
|
149
|
-
f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
# For the first symlink, we can just move the checkpoint file
|
|
153
|
-
# to the symlink path. For the rest, we need to make new symlinks.
|
|
154
|
-
new_target = symlink_paths.pop(0)
|
|
155
|
-
path.rename(new_target)
|
|
156
|
-
log.debug(f"New symlink target: {new_target}")
|
|
157
|
-
|
|
158
|
-
for symlink_path in symlink_paths:
|
|
159
|
-
_link_checkpoint(new_target, symlink_path, metadata=False)
|
|
160
|
-
|
|
161
122
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
162
123
|
current_metrics: dict[str, Any] = {
|
|
163
124
|
"epoch": trainer.current_epoch,
|
|
@@ -195,6 +156,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
195
156
|
filepath,
|
|
196
157
|
self.config.save_weights_only,
|
|
197
158
|
use_checkpoint_cache=None,
|
|
159
|
+
ckpt_cache_use_symlink=False,
|
|
198
160
|
)
|
|
199
161
|
|
|
200
162
|
if trainer.is_global_zero:
|
{nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
|
70
70
|
|
|
71
71
|
# Events
|
|
72
72
|
@override
|
|
73
|
-
def
|
|
73
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
74
74
|
self.save_checkpoints(trainer)
|
{nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
|
39
39
|
return True
|
|
40
40
|
|
|
41
41
|
@override
|
|
42
|
-
def
|
|
42
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
43
43
|
self.save_checkpoints(trainer)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import shutil
|
|
3
4
|
from collections.abc import Sequence
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
@@ -416,11 +417,12 @@ class Trainer(LightningTrainer):
|
|
|
416
417
|
weights_only: bool = False,
|
|
417
418
|
storage_options: Any | None = None,
|
|
418
419
|
use_checkpoint_cache: bool | None = None,
|
|
420
|
+
ckpt_cache_use_symlink: bool = False,
|
|
419
421
|
):
|
|
420
422
|
lm = self._base_module
|
|
421
|
-
|
|
423
|
+
root_config = cast(BaseConfig, lm.hparams)
|
|
422
424
|
if use_checkpoint_cache is None:
|
|
423
|
-
use_checkpoint_cache =
|
|
425
|
+
use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
|
|
424
426
|
|
|
425
427
|
filepath = Path(filepath)
|
|
426
428
|
|
|
@@ -440,7 +442,10 @@ class Trainer(LightningTrainer):
|
|
|
440
442
|
# If we have a cached path, then we symlink it to the new path.
|
|
441
443
|
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
444
|
if self.is_global_zero:
|
|
443
|
-
|
|
445
|
+
if ckpt_cache_use_symlink:
|
|
446
|
+
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
447
|
+
else:
|
|
448
|
+
shutil.copy(cached_path, filepath)
|
|
444
449
|
self.strategy.barrier("Trainer.save_checkpoint")
|
|
445
450
|
else:
|
|
446
451
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
@@ -454,7 +459,7 @@ class Trainer(LightningTrainer):
|
|
|
454
459
|
|
|
455
460
|
# Save the checkpoint metadata
|
|
456
461
|
metadata_path = None
|
|
457
|
-
if
|
|
462
|
+
if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
458
463
|
# Generate the metadata and write to disk
|
|
459
464
|
if (
|
|
460
465
|
metadata_path := _write_checkpoint_metadata(self, lm, filepath)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import os
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import TypeAlias
|
|
@@ -50,3 +51,20 @@ def find_symlinks(
|
|
|
50
51
|
pass
|
|
51
52
|
|
|
52
53
|
return symlinks
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_file_checksum(file_path: Path) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Calculate the SHA256 checksum of a file.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
file_path (Path): The path to the file.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
str: The hexadecimal representation of the file's SHA256 checksum.
|
|
65
|
+
"""
|
|
66
|
+
sha256_hash = hashlib.sha256()
|
|
67
|
+
with file_path.open("rb") as f:
|
|
68
|
+
for byte_block in iter(lambda: f.read(4096), b""):
|
|
69
|
+
sha256_hash.update(byte_block)
|
|
70
|
+
return sha256_hash.hexdigest()
|