lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +271 -43
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +17 -7
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -4
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
- lalamo-0.4.0.dist-info/RECORD +71 -0
- lalamo-0.3.3.dist-info/RECORD +0 -59
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/modules/attention.py
CHANGED
|
@@ -9,9 +9,10 @@ from jax import numpy as jnp
|
|
|
9
9
|
from jax import vmap
|
|
10
10
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
|
+
from lalamo.common import dummy_array
|
|
12
13
|
from lalamo.modules.normalization import RMSNorm, RMSNormConfig
|
|
13
14
|
|
|
14
|
-
from .common import AttentionType, LalamoModule, ParameterTree
|
|
15
|
+
from .common import AttentionType, LalamoModule, ParameterTree
|
|
15
16
|
from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
|
|
16
17
|
from .linear import LinearBase, LinearConfig
|
|
17
18
|
from .rope import PositionalEmbeddings
|
|
@@ -44,8 +45,6 @@ def _soft_capped_attention_kernel(
|
|
|
44
45
|
) -> Float[Array, "dst_tokens heads head_channels"]:
|
|
45
46
|
_, num_heads, head_dim = queries.shape
|
|
46
47
|
_, num_groups, _ = keys.shape
|
|
47
|
-
if scale is None:
|
|
48
|
-
scale = head_dim**-0.5
|
|
49
48
|
group_size = num_heads // num_groups
|
|
50
49
|
keys = _repeat_kv(keys, group_size)
|
|
51
50
|
values = _repeat_kv(values, group_size)
|
|
@@ -59,7 +58,11 @@ def _soft_capped_attention_kernel(
|
|
|
59
58
|
if mask is not None:
|
|
60
59
|
attention_logits = jnp.where(mask, attention_logits, jnp.array(float("-inf"), dtype=attention_logits.dtype))
|
|
61
60
|
|
|
62
|
-
|
|
61
|
+
if scale is None:
|
|
62
|
+
scale_val = head_dim**-0.5
|
|
63
|
+
else:
|
|
64
|
+
scale_val = float(scale)
|
|
65
|
+
attention_logits = attention_logits * scale_val
|
|
63
66
|
attention_logits = apply_soft_capping(attention_logits, logit_soft_cap)
|
|
64
67
|
attention_weights = jax.nn.softmax(attention_logits, axis=-1)
|
|
65
68
|
return einsum(
|
|
@@ -70,7 +73,7 @@ def _soft_capped_attention_kernel(
|
|
|
70
73
|
|
|
71
74
|
|
|
72
75
|
class AttentionResult(NamedTuple):
|
|
73
|
-
outputs: Float[Array, "suffix_tokens channels"]
|
|
76
|
+
outputs: Float[Array, "*batch suffix_tokens channels"]
|
|
74
77
|
kv_cache: KVCacheLayer | None = None
|
|
75
78
|
|
|
76
79
|
|
|
@@ -83,6 +86,7 @@ class AttentionConfig:
|
|
|
83
86
|
key_norm_config: RMSNormConfig | None
|
|
84
87
|
|
|
85
88
|
logit_soft_cap: float | None
|
|
89
|
+
has_sinks: bool
|
|
86
90
|
has_qkv_biases: bool
|
|
87
91
|
has_out_biases: bool
|
|
88
92
|
|
|
@@ -130,12 +134,18 @@ class AttentionConfig:
|
|
|
130
134
|
else:
|
|
131
135
|
key_norm = None
|
|
132
136
|
|
|
137
|
+
if self.has_sinks:
|
|
138
|
+
sinks = jnp.zeros((num_heads,), dtype=qkv_projection.activation_precision)
|
|
139
|
+
else:
|
|
140
|
+
sinks = None
|
|
141
|
+
|
|
133
142
|
return Attention(
|
|
134
143
|
self,
|
|
135
144
|
qkv_projection=qkv_projection,
|
|
136
145
|
out_projection=out_projection,
|
|
137
146
|
query_norm=query_norm,
|
|
138
147
|
key_norm=key_norm,
|
|
148
|
+
sinks=sinks,
|
|
139
149
|
num_heads=num_heads,
|
|
140
150
|
num_groups=num_groups,
|
|
141
151
|
head_dim=head_dim,
|
|
@@ -183,12 +193,18 @@ class AttentionConfig:
|
|
|
183
193
|
else:
|
|
184
194
|
key_norm = None
|
|
185
195
|
|
|
196
|
+
if self.has_sinks:
|
|
197
|
+
sinks = dummy_array(num_heads, qkv_projection.activation_precision)
|
|
198
|
+
else:
|
|
199
|
+
sinks = None
|
|
200
|
+
|
|
186
201
|
return Attention(
|
|
187
202
|
self,
|
|
188
203
|
qkv_projection=qkv_projection,
|
|
189
204
|
out_projection=out_projection,
|
|
190
205
|
query_norm=query_norm,
|
|
191
206
|
key_norm=key_norm,
|
|
207
|
+
sinks=sinks,
|
|
192
208
|
num_heads=num_heads,
|
|
193
209
|
num_groups=num_groups,
|
|
194
210
|
head_dim=head_dim,
|
|
@@ -205,6 +221,8 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
205
221
|
query_norm: RMSNorm | None
|
|
206
222
|
key_norm: RMSNorm | None
|
|
207
223
|
|
|
224
|
+
sinks: Float[Array, " heads"] | None
|
|
225
|
+
|
|
208
226
|
num_heads: int = eqx.field(static=True)
|
|
209
227
|
num_groups: int = eqx.field(static=True)
|
|
210
228
|
head_dim: int = eqx.field(static=True)
|
|
@@ -234,6 +252,10 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
234
252
|
def attention_type(self) -> AttentionType:
|
|
235
253
|
return AttentionType.SLIDING_WINDOW if self.sliding_window_size is not None else AttentionType.GLOBAL
|
|
236
254
|
|
|
255
|
+
@property
|
|
256
|
+
def has_sinks(self) -> bool:
|
|
257
|
+
return self.sinks is not None
|
|
258
|
+
|
|
237
259
|
def __post_init__(self) -> None:
|
|
238
260
|
if self.qkv_projection.has_biases != self.config.has_qkv_biases:
|
|
239
261
|
raise ValueError(
|
|
@@ -285,6 +307,12 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
285
307
|
f" ({self.num_groups} * {self.head_dim} = {self.num_groups * self.head_dim}),"
|
|
286
308
|
f" got {v_output_dim}",
|
|
287
309
|
)
|
|
310
|
+
if self.sinks is not None:
|
|
311
|
+
(num_sink_heads,) = self.sinks.shape
|
|
312
|
+
if num_sink_heads != self.num_heads:
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"Number of sink heads must be equal to number of heads ({self.num_heads}), got {num_sink_heads}",
|
|
315
|
+
)
|
|
288
316
|
|
|
289
317
|
@eqx.filter_jit
|
|
290
318
|
def __call__(
|
|
@@ -325,12 +353,22 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
325
353
|
keys = apply_positional_embeddings(keys)
|
|
326
354
|
|
|
327
355
|
if kv_cache is None:
|
|
328
|
-
updated_kv_cache = DynamicKVCacheLayer.init(keys, values, length=length_without_padding)
|
|
356
|
+
updated_kv_cache = DynamicKVCacheLayer.init(self.has_sinks, keys, values, length=length_without_padding)
|
|
329
357
|
else:
|
|
330
358
|
updated_kv_cache = kv_cache.extend(keys, values, added_length=length_without_padding)
|
|
331
359
|
|
|
332
360
|
num_suffix_tokens, _, _ = queries.shape
|
|
333
|
-
mask = updated_kv_cache.attention_mask(
|
|
361
|
+
mask = updated_kv_cache.attention_mask(
|
|
362
|
+
num_suffix_tokens,
|
|
363
|
+
self.is_causal,
|
|
364
|
+
length_without_padding,
|
|
365
|
+
self.sliding_window_size,
|
|
366
|
+
)
|
|
367
|
+
if self.sinks is not None:
|
|
368
|
+
sink_bias = jnp.zeros((self.num_heads, *mask.shape), dtype=queries.dtype)
|
|
369
|
+
sink_bias = sink_bias.at[:, :, 0].set(self.sinks[:, None])
|
|
370
|
+
else:
|
|
371
|
+
sink_bias = None
|
|
334
372
|
|
|
335
373
|
if self.config.logit_soft_cap is not None:
|
|
336
374
|
attention_output = _soft_capped_attention_kernel(
|
|
@@ -346,6 +384,7 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
346
384
|
queries,
|
|
347
385
|
updated_kv_cache.keys,
|
|
348
386
|
updated_kv_cache.values,
|
|
387
|
+
bias=sink_bias,
|
|
349
388
|
mask=mask,
|
|
350
389
|
scale=self.scale,
|
|
351
390
|
)
|
|
@@ -366,41 +405,55 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
366
405
|
)
|
|
367
406
|
|
|
368
407
|
def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
|
|
369
|
-
return StaticKVCacheLayer.empty(
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
408
|
+
return StaticKVCacheLayer.empty(
|
|
409
|
+
self.has_sinks,
|
|
410
|
+
capacity,
|
|
411
|
+
self.num_groups,
|
|
412
|
+
self.head_dim,
|
|
413
|
+
self.activation_precision,
|
|
375
414
|
)
|
|
415
|
+
|
|
416
|
+
def export_weights(self) -> ParameterTree:
|
|
417
|
+
result: dict[str, ParameterTree | Array] = {
|
|
418
|
+
"qkv_projection": self.qkv_projection.export_weights(),
|
|
419
|
+
"out_projection": self.out_projection.export_weights(),
|
|
420
|
+
}
|
|
376
421
|
if self.query_norm is not None:
|
|
377
|
-
result["query_norm"] = self.query_norm.export_weights(
|
|
422
|
+
result["query_norm"] = self.query_norm.export_weights()
|
|
378
423
|
if self.key_norm is not None:
|
|
379
|
-
result["key_norm"] = self.key_norm.export_weights(
|
|
424
|
+
result["key_norm"] = self.key_norm.export_weights()
|
|
425
|
+
if self.sinks is not None:
|
|
426
|
+
assert isinstance(self.sinks, Array)
|
|
427
|
+
result["sinks"] = self.sinks
|
|
380
428
|
return result
|
|
381
429
|
|
|
382
430
|
def import_weights(
|
|
383
431
|
self,
|
|
384
432
|
weights: ParameterTree[Array],
|
|
385
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
386
433
|
) -> Self:
|
|
387
434
|
assert isinstance(weights, Mapping)
|
|
388
435
|
assert isinstance(weights["qkv_projection"], Mapping)
|
|
389
436
|
assert isinstance(weights["out_projection"], Mapping)
|
|
390
437
|
if self.query_norm is not None:
|
|
391
438
|
assert isinstance(weights["query_norm"], Mapping)
|
|
392
|
-
query_norm = self.query_norm.import_weights(weights["query_norm"]
|
|
439
|
+
query_norm = self.query_norm.import_weights(weights["query_norm"])
|
|
393
440
|
else:
|
|
394
441
|
query_norm = None
|
|
395
442
|
if self.key_norm is not None:
|
|
396
443
|
assert isinstance(weights["key_norm"], Mapping)
|
|
397
|
-
key_norm = self.key_norm.import_weights(weights["key_norm"]
|
|
444
|
+
key_norm = self.key_norm.import_weights(weights["key_norm"])
|
|
398
445
|
else:
|
|
399
446
|
key_norm = None
|
|
447
|
+
if self.sinks is not None:
|
|
448
|
+
assert isinstance(weights["sinks"], Array)
|
|
449
|
+
sinks = weights["sinks"]
|
|
450
|
+
else:
|
|
451
|
+
sinks = None
|
|
400
452
|
return replace(
|
|
401
453
|
self,
|
|
402
|
-
qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"]
|
|
403
|
-
out_projection=self.out_projection.import_weights(weights["out_projection"]
|
|
454
|
+
qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"]),
|
|
455
|
+
out_projection=self.out_projection.import_weights(weights["out_projection"]),
|
|
404
456
|
query_norm=query_norm,
|
|
405
457
|
key_norm=key_norm,
|
|
458
|
+
sinks=sinks,
|
|
406
459
|
)
|
lalamo/modules/common.py
CHANGED
|
@@ -6,79 +6,31 @@ from typing import Self
|
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
from cattrs import Converter
|
|
9
|
-
from einops import rearrange
|
|
10
9
|
from jax import numpy as jnp
|
|
11
|
-
from jaxtyping import Array, DTypeLike
|
|
10
|
+
from jaxtyping import Array, DTypeLike
|
|
12
11
|
|
|
13
12
|
from lalamo.common import ParameterTree
|
|
14
13
|
|
|
15
14
|
__all__ = [
|
|
16
15
|
"AttentionType",
|
|
17
16
|
"DummyUnionMember",
|
|
17
|
+
"ForwardPassMode",
|
|
18
18
|
"LalamoModule",
|
|
19
19
|
"config_converter",
|
|
20
|
-
"from_layout",
|
|
21
|
-
"into_layout",
|
|
22
20
|
"register_config_union",
|
|
23
21
|
]
|
|
24
22
|
|
|
25
23
|
|
|
26
|
-
class WeightLayout(Enum):
|
|
27
|
-
AUTO = "auto"
|
|
28
|
-
INPUT_OUTPUT = "input_output"
|
|
29
|
-
OUTPUT_INPUT = "output_input"
|
|
30
|
-
|
|
31
|
-
def __str__(self) -> str:
|
|
32
|
-
match self:
|
|
33
|
-
case WeightLayout.AUTO:
|
|
34
|
-
return "auto"
|
|
35
|
-
case WeightLayout.INPUT_OUTPUT:
|
|
36
|
-
return "(input, output)"
|
|
37
|
-
case WeightLayout.OUTPUT_INPUT:
|
|
38
|
-
return "(output, input)"
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
_DEFAULT_WEIGHT_LAYOUT = WeightLayout.INPUT_OUTPUT
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def into_layout(
|
|
45
|
-
weights: Float[Array, "in_channels out_channels"],
|
|
46
|
-
layout: WeightLayout,
|
|
47
|
-
) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
|
|
48
|
-
if layout == WeightLayout.AUTO:
|
|
49
|
-
layout = _DEFAULT_WEIGHT_LAYOUT
|
|
50
|
-
match layout:
|
|
51
|
-
case WeightLayout.OUTPUT_INPUT:
|
|
52
|
-
return weights
|
|
53
|
-
case WeightLayout.INPUT_OUTPUT:
|
|
54
|
-
return rearrange(
|
|
55
|
-
weights,
|
|
56
|
-
"total_out_channels in_channels -> in_channels total_out_channels",
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def from_layout(
|
|
61
|
-
weights: ParameterTree | Array,
|
|
62
|
-
layout: WeightLayout,
|
|
63
|
-
) -> Array:
|
|
64
|
-
assert isinstance(weights, Array)
|
|
65
|
-
if layout == WeightLayout.AUTO:
|
|
66
|
-
layout = _DEFAULT_WEIGHT_LAYOUT
|
|
67
|
-
match layout:
|
|
68
|
-
case WeightLayout.OUTPUT_INPUT:
|
|
69
|
-
return weights
|
|
70
|
-
case WeightLayout.INPUT_OUTPUT:
|
|
71
|
-
return rearrange(
|
|
72
|
-
weights,
|
|
73
|
-
"in_channels total_out_channels -> total_out_channels in_channels",
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
|
|
77
24
|
class AttentionType(Enum):
|
|
78
25
|
GLOBAL = "global"
|
|
79
26
|
SLIDING_WINDOW = "sliding_window"
|
|
80
27
|
|
|
81
28
|
|
|
29
|
+
class ForwardPassMode(Enum):
|
|
30
|
+
MULTI_TOKEN = "multi_token"
|
|
31
|
+
SINGLE_TOKEN = "single_token"
|
|
32
|
+
|
|
33
|
+
|
|
82
34
|
class LalamoModule[ConfigT](eqx.Module):
|
|
83
35
|
config: ConfigT = eqx.field(static=True)
|
|
84
36
|
|
|
@@ -87,13 +39,12 @@ class LalamoModule[ConfigT](eqx.Module):
|
|
|
87
39
|
def activation_precision(self) -> DTypeLike: ...
|
|
88
40
|
|
|
89
41
|
@abstractmethod
|
|
90
|
-
def export_weights(self
|
|
42
|
+
def export_weights(self) -> ParameterTree[Array]: ...
|
|
91
43
|
|
|
92
44
|
@abstractmethod
|
|
93
45
|
def import_weights(
|
|
94
46
|
self,
|
|
95
47
|
weights: ParameterTree[Array],
|
|
96
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
97
48
|
) -> Self: ...
|
|
98
49
|
|
|
99
50
|
|
lalamo/modules/decoder.py
CHANGED
|
@@ -8,9 +8,10 @@ from jax import vmap
|
|
|
8
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
9
|
|
|
10
10
|
from lalamo.common import ParameterTree
|
|
11
|
+
from lalamo.modules.utils import vmap_twice
|
|
11
12
|
|
|
12
|
-
from .common import AttentionType,
|
|
13
|
-
from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerResult
|
|
13
|
+
from .common import AttentionType, ForwardPassMode, LalamoModule
|
|
14
|
+
from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerForwardPassConfig, DecoderLayerResult
|
|
14
15
|
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
15
16
|
from .kv_cache import KVCache
|
|
16
17
|
from .normalization import RMSNorm, RMSNormConfig
|
|
@@ -20,13 +21,17 @@ __all__ = [
|
|
|
20
21
|
"Decoder",
|
|
21
22
|
"DecoderActivationTrace",
|
|
22
23
|
"DecoderConfig",
|
|
24
|
+
"DecoderForwardPassConfig",
|
|
23
25
|
"DecoderResult",
|
|
24
26
|
]
|
|
25
27
|
|
|
26
28
|
|
|
29
|
+
type DecoderForwardPassConfig = DecoderLayerForwardPassConfig
|
|
30
|
+
|
|
31
|
+
|
|
27
32
|
class DecoderActivationTrace(eqx.Module):
|
|
28
|
-
token_ids: Int[Array, " suffix_tokens"]
|
|
29
|
-
token_positions: Int[Array, " suffix_tokens"]
|
|
33
|
+
token_ids: Int[Array, "batch suffix_tokens"]
|
|
34
|
+
token_positions: Int[Array, "batch suffix_tokens"]
|
|
30
35
|
kv_cache: KVCache | None
|
|
31
36
|
|
|
32
37
|
local_positional_embeddings: PositionalEmbeddings
|
|
@@ -34,7 +39,7 @@ class DecoderActivationTrace(eqx.Module):
|
|
|
34
39
|
|
|
35
40
|
layer_results: tuple[DecoderLayerResult, ...]
|
|
36
41
|
|
|
37
|
-
output_norm: Float[Array, "suffix_tokens channels"]
|
|
42
|
+
output_norm: Float[Array, "batch suffix_tokens channels"]
|
|
38
43
|
|
|
39
44
|
def export(self) -> ParameterTree:
|
|
40
45
|
result = dict(
|
|
@@ -51,7 +56,7 @@ class DecoderActivationTrace(eqx.Module):
|
|
|
51
56
|
|
|
52
57
|
|
|
53
58
|
class DecoderResult(eqx.Module):
|
|
54
|
-
logits: Float[Array, "suffix_tokens channels"]
|
|
59
|
+
logits: Float[Array, "batch suffix_tokens channels"]
|
|
55
60
|
updated_kv_cache: KVCache | None = None
|
|
56
61
|
activation_trace: DecoderActivationTrace | None = None
|
|
57
62
|
|
|
@@ -167,13 +172,9 @@ class DecoderConfig:
|
|
|
167
172
|
)
|
|
168
173
|
|
|
169
174
|
if self.local_rope_config:
|
|
170
|
-
assert self.sliding_window_sizes is not None
|
|
171
|
-
max_sliding_window_size = max(
|
|
172
|
-
window_size for window_size in self.sliding_window_sizes if window_size is not None
|
|
173
|
-
)
|
|
174
175
|
local_rope = self.local_rope_config.init(
|
|
175
176
|
head_dim=self.head_dim,
|
|
176
|
-
num_timesteps=
|
|
177
|
+
num_timesteps=self.context_length,
|
|
177
178
|
)
|
|
178
179
|
else:
|
|
179
180
|
local_rope = None
|
|
@@ -219,19 +220,31 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
219
220
|
@eqx.filter_jit
|
|
220
221
|
def __call__(
|
|
221
222
|
self,
|
|
222
|
-
token_ids: Int[Array, " suffix_tokens"],
|
|
223
|
-
token_positions: Int[Array, " suffix_tokens"],
|
|
223
|
+
token_ids: Int[Array, "batch suffix_tokens"],
|
|
224
|
+
token_positions: Int[Array, "batch suffix_tokens"],
|
|
224
225
|
kv_cache: KVCache | None = None,
|
|
225
226
|
return_updated_kv_cache: bool = False,
|
|
226
227
|
return_activation_trace: bool = False,
|
|
227
|
-
|
|
228
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
229
|
+
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
230
|
+
forward_pass_config: DecoderForwardPassConfig | None = None,
|
|
228
231
|
) -> DecoderResult:
|
|
232
|
+
if token_ids.ndim != 2:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"token_ids must be a 2D arrays of size (batch_size, sequence_length), got {token_ids.shape}",
|
|
235
|
+
)
|
|
236
|
+
if token_positions.ndim != 2:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"token_positions must be a 2D arrays of size (batch_size, sequence_length),"
|
|
239
|
+
f" got {token_positions.shape}",
|
|
240
|
+
)
|
|
241
|
+
|
|
229
242
|
maybe_kv_cache = kv_cache or ([None] * len(self.layers))
|
|
230
|
-
inner_features = self.embedding.embed(token_ids)
|
|
243
|
+
inner_features = vmap(self.embedding.embed)(token_ids)
|
|
231
244
|
|
|
232
|
-
global_positional_embeddings = self.global_rope(token_positions)
|
|
245
|
+
global_positional_embeddings = vmap(self.global_rope)(token_positions)
|
|
233
246
|
if self.local_rope is not None:
|
|
234
|
-
local_positional_embeddings = self.local_rope(token_positions)
|
|
247
|
+
local_positional_embeddings = vmap(self.local_rope)(token_positions)
|
|
235
248
|
else:
|
|
236
249
|
local_positional_embeddings = global_positional_embeddings
|
|
237
250
|
|
|
@@ -249,14 +262,16 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
249
262
|
kv_cache=kv_cache_slice,
|
|
250
263
|
return_updated_kv_cache=return_updated_kv_cache,
|
|
251
264
|
return_activation_trace=return_activation_trace,
|
|
252
|
-
|
|
265
|
+
lengths_without_padding=lengths_without_padding,
|
|
266
|
+
forward_pass_mode=forward_pass_mode,
|
|
267
|
+
forward_pass_config=forward_pass_config,
|
|
253
268
|
)
|
|
254
269
|
inner_features = layer_result.outputs
|
|
255
270
|
layer_results.append(layer_result)
|
|
256
271
|
updated_kv_cache_layers.append(layer_result.updated_kv_cache)
|
|
257
272
|
|
|
258
|
-
normalized_outputs =
|
|
259
|
-
logits =
|
|
273
|
+
normalized_outputs = vmap_twice(self.output_norm)(inner_features)
|
|
274
|
+
logits = vmap_twice(self.embedding.readout)(normalized_outputs)
|
|
260
275
|
|
|
261
276
|
if return_activation_trace:
|
|
262
277
|
activation_trace = DecoderActivationTrace(
|
|
@@ -282,24 +297,23 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
282
297
|
activation_trace=activation_trace,
|
|
283
298
|
)
|
|
284
299
|
|
|
285
|
-
def init_static_kv_cache(self, capacity: int) -> KVCache:
|
|
286
|
-
return KVCache(layer.init_static_kv_cache(capacity) for layer in self.layers)
|
|
300
|
+
def init_static_kv_cache(self, batch_size: int, capacity: int) -> KVCache:
|
|
301
|
+
return KVCache(layer.init_static_kv_cache(batch_size, capacity) for layer in self.layers)
|
|
287
302
|
|
|
288
|
-
def export_weights(self
|
|
303
|
+
def export_weights(self) -> ParameterTree:
|
|
289
304
|
result = dict(
|
|
290
|
-
embedding=self.embedding.export_weights(
|
|
291
|
-
global_rope=self.global_rope.export_weights(
|
|
292
|
-
layers=[layer.export_weights(
|
|
293
|
-
output_norm=self.output_norm.export_weights(
|
|
305
|
+
embedding=self.embedding.export_weights(),
|
|
306
|
+
global_rope=self.global_rope.export_weights(),
|
|
307
|
+
layers=[layer.export_weights() for layer in self.layers],
|
|
308
|
+
output_norm=self.output_norm.export_weights(),
|
|
294
309
|
)
|
|
295
310
|
if self.local_rope:
|
|
296
|
-
result["local_rope"] = self.local_rope.export_weights(
|
|
311
|
+
result["local_rope"] = self.local_rope.export_weights()
|
|
297
312
|
return result
|
|
298
313
|
|
|
299
314
|
def import_weights(
|
|
300
315
|
self,
|
|
301
316
|
weights: ParameterTree[Array],
|
|
302
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
303
317
|
) -> Self:
|
|
304
318
|
assert isinstance(weights, Mapping)
|
|
305
319
|
assert isinstance(weights["embedding"], Mapping)
|
|
@@ -308,19 +322,19 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
308
322
|
assert isinstance(weights["output_norm"], Mapping)
|
|
309
323
|
if self.local_rope:
|
|
310
324
|
assert isinstance(weights["local_rope"], Mapping)
|
|
311
|
-
local_rope = self.local_rope.import_weights(weights["local_rope"]
|
|
325
|
+
local_rope = self.local_rope.import_weights(weights["local_rope"])
|
|
312
326
|
else:
|
|
313
327
|
local_rope = None
|
|
314
328
|
|
|
315
329
|
layers = []
|
|
316
330
|
for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
|
|
317
331
|
assert isinstance(layer_weights, Mapping)
|
|
318
|
-
layers.append(layer.import_weights(layer_weights
|
|
332
|
+
layers.append(layer.import_weights(layer_weights))
|
|
319
333
|
return replace(
|
|
320
334
|
self,
|
|
321
|
-
embedding=self.embedding.import_weights(weights["embedding"]
|
|
322
|
-
global_rope=self.global_rope.import_weights(weights["global_rope"]
|
|
335
|
+
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
336
|
+
global_rope=self.global_rope.import_weights(weights["global_rope"]),
|
|
323
337
|
layers=tuple(layers),
|
|
324
|
-
output_norm=self.output_norm.import_weights(weights["output_norm"]
|
|
338
|
+
output_norm=self.output_norm.import_weights(weights["output_norm"]),
|
|
325
339
|
local_rope=local_rope,
|
|
326
340
|
)
|