nshtrainer 1.0.0b41__py3-none-any.whl → 1.0.0b43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  import copy
4
4
  import datetime
5
5
  import logging
6
- from collections.abc import Callable
7
6
  from pathlib import Path
8
7
  from typing import TYPE_CHECKING, Any, ClassVar
9
8
 
@@ -115,7 +114,7 @@ def _metadata_path(checkpoint_path: Path):
115
114
  return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
116
115
 
117
116
 
118
- def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
117
+ def write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
119
118
  metadata_path = _metadata_path(checkpoint_path)
120
119
  metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
121
120
 
@@ -130,7 +129,7 @@ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
130
129
  return metadata_path
131
130
 
132
131
 
133
- def _remove_checkpoint_metadata(checkpoint_path: Path):
132
+ def remove_checkpoint_metadata(checkpoint_path: Path):
134
133
  path = _metadata_path(checkpoint_path)
135
134
  try:
136
135
  path.unlink(missing_ok=True)
@@ -140,23 +139,15 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
140
139
  log.debug(f"Removed {path}")
141
140
 
142
141
 
143
- def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
142
+ def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
144
143
  # First, remove any existing metadata files
145
- _remove_checkpoint_metadata(linked_checkpoint_path)
144
+ remove_checkpoint_metadata(linked_checkpoint_path)
146
145
 
147
146
  # Link the metadata files to the new checkpoint
148
147
  path = _metadata_path(checkpoint_path)
149
148
  linked_path = _metadata_path(linked_checkpoint_path)
150
- try_symlink_or_copy(path, linked_path)
151
149
 
150
+ if not path.exists():
151
+ raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
152
152
 
