keras-nightly 3.14.0.dev2026012704__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +1 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +3 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +1 -0
  7. keras/ops/numpy/__init__.py +1 -0
  8. keras/quantizers/__init__.py +3 -0
  9. keras/src/backend/jax/core.py +12 -2
  10. keras/src/backend/jax/numpy.py +5 -0
  11. keras/src/backend/numpy/numpy.py +5 -0
  12. keras/src/backend/openvino/numpy.py +6 -0
  13. keras/src/backend/tensorflow/numpy.py +21 -0
  14. keras/src/backend/torch/numpy.py +10 -0
  15. keras/src/callbacks/orbax_checkpoint.py +41 -8
  16. keras/src/dtype_policies/__init__.py +2 -0
  17. keras/src/dtype_policies/dtype_policy.py +80 -1
  18. keras/src/layers/core/dense.py +278 -95
  19. keras/src/layers/core/einsum_dense.py +350 -181
  20. keras/src/layers/core/embedding.py +236 -49
  21. keras/src/layers/core/reversible_embedding.py +177 -35
  22. keras/src/layers/preprocessing/discretization.py +30 -1
  23. keras/src/ops/numpy.py +54 -0
  24. keras/src/quantizers/__init__.py +6 -0
  25. keras/src/quantizers/quantization_config.py +98 -4
  26. keras/src/quantizers/quantizers.py +262 -32
  27. keras/src/saving/file_editor.py +7 -1
  28. keras/src/saving/saving_api.py +66 -2
  29. keras/src/saving/saving_lib.py +46 -47
  30. keras/src/version.py +1 -1
  31. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
  32. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +34 -34
  33. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
  34. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from keras.src.api_export import keras_export
17
17
  from keras.src.layers.input_spec import InputSpec
18
18
  from keras.src.layers.layer import Layer
19
19
  from keras.src.quantizers.quantization_config import QuantizationConfig
20
+ from keras.src.quantizers.quantization_config import get_block_size_for_layer
20
21
  from keras.src.quantizers.quantizers import dequantize_with_sz_map
21
22
  from keras.src.saving import serialization_lib
22
23
 
@@ -238,11 +239,13 @@ class EinsumDense(Layer):
238
239
 
239
240
  # Handle int4 unpacking cases in one place
240
241
  if is_int4:
242
+ # unpack [rows, ceil(columns/2)] to [rows, columns]
241
243
  kernel = quantizers.unpack_int4(
242
244
  kernel,
243
- self._orig_length_along_pack_axis,
244
- self._int4_pack_axis,
245
+ self._int4_unpacked_column_size,
246
+ axis=-1,
245
247
  )
248
+ kernel = ops.reshape(kernel, self.original_kernel_shape)
246
249
  elif is_gptq and gptq_calibrated and gptq_bits == 4:
247
250
  kernel = quantizers.unpack_int4(
248
251
  self.quantized_kernel,
@@ -306,16 +309,11 @@ class EinsumDense(Layer):
306
309
  self._tracker.unlock()
307
310
  # Determine the appropriate (unpacked) kernel shape for LoRA.
308
311
  if self.quantization_mode == "int4":
309
- # When int4-quantized, `self._kernel` is packed along
310
- # `self._int4_pack_axis` and its length equals
311
- # `(orig_len + 1) // 2`. Recover the original length so that
312
- # the LoRA matrices operate in the full-precision space.
313
- kernel_shape_for_lora = list(self._kernel.shape)
314
- pack_axis = getattr(self, "_int4_pack_axis", 0)
315
- orig_len = getattr(self, "_orig_length_along_pack_axis", None)
316
- if orig_len is not None:
317
- kernel_shape_for_lora[pack_axis] = orig_len
318
- kernel_shape_for_lora = tuple(kernel_shape_for_lora)
312
+ # INT4 weights are stored in a flattened 2D layout that loses
313
+ # the original N-dimensional structure required by the einsum
314
+ # equation. We use `original_kernel_shape`` to ensure LoRA adapters
315
+ # operate in the correct logical dimension space.
316
+ kernel_shape_for_lora = tuple(self.original_kernel_shape)
319
317
  else:
320
318
  kernel_shape_for_lora = self.kernel.shape
321
319
 
@@ -327,7 +325,7 @@ class EinsumDense(Layer):
327
325
  )
328
326
  self.lora_kernel_b = self.add_weight(
329
327
  name="lora_kernel_b",
330
- shape=(rank, self.kernel.shape[-1]),
328
+ shape=(rank, kernel_shape_for_lora[-1]),
331
329
  initializer=initializers.get(b_initializer),
332
330
  regularizer=self.kernel_regularizer,
333
331
  )
