nshtrainer 0.10.8__tar.gz → 0.10.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/PKG-INFO +1 -1
  2. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/pyproject.toml +5 -5
  3. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/gradient_skipping.py +13 -12
  4. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/data/transform.py +14 -2
  5. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/_environment.py +5 -1
  6. nshtrainer-0.10.8/src/nshtrainer/scripts/check_env.py +0 -41
  7. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/README.md +0 -0
  8. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/__init__.py +0 -0
  9. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/_checkpoint/loader.py +0 -0
  10. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  11. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/_experimental/__init__.py +0 -0
  12. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  13. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  14. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  15. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/__init__.py +0 -0
  16. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  17. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/actsave.py +0 -0
  18. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/base.py +0 -0
  19. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  20. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/ema.py +0 -0
  21. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  22. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/interval.py +0 -0
  23. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
  24. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  25. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/model_checkpoint.py +0 -0
  26. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  27. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  28. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/__init__.py +0 -0
  35. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/_experimental.py +0 -0
  36. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/actsave.py +0 -0
  37. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/callbacks.py +0 -0
  38. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/config.py +0 -0
  39. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/data.py +0 -0
  40. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/log.py +0 -0
  41. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  42. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/model.py +0 -0
  43. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/nn.py +0 -0
  44. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/optimizer.py +0 -0
  45. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/runner.py +0 -0
  46. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/snapshot.py +0 -0
  47. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/snoop.py +0 -0
  48. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/trainer.py +0 -0
  49. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/typecheck.py +0 -0
  50. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/ll/util.py +0 -0
  51. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  52. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  53. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  54. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  55. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/metrics/__init__.py +0 -0
  56. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/metrics/_config.py +0 -0
  57. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/__init__.py +0 -0
  58. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/base.py +0 -0
  59. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/config.py +0 -0
  60. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/callback.py +0 -0
  61. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/debug.py +0 -0
  62. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/distributed.py +0 -0
  63. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/logger.py +0 -0
  64. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/profiler.py +0 -0
  65. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  66. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  67. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/nn/__init__.py +0 -0
  68. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/nn/mlp.py +0 -0
  69. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/nn/module_dict.py +0 -0
  70. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/nn/module_list.py +0 -0
  71. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/nn/nonlinearity.py +0 -0
  72. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/optimizer.py +0 -0
  73. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/runner.py +0 -0
  74. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/scripts/find_packages.py +0 -0
  75. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/trainer/__init__.py +0 -0
  76. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  77. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  78. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/trainer/signal_connector.py +0 -0
  79. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/trainer/trainer.py +0 -0
  80. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/util/environment.py +0 -0
  81. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/util/seed.py +0 -0
  82. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/util/slurm.py +0 -0
  83. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/util/typed.py +0 -0
  84. {nshtrainer-0.10.8 → nshtrainer-0.10.9}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.8
3
+ Version: 0.10.9
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.10.8"
3
+ version = "0.10.9"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -11,14 +11,14 @@ nshrunner = "*"
11
11
  nshconfig = "*"
12
12
  nshutils = "*"
13
13
  psutil = "*"
14
+ numpy = "*"
14
15
  torch = "*"
15
16
  typing-extensions = "*"
16
17
  lightning = "*"
17
18
  pytorch-lightning = "*"
18
- torchmetrics = "*"
19
- numpy = "*"
20
- wrapt = "*"
21
- GitPython = "*"
19
+ torchmetrics = { version = "*", optional = true }
20
+ wrapt = { version = "*", optional = true }
21
+ GitPython = { version = "*", optional = true }
22
22
 
23
23
  [tool.poetry.group.dev.dependencies]
24
24
  pyright = "^1.1.372"
@@ -1,8 +1,8 @@
1
- from logging import getLogger
2
- from typing import Literal, Protocol, runtime_checkable
1
+ import importlib.util
2
+ import logging
3
+ from typing import Any, Literal, Protocol, runtime_checkable
3
4
 
4
5
  import torch
5
- import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -10,23 +10,29 @@ from typing_extensions import override
10
10
  from .base import CallbackConfigBase
