nmn 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,9 @@
1
+ from .softermax import softermax
2
+ from .softer_sigmoid import softer_sigmoid
3
+ from .soft_tanh import soft_tanh
4
+
5
+ __all__ = [
6
+ "softermax",
7
+ "softer_sigmoid",
8
+ "soft_tanh",
9
+ ]
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+ def soft_tanh(
5
+ x: Array,
6
+ n: float = 1.0,
7
+ ) -> Array:
8
+ """
9
+ Maps a non-negative score to the range [-1, 1) using the soft-tanh function.
10
+
11
+ The soft-tanh function is defined as:
12
+ .. math::
13
+ \\text{soft-tanh}_n(x) = \\frac{x^n - 1}{1 + x^n}
14
+
15
+ The power `n` again controls the transition sharpness: higher `n` makes the
16
+ function approach -1 more quickly for large `x`.
17
+
18
+ Args:
19
+ x (Array): A JAX array of non-negative scores (x >= 0).
20
+ n (float, optional): The power to raise the score to. Defaults to 1.0.
21
+
22
+ Returns:
23
+ Array: The mapped scores in the range [-1, 1).
24
+ """
25
+ if n <= 0:
26
+ raise ValueError("Power 'n' must be positive.")
27
+
28
+ x_n = jnp.power(x, n)
29
+ return (x_n - 1.0) / (1.0 + x_n)
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+ def softer_sigmoid(
5
+ x: Array,
6
+ n: float = 1.0,
7
+ ) -> Array:
8
+ """
9
+ Squashes a non-negative score into the range [0, 1) using the soft-sigmoid function.
10
+
11
+ The soft-sigmoid function is defined as:
12
+ .. math::
13
+ \\text{soft-sigmoid}_n(x) = \\frac{x^n}{1 + x^n}
14
+
15
+ The power `n` modulates the softness: higher `n` makes the function approach
16
+ zero faster for large `x`, while `n < 1` makes the decay slower.
17
+
18
+ Args:
19
+ x (Array): A JAX array of non-negative scores (x >= 0).
20
+ n (float, optional): The power to raise the score to. Defaults to 1.0.
21
+
22
+ Returns:
23
+ Array: The squashed scores in the range [0, 1).
24
+ """
25
+ if n <= 0:
26
+ raise ValueError("Power 'n' must be positive.")
27
+
28
+ x_n = jnp.power(x, n)
29
+ return x_n / (1.0 + x_n)
@@ -0,0 +1,38 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+ from typing import Optional
4
+
5
+ def softermax(
6
+ x: Array,
7
+ n: float = 1.0,
8
+ epsilon: float = 1e-12,
9
+ axis: Optional[int] = -1,
10
+ ) -> Array:
11
+ """
12
+ Normalizes a set of non-negative scores using the Softermax function.
13
+
14
+ The Softermax function is defined as:
15
+ .. math::
16
+ \\text{softermax}_n(x_k, \\{x_i\\}) = \\frac{x_k^n}{\\epsilon + \\sum_i x_i^n}
17
+
18
+ The power `n` controls the sharpness of the distribution: `n=1` recovers
19
+ the original Softermax, while `n > 1` makes the distribution harder (more
20
+ peaked), and `0 < n < 1` makes it softer.
21
+
22
+ Args:
23
+ x (Array): A JAX array of non-negative scores.
24
+ n (float, optional): The power to raise each score to. Defaults to 1.0.
25
+ epsilon (float, optional): A small constant for numerical stability.
26
+ Defaults to 1e-12.
27
+ axis (Optional[int], optional): The axis to perform the sum over.
28
+ Defaults to -1.
29
+
30
+ Returns:
31
+ Array: The normalized scores.
32
+ """
33
+ if n <= 0:
34
+ raise ValueError("Power 'n' must be positive.")
35
+
36
+ x_n = jnp.power(x, n)
37
+ sum_x_n = jnp.sum(x_n, axis=axis, keepdims=True)
38
+ return x_n / (epsilon + sum_x_n)
nmn/nnx/yatattention.py CHANGED
@@ -26,7 +26,8 @@ from flax.typing import (
26
26
  DotGeneralT,
27
27
  )
28
28
 
29
-
29
+ from nmn.nnx.nmn import YatNMN
30
+ from jax import Array
30
31
 