@@ -345,15 +343,27 @@ class EinsumDense(Layer):
345
343
  if mode not in self.variable_serialization_spec:
346
344
  raise self._quantization_mode_error(mode)
347
345
 
348
- # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
349
- # for None/gptq)
350
- kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
346
+ # Kernel plus optional merged LoRA-aware scale/zero (returns
347
+ # (kernel, None, None) for None/gptq)
348
+ kernel_value, merged_kernel_scale, merged_kernel_zero = (
349
+ self._get_kernel_with_merged_lora()
350
+ )
351
351
  idx = 0
352
352
  for name in self.variable_serialization_spec[mode]:
353
353
  if name == "kernel":
354
354
  store[str(idx)] = kernel_value
355
355
  elif name == "bias" and self.bias is None:
356
356
  continue
357
+ elif name == "kernel_zero":
358
+ if merged_kernel_zero is None:
359
+ # kernel_zero only exists for sub-channel int4 quantization
360
+ continue
361
+ store[str(idx)] = merged_kernel_zero
362
+ elif name == "g_idx":
363
+ if not hasattr(self, "g_idx"):
364
+ # g_idx only exists for sub-channel int4 quantization
365
+ continue
366
+ store[str(idx)] = self.g_idx
357
367
  elif name == "kernel_scale" and mode in ("int4", "int8"):
358
368
  # For int4/int8, the merged LoRA scale (if any) comes from
359
369
  # `_get_kernel_with_merged_lora()`
@@ -382,6 +392,12 @@ class EinsumDense(Layer):
382
392
  self._kernel.assign(store[str(idx)])
383
393
  elif name == "bias" and self.bias is None:
384
394
  continue
395
+ elif name == "kernel_zero" and not hasattr(self, "kernel_zero"):
396
+ # kernel_zero only exists for sub-channel int4 quantization
397
+ continue
398
+ elif name == "g_idx" and not hasattr(self, "g_idx"):
399
+ # g_idx only exists for sub-channel int4 quantization
400
+ continue
385
401
  else:
386
402
  getattr(self, name).assign(store[str(idx)])
387
403
  idx += 1
@@ -452,6 +468,8 @@ class EinsumDense(Layer):
452
468
  "kernel",
453
469
  "bias",
454
470
  "kernel_scale",
471
+ "kernel_zero",
472
+ "g_idx",
455
473
  ],
