jaxonlayers 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ import jax.numpy as jnp
2
+ from jaxtyping import Array, Float, Int
3
+
4
+
5
+ def sinusoidal_embedding(
6
+ t: Int[Array, ""], embedding_size: int
7
+ ) -> Float[Array, " embedding_size"]:
8
+ if embedding_size % 2 != 0:
9
+ raise ValueError(f"Embedding size must be even, but got {embedding_size}")
10
+
11
+ half_dim = embedding_size // 2
12
+ embedding_freqs = jnp.exp(
13
+ -jnp.log(10000)
14
+ * jnp.arange(start=0, stop=half_dim, dtype=jnp.float32)
15
+ / half_dim
16
+ )
17
+
18
+ time_args = t * embedding_freqs
19
+ embedding = jnp.concatenate([jnp.sin(time_args), jnp.cos(time_args)])
20
+
21
+ return embedding
@@ -4,27 +4,76 @@ from jaxtyping import Array, Float
4
4
 
5
5
 
6
6
  def selective_scan(
7
- x: Float[Array, "seq_length d_inner"],
7
+ u: Float[Array, "seq_length d_inner"],
8
8
  delta: Float[Array, "seq_length d_inner"],
9
9
  A: Float[Array, "d_inner d_state"],
10
- B: Float[Array, "seq_length d_state"],
11
- C: Float[Array, "seq_length d_state"],
10
+ B: Float[Array, "seq_length d_inner d_state"],
11
+ C: Float[Array, "seq_length d_inner d_state"],
12
12
  D: Float[Array, " d_inner"],
13
+ chunk_size: int = 128,
13
14
  ) -> Float[Array, "seq_length d_inner"]:
14
- L, d_inner = x.shape
15
- _, d_state = A.shape
16
- delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A))
17
- delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, x)
15
+ deltaA = jnp.exp(jnp.einsum("l d, d n -> l d n", delta, A))
16
+ deltaB_u = jnp.einsum("l d, l d n, l d -> l d n", delta, B, u)
18
17
 
19
- x_res = jnp.zeros(shape=(d_inner, d_state))
18
+ seq_len, d_inner = u.shape
19
+ d_state = A.shape[1]
20
20
 
21
- def step(x, i):
22
- x = delta_A[i] * x + delta_B_u[i]
21
+ num_chunks = (seq_len + chunk_size - 1) // chunk_size
22
+ padded_len = num_chunks * chunk_size
23
23
 
24
- y = jnp.einsum("d n,n -> d", x, C[i, :])
25
- return x, y
24
+ pad_len = padded_len - seq_len
25
+ deltaA_padded = jnp.pad(deltaA, ((0, pad_len), (0, 0), (0, 0)))
26
+ deltaB_u_padded = jnp.pad(deltaB_u, ((0, pad_len), (0, 0), (0, 0)))
27
+ C_padded = jnp.pad(C, ((0, pad_len), (0, 0), (0, 0)))
26
28
 
27
- _, ys = jax.lax.scan(step, x_res, jnp.arange(L))
29
+ deltaA_chunked = deltaA_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
30
+ deltaB_u_chunked = deltaB_u_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
31
+ C_chunked = C_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
28
32
 
