nshtrainer 0.23.0__tar.gz → 0.25.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_checkpoint/metadata.py +3 -1
  4. nshtrainer-0.25.0/src/nshtrainer/_hf_hub.py +353 -0
  5. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/_base.py +2 -40
  6. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
  7. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  8. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/trainer.py +9 -4
  9. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/path.py +18 -0
  10. nshtrainer-0.23.0/src/nshtrainer/_hf_hub.py +0 -518
  11. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/README.md +0 -0
  12. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/__init__.py +0 -0
  13. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_callback.py +0 -0
  14. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  15. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  16. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  17. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  18. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  19. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  20. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/base.py +0 -0
  21. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  22. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  23. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  24. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/ema.py +0 -0
  25. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  26. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  27. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/interval.py +0 -0
  28. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  29. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  30. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  31. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  32. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/timer.py +0 -0
  33. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  34. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/data/__init__.py +0 -0
  35. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  36. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/data/transform.py +0 -0
  37. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/__init__.py +0 -0
  38. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/_experimental.py +0 -0
  39. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/actsave.py +0 -0
  40. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/callbacks.py +0 -0
  41. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/config.py +0 -0
  42. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/data.py +0 -0
  43. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/log.py +0 -0
  44. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  45. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/model.py +0 -0
  46. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/nn.py +0 -0
  47. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/optimizer.py +0 -0
  48. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/runner.py +0 -0
  49. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/snapshot.py +0 -0
  50. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/snoop.py +0 -0
  51. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/trainer.py +0 -0
  52. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/typecheck.py +0 -0
  53. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/ll/util.py +0 -0
  54. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/__init__.py +0 -0
  55. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/_base.py +0 -0
  56. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/csv.py +0 -0
  57. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  58. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/loggers/wandb.py +0 -0
  59. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  60. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  61. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  62. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  63. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/metrics/__init__.py +0 -0
  64. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/metrics/_config.py +0 -0
  65. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/__init__.py +0 -0
  66. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/base.py +0 -0
  67. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/config.py +0 -0
  68. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/callback.py +0 -0
  69. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/debug.py +0 -0
  70. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/distributed.py +0 -0
  71. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/logger.py +0 -0
  72. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/profiler.py +0 -0
  73. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  74. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  75. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/__init__.py +0 -0
  76. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/mlp.py +0 -0
  77. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/module_dict.py +0 -0
  78. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/module_list.py +0 -0
  79. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  80. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/optimizer.py +0 -0
  81. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/runner.py +0 -0
  82. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  83. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/__init__.py +0 -0
  84. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  85. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  86. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  87. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/_environment_info.py +0 -0
  88. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/_useful_types.py +0 -0
  89. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/environment.py +0 -0
  90. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/seed.py +0 -0
  91. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/slurm.py +0 -0
  92. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/typed.py +0 -0
  93. {nshtrainer-0.23.0 → nshtrainer-0.25.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.23.0
3
+ Version: 0.25.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.23.0"
3
+ version = "0.25.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import torch
12
12
 
13
13
  from ..util._environment_info import EnvironmentConfig
14
- from ..util.path import get_relative_path
14
+ from ..util.path import compute_file_checksum, get_relative_path
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ..model import BaseConfig, LightningModuleBase
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
28
28
 
29
29
  checkpoint_path: Path
30
30
  checkpoint_filename: str
31
+ checkpoint_checksum: str
31
32
 
32
33
  run_id: str
33
34
  name: str
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
81
82
  # moving the checkpoint directory
82
83
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
83
84
  checkpoint_filename=checkpoint_path.name,
85
+ checkpoint_checksum=compute_file_checksum(checkpoint_path),
84
86
  run_id=config.id,
85
87
  name=config.run_name,
86
88
  project=config.project,
@@ -0,0 +1,353 @@
1
+ import contextlib
2
+ import logging
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ import nshconfig as C
11
+ from nshrunner._env import SNAPSHOT_DIR
12
+ from typing_extensions import override
13
+
14
+ from ._callback import NTCallbackBase
15
+ from .callbacks.base import CallbackConfigBase
16
+
17
+ if TYPE_CHECKING:
18
+ from huggingface_hub import HfApi # noqa: F401
19
+
20
+ from .model.base import BaseConfig
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ class HuggingFaceHubAutoCreateConfig(C.Config):
27
+ enabled: bool = True
28
+ """Enable automatic repository creation on the Hugging Face Hub."""
29
+
30
+ private: bool = True
31
+ """Whether to create the repository as private."""
32
+
33
+ namespace: str | None = None
34
+ """The namespace to create the repository in. If `None`, the repository will be created in the user's namespace."""
35
+
36
+ def __bool__(self):
37
+ return self.enabled
38
+
39
+
40
+ class HuggingFaceHubConfig(CallbackConfigBase):
41
+ """Configuration options for Hugging Face Hub integration."""
42
+
43
+ enabled: bool = False
44
+ """Enable Hugging Face Hub integration."""
45
+
46
+ token: str | None = None
47
+ """Hugging Face Hub API token. If `None`, the token will be read from the current environment.
48
+ This needs to either be set using `huggingface-cli login` or by setting the `HUGGINGFACE_TOKEN`
49
+ environment variable."""
50
+
51
+ auto_create: HuggingFaceHubAutoCreateConfig = HuggingFaceHubAutoCreateConfig()
52
+ """Automatic repository creation configuration options."""
53
+
54
+ save_config: bool = True
55
+ """Whether to save the model configuration to the Hugging Face Hub."""
56
+
57
+ save_checkpoints: bool = True
58
+ """Whether to save checkpoints to the Hugging Face Hub."""
59
+
60
+ save_code: bool = True
61
+ """Whether to save code to the Hugging Face Hub.
62
+ This is only supported if `nshsnap` is installed and snapshotting is enabled."""
63
+
64
+ save_in_background: bool = True
65
+ """Whether to save to the Hugging Face Hub in the background.
66
+ This corresponds to setting `run_as_future=True` in the HFApi upload methods."""
67
+
68
+ def enable_(self):
69
+ self.enabled = True
70
+ return self
71
+
72
+ def disable_(self):
73
+ self.enabled = False
74
+ return self
75
+
76
+ def __bool__(self):
77
+ return self.enabled
78
+
79
+ @override
80
+ def create_callbacks(self, root_config):
81
+ yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
82
+
83
+
84
+ def _api(token: str | None = None):
85
+ # Make sure that `huggingface_hub` is installed
86
+ try:
87
+ import huggingface_hub # noqa: F401
88
+ except ImportError:
89
+ log.exception(
90
+ "Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
91
+ )
92
+ return None
93
+
94
+ # Create and authenticate the API instance
95
+ try:
96
+ api = huggingface_hub.HfApi(token=token)
97
+
98
+ # Verify authentication
99
+ api.whoami()
100
+ except Exception:
101
+ log.exception(
102
+ "Authentication failed for Hugging Face Hub. "
103
+ "Please make sure you are logged in using `huggingface-cli login`, "
104
+ "by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
105
+ "or by providing a valid token in the configuration."
106
+ )
107
+ return None
108
+
109
+ return api
110
+
111
+
112
+ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
113
+ username = None
114
+ if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
115
+ username = ac.namespace
116
+ elif (username := api.whoami().get("name", None)) is None:
117
+ raise ValueError("Could not get username from Hugging Face Hub.")
118
+
119
+ # Sanitize the project (if it exists), run_name, and id
120
+ parts = []
121
+ if root_config.project:
122
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
123
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
124
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
125
+
126
+ # Combine parts and ensure it starts and ends with alphanumeric characters
127
+ repo_name = "-".join(parts)
128
+ repo_name = repo_name.strip("-")
129
+ repo_name = re.sub(
130
+ r"-+", "-", repo_name
131
+ ) # Replace multiple dashes with a single dash
132
+
133
+ # Ensure the name is not longer than 96 characters (excluding username)
134
+ if len(repo_name) > 96:
135
+ repo_name = repo_name[:96].rstrip("-")
136
+
137
+ # Ensure the repo name starts with an alphanumeric character
138
+ repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
139
+
140
+ # If the repo_name is empty after all sanitization, use a default name
141
+ if not repo_name:
142
+ repo_name = "default-repo-name"
143
+
144
+ return f"{username}/{repo_name}"
145
+
146
+
147
+ @dataclass
148
+ class _Upload:
149
+ local_path: Path
150
+ path_in_repo: Path
151
+
152
+ @classmethod
153
+ def from_local_path(
154
+ cls,
155
+ local_path: Path,
156
+ root_config: "BaseConfig",
157
+ ):
158
+ # Resolve the checkpoint directory
159
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
160
+ root_config.id, "checkpoint"
161
+ )
162
+
163
+ try:
164
+ relative_path = local_path.relative_to(checkpoint_dir)
165
+ except ValueError:
166
+ raise ValueError(
167
+ f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
168
+ )
169
+
170
+ # Prefix the path in repo with "checkpoints"
171
+ path_in_repo = Path("checkpoints") / relative_path
172
+
173
+ return cls(local_path=local_path, path_in_repo=path_in_repo)
174
+
175
+
176
+ class HFHubCallback(NTCallbackBase):
177
+ @contextlib.contextmanager
178
+ def _with_error_handling(self, opeartion: str):
179
+ try:
180
+ yield
181
+ except Exception:
182
+ log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
183
+ else:
184
+ log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
185
+
186
+ def __init__(self, config: HuggingFaceHubConfig):
187
+ super().__init__()
188
+
189
+ self.config = config
190
+
191
+ self._repo_id = None
192
+ self._checksum_to_path_in_repo: dict[str, Path] = {}
193
+
194
+ @override
195
+ def setup(self, trainer, pl_module, stage):
196
+ from .trainer.trainer import Trainer
197
+
198
+ if not isinstance(trainer, Trainer):
199
+ raise ValueError(
200
+ f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
201
+ )
202
+
203
+ root_config = cast("BaseConfig", pl_module.hparams)
204
+
205
+ # Create the repository, if it doesn't exist
206
+ self._repo_id = self.api.create_repo(
207
+ repo_id=_repo_name(self.api, root_config),
208
+ repo_type="model",
209
+ private=self.config.auto_create.private,
210
+ exist_ok=True,
211
+ )
212
+
213
+ # Upload the config and code
214
+ self._save_config(root_config)
215
+ self._save_code()
216
+
217
+ @override
218
+ def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
+ root_config = cast("BaseConfig", pl_module.hparams)
220
+
221
+ # If HF Hub is enabled, then we upload
222
+ if self.config and trainer.is_global_zero:
223
+ with self._with_error_handling("save checkpoints"):
224
+ self._save_checkpoint(
225
+ _Upload.from_local_path(ckpt_path, root_config),
226
+ _Upload.from_local_path(metadata_path, root_config)
227
+ if metadata_path is not None
228
+ else None,
229
+ )
230
+
231
+ @cached_property
232
+ def api(self):
233
+ # Create and authenticate the API instance
234
+ if (api := _api(self.config.token)) is None:
235
+ raise ValueError("Failed to create Hugging Face Hub API instance.")
236
+ return api
237
+
238
+ @property
239
+ def repo_id(self):
240
+ if self._repo_id is None:
241
+ raise ValueError("Repository id has not been initialized.")
242
+ return self._repo_id
243
+
244
+ def _save_config(self, root_config: "BaseConfig"):
245
+ with self._with_error_handling("upload config"):
246
+ self.api.upload_file(
247
+ path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
248
+ path_in_repo="config.json",
249
+ repo_id=self.repo_id,
250
+ repo_type="model",
251
+ run_as_future=cast(Any, self.config.save_in_background),
252
+ )
253
+
254
+ def _save_code(self):
255
+ # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
256
+ # then upload all contents within the snapshot directory to the repository.
257
+ if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
258
+ log.debug("No snapshot directory found. Skipping upload.")
259
+ return
260
+
261
+ with self._with_error_handling("save code"):
262
+ snapshot_dir = Path(snapshot_dir)
263
+ if not snapshot_dir.exists() or not snapshot_dir.is_dir():
264
+ log.warning(
265
+ f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
266
+ )
267
+ return
268
+
269
+ self.api.upload_folder(
270
+ folder_path=str(snapshot_dir),
271
+ repo_id=self.repo_id,
272
+ repo_type="model",
273
+ path_in_repo="code", # Prefix with "code" folder
274
+ run_as_future=cast(Any, self.config.save_in_background),
275
+ )
276
+
277
+ def _save_file(self, p: _Upload):
278
+ with self._with_error_handling("save file"):
279
+ # Upload the checkpoint files to the repository
280
+ self.api.upload_file(
281
+ path_or_fileobj=p.local_path,
282
+ path_in_repo=str(p.path_in_repo),
283
+ repo_id=self.repo_id,
284
+ repo_type="model",
285
+ run_as_future=cast(Any, self.config.save_in_background),
286
+ )
287
+
288
+ def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
289
+ # Create a commit for copying the files
290
+ from huggingface_hub.hf_api import CommitOperationCopy
291
+
292
+ with self._with_error_handling("copy file"):
293
+ copy_op = CommitOperationCopy(
294
+ src_path_in_repo=str(source_path_in_repo),
295
+ path_in_repo=str(dest_path_in_repo),
296
+ )
297
+
298
+ self.api.create_commit(
299
+ repo_id=self.repo_id,
300
+ repo_type="model",
301
+ commit_message="Copy checkpoint file",
302
+ operations=[copy_op],
303
+ run_as_future=cast(Any, self.config.save_in_background),
304
+ )
305
+
306
+ def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
307
+ if not self.config.save_checkpoints:
308
+ return
309
+
310
+ # If no metadata, just save regularly.
311
+ if metadata_path is None:
312
+ self._save_file(path)
313
+ return
314
+
315
+ # Otherwise, let's check to see if we've already uploaded the metadata.
316
+ # If so, we can just copy the checkpoint file.
317
+ from ._checkpoint.metadata import CheckpointMetadata
318
+
319
+ metadata = CheckpointMetadata.from_file(metadata_path.local_path)
320
+ if (
321
+ existing_ckpt_path := self._checksum_to_path_in_repo.get(
322
+ metadata.checkpoint_checksum
323
+ )
324
+ ) is not None:
325
+ self._copy_file(existing_ckpt_path, path.path_in_repo)
326
+ else:
327
+ # Otherwise, we save the checkpoint & keep the checksum so we don't
328
+ # re-upload the same file again.
329
+ self._save_file(path)
330
+ self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
331
+ path.path_in_repo
332
+ )
333
+
334
+ # Save the metadata file
335
+ # NOTE: This file is fairly small, so we can just upload it directly.
336
+ # No need to copy.
337
+ self._save_file(metadata_path)
338
+
339
+ @override
340
+ def state_dict(self):
341
+ return {
342
+ "repo_id": self._repo_id,
343
+ "checksum_to_path_in_repo": {
344
+ k: str(v) for k, v in self._checksum_to_path_in_repo.items()
345
+ },
346
+ }
347
+
348
+ @override
349
+ def load_state_dict(self, state_dict):
350
+ self._repo_id = state_dict["repo_id"]
351
+ self._checksum_to_path_in_repo = {
352
+ k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
353
+ }
@@ -11,7 +11,6 @@ from typing_extensions import TypeVar, override
11
11
 
