keras-nightly 3.14.0.dev2026011304__py3-none-any.whl → 3.14.0.dev2026011504__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/openvino/numpy.py +145 -7
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/layer.py +5 -0
- keras/src/models/model.py +7 -3
- keras/src/ops/numpy.py +9 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026011304.dist-info → keras_nightly-3.14.0.dev2026011504.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026011304.dist-info → keras_nightly-3.14.0.dev2026011504.dist-info}/RECORD +26 -23
- {keras_nightly-3.14.0.dev2026011304.dist-info → keras_nightly-3.14.0.dev2026011504.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026011304.dist-info → keras_nightly-3.14.0.dev2026011504.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,9 @@ since your modifications would be overwritten.
|
|
|
7
7
|
from keras.src.dtype_policies import deserialize as deserialize
|
|
8
8
|
from keras.src.dtype_policies import get as get
|
|
9
9
|
from keras.src.dtype_policies import serialize as serialize
|
|
10
|
+
from keras.src.dtype_policies.dtype_policy import (
|
|
11
|
+
AWQDTypePolicy as AWQDTypePolicy,
|
|
12
|
+
)
|
|
10
13
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
|
|
11
14
|
from keras.src.dtype_policies.dtype_policy import (
|
|
12
15
|
FloatDTypePolicy as FloatDTypePolicy,
|
|
@@ -7,6 +7,7 @@ since your modifications would be overwritten.
|
|
|
7
7
|
from keras.src.quantizers import deserialize as deserialize
|
|
8
8
|
from keras.src.quantizers import get as get
|
|
9
9
|
from keras.src.quantizers import serialize as serialize
|
|
10
|
+
from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
|
|
10
11
|
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
|
|
11
12
|
from keras.src.quantizers.quantization_config import (
|
|
12
13
|
Float8QuantizationConfig as Float8QuantizationConfig,
|
keras/dtype_policies/__init__.py
CHANGED
|
@@ -7,6 +7,9 @@ since your modifications would be overwritten.
|
|
|
7
7
|
from keras.src.dtype_policies import deserialize as deserialize
|
|
8
8
|
from keras.src.dtype_policies import get as get
|
|
9
9
|
from keras.src.dtype_policies import serialize as serialize
|
|
10
|
+
from keras.src.dtype_policies.dtype_policy import (
|
|
11
|
+
AWQDTypePolicy as AWQDTypePolicy,
|
|
12
|
+
)
|
|
10
13
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
|
|
11
14
|
from keras.src.dtype_policies.dtype_policy import (
|
|
12
15
|
FloatDTypePolicy as FloatDTypePolicy,
|
keras/quantizers/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ since your modifications would be overwritten.
|
|
|
7
7
|
from keras.src.quantizers import deserialize as deserialize
|
|
8
8
|
from keras.src.quantizers import get as get
|
|
9
9
|
from keras.src.quantizers import serialize as serialize
|
|
10
|
+
from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
|
|
10
11
|
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
|
|
11
12
|
from keras.src.quantizers.quantization_config import (
|
|
12
13
|
Float8QuantizationConfig as Float8QuantizationConfig,
|
|
@@ -1126,22 +1126,28 @@ def expm1(x):
|
|
|
1126
1126
|
|
|
1127
1127
|
def flip(x, axis=None):
|
|
1128
1128
|
x_node = get_ov_output(x)
|
|
1129
|
-
|
|
1129
|
+
|
|
1130
|
+
# Using OpenVINO tensor shape
|
|
1131
|
+
ndim = len(x_node.get_partial_shape())
|
|
1130
1132
|
if ndim is None:
|
|
1131
1133
|
raise ValueError(
|
|
1132
|
-
"The `flip` operation does not support tensors with dynamic rank"
|
|
1134
|
+
"The `flip` operation does not support tensors with dynamic rank "
|
|
1133
1135
|
"for the OpenVINO backend."
|
|
1134
1136
|
)
|
|
1137
|
+
|
|
1135
1138
|
if axis is None:
|
|
1136
1139
|
axis = list(range(ndim))
|
|
1137
1140
|
elif isinstance(axis, int):
|
|
1138
1141
|
axis = [axis]
|
|
1142
|
+
|
|
1139
1143
|
axis = [a + ndim if a < 0 else a for a in axis]
|
|
1144
|
+
|
|
1140
1145
|
begin = [0] * ndim
|
|
1141
1146
|
end = [0] * ndim
|
|
1142
1147
|
strides = [1] * ndim
|
|
1143
1148
|
for a in axis:
|
|
1144
1149
|
strides[a] = -1
|
|
1150
|
+
|
|
1145
1151
|
all_ones_mask = [1] * ndim
|
|
1146
1152
|
result = ov_opset.strided_slice(
|
|
1147
1153
|
data=x_node,
|
|
@@ -1154,6 +1160,61 @@ def flip(x, axis=None):
|
|
|
1154
1160
|
return OpenVINOKerasTensor(result.output(0))
|
|
1155
1161
|
|
|
1156
1162
|
|
|
1163
|
+
def rot90(array, k=1, axes=(0, 1)):
|
|
1164
|
+
"""Rotate an array by 90 degrees in the plane specified by axes."""
|
|
1165
|
+
array = get_ov_output(array)
|
|
1166
|
+
|
|
1167
|
+
if not isinstance(axes, (tuple, list)) or len(axes) != 2:
|
|
1168
|
+
raise ValueError("axes must be a tuple of length 2")
|
|
1169
|
+
|
|
1170
|
+
shape = array.get_partial_shape()
|
|
1171
|
+
ndim = shape.rank.get_length()
|
|
1172
|
+
if ndim is None:
|
|
1173
|
+
raise ValueError(
|
|
1174
|
+
"`rot90` does not support tensors with dynamic rank "
|
|
1175
|
+
"for the OpenVINO backend."
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
axis1 = canonicalize_axis(axes[0], ndim)
|
|
1179
|
+
axis2 = canonicalize_axis(axes[1], ndim)
|
|
1180
|
+
|
|
1181
|
+
if axis1 == axis2:
|
|
1182
|
+
raise ValueError("axes must be different")
|
|
1183
|
+
|
|
1184
|
+
k = k % 4
|
|
1185
|
+
if k == 0:
|
|
1186
|
+
return OpenVINOKerasTensor(array)
|
|
1187
|
+
|
|
1188
|
+
result = array
|
|
1189
|
+
|
|
1190
|
+
for _ in range(k):
|
|
1191
|
+
# 1️ Transpose axis1 <-> axis2
|
|
1192
|
+
perm = list(range(ndim))
|
|
1193
|
+
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
|
|
1194
|
+
perm_const = ov_opset.constant(perm, Type.i32).output(0)
|
|
1195
|
+
result = ov_opset.transpose(result, perm_const).output(0)
|
|
1196
|
+
|
|
1197
|
+
# 2️ Reverse along axis1 using StridedSlice
|
|
1198
|
+
begin = [0] * ndim
|
|
1199
|
+
end = [0] * ndim
|
|
1200
|
+
strides = [1] * ndim
|
|
1201
|
+
strides[axis1] = -1
|
|
1202
|
+
|
|
1203
|
+
begin_mask = [1] * ndim
|
|
1204
|
+
end_mask = [1] * ndim
|
|
1205
|
+
|
|
1206
|
+
result = ov_opset.strided_slice(
|
|
1207
|
+
data=result,
|
|
1208
|
+
begin=begin,
|
|
1209
|
+
end=end,
|
|
1210
|
+
strides=strides,
|
|
1211
|
+
begin_mask=begin_mask,
|
|
1212
|
+
end_mask=end_mask,
|
|
1213
|
+
).output(0)
|
|
1214
|
+
|
|
1215
|
+
return OpenVINOKerasTensor(result)
|
|
1216
|
+
|
|
1217
|
+
|
|
1157
1218
|
def floor(x):
|
|
1158
1219
|
x = get_ov_output(x)
|
|
1159
1220
|
x_type = x.get_element_type()
|
|
@@ -1394,7 +1455,66 @@ def isreal(x):
|
|
|
1394
1455
|
|
|
1395
1456
|
|
|
1396
1457
|
def kron(x1, x2):
|
|
1397
|
-
|
|
1458
|
+
x1 = get_ov_output(x1)
|
|
1459
|
+
x2 = get_ov_output(x2)
|
|
1460
|
+
x1, x2 = _align_operand_types(x1, x2, "kron()")
|
|
1461
|
+
x1_shape = x1.get_partial_shape()
|
|
1462
|
+
x2_shape = x2.get_partial_shape()
|
|
1463
|
+
if x1_shape.rank.is_dynamic or x2_shape.rank.is_dynamic:
|
|
1464
|
+
raise ValueError(
|
|
1465
|
+
"`kron` does not support tensors with dynamic rank for "
|
|
1466
|
+
"the OpenVINO backend."
|
|
1467
|
+
)
|
|
1468
|
+
ndim1 = x1_shape.rank.get_length()
|
|
1469
|
+
ndim2 = x2_shape.rank.get_length()
|
|
1470
|
+
if ndim1 < ndim2:
|
|
1471
|
+
axes = ov_opset.range(
|
|
1472
|
+
ov_opset.constant(0, Type.i32),
|
|
1473
|
+
ov_opset.constant(ndim2 - ndim1, Type.i32),
|
|
1474
|
+
ov_opset.constant(1, Type.i32),
|
|
1475
|
+
)
|
|
1476
|
+
x1 = ov_opset.unsqueeze(x1, axes)
|
|
1477
|
+
ndim1 = ndim2
|
|
1478
|
+
elif ndim2 < ndim1:
|
|
1479
|
+
axes = ov_opset.range(
|
|
1480
|
+
ov_opset.constant(0, Type.i32),
|
|
1481
|
+
ov_opset.constant(ndim1 - ndim2, Type.i32),
|
|
1482
|
+
ov_opset.constant(1, Type.i32),
|
|
1483
|
+
)
|
|
1484
|
+
x2 = ov_opset.unsqueeze(x2, axes)
|
|
1485
|
+
ndim2 = ndim1
|
|
1486
|
+
shape1 = ov_opset.shape_of(x1, Type.i32)
|
|
1487
|
+
shape2 = ov_opset.shape_of(x2, Type.i32)
|
|
1488
|
+
ones = ov_opset.broadcast(
|
|
1489
|
+
ov_opset.constant(1, Type.i32), ov_opset.constant([ndim1], Type.i32)
|
|
1490
|
+
)
|
|
1491
|
+
axis = ov_opset.constant(1, Type.i32)
|
|
1492
|
+
flatten = ov_opset.constant([-1], Type.i32)
|
|
1493
|
+
unsqueezed_ones = ov_opset.unsqueeze(ones, axis)
|
|
1494
|
+
x1_new_shape = ov_opset.reshape(
|
|
1495
|
+
ov_opset.concat(
|
|
1496
|
+
[ov_opset.unsqueeze(shape1, axis), unsqueezed_ones],
|
|
1497
|
+
axis=1,
|
|
1498
|
+
),
|
|
1499
|
+
flatten,
|
|
1500
|
+
False,
|
|
1501
|
+
)
|
|
1502
|
+
x2_new_shape = ov_opset.reshape(
|
|
1503
|
+
ov_opset.concat(
|
|
1504
|
+
[unsqueezed_ones, ov_opset.unsqueeze(shape2, axis)],
|
|
1505
|
+
axis=1,
|
|
1506
|
+
),
|
|
1507
|
+
flatten,
|
|
1508
|
+
False,
|
|
1509
|
+
)
|
|
1510
|
+
result = ov_opset.multiply(
|
|
1511
|
+
ov_opset.reshape(x1, x1_new_shape, False),
|
|
1512
|
+
ov_opset.reshape(x2, x2_new_shape, False),
|
|
1513
|
+
)
|
|
1514
|
+
result = ov_opset.reshape(
|
|
1515
|
+
result, ov_opset.multiply(shape1, shape2), False
|
|
1516
|
+
).output(0)
|
|
1517
|
+
return OpenVINOKerasTensor(result)
|
|
1398
1518
|
|
|
1399
1519
|
|
|
1400
1520
|
def lcm(x1, x2):
|
|
@@ -2226,7 +2346,14 @@ def sinh(x):
|
|
|
2226
2346
|
|
|
2227
2347
|
|
|
2228
2348
|
def size(x):
|
|
2229
|
-
|
|
2349
|
+
x = get_ov_output(x)
|
|
2350
|
+
shape_tensor = ov_opset.shape_of(x, output_type=Type.i64)
|
|
2351
|
+
final_size = ov_opset.reduce_prod(
|
|
2352
|
+
shape_tensor,
|
|
2353
|
+
ov_opset.constant([0], Type.i64),
|
|
2354
|
+
keep_dims=False,
|
|
2355
|
+
)
|
|
2356
|
+
return OpenVINOKerasTensor(final_size.output(0))
|
|
2230
2357
|
|
|
2231
2358
|
|
|
2232
2359
|
def sort(x, axis=-1):
|
|
@@ -2368,9 +2495,20 @@ def std(x, axis=None, keepdims=False):
|
|
|
2368
2495
|
|
|
2369
2496
|
|
|
2370
2497
|
def swapaxes(x, axis1, axis2):
|
|
2371
|
-
|
|
2372
|
-
|
|
2373
|
-
|
|
2498
|
+
x = get_ov_output(x)
|
|
2499
|
+
x_shape = x.get_partial_shape()
|
|
2500
|
+
if x_shape.rank.is_dynamic:
|
|
2501
|
+
raise ValueError(
|
|
2502
|
+
"`swapaxes` does not support tensors with dynamic rank for the "
|
|
2503
|
+
"OpenVINO backend."
|
|
2504
|
+
)
|
|
2505
|
+
rank = x_shape.rank.get_length()
|
|
2506
|
+
axis1 = canonicalize_axis(axis1, rank)
|
|
2507
|
+
axis2 = canonicalize_axis(axis2, rank)
|
|
2508
|
+
axes = list(range(rank))
|
|
2509
|
+
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
|
|
2510
|
+
result = ov_opset.transpose(x, ov_opset.constant(axes, Type.i32))
|
|
2511
|
+
return OpenVINOKerasTensor(result.output(0))
|
|
2374
2512
|
|
|
2375
2513
|
|
|
2376
2514
|
def take(x, indices, axis=None):
|
|
@@ -2,6 +2,7 @@ from keras.src import backend
|
|
|
2
2
|
from keras.src.api_export import keras_export
|
|
3
3
|
from keras.src.dtype_policies import dtype_policy
|
|
4
4
|
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
|
|
5
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
5
6
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy
|
|
6
7
|
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
|
|
7
8
|
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
@@ -10,6 +11,7 @@ from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
|
|
|
10
11
|
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
11
12
|
|
|
12
13
|
ALL_OBJECTS = {
|
|
14
|
+
AWQDTypePolicy,
|
|
13
15
|
DTypePolicy,
|
|
14
16
|
FloatDTypePolicy,
|
|
15
17
|
QuantizedDTypePolicy,
|
|
@@ -3,7 +3,7 @@ from keras.src import ops
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
4
|
from keras.src.backend.common import global_state
|
|
5
5
|
|
|
6
|
-
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
|
|
6
|
+
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@keras_export(
|
|
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
|
|
|
376
376
|
return config
|
|
377
377
|
|
|
378
378
|
|
|
379
|
+
@keras_export("keras.dtype_policies.AWQDTypePolicy")
|
|
380
|
+
class AWQDTypePolicy(QuantizedDTypePolicy):
|
|
381
|
+
"""Quantized dtype policy for AWQ quantization.
|
|
382
|
+
|
|
383
|
+
This policy helps propagate quantization settings for AWQ
|
|
384
|
+
when loading an AWQ quantized model in Keras format.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
mode: The quantization mode. This should be a string in the format
|
|
388
|
+
`"awq/<weight_bits>/<group_size>"`.
|
|
389
|
+
- `"awq"`: The identifier for the quantization algorithm.
|
|
390
|
+
- `<weight_bits>`: Number of bits to quantize weights to.
|
|
391
|
+
AWQ presently only supports 4-bit quantization.
|
|
392
|
+
- `<group_size>`: The group size for quantization. Supported
|
|
393
|
+
values are -1 (for per-channel quantization) or any
|
|
394
|
+
positive integer.
|
|
395
|
+
Example: `"awq/4/128"`.
|
|
396
|
+
source_name: The source dtype policy name, e.g. "float32".
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
mode,
|
|
402
|
+
source_name=None,
|
|
403
|
+
):
|
|
404
|
+
parts = mode.split("/")
|
|
405
|
+
expected_format = "'awq/<weight_bits>/<group_size>'"
|
|
406
|
+
|
|
407
|
+
# Validate format.
|
|
408
|
+
if len(parts) != 3 or parts[0] != "awq":
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Invalid mode for AWQDTypePolicy. Expected format "
|
|
411
|
+
f"{expected_format}, but got '{mode}'."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Validate and cast weight_bits and group_size.
|
|
415
|
+
try:
|
|
416
|
+
weight_bits = int(parts[1])
|
|
417
|
+
group_size = int(parts[2])
|
|
418
|
+
except ValueError:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"Invalid mode for AWQDTypePolicy. <weight_bits> and "
|
|
421
|
+
"<group_size> must be integers. Expected format "
|
|
422
|
+
f"{expected_format}, but got '{mode}'."
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# AWQ presently only supports 4-bit quantization.
|
|
426
|
+
if weight_bits != 4:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"Invalid weight_bits in mode. AWQ only supports 4-bit "
|
|
429
|
+
f"quantization, but got {weight_bits} from '{mode}'."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
if group_size < -1 or group_size == 0:
|
|
433
|
+
raise ValueError(
|
|
434
|
+
"Invalid group_size in mode. Supported values are "
|
|
435
|
+
"-1 (per-channel) or a positive integer, "
|
|
436
|
+
f"but got {group_size} from '{mode}'."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
base_mode = parts[0]
|
|
440
|
+
super().__init__(
|
|
441
|
+
mode=base_mode,
|
|
442
|
+
source_name=source_name,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self._name = f"{mode}_from_{source_name}"
|
|
446
|
+
self.mode = base_mode
|
|
447
|
+
self.weight_bits = weight_bits
|
|
448
|
+
self.group_size = group_size
|
|
449
|
+
|
|
450
|
+
def __eq__(self, other):
|
|
451
|
+
if super().__eq__(other) is False:
|
|
452
|
+
return False
|
|
453
|
+
return (
|
|
454
|
+
self.weight_bits == other.weight_bits
|
|
455
|
+
and self.group_size == other.group_size
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def get_config(self):
|
|
459
|
+
config = super().get_config()
|
|
460
|
+
# Reconstruct the full mode string for serialization
|
|
461
|
+
mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
|
|
462
|
+
config.update({"mode": mode})
|
|
463
|
+
return config
|
|
464
|
+
|
|
465
|
+
|
|
379
466
|
@keras_export(
|
|
380
467
|
[
|
|
381
468
|
"keras.config.set_dtype_policy",
|
|
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
|
|
|
442
529
|
return QuantizedDTypePolicy(mode, source_name)
|
|
443
530
|
elif policy.startswith("gptq"):
|
|
444
531
|
return GPTQDTypePolicy(mode, source_name)
|
|
532
|
+
elif policy.startswith("awq"):
|
|
533
|
+
return AWQDTypePolicy(mode, source_name)
|
|
445
534
|
elif policy.startswith("float8"):
|
|
446
535
|
return QuantizedFloat8DTypePolicy(mode, source_name)
|
|
447
536
|
else:
|
keras/src/layers/core/dense.py
CHANGED
|
@@ -128,7 +128,7 @@ class Dense(Layer):
|
|
|
128
128
|
mode=self.quantization_mode,
|
|
129
129
|
config=self.quantization_config,
|
|
130
130
|
)
|
|
131
|
-
if self.quantization_mode not in ("int8", "int4", "gptq"):
|
|
131
|
+
if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
|
|
132
132
|
# If the layer is quantized to int8 or int4, `self._kernel` will be
|
|
133
133
|
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
|
|
134
134
|
# it here.
|
|
@@ -165,15 +165,17 @@ class Dense(Layer):
|
|
|
165
165
|
|
|
166
166
|
mode = self.quantization_mode
|
|
167
167
|
is_gptq = mode == "gptq"
|
|
168
|
+
is_awq = mode == "awq"
|
|
168
169
|
is_int4 = mode == "int4"
|
|
169
|
-
|
|
170
|
+
gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
171
|
+
awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
|
|
170
172
|
gptq_bits = (
|
|
171
173
|
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
172
174
|
)
|
|
173
175
|
|
|
174
176
|
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
175
177
|
# kernel)
|
|
176
|
-
if is_gptq and
|
|
178
|
+
if is_gptq and gptq_calibrated and gptq_bits != 4:
|
|
177
179
|
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
178
180
|
kernel = self.quantized_kernel
|
|
179
181
|
else:
|
|
@@ -183,7 +185,15 @@ class Dense(Layer):
|
|
|
183
185
|
# Handle int4 unpacking cases in one place
|
|
184
186
|
if is_int4:
|
|
185
187
|
kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
|
|
186
|
-
elif is_gptq and
|
|
188
|
+
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
189
|
+
kernel = quantizers.unpack_int4(
|
|
190
|
+
self.quantized_kernel,
|
|
191
|
+
orig_len=self.units,
|
|
192
|
+
axis=0,
|
|
193
|
+
dtype="uint8",
|
|
194
|
+
)
|
|
195
|
+
elif is_awq and awq_calibrated:
|
|
196
|
+
# AWQ always uses 4-bit quantization
|
|
187
197
|
kernel = quantizers.unpack_int4(
|
|
188
198
|
self.quantized_kernel,
|
|
189
199
|
orig_len=self.units,
|
|
@@ -304,8 +314,9 @@ class Dense(Layer):
|
|
|
304
314
|
if mode not in self.variable_serialization_spec:
|
|
305
315
|
raise self._quantization_mode_error(mode)
|
|
306
316
|
|
|
307
|
-
# A saved GPTQ quantized model will always be calibrated.
|
|
317
|
+
# A saved GPTQ/AWQ quantized model will always be calibrated.
|
|
308
318
|
self.is_gptq_calibrated = mode == "gptq"
|
|
319
|
+
self.is_awq_calibrated = mode == "awq"
|
|
309
320
|
|
|
310
321
|
idx = 0
|
|
311
322
|
for name in self.variable_serialization_spec[mode]:
|
|
@@ -395,6 +406,14 @@ class Dense(Layer):
|
|
|
395
406
|
"kernel_zero",
|
|
396
407
|
"g_idx",
|
|
397
408
|
],
|
|
409
|
+
"awq": [
|
|
410
|
+
"bias",
|
|
411
|
+
"quantized_kernel",
|
|
412
|
+
"kernel_scale",
|
|
413
|
+
"kernel_zero",
|
|
414
|
+
"awq_scales",
|
|
415
|
+
"g_idx",
|
|
416
|
+
],
|
|
398
417
|
}
|
|
399
418
|
|
|
400
419
|
def quantized_build(self, kernel_shape, mode, config=None):
|
|
@@ -406,6 +425,8 @@ class Dense(Layer):
|
|
|
406
425
|
self._float8_build()
|
|
407
426
|
elif mode == "gptq":
|
|
408
427
|
self._gptq_build(kernel_shape, config)
|
|
428
|
+
elif mode == "awq":
|
|
429
|
+
self._awq_build(kernel_shape, config)
|
|
409
430
|
else:
|
|
410
431
|
raise self._quantization_mode_error(mode)
|
|
411
432
|
self._is_quantized = True
|
|
@@ -515,6 +536,97 @@ class Dense(Layer):
|
|
|
515
536
|
y = self.activation(y)
|
|
516
537
|
return y
|
|
517
538
|
|
|
539
|
+
def _awq_build(self, kernel_shape, config):
|
|
540
|
+
"""Build variables for AWQ quantization.
|
|
541
|
+
|
|
542
|
+
AWQ uses 4-bit quantization with per-channel AWQ scales that protect
|
|
543
|
+
salient weights based on activation magnitudes.
|
|
544
|
+
"""
|
|
545
|
+
from keras.src.quantizers import awq_core
|
|
546
|
+
|
|
547
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
548
|
+
# until calibration has been performed.
|
|
549
|
+
self.is_awq_calibrated = False
|
|
550
|
+
self.kernel_shape = kernel_shape
|
|
551
|
+
|
|
552
|
+
# For 4-bit weights, we pack two values per byte.
|
|
553
|
+
units = (kernel_shape[1] + 1) // 2
|
|
554
|
+
|
|
555
|
+
self.quantized_kernel = self.add_weight(
|
|
556
|
+
name="kernel",
|
|
557
|
+
shape=(units, kernel_shape[0]),
|
|
558
|
+
initializer="zeros",
|
|
559
|
+
dtype="uint8",
|
|
560
|
+
trainable=False,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
group_size = awq_core.get_group_size_for_layer(self, config)
|
|
564
|
+
num_groups = (
|
|
565
|
+
1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
|
|
566
|
+
)
|
|
567
|
+
self.kernel_scale = self.add_weight(
|
|
568
|
+
name="kernel_scale",
|
|
569
|
+
shape=(self.units, num_groups),
|
|
570
|
+
initializer="ones",
|
|
571
|
+
trainable=False,
|
|
572
|
+
)
|
|
573
|
+
self.kernel_zero = self.add_weight(
|
|
574
|
+
name="kernel_zero",
|
|
575
|
+
shape=(self.units, num_groups),
|
|
576
|
+
initializer="zeros",
|
|
577
|
+
dtype="uint8",
|
|
578
|
+
trainable=False,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Per-channel AWQ scales from activation magnitudes
|
|
582
|
+
self.awq_scales = self.add_weight(
|
|
583
|
+
name="awq_scales",
|
|
584
|
+
shape=(kernel_shape[0],),
|
|
585
|
+
initializer="ones",
|
|
586
|
+
trainable=False,
|
|
587
|
+
)
|
|
588
|
+
self.g_idx = self.add_weight(
|
|
589
|
+
name="g_idx",
|
|
590
|
+
shape=(kernel_shape[0],),
|
|
591
|
+
initializer="zeros",
|
|
592
|
+
dtype="float32",
|
|
593
|
+
trainable=False,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
def _awq_call(self, inputs, training=False):
|
|
597
|
+
"""Forward pass for AWQ quantized layer."""
|
|
598
|
+
if not self.is_awq_calibrated:
|
|
599
|
+
W = self._kernel
|
|
600
|
+
else:
|
|
601
|
+
# Unpack 4-bit weights
|
|
602
|
+
W = quantizers.unpack_int4(
|
|
603
|
+
self.quantized_kernel,
|
|
604
|
+
orig_len=self.units,
|
|
605
|
+
axis=0,
|
|
606
|
+
dtype="uint8",
|
|
607
|
+
)
|
|
608
|
+
# Dequantize using scale/zero maps
|
|
609
|
+
W = ops.transpose(
|
|
610
|
+
dequantize_with_sz_map(
|
|
611
|
+
W,
|
|
612
|
+
self.kernel_scale,
|
|
613
|
+
self.kernel_zero,
|
|
614
|
+
self.g_idx,
|
|
615
|
+
)
|
|
616
|
+
)
|
|
617
|
+
# Apply AWQ scales by dividing to restore original magnitude
|
|
618
|
+
# (We multiplied by scales before quantization, so divide to undo)
|
|
619
|
+
# awq_scales has shape [input_dim], W has shape [input_dim, units]
|
|
620
|
+
# Expand dims for proper broadcasting.
|
|
621
|
+
W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
|
|
622
|
+
|
|
623
|
+
y = ops.matmul(inputs, W)
|
|
624
|
+
if self.bias is not None:
|
|
625
|
+
y = ops.add(y, self.bias)
|
|
626
|
+
if self.activation is not None:
|
|
627
|
+
y = self.activation(y)
|
|
628
|
+
return y
|
|
629
|
+
|
|
518
630
|
def _int4_build(self, kernel_shape, config=None):
|
|
519
631
|
"""Build variables for int4 quantization.
|
|
520
632
|
|
|
@@ -835,6 +947,8 @@ class Dense(Layer):
|
|
|
835
947
|
self.kernel_scale.assign(kernel_scale)
|
|
836
948
|
elif mode == "gptq":
|
|
837
949
|
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
950
|
+
elif mode == "awq":
|
|
951
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
838
952
|
elif mode == "float8":
|
|
839
953
|
self.quantized_build(kernel_shape, mode)
|
|
840
954
|
else:
|
|
@@ -847,6 +961,8 @@ class Dense(Layer):
|
|
|
847
961
|
policy_name = mode
|
|
848
962
|
if mode == "gptq":
|
|
849
963
|
policy_name = self.quantization_config.dtype_policy_string()
|
|
964
|
+
elif mode == "awq":
|
|
965
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
850
966
|
policy = dtype_policies.get(
|
|
851
967
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
852
968
|
)
|
|
@@ -881,7 +997,7 @@ class Dense(Layer):
|
|
|
881
997
|
`kernel_scale`: The quantization scale for the merged kernel.
|
|
882
998
|
This is `None` if the layer is not quantized.
|
|
883
999
|
"""
|
|
884
|
-
if self.dtype_policy.quantization_mode in (None, "gptq"):
|
|
1000
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
885
1001
|
return self.kernel, None
|
|
886
1002
|
|
|
887
1003
|
kernel_value = self._kernel
|