nshtrainer 1.0.0b41__py3-none-any.whl → 1.0.0b42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -130,7 +130,7 @@ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
130
130
  return metadata_path
131
131
 
132
132
 
133
- def _remove_checkpoint_metadata(checkpoint_path: Path):
133
+ def remove_checkpoint_metadata(checkpoint_path: Path):
134
134
  path = _metadata_path(checkpoint_path)
135
135
  try:
136
136
  path.unlink(missing_ok=True)
@@ -140,23 +140,15 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
140
140
  log.debug(f"Removed {path}")
141
141
 
142
142
 
143
- def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
143
+ def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
144
144
  # First, remove any existing metadata files
145
- _remove_checkpoint_metadata(linked_checkpoint_path)
145
+ remove_checkpoint_metadata(linked_checkpoint_path)
146
146
 
147
147
  # Link the metadata files to the new checkpoint
148
148
  path = _metadata_path(checkpoint_path)
149
149
  linked_path = _metadata_path(linked_checkpoint_path)
150
- try_symlink_or_copy(path, linked_path)
151
150
 
151
+ if not path.exists():
152
+ raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
152
153
 
153
- def _sort_ckpts_by_metadata(
154
- checkpoint_paths: list[Path],
155
- key: Callable[[CheckpointMetadata, Path], Any],
156
- reverse: bool = False,
157
- ):
158
- return sorted(
159
- [(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
160
- key=lambda args_tuple: key(*args_tuple),
161
- reverse=reverse,
162
- )
154
+ try_symlink_or_copy(path, linked_path)
@@ -8,12 +8,12 @@ from pathlib import Path
8
8
  from lightning.pytorch import Trainer
9
9
 
10
10
  from ..util.path import try_symlink_or_copy
11
- from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
11
+ from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
15
15
 
16
- def _link_checkpoint(
16
+ def link_checkpoint(
17
17
  filepath: str | Path | os.PathLike,
18
18
  linkpath: str | Path | os.PathLike,
19
19
  *,
@@ -36,14 +36,14 @@ def _link_checkpoint(
36
36
  log.debug(f"Removed {linkpath=}")
37
37
 
38
38
  if metadata:
39
- _remove_checkpoint_metadata(linkpath)
39
+ remove_checkpoint_metadata(linkpath)
40
40
 
41
41
  try_symlink_or_copy(filepath, linkpath)
42
42
  if metadata:
43
- _link_checkpoint_metadata(filepath, linkpath)
43
+ link_checkpoint_metadata(filepath, linkpath)
44
44
 
45
45
 
46
- def _remove_checkpoint(
46
+ def remove_checkpoint(
47
47
  trainer: Trainer,
48
48
  filepath: str | Path | os.PathLike,
49
49
  *,
@@ -54,4 +54,4 @@ def _remove_checkpoint(
54
54
  trainer.strategy.remove_checkpoint(filepath)
55
55
 
56
56
  if metadata:
57
- _remove_checkpoint_metadata(filepath)
57
+ remove_checkpoint_metadata(filepath)
@@ -12,7 +12,7 @@ from lightning.pytorch.callbacks import Checkpoint
12
12
  from typing_extensions import TypeVar, override
13
13
 
14
14
  from ..._checkpoint.metadata import CheckpointMetadata
15
- from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
15
+ from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
16
16
  from ..base import CallbackConfigBase
17
17
 
18
18
  if TYPE_CHECKING:
@@ -122,7 +122,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
122
122
  )
123
123
  continue
124
124
 
125
- _remove_checkpoint(trainer, old_ckpt_path, metadata=True)
125
+ remove_checkpoint(trainer, old_ckpt_path, metadata=True)
126
126
  log.debug(f"Removed old checkpoint: {old_ckpt_path}")
127
127
 
128
128
  def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
@@ -167,7 +167,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
167
167
  # Create the latest symlink
168
168
  if (symlink_filename := self.symlink_path()) is not None:
169
169
  symlink_path = self.dirpath / symlink_filename
170
- _link_checkpoint(filepath, symlink_path, metadata=True)
170
+ link_checkpoint(filepath, symlink_path, metadata=True)
171
171
  log.debug(f"Created latest symlink: {symlink_path}")
172
172
 
173
173
  # Barrier to ensure all processes have saved the checkpoint,
@@ -5,7 +5,7 @@ from typing import cast
5
5
  import pytest
6
6
  import torch
7
7
 
8
- from nshtrainer.nn.mlp import MLP
8
+ from nshtrainer.nn.mlp import MLP, custom_seed_context
9
9
 
10
10
 
11
11
  def test_mlp_seed_reproducibility():
@@ -53,3 +53,68 @@ def test_mlp_seed_reproducibility():
53
53
  assert not torch.allclose(
54
54
  cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
55
55
  )
56
+
57
+
58
+ def test_custom_seed_context():
59
+ """Test that custom_seed_context properly controls random number generation."""
60
+
61
+ # Test that the same seed produces the same random numbers
62
+ with custom_seed_context(42):
63
+ tensor1 = torch.randn(10)
64
+
65
+ with custom_seed_context(42):
66
+ tensor2 = torch.randn(10)
67
+
68
+ # Same seed should produce identical random tensors
69
+ assert torch.allclose(tensor1, tensor2)
70
+
71
+ # Test that different seeds produce different random numbers
72
+ with custom_seed_context(123):
73
+ tensor3 = torch.randn(10)
74
+
75
+ # Different seeds should produce different random tensors
76
+ assert not torch.allclose(tensor1, tensor3)
77
+
78
+
79
+ def test_custom_seed_context_preserves_state():
80
+ """Test that custom_seed_context preserves the original random state."""
81
+
82
+ # Set a known seed for the test
83
+ original_seed = 789
84
+ torch.manual_seed(original_seed)
85
+
86
+ # Generate a tensor with the original seed
87
+ original_tensor = torch.randn(10)
88
+
89
+ # Use a different seed in the context
90
+ with custom_seed_context(42):
91
+ # This should use the temporary seed
92
+ context_tensor = torch.randn(10)
93
+
94
+ # After exiting the context, we should be back to the original seed state
95
+ # Reset the generator to get the same sequence again
96
+ torch.manual_seed(original_seed)
97
+ expected_tensor = torch.randn(10)
98
+
99
+ # The tensor generated after the context should match what we would get
100
+ # if we had just set the original seed again
101
+ assert torch.allclose(original_tensor, expected_tensor)
102
+
103
+ # And it should be different from the tensor generated inside the context
104
+ assert not torch.allclose(original_tensor, context_tensor)
105
+
106
+
107
+ def test_custom_seed_context_with_none():
108
+ """Test that custom_seed_context with None seed doesn't affect randomization."""
109
+
110
+ # Set a known seed
111
+ torch.manual_seed(555)
112
+ expected_tensor = torch.randn(10)
113
+
114
+ # Reset and use None seed in context
115
+ torch.manual_seed(555)
116
+ with custom_seed_context(None):
117
+ actual_tensor = torch.randn(10)
118
+
119
+ # With None seed, the context should not affect the random state
120
+ assert torch.allclose(expected_tensor, actual_tensor)
nshtrainer/util/path.py CHANGED
@@ -81,6 +81,7 @@ def compute_file_checksum(file_path: Path) -> str:
81
81
  def try_symlink_or_copy(
82
82
  file_path: Path,
83
83
  link_path: Path,
84
+ *,
84
85
  target_is_directory: bool = False,
85
86
  relative: bool = True,
86
87
  remove_existing: bool = True,
@@ -92,7 +93,10 @@ def try_symlink_or_copy(
92
93
  # If the link already exists, remove it
93
94
  if remove_existing:
94
95
  try:
95
- if link_path.exists():
96
+ if link_path.exists(follow_symlinks=False):
97
+ # follow_symlinks=False is EXTREMELY important here
98
+ # Otherwise, we've already deleted the file that the symlink
99
+ # used to point to, so this always returns False
96
100
  if link_path.is_dir():
97
101
  shutil.rmtree(link_path)
98
102
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b41
3
+ Version: 1.0.0b42
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,8 +1,8 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
2
  nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
3
3
  nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
- nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
5
- nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
4
+ nshtrainer/_checkpoint/metadata.py,sha256=Kp1BK1iHCsOzjjy_B0Rbs0yIyd_yV9gZl4uym7rNQ_E,4796
5
+ nshtrainer/_checkpoint/saver.py,sha256=qG8dEetyqwLJS0MkdWDz4DCAInaSyafW8-Oe8x0UVnE,1353
6
6
  nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
@@ -10,7 +10,7 @@ nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTw
10
10
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
11
  nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=wCJBRI0pQYZc3GBu0b-aUBlBDhd39AdL82VvFgKmv3k,6300
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
@@ -122,7 +122,7 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
122
122
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
123
123
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
124
124
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
125
- nshtrainer/nn/tests/test_mlp.py,sha256=xBPiHlBvOCn67EbpzzKL-2FU7ikGxHT3i6CMSp1wk7M,1840
125
+ nshtrainer/nn/tests/test_mlp.py,sha256=COw88Oc4FMvj_mGdQf6F2UtgZ39FE2lbQjTmMAwtCWE,4031
126
126
  nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
127
127
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
128
128
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
@@ -148,11 +148,11 @@ nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDW
148
148
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
149
149
  nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
150
150
  nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
151
- nshtrainer/util/path.py,sha256=L-Nh9tlXSUfoP19TFbQq8I0AfS5ugCfGYTYFeddDHcs,3516
151
+ nshtrainer/util/path.py,sha256=IL8MS29MzP2aBdQhqMl2Eh4RvyuZbzaIbAVZzsdLDSk,3754
152
152
  nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
153
153
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
154
154
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
155
155
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
156
- nshtrainer-1.0.0b41.dist-info/METADATA,sha256=DL9HgN6RP8X8v0sCdTr2IjRSwIBY96NZXe15m5V4y4c,988
157
- nshtrainer-1.0.0b41.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
158
- nshtrainer-1.0.0b41.dist-info/RECORD,,
156
+ nshtrainer-1.0.0b42.dist-info/METADATA,sha256=m8vNeWow5AfufFEd6uCjm6w-0NOCTTzCFjXDed0eq6U,988
157
+ nshtrainer-1.0.0b42.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
158
+ nshtrainer-1.0.0b42.dist-info/RECORD,,