nshtrainer 1.0.0b41__py3-none-any.whl → 1.0.0b43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/metadata.py +7 -16
- nshtrainer/_checkpoint/saver.py +10 -7
- nshtrainer/callbacks/checkpoint/_base.py +3 -3
- nshtrainer/trainer/trainer.py +2 -2
- nshtrainer/util/path.py +10 -1
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b43.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b43.dist-info}/RECORD +8 -9
- nshtrainer/nn/tests/test_mlp.py +0 -55
- {nshtrainer-1.0.0b41.dist-info → nshtrainer-1.0.0b43.dist-info}/WHEEL +0 -0
| @@ -3,7 +3,6 @@ from __future__ import annotations | |
| 3 3 | 
             
            import copy
         | 
| 4 4 | 
             
            import datetime
         | 
| 5 5 | 
             
            import logging
         | 
| 6 | 
            -
            from collections.abc import Callable
         | 
| 7 6 | 
             
            from pathlib import Path
         | 
| 8 7 | 
             
            from typing import TYPE_CHECKING, Any, ClassVar
         | 
| 9 8 |  | 
| @@ -115,7 +114,7 @@ def _metadata_path(checkpoint_path: Path): | |
| 115 114 | 
             
                return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
         | 
| 116 115 |  | 
| 117 116 |  | 
| 118 | 
            -
            def  | 
| 117 | 
            +
            def write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
         | 
| 119 118 | 
             
                metadata_path = _metadata_path(checkpoint_path)
         | 
| 120 119 | 
             
                metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
         | 
| 121 120 |  | 
| @@ -130,7 +129,7 @@ def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path): | |
| 130 129 | 
             
                return metadata_path
         | 
| 131 130 |  | 
| 132 131 |  | 
| 133 | 
            -
            def  | 
| 132 | 
            +
            def remove_checkpoint_metadata(checkpoint_path: Path):
         | 
| 134 133 | 
             
                path = _metadata_path(checkpoint_path)
         | 
| 135 134 | 
             
                try:
         | 
| 136 135 | 
             
                    path.unlink(missing_ok=True)
         | 
| @@ -140,23 +139,15 @@ def _remove_checkpoint_metadata(checkpoint_path: Path): | |
| 140 139 | 
             
                    log.debug(f"Removed {path}")
         | 
| 141 140 |  | 
| 142 141 |  | 
| 143 | 
            -
            def  | 
| 142 | 
            +
            def link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
         | 
| 144 143 | 
             
                # First, remove any existing metadata files
         | 
| 145 | 
            -
                 | 
| 144 | 
            +
                remove_checkpoint_metadata(linked_checkpoint_path)
         | 
| 146 145 |  | 
| 147 146 | 
             
                # Link the metadata files to the new checkpoint
         | 
| 148 147 | 
             
                path = _metadata_path(checkpoint_path)
         | 
| 149 148 | 
             
                linked_path = _metadata_path(linked_checkpoint_path)
         | 
| 150 | 
            -
                try_symlink_or_copy(path, linked_path)
         | 
| 151 149 |  | 
| 150 | 
            +
                if not path.exists():
         | 
| 151 | 
            +
                    raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
         | 
| 152 152 |  | 
| 153 | 
            -
             | 
| 154 | 
            -
                checkpoint_paths: list[Path],
         | 
| 155 | 
            -
                key: Callable[[CheckpointMetadata, Path], Any],
         | 
| 156 | 
            -
                reverse: bool = False,
         | 
| 157 | 
            -
            ):
         | 