456
474
  "float8": [
457
475
  "kernel",
@@ -761,60 +779,81 @@ class EinsumDense(Layer):
761
779
  def _int4_build(self, kernel_shape, config=None):
762
780
  """Build variables for int4 quantization.
763
781
 
764
- The packed int4 kernel stores two int4 values within a single int8
765
- byte. Packing is performed along the first axis contained in
766
- `self._kernel_reduced_axes` (which is the axis that gets reduced in
767
- the einsum and thus analogous to the input-dim axis of a `Dense`
768
- layer).
782
+ The kernel is flattened to 2D [rows, columns]
783
+ and packed along last axis to [rows, ceil(columns/2)].
784
+
785
+ Args:
786
+ kernel_shape: Original kernel shape (may be N-dimensional).
787
+ config: Optional quantization config specifying block_size.
769
788
  """
770
789
  self._set_quantization_info()
771
790
 
772
- # Quantizer for the inputs (per the reduced axes)
773
791
  self.inputs_quantizer = (
774
- QuantizationConfig.activation_quantizer_or_default(
775
- config,
776
- quantizers.AbsMaxQuantizer(),
777
- )
792
+ QuantizationConfig.activation_quantizer_or_default(config, None)
778
793
  )
779
- # If the config provided a default AbsMaxQuantizer, we need to
780
- # override the axis to match the equation's reduction axes.
781
794
  self.quantization_axis = tuple(self._input_reduced_axes)
795
+ self.original_kernel_shape = kernel_shape
782
796
 
783
- # Choose the axis to perform int4 packing - use the first reduced axis
784
- # for the kernel (analogous to the input dimension of a Dense layer).
785
- self._int4_pack_axis = (
786
- self._kernel_reduced_axes[0] if self._kernel_reduced_axes else 0
787
- )
788
-
789
- # Original length along the packing axis (needed for unpacking).
790
- self._orig_length_along_pack_axis = kernel_shape[self._int4_pack_axis]
791
-
792
- # Packed length (ceil division by 2). Note: assumes static integer.
793
- packed_len = (self._orig_length_along_pack_axis + 1) // 2
797
+ # Flatten kernel to 2D: rows = reduced dims, columns = non-reduced dims
798
+ rows = 1
799
+ columns = 1
800
+ for i, dim in enumerate(kernel_shape):
801
+ if i in self._kernel_reduced_axes:
802
+ rows *= dim
803
+ else:
804
+ columns *= dim
794
805
 
795
- # Derive packed kernel shape by replacing the pack axis dimension.
796
- packed_kernel_shape = list(kernel_shape)
797
- packed_kernel_shape[self._int4_pack_axis] = packed_len
798
- packed_kernel_shape = tuple(packed_kernel_shape)
806
+ block_size = get_block_size_for_layer(self, config)
807
+ use_grouped = block_size is not None and block_size != -1
808
+ self._int4_block_size = block_size if use_grouped else None
809
+ self._int4_unpacked_column_size = columns
810
+ self._int4_rows = rows
799
811
 
800
- # Add packed int4 kernel variable (stored as int8 dtype).
812
+ # Kernel packed along last axis (columns)
813
+ # Stored shape: [rows, ceil(columns/2)]
814
+ packed_cols = (columns + 1) // 2
801
815
  self._kernel = self.add_weight(
802
816
  name="kernel",
803
- shape=packed_kernel_shape,
817
+ shape=(rows, packed_cols),
804
818
  initializer="zeros",
805
819
  dtype="int8",
806
820
  trainable=False,
807
821
  )
808
822
 
809
- # Kernel scale
810
- kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape)
823
+ if use_grouped:
824
+ # Sub-channel: [n_groups, columns]
825
+ n_groups = math.ceil(rows / block_size)
826
+ scale_shape = (n_groups, columns)
827
+ else:
828
+ scale_shape = (columns,)
829
+
811
830
  self.kernel_scale = self.add_weight(
812
831
  name="kernel_scale",
813
- shape=kernel_scale_shape,
832
+ shape=scale_shape,
814
833
  initializer="ones",
815
834
  trainable=False,
816
835
  )
817
836
 
837
+ # Sub-channel quantization uses asymmetric quantization with zero point
838
+ if use_grouped:
839
+ self.kernel_zero = self.add_weight(
840
+ name="kernel_zero",
841
+ shape=scale_shape,
842
+ initializer="zeros",
843
+ dtype="int8",
844
+ trainable=False,
845
+ )
846
+ self.g_idx = self.add_weight(
847
+ name="g_idx",
848
+ shape=(rows,),
849
+ initializer="zeros",
850
+ dtype="float32",
851
+ trainable=False,
852
+ )
853
+ self.g_idx.assign(
854
+ ops.floor_divide(ops.arange(rows, dtype="float32"), block_size)
855
+ )
856
+
818
857
  def _float8_build(self):
819
858
  from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
820
859
 
@@ -948,98 +987,136 @@ class EinsumDense(Layer):
948
987
  return x
949
988
 
950
989
  def _int4_call(self, inputs, training=None):
951
- """Forward pass for int4 quantized `EinsumDense`."""
990
+ """Forward pass for int4 quantized EinsumDense.
952
991
 
953
- pack_axis = getattr(self, "_int4_pack_axis", 0)
954
- orig_len = getattr(self, "_orig_length_along_pack_axis", None)
992
+ Uses custom gradients to handle quantized weights since autodiff
993
+ cannot differentiate through int4 operations.
994
+ """
995
+ block_size = getattr(self, "_int4_block_size", None)
955
996
 
