nshtrainer 0.10.7__py3-none-any.whl → 0.10.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/gradient_skipping.py +13 -12
- nshtrainer/data/transform.py +14 -2
- nshtrainer/model/_environment.py +5 -1
- nshtrainer/model/config.py +4 -0
- {nshtrainer-0.10.7.dist-info → nshtrainer-0.10.9.dist-info}/METADATA +1 -1
- {nshtrainer-0.10.7.dist-info → nshtrainer-0.10.9.dist-info}/RECORD +7 -8
- nshtrainer/scripts/check_env.py +0 -41
- {nshtrainer-0.10.7.dist-info → nshtrainer-0.10.9.dist-info}/WHEEL +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import importlib.util
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Literal, Protocol, runtime_checkable
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
|
-
import torchmetrics
|
|
6
6
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
7
7
|
from torch.optim import Optimizer
|
|
8
8
|
from typing_extensions import override
|
|
@@ -10,23 +10,29 @@ from typing_extensions import override
|
|
|
10
10
|
from .base import CallbackConfigBase
|
|
11
11
|
from .norm_logging import compute_norm
|
|
12
12
|
|
|
13
|
-
log = getLogger(__name__)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@runtime_checkable
|
|
17
17
|
class HasGradSkippedSteps(Protocol):
|
|
18
|
-
grad_skipped_steps:
|
|
18
|
+
grad_skipped_steps: Any
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class GradientSkipping(Callback):
|
|
22
22
|
def __init__(self, config: "GradientSkippingConfig"):
|
|
23
|
-
|
|
23
|
+
if importlib.util.find_spec("torchmetrics") is not None:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
|
|
26
|
+
)
|
|
24
27
|
|
|
28
|
+
super().__init__()
|
|
25
29
|
self.config = config
|
|
26
30
|
|
|
27
31
|
@override
|
|
28
32
|
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
29
33
|
if not isinstance(pl_module, HasGradSkippedSteps):
|
|
34
|
+
import torchmetrics # type: ignore
|
|
35
|
+
|
|
30
36
|
pl_module.grad_skipped_steps = torchmetrics.SumMetric()
|
|
31
37
|
|
|
32
38
|
@override
|
|
@@ -47,12 +53,7 @@ class GradientSkipping(Callback):
|
|
|
47
53
|
):
|
|
48
54
|
return
|
|
49
55
|
|
|
50
|
-
norm = compute_norm(
|
|
51
|
-
pl_module,
|
|
52
|
-
optimizer,
|
|
53
|
-
self.config.norm_type,
|
|
54
|
-
grad=True,
|
|
55
|
-
)
|
|
56
|
+
norm = compute_norm(pl_module, optimizer, self.config.norm_type, grad=True)
|
|
56
57
|
|
|
57
58
|
# If the norm is NaN/Inf, we don't want to skip the step
|
|
58
59
|
# beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
|
nshtrainer/data/transform.py
CHANGED
|
@@ -22,7 +22,13 @@ def transform(
|
|
|
22
22
|
deepcopy: Whether to deep copy each item before applying the transform.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
try:
|
|
26
|
+
import wrapt
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"wrapt is not installed. wrapt is required for the transform function."
|
|
30
|
+
"Please install it using 'pip install wrapt'"
|
|
31
|
+
)
|
|
26
32
|
|
|
27
33
|
class _TransformedDataset(wrapt.ObjectProxy):
|
|
28
34
|
def __getitem__(self, idx):
|
|
@@ -52,7 +58,13 @@ def transform_with_index(
|
|
|
52
58
|
deepcopy: Whether to deep copy each item before applying the transform.
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
|
-
|
|
61
|
+
try:
|
|
62
|
+
import wrapt
|
|
63
|
+
except ImportError:
|
|
64
|
+
raise ImportError(
|
|
65
|
+
"wrapt is not installed. wrapt is required for the transform function."
|
|
66
|
+
"Please install it using 'pip install wrapt'"
|
|
67
|
+
)
|
|
56
68
|
|
|
57
69
|
class _TransformedWithIndexDataset(wrapt.ObjectProxy):
|
|
58
70
|
def __getitem__(self, idx: int):
|
nshtrainer/model/_environment.py
CHANGED
|
@@ -9,7 +9,6 @@ from datetime import timedelta
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import TYPE_CHECKING, Any, cast
|
|
11
11
|
|
|
12
|
-
import git
|
|
13
12
|
import nshconfig as C
|
|
14
13
|
import psutil
|
|
15
14
|
import torch
|
|
@@ -618,6 +617,11 @@ class GitRepositoryConfig(C.Config):
|
|
|
618
617
|
|
|
619
618
|
@classmethod
|
|
620
619
|
def from_current_directory(cls):
|
|
620
|
+
try:
|
|
621
|
+
import git
|
|
622
|
+
except ImportError:
|
|
623
|
+
return cls()
|
|
624
|
+
|
|
621
625
|
draft = cls.draft()
|
|
622
626
|
try:
|
|
623
627
|
repo = git.Repo(os.getcwd(), search_parent_directories=True)
|
nshtrainer/model/config.py
CHANGED
|
@@ -286,6 +286,10 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
286
286
|
offline: bool = False
|
|
287
287
|
"""Whether to run WandB in offline mode."""
|
|
288
288
|
|
|
289
|
+
def offline_(self):
|
|
290
|
+
self.offline = True
|
|
291
|
+
return self
|
|
292
|
+
|
|
289
293
|
@override
|
|
290
294
|
def create_logger(self, root_config):
|
|
291
295
|
if not self.enabled:
|
|
@@ -12,7 +12,7 @@ nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,
|
|
|
12
12
|
nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
|
|
13
13
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
14
14
|
nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
|
|
15
|
-
nshtrainer/callbacks/gradient_skipping.py,sha256=
|
|
15
|
+
nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
|
|
16
16
|
nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
|
|
17
17
|
nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zCRAUsqW-2PaoIwVKlXOqdh2uF_B_YUUTmQO1wSomR8,2489
|
|
18
18
|
nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
|
|
@@ -25,7 +25,7 @@ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50
|
|
|
25
25
|
nshtrainer/callbacks/wandb_watch.py,sha256=0JKDEPSDmiKverFm00lFfufOvtvr49akFXScvdUQnqc,2930
|
|
26
26
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
27
27
|
nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
|
|
28
|
-
nshtrainer/data/transform.py,sha256=
|
|
28
|
+
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
29
29
|
nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
|
|
30
30
|
nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
|
|
31
31
|
nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
|
|
@@ -50,9 +50,9 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-
|
|
|
50
50
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
51
51
|
nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
|
|
52
52
|
nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
|
|
53
|
-
nshtrainer/model/_environment.py,sha256=
|
|
53
|
+
nshtrainer/model/_environment.py,sha256=oTtecQeF5oY2RV7UkkSLnzDy3clz4AUkf9oocD6-e54,23115
|
|
54
54
|
nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
|
|
55
|
-
nshtrainer/model/config.py,sha256=
|
|
55
|
+
nshtrainer/model/config.py,sha256=FRjn2pVbCzitBeIFcKXXPQ0TCmqvAwSylAvfedf1NHg,53356
|
|
56
56
|
nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
|
|
57
57
|
nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
|
|
58
58
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -67,7 +67,6 @@ nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,
|
|
|
67
67
|
nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
|
|
68
68
|
nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
|
|
69
69
|
nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
|
|
70
|
-
nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
|
|
71
70
|
nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
|
|
72
71
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
73
72
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
@@ -79,6 +78,6 @@ nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
|
|
|
79
78
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
80
79
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
81
80
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
82
|
-
nshtrainer-0.10.
|
|
83
|
-
nshtrainer-0.10.
|
|
84
|
-
nshtrainer-0.10.
|
|
81
|
+
nshtrainer-0.10.9.dist-info/METADATA,sha256=zpMnR1Jwc9kX7BhnHT6NA3nsPUwQzMAraBPOGsUSD1w,695
|
|
82
|
+
nshtrainer-0.10.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
83
|
+
nshtrainer-0.10.9.dist-info/RECORD,,
|
nshtrainer/scripts/check_env.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
REQUIRED_PACKAGES = [
|
|
2
|
-
"beartype",
|
|
3
|
-
"cloudpickle",
|
|
4
|
-
"jaxtyping",
|
|
5
|
-
"lightning",
|
|
6
|
-
"lightning_fabric",
|
|
7
|
-
"lightning_utilities",
|
|
8
|
-
"lovely_numpy",
|
|
9
|
-
"lovely_tensors",
|
|
10
|
-
"numpy",
|
|
11
|
-
"psutil",
|
|
12
|
-
"pydantic",
|
|
13
|
-
"pydantic_core",
|
|
14
|
-
"pysnooper",
|
|
15
|
-
"rich",
|
|
16
|
-
"tabulate",
|
|
17
|
-
"torch",
|
|
18
|
-
"torchmetrics",
|
|
19
|
-
"tqdm",
|
|
20
|
-
"typing_extensions",
|
|
21
|
-
"wrapt",
|
|
22
|
-
"yaml",
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def main():
|
|
27
|
-
import importlib.util
|
|
28
|
-
import sys
|
|
29
|
-
|
|
30
|
-
missing_packages: list[str] = []
|
|
31
|
-
for package_name in REQUIRED_PACKAGES:
|
|
32
|
-
spec = importlib.util.find_spec(package_name)
|
|
33
|
-
if spec is None:
|
|
34
|
-
missing_packages.append(package_name)
|
|
35
|
-
|
|
36
|
-
if missing_packages:
|
|
37
|
-
sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
if __name__ == "__main__":
|
|
41
|
-
main()
|
|
File without changes
|