lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +271 -43
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +17 -7
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -4
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
  48. lalamo-0.4.0.dist-info/RECORD +71 -0
  49. lalamo-0.3.3.dist-info/RECORD +0 -59
  50. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import math
2
- from abc import abstractmethod
2
+ from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping, Sequence
4
4
  from dataclasses import dataclass, replace
5
- from typing import NamedTuple, Self
5
+ from typing import Self
6
6
 
7
7
  import equinox as eqx
8
8
  import jax
@@ -15,9 +15,6 @@ from lalamo.quantization import QuantizationMode, dynamically_quantize_activatio
15
15
 
16
16
  from .common import (
17
17
  LalamoModule,
18
- WeightLayout,
19
- from_layout,
20
- into_layout,
21
18
  register_config_union,
22
19
  )
23
20
 
@@ -36,6 +33,10 @@ __all__ = [
36
33
  class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
37
34
  output_dims: tuple[int, ...] = eqx.field(static=True)
38
35
 
36
+ @property
37
+ @abstractmethod
38
+ def mixture_size(self) -> int | None: ...
39
+
39
40
  @property
40
41
  @abstractmethod
41
42
  def input_dim(self) -> int: ...
@@ -68,7 +69,7 @@ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
68
69
 
69
70
 
70
71
  @dataclass(frozen=True)
71
- class LinearConfigBase:
72
+ class LinearConfigBase(ABC):
72
73
  @abstractmethod
73
74
  def random_init(
74
75
  self,
@@ -79,6 +80,17 @@ class LinearConfigBase:
79
80
  key: PRNGKeyArray,
80
81
  ) -> LinearBase: ...
81
82
 
83
+ @abstractmethod
84
+ def random_init_mixture(
85
+ self,
86
+ mixture_size: int,
87
+ input_dim: int,
88
+ output_dims: tuple[int, ...],
89
+ has_biases: bool,
90
+ *,
91
+ key: PRNGKeyArray,
92
+ ) -> LinearBase: ...
93
+
82
94
  @abstractmethod
83
95
  def empty(
84
96
  self,
@@ -87,6 +99,15 @@ class LinearConfigBase:
87
99
  has_biases: bool,
88
100
  ) -> LinearBase: ...
89
101
 
102
+ @abstractmethod
103
+ def empty_mixture(
104
+ self,
105
+ mixture_size: int,
106
+ input_dim: int,
107
+ output_dims: tuple[int, ...],
108
+ has_biases: bool,
109
+ ) -> LinearBase: ...
110
+
90
111
 
91
112
  @dataclass(frozen=True)
92
113
  class FullPrecisionLinearConfig(LinearConfigBase):
@@ -120,18 +141,31 @@ class FullPrecisionLinearConfig(LinearConfigBase):
120
141
  biases=biases,
121
142
  )
122
143
 
123
- def empty(
144
+ def random_init_mixture(
145
+ self,
146
+ mixture_size: int,
147
+ input_dim: int,
148
+ output_dims: tuple[int, ...],
149
+ has_biases: bool,
150
+ *,
151
+ key: PRNGKeyArray,
152
+ ) -> LinearBase:
153
+ subkeys = jax.random.split(key, mixture_size)
154
+ return eqx.filter_vmap(lambda key: self.random_init(input_dim, output_dims, has_biases, key=key))(subkeys)
155
+
156
+ def _empty_general(
124
157
  self,
158
+ leading_dims: tuple[int, ...],
125
159
  input_dim: int,
126
160
  output_dims: tuple[int, ...],
127
161
  has_biases: bool,
128
162
  ) -> "FullPrecisionLinear":
129
163
  weights = dummy_array(
130
- (sum(output_dims), input_dim),
164
+ (*leading_dims, sum(output_dims), input_dim),
131
165
  dtype=self.precision,
132
166
  )
133
167
  if has_biases:
134
- biases = dummy_array((sum(output_dims),), dtype=self.precision)
168
+ biases = dummy_array((*leading_dims, sum(output_dims)), dtype=self.precision)
135
169
  else:
136
170
  biases = None
137
171
 
@@ -142,10 +176,35 @@ class FullPrecisionLinearConfig(LinearConfigBase):
142
176
  biases=biases,
143
177
  )
