keras-nightly 3.14.0.dev2026012704__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +1 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
- keras/_tf_keras/keras/quantizers/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +1 -0
- keras/ops/numpy/__init__.py +1 -0
- keras/quantizers/__init__.py +3 -0
- keras/src/backend/jax/core.py +12 -2
- keras/src/backend/jax/numpy.py +5 -0
- keras/src/backend/numpy/numpy.py +5 -0
- keras/src/backend/openvino/numpy.py +6 -0
- keras/src/backend/tensorflow/numpy.py +21 -0
- keras/src/backend/torch/numpy.py +10 -0
- keras/src/callbacks/orbax_checkpoint.py +41 -8
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +80 -1
- keras/src/layers/core/dense.py +278 -95
- keras/src/layers/core/einsum_dense.py +350 -181
- keras/src/layers/core/embedding.py +236 -49
- keras/src/layers/core/reversible_embedding.py +177 -35
- keras/src/layers/preprocessing/discretization.py +30 -1
- keras/src/ops/numpy.py +54 -0
- keras/src/quantizers/__init__.py +6 -0
- keras/src/quantizers/quantization_config.py +98 -4
- keras/src/quantizers/quantizers.py +262 -32
- keras/src/saving/file_editor.py +7 -1
- keras/src/saving/saving_api.py +66 -2
- keras/src/saving/saving_lib.py +46 -47
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +34 -34
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ from keras.src.api_export import keras_export
|
|
|
17
17
|
from keras.src.layers.input_spec import InputSpec
|
|
18
18
|
from keras.src.layers.layer import Layer
|
|
19
19
|
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
20
|
+
from keras.src.quantizers.quantization_config import get_block_size_for_layer
|
|
20
21
|
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
21
22
|
from keras.src.saving import serialization_lib
|
|
22
23
|
|
|
@@ -238,11 +239,13 @@ class EinsumDense(Layer):
|
|
|
238
239
|
|
|
239
240
|
# Handle int4 unpacking cases in one place
|
|
240
241
|
if is_int4:
|
|
242
|
+
# unpack [rows, ceil(columns/2)] to [rows, columns]
|
|
241
243
|
kernel = quantizers.unpack_int4(
|
|
242
244
|
kernel,
|
|
243
|
-
self.
|
|
244
|
-
|
|
245
|
+
self._int4_unpacked_column_size,
|
|
246
|
+
axis=-1,
|
|
245
247
|
)
|
|
248
|
+
kernel = ops.reshape(kernel, self.original_kernel_shape)
|
|
246
249
|
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
247
250
|
kernel = quantizers.unpack_int4(
|
|
248
251
|
self.quantized_kernel,
|
|
@@ -306,16 +309,11 @@ class EinsumDense(Layer):
|
|
|
306
309
|
self._tracker.unlock()
|
|
307
310
|
# Determine the appropriate (unpacked) kernel shape for LoRA.
|
|
308
311
|
if self.quantization_mode == "int4":
|
|
309
|
-
#
|
|
310
|
-
#
|
|
311
|
-
#
|
|
312
|
-
#
|
|
313
|
-
kernel_shape_for_lora =
|
|
314
|
-
pack_axis = getattr(self, "_int4_pack_axis", 0)
|
|
315
|
-
orig_len = getattr(self, "_orig_length_along_pack_axis", None)
|
|
316
|
-
if orig_len is not None:
|
|
317
|
-
kernel_shape_for_lora[pack_axis] = orig_len
|
|
318
|
-
kernel_shape_for_lora = tuple(kernel_shape_for_lora)
|
|
312
|
+
# INT4 weights are stored in a flattened 2D layout that loses
|
|
313
|
+
# the original N-dimensional structure required by the einsum
|
|
314
|
+
# equation. We use `original_kernel_shape`` to ensure LoRA adapters
|
|
315
|
+
# operate in the correct logical dimension space.
|
|
316
|
+
kernel_shape_for_lora = tuple(self.original_kernel_shape)
|
|
319
317
|
else:
|
|
320
318
|
kernel_shape_for_lora = self.kernel.shape
|
|
321
319
|
|
|
@@ -327,7 +325,7 @@ class EinsumDense(Layer):
|
|
|
327
325
|
)
|
|
328
326
|
self.lora_kernel_b = self.add_weight(
|
|
329
327
|
name="lora_kernel_b",
|
|
330
|
-
shape=(rank,
|
|
328
|
+
shape=(rank, kernel_shape_for_lora[-1]),
|
|
331
329
|
initializer=initializers.get(b_initializer),
|
|
332
330
|
regularizer=self.kernel_regularizer,
|
|
333
331
|
)
|
|
@@ -345,15 +343,27 @@ class EinsumDense(Layer):
|
|
|
345
343
|
if mode not in self.variable_serialization_spec:
|
|
346
344
|
raise self._quantization_mode_error(mode)
|
|
347
345
|
|
|
348
|
-
# Kernel plus optional merged LoRA-aware scale (returns
|
|
349
|
-
# for None/gptq)
|
|
350
|
-
kernel_value, merged_kernel_scale =
|
|
346
|
+
# Kernel plus optional merged LoRA-aware scale/zero (returns
|
|
347
|
+
# (kernel, None, None) for None/gptq)
|
|
348
|
+
kernel_value, merged_kernel_scale, merged_kernel_zero = (
|
|
349
|
+
self._get_kernel_with_merged_lora()
|
|
350
|
+
)
|
|
351
351
|
idx = 0
|
|
352
352
|
for name in self.variable_serialization_spec[mode]:
|
|
353
353
|
if name == "kernel":
|
|
354
354
|
store[str(idx)] = kernel_value
|
|
355
355
|
elif name == "bias" and self.bias is None:
|
|
356
356
|
continue
|
|
357
|
+
elif name == "kernel_zero":
|
|
358
|
+
if merged_kernel_zero is None:
|
|
359
|
+
# kernel_zero only exists for sub-channel int4 quantization
|
|
360
|
+
continue
|
|
361
|
+
store[str(idx)] = merged_kernel_zero
|
|
362
|
+
elif name == "g_idx":
|
|
363
|
+
if not hasattr(self, "g_idx"):
|
|
364
|
+
# g_idx only exists for sub-channel int4 quantization
|
|
365
|
+
continue
|
|
366
|
+
store[str(idx)] = self.g_idx
|
|
357
367
|
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
358
368
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
359
369
|
# `_get_kernel_with_merged_lora()`
|
|
@@ -382,6 +392,12 @@ class EinsumDense(Layer):
|
|
|
382
392
|
self._kernel.assign(store[str(idx)])
|
|
383
393
|
elif name == "bias" and self.bias is None:
|
|
384
394
|
continue
|
|
395
|
+
elif name == "kernel_zero" and not hasattr(self, "kernel_zero"):
|
|
396
|
+
# kernel_zero only exists for sub-channel int4 quantization
|
|
397
|
+
continue
|
|
398
|
+
elif name == "g_idx" and not hasattr(self, "g_idx"):
|
|
399
|
+
# g_idx only exists for sub-channel int4 quantization
|
|
400
|
+
continue
|
|
385
401
|
else:
|
|
386
402
|
getattr(self, name).assign(store[str(idx)])
|
|
387
403
|
idx += 1
|
|
@@ -452,6 +468,8 @@ class EinsumDense(Layer):
|
|
|
452
468
|
"kernel",
|
|
453
469
|
"bias",
|
|
454
470
|
"kernel_scale",
|
|
471
|
+
"kernel_zero",
|
|
472
|
+
"g_idx",
|
|
455
473
|
],
|
|
456
474
|
"float8": [
|
|
457
475
|
"kernel",
|
|
@@ -761,60 +779,81 @@ class EinsumDense(Layer):
|
|
|
761
779
|
def _int4_build(self, kernel_shape, config=None):
|
|
762
780
|
"""Build variables for int4 quantization.
|
|
763
781
|
|
|
764
|
-
The
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
782
|
+
The kernel is flattened to 2D [rows, columns]
|
|
783
|
+
and packed along last axis to [rows, ceil(columns/2)].
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
kernel_shape: Original kernel shape (may be N-dimensional).
|
|
787
|
+
config: Optional quantization config specifying block_size.
|
|
769
788
|
"""
|
|
770
789
|
self._set_quantization_info()
|
|
771
790
|
|
|
772
|
-
# Quantizer for the inputs (per the reduced axes)
|
|
773
791
|
self.inputs_quantizer = (
|
|
774
|
-
QuantizationConfig.activation_quantizer_or_default(
|
|
775
|
-
config,
|
|
776
|
-
quantizers.AbsMaxQuantizer(),
|
|
777
|
-
)
|
|
792
|
+
QuantizationConfig.activation_quantizer_or_default(config, None)
|
|
778
793
|
)
|
|
779
|
-
# If the config provided a default AbsMaxQuantizer, we need to
|
|
780
|
-
# override the axis to match the equation's reduction axes.
|
|
781
794
|
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
795
|
+
self.original_kernel_shape = kernel_shape
|
|
782
796
|
|
|
783
|
-
#
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
# Packed length (ceil division by 2). Note: assumes static integer.
|
|
793
|
-
packed_len = (self._orig_length_along_pack_axis + 1) // 2
|
|
797
|
+
# Flatten kernel to 2D: rows = reduced dims, columns = non-reduced dims
|
|
798
|
+
rows = 1
|
|
799
|
+
columns = 1
|
|
800
|
+
for i, dim in enumerate(kernel_shape):
|
|
801
|
+
if i in self._kernel_reduced_axes:
|
|
802
|
+
rows *= dim
|
|
803
|
+
else:
|
|
804
|
+
columns *= dim
|
|
794
805
|
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
806
|
+
block_size = get_block_size_for_layer(self, config)
|
|
807
|
+
use_grouped = block_size is not None and block_size != -1
|
|
808
|
+
self._int4_block_size = block_size if use_grouped else None
|
|
809
|
+
self._int4_unpacked_column_size = columns
|
|
810
|
+
self._int4_rows = rows
|
|
799
811
|
|
|
800
|
-
#
|
|
812
|
+
# Kernel packed along last axis (columns)
|
|
813
|
+
# Stored shape: [rows, ceil(columns/2)]
|
|
814
|
+
packed_cols = (columns + 1) // 2
|
|
801
815
|
self._kernel = self.add_weight(
|
|
802
816
|
name="kernel",
|
|
803
|
-
shape=
|
|
817
|
+
shape=(rows, packed_cols),
|
|
804
818
|
initializer="zeros",
|
|
805
819
|
dtype="int8",
|
|
806
820
|
trainable=False,
|
|
807
821
|
)
|
|
808
822
|
|
|
809
|
-
|
|
810
|
-
|
|
823
|
+
if use_grouped:
|
|
824
|
+
# Sub-channel: [n_groups, columns]
|
|
825
|
+
n_groups = math.ceil(rows / block_size)
|
|
826
|
+
scale_shape = (n_groups, columns)
|
|
827
|
+
else:
|
|
828
|
+
scale_shape = (columns,)
|
|
829
|
+
|
|
811
830
|
self.kernel_scale = self.add_weight(
|
|
812
831
|
name="kernel_scale",
|
|
813
|
-
shape=
|
|
832
|
+
shape=scale_shape,
|
|
814
833
|
initializer="ones",
|
|
815
834
|
trainable=False,
|
|
816
835
|
)
|
|
817
836
|
|
|
837
|
+
# Sub-channel quantization uses asymmetric quantization with zero point
|
|
838
|
+
if use_grouped:
|
|
839
|
+
self.kernel_zero = self.add_weight(
|
|
840
|
+
name="kernel_zero",
|
|
841
|
+
shape=scale_shape,
|
|
842
|
+
initializer="zeros",
|
|
843
|
+
dtype="int8",
|
|
844
|
+
trainable=False,
|
|
845
|
+
)
|
|
846
|
+
self.g_idx = self.add_weight(
|
|
847
|
+
name="g_idx",
|
|
848
|
+
shape=(rows,),
|
|
849
|
+
initializer="zeros",
|
|
850
|
+
dtype="float32",
|
|
851
|
+
trainable=False,
|
|
852
|
+
)
|
|
853
|
+
self.g_idx.assign(
|
|
854
|
+
ops.floor_divide(ops.arange(rows, dtype="float32"), block_size)
|
|
855
|
+
)
|
|
856
|
+
|
|
818
857
|
def _float8_build(self):
|
|
819
858
|
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
|
|
820
859
|
|
|
@@ -948,98 +987,136 @@ class EinsumDense(Layer):
|
|
|
948
987
|
return x
|
|
949
988
|
|
|
950
989
|
def _int4_call(self, inputs, training=None):
|
|
951
|
-
"""Forward pass for int4 quantized
|
|
990
|
+
"""Forward pass for int4 quantized EinsumDense.
|
|
952
991
|
|
|
953
|
-
|
|
954
|
-
|
|
992
|
+
Uses custom gradients to handle quantized weights since autodiff
|
|
993
|
+
cannot differentiate through int4 operations.
|
|
994
|
+
"""
|
|
995
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
955
996
|
|
|
956
|
-
|
|
957
|
-
def einsum_with_inputs_gradient(inputs, packed_kernel, kernel_scale):
|
|
958
|
-
"""Performs int4 quantized einsum with a custom gradient.
|
|
997
|
+
if block_size is None or block_size == -1:
|
|
959
998
|
|
|
960
|
-
|
|
961
|
-
|
|
999
|
+
@ops.custom_gradient
|
|
1000
|
+
def einsum_per_channel_with_inputs_gradient(
|
|
1001
|
+
inputs, packed_kernel, kernel_scale
|
|
1002
|
+
):
|
|
1003
|
+
"""Per-channel int4 forward pass with custom gradient."""
|
|
1004
|
+
# Unpack: stored as [rows, ceil(columns/2)],
|
|
1005
|
+
# unpack along last axis
|
|
1006
|
+
unpacked_kernel = quantizers.unpack_int4(
|
|
1007
|
+
packed_kernel,
|
|
1008
|
+
self._int4_unpacked_column_size,
|
|
1009
|
+
axis=-1,
|
|
1010
|
+
dtype="int8",
|
|
1011
|
+
)
|
|
962
1012
|
|
|
963
|
-
|
|
964
|
-
|
|
1013
|
+
def _dequantize_kernel(unpacked, scale):
|
|
1014
|
+
# kernel is [rows, columns], scale is [columns]
|
|
1015
|
+
float_kernel = ops.divide(
|
|
1016
|
+
ops.cast(unpacked, dtype=self.compute_dtype),
|
|
1017
|
+
scale,
|
|
1018
|
+
)
|
|
1019
|
+
return ops.reshape(float_kernel, self.original_kernel_shape)
|
|
965
1020
|
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
1021
|
+
def grad_fn(*args, upstream=None):
|
|
1022
|
+
if upstream is None:
|
|
1023
|
+
(upstream,) = args
|
|
1024
|
+
float_kernel = _dequantize_kernel(
|
|
1025
|
+
unpacked_kernel, kernel_scale
|
|
1026
|
+
)
|
|
1027
|
+
inputs_grad = ops.einsum(
|
|
1028
|
+
self._custom_gradient_equation, upstream, float_kernel
|
|
1029
|
+
)
|
|
1030
|
+
return (inputs_grad, None, None)
|
|
970
1031
|
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
1032
|
+
if self.inputs_quantizer:
|
|
1033
|
+
# Per-channel with input quantization
|
|
1034
|
+
float_kernel = _dequantize_kernel(
|
|
1035
|
+
unpacked_kernel, kernel_scale
|
|
1036
|
+
)
|
|
1037
|
+
inputs_q, inputs_scale = self.inputs_quantizer(
|
|
1038
|
+
inputs, axis=self.quantization_axis
|
|
1039
|
+
)
|
|
1040
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
1041
|
+
inputs_scale, "input"
|
|
1042
|
+
)
|
|
1043
|
+
# Cast inputs to float for einsum. This is a workaround
|
|
1044
|
+
# for PyTorch's einsum which doesn't support
|
|
1045
|
+
# mixed-precision inputs (int8 input, float kernel).
|
|
1046
|
+
if backend.backend() == "torch":
|
|
1047
|
+
x = ops.einsum(
|
|
1048
|
+
self.equation,
|
|
1049
|
+
ops.cast(inputs_q, self.compute_dtype),
|
|
1050
|
+
float_kernel,
|
|
1051
|
+
)
|
|
1052
|
+
x = ops.divide(x, inputs_scale)
|
|
1053
|
+
else:
|
|
1054
|
+
x = ops.einsum(self.equation, inputs_q, float_kernel)
|
|
1055
|
+
x = ops.cast(x, self.compute_dtype)
|
|
1056
|
+
x = ops.divide(x, inputs_scale)
|
|
1057
|
+
else:
|
|
1058
|
+
# Weight-only per-channel quantization
|
|
1059
|
+
float_kernel = _dequantize_kernel(
|
|
1060
|
+
unpacked_kernel, kernel_scale
|
|
1061
|
+
)
|
|
1062
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
1063
|
+
return x, grad_fn
|
|
976
1064
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
unpacked_kernel = quantizers.unpack_int4(
|
|
982
|
-
packed_kernel, orig_len, axis=pack_axis
|
|
1065
|
+
x = einsum_per_channel_with_inputs_gradient(
|
|
1066
|
+
inputs,
|
|
1067
|
+
ops.convert_to_tensor(self._kernel),
|
|
1068
|
+
ops.convert_to_tensor(self.kernel_scale),
|
|
983
1069
|
)
|
|
1070
|
+
else:
|
|
984
1071
|
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
self._custom_gradient_equation, upstream, float_kernel
|
|
1072
|
+
@ops.custom_gradient
|
|
1073
|
+
def einsum_sub_channel_with_inputs_gradient(
|
|
1074
|
+
inputs, packed_kernel, kernel_scale, kernel_zero, g_idx
|
|
1075
|
+
):
|
|
1076
|
+
"""Sub-channel int4 forward pass with custom gradient."""
|
|
1077
|
+
# Unpack: stored as [rows, ceil(columns/2)],
|
|
1078
|
+
# unpack along last axis
|
|
1079
|
+
unpacked_kernel = quantizers.unpack_int4(
|
|
1080
|
+
packed_kernel,
|
|
1081
|
+
self._int4_unpacked_column_size,
|
|
1082
|
+
axis=-1,
|
|
1083
|
+
dtype="int8",
|
|
998
1084
|
)
|
|
999
|
-
return (inputs_grad, None, None)
|
|
1000
1085
|
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
# Align `inputs_scale` axes with the output
|
|
1007
|
-
# for correct broadcasting
|
|
1008
|
-
inputs_scale = self._adjust_scale_for_quant(
|
|
1009
|
-
inputs_scale, "input"
|
|
1010
|
-
)
|
|
1011
|
-
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
|
|
1012
|
-
# De-scale outputs
|
|
1013
|
-
x = ops.cast(x, self.compute_dtype)
|
|
1014
|
-
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
1015
|
-
else:
|
|
1016
|
-
# Weight-only quantization: dequantize kernel and use float
|
|
1017
|
-
# einsum. This is a workaround for PyTorch's einsum which
|
|
1018
|
-
# doesn't support mixed-precision inputs (float input,
|
|
1019
|
-
# int4 kernel).
|
|
1020
|
-
if backend.backend() == "torch":
|
|
1021
|
-
# Align `kernel_scale` to the same layout as
|
|
1022
|
-
# `unpacked_kernel`.
|
|
1023
|
-
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
1024
|
-
float_kernel = ops.divide(
|
|
1025
|
-
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
|
|
1026
|
-
kernel_scale,
|
|
1086
|
+
def _dequantize_kernel(unpacked, scale, zero, g_idx_t):
|
|
1087
|
+
# Dequantize with group_axis=0 since
|
|
1088
|
+
# scale is [n_groups, columns]
|
|
1089
|
+
float_kernel = dequantize_with_sz_map(
|
|
1090
|
+
unpacked, scale, zero, g_idx_t, group_axis=0
|
|
1027
1091
|
)
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1092
|
+
float_kernel = ops.cast(float_kernel, self.compute_dtype)
|
|
1093
|
+
return ops.reshape(float_kernel, self.original_kernel_shape)
|
|
1094
|
+
|
|
1095
|
+
def grad_fn(*args, upstream=None):
|
|
1096
|
+
if upstream is None:
|
|
1097
|
+
(upstream,) = args
|
|
1098
|
+
float_kernel = _dequantize_kernel(
|
|
1099
|
+
unpacked_kernel, kernel_scale, kernel_zero, g_idx
|
|
1100
|
+
)
|
|
1101
|
+
inputs_grad = ops.einsum(
|
|
1102
|
+
self._custom_gradient_equation, upstream, float_kernel
|
|
1103
|
+
)
|
|
1104
|
+
return (inputs_grad, None, None, None, None)
|
|
1035
1105
|
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1106
|
+
float_kernel = _dequantize_kernel(
|
|
1107
|
+
unpacked_kernel, kernel_scale, kernel_zero, g_idx
|
|
1108
|
+
)
|
|
1109
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
1110
|
+
return x, grad_fn
|
|
1111
|
+
|
|
1112
|
+
x = einsum_sub_channel_with_inputs_gradient(
|
|
1113
|
+
inputs,
|
|
1114
|
+
ops.convert_to_tensor(self._kernel),
|
|
1115
|
+
ops.convert_to_tensor(self.kernel_scale),
|
|
1116
|
+
ops.convert_to_tensor(self.kernel_zero),
|
|
1117
|
+
ops.convert_to_tensor(self.g_idx),
|
|
1118
|
+
)
|
|
1041
1119
|
|
|
1042
|
-
# Add LoRA contribution if enabled
|
|
1043
1120
|
if self.lora_enabled:
|
|
1044
1121
|
lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)
|
|
1045
1122
|
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
|
|
@@ -1167,24 +1244,50 @@ class EinsumDense(Layer):
|
|
|
1167
1244
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
1168
1245
|
del self._kernel
|
|
1169
1246
|
elif mode == "int4":
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
self.quantization_config,
|
|
1173
|
-
quantizers.AbsMaxQuantizer(
|
|
1174
|
-
axis=self._kernel_reduced_axes,
|
|
1175
|
-
value_range=(-8, 7),
|
|
1176
|
-
output_dtype="int8",
|
|
1177
|
-
),
|
|
1247
|
+
from keras.src.quantizers.quantization_config import (
|
|
1248
|
+
Int4QuantizationConfig,
|
|
1178
1249
|
)
|
|
1179
|
-
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
1180
|
-
self._kernel, to_numpy=True
|
|
1181
|
-
)
|
|
1182
|
-
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
1183
1250
|
|
|
1184
|
-
|
|
1185
|
-
|
|
1251
|
+
block_size = None
|
|
1252
|
+
if isinstance(self.quantization_config, Int4QuantizationConfig):
|
|
1253
|
+
block_size = self.quantization_config.block_size
|
|
1254
|
+
|
|
1255
|
+
use_grouped = block_size is not None and block_size != -1
|
|
1256
|
+
|
|
1257
|
+
# Flatten kernel to 2D: rows = reduced dims, columns = non-reduced
|
|
1258
|
+
rows = 1
|
|
1259
|
+
columns = 1
|
|
1260
|
+
for i, dim in enumerate(kernel_shape):
|
|
1261
|
+
if i in self._kernel_reduced_axes:
|
|
1262
|
+
rows *= dim
|
|
1263
|
+
else:
|
|
1264
|
+
columns *= dim
|
|
1265
|
+
|
|
1266
|
+
flat_kernel = ops.reshape(self._kernel, (rows, columns))
|
|
1267
|
+
|
|
1268
|
+
if not use_grouped:
|
|
1269
|
+
# Per-channel quantization
|
|
1270
|
+
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
|
|
1271
|
+
flat_kernel,
|
|
1272
|
+
axis=0,
|
|
1273
|
+
value_range=(-8, 7),
|
|
1274
|
+
dtype="int8",
|
|
1275
|
+
to_numpy=True,
|
|
1276
|
+
)
|
|
1277
|
+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
1278
|
+
else:
|
|
1279
|
+
# Sub-channel quantization with asymmetric zero point
|
|
1280
|
+
# Returns kernel [rows, columns], scale [n_groups, columns]
|
|
1281
|
+
kernel_value_int4, kernel_scale, kernel_zero = (
|
|
1282
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
1283
|
+
flat_kernel, block_size=block_size, to_numpy=True
|
|
1284
|
+
)
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
# Pack two int4 values per int8 byte along last axis
|
|
1288
|
+
# Stored as [rows, ceil(columns/2)]
|
|
1186
1289
|
packed_kernel_value, _, _ = quantizers.pack_int4(
|
|
1187
|
-
kernel_value_int4, axis
|
|
1290
|
+
kernel_value_int4, axis=-1
|
|
1188
1291
|
)
|
|
1189
1292
|
kernel_value = packed_kernel_value
|
|
1190
1293
|
del self._kernel
|
|
@@ -1194,42 +1297,69 @@ class EinsumDense(Layer):
|
|
|
1194
1297
|
if mode in ("int8", "int4"):
|
|
1195
1298
|
self._kernel.assign(kernel_value)
|
|
1196
1299
|
self.kernel_scale.assign(kernel_scale)
|
|
1300
|
+
# Assign zero point for sub-channel int4 quantization
|
|
1301
|
+
if mode == "int4" and use_grouped:
|
|
1302
|
+
self.kernel_zero.assign(kernel_zero)
|
|
1197
1303
|
|
|
1198
1304
|
# Set new dtype policy
|
|
1199
1305
|
if self.dtype_policy.quantization_mode is None:
|
|
1200
1306
|
policy_name = mode
|
|
1201
|
-
if mode
|
|
1202
|
-
policy_name = self.quantization_config.dtype_policy_string()
|
|
1203
|
-
elif mode == "awq":
|
|
1307
|
+
if mode in ("gptq", "awq"):
|
|
1204
1308
|
policy_name = self.quantization_config.dtype_policy_string()
|
|
1309
|
+
elif mode == "int4":
|
|
1310
|
+
# Include block_size in policy name for sub-channel quantization
|
|
1311
|
+
block_size = get_block_size_for_layer(self, config)
|
|
1312
|
+
# Use -1 for per-channel, otherwise use block_size
|
|
1313
|
+
block_size_value = -1 if block_size is None else block_size
|
|
1314
|
+
policy_name = f"int4/{block_size_value}"
|
|
1205
1315
|
policy = dtype_policies.get(
|
|
1206
1316
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
1207
1317
|
)
|
|
1208
1318
|
self.dtype_policy = policy
|
|
1209
1319
|
|
|
1210
|
-
def _get_kernel_scale_shape(self, kernel_shape):
|
|
1320
|
+
def _get_kernel_scale_shape(self, kernel_shape, block_size=None):
|
|
1211
1321
|
"""Get the shape of the kernel scale tensor.
|
|
1212
1322
|
|
|
1213
1323
|
The kernel scale tensor is used to scale the kernel tensor.
|
|
1214
1324
|
The shape of the kernel scale tensor is the same as the shape of the
|
|
1215
|
-
kernel tensor, but with the reduced axes set to 1
|
|
1216
|
-
|
|
1325
|
+
kernel tensor, but with the reduced axes set to 1 (for per-channel)
|
|
1326
|
+
or n_groups (for grouped quantization), and the transpose axes set
|
|
1327
|
+
to the original axes.
|
|
1217
1328
|
|
|
1218
1329
|
Args:
|
|
1219
1330
|
kernel_shape: The shape of the kernel tensor.
|
|
1331
|
+
block_size: If provided and positive, use grouped quantization
|
|
1332
|
+
along the reduced axes with the specified block size.
|
|
1220
1333
|
|
|
1221
1334
|
Returns:
|
|
1222
1335
|
The shape of the kernel scale tensor.
|
|
1223
1336
|
"""
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1337
|
+
if block_size is not None and block_size > 0:
|
|
1338
|
+
# Grouped quantization: use simple 2D scale shape
|
|
1339
|
+
# (n_groups, non_reduced) - matches dequantize_grouped format
|
|
1340
|
+
total_reduced_dim = 1
|
|
1341
|
+
for ax in self._kernel_reduced_axes:
|
|
1342
|
+
total_reduced_dim *= kernel_shape[ax]
|
|
1343
|
+
n_groups = math.ceil(total_reduced_dim / block_size)
|
|
1344
|
+
|
|
1345
|
+
total_non_reduced = 1
|
|
1346
|
+
for i, dim in enumerate(kernel_shape):
|
|
1347
|
+
if i not in self._kernel_reduced_axes:
|
|
1348
|
+
total_non_reduced *= dim
|
|
1349
|
+
|
|
1350
|
+
return (n_groups, total_non_reduced)
|
|
1351
|
+
else:
|
|
1352
|
+
# Per-channel quantization: use the original transformation logic
|
|
1353
|
+
kernel_scale_shape = np.array(kernel_shape)
|
|
1354
|
+
kernel_scale_shape[self._kernel_reduced_axes] = 1
|
|
1355
|
+
|
|
1356
|
+
kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes]
|
|
1357
|
+
kernel_scale_shape = kernel_scale_shape.tolist()
|
|
1358
|
+
for a in sorted(self._kernel_expand_axes):
|
|
1359
|
+
kernel_scale_shape.insert(a, 1)
|
|
1360
|
+
for a in sorted(self._kernel_squeeze_axes, reverse=True):
|
|
1361
|
+
kernel_scale_shape.pop(a)
|
|
1362
|
+
return kernel_scale_shape
|
|
1233
1363
|
|
|
1234
1364
|
def _get_kernel_with_merged_lora(self):
|
|
1235
1365
|
"""Returns the kernel with LoRA matrices merged, for serialization.
|
|
@@ -1258,33 +1388,53 @@ class EinsumDense(Layer):
|
|
|
1258
1388
|
without modification.
|
|
1259
1389
|
|
|
1260
1390
|
Returns:
|
|
1261
|
-
A tuple `(kernel_value, kernel_scale)`:
|
|
1391
|
+
A tuple `(kernel_value, kernel_scale, kernel_zero)`:
|
|
1262
1392
|
`kernel_value`: The merged kernel. A quantized tensor if
|
|
1263
1393
|
quantization is active, otherwise a high precision tensor.
|
|
1264
1394
|
`kernel_scale`: The quantization scale for the merged kernel.
|
|
1265
1395
|
This is `None` if the layer is not quantized.
|
|
1396
|
+
`kernel_zero`: The zero point for sub-channel int4 quantization.
|
|
1397
|
+
This is `None` for per-channel or non-int4 modes.
|
|
1266
1398
|
"""
|
|
1267
1399
|
# If not a quantized layer, return the full-precision kernel directly.
|
|
1268
1400
|
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
1269
|
-
return self.kernel, None
|
|
1401
|
+
return self.kernel, None, None
|
|
1402
|
+
|
|
1403
|
+
kernel_zero = getattr(self, "kernel_zero", None)
|
|
1270
1404
|
|
|
1271
1405
|
# If quantized but LoRA is not enabled, return the original quantized
|
|
1272
1406
|
# kernel.
|
|
1273
1407
|
if not self.lora_enabled:
|
|
1274
|
-
return self._kernel, self.kernel_scale
|
|
1408
|
+
return self._kernel, self.kernel_scale, kernel_zero
|
|
1275
1409
|
|
|
1276
1410
|
# Dequantize, Merge, and Re-quantize
|
|
1277
1411
|
|
|
1278
1412
|
# 1. Dequantize the kernel
|
|
1279
1413
|
if self.quantization_mode == "int4":
|
|
1414
|
+
# Unpack [rows, ceil(columns/2)] to [rows, columns]
|
|
1280
1415
|
unpacked_kernel = quantizers.unpack_int4(
|
|
1281
1416
|
self._kernel,
|
|
1282
|
-
self.
|
|
1283
|
-
axis
|
|
1417
|
+
self._int4_unpacked_column_size,
|
|
1418
|
+
axis=-1,
|
|
1284
1419
|
)
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1420
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
1421
|
+
if block_size is not None and block_size != -1:
|
|
1422
|
+
# Grouped dequantization with group_axis=0
|
|
1423
|
+
kernel_fp = dequantize_with_sz_map(
|
|
1424
|
+
unpacked_kernel,
|
|
1425
|
+
self.kernel_scale,
|
|
1426
|
+
self.kernel_zero,
|
|
1427
|
+
self.g_idx,
|
|
1428
|
+
group_axis=0,
|
|
1429
|
+
)
|
|
1430
|
+
else:
|
|
1431
|
+
# Per-channel dequantization:
|
|
1432
|
+
# kernel [rows, columns], scale [columns]
|
|
1433
|
+
kernel_fp = ops.divide(
|
|
1434
|
+
ops.cast(unpacked_kernel, self.compute_dtype),
|
|
1435
|
+
self.kernel_scale,
|
|
1436
|
+
)
|
|
1437
|
+
kernel_fp = ops.reshape(kernel_fp, self.original_kernel_shape)
|
|
1288
1438
|
elif self.quantization_mode == "int8":
|
|
1289
1439
|
adjusted_scale = self._adjust_scale_for_dequant(self.kernel_scale)
|
|
1290
1440
|
kernel_fp = ops.divide(self._kernel, adjusted_scale)
|
|
@@ -1297,32 +1447,51 @@ class EinsumDense(Layer):
|
|
|
1297
1447
|
lora_update = (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
1298
1448
|
self.lora_kernel_a, self.lora_kernel_b
|
|
1299
1449
|
)
|
|
1300
|
-
|
|
1450
|
+
merged_kernel = ops.add(kernel_fp, lora_update)
|
|
1301
1451
|
|
|
1302
1452
|
# 3. Re-quantize the merged float kernel back to the target format
|
|
1303
1453
|
if self.quantization_mode == "int4":
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1454
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
1455
|
+
rows = self._int4_rows
|
|
1456
|
+
columns = self._int4_unpacked_column_size
|
|
1457
|
+
|
|
1458
|
+
# Flatten to 2D [rows, columns]
|
|
1459
|
+
flat_kernel = ops.reshape(merged_kernel, (rows, columns))
|
|
1460
|
+
|
|
1461
|
+
if block_size is not None and block_size != -1:
|
|
1462
|
+
# Use abs_max_quantize_grouped_with_zero_point for proper
|
|
1463
|
+
# signed quantization (same as quantize() method)
|
|
1464
|
+
# Returns kernel [rows, columns], scale [n_groups, columns]
|
|
1465
|
+
kernel_quant, new_scale, new_zero = (
|
|
1466
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
1467
|
+
flat_kernel, block_size=block_size, to_numpy=True
|
|
1468
|
+
)
|
|
1469
|
+
)
|
|
1470
|
+
kernel_zero = new_zero
|
|
1471
|
+
else:
|
|
1472
|
+
# Per-channel: quantize along rows axis
|
|
1473
|
+
kernel_quant, new_scale = quantizers.abs_max_quantize(
|
|
1474
|
+
flat_kernel,
|
|
1475
|
+
axis=0,
|
|
1476
|
+
value_range=(-8, 7),
|
|
1477
|
+
dtype="int8",
|
|
1478
|
+
to_numpy=True,
|
|
1479
|
+
)
|
|
1480
|
+
new_scale = ops.squeeze(new_scale, axis=0)
|
|
1481
|
+
kernel_zero = None
|
|
1482
|
+
|
|
1483
|
+
# Pack along last axis
|
|
1484
|
+
new_kernel, _, _ = quantizers.pack_int4(kernel_quant, axis=-1)
|
|
1315
1485
|
elif self.quantization_mode == "int8":
|
|
1316
1486
|
new_kernel, new_scale = quantizers.abs_max_quantize(
|
|
1317
|
-
|
|
1487
|
+
merged_kernel,
|
|
1318
1488
|
axis=self._kernel_reduced_axes,
|
|
1319
1489
|
to_numpy=True,
|
|
1320
1490
|
)
|
|
1491
|
+
new_scale = self._adjust_scale_for_quant(new_scale, "kernel")
|
|
1492
|
+
kernel_zero = None
|
|
1321
1493
|
|
|
1322
|
-
|
|
1323
|
-
new_scale = self._adjust_scale_for_quant(new_scale, "kernel")
|
|
1324
|
-
|
|
1325
|
-
return new_kernel, new_scale
|
|
1494
|
+
return new_kernel, new_scale, kernel_zero
|
|
1326
1495
|
|
|
1327
1496
|
def _adjust_scale_for_dequant(self, scale):
|
|
1328
1497
|
"""Adjusts scale tensor layout for dequantization.
|