nshtrainer 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import torch
12
12
 
13
13
  from ..util._environment_info import EnvironmentConfig
14
- from ..util.path import get_relative_path
14
+ from ..util.path import compute_file_checksum, get_relative_path
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ..model import BaseConfig, LightningModuleBase
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
28
28
 
29
29
  checkpoint_path: Path
30
30
  checkpoint_filename: str
31
+ checkpoint_checksum: str
31
32
 
32
33
  run_id: str
33
34
  name: str
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
81
82
  # moving the checkpoint directory
82
83
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
83
84
  checkpoint_filename=checkpoint_path.name,
85
+ checkpoint_checksum=compute_file_checksum(checkpoint_path),
84
86
  run_id=config.id,
85
87
  name=config.run_name,
86
88
  project=config.project,
nshtrainer/_hf_hub.py CHANGED
@@ -1,7 +1,9 @@
1
- import io
1
+ import contextlib
2
2
  import logging
3
3
  import os
4
4
  import re
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
5
7
  from pathlib import Path
6
8
  from typing import TYPE_CHECKING, Any, cast
7
9
 
@@ -16,7 +18,6 @@ if TYPE_CHECKING:
16
18
  from huggingface_hub import HfApi # noqa: F401
17
19
 
18
20
  from .model.base import BaseConfig
19
- from .trainer.trainer import Trainer
20
21
 
21
22
 
22
23
  log = logging.getLogger(__name__)
@@ -108,36 +109,6 @@ def _api(token: str | None = None):
108
109
  return api
109
110
 
110
111
 
111
- def _enabled_and_valid(
112
- trainer: "Trainer",
113
- config: HuggingFaceHubConfig,
114
- *,
115
- rank_zero_only: bool,
116
- ):
117
- # Make sure this is enabled and the config is valid
118
- if not config:
119
- return None
120
-
121
- # If `rank_zero_only` and this is not rank 0, stop here.
122
- if rank_zero_only and not trainer.is_global_zero:
123
- return
124
-
125
- # Make sure that `huggingface_hub` is installed
126
- try:
127
- import huggingface_hub # noqa: F401
128
- except ImportError:
129
- log.exception(
130
- "Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
131
- )
132
- return None
133
-
134
- # Create and authenticate the API instance
135
- if (api := getattr(trainer, "_hf_hub_api", None)) is None:
136
- api = _api(config.token)
137
- setattr(trainer, "_hf_hub_api", api)
138
- return cast(huggingface_hub.HfApi, api)
139
-
140
-
141
112
  def _repo_name(api: "HfApi", root_config: "BaseConfig"):
142
113
  username = None
143
114
  if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
@@ -173,346 +144,210 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
173
144
  return f"{username}/{repo_name}"
174
145
 
175
146
 
