lalamo 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/common.py +79 -29
  3. lalamo/language_model.py +106 -83
  4. lalamo/main.py +91 -18
  5. lalamo/message_processor.py +170 -0
  6. lalamo/model_import/common.py +159 -43
  7. lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
  8. lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
  9. lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
  10. lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
  11. lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
  12. lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
  13. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
  14. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
  15. lalamo/model_import/huggingface_generation_config.py +44 -0
  16. lalamo/model_import/huggingface_tokenizer_config.py +85 -0
  17. lalamo/model_import/loaders/common.py +2 -1
  18. lalamo/model_import/loaders/huggingface.py +12 -10
  19. lalamo/model_import/model_specs/__init__.py +3 -2
  20. lalamo/model_import/model_specs/common.py +31 -32
  21. lalamo/model_import/model_specs/deepseek.py +1 -10
  22. lalamo/model_import/model_specs/gemma.py +2 -25
  23. lalamo/model_import/model_specs/huggingface.py +2 -12
  24. lalamo/model_import/model_specs/llama.py +2 -58
  25. lalamo/model_import/model_specs/mistral.py +9 -19
  26. lalamo/model_import/model_specs/pleias.py +3 -13
  27. lalamo/model_import/model_specs/polaris.py +5 -7
  28. lalamo/model_import/model_specs/qwen.py +12 -111
  29. lalamo/model_import/model_specs/reka.py +4 -13
  30. lalamo/modules/__init__.py +2 -1
  31. lalamo/modules/attention.py +90 -10
  32. lalamo/modules/common.py +51 -4
  33. lalamo/modules/decoder.py +90 -8
  34. lalamo/modules/decoder_layer.py +85 -8
  35. lalamo/modules/embedding.py +95 -29
  36. lalamo/modules/kv_cache.py +3 -3
  37. lalamo/modules/linear.py +170 -130
  38. lalamo/modules/mlp.py +40 -7
  39. lalamo/modules/normalization.py +24 -6
  40. lalamo/modules/rope.py +24 -6
  41. lalamo/sampling.py +99 -0
  42. lalamo/utils.py +86 -1
  43. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/METADATA +6 -6
  44. lalamo-0.3.1.dist-info/RECORD +58 -0
  45. lalamo-0.2.7.dist-info/RECORD +0 -54
  46. /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
  47. /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
  48. /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
  49. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/WHEEL +0 -0
  50. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/entry_points.txt +0 -0
  51. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/licenses/LICENSE +0 -0
  52. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py CHANGED
@@ -1,19 +1,25 @@
1
1
  import math
2
2
  from abc import abstractmethod
3
- from collections.abc import Sequence
4
- from dataclasses import dataclass
5
- from typing import NamedTuple
3
+ from collections.abc import Mapping, Sequence
4
+ from dataclasses import dataclass, replace
5
+ from typing import NamedTuple, Self
6
6
 
7
7
  import equinox as eqx
8
8
  import jax
9
+ import jax.numpy as jnp
9
10
  from einops import rearrange
10
- from jax import numpy as jnp
11
11
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
12
 
13
- from lalamo.common import ParameterDict
13
+ from lalamo.common import ParameterTree, dummy_array
14
14
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
15
15
 
16
- from .common import LalamoModule, WeightLayout, register_config_union
16
+ from .common import (
17
+ LalamoModule,
18
+ WeightLayout,
19
+ from_layout,
20
+ into_layout,
21
+ register_config_union,
22
+ )
17
23
 
18
24
  __all__ = [
19
25
  "FullPrecisionLinear",
@@ -48,30 +54,11 @@ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
48
54
  inputs: Float[Array, " in_channels"],
49
55
  ) -> tuple[Float[Array, " out_channels"], ...]: ...
50
56
 
