keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082303__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. keras/_tf_keras/keras/ops/__init__.py +1 -0
  2. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  3. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  4. keras/ops/__init__.py +1 -0
  5. keras/ops/numpy/__init__.py +1 -0
  6. keras/quantizers/__init__.py +1 -0
  7. keras/src/applications/convnext.py +20 -20
  8. keras/src/applications/densenet.py +21 -21
  9. keras/src/applications/efficientnet.py +16 -16
  10. keras/src/applications/efficientnet_v2.py +28 -28
  11. keras/src/applications/inception_resnet_v2.py +7 -7
  12. keras/src/applications/inception_v3.py +5 -5
  13. keras/src/applications/mobilenet_v2.py +13 -20
  14. keras/src/applications/mobilenet_v3.py +15 -15
  15. keras/src/applications/nasnet.py +7 -8
  16. keras/src/applications/resnet.py +32 -32
  17. keras/src/applications/xception.py +10 -10
  18. keras/src/backend/common/dtypes.py +8 -3
  19. keras/src/backend/common/variables.py +3 -1
  20. keras/src/backend/jax/export.py +1 -1
  21. keras/src/backend/jax/numpy.py +6 -0
  22. keras/src/backend/jax/trainer.py +1 -1
  23. keras/src/backend/numpy/numpy.py +28 -0
  24. keras/src/backend/openvino/numpy.py +5 -1
  25. keras/src/backend/tensorflow/numpy.py +22 -0
  26. keras/src/backend/tensorflow/trainer.py +19 -1
  27. keras/src/backend/torch/core.py +6 -9
  28. keras/src/backend/torch/nn.py +1 -2
  29. keras/src/backend/torch/numpy.py +16 -0
  30. keras/src/backend/torch/trainer.py +1 -1
  31. keras/src/callbacks/backup_and_restore.py +2 -2
  32. keras/src/callbacks/csv_logger.py +1 -1
  33. keras/src/callbacks/model_checkpoint.py +1 -1
  34. keras/src/callbacks/tensorboard.py +6 -6
  35. keras/src/constraints/constraints.py +9 -7
  36. keras/src/datasets/boston_housing.py +1 -1
  37. keras/src/datasets/california_housing.py +1 -1
  38. keras/src/datasets/cifar10.py +1 -1
  39. keras/src/datasets/cifar100.py +2 -2
  40. keras/src/datasets/imdb.py +2 -2
  41. keras/src/datasets/mnist.py +1 -1
  42. keras/src/datasets/reuters.py +2 -2
  43. keras/src/dtype_policies/dtype_policy.py +1 -1
  44. keras/src/dtype_policies/dtype_policy_map.py +1 -1
  45. keras/src/export/tf2onnx_lib.py +1 -3
  46. keras/src/initializers/constant_initializers.py +9 -5
  47. keras/src/layers/input_spec.py +6 -6
  48. keras/src/layers/layer.py +1 -1
  49. keras/src/layers/preprocessing/category_encoding.py +3 -3
  50. keras/src/layers/preprocessing/data_layer.py +159 -0
  51. keras/src/layers/preprocessing/discretization.py +3 -3
  52. keras/src/layers/preprocessing/feature_space.py +4 -4
  53. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
  54. keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
  55. keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
  56. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
  57. keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
  58. keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
  59. keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
  60. keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
  61. keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
  62. keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
  63. keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
  64. keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
  65. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
  66. keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
  67. keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
  68. keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
  69. keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
  70. keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
  71. keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
  72. keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
  73. keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
  74. keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
  75. keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
  76. keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
  77. keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
  78. keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
  79. keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
  80. keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
  81. keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
  82. keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
  83. keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
  84. keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
  85. keras/src/layers/preprocessing/normalization.py +5 -2
  86. keras/src/layers/preprocessing/rescaling.py +3 -3
  87. keras/src/layers/rnn/bidirectional.py +4 -4
  88. keras/src/legacy/backend.py +9 -23
  89. keras/src/legacy/preprocessing/image.py +11 -22
  90. keras/src/legacy/preprocessing/text.py +1 -1
  91. keras/src/models/functional.py +2 -2
  92. keras/src/models/model.py +21 -3
  93. keras/src/ops/function.py +1 -1
  94. keras/src/ops/numpy.py +49 -5
  95. keras/src/ops/operation.py +3 -2
  96. keras/src/optimizers/base_optimizer.py +3 -4
  97. keras/src/optimizers/schedules/learning_rate_schedule.py +16 -9
  98. keras/src/quantizers/gptq.py +350 -0
  99. keras/src/quantizers/gptq_config.py +169 -0
  100. keras/src/quantizers/gptq_core.py +335 -0
  101. keras/src/quantizers/gptq_quant.py +133 -0
  102. keras/src/saving/file_editor.py +22 -20
  103. keras/src/saving/object_registration.py +1 -1
  104. keras/src/saving/saving_lib.py +4 -4
  105. keras/src/saving/serialization_lib.py +3 -5
  106. keras/src/trainers/compile_utils.py +1 -1
  107. keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
  108. keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
  109. keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
  110. keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
  111. keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
  112. keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
  113. keras/src/tree/dmtree_impl.py +19 -3
  114. keras/src/tree/optree_impl.py +3 -3
  115. keras/src/tree/tree_api.py +5 -2
  116. keras/src/utils/file_utils.py +13 -5
  117. keras/src/utils/io_utils.py +1 -1
  118. keras/src/utils/model_visualization.py +1 -1
  119. keras/src/utils/progbar.py +5 -5
  120. keras/src/utils/summary_utils.py +4 -4
  121. keras/src/version.py +1 -1
  122. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/METADATA +1 -1
  123. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/RECORD +125 -121
  124. keras/src/layers/preprocessing/tf_data_layer.py +0 -78
  125. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/WHEEL +0 -0
  126. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
