nshtrainer 0.5.3__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/nn/__init__.py CHANGED
@@ -1,4 +1,6 @@
1
1
  from .mlp import MLP as MLP
2
+ from .mlp import MLPConfig as MLPConfig
3
+ from .mlp import MLPConfigDict as MLPConfigDict
2
4
  from .mlp import ResidualSequential as ResidualSequential
3
5
  from .module_dict import TypedModuleDict as TypedModuleDict
4
6
  from .module_list import TypedModuleList as TypedModuleList
nshtrainer/nn/mlp.py CHANGED
@@ -2,9 +2,10 @@ import copy
2
2
  from collections.abc import Callable, Sequence
3
3
  from typing import Literal, Protocol, runtime_checkable
4
4
 
5
+ import nshconfig as C
5
6
  import torch
6
7
  import torch.nn as nn
7
- from typing_extensions import override
8
+ from typing_extensions import TypedDict, override
8
9
 
9
10
  from .nonlinearity import BaseNonlinearityConfig
10
11
 
@@ -22,6 +23,71 @@ class ResidualSequential(nn.Sequential):
22
23
  return input + super().forward(input)
23
24
 
24
25
 
26
+ class MLPConfigDict(TypedDict):
27
+ bias: bool
28
+ """Whether to include bias terms in the linear layers."""
29
+
30
+ no_bias_scalar: bool
31
+ """Whether to exclude bias terms when the output dimension is 1."""
32
+
33
+ nonlinearity: BaseNonlinearityConfig | None
34
+ """Activation function to use between layers."""
35
+
36
+ ln: bool | Literal["pre", "post"]
37
+ """Whether to apply layer normalization before or after the linear layers."""
38
+
39
+ dropout: float | None
40
+ """Dropout probability to apply between layers."""
41
+
42
+ residual: bool
43
+ """Whether to use residual connections between layers."""
44
+
45
+
46
+ class MLPConfig(C.Config):
47
+ bias: bool = True
48
+ """Whether to include bias terms in the linear layers."""
49
+
50
+ no_bias_scalar: bool = True
51
+ """Whether to exclude bias terms when the output dimension is 1."""
52
+
53
+ nonlinearity: BaseNonlinearityConfig | None = None
54
+ """Activation function to use between layers."""
55
+
56
+ ln: bool | Literal["pre", "post"] = False
57
+ """Whether to apply layer normalization before or after the linear layers."""
58
+
59
+ dropout: float | None = None
60
+ """Dropout probability to apply between layers."""
61
+
62
+ residual: bool = False
63
+ """Whether to use residual connections between layers."""
64
+
65
+ def to_kwargs(self) -> MLPConfigDict:
66
+ return {
67
+ "bias": self.bias,
68
+ "no_bias_scalar": self.no_bias_scalar,
69
+ "nonlinearity": self.nonlinearity,
70
+ "ln": self.ln,
71
+ "dropout": self.dropout,
72
+ "residual": self.residual,
73
+ }
74
+
75
+ def create_module(
76
+ self,
77
+ dims: Sequence[int],
78
+ pre_layers: Sequence[nn.Module] = [],
79
+ post_layers: Sequence[nn.Module] = [],
80
+ linear_cls: LinearModuleConstructor = nn.Linear,
81
+ ):
82
+ return MLP(
83
+ dims,
84
+ **self.to_kwargs(),
85
+ pre_layers=pre_layers,
86
+ post_layers=post_layers,
87
+ linear_cls=linear_cls,
88
+ )
89
+
90
+
25
91
  def MLP(
26
92
  dims: Sequence[int],
27
93
  activation: BaseNonlinearityConfig
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.5.3
3
+ Version: 0.6.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -12,9 +12,9 @@ Classifier: Programming Language :: Python :: 3.12
12
12
  Requires-Dist: lightning
13
13
  Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
14
14
  Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
15
- Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
16
- Requires-Dist: nshrunner (>=0.6.1,<0.7.0)
17
- Requires-Dist: nshutils (>=0.3.0,<0.4.0)
15
+ Requires-Dist: nshconfig (>=0,<1)
16
+ Requires-Dist: nshrunner (>=0,<1)
17
+ Requires-Dist: nshutils (>=0,<1)
18
18
  Requires-Dist: numpy
19
19
  Requires-Dist: pytorch-lightning
20
20
  Requires-Dist: rich
@@ -41,8 +41,8 @@ nshtrainer/model/modules/logger.py,sha256=XEeo3QrplTNKZqfl6iWZf3fze3R4YOeOvs-RKV
41
41
  nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
42
42
  nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
43
43
  nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
44
- nshtrainer/nn/__init__.py,sha256=57LPaP3G-BBGD2eGxbBUABNgYl3s_oASwrtOSS4bzTs,1339
45
- nshtrainer/nn/mlp.py,sha256=i-dHk0tomO_XlU6cKN4CC4HxTaYb-ukBCAgY1ySXl4I,3963
44
+ nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
45
+ nshtrainer/nn/mlp.py,sha256=tX1VdtdzB0dyMkV0oEGCcob9hsYrXEIeusXfb-bJ5lQ,5940
46
46
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
47
47
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
48
48
  nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
@@ -59,6 +59,6 @@ nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
59
59
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
60
60
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
61
61
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
62
- nshtrainer-0.5.3.dist-info/METADATA,sha256=WbSdHGLe7sAKHKZWi5C7KjG-MwmWbcxTiNL67yqTwFs,812
63
- nshtrainer-0.5.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
64
- nshtrainer-0.5.3.dist-info/RECORD,,
62
+ nshtrainer-0.6.1.dist-info/METADATA,sha256=A99lygdq2iZY6oebqY5iOFcU71GX0t_cQN3cwKUlCVg,788
63
+ nshtrainer-0.6.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
64
+ nshtrainer-0.6.1.dist-info/RECORD,,