51
- @classmethod
52
- def _default_weight_layout(cls) -> WeightLayout:
53
- return WeightLayout.INPUT_OUTPUT
54
-
55
- @classmethod
56
- def _into_layout(
57
- cls,
58
- weights: Float[Array, "in_channels out_channels"],
59
- layout: WeightLayout,
60
- ) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
61
- if layout == WeightLayout.AUTO:
62
- layout = cls._default_weight_layout()
63
- match layout:
64
- case WeightLayout.OUTPUT_INPUT:
65
- return weights
66
- case WeightLayout.INPUT_OUTPUT:
67
- return rearrange(
68
- weights,
69
- "total_out_channels in_channels -> in_channels total_out_channels",
70
- )
71
- raise ValueError(f"Unsupported weight layout: {layout}")
57
+ def __post_init__(self) -> None:
58
+ assert isinstance(self.output_dims, tuple)
72
59
 
73
- @classmethod
74
- def _get_split_points(cls, output_dims: Sequence[int]) -> tuple[int, ...]:
60
+ @staticmethod
61
+ def _get_split_points(output_dims: Sequence[int]) -> tuple[int, ...]:
75
62
  result = []
76
63
  last_split_point = 0
77
64
  for dim in output_dims[:-1]:
@@ -92,6 +79,14 @@ class LinearConfigBase:
92
79
  key: PRNGKeyArray,
93
80
  ) -> LinearBase: ...
94
81
 
82
+ @abstractmethod
83
+ def empty(
84
+ self,
85
+ input_dim: int,
86
+ output_dims: tuple[int, ...],
87
+ has_biases: bool,
88
+ ) -> LinearBase: ...
89
+
95
90
 
96
91
  @dataclass(frozen=True)
97
92
  class FullPrecisionLinearConfig(LinearConfigBase):
@@ -104,7 +99,7 @@ class FullPrecisionLinearConfig(LinearConfigBase):
104
99
  has_biases: bool,
105
100
  *,
106
101
  key: PRNGKeyArray,
107
- ) -> LinearBase:
102
+ ) -> "FullPrecisionLinear":
108
103
  scale = 1 / math.sqrt(input_dim)
109
104
  weights = jax.random.uniform(
110
105
  key,
@@ -125,6 +120,28 @@ class FullPrecisionLinearConfig(LinearConfigBase):
125
120
  biases=biases,
126
121
  )
127
122
 
123
+ def empty(
124
+ self,
125
+ input_dim: int,
126
+ output_dims: tuple[int, ...],
127
+ has_biases: bool,
128
+ ) -> "FullPrecisionLinear":
129
+ weights = dummy_array(
130
+ (sum(output_dims), input_dim),
131
+ dtype=self.precision,
132
+ )
133
+ if has_biases:
134
+ biases = dummy_array((sum(output_dims),), dtype=self.precision)
135
+ else:
136
+ biases = None
137
+
138
+ return FullPrecisionLinear(
139
+ config=self,
140
+ output_dims=output_dims,
141
+ weights=weights,
142
+ biases=biases,
143
+ )
144
+
128
145
 
129
146
  class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
130
147
  weights: Float[Array, "total_out_channels in_channels"]
@@ -148,7 +165,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
148
165
  raise ValueError(
149
166
  f"Weight dtype ({self.weights.dtype}) is not equal to specified precision ({self.config.precision}).",
150
167
  )
151
- w_output_dim, w_input_dim = self.weights.shape
168
+ w_output_dim, _ = self.weights.shape
152
169
  if w_output_dim != sum(self.output_dims):
153
170
  raise ValueError(
154
171
  f"Number of output channels in weights ({w_output_dim}) is not"
@@ -167,18 +184,31 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
167
184
  f"Bias dtype ({self.biases.dtype}) is not equal to specified precision ({self.config.precision}).",
168
185
  )
169
186
 
187
+ @eqx.filter_jit
170
188
  def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
171
189
  result = self.weights @ inputs
172
190
  if self.biases is not None:
173
191
  result = result + self.biases
174
192
  return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
175
193
 
