lalamo 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/common.py +79 -29
- lalamo/language_model.py +106 -83
- lalamo/main.py +91 -18
- lalamo/message_processor.py +170 -0
- lalamo/model_import/common.py +159 -43
- lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
- lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
- lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
- lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
- lalamo/model_import/huggingface_generation_config.py +44 -0
- lalamo/model_import/huggingface_tokenizer_config.py +85 -0
- lalamo/model_import/loaders/common.py +2 -1
- lalamo/model_import/loaders/huggingface.py +12 -10
- lalamo/model_import/model_specs/__init__.py +3 -2
- lalamo/model_import/model_specs/common.py +32 -34
- lalamo/model_import/model_specs/deepseek.py +1 -10
- lalamo/model_import/model_specs/gemma.py +2 -25
- lalamo/model_import/model_specs/huggingface.py +2 -12
- lalamo/model_import/model_specs/llama.py +2 -58
- lalamo/model_import/model_specs/mistral.py +9 -19
- lalamo/model_import/model_specs/pleias.py +3 -13
- lalamo/model_import/model_specs/polaris.py +5 -7
- lalamo/model_import/model_specs/qwen.py +12 -111
- lalamo/model_import/model_specs/reka.py +4 -13
- lalamo/modules/__init__.py +2 -1
- lalamo/modules/attention.py +90 -10
- lalamo/modules/common.py +51 -4
- lalamo/modules/decoder.py +90 -8
- lalamo/modules/decoder_layer.py +85 -8
- lalamo/modules/embedding.py +95 -29
- lalamo/modules/kv_cache.py +3 -3
- lalamo/modules/linear.py +170 -130
- lalamo/modules/mlp.py +40 -7
- lalamo/modules/normalization.py +24 -6
- lalamo/modules/rope.py +24 -6
- lalamo/sampling.py +99 -0
- lalamo/utils.py +86 -1
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/METADATA +6 -6
- lalamo-0.3.0.dist-info/RECORD +58 -0
- lalamo-0.2.7.dist-info/RECORD +0 -54
- /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/WHEEL +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,7 @@
|
|
|
1
|
-
from lalamo.model_import.
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFQwen2Config, HFQwen3Config
|
|
2
2
|
from lalamo.quantization import QuantizationMode
|
|
3
3
|
|
|
4
|
-
from .common import
|
|
5
|
-
HUGGINFACE_GENERATION_CONFIG_FILE,
|
|
6
|
-
HUGGINGFACE_TOKENIZER_FILES,
|
|
7
|
-
ModelSpec,
|
|
8
|
-
UseCase,
|
|
9
|
-
WeightsType,
|
|
10
|
-
huggingface_weight_files,
|
|
11
|
-
)
|
|
4
|
+
from .common import ModelSpec, UseCase, WeightsType
|
|
12
5
|
|
|
13
6
|
__all__ = ["QWEN_MODELS"]
|
|
14
7
|
|
|
@@ -22,11 +15,6 @@ QWEN25 = [
|
|
|
22
15
|
quantization=None,
|
|
23
16
|
repo="Qwen/Qwen2.5-0.5B-Instruct",
|
|
24
17
|
config_type=HFQwen2Config,
|
|
25
|
-
config_file_name="config.json",
|
|
26
|
-
weights_file_names=huggingface_weight_files(1),
|
|
27
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
28
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
29
|
-
use_cases=tuple(),
|
|
30
18
|
),
|
|
31
19
|
ModelSpec(
|
|
32
20
|
vendor="Alibaba",
|
|
@@ -36,11 +24,6 @@ QWEN25 = [
|
|
|
36
24
|
quantization=None,
|
|
37
25
|
repo="Qwen/Qwen2.5-1.5B-Instruct",
|
|
38
26
|
config_type=HFQwen2Config,
|
|
39
|
-
config_file_name="config.json",
|
|
40
|
-
weights_file_names=huggingface_weight_files(1),
|
|
41
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
42
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
43
|
-
use_cases=tuple(),
|
|
44
27
|
),
|
|
45
28
|
ModelSpec(
|
|
46
29
|
vendor="Alibaba",
|
|
@@ -50,11 +33,6 @@ QWEN25 = [
|
|
|
50
33
|
quantization=None,
|
|
51
34
|
repo="Qwen/Qwen2.5-3B-Instruct",
|
|
52
35
|
config_type=HFQwen2Config,
|
|
53
|
-
config_file_name="config.json",
|
|
54
|
-
weights_file_names=huggingface_weight_files(2),
|
|
55
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
56
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
57
|
-
use_cases=tuple(),
|
|
58
36
|
),
|
|
59
37
|
ModelSpec(
|
|
60
38
|
vendor="Alibaba",
|
|
@@ -64,11 +42,6 @@ QWEN25 = [
|
|
|
64
42
|
quantization=None,
|
|
65
43
|
repo="Qwen/Qwen2.5-7B-Instruct",
|
|
66
44
|
config_type=HFQwen2Config,
|
|
67
|
-
config_file_name="config.json",
|
|
68
|
-
weights_file_names=huggingface_weight_files(4),
|
|
69
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
70
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
71
|
-
use_cases=tuple(),
|
|
72
45
|
),
|
|
73
46
|
ModelSpec(
|
|
74
47
|
vendor="Alibaba",
|
|
@@ -78,11 +51,6 @@ QWEN25 = [
|
|
|
78
51
|
quantization=None,
|
|
79
52
|
repo="Qwen/Qwen2.5-14B-Instruct",
|
|
80
53
|
config_type=HFQwen2Config,
|
|
81
|
-
config_file_name="config.json",
|
|
82
|
-
weights_file_names=huggingface_weight_files(8),
|
|
83
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
84
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
85
|
-
use_cases=tuple(),
|
|
86
54
|
),
|
|
87
55
|
ModelSpec(
|
|
88
56
|
vendor="Alibaba",
|
|
@@ -92,11 +60,6 @@ QWEN25 = [
|
|
|
92
60
|
quantization=None,
|
|
93
61
|
repo="Qwen/Qwen2.5-32B-Instruct",
|
|
94
62
|
config_type=HFQwen2Config,
|
|
95
|
-
config_file_name="config.json",
|
|
96
|
-
weights_file_names=huggingface_weight_files(17),
|
|
97
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
98
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
99
|
-
use_cases=tuple(),
|
|
100
63
|
),
|
|
101
64
|
]
|
|
102
65
|
|
|
@@ -110,10 +73,6 @@ QWEN25_CODER = [
|
|
|
110
73
|
quantization=None,
|
|
111
74
|
repo="Qwen/Qwen2.5-Coder-0.5B-Instruct",
|
|
112
75
|
config_type=HFQwen2Config,
|
|
113
|
-
config_file_name="config.json",
|
|
114
|
-
weights_file_names=huggingface_weight_files(1),
|
|
115
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
116
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
117
76
|
use_cases=(UseCase.CODE,),
|
|
118
77
|
),
|
|
119
78
|
ModelSpec(
|
|
@@ -124,10 +83,6 @@ QWEN25_CODER = [
|
|
|
124
83
|
quantization=None,
|
|
125
84
|
repo="Qwen/Qwen2.5-Coder-1.5B-Instruct",
|
|
126
85
|
config_type=HFQwen2Config,
|
|
127
|
-
config_file_name="config.json",
|
|
128
|
-
weights_file_names=huggingface_weight_files(1),
|
|
129
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
130
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
131
86
|
use_cases=(UseCase.CODE,),
|
|
132
87
|
),
|
|
133
88
|
ModelSpec(
|
|
@@ -138,10 +93,6 @@ QWEN25_CODER = [
|
|
|
138
93
|
quantization=None,
|
|
139
94
|
repo="Qwen/Qwen2.5-Coder-3B-Instruct",
|
|
140
95
|
config_type=HFQwen2Config,
|
|
141
|
-
config_file_name="config.json",
|
|
142
|
-
weights_file_names=huggingface_weight_files(2),
|
|
143
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
144
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
145
96
|
use_cases=(UseCase.CODE,),
|
|
146
97
|
),
|
|
147
98
|
ModelSpec(
|
|
@@ -152,10 +103,6 @@ QWEN25_CODER = [
|
|
|
152
103
|
quantization=None,
|
|
153
104
|
repo="Qwen/Qwen2.5-Coder-7B-Instruct",
|
|
154
105
|
config_type=HFQwen2Config,
|
|
155
|
-
config_file_name="config.json",
|
|
156
|
-
weights_file_names=huggingface_weight_files(4),
|
|
157
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
158
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
159
106
|
use_cases=(UseCase.CODE,),
|
|
160
107
|
),
|
|
161
108
|
ModelSpec(
|
|
@@ -166,10 +113,6 @@ QWEN25_CODER = [
|
|
|
166
113
|
quantization=None,
|
|
167
114
|
repo="Qwen/Qwen2.5-Coder-14B-Instruct",
|
|
168
115
|
config_type=HFQwen2Config,
|
|
169
|
-
config_file_name="config.json",
|
|
170
|
-
weights_file_names=huggingface_weight_files(6),
|
|
171
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
172
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
173
116
|
use_cases=(UseCase.CODE,),
|
|
174
117
|
),
|
|
175
118
|
ModelSpec(
|
|
@@ -180,10 +123,16 @@ QWEN25_CODER = [
|
|
|
180
123
|
quantization=None,
|
|
181
124
|
repo="Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
182
125
|
config_type=HFQwen2Config,
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
126
|
+
use_cases=(UseCase.CODE,),
|
|
127
|
+
),
|
|
128
|
+
ModelSpec(
|
|
129
|
+
vendor="Alibaba",
|
|
130
|
+
family="Qwen2.5-Coder",
|
|
131
|
+
name="Qwen2.5-Coder-32B-Instruct",
|
|
132
|
+
size="32B",
|
|
133
|
+
quantization=None,
|
|
134
|
+
repo="Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
135
|
+
config_type=HFQwen2Config,
|
|
187
136
|
use_cases=(UseCase.CODE,),
|
|
188
137
|
),
|
|
189
138
|
]
|
|
@@ -198,11 +147,6 @@ QWEN3 = [
|
|
|
198
147
|
quantization=None,
|
|
199
148
|
repo="Qwen/Qwen3-0.6B",
|
|
200
149
|
config_type=HFQwen3Config,
|
|
201
|
-
config_file_name="config.json",
|
|
202
|
-
weights_file_names=huggingface_weight_files(1),
|
|
203
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
204
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
205
|
-
use_cases=tuple(),
|
|
206
150
|
),
|
|
207
151
|
ModelSpec(
|
|
208
152
|
vendor="Alibaba",
|
|
@@ -212,10 +156,7 @@ QWEN3 = [
|
|
|
212
156
|
quantization=None,
|
|
213
157
|
repo="Qwen/Qwen3-1.7B",
|
|
214
158
|
config_type=HFQwen3Config,
|
|
215
|
-
config_file_name="config.json",
|
|
216
|
-
weights_file_names=huggingface_weight_files(2),
|
|
217
159
|
weights_type=WeightsType.SAFETENSORS,
|
|
218
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
219
160
|
use_cases=tuple(),
|
|
220
161
|
),
|
|
221
162
|
ModelSpec(
|
|
@@ -226,11 +167,6 @@ QWEN3 = [
|
|
|
226
167
|
quantization=None,
|
|
227
168
|
repo="Qwen/Qwen3-4B",
|
|
228
169
|
config_type=HFQwen3Config,
|
|
229
|
-
config_file_name="config.json",
|
|
230
|
-
weights_file_names=huggingface_weight_files(3),
|
|
231
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
232
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
233
|
-
use_cases=tuple(),
|
|
234
170
|
),
|
|
235
171
|
ModelSpec(
|
|
236
172
|
vendor="Alibaba",
|
|
@@ -240,11 +176,6 @@ QWEN3 = [
|
|
|
240
176
|
quantization=QuantizationMode.UINT4,
|
|
241
177
|
repo="Qwen/Qwen3-4B-AWQ",
|
|
242
178
|
config_type=HFQwen3Config,
|
|
243
|
-
config_file_name="config.json",
|
|
244
|
-
weights_file_names=huggingface_weight_files(1),
|
|
245
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
246
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
247
|
-
use_cases=tuple(),
|
|
248
179
|
),
|
|
249
180
|
ModelSpec(
|
|
250
181
|
vendor="Alibaba",
|
|
@@ -254,11 +185,6 @@ QWEN3 = [
|
|
|
254
185
|
quantization=None,
|
|
255
186
|
repo="Qwen/Qwen3-8B",
|
|
256
187
|
config_type=HFQwen3Config,
|
|
257
|
-
config_file_name="config.json",
|
|
258
|
-
weights_file_names=huggingface_weight_files(5),
|
|
259
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
260
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
261
|
-
use_cases=tuple(),
|
|
262
188
|
),
|
|
263
189
|
ModelSpec(
|
|
264
190
|
vendor="Alibaba",
|
|
@@ -268,11 +194,6 @@ QWEN3 = [
|
|
|
268
194
|
quantization=QuantizationMode.UINT4,
|
|
269
195
|
repo="Qwen/Qwen3-8B-AWQ",
|
|
270
196
|
config_type=HFQwen3Config,
|
|
271
|
-
config_file_name="config.json",
|
|
272
|
-
weights_file_names=huggingface_weight_files(2),
|
|
273
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
274
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
275
|
-
use_cases=tuple(),
|
|
276
197
|
),
|
|
277
198
|
ModelSpec(
|
|
278
199
|
vendor="Alibaba",
|
|
@@ -282,11 +203,6 @@ QWEN3 = [
|
|
|
282
203
|
quantization=None,
|
|
283
204
|
repo="Qwen/Qwen3-14B",
|
|
284
205
|
config_type=HFQwen3Config,
|
|
285
|
-
config_file_name="config.json",
|
|
286
|
-
weights_file_names=huggingface_weight_files(8),
|
|
287
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
288
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
289
|
-
use_cases=tuple(),
|
|
290
206
|
),
|
|
291
207
|
ModelSpec(
|
|
292
208
|
vendor="Alibaba",
|
|
@@ -296,11 +212,6 @@ QWEN3 = [
|
|
|
296
212
|
quantization=None,
|
|
297
213
|
repo="Qwen/Qwen3-14B-AWQ",
|
|
298
214
|
config_type=HFQwen3Config,
|
|
299
|
-
config_file_name="config.json",
|
|
300
|
-
weights_file_names=huggingface_weight_files(2),
|
|
301
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
302
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
303
|
-
use_cases=tuple(),
|
|
304
215
|
),
|
|
305
216
|
ModelSpec(
|
|
306
217
|
vendor="Alibaba",
|
|
@@ -310,11 +221,6 @@ QWEN3 = [
|
|
|
310
221
|
quantization=None,
|
|
311
222
|
repo="Qwen/Qwen3-32B",
|
|
312
223
|
config_type=HFQwen3Config,
|
|
313
|
-
config_file_name="config.json",
|
|
314
|
-
weights_file_names=huggingface_weight_files(17),
|
|
315
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
316
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
317
|
-
use_cases=tuple(),
|
|
318
224
|
),
|
|
319
225
|
ModelSpec(
|
|
320
226
|
vendor="Alibaba",
|
|
@@ -324,11 +230,6 @@ QWEN3 = [
|
|
|
324
230
|
quantization=QuantizationMode.UINT4,
|
|
325
231
|
repo="Qwen/Qwen3-32B-AWQ",
|
|
326
232
|
config_type=HFQwen3Config,
|
|
327
|
-
config_file_name="config.json",
|
|
328
|
-
weights_file_names=huggingface_weight_files(4),
|
|
329
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
330
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
331
|
-
use_cases=tuple(),
|
|
332
233
|
),
|
|
333
234
|
]
|
|
334
235
|
|
|
@@ -1,12 +1,6 @@
|
|
|
1
|
-
from lalamo.model_import.
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFLlamaConfig
|
|
2
2
|
|
|
3
|
-
from .common import
|
|
4
|
-
HUGGINFACE_GENERATION_CONFIG_FILE,
|
|
5
|
-
HUGGINGFACE_TOKENIZER_FILES,
|
|
6
|
-
ModelSpec,
|
|
7
|
-
WeightsType,
|
|
8
|
-
huggingface_weight_files,
|
|
9
|
-
)
|
|
3
|
+
from .common import ModelSpec
|
|
10
4
|
|
|
11
5
|
__all__ = ["REKA_MODELS"]
|
|
12
6
|
|
|
@@ -19,10 +13,7 @@ REKA_MODELS = [
|
|
|
19
13
|
quantization=None,
|
|
20
14
|
repo="RekaAI/reka-flash-3.1",
|
|
21
15
|
config_type=HFLlamaConfig,
|
|
22
|
-
|
|
23
|
-
weights_file_names=huggingface_weight_files(9), # Model has 9 shards
|
|
24
|
-
weights_type=WeightsType.SAFETENSORS,
|
|
25
|
-
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
16
|
+
user_role_name="human",
|
|
26
17
|
use_cases=tuple(),
|
|
27
18
|
),
|
|
28
|
-
]
|
|
19
|
+
]
|
lalamo/modules/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .activations import Activation
|
|
2
2
|
from .attention import Attention, AttentionConfig
|
|
3
|
-
from .common import WeightLayout, config_converter
|
|
3
|
+
from .common import LalamoModule, WeightLayout, config_converter
|
|
4
4
|
from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderResult
|
|
5
5
|
from .decoder_layer import DecoderLayer, DecoderLayerActivationTrace, DecoderLayerConfig, DecoderLayerResult
|
|
6
6
|
from .embedding import (
|
|
@@ -58,6 +58,7 @@ __all__ = [
|
|
|
58
58
|
"GroupQuantizedLinearConfig",
|
|
59
59
|
"KVCache",
|
|
60
60
|
"KVCacheLayer",
|
|
61
|
+
"LalamoModule",
|
|
61
62
|
"LinearBase",
|
|
62
63
|
"LinearConfig",
|
|
63
64
|
"LinearScalingRoPEConfig",
|
lalamo/modules/attention.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from typing import NamedTuple, Self
|
|
3
4
|
|
|
4
5
|
import equinox as eqx
|
|
5
6
|
import jax
|
|
@@ -8,10 +9,9 @@ from jax import numpy as jnp
|
|
|
8
9
|
from jax import vmap
|
|
9
10
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
10
11
|
|
|
11
|
-
from lalamo.common import ParameterDict
|
|
12
12
|
from lalamo.modules.normalization import RMSNorm, RMSNormConfig
|
|
13
13
|
|
|
14
|
-
from .common import AttentionType, LalamoModule, WeightLayout
|
|
14
|
+
from .common import AttentionType, LalamoModule, ParameterTree, WeightLayout
|
|
15
15
|
from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
|
|
16
16
|
from .linear import LinearBase, LinearConfig
|
|
17
17
|
from .rope import PositionalEmbeddings
|
|
@@ -42,8 +42,8 @@ def _soft_capped_attention_kernel(
|
|
|
42
42
|
scale: float | None,
|
|
43
43
|
logit_soft_cap: float,
|
|
44
44
|
) -> Float[Array, "dst_tokens heads head_channels"]:
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
_, num_heads, head_dim = queries.shape
|
|
46
|
+
_, num_groups, _ = keys.shape
|
|
47
47
|
if scale is None:
|
|
48
48
|
scale = head_dim**-0.5
|
|
49
49
|
group_size = num_heads // num_groups
|
|
@@ -118,14 +118,67 @@ class AttentionConfig:
|
|
|
118
118
|
|
|
119
119
|
if self.query_norm_config is not None:
|
|
120
120
|
query_norm = self.query_norm_config.init(
|
|
121
|
-
|
|
121
|
+
input_dim=head_dim,
|
|
122
122
|
)
|
|
123
123
|
else:
|
|
124
124
|
query_norm = None
|
|
125
125
|
|
|
126
126
|
if self.key_norm_config is not None:
|
|
127
127
|
key_norm = self.key_norm_config.init(
|
|
128
|
-
|
|
128
|
+
input_dim=head_dim,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
key_norm = None
|
|
132
|
+
|
|
133
|
+
return Attention(
|
|
134
|
+
self,
|
|
135
|
+
qkv_projection=qkv_projection,
|
|
136
|
+
out_projection=out_projection,
|
|
137
|
+
query_norm=query_norm,
|
|
138
|
+
key_norm=key_norm,
|
|
139
|
+
num_heads=num_heads,
|
|
140
|
+
num_groups=num_groups,
|
|
141
|
+
head_dim=head_dim,
|
|
142
|
+
is_causal=is_causal,
|
|
143
|
+
scale=scale,
|
|
144
|
+
sliding_window_size=sliding_window_size,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def empty(
|
|
148
|
+
self,
|
|
149
|
+
model_dim: int,
|
|
150
|
+
num_heads: int,
|
|
151
|
+
num_groups: int,
|
|
152
|
+
head_dim: int,
|
|
153
|
+
is_causal: bool,
|
|
154
|
+
scale: float | None,
|
|
155
|
+
sliding_window_size: int | None,
|
|
156
|
+
) -> "Attention":
|
|
157
|
+
qkv_projection = self.qkv_projection_config.empty(
|
|
158
|
+
input_dim=model_dim,
|
|
159
|
+
output_dims=(
|
|
160
|
+
num_heads * head_dim,
|
|
161
|
+
num_groups * head_dim,
|
|
162
|
+
num_groups * head_dim,
|
|
163
|
+
),
|
|
164
|
+
has_biases=self.has_qkv_biases,
|
|
165
|
+
)
|
|
166
|
+
out_projection = self.out_projection_config.empty(
|
|
167
|
+
num_heads * head_dim,
|
|
168
|
+
(model_dim,),
|
|
169
|
+
has_biases=self.has_out_biases,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if self.query_norm_config is not None:
|
|
173
|
+
query_norm = self.query_norm_config.empty(
|
|
174
|
+
input_dim=head_dim,
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
query_norm = None
|
|
178
|
+
|
|
179
|
+
if self.key_norm_config is not None:
|
|
180
|
+
key_norm = self.key_norm_config.empty(
|
|
181
|
+
input_dim=head_dim,
|
|
129
182
|
)
|
|
130
183
|
else:
|
|
131
184
|
key_norm = None
|
|
@@ -233,6 +286,7 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
233
286
|
f" got {v_output_dim}",
|
|
234
287
|
)
|
|
235
288
|
|
|
289
|
+
@eqx.filter_jit
|
|
236
290
|
def __call__(
|
|
237
291
|
self,
|
|
238
292
|
inputs: Float[Array, "suffix_tokens channels"],
|
|
@@ -314,8 +368,8 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
314
368
|
def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
|
|
315
369
|
return StaticKVCacheLayer.empty(capacity, self.num_groups, self.head_dim, self.activation_precision)
|
|
316
370
|
|
|
317
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
318
|
-
result =
|
|
371
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
372
|
+
result = dict(
|
|
319
373
|
qkv_projection=self.qkv_projection.export_weights(weight_layout),
|
|
320
374
|
out_projection=self.out_projection.export_weights(weight_layout),
|
|
321
375
|
)
|
|
@@ -324,3 +378,29 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
324
378
|
if self.key_norm is not None:
|
|
325
379
|
result["key_norm"] = self.key_norm.export_weights(weight_layout)
|
|
326
380
|
return result
|
|
381
|
+
|
|
382
|
+
def import_weights(
|
|
383
|
+
self,
|
|
384
|
+
weights: ParameterTree[Array],
|
|
385
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
386
|
+
) -> Self:
|
|
387
|
+
assert isinstance(weights, Mapping)
|
|
388
|
+
assert isinstance(weights["qkv_projection"], Mapping)
|
|
389
|
+
assert isinstance(weights["out_projection"], Mapping)
|
|
390
|
+
if self.query_norm is not None:
|
|
391
|
+
assert isinstance(weights["query_norm"], Mapping)
|
|
392
|
+
query_norm = self.query_norm.import_weights(weights["query_norm"], weight_layout)
|
|
393
|
+
else:
|
|
394
|
+
query_norm = None
|
|
395
|
+
if self.key_norm is not None:
|
|
396
|
+
assert isinstance(weights["key_norm"], Mapping)
|
|
397
|
+
key_norm = self.key_norm.import_weights(weights["key_norm"], weight_layout)
|
|
398
|
+
else:
|
|
399
|
+
key_norm = None
|
|
400
|
+
return replace(
|
|
401
|
+
self,
|
|
402
|
+
qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"], weight_layout),
|
|
403
|
+
out_projection=self.out_projection.import_weights(weights["out_projection"], weight_layout),
|
|
404
|
+
query_norm=query_norm,
|
|
405
|
+
key_norm=key_norm,
|
|
406
|
+
)
|
lalamo/modules/common.py
CHANGED
|
@@ -2,19 +2,23 @@ from abc import abstractmethod
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from types import UnionType
|
|
5
|
+
from typing import Self
|
|
5
6
|
|
|
6
7
|
import equinox as eqx
|
|
7
8
|
from cattrs import Converter
|
|
9
|
+
from einops import rearrange
|
|
8
10
|
from jax import numpy as jnp
|
|
9
|
-
from jaxtyping import DTypeLike
|
|
11
|
+
from jaxtyping import Array, DTypeLike, Float
|
|
10
12
|
|
|
11
|
-
from lalamo.common import
|
|
13
|
+
from lalamo.common import ParameterTree
|
|
12
14
|
|
|
13
15
|
__all__ = [
|
|
14
16
|
"AttentionType",
|
|
15
17
|
"DummyUnionMember",
|
|
16
18
|
"LalamoModule",
|
|
17
19
|
"config_converter",
|
|
20
|
+
"from_layout",
|
|
21
|
+
"into_layout",
|
|
18
22
|
"register_config_union",
|
|
19
23
|
]
|
|
20
24
|
|
|
@@ -34,6 +38,42 @@ class WeightLayout(Enum):
|
|
|
34
38
|
return "(output, input)"
|
|
35
39
|
|
|
36
40
|
|
|
41
|
+
_DEFAULT_WEIGHT_LAYOUT = WeightLayout.INPUT_OUTPUT
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def into_layout(
|
|
45
|
+
weights: Float[Array, "in_channels out_channels"],
|
|
46
|
+
layout: WeightLayout,
|
|
47
|
+
) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
|
|
48
|
+
if layout == WeightLayout.AUTO:
|
|
49
|
+
layout = _DEFAULT_WEIGHT_LAYOUT
|
|
50
|
+
match layout:
|
|
51
|
+
case WeightLayout.OUTPUT_INPUT:
|
|
52
|
+
return weights
|
|
53
|
+
case WeightLayout.INPUT_OUTPUT:
|
|
54
|
+
return rearrange(
|
|
55
|
+
weights,
|
|
56
|
+
"total_out_channels in_channels -> in_channels total_out_channels",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def from_layout(
|
|
61
|
+
weights: ParameterTree | Array,
|
|
62
|
+
layout: WeightLayout,
|
|
63
|
+
) -> Array:
|
|
64
|
+
assert isinstance(weights, Array)
|
|
65
|
+
if layout == WeightLayout.AUTO:
|
|
66
|
+
layout = _DEFAULT_WEIGHT_LAYOUT
|
|
67
|
+
match layout:
|
|
68
|
+
case WeightLayout.OUTPUT_INPUT:
|
|
69
|
+
return weights
|
|
70
|
+
case WeightLayout.INPUT_OUTPUT:
|
|
71
|
+
return rearrange(
|
|
72
|
+
weights,
|
|
73
|
+
"in_channels total_out_channels -> total_out_channels in_channels",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
37
77
|
class AttentionType(Enum):
|
|
38
78
|
GLOBAL = "global"
|
|
39
79
|
SLIDING_WINDOW = "sliding_window"
|
|
@@ -47,7 +87,14 @@ class LalamoModule[ConfigT](eqx.Module):
|
|
|
47
87
|
def activation_precision(self) -> DTypeLike: ...
|
|
48
88
|
|
|
49
89
|
@abstractmethod
|
|
50
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
90
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree[Array]: ...
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def import_weights(
|
|
94
|
+
self,
|
|
95
|
+
weights: ParameterTree[Array],
|
|
96
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
97
|
+
) -> Self: ...
|
|
51
98
|
|
|
52
99
|
|
|
53
100
|
def _dtype_to_str(dtype: DTypeLike) -> str:
|
|
@@ -115,7 +162,7 @@ def register_config_union(union_type: UnionType) -> None:
|
|
|
115
162
|
new_config = dict(config)
|
|
116
163
|
type_name = new_config.pop("type")
|
|
117
164
|
target_type = name_to_type[type_name]
|
|
118
|
-
return
|
|
165
|
+
return config_converter.structure(new_config, target_type)
|
|
119
166
|
|
|
120
167
|
config_converter.register_structure_hook(
|
|
121
168
|
union_type,
|