956
- @ops.custom_gradient
957
- def einsum_with_inputs_gradient(inputs, packed_kernel, kernel_scale):
958
- """Performs int4 quantized einsum with a custom gradient.
997
+ if block_size is None or block_size == -1:
959
998
 
960
- Computes the einsum operation with quantized inputs and a quantized
961
- kernel, then de-quantizes the result.
999
+ @ops.custom_gradient
1000
+ def einsum_per_channel_with_inputs_gradient(
1001
+ inputs, packed_kernel, kernel_scale
1002
+ ):
1003
+ """Per-channel int4 forward pass with custom gradient."""
1004
+ # Unpack: stored as [rows, ceil(columns/2)],
1005
+ # unpack along last axis
1006
+ unpacked_kernel = quantizers.unpack_int4(
1007
+ packed_kernel,
1008
+ self._int4_unpacked_column_size,
1009
+ axis=-1,
1010
+ dtype="int8",
1011
+ )
962
1012
 
963
- Also computes the gradient with respect to the original,
964
- full-precision inputs by using a de-quantized kernel.
1013
+ def _dequantize_kernel(unpacked, scale):
1014
+ # kernel is [rows, columns], scale is [columns]
1015
+ float_kernel = ops.divide(
1016
+ ops.cast(unpacked, dtype=self.compute_dtype),
1017
+ scale,
1018
+ )
1019
+ return ops.reshape(float_kernel, self.original_kernel_shape)
965
1020
 
966
- Args:
967
- inputs: The full-precision input tensor.
968
- packed_kernel: The int4-packed kernel tensor.
969
- kernel_scale: The float32 scale factor for the kernel.
1021
+ def grad_fn(*args, upstream=None):
1022
+ if upstream is None:
1023
+ (upstream,) = args
1024
+ float_kernel = _dequantize_kernel(
1025
+ unpacked_kernel, kernel_scale
1026
+ )
1027
+ inputs_grad = ops.einsum(
1028
+ self._custom_gradient_equation, upstream, float_kernel
1029
+ )
1030
+ return (inputs_grad, None, None)
970
1031
 
971
- Returns:
972
- A tuple `(output, grad_fn)`:
973
- `output`: The de-quantized result of the einsum operation.
974
- `grad_fn`: The custom gradient function for the backward
975
- pass.
1032
+ if self.inputs_quantizer:
1033
+ # Per-channel with input quantization
1034
+ float_kernel = _dequantize_kernel(
1035
+ unpacked_kernel, kernel_scale
1036
+ )
1037
+ inputs_q, inputs_scale = self.inputs_quantizer(
1038
+ inputs, axis=self.quantization_axis
1039
+ )
1040
+ inputs_scale = self._adjust_scale_for_quant(
1041
+ inputs_scale, "input"
1042
+ )
1043
+ # Cast inputs to float for einsum. This is a workaround
1044
+ # for PyTorch's einsum which doesn't support
1045
+ # mixed-precision inputs (int8 input, float kernel).
1046
+ if backend.backend() == "torch":
1047
+ x = ops.einsum(
1048
+ self.equation,
1049
+ ops.cast(inputs_q, self.compute_dtype),
1050
+ float_kernel,
1051
+ )
1052
+ x = ops.divide(x, inputs_scale)
1053
+ else:
1054
+ x = ops.einsum(self.equation, inputs_q, float_kernel)
1055
+ x = ops.cast(x, self.compute_dtype)
1056
+ x = ops.divide(x, inputs_scale)
1057
+ else:
1058
+ # Weight-only per-channel quantization
1059
+ float_kernel = _dequantize_kernel(
1060
+ unpacked_kernel, kernel_scale
1061
+ )
1062
+ x = ops.einsum(self.equation, inputs, float_kernel)
1063
+ return x, grad_fn
976
1064
 
977
- Raises:
978
- ValueError: If the quantization mode is not supported.
979
- """
980
- # Unpack the int4-packed kernel back to int8 values [-8, 7].
981
- unpacked_kernel = quantizers.unpack_int4(
982
- packed_kernel, orig_len, axis=pack_axis
1065
+ x = einsum_per_channel_with_inputs_gradient(
1066
+ inputs,
1067
+ ops.convert_to_tensor(self._kernel),
1068
+ ops.convert_to_tensor(self.kernel_scale),
983
1069
  )
1070
+ else:
984
1071
 
985
- def grad_fn(*args, upstream=None):
986
- if upstream is None:
987
- (upstream,) = args
988
- # Align `kernel_scale` to the same layout as `unpacked_kernel`.
989
- _kernel_scale = kernel_scale
990
- _kernel_scale = self._adjust_scale_for_dequant(_kernel_scale)
991
-
992
- float_kernel = ops.divide(
993
- ops.cast(unpacked_kernel, dtype=self.compute_dtype),
994
- _kernel_scale,
995
- )
996
- inputs_grad = ops.einsum(
997
- self._custom_gradient_equation, upstream, float_kernel
1072
+ @ops.custom_gradient
1073
+ def einsum_sub_channel_with_inputs_gradient(
1074
+ inputs, packed_kernel, kernel_scale, kernel_zero, g_idx
1075
+ ):
1076
+ """Sub-channel int4 forward pass with custom gradient."""
1077
+ # Unpack: stored as [rows, ceil(columns/2)],
1078
+ # unpack along last axis
1079
+ unpacked_kernel = quantizers.unpack_int4(
1080
+ packed_kernel,
1081
+ self._int4_unpacked_column_size,
1082
+ axis=-1,
1083
+ dtype="int8",
998
1084
  )
999
- return (inputs_grad, None, None)
1000
1085
 
1001
- # Quantize inputs per `self.inputs_quantizer`.
1002
- if self.inputs_quantizer:
1003
- inputs_q, inputs_scale = self.inputs_quantizer(
1004
- inputs, axis=self.quantization_axis
1005
- )
1006
- # Align `inputs_scale` axes with the output
1007
- # for correct broadcasting
1008
- inputs_scale = self._adjust_scale_for_quant(
1009
- inputs_scale, "input"
1010
- )
1011
- x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
1012
- # De-scale outputs
1013
- x = ops.cast(x, self.compute_dtype)
1014
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
1015
- else:
1016
- # Weight-only quantization: dequantize kernel and use float
1017
- # einsum. This is a workaround for PyTorch's einsum which
1018
- # doesn't support mixed-precision inputs (float input,
1019
- # int4 kernel).
1020
- if backend.backend() == "torch":
1021
- # Align `kernel_scale` to the same layout as
1022
- # `unpacked_kernel`.
1023
- kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
1024
- float_kernel = ops.divide(
1025
- ops.cast(unpacked_kernel, dtype=self.compute_dtype),
1026
- kernel_scale,
1086
+ def _dequantize_kernel(unpacked, scale, zero, g_idx_t):
1087
+ # Dequantize with group_axis=0 since
1088
+ # scale is [n_groups, columns]
1089
+ float_kernel = dequantize_with_sz_map(
1090
+ unpacked, scale, zero, g_idx_t, group_axis=0
1027
1091
  )
1028
- x = ops.einsum(self.equation, inputs, float_kernel)
1029
- else:
1030
- x = ops.einsum(self.equation, inputs, unpacked_kernel)
1031
- # De-scale outputs
1032
- x = ops.cast(x, self.compute_dtype)
1033
- x = ops.divide(x, kernel_scale)
1034
- return x, grad_fn
1092
+ float_kernel = ops.cast(float_kernel, self.compute_dtype)
1093
+ return ops.reshape(float_kernel, self.original_kernel_shape)
1094
+
1095
+ def grad_fn(*args, upstream=None):
1096
+ if upstream is None:
1097
+ (upstream,) = args
1098
+ float_kernel = _dequantize_kernel(
1099
+ unpacked_kernel, kernel_scale, kernel_zero, g_idx
1100
+ )
1101
+ inputs_grad = ops.einsum(
1102
+ self._custom_gradient_equation, upstream, float_kernel
1103
+ )
1104
+ return (inputs_grad, None, None, None, None)
1035
1105
 
1036
- x = einsum_with_inputs_gradient(
1037
- inputs,
1038
- ops.convert_to_tensor(self._kernel),
1039
- ops.convert_to_tensor(self.kernel_scale),
1040
- )
1106
+ float_kernel = _dequantize_kernel(
1107
+ unpacked_kernel, kernel_scale, kernel_zero, g_idx
1108
+ )
1109
+ x = ops.einsum(self.equation, inputs, float_kernel)
1110
+ return x, grad_fn
1111
+
1112
+ x = einsum_sub_channel_with_inputs_gradient(
1113
+ inputs,
1114
+ ops.convert_to_tensor(self._kernel),
1115
+ ops.convert_to_tensor(self.kernel_scale),
1116
+ ops.convert_to_tensor(self.kernel_zero),
1117
+ ops.convert_to_tensor(self.g_idx),
1118
+ )
1041
1119
 
1042
- # Add LoRA contribution if enabled
1043
1120
  if self.lora_enabled:
1044
1121
  lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)
1045
1122
  lora_x = ops.matmul(lora_x, self.lora_kernel_b)
@@ -1167,24 +1244,50 @@ class EinsumDense(Layer):
1167
1244
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
1168
1245
  del self._kernel
1169
1246
  elif mode == "int4":
1170
- # Quantize to int4 values (stored in int8 dtype, range [-8, 7])
1171
- weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1172
- self.quantization_config,
1173
- quantizers.AbsMaxQuantizer(
1174
- axis=self._kernel_reduced_axes,
1175
- value_range=(-8, 7),
1176
- output_dtype="int8",
1177
- ),
1247
+ from keras.src.quantizers.quantization_config import (
1248
+ Int4QuantizationConfig,
1178
1249
  )
1179
- kernel_value_int4, kernel_scale = weight_quantizer(
1180
- self._kernel, to_numpy=True
1181
- )
1182
- kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
1183
1250
 
1184
- # Pack along the first kernel-reduced axis.
1185
- pack_axis = self._kernel_reduced_axes[0]
1251
+ block_size = None
1252
+ if isinstance(self.quantization_config, Int4QuantizationConfig):
1253
+ block_size = self.quantization_config.block_size
1254
+
1255
+ use_grouped = block_size is not None and block_size != -1
1256
+
1257
+ # Flatten kernel to 2D: rows = reduced dims, columns = non-reduced
1258
+ rows = 1
1259
+ columns = 1
1260
+ for i, dim in enumerate(kernel_shape):
1261
+ if i in self._kernel_reduced_axes:
1262
+ rows *= dim
1263
+ else:
1264
+ columns *= dim
1265
+
1266
+ flat_kernel = ops.reshape(self._kernel, (rows, columns))
1267
+
1268
+ if not use_grouped:
1269
+ # Per-channel quantization
1270
+ kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
1271
+ flat_kernel,
1272
+ axis=0,
1273
+ value_range=(-8, 7),
1274
+ dtype="int8",
1275
+ to_numpy=True,
1276
+ )
1277
+ kernel_scale = ops.squeeze(kernel_scale, axis=0)
1278
+ else:
1279
+ # Sub-channel quantization with asymmetric zero point
1280
+ # Returns kernel [rows, columns], scale [n_groups, columns]
1281
+ kernel_value_int4, kernel_scale, kernel_zero = (
1282
+ quantizers.abs_max_quantize_grouped_with_zero_point(
1283
+ flat_kernel, block_size=block_size, to_numpy=True
1284
+ )
1285
+ )
1286
+
1287
+ # Pack two int4 values per int8 byte along last axis
1288
+ # Stored as [rows, ceil(columns/2)]
1186
1289
  packed_kernel_value, _, _ = quantizers.pack_int4(
1187
- kernel_value_int4, axis=pack_axis
1290
+ kernel_value_int4, axis=-1
1188
1291
  )
1189
1292
  kernel_value = packed_kernel_value
1190
1293
  del self._kernel
@@ -1194,42 +1297,69 @@ class EinsumDense(Layer):
1194
1297
  if mode in ("int8", "int4"):
1195
1298
  self._kernel.assign(kernel_value)
1196
1299
  self.kernel_scale.assign(kernel_scale)
1300
+ # Assign zero point for sub-channel int4 quantization
1301
+ if mode == "int4" and use_grouped:
1302
+ self.kernel_zero.assign(kernel_zero)
1197
1303
 
1198
1304
  # Set new dtype policy
1199
1305
  if self.dtype_policy.quantization_mode is None:
1200
1306
  policy_name = mode
1201
- if mode == "gptq":
1202
- policy_name = self.quantization_config.dtype_policy_string()
1203
- elif mode == "awq":
1307
+ if mode in ("gptq", "awq"):
1204
1308
  policy_name = self.quantization_config.dtype_policy_string()
1309
+ elif mode == "int4":
1310
+ # Include block_size in policy name for sub-channel quantization
1311
+ block_size = get_block_size_for_layer(self, config)
1312
+ # Use -1 for per-channel, otherwise use block_size
1313
+ block_size_value = -1 if block_size is None else block_size
1314
+ policy_name = f"int4/{block_size_value}"
1205
1315
  policy = dtype_policies.get(
1206
1316
  f"{policy_name}_from_{self.dtype_policy.name}"
1207
1317
  )
1208
1318
  self.dtype_policy = policy
1209
1319
 
1210
- def _get_kernel_scale_shape(self, kernel_shape):
1320
+ def _get_kernel_scale_shape(self, kernel_shape, block_size=None):
1211
1321
  """Get the shape of the kernel scale tensor.
1212
1322
 
1213
1323
  The kernel scale tensor is used to scale the kernel tensor.
1214
1324
  The shape of the kernel scale tensor is the same as the shape of the
1215
- kernel tensor, but with the reduced axes set to 1 and the transpose
1216
- axes set to the original axes
1325
+ kernel tensor, but with the reduced axes set to 1 (for per-channel)
1326
+ or n_groups (for grouped quantization), and the transpose axes set
1327
+ to the original axes.
1217
1328
 
1218
1329
  Args:
1219
1330
  kernel_shape: The shape of the kernel tensor.
1331
+ block_size: If provided and positive, use grouped quantization
1332
+ along the reduced axes with the specified block size.
1220
1333
 
1221
1334
  Returns:
1222
1335
  The shape of the kernel scale tensor.
1223
1336
  """
1224
- kernel_scale_shape = np.array(kernel_shape)
1225
- kernel_scale_shape[self._kernel_reduced_axes] = 1
1226
- kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes]
1227
- kernel_scale_shape = kernel_scale_shape.tolist()
1228
- for a in sorted(self._kernel_expand_axes):
1229
- kernel_scale_shape.insert(a, 1)
1230
- for a in sorted(self._kernel_squeeze_axes, reverse=True):
1231
- kernel_scale_shape.pop(a)
1232
- return kernel_scale_shape
1337
+ if block_size is not None and block_size > 0:
1338
+ # Grouped quantization: use simple 2D scale shape
1339
+ # (n_groups, non_reduced) - matches dequantize_grouped format
1340
+ total_reduced_dim = 1
1341
+ for ax in self._kernel_reduced_axes:
1342
+ total_reduced_dim *= kernel_shape[ax]
1343
+ n_groups = math.ceil(total_reduced_dim / block_size)
1344
+
1345
+ total_non_reduced = 1
1346
+ for i, dim in enumerate(kernel_shape):
1347
+ if i not in self._kernel_reduced_axes:
1348
+ total_non_reduced *= dim
1349
+
1350
+ return (n_groups, total_non_reduced)
1351
+ else:
1352
+ # Per-channel quantization: use the original transformation logic
1353
+ kernel_scale_shape = np.array(kernel_shape)
1354
+ kernel_scale_shape[self._kernel_reduced_axes] = 1
1355
+
1356
+ kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes]
1357
+ kernel_scale_shape = kernel_scale_shape.tolist()
1358
+ for a in sorted(self._kernel_expand_axes):
1359
+ kernel_scale_shape.insert(a, 1)
1360
+ for a in sorted(self._kernel_squeeze_axes, reverse=True):
1361
+ kernel_scale_shape.pop(a)
1362
+ return kernel_scale_shape
1233
1363
 
