lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +23 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +101 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
  44. lalamo-0.2.2.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,603 @@
1
+ import math
2
+ from abc import abstractmethod
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple
6
+
7
+ import equinox as eqx
8
+ import jax
9
+ from einops import rearrange
10
+ from jax import numpy as jnp
11
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
+
13
+ from lalamo.common import ParameterDict
14
+ from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
15
+
16
+ from .common import LalamoModule, WeightLayout, register_config_union
17
+
18
+ __all__ = [
19
+ "FullPrecisionLinear",
20
+ "FullPrecisionLinearConfig",
21
+ "GroupQuantizedLinear",
22
+ "GroupQuantizedLinearConfig",
23
+ "LinearBase",
24
+ "LinearConfig",
25
+ "QLoRALinear",
26
+ "QLoRALinearConfig",
27
+ ]
28
+
29
+
30
+ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
31
+ output_dims: tuple[int, ...] = eqx.field(static=True)
32
+
33
+ @property
34
+ @abstractmethod
35
+ def input_dim(self) -> int: ...
36
+
37
+ @property
38
+ def num_outputs(self) -> int:
39
+ return len(self.output_dims)
40
+
41
+ @property
42
+ @abstractmethod
43
+ def has_biases(self) -> bool: ...
44
+
45
+ @abstractmethod
46
+ def __call__(
47
+ self,
48
+ inputs: Float[Array, " in_channels"],
49
+ ) -> tuple[Float[Array, " out_channels"], ...]: ...
50
+
51
+ @classmethod
52
+ def _default_weight_layout(cls) -> WeightLayout:
53
+ return WeightLayout.INPUT_OUTPUT
54
+
55
+ @classmethod
56
+ def _into_layout(
57
+ cls,
58
+ weights: Float[Array, "in_channels out_channels"],
59
+ layout: WeightLayout,
60
+ ) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
61
+ if layout == WeightLayout.AUTO:
62
+ layout = cls._default_weight_layout()
63
+ match layout:
64
+ case WeightLayout.OUTPUT_INPUT:
65
+ return weights
66
+ case WeightLayout.INPUT_OUTPUT:
67
+ return rearrange(
68
+ weights,
69
+ "total_out_channels in_channels -> in_channels total_out_channels",
70
+ )
71
+ raise ValueError(f"Unsupported weight layout: {layout}")
72
+
73
+ @classmethod
74
+ def _get_split_points(cls, output_dims: Sequence[int]) -> tuple[int, ...]:
75
+ result = []
76
+ last_split_point = 0
77
+ for dim in output_dims[:-1]:
78
+ last_split_point += dim
79
+ result.append(last_split_point)
80
+ return tuple(result)
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class LinearConfigBase:
85
+ @abstractmethod
86
+ def random_init(
87
+ self,
88
+ input_dim: int,
89
+ output_dims: tuple[int, ...],
90
+ has_biases: bool,
91
+ *,
92
+ key: PRNGKeyArray,
93
+ ) -> LinearBase: ...
94
+
95
+
96
+ @dataclass(frozen=True)
97
+ class FullPrecisionLinearConfig(LinearConfigBase):
98
+ precision: DTypeLike
99
+
100
+ def random_init(
101
+ self,
102
+ input_dim: int,
103
+ output_dims: tuple[int, ...],
104
+ has_biases: bool,
105
+ *,
106
+ key: PRNGKeyArray,
107
+ ) -> LinearBase:
108
+ scale = 1 / math.sqrt(input_dim)
109
+ weights = jax.random.uniform(
110
+ key,
111
+ (sum(output_dims), input_dim),
112
+ minval=-scale,
113
+ maxval=scale,
114
+ dtype=self.precision,
115
+ )
116
+ if has_biases:
117
+ biases = jnp.zeros((sum(output_dims),), dtype=self.precision)
118
+ else:
119
+ biases = None
120
+
121
+ return FullPrecisionLinear(
122
+ config=self,
123
+ output_dims=output_dims,
124
+ weights=weights,
125
+ biases=biases,
126
+ )
127
+
128
+
129
+ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
130
+ weights: Float[Array, "total_out_channels in_channels"]
131
+ biases: Float[Array, " total_out_channels"] | None
132
+
133
+ @property
134
+ def activation_precision(self) -> DTypeLike:
135
+ return self.config.precision
136
+
137
+ @property
138
+ def input_dim(self) -> int:
139
+ _, input_dim = self.weights.shape
140
+ return input_dim
141
+
142
+ @property
143
+ def has_biases(self) -> bool:
144
+ return self.biases is not None
145
+
146
+ def __post_init__(self) -> None:
147
+ if self.weights.dtype != self.config.precision:
148
+ raise ValueError(
149
+ f"Weight dtype ({self.weights.dtype}) is not equal to specified precision ({self.config.precision}).",
150
+ )
151
+ w_output_dim, w_input_dim = self.weights.shape
152
+ if w_output_dim != sum(self.output_dims):
153
+ raise ValueError(
154
+ f"Number of output channels in weights ({w_output_dim}) is not"
155
+ f" equal to sum of output dims ({sum(self.output_dims)}).",
156
+ )
157
+ if self.biases is None:
158
+ return
159
+ (b_output_dim,) = self.biases.shape
160
+ if w_output_dim != b_output_dim:
161
+ raise ValueError(
162
+ f"Number of output channels in weights ({w_output_dim}) is not"
163
+ f" equal to number of output channels in biases ({b_output_dim}).",
164
+ )
165
+ if self.biases.dtype != self.config.precision:
166
+ raise ValueError(
167
+ f"Bias dtype ({self.biases.dtype}) is not equal to specified precision ({self.config.precision}).",
168
+ )
169
+
170
+ def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
171
+ result = self.weights @ inputs
172
+ if self.biases is not None:
173
+ result = result + self.biases
174
+ return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
175
+
176
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
177
+ result = ParameterDict(weights=self._into_layout(self.weights, weight_layout))
178
+ if self.biases is not None:
179
+ result["biases"] = self.biases
180
+ return result
181
+
182
+
183
+ @dataclass(frozen=True)
184
+ class GroupQuantizedLinearConfig(LinearConfigBase):
185
+ group_size: int
186
+ weight_quantization_mode: QuantizationMode
187
+ activation_quantization_mode: QuantizationMode | None
188
+ activation_precision: DTypeLike
189
+
190
+ def random_init(
191
+ self,
192
+ input_dim: int,
193
+ output_dims: tuple[int, ...],
194
+ has_biases: bool,
195
+ *,
196
+ key: PRNGKeyArray,
197
+ ) -> LinearBase:
198
+ min_val, max_val = self.weight_quantization_mode.range
199
+ weights = jax.random.uniform(
200
+ key,
201
+ (sum(output_dims), input_dim),
202
+ minval=min_val - 1,
203
+ maxval=max_val + 1,
204
+ dtype=self.activation_precision,
205
+ )
206
+ num_groups = input_dim // self.group_size
207
+ scale = 1 / ((max_val - min_val) / 2 * math.sqrt(input_dim))
208
+ scales = scale * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
209
+
210
+ if has_biases:
211
+ biases = jnp.zeros((sum(output_dims),), dtype=self.activation_precision)
212
+ else:
213
+ biases = None
214
+
215
+ zero_point = min_val + 2 ** (self.weight_quantization_mode.bits - 1)
216
+ zero_points = zero_point * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
217
+
218
+ return GroupQuantizedLinear(
219
+ config=self,
220
+ output_dims=output_dims,
221
+ weights=weights,
222
+ scales=scales,
223
+ zero_points=zero_points,
224
+ biases=biases,
225
+ )
226
+
227
+
228
+ class RequantizedWeights(NamedTuple):
229
+ weights: Int[Array, "total_out_channels in_channels"]
230
+ zero_points: Int[Array, "groups in_channels"]
231
+ scales: Float[Array, "groups in_channels"]
232
+
233
+
234
+ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[ConfigT]):
235
+ weights: Float[Array, "total_out_channels in_channels"]
236
+ scales: Float[Array, "total_out_channels groups"]
237
+ zero_points: Float[Array, "total_out_channels groups"]
238
+ biases: Float[Array, " total_out_channels"] | None
239
+
240
+ @property
241
+ def activation_precision(self) -> DTypeLike:
242
+ return self.config.activation_precision
243
+
244
+ @property
245
+ def input_dim(self) -> int:
246
+ _, input_dim = self.weights.shape
247
+ return input_dim
248
+
249
+ @property
250
+ def has_biases(self) -> bool:
251
+ return self.biases is not None
252
+
253
+ @property
254
+ def num_groups(self) -> int:
255
+ return self.input_dim // self.config.group_size
256
+
257
+ @property
258
+ def int_weights(self) -> Int[Array, "out_channels (groups in_channels)"]:
259
+ result = quantize_weights(self.weights, self.config.weight_quantization_mode)
260
+ return result.astype(self.config.weight_quantization_mode.dtype)
261
+
262
+ @property
263
+ def int_zero_points(self) -> Int[Array, "out_channels (groups in_channels)"]:
264
+ result = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
265
+ return result.astype(self.config.weight_quantization_mode.dtype)
266
+
267
+ def __post_init__(self) -> None:
268
+ if self.weights.dtype != self.config.activation_precision:
269
+ raise ValueError(
270
+ f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
271
+ f" ({self.config.activation_precision}).",
272
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
273
+ )
274
+ w_output_dim, w_input_dim = self.weights.shape
275
+ if w_output_dim != sum(self.output_dims):
276
+ raise ValueError(
277
+ f"Number of output channels in weights ({w_output_dim}) is not"
278
+ f" equal to sum of output dims ({sum(self.output_dims)}).",
279
+ )
280
+
281
+ if self.scales.dtype != self.config.activation_precision:
282
+ raise ValueError(
283
+ f"Scale dtype ({self.scales.dtype}) is not equal to specified activation precision"
284
+ f" ({self.config.activation_precision}).",
285
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
286
+ )
287
+ s_output_dim, s_num_groups = self.scales.shape
288
+ if w_output_dim != s_output_dim:
289
+ raise ValueError(
290
+ f"Number of output channels in weights ({w_output_dim}) is not"
291
+ f" equal to number of output channels in scales ({s_output_dim}).",
292
+ )
293
+ if s_num_groups != self.num_groups:
294
+ raise ValueError(
295
+ f"Number of groups in scales ({s_num_groups}) is incompatible with"
296
+ f" the specified group size ({self.config.group_size}).",
297
+ )
298
+
299
+ if self.zero_points.dtype != self.config.activation_precision:
300
+ raise ValueError(
301
+ f"Zero point dtype ({self.zero_points.dtype}) is not equal to specified activation precision"
302
+ f" ({self.config.activation_precision}).",
303
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
304
+ )
305
+ (zp_output_dim, zp_num_groups) = self.zero_points.shape
306
+ if w_output_dim != zp_output_dim:
307
+ raise ValueError(
308
+ f"Number of output channels in weights ({w_output_dim}) is not"
309
+ f" equal to number of output channels in zero points ({zp_output_dim}).",
310
+ )
311
+ if self.num_groups != zp_num_groups:
312
+ raise ValueError(
313
+ f"Number of groups in zero points ({zp_num_groups}) is incompatible with"
314
+ f" the specified group size ({self.config.group_size}).",
315
+ )
316
+
317
+ if self.biases is not None:
318
+ if self.biases.dtype != self.config.activation_precision:
319
+ raise ValueError(
320
+ f"Bias dtype ({self.biases.dtype}) is not equal to specified activation precision"
321
+ f" ({self.config.activation_precision}).",
322
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
323
+ )
324
+ (b_output_dim,) = self.biases.shape
325
+ if w_output_dim != b_output_dim:
326
+ raise ValueError(
327
+ f"Number of output channels in weights ({w_output_dim}) is not"
328
+ f" equal to number of output channels in biases ({b_output_dim}).",
329
+ )
330
+
331
+ def _prepare_scaled_weights(self) -> Float[Array, "total_out_channels in_channels"]:
332
+ quantized_weights = quantize_weights(self.weights, self.config.weight_quantization_mode)
333
+ grouped_weights = rearrange(
334
+ quantized_weights,
335
+ "total_out_channels (groups group_channels) -> total_out_channels groups group_channels",
336
+ groups=self.num_groups,
337
+ )
338
+
339
+ zero_points = rearrange(self.zero_points, "total_out_channels groups -> total_out_channels groups 1")
340
+ grouped_weights = grouped_weights - zero_points
341
+
342
+ scales = rearrange(self.scales, "total_out_channels groups -> total_out_channels groups 1")
343
+ scaled_grouped_weights = grouped_weights * scales
344
+ result = rearrange(
345
+ scaled_grouped_weights,
346
+ "total_out_channels groups group_channels -> total_out_channels (groups group_channels)",
347
+ )
348
+ return result
349
+
350
+ def _apply_weights(self, inputs: Float[Array, " in_channels"]) -> Float[Array, " total_out_channels"]:
351
+ if self.config.activation_quantization_mode is not None:
352
+ inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
353
+ return self._prepare_scaled_weights() @ inputs
354
+
355
+ def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
356
+ result = self._apply_weights(inputs)
357
+ if self.biases is not None:
358
+ result = result + self.biases
359
+ return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
360
+
361
+ def requantize_weights(self, weights, zero_points, scales):
362
+ """
363
+ Requantize weights from [20, 6144] grouping to [2560, 48] grouping.
364
+
365
+ Args:
366
+ weights: uint4 array of shape [M, N]
367
+ zero_points: uint4 array of shape [M//group_size_0, N//group_size_1]
368
+ scales: float16 array of shape [M//group_size_0, N//group_size_1]
369
+
370
+ Returns:
371
+ new_weights: uint4 array of shape [M, N]
372
+ new_zero_points: uint4 array of shape [M, N//128]
373
+ new_scales: float16 array of shape [M, N//128]
374
+ """
375
+ # Get dimensions
376
+ M, N = weights.shape
377
+ old_groups_0, old_groups_1 = zero_points.shape
378
+
379
+ # Calculate old group sizes
380
+ old_group_size_0 = M // old_groups_0 # 2560 // 20 = 128
381
+ old_group_size_1 = N // old_groups_1 # 6144 // 6144 = 1
382
+
383
+ # New group sizes
384
+ new_group_size_0 = 1 # 2560 // 2560 = 1
385
+ new_group_size_1 = self.config.group_size # 6144 // 48 = 128
386
+
387
+ # Step 1: Dequantize with original parameters
388
+ # Expand zero_points and scales to match weights shape
389
+ zp_expanded = jnp.repeat(jnp.repeat(zero_points, old_group_size_0, axis=0), old_group_size_1, axis=1)
390
+ scales_expanded = jnp.repeat(jnp.repeat(scales, old_group_size_0, axis=0), old_group_size_1, axis=1)
391
+
392
+ # Dequantize (convert to float for computation)
393
+ weights_float = weights.astype(jnp.float32)
394
+ zp_float = zp_expanded.astype(jnp.float32)
395
+ dequantized = (weights_float - zp_float) * scales_expanded.astype(jnp.float32)
396
+
397
+ # Step 2: Requantize with new group structure [2560, 48]
398
+ # Reshape for new groups
399
+ dequantized_reshaped = dequantized.reshape(
400
+ M // new_group_size_0,
401
+ new_group_size_0,
402
+ N // new_group_size_1,
403
+ new_group_size_1,
404
+ )
405
+
406
+ # Compute new scales and zero points per group
407
+ # Move group dimensions to the end for reduction
408
+ dequantized_groups = dequantized_reshaped.transpose(0, 2, 1, 3) # [2560, 48, 1, 128]
409
+
410
+ # Find min and max per group
411
+ group_min = dequantized_groups.min(axis=(2, 3), keepdims=True)
412
+ group_max = dequantized_groups.max(axis=(2, 3), keepdims=True)
413
+
414
+ # Compute scales (with small epsilon to avoid division by zero)
415
+ eps = 1e-6
416
+ new_scales = ((group_max - group_min) / 15.0 + eps).astype(scales.dtype)
417
+ new_scales = new_scales.squeeze(axis=(2, 3)) # [2560, 48]
418
+
419
+ # Compute zero points (quantize to uint4 range 0-15)
420
+ new_zero_points = jnp.round(-group_min.squeeze(axis=(2, 3)) / new_scales).astype(jnp.uint4)
421
+ new_zero_points = jnp.clip(new_zero_points, 0, 15)
422
+
423
+ # Quantize with new parameters
424
+ scales_expanded_new = jnp.repeat(new_scales, new_group_size_1, axis=1).reshape(M, N)
425
+ zp_expanded_new = jnp.repeat(new_zero_points, new_group_size_1, axis=1).reshape(M, N)
426
+
427
+ new_weights = jnp.round(
428
+ dequantized / scales_expanded_new.astype(jnp.float32) + zp_expanded_new.astype(jnp.float32),
429
+ )
430
+ new_weights = jnp.clip(new_weights, 0, 15).astype(jnp.uint4)
431
+
432
+ return new_weights, new_zero_points, new_scales
433
+
434
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
435
+ exported_weights = self._into_layout(self.int_weights, weight_layout)
436
+
437
+ exported_zero_points = self._into_layout(self.int_zero_points, weight_layout)
438
+
439
+ exported_scales = self._into_layout(self.scales, weight_layout)
440
+
441
+ # CRIMINAL HACK!!!
442
+ exported_weights, exported_zero_points, exported_scales = self.requantize_weights(
443
+ exported_weights,
444
+ exported_zero_points,
445
+ exported_scales,
446
+ )
447
+
448
+ result = ParameterDict(
449
+ weights=exported_weights,
450
+ zero_points=exported_zero_points,
451
+ scales=exported_scales,
452
+ )
453
+ if self.biases is not None:
454
+ result["biases"] = self.biases
455
+ return result
456
+
457
+
458
+ class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
459
+ pass
460
+
461
+
462
+ @dataclass(frozen=True)
463
+ class QLoRALinearConfig(GroupQuantizedLinearConfig):
464
+ lora_rank: int
465
+ lora_scale: float
466
+ activation_precision: DTypeLike
467
+
468
+ def random_init(
469
+ self,
470
+ input_dim: int,
471
+ output_dims: tuple[int, ...],
472
+ has_biases: bool,
473
+ *,
474
+ key: PRNGKeyArray,
475
+ ) -> LinearBase:
476
+ base_key, derived_key = jax.random.split(key)
477
+ group_quantized_linear = super().random_init(input_dim, output_dims, has_biases, key=base_key)
478
+ assert isinstance(group_quantized_linear, GroupQuantizedLinear)
479
+
480
+ down_key, up_key_root = jax.random.split(derived_key)
481
+ hidden_lora_rank = len(output_dims) * self.lora_rank
482
+ max_down_abs_value = 1 / math.sqrt(input_dim)
483
+ lora_down_weights = jax.random.uniform(
484
+ down_key,
485
+ (hidden_lora_rank, input_dim),
486
+ minval=-max_down_abs_value,
487
+ maxval=max_down_abs_value,
488
+ dtype=self.activation_precision,
489
+ )
490
+
491
+ up_keys = jax.random.split(up_key_root, len(output_dims))
492
+ max_up_abs_value = 1 / math.sqrt(hidden_lora_rank)
493
+ lora_up_weights = tuple(
494
+ jax.random.uniform(
495
+ up_key,
496
+ (output_dim, self.lora_rank),
497
+ minval=-max_up_abs_value,
498
+ maxval=max_up_abs_value,
499
+ dtype=self.activation_precision,
500
+ )
501
+ for up_key, output_dim in zip(up_keys, output_dims, strict=True)
502
+ )
503
+
504
+ return QLoRALinear(
505
+ config=self,
506
+ output_dims=output_dims,
507
+ weights=group_quantized_linear.weights,
508
+ scales=group_quantized_linear.scales,
509
+ biases=group_quantized_linear.biases,
510
+ zero_points=group_quantized_linear.zero_points,
511
+ lora_down_weights=lora_down_weights,
512
+ lora_up_weights=lora_up_weights,
513
+ )
514
+
515
+
516
+ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
517
+ lora_down_weights: Float[Array, "total_lora_channels in_channels"]
518
+ lora_up_weights: tuple[Float[Array, "out_channels lora_channels"], ...]
519
+
520
+ def _split_biases(self) -> tuple[Float[Array, " out_channels"] | None, ...]:
521
+ if self.biases is not None:
522
+ return tuple(jnp.split(self.biases, self._get_split_points(self.output_dims)))
523
+ return (None,) * len(self.output_dims)
524
+
525
+ def __post_init__(self) -> None:
526
+ super().__post_init__()
527
+ if self.lora_down_weights.dtype != self.config.activation_precision:
528
+ raise ValueError(
529
+ f"LORA down weight dtype ({self.lora_down_weights.dtype}) is not equal to the"
530
+ f" specified activation precision ({self.config.activation_precision}).",
531
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
532
+ )
533
+ lora_down_output_dim, lora_down_input_dim = self.lora_down_weights.shape
534
+ if lora_down_output_dim != self.config.lora_rank * self.num_outputs:
535
+ raise ValueError(
536
+ f"Number of output channels in LORA down weights ({lora_down_output_dim}) is not"
537
+ f" equal to lora_rank * num_outputs ({self.config.lora_rank * self.num_outputs}).",
538
+ )
539
+ if lora_down_input_dim != self.input_dim:
540
+ raise ValueError(
541
+ f"Number of input channels in LORA down weights ({lora_down_input_dim}) is not"
542
+ f" equal to input_dim ({self.input_dim}).",
543
+ )
544
+ if len(self.lora_up_weights) != self.num_outputs:
545
+ raise ValueError(
546
+ f"Expected {self.num_outputs} LORA up weights, got {len(self.lora_up_weights)}.",
547
+ )
548
+ for lora_up_weight, output_dim in zip(self.lora_up_weights, self.output_dims, strict=True):
549
+ if lora_up_weight.dtype != self.config.activation_precision:
550
+ raise ValueError(
551
+ f"LORA up weight dtype ({lora_up_weight.dtype}) is not equal to specified activation precision"
552
+ f" ({self.config.activation_precision}).",
553
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
554
+ )
555
+ lora_up_output_dim, lora_up_input_dim = lora_up_weight.shape
556
+ if lora_up_output_dim != output_dim:
557
+ raise ValueError(
558
+ f"Number of output channels in LORA up weights ({lora_up_output_dim}) is not"
559
+ f" equal to number of output dims ({self.output_dims}).",
560
+ )
561
+ if lora_up_input_dim != self.config.lora_rank:
562
+ raise ValueError(
563
+ f"Number of input channels in LORA up weights ({lora_up_input_dim}) is not"
564
+ f" equal to lora_rank ({self.config.lora_rank}).",
565
+ )
566
+
567
+ def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
568
+ joint_q_out = self._apply_weights(inputs)
569
+ q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
570
+
571
+ joint_lora_hidden = self.lora_down_weights @ inputs
572
+ lora_hiddens = jnp.split(joint_lora_hidden, self._get_split_points([self.config.lora_rank] * self.num_outputs))
573
+ lora_outs = [
574
+ lora_up_weight @ lora_hidden
575
+ for lora_up_weight, lora_hidden in zip(self.lora_up_weights, lora_hiddens, strict=True)
576
+ ]
577
+
578
+ results = []
579
+ for q_out, lora_out, bias in zip(q_outs, lora_outs, self._split_biases(), strict=True):
580
+ result = q_out + self.config.lora_scale * lora_out
581
+ if bias is not None:
582
+ result = result + bias
583
+ results.append(result)
584
+
585
+ return tuple(results)
586
+
587
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
588
+ quantized_linear_weights = super().export_weights()
589
+ exported_lora_down_weights = self._into_layout(self.lora_down_weights, weight_layout)
590
+ exported_lora_up_weights = tuple(
591
+ self._into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
592
+ )
593
+ return ParameterDict(
594
+ **quantized_linear_weights,
595
+ down_weights=exported_lora_down_weights,
596
+ up_weights=exported_lora_up_weights,
597
+ )
598
+
599
+
600
+ LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | QLoRALinearConfig
601
+
602
+
603
+ register_config_union(LinearConfig)
lalamo/modules/mlp.py ADDED
@@ -0,0 +1,79 @@
1
+ from dataclasses import dataclass
2
+
3
+ import jax
4
+ from jaxtyping import Array, DTypeLike, Float, PRNGKeyArray
5
+
6
+ from lalamo.common import ParameterDict
7
+
8
+ from .activations import Activation
9
+ from .common import LalamoModule, WeightLayout
10
+ from .linear import LinearBase, LinearConfig
11
+
12
+ __all__ = ["MLP", "MLPConfig"]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class MLPConfig:
17
+ linear_config: LinearConfig
18
+ activation: Activation
19
+
20
+ def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MLP":
21
+ up_key, down_key = jax.random.split(key)
22
+ return MLP(
23
+ self,
24
+ up_projection=self.linear_config.random_init(
25
+ model_dim,
26
+ (hidden_dim, hidden_dim),
27
+ has_biases=False,
28
+ key=up_key,
29
+ ),
30
+ down_projection=self.linear_config.random_init(
31
+ hidden_dim,
32
+ (model_dim,),
33
+ has_biases=False,
34
+ key=down_key,
35
+ ),
36
+ )
37
+
38
+
39
+ class MLP(LalamoModule):
40
+ up_projection: LinearBase
41
+ down_projection: LinearBase
42
+
43
+ @property
44
+ def activation_precision(self) -> DTypeLike:
45
+ return self.up_projection.activation_precision
46
+
47
+ @property
48
+ def model_dim(self) -> int:
49
+ return self.up_projection.input_dim
50
+
51
+ @property
52
+ def hidden_dim(self) -> int:
53
+ return self.down_projection.input_dim
54
+
55
+ def __post_init__(self) -> None:
56
+ up_output_dim, gate_output_dim = self.up_projection.output_dims
57
+ if up_output_dim != gate_output_dim:
58
+ raise ValueError(
59
+ f"Up projection output dimension {up_output_dim} does not match"
60
+ f" the gate output dimension {gate_output_dim}",
61
+ )
62
+ (down_output_dim,) = self.down_projection.output_dims
63
+ if self.up_projection.input_dim != down_output_dim:
64
+ raise ValueError(
65
+ f"Down projection input dimension {down_output_dim} does not match"
66
+ f" the up projection output dimension {self.up_projection.input_dim}",
67
+ )
68
+
69
+ def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
70
+ up_proj, gate = self.up_projection(inputs)
71
+ gate = self.config.activation(gate)
72
+ (result,) = self.down_projection(up_proj * gate)
73
+ return result
74
+
75
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
76
+ return ParameterDict(
77
+ up_projection=self.up_projection.export_weights(weight_layout),
78
+ down_projection=self.down_projection.export_weights(weight_layout),
79
+ )