keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ import ml_dtypes
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
8
|
from keras.src import activations
|
|
9
|
+
from keras.src import backend
|
|
9
10
|
from keras.src import constraints
|
|
10
11
|
from keras.src import dtype_policies
|
|
11
12
|
from keras.src import initializers
|
|
@@ -13,12 +14,11 @@ from keras.src import ops
|
|
|
13
14
|
from keras.src import quantizers
|
|
14
15
|
from keras.src import regularizers
|
|
15
16
|
from keras.src.api_export import keras_export
|
|
16
|
-
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
17
|
-
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
18
17
|
from keras.src.layers.input_spec import InputSpec
|
|
19
18
|
from keras.src.layers.layer import Layer
|
|
20
|
-
from keras.src.quantizers.
|
|
19
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
21
20
|
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
21
|
+
from keras.src.saving import serialization_lib
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@keras_export("keras.layers.EinsumDense")
|
|
@@ -136,6 +136,8 @@ class EinsumDense(Layer):
|
|
|
136
136
|
bias_constraint=None,
|
|
137
137
|
lora_rank=None,
|
|
138
138
|
lora_alpha=None,
|
|
139
|
+
gptq_unpacked_column_size=None,
|
|
140
|
+
quantization_config=None,
|
|
139
141
|
**kwargs,
|
|
140
142
|
):
|
|
141
143
|
super().__init__(**kwargs)
|
|
@@ -155,6 +157,8 @@ class EinsumDense(Layer):
|
|
|
155
157
|
self.lora_rank = lora_rank
|
|
156
158
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
157
159
|
self.lora_enabled = False
|
|
160
|
+
self.gptq_unpacked_column_size = gptq_unpacked_column_size
|
|
161
|
+
self.quantization_config = quantization_config
|
|
158
162
|
|
|
159
163
|
def build(self, input_shape):
|
|
160
164
|
shape_data = _analyze_einsum_string(
|
|
@@ -170,6 +174,7 @@ class EinsumDense(Layer):
|
|
|
170
174
|
self.quantized_build(
|
|
171
175
|
kernel_shape,
|
|
172
176
|
mode=self.quantization_mode,
|
|
177
|
+
config=self.quantization_config,
|
|
173
178
|
)
|
|
174
179
|
# Skip creating a duplicate kernel variable when the layer is already
|
|
175
180
|
# quantized to int8 or int4, because `quantized_build` has created the
|
|
@@ -205,24 +210,51 @@ class EinsumDense(Layer):
|
|
|
205
210
|
|
|
206
211
|
@property
|
|
207
212
|
def kernel(self):
|
|
213
|
+
from keras.src.quantizers import gptq_core
|
|
214
|
+
|
|
208
215
|
if not self.built:
|
|
209
216
|
raise AttributeError(
|
|
210
217
|
"You must build the layer before accessing `kernel`."
|
|
211
218
|
)
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
219
|
+
|
|
220
|
+
mode = self.quantization_mode
|
|
221
|
+
is_gptq = mode == "gptq"
|
|
222
|
+
is_int4 = mode == "int4"
|
|
223
|
+
calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
224
|
+
gptq_bits = (
|
|
225
|
+
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
229
|
+
# kernel)
|
|
230
|
+
if is_gptq and calibrated and gptq_bits != 4:
|
|
231
|
+
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
232
|
+
kernel = self.quantized_kernel
|
|
233
|
+
else:
|
|
234
|
+
# Start with the stored kernel
|
|
235
|
+
kernel = getattr(self, "_kernel", None)
|
|
236
|
+
|
|
237
|
+
# Handle int4 unpacking cases in one place
|
|
238
|
+
if is_int4:
|
|
239
|
+
kernel = quantizers.unpack_int4(
|
|
240
|
+
kernel,
|
|
241
|
+
self._orig_length_along_pack_axis,
|
|
242
|
+
self._int4_pack_axis,
|
|
243
|
+
)
|
|
244
|
+
elif is_gptq and calibrated and gptq_bits == 4:
|
|
245
|
+
kernel = quantizers.unpack_int4(
|
|
246
|
+
self.quantized_kernel,
|
|
247
|
+
orig_len=self.gptq_unpacked_column_size,
|
|
248
|
+
axis=0,
|
|
249
|
+
dtype="uint8",
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Apply LoRA if enabled
|
|
222
253
|
if self.lora_enabled:
|
|
223
|
-
|
|
254
|
+
kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
224
255
|
self.lora_kernel_a, self.lora_kernel_b
|
|
225
256
|
)
|
|
257
|
+
|
|
226
258
|
return kernel
|
|
227
259
|
|
|
228
260
|
def compute_output_shape(self, _):
|
|
@@ -300,25 +332,25 @@ class EinsumDense(Layer):
|
|
|
300
332
|
if not self.built:
|
|
301
333
|
return
|
|
302
334
|
mode = self.quantization_mode
|
|
303
|
-
if mode not in self.
|
|
335
|
+
if mode not in self.variable_serialization_spec:
|
|
304
336
|
raise self._quantization_mode_error(mode)
|
|
305
337
|
|
|
306
338
|
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
|
|
307
339
|
# for None/gptq)
|
|
308
340
|
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
if name == "kernel_scale" and mode in ("int4", "int8"):
|
|
341
|
+
idx = 0
|
|
342
|
+
for name in self.variable_serialization_spec[mode]:
|
|
343
|
+
if name == "kernel":
|
|
344
|
+
store[str(idx)] = kernel_value
|
|
345
|
+
elif name == "bias" and self.bias is None:
|
|
346
|
+
continue
|
|
347
|
+
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
317
348
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
318
349
|
# `_get_kernel_with_merged_lora()`
|
|
319
|
-
store[
|
|
350
|
+
store[str(idx)] = merged_kernel_scale
|
|
320
351
|
else:
|
|
321
|
-
store[
|
|
352
|
+
store[str(idx)] = getattr(self, name)
|
|
353
|
+
idx += 1
|
|
322
354
|
|
|
323
355
|
def load_own_variables(self, store):
|
|
324
356
|
if not self.lora_enabled:
|
|
@@ -327,39 +359,21 @@ class EinsumDense(Layer):
|
|
|
327
359
|
if not self.built:
|
|
328
360
|
return
|
|
329
361
|
mode = self.quantization_mode
|
|
330
|
-
if mode not in self.
|
|
362
|
+
if mode not in self.variable_serialization_spec:
|
|
331
363
|
raise self._quantization_mode_error(mode)
|
|
332
364
|
|
|
333
|
-
#
|
|
334
|
-
|
|
335
|
-
return self._legacy_load_own_variables(store)
|
|
336
|
-
|
|
337
|
-
# Load the variables using the name as the key.
|
|
338
|
-
if mode != "gptq":
|
|
339
|
-
self._kernel.assign(store["kernel"])
|
|
340
|
-
if self.bias is not None:
|
|
341
|
-
self.bias.assign(store["bias"])
|
|
342
|
-
for name in self.quantization_variable_spec[mode]:
|
|
343
|
-
getattr(self, name).assign(store[name])
|
|
344
|
-
if self.lora_enabled:
|
|
345
|
-
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
346
|
-
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
365
|
+
# A saved GPTQ quantized model will always be calibrated.
|
|
366
|
+
self.is_gptq_calibrated = mode == "gptq"
|
|
347
367
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
targets.extend(
|
|
358
|
-
getattr(self, name)
|
|
359
|
-
for name in self.quantization_variable_spec[mode]
|
|
360
|
-
)
|
|
361
|
-
for i, variable in enumerate(targets):
|
|
362
|
-
variable.assign(store[str(i)])
|
|
368
|
+
idx = 0
|
|
369
|
+
for name in self.variable_serialization_spec[mode]:
|
|
370
|
+
if name == "kernel":
|
|
371
|
+
self._kernel.assign(store[str(idx)])
|
|
372
|
+
elif name == "bias" and self.bias is None:
|
|
373
|
+
continue
|
|
374
|
+
else:
|
|
375
|
+
getattr(self, name).assign(store[str(idx)])
|
|
376
|
+
idx += 1
|
|
363
377
|
if self.lora_enabled:
|
|
364
378
|
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
365
379
|
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
@@ -384,59 +398,53 @@ class EinsumDense(Layer):
|
|
|
384
398
|
),
|
|
385
399
|
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
|
386
400
|
"bias_constraint": constraints.serialize(self.bias_constraint),
|
|
401
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
402
|
+
self.quantization_config
|
|
403
|
+
),
|
|
387
404
|
}
|
|
388
405
|
if self.lora_rank:
|
|
389
406
|
config["lora_rank"] = self.lora_rank
|
|
390
407
|
config["lora_alpha"] = self.lora_alpha
|
|
408
|
+
if self.gptq_unpacked_column_size:
|
|
409
|
+
config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
|
|
391
410
|
return {**base_config, **config}
|
|
392
411
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
"and thus it doesn't have any variables. "
|
|
400
|
-
f"However the weights file lists {len(store.keys())} "
|
|
401
|
-
"variables for this layer.\n"
|
|
402
|
-
"In most cases, this error indicates that either:\n\n"
|
|
403
|
-
"1. The layer is owned by a parent layer that "
|
|
404
|
-
"implements a `build()` method, but calling the "
|
|
405
|
-
"parent's `build()` method did NOT create the state of "
|
|
406
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
407
|
-
"must create ALL state for the layer, including "
|
|
408
|
-
"the state of any children layers.\n\n"
|
|
409
|
-
"2. You need to implement "
|
|
410
|
-
"the `def build_from_config(self, config)` method "
|
|
411
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
412
|
-
"it during loading. "
|
|
413
|
-
"In this case, you might also want to implement the "
|
|
414
|
-
"method that generates the build config at saving time, "
|
|
415
|
-
"`def get_build_config(self)`. "
|
|
416
|
-
"The method `build_from_config()` is meant "
|
|
417
|
-
"to create the state "
|
|
418
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
419
|
-
)
|
|
420
|
-
raise ValueError(
|
|
421
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
422
|
-
"but received "
|
|
423
|
-
f"{len(store.keys())} variables during loading. "
|
|
424
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
412
|
+
@classmethod
|
|
413
|
+
def from_config(cls, config):
|
|
414
|
+
config = config.copy()
|
|
415
|
+
config["quantization_config"] = (
|
|
416
|
+
serialization_lib.deserialize_keras_object(
|
|
417
|
+
config.get("quantization_config", None)
|
|
425
418
|
)
|
|
419
|
+
)
|
|
420
|
+
return super().from_config(config)
|
|
426
421
|
|
|
427
422
|
@property
|
|
428
|
-
def
|
|
429
|
-
"""Returns a dict mapping quantization modes to variable names.
|
|
423
|
+
def variable_serialization_spec(self):
|
|
424
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
430
425
|
|
|
431
426
|
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
432
|
-
determine
|
|
433
|
-
mode.
|
|
427
|
+
determine the correct ordering of variables during serialization for
|
|
428
|
+
each quantization mode. `None` means no quantization.
|
|
434
429
|
"""
|
|
435
430
|
return {
|
|
436
|
-
None: [
|
|
437
|
-
|
|
438
|
-
|
|
431
|
+
None: [
|
|
432
|
+
"kernel",
|
|
433
|
+
"bias",
|
|
434
|
+
],
|
|
435
|
+
"int8": [
|
|
436
|
+
"kernel",
|
|
437
|
+
"bias",
|
|
438
|
+
"kernel_scale",
|
|
439
|
+
],
|
|
440
|
+
"int4": [
|
|
441
|
+
"kernel",
|
|
442
|
+
"bias",
|
|
443
|
+
"kernel_scale",
|
|
444
|
+
],
|
|
439
445
|
"float8": [
|
|
446
|
+
"kernel",
|
|
447
|
+
"bias",
|
|
440
448
|
"inputs_scale",
|
|
441
449
|
"inputs_amax_history",
|
|
442
450
|
"kernel_scale",
|
|
@@ -445,6 +453,7 @@ class EinsumDense(Layer):
|
|
|
445
453
|
"outputs_grad_amax_history",
|
|
446
454
|
],
|
|
447
455
|
"gptq": [
|
|
456
|
+
"bias",
|
|
448
457
|
"quantized_kernel",
|
|
449
458
|
"kernel_scale",
|
|
450
459
|
"kernel_zero",
|
|
@@ -454,9 +463,9 @@ class EinsumDense(Layer):
|
|
|
454
463
|
|
|
455
464
|
def quantized_build(self, kernel_shape, mode, config=None):
|
|
456
465
|
if mode == "int8":
|
|
457
|
-
self._int8_build(kernel_shape)
|
|
466
|
+
self._int8_build(kernel_shape, config)
|
|
458
467
|
elif mode == "int4":
|
|
459
|
-
self._int4_build(kernel_shape)
|
|
468
|
+
self._int4_build(kernel_shape, config)
|
|
460
469
|
elif mode == "float8":
|
|
461
470
|
self._float8_build()
|
|
462
471
|
elif mode == "gptq":
|
|
@@ -465,11 +474,17 @@ class EinsumDense(Layer):
|
|
|
465
474
|
raise self._quantization_mode_error(mode)
|
|
466
475
|
self._is_quantized = True
|
|
467
476
|
|
|
468
|
-
def _int8_build(self, kernel_shape):
|
|
477
|
+
def _int8_build(self, kernel_shape, config=None):
|
|
469
478
|
self._set_quantization_info()
|
|
470
|
-
self.inputs_quantizer =
|
|
471
|
-
|
|
479
|
+
self.inputs_quantizer = (
|
|
480
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
481
|
+
config,
|
|
482
|
+
quantizers.AbsMaxQuantizer(),
|
|
483
|
+
)
|
|
472
484
|
)
|
|
485
|
+
# If the config provided a default AbsMaxQuantizer, we need to
|
|
486
|
+
# override the axis to match the equation's reduction axes.
|
|
487
|
+
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
473
488
|
self._kernel = self.add_weight(
|
|
474
489
|
name="kernel",
|
|
475
490
|
shape=kernel_shape,
|
|
@@ -495,6 +510,8 @@ class EinsumDense(Layer):
|
|
|
495
510
|
group_size: int; contiguous input-group size for quantization
|
|
496
511
|
(=-1 means per-output-channel with no grouping).
|
|
497
512
|
"""
|
|
513
|
+
from keras.src.quantizers import gptq_core
|
|
514
|
+
|
|
498
515
|
# Ensures the forward pass uses the original high-precision kernel
|
|
499
516
|
# until calibration has been performed.
|
|
500
517
|
self.is_gptq_calibrated = False
|
|
@@ -505,12 +522,7 @@ class EinsumDense(Layer):
|
|
|
505
522
|
columns = kernel_shape[1]
|
|
506
523
|
elif len(kernel_shape) == 3:
|
|
507
524
|
shape = list(self.original_kernel_shape)
|
|
508
|
-
|
|
509
|
-
d_model_dim_index = shape.index(max(shape))
|
|
510
|
-
except ValueError:
|
|
511
|
-
raise TypeError(
|
|
512
|
-
f"Could not determine hidden dimension from shape {shape}"
|
|
513
|
-
)
|
|
525
|
+
d_model_dim_index = shape.index(max(shape))
|
|
514
526
|
|
|
515
527
|
if d_model_dim_index == 0: # QKV projection case
|
|
516
528
|
in_features, heads, head_dim = shape
|
|
@@ -527,18 +539,20 @@ class EinsumDense(Layer):
|
|
|
527
539
|
else:
|
|
528
540
|
raise ValueError("Could not determine row/column split.")
|
|
529
541
|
|
|
530
|
-
group_size =
|
|
531
|
-
if group_size == -1
|
|
532
|
-
n_groups = 1
|
|
533
|
-
else:
|
|
534
|
-
n_groups = math.ceil(rows / group_size)
|
|
542
|
+
group_size = gptq_core.get_group_size_for_layer(self, config)
|
|
543
|
+
n_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
|
|
535
544
|
|
|
536
|
-
|
|
537
|
-
|
|
545
|
+
self.gptq_unpacked_column_size = columns
|
|
546
|
+
|
|
547
|
+
weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
|
|
548
|
+
# For 4-bit weights, we pack two values per byte.
|
|
549
|
+
kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns
|
|
550
|
+
|
|
551
|
+
self._set_quantization_info()
|
|
538
552
|
|
|
539
553
|
self.quantized_kernel = self.add_weight(
|
|
540
554
|
name="kernel",
|
|
541
|
-
shape=(
|
|
555
|
+
shape=(kernel_columns, rows),
|
|
542
556
|
initializer="zeros",
|
|
543
557
|
dtype="uint8",
|
|
544
558
|
trainable=False,
|
|
@@ -567,11 +581,26 @@ class EinsumDense(Layer):
|
|
|
567
581
|
)
|
|
568
582
|
|
|
569
583
|
def _gptq_call(self, inputs, training=False):
|
|
584
|
+
from keras.src.quantizers import gptq_core
|
|
585
|
+
|
|
570
586
|
if not self.is_gptq_calibrated:
|
|
571
587
|
W = self._kernel
|
|
572
588
|
else:
|
|
589
|
+
should_unpack = (
|
|
590
|
+
gptq_core.get_weight_bits_for_layer(self, config=None) == 4
|
|
591
|
+
)
|
|
592
|
+
W = (
|
|
593
|
+
quantizers.unpack_int4(
|
|
594
|
+
self.quantized_kernel,
|
|
595
|
+
orig_len=self.gptq_unpacked_column_size,
|
|
596
|
+
axis=0,
|
|
597
|
+
dtype="uint8",
|
|
598
|
+
)
|
|
599
|
+
if should_unpack
|
|
600
|
+
else self.quantized_kernel
|
|
601
|
+
)
|
|
573
602
|
W = dequantize_with_sz_map(
|
|
574
|
-
|
|
603
|
+
W,
|
|
575
604
|
self.kernel_scale,
|
|
576
605
|
self.kernel_zero,
|
|
577
606
|
self.g_idx,
|
|
@@ -587,7 +616,7 @@ class EinsumDense(Layer):
|
|
|
587
616
|
y = self.activation(y)
|
|
588
617
|
return y
|
|
589
618
|
|
|
590
|
-
def _int4_build(self, kernel_shape):
|
|
619
|
+
def _int4_build(self, kernel_shape, config=None):
|
|
591
620
|
"""Build variables for int4 quantization.
|
|
592
621
|
|
|
593
622
|
The packed int4 kernel stores two int4 values within a single int8
|
|
@@ -599,9 +628,15 @@ class EinsumDense(Layer):
|
|
|
599
628
|
self._set_quantization_info()
|
|
600
629
|
|
|
601
630
|
# Quantizer for the inputs (per the reduced axes)
|
|
602
|
-
self.inputs_quantizer =
|
|
603
|
-
|
|
631
|
+
self.inputs_quantizer = (
|
|
632
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
633
|
+
config,
|
|
634
|
+
quantizers.AbsMaxQuantizer(),
|
|
635
|
+
)
|
|
604
636
|
)
|
|
637
|
+
# If the config provided a default AbsMaxQuantizer, we need to
|
|
638
|
+
# override the axis to match the equation's reduction axes.
|
|
639
|
+
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
605
640
|
|
|
606
641
|
# Choose the axis to perform int4 packing - use the first reduced axis
|
|
607
642
|
# for the kernel (analogous to the input dimension of a Dense layer).
|
|
@@ -723,13 +758,36 @@ class EinsumDense(Layer):
|
|
|
723
758
|
)
|
|
724
759
|
return (inputs_grad, None, None)
|
|
725
760
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
761
|
+
if self.inputs_quantizer:
|
|
762
|
+
inputs, inputs_scale = self.inputs_quantizer(
|
|
763
|
+
inputs, axis=self.quantization_axis
|
|
764
|
+
)
|
|
765
|
+
# Align `inputs_scale` axes with the output
|
|
766
|
+
# for correct broadcasting
|
|
767
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
768
|
+
inputs_scale, "input"
|
|
769
|
+
)
|
|
770
|
+
x = ops.einsum(self.equation, inputs, kernel)
|
|
771
|
+
# De-scale outputs
|
|
772
|
+
x = ops.cast(x, self.compute_dtype)
|
|
773
|
+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
774
|
+
else:
|
|
775
|
+
# Weight-only quantization: dequantize kernel and use float
|
|
776
|
+
# einsum. This is a workaround for PyTorch's einsum which
|
|
777
|
+
# doesn't support mixed-precision inputs (float input,
|
|
778
|
+
# int8 kernel).
|
|
779
|
+
if backend.backend() == "torch":
|
|
780
|
+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
781
|
+
float_kernel = ops.divide(
|
|
782
|
+
ops.cast(kernel, dtype=self.compute_dtype),
|
|
783
|
+
kernel_scale,
|
|
784
|
+
)
|
|
785
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
786
|
+
else:
|
|
787
|
+
x = ops.einsum(self.equation, inputs, kernel)
|
|
788
|
+
# De-scale outputs
|
|
789
|
+
x = ops.cast(x, self.compute_dtype)
|
|
790
|
+
x = ops.divide(x, kernel_scale)
|
|
733
791
|
return x, grad_fn
|
|
734
792
|
|
|
735
793
|
x = einsum_with_inputs_gradient(
|
|
@@ -799,17 +857,38 @@ class EinsumDense(Layer):
|
|
|
799
857
|
return (inputs_grad, None, None)
|
|
800
858
|
|
|
801
859
|
# Quantize inputs per `self.inputs_quantizer`.
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
860
|
+
if self.inputs_quantizer:
|
|
861
|
+
inputs_q, inputs_scale = self.inputs_quantizer(
|
|
862
|
+
inputs, axis=self.quantization_axis
|
|
863
|
+
)
|
|
864
|
+
# Align `inputs_scale` axes with the output
|
|
865
|
+
# for correct broadcasting
|
|
866
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
867
|
+
inputs_scale, "input"
|
|
868
|
+
)
|
|
869
|
+
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
|
|
870
|
+
# De-scale outputs
|
|
871
|
+
x = ops.cast(x, self.compute_dtype)
|
|
872
|
+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
873
|
+
else:
|
|
874
|
+
# Weight-only quantization: dequantize kernel and use float
|
|
875
|
+
# einsum. This is a workaround for PyTorch's einsum which
|
|
876
|
+
# doesn't support mixed-precision inputs (float input,
|
|
877
|
+
# int4 kernel).
|
|
878
|
+
if backend.backend() == "torch":
|
|
879
|
+
# Align `kernel_scale` to the same layout as
|
|
880
|
+
# `unpacked_kernel`.
|
|
881
|
+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
882
|
+
float_kernel = ops.divide(
|
|
883
|
+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
|
|
884
|
+
kernel_scale,
|
|
885
|
+
)
|
|
886
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
887
|
+
else:
|
|
888
|
+
x = ops.einsum(self.equation, inputs, unpacked_kernel)
|
|
889
|
+
# De-scale outputs
|
|
890
|
+
x = ops.cast(x, self.compute_dtype)
|
|
891
|
+
x = ops.divide(x, kernel_scale)
|
|
813
892
|
return x, grad_fn
|
|
814
893
|
|
|
815
894
|
x = einsum_with_inputs_gradient(
|
|
@@ -923,30 +1002,40 @@ class EinsumDense(Layer):
|
|
|
923
1002
|
x = self.activation(x)
|
|
924
1003
|
return x
|
|
925
1004
|
|
|
926
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
1005
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
927
1006
|
# Prevent quantization of the subclasses
|
|
928
1007
|
if type_check and (type(self) is not EinsumDense):
|
|
929
1008
|
raise self._not_implemented_error(self.quantize)
|
|
930
1009
|
|
|
1010
|
+
self.quantization_config = config
|
|
1011
|
+
|
|
931
1012
|
kernel_shape = self._kernel.shape
|
|
932
1013
|
if mode in ("int8", "int4", "gptq"):
|
|
933
1014
|
self._set_quantization_info()
|
|
934
1015
|
|
|
935
1016
|
if mode == "int8":
|
|
936
1017
|
# Quantize `self._kernel` to int8 and compute corresponding scale
|
|
937
|
-
|
|
938
|
-
self.
|
|
1018
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
1019
|
+
self.quantization_config,
|
|
1020
|
+
quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
|
|
1021
|
+
)
|
|
1022
|
+
kernel_value, kernel_scale = weight_quantizer(
|
|
1023
|
+
self._kernel, to_numpy=True
|
|
939
1024
|
)
|
|
940
1025
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
941
1026
|
del self._kernel
|
|
942
1027
|
elif mode == "int4":
|
|
943
1028
|
# Quantize to int4 values (stored in int8 dtype, range [-8, 7])
|
|
944
|
-
|
|
945
|
-
self.
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
1029
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
1030
|
+
self.quantization_config,
|
|
1031
|
+
quantizers.AbsMaxQuantizer(
|
|
1032
|
+
axis=self._kernel_reduced_axes,
|
|
1033
|
+
value_range=(-8, 7),
|
|
1034
|
+
output_dtype="int8",
|
|
1035
|
+
),
|
|
1036
|
+
)
|
|
1037
|
+
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
1038
|
+
self._kernel, to_numpy=True
|
|
950
1039
|
)
|
|
951
1040
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
952
1041
|
|
|
@@ -957,7 +1046,7 @@ class EinsumDense(Layer):
|
|
|
957
1046
|
)
|
|
958
1047
|
kernel_value = packed_kernel_value
|
|
959
1048
|
del self._kernel
|
|
960
|
-
self.quantized_build(kernel_shape, mode,
|
|
1049
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
961
1050
|
|
|
962
1051
|
# Assign values to the newly created variables.
|
|
963
1052
|
if mode in ("int8", "int4"):
|
|
@@ -968,7 +1057,7 @@ class EinsumDense(Layer):
|
|
|
968
1057
|
if self.dtype_policy.quantization_mode is None:
|
|
969
1058
|
policy_name = mode
|
|
970
1059
|
if mode == "gptq":
|
|
971
|
-
policy_name =
|
|
1060
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
972
1061
|
policy = dtype_policies.get(
|
|
973
1062
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
974
1063
|
)
|
|
@@ -1165,46 +1254,6 @@ class EinsumDense(Layer):
|
|
|
1165
1254
|
self._kernel_reverse_transpose_axes,
|
|
1166
1255
|
) = _analyze_quantization_info(self.equation, self.input_spec.ndim)
|
|
1167
1256
|
|
|
1168
|
-
def _get_gptq_group_size(self, config):
|
|
1169
|
-
"""Determine the group size for GPTQ quantization.
|
|
1170
|
-
|
|
1171
|
-
The group size can be specified either through the `config` argument
|
|
1172
|
-
or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
|
|
1173
|
-
|
|
1174
|
-
The config argument is usually available when quantizing the layer
|
|
1175
|
-
via the `quantize` method. If the layer was deserialized from a
|
|
1176
|
-
saved model, the group size should be specified in the `dtype_policy`.
|
|
1177
|
-
|
|
1178
|
-
Args:
|
|
1179
|
-
config: An optional configuration object that may contain the
|
|
1180
|
-
`group_size` attribute.
|
|
1181
|
-
Returns:
|
|
1182
|
-
int. The determined group size for GPTQ quantization.
|
|
1183
|
-
Raises:
|
|
1184
|
-
ValueError: If the group size is not specified in either the
|
|
1185
|
-
`config` or the `dtype_policy`.
|
|
1186
|
-
"""
|
|
1187
|
-
if config and isinstance(config, GPTQConfig):
|
|
1188
|
-
return config.group_size
|
|
1189
|
-
elif isinstance(self.dtype_policy, GPTQDTypePolicy):
|
|
1190
|
-
return self.dtype_policy.group_size
|
|
1191
|
-
elif isinstance(self.dtype_policy, DTypePolicyMap):
|
|
1192
|
-
policy = self.dtype_policy[self.path]
|
|
1193
|
-
if not isinstance(policy, GPTQDTypePolicy):
|
|
1194
|
-
# This should never happen based on how we set the
|
|
1195
|
-
# quantization mode, but we check just in case.
|
|
1196
|
-
raise ValueError(
|
|
1197
|
-
"Expected a `dtype_policy` of type `GPTQDTypePolicy`."
|
|
1198
|
-
f"Got: {type(policy)}"
|
|
1199
|
-
)
|
|
1200
|
-
return policy.group_size
|
|
1201
|
-
else:
|
|
1202
|
-
raise ValueError(
|
|
1203
|
-
"For GPTQ quantization, the group_size must be specified"
|
|
1204
|
-
"either through a `dtype_policy` of type "
|
|
1205
|
-
"`GPTQDTypePolicy` or the `config` argument."
|
|
1206
|
-
)
|
|
1207
|
-
|
|
1208
1257
|
|
|
1209
1258
|
def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
|
|
1210
1259
|
"""Parses an einsum string to determine the shapes of the weights.
|