keras-nightly 3.14.0.dev2026011304__py3-none-any.whl → 3.14.0.dev2026011504__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -180,7 +180,7 @@ class EinsumDense(Layer):
180
180
  # quantized to int8 or int4, because `quantized_build` has created the
181
181
  # appropriate kernel variable. For other modes (e.g., float8 or no
182
182
  # quantization), we still need the floating-point kernel.
183
- if self.quantization_mode not in ("int8", "int4", "gptq"):
183
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
184
184
  # If the layer is quantized to int8, `self._kernel` will be added
185
185
  # in `self._int8_build`. Therefore, we skip it here.
186
186
  self._kernel = self.add_weight(
@@ -219,15 +219,17 @@ class EinsumDense(Layer):
219
219
 
220
220
  mode = self.quantization_mode
221
221
  is_gptq = mode == "gptq"
222
+ is_awq = mode == "awq"
222
223
  is_int4 = mode == "int4"
223
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
224
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
225
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
224
226
  gptq_bits = (
225
227
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
226
228
  )
227
229
 
228
230
  # Decide the source tensor first (packed vs already-quantized vs plain
229
231
  # kernel)
230
- if is_gptq and calibrated and gptq_bits != 4:
232
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
231
233
  # calibrated GPTQ, not 4-bit, no unpacking needed
232
234
  kernel = self.quantized_kernel
233
235
  else:
@@ -241,13 +243,21 @@ class EinsumDense(Layer):
241
243
  self._orig_length_along_pack_axis,
242
244
  self._int4_pack_axis,
243
245
  )
244
- elif is_gptq and calibrated and gptq_bits == 4:
246
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
245
247
  kernel = quantizers.unpack_int4(
246
248
  self.quantized_kernel,
247
249
  orig_len=self.gptq_unpacked_column_size,
248
250
  axis=0,
249
251
  dtype="uint8",
250
252
  )
253
+ elif is_awq and awq_calibrated:
254
+ # AWQ always uses 4-bit quantization
255
+ kernel = quantizers.unpack_int4(
256
+ self.quantized_kernel,
257
+ orig_len=self.awq_unpacked_column_size,
258
+ axis=0,
259
+ dtype="uint8",
260
+ )
251
261
 
252
262
  # Apply LoRA if enabled
253
263
  if self.lora_enabled:
@@ -362,8 +372,9 @@ class EinsumDense(Layer):
362
372
  if mode not in self.variable_serialization_spec:
363
373
  raise self._quantization_mode_error(mode)
364
374
 
365
- # A saved GPTQ quantized model will always be calibrated.
375
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
366
376
  self.is_gptq_calibrated = mode == "gptq"
377
+ self.is_awq_calibrated = mode == "awq"
367
378
 
368
379
  idx = 0
369
380
  for name in self.variable_serialization_spec[mode]:
@@ -459,6 +470,14 @@ class EinsumDense(Layer):
459
470
  "kernel_zero",
460
471
  "g_idx",
461
472
  ],
473
+ "awq": [
474
+ "bias",
475
+ "quantized_kernel",
476
+ "kernel_scale",
477
+ "kernel_zero",
478
+ "awq_scales",
479
+ "g_idx",
480
+ ],
462
481
  }
463
482
 
464
483
  def quantized_build(self, kernel_shape, mode, config=None):
@@ -470,6 +489,8 @@ class EinsumDense(Layer):
470
489
  self._float8_build()
471
490
  elif mode == "gptq":
472
491
  self._gptq_build(kernel_shape, config)
492
+ elif mode == "awq":
493
+ self._awq_build(kernel_shape, config)
473
494
  else:
474
495
  raise self._quantization_mode_error(mode)
475
496
  self._is_quantized = True
@@ -616,6 +637,127 @@ class EinsumDense(Layer):
616
637
  y = self.activation(y)
617
638
  return y
618
639
 