11
11
  from .norm_logging import compute_norm
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  @runtime_checkable
17
17
  class HasGradSkippedSteps(Protocol):
18
- grad_skipped_steps: torchmetrics.SumMetric
18
+ grad_skipped_steps: Any
19
19
 
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- super().__init__()
23
+ if importlib.util.find_spec("torchmetrics") is not None:
24
+ raise ImportError(
25
+ "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
+ )
24
27
 
28
+ super().__init__()
25
29
  self.config = config
26
30
 
27
31
  @override
28
32
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
29
33
  if not isinstance(pl_module, HasGradSkippedSteps):
34
+ import torchmetrics # type: ignore
35
+
30
36
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
31
37
 
32
38
  @override
@@ -47,12 +53,7 @@ class GradientSkipping(Callback):
47
53
  ):
48
54
  return
49
55
 
50
- norm = compute_norm(
51
- pl_module,
52
- optimizer,
53
- self.config.norm_type,
54
- grad=True,
55
- )
56
+ norm = compute_norm(pl_module, optimizer, self.config.norm_type, grad=True)
56
57
 
57
58
  # If the norm is NaN/Inf, we don't want to skip the step
58
59
  # beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
@@ -22,7 +22,13 @@ def transform(
22
22
  deepcopy: Whether to deep copy each item before applying the transform.
23
23
  """
24
24
 
25
- import wrapt
25
+ try:
26
+ import wrapt
27
+ except ImportError:
28
+ raise ImportError(
29
+ "wrapt is not installed. wrapt is required for the transform function."
30
+ "Please install it using 'pip install wrapt'"
31
+ )
26
32
 
27
33
  class _TransformedDataset(wrapt.ObjectProxy):
28
34
  def __getitem__(self, idx):
@@ -52,7 +58,13 @@ def transform_with_index(
52
58
  deepcopy: Whether to deep copy each item before applying the transform.
53
59
  """
54
60
 
55
- import wrapt
61
+ try:
62
+ import wrapt
63
+ except ImportError:
64
+ raise ImportError(
65
+ "wrapt is not installed. wrapt is required for the transform function."
66
+ "Please install it using 'pip install wrapt'"
67
+ )
56
68
 
57
69
  class _TransformedWithIndexDataset(wrapt.ObjectProxy):
58
70
  def __getitem__(self, idx: int):
@@ -9,7 +9,6 @@ from datetime import timedelta
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING, Any, cast
11
11
 
12
- import git
13
12
  import nshconfig as C
14
13
  import psutil
15
14
  import torch
@@ -618,6 +617,11 @@ class GitRepositoryConfig(C.Config):
618
617
 
619
618
  @classmethod
620
619
  def from_current_directory(cls):
620
+ try:
621
+ import git
622
+ except ImportError:
623
+ return cls()
624
+
621
625
  draft = cls.draft()
622
626
  try:
623
627
  repo = git.Repo(os.getcwd(), search_parent_directories=True)
@@ -1,41 +0,0 @@
1
- REQUIRED_PACKAGES = [
2
- "beartype",
3
- "cloudpickle",
4
- "jaxtyping",
5
- "lightning",
6
- "lightning_fabric",
7
- "lightning_utilities",
8
- "lovely_numpy",
9
- "lovely_tensors",
10
- "numpy",
11
- "psutil",
12
- "pydantic",
13
- "pydantic_core",
14
- "pysnooper",
15
- "rich",
16
- "tabulate",
17
- "torch",
18
- "torchmetrics",
19
- "tqdm",
20
- "typing_extensions",
21
- "wrapt",
22
- "yaml",
23
- ]
24
-
25
-
26
- def main():
27
- import importlib.util
28
- import sys
29
-
30
- missing_packages: list[str] = []
31
- for package_name in REQUIRED_PACKAGES:
32
- spec = importlib.util.find_spec(package_name)
33
- if spec is None:
34
- missing_packages.append(package_name)
35
-
36
- if missing_packages:
37
- sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
File without changes