nshtrainer 1.0.0b41__py3-none-any.whl → 1.0.0b42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +6 -14
- nshtrainer/_checkpoint/saver.py +6 -6
- nshtrainer/callbacks/checkpoint/_base.py +3 -3
- nshtrainer/nn/tests/test_mlp.py +66 -1
- nshtrainer/util/path.py +5 -1
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b42.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b42.dist-info}/RECORD +8 -8
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b42.dist-info}/WHEEL +0 -0
@@ -130,7 +130,7 @@ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
|
|
130
130
|
return metadata_path
|
131
131
|
|
132
132
|
|
133
|
-
def
|
133
|
+
def remove_checkpoint_metadata(checkpoint_path: Path):
|
134
134
|
path = _metadata_path(checkpoint_path)
|
135
135
|
try:
|
136
136
|
path.unlink(missing_ok=True)
|
@@ -140,23 +140,15 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
140
140
|
log.debug(f"Removed {path}")
|
141
141
|
|
142
142
|
|
143
|
-
def
|
143
|
+
def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
144
144
|
# First, remove any existing metadata files
|
145
|
-
|
145
|
+
remove_checkpoint_metadata(linked_checkpoint_path)
|
146
146
|
|
147
147
|
# Link the metadata files to the new checkpoint
|
148
148
|
path = _metadata_path(checkpoint_path)
|
149
149
|
linked_path = _metadata_path(linked_checkpoint_path)
|
150
|
-
try_symlink_or_copy(path, linked_path)
|
151
150
|
|
151
|
+
if not path.exists():
|
152
|
+
raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
|
152
153
|
|
153
|
-
|
154
|
-
checkpoint_paths: list[Path],
|
155
|
-
key: Callable[[CheckpointMetadata, Path], Any],
|
156
|
-
reverse: bool = False,
|
157
|
-
):
|
158
|
-
return sorted(
|
159
|
-
[(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
|
160
|
-
key=lambda args_tuple: key(*args_tuple),
|
161
|
-
reverse=reverse,
|
162
|
-
)
|
154
|
+
try_symlink_or_copy(path, linked_path)
|
nshtrainer/_checkpoint/saver.py
CHANGED
@@ -8,12 +8,12 @@ from pathlib import Path
|
|
8
8
|
from lightning.pytorch import Trainer
|
9
9
|
|
10
10
|
from ..util.path import try_symlink_or_copy
|
11
|
-
from .metadata import
|
11
|
+
from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
15
|
|
16
|
-
def
|
16
|
+
def link_checkpoint(
|
17
17
|
filepath: str | Path | os.PathLike,
|
18
18
|
linkpath: str | Path | os.PathLike,
|
19
19
|
*,
|
@@ -36,14 +36,14 @@ def _link_checkpoint(
|
|
36
36
|
log.debug(f"Removed {linkpath=}")
|
37
37
|
|
38
38
|
if metadata:
|
39
|
-
|
39
|
+
remove_checkpoint_metadata(linkpath)
|
40
40
|
|
41
41
|
try_symlink_or_copy(filepath, linkpath)
|
42
42
|
if metadata:
|
43
|
-
|
43
|
+
link_checkpoint_metadata(filepath, linkpath)
|
44
44
|
|
45
45
|
|
46
|
-
def
|
46
|
+
def remove_checkpoint(
|
47
47
|
trainer: Trainer,
|
48
48
|
filepath: str | Path | os.PathLike,
|
49
49
|
*,
|
@@ -54,4 +54,4 @@ def _remove_checkpoint(
|
|
54
54
|
trainer.strategy.remove_checkpoint(filepath)
|
55
55
|
|
56
56
|
if metadata:
|
57
|
-
|
57
|
+
remove_checkpoint_metadata(filepath)
|
@@ -12,7 +12,7 @@ from lightning.pytorch.callbacks import Checkpoint
|
|
12
12
|
from typing_extensions import TypeVar, override
|
13
13
|
|
14
14
|
from ..._checkpoint.metadata import CheckpointMetadata
|
15
|
-
from ..._checkpoint.saver import
|
15
|
+
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
16
16
|
from ..base import CallbackConfigBase
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
@@ -122,7 +122,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
122
122
|
)
|
123
123
|
continue
|
124
124
|
|
125
|
-
|
125
|
+
remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
126
126
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
127
127
|
|
128
128
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
@@ -167,7 +167,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
167
167
|
# Create the latest symlink
|
168
168
|
if (symlink_filename := self.symlink_path()) is not None:
|
169
169
|
symlink_path = self.dirpath / symlink_filename
|
170
|
-
|
170
|
+
link_checkpoint(filepath, symlink_path, metadata=True)
|
171
171
|
log.debug(f"Created latest symlink: {symlink_path}")
|
172
172
|
|
173
173
|
# Barrier to ensure all processes have saved the checkpoint,
|
nshtrainer/nn/tests/test_mlp.py
CHANGED
@@ -5,7 +5,7 @@ from typing import cast
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from nshtrainer.nn.mlp import MLP
|
8
|
+
from nshtrainer.nn.mlp import MLP, custom_seed_context
|
9
9
|
|
10
10
|
|
11
11
|
def test_mlp_seed_reproducibility():
|
@@ -53,3 +53,68 @@ def test_mlp_seed_reproducibility():
|
|
53
53
|
assert not torch.allclose(
|
54
54
|
cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
|
55
55
|
)
|
56
|
+
|
57
|
+
|
58
|
+
def test_custom_seed_context():
|
59
|
+
"""Test that custom_seed_context properly controls random number generation."""
|
60
|
+
|
61
|
+
# Test that the same seed produces the same random numbers
|
62
|
+
with custom_seed_context(42):
|
63
|
+
tensor1 = torch.randn(10)
|
64
|
+
|
65
|
+
with custom_seed_context(42):
|
66
|
+
tensor2 = torch.randn(10)
|
67
|
+
|
68
|
+
# Same seed should produce identical random tensors
|
69
|
+
assert torch.allclose(tensor1, tensor2)
|
70
|
+
|
71
|
+
# Test that different seeds produce different random numbers
|
72
|
+
with custom_seed_context(123):
|
73
|
+
tensor3 = torch.randn(10)
|
74
|
+
|
75
|
+
# Different seeds should produce different random tensors
|
76
|
+
assert not torch.allclose(tensor1, tensor3)
|
77
|
+
|
78
|
+
|
79
|
+
def test_custom_seed_context_preserves_state():
|
80
|
+
"""Test that custom_seed_context preserves the original random state."""
|
81
|
+
|
82
|
+
# Set a known seed for the test
|
83
|
+
original_seed = 789
|
84
|
+
torch.manual_seed(original_seed)
|
85
|
+
|
86
|
+
# Generate a tensor with the original seed
|
87
|
+
original_tensor = torch.randn(10)
|
88
|
+
|
89
|
+
# Use a different seed in the context
|
90
|
+
with custom_seed_context(42):
|
91
|
+
# This should use the temporary seed
|
92
|
+
context_tensor = torch.randn(10)
|
93
|
+
|
94
|
+
# After exiting the context, we should be back to the original seed state
|
95
|
+
# Reset the generator to get the same sequence again
|
96
|
+
torch.manual_seed(original_seed)
|
97
|
+
expected_tensor = torch.randn(10)
|
98
|
+
|
99
|
+
# The tensor generated after the context should match what we would get
|
100
|
+
# if we had just set the original seed again
|
101
|
+
assert torch.allclose(original_tensor, expected_tensor)
|
102
|
+
|
103
|
+
# And it should be different from the tensor generated inside the context
|
104
|
+
assert not torch.allclose(original_tensor, context_tensor)
|
105
|
+
|
106
|
+
|
107
|
+
def test_custom_seed_context_with_none():
|
108
|
+
"""Test that custom_seed_context with None seed doesn't affect randomization."""
|
109
|
+
|
110
|
+
# Set a known seed
|
111
|
+
torch.manual_seed(555)
|
112
|
+
expected_tensor = torch.randn(10)
|
113
|
+
|
114
|
+
# Reset and use None seed in context
|
115
|
+
torch.manual_seed(555)
|
116
|
+
with custom_seed_context(None):
|
117
|
+
actual_tensor = torch.randn(10)
|
118
|
+
|
119
|
+
# With None seed, the context should not affect the random state
|
120
|
+
assert torch.allclose(expected_tensor, actual_tensor)
|
nshtrainer/util/path.py
CHANGED
@@ -81,6 +81,7 @@ def compute_file_checksum(file_path: Path) -> str:
|
|
81
81
|
def try_symlink_or_copy(
|
82
82
|
file_path: Path,
|
83
83
|
link_path: Path,
|
84
|
+
*,
|
84
85
|
target_is_directory: bool = False,
|
85
86
|
relative: bool = True,
|
86
87
|
remove_existing: bool = True,
|
@@ -92,7 +93,10 @@ def try_symlink_or_copy(
|
|
92
93
|
# If the link already exists, remove it
|
93
94
|
if remove_existing:
|
94
95
|
try:
|
95
|
-
if link_path.exists():
|
96
|
+
if link_path.exists(follow_symlinks=False):
|
97
|
+
# follow_symlinks=False is EXTREMELY important here
|
98
|
+
# Otherwise, we've already deleted the file that the symlink
|
99
|
+
# used to point to, so this always returns False
|
96
100
|
if link_path.is_dir():
|
97
101
|
shutil.rmtree(link_path)
|
98
102
|
else:
|
@@ -1,8 +1,8 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
2
|
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
5
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=Kp1BK1iHCsOzjjy_B0Rbs0yIyd_yV9gZl4uym7rNQ_E,4796
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=qG8dEetyqwLJS0MkdWDz4DCAInaSyafW8-Oe8x0UVnE,1353
|
6
6
|
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
@@ -10,7 +10,7 @@ nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTw
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
11
|
nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=wCJBRI0pQYZc3GBu0b-aUBlBDhd39AdL82VvFgKmv3k,6300
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
@@ -122,7 +122,7 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
|
|
122
122
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
123
123
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
124
124
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
125
|
-
nshtrainer/nn/tests/test_mlp.py,sha256=
|
125
|
+
nshtrainer/nn/tests/test_mlp.py,sha256=COw88Oc4FMvj_mGdQf6F2UtgZ39FE2lbQjTmMAwtCWE,4031
|
126
126
|
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
127
127
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
128
128
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
@@ -148,11 +148,11 @@ nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDW
|
|
148
148
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
149
149
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
150
150
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
151
|
-
nshtrainer/util/path.py,sha256=
|
151
|
+
nshtrainer/util/path.py,sha256=IL8MS29MzP2aBdQhqMl2Eh4RvyuZbzaIbAVZzsdLDSk,3754
|
152
152
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
153
153
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
154
154
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
155
155
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
156
|
-
nshtrainer-1.0.
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
156
|
+
nshtrainer-1.0.0b42.dist-info/METADATA,sha256=m8vNeWow5AfufFEd6uCjm6w-0NOCTTzCFjXDed0eq6U,988
|
157
|
+
nshtrainer-1.0.0b42.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
158
|
+
nshtrainer-1.0.0b42.dist-info/RECORD,,
|
File without changes
|