176
- def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
177
- config = root_config.trainer.hf_hub
178
- if (
179
- api := _enabled_and_valid(
180
- trainer,
181
- config,
182
- rank_zero_only=True,
183
- )
184
- ) is None or not config.auto_create:
185
- return
186
-
187
- from huggingface_hub.utils import RepositoryNotFoundError
188
-
189
- # Resolve the repository name
190
- repo_name = _repo_name(api, root_config)
191
-
192
- # Create the repository, if it doesn't exist
193
- try:
194
- # Check if the repository exists
195
- api.repo_info(repo_id=repo_name, repo_type="model")
196
- log.info(f"Repository '{repo_name}' already exists.")
197
- except RepositoryNotFoundError:
198
- # Repository doesn't exist, so create it
199
- try:
200
- api.create_repo(
201
- repo_id=repo_name,
202
- repo_type="model",
203
- private=config.auto_create.private,
204
- exist_ok=True,
205
- )
206
- log.info(f"Created new repository '{repo_name}'.")
207
- except Exception:
208
- log.exception(f"Failed to create repository '{repo_name}'")
209
- except Exception:
210
- log.exception(f"Error checking repository '{repo_name}'")
211
-
212
- # Upload the config
213
- _save_config(root_config, trainer=trainer)
214
-
215
- # Upload the code
216
- _save_code(repo_name, config=config, trainer=trainer)
217
-
218
-
219
- def _save_code(
220
- repo_name: str,
221
- *,
222
- config: HuggingFaceHubConfig,
223
- trainer: "Trainer",
224
- ):
225
- if (
226
- api := _enabled_and_valid(
227
- trainer,
228
- config,
229
- rank_zero_only=True,
230
- )
231
- ) is None or not config.save_code:
232
- return
233
-
234
- # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
235
- # then upload all contents within the snapshot directory to the repository.
236
- snapshot_dir = os.environ.get(SNAPSHOT_DIR)
237
- if not snapshot_dir:
238
- log.info("No snapshot directory found. Skipping upload.")
239
- return
240
-
241
- snapshot_path = Path(snapshot_dir)
242
- if not snapshot_path.exists() or not snapshot_path.is_dir():
243
- log.warning(
244
- f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
245
- )
246
- return
247
-
248
- try:
249
- api.upload_folder(
250
- folder_path=str(snapshot_path),
251
- repo_id=repo_name,
252
- repo_type="model",
253
- path_in_repo="code", # Prefix with "code" folder
254
- run_as_future=cast(Any, config.save_in_background),
255
- )
256
- log.info(
257
- f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
258
- )
259
- except Exception:
260
- log.exception(
261
- f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
262
- )
263
-
264
-
265
- def _save_config(
266
- root_config: "BaseConfig",
267
- *,
268
- trainer: "Trainer",
269
- ):
270
- config = root_config.trainer.hf_hub
271
- if (
272
- api := _enabled_and_valid(
273
- trainer,
274
- config,
275
- rank_zero_only=True,
147
+ @dataclass
148
+ class _Upload:
149
+ local_path: Path
150
+ path_in_repo: Path
151
+
152
+ @classmethod
153
+ def from_local_path(
154
+ cls,
155
+ local_path: Path,
156
+ root_config: "BaseConfig",
157
+ ):
158
+ # Resolve the checkpoint directory
159
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
160
+ root_config.id, "checkpoint"
276
161
  )
277
- ) is None or not config.save_config:
278
- return
279
-
280
- # Convert the root config to a JSON string
281
- # NOTE: This is a utf-8 string.
282
- config_json = root_config.model_dump_json(indent=4)
283
162
 
284
- # Resolve the repository name
285
- repo_name = _repo_name(api, root_config)
286
-
287
- # Upload the config file to the repository
288
- try:
289
- api.upload_file(
290
- path_or_fileobj=config_json.encode("utf-8"),
291
- path_in_repo="config.json",
292
- repo_id=repo_name,
293
- repo_type="model",
294
- run_as_future=cast(Any, config.save_in_background),
295
- )
296
- log.info(f"Uploaded config.json to repository '{repo_name}'.")
297
- except Exception:
298
- log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
299
-
300
-
301
- def _save_checkpoint_files(
302
- trainer: "Trainer",
303
- paths: list[Path],
304
- *,
305
- root_config: "BaseConfig",
306
- ):
307
- config = root_config.trainer.hf_hub
308
- if (
309
- api := _enabled_and_valid(trainer, config, rank_zero_only=True)
310
- ) is None or not config.save_checkpoints:
311
- return
312
-
313
- # Resolve the checkpoint directory
314
- checkpoint_dir = root_config.directory.resolve_subdirectory(
315
- root_config.id, "checkpoint"
316
- )
317
-
318
- # Resolve the repository name
319
- repo_name = _repo_name(api, root_config)
320
-
321
- # Let's read all the files to memory right now,
322
- # in case they get used/removed by other processes.
323
- # Read all the files to memory
324
- file_contents: list[bytes | None] = []
325
- for p in paths:
326
- assert not p.is_symlink(), f"Path {p} is a symlink."
327
- assert p.is_file(), f"Path {p} is not a file."
328
- try:
329
- with open(p, "rb") as f:
330
- file_contents.append(f.read())
331
- except IOError:
332
- log.exception(f"Failed to read checkpoint file {p}.")
333
- file_contents.append(None)
334
-
335
- # Remove the paths that failed to read
336
- file_contents_and_paths = [
337
- (contents, p)
338
- for contents, p in zip(file_contents, paths)
339
- if contents is not None
340
- ]
341
-
342
- # Upload the checkpoint files to the repository
343
- for contents, p in file_contents_and_paths:
344
163
  try:
