noxton 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
noxton-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Artur A. Galstyan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
noxton-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: noxton
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.13
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: equinox>=0.13.4
9
+ Requires-Dist: jax>=0.9.0.1
10
+ Requires-Dist: statedict2pytree>=2.0.1
11
+ Dynamic: license-file
noxton-0.1.0/README.md ADDED
File without changes
@@ -0,0 +1,32 @@
1
+ from .activation import swiglu
2
+ from .attention import (
3
+ create_attn_mask,
4
+ multi_head_attention_forward,
5
+ shifted_window_attention,
6
+ )
7
+ from .embedding import sinusoidal_embedding
8
+ from .masking import (
9
+ build_attention_mask,
10
+ canonical_attn_mask,
11
+ canonical_key_padding_mask,
12
+ canonical_mask,
13
+ make_causal_mask,
14
+ )
15
+ from .normalization import normalize
16
+ from .regularization import dropout, stochastic_depth
17
+
18
+ __all__ = [
19
+ "swiglu",
20
+ "multi_head_attention_forward",
21
+ "shifted_window_attention",
22
+ "create_attn_mask",
23
+ "sinusoidal_embedding",
24
+ "dropout",
25
+ "stochastic_depth",
26
+ "normalize",
27
+ "build_attention_mask",
28
+ "canonical_attn_mask",
29
+ "canonical_key_padding_mask",
30
+ "canonical_mask",
31
+ "make_causal_mask",
32
+ ]
@@ -0,0 +1,34 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jaxtyping import Array
4
+
5
+
6
+ def swiglu(x: Array, axis: int = -1) -> Array:
7
+ """Apply the SwiGLU activation function.
8
+
9
+ Splits the input array into two halves along the specified axis, applies
10
+ the Swish activation to the first half, and multiplies it element-wise
11
+ with the second half. This gated activation is commonly used in
12
+ transformer feed-forward blocks.
13
+
14
+ Args:
15
+ x: Input array. Its size along ``axis`` must be even.
16
+ axis: Axis along which to split the input into two halves.
17
+ Defaults to ``-1`` (last axis).
18
+
19
+ Returns:
20
+ Array of the same dtype as ``x`` with size halved along ``axis``.
21
+
22
+ Example:
23
+ >>> import jax.numpy as jnp
24
+ >>> x = jnp.array([1.0, 2.0, -1.0, 0.5])
25
+ >>> swiglu(x) # splits into [1., 2.] and [-1., 0.5]
26
+ Array([-0.26894143, 0.9526741 ], dtype=float32)
27
+
28
+ >>> # 2-D input, split along last axis (default)
29
+ >>> x2d = jnp.ones((3, 4))
30
+ >>> swiglu(x2d).shape
31
+ (3, 2)
32
+ """
33
+ a, b = jnp.split(x, 2, axis=axis)
34
+ return jax.nn.swish(a) * b
@@ -0,0 +1,594 @@
1
+ import functools
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from beartype.typing import Any
7
+ from jaxtyping import Array, Bool, Float, PRNGKeyArray
8
+
9
+ from noxton.functions.normalization import normalize
10
+ from noxton.functions.regularization import dropout as dropout_fn
11
+ from noxton.utils.utils import default_floating_dtype
12
+
13
+
14
+ def multi_head_attention_forward(
15
+ query: Float[Array, "tgt_len d_model"],
16
+ key: Float[Array, "src_len d_model"],
17
+ value: Float[Array, "src_len d_model"],
18
+ embed_dim_to_check: int,
19
+ num_heads: int,
20
+ in_proj_weight: Float[Array, "3*d_model d_model"] | None = None,
21
+ in_proj_bias: Float[Array, "3*d_model"] | None = None,
22
+ bias_k: Float[Array, "1 d_model"] | None = None,
23
+ bias_v: Float[Array, "1 d_model"] | None = None,
24
+ add_zero_attn: bool = False,
25
+ dropout_p: float = 0.0,
26
+ out_proj_weight: Float[Array, "d_model d_model"] | None = None,
27
+ out_proj_bias: Float[Array, "d_model"] | None = None,
28
+ inference: bool = False,
29
+ key_padding_mask: Float[Array, "src_len"] | Bool[Array, "src_len"] | None = None,
30
+ attn_mask: Float[Array, "tgt_len src_len"] | None = None,
31
+ need_weights: bool = True,
32
+ use_separate_proj_weight: bool = False,
33
+ q_proj_weight: Float[Array, "d_model d_model"] | None = None,
34
+ k_proj_weight: Float[Array, "d_model d_model"] | None = None,
35
+ v_proj_weight: Float[Array, "d_model d_model"] | None = None,
36
+ static_k: Float[Array, "src_len d_model"] | None = None,
37
+ static_v: Float[Array, "src_len d_model"] | None = None,
38
+ average_attn_weights: bool = True,
39
+ is_causal: bool = False,
40
+ dropout_key: PRNGKeyArray | None = None,
41
+ ) -> tuple[
42
+ Float[Array, "tgt_len d_model"],
43
+ Float[Array, "num_heads tgt_len src_len"]
44
+ | Float[Array, "tgt_len src_len"]
45
+ | Float[Array, "tgt_len src_len+1"]
46
+ | None,
47
+ ]:
48
+ """Compute scaled dot-product multi-head attention.
49
+
50
+ This is the functional core of multi-head attention as described in
51
+ *Attention Is All You Need* (Vaswani et al., 2017). It accepts
52
+ pre-allocated weight matrices and returns the attended output along with
53
+ optional attention weights.
54
+
55
+ This function is a 1:1 mapping from the PyTorch implementation.
56
+
57
+ The function supports several advanced options:
58
+
59
+ - **Separate projection weights** (``use_separate_proj_weight=True``):
60
+ pass individual ``q_proj_weight``, ``k_proj_weight``, and
61
+ ``v_proj_weight`` instead of a single fused ``in_proj_weight``.
62
+ - **Static key/value**: bypass the key/value projections by providing
63
+ pre-computed ``static_k`` / ``static_v`` tensors.
64
+ - **Extra key/value bias tokens** (``bias_k``, ``bias_v``): append
65
+ learnable bias tokens to the key and value sequences.
66
+ - **Zero-attention slot** (``add_zero_attn``): append a zero vector to
67
+ keys and values, giving the model an option to attend to nothing.
68
+ - **Causal masking** (``is_causal``): apply an upper-triangular ``-inf``
69
+ mask so each query can only attend to earlier or equal positions.
70
+ - **Key-padding mask** / **additive attention mask**: arbitrary masking
71
+ via ``key_padding_mask`` and ``attn_mask``.
72
+
73
+ Args:
74
+ query: Query tensor of shape ``(tgt_len, d_model)``.
75
+ key: Key tensor of shape ``(src_len, d_model)``.
76
+ value: Value tensor of shape ``(src_len, d_model)``.
77
+ embed_dim_to_check: Expected embedding dimension; asserted to equal
78
+ ``d_model``.
79
+ num_heads: Number of attention heads. Must divide ``d_model`` evenly.
80
+ in_proj_weight: Fused input projection weight of shape
81
+ ``(3*d_model, d_model)``. Required when
82
+ ``use_separate_proj_weight=False``.
83
+ in_proj_bias: Optional bias for the fused projection of shape
84
+ ``(3*d_model,)``.
85
+ bias_k: Optional learnable key bias of shape ``(1, d_model)`` appended
86
+ to the key sequence.
87
+ bias_v: Optional learnable value bias of shape ``(1, d_model)``
88
+ appended to the value sequence. Must be provided together with
89
+ ``bias_k``.
90
+ add_zero_attn: If ``True``, append a zero token to keys and values.
91
+ Defaults to ``False``.
92
+ dropout_p: Dropout probability applied to attention weights during
93
+ training. Defaults to ``0.0``.
94
+ out_proj_weight: Output projection weight of shape
95
+ ``(d_model, d_model)``. Required.
96
+ out_proj_bias: Optional output projection bias of shape ``(d_model,)``.
97
+ inference: If ``True``, disable dropout (eval mode).
98
+ Defaults to ``False``.
99
+ key_padding_mask: Optional mask of shape ``(src_len,)`` marking padded
100
+ key positions. ``True`` / non-zero values are masked out (set to
101
+ ``-inf`` in logits).
102
+ attn_mask: Optional additive attention mask of shape
103
+ ``(tgt_len, src_len)``. Values are added to attention logits
104
+ before softmax.
105
+ need_weights: If ``True``, also return the attention weight matrix.
106
+ Defaults to ``True``.
107
+ use_separate_proj_weight: If ``True``, use ``q_proj_weight``,
108
+ ``k_proj_weight``, ``v_proj_weight`` for projections instead of
109
+ ``in_proj_weight``. Defaults to ``False``.
110
+ q_proj_weight: Query projection weight ``(d_model, d_model)``.
111
+ Required when ``use_separate_proj_weight=True``.
112
+ k_proj_weight: Key projection weight ``(d_model, d_model)``.
113
+ Required when ``use_separate_proj_weight=True``.
114
+ v_proj_weight: Value projection weight ``(d_model, d_model)``.
115
+ Required when ``use_separate_proj_weight=True``.
116
+ static_k: Pre-computed key of shape ``(src_len, d_model)``. When
117
+ provided, bypasses the key projection.
118
+ static_v: Pre-computed value of shape ``(src_len, d_model)``. When
119
+ provided, bypasses the value projection.
120
+ average_attn_weights: If ``True``, average attention weights over
121
+ heads before returning. Defaults to ``True``.
122
+ is_causal: If ``True``, apply a causal mask to prevent attending to
123
+ future positions. Defaults to ``False``.
124
+ dropout_key: JAX PRNG key for attention dropout. Required when
125
+ ``dropout_p > 0.0`` and ``inference=False``.
126
+
127
+ Returns:
128
+ A tuple ``(attn_output, attn_weights)`` where:
129
+
130
+ - ``attn_output``: Attended output of shape ``(tgt_len, d_model)``.
131
+ - ``attn_weights``: Attention weights or ``None`` when
132
+ ``need_weights=False``. Shape depends on ``average_attn_weights``:
133
+ ``(tgt_len, src_len)`` when averaged, or
134
+ ``(num_heads, tgt_len, src_len)`` otherwise.
135
+
136
+ Example:
137
+ >>> import jax
138
+ >>> import jax.numpy as jnp
139
+ >>> key = jax.random.PRNGKey(0)
140
+ >>> d_model, num_heads, tgt_len, src_len = 8, 2, 4, 6
141
+ >>> q = jax.random.normal(key, (tgt_len, d_model))
142
+ >>> k = jax.random.normal(key, (src_len, d_model))
143
+ >>> v = jax.random.normal(key, (src_len, d_model))
144
+ >>> W = jax.random.normal(key, (3 * d_model, d_model))
145
+ >>> W_out = jax.random.normal(key, (d_model, d_model))
146
+ >>> out, weights = multi_head_attention_forward(
147
+ ... q, k, v,
148
+ ... embed_dim_to_check=d_model,
149
+ ... num_heads=num_heads,
150
+ ... in_proj_weight=W,
151
+ ... out_proj_weight=W_out,
152
+ ... inference=True,
153
+ ... )
154
+ >>> out.shape
155
+ (4, 8)
156
+ >>> weights.shape # averaged over heads by default
157
+ (4, 6)
158
+ """
159
+ tgt_len, d_model = query.shape
160
+ src_len, k_dim = key.shape
161
+ value_len, v_dim = value.shape
162
+
163
+ assert d_model == k_dim == v_dim == embed_dim_to_check, (
164
+ "Embedding dimensions must match"
165
+ )
166
+
167
+ assert src_len == value_len, "Key and value must have the same sequence length"
168
+
169
+ head_dim = d_model // num_heads
170
+ assert head_dim * num_heads == d_model, "embed_dim must be divisible by num_heads"
171
+
172
+ if dropout_p > 0.0:
173
+ assert dropout_key is not None, (
174
+ "dropout_key must be provided if dropout_p > 0.0"
175
+ )
176
+
177
+ if use_separate_proj_weight:
178
+ # When using separate projection weights for q, k, v
179
+ assert q_proj_weight is not None, (
180
+ "q_proj_weight should not be None when use_separate_proj_weight=True"
181
+ )
182
+ assert k_proj_weight is not None, (
183
+ "k_proj_weight should not be None when use_separate_proj_weight=True"
184
+ )
185
+ assert v_proj_weight is not None, (
186
+ "v_proj_weight should not be None when use_separate_proj_weight=True"
187
+ )
188
+
189
+ q = query @ q_proj_weight.T
190
+
191
+ if static_k is None:
192
+ k = key @ k_proj_weight.T
193
+ else:
194
+ k = static_k
195
+ src_len, _ = k.shape
196
+
197
+ if static_v is None:
198
+ v = value @ v_proj_weight.T
199
+ else:
200
+ v = static_v
201
+ value_len, _ = v.shape
202
+
203
+ if in_proj_bias is not None:
204
+ q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
205
+ q = q + q_bias
206
+ k = k + k_bias
207
+ v = v + v_bias
208
+
209
+ else:
210
+ assert in_proj_weight is not None, (
211
+ "in_proj_weight should not be None when use_separate_proj_weight=False"
212
+ )
213
+
214
+ q_proj_weight_part, k_proj_weight_part, v_proj_weight_part = jnp.split(
215
+ in_proj_weight, 3
216
+ )
217
+
218
+ q = query @ q_proj_weight_part.T
219
+
220
+ if static_k is None:
221
+ k = key @ k_proj_weight_part.T
222
+ else:
223
+ k = static_k
224
+ src_len, _ = static_k.shape
225
+
226
+ if static_v is None:
227
+ v = value @ v_proj_weight_part.T
228
+ else:
229
+ v = static_v
230
+ value_len, _ = static_v.shape
231
+
232
+ if in_proj_bias is not None:
233
+ q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
234
+ q = q + q_bias
235
+ k = k + k_bias
236
+ v = v + v_bias
237
+
238
+ assert src_len == value_len
239
+
240
+ q = q.reshape(tgt_len, num_heads, head_dim)
241
+ k = k.reshape(src_len, num_heads, head_dim)
242
+ v = v.reshape(src_len, num_heads, head_dim)
243
+
244
+ if add_zero_attn:
245
+ zero_attn_shape = (1, num_heads, head_dim)
246
+ k_zeros = jnp.zeros(zero_attn_shape)
247
+ v_zeros = jnp.zeros(zero_attn_shape)
248
+
249
+ k = jnp.concatenate([k, k_zeros], axis=0)
250
+ v = jnp.concatenate([v, v_zeros], axis=0)
251
+
252
+ src_len += 1
253
+ value_len += 1
254
+
255
+ if bias_k is not None and bias_v is not None:
256
+ bias_k = bias_k.reshape(1, num_heads, head_dim)
257
+ bias_v = bias_v.reshape(1, num_heads, head_dim)
258
+
259
+ k = jnp.concatenate([k, bias_k], axis=0)
260
+ v = jnp.concatenate([v, bias_v], axis=0)
261
+
262
+ src_len += 1
263
+ value_len += 1
264
+
265
+ assert src_len == value_len
266
+
267
+ # [tgt_len, num_heads, head_dim] → [num_heads, tgt_len, head_dim]
268
+ q = jnp.transpose(q, (1, 0, 2))
269
+
270
+ # [src_len, num_heads, head_dim] → [num_heads, src_len, head_dim]
271
+ k = jnp.transpose(k, (1, 0, 2))
272
+ v = jnp.transpose(v, (1, 0, 2))
273
+
274
+ scale = jnp.sqrt(head_dim)
275
+ attn_output_weights = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) / scale
276
+
277
+ if key_padding_mask is not None:
278
+ padding_mask = key_padding_mask.reshape(1, 1, src_len)
279
+ padding_mask = jnp.repeat(padding_mask, num_heads, axis=0)
280
+ padding_mask = jnp.repeat(padding_mask, tgt_len, axis=1)
281
+ attn_output_weights = jnp.where(
282
+ padding_mask, float("-inf"), attn_output_weights
283
+ )
284
+
285
+ if attn_mask is not None:
286
+ # [tgt_len, src_len] -> [num_heads, tgt_len, src_len]
287
+ mask = attn_mask.reshape(1, tgt_len, src_len)
288
+ mask = jnp.repeat(mask, num_heads, axis=0)
289
+ attn_output_weights = attn_output_weights + mask
290
+
291
+ if is_causal:
292
+ causal_mask = jnp.triu(jnp.ones((tgt_len, src_len)), k=1)
293
+ causal_mask = (causal_mask == 1).reshape(1, tgt_len, src_len)
294
+ causal_mask = jnp.repeat(causal_mask, num_heads, axis=0)
295
+ attn_output_weights = jnp.where(causal_mask, float("-inf"), attn_output_weights)
296
+
297
+ # [num_heads, tgt_len, src_len]
298
+ attn_output_weights = jax.nn.softmax(attn_output_weights, axis=-1)
299
+
300
+ if dropout_p > 0.0 and not inference:
301
+ assert dropout_key is not None, (
302
+ "dropout_key required because dropout_p > 0.0 and training"
303
+ )
304
+ dropout_mask = jax.random.bernoulli(
305
+ dropout_key, 1 - dropout_p, attn_output_weights.shape
306
+ )
307
+ scale = 1.0 / (1.0 - dropout_p)
308
+ attn_output_weights = attn_output_weights * dropout_mask * scale
309
+
310
+ attn_output = jnp.matmul(attn_output_weights, v)
311
+ attn_output = jnp.transpose(attn_output, (1, 0, 2))
312
+ attn_output = attn_output.reshape(tgt_len, d_model)
313
+
314
+ assert out_proj_weight is not None, "out_proj_weight must be provided"
315
+ attn_output = attn_output @ out_proj_weight.T
316
+
317
+ if out_proj_bias is not None:
318
+ attn_output = attn_output + out_proj_bias
319
+
320
+ if need_weights:
321
+ if average_attn_weights:
322
+ attn_output_weights = attn_output_weights.mean(axis=0)
323
+ return attn_output, attn_output_weights
324
+ else:
325
+ return attn_output, None
326
+
327
+
328
+ def create_attn_mask(
329
+ pad_H: int,
330
+ pad_W: int,
331
+ window_size: list[int],
332
+ shift_size: list[int],
333
+ dtype: Any | None = None,
334
+ ) -> Array:
335
+ """Build the region-index mask used by shifted-window attention.
336
+
337
+ Assigns each spatial position in a ``(pad_H, pad_W)`` feature map an
338
+ integer *region index* that encodes which window it belongs to after a
339
+ cyclic shift. Positions in the same window share the same region index;
340
+ positions in different windows have different indices. The mask is used
341
+ downstream to zero out cross-window attention scores.
342
+
343
+ The region indices are computed by partitioning the height axis at
344
+ ``pad_H - window_size[0]`` and ``pad_H - shift_size[0]``, and the width
345
+ axis at ``pad_W - window_size[1]`` and ``pad_W - shift_size[1]``, then
346
+ assigning a unique integer to each ``(row_region, col_region)`` cell.
347
+
348
+ Args:
349
+ pad_H: Padded feature map height (must be a multiple of
350
+ ``window_size[0]``).
351
+ pad_W: Padded feature map width (must be a multiple of
352
+ ``window_size[1]``).
353
+ window_size: Local attention window size as ``[window_H, window_W]``.
354
+ shift_size: Cyclic shift amounts as ``[shift_H, shift_W]``.
355
+ dtype: Output array dtype. Defaults to the project's default floating
356
+ dtype when ``None``.
357
+
358
+ Returns:
359
+ Integer-region-index map of shape ``(pad_H, pad_W)`` cast to
360
+ ``dtype``.
361
+
362
+ Example:
363
+ >>> mask = create_attn_mask(8, 8, window_size=[4, 4], shift_size=[2, 2])
364
+ >>> mask.shape
365
+ (8, 8)
366
+ """
367
+ if dtype is None:
368
+ dtype = default_floating_dtype()
369
+ assert dtype is not None
370
+ h_boundaries = jnp.array([pad_H - window_size[0], pad_H - shift_size[0]])
371
+ w_boundaries = jnp.array([pad_W - window_size[1], pad_W - shift_size[1]])
372
+
373
+ h_boundaries = jnp.sort(h_boundaries)
374
+ w_boundaries = jnp.sort(w_boundaries)
375
+
376
+ ii, jj = jnp.indices((pad_H, pad_W)) # ii for rows, jj for columns
377
+
378
+ row_region_idx = jnp.searchsorted(h_boundaries, ii, side="right")
379
+ col_region_idx = jnp.searchsorted(w_boundaries, jj, side="right")
380
+
381
+ num_col_regions = len(w_boundaries) + 1
382
+ attn_mask = row_region_idx * num_col_regions + col_region_idx
383
+
384
+ return attn_mask.astype(dtype)
385
+
386
+
387
+ def shifted_window_attention(
388
+ x: Float[Array, "H W C"],
389
+ qkv_weight: Float[Array, "in_dim out_dim"],
390
+ proj_weight: Float[Array, "out_dim out_dim"],
391
+ relative_position_bias: Array,
392
+ window_size: list[int],
393
+ num_heads: int,
394
+ shift_size: list[int],
395
+ attention_dropout: float = 0.0,
396
+ dropout: float = 0.0,
397
+ qkv_bias: Array | None = None,
398
+ proj_bias: Array | None = None,
399
+ logit_scale: Array | None = None,
400
+ inference: bool = False,
401
+ key: PRNGKeyArray | None = None,
402
+ ) -> Float[Array, "H W C"]:
403
+ """Apply Shifted-Window Multi-Head Self-Attention (Swin Attention).
404
+
405
+ Implements the window-based self-attention mechanism from the Swin
406
+ Transformer (Liu et al., 2021). The input feature map is partitioned
407
+ into non-overlapping local windows; attention is computed independently
408
+ within each window. A cyclic spatial shift (``shift_size``) is applied
409
+ before partitioning to create cross-window connections, and a region-index
410
+ mask is used to prevent attending across window boundaries introduced by
411
+ the padding/shift.
412
+
413
+ Supports two attention score variants:
414
+
415
+ - **Scaled dot-product** (``logit_scale=None``): standard
416
+ ``q @ k.T / sqrt(head_dim)``.
417
+ - **Cosine attention** (when ``logit_scale`` is provided): L2-normalised
418
+ ``q`` and ``k`` are used, and scores are multiplied by a learnable
419
+ (but bounded) temperature ``exp(min(logit_scale, log(100)))``.
420
+
421
+ Relative position biases are added to the attention logits before softmax.
422
+
423
+ Args:
424
+ x: Input feature map of shape ``(H, W, C)``.
425
+ qkv_weight: Fused QKV projection weight of shape
426
+ ``(in_dim, out_dim)`` where ``out_dim = 3 * C``.
427
+ proj_weight: Output projection weight of shape ``(C, C)``.
428
+ relative_position_bias: Relative position bias tensor added to
429
+ attention logits; shape must broadcast with
430
+ ``(num_windows, num_heads, window_H*window_W, window_H*window_W)``.
431
+ window_size: Local window size as ``[window_H, window_W]``.
432
+ num_heads: Number of attention heads. Must divide ``C`` evenly.
433
+ shift_size: Cyclic shift amounts as ``[shift_H, shift_W]``.
434
+ Use ``[0, 0]`` to disable shifting (regular window attention).
435
+ attention_dropout: Dropout probability applied to attention weights.
436
+ Defaults to ``0.0``.
437
+ dropout: Dropout probability applied to the output projection.
438
+ Defaults to ``0.0``.
439
+ qkv_bias: Optional bias for the QKV projection of shape
440
+ ``(3 * C,)``. When ``logit_scale`` is also provided, the key
441
+ bias component is zeroed out (as in Swin V2).
442
+ proj_bias: Optional bias for the output projection of shape ``(C,)``.
443
+ logit_scale: Learnable log-scale scalar for cosine attention. When
444
+ ``None``, standard scaled dot-product attention is used.
445
+ inference: If ``True``, disable dropout. Defaults to ``False``.
446
+ key: JAX PRNG key required when ``inference=False``.
447
+
448
+ Returns:
449
+ Output feature map of shape ``(H, W, C)``.
450
+
451
+ Raises:
452
+ ValueError: If ``inference=False`` and ``key`` is ``None``.
453
+
454
+ Example:
455
+ >>> import jax
456
+ >>> import jax.numpy as jnp
457
+ >>> H, W, C, num_heads = 8, 8, 16, 2
458
+ >>> window_size, shift_size = [4, 4], [2, 2]
459
+ >>> key = jax.random.PRNGKey(0)
460
+ >>> x = jax.random.normal(key, (H, W, C))
461
+ >>> qkv_w = jax.random.normal(key, (C, 3 * C))
462
+ >>> proj_w = jax.random.normal(key, (C, C))
463
+ >>> win_tokens = window_size[0] * window_size[1]
464
+ >>> num_wins = (H // window_size[0]) * (W // window_size[1])
465
+ >>> rpb = jnp.zeros((num_wins, num_heads, win_tokens, win_tokens))
466
+ >>> out = shifted_window_attention(
467
+ ... x, qkv_w, proj_w, rpb,
468
+ ... window_size=window_size, num_heads=num_heads,
469
+ ... shift_size=shift_size, inference=True,
470
+ ... )
471
+ >>> out.shape
472
+ (8, 8, 16)
473
+ """
474
+ if not inference and key is None:
475
+ raise ValueError("Need key when in training mode")
476
+ H, W, C = x.shape
477
+ to_pad_W = (window_size[1] - W % window_size[1]) % window_size[1]
478
+ to_pad_H = (window_size[0] - H % window_size[0]) % window_size[0]
479
+ x = jnp.pad(x, ((0, to_pad_H), (0, to_pad_W), (0, 0)))
480
+ pad_H, pad_W, _ = x.shape
481
+
482
+ shift_size = shift_size.copy()
483
+ if window_size[0] >= pad_H:
484
+ shift_size[0] = 0
485
+ if window_size[1] >= pad_W:
486
+ shift_size[1] = 0
487
+
488
+ # cyclic shift
489
+ if sum(shift_size) > 0:
490
+ x = jnp.roll(x, shift=(-shift_size[0], -shift_size[1]), axis=(0, 1))
491
+
492
+ # partition windows
493
+ num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
494
+ x = jnp.reshape(
495
+ x,
496
+ (
497
+ pad_H // window_size[0],
498
+ window_size[0],
499
+ pad_W // window_size[1],
500
+ window_size[1],
501
+ C,
502
+ ),
503
+ )
504
+ x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(
505
+ num_windows, window_size[0] * window_size[1], C
506
+ )
507
+
508
+ # multi-head attention
509
+ if logit_scale is not None and qkv_bias is not None:
510
+ length = qkv_bias.size // 3
511
+ qkv_bias = qkv_bias.at[length : 2 * length].set(0.0)
512
+
513
+ def linear(x: Array, weight: Array, bias: Array | None):
514
+ output = x @ jnp.transpose(weight) # (in,) @ (in, out) -> (out,)
515
+ if bias is not None:
516
+ output = output + bias
517
+ return output
518
+
519
+ linear_pt = functools.partial(linear, weight=qkv_weight, bias=qkv_bias)
520
+
521
+ qkv = eqx.filter_vmap(eqx.filter_vmap(linear_pt))(x)
522
+ win_size, patches, _ = qkv.shape
523
+ qkv = jnp.transpose(
524
+ qkv.reshape(win_size, patches, 3, num_heads, C // num_heads), (2, 0, 3, 1, 4)
525
+ )
526
+ q, k, v = qkv[0], qkv[1], qkv[2]
527
+ if logit_scale is not None:
528
+ # cosine attention
529
+ attn = normalize(q, axis=-1) @ jnp.transpose(
530
+ normalize(k, axis=-1), (0, 1, 3, 2)
531
+ )
532
+ # Clamp the logit scale exponent for stability
533
+ logit_scale = jnp.exp(jnp.minimum(logit_scale, jnp.log(jnp.array(100.0))))
534
+ attn = attn * logit_scale
535
+ else:
536
+ q = q * (C // num_heads) ** -0.5
537
+ # attn = q @ (jnp.transpose(normalize(k, axis=-1), (0, 1, 3, 2))) # Incorrect
538
+ attn = q @ jnp.transpose(k, (0, 1, 3, 2)) # Corrected: q @ k.T
539
+
540
+ # add relative position bias
541
+ attn = attn + relative_position_bias
542
+
543
+ if sum(shift_size) > 0:
544
+ attn_mask = create_attn_mask(pad_H, pad_W, window_size, shift_size)
545
+ attn_mask = attn_mask.reshape(
546
+ pad_H // window_size[0],
547
+ window_size[0],
548
+ pad_W // window_size[1],
549
+ window_size[1],
550
+ )
551
+ attn_mask = jnp.transpose(attn_mask, (0, 2, 1, 3)).reshape(
552
+ num_windows, window_size[0] * window_size[1]
553
+ )
554
+ attn_mask = jnp.expand_dims(attn_mask, axis=1) - jnp.expand_dims(
555
+ attn_mask, axis=2
556
+ )
557
+ attn_mask = jnp.where(attn_mask == 0, 0.0, -100.0)
558
+
559
+ attn = attn + attn_mask[:, None, :, :]
560
+
561
+ attn = jax.nn.softmax(attn, axis=-1)
562
+ if not inference:
563
+ assert key is not None, "key must be given if not inference"
564
+ key, subkey = jax.random.split(key)
565
+ attn = dropout_fn(attn, p=attention_dropout, inference=inference, key=subkey)
566
+
567
+ x = jnp.transpose(attn @ v, (0, 2, 1, 3)).reshape(
568
+ num_windows, window_size[0] * window_size[1], C
569
+ )
570
+ linear_pt_proj = functools.partial(linear, weight=proj_weight, bias=proj_bias)
571
+
572
+ x = eqx.filter_vmap(eqx.filter_vmap(linear_pt_proj))(x)
573
+ if not inference:
574
+ assert key is not None, "key must be given if not inference"
575
+ key, subkey = jax.random.split(key)
576
+ x = dropout_fn(x, p=dropout, inference=inference, key=subkey)
577
+
578
+ # reverse windows
579
+ x = x.reshape(
580
+ pad_H // window_size[0],
581
+ pad_W // window_size[1],
582
+ window_size[0],
583
+ window_size[1],
584
+ C,
585
+ )
586
+ x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(pad_H, pad_W, C)
587
+
588
+ # reverse cyclic shift
589
+ if sum(shift_size) > 0:
590
+ x = jnp.roll(x, shift=(shift_size[0], shift_size[1]), axis=(0, 1))
591
+
592
+ # unpad features
593
+ x = x[:H, :W, :]
594
+ return x
@@ -0,0 +1,64 @@
1
+ import jax.numpy as jnp
2
+ from beartype.typing import Any
3
+ from jaxtyping import Array, Float, Int
4
+
5
+ from noxton.utils import default_floating_dtype
6
+
7
+
8
+ def sinusoidal_embedding(
9
+ t: Int[Array, ""], embedding_size: int, dtype: Any | None = None
10
+ ) -> Float[Array, " embedding_size"]:
11
+ """Compute a sinusoidal positional embedding for a scalar timestep.
12
+
13
+ Encodes a single integer timestep ``t`` into a fixed-size vector using
14
+ alternating sine and cosine functions at geometrically spaced frequencies,
15
+ following the scheme from *Attention Is All You Need* (Vaswani et al., 2017)
16
+ and commonly used in diffusion model timestep conditioning.
17
+
18
+ The embedding frequencies are::
19
+
20
+ freq_i = exp(-log(10000) * i / (embedding_size / 2)) for i in [0, half_dim)
21
+
22
+ The output is the concatenation of ``sin(t * freq)`` and ``cos(t * freq)``.
23
+
24
+ Args:
25
+ t: Scalar integer timestep (0-d array).
26
+ embedding_size: Length of the output embedding vector. Must be even.
27
+ dtype: Floating-point dtype for the output. Defaults to the project's
28
+ default floating dtype when ``None``.
29
+
30
+ Returns:
31
+ 1-D array of shape ``(embedding_size,)`` containing the sinusoidal
32
+ embedding for timestep ``t``.
33
+
34
+ Raises:
35
+ ValueError: If ``embedding_size`` is odd.
36
+
37
+ Example:
38
+ >>> import jax.numpy as jnp
39
+ >>> t = jnp.array(100)
40
+ >>> emb = sinusoidal_embedding(t, embedding_size=16)
41
+ >>> emb.shape
42
+ (16,)
43
+
44
+ >>> # Embeddings for different timesteps are distinct
45
+ >>> emb0 = sinusoidal_embedding(jnp.array(0), embedding_size=8)
46
+ >>> emb1 = sinusoidal_embedding(jnp.array(1), embedding_size=8)
47
+ >>> bool(jnp.any(emb0 != emb1))
48
+ True
49
+ """
50
+ if dtype is None:
51
+ dtype = default_floating_dtype()
52
+ assert dtype is not None
53
+ if embedding_size % 2 != 0:
54
+ raise ValueError(f"Embedding size must be even, but got {embedding_size}")
55
+
56
+ half_dim = embedding_size // 2
57
+ embedding_freqs = jnp.exp(
58
+ -jnp.log(10000) * jnp.arange(start=0, stop=half_dim, dtype=dtype) / half_dim
59
+ )
60
+
61
+ time_args = t * embedding_freqs
62
+ embedding = jnp.concatenate([jnp.sin(time_args), jnp.cos(time_args)], axis=-1)
63
+
64
+ return embedding
@@ -0,0 +1,198 @@
1
+ import jax.numpy as jnp
2
+ from jaxtyping import Array, Bool
3
+
4
+
5
+ def canonical_mask(
6
+ mask,
7
+ mask_name,
8
+ other_name="",
9
+ other_type=None,
10
+ target_type=jnp.float32,
11
+ other_mask=None,
12
+ check_other=True,
13
+ ):
14
+ """Convert an arbitrary mask tensor into a canonical additive float mask.
15
+
16
+ Accepts boolean, integer, or floating-point masks and normalises them to
17
+ a floating-point additive mask (``0.0`` for positions to attend to,
18
+ ``-inf`` for positions to ignore) that can be directly added to attention
19
+ logits.
20
+
21
+ - **Boolean masks**: ``True`` → ``-inf``, ``False`` → ``0.0``
22
+ - **Integer / float masks**: cast to ``target_type`` without modification.
23
+ - ``None`` input returns ``None`` (no mask).
24
+
25
+ Args:
26
+ mask: The mask to canonicalise. May be a boolean, integer, or
27
+ floating-point JAX array, or ``None``.
28
+ mask_name: Human-readable name of this mask used in error messages.
29
+ other_name: Human-readable name of the secondary mask (used in error
30
+ messages only). Defaults to ``""``.
31
+ other_type: Expected dtype of ``other_mask`` (currently unused in the
32
+ implementation but kept for API compatibility). Defaults to
33
+ ``None``.
34
+ target_type: Target floating-point dtype for the output. Defaults to
35
+ ``jnp.float32``.
36
+ other_mask: A secondary mask for cross-validation purposes (currently
37
+ unused). Defaults to ``None``.
38
+ check_other: Whether to perform cross-mask validation (currently
39
+ unused). Defaults to ``True``.
40
+
41
+ Returns:
42
+ A floating-point array of dtype ``target_type`` with the same shape
43
+ as ``mask``, or ``None`` if ``mask`` is ``None``.
44
+
45
+ Raises:
46
+ TypeError: If ``mask`` has an unsupported dtype (not bool, int, or
47
+ float).
48
+
49
+ Example:
50
+ >>> import jax.numpy as jnp
51
+ >>> bool_mask = jnp.array([True, False, True])
52
+ >>> canonical_mask(bool_mask, "attn_mask")
53
+ Array([-inf, 0., -inf], dtype=float32)
54
+
55
+ >>> float_mask = jnp.array([0.0, -1e9, 0.0])
56
+ >>> canonical_mask(float_mask, "attn_mask")
57
+ Array([ 0.e+00, -1.e+09, 0.e+00], dtype=float32)
58
+ """
59
+ if mask is None:
60
+ return None
61
+ if mask.dtype == bool:
62
+ additive_mask = jnp.where(mask, -jnp.inf, 0.0).astype(target_type)
63
+ return additive_mask
64
+ elif jnp.issubdtype(mask.dtype, jnp.integer) or jnp.issubdtype(
65
+ mask.dtype, jnp.floating
66
+ ):
67
+ return mask.astype(target_type)
68
+ else:
69
+ raise TypeError(
70
+ f"{mask_name} must be bool, int, or float tensor, but got {mask.dtype}"
71
+ )
72
+
73
+
74
+ def canonical_key_padding_mask(
75
+ key_padding_mask, attn_mask=None, query_dtype=jnp.float32
76
+ ):
77
+ """Convert a key-padding mask to a canonical additive float mask.
78
+
79
+ A convenience wrapper around :func:`canonical_mask` that applies the
80
+ correct argument names for key-padding masks. The resulting mask can be
81
+ directly added to attention logits: padded positions get ``-inf`` (boolean
82
+ ``True`` input) while real positions get ``0.0``.
83
+
84
+ Args:
85
+ key_padding_mask: Mask of shape ``(src_len,)`` indicating which key
86
+ positions are padding. May be boolean (``True`` = padded),
87
+ integer, or float, or ``None``.
88
+ attn_mask: The attention mask that will be combined with this mask
89
+ (used for error reporting only). Defaults to ``None``.
90
+ query_dtype: Target floating-point dtype for the output. Defaults to
91
+ ``jnp.float32``.
92
+
93
+ Returns:
94
+ A floating-point array of dtype ``query_dtype`` or ``None``.
95
+
96
+ Example:
97
+ >>> import jax.numpy as jnp
98
+ >>> kpm = jnp.array([False, False, True]) # last token is padding
99
+ >>> canonical_key_padding_mask(kpm)
100
+ Array([ 0., 0., -inf], dtype=float32)
101
+ """
102
+ return canonical_mask(
103
+ mask=key_padding_mask,
104
+ mask_name="key_padding_mask",
105
+ other_name="attn_mask",
106
+ other_mask=attn_mask,
107
+ target_type=query_dtype,
108
+ )
109
+
110
+
111
+ def canonical_attn_mask(attn_mask, query_dtype=jnp.float32):
112
+ """Convert an attention mask to a canonical additive float mask.
113
+
114
+ A convenience wrapper around :func:`canonical_mask` that applies the
115
+ correct argument names for attention masks. Boolean masks are converted
116
+ so that ``True`` positions are masked out (``-inf``) and ``False``
117
+ positions are kept (``0.0``). Numeric masks are cast to ``query_dtype``
118
+ without modification, allowing pre-built additive masks (e.g. causal
119
+ masks with ``0`` and ``-inf``) to be passed through unchanged.
120
+
121
+ Args:
122
+ attn_mask: Attention mask of shape ``(tgt_len, src_len)``. May be
123
+ boolean, integer, or floating-point, or ``None``.
124
+ query_dtype: Target floating-point dtype for the output. Defaults to
125
+ ``jnp.float32``.
126
+
127
+ Returns:
128
+ A floating-point array of dtype ``query_dtype`` or ``None``.
129
+
130
+ Example:
131
+ >>> import jax.numpy as jnp
132
+ >>> mask = jnp.array([[False, True], [False, False]])
133
+ >>> canonical_attn_mask(mask)
134
+ Array([[ 0., -inf],
135
+ [ 0., 0.]], dtype=float32)
136
+ """
137
+ return canonical_mask(
138
+ mask=attn_mask,
139
+ mask_name="attn_mask",
140
+ other_type=None,
141
+ other_name="",
142
+ target_type=query_dtype,
143
+ check_other=False,
144
+ )
145
+
146
+
147
+ def make_causal_mask(seq_len: int) -> Bool[Array, "seq_len seq_len"]:
148
+ """Create a boolean lower-triangular causal mask.
149
+
150
+ Position ``[i, j]`` is ``True`` when position ``j`` is allowed to attend
151
+ to position ``i`` (i.e. ``j <= i``), and ``False`` otherwise. Use this
152
+ mask to prevent tokens from attending to future positions.
153
+
154
+ Args:
155
+ seq_len: Sequence length; the output is a square matrix of shape
156
+ ``(seq_len, seq_len)``.
157
+
158
+ Returns:
159
+ Boolean array of shape ``(seq_len, seq_len)`` where the lower
160
+ triangle (including the diagonal) is ``True`` and the upper triangle
161
+ is ``False``.
162
+
163
+ Example:
164
+ >>> make_causal_mask(3)
165
+ Array([[ True, False, False],
166
+ [ True, True, False],
167
+ [ True, True, True]], dtype=bool)
168
+ """
169
+ return jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
170
+
171
+
172
+ def build_attention_mask(context_length: int) -> Array:
173
+ """Create an additive causal attention mask with ``0`` and ``-inf``.
174
+
175
+ Produces a float32 lower-triangular matrix where attended positions
176
+ carry ``0.0`` (no change to logits) and future positions carry ``-inf``
177
+ (zeroed out after softmax). The mask can be added directly to attention
178
+ logit matrices.
179
+
180
+ Args:
181
+ context_length: Sequence length; the output is a square matrix of
182
+ shape ``(context_length, context_length)``.
183
+
184
+ Returns:
185
+ Float32 array of shape ``(context_length, context_length)`` with
186
+ ``0.0`` on and below the diagonal and ``-inf`` above the diagonal.
187
+
188
+ Example:
189
+ >>> build_attention_mask(3)
190
+ Array([[ 0., -inf, -inf],
191
+ [ 0., 0., -inf],
192
+ [ 0., 0., 0.]], dtype=float32)
193
+ """
194
+ mask = jnp.tril(jnp.zeros((context_length, context_length)))
195
+ upper = jnp.triu(jnp.full((context_length, context_length), float("-inf")), k=1)
196
+
197
+ mask = mask + upper
198
+ return mask
@@ -0,0 +1,39 @@
1
+ import jax.numpy as jnp
2
+ from jaxtyping import Array
3
+
4
+
5
+ def normalize(x: Array, p: int = 2, axis: int = 1, eps: float = 1e-12) -> Array:
6
+ """Normalize an array along an axis using an Lp norm.
7
+
8
+ Computes ``x / max(||x||_p, eps)`` along ``axis``, where the denominator is
9
+ clamped to at least ``eps`` to prevent division by zero.
10
+
11
+ Args:
12
+ x: Input array to normalize.
13
+ p: Order of the norm. ``p=2`` gives the Euclidean (L2) norm,
14
+ ``p=1`` gives the Manhattan (L1) norm, etc. Defaults to ``2``.
15
+ axis: Axis along which to compute the norm and normalize.
16
+ Defaults to ``1``.
17
+ eps: Small constant added to the denominator for numerical stability.
18
+ Defaults to ``1e-12``.
19
+
20
+ Returns:
21
+ Array of the same shape and dtype as ``x`` with unit Lp norm along
22
+ ``axis`` (or norm equal to ``eps`` when the input norm is smaller).
23
+
24
+ Example:
25
+ >>> import jax.numpy as jnp
26
+ >>> x = jnp.array([[3.0, 4.0], [0.0, 0.0]])
27
+ >>> normalize(x) # L2 norm along axis=1
28
+ Array([[0.6, 0.8],
29
+ [0. , 0. ]], dtype=float32)
30
+
31
+ >>> normalize(x, p=1) # L1 norm along axis=1
32
+ Array([[0.42857143, 0.5714286 ],
33
+ [0. , 0. ]], dtype=float32)
34
+ """
35
+ norm = jnp.linalg.norm(x, ord=p, axis=axis, keepdims=True)
36
+ norm = jnp.maximum(norm, eps)
37
+ output = x / norm
38
+
39
+ return output
@@ -0,0 +1,125 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jaxtyping import Array, PRNGKeyArray
4
+
5
+
6
+ def stochastic_depth(
7
+ input: Array,
8
+ p: float,
9
+ mode: str,
10
+ inference: bool,
11
+ key: PRNGKeyArray,
12
+ ) -> Array:
13
+ """Apply Stochastic Depth regularization (DropPath).
14
+
15
+ During training, randomly drops entire samples (``mode="batch"``) or
16
+ individual rows (``mode="row"``). At inference time or when ``p=0``, the
17
+ input is returned unchanged. Surviving elements are scaled by
18
+ ``1 / (1 - p)`` to preserve the expected value.
19
+
20
+ Args:
21
+ input: Input array to apply stochastic depth to.
22
+ p: Drop probability in ``[0, 1]``. The probability that a sample/row
23
+ is zeroed out. ``p=0`` disables the operation.
24
+ mode: Dropping granularity. One of:
25
+ - ``"batch"``: a single binary mask is broadcast over the whole
26
+ batch (all samples share the same fate).
27
+ - ``"row"``: each sample in the batch gets its own independent
28
+ mask.
29
+ inference: If ``True``, skip stochastic depth and return ``input``
30
+ unchanged (equivalent to test/eval mode).
31
+ key: JAX PRNG key used to sample the Bernoulli noise.
32
+
33
+ Returns:
34
+ Array of the same shape and dtype as ``input``.
35
+
36
+ Raises:
37
+ ValueError: If ``p`` is outside ``[0, 1]``.
38
+ ValueError: If ``mode`` is not ``"batch"`` or ``"row"``.
39
+
40
+ Example:
41
+ >>> import jax
42
+ >>> import jax.numpy as jnp
43
+ >>> key = jax.random.PRNGKey(0)
44
+ >>> x = jnp.ones((4, 8))
45
+ >>> # row mode: each of the 4 samples may be dropped independently
46
+ >>> out = stochastic_depth(x, p=0.5, mode="row", inference=False, key=key)
47
+ >>> out.shape
48
+ (4, 8)
49
+
50
+ >>> # inference mode always returns the input unchanged
51
+ >>> stochastic_depth(x, p=0.9, mode="batch", inference=True, key=key) is x
52
+ True
53
+ """
54
+ if p < 0.0 or p > 1.0:
55
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
56
+ if mode not in ["batch", "row"]:
57
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
58
+ if inference or p == 0.0:
59
+ return input
60
+ survival_rate = 1.0 - p
61
+ if mode == "row":
62
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
63
+ else:
64
+ size = [1] * input.ndim
65
+ noise = jax.random.bernoulli(key, p=survival_rate, shape=size).astype(input.dtype)
66
+ if survival_rate > 0.0:
67
+ noise = noise / survival_rate
68
+ return input * noise
69
+
70
+
71
+ def dropout(
72
+ x: Array,
73
+ p: float,
74
+ inference: bool,
75
+ key: PRNGKeyArray | None = None,
76
+ ) -> Array:
77
+ """Apply dropout regularization to an array.
78
+
79
+ During training, each element is independently zeroed with probability
80
+ ``p``. Surviving elements are scaled by ``1 / (1 - p)`` so the expected
81
+ sum is preserved (inverted dropout). At inference time or when ``p=0``
82
+ the input is returned unchanged.
83
+
84
+ Args:
85
+ x: Input array.
86
+ p: Probability of an element being zeroed. Must be in ``[0, 1)``.
87
+ When ``p=0`` the function is a no-op regardless of ``inference``.
88
+ inference: If ``True``, return ``x`` unchanged (eval/test mode).
89
+ key: JAX PRNG key. Required when ``inference=False`` and ``p > 0``.
90
+ Defaults to ``None``.
91
+
92
+ Returns:
93
+ Array of the same shape and dtype as ``x``.
94
+
95
+ Raises:
96
+ RuntimeError: If ``inference=False``, ``p > 0``, and ``key`` is
97
+ ``None``.
98
+
99
+ Example:
100
+ >>> import jax
101
+ >>> import jax.numpy as jnp
102
+ >>> key = jax.random.PRNGKey(42)
103
+ >>> x = jnp.ones((3, 4))
104
+ >>> out = dropout(x, p=0.5, inference=False, key=key)
105
+ >>> out.shape
106
+ (3, 4)
107
+
108
+ >>> # inference mode: output equals input
109
+ >>> import jax.numpy as jnp
110
+ >>> x = jnp.array([1.0, 2.0, 3.0])
111
+ >>> dropout(x, p=0.5, inference=True)
112
+ Array([1., 2., 3.], dtype=float32)
113
+ """
114
+ if isinstance(p, (int, float)) and p == 0:
115
+ inference = True
116
+ if inference:
117
+ return x
118
+ elif key is None:
119
+ raise RuntimeError(
120
+ "Dropout requires a key when running in non-deterministic mode."
121
+ )
122
+ else:
123
+ q = 1 - jax.lax.stop_gradient(p)
124
+ mask = jax.random.bernoulli(key, q, x.shape)
125
+ return jnp.where(mask, x / q, 0)
File without changes
@@ -0,0 +1,16 @@
1
+ import abc
2
+
3
+ import equinox as eqx
4
+ from jaxtyping import Array
5
+
6
+
7
+ class AbstractNorm(eqx.Module):
8
+ @abc.abstractmethod
9
+ def __call__(self, x: Array, *_, **__) -> Array: ...
10
+
11
+
12
+ class AbstractNormStateful(eqx.nn.StatefulLayer):
13
+ @abc.abstractmethod
14
+ def __call__(
15
+ self, x: Array, state: eqx.nn.State, *_, **__
16
+ ) -> tuple[Array, eqx.nn.State]: ...
@@ -0,0 +1,5 @@
1
+ from .utils import default_floating_dtype
2
+
3
+ __all__ = [
4
+ "default_floating_dtype",
5
+ ]
@@ -0,0 +1,59 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from beartype.typing import Any
5
+ from jaxtyping import PyTree
6
+
7
+
8
+ def default_floating_dtype() -> Any:
9
+ if jax.config.read("jax_enable_x64"):
10
+ return jnp.float64
11
+ else:
12
+ return jnp.float32
13
+
14
+
15
+ def summarize_model(model: PyTree) -> str:
16
+ params, _ = eqx.partition(model, eqx.is_array)
17
+
18
+ param_counts = {}
19
+ total_params = 0
20
+
21
+ def count_params(pytree, name=""):
22
+ nonlocal total_params
23
+ count = 0
24
+ if isinstance(pytree, jnp.ndarray):
25
+ count = pytree.size
26
+ total_params += count
27
+ if name:
28
+ param_counts[name] = count
29
+ elif hasattr(pytree, "__dict__"):
30
+ for key, value in pytree.__dict__.items():
31
+ subname = f"{name}.{key}" if name else key
32
+ count += count_params(value, subname)
33
+ elif isinstance(pytree, (list, tuple)):
34
+ for i, value in enumerate(pytree):
35
+ subname = f"{name}[{i}]" if name else f"[{i}]"
36
+ count += count_params(value, subname)
37
+ elif isinstance(pytree, dict):
38
+ for key, value in pytree.items():
39
+ subname = f"{name}.{key}" if name else str(key)
40
+ count += count_params(value, subname)
41
+ return count
42
+
43
+ count_params(params)
44
+
45
+ # Display as table
46
+ lines = []
47
+ lines.append("Model Parameter Summary")
48
+ lines.append("=" * 50)
49
+ lines.append(f"{'Parameter Name':<30} {'Count':<15}")
50
+ lines.append("-" * 50)
51
+
52
+ for name, count in param_counts.items():
53
+ lines.append(f"{name:<30} {count:<15,}")
54
+
55
+ lines.append("-" * 50)
56
+ lines.append(f"{'Total Parameters':<30} {total_params:<15,}")
57
+ lines.append("=" * 50)
58
+
59
+ return "\n".join(lines)
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: noxton
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.13
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: equinox>=0.13.4
9
+ Requires-Dist: jax>=0.9.0.1
10
+ Requires-Dist: statedict2pytree>=2.0.1
11
+ Dynamic: license-file
@@ -0,0 +1,19 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ noxton.egg-info/PKG-INFO
5
+ noxton.egg-info/SOURCES.txt
6
+ noxton.egg-info/dependency_links.txt
7
+ noxton.egg-info/requires.txt
8
+ noxton.egg-info/top_level.txt
9
+ noxton/functions/__init__.py
10
+ noxton/functions/activation.py
11
+ noxton/functions/attention.py
12
+ noxton/functions/embedding.py
13
+ noxton/functions/masking.py
14
+ noxton/functions/normalization.py
15
+ noxton/functions/regularization.py
16
+ noxton/nn/__init__.py
17
+ noxton/nn/abstract.py
18
+ noxton/utils/__init__.py
19
+ noxton/utils/utils.py
@@ -0,0 +1,3 @@
1
+ equinox>=0.13.4
2
+ jax>=0.9.0.1
3
+ statedict2pytree>=2.0.1
@@ -0,0 +1 @@
1
+ noxton
@@ -0,0 +1,11 @@
1
+ [project]
2
+ name = "noxton"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "equinox>=0.13.4",
9
+ "jax>=0.9.0.1",
10
+ "statedict2pytree>=2.0.1",
11
+ ]
noxton-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+