jaxonlayers 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
from beartype.typing import Any
|
|
3
|
+
from jaxtyping import Array, PRNGKeyArray
|
|
4
|
+
|
|
5
|
+
from jaxonlayers.functions import default_floating_dtype
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EmbeddingWithPadding(eqx.Module):
|
|
9
|
+
embed: eqx.nn.Embedding
|
|
10
|
+
padding_idx: int = eqx.field(static=True)
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
num_embeddings: int,
|
|
15
|
+
embedding_dim: int,
|
|
16
|
+
padding_idx: int = 0,
|
|
17
|
+
*,
|
|
18
|
+
key: PRNGKeyArray,
|
|
19
|
+
dtype: Any | None = None,
|
|
20
|
+
):
|
|
21
|
+
if dtype is None:
|
|
22
|
+
dtype = default_floating_dtype()
|
|
23
|
+
assert dtype is not None
|
|
24
|
+
self.embed = eqx.nn.Embedding(
|
|
25
|
+
num_embeddings, embedding_dim, key=key, dtype=dtype
|
|
26
|
+
)
|
|
27
|
+
self.padding_idx = padding_idx
|
|
28
|
+
|
|
29
|
+
def __call__(self, x: Array):
|
|
30
|
+
out = self.embed(x)
|
|
31
|
+
mask = (x != self.padding_idx).astype(out.dtype)
|
|
32
|
+
return out * mask[..., None]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class EmbeddingBag(eqx.Module):
|
|
36
|
+
embed: EmbeddingWithPadding
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
num_embeddings: int,
|
|
41
|
+
embedding_dim: int,
|
|
42
|
+
padding_idx: int = 0,
|
|
43
|
+
*,
|
|
44
|
+
key: PRNGKeyArray,
|
|
45
|
+
dtype: Any | None = None,
|
|
46
|
+
):
|
|
47
|
+
if dtype is None:
|
|
48
|
+
dtype = default_floating_dtype()
|
|
49
|
+
assert dtype is not None
|
|
50
|
+
self.embed = EmbeddingWithPadding(
|
|
51
|
+
num_embeddings, embedding_dim, padding_idx=padding_idx, key=key, dtype=dtype
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def __call__(self, x):
|
|
55
|
+
looked_up = eqx.filter_vmap(self.embed)(x)
|
|
56
|
+
return looked_up.sum(axis=0)
|
|
@@ -13,11 +13,12 @@ jaxonlayers/layers/__init__.py,sha256=aSYMojd0wKdGLxrPxZi_x1c4n8JTJvjAr2vNXC1zYP
|
|
|
13
13
|
jaxonlayers/layers/abstract.py,sha256=uyPKGsH5DETnjiU3PUiF98tk9boNZW96amxrM5JQlZY,366
|
|
14
14
|
jaxonlayers/layers/attention.py,sha256=RgtpzBPxJ4tDcUjiq_Wh_7GJndmBY6UKtbEuHGLA1PY,7915
|
|
15
15
|
jaxonlayers/layers/convolution.py,sha256=k0dMFBDjzycB7UNuyHqKihJtBa6u93V6OLxyUUyipN4,3247
|
|
16
|
+
jaxonlayers/layers/embedding.py,sha256=G0LyBZpKee4S6HqhgxdH6rCJNZwMomKNeDmYqRMIPGI,1509
|
|
16
17
|
jaxonlayers/layers/normalization.py,sha256=3aGzNzDN05A72ZHLUM2w9WpicLtGsjzj1l0jhuyn63U,8379
|
|
17
18
|
jaxonlayers/layers/regularization.py,sha256=ZrvtBJPH84xuxrxEbZc7TBxjp8OvKEv4ecan5s8F9zs,563
|
|
18
19
|
jaxonlayers/layers/sequential.py,sha256=xBZavhSra0oZUZjSlThzIGER5xR62n46mgTkPqRV2Y0,2843
|
|
19
20
|
jaxonlayers/layers/state_space.py,sha256=oDVRbduNtU48Q4rLd-XywZcqVN0QYTlq1UUhOXcGLoo,2537
|
|
20
21
|
jaxonlayers/layers/transformer.py,sha256=Syuh_kN-I-Gg-C20trQcNio1WEG88OHDRl3ShFaQNqQ,21848
|
|
21
|
-
jaxonlayers-0.2.
|
|
22
|
-
jaxonlayers-0.2.
|
|
23
|
-
jaxonlayers-0.2.
|
|
22
|
+
jaxonlayers-0.2.5.dist-info/METADATA,sha256=TaYvkI9lsjDOwKl-rXCnYXZlof9xh1AgxdANXyWzn2o,565
|
|
23
|
+
jaxonlayers-0.2.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
24
|
+
jaxonlayers-0.2.5.dist-info/RECORD,,
|
|
File without changes
|