nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,72 @@
1
+ from collections.abc import Sequence
2
+ from logging import getLogger
3
+ from typing import cast
4
+
5
+ import torch.nn as nn
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from typing_extensions import override
8
+
9
+ from ...util.typing_utils import mixin_base_type
10
+ from ..config import BaseConfig
11
+ from .callback import CallbackRegistrarModuleMixin
12
+
13
+ log = getLogger(__name__)
14
+
15
+
16
+ def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
17
+ mapping = {id(p): n for n, p in model.named_parameters()}
18
+ return [mapping[id(p)] for p in parameters]
19
+
20
+
21
+ class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)):
22
+ @override
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+
26
+ self.shared_parameters: list[tuple[nn.Parameter, int | float]] = []
27
+ self._warned_shared_parameters = False
28
+
29
+ def on_after_backward(_trainer: Trainer, pl_module: LightningModule):
30
+ nonlocal self
31
+
32
+ config = cast(BaseConfig, pl_module.hparams)
33
+ if not config.trainer.supports_shared_parameters:
34
+ return
35
+
36
+ log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...")
37
+ no_grad_parameters: list[nn.Parameter] = []
38
+ for p, factor in self.shared_parameters:
39
+ if not hasattr(p, "grad") or p.grad is None:
40
+ no_grad_parameters.append(p)
41
+ continue
42
+
43
+ _ = p.grad.data.div_(factor)
44
+
45
+ if no_grad_parameters and not self._warned_shared_parameters:
46
+ no_grad_parameters_str = ", ".join(
47
+ _parameters_to_names(no_grad_parameters, pl_module)
48
+ )
49
+ log.warning(
50
+ "The following parameters were marked as shared, but had no gradients: "
51
+ f"{no_grad_parameters_str}"
52
+ )
53
+ self._warned_shared_parameters = True
54
+
55
+ log.debug(
56
+ f"Done scaling shared parameters. (len={len(self.shared_parameters)})"
57
+ )
58
+
59
+ self.register_callback(on_after_backward=on_after_backward)
60
+
61
+ def register_shared_parameters(
62
+ self, parameters: list[tuple[nn.Parameter, int | float]]
63
+ ):
64
+ for parameter, factor in parameters:
65
+ if not isinstance(parameter, nn.Parameter):
66
+ raise ValueError("Shared parameters must be PyTorch parameters")
67
+ if not isinstance(factor, (int, float)):
68
+ raise ValueError("Factor must be an integer or float")
69
+
70
+ self.shared_parameters.append((parameter, factor))
71
+
72
+ log.info(f"Registered {len(parameters)} shared parameters")
@@ -0,0 +1,19 @@
1
+ from .mlp import MLP as MLP
2
+ from .mlp import ResidualSequential as ResidualSequential
3
+ from .module_dict import TypedModuleDict as TypedModuleDict
4
+ from .module_list import TypedModuleList as TypedModuleList
5
+ from .nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
6
+ from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
7
+ from .nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
8
+ from .nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
9
+ from .nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
10
+ from .nonlinearity import NonlinearityConfig as NonlinearityConfig
11
+ from .nonlinearity import PReLUConfig as PReLUConfig
12
+ from .nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
13
+ from .nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
14
+ from .nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
15
+ from .nonlinearity import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
16
+ from .nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
17
+ from .nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
18
+ from .nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
19
+ from .nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
nshtrainer/nn/mlp.py ADDED
@@ -0,0 +1,106 @@
1
+ import copy
2
+ from collections.abc import Callable, Sequence
3
+ from typing import Literal, Protocol, runtime_checkable
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from typing_extensions import override
8
+
9
+ from .nonlinearity import BaseNonlinearityConfig
10
+
11
+
12
+ @runtime_checkable
13
+ class LinearModuleConstructor(Protocol):
14
+ def __call__(
15
+ self, in_features: int, out_features: int, bias: bool = True
16
+ ) -> nn.Module: ...
17
+
18
+
19
+ class ResidualSequential(nn.Sequential):
20
+ @override
21
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
22
+ return input + super().forward(input)
23
+
24
+
25
+ def MLP(
26
+ dims: Sequence[int],
27
+ activation: BaseNonlinearityConfig
28
+ | nn.Module
29
+ | Callable[[], nn.Module]
30
+ | None = None,
31
+ nonlinearity: BaseNonlinearityConfig
32
+ | nn.Module
33
+ | Callable[[], nn.Module]
34
+ | None = None,
35
+ bias: bool = True,
36
+ no_bias_scalar: bool = True,
37
+ ln: bool | Literal["pre", "post"] = False,
38
+ dropout: float | None = None,
39
+ residual: bool = False,
40
+ pre_layers: Sequence[nn.Module] = [],
41
+ post_layers: Sequence[nn.Module] = [],
42
+ linear_cls: LinearModuleConstructor = nn.Linear,
43
+ ):
44
+ """
45
+ Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
46
+
47
+ Args:
48
+ dims (Sequence[int]): List of integers representing the dimensions of the MLP.
49
+ nonlinearity (Callable[[], nn.Module]): Activation function to use between layers.
50
+ activation (Callable[[], nn.Module]): Activation function to use between layers.
51
+ bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
52
+ no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1. Defaults to True.
53
+ ln (bool | Literal["pre", "post"], optional): Whether to apply layer normalization before or after the linear layers. Defaults to False.
54
+ dropout (float | None, optional): Dropout probability to apply between layers. Defaults to None.
55
+ residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
56
+ pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
57
+ post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
58
+
59
+ Returns:
60
+ nn.Sequential: The constructed MLP.
61
+ """
62
+
63
+ if activation is None:
64
+ activation = nonlinearity
65
+
66
+ if len(dims) < 2:
67
+ raise ValueError("mlp requires at least 2 dimensions")
68
+ if ln is True:
69
+ ln = "pre"
70
+ elif isinstance(ln, str) and ln not in ("pre", "post"):
71
+ raise ValueError("ln must be a boolean or 'pre' or 'post'")
72
+
73
+ layers: list[nn.Module] = []
74
+ if ln == "pre":
75
+ layers.append(nn.LayerNorm(dims[0]))
76
+
77
+ layers.extend(pre_layers)
78
+
79
+ for i in range(len(dims) - 1):
80
+ in_features = dims[i]
81
+ out_features = dims[i + 1]
82
+ bias_ = bias and not (no_bias_scalar and out_features == 1)
83
+ layers.append(linear_cls(in_features, out_features, bias=bias_))
84
+ if dropout is not None:
85
+ layers.append(nn.Dropout(dropout))
86
+ if i < len(dims) - 2:
87
+ match activation:
88
+ case BaseNonlinearityConfig():
89
+ layers.append(activation.create_module())
90
+ case nn.Module():
91
+ # In this case, we create a deep copy of the module to avoid sharing parameters (if any).
92
+ layers.append(copy.deepcopy(activation))
93
+ case Callable():
94
+ layers.append(activation())
95
+ case _:
96
+ raise ValueError(
97
+ "Either `nonlinearity` or `activation` must be provided"
98
+ )
99
+
100
+ layers.extend(post_layers)
101
+
102
+ if ln == "post":
103
+ layers.append(nn.LayerNorm(dims[-1]))
104
+
105
+ cls = ResidualSequential if residual else nn.Sequential
106
+ return cls(*layers)
@@ -0,0 +1,66 @@
1
+ from collections.abc import Iterable, Mapping
2
+ from typing import Generic, cast
3
+
4
+ import torch.nn as nn
5
+ from typing_extensions import TypeVar
6
+
7
+ TModule = TypeVar("TModule", bound=nn.Module, infer_variance=True)
8
+
9
+
10
+ class TypedModuleDict(nn.Module, Generic[TModule]):
11
+ def __init__(
12
+ self,
13
+ modules: Mapping[str, TModule],
14
+ key_prefix: str = "_typed_moduledict_",
15
+ # we use a key prefix to avoid attribute name collisions
16
+ # (which is a common issue in nn.ModuleDict as it uses `__setattr__` to set the modules)
17
+ ):
18
+ super().__init__()
19
+
20
+ self.key_prefix = key_prefix
21
+ self._module_dict = nn.ModuleDict(
22
+ {self._with_prefix(k): v for k, v in modules.items()}
23
+ )
24
+
25
+ def _with_prefix(self, key: str) -> str:
26
+ return f"{self.key_prefix}{key}"
27
+
28
+ def _remove_prefix(self, key: str) -> str:
29
+ assert key.startswith(
30
+ self.key_prefix
31
+ ), f"{key} does not start with {self.key_prefix}"
32
+ return key[len(self.key_prefix) :]
33
+
34
+ def __setitem__(self, key: str, module: TModule) -> None:
35
+ key = self._with_prefix(key)
36
+ return self._module_dict.__setitem__(key, module)
37
+
38
+ def __getitem__(self, key: str) -> TModule:
39
+ key = self._with_prefix(key)
40
+ return self._module_dict.__getitem__(key) # type: ignore
41
+
42
+ def update(self, modules: Mapping[str, TModule]) -> None:
43
+ return self._module_dict.update(
44
+ {self._with_prefix(k): v for k, v in modules.items()}
45
+ )
46
+
47
+ def get(self, key: str) -> TModule | None:
48
+ key = self._with_prefix(key)
49
+ if (value := self._module_dict._modules.get(key)) is None:
50
+ return None
51
+ return cast(TModule, value)
52
+
53
+ def keys(self) -> Iterable[str]:
54
+ r"""Return an iterable of the ModuleDict keys."""
55
+ return [self._remove_prefix(k) for k in self._module_dict.keys()]
56
+
57
+ def items(self) -> Iterable[tuple[str, TModule]]:
58
+ r"""Return an iterable of the ModuleDict key/value pairs."""
59
+ return [
60
+ (self._remove_prefix(k), cast(TModule, v))
61
+ for k, v in self._module_dict.items()
62
+ ]
63
+
64
+ def values(self) -> Iterable[TModule]:
65
+ r"""Return an iterable of the ModuleDict values."""
66
+ return cast(Iterable[TModule], self._module_dict.values())
@@ -0,0 +1,50 @@
1
+ from collections.abc import Iterable, Iterator
2
+ from typing import Generic, TypeVar, overload
3
+
4
+ import torch.nn as nn
5
+ from typing_extensions import override
6
+
7
+ TModule = TypeVar("TModule", bound=nn.Module)
8
+
9
+
10
+ class TypedModuleList(nn.ModuleList, Generic[TModule]):
11
+ def __init__(self, modules: Iterable[TModule] | None = None) -> None:
12
+ super().__init__(modules)
13
+
14
+ @overload
15
+ def __getitem__(self, idx: int) -> TModule: ...
16
+
17
+ @overload
18
+ def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
19
+
20
+ @override
21
+ def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":
22
+ return super().__getitem__(idx) # type: ignore
23
+
24
+ @override
25
+ def __setitem__(self, idx: int, module: TModule) -> None: # type: ignore
26
+ return super().__setitem__(idx, module)
27
+
28
+ @override
29
+ def __iter__(self) -> Iterator[TModule]:
30
+ return super().__iter__() # type: ignore
31
+
32
+ @override
33
+ def __iadd__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
34
+ return super().__iadd__(modules) # type: ignore
35
+
36
+ @override
37
+ def __add__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
38
+ return super().__add__(modules) # type: ignore
39
+
40
+ @override
41
+ def insert(self, idx: int, module: TModule) -> None: # type: ignore
42
+ return super().insert(idx, module)
43
+
44
+ @override
45
+ def append(self, module: TModule) -> "TypedModuleList[TModule]": # type: ignore
46
+ return super().append(module) # type: ignore
47
+
48
+ @override
49
+ def extend(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
50
+ return super().extend(modules) # type: ignore
@@ -0,0 +1,157 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Annotated, Literal
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing_extensions import override
7
+
8
+ from ..config import Field, TypedConfig
9
+
10
+
11
+ class BaseNonlinearityConfig(TypedConfig, ABC):
12
+ @abstractmethod
13
+ def create_module(self) -> nn.Module:
14
+ pass
15
+
16
+
17
+ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
18
+ name: Literal["relu"] = "relu"
19
+
20
+ @override
21
+ def create_module(self) -> nn.Module:
22
+ return nn.ReLU()
23
+
24
+
25
+ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
26
+ name: Literal["sigmoid"] = "sigmoid"
27
+
28
+ @override
29
+ def create_module(self) -> nn.Module:
30
+ return nn.Sigmoid()
31
+
32
+
33
+ class TanhNonlinearityConfig(BaseNonlinearityConfig):
34
+ name: Literal["tanh"] = "tanh"
35
+
36
+ @override
37
+ def create_module(self) -> nn.Module:
38
+ return nn.Tanh()
39
+
40
+
41
+ class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
42
+ name: Literal["softmax"] = "softmax"
43
+
44
+ @override
45
+ def create_module(self) -> nn.Module:
46
+ return nn.Softmax(dim=1)
47
+
48
+
49
+ class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
50
+ name: Literal["softplus"] = "softplus"
51
+
52
+ @override
53
+ def create_module(self) -> nn.Module:
54
+ return nn.Softplus()
55
+
56
+
57
+ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
58
+ name: Literal["softsign"] = "softsign"
59
+
60
+ @override
61
+ def create_module(self) -> nn.Module:
62
+ return nn.Softsign()
63
+
64
+
65
+ class ELUNonlinearityConfig(BaseNonlinearityConfig):
66
+ name: Literal["elu"] = "elu"
67
+
68
+ @override
69
+ def create_module(self) -> nn.Module:
70
+ return nn.ELU()
71
+
72
+
73
+ class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
74
+ name: Literal["leaky_relu"] = "leaky_relu"
75
+
76
+ negative_slope: float | None = None
77
+
78
+ @override
79
+ def create_module(self) -> nn.Module:
80
+ kwargs = {}
81
+ if self.negative_slope is not None:
82
+ kwargs["negative_slope"] = self.negative_slope
83
+ return nn.LeakyReLU(**kwargs)
84
+
85
+
86
+ class PReLUConfig(BaseNonlinearityConfig):
87
+ name: Literal["prelu"] = "prelu"
88
+
89
+ @override
90
+ def create_module(self) -> nn.Module:
91
+ return nn.PReLU()
92
+
93
+
94
+ class GELUNonlinearityConfig(BaseNonlinearityConfig):
95
+ name: Literal["gelu"] = "gelu"
96
+
97
+ @override
98
+ def create_module(self) -> nn.Module:
99
+ return nn.GELU()
100
+
101
+
102
+ class SwishNonlinearityConfig(BaseNonlinearityConfig):
103
+ name: Literal["swish"] = "swish"
104
+
105
+ @override
106
+ def create_module(self) -> nn.Module:
107
+ return nn.SiLU()
108
+
109
+
110
+ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
111
+ name: Literal["silu"] = "silu"
112
+
113
+ @override
114
+ def create_module(self) -> nn.Module:
115
+ return nn.SiLU()
116
+
117
+
118
+ class MishNonlinearityConfig(BaseNonlinearityConfig):
119
+ name: Literal["mish"] = "mish"
120
+
121
+ @override
122
+ def create_module(self) -> nn.Module:
123
+ return nn.Mish()
124
+
125
+
126
+ class SwiGLU(nn.SiLU):
127
+ @override
128
+ def forward(self, input: torch.Tensor):
129
+ input, gate = input.chunk(2, dim=-1)
130
+ return input * super().forward(gate)
131
+
132
+
133
+ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
134
+ name: Literal["swiglu"] = "swiglu"
135
+
136
+ @override
137
+ def create_module(self) -> nn.Module:
138
+ return SwiGLU()
139
+
140
+
141
+ NonlinearityConfig = Annotated[
142
+ ReLUNonlinearityConfig
143
+ | SigmoidNonlinearityConfig
144
+ | TanhNonlinearityConfig
145
+ | SoftmaxNonlinearityConfig
146
+ | SoftplusNonlinearityConfig
147
+ | SoftsignNonlinearityConfig
148
+ | ELUNonlinearityConfig
149
+ | LeakyReLUNonlinearityConfig
150
+ | PReLUConfig
151
+ | GELUNonlinearityConfig
152
+ | SwishNonlinearityConfig
153
+ | SiLUNonlinearityConfig
154
+ | MishNonlinearityConfig
155
+ | SwiGLUNonlinearityConfig,
156
+ Field(discriminator="name"),
157
+ ]
@@ -0,0 +1,62 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable
3
+ from typing import Annotated, Any, Literal, TypeAlias
4
+
5
+ import torch.nn as nn
6
+ from torch.optim import Optimizer
7
+ from typing_extensions import override
8
+
9
+ from .config import Field, TypedConfig
10
+
11
+
12
+ class OptimizerConfigBase(TypedConfig, ABC):
13
+ @abstractmethod
14
+ def create_optimizer(
15
+ self,
16
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
17
+ ) -> Optimizer: ...
18
+
19
+
20
+ class AdamWConfig(OptimizerConfigBase):
21
+ name: Literal["adamw"] = "adamw"
22
+
23
+ lr: float
24
+ """Learning rate for the optimizer."""
25
+
26
+ weight_decay: float = 1.0e-2
27
+ """Weight decay (L2 penalty) for the optimizer."""
28
+
29
+ betas: tuple[float, float] = (0.9, 0.999)
30
+ """
31
+ Betas for the optimizer:
32
+ (beta1, beta2) are the coefficients used for computing running averages of
33
+ gradient and its square.
34
+ """
35
+
36
+ eps: float = 1e-8
37
+ """Term added to the denominator to improve numerical stability."""
38
+
39
+ amsgrad: bool = False
40
+ """Whether to use the AMSGrad variant of this algorithm."""
41
+
42
+ @override
43
+ def create_optimizer(
44
+ self,
45
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
46
+ ):
47
+ from torch.optim import AdamW
48
+
49
+ return AdamW(
50
+ parameters,
51
+ lr=self.lr,
52
+ weight_decay=self.weight_decay,
53
+ betas=self.betas,
54
+ eps=self.eps,
55
+ amsgrad=self.amsgrad,
56
+ )
57
+
58
+
59
+ OptimizerConfig: TypeAlias = Annotated[
60
+ AdamWConfig,
61
+ Field(discriminator="name"),
62
+ ]
nshtrainer/runner.py ADDED
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic
3
+
4
+ from nshrunner import Runner as _Runner
5
+ from typing_extensions import Concatenate, TypeVar, TypeVarTuple, Unpack, override
6
+
7
+ from .model.config import BaseConfig
8
+
9
+ TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
10
+ TArguments = TypeVarTuple("TArguments")
11
+ TReturn = TypeVar("TReturn", infer_variance=True)
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class Runner(
16
+ _Runner[Unpack[tuple[TConfig, Unpack[TArguments]]], TReturn],
17
+ Generic[TConfig, Unpack[TArguments], TReturn],
18
+ ):
19
+ @override
20
+ def default_validate_fn():
21
+ pass
@@ -0,0 +1,41 @@
1
+ REQUIRED_PACKAGES = [
2
+ "beartype",
3
+ "cloudpickle",
4
+ "jaxtyping",
5
+ "lightning",
6
+ "lightning_fabric",
7
+ "lightning_utilities",
8
+ "lovely_numpy",
9
+ "lovely_tensors",
10
+ "numpy",
11
+ "psutil",
12
+ "pydantic",
13
+ "pydantic_core",
14
+ "pysnooper",
15
+ "rich",
16
+ "tabulate",
17
+ "torch",
18
+ "torchmetrics",
19
+ "tqdm",
20
+ "typing_extensions",
21
+ "wrapt",
22
+ "yaml",
23
+ ]
24
+
25
+
26
+ def main():
27
+ import importlib.util
28
+ import sys
29
+
30
+ missing_packages: list[str] = []
31
+ for package_name in REQUIRED_PACKAGES:
32
+ spec = importlib.util.find_spec(package_name)
33
+ if spec is None:
34
+ missing_packages.append(package_name)
35
+
36
+ if missing_packages:
37
+ sys.exit(f"Error: Missing required packages: {', '.join(missing_packages)}")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
@@ -0,0 +1,51 @@
1
+ import argparse
2
+ import ast
3
+ import glob
4
+ import sys
5
+ from pathlib import Path
6
+
7
+
8
+ def get_imports(file_path: Path):
9
+ with open(file_path, "r") as file:
10
+ try:
11
+ tree = ast.parse(file.read())
12
+ except SyntaxError:
13
+ print(f"Syntax error in file: {file_path}", file=sys.stderr)
14
+ return set()
15
+
16
+ imports = set()
17
+ for node in ast.walk(tree):
18
+ if isinstance(node, ast.Import):
19
+ for alias in node.names:
20
+ imports.add(alias.name.split(".")[0])
21
+ elif isinstance(node, ast.ImportFrom):
22
+ if node.level == 0 and node.module: # Absolute import
23
+ imports.add(node.module.split(".")[0])
24
+ return imports
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser(
29
+ description="Find unique Python packages used in files."
30
+ )
31
+ parser.add_argument("glob_pattern", help="Glob pattern to match files")
32
+ parser.add_argument(
33
+ "--exclude-std", action="store_true", help="Exclude Python standard libraries"
34
+ )
35
+ args = parser.parse_args()
36
+
37
+ all_imports = set()
38
+ for file_path in glob.glob(args.glob_pattern, recursive=True):
39
+ all_imports.update(get_imports(Path(file_path)))
40
+
41
+ if args.exclude_std:
42
+ std_libs = set(sys.stdlib_module_names)
43
+ std_libs.update({"pkg_resources"})
44
+ all_imports = all_imports - std_libs
45
+
46
+ for package in sorted(all_imports):
47
+ print(package)
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
@@ -0,0 +1 @@
1
+ from .trainer import Trainer as Trainer