keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ import ml_dtypes
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
8
|
from keras.src import activations
|
|
9
|
+
from keras.src import backend
|
|
9
10
|
from keras.src import constraints
|
|
10
11
|
from keras.src import dtype_policies
|
|
11
12
|
from keras.src import initializers
|
|
@@ -15,7 +16,9 @@ from keras.src import regularizers
|
|
|
15
16
|
from keras.src.api_export import keras_export
|
|
16
17
|
from keras.src.layers.input_spec import InputSpec
|
|
17
18
|
from keras.src.layers.layer import Layer
|
|
19
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
18
20
|
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
21
|
+
from keras.src.saving import serialization_lib
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
@keras_export("keras.layers.EinsumDense")
|
|
@@ -134,6 +137,7 @@ class EinsumDense(Layer):
|
|
|
134
137
|
lora_rank=None,
|
|
135
138
|
lora_alpha=None,
|
|
136
139
|
gptq_unpacked_column_size=None,
|
|
140
|
+
quantization_config=None,
|
|
137
141
|
**kwargs,
|
|
138
142
|
):
|
|
139
143
|
super().__init__(**kwargs)
|
|
@@ -154,6 +158,7 @@ class EinsumDense(Layer):
|
|
|
154
158
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
155
159
|
self.lora_enabled = False
|
|
156
160
|
self.gptq_unpacked_column_size = gptq_unpacked_column_size
|
|
161
|
+
self.quantization_config = quantization_config
|
|
157
162
|
|
|
158
163
|
def build(self, input_shape):
|
|
159
164
|
shape_data = _analyze_einsum_string(
|
|
@@ -169,12 +174,13 @@ class EinsumDense(Layer):
|
|
|
169
174
|
self.quantized_build(
|
|
170
175
|
kernel_shape,
|
|
171
176
|
mode=self.quantization_mode,
|
|
177
|
+
config=self.quantization_config,
|
|
172
178
|
)
|
|
173
179
|
# Skip creating a duplicate kernel variable when the layer is already
|
|
174
180
|
# quantized to int8 or int4, because `quantized_build` has created the
|
|
175
181
|
# appropriate kernel variable. For other modes (e.g., float8 or no
|
|
176
182
|
# quantization), we still need the floating-point kernel.
|
|
177
|
-
if self.quantization_mode not in ("int8", "int4", "gptq"):
|
|
183
|
+
if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
|
|
178
184
|
# If the layer is quantized to int8, `self._kernel` will be added
|
|
179
185
|
# in `self._int8_build`. Therefore, we skip it here.
|
|
180
186
|
self._kernel = self.add_weight(
|
|
@@ -213,15 +219,17 @@ class EinsumDense(Layer):
|
|
|
213
219
|
|
|
214
220
|
mode = self.quantization_mode
|
|
215
221
|
is_gptq = mode == "gptq"
|
|
222
|
+
is_awq = mode == "awq"
|
|
216
223
|
is_int4 = mode == "int4"
|
|
217
|
-
|
|
224
|
+
gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
225
|
+
awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
|
|
218
226
|
gptq_bits = (
|
|
219
227
|
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
220
228
|
)
|
|
221
229
|
|
|
222
230
|
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
223
231
|
# kernel)
|
|
224
|
-
if is_gptq and
|
|
232
|
+
if is_gptq and gptq_calibrated and gptq_bits != 4:
|
|
225
233
|
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
226
234
|
kernel = self.quantized_kernel
|
|
227
235
|
else:
|
|
@@ -235,13 +243,21 @@ class EinsumDense(Layer):
|
|
|
235
243
|
self._orig_length_along_pack_axis,
|
|
236
244
|
self._int4_pack_axis,
|
|
237
245
|
)
|
|
238
|
-
elif is_gptq and
|
|
246
|
+
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
239
247
|
kernel = quantizers.unpack_int4(
|
|
240
248
|
self.quantized_kernel,
|
|
241
249
|
orig_len=self.gptq_unpacked_column_size,
|
|
242
250
|
axis=0,
|
|
243
251
|
dtype="uint8",
|
|
244
252
|
)
|
|
253
|
+
elif is_awq and awq_calibrated:
|
|
254
|
+
# AWQ always uses 4-bit quantization
|
|
255
|
+
kernel = quantizers.unpack_int4(
|
|
256
|
+
self.quantized_kernel,
|
|
257
|
+
orig_len=self.awq_unpacked_column_size,
|
|
258
|
+
axis=0,
|
|
259
|
+
dtype="uint8",
|
|
260
|
+
)
|
|
245
261
|
|
|
246
262
|
# Apply LoRA if enabled
|
|
247
263
|
if self.lora_enabled:
|
|
@@ -326,25 +342,25 @@ class EinsumDense(Layer):
|
|
|
326
342
|
if not self.built:
|
|
327
343
|
return
|
|
328
344
|
mode = self.quantization_mode
|
|
329
|
-
if mode not in self.
|
|
345
|
+
if mode not in self.variable_serialization_spec:
|
|
330
346
|
raise self._quantization_mode_error(mode)
|
|
331
347
|
|
|
332
348
|
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
|
|
333
349
|
# for None/gptq)
|
|
334
350
|
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
if name == "kernel_scale" and mode in ("int4", "int8"):
|
|
351
|
+
idx = 0
|
|
352
|
+
for name in self.variable_serialization_spec[mode]:
|
|
353
|
+
if name == "kernel":
|
|
354
|
+
store[str(idx)] = kernel_value
|
|
355
|
+
elif name == "bias" and self.bias is None:
|
|
356
|
+
continue
|
|
357
|
+
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
343
358
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
344
359
|
# `_get_kernel_with_merged_lora()`
|
|
345
|
-
store[
|
|
360
|
+
store[str(idx)] = merged_kernel_scale
|
|
346
361
|
else:
|
|
347
|
-
store[
|
|
362
|
+
store[str(idx)] = getattr(self, name)
|
|
363
|
+
idx += 1
|
|
348
364
|
|
|
349
365
|
def load_own_variables(self, store):
|
|
350
366
|
if not self.lora_enabled:
|
|
@@ -353,39 +369,22 @@ class EinsumDense(Layer):
|
|
|
353
369
|
if not self.built:
|
|
354
370
|
return
|
|
355
371
|
mode = self.quantization_mode
|
|
356
|
-
if mode not in self.
|
|
372
|
+
if mode not in self.variable_serialization_spec:
|
|
357
373
|
raise self._quantization_mode_error(mode)
|
|
358
374
|
|
|
359
|
-
#
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
# Load the variables using the name as the key.
|
|
364
|
-
if mode != "gptq":
|
|
365
|
-
self._kernel.assign(store["kernel"])
|
|
366
|
-
if self.bias is not None:
|
|
367
|
-
self.bias.assign(store["bias"])
|
|
368
|
-
for name in self.quantization_variable_spec[mode]:
|
|
369
|
-
getattr(self, name).assign(store[name])
|
|
370
|
-
if self.lora_enabled:
|
|
371
|
-
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
372
|
-
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
375
|
+
# A saved GPTQ/AWQ quantized model will always be calibrated.
|
|
376
|
+
self.is_gptq_calibrated = mode == "gptq"
|
|
377
|
+
self.is_awq_calibrated = mode == "awq"
|
|
373
378
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
targets.extend(
|
|
384
|
-
getattr(self, name)
|
|
385
|
-
for name in self.quantization_variable_spec[mode]
|
|
386
|
-
)
|
|
387
|
-
for i, variable in enumerate(targets):
|
|
388
|
-
variable.assign(store[str(i)])
|
|
379
|
+
idx = 0
|
|
380
|
+
for name in self.variable_serialization_spec[mode]:
|
|
381
|
+
if name == "kernel":
|
|
382
|
+
self._kernel.assign(store[str(idx)])
|
|
383
|
+
elif name == "bias" and self.bias is None:
|
|
384
|
+
continue
|
|
385
|
+
else:
|
|
386
|
+
getattr(self, name).assign(store[str(idx)])
|
|
387
|
+
idx += 1
|
|
389
388
|
if self.lora_enabled:
|
|
390
389
|
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
391
390
|
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
@@ -410,6 +409,9 @@ class EinsumDense(Layer):
|
|
|
410
409
|
),
|
|
411
410
|
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
|
412
411
|
"bias_constraint": constraints.serialize(self.bias_constraint),
|
|
412
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
413
|
+
self.quantization_config
|
|
414
|
+
),
|
|
413
415
|
}
|
|
414
416
|
if self.lora_rank:
|
|
415
417
|
config["lora_rank"] = self.lora_rank
|
|
@@ -418,53 +420,42 @@ class EinsumDense(Layer):
|
|
|
418
420
|
config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
|
|
419
421
|
return {**base_config, **config}
|
|
420
422
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
"and thus it doesn't have any variables. "
|
|
428
|
-
f"However the weights file lists {len(store.keys())} "
|
|
429
|
-
"variables for this layer.\n"
|
|
430
|
-
"In most cases, this error indicates that either:\n\n"
|
|
431
|
-
"1. The layer is owned by a parent layer that "
|
|
432
|
-
"implements a `build()` method, but calling the "
|
|
433
|
-
"parent's `build()` method did NOT create the state of "
|
|
434
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
435
|
-
"must create ALL state for the layer, including "
|
|
436
|
-
"the state of any children layers.\n\n"
|
|
437
|
-
"2. You need to implement "
|
|
438
|
-
"the `def build_from_config(self, config)` method "
|
|
439
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
440
|
-
"it during loading. "
|
|
441
|
-
"In this case, you might also want to implement the "
|
|
442
|
-
"method that generates the build config at saving time, "
|
|
443
|
-
"`def get_build_config(self)`. "
|
|
444
|
-
"The method `build_from_config()` is meant "
|
|
445
|
-
"to create the state "
|
|
446
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
447
|
-
)
|
|
448
|
-
raise ValueError(
|
|
449
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
450
|
-
"but received "
|
|
451
|
-
f"{len(store.keys())} variables during loading. "
|
|
452
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
423
|
+
@classmethod
|
|
424
|
+
def from_config(cls, config):
|
|
425
|
+
config = config.copy()
|
|
426
|
+
config["quantization_config"] = (
|
|
427
|
+
serialization_lib.deserialize_keras_object(
|
|
428
|
+
config.get("quantization_config", None)
|
|
453
429
|
)
|
|
430
|
+
)
|
|
431
|
+
return super().from_config(config)
|
|
454
432
|
|
|
455
433
|
@property
|
|
456
|
-
def
|
|
457
|
-
"""Returns a dict mapping quantization modes to variable names.
|
|
434
|
+
def variable_serialization_spec(self):
|
|
435
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
458
436
|
|
|
459
437
|
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
460
|
-
determine
|
|
461
|
-
mode.
|
|
438
|
+
determine the correct ordering of variables during serialization for
|
|
439
|
+
each quantization mode. `None` means no quantization.
|
|
462
440
|
"""
|
|
463
441
|
return {
|
|
464
|
-
None: [
|
|
465
|
-
|
|
466
|
-
|
|
442
|
+
None: [
|
|
443
|
+
"kernel",
|
|
444
|
+
"bias",
|
|
445
|
+
],
|
|
446
|
+
"int8": [
|
|
447
|
+
"kernel",
|
|
448
|
+
"bias",
|
|
449
|
+
"kernel_scale",
|
|
450
|
+
],
|
|
451
|
+
"int4": [
|
|
452
|
+
"kernel",
|
|
453
|
+
"bias",
|
|
454
|
+
"kernel_scale",
|
|
455
|
+
],
|
|
467
456
|
"float8": [
|
|
457
|
+
"kernel",
|
|
458
|
+
"bias",
|
|
468
459
|
"inputs_scale",
|
|
469
460
|
"inputs_amax_history",
|
|
470
461
|
"kernel_scale",
|
|
@@ -473,31 +464,48 @@ class EinsumDense(Layer):
|
|
|
473
464
|
"outputs_grad_amax_history",
|
|
474
465
|
],
|
|
475
466
|
"gptq": [
|
|
467
|
+
"bias",
|
|
468
|
+
"quantized_kernel",
|
|
469
|
+
"kernel_scale",
|
|
470
|
+
"kernel_zero",
|
|
471
|
+
"g_idx",
|
|
472
|
+
],
|
|
473
|
+
"awq": [
|
|
474
|
+
"bias",
|
|
476
475
|
"quantized_kernel",
|
|
477
476
|
"kernel_scale",
|
|
478
477
|
"kernel_zero",
|
|
478
|
+
"awq_scales",
|
|
479
479
|
"g_idx",
|
|
480
480
|
],
|
|
481
481
|
}
|
|
482
482
|
|
|
483
483
|
def quantized_build(self, kernel_shape, mode, config=None):
|
|
484
484
|
if mode == "int8":
|
|
485
|
-
self._int8_build(kernel_shape)
|
|
485
|
+
self._int8_build(kernel_shape, config)
|
|
486
486
|
elif mode == "int4":
|
|
487
|
-
self._int4_build(kernel_shape)
|
|
487
|
+
self._int4_build(kernel_shape, config)
|
|
488
488
|
elif mode == "float8":
|
|
489
489
|
self._float8_build()
|
|
490
490
|
elif mode == "gptq":
|
|
491
491
|
self._gptq_build(kernel_shape, config)
|
|
492
|
+
elif mode == "awq":
|
|
493
|
+
self._awq_build(kernel_shape, config)
|
|
492
494
|
else:
|
|
493
495
|
raise self._quantization_mode_error(mode)
|
|
494
496
|
self._is_quantized = True
|
|
495
497
|
|
|
496
|
-
def _int8_build(self, kernel_shape):
|
|
498
|
+
def _int8_build(self, kernel_shape, config=None):
|
|
497
499
|
self._set_quantization_info()
|
|
498
|
-
self.inputs_quantizer =
|
|
499
|
-
|
|
500
|
+
self.inputs_quantizer = (
|
|
501
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
502
|
+
config,
|
|
503
|
+
quantizers.AbsMaxQuantizer(),
|
|
504
|
+
)
|
|
500
505
|
)
|
|
506
|
+
# If the config provided a default AbsMaxQuantizer, we need to
|
|
507
|
+
# override the axis to match the equation's reduction axes.
|
|
508
|
+
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
501
509
|
self._kernel = self.add_weight(
|
|
502
510
|
name="kernel",
|
|
503
511
|
shape=kernel_shape,
|
|
@@ -535,12 +543,7 @@ class EinsumDense(Layer):
|
|
|
535
543
|
columns = kernel_shape[1]
|
|
536
544
|
elif len(kernel_shape) == 3:
|
|
537
545
|
shape = list(self.original_kernel_shape)
|
|
538
|
-
|
|
539
|
-
d_model_dim_index = shape.index(max(shape))
|
|
540
|
-
except ValueError:
|
|
541
|
-
raise TypeError(
|
|
542
|
-
f"Could not determine hidden dimension from shape {shape}"
|
|
543
|
-
)
|
|
546
|
+
d_model_dim_index = shape.index(max(shape))
|
|
544
547
|
|
|
545
548
|
if d_model_dim_index == 0: # QKV projection case
|
|
546
549
|
in_features, heads, head_dim = shape
|
|
@@ -566,8 +569,7 @@ class EinsumDense(Layer):
|
|
|
566
569
|
# For 4-bit weights, we pack two values per byte.
|
|
567
570
|
kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns
|
|
568
571
|
|
|
569
|
-
|
|
570
|
-
self._set_quantization_info()
|
|
572
|
+
self._set_quantization_info()
|
|
571
573
|
|
|
572
574
|
self.quantized_kernel = self.add_weight(
|
|
573
575
|
name="kernel",
|
|
@@ -635,7 +637,128 @@ class EinsumDense(Layer):
|
|
|
635
637
|
y = self.activation(y)
|
|
636
638
|
return y
|
|
637
639
|
|
|
638
|
-
def
|
|
640
|
+
def _awq_build(self, kernel_shape, config):
|
|
641
|
+
"""Build variables for AWQ quantization.
|
|
642
|
+
|
|
643
|
+
AWQ uses 4-bit quantization with per-channel AWQ scales that protect
|
|
644
|
+
salient weights based on activation magnitudes.
|
|
645
|
+
"""
|
|
646
|
+
from keras.src.quantizers import awq_core
|
|
647
|
+
|
|
648
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
649
|
+
# until calibration has been performed.
|
|
650
|
+
self.is_awq_calibrated = False
|
|
651
|
+
|
|
652
|
+
self.original_kernel_shape = kernel_shape
|
|
653
|
+
if len(kernel_shape) == 2:
|
|
654
|
+
rows = kernel_shape[0]
|
|
655
|
+
columns = kernel_shape[1]
|
|
656
|
+
elif len(kernel_shape) == 3:
|
|
657
|
+
shape = list(self.original_kernel_shape)
|
|
658
|
+
d_model_dim_index = shape.index(max(shape))
|
|
659
|
+
|
|
660
|
+
if d_model_dim_index == 0: # QKV projection case
|
|
661
|
+
in_features, heads, head_dim = shape
|
|
662
|
+
rows, columns = (
|
|
663
|
+
in_features,
|
|
664
|
+
heads * head_dim,
|
|
665
|
+
)
|
|
666
|
+
elif d_model_dim_index in [1, 2]: # Attention Output case
|
|
667
|
+
heads, head_dim, out_features = shape
|
|
668
|
+
rows, columns = (
|
|
669
|
+
heads * head_dim,
|
|
670
|
+
out_features,
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
raise ValueError("Could not determine row/column split.")
|
|
674
|
+
else:
|
|
675
|
+
raise ValueError("AWQ quantization only supports 2D or 3D kernels.")
|
|
676
|
+
|
|
677
|
+
group_size = awq_core.get_group_size_for_layer(self, config)
|
|
678
|
+
num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
|
|
679
|
+
|
|
680
|
+
self.awq_unpacked_column_size = columns
|
|
681
|
+
|
|
682
|
+
# For 4-bit weights, we pack two values per byte.
|
|
683
|
+
kernel_columns = (columns + 1) // 2
|
|
684
|
+
|
|
685
|
+
self._set_quantization_info()
|
|
686
|
+
|
|
687
|
+
self.quantized_kernel = self.add_weight(
|
|
688
|
+
name="kernel",
|
|
689
|
+
shape=(kernel_columns, rows),
|
|
690
|
+
initializer="zeros",
|
|
691
|
+
dtype="uint8",
|
|
692
|
+
trainable=False,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
self.kernel_scale = self.add_weight(
|
|
696
|
+
name="kernel_scale",
|
|
697
|
+
shape=(columns, num_groups),
|
|
698
|
+
initializer="ones",
|
|
699
|
+
trainable=False,
|
|
700
|
+
)
|
|
701
|
+
self.kernel_zero = self.add_weight(
|
|
702
|
+
name="zero_point",
|
|
703
|
+
shape=(columns, num_groups),
|
|
704
|
+
initializer="zeros",
|
|
705
|
+
dtype="uint8",
|
|
706
|
+
trainable=False,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# Per-channel AWQ scales from activation magnitudes
|
|
710
|
+
self.awq_scales = self.add_weight(
|
|
711
|
+
name="awq_scales",
|
|
712
|
+
shape=(rows,),
|
|
713
|
+
initializer="ones",
|
|
714
|
+
trainable=False,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
self.g_idx = self.add_weight(
|
|
718
|
+
name="g_idx",
|
|
719
|
+
shape=(rows,),
|
|
720
|
+
initializer="zeros",
|
|
721
|
+
dtype="float32",
|
|
722
|
+
trainable=False,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def _awq_call(self, inputs, training=False):
|
|
726
|
+
"""Forward pass for AWQ quantized layer."""
|
|
727
|
+
if not self.is_awq_calibrated:
|
|
728
|
+
W = self._kernel
|
|
729
|
+
else:
|
|
730
|
+
# Unpack 4-bit weights
|
|
731
|
+
W = quantizers.unpack_int4(
|
|
732
|
+
self.quantized_kernel,
|
|
733
|
+
orig_len=self.awq_unpacked_column_size,
|
|
734
|
+
axis=0,
|
|
735
|
+
dtype="uint8",
|
|
736
|
+
)
|
|
737
|
+
# Dequantize using scale/zero maps
|
|
738
|
+
W = dequantize_with_sz_map(
|
|
739
|
+
W,
|
|
740
|
+
self.kernel_scale,
|
|
741
|
+
self.kernel_zero,
|
|
742
|
+
self.g_idx,
|
|
743
|
+
)
|
|
744
|
+
W = ops.transpose(W)
|
|
745
|
+
|
|
746
|
+
# Apply AWQ scales by dividing to restore original magnitude
|
|
747
|
+
# (We multiplied by scales before quantization, so divide to undo)
|
|
748
|
+
# awq_scales has shape [input_dim], W has shape [input_dim, out_dim]
|
|
749
|
+
# Expand dims for proper broadcasting.
|
|
750
|
+
W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
|
|
751
|
+
|
|
752
|
+
W = ops.reshape(W, self.original_kernel_shape)
|
|
753
|
+
|
|
754
|
+
y = ops.einsum(self.equation, inputs, W)
|
|
755
|
+
if self.bias is not None:
|
|
756
|
+
y = ops.add(y, self.bias)
|
|
757
|
+
if self.activation is not None:
|
|
758
|
+
y = self.activation(y)
|
|
759
|
+
return y
|
|
760
|
+
|
|
761
|
+
def _int4_build(self, kernel_shape, config=None):
|
|
639
762
|
"""Build variables for int4 quantization.
|
|
640
763
|
|
|
641
764
|
The packed int4 kernel stores two int4 values within a single int8
|
|
@@ -647,9 +770,15 @@ class EinsumDense(Layer):
|
|
|
647
770
|
self._set_quantization_info()
|
|
648
771
|
|
|
649
772
|
# Quantizer for the inputs (per the reduced axes)
|
|
650
|
-
self.inputs_quantizer =
|
|
651
|
-
|
|
773
|
+
self.inputs_quantizer = (
|
|
774
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
775
|
+
config,
|
|
776
|
+
quantizers.AbsMaxQuantizer(),
|
|
777
|
+
)
|
|
652
778
|
)
|
|
779
|
+
# If the config provided a default AbsMaxQuantizer, we need to
|
|
780
|
+
# override the axis to match the equation's reduction axes.
|
|
781
|
+
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
653
782
|
|
|
654
783
|
# Choose the axis to perform int4 packing - use the first reduced axis
|
|
655
784
|
# for the kernel (analogous to the input dimension of a Dense layer).
|
|
@@ -771,13 +900,36 @@ class EinsumDense(Layer):
|
|
|
771
900
|
)
|
|
772
901
|
return (inputs_grad, None, None)
|
|
773
902
|
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
903
|
+
if self.inputs_quantizer:
|
|
904
|
+
inputs, inputs_scale = self.inputs_quantizer(
|
|
905
|
+
inputs, axis=self.quantization_axis
|
|
906
|
+
)
|
|
907
|
+
# Align `inputs_scale` axes with the output
|
|
908
|
+
# for correct broadcasting
|
|
909
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
910
|
+
inputs_scale, "input"
|
|
911
|
+
)
|
|
912
|
+
x = ops.einsum(self.equation, inputs, kernel)
|
|
913
|
+
# De-scale outputs
|
|
914
|
+
x = ops.cast(x, self.compute_dtype)
|
|
915
|
+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
916
|
+
else:
|
|
917
|
+
# Weight-only quantization: dequantize kernel and use float
|
|
918
|
+
# einsum. This is a workaround for PyTorch's einsum which
|
|
919
|
+
# doesn't support mixed-precision inputs (float input,
|
|
920
|
+
# int8 kernel).
|
|
921
|
+
if backend.backend() == "torch":
|
|
922
|
+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
923
|
+
float_kernel = ops.divide(
|
|
924
|
+
ops.cast(kernel, dtype=self.compute_dtype),
|
|
925
|
+
kernel_scale,
|
|
926
|
+
)
|
|
927
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
928
|
+
else:
|
|
929
|
+
x = ops.einsum(self.equation, inputs, kernel)
|
|
930
|
+
# De-scale outputs
|
|
931
|
+
x = ops.cast(x, self.compute_dtype)
|
|
932
|
+
x = ops.divide(x, kernel_scale)
|
|
781
933
|
return x, grad_fn
|
|
782
934
|
|
|
783
935
|
x = einsum_with_inputs_gradient(
|
|
@@ -847,17 +999,38 @@ class EinsumDense(Layer):
|
|
|
847
999
|
return (inputs_grad, None, None)
|
|
848
1000
|
|
|
849
1001
|
# Quantize inputs per `self.inputs_quantizer`.
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
1002
|
+
if self.inputs_quantizer:
|
|
1003
|
+
inputs_q, inputs_scale = self.inputs_quantizer(
|
|
1004
|
+
inputs, axis=self.quantization_axis
|
|
1005
|
+
)
|
|
1006
|
+
# Align `inputs_scale` axes with the output
|
|
1007
|
+
# for correct broadcasting
|
|
1008
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
1009
|
+
inputs_scale, "input"
|
|
1010
|
+
)
|
|
1011
|
+
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
|
|
1012
|
+
# De-scale outputs
|
|
1013
|
+
x = ops.cast(x, self.compute_dtype)
|
|
1014
|
+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
1015
|
+
else:
|
|
1016
|
+
# Weight-only quantization: dequantize kernel and use float
|
|
1017
|
+
# einsum. This is a workaround for PyTorch's einsum which
|
|
1018
|
+
# doesn't support mixed-precision inputs (float input,
|
|
1019
|
+
# int4 kernel).
|
|
1020
|
+
if backend.backend() == "torch":
|
|
1021
|
+
# Align `kernel_scale` to the same layout as
|
|
1022
|
+
# `unpacked_kernel`.
|
|
1023
|
+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
1024
|
+
float_kernel = ops.divide(
|
|
1025
|
+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
|
|
1026
|
+
kernel_scale,
|
|
1027
|
+
)
|
|
1028
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
1029
|
+
else:
|
|
1030
|
+
x = ops.einsum(self.equation, inputs, unpacked_kernel)
|
|
1031
|
+
# De-scale outputs
|
|
1032
|
+
x = ops.cast(x, self.compute_dtype)
|
|
1033
|
+
x = ops.divide(x, kernel_scale)
|
|
861
1034
|
return x, grad_fn
|
|
862
1035
|
|
|
863
1036
|
x = einsum_with_inputs_gradient(
|
|
@@ -971,30 +1144,40 @@ class EinsumDense(Layer):
|
|
|
971
1144
|
x = self.activation(x)
|
|
972
1145
|
return x
|
|
973
1146
|
|
|
974
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
1147
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
975
1148
|
# Prevent quantization of the subclasses
|
|
976
1149
|
if type_check and (type(self) is not EinsumDense):
|
|
977
1150
|
raise self._not_implemented_error(self.quantize)
|
|
978
1151
|
|
|
1152
|
+
self.quantization_config = config
|
|
1153
|
+
|
|
979
1154
|
kernel_shape = self._kernel.shape
|
|
980
|
-
if mode in ("int8", "int4", "gptq"):
|
|
1155
|
+
if mode in ("int8", "int4", "gptq", "awq"):
|
|
981
1156
|
self._set_quantization_info()
|
|
982
1157
|
|
|
983
1158
|
if mode == "int8":
|
|
984
1159
|
# Quantize `self._kernel` to int8 and compute corresponding scale
|
|
985
|
-
|
|
986
|
-
self.
|
|
1160
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
1161
|
+
self.quantization_config,
|
|
1162
|
+
quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
|
|
1163
|
+
)
|
|
1164
|
+
kernel_value, kernel_scale = weight_quantizer(
|
|
1165
|
+
self._kernel, to_numpy=True
|
|
987
1166
|
)
|
|
988
1167
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
989
1168
|
del self._kernel
|
|
990
1169
|
elif mode == "int4":
|
|
991
1170
|
# Quantize to int4 values (stored in int8 dtype, range [-8, 7])
|
|
992
|
-
|
|
993
|
-
self.
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
1171
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
1172
|
+
self.quantization_config,
|
|
1173
|
+
quantizers.AbsMaxQuantizer(
|
|
1174
|
+
axis=self._kernel_reduced_axes,
|
|
1175
|
+
value_range=(-8, 7),
|
|
1176
|
+
output_dtype="int8",
|
|
1177
|
+
),
|
|
1178
|
+
)
|
|
1179
|
+
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
1180
|
+
self._kernel, to_numpy=True
|
|
998
1181
|
)
|
|
999
1182
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
1000
1183
|
|
|
@@ -1005,7 +1188,7 @@ class EinsumDense(Layer):
|
|
|
1005
1188
|
)
|
|
1006
1189
|
kernel_value = packed_kernel_value
|
|
1007
1190
|
del self._kernel
|
|
1008
|
-
self.quantized_build(kernel_shape, mode,
|
|
1191
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
1009
1192
|
|
|
1010
1193
|
# Assign values to the newly created variables.
|
|
1011
1194
|
if mode in ("int8", "int4"):
|
|
@@ -1016,7 +1199,9 @@ class EinsumDense(Layer):
|
|
|
1016
1199
|
if self.dtype_policy.quantization_mode is None:
|
|
1017
1200
|
policy_name = mode
|
|
1018
1201
|
if mode == "gptq":
|
|
1019
|
-
policy_name =
|
|
1202
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
1203
|
+
elif mode == "awq":
|
|
1204
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
1020
1205
|
policy = dtype_policies.get(
|
|
1021
1206
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
1022
1207
|
)
|
|
@@ -1080,7 +1265,7 @@ class EinsumDense(Layer):
|
|
|
1080
1265
|
This is `None` if the layer is not quantized.
|
|
1081
1266
|
"""
|
|
1082
1267
|
# If not a quantized layer, return the full-precision kernel directly.
|
|
1083
|
-
if self.dtype_policy.quantization_mode in (None, "gptq"):
|
|
1268
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
1084
1269
|
return self.kernel, None
|
|
1085
1270
|
|
|
1086
1271
|
# If quantized but LoRA is not enabled, return the original quantized
|