12
12
  from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
13
13
  from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
- from ...util.path import find_symlinks
15
14
  from ..base import CallbackConfigBase
16
15
 
17
16
  if TYPE_CHECKING:
@@ -117,47 +116,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
117
116
  )
118
117
  continue
119
118
 
120
- self._remove_checkpoint_with_link_support(
121
- trainer, old_ckpt_path, metadata=True
122
- )
119
+ _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
123
120
  log.debug(f"Removed old checkpoint: {old_ckpt_path}")
124
121
 
125
- def _remove_checkpoint_with_link_support(
126
- self,
127
- trainer: Trainer,
128
- path: Path,
129
- metadata: bool,
130
- ):
131
- # Find all the symlinks to the checkpoint
132
- ckpt_callbacks: list[CheckpointBase] = [
133
- callback
134
- for callback in trainer.checkpoint_callbacks
135
- if isinstance(callback, CheckpointBase) and callback is not self
136
- ]
137
- symlink_paths = find_symlinks(
138
- path,
139
- *[callback.dirpath for callback in ckpt_callbacks],
140
- glob_pattern=f"*.{self.extension()}",
141
- )
142
-
143
- # If there are no symlinks, just remove the checkpoint
144
- if not symlink_paths:
145
- _remove_checkpoint(trainer, path, metadata=metadata)
146
- return
147
-
148
- log.debug(
149
- f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
150
- )
151
-
152
- # For the first symlink, we can just move the checkpoint file
153
- # to the symlink path. For the rest, we need to make new symlinks.
154
- new_target = symlink_paths.pop(0)
155
- path.rename(new_target)
156
- log.debug(f"New symlink target: {new_target}")
157
-
158
- for symlink_path in symlink_paths:
159
- _link_checkpoint(new_target, symlink_path, metadata=False)
160
-
161
122
  def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
