nshtrainer 0.10.8__tar.gz → 0.10.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/PKG-INFO +1 -1
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/pyproject.toml +5 -5
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/early_stopping.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/finite_checks.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/gradient_skipping.py +13 -12
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/model_checkpoint.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/norm_logging.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/throughput_monitor.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/wandb_watch.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/data/balanced_batch_sampler.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/data/transform.py +14 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/_environment.py +5 -1
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/base.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/config.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/callback.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/debug.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/rlp_sanity_checks.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/shared_parameters.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/environment.py +2 -2
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/seed.py +2 -2
- nshtrainer-0.10.8/src/nshtrainer/scripts/check_env.py +0 -41
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/README.md +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.10.
|
|
3
|
+
version = "0.10.10"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -11,14 +11,14 @@ nshrunner = "*"
|
|
|
11
11
|
nshconfig = "*"
|
|
12
12
|
nshutils = "*"
|
|
13
13
|
psutil = "*"
|
|
14
|
+
numpy = "*"
|
|
14
15
|
torch = "*"
|
|
15
16
|
typing-extensions = "*"
|
|
16
17
|
lightning = "*"
|
|
17
18
|
pytorch-lightning = "*"
|
|
18
|
-
torchmetrics = "*"
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
GitPython = "*"
|
|
19
|
+
torchmetrics = { version = "*", optional = true }
|
|
20
|
+
wrapt = { version = "*", optional = true }
|
|
21
|
+
GitPython = { version = "*", optional = true }
|
|
22
22
|
|
|
23
23
|
[tool.poetry.group.dev.dependencies]
|
|
24
24
|
pyright = "^1.1.372"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import math
|
|
2
|
-
from logging import getLogger
|
|
3
3
|
|
|
4
4
|
from lightning.fabric.utilities.rank_zero import _get_rank
|
|
5
5
|
from lightning.pytorch import Trainer
|
|
@@ -7,7 +7,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
|
|
|
7
7
|
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
|
|
8
8
|
from typing_extensions import override
|
|
9
9
|
|
|
10
|
-
log = getLogger(__name__)
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class EarlyStopping(_EarlyStopping):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Literal
|
|
3
3
|
|
|
4
4
|
import torch
|
|
@@ -7,7 +7,7 @@ from typing_extensions import override
|
|
|
7
7
|
|
|
8
8
|
from .base import CallbackConfigBase
|
|
9
9
|
|
|
10
|
-
log = getLogger(__name__)
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def finite_checks(
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import importlib.util
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Literal, Protocol, runtime_checkable
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
|
-
import torchmetrics
|
|
6
6
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
7
7
|
from torch.optim import Optimizer
|
|
8
8
|
from typing_extensions import override
|
|
@@ -10,23 +10,29 @@ from typing_extensions import override
|
|
|
10
10
|
from .base import CallbackConfigBase
|
|
11
11
|
from .norm_logging import compute_norm
|
|
12
12
|
|
|
13
|
-
log = getLogger(__name__)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@runtime_checkable
|
|
17
17
|
class HasGradSkippedSteps(Protocol):
|
|
18
|
-
grad_skipped_steps:
|
|
18
|
+
grad_skipped_steps: Any
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class GradientSkipping(Callback):
|
|
22
22
|
def __init__(self, config: "GradientSkippingConfig"):
|
|
23
|
-
|
|
23
|
+
if importlib.util.find_spec("torchmetrics") is not None:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
|
|
26
|
+
)
|
|
24
27
|
|
|
28
|
+
super().__init__()
|
|
25
29
|
self.config = config
|
|
26
30
|
|
|
27
31
|
@override
|
|
28
32
|
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
29
33
|
if not isinstance(pl_module, HasGradSkippedSteps):
|
|
34
|
+
import torchmetrics # type: ignore
|
|
35
|
+
|
|
30
36
|
pl_module.grad_skipped_steps = torchmetrics.SumMetric()
|
|
31
37
|
|
|
32
38
|
@override
|
|
@@ -47,12 +53,7 @@ class GradientSkipping(Callback):
|
|
|
47
53
|
):
|
|
48
54
|
return
|
|
49
55
|
|
|
50
|
-
norm = compute_norm(
|
|
51
|
-
pl_module,
|
|
52
|
-
optimizer,
|
|
53
|
-
self.config.norm_type,
|
|
54
|
-
grad=True,
|
|
55
|
-
)
|
|
56
|
+
norm = compute_norm(pl_module, optimizer, self.config.norm_type, grad=True)
|
|
56
57
|
|
|
57
58
|
# If the norm is NaN/Inf, we don't want to skip the step
|
|
58
59
|
# beacuse AMP checks for NaN/Inf grads to adjust the loss scale.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import re
|
|
2
3
|
from datetime import timedelta
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import TYPE_CHECKING, Literal
|
|
6
6
|
|
|
@@ -15,7 +15,7 @@ from .base import CallbackConfigBase
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from ..model.config import BaseConfig
|
|
17
17
|
|
|
18
|
-
log = getLogger(__name__)
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def _convert_string(input_string: str):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Literal, cast
|
|
3
3
|
|
|
4
4
|
import torch
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
|
|
10
10
|
from .base import CallbackConfigBase
|
|
11
11
|
|
|
12
|
-
log = getLogger(__name__)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def grad_norm(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
|
|
3
3
|
|
|
4
4
|
from typing_extensions import NotRequired, override
|
|
@@ -6,7 +6,7 @@ from typing_extensions import NotRequired, override
|
|
|
6
6
|
from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
|
|
7
7
|
from .base import CallbackConfigBase
|
|
8
8
|
|
|
9
|
-
log = getLogger(__name__)
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class ThroughputMonitorBatchStats(TypedDict):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
from typing import Literal, Protocol, cast, runtime_checkable
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
|
|
10
10
|
from .base import CallbackConfigBase
|
|
11
11
|
|
|
12
|
-
log = getLogger(__name__)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
@runtime_checkable
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import heapq
|
|
2
|
+
import logging
|
|
2
3
|
from functools import cached_property
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from typing import Any, Protocol, runtime_checkable
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
@@ -10,7 +10,7 @@ from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
|
|
|
10
10
|
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
-
log = getLogger(__name__)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
|
|
@@ -22,7 +22,13 @@ def transform(
|
|
|
22
22
|
deepcopy: Whether to deep copy each item before applying the transform.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
try:
|
|
26
|
+
import wrapt
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"wrapt is not installed. wrapt is required for the transform function."
|
|
30
|
+
"Please install it using 'pip install wrapt'"
|
|
31
|
+
)
|
|
26
32
|
|
|
27
33
|
class _TransformedDataset(wrapt.ObjectProxy):
|
|
28
34
|
def __getitem__(self, idx):
|
|
@@ -52,7 +58,13 @@ def transform_with_index(
|
|
|
52
58
|
deepcopy: Whether to deep copy each item before applying the transform.
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
|
-
|
|
61
|
+
try:
|
|
62
|
+
import wrapt
|
|
63
|
+
except ImportError:
|
|
64
|
+
raise ImportError(
|
|
65
|
+
"wrapt is not installed. wrapt is required for the transform function."
|
|
66
|
+
"Please install it using 'pip install wrapt'"
|
|
67
|
+
)
|
|
56
68
|
|
|
57
69
|
class _TransformedWithIndexDataset(wrapt.ObjectProxy):
|
|
58
70
|
def __getitem__(self, idx: int):
|
|
@@ -9,7 +9,6 @@ from datetime import timedelta
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import TYPE_CHECKING, Any, cast
|
|
11
11
|
|
|
12
|
-
import git
|
|
13
12
|
import nshconfig as C
|
|
14
13
|
import psutil
|
|
15
14
|
import torch
|
|
@@ -618,6 +617,11 @@ class GitRepositoryConfig(C.Config):
|
|
|
618
617
|
|
|
619
618
|
@classmethod
|
|
620
619
|
def from_current_directory(cls):
|
|
620
|
+
try:
|
|
621
|
+
import git
|
|
622
|
+
except ImportError:
|
|
623
|
+
return cls()
|
|
624
|
+
|
|
621
625
|
draft = cls.draft()
|
|
622
626
|
try:
|
|
623
627
|
repo = git.Repo(os.getcwd(), search_parent_directories=True)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import logging
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import MutableMapping
|
|
4
|
-
from logging import getLogger
|
|
5
5
|
from typing import IO, TYPE_CHECKING, Any, Generic, cast
|
|
6
6
|
|
|
7
7
|
import torch
|
|
@@ -21,7 +21,7 @@ from .modules.profiler import ProfilerMixin
|
|
|
21
21
|
from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
|
|
22
22
|
from .modules.shared_parameters import SharedParametersModuleMixin
|
|
23
23
|
|
|
24
|
-
log = getLogger(__name__)
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
|
|
27
27
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
import string
|
|
4
5
|
import time
|
|
@@ -6,7 +7,6 @@ import warnings
|
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
7
8
|
from collections.abc import Iterable, Sequence
|
|
8
9
|
from datetime import timedelta
|
|
9
|
-
from logging import getLogger
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from typing import (
|
|
12
12
|
Annotated,
|
|
@@ -46,7 +46,7 @@ from ..callbacks.base import CallbackConfigBase
|
|
|
46
46
|
from ..metrics import MetricConfig
|
|
47
47
|
from ._environment import EnvironmentConfig
|
|
48
48
|
|
|
49
|
-
log = getLogger(__name__)
|
|
49
|
+
log = logging.getLogger(__name__)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class IdSeedWarning(Warning):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from collections import abc
|
|
2
3
|
from collections.abc import Callable, Iterable
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from typing import Any, TypeAlias, cast, final
|
|
5
5
|
|
|
6
6
|
from lightning.pytorch import Callback, LightningModule
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
|
|
10
10
|
from ...util.typing_utils import mixin_base_type
|
|
11
11
|
|
|
12
|
-
log = getLogger(__name__)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
|
|
15
15
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from collections.abc import Mapping
|
|
2
|
-
from logging import getLogger
|
|
3
3
|
from typing import cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
@@ -14,7 +14,7 @@ from ...util.typing_utils import mixin_base_type
|
|
|
14
14
|
from ..config import BaseConfig
|
|
15
15
|
from .callback import CallbackModuleMixin
|
|
16
16
|
|
|
17
|
-
log = getLogger(__name__)
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from collections.abc import Sequence
|
|
2
|
-
from logging import getLogger
|
|
3
3
|
from typing import cast
|
|
4
4
|
|
|
5
5
|
import torch.nn as nn
|
|
@@ -10,7 +10,7 @@ from ...util.typing_utils import mixin_base_type
|
|
|
10
10
|
from ..config import BaseConfig
|
|
11
11
|
from .callback import CallbackRegistrarModuleMixin
|
|
12
12
|
|
|
13
|
-
log = getLogger(__name__)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
REQUIRED_PACKAGES = [
|
|
2
|
-
"beartype",
|
|
3
|
-
"cloudpickle",
|
|
4
|
-
"jaxtyping",
|
|
5
|
-
"lightning",
|
|
6
|
-
"lightning_fabric",
|
|
7
|
-
"lightning_utilities",
|
|
8
|
-
"lovely_numpy",
|
|
9
|
-
"lovely_tensors",
|
|
10
|
-
"numpy",
|
|
11
|
-
"psutil",
|
|
12
|
-
"pydantic",
|
|
13
|
-
"pydantic_core",
|
|
14
|
-
"pysnooper",
|
|
15
|
-
"rich",
|
|
16
|
-
"tabulate",
|
|
17
|
-
"torch",
|
|
18
|
-
"torchmetrics",
|
|
19
|
-
"tqdm",
|
|
20
|
-
"typing_extensions",
|
|
21
|
-
"wrapt",
|
|
22
|
-
"yaml",
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def main():
|
|
27
|
-
import importlib.util
|
|
28
|
-
import sys
|
|
29
|
-
|
|
30
|
-
missing_packages: list[str] = []
|
|
31
|
-
for package_name in REQUIRED_PACKAGES:
|
|
32
|
-
spec = importlib.util.find_spec(package_name)
|
|
33
|
-
if spec is None:
|
|
34
|
-
missing_packages.append(package_name)
|
|
35
|
-
|
|
36
|
-
if missing_packages:
|
|
37
|
-
sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
if __name__ == "__main__":
|
|
41
|
-
main()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/_experimental/flops/module_tracker.py
RENAMED
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/callbacks/on_exception_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py
RENAMED
|
File without changes
|
{nshtrainer-0.10.8 → nshtrainer-0.10.10}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|