keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/quantizers/gptq.py
CHANGED
|
@@ -1,14 +1,282 @@
|
|
|
1
|
+
import types
|
|
2
|
+
|
|
1
3
|
from keras.src import ops
|
|
4
|
+
from keras.src import quantizers
|
|
2
5
|
from keras.src.layers import Dense
|
|
3
6
|
from keras.src.layers import EinsumDense
|
|
4
|
-
from keras.src.
|
|
7
|
+
from keras.src.ops import linalg
|
|
8
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
9
|
+
from keras.src.quantizers.quantizers import GPTQQuantizer
|
|
10
|
+
from keras.src.quantizers.quantizers import compute_quantization_parameters
|
|
11
|
+
from keras.src.quantizers.quantizers import dequantize_with_zero_point
|
|
12
|
+
from keras.src.quantizers.quantizers import quantize_with_zero_point
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _stable_permutation(metric):
|
|
16
|
+
"""Return a stable permutation that sorts `metric` in descending order.
|
|
17
|
+
Uses an index-based jitter to break ties deterministically."""
|
|
18
|
+
n = ops.shape(metric)[0]
|
|
19
|
+
idx = ops.arange(0, n, dtype="int32")
|
|
20
|
+
# tiny jitter = (idx / n) * 1e-12 so it never flips a real strict ordering
|
|
21
|
+
jitter = ops.divide(ops.cast(idx, "float32"), ops.cast(n, "float32"))
|
|
22
|
+
metric_jittered = ops.add(metric, ops.multiply(jitter, 1e-12))
|
|
23
|
+
# argsort by negative to get descending
|
|
24
|
+
return ops.argsort(ops.negative(metric_jittered))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def gptq_quantize_matrix(
|
|
28
|
+
weights_transpose,
|
|
29
|
+
inv_hessian,
|
|
30
|
+
*,
|
|
31
|
+
blocksize=128,
|
|
32
|
+
group_size=-1,
|
|
33
|
+
activation_order=False,
|
|
34
|
+
order_metric=None,
|
|
35
|
+
compute_scale_zero=compute_quantization_parameters,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Implements the GPTQ error correction updates.
|
|
39
|
+
|
|
40
|
+
For a single column update (column j):
|
|
41
|
+
e = invH[j, j] * (w_j - q_j)
|
|
42
|
+
W[:, j+1:] -= e * invH[j, j+1:]
|
|
43
|
+
where:
|
|
44
|
+
- w_j is the original column,
|
|
45
|
+
- q_j is the quantized column,
|
|
46
|
+
- invH is the inverse Hessian,
|
|
47
|
+
- e is the propagated error term.
|
|
48
|
+
|
|
49
|
+
Across entire blocks:
|
|
50
|
+
W[:, future] -= E_block * invH[block, future]
|
|
51
|
+
where:
|
|
52
|
+
- E_block is the quantization error accumulated for the current block,
|
|
53
|
+
- invH[block, future] denotes the cross-block slice of the inverse Hessian,
|
|
54
|
+
- W[:, future] are the columns yet to be quantized.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
weights_transpose: Transposed weight matrix [out_features, in_features]
|
|
58
|
+
to quantize.
|
|
59
|
+
inv_hessian: Inverse Hessian matrix [in_features, in_features] for
|
|
60
|
+
error propagation.
|
|
61
|
+
blocksize: Size of the blocks to process (default: 128).
|
|
62
|
+
group_size: Size of the groups for parameter reuse
|
|
63
|
+
(default: -1, no grouping).
|
|
64
|
+
activation_order: Whether to apply activation-order permutation
|
|
65
|
+
(default: False).
|
|
66
|
+
order_metric: Metric for ordering features
|
|
67
|
+
(default: None, uses 1 / diag(invH)).
|
|
68
|
+
compute_scale_zero: Function to compute scale and zero for
|
|
69
|
+
quantization.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
quantized_weights: Quantized weight matrix [out_features, in_features].
|
|
73
|
+
scale: float32. Scale parameters for quantization
|
|
74
|
+
[out_features, num_groups].
|
|
75
|
+
zero: Zero-point parameters for quantization [out_features, num_groups].
|
|
76
|
+
g_idx: int32. Group indices for each feature [in_features].
|
|
77
|
+
"""
|
|
78
|
+
in_features = ops.shape(weights_transpose)[1]
|
|
79
|
+
|
|
80
|
+
if activation_order:
|
|
81
|
+
# Use 1 / diag(inverse_hessian) as importance proxy by default.
|
|
82
|
+
if order_metric is None:
|
|
83
|
+
order_metric = ops.reciprocal(
|
|
84
|
+
ops.add(ops.diagonal(inv_hessian), 1e-12)
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
# sanitize provided metric
|
|
88
|
+
order_metric = ops.cast(order_metric, "float32")
|
|
89
|
+
order_metric = ops.where(
|
|
90
|
+
ops.isfinite(order_metric),
|
|
91
|
+
order_metric,
|
|
92
|
+
ops.zeros_like(order_metric),
|
|
93
|
+
)
|
|
94
|
+
# Sort in descending order by importance
|
|
95
|
+
perm = _stable_permutation(order_metric)
|
|
96
|
+
inv_perm = ops.argsort(perm)
|
|
97
|
+
|
|
98
|
+
weights_transpose = ops.take(weights_transpose, perm, axis=1)
|
|
99
|
+
inv_hessian = ops.take(
|
|
100
|
+
ops.take(inv_hessian, perm, axis=0), perm, axis=1
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
perm = inv_perm = None
|
|
104
|
+
|
|
105
|
+
# weights_buffer: [out_features, in_features]
|
|
106
|
+
weights_buffer = weights_transpose
|
|
107
|
+
# Buffer for the final quantized matrix: [out_features, in_features]
|
|
108
|
+
quantized_weights_buffer = ops.zeros_like(weights_transpose, dtype="int32")
|
|
109
|
+
|
|
110
|
+
scale_chunks = []
|
|
111
|
+
zero_chunks = []
|
|
112
|
+
|
|
113
|
+
# Compute effective group size
|
|
114
|
+
effective_group = in_features if group_size == -1 else group_size
|
|
115
|
+
|
|
116
|
+
# Process features in blocks
|
|
117
|
+
for block_start in range(0, in_features, blocksize):
|
|
118
|
+
block_end = min(block_start + blocksize, in_features)
|
|
119
|
+
block_size = block_end - block_start
|
|
120
|
+
|
|
121
|
+
# Block views
|
|
122
|
+
# block_weights: [out_features, block_size]
|
|
123
|
+
block_weights = weights_buffer[:, block_start:block_end]
|
|
124
|
+
# block_error: [out_features, block_size]
|
|
125
|
+
block_error = ops.zeros_like(block_weights)
|
|
126
|
+
# block_inv_hessian: [block_size, block_size]
|
|
127
|
+
block_inv_hessian = inv_hessian[
|
|
128
|
+
block_start:block_end, block_start:block_end
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
# Per-group cached params for reuse within the group
|
|
132
|
+
cached_scale = None
|
|
133
|
+
cached_zero = None
|
|
134
|
+
cached_maxq = None
|
|
135
|
+
cached_group_start = -1
|
|
136
|
+
|
|
137
|
+
for block_idx in range(block_size):
|
|
138
|
+
# Current global column index, represents the original column
|
|
139
|
+
# in the weight matrix
|
|
140
|
+
global_idx = block_start + block_idx
|
|
141
|
+
# weight_column: [out_features,]
|
|
142
|
+
weight_column = block_weights[:, block_idx]
|
|
143
|
+
# Group-wise parameter reuse (compute once per group)
|
|
144
|
+
if not effective_group == in_features: # group_size != -1
|
|
145
|
+
# Determine the group start index for the current column
|
|
146
|
+
group_start = (global_idx // effective_group) * effective_group
|
|
147
|
+
if group_start != cached_group_start:
|
|
148
|
+
# New group encountered, compute & cache params
|
|
149
|
+
# for this group
|
|
150
|
+
group_end = min(group_start + effective_group, in_features)
|
|
151
|
+
group_slice = weights_buffer[:, group_start:group_end]
|
|
152
|
+
cached_scale, cached_zero, cached_maxq = compute_scale_zero(
|
|
153
|
+
group_slice
|
|
154
|
+
)
|
|
155
|
+
# Store params once per group (in the order encountered).
|
|
156
|
+
scale_chunks.append(cached_scale)
|
|
157
|
+
zero_chunks.append(cached_zero)
|
|
158
|
+
cached_group_start = group_start
|
|
159
|
+
scale, zero, maxq = cached_scale, cached_zero, cached_maxq
|
|
160
|
+
else:
|
|
161
|
+
# Single global group covering all columns.
|
|
162
|
+
if cached_scale is None:
|
|
163
|
+
cached_scale, cached_zero, cached_maxq = compute_scale_zero(
|
|
164
|
+
weights_buffer
|
|
165
|
+
)
|
|
166
|
+
scale_chunks.append(cached_scale)
|
|
167
|
+
zero_chunks.append(cached_zero)
|
|
168
|
+
cached_group_start = 0
|
|
169
|
+
scale, zero, maxq = cached_scale, cached_zero, cached_maxq
|
|
170
|
+
|
|
171
|
+
# Quantize column and store it.
|
|
172
|
+
# quantized_column: [out_features, 1]
|
|
173
|
+
quantized_column = quantize_with_zero_point(
|
|
174
|
+
ops.expand_dims(weight_column, 1), scale, zero, maxq
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Store quantized column in the buffer.
|
|
178
|
+
quantized_weights_buffer = ops.slice_update(
|
|
179
|
+
quantized_weights_buffer,
|
|
180
|
+
(0, global_idx),
|
|
181
|
+
ops.cast(quantized_column, "int32"),
|
|
182
|
+
)
|
|
183
|
+
# Dequantize column to compute error.
|
|
184
|
+
# dequantized_col: [out_features,]
|
|
185
|
+
dequantized_col = dequantize_with_zero_point(
|
|
186
|
+
quantized_column, scale, zero
|
|
187
|
+
)[:, 0]
|
|
188
|
+
# Error feedback for remaining columns within the block
|
|
189
|
+
# block_inv_hessian_diag: scalar
|
|
190
|
+
current_block_influence = block_inv_hessian[block_idx, block_idx]
|
|
191
|
+
# We divide by current_block_influence to get the
|
|
192
|
+
# correct scaling of the error term.
|
|
193
|
+
err = ops.divide(
|
|
194
|
+
ops.subtract(weight_column, dequantized_col),
|
|
195
|
+
current_block_influence,
|
|
196
|
+
)
|
|
197
|
+
# Record error for propagation to future blocks
|
|
198
|
+
block_error = ops.slice_update(
|
|
199
|
+
block_error, (0, block_idx), ops.expand_dims(err, 1)
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Update remaining columns in the current block
|
|
203
|
+
# (those before the current column have already been quantized)
|
|
204
|
+
# Propagate error to remaining columns in the block.
|
|
205
|
+
if block_idx < block_size - 1:
|
|
206
|
+
# update: [out_features, block_size - block_idx - 1]
|
|
207
|
+
update = ops.matmul(
|
|
208
|
+
ops.expand_dims(err, 1),
|
|
209
|
+
ops.expand_dims(
|
|
210
|
+
block_inv_hessian[block_idx, block_idx + 1 :], 0
|
|
211
|
+
),
|
|
212
|
+
)
|
|
213
|
+
# tail is a view of the remaining columns in the block
|
|
214
|
+
# to be updated
|
|
215
|
+
# tail: [out_features, block_size - block_idx - 1]
|
|
216
|
+
tail = block_weights[:, block_idx + 1 :]
|
|
217
|
+
block_weights = ops.slice_update(
|
|
218
|
+
block_weights,
|
|
219
|
+
(0, block_idx + 1),
|
|
220
|
+
ops.subtract(tail, update),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Propagate block errors to future features (beyond the block)
|
|
224
|
+
if block_end < in_features:
|
|
225
|
+
# Total update for all future columns, based on the
|
|
226
|
+
# accumulated error in this block. This is calculated
|
|
227
|
+
# as the matrix product of the block_error and the
|
|
228
|
+
# relevant slice of the inverse Hessian.
|
|
229
|
+
# total_update: [out_features, in_features - block_end]
|
|
230
|
+
total_update = ops.matmul(
|
|
231
|
+
block_error, inv_hessian[block_start:block_end, block_end:]
|
|
232
|
+
)
|
|
233
|
+
# Update the remaining weights in the buffer. This is done
|
|
234
|
+
# by subtracting the total_update from the remaining columns.
|
|
235
|
+
weights_buffer = ops.concatenate(
|
|
236
|
+
[
|
|
237
|
+
weights_buffer[:, :block_end],
|
|
238
|
+
ops.subtract(weights_buffer[:, block_end:], total_update),
|
|
239
|
+
],
|
|
240
|
+
axis=1,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Build group indices for each (possibly permuted) column
|
|
244
|
+
# base_group = effective_group (int)
|
|
245
|
+
base_group = effective_group
|
|
246
|
+
|
|
247
|
+
# g_idx in permuted domain
|
|
248
|
+
g_idx = ops.arange(0, in_features, dtype="int32")
|
|
249
|
+
g_idx = ops.divide(g_idx, base_group)
|
|
250
|
+
g_idx = ops.cast(g_idx, "float32")
|
|
251
|
+
|
|
252
|
+
# Map group indices and quantized weights back to original column order
|
|
253
|
+
if activation_order:
|
|
254
|
+
g_idx = ops.take(g_idx, inv_perm, axis=0)
|
|
255
|
+
quantized_weights_buffer = ops.take(
|
|
256
|
+
quantized_weights_buffer, inv_perm, axis=1
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Concatenate recorded group params
|
|
260
|
+
if len(scale_chunks) == 0:
|
|
261
|
+
# Edge case: no groups recorded (empty input); fall back to whole matrix
|
|
262
|
+
s, z, _ = compute_scale_zero(weights_transpose)
|
|
263
|
+
scale = s
|
|
264
|
+
zero = z
|
|
265
|
+
else:
|
|
266
|
+
scale = ops.concatenate(scale_chunks, axis=1)
|
|
267
|
+
zero = ops.concatenate(zero_chunks, axis=1)
|
|
268
|
+
|
|
269
|
+
return quantized_weights_buffer, scale, zero, g_idx
|
|
5
270
|
|
|
6
271
|
|
|
7
272
|
class GPTQ:
|
|
8
|
-
def __init__(self, layer):
|
|
273
|
+
def __init__(self, layer, config=GPTQConfig(tokenizer=None, dataset=None)):
|
|
9
274
|
self.original_layer = layer
|
|
10
275
|
self.num_samples = 0
|
|
11
|
-
self.
|
|
276
|
+
self.config = config
|
|
277
|
+
self.quantizer = GPTQQuantizer(
|
|
278
|
+
config, compute_dtype=layer.variable_dtype
|
|
279
|
+
)
|
|
12
280
|
|
|
13
281
|
# Explicitly handle each supported layer type
|
|
14
282
|
if isinstance(layer, Dense) or (
|
|
@@ -16,21 +284,18 @@ class GPTQ:
|
|
|
16
284
|
):
|
|
17
285
|
# For a standard Dense layer, the dimensions are straightforward.
|
|
18
286
|
self.kernel_shape = layer.kernel.shape
|
|
19
|
-
|
|
20
|
-
self.
|
|
21
|
-
|
|
287
|
+
# rows: [input_features]
|
|
288
|
+
self.rows = self.kernel_shape[0]
|
|
289
|
+
# columns: [output_features]
|
|
290
|
+
self.columns = self.kernel_shape[1]
|
|
291
|
+
self.layer = layer
|
|
22
292
|
|
|
23
293
|
# Handle 3D EinsumDense layers (typically from attention blocks).
|
|
24
294
|
elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
|
|
25
295
|
# For EinsumDense, we determine the effective 2D dimensions.
|
|
26
296
|
self.kernel_shape = layer.kernel.shape
|
|
27
297
|
shape = list(self.kernel_shape)
|
|
28
|
-
|
|
29
|
-
d_model_dim_index = shape.index(max(shape))
|
|
30
|
-
except ValueError:
|
|
31
|
-
raise TypeError(
|
|
32
|
-
f"Could not determine hidden dimension from shape {shape}"
|
|
33
|
-
)
|
|
298
|
+
d_model_dim_index = shape.index(max(shape))
|
|
34
299
|
|
|
35
300
|
if d_model_dim_index == 0: # QKV projection case
|
|
36
301
|
in_features, heads, head_dim = shape
|
|
@@ -47,17 +312,9 @@ class GPTQ:
|
|
|
47
312
|
|
|
48
313
|
# Create a temporary object that holds a reshaped
|
|
49
314
|
# 2D version of the kernel.
|
|
50
|
-
self.layer =
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
{
|
|
54
|
-
"kernel": ops.reshape(
|
|
55
|
-
layer.kernel, (self.rows, self.columns)
|
|
56
|
-
),
|
|
57
|
-
"bias": layer.bias,
|
|
58
|
-
},
|
|
59
|
-
)()
|
|
60
|
-
|
|
315
|
+
self.layer = types.SimpleNamespace(
|
|
316
|
+
kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
|
|
317
|
+
)
|
|
61
318
|
else:
|
|
62
319
|
# Raise an error if the layer is not supported.
|
|
63
320
|
raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}")
|
|
@@ -86,53 +343,55 @@ class GPTQ:
|
|
|
86
343
|
pre-initialized Hessian matrix `self.hessian`.
|
|
87
344
|
"""
|
|
88
345
|
if input_batch is None:
|
|
89
|
-
raise ValueError("Input tensor
|
|
346
|
+
raise ValueError("Input tensor cannot be None.")
|
|
90
347
|
|
|
91
348
|
if len(input_batch.shape) < 2:
|
|
92
349
|
raise ValueError(
|
|
93
|
-
|
|
94
|
-
f"(
|
|
95
|
-
f"{len(input_batch.shape)}."
|
|
350
|
+
"Input tensor must have rank >= 2 "
|
|
351
|
+
f"(got rank {len(input_batch.shape)})."
|
|
96
352
|
)
|
|
97
353
|
if ops.size(input_batch) == 0:
|
|
98
|
-
raise ValueError("Input tensor
|
|
99
|
-
|
|
354
|
+
raise ValueError("Input tensor cannot be empty.")
|
|
100
355
|
if len(input_batch.shape) > 2:
|
|
356
|
+
# [batch, features]
|
|
101
357
|
input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
|
|
102
|
-
|
|
358
|
+
x = ops.cast(input_batch, "float32")
|
|
359
|
+
|
|
360
|
+
num_new_samples = ops.shape(x)[0]
|
|
361
|
+
num_prev_samples = self.num_samples
|
|
362
|
+
total_samples = ops.add(num_prev_samples, num_new_samples)
|
|
103
363
|
|
|
104
|
-
if self.hessian
|
|
364
|
+
if ops.shape(self.hessian)[0] != ops.shape(x)[-1]:
|
|
105
365
|
raise ValueError(
|
|
106
|
-
f"Hessian dimensions ({self.hessian
|
|
107
|
-
"match input features ({
|
|
366
|
+
f"Hessian dimensions ({ops.shape(self.hessian)[0]}) do not "
|
|
367
|
+
f"match input features ({ops.shape(x)[-1]})."
|
|
108
368
|
)
|
|
109
369
|
|
|
110
|
-
|
|
111
|
-
|
|
370
|
+
# gram_matrix: [features, features]
|
|
371
|
+
gram_matrix = ops.matmul(ops.transpose(x), x)
|
|
372
|
+
# Ensures numerical stability and symmetry in case of large floating
|
|
373
|
+
# point activations.
|
|
374
|
+
gram_matrix = ops.divide(
|
|
375
|
+
ops.add(gram_matrix, ops.transpose(gram_matrix)), 2.0
|
|
112
376
|
)
|
|
113
377
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
current_hessian_weight = ops.divide(
|
|
120
|
-
input_batch.shape[0], total_samples
|
|
378
|
+
# Decay previous mean and add current per-sample contribution
|
|
379
|
+
# (factor 2/N)
|
|
380
|
+
if self.num_samples > 0:
|
|
381
|
+
self.hessian = ops.multiply(
|
|
382
|
+
self.hessian, ops.divide(num_prev_samples, total_samples)
|
|
121
383
|
)
|
|
122
384
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
385
|
+
self.hessian = ops.add(
|
|
386
|
+
self.hessian,
|
|
387
|
+
ops.multiply(ops.divide(2.0, total_samples), gram_matrix),
|
|
388
|
+
)
|
|
127
389
|
|
|
128
|
-
self.num_samples =
|
|
390
|
+
self.num_samples = self.num_samples + ops.shape(x)[0] or 0
|
|
129
391
|
|
|
130
|
-
def
|
|
392
|
+
def quantize_and_correct_layer(
|
|
131
393
|
self,
|
|
132
394
|
blocksize=128,
|
|
133
|
-
hessian_damping=0.01,
|
|
134
|
-
group_size=-1,
|
|
135
|
-
activation_order=False,
|
|
136
395
|
):
|
|
137
396
|
"""
|
|
138
397
|
Performs GPTQ quantization and correction on the layer's weights.
|
|
@@ -143,66 +402,42 @@ class GPTQ:
|
|
|
143
402
|
quantization error by updating the remaining weights.
|
|
144
403
|
|
|
145
404
|
The algorithm follows these main steps:
|
|
146
|
-
1.
|
|
405
|
+
1. Initialization: It optionally reorders the weight columns based
|
|
147
406
|
on activation magnitudes (`activation_order=True`) to protect more
|
|
148
407
|
salient
|
|
149
408
|
weights.
|
|
150
|
-
2.
|
|
409
|
+
2. Hessian Modification: The Hessian matrix, pre-computed from
|
|
151
410
|
calibration data, is dampened to ensure its invertibility and
|
|
152
411
|
stability.
|
|
153
|
-
3.
|
|
412
|
+
3. Iterative Quantization: The function iterates through the
|
|
154
413
|
weight columns in blocks (`blocksize`). In each iteration, it:
|
|
155
414
|
a. Quantizes one column.
|
|
156
415
|
b. Calculates the quantization error.
|
|
157
416
|
c. Updates the remaining weights in the *current* block by
|
|
158
417
|
distributing the error, using the inverse Hessian.
|
|
159
|
-
4.
|
|
418
|
+
4. Block-wise Correction: After a block is quantized, the total
|
|
160
419
|
error from that block is propagated to the *next* block of weights
|
|
161
420
|
to be processed.
|
|
162
|
-
5.
|
|
421
|
+
5. Finalization: The quantized weights are reordered back if
|
|
163
422
|
`activation_order` was used, and the layer's weights are updated.
|
|
164
|
-
|
|
165
423
|
This implementation is based on the official GPTQ paper and repository.
|
|
166
424
|
For more details, see:
|
|
167
425
|
- Paper: https://arxiv.org/abs/2210.17323
|
|
168
426
|
- Original Code: https://github.com/IST-DASLab/gptq
|
|
169
427
|
|
|
428
|
+
|
|
170
429
|
Args:
|
|
171
430
|
blocksize: (int, optional) The size of the weight block to process
|
|
172
431
|
at a time. Defaults to 128.
|
|
173
|
-
hessian_damping: (float, optional) The percentage of dampening to
|
|
174
|
-
add the
|
|
175
|
-
Hessian's diagonal. A value of 0.01 is recommended.
|
|
176
|
-
Defaults to 0.01.
|
|
177
|
-
group_size: (int, optional) The number of weights that share the
|
|
178
|
-
same quantization parameters (scale and zero-point).
|
|
179
|
-
A value of -1 indicates per-channel quantization.
|
|
180
|
-
activation_order: (bool, optional) If True, reorders weight columns
|
|
181
|
-
based
|
|
182
|
-
on their activation's second-order information.
|
|
183
432
|
"""
|
|
184
|
-
|
|
185
|
-
weights_matrix = ops.transpose(ops.cast(self.layer.kernel, "float32"))
|
|
186
|
-
hessian_matrix = ops.cast(self.hessian, "float32")
|
|
187
|
-
|
|
188
|
-
if activation_order:
|
|
189
|
-
permutation = ops.argsort(
|
|
190
|
-
ops.negative(ops.diagonal(hessian_matrix))
|
|
191
|
-
)
|
|
192
|
-
weights_matrix = ops.take(weights_matrix, permutation, axis=1)
|
|
193
|
-
hessian_matrix = ops.take(
|
|
194
|
-
ops.take(hessian_matrix, permutation, axis=0),
|
|
195
|
-
permutation,
|
|
196
|
-
axis=1,
|
|
197
|
-
)
|
|
198
|
-
inverse_permutation = ops.argsort(permutation)
|
|
433
|
+
weights_matrix = ops.transpose(self.layer.kernel)
|
|
199
434
|
|
|
200
435
|
# Dampen the Hessian for Stability
|
|
201
|
-
hessian_diagonal = ops.diagonal(
|
|
436
|
+
hessian_diagonal = ops.diagonal(self.hessian)
|
|
202
437
|
dead_diagonal = ops.equal(hessian_diagonal, 0.0)
|
|
203
438
|
hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)
|
|
204
439
|
hessian_matrix = ops.add(
|
|
205
|
-
|
|
440
|
+
self.hessian,
|
|
206
441
|
ops.diag(
|
|
207
442
|
ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))
|
|
208
443
|
),
|
|
@@ -210,7 +445,7 @@ class GPTQ:
|
|
|
210
445
|
|
|
211
446
|
# Add dampening factor to the Hessian diagonal
|
|
212
447
|
damping_factor = ops.multiply(
|
|
213
|
-
hessian_damping, ops.mean(hessian_diagonal)
|
|
448
|
+
self.config.hessian_damping, ops.mean(hessian_diagonal)
|
|
214
449
|
)
|
|
215
450
|
hessian_diagonal = ops.add(hessian_diagonal, damping_factor)
|
|
216
451
|
hessian_matrix = ops.add(
|
|
@@ -221,130 +456,34 @@ class GPTQ:
|
|
|
221
456
|
)
|
|
222
457
|
|
|
223
458
|
# Compute the inverse Hessian, which is used for error correction
|
|
224
|
-
inverse_hessian =
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
# Process one column at a time within the block
|
|
240
|
-
for col_idx in range(block_size):
|
|
241
|
-
weight_column = block_weights[:, col_idx]
|
|
242
|
-
diagonal_element = block_inverse_hessian[col_idx, col_idx]
|
|
243
|
-
|
|
244
|
-
if group_size != -1:
|
|
245
|
-
if ops.mod(ops.add(block_start, col_idx), group_size) == 0:
|
|
246
|
-
self.quantizer.find_params(
|
|
247
|
-
weights_matrix[
|
|
248
|
-
:,
|
|
249
|
-
(ops.add(block_start, col_idx)) : (
|
|
250
|
-
ops.add(
|
|
251
|
-
ops.add(block_start, col_idx),
|
|
252
|
-
group_size,
|
|
253
|
-
)
|
|
254
|
-
),
|
|
255
|
-
],
|
|
256
|
-
weight=True,
|
|
257
|
-
)
|
|
258
|
-
else:
|
|
259
|
-
self.quantizer.find_params(
|
|
260
|
-
ops.expand_dims(weight_column, 1), weight=True
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
# Quantize the current weight column
|
|
264
|
-
quantized_column = dequantize(
|
|
265
|
-
ops.expand_dims(weight_column, 1),
|
|
266
|
-
self.quantizer.scale,
|
|
267
|
-
self.quantizer.zero,
|
|
268
|
-
self.quantizer.maxq,
|
|
269
|
-
)[:, 0]
|
|
270
|
-
|
|
271
|
-
block_quantized = ops.slice_update(
|
|
272
|
-
block_quantized,
|
|
273
|
-
(0, col_idx),
|
|
274
|
-
ops.expand_dims(quantized_column, axis=1),
|
|
275
|
-
)
|
|
276
|
-
quantization_error = ops.divide(
|
|
277
|
-
ops.subtract(weight_column, quantized_column),
|
|
278
|
-
diagonal_element,
|
|
279
|
-
)
|
|
280
|
-
block_errors = ops.slice_update(
|
|
281
|
-
block_errors,
|
|
282
|
-
(0, col_idx),
|
|
283
|
-
ops.expand_dims(quantization_error, axis=1),
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
if ops.less(col_idx, ops.subtract(block_size, 1)):
|
|
287
|
-
error_update = ops.matmul(
|
|
288
|
-
ops.expand_dims(quantization_error, 1),
|
|
289
|
-
ops.expand_dims(
|
|
290
|
-
block_inverse_hessian[
|
|
291
|
-
col_idx, ops.add(col_idx, 1) :
|
|
292
|
-
],
|
|
293
|
-
0,
|
|
294
|
-
),
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
# Efficiently update the remaining part of the
|
|
298
|
-
# block_weights tensor.
|
|
299
|
-
slice_to_update = block_weights[:, ops.add(col_idx, 1) :]
|
|
300
|
-
updated_slice = ops.subtract(slice_to_update, error_update)
|
|
301
|
-
block_weights = ops.slice_update(
|
|
302
|
-
block_weights, (0, ops.add(col_idx, 1)), updated_slice
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
# Update the full quantized matrix with the processed block
|
|
306
|
-
quantized_weights = ops.concatenate(
|
|
307
|
-
[
|
|
308
|
-
quantized_weights[:, :block_start],
|
|
309
|
-
block_quantized,
|
|
310
|
-
quantized_weights[:, block_end:],
|
|
311
|
-
],
|
|
312
|
-
axis=1,
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
if block_end < self.rows:
|
|
316
|
-
total_error_update = ops.matmul(
|
|
317
|
-
block_errors,
|
|
318
|
-
inverse_hessian[block_start:block_end, block_end:],
|
|
319
|
-
)
|
|
320
|
-
weights_matrix = ops.concatenate(
|
|
321
|
-
[
|
|
322
|
-
weights_matrix[:, :block_end],
|
|
323
|
-
ops.subtract(
|
|
324
|
-
weights_matrix[:, block_end:], total_error_update
|
|
325
|
-
),
|
|
326
|
-
],
|
|
327
|
-
axis=1,
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
if activation_order:
|
|
331
|
-
quantized_weights = ops.take(
|
|
332
|
-
quantized_weights, inverse_permutation, axis=1
|
|
333
|
-
)
|
|
334
|
-
|
|
335
|
-
quantized_weights = ops.transpose(quantized_weights)
|
|
459
|
+
inverse_hessian = linalg.inv(hessian_matrix)
|
|
460
|
+
|
|
461
|
+
quantized, scale, zero, g_idx = gptq_quantize_matrix(
|
|
462
|
+
weights_matrix,
|
|
463
|
+
inv_hessian=inverse_hessian,
|
|
464
|
+
blocksize=blocksize,
|
|
465
|
+
group_size=self.config.group_size,
|
|
466
|
+
activation_order=self.config.activation_order,
|
|
467
|
+
order_metric=ops.diagonal(hessian_matrix),
|
|
468
|
+
compute_scale_zero=self.quantizer.find_params,
|
|
469
|
+
)
|
|
470
|
+
quantized = ops.cast(
|
|
471
|
+
quantized, self.original_layer.quantized_kernel.dtype
|
|
472
|
+
)
|
|
336
473
|
|
|
337
|
-
if
|
|
338
|
-
|
|
339
|
-
|
|
474
|
+
if self.config.weight_bits == 4:
|
|
475
|
+
# For 4-bit weights, we need to pack them into bytes
|
|
476
|
+
quantized, _, _ = quantizers.pack_int4(
|
|
477
|
+
quantized, axis=0, dtype="uint8"
|
|
340
478
|
)
|
|
341
479
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
self.original_layer.
|
|
480
|
+
del self.original_layer._kernel
|
|
481
|
+
self.original_layer.quantized_kernel.assign(quantized)
|
|
482
|
+
self.original_layer.kernel_scale.assign(scale)
|
|
483
|
+
self.original_layer.kernel_zero.assign(zero)
|
|
484
|
+
self.original_layer.g_idx.assign(g_idx)
|
|
485
|
+
self.original_layer.is_gptq_calibrated = True
|
|
348
486
|
|
|
349
487
|
def free(self):
|
|
350
|
-
self.hessian
|
|
488
|
+
del self.hessian
|
|
489
|
+
del self.layer
|