153
- def _sort_ckpts_by_metadata(
154
- checkpoint_paths: list[Path],
155
- key: Callable[[CheckpointMetadata, Path], Any],
156
- reverse: bool = False,
157
- ):
158
- return sorted(
159
- [(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
160
- key=lambda args_tuple: key(*args_tuple),
161
- reverse=reverse,
162
- )
153
+ try_symlink_or_copy(path, linked_path)
@@ -8,12 +8,12 @@ from pathlib import Path
8
8
  from lightning.pytorch import Trainer
9
9
 
10
10
  from ..util.path import try_symlink_or_copy
11
- from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
11
+ from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
15
15
 
16
- def _link_checkpoint(
16
+ def link_checkpoint(
17
17
  filepath: str | Path | os.PathLike,
18
18
  linkpath: str | Path | os.PathLike,
19
19
  *,
@@ -25,7 +25,10 @@ def _link_checkpoint(
25
25
 
26
26
  if remove_existing:
27
27
  try:
28
- if linkpath.exists():
28
+ if linkpath.exists(follow_symlinks=False):
29
+ # follow_symlinks=False is EXTREMELY important here
30
+ # Otherwise, we've already deleted the file that the symlink
31
+ # used to point to, so this always returns False
29
32
  if linkpath.is_dir():
30
33
  shutil.rmtree(linkpath)
31
34
  else:
@@ -36,14 +39,14 @@ def _link_checkpoint(
36
39
  log.debug(f"Removed {linkpath=}")
37
40
 
38
41
  if metadata:
39
- _remove_checkpoint_metadata(linkpath)
42
+ remove_checkpoint_metadata(linkpath)
40
43
 
41
44
  try_symlink_or_copy(filepath, linkpath)
42
45
  if metadata:
43
- _link_checkpoint_metadata(filepath, linkpath)
46
+ link_checkpoint_metadata(filepath, linkpath)
44
47
 
45
48
 
46
- def _remove_checkpoint(
49
+ def remove_checkpoint(
47
50
  trainer: Trainer,
48
51
  filepath: str | Path | os.PathLike,
49
52
  *,
@@ -54,4 +57,4 @@ def _remove_checkpoint(
54
57
  trainer.strategy.remove_checkpoint(filepath)
55
58
 
56
59
  if metadata:
57
- _remove_checkpoint_metadata(filepath)
60
+ remove_checkpoint_metadata(filepath)
@@ -12,7 +12,7 @@ from lightning.pytorch.callbacks import Checkpoint
12
12
  from typing_extensions import TypeVar, override
13
13
 
14
14
  from ..._checkpoint.metadata import CheckpointMetadata
15
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
15
+ from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
16
16
  from ..base import CallbackConfigBase
17
17
 
18
18
  if TYPE_CHECKING:
@@ -122,7 +122,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
122
122
  )
123
123
  continue
124
124
 
125
- _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
125
+ remove_checkpoint(trainer, old_ckpt_path, metadata=True)
126
126
  log.debug(f"Removed old checkpoint: {old_ckpt_path}")
127
127
 
128
128
  def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
@@ -167,7 +167,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
167
167
  # Create the latest symlink
168
168
  if (symlink_filename := self.symlink_path()) is not None:
169
169
  symlink_path = self.dirpath / symlink_filename
170
- _link_checkpoint(filepath, symlink_path, metadata=True)
170
+ link_checkpoint(filepath, symlink_path, metadata=True)
171
171
  log.debug(f"Created latest symlink: {symlink_path}")
172
172
 
173
173
  # Barrier to ensure all processes have saved the checkpoint,
@@ -18,7 +18,7 @@ from lightning.pytorch.trainer.states import TrainerFn
18
18
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
19
19
  from typing_extensions import Never, Unpack, assert_never, deprecated, override
20
20
 
21
- from .._checkpoint.metadata import _write_checkpoint_metadata
21
+ from .._checkpoint.metadata import write_checkpoint_metadata
22
22
  from ..callbacks.base import resolve_all_callbacks
23
23
  from ..util._environment_info import EnvironmentConfig
24
24
  from ..util.bf16 import is_bf16_supported_no_emulation
@@ -478,7 +478,7 @@ class Trainer(LightningTrainer):
478
478
  metadata_path = None
479
479
  if self.hparams.save_checkpoint_metadata and self.is_global_zero:
480
480
  # Generate the metadata and write to disk
481
- metadata_path = _write_checkpoint_metadata(self, filepath)
481
+ metadata_path = write_checkpoint_metadata(self, filepath)
482
482
 
483
483
  # Call the `on_checkpoint_saved` method on all callbacks
484
484
  from .. import _callback
nshtrainer/util/path.py CHANGED
@@ -81,18 +81,27 @@ def compute_file_checksum(file_path: Path) -> str:
81
81
  def try_symlink_or_copy(
82
82
  file_path: Path,
83
83
  link_path: Path,
84
+ *,
84
85
  target_is_directory: bool = False,
85
86
  relative: bool = True,
86
87
  remove_existing: bool = True,
88
+ throw_on_invalid_target: bool = False,
87
89
  ):
88
90
  """
89
91
  Symlinks on Unix, copies on Windows.
90
92
  """
91
93
 
94
+ # Check if the target file exists
95
+ if throw_on_invalid_target and not file_path.exists():
96
+ raise FileNotFoundError(f"File not found: {file_path}")
97
+
92
98
  # If the link already exists, remove it
93
99
  if remove_existing:
94
100
  try:
95
- if link_path.exists():
101
+ if link_path.exists(follow_symlinks=False):
102
+ # follow_symlinks=False is EXTREMELY important here
103
+ # Otherwise, we've already deleted the file that the symlink
104
+ # used to point to, so this always returns False
96
105
  if link_path.is_dir():
97
106
  shutil.rmtree(link_path)
98
107
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b41
3
+ Version: 1.0.0b43
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,8 +1,8 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
2
  nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
3
3
  nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
- nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
5
- nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
4
+ nshtrainer/_checkpoint/metadata.py,sha256=LQZ8g50rKxQQx-FqiW3n8EWmal9qSWRouOpIIn6NJJY,4758
5
+ nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zuyI,1584
6
6
  nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
@@ -10,7 +10,7 @@ nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTw
10
10
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
11
  nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=wCJBRI0pQYZc3GBu0b-aUBlBDhd39AdL82VvFgKmv3k,6300
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
@@ -122,7 +122,6 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
122
122
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
123
123
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
124
124
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
125
- nshtrainer/nn/tests/test_mlp.py,sha256=xBPiHlBvOCn67EbpzzKL-2FU7ikGxHT3i6CMSp1wk7M,1840
126
125
  nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
127
126
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
128
127
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
@@ -141,18 +140,18 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
141
140
  nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
142
141
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
143
142
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
144
- nshtrainer/trainer/trainer.py,sha256=l2kJs27v4IHZnzxExr0zX0sVex0wukgiD2Wn_0wiGJg,20836
143
+ nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
145
144
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
146
145
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
147
146
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
148
147
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
149
148
  nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
150
149
  nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
151
- nshtrainer/util/path.py,sha256=L-Nh9tlXSUfoP19TFbQq8I0AfS5ugCfGYTYFeddDHcs,3516
150
+ nshtrainer/util/path.py,sha256=9fIjE3S78pPL6wjAgEJUYfIJQAPdKOQqIYvTS9lWTUk,3959
152
151
  nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
153
152
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
154
153
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
155
154
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
156
- nshtrainer-1.0.0b41.dist-info/METADATA,sha256=DL9HgN6RP8X8v0sCdTr2IjRSwIBY96NZXe15m5V4y4c,988
157
- nshtrainer-1.0.0b41.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
158
- nshtrainer-1.0.0b41.dist-info/RECORD,,
155
+ nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
156
+ nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
157
+ nshtrainer-1.0.0b43.dist-info/RECORD,,
@@ -1,55 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import cast
4
-
5
- import pytest
6
- import torch
7
-
8
- from nshtrainer.nn.mlp import MLP
9
-
10
-
11
- def test_mlp_seed_reproducibility():
12
- """Test that the seed parameter in MLP ensures reproducible weights."""
13
-
14
- # Test dimensions
15
- dims = [10, 20, 5]
16
-
17
- # Create two MLPs with the same seed
18
- seed1 = 42
19
- mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
20
- mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
21
-
22
- # Create an MLP with a different seed
23
- seed2 = 123
24
- mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
25
-
26
- # Check first layer weights
27
- layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
28
- layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
29
- layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
30
-
31
- # Same seed should produce identical weights
32
- assert torch.allclose(layer1_weights1, layer1_weights2)
33
-
34
- # Different seeds should produce different weights
35
- assert not torch.allclose(layer1_weights1, layer1_weights3)
36
-
37
- # Check second layer weights
38
- layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
39
- layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
40
- layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
41
-
42
- # Same seed should produce identical weights for all layers
43
- assert torch.allclose(layer2_weights1, layer2_weights2)
44
-
45
- # Different seeds should produce different weights for all layers
46
- assert not torch.allclose(layer2_weights1, layer2_weights3)
47
-
48
- # Test that not providing a seed gives different results each time
49
- mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
50
- mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
51
-
52
- # Without seeds, weights should be different
53
- assert not torch.allclose(
54
- cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
55
- )