keras-nightly 3.12.0.dev2025090303__py3-none-any.whl → 3.12.0.dev2025090403__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
keras/src/models/model.py CHANGED
@@ -9,6 +9,7 @@ from keras.src.api_export import keras_export
9
9
  from keras.src.layers.layer import Layer
10
10
  from keras.src.models.variable_mapping import map_saveable_variables
11
11
  from keras.src.quantizers.gptq_config import GPTQConfig
12
+ from keras.src.quantizers.gptq_core import gptq_quantize
12
13
  from keras.src.saving import saving_api
13
14
  from keras.src.trainers import trainer as base_trainer
14
15
  from keras.src.utils import summary_utils
@@ -440,8 +441,7 @@ class Model(Trainer, base_trainer.Trainer, Layer):
440
441
  "The `config` argument must be of type "
441
442
  "`keras.quantizers.GPTQConfig`."
442
443
  )
443
- # The config object's own quantize method drives the process
444
- config.quantize(self)
444
+ gptq_quantize(self, config)
445
445
  return
446
446
 
447
447
  # For all other modes, verify that a config object was not passed.
@@ -1,14 +1,235 @@
1
+ import types
2
+
1
3
  from keras.src import ops
2
4
  from keras.src.layers import Dense
3
5
  from keras.src.layers import EinsumDense
