lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +23 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +101 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
  44. lalamo-0.2.2.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,59 @@
1
+ from dataclasses import replace
2
+
3
+ from lalamo.model_import.configs import HFMistralConfig
4
+
5
+ from .common import (
6
+ HUGGINFACE_GENERATION_CONFIG_FILE,
7
+ HUGGINGFACE_TOKENIZER_FILES,
8
+ ModelSpec,
9
+ TokenizerFileSpec,
10
+ UseCase,
11
+ WeightsType,
12
+ huggingface_weight_files,
13
+ )
14
+
15
+ __all__ = ["MISTRAL_MODELS"]
16
+
17
+ CODESTRAL = [
18
+ ModelSpec(
19
+ vendor="Mistral",
20
+ family="Codestral",
21
+ name="Codestral-22B-v0.1",
22
+ size="22B",
23
+ quantization=None,
24
+ repo="mistral-community/Codestral-22B-v0.1",
25
+ config_type=HFMistralConfig,
26
+ config_file_name="config.json",
27
+ weights_file_names=huggingface_weight_files(9),
28
+ weights_type=WeightsType.SAFETENSORS,
29
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
30
+ use_cases=(UseCase.CODE,),
31
+ ),
32
+ ]
33
+
34
+
35
+ def _tokenizer_files_from_another_repo(repo: str) -> tuple[TokenizerFileSpec, ...]:
36
+ return tuple(
37
+ replace(spec, repo=repo) for spec in (*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE)
38
+ )
39
+
40
+
41
+ DEVSTRAL = [
42
+ ModelSpec(
43
+ vendor="Mistral",
44
+ family="Devstral",
45
+ name="Devstral-Small-2505",
46
+ size="24B",
47
+ quantization=None,
48
+ repo="mistralai/Devstral-Small-2505",
49
+ config_type=HFMistralConfig,
50
+ config_file_name="config.json",
51
+ weights_file_names=huggingface_weight_files(10),
52
+ weights_type=WeightsType.SAFETENSORS,
53
+ tokenizer_files=_tokenizer_files_from_another_repo("mistralai/Mistral-Small-3.1-24B-Base-2503"),
54
+ use_cases=(UseCase.CODE,),
55
+ ),
56
+ ]
57
+
58
+
59
+ MISTRAL_MODELS = CODESTRAL + DEVSTRAL
@@ -0,0 +1,28 @@
1
+ from lalamo.model_import.configs import HFLlamaConfig
2
+
3
+ from .common import (
4
+ HUGGINFACE_GENERATION_CONFIG_FILE,
5
+ HUGGINGFACE_TOKENIZER_FILES,
6
+ ModelSpec,
7
+ WeightsType,
8
+ huggingface_weight_files,
9
+ )
10
+
11
+ __all__ = ["PLEIAS_MODELS"]
12
+
13
+ PLEIAS_MODELS = [
14
+ ModelSpec(
15
+ vendor="PleIAs",
16
+ family="Pleias-RAG",
17
+ name="Pleias-RAG-1B",
18
+ size="1B",
19
+ quantization=None,
20
+ repo="PleIAs/Pleias-RAG-1B",
21
+ config_type=HFLlamaConfig,
22
+ config_file_name="config.json",
23
+ weights_file_names=huggingface_weight_files(1),
24
+ weights_type=WeightsType.SAFETENSORS,
25
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
+ use_cases=tuple(),
27
+ ),
28
+ ]
@@ -0,0 +1,22 @@
1
+ from lalamo.model_import.configs import HFQwen3Config
2
+
3
+ from .common import HUGGINGFACE_TOKENIZER_FILES, ModelSpec, TokenizerFileSpec, WeightsType, huggingface_weight_files
4
+
5
+ __all__ = ["POLARIS_MODELS"]
6
+
7
+ POLARIS_MODELS = [
8
+ ModelSpec(
9
+ vendor="POLARIS-Project",
10
+ family="Polaris-Preview",
11
+ name="Polaris-4B-Preview",
12
+ size="4B",
13
+ quantization=None,
14
+ repo="POLARIS-Project/Polaris-4B-Preview",
15
+ config_type=HFQwen3Config,
16
+ config_file_name="config.json",
17
+ weights_file_names=huggingface_weight_files(2),
18
+ weights_type=WeightsType.SAFETENSORS,
19
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, TokenizerFileSpec(repo=None, filename="chat_template.jinja")),
20
+ use_cases=tuple(),
21
+ ),
22
+ ]
@@ -0,0 +1,336 @@
1
+ from lalamo.model_import.configs import HFQwen2Config, HFQwen3Config
2
+ from lalamo.quantization import QuantizationMode
3
+
4
+ from .common import (
5
+ HUGGINFACE_GENERATION_CONFIG_FILE,
6
+ HUGGINGFACE_TOKENIZER_FILES,
7
+ ModelSpec,
8
+ UseCase,
9
+ WeightsType,
10
+ huggingface_weight_files,
11
+ )
12
+
13
+ __all__ = ["QWEN_MODELS"]
14
+
15
+
16
+ QWEN25 = [
17
+ ModelSpec(
18
+ vendor="Alibaba",
19
+ family="Qwen2.5",
20
+ name="Qwen2.5-0.5B-Instruct",
21
+ size="0.5B",
22
+ quantization=None,
23
+ repo="Qwen/Qwen2.5-0.5B-Instruct",
24
+ config_type=HFQwen2Config,
25
+ config_file_name="config.json",
26
+ weights_file_names=huggingface_weight_files(1),
27
+ weights_type=WeightsType.SAFETENSORS,
28
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
29
+ use_cases=tuple(),
30
+ ),
31
+ ModelSpec(
32
+ vendor="Alibaba",
33
+ family="Qwen2.5",
34
+ name="Qwen2.5-1.5B-Instruct",
35
+ size="1.5B",
36
+ quantization=None,
37
+ repo="Qwen/Qwen2.5-1.5B-Instruct",
38
+ config_type=HFQwen2Config,
39
+ config_file_name="config.json",
40
+ weights_file_names=huggingface_weight_files(1),
41
+ weights_type=WeightsType.SAFETENSORS,
42
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
43
+ use_cases=tuple(),
44
+ ),
45
+ ModelSpec(
46
+ vendor="Alibaba",
47
+ family="Qwen2.5",
48
+ name="Qwen2.5-3B-Instruct",
49
+ size="3B",
50
+ quantization=None,
51
+ repo="Qwen/Qwen2.5-3B-Instruct",
52
+ config_type=HFQwen2Config,
53
+ config_file_name="config.json",
54
+ weights_file_names=huggingface_weight_files(2),
55
+ weights_type=WeightsType.SAFETENSORS,
56
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
57
+ use_cases=tuple(),
58
+ ),
59
+ ModelSpec(
60
+ vendor="Alibaba",
61
+ family="Qwen2.5",
62
+ name="Qwen2.5-7B-Instruct",
63
+ size="7B",
64
+ quantization=None,
65
+ repo="Qwen/Qwen2.5-7B-Instruct",
66
+ config_type=HFQwen2Config,
67
+ config_file_name="config.json",
68
+ weights_file_names=huggingface_weight_files(4),
69
+ weights_type=WeightsType.SAFETENSORS,
70
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
71
+ use_cases=tuple(),
72
+ ),
73
+ ModelSpec(
74
+ vendor="Alibaba",
75
+ family="Qwen2.5",
76
+ name="Qwen2.5-14B-Instruct",
77
+ size="14B",
78
+ quantization=None,
79
+ repo="Qwen/Qwen2.5-14B-Instruct",
80
+ config_type=HFQwen2Config,
81
+ config_file_name="config.json",
82
+ weights_file_names=huggingface_weight_files(8),
83
+ weights_type=WeightsType.SAFETENSORS,
84
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
85
+ use_cases=tuple(),
86
+ ),
87
+ ModelSpec(
88
+ vendor="Alibaba",
89
+ family="Qwen2.5",
90
+ name="Qwen2.5-32B-Instruct",
91
+ size="32B",
92
+ quantization=None,
93
+ repo="Qwen/Qwen2.5-32B-Instruct",
94
+ config_type=HFQwen2Config,
95
+ config_file_name="config.json",
96
+ weights_file_names=huggingface_weight_files(17),
97
+ weights_type=WeightsType.SAFETENSORS,
98
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
99
+ use_cases=tuple(),
100
+ ),
101
+ ]
102
+
103
+
104
+ QWEN25_CODER = [
105
+ ModelSpec(
106
+ vendor="Alibaba",
107
+ family="Qwen2.5-Coder",
108
+ name="Qwen2.5-Coder-0.5B-Instruct",
109
+ size="0.5B",
110
+ quantization=None,
111
+ repo="Qwen/Qwen2.5-Coder-0.5B-Instruct",
112
+ config_type=HFQwen2Config,
113
+ config_file_name="config.json",
114
+ weights_file_names=huggingface_weight_files(1),
115
+ weights_type=WeightsType.SAFETENSORS,
116
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
117
+ use_cases=(UseCase.CODE,),
118
+ ),
119
+ ModelSpec(
120
+ vendor="Alibaba",
121
+ family="Qwen2.5-Coder",
122
+ name="Qwen2.5-Coder-1.5B-Instruct",
123
+ size="1.5B",
124
+ quantization=None,
125
+ repo="Qwen/Qwen2.5-Coder-1.5B-Instruct",
126
+ config_type=HFQwen2Config,
127
+ config_file_name="config.json",
128
+ weights_file_names=huggingface_weight_files(1),
129
+ weights_type=WeightsType.SAFETENSORS,
130
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
131
+ use_cases=(UseCase.CODE,),
132
+ ),
133
+ ModelSpec(
134
+ vendor="Alibaba",
135
+ family="Qwen2.5-Coder",
136
+ name="Qwen2.5-Coder-3B-Instruct",
137
+ size="3B",
138
+ quantization=None,
139
+ repo="Qwen/Qwen2.5-Coder-3B-Instruct",
140
+ config_type=HFQwen2Config,
141
+ config_file_name="config.json",
142
+ weights_file_names=huggingface_weight_files(2),
143
+ weights_type=WeightsType.SAFETENSORS,
144
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
145
+ use_cases=(UseCase.CODE,),
146
+ ),
147
+ ModelSpec(
148
+ vendor="Alibaba",
149
+ family="Qwen2.5-Coder",
150
+ name="Qwen2.5-Coder-7B-Instruct",
151
+ size="7B",
152
+ quantization=None,
153
+ repo="Qwen/Qwen2.5-Coder-7B-Instruct",
154
+ config_type=HFQwen2Config,
155
+ config_file_name="config.json",
156
+ weights_file_names=huggingface_weight_files(4),
157
+ weights_type=WeightsType.SAFETENSORS,
158
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
159
+ use_cases=(UseCase.CODE,),
160
+ ),
161
+ ModelSpec(
162
+ vendor="Alibaba",
163
+ family="Qwen2.5-Coder",
164
+ name="Qwen2.5-Coder-14B-Instruct",
165
+ size="14B",
166
+ quantization=None,
167
+ repo="Qwen/Qwen2.5-Coder-14B-Instruct",
168
+ config_type=HFQwen2Config,
169
+ config_file_name="config.json",
170
+ weights_file_names=huggingface_weight_files(6),
171
+ weights_type=WeightsType.SAFETENSORS,
172
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
173
+ use_cases=(UseCase.CODE,),
174
+ ),
175
+ ModelSpec(
176
+ vendor="Alibaba",
177
+ family="Qwen2.5-Coder",
178
+ name="Qwen2.5-Coder-32B-Instruct",
179
+ size="32B",
180
+ quantization=None,
181
+ repo="Qwen/Qwen2.5-Coder-32B-Instruct",
182
+ config_type=HFQwen2Config,
183
+ config_file_name="config.json",
184
+ weights_file_names=huggingface_weight_files(14),
185
+ weights_type=WeightsType.SAFETENSORS,
186
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
187
+ use_cases=(UseCase.CODE,),
188
+ ),
189
+ ]
190
+
191
+
192
+ QWEN3 = [
193
+ ModelSpec(
194
+ vendor="Alibaba",
195
+ family="Qwen3",
196
+ name="Qwen3-0.6B",
197
+ size="0.6B",
198
+ quantization=None,
199
+ repo="Qwen/Qwen3-0.6B",
200
+ config_type=HFQwen3Config,
201
+ config_file_name="config.json",
202
+ weights_file_names=huggingface_weight_files(1),
203
+ weights_type=WeightsType.SAFETENSORS,
204
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
205
+ use_cases=tuple(),
206
+ ),
207
+ ModelSpec(
208
+ vendor="Alibaba",
209
+ family="Qwen3",
210
+ name="Qwen3-1.7B",
211
+ size="1.7B",
212
+ quantization=None,
213
+ repo="Qwen/Qwen3-1.7B",
214
+ config_type=HFQwen3Config,
215
+ config_file_name="config.json",
216
+ weights_file_names=huggingface_weight_files(2),
217
+ weights_type=WeightsType.SAFETENSORS,
218
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
219
+ use_cases=tuple(),
220
+ ),
221
+ ModelSpec(
222
+ vendor="Alibaba",
223
+ family="Qwen3",
224
+ name="Qwen3-4B",
225
+ size="4B",
226
+ quantization=None,
227
+ repo="Qwen/Qwen3-4B",
228
+ config_type=HFQwen3Config,
229
+ config_file_name="config.json",
230
+ weights_file_names=huggingface_weight_files(3),
231
+ weights_type=WeightsType.SAFETENSORS,
232
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
233
+ use_cases=tuple(),
234
+ ),
235
+ ModelSpec(
236
+ vendor="Alibaba",
237
+ family="Qwen3",
238
+ name="Qwen3-4B-AWQ",
239
+ size="4B",
240
+ quantization=QuantizationMode.UINT4,
241
+ repo="Qwen/Qwen3-4B-AWQ",
242
+ config_type=HFQwen3Config,
243
+ config_file_name="config.json",
244
+ weights_file_names=huggingface_weight_files(1),
245
+ weights_type=WeightsType.SAFETENSORS,
246
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
247
+ use_cases=tuple(),
248
+ ),
249
+ ModelSpec(
250
+ vendor="Alibaba",
251
+ family="Qwen3",
252
+ name="Qwen3-8B",
253
+ size="8B",
254
+ quantization=None,
255
+ repo="Qwen/Qwen3-8B",
256
+ config_type=HFQwen3Config,
257
+ config_file_name="config.json",
258
+ weights_file_names=huggingface_weight_files(5),
259
+ weights_type=WeightsType.SAFETENSORS,
260
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
261
+ use_cases=tuple(),
262
+ ),
263
+ ModelSpec(
264
+ vendor="Alibaba",
265
+ family="Qwen3",
266
+ name="Qwen3-8B-AWQ",
267
+ size="8B",
268
+ quantization=QuantizationMode.UINT4,
269
+ repo="Qwen/Qwen3-8B-AWQ",
270
+ config_type=HFQwen3Config,
271
+ config_file_name="config.json",
272
+ weights_file_names=huggingface_weight_files(2),
273
+ weights_type=WeightsType.SAFETENSORS,
274
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
275
+ use_cases=tuple(),
276
+ ),
277
+ ModelSpec(
278
+ vendor="Alibaba",
279
+ family="Qwen3",
280
+ name="Qwen3-14B",
281
+ size="14B",
282
+ quantization=None,
283
+ repo="Qwen/Qwen3-14B",
284
+ config_type=HFQwen3Config,
285
+ config_file_name="config.json",
286
+ weights_file_names=huggingface_weight_files(8),
287
+ weights_type=WeightsType.SAFETENSORS,
288
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
289
+ use_cases=tuple(),
290
+ ),
291
+ ModelSpec(
292
+ vendor="Alibaba",
293
+ family="Qwen3",
294
+ name="Qwen3-14B-AWQ",
295
+ size="14B",
296
+ quantization=None,
297
+ repo="Qwen/Qwen3-14B-AWQ",
298
+ config_type=HFQwen3Config,
299
+ config_file_name="config.json",
300
+ weights_file_names=huggingface_weight_files(2),
301
+ weights_type=WeightsType.SAFETENSORS,
302
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
303
+ use_cases=tuple(),
304
+ ),
305
+ ModelSpec(
306
+ vendor="Alibaba",
307
+ family="Qwen3",
308
+ name="Qwen3-32B",
309
+ size="32B",
310
+ quantization=None,
311
+ repo="Qwen/Qwen3-32B",
312
+ config_type=HFQwen3Config,
313
+ config_file_name="config.json",
314
+ weights_file_names=huggingface_weight_files(17),
315
+ weights_type=WeightsType.SAFETENSORS,
316
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
317
+ use_cases=tuple(),
318
+ ),
319
+ ModelSpec(
320
+ vendor="Alibaba",
321
+ family="Qwen3",
322
+ name="Qwen3-32B-AWQ",
323
+ size="32B",
324
+ quantization=QuantizationMode.UINT4,
325
+ repo="Qwen/Qwen3-32B-AWQ",
326
+ config_type=HFQwen3Config,
327
+ config_file_name="config.json",
328
+ weights_file_names=huggingface_weight_files(4),
329
+ weights_type=WeightsType.SAFETENSORS,
330
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
331
+ use_cases=tuple(),
332
+ ),
333
+ ]
334
+
335
+
336
+ QWEN_MODELS = QWEN25 + QWEN25_CODER + QWEN3
@@ -0,0 +1,28 @@
1
+ from lalamo.model_import.configs import HFLlamaConfig
2
+
3
+ from .common import (
4
+ HUGGINFACE_GENERATION_CONFIG_FILE,
5
+ HUGGINGFACE_TOKENIZER_FILES,
6
+ ModelSpec,
7
+ WeightsType,
8
+ huggingface_weight_files,
9
+ )
10
+
11
+ __all__ = ["REKA_MODELS"]
12
+
13
+ REKA_MODELS = [
14
+ ModelSpec(
15
+ vendor="Reka",
16
+ family="Reka-Flash",
17
+ name="Reka-Flash-3.1",
18
+ size="21B",
19
+ quantization=None,
20
+ repo="RekaAI/reka-flash-3.1",
21
+ config_type=HFLlamaConfig,
22
+ config_file_name="config.json",
23
+ weights_file_names=huggingface_weight_files(9), # Model has 9 shards
24
+ weights_type=WeightsType.SAFETENSORS,
25
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
+ use_cases=tuple(),
27
+ ),
28
+ ]
@@ -0,0 +1,85 @@
1
+ from .activations import Activation
2
+ from .attention import Attention, AttentionConfig
3
+ from .common import WeightLayout, config_converter
4
+ from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderResult
5
+ from .decoder_layer import DecoderLayer, DecoderLayerActivationTrace, DecoderLayerConfig, DecoderLayerResult
6
+ from .embedding import (
7
+ EmbeddingBase,
8
+ EmbeddingConfig,
9
+ QuantizedTiedEmbedding,
10
+ QuantizedTiedEmbeddingConfig,
11
+ TiedEmbedding,
12
+ TiedEmbeddingConfig,
13
+ UntiedEmbedding,
14
+ UntiedEmbeddingConfig,
15
+ )
16
+ from .kv_cache import DynamicKVCacheLayer, KVCache, KVCacheLayer, StaticKVCacheLayer
17
+ from .linear import (
18
+ FullPrecisionLinear,
19
+ FullPrecisionLinearConfig,
20
+ GroupQuantizedLinear,
21
+ GroupQuantizedLinearConfig,
22
+ LinearBase,
23
+ LinearConfig,
24
+ QLoRALinear,
25
+ QLoRALinearConfig,
26
+ )
27
+ from .mlp import MLP, MLPConfig
28
+ from .normalization import RMSNorm, RMSNormConfig, UpcastMode
29
+ from .rope import (
30
+ LinearScalingRoPEConfig,
31
+ LlamaRoPEConfig,
32
+ PositionalEmbeddings,
33
+ RoPE,
34
+ RoPEConfig,
35
+ UnscaledRoPEConfig,
36
+ YARNRoPEConfig,
37
+ )
38
+
39
+ __all__ = [
40
+ "MLP",
41
+ "Activation",
42
+ "Attention",
43
+ "AttentionConfig",
44
+ "Decoder",
45
+ "DecoderActivationTrace",
46
+ "DecoderConfig",
47
+ "DecoderLayer",
48
+ "DecoderLayerActivationTrace",
49
+ "DecoderLayerConfig",
50
+ "DecoderLayerResult",
51
+ "DecoderResult",
52
+ "DynamicKVCacheLayer",
53
+ "EmbeddingBase",
54
+ "EmbeddingConfig",
55
+ "FullPrecisionLinear",
56
+ "FullPrecisionLinearConfig",
57
+ "GroupQuantizedLinear",
58
+ "GroupQuantizedLinearConfig",
59
+ "KVCache",
60
+ "KVCacheLayer",
61
+ "LinearBase",
62
+ "LinearConfig",
63
+ "LinearScalingRoPEConfig",
64
+ "LlamaRoPEConfig",
65
+ "MLPConfig",
66
+ "PositionalEmbeddings",
67
+ "QLoRALinear",
68
+ "QLoRALinearConfig",
69
+ "QuantizedTiedEmbedding",
70
+ "QuantizedTiedEmbeddingConfig",
71
+ "RMSNorm",
72
+ "RMSNormConfig",
73
+ "RoPE",
74
+ "RoPEConfig",
75
+ "StaticKVCacheLayer",
76
+ "TiedEmbedding",
77
+ "TiedEmbeddingConfig",
78
+ "UnscaledRoPEConfig",
79
+ "UntiedEmbedding",
80
+ "UntiedEmbeddingConfig",
81
+ "UpcastMode",
82
+ "WeightLayout",
83
+ "YARNRoPEConfig",
84
+ "config_converter",
85
+ ]
@@ -0,0 +1,30 @@
1
+ from enum import Enum
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax import jit
6
+ from jaxtyping import Array, Float
7
+
8
+ __all__ = [
9
+ "Activation",
10
+ "silu",
11
+ ]
12
+
13
+
14
+ @jit
15
+ def silu(x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
16
+ return x / (1 + jnp.exp(-x))
17
+
18
+
19
+ class Activation(Enum):
20
+ SILU = "silu"
21
+ GELU = "gelu"
22
+
23
+ def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
24
+ return ACTIVATION_FUNCTIONS[self](x)
25
+
26
+
27
+ ACTIVATION_FUNCTIONS = {
28
+ Activation.SILU: silu,
29
+ Activation.GELU: jax.nn.gelu,
30
+ }