162
123
  current_metrics: dict[str, Any] = {
163
124
  "epoch": trainer.current_epoch,
@@ -195,6 +156,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
195
156
  filepath,
196
157
  self.config.save_weights_only,
197
158
  use_checkpoint_cache=None,
159
+ ckpt_cache_use_symlink=False,
198
160
  )
199
161
 
200
162
  if trainer.is_global_zero:
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
70
70
 
71
71
  # Events
72
72
  @override
73
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
73
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
74
74
  self.save_checkpoints(trainer)
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
39
39
  return True
40
40
 
41
41
  @override
42
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
42
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
43
43
  self.save_checkpoints(trainer)
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import os
3
+ import shutil
3
4
  from collections.abc import Sequence
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, cast
@@ -416,11 +417,12 @@ class Trainer(LightningTrainer):
416
417
  weights_only: bool = False,
417
418
  storage_options: Any | None = None,
418
419
  use_checkpoint_cache: bool | None = None,
420
+ ckpt_cache_use_symlink: bool = False,
419
421
  ):
420
422
  lm = self._base_module
421
- hparams = cast(BaseConfig, lm.hparams)
423
+ root_config = cast(BaseConfig, lm.hparams)
422
424
  if use_checkpoint_cache is None:
423
- use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
425
+ use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
424
426
 
