keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/layers/core/dense.py
CHANGED
|
@@ -11,7 +11,9 @@ from keras.src import regularizers
|
|
|
11
11
|
from keras.src.api_export import keras_export
|
|
12
12
|
from keras.src.layers.input_spec import InputSpec
|
|
13
13
|
from keras.src.layers.layer import Layer
|
|
14
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
14
15
|
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
16
|
+
from keras.src.saving import serialization_lib
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
@keras_export("keras.layers.Dense")
|
|
@@ -23,7 +25,9 @@ class Dense(Layer):
|
|
|
23
25
|
where `activation` is the element-wise activation function
|
|
24
26
|
passed as the `activation` argument, `kernel` is a weights matrix
|
|
25
27
|
created by the layer, and `bias` is a bias vector created by the layer
|
|
26
|
-
(only applicable if `use_bias` is `True`).
|
|
28
|
+
(only applicable if `use_bias` is `True`). When this layer is
|
|
29
|
+
followed by a `BatchNormalization` layer, it is recommended to set
|
|
30
|
+
`use_bias=False` as `BatchNormalization` has its own bias term.
|
|
27
31
|
|
|
28
32
|
Note: If the input to the layer has a rank greater than 2, `Dense`
|
|
29
33
|
computes the dot product between the `inputs` and the `kernel` along the
|
|
@@ -90,8 +94,15 @@ class Dense(Layer):
|
|
|
90
94
|
bias_constraint=None,
|
|
91
95
|
lora_rank=None,
|
|
92
96
|
lora_alpha=None,
|
|
97
|
+
quantization_config=None,
|
|
93
98
|
**kwargs,
|
|
94
99
|
):
|
|
100
|
+
if not isinstance(units, int) or units <= 0:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Received an invalid value for `units`, expected a positive "
|
|
103
|
+
f"integer. Received: units={units}"
|
|
104
|
+
)
|
|
105
|
+
|
|
95
106
|
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
|
|
96
107
|
self.units = units
|
|
97
108
|
self.activation = activations.get(activation)
|
|
@@ -105,14 +116,19 @@ class Dense(Layer):
|
|
|
105
116
|
self.lora_rank = lora_rank
|
|
106
117
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
107
118
|
self.lora_enabled = False
|
|
119
|
+
self.quantization_config = quantization_config
|
|
108
120
|
self.input_spec = InputSpec(min_ndim=2)
|
|
109
121
|
self.supports_masking = True
|
|
110
122
|
|
|
111
123
|
def build(self, input_shape):
|
|
112
124
|
kernel_shape = (input_shape[-1], self.units)
|
|
113
125
|
if self.quantization_mode:
|
|
114
|
-
self.quantized_build(
|
|
115
|
-
|
|
126
|
+
self.quantized_build(
|
|
127
|
+
kernel_shape,
|
|
128
|
+
mode=self.quantization_mode,
|
|
129
|
+
config=self.quantization_config,
|
|
130
|
+
)
|
|
131
|
+
if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
|
|
116
132
|
# If the layer is quantized to int8 or int4, `self._kernel` will be
|
|
117
133
|
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
|
|
118
134
|
# it here.
|
|
@@ -149,15 +165,17 @@ class Dense(Layer):
|
|
|
149
165
|
|
|
150
166
|
mode = self.quantization_mode
|
|
151
167
|
is_gptq = mode == "gptq"
|
|
168
|
+
is_awq = mode == "awq"
|
|
152
169
|
is_int4 = mode == "int4"
|
|
153
|
-
|
|
170
|
+
gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
171
|
+
awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
|
|
154
172
|
gptq_bits = (
|
|
155
173
|
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
156
174
|
)
|
|
157
175
|
|
|
158
176
|
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
159
177
|
# kernel)
|
|
160
|
-
if is_gptq and
|
|
178
|
+
if is_gptq and gptq_calibrated and gptq_bits != 4:
|
|
161
179
|
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
162
180
|
kernel = self.quantized_kernel
|
|
163
181
|
else:
|
|
@@ -167,7 +185,15 @@ class Dense(Layer):
|
|
|
167
185
|
# Handle int4 unpacking cases in one place
|
|
168
186
|
if is_int4:
|
|
169
187
|
kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
|
|
170
|
-
elif is_gptq and
|
|
188
|
+
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
189
|
+
kernel = quantizers.unpack_int4(
|
|
190
|
+
self.quantized_kernel,
|
|
191
|
+
orig_len=self.units,
|
|
192
|
+
axis=0,
|
|
193
|
+
dtype="uint8",
|
|
194
|
+
)
|
|
195
|
+
elif is_awq and awq_calibrated:
|
|
196
|
+
# AWQ always uses 4-bit quantization
|
|
171
197
|
kernel = quantizers.unpack_int4(
|
|
172
198
|
self.quantized_kernel,
|
|
173
199
|
orig_len=self.units,
|
|
@@ -258,25 +284,25 @@ class Dense(Layer):
|
|
|
258
284
|
if not self.built:
|
|
259
285
|
return
|
|
260
286
|
mode = self.quantization_mode
|
|
261
|
-
if mode not in self.
|
|
287
|
+
if mode not in self.variable_serialization_spec:
|
|
262
288
|
raise self._quantization_mode_error(mode)
|
|
263
289
|
|
|
264
290
|
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
|
|
265
291
|
# for None/gptq)
|
|
266
292
|
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
if name == "kernel_scale" and mode in ("int4", "int8"):
|
|
293
|
+
idx = 0
|
|
294
|
+
for name in self.variable_serialization_spec[mode]:
|
|
295
|
+
if name == "kernel":
|
|
296
|
+
store[str(idx)] = kernel_value
|
|
297
|
+
elif name == "bias" and self.bias is None:
|
|
298
|
+
continue
|
|
299
|
+
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
275
300
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
276
301
|
# `_get_kernel_with_merged_lora()`
|
|
277
|
-
store[
|
|
302
|
+
store[str(idx)] = merged_kernel_scale
|
|
278
303
|
else:
|
|
279
|
-
store[
|
|
304
|
+
store[str(idx)] = getattr(self, name)
|
|
305
|
+
idx += 1
|
|
280
306
|
|
|
281
307
|
def load_own_variables(self, store):
|
|
282
308
|
if not self.lora_enabled:
|
|
@@ -285,39 +311,22 @@ class Dense(Layer):
|
|
|
285
311
|
if not self.built:
|
|
286
312
|
return
|
|
287
313
|
mode = self.quantization_mode
|
|
288
|
-
if mode not in self.
|
|
314
|
+
if mode not in self.variable_serialization_spec:
|
|
289
315
|
raise self._quantization_mode_error(mode)
|
|
290
316
|
|
|
291
|
-
#
|
|
292
|
-
|
|
293
|
-
|
|
317
|
+
# A saved GPTQ/AWQ quantized model will always be calibrated.
|
|
318
|
+
self.is_gptq_calibrated = mode == "gptq"
|
|
319
|
+
self.is_awq_calibrated = mode == "awq"
|
|
294
320
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
self.bias
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
305
|
-
|
|
306
|
-
def _legacy_load_own_variables(self, store):
|
|
307
|
-
# The keys of the `store` will be saved as determined because the
|
|
308
|
-
# default ordering will change after quantization
|
|
309
|
-
mode = self.quantization_mode
|
|
310
|
-
targets = []
|
|
311
|
-
if mode != "gptq":
|
|
312
|
-
targets.append(self._kernel)
|
|
313
|
-
if self.bias is not None:
|
|
314
|
-
targets.append(self.bias)
|
|
315
|
-
targets.extend(
|
|
316
|
-
getattr(self, name)
|
|
317
|
-
for name in self.quantization_variable_spec[mode]
|
|
318
|
-
)
|
|
319
|
-
for i, variable in enumerate(targets):
|
|
320
|
-
variable.assign(store[str(i)])
|
|
321
|
+
idx = 0
|
|
322
|
+
for name in self.variable_serialization_spec[mode]:
|
|
323
|
+
if name == "kernel":
|
|
324
|
+
self._kernel.assign(store[str(idx)])
|
|
325
|
+
elif name == "bias" and self.bias is None:
|
|
326
|
+
continue
|
|
327
|
+
else:
|
|
328
|
+
getattr(self, name).assign(store[str(idx)])
|
|
329
|
+
idx += 1
|
|
321
330
|
if self.lora_enabled:
|
|
322
331
|
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
323
332
|
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
@@ -338,59 +347,51 @@ class Dense(Layer):
|
|
|
338
347
|
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
|
|
339
348
|
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
|
340
349
|
"bias_constraint": constraints.serialize(self.bias_constraint),
|
|
350
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
351
|
+
self.quantization_config
|
|
352
|
+
),
|
|
341
353
|
}
|
|
342
354
|
if self.lora_rank:
|
|
343
355
|
config["lora_rank"] = self.lora_rank
|
|
344
356
|
config["lora_alpha"] = self.lora_alpha
|
|
345
357
|
return {**base_config, **config}
|
|
346
358
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
"and thus it doesn't have any variables. "
|
|
354
|
-
f"However the weights file lists {len(store.keys())} "
|
|
355
|
-
"variables for this layer.\n"
|
|
356
|
-
"In most cases, this error indicates that either:\n\n"
|
|
357
|
-
"1. The layer is owned by a parent layer that "
|
|
358
|
-
"implements a `build()` method, but calling the "
|
|
359
|
-
"parent's `build()` method did NOT create the state of "
|
|
360
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
361
|
-
"must create ALL state for the layer, including "
|
|
362
|
-
"the state of any children layers.\n\n"
|
|
363
|
-
"2. You need to implement "
|
|
364
|
-
"the `def build_from_config(self, config)` method "
|
|
365
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
366
|
-
"it during loading. "
|
|
367
|
-
"In this case, you might also want to implement the "
|
|
368
|
-
"method that generates the build config at saving time, "
|
|
369
|
-
"`def get_build_config(self)`. "
|
|
370
|
-
"The method `build_from_config()` is meant "
|
|
371
|
-
"to create the state "
|
|
372
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
373
|
-
)
|
|
374
|
-
raise ValueError(
|
|
375
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
376
|
-
"but received "
|
|
377
|
-
f"{len(store.keys())} variables during loading. "
|
|
378
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
359
|
+
@classmethod
|
|
360
|
+
def from_config(cls, config):
|
|
361
|
+
config = config.copy()
|
|
362
|
+
config["quantization_config"] = (
|
|
363
|
+
serialization_lib.deserialize_keras_object(
|
|
364
|
+
config.get("quantization_config", None)
|
|
379
365
|
)
|
|
366
|
+
)
|
|
367
|
+
return super().from_config(config)
|
|
380
368
|
|
|
381
369
|
@property
|
|
382
|
-
def
|
|
383
|
-
"""Returns a dict mapping quantization modes to variable names.
|
|
370
|
+
def variable_serialization_spec(self):
|
|
371
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
384
372
|
|
|
385
373
|
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
386
|
-
determine
|
|
387
|
-
mode.
|
|
374
|
+
determine the correct ordering of variables during serialization for
|
|
375
|
+
each quantization mode. `None` means no quantization.
|
|
388
376
|
"""
|
|
389
377
|
return {
|
|
390
|
-
None: [
|
|
391
|
-
|
|
392
|
-
|
|
378
|
+
None: [
|
|
379
|
+
"kernel",
|
|
380
|
+
"bias",
|
|
381
|
+
],
|
|
382
|
+
"int8": [
|
|
383
|
+
"kernel",
|
|
384
|
+
"bias",
|
|
385
|
+
"kernel_scale",
|
|
386
|
+
],
|
|
387
|
+
"int4": [
|
|
388
|
+
"kernel",
|
|
389
|
+
"bias",
|
|
390
|
+
"kernel_scale",
|
|
391
|
+
],
|
|
393
392
|
"float8": [
|
|
393
|
+
"kernel",
|
|
394
|
+
"bias",
|
|
394
395
|
"inputs_scale",
|
|
395
396
|
"inputs_amax_history",
|
|
396
397
|
"kernel_scale",
|
|
@@ -399,28 +400,44 @@ class Dense(Layer):
|
|
|
399
400
|
"outputs_grad_amax_history",
|
|
400
401
|
],
|
|
401
402
|
"gptq": [
|
|
403
|
+
"bias",
|
|
404
|
+
"quantized_kernel",
|
|
405
|
+
"kernel_scale",
|
|
406
|
+
"kernel_zero",
|
|
407
|
+
"g_idx",
|
|
408
|
+
],
|
|
409
|
+
"awq": [
|
|
410
|
+
"bias",
|
|
402
411
|
"quantized_kernel",
|
|
403
412
|
"kernel_scale",
|
|
404
413
|
"kernel_zero",
|
|
414
|
+
"awq_scales",
|
|
405
415
|
"g_idx",
|
|
406
416
|
],
|
|
407
417
|
}
|
|
408
418
|
|
|
409
419
|
def quantized_build(self, kernel_shape, mode, config=None):
|
|
410
420
|
if mode == "int8":
|
|
411
|
-
self._int8_build(kernel_shape)
|
|
421
|
+
self._int8_build(kernel_shape, config)
|
|
412
422
|
elif mode == "int4":
|
|
413
|
-
self._int4_build(kernel_shape)
|
|
423
|
+
self._int4_build(kernel_shape, config)
|
|
414
424
|
elif mode == "float8":
|
|
415
425
|
self._float8_build()
|
|
416
426
|
elif mode == "gptq":
|
|
417
427
|
self._gptq_build(kernel_shape, config)
|
|
428
|
+
elif mode == "awq":
|
|
429
|
+
self._awq_build(kernel_shape, config)
|
|
418
430
|
else:
|
|
419
431
|
raise self._quantization_mode_error(mode)
|
|
420
432
|
self._is_quantized = True
|
|
421
433
|
|
|
422
|
-
def _int8_build(self, kernel_shape):
|
|
423
|
-
self.inputs_quantizer =
|
|
434
|
+
def _int8_build(self, kernel_shape, config=None):
|
|
435
|
+
self.inputs_quantizer = (
|
|
436
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
437
|
+
config, quantizers.AbsMaxQuantizer()
|
|
438
|
+
)
|
|
439
|
+
)
|
|
440
|
+
|
|
424
441
|
self._kernel = self.add_weight(
|
|
425
442
|
name="kernel",
|
|
426
443
|
shape=kernel_shape,
|
|
@@ -519,7 +536,98 @@ class Dense(Layer):
|
|
|
519
536
|
y = self.activation(y)
|
|
520
537
|
return y
|
|
521
538
|
|
|
522
|
-
def
|
|
539
|
+
def _awq_build(self, kernel_shape, config):
|
|
540
|
+
"""Build variables for AWQ quantization.
|
|
541
|
+
|
|
542
|
+
AWQ uses 4-bit quantization with per-channel AWQ scales that protect
|
|
543
|
+
salient weights based on activation magnitudes.
|
|
544
|
+
"""
|
|
545
|
+
from keras.src.quantizers import awq_core
|
|
546
|
+
|
|
547
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
548
|
+
# until calibration has been performed.
|
|
549
|
+
self.is_awq_calibrated = False
|
|
550
|
+
self.kernel_shape = kernel_shape
|
|
551
|
+
|
|
552
|
+
# For 4-bit weights, we pack two values per byte.
|
|
553
|
+
units = (kernel_shape[1] + 1) // 2
|
|
554
|
+
|
|
555
|
+
self.quantized_kernel = self.add_weight(
|
|
556
|
+
name="kernel",
|
|
557
|
+
shape=(units, kernel_shape[0]),
|
|
558
|
+
initializer="zeros",
|
|
559
|
+
dtype="uint8",
|
|
560
|
+
trainable=False,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
group_size = awq_core.get_group_size_for_layer(self, config)
|
|
564
|
+
num_groups = (
|
|
565
|
+
1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
|
|
566
|
+
)
|
|
567
|
+
self.kernel_scale = self.add_weight(
|
|
568
|
+
name="kernel_scale",
|
|
569
|
+
shape=(self.units, num_groups),
|
|
570
|
+
initializer="ones",
|
|
571
|
+
trainable=False,
|
|
572
|
+
)
|
|
573
|
+
self.kernel_zero = self.add_weight(
|
|
574
|
+
name="kernel_zero",
|
|
575
|
+
shape=(self.units, num_groups),
|
|
576
|
+
initializer="zeros",
|
|
577
|
+
dtype="uint8",
|
|
578
|
+
trainable=False,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Per-channel AWQ scales from activation magnitudes
|
|
582
|
+
self.awq_scales = self.add_weight(
|
|
583
|
+
name="awq_scales",
|
|
584
|
+
shape=(kernel_shape[0],),
|
|
585
|
+
initializer="ones",
|
|
586
|
+
trainable=False,
|
|
587
|
+
)
|
|
588
|
+
self.g_idx = self.add_weight(
|
|
589
|
+
name="g_idx",
|
|
590
|
+
shape=(kernel_shape[0],),
|
|
591
|
+
initializer="zeros",
|
|
592
|
+
dtype="float32",
|
|
593
|
+
trainable=False,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
def _awq_call(self, inputs, training=False):
|
|
597
|
+
"""Forward pass for AWQ quantized layer."""
|
|
598
|
+
if not self.is_awq_calibrated:
|
|
599
|
+
W = self._kernel
|
|
600
|
+
else:
|
|
601
|
+
# Unpack 4-bit weights
|
|
602
|
+
W = quantizers.unpack_int4(
|
|
603
|
+
self.quantized_kernel,
|
|
604
|
+
orig_len=self.units,
|
|
605
|
+
axis=0,
|
|
606
|
+
dtype="uint8",
|
|
607
|
+
)
|
|
608
|
+
# Dequantize using scale/zero maps
|
|
609
|
+
W = ops.transpose(
|
|
610
|
+
dequantize_with_sz_map(
|
|
611
|
+
W,
|
|
612
|
+
self.kernel_scale,
|
|
613
|
+
self.kernel_zero,
|
|
614
|
+
self.g_idx,
|
|
615
|
+
)
|
|
616
|
+
)
|
|
617
|
+
# Apply AWQ scales by dividing to restore original magnitude
|
|
618
|
+
# (We multiplied by scales before quantization, so divide to undo)
|
|
619
|
+
# awq_scales has shape [input_dim], W has shape [input_dim, units]
|
|
620
|
+
# Expand dims for proper broadcasting.
|
|
621
|
+
W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
|
|
622
|
+
|
|
623
|
+
y = ops.matmul(inputs, W)
|
|
624
|
+
if self.bias is not None:
|
|
625
|
+
y = ops.add(y, self.bias)
|
|
626
|
+
if self.activation is not None:
|
|
627
|
+
y = self.activation(y)
|
|
628
|
+
return y
|
|
629
|
+
|
|
630
|
+
def _int4_build(self, kernel_shape, config=None):
|
|
523
631
|
"""Build variables for int4 quantization.
|
|
524
632
|
|
|
525
633
|
`kernel_shape` is the *original* float32 kernel shape
|
|
@@ -528,8 +636,10 @@ class Dense(Layer):
|
|
|
528
636
|
int8 byte.
|
|
529
637
|
"""
|
|
530
638
|
# Per-channel int8 quantizer for the last axis (features).
|
|
531
|
-
self.inputs_quantizer =
|
|
532
|
-
|
|
639
|
+
self.inputs_quantizer = (
|
|
640
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
641
|
+
config, quantizers.AbsMaxQuantizer()
|
|
642
|
+
)
|
|
533
643
|
)
|
|
534
644
|
input_dim, output_dim = kernel_shape
|
|
535
645
|
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
|
|
@@ -618,11 +728,15 @@ class Dense(Layer):
|
|
|
618
728
|
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
|
|
619
729
|
return (inputs_grad, None, None)
|
|
620
730
|
|
|
621
|
-
|
|
731
|
+
output_scale = kernel_scale
|
|
732
|
+
if self.inputs_quantizer:
|
|
733
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
|
|
734
|
+
output_scale = ops.multiply(output_scale, inputs_scale)
|
|
735
|
+
|
|
622
736
|
x = ops.matmul(inputs, kernel)
|
|
623
737
|
# De-scale outputs
|
|
624
738
|
x = ops.cast(x, self.compute_dtype)
|
|
625
|
-
x = ops.divide(x,
|
|
739
|
+
x = ops.divide(x, output_scale)
|
|
626
740
|
return x, grad_fn
|
|
627
741
|
|
|
628
742
|
x = matmul_with_inputs_gradient(
|
|
@@ -669,10 +783,15 @@ class Dense(Layer):
|
|
|
669
783
|
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
|
|
670
784
|
return (inputs_grad, None, None)
|
|
671
785
|
|
|
672
|
-
|
|
786
|
+
output_scale = kernel_scale
|
|
787
|
+
|
|
788
|
+
if self.inputs_quantizer:
|
|
789
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
|
|
790
|
+
output_scale = ops.multiply(output_scale, inputs_scale)
|
|
791
|
+
|
|
673
792
|
x = ops.matmul(inputs, unpacked_kernel)
|
|
674
793
|
x = ops.cast(x, self.compute_dtype)
|
|
675
|
-
x = ops.divide(x,
|
|
794
|
+
x = ops.divide(x, output_scale)
|
|
676
795
|
return x, grad_fn
|
|
677
796
|
|
|
678
797
|
x = matmul_with_inputs_gradient(
|
|
@@ -784,30 +903,37 @@ class Dense(Layer):
|
|
|
784
903
|
x = self.activation(x)
|
|
785
904
|
return x
|
|
786
905
|
|
|
787
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
906
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
788
907
|
# Prevent quantization of the subclasses
|
|
789
908
|
if type_check and (type(self) is not Dense):
|
|
790
909
|
raise self._not_implemented_error(self.quantize)
|
|
791
910
|
|
|
911
|
+
self.quantization_config = config
|
|
912
|
+
|
|
792
913
|
kernel_shape = self._kernel.shape
|
|
793
914
|
if mode == "int8":
|
|
794
|
-
|
|
795
|
-
self.
|
|
915
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
916
|
+
self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)
|
|
917
|
+
)
|
|
918
|
+
kernel_value, kernel_scale = weight_quantizer(
|
|
919
|
+
self._kernel, to_numpy=True
|
|
796
920
|
)
|
|
797
921
|
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
798
922
|
del self._kernel
|
|
799
923
|
# Build variables for int8 mode
|
|
800
|
-
self.quantized_build(kernel_shape, mode)
|
|
924
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
801
925
|
self._kernel.assign(kernel_value)
|
|
802
926
|
self.kernel_scale.assign(kernel_scale)
|
|
803
927
|
elif mode == "int4":
|
|
804
928
|
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
|
|
805
|
-
|
|
806
|
-
self.
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
929
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
930
|
+
self.quantization_config,
|
|
931
|
+
quantizers.AbsMaxQuantizer(
|
|
932
|
+
axis=0, value_range=(-8, 7), output_dtype="int8"
|
|
933
|
+
),
|
|
934
|
+
)
|
|
935
|
+
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
936
|
+
self._kernel, to_numpy=True
|
|
811
937
|
)
|
|
812
938
|
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
813
939
|
# 2. Pack two int4 values into a single int8 byte.
|
|
@@ -815,12 +941,14 @@ class Dense(Layer):
|
|
|
815
941
|
del self._kernel
|
|
816
942
|
# Build variables using the original kernel shape; _int4_build will
|
|
817
943
|
# compute the packed shape internally.
|
|
818
|
-
self.quantized_build(kernel_shape, mode)
|
|
944
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
819
945
|
# Assign packed values.
|
|
820
946
|
self._kernel.assign(packed_kernel_value)
|
|
821
947
|
self.kernel_scale.assign(kernel_scale)
|
|
822
948
|
elif mode == "gptq":
|
|
823
|
-
self.quantized_build(kernel_shape, mode,
|
|
949
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
950
|
+
elif mode == "awq":
|
|
951
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
824
952
|
elif mode == "float8":
|
|
825
953
|
self.quantized_build(kernel_shape, mode)
|
|
826
954
|
else:
|
|
@@ -832,7 +960,9 @@ class Dense(Layer):
|
|
|
832
960
|
|
|
833
961
|
policy_name = mode
|
|
834
962
|
if mode == "gptq":
|
|
835
|
-
policy_name =
|
|
963
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
964
|
+
elif mode == "awq":
|
|
965
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
836
966
|
policy = dtype_policies.get(
|
|
837
967
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
838
968
|
)
|
|
@@ -867,7 +997,7 @@ class Dense(Layer):
|
|
|
867
997
|
`kernel_scale`: The quantization scale for the merged kernel.
|
|
868
998
|
This is `None` if the layer is not quantized.
|
|
869
999
|
"""
|
|
870
|
-
if self.dtype_policy.quantization_mode in (None, "gptq"):
|
|
1000
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
871
1001
|
return self.kernel, None
|
|
872
1002
|
|
|
873
1003
|
kernel_value = self._kernel
|