nmn 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/TODO +2 -0
- nmn/nnx/yatattention.py +11 -5
- nmn/torch/conv.py +2105 -0
- {nmn-0.1.6.dist-info → nmn-0.1.8.dist-info}/METADATA +1 -1
- {nmn-0.1.6.dist-info → nmn-0.1.8.dist-info}/RECORD +7 -5
- {nmn-0.1.6.dist-info → nmn-0.1.8.dist-info}/WHEEL +0 -0
- {nmn-0.1.6.dist-info → nmn-0.1.8.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/TODO
ADDED
nmn/nnx/yatattention.py
CHANGED
@@ -29,6 +29,7 @@ from flax.typing import (
|
|
29
29
|
from nmn.nnx.nmn import YatNMN
|
30
30
|
from jax import Array
|
31
31
|
|
32
|
+
from nmn.nnx.squashers import softermax
|
32
33
|
def yat_attention_weights(
|
33
34
|
query: Array,
|
34
35
|
key: Array,
|
@@ -42,6 +43,7 @@ def yat_attention_weights(
|
|
42
43
|
precision: PrecisionLike = None,
|
43
44
|
module: Optional[Module] = None,
|
44
45
|
epsilon: float = 1e-5,
|
46
|
+
use_softermax: bool = False,
|
45
47
|
):
|
46
48
|
"""Computes attention weights using YatNMN distance-based calculation."""
|
47
49
|
query, key = promote_dtype((query, key), dtype=dtype)
|
@@ -86,7 +88,10 @@ def yat_attention_weights(
|
|
86
88
|
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
87
89
|
|
88
90
|
# normalize the attention weights
|
89
|
-
|
91
|
+
if use_softermax:
|
92
|
+
attn_weights = softermax(attn_weights).astype(dtype)
|
93
|
+
else:
|
94
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
90
95
|
|
91
96
|
if module:
|
92
97
|
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
@@ -120,6 +125,7 @@ def yat_attention(
|
|
120
125
|
precision: PrecisionLike = None,
|
121
126
|
module: Optional[Module] = None,
|
122
127
|
epsilon: float = 1e-5,
|
128
|
+
use_softermax: bool = False,
|
123
129
|
):
|
124
130
|
"""Computes attention using YatNMN distance-based calculation."""
|
125
131
|
query, key, value = promote_dtype((query, key, value), dtype=dtype)
|
@@ -147,6 +153,7 @@ def yat_attention(
|
|
147
153
|
precision,
|
148
154
|
module,
|
149
155
|
epsilon,
|
156
|
+
use_softermax,
|
150
157
|
)
|
151
158
|
|
152
159
|
# return weighted sum over values for each query position
|
@@ -362,6 +369,7 @@ class MultiHeadAttention(Module):
|
|
362
369
|
out_dot_general_cls: Any = None,
|
363
370
|
rngs: rnglib.Rngs,
|
364
371
|
epsilon: float = 1e-5,
|
372
|
+
use_softermax: bool = False,
|
365
373
|
):
|
366
374
|
self.num_heads = num_heads
|
367
375
|
self.in_features = in_features
|
@@ -390,6 +398,7 @@ class MultiHeadAttention(Module):
|
|
390
398
|
self.qkv_dot_general_cls = qkv_dot_general_cls
|
391
399
|
self.out_dot_general_cls = out_dot_general_cls
|
392
400
|
self.epsilon = epsilon
|
401
|
+
self.use_softermax = use_softermax
|
393
402
|
self.use_alpha = use_alpha
|
394
403
|
self.alpha_init = alpha_init
|
395
404
|
self.use_dropconnect = use_dropconnect
|
@@ -451,8 +460,6 @@ class MultiHeadAttention(Module):
|
|
451
460
|
self.key_ln = None
|
452
461
|
|
453
462
|
# Remove the output layer - no more self.out
|
454
|
-
self.rngs = rngs if dropout_rate > 0.0 else None
|
455
|
-
|
456
463
|
self.cached_key: nnx.Cache[Array] | None = None
|
457
464
|
self.cached_value: nnx.Cache[Array] | None = None
|
458
465
|
self.cache_index: nnx.Cache[Array] | None = None
|
@@ -498,8 +505,6 @@ class MultiHeadAttention(Module):
|
|
498
505
|
Returns:
|
499
506
|
output of shape `[batch_sizes..., length, features]`.
|
500
507
|
"""
|
501
|
-
if rngs is None:
|
502
|
-
rngs = self.rngs
|
503
508
|
|
504
509
|
if inputs_k is None:
|
505
510
|
if inputs_v is not None:
|
@@ -621,6 +626,7 @@ class MultiHeadAttention(Module):
|
|
621
626
|
precision=self.precision,
|
622
627
|
module=self if sow_weights else None,
|
623
628
|
epsilon=self.epsilon, # Pass epsilon to yat_attention
|
629
|
+
use_softermax=self.use_softermax,
|
624
630
|
)
|
625
631
|
# Reshape attention output back to original embedding dimension
|
626
632
|
# from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
|