144
178
 
179
+ def empty(
180
+ self,
181
+ input_dim: int,
182
+ output_dims: tuple[int, ...],
183
+ has_biases: bool,
184
+ ) -> "FullPrecisionLinear":
185
+ return self._empty_general((), input_dim, output_dims, has_biases)
186
+
187
+ def empty_mixture(
188
+ self,
189
+ mixture_size: int,
190
+ input_dim: int,
191
+ output_dims: tuple[int, ...],
192
+ has_biases: bool,
193
+ ) -> "FullPrecisionLinear":
194
+ return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
195
+
145
196
 
146
197
  class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
147
- weights: Float[Array, "total_out_channels in_channels"]
148
- biases: Float[Array, " total_out_channels"] | None
198
+ weights: Float[Array, "*components total_out_channels in_channels"]
199
+ biases: Float[Array, "*components total_out_channels"] | None
200
+
201
+ @property
202
+ def mixture_size(self) -> int | None:
203
+ match self.weights.shape:
204
+ case [num_components, _, _]:
205
+ return num_components
206
+ case _:
207
+ return None
149
208
 
150
209
  @property
151
210
  def activation_precision(self) -> DTypeLike:
@@ -153,7 +212,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
153
212
 
154
213
  @property
155
214
  def input_dim(self) -> int:
156
- _, input_dim = self.weights.shape
215
+ *_, _, input_dim = self.weights.shape
157
216
  return input_dim
158
217
 
159
218
  @property
@@ -165,7 +224,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
165
224
  raise ValueError(
166
225
  f"Weight dtype ({self.weights.dtype}) is not equal to specified precision ({self.config.precision}).",
167
226
  )
168
- w_output_dim, _ = self.weights.shape
227
+ *w_num_components, w_output_dim, _ = self.weights.shape
169
228
  if w_output_dim != sum(self.output_dims):
170
229
  raise ValueError(
171
230
  f"Number of output channels in weights ({w_output_dim}) is not"
@@ -173,7 +232,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
173
232
  )
174
233
  if self.biases is None:
175
234
  return
176
- (b_output_dim,) = self.biases.shape
235
+ *b_num_components, b_output_dim = self.biases.shape
177
236
  if w_output_dim != b_output_dim:
178
237
  raise ValueError(
179
238
  f"Number of output channels in weights ({w_output_dim}) is not"
@@ -183,16 +242,26 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
183
242
  raise ValueError(
184
243
  f"Bias dtype ({self.biases.dtype}) is not equal to specified precision ({self.config.precision}).",
185
244
  )
245
+ if b_num_components != w_num_components:
246
+ raise ValueError(
247
+ f"Number of mixture components in weights ({w_num_components}) is not"
248
+ f" equal to number of mixture components in biases ({b_num_components}).",
249
+ )
186
250
 
187
251
  @eqx.filter_jit
188
252
  def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
253
+ if self.mixture_size is not None:
254
+ raise ValueError(
255
+ "Mixtures of linear layers cannot be called directly."
256
+ "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
257
+ )
189
258
  result = self.weights @ inputs
190
259
  if self.biases is not None:
191
260
  result = result + self.biases
192
261
  return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
193
262
 
194
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
195
- result = dict(weights=into_layout(self.weights, weight_layout))
263
+ def export_weights(self) -> ParameterTree:
264
+ result = dict(weights=self.weights)
196
265
  if self.biases is not None:
197
266
  result["biases"] = self.biases
198
267
  return result
@@ -200,12 +269,11 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
200
269
  def import_weights(
201
270
  self,
202
271
  weights: ParameterTree[Array],
203
- weight_layout: WeightLayout = WeightLayout.AUTO,
204
272
  ) -> Self:
205
273
  assert isinstance(weights, Mapping)
206
274
  return replace(
207
275
  self,
208
- weights=from_layout(weights["weights"], weight_layout),
276
+ weights=weights["weights"],
209
277
  biases=weights["biases"] if self.has_biases else None,
210
278
  )
211
279
 
@@ -254,24 +322,37 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
254
322
  biases=biases,
255
323
  )
256
324
 
