nshtrainer 1.0.0b42__py3-none-any.whl → 1.0.0b43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +1 -2
- nshtrainer/_checkpoint/saver.py +4 -1
- nshtrainer/trainer/trainer.py +2 -2
- nshtrainer/util/path.py +5 -0
- {nshtrainer-1.0.0b42.dist-info → nshtrainer-1.0.0b43.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b42.dist-info → nshtrainer-1.0.0b43.dist-info}/RECORD +7 -8
- nshtrainer/nn/tests/test_mlp.py +0 -120
- {nshtrainer-1.0.0b42.dist-info → nshtrainer-1.0.0b43.dist-info}/WHEEL +0 -0
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
import copy
|
4
4
|
import datetime
|
5
5
|
import logging
|
6
|
-
from collections.abc import Callable
|
7
6
|
from pathlib import Path
|
8
7
|
from typing import TYPE_CHECKING, Any, ClassVar
|
9
8
|
|
@@ -115,7 +114,7 @@ def _metadata_path(checkpoint_path: Path):
|
|
115
114
|
return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
116
115
|
|
117
116
|
|
118
|
-
def
|
117
|
+
def write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
|
119
118
|
metadata_path = _metadata_path(checkpoint_path)
|
120
119
|
metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
|
121
120
|
|
nshtrainer/_checkpoint/saver.py
CHANGED
@@ -25,7 +25,10 @@ def link_checkpoint(
|
|
25
25
|
|
26
26
|
if remove_existing:
|
27
27
|
try:
|
28
|
-
if linkpath.exists():
|
28
|
+
if linkpath.exists(follow_symlinks=False):
|
29
|
+
# follow_symlinks=False is EXTREMELY important here
|
30
|
+
# Otherwise, we've already deleted the file that the symlink
|
31
|
+
# used to point to, so this always returns False
|
29
32
|
if linkpath.is_dir():
|
30
33
|
shutil.rmtree(linkpath)
|
31
34
|
else:
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -18,7 +18,7 @@ from lightning.pytorch.trainer.states import TrainerFn
|
|
18
18
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
19
19
|
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
20
20
|
|
21
|
-
from .._checkpoint.metadata import
|
21
|
+
from .._checkpoint.metadata import write_checkpoint_metadata
|
22
22
|
from ..callbacks.base import resolve_all_callbacks
|
23
23
|
from ..util._environment_info import EnvironmentConfig
|
24
24
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
@@ -478,7 +478,7 @@ class Trainer(LightningTrainer):
|
|
478
478
|
metadata_path = None
|
479
479
|
if self.hparams.save_checkpoint_metadata and self.is_global_zero:
|
480
480
|
# Generate the metadata and write to disk
|
481
|
-
metadata_path =
|
481
|
+
metadata_path = write_checkpoint_metadata(self, filepath)
|
482
482
|
|
483
483
|
# Call the `on_checkpoint_saved` method on all callbacks
|
484
484
|
from .. import _callback
|
nshtrainer/util/path.py
CHANGED
@@ -85,11 +85,16 @@ def try_symlink_or_copy(
|
|
85
85
|
target_is_directory: bool = False,
|
86
86
|
relative: bool = True,
|
87
87
|
remove_existing: bool = True,
|
88
|
+
throw_on_invalid_target: bool = False,
|
88
89
|
):
|
89
90
|
"""
|
90
91
|
Symlinks on Unix, copies on Windows.
|
91
92
|
"""
|
92
93
|
|
94
|
+
# Check if the target file exists
|
95
|
+
if throw_on_invalid_target and not file_path.exists():
|
96
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
97
|
+
|
93
98
|
# If the link already exists, remove it
|
94
99
|
if remove_existing:
|
95
100
|
try:
|
@@ -1,8 +1,8 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
2
|
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
5
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=LQZ8g50rKxQQx-FqiW3n8EWmal9qSWRouOpIIn6NJJY,4758
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zuyI,1584
|
6
6
|
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
@@ -122,7 +122,6 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
|
|
122
122
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
123
123
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
124
124
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
125
|
-
nshtrainer/nn/tests/test_mlp.py,sha256=COw88Oc4FMvj_mGdQf6F2UtgZ39FE2lbQjTmMAwtCWE,4031
|
126
125
|
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
127
126
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
128
127
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
@@ -141,18 +140,18 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
|
|
141
140
|
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
142
141
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
143
142
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
144
|
-
nshtrainer/trainer/trainer.py,sha256=
|
143
|
+
nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
|
145
144
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
146
145
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
147
146
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
148
147
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
149
148
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
150
149
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
151
|
-
nshtrainer/util/path.py,sha256=
|
150
|
+
nshtrainer/util/path.py,sha256=9fIjE3S78pPL6wjAgEJUYfIJQAPdKOQqIYvTS9lWTUk,3959
|
152
151
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
153
152
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
154
153
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
155
154
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
156
|
-
nshtrainer-1.0.
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
155
|
+
nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
|
156
|
+
nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
157
|
+
nshtrainer-1.0.0b43.dist-info/RECORD,,
|
nshtrainer/nn/tests/test_mlp.py
DELETED
@@ -1,120 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import cast
|
4
|
-
|
5
|
-
import pytest
|
6
|
-
import torch
|
7
|
-
|
8
|
-
from nshtrainer.nn.mlp import MLP, custom_seed_context
|
9
|
-
|
10
|
-
|
11
|
-
def test_mlp_seed_reproducibility():
|
12
|
-
"""Test that the seed parameter in MLP ensures reproducible weights."""
|
13
|
-
|
14
|
-
# Test dimensions
|
15
|
-
dims = [10, 20, 5]
|
16
|
-
|
17
|
-
# Create two MLPs with the same seed
|
18
|
-
seed1 = 42
|
19
|
-
mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
|
20
|
-
mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
|
21
|
-
|
22
|
-
# Create an MLP with a different seed
|
23
|
-
seed2 = 123
|
24
|
-
mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
|
25
|
-
|
26
|
-
# Check first layer weights
|
27
|
-
layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
|
28
|
-
layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
|
29
|
-
layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
|
30
|
-
|
31
|
-
# Same seed should produce identical weights
|
32
|
-
assert torch.allclose(layer1_weights1, layer1_weights2)
|
33
|
-
|
34
|
-
# Different seeds should produce different weights
|
35
|
-
assert not torch.allclose(layer1_weights1, layer1_weights3)
|
36
|
-
|
37
|
-
# Check second layer weights
|
38
|
-
layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
|
39
|
-
layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
|
40
|
-
layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
|
41
|
-
|
42
|
-
# Same seed should produce identical weights for all layers
|
43
|
-
assert torch.allclose(layer2_weights1, layer2_weights2)
|
44
|
-
|
45
|
-
# Different seeds should produce different weights for all layers
|
46
|
-
assert not torch.allclose(layer2_weights1, layer2_weights3)
|
47
|
-
|
48
|
-
# Test that not providing a seed gives different results each time
|
49
|
-
mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
|
50
|
-
mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
|
51
|
-
|
52
|
-
# Without seeds, weights should be different
|
53
|
-
assert not torch.allclose(
|
54
|
-
cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
|
55
|
-
)
|
56
|
-
|
57
|
-
|
58
|
-
def test_custom_seed_context():
|
59
|
-
"""Test that custom_seed_context properly controls random number generation."""
|
60
|
-
|
61
|
-
# Test that the same seed produces the same random numbers
|
62
|
-
with custom_seed_context(42):
|
63
|
-
tensor1 = torch.randn(10)
|
64
|
-
|
65
|
-
with custom_seed_context(42):
|
66
|
-
tensor2 = torch.randn(10)
|
67
|
-
|
68
|
-
# Same seed should produce identical random tensors
|
69
|
-
assert torch.allclose(tensor1, tensor2)
|
70
|
-
|
71
|
-
# Test that different seeds produce different random numbers
|
72
|
-
with custom_seed_context(123):
|
73
|
-
tensor3 = torch.randn(10)
|
74
|
-
|
75
|
-
# Different seeds should produce different random tensors
|
76
|
-
assert not torch.allclose(tensor1, tensor3)
|
77
|
-
|
78
|
-
|
79
|
-
def test_custom_seed_context_preserves_state():
|
80
|
-
"""Test that custom_seed_context preserves the original random state."""
|
81
|
-
|
82
|
-
# Set a known seed for the test
|
83
|
-
original_seed = 789
|
84
|
-
torch.manual_seed(original_seed)
|
85
|
-
|
86
|
-
# Generate a tensor with the original seed
|
87
|
-
original_tensor = torch.randn(10)
|
88
|
-
|
89
|
-
# Use a different seed in the context
|
90
|
-
with custom_seed_context(42):
|
91
|
-
# This should use the temporary seed
|
92
|
-
context_tensor = torch.randn(10)
|
93
|
-
|
94
|
-
# After exiting the context, we should be back to the original seed state
|
95
|
-
# Reset the generator to get the same sequence again
|
96
|
-
torch.manual_seed(original_seed)
|
97
|
-
expected_tensor = torch.randn(10)
|
98
|
-
|
99
|
-
# The tensor generated after the context should match what we would get
|
100
|
-
# if we had just set the original seed again
|
101
|
-
assert torch.allclose(original_tensor, expected_tensor)
|
102
|
-
|
103
|
-
# And it should be different from the tensor generated inside the context
|
104
|
-
assert not torch.allclose(original_tensor, context_tensor)
|
105
|
-
|
106
|
-
|
107
|
-
def test_custom_seed_context_with_none():
|
108
|
-
"""Test that custom_seed_context with None seed doesn't affect randomization."""
|
109
|
-
|
110
|
-
# Set a known seed
|
111
|
-
torch.manual_seed(555)
|
112
|
-
expected_tensor = torch.randn(10)
|
113
|
-
|
114
|
-
# Reset and use None seed in context
|
115
|
-
torch.manual_seed(555)
|
116
|
-
with custom_seed_context(None):
|
117
|
-
actual_tensor = torch.randn(10)
|
118
|
-
|
119
|
-
# With None seed, the context should not affect the random state
|
120
|
-
assert torch.allclose(expected_tensor, actual_tensor)
|
File without changes
|