nshtrainer 0.10.8__tar.gz → 0.10.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/PKG-INFO +1 -1
  2. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/pyproject.toml +5 -5
  3. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/early_stopping.py +2 -2
  4. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/finite_checks.py +2 -2
  5. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/gradient_skipping.py +13 -12
  6. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/model_checkpoint.py +2 -2
  7. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/norm_logging.py +2 -2
  8. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/throughput_monitor.py +2 -2
  9. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/wandb_watch.py +2 -2
  10. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/data/balanced_batch_sampler.py +2 -2
  11. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/data/transform.py +14 -2
  12. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/_environment.py +5 -1
  13. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/base.py +2 -2
  14. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/config.py +2 -2
  15. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/callback.py +2 -2
  16. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/debug.py +2 -2
  17. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/rlp_sanity_checks.py +2 -2
  18. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/shared_parameters.py +2 -2
  19. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/environment.py +2 -2
  20. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/seed.py +2 -2
  21. nshtrainer-0.10.8/src/nshtrainer/scripts/check_env.py +0 -41
  22. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/README.md +0 -0
  23. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/__init__.py +0 -0
  24. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_checkpoint/loader.py +0 -0
  25. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  26. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/__init__.py +0 -0
  27. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  28. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  29. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  30. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/__init__.py +0 -0
  31. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  32. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/actsave.py +0 -0
  33. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/base.py +0 -0
  34. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/ema.py +0 -0
  35. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/interval.py +0 -0
  36. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
  37. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  38. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  39. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/print_table.py +0 -0
  40. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/timer.py +0 -0
  41. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/data/__init__.py +0 -0
  42. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/__init__.py +0 -0
  43. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/_experimental.py +0 -0
  44. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/actsave.py +0 -0
  45. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/callbacks.py +0 -0
  46. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/config.py +0 -0
  47. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/data.py +0 -0
  48. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/log.py +0 -0
  49. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  50. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/model.py +0 -0
  51. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/nn.py +0 -0
  52. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/optimizer.py +0 -0
  53. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/runner.py +0 -0
  54. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/snapshot.py +0 -0
  55. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/snoop.py +0 -0
  56. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/trainer.py +0 -0
  57. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/typecheck.py +0 -0
  58. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/util.py +0 -0
  59. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  60. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  61. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  62. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  63. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/metrics/__init__.py +0 -0
  64. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/metrics/_config.py +0 -0
  65. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/__init__.py +0 -0
  66. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/distributed.py +0 -0
  67. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/logger.py +0 -0
  68. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/profiler.py +0 -0
  69. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/__init__.py +0 -0
  70. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/mlp.py +0 -0
  71. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/module_dict.py +0 -0
  72. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/module_list.py +0 -0
  73. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/nonlinearity.py +0 -0
  74. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/optimizer.py +0 -0
  75. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/runner.py +0 -0
  76. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/scripts/find_packages.py +0 -0
  77. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/__init__.py +0 -0
  78. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  79. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  80. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/signal_connector.py +0 -0
  81. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/trainer.py +0 -0
  82. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/slurm.py +0 -0
  83. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/typed.py +0 -0
  84. {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.8
3
+ Version: 0.10.10
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.10.8"
3
+ version = "0.10.10"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -11,14 +11,14 @@ nshrunner = "*"
11
11
  nshconfig = "*"
12
12
  nshutils = "*"
13
13
  psutil = "*"
14
+ numpy = "*"
14
15
  torch = "*"
15
16
  typing-extensions = "*"
16
17
  lightning = "*"
17
18
  pytorch-lightning = "*"
18
- torchmetrics = "*"
19
- numpy = "*"
20
- wrapt = "*"
21
- GitPython = "*"
19
+ torchmetrics = { version = "*", optional = true }
20
+ wrapt = { version = "*", optional = true }
21
+ GitPython = { version = "*", optional = true }
22
22
 
23
23
  [tool.poetry.group.dev.dependencies]
24
24
  pyright = "^1.1.372"
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  import math
2
- from logging import getLogger
3
3
 
4
4
  from lightning.fabric.utilities.rank_zero import _get_rank
5
5
  from lightning.pytorch import Trainer
@@ -7,7 +7,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
7
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
8
  from typing_extensions import override
9
9
 
10
- log = getLogger(__name__)
10
+ log = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class EarlyStopping(_EarlyStopping):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal
3
3
 
4
4
  import torch
@@ -7,7 +7,7 @@ from typing_extensions import override
7
7
 
8
8
  from .base import CallbackConfigBase
9
9
 
10
- log = getLogger(__name__)
10
+ log = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  def finite_checks(
@@ -1,8 +1,8 @@
1
- from logging import getLogger
2
- from typing import Literal, Protocol, runtime_checkable
1
+ import importlib.util
2
+ import logging
3
+ from typing import Any, Literal, Protocol, runtime_checkable
3
4
 
4
5
  import torch
5
- import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -10,23 +10,29 @@ from typing_extensions import override
10
10
  from .base import CallbackConfigBase
11
11
  from .norm_logging import compute_norm
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  @runtime_checkable
17
17
  class HasGradSkippedSteps(Protocol):
18
- grad_skipped_steps: torchmetrics.SumMetric
18
+ grad_skipped_steps: Any
19
19
 
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- super().__init__()
23
+ if importlib.util.find_spec("torchmetrics") is not None:
24
+ raise ImportError(
25
+ "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
+ )
24
27
 
28
+ super().__init__()
25
29
  self.config = config
26
30
 
27
31
  @override
28
32
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
29
33
  if not isinstance(pl_module, HasGradSkippedSteps):
34
+ import torchmetrics # type: ignore
35
+
30
36
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
31
37
 
32
38
  @override
@@ -47,12 +53,7 @@ class GradientSkipping(Callback):
47
53
  ):
48
54
  return
49
55
 
50
- norm = compute_norm(
51
- pl_module,
52
- optimizer,
53
- self.config.norm_type,
54
- grad=True,
55
- )
56
+ norm = compute_norm(pl_module, optimizer, self.config.norm_type, grad=True)
56
57
 
57
58
  # If the norm is NaN/Inf, we don't want to skip the step
58
59
  # beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
@@ -1,6 +1,6 @@
1
+ import logging
1
2
  import re
2
3
  from datetime import timedelta
3
- from logging import getLogger
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
@@ -15,7 +15,7 @@ from .base import CallbackConfigBase
15
15
  if TYPE_CHECKING:
16
16
  from ..model.config import BaseConfig
17
17
 
18
- log = getLogger(__name__)
18
+ log = logging.getLogger(__name__)
19
19
 
20
20
 
21
21
  def _convert_string(input_string: str):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal, cast
3
3
 
4
4
  import torch
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from .base import CallbackConfigBase
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  def grad_norm(
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
3
3
 
4
4
  from typing_extensions import NotRequired, override
@@ -6,7 +6,7 @@ from typing_extensions import NotRequired, override
6
6
  from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
7
7
  from .base import CallbackConfigBase
8
8
 
9
- log = getLogger(__name__)
9
+ log = logging.getLogger(__name__)
10
10
 
11
11
 
12
12
  class ThroughputMonitorBatchStats(TypedDict):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal, Protocol, cast, runtime_checkable
3
3
 
4
4
  import torch.nn as nn
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from .base import CallbackConfigBase
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  @runtime_checkable
@@ -1,6 +1,6 @@
1
1
  import heapq
2
+ import logging
2
3
  from functools import cached_property
3
- from logging import getLogger
4
4
  from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  import numpy as np
@@ -10,7 +10,7 @@ from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
10
  from torch.utils.data import BatchSampler, Dataset, DistributedSampler
11
11
  from typing_extensions import override
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
@@ -22,7 +22,13 @@ def transform(
22
22
  deepcopy: Whether to deep copy each item before applying the transform.
23
23
  """
24
24
 
25
- import wrapt
25
+ try:
26
+ import wrapt
27
+ except ImportError:
28
+ raise ImportError(
29
+ "wrapt is not installed. wrapt is required for the transform function."
30
+ "Please install it using 'pip install wrapt'"
31
+ )
26
32
 
27
33
  class _TransformedDataset(wrapt.ObjectProxy):
28
34
  def __getitem__(self, idx):
@@ -52,7 +58,13 @@ def transform_with_index(
52
58
  deepcopy: Whether to deep copy each item before applying the transform.
53
59
  """
54
60
 
55
- import wrapt
61
+ try:
62
+ import wrapt
63
+ except ImportError:
64
+ raise ImportError(
65
+ "wrapt is not installed. wrapt is required for the transform function."
66
+ "Please install it using 'pip install wrapt'"
67
+ )
56
68
 
57
69
  class _TransformedWithIndexDataset(wrapt.ObjectProxy):
58
70
  def __getitem__(self, idx: int):
@@ -9,7 +9,6 @@ from datetime import timedelta
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING, Any, cast
11
11
 
12
- import git
13
12
  import nshconfig as C
14
13
  import psutil
15
14
  import torch
@@ -618,6 +617,11 @@ class GitRepositoryConfig(C.Config):
618
617
 
619
618
  @classmethod
620
619
  def from_current_directory(cls):
620
+ try:
621
+ import git
622
+ except ImportError:
623
+ return cls()
624
+
621
625
  draft = cls.draft()
622
626
  try:
623
627
  repo = git.Repo(os.getcwd(), search_parent_directories=True)
@@ -1,7 +1,7 @@
1
1
  import inspect
2
+ import logging
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import MutableMapping
4
- from logging import getLogger
5
5
  from typing import IO, TYPE_CHECKING, Any, Generic, cast
6
6
 
7
7
  import torch
@@ -21,7 +21,7 @@ from .modules.profiler import ProfilerMixin
21
21
  from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
22
22
  from .modules.shared_parameters import SharedParametersModuleMixin
23
23
 
24
- log = getLogger(__name__)
24
+ log = logging.getLogger(__name__)
25
25
 
26
26
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
27
27
 
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import logging
2
3
  import os
3
4
  import string
4
5
  import time
@@ -6,7 +7,6 @@ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
  from collections.abc import Iterable, Sequence
8
9
  from datetime import timedelta
9
- from logging import getLogger
10
10
  from pathlib import Path
11
11
  from typing import (
12
12
  Annotated,
@@ -46,7 +46,7 @@ from ..callbacks.base import CallbackConfigBase
46
46
  from ..metrics import MetricConfig
47
47
  from ._environment import EnvironmentConfig
48
48
 
49
- log = getLogger(__name__)
49
+ log = logging.getLogger(__name__)
50
50
 
51
51
 
52
52
  class IdSeedWarning(Warning):
@@ -1,6 +1,6 @@
1
+ import logging
1
2
  from collections import abc
2
3
  from collections.abc import Callable, Iterable
3
- from logging import getLogger
4
4
  from typing import Any, TypeAlias, cast, final
5
5
 
6
6
  from lightning.pytorch import Callback, LightningModule
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from ...util.typing_utils import mixin_base_type
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
  CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
15
15
 
@@ -1,9 +1,9 @@
1
- from logging import getLogger
1
+ import logging
2
2
 
3
3
  import torch
4
4
  import torch.distributed
5
5
 
6
- log = getLogger(__name__)
6
+ log = logging.getLogger(__name__)
7
7
 
8
8
 
9
9
  class DebugModuleMixin:
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  from collections.abc import Mapping
2
- from logging import getLogger
3
3
  from typing import cast
4
4
 
5
5
  import torch
@@ -14,7 +14,7 @@ from ...util.typing_utils import mixin_base_type
14
14
  from ..config import BaseConfig
15
15
  from .callback import CallbackModuleMixin
16
16
 
17
- log = getLogger(__name__)
17
+ log = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  from collections.abc import Sequence
2
- from logging import getLogger
3
3
  from typing import cast
4
4
 
5
5
  import torch.nn as nn
@@ -10,7 +10,7 @@ from ...util.typing_utils import mixin_base_type
10
10
  from ..config import BaseConfig
11
11
  from .callback import CallbackRegistrarModuleMixin
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
@@ -1,8 +1,8 @@
1
+ import logging
1
2
  import os
2
3
  from contextlib import contextmanager
3
- from logging import getLogger
4
4
 
5
- log = getLogger(__name__)
5
+ log = logging.getLogger(__name__)
6
6
 
7
7
 
8
8
  @contextmanager
@@ -1,8 +1,8 @@
1
- from logging import getLogger
1
+ import logging
2
2
 
3
3
  import lightning.fabric.utilities.seed as LS
4
4
 
5
- log = getLogger(__name__)
5
+ log = logging.getLogger(__name__)
6
6
 
7
7
 
8
8
  def seed_everything(seed: int | None, *, workers: bool = False):
@@ -1,41 +0,0 @@
1
- REQUIRED_PACKAGES = [
2
- "beartype",
3
- "cloudpickle",
4
- "jaxtyping",
5
- "lightning",
6
- "lightning_fabric",
7
- "lightning_utilities",
8
- "lovely_numpy",
9
- "lovely_tensors",
10
- "numpy",
11
- "psutil",
12
- "pydantic",
13
- "pydantic_core",
14
- "pysnooper",
15
- "rich",
16
- "tabulate",
17
- "torch",
18
- "torchmetrics",
19
- "tqdm",
20
- "typing_extensions",
21
- "wrapt",
22
- "yaml",
23
- ]
24
-
25
-
26
- def main():
27
- import importlib.util
28
- import sys
29
-
30
- missing_packages: list[str] = []
31
- for package_name in REQUIRED_PACKAGES:
32
- spec = importlib.util.find_spec(package_name)
33
- if spec is None:
34
- missing_packages.append(package_name)
35
-
36
- if missing_packages:
37
- sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
File without changes