425
427
  filepath = Path(filepath)
426
428
 
@@ -440,7 +442,10 @@ class Trainer(LightningTrainer):
440
442
  # If we have a cached path, then we symlink it to the new path.
441
443
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
444
  if self.is_global_zero:
443
- _link_checkpoint(cached_path, filepath, metadata=False)
445
+ if ckpt_cache_use_symlink:
446
+ _link_checkpoint(cached_path, filepath, metadata=False)
447
+ else:
448
+ shutil.copy(cached_path, filepath)
444
449
  self.strategy.barrier("Trainer.save_checkpoint")
445
450
  else:
446
451
  super().save_checkpoint(filepath, weights_only, storage_options)
@@ -454,7 +459,7 @@ class Trainer(LightningTrainer):
454
459
 
455
460
  # Save the checkpoint metadata
456
461
  metadata_path = None
457
- if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
462
+ if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
458
463
  # Generate the metadata and write to disk
459
464
  if (
460
465
  metadata_path := _write_checkpoint_metadata(self, lm, filepath)
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import os
2
3
  from pathlib import Path
3
4
  from typing import TypeAlias
@@ -50,3 +51,20 @@ def find_symlinks(
50
51
  pass
51
52
 
52
53
  return symlinks
54
+
55
+
56
+ def compute_file_checksum(file_path: Path) -> str:
57
+ """
58
+ Calculate the SHA256 checksum of a file.
59
+
60
+ Args:
61
+ file_path (Path): The path to the file.
62
+
63
+ Returns:
64
+ str: The hexadecimal representation of the file's SHA256 checksum.
65
+ """
66
+ sha256_hash = hashlib.sha256()
67
+ with file_path.open("rb") as f:
68
+ for byte_block in iter(lambda: f.read(4096), b""):
69
+ sha256_hash.update(byte_block)
70
+ return sha256_hash.hexdigest()