nshtrainer 0.17.0__tar.gz → 0.18.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/PKG-INFO +2 -1
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/pyproject.toml +10 -2
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/_checkpoint/metadata.py +5 -3
- nshtrainer-0.18.0/src/nshtrainer/_hf_hub.py +347 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/__init__.py +1 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/base.py +1 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/config.py +5 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/trainer/trainer.py +9 -1
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/README.md +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.18.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -11,6 +11,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
12
|
Provides-Extra: extra
|
|
13
13
|
Requires-Dist: GitPython ; extra == "extra"
|
|
14
|
+
Requires-Dist: huggingface-hub ; extra == "extra"
|
|
14
15
|
Requires-Dist: lightning
|
|
15
16
|
Requires-Dist: nshconfig
|
|
16
17
|
Requires-Dist: nshrunner
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.18.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -22,6 +22,7 @@ wrapt = { version = "*", optional = true }
|
|
|
22
22
|
GitPython = { version = "*", optional = true }
|
|
23
23
|
wandb = { version = "*", optional = true }
|
|
24
24
|
tensorboard = { version = "*", optional = true }
|
|
25
|
+
huggingface-hub = { version = "*", optional = true }
|
|
25
26
|
|
|
26
27
|
[tool.poetry.group.dev.dependencies]
|
|
27
28
|
pyright = "^1.1.372"
|
|
@@ -45,4 +46,11 @@ reportPrivateImportUsage = false
|
|
|
45
46
|
ignore = ["F722", "F821", "E731", "E741"]
|
|
46
47
|
|
|
47
48
|
[tool.poetry.extras]
|
|
48
|
-
extra = [
|
|
49
|
+
extra = [
|
|
50
|
+
"torchmetrics",
|
|
51
|
+
"wrapt",
|
|
52
|
+
"GitPython",
|
|
53
|
+
"wandb",
|
|
54
|
+
"tensorboard",
|
|
55
|
+
"huggingface-hub",
|
|
56
|
+
]
|
|
@@ -44,7 +44,7 @@ class CheckpointMetadata(C.Config):
|
|
|
44
44
|
|
|
45
45
|
@classmethod
|
|
46
46
|
def from_file(cls, path: Path):
|
|
47
|
-
return cls.model_validate_json(path.read_text())
|
|
47
|
+
return cls.model_validate_json(path.read_text(encoding="utf-8"))
|
|
48
48
|
|
|
49
49
|
@classmethod
|
|
50
50
|
def from_ckpt_path(cls, checkpoint_path: Path):
|
|
@@ -112,8 +112,10 @@ def _write_checkpoint_metadata(
|
|
|
112
112
|
metadata_path.write_text(metadata.model_dump_json(indent=4), encoding="utf-8")
|
|
113
113
|
except Exception:
|
|
114
114
|
log.exception(f"Failed to write metadata to {checkpoint_path}")
|
|
115
|
-
|
|
116
|
-
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
log.debug(f"Checkpoint metadata written to {checkpoint_path}")
|
|
118
|
+
return checkpoint_path
|
|
117
119
|
|
|
118
120
|
|
|
119
121
|
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
5
|
+
|
|
6
|
+
import nshconfig as C
|
|
7
|
+
from lightning.pytorch import Callback
|
|
8
|
+
from lightning.pytorch.trainer import Trainer
|
|
9
|
+
from nshrunner._env import SNAPSHOT_DIR
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from .callbacks.base import (
|
|
13
|
+
CallbackConfigBase,
|
|
14
|
+
CallbackMetadataConfig,
|
|
15
|
+
CallbackWithMetadata,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from huggingface_hub import HfApi # noqa: F401
|
|
20
|
+
|
|
21
|
+
from .model.base import BaseConfig
|
|
22
|
+
from .trainer.trainer import Trainer
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HuggingFaceHubAutoCreateConfig(C.Config):
|
|
27
|
+
enabled: bool = True
|
|
28
|
+
"""Enable automatic repository creation on the Hugging Face Hub."""
|
|
29
|
+
|
|
30
|
+
private: bool = True
|
|
31
|
+
"""Whether to create the repository as private."""
|
|
32
|
+
|
|
33
|
+
namespace: str | None = None
|
|
34
|
+
"""The namespace to create the repository in. If `None`, the repository will be created in the user's namespace."""
|
|
35
|
+
|
|
36
|
+
def __bool__(self):
|
|
37
|
+
return self.enabled
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HuggingFaceHubConfig(CallbackConfigBase):
|
|
41
|
+
"""Configuration options for Hugging Face Hub integration."""
|
|
42
|
+
|
|
43
|
+
enabled: bool = False
|
|
44
|
+
"""Enable Hugging Face Hub integration."""
|
|
45
|
+
|
|
46
|
+
token: str | None = None
|
|
47
|
+
"""Hugging Face Hub API token. If `None`, the token will be read from the current environment.
|
|
48
|
+
This needs to either be set using `huggingface-cli login` or by setting the `HUGGINGFACE_TOKEN`
|
|
49
|
+
environment variable."""
|
|
50
|
+
|
|
51
|
+
auto_create: HuggingFaceHubAutoCreateConfig = HuggingFaceHubAutoCreateConfig()
|
|
52
|
+
"""Automatic repository creation configuration options."""
|
|
53
|
+
|
|
54
|
+
save_config: bool = True
|
|
55
|
+
"""Whether to save the model configuration to the Hugging Face Hub."""
|
|
56
|
+
|
|
57
|
+
save_checkpoints: bool = True
|
|
58
|
+
"""Whether to save checkpoints to the Hugging Face Hub."""
|
|
59
|
+
|
|
60
|
+
save_code: bool = True
|
|
61
|
+
"""Whether to save code to the Hugging Face Hub.
|
|
62
|
+
This is only supported if `nshsnap` is installed and snapshotting is enabled."""
|
|
63
|
+
|
|
64
|
+
save_in_background: bool = True
|
|
65
|
+
"""Whether to save to the Hugging Face Hub in the background.
|
|
66
|
+
This corresponds to setting `run_as_future=True` in the HFApi upload methods."""
|
|
67
|
+
|
|
68
|
+
def enable_(self):
|
|
69
|
+
self.enabled = True
|
|
70
|
+
return self
|
|
71
|
+
|
|
72
|
+
def disable_(self):
|
|
73
|
+
self.enabled = False
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
def __bool__(self):
|
|
77
|
+
return self.enabled
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def create_callbacks(self, root_config):
|
|
81
|
+
yield CallbackWithMetadata(
|
|
82
|
+
HFHubCallback(self),
|
|
83
|
+
CallbackMetadataConfig(ignore_if_exists=True),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _api(token: str | None = None):
|
|
88
|
+
# Make sure that `huggingface_hub` is installed
|
|
89
|
+
try:
|
|
90
|
+
import huggingface_hub # noqa: F401
|
|
91
|
+
except ImportError:
|
|
92
|
+
log.exception(
|
|
93
|
+
"Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
|
|
94
|
+
)
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
# Create and authenticate the API instance
|
|
98
|
+
try:
|
|
99
|
+
api = huggingface_hub.HfApi(token=token)
|
|
100
|
+
|
|
101
|
+
# Verify authentication
|
|
102
|
+
api.whoami()
|
|
103
|
+
except Exception as e:
|
|
104
|
+
log.exception(
|
|
105
|
+
f"Authentication failed for Hugging Face Hub: {str(e)}. "
|
|
106
|
+
"Please make sure you are logged in using `huggingface-cli login`, "
|
|
107
|
+
"by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
|
|
108
|
+
"or by providing a valid token in the configuration."
|
|
109
|
+
)
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
return api
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _enabled_and_valid(
|
|
116
|
+
trainer: "Trainer",
|
|
117
|
+
config: HuggingFaceHubConfig,
|
|
118
|
+
*,
|
|
119
|
+
rank_zero_only: bool,
|
|
120
|
+
):
|
|
121
|
+
# Make sure this is enabled and the config is valid
|
|
122
|
+
if not config:
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
# If `rank_zero_only` and this is not rank 0, stop here.
|
|
126
|
+
if rank_zero_only and not trainer.is_global_zero:
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Make sure that `huggingface_hub` is installed
|
|
130
|
+
try:
|
|
131
|
+
import huggingface_hub # noqa: F401
|
|
132
|
+
except ImportError:
|
|
133
|
+
log.exception(
|
|
134
|
+
"Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
|
|
135
|
+
)
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
# Create and authenticate the API instance
|
|
139
|
+
if (api := getattr(trainer, "_hf_hub_api", None)) is None:
|
|
140
|
+
api = _api(config.token)
|
|
141
|
+
setattr(trainer, "_hf_hub_api", api)
|
|
142
|
+
return cast(huggingface_hub.HfApi, api)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
146
|
+
username = None
|
|
147
|
+
if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
|
|
148
|
+
username = ac.namespace
|
|
149
|
+
elif (username := api.whoami().get("name", None)) is None:
|
|
150
|
+
raise ValueError("Could not get username from Hugging Face Hub.")
|
|
151
|
+
|
|
152
|
+
return f"{username}/{root_config.project}-{root_config.run_name}-{root_config.id}"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
|
|
156
|
+
config = root_config.trainer.hf_hub
|
|
157
|
+
if (
|
|
158
|
+
api := _enabled_and_valid(
|
|
159
|
+
trainer,
|
|
160
|
+
config,
|
|
161
|
+
rank_zero_only=True,
|
|
162
|
+
)
|
|
163
|
+
) is None or not config.auto_create:
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
from huggingface_hub.utils import RepositoryNotFoundError
|
|
167
|
+
|
|
168
|
+
# Resolve the repository name
|
|
169
|
+
repo_name = _repo_name(api, root_config)
|
|
170
|
+
|
|
171
|
+
# Create the repository, if it doesn't exist
|
|
172
|
+
try:
|
|
173
|
+
# Check if the repository exists
|
|
174
|
+
api.repo_info(repo_id=repo_name, repo_type="model")
|
|
175
|
+
log.info(f"Repository '{repo_name}' already exists.")
|
|
176
|
+
except RepositoryNotFoundError:
|
|
177
|
+
# Repository doesn't exist, so create it
|
|
178
|
+
try:
|
|
179
|
+
api.create_repo(
|
|
180
|
+
repo_id=repo_name,
|
|
181
|
+
repo_type="model",
|
|
182
|
+
private=config.auto_create.private,
|
|
183
|
+
exist_ok=True,
|
|
184
|
+
)
|
|
185
|
+
log.info(f"Created new repository '{repo_name}'.")
|
|
186
|
+
except Exception as e:
|
|
187
|
+
log.exception(f"Failed to create repository '{repo_name}': {str(e)}")
|
|
188
|
+
except Exception as e:
|
|
189
|
+
log.exception(f"Error checking repository '{repo_name}': {str(e)}")
|
|
190
|
+
|
|
191
|
+
# Upload the config
|
|
192
|
+
_save_config(root_config, trainer=trainer)
|
|
193
|
+
|
|
194
|
+
# Upload the code
|
|
195
|
+
_save_code(repo_name, config=config, trainer=trainer)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _save_code(
|
|
199
|
+
repo_name: str,
|
|
200
|
+
*,
|
|
201
|
+
config: HuggingFaceHubConfig,
|
|
202
|
+
trainer: "Trainer",
|
|
203
|
+
):
|
|
204
|
+
if (
|
|
205
|
+
api := _enabled_and_valid(
|
|
206
|
+
trainer,
|
|
207
|
+
config,
|
|
208
|
+
rank_zero_only=True,
|
|
209
|
+
)
|
|
210
|
+
) is None or not config.save_code:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
214
|
+
# then upload all contents within the snapshot directory to the repository.
|
|
215
|
+
snapshot_dir = os.environ.get(SNAPSHOT_DIR)
|
|
216
|
+
if not snapshot_dir:
|
|
217
|
+
log.info("No snapshot directory found. Skipping upload.")
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
snapshot_path = Path(snapshot_dir)
|
|
221
|
+
if not snapshot_path.exists() or not snapshot_path.is_dir():
|
|
222
|
+
log.warning(
|
|
223
|
+
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
224
|
+
)
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
api.upload_folder(
|
|
229
|
+
folder_path=str(snapshot_path),
|
|
230
|
+
repo_id=repo_name,
|
|
231
|
+
repo_type="model",
|
|
232
|
+
path_in_repo="code", # Prefix with "code" folder
|
|
233
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
234
|
+
)
|
|
235
|
+
log.info(
|
|
236
|
+
f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
237
|
+
)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
log.exception(
|
|
240
|
+
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder: {str(e)}"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _save_config(
|
|
245
|
+
root_config: "BaseConfig",
|
|
246
|
+
*,
|
|
247
|
+
trainer: "Trainer",
|
|
248
|
+
):
|
|
249
|
+
config = root_config.trainer.hf_hub
|
|
250
|
+
if (
|
|
251
|
+
api := _enabled_and_valid(
|
|
252
|
+
trainer,
|
|
253
|
+
config,
|
|
254
|
+
rank_zero_only=True,
|
|
255
|
+
)
|
|
256
|
+
) is None or not config.save_config:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
# Convert the root config to a JSON string
|
|
260
|
+
# NOTE: This is a utf-8 string.
|
|
261
|
+
config_json = root_config.model_dump_json(indent=4)
|
|
262
|
+
|
|
263
|
+
# Resolve the repository name
|
|
264
|
+
repo_name = _repo_name(api, root_config)
|
|
265
|
+
|
|
266
|
+
# Upload the config file to the repository
|
|
267
|
+
try:
|
|
268
|
+
api.upload_file(
|
|
269
|
+
path_or_fileobj=config_json.encode("utf-8"),
|
|
270
|
+
path_in_repo="config.json",
|
|
271
|
+
repo_id=repo_name,
|
|
272
|
+
repo_type="model",
|
|
273
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
274
|
+
)
|
|
275
|
+
log.info(f"Uploaded config.json to repository '{repo_name}'.")
|
|
276
|
+
except Exception as e:
|
|
277
|
+
log.exception(
|
|
278
|
+
f"Failed to upload config.json to repository '{repo_name}': {str(e)}"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _save_checkpoint_files(
|
|
283
|
+
trainer: "Trainer",
|
|
284
|
+
paths: list[Path],
|
|
285
|
+
*,
|
|
286
|
+
root_config: "BaseConfig",
|
|
287
|
+
):
|
|
288
|
+
config = root_config.trainer.hf_hub
|
|
289
|
+
if (
|
|
290
|
+
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
291
|
+
) is None or not config.save_checkpoints:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
# Resolve the checkpoint directory
|
|
295
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
296
|
+
root_config.id, "checkpoint"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Resolve the repository name
|
|
300
|
+
repo_name = _repo_name(api, root_config)
|
|
301
|
+
|
|
302
|
+
for p in paths:
|
|
303
|
+
try:
|
|
304
|
+
relative_path = p.relative_to(checkpoint_dir)
|
|
305
|
+
except ValueError:
|
|
306
|
+
log.warning(
|
|
307
|
+
f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
|
|
308
|
+
)
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
# Prefix the path in repo with "checkpoints"
|
|
312
|
+
path_in_repo = Path("checkpoints") / relative_path
|
|
313
|
+
|
|
314
|
+
# Upload the checkpoint file to the repository
|
|
315
|
+
try:
|
|
316
|
+
api.upload_file(
|
|
317
|
+
path_or_fileobj=str(p.resolve().absolute()),
|
|
318
|
+
path_in_repo=str(path_in_repo),
|
|
319
|
+
repo_id=repo_name,
|
|
320
|
+
repo_type="model",
|
|
321
|
+
run_as_future=cast(Any, config.save_in_background),
|
|
322
|
+
)
|
|
323
|
+
log.info(
|
|
324
|
+
f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
|
|
325
|
+
)
|
|
326
|
+
except Exception as e:
|
|
327
|
+
log.exception(
|
|
328
|
+
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}': {str(e)}"
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class HFHubCallback(Callback):
|
|
335
|
+
def __init__(self, config: HuggingFaceHubConfig):
|
|
336
|
+
super().__init__()
|
|
337
|
+
self.config = config
|
|
338
|
+
|
|
339
|
+
@override
|
|
340
|
+
def setup(self, trainer, pl_module, stage):
|
|
341
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
342
|
+
_init(trainer=trainer, root_config=root_config)
|
|
343
|
+
|
|
344
|
+
@override
|
|
345
|
+
def teardown(self, trainer, pl_module, stage):
|
|
346
|
+
if hasattr(trainer, "_hf_hub_api"):
|
|
347
|
+
delattr(trainer, "_hf_hub_api")
|
|
@@ -10,6 +10,7 @@ from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
|
10
10
|
from .config import DirectoryConfig as DirectoryConfig
|
|
11
11
|
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
12
12
|
from .config import GradientClippingConfig as GradientClippingConfig
|
|
13
|
+
from .config import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
13
14
|
from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
14
15
|
from .config import LoggingConfig as LoggingConfig
|
|
15
16
|
from .config import MetricConfig as MetricConfig
|
|
@@ -192,6 +192,7 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
192
192
|
hparams = self.config_cls().model_validate(hparams)
|
|
193
193
|
hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
|
|
194
194
|
hparams = self.pre_init_update_hparams(hparams)
|
|
195
|
+
|
|
195
196
|
super().__init__(hparams)
|
|
196
197
|
|
|
197
198
|
self.save_hyperparameters(hparams)
|
|
@@ -33,6 +33,7 @@ from lightning.pytorch.strategies.strategy import Strategy
|
|
|
33
33
|
from typing_extensions import Self, TypedDict, TypeVar, override
|
|
34
34
|
|
|
35
35
|
from .._checkpoint.loader import CheckpointLoadingConfig
|
|
36
|
+
from .._hf_hub import HuggingFaceHubConfig
|
|
36
37
|
from ..callbacks import (
|
|
37
38
|
BestCheckpointCallbackConfig,
|
|
38
39
|
CallbackConfig,
|
|
@@ -819,6 +820,9 @@ class TrainerConfig(C.Config):
|
|
|
819
820
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
820
821
|
"""Checkpoint saving configuration options."""
|
|
821
822
|
|
|
823
|
+
hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
|
|
824
|
+
"""Hugging Face Hub configuration options."""
|
|
825
|
+
|
|
822
826
|
logging: LoggingConfig = LoggingConfig()
|
|
823
827
|
"""Logging/experiment tracking (e.g., WandB) configuration options."""
|
|
824
828
|
|
|
@@ -1213,4 +1217,5 @@ class BaseConfig(C.Config):
|
|
|
1213
1217
|
yield self.trainer.checkpoint_saving
|
|
1214
1218
|
yield self.trainer.logging
|
|
1215
1219
|
yield self.trainer.optimizer
|
|
1220
|
+
yield self.trainer.hf_hub
|
|
1216
1221
|
yield from self.trainer.callbacks
|
|
@@ -420,8 +420,16 @@ class Trainer(LightningTrainer):
|
|
|
420
420
|
# Save the checkpoint metadata
|
|
421
421
|
lm = self._base_module
|
|
422
422
|
hparams = cast(BaseConfig, lm.hparams)
|
|
423
|
+
metadata_path = None
|
|
423
424
|
if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
424
425
|
# Generate the metadata and write to disk
|
|
425
|
-
_write_checkpoint_metadata(self, lm, filepath)
|
|
426
|
+
metadata_path = _write_checkpoint_metadata(self, lm, filepath)
|
|
427
|
+
|
|
428
|
+
# If HF Hub is enabled, then we upload
|
|
429
|
+
if hparams.trainer.hf_hub:
|
|
430
|
+
from .._hf_hub import _save_checkpoint_files
|
|
431
|
+
|
|
432
|
+
files = [f for f in (filepath, metadata_path) if f is not None]
|
|
433
|
+
_save_checkpoint_files(self, files, root_config=hparams)
|
|
426
434
|
|
|
427
435
|
return ret_val
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
File without changes
|
{nshtrainer-0.17.0 → nshtrainer-0.18.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|