lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +23 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +101 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
  44. lalamo-0.2.2.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
1
+ from abc import abstractmethod
2
+ from typing import Self
3
+
4
+ import equinox as eqx
5
+ import jax.numpy as jnp
6
+ from jax.lax import dynamic_update_slice_in_dim
7
+ from jax.tree_util import register_pytree_node_class
8
+ from jaxtyping import Array, Bool, DTypeLike, Float, Int
9
+
10
+ from lalamo.common import ParameterDict
11
+
12
+ __all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer"]
13
+
14
+
15
+ class KVCacheLayer(eqx.Module):
16
+ keys: Float[Array, "tokens groups head_channels"]
17
+ values: Float[Array, "tokens groups head_channels"]
18
+
19
+ def __post_init__(self) -> None:
20
+ if self.keys.ndim != 3:
21
+ raise ValueError(
22
+ f"Key and value buffers must have 3 dimensions: capacity, groups, head_channels,"
23
+ f" got shape {self.keys.shape}",
24
+ )
25
+ if self.keys.shape != self.values.shape:
26
+ raise ValueError("Keys and values buffers must have the same shape")
27
+ if self.keys.dtype != self.values.dtype:
28
+ raise ValueError("Keys and values buffers must have the same dtype")
29
+
30
+ @abstractmethod
31
+ def attention_mask(
32
+ self,
33
+ suffix_length: int,
34
+ is_causal: bool,
35
+ sliding_window_size: int | None = None,
36
+ ) -> Bool[Array, "suffix_tokens tokens"]: ...
37
+
38
+ @abstractmethod
39
+ def extend(
40
+ self,
41
+ added_keys: Float[Array, "new_tokens groups head_channels"],
42
+ added_values: Float[Array, "new_tokens groups head_channels"],
43
+ added_length: Int[Array, ""] | int | None = None,
44
+ ) -> Self: ...
45
+
46
+ def export(self) -> ParameterDict:
47
+ return ParameterDict(
48
+ keys=self.keys,
49
+ values=self.values,
50
+ )
51
+
52
+
53
+ @register_pytree_node_class
54
+ class KVCache(tuple[KVCacheLayer, ...]):
55
+ __slots__ = ()
56
+
57
+ def tree_flatten(self) -> tuple[tuple[KVCacheLayer, ...], None]:
58
+ return (tuple(self), None)
59
+
60
+ @classmethod
61
+ def tree_unflatten(cls, aux_data: None, children: tuple[KVCacheLayer, ...]) -> Self: # noqa: ARG003
62
+ return cls(children)
63
+
64
+
65
+ class DynamicKVCacheLayer(KVCacheLayer):
66
+ padding_mask: Bool[Array, " tokens"] | None = None
67
+
68
+ @classmethod
69
+ def init(
70
+ cls,
71
+ keys: Float[Array, "tokens groups head_channels"],
72
+ values: Float[Array, "tokens groups head_channels"],
73
+ length: Int[Array, ""] | int | None = None,
74
+ ) -> "DynamicKVCacheLayer":
75
+ num_tokens, _, _ = keys.shape
76
+ if length is None:
77
+ padding_mask = None
78
+ else:
79
+ padding_mask = jnp.arange(num_tokens, dtype=jnp.int32) < length
80
+ return cls(keys, values, padding_mask)
81
+
82
+ def attention_mask(
83
+ self,
84
+ suffix_length: int,
85
+ is_causal: bool,
86
+ sliding_window_size: int | None = None,
87
+ ) -> Bool[Array, "suffix_tokens tokens"]:
88
+ total_num_tokens, _, _ = self.keys.shape
89
+ result = jnp.ones((suffix_length, total_num_tokens), dtype=jnp.bool)
90
+ if is_causal:
91
+ result = jnp.tril(result, k=total_num_tokens - suffix_length)
92
+ if sliding_window_size is not None:
93
+ result = jnp.triu(result, k=1 - sliding_window_size)
94
+ if self.padding_mask is not None:
95
+ result = result & self.padding_mask[None, :]
96
+ return result
97
+
98
+ def extend(
99
+ self,
100
+ added_keys: Float[Array, "new_tokens groups head_channels"],
101
+ added_values: Float[Array, "new_tokens groups head_channels"],
102
+ added_length: Int[Array, ""] | int | None = None,
103
+ ) -> "DynamicKVCacheLayer":
104
+ updated_keys = jnp.concatenate([self.keys, added_keys], axis=0)
105
+ updated_values = jnp.concatenate([self.values, added_values], axis=0)
106
+
107
+ added_padded_length, _, _ = added_keys.shape
108
+ if self.padding_mask is None and added_length is None:
109
+ return DynamicKVCacheLayer(updated_keys, updated_values)
110
+ if added_length is None:
111
+ added_length = added_padded_length
112
+
113
+ if self.padding_mask is not None:
114
+ old_padding_mask = self.padding_mask
115
+ else:
116
+ old_num_tokens, _, _ = self.keys.shape
117
+ old_padding_mask = jnp.ones(old_num_tokens, dtype=jnp.bool)
118
+
119
+ added_padding_mask = jnp.arange(added_padded_length, dtype=jnp.int32) < added_length
120
+ updated_padding_mask = jnp.concatenate([old_padding_mask, added_padding_mask], axis=0)
121
+ return DynamicKVCacheLayer(updated_keys, updated_values, updated_padding_mask)
122
+
123
+
124
+ class StaticKVCacheLayer(KVCacheLayer):
125
+ current_length: Int[Array, ""]
126
+
127
+ def attention_mask(
128
+ self,
129
+ suffix_length: int,
130
+ is_causal: bool,
131
+ sliding_window_size: int | None = None,
132
+ ) -> Bool[Array, "suffix_tokens tokens"]:
133
+ if is_causal:
134
+ query_offsets = jnp.arange(-suffix_length, 0, dtype=jnp.int32)
135
+ else:
136
+ query_offsets = jnp.zeros(suffix_length, dtype=jnp.int32)
137
+
138
+ query_indices = self.current_length + query_offsets
139
+ key_indices = jnp.arange(self.capacity, dtype=jnp.int32)
140
+
141
+ result = query_indices[:, None] >= key_indices[None, :]
142
+ if sliding_window_size is not None:
143
+ swa_mask = query_indices[:, None] < (key_indices[None, :] + sliding_window_size)
144
+ result = result & swa_mask
145
+
146
+ return result
147
+
148
+ @property
149
+ def padding_mask(self) -> Bool[Array, " tokens"] | None:
150
+ return jnp.arange(self.capacity, dtype=jnp.int32) < self.current_length
151
+
152
+ @property
153
+ def capacity(self) -> int:
154
+ result, _, _ = self.keys.shape
155
+ return result
156
+
157
+ def extend(
158
+ self,
159
+ added_keys: Float[Array, "tokens groups head_channels"],
160
+ added_values: Float[Array, "tokens groups head_channels"],
161
+ added_length: Int[Array, ""] | int | None = None,
162
+ ) -> "StaticKVCacheLayer":
163
+ if added_keys.shape != added_values.shape:
164
+ raise ValueError("Keys and values must have the same shape")
165
+ num_added_tokens, new_num_groups, new_head_dim = added_keys.shape
166
+ _, old_num_groups, old_head_dim = self.keys.shape
167
+ if new_num_groups != old_num_groups or new_head_dim != old_head_dim:
168
+ raise ValueError("New keys and values must have the same number of groups and head dimensions")
169
+
170
+ if added_length is None:
171
+ added_length = num_added_tokens
172
+
173
+ updated_keys = dynamic_update_slice_in_dim(
174
+ self.keys,
175
+ added_keys,
176
+ self.current_length,
177
+ 0,
178
+ allow_negative_indices=False,
179
+ )
180
+ updated_values = dynamic_update_slice_in_dim(
181
+ self.values,
182
+ added_values,
183
+ self.current_length,
184
+ 0,
185
+ allow_negative_indices=False,
186
+ )
187
+ updated_sequence_length = self.current_length + added_length
188
+ return StaticKVCacheLayer(keys=updated_keys, values=updated_values, current_length=updated_sequence_length)
189
+
190
+ @classmethod
191
+ def empty(cls, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
192
+ return cls(
193
+ keys=jnp.empty((capacity, num_groups, head_dim), dtype=dtype),
194
+ values=jnp.empty((capacity, num_groups, head_dim), dtype=dtype),
195
+ current_length=jnp.array(0, dtype=jnp.int32),
196
+ )