nmn 0.1.6__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nmn-0.1.6 → nmn-0.1.7}/PKG-INFO +1 -1
- {nmn-0.1.6 → nmn-0.1.7}/pyproject.toml +1 -1
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/yatattention.py +11 -1
- {nmn-0.1.6 → nmn-0.1.7}/.github/workflows/publish.yml +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/.gitignore +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/LICENSE +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/MANIFEST.in +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/PUBLISH.md +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/README.md +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/hatch.toml +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/__init__.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/keras/nmn.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/linen/nmn.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/examples/language/mingpt.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/examples/vision/cnn_cifar.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/loss/__init__.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/nmn.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/squashers/__init__.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/squashers/soft_tanh.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/squashers/softer_sigmoid.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/squashers/softermax.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/nnx/yatconv.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/tf/nmn.py +0 -0
- {nmn-0.1.6 → nmn-0.1.7}/src/nmn/torch/nmn.py +0 -0
{nmn-0.1.6 → nmn-0.1.7}/PKG-INFO
RENAMED
@@ -29,6 +29,7 @@ from flax.typing import (
|
|
29
29
|
from nmn.nnx.nmn import YatNMN
|
30
30
|
from jax import Array
|
31
31
|
|
32
|
+
from nmn.nnx.squashers import softermax
|
32
33
|
def yat_attention_weights(
|
33
34
|
query: Array,
|
34
35
|
key: Array,
|
@@ -42,6 +43,7 @@ def yat_attention_weights(
|
|
42
43
|
precision: PrecisionLike = None,
|
43
44
|
module: Optional[Module] = None,
|
44
45
|
epsilon: float = 1e-5,
|
46
|
+
use_softermax: bool = False,
|
45
47
|
):
|
46
48
|
"""Computes attention weights using YatNMN distance-based calculation."""
|
47
49
|
query, key = promote_dtype((query, key), dtype=dtype)
|
@@ -86,7 +88,10 @@ def yat_attention_weights(
|
|
86
88
|
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
87
89
|
|
88
90
|
# normalize the attention weights
|
89
|
-
|
91
|
+
if use_softermax:
|
92
|
+
attn_weights = softermax(attn_weights).astype(dtype)
|
93
|
+
else:
|
94
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
90
95
|
|
91
96
|
if module:
|
92
97
|
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
@@ -120,6 +125,7 @@ def yat_attention(
|
|
120
125
|
precision: PrecisionLike = None,
|
121
126
|
module: Optional[Module] = None,
|
122
127
|
epsilon: float = 1e-5,
|
128
|
+
use_softermax: bool = False,
|
123
129
|
):
|
124
130
|
"""Computes attention using YatNMN distance-based calculation."""
|
125
131
|
query, key, value = promote_dtype((query, key, value), dtype=dtype)
|
@@ -147,6 +153,7 @@ def yat_attention(
|
|
147
153
|
precision,
|
148
154
|
module,
|
149
155
|
epsilon,
|
156
|
+
use_softermax,
|
150
157
|
)
|
151
158
|
|
152
159
|
# return weighted sum over values for each query position
|
@@ -362,6 +369,7 @@ class MultiHeadAttention(Module):
|
|
362
369
|
out_dot_general_cls: Any = None,
|
363
370
|
rngs: rnglib.Rngs,
|
364
371
|
epsilon: float = 1e-5,
|
372
|
+
use_softermax: bool = False,
|
365
373
|
):
|
366
374
|
self.num_heads = num_heads
|
367
375
|
self.in_features = in_features
|
@@ -390,6 +398,7 @@ class MultiHeadAttention(Module):
|
|
390
398
|
self.qkv_dot_general_cls = qkv_dot_general_cls
|
391
399
|
self.out_dot_general_cls = out_dot_general_cls
|
392
400
|
self.epsilon = epsilon
|
401
|
+
self.use_softermax = use_softermax
|
393
402
|
self.use_alpha = use_alpha
|
394
403
|
self.alpha_init = alpha_init
|
395
404
|
self.use_dropconnect = use_dropconnect
|
@@ -621,6 +630,7 @@ class MultiHeadAttention(Module):
|
|
621
630
|
precision=self.precision,
|
622
631
|
module=self if sow_weights else None,
|
623
632
|
epsilon=self.epsilon, # Pass epsilon to yat_attention
|
633
|
+
use_softermax=self.use_softermax,
|
624
634
|
)
|
625
635
|
# Reshape attention output back to original embedding dimension
|
626
636
|
# from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
|
File without changes
|
File without changes
|
{nmn-0.1.6 → nmn-0.1.7}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|