640
+ def _awq_build(self, kernel_shape, config):
641
+ """Build variables for AWQ quantization.
642
+
643
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
644
+ salient weights based on activation magnitudes.
645
+ """
646
+ from keras.src.quantizers import awq_core
647
+
648
+ # Ensures the forward pass uses the original high-precision kernel
649
+ # until calibration has been performed.
650
+ self.is_awq_calibrated = False
651
+
652
+ self.original_kernel_shape = kernel_shape
653
+ if len(kernel_shape) == 2:
654
+ rows = kernel_shape[0]
655
+ columns = kernel_shape[1]
656
+ elif len(kernel_shape) == 3:
657
+ shape = list(self.original_kernel_shape)
658
+ d_model_dim_index = shape.index(max(shape))
659
+
660
+ if d_model_dim_index == 0: # QKV projection case
661
+ in_features, heads, head_dim = shape
662
+ rows, columns = (
663
+ in_features,
664
+ heads * head_dim,
665
+ )
666
+ elif d_model_dim_index in [1, 2]: # Attention Output case
667
+ heads, head_dim, out_features = shape
668
+ rows, columns = (
669
+ heads * head_dim,
670
+ out_features,
671
+ )
672
+ else:
673
+ raise ValueError("Could not determine row/column split.")
674
+ else:
675
+ raise ValueError("AWQ quantization only supports 2D or 3D kernels.")
676
+
677
+ group_size = awq_core.get_group_size_for_layer(self, config)
678
+ num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
679
+
680
+ self.awq_unpacked_column_size = columns
681
+
682
+ # For 4-bit weights, we pack two values per byte.
683
+ kernel_columns = (columns + 1) // 2
684
+
685
+ self._set_quantization_info()
686
+
687
+ self.quantized_kernel = self.add_weight(
688
+ name="kernel",
689
+ shape=(kernel_columns, rows),
690
+ initializer="zeros",
691
+ dtype="uint8",
692
+ trainable=False,
693
+ )
694
+
695
+ self.kernel_scale = self.add_weight(
696
+ name="kernel_scale",
697
+ shape=(columns, num_groups),
698
+ initializer="ones",
699
+ trainable=False,
700
+ )
701
+ self.kernel_zero = self.add_weight(
702
+ name="zero_point",
703
+ shape=(columns, num_groups),
704
+ initializer="zeros",
705
+ dtype="uint8",
706
+ trainable=False,
707
+ )
708
+
709
+ # Per-channel AWQ scales from activation magnitudes
710
+ self.awq_scales = self.add_weight(
711
+ name="awq_scales",
712
+ shape=(rows,),
713
+ initializer="ones",
714
+ trainable=False,
715
+ )
716
+
717
+ self.g_idx = self.add_weight(
718
+ name="g_idx",
719
+ shape=(rows,),
720
+ initializer="zeros",
721
+ dtype="float32",
722
+ trainable=False,
723
+ )
724
+
725
+ def _awq_call(self, inputs, training=False):
726
+ """Forward pass for AWQ quantized layer."""
727
+ if not self.is_awq_calibrated:
728
+ W = self._kernel
729
+ else:
730
+ # Unpack 4-bit weights
731
+ W = quantizers.unpack_int4(
732
+ self.quantized_kernel,
733
+ orig_len=self.awq_unpacked_column_size,
734
+ axis=0,
735
+ dtype="uint8",
736
+ )
737
+ # Dequantize using scale/zero maps
738
+ W = dequantize_with_sz_map(
739
+ W,
740
+ self.kernel_scale,
741
+ self.kernel_zero,
742
+ self.g_idx,
743
+ )
744
+ W = ops.transpose(W)
745
+
746
+ # Apply AWQ scales by dividing to restore original magnitude
747
+ # (We multiplied by scales before quantization, so divide to undo)
748
+ # awq_scales has shape [input_dim], W has shape [input_dim, out_dim]
749
+ # Expand dims for proper broadcasting.
750
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
751
+
752
+ W = ops.reshape(W, self.original_kernel_shape)
753
+
754
+ y = ops.einsum(self.equation, inputs, W)
755
+ if self.bias is not None:
756
+ y = ops.add(y, self.bias)
757
+ if self.activation is not None:
758
+ y = self.activation(y)
759
+ return y
760
+
619
761
  def _int4_build(self, kernel_shape, config=None):
620
762
  """Build variables for int4 quantization.
621
763
 
@@ -1010,7 +1152,7 @@ class EinsumDense(Layer):
1010
1152
  self.quantization_config = config
1011
1153
 
1012
1154
  kernel_shape = self._kernel.shape
1013
- if mode in ("int8", "int4", "gptq"):
1155
+ if mode in ("int8", "int4", "gptq", "awq"):
1014
1156
  self._set_quantization_info()
1015
1157
 
1016
1158
  if mode == "int8":
@@ -1058,6 +1200,8 @@ class EinsumDense(Layer):
1058
1200
  policy_name = mode
1059
1201
  if mode == "gptq":
1060
1202
  policy_name = self.quantization_config.dtype_policy_string()
1203
+ elif mode == "awq":
1204
+ policy_name = self.quantization_config.dtype_policy_string()
1061
1205
  policy = dtype_policies.get(
1062
1206
  f"{policy_name}_from_{self.dtype_policy.name}"
1063
1207
  )
@@ -1121,7 +1265,7 @@ class EinsumDense(Layer):
1121
1265
  This is `None` if the layer is not quantized.
1122
1266
  """
1123
1267
  # If not a quantized layer, return the full-precision kernel directly.
1124
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1268
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
1125
1269
  return self.kernel, None
1126
1270
 
1127
1271
  # If quantized but LoRA is not enabled, return the original quantized
@@ -514,7 +514,7 @@ class Embedding(Layer):
514
514
  `embeddings_scale`: The quantization scale for the merged