176
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
177
- result = ParameterDict(weights=self._into_layout(self.weights, weight_layout))
194
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
195
+ result = dict(weights=into_layout(self.weights, weight_layout))
178
196
  if self.biases is not None:
179
197
  result["biases"] = self.biases
180
198
  return result
181
199
 
200
+ def import_weights(
201
+ self,
202
+ weights: ParameterTree[Array],
203
+ weight_layout: WeightLayout = WeightLayout.AUTO,
204
+ ) -> Self:
205
+ assert isinstance(weights, Mapping)
206
+ return replace(
207
+ self,
208
+ weights=from_layout(weights["weights"], weight_layout),
209
+ biases=weights["biases"] if self.has_biases else None,
210
+ )
211
+
182
212
 
183
213
  @dataclass(frozen=True)
184
214
  class GroupQuantizedLinearConfig(LinearConfigBase):
@@ -224,6 +254,34 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
224
254
  biases=biases,
225
255
  )
226
256
 
257
+ def empty(
258
+ self,
259
+ input_dim: int,
260
+ output_dims: tuple[int, ...],
261
+ has_biases: bool,
262
+ ) -> LinearBase:
263
+ weights = dummy_array(
264
+ (sum(output_dims), input_dim),
265
+ dtype=self.activation_precision,
266
+ )
267
+ num_groups = input_dim // self.group_size
268
+ scales = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
269
+
270
+ if has_biases:
271
+ biases = dummy_array((sum(output_dims),), dtype=self.activation_precision)
272
+ else:
273
+ biases = None
274
+ zero_points = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
275
+
276
+ return GroupQuantizedLinear(
277
+ config=self,
278
+ output_dims=output_dims,
279
+ weights=weights,
280
+ scales=scales,
281
+ zero_points=zero_points,
282
+ biases=biases,
283
+ )
284
+
227
285
 
228
286
  class RequantizedWeights(NamedTuple):
229
287
  weights: Int[Array, "total_out_channels in_channels"]
@@ -271,7 +329,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
271
329
  f" ({self.config.activation_precision}).",
272
330
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
273
331
  )
274
- w_output_dim, w_input_dim = self.weights.shape
332
+ w_output_dim, _ = self.weights.shape
275
333
  if w_output_dim != sum(self.output_dims):