31
32
  def yat_attention_weights(
32
33
  query: Array,
@@ -153,91 +154,6 @@ def yat_attention(
153
154
  '...hqk,...khd->...qhd', attn_weights, value, precision=precision
154
155
  )
155
156
 
156
- Array = jax.Array
157
-
158
- # Add YatNMN class implementation
159
- default_bias_init = initializers.zeros_init()
160
- default_alpha_init = initializers.ones_init()
161
-
162
- class YatNMN(Module):
163
- """A linear transformation with custom distance-based computation."""
164
-
165
- def __init__(
166
- self,
167
- in_features: int,
168
- out_features: int,
169
- *,
170
- use_bias: bool = True,
171
- use_alpha: bool = True,
172
- dtype: Optional[Dtype] = None,
173
- param_dtype: Dtype = jnp.float32,
174
- precision: PrecisionLike = None,
175
- kernel_init: Initializer = default_kernel_init,
176
- bias_init: Initializer = default_bias_init,
177
- alpha_init: Initializer = default_alpha_init,
178
- dot_general: DotGeneralT = lax.dot_general,
179
- rngs: rnglib.Rngs,
180
- epsilon: float = 1e-5,
181
- ):
182
-
183
- kernel_key = rngs.params()
184
- self.kernel = nnx.Param(
185
- kernel_init(kernel_key, (in_features, out_features), param_dtype)
186
- )
187
- self.bias: nnx.Param[jax.Array] | None
188
- if use_bias:
189
- bias_key = rngs.params()
190
- self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
191
- else:
192
- self.bias = None
193
-
194
- self.alpha: nnx.Param[jax.Array] | None
195
- if use_alpha:
196
- alpha_key = rngs.params()
197
- self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
198
- else:
199
- self.alpha = None
200
-
201
- self.in_features = in_features
202
- self.out_features = out_features
203
- self.use_bias = use_bias
204
- self.use_alpha = use_alpha
205
- self.dtype = dtype
206
- self.param_dtype = param_dtype
207
- self.precision = precision
208
- self.kernel_init = kernel_init
209
- self.bias_init = bias_init
210
- self.dot_general = dot_general
211
- self.epsilon = epsilon
212
-
213
- def __call__(self, inputs: Array) -> Array:
214
- """Applies YatNMN transformation to inputs."""
215
- kernel = self.kernel.value
216
- bias = self.bias.value if self.bias is not None else None
217
- alpha = self.alpha.value if self.alpha is not None else None
218
-
219
- y = self.dot_general(
220
- inputs,
221
- kernel,
222
- (((inputs.ndim - 1,), (0,)), ((), ())),
223
- precision=self.precision,
224
- )
225
-
226
- inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
227
- kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
228
- distances = inputs_squared_sum + kernel_squared_sum - 2 * y
229
-
230
- # Element-wise operation
231
- y = y ** 2 / (distances + self.epsilon)
232
-
233
- if bias is not None:
234
- y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
235
-
236
- if alpha is not None:
237
- scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
238
- y = y * scale
239
-
240
- return y
241
157
 
242
158
 
243
159
  def dot_product_attention_weights(
@@ -435,6 +351,10 @@ class MultiHeadAttention(Module):
435
351
  attention_fn: Callable[..., Array] = yat_attention,
436
352
  decode: bool | None = None,
437
353
  normalize_qk: bool = False,
354
+ use_alpha: bool = True,
355
+ alpha_init: Initializer = initializers.ones_init(),
356
+ use_dropconnect: bool = False,
357
+ dropconnect_rate: float = 0.0,
438
358
  # Deprecated, will be removed.
439
359
  qkv_dot_general: DotGeneralT | None = None,
440
360
  out_dot_general: DotGeneralT | None = None,
@@ -470,6 +390,10 @@ class MultiHeadAttention(Module):
470
390
  self.qkv_dot_general_cls = qkv_dot_general_cls
471
391
  self.out_dot_general_cls = out_dot_general_cls
472
392
  self.epsilon = epsilon
393
+ self.use_alpha = use_alpha
394
+ self.alpha_init = alpha_init
395
+ self.use_dropconnect = use_dropconnect
396
+ self.dropconnect_rate = dropconnect_rate
473
397
 
474
398
  if self.qkv_features % self.num_heads != 0:
475
399
  raise ValueError(
@@ -491,6 +415,10 @@ class MultiHeadAttention(Module):
491
415
  use_bias=self.use_bias,
492
416
  precision=self.precision,
493
417
  epsilon=self.epsilon,
418
+ use_alpha=self.use_alpha,
419
+ alpha_init=self.alpha_init,
420
+ use_dropconnect=self.use_dropconnect,
421
+ drop_rate=self.dropconnect_rate,
494
422
  )
495
423
 
496
424
  # project inputs_q to multi-headed q/k/v
@@ -590,10 +518,23 @@ class MultiHeadAttention(Module):
590
518
  f'but module expects {self.in_features}.'
591
519
  )
592
520
 
521
+ is_deterministic: bool = False
522
+ if self.dropout_rate > 0.0 or (
523
+ self.use_dropconnect and self.dropconnect_rate > 0.0
524
+ ):
525
+ is_deterministic = first_from(
526
+ deterministic,
527
+ self.deterministic,
528
+ error_msg="""No `deterministic` argument was provided to MultiHeadAttention
529
+ as either a __call__ argument, class attribute, or nnx.flag.""",
530
+ )
531
+ else:
532
+ is_deterministic = True
533
+
593
534
  # Apply YatNMN transformations and reshape to multi-head format
594
- query = squash(self.query(inputs_q))
595
- key = squash(self.key(inputs_k))
596
- value = squash(self.value(inputs_v))
535
+ query = self.query(inputs_q, deterministic=is_deterministic)
536
+ key = self.key(inputs_k, deterministic=is_deterministic)
537
+ value = self.value(inputs_v, deterministic=is_deterministic)
597
538
 
598
539
  # Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
599
540
  query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
@@ -660,26 +601,11 @@ class MultiHeadAttention(Module):
660
601
  ),
