nshtrainer 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
- import importlib.util
2
1
  import logging
3
2
  from typing import Any, Literal, Protocol, runtime_checkable
4
3
 
5
4
  import torch
5
+ import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -20,19 +20,12 @@ class HasGradSkippedSteps(Protocol):
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- if importlib.util.find_spec("torchmetrics") is not None:
24
- raise ImportError(
25
- "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
- )
27
-
28
23
  super().__init__()
29
24
  self.config = config
30
25
 
31
26
  @override
32
27
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
33
28
  if not isinstance(pl_module, HasGradSkippedSteps):
34
- import torchmetrics # type: ignore
35
-
36
29
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
37
30
 
38
31
  @override
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.25.0
3
+ Version: 0.26.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -22,7 +22,7 @@ Requires-Dist: psutil
22
22
  Requires-Dist: pytorch-lightning
23
23
  Requires-Dist: tensorboard ; extra == "extra"
24
24
  Requires-Dist: torch
25
- Requires-Dist: torchmetrics ; extra == "extra"
25
+ Requires-Dist: torchmetrics
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
@@ -17,7 +17,7 @@ nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6
17
17
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
18
18
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
19
19
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
20
- nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
20
+ nshtrainer/callbacks/gradient_skipping.py,sha256=EBNkANDnD3BTszWjnG-jwY8FEj-iRqhE3e1x5LQF6M8,3393
21
21
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
22
22
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
23
23
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.25.0.dist-info/METADATA,sha256=Rqdeh2yp2AhZ_nOHlD47v5YPDrLc2fHN6WGwqJnDv04,935
91
- nshtrainer-0.25.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.25.0.dist-info/RECORD,,
90
+ nshtrainer-0.26.0.dist-info/METADATA,sha256=YBlbpalQ3BX8UBF_5SHk_F7v9Nq3JMsqVf6MoqH8KzU,916
91
+ nshtrainer-0.26.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.26.0.dist-info/RECORD,,