nshtrainer 0.25.0__py3-none-any.whl → 0.26.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  import copy
2
2
  import datetime
3
3
  import logging
4
- import shutil
5
4
  from collections.abc import Callable
6
5
  from pathlib import Path
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, cast
@@ -11,7 +10,7 @@ import numpy as np
11
10
  import torch
12
11
 
13
12
  from ..util._environment_info import EnvironmentConfig
14
- from ..util.path import compute_file_checksum, get_relative_path
13
+ from ..util.path import compute_file_checksum, try_symlink_or_copy
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from ..model import BaseConfig, LightningModuleBase
@@ -142,21 +141,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
142
141
  # Link the metadata files to the new checkpoint
143
142
  path = _metadata_path(checkpoint_path)
144
143
  linked_path = _metadata_path(linked_checkpoint_path)
145
- try:
146
- try:
147
- # linked_path.symlink_to(path)
148
- # We should store the path as a relative path
149
- # to the metadata file to avoid issues with
150
- # moving the checkpoint directory
151
- linked_path.symlink_to(get_relative_path(linked_path, path))
152
- except OSError:
153
- # on Windows, special permissions are required to create symbolic links as a regular user
154
- # fall back to copying the file
155
- shutil.copy(path, linked_path)
156
- except Exception:
157
- log.exception(f"Failed to link {path} to {linked_path}")
158
- else:
159
- log.debug(f"Linked {path} to {linked_path}")
144
+ try_symlink_or_copy(path, linked_path)
160
145
 
161
146
 
162
147
  def _sort_ckpts_by_metadata(
@@ -5,7 +5,7 @@ from pathlib import Path
5
5
 
6
6
  from lightning.pytorch import Trainer
7
7
 
8
- from ..util.path import get_relative_path
8
+ from ..util.path import try_symlink_or_copy
9
9
  from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
10
10
 
11
11
  log = logging.getLogger(__name__)
@@ -34,13 +34,7 @@ def _link_checkpoint(
34
34
  if metadata:
35
35
  _remove_checkpoint_metadata(linkpath)
36
36
 
37
- try:
38
- linkpath.symlink_to(get_relative_path(linkpath, filepath))
39
- except OSError:
40
- # on Windows, special permissions are required to create symbolic links as a regular user
41
- # fall back to copying the file
42
- shutil.copy(filepath, linkpath)
43
-
37
+ try_symlink_or_copy(filepath, linkpath)
44
38
  if metadata:
45
39
  _link_checkpoint_metadata(filepath, linkpath)
46
40
 
@@ -1,8 +1,8 @@
1
- import importlib.util
2
1
  import logging
3
2
  from typing import Any, Literal, Protocol, runtime_checkable
4
3
 
5
4
  import torch
5
+ import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -20,19 +20,12 @@ class HasGradSkippedSteps(Protocol):
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- if importlib.util.find_spec("torchmetrics") is not None:
24
- raise ImportError(
25
- "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
- )
27
-
28
23
  super().__init__()
29
24
  self.config = config
30
25
 
31
26
  @override
32
27
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
33
28
  if not isinstance(pl_module, HasGradSkippedSteps):
34
- import torchmetrics # type: ignore
35
-
36
29
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
37
30
 
38
31
  @override
nshtrainer/util/path.py CHANGED
@@ -1,8 +1,13 @@
1
1
  import hashlib
2
+ import logging
2
3
  import os
4
+ import platform
5
+ import shutil
3
6
  from pathlib import Path
4
7
  from typing import TypeAlias
5
8
 
9
+ log = logging.getLogger(__name__)
10
+
6
11
  _Path: TypeAlias = str | Path | os.PathLike
7
12
 
8
13
 
@@ -68,3 +73,32 @@ def compute_file_checksum(file_path: Path) -> str:
68
73
  for byte_block in iter(lambda: f.read(4096), b""):
69
74
  sha256_hash.update(byte_block)
70
75
  return sha256_hash.hexdigest()
76
+
77
+
78
+ def try_symlink_or_copy(
79
+ file_path: Path,
80
+ link_path: Path,
81
+ target_is_directory: bool = False,
82
+ relative: bool = True,
83
+ ):
84
+ """
85
+ Symlinks on Unix, copies on Windows.
86
+ """
87
+
88
+ symlink_target = get_relative_path(link_path, file_path) if relative else file_path
89
+ try:
90
+ if platform.system() == "Windows":
91
+ if target_is_directory:
92
+ shutil.copytree(file_path, link_path)
93
+ else:
94
+ shutil.copy(file_path, link_path)
95
+ else:
96
+ link_path.symlink_to(
97
+ symlink_target, target_is_directory=target_is_directory
98
+ )
99
+ except Exception:
100
+ log.exception(f"Failed to create symlink or copy {file_path} to {link_path}")
101
+ return False
102
+ else:
103
+ log.debug(f"Created symlink or copied {file_path} to {link_path}")
104
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.25.0
3
+ Version: 0.26.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -22,7 +22,7 @@ Requires-Dist: psutil
22
22
  Requires-Dist: pytorch-lightning
23
23
  Requires-Dist: tensorboard ; extra == "extra"
24
24
  Requires-Dist: torch
25
- Requires-Dist: torchmetrics ; extra == "extra"
25
+ Requires-Dist: torchmetrics
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
@@ -1,8 +1,8 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
- nshtrainer/_checkpoint/metadata.py,sha256=BpxC3VGrgSDHvZpc40A2icjBEDvRNGEsTis9YkFY8Kc,5341
5
- nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
4
+ nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
+ nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
7
  nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
@@ -17,7 +17,7 @@ nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6
17
17
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
18
18
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
19
19
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
20
- nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
20
+ nshtrainer/callbacks/gradient_skipping.py,sha256=EBNkANDnD3BTszWjnG-jwY8FEj-iRqhE3e1x5LQF6M8,3393
21
21
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
22
22
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
23
23
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
@@ -82,11 +82,11 @@ nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=RUkIOrlj9b8zPPXE3JLhdihBNitJSUWw1whZ33u-2Yk,2005
85
+ nshtrainer/util/path.py,sha256=jAEjF1qp8Aii32L5lWG4UFgVyQAFkHOMYEc_TC2hDx8,2947
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.25.0.dist-info/METADATA,sha256=Rqdeh2yp2AhZ_nOHlD47v5YPDrLc2fHN6WGwqJnDv04,935
91
- nshtrainer-0.25.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.25.0.dist-info/RECORD,,
90
+ nshtrainer-0.26.1.dist-info/METADATA,sha256=tMMpyg1BTKec5d69ziW6XBxDXaI0gSK5tDMPCmj7VCQ,916
91
+ nshtrainer-0.26.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.26.1.dist-info/RECORD,,