lalamo 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/common.py +79 -29
  3. lalamo/language_model.py +106 -83
  4. lalamo/main.py +91 -18
  5. lalamo/message_processor.py +170 -0
  6. lalamo/model_import/common.py +159 -43
  7. lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
  8. lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
  9. lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
  10. lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
  11. lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
  12. lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
  13. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
  14. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
  15. lalamo/model_import/huggingface_generation_config.py +44 -0
  16. lalamo/model_import/huggingface_tokenizer_config.py +85 -0
  17. lalamo/model_import/loaders/common.py +2 -1
  18. lalamo/model_import/loaders/huggingface.py +12 -10
  19. lalamo/model_import/model_specs/__init__.py +3 -2
  20. lalamo/model_import/model_specs/common.py +32 -34
  21. lalamo/model_import/model_specs/deepseek.py +1 -10
  22. lalamo/model_import/model_specs/gemma.py +2 -25
  23. lalamo/model_import/model_specs/huggingface.py +2 -12
  24. lalamo/model_import/model_specs/llama.py +2 -58
  25. lalamo/model_import/model_specs/mistral.py +9 -19
  26. lalamo/model_import/model_specs/pleias.py +3 -13
  27. lalamo/model_import/model_specs/polaris.py +5 -7
  28. lalamo/model_import/model_specs/qwen.py +12 -111
  29. lalamo/model_import/model_specs/reka.py +4 -13
  30. lalamo/modules/__init__.py +2 -1
  31. lalamo/modules/attention.py +90 -10
  32. lalamo/modules/common.py +51 -4
  33. lalamo/modules/decoder.py +90 -8
  34. lalamo/modules/decoder_layer.py +85 -8
  35. lalamo/modules/embedding.py +95 -29
  36. lalamo/modules/kv_cache.py +3 -3
  37. lalamo/modules/linear.py +170 -130
  38. lalamo/modules/mlp.py +40 -7
  39. lalamo/modules/normalization.py +24 -6
  40. lalamo/modules/rope.py +24 -6
  41. lalamo/sampling.py +99 -0
  42. lalamo/utils.py +86 -1
  43. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/METADATA +6 -6
  44. lalamo-0.3.0.dist-info/RECORD +58 -0
  45. lalamo-0.2.7.dist-info/RECORD +0 -54
  46. /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
  47. /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
  48. /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
  49. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/WHEEL +0 -0
  50. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/entry_points.txt +0 -0
  51. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/licenses/LICENSE +0 -0
  52. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,7 @@
1
- from lalamo.model_import.configs import HFQwen2Config, HFQwen3Config
1
+ from lalamo.model_import.decoder_configs import HFQwen2Config, HFQwen3Config
2
2
  from lalamo.quantization import QuantizationMode
3
3
 
4
- from .common import (
5
- HUGGINFACE_GENERATION_CONFIG_FILE,
6
- HUGGINGFACE_TOKENIZER_FILES,
7
- ModelSpec,
8
- UseCase,
9
- WeightsType,
10
- huggingface_weight_files,
11
- )
4
+ from .common import ModelSpec, UseCase, WeightsType
12
5
 
13
6
  __all__ = ["QWEN_MODELS"]
14
7
 
@@ -22,11 +15,6 @@ QWEN25 = [
22
15
  quantization=None,
23
16
  repo="Qwen/Qwen2.5-0.5B-Instruct",
24
17
  config_type=HFQwen2Config,
25
- config_file_name="config.json",
26
- weights_file_names=huggingface_weight_files(1),
27
- weights_type=WeightsType.SAFETENSORS,
28
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
29
- use_cases=tuple(),
30
18
  ),
