nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from ...util.typing_utils import mixin_base_type
|
|
10
|
+
from ..config import BaseConfig
|
|
11
|
+
from .callback import CallbackRegistrarModuleMixin
|
|
12
|
+
|
|
13
|
+
log = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
|
|
17
|
+
mapping = {id(p): n for n, p in model.named_parameters()}
|
|
18
|
+
return [mapping[id(p)] for p in parameters]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)):
|
|
22
|
+
@override
|
|
23
|
+
def __init__(self, *args, **kwargs):
|
|
24
|
+
super().__init__(*args, **kwargs)
|
|
25
|
+
|
|
26
|
+
self.shared_parameters: list[tuple[nn.Parameter, int | float]] = []
|
|
27
|
+
self._warned_shared_parameters = False
|
|
28
|
+
|
|
29
|
+
def on_after_backward(_trainer: Trainer, pl_module: LightningModule):
|
|
30
|
+
nonlocal self
|
|
31
|
+
|
|
32
|
+
config = cast(BaseConfig, pl_module.hparams)
|
|
33
|
+
if not config.trainer.supports_shared_parameters:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...")
|
|
37
|
+
no_grad_parameters: list[nn.Parameter] = []
|
|
38
|
+
for p, factor in self.shared_parameters:
|
|
39
|
+
if not hasattr(p, "grad") or p.grad is None:
|
|
40
|
+
no_grad_parameters.append(p)
|
|
41
|
+
continue
|
|
42
|
+
|
|
43
|
+
_ = p.grad.data.div_(factor)
|
|
44
|
+
|
|
45
|
+
if no_grad_parameters and not self._warned_shared_parameters:
|
|
46
|
+
no_grad_parameters_str = ", ".join(
|
|
47
|
+
_parameters_to_names(no_grad_parameters, pl_module)
|
|
48
|
+
)
|
|
49
|
+
log.warning(
|
|
50
|
+
"The following parameters were marked as shared, but had no gradients: "
|
|
51
|
+
f"{no_grad_parameters_str}"
|
|
52
|
+
)
|
|
53
|
+
self._warned_shared_parameters = True
|
|
54
|
+
|
|
55
|
+
log.debug(
|
|
56
|
+
f"Done scaling shared parameters. (len={len(self.shared_parameters)})"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.register_callback(on_after_backward=on_after_backward)
|
|
60
|
+
|
|
61
|
+
def register_shared_parameters(
|
|
62
|
+
self, parameters: list[tuple[nn.Parameter, int | float]]
|
|
63
|
+
):
|
|
64
|
+
for parameter, factor in parameters:
|
|
65
|
+
if not isinstance(parameter, nn.Parameter):
|
|
66
|
+
raise ValueError("Shared parameters must be PyTorch parameters")
|
|
67
|
+
if not isinstance(factor, (int, float)):
|
|
68
|
+
raise ValueError("Factor must be an integer or float")
|
|
69
|
+
|
|
70
|
+
self.shared_parameters.append((parameter, factor))
|
|
71
|
+
|
|
72
|
+
log.info(f"Registered {len(parameters)} shared parameters")
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .mlp import MLP as MLP
|
|
2
|
+
from .mlp import ResidualSequential as ResidualSequential
|
|
3
|
+
from .module_dict import TypedModuleDict as TypedModuleDict
|
|
4
|
+
from .module_list import TypedModuleList as TypedModuleList
|
|
5
|
+
from .nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
|
|
6
|
+
from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
|
7
|
+
from .nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
8
|
+
from .nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
|
9
|
+
from .nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
|
|
10
|
+
from .nonlinearity import NonlinearityConfig as NonlinearityConfig
|
|
11
|
+
from .nonlinearity import PReLUConfig as PReLUConfig
|
|
12
|
+
from .nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
|
13
|
+
from .nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
|
14
|
+
from .nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
|
15
|
+
from .nonlinearity import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
|
16
|
+
from .nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
|
|
17
|
+
from .nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
|
|
18
|
+
from .nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
|
|
19
|
+
from .nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
|
nshtrainer/nn/mlp.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from typing import Literal, Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from .nonlinearity import BaseNonlinearityConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class LinearModuleConstructor(Protocol):
|
|
14
|
+
def __call__(
|
|
15
|
+
self, in_features: int, out_features: int, bias: bool = True
|
|
16
|
+
) -> nn.Module: ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ResidualSequential(nn.Sequential):
|
|
20
|
+
@override
|
|
21
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return input + super().forward(input)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def MLP(
|
|
26
|
+
dims: Sequence[int],
|
|
27
|
+
activation: BaseNonlinearityConfig
|
|
28
|
+
| nn.Module
|
|
29
|
+
| Callable[[], nn.Module]
|
|
30
|
+
| None = None,
|
|
31
|
+
nonlinearity: BaseNonlinearityConfig
|
|
32
|
+
| nn.Module
|
|
33
|
+
| Callable[[], nn.Module]
|
|
34
|
+
| None = None,
|
|
35
|
+
bias: bool = True,
|
|
36
|
+
no_bias_scalar: bool = True,
|
|
37
|
+
ln: bool | Literal["pre", "post"] = False,
|
|
38
|
+
dropout: float | None = None,
|
|
39
|
+
residual: bool = False,
|
|
40
|
+
pre_layers: Sequence[nn.Module] = [],
|
|
41
|
+
post_layers: Sequence[nn.Module] = [],
|
|
42
|
+
linear_cls: LinearModuleConstructor = nn.Linear,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
dims (Sequence[int]): List of integers representing the dimensions of the MLP.
|
|
49
|
+
nonlinearity (Callable[[], nn.Module]): Activation function to use between layers.
|
|
50
|
+
activation (Callable[[], nn.Module]): Activation function to use between layers.
|
|
51
|
+
bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
|
|
52
|
+
no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1. Defaults to True.
|
|
53
|
+
ln (bool | Literal["pre", "post"], optional): Whether to apply layer normalization before or after the linear layers. Defaults to False.
|
|
54
|
+
dropout (float | None, optional): Dropout probability to apply between layers. Defaults to None.
|
|
55
|
+
residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
|
|
56
|
+
pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
|
|
57
|
+
post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
nn.Sequential: The constructed MLP.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
if activation is None:
|
|
64
|
+
activation = nonlinearity
|
|
65
|
+
|
|
66
|
+
if len(dims) < 2:
|
|
67
|
+
raise ValueError("mlp requires at least 2 dimensions")
|
|
68
|
+
if ln is True:
|
|
69
|
+
ln = "pre"
|
|
70
|
+
elif isinstance(ln, str) and ln not in ("pre", "post"):
|
|
71
|
+
raise ValueError("ln must be a boolean or 'pre' or 'post'")
|
|
72
|
+
|
|
73
|
+
layers: list[nn.Module] = []
|
|
74
|
+
if ln == "pre":
|
|
75
|
+
layers.append(nn.LayerNorm(dims[0]))
|
|
76
|
+
|
|
77
|
+
layers.extend(pre_layers)
|
|
78
|
+
|
|
79
|
+
for i in range(len(dims) - 1):
|
|
80
|
+
in_features = dims[i]
|
|
81
|
+
out_features = dims[i + 1]
|
|
82
|
+
bias_ = bias and not (no_bias_scalar and out_features == 1)
|
|
83
|
+
layers.append(linear_cls(in_features, out_features, bias=bias_))
|
|
84
|
+
if dropout is not None:
|
|
85
|
+
layers.append(nn.Dropout(dropout))
|
|
86
|
+
if i < len(dims) - 2:
|
|
87
|
+
match activation:
|
|
88
|
+
case BaseNonlinearityConfig():
|
|
89
|
+
layers.append(activation.create_module())
|
|
90
|
+
case nn.Module():
|
|
91
|
+
# In this case, we create a deep copy of the module to avoid sharing parameters (if any).
|
|
92
|
+
layers.append(copy.deepcopy(activation))
|
|
93
|
+
case Callable():
|
|
94
|
+
layers.append(activation())
|
|
95
|
+
case _:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Either `nonlinearity` or `activation` must be provided"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
layers.extend(post_layers)
|
|
101
|
+
|
|
102
|
+
if ln == "post":
|
|
103
|
+
layers.append(nn.LayerNorm(dims[-1]))
|
|
104
|
+
|
|
105
|
+
cls = ResidualSequential if residual else nn.Sequential
|
|
106
|
+
return cls(*layers)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from collections.abc import Iterable, Mapping
|
|
2
|
+
from typing import Generic, cast
|
|
3
|
+
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from typing_extensions import TypeVar
|
|
6
|
+
|
|
7
|
+
TModule = TypeVar("TModule", bound=nn.Module, infer_variance=True)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TypedModuleDict(nn.Module, Generic[TModule]):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
modules: Mapping[str, TModule],
|
|
14
|
+
key_prefix: str = "_typed_moduledict_",
|
|
15
|
+
# we use a key prefix to avoid attribute name collisions
|
|
16
|
+
# (which is a common issue in nn.ModuleDict as it uses `__setattr__` to set the modules)
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self.key_prefix = key_prefix
|
|
21
|
+
self._module_dict = nn.ModuleDict(
|
|
22
|
+
{self._with_prefix(k): v for k, v in modules.items()}
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def _with_prefix(self, key: str) -> str:
|
|
26
|
+
return f"{self.key_prefix}{key}"
|
|
27
|
+
|
|
28
|
+
def _remove_prefix(self, key: str) -> str:
|
|
29
|
+
assert key.startswith(
|
|
30
|
+
self.key_prefix
|
|
31
|
+
), f"{key} does not start with {self.key_prefix}"
|
|
32
|
+
return key[len(self.key_prefix) :]
|
|
33
|
+
|
|
34
|
+
def __setitem__(self, key: str, module: TModule) -> None:
|
|
35
|
+
key = self._with_prefix(key)
|
|
36
|
+
return self._module_dict.__setitem__(key, module)
|
|
37
|
+
|
|
38
|
+
def __getitem__(self, key: str) -> TModule:
|
|
39
|
+
key = self._with_prefix(key)
|
|
40
|
+
return self._module_dict.__getitem__(key) # type: ignore
|
|
41
|
+
|
|
42
|
+
def update(self, modules: Mapping[str, TModule]) -> None:
|
|
43
|
+
return self._module_dict.update(
|
|
44
|
+
{self._with_prefix(k): v for k, v in modules.items()}
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def get(self, key: str) -> TModule | None:
|
|
48
|
+
key = self._with_prefix(key)
|
|
49
|
+
if (value := self._module_dict._modules.get(key)) is None:
|
|
50
|
+
return None
|
|
51
|
+
return cast(TModule, value)
|
|
52
|
+
|
|
53
|
+
def keys(self) -> Iterable[str]:
|
|
54
|
+
r"""Return an iterable of the ModuleDict keys."""
|
|
55
|
+
return [self._remove_prefix(k) for k in self._module_dict.keys()]
|
|
56
|
+
|
|
57
|
+
def items(self) -> Iterable[tuple[str, TModule]]:
|
|
58
|
+
r"""Return an iterable of the ModuleDict key/value pairs."""
|
|
59
|
+
return [
|
|
60
|
+
(self._remove_prefix(k), cast(TModule, v))
|
|
61
|
+
for k, v in self._module_dict.items()
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
def values(self) -> Iterable[TModule]:
|
|
65
|
+
r"""Return an iterable of the ModuleDict values."""
|
|
66
|
+
return cast(Iterable[TModule], self._module_dict.values())
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from collections.abc import Iterable, Iterator
|
|
2
|
+
from typing import Generic, TypeVar, overload
|
|
3
|
+
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
TModule = TypeVar("TModule", bound=nn.Module)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TypedModuleList(nn.ModuleList, Generic[TModule]):
|
|
11
|
+
def __init__(self, modules: Iterable[TModule] | None = None) -> None:
|
|
12
|
+
super().__init__(modules)
|
|
13
|
+
|
|
14
|
+
@overload
|
|
15
|
+
def __getitem__(self, idx: int) -> TModule: ...
|
|
16
|
+
|
|
17
|
+
@overload
|
|
18
|
+
def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
|
|
19
|
+
|
|
20
|
+
@override
|
|
21
|
+
def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":
|
|
22
|
+
return super().__getitem__(idx) # type: ignore
|
|
23
|
+
|
|
24
|
+
@override
|
|
25
|
+
def __setitem__(self, idx: int, module: TModule) -> None: # type: ignore
|
|
26
|
+
return super().__setitem__(idx, module)
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
def __iter__(self) -> Iterator[TModule]:
|
|
30
|
+
return super().__iter__() # type: ignore
|
|
31
|
+
|
|
32
|
+
@override
|
|
33
|
+
def __iadd__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
|
|
34
|
+
return super().__iadd__(modules) # type: ignore
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def __add__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
|
|
38
|
+
return super().__add__(modules) # type: ignore
|
|
39
|
+
|
|
40
|
+
@override
|
|
41
|
+
def insert(self, idx: int, module: TModule) -> None: # type: ignore
|
|
42
|
+
return super().insert(idx, module)
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
def append(self, module: TModule) -> "TypedModuleList[TModule]": # type: ignore
|
|
46
|
+
return super().append(module) # type: ignore
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
def extend(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
|
|
50
|
+
return super().extend(modules) # type: ignore
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Annotated, Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from ..config import Field, TypedConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseNonlinearityConfig(TypedConfig, ABC):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def create_module(self) -> nn.Module:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
18
|
+
name: Literal["relu"] = "relu"
|
|
19
|
+
|
|
20
|
+
@override
|
|
21
|
+
def create_module(self) -> nn.Module:
|
|
22
|
+
return nn.ReLU()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
|
|
26
|
+
name: Literal["sigmoid"] = "sigmoid"
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
def create_module(self) -> nn.Module:
|
|
30
|
+
return nn.Sigmoid()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TanhNonlinearityConfig(BaseNonlinearityConfig):
|
|
34
|
+
name: Literal["tanh"] = "tanh"
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def create_module(self) -> nn.Module:
|
|
38
|
+
return nn.Tanh()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
|
|
42
|
+
name: Literal["softmax"] = "softmax"
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
def create_module(self) -> nn.Module:
|
|
46
|
+
return nn.Softmax(dim=1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
|
|
50
|
+
name: Literal["softplus"] = "softplus"
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def create_module(self) -> nn.Module:
|
|
54
|
+
return nn.Softplus()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
|
|
58
|
+
name: Literal["softsign"] = "softsign"
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
def create_module(self) -> nn.Module:
|
|
62
|
+
return nn.Softsign()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ELUNonlinearityConfig(BaseNonlinearityConfig):
|
|
66
|
+
name: Literal["elu"] = "elu"
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def create_module(self) -> nn.Module:
|
|
70
|
+
return nn.ELU()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
74
|
+
name: Literal["leaky_relu"] = "leaky_relu"
|
|
75
|
+
|
|
76
|
+
negative_slope: float | None = None
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def create_module(self) -> nn.Module:
|
|
80
|
+
kwargs = {}
|
|
81
|
+
if self.negative_slope is not None:
|
|
82
|
+
kwargs["negative_slope"] = self.negative_slope
|
|
83
|
+
return nn.LeakyReLU(**kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class PReLUConfig(BaseNonlinearityConfig):
|
|
87
|
+
name: Literal["prelu"] = "prelu"
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
def create_module(self) -> nn.Module:
|
|
91
|
+
return nn.PReLU()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class GELUNonlinearityConfig(BaseNonlinearityConfig):
|
|
95
|
+
name: Literal["gelu"] = "gelu"
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
def create_module(self) -> nn.Module:
|
|
99
|
+
return nn.GELU()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class SwishNonlinearityConfig(BaseNonlinearityConfig):
|
|
103
|
+
name: Literal["swish"] = "swish"
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
def create_module(self) -> nn.Module:
|
|
107
|
+
return nn.SiLU()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class SiLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
111
|
+
name: Literal["silu"] = "silu"
|
|
112
|
+
|
|
113
|
+
@override
|
|
114
|
+
def create_module(self) -> nn.Module:
|
|
115
|
+
return nn.SiLU()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class MishNonlinearityConfig(BaseNonlinearityConfig):
|
|
119
|
+
name: Literal["mish"] = "mish"
|
|
120
|
+
|
|
121
|
+
@override
|
|
122
|
+
def create_module(self) -> nn.Module:
|
|
123
|
+
return nn.Mish()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class SwiGLU(nn.SiLU):
|
|
127
|
+
@override
|
|
128
|
+
def forward(self, input: torch.Tensor):
|
|
129
|
+
input, gate = input.chunk(2, dim=-1)
|
|
130
|
+
return input * super().forward(gate)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
134
|
+
name: Literal["swiglu"] = "swiglu"
|
|
135
|
+
|
|
136
|
+
@override
|
|
137
|
+
def create_module(self) -> nn.Module:
|
|
138
|
+
return SwiGLU()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
NonlinearityConfig = Annotated[
|
|
142
|
+
ReLUNonlinearityConfig
|
|
143
|
+
| SigmoidNonlinearityConfig
|
|
144
|
+
| TanhNonlinearityConfig
|
|
145
|
+
| SoftmaxNonlinearityConfig
|
|
146
|
+
| SoftplusNonlinearityConfig
|
|
147
|
+
| SoftsignNonlinearityConfig
|
|
148
|
+
| ELUNonlinearityConfig
|
|
149
|
+
| LeakyReLUNonlinearityConfig
|
|
150
|
+
| PReLUConfig
|
|
151
|
+
| GELUNonlinearityConfig
|
|
152
|
+
| SwishNonlinearityConfig
|
|
153
|
+
| SiLUNonlinearityConfig
|
|
154
|
+
| MishNonlinearityConfig
|
|
155
|
+
| SwiGLUNonlinearityConfig,
|
|
156
|
+
Field(discriminator="name"),
|
|
157
|
+
]
|
nshtrainer/optimizer.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Annotated, Any, Literal, TypeAlias
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from torch.optim import Optimizer
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from .config import Field, TypedConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OptimizerConfigBase(TypedConfig, ABC):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def create_optimizer(
|
|
15
|
+
self,
|
|
16
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
|
17
|
+
) -> Optimizer: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AdamWConfig(OptimizerConfigBase):
|
|
21
|
+
name: Literal["adamw"] = "adamw"
|
|
22
|
+
|
|
23
|
+
lr: float
|
|
24
|
+
"""Learning rate for the optimizer."""
|
|
25
|
+
|
|
26
|
+
weight_decay: float = 1.0e-2
|
|
27
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
|
28
|
+
|
|
29
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
30
|
+
"""
|
|
31
|
+
Betas for the optimizer:
|
|
32
|
+
(beta1, beta2) are the coefficients used for computing running averages of
|
|
33
|
+
gradient and its square.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
eps: float = 1e-8
|
|
37
|
+
"""Term added to the denominator to improve numerical stability."""
|
|
38
|
+
|
|
39
|
+
amsgrad: bool = False
|
|
40
|
+
"""Whether to use the AMSGrad variant of this algorithm."""
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def create_optimizer(
|
|
44
|
+
self,
|
|
45
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
|
46
|
+
):
|
|
47
|
+
from torch.optim import AdamW
|
|
48
|
+
|
|
49
|
+
return AdamW(
|
|
50
|
+
parameters,
|
|
51
|
+
lr=self.lr,
|
|
52
|
+
weight_decay=self.weight_decay,
|
|
53
|
+
betas=self.betas,
|
|
54
|
+
eps=self.eps,
|
|
55
|
+
amsgrad=self.amsgrad,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
OptimizerConfig: TypeAlias = Annotated[
|
|
60
|
+
AdamWConfig,
|
|
61
|
+
Field(discriminator="name"),
|
|
62
|
+
]
|
nshtrainer/runner.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Generic
|
|
3
|
+
|
|
4
|
+
from nshrunner import Runner as _Runner
|
|
5
|
+
from typing_extensions import Concatenate, TypeVar, TypeVarTuple, Unpack, override
|
|
6
|
+
|
|
7
|
+
from .model.config import BaseConfig
|
|
8
|
+
|
|
9
|
+
TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
|
|
10
|
+
TArguments = TypeVarTuple("TArguments")
|
|
11
|
+
TReturn = TypeVar("TReturn", infer_variance=True)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class Runner(
|
|
16
|
+
_Runner[Unpack[tuple[TConfig, Unpack[TArguments]]], TReturn],
|
|
17
|
+
Generic[TConfig, Unpack[TArguments], TReturn],
|
|
18
|
+
):
|
|
19
|
+
@override
|
|
20
|
+
def default_validate_fn():
|
|
21
|
+
pass
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
REQUIRED_PACKAGES = [
|
|
2
|
+
"beartype",
|
|
3
|
+
"cloudpickle",
|
|
4
|
+
"jaxtyping",
|
|
5
|
+
"lightning",
|
|
6
|
+
"lightning_fabric",
|
|
7
|
+
"lightning_utilities",
|
|
8
|
+
"lovely_numpy",
|
|
9
|
+
"lovely_tensors",
|
|
10
|
+
"numpy",
|
|
11
|
+
"psutil",
|
|
12
|
+
"pydantic",
|
|
13
|
+
"pydantic_core",
|
|
14
|
+
"pysnooper",
|
|
15
|
+
"rich",
|
|
16
|
+
"tabulate",
|
|
17
|
+
"torch",
|
|
18
|
+
"torchmetrics",
|
|
19
|
+
"tqdm",
|
|
20
|
+
"typing_extensions",
|
|
21
|
+
"wrapt",
|
|
22
|
+
"yaml",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def main():
|
|
27
|
+
import importlib.util
|
|
28
|
+
import sys
|
|
29
|
+
|
|
30
|
+
missing_packages: list[str] = []
|
|
31
|
+
for package_name in REQUIRED_PACKAGES:
|
|
32
|
+
spec = importlib.util.find_spec(package_name)
|
|
33
|
+
if spec is None:
|
|
34
|
+
missing_packages.append(package_name)
|
|
35
|
+
|
|
36
|
+
if missing_packages:
|
|
37
|
+
sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if __name__ == "__main__":
|
|
41
|
+
main()
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import ast
|
|
3
|
+
import glob
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_imports(file_path: Path):
|
|
9
|
+
with open(file_path, "r") as file:
|
|
10
|
+
try:
|
|
11
|
+
tree = ast.parse(file.read())
|
|
12
|
+
except SyntaxError:
|
|
13
|
+
print(f"Syntax error in file: {file_path}", file=sys.stderr)
|
|
14
|
+
return set()
|
|
15
|
+
|
|
16
|
+
imports = set()
|
|
17
|
+
for node in ast.walk(tree):
|
|
18
|
+
if isinstance(node, ast.Import):
|
|
19
|
+
for alias in node.names:
|
|
20
|
+
imports.add(alias.name.split(".")[0])
|
|
21
|
+
elif isinstance(node, ast.ImportFrom):
|
|
22
|
+
if node.level == 0 and node.module: # Absolute import
|
|
23
|
+
imports.add(node.module.split(".")[0])
|
|
24
|
+
return imports
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main():
|
|
28
|
+
parser = argparse.ArgumentParser(
|
|
29
|
+
description="Find unique Python packages used in files."
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument("glob_pattern", help="Glob pattern to match files")
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--exclude-std", action="store_true", help="Exclude Python standard libraries"
|
|
34
|
+
)
|
|
35
|
+
args = parser.parse_args()
|
|
36
|
+
|
|
37
|
+
all_imports = set()
|
|
38
|
+
for file_path in glob.glob(args.glob_pattern, recursive=True):
|
|
39
|
+
all_imports.update(get_imports(Path(file_path)))
|
|
40
|
+
|
|
41
|
+
if args.exclude_std:
|
|
42
|
+
std_libs = set(sys.stdlib_module_names)
|
|
43
|
+
std_libs.update({"pkg_resources"})
|
|
44
|
+
all_imports = all_imports - std_libs
|
|
45
|
+
|
|
46
|
+
for package in sorted(all_imports):
|
|
47
|
+
print(package)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
if __name__ == "__main__":
|
|
51
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .trainer import Trainer as Trainer
|