nshtrainer 0.10.11__py3-none-any.whl → 0.10.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,7 +10,7 @@ import nshconfig as C
10
10
  import numpy as np
11
11
  import torch
12
12
 
13
- from ..model._environment import EnvironmentConfig
13
+ from ..util._environment_info import EnvironmentConfig
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from ..model import BaseConfig, LightningModuleBase
@@ -36,7 +36,8 @@ def _link_checkpoint(
36
36
  # fall back to copying the file
37
37
  shutil.copy(filepath, linkpath)
38
38
 
39
- _link_checkpoint_metadata(filepath, linkpath)
39
+ if metadata:
40
+ _link_checkpoint_metadata(filepath, linkpath)
40
41
  if barrier:
41
42
  trainer.strategy.barrier()
42
43
 
@@ -44,9 +45,17 @@ def _link_checkpoint(
44
45
  def _remove_checkpoint(
45
46
  trainer: Trainer,
46
47
  filepath: str | Path | os.PathLike,
47
- remove_metadata: bool = True,
48
+ *,
49
+ metadata: bool,
50
+ barrier: bool,
48
51
  ):
49
52
  if not isinstance(filepath, Path):
50
53
  filepath = Path(filepath)
51
- trainer.strategy.remove_checkpoint(filepath)
52
- _remove_checkpoint_metadata(filepath)
54
+
55
+ if trainer.is_global_zero:
56
+ trainer.strategy.remove_checkpoint(filepath)
57
+ if metadata:
58
+ _remove_checkpoint_metadata(filepath)
59
+
60
+ if barrier:
61
+ trainer.strategy.barrier()
@@ -69,7 +69,7 @@ class LatestEpochCheckpoint(Checkpoint):
69
69
 
70
70
  def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
71
  for ckpt_path in ckpt_paths:
72
- _remove_checkpoint(trainer, ckpt_path, remove_metadata=True)
72
+ _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
73
73
 
74
74
  def _remove_old_checkpoints(self, trainer: Trainer):
75
75
  if (latest_k := self.config.latest_k) == "all":
@@ -202,4 +202,4 @@ class ModelCheckpoint(_ModelCheckpoint):
202
202
 
203
203
  @override
204
204
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
205
- return _remove_checkpoint(trainer, filepath, remove_metadata=True)
205
+ return _remove_checkpoint(trainer, filepath, metadata=True, barrier=False)
nshtrainer/ll/model.py CHANGED
@@ -1 +1,12 @@
1
1
  from nshtrainer.model import * # noqa: F403
2
+
3
+ from ..util._environment_info import (
4
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
5
+ )
6
+ from ..util._environment_info import EnvironmentConfig as EnvironmentConfig
7
+ from ..util._environment_info import (
8
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
9
+ )
10
+ from ..util._environment_info import (
11
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
12
+ )
@@ -1,16 +1,5 @@
1
1
  from typing_extensions import TypeAlias
2
2
 
3
- from ._environment import (
4
- EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
5
- )
6
- from ._environment import EnvironmentConfig as EnvironmentConfig
7
- from ._environment import (
8
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
9
- )
10
- from ._environment import (
11
- EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
12
- )
13
- from ._environment import EnvironmentSnapshotConfig as EnvironmentSnapshotConfig
14
3
  from .base import Base as Base
15
4
  from .base import LightningModuleBase as LightningModuleBase
16
5
  from .config import BaseConfig as BaseConfig
nshtrainer/model/base.py CHANGED
@@ -11,7 +11,7 @@ from lightning.pytorch.callbacks import Callback
11
11
  from lightning.pytorch.utilities.types import STEP_OUTPUT
12
12
  from typing_extensions import Self, TypeVar, override
13
13
 
14
- from ._environment import EnvironmentConfig
14
+ from ..util._environment_info import EnvironmentConfig
15
15
  from .config import BaseConfig
16
16
  from .modules.callback import CallbackModuleMixin
17
17
  from .modules.debug import DebugModuleMixin
@@ -44,7 +44,7 @@ from ..callbacks import (
44
44
  )
45
45
  from ..callbacks.base import CallbackConfigBase
46
46
  from ..metrics import MetricConfig
47
- from ._environment import EnvironmentConfig
47
+ from ..util._environment_info import EnvironmentConfig
48
48
 
49
49
  log = logging.getLogger(__name__)
50
50
 
@@ -14,11 +14,11 @@ import psutil
14
14
  import torch
15
15
  from typing_extensions import Self
16
16
 
17
- from ..util.slurm import parse_slurm_node_list
17
+ from .slurm import parse_slurm_node_list
18
18
 
19
19
  if TYPE_CHECKING:
20
- from .base import LightningModuleBase
21
- from .config import BaseConfig
20
+ from ..model.base import LightningModuleBase
21
+ from ..model.config import BaseConfig
22
22
 
23
23
 
24
24
  log = logging.getLogger(__name__)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.11
3
+ Version: 0.10.13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,7 +1,7 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
- nshtrainer/_checkpoint/metadata.py,sha256=B6kPmWsq2TQh0gTzBx-1pLIwTVEs_Qw5v0nHEeTBdO4,5636
4
- nshtrainer/_checkpoint/saver.py,sha256=KZp9ITUVHwj2Ttu81zXKdlS_h-fKkHearspwuAijDpM,1501
3
+ nshtrainer/_checkpoint/metadata.py,sha256=soK9tXVs6EOpzhlnIxTEF51KmdkaCDUj0Rdyid3uREk,5640
4
+ nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
@@ -15,9 +15,9 @@ nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,1
15
15
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
16
16
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
17
17
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
18
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=t4vWa4PvJDO3rKXKZbuegm7iLl7xCEd17wNif0Bp-BA,4138
18
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zUeYAGfeQby0R6IwQBJH3lng-MD0vkckdX4aIOm-VIc,4146
19
19
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
20
- nshtrainer/callbacks/model_checkpoint.py,sha256=MaDkD8Ismcj8u6l2flCFlqJR3-k1Tc4xzhxNWNux4n0,6556
20
+ nshtrainer/callbacks/model_checkpoint.py,sha256=8D0wWLhr_KiksAA1fjfIuby42Mq6XokCvAnVUhjADd8,6564
21
21
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
22
22
  nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKhA9qBEosiV8ZNhhZ6lI,3355
23
23
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
@@ -35,7 +35,7 @@ nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
35
35
  nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
36
36
  nshtrainer/ll/log.py,sha256=d4BB3TyM8imK65EXOiOeUTF0zFM1ropbe7Vq3DeB0xU,140
37
37
  nshtrainer/ll/lr_scheduler.py,sha256=7xjhN6L69BCUzFhcy33NtMtPuCzHiB611zVWFg92lQ0,52
38
- nshtrainer/ll/model.py,sha256=6I9gQjEFT2Veer-UmcPy05Pt3mnvaxYX1b3sOdaF96A,45
38
+ nshtrainer/ll/model.py,sha256=cxFQfFc-2mAYBGwDpP8m5tjQBs7M47cZ6JoPXksPaoI,473
39
39
  nshtrainer/ll/nn.py,sha256=8qiRDFwojIxkB7-LtNWk4mLL2tJbaskHYofDsOIHiNg,42
40
40
  nshtrainer/ll/optimizer.py,sha256=3T-VZtT73jVvwCNJGDjgGEbzs-1LFTzMQH-SB_58mSo,49
41
41
  nshtrainer/ll/runner.py,sha256=B0m5VEhNKIjF1aFmqPkonkQxDoRL2jeHZGsV3zwhSVE,117
@@ -50,10 +50,9 @@ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9b
50
50
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
51
51
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
52
52
  nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
53
- nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
54
- nshtrainer/model/_environment.py,sha256=oTtecQeF5oY2RV7UkkSLnzDy3clz4AUkf9oocD6-e54,23115
55
- nshtrainer/model/base.py,sha256=WtCj0-nLWeW04Tu2TWVjIq0D-jW_kMN2hg--4VWVnvE,17505
56
- nshtrainer/model/config.py,sha256=orGBrp8TXnHksfAzXxNJVDdo0X_iIn_nda6BZDS9N70,53349
53
+ nshtrainer/model/__init__.py,sha256=NpvyQHmGaHB8xdraHmm8l7kDHLmvJSgBNQKkfYqtgyI,1454
54
+ nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
55
+ nshtrainer/model/config.py,sha256=65UDzt3ZZFUQaHMlK7f9wzwyGH3cDyHGtjZ2eOjHvVo,53360
57
56
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
58
57
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
59
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -74,11 +73,12 @@ nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2f
74
73
  nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
75
74
  nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
76
75
  nshtrainer/trainer/trainer.py,sha256=tFyzIsF8c-FABTH6wwDOR9y8kydVJqeVO7PDNFMvhSU,16950
76
+ nshtrainer/util/_environment_info.py,sha256=yPtAbgjCY4tkvh5wp9sjNsF0Z45TYwzEAM_N2_b5BbY,23123
77
77
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
78
78
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
79
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.11.dist-info/METADATA,sha256=9WAsp25_csjDcchr5X22g7ocQpQ-d-ewB3gS9EAZSE8,696
83
- nshtrainer-0.10.11.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.11.dist-info/RECORD,,
82
+ nshtrainer-0.10.13.dist-info/METADATA,sha256=HpGl8_E6q2l2nQrIzU5ibNEyUXj8adF8cxMzouUSpAg,696
83
+ nshtrainer-0.10.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.13.dist-info/RECORD,,