31
19
  ModelSpec(
32
20
  vendor="Alibaba",
@@ -36,11 +24,6 @@ QWEN25 = [
36
24
  quantization=None,
37
25
  repo="Qwen/Qwen2.5-1.5B-Instruct",
38
26
  config_type=HFQwen2Config,
39
- config_file_name="config.json",
40
- weights_file_names=huggingface_weight_files(1),
41
- weights_type=WeightsType.SAFETENSORS,
42
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
43
- use_cases=tuple(),
44
27
  ),
45
28
  ModelSpec(
46
29
  vendor="Alibaba",
@@ -50,11 +33,6 @@ QWEN25 = [
50
33
  quantization=None,
51
34
  repo="Qwen/Qwen2.5-3B-Instruct",
52
35
  config_type=HFQwen2Config,
53
- config_file_name="config.json",
54
- weights_file_names=huggingface_weight_files(2),
55
- weights_type=WeightsType.SAFETENSORS,
56
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
57
- use_cases=tuple(),
58
36
  ),
59
37
  ModelSpec(
60
38
  vendor="Alibaba",
@@ -64,11 +42,6 @@ QWEN25 = [
64
42
  quantization=None,
65
43
  repo="Qwen/Qwen2.5-7B-Instruct",
66
44
  config_type=HFQwen2Config,
67
- config_file_name="config.json",
68
- weights_file_names=huggingface_weight_files(4),
69
- weights_type=WeightsType.SAFETENSORS,
70
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
71
- use_cases=tuple(),
72
45
  ),
73
46
  ModelSpec(
74
47
  vendor="Alibaba",
@@ -78,11 +51,6 @@ QWEN25 = [
78
51
  quantization=None,
79
52
  repo="Qwen/Qwen2.5-14B-Instruct",
80
53
  config_type=HFQwen2Config,
81
- config_file_name="config.json",
82
- weights_file_names=huggingface_weight_files(8),
83
- weights_type=WeightsType.SAFETENSORS,
84
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
85
- use_cases=tuple(),
86
54
  ),
87
55
  ModelSpec(
88
56
  vendor="Alibaba",
@@ -92,11 +60,6 @@ QWEN25 = [
92
60
  quantization=None,
93
61
  repo="Qwen/Qwen2.5-32B-Instruct",
94
62
  config_type=HFQwen2Config,
95
- config_file_name="config.json",
96
- weights_file_names=huggingface_weight_files(17),
97
- weights_type=WeightsType.SAFETENSORS,
98
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
99
- use_cases=tuple(),
100
63
  ),
101
64
  ]
102
65
 
@@ -110,10 +73,6 @@ QWEN25_CODER = [
110
73
  quantization=None,
111
74
  repo="Qwen/Qwen2.5-Coder-0.5B-Instruct",
112
75
  config_type=HFQwen2Config,
113
- config_file_name="config.json",
114
- weights_file_names=huggingface_weight_files(1),
115
- weights_type=WeightsType.SAFETENSORS,
116
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
117
76
  use_cases=(UseCase.CODE,),
118
77
  ),
119
78
  ModelSpec(
@@ -124,10 +83,6 @@ QWEN25_CODER = [
124
83
  quantization=None,
125
84
  repo="Qwen/Qwen2.5-Coder-1.5B-Instruct",
126
85
  config_type=HFQwen2Config,
127
- config_file_name="config.json",
128
- weights_file_names=huggingface_weight_files(1),
129
- weights_type=WeightsType.SAFETENSORS,
130
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
131
86
  use_cases=(UseCase.CODE,),
132
87
  ),
133
88
  ModelSpec(
@@ -138,10 +93,6 @@ QWEN25_CODER = [
138
93
  quantization=None,
139
94
  repo="Qwen/Qwen2.5-Coder-3B-Instruct",
140
95
  config_type=HFQwen2Config,
141
- config_file_name="config.json",
142
- weights_file_names=huggingface_weight_files(2),
143
- weights_type=WeightsType.SAFETENSORS,
144
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
145
96
  use_cases=(UseCase.CODE,),
146
97
  ),
147
98
  ModelSpec(
@@ -152,10 +103,6 @@ QWEN25_CODER = [
152
103
  quantization=None,
153
104
  repo="Qwen/Qwen2.5-Coder-7B-Instruct",
154
105
  config_type=HFQwen2Config,
155
- config_file_name="config.json",
156
- weights_file_names=huggingface_weight_files(4),
157
- weights_type=WeightsType.SAFETENSORS,
158
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
159
106
  use_cases=(UseCase.CODE,),
160
107
  ),
161
108
  ModelSpec(
@@ -166,10 +113,6 @@ QWEN25_CODER = [
166
113
  quantization=None,
167
114
  repo="Qwen/Qwen2.5-Coder-14B-Instruct",
168
115
  config_type=HFQwen2Config,
169
- config_file_name="config.json",
170
- weights_file_names=huggingface_weight_files(6),
171
- weights_type=WeightsType.SAFETENSORS,
172
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
173
116
  use_cases=(UseCase.CODE,),
174
117
  ),
175
118
  ModelSpec(
@@ -180,10 +123,16 @@ QWEN25_CODER = [
180
123
  quantization=None,
181
124
  repo="Qwen/Qwen2.5-Coder-32B-Instruct",
182
125
  config_type=HFQwen2Config,
183
- config_file_name="config.json",
184
- weights_file_names=huggingface_weight_files(14),
185
- weights_type=WeightsType.SAFETENSORS,
186
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
126
+ use_cases=(UseCase.CODE,),
127
+ ),
128
+ ModelSpec(
129
+ vendor="Alibaba",
130
+ family="Qwen2.5-Coder",
131
+ name="Qwen2.5-Coder-32B-Instruct",
132
+ size="32B",
133
+ quantization=None,
134
+ repo="Qwen/Qwen2.5-Coder-32B-Instruct",
135
+ config_type=HFQwen2Config,
187
136
  use_cases=(UseCase.CODE,),
188
137
  ),
189
138
  ]
@@ -198,11 +147,6 @@ QWEN3 = [
198
147
  quantization=None,
199
148
  repo="Qwen/Qwen3-0.6B",
200
149
  config_type=HFQwen3Config,
201
- config_file_name="config.json",
202
- weights_file_names=huggingface_weight_files(1),
203
- weights_type=WeightsType.SAFETENSORS,
204
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
205
- use_cases=tuple(),
206
150
  ),
207
151
  ModelSpec(
208
152
  vendor="Alibaba",
@@ -212,10 +156,7 @@ QWEN3 = [
212
156
  quantization=None,
213
157
  repo="Qwen/Qwen3-1.7B",
214
158
  config_type=HFQwen3Config,
215
- config_file_name="config.json",
216
- weights_file_names=huggingface_weight_files(2),
217
159
  weights_type=WeightsType.SAFETENSORS,
218
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
219
160
  use_cases=tuple(),
220
161
  ),
221
162
  ModelSpec(
@@ -226,11 +167,6 @@ QWEN3 = [
226
167
  quantization=None,
227
168
  repo="Qwen/Qwen3-4B",
228
169
  config_type=HFQwen3Config,
229
- config_file_name="config.json",
230
- weights_file_names=huggingface_weight_files(3),
231
- weights_type=WeightsType.SAFETENSORS,
232
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
233
- use_cases=tuple(),
234
170
  ),
235
171
  ModelSpec(
236
172
  vendor="Alibaba",
@@ -240,11 +176,6 @@ QWEN3 = [
240
176
  quantization=QuantizationMode.UINT4,
241
177
  repo="Qwen/Qwen3-4B-AWQ",
242
178
  config_type=HFQwen3Config,
243
- config_file_name="config.json",
244
- weights_file_names=huggingface_weight_files(1),
245
- weights_type=WeightsType.SAFETENSORS,
246
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
247
- use_cases=tuple(),
248
179
  ),
249
180
  ModelSpec(
250
181
  vendor="Alibaba",
@@ -254,11 +185,6 @@ QWEN3 = [
254
185
  quantization=None,
255
186
  repo="Qwen/Qwen3-8B",
256
187
  config_type=HFQwen3Config,
257
- config_file_name="config.json",
258
- weights_file_names=huggingface_weight_files(5),
259
- weights_type=WeightsType.SAFETENSORS,
260
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
261
- use_cases=tuple(),
262
188
  ),
263
189
  ModelSpec(
264
190
  vendor="Alibaba",
@@ -268,11 +194,6 @@ QWEN3 = [
268
194
  quantization=QuantizationMode.UINT4,
269
195
  repo="Qwen/Qwen3-8B-AWQ",
270
196
  config_type=HFQwen3Config,
271
- config_file_name="config.json",
272
- weights_file_names=huggingface_weight_files(2),
273
- weights_type=WeightsType.SAFETENSORS,
274
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
275
- use_cases=tuple(),
276
197
  ),
277
198
  ModelSpec(
278
199
  vendor="Alibaba",
@@ -282,11 +203,6 @@ QWEN3 = [
282
203
  quantization=None,
283
204
  repo="Qwen/Qwen3-14B",
284
205
  config_type=HFQwen3Config,
285
- config_file_name="config.json",
286
- weights_file_names=huggingface_weight_files(8),
287
- weights_type=WeightsType.SAFETENSORS,
288
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
289
- use_cases=tuple(),
290
206
  ),
291
207
  ModelSpec(
292
208
  vendor="Alibaba",
@@ -296,11 +212,6 @@ QWEN3 = [
296
212
  quantization=None,
297
213
  repo="Qwen/Qwen3-14B-AWQ",
298
214
  config_type=HFQwen3Config,
299
- config_file_name="config.json",
300
- weights_file_names=huggingface_weight_files(2),
301
- weights_type=WeightsType.SAFETENSORS,
302
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
303
- use_cases=tuple(),
304
215
  ),
305
216
  ModelSpec(
306
217
  vendor="Alibaba",
@@ -310,11 +221,6 @@ QWEN3 = [
310
221
  quantization=None,
311
222
  repo="Qwen/Qwen3-32B",
312
223
  config_type=HFQwen3Config,
313
- config_file_name="config.json",
314
- weights_file_names=huggingface_weight_files(17),
315
- weights_type=WeightsType.SAFETENSORS,
316
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
317
- use_cases=tuple(),
318
224
  ),
319
225
  ModelSpec(
320
226
  vendor="Alibaba",
@@ -324,11 +230,6 @@ QWEN3 = [
324
230
  quantization=QuantizationMode.UINT4,
325
231
  repo="Qwen/Qwen3-32B-AWQ",
326
232
  config_type=HFQwen3Config,
327
- config_file_name="config.json",
328
- weights_file_names=huggingface_weight_files(4),
329
- weights_type=WeightsType.SAFETENSORS,
330
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
331
- use_cases=tuple(),
332
233
  ),
333
234
  ]
334
235
 
@@ -1,12 +1,6 @@
1
- from lalamo.model_import.configs import HFLlamaConfig
1
+ from lalamo.model_import.decoder_configs import HFLlamaConfig
2
2
 
3
- from .common import (
4
- HUGGINFACE_GENERATION_CONFIG_FILE,
5
- HUGGINGFACE_TOKENIZER_FILES,
6
- ModelSpec,
7
- WeightsType,
8
- huggingface_weight_files,
9
- )
3
+ from .common import ModelSpec
10
4
 
11
5
  __all__ = ["REKA_MODELS"]
12
6
 
@@ -19,10 +13,7 @@ REKA_MODELS = [
19
13
  quantization=None,
20
14
  repo="RekaAI/reka-flash-3.1",
21
15
  config_type=HFLlamaConfig,
22
- config_file_name="config.json",
23
- weights_file_names=huggingface_weight_files(9), # Model has 9 shards
24
- weights_type=WeightsType.SAFETENSORS,
25
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
16
+ user_role_name="human",
26
17
  use_cases=tuple(),
27
18
  ),
28
- ]
19
+ ]
@@ -1,6 +1,6 @@
1
1
  from .activations import Activation
2
2
  from .attention import Attention, AttentionConfig
3
- from .common import WeightLayout, config_converter
3
+ from .common import LalamoModule, WeightLayout, config_converter
4
4
  from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderResult
5
5
  from .decoder_layer import DecoderLayer, DecoderLayerActivationTrace, DecoderLayerConfig, DecoderLayerResult
6
6
  from .embedding import (
@@ -58,6 +58,7 @@ __all__ = [
58
58
  "GroupQuantizedLinearConfig",
59
59
  "KVCache",
60
60
  "KVCacheLayer",
61
+ "LalamoModule",
61
62
  "LinearBase",
62
63
  "LinearConfig",
63
64
  "LinearScalingRoPEConfig",
@@ -1,5 +1,6 @@
1
- from dataclasses import dataclass
2
- from typing import NamedTuple
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass, replace
3
+ from typing import NamedTuple, Self
3
4
 
4
5
  import equinox as eqx
5
6
  import jax
@@ -8,10 +9,9 @@ from jax import numpy as jnp
8
9
  from jax import vmap
9
10
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
10
11
 
11
- from lalamo.common import ParameterDict
12
12
  from lalamo.modules.normalization import RMSNorm, RMSNormConfig
13
13
 
14
- from .common import AttentionType, LalamoModule, WeightLayout
14
+ from .common import AttentionType, LalamoModule, ParameterTree, WeightLayout
15
15
  from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
16
16
  from .linear import LinearBase, LinearConfig
17
17
  from .rope import PositionalEmbeddings
@@ -42,8 +42,8 @@ def _soft_capped_attention_kernel(
42
42
  scale: float | None,
43
43
  logit_soft_cap: float,
44
44
  ) -> Float[Array, "dst_tokens heads head_channels"]:
45
- dst_length, num_heads, head_dim = queries.shape
46
- src_length, num_groups, _ = keys.shape
45
+ _, num_heads, head_dim = queries.shape
46
+ _, num_groups, _ = keys.shape
47
47
  if scale is None:
48
48
  scale = head_dim**-0.5
49
49
  group_size = num_heads // num_groups
@@ -118,14 +118,67 @@ class AttentionConfig:
118
118
 
119
119
  if self.query_norm_config is not None:
120
120
  query_norm = self.query_norm_config.init(
121
- channels=head_dim,
121
+ input_dim=head_dim,
122
122
  )
123
123
  else:
124
124
  query_norm = None
125
125
 
126
126
  if self.key_norm_config is not None:
127
127
  key_norm = self.key_norm_config.init(
128
- channels=head_dim,
128
+ input_dim=head_dim,
129
+ )
130
+ else:
131
+ key_norm = None
132
+
133
+ return Attention(
134
+ self,
135
+ qkv_projection=qkv_projection,
136
+ out_projection=out_projection,
137
+ query_norm=query_norm,
138
+ key_norm=key_norm,
139
+ num_heads=num_heads,
140
+ num_groups=num_groups,
141
+ head_dim=head_dim,
142
+ is_causal=is_causal,
143
+ scale=scale,
144
+ sliding_window_size=sliding_window_size,
145
+ )
146
+
147
+ def empty(
148
+ self,
149
+ model_dim: int,
150
+ num_heads: int,
151
+ num_groups: int,
152
+ head_dim: int,
153
+ is_causal: bool,
154
+ scale: float | None,
155
+ sliding_window_size: int | None,
156
+ ) -> "Attention":
157
+ qkv_projection = self.qkv_projection_config.empty(
158
+ input_dim=model_dim,
159
+ output_dims=(
160
+ num_heads * head_dim,
161
+ num_groups * head_dim,
162
+ num_groups * head_dim,
163
+ ),
164
+ has_biases=self.has_qkv_biases,
165
+ )
166
+ out_projection = self.out_projection_config.empty(
167
+ num_heads * head_dim,
168
+ (model_dim,),
169
+ has_biases=self.has_out_biases,
170
+ )
171
+
172
+ if self.query_norm_config is not None:
173
+ query_norm = self.query_norm_config.empty(
174
+ input_dim=head_dim,
175
+ )
176
+ else:
177
+ query_norm = None
178
+
179
+ if self.key_norm_config is not None:
180
+ key_norm = self.key_norm_config.empty(
181
+ input_dim=head_dim,
129
182
  )
130
183
  else:
131
184
  key_norm = None
@@ -233,6 +286,7 @@ class Attention(LalamoModule[AttentionConfig]):
233
286
  f" got {v_output_dim}",
234
287
  )
235
288
 
289
+ @eqx.filter_jit
236
290
  def __call__(
237
291
  self,
238
292
  inputs: Float[Array, "suffix_tokens channels"],
@@ -314,8 +368,8 @@ class Attention(LalamoModule[AttentionConfig]):
314
368
  def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
315
369
  return StaticKVCacheLayer.empty(capacity, self.num_groups, self.head_dim, self.activation_precision)
316
370
 
317
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
318
- result = ParameterDict(
371
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
372
+ result = dict(
319
373
  qkv_projection=self.qkv_projection.export_weights(weight_layout),
320
374
  out_projection=self.out_projection.export_weights(weight_layout),
321
375
  )
@@ -324,3 +378,29 @@ class Attention(LalamoModule[AttentionConfig]):
324
378
  if self.key_norm is not None:
325
379
  result["key_norm"] = self.key_norm.export_weights(weight_layout)
326
380
  return result
381
+
382
+ def import_weights(
383
+ self,
384
+ weights: ParameterTree[Array],
385
+ weight_layout: WeightLayout = WeightLayout.AUTO,
386
+ ) -> Self:
387
+ assert isinstance(weights, Mapping)
388
+ assert isinstance(weights["qkv_projection"], Mapping)
389
+ assert isinstance(weights["out_projection"], Mapping)
390
+ if self.query_norm is not None:
391
+ assert isinstance(weights["query_norm"], Mapping)
392
+ query_norm = self.query_norm.import_weights(weights["query_norm"], weight_layout)
393
+ else:
394
+ query_norm = None
395
+ if self.key_norm is not None:
396
+ assert isinstance(weights["key_norm"], Mapping)
397
+ key_norm = self.key_norm.import_weights(weights["key_norm"], weight_layout)
398
+ else:
399
+ key_norm = None
400
+ return replace(
401
+ self,
402
+ qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"], weight_layout),
403
+ out_projection=self.out_projection.import_weights(weights["out_projection"], weight_layout),
404
+ query_norm=query_norm,
405
+ key_norm=key_norm,
406
+ )
lalamo/modules/common.py CHANGED
@@ -2,19 +2,23 @@ from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from types import UnionType
5
+ from typing import Self
5
6
 
6
7
  import equinox as eqx
7
8
  from cattrs import Converter
9
+ from einops import rearrange
8
10
  from jax import numpy as jnp
9
- from jaxtyping import DTypeLike
11
+ from jaxtyping import Array, DTypeLike, Float
10
12
 
11
- from lalamo.common import ParameterDict
13
+ from lalamo.common import ParameterTree
12
14
 
13
15
  __all__ = [
14
16
  "AttentionType",
15
17
  "DummyUnionMember",
16
18
  "LalamoModule",
17
19
  "config_converter",
20
+ "from_layout",
21
+ "into_layout",
18
22
  "register_config_union",
19
23
  ]
20
24
 
@@ -34,6 +38,42 @@ class WeightLayout(Enum):
34
38
  return "(output, input)"
35
39
 
36
40
 
41
+ _DEFAULT_WEIGHT_LAYOUT = WeightLayout.INPUT_OUTPUT
42
+
43
+
44
+ def into_layout(
45
+ weights: Float[Array, "in_channels out_channels"],
46
+ layout: WeightLayout,
47
+ ) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
48
+ if layout == WeightLayout.AUTO:
49
+ layout = _DEFAULT_WEIGHT_LAYOUT
50
+ match layout:
51
+ case WeightLayout.OUTPUT_INPUT:
52
+ return weights
53
+ case WeightLayout.INPUT_OUTPUT:
54
+ return rearrange(
55
+ weights,
56
+ "total_out_channels in_channels -> in_channels total_out_channels",
57
+ )
58
+
59
+
60
+ def from_layout(
61
+ weights: ParameterTree | Array,
62
+ layout: WeightLayout,
63
+ ) -> Array:
64
+ assert isinstance(weights, Array)
65
+ if layout == WeightLayout.AUTO:
66
+ layout = _DEFAULT_WEIGHT_LAYOUT
67
+ match layout:
68
+ case WeightLayout.OUTPUT_INPUT:
69
+ return weights
70
+ case WeightLayout.INPUT_OUTPUT:
71
+ return rearrange(
72
+ weights,
73
+ "in_channels total_out_channels -> total_out_channels in_channels",
74
+ )
75
+
76
+
37
77
  class AttentionType(Enum):
38
78
  GLOBAL = "global"
39
79
  SLIDING_WINDOW = "sliding_window"
@@ -47,7 +87,14 @@ class LalamoModule[ConfigT](eqx.Module):
47
87
  def activation_precision(self) -> DTypeLike: ...
48
88
 
49
89
  @abstractmethod
50
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: ...
90
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree[Array]: ...
91
+
92
+ @abstractmethod
93
+ def import_weights(
94
+ self,
95
+ weights: ParameterTree[Array],
96
+ weight_layout: WeightLayout = WeightLayout.AUTO,
97
+ ) -> Self: ...
51
98
 
52
99
 
53
100
  def _dtype_to_str(dtype: DTypeLike) -> str:
@@ -115,7 +162,7 @@ def register_config_union(union_type: UnionType) -> None:
115
162
  new_config = dict(config)
116
163
  type_name = new_config.pop("type")
117
164
  target_type = name_to_type[type_name]
118
- return name_to_type[type_name](**config_converter.structure(new_config, target_type))
165
+ return config_converter.structure(new_config, target_type)
119
166
 
120
167
  config_converter.register_structure_hook(
121
168
  union_type,