276
334
  raise ValueError(
277
335
  f"Number of output channels in weights ({w_output_dim}) is not"
@@ -352,100 +410,20 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
352
410
  inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
353
411
  return self._prepare_scaled_weights() @ inputs
354
412
 
413
+ @eqx.filter_jit
355
414
  def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
356
415
  result = self._apply_weights(inputs)
357
416
  if self.biases is not None:
358
417
  result = result + self.biases
359
418
  return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
360
419
 
361
- def requantize_weights(self, weights, zero_points, scales):
362
- """
363
- Requantize weights from [20, 6144] grouping to [2560, 48] grouping.
364
-
365
- Args:
366
- weights: uint4 array of shape [M, N]
367
- zero_points: uint4 array of shape [M//group_size_0, N//group_size_1]
368
- scales: float16 array of shape [M//group_size_0, N//group_size_1]
369
-
370
- Returns:
371
- new_weights: uint4 array of shape [M, N]
372
- new_zero_points: uint4 array of shape [M, N//128]
373
- new_scales: float16 array of shape [M, N//128]
374
- """
375
- # Get dimensions
376
- M, N = weights.shape
377
- old_groups_0, old_groups_1 = zero_points.shape
378
-
379
- # Calculate old group sizes
380
- old_group_size_0 = M // old_groups_0 # 2560 // 20 = 128
381
- old_group_size_1 = N // old_groups_1 # 6144 // 6144 = 1
382
-
383
- # New group sizes
384
- new_group_size_0 = 1 # 2560 // 2560 = 1
385
- new_group_size_1 = self.config.group_size # 6144 // 48 = 128
386
-
387
- # Step 1: Dequantize with original parameters
388
- # Expand zero_points and scales to match weights shape
389
- zp_expanded = jnp.repeat(jnp.repeat(zero_points, old_group_size_0, axis=0), old_group_size_1, axis=1)
390
- scales_expanded = jnp.repeat(jnp.repeat(scales, old_group_size_0, axis=0), old_group_size_1, axis=1)
391
-
392
- # Dequantize (convert to float for computation)
393
- weights_float = weights.astype(jnp.float32)
394
- zp_float = zp_expanded.astype(jnp.float32)
395
- dequantized = (weights_float - zp_float) * scales_expanded.astype(jnp.float32)
396
-
397
- # Step 2: Requantize with new group structure [2560, 48]
398
- # Reshape for new groups
399
- dequantized_reshaped = dequantized.reshape(
400
- M // new_group_size_0,
401
- new_group_size_0,
402
- N // new_group_size_1,
403
- new_group_size_1,
404
- )
405
-
406
- # Compute new scales and zero points per group
407
- # Move group dimensions to the end for reduction
408
- dequantized_groups = dequantized_reshaped.transpose(0, 2, 1, 3) # [2560, 48, 1, 128]
409
-
410
- # Find min and max per group
411
- group_min = dequantized_groups.min(axis=(2, 3), keepdims=True)
412
- group_max = dequantized_groups.max(axis=(2, 3), keepdims=True)
413
-
414
- # Compute scales (with small epsilon to avoid division by zero)
415
- eps = 1e-6
416
- new_scales = ((group_max - group_min) / 15.0 + eps).astype(scales.dtype)
417
- new_scales = new_scales.squeeze(axis=(2, 3)) # [2560, 48]
418
-
419
- # Compute zero points (quantize to uint4 range 0-15)
420
- new_zero_points = jnp.round(-group_min.squeeze(axis=(2, 3)) / new_scales).astype(jnp.uint4)
421
- new_zero_points = jnp.clip(new_zero_points, 0, 15)
422
-
423
- # Quantize with new parameters
424
- scales_expanded_new = jnp.repeat(new_scales, new_group_size_1, axis=1).reshape(M, N)
425
- zp_expanded_new = jnp.repeat(new_zero_points, new_group_size_1, axis=1).reshape(M, N)
426
-
427
- new_weights = jnp.round(
428
- dequantized / scales_expanded_new.astype(jnp.float32) + zp_expanded_new.astype(jnp.float32),
429
- )
430
- new_weights = jnp.clip(new_weights, 0, 15).astype(jnp.uint4)
431
-
432
- return new_weights, new_zero_points, new_scales
433
-
434
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
435
- exported_weights = self._into_layout(self.int_weights, weight_layout)
436
-
437
- exported_zero_points = self._into_layout(self.int_zero_points, weight_layout)
420
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
421
+ expected_weight_layout = WeightLayout.OUTPUT_INPUT
422
+ exported_weights = into_layout(self.int_weights, expected_weight_layout)
423
+ exported_zero_points = into_layout(self.int_zero_points, expected_weight_layout)
424
+ exported_scales = into_layout(self.scales, expected_weight_layout)
438
425
 
439
- exported_scales = self._into_layout(self.scales, weight_layout)
440
-
441
- # CRIMINAL HACK!!!
442
- exported_weights, exported_zero_points, exported_scales = self.requantize_weights(
443
- exported_weights,
444
- exported_zero_points,
445
- exported_scales,
446
- )
447
-
448
- result = ParameterDict(
426
+ result = dict(
449
427
  weights=exported_weights,
450
428
  zero_points=exported_zero_points,
451
429
  scales=exported_scales,
@@ -454,6 +432,21 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
454
432
  result["biases"] = self.biases
455
433
  return result
456
434
 
435
+ def import_weights(
436
+ self,
437
+ weights: ParameterTree[Array],
438
+ weight_layout: WeightLayout = WeightLayout.AUTO,
439
+ ) -> Self:
440
+ assert isinstance(weights, Mapping)
441
+ assert isinstance(weights["weights"], Array)
442
+ return replace(
443
+ self,
444
+ weights=from_layout(weights["weights"].astype(self.weights.dtype), weight_layout),
445
+ scales=from_layout(weights["scales"], weight_layout),
446
+ zero_points=from_layout(weights["zero_points"], weight_layout).astype(self.zero_points.dtype),
447
+ biases=weights["biases"] if self.has_biases else None,
448
+ )
449
+
457
450
 
458
451
  class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
459
452
  pass
@@ -512,6 +505,38 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
512
505
  lora_up_weights=lora_up_weights,
513
506
  )
514
507
 
508
+ def empty(
509
+ self,
510
+ input_dim: int,
511
+ output_dims: tuple[int, ...],
512
+ has_biases: bool,
513
+ ) -> LinearBase:
514
+ group_quantized_linear = super().empty(input_dim, output_dims, has_biases)
515
+ assert isinstance(group_quantized_linear, GroupQuantizedLinear)
516
+ hidden_lora_rank = len(output_dims) * self.lora_rank
517
+ lora_down_weights = dummy_array(
518
+ (hidden_lora_rank, input_dim),
519
+ dtype=self.activation_precision,
520
+ )
521
+ lora_up_weights = tuple(
522
+ dummy_array(
523
+ (output_dim, self.lora_rank),
524
+ dtype=self.activation_precision,
525
+ )
526
+ for output_dim in output_dims
527
+ )
528
+
529
+ return QLoRALinear(
530
+ config=self,
531
+ output_dims=output_dims,
532
+ weights=group_quantized_linear.weights,
533
+ scales=group_quantized_linear.scales,
534
+ biases=group_quantized_linear.biases,
535
+ zero_points=group_quantized_linear.zero_points,
536
+ lora_down_weights=lora_down_weights,
537
+ lora_up_weights=lora_up_weights,
538
+ )
539
+
515
540
 
516
541
  class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
517
542
  lora_down_weights: Float[Array, "total_lora_channels in_channels"]
@@ -564,6 +589,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
564
589
  f" equal to lora_rank ({self.config.lora_rank}).",
565
590
  )
566
591
 
592
+ @eqx.filter_jit
567
593
  def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
568
594
  joint_q_out = self._apply_weights(inputs)
569
595
  q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
@@ -584,16 +610,30 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
584
610
 
585
611
  return tuple(results)
586
612
 
587
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
588
- quantized_linear_weights = super().export_weights()
589
- exported_lora_down_weights = self._into_layout(self.lora_down_weights, weight_layout)
590
- exported_lora_up_weights = tuple(
591
- self._into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
592
- )
593
- return ParameterDict(
613
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
614
+ quantized_linear_weights: dict[str, ParameterTree] = super().export_weights() # type: ignore
615
+ exported_lora_down_weights = into_layout(self.lora_down_weights, weight_layout)
616
+ exported_lora_up_weights = [
617
+ into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
618
+ ]
619
+ return dict(
620
+ down_weights=into_layout(exported_lora_down_weights, weight_layout),
621
+ up_weights=[into_layout(w, weight_layout) for w in exported_lora_up_weights],
594
622
  **quantized_linear_weights,
595
- down_weights=exported_lora_down_weights,
596
- up_weights=exported_lora_up_weights,
623
+ )
624
+
625
+ def import_weights(
626
+ self,
627
+ weights: ParameterTree[Array],
628
+ weight_layout: WeightLayout = WeightLayout.AUTO,
629
+ ) -> Self:
630
+ base = super().import_weights(weights, weight_layout)
631
+ assert isinstance(weights, Mapping)
632
+ assert isinstance(weights["up_weights"], Sequence)
633
+ return replace(
634
+ base,
635
+ lora_down_weights=from_layout(weights["down_weights"], weight_layout),
636
+ lora_up_weights=tuple(from_layout(up_weights, weight_layout) for up_weights in weights["up_weights"]),
597
637
  )
598
638
 
599
639
 
lalamo/modules/mlp.py CHANGED
@@ -1,9 +1,12 @@
1
- from dataclasses import dataclass
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass, replace
3
+ from typing import Self
2
4
 
5
+ import equinox as eqx
3
6
  import jax
4
7
  from jaxtyping import Array, DTypeLike, Float, PRNGKeyArray
5
8
 
6
- from lalamo.common import ParameterDict
9
+ from lalamo.common import ParameterTree
7
10
 
8
11
  from .activations import Activation
9
12
  from .common import LalamoModule, WeightLayout
@@ -35,8 +38,23 @@ class MLPConfig:
35
38
  ),
36
39
  )
