nshtrainer 1.0.0b42__py3-none-any.whl → 1.0.0b43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +1 -2
- nshtrainer/_checkpoint/saver.py +4 -1
- nshtrainer/trainer/trainer.py +2 -2
- nshtrainer/util/path.py +5 -0
- {nshtrainer-1.0.0b42.dist-info → nshtrainer-1.0.0b43.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b42.dist-info → nshtrainer-1.0.0b43.dist-info}/RECORD +7 -8
- nshtrainer/nn/tests/test_mlp.py +0 -120
- {nshtrainer-1.0.0b42.dist-info → nshtrainer-1.0.0b43.dist-info}/WHEEL +0 -0
| @@ -3,7 +3,6 @@ from __future__ import annotations | |
| 3 3 | 
             
            import copy
         | 
| 4 4 | 
             
            import datetime
         | 
| 5 5 | 
             
            import logging
         | 
| 6 | 
            -
            from collections.abc import Callable
         | 
| 7 6 | 
             
            from pathlib import Path
         | 
| 8 7 | 
             
            from typing import TYPE_CHECKING, Any, ClassVar
         | 
| 9 8 |  | 
| @@ -115,7 +114,7 @@ def _metadata_path(checkpoint_path: Path): | |
| 115 114 | 
             
                return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
         | 
| 116 115 |  | 
| 117 116 |  | 
| 118 | 
            -
            def  | 
| 117 | 
            +
            def write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
         | 
| 119 118 | 
             
                metadata_path = _metadata_path(checkpoint_path)
         | 
| 120 119 | 
             
                metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
         | 
