nshtrainer 0.10.7__py3-none-any.whl → 0.10.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
- from logging import getLogger
2
- from typing import Literal, Protocol, runtime_checkable
1
+ import importlib.util
2
+ import logging
3
+ from typing import Any, Literal, Protocol, runtime_checkable
3
4
 
4
5
  import torch
5
- import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -10,23 +10,29 @@ from typing_extensions import override
10
10
  from .base import CallbackConfigBase
11
11
  from .norm_logging import compute_norm
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  @runtime_checkable
17
17
  class HasGradSkippedSteps(Protocol):
18
- grad_skipped_steps: torchmetrics.SumMetric
18
+ grad_skipped_steps: Any
19
19
 
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- super().__init__()
23
+ if importlib.util.find_spec("torchmetrics") is not None:
24
+ raise ImportError(
25
+ "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
+ )
24
27
 
28
+ super().__init__()
25
29
  self.config = config
26
30
 
27
31
  @override
28
32
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
29
33
  if not isinstance(pl_module, HasGradSkippedSteps):
34
+ import torchmetrics # type: ignore
35
+
30
36
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
31
37
 
32
38
  @override
@@ -47,12 +53,7 @@ class GradientSkipping(Callback):
47
53
  ):
48
54
  return
49
55
 
50
- norm = compute_norm(
51
- pl_module,
52
- optimizer,
53
- self.config.norm_type,
54
- grad=True,
55
- )
56
+ norm = compute_norm(pl_module, optimizer, self.config.norm_type, grad=True)
56
57
 
57
58
  # If the norm is NaN/Inf, we don't want to skip the step
58
59
  # beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
@@ -22,7 +22,13 @@ def transform(
22
22
  deepcopy: Whether to deep copy each item before applying the transform.
23
23
  """
24
24
 
25
- import wrapt
25
+ try:
26
+ import wrapt
27
+ except ImportError:
28
+ raise ImportError(
29
+ "wrapt is not installed. wrapt is required for the transform function."
30
+ "Please install it using 'pip install wrapt'"
31
+ )
26
32
 
27
33
  class _TransformedDataset(wrapt.ObjectProxy):
28
34
  def __getitem__(self, idx):
@@ -52,7 +58,13 @@ def transform_with_index(
52
58
  deepcopy: Whether to deep copy each item before applying the transform.
53
59
  """
54
60
 
55
- import wrapt
61
+ try:
62
+ import wrapt
63
+ except ImportError:
64
+ raise ImportError(
65
+ "wrapt is not installed. wrapt is required for the transform function."
66
+ "Please install it using 'pip install wrapt'"
67
+ )
56
68
 
57
69
  class _TransformedWithIndexDataset(wrapt.ObjectProxy):
58
70
  def __getitem__(self, idx: int):
@@ -9,7 +9,6 @@ from datetime import timedelta
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING, Any, cast
11
11
 
12
- import git
13
12
  import nshconfig as C
14
13
  import psutil
15
14
  import torch
@@ -618,6 +617,11 @@ class GitRepositoryConfig(C.Config):
618
617
 
619
618
  @classmethod
620
619
  def from_current_directory(cls):
620
+ try:
621
+ import git
622
+ except ImportError:
623
+ return cls()
624
+
621
625
  draft = cls.draft()
622
626
  try:
623
627
  repo = git.Repo(os.getcwd(), search_parent_directories=True)
@@ -286,6 +286,10 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
286
286
  offline: bool = False
287
287
  """Whether to run WandB in offline mode."""
288
288
 
289
+ def offline_(self):
290
+ self.offline = True
291
+ return self
292
+
289
293
  @override
290
294
  def create_logger(self, root_config):
291
295
  if not self.enabled:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.7
3
+ Version: 0.10.9
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -12,7 +12,7 @@ nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,
12
12
  nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
13
13
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
14
14
  nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
15
- nshtrainer/callbacks/gradient_skipping.py,sha256=fSJpjgHbztFKz7w3qFuCHZpmbEt9BCLAy-sU0B4xJQI,3474
15
+ nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
16
16
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
17
17
  nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zCRAUsqW-2PaoIwVKlXOqdh2uF_B_YUUTmQO1wSomR8,2489
18
18
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
@@ -25,7 +25,7 @@ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50
25
25
  nshtrainer/callbacks/wandb_watch.py,sha256=0JKDEPSDmiKverFm00lFfufOvtvr49akFXScvdUQnqc,2930
26
26
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
27
27
  nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
28
- nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
28
+ nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
29
29
  nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
30
30
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
31
31
  nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
@@ -50,9 +50,9 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-
50
50
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
51
51
  nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
52
52
  nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
53
- nshtrainer/model/_environment.py,sha256=JCFxxwMhkviiMDkqIXJmiuepqiSYIlcoSQM7Y2H2KX4,23036
53
+ nshtrainer/model/_environment.py,sha256=oTtecQeF5oY2RV7UkkSLnzDy3clz4AUkf9oocD6-e54,23115
54
54
  nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
55
- nshtrainer/model/config.py,sha256=6pAqDUk1eBloR3vZmtsWVdMrKeT2V3UvOn5UZ7YhZ_Q,53283
55
+ nshtrainer/model/config.py,sha256=FRjn2pVbCzitBeIFcKXXPQ0TCmqvAwSylAvfedf1NHg,53356
56
56
  nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
57
57
  nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
58
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -67,7 +67,6 @@ nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,
67
67
  nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
68
68
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
69
69
  nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
70
- nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
71
70
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
72
71
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
73
72
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
@@ -79,6 +78,6 @@ nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
79
78
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
79
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
80
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.7.dist-info/METADATA,sha256=IQ6IEecsAvygnoV5P6_mkG9RjRGnb_cFuOf2Ic2HLIY,695
83
- nshtrainer-0.10.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.7.dist-info/RECORD,,
81
+ nshtrainer-0.10.9.dist-info/METADATA,sha256=zpMnR1Jwc9kX7BhnHT6NA3nsPUwQzMAraBPOGsUSD1w,695
82
+ nshtrainer-0.10.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
83
+ nshtrainer-0.10.9.dist-info/RECORD,,
@@ -1,41 +0,0 @@
1
- REQUIRED_PACKAGES = [
2
- "beartype",
3
- "cloudpickle",
4
- "jaxtyping",
5
- "lightning",
6
- "lightning_fabric",
7
- "lightning_utilities",
8
- "lovely_numpy",
9
- "lovely_tensors",
10
- "numpy",
11
- "psutil",
12
- "pydantic",
13
- "pydantic_core",
14
- "pysnooper",
15
- "rich",
16
- "tabulate",
17
- "torch",
18
- "torchmetrics",
19
- "tqdm",
20
- "typing_extensions",
21
- "wrapt",
22
- "yaml",
23
- ]
24
-
25
-
26
- def main():
27
- import importlib.util
28
- import sys
29
-
30
- missing_packages: list[str] = []
31
- for package_name in REQUIRED_PACKAGES:
32
- spec = importlib.util.find_spec(package_name)
33
- if spec is None:
34
- missing_packages.append(package_name)
35
-
36
- if missing_packages:
37
- sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
38
-
39
-
40
- if __name__ == "__main__":
41
- main()