37
40
 
41
+ def empty(self, model_dim: int, hidden_dim: int) -> "MLP":
42
+ return MLP(
43
+ self,
44
+ up_projection=self.linear_config.empty(
45
+ model_dim,
46
+ (hidden_dim, hidden_dim),
47
+ has_biases=False,
48
+ ),
49
+ down_projection=self.linear_config.empty(
50
+ hidden_dim,
51
+ (model_dim,),
52
+ has_biases=False,
53
+ ),
54
+ )
38
55
 
39
- class MLP(LalamoModule):
56
+
57
+ class MLP(LalamoModule[MLPConfig]):
40
58
  up_projection: LinearBase
41
59
  down_projection: LinearBase
42
60
 
@@ -66,14 +84,29 @@ class MLP(LalamoModule):
66
84
  f" the up projection output dimension {self.up_projection.input_dim}",
67
85
  )
68
86
 
87
+ @eqx.filter_jit
69
88
  def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
70
89
  up_proj, gate = self.up_projection(inputs)
71
90
  gate = self.config.activation(gate)
72
91
  (result,) = self.down_projection(up_proj * gate)
73
92
  return result
74
93
 
75
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
76
- return ParameterDict(
77
- up_projection=self.up_projection.export_weights(weight_layout),
78
- down_projection=self.down_projection.export_weights(weight_layout),
94
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
95
+ return {
96
+ "up_projection": self.up_projection.export_weights(weight_layout),
97
+ "down_projection": self.down_projection.export_weights(weight_layout),
98
+ }
99
+
100
+ def import_weights(
101
+ self,
102
+ weights: ParameterTree[Array],
103
+ weight_layout: WeightLayout = WeightLayout.AUTO,
104
+ ) -> Self:
105
+ assert isinstance(weights, Mapping)
106
+ assert isinstance(weights["up_projection"], Mapping)
107
+ assert isinstance(weights["down_projection"], Mapping)
108
+ return replace(
109
+ self,
110
+ up_projection=self.up_projection.import_weights(weights["up_projection"], weight_layout),
111
+ down_projection=self.down_projection.import_weights(weights["down_projection"], weight_layout),
79
112
  )
@@ -1,11 +1,14 @@
1
- from dataclasses import dataclass
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass, replace
2
3
  from enum import Enum
4
+ from typing import Self
3
5
 
6
+ import equinox as eqx
4
7
  import jax
5
8
  from jax import numpy as jnp
6
9
  from jaxtyping import Array, DTypeLike, Float
7
10
 
8
- from lalamo.common import ParameterDict
11
+ from lalamo.common import ParameterTree, dummy_array
9
12
 
10
13
  from .common import LalamoModule, WeightLayout
11
14
 
@@ -29,10 +32,16 @@ class RMSNormConfig:
29
32
  scale_offset: float | None
30
33
  upcast_mode: UpcastMode
31
34
 
32
- def init(self, channels: int) -> "RMSNorm":
33
- scales = jnp.ones(channels, dtype=self.scale_precision)
35
+ def init(self, input_dim: int) -> "RMSNorm":
36
+ scales = jnp.ones(input_dim, dtype=self.scale_precision)
34
37
  return RMSNorm(self, scales=scales)
35
38
 
39
+ def empty(self, input_dim: int) -> "RMSNorm":
40
+ return RMSNorm(
41
+ config=self,
42
+ scales=dummy_array(input_dim, dtype=self.scale_precision),
43
+ )
44
+
36
45
 
37
46
  class RMSNorm(LalamoModule[RMSNormConfig]):
38
47
  scales: Float[Array, " channels"]
@@ -53,6 +62,7 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
53
62
  f" specified precision {self.config.scale_precision}",
54
63
  )