| 121 120 |  | 
    
        nshtrainer/_checkpoint/saver.py
    CHANGED
    
    | @@ -25,7 +25,10 @@ def link_checkpoint( | |
| 25 25 |  | 
| 26 26 | 
             
                if remove_existing:
         | 
| 27 27 | 
             
                    try:
         | 
| 28 | 
            -
                        if linkpath.exists():
         | 
| 28 | 
            +
                        if linkpath.exists(follow_symlinks=False):
         | 
| 29 | 
            +
                            # follow_symlinks=False is EXTREMELY important here
         | 
| 30 | 
            +
                            # Otherwise, we've already deleted the file that the symlink
         | 
| 31 | 
            +
                            # used to point to, so this always returns False
         | 
| 29 32 | 
             
                            if linkpath.is_dir():
         | 
| 30 33 | 
             
                                shutil.rmtree(linkpath)
         | 
| 31 34 | 
             
                            else:
         | 
    
        nshtrainer/trainer/trainer.py
    CHANGED
    
    | @@ -18,7 +18,7 @@ from lightning.pytorch.trainer.states import TrainerFn | |
| 18 18 | 
             
            from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
         | 
| 19 19 | 
             
            from typing_extensions import Never, Unpack, assert_never, deprecated, override
         | 
| 20 20 |  | 
| 21 | 
            -
            from .._checkpoint.metadata import  | 
| 21 | 
            +
            from .._checkpoint.metadata import write_checkpoint_metadata
         | 
| 22 22 | 
             
            from ..callbacks.base import resolve_all_callbacks
         | 
| 23 23 | 
             
            from ..util._environment_info import EnvironmentConfig
         | 
| 24 24 | 
             
            from ..util.bf16 import is_bf16_supported_no_emulation
         | 
| @@ -478,7 +478,7 @@ class Trainer(LightningTrainer): | |
| 478 478 | 
             
                    metadata_path = None
         | 
| 479 479 | 
             
                    if self.hparams.save_checkpoint_metadata and self.is_global_zero:
         | 
| 480 480 | 
             
                        # Generate the metadata and write to disk
         | 
| 481 | 
            -
                        metadata_path =  | 
| 481 | 
            +
                        metadata_path = write_checkpoint_metadata(self, filepath)
         | 
| 482 482 |  | 
| 483 483 | 
             
                    # Call the `on_checkpoint_saved` method on all callbacks
         | 
| 484 484 | 
             
                    from .. import _callback
         | 
    
        nshtrainer/util/path.py
    CHANGED
    
    | @@ -85,11 +85,16 @@ def try_symlink_or_copy( | |
| 85 85 | 
             
                target_is_directory: bool = False,
         | 
| 86 86 | 
             
                relative: bool = True,
         | 
| 87 87 | 
             
                remove_existing: bool = True,
         | 
| 88 | 
            +
                throw_on_invalid_target: bool = False,
         | 
| 88 89 | 
             
            ):
         | 
| 89 90 | 
             
                """
         | 
| 90 91 | 
             
                Symlinks on Unix, copies on Windows.
         | 
| 91 92 | 
             
                """
         | 
| 92 93 |  | 
| 94 | 
            +
                # Check if the target file exists
         | 
| 95 | 
            +
                if throw_on_invalid_target and not file_path.exists():
         | 
| 96 | 
            +
                    raise FileNotFoundError(f"File not found: {file_path}")
         | 
| 97 | 
            +
             | 
| 93 98 | 
             
                # If the link already exists, remove it
         | 
| 94 99 | 
             
                if remove_existing:
         | 
| 95 100 | 
             
                    try:
         | 
| @@ -1,8 +1,8 @@ | |
| 1 1 | 
             
            nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
         | 
| 2 2 | 
             
            nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
         | 
| 3 3 | 
             
            nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
         | 
| 4 | 
            -
            nshtrainer/_checkpoint/metadata.py,sha256= | 
| 5 | 
            -
            nshtrainer/_checkpoint/saver.py,sha256= | 
| 4 | 
            +
            nshtrainer/_checkpoint/metadata.py,sha256=LQZ8g50rKxQQx-FqiW3n8EWmal9qSWRouOpIIn6NJJY,4758
         | 
| 5 | 
            +
            nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zuyI,1584
         | 
| 6 6 | 
             
            nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
         | 
| 7 7 | 
             
            nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
         | 
| 8 8 | 
             
            nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
         | 
| @@ -122,7 +122,6 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101 | |
| 122 122 | 
             
            nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
         | 
| 123 123 | 
             
            nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
         | 
| 124 124 | 
             
            nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
         | 
| 125 | 
            -
            nshtrainer/nn/tests/test_mlp.py,sha256=COw88Oc4FMvj_mGdQf6F2UtgZ39FE2lbQjTmMAwtCWE,4031
         | 
| 126 125 | 
             
            nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
         | 
| 127 126 | 
             
            nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
         | 
| 128 127 | 
             
            nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
         | 
| @@ -141,18 +140,18 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv | |
| 141 140 | 
             
            nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
         | 
| 142 141 | 
             
            nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
         | 
| 143 142 | 
             
            nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
         | 
| 144 | 
            -
            nshtrainer/trainer/trainer.py,sha256= | 
| 143 | 
            +
            nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
         | 
| 145 144 | 
             
            nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
         | 
| 146 145 | 
             
            nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
         | 
| 147 146 | 
             
            nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
         | 
| 148 147 | 
             
            nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
         | 
| 149 148 | 
             
            nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
         | 
| 150 149 | 
             
            nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
         | 
| 151 | 
            -
            nshtrainer/util/path.py,sha256= | 
| 150 | 
            +
            nshtrainer/util/path.py,sha256=9fIjE3S78pPL6wjAgEJUYfIJQAPdKOQqIYvTS9lWTUk,3959
         | 
| 152 151 | 
             
            nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
         | 
| 153 152 | 
             
            nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
         | 
| 154 153 | 
             
            nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
         | 
| 155 154 | 
             
            nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
         | 
| 156 | 
            -
            nshtrainer-1.0. | 
| 157 | 
            -
            nshtrainer-1.0. | 
| 158 | 
            -
            nshtrainer-1.0. | 
| 155 | 
            +
            nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
         | 
| 156 | 
            +
            nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
         | 
| 157 | 
            +
            nshtrainer-1.0.0b43.dist-info/RECORD,,
         | 
    
        nshtrainer/nn/tests/test_mlp.py
    DELETED
    
    | @@ -1,120 +0,0 @@ | |
| 1 | 
            -
            from __future__ import annotations
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            from typing import cast
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            import pytest
         | 
| 6 | 
            -
            import torch
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            from nshtrainer.nn.mlp import MLP, custom_seed_context
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            def test_mlp_seed_reproducibility():
         | 
| 12 | 
            -
                """Test that the seed parameter in MLP ensures reproducible weights."""
         | 
| 13 | 
            -
             | 
| 14 | 
            -
                # Test dimensions
         | 
| 15 | 
            -
                dims = [10, 20, 5]
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                # Create two MLPs with the same seed
         | 
| 18 | 
            -
                seed1 = 42
         | 
| 19 | 
            -
                mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
         | 
| 20 | 
            -
                mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
         | 
| 21 | 
            -
             | 
| 22 | 
            -
                # Create an MLP with a different seed
         | 
| 23 | 
            -
                seed2 = 123
         | 
| 24 | 
            -
                mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                # Check first layer weights
         | 
| 27 | 
            -
                layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
         | 
| 28 | 
            -
                layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
         | 
| 29 | 
            -
                layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                # Same seed should produce identical weights
         | 
| 32 | 
            -
                assert torch.allclose(layer1_weights1, layer1_weights2)
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                # Different seeds should produce different weights
         | 
| 35 | 
            -
                assert not torch.allclose(layer1_weights1, layer1_weights3)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                # Check second layer weights
         | 
| 38 | 
            -
                layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
         | 
| 39 | 
            -
                layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
         | 
| 40 | 
            -
                layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                # Same seed should produce identical weights for all layers
         | 
| 43 | 
            -
                assert torch.allclose(layer2_weights1, layer2_weights2)
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                # Different seeds should produce different weights for all layers
         | 
| 46 | 
            -
                assert not torch.allclose(layer2_weights1, layer2_weights3)
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                # Test that not providing a seed gives different results each time
         | 
| 49 | 
            -
                mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
         | 
| 50 | 
            -
                mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                # Without seeds, weights should be different
         | 
| 53 | 
            -
                assert not torch.allclose(
         | 
| 54 | 
            -
                    cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
         | 
| 55 | 
            -
                )
         | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
            def test_custom_seed_context():
         | 
| 59 | 
            -
                """Test that custom_seed_context properly controls random number generation."""
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                # Test that the same seed produces the same random numbers
         | 
| 62 | 
            -
                with custom_seed_context(42):
         | 
| 63 | 
            -
                    tensor1 = torch.randn(10)
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                with custom_seed_context(42):
         | 
| 66 | 
            -
                    tensor2 = torch.randn(10)
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                # Same seed should produce identical random tensors
         | 
| 69 | 
            -
                assert torch.allclose(tensor1, tensor2)
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                # Test that different seeds produce different random numbers
         | 
| 72 | 
            -
                with custom_seed_context(123):
         | 
| 73 | 
            -
                    tensor3 = torch.randn(10)
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                # Different seeds should produce different random tensors
         | 
| 76 | 
            -
                assert not torch.allclose(tensor1, tensor3)
         | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
            def test_custom_seed_context_preserves_state():
         | 
| 80 | 
            -
                """Test that custom_seed_context preserves the original random state."""
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                # Set a known seed for the test
         | 
| 83 | 
            -
                original_seed = 789
         | 
| 84 | 
            -
                torch.manual_seed(original_seed)
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                # Generate a tensor with the original seed
         | 
| 87 | 
            -
                original_tensor = torch.randn(10)
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                # Use a different seed in the context
         | 
| 90 | 
            -
                with custom_seed_context(42):
         | 
| 91 | 
            -
                    # This should use the temporary seed
         | 
| 92 | 
            -
                    context_tensor = torch.randn(10)
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                # After exiting the context, we should be back to the original seed state
         | 
| 95 | 
            -
                # Reset the generator to get the same sequence again
         | 
| 96 | 
            -
                torch.manual_seed(original_seed)
         | 
| 97 | 
            -
                expected_tensor = torch.randn(10)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                # The tensor generated after the context should match what we would get
         | 
| 100 | 
            -
                # if we had just set the original seed again
         | 
| 101 | 
            -
                assert torch.allclose(original_tensor, expected_tensor)
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                # And it should be different from the tensor generated inside the context
         | 
| 104 | 
            -
                assert not torch.allclose(original_tensor, context_tensor)
         | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
            def test_custom_seed_context_with_none():
         | 
| 108 | 
            -
                """Test that custom_seed_context with None seed doesn't affect randomization."""
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                # Set a known seed
         | 
| 111 | 
            -
                torch.manual_seed(555)
         | 
| 112 | 
            -
                expected_tensor = torch.randn(10)
         | 
| 113 | 
            -
             | 
| 114 | 
            -
                # Reset and use None seed in context
         | 
| 115 | 
            -
                torch.manual_seed(555)
         | 
| 116 | 
            -
                with custom_seed_context(None):
         | 
| 117 | 
            -
                    actual_tensor = torch.randn(10)
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                # With None seed, the context should not affect the random state
         | 
| 120 | 
            -
                assert torch.allclose(expected_tensor, actual_tensor)
         | 
| 
            File without changes
         |