jaxonlayers 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,4 +5,4 @@ from jaxtyping import Array
5
5
 
6
6
  def swiglu(x: Array, axis=-1):
7
7
  a, b = jnp.split(x, 2, axis=axis)
8
- return a * jax.nn.swish(b)
8
+ return jax.nn.swish(a) * b
@@ -0,0 +1,56 @@
1
+ import equinox as eqx
2
+ from beartype.typing import Any
3
+ from jaxtyping import Array, PRNGKeyArray
4
+
5
+ from jaxonlayers.functions import default_floating_dtype
6
+
7
+
8
+ class EmbeddingWithPadding(eqx.Module):
9
+ embed: eqx.nn.Embedding
10
+ padding_idx: int = eqx.field(static=True)
11
+
12
+ def __init__(
13
+ self,
14
+ num_embeddings: int,
15
+ embedding_dim: int,
16
+ padding_idx: int = 0,
17
+ *,
18
+ key: PRNGKeyArray,
19
+ dtype: Any | None = None,
20
+ ):
21
+ if dtype is None:
22
+ dtype = default_floating_dtype()
23
+ assert dtype is not None
24
+ self.embed = eqx.nn.Embedding(
25
+ num_embeddings, embedding_dim, key=key, dtype=dtype
26
+ )
27
+ self.padding_idx = padding_idx
28
+
29
+ def __call__(self, x: Array):
30
+ out = self.embed(x)
31
+ mask = (x != self.padding_idx).astype(out.dtype)
32
+ return out * mask[..., None]
33
+
34
+
35
+ class EmbeddingBag(eqx.Module):
36
+ embed: EmbeddingWithPadding
37
+
38
+ def __init__(
39
+ self,
40
+ num_embeddings: int,
41
+ embedding_dim: int,
42
+ padding_idx: int = 0,
43
+ *,
44
+ key: PRNGKeyArray,
45
+ dtype: Any | None = None,
46
+ ):
47
+ if dtype is None:
48
+ dtype = default_floating_dtype()
49
+ assert dtype is not None
50
+ self.embed = EmbeddingWithPadding(
51
+ num_embeddings, embedding_dim, padding_idx=padding_idx, key=key, dtype=dtype
52
+ )
53
+
54
+ def __call__(self, x):
55
+ looked_up = eqx.filter_vmap(self.embed)(x)
56
+ return looked_up.sum(axis=0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxonlayers
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Additional layers and functions that extend Equinox
5
5
  Requires-Python: >=3.13
6
6
  Requires-Dist: beartype>=0.21.0
@@ -1,6 +1,6 @@
1
1
  jaxonlayers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  jaxonlayers/functions/__init__.py,sha256=lxMNSnEn2pJ1XLLyQTTFXiCcJTdrzfE1RlorqD9lEog,711
3
- jaxonlayers/functions/activation.py,sha256=cjwDEadbJECt1AE1uEtKV_p_wBK_pJ7Bdud0ebm_GNQ,165
3
+ jaxonlayers/functions/activation.py,sha256=cVR1Obcv5JE7GiUNYlxHCcf5bpqZ9OQq8zYwSBJ-svY,165
4
4
  jaxonlayers/functions/attention.py,sha256=AyaAeA2yo5Cgljk2rU6JlVttQeG0FMBLE-f7285PdM0,12945
5
5
  jaxonlayers/functions/embedding.py,sha256=GsXxpB3G72aLY1SHdZk7gxKbNftlEZ3HRDy-7_5z4HM,610
6
6
  jaxonlayers/functions/initialization.py,sha256=h7uzdPl-rL7faT9hbRs5aN6EeaX0r70y88ad1cwZvmY,1161
@@ -13,11 +13,12 @@ jaxonlayers/layers/__init__.py,sha256=aSYMojd0wKdGLxrPxZi_x1c4n8JTJvjAr2vNXC1zYP
13
13
  jaxonlayers/layers/abstract.py,sha256=uyPKGsH5DETnjiU3PUiF98tk9boNZW96amxrM5JQlZY,366
14
14
  jaxonlayers/layers/attention.py,sha256=RgtpzBPxJ4tDcUjiq_Wh_7GJndmBY6UKtbEuHGLA1PY,7915
15
15
  jaxonlayers/layers/convolution.py,sha256=k0dMFBDjzycB7UNuyHqKihJtBa6u93V6OLxyUUyipN4,3247
16
+ jaxonlayers/layers/embedding.py,sha256=G0LyBZpKee4S6HqhgxdH6rCJNZwMomKNeDmYqRMIPGI,1509
16
17
  jaxonlayers/layers/normalization.py,sha256=3aGzNzDN05A72ZHLUM2w9WpicLtGsjzj1l0jhuyn63U,8379
17
18
  jaxonlayers/layers/regularization.py,sha256=ZrvtBJPH84xuxrxEbZc7TBxjp8OvKEv4ecan5s8F9zs,563
18
19
  jaxonlayers/layers/sequential.py,sha256=xBZavhSra0oZUZjSlThzIGER5xR62n46mgTkPqRV2Y0,2843
19
20
  jaxonlayers/layers/state_space.py,sha256=oDVRbduNtU48Q4rLd-XywZcqVN0QYTlq1UUhOXcGLoo,2537
20
21
  jaxonlayers/layers/transformer.py,sha256=Syuh_kN-I-Gg-C20trQcNio1WEG88OHDRl3ShFaQNqQ,21848
21
- jaxonlayers-0.2.3.dist-info/METADATA,sha256=-Pde3UQPCm39v_dIWtAZ4FjKE8CRgl8WbVbsM5-qbAM,565
22
- jaxonlayers-0.2.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
23
- jaxonlayers-0.2.3.dist-info/RECORD,,
22
+ jaxonlayers-0.2.5.dist-info/METADATA,sha256=TaYvkI9lsjDOwKl-rXCnYXZlof9xh1AgxdANXyWzn2o,565
23
+ jaxonlayers-0.2.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
24
+ jaxonlayers-0.2.5.dist-info/RECORD,,