jaxonlayers 0.1.4__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/.pre-commit-config.yaml +2 -2
  2. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/PKG-INFO +1 -1
  3. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/masking.py +21 -0
  4. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/utils.py +1 -1
  5. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/__init__.py +12 -0
  6. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/attention.py +3 -0
  7. jaxonlayers-0.2.0/jaxonlayers/layers/transformer.py +728 -0
  8. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/pyproject.toml +1 -1
  9. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_efficientnet_layers.py +3 -5
  10. jaxonlayers-0.2.0/tests/test_transformer.py +503 -0
  11. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/.gitignore +0 -0
  12. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/.python-version +0 -0
  13. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/README.md +0 -0
  14. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/__init__.py +0 -0
  15. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/__init__.py +0 -0
  16. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/attention.py +0 -0
  17. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/embedding.py +0 -0
  18. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/initialization.py +0 -0
  19. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/normalization.py +0 -0
  20. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/regularization.py +0 -0
  21. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/functions/state_space.py +0 -0
  22. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/abstract.py +0 -0
  23. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/convolution.py +0 -0
  24. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/normalization.py +0 -0
  25. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/regularization.py +0 -0
  26. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/sequential.py +0 -0
  27. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/jaxonlayers/layers/state_space.py +0 -0
  28. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/__init__.py +0 -0
  29. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_attention.py +0 -0
  30. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_batch_norm.py +0 -0
  31. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_layernorm.py +0 -0
  32. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_local_response_normalisation.py +0 -0
  33. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/tests/test_mha.py +0 -0
  34. {jaxonlayers-0.1.4 → jaxonlayers-0.2.0}/uv.lock +0 -0
@@ -1,12 +1,12 @@
1
1
  repos:
2
2
  - repo: https://github.com/astral-sh/ruff-pre-commit
3
- rev: v0.13.1
3
+ rev: v0.14.7
4
4
  hooks:
5
5
  - id: ruff
6
6
  args: [--fix]
7
7
  - id: ruff-format
8
8
  - repo: https://github.com/RobertCraigie/pyright-python
9
- rev: v1.1.405
9
+ rev: v1.1.407
10
10
  hooks:
11
11
  - id: pyright
12
12
  additional_dependencies:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxonlayers
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: Additional layers and functions that extend Equinox
5
5
  Requires-Python: >=3.13
6
6
  Requires-Dist: beartype>=0.21.0
@@ -1,4 +1,5 @@
1
1
  import jax.numpy as jnp
2
+ from jaxtyping import Array, Bool
2
3
 
3
4
 
4
5
  def canonical_mask(
@@ -50,7 +51,27 @@ def canonical_attn_mask(attn_mask, query_dtype=jnp.float32):
50
51
  )
51
52
 
52
53
 
54
+ def make_causal_mask(seq_len: int) -> Bool[Array, "seq_len seq_len"]:
55
+ """
56
+ Returns a boolean mask.
57
+
58
+ Example:
59
+ [[True, False, False],
60
+ [True, True, False],
61
+ [True, True, True ]]
62
+ """
63
+ return jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
64
+
65
+
53
66
  def build_attention_mask(context_length: int):
67
+ """
68
+ Returns a numerical matrix with 0 and -inf.
69
+
70
+ Example:
71
+ [[ 0, -inf, -inf],
72
+ [ 0, 0, -inf],
73
+ [ 0, 0, 0 ]]
74
+ """
54
75
  mask = jnp.tril(jnp.zeros((context_length, context_length)))
55
76
  upper = jnp.triu(jnp.full((context_length, context_length), float("-inf")), k=1)
56
77
 
@@ -5,7 +5,7 @@ from jaxtyping import PyTree
5
5
 
6
6
 
7
7
  def default_floating_dtype():
8
- if jax.config.jax_enable_x64: # pyright: ignore
8
+ if jax.config.read("jax_enable_x64"): # pyright: ignore
9
9
  return jnp.float64
10
10
  else:
11
11
  return jnp.float32
@@ -5,6 +5,13 @@ from .normalization import BatchNorm, LayerNorm, LocalResponseNormalization
5
5
  from .regularization import StochasticDepth
6
6
  from .sequential import BatchedLinear
7
7
  from .state_space import SelectiveStateSpace
8
+ from .transformer import (
9
+ Transformer,
10
+ TransformerDecoder,
11
+ TransformerDecoderLayer,
12
+ TransformerEncoder,
13
+ TransformerEncoderLayer,
14
+ )
8
15
 
9
16
  __all__ = [
10
17
  "BatchNorm",
@@ -18,4 +25,9 @@ __all__ = [
18
25
  "AbstractNormStateful",
19
26
  "AbstractNorm",
20
27
  "BatchedLinear",
28
+ "Transformer",
29
+ "TransformerDecoder",
30
+ "TransformerDecoderLayer",
31
+ "TransformerEncoder",
32
+ "TransformerEncoderLayer",
21
33
  ]
@@ -131,6 +131,7 @@ class MultiheadAttention(eqx.Module):
131
131
  attn_mask: Array | None = None,
132
132
  average_attn_weights: bool = True,
133
133
  is_causal: bool = False,
134
+ dropout_key: PRNGKeyArray | None = None,
134
135
  ) -> tuple[Array, Array | None]:
135
136
  key_padding_mask = canonical_mask(
136
137
  mask=key_padding_mask,
@@ -171,6 +172,7 @@ class MultiheadAttention(eqx.Module):
171
172
  v_proj_weight=self.v_proj_weight,
172
173
  average_attn_weights=average_attn_weights,
173
174
  is_causal=is_causal,
175
+ dropout_key=dropout_key,
174
176
  )
175
177
  else:
176
178
  attn_output, attn_output_weights = multi_head_attention_forward(
@@ -193,6 +195,7 @@ class MultiheadAttention(eqx.Module):
193
195
  attn_mask=attn_mask,
194
196
  average_attn_weights=average_attn_weights,
195
197
  is_causal=is_causal,
198
+ dropout_key=dropout_key,
196
199
  )
197
200
 
198
201
  return attn_output, attn_output_weights