keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026013004__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ from keras.src.api_export import keras_export
12
12
  from keras.src.layers.input_spec import InputSpec
13
13
  from keras.src.layers.layer import Layer
14
14
  from keras.src.quantizers.quantization_config import QuantizationConfig
15
+ from keras.src.quantizers.quantization_config import get_block_size_for_layer
15
16
  from keras.src.quantizers.quantizers import dequantize_with_sz_map
16
17
  from keras.src.saving import serialization_lib
17
18
 
@@ -184,7 +185,10 @@ class Dense(Layer):
184
185
 
185
186
  # Handle int4 unpacking cases in one place
186
187
  if is_int4:
187
- kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
188
+ # unpack [in, ceil(out/2)] to [in, out]
189
+ kernel = quantizers.unpack_int4(
190
+ kernel, self._orig_output_dim, axis=-1
191
+ )
188
192
  elif is_gptq and gptq_calibrated and gptq_bits == 4:
189
193
  kernel = quantizers.unpack_int4(
190
194
  self.quantized_kernel,
@@ -287,15 +291,27 @@ class Dense(Layer):
287
291
  if mode not in self.variable_serialization_spec:
288
292
  raise self._quantization_mode_error(mode)
289
293
 
290
- # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
291
- # for None/gptq)
292
- kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
294
+ # Kernel plus optional merged LoRA-aware scale/zero (returns
295
+ # (kernel, None, None) for None/gptq/awq)
296
+ kernel_value, merged_kernel_scale, merged_kernel_zero = (
297
+ self._get_kernel_with_merged_lora()
298
+ )
293
299
  idx = 0
294
300
  for name in self.variable_serialization_spec[mode]:
295
301
  if name == "kernel":
296
302
  store[str(idx)] = kernel_value
297
303
  elif name == "bias" and self.bias is None:
298
304
  continue
305
+ elif name == "kernel_zero":
306
+ if merged_kernel_zero is None:
307
+ # kernel_zero only exists for sub-channel int4 quantization
308
+ continue
309
+ store[str(idx)] = merged_kernel_zero
310
+ elif name == "g_idx":
311
+ if not hasattr(self, "g_idx"):
312
+ # g_idx only exists for sub-channel int4 quantization
313
+ continue
314
+ store[str(idx)] = self.g_idx
299
315
  elif name == "kernel_scale" and mode in ("int4", "int8"):
300
316
  # For int4/int8, the merged LoRA scale (if any) comes from
301
317
  # `_get_kernel_with_merged_lora()`
@@ -324,6 +340,12 @@ class Dense(Layer):
324
340
  self._kernel.assign(store[str(idx)])
325
341
  elif name == "bias" and self.bias is None:
326
342
  continue
343
+ elif name == "kernel_zero" and not hasattr(self, "kernel_zero"):
344
+ # kernel_zero only exists for sub-channel int4 quantization
345
+ continue
346
+ elif name == "g_idx" and not hasattr(self, "g_idx"):
347
+ # g_idx only exists for sub-channel int4 quantization
348
+ continue
327
349
  else:
328
350
  getattr(self, name).assign(store[str(idx)])
329
351
  idx += 1
@@ -388,6 +410,8 @@ class Dense(Layer):
388
410
  "kernel",
389
411
  "bias",
390
412
  "kernel_scale",
413
+ "kernel_zero",
414
+ "g_idx",
391
415
  ],