257
- def empty(
325
+ def random_init_mixture(
258
326
  self,
327
+ mixture_size: int,
328
+ input_dim: int,
329
+ output_dims: tuple[int, ...],
330
+ has_biases: bool,
331
+ *,
332
+ key: PRNGKeyArray,
333
+ ) -> LinearBase:
334
+ subkeys = jax.random.split(key, mixture_size)
335
+ return eqx.filter_vmap(lambda key: self.random_init(input_dim, output_dims, has_biases, key=key))(subkeys)
336
+
337
+ def _empty_general(
338
+ self,
339
+ leading_dims: tuple[int, ...],
259
340
  input_dim: int,
260
341
  output_dims: tuple[int, ...],
261
342
  has_biases: bool,
262
343
  ) -> LinearBase:
263
344
  weights = dummy_array(
264
- (sum(output_dims), input_dim),
345
+ (*leading_dims, sum(output_dims), input_dim),
265
346
  dtype=self.activation_precision,
266
347
  )
267
348
  num_groups = input_dim // self.group_size
268
- scales = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
349
+ scales = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
269
350
 
270
351
  if has_biases:
271
- biases = dummy_array((sum(output_dims),), dtype=self.activation_precision)
352
+ biases = dummy_array((*leading_dims, sum(output_dims)), dtype=self.activation_precision)
272
353
  else:
273
354
  biases = None
274
- zero_points = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
355
+ zero_points = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
275
356
 
276
357
  return GroupQuantizedLinear(
277
358
  config=self,
@@ -282,18 +363,37 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
282
363
  biases=biases,
283
364
  )
284
365
 
366
+ def empty(
367
+ self,
368
+ input_dim: int,
369
+ output_dims: tuple[int, ...],
370
+ has_biases: bool,
371
+ ) -> LinearBase:
372
+ return self._empty_general((), input_dim, output_dims, has_biases)
285
373
 
286
- class RequantizedWeights(NamedTuple):
287
- weights: Int[Array, "total_out_channels in_channels"]
288
- zero_points: Int[Array, "groups in_channels"]
289
- scales: Float[Array, "groups in_channels"]
374
+ def empty_mixture(
375
+ self,
376
+ mixture_size: int,
377
+ input_dim: int,
378
+ output_dims: tuple[int, ...],
379
+ has_biases: bool,
380
+ ) -> LinearBase:
381
+ return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
290
382
 
291
383
 
292
384
  class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[ConfigT]):
293
- weights: Float[Array, "total_out_channels in_channels"]
294
- scales: Float[Array, "total_out_channels groups"]
295
- zero_points: Float[Array, "total_out_channels groups"]
296
- biases: Float[Array, " total_out_channels"] | None
385
+ weights: Float[Array, "*components total_out_channels in_channels"]
386
+ scales: Float[Array, "*components total_out_channels groups"]
387
+ zero_points: Float[Array, "*components total_out_channels groups"]
388
+ biases: Float[Array, "*components total_out_channels"] | None
389
+
390
+ @property
391
+ def mixture_size(self) -> int | None:
392
+ match self.weights.shape:
393
+ case [num_components, _, _]:
394
+ return num_components
395
+ case _:
396
+ return None
297
397
 
298
398
  @property
299
399
  def activation_precision(self) -> DTypeLike:
@@ -301,7 +401,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
301
401
 
302
402
  @property
303
403
  def input_dim(self) -> int:
304
- _, input_dim = self.weights.shape
404
+ *_, _, input_dim = self.weights.shape
305
405
  return input_dim
306
406
 
307
407
  @property
@@ -313,23 +413,23 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
313
413
  return self.input_dim // self.config.group_size
314
414
 
315
415
  @property
316
- def int_weights(self) -> Int[Array, "out_channels (groups in_channels)"]:
416
+ def int_weights(self) -> Int[Array, "*components in_channels out_channels"]:
317
417
  result = quantize_weights(self.weights, self.config.weight_quantization_mode)
318
418
  return result.astype(self.config.weight_quantization_mode.dtype)
319
419
 
320
420
  @property
321
- def int_zero_points(self) -> Int[Array, "out_channels (groups in_channels)"]:
421
+ def int_zero_points(self) -> Int[Array, "*components groups out_channels"]:
322
422
  result = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
