keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082303__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/ops/__init__.py +1 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/ops/__init__.py +1 -0
- keras/ops/numpy/__init__.py +1 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/applications/convnext.py +20 -20
- keras/src/applications/densenet.py +21 -21
- keras/src/applications/efficientnet.py +16 -16
- keras/src/applications/efficientnet_v2.py +28 -28
- keras/src/applications/inception_resnet_v2.py +7 -7
- keras/src/applications/inception_v3.py +5 -5
- keras/src/applications/mobilenet_v2.py +13 -20
- keras/src/applications/mobilenet_v3.py +15 -15
- keras/src/applications/nasnet.py +7 -8
- keras/src/applications/resnet.py +32 -32
- keras/src/applications/xception.py +10 -10
- keras/src/backend/common/dtypes.py +8 -3
- keras/src/backend/common/variables.py +3 -1
- keras/src/backend/jax/export.py +1 -1
- keras/src/backend/jax/numpy.py +6 -0
- keras/src/backend/jax/trainer.py +1 -1
- keras/src/backend/numpy/numpy.py +28 -0
- keras/src/backend/openvino/numpy.py +5 -1
- keras/src/backend/tensorflow/numpy.py +22 -0
- keras/src/backend/tensorflow/trainer.py +19 -1
- keras/src/backend/torch/core.py +6 -9
- keras/src/backend/torch/nn.py +1 -2
- keras/src/backend/torch/numpy.py +16 -0
- keras/src/backend/torch/trainer.py +1 -1
- keras/src/callbacks/backup_and_restore.py +2 -2
- keras/src/callbacks/csv_logger.py +1 -1
- keras/src/callbacks/model_checkpoint.py +1 -1
- keras/src/callbacks/tensorboard.py +6 -6
- keras/src/constraints/constraints.py +9 -7
- keras/src/datasets/boston_housing.py +1 -1
- keras/src/datasets/california_housing.py +1 -1
- keras/src/datasets/cifar10.py +1 -1
- keras/src/datasets/cifar100.py +2 -2
- keras/src/datasets/imdb.py +2 -2
- keras/src/datasets/mnist.py +1 -1
- keras/src/datasets/reuters.py +2 -2
- keras/src/dtype_policies/dtype_policy.py +1 -1
- keras/src/dtype_policies/dtype_policy_map.py +1 -1
- keras/src/export/tf2onnx_lib.py +1 -3
- keras/src/initializers/constant_initializers.py +9 -5
- keras/src/layers/input_spec.py +6 -6
- keras/src/layers/layer.py +1 -1
- keras/src/layers/preprocessing/category_encoding.py +3 -3
- keras/src/layers/preprocessing/data_layer.py +159 -0
- keras/src/layers/preprocessing/discretization.py +3 -3
- keras/src/layers/preprocessing/feature_space.py +4 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
- keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
- keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
- keras/src/layers/preprocessing/normalization.py +5 -2
- keras/src/layers/preprocessing/rescaling.py +3 -3
- keras/src/layers/rnn/bidirectional.py +4 -4
- keras/src/legacy/backend.py +9 -23
- keras/src/legacy/preprocessing/image.py +11 -22
- keras/src/legacy/preprocessing/text.py +1 -1
- keras/src/models/functional.py +2 -2
- keras/src/models/model.py +21 -3
- keras/src/ops/function.py +1 -1
- keras/src/ops/numpy.py +49 -5
- keras/src/ops/operation.py +3 -2
- keras/src/optimizers/base_optimizer.py +3 -4
- keras/src/optimizers/schedules/learning_rate_schedule.py +16 -9
- keras/src/quantizers/gptq.py +350 -0
- keras/src/quantizers/gptq_config.py +169 -0
- keras/src/quantizers/gptq_core.py +335 -0
- keras/src/quantizers/gptq_quant.py +133 -0
- keras/src/saving/file_editor.py +22 -20
- keras/src/saving/object_registration.py +1 -1
- keras/src/saving/saving_lib.py +4 -4
- keras/src/saving/serialization_lib.py +3 -5
- keras/src/trainers/compile_utils.py +1 -1
- keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
- keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
- keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
- keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
- keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
- keras/src/tree/dmtree_impl.py +19 -3
- keras/src/tree/optree_impl.py +3 -3
- keras/src/tree/tree_api.py +5 -2
- keras/src/utils/file_utils.py +13 -5
- keras/src/utils/io_utils.py +1 -1
- keras/src/utils/model_visualization.py +1 -1
- keras/src/utils/progbar.py +5 -5
- keras/src/utils/summary_utils.py +4 -4
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/METADATA +1 -1
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/RECORD +125 -121
- keras/src/layers/preprocessing/tf_data_layer.py +0 -78
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
|
|
1
|
+
from keras.src import ops
|
2
|
+
from keras.src.layers import Dense
|
3
|
+
from keras.src.layers import EinsumDense
|
4
|
+
from keras.src.quantizers.gptq_quant import dequantize
|
5
|
+
|
6
|
+
|
7
|
+
class GPTQ:
|
8
|
+
def __init__(self, layer):
|
9
|
+
self.original_layer = layer
|
10
|
+
self.num_samples = 0
|
11
|
+
self.quantizer = None
|
12
|
+
|
13
|
+
# Explicitly handle each supported layer type
|
14
|
+
if isinstance(layer, Dense) or (
|
15
|
+
isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
|
16
|
+
):
|
17
|
+
# For a standard Dense layer, the dimensions are straightforward.
|
18
|
+
self.kernel_shape = layer.kernel.shape
|
19
|
+
self.rows = self.kernel_shape[0] # Input features
|
20
|
+
self.columns = self.kernel_shape[1] # Output features
|
21
|
+
self.layer = layer # The layer itself can be used directly.
|
22
|
+
|
23
|
+
# Handle 3D EinsumDense layers (typically from attention blocks).
|
24
|
+
elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
|
25
|
+
# For EinsumDense, we determine the effective 2D dimensions.
|
26
|
+
self.kernel_shape = layer.kernel.shape
|
27
|
+
shape = list(self.kernel_shape)
|
28
|
+
try:
|
29
|
+
d_model_dim_index = shape.index(max(shape))
|
30
|
+
except ValueError:
|
31
|
+
raise TypeError(
|
32
|
+
f"Could not determine hidden dimension from shape {shape}"
|
33
|
+
)
|
34
|
+
|
35
|
+
if d_model_dim_index == 0: # QKV projection case
|
36
|
+
in_features, heads, head_dim = shape
|
37
|
+
self.rows, self.columns = (
|
38
|
+
in_features,
|
39
|
+
ops.multiply(heads, head_dim),
|
40
|
+
)
|
41
|
+
elif d_model_dim_index in [1, 2]: # Attention Output case
|
42
|
+
heads, head_dim, out_features = shape
|
43
|
+
self.rows, self.columns = (
|
44
|
+
ops.multiply(heads, head_dim),
|
45
|
+
out_features,
|
46
|
+
)
|
47
|
+
|
48
|
+
# Create a temporary object that holds a reshaped
|
49
|
+
# 2D version of the kernel.
|
50
|
+
self.layer = type(
|
51
|
+
"temp",
|
52
|
+
(object,),
|
53
|
+
{
|
54
|
+
"kernel": ops.reshape(
|
55
|
+
layer.kernel, (self.rows, self.columns)
|
56
|
+
),
|
57
|
+
"bias": layer.bias,
|
58
|
+
},
|
59
|
+
)()
|
60
|
+
|
61
|
+
else:
|
62
|
+
# Raise an error if the layer is not supported.
|
63
|
+
raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}")
|
64
|
+
self.hessian = ops.zeros((self.rows, self.rows), dtype="float32")
|
65
|
+
|
66
|
+
def update_hessian_with_batch(self, input_batch):
|
67
|
+
"""
|
68
|
+
Updates the running average of the Hessian matrix with a new batch.
|
69
|
+
|
70
|
+
This method computes the Hessian matrix for a given batch of input
|
71
|
+
activations and updates the accumulated Hessian (`self.hessian`) using a
|
72
|
+
numerically stable running average. This allows the Hessian to be
|
73
|
+
computed over a large dataset without loading all samples into memory
|
74
|
+
at once.
|
75
|
+
|
76
|
+
The input tensor is first reshaped into a 2D matrix [num_samples,
|
77
|
+
num_features] before the Hessian is calculated.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
input_batch: A 2D or higher-dimensional tensor of input activations
|
81
|
+
from a calibration batch.
|
82
|
+
|
83
|
+
Raises:
|
84
|
+
ValueError: If the feature dimension of the input tensor
|
85
|
+
`input_batch` does not match the dimensions of the
|
86
|
+
pre-initialized Hessian matrix `self.hessian`.
|
87
|
+
"""
|
88
|
+
if input_batch is None:
|
89
|
+
raise ValueError("Input tensor 'input_batch' cannot be None.")
|
90
|
+
|
91
|
+
if len(input_batch.shape) < 2:
|
92
|
+
raise ValueError(
|
93
|
+
f"Input tensor 'input_batch' must have a rank of at least 2 "
|
94
|
+
f"(e.g., [batch, features]), but got rank "
|
95
|
+
f"{len(input_batch.shape)}."
|
96
|
+
)
|
97
|
+
if ops.size(input_batch) == 0:
|
98
|
+
raise ValueError("Input tensor 'input_batch' cannot be empty.")
|
99
|
+
|
100
|
+
if len(input_batch.shape) > 2:
|
101
|
+
input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
|
102
|
+
input_batch = ops.cast(input_batch, "float32")
|
103
|
+
|
104
|
+
if self.hessian.shape[0] != input_batch.shape[-1]:
|
105
|
+
raise ValueError(
|
106
|
+
f"Hessian dimensions ({self.hessian.shape[0]}) do not"
|
107
|
+
"match input features ({input_batch.shape[-1]})."
|
108
|
+
)
|
109
|
+
|
110
|
+
current_hessian = ops.multiply(
|
111
|
+
2, ops.matmul(ops.transpose(input_batch), input_batch)
|
112
|
+
)
|
113
|
+
|
114
|
+
if self.num_samples == 0:
|
115
|
+
self.hessian = current_hessian
|
116
|
+
else:
|
117
|
+
total_samples = ops.add(self.num_samples, input_batch.shape[0])
|
118
|
+
old_hessian_weight = ops.divide(self.num_samples, total_samples)
|
119
|
+
current_hessian_weight = ops.divide(
|
120
|
+
input_batch.shape[0], total_samples
|
121
|
+
)
|
122
|
+
|
123
|
+
# Update the accumulated Hessian
|
124
|
+
old_term = ops.multiply(self.hessian, old_hessian_weight)
|
125
|
+
current_term = ops.multiply(current_hessian, current_hessian_weight)
|
126
|
+
self.hessian = ops.add(old_term, current_term)
|
127
|
+
|
128
|
+
self.num_samples = ops.add(self.num_samples, input_batch.shape[0])
|
129
|
+
|
130
|
+
def quantize_and_correct_block(
|
131
|
+
self,
|
132
|
+
blocksize=128,
|
133
|
+
hessian_damping=0.01,
|
134
|
+
group_size=-1,
|
135
|
+
activation_order=False,
|
136
|
+
):
|
137
|
+
"""
|
138
|
+
Performs GPTQ quantization and correction on the layer's weights.
|
139
|
+
|
140
|
+
This method implements the core logic of the "Optimal Brain Quant"
|
141
|
+
(OBQ) method, as applied by GPTQ, to quantize the weights of a single
|
142
|
+
layer. It iteratively quantizes blocks of weights and corrects for the
|
143
|
+
quantization error by updating the remaining weights.
|
144
|
+
|
145
|
+
The algorithm follows these main steps:
|
146
|
+
1. **Initialization**: It optionally reorders the weight columns based
|
147
|
+
on activation magnitudes (`activation_order=True`) to protect more
|
148
|
+
salient
|
149
|
+
weights.
|
150
|
+
2. **Hessian Modification**: The Hessian matrix, pre-computed from
|
151
|
+
calibration data, is dampened to ensure its invertibility and
|
152
|
+
stability.
|
153
|
+
3. **Iterative Quantization**: The function iterates through the
|
154
|
+
weight columns in blocks (`blocksize`). In each iteration, it:
|
155
|
+
a. Quantizes one column.
|
156
|
+
b. Calculates the quantization error.
|
157
|
+
c. Updates the remaining weights in the *current* block by
|
158
|
+
distributing the error, using the inverse Hessian.
|
159
|
+
4. **Block-wise Correction**: After a block is quantized, the total
|
160
|
+
error from that block is propagated to the *next* block of weights
|
161
|
+
to be processed.
|
162
|
+
5. **Finalization**: The quantized weights are reordered back if
|
163
|
+
`activation_order` was used, and the layer's weights are updated.
|
164
|
+
|
165
|
+
This implementation is based on the official GPTQ paper and repository.
|
166
|
+
For more details, see:
|
167
|
+
- Paper: https://arxiv.org/abs/2210.17323
|
168
|
+
- Original Code: https://github.com/IST-DASLab/gptq
|
169
|
+
|
170
|
+
Args:
|
171
|
+
blocksize: (int, optional) The size of the weight block to process
|
172
|
+
at a time. Defaults to 128.
|
173
|
+
hessian_damping: (float, optional) The percentage of dampening to
|
174
|
+
add the
|
175
|
+
Hessian's diagonal. A value of 0.01 is recommended.
|
176
|
+
Defaults to 0.01.
|
177
|
+
group_size: (int, optional) The number of weights that share the
|
178
|
+
same quantization parameters (scale and zero-point).
|
179
|
+
A value of -1 indicates per-channel quantization.
|
180
|
+
activation_order: (bool, optional) If True, reorders weight columns
|
181
|
+
based
|
182
|
+
on their activation's second-order information.
|
183
|
+
"""
|
184
|
+
|
185
|
+
weights_matrix = ops.transpose(ops.cast(self.layer.kernel, "float32"))
|
186
|
+
hessian_matrix = ops.cast(self.hessian, "float32")
|
187
|
+
|
188
|
+
if activation_order:
|
189
|
+
permutation = ops.argsort(
|
190
|
+
ops.negative(ops.diagonal(hessian_matrix))
|
191
|
+
)
|
192
|
+
weights_matrix = ops.take(weights_matrix, permutation, axis=1)
|
193
|
+
hessian_matrix = ops.take(
|
194
|
+
ops.take(hessian_matrix, permutation, axis=0),
|
195
|
+
permutation,
|
196
|
+
axis=1,
|
197
|
+
)
|
198
|
+
inverse_permutation = ops.argsort(permutation)
|
199
|
+
|
200
|
+
# Dampen the Hessian for Stability
|
201
|
+
hessian_diagonal = ops.diagonal(hessian_matrix)
|
202
|
+
dead_diagonal = ops.equal(hessian_diagonal, 0.0)
|
203
|
+
hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)
|
204
|
+
hessian_matrix = ops.add(
|
205
|
+
hessian_matrix,
|
206
|
+
ops.diag(
|
207
|
+
ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))
|
208
|
+
),
|
209
|
+
)
|
210
|
+
|
211
|
+
# Add dampening factor to the Hessian diagonal
|
212
|
+
damping_factor = ops.multiply(
|
213
|
+
hessian_damping, ops.mean(hessian_diagonal)
|
214
|
+
)
|
215
|
+
hessian_diagonal = ops.add(hessian_diagonal, damping_factor)
|
216
|
+
hessian_matrix = ops.add(
|
217
|
+
ops.subtract(
|
218
|
+
hessian_matrix, ops.diag(ops.diagonal(hessian_matrix))
|
219
|
+
),
|
220
|
+
ops.diag(hessian_diagonal),
|
221
|
+
)
|
222
|
+
|
223
|
+
# Compute the inverse Hessian, which is used for error correction
|
224
|
+
inverse_hessian = ops.linalg.inv(hessian_matrix)
|
225
|
+
quantized_weights = ops.zeros_like(weights_matrix)
|
226
|
+
|
227
|
+
for block_start in range(0, self.rows, blocksize):
|
228
|
+
block_end = min(ops.add(block_start, blocksize), self.rows)
|
229
|
+
block_size = ops.subtract(block_end, block_start)
|
230
|
+
# Extract the current block of weights and its corresponding
|
231
|
+
# Hessian
|
232
|
+
block_weights = weights_matrix[:, block_start:block_end]
|
233
|
+
block_quantized = ops.zeros_like(block_weights)
|
234
|
+
block_errors = ops.zeros_like(block_weights)
|
235
|
+
block_inverse_hessian = inverse_hessian[
|
236
|
+
block_start:block_end, block_start:block_end
|
237
|
+
]
|
238
|
+
|
239
|
+
# Process one column at a time within the block
|
240
|
+
for col_idx in range(block_size):
|
241
|
+
weight_column = block_weights[:, col_idx]
|
242
|
+
diagonal_element = block_inverse_hessian[col_idx, col_idx]
|
243
|
+
|
244
|
+
if group_size != -1:
|
245
|
+
if ops.mod(ops.add(block_start, col_idx), group_size) == 0:
|
246
|
+
self.quantizer.find_params(
|
247
|
+
weights_matrix[
|
248
|
+
:,
|
249
|
+
(ops.add(block_start, col_idx)) : (
|
250
|
+
ops.add(
|
251
|
+
ops.add(block_start, col_idx),
|
252
|
+
group_size,
|
253
|
+
)
|
254
|
+
),
|
255
|
+
],
|
256
|
+
weight=True,
|
257
|
+
)
|
258
|
+
else:
|
259
|
+
self.quantizer.find_params(
|
260
|
+
ops.expand_dims(weight_column, 1), weight=True
|
261
|
+
)
|
262
|
+
|
263
|
+
# Quantize the current weight column
|
264
|
+
quantized_column = dequantize(
|
265
|
+
ops.expand_dims(weight_column, 1),
|
266
|
+
self.quantizer.scale,
|
267
|
+
self.quantizer.zero,
|
268
|
+
self.quantizer.maxq,
|
269
|
+
)[:, 0]
|
270
|
+
|
271
|
+
block_quantized = ops.slice_update(
|
272
|
+
block_quantized,
|
273
|
+
(0, col_idx),
|
274
|
+
ops.expand_dims(quantized_column, axis=1),
|
275
|
+
)
|
276
|
+
quantization_error = ops.divide(
|
277
|
+
ops.subtract(weight_column, quantized_column),
|
278
|
+
diagonal_element,
|
279
|
+
)
|
280
|
+
block_errors = ops.slice_update(
|
281
|
+
block_errors,
|
282
|
+
(0, col_idx),
|
283
|
+
ops.expand_dims(quantization_error, axis=1),
|
284
|
+
)
|
285
|
+
|
286
|
+
if ops.less(col_idx, ops.subtract(block_size, 1)):
|
287
|
+
error_update = ops.matmul(
|
288
|
+
ops.expand_dims(quantization_error, 1),
|
289
|
+
ops.expand_dims(
|
290
|
+
block_inverse_hessian[
|
291
|
+
col_idx, ops.add(col_idx, 1) :
|
292
|
+
],
|
293
|
+
0,
|
294
|
+
),
|
295
|
+
)
|
296
|
+
|
297
|
+
# Efficiently update the remaining part of the
|
298
|
+
# block_weights tensor.
|
299
|
+
slice_to_update = block_weights[:, ops.add(col_idx, 1) :]
|
300
|
+
updated_slice = ops.subtract(slice_to_update, error_update)
|
301
|
+
block_weights = ops.slice_update(
|
302
|
+
block_weights, (0, ops.add(col_idx, 1)), updated_slice
|
303
|
+
)
|
304
|
+
|
305
|
+
# Update the full quantized matrix with the processed block
|
306
|
+
quantized_weights = ops.concatenate(
|
307
|
+
[
|
308
|
+
quantized_weights[:, :block_start],
|
309
|
+
block_quantized,
|
310
|
+
quantized_weights[:, block_end:],
|
311
|
+
],
|
312
|
+
axis=1,
|
313
|
+
)
|
314
|
+
|
315
|
+
if block_end < self.rows:
|
316
|
+
total_error_update = ops.matmul(
|
317
|
+
block_errors,
|
318
|
+
inverse_hessian[block_start:block_end, block_end:],
|
319
|
+
)
|
320
|
+
weights_matrix = ops.concatenate(
|
321
|
+
[
|
322
|
+
weights_matrix[:, :block_end],
|
323
|
+
ops.subtract(
|
324
|
+
weights_matrix[:, block_end:], total_error_update
|
325
|
+
),
|
326
|
+
],
|
327
|
+
axis=1,
|
328
|
+
)
|
329
|
+
|
330
|
+
if activation_order:
|
331
|
+
quantized_weights = ops.take(
|
332
|
+
quantized_weights, inverse_permutation, axis=1
|
333
|
+
)
|
334
|
+
|
335
|
+
quantized_weights = ops.transpose(quantized_weights)
|
336
|
+
|
337
|
+
if isinstance(self.original_layer, EinsumDense):
|
338
|
+
quantized_weights = ops.reshape(
|
339
|
+
quantized_weights, self.kernel_shape
|
340
|
+
)
|
341
|
+
|
342
|
+
# Set the new quantized weights in the original layer
|
343
|
+
new_weights = [ops.convert_to_numpy(quantized_weights)]
|
344
|
+
if self.original_layer.bias is not None:
|
345
|
+
new_weights.append(ops.convert_to_numpy(self.original_layer.bias))
|
346
|
+
|
347
|
+
self.original_layer.set_weights(new_weights)
|
348
|
+
|
349
|
+
def free(self):
|
350
|
+
self.hessian = None
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from absl import logging
|
2
|
+
|
3
|
+
from keras.src.api_export import keras_export
|
4
|
+
from keras.src.quantizers.gptq_core import quantize_model
|
5
|
+
|
6
|
+
|
7
|
+
@keras_export("keras.quantizers.GPTQConfig")
|
8
|
+
class GPTQConfig:
|
9
|
+
"""Configuration class for the GPTQ (Gradient-based Post-Training
|
10
|
+
Quantization) algorithm.
|
11
|
+
|
12
|
+
GPTQ is a post-training quantization method that quantizes neural network
|
13
|
+
weights to lower precision (e.g., 4-bit) while minimizing the impact on
|
14
|
+
model accuracy. It works by analyzing the Hessian matrix of the loss
|
15
|
+
function with respect to the weights and applying optimal quantization
|
16
|
+
that preserves the most important weight values.
|
17
|
+
|
18
|
+
**When to use GPTQ:**
|
19
|
+
- You want to reduce model size and memory usage
|
20
|
+
- You need faster inference on hardware that supports low-precision
|
21
|
+
operations
|
22
|
+
- You want to maintain model accuracy as much as possible
|
23
|
+
- You have a pre-trained model that you want to quantize without
|
24
|
+
retraining
|
25
|
+
|
26
|
+
**How it works:**
|
27
|
+
1. Uses calibration data to compute the Hessian matrix for each layer
|
28
|
+
2. Applies iterative quantization with error correction
|
29
|
+
3. Reorders weights based on activation importance (optional)
|
30
|
+
4. Quantizes weights while minimizing quantization error
|
31
|
+
|
32
|
+
**Example usage:**
|
33
|
+
```python
|
34
|
+
from keras.quantizers import GPTQConfig
|
35
|
+
from keras import Model
|
36
|
+
|
37
|
+
# Create configuration for 4-bit quantization
|
38
|
+
config = GPTQConfig(
|
39
|
+
dataset=calibration_data, # Your calibration dataset
|
40
|
+
tokenizer=your_tokenizer, # Tokenizer for text data
|
41
|
+
weight_bits=4, # Quantize to 4 bits
|
42
|
+
num_samples=128, # Number of calibration samples
|
43
|
+
sequence_length=512, # Sequence length for each sample
|
44
|
+
hessian_damping=0.01, # Hessian stabilization factor
|
45
|
+
group_size=128, # Weight grouping for quantization
|
46
|
+
symmetric=False, # Use asymmetric quantization
|
47
|
+
activation_order=True # Reorder weights by importance
|
48
|
+
)
|
49
|
+
|
50
|
+
# Apply quantization to your model
|
51
|
+
model = Model(...) # Your pre-trained model
|
52
|
+
model.quantize("gptq", config=config)
|
53
|
+
|
54
|
+
# The model now has quantized weights and can be used for inference
|
55
|
+
```
|
56
|
+
|
57
|
+
**Benefits:**
|
58
|
+
- **Memory reduction**: 4-bit quantization reduces memory by ~8x compared
|
59
|
+
to float32
|
60
|
+
- **Faster inference**: Lower precision operations are faster on supported
|
61
|
+
hardware
|
62
|
+
- **Accuracy preservation**: Minimizes accuracy loss through optimal
|
63
|
+
quantization
|
64
|
+
- **No retraining required**: Works with pre-trained models
|
65
|
+
|
66
|
+
**Advanced usage examples:**
|
67
|
+
|
68
|
+
**Per-channel quantization (recommended for most cases):**
|
69
|
+
```python
|
70
|
+
config = GPTQConfig(
|
71
|
+
dataset=calibration_data,
|
72
|
+
tokenizer=tokenizer,
|
73
|
+
weight_bits=4,
|
74
|
+
group_size=-1, # -1 enables per-channel quantization
|
75
|
+
symmetric=False
|
76
|
+
)
|
77
|
+
```
|
78
|
+
|
79
|
+
**Grouped quantization (for specific hardware requirements):**
|
80
|
+
```python
|
81
|
+
config = GPTQConfig(
|
82
|
+
dataset=calibration_data,
|
83
|
+
tokenizer=tokenizer,
|
84
|
+
weight_bits=4,
|
85
|
+
group_size=64, # 64 weights share the same scale factor
|
86
|
+
symmetric=True # Use symmetric quantization
|
87
|
+
)
|
88
|
+
```
|
89
|
+
|
90
|
+
**High-accuracy quantization with activation ordering:**
|
91
|
+
```python
|
92
|
+
config = GPTQConfig(
|
93
|
+
dataset=calibration_data,
|
94
|
+
tokenizer=tokenizer,
|
95
|
+
weight_bits=4,
|
96
|
+
activation_order=True, # Reorder weights by importance
|
97
|
+
hessian_damping=0.005, # Lower damping for more precise
|
98
|
+
# quantization
|
99
|
+
num_samples=256 # More samples for better accuracy
|
100
|
+
)
|
101
|
+
```
|
102
|
+
|
103
|
+
**References:**
|
104
|
+
- Original GPTQ paper: "GPTQ: Accurate Post-Training Quantization
|
105
|
+
for Generative Pre-trained Transformers"
|
106
|
+
- Implementation based on: https://github.com/IST-DASLab/gptq
|
107
|
+
- Suitable for: Transformer models, large language models, and other
|
108
|
+
deep neural networks
|
109
|
+
|
110
|
+
**Note:** The quality of quantization depends heavily on the calibration
|
111
|
+
dataset. Use representative data that covers the expected input
|
112
|
+
distribution for best results.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
dataset: The calibration dataset. It can be an iterable that yields
|
116
|
+
strings or pre-tokenized numerical tensors (e.g., a list of
|
117
|
+
strings, a generator, or a NumPy array). This data is used to
|
118
|
+
analyze the model's activations.
|
119
|
+
tokenizer: A `keras_nlp.Tokenizer` instance (or a similar callable)
|
120
|
+
that is used to process the `dataset` if it contains strings.
|
121
|
+
weight_bits: (int, optional) The number of bits to quantize weights to.
|
122
|
+
Defaults to 4.
|
123
|
+
num_samples: (int, optional) The number of calibration data samples to
|
124
|
+
use from the dataset. Defaults to 128.
|
125
|
+
sequence_length: (int, optional) The sequence length to use for each
|
126
|
+
calibration sample. Defaults to 512.
|
127
|
+
hessian_damping: (float, optional) The % of Hessian damping to use for
|
128
|
+
stabilization during inverse calculation. Defaults to 0.01.
|
129
|
+
group_size: (int, optional) The size of weight groups to quantize
|
130
|
+
together. A `group_size` of -1 indicates per-channel quantization.
|
131
|
+
Defaults to 128.
|
132
|
+
symmetric: (bool, optional) If `True`, uses symmetric quantization.
|
133
|
+
If `False`, uses asymmetric quantization. Defaults to `False`.
|
134
|
+
activation_order: (bool, optional) If `True`, reorders weight columns
|
135
|
+
based on activation magnitude, which can improve quantization
|
136
|
+
accuracy. Defaults to `False`.
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
dataset,
|
142
|
+
tokenizer,
|
143
|
+
weight_bits: int = 4,
|
144
|
+
num_samples: int = 128,
|
145
|
+
sequence_length: int = 512,
|
146
|
+
hessian_damping: float = 0.01,
|
147
|
+
group_size: int = 128,
|
148
|
+
symmetric: bool = False,
|
149
|
+
activation_order: bool = False,
|
150
|
+
):
|
151
|
+
self.dataset = dataset
|
152
|
+
self.tokenizer = tokenizer
|
153
|
+
self.num_samples = num_samples
|
154
|
+
self.sequence_length = sequence_length
|
155
|
+
self.hessian_damping = hessian_damping
|
156
|
+
self.weight_bits = weight_bits
|
157
|
+
self.group_size = group_size
|
158
|
+
self.symmetric = symmetric
|
159
|
+
self.activation_order = activation_order
|
160
|
+
|
161
|
+
def quantize(self, model):
|
162
|
+
"""
|
163
|
+
Applies GPTQ quantization to the provided model using this
|
164
|
+
configuration.
|
165
|
+
"""
|
166
|
+
logging.info("Initiating quantization from GPTQConfig...")
|
167
|
+
# The core logic is now delegated to gptqutils, which will handle
|
168
|
+
# the dynamic imports and data loading.
|
169
|
+
quantize_model(model=model, config=self)
|