1
+ from keras.src import ops
2
+ from keras.src.layers import Dense
3
+ from keras.src.layers import EinsumDense
4
+ from keras.src.quantizers.gptq_quant import dequantize
5
+
6
+
7
+ class GPTQ:
8
+ def __init__(self, layer):
9
+ self.original_layer = layer
10
+ self.num_samples = 0
11
+ self.quantizer = None
12
+
13
+ # Explicitly handle each supported layer type
14
+ if isinstance(layer, Dense) or (
15
+ isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
16
+ ):
17
+ # For a standard Dense layer, the dimensions are straightforward.
18
+ self.kernel_shape = layer.kernel.shape
19
+ self.rows = self.kernel_shape[0] # Input features
20
+ self.columns = self.kernel_shape[1] # Output features
21
+ self.layer = layer # The layer itself can be used directly.
22
+
23
+ # Handle 3D EinsumDense layers (typically from attention blocks).
24
+ elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
25
+ # For EinsumDense, we determine the effective 2D dimensions.
26
+ self.kernel_shape = layer.kernel.shape
27
+ shape = list(self.kernel_shape)
28
+ try:
29
+ d_model_dim_index = shape.index(max(shape))
30
+ except ValueError:
31
+ raise TypeError(
32
+ f"Could not determine hidden dimension from shape {shape}"
33
+ )
34
+
35
+ if d_model_dim_index == 0: # QKV projection case
36
+ in_features, heads, head_dim = shape
37
+ self.rows, self.columns = (
38
+ in_features,
39
+ ops.multiply(heads, head_dim),
40
+ )
41
+ elif d_model_dim_index in [1, 2]: # Attention Output case
42
+ heads, head_dim, out_features = shape
43
+ self.rows, self.columns = (
44
+ ops.multiply(heads, head_dim),
45
+ out_features,
46
+ )
47
+
48
+ # Create a temporary object that holds a reshaped
49
+ # 2D version of the kernel.
50
+ self.layer = type(
51
+ "temp",
52
+ (object,),
53
+ {
54
+ "kernel": ops.reshape(
55
+ layer.kernel, (self.rows, self.columns)
56
+ ),
57
+ "bias": layer.bias,
58
+ },
59
+ )()
60
+
61
+ else:
62
+ # Raise an error if the layer is not supported.
63
+ raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}")
64
+ self.hessian = ops.zeros((self.rows, self.rows), dtype="float32")
65
+
66
+ def update_hessian_with_batch(self, input_batch):
67
+ """
68
+ Updates the running average of the Hessian matrix with a new batch.
69
+
70
+ This method computes the Hessian matrix for a given batch of input
71
+ activations and updates the accumulated Hessian (`self.hessian`) using a
72
+ numerically stable running average. This allows the Hessian to be
73
+ computed over a large dataset without loading all samples into memory
74
+ at once.
75
+
76
+ The input tensor is first reshaped into a 2D matrix [num_samples,
77
+ num_features] before the Hessian is calculated.
78
+
79
+ Args:
80
+ input_batch: A 2D or higher-dimensional tensor of input activations
81
+ from a calibration batch.
82
+
83
+ Raises:
84
+ ValueError: If the feature dimension of the input tensor
85
+ `input_batch` does not match the dimensions of the
86
+ pre-initialized Hessian matrix `self.hessian`.
87
+ """
88
+ if input_batch is None:
89
+ raise ValueError("Input tensor 'input_batch' cannot be None.")
90
+
91
+ if len(input_batch.shape) < 2:
92
+ raise ValueError(
93
+ f"Input tensor 'input_batch' must have a rank of at least 2 "
94
+ f"(e.g., [batch, features]), but got rank "
95
+ f"{len(input_batch.shape)}."
96
+ )
97
+ if ops.size(input_batch) == 0:
98
+ raise ValueError("Input tensor 'input_batch' cannot be empty.")
99
+
100
+ if len(input_batch.shape) > 2:
101
+ input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
102
+ input_batch = ops.cast(input_batch, "float32")
103
+
104
+ if self.hessian.shape[0] != input_batch.shape[-1]:
105
+ raise ValueError(
106
+ f"Hessian dimensions ({self.hessian.shape[0]}) do not"
107
+ "match input features ({input_batch.shape[-1]})."
108
+ )
109
+
110
+ current_hessian = ops.multiply(
111
+ 2, ops.matmul(ops.transpose(input_batch), input_batch)
112
+ )
113
+
114
+ if self.num_samples == 0:
115
+ self.hessian = current_hessian
116
+ else:
117
+ total_samples = ops.add(self.num_samples, input_batch.shape[0])
118
+ old_hessian_weight = ops.divide(self.num_samples, total_samples)
119
+ current_hessian_weight = ops.divide(
120
+ input_batch.shape[0], total_samples
121
+ )
122
+
123
+ # Update the accumulated Hessian
124
+ old_term = ops.multiply(self.hessian, old_hessian_weight)
125
+ current_term = ops.multiply(current_hessian, current_hessian_weight)
126
+ self.hessian = ops.add(old_term, current_term)
127
+
128
+ self.num_samples = ops.add(self.num_samples, input_batch.shape[0])
129
+
130
+ def quantize_and_correct_block(
131
+ self,
132
+ blocksize=128,
133
+ hessian_damping=0.01,
134
+ group_size=-1,
135
+ activation_order=False,
136
+ ):
137
+ """
138
+ Performs GPTQ quantization and correction on the layer's weights.
139
+
140
+ This method implements the core logic of the "Optimal Brain Quant"
141
+ (OBQ) method, as applied by GPTQ, to quantize the weights of a single
142
+ layer. It iteratively quantizes blocks of weights and corrects for the
143
+ quantization error by updating the remaining weights.
144
+
145
+ The algorithm follows these main steps:
146
+ 1. **Initialization**: It optionally reorders the weight columns based
147
+ on activation magnitudes (`activation_order=True`) to protect more
148
+ salient
149
+ weights.
150
+ 2. **Hessian Modification**: The Hessian matrix, pre-computed from
151
+ calibration data, is dampened to ensure its invertibility and
152
+ stability.
153
+ 3. **Iterative Quantization**: The function iterates through the
154
+ weight columns in blocks (`blocksize`). In each iteration, it:
155
+ a. Quantizes one column.
156
+ b. Calculates the quantization error.
157
+ c. Updates the remaining weights in the *current* block by
158
+ distributing the error, using the inverse Hessian.
159
+ 4. **Block-wise Correction**: After a block is quantized, the total
160
+ error from that block is propagated to the *next* block of weights
161
+ to be processed.
162
+ 5. **Finalization**: The quantized weights are reordered back if
163
+ `activation_order` was used, and the layer's weights are updated.
164
+
165
+ This implementation is based on the official GPTQ paper and repository.
166
+ For more details, see:
167
+ - Paper: https://arxiv.org/abs/2210.17323
168
+ - Original Code: https://github.com/IST-DASLab/gptq
169
+
170
+ Args:
171
+ blocksize: (int, optional) The size of the weight block to process
172
+ at a time. Defaults to 128.
173
+ hessian_damping: (float, optional) The percentage of dampening to
174
+ add the
175
+ Hessian's diagonal. A value of 0.01 is recommended.
176
+ Defaults to 0.01.
177
+ group_size: (int, optional) The number of weights that share the
178
+ same quantization parameters (scale and zero-point).
179
+ A value of -1 indicates per-channel quantization.
180
+ activation_order: (bool, optional) If True, reorders weight columns
181
+ based
182
+ on their activation's second-order information.
183
+ """
184
+
185
+ weights_matrix = ops.transpose(ops.cast(self.layer.kernel, "float32"))
186
+ hessian_matrix = ops.cast(self.hessian, "float32")
187
+
188
+ if activation_order:
189
+ permutation = ops.argsort(
190
+ ops.negative(ops.diagonal(hessian_matrix))
191
+ )
192
+ weights_matrix = ops.take(weights_matrix, permutation, axis=1)
193
+ hessian_matrix = ops.take(
194
+ ops.take(hessian_matrix, permutation, axis=0),
195
+ permutation,
196
+ axis=1,
197
+ )
198
+ inverse_permutation = ops.argsort(permutation)
199
+
200
+ # Dampen the Hessian for Stability
201
+ hessian_diagonal = ops.diagonal(hessian_matrix)
202
+ dead_diagonal = ops.equal(hessian_diagonal, 0.0)
203
+ hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)
204
+ hessian_matrix = ops.add(
205
+ hessian_matrix,
206
+ ops.diag(
207
+ ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))
208
+ ),
209
+ )
210
+
211
+ # Add dampening factor to the Hessian diagonal
212
+ damping_factor = ops.multiply(
213
+ hessian_damping, ops.mean(hessian_diagonal)
214
+ )
215
+ hessian_diagonal = ops.add(hessian_diagonal, damping_factor)
216
+ hessian_matrix = ops.add(
217
+ ops.subtract(
218
+ hessian_matrix, ops.diag(ops.diagonal(hessian_matrix))
219
+ ),
220
+ ops.diag(hessian_diagonal),
221
+ )
222
+
223
+ # Compute the inverse Hessian, which is used for error correction
224
+ inverse_hessian = ops.linalg.inv(hessian_matrix)
225
+ quantized_weights = ops.zeros_like(weights_matrix)
226
+
227
+ for block_start in range(0, self.rows, blocksize):
228
+ block_end = min(ops.add(block_start, blocksize), self.rows)
229
+ block_size = ops.subtract(block_end, block_start)
230
+ # Extract the current block of weights and its corresponding
231
+ # Hessian
232
+ block_weights = weights_matrix[:, block_start:block_end]
233
+ block_quantized = ops.zeros_like(block_weights)
234
+ block_errors = ops.zeros_like(block_weights)
235
+ block_inverse_hessian = inverse_hessian[
236
+ block_start:block_end, block_start:block_end
237
+ ]
238
+
239
+ # Process one column at a time within the block
240
+ for col_idx in range(block_size):
241
+ weight_column = block_weights[:, col_idx]
242
+ diagonal_element = block_inverse_hessian[col_idx, col_idx]
243
+
244
+ if group_size != -1:
245
+ if ops.mod(ops.add(block_start, col_idx), group_size) == 0:
246
+ self.quantizer.find_params(
247
+ weights_matrix[
248
+ :,
249
+ (ops.add(block_start, col_idx)) : (
250
+ ops.add(
251
+ ops.add(block_start, col_idx),
252
+ group_size,
253
+ )
254
+ ),
255
+ ],
256
+ weight=True,
257
+ )
258
+ else:
259
+ self.quantizer.find_params(
260
+ ops.expand_dims(weight_column, 1), weight=True
261
+ )
262
+
263
+ # Quantize the current weight column
264
+ quantized_column = dequantize(
265
+ ops.expand_dims(weight_column, 1),
266
+ self.quantizer.scale,
267
+ self.quantizer.zero,
268
+ self.quantizer.maxq,
269
+ )[:, 0]
270
+
271
+ block_quantized = ops.slice_update(
272
+ block_quantized,
273
+ (0, col_idx),
274
+ ops.expand_dims(quantized_column, axis=1),
275
+ )
276
+ quantization_error = ops.divide(
277
+ ops.subtract(weight_column, quantized_column),
278
+ diagonal_element,
279
+ )
280
+ block_errors = ops.slice_update(
281
+ block_errors,
282
+ (0, col_idx),
283
+ ops.expand_dims(quantization_error, axis=1),
284
+ )
285
+
286
+ if ops.less(col_idx, ops.subtract(block_size, 1)):
287
+ error_update = ops.matmul(
288
+ ops.expand_dims(quantization_error, 1),
289
+ ops.expand_dims(
290
+ block_inverse_hessian[
291
+ col_idx, ops.add(col_idx, 1) :
292
+ ],
293
+ 0,
294
+ ),
295
+ )
296
+
297
+ # Efficiently update the remaining part of the
298
+ # block_weights tensor.
299
+ slice_to_update = block_weights[:, ops.add(col_idx, 1) :]
300
+ updated_slice = ops.subtract(slice_to_update, error_update)
301
+ block_weights = ops.slice_update(
302
+ block_weights, (0, ops.add(col_idx, 1)), updated_slice
303
+ )
304
+
305
+ # Update the full quantized matrix with the processed block
306
+ quantized_weights = ops.concatenate(
307
+ [
308
+ quantized_weights[:, :block_start],
309
+ block_quantized,
310
+ quantized_weights[:, block_end:],
311
+ ],
312
+ axis=1,
313
+ )
314
+
315
+ if block_end < self.rows:
316
+ total_error_update = ops.matmul(
317
+ block_errors,
318
+ inverse_hessian[block_start:block_end, block_end:],
319
+ )
320
+ weights_matrix = ops.concatenate(
321
+ [
322
+ weights_matrix[:, :block_end],
323
+ ops.subtract(
324
+ weights_matrix[:, block_end:], total_error_update
325
+ ),
326
+ ],
327
+ axis=1,
328
+ )
329
+
330
+ if activation_order:
331
+ quantized_weights = ops.take(
332
+ quantized_weights, inverse_permutation, axis=1
333
+ )
334
+
335
+ quantized_weights = ops.transpose(quantized_weights)
336
+
337
+ if isinstance(self.original_layer, EinsumDense):
338
+ quantized_weights = ops.reshape(
339
+ quantized_weights, self.kernel_shape
340
+ )
341
+
342
+ # Set the new quantized weights in the original layer
343
+ new_weights = [ops.convert_to_numpy(quantized_weights)]
344
+ if self.original_layer.bias is not None:
345
+ new_weights.append(ops.convert_to_numpy(self.original_layer.bias))
346
+
347
+ self.original_layer.set_weights(new_weights)
348
+
349
+ def free(self):
350
+ self.hessian = None
@@ -0,0 +1,169 @@
1
+ from absl import logging
2
+
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.quantizers.gptq_core import quantize_model
5
+
6
+
7
+ @keras_export("keras.quantizers.GPTQConfig")
8
+ class GPTQConfig:
9
+ """Configuration class for the GPTQ (Gradient-based Post-Training
10
+ Quantization) algorithm.
11
+
12
+ GPTQ is a post-training quantization method that quantizes neural network
13
+ weights to lower precision (e.g., 4-bit) while minimizing the impact on
14
+ model accuracy. It works by analyzing the Hessian matrix of the loss
15
+ function with respect to the weights and applying optimal quantization
16
+ that preserves the most important weight values.
17
+
18
+ **When to use GPTQ:**
19
+ - You want to reduce model size and memory usage
20
+ - You need faster inference on hardware that supports low-precision
21
+ operations
22
+ - You want to maintain model accuracy as much as possible
23
+ - You have a pre-trained model that you want to quantize without
24
+ retraining
25
+
26
+ **How it works:**
27
+ 1. Uses calibration data to compute the Hessian matrix for each layer
28
+ 2. Applies iterative quantization with error correction
29
+ 3. Reorders weights based on activation importance (optional)
30
+ 4. Quantizes weights while minimizing quantization error
31
+
32
+ **Example usage:**
33
+ ```python
34
+ from keras.quantizers import GPTQConfig
35
+ from keras import Model
36
+
37
+ # Create configuration for 4-bit quantization
38
+ config = GPTQConfig(
39
+ dataset=calibration_data, # Your calibration dataset
40
+ tokenizer=your_tokenizer, # Tokenizer for text data
41
+ weight_bits=4, # Quantize to 4 bits
42
+ num_samples=128, # Number of calibration samples
43
+ sequence_length=512, # Sequence length for each sample
44
+ hessian_damping=0.01, # Hessian stabilization factor
45
+ group_size=128, # Weight grouping for quantization
46
+ symmetric=False, # Use asymmetric quantization
47
+ activation_order=True # Reorder weights by importance
48
+ )
49
+
50
+ # Apply quantization to your model
51
+ model = Model(...) # Your pre-trained model
52
+ model.quantize("gptq", config=config)
53
+
54
+ # The model now has quantized weights and can be used for inference
55
+ ```
56
+
57
+ **Benefits:**
58
+ - **Memory reduction**: 4-bit quantization reduces memory by ~8x compared
59
+ to float32
60
+ - **Faster inference**: Lower precision operations are faster on supported
61
+ hardware
62
+ - **Accuracy preservation**: Minimizes accuracy loss through optimal
63
+ quantization
64
+ - **No retraining required**: Works with pre-trained models
65
+
66
+ **Advanced usage examples:**
67
+
68
+ **Per-channel quantization (recommended for most cases):**
69
+ ```python
70
+ config = GPTQConfig(
71
+ dataset=calibration_data,
72
+ tokenizer=tokenizer,
73
+ weight_bits=4,
74
+ group_size=-1, # -1 enables per-channel quantization
75
+ symmetric=False
76
+ )
77
+ ```
78
+
79
+ **Grouped quantization (for specific hardware requirements):**
80
+ ```python
81
+ config = GPTQConfig(
82
+ dataset=calibration_data,
83
+ tokenizer=tokenizer,
84
+ weight_bits=4,
85
+ group_size=64, # 64 weights share the same scale factor
86
+ symmetric=True # Use symmetric quantization
87
+ )
88
+ ```
89
+
90
+ **High-accuracy quantization with activation ordering:**
91
+ ```python
92
+ config = GPTQConfig(
93
+ dataset=calibration_data,
94
+ tokenizer=tokenizer,
95
+ weight_bits=4,
96
+ activation_order=True, # Reorder weights by importance
97
+ hessian_damping=0.005, # Lower damping for more precise
98
+ # quantization
99
+ num_samples=256 # More samples for better accuracy
100
+ )
101
+ ```
102
+
103
+ **References:**
104
+ - Original GPTQ paper: "GPTQ: Accurate Post-Training Quantization
105
+ for Generative Pre-trained Transformers"
106
+ - Implementation based on: https://github.com/IST-DASLab/gptq
107
+ - Suitable for: Transformer models, large language models, and other
108
+ deep neural networks
109
+
110
+ **Note:** The quality of quantization depends heavily on the calibration
111
+ dataset. Use representative data that covers the expected input
112
+ distribution for best results.
113
+
114
+ Args:
115
+ dataset: The calibration dataset. It can be an iterable that yields
116
+ strings or pre-tokenized numerical tensors (e.g., a list of
117
+ strings, a generator, or a NumPy array). This data is used to
118
+ analyze the model's activations.
119
+ tokenizer: A `keras_nlp.Tokenizer` instance (or a similar callable)
120
+ that is used to process the `dataset` if it contains strings.
121
+ weight_bits: (int, optional) The number of bits to quantize weights to.
122
+ Defaults to 4.
123
+ num_samples: (int, optional) The number of calibration data samples to
124
+ use from the dataset. Defaults to 128.
125
+ sequence_length: (int, optional) The sequence length to use for each
126
+ calibration sample. Defaults to 512.
127
+ hessian_damping: (float, optional) The % of Hessian damping to use for
128
+ stabilization during inverse calculation. Defaults to 0.01.
129
+ group_size: (int, optional) The size of weight groups to quantize
130
+ together. A `group_size` of -1 indicates per-channel quantization.
131
+ Defaults to 128.
132
+ symmetric: (bool, optional) If `True`, uses symmetric quantization.
133
+ If `False`, uses asymmetric quantization. Defaults to `False`.
134
+ activation_order: (bool, optional) If `True`, reorders weight columns
135
+ based on activation magnitude, which can improve quantization
136
+ accuracy. Defaults to `False`.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ dataset,
142
+ tokenizer,
143
+ weight_bits: int = 4,
144
+ num_samples: int = 128,
145
+ sequence_length: int = 512,
146
+ hessian_damping: float = 0.01,
147
+ group_size: int = 128,
148
+ symmetric: bool = False,
149
+ activation_order: bool = False,
150
+ ):
151
+ self.dataset = dataset
152
+ self.tokenizer = tokenizer
153
+ self.num_samples = num_samples
154
+ self.sequence_length = sequence_length
155
+ self.hessian_damping = hessian_damping
156
+ self.weight_bits = weight_bits
157
+ self.group_size = group_size
158
+ self.symmetric = symmetric
159
+ self.activation_order = activation_order
160
+
161
+ def quantize(self, model):
162
+ """
163
+ Applies GPTQ quantization to the provided model using this
164
+ configuration.
165
+ """
166
+ logging.info("Initiating quantization from GPTQConfig...")
167
+ # The core logic is now delegated to gptqutils, which will handle
168
+ # the dynamic imports and data loading.
169
+ quantize_model(model=model, config=self)