345
- relative_path = p.relative_to(checkpoint_dir)
164
+ relative_path = local_path.relative_to(checkpoint_dir)
346
165
  except ValueError:
347
- log.warning(
348
- f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
166
+ raise ValueError(
167
+ f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
349
168
  )
350
- continue
351
169
 
352
170
  # Prefix the path in repo with "checkpoints"
353
171
  path_in_repo = Path("checkpoints") / relative_path
354
172
 
355
- # Upload the checkpoint file to the repository
356
- try:
357
- api.upload_file(
358
- path_or_fileobj=io.BytesIO(contents),
359
- path_in_repo=str(path_in_repo),
360
- repo_id=repo_name,
361
- repo_type="model",
362
- run_as_future=cast(Any, config.save_in_background),
363
- )
364
- log.info(
365
- f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
366
- )
367
- except Exception:
368
- log.exception(
369
- f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
370
- )
371
-
372
- log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
373
-
173
+ return cls(local_path=local_path, path_in_repo=path_in_repo)
374
174
 
375
- def _save_checkpoint_symlinks(
376
- trainer: "Trainer",
377
- paths: list[Path],
378
- *,
379
- root_config: "BaseConfig",
380
- ):
381
- config = root_config.trainer.hf_hub
382
- if (
383
- api := _enabled_and_valid(trainer, config, rank_zero_only=True)
384
- ) is None or not config.save_checkpoints:
385
- return
386
175
 
387
- # Resolve the checkpoint directory
388
- checkpoint_dir = root_config.directory.resolve_subdirectory(
389
- root_config.id, "checkpoint"
390
- )
176
+ class HFHubCallback(NTCallbackBase):
177
+ @contextlib.contextmanager
178
+ def _with_error_handling(self, opeartion: str):
179
+ try:
180
+ yield
181
+ except Exception:
182
+ log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
183
+ else:
184
+ log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
391
185
 
392
- # Resolve the repository name
393
- repo_name = _repo_name(api, root_config)
186
+ def __init__(self, config: HuggingFaceHubConfig):
187
+ super().__init__()
394
188
 
395
- # Create a commit for copying the files
396
- from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
189
+ self.config = config
397
190
 
398
- commits: list[CommitOperation] = []
399
- for p in paths:
400
- assert p.is_symlink(), f"Path {p} is not a symlink."
191
+ self._repo_id = None
192
+ self._checksum_to_path_in_repo: dict[str, Path] = {}
401
193
 
402
- try:
403
- dest_relative_path = p.relative_to(checkpoint_dir)
404
- except ValueError:
405
- log.warning(
406
- f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
407
- )
408
- continue
194
+ @override
195
+ def setup(self, trainer, pl_module, stage):
196
+ from .trainer.trainer import Trainer
409
197
 
410
- try:
411
- source_relative_path = p.resolve().relative_to(checkpoint_dir)
412
- except ValueError:
413
- log.warning(
414
- f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
198
+ if not isinstance(trainer, Trainer):
199
+ raise ValueError(
200
+ f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
415
201
  )
416
- continue
417
202
 
418
- # Prefix the path in repo with "checkpoints"
419
- dest_path_in_repo = Path("checkpoints") / dest_relative_path
420
- source_path_in_repo = Path("checkpoints") / source_relative_path
203
+ root_config = cast("BaseConfig", pl_module.hparams)
421
204
 
422
- # Create and append a CommitOperationCopy for copying the symlink
423
- copy_op = CommitOperationCopy(
424
- src_path_in_repo=str(source_path_in_repo),
425
- path_in_repo=str(dest_path_in_repo),
205
+ # Create the repository, if it doesn't exist
206
+ self._repo_id = self.api.create_repo(
207
+ repo_id=_repo_name(self.api, root_config),
208
+ repo_type="model",
209
+ private=self.config.auto_create.private,
210
+ exist_ok=True,
426
211
  )
427
- commits.append(copy_op)
428
212
 
429
- log.info(f"Creating a commit with the following operations: {commits}")
213
+ # Upload the config and code
214
+ self._save_config(root_config)
215
+ self._save_code()
430
216
 
431
- try:
432
- api.create_commit(
433
- repo_id=repo_name,
434
- repo_type="model",
435
- commit_message="Copy checkpoint symlinks",
436
- operations=commits,
437
- run_as_future=cast(Any, config.save_in_background),
438
- )
439
- log.info(
440
- f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
441
- )
442
- except Exception:
443
- log.exception(
444
- f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
445
- )
217
+ @override
218
+ def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
+ root_config = cast("BaseConfig", pl_module.hparams)
446
220
 
447
- log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
221
+ # If HF Hub is enabled, then we upload
222
+ if self.config and trainer.is_global_zero:
223
+ with self._with_error_handling("save checkpoints"):
224
+ self._save_checkpoint(
225
+ _Upload.from_local_path(ckpt_path, root_config),
226
+ _Upload.from_local_path(metadata_path, root_config)
227
+ if metadata_path is not None
228
+ else None,
229
+ )
448
230
 
231
+ @cached_property
232
+ def api(self):
233
+ # Create and authenticate the API instance
234
+ if (api := _api(self.config.token)) is None:
235
+ raise ValueError("Failed to create Hugging Face Hub API instance.")
236
+ return api
237
+
238
+ @property
239
+ def repo_id(self):
240
+ if self._repo_id is None:
241
+ raise ValueError("Repository id has not been initialized.")
242
+ return self._repo_id
243
+
244
+ def _save_config(self, root_config: "BaseConfig"):
245
+ with self._with_error_handling("upload config"):
246
+ self.api.upload_file(
247
+ path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
248
+ path_in_repo="config.json",
249
+ repo_id=self.repo_id,
250
+ repo_type="model",
251
+ run_as_future=cast(Any, self.config.save_in_background),
252
+ )
449
253
 
450
- def _save_checkpoint_directory(trainer: "Trainer", *, root_config: "BaseConfig"):
451
- config = root_config.trainer.hf_hub
452
- if (
453
- api := _enabled_and_valid(trainer, config, rank_zero_only=True)
454
- ) is None or not config.save_checkpoints:
455
- return
254
+ def _save_code(self):
255
+ # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
256
+ # then upload all contents within the snapshot directory to the repository.
257
+ if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
258
+ log.debug("No snapshot directory found. Skipping upload.")
259
+ return
260
+
261
+ with self._with_error_handling("save code"):
262
+ snapshot_dir = Path(snapshot_dir)
263
+ if not snapshot_dir.exists() or not snapshot_dir.is_dir():
264
+ log.warning(
265
+ f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
266
+ )
267
+ return
456
268
 
457
- # Resolve the checkpoint directory
458
- checkpoint_dir = root_config.directory.resolve_subdirectory(
459
- root_config.id, "checkpoint"
460
- )
269
+ self.api.upload_folder(
270
+ folder_path=str(snapshot_dir),
271
+ repo_id=self.repo_id,
272
+ repo_type="model",
273
+ path_in_repo="code", # Prefix with "code" folder
274
+ run_as_future=cast(Any, self.config.save_in_background),
275
+ )
461
276
 
462
- # Resolve the repository name
463
- repo_name = _repo_name(api, root_config)
277
+ def _save_file(self, p: _Upload):
278
+ with self._with_error_handling("save file"):
279
+ # Upload the checkpoint files to the repository
280
+ self.api.upload_file(
281
+ path_or_fileobj=p.local_path,
282
+ path_in_repo=str(p.path_in_repo),
283
+ repo_id=self.repo_id,
284
+ repo_type="model",
285
+ run_as_future=cast(Any, self.config.save_in_background),
286
+ )
464
287
 
465
- # Upload the checkpoint directory to the repository
466
- try:
467
- api.upload_folder(
468
- folder_path=str(checkpoint_dir),
469
- repo_id=repo_name,
470
- repo_type="model",
471
- path_in_repo="checkpoints",
472
- run_as_future=cast(Any, config.save_in_background),
473
- )
474
- log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
475
- except Exception:
476
- log.exception(
477
- f"Failed to upload checkpoint directory to repository '{repo_name}'."
478
- )
288
+ def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
289
+ # Create a commit for copying the files
290
+ from huggingface_hub.hf_api import CommitOperationCopy
479
291
 
480
- log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
292
+ with self._with_error_handling("copy file"):
293
+ copy_op = CommitOperationCopy(
294
+ src_path_in_repo=str(source_path_in_repo),
295
+ path_in_repo=str(dest_path_in_repo),
296
+ )
481
297
 
298
+ self.api.create_commit(
299
+ repo_id=self.repo_id,
300
+ repo_type="model",
301
+ commit_message="Copy checkpoint file",
302
+ operations=[copy_op],
303
+ run_as_future=cast(Any, self.config.save_in_background),
304
+ )
482
305
 
483
- class HFHubCallback(NTCallbackBase):
484
- def __init__(self, config: HuggingFaceHubConfig):
485
- super().__init__()
486
- self.config = config
306
+ def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
307
+ if not self.config.save_checkpoints:
308
+ return
487
309
 
488
- @override
489
- def setup(self, trainer, pl_module, stage):
490
- from .trainer.trainer import Trainer
310
+ # If no metadata, just save regularly.
311
+ if metadata_path is None:
312
+ self._save_file(path)
313
+ return
491
314
 
492
- if not isinstance(trainer, Trainer):
493
- raise ValueError(
494
- f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
315
+ # Otherwise, let's check to see if we've already uploaded the metadata.
316
+ # If so, we can just copy the checkpoint file.
317
+ from ._checkpoint.metadata import CheckpointMetadata
318
+
319
+ metadata = CheckpointMetadata.from_file(metadata_path.local_path)
320
+ if (
321
+ existing_ckpt_path := self._checksum_to_path_in_repo.get(
322
+ metadata.checkpoint_checksum
323
+ )
324
+ ) is not None:
325
+ self._copy_file(existing_ckpt_path, path.path_in_repo)
326
+ else:
327
+ # Otherwise, we save the checkpoint & keep the checksum so we don't
328
+ # re-upload the same file again.
329
+ self._save_file(path)
330
+ self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
331
+ path.path_in_repo
495
332
  )
496
333
 
497
- root_config = cast("BaseConfig", pl_module.hparams)
498
- _init(trainer=trainer, root_config=root_config)
334
+ # Save the metadata file
335
+ # NOTE: This file is fairly small, so we can just upload it directly.
336
+ # No need to copy.
337
+ self._save_file(metadata_path)
499
338
 
500
339
  @override
501
- def teardown(self, trainer, pl_module, stage):
502
- if hasattr(trainer, "_hf_hub_api"):
503
- delattr(trainer, "_hf_hub_api")
340
+ def state_dict(self):
341
+ return {
342
+ "repo_id": self._repo_id,
343
+ "checksum_to_path_in_repo": {
344
+ k: str(v) for k, v in self._checksum_to_path_in_repo.items()
345
+ },
346
+ }
504
347
 
505
348
  @override
506
- def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
507
- root_config = cast("BaseConfig", pl_module.hparams)
508
-
509
- # If HF Hub is enabled, then we upload
510
- if root_config.trainer.hf_hub and trainer.is_global_zero:
511
- # Upload the regular files first, then the symlinks
512
- all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
513
- if regular_paths := [p for p in all_paths if not p.is_symlink()]:
514
- _save_checkpoint_files(trainer, regular_paths, root_config=root_config)
515
- if symlink_paths := [p for p in all_paths if p.is_symlink()]:
516
- _save_checkpoint_symlinks(
517
- trainer, symlink_paths, root_config=root_config
518
- )
349
+ def load_state_dict(self, state_dict):
350
+ self._repo_id = state_dict["repo_id"]
351
+ self._checksum_to_path_in_repo = {
352
+ k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
353
+ }
@@ -11,7 +11,6 @@ from typing_extensions import TypeVar, override
11
11
 
12
12
  from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
13
13
  from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
- from ...util.path import find_symlinks
15
14
  from ..base import CallbackConfigBase
16
15
 
17
16
  if TYPE_CHECKING:
@@ -117,47 +116,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
117
116
  )
118
117
  continue
119
118
 
120
- self._remove_checkpoint_with_link_support(
121
- trainer, old_ckpt_path, metadata=True
122
- )
119
+ _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
123
120
  log.debug(f"Removed old checkpoint: {old_ckpt_path}")
124
121
 
125
- def _remove_checkpoint_with_link_support(
126
- self,
127
- trainer: Trainer,
128
- path: Path,
129
- metadata: bool,
130
- ):
131
- # Find all the symlinks to the checkpoint
132
- ckpt_callbacks: list[CheckpointBase] = [
133
- callback
134
- for callback in trainer.checkpoint_callbacks
135
- if isinstance(callback, CheckpointBase) and callback is not self
136
- ]
137
- symlink_paths = find_symlinks(
138
- path,
139
- *[callback.dirpath for callback in ckpt_callbacks],
140
- glob_pattern=f"*.{self.extension()}",
141
- )
142
-
143
- # If there are no symlinks, just remove the checkpoint
144
- if not symlink_paths:
145
- _remove_checkpoint(trainer, path, metadata=metadata)
146
- return
147
-
148
- log.debug(
149
- f"Removing checkpoint with symlinks: {path}, symlinks: {symlink_paths}"
150
- )
151
-
152
- # For the first symlink, we can just move the checkpoint file
153
- # to the symlink path. For the rest, we need to make new symlinks.
154
- new_target = symlink_paths.pop(0)
155
- path.rename(new_target)
156
- log.debug(f"New symlink target: {new_target}")
157
-
158
- for symlink_path in symlink_paths:
159
- _link_checkpoint(new_target, symlink_path, metadata=False)
160
-
161
122
  def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
162
123
  current_metrics: dict[str, Any] = {
163
124
  "epoch": trainer.current_epoch,
@@ -195,6 +156,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
195
156
  filepath,
196
157
  self.config.save_weights_only,
197
158
  use_checkpoint_cache=None,
159
+ ckpt_cache_use_symlink=False,
198
160
  )
199
161
 
200
162
  if trainer.is_global_zero:
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
70
70
 
71
71
  # Events
72
72
  @override
73
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
73
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
74
74
  self.save_checkpoints(trainer)
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
39
39
  return True
40
40
 
41
41
  @override
42
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
42
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
43
43
  self.save_checkpoints(trainer)
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import os
3
+ import shutil
3
4
  from collections.abc import Sequence
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, cast
@@ -416,11 +417,12 @@ class Trainer(LightningTrainer):
416
417
  weights_only: bool = False,
417
418
  storage_options: Any | None = None,
418
419
  use_checkpoint_cache: bool | None = None,
420
+ ckpt_cache_use_symlink: bool = False,
419
421
  ):
420
422
  lm = self._base_module
421
- hparams = cast(BaseConfig, lm.hparams)
423
+ root_config = cast(BaseConfig, lm.hparams)
422
424
  if use_checkpoint_cache is None:
423
- use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
425
+ use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
424
426
 
425
427
  filepath = Path(filepath)
426
428
 
@@ -440,7 +442,10 @@ class Trainer(LightningTrainer):
440
442
  # If we have a cached path, then we symlink it to the new path.
441
443
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
444
  if self.is_global_zero:
443
- _link_checkpoint(cached_path, filepath, metadata=False)
445
+ if ckpt_cache_use_symlink:
446
+ _link_checkpoint(cached_path, filepath, metadata=False)
447
+ else:
448
+ shutil.copy(cached_path, filepath)
444
449
  self.strategy.barrier("Trainer.save_checkpoint")
445
450
  else:
446
451
  super().save_checkpoint(filepath, weights_only, storage_options)
@@ -454,7 +459,7 @@ class Trainer(LightningTrainer):
454
459
 
455
460
  # Save the checkpoint metadata
456
461
  metadata_path = None
457
- if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
462
+ if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
458
463
  # Generate the metadata and write to disk
459
464
  if (
460
465
  metadata_path := _write_checkpoint_metadata(self, lm, filepath)
nshtrainer/util/path.py CHANGED
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import os
2
3
  from pathlib import Path
3
4
  from typing import TypeAlias
@@ -50,3 +51,20 @@ def find_symlinks(
50
51
  pass
51
52
 
52
53
  return symlinks
54
+
55
+
56
+ def compute_file_checksum(file_path: Path) -> str:
57
+ """
58
+ Calculate the SHA256 checksum of a file.
59
+
60
+ Args:
61
+ file_path (Path): The path to the file.
62
+
63
+ Returns:
64
+ str: The hexadecimal representation of the file's SHA256 checksum.
65
+ """
66
+ sha256_hash = hashlib.sha256()
67
+ with file_path.open("rb") as f:
68
+ for byte_block in iter(lambda: f.read(4096), b""):
69
+ sha256_hash.update(byte_block)
70
+ return sha256_hash.hexdigest()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.23.0
3
+ Version: 0.25.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,18 +1,18 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
- nshtrainer/_checkpoint/metadata.py,sha256=E4tfiGzhnn65X95P0Y6K2d_YfPWqvHZoF0FF1-smEJc,5221
4
+ nshtrainer/_checkpoint/metadata.py,sha256=BpxC3VGrgSDHvZpc40A2icjBEDvRNGEsTis9YkFY8Kc,5341
5
5
  nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
7
+ nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=zKN-6n61aGze-Hf8MBY1Surh6B-xDwNSApqQJtPcTUs,8048
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
15
- nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
14
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
17
17
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
18
18
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=Leh3ADxoYsRWlJFIW20netohLcKx0XxUrRhD9LM4jws,19201
81
+ nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=WbPWXpu5LIDocQihQC3-72qxN1sa6-d1kPOmKDR-NC8,1520
85
+ nshtrainer/util/path.py,sha256=RUkIOrlj9b8zPPXE3JLhdihBNitJSUWw1whZ33u-2Yk,2005
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.23.0.dist-info/METADATA,sha256=wkbqsz6A4d0h1u-8CCZwfYYmqLm7YjirdnS-fTA-mkI,935
91
- nshtrainer-0.23.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.23.0.dist-info/RECORD,,
90
+ nshtrainer-0.25.0.dist-info/METADATA,sha256=Rqdeh2yp2AhZ_nOHlD47v5YPDrLc2fHN6WGwqJnDv04,935
91
+ nshtrainer-0.25.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.25.0.dist-info/RECORD,,