keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/quantizers/__init__.py +3 -0
- keras/src/backend/jax/core.py +12 -2
- keras/src/callbacks/orbax_checkpoint.py +41 -8
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +80 -1
- keras/src/layers/core/dense.py +278 -95
- keras/src/layers/core/einsum_dense.py +350 -181
- keras/src/layers/core/embedding.py +236 -49
- keras/src/layers/core/reversible_embedding.py +177 -35
- keras/src/layers/preprocessing/discretization.py +30 -1
- keras/src/quantizers/__init__.py +6 -0
- keras/src/quantizers/quantization_config.py +98 -4
- keras/src/quantizers/quantizers.py +262 -32
- keras/src/saving/saving_api.py +66 -2
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +22 -22
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
keras/src/layers/core/dense.py
CHANGED
|
@@ -12,6 +12,7 @@ from keras.src.api_export import keras_export
|
|
|
12
12
|
from keras.src.layers.input_spec import InputSpec
|
|
13
13
|
from keras.src.layers.layer import Layer
|
|
14
14
|
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
15
|
+
from keras.src.quantizers.quantization_config import get_block_size_for_layer
|
|
15
16
|
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
16
17
|
from keras.src.saving import serialization_lib
|
|
17
18
|
|
|
@@ -184,7 +185,10 @@ class Dense(Layer):
|
|
|
184
185
|
|
|
185
186
|
# Handle int4 unpacking cases in one place
|
|
186
187
|
if is_int4:
|
|
187
|
-
|
|
188
|
+
# unpack [in, ceil(out/2)] to [in, out]
|
|
189
|
+
kernel = quantizers.unpack_int4(
|
|
190
|
+
kernel, self._orig_output_dim, axis=-1
|
|
191
|
+
)
|
|
188
192
|
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
189
193
|
kernel = quantizers.unpack_int4(
|
|
190
194
|
self.quantized_kernel,
|
|
@@ -287,15 +291,27 @@ class Dense(Layer):
|
|
|
287
291
|
if mode not in self.variable_serialization_spec:
|
|
288
292
|
raise self._quantization_mode_error(mode)
|
|
289
293
|
|
|
290
|
-
# Kernel plus optional merged LoRA-aware scale (returns
|
|
291
|
-
# for None/gptq)
|
|
292
|
-
kernel_value, merged_kernel_scale =
|
|
294
|
+
# Kernel plus optional merged LoRA-aware scale/zero (returns
|
|
295
|
+
# (kernel, None, None) for None/gptq/awq)
|
|
296
|
+
kernel_value, merged_kernel_scale, merged_kernel_zero = (
|
|
297
|
+
self._get_kernel_with_merged_lora()
|
|
298
|
+
)
|
|
293
299
|
idx = 0
|
|
294
300
|
for name in self.variable_serialization_spec[mode]:
|
|
295
301
|
if name == "kernel":
|
|
296
302
|
store[str(idx)] = kernel_value
|
|
297
303
|
elif name == "bias" and self.bias is None:
|
|
298
304
|
continue
|
|
305
|
+
elif name == "kernel_zero":
|
|
306
|
+
if merged_kernel_zero is None:
|
|
307
|
+
# kernel_zero only exists for sub-channel int4 quantization
|
|
308
|
+
continue
|
|
309
|
+
store[str(idx)] = merged_kernel_zero
|
|
310
|
+
elif name == "g_idx":
|
|
311
|
+
if not hasattr(self, "g_idx"):
|
|
312
|
+
# g_idx only exists for sub-channel int4 quantization
|
|
313
|
+
continue
|
|
314
|
+
store[str(idx)] = self.g_idx
|
|
299
315
|
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
300
316
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
301
317
|
# `_get_kernel_with_merged_lora()`
|
|
@@ -324,6 +340,12 @@ class Dense(Layer):
|
|
|
324
340
|
self._kernel.assign(store[str(idx)])
|
|
325
341
|
elif name == "bias" and self.bias is None:
|
|
326
342
|
continue
|
|
343
|
+
elif name == "kernel_zero" and not hasattr(self, "kernel_zero"):
|
|
344
|
+
# kernel_zero only exists for sub-channel int4 quantization
|
|
345
|
+
continue
|
|
346
|
+
elif name == "g_idx" and not hasattr(self, "g_idx"):
|
|
347
|
+
# g_idx only exists for sub-channel int4 quantization
|
|
348
|
+
continue
|
|
327
349
|
else:
|
|
328
350
|
getattr(self, name).assign(store[str(idx)])
|
|
329
351
|
idx += 1
|
|
@@ -388,6 +410,8 @@ class Dense(Layer):
|
|
|
388
410
|
"kernel",
|
|
389
411
|
"bias",
|
|
390
412
|
"kernel_scale",
|
|
413
|
+
"kernel_zero",
|
|
414
|
+
"g_idx",
|
|
391
415
|
],
|
|
392
416
|
"float8": [
|
|
393
417
|
"kernel",
|
|
@@ -630,37 +654,75 @@ class Dense(Layer):
|
|
|
630
654
|
def _int4_build(self, kernel_shape, config=None):
|
|
631
655
|
"""Build variables for int4 quantization.
|
|
632
656
|
|
|
633
|
-
|
|
634
|
-
`(input_dim, units)`.
|
|
635
|
-
|
|
636
|
-
|
|
657
|
+
The kernel is packed along the last axis,
|
|
658
|
+
resulting in shape `(input_dim, ceil(units/2))`.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
kernel_shape: The original float32 kernel shape
|
|
662
|
+
`(input_dim, units)`.
|
|
663
|
+
config: Optional quantization config specifying block_size.
|
|
637
664
|
"""
|
|
638
|
-
# Per-channel int8 quantizer for the last axis (features).
|
|
639
665
|
self.inputs_quantizer = (
|
|
640
|
-
QuantizationConfig.activation_quantizer_or_default(
|
|
641
|
-
config, quantizers.AbsMaxQuantizer()
|
|
642
|
-
)
|
|
666
|
+
QuantizationConfig.activation_quantizer_or_default(config, None)
|
|
643
667
|
)
|
|
644
668
|
input_dim, output_dim = kernel_shape
|
|
645
|
-
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
|
|
646
669
|
|
|
647
|
-
#
|
|
670
|
+
# kernel is packed along last axis (output dimension)
|
|
671
|
+
# Stored shape: [input_dim, ceil(output_dim/2)]
|
|
672
|
+
packed_cols = (output_dim + 1) // 2
|
|
673
|
+
|
|
648
674
|
self._kernel = self.add_weight(
|
|
649
675
|
name="kernel",
|
|
650
|
-
shape=(
|
|
676
|
+
shape=(input_dim, packed_cols),
|
|
651
677
|
initializer="zeros",
|
|
652
678
|
dtype="int8",
|
|
653
679
|
trainable=False,
|
|
654
680
|
)
|
|
655
|
-
|
|
681
|
+
|
|
682
|
+
block_size = get_block_size_for_layer(self, config)
|
|
683
|
+
self._int4_block_size = block_size
|
|
684
|
+
|
|
685
|
+
if block_size is None or block_size == -1:
|
|
686
|
+
# Per-channel: one scale per output unit
|
|
687
|
+
scale_shape = (self.units,)
|
|
688
|
+
else:
|
|
689
|
+
# Sub-channel: [n_groups, out_features]
|
|
690
|
+
n_groups = math.ceil(input_dim / block_size)
|
|
691
|
+
scale_shape = (n_groups, self.units)
|
|
692
|
+
|
|
656
693
|
self.kernel_scale = self.add_weight(
|
|
657
694
|
name="kernel_scale",
|
|
658
|
-
shape=
|
|
695
|
+
shape=scale_shape,
|
|
659
696
|
initializer="ones",
|
|
660
697
|
trainable=False,
|
|
661
698
|
)
|
|
662
|
-
|
|
699
|
+
|
|
700
|
+
# Sub-channel quantization uses asymmetric quantization
|
|
701
|
+
if block_size is not None and block_size > 0:
|
|
702
|
+
|
|
703
|
+
def idx_initializer(shape, dtype):
|
|
704
|
+
return ops.floor_divide(
|
|
705
|
+
ops.arange(input_dim, dtype=dtype), block_size
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
self.kernel_zero = self.add_weight(
|
|
709
|
+
name="kernel_zero",
|
|
710
|
+
shape=scale_shape,
|
|
711
|
+
initializer="zeros",
|
|
712
|
+
dtype="int8",
|
|
713
|
+
trainable=False,
|
|
714
|
+
)
|
|
715
|
+
self.g_idx = self.add_weight(
|
|
716
|
+
name="g_idx",
|
|
717
|
+
shape=(input_dim,),
|
|
718
|
+
initializer=idx_initializer,
|
|
719
|
+
dtype="float32",
|
|
720
|
+
trainable=False,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Record dimensions for unpacking and reshaping at runtime.
|
|
663
724
|
self._orig_input_dim = input_dim
|
|
725
|
+
self._orig_output_dim = output_dim
|
|
664
726
|
|
|
665
727
|
def _float8_build(self):
|
|
666
728
|
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
|
|
@@ -755,57 +817,108 @@ class Dense(Layer):
|
|
|
755
817
|
return x
|
|
756
818
|
|
|
757
819
|
def _int4_call(self, inputs, training=None):
|
|
758
|
-
"""Forward pass for int4 quantized Dense layer.
|
|
820
|
+
"""Forward pass for int4 quantized Dense layer.
|
|
759
821
|
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
822
|
+
Uses custom gradients to handle quantized weights since autodiff
|
|
823
|
+
cannot differentiate through int4 operations.
|
|
824
|
+
"""
|
|
825
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
826
|
+
|
|
827
|
+
if block_size is None or block_size == -1:
|
|
828
|
+
# Per-channel: symmetric quantization (no zero point needed)
|
|
829
|
+
@ops.custom_gradient
|
|
830
|
+
def matmul_per_channel_with_inputs_gradient(
|
|
831
|
+
inputs, kernel, kernel_scale
|
|
832
|
+
):
|
|
833
|
+
"""Per-channel int4 forward pass with custom gradient."""
|
|
834
|
+
# Unpack: stored as [in, ceil(out/2)], unpack along last axis
|
|
835
|
+
unpacked_kernel = quantizers.unpack_int4(
|
|
836
|
+
kernel, self._orig_output_dim, axis=-1
|
|
837
|
+
)
|
|
771
838
|
|
|
772
|
-
|
|
773
|
-
|
|
839
|
+
def grad_fn(*args, upstream=None):
|
|
840
|
+
if upstream is None:
|
|
841
|
+
(upstream,) = args
|
|
842
|
+
# Per-channel: unpacked is [in, out]
|
|
843
|
+
float_kernel = ops.divide(
|
|
844
|
+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
|
|
845
|
+
kernel_scale,
|
|
846
|
+
)
|
|
847
|
+
inputs_grad = ops.matmul(
|
|
848
|
+
upstream, ops.transpose(float_kernel)
|
|
849
|
+
)
|
|
850
|
+
return (inputs_grad, None, None)
|
|
851
|
+
|
|
852
|
+
# Forward pass: per-channel dequantization
|
|
853
|
+
output_scale = kernel_scale
|
|
854
|
+
if self.inputs_quantizer:
|
|
855
|
+
inputs, inputs_scale = self.inputs_quantizer(
|
|
856
|
+
inputs, axis=-1
|
|
857
|
+
)
|
|
858
|
+
output_scale = ops.multiply(output_scale, inputs_scale)
|
|
859
|
+
|
|
860
|
+
x = ops.matmul(inputs, unpacked_kernel)
|
|
861
|
+
x = ops.cast(x, self.compute_dtype)
|
|
862
|
+
x = ops.divide(x, output_scale)
|
|
863
|
+
return x, grad_fn
|
|
864
|
+
|
|
865
|
+
x = matmul_per_channel_with_inputs_gradient(
|
|
866
|
+
inputs,
|
|
867
|
+
ops.convert_to_tensor(self._kernel),
|
|
868
|
+
ops.convert_to_tensor(self.kernel_scale),
|
|
774
869
|
)
|
|
870
|
+
else:
|
|
871
|
+
# Sub-channel: asymmetric quantization (with zero point)
|
|
872
|
+
@ops.custom_gradient
|
|
873
|
+
def matmul_sub_channel_with_inputs_gradient(
|
|
874
|
+
inputs, kernel, kernel_scale, kernel_zero, g_idx
|
|
875
|
+
):
|
|
876
|
+
"""Sub-channel int4 forward pass with custom gradient."""
|
|
877
|
+
# Unpack: stored as [in, ceil(out/2)], unpack along last axis
|
|
878
|
+
unpacked_kernel = quantizers.unpack_int4(
|
|
879
|
+
kernel, self._orig_output_dim, axis=-1
|
|
880
|
+
)
|
|
775
881
|
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
882
|
+
def grad_fn(*args, upstream=None):
|
|
883
|
+
if upstream is None:
|
|
884
|
+
(upstream,) = args
|
|
885
|
+
float_kernel = dequantize_with_sz_map(
|
|
886
|
+
unpacked_kernel,
|
|
887
|
+
kernel_scale,
|
|
888
|
+
kernel_zero,
|
|
889
|
+
g_idx,
|
|
890
|
+
group_axis=0,
|
|
891
|
+
)
|
|
892
|
+
float_kernel = ops.cast(float_kernel, self.compute_dtype)
|
|
893
|
+
inputs_grad = ops.matmul(
|
|
894
|
+
upstream, ops.transpose(float_kernel)
|
|
895
|
+
)
|
|
896
|
+
return (inputs_grad, None, None, None, None)
|
|
897
|
+
|
|
898
|
+
float_kernel = dequantize_with_sz_map(
|
|
899
|
+
unpacked_kernel,
|
|
781
900
|
kernel_scale,
|
|
901
|
+
kernel_zero,
|
|
902
|
+
g_idx,
|
|
903
|
+
group_axis=0,
|
|
782
904
|
)
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
output_scale = kernel_scale
|
|
905
|
+
float_kernel = ops.cast(float_kernel, self.compute_dtype)
|
|
906
|
+
x = ops.matmul(inputs, float_kernel)
|
|
907
|
+
return x, grad_fn
|
|
787
908
|
|
|
788
|
-
|
|
789
|
-
inputs,
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
return x, grad_fn
|
|
796
|
-
|
|
797
|
-
x = matmul_with_inputs_gradient(
|
|
798
|
-
inputs,
|
|
799
|
-
ops.convert_to_tensor(self._kernel),
|
|
800
|
-
ops.convert_to_tensor(self.kernel_scale),
|
|
801
|
-
)
|
|
909
|
+
x = matmul_sub_channel_with_inputs_gradient(
|
|
910
|
+
inputs,
|
|
911
|
+
ops.convert_to_tensor(self._kernel),
|
|
912
|
+
ops.convert_to_tensor(self.kernel_scale),
|
|
913
|
+
ops.convert_to_tensor(self.kernel_zero),
|
|
914
|
+
ops.convert_to_tensor(self.g_idx),
|
|
915
|
+
)
|
|
802
916
|
|
|
803
917
|
if self.lora_enabled:
|
|
804
918
|
lora_x = ops.matmul(inputs, self.lora_kernel_a)
|
|
805
919
|
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
|
|
806
920
|
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
|
|
807
921
|
|
|
808
|
-
# Add bias and activation
|
|
809
922
|
if self.bias is not None:
|
|
810
923
|
x = ops.add(x, self.bias)
|
|
811
924
|
if self.activation is not None:
|
|
@@ -925,26 +1038,49 @@ class Dense(Layer):
|
|
|
925
1038
|
self._kernel.assign(kernel_value)
|
|
926
1039
|
self.kernel_scale.assign(kernel_scale)
|
|
927
1040
|
elif mode == "int4":
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
self.quantization_config,
|
|
931
|
-
quantizers.AbsMaxQuantizer(
|
|
932
|
-
axis=0, value_range=(-8, 7), output_dtype="int8"
|
|
933
|
-
),
|
|
1041
|
+
from keras.src.quantizers.quantization_config import (
|
|
1042
|
+
Int4QuantizationConfig,
|
|
934
1043
|
)
|
|
935
|
-
|
|
936
|
-
|
|
1044
|
+
|
|
1045
|
+
block_size = None
|
|
1046
|
+
if isinstance(self.quantization_config, Int4QuantizationConfig):
|
|
1047
|
+
block_size = self.quantization_config.block_size
|
|
1048
|
+
|
|
1049
|
+
if block_size is None or block_size == -1:
|
|
1050
|
+
# Per-channel quantization
|
|
1051
|
+
weight_quantizer = (
|
|
1052
|
+
QuantizationConfig.weight_quantizer_or_default(
|
|
1053
|
+
self.quantization_config,
|
|
1054
|
+
quantizers.AbsMaxQuantizer(
|
|
1055
|
+
axis=0, value_range=(-8, 7), output_dtype="int8"
|
|
1056
|
+
),
|
|
1057
|
+
)
|
|
1058
|
+
)
|
|
1059
|
+
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
1060
|
+
self._kernel, to_numpy=True
|
|
1061
|
+
)
|
|
1062
|
+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
1063
|
+
else:
|
|
1064
|
+
# Sub-channel quantization with asymmetric zero point
|
|
1065
|
+
# Returns kernel [in, out], scale [n_groups, out], zero
|
|
1066
|
+
# [n_groups, out]
|
|
1067
|
+
kernel_value_int4, kernel_scale, kernel_zero = (
|
|
1068
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
1069
|
+
self._kernel, block_size=block_size, to_numpy=True
|
|
1070
|
+
)
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
# Pack two int4 values per int8 byte along last axis
|
|
1074
|
+
# Stored as [in, ceil(out/2)]
|
|
1075
|
+
packed_kernel_value, _, _ = quantizers.pack_int4(
|
|
1076
|
+
kernel_value_int4, axis=-1
|
|
937
1077
|
)
|
|
938
|
-
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
939
|
-
# 2. Pack two int4 values into a single int8 byte.
|
|
940
|
-
packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4)
|
|
941
1078
|
del self._kernel
|
|
942
|
-
# Build variables using the original kernel shape; _int4_build will
|
|
943
|
-
# compute the packed shape internally.
|
|
944
1079
|
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
945
|
-
# Assign packed values.
|
|
946
1080
|
self._kernel.assign(packed_kernel_value)
|
|
947
1081
|
self.kernel_scale.assign(kernel_scale)
|
|
1082
|
+
if block_size is not None and block_size > 0:
|
|
1083
|
+
self.kernel_zero.assign(kernel_zero)
|
|
948
1084
|
elif mode == "gptq":
|
|
949
1085
|
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
950
1086
|
elif mode == "awq":
|
|
@@ -959,10 +1095,14 @@ class Dense(Layer):
|
|
|
959
1095
|
from keras.src import dtype_policies # local import to avoid cycle
|
|
960
1096
|
|
|
961
1097
|
policy_name = mode
|
|
962
|
-
if mode
|
|
963
|
-
policy_name = self.quantization_config.dtype_policy_string()
|
|
964
|
-
elif mode == "awq":
|
|
1098
|
+
if mode in ("gptq", "awq"):
|
|
965
1099
|
policy_name = self.quantization_config.dtype_policy_string()
|
|
1100
|
+
elif mode == "int4":
|
|
1101
|
+
# Include block_size in policy name for sub-channel quantization
|
|
1102
|
+
block_size = get_block_size_for_layer(self, config)
|
|
1103
|
+
# Use -1 for per-channel, otherwise use block_size
|
|
1104
|
+
block_size_value = -1 if block_size is None else block_size
|
|
1105
|
+
policy_name = f"int4/{block_size_value}"
|
|
966
1106
|
policy = dtype_policies.get(
|
|
967
1107
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
968
1108
|
)
|
|
@@ -991,32 +1131,49 @@ class Dense(Layer):
|
|
|
991
1131
|
without modification.
|
|
992
1132
|
|
|
993
1133
|
Returns:
|
|
994
|
-
A tuple `(kernel_value, kernel_scale)`:
|
|
1134
|
+
A tuple `(kernel_value, kernel_scale, kernel_zero)`:
|
|
995
1135
|
`kernel_value`: The merged kernel. A quantized tensor if
|
|
996
1136
|
quantization is active, otherwise a high precision tensor.
|
|
997
1137
|
`kernel_scale`: The quantization scale for the merged kernel.
|
|
998
1138
|
This is `None` if the layer is not quantized.
|
|
1139
|
+
`kernel_zero`: The zero point for sub-channel int4 quantization.
|
|
1140
|
+
This is `None` for per-channel or non-int4 modes.
|
|
999
1141
|
"""
|
|
1000
1142
|
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
1001
|
-
return self.kernel, None
|
|
1143
|
+
return self.kernel, None, None
|
|
1002
1144
|
|
|
1003
1145
|
kernel_value = self._kernel
|
|
1004
1146
|
kernel_scale = self.kernel_scale
|
|
1147
|
+
kernel_zero = getattr(self, "kernel_zero", None)
|
|
1005
1148
|
|
|
1006
1149
|
if not self.lora_enabled:
|
|
1007
|
-
return kernel_value, kernel_scale
|
|
1150
|
+
return kernel_value, kernel_scale, kernel_zero
|
|
1008
1151
|
|
|
1009
1152
|
# Dequantize, Merge, and Re-quantize
|
|
1153
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
1010
1154
|
|
|
1011
|
-
# Dequantize kernel to float
|
|
1155
|
+
# Step 1: Dequantize kernel to float
|
|
1012
1156
|
if self.quantization_mode == "int4":
|
|
1157
|
+
# Unpack along last axis ([in, out])
|
|
1013
1158
|
unpacked_kernel = quantizers.unpack_int4(
|
|
1014
|
-
kernel_value, self.
|
|
1015
|
-
)
|
|
1016
|
-
float_kernel = ops.divide(
|
|
1017
|
-
ops.cast(unpacked_kernel, self.compute_dtype),
|
|
1018
|
-
kernel_scale,
|
|
1159
|
+
kernel_value, self._orig_output_dim, axis=-1
|
|
1019
1160
|
)
|
|
1161
|
+
if block_size is None or block_size == -1:
|
|
1162
|
+
# Per-channel: kernel [in, out], scale [out]
|
|
1163
|
+
float_kernel = ops.divide(
|
|
1164
|
+
ops.cast(unpacked_kernel, self.compute_dtype),
|
|
1165
|
+
kernel_scale,
|
|
1166
|
+
)
|
|
1167
|
+
else:
|
|
1168
|
+
# Sub-channel: scale/zero are [n_groups, out]
|
|
1169
|
+
float_kernel = dequantize_with_sz_map(
|
|
1170
|
+
unpacked_kernel,
|
|
1171
|
+
kernel_scale,
|
|
1172
|
+
self.kernel_zero,
|
|
1173
|
+
self.g_idx,
|
|
1174
|
+
group_axis=0,
|
|
1175
|
+
)
|
|
1176
|
+
float_kernel = ops.cast(float_kernel, self.compute_dtype)
|
|
1020
1177
|
quant_range = (-8, 7)
|
|
1021
1178
|
elif self.quantization_mode == "int8":
|
|
1022
1179
|
float_kernel = ops.divide(
|
|
@@ -1028,25 +1185,51 @@ class Dense(Layer):
|
|
|
1028
1185
|
f"Unsupported quantization mode: {self.quantization_mode}"
|
|
1029
1186
|
)
|
|
1030
1187
|
|
|
1031
|
-
# Merge LoRA weights in float domain
|
|
1188
|
+
# Step 2: Merge LoRA weights in float domain
|
|
1032
1189
|
lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
1033
1190
|
self.lora_kernel_a, self.lora_kernel_b
|
|
1034
1191
|
)
|
|
1035
1192
|
merged_float_kernel = ops.add(float_kernel, lora_delta)
|
|
1036
1193
|
|
|
1037
|
-
#
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1194
|
+
# Step 3: Re-quantize the merged kernel
|
|
1195
|
+
if (
|
|
1196
|
+
self.quantization_mode == "int4"
|
|
1197
|
+
and block_size is not None
|
|
1198
|
+
and block_size != -1
|
|
1199
|
+
):
|
|
1200
|
+
# Sub-channel: returns kernel [in, out], scale [n_groups, out]
|
|
1201
|
+
requantized_kernel, kernel_scale, kernel_zero = (
|
|
1202
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
1203
|
+
merged_float_kernel, block_size=block_size, to_numpy=True
|
|
1204
|
+
)
|
|
1205
|
+
)
|
|
1206
|
+
elif self.quantization_mode == "int4":
|
|
1207
|
+
# Per-channel: quantize along input axis (axis=0)
|
|
1208
|
+
requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
|
|
1209
|
+
merged_float_kernel,
|
|
1210
|
+
axis=0,
|
|
1211
|
+
value_range=quant_range,
|
|
1212
|
+
dtype="int8",
|
|
1213
|
+
to_numpy=True,
|
|
1214
|
+
)
|
|
1215
|
+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
1216
|
+
kernel_zero = None
|
|
1217
|
+
else:
|
|
1218
|
+
requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
|
|
1219
|
+
merged_float_kernel,
|
|
1220
|
+
axis=0,
|
|
1221
|
+
value_range=quant_range,
|
|
1222
|
+
dtype="int8",
|
|
1223
|
+
to_numpy=True,
|
|
1224
|
+
)
|
|
1225
|
+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
1226
|
+
kernel_zero = None
|
|
1046
1227
|
|
|
1047
|
-
# Pack if int4
|
|
1048
1228
|
if self.quantization_mode == "int4":
|
|
1049
|
-
|
|
1229
|
+
# Pack along last axis
|
|
1230
|
+
kernel_value, _, _ = quantizers.pack_int4(
|
|
1231
|
+
requantized_kernel, axis=-1
|
|
1232
|
+
)
|
|
1050
1233
|
else:
|
|
1051
1234
|
kernel_value = requantized_kernel
|
|
1052
|
-
return kernel_value, kernel_scale
|
|
1235
|
+
return kernel_value, kernel_scale, kernel_zero
|