keras-nightly 3.14.0.dev2026011304__py3-none-any.whl → 3.14.0.dev2026011504__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,9 @@ since your modifications would be overwritten.
7
7
  from keras.src.dtype_policies import deserialize as deserialize
8
8
  from keras.src.dtype_policies import get as get
9
9
  from keras.src.dtype_policies import serialize as serialize
10
+ from keras.src.dtype_policies.dtype_policy import (
11
+ AWQDTypePolicy as AWQDTypePolicy,
12
+ )
10
13
  from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
14
  from keras.src.dtype_policies.dtype_policy import (
12
15
  FloatDTypePolicy as FloatDTypePolicy,
@@ -7,6 +7,7 @@ since your modifications would be overwritten.
7
7
  from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
+ from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
10
11
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11
12
  from keras.src.quantizers.quantization_config import (
12
13
  Float8QuantizationConfig as Float8QuantizationConfig,
@@ -7,6 +7,9 @@ since your modifications would be overwritten.
7
7
  from keras.src.dtype_policies import deserialize as deserialize
8
8
  from keras.src.dtype_policies import get as get
9
9
  from keras.src.dtype_policies import serialize as serialize
10
+ from keras.src.dtype_policies.dtype_policy import (
11
+ AWQDTypePolicy as AWQDTypePolicy,
12
+ )
10
13
  from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
14
  from keras.src.dtype_policies.dtype_policy import (
12
15
  FloatDTypePolicy as FloatDTypePolicy,
@@ -7,6 +7,7 @@ since your modifications would be overwritten.
7
7
  from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
+ from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
10
11
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11
12
  from keras.src.quantizers.quantization_config import (
12
13
  Float8QuantizationConfig as Float8QuantizationConfig,
@@ -1126,22 +1126,28 @@ def expm1(x):
1126
1126
 
1127
1127
  def flip(x, axis=None):
1128
1128
  x_node = get_ov_output(x)
1129
- ndim = x.ndim
1129
+
1130
+ # Using OpenVINO tensor shape
1131
+ ndim = len(x_node.get_partial_shape())
1130
1132
  if ndim is None:
1131
1133
  raise ValueError(
1132
- "The `flip` operation does not support tensors with dynamic rank"
1134
+ "The `flip` operation does not support tensors with dynamic rank "
1133
1135
  "for the OpenVINO backend."
1134
1136
  )
1137
+
1135
1138
  if axis is None:
1136
1139
  axis = list(range(ndim))
1137
1140
  elif isinstance(axis, int):
1138
1141
  axis = [axis]
1142
+
1139
1143
  axis = [a + ndim if a < 0 else a for a in axis]
1144
+
1140
1145
  begin = [0] * ndim
1141
1146
  end = [0] * ndim
1142
1147
  strides = [1] * ndim
1143
1148
  for a in axis:
1144
1149
  strides[a] = -1
1150
+
1145
1151
  all_ones_mask = [1] * ndim
1146
1152
  result = ov_opset.strided_slice(
1147
1153
  data=x_node,
@@ -1154,6 +1160,61 @@ def flip(x, axis=None):
1154
1160
  return OpenVINOKerasTensor(result.output(0))
1155
1161
 
1156
1162
 
1163
+ def rot90(array, k=1, axes=(0, 1)):
1164
+ """Rotate an array by 90 degrees in the plane specified by axes."""
1165
+ array = get_ov_output(array)
1166
+
1167
+ if not isinstance(axes, (tuple, list)) or len(axes) != 2:
1168
+ raise ValueError("axes must be a tuple of length 2")
1169
+
1170
+ shape = array.get_partial_shape()
1171
+ ndim = shape.rank.get_length()
1172
+ if ndim is None:
1173
+ raise ValueError(
1174
+ "`rot90` does not support tensors with dynamic rank "
1175
+ "for the OpenVINO backend."
1176
+ )
1177
+
1178
+ axis1 = canonicalize_axis(axes[0], ndim)
1179
+ axis2 = canonicalize_axis(axes[1], ndim)
1180
+
1181
+ if axis1 == axis2:
1182
+ raise ValueError("axes must be different")
1183
+
1184
+ k = k % 4
1185
+ if k == 0:
1186
+ return OpenVINOKerasTensor(array)
1187
+
1188
+ result = array
1189
+
1190
+ for _ in range(k):
1191
+ # 1️ Transpose axis1 <-> axis2
1192
+ perm = list(range(ndim))
1193
+ perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
1194
+ perm_const = ov_opset.constant(perm, Type.i32).output(0)
1195
+ result = ov_opset.transpose(result, perm_const).output(0)
1196
+
1197
+ # 2️ Reverse along axis1 using StridedSlice
1198
+ begin = [0] * ndim
1199
+ end = [0] * ndim
1200
+ strides = [1] * ndim
1201
+ strides[axis1] = -1
1202
+
1203
+ begin_mask = [1] * ndim
1204
+ end_mask = [1] * ndim
1205
+
1206
+ result = ov_opset.strided_slice(
1207
+ data=result,
1208
+ begin=begin,
1209
+ end=end,
1210
+ strides=strides,
1211
+ begin_mask=begin_mask,
1212
+ end_mask=end_mask,
1213
+ ).output(0)
1214
+
1215
+ return OpenVINOKerasTensor(result)
1216
+
1217
+
1157
1218
  def floor(x):
1158
1219
  x = get_ov_output(x)
1159
1220
  x_type = x.get_element_type()
@@ -1394,7 +1455,66 @@ def isreal(x):
1394
1455
 
1395
1456
 
1396
1457
  def kron(x1, x2):
1397
- raise NotImplementedError("`kron` is not supported with openvino backend")
1458
+ x1 = get_ov_output(x1)
1459
+ x2 = get_ov_output(x2)
1460
+ x1, x2 = _align_operand_types(x1, x2, "kron()")
1461
+ x1_shape = x1.get_partial_shape()
1462
+ x2_shape = x2.get_partial_shape()
1463
+ if x1_shape.rank.is_dynamic or x2_shape.rank.is_dynamic:
1464
+ raise ValueError(
1465
+ "`kron` does not support tensors with dynamic rank for "
1466
+ "the OpenVINO backend."
1467
+ )
1468
+ ndim1 = x1_shape.rank.get_length()
1469
+ ndim2 = x2_shape.rank.get_length()
1470
+ if ndim1 < ndim2:
1471
+ axes = ov_opset.range(
1472
+ ov_opset.constant(0, Type.i32),
1473
+ ov_opset.constant(ndim2 - ndim1, Type.i32),
1474
+ ov_opset.constant(1, Type.i32),
1475
+ )
1476
+ x1 = ov_opset.unsqueeze(x1, axes)
1477
+ ndim1 = ndim2
1478
+ elif ndim2 < ndim1:
1479
+ axes = ov_opset.range(
1480
+ ov_opset.constant(0, Type.i32),
1481
+ ov_opset.constant(ndim1 - ndim2, Type.i32),
1482
+ ov_opset.constant(1, Type.i32),
1483
+ )
1484
+ x2 = ov_opset.unsqueeze(x2, axes)
1485
+ ndim2 = ndim1
1486
+ shape1 = ov_opset.shape_of(x1, Type.i32)
1487
+ shape2 = ov_opset.shape_of(x2, Type.i32)
1488
+ ones = ov_opset.broadcast(
1489
+ ov_opset.constant(1, Type.i32), ov_opset.constant([ndim1], Type.i32)
1490
+ )
1491
+ axis = ov_opset.constant(1, Type.i32)
1492
+ flatten = ov_opset.constant([-1], Type.i32)
1493
+ unsqueezed_ones = ov_opset.unsqueeze(ones, axis)
1494
+ x1_new_shape = ov_opset.reshape(
1495
+ ov_opset.concat(
1496
+ [ov_opset.unsqueeze(shape1, axis), unsqueezed_ones],
1497
+ axis=1,
1498
+ ),
1499
+ flatten,
1500
+ False,
1501
+ )
1502
+ x2_new_shape = ov_opset.reshape(
1503
+ ov_opset.concat(
1504
+ [unsqueezed_ones, ov_opset.unsqueeze(shape2, axis)],
1505
+ axis=1,
1506
+ ),
1507
+ flatten,
1508
+ False,
1509
+ )
1510
+ result = ov_opset.multiply(
1511
+ ov_opset.reshape(x1, x1_new_shape, False),
1512
+ ov_opset.reshape(x2, x2_new_shape, False),
1513
+ )
1514
+ result = ov_opset.reshape(
1515
+ result, ov_opset.multiply(shape1, shape2), False
1516
+ ).output(0)
1517
+ return OpenVINOKerasTensor(result)
1398
1518
 
1399
1519
 
1400
1520
  def lcm(x1, x2):
@@ -2226,7 +2346,14 @@ def sinh(x):
2226
2346
 
2227
2347
 
2228
2348
  def size(x):
2229
- raise NotImplementedError("`size` is not supported with openvino backend")
2349
+ x = get_ov_output(x)
2350
+ shape_tensor = ov_opset.shape_of(x, output_type=Type.i64)
2351
+ final_size = ov_opset.reduce_prod(
2352
+ shape_tensor,
2353
+ ov_opset.constant([0], Type.i64),
2354
+ keep_dims=False,
2355
+ )
2356
+ return OpenVINOKerasTensor(final_size.output(0))
2230
2357
 
2231
2358
 
2232
2359
  def sort(x, axis=-1):
@@ -2368,9 +2495,20 @@ def std(x, axis=None, keepdims=False):
2368
2495
 
2369
2496
 
2370
2497
  def swapaxes(x, axis1, axis2):
2371
- raise NotImplementedError(
2372
- "`swapaxes` is not supported with openvino backend"
2373
- )
2498
+ x = get_ov_output(x)
2499
+ x_shape = x.get_partial_shape()
2500
+ if x_shape.rank.is_dynamic:
2501
+ raise ValueError(
2502
+ "`swapaxes` does not support tensors with dynamic rank for the "
2503
+ "OpenVINO backend."
2504
+ )
2505
+ rank = x_shape.rank.get_length()
2506
+ axis1 = canonicalize_axis(axis1, rank)
2507
+ axis2 = canonicalize_axis(axis2, rank)
2508
+ axes = list(range(rank))
2509
+ axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
2510
+ result = ov_opset.transpose(x, ov_opset.constant(axes, Type.i32))
2511
+ return OpenVINOKerasTensor(result.output(0))
2374
2512
 
2375
2513
 
2376
2514
  def take(x, indices, axis=None):
@@ -2,6 +2,7 @@ from keras.src import backend
2
2
  from keras.src.api_export import keras_export
3
3
  from keras.src.dtype_policies import dtype_policy
4
4
  from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
5
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
5
6
  from keras.src.dtype_policies.dtype_policy import DTypePolicy
6
7
  from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
7
8
  from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
@@ -10,6 +11,7 @@ from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
10
11
  from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
11
12
 
12
13
  ALL_OBJECTS = {
14
+ AWQDTypePolicy,
13
15
  DTypePolicy,
14
16
  FloatDTypePolicy,
15
17
  QuantizedDTypePolicy,
@@ -3,7 +3,7 @@ from keras.src import ops
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.backend.common import global_state
5
5
 
6
- QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
6
+ QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
7
7
 
8
8
 
9
9
  @keras_export(
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
376
376
  return config
377
377
 
378
378
 
379
+ @keras_export("keras.dtype_policies.AWQDTypePolicy")
380
+ class AWQDTypePolicy(QuantizedDTypePolicy):
381
+ """Quantized dtype policy for AWQ quantization.
382
+
383
+ This policy helps propagate quantization settings for AWQ
384
+ when loading an AWQ quantized model in Keras format.
385
+
386
+ Args:
387
+ mode: The quantization mode. This should be a string in the format
388
+ `"awq/<weight_bits>/<group_size>"`.
389
+ - `"awq"`: The identifier for the quantization algorithm.
390
+ - `<weight_bits>`: Number of bits to quantize weights to.
391
+ AWQ presently only supports 4-bit quantization.
392
+ - `<group_size>`: The group size for quantization. Supported
393
+ values are -1 (for per-channel quantization) or any
394
+ positive integer.
395
+ Example: `"awq/4/128"`.
396
+ source_name: The source dtype policy name, e.g. "float32".
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ mode,
402
+ source_name=None,
403
+ ):
404
+ parts = mode.split("/")
405
+ expected_format = "'awq/<weight_bits>/<group_size>'"
406
+
407
+ # Validate format.
408
+ if len(parts) != 3 or parts[0] != "awq":
409
+ raise ValueError(
410
+ "Invalid mode for AWQDTypePolicy. Expected format "
411
+ f"{expected_format}, but got '{mode}'."
412
+ )
413
+
414
+ # Validate and cast weight_bits and group_size.
415
+ try:
416
+ weight_bits = int(parts[1])
417
+ group_size = int(parts[2])
418
+ except ValueError:
419
+ raise ValueError(
420
+ "Invalid mode for AWQDTypePolicy. <weight_bits> and "
421
+ "<group_size> must be integers. Expected format "
422
+ f"{expected_format}, but got '{mode}'."
423
+ )
424
+
425
+ # AWQ presently only supports 4-bit quantization.
426
+ if weight_bits != 4:
427
+ raise ValueError(
428
+ "Invalid weight_bits in mode. AWQ only supports 4-bit "
429
+ f"quantization, but got {weight_bits} from '{mode}'."
430
+ )
431
+
432
+ if group_size < -1 or group_size == 0:
433
+ raise ValueError(
434
+ "Invalid group_size in mode. Supported values are "
435
+ "-1 (per-channel) or a positive integer, "
436
+ f"but got {group_size} from '{mode}'."
437
+ )
438
+
439
+ base_mode = parts[0]
440
+ super().__init__(
441
+ mode=base_mode,
442
+ source_name=source_name,
443
+ )
444
+
445
+ self._name = f"{mode}_from_{source_name}"
446
+ self.mode = base_mode
447
+ self.weight_bits = weight_bits
448
+ self.group_size = group_size
449
+
450
+ def __eq__(self, other):
451
+ if super().__eq__(other) is False:
452
+ return False
453
+ return (
454
+ self.weight_bits == other.weight_bits
455
+ and self.group_size == other.group_size
456
+ )
457
+
458
+ def get_config(self):
459
+ config = super().get_config()
460
+ # Reconstruct the full mode string for serialization
461
+ mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
462
+ config.update({"mode": mode})
463
+ return config
464
+
465
+
379
466
  @keras_export(
380
467
  [
381
468
  "keras.config.set_dtype_policy",
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
442
529
  return QuantizedDTypePolicy(mode, source_name)
443
530
  elif policy.startswith("gptq"):
444
531
  return GPTQDTypePolicy(mode, source_name)
532
+ elif policy.startswith("awq"):
533
+ return AWQDTypePolicy(mode, source_name)
445
534
  elif policy.startswith("float8"):
446
535
  return QuantizedFloat8DTypePolicy(mode, source_name)
447
536
  else:
@@ -128,7 +128,7 @@ class Dense(Layer):
128
128
  mode=self.quantization_mode,
129
129
  config=self.quantization_config,
130
130
  )
131
- if self.quantization_mode not in ("int8", "int4", "gptq"):
131
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
132
132
  # If the layer is quantized to int8 or int4, `self._kernel` will be
133
133
  # added in `self._int8_build` or `_int4_build`. Therefore, we skip
134
134
  # it here.
@@ -165,15 +165,17 @@ class Dense(Layer):
165
165
 
166
166
  mode = self.quantization_mode
167
167
  is_gptq = mode == "gptq"
168
+ is_awq = mode == "awq"
168
169
  is_int4 = mode == "int4"
169
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
170
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
171
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
170
172
  gptq_bits = (
171
173
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
172
174
  )
173
175
 
174
176
  # Decide the source tensor first (packed vs already-quantized vs plain
175
177
  # kernel)
176
- if is_gptq and calibrated and gptq_bits != 4:
178
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
177
179
  # calibrated GPTQ, not 4-bit, no unpacking needed
178
180
  kernel = self.quantized_kernel
179
181
  else:
@@ -183,7 +185,15 @@ class Dense(Layer):
183
185
  # Handle int4 unpacking cases in one place
184
186
  if is_int4:
185
187
  kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
186
- elif is_gptq and calibrated and gptq_bits == 4:
188
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
189
+ kernel = quantizers.unpack_int4(
190
+ self.quantized_kernel,
191
+ orig_len=self.units,
192
+ axis=0,
193
+ dtype="uint8",
194
+ )
195
+ elif is_awq and awq_calibrated:
196
+ # AWQ always uses 4-bit quantization
187
197
  kernel = quantizers.unpack_int4(
188
198
  self.quantized_kernel,
189
199
  orig_len=self.units,
@@ -304,8 +314,9 @@ class Dense(Layer):
304
314
  if mode not in self.variable_serialization_spec:
305
315
  raise self._quantization_mode_error(mode)
306
316
 
307
- # A saved GPTQ quantized model will always be calibrated.
317
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
308
318
  self.is_gptq_calibrated = mode == "gptq"
319
+ self.is_awq_calibrated = mode == "awq"
309
320
 
310
321
  idx = 0
311
322
  for name in self.variable_serialization_spec[mode]:
@@ -395,6 +406,14 @@ class Dense(Layer):
395
406
  "kernel_zero",
396
407
  "g_idx",
397
408
  ],
409
+ "awq": [
410
+ "bias",
411
+ "quantized_kernel",
412
+ "kernel_scale",
413
+ "kernel_zero",
414
+ "awq_scales",
415
+ "g_idx",
416
+ ],
398
417
  }
399
418
 
400
419
  def quantized_build(self, kernel_shape, mode, config=None):
@@ -406,6 +425,8 @@ class Dense(Layer):
406
425
  self._float8_build()
407
426
  elif mode == "gptq":
408
427
  self._gptq_build(kernel_shape, config)
428
+ elif mode == "awq":
429
+ self._awq_build(kernel_shape, config)
409
430
  else:
410
431
  raise self._quantization_mode_error(mode)
411
432
  self._is_quantized = True
@@ -515,6 +536,97 @@ class Dense(Layer):
515
536
  y = self.activation(y)
516
537
  return y
517
538
 
539
+ def _awq_build(self, kernel_shape, config):
540
+ """Build variables for AWQ quantization.
541
+
542
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
543
+ salient weights based on activation magnitudes.
544
+ """
545
+ from keras.src.quantizers import awq_core
546
+
547
+ # Ensures the forward pass uses the original high-precision kernel
548
+ # until calibration has been performed.
549
+ self.is_awq_calibrated = False
550
+ self.kernel_shape = kernel_shape
551
+
552
+ # For 4-bit weights, we pack two values per byte.
553
+ units = (kernel_shape[1] + 1) // 2
554
+
555
+ self.quantized_kernel = self.add_weight(
556
+ name="kernel",
557
+ shape=(units, kernel_shape[0]),
558
+ initializer="zeros",
559
+ dtype="uint8",
560
+ trainable=False,
561
+ )
562
+
563
+ group_size = awq_core.get_group_size_for_layer(self, config)
564
+ num_groups = (
565
+ 1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
566
+ )
567
+ self.kernel_scale = self.add_weight(
568
+ name="kernel_scale",
569
+ shape=(self.units, num_groups),
570
+ initializer="ones",
571
+ trainable=False,
572
+ )
573
+ self.kernel_zero = self.add_weight(
574
+ name="kernel_zero",
575
+ shape=(self.units, num_groups),
576
+ initializer="zeros",
577
+ dtype="uint8",
578
+ trainable=False,
579
+ )
580
+
581
+ # Per-channel AWQ scales from activation magnitudes
582
+ self.awq_scales = self.add_weight(
583
+ name="awq_scales",
584
+ shape=(kernel_shape[0],),
585
+ initializer="ones",
586
+ trainable=False,
587
+ )
588
+ self.g_idx = self.add_weight(
589
+ name="g_idx",
590
+ shape=(kernel_shape[0],),
591
+ initializer="zeros",
592
+ dtype="float32",
593
+ trainable=False,
594
+ )
595
+
596
+ def _awq_call(self, inputs, training=False):
597
+ """Forward pass for AWQ quantized layer."""
598
+ if not self.is_awq_calibrated:
599
+ W = self._kernel
600
+ else:
601
+ # Unpack 4-bit weights
602
+ W = quantizers.unpack_int4(
603
+ self.quantized_kernel,
604
+ orig_len=self.units,
605
+ axis=0,
606
+ dtype="uint8",
607
+ )
608
+ # Dequantize using scale/zero maps
609
+ W = ops.transpose(
610
+ dequantize_with_sz_map(
611
+ W,
612
+ self.kernel_scale,
613
+ self.kernel_zero,
614
+ self.g_idx,
615
+ )
616
+ )
617
+ # Apply AWQ scales by dividing to restore original magnitude
618
+ # (We multiplied by scales before quantization, so divide to undo)
619
+ # awq_scales has shape [input_dim], W has shape [input_dim, units]
620
+ # Expand dims for proper broadcasting.
621
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
622
+
623
+ y = ops.matmul(inputs, W)
624
+ if self.bias is not None:
625
+ y = ops.add(y, self.bias)
626
+ if self.activation is not None:
627
+ y = self.activation(y)
628
+ return y
629
+
518
630
  def _int4_build(self, kernel_shape, config=None):
519
631
  """Build variables for int4 quantization.
520
632
 
@@ -835,6 +947,8 @@ class Dense(Layer):
835
947
  self.kernel_scale.assign(kernel_scale)
836
948
  elif mode == "gptq":
837
949
  self.quantized_build(kernel_shape, mode, self.quantization_config)
950
+ elif mode == "awq":
951
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
838
952
  elif mode == "float8":
839
953
  self.quantized_build(kernel_shape, mode)
840
954
  else:
@@ -847,6 +961,8 @@ class Dense(Layer):
847
961
  policy_name = mode
848
962
  if mode == "gptq":
849
963
  policy_name = self.quantization_config.dtype_policy_string()
964
+ elif mode == "awq":
965
+ policy_name = self.quantization_config.dtype_policy_string()
850
966
  policy = dtype_policies.get(
851
967
  f"{policy_name}_from_{self.dtype_policy.name}"
852
968
  )
@@ -881,7 +997,7 @@ class Dense(Layer):
881
997
  `kernel_scale`: The quantization scale for the merged kernel.
882
998
  This is `None` if the layer is not quantized.
883
999
  """
884
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1000
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
885
1001
  return self.kernel, None
886
1002
 
887
1003
  kernel_value = self._kernel