515
515
  embeddings. This is `None` if the layer is not quantized.
516
516
  """
517
- if self.dtype_policy.quantization_mode in (None, "gptq"):
517
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
518
518
  return self.embeddings, None
519
519
 
520
520
  embeddings_value = self._embeddings
keras/src/layers/layer.py CHANGED
@@ -1337,6 +1337,8 @@ class Layer(BackendLayer, Operation):
1337
1337
  return self._int4_call(*args, **kwargs)
1338
1338
  elif self.quantization_mode == "gptq":
1339
1339
  return self._gptq_call(*args, **kwargs)
1340
+ elif self.quantization_mode == "awq":
1341
+ return self._awq_call(*args, **kwargs)
1340
1342
  else:
1341
1343
  raise self._quantization_mode_error(self.quantization_mode)
1342
1344
 
@@ -1352,6 +1354,9 @@ class Layer(BackendLayer, Operation):
1352
1354
  def _gptq_call(self, *args, **kwargs):
1353
1355
  raise self._not_implemented_error(self._gptq_call)
1354
1356
 
1357
+ def _awq_call(self, *args, **kwargs):
1358
+ raise self._not_implemented_error(self._awq_call)
1359
+
1355
1360
  def _not_implemented_error(self, attr, msg=None):
1356
1361
  if callable(attr):
1357
1362
  attr_name = attr.__name__
keras/src/models/model.py CHANGED
@@ -9,6 +9,7 @@ from keras.src import utils
9
9
  from keras.src.api_export import keras_export
10
10
  from keras.src.layers.layer import Layer
11
11
  from keras.src.models.variable_mapping import map_saveable_variables
12
+ from keras.src.quantizers.awq_core import awq_quantize
12
13
  from keras.src.quantizers.gptq_core import gptq_quantize
13
14
  from keras.src.quantizers.utils import should_quantize_layer
14
15
  from keras.src.saving import saving_api
@@ -547,7 +548,7 @@ class Model(Trainer, base_trainer.Trainer, Layer):
547
548
  except AttributeError:
548
549
  pass
549
550
 
550
- if mode == "gptq":
551
+ if mode in ["gptq", "awq"]:
551
552
  # Resolve model structure.
552
553
  # 1. If quantization_layer_structure is provided inside the config,
553
554
  # use that.
@@ -559,14 +560,17 @@ class Model(Trainer, base_trainer.Trainer, Layer):
559
560
 
560
561
  if structure is None:
561
562
  raise ValueError(
562
- "For 'gptq' mode, a valid quantization structure must be "
563
+ f"For {mode=}, a valid quantization structure must be "
563
564
  "provided either via `config.quantization_layer_structure` "
564
565
  "or by overriding "
565
566
  "`model.get_quantization_layer_structure(mode)`. The "
566
567
  "structure should be a dictionary with keys "
567
568
  "'pre_block_layers' and 'sequential_blocks'."
568
569
  )
569
- gptq_quantize(config, structure, filters=filters)
570
+ if mode == "gptq":
571
+ gptq_quantize(config, structure, filters=filters)
572
+ elif mode == "awq":
573
+ awq_quantize(config, structure, filters=filters)
570
574
 
571
575
  # If any layer was changed, we must rebuild the execution functions.
572
576
  if graph_modified:
keras/src/ops/numpy.py CHANGED
@@ -7802,6 +7802,15 @@ def correlate(x1, x2, mode="valid"):
7802
7802
 
7803
7803
  Returns:
7804
7804
  Output tensor, cross-correlation of `x1` and `x2`.
7805
+
7806
+ Notes:
7807
+ Complex-valued inputs are currently not fully supported on the
7808
+ TensorFlow and PyTorch backends. When complex tensors are passed,
7809
+ they are cast to floating-point types and the imaginary component
7810
+ is discarded.
7811
+
7812
+ This behavior is documented for clarity and may change in the
7813
+ future. See discussion in issue #21617.
7805
7814
  """
7806
7815
  if any_symbolic_tensors((x1, x2)):
7807
7816
  return Correlate(mode=mode).symbolic_call(x1, x2)
@@ -1,6 +1,7 @@
1
1
  import inspect
2
2
 
3
3
  from keras.src.api_export import keras_export
4
+ from keras.src.quantizers.awq_config import AWQConfig
4
5
  from keras.src.quantizers.quantization_config import Float8QuantizationConfig
5
6
  from keras.src.quantizers.quantization_config import Int4QuantizationConfig
6
7
  from keras.src.quantizers.quantization_config import Int8QuantizationConfig
@@ -24,6 +25,7 @@ ALL_OBJECTS = {
24
25
  Int8QuantizationConfig,
25
26
  Int4QuantizationConfig,
26
27
  Float8QuantizationConfig,
28
+ AWQConfig,
27
29
  }
28
30
  ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
29
31
  ALL_OBJECTS_DICT.update(