nshtrainer 0.20.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,40 @@
1
+ from pathlib import Path
2
+ from typing import TYPE_CHECKING
3
+
4
+ from lightning.pytorch.callbacks import Callback as _LightningCallback
5
+
6
+ if TYPE_CHECKING:
7
+ from .model import LightningModuleBase
8
+ from .trainer import Trainer
9
+
10
+
11
+ class NTCallbackBase(_LightningCallback):
12
+ def on_checkpoint_saved(
13
+ self,
14
+ ckpt_path: Path,
15
+ metadata_path: Path | None,
16
+ trainer: "Trainer",
17
+ pl_module: "LightningModuleBase",
18
+ ) -> None:
19
+ """Called after a checkpoint is saved."""
20
+ pass
21
+
22
+
23
+ def _call_on_checkpoint_saved(
24
+ trainer: "Trainer",
25
+ ckpt_path: str | Path,
26
+ metadata_path: str | Path | None,
27
+ ):
28
+ ckpt_path = Path(ckpt_path)
29
+ metadata_path = Path(metadata_path) if metadata_path else None
30
+
31
+ for callback in trainer.callbacks:
32
+ if not isinstance(callback, NTCallbackBase):
33
+ continue
34
+
35
+ callback.on_checkpoint_saved(
36
+ ckpt_path,
37
+ metadata_path,
38
+ trainer,
39
+ trainer._base_module,
40
+ )
@@ -96,13 +96,17 @@ def _generate_checkpoint_metadata(
96
96
  )
97
97
 
98
98
 
99
+ def _metadata_path(checkpoint_path: Path):
100
+ return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
101
+
102
+
99
103
  def _write_checkpoint_metadata(
100
104
  trainer: "Trainer",
101
105
  model: "LightningModuleBase",
102
106
  checkpoint_path: Path,
103
107
  ):
104
108
  config = cast("BaseConfig", model.config)
105
- metadata_path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
109
+ metadata_path = _metadata_path(checkpoint_path)
106
110
  metadata = _generate_checkpoint_metadata(
107
111
  config, trainer, checkpoint_path, metadata_path
108
112
  )
