nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/examples/language/mingpt.py +1650 -0
- nmn/nnx/examples/vision/cnn_cifar.py +1769 -0
- nmn/nnx/nmn.py +1 -1
- nmn/nnx/yatattention.py +764 -0
- nmn/nnx/yatconv.py +22 -2
- nmn/torch/nmn.py +2 -1
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/METADATA +2 -2
- nmn-0.1.4.dist-info/RECORD +14 -0
- nmn-0.1.2.dist-info/RECORD +0 -11
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/WHEEL +0 -0
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/yatconv.py
CHANGED
@@ -24,6 +24,7 @@ Array = jax.Array
|
|
24
24
|
# Default initializers
|
25
25
|
default_kernel_init = initializers.lecun_normal()
|
26
26
|
default_bias_init = initializers.zeros_init()
|
27
|
+
default_alpha_init = initializers.ones_init()
|
27
28
|
|
28
29
|
# Helper functions
|
29
30
|
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
@@ -138,13 +139,17 @@ class YatConv(Module):
|
|
138
139
|
input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
|
139
140
|
kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
|
140
141
|
feature_group_count: int = 1,
|
142
|
+
|
141
143
|
use_bias: bool = True,
|
144
|
+
use_alpha: bool = True,
|
145
|
+
kernel_init: Initializer = default_kernel_init,
|
146
|
+
bias_init: Initializer = default_bias_init,
|
147
|
+
alpha_init: Initializer = default_alpha_init,
|
148
|
+
|
142
149
|
mask: tp.Optional[Array] = None,
|
143
150
|
dtype: tp.Optional[Dtype] = None,
|
144
151
|
param_dtype: Dtype = jnp.float32,
|
145
152
|
precision: PrecisionLike = None,
|
146
|
-
kernel_init: Initializer = default_kernel_init,
|
147
|
-
bias_init: Initializer = default_bias_init,
|
148
153
|
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
|
149
154
|
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
150
155
|
epsilon: float = 1e-5,
|
@@ -179,6 +184,8 @@ class YatConv(Module):
|
|
179
184
|
self.kernel_dilation = kernel_dilation
|
180
185
|
self.feature_group_count = feature_group_count
|
181
186
|
self.use_bias = use_bias
|
187
|
+
self.use_alpha = use_alpha
|
188
|
+
|
182
189
|
self.mask = mask
|
183
190
|
self.dtype = dtype
|
184
191
|
self.param_dtype = param_dtype
|
@@ -189,6 +196,13 @@ class YatConv(Module):
|
|
189
196
|
self.promote_dtype = promote_dtype
|
190
197
|
self.epsilon = epsilon
|
191
198
|
|
199
|
+
if use_alpha:
|
200
|
+
alpha_key = rngs.params()
|
201
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
202
|
+
else:
|
203
|
+
self.alpha = None
|
204
|
+
|
205
|
+
|
192
206
|
def __call__(self, inputs: Array) -> Array:
|
193
207
|
assert isinstance(self.kernel_size, tuple)
|
194
208
|
|
@@ -257,6 +271,7 @@ class YatConv(Module):
|
|
257
271
|
kernel_val *= current_mask
|
258
272
|
|
259
273
|
bias_val = self.bias.value if self.bias is not None else None
|
274
|
+
alpha = self.alpha.value if self.alpha is not None else None
|
260
275
|
|
261
276
|
inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
|
262
277
|
(inputs_flat, kernel_val, bias_val), dtype=self.dtype
|
@@ -314,6 +329,11 @@ class YatConv(Module):
|
|
314
329
|
bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
|
315
330
|
y += jnp.reshape(bias_val, bias_reshape_dims)
|
316
331
|
|
332
|
+
assert self.use_alpha == (alpha is not None)
|
333
|
+
if alpha is not None:
|
334
|
+
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
335
|
+
y = y * scale
|
336
|
+
|
317
337
|
if num_batch_dimensions != 1:
|
318
338
|
output_shape = input_batch_shape + y.shape[1:]
|
319
339
|
y = jnp.reshape(y, output_shape)
|
nmn/torch/nmn.py
CHANGED
@@ -10,7 +10,8 @@ class YatNMN(nn.Module):
|
|
10
10
|
Attributes:
|
11
11
|
in_features (int): Size of each input sample
|
12
12
|
out_features (int): Size of each output sample
|
13
|
-
|
13
|
+
bias (bool): Whether to add a bias to the output
|
14
|
+
alpha (bool): Whether to multiply with alpha
|
14
15
|
dtype (torch.dtype): Data type for computation
|
15
16
|
epsilon (float): Small constant to avoid division by zero
|
16
17
|
kernel_init (callable): Initializer for the weight matrix
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nmn
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.4
|
4
4
|
Summary: a neuron that matter
|
5
5
|
Project-URL: Homepage, https://github.com/mlnomadpy/nmn
|
6
6
|
Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
|
@@ -34,7 +34,7 @@ Not the neurons we want, but the neurons we need
|
|
34
34
|
|
35
35
|
Yat-Product:
|
36
36
|
$$
|
37
|
-
ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w})) + \epsilon}.
|
37
|
+
ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
|
38
38
|
$$
|
39
39
|
|
40
40
|
**Explanation:**
|
@@ -0,0 +1,14 @@
|
|
1
|
+
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
|
+
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
|
+
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
+
nmn/nnx/nmn.py,sha256=gWe8EL-aUm7be03M9O5R3XdBb92EpBEFsylrY6BA60c,4871
|
5
|
+
nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
|
6
|
+
nmn/nnx/yatconv.py,sha256=xUH9NBY1fIDZeTA9GdgmqR_DJiQJgwU2uDrgxqirKmU,12308
|
7
|
+
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
8
|
+
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
9
|
+
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
10
|
+
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
11
|
+
nmn-0.1.4.dist-info/METADATA,sha256=k28p055Dr6WWVQcb01uinFRiT5R-CAvdKz33fqZ85g4,5032
|
12
|
+
nmn-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
+
nmn-0.1.4.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
14
|
+
nmn-0.1.4.dist-info/RECORD,,
|
nmn-0.1.2.dist-info/RECORD
DELETED
@@ -1,11 +0,0 @@
|
|
1
|
-
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
|
-
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
|
-
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
-
nmn/nnx/nmn.py,sha256=hZDgMnGnSnBSqMbk-z7qUt8QsHEM-2o6CVWacXZfz3E,4870
|
5
|
-
nmn/nnx/yatconv.py,sha256=EZx6g-KcuwrPNEVPl8YdQ16ZXkly_m0XvYCIoWVwFc0,11742
|
6
|
-
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
7
|
-
nmn/torch/nmn.py,sha256=qOFOlH4_pCOQr_4ctGpEbnW3DAGQotijDTKu5aIEXaE,4609
|
8
|
-
nmn-0.1.2.dist-info/METADATA,sha256=MxRIZIm8TIcvUAyW-5gYBu88g4hF-upahr3e2tfrWE8,5030
|
9
|
-
nmn-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
-
nmn-0.1.2.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
11
|
-
nmn-0.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|