noxton 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- noxton-0.1.0/LICENSE +21 -0
- noxton-0.1.0/PKG-INFO +11 -0
- noxton-0.1.0/README.md +0 -0
- noxton-0.1.0/noxton/functions/__init__.py +32 -0
- noxton-0.1.0/noxton/functions/activation.py +34 -0
- noxton-0.1.0/noxton/functions/attention.py +594 -0
- noxton-0.1.0/noxton/functions/embedding.py +64 -0
- noxton-0.1.0/noxton/functions/masking.py +198 -0
- noxton-0.1.0/noxton/functions/normalization.py +39 -0
- noxton-0.1.0/noxton/functions/regularization.py +125 -0
- noxton-0.1.0/noxton/nn/__init__.py +0 -0
- noxton-0.1.0/noxton/nn/abstract.py +16 -0
- noxton-0.1.0/noxton/utils/__init__.py +5 -0
- noxton-0.1.0/noxton/utils/utils.py +59 -0
- noxton-0.1.0/noxton.egg-info/PKG-INFO +11 -0
- noxton-0.1.0/noxton.egg-info/SOURCES.txt +19 -0
- noxton-0.1.0/noxton.egg-info/dependency_links.txt +1 -0
- noxton-0.1.0/noxton.egg-info/requires.txt +3 -0
- noxton-0.1.0/noxton.egg-info/top_level.txt +1 -0
- noxton-0.1.0/pyproject.toml +11 -0
- noxton-0.1.0/setup.cfg +4 -0
noxton-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Artur A. Galstyan
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
noxton-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: noxton
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.13
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: equinox>=0.13.4
|
|
9
|
+
Requires-Dist: jax>=0.9.0.1
|
|
10
|
+
Requires-Dist: statedict2pytree>=2.0.1
|
|
11
|
+
Dynamic: license-file
|
noxton-0.1.0/README.md
ADDED
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from .activation import swiglu
|
|
2
|
+
from .attention import (
|
|
3
|
+
create_attn_mask,
|
|
4
|
+
multi_head_attention_forward,
|
|
5
|
+
shifted_window_attention,
|
|
6
|
+
)
|
|
7
|
+
from .embedding import sinusoidal_embedding
|
|
8
|
+
from .masking import (
|
|
9
|
+
build_attention_mask,
|
|
10
|
+
canonical_attn_mask,
|
|
11
|
+
canonical_key_padding_mask,
|
|
12
|
+
canonical_mask,
|
|
13
|
+
make_causal_mask,
|
|
14
|
+
)
|
|
15
|
+
from .normalization import normalize
|
|
16
|
+
from .regularization import dropout, stochastic_depth
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"swiglu",
|
|
20
|
+
"multi_head_attention_forward",
|
|
21
|
+
"shifted_window_attention",
|
|
22
|
+
"create_attn_mask",
|
|
23
|
+
"sinusoidal_embedding",
|
|
24
|
+
"dropout",
|
|
25
|
+
"stochastic_depth",
|
|
26
|
+
"normalize",
|
|
27
|
+
"build_attention_mask",
|
|
28
|
+
"canonical_attn_mask",
|
|
29
|
+
"canonical_key_padding_mask",
|
|
30
|
+
"canonical_mask",
|
|
31
|
+
"make_causal_mask",
|
|
32
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def swiglu(x: Array, axis: int = -1) -> Array:
|
|
7
|
+
"""Apply the SwiGLU activation function.
|
|
8
|
+
|
|
9
|
+
Splits the input array into two halves along the specified axis, applies
|
|
10
|
+
the Swish activation to the first half, and multiplies it element-wise
|
|
11
|
+
with the second half. This gated activation is commonly used in
|
|
12
|
+
transformer feed-forward blocks.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
x: Input array. Its size along ``axis`` must be even.
|
|
16
|
+
axis: Axis along which to split the input into two halves.
|
|
17
|
+
Defaults to ``-1`` (last axis).
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Array of the same dtype as ``x`` with size halved along ``axis``.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> import jax.numpy as jnp
|
|
24
|
+
>>> x = jnp.array([1.0, 2.0, -1.0, 0.5])
|
|
25
|
+
>>> swiglu(x) # splits into [1., 2.] and [-1., 0.5]
|
|
26
|
+
Array([-0.26894143, 0.9526741 ], dtype=float32)
|
|
27
|
+
|
|
28
|
+
>>> # 2-D input, split along last axis (default)
|
|
29
|
+
>>> x2d = jnp.ones((3, 4))
|
|
30
|
+
>>> swiglu(x2d).shape
|
|
31
|
+
(3, 2)
|
|
32
|
+
"""
|
|
33
|
+
a, b = jnp.split(x, 2, axis=axis)
|
|
34
|
+
return jax.nn.swish(a) * b
|
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from beartype.typing import Any
|
|
7
|
+
from jaxtyping import Array, Bool, Float, PRNGKeyArray
|
|
8
|
+
|
|
9
|
+
from noxton.functions.normalization import normalize
|
|
10
|
+
from noxton.functions.regularization import dropout as dropout_fn
|
|
11
|
+
from noxton.utils.utils import default_floating_dtype
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def multi_head_attention_forward(
|
|
15
|
+
query: Float[Array, "tgt_len d_model"],
|
|
16
|
+
key: Float[Array, "src_len d_model"],
|
|
17
|
+
value: Float[Array, "src_len d_model"],
|
|
18
|
+
embed_dim_to_check: int,
|
|
19
|
+
num_heads: int,
|
|
20
|
+
in_proj_weight: Float[Array, "3*d_model d_model"] | None = None,
|
|
21
|
+
in_proj_bias: Float[Array, "3*d_model"] | None = None,
|
|
22
|
+
bias_k: Float[Array, "1 d_model"] | None = None,
|
|
23
|
+
bias_v: Float[Array, "1 d_model"] | None = None,
|
|
24
|
+
add_zero_attn: bool = False,
|
|
25
|
+
dropout_p: float = 0.0,
|
|
26
|
+
out_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
27
|
+
out_proj_bias: Float[Array, "d_model"] | None = None,
|
|
28
|
+
inference: bool = False,
|
|
29
|
+
key_padding_mask: Float[Array, "src_len"] | Bool[Array, "src_len"] | None = None,
|
|
30
|
+
attn_mask: Float[Array, "tgt_len src_len"] | None = None,
|
|
31
|
+
need_weights: bool = True,
|
|
32
|
+
use_separate_proj_weight: bool = False,
|
|
33
|
+
q_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
34
|
+
k_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
35
|
+
v_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
36
|
+
static_k: Float[Array, "src_len d_model"] | None = None,
|
|
37
|
+
static_v: Float[Array, "src_len d_model"] | None = None,
|
|
38
|
+
average_attn_weights: bool = True,
|
|
39
|
+
is_causal: bool = False,
|
|
40
|
+
dropout_key: PRNGKeyArray | None = None,
|
|
41
|
+
) -> tuple[
|
|
42
|
+
Float[Array, "tgt_len d_model"],
|
|
43
|
+
Float[Array, "num_heads tgt_len src_len"]
|
|
44
|
+
| Float[Array, "tgt_len src_len"]
|
|
45
|
+
| Float[Array, "tgt_len src_len+1"]
|
|
46
|
+
| None,
|
|
47
|
+
]:
|
|
48
|
+
"""Compute scaled dot-product multi-head attention.
|
|
49
|
+
|
|
50
|
+
This is the functional core of multi-head attention as described in
|
|
51
|
+
*Attention Is All You Need* (Vaswani et al., 2017). It accepts
|
|
52
|
+
pre-allocated weight matrices and returns the attended output along with
|
|
53
|
+
optional attention weights.
|
|
54
|
+
|
|
55
|
+
This function is a 1:1 mapping from the PyTorch implementation.
|
|
56
|
+
|
|
57
|
+
The function supports several advanced options:
|
|
58
|
+
|
|
59
|
+
- **Separate projection weights** (``use_separate_proj_weight=True``):
|
|
60
|
+
pass individual ``q_proj_weight``, ``k_proj_weight``, and
|
|
61
|
+
``v_proj_weight`` instead of a single fused ``in_proj_weight``.
|
|
62
|
+
- **Static key/value**: bypass the key/value projections by providing
|
|
63
|
+
pre-computed ``static_k`` / ``static_v`` tensors.
|
|
64
|
+
- **Extra key/value bias tokens** (``bias_k``, ``bias_v``): append
|
|
65
|
+
learnable bias tokens to the key and value sequences.
|
|
66
|
+
- **Zero-attention slot** (``add_zero_attn``): append a zero vector to
|
|
67
|
+
keys and values, giving the model an option to attend to nothing.
|
|
68
|
+
- **Causal masking** (``is_causal``): apply an upper-triangular ``-inf``
|
|
69
|
+
mask so each query can only attend to earlier or equal positions.
|
|
70
|
+
- **Key-padding mask** / **additive attention mask**: arbitrary masking
|
|
71
|
+
via ``key_padding_mask`` and ``attn_mask``.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
query: Query tensor of shape ``(tgt_len, d_model)``.
|
|
75
|
+
key: Key tensor of shape ``(src_len, d_model)``.
|
|
76
|
+
value: Value tensor of shape ``(src_len, d_model)``.
|
|
77
|
+
embed_dim_to_check: Expected embedding dimension; asserted to equal
|
|
78
|
+
``d_model``.
|
|
79
|
+
num_heads: Number of attention heads. Must divide ``d_model`` evenly.
|
|
80
|
+
in_proj_weight: Fused input projection weight of shape
|
|
81
|
+
``(3*d_model, d_model)``. Required when
|
|
82
|
+
``use_separate_proj_weight=False``.
|
|
83
|
+
in_proj_bias: Optional bias for the fused projection of shape
|
|
84
|
+
``(3*d_model,)``.
|
|
85
|
+
bias_k: Optional learnable key bias of shape ``(1, d_model)`` appended
|
|
86
|
+
to the key sequence.
|
|
87
|
+
bias_v: Optional learnable value bias of shape ``(1, d_model)``
|
|
88
|
+
appended to the value sequence. Must be provided together with
|
|
89
|
+
``bias_k``.
|
|
90
|
+
add_zero_attn: If ``True``, append a zero token to keys and values.
|
|
91
|
+
Defaults to ``False``.
|
|
92
|
+
dropout_p: Dropout probability applied to attention weights during
|
|
93
|
+
training. Defaults to ``0.0``.
|
|
94
|
+
out_proj_weight: Output projection weight of shape
|
|
95
|
+
``(d_model, d_model)``. Required.
|
|
96
|
+
out_proj_bias: Optional output projection bias of shape ``(d_model,)``.
|
|
97
|
+
inference: If ``True``, disable dropout (eval mode).
|
|
98
|
+
Defaults to ``False``.
|
|
99
|
+
key_padding_mask: Optional mask of shape ``(src_len,)`` marking padded
|
|
100
|
+
key positions. ``True`` / non-zero values are masked out (set to
|
|
101
|
+
``-inf`` in logits).
|
|
102
|
+
attn_mask: Optional additive attention mask of shape
|
|
103
|
+
``(tgt_len, src_len)``. Values are added to attention logits
|
|
104
|
+
before softmax.
|
|
105
|
+
need_weights: If ``True``, also return the attention weight matrix.
|
|
106
|
+
Defaults to ``True``.
|
|
107
|
+
use_separate_proj_weight: If ``True``, use ``q_proj_weight``,
|
|
108
|
+
``k_proj_weight``, ``v_proj_weight`` for projections instead of
|
|
109
|
+
``in_proj_weight``. Defaults to ``False``.
|
|
110
|
+
q_proj_weight: Query projection weight ``(d_model, d_model)``.
|
|
111
|
+
Required when ``use_separate_proj_weight=True``.
|
|
112
|
+
k_proj_weight: Key projection weight ``(d_model, d_model)``.
|
|
113
|
+
Required when ``use_separate_proj_weight=True``.
|
|
114
|
+
v_proj_weight: Value projection weight ``(d_model, d_model)``.
|
|
115
|
+
Required when ``use_separate_proj_weight=True``.
|
|
116
|
+
static_k: Pre-computed key of shape ``(src_len, d_model)``. When
|
|
117
|
+
provided, bypasses the key projection.
|
|
118
|
+
static_v: Pre-computed value of shape ``(src_len, d_model)``. When
|
|
119
|
+
provided, bypasses the value projection.
|
|
120
|
+
average_attn_weights: If ``True``, average attention weights over
|
|
121
|
+
heads before returning. Defaults to ``True``.
|
|
122
|
+
is_causal: If ``True``, apply a causal mask to prevent attending to
|
|
123
|
+
future positions. Defaults to ``False``.
|
|
124
|
+
dropout_key: JAX PRNG key for attention dropout. Required when
|
|
125
|
+
``dropout_p > 0.0`` and ``inference=False``.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A tuple ``(attn_output, attn_weights)`` where:
|
|
129
|
+
|
|
130
|
+
- ``attn_output``: Attended output of shape ``(tgt_len, d_model)``.
|
|
131
|
+
- ``attn_weights``: Attention weights or ``None`` when
|
|
132
|
+
``need_weights=False``. Shape depends on ``average_attn_weights``:
|
|
133
|
+
``(tgt_len, src_len)`` when averaged, or
|
|
134
|
+
``(num_heads, tgt_len, src_len)`` otherwise.
|
|
135
|
+
|
|
136
|
+
Example:
|
|
137
|
+
>>> import jax
|
|
138
|
+
>>> import jax.numpy as jnp
|
|
139
|
+
>>> key = jax.random.PRNGKey(0)
|
|
140
|
+
>>> d_model, num_heads, tgt_len, src_len = 8, 2, 4, 6
|
|
141
|
+
>>> q = jax.random.normal(key, (tgt_len, d_model))
|
|
142
|
+
>>> k = jax.random.normal(key, (src_len, d_model))
|
|
143
|
+
>>> v = jax.random.normal(key, (src_len, d_model))
|
|
144
|
+
>>> W = jax.random.normal(key, (3 * d_model, d_model))
|
|
145
|
+
>>> W_out = jax.random.normal(key, (d_model, d_model))
|
|
146
|
+
>>> out, weights = multi_head_attention_forward(
|
|
147
|
+
... q, k, v,
|
|
148
|
+
... embed_dim_to_check=d_model,
|
|
149
|
+
... num_heads=num_heads,
|
|
150
|
+
... in_proj_weight=W,
|
|
151
|
+
... out_proj_weight=W_out,
|
|
152
|
+
... inference=True,
|
|
153
|
+
... )
|
|
154
|
+
>>> out.shape
|
|
155
|
+
(4, 8)
|
|
156
|
+
>>> weights.shape # averaged over heads by default
|
|
157
|
+
(4, 6)
|
|
158
|
+
"""
|
|
159
|
+
tgt_len, d_model = query.shape
|
|
160
|
+
src_len, k_dim = key.shape
|
|
161
|
+
value_len, v_dim = value.shape
|
|
162
|
+
|
|
163
|
+
assert d_model == k_dim == v_dim == embed_dim_to_check, (
|
|
164
|
+
"Embedding dimensions must match"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
assert src_len == value_len, "Key and value must have the same sequence length"
|
|
168
|
+
|
|
169
|
+
head_dim = d_model // num_heads
|
|
170
|
+
assert head_dim * num_heads == d_model, "embed_dim must be divisible by num_heads"
|
|
171
|
+
|
|
172
|
+
if dropout_p > 0.0:
|
|
173
|
+
assert dropout_key is not None, (
|
|
174
|
+
"dropout_key must be provided if dropout_p > 0.0"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if use_separate_proj_weight:
|
|
178
|
+
# When using separate projection weights for q, k, v
|
|
179
|
+
assert q_proj_weight is not None, (
|
|
180
|
+
"q_proj_weight should not be None when use_separate_proj_weight=True"
|
|
181
|
+
)
|
|
182
|
+
assert k_proj_weight is not None, (
|
|
183
|
+
"k_proj_weight should not be None when use_separate_proj_weight=True"
|
|
184
|
+
)
|
|
185
|
+
assert v_proj_weight is not None, (
|
|
186
|
+
"v_proj_weight should not be None when use_separate_proj_weight=True"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
q = query @ q_proj_weight.T
|
|
190
|
+
|
|
191
|
+
if static_k is None:
|
|
192
|
+
k = key @ k_proj_weight.T
|
|
193
|
+
else:
|
|
194
|
+
k = static_k
|
|
195
|
+
src_len, _ = k.shape
|
|
196
|
+
|
|
197
|
+
if static_v is None:
|
|
198
|
+
v = value @ v_proj_weight.T
|
|
199
|
+
else:
|
|
200
|
+
v = static_v
|
|
201
|
+
value_len, _ = v.shape
|
|
202
|
+
|
|
203
|
+
if in_proj_bias is not None:
|
|
204
|
+
q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
|
|
205
|
+
q = q + q_bias
|
|
206
|
+
k = k + k_bias
|
|
207
|
+
v = v + v_bias
|
|
208
|
+
|
|
209
|
+
else:
|
|
210
|
+
assert in_proj_weight is not None, (
|
|
211
|
+
"in_proj_weight should not be None when use_separate_proj_weight=False"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
q_proj_weight_part, k_proj_weight_part, v_proj_weight_part = jnp.split(
|
|
215
|
+
in_proj_weight, 3
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
q = query @ q_proj_weight_part.T
|
|
219
|
+
|
|
220
|
+
if static_k is None:
|
|
221
|
+
k = key @ k_proj_weight_part.T
|
|
222
|
+
else:
|
|
223
|
+
k = static_k
|
|
224
|
+
src_len, _ = static_k.shape
|
|
225
|
+
|
|
226
|
+
if static_v is None:
|
|
227
|
+
v = value @ v_proj_weight_part.T
|
|
228
|
+
else:
|
|
229
|
+
v = static_v
|
|
230
|
+
value_len, _ = static_v.shape
|
|
231
|
+
|
|
232
|
+
if in_proj_bias is not None:
|
|
233
|
+
q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
|
|
234
|
+
q = q + q_bias
|
|
235
|
+
k = k + k_bias
|
|
236
|
+
v = v + v_bias
|
|
237
|
+
|
|
238
|
+
assert src_len == value_len
|
|
239
|
+
|
|
240
|
+
q = q.reshape(tgt_len, num_heads, head_dim)
|
|
241
|
+
k = k.reshape(src_len, num_heads, head_dim)
|
|
242
|
+
v = v.reshape(src_len, num_heads, head_dim)
|
|
243
|
+
|
|
244
|
+
if add_zero_attn:
|
|
245
|
+
zero_attn_shape = (1, num_heads, head_dim)
|
|
246
|
+
k_zeros = jnp.zeros(zero_attn_shape)
|
|
247
|
+
v_zeros = jnp.zeros(zero_attn_shape)
|
|
248
|
+
|
|
249
|
+
k = jnp.concatenate([k, k_zeros], axis=0)
|
|
250
|
+
v = jnp.concatenate([v, v_zeros], axis=0)
|
|
251
|
+
|
|
252
|
+
src_len += 1
|
|
253
|
+
value_len += 1
|
|
254
|
+
|
|
255
|
+
if bias_k is not None and bias_v is not None:
|
|
256
|
+
bias_k = bias_k.reshape(1, num_heads, head_dim)
|
|
257
|
+
bias_v = bias_v.reshape(1, num_heads, head_dim)
|
|
258
|
+
|
|
259
|
+
k = jnp.concatenate([k, bias_k], axis=0)
|
|
260
|
+
v = jnp.concatenate([v, bias_v], axis=0)
|
|
261
|
+
|
|
262
|
+
src_len += 1
|
|
263
|
+
value_len += 1
|
|
264
|
+
|
|
265
|
+
assert src_len == value_len
|
|
266
|
+
|
|
267
|
+
# [tgt_len, num_heads, head_dim] → [num_heads, tgt_len, head_dim]
|
|
268
|
+
q = jnp.transpose(q, (1, 0, 2))
|
|
269
|
+
|
|
270
|
+
# [src_len, num_heads, head_dim] → [num_heads, src_len, head_dim]
|
|
271
|
+
k = jnp.transpose(k, (1, 0, 2))
|
|
272
|
+
v = jnp.transpose(v, (1, 0, 2))
|
|
273
|
+
|
|
274
|
+
scale = jnp.sqrt(head_dim)
|
|
275
|
+
attn_output_weights = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) / scale
|
|
276
|
+
|
|
277
|
+
if key_padding_mask is not None:
|
|
278
|
+
padding_mask = key_padding_mask.reshape(1, 1, src_len)
|
|
279
|
+
padding_mask = jnp.repeat(padding_mask, num_heads, axis=0)
|
|
280
|
+
padding_mask = jnp.repeat(padding_mask, tgt_len, axis=1)
|
|
281
|
+
attn_output_weights = jnp.where(
|
|
282
|
+
padding_mask, float("-inf"), attn_output_weights
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
if attn_mask is not None:
|
|
286
|
+
# [tgt_len, src_len] -> [num_heads, tgt_len, src_len]
|
|
287
|
+
mask = attn_mask.reshape(1, tgt_len, src_len)
|
|
288
|
+
mask = jnp.repeat(mask, num_heads, axis=0)
|
|
289
|
+
attn_output_weights = attn_output_weights + mask
|
|
290
|
+
|
|
291
|
+
if is_causal:
|
|
292
|
+
causal_mask = jnp.triu(jnp.ones((tgt_len, src_len)), k=1)
|
|
293
|
+
causal_mask = (causal_mask == 1).reshape(1, tgt_len, src_len)
|
|
294
|
+
causal_mask = jnp.repeat(causal_mask, num_heads, axis=0)
|
|
295
|
+
attn_output_weights = jnp.where(causal_mask, float("-inf"), attn_output_weights)
|
|
296
|
+
|
|
297
|
+
# [num_heads, tgt_len, src_len]
|
|
298
|
+
attn_output_weights = jax.nn.softmax(attn_output_weights, axis=-1)
|
|
299
|
+
|
|
300
|
+
if dropout_p > 0.0 and not inference:
|
|
301
|
+
assert dropout_key is not None, (
|
|
302
|
+
"dropout_key required because dropout_p > 0.0 and training"
|
|
303
|
+
)
|
|
304
|
+
dropout_mask = jax.random.bernoulli(
|
|
305
|
+
dropout_key, 1 - dropout_p, attn_output_weights.shape
|
|
306
|
+
)
|
|
307
|
+
scale = 1.0 / (1.0 - dropout_p)
|
|
308
|
+
attn_output_weights = attn_output_weights * dropout_mask * scale
|
|
309
|
+
|
|
310
|
+
attn_output = jnp.matmul(attn_output_weights, v)
|
|
311
|
+
attn_output = jnp.transpose(attn_output, (1, 0, 2))
|
|
312
|
+
attn_output = attn_output.reshape(tgt_len, d_model)
|
|
313
|
+
|
|
314
|
+
assert out_proj_weight is not None, "out_proj_weight must be provided"
|
|
315
|
+
attn_output = attn_output @ out_proj_weight.T
|
|
316
|
+
|
|
317
|
+
if out_proj_bias is not None:
|
|
318
|
+
attn_output = attn_output + out_proj_bias
|
|
319
|
+
|
|
320
|
+
if need_weights:
|
|
321
|
+
if average_attn_weights:
|
|
322
|
+
attn_output_weights = attn_output_weights.mean(axis=0)
|
|
323
|
+
return attn_output, attn_output_weights
|
|
324
|
+
else:
|
|
325
|
+
return attn_output, None
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def create_attn_mask(
|
|
329
|
+
pad_H: int,
|
|
330
|
+
pad_W: int,
|
|
331
|
+
window_size: list[int],
|
|
332
|
+
shift_size: list[int],
|
|
333
|
+
dtype: Any | None = None,
|
|
334
|
+
) -> Array:
|
|
335
|
+
"""Build the region-index mask used by shifted-window attention.
|
|
336
|
+
|
|
337
|
+
Assigns each spatial position in a ``(pad_H, pad_W)`` feature map an
|
|
338
|
+
integer *region index* that encodes which window it belongs to after a
|
|
339
|
+
cyclic shift. Positions in the same window share the same region index;
|
|
340
|
+
positions in different windows have different indices. The mask is used
|
|
341
|
+
downstream to zero out cross-window attention scores.
|
|
342
|
+
|
|
343
|
+
The region indices are computed by partitioning the height axis at
|
|
344
|
+
``pad_H - window_size[0]`` and ``pad_H - shift_size[0]``, and the width
|
|
345
|
+
axis at ``pad_W - window_size[1]`` and ``pad_W - shift_size[1]``, then
|
|
346
|
+
assigning a unique integer to each ``(row_region, col_region)`` cell.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
pad_H: Padded feature map height (must be a multiple of
|
|
350
|
+
``window_size[0]``).
|
|
351
|
+
pad_W: Padded feature map width (must be a multiple of
|
|
352
|
+
``window_size[1]``).
|
|
353
|
+
window_size: Local attention window size as ``[window_H, window_W]``.
|
|
354
|
+
shift_size: Cyclic shift amounts as ``[shift_H, shift_W]``.
|
|
355
|
+
dtype: Output array dtype. Defaults to the project's default floating
|
|
356
|
+
dtype when ``None``.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Integer-region-index map of shape ``(pad_H, pad_W)`` cast to
|
|
360
|
+
``dtype``.
|
|
361
|
+
|
|
362
|
+
Example:
|
|
363
|
+
>>> mask = create_attn_mask(8, 8, window_size=[4, 4], shift_size=[2, 2])
|
|
364
|
+
>>> mask.shape
|
|
365
|
+
(8, 8)
|
|
366
|
+
"""
|
|
367
|
+
if dtype is None:
|
|
368
|
+
dtype = default_floating_dtype()
|
|
369
|
+
assert dtype is not None
|
|
370
|
+
h_boundaries = jnp.array([pad_H - window_size[0], pad_H - shift_size[0]])
|
|
371
|
+
w_boundaries = jnp.array([pad_W - window_size[1], pad_W - shift_size[1]])
|
|
372
|
+
|
|
373
|
+
h_boundaries = jnp.sort(h_boundaries)
|
|
374
|
+
w_boundaries = jnp.sort(w_boundaries)
|
|
375
|
+
|
|
376
|
+
ii, jj = jnp.indices((pad_H, pad_W)) # ii for rows, jj for columns
|
|
377
|
+
|
|
378
|
+
row_region_idx = jnp.searchsorted(h_boundaries, ii, side="right")
|
|
379
|
+
col_region_idx = jnp.searchsorted(w_boundaries, jj, side="right")
|
|
380
|
+
|
|
381
|
+
num_col_regions = len(w_boundaries) + 1
|
|
382
|
+
attn_mask = row_region_idx * num_col_regions + col_region_idx
|
|
383
|
+
|
|
384
|
+
return attn_mask.astype(dtype)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def shifted_window_attention(
|
|
388
|
+
x: Float[Array, "H W C"],
|
|
389
|
+
qkv_weight: Float[Array, "in_dim out_dim"],
|
|
390
|
+
proj_weight: Float[Array, "out_dim out_dim"],
|
|
391
|
+
relative_position_bias: Array,
|
|
392
|
+
window_size: list[int],
|
|
393
|
+
num_heads: int,
|
|
394
|
+
shift_size: list[int],
|
|
395
|
+
attention_dropout: float = 0.0,
|
|
396
|
+
dropout: float = 0.0,
|
|
397
|
+
qkv_bias: Array | None = None,
|
|
398
|
+
proj_bias: Array | None = None,
|
|
399
|
+
logit_scale: Array | None = None,
|
|
400
|
+
inference: bool = False,
|
|
401
|
+
key: PRNGKeyArray | None = None,
|
|
402
|
+
) -> Float[Array, "H W C"]:
|
|
403
|
+
"""Apply Shifted-Window Multi-Head Self-Attention (Swin Attention).
|
|
404
|
+
|
|
405
|
+
Implements the window-based self-attention mechanism from the Swin
|
|
406
|
+
Transformer (Liu et al., 2021). The input feature map is partitioned
|
|
407
|
+
into non-overlapping local windows; attention is computed independently
|
|
408
|
+
within each window. A cyclic spatial shift (``shift_size``) is applied
|
|
409
|
+
before partitioning to create cross-window connections, and a region-index
|
|
410
|
+
mask is used to prevent attending across window boundaries introduced by
|
|
411
|
+
the padding/shift.
|
|
412
|
+
|
|
413
|
+
Supports two attention score variants:
|
|
414
|
+
|
|
415
|
+
- **Scaled dot-product** (``logit_scale=None``): standard
|
|
416
|
+
``q @ k.T / sqrt(head_dim)``.
|
|
417
|
+
- **Cosine attention** (when ``logit_scale`` is provided): L2-normalised
|
|
418
|
+
``q`` and ``k`` are used, and scores are multiplied by a learnable
|
|
419
|
+
(but bounded) temperature ``exp(min(logit_scale, log(100)))``.
|
|
420
|
+
|
|
421
|
+
Relative position biases are added to the attention logits before softmax.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
x: Input feature map of shape ``(H, W, C)``.
|
|
425
|
+
qkv_weight: Fused QKV projection weight of shape
|
|
426
|
+
``(in_dim, out_dim)`` where ``out_dim = 3 * C``.
|
|
427
|
+
proj_weight: Output projection weight of shape ``(C, C)``.
|
|
428
|
+
relative_position_bias: Relative position bias tensor added to
|
|
429
|
+
attention logits; shape must broadcast with
|
|
430
|
+
``(num_windows, num_heads, window_H*window_W, window_H*window_W)``.
|
|
431
|
+
window_size: Local window size as ``[window_H, window_W]``.
|
|
432
|
+
num_heads: Number of attention heads. Must divide ``C`` evenly.
|
|
433
|
+
shift_size: Cyclic shift amounts as ``[shift_H, shift_W]``.
|
|
434
|
+
Use ``[0, 0]`` to disable shifting (regular window attention).
|
|
435
|
+
attention_dropout: Dropout probability applied to attention weights.
|
|
436
|
+
Defaults to ``0.0``.
|
|
437
|
+
dropout: Dropout probability applied to the output projection.
|
|
438
|
+
Defaults to ``0.0``.
|
|
439
|
+
qkv_bias: Optional bias for the QKV projection of shape
|
|
440
|
+
``(3 * C,)``. When ``logit_scale`` is also provided, the key
|
|
441
|
+
bias component is zeroed out (as in Swin V2).
|
|
442
|
+
proj_bias: Optional bias for the output projection of shape ``(C,)``.
|
|
443
|
+
logit_scale: Learnable log-scale scalar for cosine attention. When
|
|
444
|
+
``None``, standard scaled dot-product attention is used.
|
|
445
|
+
inference: If ``True``, disable dropout. Defaults to ``False``.
|
|
446
|
+
key: JAX PRNG key required when ``inference=False``.
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
Output feature map of shape ``(H, W, C)``.
|
|
450
|
+
|
|
451
|
+
Raises:
|
|
452
|
+
ValueError: If ``inference=False`` and ``key`` is ``None``.
|
|
453
|
+
|
|
454
|
+
Example:
|
|
455
|
+
>>> import jax
|
|
456
|
+
>>> import jax.numpy as jnp
|
|
457
|
+
>>> H, W, C, num_heads = 8, 8, 16, 2
|
|
458
|
+
>>> window_size, shift_size = [4, 4], [2, 2]
|
|
459
|
+
>>> key = jax.random.PRNGKey(0)
|
|
460
|
+
>>> x = jax.random.normal(key, (H, W, C))
|
|
461
|
+
>>> qkv_w = jax.random.normal(key, (C, 3 * C))
|
|
462
|
+
>>> proj_w = jax.random.normal(key, (C, C))
|
|
463
|
+
>>> win_tokens = window_size[0] * window_size[1]
|
|
464
|
+
>>> num_wins = (H // window_size[0]) * (W // window_size[1])
|
|
465
|
+
>>> rpb = jnp.zeros((num_wins, num_heads, win_tokens, win_tokens))
|
|
466
|
+
>>> out = shifted_window_attention(
|
|
467
|
+
... x, qkv_w, proj_w, rpb,
|
|
468
|
+
... window_size=window_size, num_heads=num_heads,
|
|
469
|
+
... shift_size=shift_size, inference=True,
|
|
470
|
+
... )
|
|
471
|
+
>>> out.shape
|
|
472
|
+
(8, 8, 16)
|
|
473
|
+
"""
|
|
474
|
+
if not inference and key is None:
|
|
475
|
+
raise ValueError("Need key when in training mode")
|
|
476
|
+
H, W, C = x.shape
|
|
477
|
+
to_pad_W = (window_size[1] - W % window_size[1]) % window_size[1]
|
|
478
|
+
to_pad_H = (window_size[0] - H % window_size[0]) % window_size[0]
|
|
479
|
+
x = jnp.pad(x, ((0, to_pad_H), (0, to_pad_W), (0, 0)))
|
|
480
|
+
pad_H, pad_W, _ = x.shape
|
|
481
|
+
|
|
482
|
+
shift_size = shift_size.copy()
|
|
483
|
+
if window_size[0] >= pad_H:
|
|
484
|
+
shift_size[0] = 0
|
|
485
|
+
if window_size[1] >= pad_W:
|
|
486
|
+
shift_size[1] = 0
|
|
487
|
+
|
|
488
|
+
# cyclic shift
|
|
489
|
+
if sum(shift_size) > 0:
|
|
490
|
+
x = jnp.roll(x, shift=(-shift_size[0], -shift_size[1]), axis=(0, 1))
|
|
491
|
+
|
|
492
|
+
# partition windows
|
|
493
|
+
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
|
|
494
|
+
x = jnp.reshape(
|
|
495
|
+
x,
|
|
496
|
+
(
|
|
497
|
+
pad_H // window_size[0],
|
|
498
|
+
window_size[0],
|
|
499
|
+
pad_W // window_size[1],
|
|
500
|
+
window_size[1],
|
|
501
|
+
C,
|
|
502
|
+
),
|
|
503
|
+
)
|
|
504
|
+
x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(
|
|
505
|
+
num_windows, window_size[0] * window_size[1], C
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# multi-head attention
|
|
509
|
+
if logit_scale is not None and qkv_bias is not None:
|
|
510
|
+
length = qkv_bias.size // 3
|
|
511
|
+
qkv_bias = qkv_bias.at[length : 2 * length].set(0.0)
|
|
512
|
+
|
|
513
|
+
def linear(x: Array, weight: Array, bias: Array | None):
|
|
514
|
+
output = x @ jnp.transpose(weight) # (in,) @ (in, out) -> (out,)
|
|
515
|
+
if bias is not None:
|
|
516
|
+
output = output + bias
|
|
517
|
+
return output
|
|
518
|
+
|
|
519
|
+
linear_pt = functools.partial(linear, weight=qkv_weight, bias=qkv_bias)
|
|
520
|
+
|
|
521
|
+
qkv = eqx.filter_vmap(eqx.filter_vmap(linear_pt))(x)
|
|
522
|
+
win_size, patches, _ = qkv.shape
|
|
523
|
+
qkv = jnp.transpose(
|
|
524
|
+
qkv.reshape(win_size, patches, 3, num_heads, C // num_heads), (2, 0, 3, 1, 4)
|
|
525
|
+
)
|
|
526
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
527
|
+
if logit_scale is not None:
|
|
528
|
+
# cosine attention
|
|
529
|
+
attn = normalize(q, axis=-1) @ jnp.transpose(
|
|
530
|
+
normalize(k, axis=-1), (0, 1, 3, 2)
|
|
531
|
+
)
|
|
532
|
+
# Clamp the logit scale exponent for stability
|
|
533
|
+
logit_scale = jnp.exp(jnp.minimum(logit_scale, jnp.log(jnp.array(100.0))))
|
|
534
|
+
attn = attn * logit_scale
|
|
535
|
+
else:
|
|
536
|
+
q = q * (C // num_heads) ** -0.5
|
|
537
|
+
# attn = q @ (jnp.transpose(normalize(k, axis=-1), (0, 1, 3, 2))) # Incorrect
|
|
538
|
+
attn = q @ jnp.transpose(k, (0, 1, 3, 2)) # Corrected: q @ k.T
|
|
539
|
+
|
|
540
|
+
# add relative position bias
|
|
541
|
+
attn = attn + relative_position_bias
|
|
542
|
+
|
|
543
|
+
if sum(shift_size) > 0:
|
|
544
|
+
attn_mask = create_attn_mask(pad_H, pad_W, window_size, shift_size)
|
|
545
|
+
attn_mask = attn_mask.reshape(
|
|
546
|
+
pad_H // window_size[0],
|
|
547
|
+
window_size[0],
|
|
548
|
+
pad_W // window_size[1],
|
|
549
|
+
window_size[1],
|
|
550
|
+
)
|
|
551
|
+
attn_mask = jnp.transpose(attn_mask, (0, 2, 1, 3)).reshape(
|
|
552
|
+
num_windows, window_size[0] * window_size[1]
|
|
553
|
+
)
|
|
554
|
+
attn_mask = jnp.expand_dims(attn_mask, axis=1) - jnp.expand_dims(
|
|
555
|
+
attn_mask, axis=2
|
|
556
|
+
)
|
|
557
|
+
attn_mask = jnp.where(attn_mask == 0, 0.0, -100.0)
|
|
558
|
+
|
|
559
|
+
attn = attn + attn_mask[:, None, :, :]
|
|
560
|
+
|
|
561
|
+
attn = jax.nn.softmax(attn, axis=-1)
|
|
562
|
+
if not inference:
|
|
563
|
+
assert key is not None, "key must be given if not inference"
|
|
564
|
+
key, subkey = jax.random.split(key)
|
|
565
|
+
attn = dropout_fn(attn, p=attention_dropout, inference=inference, key=subkey)
|
|
566
|
+
|
|
567
|
+
x = jnp.transpose(attn @ v, (0, 2, 1, 3)).reshape(
|
|
568
|
+
num_windows, window_size[0] * window_size[1], C
|
|
569
|
+
)
|
|
570
|
+
linear_pt_proj = functools.partial(linear, weight=proj_weight, bias=proj_bias)
|
|
571
|
+
|
|
572
|
+
x = eqx.filter_vmap(eqx.filter_vmap(linear_pt_proj))(x)
|
|
573
|
+
if not inference:
|
|
574
|
+
assert key is not None, "key must be given if not inference"
|
|
575
|
+
key, subkey = jax.random.split(key)
|
|
576
|
+
x = dropout_fn(x, p=dropout, inference=inference, key=subkey)
|
|
577
|
+
|
|
578
|
+
# reverse windows
|
|
579
|
+
x = x.reshape(
|
|
580
|
+
pad_H // window_size[0],
|
|
581
|
+
pad_W // window_size[1],
|
|
582
|
+
window_size[0],
|
|
583
|
+
window_size[1],
|
|
584
|
+
C,
|
|
585
|
+
)
|
|
586
|
+
x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(pad_H, pad_W, C)
|
|
587
|
+
|
|
588
|
+
# reverse cyclic shift
|
|
589
|
+
if sum(shift_size) > 0:
|
|
590
|
+
x = jnp.roll(x, shift=(shift_size[0], shift_size[1]), axis=(0, 1))
|
|
591
|
+
|
|
592
|
+
# unpad features
|
|
593
|
+
x = x[:H, :W, :]
|
|
594
|
+
return x
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from beartype.typing import Any
|
|
3
|
+
from jaxtyping import Array, Float, Int
|
|
4
|
+
|
|
5
|
+
from noxton.utils import default_floating_dtype
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def sinusoidal_embedding(
|
|
9
|
+
t: Int[Array, ""], embedding_size: int, dtype: Any | None = None
|
|
10
|
+
) -> Float[Array, " embedding_size"]:
|
|
11
|
+
"""Compute a sinusoidal positional embedding for a scalar timestep.
|
|
12
|
+
|
|
13
|
+
Encodes a single integer timestep ``t`` into a fixed-size vector using
|
|
14
|
+
alternating sine and cosine functions at geometrically spaced frequencies,
|
|
15
|
+
following the scheme from *Attention Is All You Need* (Vaswani et al., 2017)
|
|
16
|
+
and commonly used in diffusion model timestep conditioning.
|
|
17
|
+
|
|
18
|
+
The embedding frequencies are::
|
|
19
|
+
|
|
20
|
+
freq_i = exp(-log(10000) * i / (embedding_size / 2)) for i in [0, half_dim)
|
|
21
|
+
|
|
22
|
+
The output is the concatenation of ``sin(t * freq)`` and ``cos(t * freq)``.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
t: Scalar integer timestep (0-d array).
|
|
26
|
+
embedding_size: Length of the output embedding vector. Must be even.
|
|
27
|
+
dtype: Floating-point dtype for the output. Defaults to the project's
|
|
28
|
+
default floating dtype when ``None``.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
1-D array of shape ``(embedding_size,)`` containing the sinusoidal
|
|
32
|
+
embedding for timestep ``t``.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If ``embedding_size`` is odd.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> import jax.numpy as jnp
|
|
39
|
+
>>> t = jnp.array(100)
|
|
40
|
+
>>> emb = sinusoidal_embedding(t, embedding_size=16)
|
|
41
|
+
>>> emb.shape
|
|
42
|
+
(16,)
|
|
43
|
+
|
|
44
|
+
>>> # Embeddings for different timesteps are distinct
|
|
45
|
+
>>> emb0 = sinusoidal_embedding(jnp.array(0), embedding_size=8)
|
|
46
|
+
>>> emb1 = sinusoidal_embedding(jnp.array(1), embedding_size=8)
|
|
47
|
+
>>> bool(jnp.any(emb0 != emb1))
|
|
48
|
+
True
|
|
49
|
+
"""
|
|
50
|
+
if dtype is None:
|
|
51
|
+
dtype = default_floating_dtype()
|
|
52
|
+
assert dtype is not None
|
|
53
|
+
if embedding_size % 2 != 0:
|
|
54
|
+
raise ValueError(f"Embedding size must be even, but got {embedding_size}")
|
|
55
|
+
|
|
56
|
+
half_dim = embedding_size // 2
|
|
57
|
+
embedding_freqs = jnp.exp(
|
|
58
|
+
-jnp.log(10000) * jnp.arange(start=0, stop=half_dim, dtype=dtype) / half_dim
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
time_args = t * embedding_freqs
|
|
62
|
+
embedding = jnp.concatenate([jnp.sin(time_args), jnp.cos(time_args)], axis=-1)
|
|
63
|
+
|
|
64
|
+
return embedding
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jaxtyping import Array, Bool
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def canonical_mask(
|
|
6
|
+
mask,
|
|
7
|
+
mask_name,
|
|
8
|
+
other_name="",
|
|
9
|
+
other_type=None,
|
|
10
|
+
target_type=jnp.float32,
|
|
11
|
+
other_mask=None,
|
|
12
|
+
check_other=True,
|
|
13
|
+
):
|
|
14
|
+
"""Convert an arbitrary mask tensor into a canonical additive float mask.
|
|
15
|
+
|
|
16
|
+
Accepts boolean, integer, or floating-point masks and normalises them to
|
|
17
|
+
a floating-point additive mask (``0.0`` for positions to attend to,
|
|
18
|
+
``-inf`` for positions to ignore) that can be directly added to attention
|
|
19
|
+
logits.
|
|
20
|
+
|
|
21
|
+
- **Boolean masks**: ``True`` → ``-inf``, ``False`` → ``0.0``
|
|
22
|
+
- **Integer / float masks**: cast to ``target_type`` without modification.
|
|
23
|
+
- ``None`` input returns ``None`` (no mask).
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
mask: The mask to canonicalise. May be a boolean, integer, or
|
|
27
|
+
floating-point JAX array, or ``None``.
|
|
28
|
+
mask_name: Human-readable name of this mask used in error messages.
|
|
29
|
+
other_name: Human-readable name of the secondary mask (used in error
|
|
30
|
+
messages only). Defaults to ``""``.
|
|
31
|
+
other_type: Expected dtype of ``other_mask`` (currently unused in the
|
|
32
|
+
implementation but kept for API compatibility). Defaults to
|
|
33
|
+
``None``.
|
|
34
|
+
target_type: Target floating-point dtype for the output. Defaults to
|
|
35
|
+
``jnp.float32``.
|
|
36
|
+
other_mask: A secondary mask for cross-validation purposes (currently
|
|
37
|
+
unused). Defaults to ``None``.
|
|
38
|
+
check_other: Whether to perform cross-mask validation (currently
|
|
39
|
+
unused). Defaults to ``True``.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A floating-point array of dtype ``target_type`` with the same shape
|
|
43
|
+
as ``mask``, or ``None`` if ``mask`` is ``None``.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
TypeError: If ``mask`` has an unsupported dtype (not bool, int, or
|
|
47
|
+
float).
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> import jax.numpy as jnp
|
|
51
|
+
>>> bool_mask = jnp.array([True, False, True])
|
|
52
|
+
>>> canonical_mask(bool_mask, "attn_mask")
|
|
53
|
+
Array([-inf, 0., -inf], dtype=float32)
|
|
54
|
+
|
|
55
|
+
>>> float_mask = jnp.array([0.0, -1e9, 0.0])
|
|
56
|
+
>>> canonical_mask(float_mask, "attn_mask")
|
|
57
|
+
Array([ 0.e+00, -1.e+09, 0.e+00], dtype=float32)
|
|
58
|
+
"""
|
|
59
|
+
if mask is None:
|
|
60
|
+
return None
|
|
61
|
+
if mask.dtype == bool:
|
|
62
|
+
additive_mask = jnp.where(mask, -jnp.inf, 0.0).astype(target_type)
|
|
63
|
+
return additive_mask
|
|
64
|
+
elif jnp.issubdtype(mask.dtype, jnp.integer) or jnp.issubdtype(
|
|
65
|
+
mask.dtype, jnp.floating
|
|
66
|
+
):
|
|
67
|
+
return mask.astype(target_type)
|
|
68
|
+
else:
|
|
69
|
+
raise TypeError(
|
|
70
|
+
f"{mask_name} must be bool, int, or float tensor, but got {mask.dtype}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def canonical_key_padding_mask(
|
|
75
|
+
key_padding_mask, attn_mask=None, query_dtype=jnp.float32
|
|
76
|
+
):
|
|
77
|
+
"""Convert a key-padding mask to a canonical additive float mask.
|
|
78
|
+
|
|
79
|
+
A convenience wrapper around :func:`canonical_mask` that applies the
|
|
80
|
+
correct argument names for key-padding masks. The resulting mask can be
|
|
81
|
+
directly added to attention logits: padded positions get ``-inf`` (boolean
|
|
82
|
+
``True`` input) while real positions get ``0.0``.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
key_padding_mask: Mask of shape ``(src_len,)`` indicating which key
|
|
86
|
+
positions are padding. May be boolean (``True`` = padded),
|
|
87
|
+
integer, or float, or ``None``.
|
|
88
|
+
attn_mask: The attention mask that will be combined with this mask
|
|
89
|
+
(used for error reporting only). Defaults to ``None``.
|
|
90
|
+
query_dtype: Target floating-point dtype for the output. Defaults to
|
|
91
|
+
``jnp.float32``.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A floating-point array of dtype ``query_dtype`` or ``None``.
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
>>> import jax.numpy as jnp
|
|
98
|
+
>>> kpm = jnp.array([False, False, True]) # last token is padding
|
|
99
|
+
>>> canonical_key_padding_mask(kpm)
|
|
100
|
+
Array([ 0., 0., -inf], dtype=float32)
|
|
101
|
+
"""
|
|
102
|
+
return canonical_mask(
|
|
103
|
+
mask=key_padding_mask,
|
|
104
|
+
mask_name="key_padding_mask",
|
|
105
|
+
other_name="attn_mask",
|
|
106
|
+
other_mask=attn_mask,
|
|
107
|
+
target_type=query_dtype,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def canonical_attn_mask(attn_mask, query_dtype=jnp.float32):
|
|
112
|
+
"""Convert an attention mask to a canonical additive float mask.
|
|
113
|
+
|
|
114
|
+
A convenience wrapper around :func:`canonical_mask` that applies the
|
|
115
|
+
correct argument names for attention masks. Boolean masks are converted
|
|
116
|
+
so that ``True`` positions are masked out (``-inf``) and ``False``
|
|
117
|
+
positions are kept (``0.0``). Numeric masks are cast to ``query_dtype``
|
|
118
|
+
without modification, allowing pre-built additive masks (e.g. causal
|
|
119
|
+
masks with ``0`` and ``-inf``) to be passed through unchanged.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
attn_mask: Attention mask of shape ``(tgt_len, src_len)``. May be
|
|
123
|
+
boolean, integer, or floating-point, or ``None``.
|
|
124
|
+
query_dtype: Target floating-point dtype for the output. Defaults to
|
|
125
|
+
``jnp.float32``.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A floating-point array of dtype ``query_dtype`` or ``None``.
|
|
129
|
+
|
|
130
|
+
Example:
|
|
131
|
+
>>> import jax.numpy as jnp
|
|
132
|
+
>>> mask = jnp.array([[False, True], [False, False]])
|
|
133
|
+
>>> canonical_attn_mask(mask)
|
|
134
|
+
Array([[ 0., -inf],
|
|
135
|
+
[ 0., 0.]], dtype=float32)
|
|
136
|
+
"""
|
|
137
|
+
return canonical_mask(
|
|
138
|
+
mask=attn_mask,
|
|
139
|
+
mask_name="attn_mask",
|
|
140
|
+
other_type=None,
|
|
141
|
+
other_name="",
|
|
142
|
+
target_type=query_dtype,
|
|
143
|
+
check_other=False,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def make_causal_mask(seq_len: int) -> Bool[Array, "seq_len seq_len"]:
|
|
148
|
+
"""Create a boolean lower-triangular causal mask.
|
|
149
|
+
|
|
150
|
+
Position ``[i, j]`` is ``True`` when position ``j`` is allowed to attend
|
|
151
|
+
to position ``i`` (i.e. ``j <= i``), and ``False`` otherwise. Use this
|
|
152
|
+
mask to prevent tokens from attending to future positions.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
seq_len: Sequence length; the output is a square matrix of shape
|
|
156
|
+
``(seq_len, seq_len)``.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Boolean array of shape ``(seq_len, seq_len)`` where the lower
|
|
160
|
+
triangle (including the diagonal) is ``True`` and the upper triangle
|
|
161
|
+
is ``False``.
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
>>> make_causal_mask(3)
|
|
165
|
+
Array([[ True, False, False],
|
|
166
|
+
[ True, True, False],
|
|
167
|
+
[ True, True, True]], dtype=bool)
|
|
168
|
+
"""
|
|
169
|
+
return jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def build_attention_mask(context_length: int) -> Array:
|
|
173
|
+
"""Create an additive causal attention mask with ``0`` and ``-inf``.
|
|
174
|
+
|
|
175
|
+
Produces a float32 lower-triangular matrix where attended positions
|
|
176
|
+
carry ``0.0`` (no change to logits) and future positions carry ``-inf``
|
|
177
|
+
(zeroed out after softmax). The mask can be added directly to attention
|
|
178
|
+
logit matrices.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
context_length: Sequence length; the output is a square matrix of
|
|
182
|
+
shape ``(context_length, context_length)``.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Float32 array of shape ``(context_length, context_length)`` with
|
|
186
|
+
``0.0`` on and below the diagonal and ``-inf`` above the diagonal.
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
>>> build_attention_mask(3)
|
|
190
|
+
Array([[ 0., -inf, -inf],
|
|
191
|
+
[ 0., 0., -inf],
|
|
192
|
+
[ 0., 0., 0.]], dtype=float32)
|
|
193
|
+
"""
|
|
194
|
+
mask = jnp.tril(jnp.zeros((context_length, context_length)))
|
|
195
|
+
upper = jnp.triu(jnp.full((context_length, context_length), float("-inf")), k=1)
|
|
196
|
+
|
|
197
|
+
mask = mask + upper
|
|
198
|
+
return mask
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jaxtyping import Array
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def normalize(x: Array, p: int = 2, axis: int = 1, eps: float = 1e-12) -> Array:
|
|
6
|
+
"""Normalize an array along an axis using an Lp norm.
|
|
7
|
+
|
|
8
|
+
Computes ``x / max(||x||_p, eps)`` along ``axis``, where the denominator is
|
|
9
|
+
clamped to at least ``eps`` to prevent division by zero.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
x: Input array to normalize.
|
|
13
|
+
p: Order of the norm. ``p=2`` gives the Euclidean (L2) norm,
|
|
14
|
+
``p=1`` gives the Manhattan (L1) norm, etc. Defaults to ``2``.
|
|
15
|
+
axis: Axis along which to compute the norm and normalize.
|
|
16
|
+
Defaults to ``1``.
|
|
17
|
+
eps: Small constant added to the denominator for numerical stability.
|
|
18
|
+
Defaults to ``1e-12``.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Array of the same shape and dtype as ``x`` with unit Lp norm along
|
|
22
|
+
``axis`` (or norm equal to ``eps`` when the input norm is smaller).
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
>>> import jax.numpy as jnp
|
|
26
|
+
>>> x = jnp.array([[3.0, 4.0], [0.0, 0.0]])
|
|
27
|
+
>>> normalize(x) # L2 norm along axis=1
|
|
28
|
+
Array([[0.6, 0.8],
|
|
29
|
+
[0. , 0. ]], dtype=float32)
|
|
30
|
+
|
|
31
|
+
>>> normalize(x, p=1) # L1 norm along axis=1
|
|
32
|
+
Array([[0.42857143, 0.5714286 ],
|
|
33
|
+
[0. , 0. ]], dtype=float32)
|
|
34
|
+
"""
|
|
35
|
+
norm = jnp.linalg.norm(x, ord=p, axis=axis, keepdims=True)
|
|
36
|
+
norm = jnp.maximum(norm, eps)
|
|
37
|
+
output = x / norm
|
|
38
|
+
|
|
39
|
+
return output
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
from jaxtyping import Array, PRNGKeyArray
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def stochastic_depth(
|
|
7
|
+
input: Array,
|
|
8
|
+
p: float,
|
|
9
|
+
mode: str,
|
|
10
|
+
inference: bool,
|
|
11
|
+
key: PRNGKeyArray,
|
|
12
|
+
) -> Array:
|
|
13
|
+
"""Apply Stochastic Depth regularization (DropPath).
|
|
14
|
+
|
|
15
|
+
During training, randomly drops entire samples (``mode="batch"``) or
|
|
16
|
+
individual rows (``mode="row"``). At inference time or when ``p=0``, the
|
|
17
|
+
input is returned unchanged. Surviving elements are scaled by
|
|
18
|
+
``1 / (1 - p)`` to preserve the expected value.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
input: Input array to apply stochastic depth to.
|
|
22
|
+
p: Drop probability in ``[0, 1]``. The probability that a sample/row
|
|
23
|
+
is zeroed out. ``p=0`` disables the operation.
|
|
24
|
+
mode: Dropping granularity. One of:
|
|
25
|
+
- ``"batch"``: a single binary mask is broadcast over the whole
|
|
26
|
+
batch (all samples share the same fate).
|
|
27
|
+
- ``"row"``: each sample in the batch gets its own independent
|
|
28
|
+
mask.
|
|
29
|
+
inference: If ``True``, skip stochastic depth and return ``input``
|
|
30
|
+
unchanged (equivalent to test/eval mode).
|
|
31
|
+
key: JAX PRNG key used to sample the Bernoulli noise.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Array of the same shape and dtype as ``input``.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If ``p`` is outside ``[0, 1]``.
|
|
38
|
+
ValueError: If ``mode`` is not ``"batch"`` or ``"row"``.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> import jax
|
|
42
|
+
>>> import jax.numpy as jnp
|
|
43
|
+
>>> key = jax.random.PRNGKey(0)
|
|
44
|
+
>>> x = jnp.ones((4, 8))
|
|
45
|
+
>>> # row mode: each of the 4 samples may be dropped independently
|
|
46
|
+
>>> out = stochastic_depth(x, p=0.5, mode="row", inference=False, key=key)
|
|
47
|
+
>>> out.shape
|
|
48
|
+
(4, 8)
|
|
49
|
+
|
|
50
|
+
>>> # inference mode always returns the input unchanged
|
|
51
|
+
>>> stochastic_depth(x, p=0.9, mode="batch", inference=True, key=key) is x
|
|
52
|
+
True
|
|
53
|
+
"""
|
|
54
|
+
if p < 0.0 or p > 1.0:
|
|
55
|
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
|
56
|
+
if mode not in ["batch", "row"]:
|
|
57
|
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
|
58
|
+
if inference or p == 0.0:
|
|
59
|
+
return input
|
|
60
|
+
survival_rate = 1.0 - p
|
|
61
|
+
if mode == "row":
|
|
62
|
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
|
63
|
+
else:
|
|
64
|
+
size = [1] * input.ndim
|
|
65
|
+
noise = jax.random.bernoulli(key, p=survival_rate, shape=size).astype(input.dtype)
|
|
66
|
+
if survival_rate > 0.0:
|
|
67
|
+
noise = noise / survival_rate
|
|
68
|
+
return input * noise
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def dropout(
|
|
72
|
+
x: Array,
|
|
73
|
+
p: float,
|
|
74
|
+
inference: bool,
|
|
75
|
+
key: PRNGKeyArray | None = None,
|
|
76
|
+
) -> Array:
|
|
77
|
+
"""Apply dropout regularization to an array.
|
|
78
|
+
|
|
79
|
+
During training, each element is independently zeroed with probability
|
|
80
|
+
``p``. Surviving elements are scaled by ``1 / (1 - p)`` so the expected
|
|
81
|
+
sum is preserved (inverted dropout). At inference time or when ``p=0``
|
|
82
|
+
the input is returned unchanged.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
x: Input array.
|
|
86
|
+
p: Probability of an element being zeroed. Must be in ``[0, 1)``.
|
|
87
|
+
When ``p=0`` the function is a no-op regardless of ``inference``.
|
|
88
|
+
inference: If ``True``, return ``x`` unchanged (eval/test mode).
|
|
89
|
+
key: JAX PRNG key. Required when ``inference=False`` and ``p > 0``.
|
|
90
|
+
Defaults to ``None``.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Array of the same shape and dtype as ``x``.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
RuntimeError: If ``inference=False``, ``p > 0``, and ``key`` is
|
|
97
|
+
``None``.
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
>>> import jax
|
|
101
|
+
>>> import jax.numpy as jnp
|
|
102
|
+
>>> key = jax.random.PRNGKey(42)
|
|
103
|
+
>>> x = jnp.ones((3, 4))
|
|
104
|
+
>>> out = dropout(x, p=0.5, inference=False, key=key)
|
|
105
|
+
>>> out.shape
|
|
106
|
+
(3, 4)
|
|
107
|
+
|
|
108
|
+
>>> # inference mode: output equals input
|
|
109
|
+
>>> import jax.numpy as jnp
|
|
110
|
+
>>> x = jnp.array([1.0, 2.0, 3.0])
|
|
111
|
+
>>> dropout(x, p=0.5, inference=True)
|
|
112
|
+
Array([1., 2., 3.], dtype=float32)
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(p, (int, float)) and p == 0:
|
|
115
|
+
inference = True
|
|
116
|
+
if inference:
|
|
117
|
+
return x
|
|
118
|
+
elif key is None:
|
|
119
|
+
raise RuntimeError(
|
|
120
|
+
"Dropout requires a key when running in non-deterministic mode."
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
q = 1 - jax.lax.stop_gradient(p)
|
|
124
|
+
mask = jax.random.bernoulli(key, q, x.shape)
|
|
125
|
+
return jnp.where(mask, x / q, 0)
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
from jaxtyping import Array
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AbstractNorm(eqx.Module):
|
|
8
|
+
@abc.abstractmethod
|
|
9
|
+
def __call__(self, x: Array, *_, **__) -> Array: ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AbstractNormStateful(eqx.nn.StatefulLayer):
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def __call__(
|
|
15
|
+
self, x: Array, state: eqx.nn.State, *_, **__
|
|
16
|
+
) -> tuple[Array, eqx.nn.State]: ...
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from beartype.typing import Any
|
|
5
|
+
from jaxtyping import PyTree
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def default_floating_dtype() -> Any:
|
|
9
|
+
if jax.config.read("jax_enable_x64"):
|
|
10
|
+
return jnp.float64
|
|
11
|
+
else:
|
|
12
|
+
return jnp.float32
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def summarize_model(model: PyTree) -> str:
|
|
16
|
+
params, _ = eqx.partition(model, eqx.is_array)
|
|
17
|
+
|
|
18
|
+
param_counts = {}
|
|
19
|
+
total_params = 0
|
|
20
|
+
|
|
21
|
+
def count_params(pytree, name=""):
|
|
22
|
+
nonlocal total_params
|
|
23
|
+
count = 0
|
|
24
|
+
if isinstance(pytree, jnp.ndarray):
|
|
25
|
+
count = pytree.size
|
|
26
|
+
total_params += count
|
|
27
|
+
if name:
|
|
28
|
+
param_counts[name] = count
|
|
29
|
+
elif hasattr(pytree, "__dict__"):
|
|
30
|
+
for key, value in pytree.__dict__.items():
|
|
31
|
+
subname = f"{name}.{key}" if name else key
|
|
32
|
+
count += count_params(value, subname)
|
|
33
|
+
elif isinstance(pytree, (list, tuple)):
|
|
34
|
+
for i, value in enumerate(pytree):
|
|
35
|
+
subname = f"{name}[{i}]" if name else f"[{i}]"
|
|
36
|
+
count += count_params(value, subname)
|
|
37
|
+
elif isinstance(pytree, dict):
|
|
38
|
+
for key, value in pytree.items():
|
|
39
|
+
subname = f"{name}.{key}" if name else str(key)
|
|
40
|
+
count += count_params(value, subname)
|
|
41
|
+
return count
|
|
42
|
+
|
|
43
|
+
count_params(params)
|
|
44
|
+
|
|
45
|
+
# Display as table
|
|
46
|
+
lines = []
|
|
47
|
+
lines.append("Model Parameter Summary")
|
|
48
|
+
lines.append("=" * 50)
|
|
49
|
+
lines.append(f"{'Parameter Name':<30} {'Count':<15}")
|
|
50
|
+
lines.append("-" * 50)
|
|
51
|
+
|
|
52
|
+
for name, count in param_counts.items():
|
|
53
|
+
lines.append(f"{name:<30} {count:<15,}")
|
|
54
|
+
|
|
55
|
+
lines.append("-" * 50)
|
|
56
|
+
lines.append(f"{'Total Parameters':<30} {total_params:<15,}")
|
|
57
|
+
lines.append("=" * 50)
|
|
58
|
+
|
|
59
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: noxton
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.13
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: equinox>=0.13.4
|
|
9
|
+
Requires-Dist: jax>=0.9.0.1
|
|
10
|
+
Requires-Dist: statedict2pytree>=2.0.1
|
|
11
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
noxton.egg-info/PKG-INFO
|
|
5
|
+
noxton.egg-info/SOURCES.txt
|
|
6
|
+
noxton.egg-info/dependency_links.txt
|
|
7
|
+
noxton.egg-info/requires.txt
|
|
8
|
+
noxton.egg-info/top_level.txt
|
|
9
|
+
noxton/functions/__init__.py
|
|
10
|
+
noxton/functions/activation.py
|
|
11
|
+
noxton/functions/attention.py
|
|
12
|
+
noxton/functions/embedding.py
|
|
13
|
+
noxton/functions/masking.py
|
|
14
|
+
noxton/functions/normalization.py
|
|
15
|
+
noxton/functions/regularization.py
|
|
16
|
+
noxton/nn/__init__.py
|
|
17
|
+
noxton/nn/abstract.py
|
|
18
|
+
noxton/utils/__init__.py
|
|
19
|
+
noxton/utils/utils.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
noxton
|
noxton-0.1.0/setup.cfg
ADDED