nshtrainer 1.0.0b42__py3-none-any.whl → 1.0.0b43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  import copy
4
4
  import datetime
5
5
  import logging
6
- from collections.abc import Callable
7
6
  from pathlib import Path
8
7
  from typing import TYPE_CHECKING, Any, ClassVar
9
8
 
@@ -115,7 +114,7 @@ def _metadata_path(checkpoint_path: Path):
115
114
  return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
116
115
 
117
116
 
118
- def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
117
+ def write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
119
118
  metadata_path = _metadata_path(checkpoint_path)
120
119
  metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
121
120
 
@@ -25,7 +25,10 @@ def link_checkpoint(
25
25
 
26
26
  if remove_existing:
27
27
  try:
28
- if linkpath.exists():
28
+ if linkpath.exists(follow_symlinks=False):
29
+ # follow_symlinks=False is EXTREMELY important here
30
+ # Otherwise, we've already deleted the file that the symlink
31
+ # used to point to, so this always returns False
29
32
  if linkpath.is_dir():
30
33
  shutil.rmtree(linkpath)
31
34
  else:
@@ -18,7 +18,7 @@ from lightning.pytorch.trainer.states import TrainerFn
18
18
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
19
19
  from typing_extensions import Never, Unpack, assert_never, deprecated, override
20
20
 
21
- from .._checkpoint.metadata import _write_checkpoint_metadata
21
+ from .._checkpoint.metadata import write_checkpoint_metadata
22
22
  from ..callbacks.base import resolve_all_callbacks
23
23
  from ..util._environment_info import EnvironmentConfig
24
24
  from ..util.bf16 import is_bf16_supported_no_emulation
@@ -478,7 +478,7 @@ class Trainer(LightningTrainer):
478
478
  metadata_path = None
479
479
  if self.hparams.save_checkpoint_metadata and self.is_global_zero:
480
480
  # Generate the metadata and write to disk
481
- metadata_path = _write_checkpoint_metadata(self, filepath)
481
+ metadata_path = write_checkpoint_metadata(self, filepath)
482
482
 
483
483
  # Call the `on_checkpoint_saved` method on all callbacks
484
484
  from .. import _callback
nshtrainer/util/path.py CHANGED
@@ -85,11 +85,16 @@ def try_symlink_or_copy(
85
85
  target_is_directory: bool = False,
86
86
  relative: bool = True,
87
87
  remove_existing: bool = True,
88
+ throw_on_invalid_target: bool = False,
88
89
  ):
89
90
  """
90
91
  Symlinks on Unix, copies on Windows.
91
92
  """
92
93
 
94
+ # Check if the target file exists
95
+ if throw_on_invalid_target and not file_path.exists():
96
+ raise FileNotFoundError(f"File not found: {file_path}")
97
+
93
98
  # If the link already exists, remove it
94
99
  if remove_existing:
95
100
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b42
3
+ Version: 1.0.0b43
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,8 +1,8 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
2
  nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
3
3
  nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
- nshtrainer/_checkpoint/metadata.py,sha256=Kp1BK1iHCsOzjjy_B0Rbs0yIyd_yV9gZl4uym7rNQ_E,4796
5
- nshtrainer/_checkpoint/saver.py,sha256=qG8dEetyqwLJS0MkdWDz4DCAInaSyafW8-Oe8x0UVnE,1353
4
+ nshtrainer/_checkpoint/metadata.py,sha256=LQZ8g50rKxQQx-FqiW3n8EWmal9qSWRouOpIIn6NJJY,4758
5
+ nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zuyI,1584
6
6
  nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
@@ -122,7 +122,6 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
122
122
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
123
123
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
124
124
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
125
- nshtrainer/nn/tests/test_mlp.py,sha256=COw88Oc4FMvj_mGdQf6F2UtgZ39FE2lbQjTmMAwtCWE,4031
126
125
  nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
127
126
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
128
127
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
@@ -141,18 +140,18 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
141
140
  nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
142
141
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
143
142
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
144
- nshtrainer/trainer/trainer.py,sha256=l2kJs27v4IHZnzxExr0zX0sVex0wukgiD2Wn_0wiGJg,20836
143
+ nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
145
144
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
146
145
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
147
146
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
148
147
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
149
148
  nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
150
149
  nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
151
- nshtrainer/util/path.py,sha256=IL8MS29MzP2aBdQhqMl2Eh4RvyuZbzaIbAVZzsdLDSk,3754
150
+ nshtrainer/util/path.py,sha256=9fIjE3S78pPL6wjAgEJUYfIJQAPdKOQqIYvTS9lWTUk,3959
152
151
  nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
153
152
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
154
153
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
155
154
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
156
- nshtrainer-1.0.0b42.dist-info/METADATA,sha256=m8vNeWow5AfufFEd6uCjm6w-0NOCTTzCFjXDed0eq6U,988
157
- nshtrainer-1.0.0b42.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
158
- nshtrainer-1.0.0b42.dist-info/RECORD,,
155
+ nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
156
+ nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
157
+ nshtrainer-1.0.0b43.dist-info/RECORD,,
@@ -1,120 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import cast
4
-
5
- import pytest
6
- import torch
7
-
8
- from nshtrainer.nn.mlp import MLP, custom_seed_context
9
-
10
-
11
- def test_mlp_seed_reproducibility():
12
- """Test that the seed parameter in MLP ensures reproducible weights."""
13
-
14
- # Test dimensions
15
- dims = [10, 20, 5]
16
-
17
- # Create two MLPs with the same seed
18
- seed1 = 42
19
- mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
20
- mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
21
-
22
- # Create an MLP with a different seed
23
- seed2 = 123
24
- mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
25
-
26
- # Check first layer weights
27
- layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
28
- layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
29
- layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
30
-
31
- # Same seed should produce identical weights
32
- assert torch.allclose(layer1_weights1, layer1_weights2)
33
-
34
- # Different seeds should produce different weights
35
- assert not torch.allclose(layer1_weights1, layer1_weights3)
36
-
37
- # Check second layer weights
38
- layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
39
- layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
40
- layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
41
-
42
- # Same seed should produce identical weights for all layers
43
- assert torch.allclose(layer2_weights1, layer2_weights2)
44
-
45
- # Different seeds should produce different weights for all layers
46
- assert not torch.allclose(layer2_weights1, layer2_weights3)
47
-
48
- # Test that not providing a seed gives different results each time
49
- mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
50
- mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
51
-
52
- # Without seeds, weights should be different
53
- assert not torch.allclose(
54
- cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
55
- )
56
-
57
-
58
- def test_custom_seed_context():
59
- """Test that custom_seed_context properly controls random number generation."""
60
-
61
- # Test that the same seed produces the same random numbers
62
- with custom_seed_context(42):
63
- tensor1 = torch.randn(10)
64
-
65
- with custom_seed_context(42):
66
- tensor2 = torch.randn(10)
67
-
68
- # Same seed should produce identical random tensors
69
- assert torch.allclose(tensor1, tensor2)
70
-
71
- # Test that different seeds produce different random numbers
72
- with custom_seed_context(123):
73
- tensor3 = torch.randn(10)
74
-
75
- # Different seeds should produce different random tensors
76
- assert not torch.allclose(tensor1, tensor3)
77
-
78
-
79
- def test_custom_seed_context_preserves_state():
80
- """Test that custom_seed_context preserves the original random state."""
81
-
82
- # Set a known seed for the test
83
- original_seed = 789
84
- torch.manual_seed(original_seed)
85
-
86
- # Generate a tensor with the original seed
87
- original_tensor = torch.randn(10)
88
-
89
- # Use a different seed in the context
90
- with custom_seed_context(42):
91
- # This should use the temporary seed
92
- context_tensor = torch.randn(10)
93
-
94
- # After exiting the context, we should be back to the original seed state
95
- # Reset the generator to get the same sequence again
96
- torch.manual_seed(original_seed)
97
- expected_tensor = torch.randn(10)
98
-
99
- # The tensor generated after the context should match what we would get
100
- # if we had just set the original seed again
101
- assert torch.allclose(original_tensor, expected_tensor)
102
-
103
- # And it should be different from the tensor generated inside the context
104
- assert not torch.allclose(original_tensor, context_tensor)
105
-
106
-
107
- def test_custom_seed_context_with_none():
108
- """Test that custom_seed_context with None seed doesn't affect randomization."""
109
-
110
- # Set a known seed
111
- torch.manual_seed(555)
112
- expected_tensor = torch.randn(10)
113
-
114
- # Reset and use None seed in context
115
- torch.manual_seed(555)
116
- with custom_seed_context(None):
117
- actual_tensor = torch.randn(10)
118
-
119
- # With None seed, the context should not affect the random state
120
- assert torch.allclose(expected_tensor, actual_tensor)