nmn 0.1.5__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nmn-0.1.5 → nmn-0.1.7}/PKG-INFO +1 -1
- {nmn-0.1.5 → nmn-0.1.7}/pyproject.toml +1 -1
- nmn-0.1.7/src/nmn/nnx/loss/__init__.py +0 -0
- nmn-0.1.7/src/nmn/nnx/squashers/__init__.py +9 -0
- nmn-0.1.7/src/nmn/nnx/squashers/soft_tanh.py +29 -0
- nmn-0.1.7/src/nmn/nnx/squashers/softer_sigmoid.py +29 -0
- nmn-0.1.7/src/nmn/nnx/squashers/softermax.py +38 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/nnx/yatattention.py +47 -111
- {nmn-0.1.5 → nmn-0.1.7}/.github/workflows/publish.yml +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/.gitignore +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/LICENSE +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/MANIFEST.in +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/PUBLISH.md +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/README.md +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/hatch.toml +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/__init__.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/keras/nmn.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/linen/nmn.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/nnx/examples/language/mingpt.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/nnx/examples/vision/cnn_cifar.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/nnx/nmn.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/nnx/yatconv.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/tf/nmn.py +0 -0
- {nmn-0.1.5 → nmn-0.1.7}/src/nmn/torch/nmn.py +0 -0
{nmn-0.1.5 → nmn-0.1.7}/PKG-INFO
RENAMED
File without changes
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
|
4
|
+
def soft_tanh(
|
5
|
+
x: Array,
|
6
|
+
n: float = 1.0,
|
7
|
+
) -> Array:
|
8
|
+
"""
|
9
|
+
Maps a non-negative score to the range [-1, 1) using the soft-tanh function.
|
10
|
+
|
11
|
+
The soft-tanh function is defined as:
|
12
|
+
.. math::
|
13
|
+
\\text{soft-tanh}_n(x) = \\frac{x^n - 1}{1 + x^n}
|
14
|
+
|
15
|
+
The power `n` again controls the transition sharpness: higher `n` makes the
|
16
|
+
function approach -1 more quickly for large `x`.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
x (Array): A JAX array of non-negative scores (x >= 0).
|
20
|
+
n (float, optional): The power to raise the score to. Defaults to 1.0.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Array: The mapped scores in the range [-1, 1).
|
24
|
+
"""
|
25
|
+
if n <= 0:
|
26
|
+
raise ValueError("Power 'n' must be positive.")
|
27
|
+
|
28
|
+
x_n = jnp.power(x, n)
|
29
|
+
return (x_n - 1.0) / (1.0 + x_n)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
|
4
|
+
def softer_sigmoid(
|
5
|
+
x: Array,
|
6
|
+
n: float = 1.0,
|
7
|
+
) -> Array:
|
8
|
+
"""
|
9
|
+
Squashes a non-negative score into the range [0, 1) using the soft-sigmoid function.
|
10
|
+
|
11
|
+
The soft-sigmoid function is defined as:
|
12
|
+
.. math::
|
13
|
+
\\text{soft-sigmoid}_n(x) = \\frac{x^n}{1 + x^n}
|
14
|
+
|
15
|
+
The power `n` modulates the softness: higher `n` makes the function approach
|
16
|
+
zero faster for large `x`, while `n < 1` makes the decay slower.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
x (Array): A JAX array of non-negative scores (x >= 0).
|
20
|
+
n (float, optional): The power to raise the score to. Defaults to 1.0.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Array: The squashed scores in the range [0, 1).
|
24
|
+
"""
|
25
|
+
if n <= 0:
|
26
|
+
raise ValueError("Power 'n' must be positive.")
|
27
|
+
|
28
|
+
x_n = jnp.power(x, n)
|
29
|
+
return x_n / (1.0 + x_n)
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
def softermax(
|
6
|
+
x: Array,
|
7
|
+
n: float = 1.0,
|
8
|
+
epsilon: float = 1e-12,
|
9
|
+
axis: Optional[int] = -1,
|
10
|
+
) -> Array:
|
11
|
+
"""
|
12
|
+
Normalizes a set of non-negative scores using the Softermax function.
|
13
|
+
|
14
|
+
The Softermax function is defined as:
|
15
|
+
.. math::
|
16
|
+
\\text{softermax}_n(x_k, \\{x_i\\}) = \\frac{x_k^n}{\\epsilon + \\sum_i x_i^n}
|
17
|
+
|
18
|
+
The power `n` controls the sharpness of the distribution: `n=1` recovers
|
19
|
+
the original Softermax, while `n > 1` makes the distribution harder (more
|
20
|
+
peaked), and `0 < n < 1` makes it softer.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
x (Array): A JAX array of non-negative scores.
|
24
|
+
n (float, optional): The power to raise each score to. Defaults to 1.0.
|
25
|
+
epsilon (float, optional): A small constant for numerical stability.
|
26
|
+
Defaults to 1e-12.
|
27
|
+
axis (Optional[int], optional): The axis to perform the sum over.
|
28
|
+
Defaults to -1.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
Array: The normalized scores.
|
32
|
+
"""
|
33
|
+
if n <= 0:
|
34
|
+
raise ValueError("Power 'n' must be positive.")
|
35
|
+
|
36
|
+
x_n = jnp.power(x, n)
|
37
|
+
sum_x_n = jnp.sum(x_n, axis=axis, keepdims=True)
|
38
|
+
return x_n / (epsilon + sum_x_n)
|
@@ -26,8 +26,10 @@ from flax.typing import (
|
|
26
26
|
DotGeneralT,
|
27
27
|
)
|
28
28
|
|
29
|
+
from nmn.nnx.nmn import YatNMN
|
30
|
+
from jax import Array
|
29
31
|
|
30
|
-
|
32
|
+
from nmn.nnx.squashers import softermax
|
31
33
|
def yat_attention_weights(
|
32
34
|
query: Array,
|
33
35
|
key: Array,
|
@@ -41,6 +43,7 @@ def yat_attention_weights(
|
|
41
43
|
precision: PrecisionLike = None,
|
42
44
|
module: Optional[Module] = None,
|
43
45
|
epsilon: float = 1e-5,
|
46
|
+
use_softermax: bool = False,
|
44
47
|
):
|
45
48
|
"""Computes attention weights using YatNMN distance-based calculation."""
|
46
49
|
query, key = promote_dtype((query, key), dtype=dtype)
|
@@ -85,7 +88,10 @@ def yat_attention_weights(
|
|
85
88
|
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
86
89
|
|
87
90
|
# normalize the attention weights
|
88
|
-
|
91
|
+
if use_softermax:
|
92
|
+
attn_weights = softermax(attn_weights).astype(dtype)
|
93
|
+
else:
|
94
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
89
95
|
|
90
96
|
if module:
|
91
97
|
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
@@ -119,6 +125,7 @@ def yat_attention(
|
|
119
125
|
precision: PrecisionLike = None,
|
120
126
|
module: Optional[Module] = None,
|
121
127
|
epsilon: float = 1e-5,
|
128
|
+
use_softermax: bool = False,
|
122
129
|
):
|
123
130
|
"""Computes attention using YatNMN distance-based calculation."""
|
124
131
|
query, key, value = promote_dtype((query, key, value), dtype=dtype)
|
@@ -146,6 +153,7 @@ def yat_attention(
|
|
146
153
|
precision,
|
147
154
|
module,
|
148
155
|
epsilon,
|
156
|
+
use_softermax,
|
149
157
|
)
|
150
158
|
|
151
159
|
# return weighted sum over values for each query position
|
@@ -153,91 +161,6 @@ def yat_attention(
|
|
153
161
|
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
154
162
|
)
|
155
163
|
|
156
|
-
Array = jax.Array
|
157
|
-
|
158
|
-
# Add YatNMN class implementation
|
159
|
-
default_bias_init = initializers.zeros_init()
|
160
|
-
default_alpha_init = initializers.ones_init()
|
161
|
-
|
162
|
-
class YatNMN(Module):
|
163
|
-
"""A linear transformation with custom distance-based computation."""
|
164
|
-
|
165
|
-
def __init__(
|
166
|
-
self,
|
167
|
-
in_features: int,
|
168
|
-
out_features: int,
|
169
|
-
*,
|
170
|
-
use_bias: bool = True,
|
171
|
-
use_alpha: bool = True,
|
172
|
-
dtype: Optional[Dtype] = None,
|
173
|
-
param_dtype: Dtype = jnp.float32,
|
174
|
-
precision: PrecisionLike = None,
|
175
|
-
kernel_init: Initializer = default_kernel_init,
|
176
|
-
bias_init: Initializer = default_bias_init,
|
177
|
-
alpha_init: Initializer = default_alpha_init,
|
178
|
-
dot_general: DotGeneralT = lax.dot_general,
|
179
|
-
rngs: rnglib.Rngs,
|
180
|
-
epsilon: float = 1e-5,
|
181
|
-
):
|
182
|
-
|
183
|
-
kernel_key = rngs.params()
|
184
|
-
self.kernel = nnx.Param(
|
185
|
-
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
186
|
-
)
|
187
|
-
self.bias: nnx.Param[jax.Array] | None
|
188
|
-
if use_bias:
|
189
|
-
bias_key = rngs.params()
|
190
|
-
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
191
|
-
else:
|
192
|
-
self.bias = None
|
193
|
-
|
194
|
-
self.alpha: nnx.Param[jax.Array] | None
|
195
|
-
if use_alpha:
|
196
|
-
alpha_key = rngs.params()
|
197
|
-
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
198
|
-
else:
|
199
|
-
self.alpha = None
|
200
|
-
|
201
|
-
self.in_features = in_features
|
202
|
-
self.out_features = out_features
|
203
|
-
self.use_bias = use_bias
|
204
|
-
self.use_alpha = use_alpha
|
205
|
-
self.dtype = dtype
|
206
|
-
self.param_dtype = param_dtype
|
207
|
-
self.precision = precision
|
208
|
-
self.kernel_init = kernel_init
|
209
|
-
self.bias_init = bias_init
|
210
|
-
self.dot_general = dot_general
|
211
|
-
self.epsilon = epsilon
|
212
|
-
|
213
|
-
def __call__(self, inputs: Array) -> Array:
|
214
|
-
"""Applies YatNMN transformation to inputs."""
|
215
|
-
kernel = self.kernel.value
|
216
|
-
bias = self.bias.value if self.bias is not None else None
|
217
|
-
alpha = self.alpha.value if self.alpha is not None else None
|
218
|
-
|
219
|
-
y = self.dot_general(
|
220
|
-
inputs,
|
221
|
-
kernel,
|
222
|
-
(((inputs.ndim - 1,), (0,)), ((), ())),
|
223
|
-
precision=self.precision,
|
224
|
-
)
|
225
|
-
|
226
|
-
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
227
|
-
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
|
228
|
-
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
229
|
-
|
230
|
-
# Element-wise operation
|
231
|
-
y = y ** 2 / (distances + self.epsilon)
|
232
|
-
|
233
|
-
if bias is not None:
|
234
|
-
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
235
|
-
|
236
|
-
if alpha is not None:
|
237
|
-
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
238
|
-
y = y * scale
|
239
|
-
|
240
|
-
return y
|
241
164
|
|
242
165
|
|
243
166
|
def dot_product_attention_weights(
|
@@ -435,6 +358,10 @@ class MultiHeadAttention(Module):
|
|
435
358
|
attention_fn: Callable[..., Array] = yat_attention,
|
436
359
|
decode: bool | None = None,
|
437
360
|
normalize_qk: bool = False,
|
361
|
+
use_alpha: bool = True,
|
362
|
+
alpha_init: Initializer = initializers.ones_init(),
|
363
|
+
use_dropconnect: bool = False,
|
364
|
+
dropconnect_rate: float = 0.0,
|
438
365
|
# Deprecated, will be removed.
|
439
366
|
qkv_dot_general: DotGeneralT | None = None,
|
440
367
|
out_dot_general: DotGeneralT | None = None,
|
@@ -442,6 +369,7 @@ class MultiHeadAttention(Module):
|
|
442
369
|
out_dot_general_cls: Any = None,
|
443
370
|
rngs: rnglib.Rngs,
|
444
371
|
epsilon: float = 1e-5,
|
372
|
+
use_softermax: bool = False,
|
445
373
|
):
|
446
374
|
self.num_heads = num_heads
|
447
375
|
self.in_features = in_features
|
@@ -470,6 +398,11 @@ class MultiHeadAttention(Module):
|
|
470
398
|
self.qkv_dot_general_cls = qkv_dot_general_cls
|
471
399
|
self.out_dot_general_cls = out_dot_general_cls
|
472
400
|
self.epsilon = epsilon
|
401
|
+
self.use_softermax = use_softermax
|
402
|
+
self.use_alpha = use_alpha
|
403
|
+
self.alpha_init = alpha_init
|
404
|
+
self.use_dropconnect = use_dropconnect
|
405
|
+
self.dropconnect_rate = dropconnect_rate
|
473
406
|
|
474
407
|
if self.qkv_features % self.num_heads != 0:
|
475
408
|
raise ValueError(
|
@@ -491,6 +424,10 @@ class MultiHeadAttention(Module):
|
|
491
424
|
use_bias=self.use_bias,
|
492
425
|
precision=self.precision,
|
493
426
|
epsilon=self.epsilon,
|
427
|
+
use_alpha=self.use_alpha,
|
428
|
+
alpha_init=self.alpha_init,
|
429
|
+
use_dropconnect=self.use_dropconnect,
|
430
|
+
drop_rate=self.dropconnect_rate,
|
494
431
|
)
|
495
432
|
|
496
433
|
# project inputs_q to multi-headed q/k/v
|
@@ -590,10 +527,23 @@ class MultiHeadAttention(Module):
|
|
590
527
|
f'but module expects {self.in_features}.'
|
591
528
|
)
|
592
529
|
|
530
|
+
is_deterministic: bool = False
|
531
|
+
if self.dropout_rate > 0.0 or (
|
532
|
+
self.use_dropconnect and self.dropconnect_rate > 0.0
|
533
|
+
):
|
534
|
+
is_deterministic = first_from(
|
535
|
+
deterministic,
|
536
|
+
self.deterministic,
|
537
|
+
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
538
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
539
|
+
)
|
540
|
+
else:
|
541
|
+
is_deterministic = True
|
542
|
+
|
593
543
|
# Apply YatNMN transformations and reshape to multi-head format
|
594
|
-
query =
|
595
|
-
key =
|
596
|
-
value =
|
544
|
+
query = self.query(inputs_q, deterministic=is_deterministic)
|
545
|
+
key = self.key(inputs_k, deterministic=is_deterministic)
|
546
|
+
value = self.value(inputs_v, deterministic=is_deterministic)
|
597
547
|
|
598
548
|
# Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
|
599
549
|
query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
|
@@ -660,26 +610,11 @@ class MultiHeadAttention(Module):
|
|
660
610
|
),
|
661
611
|
)
|
662
612
|
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
self.deterministic,
|
669
|
-
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
670
|
-
as either a __call__ argument, class attribute, or nnx.flag.""",
|
671
|
-
)
|
672
|
-
if not deterministic:
|
673
|
-
if rngs is None:
|
674
|
-
raise ValueError(
|
675
|
-
"'rngs' must be provided if 'dropout_rng' is not given."
|
676
|
-
)
|
677
|
-
dropout_rng = rngs.dropout()
|
678
|
-
else:
|
679
|
-
dropout_rng = None
|
680
|
-
else:
|
681
|
-
deterministic = True
|
682
|
-
dropout_rng = None
|
613
|
+
dropout_rng = None
|
614
|
+
if self.dropout_rate > 0.0 and not is_deterministic:
|
615
|
+
if rngs is None:
|
616
|
+
raise ValueError("'rngs' must be provided for dropout.")
|
617
|
+
dropout_rng = rngs.dropout()
|
683
618
|
|
684
619
|
# apply attention with epsilon parameter for YatNMN
|
685
620
|
x = self.attention_fn(
|
@@ -690,11 +625,12 @@ class MultiHeadAttention(Module):
|
|
690
625
|
dropout_rng=dropout_rng,
|
691
626
|
dropout_rate=self.dropout_rate,
|
692
627
|
broadcast_dropout=self.broadcast_dropout,
|
693
|
-
deterministic=
|
628
|
+
deterministic=is_deterministic,
|
694
629
|
dtype=self.dtype,
|
695
630
|
precision=self.precision,
|
696
631
|
module=self if sow_weights else None,
|
697
632
|
epsilon=self.epsilon, # Pass epsilon to yat_attention
|
633
|
+
use_softermax=self.use_softermax,
|
698
634
|
)
|
699
635
|
# Reshape attention output back to original embedding dimension
|
700
636
|
# from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
|
File without changes
|
File without changes
|
{nmn-0.1.5 → nmn-0.1.7}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|