323
423
  return result.astype(self.config.weight_quantization_mode.dtype)
324
424
 
325
- def __post_init__(self) -> None:
425
+ def __post_init__(self) -> None: # noqa: PLR0912
326
426
  if self.weights.dtype != self.config.activation_precision:
327
427
  raise ValueError(
328
428
  f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
329
429
  f" ({self.config.activation_precision}).",
330
430
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
331
431
  )
332
- w_output_dim, _ = self.weights.shape
432
+ *w_num_components, w_output_dim, _ = self.weights.shape
333
433
  if w_output_dim != sum(self.output_dims):
334
434
  raise ValueError(
335
435
  f"Number of output channels in weights ({w_output_dim}) is not"
@@ -342,12 +442,17 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
342
442
  f" ({self.config.activation_precision}).",
343
443
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
344
444
  )
345
- s_output_dim, s_num_groups = self.scales.shape
445
+ *s_num_components, s_output_dim, s_num_groups = self.scales.shape
346
446
  if w_output_dim != s_output_dim:
347
447
  raise ValueError(
348
448
  f"Number of output channels in weights ({w_output_dim}) is not"
349
449
  f" equal to number of output channels in scales ({s_output_dim}).",
350
450
  )
451
+ if tuple(s_num_components) != tuple(w_num_components):
452
+ raise ValueError(
453
+ f"Number of mixture components in weights ({w_num_components}) is not"
454
+ f" equal to number of mixture components in scales ({s_num_components}).",
455
+ )
351
456
  if s_num_groups != self.num_groups:
352
457
  raise ValueError(
353
458
  f"Number of groups in scales ({s_num_groups}) is incompatible with"
@@ -360,12 +465,17 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
360
465
  f" ({self.config.activation_precision}).",
361
466
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
362
467
  )
363
- (zp_output_dim, zp_num_groups) = self.zero_points.shape
468
+ *zp_num_components, zp_output_dim, zp_num_groups = self.zero_points.shape
364
469
  if w_output_dim != zp_output_dim:
365
470
  raise ValueError(
366
471
  f"Number of output channels in weights ({w_output_dim}) is not"
367
472
  f" equal to number of output channels in zero points ({zp_output_dim}).",
368
473
  )
474
+ if tuple(zp_num_components) != tuple(w_num_components):
475
+ raise ValueError(
476
+ f"Number of mixture components in weights ({w_num_components}) is not"
477
+ f" equal to number of mixture components in zero points ({zp_num_components}).",
478
+ )
369
479
  if self.num_groups != zp_num_groups:
370
480
  raise ValueError(
371
481
  f"Number of groups in zero points ({zp_num_groups}) is incompatible with"
@@ -379,29 +489,34 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
379
489
  f" ({self.config.activation_precision}).",
380
490
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
381
491
  )
382
- (b_output_dim,) = self.biases.shape
492
+ *b_num_components, b_output_dim = self.biases.shape
383
493
  if w_output_dim != b_output_dim:
384
494
  raise ValueError(
385
495
  f"Number of output channels in weights ({w_output_dim}) is not"
386
496
  f" equal to number of output channels in biases ({b_output_dim}).",
387
497
  )
498
+ if tuple(b_num_components) != tuple(w_num_components):
499
+ raise ValueError(
500
+ f"Number of mixture components in weights ({w_num_components}) is not"
501
+ f" equal to number of mixture components in biases ({b_num_components}).",
502
+ )
388
503
 
389
- def _prepare_scaled_weights(self) -> Float[Array, "total_out_channels in_channels"]:
504
+ def _prepare_scaled_weights(self) -> Float[Array, "*components in_channels total_out_channels"]:
390
505
  quantized_weights = quantize_weights(self.weights, self.config.weight_quantization_mode)
391
506
  grouped_weights = rearrange(
392
507
  quantized_weights,
393
- "total_out_channels (groups group_channels) -> total_out_channels groups group_channels",
508
+ "... total_out_channels (groups group_channels) -> ... total_out_channels groups group_channels",
394
509
  groups=self.num_groups,
395
510
  )
396
511
 