29
- ys = ys + x * D
30
- return ys
33
+ def intra_chunk_step(h_prev, scan_inputs):
34
+ deltaA_i, deltaB_u_i, C_i = scan_inputs
35
+ h_i = deltaA_i * h_prev + deltaB_u_i
36
+ y_i = jnp.einsum("d n, d n -> d", h_i, C_i)
37
+ return h_i, y_i
38
+
39
+ h0 = jnp.zeros((d_inner, d_state))
40
+
41
+ _, y_chunks = jax.vmap(jax.lax.scan, in_axes=(None, None, 0))(
42
+ intra_chunk_step, h0, (deltaA_chunked, deltaB_u_chunked, C_chunked)
43
+ )
44
+
45
+ def inter_chunk_step(carry_prev, scan_inputs):
46
+ A_prev, h_prev = carry_prev
47
+ deltaA_i, deltaB_u_i = scan_inputs
48
+
49
+ A_new = deltaA_i * A_prev
50
+ h_new = deltaA_i * h_prev + deltaB_u_i
51
+
52
+ return (A_new, h_new), (A_new, h_new)
53
+
54
+ A_carry_initial = jnp.ones((d_inner, d_state))
55
+ h_carry_initial = jnp.zeros((d_inner, d_state))
56
+ initial_carry = (A_carry_initial, h_carry_initial)
57
+
58
+ scan_inputs = (deltaA_chunked[:, -1], deltaB_u_chunked[:, -1])
59
+
60
+ _, (A_carry, h_carry) = jax.lax.scan(inter_chunk_step, initial_carry, scan_inputs)
61
+
62
+ A_carry = jnp.roll(A_carry, 1, axis=0)
63
+ h_carry = jnp.roll(h_carry, 1, axis=0)
64
+ A_carry = A_carry.at[0].set(jnp.ones_like(A_carry[0]))
65
+ h_carry = h_carry.at[0].set(jnp.zeros_like(h_carry[0]))
66
+
67
+ h_carry_broadcast = jnp.expand_dims(h_carry, axis=1)
68
+ h_correction = deltaA_chunked * h_carry_broadcast
69
+ y_carry = jnp.einsum("csdn, csdn -> csd", C_chunked, h_correction)
70
+
71
+ y_final = y_chunks + y_carry
72
+
73
+ y_final = y_final.reshape(padded_len, d_inner)
74
+
75
+ y_unpadded = y_final[:seq_len]
76
+
77
+ output = y_unpadded.real + u * D
78
+
79
+ return output
@@ -1,5 +1,7 @@
1
+ import equinox as eqx
1
2
  import jax
2
3
  import jax.numpy as jnp
4
+ from jaxtyping import PyTree
3
5
 
4
6
 
5
7
  def default_floating_dtype():
@@ -7,3 +9,50 @@ def default_floating_dtype():
7
9
  return jnp.float64
8
10
  else:
9
11
  return jnp.float32
12
+
13
+
14
+ def summarize_model(model: PyTree) -> str:
15
+ params, _ = eqx.partition(model, eqx.is_array)
16
+
17
+ param_counts = {}
18
+ total_params = 0
19
+
20
+ def count_params(pytree, name=""):
21
+ nonlocal total_params
22
+ count = 0
23
+ if isinstance(pytree, jnp.ndarray):
24
+ count = pytree.size
25
+ total_params += count
26
+ if name:
27
+ param_counts[name] = count
28
+ elif hasattr(pytree, "__dict__"):
29
+ for key, value in pytree.__dict__.items():
30
+ subname = f"{name}.{key}" if name else key
31
+ count += count_params(value, subname)
32
+ elif isinstance(pytree, (list, tuple)):
33
+ for i, value in enumerate(pytree):
34
+ subname = f"{name}[{i}]" if name else f"[{i}]"
35
+ count += count_params(value, subname)
36
+ elif isinstance(pytree, dict):
37
+ for key, value in pytree.items():
38
+ subname = f"{name}.{key}" if name else str(key)
39
+ count += count_params(value, subname)
40
+ return count
41
+
42
+ count_params(params)
43
+
44
+ # Display as table
45
+ lines = []
46
+ lines.append("Model Parameter Summary")
47
+ lines.append("=" * 50)
48
+ lines.append(f"{'Parameter Name':<30} {'Count':<15}")
49
+ lines.append("-" * 50)
50
+
51
+ for name, count in param_counts.items():
52
+ lines.append(f"{name:<30} {count:<15,}")
53
+
54
+ lines.append("-" * 50)
55
+ lines.append(f"{'Total Parameters':<30} {total_params:<15,}")
56
+ lines.append("=" * 50)
57
+
58
+ return "\n".join(lines)
@@ -38,8 +38,8 @@ class MultiheadAttention(eqx.Module):
38
38
 