1234
1364
  def _get_kernel_with_merged_lora(self):
1235
1365
  """Returns the kernel with LoRA matrices merged, for serialization.
@@ -1258,33 +1388,53 @@ class EinsumDense(Layer):
1258
1388
  without modification.
1259
1389
 
1260
1390
  Returns:
1261
- A tuple `(kernel_value, kernel_scale)`:
1391
+ A tuple `(kernel_value, kernel_scale, kernel_zero)`:
1262
1392
  `kernel_value`: The merged kernel. A quantized tensor if
1263
1393
  quantization is active, otherwise a high precision tensor.
1264
1394
  `kernel_scale`: The quantization scale for the merged kernel.
1265
1395
  This is `None` if the layer is not quantized.
1396
+ `kernel_zero`: The zero point for sub-channel int4 quantization.
1397
+ This is `None` for per-channel or non-int4 modes.
1266
1398
  """
1267
1399
  # If not a quantized layer, return the full-precision kernel directly.
1268
1400
  if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
1269
- return self.kernel, None
1401
+ return self.kernel, None, None
1402
+
1403
+ kernel_zero = getattr(self, "kernel_zero", None)
1270
1404
 
1271
1405
  # If quantized but LoRA is not enabled, return the original quantized
1272
1406
  # kernel.
