nmn 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/TODO ADDED
@@ -0,0 +1,2 @@
1
+ - add support to masked kernels
2
+ - explain attention [directed graph]
nmn/nnx/yatattention.py CHANGED
@@ -29,6 +29,7 @@ from flax.typing import (
29
29
  from nmn.nnx.nmn import YatNMN
30
30
  from jax import Array
31
31
 
32
+ from nmn.nnx.squashers import softermax
32
33
  def yat_attention_weights(
33
34
  query: Array,
34
35
  key: Array,
@@ -42,6 +43,7 @@ def yat_attention_weights(
42
43
  precision: PrecisionLike = None,
43
44
  module: Optional[Module] = None,
44
45
  epsilon: float = 1e-5,
46
+ use_softermax: bool = False,
45
47
  ):
46
48
  """Computes attention weights using YatNMN distance-based calculation."""
47
49
  query, key = promote_dtype((query, key), dtype=dtype)
@@ -86,7 +88,10 @@ def yat_attention_weights(
86
88
  attn_weights = jnp.where(mask, attn_weights, big_neg)
87
89
 
88
90
  # normalize the attention weights
89
- attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
91
+ if use_softermax:
92
+ attn_weights = softermax(attn_weights).astype(dtype)
93
+ else:
94
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
90
95
 
91
96
  if module:
92
97
  module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
@@ -120,6 +125,7 @@ def yat_attention(
120
125
  precision: PrecisionLike = None,
121
126
  module: Optional[Module] = None,
122
127
  epsilon: float = 1e-5,
128
+ use_softermax: bool = False,
123
129
  ):
124
130
  """Computes attention using YatNMN distance-based calculation."""
125
131
  query, key, value = promote_dtype((query, key, value), dtype=dtype)
@@ -147,6 +153,7 @@ def yat_attention(
147
153
  precision,
148
154
  module,
149
155
  epsilon,
156
+ use_softermax,
150
157
  )
151
158
 
152
159
  # return weighted sum over values for each query position
@@ -362,6 +369,7 @@ class MultiHeadAttention(Module):
362
369
  out_dot_general_cls: Any = None,
363
370
  rngs: rnglib.Rngs,
364
371
  epsilon: float = 1e-5,
372
+ use_softermax: bool = False,
365
373
  ):
366
374
  self.num_heads = num_heads
367
375
  self.in_features = in_features
@@ -390,6 +398,7 @@ class MultiHeadAttention(Module):
390
398
  self.qkv_dot_general_cls = qkv_dot_general_cls
391
399
  self.out_dot_general_cls = out_dot_general_cls
392
400
  self.epsilon = epsilon
401
+ self.use_softermax = use_softermax
393
402
  self.use_alpha = use_alpha
394
403
  self.alpha_init = alpha_init
395
404
  self.use_dropconnect = use_dropconnect
@@ -451,8 +460,6 @@ class MultiHeadAttention(Module):
451
460
  self.key_ln = None
452
461
 
453
462
  # Remove the output layer - no more self.out
454
- self.rngs = rngs if dropout_rate > 0.0 else None
455
-
456
463
  self.cached_key: nnx.Cache[Array] | None = None
457
464
  self.cached_value: nnx.Cache[Array] | None = None
458
465
  self.cache_index: nnx.Cache[Array] | None = None
@@ -498,8 +505,6 @@ class MultiHeadAttention(Module):
498
505
  Returns:
499
506
  output of shape `[batch_sizes..., length, features]`.
500
507
  """
501
- if rngs is None:
502
- rngs = self.rngs
503
508
 
504
509
  if inputs_k is None:
505
510
  if inputs_v is not None:
@@ -621,6 +626,7 @@ class MultiHeadAttention(Module):
621
626
  precision=self.precision,
622
627
  module=self if sow_weights else None,
623
628
  epsilon=self.epsilon, # Pass epsilon to yat_attention
629
+ use_softermax=self.use_softermax,
624
630
  )
625
631
  # Reshape attention output back to original embedding dimension
626
632
  # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]