397
- zero_points = rearrange(self.zero_points, "total_out_channels groups -> total_out_channels groups 1")
512
+ zero_points = rearrange(self.zero_points, "... total_out_channels groups -> ... total_out_channels groups 1")
398
513
  grouped_weights = grouped_weights - zero_points
399
514
 
400
- scales = rearrange(self.scales, "total_out_channels groups -> total_out_channels groups 1")
515
+ scales = rearrange(self.scales, "... total_out_channels groups -> ... total_out_channels groups 1")
401
516
  scaled_grouped_weights = grouped_weights * scales
402
517
  result = rearrange(
403
518
  scaled_grouped_weights,
404
- "total_out_channels groups group_channels -> total_out_channels (groups group_channels)",
519
+ "... total_out_channels groups group_channels -> ... total_out_channels (groups group_channels)",
405
520
  )
406
521
  return result
407
522
 
@@ -412,21 +527,21 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
412
527
 
413
528
  @eqx.filter_jit
414
529
  def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
530
+ if self.mixture_size is not None:
531
+ raise ValueError(
532
+ "Mixtures of linear layers cannot be called directly."
533
+ "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
534
+ )
415
535
  result = self._apply_weights(inputs)
416
536
  if self.biases is not None:
417
537
  result = result + self.biases
418
538
  return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
419
539
 
420
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
421
- expected_weight_layout = WeightLayout.OUTPUT_INPUT
422
- exported_weights = into_layout(self.int_weights, expected_weight_layout)
423
- exported_zero_points = into_layout(self.int_zero_points, expected_weight_layout)
424
- exported_scales = into_layout(self.scales, expected_weight_layout)
425
-
540
+ def export_weights(self) -> ParameterTree:
426
541
  result = dict(
427
- weights=exported_weights,
428
- zero_points=exported_zero_points,
429
- scales=exported_scales,
542
+ weights=self.int_weights,
543
+ zero_points=self.int_zero_points,
544
+ scales=self.scales,
430
545
  )
431
546
  if self.biases is not None:
432
547
  result["biases"] = self.biases
@@ -435,15 +550,15 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
435
550
  def import_weights(
436
551
  self,
437
552
  weights: ParameterTree[Array],
438
- weight_layout: WeightLayout = WeightLayout.AUTO,
439
553
  ) -> Self:
440
554
  assert isinstance(weights, Mapping)
441
555
  assert isinstance(weights["weights"], Array)
556
+ assert isinstance(weights["zero_points"], Array)
442
557
  return replace(
443
558
  self,
444
- weights=from_layout(weights["weights"].astype(self.weights.dtype), weight_layout),
445
- scales=from_layout(weights["scales"], weight_layout),
446
- zero_points=from_layout(weights["zero_points"], weight_layout).astype(self.zero_points.dtype),
559
+ weights=weights["weights"].astype(self.weights.dtype),
560
+ scales=weights["scales"],
561
+ zero_points=weights["zero_points"].astype(self.zero_points.dtype),
447
562
  biases=weights["biases"] if self.has_biases else None,
448
563
  )
449
564
 
@@ -475,7 +590,7 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
475
590
  max_down_abs_value = 1 / math.sqrt(input_dim)
