keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from keras.src.api_export import keras_export
|
|
2
|
+
from keras.src.dtype_policies import QUANTIZATION_MODES
|
|
3
|
+
from keras.src.saving import serialization_lib
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras_export("keras.quantizers.QuantizationConfig")
|
|
7
|
+
class QuantizationConfig:
|
|
8
|
+
"""Base class for quantization configs.
|
|
9
|
+
|
|
10
|
+
Subclasses must implement the `mode` property and the `get_config` and
|
|
11
|
+
`from_config` class methods.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
weight_quantizer: Quantizer for weights.
|
|
15
|
+
activation_quantizer: Quantizer for activations.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, weight_quantizer=None, activation_quantizer=None):
|
|
19
|
+
self.weight_quantizer = weight_quantizer
|
|
20
|
+
self.activation_quantizer = activation_quantizer
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def mode(self):
|
|
24
|
+
raise NotImplementedError(
|
|
25
|
+
"Subclasses must implement this property. Do not instantiate "
|
|
26
|
+
"QuantizationConfig directly."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def get_config(self):
|
|
30
|
+
return {
|
|
31
|
+
"weight_quantizer": serialization_lib.serialize_keras_object(
|
|
32
|
+
self.weight_quantizer
|
|
33
|
+
),
|
|
34
|
+
"activation_quantizer": serialization_lib.serialize_keras_object(
|
|
35
|
+
self.activation_quantizer
|
|
36
|
+
),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_config(cls, config):
|
|
41
|
+
weight_quantizer = serialization_lib.deserialize_keras_object(
|
|
42
|
+
config.get("weight_quantizer")
|
|
43
|
+
)
|
|
44
|
+
activation_quantizer = serialization_lib.deserialize_keras_object(
|
|
45
|
+
config.get("activation_quantizer")
|
|
46
|
+
)
|
|
47
|
+
return cls(
|
|
48
|
+
weight_quantizer=weight_quantizer,
|
|
49
|
+
activation_quantizer=activation_quantizer,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def weight_quantizer_or_default(config, default):
|
|
54
|
+
if config is not None and config.weight_quantizer is not None:
|
|
55
|
+
return config.weight_quantizer
|
|
56
|
+
return default
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def activation_quantizer_or_default(config, default):
|
|
60
|
+
if config is not None:
|
|
61
|
+
return config.activation_quantizer
|
|
62
|
+
return default
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@keras_export("keras.quantizers.Int8QuantizationConfig")
|
|
66
|
+
class Int8QuantizationConfig(QuantizationConfig):
|
|
67
|
+
"""Int8 quantization config.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
weight_quantizer: Quantizer for weights.
|
|
71
|
+
activation_quantizer: Quantizer for activations. If "default", uses
|
|
72
|
+
AbsMaxQuantizer with axis=-1.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
|
|
76
|
+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
77
|
+
|
|
78
|
+
if activation_quantizer == "default":
|
|
79
|
+
activation_quantizer = AbsMaxQuantizer()
|
|
80
|
+
super().__init__(weight_quantizer, activation_quantizer)
|
|
81
|
+
if self.weight_quantizer is not None:
|
|
82
|
+
if self.weight_quantizer.output_dtype != "int8":
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Int8QuantizationConfig requires a weight_quantizer "
|
|
85
|
+
"with output_dtype='int8'. Received: "
|
|
86
|
+
f"output_dtype={self.weight_quantizer.output_dtype}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def mode(self):
|
|
91
|
+
return "int8"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@keras_export("keras.quantizers.Int4QuantizationConfig")
|
|
95
|
+
class Int4QuantizationConfig(QuantizationConfig):
|
|
96
|
+
"""Int4 quantization config.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
weight_quantizer: Quantizer for weights.
|
|
100
|
+
activation_quantizer: Quantizer for activations. If "default", uses
|
|
101
|
+
AbsMaxQuantizer with axis=-1.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
|
|
105
|
+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
106
|
+
|
|
107
|
+
if activation_quantizer == "default":
|
|
108
|
+
activation_quantizer = AbsMaxQuantizer()
|
|
109
|
+
super().__init__(weight_quantizer, activation_quantizer)
|
|
110
|
+
if self.weight_quantizer is not None:
|
|
111
|
+
if self.weight_quantizer.value_range != (-8, 7):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"Int4QuantizationConfig requires a weight_quantizer "
|
|
114
|
+
"with value_range=(-8, 7). Received: "
|
|
115
|
+
f"value_range={self.weight_quantizer.value_range}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if self.weight_quantizer.output_dtype != "int8":
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Int4QuantizationConfig requires a weight_quantizer "
|
|
121
|
+
"with output_dtype='int8'. Received: "
|
|
122
|
+
f"output_dtype={self.weight_quantizer.output_dtype}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def mode(self):
|
|
127
|
+
return "int4"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@keras_export("keras.quantizers.Float8QuantizationConfig")
|
|
131
|
+
class Float8QuantizationConfig(QuantizationConfig):
|
|
132
|
+
"""FP8 quantization config.
|
|
133
|
+
|
|
134
|
+
FP8 mixed-precision training does not support user defined quantizers.
|
|
135
|
+
This config is only used to indicate that FP8 mixed-precision training
|
|
136
|
+
should be used.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(self):
|
|
140
|
+
super().__init__(None, None)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def mode(self):
|
|
144
|
+
return "float8"
|
|
145
|
+
|
|
146
|
+
def get_config(self):
|
|
147
|
+
return {}
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_config(cls, config):
|
|
151
|
+
return cls()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def validate_and_resolve_config(mode, config):
|
|
155
|
+
"""Validate and resolve quantization config.
|
|
156
|
+
|
|
157
|
+
This function validates the quantization config and resolves the mode.
|
|
158
|
+
If mode is not provided, it is inferred from the config.
|
|
159
|
+
If config is not provided, a default config is inferred from the mode.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
mode: Quantization mode.
|
|
163
|
+
config: Quantization config.
|
|
164
|
+
"""
|
|
165
|
+
# 1. Backwards Compatibility: Handle string shortcuts.
|
|
166
|
+
if isinstance(config, str):
|
|
167
|
+
mode = config
|
|
168
|
+
config = None
|
|
169
|
+
|
|
170
|
+
_validate_mode(mode)
|
|
171
|
+
|
|
172
|
+
# 2. Resolve "mode" into a Config object.
|
|
173
|
+
if config is None:
|
|
174
|
+
if mode == "int8":
|
|
175
|
+
config = Int8QuantizationConfig()
|
|
176
|
+
elif mode == "int4":
|
|
177
|
+
config = Int4QuantizationConfig()
|
|
178
|
+
elif mode == "float8":
|
|
179
|
+
config = Float8QuantizationConfig()
|
|
180
|
+
elif mode == "gptq":
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"For GPTQ, you must pass a `GPTQConfig` object in the "
|
|
183
|
+
"`config` argument."
|
|
184
|
+
)
|
|
185
|
+
elif mode == "awq":
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"For AWQ, you must pass an `AWQConfig` object in the "
|
|
188
|
+
"`config` argument."
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
if mode is not None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Invalid quantization mode. Received: mode={mode}"
|
|
194
|
+
)
|
|
195
|
+
raise ValueError(
|
|
196
|
+
"You must provide either `mode` or `config` to `quantize`."
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
if not isinstance(config, QuantizationConfig):
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Argument `config` must be an instance of "
|
|
202
|
+
"`QuantizationConfig`. "
|
|
203
|
+
f"Received: config={config} (of type {type(config)})"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# 3. Validation: Prevent contradictions.
|
|
207
|
+
if mode is not None and config.mode != mode:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Contradictory arguments: mode='{mode}' but "
|
|
210
|
+
f"config.mode='{config.mode}'"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Ensure mode is consistent.
|
|
214
|
+
mode = config.mode
|
|
215
|
+
|
|
216
|
+
# Ensure the mode derived from the config is valid.
|
|
217
|
+
_validate_mode(mode)
|
|
218
|
+
|
|
219
|
+
if mode == "gptq":
|
|
220
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
221
|
+
|
|
222
|
+
if not isinstance(config, GPTQConfig):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Mode 'gptq' requires a valid `config` argument of type "
|
|
225
|
+
f"`GPTQConfig`. Received: {type(config)}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if mode == "awq":
|
|
229
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
230
|
+
|
|
231
|
+
if not isinstance(config, AWQConfig):
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"Mode 'awq' requires a valid `config` argument of type "
|
|
234
|
+
f"`AWQConfig`. Received: {type(config)}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return config
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _validate_mode(mode):
|
|
241
|
+
"""Validates quantization mode."""
|
|
242
|
+
if mode is not None and mode not in QUANTIZATION_MODES:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"Invalid quantization mode. "
|
|
245
|
+
f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
|
|
246
|
+
)
|
|
@@ -73,6 +73,23 @@ def abs_max_quantize(
|
|
|
73
73
|
epsilon=backend.epsilon(),
|
|
74
74
|
to_numpy=False,
|
|
75
75
|
):
|
|
76
|
+
"""
|
|
77
|
+
Quantizes the input tensor using the absolute maximum quantization scheme.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
inputs: Input tensor to quantize.
|
|
81
|
+
axis: Axis along which to compute the quantization range.
|
|
82
|
+
value_range: Tuple of the minimum and maximum values of the quantization
|
|
83
|
+
range.
|
|
84
|
+
dtype: Data type of the quantized output.
|
|
85
|
+
epsilon: Small value to avoid division by zero.
|
|
86
|
+
to_numpy: Whether to perform the quantization in numpy. This performs
|
|
87
|
+
the computation on the host CPU and can be useful for saving memory
|
|
88
|
+
on the device. If False, the computation is performed on the device.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A tuple of the quantized tensor and the scale.
|
|
92
|
+
"""
|
|
76
93
|
if to_numpy:
|
|
77
94
|
# Save memory on the device using numpy
|
|
78
95
|
original_dtype = backend.standardize_dtype(inputs.dtype)
|
|
@@ -105,31 +122,69 @@ def abs_max_quantize(
|
|
|
105
122
|
class AbsMaxQuantizer(Quantizer):
|
|
106
123
|
def __init__(
|
|
107
124
|
self,
|
|
108
|
-
axis,
|
|
125
|
+
axis=None, # Deprecated, provide axis in __call__ instead.
|
|
109
126
|
value_range=(-127, 127),
|
|
110
127
|
epsilon=backend.epsilon(),
|
|
111
128
|
output_dtype="int8",
|
|
112
129
|
):
|
|
113
130
|
Quantizer.__init__(self, output_dtype=output_dtype)
|
|
114
|
-
if
|
|
115
|
-
|
|
116
|
-
|
|
131
|
+
if axis is not None:
|
|
132
|
+
if isinstance(axis, int):
|
|
133
|
+
axis = (axis,)
|
|
134
|
+
self.axis = tuple(axis)
|
|
135
|
+
else:
|
|
136
|
+
self.axis = None
|
|
117
137
|
self.value_range = value_range
|
|
118
138
|
self.epsilon = epsilon
|
|
139
|
+
if output_dtype == "int8":
|
|
140
|
+
if value_range[0] < -128 or value_range[1] > 127:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Quantizer with output_dtype='int8' requires value_range "
|
|
143
|
+
f"to be within the interval [-128, 127]. Received: "
|
|
144
|
+
f"value_range={value_range}"
|
|
145
|
+
)
|
|
119
146
|
|
|
120
|
-
def __call__(self, x):
|
|
147
|
+
def __call__(self, x, axis=None, to_numpy=False):
|
|
148
|
+
"""
|
|
149
|
+
Quantizes the input tensor.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
x: Input tensor to quantize.
|
|
153
|
+
axis: Axis along which to compute the quantization range. If None,
|
|
154
|
+
uses the axis specified in the constructor. If None and no axis
|
|
155
|
+
was specified in the constructor, defaults to -1.
|
|
156
|
+
to_numpy: Whether to perform the quantization in numpy. This
|
|
157
|
+
performs the computation on the host CPU and can be useful for
|
|
158
|
+
saving memory on the device. If False, the computation is
|
|
159
|
+
performed on the device.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
A tuple of the quantized tensor and the scale.
|
|
163
|
+
"""
|
|
164
|
+
if axis is None:
|
|
165
|
+
axis = self.axis
|
|
166
|
+
if axis is None:
|
|
167
|
+
# Default to -1 if no axis is specified
|
|
168
|
+
axis = -1
|
|
121
169
|
quantized_x, scale = abs_max_quantize(
|
|
122
|
-
x,
|
|
170
|
+
x,
|
|
171
|
+
axis,
|
|
172
|
+
self.value_range,
|
|
173
|
+
self.output_dtype,
|
|
174
|
+
self.epsilon,
|
|
175
|
+
to_numpy,
|
|
123
176
|
)
|
|
124
177
|
return quantized_x, scale
|
|
125
178
|
|
|
126
179
|
def get_config(self):
|
|
127
|
-
|
|
128
|
-
"axis": self.axis,
|
|
180
|
+
config = {
|
|
129
181
|
"value_range": self.value_range,
|
|
130
182
|
"epsilon": self.epsilon,
|
|
131
183
|
"output_dtype": self.output_dtype,
|
|
132
184
|
}
|
|
185
|
+
if self.axis is not None:
|
|
186
|
+
config["axis"] = self.axis
|
|
187
|
+
return config
|
|
133
188
|
|
|
134
189
|
|
|
135
190
|
def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
|
|
@@ -281,7 +336,7 @@ def fake_quant_with_min_max_vars(
|
|
|
281
336
|
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
|
|
282
337
|
)
|
|
283
338
|
x_clamped = ops.clip(
|
|
284
|
-
|
|
339
|
+
ops.cast(x, nudged_min.dtype), nudged_min, nudged_max
|
|
285
340
|
)
|
|
286
341
|
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
|
|
287
342
|
result = ops.multiply(
|
|
@@ -318,6 +373,7 @@ def fake_quant_with_min_max_vars(
|
|
|
318
373
|
grad_min = ops.sum(grad_min, axis=axes)
|
|
319
374
|
else:
|
|
320
375
|
grad_min = ops.sum(grad_min)
|
|
376
|
+
grad_min = ops.reshape(grad_min, ops.shape(min_val))
|
|
321
377
|
|
|
322
378
|
# Gradient for max_val
|
|
323
379
|
# When x is clipped to max, the gradient flows to max_val
|
|
@@ -327,6 +383,7 @@ def fake_quant_with_min_max_vars(
|
|
|
327
383
|
grad_max = ops.sum(grad_max, axis=axes)
|
|
328
384
|
else:
|
|
329
385
|
grad_max = ops.sum(grad_max)
|
|
386
|
+
grad_max = ops.reshape(grad_max, ops.shape(max_val))
|
|
330
387
|
|
|
331
388
|
return dx, grad_min, grad_max
|
|
332
389
|
|
|
@@ -596,11 +653,14 @@ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
|
|
|
596
653
|
)
|
|
597
654
|
|
|
598
655
|
def to_signed(x):
|
|
599
|
-
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
656
|
+
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
657
|
+
|
|
658
|
+
Uses a branchless XOR approach: (x ^ 8) - 8
|
|
659
|
+
This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
|
|
660
|
+
"""
|
|
600
661
|
dtype_x = backend.standardize_dtype(x.dtype)
|
|
601
662
|
eight = ops.cast(8, dtype_x)
|
|
602
|
-
|
|
603
|
-
return ops.where(x < eight, x, x - sixteen)
|
|
663
|
+
return ops.subtract(ops.bitwise_xor(x, eight), eight)
|
|
604
664
|
|
|
605
665
|
rank = getattr(packed.shape, "rank", None) or len(packed.shape)
|
|
606
666
|
if axis < 0:
|
|
@@ -691,7 +751,7 @@ class GPTQQuantizer(Quantizer):
|
|
|
691
751
|
self.zero = None
|
|
692
752
|
self.maxq = None
|
|
693
753
|
|
|
694
|
-
def find_params(self, input_tensor
|
|
754
|
+
def find_params(self, input_tensor):
|
|
695
755
|
"""Finds quantization parameters (scale and zero) for a given tensor."""
|
|
696
756
|
self.scale, self.zero, self.maxq = compute_quantization_parameters(
|
|
697
757
|
input_tensor,
|
|
@@ -699,7 +759,6 @@ class GPTQQuantizer(Quantizer):
|
|
|
699
759
|
symmetric=self.symmetric,
|
|
700
760
|
per_channel=self.per_channel,
|
|
701
761
|
group_size=self.group_size,
|
|
702
|
-
weight=weight,
|
|
703
762
|
compute_dtype=self.compute_dtype,
|
|
704
763
|
)
|
|
705
764
|
return self.scale, self.zero, self.maxq
|
|
@@ -736,98 +795,105 @@ def compute_quantization_parameters(
|
|
|
736
795
|
symmetric=False,
|
|
737
796
|
per_channel=False,
|
|
738
797
|
group_size=-1,
|
|
739
|
-
weight=False,
|
|
740
798
|
compute_dtype="float32",
|
|
741
799
|
):
|
|
742
800
|
"""
|
|
743
|
-
Computes the scale and zero-point for
|
|
801
|
+
Computes the scale and zero-point for quantizing weight tensors.
|
|
744
802
|
|
|
745
803
|
This function calculates the scale and zero-point required for quantizing
|
|
746
|
-
a given tensor `x` based on the specified parameters. It supports
|
|
747
|
-
per-channel, per-tensor, symmetric, and asymmetric quantization
|
|
748
|
-
|
|
804
|
+
a given weight tensor `x` based on the specified parameters. It supports
|
|
805
|
+
grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
|
|
806
|
+
|
|
807
|
+
For grouped quantization (per_channel=True, group_size > 0), the output
|
|
808
|
+
shapes are [out_features, n_groups] where n_groups is the number of groups
|
|
809
|
+
along the in_features dimension.
|
|
749
810
|
|
|
750
811
|
Args:
|
|
751
|
-
x: KerasTensor. The
|
|
812
|
+
x: KerasTensor. The weight tensor to quantize with shape
|
|
813
|
+
[out_features, in_features].
|
|
752
814
|
bits: int. The number of bits to quantize to (e.g., 4).
|
|
753
815
|
symmetric: bool. Whether to use symmetric quantization.
|
|
754
816
|
per_channel: bool. Whether to quantize per channel.
|
|
755
|
-
group_size: int. The group size for quantization.
|
|
756
|
-
|
|
817
|
+
group_size: int. The group size for quantization. -1 means no grouping.
|
|
818
|
+
compute_dtype: str. The dtype for computation. Defaults to "float32".
|
|
757
819
|
|
|
758
820
|
Returns:
|
|
759
821
|
scale: KerasTensor. The scale tensor for quantization.
|
|
760
822
|
zero: KerasTensor. The zero tensor for quantization.
|
|
761
823
|
maxq: scalar. The maximum quantization value.
|
|
762
824
|
"""
|
|
825
|
+
# Input validation
|
|
763
826
|
if x is None:
|
|
764
827
|
raise ValueError(f"Input tensor {x} cannot be None.")
|
|
765
|
-
|
|
766
|
-
# For weights, we typically expect at least a 2D tensor.
|
|
767
|
-
if weight and len(x.shape) < 2:
|
|
828
|
+
if len(x.shape) < 2:
|
|
768
829
|
raise ValueError(
|
|
769
830
|
f"Input weight tensor {x} must have a rank of at "
|
|
770
831
|
f"least 2, but got rank {len(x.shape)}."
|
|
771
832
|
)
|
|
772
|
-
|
|
773
833
|
if ops.size(x) == 0:
|
|
774
834
|
raise ValueError("Input tensor 'x' cannot be empty.")
|
|
775
835
|
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
if per_channel:
|
|
779
|
-
if weight:
|
|
780
|
-
if group_size != -1:
|
|
781
|
-
input_reshaped = ops.reshape(x, [-1, group_size])
|
|
782
|
-
else:
|
|
783
|
-
input_reshaped = ops.reshape(x, [original_shape[0], -1])
|
|
784
|
-
else: # per-tensor
|
|
785
|
-
input_reshaped = ops.reshape(x, [1, -1])
|
|
836
|
+
out_features, in_features = x.shape[0], x.shape[1]
|
|
786
837
|
|
|
787
|
-
#
|
|
788
|
-
|
|
789
|
-
|
|
838
|
+
# Determine number of groups for quantization
|
|
839
|
+
if per_channel and group_size > 0:
|
|
840
|
+
n_groups = (in_features + group_size - 1) // group_size
|
|
841
|
+
else:
|
|
842
|
+
n_groups = 1
|
|
843
|
+
|
|
844
|
+
# Compute min/max values based on quantization mode
|
|
845
|
+
if n_groups > 1:
|
|
846
|
+
# Grouped quantization: output shape [out_features, n_groups]
|
|
847
|
+
remainder = in_features % group_size
|
|
848
|
+
if remainder != 0:
|
|
849
|
+
pad_size = group_size - remainder
|
|
850
|
+
x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
|
|
851
|
+
|
|
852
|
+
x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
|
|
853
|
+
min_values = ops.min(x_grouped, axis=2)
|
|
854
|
+
max_values = ops.max(x_grouped, axis=2)
|
|
855
|
+
else:
|
|
856
|
+
# Per-channel or per-tensor: compute stats along rows
|
|
857
|
+
reduction_shape = [out_features, -1] if per_channel else [1, -1]
|
|
858
|
+
x_reshaped = ops.reshape(x, reduction_shape)
|
|
859
|
+
min_values = ops.min(x_reshaped, axis=1)
|
|
860
|
+
max_values = ops.max(x_reshaped, axis=1)
|
|
790
861
|
|
|
791
|
-
#
|
|
862
|
+
# Symmetric quantization: make range symmetric around zero
|
|
792
863
|
if symmetric:
|
|
793
|
-
|
|
864
|
+
max_abs = ops.maximum(ops.abs(min_values), max_values)
|
|
794
865
|
min_values = ops.where(
|
|
795
|
-
ops.less(min_values, 0), ops.negative(
|
|
866
|
+
ops.less(min_values, 0), ops.negative(max_abs), min_values
|
|
796
867
|
)
|
|
868
|
+
max_values = max_abs
|
|
797
869
|
|
|
798
|
-
# Ensure range
|
|
870
|
+
# Ensure non-zero range to avoid division errors
|
|
799
871
|
zero_range = ops.equal(min_values, max_values)
|
|
800
872
|
min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
|
|
801
873
|
max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
|
|
802
874
|
|
|
875
|
+
# Compute scale and zero-point
|
|
803
876
|
maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
|
|
804
|
-
|
|
805
|
-
# Calculate scale and zero-point
|
|
806
877
|
scale = ops.divide(ops.subtract(max_values, min_values), maxq)
|
|
878
|
+
scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
|
|
879
|
+
|
|
807
880
|
if symmetric:
|
|
808
881
|
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
|
|
809
882
|
else:
|
|
810
883
|
zero = ops.round(ops.divide(ops.negative(min_values), scale))
|
|
811
884
|
|
|
812
|
-
#
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
# Per-channel, non-grouped case: simple reshape is correct.
|
|
817
|
-
if per_channel and group_size == -1:
|
|
818
|
-
scale = ops.reshape(scale, [-1, 1])
|
|
819
|
-
zero = ops.reshape(zero, [-1, 1])
|
|
820
|
-
elif not per_channel:
|
|
821
|
-
num_rows = original_shape[0]
|
|
822
|
-
scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
|
|
823
|
-
zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
|
|
824
|
-
if per_channel:
|
|
885
|
+
# Reshape output to [out_features, n_groups] or [out_features, 1]
|
|
886
|
+
if n_groups > 1:
|
|
887
|
+
pass # Already [out_features, n_groups]
|
|
888
|
+
elif per_channel:
|
|
825
889
|
scale = ops.reshape(scale, [-1, 1])
|
|
826
890
|
zero = ops.reshape(zero, [-1, 1])
|
|
891
|
+
else:
|
|
892
|
+
# Per-tensor: tile single value to [out_features, 1]
|
|
893
|
+
scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
|
|
894
|
+
zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
|
|
827
895
|
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
return scale, zero, maxq
|
|
896
|
+
return scale, ops.cast(zero, "uint8"), maxq
|
|
831
897
|
|
|
832
898
|
|
|
833
899
|
def quantize_with_zero_point(input_tensor, scale, zero, maxq):
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def should_quantize_layer(layer, filters):
|
|
5
|
+
"""Determines if a layer should be quantized based on filters.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
layer: The layer to check.
|
|
9
|
+
filters: A regex string, a list of regex strings, or a callable.
|
|
10
|
+
If None, returns True.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
True if the layer should be quantized, False otherwise.
|
|
14
|
+
"""
|
|
15
|
+
if filters is None:
|
|
16
|
+
return True
|
|
17
|
+
if isinstance(filters, str):
|
|
18
|
+
return bool(re.search(filters, layer.name))
|
|
19
|
+
if isinstance(filters, (list, tuple)):
|
|
20
|
+
return any(re.search(pat, layer.name) for pat in filters)
|
|
21
|
+
if callable(filters):
|
|
22
|
+
return filters(layer)
|
|
23
|
+
return True
|
|
@@ -8,6 +8,8 @@ from keras.src.backend.common import global_state
|
|
|
8
8
|
from keras.src.utils import jax_utils
|
|
9
9
|
from keras.src.utils.naming import auto_name
|
|
10
10
|
|
|
11
|
+
GLOBAL_SEED_GENERATOR = "global_seed_generator"
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
@keras_export("keras.random.SeedGenerator")
|
|
13
15
|
class SeedGenerator:
|
|
@@ -27,7 +29,7 @@ class SeedGenerator:
|
|
|
27
29
|
a local `StateGenerator` with either a deterministic or random initial
|
|
28
30
|
state.
|
|
29
31
|
|
|
30
|
-
Remark concerning the JAX
|
|
32
|
+
Remark concerning the JAX backend: Note that the use of a local
|
|
31
33
|
`StateGenerator` as seed argument is required for JIT compilation of
|
|
32
34
|
RNG with the JAX backend, because the use of global state is not
|
|
33
35
|
supported.
|
|
@@ -109,7 +111,7 @@ class SeedGenerator:
|
|
|
109
111
|
return new_seed_value
|
|
110
112
|
|
|
111
113
|
def get_config(self):
|
|
112
|
-
return {"seed": self._initial_seed}
|
|
114
|
+
return {"seed": self._initial_seed, "name": self.name}
|
|
113
115
|
|
|
114
116
|
@classmethod
|
|
115
117
|
def from_config(cls, config):
|
|
@@ -133,10 +135,10 @@ def global_seed_generator():
|
|
|
133
135
|
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
|
|
134
136
|
"```"
|
|
135
137
|
)
|
|
136
|
-
gen = global_state.get_global_attribute(
|
|
138
|
+
gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
|
|
137
139
|
if gen is None:
|
|
138
140
|
gen = SeedGenerator()
|
|
139
|
-
global_state.set_global_attribute(
|
|
141
|
+
global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
|
|
140
142
|
return gen
|
|
141
143
|
|
|
142
144
|
|