lalamo 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/common.py +79 -29
- lalamo/language_model.py +106 -83
- lalamo/main.py +91 -18
- lalamo/message_processor.py +170 -0
- lalamo/model_import/common.py +159 -43
- lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
- lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
- lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
- lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
- lalamo/model_import/huggingface_generation_config.py +44 -0
- lalamo/model_import/huggingface_tokenizer_config.py +85 -0
- lalamo/model_import/loaders/common.py +2 -1
- lalamo/model_import/loaders/huggingface.py +12 -10
- lalamo/model_import/model_specs/__init__.py +3 -2
- lalamo/model_import/model_specs/common.py +32 -34
- lalamo/model_import/model_specs/deepseek.py +1 -10
- lalamo/model_import/model_specs/gemma.py +2 -25
- lalamo/model_import/model_specs/huggingface.py +2 -12
- lalamo/model_import/model_specs/llama.py +2 -58
- lalamo/model_import/model_specs/mistral.py +9 -19
- lalamo/model_import/model_specs/pleias.py +3 -13
- lalamo/model_import/model_specs/polaris.py +5 -7
- lalamo/model_import/model_specs/qwen.py +12 -111
- lalamo/model_import/model_specs/reka.py +4 -13
- lalamo/modules/__init__.py +2 -1
- lalamo/modules/attention.py +90 -10
- lalamo/modules/common.py +51 -4
- lalamo/modules/decoder.py +90 -8
- lalamo/modules/decoder_layer.py +85 -8
- lalamo/modules/embedding.py +95 -29
- lalamo/modules/kv_cache.py +3 -3
- lalamo/modules/linear.py +170 -130
- lalamo/modules/mlp.py +40 -7
- lalamo/modules/normalization.py +24 -6
- lalamo/modules/rope.py +24 -6
- lalamo/sampling.py +99 -0
- lalamo/utils.py +86 -1
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/METADATA +6 -6
- lalamo-0.3.0.dist-info/RECORD +58 -0
- lalamo-0.2.7.dist-info/RECORD +0 -54
- /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/WHEEL +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py
CHANGED
|
@@ -1,19 +1,25 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from abc import abstractmethod
|
|
3
|
-
from collections.abc import Sequence
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from typing import NamedTuple
|
|
3
|
+
from collections.abc import Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass, replace
|
|
5
|
+
from typing import NamedTuple, Self
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
9
10
|
from einops import rearrange
|
|
10
|
-
from jax import numpy as jnp
|
|
11
11
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
12
12
|
|
|
13
|
-
from lalamo.common import
|
|
13
|
+
from lalamo.common import ParameterTree, dummy_array
|
|
14
14
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
15
15
|
|
|
16
|
-
from .common import
|
|
16
|
+
from .common import (
|
|
17
|
+
LalamoModule,
|
|
18
|
+
WeightLayout,
|
|
19
|
+
from_layout,
|
|
20
|
+
into_layout,
|
|
21
|
+
register_config_union,
|
|
22
|
+
)
|
|
17
23
|
|
|
18
24
|
__all__ = [
|
|
19
25
|
"FullPrecisionLinear",
|
|
@@ -48,30 +54,11 @@ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
|
|
|
48
54
|
inputs: Float[Array, " in_channels"],
|
|
49
55
|
) -> tuple[Float[Array, " out_channels"], ...]: ...
|
|
50
56
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
return WeightLayout.INPUT_OUTPUT
|
|
54
|
-
|
|
55
|
-
@classmethod
|
|
56
|
-
def _into_layout(
|
|
57
|
-
cls,
|
|
58
|
-
weights: Float[Array, "in_channels out_channels"],
|
|
59
|
-
layout: WeightLayout,
|
|
60
|
-
) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
|
|
61
|
-
if layout == WeightLayout.AUTO:
|
|
62
|
-
layout = cls._default_weight_layout()
|
|
63
|
-
match layout:
|
|
64
|
-
case WeightLayout.OUTPUT_INPUT:
|
|
65
|
-
return weights
|
|
66
|
-
case WeightLayout.INPUT_OUTPUT:
|
|
67
|
-
return rearrange(
|
|
68
|
-
weights,
|
|
69
|
-
"total_out_channels in_channels -> in_channels total_out_channels",
|
|
70
|
-
)
|
|
71
|
-
raise ValueError(f"Unsupported weight layout: {layout}")
|
|
57
|
+
def __post_init__(self) -> None:
|
|
58
|
+
assert isinstance(self.output_dims, tuple)
|
|
72
59
|
|
|
73
|
-
@
|
|
74
|
-
def _get_split_points(
|
|
60
|
+
@staticmethod
|
|
61
|
+
def _get_split_points(output_dims: Sequence[int]) -> tuple[int, ...]:
|
|
75
62
|
result = []
|
|
76
63
|
last_split_point = 0
|
|
77
64
|
for dim in output_dims[:-1]:
|
|
@@ -92,6 +79,14 @@ class LinearConfigBase:
|
|
|
92
79
|
key: PRNGKeyArray,
|
|
93
80
|
) -> LinearBase: ...
|
|
94
81
|
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def empty(
|
|
84
|
+
self,
|
|
85
|
+
input_dim: int,
|
|
86
|
+
output_dims: tuple[int, ...],
|
|
87
|
+
has_biases: bool,
|
|
88
|
+
) -> LinearBase: ...
|
|
89
|
+
|
|
95
90
|
|
|
96
91
|
@dataclass(frozen=True)
|
|
97
92
|
class FullPrecisionLinearConfig(LinearConfigBase):
|
|
@@ -104,7 +99,7 @@ class FullPrecisionLinearConfig(LinearConfigBase):
|
|
|
104
99
|
has_biases: bool,
|
|
105
100
|
*,
|
|
106
101
|
key: PRNGKeyArray,
|
|
107
|
-
) ->
|
|
102
|
+
) -> "FullPrecisionLinear":
|
|
108
103
|
scale = 1 / math.sqrt(input_dim)
|
|
109
104
|
weights = jax.random.uniform(
|
|
110
105
|
key,
|
|
@@ -125,6 +120,28 @@ class FullPrecisionLinearConfig(LinearConfigBase):
|
|
|
125
120
|
biases=biases,
|
|
126
121
|
)
|
|
127
122
|
|
|
123
|
+
def empty(
|
|
124
|
+
self,
|
|
125
|
+
input_dim: int,
|
|
126
|
+
output_dims: tuple[int, ...],
|
|
127
|
+
has_biases: bool,
|
|
128
|
+
) -> "FullPrecisionLinear":
|
|
129
|
+
weights = dummy_array(
|
|
130
|
+
(sum(output_dims), input_dim),
|
|
131
|
+
dtype=self.precision,
|
|
132
|
+
)
|
|
133
|
+
if has_biases:
|
|
134
|
+
biases = dummy_array((sum(output_dims),), dtype=self.precision)
|
|
135
|
+
else:
|
|
136
|
+
biases = None
|
|
137
|
+
|
|
138
|
+
return FullPrecisionLinear(
|
|
139
|
+
config=self,
|
|
140
|
+
output_dims=output_dims,
|
|
141
|
+
weights=weights,
|
|
142
|
+
biases=biases,
|
|
143
|
+
)
|
|
144
|
+
|
|
128
145
|
|
|
129
146
|
class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
130
147
|
weights: Float[Array, "total_out_channels in_channels"]
|
|
@@ -148,7 +165,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
148
165
|
raise ValueError(
|
|
149
166
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified precision ({self.config.precision}).",
|
|
150
167
|
)
|
|
151
|
-
w_output_dim,
|
|
168
|
+
w_output_dim, _ = self.weights.shape
|
|
152
169
|
if w_output_dim != sum(self.output_dims):
|
|
153
170
|
raise ValueError(
|
|
154
171
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
@@ -167,18 +184,31 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
167
184
|
f"Bias dtype ({self.biases.dtype}) is not equal to specified precision ({self.config.precision}).",
|
|
168
185
|
)
|
|
169
186
|
|
|
187
|
+
@eqx.filter_jit
|
|
170
188
|
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
171
189
|
result = self.weights @ inputs
|
|
172
190
|
if self.biases is not None:
|
|
173
191
|
result = result + self.biases
|
|
174
192
|
return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
|
|
175
193
|
|
|
176
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
177
|
-
result =
|
|
194
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
195
|
+
result = dict(weights=into_layout(self.weights, weight_layout))
|
|
178
196
|
if self.biases is not None:
|
|
179
197
|
result["biases"] = self.biases
|
|
180
198
|
return result
|
|
181
199
|
|
|
200
|
+
def import_weights(
|
|
201
|
+
self,
|
|
202
|
+
weights: ParameterTree[Array],
|
|
203
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
204
|
+
) -> Self:
|
|
205
|
+
assert isinstance(weights, Mapping)
|
|
206
|
+
return replace(
|
|
207
|
+
self,
|
|
208
|
+
weights=from_layout(weights["weights"], weight_layout),
|
|
209
|
+
biases=weights["biases"] if self.has_biases else None,
|
|
210
|
+
)
|
|
211
|
+
|
|
182
212
|
|
|
183
213
|
@dataclass(frozen=True)
|
|
184
214
|
class GroupQuantizedLinearConfig(LinearConfigBase):
|
|
@@ -224,6 +254,34 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
|
|
|
224
254
|
biases=biases,
|
|
225
255
|
)
|
|
226
256
|
|
|
257
|
+
def empty(
|
|
258
|
+
self,
|
|
259
|
+
input_dim: int,
|
|
260
|
+
output_dims: tuple[int, ...],
|
|
261
|
+
has_biases: bool,
|
|
262
|
+
) -> LinearBase:
|
|
263
|
+
weights = dummy_array(
|
|
264
|
+
(sum(output_dims), input_dim),
|
|
265
|
+
dtype=self.activation_precision,
|
|
266
|
+
)
|
|
267
|
+
num_groups = input_dim // self.group_size
|
|
268
|
+
scales = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
269
|
+
|
|
270
|
+
if has_biases:
|
|
271
|
+
biases = dummy_array((sum(output_dims),), dtype=self.activation_precision)
|
|
272
|
+
else:
|
|
273
|
+
biases = None
|
|
274
|
+
zero_points = dummy_array((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
275
|
+
|
|
276
|
+
return GroupQuantizedLinear(
|
|
277
|
+
config=self,
|
|
278
|
+
output_dims=output_dims,
|
|
279
|
+
weights=weights,
|
|
280
|
+
scales=scales,
|
|
281
|
+
zero_points=zero_points,
|
|
282
|
+
biases=biases,
|
|
283
|
+
)
|
|
284
|
+
|
|
227
285
|
|
|
228
286
|
class RequantizedWeights(NamedTuple):
|
|
229
287
|
weights: Int[Array, "total_out_channels in_channels"]
|
|
@@ -271,7 +329,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
271
329
|
f" ({self.config.activation_precision}).",
|
|
272
330
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
273
331
|
)
|
|
274
|
-
w_output_dim,
|
|
332
|
+
w_output_dim, _ = self.weights.shape
|
|
275
333
|
if w_output_dim != sum(self.output_dims):
|
|
276
334
|
raise ValueError(
|
|
277
335
|
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
@@ -352,100 +410,20 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
352
410
|
inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
|
|
353
411
|
return self._prepare_scaled_weights() @ inputs
|
|
354
412
|
|
|
413
|
+
@eqx.filter_jit
|
|
355
414
|
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
356
415
|
result = self._apply_weights(inputs)
|
|
357
416
|
if self.biases is not None:
|
|
358
417
|
result = result + self.biases
|
|
359
418
|
return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
|
|
360
419
|
|
|
361
|
-
def
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
weights: uint4 array of shape [M, N]
|
|
367
|
-
zero_points: uint4 array of shape [M//group_size_0, N//group_size_1]
|
|
368
|
-
scales: float16 array of shape [M//group_size_0, N//group_size_1]
|
|
369
|
-
|
|
370
|
-
Returns:
|
|
371
|
-
new_weights: uint4 array of shape [M, N]
|
|
372
|
-
new_zero_points: uint4 array of shape [M, N//128]
|
|
373
|
-
new_scales: float16 array of shape [M, N//128]
|
|
374
|
-
"""
|
|
375
|
-
# Get dimensions
|
|
376
|
-
M, N = weights.shape
|
|
377
|
-
old_groups_0, old_groups_1 = zero_points.shape
|
|
378
|
-
|
|
379
|
-
# Calculate old group sizes
|
|
380
|
-
old_group_size_0 = M // old_groups_0 # 2560 // 20 = 128
|
|
381
|
-
old_group_size_1 = N // old_groups_1 # 6144 // 6144 = 1
|
|
382
|
-
|
|
383
|
-
# New group sizes
|
|
384
|
-
new_group_size_0 = 1 # 2560 // 2560 = 1
|
|
385
|
-
new_group_size_1 = self.config.group_size # 6144 // 48 = 128
|
|
386
|
-
|
|
387
|
-
# Step 1: Dequantize with original parameters
|
|
388
|
-
# Expand zero_points and scales to match weights shape
|
|
389
|
-
zp_expanded = jnp.repeat(jnp.repeat(zero_points, old_group_size_0, axis=0), old_group_size_1, axis=1)
|
|
390
|
-
scales_expanded = jnp.repeat(jnp.repeat(scales, old_group_size_0, axis=0), old_group_size_1, axis=1)
|
|
391
|
-
|
|
392
|
-
# Dequantize (convert to float for computation)
|
|
393
|
-
weights_float = weights.astype(jnp.float32)
|
|
394
|
-
zp_float = zp_expanded.astype(jnp.float32)
|
|
395
|
-
dequantized = (weights_float - zp_float) * scales_expanded.astype(jnp.float32)
|
|
396
|
-
|
|
397
|
-
# Step 2: Requantize with new group structure [2560, 48]
|
|
398
|
-
# Reshape for new groups
|
|
399
|
-
dequantized_reshaped = dequantized.reshape(
|
|
400
|
-
M // new_group_size_0,
|
|
401
|
-
new_group_size_0,
|
|
402
|
-
N // new_group_size_1,
|
|
403
|
-
new_group_size_1,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
# Compute new scales and zero points per group
|
|
407
|
-
# Move group dimensions to the end for reduction
|
|
408
|
-
dequantized_groups = dequantized_reshaped.transpose(0, 2, 1, 3) # [2560, 48, 1, 128]
|
|
409
|
-
|
|
410
|
-
# Find min and max per group
|
|
411
|
-
group_min = dequantized_groups.min(axis=(2, 3), keepdims=True)
|
|
412
|
-
group_max = dequantized_groups.max(axis=(2, 3), keepdims=True)
|
|
413
|
-
|
|
414
|
-
# Compute scales (with small epsilon to avoid division by zero)
|
|
415
|
-
eps = 1e-6
|
|
416
|
-
new_scales = ((group_max - group_min) / 15.0 + eps).astype(scales.dtype)
|
|
417
|
-
new_scales = new_scales.squeeze(axis=(2, 3)) # [2560, 48]
|
|
418
|
-
|
|
419
|
-
# Compute zero points (quantize to uint4 range 0-15)
|
|
420
|
-
new_zero_points = jnp.round(-group_min.squeeze(axis=(2, 3)) / new_scales).astype(jnp.uint4)
|
|
421
|
-
new_zero_points = jnp.clip(new_zero_points, 0, 15)
|
|
422
|
-
|
|
423
|
-
# Quantize with new parameters
|
|
424
|
-
scales_expanded_new = jnp.repeat(new_scales, new_group_size_1, axis=1).reshape(M, N)
|
|
425
|
-
zp_expanded_new = jnp.repeat(new_zero_points, new_group_size_1, axis=1).reshape(M, N)
|
|
426
|
-
|
|
427
|
-
new_weights = jnp.round(
|
|
428
|
-
dequantized / scales_expanded_new.astype(jnp.float32) + zp_expanded_new.astype(jnp.float32),
|
|
429
|
-
)
|
|
430
|
-
new_weights = jnp.clip(new_weights, 0, 15).astype(jnp.uint4)
|
|
431
|
-
|
|
432
|
-
return new_weights, new_zero_points, new_scales
|
|
433
|
-
|
|
434
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
435
|
-
exported_weights = self._into_layout(self.int_weights, weight_layout)
|
|
436
|
-
|
|
437
|
-
exported_zero_points = self._into_layout(self.int_zero_points, weight_layout)
|
|
420
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
421
|
+
expected_weight_layout = WeightLayout.OUTPUT_INPUT
|
|
422
|
+
exported_weights = into_layout(self.int_weights, expected_weight_layout)
|
|
423
|
+
exported_zero_points = into_layout(self.int_zero_points, expected_weight_layout)
|
|
424
|
+
exported_scales = into_layout(self.scales, expected_weight_layout)
|
|
438
425
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
# CRIMINAL HACK!!!
|
|
442
|
-
exported_weights, exported_zero_points, exported_scales = self.requantize_weights(
|
|
443
|
-
exported_weights,
|
|
444
|
-
exported_zero_points,
|
|
445
|
-
exported_scales,
|
|
446
|
-
)
|
|
447
|
-
|
|
448
|
-
result = ParameterDict(
|
|
426
|
+
result = dict(
|
|
449
427
|
weights=exported_weights,
|
|
450
428
|
zero_points=exported_zero_points,
|
|
451
429
|
scales=exported_scales,
|
|
@@ -454,6 +432,21 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
454
432
|
result["biases"] = self.biases
|
|
455
433
|
return result
|
|
456
434
|
|
|
435
|
+
def import_weights(
|
|
436
|
+
self,
|
|
437
|
+
weights: ParameterTree[Array],
|
|
438
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
439
|
+
) -> Self:
|
|
440
|
+
assert isinstance(weights, Mapping)
|
|
441
|
+
assert isinstance(weights["weights"], Array)
|
|
442
|
+
return replace(
|
|
443
|
+
self,
|
|
444
|
+
weights=from_layout(weights["weights"].astype(self.weights.dtype), weight_layout),
|
|
445
|
+
scales=from_layout(weights["scales"], weight_layout),
|
|
446
|
+
zero_points=from_layout(weights["zero_points"], weight_layout).astype(self.zero_points.dtype),
|
|
447
|
+
biases=weights["biases"] if self.has_biases else None,
|
|
448
|
+
)
|
|
449
|
+
|
|
457
450
|
|
|
458
451
|
class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
|
|
459
452
|
pass
|
|
@@ -512,6 +505,38 @@ class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
|
512
505
|
lora_up_weights=lora_up_weights,
|
|
513
506
|
)
|
|
514
507
|
|
|
508
|
+
def empty(
|
|
509
|
+
self,
|
|
510
|
+
input_dim: int,
|
|
511
|
+
output_dims: tuple[int, ...],
|
|
512
|
+
has_biases: bool,
|
|
513
|
+
) -> LinearBase:
|
|
514
|
+
group_quantized_linear = super().empty(input_dim, output_dims, has_biases)
|
|
515
|
+
assert isinstance(group_quantized_linear, GroupQuantizedLinear)
|
|
516
|
+
hidden_lora_rank = len(output_dims) * self.lora_rank
|
|
517
|
+
lora_down_weights = dummy_array(
|
|
518
|
+
(hidden_lora_rank, input_dim),
|
|
519
|
+
dtype=self.activation_precision,
|
|
520
|
+
)
|
|
521
|
+
lora_up_weights = tuple(
|
|
522
|
+
dummy_array(
|
|
523
|
+
(output_dim, self.lora_rank),
|
|
524
|
+
dtype=self.activation_precision,
|
|
525
|
+
)
|
|
526
|
+
for output_dim in output_dims
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
return QLoRALinear(
|
|
530
|
+
config=self,
|
|
531
|
+
output_dims=output_dims,
|
|
532
|
+
weights=group_quantized_linear.weights,
|
|
533
|
+
scales=group_quantized_linear.scales,
|
|
534
|
+
biases=group_quantized_linear.biases,
|
|
535
|
+
zero_points=group_quantized_linear.zero_points,
|
|
536
|
+
lora_down_weights=lora_down_weights,
|
|
537
|
+
lora_up_weights=lora_up_weights,
|
|
538
|
+
)
|
|
539
|
+
|
|
515
540
|
|
|
516
541
|
class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
517
542
|
lora_down_weights: Float[Array, "total_lora_channels in_channels"]
|
|
@@ -564,6 +589,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
564
589
|
f" equal to lora_rank ({self.config.lora_rank}).",
|
|
565
590
|
)
|
|
566
591
|
|
|
592
|
+
@eqx.filter_jit
|
|
567
593
|
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
568
594
|
joint_q_out = self._apply_weights(inputs)
|
|
569
595
|
q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
|
|
@@ -584,16 +610,30 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
584
610
|
|
|
585
611
|
return tuple(results)
|
|
586
612
|
|
|
587
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
588
|
-
quantized_linear_weights = super().export_weights()
|
|
589
|
-
exported_lora_down_weights =
|
|
590
|
-
exported_lora_up_weights =
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
return
|
|
613
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
614
|
+
quantized_linear_weights: dict[str, ParameterTree] = super().export_weights() # type: ignore
|
|
615
|
+
exported_lora_down_weights = into_layout(self.lora_down_weights, weight_layout)
|
|
616
|
+
exported_lora_up_weights = [
|
|
617
|
+
into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
|
|
618
|
+
]
|
|
619
|
+
return dict(
|
|
620
|
+
down_weights=into_layout(exported_lora_down_weights, weight_layout),
|
|
621
|
+
up_weights=[into_layout(w, weight_layout) for w in exported_lora_up_weights],
|
|
594
622
|
**quantized_linear_weights,
|
|
595
|
-
|
|
596
|
-
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def import_weights(
|
|
626
|
+
self,
|
|
627
|
+
weights: ParameterTree[Array],
|
|
628
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
629
|
+
) -> Self:
|
|
630
|
+
base = super().import_weights(weights, weight_layout)
|
|
631
|
+
assert isinstance(weights, Mapping)
|
|
632
|
+
assert isinstance(weights["up_weights"], Sequence)
|
|
633
|
+
return replace(
|
|
634
|
+
base,
|
|
635
|
+
lora_down_weights=from_layout(weights["down_weights"], weight_layout),
|
|
636
|
+
lora_up_weights=tuple(from_layout(up_weights, weight_layout) for up_weights in weights["up_weights"]),
|
|
597
637
|
)
|
|
598
638
|
|
|
599
639
|
|
lalamo/modules/mlp.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from typing import Self
|
|
2
4
|
|
|
5
|
+
import equinox as eqx
|
|
3
6
|
import jax
|
|
4
7
|
from jaxtyping import Array, DTypeLike, Float, PRNGKeyArray
|
|
5
8
|
|
|
6
|
-
from lalamo.common import
|
|
9
|
+
from lalamo.common import ParameterTree
|
|
7
10
|
|
|
8
11
|
from .activations import Activation
|
|
9
12
|
from .common import LalamoModule, WeightLayout
|
|
@@ -35,8 +38,23 @@ class MLPConfig:
|
|
|
35
38
|
),
|
|
36
39
|
)
|
|
37
40
|
|
|
41
|
+
def empty(self, model_dim: int, hidden_dim: int) -> "MLP":
|
|
42
|
+
return MLP(
|
|
43
|
+
self,
|
|
44
|
+
up_projection=self.linear_config.empty(
|
|
45
|
+
model_dim,
|
|
46
|
+
(hidden_dim, hidden_dim),
|
|
47
|
+
has_biases=False,
|
|
48
|
+
),
|
|
49
|
+
down_projection=self.linear_config.empty(
|
|
50
|
+
hidden_dim,
|
|
51
|
+
(model_dim,),
|
|
52
|
+
has_biases=False,
|
|
53
|
+
),
|
|
54
|
+
)
|
|
38
55
|
|
|
39
|
-
|
|
56
|
+
|
|
57
|
+
class MLP(LalamoModule[MLPConfig]):
|
|
40
58
|
up_projection: LinearBase
|
|
41
59
|
down_projection: LinearBase
|
|
42
60
|
|
|
@@ -66,14 +84,29 @@ class MLP(LalamoModule):
|
|
|
66
84
|
f" the up projection output dimension {self.up_projection.input_dim}",
|
|
67
85
|
)
|
|
68
86
|
|
|
87
|
+
@eqx.filter_jit
|
|
69
88
|
def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
|
|
70
89
|
up_proj, gate = self.up_projection(inputs)
|
|
71
90
|
gate = self.config.activation(gate)
|
|
72
91
|
(result,) = self.down_projection(up_proj * gate)
|
|
73
92
|
return result
|
|
74
93
|
|
|
75
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
76
|
-
return
|
|
77
|
-
up_projection
|
|
78
|
-
down_projection
|
|
94
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
95
|
+
return {
|
|
96
|
+
"up_projection": self.up_projection.export_weights(weight_layout),
|
|
97
|
+
"down_projection": self.down_projection.export_weights(weight_layout),
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
def import_weights(
|
|
101
|
+
self,
|
|
102
|
+
weights: ParameterTree[Array],
|
|
103
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
104
|
+
) -> Self:
|
|
105
|
+
assert isinstance(weights, Mapping)
|
|
106
|
+
assert isinstance(weights["up_projection"], Mapping)
|
|
107
|
+
assert isinstance(weights["down_projection"], Mapping)
|
|
108
|
+
return replace(
|
|
109
|
+
self,
|
|
110
|
+
up_projection=self.up_projection.import_weights(weights["up_projection"], weight_layout),
|
|
111
|
+
down_projection=self.down_projection.import_weights(weights["down_projection"], weight_layout),
|
|
79
112
|
)
|
lalamo/modules/normalization.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
2
3
|
from enum import Enum
|
|
4
|
+
from typing import Self
|
|
3
5
|
|
|
6
|
+
import equinox as eqx
|
|
4
7
|
import jax
|
|
5
8
|
from jax import numpy as jnp
|
|
6
9
|
from jaxtyping import Array, DTypeLike, Float
|
|
7
10
|
|
|
8
|
-
from lalamo.common import
|
|
11
|
+
from lalamo.common import ParameterTree, dummy_array
|
|
9
12
|
|
|
10
13
|
from .common import LalamoModule, WeightLayout
|
|
11
14
|
|
|
@@ -29,10 +32,16 @@ class RMSNormConfig:
|
|
|
29
32
|
scale_offset: float | None
|
|
30
33
|
upcast_mode: UpcastMode
|
|
31
34
|
|
|
32
|
-
def init(self,
|
|
33
|
-
scales = jnp.ones(
|
|
35
|
+
def init(self, input_dim: int) -> "RMSNorm":
|
|
36
|
+
scales = jnp.ones(input_dim, dtype=self.scale_precision)
|
|
34
37
|
return RMSNorm(self, scales=scales)
|
|
35
38
|
|
|
39
|
+
def empty(self, input_dim: int) -> "RMSNorm":
|
|
40
|
+
return RMSNorm(
|
|
41
|
+
config=self,
|
|
42
|
+
scales=dummy_array(input_dim, dtype=self.scale_precision),
|
|
43
|
+
)
|
|
44
|
+
|
|
36
45
|
|
|
37
46
|
class RMSNorm(LalamoModule[RMSNormConfig]):
|
|
38
47
|
scales: Float[Array, " channels"]
|
|
@@ -53,6 +62,7 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
|
|
|
53
62
|
f" specified precision {self.config.scale_precision}",
|
|
54
63
|
)
|
|
55
64
|
|
|
65
|
+
@eqx.filter_jit
|
|
56
66
|
def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
|
|
57
67
|
upcasted_inputs = inputs.astype(self.config.accumulation_precision)
|
|
58
68
|
|
|
@@ -73,5 +83,13 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
|
|
|
73
83
|
result = normalized_x * adjusted_scales
|
|
74
84
|
return result.astype(inputs.dtype)
|
|
75
85
|
|
|
76
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
77
|
-
return
|
|
86
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
|
|
87
|
+
return {"scales": self.scales}
|
|
88
|
+
|
|
89
|
+
def import_weights(
|
|
90
|
+
self,
|
|
91
|
+
weights: ParameterTree[Array],
|
|
92
|
+
weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
|
|
93
|
+
) -> Self:
|
|
94
|
+
assert isinstance(weights, Mapping)
|
|
95
|
+
return replace(self, scales=weights["scales"])
|
lalamo/modules/rope.py
CHANGED
|
@@ -16,13 +16,14 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
import math
|
|
19
|
-
from
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from dataclasses import dataclass, replace
|
|
20
21
|
|
|
21
22
|
import equinox as eqx
|
|
22
23
|
from jax import numpy as jnp
|
|
23
24
|
from jaxtyping import Array, DTypeLike, Float, Int
|
|
24
25
|
|
|
25
|
-
from lalamo.common import
|
|
26
|
+
from lalamo.common import ParameterTree
|
|
26
27
|
|
|
27
28
|
from .common import LalamoModule, WeightLayout, register_config_union
|
|
28
29
|
|
|
@@ -53,8 +54,8 @@ class PositionalEmbeddings(eqx.Module):
|
|
|
53
54
|
def apply(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
|
|
54
55
|
return heads * self.cosines + self.rotate_half(heads) * self.sines
|
|
55
56
|
|
|
56
|
-
def export(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
57
|
-
return
|
|
57
|
+
def export(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
|
|
58
|
+
return dict(
|
|
58
59
|
cosines=self.cosines,
|
|
59
60
|
sines=self.sines,
|
|
60
61
|
)
|
|
@@ -103,6 +104,11 @@ class RoPE(LalamoModule[RoPEConfigBase]):
|
|
|
103
104
|
return self.config.precision
|
|
104
105
|
|
|
105
106
|
def __post_init__(self) -> None:
|
|
107
|
+
num_tokens, _ = self.sines.shape
|
|
108
|
+
if num_tokens != self.config.max_sequence_length:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"{num_tokens} does not match the specified max sequence length {self.config.max_sequence_length}",
|
|
111
|
+
)
|
|
106
112
|
if self.cosines.dtype != self.config.precision:
|
|
107
113
|
raise ValueError(
|
|
108
114
|
f"Cosines dtype {self.cosines.dtype} does not match the specified precision {self.config.precision}",
|
|
@@ -127,14 +133,26 @@ class RoPE(LalamoModule[RoPEConfigBase]):
|
|
|
127
133
|
result, _ = self.sines.shape
|
|
128
134
|
return result
|
|
129
135
|
|
|
136
|
+
@eqx.filter_jit
|
|
130
137
|
def __call__(self, timesteps: Int[Array, " tokens"]) -> PositionalEmbeddings:
|
|
131
138
|
return PositionalEmbeddings(
|
|
132
139
|
cosines=self.cosines[timesteps],
|
|
133
140
|
sines=self.sines[timesteps],
|
|
134
141
|
)
|
|
135
142
|
|
|
136
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
137
|
-
return
|
|
143
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree[Array]: # noqa: ARG002
|
|
144
|
+
return {
|
|
145
|
+
"cosines": self.cosines,
|
|
146
|
+
"sines": self.sines,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
def import_weights(
|
|
150
|
+
self,
|
|
151
|
+
weights: ParameterTree[Array],
|
|
152
|
+
weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
|
|
153
|
+
) -> "RoPE":
|
|
154
|
+
assert isinstance(weights, Mapping)
|
|
155
|
+
return replace(self, cosines=weights["cosines"], sines=weights["sines"])
|
|
138
156
|
|
|
139
157
|
|
|
140
158
|
class UnscaledRoPEConfig(RoPEConfigBase):
|