nmn 0.1.5__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nmn
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: a neuron that matter
5
5
  Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
6
  Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "nmn"
7
- version = "0.1.5"
7
+ version = "0.1.7"
8
8
  authors = [
9
9
  { name="Taha Bouhsine", email="yat@mlnomads.com" },
10
10
  ]
File without changes
@@ -0,0 +1,9 @@
1
+ from .softermax import softermax
2
+ from .softer_sigmoid import softer_sigmoid
3
+ from .soft_tanh import soft_tanh
4
+
5
+ __all__ = [
6
+ "softermax",
7
+ "softer_sigmoid",
8
+ "soft_tanh",
9
+ ]
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+ def soft_tanh(
5
+ x: Array,
6
+ n: float = 1.0,
7
+ ) -> Array:
8
+ """
9
+ Maps a non-negative score to the range [-1, 1) using the soft-tanh function.
10
+
11
+ The soft-tanh function is defined as:
12
+ .. math::
13
+ \\text{soft-tanh}_n(x) = \\frac{x^n - 1}{1 + x^n}
14
+
15
+ The power `n` again controls the transition sharpness: higher `n` makes the
16
+ function approach -1 more quickly for large `x`.
17
+
18
+ Args:
19
+ x (Array): A JAX array of non-negative scores (x >= 0).
20
+ n (float, optional): The power to raise the score to. Defaults to 1.0.
21
+
22
+ Returns:
23
+ Array: The mapped scores in the range [-1, 1).
24
+ """
25
+ if n <= 0:
26
+ raise ValueError("Power 'n' must be positive.")
27
+
28
+ x_n = jnp.power(x, n)
29
+ return (x_n - 1.0) / (1.0 + x_n)
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+ def softer_sigmoid(
5
+ x: Array,
6
+ n: float = 1.0,
7
+ ) -> Array:
8
+ """
9
+ Squashes a non-negative score into the range [0, 1) using the soft-sigmoid function.
10
+
11
+ The soft-sigmoid function is defined as:
12
+ .. math::
13
+ \\text{soft-sigmoid}_n(x) = \\frac{x^n}{1 + x^n}
14
+
15
+ The power `n` modulates the softness: higher `n` makes the function approach
16
+ zero faster for large `x`, while `n < 1` makes the decay slower.
17
+
18
+ Args:
19
+ x (Array): A JAX array of non-negative scores (x >= 0).
20
+ n (float, optional): The power to raise the score to. Defaults to 1.0.
21
+
22
+ Returns:
23
+ Array: The squashed scores in the range [0, 1).
24
+ """
25
+ if n <= 0:
26
+ raise ValueError("Power 'n' must be positive.")
27
+
28
+ x_n = jnp.power(x, n)
29
+ return x_n / (1.0 + x_n)
@@ -0,0 +1,38 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+ from typing import Optional
4
+
5
+ def softermax(
6
+ x: Array,
7
+ n: float = 1.0,
8
+ epsilon: float = 1e-12,
9
+ axis: Optional[int] = -1,
10
+ ) -> Array:
11
+ """
12
+ Normalizes a set of non-negative scores using the Softermax function.
13
+
14
+ The Softermax function is defined as:
15
+ .. math::
16
+ \\text{softermax}_n(x_k, \\{x_i\\}) = \\frac{x_k^n}{\\epsilon + \\sum_i x_i^n}
17
+
18
+ The power `n` controls the sharpness of the distribution: `n=1` recovers
19
+ the original Softermax, while `n > 1` makes the distribution harder (more
20
+ peaked), and `0 < n < 1` makes it softer.
21
+
22
+ Args:
23
+ x (Array): A JAX array of non-negative scores.
24
+ n (float, optional): The power to raise each score to. Defaults to 1.0.
25
+ epsilon (float, optional): A small constant for numerical stability.
26
+ Defaults to 1e-12.
27
+ axis (Optional[int], optional): The axis to perform the sum over.
28
+ Defaults to -1.
29
+
30
+ Returns:
31
+ Array: The normalized scores.
32
+ """
33
+ if n <= 0:
34
+ raise ValueError("Power 'n' must be positive.")
35
+
36
+ x_n = jnp.power(x, n)
37
+ sum_x_n = jnp.sum(x_n, axis=axis, keepdims=True)
38
+ return x_n / (epsilon + sum_x_n)
@@ -26,8 +26,10 @@ from flax.typing import (
26
26
  DotGeneralT,
27
27
  )
28
28
 
29
+ from nmn.nnx.nmn import YatNMN
30
+ from jax import Array
29
31
 