| 158 | 
            -
                return sorted(
         | 
| 159 | 
            -
                    [(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
         | 
| 160 | 
            -
                    key=lambda args_tuple: key(*args_tuple),
         | 
| 161 | 
            -
                    reverse=reverse,
         | 
| 162 | 
            -
                )
         | 
| 153 | 
            +
                try_symlink_or_copy(path, linked_path)
         | 
    
        nshtrainer/_checkpoint/saver.py
    CHANGED
    
    | @@ -8,12 +8,12 @@ from pathlib import Path | |
| 8 8 | 
             
            from lightning.pytorch import Trainer
         | 
| 9 9 |  | 
| 10 10 | 
             
            from ..util.path import try_symlink_or_copy
         | 
| 11 | 
            -
            from .metadata import  | 
| 11 | 
            +
            from .metadata import link_checkpoint_metadata, remove_checkpoint_metadata
         | 
| 12 12 |  | 
| 13 13 | 
             
            log = logging.getLogger(__name__)
         | 
| 14 14 |  | 
| 15 15 |  | 
| 16 | 
            -
            def  | 
| 16 | 
            +
            def link_checkpoint(
         | 
| 17 17 | 
             
                filepath: str | Path | os.PathLike,
         | 
| 18 18 | 
             
                linkpath: str | Path | os.PathLike,
         | 
| 19 19 | 
             
                *,
         | 
| @@ -25,7 +25,10 @@ def _link_checkpoint( | |
| 25 25 |  | 
| 26 26 | 
             
                if remove_existing:
         | 
| 27 27 | 
             
                    try:
         | 
| 28 | 
            -
                        if linkpath.exists():
         | 
| 28 | 
            +
                        if linkpath.exists(follow_symlinks=False):
         | 
| 29 | 
            +
                            # follow_symlinks=False is EXTREMELY important here
         | 
| 30 | 
            +
                            # Otherwise, we've already deleted the file that the symlink
         | 
| 31 | 
            +
                            # used to point to, so this always returns False
         | 
| 29 32 | 
             
                            if linkpath.is_dir():
         | 
| 30 33 | 
             
                                shutil.rmtree(linkpath)
         | 
| 31 34 | 
             
                            else:
         | 
| @@ -36,14 +39,14 @@ def _link_checkpoint( | |
| 36 39 | 
             
                        log.debug(f"Removed {linkpath=}")
         | 
| 37 40 |  | 
| 38 41 | 
             
                    if metadata:
         | 
| 39 | 
            -
                         | 
| 42 | 
            +
                        remove_checkpoint_metadata(linkpath)
         | 
| 40 43 |  | 
| 41 44 | 
             
                try_symlink_or_copy(filepath, linkpath)
         | 
| 42 45 | 
             
                if metadata:
         | 
| 43 | 
            -
                     | 
| 46 | 
            +
                    link_checkpoint_metadata(filepath, linkpath)
         | 
| 44 47 |  | 
| 45 48 |  | 
| 46 | 
            -
            def  | 
| 49 | 
            +
            def remove_checkpoint(
         | 
| 47 50 | 
             
                trainer: Trainer,
         | 
| 48 51 | 
             
                filepath: str | Path | os.PathLike,
         | 
| 49 52 | 
             
                *,
         | 
| @@ -54,4 +57,4 @@ def _remove_checkpoint( | |
| 54 57 | 
             
                trainer.strategy.remove_checkpoint(filepath)
         | 
| 55 58 |  | 
| 56 59 | 
             
                if metadata:
         | 
| 57 | 
            -
                     | 
| 60 | 
            +
                    remove_checkpoint_metadata(filepath)
         | 
| @@ -12,7 +12,7 @@ from lightning.pytorch.callbacks import Checkpoint | |
| 12 12 | 
             
            from typing_extensions import TypeVar, override
         | 
| 13 13 |  | 
| 14 14 | 
             
            from ..._checkpoint.metadata import CheckpointMetadata
         | 
| 15 | 
            -
            from ..._checkpoint.saver import  | 
| 15 | 
            +
            from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
         | 
| 16 16 | 
             
            from ..base import CallbackConfigBase
         | 
| 17 17 |  | 
| 18 18 | 
             
            if TYPE_CHECKING:
         | 
| @@ -122,7 +122,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]): | |
| 122 122 | 
             
                            )
         | 
| 123 123 | 
             
                            continue
         | 
| 124 124 |  | 
| 125 | 
            -
                         | 
| 125 | 
            +
                        remove_checkpoint(trainer, old_ckpt_path, metadata=True)
         | 
| 126 126 | 
             
                        log.debug(f"Removed old checkpoint: {old_ckpt_path}")
         | 
| 127 127 |  | 
| 128 128 | 
             
                def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
         | 
| @@ -167,7 +167,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]): | |
| 167 167 | 
             
                        # Create the latest symlink
         | 
| 168 168 | 
             
                        if (symlink_filename := self.symlink_path()) is not None:
         | 
| 169 169 | 
             
                            symlink_path = self.dirpath / symlink_filename
         | 
| 170 | 
            -
                             | 
| 170 | 
            +
                            link_checkpoint(filepath, symlink_path, metadata=True)
         | 
| 171 171 | 
             
                            log.debug(f"Created latest symlink: {symlink_path}")
         | 
| 172 172 |  | 
| 173 173 | 
             
                    # Barrier to ensure all processes have saved the checkpoint,
         | 
    
        nshtrainer/trainer/trainer.py
    CHANGED
    
    | @@ -18,7 +18,7 @@ from lightning.pytorch.trainer.states import TrainerFn | |
| 18 18 | 
             
            from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
         | 
| 19 19 | 
             
            from typing_extensions import Never, Unpack, assert_never, deprecated, override
         | 
| 20 20 |  | 
| 21 | 
            -
            from .._checkpoint.metadata import  | 
| 21 | 
            +
            from .._checkpoint.metadata import write_checkpoint_metadata
         | 
| 22 22 | 
             
            from ..callbacks.base import resolve_all_callbacks
         | 
| 23 23 | 
             
            from ..util._environment_info import EnvironmentConfig
         | 
| 24 24 | 
             
            from ..util.bf16 import is_bf16_supported_no_emulation
         | 
| @@ -478,7 +478,7 @@ class Trainer(LightningTrainer): | |
| 478 478 | 
             
                    metadata_path = None
         | 
| 479 479 | 
             
                    if self.hparams.save_checkpoint_metadata and self.is_global_zero:
         | 
| 480 480 | 
             
                        # Generate the metadata and write to disk
         | 
| 481 | 
            -
                        metadata_path =  | 
| 481 | 
            +
                        metadata_path = write_checkpoint_metadata(self, filepath)
         | 
| 482 482 |  | 
| 483 483 | 
             
                    # Call the `on_checkpoint_saved` method on all callbacks
         | 
| 484 484 | 
             
                    from .. import _callback
         | 
    
        nshtrainer/util/path.py
    CHANGED
    
    | @@ -81,18 +81,27 @@ def compute_file_checksum(file_path: Path) -> str: | |
| 81 81 | 
             
            def try_symlink_or_copy(
         | 
| 82 82 | 
             
                file_path: Path,
         | 
| 83 83 | 
             
                link_path: Path,
         | 
| 84 | 
            +
                *,
         | 
| 84 85 | 
             
                target_is_directory: bool = False,
         | 
| 85 86 | 
             
                relative: bool = True,
         | 
| 86 87 | 
             
                remove_existing: bool = True,
         | 
| 88 | 
            +
                throw_on_invalid_target: bool = False,
         | 
| 87 89 | 
             
            ):
         | 
| 88 90 | 
             
                """
         | 
| 89 91 | 
             
                Symlinks on Unix, copies on Windows.
         | 
| 90 92 | 
             
                """
         | 
| 91 93 |  | 
| 94 | 
            +
                # Check if the target file exists
         | 
| 95 | 
            +
                if throw_on_invalid_target and not file_path.exists():
         | 
| 96 | 
            +
                    raise FileNotFoundError(f"File not found: {file_path}")
         | 
| 97 | 
            +
             | 
| 92 98 | 
             
                # If the link already exists, remove it
         | 
| 93 99 | 
             
                if remove_existing:
         | 
| 94 100 | 
             
                    try:
         | 
| 95 | 
            -
                        if link_path.exists():
         | 
| 101 | 
            +
                        if link_path.exists(follow_symlinks=False):
         | 
| 102 | 
            +
                            # follow_symlinks=False is EXTREMELY important here
         | 
| 103 | 
            +
                            # Otherwise, we've already deleted the file that the symlink
         | 
| 104 | 
            +
                            # used to point to, so this always returns False
         | 
| 96 105 | 
             
                            if link_path.is_dir():
         | 
| 97 106 | 
             
                                shutil.rmtree(link_path)
         | 
| 98 107 | 
             
                            else:
         | 
| @@ -1,8 +1,8 @@ | |
| 1 1 | 
             
            nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
         | 
| 2 2 | 
             
            nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
         | 
| 3 3 | 
             
            nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
         | 
| 4 | 
            -
            nshtrainer/_checkpoint/metadata.py,sha256= | 
| 5 | 
            -
            nshtrainer/_checkpoint/saver.py,sha256= | 
| 4 | 
            +
            nshtrainer/_checkpoint/metadata.py,sha256=LQZ8g50rKxQQx-FqiW3n8EWmal9qSWRouOpIIn6NJJY,4758
         | 
| 5 | 
            +
            nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zuyI,1584
         | 
| 6 6 | 
             
            nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
         | 
| 7 7 | 
             
            nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
         | 
| 8 8 | 
             
            nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
         | 
| @@ -10,7 +10,7 @@ nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTw | |
| 10 10 | 
             
            nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
         | 
| 11 11 | 
             
            nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
         | 
| 12 12 | 
             
            nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
         | 
| 13 | 
            -
            nshtrainer/callbacks/checkpoint/_base.py,sha256= | 
| 13 | 
            +
            nshtrainer/callbacks/checkpoint/_base.py,sha256=wCJBRI0pQYZc3GBu0b-aUBlBDhd39AdL82VvFgKmv3k,6300
         | 
| 14 14 | 
             
            nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
         | 
| 15 15 | 
             
            nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
         | 
| 16 16 | 
             
            nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
         | 
| @@ -122,7 +122,6 @@ nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101 | |
| 122 122 | 
             
            nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
         | 
| 123 123 | 
             
            nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
         | 
| 124 124 | 
             
            nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
         | 
| 125 | 
            -
            nshtrainer/nn/tests/test_mlp.py,sha256=xBPiHlBvOCn67EbpzzKL-2FU7ikGxHT3i6CMSp1wk7M,1840
         | 
| 126 125 | 
             
            nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
         | 
| 127 126 | 
             
            nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
         | 
| 128 127 | 
             
            nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
         | 
| @@ -141,18 +140,18 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv | |
| 141 140 | 
             
            nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
         | 
| 142 141 | 
             
            nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
         | 
| 143 142 | 
             
            nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
         | 
| 144 | 
            -
            nshtrainer/trainer/trainer.py,sha256= | 
| 143 | 
            +
            nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
         | 
| 145 144 | 
             
            nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
         | 
| 146 145 | 
             
            nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
         | 
| 147 146 | 
             
            nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
         | 
| 148 147 | 
             
            nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
         | 
| 149 148 | 
             
            nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
         | 
| 150 149 | 
             
            nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
         | 
| 151 | 
            -
            nshtrainer/util/path.py,sha256= | 
| 150 | 
            +
            nshtrainer/util/path.py,sha256=9fIjE3S78pPL6wjAgEJUYfIJQAPdKOQqIYvTS9lWTUk,3959
         | 
| 152 151 | 
             
            nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
         | 
| 153 152 | 
             
            nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
         | 
| 154 153 | 
             
            nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
         | 
| 155 154 | 
             
            nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
         | 
| 156 | 
            -
            nshtrainer-1.0. | 
| 157 | 
            -
            nshtrainer-1.0. | 
| 158 | 
            -
            nshtrainer-1.0. | 
| 155 | 
            +
            nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
         | 
| 156 | 
            +
            nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
         | 
| 157 | 
            +
            nshtrainer-1.0.0b43.dist-info/RECORD,,
         | 
    
        nshtrainer/nn/tests/test_mlp.py
    DELETED
    
    | @@ -1,55 +0,0 @@ | |
| 1 | 
            -
            from __future__ import annotations
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            from typing import cast
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            import pytest
         | 
| 6 | 
            -
            import torch
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            from nshtrainer.nn.mlp import MLP
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            def test_mlp_seed_reproducibility():
         | 
| 12 | 
            -
                """Test that the seed parameter in MLP ensures reproducible weights."""
         | 
| 13 | 
            -
             | 
| 14 | 
            -
                # Test dimensions
         | 
| 15 | 
            -
                dims = [10, 20, 5]
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                # Create two MLPs with the same seed
         | 
| 18 | 
            -
                seed1 = 42
         | 
| 19 | 
            -
                mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
         | 
| 20 | 
            -
                mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
         | 
| 21 | 
            -
             | 
| 22 | 
            -
                # Create an MLP with a different seed
         | 
| 23 | 
            -
                seed2 = 123
         | 
| 24 | 
            -
                mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                # Check first layer weights
         | 
| 27 | 
            -
                layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
         | 
| 28 | 
            -
                layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
         | 
| 29 | 
            -
                layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                # Same seed should produce identical weights
         | 
| 32 | 
            -
                assert torch.allclose(layer1_weights1, layer1_weights2)
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                # Different seeds should produce different weights
         | 
| 35 | 
            -
                assert not torch.allclose(layer1_weights1, layer1_weights3)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                # Check second layer weights
         | 
| 38 | 
            -
                layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
         | 
| 39 | 
            -
                layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
         | 
| 40 | 
            -
                layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                # Same seed should produce identical weights for all layers
         | 
| 43 | 
            -
                assert torch.allclose(layer2_weights1, layer2_weights2)
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                # Different seeds should produce different weights for all layers
         | 
| 46 | 
            -
                assert not torch.allclose(layer2_weights1, layer2_weights3)
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                # Test that not providing a seed gives different results each time
         | 
| 49 | 
            -
                mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
         | 
| 50 | 
            -
                mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                # Without seeds, weights should be different
         | 
| 53 | 
            -
                assert not torch.allclose(
         | 
| 54 | 
            -
                    cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
         | 
| 55 | 
            -
                )
         | 
| 
            File without changes
         |