keras-nightly 3.12.0.dev2025090203__py3-none-any.whl → 3.12.0.dev2025090403__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/src/backend/jax/layer.py +3 -1
- keras/src/models/model.py +2 -2
- keras/src/quantizers/gptq.py +284 -192
- keras/src/quantizers/gptq_config.py +3 -13
- keras/src/quantizers/gptq_core.py +211 -158
- keras/src/quantizers/quantizers.py +200 -0
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/METADATA +1 -1
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/RECORD +11 -12
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/layer.py
CHANGED
@@ -3,7 +3,9 @@ from keras.src.backend.config import is_nnx_enabled
|
|
3
3
|
if is_nnx_enabled():
|
4
4
|
from flax import nnx
|
5
5
|
|
6
|
-
BaseLayer
|
6
|
+
class BaseLayer(nnx.Module):
|
7
|
+
def __init_subclass__(cls, **kwargs):
|
8
|
+
super().__init_subclass__(pytree=False, **kwargs)
|
7
9
|
else:
|
8
10
|
BaseLayer = object
|
9
11
|
|
keras/src/models/model.py
CHANGED
@@ -9,6 +9,7 @@ from keras.src.api_export import keras_export
|
|
9
9
|
from keras.src.layers.layer import Layer
|
10
10
|
from keras.src.models.variable_mapping import map_saveable_variables
|
11
11
|
from keras.src.quantizers.gptq_config import GPTQConfig
|
12
|
+
from keras.src.quantizers.gptq_core import gptq_quantize
|
12
13
|
from keras.src.saving import saving_api
|
13
14
|
from keras.src.trainers import trainer as base_trainer
|
14
15
|
from keras.src.utils import summary_utils
|
@@ -440,8 +441,7 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
440
441
|
"The `config` argument must be of type "
|
441
442
|
"`keras.quantizers.GPTQConfig`."
|
442
443
|
)
|
443
|
-
|
444
|
-
config.quantize(self)
|
444
|
+
gptq_quantize(self, config)
|
445
445
|
return
|
446
446
|
|
447
447
|
# For all other modes, verify that a config object was not passed.
|
keras/src/quantizers/gptq.py
CHANGED
@@ -1,14 +1,235 @@
|
|
1
|
+
import types
|
2
|
+
|
1
3
|
from keras.src import ops
|
2
4
|
from keras.src.layers import Dense
|
3
5
|
from keras.src.layers import EinsumDense
|
4
|
-
from keras.src.
|
6
|
+
from keras.src.ops import linalg
|
7
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
8
|
+
from keras.src.quantizers.quantizers import GPTQQuantizer
|
9
|
+
from keras.src.quantizers.quantizers import compute_quantization_parameters
|
10
|
+
from keras.src.quantizers.quantizers import dequantize_with_zero_point
|
11
|
+
from keras.src.quantizers.quantizers import quantize_with_zero_point
|
12
|
+
|
13
|
+
|
14
|
+
def _stable_permutation(metric):
|
15
|
+
"""Return a stable permutation that sorts `metric` in descending order.
|
16
|
+
Uses an index-based jitter to break ties deterministically."""
|
17
|
+
n = ops.shape(metric)[0]
|
18
|
+
idx = ops.arange(0, n, dtype="int32")
|
19
|
+
|
20
|
+
# tiny jitter = (idx / n) * 1e-12 so it never flips a real strict ordering
|
21
|
+
jitter = ops.divide(ops.cast(idx, "float32"), ops.cast(n, "float32"))
|
22
|
+
metric_jittered = ops.add(metric, ops.multiply(jitter, 1e-12))
|
23
|
+
|
24
|
+
# argsort by negative to get descending
|
25
|
+
return ops.argsort(ops.negative(metric_jittered))
|
26
|
+
|
27
|
+
|
28
|
+
def gptq_quantize_matrix(
|
29
|
+
weights_transpose,
|
30
|
+
inv_hessian,
|
31
|
+
*,
|
32
|
+
blocksize=128,
|
33
|
+
group_size=-1,
|
34
|
+
activation_order=False,
|
35
|
+
order_metric=None,
|
36
|
+
compute_scale_zero=compute_quantization_parameters,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Implements the GPTQ error correction updates.
|
40
|
+
|
41
|
+
For a single column update (column j):
|
42
|
+
e = invH[j, j] * (w_j - q_j)
|
43
|
+
W[:, j+1:] -= e * invH[j, j+1:]
|
44
|
+
where:
|
45
|
+
- w_j is the original column,
|
46
|
+
- q_j is the quantized column,
|
47
|
+
- invH is the inverse Hessian,
|
48
|
+
- e is the propagated error term.
|
49
|
+
|
50
|
+
Across entire blocks:
|
51
|
+
W[:, future] -= E_block * invH[block, future]
|
52
|
+
where:
|
53
|
+
- E_block is the quantization error accumulated for the current block,
|
54
|
+
- invH[block, future] denotes the cross-block slice of the inverse Hessian,
|
55
|
+
- W[:, future] are the columns yet to be quantized.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
weights_transpose: Transposed weight matrix [out_features, in_features]
|
59
|
+
to quantize.
|
60
|
+
inv_hessian: Inverse Hessian matrix [in_features, in_features] for
|
61
|
+
error propagation.
|
62
|
+
blocksize: Size of the blocks to process (default: 128).
|
63
|
+
group_size: Size of the groups for parameter reuse
|
64
|
+
(default: -1, no grouping).
|
65
|
+
activation_order: Whether to apply activation-order permutation
|
66
|
+
(default: False).
|
67
|
+
order_metric: Metric for ordering features
|
68
|
+
(default: None, uses 1 / diag(invH)).
|
69
|
+
compute_scale_zero: Function to compute scale and zero for
|
70
|
+
quantization.
|
71
|
+
"""
|
72
|
+
in_features = ops.shape(weights_transpose)[1]
|
73
|
+
|
74
|
+
# Optional activation-order permutation on feature axis (axis=1)
|
75
|
+
if activation_order:
|
76
|
+
if order_metric is None:
|
77
|
+
# Use 1 / diag(inverse_hessian) as importance proxy if H not
|
78
|
+
# available.
|
79
|
+
order_metric = ops.reciprocal(
|
80
|
+
ops.add(ops.diagonal(inv_hessian), 1e-12)
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
# sanitize provided metric
|
84
|
+
order_metric = ops.cast(order_metric, "float32")
|
85
|
+
order_metric = ops.where(
|
86
|
+
ops.isfinite(order_metric),
|
87
|
+
order_metric,
|
88
|
+
ops.zeros_like(order_metric),
|
89
|
+
)
|
90
|
+
|
91
|
+
# Sort in descending order by importance
|
92
|
+
perm = _stable_permutation(order_metric)
|
93
|
+
inv_perm = ops.argsort(perm)
|
94
|
+
|
95
|
+
weights_transpose = ops.take(weights_transpose, perm, axis=1)
|
96
|
+
inv_hessian = ops.take(
|
97
|
+
ops.take(inv_hessian, perm, axis=0), perm, axis=1
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
perm = inv_perm = None
|
101
|
+
|
102
|
+
# weights_buffer: [out_features, in_features]
|
103
|
+
weights_buffer = weights_transpose
|
104
|
+
# quantized_weights_buffer: [out_features, in_features]
|
105
|
+
quantized_weights_buffer = ops.zeros_like(weights_buffer)
|
106
|
+
|
107
|
+
# Process features in blocks
|
108
|
+
for block_start in range(0, in_features, blocksize):
|
109
|
+
block_end = min(block_start + blocksize, in_features)
|
110
|
+
block_size = block_end - block_start
|
111
|
+
|
112
|
+
# Block views
|
113
|
+
# block_weights: [out_features, bsize]
|
114
|
+
block_weights = weights_buffer[:, block_start:block_end]
|
115
|
+
# block_weights_quantized: [out_features, bsize]
|
116
|
+
block_weights_quantized = ops.zeros_like(block_weights)
|
117
|
+
# block_error: [out_features, bsize]
|
118
|
+
block_error = ops.zeros_like(block_weights)
|
119
|
+
# block_inv_hessian: [bsize, bsize]
|
120
|
+
block_inv_hessian = inv_hessian[
|
121
|
+
block_start:block_end, block_start:block_end
|
122
|
+
]
|
123
|
+
|
124
|
+
# group cache for per-group s/z/maxq reuse
|
125
|
+
cached_scale = None
|
126
|
+
cached_zero = None
|
127
|
+
cached_maxq = None
|
128
|
+
cached_group_start = -1
|
129
|
+
|
130
|
+
for block_idx in range(block_size):
|
131
|
+
global_idx = block_start + block_idx
|
132
|
+
# weight_column: [out_features,]
|
133
|
+
weight_column = block_weights[:, block_idx]
|
134
|
+
|
135
|
+
# Group-wise parameter reuse (compute once per group)
|
136
|
+
if group_size != -1:
|
137
|
+
# Determine group boundaries
|
138
|
+
group_start = (global_idx // group_size) * group_size
|
139
|
+
if group_start != cached_group_start:
|
140
|
+
group_end = min(group_start + group_size, in_features)
|
141
|
+
# group_slice: [out_features, group_len]
|
142
|
+
group_slice = weights_buffer[:, group_start:group_end]
|
143
|
+
cached_scale, cached_zero, cached_maxq = compute_scale_zero(
|
144
|
+
group_slice, weight=True
|
145
|
+
)
|
146
|
+
cached_group_start = group_start
|
147
|
+
scale, zero, maxq = cached_scale, cached_zero, cached_maxq
|
148
|
+
else:
|
149
|
+
# Per-column params
|
150
|
+
scale, zero, maxq = compute_scale_zero(
|
151
|
+
ops.expand_dims(weight_column, 1), weight=True
|
152
|
+
)
|
153
|
+
|
154
|
+
# Quantize one column
|
155
|
+
# quantized_column: [out_features,]
|
156
|
+
quantized_column = quantize_with_zero_point(
|
157
|
+
ops.expand_dims(weight_column, 1), scale, zero, maxq
|
158
|
+
)
|
159
|
+
quantized_column = dequantize_with_zero_point(
|
160
|
+
quantized_column, scale, zero
|
161
|
+
)[:, 0]
|
162
|
+
block_weights_quantized = ops.slice_update(
|
163
|
+
block_weights_quantized,
|
164
|
+
(0, block_idx),
|
165
|
+
ops.expand_dims(quantized_column, 1),
|
166
|
+
)
|
167
|
+
|
168
|
+
# Error feedback for remaining columns within the block
|
169
|
+
# diag: [out_features,]
|
170
|
+
diag = block_inv_hessian[block_idx, block_idx]
|
171
|
+
# error = (col - quantized_col) / block_inv_hessian[idx, idx]
|
172
|
+
# error: [out_features,]
|
173
|
+
error = ops.divide(
|
174
|
+
ops.subtract(weight_column, quantized_column), diag
|
175
|
+
)
|
176
|
+
# block_error: [out_features, bsize]
|
177
|
+
block_error = ops.slice_update(
|
178
|
+
block_error, (0, block_idx), ops.expand_dims(error, 1)
|
179
|
+
)
|
180
|
+
|
181
|
+
if block_idx < block_size - 1:
|
182
|
+
update = ops.matmul(
|
183
|
+
ops.expand_dims(error, 1),
|
184
|
+
ops.expand_dims(
|
185
|
+
block_inv_hessian[block_idx, block_idx + 1 :], 0
|
186
|
+
),
|
187
|
+
)
|
188
|
+
tail = block_weights[:, block_idx + 1 :]
|
189
|
+
block_weights = ops.slice_update(
|
190
|
+
block_weights,
|
191
|
+
(0, block_idx + 1),
|
192
|
+
ops.subtract(tail, update),
|
193
|
+
)
|
194
|
+
|
195
|
+
# Write block’s quantized columns into result
|
196
|
+
left = quantized_weights_buffer[:, :block_start]
|
197
|
+
right = quantized_weights_buffer[:, block_end:]
|
198
|
+
quantized_weights_buffer = ops.concatenate(
|
199
|
+
[left, block_weights_quantized, right], axis=1
|
200
|
+
)
|
201
|
+
|
202
|
+
# Propagate block errors to *future* features (beyond the block)
|
203
|
+
if block_end < in_features:
|
204
|
+
# weights_buffer[:, block_end:] -=
|
205
|
+
# block_error @ invH[block_start:block_end, block_end:]
|
206
|
+
# total_update: [out_features, bsize]
|
207
|
+
total_update = ops.matmul(
|
208
|
+
block_error, inv_hessian[block_start:block_end, block_end:]
|
209
|
+
)
|
210
|
+
weights_buffer = ops.concatenate(
|
211
|
+
[
|
212
|
+
weights_buffer[:, :block_end],
|
213
|
+
ops.subtract(weights_buffer[:, block_end:], total_update),
|
214
|
+
],
|
215
|
+
axis=1,
|
216
|
+
)
|
217
|
+
|
218
|
+
# Undo permutation if used
|
219
|
+
if activation_order:
|
220
|
+
quantized_weights_buffer = ops.take(
|
221
|
+
quantized_weights_buffer, inv_perm, axis=1
|
222
|
+
)
|
223
|
+
|
224
|
+
return quantized_weights_buffer
|
5
225
|
|
6
226
|
|
7
227
|
class GPTQ:
|
8
|
-
def __init__(self, layer):
|
228
|
+
def __init__(self, layer, config=GPTQConfig(tokenizer=None, dataset=None)):
|
9
229
|
self.original_layer = layer
|
10
230
|
self.num_samples = 0
|
11
|
-
self.
|
231
|
+
self.config = config
|
232
|
+
self.quantizer = GPTQQuantizer(config)
|
12
233
|
|
13
234
|
# Explicitly handle each supported layer type
|
14
235
|
if isinstance(layer, Dense) or (
|
@@ -16,9 +237,11 @@ class GPTQ:
|
|
16
237
|
):
|
17
238
|
# For a standard Dense layer, the dimensions are straightforward.
|
18
239
|
self.kernel_shape = layer.kernel.shape
|
19
|
-
|
20
|
-
self.
|
21
|
-
|
240
|
+
# rows: [input_features]
|
241
|
+
self.rows = self.kernel_shape[0]
|
242
|
+
# columns: [output_features]
|
243
|
+
self.columns = self.kernel_shape[1]
|
244
|
+
self.layer = layer
|
22
245
|
|
23
246
|
# Handle 3D EinsumDense layers (typically from attention blocks).
|
24
247
|
elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
|
@@ -47,16 +270,10 @@ class GPTQ:
|
|
47
270
|
|
48
271
|
# Create a temporary object that holds a reshaped
|
49
272
|
# 2D version of the kernel.
|
50
|
-
self.layer =
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
"kernel": ops.reshape(
|
55
|
-
layer.kernel, (self.rows, self.columns)
|
56
|
-
),
|
57
|
-
"bias": layer.bias,
|
58
|
-
},
|
59
|
-
)()
|
273
|
+
self.layer = types.SimpleNamespace(
|
274
|
+
kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
|
275
|
+
bias=layer.bias,
|
276
|
+
)
|
60
277
|
|
61
278
|
else:
|
62
279
|
# Raise an error if the layer is not supported.
|
@@ -86,53 +303,54 @@ class GPTQ:
|
|
86
303
|
pre-initialized Hessian matrix `self.hessian`.
|
87
304
|
"""
|
88
305
|
if input_batch is None:
|
89
|
-
raise ValueError("Input tensor
|
90
|
-
|
306
|
+
raise ValueError("Input tensor cannot be None.")
|
91
307
|
if len(input_batch.shape) < 2:
|
92
308
|
raise ValueError(
|
93
|
-
|
94
|
-
f"(
|
95
|
-
f"{len(input_batch.shape)}."
|
309
|
+
"Input tensor must have rank >= 2 "
|
310
|
+
f"(got rank {len(input_batch.shape)})."
|
96
311
|
)
|
97
312
|
if ops.size(input_batch) == 0:
|
98
|
-
raise ValueError("Input tensor
|
313
|
+
raise ValueError("Input tensor cannot be empty.")
|
99
314
|
|
100
315
|
if len(input_batch.shape) > 2:
|
316
|
+
# [batch, features]
|
101
317
|
input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
|
102
|
-
|
318
|
+
x = ops.cast(input_batch, "float32")
|
103
319
|
|
104
|
-
|
320
|
+
num_new_samples = ops.shape(x)[0]
|
321
|
+
num_prev_samples = self.num_samples
|
322
|
+
total_samples = ops.add(num_prev_samples, num_new_samples)
|
323
|
+
|
324
|
+
if ops.shape(self.hessian)[0] != ops.shape(x)[-1]:
|
105
325
|
raise ValueError(
|
106
|
-
f"Hessian dimensions ({self.hessian
|
107
|
-
"match input features ({
|
326
|
+
f"Hessian dimensions ({ops.shape(self.hessian)[0]}) do not "
|
327
|
+
f"match input features ({ops.shape(x)[-1]})."
|
108
328
|
)
|
109
329
|
|
110
|
-
|
111
|
-
|
330
|
+
# gram_matrix: [features, features]
|
331
|
+
gram_matrix = ops.matmul(ops.transpose(x), x)
|
332
|
+
# Ensures numerical stability and symmetry in case of large floating
|
333
|
+
# point activations.
|
334
|
+
gram_matrix = ops.divide(
|
335
|
+
ops.add(gram_matrix, ops.transpose(gram_matrix)), 2.0
|
112
336
|
)
|
113
337
|
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
current_hessian_weight = ops.divide(
|
120
|
-
input_batch.shape[0], total_samples
|
338
|
+
# Decay previous mean and add current per-sample contribution
|
339
|
+
# (factor 2/N)
|
340
|
+
if self.num_samples > 0:
|
341
|
+
self.hessian = ops.multiply(
|
342
|
+
self.hessian, ops.divide(num_prev_samples, total_samples)
|
121
343
|
)
|
344
|
+
self.hessian = ops.add(
|
345
|
+
self.hessian,
|
346
|
+
ops.multiply(ops.divide(2.0, total_samples), gram_matrix),
|
347
|
+
)
|
122
348
|
|
123
|
-
|
124
|
-
old_term = ops.multiply(self.hessian, old_hessian_weight)
|
125
|
-
current_term = ops.multiply(current_hessian, current_hessian_weight)
|
126
|
-
self.hessian = ops.add(old_term, current_term)
|
127
|
-
|
128
|
-
self.num_samples = ops.add(self.num_samples, input_batch.shape[0])
|
349
|
+
self.num_samples = self.num_samples + ops.shape(x)[0] or 0
|
129
350
|
|
130
|
-
def
|
351
|
+
def quantize_and_correct_layer(
|
131
352
|
self,
|
132
353
|
blocksize=128,
|
133
|
-
hessian_damping=0.01,
|
134
|
-
group_size=-1,
|
135
|
-
activation_order=False,
|
136
354
|
):
|
137
355
|
"""
|
138
356
|
Performs GPTQ quantization and correction on the layer's weights.
|
@@ -143,23 +361,23 @@ class GPTQ:
|
|
143
361
|
quantization error by updating the remaining weights.
|
144
362
|
|
145
363
|
The algorithm follows these main steps:
|
146
|
-
1.
|
364
|
+
1. Initialization: It optionally reorders the weight columns based
|
147
365
|
on activation magnitudes (`activation_order=True`) to protect more
|
148
366
|
salient
|
149
367
|
weights.
|
150
|
-
2.
|
368
|
+
2. Hessian Modification: The Hessian matrix, pre-computed from
|
151
369
|
calibration data, is dampened to ensure its invertibility and
|
152
370
|
stability.
|
153
|
-
3.
|
371
|
+
3. Iterative Quantization: The function iterates through the
|
154
372
|
weight columns in blocks (`blocksize`). In each iteration, it:
|
155
373
|
a. Quantizes one column.
|
156
374
|
b. Calculates the quantization error.
|
157
375
|
c. Updates the remaining weights in the *current* block by
|
158
376
|
distributing the error, using the inverse Hessian.
|
159
|
-
4.
|
377
|
+
4. Block-wise Correction: After a block is quantized, the total
|
160
378
|
error from that block is propagated to the *next* block of weights
|
161
379
|
to be processed.
|
162
|
-
5.
|
380
|
+
5. Finalization: The quantized weights are reordered back if
|
163
381
|
`activation_order` was used, and the layer's weights are updated.
|
164
382
|
|
165
383
|
This implementation is based on the official GPTQ paper and repository.
|
@@ -170,39 +388,16 @@ class GPTQ:
|
|
170
388
|
Args:
|
171
389
|
blocksize: (int, optional) The size of the weight block to process
|
172
390
|
at a time. Defaults to 128.
|
173
|
-
hessian_damping: (float, optional) The percentage of dampening to
|
174
|
-
add the
|
175
|
-
Hessian's diagonal. A value of 0.01 is recommended.
|
176
|
-
Defaults to 0.01.
|
177
|
-
group_size: (int, optional) The number of weights that share the
|
178
|
-
same quantization parameters (scale and zero-point).
|
179
|
-
A value of -1 indicates per-channel quantization.
|
180
|
-
activation_order: (bool, optional) If True, reorders weight columns
|
181
|
-
based
|
182
|
-
on their activation's second-order information.
|
183
391
|
"""
|
184
392
|
|
185
|
-
weights_matrix = ops.transpose(
|
186
|
-
hessian_matrix = ops.cast(self.hessian, "float32")
|
187
|
-
|
188
|
-
if activation_order:
|
189
|
-
permutation = ops.argsort(
|
190
|
-
ops.negative(ops.diagonal(hessian_matrix))
|
191
|
-
)
|
192
|
-
weights_matrix = ops.take(weights_matrix, permutation, axis=1)
|
193
|
-
hessian_matrix = ops.take(
|
194
|
-
ops.take(hessian_matrix, permutation, axis=0),
|
195
|
-
permutation,
|
196
|
-
axis=1,
|
197
|
-
)
|
198
|
-
inverse_permutation = ops.argsort(permutation)
|
393
|
+
weights_matrix = ops.transpose(self.layer.kernel)
|
199
394
|
|
200
395
|
# Dampen the Hessian for Stability
|
201
|
-
hessian_diagonal = ops.diagonal(
|
396
|
+
hessian_diagonal = ops.diagonal(self.hessian)
|
202
397
|
dead_diagonal = ops.equal(hessian_diagonal, 0.0)
|
203
398
|
hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)
|
204
399
|
hessian_matrix = ops.add(
|
205
|
-
|
400
|
+
self.hessian,
|
206
401
|
ops.diag(
|
207
402
|
ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))
|
208
403
|
),
|
@@ -210,7 +405,7 @@ class GPTQ:
|
|
210
405
|
|
211
406
|
# Add dampening factor to the Hessian diagonal
|
212
407
|
damping_factor = ops.multiply(
|
213
|
-
hessian_damping, ops.mean(hessian_diagonal)
|
408
|
+
self.config.hessian_damping, ops.mean(hessian_diagonal)
|
214
409
|
)
|
215
410
|
hessian_diagonal = ops.add(hessian_diagonal, damping_factor)
|
216
411
|
hessian_matrix = ops.add(
|
@@ -221,116 +416,17 @@ class GPTQ:
|
|
221
416
|
)
|
222
417
|
|
223
418
|
# Compute the inverse Hessian, which is used for error correction
|
224
|
-
inverse_hessian =
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
block_inverse_hessian = inverse_hessian[
|
236
|
-
block_start:block_end, block_start:block_end
|
237
|
-
]
|
238
|
-
|
239
|
-
# Process one column at a time within the block
|
240
|
-
for col_idx in range(block_size):
|
241
|
-
weight_column = block_weights[:, col_idx]
|
242
|
-
diagonal_element = block_inverse_hessian[col_idx, col_idx]
|
243
|
-
|
244
|
-
if group_size != -1:
|
245
|
-
if ops.mod(ops.add(block_start, col_idx), group_size) == 0:
|
246
|
-
self.quantizer.find_params(
|
247
|
-
weights_matrix[
|
248
|
-
:,
|
249
|
-
(ops.add(block_start, col_idx)) : (
|
250
|
-
ops.add(
|
251
|
-
ops.add(block_start, col_idx),
|
252
|
-
group_size,
|
253
|
-
)
|
254
|
-
),
|
255
|
-
],
|
256
|
-
weight=True,
|
257
|
-
)
|
258
|
-
else:
|
259
|
-
self.quantizer.find_params(
|
260
|
-
ops.expand_dims(weight_column, 1), weight=True
|
261
|
-
)
|
262
|
-
|
263
|
-
# Quantize the current weight column
|
264
|
-
quantized_column = dequantize(
|
265
|
-
ops.expand_dims(weight_column, 1),
|
266
|
-
self.quantizer.scale,
|
267
|
-
self.quantizer.zero,
|
268
|
-
self.quantizer.maxq,
|
269
|
-
)[:, 0]
|
270
|
-
|
271
|
-
block_quantized = ops.slice_update(
|
272
|
-
block_quantized,
|
273
|
-
(0, col_idx),
|
274
|
-
ops.expand_dims(quantized_column, axis=1),
|
275
|
-
)
|
276
|
-
quantization_error = ops.divide(
|
277
|
-
ops.subtract(weight_column, quantized_column),
|
278
|
-
diagonal_element,
|
279
|
-
)
|
280
|
-
block_errors = ops.slice_update(
|
281
|
-
block_errors,
|
282
|
-
(0, col_idx),
|
283
|
-
ops.expand_dims(quantization_error, axis=1),
|
284
|
-
)
|
285
|
-
|
286
|
-
if ops.less(col_idx, ops.subtract(block_size, 1)):
|
287
|
-
error_update = ops.matmul(
|
288
|
-
ops.expand_dims(quantization_error, 1),
|
289
|
-
ops.expand_dims(
|
290
|
-
block_inverse_hessian[
|
291
|
-
col_idx, ops.add(col_idx, 1) :
|
292
|
-
],
|
293
|
-
0,
|
294
|
-
),
|
295
|
-
)
|
296
|
-
|
297
|
-
# Efficiently update the remaining part of the
|
298
|
-
# block_weights tensor.
|
299
|
-
slice_to_update = block_weights[:, ops.add(col_idx, 1) :]
|
300
|
-
updated_slice = ops.subtract(slice_to_update, error_update)
|
301
|
-
block_weights = ops.slice_update(
|
302
|
-
block_weights, (0, ops.add(col_idx, 1)), updated_slice
|
303
|
-
)
|
304
|
-
|
305
|
-
# Update the full quantized matrix with the processed block
|
306
|
-
quantized_weights = ops.concatenate(
|
307
|
-
[
|
308
|
-
quantized_weights[:, :block_start],
|
309
|
-
block_quantized,
|
310
|
-
quantized_weights[:, block_end:],
|
311
|
-
],
|
312
|
-
axis=1,
|
313
|
-
)
|
314
|
-
|
315
|
-
if block_end < self.rows:
|
316
|
-
total_error_update = ops.matmul(
|
317
|
-
block_errors,
|
318
|
-
inverse_hessian[block_start:block_end, block_end:],
|
319
|
-
)
|
320
|
-
weights_matrix = ops.concatenate(
|
321
|
-
[
|
322
|
-
weights_matrix[:, :block_end],
|
323
|
-
ops.subtract(
|
324
|
-
weights_matrix[:, block_end:], total_error_update
|
325
|
-
),
|
326
|
-
],
|
327
|
-
axis=1,
|
328
|
-
)
|
329
|
-
|
330
|
-
if activation_order:
|
331
|
-
quantized_weights = ops.take(
|
332
|
-
quantized_weights, inverse_permutation, axis=1
|
333
|
-
)
|
419
|
+
inverse_hessian = linalg.inv(hessian_matrix)
|
420
|
+
|
421
|
+
quantized_weights = gptq_quantize_matrix(
|
422
|
+
weights_matrix,
|
423
|
+
inv_hessian=inverse_hessian,
|
424
|
+
blocksize=blocksize,
|
425
|
+
group_size=self.config.group_size,
|
426
|
+
activation_order=self.config.activation_order,
|
427
|
+
order_metric=ops.diagonal(hessian_matrix),
|
428
|
+
compute_scale_zero=self.quantizer.find_params,
|
429
|
+
)
|
334
430
|
|
335
431
|
quantized_weights = ops.transpose(quantized_weights)
|
336
432
|
|
@@ -340,11 +436,7 @@ class GPTQ:
|
|
340
436
|
)
|
341
437
|
|
342
438
|
# Set the new quantized weights in the original layer
|
343
|
-
|
344
|
-
if self.original_layer.bias is not None:
|
345
|
-
new_weights.append(ops.convert_to_numpy(self.original_layer.bias))
|
346
|
-
|
347
|
-
self.original_layer.set_weights(new_weights)
|
439
|
+
self.original_layer._kernel.assign(quantized_weights)
|
348
440
|
|
349
441
|
def free(self):
|
350
442
|
self.hessian = None
|
@@ -1,7 +1,4 @@
|
|
1
|
-
from absl import logging
|
2
|
-
|
3
1
|
from keras.src.api_export import keras_export
|
4
|
-
from keras.src.quantizers.gptq_core import quantize_model
|
5
2
|
|
6
3
|
|
7
4
|
@keras_export("keras.quantizers.GPTQConfig")
|
@@ -140,8 +137,10 @@ class GPTQConfig:
|
|
140
137
|
self,
|
141
138
|
dataset,
|
142
139
|
tokenizer,
|
140
|
+
*,
|
143
141
|
weight_bits: int = 4,
|
144
142
|
num_samples: int = 128,
|
143
|
+
per_channel: bool = True,
|
145
144
|
sequence_length: int = 512,
|
146
145
|
hessian_damping: float = 0.01,
|
147
146
|
group_size: int = 128,
|
@@ -151,19 +150,10 @@ class GPTQConfig:
|
|
151
150
|
self.dataset = dataset
|
152
151
|
self.tokenizer = tokenizer
|
153
152
|
self.num_samples = num_samples
|
153
|
+
self.per_channel = per_channel
|
154
154
|
self.sequence_length = sequence_length
|
155
155
|
self.hessian_damping = hessian_damping
|
156
156
|
self.weight_bits = weight_bits
|
157
157
|
self.group_size = group_size
|
158
158
|
self.symmetric = symmetric
|
159
159
|
self.activation_order = activation_order
|
160
|
-
|
161
|
-
def quantize(self, model):
|
162
|
-
"""
|
163
|
-
Applies GPTQ quantization to the provided model using this
|
164
|
-
configuration.
|
165
|
-
"""
|
166
|
-
logging.info("Initiating quantization from GPTQConfig...")
|
167
|
-
# The core logic is now delegated to gptqutils, which will handle
|
168
|
-
# the dynamic imports and data loading.
|
169
|
-
quantize_model(model=model, config=self)
|