476
591
  lora_down_weights = jax.random.uniform(
477
592
  down_key,
478
- (hidden_lora_rank, input_dim),
593
+ (input_dim, hidden_lora_rank),
479
594
  minval=-max_down_abs_value,
480
595
  maxval=max_down_abs_value,
481
596
  dtype=self.activation_precision,
@@ -486,7 +601,7 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
486
601
  lora_up_weights = tuple(
487
602
  jax.random.uniform(
488
603
  up_key,
489
- (output_dim, self.lora_rank),
604
+ (self.lora_rank, output_dim),
490
605
  minval=-max_up_abs_value,
491
606
  maxval=max_up_abs_value,
492
607
  dtype=self.activation_precision,
@@ -505,6 +620,18 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
505
620
  lora_up_weights=lora_up_weights,
506
621
  )
507
622
 
623
+ def random_init_mixture(
624
+ self,
625
+ mixture_size: int,
626
+ input_dim: int,
627
+ output_dims: tuple[int, ...],
628
+ has_biases: bool,
629
+ *,
630
+ key: PRNGKeyArray,
631
+ ) -> LinearBase:
632
+ subkeys = jax.random.split(key, mixture_size)
633
+ return eqx.filter_vmap(lambda k: self.random_init(input_dim, output_dims, has_biases, key=k))(subkeys)
634
+
508
635
  def empty(
509
636
  self,
510
637
  input_dim: int,
@@ -515,12 +642,46 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
515
642
  assert isinstance(group_quantized_linear, GroupQuantizedLinear)
516
643
  hidden_lora_rank = len(output_dims) * self.lora_rank
517
644
  lora_down_weights = dummy_array(
518
- (hidden_lora_rank, input_dim),
645
+ (input_dim, hidden_lora_rank),
646
+ dtype=self.activation_precision,
647
+ )
648
+ lora_up_weights = tuple(
649
+ dummy_array(
650
+ (self.lora_rank, output_dim),
651
+ dtype=self.activation_precision,
652
+ )
653
+ for output_dim in output_dims
654
+ )
655
+
656
+ return QLoRALinear(
657
+ config=self,
658
+ output_dims=output_dims,
659
+ weights=group_quantized_linear.weights,
660
+ scales=group_quantized_linear.scales,
661
+ biases=group_quantized_linear.biases,
662
+ zero_points=group_quantized_linear.zero_points,
663
+ lora_down_weights=lora_down_weights,
664
+ lora_up_weights=lora_up_weights,
665
+ )
666
+
667
+ def _empty_general(
668
+ self,
669
+ leading_dims: tuple[int, ...],
670
+ input_dim: int,
671
+ output_dims: tuple[int, ...],
672
+ has_biases: bool,
673
+ ) -> LinearBase:
674
+ group_quantized_linear = super().empty(input_dim, output_dims, has_biases)
675
+ assert isinstance(group_quantized_linear, GroupQuantizedLinear)
676
+
677
+ hidden_lora_rank = len(output_dims) * self.lora_rank
678
+ lora_down_weights = dummy_array(
679
+ (*leading_dims, input_dim, hidden_lora_rank),
519
680
  dtype=self.activation_precision,
520
681
  )
521
682
  lora_up_weights = tuple(
522
683
  dummy_array(
523
- (output_dim, self.lora_rank),
684
+ (*leading_dims, self.lora_rank, output_dim),
524
685
  dtype=self.activation_precision,
525
686
  )
526
687
  for output_dim in output_dims
@@ -537,12 +698,21 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
537
698
  lora_up_weights=lora_up_weights,
538
699
  )
539
700
 
701
+ def empty_mixture(
702
+ self,
703
+ mixture_size: int,
704
+ input_dim: int,
705
+ output_dims: tuple[int, ...],
706
+ has_biases: bool,
707
+ ) -> LinearBase:
708
+ return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
709
+
540
710
 
541
711
  class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
542
- lora_down_weights: Float[Array, "total_lora_channels in_channels"]
543
- lora_up_weights: tuple[Float[Array, "out_channels lora_channels"], ...]
712
+ lora_down_weights: Float[Array, "*components in_channels total_lora_channels"]
713
+ lora_up_weights: tuple[Float[Array, "*components lora_channels out_channels"], ...]
544
714
 
545
- def _split_biases(self) -> tuple[Float[Array, " out_channels"] | None, ...]:
715
+ def _split_biases(self) -> tuple[Float[Array, "*components out_channels"] | None, ...]:
546
716
  if self.biases is not None:
547
717
  return tuple(jnp.split(self.biases, self._get_split_points(self.output_dims)))
548
718
  return (None,) * len(self.output_dims)
@@ -555,7 +725,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
555
725
  f" specified activation precision ({self.config.activation_precision}).",
556
726
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
557
727
  )
558
- lora_down_output_dim, lora_down_input_dim = self.lora_down_weights.shape
728
+ *ld_num_components, lora_down_input_dim, lora_down_output_dim = self.lora_down_weights.shape
559
729
  if lora_down_output_dim != self.config.lora_rank * self.num_outputs:
560
730
  raise ValueError(
561
731
  f"Number of output channels in LORA down weights ({lora_down_output_dim}) is not"
@@ -566,6 +736,12 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
566
736
  f"Number of input channels in LORA down weights ({lora_down_input_dim}) is not"
567
737
  f" equal to input_dim ({self.input_dim}).",
568
738
  )
739
+ *w_num_components, _, _ = self.weights.shape
740
+ if tuple(ld_num_components) != tuple(w_num_components):
741
+ raise ValueError(
742
+ f"Number of mixture components in LORA down weights ({ld_num_components}) is not"
743
+ f" equal to number of mixture components in base weights ({w_num_components}).",
744
+ )
569
745
  if len(self.lora_up_weights) != self.num_outputs:
570
746
  raise ValueError(
571
747
  f"Expected {self.num_outputs} LORA up weights, got {len(self.lora_up_weights)}.",
@@ -577,7 +753,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
577
753
  f" ({self.config.activation_precision}).",
578
754
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
579
755
  )
580
- lora_up_output_dim, lora_up_input_dim = lora_up_weight.shape
756
+ *lu_num_components, lora_up_input_dim, lora_up_output_dim = lora_up_weight.shape
581
757
  if lora_up_output_dim != output_dim:
582
758
  raise ValueError(
583
759
  f"Number of output channels in LORA up weights ({lora_up_output_dim}) is not"
@@ -588,16 +764,26 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
588
764
  f"Number of input channels in LORA up weights ({lora_up_input_dim}) is not"
589
765
  f" equal to lora_rank ({self.config.lora_rank}).",
590
766
  )
767
+ if tuple(lu_num_components) != tuple(w_num_components):
768
+ raise ValueError(
769
+ f"Number of mixture components in LORA up weights ({lu_num_components}) is not"
770
+ f" equal to number of mixture components in base weights ({w_num_components}).",
771
+ )
591
772
 
592
773
  @eqx.filter_jit
593
774
  def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
775
+ if self.mixture_size is not None:
776
+ raise ValueError(
777
+ "Mixtures of linear layers cannot be called directly."
778
+ "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
779
+ )
594
780
  joint_q_out = self._apply_weights(inputs)
595
781
  q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
596
782
 
597
- joint_lora_hidden = self.lora_down_weights @ inputs
783
+ joint_lora_hidden = inputs @ self.lora_down_weights
598
784
  lora_hiddens = jnp.split(joint_lora_hidden, self._get_split_points([self.config.lora_rank] * self.num_outputs))
599
785
  lora_outs = [
600
- lora_up_weight @ lora_hidden
786
+ lora_hidden @ lora_up_weight
601
787
  for lora_up_weight, lora_hidden in zip(self.lora_up_weights, lora_hiddens, strict=True)
602
788
  ]
603
789
 
@@ -610,30 +796,25 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
610
796
 
611
797
  return tuple(results)
612
798
 
613
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
799
+ def export_weights(self) -> ParameterTree:
614
800
  quantized_linear_weights: dict[str, ParameterTree] = super().export_weights() # type: ignore
615
- exported_lora_down_weights = into_layout(self.lora_down_weights, weight_layout)
616
- exported_lora_up_weights = [
617
- into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
618
- ]
619
801
  return dict(
620
- down_weights=into_layout(exported_lora_down_weights, weight_layout),
621
- up_weights=[into_layout(w, weight_layout) for w in exported_lora_up_weights],
802
+ down_weights=self.lora_down_weights,
803
+ up_weights=self.lora_up_weights,
622
804
  **quantized_linear_weights,
623
805
  )
624
806
 
625
807
  def import_weights(
626
808
  self,
627
809
  weights: ParameterTree[Array],
628
- weight_layout: WeightLayout = WeightLayout.AUTO,
629
810
  ) -> Self:
630
- base = super().import_weights(weights, weight_layout)
811
+ base = super().import_weights(weights)
631
812
  assert isinstance(weights, Mapping)
632
813
  assert isinstance(weights["up_weights"], Sequence)
633
814
  return replace(
634
815
  base,
635
- lora_down_weights=from_layout(weights["down_weights"], weight_layout),
636
- lora_up_weights=tuple(from_layout(up_weights, weight_layout) for up_weights in weights["up_weights"]),
816
+ lora_down_weights=weights["down_weights"],
817
+ lora_up_weights=tuple(up_weights for up_weights in weights["up_weights"]),
637
818
  )
638
819
 
639
820