nshtrainer 0.10.19__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/actsave.py +1 -15
- nshtrainer/model/modules/logger.py +1 -9
- {nshtrainer-0.10.19.dist-info → nshtrainer-0.11.0.dist-info}/METADATA +7 -4
- {nshtrainer-0.10.19.dist-info → nshtrainer-0.11.0.dist-info}/RECORD +5 -5
- {nshtrainer-0.10.19.dist-info → nshtrainer-0.11.0.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/actsave.py
CHANGED
|
@@ -4,15 +4,11 @@ from typing import Literal
|
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
6
6
|
from lightning.pytorch.callbacks.callback import Callback
|
|
7
|
+
from nshutils import ActSave
|
|
7
8
|
from typing_extensions import TypeAlias, override
|
|
8
9
|
|
|
9
10
|
from .base import CallbackConfigBase
|
|
10
11
|
|
|
11
|
-
try:
|
|
12
|
-
from nshutils import ActSave # type: ignore
|
|
13
|
-
except ImportError:
|
|
14
|
-
ActSave = None
|
|
15
|
-
|
|
16
12
|
Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
|
|
17
13
|
|
|
18
14
|
|
|
@@ -51,11 +47,6 @@ class ActSaveCallback(Callback):
|
|
|
51
47
|
if not self.config:
|
|
52
48
|
return
|
|
53
49
|
|
|
54
|
-
if ActSave is None:
|
|
55
|
-
raise ImportError(
|
|
56
|
-
"ActSave is not installed. Please install nshutils to use the ActSaveCallback."
|
|
57
|
-
)
|
|
58
|
-
|
|
59
50
|
context = ActSave.enabled(self.save_dir)
|
|
60
51
|
context.__enter__()
|
|
61
52
|
self._enabled_context = context
|
|
@@ -77,11 +68,6 @@ class ActSaveCallback(Callback):
|
|
|
77
68
|
if not self.config:
|
|
78
69
|
return
|
|
79
70
|
|
|
80
|
-
if ActSave is None:
|
|
81
|
-
raise ImportError(
|
|
82
|
-
"ActSave is not installed. Please install nshutils to use the ActSaveCallback."
|
|
83
|
-
)
|
|
84
|
-
|
|
85
71
|
# If we have an active context manager for this stage, exit it
|
|
86
72
|
if active_contexts := self._active_contexts.get(stage):
|
|
87
73
|
active_contexts.__exit__(None, None, None)
|
|
@@ -9,16 +9,12 @@ import torchmetrics
|
|
|
9
9
|
from lightning.pytorch import LightningDataModule, LightningModule
|
|
10
10
|
from lightning.pytorch.utilities.types import _METRIC
|
|
11
11
|
from lightning_utilities.core.rank_zero import rank_zero_warn
|
|
12
|
+
from nshutils import ActSave
|
|
12
13
|
from typing_extensions import override
|
|
13
14
|
|
|
14
15
|
from ...util.typing_utils import mixin_base_type
|
|
15
16
|
from ..config import BaseConfig
|
|
16
17
|
|
|
17
|
-
try:
|
|
18
|
-
from nshutils import ActSave # type: ignore
|
|
19
|
-
except ImportError:
|
|
20
|
-
ActSave = None
|
|
21
|
-
|
|
22
18
|
|
|
23
19
|
@dataclass(frozen=True, kw_only=True)
|
|
24
20
|
class _LogContext:
|
|
@@ -162,10 +158,6 @@ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningMod
|
|
|
162
158
|
if not hparams.trainer.logging.actsave_logged_metrics:
|
|
163
159
|
return
|
|
164
160
|
|
|
165
|
-
if ActSave is None:
|
|
166
|
-
rank_zero_warn("ActSave is not available, skipping logging of metrics")
|
|
167
|
-
return
|
|
168
|
-
|
|
169
161
|
ActSave.save(
|
|
170
162
|
lambda: {
|
|
171
163
|
f"logger.{name}": lambda: value.compute()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.11.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -9,7 +9,8 @@ Classifier: Programming Language :: Python :: 3
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
-
|
|
12
|
+
Provides-Extra: extra
|
|
13
|
+
Requires-Dist: GitPython ; extra == "extra"
|
|
13
14
|
Requires-Dist: lightning
|
|
14
15
|
Requires-Dist: nshconfig
|
|
15
16
|
Requires-Dist: nshrunner
|
|
@@ -17,10 +18,12 @@ Requires-Dist: nshutils
|
|
|
17
18
|
Requires-Dist: numpy
|
|
18
19
|
Requires-Dist: psutil
|
|
19
20
|
Requires-Dist: pytorch-lightning
|
|
21
|
+
Requires-Dist: tensorboard ; extra == "extra"
|
|
20
22
|
Requires-Dist: torch
|
|
21
|
-
Requires-Dist: torchmetrics
|
|
23
|
+
Requires-Dist: torchmetrics ; extra == "extra"
|
|
22
24
|
Requires-Dist: typing-extensions
|
|
23
|
-
Requires-Dist:
|
|
25
|
+
Requires-Dist: wandb ; extra == "extra"
|
|
26
|
+
Requires-Dist: wrapt ; extra == "extra"
|
|
24
27
|
Description-Content-Type: text/markdown
|
|
25
28
|
|
|
26
29
|
|
|
@@ -8,7 +8,7 @@ nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPUL
|
|
|
8
8
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gOqjE,2345
|
|
10
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
11
|
-
nshtrainer/callbacks/actsave.py,sha256=
|
|
11
|
+
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
13
|
nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
|
|
14
14
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -56,7 +56,7 @@ nshtrainer/model/config.py,sha256=4vze6tpMFAgpk532T33jmSH1lHfolHK1vAJEaa2-Vxs,54
|
|
|
56
56
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
57
57
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
58
58
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
59
|
-
nshtrainer/model/modules/logger.py,sha256=
|
|
59
|
+
nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
|
|
60
60
|
nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
|
|
61
61
|
nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
|
|
62
62
|
nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
|
|
@@ -79,6 +79,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
79
79
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
80
80
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
81
81
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
82
|
-
nshtrainer-0.
|
|
83
|
-
nshtrainer-0.
|
|
84
|
-
nshtrainer-0.
|
|
82
|
+
nshtrainer-0.11.0.dist-info/METADATA,sha256=IE_faJS_HOMLLu67UPzCAO1aBPZvjWZT3YBCaQ_YpC0,860
|
|
83
|
+
nshtrainer-0.11.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
84
|
+
nshtrainer-0.11.0.dist-info/RECORD,,
|
|
File without changes
|