jaxonlayers 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,27 @@
1
+ from .attention import multi_head_attention_forward, shifted_window_attention
2
+ from .initialization import kaiming_init_conv2d
3
+ from .masking import (
4
+ build_attention_mask,
5
+ canonical_attn_mask,
6
+ canonical_key_padding_mask,
7
+ canonical_mask,
8
+ )
9
+ from .regularization import dropout, stochastic_depth
10
+ from .state_space import selective_scan
11
+ from .utils import (
12
+ default_floating_dtype,
13
+ )
14
+
15
+ __all__ = [
16
+ "multi_head_attention_forward",
17
+ "kaiming_init_conv2d",
18
+ "build_attention_mask",
19
+ "canonical_attn_mask",
20
+ "canonical_key_padding_mask",
21
+ "canonical_mask",
22
+ "stochastic_depth",
23
+ "selective_scan",
24
+ "dropout",
25
+ "default_floating_dtype",
26
+ "shifted_window_attention",
27
+ ]
@@ -0,0 +1,374 @@
1
+ import functools
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from beartype.typing import Any
7
+ from jaxtyping import Array, Bool, Float, PRNGKeyArray
8
+
9
+ from jaxonlayers.functions.normalization import normalize
10
+ from jaxonlayers.functions.regularization import dropout as dropout_fn
11
+ from jaxonlayers.functions.utils import default_floating_dtype
12
+
13
+
14
+ def multi_head_attention_forward(
15
+ query: Float[Array, "tgt_len d_model"],
16
+ key: Float[Array, "src_len d_model"],
17
+ value: Float[Array, "src_len d_model"],
18
+ embed_dim_to_check: int,
19
+ num_heads: int,
20
+ in_proj_weight: Float[Array, "3*d_model d_model"] | None = None,
21
+ in_proj_bias: Float[Array, "3*d_model"] | None = None,
22
+ bias_k: Float[Array, "1 d_model"] | None = None,
23
+ bias_v: Float[Array, "1 d_model"] | None = None,
24
+ add_zero_attn: bool = False,
25
+ dropout_p: float = 0.0,
26
+ out_proj_weight: Float[Array, "d_model d_model"] | None = None,
27
+ out_proj_bias: Float[Array, "d_model"] | None = None,
28
+ inference: bool = False,
29
+ key_padding_mask: Float[Array, "src_len"] | Bool[Array, "src_len"] | None = None,
30
+ attn_mask: Float[Array, "tgt_len src_len"] | None = None,
31
+ need_weights: bool = True,
32
+ use_separate_proj_weight: bool = False,
33
+ q_proj_weight: Float[Array, "d_model d_model"] | None = None,
34
+ k_proj_weight: Float[Array, "d_model d_model"] | None = None,
35
+ v_proj_weight: Float[Array, "d_model d_model"] | None = None,
36
+ static_k: Float[Array, "src_len d_model"] | None = None,
37
+ static_v: Float[Array, "src_len d_model"] | None = None,
38
+ average_attn_weights: bool = True,
39
+ is_causal: bool = False,
40
+ dropout_key: PRNGKeyArray | None = None,
41
+ ) -> tuple[
42
+ Float[Array, "tgt_len d_model"],
43
+ Float[Array, "num_heads tgt_len src_len"]
44
+ | Float[Array, "tgt_len src_len"]
45
+ | Float[Array, "tgt_len src_len+1"]
46
+ | None,
47
+ ]:
48
+ tgt_len, d_model = query.shape
49
+ src_len, k_dim = key.shape
50
+ value_len, v_dim = value.shape
51
+
52
+ assert d_model == k_dim == v_dim == embed_dim_to_check, (
53
+ "Embedding dimensions must match"
54
+ )
55
+
56
+ assert src_len == value_len, "Key and value must have the same sequence length"
57
+
58
+ head_dim = d_model // num_heads
59
+ assert head_dim * num_heads == d_model, "embed_dim must be divisible by num_heads"
60
+
61
+ if dropout_p > 0.0:
62
+ assert dropout_key is not None, (
63
+ "dropout_key must be provided if dropout_p > 0.0"
64
+ )
65
+
66
+ if use_separate_proj_weight:
67
+ # When using separate projection weights for q, k, v
68
+ assert q_proj_weight is not None, (
69
+ "q_proj_weight should not be None when use_separate_proj_weight=True"
70
+ )
71
+ assert k_proj_weight is not None, (
72
+ "k_proj_weight should not be None when use_separate_proj_weight=True"
73
+ )
74
+ assert v_proj_weight is not None, (
75
+ "v_proj_weight should not be None when use_separate_proj_weight=True"
76
+ )
77
+
78
+ q = query @ q_proj_weight.T
79
+
80
+ if static_k is None:
81
+ k = key @ k_proj_weight.T
82
+ else:
83
+ k = static_k
84
+ src_len, _ = k.shape
85
+
86
+ if static_v is None:
87
+ v = value @ v_proj_weight.T
88
+ else:
89
+ v = static_v
90
+ value_len, _ = v.shape
91
+
92
+ if in_proj_bias is not None:
93
+ q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
94
+ q = q + q_bias
95
+ k = k + k_bias
96
+ v = v + v_bias
97
+
98
+ else:
99
+ assert in_proj_weight is not None, (
100
+ "in_proj_weight should not be None when use_separate_proj_weight=False"
101
+ )
102
+
103
+ q_proj_weight_part, k_proj_weight_part, v_proj_weight_part = jnp.split(
104
+ in_proj_weight, 3
105
+ )
106
+
107
+ q = query @ q_proj_weight_part.T
108
+
109
+ if static_k is None:
110
+ k = key @ k_proj_weight_part.T
111
+ else:
112
+ k = static_k
113
+ src_len, _ = static_k.shape
114
+
115
+ if static_v is None:
116
+ v = value @ v_proj_weight_part.T
117
+ else:
118
+ v = static_v
119
+ value_len, _ = static_v.shape
120
+
121
+ if in_proj_bias is not None:
122
+ q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
123
+ q = q + q_bias
124
+ k = k + k_bias
125
+ v = v + v_bias
126
+
127
+ assert src_len == value_len
128
+
129
+ q = q.reshape(tgt_len, num_heads, head_dim)
130
+ k = k.reshape(src_len, num_heads, head_dim)
131
+ v = v.reshape(src_len, num_heads, head_dim)
132
+
133
+ if add_zero_attn:
134
+ zero_attn_shape = (1, num_heads, head_dim)
135
+ k_zeros = jnp.zeros(zero_attn_shape)
136
+ v_zeros = jnp.zeros(zero_attn_shape)
137
+
138
+ k = jnp.concatenate([k, k_zeros], axis=0)
139
+ v = jnp.concatenate([v, v_zeros], axis=0)
140
+
141
+ src_len += 1
142
+ value_len += 1
143
+
144
+ if bias_k is not None and bias_v is not None:
145
+ bias_k = bias_k.reshape(1, num_heads, head_dim)
146
+ bias_v = bias_v.reshape(1, num_heads, head_dim)
147
+
148
+ k = jnp.concatenate([k, bias_k], axis=0)
149
+ v = jnp.concatenate([v, bias_v], axis=0)
150
+
151
+ src_len += 1
152
+ value_len += 1
153
+
154
+ assert src_len == value_len
155
+
156
+ # [tgt_len, num_heads, head_dim] → [num_heads, tgt_len, head_dim]
157
+ q = jnp.transpose(q, (1, 0, 2))
158
+
159
+ # [src_len, num_heads, head_dim] → [num_heads, src_len, head_dim]
160
+ k = jnp.transpose(k, (1, 0, 2))
161
+ v = jnp.transpose(v, (1, 0, 2))
162
+
163
+ scale = jnp.sqrt(head_dim)
164
+ attn_output_weights = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) / scale
165
+
166
+ if key_padding_mask is not None:
167
+ padding_mask = key_padding_mask.reshape(1, 1, src_len)
168
+ padding_mask = jnp.repeat(padding_mask, num_heads, axis=0)
169
+ padding_mask = jnp.repeat(padding_mask, tgt_len, axis=1)
170
+ attn_output_weights = jnp.where(
171
+ padding_mask, float("-inf"), attn_output_weights
172
+ )
173
+
174
+ if attn_mask is not None:
175
+ # [tgt_len, src_len] -> [num_heads, tgt_len, src_len]
176
+ mask = attn_mask.reshape(1, tgt_len, src_len)
177
+ mask = jnp.repeat(mask, num_heads, axis=0)
178
+ attn_output_weights = attn_output_weights + mask
179
+
180
+ if is_causal:
181
+ causal_mask = jnp.triu(jnp.ones((tgt_len, src_len)), k=1)
182
+ causal_mask = (causal_mask == 1).reshape(1, tgt_len, src_len)
183
+ causal_mask = jnp.repeat(causal_mask, num_heads, axis=0)
184
+ attn_output_weights = jnp.where(causal_mask, float("-inf"), attn_output_weights)
185
+
186
+ # [num_heads, tgt_len, src_len]
187
+ attn_output_weights = jax.nn.softmax(attn_output_weights, axis=-1)
188
+
189
+ if dropout_p > 0.0 and not inference:
190
+ assert dropout_key is not None, (
191
+ "dropout_key required because dropout_p > 0.0 and training"
192
+ )
193
+ dropout_mask = jax.random.bernoulli(
194
+ dropout_key, 1 - dropout_p, attn_output_weights.shape
195
+ )
196
+ scale = 1.0 / (1.0 - dropout_p)
197
+ attn_output_weights = attn_output_weights * dropout_mask * scale
198
+
199
+ attn_output = jnp.matmul(attn_output_weights, v)
200
+ attn_output = jnp.transpose(attn_output, (1, 0, 2))
201
+ attn_output = attn_output.reshape(tgt_len, d_model)
202
+
203
+ assert out_proj_weight is not None, "out_proj_weight must be provided"
204
+ attn_output = attn_output @ out_proj_weight.T
205
+
206
+ if out_proj_bias is not None:
207
+ attn_output = attn_output + out_proj_bias
208
+
209
+ if need_weights:
210
+ if average_attn_weights:
211
+ attn_output_weights = attn_output_weights.mean(axis=0)
212
+ return attn_output, attn_output_weights
213
+ else:
214
+ return attn_output, None
215
+
216
+
217
+ def create_attn_mask(pad_H, pad_W, window_size, shift_size, dtype: Any | None = None):
218
+ if dtype is None:
219
+ dtype = default_floating_dtype()
220
+ assert dtype is not None
221
+ h_boundaries = jnp.array([pad_H - window_size[0], pad_H - shift_size[0]])
222
+ w_boundaries = jnp.array([pad_W - window_size[1], pad_W - shift_size[1]])
223
+
224
+ h_boundaries = jnp.sort(h_boundaries)
225
+ w_boundaries = jnp.sort(w_boundaries)
226
+
227
+ ii, jj = jnp.indices((pad_H, pad_W)) # ii for rows, jj for columns
228
+
229
+ row_region_idx = jnp.searchsorted(h_boundaries, ii, side="right")
230
+ col_region_idx = jnp.searchsorted(w_boundaries, jj, side="right")
231
+
232
+ num_col_regions = len(w_boundaries) + 1
233
+ attn_mask = row_region_idx * num_col_regions + col_region_idx
234
+
235
+ return attn_mask.astype(dtype)
236
+
237
+
238
+ def shifted_window_attention(
239
+ x: Float[Array, "H W C"],
240
+ qkv_weight: Float[Array, "in_dim out_dim"],
241
+ proj_weight: Float[Array, "out_dim out_dim"],
242
+ relative_position_bias: Array,
243
+ window_size: list[int],
244
+ num_heads: int,
245
+ shift_size: list[int],
246
+ attention_dropout: float = 0.0,
247
+ dropout: float = 0.0,
248
+ qkv_bias: Array | None = None,
249
+ proj_bias: Array | None = None,
250
+ logit_scale: Array | None = None,
251
+ inference: bool = False,
252
+ key: PRNGKeyArray | None = None,
253
+ ) -> Float[Array, "H W C"]:
254
+ if not inference and key is None:
255
+ raise ValueError("Need key when in training mode")
256
+ H, W, C = x.shape
257
+ to_pad_W = (window_size[1] - W % window_size[1]) % window_size[1]
258
+ to_pad_H = (window_size[0] - H % window_size[0]) % window_size[0]
259
+ x = jnp.pad(x, ((0, to_pad_H), (0, to_pad_W), (0, 0)))
260
+ pad_H, pad_W, _ = x.shape
261
+
262
+ shift_size = shift_size.copy()
263
+ if window_size[0] >= pad_H:
264
+ shift_size[0] = 0
265
+ if window_size[1] >= pad_W:
266
+ shift_size[1] = 0
267
+
268
+ # cyclic shift
269
+ if sum(shift_size) > 0:
270
+ x = jnp.roll(x, shift=(-shift_size[0], -shift_size[1]), axis=(0, 1))
271
+
272
+ # partition windows
273
+ num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
274
+ x = jnp.reshape(
275
+ x,
276
+ (
277
+ pad_H // window_size[0],
278
+ window_size[0],
279
+ pad_W // window_size[1],
280
+ window_size[1],
281
+ C,
282
+ ),
283
+ )
284
+ x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(
285
+ num_windows, window_size[0] * window_size[1], C
286
+ )
287
+
288
+ # multi-head attention
289
+ if logit_scale is not None and qkv_bias is not None:
290
+ length = qkv_bias.size // 3
291
+ qkv_bias = qkv_bias.at[length : 2 * length].set(0.0)
292
+
293
+ def linear(x: Array, weight: Array, bias: Array | None):
294
+ output = x @ jnp.transpose(weight) # (in,) @ (in, out) -> (out,)
295
+ if bias is not None:
296
+ output = output + bias
297
+ return output
298
+
299
+ linear_pt = functools.partial(linear, weight=qkv_weight, bias=qkv_bias)
300
+
301
+ qkv = eqx.filter_vmap(eqx.filter_vmap(linear_pt))(x)
302
+ win_size, patches, _ = qkv.shape
303
+ qkv = jnp.transpose(
304
+ qkv.reshape(win_size, patches, 3, num_heads, C // num_heads), (2, 0, 3, 1, 4)
305
+ )
306
+ q, k, v = qkv[0], qkv[1], qkv[2]
307
+ if logit_scale is not None:
308
+ # cosine attention
309
+ attn = normalize(q, axis=-1) @ jnp.transpose(
310
+ normalize(k, axis=-1), (0, 1, 3, 2)
311
+ )
312
+ # Clamp the logit scale exponent for stability
313
+ logit_scale = jnp.exp(jnp.minimum(logit_scale, jnp.log(jnp.array(100.0))))
314
+ attn = attn * logit_scale
315
+ else:
316
+ q = q * (C // num_heads) ** -0.5
317
+ # attn = q @ (jnp.transpose(normalize(k, axis=-1), (0, 1, 3, 2))) # Incorrect
318
+ attn = q @ jnp.transpose(k, (0, 1, 3, 2)) # Corrected: q @ k.T
319
+
320
+ # add relative position bias
321
+ attn = attn + relative_position_bias
322
+
323
+ if sum(shift_size) > 0:
324
+ attn_mask = create_attn_mask(pad_H, pad_W, window_size, shift_size)
325
+ attn_mask = attn_mask.reshape(
326
+ pad_H // window_size[0],
327
+ window_size[0],
328
+ pad_W // window_size[1],
329
+ window_size[1],
330
+ )
331
+ attn_mask = jnp.transpose(attn_mask, (0, 2, 1, 3)).reshape(
332
+ num_windows, window_size[0] * window_size[1]
333
+ )
334
+ attn_mask = jnp.expand_dims(attn_mask, axis=1) - jnp.expand_dims(
335
+ attn_mask, axis=2
336
+ )
337
+ attn_mask = jnp.where(attn_mask == 0, 0.0, -100.0)
338
+
339
+ attn = attn + attn_mask[:, None, :, :]
340
+
341
+ attn = jax.nn.softmax(attn, axis=-1)
342
+ if not inference:
343
+ assert key is not None, "key must be given if not inference"
344
+ key, subkey = jax.random.split(key)
345
+ attn = dropout_fn(attn, p=attention_dropout, inference=inference, key=subkey)
346
+
347
+ x = jnp.transpose(attn @ v, (0, 2, 1, 3)).reshape(
348
+ num_windows, window_size[0] * window_size[1], C
349
+ )
350
+ linear_pt_proj = functools.partial(linear, weight=proj_weight, bias=proj_bias)
351
+
352
+ x = eqx.filter_vmap(eqx.filter_vmap(linear_pt_proj))(x)
353
+ if not inference:
354
+ assert key is not None, "key must be given if not inference"
355
+ key, subkey = jax.random.split(key)
356
+ x = dropout_fn(x, p=dropout, inference=inference, key=subkey)
357
+
358
+ # reverse windows
359
+ x = x.reshape(
360
+ pad_H // window_size[0],
361
+ pad_W // window_size[1],
362
+ window_size[0],
363
+ window_size[1],
364
+ C,
365
+ )
366
+ x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(pad_H, pad_W, C)
367
+
368
+ # reverse cyclic shift
369
+ if sum(shift_size) > 0:
370
+ x = jnp.roll(x, shift=(shift_size[0], shift_size[1]), axis=(0, 1))
371
+
372
+ # unpad features
373
+ x = x[:H, :W, :]
374
+ return x
@@ -0,0 +1,21 @@
1
+ import jax.numpy as jnp
2
+ from jaxtyping import Array, Float, Int
3
+
4
+
5
+ def sinusoidal_embedding(
6
+ t: Int[Array, ""], embedding_size: int
7
+ ) -> Float[Array, " embedding_size"]:
8
+ if embedding_size % 2 != 0:
9
+ raise ValueError(f"Embedding size must be even, but got {embedding_size}")
10
+
11
+ half_dim = embedding_size // 2
12
+ embedding_freqs = jnp.exp(
13
+ -jnp.log(10000)
14
+ * jnp.arange(start=0, stop=half_dim, dtype=jnp.float32)
15
+ / half_dim
16
+ )
17
+
18
+ time_args = t * embedding_freqs
19
+ embedding = jnp.concatenate([jnp.sin(time_args), jnp.cos(time_args)])
20
+
21
+ return embedding
@@ -0,0 +1,34 @@
1
+ import equinox as eqx
2
+ import jax
3
+ from jaxtyping import PRNGKeyArray, PyTree
4
+
5
+
6
+ def kaiming_init_conv2d(model: PyTree, state: eqx.nn.State, key: PRNGKeyArray):
7
+ """Applies Kaiming He normal initialization to Conv2d weights."""
8
+ # Filter function to identify Conv2d layers
9
+ is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
10
+
11
+ # Function to get weights (leaves) based on the filter
12
+ def get_weights(model):
13
+ return [
14
+ x.weight for x in jax.tree.leaves(model, is_leaf=is_conv2d) if is_conv2d(x)
15
+ ]
16
+
17
+ # Get the list of current weights
18
+ weights = get_weights(model)
19
+ if not weights: # If no Conv2d layers found
20
+ return model, state
21
+
22
+ # Create new weights using He initializer
23
+ initializer = jax.nn.initializers.he_normal()
24
+ # Split key for each weight tensor
25
+ subkeys = jax.random.split(key, len(weights))
26
+ new_weights = [
27
+ initializer(subkeys[i], w.shape, w.dtype) # Use original weight's dtype
28
+ for i, w in enumerate(weights)
29
+ ]
30
+
31
+ # Replace old weights with new weights in the model pytree
32
+ model = eqx.tree_at(get_weights, model, new_weights)
33
+
34
+ return model, state
@@ -0,0 +1,58 @@
1
+ import jax.numpy as jnp
2
+
3
+
4
+ def canonical_mask(
5
+ mask,
6
+ mask_name,
7
+ other_name="",
8
+ other_type=None,
9
+ target_type=jnp.float32,
10
+ other_mask=None,
11
+ check_other=True,
12
+ ):
13
+ if mask is None:
14
+ return None
15
+ if mask.dtype == bool:
16
+ additive_mask = jnp.where(mask, -jnp.inf, 0.0).astype(target_type)
17
+ return additive_mask
18
+ elif jnp.issubdtype(mask.dtype, jnp.integer) or jnp.issubdtype(
19
+ mask.dtype, jnp.floating
20
+ ):
21
+ return mask.astype(target_type)
22
+ else:
23
+ raise TypeError(
24
+ f"{mask_name} must be bool, int, or float tensor, but got {mask.dtype}"
25
+ )
26
+
27
+
28
+ def canonical_key_padding_mask(
29
+ key_padding_mask, attn_mask=None, query_dtype=jnp.float32
30
+ ):
31
+ """Wrapper for canonicalizing key_padding_mask"""
32
+ return canonical_mask(
33
+ mask=key_padding_mask,
34
+ mask_name="key_padding_mask",
35
+ other_name="attn_mask",
36
+ other_mask=attn_mask,
37
+ target_type=query_dtype,
38
+ )
39
+
40
+
41
+ def canonical_attn_mask(attn_mask, query_dtype=jnp.float32):
42
+ """Wrapper for canonicalizing attn_mask"""
43
+ return canonical_mask(
44
+ mask=attn_mask,
45
+ mask_name="attn_mask",
46
+ other_type=None,
47
+ other_name="",
48
+ target_type=query_dtype,
49
+ check_other=False,
50
+ )
51
+
52
+
53
+ def build_attention_mask(context_length: int):
54
+ mask = jnp.tril(jnp.zeros((context_length, context_length)))
55
+ upper = jnp.triu(jnp.full((context_length, context_length), float("-inf")), k=1)
56
+
57
+ mask = mask + upper
58
+ return mask
@@ -0,0 +1,9 @@
1
+ import jax.numpy as jnp
2
+
3
+
4
+ def normalize(x, p=2, axis=1, eps=1e-12):
5
+ norm = jnp.linalg.norm(x, ord=p, axis=axis, keepdims=True)
6
+ norm = jnp.maximum(norm, eps)
7
+ output = x / norm
8
+
9
+ return output
@@ -0,0 +1,47 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jaxtyping import Array, PRNGKeyArray
4
+
5
+
6
+ def stochastic_depth(
7
+ input: Array,
8
+ p: float,
9
+ mode: str,
10
+ inference: bool,
11
+ key: PRNGKeyArray,
12
+ ) -> Array:
13
+ if p < 0.0 or p > 1.0:
14
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
15
+ if mode not in ["batch", "row"]:
16
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
17
+ if inference or p == 0.0:
18
+ return input
19
+ survival_rate = 1.0 - p
20
+ if mode == "row":
21
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
22
+ else:
23
+ size = [1] * input.ndim
24
+ noise = jax.random.bernoulli(key, p=survival_rate, shape=size).astype(input.dtype)
25
+ if survival_rate > 0.0:
26
+ noise = noise / survival_rate
27
+ return input * noise
28
+
29
+
30
+ def dropout(
31
+ x: Array,
32
+ p: float,
33
+ inference: bool,
34
+ key: PRNGKeyArray | None = None,
35
+ ) -> Array:
36
+ if isinstance(p, (int, float)) and p == 0:
37
+ inference = True
38
+ if inference:
39
+ return x
40
+ elif key is None:
41
+ raise RuntimeError(
42
+ "Dropout requires a key when running in non-deterministic mode."
43
+ )
44
+ else:
45
+ q = 1 - jax.lax.stop_gradient(p)
46
+ mask = jax.random.bernoulli(key, q, x.shape)
47
+ return jnp.where(mask, x / q, 0)
@@ -0,0 +1,79 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jaxtyping import Array, Float
4
+
5
+
6
+ def selective_scan(
7
+ u: Float[Array, "seq_length d_inner"],
8
+ delta: Float[Array, "seq_length d_inner"],
9
+ A: Float[Array, "d_inner d_state"],
10
+ B: Float[Array, "seq_length d_inner d_state"],
11
+ C: Float[Array, "seq_length d_inner d_state"],
12
+ D: Float[Array, " d_inner"],
13
+ chunk_size: int = 128,
14
+ ) -> Float[Array, "seq_length d_inner"]:
15
+ deltaA = jnp.exp(jnp.einsum("l d, d n -> l d n", delta, A))
16
+ deltaB_u = jnp.einsum("l d, l d n, l d -> l d n", delta, B, u)
17
+
18
+ seq_len, d_inner = u.shape
19
+ d_state = A.shape[1]
20
+
21
+ num_chunks = (seq_len + chunk_size - 1) // chunk_size
22
+ padded_len = num_chunks * chunk_size
23
+
24
+ pad_len = padded_len - seq_len
25
+ deltaA_padded = jnp.pad(deltaA, ((0, pad_len), (0, 0), (0, 0)))
26
+ deltaB_u_padded = jnp.pad(deltaB_u, ((0, pad_len), (0, 0), (0, 0)))
27
+ C_padded = jnp.pad(C, ((0, pad_len), (0, 0), (0, 0)))
28
+
29
+ deltaA_chunked = deltaA_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
30
+ deltaB_u_chunked = deltaB_u_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
31
+ C_chunked = C_padded.reshape(num_chunks, chunk_size, d_inner, d_state)
32
+
33
+ def intra_chunk_step(h_prev, scan_inputs):
34
+ deltaA_i, deltaB_u_i, C_i = scan_inputs
35
+ h_i = deltaA_i * h_prev + deltaB_u_i
36
+ y_i = jnp.einsum("d n, d n -> d", h_i, C_i)
37
+ return h_i, y_i
38
+
39
+ h0 = jnp.zeros((d_inner, d_state))
40
+
41
+ _, y_chunks = jax.vmap(jax.lax.scan, in_axes=(None, None, 0))(
42
+ intra_chunk_step, h0, (deltaA_chunked, deltaB_u_chunked, C_chunked)
43
+ )
44
+
45
+ def inter_chunk_step(carry_prev, scan_inputs):
46
+ A_prev, h_prev = carry_prev
47
+ deltaA_i, deltaB_u_i = scan_inputs
48
+
49
+ A_new = deltaA_i * A_prev
50
+ h_new = deltaA_i * h_prev + deltaB_u_i
51
+
52
+ return (A_new, h_new), (A_new, h_new)
53
+
54
+ A_carry_initial = jnp.ones((d_inner, d_state))
55
+ h_carry_initial = jnp.zeros((d_inner, d_state))
56
+ initial_carry = (A_carry_initial, h_carry_initial)
57
+
58
+ scan_inputs = (deltaA_chunked[:, -1], deltaB_u_chunked[:, -1])
59
+
60
+ _, (A_carry, h_carry) = jax.lax.scan(inter_chunk_step, initial_carry, scan_inputs)
61
+
62
+ A_carry = jnp.roll(A_carry, 1, axis=0)
63
+ h_carry = jnp.roll(h_carry, 1, axis=0)
64
+ A_carry = A_carry.at[0].set(jnp.ones_like(A_carry[0]))
65
+ h_carry = h_carry.at[0].set(jnp.zeros_like(h_carry[0]))
66
+
67
+ h_carry_broadcast = jnp.expand_dims(h_carry, axis=1)
68
+ h_correction = deltaA_chunked * h_carry_broadcast
69
+ y_carry = jnp.einsum("csdn, csdn -> csd", C_chunked, h_correction)
70
+
71
+ y_final = y_chunks + y_carry
72
+
73
+ y_final = y_final.reshape(padded_len, d_inner)
74
+
75
+ y_unpadded = y_final[:seq_len]
76
+
77
+ output = y_unpadded.real + u * D
78
+
79
+ return output