30
-
32
+ from nmn.nnx.squashers import softermax
31
33
  def yat_attention_weights(
32
34
  query: Array,
33
35
  key: Array,
@@ -41,6 +43,7 @@ def yat_attention_weights(
41
43
  precision: PrecisionLike = None,
42
44
  module: Optional[Module] = None,
43
45
  epsilon: float = 1e-5,
46
+ use_softermax: bool = False,
44
47
  ):
45
48
  """Computes attention weights using YatNMN distance-based calculation."""
46
49
  query, key = promote_dtype((query, key), dtype=dtype)
@@ -85,7 +88,10 @@ def yat_attention_weights(
85
88
  attn_weights = jnp.where(mask, attn_weights, big_neg)
86
89
 
87
90
  # normalize the attention weights
88
- attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
91
+ if use_softermax:
92
+ attn_weights = softermax(attn_weights).astype(dtype)
93
+ else:
94
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
89
95
 
90
96
  if module:
91
97
  module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
@@ -119,6 +125,7 @@ def yat_attention(
119
125
  precision: PrecisionLike = None,
120
126
  module: Optional[Module] = None,
121
127
  epsilon: float = 1e-5,
128
+ use_softermax: bool = False,
122
129
  ):
123
130
  """Computes attention using YatNMN distance-based calculation."""
124
131
  query, key, value = promote_dtype((query, key, value), dtype=dtype)
@@ -146,6 +153,7 @@ def yat_attention(
146
153
  precision,
147
154
  module,
148
155
  epsilon,
156
+ use_softermax,
149
157
  )
150
158
 
151
159
  # return weighted sum over values for each query position
@@ -153,91 +161,6 @@ def yat_attention(
153
161
  '...hqk,...khd->...qhd', attn_weights, value, precision=precision
154
162
  )
155
163
 
156
- Array = jax.Array
157
-
158
- # Add YatNMN class implementation
159
- default_bias_init = initializers.zeros_init()
160
- default_alpha_init = initializers.ones_init()
161
-
162
- class YatNMN(Module):
163
- """A linear transformation with custom distance-based computation."""
164
-
165
- def __init__(
166
- self,
167
- in_features: int,
168
- out_features: int,
169
- *,
170
- use_bias: bool = True,
171
- use_alpha: bool = True,
172
- dtype: Optional[Dtype] = None,
173
- param_dtype: Dtype = jnp.float32,
174
- precision: PrecisionLike = None,
175
- kernel_init: Initializer = default_kernel_init,
176
- bias_init: Initializer = default_bias_init,
177
- alpha_init: Initializer = default_alpha_init,
178
- dot_general: DotGeneralT = lax.dot_general,
179
- rngs: rnglib.Rngs,
180
- epsilon: float = 1e-5,
181
- ):
182
-
183
- kernel_key = rngs.params()
184
- self.kernel = nnx.Param(
185
- kernel_init(kernel_key, (in_features, out_features), param_dtype)
186
- )
187
- self.bias: nnx.Param[jax.Array] | None
188
- if use_bias:
189
- bias_key = rngs.params()
190
- self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
191
- else:
192
- self.bias = None
193
-
194
- self.alpha: nnx.Param[jax.Array] | None
195
- if use_alpha:
196
- alpha_key = rngs.params()
197
- self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
198
- else:
199
- self.alpha = None
200
-
201
- self.in_features = in_features
202
- self.out_features = out_features
203
- self.use_bias = use_bias
204
- self.use_alpha = use_alpha
205
- self.dtype = dtype
206
- self.param_dtype = param_dtype
207
- self.precision = precision
208
- self.kernel_init = kernel_init
209
- self.bias_init = bias_init
210
- self.dot_general = dot_general
211
- self.epsilon = epsilon
212
-
213
- def __call__(self, inputs: Array) -> Array:
214
- """Applies YatNMN transformation to inputs."""
215
- kernel = self.kernel.value
216
- bias = self.bias.value if self.bias is not None else None
217
- alpha = self.alpha.value if self.alpha is not None else None
218
-
219
- y = self.dot_general(
220
- inputs,
221
- kernel,
222
- (((inputs.ndim - 1,), (0,)), ((), ())),
223
- precision=self.precision,
224
- )
225
-
226
- inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
227
- kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
228
- distances = inputs_squared_sum + kernel_squared_sum - 2 * y
229
-
230
- # Element-wise operation
231
- y = y ** 2 / (distances + self.epsilon)
232
-
233
- if bias is not None:
234
- y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
235
-
236
- if alpha is not None:
237
- scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
238
- y = y * scale
239
-
240
- return y
241
164
 
242
165
 
243
166
  def dot_product_attention_weights(
@@ -435,6 +358,10 @@ class MultiHeadAttention(Module):
435
358
  attention_fn: Callable[..., Array] = yat_attention,
436
359
  decode: bool | None = None,
437
360
  normalize_qk: bool = False,
361
+ use_alpha: bool = True,
362
+ alpha_init: Initializer = initializers.ones_init(),
363
+ use_dropconnect: bool = False,
364
+ dropconnect_rate: float = 0.0,
438
365
  # Deprecated, will be removed.
439
366
  qkv_dot_general: DotGeneralT | None = None,
440
367
  out_dot_general: DotGeneralT | None = None,
@@ -442,6 +369,7 @@ class MultiHeadAttention(Module):
442
369
  out_dot_general_cls: Any = None,
443
370
  rngs: rnglib.Rngs,
444
371
  epsilon: float = 1e-5,
372
+ use_softermax: bool = False,
445
373
  ):
446
374
  self.num_heads = num_heads
447
375
  self.in_features = in_features
@@ -470,6 +398,11 @@ class MultiHeadAttention(Module):
470
398
  self.qkv_dot_general_cls = qkv_dot_general_cls
471
399
  self.out_dot_general_cls = out_dot_general_cls
472
400
  self.epsilon = epsilon
401
+ self.use_softermax = use_softermax
402
+ self.use_alpha = use_alpha
403
+ self.alpha_init = alpha_init
404
+ self.use_dropconnect = use_dropconnect
405
+ self.dropconnect_rate = dropconnect_rate
473
406
 
474
407
  if self.qkv_features % self.num_heads != 0:
475
408
  raise ValueError(
@@ -491,6 +424,10 @@ class MultiHeadAttention(Module):
491
424
  use_bias=self.use_bias,
492
425
  precision=self.precision,
493
426
  epsilon=self.epsilon,
427
+ use_alpha=self.use_alpha,
428
+ alpha_init=self.alpha_init,
429
+ use_dropconnect=self.use_dropconnect,
430
+ drop_rate=self.dropconnect_rate,
494
431
  )
495
432
 
496
433
  # project inputs_q to multi-headed q/k/v
@@ -590,10 +527,23 @@ class MultiHeadAttention(Module):
590
527
  f'but module expects {self.in_features}.'
591
528
  )
592
529
 
530
+ is_deterministic: bool = False
531
+ if self.dropout_rate > 0.0 or (
532
+ self.use_dropconnect and self.dropconnect_rate > 0.0
533
+ ):
534
+ is_deterministic = first_from(
535
+ deterministic,
536
+ self.deterministic,
537
+ error_msg="""No `deterministic` argument was provided to MultiHeadAttention
538
+ as either a __call__ argument, class attribute, or nnx.flag.""",
539
+ )
540
+ else:
541
+ is_deterministic = True
542
+
593
543
  # Apply YatNMN transformations and reshape to multi-head format
594
- query = squash(self.query(inputs_q))
595
- key = squash(self.key(inputs_k))
596
- value = squash(self.value(inputs_v))
544
+ query = self.query(inputs_q, deterministic=is_deterministic)
545
+ key = self.key(inputs_k, deterministic=is_deterministic)
546
+ value = self.value(inputs_v, deterministic=is_deterministic)
597
547
 
598
548
  # Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
599
549
  query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
@@ -660,26 +610,11 @@ class MultiHeadAttention(Module):
660
610
  ),
