lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/__init__.py +8 -0
- lalamo/model_import/common.py +111 -0
- lalamo/model_import/configs/__init__.py +23 -0
- lalamo/model_import/configs/common.py +62 -0
- lalamo/model_import/configs/executorch.py +166 -0
- lalamo/model_import/configs/huggingface/__init__.py +18 -0
- lalamo/model_import/configs/huggingface/common.py +72 -0
- lalamo/model_import/configs/huggingface/gemma2.py +122 -0
- lalamo/model_import/configs/huggingface/gemma3.py +187 -0
- lalamo/model_import/configs/huggingface/llama.py +155 -0
- lalamo/model_import/configs/huggingface/mistral.py +132 -0
- lalamo/model_import/configs/huggingface/qwen2.py +144 -0
- lalamo/model_import/configs/huggingface/qwen3.py +142 -0
- lalamo/model_import/loaders/__init__.py +7 -0
- lalamo/model_import/loaders/common.py +45 -0
- lalamo/model_import/loaders/executorch.py +223 -0
- lalamo/model_import/loaders/huggingface.py +304 -0
- lalamo/model_import/model_specs/__init__.py +38 -0
- lalamo/model_import/model_specs/common.py +118 -0
- lalamo/model_import/model_specs/deepseek.py +28 -0
- lalamo/model_import/model_specs/gemma.py +76 -0
- lalamo/model_import/model_specs/huggingface.py +28 -0
- lalamo/model_import/model_specs/llama.py +101 -0
- lalamo/model_import/model_specs/mistral.py +59 -0
- lalamo/model_import/model_specs/pleias.py +28 -0
- lalamo/model_import/model_specs/polaris.py +22 -0
- lalamo/model_import/model_specs/qwen.py +336 -0
- lalamo/model_import/model_specs/reka.py +28 -0
- lalamo/modules/__init__.py +85 -0
- lalamo/modules/activations.py +30 -0
- lalamo/modules/attention.py +326 -0
- lalamo/modules/common.py +133 -0
- lalamo/modules/decoder.py +244 -0
- lalamo/modules/decoder_layer.py +240 -0
- lalamo/modules/embedding.py +299 -0
- lalamo/modules/kv_cache.py +196 -0
- lalamo/modules/linear.py +603 -0
- lalamo/modules/mlp.py +79 -0
- lalamo/modules/normalization.py +77 -0
- lalamo/modules/rope.py +255 -0
- lalamo/modules/utils.py +13 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
- lalamo-0.2.2.dist-info/RECORD +53 -0
- lalamo-0.2.1.dist-info/RECORD +0 -12
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Self
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax.lax import dynamic_update_slice_in_dim
|
|
7
|
+
from jax.tree_util import register_pytree_node_class
|
|
8
|
+
from jaxtyping import Array, Bool, DTypeLike, Float, Int
|
|
9
|
+
|
|
10
|
+
from lalamo.common import ParameterDict
|
|
11
|
+
|
|
12
|
+
__all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class KVCacheLayer(eqx.Module):
|
|
16
|
+
keys: Float[Array, "tokens groups head_channels"]
|
|
17
|
+
values: Float[Array, "tokens groups head_channels"]
|
|
18
|
+
|
|
19
|
+
def __post_init__(self) -> None:
|
|
20
|
+
if self.keys.ndim != 3:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Key and value buffers must have 3 dimensions: capacity, groups, head_channels,"
|
|
23
|
+
f" got shape {self.keys.shape}",
|
|
24
|
+
)
|
|
25
|
+
if self.keys.shape != self.values.shape:
|
|
26
|
+
raise ValueError("Keys and values buffers must have the same shape")
|
|
27
|
+
if self.keys.dtype != self.values.dtype:
|
|
28
|
+
raise ValueError("Keys and values buffers must have the same dtype")
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def attention_mask(
|
|
32
|
+
self,
|
|
33
|
+
suffix_length: int,
|
|
34
|
+
is_causal: bool,
|
|
35
|
+
sliding_window_size: int | None = None,
|
|
36
|
+
) -> Bool[Array, "suffix_tokens tokens"]: ...
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def extend(
|
|
40
|
+
self,
|
|
41
|
+
added_keys: Float[Array, "new_tokens groups head_channels"],
|
|
42
|
+
added_values: Float[Array, "new_tokens groups head_channels"],
|
|
43
|
+
added_length: Int[Array, ""] | int | None = None,
|
|
44
|
+
) -> Self: ...
|
|
45
|
+
|
|
46
|
+
def export(self) -> ParameterDict:
|
|
47
|
+
return ParameterDict(
|
|
48
|
+
keys=self.keys,
|
|
49
|
+
values=self.values,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_pytree_node_class
|
|
54
|
+
class KVCache(tuple[KVCacheLayer, ...]):
|
|
55
|
+
__slots__ = ()
|
|
56
|
+
|
|
57
|
+
def tree_flatten(self) -> tuple[tuple[KVCacheLayer, ...], None]:
|
|
58
|
+
return (tuple(self), None)
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def tree_unflatten(cls, aux_data: None, children: tuple[KVCacheLayer, ...]) -> Self: # noqa: ARG003
|
|
62
|
+
return cls(children)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class DynamicKVCacheLayer(KVCacheLayer):
|
|
66
|
+
padding_mask: Bool[Array, " tokens"] | None = None
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def init(
|
|
70
|
+
cls,
|
|
71
|
+
keys: Float[Array, "tokens groups head_channels"],
|
|
72
|
+
values: Float[Array, "tokens groups head_channels"],
|
|
73
|
+
length: Int[Array, ""] | int | None = None,
|
|
74
|
+
) -> "DynamicKVCacheLayer":
|
|
75
|
+
num_tokens, _, _ = keys.shape
|
|
76
|
+
if length is None:
|
|
77
|
+
padding_mask = None
|
|
78
|
+
else:
|
|
79
|
+
padding_mask = jnp.arange(num_tokens, dtype=jnp.int32) < length
|
|
80
|
+
return cls(keys, values, padding_mask)
|
|
81
|
+
|
|
82
|
+
def attention_mask(
|
|
83
|
+
self,
|
|
84
|
+
suffix_length: int,
|
|
85
|
+
is_causal: bool,
|
|
86
|
+
sliding_window_size: int | None = None,
|
|
87
|
+
) -> Bool[Array, "suffix_tokens tokens"]:
|
|
88
|
+
total_num_tokens, _, _ = self.keys.shape
|
|
89
|
+
result = jnp.ones((suffix_length, total_num_tokens), dtype=jnp.bool)
|
|
90
|
+
if is_causal:
|
|
91
|
+
result = jnp.tril(result, k=total_num_tokens - suffix_length)
|
|
92
|
+
if sliding_window_size is not None:
|
|
93
|
+
result = jnp.triu(result, k=1 - sliding_window_size)
|
|
94
|
+
if self.padding_mask is not None:
|
|
95
|
+
result = result & self.padding_mask[None, :]
|
|
96
|
+
return result
|
|
97
|
+
|
|
98
|
+
def extend(
|
|
99
|
+
self,
|
|
100
|
+
added_keys: Float[Array, "new_tokens groups head_channels"],
|
|
101
|
+
added_values: Float[Array, "new_tokens groups head_channels"],
|
|
102
|
+
added_length: Int[Array, ""] | int | None = None,
|
|
103
|
+
) -> "DynamicKVCacheLayer":
|
|
104
|
+
updated_keys = jnp.concatenate([self.keys, added_keys], axis=0)
|
|
105
|
+
updated_values = jnp.concatenate([self.values, added_values], axis=0)
|
|
106
|
+
|
|
107
|
+
added_padded_length, _, _ = added_keys.shape
|
|
108
|
+
if self.padding_mask is None and added_length is None:
|
|
109
|
+
return DynamicKVCacheLayer(updated_keys, updated_values)
|
|
110
|
+
if added_length is None:
|
|
111
|
+
added_length = added_padded_length
|
|
112
|
+
|
|
113
|
+
if self.padding_mask is not None:
|
|
114
|
+
old_padding_mask = self.padding_mask
|
|
115
|
+
else:
|
|
116
|
+
old_num_tokens, _, _ = self.keys.shape
|
|
117
|
+
old_padding_mask = jnp.ones(old_num_tokens, dtype=jnp.bool)
|
|
118
|
+
|
|
119
|
+
added_padding_mask = jnp.arange(added_padded_length, dtype=jnp.int32) < added_length
|
|
120
|
+
updated_padding_mask = jnp.concatenate([old_padding_mask, added_padding_mask], axis=0)
|
|
121
|
+
return DynamicKVCacheLayer(updated_keys, updated_values, updated_padding_mask)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class StaticKVCacheLayer(KVCacheLayer):
|
|
125
|
+
current_length: Int[Array, ""]
|
|
126
|
+
|
|
127
|
+
def attention_mask(
|
|
128
|
+
self,
|
|
129
|
+
suffix_length: int,
|
|
130
|
+
is_causal: bool,
|
|
131
|
+
sliding_window_size: int | None = None,
|
|
132
|
+
) -> Bool[Array, "suffix_tokens tokens"]:
|
|
133
|
+
if is_causal:
|
|
134
|
+
query_offsets = jnp.arange(-suffix_length, 0, dtype=jnp.int32)
|
|
135
|
+
else:
|
|
136
|
+
query_offsets = jnp.zeros(suffix_length, dtype=jnp.int32)
|
|
137
|
+
|
|
138
|
+
query_indices = self.current_length + query_offsets
|
|
139
|
+
key_indices = jnp.arange(self.capacity, dtype=jnp.int32)
|
|
140
|
+
|
|
141
|
+
result = query_indices[:, None] >= key_indices[None, :]
|
|
142
|
+
if sliding_window_size is not None:
|
|
143
|
+
swa_mask = query_indices[:, None] < (key_indices[None, :] + sliding_window_size)
|
|
144
|
+
result = result & swa_mask
|
|
145
|
+
|
|
146
|
+
return result
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def padding_mask(self) -> Bool[Array, " tokens"] | None:
|
|
150
|
+
return jnp.arange(self.capacity, dtype=jnp.int32) < self.current_length
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def capacity(self) -> int:
|
|
154
|
+
result, _, _ = self.keys.shape
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
def extend(
|
|
158
|
+
self,
|
|
159
|
+
added_keys: Float[Array, "tokens groups head_channels"],
|
|
160
|
+
added_values: Float[Array, "tokens groups head_channels"],
|
|
161
|
+
added_length: Int[Array, ""] | int | None = None,
|
|
162
|
+
) -> "StaticKVCacheLayer":
|
|
163
|
+
if added_keys.shape != added_values.shape:
|
|
164
|
+
raise ValueError("Keys and values must have the same shape")
|
|
165
|
+
num_added_tokens, new_num_groups, new_head_dim = added_keys.shape
|
|
166
|
+
_, old_num_groups, old_head_dim = self.keys.shape
|
|
167
|
+
if new_num_groups != old_num_groups or new_head_dim != old_head_dim:
|
|
168
|
+
raise ValueError("New keys and values must have the same number of groups and head dimensions")
|
|
169
|
+
|
|
170
|
+
if added_length is None:
|
|
171
|
+
added_length = num_added_tokens
|
|
172
|
+
|
|
173
|
+
updated_keys = dynamic_update_slice_in_dim(
|
|
174
|
+
self.keys,
|
|
175
|
+
added_keys,
|
|
176
|
+
self.current_length,
|
|
177
|
+
0,
|
|
178
|
+
allow_negative_indices=False,
|
|
179
|
+
)
|
|
180
|
+
updated_values = dynamic_update_slice_in_dim(
|
|
181
|
+
self.values,
|
|
182
|
+
added_values,
|
|
183
|
+
self.current_length,
|
|
184
|
+
0,
|
|
185
|
+
allow_negative_indices=False,
|
|
186
|
+
)
|
|
187
|
+
updated_sequence_length = self.current_length + added_length
|
|
188
|
+
return StaticKVCacheLayer(keys=updated_keys, values=updated_values, current_length=updated_sequence_length)
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def empty(cls, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
|
|
192
|
+
return cls(
|
|
193
|
+
keys=jnp.empty((capacity, num_groups, head_dim), dtype=dtype),
|
|
194
|
+
values=jnp.empty((capacity, num_groups, head_dim), dtype=dtype),
|
|
195
|
+
current_length=jnp.array(0, dtype=jnp.int32),
|
|
196
|
+
)
|