nshtrainer 1.0.0b40__py3-none-any.whl → 1.0.0b42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +6 -14
- nshtrainer/_checkpoint/saver.py +6 -6
- nshtrainer/callbacks/checkpoint/_base.py +3 -3
- nshtrainer/nn/__init__.py +1 -0
- nshtrainer/nn/mlp.py +13 -7
- nshtrainer/nn/tests/test_mlp.py +66 -1
- nshtrainer/util/path.py +5 -1
- {nshtrainer-1.0.0b40.dist-info → nshtrainer-1.0.0b42.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b40.dist-info → nshtrainer-1.0.0b42.dist-info}/RECORD +10 -10
- {nshtrainer-1.0.0b40.dist-info → nshtrainer-1.0.0b42.dist-info}/WHEEL +0 -0
@@ -130,7 +130,7 @@ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
|
|
130
130
|
return metadata_path
|
131
131
|
|
132
132
|
|
133
|
-
def
|
133
|
+
def remove_checkpoint_metadata(checkpoint_path: Path):
|
134
134
|
path = _metadata_path(checkpoint_path)
|
135
135
|
try:
|
136
136
|
path.unlink(missing_ok=True)
|
@@ -140,23 +140,15 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
140
140
|
log.debug(f"Removed {path}")
|
141
141
|
|
142
142
|
|
143
|
-
def
|
143
|
+
def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
144
144
|
# First, remove any existing metadata files
|
145
|
-
|
145
|
+
remove_checkpoint_metadata(linked_checkpoint_path)
|
146
146
|
|
147
147
|
# Link the metadata files to the new checkpoint
|
148
148
|
path = _metadata_path(checkpoint_path)
|
149
149
|
linked_path = _metadata_path(linked_checkpoint_path)
|
150
|
-
try_symlink_or_copy(path, linked_path)
|
151
150
|
|
151
|
+
if not path.exists():
|
152
|
+
raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
|
152
153
|
|
153
|
-
|
154
|
-
checkpoint_paths: list[Path],
|
155
|
-
key: Callable[[CheckpointMetadata, Path], Any],
|
156
|
-
reverse: bool = False,
|
157
|
-
):
|
158
|
-
return sorted(
|
159
|
-
[(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
|
160
|
-
key=lambda args_tuple: key(*args_tuple),
|
161
|
-
reverse=reverse,
|
162
|
-
)
|
154
|
+
try_symlink_or_copy(path, linked_path)
|
nshtrainer/_checkpoint/saver.py
CHANGED
@@ -8,12 +8,12 @@ from pathlib import Path
|
|
8
8
|
from lightning.pytorch import Trainer
|
9
9
|
|
10
10
|
from ..util.path import try_symlink_or_copy
|
11
|
-
from .metadata import
|
11
|
+
from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
15
|
|
16
|
-
def
|
16
|
+
def link_checkpoint(
|
17
17
|
filepath: str | Path | os.PathLike,
|
18
18
|
linkpath: str | Path | os.PathLike,
|
19
19
|
*,
|
@@ -36,14 +36,14 @@ def _link_checkpoint(
|
|
36
36
|
log.debug(f"Removed {linkpath=}")
|
37
37
|
|
38
38
|
if metadata:
|
39
|
-
|
39
|
+
remove_checkpoint_metadata(linkpath)
|
40
40
|
|
41
41
|
try_symlink_or_copy(filepath, linkpath)
|
42
42
|
if metadata:
|
43
|
-
|
43
|
+
link_checkpoint_metadata(filepath, linkpath)
|
44
44
|
|
45
45
|
|
46
|
-
def
|
46
|
+
def remove_checkpoint(
|
47
47
|
trainer: Trainer,
|
48
48
|
filepath: str | Path | os.PathLike,
|
49
49
|
*,
|
@@ -54,4 +54,4 @@ def _remove_checkpoint(
|
|
54
54
|
trainer.strategy.remove_checkpoint(filepath)
|
55
55
|
|
56
56
|
if metadata:
|
57
|
-
|
57
|
+
remove_checkpoint_metadata(filepath)
|
@@ -12,7 +12,7 @@ from lightning.pytorch.callbacks import Checkpoint
|
|
12
12
|
from typing_extensions import TypeVar, override
|
13
13
|
|
14
14
|
from ..._checkpoint.metadata import CheckpointMetadata
|
15
|
-
from ..._checkpoint.saver import
|
15
|
+
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
16
16
|
from ..base import CallbackConfigBase
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
@@ -122,7 +122,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
122
122
|
)
|
123
123
|
continue
|
124
124
|
|
125
|
-
|
125
|
+
remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
126
126
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
127
127
|
|
128
128
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
@@ -167,7 +167,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
167
167
|
# Create the latest symlink
|
168
168
|
if (symlink_filename := self.symlink_path()) is not None:
|
169
169
|
symlink_path = self.dirpath / symlink_filename
|
170
|
-
|
170
|
+
link_checkpoint(filepath, symlink_path, metadata=True)
|
171
171
|
log.debug(f"Created latest symlink: {symlink_path}")
|
172
172
|
|
173
173
|
# Barrier to ensure all processes have saved the checkpoint,
|
nshtrainer/nn/__init__.py
CHANGED
@@ -4,6 +4,7 @@ from .mlp import MLP as MLP
|
|
4
4
|
from .mlp import MLPConfig as MLPConfig
|
5
5
|
from .mlp import MLPConfigDict as MLPConfigDict
|
6
6
|
from .mlp import ResidualSequential as ResidualSequential
|
7
|
+
from .mlp import custom_seed_context as custom_seed_context
|
7
8
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
8
9
|
from .module_list import TypedModuleList as TypedModuleList
|
9
10
|
from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
nshtrainer/nn/mlp.py
CHANGED
@@ -99,6 +99,18 @@ class MLPConfig(C.Config):
|
|
99
99
|
)
|
100
100
|
|
101
101
|
|
102
|
+
@contextlib.contextmanager
|
103
|
+
def custom_seed_context(seed: int | None):
|
104
|
+
with contextlib.ExitStack() as stack:
|
105
|
+
if seed is not None:
|
106
|
+
stack.enter_context(
|
107
|
+
torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
108
|
+
)
|
109
|
+
torch.manual_seed(seed)
|
110
|
+
|
111
|
+
yield
|
112
|
+
|
113
|
+
|
102
114
|
def MLP(
|
103
115
|
dims: Sequence[int],
|
104
116
|
activation: NonlinearityConfigBase
|
@@ -140,13 +152,7 @@ def MLP(
|
|
140
152
|
nn.Sequential: The constructed MLP.
|
141
153
|
"""
|
142
154
|
|
143
|
-
with
|
144
|
-
if seed is not None:
|
145
|
-
stack.enter_context(
|
146
|
-
torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
147
|
-
)
|
148
|
-
torch.manual_seed(seed)
|
149
|
-
|
155
|
+
with custom_seed_context(seed):
|
150
156
|
if activation is None:
|
151
157
|
activation = nonlinearity
|
152
158
|
|
nshtrainer/nn/tests/test_mlp.py
CHANGED
@@ -5,7 +5,7 @@ from typing import cast
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from nshtrainer.nn.mlp import MLP
|
8
|
+
from nshtrainer.nn.mlp import MLP, custom_seed_context
|
9
9
|
|
10
10
|
|
11
11
|
def test_mlp_seed_reproducibility():
|
@@ -53,3 +53,68 @@ def test_mlp_seed_reproducibility():
|
|
53
53
|
assert not torch.allclose(
|
54
54
|
cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
|
55
55
|
)
|
56
|
+
|
57
|
+
|
58
|
+
def test_custom_seed_context():
|
59
|
+
"""Test that custom_seed_context properly controls random number generation."""
|
60
|
+
|
61
|
+
# Test that the same seed produces the same random numbers
|
62
|
+
with custom_seed_context(42):
|
63
|
+
tensor1 = torch.randn(10)
|
64
|
+
|
65
|
+
with custom_seed_context(42):
|
66
|
+
tensor2 = torch.randn(10)
|
67
|
+
|
68
|
+
# Same seed should produce identical random tensors
|
69
|
+
assert torch.allclose(tensor1, tensor2)
|
70
|
+
|
71
|
+
# Test that different seeds produce different random numbers
|
72
|
+
with custom_seed_context(123):
|
73
|
+
tensor3 = torch.randn(10)
|
74
|
+
|
75
|
+
# Different seeds should produce different random tensors
|
76
|
+
assert not torch.allclose(tensor1, tensor3)
|
77
|
+
|
78
|
+
|
79
|
+
def test_custom_seed_context_preserves_state():
|
80
|
+
"""Test that custom_seed_context preserves the original random state."""
|
81
|
+
|
82
|
+
# Set a known seed for the test
|
83
|
+
original_seed = 789
|
84
|
+
torch.manual_seed(original_seed)
|
85
|
+
|
86
|
+
# Generate a tensor with the original seed
|
87
|
+
original_tensor = torch.randn(10)
|
88
|
+
|
89
|
+
# Use a different seed in the context
|
90
|
+
with custom_seed_context(42):
|
91
|
+
# This should use the temporary seed
|
92
|
+
context_tensor = torch.randn(10)
|
93
|
+
|
94
|
+
# After exiting the context, we should be back to the original seed state
|
95
|
+
# Reset the generator to get the same sequence again
|
96
|
+
torch.manual_seed(original_seed)
|
97
|
+
expected_tensor = torch.randn(10)
|
98
|
+
|
99
|
+
# The tensor generated after the context should match what we would get
|
100
|
+
# if we had just set the original seed again
|
101
|
+
assert torch.allclose(original_tensor, expected_tensor)
|
102
|
+
|
103
|
+
# And it should be different from the tensor generated inside the context
|
104
|
+
assert not torch.allclose(original_tensor, context_tensor)
|
105
|
+
|
106
|
+
|
107
|
+
def test_custom_seed_context_with_none():
|
108
|
+
"""Test that custom_seed_context with None seed doesn't affect randomization."""
|
109
|
+
|
110
|
+
# Set a known seed
|
111
|
+
torch.manual_seed(555)
|
112
|
+
expected_tensor = torch.randn(10)
|
113
|
+
|
114
|
+
# Reset and use None seed in context
|
115
|
+
torch.manual_seed(555)
|
116
|
+
with custom_seed_context(None):
|
117
|
+
actual_tensor = torch.randn(10)
|
118
|
+
|
119
|
+
# With None seed, the context should not affect the random state
|
120
|
+
assert torch.allclose(expected_tensor, actual_tensor)
|
nshtrainer/util/path.py
CHANGED
@@ -81,6 +81,7 @@ def compute_file_checksum(file_path: Path) -> str:
|
|
81
81
|
def try_symlink_or_copy(
|
82
82
|
file_path: Path,
|
83
83
|
link_path: Path,
|
84
|
+
*,
|
84
85
|
target_is_directory: bool = False,
|
85
86
|
relative: bool = True,
|
86
87
|
remove_existing: bool = True,
|
@@ -92,7 +93,10 @@ def try_symlink_or_copy(
|
|
92
93
|
# If the link already exists, remove it
|
93
94
|
if remove_existing:
|
94
95
|
try:
|
95
|
-
if link_path.exists():
|
96
|
+
if link_path.exists(follow_symlinks=False):
|
97
|
+
# follow_symlinks=False is EXTREMELY important here
|
98
|
+
# Otherwise, we've already deleted the file that the symlink
|
99
|
+
# used to point to, so this always returns False
|
96
100
|
if link_path.is_dir():
|
97
101
|
shutil.rmtree(link_path)
|
98
102
|
else:
|
@@ -1,8 +1,8 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
2
|
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
5
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=Kp1BK1iHCsOzjjy_B0Rbs0yIyd_yV9gZl4uym7rNQ_E,4796
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=qG8dEetyqwLJS0MkdWDz4DCAInaSyafW8-Oe8x0UVnE,1353
|
6
6
|
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
@@ -10,7 +10,7 @@ nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTw
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
11
|
nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=wCJBRI0pQYZc3GBu0b-aUBlBDhd39AdL82VvFgKmv3k,6300
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
@@ -117,12 +117,12 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
|
|
117
117
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
118
118
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
119
119
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
120
|
-
nshtrainer/nn/__init__.py,sha256=
|
121
|
-
nshtrainer/nn/mlp.py,sha256=
|
120
|
+
nshtrainer/nn/__init__.py,sha256=0FgeoaLYtRiSLT8fdPigLD8t-d8DKR8IQDw16JA9lT4,1523
|
121
|
+
nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
|
122
122
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
123
123
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
124
124
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
125
|
-
nshtrainer/nn/tests/test_mlp.py,sha256=
|
125
|
+
nshtrainer/nn/tests/test_mlp.py,sha256=COw88Oc4FMvj_mGdQf6F2UtgZ39FE2lbQjTmMAwtCWE,4031
|
126
126
|
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
127
127
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
128
128
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
@@ -148,11 +148,11 @@ nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDW
|
|
148
148
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
149
149
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
150
150
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
151
|
-
nshtrainer/util/path.py,sha256=
|
151
|
+
nshtrainer/util/path.py,sha256=IL8MS29MzP2aBdQhqMl2Eh4RvyuZbzaIbAVZzsdLDSk,3754
|
152
152
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
153
153
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
154
154
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
155
155
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
156
|
-
nshtrainer-1.0.
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
156
|
+
nshtrainer-1.0.0b42.dist-info/METADATA,sha256=m8vNeWow5AfufFEd6uCjm6w-0NOCTTzCFjXDed0eq6U,988
|
157
|
+
nshtrainer-1.0.0b42.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
158
|
+
nshtrainer-1.0.0b42.dist-info/RECORD,,
|
File without changes
|