nshtrainer 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +3 -1
- nshtrainer/_hf_hub.py +171 -336
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/trainer/trainer.py +0 -11
- nshtrainer/util/path.py +18 -0
- {nshtrainer-0.24.0.dist-info → nshtrainer-0.25.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.24.0.dist-info → nshtrainer-0.25.0.dist-info}/RECORD +9 -9
- {nshtrainer-0.24.0.dist-info → nshtrainer-0.25.0.dist-info}/WHEEL +0 -0
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
-
from ..util.path import get_relative_path
|
|
14
|
+
from ..util.path import compute_file_checksum, get_relative_path
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
|
|
|
28
28
|
|
|
29
29
|
checkpoint_path: Path
|
|
30
30
|
checkpoint_filename: str
|
|
31
|
+
checkpoint_checksum: str
|
|
31
32
|
|
|
32
33
|
run_id: str
|
|
33
34
|
name: str
|
|
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
|
|
|
81
82
|
# moving the checkpoint directory
|
|
82
83
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
83
84
|
checkpoint_filename=checkpoint_path.name,
|
|
85
|
+
checkpoint_checksum=compute_file_checksum(checkpoint_path),
|
|
84
86
|
run_id=config.id,
|
|
85
87
|
name=config.run_name,
|
|
86
88
|
project=config.project,
|
nshtrainer/_hf_hub.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
import
|
|
1
|
+
import contextlib
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
5
6
|
from functools import cached_property
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import TYPE_CHECKING, Any, cast
|
|
@@ -17,7 +18,6 @@ if TYPE_CHECKING:
|
|
|
17
18
|
from huggingface_hub import HfApi # noqa: F401
|
|
18
19
|
|
|
19
20
|
from .model.base import BaseConfig
|
|
20
|
-
from .trainer.trainer import Trainer
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
log = logging.getLogger(__name__)
|
|
@@ -109,24 +109,6 @@ def _api(token: str | None = None):
|
|
|
109
109
|
return api
|
|
110
110
|
|
|
111
111
|
|
|
112
|
-
def _enabled_and_valid(
|
|
113
|
-
trainer: "Trainer",
|
|
114
|
-
callback: "HFHubCallback",
|
|
115
|
-
config: HuggingFaceHubConfig,
|
|
116
|
-
*,
|
|
117
|
-
rank_zero_only: bool,
|
|
118
|
-
):
|
|
119
|
-
# Make sure this is enabled and the config is valid
|
|
120
|
-
if not config:
|
|
121
|
-
return None
|
|
122
|
-
|
|
123
|
-
# If `rank_zero_only` and this is not rank 0, stop here.
|
|
124
|
-
if rank_zero_only and not trainer.is_global_zero:
|
|
125
|
-
return None
|
|
126
|
-
|
|
127
|
-
return callback._hf_hub_api
|
|
128
|
-
|
|
129
|
-
|
|
130
112
|
def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
131
113
|
username = None
|
|
132
114
|
if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
|
|
@@ -162,357 +144,210 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
|
162
144
|
return f"{username}/{repo_name}"
|
|
163
145
|
|
|
164
146
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
147
|
+
@dataclass
|
|
148
|
+
class _Upload:
|
|
149
|
+
local_path: Path
|
|
150
|
+
path_in_repo: Path
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_local_path(
|
|
154
|
+
cls,
|
|
155
|
+
local_path: Path,
|
|
156
|
+
root_config: "BaseConfig",
|
|
157
|
+
):
|
|
158
|
+
# Resolve the checkpoint directory
|
|
159
|
+
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
160
|
+
root_config.id, "checkpoint"
|
|
173
161
|
)
|
|
174
|
-
) is None or not config.auto_create:
|
|
175
|
-
return
|
|
176
|
-
|
|
177
|
-
from huggingface_hub.utils import RepositoryNotFoundError
|
|
178
|
-
|
|
179
|
-
# Resolve the repository name
|
|
180
|
-
repo_name = _repo_name(api, root_config)
|
|
181
162
|
|
|
182
|
-
# Create the repository, if it doesn't exist
|
|
183
|
-
try:
|
|
184
|
-
# Check if the repository exists
|
|
185
|
-
api.repo_info(repo_id=repo_name, repo_type="model")
|
|
186
|
-
log.info(f"Repository '{repo_name}' already exists.")
|
|
187
|
-
except RepositoryNotFoundError:
|
|
188
|
-
# Repository doesn't exist, so create it
|
|
189
163
|
try:
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
exist_ok=True,
|
|
164
|
+
relative_path = local_path.relative_to(checkpoint_dir)
|
|
165
|
+
except ValueError:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
|
|
195
168
|
)
|
|
196
|
-
log.info(f"Created new repository '{repo_name}'.")
|
|
197
|
-
except Exception:
|
|
198
|
-
log.exception(f"Failed to create repository '{repo_name}'")
|
|
199
|
-
except Exception:
|
|
200
|
-
log.exception(f"Error checking repository '{repo_name}'")
|
|
201
|
-
|
|
202
|
-
# Upload the config
|
|
203
|
-
_save_config(root_config, trainer=trainer, callback=callback)
|
|
204
|
-
|
|
205
|
-
# Upload the code
|
|
206
|
-
_save_code(repo_name, config=config, trainer=trainer, callback=callback)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def _save_code(
|
|
210
|
-
repo_name: str,
|
|
211
|
-
*,
|
|
212
|
-
config: HuggingFaceHubConfig,
|
|
213
|
-
trainer: "Trainer",
|
|
214
|
-
callback: "HFHubCallback",
|
|
215
|
-
):
|
|
216
|
-
if (
|
|
217
|
-
api := _enabled_and_valid(
|
|
218
|
-
trainer,
|
|
219
|
-
callback,
|
|
220
|
-
config,
|
|
221
|
-
rank_zero_only=True,
|
|
222
|
-
)
|
|
223
|
-
) is None or not config.save_code:
|
|
224
|
-
return
|
|
225
|
-
|
|
226
|
-
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
227
|
-
# then upload all contents within the snapshot directory to the repository.
|
|
228
|
-
snapshot_dir = os.environ.get(SNAPSHOT_DIR)
|
|
229
|
-
if not snapshot_dir:
|
|
230
|
-
log.info("No snapshot directory found. Skipping upload.")
|
|
231
|
-
return
|
|
232
|
-
|
|
233
|
-
snapshot_path = Path(snapshot_dir)
|
|
234
|
-
if not snapshot_path.exists() or not snapshot_path.is_dir():
|
|
235
|
-
log.warning(
|
|
236
|
-
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
237
|
-
)
|
|
238
|
-
return
|
|
239
169
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
folder_path=str(snapshot_path),
|
|
243
|
-
repo_id=repo_name,
|
|
244
|
-
repo_type="model",
|
|
245
|
-
path_in_repo="code", # Prefix with "code" folder
|
|
246
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
247
|
-
)
|
|
248
|
-
log.info(
|
|
249
|
-
f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
250
|
-
)
|
|
251
|
-
except Exception:
|
|
252
|
-
log.exception(
|
|
253
|
-
f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
|
|
254
|
-
)
|
|
170
|
+
# Prefix the path in repo with "checkpoints"
|
|
171
|
+
path_in_repo = Path("checkpoints") / relative_path
|
|
255
172
|
|
|
173
|
+
return cls(local_path=local_path, path_in_repo=path_in_repo)
|
|
256
174
|
|
|
257
|
-
def _save_config(
|
|
258
|
-
root_config: "BaseConfig",
|
|
259
|
-
*,
|
|
260
|
-
trainer: "Trainer",
|
|
261
|
-
callback: "HFHubCallback",
|
|
262
|
-
):
|
|
263
|
-
config = root_config.trainer.hf_hub
|
|
264
|
-
if (
|
|
265
|
-
api := _enabled_and_valid(
|
|
266
|
-
trainer,
|
|
267
|
-
callback,
|
|
268
|
-
config,
|
|
269
|
-
rank_zero_only=True,
|
|
270
|
-
)
|
|
271
|
-
) is None or not config.save_config:
|
|
272
|
-
return
|
|
273
175
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
176
|
+
class HFHubCallback(NTCallbackBase):
|
|
177
|
+
@contextlib.contextmanager
|
|
178
|
+
def _with_error_handling(self, opeartion: str):
|
|
179
|
+
try:
|
|
180
|
+
yield
|
|
181
|
+
except Exception:
|
|
182
|
+
log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
|
|
183
|
+
else:
|
|
184
|
+
log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
|
|
277
185
|
|
|
278
|
-
|
|
279
|
-
|
|
186
|
+
def __init__(self, config: HuggingFaceHubConfig):
|
|
187
|
+
super().__init__()
|
|
280
188
|
|
|
281
|
-
|
|
282
|
-
try:
|
|
283
|
-
api.upload_file(
|
|
284
|
-
path_or_fileobj=config_json.encode("utf-8"),
|
|
285
|
-
path_in_repo="config.json",
|
|
286
|
-
repo_id=repo_name,
|
|
287
|
-
repo_type="model",
|
|
288
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
289
|
-
)
|
|
290
|
-
log.info(f"Uploaded config.json to repository '{repo_name}'.")
|
|
291
|
-
except Exception:
|
|
292
|
-
log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
|
|
189
|
+
self.config = config
|
|
293
190
|
|
|
191
|
+
self._repo_id = None
|
|
192
|
+
self._checksum_to_path_in_repo: dict[str, Path] = {}
|
|
294
193
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
194
|
+
@override
|
|
195
|
+
def setup(self, trainer, pl_module, stage):
|
|
196
|
+
from .trainer.trainer import Trainer
|
|
298
197
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
log.info(f"Failed to check if path {p} is a symlink.", exc_info=True)
|
|
198
|
+
if not isinstance(trainer, Trainer):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
|
|
201
|
+
)
|
|
304
202
|
|
|
305
|
-
|
|
203
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
306
204
|
|
|
205
|
+
# Create the repository, if it doesn't exist
|
|
206
|
+
self._repo_id = self.api.create_repo(
|
|
207
|
+
repo_id=_repo_name(self.api, root_config),
|
|
208
|
+
repo_type="model",
|
|
209
|
+
private=self.config.auto_create.private,
|
|
210
|
+
exist_ok=True,
|
|
211
|
+
)
|
|
307
212
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
213
|
+
# Upload the config and code
|
|
214
|
+
self._save_config(root_config)
|
|
215
|
+
self._save_code()
|
|
311
216
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
resolved := trainer._nshtrainer_checkpoint_link_dict.get(link_path)
|
|
316
|
-
) is not None:
|
|
317
|
-
return resolved
|
|
318
|
-
except Exception:
|
|
319
|
-
log.info(f"Failed to resolve symlink for path {p}.", exc_info=True)
|
|
320
|
-
|
|
321
|
-
return None
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
def _save_checkpoint_files(
|
|
325
|
-
trainer: "Trainer",
|
|
326
|
-
callback: "HFHubCallback",
|
|
327
|
-
paths: list[Path],
|
|
328
|
-
*,
|
|
329
|
-
root_config: "BaseConfig",
|
|
330
|
-
):
|
|
331
|
-
config = root_config.trainer.hf_hub
|
|
332
|
-
if (
|
|
333
|
-
api := _enabled_and_valid(trainer, callback, config, rank_zero_only=True)
|
|
334
|
-
) is None or not config.save_checkpoints:
|
|
335
|
-
return
|
|
336
|
-
|
|
337
|
-
# Resolve the checkpoint directory
|
|
338
|
-
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
339
|
-
root_config.id, "checkpoint"
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
# Resolve the repository name
|
|
343
|
-
repo_name = _repo_name(api, root_config)
|
|
344
|
-
|
|
345
|
-
# Let's read all the files to memory right now,
|
|
346
|
-
# in case they get used/removed by other processes.
|
|
347
|
-
# Read all the files to memory
|
|
348
|
-
file_contents: list[bytes | None] = []
|
|
349
|
-
for p in paths:
|
|
350
|
-
assert not _is_link(p, trainer=trainer), f"Path {p} is a symlink."
|
|
351
|
-
assert p.is_file(), f"Path {p} is not a file."
|
|
352
|
-
try:
|
|
353
|
-
with open(p, "rb") as f:
|
|
354
|
-
file_contents.append(f.read())
|
|
355
|
-
except IOError:
|
|
356
|
-
log.exception(f"Failed to read checkpoint file {p}.")
|
|
357
|
-
file_contents.append(None)
|
|
358
|
-
|
|
359
|
-
# Remove the paths that failed to read
|
|
360
|
-
file_contents_and_paths = [
|
|
361
|
-
(contents, p)
|
|
362
|
-
for contents, p in zip(file_contents, paths)
|
|
363
|
-
if contents is not None
|
|
364
|
-
]
|
|
365
|
-
|
|
366
|
-
# Upload the checkpoint files to the repository
|
|
367
|
-
for contents, p in file_contents_and_paths:
|
|
368
|
-
try:
|
|
369
|
-
relative_path = p.relative_to(checkpoint_dir)
|
|
370
|
-
except ValueError:
|
|
371
|
-
log.warning(
|
|
372
|
-
f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
|
|
373
|
-
)
|
|
374
|
-
continue
|
|
217
|
+
@override
|
|
218
|
+
def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
|
|
219
|
+
root_config = cast("BaseConfig", pl_module.hparams)
|
|
375
220
|
|
|
376
|
-
#
|
|
377
|
-
|
|
221
|
+
# If HF Hub is enabled, then we upload
|
|
222
|
+
if self.config and trainer.is_global_zero:
|
|
223
|
+
with self._with_error_handling("save checkpoints"):
|
|
224
|
+
self._save_checkpoint(
|
|
225
|
+
_Upload.from_local_path(ckpt_path, root_config),
|
|
226
|
+
_Upload.from_local_path(metadata_path, root_config)
|
|
227
|
+
if metadata_path is not None
|
|
228
|
+
else None,
|
|
229
|
+
)
|
|
378
230
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
231
|
+
@cached_property
|
|
232
|
+
def api(self):
|
|
233
|
+
# Create and authenticate the API instance
|
|
234
|
+
if (api := _api(self.config.token)) is None:
|
|
235
|
+
raise ValueError("Failed to create Hugging Face Hub API instance.")
|
|
236
|
+
return api
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def repo_id(self):
|
|
240
|
+
if self._repo_id is None:
|
|
241
|
+
raise ValueError("Repository id has not been initialized.")
|
|
242
|
+
return self._repo_id
|
|
243
|
+
|
|
244
|
+
def _save_config(self, root_config: "BaseConfig"):
|
|
245
|
+
with self._with_error_handling("upload config"):
|
|
246
|
+
self.api.upload_file(
|
|
247
|
+
path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
|
|
248
|
+
path_in_repo="config.json",
|
|
249
|
+
repo_id=self.repo_id,
|
|
385
250
|
repo_type="model",
|
|
386
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
251
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
387
252
|
)
|
|
388
|
-
log.info(
|
|
389
|
-
f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
|
|
390
|
-
)
|
|
391
|
-
except Exception:
|
|
392
|
-
log.exception(
|
|
393
|
-
f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
|
|
394
|
-
)
|
|
395
|
-
|
|
396
|
-
log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
def _save_checkpoint_symlinks(
|
|
400
|
-
trainer: "Trainer",
|
|
401
|
-
callback: "HFHubCallback",
|
|
402
|
-
paths: list[Path],
|
|
403
|
-
*,
|
|
404
|
-
root_config: "BaseConfig",
|
|
405
|
-
):
|
|
406
|
-
config = root_config.trainer.hf_hub
|
|
407
|
-
if (
|
|
408
|
-
api := _enabled_and_valid(trainer, callback, config, rank_zero_only=True)
|
|
409
|
-
) is None or not config.save_checkpoints:
|
|
410
|
-
return
|
|
411
|
-
|
|
412
|
-
# Resolve the checkpoint directory
|
|
413
|
-
checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
414
|
-
root_config.id, "checkpoint"
|
|
415
|
-
)
|
|
416
|
-
|
|
417
|
-
# Resolve the repository name
|
|
418
|
-
repo_name = _repo_name(api, root_config)
|
|
419
|
-
|
|
420
|
-
# Create a commit for copying the files
|
|
421
|
-
from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
|
|
422
253
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
254
|
+
def _save_code(self):
|
|
255
|
+
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
|
256
|
+
# then upload all contents within the snapshot directory to the repository.
|
|
257
|
+
if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
|
|
258
|
+
log.debug("No snapshot directory found. Skipping upload.")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
with self._with_error_handling("save code"):
|
|
262
|
+
snapshot_dir = Path(snapshot_dir)
|
|
263
|
+
if not snapshot_dir.exists() or not snapshot_dir.is_dir():
|
|
264
|
+
log.warning(
|
|
265
|
+
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
|
266
|
+
)
|
|
267
|
+
return
|
|
426
268
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
269
|
+
self.api.upload_folder(
|
|
270
|
+
folder_path=str(snapshot_dir),
|
|
271
|
+
repo_id=self.repo_id,
|
|
272
|
+
repo_type="model",
|
|
273
|
+
path_in_repo="code", # Prefix with "code" folder
|
|
274
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
432
275
|
)
|
|
433
|
-
continue
|
|
434
276
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
277
|
+
def _save_file(self, p: _Upload):
|
|
278
|
+
with self._with_error_handling("save file"):
|
|
279
|
+
# Upload the checkpoint files to the repository
|
|
280
|
+
self.api.upload_file(
|
|
281
|
+
path_or_fileobj=p.local_path,
|
|
282
|
+
path_in_repo=str(p.path_in_repo),
|
|
283
|
+
repo_id=self.repo_id,
|
|
284
|
+
repo_type="model",
|
|
285
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
444
286
|
)
|
|
445
|
-
continue
|
|
446
287
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
# Create and append a CommitOperationCopy for copying the symlink
|
|
452
|
-
copy_op = CommitOperationCopy(
|
|
453
|
-
src_path_in_repo=str(source_path_in_repo),
|
|
454
|
-
path_in_repo=str(dest_path_in_repo),
|
|
455
|
-
)
|
|
456
|
-
commits.append(copy_op)
|
|
457
|
-
|
|
458
|
-
log.info(f"Creating a commit with the following operations: {commits}")
|
|
288
|
+
def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
|
|
289
|
+
# Create a commit for copying the files
|
|
290
|
+
from huggingface_hub.hf_api import CommitOperationCopy
|
|
459
291
|
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
operations=commits,
|
|
466
|
-
run_as_future=cast(Any, config.save_in_background),
|
|
467
|
-
)
|
|
468
|
-
log.info(
|
|
469
|
-
f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
|
|
470
|
-
)
|
|
471
|
-
except Exception:
|
|
472
|
-
log.exception(
|
|
473
|
-
f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
|
|
474
|
-
)
|
|
292
|
+
with self._with_error_handling("copy file"):
|
|
293
|
+
copy_op = CommitOperationCopy(
|
|
294
|
+
src_path_in_repo=str(source_path_in_repo),
|
|
295
|
+
path_in_repo=str(dest_path_in_repo),
|
|
296
|
+
)
|
|
475
297
|
|
|
476
|
-
|
|
298
|
+
self.api.create_commit(
|
|
299
|
+
repo_id=self.repo_id,
|
|
300
|
+
repo_type="model",
|
|
301
|
+
commit_message="Copy checkpoint file",
|
|
302
|
+
operations=[copy_op],
|
|
303
|
+
run_as_future=cast(Any, self.config.save_in_background),
|
|
304
|
+
)
|
|
477
305
|
|
|
306
|
+
def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
|
|
307
|
+
if not self.config.save_checkpoints:
|
|
308
|
+
return
|
|
478
309
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
310
|
+
# If no metadata, just save regularly.
|
|
311
|
+
if metadata_path is None:
|
|
312
|
+
self._save_file(path)
|
|
313
|
+
return
|
|
483
314
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
from .
|
|
315
|
+
# Otherwise, let's check to see if we've already uploaded the metadata.
|
|
316
|
+
# If so, we can just copy the checkpoint file.
|
|
317
|
+
from ._checkpoint.metadata import CheckpointMetadata
|
|
487
318
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
319
|
+
metadata = CheckpointMetadata.from_file(metadata_path.local_path)
|
|
320
|
+
if (
|
|
321
|
+
existing_ckpt_path := self._checksum_to_path_in_repo.get(
|
|
322
|
+
metadata.checkpoint_checksum
|
|
323
|
+
)
|
|
324
|
+
) is not None:
|
|
325
|
+
self._copy_file(existing_ckpt_path, path.path_in_repo)
|
|
326
|
+
else:
|
|
327
|
+
# Otherwise, we save the checkpoint & keep the checksum so we don't
|
|
328
|
+
# re-upload the same file again.
|
|
329
|
+
self._save_file(path)
|
|
330
|
+
self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
|
|
331
|
+
path.path_in_repo
|
|
491
332
|
)
|
|
492
333
|
|
|
493
|
-
|
|
494
|
-
|
|
334
|
+
# Save the metadata file
|
|
335
|
+
# NOTE: This file is fairly small, so we can just upload it directly.
|
|
336
|
+
# No need to copy.
|
|
337
|
+
self._save_file(metadata_path)
|
|
495
338
|
|
|
496
339
|
@override
|
|
497
|
-
def
|
|
498
|
-
|
|
340
|
+
def state_dict(self):
|
|
341
|
+
return {
|
|
342
|
+
"repo_id": self._repo_id,
|
|
343
|
+
"checksum_to_path_in_repo": {
|
|
344
|
+
k: str(v) for k, v in self._checksum_to_path_in_repo.items()
|
|
345
|
+
},
|
|
346
|
+
}
|
|
499
347
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
]:
|
|
507
|
-
_save_checkpoint_files(
|
|
508
|
-
trainer, self, regular_paths, root_config=root_config
|
|
509
|
-
)
|
|
510
|
-
if symlink_paths := [p for p in all_paths if _is_link(p, trainer=trainer)]:
|
|
511
|
-
_save_checkpoint_symlinks(
|
|
512
|
-
trainer, self, symlink_paths, root_config=root_config
|
|
513
|
-
)
|
|
514
|
-
|
|
515
|
-
@cached_property
|
|
516
|
-
def _hf_hub_api(self):
|
|
517
|
-
# Create and authenticate the API instance
|
|
518
|
-
return _api(self.config.token)
|
|
348
|
+
@override
|
|
349
|
+
def load_state_dict(self, state_dict):
|
|
350
|
+
self._repo_id = state_dict["repo_id"]
|
|
351
|
+
self._checksum_to_path_in_repo = {
|
|
352
|
+
k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
|
|
353
|
+
}
|
|
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
|
70
70
|
|
|
71
71
|
# Events
|
|
72
72
|
@override
|
|
73
|
-
def
|
|
73
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
74
74
|
self.save_checkpoints(trainer)
|
|
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
|
39
39
|
return True
|
|
40
40
|
|
|
41
41
|
@override
|
|
42
|
-
def
|
|
42
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
43
43
|
self.save_checkpoints(trainer)
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -280,13 +280,6 @@ class Trainer(LightningTrainer):
|
|
|
280
280
|
if TYPE_CHECKING:
|
|
281
281
|
callbacks: list[Callback]
|
|
282
282
|
|
|
283
|
-
def _nshtrainer_ckpt_link(self, ckpt_path: Path):
|
|
284
|
-
root_config = cast(BaseConfig, self._base_module.hparams)
|
|
285
|
-
ckpt_dir = root_config.directory.resolve_subdirectory(
|
|
286
|
-
root_config.id, "checkpoint"
|
|
287
|
-
)
|
|
288
|
-
return str(ckpt_path.absolute().relative_to(ckpt_dir))
|
|
289
|
-
|
|
290
283
|
@override
|
|
291
284
|
def __init__(
|
|
292
285
|
self,
|
|
@@ -295,7 +288,6 @@ class Trainer(LightningTrainer):
|
|
|
295
288
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
296
289
|
):
|
|
297
290
|
self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
|
|
298
|
-
self._nshtrainer_checkpoint_link_dict = dict[str, Path]()
|
|
299
291
|
|
|
300
292
|
self._pre_init(config)
|
|
301
293
|
|
|
@@ -454,9 +446,6 @@ class Trainer(LightningTrainer):
|
|
|
454
446
|
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
455
447
|
else:
|
|
456
448
|
shutil.copy(cached_path, filepath)
|
|
457
|
-
self._nshtrainer_checkpoint_link_dict[
|
|
458
|
-
self._nshtrainer_ckpt_link(filepath)
|
|
459
|
-
] = cached_path
|
|
460
449
|
self.strategy.barrier("Trainer.save_checkpoint")
|
|
461
450
|
else:
|
|
462
451
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
nshtrainer/util/path.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import os
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import TypeAlias
|
|
@@ -50,3 +51,20 @@ def find_symlinks(
|
|
|
50
51
|
pass
|
|
51
52
|
|
|
52
53
|
return symlinks
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_file_checksum(file_path: Path) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Calculate the SHA256 checksum of a file.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
file_path (Path): The path to the file.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
str: The hexadecimal representation of the file's SHA256 checksum.
|
|
65
|
+
"""
|
|
66
|
+
sha256_hash = hashlib.sha256()
|
|
67
|
+
with file_path.open("rb") as f:
|
|
68
|
+
for byte_block in iter(lambda: f.read(4096), b""):
|
|
69
|
+
sha256_hash.update(byte_block)
|
|
70
|
+
return sha256_hash.hexdigest()
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=BpxC3VGrgSDHvZpc40A2icjBEDvRNGEsTis9YkFY8Kc,5341
|
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
|
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
|
|
8
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
9
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
11
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
|
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
15
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
15
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
17
17
|
nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
|
|
18
18
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
78
78
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
79
79
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
80
80
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
81
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
81
|
+
nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
|
|
82
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
83
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
84
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
85
|
-
nshtrainer/util/path.py,sha256=
|
|
85
|
+
nshtrainer/util/path.py,sha256=RUkIOrlj9b8zPPXE3JLhdihBNitJSUWw1whZ33u-2Yk,2005
|
|
86
86
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
87
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.
|
|
91
|
-
nshtrainer-0.
|
|
92
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.25.0.dist-info/METADATA,sha256=Rqdeh2yp2AhZ_nOHlD47v5YPDrLc2fHN6WGwqJnDv04,935
|
|
91
|
+
nshtrainer-0.25.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.25.0.dist-info/RECORD,,
|
|
File without changes
|