1273
1407
  if not self.lora_enabled:
1274
- return self._kernel, self.kernel_scale
1408
+ return self._kernel, self.kernel_scale, kernel_zero
1275
1409
 
1276
1410
  # Dequantize, Merge, and Re-quantize
1277
1411
 
1278
1412
  # 1. Dequantize the kernel
1279
1413
  if self.quantization_mode == "int4":
1414
+ # Unpack [rows, ceil(columns/2)] to [rows, columns]
1280
1415
  unpacked_kernel = quantizers.unpack_int4(
1281
1416
  self._kernel,
1282
- self._orig_length_along_pack_axis,
1283
- axis=self._int4_pack_axis,
1417
+ self._int4_unpacked_column_size,
1418
+ axis=-1,
1284
1419
  )
1285
- # Adjust scale for dequantization (reverse the transformations).
1286
- adjusted_scale = self._adjust_scale_for_dequant(self.kernel_scale)
1287
- kernel_fp = ops.divide(unpacked_kernel, adjusted_scale)
1420
+ block_size = getattr(self, "_int4_block_size", None)
1421
+ if block_size is not None and block_size != -1:
1422
+ # Grouped dequantization with group_axis=0
1423
+ kernel_fp = dequantize_with_sz_map(
1424
+ unpacked_kernel,
1425
+ self.kernel_scale,
1426
+ self.kernel_zero,
1427
+ self.g_idx,
1428
+ group_axis=0,
1429
+ )
1430
+ else:
1431
+ # Per-channel dequantization:
1432
+ # kernel [rows, columns], scale [columns]
1433
+ kernel_fp = ops.divide(
1434
+ ops.cast(unpacked_kernel, self.compute_dtype),
1435
+ self.kernel_scale,
1436
+ )
1437
+ kernel_fp = ops.reshape(kernel_fp, self.original_kernel_shape)
1288
1438
  elif self.quantization_mode == "int8":
