lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +2 -16
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
- lalamo-0.5.0.dist-info/RECORD +80 -0
- lalamo-0.4.1.dist-info/RECORD +0 -71
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import NamedTuple, Self
|
|
5
|
+
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from einops import einsum, rearrange
|
|
10
|
+
from jax import vmap
|
|
11
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
12
|
+
|
|
13
|
+
from lalamo.common import ParameterTree, dummy_array
|
|
14
|
+
from lalamo.modules.activations import Activation
|
|
15
|
+
from lalamo.modules.common import LalamoModule, PositionalEmbeddingSelector
|
|
16
|
+
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
17
|
+
from lalamo.modules.rope import PositionalEmbeddings
|
|
18
|
+
|
|
19
|
+
from .common import TokenMixerBase, TokenMixerConfigBase, TokenMixerResult
|
|
20
|
+
from .state import Mamba2StateLayer
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"Mamba2",
|
|
24
|
+
"Mamba2Config",
|
|
25
|
+
"Mamba2Result",
|
|
26
|
+
"SeparableCausalConv",
|
|
27
|
+
"SeparableCausalConvConfig",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
Mamba2Result = TokenMixerResult[Mamba2StateLayer]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CausalConvResult(NamedTuple):
|
|
35
|
+
outputs: Float[Array, "*batch suffix_tokens channels"]
|
|
36
|
+
state: Float[Array, "*batch tokens channels"] | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class SeparableCausalConvConfig:
|
|
41
|
+
precision: DTypeLike
|
|
42
|
+
has_biases: bool
|
|
43
|
+
|
|
44
|
+
def random_init(
|
|
45
|
+
self,
|
|
46
|
+
input_dim: int,
|
|
47
|
+
kernel_size: int,
|
|
48
|
+
*,
|
|
49
|
+
key: PRNGKeyArray,
|
|
50
|
+
) -> "SeparableCausalConv":
|
|
51
|
+
scale = 1 / math.sqrt(kernel_size * input_dim)
|
|
52
|
+
weights = jax.random.uniform(
|
|
53
|
+
key,
|
|
54
|
+
(input_dim, kernel_size),
|
|
55
|
+
minval=-scale,
|
|
56
|
+
maxval=scale,
|
|
57
|
+
dtype=self.precision,
|
|
58
|
+
)
|
|
59
|
+
if self.has_biases:
|
|
60
|
+
biases = jnp.zeros((input_dim,), dtype=self.precision)
|
|
61
|
+
else:
|
|
62
|
+
biases = None
|
|
63
|
+
return SeparableCausalConv(self, weights=weights, biases=biases)
|
|
64
|
+
|
|
65
|
+
def empty(
|
|
66
|
+
self,
|
|
67
|
+
input_dim: int,
|
|
68
|
+
kernel_size: int,
|
|
69
|
+
) -> "SeparableCausalConv":
|
|
70
|
+
weights = dummy_array(
|
|
71
|
+
(input_dim, kernel_size),
|
|
72
|
+
dtype=self.precision,
|
|
73
|
+
)
|
|
74
|
+
if self.has_biases:
|
|
75
|
+
biases = dummy_array((input_dim,), dtype=self.precision)
|
|
76
|
+
else:
|
|
77
|
+
biases = None
|
|
78
|
+
return SeparableCausalConv(self, weights=weights, biases=biases)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
|
|
82
|
+
weights: Float[Array, "channels kernel"]
|
|
83
|
+
biases: Float[Array, " channels"] | None
|
|
84
|
+
|
|
85
|
+
def __post_init__(self) -> None:
|
|
86
|
+
input_dim, _ = self.weights.shape
|
|
87
|
+
if self.biases is not None:
|
|
88
|
+
(output_dim,) = self.biases.shape
|
|
89
|
+
if output_dim != input_dim:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Output dimension of biases ({output_dim}) must match input dimension ({input_dim})",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def activation_precision(self) -> DTypeLike:
|
|
96
|
+
return self.config.precision
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def input_dim(self) -> int:
|
|
100
|
+
input_dim, _ = self.weights.shape
|
|
101
|
+
return input_dim
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def kernel_size(self) -> int:
|
|
105
|
+
_, kernel_size = self.weights.shape
|
|
106
|
+
return kernel_size
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def has_biases(self) -> bool:
|
|
110
|
+
return self.biases is not None
|
|
111
|
+
|
|
112
|
+
def __call__(
|
|
113
|
+
self,
|
|
114
|
+
inputs: Float[Array, "suffix_tokens channels"],
|
|
115
|
+
state: Float[Array, "prefix_tokens channels"] | None = None,
|
|
116
|
+
return_updated_state: bool = False,
|
|
117
|
+
) -> CausalConvResult:
|
|
118
|
+
num_suffix_tokens, input_dim = inputs.shape
|
|
119
|
+
|
|
120
|
+
if state is None:
|
|
121
|
+
state = jnp.zeros((self.kernel_size - 1, input_dim), dtype=inputs.dtype)
|
|
122
|
+
|
|
123
|
+
required_context = num_suffix_tokens + self.kernel_size - 1
|
|
124
|
+
|
|
125
|
+
inputs_with_history = jnp.concatenate([state, inputs], axis=0)
|
|
126
|
+
conv_outputs = jax.lax.conv_general_dilated(
|
|
127
|
+
inputs_with_history[None, -required_context:, :],
|
|
128
|
+
self.weights[:, :, None],
|
|
129
|
+
window_strides=(1,),
|
|
130
|
+
feature_group_count=input_dim,
|
|
131
|
+
padding="VALID",
|
|
132
|
+
dimension_numbers=("NTC", "OTI", "NTC"),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
results = conv_outputs.squeeze(0)
|
|
136
|
+
if self.biases is not None:
|
|
137
|
+
results = results + self.biases
|
|
138
|
+
|
|
139
|
+
return CausalConvResult(
|
|
140
|
+
results,
|
|
141
|
+
(inputs_with_history if return_updated_state else None),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def export_weights(self) -> ParameterTree:
|
|
145
|
+
result: dict[str, Array] = {"weights": self.weights}
|
|
146
|
+
if self.biases is not None:
|
|
147
|
+
result["biases"] = self.biases
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
def import_weights(self, weights: ParameterTree[Array]) -> "SeparableCausalConv":
|
|
151
|
+
assert isinstance(weights, Mapping)
|
|
152
|
+
assert isinstance(weights["weights"], Array)
|
|
153
|
+
if self.biases is not None:
|
|
154
|
+
assert isinstance(weights["biases"], Array)
|
|
155
|
+
biases = weights["biases"]
|
|
156
|
+
else:
|
|
157
|
+
biases = None
|
|
158
|
+
return replace(
|
|
159
|
+
self,
|
|
160
|
+
weights=weights["weights"],
|
|
161
|
+
biases=biases,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass(frozen=True)
|
|
166
|
+
class Mamba2Config(TokenMixerConfigBase):
|
|
167
|
+
in_projection_config: LinearConfig
|
|
168
|
+
out_projection_config: LinearConfig
|
|
169
|
+
conv_config: SeparableCausalConvConfig
|
|
170
|
+
activation: Activation
|
|
171
|
+
|
|
172
|
+
kernel_size: int
|
|
173
|
+
num_heads: int
|
|
174
|
+
num_groups: int
|
|
175
|
+
head_dim: int
|
|
176
|
+
state_dim: int
|
|
177
|
+
expansion_factor: int
|
|
178
|
+
|
|
179
|
+
has_in_biases: bool
|
|
180
|
+
has_out_biases: bool
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def inner_dim(self) -> int:
|
|
184
|
+
return self.num_heads * self.head_dim
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def rope_dim(self) -> int:
|
|
188
|
+
return self.head_dim
|
|
189
|
+
|
|
190
|
+
def random_init(
|
|
191
|
+
self,
|
|
192
|
+
model_dim: int,
|
|
193
|
+
*,
|
|
194
|
+
key: PRNGKeyArray,
|
|
195
|
+
) -> "Mamba2":
|
|
196
|
+
in_key, out_key, conv_key, skip_key = jax.random.split(key, 4)
|
|
197
|
+
|
|
198
|
+
in_projection = self.in_projection_config.random_init(
|
|
199
|
+
input_dim=model_dim,
|
|
200
|
+
output_dims=(
|
|
201
|
+
self.inner_dim + 2 * self.num_groups * self.state_dim,
|
|
202
|
+
self.inner_dim,
|
|
203
|
+
self.num_heads,
|
|
204
|
+
),
|
|
205
|
+
has_biases=self.has_in_biases,
|
|
206
|
+
key=in_key,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
out_projection = self.out_projection_config.random_init(
|
|
210
|
+
self.inner_dim,
|
|
211
|
+
(model_dim,),
|
|
212
|
+
has_biases=self.has_out_biases,
|
|
213
|
+
key=out_key,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
conv_channels = self.inner_dim + 2 * self.num_groups * self.state_dim
|
|
217
|
+
conv = self.conv_config.random_init(conv_channels, self.kernel_size, key=conv_key)
|
|
218
|
+
|
|
219
|
+
skip_connection_weight = jax.random.normal(
|
|
220
|
+
skip_key,
|
|
221
|
+
(self.num_heads,),
|
|
222
|
+
dtype=in_projection.activation_precision,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
gate_bias = jnp.zeros((self.inner_dim,), dtype=in_projection.activation_precision)
|
|
226
|
+
|
|
227
|
+
return Mamba2(
|
|
228
|
+
self,
|
|
229
|
+
in_projection=in_projection,
|
|
230
|
+
conv=conv,
|
|
231
|
+
out_projection=out_projection,
|
|
232
|
+
skip_connection_weight=skip_connection_weight,
|
|
233
|
+
gate_bias=gate_bias,
|
|
234
|
+
num_heads=self.num_heads,
|
|
235
|
+
num_groups=self.num_groups,
|
|
236
|
+
head_dim=self.head_dim,
|
|
237
|
+
state_dim=self.state_dim,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def empty(
|
|
241
|
+
self,
|
|
242
|
+
model_dim: int,
|
|
243
|
+
) -> "Mamba2":
|
|
244
|
+
in_projection = self.in_projection_config.empty(
|
|
245
|
+
input_dim=model_dim,
|
|
246
|
+
output_dims=(
|
|
247
|
+
self.inner_dim + 2 * self.num_groups * self.state_dim,
|
|
248
|
+
self.inner_dim,
|
|
249
|
+
self.num_heads,
|
|
250
|
+
),
|
|
251
|
+
has_biases=self.has_in_biases,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
out_projection = self.out_projection_config.empty(
|
|
255
|
+
self.inner_dim,
|
|
256
|
+
(model_dim,),
|
|
257
|
+
has_biases=self.has_out_biases,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
conv_channels = self.inner_dim + 2 * self.num_groups * self.state_dim
|
|
261
|
+
conv = self.conv_config.empty(conv_channels, self.kernel_size)
|
|
262
|
+
|
|
263
|
+
skip_connection_weight = dummy_array((self.num_heads,), in_projection.activation_precision)
|
|
264
|
+
gate_bias = dummy_array((self.inner_dim,), in_projection.activation_precision)
|
|
265
|
+
|
|
266
|
+
return Mamba2(
|
|
267
|
+
self,
|
|
268
|
+
in_projection=in_projection,
|
|
269
|
+
conv=conv,
|
|
270
|
+
out_projection=out_projection,
|
|
271
|
+
skip_connection_weight=skip_connection_weight,
|
|
272
|
+
gate_bias=gate_bias,
|
|
273
|
+
num_heads=self.num_heads,
|
|
274
|
+
num_groups=self.num_groups,
|
|
275
|
+
head_dim=self.head_dim,
|
|
276
|
+
state_dim=self.state_dim,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
|
|
281
|
+
in_projection: LinearBase
|
|
282
|
+
conv: SeparableCausalConv
|
|
283
|
+
out_projection: LinearBase
|
|
284
|
+
|
|
285
|
+
skip_connection_weight: Float[Array, " heads"]
|
|
286
|
+
gate_bias: Float[Array, " inner_channels"]
|
|
287
|
+
|
|
288
|
+
num_heads: int = eqx.field(static=True)
|
|
289
|
+
num_groups: int = eqx.field(static=True)
|
|
290
|
+
head_dim: int = eqx.field(static=True)
|
|
291
|
+
state_dim: int = eqx.field(static=True)
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def activation_precision(self) -> DTypeLike:
|
|
295
|
+
return self.in_projection.activation_precision
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def model_dim(self) -> int:
|
|
299
|
+
return self.in_projection.input_dim
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def inner_dim(self) -> int:
|
|
303
|
+
return self.num_heads * self.head_dim
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
|
|
307
|
+
return PositionalEmbeddingSelector.NONE
|
|
308
|
+
|
|
309
|
+
def __post_init__(self) -> None:
|
|
310
|
+
if self.skip_connection_weight.shape != (self.num_heads,):
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Skip connection weight must have shape (num_heads,) = ({self.num_heads},), "
|
|
313
|
+
f"got {self.skip_connection_weight.shape}",
|
|
314
|
+
)
|
|
315
|
+
if self.gate_bias.shape != (self.inner_dim,):
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Gate bias must have shape (inner_dim,) = ({self.inner_dim},), got {self.gate_bias.shape}",
|
|
318
|
+
)
|
|
319
|
+
if self.num_heads % self.num_groups != 0:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"Number of value heads ({self.num_heads}) must be divisible by number of groups ({self.num_groups})",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _scan(
|
|
325
|
+
self,
|
|
326
|
+
hidden_states: Float[Array, "suffix_tokens heads head_channels"],
|
|
327
|
+
input_projection: Float[Array, "suffix_tokens groups state_channels"],
|
|
328
|
+
output_projection: Float[Array, "suffix_tokens groups state_channels"],
|
|
329
|
+
time_delta_log: Float[Array, "suffix_tokens heads"],
|
|
330
|
+
initial_state: Float[Array, "heads head_channels state_channels"],
|
|
331
|
+
num_steps: Int[Array, ""] | int,
|
|
332
|
+
) -> tuple[
|
|
333
|
+
Float[Array, "suffix_tokens heads head_channels"],
|
|
334
|
+
Float[Array, "heads head_channels state_channels"],
|
|
335
|
+
]:
|
|
336
|
+
def scan_fn(
|
|
337
|
+
index_and_carry_state: tuple[Int[Array, ""], Float[Array, "heads head_channels state_channels"]],
|
|
338
|
+
step_inputs: tuple[
|
|
339
|
+
Float[Array, "heads head_channels"],
|
|
340
|
+
Float[Array, "groups state_channels"],
|
|
341
|
+
Float[Array, "groups state_channels"],
|
|
342
|
+
Float[Array, " heads"],
|
|
343
|
+
],
|
|
344
|
+
) -> tuple[
|
|
345
|
+
tuple[Int[Array, ""], Float[Array, "heads head_channels state_channels"]],
|
|
346
|
+
Float[Array, "heads head_channels"],
|
|
347
|
+
]:
|
|
348
|
+
index, carry_state = index_and_carry_state
|
|
349
|
+
hidden_state_t, input_proj_t, output_proj_t, time_delta_log_t = step_inputs
|
|
350
|
+
dt = jax.nn.softplus(time_delta_log_t)[:, None]
|
|
351
|
+
heads_per_group = self.num_heads // self.num_groups
|
|
352
|
+
|
|
353
|
+
hidden_grouped = rearrange(
|
|
354
|
+
hidden_state_t,
|
|
355
|
+
"(groups heads) head_channels -> groups heads head_channels",
|
|
356
|
+
groups=self.num_groups,
|
|
357
|
+
heads=heads_per_group,
|
|
358
|
+
)
|
|
359
|
+
x_norm_grouped = hidden_grouped / (
|
|
360
|
+
dt.reshape(self.num_heads)[
|
|
361
|
+
rearrange(
|
|
362
|
+
jnp.arange(self.num_heads),
|
|
363
|
+
"(groups heads)-> groups heads",
|
|
364
|
+
groups=self.num_groups,
|
|
365
|
+
heads=heads_per_group,
|
|
366
|
+
)
|
|
367
|
+
][:, :, None]
|
|
368
|
+
+ 1e-8
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
decay = jnp.exp(-dt)[:, :, None]
|
|
372
|
+
mix = dt[:, :, None]
|
|
373
|
+
decay_group = rearrange(
|
|
374
|
+
decay,
|
|
375
|
+
"(groups heads) 1 1 -> groups heads 1 1",
|
|
376
|
+
groups=self.num_groups,
|
|
377
|
+
heads=heads_per_group,
|
|
378
|
+
)
|
|
379
|
+
mix_group = rearrange(
|
|
380
|
+
mix,
|
|
381
|
+
"(groups heads) 1 1 -> groups heads 1 1",
|
|
382
|
+
groups=self.num_groups,
|
|
383
|
+
heads=heads_per_group,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
input_contribution_group = mix_group * x_norm_grouped[:, :, :, None] * input_proj_t[:, None, None, :]
|
|
387
|
+
carry_state_group = rearrange(
|
|
388
|
+
carry_state,
|
|
389
|
+
"(groups heads) head_channels state_channels -> groups heads head_channels state_channels",
|
|
390
|
+
groups=self.num_groups,
|
|
391
|
+
heads=heads_per_group,
|
|
392
|
+
)
|
|
393
|
+
updated_state_group = decay_group * carry_state_group + input_contribution_group
|
|
394
|
+
|
|
395
|
+
output_group = einsum(
|
|
396
|
+
updated_state_group,
|
|
397
|
+
output_proj_t,
|
|
398
|
+
"groups heads head_channels state_channels, groups state_channels -> groups heads head_channels",
|
|
399
|
+
)
|
|
400
|
+
updated_state = rearrange(
|
|
401
|
+
updated_state_group,
|
|
402
|
+
"groups heads head_channels state_channels -> (groups heads) head_channels state_channels",
|
|
403
|
+
)
|
|
404
|
+
output_t = rearrange(output_group, "groups heads head_channels -> (groups heads) head_channels")
|
|
405
|
+
|
|
406
|
+
propagated_state = jax.lax.cond(index < num_steps, lambda: updated_state, lambda: carry_state)
|
|
407
|
+
|
|
408
|
+
return (index + 1, propagated_state), output_t
|
|
409
|
+
|
|
410
|
+
(_, final_state), outputs = jax.lax.scan(
|
|
411
|
+
scan_fn,
|
|
412
|
+
(jnp.zeros((), dtype=jnp.int32), initial_state),
|
|
413
|
+
(hidden_states, input_projection, output_projection, time_delta_log),
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
return outputs, final_state
|
|
417
|
+
|
|
418
|
+
@eqx.filter_jit
|
|
419
|
+
def __call__(
|
|
420
|
+
self,
|
|
421
|
+
inputs: Float[Array, "suffix_tokens channels"],
|
|
422
|
+
positional_embeddings: PositionalEmbeddings | None,
|
|
423
|
+
state: Mamba2StateLayer | None = None,
|
|
424
|
+
return_updated_state: bool = False,
|
|
425
|
+
length_without_padding: Int[Array, ""] | int | None = None,
|
|
426
|
+
) -> Mamba2Result:
|
|
427
|
+
if positional_embeddings is not None:
|
|
428
|
+
raise ValueError("Positional embeddings are not supported for Mamba2.")
|
|
429
|
+
|
|
430
|
+
conv_inputs, gate_values, time_delta_log = vmap(self.in_projection)(inputs)
|
|
431
|
+
|
|
432
|
+
if state is None:
|
|
433
|
+
state = Mamba2StateLayer.init(
|
|
434
|
+
self.config.kernel_size,
|
|
435
|
+
self.inner_dim,
|
|
436
|
+
self.num_heads,
|
|
437
|
+
self.num_groups,
|
|
438
|
+
self.head_dim,
|
|
439
|
+
self.state_dim,
|
|
440
|
+
self.activation_precision,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
conv_output, updated_conv_state = self.conv(
|
|
444
|
+
conv_inputs,
|
|
445
|
+
state.conv_state,
|
|
446
|
+
return_updated_state=return_updated_state,
|
|
447
|
+
)
|
|
448
|
+
conv_activated = self.config.activation(conv_output)
|
|
449
|
+
|
|
450
|
+
x_channels, input_proj_channels, output_proj_channels = jnp.split(
|
|
451
|
+
conv_activated,
|
|
452
|
+
[
|
|
453
|
+
self.inner_dim,
|
|
454
|
+
self.inner_dim + self.num_groups * self.state_dim,
|
|
455
|
+
],
|
|
456
|
+
axis=-1,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
hidden_states = rearrange(
|
|
460
|
+
x_channels,
|
|
461
|
+
"suffix_tokens (heads head_channels) -> suffix_tokens heads head_channels",
|
|
462
|
+
heads=self.num_heads,
|
|
463
|
+
)
|
|
464
|
+
input_projection = rearrange(
|
|
465
|
+
input_proj_channels,
|
|
466
|
+
"suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
|
|
467
|
+
groups=self.num_groups,
|
|
468
|
+
)
|
|
469
|
+
output_projection = rearrange(
|
|
470
|
+
output_proj_channels,
|
|
471
|
+
"suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
|
|
472
|
+
groups=self.num_groups,
|
|
473
|
+
)
|
|
474
|
+
time_delta_log = rearrange(
|
|
475
|
+
time_delta_log,
|
|
476
|
+
"suffix_tokens heads -> suffix_tokens heads",
|
|
477
|
+
heads=self.num_heads,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if length_without_padding is None:
|
|
481
|
+
length_without_padding, _ = inputs.shape
|
|
482
|
+
|
|
483
|
+
ssm_outputs, final_ssm_state = self._scan(
|
|
484
|
+
hidden_states,
|
|
485
|
+
input_projection,
|
|
486
|
+
output_projection,
|
|
487
|
+
time_delta_log,
|
|
488
|
+
state.ssm_state,
|
|
489
|
+
length_without_padding,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
skip_contribution = self.skip_connection_weight[None, :, None] * hidden_states
|
|
493
|
+
ssm_outputs = ssm_outputs + skip_contribution
|
|
494
|
+
|
|
495
|
+
ssm_outputs = rearrange(
|
|
496
|
+
ssm_outputs,
|
|
497
|
+
"suffix_tokens heads head_channels -> suffix_tokens (heads head_channels)",
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
gated_outputs = ssm_outputs * jax.nn.silu(gate_values + self.gate_bias)
|
|
501
|
+
|
|
502
|
+
(outputs,) = vmap(self.out_projection)(gated_outputs)
|
|
503
|
+
|
|
504
|
+
if return_updated_state:
|
|
505
|
+
assert updated_conv_state is not None
|
|
506
|
+
updated_state = Mamba2StateLayer(updated_conv_state, final_ssm_state)
|
|
507
|
+
else:
|
|
508
|
+
updated_state = None
|
|
509
|
+
|
|
510
|
+
return Mamba2Result(
|
|
511
|
+
outputs=outputs,
|
|
512
|
+
state=updated_state,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def init_static_state(self, capacity: int) -> Mamba2StateLayer: # noqa: ARG002
|
|
516
|
+
return Mamba2StateLayer.init(
|
|
517
|
+
self.config.kernel_size,
|
|
518
|
+
self.inner_dim,
|
|
519
|
+
self.num_heads,
|
|
520
|
+
self.num_groups,
|
|
521
|
+
self.head_dim,
|
|
522
|
+
self.state_dim,
|
|
523
|
+
self.activation_precision,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
def export_weights(self) -> ParameterTree:
|
|
527
|
+
return {
|
|
528
|
+
"in_projection": self.in_projection.export_weights(),
|
|
529
|
+
"out_projection": self.out_projection.export_weights(),
|
|
530
|
+
"conv": self.conv.export_weights(),
|
|
531
|
+
"skip_connection_weight": self.skip_connection_weight,
|
|
532
|
+
"gate_bias": self.gate_bias,
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
def import_weights(
|
|
536
|
+
self,
|
|
537
|
+
weights: ParameterTree[Array],
|
|
538
|
+
) -> Self:
|
|
539
|
+
assert isinstance(weights, Mapping)
|
|
540
|
+
assert isinstance(weights["in_projection"], Mapping)
|
|
541
|
+
assert isinstance(weights["out_projection"], Mapping)
|
|
542
|
+
assert isinstance(weights["conv"], Mapping)
|
|
543
|
+
assert isinstance(weights["skip_connection_weight"], Array)
|
|
544
|
+
assert isinstance(weights["gate_bias"], Array)
|
|
545
|
+
|
|
546
|
+
return replace(
|
|
547
|
+
self,
|
|
548
|
+
in_projection=self.in_projection.import_weights(weights["in_projection"]),
|
|
549
|
+
out_projection=self.out_projection.import_weights(weights["out_projection"]),
|
|
550
|
+
conv=self.conv.import_weights(weights["conv"]),
|
|
551
|
+
skip_connection_weight=weights["skip_connection_weight"],
|
|
552
|
+
gate_bias=weights["gate_bias"],
|
|
553
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .common import State, StateLayerBase
|
|
2
|
+
from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
|
|
3
|
+
from .mamba_state import Mamba2StateLayer
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"DynamicKVCacheLayer",
|
|
7
|
+
"KVCacheLayer",
|
|
8
|
+
"Mamba2StateLayer",
|
|
9
|
+
"State",
|
|
10
|
+
"StateLayerBase",
|
|
11
|
+
"StaticKVCacheLayer",
|
|
12
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Self
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
from jax.tree_util import register_pytree_node_class
|
|
6
|
+
|
|
7
|
+
from lalamo.common import ParameterTree
|
|
8
|
+
|
|
9
|
+
__all__ = ["State", "StateLayerBase"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StateLayerBase(eqx.Module):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def export(self) -> ParameterTree: ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_pytree_node_class
|
|
18
|
+
class State(tuple[StateLayerBase, ...]):
|
|
19
|
+
__slots__ = ()
|
|
20
|
+
|
|
21
|
+
def tree_flatten(self) -> tuple[tuple[StateLayerBase, ...], None]:
|
|
22
|
+
return (tuple(self), None)
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def tree_unflatten(cls, aux_data: None, children: tuple[StateLayerBase, ...]) -> Self: # noqa: ARG003
|
|
26
|
+
return cls(children)
|
|
@@ -4,15 +4,16 @@ from typing import Self
|
|
|
4
4
|
import equinox as eqx
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
from jax.lax import dynamic_update_slice_in_dim
|
|
7
|
-
from jax.tree_util import register_pytree_node_class
|
|
8
7
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int
|
|
9
8
|
|
|
10
9
|
from lalamo.common import ParameterTree
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
from .common import StateLayerBase
|
|
13
12
|
|
|
13
|
+
__all__ = ["DynamicKVCacheLayer", "KVCacheLayer", "StaticKVCacheLayer"]
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
|
|
16
|
+
class KVCacheLayer(StateLayerBase):
|
|
16
17
|
has_sinks: bool = eqx.field(static=True)
|
|
17
18
|
keys: Float[Array, "*batch tokens groups head_channels"]
|
|
18
19
|
values: Float[Array, "*batch tokens groups head_channels"]
|
|
@@ -58,18 +59,6 @@ class KVCacheLayer(eqx.Module):
|
|
|
58
59
|
)
|
|
59
60
|
|
|
60
61
|
|
|
61
|
-
@register_pytree_node_class
|
|
62
|
-
class KVCache(tuple[KVCacheLayer, ...]):
|
|
63
|
-
__slots__ = ()
|
|
64
|
-
|
|
65
|
-
def tree_flatten(self) -> tuple[tuple[KVCacheLayer, ...], None]:
|
|
66
|
-
return (tuple(self), None)
|
|
67
|
-
|
|
68
|
-
@classmethod
|
|
69
|
-
def tree_unflatten(cls, aux_data: None, children: tuple[KVCacheLayer, ...]) -> Self: # noqa: ARG003
|
|
70
|
-
return cls(children)
|
|
71
|
-
|
|
72
|
-
|
|
73
62
|
class DynamicKVCacheLayer(KVCacheLayer):
|
|
74
63
|
padding_mask: Bool[Array, " tokens"] | None = None
|
|
75
64
|
|
|
@@ -224,7 +213,7 @@ class StaticKVCacheLayer(KVCacheLayer):
|
|
|
224
213
|
)
|
|
225
214
|
|
|
226
215
|
@classmethod
|
|
227
|
-
def
|
|
216
|
+
def init(cls, has_sinks: bool, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
|
|
228
217
|
return cls(
|
|
229
218
|
has_sinks=has_sinks,
|
|
230
219
|
keys=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array, DTypeLike, Float
|
|
5
|
+
|
|
6
|
+
from lalamo.common import ParameterTree
|
|
7
|
+
|
|
8
|
+
from .common import StateLayerBase
|
|
9
|
+
|
|
10
|
+
__all__ = ["Mamba2StateLayer"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Mamba2StateLayer(StateLayerBase):
|
|
14
|
+
conv_state: Float[Array, "*batch tokens conv_channels"]
|
|
15
|
+
ssm_state: Float[Array, "*batch groups head_channels state_channels"]
|
|
16
|
+
|
|
17
|
+
def __post_init__(self) -> None:
|
|
18
|
+
if self.conv_state.ndim not in (2, 3):
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Conv state must have 2 or 3 dimensions: [batch], tokens, conv_channels,"
|
|
21
|
+
f" got shape {self.conv_state.shape}",
|
|
22
|
+
)
|
|
23
|
+
if self.ssm_state.ndim not in (3, 4):
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"SSM state must have 3 or 4 dimensions: [batch], groups, head_channels, state_channels,"
|
|
26
|
+
f" got shape {self.ssm_state.shape}",
|
|
27
|
+
)
|
|
28
|
+
if self.conv_state.dtype != self.ssm_state.dtype:
|
|
29
|
+
raise ValueError("Conv state and SSM state must have the same dtype")
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def init(
|
|
33
|
+
cls,
|
|
34
|
+
kernel_size: int,
|
|
35
|
+
inner_dim: int,
|
|
36
|
+
num_heads: int,
|
|
37
|
+
num_groups: int,
|
|
38
|
+
head_dim: int,
|
|
39
|
+
state_dim: int,
|
|
40
|
+
dtype: DTypeLike,
|
|
41
|
+
) -> Self:
|
|
42
|
+
return cls(
|
|
43
|
+
conv_state=jnp.zeros((kernel_size - 1, inner_dim + 2 * num_groups * state_dim), dtype=dtype),
|
|
44
|
+
ssm_state=jnp.zeros((num_heads, head_dim, state_dim), dtype=dtype),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def export(self) -> ParameterTree:
|
|
48
|
+
return dict(
|
|
49
|
+
conv_state=self.conv_state,
|
|
50
|
+
ssm_state=self.ssm_state,
|
|
51
|
+
)
|