lalamo 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +271 -43
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +10 -6
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -3
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
- lalamo-0.4.0.dist-info/RECORD +71 -0
- lalamo-0.3.4.dist-info/RECORD +0 -59
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from abc import abstractmethod
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
3
|
from collections.abc import Mapping, Sequence
|
|
4
4
|
from dataclasses import dataclass, replace
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Self
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
import jax
|
|
@@ -15,9 +15,6 @@ from lalamo.quantization import QuantizationMode, dynamically_quantize_activatio
|
|
|
15
15
|
|
|
16
16
|
from .common import (
|
|
17
17
|
LalamoModule,
|
|
18
|
-
WeightLayout,
|
|
19
|
-
from_layout,
|
|
20
|
-
into_layout,
|
|
21
18
|
register_config_union,
|
|
22
19
|
)
|
|
23
20
|
|
|
@@ -36,6 +33,10 @@ __all__ = [
|
|
|
36
33
|
class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
|
|
37
34
|
output_dims: tuple[int, ...] = eqx.field(static=True)
|
|
38
35
|
|
|
36
|
+
@property
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def mixture_size(self) -> int | None: ...
|
|
39
|
+
|
|
39
40
|
@property
|
|
40
41
|
@abstractmethod
|
|
41
42
|
def input_dim(self) -> int: ...
|
|
@@ -68,7 +69,7 @@ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
|
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
@dataclass(frozen=True)
|
|
71
|
-
class LinearConfigBase:
|
|
72
|
+
class LinearConfigBase(ABC):
|
|
72
73
|
@abstractmethod
|
|
73
74
|
def random_init(
|
|
74
75
|
self,
|
|
@@ -79,6 +80,17 @@ class LinearConfigBase:
|
|
|
79
80
|
key: PRNGKeyArray,
|
|
80
81
|
) -> LinearBase: ...
|
|
81
82
|
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def random_init_mixture(
|
|
85
|
+
self,
|
|
86
|
+
mixture_size: int,
|
|
87
|
+
input_dim: int,
|
|
88
|
+
output_dims: tuple[int, ...],
|
|
89
|
+
has_biases: bool,
|
|
90
|
+
*,
|
|
91
|
+
key: PRNGKeyArray,
|
|
92
|
+
) -> LinearBase: ...
|
|
93
|
+
|
|
82
94
|
@abstractmethod
|
|
83
95
|
def empty(
|
|
84
96
|
self,
|
|
@@ -87,6 +99,15 @@ class LinearConfigBase:
|
|
|
87
99
|
has_biases: bool,
|
|
88
100
|
) -> LinearBase: ...
|
|
89
101
|
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def empty_mixture(
|
|
104
|
+
self,
|
|
105
|
+
mixture_size: int,
|
|
106
|
+
input_dim: int,
|
|
107
|
+
output_dims: tuple[int, ...],
|
|
108
|
+
has_biases: bool,
|
|
109
|
+
) -> LinearBase: ...
|
|
110
|
+
|
|
90
111
|
|
|
91
112
|
@dataclass(frozen=True)
|
|
92
113
|
class FullPrecisionLinearConfig(LinearConfigBase):
|
|
@@ -120,18 +141,31 @@ class FullPrecisionLinearConfig(LinearConfigBase):
|
|
|
120
141
|
biases=biases,
|
|
121
142
|
)
|
|
122
143
|
|
|
123
|
-
def
|
|
144
|
+
def random_init_mixture(
|
|
145
|
+
self,
|
|
146
|
+
mixture_size: int,
|
|
147
|
+
input_dim: int,
|
|
148
|
+
output_dims: tuple[int, ...],
|
|
149
|
+
has_biases: bool,
|
|
150
|
+
*,
|
|
151
|
+
key: PRNGKeyArray,
|
|
152
|
+
) -> LinearBase:
|
|
153
|
+
subkeys = jax.random.split(key, mixture_size)
|
|
154
|
+
return eqx.filter_vmap(lambda key: self.random_init(input_dim, output_dims, has_biases, key=key))(subkeys)
|
|
155
|
+
|
|
156
|
+
def _empty_general(
|
|
124
157
|
self,
|
|
158
|
+
leading_dims: tuple[int, ...],
|
|
125
159
|
input_dim: int,
|
|
126
160
|
output_dims: tuple[int, ...],
|
|
127
161
|
has_biases: bool,
|
|
128
162
|
) -> "FullPrecisionLinear":
|
|
129
163
|
weights = dummy_array(
|
|
130
|
-
(sum(output_dims), input_dim),
|
|
164
|
+
(*leading_dims, sum(output_dims), input_dim),
|
|
131
165
|
dtype=self.precision,
|
|
132
166
|
)
|
|
133
167
|
if has_biases:
|
|
134
|
-
biases = dummy_array((sum(output_dims)
|
|
168
|
+
biases = dummy_array((*leading_dims, sum(output_dims)), dtype=self.precision)
|
|
135
169
|
else:
|
|
136
170
|
biases = None
|
|
137
171
|
|
|
@@ -142,10 +176,35 @@ class FullPrecisionLinearConfig(LinearConfigBase):
|
|
|
142
176
|
biases=biases,
|
|
143
177
|
)
|
|
144
178
|
|
|
179
|
+
def empty(
|
|
180
|
+
self,
|
|
181
|
+
input_dim: int,
|
|
182
|
+
output_dims: tuple[int, ...],
|
|
183
|
+
has_biases: bool,
|
|
184
|
+
) -> "FullPrecisionLinear":
|
|
185
|
+
return self._empty_general((), input_dim, output_dims, has_biases)
|
|
186
|
+
|
|
187
|
+
def empty_mixture(
|
|
188
|
+
self,
|
|
189
|
+
mixture_size: int,
|
|
190
|
+
input_dim: int,
|
|
191
|
+
output_dims: tuple[int, ...],
|
|
192
|
+
has_biases: bool,
|
|
193
|
+
) -> "FullPrecisionLinear":
|
|
194
|
+
return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
|
|
195
|
+
|
|
145
196
|
|
|
146
197
|
class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
147
|
-
weights: Float[Array, "total_out_channels in_channels"]
|
|
148
|
-
biases: Float[Array, " total_out_channels"] | None
|
|
198
|
+
weights: Float[Array, "*components total_out_channels in_channels"]
|
|
199
|
+
biases: Float[Array, "*components total_out_channels"] | None
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def mixture_size(self) -> int | None:
|
|
203
|
+
match self.weights.shape:
|
|
204
|
+
case [num_components, _, _]:
|
|
205
|
+
return num_components
|
|
206
|
+
case _:
|
|
207
|
+
return None
|
|
149
208
|
|
|
150
209
|
@property
|
|
151
210
|
def activation_precision(self) -> DTypeLike:
|
|
@@ -153,7 +212,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
153
212
|
|
|
154
213
|
@property
|
|
155
214
|
def input_dim(self) -> int:
|
|
156
|
-
_, input_dim = self.weights.shape
|
|
215
|
+
*_, _, input_dim = self.weights.shape
|
|
157
216
|
return input_dim
|
|
158
217
|
|
|
159
218
|
@property
|
|
@@ -165,7 +224,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
165
224
|
raise ValueError(
|
|
166
225
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified precision ({self.config.precision}).",
|
|
167
226
|
)
|
|
168
|
-
w_output_dim, _ = self.weights.shape
|
|
227
|
+
*w_num_components, w_output_dim, _ = self.weights.shape
|
|
169
228
|
if w_output_dim != sum(self.output_dims):
|
|
170
229
|
raise ValueError(
|
|
171
230
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
@@ -173,7 +232,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
173
232
|
)
|
|
174
233
|
if self.biases is None:
|
|
175
234
|
return
|
|
176
|
-
|
|
235
|
+
*b_num_components, b_output_dim = self.biases.shape
|
|
177
236
|
if w_output_dim != b_output_dim:
|
|
178
237
|
raise ValueError(
|
|
179
238
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
@@ -183,16 +242,26 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
183
242
|
raise ValueError(
|
|
184
243
|
f"Bias dtype ({self.biases.dtype}) is not equal to specified precision ({self.config.precision}).",
|
|
185
244
|
)
|
|
245
|
+
if b_num_components != w_num_components:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
248
|
+
f" equal to number of mixture components in biases ({b_num_components}).",
|
|
249
|
+
)
|
|
186
250
|
|
|
187
251
|
@eqx.filter_jit
|
|
188
252
|
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
253
|
+
if self.mixture_size is not None:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
"Mixtures of linear layers cannot be called directly."
|
|
256
|
+
"They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
|
|
257
|
+
)
|
|
189
258
|
result = self.weights @ inputs
|
|
190
259
|
if self.biases is not None:
|
|
191
260
|
result = result + self.biases
|
|
192
261
|
return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
|
|
193
262
|
|
|
194
|
-
def export_weights(self
|
|
195
|
-
result = dict(weights=
|
|
263
|
+
def export_weights(self) -> ParameterTree:
|
|
264
|
+
result = dict(weights=self.weights)
|
|
196
265
|
if self.biases is not None:
|
|
197
266
|
result["biases"] = self.biases
|
|
198
267
|
return result
|
|
@@ -200,12 +269,11 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
200
269
|
def import_weights(
|
|
201
270
|
self,
|
|
202
271
|
weights: ParameterTree[Array],
|
|
203
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
204
272
|
) -> Self:
|
|
205
273
|
assert isinstance(weights, Mapping)
|
|
206
274
|
return replace(
|
|
207
275
|
self,
|
|
208
|
-
weights=
|
|
276
|
+
weights=weights["weights"],
|
|
209
277
|
biases=weights["biases"] if self.has_biases else None,
|
|
210
278
|
)
|
|
211
279
|
|
|
@@ -254,24 +322,37 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
|
|
|
254
322
|
biases=biases,
|
|
255
323
|
)
|
|
256
324
|
|
|
257
|
-
def
|
|
325
|
+
def random_init_mixture(
|
|
258
326
|
self,
|
|
327
|
+
mixture_size: int,
|
|
328
|
+
input_dim: int,
|
|
329
|
+
output_dims: tuple[int, ...],
|
|
330
|
+
has_biases: bool,
|
|
331
|
+
*,
|
|
332
|
+
key: PRNGKeyArray,
|
|
333
|
+
) -> LinearBase:
|
|
334
|
+
subkeys = jax.random.split(key, mixture_size)
|
|
335
|
+
return eqx.filter_vmap(lambda key: self.random_init(input_dim, output_dims, has_biases, key=key))(subkeys)
|
|
336
|
+
|
|
337
|
+
def _empty_general(
|
|
338
|
+
self,
|
|
339
|
+
leading_dims: tuple[int, ...],
|
|
259
340
|
input_dim: int,
|
|
260
341
|
output_dims: tuple[int, ...],
|
|
261
342
|
has_biases: bool,
|
|
262
343
|
) -> LinearBase:
|
|
263
344
|
weights = dummy_array(
|
|
264
|
-
(sum(output_dims), input_dim),
|
|
345
|
+
(*leading_dims, sum(output_dims), input_dim),
|
|
265
346
|
dtype=self.activation_precision,
|
|
266
347
|
)
|
|
267
348
|
num_groups = input_dim // self.group_size
|
|
268
|
-
scales = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
349
|
+
scales = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
269
350
|
|
|
270
351
|
if has_biases:
|
|
271
|
-
biases = dummy_array((sum(output_dims)
|
|
352
|
+
biases = dummy_array((*leading_dims, sum(output_dims)), dtype=self.activation_precision)
|
|
272
353
|
else:
|
|
273
354
|
biases = None
|
|
274
|
-
zero_points = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
355
|
+
zero_points = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
275
356
|
|
|
276
357
|
return GroupQuantizedLinear(
|
|
277
358
|
config=self,
|
|
@@ -282,18 +363,37 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
|
|
|
282
363
|
biases=biases,
|
|
283
364
|
)
|
|
284
365
|
|
|
366
|
+
def empty(
|
|
367
|
+
self,
|
|
368
|
+
input_dim: int,
|
|
369
|
+
output_dims: tuple[int, ...],
|
|
370
|
+
has_biases: bool,
|
|
371
|
+
) -> LinearBase:
|
|
372
|
+
return self._empty_general((), input_dim, output_dims, has_biases)
|
|
285
373
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
374
|
+
def empty_mixture(
|
|
375
|
+
self,
|
|
376
|
+
mixture_size: int,
|
|
377
|
+
input_dim: int,
|
|
378
|
+
output_dims: tuple[int, ...],
|
|
379
|
+
has_biases: bool,
|
|
380
|
+
) -> LinearBase:
|
|
381
|
+
return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
|
|
290
382
|
|
|
291
383
|
|
|
292
384
|
class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[ConfigT]):
|
|
293
|
-
weights: Float[Array, "total_out_channels in_channels"]
|
|
294
|
-
scales: Float[Array, "total_out_channels groups"]
|
|
295
|
-
zero_points: Float[Array, "total_out_channels groups"]
|
|
296
|
-
biases: Float[Array, " total_out_channels"] | None
|
|
385
|
+
weights: Float[Array, "*components total_out_channels in_channels"]
|
|
386
|
+
scales: Float[Array, "*components total_out_channels groups"]
|
|
387
|
+
zero_points: Float[Array, "*components total_out_channels groups"]
|
|
388
|
+
biases: Float[Array, "*components total_out_channels"] | None
|
|
389
|
+
|
|
390
|
+
@property
|
|
391
|
+
def mixture_size(self) -> int | None:
|
|
392
|
+
match self.weights.shape:
|
|
393
|
+
case [num_components, _, _]:
|
|
394
|
+
return num_components
|
|
395
|
+
case _:
|
|
396
|
+
return None
|
|
297
397
|
|
|
298
398
|
@property
|
|
299
399
|
def activation_precision(self) -> DTypeLike:
|
|
@@ -301,7 +401,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
301
401
|
|
|
302
402
|
@property
|
|
303
403
|
def input_dim(self) -> int:
|
|
304
|
-
_, input_dim = self.weights.shape
|
|
404
|
+
*_, _, input_dim = self.weights.shape
|
|
305
405
|
return input_dim
|
|
306
406
|
|
|
307
407
|
@property
|
|
@@ -313,23 +413,23 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
313
413
|
return self.input_dim // self.config.group_size
|
|
314
414
|
|
|
315
415
|
@property
|
|
316
|
-
def int_weights(self) -> Int[Array, "
|
|
416
|
+
def int_weights(self) -> Int[Array, "*components in_channels out_channels"]:
|
|
317
417
|
result = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
318
418
|
return result.astype(self.config.weight_quantization_mode.dtype)
|
|
319
419
|
|
|
320
420
|
@property
|
|
321
|
-
def int_zero_points(self) -> Int[Array, "
|
|
421
|
+
def int_zero_points(self) -> Int[Array, "*components groups out_channels"]:
|
|
322
422
|
result = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
|
|
323
423
|
return result.astype(self.config.weight_quantization_mode.dtype)
|
|
324
424
|
|
|
325
|
-
def __post_init__(self) -> None:
|
|
425
|
+
def __post_init__(self) -> None: # noqa: PLR0912
|
|
326
426
|
if self.weights.dtype != self.config.activation_precision:
|
|
327
427
|
raise ValueError(
|
|
328
428
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
329
429
|
f" ({self.config.activation_precision}).",
|
|
330
430
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
331
431
|
)
|
|
332
|
-
w_output_dim, _ = self.weights.shape
|
|
432
|
+
*w_num_components, w_output_dim, _ = self.weights.shape
|
|
333
433
|
if w_output_dim != sum(self.output_dims):
|
|
334
434
|
raise ValueError(
|
|
335
435
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
@@ -342,12 +442,17 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
342
442
|
f" ({self.config.activation_precision}).",
|
|
343
443
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
344
444
|
)
|
|
345
|
-
s_output_dim, s_num_groups = self.scales.shape
|
|
445
|
+
*s_num_components, s_output_dim, s_num_groups = self.scales.shape
|
|
346
446
|
if w_output_dim != s_output_dim:
|
|
347
447
|
raise ValueError(
|
|
348
448
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
349
449
|
f" equal to number of output channels in scales ({s_output_dim}).",
|
|
350
450
|
)
|
|
451
|
+
if tuple(s_num_components) != tuple(w_num_components):
|
|
452
|
+
raise ValueError(
|
|
453
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
454
|
+
f" equal to number of mixture components in scales ({s_num_components}).",
|
|
455
|
+
)
|
|
351
456
|
if s_num_groups != self.num_groups:
|
|
352
457
|
raise ValueError(
|
|
353
458
|
f"Number of groups in scales ({s_num_groups}) is incompatible with"
|
|
@@ -360,12 +465,17 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
360
465
|
f" ({self.config.activation_precision}).",
|
|
361
466
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
362
467
|
)
|
|
363
|
-
|
|
468
|
+
*zp_num_components, zp_output_dim, zp_num_groups = self.zero_points.shape
|
|
364
469
|
if w_output_dim != zp_output_dim:
|
|
365
470
|
raise ValueError(
|
|
366
471
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
367
472
|
f" equal to number of output channels in zero points ({zp_output_dim}).",
|
|
368
473
|
)
|
|
474
|
+
if tuple(zp_num_components) != tuple(w_num_components):
|
|
475
|
+
raise ValueError(
|
|
476
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
477
|
+
f" equal to number of mixture components in zero points ({zp_num_components}).",
|
|
478
|
+
)
|
|
369
479
|
if self.num_groups != zp_num_groups:
|
|
370
480
|
raise ValueError(
|
|
371
481
|
f"Number of groups in zero points ({zp_num_groups}) is incompatible with"
|
|
@@ -379,29 +489,34 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
379
489
|
f" ({self.config.activation_precision}).",
|
|
380
490
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
381
491
|
)
|
|
382
|
-
|
|
492
|
+
*b_num_components, b_output_dim = self.biases.shape
|
|
383
493
|
if w_output_dim != b_output_dim:
|
|
384
494
|
raise ValueError(
|
|
385
495
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
386
496
|
f" equal to number of output channels in biases ({b_output_dim}).",
|
|
387
497
|
)
|
|
498
|
+
if tuple(b_num_components) != tuple(w_num_components):
|
|
499
|
+
raise ValueError(
|
|
500
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
501
|
+
f" equal to number of mixture components in biases ({b_num_components}).",
|
|
502
|
+
)
|
|
388
503
|
|
|
389
|
-
def _prepare_scaled_weights(self) -> Float[Array, "
|
|
504
|
+
def _prepare_scaled_weights(self) -> Float[Array, "*components in_channels total_out_channels"]:
|
|
390
505
|
quantized_weights = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
391
506
|
grouped_weights = rearrange(
|
|
392
507
|
quantized_weights,
|
|
393
|
-
"total_out_channels (groups group_channels) -> total_out_channels groups group_channels",
|
|
508
|
+
"... total_out_channels (groups group_channels) -> ... total_out_channels groups group_channels",
|
|
394
509
|
groups=self.num_groups,
|
|
395
510
|
)
|
|
396
511
|
|
|
397
|
-
zero_points = rearrange(self.zero_points, "total_out_channels groups -> total_out_channels groups 1")
|
|
512
|
+
zero_points = rearrange(self.zero_points, "... total_out_channels groups -> ... total_out_channels groups 1")
|
|
398
513
|
grouped_weights = grouped_weights - zero_points
|
|
399
514
|
|
|
400
|
-
scales = rearrange(self.scales, "total_out_channels groups -> total_out_channels groups 1")
|
|
515
|
+
scales = rearrange(self.scales, "... total_out_channels groups -> ... total_out_channels groups 1")
|
|
401
516
|
scaled_grouped_weights = grouped_weights * scales
|
|
402
517
|
result = rearrange(
|
|
403
518
|
scaled_grouped_weights,
|
|
404
|
-
"total_out_channels groups group_channels -> total_out_channels (groups group_channels)",
|
|
519
|
+
"... total_out_channels groups group_channels -> ... total_out_channels (groups group_channels)",
|
|
405
520
|
)
|
|
406
521
|
return result
|
|
407
522
|
|
|
@@ -412,21 +527,21 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
412
527
|
|
|
413
528
|
@eqx.filter_jit
|
|
414
529
|
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
530
|
+
if self.mixture_size is not None:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"Mixtures of linear layers cannot be called directly."
|
|
533
|
+
"They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
|
|
534
|
+
)
|
|
415
535
|
result = self._apply_weights(inputs)
|
|
416
536
|
if self.biases is not None:
|
|
417
537
|
result = result + self.biases
|
|
418
538
|
return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
|
|
419
539
|
|
|
420
|
-
def export_weights(self
|
|
421
|
-
expected_weight_layout = WeightLayout.OUTPUT_INPUT
|
|
422
|
-
exported_weights = into_layout(self.int_weights, expected_weight_layout)
|
|
423
|
-
exported_zero_points = into_layout(self.int_zero_points, expected_weight_layout)
|
|
424
|
-
exported_scales = into_layout(self.scales, expected_weight_layout)
|
|
425
|
-
|
|
540
|
+
def export_weights(self) -> ParameterTree:
|
|
426
541
|
result = dict(
|
|
427
|
-
weights=
|
|
428
|
-
zero_points=
|
|
429
|
-
scales=
|
|
542
|
+
weights=self.int_weights,
|
|
543
|
+
zero_points=self.int_zero_points,
|
|
544
|
+
scales=self.scales,
|
|
430
545
|
)
|
|
431
546
|
if self.biases is not None:
|
|
432
547
|
result["biases"] = self.biases
|
|
@@ -435,15 +550,15 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
435
550
|
def import_weights(
|
|
436
551
|
self,
|
|
437
552
|
weights: ParameterTree[Array],
|
|
438
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
439
553
|
) -> Self:
|
|
440
554
|
assert isinstance(weights, Mapping)
|
|
441
555
|
assert isinstance(weights["weights"], Array)
|
|
556
|
+
assert isinstance(weights["zero_points"], Array)
|
|
442
557
|
return replace(
|
|
443
558
|
self,
|
|
444
|
-
weights=
|
|
445
|
-
scales=
|
|
446
|
-
zero_points=
|
|
559
|
+
weights=weights["weights"].astype(self.weights.dtype),
|
|
560
|
+
scales=weights["scales"],
|
|
561
|
+
zero_points=weights["zero_points"].astype(self.zero_points.dtype),
|
|
447
562
|
biases=weights["biases"] if self.has_biases else None,
|
|
448
563
|
)
|
|
449
564
|
|
|
@@ -475,7 +590,7 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
|
475
590
|
max_down_abs_value = 1 / math.sqrt(input_dim)
|
|
476
591
|
lora_down_weights = jax.random.uniform(
|
|
477
592
|
down_key,
|
|
478
|
-
(
|
|
593
|
+
(input_dim, hidden_lora_rank),
|
|
479
594
|
minval=-max_down_abs_value,
|
|
480
595
|
maxval=max_down_abs_value,
|
|
481
596
|
dtype=self.activation_precision,
|
|
@@ -486,7 +601,7 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
|
486
601
|
lora_up_weights = tuple(
|
|
487
602
|
jax.random.uniform(
|
|
488
603
|
up_key,
|
|
489
|
-
(
|
|
604
|
+
(self.lora_rank, output_dim),
|
|
490
605
|
minval=-max_up_abs_value,
|
|
491
606
|
maxval=max_up_abs_value,
|
|
492
607
|
dtype=self.activation_precision,
|
|
@@ -505,6 +620,18 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
|
505
620
|
lora_up_weights=lora_up_weights,
|
|
506
621
|
)
|
|
507
622
|
|
|
623
|
+
def random_init_mixture(
|
|
624
|
+
self,
|
|
625
|
+
mixture_size: int,
|
|
626
|
+
input_dim: int,
|
|
627
|
+
output_dims: tuple[int, ...],
|
|
628
|
+
has_biases: bool,
|
|
629
|
+
*,
|
|
630
|
+
key: PRNGKeyArray,
|
|
631
|
+
) -> LinearBase:
|
|
632
|
+
subkeys = jax.random.split(key, mixture_size)
|
|
633
|
+
return eqx.filter_vmap(lambda k: self.random_init(input_dim, output_dims, has_biases, key=k))(subkeys)
|
|
634
|
+
|
|
508
635
|
def empty(
|
|
509
636
|
self,
|
|
510
637
|
input_dim: int,
|
|
@@ -515,12 +642,46 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
|
515
642
|
assert isinstance(group_quantized_linear, GroupQuantizedLinear)
|
|
516
643
|
hidden_lora_rank = len(output_dims) * self.lora_rank
|
|
517
644
|
lora_down_weights = dummy_array(
|
|
518
|
-
(
|
|
645
|
+
(input_dim, hidden_lora_rank),
|
|
646
|
+
dtype=self.activation_precision,
|
|
647
|
+
)
|
|
648
|
+
lora_up_weights = tuple(
|
|
649
|
+
dummy_array(
|
|
650
|
+
(self.lora_rank, output_dim),
|
|
651
|
+
dtype=self.activation_precision,
|
|
652
|
+
)
|
|
653
|
+
for output_dim in output_dims
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
return QLoRALinear(
|
|
657
|
+
config=self,
|
|
658
|
+
output_dims=output_dims,
|
|
659
|
+
weights=group_quantized_linear.weights,
|
|
660
|
+
scales=group_quantized_linear.scales,
|
|
661
|
+
biases=group_quantized_linear.biases,
|
|
662
|
+
zero_points=group_quantized_linear.zero_points,
|
|
663
|
+
lora_down_weights=lora_down_weights,
|
|
664
|
+
lora_up_weights=lora_up_weights,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
def _empty_general(
|
|
668
|
+
self,
|
|
669
|
+
leading_dims: tuple[int, ...],
|
|
670
|
+
input_dim: int,
|
|
671
|
+
output_dims: tuple[int, ...],
|
|
672
|
+
has_biases: bool,
|
|
673
|
+
) -> LinearBase:
|
|
674
|
+
group_quantized_linear = super().empty(input_dim, output_dims, has_biases)
|
|
675
|
+
assert isinstance(group_quantized_linear, GroupQuantizedLinear)
|
|
676
|
+
|
|
677
|
+
hidden_lora_rank = len(output_dims) * self.lora_rank
|
|
678
|
+
lora_down_weights = dummy_array(
|
|
679
|
+
(*leading_dims, input_dim, hidden_lora_rank),
|
|
519
680
|
dtype=self.activation_precision,
|
|
520
681
|
)
|
|
521
682
|
lora_up_weights = tuple(
|
|
522
683
|
dummy_array(
|
|
523
|
-
(
|
|
684
|
+
(*leading_dims, self.lora_rank, output_dim),
|
|
524
685
|
dtype=self.activation_precision,
|
|
525
686
|
)
|
|
526
687
|
for output_dim in output_dims
|
|
@@ -537,12 +698,21 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
|
537
698
|
lora_up_weights=lora_up_weights,
|
|
538
699
|
)
|
|
539
700
|
|
|
701
|
+
def empty_mixture(
|
|
702
|
+
self,
|
|
703
|
+
mixture_size: int,
|
|
704
|
+
input_dim: int,
|
|
705
|
+
output_dims: tuple[int, ...],
|
|
706
|
+
has_biases: bool,
|
|
707
|
+
) -> LinearBase:
|
|
708
|
+
return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
|
|
709
|
+
|
|
540
710
|
|
|
541
711
|
class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
542
|
-
lora_down_weights: Float[Array, "
|
|
543
|
-
lora_up_weights: tuple[Float[Array, "
|
|
712
|
+
lora_down_weights: Float[Array, "*components in_channels total_lora_channels"]
|
|
713
|
+
lora_up_weights: tuple[Float[Array, "*components lora_channels out_channels"], ...]
|
|
544
714
|
|
|
545
|
-
def _split_biases(self) -> tuple[Float[Array, " out_channels"] | None, ...]:
|
|
715
|
+
def _split_biases(self) -> tuple[Float[Array, "*components out_channels"] | None, ...]:
|
|
546
716
|
if self.biases is not None:
|
|
547
717
|
return tuple(jnp.split(self.biases, self._get_split_points(self.output_dims)))
|
|
548
718
|
return (None,) * len(self.output_dims)
|
|
@@ -555,7 +725,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
555
725
|
f" specified activation precision ({self.config.activation_precision}).",
|
|
556
726
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
557
727
|
)
|
|
558
|
-
|
|
728
|
+
*ld_num_components, lora_down_input_dim, lora_down_output_dim = self.lora_down_weights.shape
|
|
559
729
|
if lora_down_output_dim != self.config.lora_rank * self.num_outputs:
|
|
560
730
|
raise ValueError(
|
|
561
731
|
f"Number of output channels in LORA down weights ({lora_down_output_dim}) is not"
|
|
@@ -566,6 +736,12 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
566
736
|
f"Number of input channels in LORA down weights ({lora_down_input_dim}) is not"
|
|
567
737
|
f" equal to input_dim ({self.input_dim}).",
|
|
568
738
|
)
|
|
739
|
+
*w_num_components, _, _ = self.weights.shape
|
|
740
|
+
if tuple(ld_num_components) != tuple(w_num_components):
|
|
741
|
+
raise ValueError(
|
|
742
|
+
f"Number of mixture components in LORA down weights ({ld_num_components}) is not"
|
|
743
|
+
f" equal to number of mixture components in base weights ({w_num_components}).",
|
|
744
|
+
)
|
|
569
745
|
if len(self.lora_up_weights) != self.num_outputs:
|
|
570
746
|
raise ValueError(
|
|
571
747
|
f"Expected {self.num_outputs} LORA up weights, got {len(self.lora_up_weights)}.",
|
|
@@ -577,7 +753,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
577
753
|
f" ({self.config.activation_precision}).",
|
|
578
754
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
579
755
|
)
|
|
580
|
-
|
|
756
|
+
*lu_num_components, lora_up_input_dim, lora_up_output_dim = lora_up_weight.shape
|
|
581
757
|
if lora_up_output_dim != output_dim:
|
|
582
758
|
raise ValueError(
|
|
583
759
|
f"Number of output channels in LORA up weights ({lora_up_output_dim}) is not"
|
|
@@ -588,16 +764,26 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
588
764
|
f"Number of input channels in LORA up weights ({lora_up_input_dim}) is not"
|
|
589
765
|
f" equal to lora_rank ({self.config.lora_rank}).",
|
|
590
766
|
)
|
|
767
|
+
if tuple(lu_num_components) != tuple(w_num_components):
|
|
768
|
+
raise ValueError(
|
|
769
|
+
f"Number of mixture components in LORA up weights ({lu_num_components}) is not"
|
|
770
|
+
f" equal to number of mixture components in base weights ({w_num_components}).",
|
|
771
|
+
)
|
|
591
772
|
|
|
592
773
|
@eqx.filter_jit
|
|
593
774
|
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
775
|
+
if self.mixture_size is not None:
|
|
776
|
+
raise ValueError(
|
|
777
|
+
"Mixtures of linear layers cannot be called directly."
|
|
778
|
+
"They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
|
|
779
|
+
)
|
|
594
780
|
joint_q_out = self._apply_weights(inputs)
|
|
595
781
|
q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
|
|
596
782
|
|
|
597
|
-
joint_lora_hidden = self.lora_down_weights
|
|
783
|
+
joint_lora_hidden = inputs @ self.lora_down_weights
|
|
598
784
|
lora_hiddens = jnp.split(joint_lora_hidden, self._get_split_points([self.config.lora_rank] * self.num_outputs))
|
|
599
785
|
lora_outs = [
|
|
600
|
-
|
|
786
|
+
lora_hidden @ lora_up_weight
|
|
601
787
|
for lora_up_weight, lora_hidden in zip(self.lora_up_weights, lora_hiddens, strict=True)
|
|
602
788
|
]
|
|
603
789
|
|
|
@@ -610,30 +796,25 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
610
796
|
|
|
611
797
|
return tuple(results)
|
|
612
798
|
|
|
613
|
-
def export_weights(self
|
|
799
|
+
def export_weights(self) -> ParameterTree:
|
|
614
800
|
quantized_linear_weights: dict[str, ParameterTree] = super().export_weights() # type: ignore
|
|
615
|
-
exported_lora_down_weights = into_layout(self.lora_down_weights, weight_layout)
|
|
616
|
-
exported_lora_up_weights = [
|
|
617
|
-
into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
|
|
618
|
-
]
|
|
619
801
|
return dict(
|
|
620
|
-
down_weights=
|
|
621
|
-
up_weights=
|
|
802
|
+
down_weights=self.lora_down_weights,
|
|
803
|
+
up_weights=self.lora_up_weights,
|
|
622
804
|
**quantized_linear_weights,
|
|
623
805
|
)
|
|
624
806
|
|
|
625
807
|
def import_weights(
|
|
626
808
|
self,
|
|
627
809
|
weights: ParameterTree[Array],
|
|
628
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
629
810
|
) -> Self:
|
|
630
|
-
base = super().import_weights(weights
|
|
811
|
+
base = super().import_weights(weights)
|
|
631
812
|
assert isinstance(weights, Mapping)
|
|
632
813
|
assert isinstance(weights["up_weights"], Sequence)
|
|
633
814
|
return replace(
|
|
634
815
|
base,
|
|
635
|
-
lora_down_weights=
|
|
636
|
-
lora_up_weights=tuple(
|
|
816
|
+
lora_down_weights=weights["down_weights"],
|
|
817
|
+
lora_up_weights=tuple(up_weights for up_weights in weights["up_weights"]),
|
|
637
818
|
)
|
|
638
819
|
|
|
639
820
|
|