nmn 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/loss/__init__.py +0 -0
- nmn/nnx/squashers/__init__.py +9 -0
- nmn/nnx/squashers/soft_tanh.py +29 -0
- nmn/nnx/squashers/softer_sigmoid.py +29 -0
- nmn/nnx/squashers/softermax.py +38 -0
- nmn/nnx/yatattention.py +36 -110
- {nmn-0.1.5.dist-info → nmn-0.1.6.dist-info}/METADATA +1 -1
- nmn-0.1.6.dist-info/RECORD +19 -0
- nmn-0.1.5.dist-info/RECORD +0 -14
- {nmn-0.1.5.dist-info → nmn-0.1.6.dist-info}/WHEEL +0 -0
- {nmn-0.1.5.dist-info → nmn-0.1.6.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/loss/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
|
4
|
+
def soft_tanh(
|
5
|
+
x: Array,
|
6
|
+
n: float = 1.0,
|
7
|
+
) -> Array:
|
8
|
+
"""
|
9
|
+
Maps a non-negative score to the range [-1, 1) using the soft-tanh function.
|
10
|
+
|
11
|
+
The soft-tanh function is defined as:
|
12
|
+
.. math::
|
13
|
+
\\text{soft-tanh}_n(x) = \\frac{x^n - 1}{1 + x^n}
|
14
|
+
|
15
|
+
The power `n` again controls the transition sharpness: higher `n` makes the
|
16
|
+
function approach -1 more quickly for large `x`.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
x (Array): A JAX array of non-negative scores (x >= 0).
|
20
|
+
n (float, optional): The power to raise the score to. Defaults to 1.0.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Array: The mapped scores in the range [-1, 1).
|
24
|
+
"""
|
25
|
+
if n <= 0:
|
26
|
+
raise ValueError("Power 'n' must be positive.")
|
27
|
+
|
28
|
+
x_n = jnp.power(x, n)
|
29
|
+
return (x_n - 1.0) / (1.0 + x_n)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
|
4
|
+
def softer_sigmoid(
|
5
|
+
x: Array,
|
6
|
+
n: float = 1.0,
|
7
|
+
) -> Array:
|
8
|
+
"""
|
9
|
+
Squashes a non-negative score into the range [0, 1) using the soft-sigmoid function.
|
10
|
+
|
11
|
+
The soft-sigmoid function is defined as:
|
12
|
+
.. math::
|
13
|
+
\\text{soft-sigmoid}_n(x) = \\frac{x^n}{1 + x^n}
|
14
|
+
|
15
|
+
The power `n` modulates the softness: higher `n` makes the function approach
|
16
|
+
zero faster for large `x`, while `n < 1` makes the decay slower.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
x (Array): A JAX array of non-negative scores (x >= 0).
|
20
|
+
n (float, optional): The power to raise the score to. Defaults to 1.0.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Array: The squashed scores in the range [0, 1).
|
24
|
+
"""
|
25
|
+
if n <= 0:
|
26
|
+
raise ValueError("Power 'n' must be positive.")
|
27
|
+
|
28
|
+
x_n = jnp.power(x, n)
|
29
|
+
return x_n / (1.0 + x_n)
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
def softermax(
|
6
|
+
x: Array,
|
7
|
+
n: float = 1.0,
|
8
|
+
epsilon: float = 1e-12,
|
9
|
+
axis: Optional[int] = -1,
|
10
|
+
) -> Array:
|
11
|
+
"""
|
12
|
+
Normalizes a set of non-negative scores using the Softermax function.
|
13
|
+
|
14
|
+
The Softermax function is defined as:
|
15
|
+
.. math::
|
16
|
+
\\text{softermax}_n(x_k, \\{x_i\\}) = \\frac{x_k^n}{\\epsilon + \\sum_i x_i^n}
|
17
|
+
|
18
|
+
The power `n` controls the sharpness of the distribution: `n=1` recovers
|
19
|
+
the original Softermax, while `n > 1` makes the distribution harder (more
|
20
|
+
peaked), and `0 < n < 1` makes it softer.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
x (Array): A JAX array of non-negative scores.
|
24
|
+
n (float, optional): The power to raise each score to. Defaults to 1.0.
|
25
|
+
epsilon (float, optional): A small constant for numerical stability.
|
26
|
+
Defaults to 1e-12.
|
27
|
+
axis (Optional[int], optional): The axis to perform the sum over.
|
28
|
+
Defaults to -1.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
Array: The normalized scores.
|
32
|
+
"""
|
33
|
+
if n <= 0:
|
34
|
+
raise ValueError("Power 'n' must be positive.")
|
35
|
+
|
36
|
+
x_n = jnp.power(x, n)
|
37
|
+
sum_x_n = jnp.sum(x_n, axis=axis, keepdims=True)
|
38
|
+
return x_n / (epsilon + sum_x_n)
|
nmn/nnx/yatattention.py
CHANGED
@@ -26,7 +26,8 @@ from flax.typing import (
|
|
26
26
|
DotGeneralT,
|
27
27
|
)
|
28
28
|
|
29
|
-
|
29
|
+
from nmn.nnx.nmn import YatNMN
|
30
|
+
from jax import Array
|
30
31
|
|
31
32
|
def yat_attention_weights(
|
32
33
|
query: Array,
|
@@ -153,91 +154,6 @@ def yat_attention(
|
|
153
154
|
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
154
155
|
)
|
155
156
|
|
156
|
-
Array = jax.Array
|
157
|
-
|
158
|
-
# Add YatNMN class implementation
|
159
|
-
default_bias_init = initializers.zeros_init()
|
160
|
-
default_alpha_init = initializers.ones_init()
|
161
|
-
|
162
|
-
class YatNMN(Module):
|
163
|
-
"""A linear transformation with custom distance-based computation."""
|
164
|
-
|
165
|
-
def __init__(
|
166
|
-
self,
|
167
|
-
in_features: int,
|
168
|
-
out_features: int,
|
169
|
-
*,
|
170
|
-
use_bias: bool = True,
|
171
|
-
use_alpha: bool = True,
|
172
|
-
dtype: Optional[Dtype] = None,
|
173
|
-
param_dtype: Dtype = jnp.float32,
|
174
|
-
precision: PrecisionLike = None,
|
175
|
-
kernel_init: Initializer = default_kernel_init,
|
176
|
-
bias_init: Initializer = default_bias_init,
|
177
|
-
alpha_init: Initializer = default_alpha_init,
|
178
|
-
dot_general: DotGeneralT = lax.dot_general,
|
179
|
-
rngs: rnglib.Rngs,
|
180
|
-
epsilon: float = 1e-5,
|
181
|
-
):
|
182
|
-
|
183
|
-
kernel_key = rngs.params()
|
184
|
-
self.kernel = nnx.Param(
|
185
|
-
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
186
|
-
)
|
187
|
-
self.bias: nnx.Param[jax.Array] | None
|
188
|
-
if use_bias:
|
189
|
-
bias_key = rngs.params()
|
190
|
-
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
191
|
-
else:
|
192
|
-
self.bias = None
|
193
|
-
|
194
|
-
self.alpha: nnx.Param[jax.Array] | None
|
195
|
-
if use_alpha:
|
196
|
-
alpha_key = rngs.params()
|
197
|
-
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
198
|
-
else:
|
199
|
-
self.alpha = None
|
200
|
-
|
201
|
-
self.in_features = in_features
|
202
|
-
self.out_features = out_features
|
203
|
-
self.use_bias = use_bias
|
204
|
-
self.use_alpha = use_alpha
|
205
|
-
self.dtype = dtype
|
206
|
-
self.param_dtype = param_dtype
|
207
|
-
self.precision = precision
|
208
|
-
self.kernel_init = kernel_init
|
209
|
-
self.bias_init = bias_init
|
210
|
-
self.dot_general = dot_general
|
211
|
-
self.epsilon = epsilon
|
212
|
-
|
213
|
-
def __call__(self, inputs: Array) -> Array:
|
214
|
-
"""Applies YatNMN transformation to inputs."""
|
215
|
-
kernel = self.kernel.value
|
216
|
-
bias = self.bias.value if self.bias is not None else None
|
217
|
-
alpha = self.alpha.value if self.alpha is not None else None
|
218
|
-
|
219
|
-
y = self.dot_general(
|
220
|
-
inputs,
|
221
|
-
kernel,
|
222
|
-
(((inputs.ndim - 1,), (0,)), ((), ())),
|
223
|
-
precision=self.precision,
|
224
|
-
)
|
225
|
-
|
226
|
-
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
227
|
-
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
|
228
|
-
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
229
|
-
|
230
|
-
# Element-wise operation
|
231
|
-
y = y ** 2 / (distances + self.epsilon)
|
232
|
-
|
233
|
-
if bias is not None:
|
234
|
-
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
235
|
-
|
236
|
-
if alpha is not None:
|
237
|
-
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
238
|
-
y = y * scale
|
239
|
-
|
240
|
-
return y
|
241
157
|
|
242
158
|
|
243
159
|
def dot_product_attention_weights(
|
@@ -435,6 +351,10 @@ class MultiHeadAttention(Module):
|
|
435
351
|
attention_fn: Callable[..., Array] = yat_attention,
|
436
352
|
decode: bool | None = None,
|
437
353
|
normalize_qk: bool = False,
|
354
|
+
use_alpha: bool = True,
|
355
|
+
alpha_init: Initializer = initializers.ones_init(),
|
356
|
+
use_dropconnect: bool = False,
|
357
|
+
dropconnect_rate: float = 0.0,
|
438
358
|
# Deprecated, will be removed.
|
439
359
|
qkv_dot_general: DotGeneralT | None = None,
|
440
360
|
out_dot_general: DotGeneralT | None = None,
|
@@ -470,6 +390,10 @@ class MultiHeadAttention(Module):
|
|
470
390
|
self.qkv_dot_general_cls = qkv_dot_general_cls
|
471
391
|
self.out_dot_general_cls = out_dot_general_cls
|
472
392
|
self.epsilon = epsilon
|
393
|
+
self.use_alpha = use_alpha
|
394
|
+
self.alpha_init = alpha_init
|
395
|
+
self.use_dropconnect = use_dropconnect
|
396
|
+
self.dropconnect_rate = dropconnect_rate
|
473
397
|
|
474
398
|
if self.qkv_features % self.num_heads != 0:
|
475
399
|
raise ValueError(
|
@@ -491,6 +415,10 @@ class MultiHeadAttention(Module):
|
|
491
415
|
use_bias=self.use_bias,
|
492
416
|
precision=self.precision,
|
493
417
|
epsilon=self.epsilon,
|
418
|
+
use_alpha=self.use_alpha,
|
419
|
+
alpha_init=self.alpha_init,
|
420
|
+
use_dropconnect=self.use_dropconnect,
|
421
|
+
drop_rate=self.dropconnect_rate,
|
494
422
|
)
|
495
423
|
|
496
424
|
# project inputs_q to multi-headed q/k/v
|
@@ -590,10 +518,23 @@ class MultiHeadAttention(Module):
|
|
590
518
|
f'but module expects {self.in_features}.'
|
591
519
|
)
|
592
520
|
|
521
|
+
is_deterministic: bool = False
|
522
|
+
if self.dropout_rate > 0.0 or (
|
523
|
+
self.use_dropconnect and self.dropconnect_rate > 0.0
|
524
|
+
):
|
525
|
+
is_deterministic = first_from(
|
526
|
+
deterministic,
|
527
|
+
self.deterministic,
|
528
|
+
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
529
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
530
|
+
)
|
531
|
+
else:
|
532
|
+
is_deterministic = True
|
533
|
+
|
593
534
|
# Apply YatNMN transformations and reshape to multi-head format
|
594
|
-
query =
|
595
|
-
key =
|
596
|
-
value =
|
535
|
+
query = self.query(inputs_q, deterministic=is_deterministic)
|
536
|
+
key = self.key(inputs_k, deterministic=is_deterministic)
|
537
|
+
value = self.value(inputs_v, deterministic=is_deterministic)
|
597
538
|
|
598
539
|
# Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
|
599
540
|
query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
|
@@ -660,26 +601,11 @@ class MultiHeadAttention(Module):
|
|
660
601
|
),
|
661
602
|
)
|
662
603
|
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
self.deterministic,
|
669
|
-
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
670
|
-
as either a __call__ argument, class attribute, or nnx.flag.""",
|
671
|
-
)
|
672
|
-
if not deterministic:
|
673
|
-
if rngs is None:
|
674
|
-
raise ValueError(
|
675
|
-
"'rngs' must be provided if 'dropout_rng' is not given."
|
676
|
-
)
|
677
|
-
dropout_rng = rngs.dropout()
|
678
|
-
else:
|
679
|
-
dropout_rng = None
|
680
|
-
else:
|
681
|
-
deterministic = True
|
682
|
-
dropout_rng = None
|
604
|
+
dropout_rng = None
|
605
|
+
if self.dropout_rate > 0.0 and not is_deterministic:
|
606
|
+
if rngs is None:
|
607
|
+
raise ValueError("'rngs' must be provided for dropout.")
|
608
|
+
dropout_rng = rngs.dropout()
|
683
609
|
|
684
610
|
# apply attention with epsilon parameter for YatNMN
|
685
611
|
x = self.attention_fn(
|
@@ -690,7 +616,7 @@ class MultiHeadAttention(Module):
|
|
690
616
|
dropout_rng=dropout_rng,
|
691
617
|
dropout_rate=self.dropout_rate,
|
692
618
|
broadcast_dropout=self.broadcast_dropout,
|
693
|
-
deterministic=
|
619
|
+
deterministic=is_deterministic,
|
694
620
|
dtype=self.dtype,
|
695
621
|
precision=self.precision,
|
696
622
|
module=self if sow_weights else None,
|
@@ -0,0 +1,19 @@
|
|
1
|
+
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
|
+
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
|
+
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
+
nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
|
5
|
+
nmn/nnx/yatattention.py,sha256=i6XfCGHISyb2P6KrgYFnhhdzqSTWAyshFhy1XEeuEWc,24642
|
6
|
+
nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
|
7
|
+
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
8
|
+
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
9
|
+
nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
|
11
|
+
nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
|
12
|
+
nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
|
13
|
+
nmn/nnx/squashers/softermax.py,sha256=NfxEDbogLUysyTvtVCTpDt27PplYvKRQLTZbYCL-Wfg,1226
|
14
|
+
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
15
|
+
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
16
|
+
nmn-0.1.6.dist-info/METADATA,sha256=Y9MByC16wz1MGYVZRmZA0wJQATB6Kj6w6TOL5lPzl0Q,8800
|
17
|
+
nmn-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
18
|
+
nmn-0.1.6.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
19
|
+
nmn-0.1.6.dist-info/RECORD,,
|
nmn-0.1.5.dist-info/RECORD
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
|
-
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
|
-
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
-
nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
|
5
|
-
nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
|
6
|
-
nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
|
7
|
-
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
8
|
-
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
9
|
-
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
10
|
-
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
11
|
-
nmn-0.1.5.dist-info/METADATA,sha256=7gvXle6Hgdgyj_tJk1DGdkOh03BOsfSks-ZHPOIEwHQ,8800
|
12
|
-
nmn-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
-
nmn-0.1.5.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
14
|
-
nmn-0.1.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|