@@ -119,7 +123,7 @@ def _write_checkpoint_metadata(
119
123
 
120
124
 
121
125
  def _remove_checkpoint_metadata(checkpoint_path: Path):
122
- path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
126
+ path = _metadata_path(checkpoint_path)
123
127
  try:
124
128
  path.unlink(missing_ok=True)
125
129
  except Exception:
@@ -133,8 +137,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
133
137
  _remove_checkpoint_metadata(linked_checkpoint_path)
134
138
 
135
139
  # Link the metadata files to the new checkpoint
136
- path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
137
- linked_path = linked_checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
140
+ path = _metadata_path(checkpoint_path)
141
+ linked_path = _metadata_path(linked_checkpoint_path)
138
142
  try:
139
143
  try:
140
144
  # linked_path.symlink_to(path)
@@ -4,6 +4,7 @@ from pathlib import Path
4
4
 
5
5
  from lightning.pytorch import Trainer
6
6
 
7
+ from ..util.path import get_relative_path
7
8
  from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
9
 
9
10
 
@@ -14,10 +15,8 @@ def _link_checkpoint(
14
15
  metadata: bool,
15
16
  remove_existing: bool = True,
16
17
  ):
17
- if not isinstance(filepath, Path):
18
- filepath = Path(filepath)
19
- if not isinstance(linkpath, Path):
20
- linkpath = Path(linkpath)
18
+ filepath = Path(filepath)
19
+ linkpath = Path(linkpath)
21
20
 
22
21
  if remove_existing:
23
22
  if linkpath.exists():
@@ -30,7 +29,7 @@ def _link_checkpoint(
30
29
  _remove_checkpoint_metadata(linkpath)
31
30
 
32
31
  try:
33
- linkpath.symlink_to(filepath.relative_to(linkpath.parent))
32
+ linkpath.symlink_to(get_relative_path(linkpath, filepath))
34
33
  except OSError:
35
34
  # on Windows, special permissions are required to create symbolic links as a regular user
36
35
  # fall back to copying the file
@@ -46,9 +45,9 @@ def _remove_checkpoint(
46
45
  *,
47
46
  metadata: bool,
48
47
  ):
49
- if not isinstance(filepath, Path):
50
- filepath = Path(filepath)
48
+ filepath = Path(filepath)
51
49
 
52
50
  trainer.strategy.remove_checkpoint(filepath)
51
+
53
52
  if metadata:
54
53
  _remove_checkpoint_metadata(filepath)
nshtrainer/_hf_hub.py CHANGED
@@ -6,11 +6,10 @@ from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  import nshconfig as C
9
- from lightning.pytorch import Callback
10
- from lightning.pytorch.trainer import Trainer
11
9
  from nshrunner._env import SNAPSHOT_DIR
12
10
  from typing_extensions import override
13
11
 
12
+ from ._callback import NTCallbackBase
14
13
  from .callbacks.base import (
15
14
  CallbackConfigBase,
16
15
  CallbackMetadataConfig,
@@ -22,6 +21,8 @@ if TYPE_CHECKING:
22
21
 
23
22
  from .model.base import BaseConfig
24
23
  from .trainer.trainer import Trainer
24
+
25
+
25
26
  log = logging.getLogger(__name__)
26
27
 
27
28
 
@@ -102,9 +103,9 @@ def _api(token: str | None = None):
102
103
 
103
104
  # Verify authentication
104
105
  api.whoami()
105
- except Exception as e:
106
+ except Exception:
106
107
  log.exception(
107
- f"Authentication failed for Hugging Face Hub: {str(e)}. "
108
+ "Authentication failed for Hugging Face Hub. "
108
109
  "Please make sure you are logged in using `huggingface-cli login`, "
109
110
  "by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
110
111
  "or by providing a valid token in the configuration."
@@ -210,10 +211,10 @@ def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
210
211
  exist_ok=True,
211
212
  )
212
213
  log.info(f"Created new repository '{repo_name}'.")
213
- except Exception as e:
214
- log.exception(f"Failed to create repository '{repo_name}': {str(e)}")
215
- except Exception as e:
216
- log.exception(f"Error checking repository '{repo_name}': {str(e)}")
214
+ except Exception:
215
+ log.exception(f"Failed to create repository '{repo_name}'")
216
+ except Exception:
217
+ log.exception(f"Error checking repository '{repo_name}'")
217
218
 
218
219
  # Upload the config
219
220
  _save_config(root_config, trainer=trainer)
@@ -262,9 +263,9 @@ def _save_code(
262
263
  log.info(
263
264
  f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
264
265
  )
265
- except Exception as e:
266
+ except Exception:
266
267
  log.exception(
267
- f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder: {str(e)}"
268
+ f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
268
269
  )
269
270
 
270
271
 
@@ -300,10 +301,8 @@ def _save_config(
300
301
  run_as_future=cast(Any, config.save_in_background),
301
302
  )
302
303
  log.info(f"Uploaded config.json to repository '{repo_name}'.")
303
- except Exception as e:
304
- log.exception(
305
- f"Failed to upload config.json to repository '{repo_name}': {str(e)}"
306
- )
304
+ except Exception:
305
+ log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
307
306
 
308
307
 
309
308
  def _save_checkpoint_files(
@@ -331,17 +330,24 @@ def _save_checkpoint_files(
331
330
  # Read all the files to memory
332
331
  file_contents: list[bytes | None] = []
333
332
  for p in paths:
333
+ assert not p.is_symlink(), f"Path {p} is a symlink."
334
+ assert p.is_file(), f"Path {p} is not a file."
334
335
  try:
335
336
  with open(p, "rb") as f:
336
337
  file_contents.append(f.read())
337
- except IOError as e:
338
- log.warning(f"Failed to read checkpoint file {p}: {str(e)}")
338
+ except IOError:
339
+ log.exception(f"Failed to read checkpoint file {p}.")
339
340
  file_contents.append(None)
340
341
 
341
- for p, contents in zip(paths, file_contents):
342
- if contents is None:
343
- continue
342
+ # Remove the paths that failed to read
343
+ file_contents_and_paths = [
344
+ (contents, p)
345
+ for contents, p in zip(file_contents, paths)
346
+ if contents is not None
347
+ ]
344
348
 
349
+ # Upload the checkpoint files to the repository
350
+ for contents, p in file_contents_and_paths:
345
351
  try:
346
352
  relative_path = p.relative_to(checkpoint_dir)
347
353
  except ValueError:
@@ -365,21 +371,136 @@ def _save_checkpoint_files(
365
371
  log.info(
366
372
  f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
367
373
  )
368
- except Exception as e:
374
+ except Exception:
369
375
  log.exception(
370
- f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}': {str(e)}"
376
+ f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
371
377
  )
372
378
 
373
379
  log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
374
380
 
375
381
 
376
- class HFHubCallback(Callback):
382
+ def _save_checkpoint_symlinks(
383
+ trainer: "Trainer",
384
+ paths: list[Path],
385
+ *,
386
+ root_config: "BaseConfig",
387
+ ):
388
+ config = root_config.trainer.hf_hub
389
+ if (
390
+ api := _enabled_and_valid(trainer, config, rank_zero_only=True)
391
+ ) is None or not config.save_checkpoints:
392
+ return
393
+
394
+ # Resolve the checkpoint directory
395
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
396
+ root_config.id, "checkpoint"
397
+ )
398
+
399
+ # Resolve the repository name
400
+ repo_name = _repo_name(api, root_config)
401
+
402
+ # Create a commit for copying the files
403
+ from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
404
+
405
+ commits: list[CommitOperation] = []
406
+ for p in paths:
407
+ assert p.is_symlink(), f"Path {p} is not a symlink."
408
+
409
+ try:
410
+ dest_relative_path = p.relative_to(checkpoint_dir)
411
+ except ValueError:
412
+ log.warning(
413
+ f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
414
+ )
415
+ continue
416
+
417
+ try:
418
+ source_relative_path = p.resolve().relative_to(checkpoint_dir)
419
+ except ValueError:
420
+ log.warning(
421
+ f"Checkpoint symlink target {p.resolve()} is not within the checkpoint directory {checkpoint_dir}."
422
+ )
423
+ continue
424
+
425
+ # Prefix the path in repo with "checkpoints"
426
+ dest_path_in_repo = Path("checkpoints") / dest_relative_path
427
+ source_path_in_repo = Path("checkpoints") / source_relative_path
428
+
429
+ # Create and append a CommitOperationCopy for copying the symlink
430
+ copy_op = CommitOperationCopy(
431
+ src_path_in_repo=str(source_path_in_repo),
432
+ path_in_repo=str(dest_path_in_repo),
433
+ )
434
+ commits.append(copy_op)
435
+
436
+ log.info(f"Creating a commit with the following operations: {commits}")
437
+
438
+ try:
439
+ api.create_commit(
440
+ repo_id=repo_name,
441
+ repo_type="model",
442
+ commit_message="Copy checkpoint symlinks",
443
+ operations=commits,
444
+ run_as_future=cast(Any, config.save_in_background),
445
+ )
446
+ log.info(
447
+ f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
448
+ )
449
+ except Exception:
450
+ log.exception(
451
+ f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
452
+ )
453
+
454
+ log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
455
+
456
+
457
+ def _save_checkpoint_directory(trainer: "Trainer", *, root_config: "BaseConfig"):
458
+ config = root_config.trainer.hf_hub
459
+ if (
460
+ api := _enabled_and_valid(trainer, config, rank_zero_only=True)
461
+ ) is None or not config.save_checkpoints:
462
+ return
463
+
464
+ # Resolve the checkpoint directory
465
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
466
+ root_config.id, "checkpoint"
467
+ )
468
+
469
+ # Resolve the repository name
470
+ repo_name = _repo_name(api, root_config)
471
+
472
+ # Upload the checkpoint directory to the repository
473
+ try:
474
+ api.upload_folder(
475
+ folder_path=str(checkpoint_dir),
476
+ repo_id=repo_name,
477
+ repo_type="model",
478
+ path_in_repo="checkpoints",
479
+ run_as_future=cast(Any, config.save_in_background),
480
+ )
481
+ log.info(f"Uploaded checkpoint directory to repository '{repo_name}'.")
482
+ except Exception:
483
+ log.exception(
484
+ f"Failed to upload checkpoint directory to repository '{repo_name}'."
485
+ )
486
+
487
+ log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
488
+
489
+
490
+ class HFHubCallback(NTCallbackBase):
377
491
  def __init__(self, config: HuggingFaceHubConfig):
378
492
  super().__init__()
379
493
  self.config = config
380
494
 
381
495
  @override
382
496
  def setup(self, trainer, pl_module, stage):
497
+ from .trainer.trainer import Trainer
498
+
499
+ if not isinstance(trainer, Trainer):
500
+ raise ValueError(
501
+ f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
502
+ )
503
+
383
504
  root_config = cast("BaseConfig", pl_module.hparams)
384
505
  _init(trainer=trainer, root_config=root_config)
385
506
 
@@ -387,3 +508,18 @@ class HFHubCallback(Callback):
387
508
  def teardown(self, trainer, pl_module, stage):
388
509
  if hasattr(trainer, "_hf_hub_api"):
389
510
  delattr(trainer, "_hf_hub_api")
511
+
512
+ @override
513
+ def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
514
+ root_config = cast("BaseConfig", pl_module.hparams)
515
+
516
+ # If HF Hub is enabled, then we upload
517
+ if root_config.trainer.hf_hub and trainer.is_global_zero:
518
+ # Upload the regular files first, then the symlinks
519
+ all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
520
+ if regular_paths := [p for p in all_paths if not p.is_symlink()]:
521
+ _save_checkpoint_files(trainer, regular_paths, root_config=root_config)
522
+ if symlink_paths := [p for p in all_paths if p.is_symlink()]:
523
+ _save_checkpoint_symlinks(
524
+ trainer, symlink_paths, root_config=root_config
525
+ )
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer
9
9
  from lightning.pytorch.callbacks import Checkpoint
10
10
  from typing_extensions import TypeVar, override
11
11
 
12
- from ..._checkpoint.metadata import CheckpointMetadata
12
+ from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
13
13
  from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
14
  from ..base import CallbackConfigBase
15
15
 
@@ -65,8 +65,6 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
65
65
  self.dirpath.mkdir(parents=True, exist_ok=True)
66
66
  self.symlink_dirpath = dirpath
67
67
 
68
- self._last_global_step_saved = 0
69
-
70
68
  @abstractmethod
71
69
  def default_filename(self) -> str: ...
72
70
 
@@ -144,9 +142,21 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
144
142
  if self._should_skip_saving_checkpoint(trainer):
145
143
  return
146
144
 
145
+ from ...trainer import Trainer as NTTrainer
146
+
147
+ if not isinstance(trainer, NTTrainer):
148
+ raise TypeError(
149
+ f"Trainer must be an instance of {NTTrainer.__name__}, "
150
+ f"but got {type(trainer).__name__}"
151
+ )
152
+
147
153
  # Save the new checkpoint
148
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
149
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
155
+ trainer._nshtrainer_save_checkpoint(
156
+ filepath,
157
+ self.config.save_weights_only,
158
+ use_checkpoint_cache=None,
159
+ )
150
160
 
151
161
  if trainer.is_global_zero:
152
162
  # Create the latest symlink
@@ -162,8 +172,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
162
172
  # deleted the old checkpoints, and created the symlink before continuing
163
173
  trainer.strategy.barrier()
164
174
 
165
- # Set the last global step saved
166
- self._last_global_step_saved = trainer.global_step
175
+ # Call the on save checkpoint callback for the symlink (if it exists)
176
+ if (symlink_filename := self.symlink_path()) is not None:
177
+ from ... import _callback
178
+
179
+ symlink_path = self.dirpath / symlink_filename
180
+ symlink_metadata_path = _metadata_path(symlink_path)
181
+ _callback._call_on_checkpoint_saved(
182
+ trainer, symlink_path, symlink_metadata_path
183
+ )
167
184
 
168
185
  def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
169
186
  from lightning.pytorch.trainer.states import TrainerFn
@@ -175,6 +192,4 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
175
192
  or trainer.state.fn
176
193
  != TrainerFn.FITTING # don't save anything during non-fit
177
194
  or trainer.sanity_checking # don't save anything during sanity check
178
- or self._last_global_step_saved
179
- == trainer.global_step # already saved at the last step
180
195
  )
@@ -1012,6 +1012,8 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
+ use_checkpoint_cache: bool = True
1016
+ """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1015
1017
 
1016
1018
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1017
1019
  """
@@ -17,6 +17,7 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
17
17
  from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
+ from .._checkpoint.saver import _link_checkpoint
20
21
  from ..callbacks.base import resolve_all_callbacks
21
22
  from ..model.config import (
22
23
  AcceleratorConfigProtocol,
@@ -285,6 +286,8 @@ class Trainer(LightningTrainer):
285
286
  /,
286
287
  **kwargs: Unpack[LightningTrainerKwargs],
287
288
  ):
289
+ self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
290
+
288
291
  self._pre_init(config)
289
292
 
290
293
  kwargs = self._update_kwargs(config, kwargs)
@@ -407,29 +410,70 @@ class Trainer(LightningTrainer):
407
410
 
408
411
  return super()._run(model, ckpt_path)
409
412
 
410
- @override
411
- def save_checkpoint(
413
+ def _nshtrainer_save_checkpoint(
412
414
  self,
413
415
  filepath: str | Path,
414
416
  weights_only: bool = False,
415
417
  storage_options: Any | None = None,
418
+ use_checkpoint_cache: bool | None = None,
416
419
  ):
420
+ lm = self._base_module
421
+ hparams = cast(BaseConfig, lm.hparams)
422
+ if use_checkpoint_cache is None:
423
+ use_checkpoint_cache = hparams.trainer.use_checkpoint_cache
424
+
417
425
  filepath = Path(filepath)
418
- ret_val = super().save_checkpoint(filepath, weights_only, storage_options)
426
+
427
+ # List of files that we should upload to HF
428
+ written_files: list[Path] = [filepath]
429
+
430
+ cached_path = None
431
+ if (
432
+ use_checkpoint_cache
433
+ and (
434
+ cached_path := self._nshtrainer_checkpoint_cache.get(
435
+ (self.current_epoch, self.global_step)
436
+ )
437
+ )
438
+ is not None
439
+ ):
440
+ # If we have a cached path, then we symlink it to the new path.
441
+ log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
+ _link_checkpoint(cached_path, filepath, metadata=False)
443
+ else:
444
+ super().save_checkpoint(filepath, weights_only, storage_options)
445
+
446
+ # If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
447
+ if use_checkpoint_cache and cached_path is None:
448
+ self._nshtrainer_checkpoint_cache[
449
+ (self.current_epoch, self.global_step)
450
+ ] = filepath
451
+ log.debug(f"Checkpoint saved to cache: {filepath}")
419
452
 
420
453
  # Save the checkpoint metadata
421
- lm = self._base_module
422
- hparams = cast(BaseConfig, lm.hparams)
423
454
  metadata_path = None
424
455
  if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
425
456
  # Generate the metadata and write to disk
426
- metadata_path = _write_checkpoint_metadata(self, lm, filepath)
457
+ if (
458
+ metadata_path := _write_checkpoint_metadata(self, lm, filepath)
459
+ ) is not None:
460
+ written_files.append(metadata_path)
427
461
 
428
- # If HF Hub is enabled, then we upload
429
- if hparams.trainer.hf_hub and self.is_global_zero:
430
- from .._hf_hub import _save_checkpoint_files
462
+ # Call the `on_checkpoint_saved` method on all callbacks
463
+ from .. import _callback
431
464
 
432
- files = [f for f in (filepath, metadata_path) if f is not None]
433
- _save_checkpoint_files(self, files, root_config=hparams)
465
+ _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
434
466
 
435
- return ret_val
467
+ @override
468
+ def save_checkpoint(
469
+ self,
470
+ filepath: str | Path,
471
+ weights_only: bool = False,
472
+ storage_options: Any | None = None,
473
+ ):
474
+ return self._nshtrainer_save_checkpoint(
475
+ filepath=filepath,
476
+ weights_only=weights_only,
477
+ storage_options=storage_options,
478
+ use_checkpoint_cache=False,
479
+ )
@@ -0,0 +1,29 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import TypeAlias
4
+
5
+ _Path: TypeAlias = str | Path | os.PathLike
6
+
7
+
8
+ def get_relative_path(source: _Path, destination: _Path):
9
+ # Get the absolute paths
10
+ source = os.path.abspath(source)
11
+ destination = os.path.abspath(destination)
12
+
13
+ # Split the paths into components
14
+ source_parts = source.split(os.sep)
15
+ destination_parts = destination.split(os.sep)
16
+
17
+ # Find the point where the paths diverge
18
+ i = 0
19
+ for i in range(min(len(source_parts), len(destination_parts))):
20
+ if source_parts[i] != destination_parts[i]:
21
+ break
22
+ else:
23
+ i += 1
24
+
25
+ # Build the relative path
26
+ up = os.sep.join([".." for _ in range(len(source_parts) - i - 1)])
27
+ down = os.sep.join(destination_parts[i:])
28
+
29
+ return Path(os.path.normpath(os.path.join(up, down)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.20.0
3
+ Version: 0.21.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -26,6 +26,7 @@ Requires-Dist: torchmetrics ; extra == "extra"
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
29
+ Requires-Dist: zstandard ; extra == "extra"
29
30
  Description-Content-Type: text/markdown
30
31
 
31
32
 
@@ -1,15 +1,16 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
+ nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
2
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
3
- nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
4
- nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
4
+ nshtrainer/_checkpoint/metadata.py,sha256=TLAt7yR3KhSRbXCtomLMxcMvOiAju873A1ZRo8VWNwA,5179
5
+ nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
5
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
6
- nshtrainer/_hf_hub.py,sha256=To3BnnGWbMNNMBdzVtgrNOcNU2fi1dQpwwuclusFAbI,12169
7
+ nshtrainer/_hf_hub.py,sha256=0bOhJNyIjQGJsMRaW7qQJc1oTnUMHj08auuztzTQvZ0,16906
7
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
8
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
9
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
10
11
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
11
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
12
- nshtrainer/callbacks/checkpoint/_base.py,sha256=YT_V-oihO9iB4ETl46CGYTCQjIYl-CpV7TMViTn07Lk,6144
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
13
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
14
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
15
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -57,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
57
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
58
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
- nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
61
+ nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
61
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
62
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
63
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -77,14 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
77
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
78
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
79
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
80
- nshtrainer/trainer/trainer.py,sha256=TTtVkgSB_ekgDlHg24d58Vzddtkpp6ZHOTVprXdXMH0,17503
81
+ nshtrainer/trainer/trainer.py,sha256=DNKA4mcW083i7qLk0fi3j5-Qj4KNBtlLuyIsNxykebw,19100
81
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
82
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
83
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
+ nshtrainer/util/path.py,sha256=A_Ocag3_hbwns_zAxFDlH-5eVHWFlcy2DKxHQ7jddvk,837
84
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.20.0.dist-info/METADATA,sha256=BCzgQYVMH8_7VHpAcAEuJqlQ0oJOERSbBop4bOebYZ4,935
89
- nshtrainer-0.20.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.20.0.dist-info/RECORD,,
90
+ nshtrainer-0.21.0.dist-info/METADATA,sha256=7QfSX_yXi-Up6uxOVFfDPn4ieGK5b3UgQfO_KFsNzXk,979
91
+ nshtrainer-0.21.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.21.0.dist-info/RECORD,,