jaxonlayers 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxonlayers/functions/embedding.py +21 -0
- jaxonlayers/functions/state_space.py +64 -15
- jaxonlayers/functions/utils.py +49 -0
- jaxonlayers/layers/attention.py +2 -2
- jaxonlayers/layers/state_space.py +34 -19
- jaxonlayers-0.1.2.dist-info/METADATA +16 -0
- {jaxonlayers-0.1.0.dist-info → jaxonlayers-0.1.2.dist-info}/RECORD +8 -8
- {jaxonlayers-0.1.0.dist-info → jaxonlayers-0.1.2.dist-info}/WHEEL +1 -2
- jaxonlayers-0.1.0.dist-info/METADATA +0 -10
- jaxonlayers-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jaxtyping import Array, Float, Int
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def sinusoidal_embedding(
|
|
6
|
+
t: Int[Array, ""], embedding_size: int
|
|
7
|
+
) -> Float[Array, " embedding_size"]:
|
|
8
|
+
if embedding_size % 2 != 0:
|
|
9
|
+
raise ValueError(f"Embedding size must be even, but got {embedding_size}")
|
|
10
|
+
|
|
11
|
+
half_dim = embedding_size // 2
|
|
12
|
+
embedding_freqs = jnp.exp(
|
|
13
|
+
-jnp.log(10000)
|
|
14
|
+
* jnp.arange(start=0, stop=half_dim, dtype=jnp.float32)
|
|
15
|
+
/ half_dim
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
time_args = t * embedding_freqs
|
|
19
|
+
embedding = jnp.concatenate([jnp.sin(time_args), jnp.cos(time_args)])
|
|
20
|
+
|
|
21
|
+
return embedding
|
|
@@ -4,27 +4,76 @@ from jaxtyping import Array, Float
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def selective_scan(
|
|
7
|
-
|
|
7
|
+
u: Float[Array, "seq_length d_inner"],
|
|
8
8
|
delta: Float[Array, "seq_length d_inner"],
|
|
9
9
|
A: Float[Array, "d_inner d_state"],
|
|
10
|
-
B: Float[Array, "seq_length d_state"],
|
|
11
|
-
C: Float[Array, "seq_length d_state"],
|
|
10
|
+
B: Float[Array, "seq_length d_inner d_state"],
|
|
11
|
+
C: Float[Array, "seq_length d_inner d_state"],
|
|
12
12
|
D: Float[Array, " d_inner"],
|
|
13
|
+
chunk_size: int = 128,
|
|
13
14
|
) -> Float[Array, "seq_length d_inner"]:
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A))
|
|
17
|
-
delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, x)
|
|
15
|
+
deltaA = jnp.exp(jnp.einsum("l d, d n -> l d n", delta, A))
|
|
16
|
+
deltaB_u = jnp.einsum("l d, l d n, l d -> l d n", delta, B, u)
|
|
18
17
|
|
|
19
|
-
|
|
18
|
+
seq_len, d_inner = u.shape
|
|
19
|
+
d_state = A.shape[1]
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
num_chunks = (seq_len + chunk_size - 1) // chunk_size
|
|
22
|
+
padded_len = num_chunks * chunk_size
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
pad_len = padded_len - seq_len
|
|
25
|
+
deltaA_padded = jnp.pad(deltaA, ((0, pad_len), (0, 0), (0, 0)))
|
|
26
|
+
deltaB_u_padded = jnp.pad(deltaB_u, ((0, pad_len), (0, 0), (0, 0)))
|
|
27
|
+
C_padded = jnp.pad(C, ((0, pad_len), (0, 0), (0, 0)))
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
deltaA_chunked = deltaA_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
|
|
30
|
+
deltaB_u_chunked = deltaB_u_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
|
|
31
|
+
C_chunked = C_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
|
|
28
32
|
|
|
29
|
-
|
|
30
|
-
|
|
33
|
+
def intra_chunk_step(h_prev, scan_inputs):
|
|
34
|
+
deltaA_i, deltaB_u_i, C_i = scan_inputs
|
|
35
|
+
h_i = deltaA_i * h_prev + deltaB_u_i
|
|
36
|
+
y_i = jnp.einsum("d n, d n -> d", h_i, C_i)
|
|
37
|
+
return h_i, y_i
|
|
38
|
+
|
|
39
|
+
h0 = jnp.zeros((d_inner, d_state))
|
|
40
|
+
|
|
41
|
+
_, y_chunks = jax.vmap(jax.lax.scan, in_axes=(None, None, 0))(
|
|
42
|
+
intra_chunk_step, h0, (deltaA_chunked, deltaB_u_chunked, C_chunked)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def inter_chunk_step(carry_prev, scan_inputs):
|
|
46
|
+
A_prev, h_prev = carry_prev
|
|
47
|
+
deltaA_i, deltaB_u_i = scan_inputs
|
|
48
|
+
|
|
49
|
+
A_new = deltaA_i * A_prev
|
|
50
|
+
h_new = deltaA_i * h_prev + deltaB_u_i
|
|
51
|
+
|
|
52
|
+
return (A_new, h_new), (A_new, h_new)
|
|
53
|
+
|
|
54
|
+
A_carry_initial = jnp.ones((d_inner, d_state))
|
|
55
|
+
h_carry_initial = jnp.zeros((d_inner, d_state))
|
|
56
|
+
initial_carry = (A_carry_initial, h_carry_initial)
|
|
57
|
+
|
|
58
|
+
scan_inputs = (deltaA_chunked[:, -1], deltaB_u_chunked[:, -1])
|
|
59
|
+
|
|
60
|
+
_, (A_carry, h_carry) = jax.lax.scan(inter_chunk_step, initial_carry, scan_inputs)
|
|
61
|
+
|
|
62
|
+
A_carry = jnp.roll(A_carry, 1, axis=0)
|
|
63
|
+
h_carry = jnp.roll(h_carry, 1, axis=0)
|
|
64
|
+
A_carry = A_carry.at[0].set(jnp.ones_like(A_carry[0]))
|
|
65
|
+
h_carry = h_carry.at[0].set(jnp.zeros_like(h_carry[0]))
|
|
66
|
+
|
|
67
|
+
h_carry_broadcast = jnp.expand_dims(h_carry, axis=1)
|
|
68
|
+
h_correction = deltaA_chunked * h_carry_broadcast
|
|
69
|
+
y_carry = jnp.einsum("csdn, csdn -> csd", C_chunked, h_correction)
|
|
70
|
+
|
|
71
|
+
y_final = y_chunks + y_carry
|
|
72
|
+
|
|
73
|
+
y_final = y_final.reshape(padded_len, d_inner)
|
|
74
|
+
|
|
75
|
+
y_unpadded = y_final[:seq_len]
|
|
76
|
+
|
|
77
|
+
output = y_unpadded.real + u * D
|
|
78
|
+
|
|
79
|
+
return output
|
jaxonlayers/functions/utils.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import equinox as eqx
|
|
1
2
|
import jax
|
|
2
3
|
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import PyTree
|
|
3
5
|
|
|
4
6
|
|
|
5
7
|
def default_floating_dtype():
|
|
@@ -7,3 +9,50 @@ def default_floating_dtype():
|
|
|
7
9
|
return jnp.float64
|
|
8
10
|
else:
|
|
9
11
|
return jnp.float32
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def summarize_model(model: PyTree) -> str:
|
|
15
|
+
params, _ = eqx.partition(model, eqx.is_array)
|
|
16
|
+
|
|
17
|
+
param_counts = {}
|
|
18
|
+
total_params = 0
|
|
19
|
+
|
|
20
|
+
def count_params(pytree, name=""):
|
|
21
|
+
nonlocal total_params
|
|
22
|
+
count = 0
|
|
23
|
+
if isinstance(pytree, jnp.ndarray):
|
|
24
|
+
count = pytree.size
|
|
25
|
+
total_params += count
|
|
26
|
+
if name:
|
|
27
|
+
param_counts[name] = count
|
|
28
|
+
elif hasattr(pytree, "__dict__"):
|
|
29
|
+
for key, value in pytree.__dict__.items():
|
|
30
|
+
subname = f"{name}.{key}" if name else key
|
|
31
|
+
count += count_params(value, subname)
|
|
32
|
+
elif isinstance(pytree, (list, tuple)):
|
|
33
|
+
for i, value in enumerate(pytree):
|
|
34
|
+
subname = f"{name}[{i}]" if name else f"[{i}]"
|
|
35
|
+
count += count_params(value, subname)
|
|
36
|
+
elif isinstance(pytree, dict):
|
|
37
|
+
for key, value in pytree.items():
|
|
38
|
+
subname = f"{name}.{key}" if name else str(key)
|
|
39
|
+
count += count_params(value, subname)
|
|
40
|
+
return count
|
|
41
|
+
|
|
42
|
+
count_params(params)
|
|
43
|
+
|
|
44
|
+
# Display as table
|
|
45
|
+
lines = []
|
|
46
|
+
lines.append("Model Parameter Summary")
|
|
47
|
+
lines.append("=" * 50)
|
|
48
|
+
lines.append(f"{'Parameter Name':<30} {'Count':<15}")
|
|
49
|
+
lines.append("-" * 50)
|
|
50
|
+
|
|
51
|
+
for name, count in param_counts.items():
|
|
52
|
+
lines.append(f"{name:<30} {count:<15,}")
|
|
53
|
+
|
|
54
|
+
lines.append("-" * 50)
|
|
55
|
+
lines.append(f"{'Total Parameters':<30} {total_params:<15,}")
|
|
56
|
+
lines.append("=" * 50)
|
|
57
|
+
|
|
58
|
+
return "\n".join(lines)
|
jaxonlayers/layers/attention.py
CHANGED
|
@@ -11,14 +11,17 @@ class SelectiveStateSpace(eqx.Module):
|
|
|
11
11
|
input_proj: eqx.nn.Linear
|
|
12
12
|
delta_proj: eqx.nn.Linear
|
|
13
13
|
A_log: Float[Array, "d_inner d_state"]
|
|
14
|
-
D: Float[Array, "
|
|
14
|
+
D: Float[Array, "d_inner"]
|
|
15
|
+
out_proj: eqx.nn.Linear
|
|
15
16
|
|
|
16
17
|
d_inner: int = eqx.field(static=True)
|
|
17
18
|
dt_rank: int = eqx.field(static=True)
|
|
18
19
|
d_state: int = eqx.field(static=True)
|
|
20
|
+
d_model: int = eqx.field(static=True)
|
|
19
21
|
|
|
20
22
|
def __init__(
|
|
21
23
|
self,
|
|
24
|
+
d_model: int,
|
|
22
25
|
d_inner: int,
|
|
23
26
|
dt_rank: int,
|
|
24
27
|
d_state: int,
|
|
@@ -31,19 +34,19 @@ class SelectiveStateSpace(eqx.Module):
|
|
|
31
34
|
if dtype is None:
|
|
32
35
|
dtype = default_floating_dtype()
|
|
33
36
|
assert dtype is not None
|
|
37
|
+
|
|
38
|
+
self.d_model = d_model
|
|
34
39
|
self.d_inner = d_inner
|
|
35
40
|
self.dt_rank = dt_rank
|
|
36
41
|
self.d_state = d_state
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
delta_proj_key,
|
|
41
|
-
) = jax.random.split(key, 3)
|
|
42
|
+
|
|
43
|
+
keys = jax.random.split(key, 4)
|
|
44
|
+
proj_dim = self.dt_rank + 2 * self.d_inner * self.d_state
|
|
42
45
|
self.input_proj = eqx.nn.Linear(
|
|
43
|
-
|
|
44
|
-
|
|
46
|
+
self.d_model,
|
|
47
|
+
proj_dim,
|
|
45
48
|
use_bias=use_input_proj_bias,
|
|
46
|
-
key=
|
|
49
|
+
key=keys[0],
|
|
47
50
|
dtype=dtype,
|
|
48
51
|
)
|
|
49
52
|
|
|
@@ -51,25 +54,37 @@ class SelectiveStateSpace(eqx.Module):
|
|
|
51
54
|
dt_rank,
|
|
52
55
|
d_inner,
|
|
53
56
|
use_bias=use_delta_proj_bias,
|
|
54
|
-
key=
|
|
57
|
+
key=keys[1],
|
|
55
58
|
dtype=dtype,
|
|
56
59
|
)
|
|
57
|
-
|
|
60
|
+
|
|
61
|
+
A = jnp.arange(1, d_state + 1, dtype=jnp.float32)
|
|
62
|
+
A = jnp.tile(A, (d_inner, 1))
|
|
58
63
|
self.A_log = jnp.log(A)
|
|
64
|
+
|
|
59
65
|
self.D = jnp.ones(d_inner, dtype=dtype)
|
|
60
66
|
|
|
67
|
+
self.out_proj = eqx.nn.Linear(
|
|
68
|
+
d_inner, d_model, use_bias=False, key=keys[2], dtype=dtype
|
|
69
|
+
)
|
|
70
|
+
|
|
61
71
|
def __call__(self, x: Float[Array, "seq_length d_inner"]):
|
|
62
|
-
|
|
63
|
-
|
|
72
|
+
L, _ = x.shape
|
|
73
|
+
A = -jnp.exp(self.A_log.astype(jnp.float32))
|
|
74
|
+
D = self.D.astype(jnp.float32)
|
|
64
75
|
|
|
65
76
|
delta_b_c = jax.vmap(self.input_proj)(x)
|
|
77
|
+
delta, B, C = jnp.split(
|
|
78
|
+
delta_b_c,
|
|
79
|
+
[self.dt_rank, self.dt_rank + self.d_inner * self.d_state],
|
|
80
|
+
axis=-1,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
B = B.reshape(L, self.d_inner, self.d_state)
|
|
84
|
+
C = C.reshape(L, self.d_inner, self.d_state)
|
|
66
85
|
|
|
67
|
-
split_indices = [
|
|
68
|
-
self.dt_rank,
|
|
69
|
-
self.dt_rank + self.d_state,
|
|
70
|
-
]
|
|
71
|
-
delta, B, C = jnp.split(delta_b_c, split_indices, axis=-1)
|
|
72
86
|
delta = jax.nn.softplus(jax.vmap(self.delta_proj)(delta))
|
|
73
87
|
|
|
74
88
|
y = selective_scan(x, delta, A, B, C, D)
|
|
75
|
-
|
|
89
|
+
|
|
90
|
+
return jax.vmap(self.out_proj)(y)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jaxonlayers
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.13
|
|
6
|
+
Requires-Dist: beartype>=0.21.0
|
|
7
|
+
Requires-Dist: equinox>=0.13.0
|
|
8
|
+
Requires-Dist: jax>=0.8.0
|
|
9
|
+
Requires-Dist: jaxtyping>=0.3.2
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# jaxonlayers
|
|
13
|
+
|
|
14
|
+
This library provides some utility function and useful layers that extend the [Equinox](https://github.com/patrick-kidger/equinox) library.
|
|
15
|
+
|
|
16
|
+
The aim was to create them to be the PyTorch equivalent, i.e. to match their PyTorch counterpart's output.
|
|
@@ -1,21 +1,21 @@
|
|
|
1
1
|
jaxonlayers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
jaxonlayers/functions/__init__.py,sha256=lxMNSnEn2pJ1XLLyQTTFXiCcJTdrzfE1RlorqD9lEog,711
|
|
3
3
|
jaxonlayers/functions/attention.py,sha256=AyaAeA2yo5Cgljk2rU6JlVttQeG0FMBLE-f7285PdM0,12945
|
|
4
|
+
jaxonlayers/functions/embedding.py,sha256=g76Wg_MSyqatwjWrhDSfAxDaMyr2s3A5AoST_cRa3Q8,601
|
|
4
5
|
jaxonlayers/functions/initialization.py,sha256=h7uzdPl-rL7faT9hbRs5aN6EeaX0r70y88ad1cwZvmY,1161
|
|
5
6
|
jaxonlayers/functions/masking.py,sha256=5VeHMuoVeKxAHtzGb74GqvKHfib4wJAETTAn1oaOgLM,1531
|
|
6
7
|
jaxonlayers/functions/normalization.py,sha256=e2vNNbDz-Y6j5bgQshd4MshnzLcDOjUnGSinpRXtPtA,206
|
|
7
8
|
jaxonlayers/functions/regularization.py,sha256=ltFVeJZFhdNrHsH0CjDl9Y3dlAQmh1ABntCjpYUBgXM,1317
|
|
8
|
-
jaxonlayers/functions/state_space.py,sha256=
|
|
9
|
-
jaxonlayers/functions/utils.py,sha256=
|
|
9
|
+
jaxonlayers/functions/state_space.py,sha256=_PdykvDSRZWmIQR5KZTUN6shLhPGElio9THnoVpLJ_g,2747
|
|
10
|
+
jaxonlayers/functions/utils.py,sha256=M1uc01yOADy2ig6YsIyn0JBf_zdFuJS-rxin5RVub-A,1749
|
|
10
11
|
jaxonlayers/layers/__init__.py,sha256=gjH0QUOCplii5KNYshWNehPfHZLVjEeYk8EhhY7FHRE,480
|
|
11
12
|
jaxonlayers/layers/abstract.py,sha256=uyPKGsH5DETnjiU3PUiF98tk9boNZW96amxrM5JQlZY,366
|
|
12
|
-
jaxonlayers/layers/attention.py,sha256=
|
|
13
|
+
jaxonlayers/layers/attention.py,sha256=NdjDnA3yYFpvXD2_SjBwF-8Pdk_GeRLPuPekTR-iMwE,7784
|
|
13
14
|
jaxonlayers/layers/convolution.py,sha256=k0dMFBDjzycB7UNuyHqKihJtBa6u93V6OLxyUUyipN4,3247
|
|
14
15
|
jaxonlayers/layers/normalization.py,sha256=3aGzNzDN05A72ZHLUM2w9WpicLtGsjzj1l0jhuyn63U,8379
|
|
15
16
|
jaxonlayers/layers/regularization.py,sha256=ZrvtBJPH84xuxrxEbZc7TBxjp8OvKEv4ecan5s8F9zs,563
|
|
16
17
|
jaxonlayers/layers/sequential.py,sha256=Tw98hNZiXMC-CYZD6h_pi7eAxkgHeQAUvZF2I9H0d8Y,2833
|
|
17
|
-
jaxonlayers/layers/state_space.py,sha256=
|
|
18
|
-
jaxonlayers-0.1.
|
|
19
|
-
jaxonlayers-0.1.
|
|
20
|
-
jaxonlayers-0.1.
|
|
21
|
-
jaxonlayers-0.1.0.dist-info/RECORD,,
|
|
18
|
+
jaxonlayers/layers/state_space.py,sha256=oDVRbduNtU48Q4rLd-XywZcqVN0QYTlq1UUhOXcGLoo,2537
|
|
19
|
+
jaxonlayers-0.1.2.dist-info/METADATA,sha256=NGl7HoPwh1sC8JmgOeBD3kVEJgy5tmTR7jhlKM9um7k,539
|
|
20
|
+
jaxonlayers-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
21
|
+
jaxonlayers-0.1.2.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: jaxonlayers
|
|
3
|
-
Version: 0.1.0
|
|
4
|
-
Summary: Add your description here
|
|
5
|
-
Requires-Python: >=3.13
|
|
6
|
-
Description-Content-Type: text/markdown
|
|
7
|
-
Requires-Dist: beartype>=0.21.0
|
|
8
|
-
Requires-Dist: equinox>=0.13.0
|
|
9
|
-
Requires-Dist: jax>=0.7.2
|
|
10
|
-
Requires-Dist: jaxtyping>=0.3.2
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
jaxonlayers
|