55
64
 
65
+ @eqx.filter_jit
56
66
  def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
57
67
  upcasted_inputs = inputs.astype(self.config.accumulation_precision)
58
68
 
@@ -73,5 +83,13 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
73
83
  result = normalized_x * adjusted_scales
74
84
  return result.astype(inputs.dtype)
75
85
 
76
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
77
- return ParameterDict(scales=self.scales)
86
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
87
+ return {"scales": self.scales}
88
+
89
+ def import_weights(
90
+ self,
91
+ weights: ParameterTree[Array],
92
+ weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
93
+ ) -> Self:
94
+ assert isinstance(weights, Mapping)
95
+ return replace(self, scales=weights["scales"])
lalamo/modules/rope.py CHANGED
@@ -16,13 +16,14 @@
16
16
  # limitations under the License.
17
17
 
18
18
  import math
19
- from dataclasses import dataclass
19
+ from collections.abc import Mapping
20
+ from dataclasses import dataclass, replace
20
21
 
21
22
  import equinox as eqx
22
23
  from jax import numpy as jnp
23
24
  from jaxtyping import Array, DTypeLike, Float, Int
24
25
 
25
- from lalamo.common import ParameterDict
26
+ from lalamo.common import ParameterTree
26
27
 
27
28
  from .common import LalamoModule, WeightLayout, register_config_union