661
602
  )
662
603
 
663
- if (
664
- self.dropout_rate > 0.0
665
- ): # Require `deterministic` only if using dropout.
666
- deterministic = first_from(
667
- deterministic,
668
- self.deterministic,
669
- error_msg="""No `deterministic` argument was provided to MultiHeadAttention
670
- as either a __call__ argument, class attribute, or nnx.flag.""",
671
- )
672
- if not deterministic:
673
- if rngs is None:
674
- raise ValueError(
675
- "'rngs' must be provided if 'dropout_rng' is not given."
676
- )
677
- dropout_rng = rngs.dropout()
678
- else:
679
- dropout_rng = None
680
- else:
681
- deterministic = True
682
- dropout_rng = None
604
+ dropout_rng = None
605
+ if self.dropout_rate > 0.0 and not is_deterministic:
606
+ if rngs is None:
607
+ raise ValueError("'rngs' must be provided for dropout.")
608
+ dropout_rng = rngs.dropout()
683
609
 
684
610
  # apply attention with epsilon parameter for YatNMN
685
611
  x = self.attention_fn(
@@ -690,7 +616,7 @@ class MultiHeadAttention(Module):
690
616
  dropout_rng=dropout_rng,
691
617
  dropout_rate=self.dropout_rate,
692
618
  broadcast_dropout=self.broadcast_dropout,
693
- deterministic=deterministic,
619
+ deterministic=is_deterministic,
694
620
  dtype=self.dtype,
695
621
  precision=self.precision,
696
622
  module=self if sow_weights else None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nmn
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: a neuron that matter
5
5
  Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
6
  Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
@@ -0,0 +1,19 @@
1
+ nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
+ nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
+ nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
+ nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
5
+ nmn/nnx/yatattention.py,sha256=i6XfCGHISyb2P6KrgYFnhhdzqSTWAyshFhy1XEeuEWc,24642
6
+ nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
7
+ nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
8
+ nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
9
+ nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
11
+ nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
12
+ nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
13
+ nmn/nnx/squashers/softermax.py,sha256=NfxEDbogLUysyTvtVCTpDt27PplYvKRQLTZbYCL-Wfg,1226
14
+ nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
15
+ nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
16
+ nmn-0.1.6.dist-info/METADATA,sha256=Y9MByC16wz1MGYVZRmZA0wJQATB6Kj6w6TOL5lPzl0Q,8800
17
+ nmn-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ nmn-0.1.6.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
19
+ nmn-0.1.6.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
- nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
- nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
- nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
5
- nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
6
- nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
7
- nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
8
- nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
9
- nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
10
- nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
11
- nmn-0.1.5.dist-info/METADATA,sha256=7gvXle6Hgdgyj_tJk1DGdkOh03BOsfSks-ZHPOIEwHQ,8800
12
- nmn-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- nmn-0.1.5.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
14
- nmn-0.1.5.dist-info/RECORD,,
File without changes