nmn 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nmn
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: a neuron that matter
5
5
  Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
6
  Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "nmn"
7
- version = "0.1.6"
7
+ version = "0.1.7"
8
8
  authors = [
9
9
  { name="Taha Bouhsine", email="yat@mlnomads.com" },
10
10
  ]
@@ -29,6 +29,7 @@ from flax.typing import (
29
29
  from nmn.nnx.nmn import YatNMN
30
30
  from jax import Array
31
31
 
32
+ from nmn.nnx.squashers import softermax
32
33
  def yat_attention_weights(
33
34
  query: Array,
34
35
  key: Array,
@@ -42,6 +43,7 @@ def yat_attention_weights(
42
43
  precision: PrecisionLike = None,
43
44
  module: Optional[Module] = None,
44
45
  epsilon: float = 1e-5,
46
+ use_softermax: bool = False,
45
47
  ):
46
48
  """Computes attention weights using YatNMN distance-based calculation."""
47
49
  query, key = promote_dtype((query, key), dtype=dtype)
@@ -86,7 +88,10 @@ def yat_attention_weights(
86
88
  attn_weights = jnp.where(mask, attn_weights, big_neg)
87
89
 
88
90
  # normalize the attention weights
89
- attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
91
+ if use_softermax:
92
+ attn_weights = softermax(attn_weights).astype(dtype)
93
+ else:
94
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
90
95
 
91
96
  if module:
92
97
  module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
@@ -120,6 +125,7 @@ def yat_attention(
120
125
  precision: PrecisionLike = None,
121
126
  module: Optional[Module] = None,
122
127
  epsilon: float = 1e-5,
128
+ use_softermax: bool = False,
123
129
  ):
124
130
  """Computes attention using YatNMN distance-based calculation."""
125
131
  query, key, value = promote_dtype((query, key, value), dtype=dtype)
@@ -147,6 +153,7 @@ def yat_attention(
147
153
  precision,
148
154
  module,
149
155
  epsilon,
156
+ use_softermax,
150
157
  )
151
158
 
152
159
  # return weighted sum over values for each query position
@@ -362,6 +369,7 @@ class MultiHeadAttention(Module):
362
369
  out_dot_general_cls: Any = None,
363
370
  rngs: rnglib.Rngs,
364
371
  epsilon: float = 1e-5,
372
+ use_softermax: bool = False,
365
373
  ):
366
374
  self.num_heads = num_heads
367
375
  self.in_features = in_features
@@ -390,6 +398,7 @@ class MultiHeadAttention(Module):
390
398
  self.qkv_dot_general_cls = qkv_dot_general_cls
391
399
  self.out_dot_general_cls = out_dot_general_cls
392
400
  self.epsilon = epsilon
401
+ self.use_softermax = use_softermax
393
402
  self.use_alpha = use_alpha
394
403
  self.alpha_init = alpha_init
395
404
  self.use_dropconnect = use_dropconnect
@@ -621,6 +630,7 @@ class MultiHeadAttention(Module):
621
630
  precision=self.precision,
622
631
  module=self if sow_weights else None,
623
632
  epsilon=self.epsilon, # Pass epsilon to yat_attention
633
+ use_softermax=self.use_softermax,
624
634
  )
625
635
  # Reshape attention output back to original embedding dimension
626
636
  # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes