nshtrainer 0.10.18__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -92,7 +92,7 @@ def _write_checkpoint_metadata(
92
92
  except Exception as e:
93
93
  log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
94
94
  else:
95
- log.info(f"Checkpoint metadata written to {checkpoint_path}")
95
+ log.debug(f"Checkpoint metadata written to {checkpoint_path}")
96
96
 
97
97
  # Write the hparams to the checkpoint directory
98
98
  try:
@@ -101,7 +101,7 @@ def _write_checkpoint_metadata(
101
101
  except Exception as e:
102
102
  log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
103
103
  else:
104
- log.info(f"Checkpoint metadata written to {checkpoint_path}")
104
+ log.debug(f"Checkpoint metadata written to {checkpoint_path}")
105
105
 
106
106
 
107
107
  def _remove_checkpoint_metadata(checkpoint_path: Path):
@@ -112,7 +112,7 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
112
112
  except Exception as e:
113
113
  log.warning(f"Failed to remove {path}: {e}")
114
114
  else:
115
- log.info(f"Removed {path}")
115
+ log.debug(f"Removed {path}")
116
116
 
117
117
 
118
118
  def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
@@ -133,7 +133,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
133
133
  except Exception as e:
134
134
  log.warning(f"Failed to link {path} to {linked_path}: {e}")
135
135
  else:
136
- log.info(f"Linked {path} to {linked_path}")
136
+ log.debug(f"Linked {path} to {linked_path}")
137
137
 
138
138
 
139
139
  def _checkpoint_sort_key_fn(key: Callable[[CheckpointMetadata, Path], Any]):
@@ -4,15 +4,11 @@ from typing import Literal
4
4
 
5
5
  from lightning.pytorch import LightningModule, Trainer
6
6
  from lightning.pytorch.callbacks.callback import Callback
7
+ from nshutils import ActSave
7
8
  from typing_extensions import TypeAlias, override
8
9
 
9
10
  from .base import CallbackConfigBase
10
11
 
11
- try:
12
- from nshutils import ActSave # type: ignore
13
- except ImportError:
14
- ActSave = None
15
-
16
12
  Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
17
13
 
18
14
 
@@ -51,11 +47,6 @@ class ActSaveCallback(Callback):
51
47
  if not self.config:
52
48
  return
53
49
 
54
- if ActSave is None:
55
- raise ImportError(
56
- "ActSave is not installed. Please install nshutils to use the ActSaveCallback."
57
- )
58
-
59
50
  context = ActSave.enabled(self.save_dir)
60
51
  context.__enter__()
61
52
  self._enabled_context = context
@@ -77,11 +68,6 @@ class ActSaveCallback(Callback):
77
68
  if not self.config:
78
69
  return
79
70
 
80
- if ActSave is None:
81
- raise ImportError(
82
- "ActSave is not installed. Please install nshutils to use the ActSaveCallback."
83
- )
84
-
85
71
  # If we have an active context manager for this stage, exit it
86
72
  if active_contexts := self._active_contexts.get(stage):
87
73
  active_contexts.__exit__(None, None, None)
@@ -9,16 +9,12 @@ import torchmetrics
9
9
  from lightning.pytorch import LightningDataModule, LightningModule
10
10
  from lightning.pytorch.utilities.types import _METRIC
11
11
  from lightning_utilities.core.rank_zero import rank_zero_warn
12
+ from nshutils import ActSave
12
13
  from typing_extensions import override
13
14
 
14
15
  from ...util.typing_utils import mixin_base_type
15
16
  from ..config import BaseConfig
16
17
 
17
- try:
18
- from nshutils import ActSave # type: ignore
19
- except ImportError:
20
- ActSave = None
21
-
22
18
 
23
19
  @dataclass(frozen=True, kw_only=True)
24
20
  class _LogContext:
@@ -162,10 +158,6 @@ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningMod
162
158
  if not hparams.trainer.logging.actsave_logged_metrics:
163
159
  return
164
160
 
165
- if ActSave is None:
166
- rank_zero_warn("ActSave is not available, skipping logging of metrics")
167
- return
168
-
169
161
  ActSave.save(
170
162
  lambda: {
171
163
  f"logger.{name}": lambda: value.compute()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.18
3
+ Version: 0.11.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,7 +9,8 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
- Requires-Dist: GitPython
12
+ Provides-Extra: extra
13
+ Requires-Dist: GitPython ; extra == "extra"
13
14
  Requires-Dist: lightning
14
15
  Requires-Dist: nshconfig
15
16
  Requires-Dist: nshrunner
@@ -17,10 +18,12 @@ Requires-Dist: nshutils
17
18
  Requires-Dist: numpy
18
19
  Requires-Dist: psutil
19
20
  Requires-Dist: pytorch-lightning
21
+ Requires-Dist: tensorboard ; extra == "extra"
20
22
  Requires-Dist: torch
21
- Requires-Dist: torchmetrics
23
+ Requires-Dist: torchmetrics ; extra == "extra"
22
24
  Requires-Dist: typing-extensions
23
- Requires-Dist: wrapt
25
+ Requires-Dist: wandb ; extra == "extra"
26
+ Requires-Dist: wrapt ; extra == "extra"
24
27
  Description-Content-Type: text/markdown
25
28
 
26
29
 
@@ -1,6 +1,6 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
- nshtrainer/_checkpoint/metadata.py,sha256=GlhlAyJh5gcp3R8l2Y3eAUQtQzBnitFlB0xdx-khEUQ,5579
3
+ nshtrainer/_checkpoint/metadata.py,sha256=3yxGxHLIVwKh5K4L8LYOEK3GQ6HQXy89CGcy9zarApo,5583
4
4
  nshtrainer/_checkpoint/saver.py,sha256=z_c7a91O4Bh4lZZjqJgxT3w25qFlJsOopV3cpJtkHk8,1655
5
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
@@ -8,7 +8,7 @@ nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPUL
8
8
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
9
9
  nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gOqjE,2345
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
- nshtrainer/callbacks/actsave.py,sha256=aY6T_NAzaFAVU8WMHOXnWL5wd2bi8eVxeU2S0iAs70c,4446
11
+ nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
13
  nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
14
14
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
@@ -56,7 +56,7 @@ nshtrainer/model/config.py,sha256=4vze6tpMFAgpk532T33jmSH1lHfolHK1vAJEaa2-Vxs,54
56
56
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
57
57
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
58
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
59
- nshtrainer/model/modules/logger.py,sha256=YYhehQysqTjuVFcd_EREYDh57CIlezidFBS2Ohp_xKo,5661
59
+ nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
60
60
  nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
61
61
  nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
62
62
  nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
@@ -79,6 +79,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
79
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.18.dist-info/METADATA,sha256=-r5sz2eulZvKCATa6NFmUGkwT4FFrjetLGJawsakdyM,696
83
- nshtrainer-0.10.18.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.18.dist-info/RECORD,,
82
+ nshtrainer-0.11.0.dist-info/METADATA,sha256=IE_faJS_HOMLLu67UPzCAO1aBPZvjWZT3YBCaQ_YpC0,860
83
+ nshtrainer-0.11.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.11.0.dist-info/RECORD,,