nmn 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/TODO ADDED
@@ -0,0 +1,2 @@
1
+ - add support to masked kernels
2
+ - explain attention [directed graph]
nmn/nnx/yatattention.py CHANGED
@@ -44,6 +44,8 @@ def yat_attention_weights(
44
44
  module: Optional[Module] = None,
45
45
  epsilon: float = 1e-5,
46
46
  use_softermax: bool = False,
47
+ power: float = 1.0,
48
+
47
49
  ):
48
50
  """Computes attention weights using YatNMN distance-based calculation."""
49
51
  query, key = promote_dtype((query, key), dtype=dtype)
@@ -89,7 +91,7 @@ def yat_attention_weights(
89
91
 
90
92
  # normalize the attention weights
91
93
  if use_softermax:
92
- attn_weights = softermax(attn_weights).astype(dtype)
94
+ attn_weights = softermax(attn_weights, n=power).astype(dtype)
93
95
  else:
94
96
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
95
97
 
@@ -126,6 +128,7 @@ def yat_attention(
126
128
  module: Optional[Module] = None,
127
129
  epsilon: float = 1e-5,
128
130
  use_softermax: bool = False,
131
+ power: float = 1.0,
129
132
  ):
130
133
  """Computes attention using YatNMN distance-based calculation."""
131
134
  query, key, value = promote_dtype((query, key, value), dtype=dtype)
@@ -154,6 +157,7 @@ def yat_attention(
154
157
  module,
155
158
  epsilon,
156
159
  use_softermax,
160
+ power,
157
161
  )
158
162
 
159
163
  # return weighted sum over values for each query position
@@ -370,6 +374,7 @@ class MultiHeadAttention(Module):
370
374
  rngs: rnglib.Rngs,
371
375
  epsilon: float = 1e-5,
372
376
  use_softermax: bool = False,
377
+ power: float = 1.0,
373
378
  ):
374
379
  self.num_heads = num_heads
375
380
  self.in_features = in_features
@@ -399,6 +404,7 @@ class MultiHeadAttention(Module):
399
404
  self.out_dot_general_cls = out_dot_general_cls
400
405
  self.epsilon = epsilon
401
406
  self.use_softermax = use_softermax
407
+ self.power = power
402
408
  self.use_alpha = use_alpha
403
409
  self.alpha_init = alpha_init
404
410
  self.use_dropconnect = use_dropconnect
@@ -460,8 +466,6 @@ class MultiHeadAttention(Module):
460
466
  self.key_ln = None
461
467
 
462
468
  # Remove the output layer - no more self.out
463
- self.rngs = rngs if dropout_rate > 0.0 else None
464
-
465
469
  self.cached_key: nnx.Cache[Array] | None = None
466
470
  self.cached_value: nnx.Cache[Array] | None = None
467
471
  self.cache_index: nnx.Cache[Array] | None = None
@@ -507,8 +511,6 @@ class MultiHeadAttention(Module):
507
511
  Returns:
508
512
  output of shape `[batch_sizes..., length, features]`.
509
513
  """
510
- if rngs is None:
511
- rngs = self.rngs
512
514
 
513
515
  if inputs_k is None:
514
516
  if inputs_v is not None:
@@ -631,6 +633,7 @@ class MultiHeadAttention(Module):
631
633
  module=self if sow_weights else None,
632
634
  epsilon=self.epsilon, # Pass epsilon to yat_attention
633
635
  use_softermax=self.use_softermax,
636
+ power= self.power,
634
637
  )
635
638
  # Reshape attention output back to original embedding dimension
636
639
  # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]