lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/__init__.py +8 -0
- lalamo/model_import/common.py +111 -0
- lalamo/model_import/configs/__init__.py +23 -0
- lalamo/model_import/configs/common.py +62 -0
- lalamo/model_import/configs/executorch.py +166 -0
- lalamo/model_import/configs/huggingface/__init__.py +18 -0
- lalamo/model_import/configs/huggingface/common.py +72 -0
- lalamo/model_import/configs/huggingface/gemma2.py +122 -0
- lalamo/model_import/configs/huggingface/gemma3.py +187 -0
- lalamo/model_import/configs/huggingface/llama.py +155 -0
- lalamo/model_import/configs/huggingface/mistral.py +132 -0
- lalamo/model_import/configs/huggingface/qwen2.py +144 -0
- lalamo/model_import/configs/huggingface/qwen3.py +142 -0
- lalamo/model_import/loaders/__init__.py +7 -0
- lalamo/model_import/loaders/common.py +45 -0
- lalamo/model_import/loaders/executorch.py +223 -0
- lalamo/model_import/loaders/huggingface.py +304 -0
- lalamo/model_import/model_specs/__init__.py +38 -0
- lalamo/model_import/model_specs/common.py +118 -0
- lalamo/model_import/model_specs/deepseek.py +28 -0
- lalamo/model_import/model_specs/gemma.py +76 -0
- lalamo/model_import/model_specs/huggingface.py +28 -0
- lalamo/model_import/model_specs/llama.py +101 -0
- lalamo/model_import/model_specs/mistral.py +59 -0
- lalamo/model_import/model_specs/pleias.py +28 -0
- lalamo/model_import/model_specs/polaris.py +22 -0
- lalamo/model_import/model_specs/qwen.py +336 -0
- lalamo/model_import/model_specs/reka.py +28 -0
- lalamo/modules/__init__.py +85 -0
- lalamo/modules/activations.py +30 -0
- lalamo/modules/attention.py +326 -0
- lalamo/modules/common.py +133 -0
- lalamo/modules/decoder.py +244 -0
- lalamo/modules/decoder_layer.py +240 -0
- lalamo/modules/embedding.py +299 -0
- lalamo/modules/kv_cache.py +196 -0
- lalamo/modules/linear.py +603 -0
- lalamo/modules/mlp.py +79 -0
- lalamo/modules/normalization.py +77 -0
- lalamo/modules/rope.py +255 -0
- lalamo/modules/utils.py +13 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
- lalamo-0.2.2.dist-info/RECORD +53 -0
- lalamo-0.2.1.dist-info/RECORD +0 -12
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
import jax
|
|
6
|
+
from einops import einsum, rearrange, repeat
|
|
7
|
+
from jax import numpy as jnp
|
|
8
|
+
from jax import vmap
|
|
9
|
+
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
10
|
+
|
|
11
|
+
from lalamo.common import ParameterDict
|
|
12
|
+
from lalamo.modules.normalization import RMSNorm, RMSNormConfig
|
|
13
|
+
|
|
14
|
+
from .common import AttentionType, LalamoModule, WeightLayout
|
|
15
|
+
from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
|
|
16
|
+
from .linear import LinearBase, LinearConfig
|
|
17
|
+
from .rope import PositionalEmbeddings
|
|
18
|
+
from .utils import apply_soft_capping
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"Attention",
|
|
22
|
+
"AttentionConfig",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _repeat_kv(
|
|
27
|
+
keys_or_values: Float[Array, "tokens groups channels"],
|
|
28
|
+
group_size: int,
|
|
29
|
+
) -> Float[Array, "tokens groups*group_size channels"]:
|
|
30
|
+
return repeat(
|
|
31
|
+
keys_or_values,
|
|
32
|
+
"tokens groups channels -> tokens (groups group_size) channels",
|
|
33
|
+
group_size=group_size,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _soft_capped_attention_kernel(
|
|
38
|
+
queries: Float[Array, "dst_tokens heads head_channels"],
|
|
39
|
+
keys: Float[Array, "src_tokens groups head_channels"],
|
|
40
|
+
values: Float[Array, "src_tokens groups head_channels"],
|
|
41
|
+
mask: Bool[Array, "dst_tokens src_tokens"] | None,
|
|
42
|
+
scale: float | None,
|
|
43
|
+
logit_soft_cap: float,
|
|
44
|
+
) -> Float[Array, "dst_tokens heads head_channels"]:
|
|
45
|
+
dst_length, num_heads, head_dim = queries.shape
|
|
46
|
+
src_length, num_groups, _ = keys.shape
|
|
47
|
+
if scale is None:
|
|
48
|
+
scale = head_dim**-0.5
|
|
49
|
+
group_size = num_heads // num_groups
|
|
50
|
+
keys = _repeat_kv(keys, group_size)
|
|
51
|
+
values = _repeat_kv(values, group_size)
|
|
52
|
+
queries_head_first = rearrange(queries, "dst_tokens heads channels -> heads dst_tokens channels")
|
|
53
|
+
keys_head_first = rearrange(keys, "src_tokens heads channels -> heads src_tokens channels")
|
|
54
|
+
attention_logits = einsum(
|
|
55
|
+
queries_head_first,
|
|
56
|
+
keys_head_first,
|
|
57
|
+
"heads dst_tokens channels, heads src_tokens channels -> heads dst_tokens src_tokens",
|
|
58
|
+
)
|
|
59
|
+
if mask is not None:
|
|
60
|
+
attention_logits = jnp.where(mask, attention_logits, jnp.array(float("-inf"), dtype=attention_logits.dtype))
|
|
61
|
+
|
|
62
|
+
attention_logits = attention_logits * scale
|
|
63
|
+
attention_logits = apply_soft_capping(attention_logits, logit_soft_cap)
|
|
64
|
+
attention_weights = jax.nn.softmax(attention_logits, axis=-1)
|
|
65
|
+
return einsum(
|
|
66
|
+
attention_weights,
|
|
67
|
+
values,
|
|
68
|
+
"heads dst_tokens src_tokens, src_tokens heads channels -> dst_tokens heads channels",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class AttentionResult(NamedTuple):
|
|
73
|
+
outputs: Float[Array, "suffix_tokens channels"]
|
|
74
|
+
kv_cache: KVCacheLayer | None = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class AttentionConfig:
|
|
79
|
+
qkv_projection_config: LinearConfig
|
|
80
|
+
out_projection_config: LinearConfig
|
|
81
|
+
|
|
82
|
+
query_norm_config: RMSNormConfig | None
|
|
83
|
+
key_norm_config: RMSNormConfig | None
|
|
84
|
+
|
|
85
|
+
logit_soft_cap: float | None
|
|
86
|
+
has_qkv_biases: bool
|
|
87
|
+
has_out_biases: bool
|
|
88
|
+
|
|
89
|
+
def random_init(
|
|
90
|
+
self,
|
|
91
|
+
model_dim: int,
|
|
92
|
+
num_heads: int,
|
|
93
|
+
num_groups: int,
|
|
94
|
+
head_dim: int,
|
|
95
|
+
is_causal: bool,
|
|
96
|
+
scale: float | None,
|
|
97
|
+
sliding_window_size: int | None,
|
|
98
|
+
*,
|
|
99
|
+
key: PRNGKeyArray,
|
|
100
|
+
) -> "Attention":
|
|
101
|
+
qkv_key, out_key = jax.random.split(key)
|
|
102
|
+
qkv_projection = self.qkv_projection_config.random_init(
|
|
103
|
+
input_dim=model_dim,
|
|
104
|
+
output_dims=(
|
|
105
|
+
num_heads * head_dim,
|
|
106
|
+
num_groups * head_dim,
|
|
107
|
+
num_groups * head_dim,
|
|
108
|
+
),
|
|
109
|
+
has_biases=self.has_qkv_biases,
|
|
110
|
+
key=qkv_key,
|
|
111
|
+
)
|
|
112
|
+
out_projection = self.out_projection_config.random_init(
|
|
113
|
+
num_heads * head_dim,
|
|
114
|
+
(model_dim,),
|
|
115
|
+
has_biases=self.has_out_biases,
|
|
116
|
+
key=out_key,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if self.query_norm_config is not None:
|
|
120
|
+
query_norm = self.query_norm_config.init(
|
|
121
|
+
channels=head_dim,
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
query_norm = None
|
|
125
|
+
|
|
126
|
+
if self.key_norm_config is not None:
|
|
127
|
+
key_norm = self.key_norm_config.init(
|
|
128
|
+
channels=head_dim,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
key_norm = None
|
|
132
|
+
|
|
133
|
+
return Attention(
|
|
134
|
+
self,
|
|
135
|
+
qkv_projection=qkv_projection,
|
|
136
|
+
out_projection=out_projection,
|
|
137
|
+
query_norm=query_norm,
|
|
138
|
+
key_norm=key_norm,
|
|
139
|
+
num_heads=num_heads,
|
|
140
|
+
num_groups=num_groups,
|
|
141
|
+
head_dim=head_dim,
|
|
142
|
+
is_causal=is_causal,
|
|
143
|
+
scale=scale,
|
|
144
|
+
sliding_window_size=sliding_window_size,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Attention(LalamoModule[AttentionConfig]):
|
|
149
|
+
qkv_projection: LinearBase
|
|
150
|
+
out_projection: LinearBase
|
|
151
|
+
|
|
152
|
+
query_norm: RMSNorm | None
|
|
153
|
+
key_norm: RMSNorm | None
|
|
154
|
+
|
|
155
|
+
num_heads: int = eqx.field(static=True)
|
|
156
|
+
num_groups: int = eqx.field(static=True)
|
|
157
|
+
head_dim: int = eqx.field(static=True)
|
|
158
|
+
|
|
159
|
+
is_causal: bool = eqx.field(static=True)
|
|
160
|
+
|
|
161
|
+
scale: float | None = eqx.field(static=True)
|
|
162
|
+
sliding_window_size: int | None = eqx.field(static=True)
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def activation_precision(self) -> DTypeLike:
|
|
166
|
+
return self.qkv_projection.activation_precision
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def model_dim(self) -> int:
|
|
170
|
+
return self.qkv_projection.input_dim
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def group_size(self) -> int:
|
|
174
|
+
return self.num_heads // self.num_groups
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def use_sliding_window(self) -> bool:
|
|
178
|
+
return self.sliding_window_size is not None
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def attention_type(self) -> AttentionType:
|
|
182
|
+
return AttentionType.SLIDING_WINDOW if self.sliding_window_size is not None else AttentionType.GLOBAL
|
|
183
|
+
|
|
184
|
+
def __post_init__(self) -> None:
|
|
185
|
+
if self.qkv_projection.has_biases != self.config.has_qkv_biases:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"QKV projection has_biases {self.qkv_projection.has_biases} does not match"
|
|
188
|
+
f" the specified config has_qkv_biases {self.config.has_qkv_biases}",
|
|
189
|
+
)
|
|
190
|
+
if self.out_projection.has_biases != self.config.has_out_biases:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f"Output projection has_biases {self.out_projection.has_biases} does not match"
|
|
193
|
+
f" the specified config has_out_biases {self.config.has_out_biases}",
|
|
194
|
+
)
|
|
195
|
+
if self.query_norm is not None and self.query_norm.input_dim != self.head_dim:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"Query normalization input dimension must match head_dim ({self.head_dim}),"
|
|
198
|
+
f" got {self.query_norm.input_dim}",
|
|
199
|
+
)
|
|
200
|
+
if self.key_norm is not None and self.key_norm.input_dim != self.head_dim:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Key normalization input dimension must match head_dim ({self.head_dim}),"
|
|
203
|
+
f" got {self.key_norm.input_dim}",
|
|
204
|
+
)
|
|
205
|
+
if self.num_heads % self.num_groups != 0:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"Number of heads must be divisible by the number of groups,"
|
|
208
|
+
f" got {self.num_heads} heads and {self.num_groups} groups",
|
|
209
|
+
)
|
|
210
|
+
if self.out_projection.input_dim != self.num_heads * self.head_dim:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Output projection input dimension must be num_heads * head_dim"
|
|
213
|
+
f" ({self.num_heads} * {self.head_dim} = {self.num_heads * self.head_dim}),"
|
|
214
|
+
f" got {self.out_projection.input_dim}",
|
|
215
|
+
)
|
|
216
|
+
q_output_dim, k_output_dim, v_output_dim = self.qkv_projection.output_dims
|
|
217
|
+
if q_output_dim != self.num_heads * self.head_dim:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Query projection output dimension must be num_heads * head_dim"
|
|
220
|
+
f" ({self.num_heads} * {self.head_dim} = {self.num_heads * self.head_dim}),"
|
|
221
|
+
f" got {q_output_dim}",
|
|
222
|
+
)
|
|
223
|
+
if k_output_dim != self.num_groups * self.head_dim:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Key projection output dimension must be num_groups * head_dim"
|
|
226
|
+
f" ({self.num_groups} * {self.head_dim} = {self.num_groups * self.head_dim}),"
|
|
227
|
+
f" got {k_output_dim}",
|
|
228
|
+
)
|
|
229
|
+
if v_output_dim != self.num_groups * self.head_dim:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Value projection output dimension must be num_groups * head_dim"
|
|
232
|
+
f" ({self.num_groups} * {self.head_dim} = {self.num_groups * self.head_dim}),"
|
|
233
|
+
f" got {v_output_dim}",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def __call__(
|
|
237
|
+
self,
|
|
238
|
+
inputs: Float[Array, "suffix_tokens channels"],
|
|
239
|
+
positional_embeddings: PositionalEmbeddings,
|
|
240
|
+
kv_cache: KVCacheLayer | None = None,
|
|
241
|
+
return_updated_kv_cache: bool = False,
|
|
242
|
+
length_without_padding: Int[Array, ""] | int | None = None,
|
|
243
|
+
) -> AttentionResult:
|
|
244
|
+
queries, keys, values = vmap(self.qkv_projection, in_axes=0)(inputs)
|
|
245
|
+
queries = rearrange(
|
|
246
|
+
queries,
|
|
247
|
+
"tokens (heads head_channels) -> tokens heads head_channels",
|
|
248
|
+
heads=self.num_heads,
|
|
249
|
+
head_channels=self.head_dim,
|
|
250
|
+
)
|
|
251
|
+
keys = rearrange(
|
|
252
|
+
keys,
|
|
253
|
+
"tokens (groups head_channels) -> tokens groups head_channels",
|
|
254
|
+
groups=self.num_groups,
|
|
255
|
+
head_channels=self.head_dim,
|
|
256
|
+
)
|
|
257
|
+
values = rearrange(
|
|
258
|
+
values,
|
|
259
|
+
"tokens (groups head_channels) -> tokens groups head_channels",
|
|
260
|
+
groups=self.num_groups,
|
|
261
|
+
head_channels=self.head_dim,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if self.query_norm is not None:
|
|
265
|
+
queries = vmap(vmap(self.query_norm))(queries)
|
|
266
|
+
if self.key_norm is not None:
|
|
267
|
+
keys = vmap(vmap(self.key_norm))(keys)
|
|
268
|
+
|
|
269
|
+
apply_positional_embeddings = vmap(positional_embeddings.apply, in_axes=1, out_axes=1)
|
|
270
|
+
queries = apply_positional_embeddings(queries)
|
|
271
|
+
keys = apply_positional_embeddings(keys)
|
|
272
|
+
|
|
273
|
+
if kv_cache is None:
|
|
274
|
+
updated_kv_cache = DynamicKVCacheLayer.init(keys, values, length=length_without_padding)
|
|
275
|
+
else:
|
|
276
|
+
updated_kv_cache = kv_cache.extend(keys, values, added_length=length_without_padding)
|
|
277
|
+
|
|
278
|
+
num_suffix_tokens, _, _ = queries.shape
|
|
279
|
+
mask = updated_kv_cache.attention_mask(num_suffix_tokens, self.is_causal, self.sliding_window_size)
|
|
280
|
+
|
|
281
|
+
if self.config.logit_soft_cap is not None:
|
|
282
|
+
attention_output = _soft_capped_attention_kernel(
|
|
283
|
+
queries,
|
|
284
|
+
updated_kv_cache.keys,
|
|
285
|
+
updated_kv_cache.values,
|
|
286
|
+
mask=mask,
|
|
287
|
+
scale=self.scale,
|
|
288
|
+
logit_soft_cap=self.config.logit_soft_cap,
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
attention_output = jax.nn.dot_product_attention(
|
|
292
|
+
queries,
|
|
293
|
+
updated_kv_cache.keys,
|
|
294
|
+
updated_kv_cache.values,
|
|
295
|
+
mask=mask,
|
|
296
|
+
scale=self.scale,
|
|
297
|
+
)
|
|
298
|
+
attention_output = rearrange(
|
|
299
|
+
attention_output,
|
|
300
|
+
"tokens heads head_channels -> tokens (heads head_channels)",
|
|
301
|
+
heads=self.num_heads,
|
|
302
|
+
head_channels=self.head_dim,
|
|
303
|
+
)
|
|
304
|
+
(result,) = vmap(self.out_projection, in_axes=0)(attention_output)
|
|
305
|
+
|
|
306
|
+
if not return_updated_kv_cache:
|
|
307
|
+
updated_kv_cache = None
|
|
308
|
+
|
|
309
|
+
return AttentionResult(
|
|
310
|
+
outputs=result,
|
|
311
|
+
kv_cache=updated_kv_cache,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
|
|
315
|
+
return StaticKVCacheLayer.empty(capacity, self.num_groups, self.head_dim, self.activation_precision)
|
|
316
|
+
|
|
317
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
318
|
+
result = ParameterDict(
|
|
319
|
+
qkv_projection=self.qkv_projection.export_weights(weight_layout),
|
|
320
|
+
out_projection=self.out_projection.export_weights(weight_layout),
|
|
321
|
+
)
|
|
322
|
+
if self.query_norm is not None:
|
|
323
|
+
result["query_norm"] = self.query_norm.export_weights(weight_layout)
|
|
324
|
+
if self.key_norm is not None:
|
|
325
|
+
result["key_norm"] = self.key_norm.export_weights(weight_layout)
|
|
326
|
+
return result
|
lalamo/modules/common.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from types import UnionType
|
|
5
|
+
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
from cattrs import Converter
|
|
8
|
+
from jax import numpy as jnp
|
|
9
|
+
from jaxtyping import DTypeLike
|
|
10
|
+
|
|
11
|
+
from lalamo.common import ParameterDict
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"AttentionType",
|
|
15
|
+
"DummyUnionMember",
|
|
16
|
+
"LalamoModule",
|
|
17
|
+
"config_converter",
|
|
18
|
+
"register_config_union",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class WeightLayout(Enum):
|
|
23
|
+
AUTO = "auto"
|
|
24
|
+
INPUT_OUTPUT = "input_output"
|
|
25
|
+
OUTPUT_INPUT = "output_input"
|
|
26
|
+
|
|
27
|
+
def __str__(self) -> str:
|
|
28
|
+
match self:
|
|
29
|
+
case WeightLayout.AUTO:
|
|
30
|
+
return "auto"
|
|
31
|
+
case WeightLayout.INPUT_OUTPUT:
|
|
32
|
+
return "(input, output)"
|
|
33
|
+
case WeightLayout.OUTPUT_INPUT:
|
|
34
|
+
return "(output, input)"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AttentionType(Enum):
|
|
38
|
+
GLOBAL = "global"
|
|
39
|
+
SLIDING_WINDOW = "sliding_window"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LalamoModule[ConfigT](eqx.Module):
|
|
43
|
+
config: ConfigT = eqx.field(static=True)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def activation_precision(self) -> DTypeLike: ...
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: ...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _dtype_to_str(dtype: DTypeLike) -> str:
|
|
54
|
+
if dtype == jnp.bfloat16:
|
|
55
|
+
return "bfloat16"
|
|
56
|
+
try:
|
|
57
|
+
return str(dtype.dtype) # type: ignore
|
|
58
|
+
except AttributeError:
|
|
59
|
+
return str(dtype)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _str_to_dtype(dtype_str: str) -> jnp.dtype:
|
|
63
|
+
return {
|
|
64
|
+
"int4": jnp.int4,
|
|
65
|
+
"int8": jnp.int8,
|
|
66
|
+
"int16": jnp.int16,
|
|
67
|
+
"int32": jnp.int32,
|
|
68
|
+
"int64": jnp.int64,
|
|
69
|
+
"bfloat16": jnp.bfloat16,
|
|
70
|
+
"float16": jnp.float16,
|
|
71
|
+
"float32": jnp.float32,
|
|
72
|
+
"float64": jnp.float64,
|
|
73
|
+
}[dtype_str]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
config_converter = Converter()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
config_converter.register_unstructure_hook_func(
|
|
80
|
+
lambda t: t in [jnp.dtype, DTypeLike],
|
|
81
|
+
_dtype_to_str,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
config_converter.register_structure_hook_func(
|
|
85
|
+
lambda t: t in [jnp.dtype, DTypeLike],
|
|
86
|
+
lambda s, _: _str_to_dtype(s),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def register_config_union(union_type: UnionType) -> None:
|
|
91
|
+
union_members = union_type.__args__
|
|
92
|
+
name_to_type = {m.__name__: m for m in union_members}
|
|
93
|
+
|
|
94
|
+
def unstructure(obj: object) -> dict | None:
|
|
95
|
+
if obj is None:
|
|
96
|
+
return None
|
|
97
|
+
return {
|
|
98
|
+
"type": obj.__class__.__name__,
|
|
99
|
+
**config_converter.unstructure(obj),
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
config_converter.register_unstructure_hook(
|
|
103
|
+
union_type,
|
|
104
|
+
unstructure,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
config_converter.register_unstructure_hook(
|
|
108
|
+
union_type | None,
|
|
109
|
+
unstructure,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def structure[T](config: dict | None, _: type[T]) -> T | None:
|
|
113
|
+
if config is None:
|
|
114
|
+
return None
|
|
115
|
+
new_config = dict(config)
|
|
116
|
+
type_name = new_config.pop("type")
|
|
117
|
+
target_type = name_to_type[type_name]
|
|
118
|
+
return name_to_type[type_name](**config_converter.structure(new_config, target_type))
|
|
119
|
+
|
|
120
|
+
config_converter.register_structure_hook(
|
|
121
|
+
union_type,
|
|
122
|
+
structure,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
config_converter.register_structure_hook(
|
|
126
|
+
union_type | None,
|
|
127
|
+
structure,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class DummyUnionMember:
|
|
133
|
+
pass
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
from jax import vmap
|
|
6
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
7
|
+
|
|
8
|
+
from lalamo.common import ParameterDict
|
|
9
|
+
|
|
10
|
+
from .common import AttentionType, LalamoModule, WeightLayout
|
|
11
|
+
from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerResult
|
|
12
|
+
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
13
|
+
from .kv_cache import KVCache
|
|
14
|
+
from .normalization import RMSNorm, RMSNormConfig
|
|
15
|
+
from .rope import PositionalEmbeddings, RoPE, RoPEConfig
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"Decoder",
|
|
19
|
+
"DecoderActivationTrace",
|
|
20
|
+
"DecoderConfig",
|
|
21
|
+
"DecoderResult",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DecoderActivationTrace(eqx.Module):
|
|
26
|
+
token_ids: Int[Array, " suffix_tokens"]
|
|
27
|
+
token_positions: Int[Array, " suffix_tokens"]
|
|
28
|
+
kv_cache: KVCache | None
|
|
29
|
+
|
|
30
|
+
local_positional_embeddings: PositionalEmbeddings
|
|
31
|
+
global_positional_embeddings: PositionalEmbeddings
|
|
32
|
+
|
|
33
|
+
layer_results: tuple[DecoderLayerResult, ...]
|
|
34
|
+
|
|
35
|
+
output_norm: Float[Array, "suffix_tokens channels"]
|
|
36
|
+
|
|
37
|
+
def export(self) -> ParameterDict:
|
|
38
|
+
result = ParameterDict(
|
|
39
|
+
token_ids=self.token_ids,
|
|
40
|
+
token_positions=self.token_positions,
|
|
41
|
+
local_positional_embeddings=self.local_positional_embeddings.export(),
|
|
42
|
+
global_positional_embeddings=self.global_positional_embeddings.export(),
|
|
43
|
+
layer_results=[layer_result.export() for layer_result in self.layer_results],
|
|
44
|
+
output_norm=self.output_norm,
|
|
45
|
+
)
|
|
46
|
+
if self.kv_cache is not None:
|
|
47
|
+
result["kv_cache"] = [kv_cache_layer_slice.export() for kv_cache_layer_slice in self.kv_cache]
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DecoderResult(eqx.Module):
|
|
52
|
+
logits: Float[Array, "suffix_tokens channels"]
|
|
53
|
+
updated_kv_cache: KVCache | None = None
|
|
54
|
+
activation_trace: DecoderActivationTrace | None = None
|
|
55
|
+
|
|
56
|
+
def export(self) -> ParameterDict:
|
|
57
|
+
result = ParameterDict(
|
|
58
|
+
logits=self.logits,
|
|
59
|
+
)
|
|
60
|
+
if self.updated_kv_cache is not None:
|
|
61
|
+
result["updated_kv_cache"] = [
|
|
62
|
+
kv_cache_layer_slice.export() for kv_cache_layer_slice in self.updated_kv_cache
|
|
63
|
+
]
|
|
64
|
+
if self.activation_trace is not None:
|
|
65
|
+
result["activation_trace"] = self.activation_trace.export()
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(frozen=True)
|
|
70
|
+
class DecoderConfig:
|
|
71
|
+
embedding_config: EmbeddingConfig
|
|
72
|
+
global_rope_config: RoPEConfig
|
|
73
|
+
local_rope_config: RoPEConfig | None
|
|
74
|
+
layer_config: DecoderLayerConfig
|
|
75
|
+
output_norm_config: RMSNormConfig
|
|
76
|
+
|
|
77
|
+
vocab_size: int
|
|
78
|
+
model_dim: int
|
|
79
|
+
hidden_dim: int
|
|
80
|
+
num_heads: int
|
|
81
|
+
num_groups: int
|
|
82
|
+
head_dim: int
|
|
83
|
+
attention_scale: float | None
|
|
84
|
+
num_layers: int
|
|
85
|
+
sliding_window_sizes: tuple[int | None, ...] | None
|
|
86
|
+
context_length: int
|
|
87
|
+
|
|
88
|
+
def __post_init__(self) -> None:
|
|
89
|
+
if self.local_rope_config is not None and self.sliding_window_sizes is None:
|
|
90
|
+
raise ValueError("Sliding window sizes must be provided when using local RoPE")
|
|
91
|
+
if self.sliding_window_sizes is None:
|
|
92
|
+
return
|
|
93
|
+
if len(self.sliding_window_sizes) != self.num_layers:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Number of sliding window sizes {len(self.sliding_window_sizes)} does not match"
|
|
96
|
+
f" the number of layers {self.num_layers}",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def random_init(
|
|
100
|
+
self,
|
|
101
|
+
*,
|
|
102
|
+
key: PRNGKeyArray,
|
|
103
|
+
) -> "Decoder":
|
|
104
|
+
embedding_key, layers_key = jax.random.split(key)
|
|
105
|
+
embedding = self.embedding_config.random_init(
|
|
106
|
+
vocab_size=self.vocab_size,
|
|
107
|
+
model_dim=self.model_dim,
|
|
108
|
+
key=embedding_key,
|
|
109
|
+
)
|
|
110
|
+
global_rope = self.global_rope_config.init(
|
|
111
|
+
head_dim=self.head_dim,
|
|
112
|
+
num_timesteps=self.context_length,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if self.local_rope_config:
|
|
116
|
+
assert self.sliding_window_sizes is not None
|
|
117
|
+
max_sliding_window_size = max(
|
|
118
|
+
window_size for window_size in self.sliding_window_sizes if window_size is not None
|
|
119
|
+
)
|
|
120
|
+
local_rope = self.local_rope_config.init(
|
|
121
|
+
head_dim=self.head_dim,
|
|
122
|
+
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
local_rope = None
|
|
126
|
+
|
|
127
|
+
if self.sliding_window_sizes is None:
|
|
128
|
+
sliding_window_sizes = [None] * self.num_layers
|
|
129
|
+
else:
|
|
130
|
+
sliding_window_sizes = self.sliding_window_sizes
|
|
131
|
+
layers_keys = jax.random.split(layers_key, self.num_layers)
|
|
132
|
+
layers = tuple(
|
|
133
|
+
self.layer_config.random_init(
|
|
134
|
+
model_dim=self.model_dim,
|
|
135
|
+
hidden_dim=self.hidden_dim,
|
|
136
|
+
num_heads=self.num_heads,
|
|
137
|
+
num_groups=self.num_groups,
|
|
138
|
+
head_dim=self.head_dim,
|
|
139
|
+
attention_scale=self.attention_scale,
|
|
140
|
+
sliding_window_size=sliding_window_size,
|
|
141
|
+
key=key,
|
|
142
|
+
)
|
|
143
|
+
for sliding_window_size, key in zip(sliding_window_sizes, layers_keys, strict=True)
|
|
144
|
+
)
|
|
145
|
+
output_norm = self.output_norm_config.init(self.model_dim)
|
|
146
|
+
return Decoder(
|
|
147
|
+
self,
|
|
148
|
+
embedding=embedding,
|
|
149
|
+
global_rope=global_rope,
|
|
150
|
+
local_rope=local_rope,
|
|
151
|
+
layers=layers,
|
|
152
|
+
output_norm=output_norm,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class Decoder(LalamoModule[DecoderConfig]):
|
|
157
|
+
embedding: EmbeddingBase
|
|
158
|
+
global_rope: RoPE
|
|
159
|
+
local_rope: RoPE | None
|
|
160
|
+
layers: tuple[DecoderLayer, ...]
|
|
161
|
+
output_norm: RMSNorm
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def activation_precision(self) -> DTypeLike:
|
|
165
|
+
return self.embedding.activation_precision
|
|
166
|
+
|
|
167
|
+
def __call__(
|
|
168
|
+
self,
|
|
169
|
+
token_ids: Int[Array, " suffix_tokens"],
|
|
170
|
+
token_positions: Int[Array, " suffix_tokens"],
|
|
171
|
+
kv_cache: KVCache | None = None,
|
|
172
|
+
return_updated_kv_cache: bool = False,
|
|
173
|
+
return_activation_trace: bool = False,
|
|
174
|
+
length_without_padding: Int[Array, ""] | int | None = None,
|
|
175
|
+
) -> DecoderResult:
|
|
176
|
+
maybe_kv_cache = kv_cache or ([None] * len(self.layers))
|
|
177
|
+
inner_features = self.embedding.embed(token_ids)
|
|
178
|
+
|
|
179
|
+
global_positional_embeddings = self.global_rope(token_positions)
|
|
180
|
+
if self.local_rope is not None:
|
|
181
|
+
local_positional_embeddings = self.local_rope(token_positions)
|
|
182
|
+
else:
|
|
183
|
+
local_positional_embeddings = global_positional_embeddings
|
|
184
|
+
|
|
185
|
+
updated_kv_cache_layers = []
|
|
186
|
+
layer_results = []
|
|
187
|
+
for layer, kv_cache_slice in zip(self.layers, maybe_kv_cache, strict=True):
|
|
188
|
+
if layer.attention_type == AttentionType.SLIDING_WINDOW:
|
|
189
|
+
positional_embeddings_to_use = local_positional_embeddings
|
|
190
|
+
else:
|
|
191
|
+
positional_embeddings_to_use = global_positional_embeddings
|
|
192
|
+
|
|
193
|
+
layer_result = layer(
|
|
194
|
+
inner_features,
|
|
195
|
+
positional_embeddings_to_use,
|
|
196
|
+
kv_cache=kv_cache_slice,
|
|
197
|
+
return_updated_kv_cache=return_updated_kv_cache,
|
|
198
|
+
return_activation_trace=return_activation_trace,
|
|
199
|
+
length_without_padding=length_without_padding,
|
|
200
|
+
)
|
|
201
|
+
inner_features = layer_result.outputs
|
|
202
|
+
layer_results.append(layer_result)
|
|
203
|
+
updated_kv_cache_layers.append(layer_result.updated_kv_cache)
|
|
204
|
+
|
|
205
|
+
normalized_outputs = vmap(self.output_norm, in_axes=0)(inner_features)
|
|
206
|
+
logits = vmap(self.embedding.readout, in_axes=0)(normalized_outputs)
|
|
207
|
+
|
|
208
|
+
if return_activation_trace:
|
|
209
|
+
activation_trace = DecoderActivationTrace(
|
|
210
|
+
token_ids=token_ids,
|
|
211
|
+
token_positions=token_positions,
|
|
212
|
+
kv_cache=kv_cache,
|
|
213
|
+
global_positional_embeddings=global_positional_embeddings,
|
|
214
|
+
local_positional_embeddings=local_positional_embeddings,
|
|
215
|
+
layer_results=tuple(layer_results),
|
|
216
|
+
output_norm=normalized_outputs,
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
activation_trace = None
|
|
220
|
+
|
|
221
|
+
if return_updated_kv_cache:
|
|
222
|
+
updated_kv_cache = KVCache(updated_kv_cache_layers)
|
|
223
|
+
else:
|
|
224
|
+
updated_kv_cache = None
|
|
225
|
+
|
|
226
|
+
return DecoderResult(
|
|
227
|
+
logits=logits,
|
|
228
|
+
updated_kv_cache=updated_kv_cache,
|
|
229
|
+
activation_trace=activation_trace,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def init_static_kv_cache(self, capacity: int) -> KVCache:
|
|
233
|
+
return KVCache(layer.init_static_kv_cache(capacity) for layer in self.layers)
|
|
234
|
+
|
|
235
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
236
|
+
result = ParameterDict(
|
|
237
|
+
embedding=self.embedding.export_weights(weight_layout),
|
|
238
|
+
global_rope=self.global_rope.export_weights(weight_layout),
|
|
239
|
+
layers=[layer.export_weights(weight_layout) for layer in self.layers],
|
|
240
|
+
output_norm=self.output_norm.export_weights(weight_layout),
|
|
241
|
+
)
|
|
242
|
+
if self.local_rope:
|
|
243
|
+
result["local_rope"] = self.local_rope.export_weights(weight_layout)
|
|
244
|
+
return result
|