28
29
 
@@ -53,8 +54,8 @@ class PositionalEmbeddings(eqx.Module):
53
54
  def apply(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
54
55
  return heads * self.cosines + self.rotate_half(heads) * self.sines
55
56
 
56
- def export(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
57
- return ParameterDict(
57
+ def export(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
58
+ return dict(
58
59
  cosines=self.cosines,
59
60
  sines=self.sines,
60
61
  )
@@ -103,6 +104,11 @@ class RoPE(LalamoModule[RoPEConfigBase]):
103
104
  return self.config.precision
104
105
 
105
106
  def __post_init__(self) -> None:
107
+ num_tokens, _ = self.sines.shape
108
+ if num_tokens != self.config.max_sequence_length:
109
+ raise ValueError(
110
+ f"{num_tokens} does not match the specified max sequence length {self.config.max_sequence_length}",
111
+ )
106
112
  if self.cosines.dtype != self.config.precision:
107
113
  raise ValueError(
108
114
  f"Cosines dtype {self.cosines.dtype} does not match the specified precision {self.config.precision}",
@@ -127,14 +133,26 @@ class RoPE(LalamoModule[RoPEConfigBase]):
127
133
  result, _ = self.sines.shape
128
134
  return result
129
135
 
136
+ @eqx.filter_jit
130
137
  def __call__(self, timesteps: Int[Array, " tokens"]) -> PositionalEmbeddings:
131
138
  return PositionalEmbeddings(
132
139
  cosines=self.cosines[timesteps],
133
140
  sines=self.sines[timesteps],
134
141
  )
135
142
 
136
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
137
- return ParameterDict(cosines=self.cosines, sines=self.sines)
143
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree[Array]: # noqa: ARG002
144
+ return {
145
+ "cosines": self.cosines,
146
+ "sines": self.sines,
147
+ }
148
+
149
+ def import_weights(
150
+ self,
151
+ weights: ParameterTree[Array],
152
+ weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
153
+ ) -> "RoPE":
154
+ assert isinstance(weights, Mapping)
155
+ return replace(self, cosines=weights["cosines"], sines=weights["sines"])
138
156
 
139
157
 
140
158
  class UnscaledRoPEConfig(RoPEConfigBase):