lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/__init__.py +8 -0
- lalamo/model_import/common.py +111 -0
- lalamo/model_import/configs/__init__.py +23 -0
- lalamo/model_import/configs/common.py +62 -0
- lalamo/model_import/configs/executorch.py +166 -0
- lalamo/model_import/configs/huggingface/__init__.py +18 -0
- lalamo/model_import/configs/huggingface/common.py +72 -0
- lalamo/model_import/configs/huggingface/gemma2.py +122 -0
- lalamo/model_import/configs/huggingface/gemma3.py +187 -0
- lalamo/model_import/configs/huggingface/llama.py +155 -0
- lalamo/model_import/configs/huggingface/mistral.py +132 -0
- lalamo/model_import/configs/huggingface/qwen2.py +144 -0
- lalamo/model_import/configs/huggingface/qwen3.py +142 -0
- lalamo/model_import/loaders/__init__.py +7 -0
- lalamo/model_import/loaders/common.py +45 -0
- lalamo/model_import/loaders/executorch.py +223 -0
- lalamo/model_import/loaders/huggingface.py +304 -0
- lalamo/model_import/model_specs/__init__.py +38 -0
- lalamo/model_import/model_specs/common.py +118 -0
- lalamo/model_import/model_specs/deepseek.py +28 -0
- lalamo/model_import/model_specs/gemma.py +76 -0
- lalamo/model_import/model_specs/huggingface.py +28 -0
- lalamo/model_import/model_specs/llama.py +101 -0
- lalamo/model_import/model_specs/mistral.py +59 -0
- lalamo/model_import/model_specs/pleias.py +28 -0
- lalamo/model_import/model_specs/polaris.py +22 -0
- lalamo/model_import/model_specs/qwen.py +336 -0
- lalamo/model_import/model_specs/reka.py +28 -0
- lalamo/modules/__init__.py +85 -0
- lalamo/modules/activations.py +30 -0
- lalamo/modules/attention.py +326 -0
- lalamo/modules/common.py +133 -0
- lalamo/modules/decoder.py +244 -0
- lalamo/modules/decoder_layer.py +240 -0
- lalamo/modules/embedding.py +299 -0
- lalamo/modules/kv_cache.py +196 -0
- lalamo/modules/linear.py +603 -0
- lalamo/modules/mlp.py +79 -0
- lalamo/modules/normalization.py +77 -0
- lalamo/modules/rope.py +255 -0
- lalamo/modules/utils.py +13 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
- lalamo-0.2.2.dist-info/RECORD +53 -0
- lalamo-0.2.1.dist-info/RECORD +0 -12
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import NamedTuple
|
|
6
|
+
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
import jax
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from jax import numpy as jnp
|
|
11
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
12
|
+
|
|
13
|
+
from lalamo.common import ParameterDict
|
|
14
|
+
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
15
|
+
|
|
16
|
+
from .common import LalamoModule, WeightLayout, register_config_union
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"FullPrecisionLinear",
|
|
20
|
+
"FullPrecisionLinearConfig",
|
|
21
|
+
"GroupQuantizedLinear",
|
|
22
|
+
"GroupQuantizedLinearConfig",
|
|
23
|
+
"LinearBase",
|
|
24
|
+
"LinearConfig",
|
|
25
|
+
"QLoRALinear",
|
|
26
|
+
"QLoRALinearConfig",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
|
|
31
|
+
output_dims: tuple[int, ...] = eqx.field(static=True)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def input_dim(self) -> int: ...
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def num_outputs(self) -> int:
|
|
39
|
+
return len(self.output_dims)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def has_biases(self) -> bool: ...
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def __call__(
|
|
47
|
+
self,
|
|
48
|
+
inputs: Float[Array, " in_channels"],
|
|
49
|
+
) -> tuple[Float[Array, " out_channels"], ...]: ...
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def _default_weight_layout(cls) -> WeightLayout:
|
|
53
|
+
return WeightLayout.INPUT_OUTPUT
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def _into_layout(
|
|
57
|
+
cls,
|
|
58
|
+
weights: Float[Array, "in_channels out_channels"],
|
|
59
|
+
layout: WeightLayout,
|
|
60
|
+
) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
|
|
61
|
+
if layout == WeightLayout.AUTO:
|
|
62
|
+
layout = cls._default_weight_layout()
|
|
63
|
+
match layout:
|
|
64
|
+
case WeightLayout.OUTPUT_INPUT:
|
|
65
|
+
return weights
|
|
66
|
+
case WeightLayout.INPUT_OUTPUT:
|
|
67
|
+
return rearrange(
|
|
68
|
+
weights,
|
|
69
|
+
"total_out_channels in_channels -> in_channels total_out_channels",
|
|
70
|
+
)
|
|
71
|
+
raise ValueError(f"Unsupported weight layout: {layout}")
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def _get_split_points(cls, output_dims: Sequence[int]) -> tuple[int, ...]:
|
|
75
|
+
result = []
|
|
76
|
+
last_split_point = 0
|
|
77
|
+
for dim in output_dims[:-1]:
|
|
78
|
+
last_split_point += dim
|
|
79
|
+
result.append(last_split_point)
|
|
80
|
+
return tuple(result)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass(frozen=True)
|
|
84
|
+
class LinearConfigBase:
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def random_init(
|
|
87
|
+
self,
|
|
88
|
+
input_dim: int,
|
|
89
|
+
output_dims: tuple[int, ...],
|
|
90
|
+
has_biases: bool,
|
|
91
|
+
*,
|
|
92
|
+
key: PRNGKeyArray,
|
|
93
|
+
) -> LinearBase: ...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(frozen=True)
|
|
97
|
+
class FullPrecisionLinearConfig(LinearConfigBase):
|
|
98
|
+
precision: DTypeLike
|
|
99
|
+
|
|
100
|
+
def random_init(
|
|
101
|
+
self,
|
|
102
|
+
input_dim: int,
|
|
103
|
+
output_dims: tuple[int, ...],
|
|
104
|
+
has_biases: bool,
|
|
105
|
+
*,
|
|
106
|
+
key: PRNGKeyArray,
|
|
107
|
+
) -> LinearBase:
|
|
108
|
+
scale = 1 / math.sqrt(input_dim)
|
|
109
|
+
weights = jax.random.uniform(
|
|
110
|
+
key,
|
|
111
|
+
(sum(output_dims), input_dim),
|
|
112
|
+
minval=-scale,
|
|
113
|
+
maxval=scale,
|
|
114
|
+
dtype=self.precision,
|
|
115
|
+
)
|
|
116
|
+
if has_biases:
|
|
117
|
+
biases = jnp.zeros((sum(output_dims),), dtype=self.precision)
|
|
118
|
+
else:
|
|
119
|
+
biases = None
|
|
120
|
+
|
|
121
|
+
return FullPrecisionLinear(
|
|
122
|
+
config=self,
|
|
123
|
+
output_dims=output_dims,
|
|
124
|
+
weights=weights,
|
|
125
|
+
biases=biases,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
130
|
+
weights: Float[Array, "total_out_channels in_channels"]
|
|
131
|
+
biases: Float[Array, " total_out_channels"] | None
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def activation_precision(self) -> DTypeLike:
|
|
135
|
+
return self.config.precision
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def input_dim(self) -> int:
|
|
139
|
+
_, input_dim = self.weights.shape
|
|
140
|
+
return input_dim
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def has_biases(self) -> bool:
|
|
144
|
+
return self.biases is not None
|
|
145
|
+
|
|
146
|
+
def __post_init__(self) -> None:
|
|
147
|
+
if self.weights.dtype != self.config.precision:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Weight dtype ({self.weights.dtype}) is not equal to specified precision ({self.config.precision}).",
|
|
150
|
+
)
|
|
151
|
+
w_output_dim, w_input_dim = self.weights.shape
|
|
152
|
+
if w_output_dim != sum(self.output_dims):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
155
|
+
f" equal to sum of output dims ({sum(self.output_dims)}).",
|
|
156
|
+
)
|
|
157
|
+
if self.biases is None:
|
|
158
|
+
return
|
|
159
|
+
(b_output_dim,) = self.biases.shape
|
|
160
|
+
if w_output_dim != b_output_dim:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
163
|
+
f" equal to number of output channels in biases ({b_output_dim}).",
|
|
164
|
+
)
|
|
165
|
+
if self.biases.dtype != self.config.precision:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Bias dtype ({self.biases.dtype}) is not equal to specified precision ({self.config.precision}).",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
171
|
+
result = self.weights @ inputs
|
|
172
|
+
if self.biases is not None:
|
|
173
|
+
result = result + self.biases
|
|
174
|
+
return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
|
|
175
|
+
|
|
176
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
177
|
+
result = ParameterDict(weights=self._into_layout(self.weights, weight_layout))
|
|
178
|
+
if self.biases is not None:
|
|
179
|
+
result["biases"] = self.biases
|
|
180
|
+
return result
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass(frozen=True)
|
|
184
|
+
class GroupQuantizedLinearConfig(LinearConfigBase):
|
|
185
|
+
group_size: int
|
|
186
|
+
weight_quantization_mode: QuantizationMode
|
|
187
|
+
activation_quantization_mode: QuantizationMode | None
|
|
188
|
+
activation_precision: DTypeLike
|
|
189
|
+
|
|
190
|
+
def random_init(
|
|
191
|
+
self,
|
|
192
|
+
input_dim: int,
|
|
193
|
+
output_dims: tuple[int, ...],
|
|
194
|
+
has_biases: bool,
|
|
195
|
+
*,
|
|
196
|
+
key: PRNGKeyArray,
|
|
197
|
+
) -> LinearBase:
|
|
198
|
+
min_val, max_val = self.weight_quantization_mode.range
|
|
199
|
+
weights = jax.random.uniform(
|
|
200
|
+
key,
|
|
201
|
+
(sum(output_dims), input_dim),
|
|
202
|
+
minval=min_val - 1,
|
|
203
|
+
maxval=max_val + 1,
|
|
204
|
+
dtype=self.activation_precision,
|
|
205
|
+
)
|
|
206
|
+
num_groups = input_dim // self.group_size
|
|
207
|
+
scale = 1 / ((max_val - min_val) / 2 * math.sqrt(input_dim))
|
|
208
|
+
scales = scale * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
209
|
+
|
|
210
|
+
if has_biases:
|
|
211
|
+
biases = jnp.zeros((sum(output_dims),), dtype=self.activation_precision)
|
|
212
|
+
else:
|
|
213
|
+
biases = None
|
|
214
|
+
|
|
215
|
+
zero_point = min_val + 2 ** (self.weight_quantization_mode.bits - 1)
|
|
216
|
+
zero_points = zero_point * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
217
|
+
|
|
218
|
+
return GroupQuantizedLinear(
|
|
219
|
+
config=self,
|
|
220
|
+
output_dims=output_dims,
|
|
221
|
+
weights=weights,
|
|
222
|
+
scales=scales,
|
|
223
|
+
zero_points=zero_points,
|
|
224
|
+
biases=biases,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class RequantizedWeights(NamedTuple):
|
|
229
|
+
weights: Int[Array, "total_out_channels in_channels"]
|
|
230
|
+
zero_points: Int[Array, "groups in_channels"]
|
|
231
|
+
scales: Float[Array, "groups in_channels"]
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[ConfigT]):
|
|
235
|
+
weights: Float[Array, "total_out_channels in_channels"]
|
|
236
|
+
scales: Float[Array, "total_out_channels groups"]
|
|
237
|
+
zero_points: Float[Array, "total_out_channels groups"]
|
|
238
|
+
biases: Float[Array, " total_out_channels"] | None
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def activation_precision(self) -> DTypeLike:
|
|
242
|
+
return self.config.activation_precision
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def input_dim(self) -> int:
|
|
246
|
+
_, input_dim = self.weights.shape
|
|
247
|
+
return input_dim
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def has_biases(self) -> bool:
|
|
251
|
+
return self.biases is not None
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def num_groups(self) -> int:
|
|
255
|
+
return self.input_dim // self.config.group_size
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def int_weights(self) -> Int[Array, "out_channels (groups in_channels)"]:
|
|
259
|
+
result = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
260
|
+
return result.astype(self.config.weight_quantization_mode.dtype)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def int_zero_points(self) -> Int[Array, "out_channels (groups in_channels)"]:
|
|
264
|
+
result = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
|
|
265
|
+
return result.astype(self.config.weight_quantization_mode.dtype)
|
|
266
|
+
|
|
267
|
+
def __post_init__(self) -> None:
|
|
268
|
+
if self.weights.dtype != self.config.activation_precision:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
271
|
+
f" ({self.config.activation_precision}).",
|
|
272
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
273
|
+
)
|
|
274
|
+
w_output_dim, w_input_dim = self.weights.shape
|
|
275
|
+
if w_output_dim != sum(self.output_dims):
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
278
|
+
f" equal to sum of output dims ({sum(self.output_dims)}).",
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if self.scales.dtype != self.config.activation_precision:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"Scale dtype ({self.scales.dtype}) is not equal to specified activation precision"
|
|
284
|
+
f" ({self.config.activation_precision}).",
|
|
285
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
286
|
+
)
|
|
287
|
+
s_output_dim, s_num_groups = self.scales.shape
|
|
288
|
+
if w_output_dim != s_output_dim:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
291
|
+
f" equal to number of output channels in scales ({s_output_dim}).",
|
|
292
|
+
)
|
|
293
|
+
if s_num_groups != self.num_groups:
|
|
294
|
+
raise ValueError(
|
|
295
|
+
f"Number of groups in scales ({s_num_groups}) is incompatible with"
|
|
296
|
+
f" the specified group size ({self.config.group_size}).",
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if self.zero_points.dtype != self.config.activation_precision:
|
|
300
|
+
raise ValueError(
|
|
301
|
+
f"Zero point dtype ({self.zero_points.dtype}) is not equal to specified activation precision"
|
|
302
|
+
f" ({self.config.activation_precision}).",
|
|
303
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
304
|
+
)
|
|
305
|
+
(zp_output_dim, zp_num_groups) = self.zero_points.shape
|
|
306
|
+
if w_output_dim != zp_output_dim:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
309
|
+
f" equal to number of output channels in zero points ({zp_output_dim}).",
|
|
310
|
+
)
|
|
311
|
+
if self.num_groups != zp_num_groups:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Number of groups in zero points ({zp_num_groups}) is incompatible with"
|
|
314
|
+
f" the specified group size ({self.config.group_size}).",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
if self.biases is not None:
|
|
318
|
+
if self.biases.dtype != self.config.activation_precision:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"Bias dtype ({self.biases.dtype}) is not equal to specified activation precision"
|
|
321
|
+
f" ({self.config.activation_precision}).",
|
|
322
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
323
|
+
)
|
|
324
|
+
(b_output_dim,) = self.biases.shape
|
|
325
|
+
if w_output_dim != b_output_dim:
|
|
326
|
+
raise ValueError(
|
|
327
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
328
|
+
f" equal to number of output channels in biases ({b_output_dim}).",
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def _prepare_scaled_weights(self) -> Float[Array, "total_out_channels in_channels"]:
|
|
332
|
+
quantized_weights = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
333
|
+
grouped_weights = rearrange(
|
|
334
|
+
quantized_weights,
|
|
335
|
+
"total_out_channels (groups group_channels) -> total_out_channels groups group_channels",
|
|
336
|
+
groups=self.num_groups,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
zero_points = rearrange(self.zero_points, "total_out_channels groups -> total_out_channels groups 1")
|
|
340
|
+
grouped_weights = grouped_weights - zero_points
|
|
341
|
+
|
|
342
|
+
scales = rearrange(self.scales, "total_out_channels groups -> total_out_channels groups 1")
|
|
343
|
+
scaled_grouped_weights = grouped_weights * scales
|
|
344
|
+
result = rearrange(
|
|
345
|
+
scaled_grouped_weights,
|
|
346
|
+
"total_out_channels groups group_channels -> total_out_channels (groups group_channels)",
|
|
347
|
+
)
|
|
348
|
+
return result
|
|
349
|
+
|
|
350
|
+
def _apply_weights(self, inputs: Float[Array, " in_channels"]) -> Float[Array, " total_out_channels"]:
|
|
351
|
+
if self.config.activation_quantization_mode is not None:
|
|
352
|
+
inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
|
|
353
|
+
return self._prepare_scaled_weights() @ inputs
|
|
354
|
+
|
|
355
|
+
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
356
|
+
result = self._apply_weights(inputs)
|
|
357
|
+
if self.biases is not None:
|
|
358
|
+
result = result + self.biases
|
|
359
|
+
return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
|
|
360
|
+
|
|
361
|
+
def requantize_weights(self, weights, zero_points, scales):
|
|
362
|
+
"""
|
|
363
|
+
Requantize weights from [20, 6144] grouping to [2560, 48] grouping.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
weights: uint4 array of shape [M, N]
|
|
367
|
+
zero_points: uint4 array of shape [M//group_size_0, N//group_size_1]
|
|
368
|
+
scales: float16 array of shape [M//group_size_0, N//group_size_1]
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
new_weights: uint4 array of shape [M, N]
|
|
372
|
+
new_zero_points: uint4 array of shape [M, N//128]
|
|
373
|
+
new_scales: float16 array of shape [M, N//128]
|
|
374
|
+
"""
|
|
375
|
+
# Get dimensions
|
|
376
|
+
M, N = weights.shape
|
|
377
|
+
old_groups_0, old_groups_1 = zero_points.shape
|
|
378
|
+
|
|
379
|
+
# Calculate old group sizes
|
|
380
|
+
old_group_size_0 = M // old_groups_0 # 2560 // 20 = 128
|
|
381
|
+
old_group_size_1 = N // old_groups_1 # 6144 // 6144 = 1
|
|
382
|
+
|
|
383
|
+
# New group sizes
|
|
384
|
+
new_group_size_0 = 1 # 2560 // 2560 = 1
|
|
385
|
+
new_group_size_1 = self.config.group_size # 6144 // 48 = 128
|
|
386
|
+
|
|
387
|
+
# Step 1: Dequantize with original parameters
|
|
388
|
+
# Expand zero_points and scales to match weights shape
|
|
389
|
+
zp_expanded = jnp.repeat(jnp.repeat(zero_points, old_group_size_0, axis=0), old_group_size_1, axis=1)
|
|
390
|
+
scales_expanded = jnp.repeat(jnp.repeat(scales, old_group_size_0, axis=0), old_group_size_1, axis=1)
|
|
391
|
+
|
|
392
|
+
# Dequantize (convert to float for computation)
|
|
393
|
+
weights_float = weights.astype(jnp.float32)
|
|
394
|
+
zp_float = zp_expanded.astype(jnp.float32)
|
|
395
|
+
dequantized = (weights_float - zp_float) * scales_expanded.astype(jnp.float32)
|
|
396
|
+
|
|
397
|
+
# Step 2: Requantize with new group structure [2560, 48]
|
|
398
|
+
# Reshape for new groups
|
|
399
|
+
dequantized_reshaped = dequantized.reshape(
|
|
400
|
+
M // new_group_size_0,
|
|
401
|
+
new_group_size_0,
|
|
402
|
+
N // new_group_size_1,
|
|
403
|
+
new_group_size_1,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Compute new scales and zero points per group
|
|
407
|
+
# Move group dimensions to the end for reduction
|
|
408
|
+
dequantized_groups = dequantized_reshaped.transpose(0, 2, 1, 3) # [2560, 48, 1, 128]
|
|
409
|
+
|
|
410
|
+
# Find min and max per group
|
|
411
|
+
group_min = dequantized_groups.min(axis=(2, 3), keepdims=True)
|
|
412
|
+
group_max = dequantized_groups.max(axis=(2, 3), keepdims=True)
|
|
413
|
+
|
|
414
|
+
# Compute scales (with small epsilon to avoid division by zero)
|
|
415
|
+
eps = 1e-6
|
|
416
|
+
new_scales = ((group_max - group_min) / 15.0 + eps).astype(scales.dtype)
|
|
417
|
+
new_scales = new_scales.squeeze(axis=(2, 3)) # [2560, 48]
|
|
418
|
+
|
|
419
|
+
# Compute zero points (quantize to uint4 range 0-15)
|
|
420
|
+
new_zero_points = jnp.round(-group_min.squeeze(axis=(2, 3)) / new_scales).astype(jnp.uint4)
|
|
421
|
+
new_zero_points = jnp.clip(new_zero_points, 0, 15)
|
|
422
|
+
|
|
423
|
+
# Quantize with new parameters
|
|
424
|
+
scales_expanded_new = jnp.repeat(new_scales, new_group_size_1, axis=1).reshape(M, N)
|
|
425
|
+
zp_expanded_new = jnp.repeat(new_zero_points, new_group_size_1, axis=1).reshape(M, N)
|
|
426
|
+
|
|
427
|
+
new_weights = jnp.round(
|
|
428
|
+
dequantized / scales_expanded_new.astype(jnp.float32) + zp_expanded_new.astype(jnp.float32),
|
|
429
|
+
)
|
|
430
|
+
new_weights = jnp.clip(new_weights, 0, 15).astype(jnp.uint4)
|
|
431
|
+
|
|
432
|
+
return new_weights, new_zero_points, new_scales
|
|
433
|
+
|
|
434
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
435
|
+
exported_weights = self._into_layout(self.int_weights, weight_layout)
|
|
436
|
+
|
|
437
|
+
exported_zero_points = self._into_layout(self.int_zero_points, weight_layout)
|
|
438
|
+
|
|
439
|
+
exported_scales = self._into_layout(self.scales, weight_layout)
|
|
440
|
+
|
|
441
|
+
# CRIMINAL HACK!!!
|
|
442
|
+
exported_weights, exported_zero_points, exported_scales = self.requantize_weights(
|
|
443
|
+
exported_weights,
|
|
444
|
+
exported_zero_points,
|
|
445
|
+
exported_scales,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
result = ParameterDict(
|
|
449
|
+
weights=exported_weights,
|
|
450
|
+
zero_points=exported_zero_points,
|
|
451
|
+
scales=exported_scales,
|
|
452
|
+
)
|
|
453
|
+
if self.biases is not None:
|
|
454
|
+
result["biases"] = self.biases
|
|
455
|
+
return result
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
|
|
459
|
+
pass
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@dataclass(frozen=True)
|
|
463
|
+
class QLoRALinearConfig(GroupQuantizedLinearConfig):
|
|
464
|
+
lora_rank: int
|
|
465
|
+
lora_scale: float
|
|
466
|
+
activation_precision: DTypeLike
|
|
467
|
+
|
|
468
|
+
def random_init(
|
|
469
|
+
self,
|
|
470
|
+
input_dim: int,
|
|
471
|
+
output_dims: tuple[int, ...],
|
|
472
|
+
has_biases: bool,
|
|
473
|
+
*,
|
|
474
|
+
key: PRNGKeyArray,
|
|
475
|
+
) -> LinearBase:
|
|
476
|
+
base_key, derived_key = jax.random.split(key)
|
|
477
|
+
group_quantized_linear = super().random_init(input_dim, output_dims, has_biases, key=base_key)
|
|
478
|
+
assert isinstance(group_quantized_linear, GroupQuantizedLinear)
|
|
479
|
+
|
|
480
|
+
down_key, up_key_root = jax.random.split(derived_key)
|
|
481
|
+
hidden_lora_rank = len(output_dims) * self.lora_rank
|
|
482
|
+
max_down_abs_value = 1 / math.sqrt(input_dim)
|
|
483
|
+
lora_down_weights = jax.random.uniform(
|
|
484
|
+
down_key,
|
|
485
|
+
(hidden_lora_rank, input_dim),
|
|
486
|
+
minval=-max_down_abs_value,
|
|
487
|
+
maxval=max_down_abs_value,
|
|
488
|
+
dtype=self.activation_precision,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
up_keys = jax.random.split(up_key_root, len(output_dims))
|
|
492
|
+
max_up_abs_value = 1 / math.sqrt(hidden_lora_rank)
|
|
493
|
+
lora_up_weights = tuple(
|
|
494
|
+
jax.random.uniform(
|
|
495
|
+
up_key,
|
|
496
|
+
(output_dim, self.lora_rank),
|
|
497
|
+
minval=-max_up_abs_value,
|
|
498
|
+
maxval=max_up_abs_value,
|
|
499
|
+
dtype=self.activation_precision,
|
|
500
|
+
)
|
|
501
|
+
for up_key, output_dim in zip(up_keys, output_dims, strict=True)
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
return QLoRALinear(
|
|
505
|
+
config=self,
|
|
506
|
+
output_dims=output_dims,
|
|
507
|
+
weights=group_quantized_linear.weights,
|
|
508
|
+
scales=group_quantized_linear.scales,
|
|
509
|
+
biases=group_quantized_linear.biases,
|
|
510
|
+
zero_points=group_quantized_linear.zero_points,
|
|
511
|
+
lora_down_weights=lora_down_weights,
|
|
512
|
+
lora_up_weights=lora_up_weights,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
517
|
+
lora_down_weights: Float[Array, "total_lora_channels in_channels"]
|
|
518
|
+
lora_up_weights: tuple[Float[Array, "out_channels lora_channels"], ...]
|
|
519
|
+
|
|
520
|
+
def _split_biases(self) -> tuple[Float[Array, " out_channels"] | None, ...]:
|
|
521
|
+
if self.biases is not None:
|
|
522
|
+
return tuple(jnp.split(self.biases, self._get_split_points(self.output_dims)))
|
|
523
|
+
return (None,) * len(self.output_dims)
|
|
524
|
+
|
|
525
|
+
def __post_init__(self) -> None:
|
|
526
|
+
super().__post_init__()
|
|
527
|
+
if self.lora_down_weights.dtype != self.config.activation_precision:
|
|
528
|
+
raise ValueError(
|
|
529
|
+
f"LORA down weight dtype ({self.lora_down_weights.dtype}) is not equal to the"
|
|
530
|
+
f" specified activation precision ({self.config.activation_precision}).",
|
|
531
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
532
|
+
)
|
|
533
|
+
lora_down_output_dim, lora_down_input_dim = self.lora_down_weights.shape
|
|
534
|
+
if lora_down_output_dim != self.config.lora_rank * self.num_outputs:
|
|
535
|
+
raise ValueError(
|
|
536
|
+
f"Number of output channels in LORA down weights ({lora_down_output_dim}) is not"
|
|
537
|
+
f" equal to lora_rank * num_outputs ({self.config.lora_rank * self.num_outputs}).",
|
|
538
|
+
)
|
|
539
|
+
if lora_down_input_dim != self.input_dim:
|
|
540
|
+
raise ValueError(
|
|
541
|
+
f"Number of input channels in LORA down weights ({lora_down_input_dim}) is not"
|
|
542
|
+
f" equal to input_dim ({self.input_dim}).",
|
|
543
|
+
)
|
|
544
|
+
if len(self.lora_up_weights) != self.num_outputs:
|
|
545
|
+
raise ValueError(
|
|
546
|
+
f"Expected {self.num_outputs} LORA up weights, got {len(self.lora_up_weights)}.",
|
|
547
|
+
)
|
|
548
|
+
for lora_up_weight, output_dim in zip(self.lora_up_weights, self.output_dims, strict=True):
|
|
549
|
+
if lora_up_weight.dtype != self.config.activation_precision:
|
|
550
|
+
raise ValueError(
|
|
551
|
+
f"LORA up weight dtype ({lora_up_weight.dtype}) is not equal to specified activation precision"
|
|
552
|
+
f" ({self.config.activation_precision}).",
|
|
553
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
554
|
+
)
|
|
555
|
+
lora_up_output_dim, lora_up_input_dim = lora_up_weight.shape
|
|
556
|
+
if lora_up_output_dim != output_dim:
|
|
557
|
+
raise ValueError(
|
|
558
|
+
f"Number of output channels in LORA up weights ({lora_up_output_dim}) is not"
|
|
559
|
+
f" equal to number of output dims ({self.output_dims}).",
|
|
560
|
+
)
|
|
561
|
+
if lora_up_input_dim != self.config.lora_rank:
|
|
562
|
+
raise ValueError(
|
|
563
|
+
f"Number of input channels in LORA up weights ({lora_up_input_dim}) is not"
|
|
564
|
+
f" equal to lora_rank ({self.config.lora_rank}).",
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
568
|
+
joint_q_out = self._apply_weights(inputs)
|
|
569
|
+
q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
|
|
570
|
+
|
|
571
|
+
joint_lora_hidden = self.lora_down_weights @ inputs
|
|
572
|
+
lora_hiddens = jnp.split(joint_lora_hidden, self._get_split_points([self.config.lora_rank] * self.num_outputs))
|
|
573
|
+
lora_outs = [
|
|
574
|
+
lora_up_weight @ lora_hidden
|
|
575
|
+
for lora_up_weight, lora_hidden in zip(self.lora_up_weights, lora_hiddens, strict=True)
|
|
576
|
+
]
|
|
577
|
+
|
|
578
|
+
results = []
|
|
579
|
+
for q_out, lora_out, bias in zip(q_outs, lora_outs, self._split_biases(), strict=True):
|
|
580
|
+
result = q_out + self.config.lora_scale * lora_out
|
|
581
|
+
if bias is not None:
|
|
582
|
+
result = result + bias
|
|
583
|
+
results.append(result)
|
|
584
|
+
|
|
585
|
+
return tuple(results)
|
|
586
|
+
|
|
587
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
588
|
+
quantized_linear_weights = super().export_weights()
|
|
589
|
+
exported_lora_down_weights = self._into_layout(self.lora_down_weights, weight_layout)
|
|
590
|
+
exported_lora_up_weights = tuple(
|
|
591
|
+
self._into_layout(lora_up_weight, weight_layout) for lora_up_weight in self.lora_up_weights
|
|
592
|
+
)
|
|
593
|
+
return ParameterDict(
|
|
594
|
+
**quantized_linear_weights,
|
|
595
|
+
down_weights=exported_lora_down_weights,
|
|
596
|
+
up_weights=exported_lora_up_weights,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | QLoRALinearConfig
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
register_config_union(LinearConfig)
|
lalamo/modules/mlp.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from jaxtyping import Array, DTypeLike, Float, PRNGKeyArray
|
|
5
|
+
|
|
6
|
+
from lalamo.common import ParameterDict
|
|
7
|
+
|
|
8
|
+
from .activations import Activation
|
|
9
|
+
from .common import LalamoModule, WeightLayout
|
|
10
|
+
from .linear import LinearBase, LinearConfig
|
|
11
|
+
|
|
12
|
+
__all__ = ["MLP", "MLPConfig"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class MLPConfig:
|
|
17
|
+
linear_config: LinearConfig
|
|
18
|
+
activation: Activation
|
|
19
|
+
|
|
20
|
+
def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MLP":
|
|
21
|
+
up_key, down_key = jax.random.split(key)
|
|
22
|
+
return MLP(
|
|
23
|
+
self,
|
|
24
|
+
up_projection=self.linear_config.random_init(
|
|
25
|
+
model_dim,
|
|
26
|
+
(hidden_dim, hidden_dim),
|
|
27
|
+
has_biases=False,
|
|
28
|
+
key=up_key,
|
|
29
|
+
),
|
|
30
|
+
down_projection=self.linear_config.random_init(
|
|
31
|
+
hidden_dim,
|
|
32
|
+
(model_dim,),
|
|
33
|
+
has_biases=False,
|
|
34
|
+
key=down_key,
|
|
35
|
+
),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MLP(LalamoModule):
|
|
40
|
+
up_projection: LinearBase
|
|
41
|
+
down_projection: LinearBase
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def activation_precision(self) -> DTypeLike:
|
|
45
|
+
return self.up_projection.activation_precision
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def model_dim(self) -> int:
|
|
49
|
+
return self.up_projection.input_dim
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def hidden_dim(self) -> int:
|
|
53
|
+
return self.down_projection.input_dim
|
|
54
|
+
|
|
55
|
+
def __post_init__(self) -> None:
|
|
56
|
+
up_output_dim, gate_output_dim = self.up_projection.output_dims
|
|
57
|
+
if up_output_dim != gate_output_dim:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Up projection output dimension {up_output_dim} does not match"
|
|
60
|
+
f" the gate output dimension {gate_output_dim}",
|
|
61
|
+
)
|
|
62
|
+
(down_output_dim,) = self.down_projection.output_dims
|
|
63
|
+
if self.up_projection.input_dim != down_output_dim:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Down projection input dimension {down_output_dim} does not match"
|
|
66
|
+
f" the up projection output dimension {self.up_projection.input_dim}",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
|
|
70
|
+
up_proj, gate = self.up_projection(inputs)
|
|
71
|
+
gate = self.config.activation(gate)
|
|
72
|
+
(result,) = self.down_projection(up_proj * gate)
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
76
|
+
return ParameterDict(
|
|
77
|
+
up_projection=self.up_projection.export_weights(weight_layout),
|
|
78
|
+
down_projection=self.down_projection.export_weights(weight_layout),
|
|
79
|
+
)
|