nmn 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,12 @@
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import jax
1
5
  import jax.numpy as jnp
2
6
  from jax import Array
3
- from typing import Optional
4
7
 
8
+
9
+ @partial(jax.jit, static_argnames=("n", "axis"))
5
10
  def softermax(
6
11
  x: Array,
7
12
  n: float = 1.0,
nmn/nnx/yatattention.py CHANGED
@@ -44,6 +44,8 @@ def yat_attention_weights(
44
44
  module: Optional[Module] = None,
45
45
  epsilon: float = 1e-5,
46
46
  use_softermax: bool = False,
47
+ power: float = 1.0,
48
+
47
49
  ):
48
50
  """Computes attention weights using YatNMN distance-based calculation."""
49
51
  query, key = promote_dtype((query, key), dtype=dtype)
@@ -89,7 +91,7 @@ def yat_attention_weights(
89
91
 
90
92
  # normalize the attention weights
91
93
  if use_softermax:
92
- attn_weights = softermax(attn_weights).astype(dtype)
94
+ attn_weights = softermax(attn_weights, n=power).astype(dtype)
93
95
  else:
94
96
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
95
97
 
@@ -126,6 +128,7 @@ def yat_attention(
126
128
  module: Optional[Module] = None,
127
129
  epsilon: float = 1e-5,
128
130
  use_softermax: bool = False,
131
+ power: float = 1.0,
129
132
  ):
130
133
  """Computes attention using YatNMN distance-based calculation."""
131
134
  query, key, value = promote_dtype((query, key, value), dtype=dtype)
@@ -154,6 +157,7 @@ def yat_attention(
154
157
  module,
155
158
  epsilon,
156
159
  use_softermax,
160
+ power,
157
161
  )
158
162
 
159
163
  # return weighted sum over values for each query position
@@ -370,6 +374,7 @@ class MultiHeadAttention(Module):
370
374
  rngs: rnglib.Rngs,
371
375
  epsilon: float = 1e-5,
372
376
  use_softermax: bool = False,
377
+ power: float = 1.0,
373
378
  ):
374
379
  self.num_heads = num_heads
375
380
  self.in_features = in_features
@@ -399,6 +404,7 @@ class MultiHeadAttention(Module):
399
404
  self.out_dot_general_cls = out_dot_general_cls
400
405
  self.epsilon = epsilon
401
406
  self.use_softermax = use_softermax
407
+ self.power = power
402
408
  self.use_alpha = use_alpha
403
409
  self.alpha_init = alpha_init
404
410
  self.use_dropconnect = use_dropconnect
@@ -627,6 +633,7 @@ class MultiHeadAttention(Module):
627
633
  module=self if sow_weights else None,
628
634
  epsilon=self.epsilon, # Pass epsilon to yat_attention
629
635
  use_softermax=self.use_softermax,
636
+ power= self.power,
630
637
  )
631
638
  # Reshape attention output back to original embedding dimension
632
639
  # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nmn
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: a neuron that matter
5
5
  Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
6
  Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
@@ -3,7 +3,7 @@ nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
3
  nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
4
  nmn/nnx/TODO,sha256=U1WV51Eqij5igMjWLcbCjAZPONwIoPUQsMFKYHC6C8g,68
5
5
  nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
6
- nmn/nnx/yatattention.py,sha256=4WyL9JW5wG05YURaLt76wA0zu2Bu2rMWtyFnVo9Gybo,24864
6
+ nmn/nnx/yatattention.py,sha256=qEWiG_FIgr-TslYCbm2pcBi1myXJLC84nT6k1tMQcr4,25001
7
7
  nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
8
8
  nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
9
9
  nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
@@ -11,11 +11,11 @@ nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
12
12
  nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
13
13
  nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
14
- nmn/nnx/squashers/softermax.py,sha256=NfxEDbogLUysyTvtVCTpDt27PplYvKRQLTZbYCL-Wfg,1226
14
+ nmn/nnx/squashers/softermax.py,sha256=ggg0mHMFyk7b5xs31o-inNvWDzEvghD6YO3mtPlnkW4,1318
15
15
  nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
16
16
  nmn/torch/conv.py,sha256=g5YxStk1p85WkvfecqbzRZaWaAJahOSArpMcqxWAWKc,83413
17
17
  nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
18
- nmn-0.1.8.dist-info/METADATA,sha256=8DvDHl3Tkp3HSge5rYHZdKaKTso3j-H1lDgF3c4owrI,8800
19
- nmn-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
20
- nmn-0.1.8.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
21
- nmn-0.1.8.dist-info/RECORD,,
18
+ nmn-0.1.10.dist-info/METADATA,sha256=o-wLjeO-n2h56-cvw-AqrRiio5UFaerm58w03XkdHQY,8801
19
+ nmn-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
20
+ nmn-0.1.10.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
21
+ nmn-0.1.10.dist-info/RECORD,,
File without changes