4
- from keras.src.quantizers.gptq_quant import dequantize
6
+ from keras.src.ops import linalg
7
+ from keras.src.quantizers.gptq_config import GPTQConfig
8
+ from keras.src.quantizers.quantizers import GPTQQuantizer
9
+ from keras.src.quantizers.quantizers import compute_quantization_parameters
10
+ from keras.src.quantizers.quantizers import dequantize_with_zero_point
11
+ from keras.src.quantizers.quantizers import quantize_with_zero_point
12
+
13
+
14
+ def _stable_permutation(metric):
15
+ """Return a stable permutation that sorts `metric` in descending order.
16
+ Uses an index-based jitter to break ties deterministically."""
17
+ n = ops.shape(metric)[0]
18
+ idx = ops.arange(0, n, dtype="int32")
19
+
20
+ # tiny jitter = (idx / n) * 1e-12 so it never flips a real strict ordering
21
+ jitter = ops.divide(ops.cast(idx, "float32"), ops.cast(n, "float32"))
22
+ metric_jittered = ops.add(metric, ops.multiply(jitter, 1e-12))
23
+
24
+ # argsort by negative to get descending
25
+ return ops.argsort(ops.negative(metric_jittered))
26
+
27
+
28
+ def gptq_quantize_matrix(
29
+ weights_transpose,
30
+ inv_hessian,
31
+ *,
32
+ blocksize=128,
33
+ group_size=-1,
34
+ activation_order=False,
35
+ order_metric=None,
36
+ compute_scale_zero=compute_quantization_parameters,
37
+ ):
38
+ """
39
+ Implements the GPTQ error correction updates.
40
+
41
+ For a single column update (column j):
42
+ e = invH[j, j] * (w_j - q_j)
43
+ W[:, j+1:] -= e * invH[j, j+1:]
44
+ where:
45
+ - w_j is the original column,
46
+ - q_j is the quantized column,
47
+ - invH is the inverse Hessian,
48
+ - e is the propagated error term.
49
+
50
+ Across entire blocks:
51
+ W[:, future] -= E_block * invH[block, future]
52
+ where:
53
+ - E_block is the quantization error accumulated for the current block,
54
+ - invH[block, future] denotes the cross-block slice of the inverse Hessian,
55
+ - W[:, future] are the columns yet to be quantized.
56
+
57
+ Args:
58
+ weights_transpose: Transposed weight matrix [out_features, in_features]
59
+ to quantize.
60
+ inv_hessian: Inverse Hessian matrix [in_features, in_features] for
61
+ error propagation.
62
+ blocksize: Size of the blocks to process (default: 128).
63
+ group_size: Size of the groups for parameter reuse
64
+ (default: -1, no grouping).
65
+ activation_order: Whether to apply activation-order permutation
66
+ (default: False).
67
+ order_metric: Metric for ordering features
68
+ (default: None, uses 1 / diag(invH)).
69
+ compute_scale_zero: Function to compute scale and zero for
70
+ quantization.
71
+ """
72
+ in_features = ops.shape(weights_transpose)[1]
73
+
74
+ # Optional activation-order permutation on feature axis (axis=1)
75
+ if activation_order:
76
+ if order_metric is None:
77
+ # Use 1 / diag(inverse_hessian) as importance proxy if H not
78
+ # available.
79
+ order_metric = ops.reciprocal(
80
+ ops.add(ops.diagonal(inv_hessian), 1e-12)
81
+ )
82
+ else:
83
+ # sanitize provided metric
84
+ order_metric = ops.cast(order_metric, "float32")
85
+ order_metric = ops.where(
86
+ ops.isfinite(order_metric),
87
+ order_metric,
88
+ ops.zeros_like(order_metric),
89
+ )
90
+
91
+ # Sort in descending order by importance
92
+ perm = _stable_permutation(order_metric)
93
+ inv_perm = ops.argsort(perm)
94
+
95
+ weights_transpose = ops.take(weights_transpose, perm, axis=1)
96
+ inv_hessian = ops.take(
97
+ ops.take(inv_hessian, perm, axis=0), perm, axis=1
98
+ )
99
+ else:
100
+ perm = inv_perm = None
101
+
102
+ # weights_buffer: [out_features, in_features]
103
+ weights_buffer = weights_transpose
104
+ # quantized_weights_buffer: [out_features, in_features]
105
+ quantized_weights_buffer = ops.zeros_like(weights_buffer)
106
+
107
+ # Process features in blocks
108
+ for block_start in range(0, in_features, blocksize):
109
+ block_end = min(block_start + blocksize, in_features)
110
+ block_size = block_end - block_start
111
+
112
+ # Block views
113
+ # block_weights: [out_features, bsize]
114
+ block_weights = weights_buffer[:, block_start:block_end]
115
+ # block_weights_quantized: [out_features, bsize]
116
+ block_weights_quantized = ops.zeros_like(block_weights)
117
+ # block_error: [out_features, bsize]
118
+ block_error = ops.zeros_like(block_weights)
119
+ # block_inv_hessian: [bsize, bsize]
120
+ block_inv_hessian = inv_hessian[
121
+ block_start:block_end, block_start:block_end
122
+ ]
123
+
124
+ # group cache for per-group s/z/maxq reuse
125
+ cached_scale = None
126
+ cached_zero = None
127
+ cached_maxq = None
128
+ cached_group_start = -1
129
+
130
+ for block_idx in range(block_size):
131
+ global_idx = block_start + block_idx
132
+ # weight_column: [out_features,]
133
+ weight_column = block_weights[:, block_idx]
134
+
135
+ # Group-wise parameter reuse (compute once per group)
136
+ if group_size != -1:
137
+ # Determine group boundaries
138
+ group_start = (global_idx // group_size) * group_size
139
+ if group_start != cached_group_start:
140
+ group_end = min(group_start + group_size, in_features)
141
+ # group_slice: [out_features, group_len]
142
+ group_slice = weights_buffer[:, group_start:group_end]
143
+ cached_scale, cached_zero, cached_maxq = compute_scale_zero(
144
+ group_slice, weight=True
145
+ )
146
+ cached_group_start = group_start
147
+ scale, zero, maxq = cached_scale, cached_zero, cached_maxq
148
+ else:
149
+ # Per-column params
150
+ scale, zero, maxq = compute_scale_zero(
151
+ ops.expand_dims(weight_column, 1), weight=True
152
+ )
153
+
154
+ # Quantize one column
155
+ # quantized_column: [out_features,]
156
+ quantized_column = quantize_with_zero_point(
157
+ ops.expand_dims(weight_column, 1), scale, zero, maxq
158
+ )
159
+ quantized_column = dequantize_with_zero_point(
160
+ quantized_column, scale, zero
161
+ )[:, 0]
162
+ block_weights_quantized = ops.slice_update(
163
+ block_weights_quantized,
164
+ (0, block_idx),
165
+ ops.expand_dims(quantized_column, 1),
166
+ )
167
+
168
+ # Error feedback for remaining columns within the block
169
+ # diag: [out_features,]
170
+ diag = block_inv_hessian[block_idx, block_idx]
171
+ # error = (col - quantized_col) / block_inv_hessian[idx, idx]
172
+ # error: [out_features,]
173
+ error = ops.divide(
174
+ ops.subtract(weight_column, quantized_column), diag
175
+ )
176
+ # block_error: [out_features, bsize]
177
+ block_error = ops.slice_update(
178
+ block_error, (0, block_idx), ops.expand_dims(error, 1)
179
+ )
180
+
181
+ if block_idx < block_size - 1:
182
+ update = ops.matmul(
183
+ ops.expand_dims(error, 1),
184
+ ops.expand_dims(
185
+ block_inv_hessian[block_idx, block_idx + 1 :], 0
186
+ ),
187
+ )
188
+ tail = block_weights[:, block_idx + 1 :]
189
+ block_weights = ops.slice_update(
190
+ block_weights,
191
+ (0, block_idx + 1),
192
+ ops.subtract(tail, update),
193
+ )
194
+
195
+ # Write block’s quantized columns into result
196
+ left = quantized_weights_buffer[:, :block_start]
197
+ right = quantized_weights_buffer[:, block_end:]
198
+ quantized_weights_buffer = ops.concatenate(
199
+ [left, block_weights_quantized, right], axis=1
200
+ )
201
+
202
+ # Propagate block errors to *future* features (beyond the block)
203
+ if block_end < in_features:
204
+ # weights_buffer[:, block_end:] -=
205
+ # block_error @ invH[block_start:block_end, block_end:]
206
+ # total_update: [out_features, bsize]
207
+ total_update = ops.matmul(
208
+ block_error, inv_hessian[block_start:block_end, block_end:]
209
+ )
210
+ weights_buffer = ops.concatenate(
211
+ [
212
+ weights_buffer[:, :block_end],
213
+ ops.subtract(weights_buffer[:, block_end:], total_update),
214
+ ],
215
+ axis=1,
216
+ )
217
+
218
+ # Undo permutation if used
219
+ if activation_order:
220
+ quantized_weights_buffer = ops.take(
221
+ quantized_weights_buffer, inv_perm, axis=1
222
+ )
223
+
224
+ return quantized_weights_buffer
5
225
 
6
226
 
7
227
  class GPTQ:
8
- def __init__(self, layer):
228
+ def __init__(self, layer, config=GPTQConfig(tokenizer=None, dataset=None)):
9
229
  self.original_layer = layer
10
230
  self.num_samples = 0
11
- self.quantizer = None
231
+ self.config = config
232
+ self.quantizer = GPTQQuantizer(config)
12
233
 
13
234
  # Explicitly handle each supported layer type
14
235
  if isinstance(layer, Dense) or (
@@ -16,9 +237,11 @@ class GPTQ:
16
237
  ):
17
238
  # For a standard Dense layer, the dimensions are straightforward.
18
239
  self.kernel_shape = layer.kernel.shape
19
- self.rows = self.kernel_shape[0] # Input features
20
- self.columns = self.kernel_shape[1] # Output features
21
- self.layer = layer # The layer itself can be used directly.
240
+ # rows: [input_features]
241
+ self.rows = self.kernel_shape[0]
242
+ # columns: [output_features]
243
+ self.columns = self.kernel_shape[1]
244
+ self.layer = layer
22
245
 
23
246
  # Handle 3D EinsumDense layers (typically from attention blocks).
24
247
  elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
@@ -47,16 +270,10 @@ class GPTQ:
47
270
 
48
271
  # Create a temporary object that holds a reshaped
49
272
  # 2D version of the kernel.
50
- self.layer = type(
51
- "temp",
52
- (object,),
53
- {
54
- "kernel": ops.reshape(
55
- layer.kernel, (self.rows, self.columns)
56
- ),
57
- "bias": layer.bias,
58
- },
59
- )()
273
+ self.layer = types.SimpleNamespace(
274
+ kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
275
+ bias=layer.bias,
276
+ )
60
277
 
61
278
  else:
62
279
  # Raise an error if the layer is not supported.
@@ -86,53 +303,54 @@ class GPTQ:
86
303
  pre-initialized Hessian matrix `self.hessian`.
87
304
  """
88
305
  if input_batch is None:
89
- raise ValueError("Input tensor 'input_batch' cannot be None.")
90
-
306
+ raise ValueError("Input tensor cannot be None.")
91
307
  if len(input_batch.shape) < 2:
92
308
  raise ValueError(
93
- f"Input tensor 'input_batch' must have a rank of at least 2 "
94
- f"(e.g., [batch, features]), but got rank "
95
- f"{len(input_batch.shape)}."
309
+ "Input tensor must have rank >= 2 "
310
+ f"(got rank {len(input_batch.shape)})."
96
311
  )
97
312
  if ops.size(input_batch) == 0:
98
- raise ValueError("Input tensor 'input_batch' cannot be empty.")
313
+ raise ValueError("Input tensor cannot be empty.")
99
314
 
100
315
  if len(input_batch.shape) > 2:
316
+ # [batch, features]
101
317
  input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
102
- input_batch = ops.cast(input_batch, "float32")
318
+ x = ops.cast(input_batch, "float32")
103
319
 
104
- if self.hessian.shape[0] != input_batch.shape[-1]:
320
+ num_new_samples = ops.shape(x)[0]
321
+ num_prev_samples = self.num_samples
322
+ total_samples = ops.add(num_prev_samples, num_new_samples)
323
+
324
+ if ops.shape(self.hessian)[0] != ops.shape(x)[-1]:
105
325
  raise ValueError(
106
- f"Hessian dimensions ({self.hessian.shape[0]}) do not"
107
- "match input features ({input_batch.shape[-1]})."
326
+ f"Hessian dimensions ({ops.shape(self.hessian)[0]}) do not "
327
+ f"match input features ({ops.shape(x)[-1]})."
108
328
  )
109
329
 
110
- current_hessian = ops.multiply(
111
- 2, ops.matmul(ops.transpose(input_batch), input_batch)
330
+ # gram_matrix: [features, features]
331
+ gram_matrix = ops.matmul(ops.transpose(x), x)
332
+ # Ensures numerical stability and symmetry in case of large floating
333
+ # point activations.
334
+ gram_matrix = ops.divide(
335
+ ops.add(gram_matrix, ops.transpose(gram_matrix)), 2.0
112
336
  )
113
337
 
114
- if self.num_samples == 0:
115
- self.hessian = current_hessian
116
- else:
117
- total_samples = ops.add(self.num_samples, input_batch.shape[0])
118
- old_hessian_weight = ops.divide(self.num_samples, total_samples)
119
- current_hessian_weight = ops.divide(
120
- input_batch.shape[0], total_samples
338
+ # Decay previous mean and add current per-sample contribution
339
+ # (factor 2/N)
340
+ if self.num_samples > 0:
341
+ self.hessian = ops.multiply(
342
+ self.hessian, ops.divide(num_prev_samples, total_samples)
121
343
  )
344
+ self.hessian = ops.add(
345
+ self.hessian,
346
+ ops.multiply(ops.divide(2.0, total_samples), gram_matrix),
347
+ )
122
348
 
123
- # Update the accumulated Hessian
124
- old_term = ops.multiply(self.hessian, old_hessian_weight)
125
- current_term = ops.multiply(current_hessian, current_hessian_weight)
126
- self.hessian = ops.add(old_term, current_term)
127
-
128
- self.num_samples = ops.add(self.num_samples, input_batch.shape[0])
349
+ self.num_samples = self.num_samples + ops.shape(x)[0] or 0
129
350
 
130
- def quantize_and_correct_block(
351
+ def quantize_and_correct_layer(
131
352
  self,
132
353
  blocksize=128,
133
- hessian_damping=0.01,
134
- group_size=-1,
135
- activation_order=False,
136
354
  ):
137
355
  """
138
356
  Performs GPTQ quantization and correction on the layer's weights.
@@ -143,23 +361,23 @@ class GPTQ:
143
361
  quantization error by updating the remaining weights.
144
362
 
145
363
  The algorithm follows these main steps:
146
- 1. **Initialization**: It optionally reorders the weight columns based
364
+ 1. Initialization: It optionally reorders the weight columns based
147
365
  on activation magnitudes (`activation_order=True`) to protect more
148
366
  salient
149
367
  weights.
150
- 2. **Hessian Modification**: The Hessian matrix, pre-computed from
368
+ 2. Hessian Modification: The Hessian matrix, pre-computed from
151
369
  calibration data, is dampened to ensure its invertibility and
152
370
  stability.
153
- 3. **Iterative Quantization**: The function iterates through the
371
+ 3. Iterative Quantization: The function iterates through the
154
372
  weight columns in blocks (`blocksize`). In each iteration, it:
155
373
  a. Quantizes one column.
156
374
  b. Calculates the quantization error.
157
375
  c. Updates the remaining weights in the *current* block by
158
376
  distributing the error, using the inverse Hessian.
159
- 4. **Block-wise Correction**: After a block is quantized, the total
377
+ 4. Block-wise Correction: After a block is quantized, the total
160
378
  error from that block is propagated to the *next* block of weights
161
379
  to be processed.
162
- 5. **Finalization**: The quantized weights are reordered back if
380
+ 5. Finalization: The quantized weights are reordered back if
163
381
  `activation_order` was used, and the layer's weights are updated.
164
382
 
165
383
  This implementation is based on the official GPTQ paper and repository.
@@ -170,39 +388,16 @@ class GPTQ:
170
388
  Args:
171
389
  blocksize: (int, optional) The size of the weight block to process
172
390
  at a time. Defaults to 128.
173
- hessian_damping: (float, optional) The percentage of dampening to
174
- add the
175
- Hessian's diagonal. A value of 0.01 is recommended.
176
- Defaults to 0.01.
177
- group_size: (int, optional) The number of weights that share the
178
- same quantization parameters (scale and zero-point).
179
- A value of -1 indicates per-channel quantization.
180
- activation_order: (bool, optional) If True, reorders weight columns
181
- based
182
- on their activation's second-order information.
183
391
  """
184
392
 
185
- weights_matrix = ops.transpose(ops.cast(self.layer.kernel, "float32"))
186
- hessian_matrix = ops.cast(self.hessian, "float32")
187
-
188
- if activation_order:
189
- permutation = ops.argsort(
190
- ops.negative(ops.diagonal(hessian_matrix))
191
- )
192
- weights_matrix = ops.take(weights_matrix, permutation, axis=1)
193
- hessian_matrix = ops.take(
194
- ops.take(hessian_matrix, permutation, axis=0),
195
- permutation,
196
- axis=1,
197
- )
198
- inverse_permutation = ops.argsort(permutation)
393
+ weights_matrix = ops.transpose(self.layer.kernel)
199
394
 
200
395
  # Dampen the Hessian for Stability
201
- hessian_diagonal = ops.diagonal(hessian_matrix)
396
+ hessian_diagonal = ops.diagonal(self.hessian)
202
397
  dead_diagonal = ops.equal(hessian_diagonal, 0.0)
203
398
  hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)
204
399
  hessian_matrix = ops.add(
205
- hessian_matrix,
400
+ self.hessian,
206
401
  ops.diag(
207
402
  ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))
208
403
  ),
@@ -210,7 +405,7 @@ class GPTQ:
210
405
 
211
406
  # Add dampening factor to the Hessian diagonal
212
407
  damping_factor = ops.multiply(
213
- hessian_damping, ops.mean(hessian_diagonal)
408
+ self.config.hessian_damping, ops.mean(hessian_diagonal)
214
409
  )
215
410
  hessian_diagonal = ops.add(hessian_diagonal, damping_factor)
216
411
  hessian_matrix = ops.add(
@@ -221,116 +416,17 @@ class GPTQ:
221
416
  )
222
417
 
223
418
  # Compute the inverse Hessian, which is used for error correction
224
- inverse_hessian = ops.linalg.inv(hessian_matrix)
225
- quantized_weights = ops.zeros_like(weights_matrix)
226
-
227
- for block_start in range(0, self.rows, blocksize):
228
- block_end = min(ops.add(block_start, blocksize), self.rows)
229
- block_size = ops.subtract(block_end, block_start)
230
- # Extract the current block of weights and its corresponding
231
- # Hessian
232
- block_weights = weights_matrix[:, block_start:block_end]
233
- block_quantized = ops.zeros_like(block_weights)
234
- block_errors = ops.zeros_like(block_weights)
235
- block_inverse_hessian = inverse_hessian[
236
- block_start:block_end, block_start:block_end
237
- ]
238
-
239
- # Process one column at a time within the block
240
- for col_idx in range(block_size):
241
- weight_column = block_weights[:, col_idx]
242
- diagonal_element = block_inverse_hessian[col_idx, col_idx]
243
-
244
- if group_size != -1:
245
- if ops.mod(ops.add(block_start, col_idx), group_size) == 0:
246
- self.quantizer.find_params(
247
- weights_matrix[
248
- :,
249
- (ops.add(block_start, col_idx)) : (
250
- ops.add(
251
- ops.add(block_start, col_idx),
252
- group_size,
253
- )
254
- ),
255
- ],
256
- weight=True,
257
- )
258
- else:
259
- self.quantizer.find_params(
260
- ops.expand_dims(weight_column, 1), weight=True
261
- )
262
-
263
- # Quantize the current weight column
264
- quantized_column = dequantize(
265
- ops.expand_dims(weight_column, 1),
266
- self.quantizer.scale,
267
- self.quantizer.zero,
268
- self.quantizer.maxq,
269
- )[:, 0]
270
-
271
- block_quantized = ops.slice_update(
272
- block_quantized,
273
- (0, col_idx),
274
- ops.expand_dims(quantized_column, axis=1),
275
- )
276
- quantization_error = ops.divide(
277
- ops.subtract(weight_column, quantized_column),
278
- diagonal_element,
279
- )
280
- block_errors = ops.slice_update(
281
- block_errors,
282
- (0, col_idx),
283
- ops.expand_dims(quantization_error, axis=1),
284
- )
285
-
286
- if ops.less(col_idx, ops.subtract(block_size, 1)):
287
- error_update = ops.matmul(
288
- ops.expand_dims(quantization_error, 1),
289
- ops.expand_dims(
290
- block_inverse_hessian[
291
- col_idx, ops.add(col_idx, 1) :
292
- ],
293
- 0,
294
- ),
295
- )
296
-
297
- # Efficiently update the remaining part of the
298
- # block_weights tensor.
299
- slice_to_update = block_weights[:, ops.add(col_idx, 1) :]
300
- updated_slice = ops.subtract(slice_to_update, error_update)
301
- block_weights = ops.slice_update(
302
- block_weights, (0, ops.add(col_idx, 1)), updated_slice
303
- )
304
-
305
- # Update the full quantized matrix with the processed block
306
- quantized_weights = ops.concatenate(
307
- [
308
- quantized_weights[:, :block_start],
309
- block_quantized,
310
- quantized_weights[:, block_end:],
311
- ],
312
- axis=1,
313
- )
314
-
315
- if block_end < self.rows:
316
- total_error_update = ops.matmul(
317
- block_errors,
318
- inverse_hessian[block_start:block_end, block_end:],
319
- )
320
- weights_matrix = ops.concatenate(
321
- [
322
- weights_matrix[:, :block_end],
323
- ops.subtract(
324
- weights_matrix[:, block_end:], total_error_update
325
- ),
326
- ],
327
- axis=1,
328
- )
329
-
330
- if activation_order:
331
- quantized_weights = ops.take(
332
- quantized_weights, inverse_permutation, axis=1
333
- )
419
+ inverse_hessian = linalg.inv(hessian_matrix)
420
+
421
+ quantized_weights = gptq_quantize_matrix(
422
+ weights_matrix,
423
+ inv_hessian=inverse_hessian,
424
+ blocksize=blocksize,
425
+ group_size=self.config.group_size,
426
+ activation_order=self.config.activation_order,
427
+ order_metric=ops.diagonal(hessian_matrix),
428
+ compute_scale_zero=self.quantizer.find_params,
429
+ )
334
430
 
335
431
  quantized_weights = ops.transpose(quantized_weights)
336
432
 
@@ -340,11 +436,7 @@ class GPTQ:
340
436
  )
341
437
 
342
438
  # Set the new quantized weights in the original layer
343
- new_weights = [ops.convert_to_numpy(quantized_weights)]
344
- if self.original_layer.bias is not None:
345
- new_weights.append(ops.convert_to_numpy(self.original_layer.bias))
346
-
347
- self.original_layer.set_weights(new_weights)
439
+ self.original_layer._kernel.assign(quantized_weights)
348
440
 
349
441
  def free(self):
350
442
  self.hessian = None
@@ -1,7 +1,4 @@
1
- from absl import logging
2
-
3
1
  from keras.src.api_export import keras_export
4
- from keras.src.quantizers.gptq_core import quantize_model
5
2
 
6
3
 
7
4
  @keras_export("keras.quantizers.GPTQConfig")
@@ -140,8 +137,10 @@ class GPTQConfig:
140
137
  self,
141
138
  dataset,
142
139
  tokenizer,
140
+ *,
143
141
  weight_bits: int = 4,
144
142
  num_samples: int = 128,
143
+ per_channel: bool = True,
145
144
  sequence_length: int = 512,
146
145
  hessian_damping: float = 0.01,
147
146
  group_size: int = 128,
@@ -151,19 +150,10 @@ class GPTQConfig:
151
150
  self.dataset = dataset
152
151
  self.tokenizer = tokenizer
153
152
  self.num_samples = num_samples
153
+ self.per_channel = per_channel
154
154
  self.sequence_length = sequence_length
155
155
  self.hessian_damping = hessian_damping
156
156
  self.weight_bits = weight_bits
157
157
  self.group_size = group_size
158
158
  self.symmetric = symmetric
159
159
  self.activation_order = activation_order
160
-
161
- def quantize(self, model):
162
- """
163
- Applies GPTQ quantization to the provided model using this
164
- configuration.
165
- """
166
- logging.info("Initiating quantization from GPTQConfig...")
167
- # The core logic is now delegated to gptqutils, which will handle
168
- # the dynamic imports and data loading.
169
- quantize_model(model=model, config=self)