39
39
  def __init__(
40
40
  self,
41
- embed_dim,
42
- num_heads,
41
+ embed_dim: int,
42
+ num_heads: int,
43
43
  dropout=0.0,
44
44
  bias=True,
45
45
  add_bias_kv=False,
@@ -11,14 +11,17 @@ class SelectiveStateSpace(eqx.Module):
11
11
  input_proj: eqx.nn.Linear
12
12
  delta_proj: eqx.nn.Linear
13
13
  A_log: Float[Array, "d_inner d_state"]
14
- D: Float[Array, " d_inner"]
14
+ D: Float[Array, "d_inner"]
15
+ out_proj: eqx.nn.Linear
15
16
 
16
17
  d_inner: int = eqx.field(static=True)
17
18
  dt_rank: int = eqx.field(static=True)
18
19
  d_state: int = eqx.field(static=True)
20
+ d_model: int = eqx.field(static=True)
19
21
 
20
22
  def __init__(
21
23
  self,
24
+ d_model: int,
22
25
  d_inner: int,
23
26
  dt_rank: int,
24
27
  d_state: int,
@@ -31,19 +34,19 @@ class SelectiveStateSpace(eqx.Module):
31
34
  if dtype is None:
32
35
  dtype = default_floating_dtype()
33
36
  assert dtype is not None
37
+
38
+ self.d_model = d_model
34
39
  self.d_inner = d_inner
35
40
  self.dt_rank = dt_rank
36
41
  self.d_state = d_state
37
- (
38
- key,
39
- input_proj_key,
40
- delta_proj_key,
41
- ) = jax.random.split(key, 3)
42
+
43
+ keys = jax.random.split(key, 4)
44
+ proj_dim = self.dt_rank + 2 * self.d_inner * self.d_state
42
45
  self.input_proj = eqx.nn.Linear(
43
- d_inner,
44
- dt_rank + d_state * 2,
46
+ self.d_model,
47
+ proj_dim,
45
48
  use_bias=use_input_proj_bias,
46
- key=input_proj_key,
49
+ key=keys[0],
47
50
  dtype=dtype,
48
51
  )
49
52
 
@@ -51,25 +54,37 @@ class SelectiveStateSpace(eqx.Module):
51
54
  dt_rank,
52
55
  d_inner,
53
56
  use_bias=use_delta_proj_bias,
54
- key=delta_proj_key,
57
+ key=keys[1],
55
58
  dtype=dtype,
56
59
  )
57
- A = jnp.repeat(jnp.arange(1, d_state + 1), d_inner).reshape(d_inner, d_state)
60
+
61
+ A = jnp.arange(1, d_state + 1, dtype=jnp.float32)
62
+ A = jnp.tile(A, (d_inner, 1))
58
63
  self.A_log = jnp.log(A)
64
+
59
65
  self.D = jnp.ones(d_inner, dtype=dtype)
60
66
 
67
+ self.out_proj = eqx.nn.Linear(
68
+ d_inner, d_model, use_bias=False, key=keys[2], dtype=dtype
69
+ )
70
+
61
71
  def __call__(self, x: Float[Array, "seq_length d_inner"]):
62
- A = -jnp.exp(self.A_log)
63
- D = self.D
72
+ L, _ = x.shape
73
+ A = -jnp.exp(self.A_log.astype(jnp.float32))
74
+ D = self.D.astype(jnp.float32)
64
75
 
65
76
  delta_b_c = jax.vmap(self.input_proj)(x)
77
+ delta, B, C = jnp.split(
78
+ delta_b_c,
79
+ [self.dt_rank, self.dt_rank + self.d_inner * self.d_state],
80
+ axis=-1,
81
+ )
82
+
83
+ B = B.reshape(L, self.d_inner, self.d_state)
84
+ C = C.reshape(L, self.d_inner, self.d_state)
66
85
 
67
- split_indices = [
68
- self.dt_rank,
69
- self.dt_rank + self.d_state,
70
- ]
71
- delta, B, C = jnp.split(delta_b_c, split_indices, axis=-1)
72
86
  delta = jax.nn.softplus(jax.vmap(self.delta_proj)(delta))
73
87
 
74
88
  y = selective_scan(x, delta, A, B, C, D)
75
- return y
89
+
90
+ return jax.vmap(self.out_proj)(y)
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.4
2
+ Name: jaxonlayers
3
+ Version: 0.1.2
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.13
6
+ Requires-Dist: beartype>=0.21.0
7
+ Requires-Dist: equinox>=0.13.0
8
+ Requires-Dist: jax>=0.8.0
9
+ Requires-Dist: jaxtyping>=0.3.2
10
+ Description-Content-Type: text/markdown
11
+
12
+ # jaxonlayers
13
+
14
+ This library provides some utility function and useful layers that extend the [Equinox](https://github.com/patrick-kidger/equinox) library.
15
+
16
+ The aim was to create them to be the PyTorch equivalent, i.e. to match their PyTorch counterpart's output.
@@ -1,21 +1,21 @@
1
1
  jaxonlayers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  jaxonlayers/functions/__init__.py,sha256=lxMNSnEn2pJ1XLLyQTTFXiCcJTdrzfE1RlorqD9lEog,711
3
3
  jaxonlayers/functions/attention.py,sha256=AyaAeA2yo5Cgljk2rU6JlVttQeG0FMBLE-f7285PdM0,12945
4
+ jaxonlayers/functions/embedding.py,sha256=g76Wg_MSyqatwjWrhDSfAxDaMyr2s3A5AoST_cRa3Q8,601
4
5
  jaxonlayers/functions/initialization.py,sha256=h7uzdPl-rL7faT9hbRs5aN6EeaX0r70y88ad1cwZvmY,1161
5
6
  jaxonlayers/functions/masking.py,sha256=5VeHMuoVeKxAHtzGb74GqvKHfib4wJAETTAn1oaOgLM,1531
6
7
  jaxonlayers/functions/normalization.py,sha256=e2vNNbDz-Y6j5bgQshd4MshnzLcDOjUnGSinpRXtPtA,206
7
8
  jaxonlayers/functions/regularization.py,sha256=ltFVeJZFhdNrHsH0CjDl9Y3dlAQmh1ABntCjpYUBgXM,1317
8
- jaxonlayers/functions/state_space.py,sha256=_aIk4hXHZRtU2117z9qlfYjo-iAwA_dCkn9OpUEBtEk,831
9
- jaxonlayers/functions/utils.py,sha256=ahNtu8wf5A3wVjRUg6z5kPktDbEHqsak4sezRu6ap0E,184
9
+ jaxonlayers/functions/state_space.py,sha256=_PdykvDSRZWmIQR5KZTUN6shLhPGElio9THnoVpLJ_g,2747
10
+ jaxonlayers/functions/utils.py,sha256=M1uc01yOADy2ig6YsIyn0JBf_zdFuJS-rxin5RVub-A,1749
10
11
  jaxonlayers/layers/__init__.py,sha256=gjH0QUOCplii5KNYshWNehPfHZLVjEeYk8EhhY7FHRE,480
11
12
  jaxonlayers/layers/abstract.py,sha256=uyPKGsH5DETnjiU3PUiF98tk9boNZW96amxrM5JQlZY,366
12
- jaxonlayers/layers/attention.py,sha256=lAOIhzutQv1TcdECBJF8G95VoGxmPMBMe52S24H0nIo,7774
13
+ jaxonlayers/layers/attention.py,sha256=NdjDnA3yYFpvXD2_SjBwF-8Pdk_GeRLPuPekTR-iMwE,7784
13
14
  jaxonlayers/layers/convolution.py,sha256=k0dMFBDjzycB7UNuyHqKihJtBa6u93V6OLxyUUyipN4,3247
14
15
  jaxonlayers/layers/normalization.py,sha256=3aGzNzDN05A72ZHLUM2w9WpicLtGsjzj1l0jhuyn63U,8379
15
16
  jaxonlayers/layers/regularization.py,sha256=ZrvtBJPH84xuxrxEbZc7TBxjp8OvKEv4ecan5s8F9zs,563
16
17
  jaxonlayers/layers/sequential.py,sha256=Tw98hNZiXMC-CYZD6h_pi7eAxkgHeQAUvZF2I9H0d8Y,2833
17
- jaxonlayers/layers/state_space.py,sha256=Nesj2Ts3mCCqE-u7PeB8roJbQXUql7rG0AIpqUVMqvg,2131
18
- jaxonlayers-0.1.0.dist-info/METADATA,sha256=02mfySFIYtuSbEYqaZUZqTY7op6jWHdICubD77Tr_Cg,275
19
- jaxonlayers-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- jaxonlayers-0.1.0.dist-info/top_level.txt,sha256=n5UHFDErh3dJY77ypkEKlwFOQffKBnpGH9nUrwCinto,12
21
- jaxonlayers-0.1.0.dist-info/RECORD,,
18
+ jaxonlayers/layers/state_space.py,sha256=oDVRbduNtU48Q4rLd-XywZcqVN0QYTlq1UUhOXcGLoo,2537
19
+ jaxonlayers-0.1.2.dist-info/METADATA,sha256=NGl7HoPwh1sC8JmgOeBD3kVEJgy5tmTR7jhlKM9um7k,539
20
+ jaxonlayers-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
21
+ jaxonlayers-0.1.2.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
-
@@ -1,10 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: jaxonlayers
3
- Version: 0.1.0
4
- Summary: Add your description here
5
- Requires-Python: >=3.13
6
- Description-Content-Type: text/markdown
7
- Requires-Dist: beartype>=0.21.0
8
- Requires-Dist: equinox>=0.13.0
9
- Requires-Dist: jax>=0.7.2
10
- Requires-Dist: jaxtyping>=0.3.2
@@ -1 +0,0 @@
1
- jaxonlayers