nshtrainer 1.0.0b41__py3-none-any.whl → 1.0.0b43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +7 -16
- nshtrainer/_checkpoint/saver.py +10 -7
- nshtrainer/callbacks/checkpoint/_base.py +3 -3
- nshtrainer/trainer/trainer.py +2 -2
- nshtrainer/util/path.py +10 -1
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b43.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b43.dist-info}/RECORD +8 -9
- nshtrainer/nn/tests/test_mlp.py +0 -55
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b43.dist-info}/WHEEL +0 -0
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
import copy
|
4
4
|
import datetime
|
5
5
|
import logging
|
6
|
-
from collections.abc import Callable
|
7
6
|
from pathlib import Path
|
8
7
|
from typing import TYPE_CHECKING, Any, ClassVar
|
9
8
|
|
@@ -115,7 +114,7 @@ def _metadata_path(checkpoint_path: Path):
|
|
115
114
|
return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
116
115
|
|
117
116
|
|
118
|
-
def
|
117
|
+
def write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
|
119
118
|
metadata_path = _metadata_path(checkpoint_path)
|
120
119
|
metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
|
121
120
|
|
@@ -130,7 +129,7 @@ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
|
|
130
129
|
return metadata_path
|
131
130
|
|
132
131
|
|
133
|
-
def
|
132
|
+
def remove_checkpoint_metadata(checkpoint_path: Path):
|
134
133
|
path = _metadata_path(checkpoint_path)
|
135
134
|
try:
|
136
135
|
path.unlink(missing_ok=True)
|
@@ -140,23 +139,15 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
140
139
|
log.debug(f"Removed {path}")
|
141
140
|
|
142
141
|
|
143
|
-
def
|
142
|
+
def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
144
143
|
# First, remove any existing metadata files
|
145
|
-
|
144
|
+
remove_checkpoint_metadata(linked_checkpoint_path)
|
146
145
|
|
147
146
|
# Link the metadata files to the new checkpoint
|
148
147
|
path = _metadata_path(checkpoint_path)
|
149
148
|
linked_path = _metadata_path(linked_checkpoint_path)
|
150
|
-
try_symlink_or_copy(path, linked_path)
|
151
149
|
|
150
|
+
if not path.exists():
|
151
|
+
raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
|
152
152
|
|
153
|
-
|
154
|
-
checkpoint_paths: list[Path],
|
155
|
-
key: Callable[[CheckpointMetadata, Path], Any],
|
156
|
-
reverse: bool = False,
|
157
|
-
):
|
158
|
-
return sorted(
|
159
|
-
[(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
|
160
|
-
key=lambda args_tuple: key(*args_tuple),
|
161
|
-
reverse=reverse,
|
162
|
-
)
|
153
|
+
try_symlink_or_copy(path, linked_path)
|
nshtrainer/_checkpoint/saver.py
CHANGED
@@ -8,12 +8,12 @@ from pathlib import Path
|
|
8
8
|
from lightning.pytorch import Trainer
|
9
9
|
|
10
10
|
from ..util.path import try_symlink_or_copy
|
11
|
-
from .metadata import
|
11
|
+
from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
15
|
|
16
|
-
def
|
16
|
+
def link_checkpoint(
|
17
17
|
filepath: str | Path | os.PathLike,
|
18
18
|
linkpath: str | Path | os.PathLike,
|
19
19
|
*,
|
@@ -25,7 +25,10 @@ def _link_checkpoint(
|
|
25
25
|
|
26
26
|
if remove_existing:
|
27
27
|
try:
|
28
|
-
if linkpath.exists():
|
28
|
+
if linkpath.exists(follow_symlinks=False):
|
29
|
+
# follow_symlinks=False is EXTREMELY important here
|
30
|
+
# Otherwise, we've already deleted the file that the symlink
|
31
|
+
# used to point to, so this always returns False
|
29
32
|
if linkpath.is_dir():
|
30
33
|
shutil.rmtree(linkpath)
|
31
34
|
else:
|
@@ -36,14 +39,14 @@ def _link_checkpoint(
|
|
36
39
|
log.debug(f"Removed {linkpath=}")
|
37
40
|
|
38
41
|
if metadata:
|
39
|
-
|
42
|
+
remove_checkpoint_metadata(linkpath)
|
40
43
|
|
41
44
|
try_symlink_or_copy(filepath, linkpath)
|
42
45
|
if metadata:
|
43
|
-
|
46
|
+
link_checkpoint_metadata(filepath, linkpath)
|
44
47
|
|
45
48
|
|
46
|
-
def
|
49
|
+
def remove_checkpoint(
|
47
50
|
trainer: Trainer,
|
48
51
|
filepath: str | Path | os.PathLike,
|
49
52
|
*,
|
@@ -54,4 +57,4 @@ def _remove_checkpoint(
|
|
54
57
|
trainer.strategy.remove_checkpoint(filepath)
|
55
58
|
|
56
59
|
if metadata:
|
57
|
-
|
60
|
+
remove_checkpoint_metadata(filepath)
|
@@ -12,7 +12,7 @@ from lightning.pytorch.callbacks import Checkpoint
|
|
12
12
|
from typing_extensions import TypeVar, override
|
13
13
|
|
14
14
|
from ..._checkpoint.metadata import CheckpointMetadata
|
15
|
-
from ..._checkpoint.saver import
|
15
|
+
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
16
16
|
from ..base import CallbackConfigBase
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
@@ -122,7 +122,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
122
122
|
)
|
123
123
|
continue
|
124
124
|
|
125
|
-
|
125
|
+
remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
126
126
|
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
127
127
|
|
128
128
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
@@ -167,7 +167,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
167
167
|
# Create the latest symlink
|
168
168
|
if (symlink_filename := self.symlink_path()) is not None:
|
169
169
|
symlink_path = self.dirpath / symlink_filename
|
170
|
-
|
170
|
+
link_checkpoint(filepath, symlink_path, metadata=True)
|
171
171
|
log.debug(f"Created latest symlink: {symlink_path}")
|
172
172
|
|
173
173
|
# Barrier to ensure all processes have saved the checkpoint,
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -18,7 +18,7 @@ from lightning.pytorch.trainer.states import TrainerFn
|
|
18
18
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
19
19
|
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
20
20
|
|
21
|
-
from .._checkpoint.metadata import
|
21
|
+
from .._checkpoint.metadata import write_checkpoint_metadata
|
22
22
|
from ..callbacks.base import resolve_all_callbacks
|
23
23
|
from ..util._environment_info import EnvironmentConfig
|
24
24
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
@@ -478,7 +478,7 @@ class Trainer(LightningTrainer):
|
|
478
478
|
metadata_path = None
|
479
479
|
if self.hparams.save_checkpoint_metadata and self.is_global_zero:
|
480
480
|
# Generate the metadata and write to disk
|
481
|
-
metadata_path =
|
481
|
+
metadata_path = write_checkpoint_metadata(self, filepath)
|
482
482
|
|
483
483
|
# Call the `on_checkpoint_saved` method on all callbacks
|
484
484
|
from .. import _callback
|
nshtrainer/util/path.py
CHANGED
@@ -81,18 +81,27 @@ def compute_file_checksum(file_path: Path) -> str:
|
|
81
81
|
def try_symlink_or_copy(
|
82
82
|
file_path: Path,
|
83
83
|
link_path: Path,
|
84
|
+
*,
|
84
85
|
target_is_directory: bool = False,
|
85
86
|
relative: bool = True,
|
86
87
|
remove_existing: bool = True,
|
88
|
+
throw_on_invalid_target: bool = False,
|
87
89
|
):
|
88
90
|
"""
|
89
91
|
Symlinks on Unix, copies on Windows.
|
90
92
|
"""
|
91
93
|
|
94
|
+
# Check if the target file exists
|
95
|
+
if throw_on_invalid_target and not file_path.exists():
|
96
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
97
|
+
|
92
98
|
# If the link already exists, remove it
|
93
99
|
if remove_existing:
|
94
100
|
try:
|
95
|
-
if link_path.exists():
|
101
|
+
if link_path.exists(follow_symlinks=False):
|
102
|
+
# follow_symlinks=False is EXTREMELY important here
|
103
|
+
# Otherwise, we've already deleted the file that the symlink
|
104
|
+
# used to point to, so this always returns False
|
96
105
|
if link_path.is_dir():
|
97
106
|
shutil.rmtree(link_path)
|
98
107
|
else:
|
@@ -1,8 +1,8 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
2
|
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
5
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=LQZ8g50rKxQQx-FqiW3n8EWmal9qSWRouOpIIn6NJJY,4758
|
5
|
+
nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zuyI,1584
|
6
6
|
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
@@ -10,7 +10,7 @@ nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTw
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
11
|
nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=wCJBRI0pQYZc3GBu0b-aUBlBDhd39AdL82VvFgKmv3k,6300
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
@@ -122,7 +122,6 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
|
|
122
122
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
123
123
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
124
124
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
125
|
-
nshtrainer/nn/tests/test_mlp.py,sha256=xBPiHlBvOCn67EbpzzKL-2FU7ikGxHT3i6CMSp1wk7M,1840
|
126
125
|
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
127
126
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
128
127
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
@@ -141,18 +140,18 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
|
|
141
140
|
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
142
141
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
143
142
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
144
|
-
nshtrainer/trainer/trainer.py,sha256=
|
143
|
+
nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
|
145
144
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
146
145
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
147
146
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
148
147
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
149
148
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
150
149
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
151
|
-
nshtrainer/util/path.py,sha256=
|
150
|
+
nshtrainer/util/path.py,sha256=9fIjE3S78pPL6wjAgEJUYfIJQAPdKOQqIYvTS9lWTUk,3959
|
152
151
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
153
152
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
154
153
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
155
154
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
156
|
-
nshtrainer-1.0.
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
155
|
+
nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
|
156
|
+
nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
157
|
+
nshtrainer-1.0.0b43.dist-info/RECORD,,
|
nshtrainer/nn/tests/test_mlp.py
DELETED
@@ -1,55 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import cast
|
4
|
-
|
5
|
-
import pytest
|
6
|
-
import torch
|
7
|
-
|
8
|
-
from nshtrainer.nn.mlp import MLP
|
9
|
-
|
10
|
-
|
11
|
-
def test_mlp_seed_reproducibility():
|
12
|
-
"""Test that the seed parameter in MLP ensures reproducible weights."""
|
13
|
-
|
14
|
-
# Test dimensions
|
15
|
-
dims = [10, 20, 5]
|
16
|
-
|
17
|
-
# Create two MLPs with the same seed
|
18
|
-
seed1 = 42
|
19
|
-
mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
|
20
|
-
mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
|
21
|
-
|
22
|
-
# Create an MLP with a different seed
|
23
|
-
seed2 = 123
|
24
|
-
mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
|
25
|
-
|
26
|
-
# Check first layer weights
|
27
|
-
layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
|
28
|
-
layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
|
29
|
-
layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
|
30
|
-
|
31
|
-
# Same seed should produce identical weights
|
32
|
-
assert torch.allclose(layer1_weights1, layer1_weights2)
|
33
|
-
|
34
|
-
# Different seeds should produce different weights
|
35
|
-
assert not torch.allclose(layer1_weights1, layer1_weights3)
|
36
|
-
|
37
|
-
# Check second layer weights
|
38
|
-
layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
|
39
|
-
layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
|
40
|
-
layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
|
41
|
-
|
42
|
-
# Same seed should produce identical weights for all layers
|
43
|
-
assert torch.allclose(layer2_weights1, layer2_weights2)
|
44
|
-
|
45
|
-
# Different seeds should produce different weights for all layers
|
46
|
-
assert not torch.allclose(layer2_weights1, layer2_weights3)
|
47
|
-
|
48
|
-
# Test that not providing a seed gives different results each time
|
49
|
-
mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
|
50
|
-
mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
|
51
|
-
|
52
|
-
# Without seeds, weights should be different
|
53
|
-
assert not torch.allclose(
|
54
|
-
cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
|
55
|
-
)
|
File without changes
|