nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/yatconv.py CHANGED
@@ -24,6 +24,7 @@ Array = jax.Array
24
24
  # Default initializers
25
25
  default_kernel_init = initializers.lecun_normal()
26
26
  default_bias_init = initializers.zeros_init()
27
+ default_alpha_init = initializers.ones_init()
27
28
 
28
29
  # Helper functions
29
30
  def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
@@ -138,13 +139,17 @@ class YatConv(Module):
138
139
  input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
139
140
  kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
140
141
  feature_group_count: int = 1,
142
+
141
143
  use_bias: bool = True,
144
+ use_alpha: bool = True,
145
+ kernel_init: Initializer = default_kernel_init,
146
+ bias_init: Initializer = default_bias_init,
147
+ alpha_init: Initializer = default_alpha_init,
148
+
142
149
  mask: tp.Optional[Array] = None,
143
150
  dtype: tp.Optional[Dtype] = None,
144
151
  param_dtype: Dtype = jnp.float32,
145
152
  precision: PrecisionLike = None,
146
- kernel_init: Initializer = default_kernel_init,
147
- bias_init: Initializer = default_bias_init,
148
153
  conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
149
154
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
150
155
  epsilon: float = 1e-5,
@@ -179,6 +184,8 @@ class YatConv(Module):
179
184
  self.kernel_dilation = kernel_dilation
180
185
  self.feature_group_count = feature_group_count
181
186
  self.use_bias = use_bias
187
+ self.use_alpha = use_alpha
188
+
182
189
  self.mask = mask
183
190
  self.dtype = dtype
184
191
  self.param_dtype = param_dtype
@@ -189,6 +196,13 @@ class YatConv(Module):
189
196
  self.promote_dtype = promote_dtype
190
197
  self.epsilon = epsilon
191
198
 
199
+ if use_alpha:
200
+ alpha_key = rngs.params()
201
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
202
+ else:
203
+ self.alpha = None
204
+
205
+
192
206
  def __call__(self, inputs: Array) -> Array:
193
207
  assert isinstance(self.kernel_size, tuple)
194
208
 
@@ -257,6 +271,7 @@ class YatConv(Module):
257
271
  kernel_val *= current_mask
258
272
 
259
273
  bias_val = self.bias.value if self.bias is not None else None
274
+ alpha = self.alpha.value if self.alpha is not None else None
260
275
 
261
276
  inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
262
277
  (inputs_flat, kernel_val, bias_val), dtype=self.dtype
@@ -314,6 +329,11 @@ class YatConv(Module):
314
329
  bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
315
330
  y += jnp.reshape(bias_val, bias_reshape_dims)
316
331
 
332
+ assert self.use_alpha == (alpha is not None)
333
+ if alpha is not None:
334
+ scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
335
+ y = y * scale
336
+
317
337
  if num_batch_dimensions != 1:
318
338
  output_shape = input_batch_shape + y.shape[1:]
319
339
  y = jnp.reshape(y, output_shape)
nmn/torch/nmn.py CHANGED
@@ -10,7 +10,8 @@ class YatNMN(nn.Module):
10
10
  Attributes:
11
11
  in_features (int): Size of each input sample
12
12
  out_features (int): Size of each output sample
13
- use_bias (bool): Whether to add a bias to the output
13
+ bias (bool): Whether to add a bias to the output
14
+ alpha (bool): Whether to multiply with alpha
14
15
  dtype (torch.dtype): Data type for computation
15
16
  epsilon (float): Small constant to avoid division by zero
16
17
  kernel_init (callable): Initializer for the weight matrix
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nmn
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: a neuron that matter
5
5
  Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
6
  Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
@@ -34,7 +34,7 @@ Not the neurons we want, but the neurons we need
34
34
 
35
35
  Yat-Product:
36
36
  $$
37
- ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w})) + \epsilon}.
37
+ ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
38
38
  $$
39
39
 
40
40
  **Explanation:**
@@ -0,0 +1,14 @@
1
+ nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
+ nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
+ nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
+ nmn/nnx/nmn.py,sha256=gWe8EL-aUm7be03M9O5R3XdBb92EpBEFsylrY6BA60c,4871
5
+ nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
6
+ nmn/nnx/yatconv.py,sha256=xUH9NBY1fIDZeTA9GdgmqR_DJiQJgwU2uDrgxqirKmU,12308
7
+ nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
8
+ nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
9
+ nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
10
+ nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
11
+ nmn-0.1.4.dist-info/METADATA,sha256=k28p055Dr6WWVQcb01uinFRiT5R-CAvdKz33fqZ85g4,5032
12
+ nmn-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ nmn-0.1.4.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
14
+ nmn-0.1.4.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
- nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
- nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
- nmn/nnx/nmn.py,sha256=hZDgMnGnSnBSqMbk-z7qUt8QsHEM-2o6CVWacXZfz3E,4870
5
- nmn/nnx/yatconv.py,sha256=EZx6g-KcuwrPNEVPl8YdQ16ZXkly_m0XvYCIoWVwFc0,11742
6
- nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
7
- nmn/torch/nmn.py,sha256=qOFOlH4_pCOQr_4ctGpEbnW3DAGQotijDTKu5aIEXaE,4609
8
- nmn-0.1.2.dist-info/METADATA,sha256=MxRIZIm8TIcvUAyW-5gYBu88g4hF-upahr3e2tfrWE8,5030
9
- nmn-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- nmn-0.1.2.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
11
- nmn-0.1.2.dist-info/RECORD,,
File without changes