nshtrainer 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +3 -1
- nshtrainer/_hf_hub.py +171 -336
- nshtrainer/callbacks/checkpoint/_base.py +2 -40
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/trainer/trainer.py +9 -4
- nshtrainer/util/path.py +18 -0
- {nshtrainer-0.23.0.dist-info → nshtrainer-0.25.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.23.0.dist-info → nshtrainer-0.25.0.dist-info}/RECORD +10 -10
- {nshtrainer-0.23.0.dist-info → nshtrainer-0.25.0.dist-info}/WHEEL +0 -0
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
-
from ..util.path import get_relative_path
|
|
14
|
+
from ..util.path import compute_file_checksum, get_relative_path
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
|
|
|
28
28
|
|
|
29
29
|
checkpoint_path: Path
|
|
30
30
|
checkpoint_filename: str
|
|
31
|
+
checkpoint_checksum: str
|
|
31
32
|
|
|
32
33
|
run_id: str
|
|
33
34
|
name: str
|
|
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
|
|
|
81
82
|
# moving the checkpoint directory
|
|
82
83
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
83
84
|
checkpoint_filename=checkpoint_path.name,
|
|
85
|
+
checkpoint_checksum=compute_file_checksum(checkpoint_path),
|
|
84
86
|
run_id=config.id,
|
|
85
87
|
name=config.run_name,
|
|
86
88
|
project=config.project,
|
nshtrainer/_hf_hub.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
|
-
import
|
|
1
|
+
import contextlib
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import cached_property
|
|
5
7
|
from pathlib import Path
|
|
6
8
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
9
|
|
|
@@ -16,7 +18,6 @@ if TYPE_CHECKING:
|
|
|
16
18
|
from huggingface_hub import HfApi # noqa: F401
|
|
17
19
|
|
|
18
20
|
from .model.base import BaseConfig
|
|
19
|
-
from .trainer.trainer import Trainer
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
log = logging.getLogger(__name__)
|
|
@@ -108,36 +109,6 @@ def _api(token: str | None = None):
|
|
|
108
109
|
return api
|
|
109
110
|
|
|
110
111
|
|
|
111
|
-
def _enabled_and_valid(
|
|
112
|
-
trainer: "Trainer",
|
|
113
|
-
config: HuggingFaceHubConfig,
|
|
114
|
-
*,
|
|
115
|
-
rank_zero_only: bool,
|
|
116
|
-
):
|
|
117
|
-
# Make sure this is enabled and the config is valid
|
|
118
|
-
if not config:
|
|
119
|
-
return None
|
|
120
|
-
|
|
121
|
-
# If `rank_zero_only` and this is not rank 0, stop here.
|
|
122
|
-
if rank_zero_only and not trainer.is_global_zero:
|
|
123
|
-
return
|
|
124
|
-
|
|
125
|
-
# Make sure that `huggingface_hub` is installed
|
|
126
|
-
try:
|
|
127
|
-
import huggingface_hub # noqa: F401
|
|
128
|
-
except ImportError:
|
|
129
|
-
log.exception(
|
|
130
|
-
"Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
|
|
131
|
-
)
|
|
132
|
-
return None
|
|
133
|
-
|
|
134
|
-
# Create and authenticate the API instance
|
|
135
|
-
if (api := getattr(trainer, "_hf_hub_api", None)) is None:
|
|
136
|
-
api = _api(config.token)
|
|
137
|
-
setattr(trainer, "_hf_hub_api", api)
|
|
138
|
-
return cast(huggingface_hub.HfApi, api)
|
|
139
|
-
|
|
140
|
-
|
|
141
112
|
def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
142
113
|
username = None
|
|
143
114
|
if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
|
|
@@ -173,346 +144,210 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
|
173
144
|
return f"{username}/{repo_name}"
|
|
174
145
|
|
|
175
146
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
repo_name = _repo_name(api, root_config)
|
|
191
|
-
|
|
192
|
-
# Create the repository, if it doesn't exist
|
|
193
|
-
try:
|
|
194
|
-
# Check if the repository exists
|
|
195
|
-
api.repo_info(repo_id=repo_name, repo_type="model")
|
|
196
|
-
log.info(f"Repository '{repo_name}' already exists.")
|
|
197
|
-
except RepositoryNotFoundError:
|
|
198
|
-
# Repository doesn't exist, so create it
|
|
199
|
-
try:
|
|
200
|
-
api.create_repo(
|
|
201
|
-
repo_id=repo_name,
|
|
202
|
-
repo_type="model",
|
|
203
|
-
private=config.auto_create.private,
|
|
204
|
-
exist_ok=True,
|
|
205
|
-
)
|
|
206
|
-
log.info(f"Created new repository '{repo_name}'.")
|
|
207
|
-
except Exception:
|
|
208
|
-
log.exception(f"Failed to create repository '{repo_name}'")
|
|
209
|
-
except Exception:
|
|
210
|
-
log.exception(f"Error checking repository '{repo_name}'")
|
|
211
|
-
|
|
212
|
-
# Upload the config
|
|
213
|
-
_save_config(root_config, trainer=trainer)
|
|
214
|
-
|
|
215
|
-
# Upload the code
|
|
216
|
-
_save_code(repo_name, config=config, trainer=trainer)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def _save_code(
|
|
220
|
-
repo_name: str,
|
|
221
|
-
*,
|
|
222
|
-
config: HuggingFaceHubConfig,
|
|
223
|
-
trainer: "Trainer",
|
|
224
|
-
):
|
|
225
|
-
if (
|
|
226
|
-
api := _enabled_and_valid(
|
|
227
|
-
trainer,
|
|
228
|
-
config,
|
|
229
|
-
rank_zero_only=True,
|
|
230
|
-
)
|
|
231
|
-
) is None or not config.save_code:
|
|
232
|
-
return
|
|
233
|
-
|
|
234
|
-
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
235
|
-
# then upload all contents within the snapshot directory to the repository.
|
|
236
|
-
snapshot_dir = os.environ.get(SNAPSHOT_DIR)
|
|
237
|
-
if not snapshot_dir:
|
|
238
|
-
log.info("No snapshot directory found. Skipping upload.")
|
|
239
|
-
return
|
|
240
|
-
|
|
241
|
-
snapshot_path = Path(snapshot_dir)
|
|
242
|
-
if not snapshot_path.exists() or not snapshot_path.is_dir():
|
|
243
|
-
log.warning(
|
|
244
|
-
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
245
|
-
)
|
|
246
|
-
return
|
|
247
|
-
|
|
248
|
-
try:
|
|
249
|
-
api.upload_folder(
|
|
250
|
-
folder_path=str(snapshot_path),
|
|
251
|
-
repo_id=repo_name,
|
|
252
|
-
repo_type="model",
|
|
253
|
-
path_in_repo="code", # Prefix with "code" folder
|
|
254
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
255
|
-
)
|
|
256
|
-
log.info(
|
|
257
|
-
f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
258
|
-
)
|
|
259
|
-
except Exception:
|
|
260
|
-
log.exception(
|
|
261
|
-
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def _save_config(
|
|
266
|
-
root_config: "BaseConfig",
|
|
267
|
-
*,
|
|
268
|
-
trainer: "Trainer",
|
|
269
|
-
):
|
|
270
|
-
config = root_config.trainer.hf_hub
|
|
271
|
-
if (
|
|
272
|
-
api := _enabled_and_valid(
|
|
273
|
-
trainer,
|
|
274
|
-
config,
|
|
275
|
-
rank_zero_only=True,
|
|
147
|
+
@dataclass
|
|
148
|
+
class _Upload:
|
|
149
|
+
local_path: Path
|
|
150
|
+
path_in_repo: Path
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_local_path(
|
|
154
|
+
cls,
|
|
155
|
+
local_path: Path,
|
|
156
|
+
root_config: "BaseConfig",
|
|
157
|
+
):
|
|
158
|
+
# Resolve the checkpoint directory
|
|
159
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
160
|
+
root_config.id, "checkpoint"
|
|
276
161
|
)
|
|
277
|
-
) is None or not config.save_config:
|
|
278
|
-
return
|
|
279
|
-
|
|
280
|
-
# Convert the root config to a JSON string
|
|
281
|
-
# NOTE: This is a utf-8 string.
|
|
282
|
-
config_json = root_config.model_dump_json(indent=4)
|
|
283
162
|
|
|
284
|
-
# Resolve the repository name
|
|
285
|
-
repo_name = _repo_name(api, root_config)
|
|
286
|
-
|
|
287
|
-
# Upload the config file to the repository
|
|
288
|
-
try:
|
|
289
|
-
api.upload_file(
|
|
290
|
-
path_or_fileobj=config_json.encode("utf-8"),
|
|
291
|
-
path_in_repo="config.json",
|
|
292
|
-
repo_id=repo_name,
|
|
293
|
-
repo_type="model",
|
|
294
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
295
|
-
)
|
|
296
|
-
log.info(f"Uploaded config.json to repository '{repo_name}'.")
|
|
297
|
-
except Exception:
|
|
298
|
-
log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def _save_checkpoint_files(
|
|
302
|
-
trainer: "Trainer",
|
|
303
|
-
paths: list[Path],
|
|
304
|
-
*,
|
|
305
|
-
root_config: "BaseConfig",
|
|
306
|
-
):
|
|
307
|
-
config = root_config.trainer.hf_hub
|
|
308
|
-
if (
|
|
309
|
-
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
310
|
-
) is None or not config.save_checkpoints:
|
|
311
|
-
return
|
|
312
|
-
|
|
313
|
-
# Resolve the checkpoint directory
|
|
314
|
-
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
315
|
-
root_config.id, "checkpoint"
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
# Resolve the repository name
|
|
319
|
-
repo_name = _repo_name(api, root_config)
|
|
320
|
-
|
|
321
|
-
# Let's read all the files to memory right now,
|
|
322
|
-
# in case they get used/removed by other processes.
|
|
323
|
-
# Read all the files to memory
|
|
324
|
-
file_contents: list[bytes | None] = []
|
|
325
|
-
for p in paths:
|
|
326
|
-
assert not p.is_symlink(), f"Path {p} is a symlink."
|
|
327
|
-
assert p.is_file(), f"Path {p} is not a file."
|
|
328
|
-
try:
|
|
329
|
-
with open(p, "rb") as f:
|
|
330
|
-
file_contents.append(f.read())
|
|
331
|
-
except IOError:
|
|
332
|
-
log.exception(f"Failed to read checkpoint file {p}.")
|
|
333
|
-
file_contents.append(None)
|
|
334
|
-
|
|
335
|
-
# Remove the paths that failed to read
|
|
336
|
-
file_contents_and_paths = [
|
|
337
|
-
(contents, p)
|
|
338
|
-
for contents, p in zip(file_contents, paths)
|
|
339
|
-
if contents is not None
|
|
340
|
-
]
|
|
341
|
-
|
|
342
|
-
# Upload the checkpoint files to the repository
|
|
343
|
-
for contents, p in file_contents_and_paths:
|
|
344
163
|
try:
|
|
345
|
-
relative_path =
|
|
164
|
+
relative_path = local_path.relative_to(checkpoint_dir)
|
|
346
165
|
except ValueError:
|
|
347
|
-
|
|
348
|
-
f"Checkpoint path {
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
|
|
349
168
|
)
|
|
350
|
-
continue
|
|
351
169
|
|
|
352
170
|
# Prefix the path in repo with "checkpoints"
|
|
353
171
|
path_in_repo = Path("checkpoints") / relative_path
|
|
354
172
|
|
|
355
|
-
|
|
356
|
-
try:
|
|
357
|
-
api.upload_file(
|
|
358
|
-
path_or_fileobj=io.BytesIO(contents),
|
|
359
|
-
path_in_repo=str(path_in_repo),
|
|
360
|
-
repo_id=repo_name,
|
|
361
|
-
repo_type="model",
|
|
362
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
363
|
-
)
|
|
364
|
-
log.info(
|
|
365
|
-
f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
|
|
366
|
-
)
|
|
367
|
-
except Exception:
|
|
368
|
-
log.exception(
|
|
369
|
-
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
373
|
-
|
|
173
|
+
return cls(local_path=local_path, path_in_repo=path_in_repo)
|
|
374
174
|
|
|
375
|
-
def _save_checkpoint_symlinks(
|
|
376
|
-
trainer: "Trainer",
|
|
377
|
-
paths: list[Path],
|
|
378
|
-
*,
|
|
379
|
-
root_config: "BaseConfig",
|
|
380
|
-
):
|
|
381
|
-
config = root_config.trainer.hf_hub
|
|
382
|
-
if (
|
|
383
|
-
api := _enabled_and_valid(trainer, config, rank_zero_only=True)
|
|
384
|
-
) is None or not config.save_checkpoints:
|
|
385
|
-
return
|
|
386
175
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
176
|
+
class HFHubCallback(NTCallbackBase):
|
|
177
|
+
@contextlib.contextmanager
|
|
178
|
+
def _with_error_handling(self, opeartion: str):
|
|
179
|
+
try:
|
|
180
|
+
yield
|
|
181
|
+
except Exception:
|
|
182
|
+
log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
|
|
183
|
+
else:
|
|
184
|
+
log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
|
|
391
185
|
|
|
392
|
-
|
|
393
|
-
|
|
186
|
+
def __init__(self, config: HuggingFaceHubConfig):
|
|
187
|
+
super().__init__()
|
|
394
188
|
|
|
395
|
-
|
|
396
|
-
from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
|
|
189
|
+
self.config = config
|
|
397
190
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
assert p.is_symlink(), f"Path {p} is not a symlink."
|
|
191
|
+
self._repo_id = None
|
|
192
|
+
self._checksum_to_path_in_repo: dict[str, Path] = {}
|
|
401
193
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
log.warning(
|
|
406
|
-
f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
|
|
407
|
-
)
|
|
408
|
-
continue
|
|
194
|
+
@override
|
|
195
|
+
def setup(self, trainer, pl_module, stage):
|
|
196
|
+
from .trainer.trainer import Trainer
|
|
409
197
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
log.warning(
|
|
414
|
-
f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
|
|
198
|
+
if not isinstance(trainer, Trainer):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
415
201
|
)
|
|
416
|
-
continue
|
|
417
202
|
|
|
418
|
-
|
|
419
|
-
dest_path_in_repo = Path("checkpoints") / dest_relative_path
|
|
420
|
-
source_path_in_repo = Path("checkpoints") / source_relative_path
|
|
203
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
421
204
|
|
|
422
|
-
# Create
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
205
|
+
# Create the repository, if it doesn't exist
|
|
206
|
+
self._repo_id = self.api.create_repo(
|
|
207
|
+
repo_id=_repo_name(self.api, root_config),
|
|
208
|
+
repo_type="model",
|
|
209
|
+
private=self.config.auto_create.private,
|
|
210
|
+
exist_ok=True,
|
|
426
211
|
)
|
|
427
|
-
commits.append(copy_op)
|
|
428
212
|
|
|
429
|
-
|
|
213
|
+
# Upload the config and code
|
|
214
|
+
self._save_config(root_config)
|
|
215
|
+
self._save_code()
|
|
430
216
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
repo_type="model",
|
|
435
|
-
commit_message="Copy checkpoint symlinks",
|
|
436
|
-
operations=commits,
|
|
437
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
438
|
-
)
|
|
439
|
-
log.info(
|
|
440
|
-
f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
|
|
441
|
-
)
|
|
442
|
-
except Exception:
|
|
443
|
-
log.exception(
|
|
444
|
-
f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
|
|
445
|
-
)
|
|
217
|
+
@override
|
|
218
|
+
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
219
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
446
220
|
|
|
447
|
-
|
|
221
|
+
# If HF Hub is enabled, then we upload
|
|
222
|
+
if self.config and trainer.is_global_zero:
|
|
223
|
+
with self._with_error_handling("save checkpoints"):
|
|
224
|
+
self._save_checkpoint(
|
|
225
|
+
_Upload.from_local_path(ckpt_path, root_config),
|
|
226
|
+
_Upload.from_local_path(metadata_path, root_config)
|
|
227
|
+
if metadata_path is not None
|
|
228
|
+
else None,
|
|
229
|
+
)
|
|
448
230
|
|
|
231
|
+
@cached_property
|
|
232
|
+
def api(self):
|
|
233
|
+
# Create and authenticate the API instance
|
|
234
|
+
if (api := _api(self.config.token)) is None:
|
|
235
|
+
raise ValueError("Failed to create Hugging Face Hub API instance.")
|
|
236
|
+
return api
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def repo_id(self):
|
|
240
|
+
if self._repo_id is None:
|
|
241
|
+
raise ValueError("Repository id has not been initialized.")
|
|
242
|
+
return self._repo_id
|
|
243
|
+
|
|
244
|
+
def _save_config(self, root_config: "BaseConfig"):
|
|
245
|
+
with self._with_error_handling("upload config"):
|
|
246
|
+
self.api.upload_file(
|
|
247
|
+
path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
|
|
248
|
+
path_in_repo="config.json",
|
|
249
|
+
repo_id=self.repo_id,
|
|
250
|
+
repo_type="model",
|
|
251
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
252
|
+
)
|
|
449
253
|
|
|
450
|
-
def
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
254
|
+
def _save_code(self):
|
|
255
|
+
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
256
|
+
# then upload all contents within the snapshot directory to the repository.
|
|
257
|
+
if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
|
|
258
|
+
log.debug("No snapshot directory found. Skipping upload.")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
with self._with_error_handling("save code"):
|
|
262
|
+
snapshot_dir = Path(snapshot_dir)
|
|
263
|
+
if not snapshot_dir.exists() or not snapshot_dir.is_dir():
|
|
264
|
+
log.warning(
|
|
265
|
+
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
266
|
+
)
|
|
267
|
+
return
|
|
456
268
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
269
|
+
self.api.upload_folder(
|
|
270
|
+
folder_path=str(snapshot_dir),
|
|
271
|
+
repo_id=self.repo_id,
|
|
272
|
+
repo_type="model",
|
|
273
|
+
path_in_repo="code", # Prefix with "code" folder
|
|
274
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
275
|
+
)
|
|
461
276
|
|
|
462
|
-
|
|
463
|
-
|
|
277
|
+
def _save_file(self, p: _Upload):
|
|
278
|
+
with self._with_error_handling("save file"):
|
|
279
|
+
# Upload the checkpoint files to the repository
|
|
280
|
+
self.api.upload_file(
|
|
281
|
+
path_or_fileobj=p.local_path,
|
|
282
|
+
path_in_repo=str(p.path_in_repo),
|
|
283
|
+
repo_id=self.repo_id,
|
|
284
|
+
repo_type="model",
|
|
285
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
286
|
+
)
|
|
464
287
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
folder_path=str(checkpoint_dir),
|
|
469
|
-
repo_id=repo_name,
|
|
470
|
-
repo_type="model",
|
|
471
|
-
path_in_repo="checkpoints",
|
|
472
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
473
|
-
)
|
|
474
|
-
log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
|
|
475
|
-
except Exception:
|
|
476
|
-
log.exception(
|
|
477
|
-
f"Failed to upload checkpoint directory to repository '{repo_name}'."
|
|
478
|
-
)
|
|
288
|
+
def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
|
|
289
|
+
# Create a commit for copying the files
|
|
290
|
+
from huggingface_hub.hf_api import CommitOperationCopy
|
|
479
291
|
|
|
480
|
-
|
|
292
|
+
with self._with_error_handling("copy file"):
|
|
293
|
+
copy_op = CommitOperationCopy(
|
|
294
|
+
src_path_in_repo=str(source_path_in_repo),
|
|
295
|
+
path_in_repo=str(dest_path_in_repo),
|
|
296
|
+
)
|
|
481
297
|
|
|
298
|
+
self.api.create_commit(
|
|
299
|
+
repo_id=self.repo_id,
|
|
300
|
+
repo_type="model",
|
|
301
|
+
commit_message="Copy checkpoint file",
|
|
302
|
+
operations=[copy_op],
|
|
303
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
304
|
+
)
|
|
482
305
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
self.config = config
|
|
306
|
+
def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
|
|
307
|
+
if not self.config.save_checkpoints:
|
|
308
|
+
return
|
|
487
309
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
310
|
+
# If no metadata, just save regularly.
|
|
311
|
+
if metadata_path is None:
|
|
312
|
+
self._save_file(path)
|
|
313
|
+
return
|
|
491
314
|
|
|
492
|
-
if
|
|
493
|
-
|
|
494
|
-
|
|
315
|
+
# Otherwise, let's check to see if we've already uploaded the metadata.
|
|
316
|
+
# If so, we can just copy the checkpoint file.
|
|
317
|
+
from ._checkpoint.metadata import CheckpointMetadata
|
|
318
|
+
|
|
319
|
+
metadata = CheckpointMetadata.from_file(metadata_path.local_path)
|
|
320
|
+
if (
|
|
321
|
+
existing_ckpt_path := self._checksum_to_path_in_repo.get(
|
|
322
|
+
metadata.checkpoint_checksum
|
|
323
|
+
)
|
|
324
|
+
) is not None:
|
|
325
|
+
self._copy_file(existing_ckpt_path, path.path_in_repo)
|
|
326
|
+
else:
|
|
327
|
+
# Otherwise, we save the checkpoint & keep the checksum so we don't
|
|
328
|
+
# re-upload the same file again.
|
|
329
|
+
self._save_file(path)
|
|
330
|
+
self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
|
|
331
|
+
path.path_in_repo
|
|
495
332
|
)
|
|
496
333
|
|
|
497
|
-
|
|
498
|
-
|
|
334
|
+
# Save the metadata file
|
|
335
|
+
# NOTE: This file is fairly small, so we can just upload it directly.
|
|
336
|
+
# No need to copy.
|
|
337
|
+
self._save_file(metadata_path)
|
|
499
338
|
|
|
500
339
|
@override
|
|
501
|
-
def
|
|
502
|
-
|
|
503
|
-
|
|
340
|
+
def state_dict(self):
|
|
341
|
+
return {
|
|
342
|
+
"repo_id": self._repo_id,
|
|
343
|
+
"checksum_to_path_in_repo": {
|
|
344
|
+
k: str(v) for k, v in self._checksum_to_path_in_repo.items()
|
|
345
|
+
},
|
|
346
|
+
}
|
|
504
347
|
|
|
505
348
|
@override
|
|
506
|
-
def
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
# Upload the regular files first, then the symlinks
|
|
512
|
-
all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
|
|
513
|
-
if regular_paths := [p for p in all_paths if not p.is_symlink()]:
|
|
514
|
-
_save_checkpoint_files(trainer, regular_paths, root_config=root_config)
|
|
515
|
-
if symlink_paths := [p for p in all_paths if p.is_symlink()]:
|
|
516
|
-
_save_checkpoint_symlinks(
|
|
517
|
-
trainer, symlink_paths, root_config=root_config
|
|
518
|
-
)
|
|
349
|
+
def load_state_dict(self, state_dict):
|
|
350
|
+
self._repo_id = state_dict["repo_id"]
|
|
351
|
+
self._checksum_to_path_in_repo = {
|
|
352
|
+
k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
|
|
353
|
+
}
|
|
@@ -11,7 +11,6 @@ from typing_extensions import TypeVar, override
|
|
|
11
11
|
|
|
12
12
|
from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
|
|
13
13
|
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
-
from ...util.path import find_symlinks
|
|
15
14
|
from ..base import CallbackConfigBase
|
|
16
15
|
|
|
17
16
|
if TYPE_CHECKING:
|
|
@@ -117,47 +116,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
117
116
|
)
|
|
118
117
|
continue
|
|
119
118
|
|
|
120
|
-
|
|
121
|
-
trainer, old_ckpt_path, metadata=True
|
|
122
|
-
)
|
|
119
|
+
_remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
|
123
120
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
124
121
|
|
|
125
|
-
def _remove_checkpoint_with_link_support(
|
|
126
|
-
self,
|
|
127
|
-
trainer: Trainer,
|
|
128
|
-
path: Path,
|
|
129
|
-
metadata: bool,
|
|
130
|
-
):
|
|
131
|
-
# Find all the symlinks to the checkpoint
|
|
132
|
-
ckpt_callbacks: list[CheckpointBase] = [
|
|
133
|
-
callback
|
|
134
|
-
for callback in trainer.checkpoint_callbacks
|
|
135
|
-
if isinstance(callback, CheckpointBase) and callback is not self
|
|
136
|
-
]
|
|
137
|
-
symlink_paths = find_symlinks(
|
|
138
|
-
path,
|
|
139
|
-
*[callback.dirpath for callback in ckpt_callbacks],
|
|
140
|
-
glob_pattern=f"*.{self.extension()}",
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
# If there are no symlinks, just remove the checkpoint
|
|
144
|
-
if not symlink_paths:
|
|
145
|
-
_remove_checkpoint(trainer, path, metadata=metadata)
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
log.debug(
|
|
149
|
-
f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
# For the first symlink, we can just move the checkpoint file
|
|
153
|
-
# to the symlink path. For the rest, we need to make new symlinks.
|
|
154
|
-
new_target = symlink_paths.pop(0)
|
|
155
|
-
path.rename(new_target)
|
|
156
|
-
log.debug(f"New symlink target: {new_target}")
|
|
157
|
-
|
|
158
|
-
for symlink_path in symlink_paths:
|
|
159
|
-
_link_checkpoint(new_target, symlink_path, metadata=False)
|
|
160
|
-
|
|
161
122
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
162
123
|
current_metrics: dict[str, Any] = {
|
|
163
124
|
"epoch": trainer.current_epoch,
|
|
@@ -195,6 +156,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
195
156
|
filepath,
|
|
196
157
|
self.config.save_weights_only,
|
|
197
158
|
use_checkpoint_cache=None,
|
|
159
|
+
ckpt_cache_use_symlink=False,
|
|
198
160
|
)
|
|
199
161
|
|
|
200
162
|
if trainer.is_global_zero:
|
|
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
|
70
70
|
|
|
71
71
|
# Events
|
|
72
72
|
@override
|
|
73
|
-
def
|
|
73
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
74
74
|
self.save_checkpoints(trainer)
|
|
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
|
39
39
|
return True
|
|
40
40
|
|
|
41
41
|
@override
|
|
42
|
-
def
|
|
42
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
43
43
|
self.save_checkpoints(trainer)
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import shutil
|
|
3
4
|
from collections.abc import Sequence
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
@@ -416,11 +417,12 @@ class Trainer(LightningTrainer):
|
|
|
416
417
|
weights_only: bool = False,
|
|
417
418
|
storage_options: Any | None = None,
|
|
418
419
|
use_checkpoint_cache: bool | None = None,
|
|
420
|
+
ckpt_cache_use_symlink: bool = False,
|
|
419
421
|
):
|
|
420
422
|
lm = self._base_module
|
|
421
|
-
|
|
423
|
+
root_config = cast(BaseConfig, lm.hparams)
|
|
422
424
|
if use_checkpoint_cache is None:
|
|
423
|
-
use_checkpoint_cache =
|
|
425
|
+
use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
|
|
424
426
|
|
|
425
427
|
filepath = Path(filepath)
|
|
426
428
|
|
|
@@ -440,7 +442,10 @@ class Trainer(LightningTrainer):
|
|
|
440
442
|
# If we have a cached path, then we symlink it to the new path.
|
|
441
443
|
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
444
|
if self.is_global_zero:
|
|
443
|
-
|
|
445
|
+
if ckpt_cache_use_symlink:
|
|
446
|
+
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
447
|
+
else:
|
|
448
|
+
shutil.copy(cached_path, filepath)
|
|
444
449
|
self.strategy.barrier("Trainer.save_checkpoint")
|
|
445
450
|
else:
|
|
446
451
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
@@ -454,7 +459,7 @@ class Trainer(LightningTrainer):
|
|
|
454
459
|
|
|
455
460
|
# Save the checkpoint metadata
|
|
456
461
|
metadata_path = None
|
|
457
|
-
if
|
|
462
|
+
if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
458
463
|
# Generate the metadata and write to disk
|
|
459
464
|
if (
|
|
460
465
|
metadata_path := _write_checkpoint_metadata(self, lm, filepath)
|
nshtrainer/util/path.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import os
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import TypeAlias
|
|
@@ -50,3 +51,20 @@ def find_symlinks(
|
|
|
50
51
|
pass
|
|
51
52
|
|
|
52
53
|
return symlinks
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_file_checksum(file_path: Path) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Calculate the SHA256 checksum of a file.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
file_path (Path): The path to the file.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
str: The hexadecimal representation of the file's SHA256 checksum.
|
|
65
|
+
"""
|
|
66
|
+
sha256_hash = hashlib.sha256()
|
|
67
|
+
with file_path.open("rb") as f:
|
|
68
|
+
for byte_block in iter(lambda: f.read(4096), b""):
|
|
69
|
+
sha256_hash.update(byte_block)
|
|
70
|
+
return sha256_hash.hexdigest()
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=BpxC3VGrgSDHvZpc40A2icjBEDvRNGEsTis9YkFY8Kc,5341
|
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
|
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
|
|
8
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
9
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
11
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
15
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
|
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
15
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
17
17
|
nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
|
|
18
18
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
78
78
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
79
79
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
80
80
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
81
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
81
|
+
nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
|
|
82
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
83
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
84
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
85
|
-
nshtrainer/util/path.py,sha256=
|
|
85
|
+
nshtrainer/util/path.py,sha256=RUkIOrlj9b8zPPXE3JLhdihBNitJSUWw1whZ33u-2Yk,2005
|
|
86
86
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
87
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.
|
|
91
|
-
nshtrainer-0.
|
|
92
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.25.0.dist-info/METADATA,sha256=Rqdeh2yp2AhZ_nOHlD47v5YPDrLc2fHN6WGwqJnDv04,935
|
|
91
|
+
nshtrainer-0.25.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.25.0.dist-info/RECORD,,
|
|
File without changes
|