jaxonlayers 0.1.4__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/.pre-commit-config.yaml +2 -2
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/PKG-INFO +1 -1
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/masking.py +21 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/utils.py +1 -1
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/__init__.py +12 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/attention.py +3 -0
- jaxonlayers-0.2.0/jaxonlayers/layers/transformer.py +728 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/pyproject.toml +1 -1
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_efficientnet_layers.py +3 -5
- jaxonlayers-0.2.0/tests/test_transformer.py +503 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/.gitignore +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/.python-version +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/README.md +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/__init__.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/__init__.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/attention.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/embedding.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/initialization.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/normalization.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/regularization.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/state_space.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/abstract.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/convolution.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/normalization.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/regularization.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/sequential.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/state_space.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/__init__.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_attention.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_batch_norm.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_layernorm.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_local_response_normalisation.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_mha.py +0 -0
- {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/uv.lock +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
repos:
|
|
2
2
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
3
|
-
rev: v0.
|
|
3
|
+
rev: v0.14.7
|
|
4
4
|
hooks:
|
|
5
5
|
- id: ruff
|
|
6
6
|
args: [--fix]
|
|
7
7
|
- id: ruff-format
|
|
8
8
|
- repo: https://github.com/RobertCraigie/pyright-python
|
|
9
|
-
rev: v1.1.
|
|
9
|
+
rev: v1.1.407
|
|
10
10
|
hooks:
|
|
11
11
|
- id: pyright
|
|
12
12
|
additional_dependencies:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
+
from jaxtyping import Array, Bool
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
def canonical_mask(
|
|
@@ -50,7 +51,27 @@ def canonical_attn_mask(attn_mask, query_dtype=jnp.float32):
|
|
|
50
51
|
)
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
def make_causal_mask(seq_len: int) -> Bool[Array, "seq_len seq_len"]:
|
|
55
|
+
"""
|
|
56
|
+
Returns a boolean mask.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
[[True, False, False],
|
|
60
|
+
[True, True, False],
|
|
61
|
+
[True, True, True ]]
|
|
62
|
+
"""
|
|
63
|
+
return jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
|
64
|
+
|
|
65
|
+
|
|
53
66
|
def build_attention_mask(context_length: int):
|
|
67
|
+
"""
|
|
68
|
+
Returns a numerical matrix with 0 and -inf.
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
[[ 0, -inf, -inf],
|
|
72
|
+
[ 0, 0, -inf],
|
|
73
|
+
[ 0, 0, 0 ]]
|
|
74
|
+
"""
|
|
54
75
|
mask = jnp.tril(jnp.zeros((context_length, context_length)))
|
|
55
76
|
upper = jnp.triu(jnp.full((context_length, context_length), float("-inf")), k=1)
|
|
56
77
|
|
|
@@ -5,6 +5,13 @@ from .normalization import BatchNorm, LayerNorm, LocalResponseNormalization
|
|
|
5
5
|
from .regularization import StochasticDepth
|
|
6
6
|
from .sequential import BatchedLinear
|
|
7
7
|
from .state_space import SelectiveStateSpace
|
|
8
|
+
from .transformer import (
|
|
9
|
+
Transformer,
|
|
10
|
+
TransformerDecoder,
|
|
11
|
+
TransformerDecoderLayer,
|
|
12
|
+
TransformerEncoder,
|
|
13
|
+
TransformerEncoderLayer,
|
|
14
|
+
)
|
|
8
15
|
|
|
9
16
|
__all__ = [
|
|
10
17
|
"BatchNorm",
|
|
@@ -18,4 +25,9 @@ __all__ = [
|
|
|
18
25
|
"AbstractNormStateful",
|
|
19
26
|
"AbstractNorm",
|
|
20
27
|
"BatchedLinear",
|
|
28
|
+
"Transformer",
|
|
29
|
+
"TransformerDecoder",
|
|
30
|
+
"TransformerDecoderLayer",
|
|
31
|
+
"TransformerEncoder",
|
|
32
|
+
"TransformerEncoderLayer",
|
|
21
33
|
]
|
|
@@ -131,6 +131,7 @@ class MultiheadAttention(eqx.Module):
|
|
|
131
131
|
attn_mask: Array | None = None,
|
|
132
132
|
average_attn_weights: bool = True,
|
|
133
133
|
is_causal: bool = False,
|
|
134
|
+
dropout_key: PRNGKeyArray | None = None,
|
|
134
135
|
) -> tuple[Array, Array | None]:
|
|
135
136
|
key_padding_mask = canonical_mask(
|
|
136
137
|
mask=key_padding_mask,
|
|
@@ -171,6 +172,7 @@ class MultiheadAttention(eqx.Module):
|
|
|
171
172
|
v_proj_weight=self.v_proj_weight,
|
|
172
173
|
average_attn_weights=average_attn_weights,
|
|
173
174
|
is_causal=is_causal,
|
|
175
|
+
dropout_key=dropout_key,
|
|
174
176
|
)
|
|
175
177
|
else:
|
|
176
178
|
attn_output, attn_output_weights = multi_head_attention_forward(
|
|
@@ -193,6 +195,7 @@ class MultiheadAttention(eqx.Module):
|
|
|
193
195
|
attn_mask=attn_mask,
|
|
194
196
|
average_attn_weights=average_attn_weights,
|
|
195
197
|
is_causal=is_causal,
|
|
198
|
+
dropout_key=dropout_key,
|
|
196
199
|
)
|
|
197
200
|
|
|
198
201
|
return attn_output, attn_output_weights
|