1289
1439
  adjusted_scale = self._adjust_scale_for_dequant(self.kernel_scale)
1290
1440
  kernel_fp = ops.divide(self._kernel, adjusted_scale)
@@ -1297,32 +1447,51 @@ class EinsumDense(Layer):
1297
1447
  lora_update = (self.lora_alpha / self.lora_rank) * ops.matmul(
1298
1448
  self.lora_kernel_a, self.lora_kernel_b
1299
1449
  )
1300
- merged_kernel_fp = ops.add(kernel_fp, lora_update)
1450
+ merged_kernel = ops.add(kernel_fp, lora_update)
1301
1451
 
1302
1452
  # 3. Re-quantize the merged float kernel back to the target format
1303
1453
  if self.quantization_mode == "int4":
1304
- kernel_quant, new_scale = quantizers.abs_max_quantize(
1305
- merged_kernel_fp,
1306
- axis=self._kernel_reduced_axes,
1307
- value_range=(-8, 7),
1308
- dtype="int8",
1309
- to_numpy=True,
1310
- )
1311
- # Pack back to int4
1312
- new_kernel, _, _ = quantizers.pack_int4(
1313
- kernel_quant, axis=self._int4_pack_axis
1314
- )
1454
+ block_size = getattr(self, "_int4_block_size", None)
1455
+ rows = self._int4_rows
1456
+ columns = self._int4_unpacked_column_size
1457
+
1458
+ # Flatten to 2D [rows, columns]
1459
+ flat_kernel = ops.reshape(merged_kernel, (rows, columns))
1460
+
1461
+ if block_size is not None and block_size != -1:
1462
+ # Use abs_max_quantize_grouped_with_zero_point for proper
1463
+ # signed quantization (same as quantize() method)
1464
+ # Returns kernel [rows, columns], scale [n_groups, columns]
1465
+ kernel_quant, new_scale, new_zero = (
1466
+ quantizers.abs_max_quantize_grouped_with_zero_point(
1467
+ flat_kernel, block_size=block_size, to_numpy=True
1468
+ )
1469
+ )
1470
+ kernel_zero = new_zero
1471
+ else:
1472
+ # Per-channel: quantize along rows axis
1473
+ kernel_quant, new_scale = quantizers.abs_max_quantize(
1474
+ flat_kernel,
1475
+ axis=0,
1476
+ value_range=(-8, 7),
1477
+ dtype="int8",
1478
+ to_numpy=True,
1479
+ )
1480
+ new_scale = ops.squeeze(new_scale, axis=0)
1481
+ kernel_zero = None
1482
+
1483
+ # Pack along last axis
1484
+ new_kernel, _, _ = quantizers.pack_int4(kernel_quant, axis=-1)
1315
1485
  elif self.quantization_mode == "int8":
1316
1486
  new_kernel, new_scale = quantizers.abs_max_quantize(
1317
- merged_kernel_fp,
1487
+ merged_kernel,
1318
1488
  axis=self._kernel_reduced_axes,
1319
1489
  to_numpy=True,
1320
1490
  )
1491
+ new_scale = self._adjust_scale_for_quant(new_scale, "kernel")
1492
+ kernel_zero = None
1321
1493
 
1322
- # Adjust the new scale tensor to the required layout.
1323
- new_scale = self._adjust_scale_for_quant(new_scale, "kernel")
1324
-
1325
- return new_kernel, new_scale
1494
+ return new_kernel, new_scale, kernel_zero
1326
1495
 
1327
1496
  def _adjust_scale_for_dequant(self, scale):
1328
1497
  """Adjusts scale tensor layout for dequantization.