392
416
  "float8": [
393
417
  "kernel",
@@ -630,37 +654,75 @@ class Dense(Layer):
630
654
  def _int4_build(self, kernel_shape, config=None):
631
655
  """Build variables for int4 quantization.
632
656
 
633
- `kernel_shape` is the *original* float32 kernel shape
634
- `(input_dim, units)`. We allocate the stored kernel with rows
635
- `ceil(input_dim/2)` because two int4 values are packed into a single
636
- int8 byte.
657
+ The kernel is packed along the last axis,
658
+ resulting in shape `(input_dim, ceil(units/2))`.
659
+
660
+ Args:
661
+ kernel_shape: The original float32 kernel shape
662
+ `(input_dim, units)`.
663
+ config: Optional quantization config specifying block_size.
637
664
  """
638
- # Per-channel int8 quantizer for the last axis (features).
639
665
  self.inputs_quantizer = (
640
- QuantizationConfig.activation_quantizer_or_default(
641
- config, quantizers.AbsMaxQuantizer()
642
- )
666
+ QuantizationConfig.activation_quantizer_or_default(config, None)
643
667
  )
644
668
  input_dim, output_dim = kernel_shape
645
- packed_rows = (input_dim + 1) // 2 # ceil for odd dims
646
669
 
647
- # Kernel is stored *packed*: each int8 byte contains two int4 values.
670
+ # kernel is packed along last axis (output dimension)
671
+ # Stored shape: [input_dim, ceil(output_dim/2)]
672
+ packed_cols = (output_dim + 1) // 2
673
+
648
674
  self._kernel = self.add_weight(
649
675
  name="kernel",
650
- shape=(packed_rows, output_dim),
676
+ shape=(input_dim, packed_cols),
651
677
  initializer="zeros",
652
678
  dtype="int8",
653
679
  trainable=False,
654
680
  )
655
- # One scale per output unit (per-channel).
681
+
682
+ block_size = get_block_size_for_layer(self, config)
683
+ self._int4_block_size = block_size
684
+
685
+ if block_size is None or block_size == -1:
686
+ # Per-channel: one scale per output unit
687
+ scale_shape = (self.units,)
688
+ else:
689
+ # Sub-channel: [n_groups, out_features]
690
+ n_groups = math.ceil(input_dim / block_size)
691
+ scale_shape = (n_groups, self.units)
692
+
656
693
  self.kernel_scale = self.add_weight(
657
694
  name="kernel_scale",
658
- shape=(self.units,),
695
+ shape=scale_shape,
659
696
  initializer="ones",
660
697
  trainable=False,
661
698
  )
662
- # Record original input_dim for unpacking at runtime.
699
+
700
+ # Sub-channel quantization uses asymmetric quantization
701
+ if block_size is not None and block_size > 0:
702
+
703
+ def idx_initializer(shape, dtype):
704
+ return ops.floor_divide(
705
+ ops.arange(input_dim, dtype=dtype), block_size
706
+ )
707
+
708
+ self.kernel_zero = self.add_weight(
709
+ name="kernel_zero",
710
+ shape=scale_shape,
711
+ initializer="zeros",
712
+ dtype="int8",
713
+ trainable=False,
714
+ )
715
+ self.g_idx = self.add_weight(
716
+ name="g_idx",
717
+ shape=(input_dim,),
718
+ initializer=idx_initializer,
719
+ dtype="float32",
720
+ trainable=False,
721
+ )
722
+
723
+ # Record dimensions for unpacking and reshaping at runtime.
663
724
  self._orig_input_dim = input_dim
725
+ self._orig_output_dim = output_dim
664
726
 
665
727
  def _float8_build(self):
666
728
  from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
@@ -755,57 +817,108 @@ class Dense(Layer):
755
817
  return x
756
818
 
757
819
  def _int4_call(self, inputs, training=None):
758
- """Forward pass for int4 quantized Dense layer."""
820
+ """Forward pass for int4 quantized Dense layer.
759
821
 
760
- @ops.custom_gradient
761
- def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
762
- """Custom gradient function for int4 quantized weights.
763
-
764
- Automatic differentiation will not know how to handle the
765
- int4 quantized weights. So a custom gradient function is needed
766
- to handle the int4 quantized weights.
767
-
768
- The custom gradient function will use the dequantized kernel to
769
- compute the gradient.
770
- """
822
+ Uses custom gradients to handle quantized weights since autodiff
823
+ cannot differentiate through int4 operations.
824
+ """
825
+ block_size = getattr(self, "_int4_block_size", None)
826
+
827
+ if block_size is None or block_size == -1:
828
+ # Per-channel: symmetric quantization (no zero point needed)
829
+ @ops.custom_gradient
830
+ def matmul_per_channel_with_inputs_gradient(
831
+ inputs, kernel, kernel_scale
832
+ ):
833
+ """Per-channel int4 forward pass with custom gradient."""
834
+ # Unpack: stored as [in, ceil(out/2)], unpack along last axis
835
+ unpacked_kernel = quantizers.unpack_int4(
836
+ kernel, self._orig_output_dim, axis=-1
837
+ )
771
838
 
772
- unpacked_kernel = quantizers.unpack_int4(
773
- kernel, self._orig_input_dim
839
+ def grad_fn(*args, upstream=None):
840
+ if upstream is None:
841
+ (upstream,) = args
842
+ # Per-channel: unpacked is [in, out]
843
+ float_kernel = ops.divide(
844
+ ops.cast(unpacked_kernel, dtype=self.compute_dtype),
845
+ kernel_scale,
846
+ )
847
+ inputs_grad = ops.matmul(
848
+ upstream, ops.transpose(float_kernel)
849
+ )
850
+ return (inputs_grad, None, None)
851
+
852
+ # Forward pass: per-channel dequantization
853
+ output_scale = kernel_scale
854
+ if self.inputs_quantizer:
855
+ inputs, inputs_scale = self.inputs_quantizer(
856
+ inputs, axis=-1
857
+ )
858
+ output_scale = ops.multiply(output_scale, inputs_scale)
859
+
860
+ x = ops.matmul(inputs, unpacked_kernel)
861
+ x = ops.cast(x, self.compute_dtype)
862
+ x = ops.divide(x, output_scale)
863
+ return x, grad_fn
864
+
865
+ x = matmul_per_channel_with_inputs_gradient(
866
+ inputs,
867
+ ops.convert_to_tensor(self._kernel),
868
+ ops.convert_to_tensor(self.kernel_scale),
774
869
  )
870
+ else:
871
+ # Sub-channel: asymmetric quantization (with zero point)
872
+ @ops.custom_gradient
873
+ def matmul_sub_channel_with_inputs_gradient(
874
+ inputs, kernel, kernel_scale, kernel_zero, g_idx
875
+ ):
876
+ """Sub-channel int4 forward pass with custom gradient."""
877
+ # Unpack: stored as [in, ceil(out/2)], unpack along last axis
878
+ unpacked_kernel = quantizers.unpack_int4(
879
+ kernel, self._orig_output_dim, axis=-1
880
+ )
775
881
 
776
- def grad_fn(*args, upstream=None):
777
- if upstream is None:
778
- (upstream,) = args
779
- float_kernel = ops.divide(
780
- ops.cast(unpacked_kernel, dtype=self.compute_dtype),
882
+ def grad_fn(*args, upstream=None):
883
+ if upstream is None:
884
+ (upstream,) = args
885
+ float_kernel = dequantize_with_sz_map(
886
+ unpacked_kernel,
887
+ kernel_scale,
888
+ kernel_zero,
889
+ g_idx,
890
+ group_axis=0,
891
+ )
892
+ float_kernel = ops.cast(float_kernel, self.compute_dtype)
893
+ inputs_grad = ops.matmul(
894
+ upstream, ops.transpose(float_kernel)
895
+ )
896
+ return (inputs_grad, None, None, None, None)
897
+
898
+ float_kernel = dequantize_with_sz_map(
899
+ unpacked_kernel,
781
900
  kernel_scale,
901
+ kernel_zero,
902
+ g_idx,
903
+ group_axis=0,
782
904
  )
783
- inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
784
- return (inputs_grad, None, None)
785
-
786
- output_scale = kernel_scale
905
+ float_kernel = ops.cast(float_kernel, self.compute_dtype)
906
+ x = ops.matmul(inputs, float_kernel)
907
+ return x, grad_fn
787
908
 
788
- if self.inputs_quantizer:
789
- inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
790
- output_scale = ops.multiply(output_scale, inputs_scale)
791
-
792
- x = ops.matmul(inputs, unpacked_kernel)
793
- x = ops.cast(x, self.compute_dtype)
794
- x = ops.divide(x, output_scale)
795
- return x, grad_fn
796
-
797
- x = matmul_with_inputs_gradient(
798
- inputs,
799
- ops.convert_to_tensor(self._kernel),
800
- ops.convert_to_tensor(self.kernel_scale),
801
- )
909
+ x = matmul_sub_channel_with_inputs_gradient(
910
+ inputs,
911
+ ops.convert_to_tensor(self._kernel),
912
+ ops.convert_to_tensor(self.kernel_scale),
913
+ ops.convert_to_tensor(self.kernel_zero),
914
+ ops.convert_to_tensor(self.g_idx),
915
+ )
802
916
 
803
917
  if self.lora_enabled:
804
918
  lora_x = ops.matmul(inputs, self.lora_kernel_a)
805
919
  lora_x = ops.matmul(lora_x, self.lora_kernel_b)
806
920
  x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
807
921
 
808
- # Add bias and activation
809
922
  if self.bias is not None:
810
923
  x = ops.add(x, self.bias)
811
924
  if self.activation is not None:
@@ -925,26 +1038,49 @@ class Dense(Layer):
925
1038
  self._kernel.assign(kernel_value)
926
1039
  self.kernel_scale.assign(kernel_scale)
927
1040
  elif mode == "int4":
928
- # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
929
- weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
930
- self.quantization_config,
931
- quantizers.AbsMaxQuantizer(
932
- axis=0, value_range=(-8, 7), output_dtype="int8"
933
- ),
1041
+ from keras.src.quantizers.quantization_config import (
1042
+ Int4QuantizationConfig,
934
1043
  )
935
- kernel_value_int4, kernel_scale = weight_quantizer(
936
- self._kernel, to_numpy=True
1044
+
1045
+ block_size = None
1046
+ if isinstance(self.quantization_config, Int4QuantizationConfig):
1047
+ block_size = self.quantization_config.block_size
1048
+
1049
+ if block_size is None or block_size == -1:
1050
+ # Per-channel quantization
1051
+ weight_quantizer = (
1052
+ QuantizationConfig.weight_quantizer_or_default(
1053
+ self.quantization_config,
1054
+ quantizers.AbsMaxQuantizer(
1055
+ axis=0, value_range=(-8, 7), output_dtype="int8"
1056
+ ),
1057
+ )
1058
+ )
1059
+ kernel_value_int4, kernel_scale = weight_quantizer(
1060
+ self._kernel, to_numpy=True
1061
+ )
1062
+ kernel_scale = ops.squeeze(kernel_scale, axis=0)
1063
+ else:
1064
+ # Sub-channel quantization with asymmetric zero point
1065
+ # Returns kernel [in, out], scale [n_groups, out], zero
1066
+ # [n_groups, out]
1067
+ kernel_value_int4, kernel_scale, kernel_zero = (
1068
+ quantizers.abs_max_quantize_grouped_with_zero_point(
1069
+ self._kernel, block_size=block_size, to_numpy=True
1070
+ )
1071
+ )
1072
+
1073
+ # Pack two int4 values per int8 byte along last axis
1074
+ # Stored as [in, ceil(out/2)]
1075
+ packed_kernel_value, _, _ = quantizers.pack_int4(
1076
+ kernel_value_int4, axis=-1
937
1077
  )
938
- kernel_scale = ops.squeeze(kernel_scale, axis=0)
939
- # 2. Pack two int4 values into a single int8 byte.
940
- packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4)
941
1078
  del self._kernel
942
- # Build variables using the original kernel shape; _int4_build will
943
- # compute the packed shape internally.
944
1079
  self.quantized_build(kernel_shape, mode, self.quantization_config)
945
- # Assign packed values.
946
1080
  self._kernel.assign(packed_kernel_value)
947
1081
  self.kernel_scale.assign(kernel_scale)
1082
+ if block_size is not None and block_size > 0:
1083
+ self.kernel_zero.assign(kernel_zero)
948
1084
  elif mode == "gptq":
949
1085
  self.quantized_build(kernel_shape, mode, self.quantization_config)
950
1086
  elif mode == "awq":
@@ -959,10 +1095,14 @@ class Dense(Layer):
959
1095
  from keras.src import dtype_policies # local import to avoid cycle
960
1096
 
961
1097
  policy_name = mode
962
- if mode == "gptq":
963
- policy_name = self.quantization_config.dtype_policy_string()
964
- elif mode == "awq":
1098
+ if mode in ("gptq", "awq"):
965
1099
  policy_name = self.quantization_config.dtype_policy_string()
1100
+ elif mode == "int4":
1101
+ # Include block_size in policy name for sub-channel quantization
1102
+ block_size = get_block_size_for_layer(self, config)
1103
+ # Use -1 for per-channel, otherwise use block_size
1104
+ block_size_value = -1 if block_size is None else block_size
1105
+ policy_name = f"int4/{block_size_value}"
966
1106
  policy = dtype_policies.get(
967
1107
  f"{policy_name}_from_{self.dtype_policy.name}"
968
1108
  )
@@ -991,32 +1131,49 @@ class Dense(Layer):
991
1131
  without modification.
992
1132
 
993
1133
  Returns:
994
- A tuple `(kernel_value, kernel_scale)`:
1134
+ A tuple `(kernel_value, kernel_scale, kernel_zero)`:
995
1135
  `kernel_value`: The merged kernel. A quantized tensor if
996
1136
  quantization is active, otherwise a high precision tensor.
997
1137
  `kernel_scale`: The quantization scale for the merged kernel.
998
1138
  This is `None` if the layer is not quantized.
1139
+ `kernel_zero`: The zero point for sub-channel int4 quantization.
1140
+ This is `None` for per-channel or non-int4 modes.
999
1141
  """
1000
1142
  if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
1001
- return self.kernel, None
1143
+ return self.kernel, None, None
1002
1144
 
1003
1145
  kernel_value = self._kernel
1004
1146
  kernel_scale = self.kernel_scale
1147
+ kernel_zero = getattr(self, "kernel_zero", None)
1005
1148
 
1006
1149
  if not self.lora_enabled:
1007
- return kernel_value, kernel_scale
1150
+ return kernel_value, kernel_scale, kernel_zero
1008
1151
 
1009
1152
  # Dequantize, Merge, and Re-quantize
1153
+ block_size = getattr(self, "_int4_block_size", None)
1010
1154
 
1011
- # Dequantize kernel to float
1155
+ # Step 1: Dequantize kernel to float
1012
1156
  if self.quantization_mode == "int4":
1157
+ # Unpack along last axis ([in, out])
1013
1158
  unpacked_kernel = quantizers.unpack_int4(
1014
- kernel_value, self._orig_input_dim
1015
- )
1016
- float_kernel = ops.divide(
1017
- ops.cast(unpacked_kernel, self.compute_dtype),
1018
- kernel_scale,
1159
+ kernel_value, self._orig_output_dim, axis=-1
1019
1160
  )
1161
+ if block_size is None or block_size == -1:
1162
+ # Per-channel: kernel [in, out], scale [out]
1163
+ float_kernel = ops.divide(
1164
+ ops.cast(unpacked_kernel, self.compute_dtype),
1165
+ kernel_scale,
1166
+ )
1167
+ else:
1168
+ # Sub-channel: scale/zero are [n_groups, out]
1169
+ float_kernel = dequantize_with_sz_map(
1170
+ unpacked_kernel,
1171
+ kernel_scale,
1172
+ self.kernel_zero,
1173
+ self.g_idx,
1174
+ group_axis=0,
1175
+ )
1176
+ float_kernel = ops.cast(float_kernel, self.compute_dtype)
1020
1177
  quant_range = (-8, 7)
1021
1178
  elif self.quantization_mode == "int8":
1022
1179
  float_kernel = ops.divide(
@@ -1028,25 +1185,51 @@ class Dense(Layer):
1028
1185
  f"Unsupported quantization mode: {self.quantization_mode}"
1029
1186
  )
1030
1187
 
1031
- # Merge LoRA weights in float domain
1188
+ # Step 2: Merge LoRA weights in float domain
1032
1189
  lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
1033
1190
  self.lora_kernel_a, self.lora_kernel_b
1034
1191
  )
1035
1192
  merged_float_kernel = ops.add(float_kernel, lora_delta)
1036
1193
 
1037
- # Requantize
1038
- requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
1039
- merged_float_kernel,
1040
- axis=0,
1041
- value_range=quant_range,
1042
- dtype="int8",
1043
- to_numpy=True,
1044
- )
1045
- kernel_scale = ops.squeeze(kernel_scale, axis=0)
1194
+ # Step 3: Re-quantize the merged kernel
1195
+ if (
1196
+ self.quantization_mode == "int4"
1197
+ and block_size is not None
1198
+ and block_size != -1
1199
+ ):
1200
+ # Sub-channel: returns kernel [in, out], scale [n_groups, out]
1201
+ requantized_kernel, kernel_scale, kernel_zero = (
1202
+ quantizers.abs_max_quantize_grouped_with_zero_point(
1203
+ merged_float_kernel, block_size=block_size, to_numpy=True
1204
+ )
1205
+ )
1206
+ elif self.quantization_mode == "int4":
1207
+ # Per-channel: quantize along input axis (axis=0)
1208
+ requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
1209
+ merged_float_kernel,
1210
+ axis=0,
1211
+ value_range=quant_range,
1212
+ dtype="int8",
1213
+ to_numpy=True,
1214
+ )
1215
+ kernel_scale = ops.squeeze(kernel_scale, axis=0)
1216
+ kernel_zero = None
1217
+ else:
1218
+ requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
1219
+ merged_float_kernel,
1220
+ axis=0,
1221
+ value_range=quant_range,
1222
+ dtype="int8",
1223
+ to_numpy=True,
1224
+ )
1225
+ kernel_scale = ops.squeeze(kernel_scale, axis=0)
1226
+ kernel_zero = None
1046
1227
 
1047
- # Pack if int4
1048
1228
  if self.quantization_mode == "int4":
1049
- kernel_value, _, _ = quantizers.pack_int4(requantized_kernel)
1229
+ # Pack along last axis
1230
+ kernel_value, _, _ = quantizers.pack_int4(
1231
+ requantized_kernel, axis=-1
1232
+ )
1050
1233
  else:
1051
1234
  kernel_value = requantized_kernel
1052
- return kernel_value, kernel_scale
1235
+ return kernel_value, kernel_scale, kernel_zero