keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +2 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +2 -0
- keras/ops/numpy/__init__.py +2 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/jax/nn.py +26 -9
- keras/src/backend/jax/numpy.py +10 -0
- keras/src/backend/numpy/numpy.py +15 -0
- keras/src/backend/openvino/numpy.py +338 -17
- keras/src/backend/tensorflow/numpy.py +24 -1
- keras/src/backend/tensorflow/rnn.py +17 -7
- keras/src/backend/torch/numpy.py +26 -0
- keras/src/backend/torch/rnn.py +28 -11
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/core/reversible_embedding.py +10 -1
- keras/src/layers/layer.py +5 -0
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +18 -9
- keras/src/ops/image.py +106 -93
- keras/src/ops/numpy.py +138 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/random/seed_generator.py +2 -2
- keras/src/saving/orbax_util.py +50 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/jax_layer.py +69 -31
- keras/src/utils/module_utils.py +11 -0
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
"""AWQ (Activation-aware Weight Quantization) algorithm implementation.
|
|
2
|
+
|
|
3
|
+
AWQ protects salient weights by finding optimal per-channel scales based on
|
|
4
|
+
activation magnitudes, then applies those scales before quantization.
|
|
5
|
+
|
|
6
|
+
Reference: https://arxiv.org/abs/2306.00978
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import types
|
|
10
|
+
|
|
11
|
+
from keras.src import ops
|
|
12
|
+
from keras.src.layers import Dense
|
|
13
|
+
from keras.src.layers import EinsumDense
|
|
14
|
+
from keras.src.quantizers.quantizers import compute_quantization_parameters
|
|
15
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
16
|
+
from keras.src.quantizers.quantizers import dequantize_with_zero_point
|
|
17
|
+
from keras.src.quantizers.quantizers import quantize_with_sz_map
|
|
18
|
+
from keras.src.quantizers.quantizers import quantize_with_zero_point
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def awq_search_optimal_scales(
|
|
22
|
+
weights,
|
|
23
|
+
activation_magnitudes,
|
|
24
|
+
*,
|
|
25
|
+
num_grid_points=20,
|
|
26
|
+
group_size=-1,
|
|
27
|
+
):
|
|
28
|
+
"""Search for optimal AWQ scales using grid search.
|
|
29
|
+
|
|
30
|
+
The AWQ algorithm finds scaling factors that protect salient weights.
|
|
31
|
+
For each channel, we search for an optimal ratio in [0, 1] that minimizes
|
|
32
|
+
the activation-weighted quantization error.
|
|
33
|
+
|
|
34
|
+
The key insight: we MULTIPLY weights by scales before quantization to
|
|
35
|
+
expand salient weights. This ensures quantization noise is small relative
|
|
36
|
+
to the expanded weight magnitude. During inference, we divide by scales
|
|
37
|
+
to restore the original magnitude.
|
|
38
|
+
|
|
39
|
+
Scale formula: scales = x_max.pow(ratio).clamp(min=1e-4)
|
|
40
|
+
Loss function: Activation-weighted MSE (approximates output error)
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
weights: Weight tensor [out_features, in_features] (transposed kernel).
|
|
44
|
+
activation_magnitudes: Per-channel activation magnitudes [in_features].
|
|
45
|
+
num_grid_points: Number of grid search points. Defaults to 20.
|
|
46
|
+
group_size: Group size for quantization (-1 for per-channel).
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
best_scales: Optimal per-channel scales [in_features].
|
|
50
|
+
"""
|
|
51
|
+
in_features = ops.shape(weights)[1]
|
|
52
|
+
|
|
53
|
+
# Compute per-channel activation magnitudes (x_max)
|
|
54
|
+
# activations should already be per-channel max magnitudes
|
|
55
|
+
x_max = ops.cast(activation_magnitudes, "float32")
|
|
56
|
+
# Avoid zero or very small values
|
|
57
|
+
x_max = ops.where(ops.less(x_max, 1e-8), ops.ones_like(x_max), x_max)
|
|
58
|
+
|
|
59
|
+
best_loss = None
|
|
60
|
+
best_scales = ops.ones((in_features,), dtype="float32")
|
|
61
|
+
|
|
62
|
+
# Grid search over ratio values from 0 to 1
|
|
63
|
+
for i in range(num_grid_points + 1):
|
|
64
|
+
ratio = i / num_grid_points
|
|
65
|
+
|
|
66
|
+
# Compute scales: x_max^ratio (clipped to avoid numerical issues)
|
|
67
|
+
if ratio == 0:
|
|
68
|
+
scales = ops.ones_like(x_max)
|
|
69
|
+
else:
|
|
70
|
+
scales = ops.power(x_max, ratio)
|
|
71
|
+
scales = ops.maximum(scales, 1e-4)
|
|
72
|
+
|
|
73
|
+
# Normalize scales to avoid extreme values
|
|
74
|
+
scale_mean = ops.sqrt(ops.multiply(ops.max(scales), ops.min(scales)))
|
|
75
|
+
scale_mean = ops.maximum(scale_mean, 1e-8)
|
|
76
|
+
scales = ops.divide(scales, scale_mean)
|
|
77
|
+
|
|
78
|
+
# Apply scales to weights by MULTIPLYING (expand salient weights)
|
|
79
|
+
# weights_scaled: [out_features, in_features]
|
|
80
|
+
weights_scaled = ops.multiply(weights, scales)
|
|
81
|
+
|
|
82
|
+
if group_size == -1:
|
|
83
|
+
# Per-channel quantization (no grouping)
|
|
84
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
85
|
+
weights_scaled,
|
|
86
|
+
bits=4,
|
|
87
|
+
symmetric=False,
|
|
88
|
+
per_channel=True,
|
|
89
|
+
group_size=-1,
|
|
90
|
+
compute_dtype="float32",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Quantize and dequantize
|
|
94
|
+
quantized = quantize_with_zero_point(
|
|
95
|
+
weights_scaled, scale_q, zero_q, maxq
|
|
96
|
+
)
|
|
97
|
+
dequantized = dequantize_with_zero_point(quantized, scale_q, zero_q)
|
|
98
|
+
else:
|
|
99
|
+
# Grouped quantization - use proper per-row grouping
|
|
100
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
101
|
+
weights_scaled,
|
|
102
|
+
bits=4,
|
|
103
|
+
symmetric=False,
|
|
104
|
+
per_channel=True,
|
|
105
|
+
group_size=group_size,
|
|
106
|
+
compute_dtype="float32",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Compute group indices: maps each input feature to its group
|
|
110
|
+
g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
|
|
111
|
+
|
|
112
|
+
# Quantize and dequantize using group index mapping
|
|
113
|
+
quantized = quantize_with_sz_map(
|
|
114
|
+
weights_scaled, scale_q, zero_q, g_idx, maxq
|
|
115
|
+
)
|
|
116
|
+
dequantized = dequantize_with_sz_map(
|
|
117
|
+
quantized, scale_q, zero_q, g_idx
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Scale back down by DIVIDING to restore original magnitude
|
|
121
|
+
reconstructed = ops.divide(dequantized, scales)
|
|
122
|
+
|
|
123
|
+
# Compute activation-weighted MSE loss
|
|
124
|
+
# This approximates the output error: ||W*X - W_hat*X||^2
|
|
125
|
+
# by weighting each channel's error by x_max^2
|
|
126
|
+
weight_error = ops.square(ops.subtract(weights, reconstructed))
|
|
127
|
+
# Weight by activation magnitudes squared (broadcast over out_features)
|
|
128
|
+
weighted_error = ops.multiply(weight_error, ops.square(x_max))
|
|
129
|
+
loss = ops.mean(weighted_error)
|
|
130
|
+
|
|
131
|
+
# Track best
|
|
132
|
+
if best_loss is None:
|
|
133
|
+
best_loss = loss
|
|
134
|
+
best_scales = scales
|
|
135
|
+
else:
|
|
136
|
+
is_better = ops.less(loss, best_loss)
|
|
137
|
+
if is_better:
|
|
138
|
+
best_loss = loss
|
|
139
|
+
best_scales = scales
|
|
140
|
+
|
|
141
|
+
return best_scales
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def awq_quantize_matrix(
|
|
145
|
+
weights_transpose,
|
|
146
|
+
activation_magnitudes,
|
|
147
|
+
*,
|
|
148
|
+
num_grid_points=20,
|
|
149
|
+
group_size=-1,
|
|
150
|
+
):
|
|
151
|
+
"""Quantize a weight matrix using AWQ.
|
|
152
|
+
|
|
153
|
+
This function performs the complete AWQ quantization process:
|
|
154
|
+
1. Find optimal per-channel scales via grid search
|
|
155
|
+
2. Apply scales to weights
|
|
156
|
+
3. Compute quantization parameters
|
|
157
|
+
4. Quantize weights
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
weights_transpose: Weight matrix [out_features, in_features].
|
|
161
|
+
activation_magnitudes: Per-channel activation magnitudes [in_features].
|
|
162
|
+
num_grid_points: Number of grid search points.
|
|
163
|
+
group_size: Group size for quantization.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
quantized_weights: Quantized weights [out_features, in_features].
|
|
167
|
+
scales: Quantization scales [out_features, num_groups].
|
|
168
|
+
zeros: Zero points [out_features, num_groups].
|
|
169
|
+
awq_scales: AWQ per-channel scales [in_features].
|
|
170
|
+
g_idx: Group indices [in_features].
|
|
171
|
+
"""
|
|
172
|
+
in_features = ops.shape(weights_transpose)[1]
|
|
173
|
+
|
|
174
|
+
# Step 1: Find optimal AWQ scales via grid search
|
|
175
|
+
awq_scales = awq_search_optimal_scales(
|
|
176
|
+
weights_transpose,
|
|
177
|
+
activation_magnitudes,
|
|
178
|
+
num_grid_points=num_grid_points,
|
|
179
|
+
group_size=group_size,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Step 2: Apply AWQ scales by MULTIPLYING (expand salient weights)
|
|
183
|
+
# weights_scaled: [out_features, in_features]
|
|
184
|
+
weights_scaled = ops.multiply(weights_transpose, awq_scales)
|
|
185
|
+
|
|
186
|
+
if group_size == -1:
|
|
187
|
+
# Per-channel quantization (no grouping)
|
|
188
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
189
|
+
weights_scaled,
|
|
190
|
+
bits=4,
|
|
191
|
+
symmetric=False,
|
|
192
|
+
per_channel=True,
|
|
193
|
+
group_size=-1,
|
|
194
|
+
compute_dtype="float32",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Quantize
|
|
198
|
+
quantized = quantize_with_zero_point(
|
|
199
|
+
weights_scaled, scale_q, zero_q, maxq
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Build group indices (all 0s for per-channel)
|
|
203
|
+
g_idx = ops.zeros((in_features,), dtype="float32")
|
|
204
|
+
else:
|
|
205
|
+
# Grouped quantization - use proper per-row grouping
|
|
206
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
207
|
+
weights_scaled,
|
|
208
|
+
bits=4,
|
|
209
|
+
symmetric=False,
|
|
210
|
+
per_channel=True,
|
|
211
|
+
group_size=group_size,
|
|
212
|
+
compute_dtype="float32",
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Compute group indices: maps each input feature to its group
|
|
216
|
+
g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
|
|
217
|
+
|
|
218
|
+
# Quantize using group index mapping
|
|
219
|
+
quantized = quantize_with_sz_map(
|
|
220
|
+
weights_scaled, scale_q, zero_q, g_idx, maxq
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Convert g_idx to float for storage
|
|
224
|
+
g_idx = ops.cast(g_idx, "float32")
|
|
225
|
+
|
|
226
|
+
return quantized, scale_q, zero_q, awq_scales, g_idx
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class AWQ:
|
|
230
|
+
"""AWQ quantizer for a single layer.
|
|
231
|
+
|
|
232
|
+
This class accumulates activation statistics during calibration and
|
|
233
|
+
performs AWQ quantization on layer weights.
|
|
234
|
+
|
|
235
|
+
The AWQ algorithm works by:
|
|
236
|
+
1. Collecting per-channel maximum activation magnitudes
|
|
237
|
+
2. Using activation magnitudes to determine weight saliency
|
|
238
|
+
3. Finding optimal per-channel scales via grid search
|
|
239
|
+
4. Applying scales before quantization to protect salient weights
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
layer: The layer to quantize (Dense or EinsumDense).
|
|
243
|
+
config: AWQConfig instance with quantization parameters.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def __init__(self, layer, config=None):
|
|
247
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
248
|
+
|
|
249
|
+
self.original_layer = layer
|
|
250
|
+
self.config = config or AWQConfig(dataset=None, tokenizer=None)
|
|
251
|
+
self.num_samples = 0
|
|
252
|
+
|
|
253
|
+
# Handle Dense and EinsumDense layers
|
|
254
|
+
if isinstance(layer, Dense) or (
|
|
255
|
+
isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
|
|
256
|
+
):
|
|
257
|
+
self.kernel_shape = layer.kernel.shape
|
|
258
|
+
self.rows = self.kernel_shape[0] # in_features
|
|
259
|
+
self.columns = self.kernel_shape[1] # out_features
|
|
260
|
+
self.layer = layer
|
|
261
|
+
elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
|
|
262
|
+
# Handle 3D EinsumDense layers (typically from attention blocks)
|
|
263
|
+
self.kernel_shape = layer.kernel.shape
|
|
264
|
+
shape = list(self.kernel_shape)
|
|
265
|
+
d_model_dim_index = shape.index(max(shape))
|
|
266
|
+
|
|
267
|
+
if d_model_dim_index == 0: # QKV projection case
|
|
268
|
+
in_features, heads, head_dim = shape
|
|
269
|
+
self.rows = in_features
|
|
270
|
+
self.columns = heads * head_dim
|
|
271
|
+
elif d_model_dim_index in [1, 2]: # Attention Output case
|
|
272
|
+
heads, head_dim, out_features = shape
|
|
273
|
+
self.rows = heads * head_dim
|
|
274
|
+
self.columns = out_features
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Cannot determine dimensions for EinsumDense kernel "
|
|
278
|
+
f"shape {shape}"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Create a temporary object that holds a reshaped 2D version
|
|
282
|
+
self.layer = types.SimpleNamespace(
|
|
283
|
+
kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
raise TypeError(f"Unsupported layer type for AWQ: {type(layer)}")
|
|
287
|
+
|
|
288
|
+
# Initialize activation magnitude accumulator (per-channel max)
|
|
289
|
+
self.activation_magnitudes = ops.zeros((self.rows,), dtype="float32")
|
|
290
|
+
|
|
291
|
+
def update_activation_magnitudes(self, input_batch):
|
|
292
|
+
"""Update per-channel activation magnitude statistics.
|
|
293
|
+
|
|
294
|
+
This method tracks the maximum absolute activation value for each
|
|
295
|
+
input channel across all calibration batches.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
input_batch: Input activations tensor [batch, ..., in_features].
|
|
299
|
+
"""
|
|
300
|
+
if input_batch is None:
|
|
301
|
+
raise ValueError("Input tensor cannot be None.")
|
|
302
|
+
if ops.size(input_batch) == 0:
|
|
303
|
+
raise ValueError("Input tensor cannot be empty.")
|
|
304
|
+
|
|
305
|
+
# Flatten to [batch_samples, in_features]
|
|
306
|
+
if len(input_batch.shape) > 2:
|
|
307
|
+
input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
|
|
308
|
+
|
|
309
|
+
x = ops.cast(input_batch, "float32")
|
|
310
|
+
|
|
311
|
+
# Compute per-channel max absolute value for this batch
|
|
312
|
+
batch_max = ops.max(ops.abs(x), axis=0)
|
|
313
|
+
|
|
314
|
+
# Update running max
|
|
315
|
+
self.activation_magnitudes = ops.maximum(
|
|
316
|
+
self.activation_magnitudes, batch_max
|
|
317
|
+
)
|
|
318
|
+
self.num_samples = self.num_samples + int(ops.shape(x)[0])
|
|
319
|
+
|
|
320
|
+
def quantize_layer(self):
|
|
321
|
+
"""Perform AWQ quantization on the layer.
|
|
322
|
+
|
|
323
|
+
This method:
|
|
324
|
+
1. Runs the AWQ grid search to find optimal scales
|
|
325
|
+
2. Quantizes the layer weights
|
|
326
|
+
3. Updates the layer's quantized variables
|
|
327
|
+
"""
|
|
328
|
+
from keras.src import quantizers
|
|
329
|
+
|
|
330
|
+
weights_matrix = ops.transpose(self.layer.kernel)
|
|
331
|
+
|
|
332
|
+
# Perform AWQ quantization
|
|
333
|
+
quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(
|
|
334
|
+
weights_matrix,
|
|
335
|
+
self.activation_magnitudes,
|
|
336
|
+
num_grid_points=self.config.num_grid_points,
|
|
337
|
+
group_size=self.config.group_size,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Cast to uint8 for storage
|
|
341
|
+
# quantized is already [out_features, in_features]
|
|
342
|
+
quantized = ops.cast(quantized, "uint8")
|
|
343
|
+
|
|
344
|
+
# Pack to 4-bit along axis 0 (output features)
|
|
345
|
+
quantized_packed, _, _ = quantizers.pack_int4(
|
|
346
|
+
quantized, axis=0, dtype="uint8"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Assign to layer variables
|
|
350
|
+
del self.original_layer._kernel
|
|
351
|
+
self.original_layer.quantized_kernel.assign(quantized_packed)
|
|
352
|
+
self.original_layer.kernel_scale.assign(scale)
|
|
353
|
+
self.original_layer.kernel_zero.assign(zero)
|
|
354
|
+
self.original_layer.awq_scales.assign(awq_scales)
|
|
355
|
+
self.original_layer.g_idx.assign(g_idx)
|
|
356
|
+
self.original_layer.is_awq_calibrated = True
|
|
357
|
+
|
|
358
|
+
def free(self):
|
|
359
|
+
"""Free memory used by the quantizer."""
|
|
360
|
+
del self.activation_magnitudes
|
|
361
|
+
del self.layer
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from keras.src.api_export import keras_export
|
|
2
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@keras_export("keras.quantizers.AWQConfig")
|
|
6
|
+
class AWQConfig(QuantizationConfig):
|
|
7
|
+
"""Configuration class for AWQ (Activation-aware Weight Quantization).
|
|
8
|
+
|
|
9
|
+
AWQ is a post-training quantization method that identifies and protects
|
|
10
|
+
salient weights based on activation magnitudes. It applies per-channel
|
|
11
|
+
scaling before quantization to minimize accuracy loss.
|
|
12
|
+
|
|
13
|
+
Methodology:
|
|
14
|
+
1. Collects activation statistics from calibration data
|
|
15
|
+
2. Identifies salient weight channels based on activation magnitudes
|
|
16
|
+
3. Searches for optimal per-channel scaling factors via grid search
|
|
17
|
+
4. Applies scaling before quantization to protect important weights
|
|
18
|
+
|
|
19
|
+
References:
|
|
20
|
+
- Original AWQ paper: "AWQ: Activation-aware Weight Quantization for
|
|
21
|
+
LLM Compression and Acceleration" (https://arxiv.org/abs/2306.00978)
|
|
22
|
+
- Reference implementation: https://github.com/mit-han-lab/llm-awq
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dataset: The calibration dataset. It can be an iterable that yields
|
|
26
|
+
strings or pre-tokenized numerical tensors (e.g., a list of
|
|
27
|
+
strings, a generator, or a NumPy array). This data is used to
|
|
28
|
+
analyze activation patterns.
|
|
29
|
+
tokenizer: A tokenizer instance (or a similar callable) that is used
|
|
30
|
+
to process the `dataset`.
|
|
31
|
+
weight_bits: The number of bits for weight quantization. AWQ presently
|
|
32
|
+
only supports 4-bit quantization. Defaults to 4.
|
|
33
|
+
num_samples: The number of calibration data samples to use from the
|
|
34
|
+
dataset. Defaults to 128.
|
|
35
|
+
sequence_length: The sequence length to use for each calibration
|
|
36
|
+
sample. Defaults to 512.
|
|
37
|
+
group_size: The size of weight groups to quantize together. A
|
|
38
|
+
`group_size` of -1 indicates per-channel quantization.
|
|
39
|
+
Defaults to 128.
|
|
40
|
+
num_grid_points: The number of grid search points for finding optimal
|
|
41
|
+
per-channel scales. Higher values may find better scales but
|
|
42
|
+
take longer. Defaults to 20.
|
|
43
|
+
quantization_layer_structure: A dictionary defining the model's
|
|
44
|
+
quantization structure. It should contain:
|
|
45
|
+
- "pre_block_layers": list of layers to run before the first
|
|
46
|
+
block (e.g., embedding layer).
|
|
47
|
+
- "sequential_blocks": list of transformer blocks to quantize
|
|
48
|
+
sequentially.
|
|
49
|
+
If not provided, the model must implement
|
|
50
|
+
`get_quantization_layer_structure`.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
```python
|
|
54
|
+
from keras.quantizers import AWQConfig
|
|
55
|
+
|
|
56
|
+
# Create configuration for 4-bit AWQ quantization
|
|
57
|
+
config = AWQConfig(
|
|
58
|
+
dataset=calibration_data, # Your calibration dataset
|
|
59
|
+
tokenizer=your_tokenizer, # Tokenizer for text data
|
|
60
|
+
num_samples=128, # Number of calibration samples
|
|
61
|
+
sequence_length=512, # Sequence length for each sample
|
|
62
|
+
group_size=128, # Weight grouping for quantization
|
|
63
|
+
num_grid_points=20, # Grid search points for scale search
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Apply quantization to your model
|
|
67
|
+
model.quantize("awq", config=config)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dataset,
|
|
75
|
+
tokenizer,
|
|
76
|
+
*,
|
|
77
|
+
weight_bits: int = 4,
|
|
78
|
+
num_samples: int = 128,
|
|
79
|
+
sequence_length: int = 512,
|
|
80
|
+
group_size: int = 128,
|
|
81
|
+
num_grid_points: int = 20,
|
|
82
|
+
quantization_layer_structure: dict = None,
|
|
83
|
+
):
|
|
84
|
+
super().__init__()
|
|
85
|
+
# AWQ only supports 4-bit quantization
|
|
86
|
+
if weight_bits != 4:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"AWQ only supports 4-bit quantization. "
|
|
89
|
+
f"Received weight_bits={weight_bits}."
|
|
90
|
+
)
|
|
91
|
+
if num_samples <= 0:
|
|
92
|
+
raise ValueError("num_samples must be a positive integer.")
|
|
93
|
+
if sequence_length <= 0:
|
|
94
|
+
raise ValueError("sequence_length must be a positive integer.")
|
|
95
|
+
if group_size < -1 or group_size == 0:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Invalid group_size. Supported values are -1 (per-channel) "
|
|
98
|
+
f"or a positive integer, but got {group_size}."
|
|
99
|
+
)
|
|
100
|
+
if num_grid_points <= 0:
|
|
101
|
+
raise ValueError("num_grid_points must be a positive integer.")
|
|
102
|
+
|
|
103
|
+
self.dataset = dataset
|
|
104
|
+
self.tokenizer = tokenizer
|
|
105
|
+
self.weight_bits = weight_bits
|
|
106
|
+
self.num_samples = num_samples
|
|
107
|
+
self.sequence_length = sequence_length
|
|
108
|
+
self.group_size = group_size
|
|
109
|
+
self.num_grid_points = num_grid_points
|
|
110
|
+
self.quantization_layer_structure = quantization_layer_structure
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def mode(self):
|
|
114
|
+
return "awq"
|
|
115
|
+
|
|
116
|
+
def dtype_policy_string(self):
|
|
117
|
+
"""Returns the dtype policy string for this configuration.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A string representing the dtype policy, e.g. "awq/4/128".
|
|
121
|
+
"""
|
|
122
|
+
return f"awq/{self.weight_bits}/{self.group_size}"
|
|
123
|
+
|
|
124
|
+
def get_config(self):
|
|
125
|
+
return {
|
|
126
|
+
# Dataset and Tokenizer are only required for one-time
|
|
127
|
+
# calibration and are not saved in the config.
|
|
128
|
+
"dataset": None,
|
|
129
|
+
"tokenizer": None,
|
|
130
|
+
"weight_bits": self.weight_bits,
|
|
131
|
+
"num_samples": self.num_samples,
|
|
132
|
+
"sequence_length": self.sequence_length,
|
|
133
|
+
"group_size": self.group_size,
|
|
134
|
+
"num_grid_points": self.num_grid_points,
|
|
135
|
+
"quantization_layer_structure": self.quantization_layer_structure,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def from_config(cls, config):
|
|
140
|
+
return cls(**config)
|