661
611
  )
662
612
 
663
- if (
664
- self.dropout_rate > 0.0
665
- ): # Require `deterministic` only if using dropout.
666
- deterministic = first_from(
667
- deterministic,
668
- self.deterministic,
669
- error_msg="""No `deterministic` argument was provided to MultiHeadAttention
670
- as either a __call__ argument, class attribute, or nnx.flag.""",
671
- )
672
- if not deterministic:
673
- if rngs is None:
674
- raise ValueError(
675
- "'rngs' must be provided if 'dropout_rng' is not given."
676
- )
677
- dropout_rng = rngs.dropout()
678
- else:
679
- dropout_rng = None
680
- else:
681
- deterministic = True
682
- dropout_rng = None
613
+ dropout_rng = None
614
+ if self.dropout_rate > 0.0 and not is_deterministic:
615
+ if rngs is None:
616
+ raise ValueError("'rngs' must be provided for dropout.")
617
+ dropout_rng = rngs.dropout()
683
618
 
684
619
  # apply attention with epsilon parameter for YatNMN
685
620
  x = self.attention_fn(
@@ -690,11 +625,12 @@ class MultiHeadAttention(Module):
690
625
  dropout_rng=dropout_rng,
691
626
  dropout_rate=self.dropout_rate,
692
627
  broadcast_dropout=self.broadcast_dropout,
693
- deterministic=deterministic,
628
+ deterministic=is_deterministic,
694
629
  dtype=self.dtype,
695
630
  precision=self.precision,
696
631
  module=self if sow_weights else None,
697
632
  epsilon=self.epsilon, # Pass epsilon to yat_attention
633
+ use_softermax=self.use_softermax,
698
634
  )
699
635
  # Reshape attention output back to original embedding dimension
700
636
  # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes